"""Learning routes for platform teGPT.

NOTE: Move-only split from `teGPT.py`.
"""

from fastapi import APIRouter, Depends, HTTPException, Body, Query, File, UploadFile
from typing import Dict, Any, List, Optional
import logging
from datetime import datetime, timedelta
from zoneinfo import ZoneInfo
from bson import ObjectId
import csv
import io

from pymongo import UpdateOne

from app.db import database
from app.v1.dependencies.auth import get_current_userdetails
from app.v1.services.teGPT import get_zerodha_client_service, chat_with_stock_service

from .teGPT_helpers import (
    _compute_fifo_enrichment,
    _dt_to_iso,
    _norm_account_id,
    _norm_csv_key,
    _norm_side,
    _norm_symbol,
    _parse_ts,
    _parse_ts_flexible,
    _pick,
    _safe_float,
    _safe_int,
)

router = APIRouter()
logger = logging.getLogger(__name__)


# ============ LEARNING (Personal Zerodha trades) ============


@router.get("/learning/transactions", summary="Get user's recent Zerodha transactions (trades)")
async def get_learning_transactions(
    limit: int = Query(50, ge=1, le=200, description="Max rows to return"),
    days: int = Query(120, ge=1, le=365, description="How many days of history to return (ignored if day is set)"),
    day: Optional[str] = Query(None, description="IST date YYYY-MM-DD (filters transactions to one day)"),
    sync: bool = Query(True, description="Sync today's Zerodha trades/orders into DB before listing"),
    db=Depends(database.get_mongo_db),
    current_user=Depends(get_current_userdetails),
):
    """Personal endpoint: last N transactions across months.

    Zerodha's official API typically exposes only today's tradebook/orders.
    To support multi-month history, we persist daily syncs into Mongo.
    """
    try:
        user_id = str(current_user.get("_id"))
        col = db["learning_transactions"]

        ist = ZoneInfo("Asia/Kolkata")

        def _today_ist_date() -> str:
            try:
                return datetime.utcnow().replace(tzinfo=ZoneInfo("UTC")).astimezone(ist).strftime("%Y-%m-%d")
            except Exception:
                return datetime.now(ist).strftime("%Y-%m-%d")

        day_norm = (day or "").strip()
        if day_norm:
            # Very small validation: YYYY-MM-DD
            try:
                datetime.strptime(day_norm, "%Y-%m-%d")
            except Exception:
                raise HTTPException(status_code=400, detail="Invalid day; expected YYYY-MM-DD")
        else:
            day_norm = ""

        cutoff = datetime.utcnow() - timedelta(days=days)

        today_ist = _today_ist_date()
        requested_day = day_norm or today_ist

        # KiteConnect tradebook/orders are typically limited to today's session.
        # If user requests a non-today day, syncing cannot help and would be misleading.
        can_sync_requested_day = requested_day == today_ist
        sync_effective = bool(sync) and can_sync_requested_day

        sync_report = {
            "attempted": bool(sync_effective),
            "trades_seen": 0,
            "orders_seen": 0,
            "upserts": 0,
            "errors": [],
        }

        if sync and not can_sync_requested_day:
            sync_report["errors"].append("historical_day_sync_not_supported")

        if sync_effective:
            zerodha_client = None
            try:
                zerodha_client = get_zerodha_client_service(db, current_user)
            except Exception:
                # Do not fail the endpoint if Zerodha auth is invalid/expired.
                sync_report["errors"].append("zerodha_client_unavailable")

            def upsert_one(kind: str, it: Dict[str, Any]) -> None:
                ts_raw = (
                    it.get("exchange_timestamp")
                    or it.get("order_timestamp")
                    or it.get("fill_timestamp")
                    or it.get("timestamp")
                )
                ts_dt = _parse_ts(ts_raw)
                symbol = (it.get("tradingsymbol") or it.get("symbol") or "").strip().upper()

                # Ensure we always have a usable timestamp for sorting/filtering.
                now = datetime.utcnow()
                effective_ts = ts_dt or now

                # Store IST date string for easy daily filtering.
                try:
                    ist_date = effective_ts.replace(tzinfo=ZoneInfo("UTC")).astimezone(ist).strftime("%Y-%m-%d")
                except Exception:
                    ist_date = _today_ist_date()

                doc = {
                    "user_id": user_id,
                    "kind": kind,
                    "symbol": symbol,
                    "exchange": it.get("exchange"),
                    "product": it.get("product"),
                    "transaction_type": (it.get("transaction_type") or "").upper() or None,
                    "quantity": _safe_int(it.get("quantity")),
                    "price": _safe_float(it.get("price")),
                    "average_price": _safe_float(it.get("average_price")),
                    "trade_id": it.get("trade_id"),
                    "order_id": it.get("order_id"),
                    "status": it.get("status"),
                    "ts": effective_ts,
                    "ist_date": ist_date,
                    "instrument_token": _safe_int(it.get("instrument_token") or it.get("token")),
                    "raw": it,
                    "updated_at": now,
                }
                # Choose a stable upsert key.
                if kind == "trade" and it.get("trade_id"):
                    key = {"user_id": user_id, "kind": "trade", "trade_id": it.get("trade_id")}
                elif it.get("order_id"):
                    key = {"user_id": user_id, "kind": kind, "order_id": it.get("order_id")}
                else:
                    key = {"user_id": user_id, "kind": kind, "symbol": symbol, "ts": effective_ts}

                col.update_one(key, {"$set": doc, "$setOnInsert": {"created_at": datetime.utcnow()}}, upsert=True)
                sync_report["upserts"] += 1

            if zerodha_client is not None:
                # Sync trades (today) and orders (today) so the DB accumulates over time.
                try:
                    trades = zerodha_client.get_trades()
                    if isinstance(trades, list):
                        sync_report["trades_seen"] = len(trades)
                        for it in trades:
                            if isinstance(it, dict):
                                upsert_one("trade", it)
                except Exception:
                    sync_report["errors"].append("trades_sync_failed")

                try:
                    orders = zerodha_client.get_orders()
                    if isinstance(orders, list):
                        sync_report["orders_seen"] = len(orders)
                        for it in orders:
                            if isinstance(it, dict):
                                upsert_one("order", it)
                except Exception:
                    sync_report["errors"].append("orders_sync_failed")

        # Read from DB (months history or one IST day)
        if day_norm:
            # Match by ist_date (new docs) OR by ts within IST day boundaries (older docs).
            # Convert IST day start/end to UTC range.
            try:
                day_start_ist = datetime.strptime(day_norm, "%Y-%m-%d").replace(
                    hour=0, minute=0, second=0, microsecond=0, tzinfo=ist
                )
                day_end_ist = day_start_ist + timedelta(days=1)
                day_start_utc = day_start_ist.astimezone(ZoneInfo("UTC")).replace(tzinfo=None)
                day_end_utc = day_end_ist.astimezone(ZoneInfo("UTC")).replace(tzinfo=None)
            except Exception:
                day_start_utc = None
                day_end_utc = None

            ors = [{"ist_date": day_norm}]
            if day_start_utc and day_end_utc:
                ors.append({"ts": {"$gte": day_start_utc, "$lt": day_end_utc}})
                # Fallback for older rows that only have updated_at.
                ors.append({"updated_at": {"$gte": day_start_utc, "$lt": day_end_utc}})

            q = {"user_id": user_id, "$or": ors}
        else:
            q = {
                "user_id": user_id,
                "$or": [
                    {"ts": {"$gte": cutoff}},
                    {"ts": None},
                    {"ts": {"$exists": False}},
                ],
            }
        # Pull more rows than we return so FIFO has context; still fast.
        compute_cap = max(limit, 2000 if day_norm else min(2000, days * 50))
        cur = col.find(q, {"_id": 0}).sort([("ts", 1), ("updated_at", 1)]).limit(compute_cap)
        docs = list(cur)

        base_items: List[Dict[str, Any]] = []
        for d in docs:
            base_items.append(
                {
                    "kind": d.get("kind") or "—",
                    "symbol": d.get("symbol") or "",
                    "exchange": d.get("exchange"),
                    "product": d.get("product"),
                    "transaction_type": d.get("transaction_type"),
                    "quantity": d.get("quantity"),
                    "price": d.get("price"),
                    "average_price": d.get("average_price"),
                    "trade_id": d.get("trade_id"),
                    "order_id": d.get("order_id"),
                    "status": d.get("status"),
                    "ts": d.get("ts") or d.get("updated_at"),
                    "exchange_timestamp": _dt_to_iso(d.get("ts")) or _dt_to_iso(d.get("updated_at")),
                    "instrument_token": d.get("instrument_token"),
                    "segment": d.get("segment"),
                    "series": d.get("series"),
                    "isin": d.get("isin"),
                    "raw": d.get("raw") if isinstance(d.get("raw"), dict) else {},
                }
            )

        enriched_all = _compute_fifo_enrichment(base_items)

        # Summary (over the full computed set, not just the returned page)
        summary = {
            "rows": len(enriched_all),
            "realized_pnl": 0.0,
            "buy_value": 0.0,
            "sell_value": 0.0,
            "buy_qty": 0,
            "sell_qty": 0,
        }
        for it in enriched_all:
            side = _norm_side(it.get("transaction_type"))
            qty = _safe_int(it.get("quantity")) or 0
            px = _safe_float(it.get("price"))
            if px is None:
                px = _safe_float(it.get("average_price"))
            val = float(qty) * float(px or 0.0)
            if side == "BUY":
                summary["buy_value"] += val
                summary["buy_qty"] += qty
            elif side == "SELL":
                summary["sell_value"] += val
                summary["sell_qty"] += qty
            rp = _safe_float(it.get("realized_pnl"))
            if rp is not None:
                summary["realized_pnl"] += float(rp)
        for k in ("realized_pnl", "buy_value", "sell_value"):
            summary[k] = round(float(summary[k]), 2)

        # Return newest rows only, with enrichment.
        enriched_all_sorted = sorted(enriched_all, key=lambda x: x.get("ts") or datetime.min, reverse=True)
        items = []
        for it in enriched_all_sorted[:limit]:
            qty = _safe_int(it.get("quantity"))
            px = _safe_float(it.get("price"))
            if px is None:
                px = _safe_float(it.get("average_price"))
            value = round((qty or 0) * (px or 0.0), 2) if qty and px is not None else None

            items.append(
                {
                    "kind": it.get("kind") or "—",
                    "symbol": it.get("symbol") or "",
                    "exchange": it.get("exchange"),
                    "product": it.get("product"),
                    "transaction_type": _norm_side(it.get("transaction_type")),
                    "quantity": it.get("quantity"),
                    "price": it.get("price"),
                    "average_price": it.get("average_price"),
                    "trade_id": it.get("trade_id"),
                    "order_id": it.get("order_id"),
                    "status": it.get("status"),
                    "exchange_timestamp": it.get("exchange_timestamp"),
                    "instrument_token": it.get("instrument_token"),
                    "isin": it.get("isin") or (it.get("raw") or {}).get("isin"),
                    "segment": it.get("segment") or (it.get("raw") or {}).get("segment"),
                    "series": it.get("series") or (it.get("raw") or {}).get("series"),
                    "value": value,
                    "realized_pnl": it.get("realized_pnl"),
                    "matched_qty": it.get("matched_qty"),
                    "matched_avg_entry_price": it.get("matched_avg_entry_price"),
                    "position_after": it.get("position_after"),
                    "open_avg_price": it.get("open_avg_price"),
                }
            )

        total = col.count_documents(q)

        return {
            "status": "success",
            "source": "db_learning_transactions",
            "days": days,
            "day": day_norm or None,
            "summary": summary,
            "total": int(total),
            "sync": sync_report,
            "items": items,
        }
    except HTTPException:
        raise
    except Exception as e:
        logger.exception("Failed to fetch learning transactions")
        raise HTTPException(status_code=500, detail=str(e))


