import asyncio
import json
import os
from dotenv import load_dotenv
from fastapi import WebSocket, WebSocketDisconnect
from aiokafka import AIOKafkaProducer
from datetime import datetime,timedelta
from app.utils.connections import connected_clients
from app.db import database
from bson import ObjectId, errors as bson_errors

from app.v1.services.platform.iotdevices import list_nearby_devices_service_socket
from app.v1.models.platform.iotdevices import NearbyDevicesRequest, IotDeviceWithMappingList
#from app.v1.sockets.tracking_data import notify_nearby_subscribers

# 🔹 Load env variables once
load_dotenv()
USE_KAFKA = os.getenv("USE_KAFKA", "false").lower() == "true"

# 🔹 Kafka producer (started in start_kafka_producer)
producer: AIOKafkaProducer = None 

# Keep active subscribers here
active_nearby_subscribers = {}

async def start_kafka_producer():
    global producer
    #producer = AIOKafkaProducer(bootstrap_servers="localhost:9092")
    producer = AIOKafkaProducer(bootstrap_servers="stagingapi.movex.ai:9092")
    await producer.start()

async def stop_kafka_producer():
    if producer:
        await producer.stop()

# async def websocket_tracking_data(websocket: WebSocket):
#     await websocket.accept()
#     try:
#         while True:
#             message = await websocket.receive_text()
#             data = json.loads(message)
#             data["received_at"] = datetime.utcnow().isoformat()
#             if producer:
#                 await producer.send("vehicle_tracking", json.dumps(data).encode("utf-8"))
#             await websocket.send_text("Message queued successfully.")
#     except WebSocketDisconnect:
#         pass

# async def websocket_tracking_data(websocket: WebSocket):
#     await websocket.accept()
#     mongo_gen = database.get_mongo_db()
#     mongo = next(mongo_gen)
#     tracking_logs = mongo["tracking_logs"]
#     alert_logs = mongo["alert_logs"]
#     devices = mongo["iot_devices"]
#     categories = mongo["alert_category"]
#     keywords = mongo["alert_keywords"]  # new collection

#     try:
#         while True:
#             message = await websocket.receive_text()
#             data = json.loads(message)
#             data["received_at"] = datetime.utcnow()

#             # 1. Identify device
#             device_id = data.get("device_id")
#             device = devices.find_one({"_id": ObjectId(device_id)})
#             if not device:
#                 await websocket.send_text("Unknown device")
#                 continue

#             # 2. Check for alert keywords in payload
#             payload_text = json.dumps(data).lower()
#             keyword_doc = keywords.find_one({
#                 "keyword": {"$in": payload_text.split()}
#             })

#             if keyword_doc:
#                 # 3. Map category from alert_category
#                 category = categories.find_one({"_id": keyword_doc["category_id"]})
#                 alert_doc = {
#                     "device_id": device_id,
#                     "account_id": device["account_id"],
#                     "category_id": category["_id"],
#                     "category_name": category["category_name"],
#                     "message": f"Alert detected: {keyword_doc['keyword']}",
#                     "raw_payload": data,
#                     "received_at": data["received_at"],
#                     "status": "unread"
#                 }
#                 alert_logs.insert_one(alert_doc)
#                 await websocket.send_text(f"✅ Alert stored under category: {category['category_name']}")
#             else:
#                 # 4. Normal tracking log
#                 tracking_logs.insert_one(data)
#                 if producer:
#                     await producer.send("vehicle_tracking", json.dumps(data, default=str).encode("utf-8"))
#                 await websocket.send_text("✅ Tracking log stored")

#     except WebSocketDisconnect:
#         pass

# async def websocket_tracking_data(websocket: WebSocket):
#     await websocket.accept()
#     mongo_gen = database.get_mongo_db()
#     mongo = next(mongo_gen)
#     tracking_logs = mongo["tracking_logs"]
#     alert_logs = mongo["alert_logs"]
#     devices = mongo["iot_devices"]
#     categories = mongo["alert_category"]
#     keywords = mongo["alert_keywords"]

#     try:
#         while True:
#             message = await websocket.receive_text()
#             data = json.loads(message)
#             data["received_at"] = datetime.utcnow()

