import logging #import create_history_aware_retriever, from langchain.chains import create_retrieval_chain from langchain.chains.combine_documents import create_stuff_documents_chain from app.rag_pipeline.prompt_utils import qa_prompt from app.rag_pipeline.chroma_client import get_chroma_client from app.settings import Config # from prompt_utils import qa_prompt # from chroma_client import get_chroma_client # import sys # import os # parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) # sys.path.insert(0, parent_dir) # from settings import Config conf = Config() MODELS_PATH = conf.MODELS_PATH #'/models' CONTEXT_WINDOW_SIZE = 2048 MAX_NEW_TOKENS = 2048 N_BATCH= 512 N_GPU_LAYERS = 1 MODEL_ID = conf.MODEL_ID #"TheBloke/Mistral-7B-v0.1-GGUF" MODEL_BASENAME = conf.MODEL_BASENAME # "mistral-7b-v0.1.Q4_0.gguf" device_type = 'cpu' logger = logging.getLogger(__name__) class RetrieverChain: def __init__(self, collection_name, embedding_function, persist_directory): try: self.vector_db = get_chroma_client(collection_name, embedding_function, persist_directory) except Exception as e: logger.error(f"Error creating RetrieverChain: {e}") raise def get_retriever(self): try: retriever = self.vector_db.as_retriever(search_type="mmr", search_kwargs={"k": 5, "fetch_k": 2}) return retriever except Exception as e: logger.error(f"Failed to get retriever: {e}") raise def get_conversational_rag_chain(self, llm): try: if self.get_retriever is None: logger.error(f"Retriever must not be None") raise ValueError("Retriever must not be None") if llm is None: logger.error(f"Model must not be None") raise ValueError("Model must not be None") question_answer_chain = create_stuff_documents_chain(llm, qa_prompt) return create_retrieval_chain(self.get_retriever(), question_answer_chain) except Exception as e: logger.error(f"Error creating RAG chain: {e}") raise def get_relevent_docs(self, user_input): try: docs = self.vector_db.as_retriever(search_type="mmr", search_kwargs={"k": 6, "fetch_k": 3}).get_relevant_documents(user_input) logger.info(f"Relevent documents for {user_input}: {docs}") # Access the retrieved documents # print("Relevent Docs") # for doc in docs: # print(doc.page_content) # Access the original text # print(doc.metadata) # Access any metadata associated with the document # print("Relevent Docs end") return docs except Exception as e: logger.error(f"Error getting response: {e}") raise def get_response(self, user_input, llm): try: qa_rag_chain = self.get_conversational_rag_chain(llm) response = qa_rag_chain.invoke({"input": user_input}) return response['answer'] except Exception as e: logger.error(f"Error getting response: {e}") raise # if __name__ == "__main__": # import os # from model_initializer import initialize_models # parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) # openai_api_key = conf.API_KEY # embedding_model, llm_model = initialize_models(openai_api_key,model_id=MODEL_ID, model_basename=MODEL_BASENAME) # print(f"embeddi_modelng: {embedding_model}") # print(f"llm_model: {llm_model}") # collection_name = 'AI_assignment' # persist_directory = f'D:/AI Assignment/vector_store' # print(f"persist_directory: {persist_directory}") # while True: # print("Enter query: ") # user_query = input() # if user_query.lower() == 'exit': # break # retriever_qa = RetrieverChain( # collection_name=collection_name, embedding_function=embedding_model, persist_directory=persist_directory) # response = retriever_qa.get_response(user_input = user_query, llm= llm_model) # print(f"Response: {response}")