#!/usr/bin/env python3
"""
AI Photo Metadata Enricher

Recursively scans a photo folder, reads existing metadata with ExifTool,
creates an in-memory preview, asks a local OpenAI-compatible vision model for
structured metadata, and optionally writes the result back to the image.

Dry run is the default. Files are modified only when --write is supplied.

Requirements:
    python -m pip install pillow requests

Optional HEIC/HEIF support:
    python -m pip install pillow-heif

Example:
    python ai-photo-metadata-enricher.py "D:\\Photos" \
        --model "your-vision-model-id" --limit 5
"""

from __future__ import annotations

import argparse
import base64
import io
import json
import os
import shutil
import subprocess
import sys
import time
from dataclasses import asdict, dataclass
from pathlib import Path
from typing import Any, TextIO

import requests
from PIL import Image, ImageOps, UnidentifiedImageError

try:
    from pillow_heif import register_heif_opener

    register_heif_opener()
except ImportError:
    pass


SCRIPT_VERSION = "1.0.0"
DEFAULT_ENDPOINT = "http://localhost:1234/v1/chat/completions"
IMAGE_EXTENSIONS = {
    ".jpg",
    ".jpeg",
    ".png",
    ".tif",
    ".tiff",
    ".webp",
    ".heic",
    ".heif",
}
DESCRIPTION_FIELDS = (
    "XMP:Description",
    "IPTC:Caption-Abstract",
    "EXIF:ImageDescription",
)
IMPORTANT_METADATA_FIELDS = (
    "XMP:Title",
    "XMP:Description",
    "XMP:Subject",
    "IPTC:ObjectName",
    "IPTC:Caption-Abstract",
    "IPTC:Keywords",
    "EXIF:ImageDescription",
    "Composite:ImageSize",
    "EXIF:DateTimeOriginal",
    "EXIF:Make",
    "EXIF:Model",
    "EXIF:LensModel",
    "EXIF:FocalLength",
    "EXIF:FNumber",
    "EXIF:ExposureTime",
    "EXIF:ISO",
)


@dataclass
class AiMetadata:
    title: str
    description: str
    alt_text: str
    keywords: list[str]
    category: str
    adult_content: bool
    explicit_content: bool
    confidence: float
    warnings: list[str]


def resolve_exiftool_path(configured_path: str | None) -> str:
    candidates: list[str] = []
    if configured_path:
        candidates.append(configured_path)

    path_match = shutil.which("exiftool")
    if path_match:
        candidates.append(path_match)

    candidates.extend(
        [
            r"C:\Program Files\ExifTool\exiftool.exe",
            r"C:\exiftool\exiftool.exe",
            r"C:\Windows\exiftool.exe",
        ]
    )

    for candidate in candidates:
        candidate_path = Path(candidate).expanduser()
        if candidate_path.is_file():
            return str(candidate_path.resolve())

    raise RuntimeError(
        "ExifTool was not found. Add it to PATH or pass "
        '--exiftool "C:\\path\\to\\exiftool.exe".'
    )


def run_exiftool(
    arguments: list[str], exiftool_path: str
) -> subprocess.CompletedProcess[str]:
    return subprocess.run(
        [exiftool_path, *arguments],
        capture_output=True,
        text=True,
        encoding="utf-8",
        errors="replace",
        check=False,
    )


def read_metadata(path: Path, exiftool_path: str) -> dict[str, Any]:
    result = run_exiftool(["-json", "-G1", "-a", "-s", str(path)], exiftool_path)
    if result.returncode != 0:
        raise RuntimeError(result.stderr.strip() or "ExifTool could not read metadata.")

    payload = json.loads(result.stdout)
    return payload[0] if payload else {}


def metadata_completeness(metadata: dict[str, Any]) -> dict[str, Any]:
    present = {
        field: metadata.get(field) not in (None, "", [], {})
        for field in IMPORTANT_METADATA_FIELDS
    }
    filled = sum(present.values())
    return {
        "filled_count": filled,
        "total_fields_checked": len(IMPORTANT_METADATA_FIELDS),
        "score": round(filled / len(IMPORTANT_METADATA_FIELDS), 3),
        "missing_fields": [field for field, exists in present.items() if not exists],
    }


