Spaces:
Runtime error
Runtime error
File size: 4,307 Bytes
abb6f94 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
import logging
#import create_history_aware_retriever,
from langchain.chains import create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain
from app.rag_pipeline.prompt_utils import qa_prompt
from app.rag_pipeline.chroma_client import get_chroma_client
from app.settings import Config
# from prompt_utils import qa_prompt
# from chroma_client import get_chroma_client
# import sys
# import os
# parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
# sys.path.insert(0, parent_dir)
# from settings import Config
conf = Config()
MODELS_PATH = conf.MODELS_PATH #'/models'
CONTEXT_WINDOW_SIZE = 2048
MAX_NEW_TOKENS = 2048
N_BATCH= 512
N_GPU_LAYERS = 1
MODEL_ID = conf.MODEL_ID #"TheBloke/Mistral-7B-v0.1-GGUF"
MODEL_BASENAME = conf.MODEL_BASENAME # "mistral-7b-v0.1.Q4_0.gguf"
device_type = 'cpu'
logger = logging.getLogger(__name__)
class RetrieverChain:
def __init__(self, collection_name, embedding_function, persist_directory):
try:
self.vector_db = get_chroma_client(collection_name, embedding_function, persist_directory)
except Exception as e:
logger.error(f"Error creating RetrieverChain: {e}")
raise
def get_retriever(self):
try:
retriever = self.vector_db.as_retriever(search_type="mmr", search_kwargs={"k": 5, "fetch_k": 2})
return retriever
except Exception as e:
logger.error(f"Failed to get retriever: {e}")
raise
def get_conversational_rag_chain(self, llm):
try:
if self.get_retriever is None:
logger.error(f"Retriever must not be None")
raise ValueError("Retriever must not be None")
if llm is None:
logger.error(f"Model must not be None")
raise ValueError("Model must not be None")
question_answer_chain = create_stuff_documents_chain(llm, qa_prompt)
return create_retrieval_chain(self.get_retriever(), question_answer_chain)
except Exception as e:
logger.error(f"Error creating RAG chain: {e}")
raise
def get_relevent_docs(self, user_input):
try:
docs = self.vector_db.as_retriever(search_type="mmr", search_kwargs={"k": 6, "fetch_k": 3}).get_relevant_documents(user_input)
logger.info(f"Relevent documents for {user_input}: {docs}")
# Access the retrieved documents
# print("Relevent Docs")
# for doc in docs:
# print(doc.page_content) # Access the original text
# print(doc.metadata) # Access any metadata associated with the document
# print("Relevent Docs end")
return docs
except Exception as e:
logger.error(f"Error getting response: {e}")
raise
def get_response(self, user_input, llm):
try:
qa_rag_chain = self.get_conversational_rag_chain(llm)
response = qa_rag_chain.invoke({"input": user_input})
return response['answer']
except Exception as e:
logger.error(f"Error getting response: {e}")
raise
# if __name__ == "__main__":
# import os
# from model_initializer import initialize_models
# parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
# openai_api_key = conf.API_KEY
# embedding_model, llm_model = initialize_models(openai_api_key,model_id=MODEL_ID, model_basename=MODEL_BASENAME)
# print(f"embeddi_modelng: {embedding_model}")
# print(f"llm_model: {llm_model}")
# collection_name = 'AI_assignment'
# persist_directory = f'D:/AI Assignment/vector_store'
# print(f"persist_directory: {persist_directory}")
# while True:
# print("Enter query: ")
# user_query = input()
# if user_query.lower() == 'exit':
# break
# retriever_qa = RetrieverChain(
# collection_name=collection_name, embedding_function=embedding_model, persist_directory=persist_directory)
# response = retriever_qa.get_response(user_input = user_query, llm= llm_model)
# print(f"Response: {response}")
|