import asyncio
import logging
import os
from datetime import datetime, timedelta
from typing import Any, Dict, List, Optional, Set

from app.db.database import get_mongo_db
from app.v1.background.global_intraday import (
    _get_global_user_id,
    _get_global_zerodha_client,
    _is_market_hours_ist,
)
from app.v1.services.teGPT import analyze_symbol_service
from app.v1.utils.confidence import normalize_confidence
from app.v1.utils.snapshot_sanitize import compact_analysis_for_persistence


logger = logging.getLogger(__name__)

# Keep the loop bounded; override via env if you truly want "all".
MAX_ANALYSES_PER_CYCLE = int(os.getenv("PORTFOLIO_INTRADAY_MAX_ANALYSES", "200"))
FRESHNESS_MINUTES = int(os.getenv("PORTFOLIO_INTRADAY_FRESHNESS_MINUTES", "30"))

IST = None  # maintained in global_intraday; here we reuse _is_market_hours_ist
_PORTFOLIO_INTRADAY_LOCK = asyncio.Lock()


def _enforce_no_shorting_for_portfolio(analysis: Dict[str, Any]) -> Dict[str, Any]:
    """Portfolio constraint: never publish short/SELL signals.

    If the upstream analysis suggests SELL, we downgrade to HOLD and clear
    any execution fields so consumers (alerts, UI, paper trading) don't
    treat it as an actionable short.
    """

    if not isinstance(analysis, dict):
        return {}

    decision = (analysis.get("decision") or analysis.get("action") or "HOLD")
    decision = str(decision).strip().upper()
    if decision != "SELL":
        return analysis

    analysis["decision"] = "HOLD"
    analysis["action"] = "HOLD"
    analysis["confidence"] = "LOW"
    analysis["conviction"] = "LOW"

    score_val = 0.0
    try:
        raw_score = analysis.get("score")
        if raw_score is not None and not isinstance(raw_score, bool):
            score_val = float(raw_score)
    except Exception:
        score_val = 0.0
    analysis["score"] = min(score_val, 39.0)

    for k in (
        "entry_price",
        "entry_zone",
        "stop_loss",
        "stop_loss_zone",
        "price_target",
        "target",
        "targets",
        "exec_entry",
        "exec_sl",
        "exec_targets",
        "trend_label",
    ):
        if k in analysis:
            analysis[k] = None

    rationale = (analysis.get("rationale") or "").strip()
    suffix = "[portfolio:no-shorting enforced: SELL downgraded to HOLD]"
    analysis["rationale"] = (rationale + " " + suffix).strip() if rationale else suffix
    return analysis


def _distinct_active_portfolio_stock_ids(db) -> List[str]:
    """Return distinct ACTIVE portfolio stock_ids across all users."""

    try:
        ids = db["user_portfolio_items"].distinct(
            "stock_id",
            {
                "status": "ACTIVE",
                "stock_id": {"$ne": None},
            },
        )
    except Exception:
        logger.exception("[PortfolioIntraday] Failed to distinct user_portfolio_items.stock_id")
        return []

    out: List[str] = []
    for sid in ids or []:
        if isinstance(sid, str) and sid.strip():
            out.append(sid.strip())
    return out


def _filter_to_active_stocks(db, stock_ids: List[str]) -> Dict[str, Dict[str, Any]]:
    """Load stocks map for given stock_ids and filter to active+symbol."""

    if not stock_ids:
        return {}

    stocks = db["stocks"].find(
        {
            "stock_id": {"$in": stock_ids},
            # If field missing, treat as active.
            "$or": [{"is_active": True}, {"is_active": {"$exists": False}}],
        },
        {"_id": 0},
    )

    by_id: Dict[str, Dict[str, Any]] = {}
    for s in stocks:
        sid = s.get("stock_id")
        sym = (s.get("symbol") or "").strip().upper()
        if sid and sym:
            s["symbol"] = sym
            by_id[sid] = s
    return by_id


def _get_fresh_stock_ids(db, stock_ids: List[str], cutoff: datetime, batch_size: int = 500) -> Set[str]:
    """Return stock_ids that already have a snapshot newer than cutoff."""

    fresh: Set[str] = set()
    if not stock_ids:
        return fresh

    for i in range(0, len(stock_ids), batch_size):
        batch = stock_ids[i : i + batch_size]
        try:
            cur = db["stock_analysis_snapshots"].find(
                {"stock_id": {"$in": batch}, "timestamp": {"$gte": cutoff}},
                {"_id": 0, "stock_id": 1},
            )
            for doc in cur:
                sid = doc.get("stock_id")
                if sid:
                    fresh.add(sid)
        except Exception:
            logger.exception("[PortfolioIntraday] Failed to query recent snapshots")
            # continue; worst case we re-analyze some stocks
            continue

    return fresh


