import asyncio
import json
import os
from dotenv import load_dotenv
from fastapi import WebSocket, WebSocketDisconnect
from aiokafka import AIOKafkaProducer
from datetime import datetime,timedelta
from app.utils.connections import connected_clients, device_subscriptions
from app.db import database
from bson import ObjectId, errors as bson_errors

from app.v1.services.platform.iotdevices import list_nearby_devices_service_socket
from app.v1.models.platform.iotdevices import NearbyDevicesRequest, IotDeviceWithMappingList
#from app.v1.sockets.tracking_data import notify_nearby_subscribers

# 🔹 Load env variables once
load_dotenv()
USE_KAFKA = os.getenv("USE_KAFKA", "false").lower() == "true"

producer: AIOKafkaProducer = None
active_nearby_subscribers = {}

async def start_kafka_producer():
    global producer
    producer = AIOKafkaProducer(bootstrap_servers="localhost:9092")
    await producer.start()

async def stop_kafka_producer():
    if producer:
        await producer.stop()

# async def websocket_tracking_data(websocket: WebSocket):
#     await websocket.accept()
#     try:
#         while True:
#             message = await websocket.receive_text()
#             data = json.loads(message)
#             data["received_at"] = datetime.utcnow().isoformat()
#             if producer:
#                 await producer.send("vehicle_tracking", json.dumps(data).encode("utf-8"))
#             await websocket.send_text("Message queued successfully.")
#     except WebSocketDisconnect:
#         pass

# async def websocket_tracking_data(websocket: WebSocket):
#     await websocket.accept()
#     mongo_gen = database.get_mongo_db()
#     mongo = next(mongo_gen)
#     tracking_logs = mongo["tracking_logs"]
#     alert_logs = mongo["alert_logs"]
#     devices = mongo["iot_devices"]
#     categories = mongo["alert_category"]
#     keywords = mongo["alert_keywords"]  # new collection

#     try:
#         while True:
#             message = await websocket.receive_text()
#             data = json.loads(message)
#             data["received_at"] = datetime.utcnow()

#             # 1. Identify device
#             device_id = data.get("device_id")
#             device = devices.find_one({"_id": ObjectId(device_id)})
#             if not device:
#                 await websocket.send_text("Unknown device")
#                 continue

#             # 2. Check for alert keywords in payload
#             payload_text = json.dumps(data).lower()
#             keyword_doc = keywords.find_one({
#                 "keyword": {"$in": payload_text.split()}
#             })

#             if keyword_doc:
#                 # 3. Map category from alert_category
#                 category = categories.find_one({"_id": keyword_doc["category_id"]})
#                 alert_doc = {
#                     "device_id": device_id,
#                     "account_id": device["account_id"],
#                     "category_id": category["_id"],
#                     "category_name": category["category_name"],
#                     "message": f"Alert detected: {keyword_doc['keyword']}",
#                     "raw_payload": data,
#                     "received_at": data["received_at"],
#                     "status": "unread"
#                 }
#                 alert_logs.insert_one(alert_doc)
#                 await websocket.send_text(f"✅ Alert stored under category: {category['category_name']}")
#             else:
#                 # 4. Normal tracking log
#                 tracking_logs.insert_one(data)
#                 if producer:
#                     await producer.send("vehicle_tracking", json.dumps(data, default=str).encode("utf-8"))
#                 await websocket.send_text("✅ Tracking log stored")

#     except WebSocketDisconnect:
#         pass

# async def websocket_tracking_data(websocket: WebSocket):
#     await websocket.accept()
#     mongo_gen = database.get_mongo_db()
#     mongo = next(mongo_gen)
#     tracking_logs = mongo["tracking_logs"]
#     alert_logs = mongo["alert_logs"]
#     devices = mongo["iot_devices"]
#     categories = mongo["alert_category"]
#     keywords = mongo["alert_keywords"]

#     try:
#         while True:
#             message = await websocket.receive_text()
#             data = json.loads(message)
#             data["received_at"] = datetime.utcnow()

#             # 1. Identify device
#             device_id = data.get("device_id")
#             device = devices.find_one({"_id": ObjectId(device_id)})
#             if not device:
#                 await websocket.send_text("Unknown device")
#                 continue

#             # 2. Check for alert keywords
#             payload_text = json.dumps(data).lower()
#             keyword_doc = keywords.find_one({
#                 "keyword": {"$in": payload_text.split()}
#             })

#             if keyword_doc:
#                 # 3. Map category
#                 category = categories.find_one({"_id": keyword_doc["category_id"]})
#                 alert_doc = {
#                     "device_id": device_id,
#                     "account_id": device["account_id"],
#                     "category_id": category["_id"],
#                     "category_name": category["category_name"],
#                     "message": f"Alert detected: {keyword_doc['keyword']}",
#                     "raw_payload": data,
#                     "received_at": data["received_at"],
#                     "status": "unread"
#                 }

#                 if USE_KAFKA and producer:
#                     await producer.send(
#                         "alerts",
#                         json.dumps(alert_doc, default=str).encode("utf-8")
#                     )
#                 else:
#                     alert_logs.insert_one(alert_doc)

#                 await websocket.send_text(f"✅ Alert processed under {category['category_name']}")

#             else:
#                 if USE_KAFKA and producer:
#                     # ✅ Publish tracking to Kafka
#                     await producer.send(
#                         "vehicle_tracking",
#                         json.dumps(data, default=str).encode("utf-8")
#                     )
#                 else:
#                     tracking_logs.insert_one(data)

#                 await websocket.send_text("✅ Tracking log processed")

#     except WebSocketDisconnect:
#         pass

# async def websocket_tracking_data(websocket: WebSocket):
#     await websocket.accept()
#     mongo_gen = database.get_mongo_db()
#     mongo = next(mongo_gen)
#     tracking_logs = mongo["tracking_logs"]
#     alert_logs = mongo["alert_logs"]
#     devices = mongo["iot_devices"]
#     categories = mongo["alert_category"]
#     keywords = mongo["alert_keywords"]
#     print("websocket_tracking_datawebsocket_tracking_datawebsocket_tracking_datawebsocket_tracking_data")
#     try:
#         while True:
#             message = await websocket.receive_text()
#             data = json.loads(message)
#             data["received_at"] = datetime.utcnow()
#             print("received_atreceived_atreceived_at")
#             print(data)
#             # 1. Identify device
#             device_id = data.get("device_id")           

#             try:
#                 device_obj_id = ObjectId(device_id)
#             except bson_errors.InvalidId:
#                 await websocket.send_text("❌ Invalid device_id format")
#                 continue

#             device = devices.find_one({"_id": device_obj_id})

#             if not device:
#                 await websocket.send_text("Unknown device")
#                 continue

#             # Default tracking entry
#             tracking_entry = {
#                 **data,
#                 "account_id": device["account_id"],
#                 "is_alert": False
#             }

