File size: 4,307 Bytes
abb6f94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import logging
#import create_history_aware_retriever, 
from langchain.chains import create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain
from app.rag_pipeline.prompt_utils import qa_prompt
from app.rag_pipeline.chroma_client import get_chroma_client
from app.settings import Config

# from prompt_utils import qa_prompt
# from chroma_client import get_chroma_client




# import sys
# import os

# parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
# sys.path.insert(0, parent_dir)
# from settings import Config


conf = Config()


MODELS_PATH =  conf.MODELS_PATH #'/models'

CONTEXT_WINDOW_SIZE = 2048
MAX_NEW_TOKENS = 2048
N_BATCH= 512
N_GPU_LAYERS = 1

MODEL_ID = conf.MODEL_ID  #"TheBloke/Mistral-7B-v0.1-GGUF"
MODEL_BASENAME = conf.MODEL_BASENAME # "mistral-7b-v0.1.Q4_0.gguf"
device_type = 'cpu'

logger = logging.getLogger(__name__)

class RetrieverChain:
    def __init__(self, collection_name, embedding_function, persist_directory):
        try:
            self.vector_db = get_chroma_client(collection_name, embedding_function, persist_directory)
        except Exception as e:
            logger.error(f"Error creating RetrieverChain: {e}")
            raise

    def get_retriever(self):
        try:
            retriever = self.vector_db.as_retriever(search_type="mmr", search_kwargs={"k": 5, "fetch_k": 2})
            
            return retriever
        except Exception as e:
            logger.error(f"Failed to get retriever: {e}")
            raise

    def get_conversational_rag_chain(self, llm):
        try:

            if self.get_retriever is None:
                logger.error(f"Retriever must not be None")
                raise ValueError("Retriever must not be None")
            if llm is None:
                logger.error(f"Model must not be None")
                raise ValueError("Model must not be None")

            
            question_answer_chain = create_stuff_documents_chain(llm, qa_prompt)
            return create_retrieval_chain(self.get_retriever(), question_answer_chain)
        except Exception as e:
            logger.error(f"Error creating RAG chain: {e}")
            raise

    
    def get_relevent_docs(self, user_input):

        try:
            docs = self.vector_db.as_retriever(search_type="mmr", search_kwargs={"k": 6, "fetch_k": 3}).get_relevant_documents(user_input)
            logger.info(f"Relevent documents for {user_input}: {docs}")
            # Access the retrieved documents

            # print("Relevent Docs")
            # for doc in docs:
            #     print(doc.page_content)  # Access the original text
            #     print(doc.metadata)  # Access any metadata associated with the document
            # print("Relevent Docs end")
            return docs

        except Exception as e:
            logger.error(f"Error getting response: {e}")
            raise

    def get_response(self, user_input, llm):
        try:   
            qa_rag_chain = self.get_conversational_rag_chain(llm)
            response = qa_rag_chain.invoke({"input": user_input})
            return response['answer']
        except Exception as e:
            logger.error(f"Error getting response: {e}")
            raise




# if __name__ == "__main__":
#     import os
#     from model_initializer import initialize_models
    
    
#     parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))

#     openai_api_key = conf.API_KEY

#     embedding_model, llm_model = initialize_models(openai_api_key,model_id=MODEL_ID, model_basename=MODEL_BASENAME)

#     print(f"embeddi_modelng: {embedding_model}")
#     print(f"llm_model: {llm_model}")

#     collection_name = 'AI_assignment'

#     persist_directory = f'D:/AI Assignment/vector_store'
#     print(f"persist_directory: {persist_directory}")
#     while True:
#         print("Enter query: ")
#         user_query = input()
#         if user_query.lower() == 'exit':
#             break

#         retriever_qa = RetrieverChain(
#             collection_name=collection_name, embedding_function=embedding_model, persist_directory=persist_directory)
#         response = retriever_qa.get_response(user_input = user_query, llm= llm_model)
#         print(f"Response: {response}")