RAG / app.py
PyroSama's picture
Update app.py
a6bc888 verified
raw
history blame
5.68 kB
import gradio as gr
import utils
from langchain_mistralai import ChatMistralAI
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_community.vectorstores import Chroma
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_core.runnables import RunnablePassthrough
import torch
import os
os.environ['MISTRAL_API_KEY'] = 'XuyOObDE7trMbpAeI7OXYr3dnmoWy3L0'
class VectorData():
def __init__(self):
embedding_model_name = 'l3cube-pune/punjabi-sentence-similarity-sbert'
model_kwargs = {'device':'cuda' if torch.cuda.is_available() else 'cpu',"trust_remote_code": True}
self.embeddings = HuggingFaceEmbeddings(
model_name=embedding_model_name,
model_kwargs=model_kwargs
)
self.vectorstore = Chroma(persist_directory="chroma_db", embedding_function=self.embeddings)
self.retriever = self.vectorstore.as_retriever()
self.ingested_files = []
self.prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"""Answer the question based on the given context. Dont give any ans if context is not valid to question. Always give the source of context:
{context}
""",
),
("human", "{question}"),
]
)
self.llm = ChatMistralAI(model="mistral-large-latest")
self.rag_chain = (
{"context": self.retriever, "question": RunnablePassthrough()}
| self.prompt
| self.llm
| StrOutputParser()
)
def add_file(self,file):
if file is not None:
self.ingested_files.append(file.name.split('/')[-1])
self.retriever, self.vectorstore = utils.add_doc(file,self.vectorstore)
self.rag_chain = (
{"context": self.retriever, "question": RunnablePassthrough()}
| self.prompt
| self.llm
| StrOutputParser()
)
return [[name] for name in self.ingested_files]
def delete_file_by_name(self,file_name):
if file_name in self.ingested_files:
self.retriever, self.vectorstore = utils.delete_doc(file_name,self.vectorstore)
self.ingested_files.remove(file_name)
return [[name] for name in self.ingested_files]
def delete_all_files(self):
self.ingested_files.clear()
self.retriever, self.vectorstore = utils.delete_all_doc(self.vectorstore)
return []
data_obj = VectorData()
# Function to handle question answering
def answer_question(question):
if question.strip():
return f'{data_obj.rag_chain.invoke(question)}'
return "Please enter a question."
# Define the Gradio interface
with gr.Blocks() as rag_interface:
# Title and Description
gr.Markdown("# RAG Interface")
gr.Markdown("Manage documents and ask questions with a Retrieval-Augmented Generation (RAG) system.")
with gr.Row():
# Left Column: File Management
with gr.Column():
gr.Markdown("### File Management")
# File upload and ingest
file_input = gr.File(label="Upload File to Ingest")
add_file_button = gr.Button("Ingest File")
# Scrollable list for ingested files
ingested_files_box = gr.Dataframe(
headers=["Files"],
datatype="str",
row_count=4, # Limits the visible rows to create a scrollable view
interactive=False
)
# Radio buttons to choose delete option
delete_option = gr.Radio(choices=["Delete by File Name", "Delete All Files"], label="Delete Option")
file_name_input = gr.Textbox(label="Enter File Name to Delete", visible=False)
delete_button = gr.Button("Delete Selected")
# Show or hide file name input based on delete option selection
def toggle_file_input(option):
return gr.update(visible=(option == "Delete by File Name"))
delete_option.change(fn=toggle_file_input, inputs=delete_option, outputs=file_name_input)
# Handle file ingestion
add_file_button.click(
fn=data_obj.add_file,
inputs=file_input,
outputs=ingested_files_box
)
# Handle delete based on selected option
def delete_action(delete_option, file_name):
if delete_option == "Delete by File Name" and file_name:
return data_obj.delete_file_by_name(file_name)
elif delete_option == "Delete All Files":
return data_obj.delete_all_files()
else:
return [[name] for name in data_obj.ingested_files]
delete_button.click(
fn=delete_action,
inputs=[delete_option, file_name_input],
outputs=ingested_files_box
)
# Right Column: Question Answering
with gr.Column():
gr.Markdown("### Ask a Question")
# Question input
question_input = gr.Textbox(label="Enter your question")
# Get answer button and answer output
ask_button = gr.Button("Get Answer")
answer_output = gr.Textbox(label="Answer", interactive=False)
ask_button.click(fn=answer_question, inputs=question_input, outputs=answer_output)
# Launch the Gradio interface
rag_interface.launch()