from datetime import datetime, timedelta
from bson import json_util, ObjectId
import json
from shapely.geometry import Point, Polygon
from aiokafka import AIOKafkaConsumer
from app.db import database
from app.utils.connections import connected_clients, device_subscriptions
from app.v1.sockets.tracking_data import notify_nearby_subscribers

# Define your operators dictionary
OPERATORS = {
    "==": lambda a, b: a == b,
    ">": lambda a, b: a > b,
    "<": lambda a, b: a < b,
    ">=": lambda a, b: a >= b,
    "<=": lambda a, b: a <= b,
    "!=": lambda a, b: a != b
}

connected_clients = set()  # WebSocket clients

# async def consume_tracking_data():
#     consumer = AIOKafkaConsumer(
#         "vehicle_tracking",
#         bootstrap_servers="127.0.0.1:9092",  # connect to Docker Kafka from host
#         group_id="tracking-group",
#         auto_offset_reset="earliest",         # read from beginning if no offset
#         enable_auto_commit=True
#     )
#     await consumer.start()
    
#     mongo_gen = database.get_mongo_db()
#     mongo = next(mongo_gen)

#     tracking_logs = mongo["tracking_logs"]
#     alert_logs = mongo["alert_logs"]
#     geofences = mongo["geofences"]
#     geofence_rules = mongo["geofence_rules"]
#     geofence_mappings = mongo["geofence_rule_mapping"]
#     devices = mongo["iot_devices"]
#     fleets = mongo["fleets"]
#     workforce = mongo["workforce"]

#     try:
#         print("✅ Kafka consumer started")

#         async for msg in consumer:
#             print("📩 Message received from Kafka")
#             try:
#                 data = json.loads(msg.value.decode("utf-8"))
#                 print("🔹 Kafka payload:", data)
#             except Exception as e:
#                 print("❌ Error decoding message:", e)
#                 continue

#             data["received_at"] = datetime.utcnow()
#             data["alert_type"] = "W"
#             device_id = data.get("device_id")
#             coordinates = data.get("location", {}).get("coordinates", [])

#             # Insert tracking log
#             tracking_logs.insert_one(data)

#             alerts = []

#             # 1️⃣ Check inactivity (>30 mins)
#             last_entry = tracking_logs.find_one(
#                 {"device_id": device_id},
#                 sort=[("received_at", -1)]
#             )
#             if last_entry:
#                 last_time = last_entry["received_at"]
#                 current_time = datetime.utcnow()
#                 gap = current_time - last_time

#                 if gap > timedelta(minutes=30):
#                     alert = {
#                         "device_id": device_id,
#                         "type": "inactivity",
#                         "message": f"Device {device_id} inactive for over 30 minutes.",
#                         "triggered_at": current_time
#                     }
#                     alert_logs.insert_one(alert)
#                     alerts.append(alert)

#                     online_alert = {
#                         "device_id": device_id,
#                         "type": "recovery",
#                         "message": f"Device {device_id} is back online after {gap.seconds // 60} minutes of inactivity.",
#                         "triggered_at": current_time
#                     }
#                     alert_logs.insert_one(online_alert)
#                     alerts.append(online_alert)

#             # 2️⃣ Geofence rules
#             mappings = list(geofence_mappings.find({"assigned_entity_id": device_id}))

#             for mapping in mappings:
#                 geo_id = mapping.get("geofence_id")
#                 rule_id = mapping.get("geofence_rule_id")

#                 geofence = geofences.find_one({"_id": ObjectId(geo_id)})
#                 if not geofence or not coordinates:
#                     continue

#                 point = Point(coordinates)
#                 polygon = Polygon(geofence["loc"]["coordinates"][0])
#                 inside = polygon.contains(point)

#                 if mapping.get("trigger_events") == "exit" and not inside:
#                     rule = geofence_rules.find_one({"_id": ObjectId(rule_id)}) if rule_id else None
#                     conditions_passed = True

#                     if rule and rule.get("conditions"):
#                         for cond in rule["conditions"]:
#                             param = cond.get("parameter")
#                             op = cond.get("operator")
#                             expected = cond.get("value")

#                             if op not in OPERATORS:
#                                 conditions_passed = False
#                                 break

#                             actual = data.get(param)
#                             try:
#                                 expected_val = float(expected)
#                                 actual_val = float(actual) if actual is not None else None
#                                 if actual_val is None or not OPERATORS[op](actual_val, expected_val):
#                                     conditions_passed = False
#                                     break
#                             except Exception:
#                                 conditions_passed = False
#                                 break

#                     if conditions_passed:
#                         alert_msg = rule.get("alert_message") if rule else f"Device {device_id} exited geofence."
#                         alert = {
#                             "device_id": device_id,
#                             "type": "geofence-exit",
#                             "geofence": geofence["name"],
#                             "message": alert_msg,
#                             "triggered_at": data["received_at"]
#                         }
#                         alert_logs.insert_one(alert)
#                         alerts.append(alert)

