Spaces:
Build error
Build error
from langchain import PromptTemplate | |
from langchain.embeddings import HuggingFaceEmbeddings | |
from langchain_community.chat_models import ChatOpenAI | |
from langchain_community.embeddings import OpenAIEmbeddings | |
from langchain.vectorstores import FAISS | |
from langchain.chains import RetrievalQA, ConversationalRetrievalChain | |
from langchain.llms import CTransformers | |
from langchain.memory import ConversationBufferWindowMemory | |
from langchain.chains.conversational_retrieval.prompts import QA_PROMPT | |
import os | |
from modules.constants import * | |
from modules.helpers import get_prompt | |
from modules.chat_model_loader import ChatModelLoader | |
from modules.vector_db import VectorDB, VectorDBScore | |
class LLMTutor: | |
def __init__(self, config, logger=None): | |
self.config = config | |
self.vector_db = VectorDB(config, logger=logger) | |
if self.config["embedding_options"]["embedd_files"]: | |
self.vector_db.create_database() | |
self.vector_db.save_database() | |
def set_custom_prompt(self): | |
""" | |
Prompt template for QA retrieval for each vectorstore | |
""" | |
prompt = get_prompt(self.config) | |
# prompt = QA_PROMPT | |
return prompt | |
# Retrieval QA Chain | |
def retrieval_qa_chain(self, llm, prompt, db): | |
if self.config["embedding_options"]["db_option"] in ["FAISS", "Chroma"]: | |
retriever = VectorDBScore( | |
vectorstore=db, | |
search_type="similarity_score_threshold", | |
search_kwargs={ | |
"score_threshold": self.config["embedding_options"][ | |
"score_threshold" | |
], | |
"k": self.config["embedding_options"]["search_top_k"], | |
}, | |
) | |
elif self.config["embedding_options"]["db_option"] == "RAGatouille": | |
retriever = db.as_langchain_retriever( | |
k=self.config["embedding_options"]["search_top_k"] | |
) | |
if self.config["llm_params"]["use_history"]: | |
memory = ConversationBufferWindowMemory( | |
k=self.config["llm_params"]["memory_window"], | |
memory_key="chat_history", | |
return_messages=True, | |
output_key="answer", | |
) | |
qa_chain = ConversationalRetrievalChain.from_llm( | |
llm=llm, | |
chain_type="stuff", | |
retriever=retriever, | |
return_source_documents=True, | |
memory=memory, | |
combine_docs_chain_kwargs={"prompt": prompt}, | |
) | |
else: | |
qa_chain = RetrievalQA.from_chain_type( | |
llm=llm, | |
chain_type="stuff", | |
retriever=retriever, | |
return_source_documents=True, | |
chain_type_kwargs={"prompt": prompt}, | |
) | |
return qa_chain | |
# Loading the model | |
def load_llm(self): | |
chat_model_loader = ChatModelLoader(self.config) | |
llm = chat_model_loader.load_chat_model() | |
return llm | |
# QA Model Function | |
def qa_bot(self): | |
db = self.vector_db.load_database() | |
self.llm = self.load_llm() | |
qa_prompt = self.set_custom_prompt() | |
qa = self.retrieval_qa_chain(self.llm, qa_prompt, db) | |
return qa | |
# output function | |
def final_result(query): | |
qa_result = qa_bot() | |
response = qa_result({"query": query}) | |
return response | |