from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import httpx
import pinecone
from pinecone import Index
from elasticsearch import AsyncElasticsearch
import weaviate
import logging
from typing import Dict, Any
from sentence_transformers import SentenceTransformer
from pymilvus import Milvus, DataType

from cassandra.cluster import Cluster
from cassandra.auth import PlainTextAuthProvider

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class VectorQueryRequest(BaseModel):
    api_key: str
    vector_db_type: str
    user_query: str
    top_k: int = 5

class VectorDBQueryHandler:
    def __init__(self, vector_details: Dict[str, Any], user_query: str):
        self.api_key = vector_details.get("auth")
        self.vector_db_type = vector_details.get("type")
        self.endpoint = vector_details.get("url")
        self.user_query = user_query
        self.index_name = vector_details.get("index")
        self.top_k = vector_details.get("top_k", 5)
        self.model = vector_details.get("model", "all-MiniLM-L6-v2")

        # Initialize your model here
        # Initializing the model
        self.model = self.initialize_model()

    def initialize_model(self):
        # Loading a pre-trained model using the model name
        # The try-except block is to handle cases where the model name might be incorrect or not supported
        try:
            model = SentenceTransformer(self.model_name)
            return model
        except Exception as e:
            # Log the error or handle it as per your application's requirements
            print(f"Error loading model {self.model_name}: {e}")
            return None

    async def vector_query(self):
        if self.vector_db_type == "pinecone":
            return await self.query_pinecone()
        elif self.vector_db_type == "elasticsearch":
            return await self.query_elasticsearch()
        elif self.vector_db_type == "weaviate":
            return await self.query_weaviate()
        elif self.vector_db_type == "milvus":
            return await self.query_milvus()
        elif  self.vector_db_type == "chroma":
            return await self.query_chroma()
        elif  self.vector_db_type == "vectara":
            return await self.query_vectara()
        elif  self.vector_db_type == "astra":
            return await self.query_astra()
        else:
            return {"error": "Unsupported vector database type"}


    async def query_pinecone(self):
        #endpoint, api_key, user_query, index_name, top_k are the parameters
        # Ensure Pinecone is initialized and the index is set up
        pinecone.init(api_key=self.api_key)
        index = pinecone.Index(self.index_name)

        # Convert the text query to a vector using the loaded model
        query_vector = self.model.encode([self.user_query]).tolist()

        # Perform the vector search in Pinecone
        response = index.query(queries=query_vector, top_k=self.top_k)

        # Process Pinecone's response
        # Adjust the response handling according to Pinecone's actual response format and your requirements
        results = []
        if "matches" in response and len(response["matches"]) > 0:
            for match in response["matches"][0]:
                # Assuming you have a way to fetch document details (e.g., from another database or metadata stored in Pinecone)
                doc_details = await self.fetch_document_details(match["id"])
                results.append({
                    "id": match["id"],
                    "score": match["score"],
                    "link": doc_details.get("link", "No link available"),  # Handle according to your document details retrieval method
                })
        return results

    async def fetch_document_details(self, doc_id):
        # Placeholder for an asynchronous method to fetch document details based on document ID
        # Implement this method based on how your application stores and retrieves document metadata
        # For illustration, returning a dummy link; replace with your actual data retrieval logic
        return {"link": f"http://example.com/document/{doc_id}"}

    async def query_elasticsearch(self):
        #endpoint, api_key, user_query, index_name, top_k are the parameters
        es_endpoint = self.endpoint
        es = AsyncElasticsearch([es_endpoint], headers={"Authorization": f"ApiKey {self.api_key}"})

        query_body = {
            "query": {
                "simple_query_string": {
                    "query": self.user_query,
                    "fields": ["title", "content"],  # Adjust based on your document structure
                }
            },
            "size": self.top_k
        }

        try:
            response = await es.search(index=self.index_name, body=query_body)
            hits = response['hits']['hits']
            logger.info(f"Found {len(hits)} hits for query: {self.user_query}")

            results = [{
                "id": hit["_id"],
                "score": hit["_score"],
                "source": hit["_source"]  # Contains the document content and metadata
            } for hit in hits]
            return results
        except Exception as e:
            logger.error(f"Error performing search in Elasticsearch: {e}")
            # Proper use of HTTPException for FastAPI applications, ensuring correct HTTP status codes are used
            raise HTTPException(status_code=500, detail="An error occurred while performing the search.")
        finally:
            await es.close()

    async def query_weaviate(self):
        # Construct your query. Adjust based on whether you're using text or vector search.
        # The following is a basic text search example, assuming Weaviate's automatic text-to-vector conversion is enabled.
        
        #endpoint, index_name, user_query, top_k are the parameters
        weaviate_client = weaviate.Client(self.endpoint)
        try:
            # Perform the search query using near_text for text-based search
            response = weaviate_client.query.get(
                self.index_name
            ).with_limit(
                self.top_k
            ).with_near_text({
                "concepts": [self.user_query]
            }).do()

            hits = response["data"]["Get"][self.index_name]
            
            # Process and return the response
            results = [{
                "id": hit["id"],
                "name": hit.get("name", "No name available"),
                "content": hit.get("content", "No content available")
            } for hit in hits]
            
            return results
        except Exception as e:
            logger.error(f"Error performing search in Weaviate: {e}")
            # Decide how to handle the exception, e.g., re-raise or return a specific error message
            raise

    async def query_milvus(self):
        #endpoint, customer_id, api_key, user_query, index_name, top_k are the parameters
        milvus_client = Milvus(uri=self.endpoint)
        try:
            # Convert the user query to a vector
            query_vector = self.model.encode([self.user_query]).tolist()

            # Prepare the search parameters
            search_params = {
                "collection_name": self.index_name,
                "query_records": query_vector,
                "top_k": self.top_k,
                "params": {"metric_type": MetricType.L2, "params": {"nprobe": 10}},
            }

            # Perform the search
            status, results = milvus_client.search(**search_params)

            if not status.OK():
                logger.error(f"Milvus search error: {status}")
                return []

            # Process and return results
            processed_results = []
            for result in results[0]:  # Assuming the structure of results
                processed_results.append({
                    "id": result.id,
                    "distance": result.distance  # Or use 'score': 1 / (1 + result.distance) for similarity
                })

            return processed_results
        except Exception as e:
            logger.error(f"Error performing search in Milvus: {e}")
            raise

    async def query_vectara(self):
        #endpoint, customer_id, api_key, user_query, index_name, top_k are the parameters

        endpoint = f"{self.endpoint}"  # Adjust with actual Vectara query endpoint
        headers = {
            "Authorization": f"Bearer {self.api_key}",
            "Content-Type": "application/json"
        }
        payload = {
            "query": self.user_query,
            "customer_id": "your_customer_id",  # Replace with your actual customer ID
            "corpus_key": self.index_name,  # Assuming this corresponds to Vectara's corpus concept
            "top_k": self.top_k
        }

        try:
            async with httpx.AsyncClient() as client:
                response = await client.post(endpoint, headers=headers, json=payload)
                response.raise_for_status()  # Raises exception for 4XX/5XX responses
                
                data = response.json()
                hits = data.get("results", [])  # Adjust based on actual response structure
                
                # Process and return the response
                results = [{
                    "id": hit["documentId"],  # Adjust according to Vectara's response structure
                    "score": hit.get("score"),
                    # Extract other relevant fields as needed
                } for hit in hits]

                return results
        except httpx.HTTPStatusError as e:
            logger.error(f"HTTP error occurred during Vectara query: {e}")
        except Exception as e:
            logger.error(f"Error performing search in Vectara: {e}")
            # Handle other types of errors appropriately
            raise

    async def query_astra(self):
        # Setup authentication and connect to AstraDB

        #client_id, client_secret, endpoint, keyspace, table_name are the parameters
        
        auth_provider = PlainTextAuthProvider(username=self.client_id, password=self.client_secret)
        cluster = Cluster(cloud={'secure_connect_bundle': self.endpoint}, auth_provider=auth_provider)
        session = cluster.connect()

        try:
            # Adjust the CQL query based on your specific use case and schema
            # This is a simple SELECT statement as an example. You'll need to customize it.
            cql_query = f"SELECT * FROM {self.keyspace}.{self.table_name} WHERE your_search_field = %s"
            prepared = session.prepare(cql_query)
            rows = session.execute(prepared, [self.user_query])

            results = [{"id": row.id, "content": row.content} for row in rows]  # Adjust based on your schema
            return results
        except Exception as e:
            logger.error(f"Error querying AstraDB: {e}")
            raise
        finally:
            session.shutdown()

    async def perform_vector_search(query_request: VectorQueryRequest):
        handler = VectorDBQueryHandler(
            api_key=query_request.api_key,
            vector_db_type=query_request.vector_db_type,
            user_query=query_request.user_query,
            index_name="your_index_name",  # This needs to be passed or set correctly
            top_k=query_request.top_k
        )
        try:
            results = await handler.vector_query()
            return {"results": results}
        except HTTPException as e:
            raise e
        except Exception as e:
            raise HTTPException(status_code=500, detail=str(e))
