"""Helpers for platform teGPT routers.

NOTE: This file is a move-only extraction from `teGPT.py` to keep router modules small.
"""

from fastapi import HTTPException
from typing import Dict, Any, List, Optional
from datetime import datetime
from bson import ObjectId
import re
from collections import deque
import os

from app.v1.services.zerodha.client import ZerodhaClient
from app.v1.utils.confidence import confidence_rank, normalize_confidence


def _env_int(name: str, default: int, *, min_value: int, max_value: int) -> int:
    raw = os.getenv(name)
    if raw is None or str(raw).strip() == "":
        return default
    try:
        v = int(str(raw).strip())
    except Exception:
        return default
    return max(min_value, min(max_value, v))


DEFAULT_PORTFOLIO_LIVE_FRESHNESS_MINUTES = _env_int(
    "PORTFOLIO_LIVE_FRESHNESS_MINUTES",
    10,
    min_value=1,
    max_value=60,
)


def _dt_to_iso(v: Any) -> Optional[str]:
    if v is None:
        return None
    if isinstance(v, datetime):
        return v.isoformat()
    # KiteConnect sometimes returns datetime; if it is already a string, keep it.
    if isinstance(v, str):
        return v
    return None


def _parse_ts(v: Any) -> Optional[datetime]:
    if v is None:
        return None
    if isinstance(v, datetime):
        return v
    if isinstance(v, str) and v.strip():
        # Keep parsing simple and safe; handle ISO-ish strings.
        try:
            return datetime.fromisoformat(v.replace("Z", "+00:00"))
        except Exception:
            return None
    return None


def _parse_ts_flexible(v: Any) -> Optional[datetime]:
    """Parse common CSV/export timestamp formats.

    Keeps behavior conservative: returns None if unparseable.
    """
    dt = _parse_ts(v)
    if dt is not None:
        return dt
    if not isinstance(v, str) or not v.strip():
        return None
    s = v.strip()
    for fmt in (
        "%d-%m-%Y %H:%M:%S",
        "%d/%m/%Y %H:%M:%S",
        "%Y-%m-%d %H:%M:%S",
        "%Y/%m/%d %H:%M:%S",
        "%d-%m-%Y %H:%M",
        "%d/%m/%Y %H:%M",
        "%Y-%m-%d",
        "%d-%m-%Y",
        "%d/%m/%Y",
    ):
        try:
            return datetime.strptime(s, fmt)
        except Exception:
            continue
    return None


def _norm_csv_key(k: Any) -> str:
    if not isinstance(k, str):
        return ""
    return re.sub(r"[^a-z0-9]+", "", k.strip().lower())


def _pick(row: Dict[str, Any], *keys: str) -> Any:
    for k in keys:
        nk = _norm_csv_key(k)
        if not nk:
            continue
        if nk in row and row.get(nk) not in (None, ""):
            return row.get(nk)
    return None


def _safe_float(v: Any) -> Optional[float]:
    try:
        n = float(v)
        return n if n == n and n not in (float("inf"), float("-inf")) else None
    except Exception:
        return None


def _safe_int(v: Any) -> Optional[int]:
    try:
        return int(v)
    except Exception:
        return None


def _norm_symbol(symbol: str) -> str:
    return (symbol or "").strip().upper()


def _norm_tags(tags: Any) -> List[str]:
    if tags is None:
        return []
    if isinstance(tags, str):
        parts = [p.strip() for p in tags.split(",")]
        return [p.lower() for p in parts if p]
    if isinstance(tags, list):
        out: List[str] = []
        for t in tags:
            if isinstance(t, str) and t.strip():
                out.append(t.strip().lower())
        # de-dupe preserving order
        seen = set()
        deduped = []
        for t in out:
            if t in seen:
                continue
            seen.add(t)
            deduped.append(t)
        return deduped
    return []


def _norm_account_id(v: Any) -> str:
    return (v or "").strip()


