Spaces:
Running
Running
""" | |
LLM chain retrieval | |
""" | |
import json | |
import gradio as gr | |
from langchain.chains.conversational_retrieval.base import ConversationalRetrievalChain | |
from langchain.memory import ConversationBufferMemory | |
from langchain_huggingface import HuggingFaceEndpoint | |
from langchain_core.prompts import PromptTemplate | |
# Add system template for RAG application | |
PROMPT_TEMPLATE = """ | |
You are an assistant for question-answering tasks. Use the following pieces of context to answer the question at the end. | |
If you don't know the answer, just say that you don't know, don't try to make up an answer. Keep the answer concise. | |
Question: {question} | |
Context: {context} | |
Helpful Answer: | |
""" | |
# Initialize langchain LLM chain | |
def initialize_llmchain( | |
llm_model, | |
huggingfacehub_api_token, | |
temperature, | |
max_tokens, | |
top_k, | |
vector_db, | |
progress=gr.Progress(), | |
): | |
"""Initialize Langchain LLM chain""" | |
progress(0.1, desc="Initializing HF tokenizer...") | |
# HuggingFaceHub uses HF inference endpoints | |
progress(0.5, desc="Initializing HF Hub...") | |
# Use of trust_remote_code as model_kwargs | |
# Warning: langchain issue | |
# URL: https://github.com/langchain-ai/langchain/issues/6080 | |
llm = HuggingFaceEndpoint( | |
repo_id=llm_model, | |
task="text-generation", | |
temperature=temperature, | |
max_new_tokens=max_tokens, | |
top_k=top_k, | |
huggingfacehub_api_token=huggingfacehub_api_token, | |
) | |
progress(0.75, desc="Defining buffer memory...") | |
memory = ConversationBufferMemory( | |
memory_key="chat_history", output_key="answer", return_messages=True | |
) | |
# retriever=vector_db.as_retriever(search_type="similarity", search_kwargs={'k': 3}) | |
retriever = vector_db.as_retriever() | |
progress(0.8, desc="Defining retrieval chain...") | |
with open('prompt_template.json', 'r') as file: | |
system_prompt = json.load(file) | |
prompt_template = system_prompt["prompt"] | |
rag_prompt = PromptTemplate( | |
template=prompt_template, input_variables=["context", "question"] | |
) | |
qa_chain = ConversationalRetrievalChain.from_llm( | |
llm, | |
retriever=retriever, | |
chain_type="stuff", | |
memory=memory, | |
combine_docs_chain_kwargs={"prompt": rag_prompt}, | |
return_source_documents=True, | |
# return_generated_question=False, | |
verbose=False, | |
) | |
progress(0.9, desc="Done!") | |
return qa_chain | |
def format_chat_history(message, chat_history): | |
"""Format chat history for llm chain""" | |
formatted_chat_history = [] | |
for user_message, bot_message in chat_history: | |
formatted_chat_history.append(f"User: {user_message}") | |
formatted_chat_history.append(f"Assistant: {bot_message}") | |
return formatted_chat_history | |
def invoke_qa_chain(qa_chain, message, history): | |
"""Invoke question-answering chain""" | |
formatted_chat_history = format_chat_history(message, history) | |
# print("formatted_chat_history",formatted_chat_history) | |
# Generate response using QA chain | |
response = qa_chain.invoke( | |
{"question": message, "chat_history": formatted_chat_history} | |
) | |
response_sources = response["source_documents"] | |
response_answer = response["answer"] | |
if response_answer.find("Helpful Answer:") != -1: | |
response_answer = response_answer.split("Helpful Answer:")[-1] | |
# Append user message and response to chat history | |
new_history = history + [(message, response_answer)] | |
# print ('chat response: ', response_answer) | |
# print('DB source', response_sources) | |
return qa_chain, new_history, response_sources | |