def slim_metadata_for_ai(
    metadata: dict[str, Any], max_chars: int = 7000
) -> dict[str, Any]:
    preferred_prefixes = (
        "EXIF:",
        "XMP:",
        "IPTC:",
        "Composite:",
        "File:",
        "QuickTime:",
    )
    excluded_fragments = ("thumbnail", "preview", "binary", "icc_profile")
    cleaned = {
        key: value
        for key, value in metadata.items()
        if key.startswith(preferred_prefixes)
        and not any(fragment in key.lower() for fragment in excluded_fragments)
    }

    if len(json.dumps(cleaned, ensure_ascii=False, default=str)) <= max_chars:
        return cleaned

    trimmed: dict[str, Any] = {}
    current_length = 0
    for key, value in cleaned.items():
        piece_length = len(
            json.dumps({key: value}, ensure_ascii=False, default=str)
        )
        if current_length + piece_length > max_chars:
            break
        trimmed[key] = value
        current_length += piece_length

    trimmed["_truncated"] = True
    return trimmed


def image_preview_as_data_url(
    path: Path, long_side: int, jpeg_quality: int
) -> str:
    try:
        with Image.open(path) as source:
            image = ImageOps.exif_transpose(source)

            if image.mode in ("RGBA", "LA"):
                background = Image.new("RGB", image.size, (255, 255, 255))
                background.paste(image, mask=image.getchannel("A"))
                image = background
            elif image.mode != "RGB":
                image = image.convert("RGB")

            width, height = image.size
            current_long_side = max(width, height)
            if current_long_side > long_side:
                scale = long_side / current_long_side
                image = image.resize(
                    (round(width * scale), round(height * scale)),
                    Image.Resampling.LANCZOS,
                )

            buffer = io.BytesIO()
            image.save(buffer, format="JPEG", quality=jpeg_quality, optimize=True)
    except UnidentifiedImageError as exc:
        if path.suffix.lower() in {".heic", ".heif"}:
            raise RuntimeError(
                "Pillow could not open this HEIC/HEIF file. Install pillow-heif."
            ) from exc
        raise RuntimeError("Pillow could not identify this image format.") from exc

    encoded = base64.b64encode(buffer.getvalue()).decode("ascii")
    return f"data:image/jpeg;base64,{encoded}"


def extract_json(text: str) -> dict[str, Any]:
    text = text.strip()
    first_brace = text.find("{")
    last_brace = text.rfind("}")
    if first_brace < 0 or last_brace <= first_brace:
        raise ValueError(f"No JSON object found in model response: {text[:300]}")
    return json.loads(text[first_brace : last_brace + 1])


def parse_bool(value: Any) -> bool:
    if isinstance(value, bool):
        return value
    if isinstance(value, (int, float)):
        return value != 0
    if isinstance(value, str):
        return value.strip().lower() in {"1", "true", "yes", "y"}
    return False


def normalize_string_list(value: Any, limit: int) -> list[str]:
    if isinstance(value, str):
        value = value.split(",")
    if not isinstance(value, list):
        return []

    normalized: list[str] = []
    seen: set[str] = set()
    for item in value:
        text = str(item).strip()
        comparison = text.casefold()
        if text and comparison not in seen:
            normalized.append(text)
            seen.add(comparison)
    return normalized[:limit]


def response_content(payload: dict[str, Any]) -> str:
    try:
        content = payload["choices"][0]["message"]["content"]
    except (KeyError, IndexError, TypeError) as exc:
        raise ValueError("The model response did not contain message content.") from exc

    if isinstance(content, str):
        return content
    if isinstance(content, list):
        return "".join(
            str(part.get("text", ""))
            for part in content
            if isinstance(part, dict) and part.get("type") == "text"
        )
    raise ValueError("The model returned an unsupported message content format.")