#             # 2. Detect alerts
#             # payload_text = json.dumps(data).lower()
#             payload_text = json.dumps(data, default=str).lower()
#             keyword_doc = keywords.find_one({
#                 "keyword": {"$in": payload_text.split()}
#             })
#             print("device_iddevice_idddd")
#             print(keyword_doc)
#             if keyword_doc:
#                 # 3. Enrich tracking entry with alert info (NO status here)
#                 category = categories.find_one({"_id": keyword_doc["category_id"]})
#                 tracking_entry.update({
#                     "category_id": category["_id"],
#                     "category_name": category["category_name"],
#                     "message": f"Alert detected: {keyword_doc['keyword']}",
#                     "is_alert": True
#                 })

#                 # 4. Insert into alert_logs with status
#                 alert_doc = {
#                     "device_id": device_id,
#                     "account_id": device["account_id"],
#                     "category_id": category["_id"],
#                     "category_name": category["category_name"],
#                     "message": f"Alert detected: {keyword_doc['keyword']}",
#                     "raw_payload": data,
#                     "received_at": data["received_at"],
#                     "status": "unread"
#                 }
#                 if USE_KAFKA and producer:
#                     await producer.send(
#                         "alerts",
#                         json.dumps(alert_doc, default=str).encode("utf-8")
#                     )
#                 else:
#                     alert_logs.insert_one(alert_doc)

#             # 5. Always insert into tracking_logs
#             if USE_KAFKA and producer:
#                 print("Coming IFFF part 24111111")
#                 await producer.send(
#                     "vehicle_tracking",
#                     json.dumps(tracking_entry, default=str).encode("utf-8")
#                 )
#             else:
#                 print("Coming ELSE part 246666")
#                 tracking_logs.insert_one(tracking_entry)

#             if tracking_entry.get("is_alert"):
#                 await websocket.send_text(f"✅ Stored in tracking_logs + alert_logs ({tracking_entry['category_name']})")
#             else:
#                 await websocket.send_text("✅ Stored in tracking_logs")

#     except WebSocketDisconnect:
#         pass

# async def websocket_tracking_data(websocket: WebSocket):
#     await websocket.accept()
#     mongo_gen = database.get_mongo_db()
#     mongo = next(mongo_gen)
#     tracking_logs = mongo["tracking_logs"]
#     alert_logs = mongo["alert_logs"]
#     devices = mongo["iot_devices"]
#     categories = mongo["alert_category"]
#     keywords = mongo["alert_keywords"]
#     fleets = mongo["fleets"]            # <-- added
#     workforce = mongo["workforce"]      # <-- added

#     print("websocket_tracking_data STARTED")
#     try:
#         while True:
#             message = await websocket.receive_text()
#             data = json.loads(message)
#             data["received_at"] = datetime.utcnow()
#             print("Incoming WebSocket Data:", data)

#             # 1. Identify device
#             device_id = data.get("device_id")
#             try:
#                 device_obj_id = ObjectId(device_id)
#             except bson_errors.InvalidId:
#                 await websocket.send_text("❌ Invalid device_id format")
#                 continue

#             device = devices.find_one({"_id": device_obj_id})
#             if not device:
#                 await websocket.send_text("Unknown device")
#                 continue

#             # 2. Prepare tracking entry
#             tracking_entry = {
#                 **data,
#                 "account_id": device["account_id"],
#                 "is_alert": False
#             }
#             print("=== SPLIT DEBUG START ===")
#             print(payload_text.split())
#             print("=== SPLIT DEBUG END ===")

#             # 3. Detect alerts
#             payload_text = json.dumps(data, default=str).lower()
#             keyword_doc = keywords.find_one({"keyword": {"$in": payload_text.split()}})
#             if keyword_doc:
#                 category = categories.find_one({"_id": keyword_doc["category_id"]})
#                 tracking_entry.update({
#                     "category_id": category["_id"],
#                     "category_name": category["category_name"],
#                     "message": f"Alert detected: {keyword_doc['keyword']}",
#                     "is_alert": True
#                 })

#                 alert_doc = {
#                     "device_id": device_id,
#                     "account_id": device["account_id"],
#                     "category_id": category["_id"],
#                     "category_name": category["category_name"],
#                     "message": f"Alert detected: {keyword_doc['keyword']}",
#                     "raw_payload": data,
#                     "received_at": data["received_at"],
#                     "status": "unread"
#                 }
#                 if USE_KAFKA and producer:
#                     await producer.send(
#                         "alerts",
#                         json.dumps(alert_doc, default=str).encode("utf-8")
#                     )
#                 else:
#                     alert_logs.insert_one(alert_doc)

#             # 4. Insert into tracking_logs
#             if USE_KAFKA and producer:
#                 await producer.send(
#                     "vehicle_tracking",
#                     json.dumps(tracking_entry, default=str).encode("utf-8")
#                 )
#             else:
#                 tracking_logs.insert_one(tracking_entry)

#             # 5. Update Fleet/Workforce with last location
#             try:
#                 coordinates = data.get("location")  # expected {"lat": ..., "lng": ...}
#                 if coordinates:
#                     last_location_update = {
#                         "last_location": {
#                             "coordinates": coordinates,
#                             "speed": data.get("speed"),
#                             "heading": data.get("heading"),
#                             "fuel": data.get("fuel"),
#                             "timestamp": data.get("timestamp"),
#                         },
#                         "last_updated": datetime.utcnow()
#                     }

#                     associated_type = device.get("associated_entity_type")
#                     associated_id = device.get("associated_entity_id")
#                     print("device_iddevice_iddevice_id")
#                     print(device_id)
#                     # ✅ Always update iot_devices
#                     devices.update_one(
#                         {"_id": ObjectId(device_id)},
#                         {"$set": last_location_update}
#                     )

#                     # ✅ Then update linked entity
#                     if associated_type == "vehicle" and associated_id:
#                         fleets.update_one(
#                             {"_id": ObjectId(associated_id)},
#                             {"$set": last_location_update}
#                         )
#                     elif associated_type == "workforce" and associated_id:
#                         workforce.update_one(
#                             {"_id": ObjectId(associated_id)},
#                             {"$set": last_location_update}
#                         )
#                     else:
#                         # Fallback update if association is missing
#                         fleets.update_many(
#                             {"devices": device_id},
#                             {"$set": last_location_update}
#                         )
#                         workforce.update_many(
#                             {"devices": device_id},
#                             {"$set": last_location_update}
#                         )

#                     await notify_nearby_subscribers(device_id)

#             except Exception as e:
#                 print(f"⚠️ Failed to update last location for {device_id}: {e}")


#             # 6. Send ack to WebSocket client
#             if tracking_entry.get("is_alert"):
#                 await websocket.send_text(f"✅ Stored in tracking_logs + alert_logs ({tracking_entry['category_name']})")
#             else:
#                 await websocket.send_text("✅ Stored in tracking_logs")

#     except WebSocketDisconnect:
#         print("⚠️ WebSocket disconnected")
#         pass