#             # 1. Identify device
#             device_id = data.get("device_id")
#             device = devices.find_one({"_id": ObjectId(device_id)})
#             if not device:
#                 await websocket.send_text("Unknown device")
#                 continue

#             # 2. Check for alert keywords
#             payload_text = json.dumps(data).lower()
#             keyword_doc = keywords.find_one({
#                 "keyword": {"$in": payload_text.split()}
#             })

#             if keyword_doc:
#                 # 3. Map category
#                 category = categories.find_one({"_id": keyword_doc["category_id"]})
#                 alert_doc = {
#                     "device_id": device_id,
#                     "account_id": device["account_id"],
#                     "category_id": category["_id"],
#                     "category_name": category["category_name"],
#                     "message": f"Alert detected: {keyword_doc['keyword']}",
#                     "raw_payload": data,
#                     "received_at": data["received_at"],
#                     "status": "unread"
#                 }

#                 if USE_KAFKA and producer:
#                     await producer.send(
#                         "alerts",
#                         json.dumps(alert_doc, default=str).encode("utf-8")
#                     )
#                 else:
#                     alert_logs.insert_one(alert_doc)

#                 await websocket.send_text(f"✅ Alert processed under {category['category_name']}")

#             else:
#                 if USE_KAFKA and producer:
#                     # ✅ Publish tracking to Kafka
#                     await producer.send(
#                         "vehicle_tracking",
#                         json.dumps(data, default=str).encode("utf-8")
#                     )
#                 else:
#                     tracking_logs.insert_one(data)

#                 await websocket.send_text("✅ Tracking log processed")

#     except WebSocketDisconnect:
#         pass

# async def websocket_tracking_data(websocket: WebSocket):
#     await websocket.accept()
#     mongo_gen = database.get_mongo_db()
#     mongo = next(mongo_gen)
#     tracking_logs = mongo["tracking_logs"]
#     alert_logs = mongo["alert_logs"]
#     devices = mongo["iot_devices"]
#     categories = mongo["alert_category"]
#     keywords = mongo["alert_keywords"]
#     print("websocket_tracking_datawebsocket_tracking_datawebsocket_tracking_datawebsocket_tracking_data")
#     try:
#         while True:
#             message = await websocket.receive_text()
#             data = json.loads(message)
#             data["received_at"] = datetime.utcnow()
#             print("received_atreceived_atreceived_at")
#             print(data)
#             # 1. Identify device
#             device_id = data.get("device_id")           

#             try:
#                 device_obj_id = ObjectId(device_id)
#             except bson_errors.InvalidId:
#                 await websocket.send_text("❌ Invalid device_id format")
#                 continue

#             device = devices.find_one({"_id": device_obj_id})

#             if not device:
#                 await websocket.send_text("Unknown device")
#                 continue

#             # Default tracking entry
#             tracking_entry = {
#                 **data,
#                 "account_id": device["account_id"],
#                 "is_alert": False
#             }

#             # 2. Detect alerts
#             # payload_text = json.dumps(data).lower()
#             payload_text = json.dumps(data, default=str).lower()
#             keyword_doc = keywords.find_one({
#                 "keyword": {"$in": payload_text.split()}
#             })
#             print("device_iddevice_idddd")
#             print(keyword_doc)
#             if keyword_doc:
#                 # 3. Enrich tracking entry with alert info (NO status here)
#                 category = categories.find_one({"_id": keyword_doc["category_id"]})
#                 tracking_entry.update({
#                     "category_id": category["_id"],
#                     "category_name": category["category_name"],
#                     "message": f"Alert detected: {keyword_doc['keyword']}",
#                     "is_alert": True
#                 })

#                 # 4. Insert into alert_logs with status
#                 alert_doc = {
#                     "device_id": device_id,
#                     "account_id": device["account_id"],
#                     "category_id": category["_id"],
#                     "category_name": category["category_name"],
#                     "message": f"Alert detected: {keyword_doc['keyword']}",
#                     "raw_payload": data,
#                     "received_at": data["received_at"],
#                     "status": "unread"
#                 }
#                 if USE_KAFKA and producer:
#                     await producer.send(
#                         "alerts",
#                         json.dumps(alert_doc, default=str).encode("utf-8")
#                     )
#                 else:
#                     alert_logs.insert_one(alert_doc)

