from __future__ import annotations

import logging
import math
import os
from dataclasses import dataclass
from datetime import datetime, timedelta
from statistics import median
from typing import Any, Dict, Iterable, List, Optional, Tuple
from zoneinfo import ZoneInfo

from app.v1.services.zerodha.client import ZerodhaClient

logger = logging.getLogger(__name__)

IST = ZoneInfo("Asia/Kolkata")

# Core indices requested (optional ones included if available in instrument master).
DEFAULT_INDEX_QUERIES: List[str] = [
    "NIFTY 50",
    "NIFTY BANK",
    "NIFTY PSU BANK",
    "NIFTY FIN SERVICE",
    "NIFTY IT",
    "NIFTY AUTO",
    "NIFTY PHARMA",
    "NIFTY HEALTHCARE INDEX",
    "NIFTY METAL",
    "NIFTY ENERGY",
    "NIFTY OIL & GAS",
    "NIFTY REALTY",
    "NIFTY INFRA",
    "NIFTY FMCG",
    "NIFTY CONSUMER DURABLES",
    "NIFTY CAPITAL GOODS",
]

PRIMARY_MARKET_INDEX = os.getenv("MI_PRIMARY_INDEX", "NIFTY 50").strip() or "NIFTY 50"

# 15m cadence.
MI_INTERVAL_SECONDS = int(os.getenv("MARKET_INTELLIGENCE_INTERVAL_SECONDS", "900"))

# Heuristics / thresholds
EMA_PERIOD = int(os.getenv("MI_EMA_PERIOD", "20"))
ATR_PERIOD = int(os.getenv("MI_ATR_PERIOD", "14"))
STRUCTURE_LOOKBACK = int(os.getenv("MI_STRUCTURE_LOOKBACK", "6"))
TREND_LOOKBACK = int(os.getenv("MI_TREND_LOOKBACK", "12"))

# Market bias thresholds
MIN_TREND_PCT = float(os.getenv("MI_MIN_TREND_PCT", "0.35")) / 100.0  # 0.35%
SLOPE_ATR_FRACTION = float(os.getenv("MI_SLOPE_ATR_FRACTION", "0.08"))  # EMA slope threshold vs ATR

# Volatility thresholds
VOL_LOW_RATIO = float(os.getenv("MI_VOL_LOW_RATIO", "0.85"))
VOL_HIGH_RATIO = float(os.getenv("MI_VOL_HIGH_RATIO", "1.25"))

# Sector thresholds
SECTOR_MIN_MOVE_PCT = float(os.getenv("MI_SECTOR_MIN_MOVE_PCT", "0.30")) / 100.0


@dataclass(frozen=True)
class Candle:
    ts: datetime
    o: float
    h: float
    l: float
    c: float
    v: float


def _safe_float(v: Any) -> Optional[float]:
    try:
        if v is None or isinstance(v, bool):
            return None
        f = float(v)
        if math.isnan(f) or math.isinf(f):
            return None
        return f
    except Exception:
        return None


def _to_datetime(v: Any) -> Optional[datetime]:
    if isinstance(v, datetime):
        return v
    if isinstance(v, str) and v.strip():
        try:
            # Zerodha returns ISO-like strings in some places.
            return datetime.fromisoformat(v.replace("Z", "+00:00"))
        except Exception:
            return None
    return None


def _normalize_candles(raw: Iterable[Dict[str, Any]]) -> List[Candle]:
    out: List[Candle] = []
    for row in raw or []:
        if not isinstance(row, dict):
            continue
        ts = _to_datetime(row.get("date"))
        if not ts:
            continue
        o = _safe_float(row.get("open"))
        h = _safe_float(row.get("high"))
        l = _safe_float(row.get("low"))
        c = _safe_float(row.get("close"))
        v = _safe_float(row.get("volume"))
        if None in (o, h, l, c):
            continue
        out.append(Candle(ts=ts, o=float(o), h=float(h), l=float(l), c=float(c), v=float(v or 0.0)))
    out.sort(key=lambda x: x.ts)
    return out


def _ema(values: List[float], period: int) -> List[float]:
    if not values:
        return []
    period = max(1, int(period))
    k = 2.0 / (period + 1.0)
    ema: List[float] = []
    prev = values[0]
    ema.append(prev)
    for x in values[1:]:
        prev = (x * k) + (prev * (1.0 - k))
        ema.append(prev)
    return ema


def _true_ranges(candles: List[Candle]) -> List[float]:
    if len(candles) < 2:
        return []
    out: List[float] = []
    prev_close = candles[0].c
    for c in candles[1:]:
        tr = max(c.h - c.l, abs(c.h - prev_close), abs(c.l - prev_close))
        out.append(float(tr))
        prev_close = c.c
    return out