# async def websocket_tracking_data(websocket: WebSocket):
#     await websocket.accept()
#     mongo_gen = database.get_mongo_db()
#     mongo = next(mongo_gen)
#     tracking_logs = mongo["tracking_logs"]
#     alert_logs = mongo["alert_logs"]
#     devices = mongo["iot_devices"]
#     categories = mongo["alert_category"]
#     keywords = mongo["alert_keywords"]
#     fleets = mongo["fleets"]
#     workforce = mongo["workforce"]

#     print("websocket_tracking_data STARTED")
#     try:
#         while True:
#             message = await websocket.receive_text()
#             data = json.loads(message)
#             data["received_at"] = datetime.utcnow()
#             print("Incoming WebSocket Data:", data)

#             # 1. Identify device
#             device_id = data.get("device_id")
#             try:
#                 device_obj_id = ObjectId(device_id)
#             except bson_errors.InvalidId:
#                 await websocket.send_text("❌ Invalid device_id format")
#                 continue

#             device = devices.find_one({"_id": device_obj_id})
#             if not device:
#                 await websocket.send_text("Unknown device")
#                 continue

#             # 2. Prepare tracking entry
#             tracking_entry = {
#                 **data,
#                 "account_id": device["account_id"],
#                 "is_alert": False
#             }

#             # 🔹 Build searchable text from entire payload
#             payload_text = json.dumps(data, default=str).lower()

#             # 3. Detect alerts by scanning keywords
#             all_keywords = list(keywords.find({}, {"keyword": 1, "category_id": 1}))
#             matched_keyword = None
#             for kw in all_keywords:
#                 if kw["keyword"].lower() in payload_text:
#                     matched_keyword = kw
#                     break

#             print("=== ALERT DEBUG START ===")
#             print("Payload text:", payload_text)
#             print("Matched keyword:", matched_keyword)
#             print("=== ALERT DEBUG END ===")

#             if matched_keyword:
#                 category = categories.find_one({"_id": matched_keyword["category_id"]})
#                 tracking_entry.update({
#                     "category_id": category["_id"],
#                     "category_name": category["category_name"],
#                     "message": f"Alert detected: {matched_keyword['keyword']}",
#                     "is_alert": True
#                 })

#                 alert_doc = {
#                     "device_id": device_id,
#                     "account_id": device["account_id"],
#                     "category_id": category["_id"],
#                     "category_name": category["category_name"],
#                     "message": f"Alert detected: {matched_keyword['keyword']}",
#                     "raw_payload": data,
#                     "received_at": data["received_at"],
#                     "status": "unread"
#                 }
#                 if USE_KAFKA and producer:
#                     await producer.send(
#                         "alerts",
#                         json.dumps(alert_doc, default=str).encode("utf-8")
#                     )
#                 else:
#                     alert_logs.insert_one(alert_doc)

#             # 4. Insert into tracking_logs
#             if USE_KAFKA and producer:
#                 await producer.send(
#                     "vehicle_tracking",
#                     json.dumps(tracking_entry, default=str).encode("utf-8")
#                 )
#             else:
#                 tracking_logs.insert_one(tracking_entry)

#             # 5. Update Fleet/Workforce with last location
#             try:
#                 coordinates = data.get("location")
#                 if coordinates:
#                     last_location_update = {
#                         "last_location": {
#                             "coordinates": coordinates,
#                             "speed": data.get("speed"),
#                             "heading": data.get("heading"),
#                             "fuel": data.get("fuel"),
#                             "timestamp": data.get("timestamp"),
#                         },
#                         "last_updated": datetime.utcnow()
#                     }

#                     associated_type = device.get("associated_entity_type")
#                     associated_id = device.get("associated_entity_id")

#                     # ✅ Always update iot_devices
#                     devices.update_one(
#                         {"_id": ObjectId(device_id)},
#                         {"$set": last_location_update}
#                     )

#                     # ✅ Then update linked entity
#                     if associated_type == "vehicle" and associated_id:
#                         fleets.update_one(
#                             {"_id": ObjectId(associated_id)},
#                             {"$set": last_location_update}
#                         )
#                     elif associated_type == "workforce" and associated_id:
#                         workforce.update_one(
#                             {"_id": ObjectId(associated_id)},
#                             {"$set": last_location_update}
#                         )
#                     else:
#                         fleets.update_many(
#                             {"devices": device_id},
#                             {"$set": last_location_update}
#                         )
#                         workforce.update_many(
#                             {"devices": device_id},
#                             {"$set": last_location_update}
#                         )

#                     await notify_nearby_subscribers(device_id)

#             except Exception as e:
#                 print(f"⚠️ Failed to update last location for {device_id}: {e}")

#             # 6. Send ack to WebSocket client
#             if tracking_entry.get("is_alert"):
#                 await websocket.send_text(
#                     f"✅ Stored in tracking_logs + alert_logs ({tracking_entry['category_name']})"
#                 )
#             else:
#                 await websocket.send_text("✅ Stored in tracking_logs")

#     except WebSocketDisconnect:
#         print("⚠️ WebSocket disconnected")
#         pass

# Broadcast tracking entry to subscribed clients
async def broadcast_tracking_update(tracking_entry):
    device_id = tracking_entry.get("device_id")
    for ws in list(connected_clients):
        subscribed_device = device_subscriptions.get(ws)
        if subscribed_device != device_id:
            continue
        try:
            await ws.send_text(json.dumps(tracking_entry, default=str))
        except Exception as e:
            connected_clients.discard(ws)
            device_subscriptions.pop(ws, None)
            print(f"❌ WebSocket broadcast error: {e}")