def call_local_ai(
    endpoint: str,
    model: str,
    api_key: str | None,
    image_data_url: str,
    metadata: dict[str, Any],
    completeness: dict[str, Any],
    timeout_seconds: int,
) -> AiMetadata:
    system_prompt = """
You are an image metadata assistant for a private local photo archive.
Return only valid JSON, without markdown or commentary.

Create accurate, useful metadata from the visible image and existing metadata.
Do not invent identity, age, location, relationships, intent, or private traits.
Preserve useful existing facts when they are better supported than an inference.
Use direct, factual cataloging language for boudoir, nude, fetish, or explicit
adult content when visible. Do not eroticize a person whose age is uncertain.
If a person may be under 18, or content appears exploitative or non-consensual,
avoid sexualized metadata and add a warning.

Use this schema:
{
  "title": "descriptive title, 80 characters maximum",
  "description": "detailed factual description, normally 2-4 sentences",
  "alt_text": "concise accessibility description",
  "keywords": ["15-40 useful search terms"],
  "category": "short broad category",
  "adult_content": false,
  "explicit_content": false,
  "confidence": 0.0,
  "warnings": ["uncertainties or safety notes"]
}
""".strip()

    payload = {
        "model": model,
        "temperature": 0.2,
        "max_tokens": 1200,
        "messages": [
            {"role": "system", "content": system_prompt},
            {
                "role": "user",
                "content": [
                    {
                        "type": "text",
                        "text": json.dumps(
                            {
                                "existing_metadata": metadata,
                                "metadata_completeness": completeness,
                                "task": "Return improved metadata using the JSON schema.",
                            },
                            ensure_ascii=False,
                            default=str,
                        ),
                    },
                    {"type": "image_url", "image_url": {"url": image_data_url}},
                ],
            },
        ],
    }
    headers = {"Content-Type": "application/json"}
    if api_key:
        headers["Authorization"] = f"Bearer {api_key}"

    try:
        response = requests.post(
            endpoint,
            json=payload,
            headers=headers,
            timeout=timeout_seconds,
        )
        response.raise_for_status()
    except requests.RequestException as exc:
        raise RuntimeError(f"Local AI request failed: {exc}") from exc

    parsed = extract_json(response_content(response.json()))
    try:
        confidence = float(parsed.get("confidence", 0.0))
    except (TypeError, ValueError):
        confidence = 0.0

    return AiMetadata(
        title=str(parsed.get("title", "")).strip()[:80],
        description=str(parsed.get("description", "")).strip(),
        alt_text=str(parsed.get("alt_text", "")).strip(),
        keywords=normalize_string_list(parsed.get("keywords"), 50),
        category=str(parsed.get("category", "")).strip()[:64],
        adult_content=parse_bool(parsed.get("adult_content")),
        explicit_content=parse_bool(parsed.get("explicit_content")),
        confidence=max(0.0, min(1.0, confidence)),
        warnings=normalize_string_list(parsed.get("warnings"), 20),
    )


def build_exiftool_write_args(
    path: Path, ai_metadata: AiMetadata, keep_backup: bool
) -> list[str]:
    arguments: list[str] = []
    if not keep_backup:
        arguments.append("-overwrite_original")

    arguments.extend(["-XMP-dc:Subject=", "-IPTC:Keywords="])
    if ai_metadata.title:
        arguments.extend(
            [
                f"-XMP-dc:Title={ai_metadata.title}",
                f"-IPTC:ObjectName={ai_metadata.title}",
            ]
        )
    if ai_metadata.description:
        arguments.extend(
            [
                f"-XMP-dc:Description={ai_metadata.description}",
                f"-IPTC:Caption-Abstract={ai_metadata.description}",
                f"-EXIF:ImageDescription={ai_metadata.description}",
            ]
        )
    if ai_metadata.category:
        arguments.append(f"-XMP-photoshop:Category={ai_metadata.category}")
    if ai_metadata.alt_text:
        arguments.append(
            f"-XMP-iptcExt:AltTextAccessibility={ai_metadata.alt_text}"
        )
    for keyword in ai_metadata.keywords:
        arguments.extend(
            [
                f"-XMP-dc:Subject+={keyword}",
                f"-IPTC:Keywords+={keyword}",
            ]
        )

    arguments.extend(
        [
            "-XMP-xmp:MetadataDate=now",
            "-XMP-xmp:ModifyDate=now",
            f"-XMP-xmp:CreatorTool=AI Photo Metadata Enricher {SCRIPT_VERSION}",
            str(path),
        ]
    )
    return arguments


