import asyncio
import json
from fastapi import WebSocket, WebSocketDisconnect
from aiokafka import AIOKafkaProducer
from datetime import datetime,timedelta
from app.utils.connections import connected_clients
from app.db import database

producer: AIOKafkaProducer = None


async def start_kafka_producer():
    global producer
    producer = AIOKafkaProducer(bootstrap_servers="localhost:9092")
    await producer.start()

async def stop_kafka_producer():
    if producer:
        await producer.stop()

async def websocket_tracking_data(websocket: WebSocket):
    await websocket.accept()
    try:
        while True:
            message = await websocket.receive_text()
            data = json.loads(message)
            data["received_at"] = datetime.utcnow().isoformat()
            if producer:
                await producer.send("vehicle_tracking", json.dumps(data).encode("utf-8"))
            await websocket.send_text("Message queued successfully.")
    except WebSocketDisconnect:
        pass

async def websocket_realtime_updates(websocket: WebSocket):
    await websocket.accept()
    connected_clients.add(websocket)
    print("✅ Client connected. Total:", len(connected_clients))  # ✅ Moved here

    try:
        while True:
            await asyncio.sleep(1)
    except WebSocketDisconnect:
        connected_clients.remove(websocket)
        print("❌ Client disconnected. Remaining:", len(connected_clients))

async def websocket_historical_route(websocket: WebSocket):
    await websocket.accept()
    try:
        query = await websocket.receive_text()
        query_data = json.loads(query)
        device_id = query_data.get("device_id")
        date_str = query_data.get("date")  # Format: YYYY-MM-DD

        if not device_id or not date_str:
            await websocket.send_text("Missing device_id or date.")
            await websocket.close()
            return

        # Parse and compute date range
        start_date = datetime.strptime(date_str, "%Y-%m-%d")
        end_date = start_date + timedelta(days=1)

        # DB connection
        mongo_gen = database.get_mongo_db()
        mongo = next(mongo_gen)
        collection = mongo["tracking_logs"]

        # Fetch historical points
        cursor = collection.find({
            "device_id": device_id,
            "received_at": {
                "$gte": start_date,
                "$lt": end_date
            }
        }, {"location": 1, "speed": 1, "heading": 1, "received_at": 1})

        # ✅ print only after cursor is assigned
        print("Query returned:", cursor)

        # Stream points one-by-one
        for doc in cursor:
            print(json.dumps(doc, default=str))
            await websocket.send_text(json.dumps(doc, default=str))
            await asyncio.sleep(0.2)  # Simulate streaming

        await websocket.send_text("done")
        await websocket.close()

        try:
            next(mongo_gen)  # Cleanup DB connection
        except StopIteration:
            pass

    except WebSocketDisconnect:
        print("Client disconnected.")
    except Exception as e:
        print(f"Error: {str(e)}")
        await websocket.send_text(f"Error: {str(e)}")
        await websocket.close()

async def websocket_workforce_schedules(websocket: WebSocket):
    await websocket.accept()  # ✅ Accept first

    try:
        msg = await websocket.receive_text()
        data = json.loads(msg)
        workforce_id = data.get("workforce_id")

        if not workforce_id:
            await websocket.send_text("Missing workforce_id.")
            await websocket.close()
            return

        connected_clients[workforce_id] = websocket
        print(f"✅ Workforce {workforce_id} connected. Total clients: {len(connected_clients)}")

        while True:
            await asyncio.sleep(1)

    except WebSocketDisconnect:
        if workforce_id in connected_clients:
            del connected_clients[workforce_id]
            print(f"❌ Workforce {workforce_id} disconnected. Remaining: {len(connected_clients)}")