#             # 5. Always insert into tracking_logs
#             if USE_KAFKA and producer:
#                 print("Coming IFFF part 24111111")
#                 await producer.send(
#                     "vehicle_tracking",
#                     json.dumps(tracking_entry, default=str).encode("utf-8")
#                 )
#             else:
#                 print("Coming ELSE part 246666")
#                 tracking_logs.insert_one(tracking_entry)

#             if tracking_entry.get("is_alert"):
#                 await websocket.send_text(f"✅ Stored in tracking_logs + alert_logs ({tracking_entry['category_name']})")
#             else:
#                 await websocket.send_text("✅ Stored in tracking_logs")

#     except WebSocketDisconnect:
#         pass

# async def websocket_tracking_data(websocket: WebSocket):
#     await websocket.accept()
#     mongo_gen = database.get_mongo_db()
#     mongo = next(mongo_gen)
#     tracking_logs = mongo["tracking_logs"]
#     alert_logs = mongo["alert_logs"]
#     devices = mongo["iot_devices"]
#     categories = mongo["alert_category"]
#     keywords = mongo["alert_keywords"]
#     fleets = mongo["fleets"]            # <-- added
#     workforce = mongo["workforce"]      # <-- added

#     print("websocket_tracking_data STARTED")
#     try:
#         while True:
#             message = await websocket.receive_text()
#             data = json.loads(message)
#             data["received_at"] = datetime.utcnow()
#             print("Incoming WebSocket Data:", data)

#             # 1. Identify device
#             device_id = data.get("device_id")
#             try:
#                 device_obj_id = ObjectId(device_id)
#             except bson_errors.InvalidId:
#                 await websocket.send_text("❌ Invalid device_id format")
#                 continue

#             device = devices.find_one({"_id": device_obj_id})
#             if not device:
#                 await websocket.send_text("Unknown device")
#                 continue

#             # 2. Prepare tracking entry
#             tracking_entry = {
#                 **data,
#                 "account_id": device["account_id"],
#                 "is_alert": False
#             }
#             print("=== SPLIT DEBUG START ===")
#             print(payload_text.split())
#             print("=== SPLIT DEBUG END ===")

#             # 3. Detect alerts
#             payload_text = json.dumps(data, default=str).lower()
#             keyword_doc = keywords.find_one({"keyword": {"$in": payload_text.split()}})
#             if keyword_doc:
#                 category = categories.find_one({"_id": keyword_doc["category_id"]})
#                 tracking_entry.update({
#                     "category_id": category["_id"],
#                     "category_name": category["category_name"],
#                     "message": f"Alert detected: {keyword_doc['keyword']}",
#                     "is_alert": True
#                 })

#                 alert_doc = {
#                     "device_id": device_id,
#                     "account_id": device["account_id"],
#                     "category_id": category["_id"],
#                     "category_name": category["category_name"],
#                     "message": f"Alert detected: {keyword_doc['keyword']}",
#                     "raw_payload": data,
#                     "received_at": data["received_at"],
#                     "status": "unread"
#                 }
#                 if USE_KAFKA and producer:
#                     await producer.send(
#                         "alerts",
#                         json.dumps(alert_doc, default=str).encode("utf-8")
#                     )
#                 else:
#                     alert_logs.insert_one(alert_doc)

#             # 4. Insert into tracking_logs
#             if USE_KAFKA and producer:
#                 await producer.send(
#                     "vehicle_tracking",
#                     json.dumps(tracking_entry, default=str).encode("utf-8")
#                 )
#             else:
#                 tracking_logs.insert_one(tracking_entry)

#             # 5. Update Fleet/Workforce with last location
#             try:
#                 coordinates = data.get("location")  # expected {"lat": ..., "lng": ...}
#                 if coordinates:
#                     last_location_update = {
#                         "last_location": {
#                             "coordinates": coordinates,
#                             "speed": data.get("speed"),
#                             "heading": data.get("heading"),
#                             "fuel": data.get("fuel"),
#                             "timestamp": data.get("timestamp"),
#                         },
#                         "last_updated": datetime.utcnow()
#                     }

