""" LLM chain retrieval """ import json import gradio as gr from langchain.chains.conversational_retrieval.base import ConversationalRetrievalChain from langchain.memory import ConversationBufferMemory from langchain_huggingface import HuggingFaceEndpoint from langchain_core.prompts import PromptTemplate # Add system template for RAG application PROMPT_TEMPLATE = """ You are an assistant for question-answering tasks. Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer. Keep the answer concise. Question: {question} Context: {context} Helpful Answer: """ # Initialize langchain LLM chain def initialize_llmchain( llm_model, huggingfacehub_api_token, temperature, max_tokens, top_k, vector_db, progress=gr.Progress(), ): """Initialize Langchain LLM chain""" progress(0.1, desc="Initializing HF tokenizer...") # HuggingFaceHub uses HF inference endpoints progress(0.5, desc="Initializing HF Hub...") # Use of trust_remote_code as model_kwargs # Warning: langchain issue # URL: https://github.com/langchain-ai/langchain/issues/6080 llm = HuggingFaceEndpoint( repo_id=llm_model, task="text-generation", temperature=temperature, max_new_tokens=max_tokens, top_k=top_k, huggingfacehub_api_token=huggingfacehub_api_token, ) progress(0.75, desc="Defining buffer memory...") memory = ConversationBufferMemory( memory_key="chat_history", output_key="answer", return_messages=True ) # retriever=vector_db.as_retriever(search_type="similarity", search_kwargs={'k': 3}) retriever = vector_db.as_retriever() progress(0.8, desc="Defining retrieval chain...") with open('prompt_template.json', 'r') as file: system_prompt = json.load(file) prompt_template = system_prompt["prompt"] rag_prompt = PromptTemplate( template=prompt_template, input_variables=["context", "question"] ) qa_chain = ConversationalRetrievalChain.from_llm( llm, retriever=retriever, chain_type="stuff", memory=memory, combine_docs_chain_kwargs={"prompt": rag_prompt}, return_source_documents=True, # return_generated_question=False, verbose=False, ) progress(0.9, desc="Done!") return qa_chain def format_chat_history(message, chat_history): """Format chat history for llm chain""" formatted_chat_history = [] for user_message, bot_message in chat_history: formatted_chat_history.append(f"User: {user_message}") formatted_chat_history.append(f"Assistant: {bot_message}") return formatted_chat_history def invoke_qa_chain(qa_chain, message, history): """Invoke question-answering chain""" formatted_chat_history = format_chat_history(message, history) # print("formatted_chat_history",formatted_chat_history) # Generate response using QA chain response = qa_chain.invoke( {"question": message, "chat_history": formatted_chat_history} ) response_sources = response["source_documents"] response_answer = response["answer"] if response_answer.find("Helpful Answer:") != -1: response_answer = response_answer.split("Helpful Answer:")[-1] # Append user message and response to chat history new_history = history + [(message, response_answer)] # print ('chat response: ', response_answer) # print('DB source', response_sources) return qa_chain, new_history, response_sources