#             # 3️⃣ Update last_location in fleet/workforce
#             device = devices.find_one({"_id": ObjectId(device_id)})

#             last_location_update = {
#                 "last_location": {
#                     "coordinates": coordinates,
#                     "speed": data.get("speed"),
#                     "heading": data.get("heading"),
#                     "fuel": data.get("fuel"),
#                     "timestamp": data.get("timestamp"),
#                 },
#                 "last_updated": datetime.utcnow()
#             }

#             if device:
#                 associated_type = device.get("associated_entity_type")
#                 associated_id = device.get("associated_entity_id")

#                 if associated_type == "vehicle":
#                     fleets.update_one({"_id": ObjectId(associated_id)}, {"$set": last_location_update})
#                 elif associated_type == "workforce":
#                     workforce.update_one({"_id": ObjectId(associated_id)}, {"$set": last_location_update})

#                 # push live updates
#                 await notify_nearby_subscribers(device_id)
#             else:
#                 print(f"⚠️ Device {device_id} not found. Fallback update applied")
#                 fleets.update_many({"devices": device_id}, {"$set": last_location_update})
#                 workforce.update_many({"devices": device_id}, {"$set": last_location_update})

#             # 4️⃣ Broadcast to WebSocket clients
#             for ws in connected_clients.copy():
#                 try:
#                     await ws.send_text(json.dumps(data, default=json_util.default))
#                     for alert in alerts:
#                         await ws.send_text(json.dumps({
#                             "type": "alert",
#                             "message": alert.get("message", ""),
#                             "device_id": alert.get("device_id"),
#                             "alert_type": alert.get("type"),
#                             "geofence": alert.get("geofence"),
#                             "triggered_at": alert.get("triggered_at")
#                         }, default=json_util.default))
#                 except Exception as e:
#                     connected_clients.remove(ws)
#                     print("❌ WebSocket error:", e)

#     finally:
#         await consumer.stop()
#         try:
#             next(mongo_gen)
#         except StopIteration:
#             pass

# async def broadcast_to_clients(payload: dict, alerts: list = None):
#     """
#     Push tracking + alerts to all connected WebSocket clients
#     """
#     for ws in connected_clients.copy():
#         try:
#             # push tracking update
#             await ws.send_text(json.dumps(payload, default=json_util.default))

#             # push alerts (if any)
#             if alerts:
#                 for alert in alerts:
#                     await ws.send_text(
#                         json.dumps(
#                             {
#                                 "type": "alert",
#                                 "message": alert.get("message", ""),
#                                 "device_id": alert.get("device_id"),
#                                 "alert_type": alert.get("type"),
#                                 "geofence": alert.get("geofence"),
#                                 "triggered_at": alert.get("triggered_at"),
#                             },
#                             default=json_util.default,
#                         )
#                     )
#         except Exception as e:
#             connected_clients.remove(ws)
#             print("❌ WebSocket error:", e)

# async def consume_tracking_data():
#     consumer = AIOKafkaConsumer(
#         "vehicle_tracking",
#         bootstrap_servers="127.0.0.1:9092",  # connect to Docker Kafka from host
#         group_id="tracking-group",
#         auto_offset_reset="earliest",  # read from beginning if no offset
#         enable_auto_commit=True,
#     )
#     await consumer.start()

#     mongo_gen = database.get_mongo_db()
#     mongo = next(mongo_gen)

#     tracking_logs = mongo["tracking_logs"]
#     alert_logs = mongo["alert_logs"]
#     geofences = mongo["geofences"]
#     geofence_rules = mongo["geofence_rules"]
#     geofence_mappings = mongo["geofence_rule_mapping"]
#     devices = mongo["iot_devices"]
#     fleets = mongo["fleets"]
#     workforce = mongo["workforce"]

#     try:
#         print("✅ Kafka consumer started")

#         async for msg in consumer:
#             print("📩 Message received from Kafka")
#             try:
#                 data = json.loads(msg.value.decode("utf-8"))
#                 print("🔹 Kafka payload:", data)
#             except Exception as e:
#                 print("❌ Error decoding message:", e)
#                 continue

#             data["received_at"] = datetime.utcnow()
#             data["alert_type"] = "W"
#             device_id = data.get("device_id")
#             coordinates = data.get("location", {}).get("coordinates", [])

#             # Insert tracking log
#             tracking_logs.insert_one(data)

#             alerts = []

#             # 1️⃣ Inactivity check
#             last_entry = tracking_logs.find_one(
#                 {"device_id": device_id}, sort=[("received_at", -1)]
#             )
#             if last_entry:
#                 last_time = last_entry["received_at"]
#                 current_time = datetime.utcnow()
#                 gap = current_time - last_time