def _run_portfolio_cycle(db) -> None:
    logger.info("[PortfolioIntraday] Cycle start")

    zerodha_client = _get_global_zerodha_client(db)
    if not zerodha_client:
        logger.error("[PortfolioIntraday] Skipping cycle: global Zerodha client unavailable")
        return

    user_id = _get_global_user_id(db)
    if not user_id:
        logger.error("[PortfolioIntraday] Global user_id missing, cannot tag analyses")
        return

    stock_ids = _distinct_active_portfolio_stock_ids(db)
    if not stock_ids:
        logger.info("[PortfolioIntraday] No ACTIVE portfolio stocks found")
        return

    stocks_by_id = _filter_to_active_stocks(db, stock_ids)
    if not stocks_by_id:
        logger.info("[PortfolioIntraday] No eligible stocks (active+symbol) after filter")
        return

    cutoff = datetime.utcnow() - timedelta(minutes=max(1, int(FRESHNESS_MINUTES)))
    fresh_ids = _get_fresh_stock_ids(db, list(stocks_by_id.keys()), cutoff=cutoff)

    eligible_ids = [sid for sid in stocks_by_id.keys() if sid not in fresh_ids]

    logger.info(
        "[PortfolioIntraday] Universe=%d eligible=%d fresh=%d freshness_minutes=%d cap=%d",
        len(stocks_by_id),
        len(eligible_ids),
        len(fresh_ids),
        int(FRESHNESS_MINUTES),
        int(MAX_ANALYSES_PER_CYCLE),
    )

    analyzed = 0
    seen_symbols: Set[str] = set()
    for sid in sorted(eligible_ids):
        if analyzed >= MAX_ANALYSES_PER_CYCLE:
            logger.info(
                "[PortfolioIntraday] Reached per-cycle cap (max=%d); stopping early",
                MAX_ANALYSES_PER_CYCLE,
            )
            break

        stock = stocks_by_id.get(sid) or {}
        symbol = stock.get("symbol")
        if not symbol:
            continue

        sym_key = str(symbol).strip().upper()
        if not sym_key:
            continue
        if sym_key in seen_symbols:
            continue
        seen_symbols.add(sym_key)

        try:
            analysis = analyze_symbol_service(
                db=db,
                zerodha_client=zerodha_client,
                symbol=sym_key,
                timeframes=["5minute", "15minute", "30minute", "day", "week", "month"],
                question="general portfolio analysis",
                context="portfolio_intraday",
                user_id=user_id,
                include_market_data=True,
            )
        except Exception as e:  # pragma: no cover
            logger.exception("[PortfolioIntraday] Analysis failed for %s: %s", symbol, e)
            continue

        market_data = analysis.get("market_data") if isinstance(analysis, dict) else None
        features = analysis.get("features") if isinstance(analysis, dict) else None
        conf = normalize_confidence(
            analysis.get("confidence") if isinstance(analysis, dict) else None,
            decision_probability=analysis.get("decision_probability") if isinstance(analysis, dict) else None,
            score=analysis.get("score") if isinstance(analysis, dict) else None,
        )
        if isinstance(analysis, dict):
            analysis["confidence"] = conf
            analysis = _enforce_no_shorting_for_portfolio(analysis)

        targets = analysis.get("targets") if isinstance(analysis, dict) else None
        primary_target = None
        if isinstance(targets, list) and targets:
            primary_target = targets[0]

        snapshot_doc: Dict[str, Any] = {
            "stock_id": sid,
            "timestamp": datetime.utcnow(),
            "decision": analysis.get("decision"),
            "confidence": conf,
            "entry": analysis.get("entry_price"),
            "stop_loss": analysis.get("stop_loss"),
            "target": analysis.get("price_target") or analysis.get("target") or primary_target,
            "reason": analysis.get("rationale"),
            "source": "PORTFOLIO",
            "analysis": compact_analysis_for_persistence(analysis),
            "market_data": _compact_market_data_no_candles(market_data),
            "features": features,
        }

        try:
            db["stock_analysis_snapshots"].insert_one(snapshot_doc)
        except Exception:  # pragma: no cover
            logger.exception("[PortfolioIntraday] Failed to persist analysis snapshot for %s", symbol)

        analyzed += 1

    logger.info("[PortfolioIntraday] Cycle finished | analyzed=%d", analyzed)


def _run_portfolio_cycle_sync() -> None:
    db_gen = get_mongo_db()
    db = next(db_gen)
    try:
        _run_portfolio_cycle(db)
    finally:
        try:
            db_gen.close()
        except Exception:
            logger.debug("[PortfolioIntraday] Error closing DB generator", exc_info=True)


async def portfolio_intraday_loop(interval_seconds: int = 1800) -> None:
    """Background loop for global portfolio-driven analysis.

    - Universe: distinct ACTIVE portfolio stocks across all users (deduped by stock_id)
    - Skips symbols with a recent snapshot (<= PORTFOLIO_INTRADAY_FRESHNESS_MINUTES)
    - Runs only during IST market hours (same gate as ET/global loops)

    Start from FastAPI startup via `asyncio.create_task(portfolio_intraday_loop(...))`.
    """

    await asyncio.sleep(5)
    logger.info("[PortfolioIntraday] Background loop started (interval=%ss)", interval_seconds)

    while True:
        if not _is_market_hours_ist():
            from app.v1.utils.market_time import backend_market_window, format_window, local_time_str

            logger.info(
                "[PortfolioIntraday] Backend skipped due to time | now=%s | window=%s",
                local_time_str(),
                format_window(backend_market_window()),
            )
            await asyncio.sleep(interval_seconds)
            continue

        if _PORTFOLIO_INTRADAY_LOCK.locked():
            logger.info("[PortfolioIntraday] Previous cycle still running; skipping this tick")
            await asyncio.sleep(interval_seconds)
            continue

        async with _PORTFOLIO_INTRADAY_LOCK:
            try:
                await asyncio.to_thread(_run_portfolio_cycle_sync)
            except Exception:  # pragma: no cover
                logger.exception("[PortfolioIntraday] Unexpected error in cycle")

        await asyncio.sleep(interval_seconds)


def _compact_market_data_no_candles(md: Any) -> Dict[str, Any]:
    """Persist-safe market_data.

    Strict rule: do NOT persist candles/quote snapshots (especially intraday).
    """
    if not isinstance(md, dict):
        return {}
    out: Dict[str, Any] = {}
    for k in ("instrument_token", "stock_id", "indicators", "strategies", "pivots", "fib", "error"):
        v = md.get(k)
        if v is not None:
            out[k] = v
    return out