def _atr(candles: List[Candle], period: int) -> Optional[float]:
    trs = _true_ranges(candles)
    if not trs:
        return None
    period = max(1, int(period))
    if len(trs) < period:
        return float(sum(trs) / len(trs))
    window = trs[-period:]
    return float(sum(window) / period)


def _vwap(candles: List[Candle]) -> Optional[float]:
    pv = 0.0
    vv = 0.0
    for c in candles:
        tp = (c.h + c.l + c.c) / 3.0
        pv += tp * c.v
        vv += c.v
    if vv <= 0:
        return None
    return float(pv / vv)


def _structure_score(candles: List[Candle], lookback: int) -> int:
    """Higher-high + higher-low score in the last `lookback` candles."""
    lookback = max(3, int(lookback))
    window = candles[-lookback:] if len(candles) >= lookback else candles
    if len(window) < 3:
        return 0

    score = 0
    for i in range(1, len(window)):
        if window[i].h > window[i - 1].h:
            score += 1
        if window[i].l > window[i - 1].l:
            score += 1
    return score


def _pct_change(a: float, b: float) -> Optional[float]:
    if a is None or b is None:
        return None
    if b == 0:
        return None
    return float((a - b) / b)


def _volatility_state(atr_pct: Optional[float], atr_pct_history: List[float]) -> Tuple[str, Dict[str, Any]]:
    if atr_pct is None:
        return "Moderate", {"reason": "missing_atr"}

    hist = [x for x in atr_pct_history if x is not None and x > 0]
    base = median(hist) if hist else atr_pct
    ratio = float(atr_pct / base) if base else 1.0

    if ratio >= VOL_HIGH_RATIO:
        return "High", {"atr_pct": atr_pct, "baseline": base, "ratio": ratio}
    if ratio <= VOL_LOW_RATIO:
        return "Low", {"atr_pct": atr_pct, "baseline": base, "ratio": ratio}
    return "Moderate", {"atr_pct": atr_pct, "baseline": base, "ratio": ratio}


def _market_bias(
    *,
    close: float,
    ema: float,
    ema_prev: float,
    vwap: Optional[float],
    atr: Optional[float],
    structure_score: int,
    trend_ret: Optional[float],
) -> Tuple[str, Dict[str, Any]]:
    atr = atr or 0.0
    slope = float(ema - ema_prev)
    slope_ok = abs(slope) >= (SLOPE_ATR_FRACTION * atr) if atr > 0 else abs(slope) >= 0

    above = (vwap is None) or (close >= vwap)
    below = (vwap is None) or (close <= vwap)

    trend_ok_up = (trend_ret is not None) and (trend_ret >= MIN_TREND_PCT)
    trend_ok_dn = (trend_ret is not None) and (trend_ret <= -MIN_TREND_PCT)

    bull = close > ema and above and slope > 0 and slope_ok and structure_score >= 6 and trend_ok_up
    bear = close < ema and below and slope < 0 and slope_ok and structure_score <= 2 and trend_ok_dn

    if bull:
        return "Trending", {
            "direction": "up",
            "close": close,
            "ema": ema,
            "ema_slope": slope,
            "vwap": vwap,
            "structure_score": structure_score,
            "trend_ret": trend_ret,
        }
    if bear:
        return "Bearish", {
            "direction": "down",
            "close": close,
            "ema": ema,
            "ema_slope": slope,
            "vwap": vwap,
            "structure_score": structure_score,
            "trend_ret": trend_ret,
        }

    return "Sideways", {
        "direction": "flat",
        "close": close,
        "ema": ema,
        "ema_slope": slope,
        "vwap": vwap,
        "structure_score": structure_score,
        "trend_ret": trend_ret,
    }


def _overall_risk(market_bias: str, volatility_state: str) -> str:
    vb = (volatility_state or "").strip().lower()
    mb = (market_bias or "").strip().lower()

    if vb == "high":
        return "High"
    if vb == "low" and mb in {"trending", "bearish"}:
        return "Low"
    return "Medium"