#                     associated_type = device.get("associated_entity_type")
#                     associated_id = device.get("associated_entity_id")
#                     print("device_iddevice_iddevice_id")
#                     print(device_id)
#                     # ✅ Always update iot_devices
#                     devices.update_one(
#                         {"_id": ObjectId(device_id)},
#                         {"$set": last_location_update}
#                     )

#                     # ✅ Then update linked entity
#                     if associated_type == "vehicle" and associated_id:
#                         fleets.update_one(
#                             {"_id": ObjectId(associated_id)},
#                             {"$set": last_location_update}
#                         )
#                     elif associated_type == "workforce" and associated_id:
#                         workforce.update_one(
#                             {"_id": ObjectId(associated_id)},
#                             {"$set": last_location_update}
#                         )
#                     else:
#                         # Fallback update if association is missing
#                         fleets.update_many(
#                             {"devices": device_id},
#                             {"$set": last_location_update}
#                         )
#                         workforce.update_many(
#                             {"devices": device_id},
#                             {"$set": last_location_update}
#                         )

#                     await notify_nearby_subscribers(device_id)

#             except Exception as e:
#                 print(f"⚠️ Failed to update last location for {device_id}: {e}")


#             # 6. Send ack to WebSocket client
#             if tracking_entry.get("is_alert"):
#                 await websocket.send_text(f"✅ Stored in tracking_logs + alert_logs ({tracking_entry['category_name']})")
#             else:
#                 await websocket.send_text("✅ Stored in tracking_logs")

#     except WebSocketDisconnect:
#         print("⚠️ WebSocket disconnected")
#         pass

async def websocket_tracking_data(websocket: WebSocket):
    await websocket.accept()
    mongo_gen = database.get_mongo_db()
    mongo = next(mongo_gen)
    tracking_logs = mongo["tracking_logs"]
    alert_logs = mongo["alert_logs"]
    devices = mongo["iot_devices"]
    categories = mongo["alert_category"]
    keywords = mongo["alert_keywords"]
    fleets = mongo["fleets"]
    workforce = mongo["workforce"]

    print("websocket_tracking_data STARTED")
    try:
        while True:
            message = await websocket.receive_text()
            data = json.loads(message)
            data["received_at"] = datetime.utcnow()
            print("Incoming WebSocket Data:", data)

            # 1. Identify device
            device_id = data.get("device_id")
            try:
                device_obj_id = ObjectId(device_id)
            except bson_errors.InvalidId:
                await websocket.send_text("❌ Invalid device_id format")
                continue

            device = devices.find_one({"_id": device_obj_id})
            if not device:
                await websocket.send_text("Unknown device")
                continue

            # 2. Prepare tracking entry
            tracking_entry = {
                **data,
                "account_id": device["account_id"],
                "is_alert": False
            }

            # 🔹 Build searchable text from entire payload
            payload_text = json.dumps(data, default=str).lower()

            # 3. Detect alerts by scanning keywords
            all_keywords = list(keywords.find({}, {"keyword": 1, "category_id": 1}))
            matched_keyword = None
            for kw in all_keywords:
                if kw["keyword"].lower() in payload_text:
                    matched_keyword = kw
                    break

            print("=== ALERT DEBUG START ===")
            print("Payload text:", payload_text)
            print("Matched keyword:", matched_keyword)
            print("=== ALERT DEBUG END ===")

            if matched_keyword:
                category = categories.find_one({"_id": matched_keyword["category_id"]})
                tracking_entry.update({
                    "category_id": category["_id"],
                    "category_name": category["category_name"],
                    "message": f"Alert detected: {matched_keyword['keyword']}",
                    "is_alert": True
                })

                alert_doc = {
                    "device_id": device_id,
                    "account_id": device["account_id"],
                    "category_id": category["_id"],
                    "category_name": category["category_name"],
                    "message": f"Alert detected: {matched_keyword['keyword']}",
                    "raw_payload": data,
                    "received_at": data["received_at"],
                    "status": "unread"
                }
                if USE_KAFKA and producer:
                    await producer.send(
                        "alerts",
                        json.dumps(alert_doc, default=str).encode("utf-8")
                    )
                else:
                    alert_logs.insert_one(alert_doc)

            # 4. Insert into tracking_logs
            if USE_KAFKA and producer:
                await producer.send(
                    "vehicle_tracking",
                    json.dumps(tracking_entry, default=str).encode("utf-8")
                )
            else:
                tracking_logs.insert_one(tracking_entry)

            # 5. Update Fleet/Workforce with last location
            try:
                coordinates = data.get("location")
                if coordinates:
                    last_location_update = {
                        "last_location": {
                            "coordinates": coordinates,
                            "speed": data.get("speed"),
                            "heading": data.get("heading"),
                            "fuel": data.get("fuel"),
                            "timestamp": data.get("timestamp"),
                        },
                        "last_updated": datetime.utcnow()
                    }

                    associated_type = device.get("associated_entity_type")
                    associated_id = device.get("associated_entity_id")

                    # ✅ Always update iot_devices
                    devices.update_one(
                        {"_id": ObjectId(device_id)},
                        {"$set": last_location_update}
                    )

                    # ✅ Then update linked entity
                    if associated_type == "vehicle" and associated_id:
                        fleets.update_one(
                            {"_id": ObjectId(associated_id)},
                            {"$set": last_location_update}
                        )
                    elif associated_type == "workforce" and associated_id:
                        workforce.update_one(
                            {"_id": ObjectId(associated_id)},
                            {"$set": last_location_update}
                        )
                    else:
                        fleets.update_many(
                            {"devices": device_id},
                            {"$set": last_location_update}
                        )
                        workforce.update_many(
                            {"devices": device_id},
                            {"$set": last_location_update}
                        )

                    await notify_nearby_subscribers(device_id)

            except Exception as e:
                print(f"⚠️ Failed to update last location for {device_id}: {e}")

            # 6. Send ack to WebSocket client
            if tracking_entry.get("is_alert"):
                await websocket.send_text(
                    f"✅ Stored in tracking_logs + alert_logs ({tracking_entry['category_name']})"
                )
            else:
                await websocket.send_text("✅ Stored in tracking_logs")

    except WebSocketDisconnect:
        print("⚠️ WebSocket disconnected")
        pass

