from datetime import datetime, timedelta
import pandas as pd
import logging

from kiteconnect import KiteConnect
from ta.trend import EMAIndicator
from scipy.signal import find_peaks

# import your existing base class
from .strategy import BearishDivergenceStrategy as BaseStrategy

# configure logging
logging.basicConfig(
    level=logging.DEBUG,
    format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger("BearishSellDivergenceStrategy")


class BearishSellDivergenceStrategy(BaseStrategy):
    """
    Intraday Bearish Divergence Short-Sell Strategy
    Implements Steps 1–10; returns only the symbols that qualify,
    processing at most the first 25 symbols passed in.
    """
    def __init__(
        self,
        api_key: str,
        api_secret: str,
        access_token: str,
        db,
        *,
        capital: float = 100_000,
        margin: float = 1.0,
        near_resistance_tol: float = 0.02,
        divergence_lookback: int = 15,
        ema_non_touch_bars: int = 11
    ):
        super().__init__(api_key, api_secret, access_token, db)
        self.capital = capital
        self.margin = margin
        self.near_resistance_tol = near_resistance_tol
        self.divergence_lookback = divergence_lookback
        self.ema_non_touch_bars = ema_non_touch_bars

    def is_near_upper_circuit(self, ltp: float, instrument_token: int, tol: float = None) -> bool:
        inst = next((i for i in self.instruments if i.get("instrument_token") == instrument_token), None)
        if not inst:
            return False
        upper = inst.get("upper_price_range") or inst.get("upper_circuit_limit") or inst.get("day_high") or 0
        if upper <= 0:
            return False
        tolerance = tol if tol is not None else self.near_resistance_tol
        return (upper - ltp) / upper <= tolerance

    def _ema_non_touch(self, df: pd.DataFrame, period: int = 15) -> bool:
        if len(df) < self.ema_non_touch_bars:
            return False
        df = df.sort_values("date")
        ema = EMAIndicator(df["close"], window=period).ema_indicator()
        last = df["close"].tail(self.ema_non_touch_bars)
        last_ema = ema.tail(self.ema_non_touch_bars)
        return not (last <= last_ema).any()

    def _detect_rising_wedge(self, df: pd.DataFrame) -> bool:
        if len(df) < 10:
            return False
        highs = df["high"].values
        vols = df["volume"].values
        peaks, _ = find_peaks(highs, distance=5, prominence=highs * 0.005)
        if len(peaks) < 2:
            return False
        p1, p2 = peaks[-2], peaks[-1]
        return highs[p2] < highs[p1] and vols[p2] < vols[p1]

    def run_strategy(self, symbols: list) -> list:
        """
        Run all 10 steps, but:
          - only on the first 25 symbols,
          - and always return each symbol along with its skip_reasons.
        Logs skip reasons for each symbol at INFO level.
        """
        results = []

        for symbol in symbols[:25]:
            skip_reasons = []
            # defaults
            volume_drop = mom_drop = 0.0
            ema5_gap = ema9_gap = ema15_gap = vwap_gap = 0.0
            rsi_div = macd_div = wedge = False
            non_touch = False
            entry_cond = False
            near_up = near_res = False
            pivot_level = ""
            entry = target = stop = 0.0
            qty = 0
            bear_score = 0.0

            try:
                # --- Step 2.1: Upper-circuit check ---
                token = self.get_instrument_token(symbol)
                quote = self.kite.quote([f"NSE:{symbol}"]).get(f"NSE:{symbol}", {})
                ltp = float(quote.get("last_price", 0.0))
                if ltp <= 0:
                    skip_reasons.append("invalid LTP")

                near_up = self.is_near_upper_circuit(ltp, token)
                if near_up:
                    skip_reasons.append("near upper circuit")

                # fetch candles
                df5m = self.fetch_historical_data(token, "5minute", days=1)
                df30m = self.fetch_historical_data(token, "30minute", days=1)
                df_day = self.fetch_historical_data(token, "day", days=30)

                # --- Step 2: 15EMA non-touch ---
                non_touch = self._ema_non_touch(df5m)
                if not non_touch:
                    skip_reasons.append("EMA touched")

                # --- Step 3: Volume & momentum drop ---
                last10 = df5m.sort_values("date").tail(10)
                if len(last10) == 10:
                    prv, rec = last10.head(5), last10.tail(5)
                    prev_vol = prv["volume"].mean()
                    if prev_vol > 0:
                        volume_drop = ((prev_vol - rec["volume"].mean()) / prev_vol) * 100
                    else:
                        skip_reasons.append("zero previous volume")
                    mom_prev = prv["close"].diff().abs().mean()
                    mom_last = rec["close"].diff().abs().mean()
                    if mom_prev > 0:
                        mom_drop = ((mom_prev - mom_last) / mom_prev) * 100
                    else:
                        skip_reasons.append("zero previous momentum")
                else:
                    skip_reasons.append("insufficient 5m bars")

                # --- Step 4: EMA/VWAP gaps (intraday) ---
                if not df5m.empty:
                    ema5   = EMAIndicator(df5m["close"], window=5).ema_indicator().iloc[-1]
                    ema9   = EMAIndicator(df5m["close"], window=9).ema_indicator().iloc[-1]
                    ema15  = EMAIndicator(df5m["close"], window=15).ema_indicator().iloc[-1]
                    if ema5:   ema5_gap  = ((ltp - ema5)  / ema5)  * 100
                    if ema9:   ema9_gap  = ((ltp - ema9)  / ema9)  * 100
                    if ema15:  ema15_gap = ((ltp - ema15) / ema15) * 100
                    typ = (df5m["high"] + df5m["low"] + df5m["close"]) / 3
                    tot = df5m["volume"].sum()
                    if tot > 0:
                        vwap = (typ * df5m["volume"]).sum() / tot
                        vwap_gap = ((ltp - vwap) / vwap) * 100
                    else:
                        skip_reasons.append("zero intraday volume")
                else:
                    skip_reasons.append("no intraday data")

                # --- Step 5: RSI divergence (30m) ---
                lb = min(self.divergence_lookback, len(df30m))
                if lb >= 2:
                    rsi_div, macd_div = self.detect_bearish_divergence(df30m.tail(lb))
                    if not (rsi_div or macd_div):
                        skip_reasons.append("no divergence")
                else:
                    skip_reasons.append("insufficient 30m bars")

                # --- Step 6: Rising wedge (5m) ---
                wedge = self._detect_rising_wedge(df5m)
                if not wedge:
                    skip_reasons.append("no wedge")

                # --- Step 7: Position sizing ---
                qty = int((self.capital * self.margin) / ltp) if ltp else 0

                # --- Step 8: Pivot proximity (daily) ---
                pivots = self.calculate_pivot_points(df_day.tail(1))
                near_res = any(
                    abs(ltp - p) / p <= self.near_resistance_tol
                    for p in pivots.values() if p > 0
                )
                if near_res:
                    pivot_level = next(
                        lvl for lvl, p in pivots.items()
                        if p > 0 and abs(ltp - p) / p <= self.near_resistance_tol
                    )
                else:
                    skip_reasons.append("not near pivot")

                # composite bearish score
                bear_score = (
                    volume_drop * 0.2 +
                    mom_drop * 0.2 +
                    ((ema5_gap + ema9_gap + ema15_gap + vwap_gap) / 4) * 0.2 +
                    (1 if (rsi_div or macd_div) else 0) * 0.2 +
                    (1 if wedge else 0) * 0.1 +
                    (1 if near_res else 0) * 0.1
                )

                # --- Step 9: Entry at 60th minute ---
                if len(df5m) >= 12:
                    entry_close = df5m.sort_values("date")["close"].iloc[-1]
                    close_55    = df5m.sort_values("date")["close"].iloc[-12]
                    entry_cond  = entry_close > close_55
                    if not entry_cond:
                        skip_reasons.append("entry condition failed")
                else:
                    skip_reasons.append("no 60m entry data")

                # --- Step 10: Exit on closest support (EMA-9 or VWAP) ---
                entry = round(ltp, 2)
                cands = []
                if 'ema9' in locals(): cands.append(ema9)
                if 'vwap' in locals(): cands.append(vwap)
                exit_price = min(cands) if cands else entry * 0.97
                target     = round(exit_price, 2)
                stop       = round(ltp * 1.01, 2)

                # log all reasons (if any) and always append full result:
                if skip_reasons:
                    logger.info(f"{symbol} skipped: {skip_reasons}")
                else:
                    logger.info(f"{symbol} INCLUDED with score={bear_score:.2f}")

                results.append({
                    "symbol":            symbol,
                    "instrument_token":  token or 0,
                    "entry_price":       entry,
                    "target_price":      target,
                    "stop_loss":         stop,
                    "quantity":          qty,
                    "bear_score":        round(bear_score, 2),
                    "volume_drop_pct":   round(volume_drop, 2),
                    "momentum_drop_pct": round(mom_drop, 2),
                    "ema5_gap_pct":      round(ema5_gap, 2),
                    "ema9_gap_pct":      round(ema9_gap, 2),
                    "ema15_gap_pct":     round(ema15_gap, 2),
                    "vwap_gap_pct":      round(vwap_gap, 2),
                    "rsi_divergence":    rsi_div,
                    "macd_divergence":   macd_div,
                    "rising_wedge":      wedge,
                    "ema_non_touch":     non_touch,
                    "entry_condition":   entry_cond,
                    "pivot_level":       pivot_level,
                    "near_resistance":   near_res,
                    "near_upper_circuit":near_up,
                    "skip_reasons":      skip_reasons,
                })

            except Exception as e:
                logger.error(f"Error scanning {symbol}: {e}", exc_info=True)

        # final sort by bearish score high→low
        #return sorted(results, key=lambda x: x["bear_score"], reverse=True)
        return results
