Spaces:
Build error
Build error
from langchain import PromptTemplate | |
from langchain.embeddings import HuggingFaceEmbeddings | |
from langchain_community.chat_models import ChatOpenAI | |
from langchain_community.embeddings import OpenAIEmbeddings | |
from langchain.vectorstores import FAISS | |
from langchain.chains import RetrievalQA, ConversationalRetrievalChain | |
from langchain.llms import CTransformers | |
from langchain.memory import ConversationBufferMemory | |
from langchain.chains.conversational_retrieval.prompts import QA_PROMPT | |
import os | |
from modules.constants import * | |
from modules.chat_model_loader import ChatModelLoader | |
from modules.vector_db import VectorDB | |
class LLMTutor: | |
def __init__(self, config, logger=None): | |
self.config = config | |
self.vector_db = VectorDB(config, logger=logger) | |
if self.config["embedding_options"]["embedd_files"]: | |
self.vector_db.create_database() | |
self.vector_db.save_database() | |
def set_custom_prompt(self): | |
""" | |
Prompt template for QA retrieval for each vectorstore | |
""" | |
if self.config["llm_params"]["use_history"]: | |
custom_prompt_template = prompt_template_with_history | |
else: | |
custom_prompt_template = prompt_template | |
prompt = PromptTemplate( | |
template=custom_prompt_template, | |
input_variables=["context", "chat_history", "question"], | |
) | |
# prompt = QA_PROMPT | |
return prompt | |
# Retrieval QA Chain | |
def retrieval_qa_chain(self, llm, prompt, db): | |
if self.config["llm_params"]["use_history"]: | |
memory = ConversationBufferMemory( | |
memory_key="chat_history", return_messages=True, output_key="answer" | |
) | |
qa_chain = ConversationalRetrievalChain.from_llm( | |
llm=llm, | |
chain_type="stuff", | |
retriever=db.as_retriever( | |
search_kwargs={ | |
"k": self.config["embedding_options"]["search_top_k"] | |
} | |
), | |
return_source_documents=True, | |
memory=memory, | |
combine_docs_chain_kwargs={"prompt": prompt}, | |
) | |
else: | |
qa_chain = RetrievalQA.from_chain_type( | |
llm=llm, | |
chain_type="stuff", | |
retriever=db.as_retriever( | |
search_kwargs={ | |
"k": self.config["embedding_options"]["search_top_k"] | |
} | |
), | |
return_source_documents=True, | |
chain_type_kwargs={"prompt": prompt}, | |
) | |
return qa_chain | |
# Loading the model | |
def load_llm(self): | |
chat_model_loader = ChatModelLoader(self.config) | |
llm = chat_model_loader.load_chat_model() | |
return llm | |
# QA Model Function | |
def qa_bot(self): | |
db = self.vector_db.load_database() | |
self.llm = self.load_llm() | |
qa_prompt = self.set_custom_prompt() | |
qa = self.retrieval_qa_chain(self.llm, qa_prompt, db) | |
return qa | |
# output function | |
def final_result(query): | |
qa_result = qa_bot() | |
response = qa_result({"query": query}) | |
return response | |