async def websocket_realtime_updates(websocket: WebSocket):
    await websocket.accept()
    connected_clients.add(websocket)
    print("✅ Client connected. Total:", len(connected_clients))  # ✅ Moved here

    try:
        while True:
            await asyncio.sleep(1)
    except WebSocketDisconnect:
        connected_clients.remove(websocket)
        print("❌ Client disconnected. Remaining:", len(connected_clients))

# async def websocket_historical_route(websocket: WebSocket):
#     await websocket.accept()
#     try:
#         query = await websocket.receive_text()
#         query_data = json.loads(query)
#         device_id = query_data.get("device_id")
#         datetime_str = query_data.get("date")  # e.g. "2025-06-01 09:00:00"

#         if not device_id or not datetime_str:
#             await websocket.send_text("Missing device_id or date.")
#             await websocket.close()
#             return

#         # ✅ Parse "YYYY-MM-DD HH:MM:SS"
#         try:
#             start_date = datetime.strptime(datetime_str, "%Y-%m-%d %H:%M:%S")
#             end_date = start_date + timedelta(hours=1)   # Example: 1-hour window
#         except ValueError:
#             await websocket.send_text("Invalid date format. Use YYYY-MM-DD HH:MM:SS")
#             await websocket.close()
#             return

#         # DB connection
#         mongo_gen = database.get_mongo_db()
#         mongo = next(mongo_gen)
#         collection = mongo["tracking_logs"]
#         alert_logs = mongo["alert_logs"]
#         devices = mongo["iot_devices"]
#         categories = mongo["alert_category"]
#         keywords = mongo["alert_keywords"]

#         # Fetch historical points in that range
#         cursor = collection.find({
#             "device_id": device_id,
#             "received_at": {
#                 "$gte": start_date,
#                 "$lt": end_date
#             }
#         }, {
#             "location": 1, "speed": 1, "heading": 1,
#             "fuel": 1, "received_at": 1, "alert_type": 1
#         })

#         print("Query returned:", cursor)