def write_metadata(
    path: Path, ai_metadata: AiMetadata, keep_backup: bool, exiftool_path: str
) -> str:
    result = run_exiftool(
        build_exiftool_write_args(path, ai_metadata, keep_backup),
        exiftool_path,
    )
    if result.returncode != 0:
        raise RuntimeError(result.stderr.strip() or "ExifTool could not write metadata.")
    return result.stdout.strip()


def append_log(log_file: TextIO, record: dict[str, Any]) -> None:
    log_file.write(json.dumps(record, ensure_ascii=False, default=str) + "\n")
    log_file.flush()


def process_image(
    path: Path,
    arguments: argparse.Namespace,
    exiftool_path: str,
    log_file: TextIO,
) -> dict[str, Any]:
    started = time.time()
    record: dict[str, Any] = {
        "file": str(path),
        "status": "started",
        "write_enabled": arguments.write,
    }

    try:
        existing_metadata = read_metadata(path, exiftool_path)
        if arguments.skip_existing_description and any(
            existing_metadata.get(field) for field in DESCRIPTION_FIELDS
        ):
            record.update(
                status="skipped_existing_description",
                reason="A description or caption already exists.",
            )
            return record

        completeness = metadata_completeness(existing_metadata)
        preview = image_preview_as_data_url(
            path, arguments.long_side, arguments.jpeg_quality
        )
        ai_metadata = call_local_ai(
            endpoint=arguments.endpoint,
            model=arguments.model,
            api_key=arguments.api_key,
            image_data_url=preview,
            metadata=slim_metadata_for_ai(existing_metadata),
            completeness=completeness,
            timeout_seconds=arguments.timeout,
        )
        record.update(
            metadata_completeness_before=completeness,
            ai=asdict(ai_metadata),
        )

        if ai_metadata.confidence < arguments.min_confidence:
            record.update(
                status="skipped_low_confidence",
                reason=(
                    f"Confidence {ai_metadata.confidence:.2f} is below "
                    f"{arguments.min_confidence:.2f}."
                ),
            )
        elif arguments.write:
            record.update(
                status="written",
                exiftool_output=write_metadata(
                    path,
                    ai_metadata,
                    keep_backup=arguments.backup,
                    exiftool_path=exiftool_path,
                ),
            )
        else:
            record["status"] = "dry_run"
    except Exception as exc:
        record.update(status="error", error=str(exc))
    finally:
        record["seconds"] = round(time.time() - started, 3)
        append_log(log_file, record)

    return record


