import os
import logging

from llama_index.core import (
    SimpleDirectoryReader,
    VectorStoreIndex,
    StorageContext,
    Settings,
    get_response_synthesizer)
# from llama_index.core.query_engine import RetrieverQueryEngine, TransformQueryEngine
from llama_index.core.node_parser import SentenceSplitter
from llama_index.core.schema import TextNode, MetadataMode
# from llama_index.core.retrievers import VectorIndexRetriever

# from llama_index.core.response_synthesizers import ResponseMode
# from transformers import AutoTokenizer
from llama_index.core.vector_stores import VectorStoreQuery
from llama_index.vector_stores.qdrant import QdrantVectorStore
from qdrant_client import QdrantClient

from llama_index.llms.llama_cpp import LlamaCPP
from llama_index.embeddings.fastembed import FastEmbedEmbedding


QDRANT_API_URL = os.getenv('QDRANT_API_URL')
QDRANT_API_KEY = os.getenv('QDRANT_API_KEY')

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class ChatPDF:
    query_engine = None

    model_url = "https://huggingface.co/Qwen/Qwen1.5-1.8B-Chat-GGUF/resolve/main/qwen1_5-1_8b-chat-q4_k_m.gguf"
    # model_url = "https://huggingface.co/Qwen/Qwen1.5-1.8B-Chat-GGUF/resolve/main/qwen1_5-1_8b-chat-q8_0.gguf"
    # model_url = "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct-gguf/resolve/main/Phi-3-mini-4k-instruct-q4.gguf"

    # def messages_to_prompt(messages):
    #     prompt = ""
    #     for message in messages:
    #         if message.role == 'system':
    #             prompt += f"<|system|>\n{message.content}</s>\n"
    #         elif message.role == 'user':
    #             prompt += f"<|user|>\n{message.content}</s>\n"
    #         elif message.role == 'assistant':
    #             prompt += f"<|assistant|>\n{message.content}</s>\n"

    #     if not prompt.startswith("<|system|>\n"):
    #         prompt = "<|system|>\n</s>\n" + prompt

    #     prompt = prompt + "<|assistant|>\n"

    #     return prompt

    # def completion_to_prompt(completion):
    #     return f"<|system|>\n</s>\n<|user|>\n{completion}</s>\n<|assistant|>\n"


    def __init__(self):
        self.text_parser = SentenceSplitter(chunk_size=512, chunk_overlap=24)

        logger.info("initializing the vector store related objects")
        # client = QdrantClient(host="localhost", port=6333)
        client = QdrantClient(url=QDRANT_API_URL, api_key=QDRANT_API_KEY)
        # client = QdrantClient(":memory:")
        self.vector_store = QdrantVectorStore(
            client=client,
            collection_name="rag_documents",
            # enable_hybrid=True
        )

        logger.info("initializing the FastEmbedEmbedding")
        self.embed_model = FastEmbedEmbedding(
            # model_name="BAAI/bge-small-en"
        )

        llm = LlamaCPP(
            model_url=self.model_url,
            temperature=0.1,
            max_new_tokens=256,
            context_window=3900,
            # generate_kwargs={},
            # model_kwargs={"n_gpu_layers": -1},
            # messages_to_prompt=self.messages_to_prompt,
            # completion_to_prompt=self.completion_to_prompt,
            verbose=True,
        )

        # tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-4k-instruct")
        # tokenizer.save_pretrained("./models/tokenizer/")

        logger.info("initializing the global settings")
        Settings.text_splitter = self.text_parser
        Settings.embed_model = self.embed_model
        Settings.llm = llm
        # Settings.tokenzier = tokenizer
        Settings.transformations = [self.text_parser]

    def ingest(self, files_dir: str):
        text_chunks = []
        doc_ids = []
        nodes = []

        docs = SimpleDirectoryReader(input_dir=files_dir).load_data()

        logger.info("enumerating docs")
        for doc_idx, doc in enumerate(docs):
            curr_text_chunks = self.text_parser.split_text(doc.text)
            text_chunks.extend(curr_text_chunks)
            doc_ids.extend([doc_idx] * len(curr_text_chunks))

        logger.info("enumerating text_chunks")
        for idx, text_chunk in enumerate(text_chunks):
            node = TextNode(text=text_chunk)
            src_doc = docs[doc_ids[idx]]
            node.metadata = src_doc.metadata
            nodes.append(node)

        logger.info("enumerating nodes")
        for node in nodes:
            node_embedding = self.embed_model.get_text_embedding(
                node.get_content(metadata_mode=MetadataMode.ALL)
            )
            node.embedding = node_embedding

        logger.info("initializing the storage context")
        storage_context = StorageContext.from_defaults(vector_store=self.vector_store)
        logger.info("indexing the nodes in VectorStoreIndex")
        index = VectorStoreIndex(
            nodes=nodes,
            storage_context=storage_context,
            transformations=Settings.transformations,
        )

        # logger.info("configure retriever")
        # retriever = VectorIndexRetriever(
        #     index=index,
        #     similarity_top_k=6,
        #     # vector_store_query_mode="hybrid"
        # )

        # logger.info("configure response synthesizer")
        # response_synthesizer = get_response_synthesizer(
        #     # streaming=True,
        #     response_mode=ResponseMode.COMPACT,
        # )

        # logger.info("assemble query engine")
        # self.query_engine = RetrieverQueryEngine(
        #     retriever=retriever,
        #     response_synthesizer=response_synthesizer,
        # )

        self.query_engine = index.as_query_engine(
            streaming=True,
            # similarity_top_k=6,
        )

    def ask(self, query: str):
        if not self.query_engine:
            return "Please, add a PDF document first."

        logger.info("retrieving the response to the query")
        # response = self.query_engine.query(str_or_query_bundle=query)
        streaming_response = self.query_engine.query(query)
        return streaming_response.response_gen

    def clear(self):
        self.query_engine = None