@router.post("/learning/transactions/import", summary="Import (backfill) learning transactions from CSV")
async def import_learning_transactions_csv(
    file: UploadFile = File(..., description="CSV export (tradebook/contract note)"),
    db=Depends(database.get_mongo_db),
    current_user=Depends(get_current_userdetails),
):
    """Backfill multi-month history.

    Zerodha APIs generally don't provide months of tradebook history.
    This endpoint lets the user import/exported transactions into Mongo.
    """
    try:
        user_id = str(current_user.get("_id"))
        col = db["learning_transactions"]

        raw_bytes = await file.read()
        if not raw_bytes:
            raise HTTPException(status_code=400, detail="Empty file")

        # Decode with BOM-safe utf-8; fall back to latin-1.
        try:
            text = raw_bytes.decode("utf-8-sig")
        except Exception:
            text = raw_bytes.decode("latin-1")

        reader = csv.DictReader(io.StringIO(text))
        if not reader.fieldnames:
            raise HTTPException(status_code=400, detail="CSV has no header")

        processed = 0
        upserts = 0
        skipped = 0
        errors: List[Dict[str, Any]] = []

        ops: List[UpdateOne] = []
        batch_size = 500

        def flush_ops() -> None:
            nonlocal upserts
            if not ops:
                return
            # ordered=False makes large imports much faster.
            col.bulk_write(ops, ordered=False)
            upserts += len(ops)
            ops.clear()

        for row in reader:
            processed += 1
            # Normalize keys once per row.
            nrow: Dict[str, Any] = {}
            for k, v in (row or {}).items():
                nk = _norm_csv_key(k)
                if not nk:
                    continue
                nrow[nk] = v.strip() if isinstance(v, str) else v

            symbol = _norm_symbol(
                str(
                    _pick(
                        nrow,
                        "tradingsymbol",
                        "trading_symbol",
                        "symbol",
                        "instrument",
                    )
                    or ""
                )
            )
            if not symbol:
                skipped += 1
                continue

            side = _pick(
                nrow,
                "transaction_type",
                "transactiontype",
                "trade_type",
                "tradetype",
                "buy_sell",
                "side",
            )
            side_norm = (str(side).strip().upper() if side is not None else "") or None

            qty = _safe_int(_pick(nrow, "quantity", "qty"))
            price = _safe_float(_pick(nrow, "price", "tradeprice", "rate"))
            avg = _safe_float(_pick(nrow, "average_price", "avgprice", "averageprice"))

            exchange = _pick(nrow, "exchange")
            product = _pick(nrow, "product")
            segment = _pick(nrow, "segment")
            series = _pick(nrow, "series")
            isin = _pick(nrow, "isin")
            auction = _pick(nrow, "auction")

            trade_id = _pick(nrow, "trade_id", "tradeid")
            order_id = _pick(nrow, "order_id", "orderid")

            trade_date = _pick(nrow, "trade_date", "tradedate", "date")
            exec_time = _pick(
                nrow,
                "order_execution_time",
                "orderexecutiontime",
                "executiontime",
                "timestamp",
                "time",
            )

            ts_dt = _parse_ts_flexible(exec_time) or _parse_ts_flexible(trade_date)
            if ts_dt is None:
                # Keep a timestamp for sorting; mark as unknown.
                ts_dt = datetime.utcnow()

            now = datetime.utcnow()
            doc = {
                "user_id": user_id,
                "kind": "import",
                "symbol": symbol,
                "exchange": exchange,
                "product": product,
                "segment": segment,
                "series": series,
                "isin": isin,
                "auction": auction,
                "transaction_type": side_norm,
                "quantity": qty,
                "price": price,
                "average_price": avg,
                "trade_id": trade_id,
                "order_id": order_id,
                "ts": ts_dt,
                "raw": nrow,
                "updated_at": now,
            }

            if trade_id:
                key = {"user_id": user_id, "kind": "import", "trade_id": trade_id}
            elif order_id:
                key = {"user_id": user_id, "kind": "import", "order_id": order_id}
            else:
                key = {
                    "user_id": user_id,
                    "kind": "import",
                    "symbol": symbol,
                    "ts": ts_dt,
                    "transaction_type": side_norm,
                    "quantity": qty,
                    "price": price,
                }

            try:
                ops.append(UpdateOne(key, {"$set": doc, "$setOnInsert": {"created_at": now}}, upsert=True))
                if len(ops) >= batch_size:
                    flush_ops()
            except Exception as e:
                if len(errors) < 20:
                    errors.append({"row": processed, "error": str(e)})

        # flush remaining
        try:
            flush_ops()
        except Exception as e:
            if len(errors) < 20:
                errors.append({"row": processed, "error": f"bulk_write_failed: {e}"})

        return {
            "status": "success",
            "processed": processed,
            "upserts": upserts,
            "skipped": skipped,
            "errors": errors,
        }
    except HTTPException:
        raise
    except Exception as e:
        logger.exception("Failed to import learning transactions CSV")
        raise HTTPException(status_code=500, detail=str(e))


