import datetime
import pandas as pd
from scipy.signal import find_peaks
import numpy as np
import time
import requests
from kiteconnect import KiteConnect
from ta import trend, momentum
import logging
import os
from typing import List, Dict, Tuple, Optional
from bs4 import BeautifulSoup
from fastapi import APIRouter, Depends, HTTPException
from app.db import database
from app.v1.dependencies.auth import get_current_userdetails
from app.v1.services.zerodha.client import ZerodhaClient
import difflib
import json
from datetime import datetime, timedelta

# Configure logging
logging.basicConfig(level=logging.INFO,
                    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger("BearishDivergenceStrategy")

router = APIRouter()

class BearishDivergenceStrategy:
    def __init__(self, api_key: str, api_secret:str, access_token: str, db):
        self.kite = KiteConnect(api_key=api_key)
        self.kite.set_access_token(access_token)
        self.db = db
        self.ema_periods = [5, 9, 15, 21, 30, 55, 100, 200]
        self.instruments = self._load_instruments()

    def _load_instruments(self) -> List[Dict]:
        """Load instruments with caching"""
        CACHE_EXPIRY = timedelta(hours=12)
        
        cached = self.db["zerodha_instruments"].find_one({"type": "nse_equity"})
        if cached and (datetime.now() - cached["last_updated"]) < CACHE_EXPIRY:
            return cached["instruments"]
            
        try:
            instruments = self.kite.instruments("NSE")
            self.db["zerodha_instruments"].update_one(
                {"type": "nse_equity"},
                {"$set": {
                    "instruments": instruments,
                    "last_updated": datetime.now()
                }},
                upsert=True
            )
            return instruments
        except Exception as e:
            logger.error(f"Instruments load failed: {str(e)}")
            return cached["instruments"] if cached else []


    def map_et_to_zerodha(self, et_names: List[str]) -> Dict[str, Optional[str]]:
        """Improved mapping with better name cleaning.

        First try fuzzy matching against Zerodha instrument "name"; if that
        fails for a given ET entry, fall back to the GPT-based mapper in
        `list.map_company_to_symbol` to avoid hard failures for names like
        TRANSFORMERS_RECTIFIERS_INDIA_LTD.
        """
        symbol_map: Dict[str, Optional[str]] = {}
        name_to_symbol = {
            inst['name'].upper().replace('-', ' '): inst['tradingsymbol']
            for inst in self.instruments
        }

        # 1) Primary: fuzzy match on cleaned names
        unmapped: List[str] = []
        for et_name in et_names:
            try:
                clean_name = (
                    et_name.replace('_', ' ')
                    .replace(' LTD', '')
                    .replace(' LIMITED', '')
                    .replace(' INDIA', '')
                    .strip()
                    .upper()
                )

                matches = difflib.get_close_matches(
                    clean_name,
                    name_to_symbol.keys(),
                    n=1,
                    cutoff=0.7,
                )

                if matches:
                    symbol_map[et_name] = name_to_symbol[matches[0]]
                else:
                    symbol_map[et_name] = None
                    unmapped.append(clean_name)

            except Exception as e:
                logger.warning(f"Failed to map {et_name} via fuzzy match: {str(e)}")
                symbol_map[et_name] = None
                unmapped.append(et_name.replace('_', ' ').upper())

        # 2) Fallback: GPT-based mapping for any still-unmapped names
        try:
            from .list import map_company_to_symbol

            # Deduplicate while preserving order
            seen = set()
            fallback_input: List[str] = []
            for name in unmapped:
                if name not in seen:
                    seen.add(name)
                    fallback_input.append(name)

            if fallback_input:
                logger.info("GPT fallback for ET→NSE mapping: %s", fallback_input)
                gpt_symbols = map_company_to_symbol(fallback_input)

                # Build a simple index from cleaned name to suggested symbol
                gpt_map: Dict[str, str] = {}
                for idx, company in enumerate(fallback_input):
                    if idx < len(gpt_symbols):
                        sym = (gpt_symbols[idx] or "").strip().upper()
                        if sym:
                            gpt_map[company] = sym

                # Apply GPT suggestions to original keys that are still None
                for original, current in list(symbol_map.items()):
                    if current:
                        continue
                    cleaned = (
                        original.replace('_', ' ')
                        .replace(' LTD', '')
                        .replace(' LIMITED', '')
                        .replace(' INDIA', '')
                        .strip()
                        .upper()
                    )
                    if cleaned in gpt_map:
                        symbol_map[original] = gpt_map[cleaned]

        except Exception as e:
            logger.warning("GPT fallback for ET mapping failed: %s", e)

        return symbol_map
    
    def get_instrument_token(self, symbol: str) -> Optional[int]:
        """Get instrument token from cached data"""
        for inst in self.instruments:
            if inst['tradingsymbol'] == symbol:
                return inst['instrument_token']
        return None

    
    def fetch_historical_data(self, instrument_token: int, interval: str, 
                             days: int) -> pd.DataFrame:
        """Fetch historical data for the given instrument"""
        to_date = datetime.now()
        from_date = to_date - timedelta(days=days)
        
        data = self.kite.historical_data(
            instrument_token=instrument_token,
            from_date=from_date,
            to_date=to_date,
            interval=interval,
            continuous=False,
            oi=False
        )
        return pd.DataFrame(data)
    
    def is_price_above_all_emas(self, df: pd.DataFrame) -> bool:
        """Check if price is above all EMAs"""
        if df.empty:
            return False
            
        for period in self.ema_periods:
            ema_indicator = trend.EMAIndicator(df['close'], window=period)
            ema = ema_indicator.ema_indicator()
            if df['close'].iloc[-1] <= ema.iloc[-1]:
                return False
        return True
    
    def is_near_upper_circuit(
        self,
        ltp: float,
        instrument_token: int,
        tol: float = 0.02
    ) -> bool:
        """
        Check if `ltp` is within `tol` (e.g. 2%) of the exchange-defined
        upper-circuit limit for this instrument. Falls back to False if
        we can't find the true circuit value.
        """
        # 1) Look up the instrument in your cached list
        inst = next(
            (i for i in self.instruments
             if i.get("instrument_token") == instrument_token),
            None
        )
        if not inst:
            return False

        # 2) Pull the actual circuit limit field
        #    (Zerodha's NSE instrument list includes `upper_price_range`)
        upper = inst.get("upper_price_range") or inst.get("upper_circuit_limit")
        if not upper or upper <= 0:
            return False

        # 3) True proximity check
        return (upper - ltp) / upper <= tol
    
    def calculate_volume_drop(self, df: pd.DataFrame, days: int = 3) -> List[float]:
        """
        Compute the % change in volume for each of the last `days` days:
        drop_pct_i = (volume_i-1 – volume_i) / volume_i-1 * 100
        
        Returns a list of floats [drop_day1, drop_day2, …], where:
        drop_day1 = % change from 4 days ago → 3 days ago
        drop_day2 = % change from 3 days ago → 2 days ago
        drop_day3 = % change from 2 days ago → yesterday
        """
        # need at least days+1 rows to compare pairs
        if len(df) < days + 1:
            return [0.0] * days

        vols = df['volume'].iloc[-(days+1):].values  # e.g. [..., vol[-4], vol[-3], vol[-2], vol[-1]]
        drops = []
        # for i in range(1..days): compare vols[i-1] → vols[i]
        for i in range(1, days+1):
            prev = vols[i-1]
            curr = vols[i]
            if prev == 0:
                drops.append(0.0)
            else:
                drops.append(((prev - curr) / prev) * 100)

        return [round(d, 2) for d in drops]

    
    def calculate_pivot_points(self, df: pd.DataFrame) -> Dict[str, float]:
        """Calculate standard pivot points"""
        if df.empty:
            return {}
            
        prev = df.iloc[-1]
        high, low, close = prev['high'], prev['low'], prev['close']
        pivot = (high + low + close) / 3
        return {
            'R1': 2 * pivot - low,
            'R2': pivot + (high - low),
            'R3': pivot + 2 * (high - low),
            'S1': 2 * pivot - high,
            'S2': pivot - (high - low),
            'S3': pivot - 2 * (high - low)
        }
    
    def is_near_resistance(self, price: float, pivot_points: Dict[str, float]) -> bool:
        """Check if price is near resistance levels"""
        for level in ['R1', 'R2', 'R3']:
            resistance = pivot_points.get(level, 0)
            if resistance > 0 and abs(price - resistance) / price <= 0.02:
                return True
        return False
    
    def calculate_ema5_gap(self, price: float, df: pd.DataFrame) -> float:
        """Calculate gap between price and EMA5"""
        if df.empty:
            return 0.0
            
        ema_indicator = trend.EMAIndicator(df['close'], window=5)
        ema5 = ema_indicator.ema_indicator().iloc[-1]
        if ema5 == 0:
            return 0.0
        return ((price - ema5) / ema5) * 100
    
    def detect_bearish_divergence(self, df: pd.DataFrame) -> Tuple[bool, bool]:
        """
        Detect bearish divergence using RSI and MACD on the given DataFrame.

        Returns (rsi_divergence, macd_divergence).
        """
        # need at least enough bars to find two peaks
        if len(df) < 15:
            return False, False

        # 1) Compute indicators
        close = df['close'].values
        rsi = momentum.RSIIndicator(df['close'], window=14).rsi().fillna(0).values
        macd = trend.MACD(df['close']).macd_diff().fillna(0).values

        # 2) Find peaks
        # - price: require at least 1% move and 5 bars apart
        price_peaks, _ = find_peaks(close, distance=5, prominence=close * 0.01)
        # - RSI & MACD: only distance rule (can also add prominence if you like)
        rsi_peaks, _  = find_peaks(rsi, distance=5)
        macd_peaks, _ = find_peaks(macd, distance=5)

        # need two swing highs in price
        if len(price_peaks) < 2:
            return False, False

        # take the last two peaks
        p1, p2 = price_peaks[-2], price_peaks[-1]

        # RSI divergence?
        rsi_div = False
        if len(rsi_peaks) >= 2:
            r1, r2 = rsi_peaks[-2], rsi_peaks[-1]
            # price made a higher high, RSI made a lower high
            if close[p2] > close[p1] and rsi[r2] < rsi[r1]:
                rsi_div = True

        # MACD divergence?
        macd_div = False
        if len(macd_peaks) >= 2:
            m1, m2 = macd_peaks[-2], macd_peaks[-1]
            if close[p2] > close[p1] and macd[m2] < macd[m1]:
                macd_div = True

        return rsi_div, macd_div
    
    def find_peaks(self, series: pd.Series, lookback: int = 3) -> List[int]:
        """Find peak indices in a series"""
        peaks = []
        for i in range(lookback, len(series)-lookback):
            if (series.iloc[i] == series.iloc[i-lookback:i+lookback+1].max() and
                series.iloc[i] > series.iloc[i-1] and
                series.iloc[i] > series.iloc[i+1]):
                peaks.append(i)
        return peaks
    
    def calculate_score(self, volume_drop: float, ema5_gap: float, 
                       rsi_div: bool, macd_div: bool, near_res: bool) -> float:
        """Calculate a score for the stock based on bearish factors"""
        score = 0
        
        # Volume drop contributes up to 30 points
        score += min(30, volume_drop * 0.3)
        
        # EMA5 gap contributes up to 30 points
        score += min(30, ema5_gap * 0.6)
        
        # Divergence and resistance factors
        if rsi_div and macd_div:
            score += 25
        elif rsi_div or macd_div:
            score += 15
            
        if near_res:
            score += 15
            
        return min(100, score)
    
    @staticmethod
    def get_tradingsymbols_from_names(company_names: List[str],
                                      instruments: List[Dict]) -> List[str]:
        tradingsymbols = []
        instrument_names = {inst['name'].upper(): inst['tradingsymbol']
                            for inst in instruments}

        for cname in company_names:
            cleaned = cname.replace('_', ' ').upper()
            match = difflib.get_close_matches(
                cleaned, instrument_names.keys(), n=1, cutoff=0.6)
            if match:
                tradingsymbols.append(instrument_names[match[0]])
            else:
                logger.warning(f"No match found for {cname}")
        return tradingsymbols
    


    def calculate_score(self, volume_drop: float, ema5_gap: float, 
                       rsi_div: bool, macd_div: bool, near_res: bool) -> float:
        """Calculate a score for the stock based on bearish factors"""
        score = 0
        # Volume drop contributes up to 30 points
        score += min(30, volume_drop * 0.3)
        # EMA5 gap contributes up to 30 points
        score += min(30, ema5_gap * 0.6)
        # Divergence and resistance factors
        if rsi_div and macd_div:
            score += 25
        elif rsi_div or macd_div:
            score += 15
        if near_res:
            score += 15
        return min(100, score)

    # ---- New BULL modules ----
    def calculate_volume_rise(self, df: pd.DataFrame, days: int = 3) -> List[float]:
        """Compute % change in volume rise for bullish setup"""
        if len(df) < days + 1:
            return [0.0] * days
        vols = df['volume'].iloc[-(days+1):].values
        rises = []
        for i in range(1, days+1):
            prev, curr = vols[i-1], vols[i]
            if prev == 0 or curr <= prev:
                rises.append(0.0)
            else:
                rises.append(((curr - prev) / prev) * 100)
        return [round(r, 2) for r in rises]

    def detect_bullish_divergence(self, df: pd.DataFrame) -> Tuple[bool, bool]:
        """Detect bullish divergence using RSI and MACD"""
        if len(df) < 15:
            return False, False
        close = df['close'].values
        rsi = momentum.RSIIndicator(df['close'], window=14).rsi().fillna(0).values
        macd = trend.MACD(df['close']).macd_diff().fillna(0).values
        troughs, _ = find_peaks(-close, distance=5, prominence=close * 0.01)
        rsi_tr, _ = find_peaks(-rsi, distance=5)
        macd_tr, _ = find_peaks(-macd, distance=5)
        if len(troughs) < 2:
            return False, False
        t1, t2 = troughs[-2], troughs[-1]
        bull_rsi = False
        if len(rsi_tr) >= 2:
            r1, r2 = rsi_tr[-2], rsi_tr[-1]
            if close[t2] < close[t1] and rsi[r2] > rsi[r1]:
                bull_rsi = True
        bull_macd = False
        if len(macd_tr) >= 2:
            m1, m2 = macd_tr[-2], macd_tr[-1]
            if close[t2] < close[t1] and macd[m2] > macd[m1]:
                bull_macd = True
        return bull_rsi, bull_macd

    def is_near_support(self, price: float, pivots: Dict[str, float]) -> bool:
        """Check if price is near support levels"""
        for lvl in ['S1', 'S2', 'S3']:
            sup = pivots.get(lvl, 0)
            if sup > 0 and abs(price - sup) / price <= 0.02:
                return True
        return False

    def calculate_score_bull(self, volume_rise: float, ema5_gap: float,
                             rsi_div: bool, macd_div: bool, near_sup: bool) -> float:
        """Calculate a score for the stock based on bullish factors"""
        score = 0
        score += min(30, volume_rise * 0.3)
        score += min(30, ema5_gap * 0.6)
        if rsi_div and macd_div:
            score += 25
        elif rsi_div or macd_div:
            score += 15
        if near_sup:
            score += 15
        return min(100, score)

    def calculate_ema5_gap_bull(self, price, df):
        ema5 = trend.EMAIndicator(df['close'], window=5).ema_indicator().iloc[-1]
        if ema5 == 0: return 0.0
        # always show price-to-EMA5 gap (positive if price>EMA, negative if below)
        return round(((price - ema5) / ema5) * 100, 2)

    def run_strategy(self, symbols: List[str]) -> List[Dict]:
        """
        Run the bearish divergence strategy on the given symbols,
        but relax all skip-conditions to ensure every symbol is analyzed.
        Also runs a parallel bullish analysis and makes a final BUY/SELL/HOLD decision.
        """
        results = []

        # thresholds for relaxed checks
        MIN_DAILY_BARS = 15
        MIN_30M_BARS   = 30
        VOL_DROP_THRESH = 15.0
        VOL_RISE_THRESH = 15.0  # for bullish volume rise

        for symbol in symbols:
            try:
                token = self.get_instrument_token(symbol)
                if not token:
                    logger.info(f"{symbol}: no instrument token")
                    # proceed without skipping

                # fetch live quote
                quote = self.kite.quote([f"NSE:{symbol}"]).get(f"NSE:{symbol}", {})
                if not quote:
                    logger.info(f"{symbol}: no quote data")
                ltp = quote.get('last_price', np.nan)
                day_high = quote.get('ohlc', {}).get('high', 0.0)

                # 1) upper circuit proximity (log but do not skip)
                if self.is_near_upper_circuit(ltp, day_high):
                    logger.info(f"{symbol}: near upper circuit ({ltp}/{day_high}), continuing analysis")

                # 2) fetch history
                df_day = self.fetch_historical_data(token, "day", 30).sort_values('date')
                df_30 = self.fetch_historical_data(token, "30minute", 5).sort_values('date')

                # relaxed bar count logging
                if len(df_day) < MIN_DAILY_BARS:
                    logger.info(f"{symbol}: only {len(df_day)} daily bars (<{MIN_DAILY_BARS}), continuing")
                if len(df_30) < MIN_30M_BARS:
                    logger.info(f"{symbol}: only {len(df_30)} 30m bars (<{MIN_30M_BARS}), continuing")

                # 3) EMA alignment
                if not self.is_price_above_all_emas(df_day):
                    logger.info(f"{symbol}: price not above all EMAs, continuing")

                # 4) volume drop
                vol_drops = self.calculate_volume_drop(df_day.tail(4), days=3)
                avg_vol_drop = sum(vol_drops) / len(vol_drops)
                if avg_vol_drop < VOL_DROP_THRESH:
                    logger.info(f"{symbol}: avg volume drop {avg_vol_drop:.2f}% < {VOL_DROP_THRESH}%, continuing")
                
                # 5) pivot points & resistance
                pivots = self.calculate_pivot_points(df_day.tail(1))
                near_res = self.is_near_resistance(ltp, pivots)

                # 6) EMA5 gap
                ema5_gap = self.calculate_ema5_gap(ltp, df_day)

                # 7) divergence (bearish)
                rsi_div, macd_div = self.detect_bearish_divergence(df_30.tail(50))
                if not (rsi_div or macd_div):
                    logger.info(f"{symbol}: no bearish divergence, continuing")

                # 8) scoring (bearish)
                bear_score = self.calculate_score(avg_vol_drop, ema5_gap, rsi_div, macd_div, near_res)
                logger.info(f"{symbol}: bearish score {bear_score:.2f}")

                # ---- BULLISH STRATEGY ----
                # 9) bullish volume rise
                vol_rises = self.calculate_volume_rise(df_day.tail(4), days=3)
                avg_vol_rise = sum(vol_rises) / len(vol_rises)
                if avg_vol_rise < VOL_RISE_THRESH:
                    logger.info(f"{symbol}: avg volume rise {avg_vol_rise:.2f}% < {VOL_RISE_THRESH}%, continuing bullish check")

                # 10) support proximity
                supports = self.calculate_pivot_points(df_day.tail(1))
                near_sup = self.is_near_support(ltp, supports)

                # 11) EMA5 gap bullish context (if price below EMA5)
                ema5_gap_bull = self.calculate_ema5_gap_bull(ltp, df_day)

                # 12) divergence (bullish)
                rsi_bull, macd_bull = self.detect_bullish_divergence(df_30.tail(50))
                if not (rsi_bull or macd_bull):
                    logger.info(f"{symbol}: no bullish divergence, continuing")

                # 13) scoring (bullish)
                bull_score = self.calculate_score_bull(avg_vol_rise, ema5_gap_bull, rsi_bull, macd_bull, near_sup)
                logger.info(f"{symbol}: bullish score {bull_score:.2f}")

                # 14) final decision
                if bull_score - bear_score > 50:
                    decision = "BUY"
                    entry = round(ltp, 2)
                    target = round(ltp * 1.03, 2)
                elif bear_score - bull_score > 50:
                    decision = "SELL"
                    entry = round(ltp, 2)
                    target = round(ltp * 0.97, 2)
                else:
                    decision = "HOLD"
                    entry = round(ltp, 2)
                    target = None

                # record results
                results.append({
                    "symbol": symbol,
                    "instrument_token": token,
                    # bearish metrics
                    "volume_drop": round(avg_vol_drop, 2),
                    "ema5_gap": round(ema5_gap, 2),
                    "divergence_confirmed": "RSI & MACD" if (rsi_div and macd_div) else ("RSI" if rsi_div else ("MACD" if macd_div else "")),
                    "near_resistance": near_res,
                    "bear_score": round(bear_score, 2),
                    # bullish metrics
                    "volume_rise": round(avg_vol_rise, 2),
                    "ema5_gap_bull": round(ema5_gap_bull, 2),
                    "divergence_bull": "RSI & MACD" if (rsi_bull and macd_bull) else ("RSI" if rsi_bull else ("MACD" if macd_bull else "")),
                    "near_support": near_sup,
                    "bull_score": round(bull_score, 2),
                    # decision
                    "decision": decision,
                    "entry_price": entry,
                    "target_price": target
                })

            except Exception as e:
                logger.error(f"{symbol}: ERROR → {e}", exc_info=True)

        # sort by bearish score by default (or could sort by decision or combined metric)
        return sorted(results, key=lambda x: x['bear_score'], reverse=True)