def _resolve_instrument_token(db, query: str) -> Optional[int]:
    """Resolve an instrument token from cached Zerodha instruments.

    We attempt exact match on tradingsymbol (case-insensitive), and fall back to name.
    """

    q = (query or "").strip()
    if not q:
        return None

    # Primary: stocks master list (single source of truth).
    try:
        doc = db["stocks"].find_one(
            {"symbol": {"$regex": f"^{q}$", "$options": "i"}, "exchange": "NSE"},
            {"instrument_token": 1},
        )
        if doc and doc.get("instrument_token") is not None:
            return int(doc.get("instrument_token"))
    except Exception:
        pass

    try:
        doc = db["stocks"].find_one(
            {"name": {"$regex": f"^{q}$", "$options": "i"}, "exchange": "NSE"},
            {"instrument_token": 1},
        )
        if doc and doc.get("instrument_token") is not None:
            return int(doc.get("instrument_token"))
    except Exception:
        pass

    # Legacy fallback: zerodha_instruments cache (to be phased out).
    doc = db["zerodha_instruments"].find_one(
        {"tradingsymbol": {"$regex": f"^{q}$", "$options": "i"}, "exchange": "NSE"},
        {"instrument_token": 1},
    )
    if doc and doc.get("instrument_token"):
        try:
            return int(doc.get("instrument_token"))
        except Exception:
            return None

    # Fallback: name match (indices often store human names)
    doc = db["zerodha_instruments"].find_one(
        {"name": {"$regex": f"^{q}$", "$options": "i"}, "exchange": "NSE"},
        {"instrument_token": 1},
    )
    if doc and doc.get("instrument_token"):
        try:
            return int(doc.get("instrument_token"))
        except Exception:
            return None

    # Fallback: some parts of the codebase store instruments as a single cached document
    # with `type: nse_equity` and an `instruments` array.
    cached = db["zerodha_instruments"].find_one({"type": "nse_equity"}, {"instruments": 1})
    arr = cached.get("instruments") if isinstance(cached, dict) else None
    if isinstance(arr, list) and arr:
        q_upper = q.upper()
        for inst in arr:
            if not isinstance(inst, dict):
                continue
            ts = str(inst.get("tradingsymbol") or "").strip().upper()
            nm = str(inst.get("name") or "").strip().upper()
            ex = str(inst.get("exchange") or "").strip().upper()
            if ex and ex != "NSE":
                continue
            if ts == q_upper or nm == q_upper:
                tok = inst.get("instrument_token")
                try:
                    return int(tok)
                except Exception:
                    return None

    return None


def _fetch_candles(
    *,
    zerodha_client: ZerodhaClient,
    instrument_token: int,
    interval: str,
    days_back: int,
) -> List[Candle]:
    """Fetch candles from Zerodha (raw list), normalize to Candle list."""

    end = datetime.utcnow()
    start = end - timedelta(days=int(max(1, days_back)))

    # Kite accepts datetime objects; ZerodhaClient wraps but returns DataFrame.
    # We call the underlying kite directly to avoid pandas dependency here.
    raw = zerodha_client.kite.historical_data(
        instrument_token=instrument_token,
        from_date=start,
        to_date=end,
        interval=interval,
        continuous=False,
        oi=False,
    )
    return _normalize_candles(raw)