#         # Stream points one-by-one
#         for doc in cursor:
#             # 🔎 Enrich alerts just like in websocket_tracking_data
#             enriched_doc = dict(doc)
#             enriched_doc["is_alert"] = False

#             payload_text = json.dumps(doc, default=str).lower()
#             keyword_doc = keywords.find_one({
#                 "keyword": {"$in": payload_text.split()}
#             })

#             if keyword_doc:
#                 category = categories.find_one({"_id": keyword_doc["category_id"]})
#                 enriched_doc.update({
#                     "category_id": category["_id"],
#                     "category_name": category["category_name"],
#                     "message": f"Alert detected: {keyword_doc['keyword']}",
#                     "is_alert": True
#                 })

#                 alert_doc = {
#                     "device_id": device_id,
#                     "account_id": devices.find_one({"_id": ObjectId(device_id)})["account_id"],
#                     "category_id": category["_id"],
#                     "category_name": category["category_name"],
#                     "message": f"Alert detected: {keyword_doc['keyword']}",
#                     "raw_payload": doc,
#                     "received_at": doc["received_at"],
#                     "status": "historical"
#                 }
#                 alert_logs.insert_one(alert_doc)

#             # Send enriched doc to client
#             await websocket.send_text(json.dumps(enriched_doc, default=str))
#             await asyncio.sleep(0.2)

#         await websocket.send_text("done")
#         await websocket.close()

#         try:
#             next(mongo_gen)  # Cleanup DB connection
#         except StopIteration:
#             pass

#     except WebSocketDisconnect:
#         print("Client disconnected.")
#     except Exception as e:
#         print(f"Error: {str(e)}")
#         await websocket.send_text(f"Error: {str(e)}")
#         await websocket.close()

# async def websocket_historical_route(websocket: WebSocket):
#     await websocket.accept()
#     mongo_gen = None

#     try:
#         query = await websocket.receive_text()
#         query_data = json.loads(query)
#         #device_id = query_data.get("device_id")
#         device_id = "683d3ffdc355c05ce30f9c94"
#         datetime_str = query_data.get("date")  # e.g. "2025-06-01 09:00:00"

#         if not device_id or not datetime_str:
#             await websocket.send_text("Missing device_id or date.")
#             await websocket.close()
#             return

#         # Parse "YYYY-MM-DD HH:MM:SS"
#         try:
#             start_date = datetime.strptime(datetime_str, "%Y-%m-%d %H:%M:%S")
#             end_date = datetime.now()   # ✅ current date and time
#         except ValueError:
#             await websocket.send_text("Invalid date format. Use YYYY-MM-DD HH:MM:SS")
#             await websocket.close()
#             return

#         # DB connection
#         mongo_gen = database.get_mongo_db()
#         mongo = next(mongo_gen)
#         collection = mongo["tracking_logs"]
#         alert_logs = mongo["alert_logs"]
#         devices = mongo["iot_devices"]
#         categories = mongo["alert_category"]
#         keywords = mongo["alert_keywords"]

#         # Fetch historical points in that range
#         cursor = collection.find(
#             {
#                 "device_id": device_id,
#                 "received_at": {
#                     "$gte": start_date,
#                     "$lt": end_date
#                 }
#             },
#             {
#                 "location": 1, "speed": 1, "heading": 1,
#                 "fuel": 1, "received_at": 1, "alert_type": 1
#             }
#         )

#         docs = list(cursor)
#         print(f"Query returned {len(docs)} records")

#         # Stream points one-by-one
#         for doc in docs:
#             enriched_doc = dict(doc)
#             enriched_doc["is_alert"] = False

#             payload_text = json.dumps(doc, default=str).lower()
#             keyword_doc = keywords.find_one({
#                 "keyword": {"$in": payload_text.split()}
#             })

#             if keyword_doc:
#                 category = categories.find_one({"_id": keyword_doc["category_id"]})
#                 enriched_doc.update({
#                     "category_id": category["_id"],
#                     "category_name": category["category_name"],
#                     "message": f"Alert detected: {keyword_doc['keyword']}",
#                     "is_alert": True
#                 })

