import logging
import os
import uuid
from datetime import datetime, timedelta
from typing import Any, Dict, List, Optional

from fastapi import HTTPException

from app.v1.services.zerodha.client import ZerodhaClient
from app.v1.services.gpt_engine import (
    OPENAI_MODEL,
    REQUEST_TIMEOUT,
    get_openai_client,
    prepare_market_data_prompt,
    call_chatgpt_analysis,
)
from app.v1.services.tegpt.zerodha_services import fetch_market_data
from app.v1.utils.snapshot_sanitize import compact_analysis_for_persistence

logger = logging.getLogger(__name__)


def _looks_like_model_error(exc: Exception) -> bool:
    msg = str(exc or "").lower()
    if not msg:
        return False
    return (
        "model" in msg
        and (
            "not found" in msg
            or "does not exist" in msg
            or "no such model" in msg
            or "model_not_found" in msg
            or "unsupported model" in msg
        )
    )


def chat_with_stock_service(
    db,
    zerodha_client: ZerodhaClient,
    symbol: str,
    message: str,
    user_id: str,
    conversation_id: Optional[str] = None,
    include_fresh_data: bool = True,
) -> Dict[str, Any]:
    """Interactive chat about a specific symbol."""

    try:
        if not conversation_id:
            conversation_id = str(uuid.uuid4())

        def _as_utc_dt(v):
            if v is None:
                return None
            if isinstance(v, datetime):
                return v
            if isinstance(v, str):
                try:
                    return datetime.fromisoformat(v.replace("Z", "+00:00")).replace(tzinfo=None)
                except Exception:
                    return None
            return None

        def _get_latest_snapshot(sym: str) -> Dict[str, Any]:
            try:
                doc = db["stock_analysis_snapshots"].find_one(
                    {"symbol": sym},
                    sort=[("timestamp", -1), ("created_at", -1), ("_id", -1)],
                )
                if not doc:
                    return {}

                ts = _as_utc_dt(doc.get("timestamp") or doc.get("created_at"))
                return {
                    "timestamp": ts,
                    "analysis": doc.get("analysis") or {},
                    "features": doc.get("features") or {},
                    "source": doc.get("source"),
                }
            except Exception:
                return {}

        def _is_fresh(ts: Optional[datetime], minutes: int) -> bool:
            if not ts:
                return False
            return datetime.utcnow() - ts <= timedelta(minutes=minutes)

        symbol = (symbol or "").strip().upper()

        refresh_minutes = int(os.getenv("CHAT_CONTEXT_REFRESH_MINUTES", os.getenv("CHAT_CONTEXT_FRESHNESS_MINUTES", "5")))
        refresh_minutes = max(1, refresh_minutes)

        def _safe_float(v: Any) -> Optional[float]:
            try:
                if v is None or isinstance(v, bool):
                    return None
                f = float(v)
                if f != f or f in (float("inf"), float("-inf")):
                    return None
                return f
            except Exception:
                return None

        def _compact_market_data_only(md: Dict[str, Any]) -> Dict[str, Any]:
            out: Dict[str, Any] = {}
            if not isinstance(md, dict):
                return out
            for k in ("indicators", "strategies", "pivots", "fib", "instrument_token"):
                v = md.get(k)
                if v is not None:
                    out[k] = v
            return out

        def _summarize_compact_context(sym: str, analysis: Dict[str, Any], md_compact: Dict[str, Any]) -> str:
            dec = (analysis.get("decision") or "HOLD")
            conf = (analysis.get("confidence") or "LOW")
            prob = analysis.get("decision_probability")
            trend_label = analysis.get("trend_label") or ""
            zone = analysis.get("entry_zone") if isinstance(analysis.get("entry_zone"), dict) else None
            ztxt = ""
            if isinstance(zone, dict):
                lo = _safe_float(zone.get("low") or zone.get("lower"))
                hi = _safe_float(zone.get("high") or zone.get("upper"))
                if lo is not None and hi is not None:
                    ztxt = f"entry_zone={lo:.2f}-{hi:.2f}"

            sl = _safe_float(analysis.get("exec_sl") or analysis.get("stop_loss"))
            t = analysis.get("exec_targets") if isinstance(analysis.get("exec_targets"), list) else analysis.get("targets")
            targets: List[float] = []
            if isinstance(t, list):
                for x in t:
                    fx = _safe_float(x)
                    if fx is not None:
                        targets.append(fx)
                    if len(targets) >= 2:
                        break

            inds = (md_compact.get("indicators") or {}) if isinstance(md_compact.get("indicators"), dict) else {}

            def _tf_line(tf: str) -> str:
                s = inds.get(tf) if isinstance(inds, dict) else None
                if not isinstance(s, dict):
                    return ""
                rsi = _safe_float(s.get("rsi"))
                vwap = _safe_float(s.get("vwap"))
                close = _safe_float(s.get("close"))
                macd_hist = _safe_float(s.get("macd_hist"))
                parts: List[str] = []
                if close is not None:
                    parts.append(f"close={close:.2f}")
                if vwap is not None:
                    parts.append(f"vwap={vwap:.2f}")
                if rsi is not None:
                    parts.append(f"rsi={rsi:.1f}")
                if macd_hist is not None:
                    parts.append(f"macd_hist={macd_hist:.3f}")
                return f"- {tf}: " + ", ".join(parts) if parts else ""

            tf_lines = [x for x in (_tf_line("5minute"), _tf_line("15minute"), _tf_line("day")) if x]

            p = f"CONTEXT ({sym})\n"
            p += f"- decision={dec} confidence={conf}"
            if prob is not None:
                try:
                    p += f" decision_probability={float(prob):.0f}%"
                except Exception:
                    pass
            if trend_label:
                p += f" trend_label={trend_label}"
            p += "\n"
            if ztxt:
                p += f"- {ztxt}\n"
            if sl is not None:
                p += f"- exec_sl={sl:.2f}\n"
            if targets:
                p += f"- exec_targets={', '.join([f'{x:.2f}' for x in targets])}\n"
            if tf_lines:
                p += "INDICATORS:\n" + "\n".join(tf_lines) + "\n"

            st = md_compact.get("strategies")
            if isinstance(st, dict):
                mt = st.get("multi_timeframe") if isinstance(st.get("multi_timeframe"), dict) else None
                per = st.get("per_timeframe") if isinstance(st.get("per_timeframe"), dict) else None
                if mt:
                    p += f"MTF: {mt}\n"
                if per:
                    slim: Dict[str, Any] = {}
                    for k in ("5minute", "15minute", "day"):
                        if isinstance(per.get(k), dict) and per.get(k):
                            slim[k] = per.get(k)
                    if slim:
                        p += f"STRATEGY_FLAGS: {slim}\n"

            return p.strip()

        def _get_latest_context_doc() -> Optional[Dict[str, Any]]:
            try:
                return db["chats"].find_one(
                    {
                        "user_id": user_id,
                        "symbol": symbol,
                        "conversation_id": conversation_id,
                        "is_context_snapshot": True,
                    },
                    sort=[("created_at", -1), ("_id", -1)],
                )
            except Exception:
                return None

        def _get_latest_market_intelligence() -> Dict[str, Any]:
            try:
                doc = db["market_intelligence_summary"].find_one(
                    {"type": "latest"},
                    sort=[("updated_at", -1), ("captured_at", -1), ("_id", -1)],
                )
                if not doc:
                    return {}
                payload = doc.get("payload") or {}
                ts = _as_utc_dt(doc.get("captured_at") or doc.get("updated_at") or doc.get("created_at"))

                def _norm_key(v: Any) -> str:
                    s = str(v or "").strip().upper()
                    out_chars = []
                    for ch in s:
                        out_chars.append(ch if ch.isalnum() else " ")
                    return " ".join("".join(out_chars).split())

                def _best_sector_match(sector_value: str, sectors_debug: Dict[str, Any]) -> Optional[str]:
                    if not sector_value or not isinstance(sectors_debug, dict) or not sectors_debug:
                        return None
                    target = _norm_key(sector_value)
                    if not target:
                        return None

                    target_tokens = set(target.split())
                    best = None
                    best_score = 0

                    for k in sectors_debug.keys():
                        kk = _norm_key(k)
                        if not kk:
                            continue
                        kk_tokens = set(kk.split())
                        overlap = len(target_tokens & kk_tokens)
                        if any(t in kk for t in target_tokens):
                            overlap += 1
                        if overlap > best_score:
                            best_score = overlap
                            best = k

                    return best if best_score >= 1 else None

                out: Dict[str, Any] = {}
                if ts:
                    out["as_of"] = ts.isoformat() + "Z"
                if isinstance(payload, dict):
                    if payload.get("market_bias") is not None:
                        out["market_bias"] = payload.get("market_bias")
                    if payload.get("overall_risk") is not None:
                        out["risk_regime"] = payload.get("overall_risk")
                    if payload.get("volatility_state") is not None:
                        out["volatility_regime"] = payload.get("volatility_state")
                    if payload.get("sector_strength") is not None:
                        out["sector_strength"] = payload.get("sector_strength")

                    try:
                        stock = db["stocks"].find_one(
                            {"symbol": symbol},
                            {"_id": 0, "symbol": 1, "sector": 1, "industry": 1, "name": 1},
                        )
                    except Exception:
                        stock = None

                    if isinstance(stock, dict):
                        sector_val = (stock.get("sector") or "").strip()
                        industry_val = (stock.get("industry") or "").strip()
                        if sector_val:
                            out["symbol_sector"] = sector_val
                        if industry_val:
                            out["symbol_industry"] = industry_val

                        sectors_debug = ((payload.get("diagnostics") or {}).get("sectors") or {}) if isinstance(payload, dict) else {}
                        match_key = _best_sector_match(sector_val or industry_val, sectors_debug) if isinstance(sectors_debug, dict) else None
                        if match_key and isinstance(sectors_debug.get(match_key), dict):
                            diag = sectors_debug.get(match_key) or {}
                            vol = diag.get("volatility") if isinstance(diag.get("volatility"), dict) else {}
                            out["symbol_sector_trend"] = {
                                "index": match_key,
                                "class": diag.get("class"),
                                "ret": diag.get("ret"),
                                "volatility_state": vol.get("state"),
                            }
                return out
            except Exception:
                return {}

        ctx_doc = _get_latest_context_doc()
        ctx_ts = _as_utc_dt(ctx_doc.get("created_at")) if isinstance(ctx_doc, dict) else None
        ctx_fresh = _is_fresh(ctx_ts, refresh_minutes)

        used_zerodha = False
        refreshed = False
        context_summary = ctx_doc.get("message") if isinstance(ctx_doc, dict) else None
        context_snapshot_id = ctx_doc.get("context_snapshot_id") if isinstance(ctx_doc, dict) else None

        if include_fresh_data and not ctx_fresh:
            md_raw = fetch_market_data(zerodha_client, symbol, ["5minute", "15minute", "day"], db=db)
            used_zerodha = True

            md_compact = _compact_market_data_only(md_raw)
            market_intelligence = _get_latest_market_intelligence()

            try:
                prompt = prepare_market_data_prompt(
                    symbol,
                    md_compact,
                    question="general intraday analysis",
                    context="general",
                    market_intelligence=market_intelligence,
                )

                analysis = call_chatgpt_analysis(prompt)

                analysis["features"] = {
                    "indicators": md_compact.get("indicators") or {},
                    "strategies": md_compact.get("strategies") or {},
                    "pivots": md_compact.get("pivots") or {},
                    "fib": md_compact.get("fib") or {},
                }

                try:
                    from app.v1.services.entry_engine import build_execution_plan

                    plan = build_execution_plan(decision=str(analysis.get("decision") or "").upper(), market_data=md_raw)
                    if plan is not None:
                        analysis.setdefault("entry_zone", plan.entry_zone)
                        analysis.setdefault("entry_trigger_reason", plan.entry_trigger_reason)
                        analysis.setdefault("exec_sl", plan.sl)
                        analysis.setdefault("exec_targets", plan.targets)
                        analysis.setdefault("exec_rr_ratio", plan.rr_ratio)
                        analysis.setdefault("signal_state", plan.state)
                except Exception:
                    logger.exception("[Chat] Failed to compute execution plan")

                snapshot_doc: Dict[str, Any] = {
                    "symbol": symbol,
                    "timestamp": datetime.utcnow(),
                    "source": "CHAT",
                    "analysis": compact_analysis_for_persistence(analysis),
                    "market_data": md_compact,
                    "features": analysis.get("features") or {},
                }

                try:
                    ins = db["stock_analysis_snapshots"].insert_one(snapshot_doc)
                    context_snapshot_id = str(ins.inserted_id)
                except Exception:
                    logger.exception("[Chat] Failed to persist chat snapshot")

                context_summary = _summarize_compact_context(symbol, analysis, md_compact)

                ctx_store = {
                    "user_id": user_id,
                    "symbol": symbol,
                    "conversation_id": conversation_id,
                    "role": "system",
                    "is_context_snapshot": True,
                    "context_snapshot_id": context_snapshot_id,
                    "message": context_summary,
                    "created_at": datetime.utcnow(),
                }
                try:
                    db["chats"].insert_one(ctx_store)
                except Exception:
                    logger.exception("[Chat] Failed to store context snapshot message")

                refreshed = True
            except Exception:
                logger.exception("[Chat] Failed to refresh analysis context")

        history = list(
            db["chats"]
            .find({"user_id": user_id, "symbol": symbol, "conversation_id": conversation_id, "is_context_snapshot": {"$ne": True}})
            .sort("created_at", 1)
            .limit(30)
        )

        if not isinstance(context_summary, str) or not context_summary.strip():
            context_summary = f"CONTEXT ({symbol})\n- No fresh computed snapshot available yet."

        messages = [
            {
                "role": "system",
                "content": (
                    "You are an expert intraday trader and coach. "
                    "Be concise and practical. Use Markdown bullets. "
                    "Do NOT invent prices/indicators; use only provided context."
                ),
            },
            {"role": "system", "content": context_summary},
        ]

        for chat in history[-8:]:
            role = chat.get("role")
            if role not in ("user", "assistant"):
                continue
            txt = (chat.get("message") or "").strip()
            if not txt:
                continue
            messages.append({"role": role, "content": txt})

        messages.append({"role": "user", "content": (message or "").strip()})

        try:
            client = get_openai_client()

            chat_model = os.getenv("OPENAI_CHAT_MODEL", OPENAI_MODEL)
            chat_model_fallback = os.getenv("OPENAI_CHAT_MODEL_FALLBACK", os.getenv("OPENAI_MODEL_FALLBACK", "gpt-4o-mini"))
            chat_max_tokens = int(os.getenv("OPENAI_CHAT_MAX_TOKENS", "800"))

            def _max_tok_kwargs_for(model_name: str) -> Dict[str, Any]:
                if (model_name or "").strip().lower().startswith("gpt-5"):
                    return {"max_completion_tokens": int(chat_max_tokens)}
                return {"max_tokens": int(chat_max_tokens)}

            try:
                response = client.chat.completions.create(
                    model=chat_model,
                    messages=messages,
                    temperature=float(os.getenv("OPENAI_CHAT_TEMPERATURE", "0.2")),
                    **_max_tok_kwargs_for(chat_model),
                    timeout=REQUEST_TIMEOUT,
                )
            except Exception as e:
                msg = str(e)
                if "max_tokens" in msg and "max_completion_tokens" in msg:
                    logger.warning(
                        "Retrying chat completion with max_completion_tokens due to parameter mismatch (model=%s)",
                        chat_model,
                    )
                    response = client.chat.completions.create(
                        model=chat_model,
                        messages=messages,
                        temperature=float(os.getenv("OPENAI_CHAT_TEMPERATURE", "0.2")),
                        max_completion_tokens=int(chat_max_tokens),
                        timeout=REQUEST_TIMEOUT,
                    )
                elif chat_model_fallback and chat_model_fallback != chat_model and _looks_like_model_error(e):
                    logger.warning(
                        "Chat model failed (%s). Retrying with fallback model=%s",
                        chat_model,
                        chat_model_fallback,
                    )
                    response = client.chat.completions.create(
                        model=chat_model_fallback,
                        messages=messages,
                        temperature=float(os.getenv("OPENAI_CHAT_TEMPERATURE", "0.2")),
                        **_max_tok_kwargs_for(chat_model_fallback),
                        timeout=REQUEST_TIMEOUT,
                    )
                else:
                    raise

            assistant_message = response.choices[0].message.content.strip()

        except Exception as e:
            logger.exception("ChatGPT call failed in chat")
            # Give an actionable hint for the common misconfig (invalid model name).
            if _looks_like_model_error(e):
                assistant_message = (
                    f"Chat is unavailable because the configured model is invalid or not accessible. "
                    f"Please check OPENAI_CHAT_MODEL (current: {os.getenv('OPENAI_CHAT_MODEL', OPENAI_MODEL)})."
                )
            else:
                assistant_message = f"I'm having trouble analyzing {symbol} right now. Please try again in a moment."

        user_chat_doc = {
            "user_id": user_id,
            "symbol": symbol,
            "conversation_id": conversation_id,
            "role": "user",
            "message": message,
            "created_at": datetime.utcnow(),
        }
        db["chats"].insert_one(user_chat_doc)

        assistant_chat_doc = {
            "user_id": user_id,
            "symbol": symbol,
            "conversation_id": conversation_id,
            "role": "assistant",
            "message": assistant_message,
            "market_data_included": bool(include_fresh_data),
            "context": {
                "computed_snapshot_refreshed": refreshed,
                "context_snapshot_id": context_snapshot_id,
                "zerodha_used": used_zerodha,
                "refresh_minutes": refresh_minutes,
            },
            "created_at": datetime.utcnow(),
        }
        result = db["chats"].insert_one(assistant_chat_doc)

        return {
            "conversation_id": conversation_id,
            "message": assistant_message,
            "message_id": str(result.inserted_id),
            "timestamp": assistant_chat_doc["created_at"].isoformat(),
            "model": os.getenv("OPENAI_CHAT_MODEL", OPENAI_MODEL),
            "context": assistant_chat_doc.get("context"),
        }

    except Exception as e:
        logger.exception("Chat service failed for %s", symbol)
        raise HTTPException(status_code=500, detail=str(e))