async def websocket_tracking_data(websocket: WebSocket):
    await websocket.accept()
    mongo_gen = database.get_mongo_db()
    mongo = next(mongo_gen)
    tracking_logs = mongo["tracking_logs"]
    alert_logs = mongo["alert_logs"]
    devices = mongo["iot_devices"]
    categories = mongo["alert_category"]
    keywords = mongo["alert_keywords"]
    fleets = mongo["fleets"]
    workforce = mongo["workforce"]

    print("websocket_tracking_data STARTED")
    try:
        while True:
            message = await websocket.receive_text()
            data = json.loads(message)
            data["received_at"] = datetime.utcnow()
            print("Incoming WebSocket Data:", data)

            device_id = data.get("device_id")
            try:
                device_obj_id = ObjectId(device_id)
            except bson_errors.InvalidId:
                await websocket.send_text("❌ Invalid device_id format")
                continue

            device = devices.find_one({"_id": device_obj_id})
            if not device:
                await websocket.send_text("Unknown device")
                continue

            tracking_entry = {**data, "account_id": device["account_id"], "is_alert": False}
            payload_text = json.dumps(data, default=str).lower()

            # Detect alerts
            all_keywords = list(keywords.find({}, {"keyword": 1, "category_id": 1}))
            matched_keyword = None
            for kw in all_keywords:
                if kw["keyword"].lower() in payload_text:
                    matched_keyword = kw
                    break

            if matched_keyword:
                category = categories.find_one({"_id": matched_keyword["category_id"]})
                tracking_entry.update({
                    "category_id": category["_id"],
                    "category_name": category["category_name"],
                    "message": f"Alert detected: {matched_keyword['keyword']}",
                    "is_alert": True
                })
                alert_doc = {
                    "device_id": device_id,
                    "account_id": device["account_id"],
                    "category_id": category["_id"],
                    "category_name": category["category_name"],
                    "message": f"Alert detected: {matched_keyword['keyword']}",
                    "raw_payload": data,
                    "received_at": data["received_at"],
                    "status": "unread"
                }
                if USE_KAFKA and producer:
                    await producer.send("alerts", json.dumps(alert_doc, default=str).encode("utf-8"))
                else:
                    alert_logs.insert_one(alert_doc)

            # Insert tracking logs
            if USE_KAFKA and producer:
                await producer.send("vehicle_tracking", json.dumps(tracking_entry, default=str).encode("utf-8"))
            else:
                tracking_logs.insert_one(tracking_entry)

            # Update last location
            try:
                coordinates = data.get("location")
                if coordinates:
                    last_location_update = {
                        "last_location": {
                            "coordinates": coordinates,
                            "speed": data.get("speed"),
                            "heading": data.get("heading"),
                            "fuel": data.get("fuel"),
                            "timestamp": data.get("timestamp"),
                        },
                        "last_updated": datetime.utcnow()
                    }
                    assoc_type = device.get("associated_entity_type")
                    assoc_id = device.get("associated_entity_id")
                    devices.update_one({"_id": ObjectId(device_id)}, {"$set": last_location_update})
                    if assoc_type == "vehicle" and assoc_id:
                        fleets.update_one({"_id": ObjectId(assoc_id)}, {"$set": last_location_update})
                    elif assoc_type == "workforce" and assoc_id:
                        workforce.update_one({"_id": ObjectId(assoc_id)}, {"$set": last_location_update})
                    else:
                        fleets.update_many({"devices": device_id}, {"$set": last_location_update})
                        workforce.update_many({"devices": device_id}, {"$set": last_location_update})

                    await notify_nearby_subscribers(device_id)
            except Exception as e:
                print(f"⚠️ Failed to update last location for {device_id}: {e}")

            # Broadcast to subscribed clients
            await broadcast_tracking_update(tracking_entry)

            # Ack to device
            if tracking_entry.get("is_alert"):
                await websocket.send_text(f"✅ Stored in tracking_logs + alert_logs ({tracking_entry['category_name']})")
            else:
                await websocket.send_text("✅ Stored in tracking_logs")

    except WebSocketDisconnect:
        print("⚠️ WebSocket disconnected")

async def websocket_realtime_updates(websocket: WebSocket):
    await websocket.accept()
    connected_clients.add(websocket)
    print("✅ Client connected. Total:", len(connected_clients))  # ✅ Moved here

    try:
        while True:
            await asyncio.sleep(1)
    except WebSocketDisconnect:
        connected_clients.remove(websocket)
        print("❌ Client disconnected. Remaining:", len(connected_clients))

# async def websocket_historical_route(websocket: WebSocket):
#     await websocket.accept()
#     try:
#         query = await websocket.receive_text()
#         query_data = json.loads(query)
#         device_id = query_data.get("device_id")
#         datetime_str = query_data.get("date")  # e.g. "2025-06-01 09:00:00"

#         if not device_id or not datetime_str:
#             await websocket.send_text("Missing device_id or date.")
#             await websocket.close()
#             return

#         # ✅ Parse "YYYY-MM-DD HH:MM:SS"
#         try:
#             start_date = datetime.strptime(datetime_str, "%Y-%m-%d %H:%M:%S")
#             end_date = start_date + timedelta(hours=1)   # Example: 1-hour window
#         except ValueError:
#             await websocket.send_text("Invalid date format. Use YYYY-MM-DD HH:MM:SS")
#             await websocket.close()
#             return

#         # DB connection
#         mongo_gen = database.get_mongo_db()
#         mongo = next(mongo_gen)
#         collection = mongo["tracking_logs"]
#         alert_logs = mongo["alert_logs"]
#         devices = mongo["iot_devices"]
#         categories = mongo["alert_category"]
#         keywords = mongo["alert_keywords"]

#         # Fetch historical points in that range
#         cursor = collection.find({
#             "device_id": device_id,
#             "received_at": {
#                 "$gte": start_date,
#                 "$lt": end_date
#             }
#         }, {
#             "location": 1, "speed": 1, "heading": 1,
#             "fuel": 1, "received_at": 1, "alert_type": 1
#         })

#         print("Query returned:", cursor)

#         # Stream points one-by-one
#         for doc in cursor:
#             # 🔎 Enrich alerts just like in websocket_tracking_data
#             enriched_doc = dict(doc)
#             enriched_doc["is_alert"] = False

#             payload_text = json.dumps(doc, default=str).lower()
#             keyword_doc = keywords.find_one({
#                 "keyword": {"$in": payload_text.split()}
#             })

#             if keyword_doc:
#                 category = categories.find_one({"_id": keyword_doc["category_id"]})
#                 enriched_doc.update({
#                     "category_id": category["_id"],
#                     "category_name": category["category_name"],
#                     "message": f"Alert detected: {keyword_doc['keyword']}",
#                     "is_alert": True
#                 })

#                 alert_doc = {
#                     "device_id": device_id,
#                     "account_id": devices.find_one({"_id": ObjectId(device_id)})["account_id"],
#                     "category_id": category["_id"],
#                     "category_name": category["category_name"],
#                     "message": f"Alert detected: {keyword_doc['keyword']}",
#                     "raw_payload": doc,
#                     "received_at": doc["received_at"],
#                     "status": "historical"
#                 }
#                 alert_logs.insert_one(alert_doc)

#             # Send enriched doc to client
#             await websocket.send_text(json.dumps(enriched_doc, default=str))
#             await asyncio.sleep(0.2)

#         await websocket.send_text("done")
#         await websocket.close()

#         try:
#             next(mongo_gen)  # Cleanup DB connection
#         except StopIteration:
#             pass

#     except WebSocketDisconnect:
#         print("Client disconnected.")
#     except Exception as e:
#         print(f"Error: {str(e)}")
#         await websocket.send_text(f"Error: {str(e)}")
#         await websocket.close()

# async def websocket_historical_route(websocket: WebSocket):
#     await websocket.accept()
#     mongo_gen = None

#     try:
#         query = await websocket.receive_text()
#         query_data = json.loads(query)
#         #device_id = query_data.get("device_id")
#         device_id = "683d3ffdc355c05ce30f9c94"
#         datetime_str = query_data.get("date")  # e.g. "2025-06-01 09:00:00"

#         if not device_id or not datetime_str:
#             await websocket.send_text("Missing device_id or date.")
#             await websocket.close()
#             return

