from __future__ import annotations

import os
from dataclasses import dataclass
from datetime import datetime
from typing import Any, Dict, List, Optional, Tuple


def _safe_float(v: Any) -> Optional[float]:
    try:
        if v is None:
            return None
        if isinstance(v, bool):
            return None
        f = float(v)
        if f != f or f in (float("inf"), float("-inf")):
            return None
        return f
    except Exception:
        return None


def _env_float(name: str, default: float, *, min_value: float = 0.0, max_value: float = 1e9) -> float:
    raw = os.getenv(name)
    if raw is None or str(raw).strip() == "":
        return float(default)
    try:
        v = float(str(raw).strip())
    except Exception:
        return float(default)
    if v != v or v in (float("inf"), float("-inf")):
        return float(default)
    return float(max(min_value, min(max_value, v)))


def _get_candles(market_data: Dict[str, Any], timeframe: str) -> List[Dict[str, Any]]:
    if not isinstance(market_data, dict):
        return []
    candles_by_tf = market_data.get("candles")
    if not isinstance(candles_by_tf, dict):
        return []
    arr = candles_by_tf.get(timeframe)
    if not isinstance(arr, list):
        return []
    return [c for c in arr if isinstance(c, dict)]


def _candle_num(c: Dict[str, Any], key: str) -> Optional[float]:
    return _safe_float(c.get(key) or c.get(key[:1]))


def _swing_high_low(candles: List[Dict[str, Any]], *, lookback: int) -> Tuple[Optional[float], Optional[float]]:
    if not candles:
        return None, None
    lookback = max(1, int(lookback))
    window = candles[-lookback:]
    highs = [_candle_num(c, "high") for c in window]
    lows = [_candle_num(c, "low") for c in window]
    highs = [h for h in highs if h is not None]
    lows = [l for l in lows if l is not None]
    return (max(highs) if highs else None), (min(lows) if lows else None)


def _latest_indicator_summary(market_data: Dict[str, Any], preferred: str, fallback: str) -> Dict[str, Any]:
    ind = market_data.get("indicators") if isinstance(market_data, dict) else None
    if not isinstance(ind, dict):
        return {}
    for tf in (preferred, fallback):
        s = ind.get(tf)
        if isinstance(s, dict) and s:
            return s
    return {}


def _get_pivots_day(market_data: Dict[str, Any]) -> Dict[str, float]:
    pivots = market_data.get("pivots") if isinstance(market_data, dict) else None
    if not isinstance(pivots, dict):
        return {}
    day = pivots.get("day")
    if not isinstance(day, dict):
        return {}

    out: Dict[str, float] = {}
    for k in ("R1", "R2", "R3", "P", "S1", "S2", "S3"):
        v = _safe_float(day.get(k))
        if v is not None:
            out[k] = float(v)
    return out


@dataclass(frozen=True)
class ExecutionPlan:
    entry_zone: Dict[str, float]
    entry_trigger_reason: List[str]
    sl: Optional[float]
    targets: List[float]
    rr_ratio: Optional[float]
    state: str  # WAITING_FOR_ENTRY | ENTRY_ACTIVATED
    diagnostics: Dict[str, Any]