def compute_market_intelligence_summary(
    *,
    db,
    zerodha_client: ZerodhaClient,
    captured_at: Optional[datetime] = None,
    interval: str = "15minute",
    days_back: int = 7,
    index_queries: Optional[List[str]] = None,
) -> Dict[str, Any]:
    """Compute deterministic Market Intelligence Summary.

    No LLMs, no probabilistic logic. Purely rule-based.
    """

    captured_at = captured_at or datetime.utcnow()
    index_queries = index_queries or DEFAULT_INDEX_QUERIES

    token_by_name: Dict[str, int] = {}
    missing: List[str] = []
    for name in index_queries:
        try:
            tok = _resolve_instrument_token(db, name)
        except Exception:
            tok = None
        if tok:
            token_by_name[name] = tok
        else:
            missing.append(name)

    if PRIMARY_MARKET_INDEX not in token_by_name:
        raise RuntimeError(f"Primary index '{PRIMARY_MARKET_INDEX}' not found in zerodha_instruments")

    # --- Primary market index analytics (NIFTY 50) ---
    nifty_tok = token_by_name[PRIMARY_MARKET_INDEX]
    nifty_candles = _fetch_candles(zerodha_client=zerodha_client, instrument_token=nifty_tok, interval=interval, days_back=days_back)
    if len(nifty_candles) < max(EMA_PERIOD, ATR_PERIOD, TREND_LOOKBACK) + 5:
        raise RuntimeError(f"Not enough candles for {PRIMARY_MARKET_INDEX}: {len(nifty_candles)}")

    closes = [c.c for c in nifty_candles]
    ema_series = _ema(closes, EMA_PERIOD)
    ema_last = ema_series[-1]
    ema_prev = ema_series[-2] if len(ema_series) >= 2 else ema_last

    close_last = closes[-1]
    atr = _atr(nifty_candles, ATR_PERIOD)
    atr_pct = (float(atr) / close_last * 100.0) if atr and close_last else None

    # Vol baseline history from recent candles
    atr_pct_hist: List[float] = []
    for i in range(ATR_PERIOD + 2, min(len(nifty_candles), 150)):
        sub = nifty_candles[: i + 1]
        a = _atr(sub, ATR_PERIOD)
        if a and sub[-1].c:
            atr_pct_hist.append(float(a) / sub[-1].c * 100.0)

    vol_state, vol_diag = _volatility_state(atr_pct, atr_pct_hist[-64:])

    # Session VWAP: candles from today (IST)
    today_ist = captured_at.astimezone(IST).date()
    session = [c for c in nifty_candles if c.ts.astimezone(IST).date() == today_ist]
    vwap = _vwap(session or nifty_candles[-40:])

    struct_score = _structure_score(nifty_candles, STRUCTURE_LOOKBACK)
    ref = nifty_candles[-TREND_LOOKBACK - 1].c if len(nifty_candles) > TREND_LOOKBACK else nifty_candles[0].c
    trend_ret = _pct_change(close_last, ref)

    market_bias, bias_diag = _market_bias(
        close=close_last,
        ema=float(ema_last),
        ema_prev=float(ema_prev),
        vwap=vwap,
        atr=atr,
        structure_score=struct_score,
        trend_ret=trend_ret,
    )

    overall_risk = _overall_risk(market_bias, vol_state)

    # --- Sector analytics (exclude primary) ---
    bullish: List[str] = []
    bearish: List[str] = []
    volatile: List[str] = []

    sectors_debug: Dict[str, Any] = {}

    for name, tok in token_by_name.items():
        if name == PRIMARY_MARKET_INDEX:
            continue

        try:
            candles = _fetch_candles(zerodha_client=zerodha_client, instrument_token=tok, interval=interval, days_back=days_back)
            if len(candles) < max(EMA_PERIOD, ATR_PERIOD, TREND_LOOKBACK) + 5:
                continue

            closes_s = [c.c for c in candles]
            ema_s = _ema(closes_s, EMA_PERIOD)
            close_s = closes_s[-1]
            ema_last_s = ema_s[-1]
            ema_prev_s = ema_s[-2] if len(ema_s) >= 2 else ema_last_s
            atr_s = _atr(candles, ATR_PERIOD)
            atr_pct_s = (float(atr_s) / close_s * 100.0) if atr_s and close_s else None

            # Vol baseline
            atr_pct_hist_s: List[float] = []
            for i in range(ATR_PERIOD + 2, min(len(candles), 150)):
                sub = candles[: i + 1]
                a = _atr(sub, ATR_PERIOD)
                if a and sub[-1].c:
                    atr_pct_hist_s.append(float(a) / sub[-1].c * 100.0)
            sector_vol_state, sector_vol_diag = _volatility_state(atr_pct_s, atr_pct_hist_s[-64:])

            ref_s = candles[-TREND_LOOKBACK - 1].c if len(candles) > TREND_LOOKBACK else candles[0].c
            ret_s = _pct_change(close_s, ref_s)
            slope_s = float(ema_last_s - ema_prev_s)

            cls = "NEUTRAL"
            if sector_vol_state == "High":
                cls = "VOLATILE"
                volatile.append(name)
            else:
                up = (close_s > ema_last_s) and (ret_s is not None and ret_s >= SECTOR_MIN_MOVE_PCT) and (slope_s > 0)
                dn = (close_s < ema_last_s) and (ret_s is not None and ret_s <= -SECTOR_MIN_MOVE_PCT) and (slope_s < 0)
                if up:
                    cls = "BULLISH"
                    bullish.append(name)
                elif dn:
                    cls = "BEARISH"
                    bearish.append(name)

            sectors_debug[name] = {
                "class": cls,
                "ret": ret_s,
                "close": close_s,
                "ema": float(ema_last_s),
                "ema_slope": slope_s,
                "volatility": {"state": sector_vol_state, **sector_vol_diag},
            }
        except Exception as e:
            raw = (os.getenv("QUIET_LOGS", "true") or "true").strip().lower()
            quiet = raw not in ("0", "false", "no", "off")
            if quiet:
                logger.warning("[MI] Failed computing sector %s: %s", name, e)
            else:
                logger.exception("[MI] Failed computing sector %s", name)

    summary: Dict[str, Any] = {
        "captured_at": captured_at,
        "primary_index": PRIMARY_MARKET_INDEX,
        "interval": interval,
        "market_bias": market_bias,
        "volatility_state": vol_state,
        "overall_risk": overall_risk,
        "sector_strength": {
            "bullish": sorted(set(bullish)),
            "bearish": sorted(set(bearish)),
            "volatile": sorted(set(volatile)),
        },
        "diagnostics": {
            "market": {
                "bias": bias_diag,
                "volatility": vol_diag,
            },
            "sectors": sectors_debug,
            "missing_indices": missing,
        },
    }

    return summary