#                 # Insert into alert_logs
#                 alert_doc = {
#                     "device_id": device_id,
#                     "account_id": devices.find_one({"_id": ObjectId(device_id)})["account_id"],
#                     "category_id": category["_id"],
#                     "category_name": category["category_name"],
#                     "message": f"Alert detected: {keyword_doc['keyword']}",
#                     "raw_payload": doc,
#                     "received_at": doc["received_at"],
#                     "status": "historical"
#                 }
#                 alert_logs.insert_one(alert_doc)

#             # Send enriched doc to client
#             await websocket.send_text(json.dumps(enriched_doc, default=str))
#             await asyncio.sleep(0.2)

#         await websocket.send_text("done")
#         await websocket.close()

#     except WebSocketDisconnect:
#         print("Client disconnected.")
#     except Exception as e:
#         print(f"Error: {str(e)}")
#         try:
#             await websocket.send_text(f"Error: {str(e)}")
#         except Exception:
#             pass
#         await websocket.close()
#     finally:
#         if mongo_gen:
#             mongo_gen.close()

# async def websocket_historical_route(websocket: WebSocket):
#     await websocket.accept()
#     try:
#         query = await websocket.receive_text()
#         query_data = json.loads(query)
#         #device_id = query_data.get("device_id")
#         device_id = "683d3ffdc355c05ce30f9c94"
#         datetime_str = query_data.get("date")  # e.g. "2025-06-01 09:00:00"

#         if not device_id or not datetime_str:
#             await websocket.send_text("Missing device_id or date.")
#             await websocket.close()
#             return

#         # ✅ Parse "YYYY-MM-DD HH:MM:SS"
#         try:
#             start_date = datetime.strptime(datetime_str, "%Y-%m-%d %H:%M:%S")
#             end_date = datetime.now()   # ✅ current date and time
#         except ValueError:
#             await websocket.send_text("Invalid date format. Use YYYY-MM-DD HH:MM:SS")
#             await websocket.close()
#             return

#         # DB connection
#         mongo_gen = database.get_mongo_db()
#         mongo = next(mongo_gen)
#         collection = mongo["tracking_logs"]

#         # Fetch historical points in that range
#         cursor = collection.find({
#             "device_id": device_id,
#             "received_at": {
#                 "$gte": start_date,
#                 "$lt": end_date
#             }
#         }, {"location": 1, "speed": 1, "heading": 1, "received_at": 1})


#         # ✅ print only after cursor is assigned
#         print("Query returned:", cursor)

#         # Stream points one-by-one
#         for doc in cursor:
#             print(json.dumps(doc, default=str))
#             await websocket.send_text(json.dumps(doc, default=str))
#             await asyncio.sleep(0.2)  # Simulate streaming

#         await websocket.send_text("done")
#         await websocket.close()

#         try:
#             next(mongo_gen)  # Cleanup DB connection
#         except StopIteration:
#             pass

#     except WebSocketDisconnect:
#         print("Client disconnected.")
#     except Exception as e:
#         print(f"Error: {str(e)}")
#         await websocket.send_text(f"Error: {str(e)}")
#         await websocket.close()