def build_execution_plan(
    *,
    decision: str,
    market_data: Dict[str, Any],
    now_utc: Optional[datetime] = None,
) -> Optional[ExecutionPlan]:
    """Rule-based execution plan.

    - Does NOT use LTP/current price as entry.
    - Produces an entry_zone + trigger-based activation state.
    - SL/Targets are derived from pivots + recent swings (structure).

    This function is intentionally deterministic and read-only.
    """

    d = (decision or "").strip().upper()
    if d not in {"BUY", "SELL"}:
        return None

    now_utc = now_utc or datetime.utcnow()

    zone_buffer_pct = _env_float("ENTRY_ZONE_BUFFER_PCT", 0.001, min_value=0.0, max_value=0.02)
    swing_lookback = int(_env_float("ENTRY_SWING_LOOKBACK", 12, min_value=3.0, max_value=200.0))
    missed_pct = _env_float("ENTRY_MISSED_PCT", 0.01, min_value=0.0, max_value=0.25)
    sl_buffer_pct = _env_float("ENTRY_SL_BUFFER_PCT", 0.001, min_value=0.0, max_value=0.02)

    candles_5m = _get_candles(market_data, "5minute")
    candles_15m = _get_candles(market_data, "15minute")
    candles_day = _get_candles(market_data, "day")

    ind = _latest_indicator_summary(market_data, "5minute", "15minute")
    ema = ind.get("ema") if isinstance(ind.get("ema"), dict) else {}
    ema9 = _safe_float(ema.get("9"))
    ema21 = _safe_float(ema.get("21"))
    vwap = _safe_float(ind.get("vwap"))
    close = _safe_float(ind.get("close"))

    piv = _get_pivots_day(market_data)

    # --- Determine structural reference level for the zone ---
    ref_level: Optional[float] = None
    ref_reason: str = ""

    # Prefer EMA pullback zone if EMAs are present.
    if ema9 is not None and ema21 is not None:
        # For BUY, we prefer pullback between EMA9 and EMA21 (support band).
        # For SELL, it's the same band (resistance band) but interpreted oppositely.
        low = min(float(ema9), float(ema21))
        high = max(float(ema9), float(ema21))
        if low > 0 and high > 0 and high > low:
            entry_zone = {"low": low, "high": high}
            ref_reason = "EMA_PULLBACK_ZONE"
        else:
            entry_zone = {"low": 0.0, "high": 0.0}
    else:
        entry_zone = {"low": 0.0, "high": 0.0}

    # If EMA zone is unusable, fall back to pivots.
    if entry_zone.get("low", 0.0) <= 0 or entry_zone.get("high", 0.0) <= 0:
        if d == "BUY":
            for k in ("R1", "P"):
                if k in piv:
                    ref_level = float(piv[k])
                    ref_reason = f"PIVOT_{k}"
                    break
        else:
            for k in ("S1", "P"):
                if k in piv:
                    ref_level = float(piv[k])
                    ref_reason = f"PIVOT_{k}"
                    break

        # If still missing, use previous 15m structure.
        if ref_level is None and candles_15m:
            prev = candles_15m[-2] if len(candles_15m) >= 2 else candles_15m[-1]
            prev_high = _candle_num(prev, "high")
            prev_low = _candle_num(prev, "low")
            if d == "BUY" and prev_high is not None:
                ref_level = float(prev_high)
                ref_reason = "PREV_15M_HIGH_BREAKOUT"
            elif d == "SELL" and prev_low is not None:
                ref_level = float(prev_low)
                ref_reason = "PREV_15M_LOW_BREAKDOWN"

        if ref_level is None and close is not None:
            ref_level = float(close)
            ref_reason = "FALLBACK_CLOSE_LEVEL"

        if ref_level is not None and ref_level > 0:
            entry_zone = {
                "low": float(ref_level) * (1.0 - zone_buffer_pct),
                "high": float(ref_level) * (1.0 + zone_buffer_pct),
            }

    # --- Determine structural SL + targets ---
    swing_high, swing_low = _swing_high_low(candles_5m or candles_15m, lookback=swing_lookback)

    day_high = None
    day_low = None
    if candles_day:
        last_day = candles_day[-1]
        day_high = _candle_num(last_day, "high")
        day_low = _candle_num(last_day, "low")

    sl: Optional[float] = None
    targets: List[float] = []

    if d == "BUY":
        candidates: List[float] = []
        if swing_low is not None:
            candidates.append(float(swing_low))
        if "S1" in piv:
            candidates.append(float(piv["S1"]))
        if vwap is not None:
            candidates.append(float(vwap))
        # Structural invalidation: choose the highest *support* among candidates (tighter but still structural).
        if candidates:
            sl = max(candidates)

        # Targets: next resistances via pivots if available.
        for k in ("R1", "R2", "R3"):
            if k in piv:
                targets.append(float(piv[k]))
        # Ensure targets are above the zone.
        zone_high = float(entry_zone.get("high") or 0.0)
        targets = [t for t in targets if t > zone_high]
        if not targets:
            if swing_high is not None and swing_high > zone_high:
                targets.append(float(swing_high))
            if day_high is not None and day_high > zone_high:
                targets.append(float(day_high))

    else:  # SELL
        candidates = []
        if swing_high is not None:
            candidates.append(float(swing_high))
        if "R1" in piv:
            candidates.append(float(piv["R1"]))
        if vwap is not None:
            candidates.append(float(vwap))
        # Structural invalidation: choose the lowest *resistance* among candidates (tighter but structural).
        if candidates:
            sl = min(candidates)

        for k in ("S1", "S2", "S3"):
            if k in piv:
                targets.append(float(piv[k]))
        zone_low = float(entry_zone.get("low") or 0.0)
        targets = [t for t in targets if t < zone_low]
        if not targets:
            if swing_low is not None and swing_low < zone_low:
                targets.append(float(swing_low))
            if day_low is not None and day_low < zone_low:
                targets.append(float(day_low))

    # --- Sanity: SL must be on the correct side of the entry zone ---
    # BUY: SL strictly below zone.low
    # SELL: SL strictly above zone.high
    try:
        zl = _safe_float(entry_zone.get("low"))
        zh = _safe_float(entry_zone.get("high"))
        if sl is not None and zl is not None and zh is not None and zl > 0 and zh > 0:
            if d == "BUY":
                max_ok = float(zl) * (1.0 - float(sl_buffer_pct))
                if not (sl < float(zl)):
                    sl = max_ok
                else:
                    sl = min(float(sl), max_ok)
            else:  # SELL
                min_ok = float(zh) * (1.0 + float(sl_buffer_pct))
                if not (sl > float(zh)):
                    sl = min_ok
                else:
                    sl = max(float(sl), min_ok)
    except Exception:
        pass

    # Keep 1-2 targets, deterministic order.
    targets = targets[:2]

    # --- Trigger evaluation (prefer 5m; fall back to 15m) ---
    reasons: List[str] = []
    state = "WAITING_FOR_ENTRY"

    trigger_tf = "5minute" if candles_5m else "15minute" if candles_15m else ""
    trigger_candles = candles_5m if candles_5m else candles_15m

    if trigger_candles:
        last = trigger_candles[-1]
        prev = trigger_candles[-2] if len(trigger_candles) >= 2 else None
        c_open = _candle_num(last, "open")
        c_close = _candle_num(last, "close")
        c_high = _candle_num(last, "high")
        c_low = _candle_num(last, "low")

        zone_low = float(entry_zone.get("low") or 0.0)
        zone_high = float(entry_zone.get("high") or 0.0)

        if d == "BUY":
            if c_close is not None and c_close >= zone_high:
                reasons.append("5M_CLOSE_ABOVE_ZONE" if trigger_tf == "5minute" else "15M_CLOSE_ABOVE_ZONE")
            elif c_open is not None and c_open > zone_high:
                reasons.append("5M_GAP_OPEN_ABOVE_ZONE" if trigger_tf == "5minute" else "15M_GAP_OPEN_ABOVE_ZONE")

            if prev is not None:
                prev_close = _candle_num(prev, "close")
                if prev_close is not None and prev_close >= zone_high and c_low is not None and c_low <= zone_high and c_close is not None and c_close > zone_high:
                    reasons.append("BREAKOUT_RETEST_HOLD" if trigger_tf == "5minute" else "BREAKOUT_RETEST_HOLD_15M")

            # Missed trade diagnostic: price ran without trigger.
            if not reasons and c_high is not None and zone_high > 0 and c_high >= zone_high * (1.0 + missed_pct):
                # Not a trigger; just a diagnostic.
                pass

        else:  # SELL
            if c_close is not None and c_close <= zone_low:
                reasons.append("5M_CLOSE_BELOW_ZONE" if trigger_tf == "5minute" else "15M_CLOSE_BELOW_ZONE")
            elif c_open is not None and c_open < zone_low:
                reasons.append("5M_GAP_OPEN_BELOW_ZONE" if trigger_tf == "5minute" else "15M_GAP_OPEN_BELOW_ZONE")

            if prev is not None:
                prev_close = _candle_num(prev, "close")
                if prev_close is not None and prev_close <= zone_low and c_high is not None and c_high >= zone_low and c_close is not None and c_close < zone_low:
                    reasons.append("BREAKDOWN_RETEST_HOLD" if trigger_tf == "5minute" else "BREAKDOWN_RETEST_HOLD_15M")

    if reasons:
        state = "ENTRY_ACTIVATED"

    # --- RR ratio ---
    rr: Optional[float] = None
    if sl is not None and targets:
        zone_mid = (float(entry_zone.get("low") or 0.0) + float(entry_zone.get("high") or 0.0)) / 2.0
        t1 = float(targets[0])
        if d == "BUY":
            risk = zone_mid - float(sl)
            reward = t1 - zone_mid
        else:
            risk = float(sl) - zone_mid
            reward = zone_mid - t1
        if risk > 0 and reward > 0:
            rr = round(reward / risk, 3)

    diagnostics: Dict[str, Any] = {
        "ref_reason": ref_reason,
        "ema9": ema9,
        "ema21": ema21,
        "vwap": vwap,
        "close": close,
        "swing_high": swing_high,
        "swing_low": swing_low,
        "pivots": piv,
        "computed_at": now_utc.isoformat() + "Z",
    }

    # Missed-trade analysis (diagnostic-only)
    missed: List[str] = []
    if state == "WAITING_FOR_ENTRY" and trigger_candles:
        last = trigger_candles[-1]
        c_high = _candle_num(last, "high")
        c_low = _candle_num(last, "low")
        zone_low = float(entry_zone.get("low") or 0.0)
        zone_high = float(entry_zone.get("high") or 0.0)
        if d == "BUY" and c_high is not None and zone_high > 0 and c_high >= zone_high * (1.0 + missed_pct):
            missed.append("PRICE_RAN_UP_WITHOUT_TRIGGER" if trigger_tf == "5minute" else "PRICE_RAN_UP_WITHOUT_TRIGGER_15M")
        if d == "SELL" and c_low is not None and zone_low > 0 and c_low <= zone_low * (1.0 - missed_pct):
            missed.append("PRICE_DROPPED_WITHOUT_TRIGGER" if trigger_tf == "5minute" else "PRICE_DROPPED_WITHOUT_TRIGGER_15M")
    if missed:
        diagnostics["missed_trade_reasons"] = missed

    if trigger_tf:
        diagnostics["trigger_timeframe"] = trigger_tf

    return ExecutionPlan(
        entry_zone={"low": float(entry_zone.get("low") or 0.0), "high": float(entry_zone.get("high") or 0.0)},
        entry_trigger_reason=reasons,
        sl=float(sl) if sl is not None else None,
        targets=[float(t) for t in targets if t is not None],
        rr_ratio=float(rr) if rr is not None else None,
        state=state,
        diagnostics=diagnostics,
    )
