from __future__ import annotations

from typing import Dict, Any, Optional

try:  # pragma: no cover - runtime dependency may be missing/broken
    import pandas as pd  # type: ignore
    _PANDAS_OK = True
except Exception:  # pragma: no cover
    pd = None  # type: ignore
    _PANDAS_OK = False

from .indicators import IndicatorCalculator


class StrategyFeatureCalculator:
    """Numeric/enum strategy features for GPT consumption.

    This is intentionally read-only and does not emit trading signals.
    It summarizes trend, volatility, range/ORB breakouts, HOD/LOD
    proximity, momentum bursts, volume climaxes, Fibonacci zones and
    simple candlestick patterns per timeframe.
    """

    @staticmethod
    def _safe_df(candles: Any) -> Optional[pd.DataFrame]:
        if not _PANDAS_OK:
            return None
        if not candles:
            return None
        try:
            df = pd.DataFrame(candles)
            required_cols = {"open", "high", "low", "close"}
            if not required_cols.issubset(df.columns):
                return None
            return df
        except Exception:
            return None

    @staticmethod
    def _trend_structure(df: pd.DataFrame) -> Dict[str, Any]:
        if df is None or df.empty:
            return {"direction": "unknown", "strength": "unknown"}

        closes = df["close"]
        if len(closes) < 5:
            return {"direction": "unknown", "strength": "unknown"}

        ema_fast = closes.ewm(span=9, adjust=False).mean()
        ema_slow = closes.ewm(span=21, adjust=False).mean()

        last_close = float(closes.iloc[-1])
        slope = float(ema_fast.iloc[-1] - ema_fast.iloc[0])

        if last_close > float(ema_fast.iloc[-1]) > float(ema_slow.iloc[-1]) and slope > 0:
            direction = "uptrend"
        elif last_close < float(ema_fast.iloc[-1]) < float(ema_slow.iloc[-1]) and slope < 0:
            direction = "downtrend"
        else:
            direction = "sideways"

        strength = "strong" if abs(slope) > 0.002 * last_close else "normal"

        return {"direction": direction, "strength": strength}

    @staticmethod
    def _volatility_regime(df: pd.DataFrame) -> Dict[str, Any]:
        if df is None or df.empty:
            return {"regime": "unknown"}

        closes = df["close"].astype(float)
        if len(closes) < 20:
            return {"regime": "unknown"}

        returns = closes.pct_change().dropna()
        if returns.empty:
            return {"regime": "unknown"}

        rolling = returns.rolling(20).std().dropna()
        if rolling.empty:
            return {"regime": "unknown"}

        current = float(rolling.iloc[-1])
        median = float(rolling.median()) if len(rolling) > 0 else current

        if current > 1.5 * median:
            regime = "high"
        elif current < 0.7 * median:
            regime = "low"
        else:
            regime = "normal"

        return {"regime": regime}

    @staticmethod
    def _range_breakout(df: pd.DataFrame, lookback: int = 20) -> Dict[str, Any]:
        if df is None or df.empty:
            return {"status": "unknown"}

        if len(df) < lookback + 1:
            return {"status": "unknown"}

        window = df.iloc[-(lookback + 1):-1]
        last = df.iloc[-1]

        high_range = float(window["high"].max())
        low_range = float(window["low"].min())
        last_close = float(last["close"])

        tol = 0.002 * last_close

        if last_close > high_range + tol:
            status = "breakout_up"
        elif last_close < low_range - tol:
            status = "breakout_down"
        elif abs(last_close - high_range) <= tol:
            status = "near_range_high"
        elif abs(last_close - low_range) <= tol:
            status = "near_range_low"
        else:
            status = "inside_range"

        return {
            "status": status,
            "range_high": high_range,
            "range_low": low_range,
        }

    @staticmethod
    def _opening_range_breakout(df: pd.DataFrame, candles_in_open: int = 3) -> Dict[str, Any]:
        if df is None or df.empty or len(df) <= candles_in_open:
            return {"status": "unknown"}

        first = df.iloc[:candles_in_open]
        last = df.iloc[-1]

        or_high = float(first["high"].max())
        or_low = float(first["low"].min())
        last_close = float(last["close"])
        tol = 0.002 * last_close

        if last_close > or_high + tol:
            status = "orb_breakout_up"
        elif last_close < or_low - tol:
            status = "orb_breakout_down"
        elif abs(last_close - or_high) <= tol:
            status = "near_or_high"
        elif abs(last_close - or_low) <= tol:
            status = "near_or_low"
        else:
            status = "inside_open_range"

        return {"status": status, "or_high": or_high, "or_low": or_low}

    @staticmethod
    def _hod_lod_proximity(df: pd.DataFrame) -> Dict[str, Any]:
        if df is None or df.empty:
            return {"near_hod": False, "near_lod": False}

        high_day = float(df["high"].max())
        low_day = float(df["low"].min())
        last_close = float(df["close"].iloc[-1])

        dist_h = (high_day - last_close) / last_close
        dist_l = (last_close - low_day) / last_close

        near_hod = dist_h >= 0 and dist_h < 0.003
        near_lod = dist_l >= 0 and dist_l < 0.003

        return {
            "near_hod": bool(near_hod),
            "near_lod": bool(near_lod),
            "distance_to_hod": float(dist_h),
            "distance_to_lod": float(dist_l),
        }

    @staticmethod
    def _momentum_burst(df: pd.DataFrame) -> Dict[str, Any]:
        if df is None or df.empty or len(df) < 10:
            return {"status": "unknown"}

        closes = df["close"].astype(float)
        highs = df["high"].astype(float)
        lows = df["low"].astype(float)
        volumes = df["volume"].astype(float) if "volume" in df.columns else None

        body = (closes - df["open"].astype(float)).abs()
        range_ = (highs - lows).abs()
        avg_range = range_.rolling(20).mean().iloc[-1] if len(range_) >= 20 else range_.mean()

        last_body = float(body.iloc[-1])
        last_range = float(range_.iloc[-1])

        vol_spike = False
        volume_ratio = None
        if volumes is not None and len(volumes) > 5:
            avg_vol = float(volumes.rolling(20).mean().iloc[-1]) if len(volumes) >= 20 else float(volumes.mean())
            if avg_vol > 0:
                volume_ratio = float(volumes.iloc[-1]) / avg_vol
                vol_spike = volume_ratio > 1.5

        strong_body = avg_range > 0 and last_body > 0.7 * avg_range
        expanded_range = avg_range > 0 and last_range > 1.5 * avg_range

        if strong_body and expanded_range and vol_spike:
            status = "momentum_burst"
        else:
            status = "normal"

        return {
            "status": status,
            "volume_ratio": volume_ratio,
        }

    @staticmethod
    def _volume_climax(df: pd.DataFrame) -> Dict[str, Any]:
        if df is None or df.empty or "volume" not in df.columns or len(df) < 30:
            return {"climax": False}

        volumes = df["volume"].astype(float)
        avg_vol = float(volumes.rolling(30).mean().iloc[-1])
        if avg_vol <= 0:
            return {"climax": False}

        last_vol = float(volumes.iloc[-1])
        vol_ratio = last_vol / avg_vol

        is_climax = vol_ratio > 2.0 and last_vol == float(volumes.max())

        return {"climax": bool(is_climax), "volume_ratio": vol_ratio}

    @staticmethod
    def _mean_reversion_stretch(df: pd.DataFrame) -> Dict[str, Any]:
        """How far price is stretched from a recent mean (non-VWAP).

        Uses EMA(20) and ATR-like true range to express stretch in ATR multiples.
        """

        if df is None or df.empty or len(df) < 20:
            return {"stretch": "unknown"}

        closes = df["close"].astype(float)
        highs = df["high"].astype(float)
        lows = df["low"].astype(float)

        ema20 = closes.ewm(span=20, adjust=False).mean()

        prev_close = closes.shift(1)
        tr = pd.concat(
            [
                (highs - lows).abs(),
                (highs - prev_close).abs(),
                (lows - prev_close).abs(),
            ],
            axis=1,
        ).max(axis=1)
        atr14 = tr.rolling(14).mean()

        last_close = float(closes.iloc[-1])
        last_ema = float(ema20.iloc[-1])
        last_atr = float(atr14.iloc[-1]) if not pd.isna(atr14.iloc[-1]) else 0.0

        distance = last_close - last_ema
        atr_mult = distance / last_atr if last_atr > 0 else 0.0

        if abs(atr_mult) >= 2.0:
            stretch = "extreme"
        elif abs(atr_mult) >= 1.0:
            stretch = "mild"
        else:
            stretch = "none"

        side = "above" if distance > 0 else "below" if distance < 0 else "at_mean"

        return {
            "stretch": stretch,
            "side": side,
            "distance_from_ema": float(distance),
            "atr_multiples": float(atr_mult),
        }

    @staticmethod
    def _volatility_shift(df: pd.DataFrame) -> Dict[str, Any]:
        """Detect compression vs expansion using recent true range."""

        if df is None or df.empty or len(df) < 40:
            return {"phase": "unknown"}

        highs = df["high"].astype(float)
        lows = df["low"].astype(float)
        closes = df["close"].astype(float)
        prev_close = closes.shift(1)

        tr = pd.concat(
            [
                (highs - lows).abs(),
                (highs - prev_close).abs(),
                (lows - prev_close).abs(),
            ],
            axis=1,
        ).max(axis=1)

        recent = tr.iloc[-20:]
        prior = tr.iloc[-40:-20]
        if recent.empty or prior.empty:
            return {"phase": "unknown"}

        recent_avg = float(recent.mean())
        prior_avg = float(prior.mean())

        if prior_avg <= 0:
            return {"phase": "unknown"}

        ratio = recent_avg / prior_avg
        if ratio < 0.7:
            phase = "compression"
        elif ratio > 1.5:
            phase = "expansion"
        else:
            phase = "stable"

        return {"phase": phase, "recent_to_prior_ratio": ratio}

    @staticmethod
    def _basic_patterns(df: pd.DataFrame) -> Dict[str, Any]:
        if df is None or df.empty or len(df) < 2:
            return {"pattern": "none"}

        o = df["open"].astype(float)
        h = df["high"].astype(float)
        l = df["low"].astype(float)
        c = df["close"].astype(float)

        o1, c1, h1, l1 = o.iloc[-1], c.iloc[-1], h.iloc[-1], l.iloc[-1]
        o2, c2, h2, l2 = o.iloc[-2], c.iloc[-2], h.iloc[-2], l.iloc[-2]

        body = abs(c1 - o1)
        range_ = h1 - l1
        small_body = range_ > 0 and body / range_ < 0.2

        pattern = "none"

        if small_body and abs(c1 - o1) <= 0.1 * range_:
            pattern = "doji"
        elif c1 > o1 and c2 < o2 and c1 >= h2 and o1 <= l2:
            pattern = "bullish_engulfing"
        elif c1 < o1 and c2 > o2 and c1 <= l2 and o1 >= h2:
            pattern = "bearish_engulfing"
        else:
            upper_wick = h1 - max(o1, c1)
            lower_wick = min(o1, c1) - l1
            if lower_wick > 2 * body and upper_wick < body:
                pattern = "hammer"
            elif upper_wick > 2 * body and lower_wick < body:
                pattern = "shooting_star"

        return {"pattern": pattern}

    @staticmethod
    def _rsi_macd_divergence(df: pd.DataFrame) -> Dict[str, Any]:
        """Very simple regular divergence detection on last two swings.

        Looks at price vs RSI and MACD highs in two recent windows.
        """

        if df is None or df.empty or len(df) < 30:
            return {"rsi": "none", "macd": "none"}

        if "rsi" not in df.columns or "macd" not in df.columns:
            return {"rsi": "none", "macd": "none"}

        closes = df["close"].astype(float)
        rsi = df["rsi"].astype(float)
        macd = df["macd"].astype(float)

        # Split last 30 bars into older and recent window
        window = closes.iloc[-30:]
        if len(window) < 30:
            return {"rsi": "none", "macd": "none"}

        older_idx = window.index[:15]
        recent_idx = window.index[15:]

        price_old_high = float(closes.loc[older_idx].max())
        price_recent_high = float(closes.loc[recent_idx].max())
        rsi_old_at_high = float(rsi.loc[closes.loc[older_idx].idxmax()])
        rsi_recent_at_high = float(rsi.loc[closes.loc[recent_idx].idxmax()])
        macd_old_at_high = float(macd.loc[closes.loc[older_idx].idxmax()])
        macd_recent_at_high = float(macd.loc[closes.loc[recent_idx].idxmax()])

        rsi_div = "none"
        macd_div = "none"

        # Bearish: price higher high, indicator lower high
        if price_recent_high > price_old_high and rsi_recent_at_high < rsi_old_at_high:
            rsi_div = "bearish"
        # Bullish: price lower low, indicator higher low (approx via inverted logic)
        price_old_low = float(closes.loc[older_idx].min())
        price_recent_low = float(closes.loc[recent_idx].min())
        rsi_old_at_low = float(rsi.loc[closes.loc[older_idx].idxmin()])
        rsi_recent_at_low = float(rsi.loc[closes.loc[recent_idx].idxmin()])
        if price_recent_low < price_old_low and rsi_recent_at_low > rsi_old_at_low:
            rsi_div = "bullish"

        if price_recent_high > price_old_high and macd_recent_at_high < macd_old_at_high:
            macd_div = "bearish"
        macd_old_at_low = float(macd.loc[closes.loc[older_idx].idxmin()])
        macd_recent_at_low = float(macd.loc[closes.loc[recent_idx].idxmin()])
        if price_recent_low < price_old_low and macd_recent_at_low > macd_old_at_low:
            macd_div = "bullish"

        return {"rsi": rsi_div, "macd": macd_div}

    @staticmethod
    def _volume_divergence(df: pd.DataFrame) -> Dict[str, Any]:
        """Price makes a new extreme but volume fades relative to prior extreme."""

        if df is None or df.empty or "volume" not in df.columns or len(df) < 30:
            return {"divergence": "none"}

        closes = df["close"].astype(float)
        volumes = df["volume"].astype(float)

        window = closes.iloc[-30:]
        if len(window) < 30:
            return {"divergence": "none"}

        older_idx = window.index[:15]
        recent_idx = window.index[15:]

        price_old_high_idx = closes.loc[older_idx].idxmax()
        price_recent_high_idx = closes.loc[recent_idx].idxmax()

        price_old_high = float(closes.loc[price_old_high_idx])
        price_recent_high = float(closes.loc[price_recent_high_idx])
        vol_old_high = float(volumes.loc[price_old_high_idx])
        vol_recent_high = float(volumes.loc[price_recent_high_idx])

        div = "none"
        vol_ratio = None
        if price_recent_high > price_old_high and vol_recent_high < vol_old_high * 0.8:
            div = "bearish_price_up_vol_down"
            vol_ratio = vol_recent_high / vol_old_high if vol_old_high > 0 else None

        return {"divergence": div, "volume_ratio_vs_prev_high": vol_ratio}

    @staticmethod
    def _support_resistance_zones(df: pd.DataFrame) -> Dict[str, Any]:
        """Approximate horizontal support/resistance from recent swing highs/lows."""

        if df is None or df.empty or len(df) < 20:
            return {}

        closes = df["close"].astype(float)
        highs = df["high"].astype(float)
        lows = df["low"].astype(float)

        last_close = float(closes.iloc[-1])

        window_highs = highs.iloc[-50:] if len(highs) > 50 else highs
        window_lows = lows.iloc[-50:] if len(lows) > 50 else lows

        # Simple fractal pivots
        pivot_highs = []
        pivot_lows = []
        for i in range(2, len(window_highs) - 2):
            if window_highs.iloc[i] >= window_highs.iloc[i - 1 : i + 2].max():
                pivot_highs.append(float(window_highs.iloc[i]))
            if window_lows.iloc[i] <= window_lows.iloc[i - 1 : i + 2].min():
                pivot_lows.append(float(window_lows.iloc[i]))

        if not pivot_highs and not pivot_lows:
            return {}

        sup = None
        res = None
        for ph in pivot_highs:
            if ph >= last_close and (res is None or ph < res):
                res = ph
        for pl in pivot_lows:
            if pl <= last_close and (sup is None or pl > sup):
                sup = pl

        info: Dict[str, Any] = {}
        if sup is not None:
            info["support"] = sup
            info["distance_to_support"] = (last_close - sup) / last_close
        if res is not None:
            info["resistance"] = res
            info["distance_to_resistance"] = (res - last_close) / last_close

        return info

    @staticmethod
    def _liquidity_sweep(df: pd.DataFrame, lookback: int = 20) -> Dict[str, Any]:
        """Detect simple buy/sell-side liquidity sweeps around obvious highs/lows."""

        if df is None or df.empty or len(df) < lookback + 2:
            return {"sweep": "none"}

        window = df.iloc[-(lookback + 1) : -1]
        last = df.iloc[-1]

        prev_high = float(window["high"].max())
        prev_low = float(window["low"].min())
        last_high = float(last["high"])
        last_low = float(last["low"])
        last_close = float(last["close"])

        sweep = "none"
        if last_high > prev_high and last_close < prev_high:
            sweep = "buy_side"
        elif last_low < prev_low and last_close > prev_low:
            sweep = "sell_side"

        return {"sweep": sweep, "ref_high": prev_high, "ref_low": prev_low}

    @staticmethod
    def _fib_zone_for_price(price: float, fib_levels: Optional[Dict[str, Any]]) -> Dict[str, Any]:
        if not fib_levels or price is None:
            return {"zone": "unknown"}

        retr = fib_levels.get("retracements", {}) or {}
        try:
            r0 = float(retr.get("0"))
            r382 = float(retr.get("38.2"))
            r618 = float(retr.get("61.8"))
            r100 = float(retr.get("100"))
        except Exception:
            return {"zone": "unknown"}

        zone = "between_0_38"
        if price >= r0 and price < r382:
            zone = "between_0_38"
        elif price >= r382 and price < r618:
            zone = "between_38_62"
        elif price >= r618 and price <= r100:
            zone = "between_62_100"
        elif price > r100:
            zone = "above_100"
        elif price < r0:
            zone = "below_0"

        return {"zone": zone, "ref_levels": {"0": r0, "38.2": r382, "61.8": r618, "100": r100}}

    @staticmethod
    def summarize_timeframe(
        df: Optional[pd.DataFrame],
        timeframe: str,
        fib_day: Optional[Dict[str, Any]] = None,
    ) -> Dict[str, Any]:
        if df is None or df.empty:
            return {}

        try:
            df_ind = IndicatorCalculator.add_all_indicators(df.copy())
        except Exception:
            df_ind = df

        trend = StrategyFeatureCalculator._trend_structure(df_ind)
        vol_regime = StrategyFeatureCalculator._volatility_regime(df_ind)
        range_brk = StrategyFeatureCalculator._range_breakout(df_ind)
        hod_lod = StrategyFeatureCalculator._hod_lod_proximity(df_ind)
        momentum = StrategyFeatureCalculator._momentum_burst(df_ind)
        vol_climax = StrategyFeatureCalculator._volume_climax(df_ind)
        patterns = StrategyFeatureCalculator._basic_patterns(df_ind)
        mean_rev = StrategyFeatureCalculator._mean_reversion_stretch(df_ind)
        vol_shift = StrategyFeatureCalculator._volatility_shift(df_ind)
        rsi_macd_div = StrategyFeatureCalculator._rsi_macd_divergence(df_ind)
        vol_div = StrategyFeatureCalculator._volume_divergence(df_ind)
        sup_res = StrategyFeatureCalculator._support_resistance_zones(df_ind)
        liquidity = StrategyFeatureCalculator._liquidity_sweep(df_ind)

        tf_norm = timeframe.lower()
        if "minute" in tf_norm:
            orb = StrategyFeatureCalculator._opening_range_breakout(df_ind)
        else:
            orb = {"status": "not_applicable"}

        fib_zone = {}
        try:
            price = float(df_ind["close"].iloc[-1])
            fib_zone = StrategyFeatureCalculator._fib_zone_for_price(price, fib_day)
        except Exception:
            fib_zone = {"zone": "unknown"}

        return {
            "trend": trend,
            "volatility": vol_regime,
            "range_breakout": range_brk,
            "opening_range": orb,
            "hod_lod": hod_lod,
            "momentum_burst": momentum,
            "volume_climax": vol_climax,
            "pattern": patterns,
            "fib_zone": fib_zone,
            "mean_reversion": mean_rev,
            "volatility_shift": vol_shift,
            "rsi_macd_divergence": rsi_macd_div,
            "volume_divergence": vol_div,
            "support_resistance": sup_res,
            "liquidity_sweep": liquidity,
        }

    @staticmethod
    def summarize_strategies(
        candles_by_timeframe: Dict[str, Any],
        fib: Optional[Dict[str, Any]] = None,
    ) -> Dict[str, Any]:
        """Build per-timeframe and multi-timeframe strategy-style features."""

        if not _PANDAS_OK:
            return {
                "per_timeframe": {},
                "multi_timeframe": {"trend_alignment": "unknown"},
            }

        fib_day = None
        if fib and isinstance(fib, dict):
            if "swing_high" in fib and "retracements" in fib:
                fib_day = fib
            else:
                fib_day = fib.get("day")

        per_tf: Dict[str, Any] = {}

        for tf, candles in candles_by_timeframe.items():
            df = StrategyFeatureCalculator._safe_df(candles)
            summary = StrategyFeatureCalculator.summarize_timeframe(df, tf, fib_day=fib_day)
            if summary:
                per_tf[tf] = summary

        directions = [v.get("trend", {}).get("direction") for v in per_tf.values() if v.get("trend")]
        alignment = "unknown"
        if directions:
            if all(d == "uptrend" for d in directions):
                alignment = "aligned_up"
            elif all(d == "downtrend" for d in directions):
                alignment = "aligned_down"
            else:
                alignment = "mixed"

        return {
            "per_timeframe": per_tf,
            "multi_timeframe": {"trend_alignment": alignment},
        }
