from datetime import datetime
import pandas as pd
import logging

from kiteconnect import KiteConnect
from ta.trend import EMAIndicator
from ta.momentum import RSIIndicator
from scipy.signal import find_peaks

from .strategy import BearishDivergenceStrategy as BaseStrategy

logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
logger = logging.getLogger("BearishSellDivergenceStrategy")


class BearishSellDivergenceStrategy(BaseStrategy):
    """
    Intraday Bearish Divergence Short-Sell Strategy
    Steps 1–10 implemented as separate, concise methods.
    """

    def __init__(self, api_key, api_secret, access_token, db,
                 capital=100_000, margin=1.0, near_resistance_tol=0.02,
                 divergence_lookback=15, ema_non_touch_bars=11, rsi_timeframe='30minute'):
        super().__init__(api_key, api_secret, access_token, db)
        self.capital = capital
        self.margin = margin
        self.near_res_tol = near_resistance_tol
        self.div_lb = divergence_lookback
        self.ema_bars = ema_non_touch_bars
        self.rsi_tf     = rsi_timeframe


    # helper to get token & LTP
    def _get_token_ltp(self, symbol):
        tok = self.get_instrument_token(symbol)
        qp = self.kite.quote([f"NSE:{symbol}"]).get(f"NSE:{symbol}", {})
        return tok, float(qp.get("last_price", 0.0))

    # Step 2.1
    def is_near_upper_circuit(self, ltp, token):
        inst = next((i for i in self.instruments if i["instrument_token"] == token), None)
        upper = inst and (inst.get("upper_price_range") or inst.get("upper_circuit_limit") or 0)
        return bool(upper and (upper - ltp) / upper <= self.near_res_tol)

    # Step 2
    def ema_non_touch(self, symbol):
        tok, _ = self._get_token_ltp(symbol)
        df = self.fetch_historical_data(tok, "5minute", days=1).sort_values("date")
        if len(df) < self.ema_bars:
            return False
        e15 = EMAIndicator(df.close, window=15).ema_indicator()
        return not (df.close.tail(self.ema_bars) <= e15.tail(self.ema_bars)).any()

    # Step 3
    def vol_mom_drop(self, symbol):
        tok, _ = self._get_token_ltp(symbol)
        df = self.fetch_historical_data(tok, "5minute", days=1).sort_values("date").tail(10)
        if len(df) < 10:
            return False, 0.0, 0.0
        prev, rec = df.head(5), df.tail(5)
        ok = rec.volume.mean() < prev.volume.mean() and \
             rec.close.diff().abs().mean() < prev.close.diff().abs().mean()
        vd = (prev.volume.mean() - rec.volume.mean()) / prev.volume.mean() * 100 if prev.volume.mean() else 0.0
        md = (prev.close.diff().abs().mean() - rec.close.diff().abs().mean()) / prev.close.diff().abs().mean() * 100 if prev.close.diff().abs().mean() else 0.0
        return ok, round(vd, 2), round(md, 2)

    # Step 4
    def gap_metrics(self, symbol, ltp)-> dict:
        tok, _ = self._get_token_ltp(symbol)
        df = self.fetch_historical_data(tok, "5minute", days=1)
        df['date'] = pd.to_datetime(df['date'])
        today = datetime.now().date()
        df = df[df['date'].dt.date == today]  # only today's bars
        e5 = EMAIndicator(df.close, window=5).ema_indicator().iloc[-1]
        return {"ema5_gap_pct": round((ltp - e5) / e5 * 100, 2) if e5 else 0.0}

    # Step 5
    # --- STEP 5: RSI Divergence with configurable timeframe ---
    def detect_rsi_divergence(self, symbol, timeframe=None):
        tf = timeframe or self.rsi_tf
        tok, _ = self._get_token_ltp(symbol)
        df = self.fetch_historical_data(tok, tf, days=1).sort_values("date").tail(self.div_lb)
        if len(df) < 5:
            return False
        highs = df.high.values
        rsi   = RSIIndicator(df.close, window=14).rsi().fillna(0).values
        peaks, _ = find_peaks(highs, distance=3, prominence=highs * 0.005)
        return (
            len(peaks) >= 2 and
            highs[peaks[-1]] > highs[peaks[-2]] and   # **higher** highs
            rsi[peaks[-1]] < rsi[peaks[-2]]            # narrowing RSI
        )

    # Step 6
    # --- STEP 6: Rising Wedge — price higher highs + decreasing volume ---
    def detect_rising_wedge(self, symbol):
        tok, _ = self._get_token_ltp(symbol)
        df = self.fetch_historical_data(tok, "5minute", days=1)
        highs, vols = df.high.values, df.volume.values
        peaks, _ = find_peaks(highs, distance=5, prominence=highs * 0.005)
        return (
            len(peaks) >= 2 and
            highs[peaks[-1]] > highs[peaks[-2]] and  # higher highs
            vols[peaks[-1]] < vols[peaks[-2]]        # tapering volume
        )

    # Step 7
    def calc_quantity(self, price):
        return max(int((self.capital * self.margin) / price), 1)

    # Step 8
    # --- STEP 8: Extended pivot levels R1–R3 & S1–S3 ---
    def detect_near_pivot(self, symbol, ltp):
        tok, _ = self._get_token_ltp(symbol)
        row = self.fetch_historical_data(tok, "day", days=30).tail(1).iloc[0]
        H, L, C = row.high, row.low, row.close
        pivot = (H + L + C) / 3
        pivots = {
            'R1': 2*pivot - L,
            'R2': pivot + (H - L),
            'R3': H + 2*(pivot - L),
            'S1': 2*pivot - H,
            'S2': pivot - (H - L),
            'S3': L - 2*(H - pivot)
        }
        for lvl, val in pivots.items():
            if val > 0 and abs(ltp - val) / val <= self.near_res_tol:
                return True, lvl
        return False, ''

    # Step 9
    def entry_condition(self, symbol):
        tok, _ = self._get_token_ltp(symbol)
        df = self.fetch_historical_data(tok, "5minute", days=1).sort_values("date")
        return len(df) >= 12 and df.close.iloc[-1] > df.close.iloc[-12]

    # Step 10
    def exit_price(self, symbol):
        tok, _ = self._get_token_ltp(symbol)
        df = self.fetch_historical_data(tok, "5minute", days=1)
        e9 = EMAIndicator(df.close, window=9).ema_indicator().iloc[-1]
        typ = (df.high + df.low + df.close) / 3
        vwap = (typ * df.volume).sum() / df.volume.sum() if df.volume.sum() else 0
        candidates = [p for p in (e9, vwap) if p > 0]
        return round(min(candidates), 2) if candidates else 0.0

    # scoring
    def calculate_score(self, vd, md, vwap_gap, rsi, wedge, pivot_ok):
        return round(
            vd * 0.2 +
            md * 0.2 +
            vwap_gap * 0.2 +
            (1 if rsi else 0) * 0.2 +
            (1 if wedge else 0) * 0.1 +
            (1 if pivot_ok else 0) * 0.1,
            2
        )

    # process single symbol
    def process_symbol(self, symbol):
        token, ltp = self._get_token_ltp(symbol)
        near_up = self.is_near_upper_circuit(ltp, token)
        non_touch = self.ema_non_touch(symbol)
        vol_ok, vd, md = self.vol_mom_drop(symbol)
        entry_ok = self.entry_condition(symbol)

        skips = []
        if near_up:          skips.append("near upper circuit")
        if not non_touch:    skips.append("EMA touched")
        if not vol_ok:       skips.append("no vol/mom drop")
        if not entry_ok:     skips.append("entry condition failed")

        #vwap_gap = self.gap_metrics(symbol, ltp)["vwap_gap"]
        ema5_gap = self.gap_metrics(symbol, ltp)["ema5_gap_pct"]

        rsi_div = self.detect_rsi_divergence(symbol)
        wedge   = self.detect_rising_wedge(symbol)
        pivot_ok, pivot_lvl = self.detect_near_pivot(symbol, ltp)
        qty     = self.calc_quantity(ltp)
        xp      = self.exit_price(symbol)
        #score   = self.calculate_score(vd, md, vwap_gap, rsi_div, wedge, pivot_ok)
        score = self.calculate_score(vd, md, ema5_gap, rsi_div, wedge, pivot_ok)

        if skips:
            logger.info(f"{symbol} skipped: {skips}")
        else:
            logger.info(f"{symbol} INCLUDED (score={score})")

        return {
            "symbol":             symbol,
            "instrument_token":   token,
            "entry_price":        round(ltp, 2),
            "target_price":         xp,
            "quantity":           qty,
            "bear_score":         score,
            "volume_drop_pct":    vd,
            "momentum_drop_pct":  md,
            "ema5_gap_pct":       ema5_gap,
            #"vwap_gap_pct":       vwap_gap,
            "rsi_divergence":     rsi_div,
            "rising_wedge":       wedge,
            "pivot_level":        pivot_lvl,
            "near_upper_circuit": near_up,
            "skip_reasons":       skips
        }

    def run_strategy(self, symbols):
        # Process up to 25 symbols
        results = [self.process_symbol(s) for s in symbols[:15]]
        df = pd.DataFrame(results)
        # Sort by bearish score, reset index, and return a list of dicts
        return (
            df
            .sort_values('bear_score', ascending=False)
            .reset_index(drop=True)
            .to_dict(orient='records')
        )