def _norm_side(v: Any) -> Optional[str]:
    if v is None:
        return None
    s = str(v).strip().upper()
    if s in ("BUY", "B"):
        return "BUY"
    if s in ("SELL", "S"):
        return "SELL"
    # common CSV values
    if s in ("BUY ", "PURCHASE"):
        return "BUY"
    if s in ("SELL ", "SALE"):
        return "SELL"
    if s in ("BUY", "SELL"):
        return s
    if s.lower() == "buy":
        return "BUY"
    if s.lower() == "sell":
        return "SELL"
    return "BUY" if s == "BUY" else ("SELL" if s == "SELL" else None)


def _compute_fifo_enrichment(items: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
    """Compute gross realized P&L using FIFO matching.

    - Works on execution-level rows (BUY/SELL, qty, price).
    - Returns per-row: realized_pnl, matched_qty, matched_avg_entry_price, position_after, open_avg_price.
    - This is gross P&L (fees/taxes not included).
    """

    def key_of(it: Dict[str, Any]) -> tuple:
        return (
            (it.get("symbol") or "").strip().upper(),
            (it.get("product") or ""),
            (it.get("exchange") or ""),
            (it.get("segment") or it.get("raw", {}).get("segment") or ""),
            (it.get("series") or it.get("raw", {}).get("series") or ""),
        )

    def num(v: Any) -> Optional[float]:
        try:
            f = float(v)
            return f if f == f and f not in (float("inf"), float("-inf")) else None
        except Exception:
            return None

    # process oldest -> newest
    enriched = []
    lots_by_key: Dict[tuple, deque] = {}

    for it in sorted(items, key=lambda x: x.get("ts") or datetime.min):
        side = _norm_side(it.get("transaction_type"))
        qty_i = _safe_int(it.get("quantity"))
        px = num(it.get("price"))
        if px is None:
            px = num(it.get("average_price"))

        realized = 0.0
        matched_qty = 0
        matched_entry_val = 0.0

        if not side or not qty_i or qty_i <= 0 or px is None:
            # still compute position snapshot
            k = key_of(it)
            dq = lots_by_key.get(k) or deque()
            pos = sum(l["qty"] for l in dq)
            open_qty = pos
            open_avg = None
            if open_qty:
                same = [l for l in dq if (l["qty"] > 0) == (open_qty > 0)]
                denom = sum(abs(l["qty"]) for l in same)
                if denom:
                    open_avg = sum(abs(l["qty"]) * l["price"] for l in same) / denom
            out = dict(it)
            out.update(
                {
                    "realized_pnl": None,
                    "matched_qty": 0,
                    "matched_avg_entry_price": None,
                    "position_after": pos,
                    "open_avg_price": open_avg,
                }
            )
            enriched.append(out)
            continue

        signed = qty_i if side == "BUY" else -qty_i
        k = key_of(it)
        dq = lots_by_key.setdefault(k, deque())

        remaining = abs(signed)
        if signed > 0:
            # BUY: first close shorts (negative lots)
            while remaining > 0 and dq and dq[0]["qty"] < 0:
                lot = dq[0]
                cover = min(remaining, -lot["qty"])
                # short pnl: sold at lot.price, bought at px
                realized += (lot["price"] - px) * cover
                matched_qty += cover
                matched_entry_val += cover * lot["price"]
                lot["qty"] += cover  # lot.qty is negative
                remaining -= cover
                if lot["qty"] == 0:
                    dq.popleft()
            if remaining > 0:
                dq.append({"qty": remaining, "price": px})
        else:
            # SELL: first close longs (positive lots)
            while remaining > 0 and dq and dq[0]["qty"] > 0:
                lot = dq[0]
                close = min(remaining, lot["qty"])
                realized += (px - lot["price"]) * close
                matched_qty += close
                matched_entry_val += close * lot["price"]
                lot["qty"] -= close
                remaining -= close
                if lot["qty"] == 0:
                    dq.popleft()
            if remaining > 0:
                # becomes a short position
                dq.append({"qty": -remaining, "price": px})

        pos = sum(l["qty"] for l in dq)
        open_qty = pos
        open_avg = None
        if open_qty:
            same = [l for l in dq if (l["qty"] > 0) == (open_qty > 0)]
            denom = sum(abs(l["qty"]) for l in same)
            if denom:
                open_avg = sum(abs(l["qty"]) * l["price"] for l in same) / denom

        out = dict(it)
        out.update(
            {
                "realized_pnl": round(realized, 2),
                "matched_qty": int(matched_qty),
                "matched_avg_entry_price": round(matched_entry_val / matched_qty, 4) if matched_qty else None,
                "position_after": int(pos),
                "open_avg_price": round(open_avg, 4) if open_avg is not None else None,
            }
        )
        enriched.append(out)

    return enriched


def _portfolio_account_match(account_id: str) -> Dict[str, Any]:
    """Match portfolio docs for the given account.

    Backward compatible: include docs with missing/empty account_id.
    """

    aid = _norm_account_id(account_id)
    if not aid:
        return {"$or": [{"account_id": {"$exists": False}}, {"account_id": ""}, {"account_id": None}]}
    return {
        "$or": [
            {"account_id": aid},
            {"account_id": {"$exists": False}},
            {"account_id": ""},
            {"account_id": None},
        ]
    }


def _resolve_instrument_token_via_user_zerodha(
    zerodha_client: ZerodhaClient,
    symbol: str,
    exchange: str,
) -> Optional[Dict[str, Any]]:
    sym = _norm_symbol(symbol)
    ex = (exchange or "NSE").strip().upper()

    attempts = []
    if ex:
        attempts.append(ex)
    for fallback in ("NSE", "BSE"):
        if fallback not in attempts:
            attempts.append(fallback)

    for exch in attempts:
        key = f"{exch}:{sym}"

        # Prefer quote (richer) then ltp (lighter).
        try:
            qd = zerodha_client.get_quote([key])
            row = qd.get(key) if isinstance(qd, dict) else None
            if isinstance(row, dict) and row.get("instrument_token"):
                return {
                    "instrument_token": row.get("instrument_token"),
                    "exchange": exch,
                    "tradingsymbol": sym,
                }
        except Exception:
            pass

        try:
            kite = getattr(zerodha_client, "kite", None)
            if kite is None:
                continue
            ld = kite.ltp([key])
            row = ld.get(key) if isinstance(ld, dict) else None
            if isinstance(row, dict) and row.get("instrument_token"):
                return {
                    "instrument_token": row.get("instrument_token"),
                    "exchange": exch,
                    "tradingsymbol": sym,
                }
        except Exception:
            pass

    return None


def _ensure_stock_identity(db, symbol: str, exchange: str = "NSE") -> Dict[str, Any]:
    """Read-only stock lookup.

    Non-negotiable: no manual stock creation.
    Stocks are created ONLY from Zerodha instruments via the stocks master refresh.
    """

    sym = _norm_symbol(symbol)
    if not sym:
        raise HTTPException(status_code=400, detail="symbol is required")

    exchange = (exchange or "NSE").strip().upper()
    stock = db["stocks"].find_one({"symbol": sym, "exchange": exchange}) or db["stocks"].find_one({"symbol": sym})
    if not isinstance(stock, dict):
        raise HTTPException(status_code=404, detail=f"Stock not found in stocks master list: {exchange}:{sym}")

    if not stock.get("stock_id") and stock.get("_id") is not None:
        try:
            sid = str(stock.get("_id"))
            db["stocks"].update_one({"_id": stock["_id"]}, {"$set": {"stock_id": sid}})
            stock["stock_id"] = sid
        except Exception:
            pass

    return stock


def _normalize_mover_param(mover: str) -> str:
    v = (mover or "").strip().lower()
    if v in ("gainers", "gainer", "top-gainers"):
        return "GAINER"
    if v in ("losers", "loser", "top-losers"):
        return "LOSER"
    if v in ("both", "all"):
        return "BOTH"
    return "GAINER"


def _get_stock_by_symbol(db, symbol: str) -> Optional[Dict[str, Any]]:
    sym = (symbol or "").strip().upper()
    if not sym:
        return None
    return db["stocks"].find_one({"symbol": sym, "exchange": "NSE"}) or db["stocks"].find_one({"symbol": sym})


def _get_symbols_from_live_movers(db, mover_type: str, limit: int) -> List[str]:
    q: Dict[str, Any] = {}
    if mover_type in ("GAINER", "LOSER"):
        q["mover_type"] = mover_type

    live = list(db["live_movers"].find(q).sort([("rank", 1), ("last_updated", -1)]).limit(limit))
    stock_ids = [d.get("stock_id") for d in live if d.get("stock_id")]
    if not stock_ids:
        return []

    stocks = list(db["stocks"].find({"stock_id": {"$in": stock_ids}}, {"stock_id": 1, "symbol": 1}))
    by_id = {s.get("stock_id"): (s.get("symbol") or "").strip().upper() for s in stocks}
    out: List[str] = []
    for sid in stock_ids:
        sym = by_id.get(sid)
        if sym:
            out.append(sym)
    return out


def _get_latest_snapshot_by_stock_id(db, stock_id: str) -> Optional[Dict[str, Any]]:
    if not stock_id:
        return None
    return db["stock_analysis_snapshots"].find_one({"stock_id": stock_id}, sort=[("timestamp", -1)])


def _get_latest_snapshot_by_symbol(db, symbol: str) -> Optional[Dict[str, Any]]:
    stock = _get_stock_by_symbol(db, symbol)
    if not stock or not stock.get("stock_id"):
        return None
    return _get_latest_snapshot_by_stock_id(db, stock.get("stock_id"))


def _extract_analysis_fields(symbol: str, stock: Optional[Dict[str, Any]], snap: Dict[str, Any]) -> Dict[str, Any]:
    analysis = snap.get("analysis") or {}
    if not isinstance(analysis, dict):
        analysis = {}

    decision = (analysis.get("decision") or analysis.get("action") or "HOLD")
    confidence = normalize_confidence(
        (analysis.get("confidence") or analysis.get("confidence_level")),
        decision_probability=analysis.get("decision_probability"),
        score=analysis.get("score"),
    )

    targets = analysis.get("targets")
    if isinstance(targets, dict):
        entry_price = analysis.get("entry_price") or targets.get("entry")
    else:
        entry_price = analysis.get("entry_price")

    out = {
        "analysis_id": analysis.get("analysis_id") or str(snap.get("_id")),
        "symbol": (symbol or "").strip().upper(),
        "decision": str(decision).upper() if decision is not None else "HOLD",
        "confidence": confidence,
        "entry_price": entry_price,
        "stop_loss": analysis.get("stop_loss"),
        "targets": targets,
        "timestamp": snap.get("timestamp").isoformat() if isinstance(snap.get("timestamp"), datetime) else None,
    }

    if stock:
        out["instrument_token"] = stock.get("instrument_token")
        out["stock_id"] = stock.get("stock_id")

    return out


def _what_changed(prev: Optional[Dict[str, Any]], curr: Dict[str, Any]) -> str:
    if not prev:
        return "New"

    parts: List[str] = []
    if prev.get("decision") != curr.get("decision"):
        parts.append(f"decision {prev.get('decision')}→{curr.get('decision')}")
    if prev.get("confidence") != curr.get("confidence"):
        parts.append(f"confidence {prev.get('confidence')}→{curr.get('confidence')}")

    for key in ("entry_price", "stop_loss"):
        if prev.get(key) != curr.get(key) and curr.get(key) is not None:
            parts.append(f"{key} updated")

    return "; ".join(parts) if parts else "-"


def _confidence_score(v: str) -> int:
    return confidence_rank(v)


def _format_analysis_for_stream_row(analysis: Dict[str, Any], rank: int, mover: str) -> Dict[str, Any]:
    symbol = (analysis.get("symbol") or "").strip().upper()

    def _extract_quote(a: Dict[str, Any]) -> Optional[Dict[str, Any]]:
        md = a.get("market_data")
        if isinstance(md, dict):
            q = md.get("quote")
            if isinstance(q, dict) and q:
                return q
        q2 = a.get("quote")
        if isinstance(q2, dict) and q2:
            return q2
        return None

    def _pick_candles_dict(md: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
        if not isinstance(md, dict) or not md:
            return None
        c = md.get("candles")
        if isinstance(c, dict) and c:
            return c
        nested = md.get("market_data")
        if isinstance(nested, dict):
            c2 = nested.get("candles")
            if isinstance(c2, dict) and c2:
                return c2
        return None

    def _compute_change_pct(quote: Optional[Dict[str, Any]], market_data: Optional[Dict[str, Any]]) -> Optional[float]:
        # % Change definition for Stream:
        #   intraday % = (current - today's open) / open * 100
        # This matches what traders expect in an intraday movers table.
        #
        # 1) Preferred: quote.ohlc.open -> quote.last_price
        if isinstance(quote, dict) and quote:
            last = _safe_float(quote.get("last_price"))
            ohlc = quote.get("ohlc") if isinstance(quote.get("ohlc"), dict) else {}
            open_px = _safe_float(ohlc.get("open"))
            if last is not None and open_px is not None and open_px != 0:
                return ((last - open_px) / open_px) * 100.0

            # Fallback: some feeds provide net_change (difference) but not net_change_percentage.
            net = _safe_float(quote.get("net_change"))
            prev_close = _safe_float(ohlc.get("close"))
            if net is not None and prev_close is not None and prev_close != 0:
                return (net / prev_close) * 100.0
            if last is not None and prev_close is not None and prev_close != 0:
                return ((last - prev_close) / prev_close) * 100.0

        # 2) Fallback: compute from stored INTRADAY candles (first open -> last close).
        candles = _pick_candles_dict(market_data)
        if not isinstance(candles, dict) or not candles:
            return None

        for tf in ("5minute", "5min", "15minute", "15min", "30minute", "30min"):
            series = candles.get(tf)
            if not isinstance(series, list) or len(series) < 2:
                continue
            open_px = _safe_float((series[0] or {}).get("open"))
            last_close = _safe_float((series[-1] or {}).get("close"))
            if last_close is None or open_px is None or open_px == 0:
                continue
            return ((last_close - open_px) / open_px) * 100.0

        return None

    def _ensure_two_targets(
        *,
        decision: str,
        entry_zone: Optional[Dict[str, Any]],
        entry_price: Optional[float],
        sl: Optional[float],
        targets: List[float],
    ) -> List[float]:
        out = [float(t) for t in targets if _safe_float(t) is not None]

        zone_mid = None
        if isinstance(entry_zone, dict):
            lo = _safe_float(entry_zone.get("low") or entry_zone.get("lower"))
            hi = _safe_float(entry_zone.get("high") or entry_zone.get("upper"))
            if lo is not None and hi is not None and lo > 0 and hi > 0:
                zone_mid = (lo + hi) / 2.0
        entry_ref = zone_mid if zone_mid is not None else entry_price
        entry_ref = _safe_float(entry_ref)
        sl = _safe_float(sl)

        d = (decision or "").strip().upper()
        if entry_ref is None or sl is None or d not in ("BUY", "SELL"):
            return out[:2]

        risk = (entry_ref - sl) if d == "BUY" else (sl - entry_ref)
        if risk is None or risk <= 0:
            return out[:2]

        if len(out) == 0:
            t1 = entry_ref + risk if d == "BUY" else entry_ref - risk
            t2 = entry_ref + 2 * risk if d == "BUY" else entry_ref - 2 * risk
            return [round(t1, 2), round(t2, 2)]

        if len(out) == 1:
            t1 = float(out[0])
            if d == "BUY":
                t2 = max(t1 + risk, entry_ref + 2 * risk)
            else:
                t2 = min(t1 - risk, entry_ref - 2 * risk)
            out.append(round(float(t2), 2))

        return out[:2]

    def _ensure_sl_side(
        *,
        decision: str,
        entry_zone: Optional[Dict[str, Any]],
        sl: Optional[float],
    ) -> Optional[float]:
        sl = _safe_float(sl)
        if sl is None or not isinstance(entry_zone, dict):
            return sl
        lo = _safe_float(entry_zone.get("low") or entry_zone.get("lower"))
        hi = _safe_float(entry_zone.get("high") or entry_zone.get("upper"))
        if lo is None or hi is None or lo <= 0 or hi <= 0:
            return sl

        d = (decision or "").strip().upper()
        buf = _env_int("ENTRY_SL_BUFFER_BP", 10, min_value=0, max_value=200)  # basis points
        buf_pct = float(buf) / 10000.0

        if d == "BUY":
            max_ok = float(lo) * (1.0 - buf_pct)
            if sl >= float(lo):
                return max_ok
            return min(float(sl), max_ok)
        if d == "SELL":
            min_ok = float(hi) * (1.0 + buf_pct)
            if sl <= float(hi):
                return min_ok
            return max(float(sl), min_ok)
        return sl

    raw_targets = analysis.get("targets")
    exit_target = None
    if isinstance(raw_targets, (list, tuple)) and raw_targets:
        exit_target = raw_targets[0]

    quote = _extract_quote(analysis)
    md_for_change = analysis.get("market_data") if isinstance(analysis.get("market_data"), dict) else None
    change_pct = _compute_change_pct(quote, md_for_change)

    # Derive a stable current_price for UI (prefer explicit field; else quote last_price).
    current_price = _safe_float(analysis.get("current_price"))
    if current_price is None and isinstance(quote, dict):
        current_price = _safe_float(quote.get("last_price"))
    if current_price is None and isinstance(md_for_change, dict):
        # last resort: use last candle close if available
        candles = _pick_candles_dict(md_for_change)
        if isinstance(candles, dict):
            for tf in ("5minute", "15minute", "day"):
                series = candles.get(tf)
                if isinstance(series, list) and series:
                    current_price = _safe_float((series[-1] or {}).get("close"))
                    if current_price is not None:
                        break

    entry_zone = analysis.get("entry_zone") if isinstance(analysis.get("entry_zone"), dict) else None
    sl_exec = _safe_float(analysis.get("exec_sl"))
    sl_fallback = _safe_float(analysis.get("sl"))
    sl = sl_exec if sl_exec is not None else sl_fallback

    sl = _ensure_sl_side(decision=str(analysis.get("decision", "HOLD")), entry_zone=entry_zone, sl=sl)

    raw_exec_targets = analysis.get("exec_targets")
    exec_targets: List[float] = []
    if isinstance(raw_exec_targets, list):
        exec_targets = [float(x) for x in raw_exec_targets if _safe_float(x) is not None]
    exec_targets = _ensure_two_targets(
        decision=str(analysis.get("decision", "HOLD")),
        entry_zone=entry_zone,
        entry_price=_safe_float(analysis.get("entry_price")) or _safe_float(analysis.get("current_price")),
        sl=sl or _safe_float(analysis.get("stop_loss")),
        targets=exec_targets,
    )

    return {
        "symbol": symbol,
        "rank": rank,
        "trend_label": analysis.get("trend_label") or "",
        "decision": analysis.get("decision", "HOLD"),
        "confidence": normalize_confidence(
            analysis.get("confidence"),
            decision_probability=analysis.get("decision_probability"),
            score=analysis.get("score"),
        ),
        "instrument_token": analysis.get("instrument_token"),
        "change_pct": change_pct,

        # New execution-plan fields (preferred by newer frontend)
        "entry_zone": entry_zone,
        "entry_trigger_reason": analysis.get("entry_trigger_reason"),
        "signal_state": analysis.get("signal_state"),
        "exec_sl": sl,
        "exec_targets": exec_targets,
        "exec_rr_ratio": analysis.get("exec_rr_ratio"),

        "entry_price": analysis.get("entry_price"),
        "stop_loss": analysis.get("stop_loss"),
        "price_target": analysis.get("price_target"),
        "risk_reward_ratio": analysis.get("risk_reward_ratio"),
        "current_price": current_price,
        "metrics": {
            "decision": analysis.get("decision", "HOLD"),
            "confidence": analysis.get("confidence", "LOW"),
        },
        "rationale": analysis.get("rationale", []),
        "features": analysis.get("features", {}),
        "technical_indicators": analysis.get("technical_indicators", {}),
        "targets": {
            "entry": analysis.get("entry_price"),
            "exit": analysis.get("price_target") or exit_target,
        },
        "timestamp": analysis.get("timestamp"),
        "mover": mover,
    }