#                 if gap > timedelta(minutes=30):
#                     alert = {
#                         "device_id": device_id,
#                         "type": "inactivity",
#                         "message": f"Device {device_id} inactive for over 30 minutes.",
#                         "triggered_at": current_time,
#                     }
#                     alert_logs.insert_one(alert)
#                     alerts.append(alert)

#                     online_alert = {
#                         "device_id": device_id,
#                         "type": "recovery",
#                         "message": f"Device {device_id} is back online after {gap.seconds // 60} minutes of inactivity.",
#                         "triggered_at": current_time,
#                     }
#                     alert_logs.insert_one(online_alert)
#                     alerts.append(online_alert)

#             # 2️⃣ Geofence rules
#             mappings = list(geofence_mappings.find({"assigned_entity_id": device_id}))

#             for mapping in mappings:
#                 geo_id = mapping.get("geofence_id")
#                 rule_id = mapping.get("geofence_rule_id")

#                 geofence = geofences.find_one({"_id": ObjectId(geo_id)})
#                 if not geofence or not coordinates:
#                     continue

#                 point = Point(coordinates)
#                 polygon = Polygon(geofence["loc"]["coordinates"][0])
#                 inside = polygon.contains(point)

#                 if mapping.get("trigger_events") == "exit" and not inside:
#                     rule = (
#                         geofence_rules.find_one({"_id": ObjectId(rule_id)})
#                         if rule_id
#                         else None
#                     )
#                     conditions_passed = True

#                     if rule and rule.get("conditions"):
#                         for cond in rule["conditions"]:
#                             param = cond.get("parameter")
#                             op = cond.get("operator")
#                             expected = cond.get("value")

#                             if op not in OPERATORS:
#                                 conditions_passed = False
#                                 break

#                             actual = data.get(param)
#                             try:
#                                 expected_val = float(expected)
#                                 actual_val = float(actual) if actual is not None else None
#                                 if actual_val is None or not OPERATORS[op](
#                                     actual_val, expected_val
#                                 ):
#                                     conditions_passed = False
#                                     break
#                             except Exception:
#                                 conditions_passed = False
#                                 break

#                     if conditions_passed:
#                         alert_msg = (
#                             rule.get("alert_message")
#                             if rule
#                             else f"Device {device_id} exited geofence."
#                         )
#                         alert = {
#                             "device_id": device_id,
#                             "type": "geofence-exit",
#                             "geofence": geofence["name"],
#                             "message": alert_msg,
#                             "triggered_at": data["received_at"],
#                         }
#                         alert_logs.insert_one(alert)
#                         alerts.append(alert)

#             # 3️⃣ Update last_location in fleet/workforce
#             device = devices.find_one({"_id": ObjectId(device_id)})

#             last_location_update = {
#                 "last_location": {
#                     "coordinates": coordinates,
#                     "speed": data.get("speed"),
#                     "heading": data.get("heading"),
#                     "fuel": data.get("fuel"),
#                     "timestamp": data.get("timestamp"),
#                 },
#                 "last_updated": datetime.utcnow(),
#             }

#             if device:
#                 associated_type = device.get("associated_entity_type")
#                 associated_id = device.get("associated_entity_id")

#                 if associated_type == "vehicle":
#                     fleets.update_one(
#                         {"_id": ObjectId(associated_id)}, {"$set": last_location_update}
#                     )
#                 elif associated_type == "workforce":
#                     workforce.update_one(
#                         {"_id": ObjectId(associated_id)}, {"$set": last_location_update}
#                     )

#                 # push live updates for nearby subscribers
#                 await notify_nearby_subscribers(device_id)
#             else:
#                 print(f"⚠️ Device {device_id} not found. Fallback update applied")
#                 fleets.update_many(
#                     {"devices": device_id}, {"$set": last_location_update}
#                 )
#                 workforce.update_many(
#                     {"devices": device_id}, {"$set": last_location_update}
#                 )

#             # 4️⃣ Broadcast full update to WebSocket clients
#             await broadcast_to_clients(data, alerts)

#     finally:
#         await consumer.stop()
#         try:
#             next(mongo_gen)
#         except StopIteration:
#             pass

# async def broadcast_to_clients(payload: dict, alerts: list = None):
#     """
#     Push tracking + alerts to all connected WebSocket clients
#     Only send to clients subscribed to the payload's device_id
#     """
#     device_id = payload.get("device_id")
#     for ws, subscribed_device in list(connected_clients.items()):
#         if subscribed_device != device_id:
#             continue  # Skip clients not subscribed to this device

#         try:
#             # Push tracking update
#             await ws.send_text(json.dumps(payload, default=json_util.default))

