from __future__ import annotations

import os
import sqlite3
from contextlib import asynccontextmanager
from typing import Annotated, Any

from fastapi import Depends, FastAPI, HTTPException, Query
from fastapi.middleware.cors import CORSMiddleware

from app.config import (
    API_PREFIX,
    CORS_ORIGINS,
    DEFAULT_PAGE_SIZE,
    MAX_PAGE_SIZE,
)
from app.db import get_connection, init_database


def _row_to_dict(row: sqlite3.Row) -> dict[str, Any]:
    return {k: row[k] for k in row.keys()}


def get_db() -> sqlite3.Connection:
    conn = get_connection()
    try:
        yield conn
    finally:
        conn.close()


Db = Annotated[sqlite3.Connection, Depends(get_db)]


@asynccontextmanager
async def lifespan(app: FastAPI):
    force = os.environ.get("FORCE_RELOAD_DB", "").lower() in ("1", "true", "yes")
    init_database(force_reload=force)
    yield


app = FastAPI(
    title="DramaBox Catalog API",
    description="REST API for drama/movie catalog and playback sources.",
    version="1.0.0",
    lifespan=lifespan,
)

if CORS_ORIGINS == ["*"]:
    app.add_middleware(
        CORSMiddleware,
        allow_origins=["*"],
        allow_credentials=False,
        allow_methods=["*"],
        allow_headers=["*"],
    )
else:
    app.add_middleware(
        CORSMiddleware,
        allow_origins=CORS_ORIGINS,
        allow_credentials=True,
        allow_methods=["*"],
        allow_headers=["*"],
    )


@app.get("/health")
def health():
    return {"status": "ok"}


@app.get(f"{API_PREFIX}/stats")
def stats(db: Db):
    d = db.execute("SELECT COUNT(*) FROM dramas").fetchone()[0]
    e = db.execute("SELECT COUNT(*) FROM episodes").fetchone()[0]
    s = db.execute("SELECT COUNT(*) FROM episode_sources").fetchone()[0]
    return {"dramas": int(d), "episodes": int(e), "episode_sources": int(s)}


@app.get(f"{API_PREFIX}/settings")
def app_settings(db: Db):
    rows = db.execute("SELECT key_name, key_value FROM settings").fetchall()
    return {"settings": {r["key_name"]: r["key_value"] for r in rows}}


@app.get(f"{API_PREFIX}/dramas")
def list_dramas(
    db: Db,
    page: int = Query(1, ge=1),
    limit: int = Query(DEFAULT_PAGE_SIZE, ge=1, le=MAX_PAGE_SIZE),
    platform: str | None = None,
    language: str | None = None,
    media_type: str | None = Query(None, alias="type"),
    q: str | None = None,
):
    offset = (page - 1) * limit
    where: list[str] = ["1=1"]
    params: list[Any] = []

    if platform:
        where.append("platform = ?")
        params.append(platform)
    if language:
        where.append("language = ?")
        params.append(language)
    if media_type:
        where.append("type = ?")
        params.append(media_type)
    if q:
        qq = f"%{q.strip()}%"
        where.append("(title LIKE ? OR IFNULL(description,'') LIKE ?)")
        params.extend([qq, qq])

    where_sql = " AND ".join(where)

    total = db.execute(
        f"SELECT COUNT(*) FROM dramas WHERE {where_sql}",
        params,
    ).fetchone()[0]

    rows = db.execute(
        f"""
        SELECT * FROM dramas
        WHERE {where_sql}
        ORDER BY datetime(created_at) DESC
        LIMIT ? OFFSET ?
        """,
        [*params, limit, offset],
    ).fetchall()

    pages = (int(total) + limit - 1) // limit if total else 0
    return {
        "items": [_row_to_dict(r) for r in rows],
        "total": int(total),
        "page": page,
        "pages": pages,
        "limit": limit,
    }