#         # Parse "YYYY-MM-DD HH:MM:SS"
#         try:
#             start_date = datetime.strptime(datetime_str, "%Y-%m-%d %H:%M:%S")
#             end_date = datetime.now()   # ✅ current date and time
#         except ValueError:
#             await websocket.send_text("Invalid date format. Use YYYY-MM-DD HH:MM:SS")
#             await websocket.close()
#             return

#         # DB connection
#         mongo_gen = database.get_mongo_db()
#         mongo = next(mongo_gen)
#         collection = mongo["tracking_logs"]
#         alert_logs = mongo["alert_logs"]
#         devices = mongo["iot_devices"]
#         categories = mongo["alert_category"]
#         keywords = mongo["alert_keywords"]

#         # Fetch historical points in that range
#         cursor = collection.find(
#             {
#                 "device_id": device_id,
#                 "received_at": {
#                     "$gte": start_date,
#                     "$lt": end_date
#                 }
#             },
#             {
#                 "location": 1, "speed": 1, "heading": 1,
#                 "fuel": 1, "received_at": 1, "alert_type": 1
#             }
#         )

#         docs = list(cursor)
#         print(f"Query returned {len(docs)} records")

#         # Stream points one-by-one
#         for doc in docs:
#             enriched_doc = dict(doc)
#             enriched_doc["is_alert"] = False

#             payload_text = json.dumps(doc, default=str).lower()
#             keyword_doc = keywords.find_one({
#                 "keyword": {"$in": payload_text.split()}
#             })

#             if keyword_doc:
#                 category = categories.find_one({"_id": keyword_doc["category_id"]})
#                 enriched_doc.update({
#                     "category_id": category["_id"],
#                     "category_name": category["category_name"],
#                     "message": f"Alert detected: {keyword_doc['keyword']}",
#                     "is_alert": True
#                 })

#                 # Insert into alert_logs
#                 alert_doc = {
#                     "device_id": device_id,
#                     "account_id": devices.find_one({"_id": ObjectId(device_id)})["account_id"],
#                     "category_id": category["_id"],
#                     "category_name": category["category_name"],
#                     "message": f"Alert detected: {keyword_doc['keyword']}",
#                     "raw_payload": doc,
#                     "received_at": doc["received_at"],
#                     "status": "historical"
#                 }
#                 alert_logs.insert_one(alert_doc)

#             # Send enriched doc to client
#             await websocket.send_text(json.dumps(enriched_doc, default=str))
#             await asyncio.sleep(0.2)

#         await websocket.send_text("done")
#         await websocket.close()

#     except WebSocketDisconnect:
#         print("Client disconnected.")
#     except Exception as e:
#         print(f"Error: {str(e)}")
#         try:
#             await websocket.send_text(f"Error: {str(e)}")
#         except Exception:
#             pass
#         await websocket.close()
#     finally:
#         if mongo_gen:
#             mongo_gen.close()

# async def websocket_historical_route(websocket: WebSocket):
#     await websocket.accept()
#     try:
#         query = await websocket.receive_text()
#         query_data = json.loads(query)
#         #device_id = query_data.get("device_id")
#         device_id = "683d3ffdc355c05ce30f9c94"
#         datetime_str = query_data.get("date")  # e.g. "2025-06-01 09:00:00"

#         if not device_id or not datetime_str:
#             await websocket.send_text("Missing device_id or date.")
#             await websocket.close()
#             return

#         # ✅ Parse "YYYY-MM-DD HH:MM:SS"
#         try:
#             start_date = datetime.strptime(datetime_str, "%Y-%m-%d %H:%M:%S")
#             end_date = datetime.now()   # ✅ current date and time
#         except ValueError:
#             await websocket.send_text("Invalid date format. Use YYYY-MM-DD HH:MM:SS")
#             await websocket.close()
#             return

#         # DB connection
#         mongo_gen = database.get_mongo_db()
#         mongo = next(mongo_gen)
#         collection = mongo["tracking_logs"]

#         # Fetch historical points in that range
#         cursor = collection.find({
#             "device_id": device_id,
#             "received_at": {
#                 "$gte": start_date,
#                 "$lt": end_date
#             }
#         }, {"location": 1, "speed": 1, "heading": 1, "received_at": 1})


#         # ✅ print only after cursor is assigned
#         print("Query returned:", cursor)

#         # Stream points one-by-one
#         for doc in cursor:
#             print(json.dumps(doc, default=str))
#             await websocket.send_text(json.dumps(doc, default=str))
#             await asyncio.sleep(0.2)  # Simulate streaming

#         await websocket.send_text("done")
#         await websocket.close()

#         try:
#             next(mongo_gen)  # Cleanup DB connection
#         except StopIteration:
#             pass

#     except WebSocketDisconnect:
#         print("Client disconnected.")
#     except Exception as e:
#         print(f"Error: {str(e)}")
#         await websocket.send_text(f"Error: {str(e)}")
#         await websocket.close()

async def websocket_historical_route(websocket: WebSocket):
    await websocket.accept()
    try:
        query = await websocket.receive_text()
        query_data = json.loads(query)

        # ✅ Device ID (hardcoded for now, or use query_data.get("device_id"))
        #device_id = "683d3ffdc355c05ce30f9c94"
        device_id = query_data.get("device_id")
        datetime_str = query_data.get("date")  # e.g. "2025-06-01 09:00:00"

        if not device_id or not datetime_str:
            await websocket.send_text("Missing device_id or date.")
            await websocket.close()
            return

        # ✅ Parse "YYYY-MM-DD HH:MM:SS"
        try:
            start_date = datetime.strptime(datetime_str, "%Y-%m-%d %H:%M:%S")
            end_date = datetime.utcnow()   # ✅ Now = current date and time
        except ValueError:
            await websocket.send_text("Invalid date format. Use YYYY-MM-DD HH:MM:SS")
            await websocket.close()
            return

        # DB connection
        mongo_gen = database.get_mongo_db()
        mongo = next(mongo_gen)
        collection = mongo["tracking_logs"]

        # ✅ Fetch documents within range
        cursor = collection.find(
            {
                "device_id": device_id,
                "received_at": {
                    "$gte": start_date,
                    "$lt": end_date
                }
            },
            {
                "_id": 1,
                "device_id": 1,
                "location": 1,
                "speed": 1,
                "heading": 1,
                "fuel": 1,
                "received_at": 1,
                "account_id": 1,
                "is_alert": 1,
                "category_id": 1,
                "category_name": 1,
                "message": 1,
                "UUID": 1
            }
        ).sort("received_at", 1)
        print("cursorcursorcursorcursor")
        print(cursor)
        print("************************")
        # ✅ Stream documents one by one (normal loop, not async for)
        for doc in cursor:
            response = {
                "id": str(doc["_id"]),
                "device_id": doc.get("device_id"),
                "location": doc.get("location"),
                "speed": doc.get("speed"),
                "heading": doc.get("heading"),
                "fuel": doc.get("fuel"),
                "UUID": doc.get("UUID"),
                "account_id": doc.get("account_id"),
                "category_id": doc.get("category_id"),
                "category_name": doc.get("category_name"),
                "message": doc.get("message"),
                "is_alert": doc.get("is_alert", False),
                "received_at": doc.get("received_at").isoformat() if doc.get("received_at") else None
            }

            print("Historical record is alert checking:", response)
            await websocket.send_text(json.dumps(response, default=str))
            await asyncio.sleep(0.2)  # simulate streaming delay

        await websocket.send_text("done")
        await websocket.close()

        # ✅ Close DB connection
        try:
            next(mongo_gen)
        except StopIteration:
            pass

    except WebSocketDisconnect:
        print("⚠️ Client disconnected.")
    except Exception as e:
        print(f"❌ Error: {str(e)}")
        await websocket.send_text(f"Error: {str(e)}")
        await websocket.close()
        