#             # Push alerts
#             if alerts:
#                 for alert in alerts:
#                     await ws.send_text(
#                         json.dumps({
#                             "type": "alert",
#                             "message": alert.get("message", ""),
#                             "device_id": alert.get("device_id"),
#                             "alert_type": alert.get("type"),
#                             "geofence": alert.get("geofence"),
#                             "triggered_at": alert.get("triggered_at")
#                         }, default=json_util.default)
#                     )
#         except Exception as e:
#             # Remove disconnected clients
#             connected_clients.pop(ws, None)
#             print("❌ WebSocket error:", e)

# async def broadcast_to_clients(payload: dict, alerts: list = None):
#     """
#     Push tracking + alerts to all connected WebSocket clients
#     Only send to clients subscribed to the payload's device_id
#     """
#     device_id = payload.get("device_id")

#     for ws in list(connected_clients):
#         subscribed_device = device_subscriptions.get(ws)
#         if subscribed_device != device_id:
#             continue  # Skip clients not subscribed to this device

#         try:
#             # Push tracking update
#             await ws.send_text(json.dumps(payload, default=json_util.default))

#             # Push alerts
#             if alerts:
#                 for alert in alerts:
#                     if alert.get("device_id") != device_id:
#                         continue
#                     await ws.send_text(json.dumps({
#                         "type": "alert",
#                         "message": alert.get("message", ""),
#                         "device_id": alert.get("device_id"),
#                         "alert_type": alert.get("type"),
#                         "geofence": alert.get("geofence"),
#                         "triggered_at": alert.get("triggered_at")
#                     }, default=json_util.default))

#         except Exception as e:
#             connected_clients.discard(ws)
#             device_subscriptions.pop(ws, None)
#             print("❌ WebSocket error:", e)

# async def consume_tracking_data():
#     consumer = AIOKafkaConsumer(
#         "vehicle_tracking",
#         bootstrap_servers="127.0.0.1:9092",
#         group_id="tracking-group",
#         auto_offset_reset="earliest",
#         enable_auto_commit=True
#     )
#     await consumer.start()

#     mongo_gen = database.get_mongo_db()
#     mongo = next(mongo_gen)

#     tracking_logs = mongo["tracking_logs"]
#     alert_logs = mongo["alert_logs"]
#     geofences = mongo["geofences"]
#     geofence_rules = mongo["geofence_rules"]
#     geofence_mappings = mongo["geofence_rule_mapping"]
#     devices = mongo["iot_devices"]
#     fleets = mongo["fleets"]
#     workforce = mongo["workforce"]

#     try:
#         print("✅ Kafka consumer started")

#         async for msg in consumer:
#             try:
#                 data = json.loads(msg.value.decode("utf-8"))
#             except Exception as e:
#                 print("❌ Error decoding message:", e)
#                 continue

#             data["received_at"] = datetime.utcnow()
#             data["alert_type"] = "W"
#             device_id = data.get("device_id")
#             coordinates = data.get("location", {}).get("coordinates", [])

#             # Insert tracking log
#             tracking_logs.insert_one(data)

#             alerts = []

#             # 1️⃣ Inactivity check (>30 min)
#             last_entry = tracking_logs.find_one({"device_id": device_id}, sort=[("received_at", -1)])
#             if last_entry:
#                 gap = datetime.utcnow() - last_entry["received_at"]
#                 if gap > timedelta(minutes=30):
#                     inactivity_alert = {
#                         "device_id": device_id,
#                         "type": "inactivity",
#                         "message": f"Device {device_id} inactive for over 30 minutes.",
#                         "triggered_at": datetime.utcnow()
#                     }
#                     recovery_alert = {
#                         "device_id": device_id,
#                         "type": "recovery",
#                         "message": f"Device {device_id} back online after {gap.seconds // 60} min inactivity.",
#                         "triggered_at": datetime.utcnow()
#                     }
#                     alert_logs.insert_many([inactivity_alert, recovery_alert])
#                     alerts.extend([inactivity_alert, recovery_alert])

#             # 2️⃣ Geofence rules
#             mappings = list(geofence_mappings.find({"assigned_entity_id": device_id}))
#             for mapping in mappings:
#                 geo_id = mapping.get("geofence_id")
#                 rule_id = mapping.get("geofence_rule_id")
#                 geofence = geofences.find_one({"_id": ObjectId(geo_id)})
#                 if not geofence or not coordinates:
#                     continue

#                 point = Point(coordinates)
#                 polygon = Polygon(geofence["loc"]["coordinates"][0])
#                 inside = polygon.contains(point)