@app.get(f"{API_PREFIX}/dramas/by-book/{{book_id}}")
def get_drama_by_book_id(book_id: str, db: Db):
    row = db.execute(
        "SELECT * FROM dramas WHERE book_id = ? LIMIT 1",
        (book_id,),
    ).fetchone()
    if not row:
        raise HTTPException(status_code=404, detail="Drama not found")
    return _row_to_dict(row)


@app.get(f"{API_PREFIX}/dramas/{{drama_id}}")
def get_drama(drama_id: str, db: Db):
    row = db.execute(
        "SELECT * FROM dramas WHERE id = ? LIMIT 1",
        (drama_id,),
    ).fetchone()
    if not row:
        raise HTTPException(status_code=404, detail="Drama not found")
    return _row_to_dict(row)


@app.get(f"{API_PREFIX}/dramas/{{drama_id}}/episodes")
def list_episodes(
    drama_id: str,
    db: Db,
    page: int = Query(1, ge=1),
    limit: int = Query(50, ge=1, le=MAX_PAGE_SIZE),
):
    exists = db.execute(
        "SELECT 1 FROM dramas WHERE id = ? LIMIT 1",
        (drama_id,),
    ).fetchone()
    if not exists:
        raise HTTPException(status_code=404, detail="Drama not found")

    offset = (page - 1) * limit
    total = db.execute(
        "SELECT COUNT(*) FROM episodes WHERE drama_id = ?",
        (drama_id,),
    ).fetchone()[0]

    rows = db.execute(
        """
        SELECT * FROM episodes
        WHERE drama_id = ?
        ORDER BY CAST(chapter_index AS INTEGER), CAST(episode_number AS INTEGER), id
        LIMIT ? OFFSET ?
        """,
        (drama_id, limit, offset),
    ).fetchall()

    pages = (int(total) + limit - 1) // limit if total else 0
    return {
        "drama_id": drama_id,
        "items": [_row_to_dict(r) for r in rows],
        "total": int(total),
        "page": page,
        "pages": pages,
        "limit": limit,
    }


@app.get(f"{API_PREFIX}/episodes/{{episode_id}}")
def get_episode(episode_id: str, db: Db):
    row = db.execute(
        "SELECT * FROM episodes WHERE id = ? LIMIT 1",
        (episode_id,),
    ).fetchone()
    if not row:
        raise HTTPException(status_code=404, detail="Episode not found")
    return _row_to_dict(row)


@app.get(f"{API_PREFIX}/episodes/{{episode_id}}/sources")
def list_episode_sources(episode_id: str, db: Db):
    exists = db.execute(
        "SELECT 1 FROM episodes WHERE id = ? LIMIT 1",
        (episode_id,),
    ).fetchone()
    if not exists:
        raise HTTPException(status_code=404, detail="Episode not found")

    rows = db.execute(
        """
        SELECT * FROM episode_sources
        WHERE episode_id = ?
        ORDER BY quality, server, id
        """,
        (episode_id,),
    ).fetchall()
    return {
        "episode_id": episode_id,
        "sources": [_row_to_dict(r) for r in rows],
    }


@app.get(f"{API_PREFIX}/platforms")
def list_platforms(db: Db):
    rows = db.execute(
        """
        SELECT platform, COUNT(*) AS cnt
        FROM dramas
        WHERE platform IS NOT NULL AND TRIM(platform) != ''
        GROUP BY platform
        ORDER BY cnt DESC, platform
        """
    ).fetchall()
    return {"platforms": [{"platform": r["platform"], "count": r["cnt"]} for r in rows]}


@app.get(f"{API_PREFIX}/languages")
def list_languages(db: Db):
    rows = db.execute(
        """
        SELECT language, COUNT(*) AS cnt
        FROM dramas
        WHERE language IS NOT NULL AND TRIM(language) != ''
        GROUP BY language
        ORDER BY cnt DESC, language
        """
    ).fetchall()
    return {"languages": [{"language": r["language"], "count": r["cnt"]} for r in rows]}
