from kiteconnect import KiteConnect, KiteTicker
from typing import Optional, Union
import pandas as pd
import logging
import os
import time
import random

logger = logging.getLogger("ZerodhaClient")
logger.setLevel(logging.INFO)


class ZerodhaClient:
    def __init__(self, api_key: str, api_secret: str, access_token: Optional[str] = None):
        self.api_key = api_key
        self.api_secret = api_secret
        self.access_token = access_token
        self.kite = KiteConnect(api_key=self.api_key)

        if self.access_token:
            self.kite.set_access_token(self.access_token)

        # In-memory cache to reduce repeated instrument master downloads.
        self._instruments_cache = {}
        self._instruments_cache_ts = {}

        # Simple retry controls for Kite rate limits.
        self._max_retries = int(os.getenv("ZERODHA_MAX_RETRIES", "3"))
        self._retry_base_sleep = float(os.getenv("ZERODHA_RETRY_BASE_SLEEP_SEC", "0.6"))
        self._throttle_sleep = float(os.getenv("ZERODHA_THROTTLE_SLEEP_SEC", "0.05"))

    def _should_retry(self, e: Exception) -> bool:
        msg = str(e).lower()
        return ("too many requests" in msg) or ("429" in msg) or ("rate limit" in msg)

    def _sleep_backoff(self, attempt: int) -> None:
        # Exponential backoff with a little jitter.
        delay = self._retry_base_sleep * (2 ** max(attempt, 0))
        delay = delay + random.random() * 0.15
        time.sleep(delay)

    def get_instruments(self, exchange: str = "NSE", ttl_seconds: int = 43200):
        """Cached wrapper for Kite instruments master."""
        ex = (exchange or "NSE").strip().upper()
        now = time.time()
        ts = self._instruments_cache_ts.get(ex)
        if ts and (now - ts) < ttl_seconds and ex in self._instruments_cache:
            return self._instruments_cache[ex]

        # Fetch with a bit of retry for transient rate limits.
        last_err = None
        for attempt in range(self._max_retries + 1):
            try:
                if self._throttle_sleep > 0:
                    time.sleep(self._throttle_sleep)
                instruments = self.kite.instruments(ex)
                self._instruments_cache[ex] = instruments
                self._instruments_cache_ts[ex] = time.time()
                return instruments
            except Exception as e:
                last_err = e
                if attempt < self._max_retries and self._should_retry(e):
                    self._sleep_backoff(attempt)
                    continue
                raise
        raise last_err

    def generate_login_url(self) -> str:
        """Return the Kite login URL"""
        return self.kite.login_url()

    def generate_session(self, request_token: str) -> dict:
        """Exchange request_token for access_token"""
        try:
            session_data = self.kite.generate_session(
                request_token=request_token,
                api_secret=self.api_secret
            )
            self.access_token = session_data["access_token"]
            self.kite.set_access_token(self.access_token)
            logger.info("Zerodha session established")
            return session_data
        except Exception as e:
            logger.exception("Failed to generate Zerodha session")
            raise e

    def refresh_token(self, refresh_token: str) -> dict:
        """(Optional) If using refresh token mechanism"""
        try:
            session_data = self.kite.refresh_token(refresh_token)
            self.access_token = session_data["access_token"]
            self.kite.set_access_token(self.access_token)
            logger.info("Access token refreshed")
            return session_data
        except Exception as e:
            logger.exception("Failed to refresh access token")
            raise e

    def place_order(
        self,
        symbol: str,
        quantity: int,
        transaction_type: str,
        exchange: str = "NSE",
        order_type: str = "MARKET",
        product: str = "CNC",
        validity: str = "DAY",
        price: Optional[float] = None,
        trigger_price: Optional[float] = None,
        disclosed_quantity: Optional[int] = None,
        tag: Optional[str] = None
    ) -> str:
        """Place a market order"""
        try:
            payload = {
                "variety": self.kite.VARIETY_REGULAR,
                "tradingsymbol": symbol,
                "exchange": exchange,
                "transaction_type": transaction_type,
                "quantity": quantity,
                "order_type": order_type,
                "product": product,
                "validity": validity,
                "price": price,
                "trigger_price": trigger_price,
                "disclosed_quantity": disclosed_quantity,
                "tag": tag,
            }
            # Remove None values
            payload = {k: v for k, v in payload.items() if v is not None}

            order_id = self.kite.place_order(**payload)
            logger.info(f"Order placed successfully: {order_id}")
            return order_id
        except Exception as e:
            logger.exception(f"Order placement failed for {symbol}")
            raise e

    def get_profile(self) -> dict:
        try:
            return self.kite.profile()
        except Exception as e:
            logger.exception("Failed to fetch profile")
            raise e

    def get_holdings(self) -> list:
        try:
            return self.kite.holdings()
        except Exception as e:
            logger.exception("Error fetching holdings")
            raise e

    def get_positions(self) -> dict:
        try:
            return self.kite.positions()
        except Exception as e:
            logger.exception("Error fetching positions")
            raise e

    def get_order_history(self, order_id: str) -> list:
        try:
            return self.kite.order_history(order_id=order_id)
        except Exception as e:
            logger.exception("Error fetching order history")
            raise e

    def get_orders(self) -> list:
        try:
            return self.kite.orders()
        except Exception as e:
            logger.exception("Error fetching orders")
            raise e

    def get_trades(self) -> list:
        """Return recent executed trades (Kite tradebook).

        Note: KiteConnect returns the user's trades for the day/session.
        """
        try:
            return self.kite.trades()
        except Exception as e:
            logger.exception("Error fetching trades")
            raise e

    def get_margins(self, segment: Optional[str] = None) -> dict:
        """Return available margins (account-level).

        KiteConnect:
        - margins() => all segments
        - margins(segment) => specific segment, e.g. "equity"
        """
        try:
            if segment and str(segment).strip():
                return self.kite.margins(str(segment).strip())
            return self.kite.margins()
        except Exception as e:
            logger.exception("Error fetching margins")
            raise e

    def get_order_margins(self, orders: list) -> list:
        """Return per-order margin requirements.

        Input: list of order dicts for KiteConnect `order_margins`.
        This is read-only and does not place orders.
        """
        try:
            return self.kite.order_margins(orders)
        except Exception as e:
            logger.exception("Error fetching order margins")
            raise e

    def get_quote(self, symbols: list) -> dict:
        """symbols = ['NSE:RELIANCE', 'NSE:TCS']"""
        last_err = None
        for attempt in range(self._max_retries + 1):
            try:
                if self._throttle_sleep > 0:
                    time.sleep(self._throttle_sleep)
                return self.kite.quote(symbols)
            except Exception as e:
                last_err = e
                if attempt < self._max_retries and self._should_retry(e):
                    self._sleep_backoff(attempt)
                    continue
                logger.exception("Error fetching quotes")
                raise
        logger.exception("Error fetching quotes")
        raise last_err

    def get_historical_data(
        self,
        instrument_token: Union[str, int],
        interval: str,
        from_date: str,
        to_date: str,
        continuous: bool = False,
        oi: bool = False
    ) -> pd.DataFrame:
        """Fetch historical OHLC data"""
        last_err = None
        for attempt in range(self._max_retries + 1):
            try:
                if self._throttle_sleep > 0:
                    time.sleep(self._throttle_sleep)
                data = self.kite.historical_data(
                    instrument_token=instrument_token,
                    from_date=from_date,
                    to_date=to_date,
                    interval=interval,
                    continuous=continuous,
                    oi=oi
                )
                df = pd.DataFrame(data)
                if not df.empty:
                    df['date'] = pd.to_datetime(df['date'])
                    df.set_index('date', inplace=True)
                return df
            except Exception as e:
                last_err = e
                if attempt < self._max_retries and self._should_retry(e):
                    self._sleep_backoff(attempt)
                    continue
                logger.error(f"Error fetching historical data: {e}")
                raise
        logger.error(f"Error fetching historical data: {last_err}")
        raise last_err


    def get_historical_data_records(
        self,
        instrument_token: Union[str, int],
        *,
        interval: str,
        from_date,
        to_date,
        continuous: bool = False,
        oi: bool = False,
    ) -> list:
        """Fetch historical data as raw Kite records with retry/throttle.

        This is used for intraday-only snapshots where the caller wants the
        raw list (and does its own trimming/caching).
        """

        last_err = None
        for attempt in range(self._max_retries + 1):
            try:
                if self._throttle_sleep > 0:
                    time.sleep(self._throttle_sleep)
                data = self.kite.historical_data(
                    instrument_token=instrument_token,
                    from_date=from_date,
                    to_date=to_date,
                    interval=interval,
                    continuous=continuous,
                    oi=oi,
                )
                return data or []
            except Exception as e:
                last_err = e
                if attempt < self._max_retries and self._should_retry(e):
                    self._sleep_backoff(attempt)
                    continue
                logger.error(f"Error fetching historical records: {e}")
                raise
        logger.error(f"Error fetching historical records: {last_err}")
        raise last_err
