"""DB-only sanity check for the intraday analysis universe.

Prints:
- ET top gainers/losers (top N per side) resolved from `live_movers` -> `stocks.symbol`
- Early movers snapshot symbols (today IST, else latest snapshot)
- Manual watchlist symbols (today IST from `user_portfolio_items`)
- Union + source breakdown

No Zerodha/GPT calls.

Run:
  cd API && set -a && source ./.env && set +a && python scripts/sanity_intraday_universe.py
"""

from __future__ import annotations

import os
import sys
from collections import Counter
from datetime import datetime
from typing import Any, Dict, List, Tuple
from zoneinfo import ZoneInfo

# When running as `python scripts/...py`, ensure API root is importable.
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))

from app.db.database import get_mongo_db
from app.v1.services.intraday_watchlist import ist_date_str

IST = ZoneInfo("Asia/Kolkata")


def _norm_symbol(v: Any) -> str:
    if not isinstance(v, str):
        return ""
    return v.strip().upper()


def _load_et_symbols(db, *, limit_per_side: int) -> Tuple[List[str], List[str]]:
    if limit_per_side <= 0:
        return [], []

    gainers = list(
        db["live_movers"]
        .find({"mover_type": "GAINER"}, {"_id": 0, "stock_id": 1, "rank": 1})
        .sort([("rank", 1), ("last_updated", -1)])
        .limit(int(limit_per_side))
    )
    losers = list(
        db["live_movers"]
        .find({"mover_type": "LOSER"}, {"_id": 0, "stock_id": 1, "rank": 1})
        .sort([("rank", 1), ("last_updated", -1)])
        .limit(int(limit_per_side))
    )

    stock_ids: List[str] = []
    for r in gainers + losers:
        sid = r.get("stock_id")
        if isinstance(sid, str) and sid and sid not in stock_ids:
            stock_ids.append(sid)

    if not stock_ids:
        return [], []

    stocks = list(db["stocks"].find({"stock_id": {"$in": stock_ids}}, {"_id": 0, "stock_id": 1, "symbol": 1}))
    sym_by_id: Dict[str, str] = {}
    for s in stocks:
        sid = s.get("stock_id")
        sym = _norm_symbol(s.get("symbol"))
        if sid and sym:
            sym_by_id[str(sid)] = sym

    g_syms: List[str] = []
    for r in gainers:
        sid = r.get("stock_id")
        sym = sym_by_id.get(str(sid)) if sid else None
        if sym and sym not in g_syms:
            g_syms.append(sym)

    l_syms: List[str] = []
    for r in losers:
        sid = r.get("stock_id")
        sym = sym_by_id.get(str(sid)) if sid else None
        if sym and sym not in l_syms:
            l_syms.append(sym)

    return g_syms, l_syms


def _load_early_movers_symbols(db, *, limit_total: int) -> Tuple[str, List[str]]:
    ist_date = ist_date_str()
    coll = db[os.getenv("EARLY_MOVERS_SNAPSHOT_COLLECTION", "early_movers_snapshots")]
    snap = coll.find_one({"date": ist_date})
    if not snap:
        snap = coll.find_one({}, sort=[("date", -1)])

    if not isinstance(snap, dict):
        return "", []

    snap_date = str(snap.get("date") or "")

    top = snap.get("top") if isinstance(snap.get("top"), dict) else {}
    bullish = top.get("bullish") if isinstance(top.get("bullish"), list) else []
    bearish = top.get("bearish") if isinstance(top.get("bearish"), list) else []

    out: List[str] = []
    for it in bullish + bearish:
        if not isinstance(it, dict):
            continue
        sym = _norm_symbol(it.get("symbol"))
        if sym and sym not in out:
            out.append(sym)

    if limit_total > 0:
        out = out[: int(limit_total)]

    return snap_date, out


def _load_manual_symbols(db, *, limit_total: int) -> List[str]:
    if limit_total <= 0:
        return []

    ist_date = ist_date_str()
    out: List[str] = []
    cur = (
        db["user_portfolio_items"]
        .find({"ist_date": ist_date, "status": "ACTIVE"}, {"_id": 0, "symbol": 1})
        .sort([("updated_at", -1), ("created_at", -1)])
    )
    for it in cur:
        if len(out) >= int(limit_total):
            break
        sym = _norm_symbol(it.get("symbol"))
        if sym and sym not in out:
            out.append(sym)
    return out


def main() -> None:
    et_limit = int(os.getenv("GLOBAL_INTRADAY_ET_LIMIT_PER_SIDE", "10") or "10")
    early_limit = int(os.getenv("GLOBAL_INTRADAY_EARLY_MOVERS_LIMIT", "40") or "40")
    manual_limit = int(os.getenv("GLOBAL_INTRADAY_MANUAL_LIMIT", "50") or "50")

    gen = get_mongo_db()
    db = next(gen)
    try:
        now_utc = datetime.utcnow()
        now_ist = now_utc.replace(tzinfo=ZoneInfo("UTC")).astimezone(IST)
        today_ist = ist_date_str(now_utc)

        g_syms, l_syms = _load_et_symbols(db, limit_per_side=et_limit)
        snap_date, early_syms = _load_early_movers_symbols(db, limit_total=early_limit)
        manual_syms = _load_manual_symbols(db, limit_total=manual_limit)

        # Union preserving priority: early -> ET gainers -> ET losers -> manual
        union: List[str] = []
        sources: Dict[str, str] = {}

        for sym in early_syms:
            if sym not in union:
                union.append(sym)
                sources[sym] = "EARLY_MOVERS"

        for sym in g_syms:
            if sym not in union:
                union.append(sym)
                sources[sym] = "ET_GAINER"

        for sym in l_syms:
            if sym not in union:
                union.append(sym)
                sources[sym] = "ET_LOSER"

        for sym in manual_syms:
            if sym not in union:
                union.append(sym)
                sources[sym] = "MANUAL"

        print("now_utc:", now_utc.isoformat() + "Z")
        print("now_ist:", now_ist.isoformat())
        print("ist_date:", today_ist)
        print()

        print(f"ET top gainers (limit={et_limit}):", g_syms)
        print(f"ET top losers  (limit={et_limit}):", l_syms)
        print()

        print(f"Early movers snapshot date={snap_date!r} (limit={early_limit}):", early_syms)
        print(f"Manual watchlist (limit={manual_limit}):", manual_syms)
        print()

        print("Union count:", len(union))
        print("Union symbols:", union)
        print("Source counts:", dict(Counter(sources.values())))

    finally:
        try:
            gen.close()
        except Exception:
            pass


if __name__ == "__main__":
    main()
