#load & split data
from langchain.text_splitter import RecursiveCharacterTextSplitter
# embed data
from langchain_mistralai.embeddings import MistralAIEmbeddings
# vector store
from langchain_community.vectorstores import FAISS
# prompt
from langchain.prompts import PromptTemplate
# memory
from langchain.memory import ConversationBufferMemory
#llm
from langchain_mistralai.chat_models import ChatMistralAI

#chain modules
from langchain.chains import RetrievalQA



# import PyPDF2
import os
import re
from dotenv import load_dotenv
load_dotenv()
from collections import defaultdict

api_key = os.environ.get("MISTRAL_API_KEY")

class RagModule():
    def __init__(self):
        self.mistral_api_key = api_key
        self.model_name_embedding = "mistral-embed"
        self.embedding_model = MistralAIEmbeddings(model=self.model_name_embedding, mistral_api_key=self.mistral_api_key)
        self.chunk_size = 1000
        self.chunk_overlap = 120
        self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap)
        self.db_faiss_path = "data/vector_store"
        #params llm
        self.llm_model = "mistral-small"
        self.max_new_tokens = 512
        self.top_p = 0.5
        self.temperature = 0.1
    



    def split_text(self, text:str) -> list:
        """Split the text into chunk

        Args:
            text (str): _description_

        Returns:
            list: _description_
        """
        texts = self.text_splitter.split_text(text)
        return texts
    
    def get_metadata(self, texts:list) -> list:
        """_summary_

        Args:
            texts (list): _description_

        Returns:
            list: _description_
        """
        metadatas = [{"source": f'Paragraphe: {i}'} for i in range(len(texts))]
        return metadatas
    
    def get_faiss_db(self):
        """load local faiss vector store containing all embeddings 

        """
        db = FAISS.load_local(self.db_faiss_path, self.embedding_model)
        return db

    def set_custom_prompt(self, prompt_template:str):
        """Instantiate prompt template for Q&A retreival for each vectore stores

        Args:
            prompt_template (str): description of the prompt
            input_variables (list): variables in the prompt
        """
        prompt = PromptTemplate.from_template(
            template=prompt_template,
            )
    
        return prompt
    
    def load_mistral(self):
        """instantiate LLM
        """

        model_kwargs = {
        "mistral_api_key": self.mistral_api_key,
        "model": self.llm_model,
        "max_new_tokens": self.max_new_tokens,
        "top_p": self.top_p,
        "temperature": self.temperature,
        }

        llm = ChatMistralAI(**model_kwargs)
        
        return llm

    def retrieval_qa_memory_chain(self, db, prompt_template):
        """_summary_
        """
        llm = self.load_mistral()
        prompt = self.set_custom_prompt(prompt_template)
        memory = ConversationBufferMemory(
            memory_key = 'history',
            input_key = 'question'
        )
        chain_type_kwargs= {
            "prompt" : prompt,
            "memory" : memory
            }
        
        qa_chain = RetrievalQA.from_chain_type(
            llm = llm,
            chain_type = 'stuff',
            retriever = db.as_retriever(search_kwargs={"k":5}),
            chain_type_kwargs = chain_type_kwargs,
            return_source_documents = True,
            )

        return qa_chain

    def retrieval_qa_chain(self, db, prompt_template):
        """_summary_
        """
        llm = self.load_llm()
        prompt = self.set_custom_prompt(prompt_template)
       
        chain_type_kwargs= {
            "prompt" : prompt,
            }
        
        qa_chain = RetrievalQA.from_chain_type(
            llm = llm,
            chain_type = 'stuff',
            retriever = db.as_retriever(search_kwargs={"k":3}),
            chain_type_kwargs = chain_type_kwargs,
            return_source_documents = True,
            )

        return qa_chain
    
    
    
    def get_sources_document(self, source_documents:list) -> dict:
        """generate dictionnary with path (as a key) and list of pages associated to one path

        Args:
            source_document (list): list of documents containing source_document of rag response

        Returns:
            dict: {
                path/to/file1 : [0, 1, 3],
                path/to/file2 : [5, 2]
                }
        """
        sources = defaultdict(list)
        for doc in source_documents:
            sources[doc.metadata["source"]].append(doc.metadata["page"])
        
        return sources

    def shape_answer_with_source(self, answer: str, sources: dict):
        """_summary_

        Args:
            answer (str): _description_
            source (dict): _description_
        """
        pattern = r"^(.+)\/([^\/]+)$"
        
        source_msg = ""
        for path, page in sources.items():
            file = re.findall(pattern, path)[0][1]
            source_msg += f"\nFichier: {file} - Page: {page}"
    
        answer += f"\n{source_msg}"

        return answer