def validate_arguments(arguments: argparse.Namespace) -> None:
    if arguments.long_side < 256:
        raise ValueError("--long-side must be at least 256.")
    if not 1 <= arguments.jpeg_quality <= 100:
        raise ValueError("--jpeg-quality must be between 1 and 100.")
    if arguments.timeout < 1:
        raise ValueError("--timeout must be at least 1 second.")
    if not 0 <= arguments.min_confidence <= 1:
        raise ValueError("--min-confidence must be between 0 and 1.")
    if arguments.limit < 0:
        raise ValueError("--limit cannot be negative.")


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(
        description="Enrich photo metadata with a local vision model."
    )
    parser.add_argument("folder", help="Root photo folder to scan recursively.")
    parser.add_argument(
        "--model",
        required=True,
        help="Vision model identifier shown by LM Studio.",
    )
    parser.add_argument(
        "--endpoint",
        default=DEFAULT_ENDPOINT,
        help=f"OpenAI-compatible chat endpoint. Default: {DEFAULT_ENDPOINT}",
    )
    parser.add_argument(
        "--api-key",
        default=os.getenv("LM_STUDIO_API_KEY"),
        help="Optional bearer token. Defaults to LM_STUDIO_API_KEY.",
    )
    parser.add_argument(
        "--exiftool",
        default=None,
        help="Optional full path to the ExifTool executable.",
    )
    parser.add_argument(
        "--write",
        action="store_true",
        help="Write metadata. Without this flag, no image is modified.",
    )
    parser.add_argument(
        "--backup",
        action="store_true",
        help="Keep ExifTool _original backup files when writing.",
    )
    parser.add_argument(
        "--skip-existing-description",
        action="store_true",
        help="Skip files that already have a description or caption.",
    )
    parser.add_argument("--long-side", type=int, default=1024)
    parser.add_argument("--jpeg-quality", type=int, default=88)
    parser.add_argument("--timeout", type=int, default=120)
    parser.add_argument("--min-confidence", type=float, default=0.35)
    parser.add_argument(
        "--limit",
        type=int,
        default=0,
        help="Maximum images to process. Zero means no limit.",
    )
    parser.add_argument(
        "--log",
        default="ai_photo_metadata_log.jsonl",
        help="JSON Lines audit log path.",
    )
    parser.add_argument(
        "--version",
        action="version",
        version=f"%(prog)s {SCRIPT_VERSION}",
    )
    return parser.parse_args()


def main() -> int:
    arguments = parse_args()
    try:
        validate_arguments(arguments)
    except ValueError as exc:
        print(f"Configuration error: {exc}", file=sys.stderr)
        return 2

    root = Path(arguments.folder).expanduser().resolve()
    if not root.is_dir():
        print(f"Photo folder does not exist: {root}", file=sys.stderr)
        return 2

    try:
        exiftool_path = resolve_exiftool_path(arguments.exiftool)
    except RuntimeError as exc:
        print(f"ExifTool error: {exc}", file=sys.stderr)
        return 2

    image_paths = sorted(
        path
        for path in root.rglob("*")
        if path.is_file() and path.suffix.lower() in IMAGE_EXTENSIONS
    )
    if arguments.limit:
        image_paths = image_paths[: arguments.limit]

    log_path = Path(arguments.log).expanduser().resolve()
    log_path.parent.mkdir(parents=True, exist_ok=True)

    print(f"Found:    {len(image_paths)} image(s)")
    print(f"Mode:     {'WRITE' if arguments.write else 'DRY RUN'}")
    print(f"Endpoint: {arguments.endpoint}")
    print(f"Model:    {arguments.model}")
    print(f"ExifTool: {exiftool_path}")
    print(f"Log:      {log_path}")
    print()

    totals = {"written": 0, "dry_run": 0, "skipped": 0, "error": 0}
    with log_path.open("a", encoding="utf-8") as log_file:
        for index, path in enumerate(image_paths, start=1):
            print(f"[{index}/{len(image_paths)}] {path}")
            record = process_image(path, arguments, exiftool_path, log_file)
            status = record["status"]

            if status == "written":
                totals["written"] += 1
                print("  metadata written")
            elif status == "dry_run":
                totals["dry_run"] += 1
                ai = record["ai"]
                print(f"  title: {ai['title']}")
                print(f"  category: {ai['category']}")
                print(f"  confidence: {ai['confidence']:.2f}")
                print(f"  keywords: {', '.join(ai['keywords'])}")
            elif status.startswith("skipped"):
                totals["skipped"] += 1
                print(f"  skipped: {record.get('reason', status)}")
            else:
                totals["error"] += 1
                print(f"  error: {record.get('error', 'Unknown error')}")

    print()
    print("Done")
    print(f"Dry runs: {totals['dry_run']}")
    print(f"Written:  {totals['written']}")
    print(f"Skipped:  {totals['skipped']}")
    print(f"Errors:   {totals['error']}")
    if not arguments.write:
        print("No image files were modified.")

    return 1 if totals["error"] else 0


if __name__ == "__main__":
    raise SystemExit(main())