def get_user_signals_service(
    db,
    user_id: str,
    limit: int,
    symbol: Optional[str] = None,
    decision: Optional[str] = None,
) -> List[Dict[str, Any]]:
    """Get user's latest trading signals."""

    try:
        query: Dict[str, Any] = {"user_id": user_id}

        if symbol:
            query["symbol"] = symbol.upper()

        if decision:
            query["analysis.decision"] = decision.upper()

        signals = list(db["analyses"].find(query).sort("timestamp", -1).limit(limit))

        formatted_signals = []
        for signal in signals:
            analysis = signal.get("analysis", {})
            formatted_signal = {
                "id": str(signal["_id"]),
                "symbol": signal["symbol"],
                "decision": analysis.get("decision", "HOLD"),
                "confidence": analysis.get("confidence", "LOW"),
                "signal_state": analysis.get("signal_state") or "WAITING_FOR_ENTRY",
                "entry_zone": analysis.get("entry_zone"),
                "entry_trigger_reason": analysis.get("entry_trigger_reason") or [],
                "sl": analysis.get("exec_sl"),
                "targets": analysis.get("exec_targets") or [],
                "rr_ratio": analysis.get("exec_rr_ratio"),
                "rationale": analysis.get("rationale", []),
                "technical_indicators": analysis.get("technical_indicators", {}),
                "timestamp": signal["timestamp"].isoformat(),
                "context": signal.get("context", "general"),
            }
            formatted_signals.append(formatted_signal)

        return formatted_signals

    except Exception as e:
        logger.exception("Failed to get user signals")
        raise HTTPException(status_code=500, detail=str(e))
