from datetime import datetime, timedelta
from bson import json_util, ObjectId
import json
from shapely.geometry import Point, Polygon
from aiokafka import AIOKafkaConsumer
from app.db import database
from app.utils.connections import connected_clients
from app.v1.sockets.tracking_data import notify_nearby_subscribers

# Define your operators dictionary
OPERATORS = {
    "==": lambda a, b: a == b,
    ">": lambda a, b: a > b,
    "<": lambda a, b: a < b,
    ">=": lambda a, b: a >= b,
    "<=": lambda a, b: a <= b,
    "!=": lambda a, b: a != b
}

connected_clients = set()  # WebSocket clients

async def consume_tracking_data():
    consumer = AIOKafkaConsumer(
        "vehicle_tracking",
        bootstrap_servers="127.0.0.1:9092",  # connect to Docker Kafka from host
        group_id="tracking-group",
        auto_offset_reset="earliest",         # read from beginning if no offset
        enable_auto_commit=True
    )
    await consumer.start()
    
    mongo_gen = database.get_mongo_db()
    mongo = next(mongo_gen)

    tracking_logs = mongo["tracking_logs"]
    alert_logs = mongo["alert_logs"]
    geofences = mongo["geofences"]
    geofence_rules = mongo["geofence_rules"]
    geofence_mappings = mongo["geofence_rule_mapping"]
    devices = mongo["iot_devices"]
    fleets = mongo["fleets"]
    workforce = mongo["workforce"]

    try:
        print("✅ Kafka consumer started")

        async for msg in consumer:
            print("📩 Message received from Kafka")
            try:
                data = json.loads(msg.value.decode("utf-8"))
                print("🔹 Kafka payload:", data)
            except Exception as e:
                print("❌ Error decoding message:", e)
                continue

            data["received_at"] = datetime.utcnow()
            data["alert_type"] = "W"
            device_id = data.get("device_id")
            coordinates = data.get("location", {}).get("coordinates", [])

            # Insert tracking log
            tracking_logs.insert_one(data)

            alerts = []

            # 1️⃣ Check inactivity (>30 mins)
            last_entry = tracking_logs.find_one(
                {"device_id": device_id},
                sort=[("received_at", -1)]
            )
            if last_entry:
                last_time = last_entry["received_at"]
                current_time = datetime.utcnow()
                gap = current_time - last_time

                if gap > timedelta(minutes=30):
                    alert = {
                        "device_id": device_id,
                        "type": "inactivity",
                        "message": f"Device {device_id} inactive for over 30 minutes.",
                        "triggered_at": current_time
                    }
                    alert_logs.insert_one(alert)
                    alerts.append(alert)

                    online_alert = {
                        "device_id": device_id,
                        "type": "recovery",
                        "message": f"Device {device_id} is back online after {gap.seconds // 60} minutes of inactivity.",
                        "triggered_at": current_time
                    }
                    alert_logs.insert_one(online_alert)
                    alerts.append(online_alert)

            # 2️⃣ Geofence rules
            mappings = list(geofence_mappings.find({"assigned_entity_id": device_id}))

            for mapping in mappings:
                geo_id = mapping.get("geofence_id")
                rule_id = mapping.get("geofence_rule_id")

                geofence = geofences.find_one({"_id": ObjectId(geo_id)})
                if not geofence or not coordinates:
                    continue

                point = Point(coordinates)
                polygon = Polygon(geofence["loc"]["coordinates"][0])
                inside = polygon.contains(point)

                if mapping.get("trigger_events") == "exit" and not inside:
                    rule = geofence_rules.find_one({"_id": ObjectId(rule_id)}) if rule_id else None
                    conditions_passed = True

                    if rule and rule.get("conditions"):
                        for cond in rule["conditions"]:
                            param = cond.get("parameter")
                            op = cond.get("operator")
                            expected = cond.get("value")

                            if op not in OPERATORS:
                                conditions_passed = False
                                break

                            actual = data.get(param)
                            try:
                                expected_val = float(expected)
                                actual_val = float(actual) if actual is not None else None
                                if actual_val is None or not OPERATORS[op](actual_val, expected_val):
                                    conditions_passed = False
                                    break
                            except Exception:
                                conditions_passed = False
                                break

                    if conditions_passed:
                        alert_msg = rule.get("alert_message") if rule else f"Device {device_id} exited geofence."
                        alert = {
                            "device_id": device_id,
                            "type": "geofence-exit",
                            "geofence": geofence["name"],
                            "message": alert_msg,
                            "triggered_at": data["received_at"]
                        }
                        alert_logs.insert_one(alert)
                        alerts.append(alert)

            # 3️⃣ Update last_location in fleet/workforce
            device = devices.find_one({"_id": ObjectId(device_id)})

            last_location_update = {
                "last_location": {
                    "coordinates": coordinates,
                    "speed": data.get("speed"),
                    "heading": data.get("heading"),
                    "fuel": data.get("fuel"),
                    "timestamp": data.get("timestamp"),
                },
                "last_updated": datetime.utcnow()
            }

            if device:
                associated_type = device.get("associated_entity_type")
                associated_id = device.get("associated_entity_id")

                if associated_type == "vehicle":
                    fleets.update_one({"_id": ObjectId(associated_id)}, {"$set": last_location_update})
                elif associated_type == "workforce":
                    workforce.update_one({"_id": ObjectId(associated_id)}, {"$set": last_location_update})

                # push live updates
                await notify_nearby_subscribers(device_id)
            else:
                print(f"⚠️ Device {device_id} not found. Fallback update applied")
                fleets.update_many({"devices": device_id}, {"$set": last_location_update})
                workforce.update_many({"devices": device_id}, {"$set": last_location_update})

            # 4️⃣ Broadcast to WebSocket clients
            for ws in connected_clients.copy():
                try:
                    await ws.send_text(json.dumps(data, default=json_util.default))
                    for alert in alerts:
                        await ws.send_text(json.dumps({
                            "type": "alert",
                            "message": alert.get("message", ""),
                            "device_id": alert.get("device_id"),
                            "alert_type": alert.get("type"),
                            "geofence": alert.get("geofence"),
                            "triggered_at": alert.get("triggered_at")
                        }, default=json_util.default))
                except Exception as e:
                    connected_clients.remove(ws)
                    print("❌ WebSocket error:", e)

    finally:
        await consumer.stop()
        try:
            next(mongo_gen)
        except StopIteration:
            pass