async def websocket_workforce_schedules(websocket: WebSocket):
    await websocket.accept()  # ✅ Accept first

    try:
        msg = await websocket.receive_text()
        data = json.loads(msg)
        workforce_id = data.get("workforce_id")

        if not workforce_id:
            await websocket.send_text("Missing workforce_id.")
            await websocket.close()
            return

        connected_clients[workforce_id] = websocket
        print(f"✅ Workforce {workforce_id} connected. Total clients: {len(connected_clients)}")

        while True:
            await asyncio.sleep(1)

    except WebSocketDisconnect:
        if workforce_id in connected_clients:
            del connected_clients[workforce_id]
            print(f"❌ Workforce {workforce_id} disconnected. Remaining: {len(connected_clients)}")

async def websocket_nearby_devices(websocket: WebSocket):
    await websocket.accept()
    mongo_gen = database.get_mongo_db()
    db = next(mongo_gen)

    try:
        # Receive initial subscription request
        init_message = await websocket.receive_text()
        data = json.loads(init_message)
        payload = NearbyDevicesRequest(**data)

        # Store subscriber
        active_nearby_subscribers[websocket] = {
            "latitude": payload.latitude,
            "longitude": payload.longitude,
            "radius_km": payload.radius_km,
            "db": db,
        }

        # Send initial device list
        result = list_nearby_devices_service_socket(
            payload.latitude, payload.longitude, payload.radius_km, db
        )
        response = IotDeviceWithMappingList(**result)
        await websocket.send_text(json.dumps(response.dict(), default=str))

        # Keep socket open
        while True:
            await websocket.receive_text()  # ignore pings/extra messages

    except WebSocketDisconnect:
        print("❌ Nearby devices WS disconnected")
        if websocket in active_nearby_subscribers:
            del active_nearby_subscribers[websocket]


# Notify subscribers when device location changes
async def notify_nearby_subscribers(updated_device):
    for ws, sub in list(active_nearby_subscribers.items()):
        try:
            result = list_nearby_devices_service_socket(
                sub["latitude"], sub["longitude"], sub["radius_km"], sub["db"]
            )
            response = IotDeviceWithMappingList(**result)
            await ws.send_text(json.dumps(response.dict(), default=str))
        except Exception as e:
            print(f"⚠️ Failed sending update to subscriber: {e}")

#Live movement record and pass tracking data
# async def websocket_historical_and_live(websocket: WebSocket):
#     await websocket.accept()
#     try:
#         query = await websocket.receive_text()
#         query_data = json.loads(query)

#         device_id = query_data.get("device_id")
#         datetime_str = query_data.get("date")

#         if not device_id or not datetime_str:
#             await websocket.send_text("Missing device_id or date.")
#             await websocket.close()
#             return

#         try:
#             start_date = datetime.strptime(datetime_str, "%Y-%m-%d %H:%M:%S")
#             end_date = datetime.utcnow()
#         except ValueError:
#             await websocket.send_text("Invalid date format. Use YYYY-MM-DD HH:MM:SS")
#             await websocket.close()
#             return

#         # DB connection
#         mongo_gen = database.get_mongo_db()
#         mongo = next(mongo_gen)
#         collection = mongo["tracking_logs"]

#         # 1️⃣ Send historical data
#         cursor = collection.find(
#             {"device_id": device_id, "received_at": {"$gte": start_date, "$lt": end_date}}
#         ).sort("received_at", 1)

#         for doc in cursor:
#             response = {
#                 "id": str(doc["_id"]),
#                 "device_id": doc.get("device_id"),
#                 "location": doc.get("location"),
#                 "speed": doc.get("speed"),
#                 "heading": doc.get("heading"),
#                 "fuel": doc.get("fuel"),
#                 "UUID": doc.get("UUID"),
#                 "account_id": doc.get("account_id"),
#                 "category_id": doc.get("category_id"),
#                 "category_name": doc.get("category_name"),
#                 "message": doc.get("message"),
#                 "is_alert": doc.get("is_alert", False),
#                 "received_at": doc.get("received_at").isoformat() if doc.get("received_at") else None
#             }
#             await websocket.send_text(json.dumps(response, default=str))
#             await asyncio.sleep(0.1)

#         await websocket.send_text("history_done")

#         # 2️⃣ Subscribe client for live updates
#         connected_clients[websocket] = device_id
#         print(f"✅ Client subscribed for live updates: {device_id}")

#         # 3️⃣ Keep alive
#         while True:
#             await asyncio.sleep(10)

#     except WebSocketDisconnect:
#         if websocket in connected_clients:
#             del connected_clients[websocket]
#             print(f"❌ Client {device_id} disconnected. Remaining: {len(connected_clients)}")
#     except Exception as e:
#         print(f"❌ Error: {str(e)}")
#         await websocket.send_text(f"Error: {str(e)}")
#         await websocket.close()

# async def websocket_historical_and_live(websocket: WebSocket):
#     await websocket.accept()
#     try:
#         query = await websocket.receive_text()
#         query_data = json.loads(query)

#         device_id = query_data.get("device_id")
#         datetime_str = query_data.get("date")

#         if not device_id or not datetime_str:
#             await websocket.send_text("Missing device_id or date.")
#             await websocket.close()
#             return

#         try:
#             start_date = datetime.strptime(datetime_str, "%Y-%m-%d %H:%M:%S")
#             end_date = datetime.utcnow()
#         except ValueError:
#             await websocket.send_text("Invalid date format. Use YYYY-MM-DD HH:MM:SS")
#             await websocket.close()
#             return

#         # DB connection
#         mongo_gen = database.get_mongo_db()
#         mongo = next(mongo_gen)
#         collection = mongo["tracking_logs"]

#         # 1️⃣ Send historical data
#         cursor = collection.find(
#             {"device_id": device_id, "received_at": {"$gte": start_date, "$lt": end_date}}
#         ).sort("received_at", 1)