#                 if mapping.get("trigger_events") == "exit" and not inside:
#                     rule = geofence_rules.find_one({"_id": ObjectId(rule_id)}) if rule_id else None
#                     conditions_passed = True
#                     if rule and rule.get("conditions"):
#                         for cond in rule["conditions"]:
#                             param = cond.get("parameter")
#                             op = cond.get("operator")
#                             expected = cond.get("value")
#                             if op not in OPERATORS:
#                                 conditions_passed = False
#                                 break
#                             actual = data.get(param)
#                             try:
#                                 if actual is None or not OPERATORS[op](float(actual), float(expected)):
#                                     conditions_passed = False
#                                     break
#                             except:
#                                 conditions_passed = False
#                                 break

#                     if conditions_passed:
#                         alert_msg = rule.get("alert_message") if rule else f"Device {device_id} exited geofence."
#                         alert = {
#                             "device_id": device_id,
#                             "type": "geofence-exit",
#                             "geofence": geofence["name"],
#                             "message": alert_msg,
#                             "triggered_at": data["received_at"]
#                         }
#                         alert_logs.insert_one(alert)
#                         alerts.append(alert)

#             # 3️⃣ Update last_location
#             device = devices.find_one({"_id": ObjectId(device_id)})
#             last_location_update = {
#                 "last_location": {
#                     "coordinates": coordinates,
#                     "speed": data.get("speed"),
#                     "heading": data.get("heading"),
#                     "fuel": data.get("fuel"),
#                     "timestamp": data.get("timestamp")
#                 },
#                 "last_updated": datetime.utcnow()
#             }

#             if device:
#                 assoc_type = device.get("associated_entity_type")
#                 assoc_id = device.get("associated_entity_id")
#                 if assoc_type == "vehicle" and assoc_id:
#                     fleets.update_one({"_id": ObjectId(assoc_id)}, {"$set": last_location_update})
#                 elif assoc_type == "workforce" and assoc_id:
#                     workforce.update_one({"_id": ObjectId(assoc_id)}, {"$set": last_location_update})

#                 # Push to nearby subscribers
#                 await notify_nearby_subscribers(device_id)
#             else:
#                 fleets.update_many({"devices": device_id}, {"$set": last_location_update})
#                 workforce.update_many({"devices": device_id}, {"$set": last_location_update})

#             # 4️⃣ Broadcast only to subscribed clients
#             await broadcast_to_clients(data, alerts)

#     finally:
#         await consumer.stop()
#         try:
#             next(mongo_gen)
#         except StopIteration:
#             pass

# async def consume_tracking_data():
#     consumer = AIOKafkaConsumer(
#         "vehicle_tracking",
#         bootstrap_servers="127.0.0.1:9092",
#         group_id="tracking-group",
#         auto_offset_reset="earliest",
#         enable_auto_commit=True
#     )
#     await consumer.start()

#     mongo_gen = database.get_mongo_db()
#     mongo = next(mongo_gen)

#     tracking_logs = mongo["tracking_logs"]
#     alert_logs = mongo["alert_logs"]
#     geofences = mongo["geofences"]
#     geofence_rules = mongo["geofence_rules"]
#     geofence_mappings = mongo["geofence_rule_mapping"]
#     devices = mongo["iot_devices"]
#     fleets = mongo["fleets"]
#     workforce = mongo["workforce"]

#     try:
#         print("✅ Kafka consumer started")
#         async for msg in consumer:
#             try:
#                 data = json.loads(msg.value.decode("utf-8"))
#             except Exception as e:
#                 print("❌ Error decoding Kafka message:", e)
#                 continue

#             device_id = data.get("device_id")
#             if not device_id:
#                 continue

#             data["received_at"] = datetime.utcnow()
#             data["alert_type"] = "W"
#             coordinates = data.get("location", {}).get("coordinates", [])

#             # 1️⃣ Insert tracking log
#             tracking_logs.insert_one(data)

#             alerts = []

#             # 2️⃣ Inactivity check (>30 min)
#             last_entry = tracking_logs.find_one(
#                 {"device_id": device_id}, sort=[("received_at", -1)]
#             )
#             if last_entry:
#                 gap = datetime.utcnow() - last_entry["received_at"]
#                 if gap > timedelta(minutes=30):
#                     inactivity_alert = {
#                         "device_id": device_id,
#                         "type": "inactivity",
#                         "message": f"Device {device_id} inactive for over 30 minutes.",
#                         "triggered_at": datetime.utcnow()
#                     }
#                     recovery_alert = {
#                         "device_id": device_id,
#                         "type": "recovery",
#                         "message": f"Device {device_id} back online after {gap.seconds // 60} min inactivity.",
#                         "triggered_at": datetime.utcnow()
#                     }
#                     alert_logs.insert_many([inactivity_alert, recovery_alert])
#                     alerts.extend([inactivity_alert, recovery_alert])