async def websocket_historical_route(websocket: WebSocket):
    await websocket.accept()
    try:
        query = await websocket.receive_text()
        query_data = json.loads(query)

        # ✅ Device ID (hardcoded for now, or use query_data.get("device_id"))
        #device_id = "683d3ffdc355c05ce30f9c94"
        device_id = query_data.get("device_id")
        datetime_str = query_data.get("date")  # e.g. "2025-06-01 09:00:00"

        if not device_id or not datetime_str:
            await websocket.send_text("Missing device_id or date.")
            await websocket.close()
            return

        # ✅ Parse "YYYY-MM-DD HH:MM:SS"
        try:
            start_date = datetime.strptime(datetime_str, "%Y-%m-%d %H:%M:%S")
            end_date = datetime.utcnow()   # ✅ Now = current date and time
        except ValueError:
            await websocket.send_text("Invalid date format. Use YYYY-MM-DD HH:MM:SS")
            await websocket.close()
            return

        # DB connection
        mongo_gen = database.get_mongo_db()
        mongo = next(mongo_gen)
        collection = mongo["tracking_logs"]

        # ✅ Fetch documents within range
        cursor = collection.find(
            {
                "device_id": device_id,
                "received_at": {
                    "$gte": start_date,
                    "$lt": end_date
                }
            },
            {
                "_id": 1,
                "device_id": 1,
                "location": 1,
                "speed": 1,
                "heading": 1,
                "fuel": 1,
                "received_at": 1,
                "account_id": 1,
                "is_alert": 1,
                "category_id": 1,
                "category_name": 1,
                "message": 1,
                "UUID": 1
            }
        ).sort("received_at", 1)
        print("cursorcursorcursorcursor")
        print(cursor)
        print("************************")
        # ✅ Stream documents one by one (normal loop, not async for)
        for doc in cursor:
            response = {
                "id": str(doc["_id"]),
                "device_id": doc.get("device_id"),
                "location": doc.get("location"),
                "speed": doc.get("speed"),
                "heading": doc.get("heading"),
                "fuel": doc.get("fuel"),
                "UUID": doc.get("UUID"),
                "account_id": doc.get("account_id"),
                "category_id": doc.get("category_id"),
                "category_name": doc.get("category_name"),
                "message": doc.get("message"),
                "is_alert": doc.get("is_alert", False),
                "received_at": doc.get("received_at").isoformat() if doc.get("received_at") else None
            }

            print("Historical record is alert checking:", response)
            await websocket.send_text(json.dumps(response, default=str))
            await asyncio.sleep(0.2)  # simulate streaming delay

        await websocket.send_text("done")
        await websocket.close()

        # ✅ Close DB connection
        try:
            next(mongo_gen)
        except StopIteration:
            pass

    except WebSocketDisconnect:
        print("⚠️ Client disconnected.")
    except Exception as e:
        print(f"❌ Error: {str(e)}")
        await websocket.send_text(f"Error: {str(e)}")
        await websocket.close()
        
async def websocket_workforce_schedules(websocket: WebSocket):
    await websocket.accept()  # ✅ Accept first

    try:
        msg = await websocket.receive_text()
        data = json.loads(msg)
        workforce_id = data.get("workforce_id")

        if not workforce_id:
            await websocket.send_text("Missing workforce_id.")
            await websocket.close()
            return

        connected_clients[workforce_id] = websocket
        print(f"✅ Workforce {workforce_id} connected. Total clients: {len(connected_clients)}")

        while True:
            await asyncio.sleep(1)

    except WebSocketDisconnect:
        if workforce_id in connected_clients:
            del connected_clients[workforce_id]
            print(f"❌ Workforce {workforce_id} disconnected. Remaining: {len(connected_clients)}")

async def websocket_nearby_devices(websocket: WebSocket):
    await websocket.accept()
    mongo_gen = database.get_mongo_db()
    db = next(mongo_gen)

    try:
        # Receive initial subscription request
        init_message = await websocket.receive_text()
        data = json.loads(init_message)
        payload = NearbyDevicesRequest(**data)

        # Store subscriber
        active_nearby_subscribers[websocket] = {
            "latitude": payload.latitude,
            "longitude": payload.longitude,
            "radius_km": payload.radius_km,
            "db": db,
        }

        # Send initial device list
        result = list_nearby_devices_service_socket(
            payload.latitude, payload.longitude, payload.radius_km, db
        )
        response = IotDeviceWithMappingList(**result)
        await websocket.send_text(json.dumps(response.dict(), default=str))

        # Keep socket open
        while True:
            await websocket.receive_text()  # ignore pings/extra messages

    except WebSocketDisconnect:
        print("❌ Nearby devices WS disconnected")
        if websocket in active_nearby_subscribers:
            del active_nearby_subscribers[websocket]


# Notify subscribers when device location changes
async def notify_nearby_subscribers(updated_device):
    for ws, sub in list(active_nearby_subscribers.items()):
        try:
            result = list_nearby_devices_service_socket(
                sub["latitude"], sub["longitude"], sub["radius_km"], sub["db"]
            )
            response = IotDeviceWithMappingList(**result)
            await ws.send_text(json.dumps(response.dict(), default=str))
        except Exception as e:
            print(f"⚠️ Failed sending update to subscriber: {e}")