#         for doc in cursor:
#             response = {
#                 "id": str(doc["_id"]),
#                 "device_id": doc.get("device_id"),
#                 "location": doc.get("location"),
#                 "speed": doc.get("speed"),
#                 "heading": doc.get("heading"),
#                 "fuel": doc.get("fuel"),
#                 "UUID": doc.get("UUID"),
#                 "account_id": doc.get("account_id"),
#                 "category_id": doc.get("category_id"),
#                 "category_name": doc.get("category_name"),
#                 "message": doc.get("message"),
#                 "is_alert": doc.get("is_alert", False),
#                 "received_at": doc.get("received_at").isoformat() if doc.get("received_at") else None
#             }
#             await websocket.send_text(json.dumps(response, default=str))
#             await asyncio.sleep(0.05)

#         await websocket.send_text("history_done")

#         # 2️⃣ Subscribe client for live updates
#         connected_clients.add(websocket)
#         device_subscriptions[websocket] = device_id
#         print(f"✅ Client subscribed for live updates: {device_id}")

#         # 3️⃣ Keep connection alive
#         while True:
#             await asyncio.sleep(10)

#     except WebSocketDisconnect:
#         connected_clients.discard(websocket)
#         device_subscriptions.pop(websocket, None)
#         print(f"❌ Client {device_id} disconnected. Remaining: {len(connected_clients)}")
#     except Exception as e:
#         print(f"❌ Error: {str(e)}")
#         await websocket.send_text(f"Error: {str(e)}")
#         await websocket.close()

# async def websocket_historical_and_live(websocket: WebSocket):
#     await websocket.accept()
#     device_id = None

#     try:
#         # 1️⃣ Receive subscription query from client
#         query = await websocket.receive_text()
#         query_data = json.loads(query)
#         device_id = query_data.get("device_id")
#         datetime_str = query_data.get("date")

#         if not device_id or not datetime_str:
#             await websocket.send_text("Missing device_id or date.")
#             await websocket.close()
#             return

#         try:
#             start_date = datetime.strptime(datetime_str, "%Y-%m-%d %H:%M:%S")
#             end_date = datetime.utcnow()
#         except ValueError:
#             await websocket.send_text("Invalid date format. Use YYYY-MM-DD HH:MM:SS")
#             await websocket.close()
#             return

#         # 2️⃣ DB connection and fetch historical data
#         mongo_gen = database.get_mongo_db()
#         mongo = next(mongo_gen)
#         collection = mongo["tracking_logs"]

#         cursor = collection.find(
#             {"device_id": device_id, "received_at": {"$gte": start_date, "$lt": end_date}}
#         ).sort("received_at", 1)

#         for doc in cursor:
#             response = {
#                 "id": str(doc["_id"]),
#                 "device_id": doc.get("device_id"),
#                 "location": doc.get("location"),
#                 "speed": doc.get("speed"),
#                 "heading": doc.get("heading"),
#                 "fuel": doc.get("fuel"),
#                 "UUID": doc.get("UUID"),
#                 "account_id": doc.get("account_id"),
#                 "category_id": doc.get("category_id"),
#                 "category_name": doc.get("category_name"),
#                 "message": doc.get("message"),
#                 "is_alert": doc.get("is_alert", False),
#                 "received_at": doc.get("received_at").isoformat() if doc.get("received_at") else None
#             }
#             await websocket.send_text(json.dumps(response, default=str))
#             await asyncio.sleep(0.05)  # small delay to avoid flooding

#         await websocket.send_text("history_done")

#         # 3️⃣ Subscribe client for live updates
#         connected_clients.add(websocket)
#         device_subscriptions[websocket] = device_id
#         print(f"✅ Client subscribed for live updates: {device_id}")

#         # 4️⃣ Keep connection alive
#         while True:
#             await asyncio.sleep(10)

#     except WebSocketDisconnect:
#         if websocket in connected_clients:
#             connected_clients.discard(websocket)
#             device_subscriptions.pop(websocket, None)
#             print(f"❌ Client {device_id} disconnected. Remaining: {len(connected_clients)}")
#     except Exception as e:
#         print(f"❌ Error: {str(e)}")
#         await websocket.send_text(f"Error: {str(e)}")
#         await websocket.close()

# async def websocket_historical_and_live(websocket: WebSocket):
async def websocket_historical_and_live(websocket: WebSocket):
    await websocket.accept()
    device_id = None

    try:
        # 1️⃣ Receive subscription query
        query = await websocket.receive_text()
        query_data = json.loads(query)

        device_id = query_data.get("device_id")
        date_str = query_data.get("date")
        from_date_str = query_data.get("from_date")
        to_date_str = query_data.get("to_date")

        if not device_id:
            await websocket.send_text("Missing device_id.")
            await websocket.close()
            return

        # Function to parse dates
        def parse_date(value):
            if value:
                try:
                    return datetime.strptime(value, "%Y-%m-%d %H:%M:%S")
                except ValueError:
                    return None
            return None

        date_val = parse_date(date_str)
        from_date = parse_date(from_date_str)
        to_date = parse_date(to_date_str)

        if date_str and not date_val:
            await websocket.send_text("Invalid date format for 'date'. Use YYYY-MM-DD HH:MM:SS")
            await websocket.close()
            return

        if from_date_str and not from_date:
            await websocket.send_text("Invalid date format for 'from_date'. Use YYYY-MM-DD HH:MM:SS")
            await websocket.close()
            return

        if to_date_str and not to_date:
            await websocket.send_text("Invalid date format for 'to_date'. Use YYYY-MM-DD HH:MM:SS")
            await websocket.close()
            return

        # 2️⃣ Build date filter logic
        date_filter = {}

        if from_date and to_date:
            date_filter = {"$gte": from_date, "$lte": to_date}
        elif from_date:
            date_filter = {"$gte": from_date, "$lt": datetime.utcnow()}
        elif date_val:
            date_filter = {"$gte": date_val, "$lt": datetime.utcnow()}

        # 3️⃣ DB connection
        mongo_gen = database.get_mongo_db()
        mongo = next(mongo_gen)
        collection = mongo["tracking_logs"]

        query_object = {"device_id": device_id}
        if date_filter:
            query_object["received_at"] = date_filter

        cursor = collection.find(query_object).sort("received_at", 1)

        # 4️⃣ Correct: normal loop (NOT async for)
        for doc in cursor:
            response = {
                "id": str(doc["_id"]),
                "device_id": doc.get("device_id"),
                "location": doc.get("location"),
                "speed": doc.get("speed"),
                "heading": doc.get("heading"),
                "fuel": doc.get("fuel"),
                "UUID": doc.get("UUID"),
                "account_id": doc.get("account_id"),
                "category_id": doc.get("category_id"),
                "category_name": doc.get("category_name"),
                "message": doc.get("message"),
                "is_alert": doc.get("is_alert", False),
                "received_at": doc.get("received_at").isoformat() if doc.get("received_at") else None
            }
            await websocket.send_text(json.dumps(response, default=str))
            await asyncio.sleep(0.05)

        await websocket.send_text("history_done")

        # 5️⃣ Subscribe to live
        connected_clients.add(websocket)
        device_subscriptions[websocket] = device_id

        # 6️⃣ Keep alive
        while True:
            await asyncio.sleep(10)

    except WebSocketDisconnect:
        if websocket in connected_clients:
            connected_clients.discard(websocket)
            device_subscriptions.pop(websocket, None)

    except Exception as e:
        print(f"❌ Error: {str(e)}")
        await websocket.send_text(f"Error: {str(e)}")
        await websocket.close()