#             # 3️⃣ Geofence rules
#             mappings = list(geofence_mappings.find({"assigned_entity_id": device_id}))
#             for mapping in mappings:
#                 geo_id = mapping.get("geofence_id")
#                 rule_id = mapping.get("geofence_rule_id")
#                 geofence = geofences.find_one({"_id": ObjectId(geo_id)})
#                 if not geofence or not coordinates:
#                     continue

#                 point = Point(coordinates)
#                 polygon = Polygon(geofence["loc"]["coordinates"][0])
#                 inside = polygon.contains(point)

#                 if mapping.get("trigger_events") == "exit" and not inside:
#                     rule = geofence_rules.find_one({"_id": ObjectId(rule_id)}) if rule_id else None
#                     conditions_passed = True
#                     if rule and rule.get("conditions"):
#                         for cond in rule["conditions"]:
#                             param = cond.get("parameter")
#                             op = cond.get("operator")
#                             expected = cond.get("value")
#                             if op not in OPERATORS:
#                                 conditions_passed = False
#                                 break
#                             actual = data.get(param)
#                             try:
#                                 if actual is None or not OPERATORS[op](float(actual), float(expected)):
#                                     conditions_passed = False
#                                     break
#                             except:
#                                 conditions_passed = False
#                                 break

#                     if conditions_passed:
#                         alert_msg = rule.get("alert_message") if rule else f"Device {device_id} exited geofence."
#                         alert = {
#                             "device_id": device_id,
#                             "type": "geofence-exit",
#                             "geofence": geofence["name"],
#                             "message": alert_msg,
#                             "triggered_at": data["received_at"]
#                         }
#                         alert_logs.insert_one(alert)
#                         alerts.append(alert)

#             # 4️⃣ Update last_location
#             device = devices.find_one({"_id": ObjectId(device_id)})
#             last_location_update = {
#                 "last_location": {
#                     "coordinates": coordinates,
#                     "speed": data.get("speed"),
#                     "heading": data.get("heading"),
#                     "fuel": data.get("fuel"),
#                     "timestamp": data.get("timestamp")
#                 },
#                 "last_updated": datetime.utcnow()
#             }

#             if device:
#                 assoc_type = device.get("associated_entity_type")
#                 assoc_id = device.get("associated_entity_id")
#                 if assoc_type == "vehicle" and assoc_id:
#                     fleets.update_one({"_id": ObjectId(assoc_id)}, {"$set": last_location_update})
#                 elif assoc_type == "workforce" and assoc_id:
#                     workforce.update_one({"_id": ObjectId(assoc_id)}, {"$set": last_location_update})

#                 # Push to nearby subscribers (optional)
#                 await notify_nearby_subscribers(device_id)
#             else:
#                 fleets.update_many({"devices": device_id}, {"$set": last_location_update})
#                 workforce.update_many({"devices": device_id}, {"$set": last_location_update})

#             # 5️⃣ Broadcast only to clients subscribed to this device
#             await broadcast_to_clients(data, alerts)

#     finally:
#         await consumer.stop()
#         try:
#             next(mongo_gen)
#         except StopIteration:
#             pass

async def broadcast_to_clients(payload: dict, alerts: list = None):
    """
    Push tracking + alerts only to clients subscribed to the payload's device_id
    """
    device_id = payload.get("device_id")

    for ws in list(connected_clients):
        subscribed_device = device_subscriptions.get(ws)
        if subscribed_device != device_id:
            continue

        try:
            await ws.send_text(json.dumps(payload, default=json_util.default))

            if alerts:
                for alert in alerts:
                    if alert.get("device_id") != device_id:
                        continue
                    await ws.send_text(json.dumps({
                        "type": "alert",
                        "message": alert.get("message", ""),
                        "device_id": alert.get("device_id"),
                        "alert_type": alert.get("type"),
                        "geofence": alert.get("geofence"),
                        "triggered_at": alert.get("triggered_at")
                    }, default=json_util.default))

        except Exception as e:
            connected_clients.discard(ws)
            device_subscriptions.pop(ws, None)
            print("❌ WebSocket error:", e)

# async def consume_tracking_data():
#     consumer = AIOKafkaConsumer(
#         "vehicle_tracking",
#         bootstrap_servers="127.0.0.1:9092",
#         group_id="tracking-group",
#         auto_offset_reset="earliest",
#         enable_auto_commit=True
#     )
#     await consumer.start()

#     mongo_gen = database.get_mongo_db()
#     mongo = next(mongo_gen)

#     tracking_logs = mongo["tracking_logs"]
#     alert_logs = mongo["alert_logs"]
#     geofences = mongo["geofences"]
#     geofence_rules = mongo["geofence_rules"]
#     geofence_mappings = mongo["geofence_rule_mapping"]
#     devices = mongo["iot_devices"]
#     fleets = mongo["fleets"]
#     workforce = mongo["workforce"]