@router.get("/learning/paper-trades", summary="Get user's paper trades (simulation only)")
async def get_learning_paper_trades(
    page: int = Query(1, ge=1, le=100000, description="1-indexed page"),
    page_size: int = Query(50, ge=1, le=200, description="Rows per page"),
    limit: Optional[int] = Query(None, ge=1, le=200, description="(legacy) Max rows to return"),
    days: int = Query(120, ge=1, le=365, description="How many days of history to return"),
    day: Optional[str] = Query(None, description="Filter to a single day (YYYY-MM-DD) in Asia/Kolkata"),
    status: Optional[str] = Query(None, description="Filter by status: OPEN or CLOSED"),
    db=Depends(database.get_mongo_db),
    current_user=Depends(get_current_userdetails),
):
    try:
        user_id = str(current_user.get("_id"))
        account_id = _norm_account_id(current_user.get("account_id"))
        tz = ZoneInfo("Asia/Kolkata")
        if day:
            try:
                base = datetime.fromisoformat(str(day).strip())
            except Exception:
                raise HTTPException(status_code=400, detail="Invalid day; expected YYYY-MM-DD")
            start_local = base.replace(hour=0, minute=0, second=0, microsecond=0, tzinfo=tz)
            end_local = start_local + timedelta(days=1)
            start_utc = start_local.astimezone(ZoneInfo("UTC")).replace(tzinfo=None)
            end_utc = end_local.astimezone(ZoneInfo("UTC")).replace(tzinfo=None)
            q: Dict[str, Any] = {"user_id": user_id, "created_at": {"$gte": start_utc, "$lt": end_utc}}
        else:
            cutoff = datetime.utcnow() - timedelta(days=days)
            q = {"user_id": user_id, "created_at": {"$gte": cutoff}}

        if account_id:
            q["account_id"] = account_id
        if status:
            q["status"] = str(status).strip().upper()

        if isinstance(limit, int) and limit > 0:
            # Back-compat: treat `limit` as page_size for page 1
            page = 1
            page_size = int(limit)

        skip = max(0, (int(page) - 1) * int(page_size))
        cur = db["paper_trades"].find(q).sort([("created_at", -1)]).skip(skip).limit(int(page_size))
        docs = list(cur)

        items: List[Dict[str, Any]] = []
        for d in docs:
            items.append(
                {
                    "paper_trade_id": d.get("paper_trade_id") or str(d.get("_id")),
                    "symbol": d.get("symbol"),
                    "stock_id": d.get("stock_id"),
                    "direction": d.get("direction"),
                    "status": d.get("status"),
                    "source": d.get("source"),
                    "signal_id": d.get("signal_id"),
                    "snapshot_id": d.get("snapshot_id"),
                    "signal_strength": d.get("signal_strength"),
                    "decision_probability": d.get("decision_probability"),
                    "entry_price": d.get("entry_price"),
                    "stop_loss": d.get("stop_loss"),
                    "target": d.get("target"),
                    "quantity": d.get("quantity"),
                    "trade_value": d.get("trade_value"),
                    "reserved_amount": d.get("reserved_amount"),
                    "created_at": _dt_to_iso(d.get("created_at")),
                    "opened_at": _dt_to_iso(d.get("opened_at")),
                    "closed_at": _dt_to_iso(d.get("closed_at")),
                    "exit_price": d.get("exit_price"),
                    "exit_reason": d.get("exit_reason"),
                    "realized_pnl": d.get("realized_pnl"),
                    "realized_pnl_per_unit": d.get("realized_pnl_per_unit"),
                    "current_unrealized_pnl": d.get("current_unrealized_pnl"),
                    "current_unrealized_pnl_per_unit": d.get("current_unrealized_pnl_per_unit"),
                    "max_favorable_move": d.get("max_favorable_move"),
                    "max_adverse_move": d.get("max_adverse_move"),
                    "last_price_close": d.get("last_price_close"),
                    "last_candle_ts": _dt_to_iso(d.get("last_candle_ts")),
                    "last_candle_timeframe": d.get("last_candle_timeframe"),
                }
            )

        total = db["paper_trades"].count_documents(q)

        total_pages = int((int(total) + int(page_size) - 1) // int(page_size)) if int(page_size) > 0 else 1
        return {
            "status": "success",
            "source": "db_paper_trades",
            "days": days,
            "day": str(day).strip() if day else None,
            "total": int(total),
            "page": int(page),
            "page_size": int(page_size),
            "total_pages": int(total_pages),
            "items": items,
        }
    except Exception as e:
        logger.exception("Failed to fetch learning paper trades")
        raise HTTPException(status_code=500, detail=str(e))


@router.post("/learning/paper-trades/manual", summary="Create a manual paper trade")
async def create_manual_paper_trade_endpoint(
    payload: Dict[str, Any] = Body(...),
    db=Depends(database.get_mongo_db),
    current_user=Depends(get_current_userdetails),
):
    try:
        user_id = str(current_user.get("_id"))
        account_id = _norm_account_id(current_user.get("account_id"))

        symbol = str(payload.get("symbol") or "").strip().upper()
        direction = str(payload.get("direction") or payload.get("side") or "").strip().upper()
        stock_id = payload.get("stock_id")
        snapshot_id = payload.get("snapshot_id")

        entry_price = payload.get("entry_price")
        stop_loss = payload.get("stop_loss")
        target = payload.get("target")
        quantity = payload.get("quantity")

        from app.v1.services.paper_trading import create_manual_paper_trade

        res = create_manual_paper_trade(
            db,
            user_id=user_id,
            account_id=account_id,
            symbol=symbol,
            direction=direction,
            entry_price=float(entry_price),
            stop_loss=float(stop_loss),
            target=float(target),
            quantity=int(quantity),
            stock_id=str(stock_id) if stock_id is not None else None,
            snapshot_id=str(snapshot_id) if snapshot_id is not None else None,
        )

        if not isinstance(res, dict) or not res.get("ok"):
            return {"status": "error", "error": res.get("reason") if isinstance(res, dict) else "FAILED"}

        return {"status": "success", "paper_trade_id": res.get("paper_trade_id")}
    except HTTPException:
        raise
    except Exception as e:
        logger.exception("Failed to create manual paper trade")
        raise HTTPException(status_code=500, detail=str(e))


@router.get("/learning/paper-day", summary="Paper trades + balance statement for a day")
async def get_paper_day_statement(
    day: Optional[str] = Query(None, description="YYYY-MM-DD in Asia/Kolkata (default today)"),
    db=Depends(database.get_mongo_db),
    current_user=Depends(get_current_userdetails),
):
    """Day-based view for intraday-only paper trading.

    Returns a statement:
    - start balance (previous day close if available)
    - realized pnl for the day (closed trades)
    - end balance (= start + realized)
    - trades created that day
    """

    user_id = str(current_user.get("_id"))
    account_id = _norm_account_id(current_user.get("account_id"))
    if not account_id:
        raise HTTPException(status_code=400, detail="account_id missing")

    tz = ZoneInfo("Asia/Kolkata")
    now_local = datetime.utcnow().replace(tzinfo=ZoneInfo("UTC")).astimezone(tz)
    if day:
        try:
            base = datetime.fromisoformat(str(day).strip())
        except Exception:
            raise HTTPException(status_code=400, detail="Invalid day; expected YYYY-MM-DD")
        day_local = base.replace(tzinfo=tz)
    else:
        day_local = now_local

    start_local = day_local.replace(hour=0, minute=0, second=0, microsecond=0)
    end_local = start_local + timedelta(days=1)
    start_utc = start_local.astimezone(ZoneInfo("UTC")).replace(tzinfo=None)
    end_utc = end_local.astimezone(ZoneInfo("UTC")).replace(tzinfo=None)

    day_key = start_local.strftime("%Y-%m-%d")
    prev_key = (start_local - timedelta(days=1)).strftime("%Y-%m-%d")

    from app.v1.services.paper_trading import get_or_create_paper_account, PAPER_BALANCE_DAILY_COLLECTION

    acc = get_or_create_paper_account(db, user_id=user_id, account_id=account_id)

    daily = db[PAPER_BALANCE_DAILY_COLLECTION].find_one({"user_id": user_id, "account_id": account_id, "day": day_key}) or {}
    prev = db[PAPER_BALANCE_DAILY_COLLECTION].find_one({"user_id": user_id, "account_id": account_id, "day": prev_key}) or {}

    start_balance = daily.get("balance_open")
    if start_balance is None:
        start_balance = prev.get("balance_close")
    if start_balance is None:
        start_balance = acc.get("starting_balance")

    q = {"user_id": user_id, "account_id": account_id, "created_at": {"$gte": start_utc, "$lt": end_utc}}
    trades = list(db["paper_trades"].find(q, {"realized_pnl": 1, "status": 1}))

    realized = 0.0
    open_count = 0
    for t in trades:
        if str(t.get("status") or "").upper() == "OPEN":
            open_count += 1
        p = t.get("realized_pnl")
        try:
            realized += float(p) if p is not None else 0.0
        except Exception:
            pass

    end_balance = None
    try:
        end_balance = float(start_balance) + float(realized) if start_balance is not None else None
    except Exception:
        end_balance = None

    return {
        "status": "success",
        "day": day_key,
        "statement": {
            "balance_open": start_balance,
            "pnl_day_realized": float(realized),
            "balance_close": end_balance,
            "trades": len(trades),
            "open_trades": int(open_count),
        },
        "account": {
            "balance": acc.get("balance"),
            "reserved_balance": acc.get("reserved_balance"),
            "available_balance": acc.get("available_balance"),
            "starting_balance": acc.get("starting_balance"),
        },
    }


@router.get("/learning/paper-trades/snapshot/{snapshot_id}", summary="Get stored analysis snapshot for a paper trade")
async def get_paper_trade_snapshot(
    snapshot_id: str,
    db=Depends(database.get_mongo_db),
    current_user=Depends(get_current_userdetails),
):
    """DB-only: returns the exact `stock_analysis_snapshots` doc referenced by the user's paper trade."""

    user_id = str(current_user.get("_id"))
    account_id = _norm_account_id(current_user.get("account_id"))
    snap_id = (snapshot_id or "").strip()
    if not snap_id:
        raise HTTPException(status_code=400, detail="snapshot_id is required")

    q: Dict[str, Any] = {"user_id": user_id, "snapshot_id": snap_id}
    if account_id:
        q["account_id"] = account_id

    pt = db["paper_trades"].find_one(q, {"_id": 1})
    if not pt:
        raise HTTPException(status_code=404, detail="Paper trade snapshot not found")

    try:
        oid = ObjectId(snap_id)
    except Exception:
        raise HTTPException(status_code=400, detail="Invalid snapshot_id")

    snap = db["stock_analysis_snapshots"].find_one({"_id": oid})
    if not snap:
        raise HTTPException(status_code=404, detail="Snapshot not found")

    analysis = snap.get("analysis") if isinstance(snap.get("analysis"), dict) else {}
    if isinstance(snap.get("market_data"), dict) and not isinstance(analysis.get("market_data"), dict):
        analysis["market_data"] = snap.get("market_data")
    if isinstance(snap.get("features"), dict) and not isinstance(analysis.get("features"), dict):
        analysis["features"] = snap.get("features")

    symbol = (analysis.get("symbol") or "").strip().upper()
    if not symbol:
        # best-effort: resolve from stocks
        try:
            stock = db["stocks"].find_one({"stock_id": snap.get("stock_id")}, {"symbol": 1})
            if stock and stock.get("symbol"):
                analysis["symbol"] = str(stock.get("symbol")).strip().upper()
        except Exception:
            pass

    return {
        "status": "success",
        "snapshot": {
            "snapshot_id": str(snap.get("_id")),
            "stock_id": snap.get("stock_id"),
            "timestamp": _dt_to_iso(snap.get("timestamp")),
            "source": snap.get("source"),
            "analysis": analysis,
        },
    }


@router.get("/learning/paper-account", summary="Get user's paper trading account (game)")
async def get_paper_account(
    db=Depends(database.get_mongo_db),
    current_user=Depends(get_current_userdetails),
):
    try:
        user_id = str(current_user.get("_id"))
        account_id = _norm_account_id(current_user.get("account_id"))
        if not account_id:
            raise HTTPException(status_code=400, detail="account_id missing")

        from app.v1.services.paper_trading import get_or_create_paper_account

        acc = get_or_create_paper_account(db, user_id=user_id, account_id=account_id)
        return {
            "status": "success",
            "account": {
                "paper_account_id": acc.get("paper_account_id") or f"{user_id}:{account_id}",
                "starting_balance": acc.get("starting_balance"),
                "balance": acc.get("balance"),
                "reserved_balance": acc.get("reserved_balance"),
                "available_balance": acc.get("available_balance"),
                "settings": acc.get("settings") if isinstance(acc.get("settings"), dict) else {},
                "updated_at": _dt_to_iso(acc.get("updated_at")),
            },
        }
    except HTTPException:
        raise
    except Exception as e:
        logger.exception("Failed to fetch paper account")
        raise HTTPException(status_code=500, detail=str(e))


@router.put("/learning/paper-account", summary="Update paper trading account (balance + rules)")
async def update_paper_account(
    payload: Dict[str, Any] = Body(...),
    db=Depends(database.get_mongo_db),
    current_user=Depends(get_current_userdetails),
):
    try:
        user_id = str(current_user.get("_id"))
        account_id = _norm_account_id(current_user.get("account_id"))
        if not account_id:
            raise HTTPException(status_code=400, detail="account_id missing")

        from app.v1.services.paper_trading import (
            PAPER_ACCOUNTS_COLLECTION,
            get_or_create_paper_account,
            _upsert_daily_balance,
            _now_utc,
        )

        acc = get_or_create_paper_account(db, user_id=user_id, account_id=account_id)
        col = db[PAPER_ACCOUNTS_COLLECTION]

        # Editable fields
        next_balance = payload.get("balance")
        next_starting = payload.get("starting_balance")
        next_settings = payload.get("settings") if isinstance(payload.get("settings"), dict) else None

        updates: Dict[str, Any] = {"updated_at": _now_utc()}
        pushes: Dict[str, Any] = {}

        if next_settings is not None:
            # sanitize only allowed keys
            allowed = {
                "min_trade_value",
                "max_trade_value",
                "max_quantity",
                "max_loss_pct",
                "max_profit_pct",
                "eod_exit",
                "product",
                # New: paper-trade filters
                "min_score",
                "sources",
                "decisions",
            }

            sanitized: Dict[str, Any] = {}
            for k, v in next_settings.items():
                if k not in allowed:
                    continue

                if k == "min_score":
                    try:
                        vv = float(v)
                        if vv == vv and vv not in (float("inf"), float("-inf")):
                            sanitized[k] = max(0.0, min(100.0, float(vv)))
                    except Exception:
                        continue
                    continue

                if k == "sources":
                    if isinstance(v, list):
                        vals = [str(x or "").strip().upper() for x in v]
                        vals = [x for x in vals if x in {"MANUAL", "GAINER", "LOSER", "EARLY_MOVERS"}]
                        if vals:
                            sanitized[k] = vals
                    continue

                if k == "decisions":
                    if isinstance(v, list):
                        vals = [str(x or "").strip().upper() for x in v]
                        # Normalize LONG/SHORT aliases
                        vals = [("BUY" if x == "LONG" else "SELL" if x == "SHORT" else x) for x in vals]
                        vals = [x for x in vals if x in {"BUY", "SELL"}]
                        if vals:
                            sanitized[k] = vals
                    continue

                sanitized[k] = v

            # merge with existing
            merged = dict(acc.get("settings") or {})
            merged.update(sanitized)
            updates["settings"] = merged

        def _to_float(x):
            try:
                f = float(x)
                if f != f or f in (float("inf"), float("-inf")):
                    return None
                return f
            except Exception:
                return None

        bal_before = _to_float(acc.get("balance")) or 0.0
        reserved = _to_float(acc.get("reserved_balance")) or 0.0

        if next_starting is not None:
            sb = _to_float(next_starting)
            if sb is None or sb < 0:
                raise HTTPException(status_code=400, detail="starting_balance must be >= 0")
            updates["starting_balance"] = float(sb)

        if next_balance is not None:
            nb = _to_float(next_balance)
            if nb is None or nb < 0:
                raise HTTPException(status_code=400, detail="balance must be >= 0")

            # balance can be edited, but must not be below reserved
            nb = max(float(nb), float(reserved))
            updates["balance"] = float(nb)
            updates["available_balance"] = float(max(0.0, nb - reserved))
            pushes = {
                "ts": _now_utc(),
                "type": "MANUAL_EDIT",
                "balance_before": float(bal_before),
                "balance_after": float(nb),
                "reserved": float(reserved),
            }

        if len(updates) == 1 and not pushes:
            return {"status": "success", "account": {"paper_account_id": acc.get("paper_account_id")}}

        upd_doc: Dict[str, Any] = {"$set": updates}
        if pushes:
            upd_doc["$push"] = {"events": pushes}

        col.update_one({"user_id": user_id, "account_id": account_id}, upd_doc)

        # daily history for manual edit
        if next_balance is not None:
            nb2 = float(updates.get("balance") or bal_before)
            _upsert_daily_balance(
                db,
                user_id=user_id,
                account_id=account_id,
                dt_utc=_now_utc(),
                balance=float(nb2),
                delta=float(nb2 - bal_before),
                reason="MANUAL_BALANCE_EDIT",
            )

        acc2 = col.find_one({"user_id": user_id, "account_id": account_id}) or {}
        return {
            "status": "success",
            "account": {
                "paper_account_id": acc2.get("paper_account_id") or f"{user_id}:{account_id}",
                "starting_balance": acc2.get("starting_balance"),
                "balance": acc2.get("balance"),
                "reserved_balance": acc2.get("reserved_balance"),
                "available_balance": acc2.get("available_balance"),
                "settings": acc2.get("settings") if isinstance(acc2.get("settings"), dict) else {},
                "updated_at": _dt_to_iso(acc2.get("updated_at")),
            },
        }
    except HTTPException:
        raise
    except Exception as e:
        logger.exception("Failed to update paper account")
        raise HTTPException(status_code=500, detail=str(e))


@router.get("/learning/paper-game/summary", summary="Paper trading gamification summary")
async def get_paper_game_summary(
    days: int = Query(120, ge=1, le=365, description="History window for totals"),
    db=Depends(database.get_mongo_db),
    current_user=Depends(get_current_userdetails),
):
    try:
        user_id = str(current_user.get("_id"))
        account_id = _norm_account_id(current_user.get("account_id"))
        if not account_id:
            raise HTTPException(status_code=400, detail="account_id missing")

        from app.v1.services.paper_trading import get_or_create_paper_account, _paper_day_key, _now_utc

        acc = get_or_create_paper_account(db, user_id=user_id, account_id=account_id)
        cutoff = datetime.utcnow() - timedelta(days=days)
        qbase: Dict[str, Any] = {"user_id": user_id, "account_id": account_id, "created_at": {"$gte": cutoff}}

        total_trades = int(db["paper_trades"].count_documents(qbase))
        open_trades = int(db["paper_trades"].count_documents({**qbase, "status": "OPEN"}))

        closed_q = {**qbase, "status": "CLOSED"}
        closed_docs = list(db["paper_trades"].find(closed_q, {"realized_pnl": 1, "closed_at": 1}))
        closed_count = len(closed_docs)

        pnl_sum = 0.0
        wins = 0
        today_key = _paper_day_key(_now_utc())
        today_pnl = 0.0
        for d in closed_docs:
            p = d.get("realized_pnl")
            try:
                pf = float(p) if p is not None else 0.0
            except Exception:
                pf = 0.0
            pnl_sum += pf
            if pf > 0:
                wins += 1
            ca = d.get("closed_at")
            if isinstance(ca, datetime) and _paper_day_key(ca) == today_key:
                today_pnl += pf

        win_pct = (float(wins) / float(closed_count) * 100.0) if closed_count > 0 else 0.0

        return {
            "status": "success",
            "account": {
                "balance": acc.get("balance"),
                "reserved_balance": acc.get("reserved_balance"),
                "available_balance": acc.get("available_balance"),
                "starting_balance": acc.get("starting_balance"),
            },
            "metrics": {
                "total_trades": total_trades,
                "open_trades": open_trades,
                "closed_trades": closed_count,
                "pnl_sum": float(pnl_sum),
                "wins": int(wins),
                "win_pct": float(win_pct),
                "today_pnl": float(today_pnl),
                "days": int(days),
            },
        }
    except HTTPException:
        raise
    except Exception as e:
        logger.exception("Failed to compute paper game summary")
        raise HTTPException(status_code=500, detail=str(e))


@router.post("/learning/chat/{symbol}", summary="Learning chat about a trade/symbol (no fresh Zerodha data)")
async def learning_chat_with_symbol(
    symbol: str,
    payload: Dict[str, Any] = Body(...),
    db=Depends(database.get_mongo_db),
    current_user=Depends(get_current_userdetails),
):
    """Learning chat that stays within /learning.

    This endpoint does NOT require Zerodha access_token because it forces
    include_fresh_data=False and relies on DB snapshots + user-provided context.
    """
    try:
        result = chat_with_stock_service(
            db=db,
            zerodha_client=None,  # safe when include_fresh_data=False
            symbol=symbol,
            message=payload.get("message", ""),
            user_id=str(current_user.get("_id")),
            conversation_id=payload.get("conversation_id"),
            include_fresh_data=False,
        )
        return {"status": "success", "response": result}
    except Exception as e:
        logger.exception(f"Learning chat failed for symbol {symbol}")
        raise HTTPException(status_code=500, detail=str(e))