#For return all the past historical data in single api call
# async def websocket_historical_and_live_all(websocket: WebSocket):
#     await websocket.accept()
#     device_id = None

#     try:
#         # 1️⃣ Receive subscription query from client
#         query = await websocket.receive_text()
#         query_data = json.loads(query)
#         device_id = query_data.get("device_id")
#         datetime_str = query_data.get("date")

#         if not device_id or not datetime_str:
#             await websocket.send_text("Missing device_id or date.")
#             await websocket.close()
#             return

#         try:
#             start_date = datetime.strptime(datetime_str, "%Y-%m-%d %H:%M:%S")
#             end_date = datetime.utcnow()
#         except ValueError:
#             await websocket.send_text("Invalid date format. Use YYYY-MM-DD HH:MM:SS")
#             await websocket.close()
#             return

#         # 2️⃣ DB connection and fetch historical data
#         mongo_gen = database.get_mongo_db()
#         mongo = next(mongo_gen)
#         collection = mongo["tracking_logs"]

#         cursor = collection.find(
#             {"device_id": device_id, "received_at": {"$gte": start_date, "$lt": end_date}}
#         ).sort("received_at", 1)

#         # Convert all docs into a list (single push)
#         history_data = []
#         for doc in cursor:
#             history_data.append({
#                 "id": str(doc["_id"]),
#                 "device_id": doc.get("device_id"),
#                 "location": doc.get("location"),
#                 "speed": doc.get("speed"),
#                 "heading": doc.get("heading"),
#                 "fuel": doc.get("fuel"),
#                 "UUID": doc.get("UUID"),
#                 "account_id": doc.get("account_id"),
#                 "category_id": doc.get("category_id"),
#                 "category_name": doc.get("category_name"),
#                 "message": doc.get("message"),
#                 "is_alert": doc.get("is_alert", False),
#                 "received_at": doc.get("received_at").isoformat() if doc.get("received_at") else None
#             })

#         # 3️⃣ Send entire history data once
#         await websocket.send_text(json.dumps({
#             "type": "history",
#             "device_id": device_id,
#             "data": history_data
#         }, default=str))

#         await websocket.send_text("history_done")

#         # 4️⃣ Subscribe to live
#         connected_clients.add(websocket)
#         device_subscriptions[websocket] = device_id
#         print(f"✅ Client subscribed for live updates: {device_id}")

#         # 5️⃣ Keep connection alive
#         while True:
#             await asyncio.sleep(10)

#     except WebSocketDisconnect:
#         if websocket in connected_clients:
#             connected_clients.discard(websocket)
#             device_subscriptions.pop(websocket, None)
#             print(f"❌ Client {device_id} disconnected. Remaining: {len(connected_clients)}")

#     except Exception as e:
#         print(f"❌ Error: {str(e)}")
#         await websocket.send_text(f"Error: {str(e)}")
#         await websocket.close()

#For return all the past historical data in single api call
async def websocket_historical_and_live_all(websocket: WebSocket):
    await websocket.accept()
    device_id = None

    try:
        # 1️⃣ Receive client query
        query = await websocket.receive_text()
        query_data = json.loads(query)

        device_id = query_data.get("device_id")
        date_str = query_data.get("date")  # old behavior
        from_date_str = query_data.get("from_date")
        to_date_str = query_data.get("to_date")

        if not device_id:
            await websocket.send_text("Missing device_id.")
            await websocket.close()
            return

        # 🔍 Helper function to parse date safely
        def parse_date(value):
            if value:
                try:
                    return datetime.strptime(value, "%Y-%m-%d %H:%M:%S")
                except Exception:
                    return None
            return None

        # Parse all dates
        date_val = parse_date(date_str)
        from_date = parse_date(from_date_str)
        to_date = parse_date(to_date_str)

        # 🔥 Validate bad date formats
        if date_str and not date_val:
            await websocket.send_text("Invalid format for 'date'. Use YYYY-MM-DD HH:MM:SS")
            await websocket.close()
            return

        if from_date_str and not from_date:
            await websocket.send_text("Invalid format for 'from_date'. Use YYYY-MM-DD HH:MM:SS")
            await websocket.close()
            return

        if to_date_str and not to_date:
            await websocket.send_text("Invalid format for 'to_date'. Use YYYY-MM-DD HH:MM:SS")
            await websocket.close()
            return

        # 2️⃣ Build date filter logic
        date_filter = {}

        if from_date and to_date:
            # Case 1: from_date AND to_date
            date_filter = {"$gte": from_date, "$lte": to_date}

        elif from_date:
            # Case 2: only from_date
            date_filter = {"$gte": from_date, "$lt": datetime.utcnow()}

        elif date_val:
            # Case 3: only date param
            date_filter = {"$gte": date_val, "$lt": datetime.utcnow()}

        # Case 4: No date filter → get full history (do not add filter)

        # 3️⃣ DB connection
        mongo_gen = database.get_mongo_db()
        mongo = next(mongo_gen)
        collection = mongo["tracking_logs"]

        query_object = {"device_id": device_id}

        if date_filter:
            query_object["received_at"] = date_filter

        cursor = collection.find(query_object).sort("received_at", 1)

        # 4️⃣ Convert cursor to list (send once)
        history_data = []
        for doc in cursor:
            history_data.append({
                "id": str(doc["_id"]),
                "device_id": doc.get("device_id"),
                "location": doc.get("location"),
                "speed": doc.get("speed"),
                "heading": doc.get("heading"),
                "fuel": doc.get("fuel"),
                "UUID": doc.get("UUID"),
                "account_id": doc.get("account_id"),
                "category_id": doc.get("category_id"),
                "category_name": doc.get("category_name"),
                "message": doc.get("message"),
                "is_alert": doc.get("is_alert", False),
                "received_at": doc.get("received_at").isoformat() if doc.get("received_at") else None
            })

        # 5️⃣ Push all historical data at once
        await websocket.send_text(json.dumps({
            "type": "history",
            "device_id": device_id,
            "data": history_data
        }, default=str))

        await websocket.send_text("history_done")

        # 6️⃣ Add to live update subscription
        connected_clients.add(websocket)
        device_subscriptions[websocket] = device_id
        print(f"✅ Client subscribed for live updates: {device_id}")

        # 7️⃣ Keep alive
        while True:
            await asyncio.sleep(10)

    except WebSocketDisconnect:
        if websocket in connected_clients:
            connected_clients.discard(websocket)
            device_subscriptions.pop(websocket, None)
            print(f"❌ Disconnected: {device_id}")

    except Exception as e:
        print(f"❌ Error: {str(e)}")
        await websocket.send_text(f"Error: {str(e)}")
        await websocket.close()