#     try:
#         print("✅ Kafka consumer started")
#         async for msg in consumer:
#             try:
#                 data = json.loads(msg.value.decode("utf-8"))
#             except Exception as e:
#                 print("❌ Error decoding Kafka message:", e)
#                 continue

#             device_id = data.get("device_id")
#             if not device_id:
#                 continue

#             data["received_at"] = datetime.utcnow()
#             data["alert_type"] = "W"
#             coordinates = data.get("location", {}).get("coordinates", [])

#             alerts = []

#             # 1️⃣ Inactivity check before inserting new record
#             last_entry = tracking_logs.find_one(
#                 {"device_id": device_id}, sort=[("received_at", -1)]
#             )
#             if last_entry:
#                 gap = datetime.utcnow() - last_entry["received_at"]
#                 if gap > timedelta(minutes=30):
#                     inactivity_alert = {
#                         "device_id": device_id,
#                         "type": "inactivity",
#                         "message": f"Device {device_id} inactive for over 30 minutes.",
#                         "triggered_at": datetime.utcnow()
#                     }
#                     recovery_alert = {
#                         "device_id": device_id,
#                         "type": "recovery",
#                         "message": f"Device {device_id} back online after {gap.seconds // 60} min inactivity.",
#                         "triggered_at": datetime.utcnow()
#                     }
#                     alert_logs.insert_many([inactivity_alert, recovery_alert])
#                     alerts.extend([inactivity_alert, recovery_alert])

#             # 2️⃣ Insert new tracking record
#             tracking_logs.insert_one(data)

#             # 3️⃣ Geofence rules
#             mappings = list(geofence_mappings.find({"assigned_entity_id": device_id}))
#             for mapping in mappings:
#                 geo_id = mapping.get("geofence_id")
#                 rule_id = mapping.get("geofence_rule_id")
#                 geofence = geofences.find_one({"_id": ObjectId(geo_id)})
#                 if not geofence or not coordinates:
#                     continue

#                 point = Point(coordinates)
#                 polygon = Polygon(geofence["loc"]["coordinates"][0])
#                 inside = polygon.contains(point)

#                 if mapping.get("trigger_events") == "exit" and not inside:
#                     rule = geofence_rules.find_one({"_id": ObjectId(rule_id)}) if rule_id else None
#                     conditions_passed = True
#                     if rule and rule.get("conditions"):
#                         for cond in rule["conditions"]:
#                             param = cond.get("parameter")
#                             op = cond.get("operator")
#                             expected = cond.get("value")
#                             if op not in OPERATORS:
#                                 conditions_passed = False
#                                 break
#                             actual = data.get(param)
#                             try:
#                                 if actual is None or not OPERATORS[op](float(actual), float(expected)):
#                                     conditions_passed = False
#                                     break
#                             except:
#                                 conditions_passed = False
#                                 break

#                     if conditions_passed:
#                         alert_msg = rule.get("alert_message") if rule else f"Device {device_id} exited geofence."
#                         alert = {
#                             "device_id": device_id,
#                             "type": "geofence-exit",
#                             "geofence": geofence["name"],
#                             "message": alert_msg,
#                             "triggered_at": data["received_at"]
#                         }
#                         alert_logs.insert_one(alert)
#                         alerts.append(alert)

#             # 4️⃣ Update last_location
#             device = devices.find_one({"_id": ObjectId(device_id)})
#             last_location_update = {
#                 "last_location": {
#                     "coordinates": coordinates,
#                     "speed": data.get("speed"),
#                     "heading": data.get("heading"),
#                     "fuel": data.get("fuel"),
#                     "timestamp": data.get("timestamp")
#                 },
#                 "last_updated": datetime.utcnow()
#             }

#             if device:
#                 assoc_type = device.get("associated_entity_type")
#                 assoc_id = device.get("associated_entity_id")
#                 if assoc_type == "vehicle" and assoc_id:
#                     fleets.update_one({"_id": ObjectId(assoc_id)}, {"$set": last_location_update})
#                 elif assoc_type == "workforce" and assoc_id:
#                     workforce.update_one({"_id": ObjectId(assoc_id)}, {"$set": last_location_update})

#                 await notify_nearby_subscribers(device_id)
#             else:
#                 fleets.update_many({"devices": device_id}, {"$set": last_location_update})
#                 workforce.update_many({"devices": device_id}, {"$set": last_location_update})

#             # 5️⃣ Broadcast only to subscribed clients
#             await broadcast_to_clients(data, alerts)

#     finally:
#         await consumer.stop()
#         try:
#             next(mongo_gen)
#         except StopIteration:
#             pass

