dl4ds_tutor / code /modules /llm_tutor.py
XThomasBU's picture
improvements
a052bdc
raw
history blame
2.96 kB
from langchain import PromptTemplate
from langchain.embeddings import HuggingFaceEmbeddings
from langchain_community.chat_models import ChatOpenAI
from langchain_community.embeddings import OpenAIEmbeddings
from langchain.vectorstores import FAISS
from langchain.chains import RetrievalQA, ConversationalRetrievalChain
from langchain.llms import CTransformers
from langchain.memory import ConversationBufferMemory
from langchain.chains.conversational_retrieval.prompts import QA_PROMPT
import os
from modules.constants import *
from modules.helpers import get_prompt
from modules.chat_model_loader import ChatModelLoader
from modules.vector_db import VectorDB
class LLMTutor:
def __init__(self, config, logger=None):
self.config = config
self.vector_db = VectorDB(config, logger=logger)
if self.config["embedding_options"]["embedd_files"]:
self.vector_db.create_database()
self.vector_db.save_database()
def set_custom_prompt(self):
"""
Prompt template for QA retrieval for each vectorstore
"""
prompt = get_prompt(self.config)
# prompt = QA_PROMPT
return prompt
# Retrieval QA Chain
def retrieval_qa_chain(self, llm, prompt, db):
if self.config["llm_params"]["use_history"]:
memory = ConversationBufferMemory(
memory_key="chat_history", return_messages=True, output_key="answer"
)
qa_chain = ConversationalRetrievalChain.from_llm(
llm=llm,
chain_type="stuff",
retriever=db.as_retriever(
search_kwargs={
"k": self.config["embedding_options"]["search_top_k"]
}
),
return_source_documents=True,
memory=memory,
combine_docs_chain_kwargs={"prompt": prompt},
)
else:
qa_chain = RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff",
retriever=db.as_retriever(
search_kwargs={
"k": self.config["embedding_options"]["search_top_k"]
}
),
return_source_documents=True,
chain_type_kwargs={"prompt": prompt},
)
return qa_chain
# Loading the model
def load_llm(self):
chat_model_loader = ChatModelLoader(self.config)
llm = chat_model_loader.load_chat_model()
return llm
# QA Model Function
def qa_bot(self):
db = self.vector_db.load_database()
self.llm = self.load_llm()
qa_prompt = self.set_custom_prompt()
qa = self.retrieval_qa_chain(self.llm, qa_prompt, db)
return qa
# output function
def final_result(query):
qa_result = qa_bot()
response = qa_result({"query": query})
return response