"""
Indexed Color Compression using K-means (8bpp BMP output)
==========================================================
Converts images to 256-color indexed format using K-means clustering.
Outputs as 8-bit BMP with color palette.

Usage: python indexed_kmeans_8bpp.py [input_image]
"""

import numpy as np
from PIL import Image
import sys
import os


def compress_indexed_kmeans(rgb: np.ndarray, num_colors: int = 256
                            ) -> tuple[np.ndarray, np.ndarray]:
    """
    Compress RGB image using K-means clustering for palette generation.

    Args:
        rgb: Input RGB image (H, W, 3)
        num_colors: Number of colors in palette (default 256)

    Returns:
        indices: Index array (H, W) with palette indices
        palette: Color palette (num_colors, 3)
    """
    from sklearn.cluster import KMeans

    h, w = rgb.shape[:2]

    print("  Running K-means clustering (this may take a while)...")
    rgb_flat = rgb.reshape(-1, 3).astype(np.float32)

    # Sample pixels for faster clustering (use all if small image)
    n_pixels = rgb_flat.shape[0]
    if n_pixels > 100000:
        sample_idx = np.random.choice(n_pixels, 100000, replace=False)
        sample = rgb_flat[sample_idx]
    else:
        sample = rgb_flat

    # Run K-means
    kmeans = KMeans(n_clusters=num_colors, random_state=42, n_init=10, max_iter=300)
    kmeans.fit(sample)

    # Get palette from cluster centers
    palette = np.clip(kmeans.cluster_centers_, 0, 255).astype(np.uint8)

    # Assign all pixels to nearest cluster
    indices = kmeans.predict(rgb_flat).astype(np.uint8).reshape(h, w)

    return indices, palette


def decompress_indexed(indices: np.ndarray, palette: np.ndarray) -> np.ndarray:
    """
    Decompress indexed color back to RGB.

    Args:
        indices: Index array (H, W)
        palette: Color palette (num_colors, 3)

    Returns:
        RGB image (H, W, 3)
    """
    # Look up each pixel in palette
    rgb = palette[indices]
    return rgb.astype(np.uint8)


def calculate_psnr(original: np.ndarray, compressed: np.ndarray) -> float:
    """
    Calculate Peak Signal-to-Noise Ratio (PSNR) in dB.
    Higher is better. Typical values: 30-40 dB is good.
    """
    mse = np.mean((original.astype(np.float64) - compressed.astype(np.float64)) ** 2)
    if mse == 0:
        return float('inf')
    max_pixel = 255.0
    psnr = 20 * np.log10(max_pixel / np.sqrt(mse))
    return psnr


def save_indexed_bmp(indices: np.ndarray, palette: np.ndarray, output_path: str):
    """
    Save indexed image as 8bpp BMP file.

    Args:
        indices: Index array (H, W)
        palette: Color palette (num_colors, 3)
        output_path: Output BMP file path
    """
    h, w = indices.shape

    # Create palette image
    img_p = Image.new('P', (w, h))

    # Set palette (PIL expects flat list of R,G,B values)
    img_p.putpalette(palette.flatten().tolist())

    # Set pixel data
    img_p.putdata(indices.flatten().tolist())

    # Save as BMP
    img_p.save(output_path, format='BMP')


def main():
    # Parse arguments
    input_path = "test.png"

    args = sys.argv[1:]
    if len(args) > 0 and not args[0].startswith("-"):
        input_path = args[0]

    # Get directory of input file
    input_dir = os.path.dirname(os.path.abspath(input_path))
    if not input_dir:
        input_dir = "."

    base_name = os.path.splitext(os.path.basename(input_path))[0]

    print(f"Indexed Color K-means Compression (8bpp BMP)")
    print(f"=" * 50)
    print(f"Input: {input_path}")

    # Load image
    try:
        img = Image.open(input_path).convert('RGB')
    except FileNotFoundError:
        print(f"Error: File '{input_path}' not found!")
        sys.exit(1)

    rgb = np.array(img)
    h, w = rgb.shape[:2]

    print(f"Image size: {w} x {h}")
    print(f"Original: 24 bpp ({w * h * 24 / 8 / 1024:.1f} KB)")

    # Compress with K-means
    num_colors = 256
    print(f"\nCompressing:")
    print(f"  Palette: {num_colors} colors")
    print(f"  Method: K-means clustering")

    indices, palette = compress_indexed_kmeans(rgb, num_colors)

    # Calculate compressed size
    index_bits = h * w * 8  # 8 bits per pixel
    palette_bits = num_colors * 3 * 8  # RGB palette
    total_bits = index_bits + palette_bits
    effective_bpp = 8.0

    print(f"\nCompressed statistics:")
    print(f"  Index data:   {index_bits / 8 / 1024:.1f} KB")
    print(f"  Palette data: {palette_bits / 8 / 1024:.2f} KB")
    print(f"  Total:        {total_bits / 8 / 1024:.1f} KB")
    print(f"  Effective BPP: {effective_bpp:.1f}")
    print(f"  Compression ratio: {24 / effective_bpp:.2f}x")

    # Decompress for quality check
    rgb_reconstructed = decompress_indexed(indices, palette)

    # Calculate quality metrics
    psnr = calculate_psnr(rgb, rgb_reconstructed)
    print(f"\nQuality metrics:")
    print(f"  PSNR: {psnr:.2f} dB")
    if psnr > 35:
        print(f"  Quality: Excellent")
    elif psnr > 30:
        print(f"  Quality: Good")
    elif psnr > 25:
        print(f"  Quality: Acceptable")
    else:
        print(f"  Quality: Poor")

    # Save as 8bpp BMP
    output_bmp = os.path.join(input_dir, f"{base_name}_indexed_kmeans.bmp")
    save_indexed_bmp(indices, palette, output_bmp)
    print(f"\nSaved: {output_bmp}")

    print(f"\nDone!")


if __name__ == "__main__":
    main()