async def consume_tracking_data():
    consumer = AIOKafkaConsumer(
        "vehicle_tracking",
        bootstrap_servers="stagingapi.movex.ai:9092",  # ✅ Live Kafka broker
        group_id="tracking-group",
        auto_offset_reset="earliest",  # read from beginning if no offset
        enable_auto_commit=True
    )
    await consumer.start()

    mongo_gen = database.get_mongo_db()
    mongo = next(mongo_gen)

    tracking_logs = mongo["tracking_logs"]
    alert_logs = mongo["alert_logs"]
    geofences = mongo["geofences"]
    geofence_rules = mongo["geofence_rules"]
    geofence_mappings = mongo["geofence_rule_mapping"]
    devices = mongo["iot_devices"]
    fleets = mongo["fleets"]
    workforce = mongo["workforce"]

    try:
        print("✅ Kafka consumer started")
        async for msg in consumer:
            try:
                data = json.loads(msg.value.decode("utf-8"))
            except Exception as e:
                print("❌ Error decoding Kafka message:", e)
                continue

            device_id = data.get("device_id")
            if not device_id:
                continue

            data["received_at"] = datetime.utcnow()
            data["alert_type"] = "W"
            coordinates = data.get("location", {}).get("coordinates", [])
            alerts = []

            # 1️⃣ Inactivity check
            last_entry = tracking_logs.find_one({"device_id": device_id}, sort=[("received_at", -1)])
            if last_entry:
                gap = datetime.utcnow() - last_entry["received_at"]
                if gap > timedelta(minutes=30):
                    inactivity_alert = {
                        "device_id": device_id,
                        "type": "inactivity",
                        "message": f"Device {device_id} inactive for over 30 minutes.",
                        "triggered_at": datetime.utcnow()
                    }
                    recovery_alert = {
                        "device_id": device_id,
                        "type": "recovery",
                        "message": f"Device {device_id} back online after {gap.seconds // 60} min inactivity.",
                        "triggered_at": datetime.utcnow()
                    }
                    alert_logs.insert_many([inactivity_alert, recovery_alert])
                    alerts.extend([inactivity_alert, recovery_alert])

            # 2️⃣ Insert tracking record
            tracking_logs.insert_one(data)

            # 3️⃣ Geofence rules
            mappings = list(geofence_mappings.find({"assigned_entity_id": device_id}))
            for mapping in mappings:
                geo_id = mapping.get("geofence_id")
                rule_id = mapping.get("geofence_rule_id")
                geofence = geofences.find_one({"_id": ObjectId(geo_id)})
                if not geofence or not coordinates:
                    continue

                point = Point(coordinates)
                polygon = Polygon(geofence["loc"]["coordinates"][0])
                inside = polygon.contains(point)

                if mapping.get("trigger_events") == "exit" and not inside:
                    rule = geofence_rules.find_one({"_id": ObjectId(rule_id)}) if rule_id else None
                    conditions_passed = True
                    if rule and rule.get("conditions"):
                        for cond in rule["conditions"]:
                            param = cond.get("parameter")
                            op = cond.get("operator")
                            expected = cond.get("value")
                            if op not in OPERATORS:
                                conditions_passed = False
                                break
                            actual = data.get(param)
                            try:
                                if actual is None or not OPERATORS[op](float(actual), float(expected)):
                                    conditions_passed = False
                                    break
                            except:
                                conditions_passed = False
                                break
                    if conditions_passed:
                        alert_msg = rule.get("alert_message") if rule else f"Device {device_id} exited geofence."
                        alert = {
                            "device_id": device_id,
                            "type": "geofence-exit",
                            "geofence": geofence["name"],
                            "message": alert_msg,
                            "triggered_at": data["received_at"]
                        }
                        alert_logs.insert_one(alert)
                        alerts.append(alert)

            # 4️⃣ Update last_location
            device = devices.find_one({"_id": ObjectId(device_id)})
            last_location_update = {
                "last_location": {
                    "coordinates": coordinates,
                    "speed": data.get("speed"),
                    "heading": data.get("heading"),
                    "fuel": data.get("fuel"),
                    "timestamp": data.get("timestamp")
                },
                "last_updated": datetime.utcnow()
            }

            if device:
                assoc_type = device.get("associated_entity_type")
                assoc_id = device.get("associated_entity_id")
                if assoc_type == "vehicle" and assoc_id:
                    fleets.update_one({"_id": ObjectId(assoc_id)}, {"$set": last_location_update})
                elif assoc_type == "workforce" and assoc_id:
                    workforce.update_one({"_id": ObjectId(assoc_id)}, {"$set": last_location_update})

                await notify_nearby_subscribers(device_id)
            else:
                fleets.update_many({"devices": device_id}, {"$set": last_location_update})
                workforce.update_many({"devices": device_id}, {"$set": last_location_update})

            # 5️⃣ Broadcast to subscribed clients (LIVE UPDATES)
            await broadcast_tracking_update(data)

    finally:
        await consumer.stop()
        try:
            next(mongo_gen)
        except StopIteration:
            pass

