rag / app.py
Vasyl808
Add application file
7f0844d
import gradio as gr
import utils
from datasets import load_dataset, concatenate_datasets
from langchain.docstore.document import Document as LangchainDocument
from tqdm import tqdm
import pickle
from ragatouille import RAGPretrainedModel
import chunker
import retriver
import rag
import nltk
import config
import os
import warnings
import sys
import logging
logging.getLogger("langchain").setLevel(logging.ERROR)
warnings.filterwarnings("ignore")
class AnswerSystem:
def __init__(self, rag_system) -> None:
self.rag_system = rag_system
def answer_generate(self, question, bm_25_flag, semantic_flag, temperature):
answer, relevant_docs = self.rag_system.answer(
question=question,
temperature=temperature,
bm_25_flag=bm_25_flag,
semantic_flag=semantic_flag,
num_retrieved_docs = 10,
num_docs_final = 5
)
formatted_docs = "\n\n".join([f"Document {i + 1}: {doc}" for i, doc in enumerate(relevant_docs)])
return answer, formatted_docs
def run_app(rag_model):
with gr.Blocks() as demo:
gr.Markdown(
"""
# RealTimeData Monthly Collection - BBC News Documentation Assistant
Welcome! This system is designed to help you explore and find insights from the RealTimeData Monthly Collection - BBC News dataset.
For example:
- *"What position does Josko Gvardiol play, and how much did Manchester City pay for him?"*
"""
)
# Поля вводу
question_input = gr.Textbox(label="Enter your question:",
placeholder="E.g., What position does Josko Gvardiol play, and how much did Manchester City pay for him?")
bm25_checkbox = gr.Checkbox(label="Enable BM25-based retrieval", value=True) # BM25 flag
semantic_checkbox = gr.Checkbox(label="Enable Semantic Search", value=True) # Semantic flag
temperature_slider = gr.Slider(label="Response Temperature", minimum=0.1, maximum=1.0, value=0.5,
step=0.1) # Temperature
# Кнопка пошуку
search_button = gr.Button("Search")
# Поля виводу
answer_output = gr.Textbox(label="Answer", interactive=False, lines=5)
docs_output = gr.Textbox(label="Relevant Documents", interactive=False, lines=10)
# Логіка пошуку
system = AnswerSystem(rag_model)
search_button.click(
system.answer_generate,
inputs=[question_input, bm25_checkbox, semantic_checkbox, temperature_slider], # Всі параметри
outputs=[answer_output, docs_output]
)
# Запуск додатку
demo.launch(debug=True, share=True)
def get_rag_data():
nltk.download('punkt')
nltk.download('punkt_tab')
if os.path.exists(config.DOCUMENTS_PATH):
print(f"Loading preprocessed documents from {config.DOCUMENTS_PATH}")
with open(config.DOCUMENTS_PATH, "rb") as file:
docs_processed = pickle.load(file)
else:
print("Processing documents...")
datasets_list = [
utils.align_features(load_dataset("RealTimeData/bbc_news_alltime", config)["train"])
for config in tqdm(config.AVAILABLE_DATASET_CONFIGS)
]
ds = concatenate_datasets(datasets_list)
RAW_KNOWLEDGE_BASE = [
LangchainDocument(
page_content=doc["content"],
metadata={
"title": doc["title"],
"published_date": doc["published_date"],
"authors": doc["authors"],
"section": doc["section"],
"description": doc["description"],
"link": doc["link"]
}
)
for doc in tqdm(ds)
]
docs_processed = chunker.split_documents(512, RAW_KNOWLEDGE_BASE)
print(f"Saving preprocessed documents to {config.DOCUMENTS_PATH}")
with open(config.DOCUMENTS_PATH, "wb") as file:
pickle.dump(docs_processed, file)
return docs_processed
if __name__ == '__main__':
docs_processed = get_rag_data()
bm25 = retriver.create_bm25(docs_processed)
KNOWLEDGE_VECTOR_DATABASE = retriver.create_vector_db(docs_processed)
RERANKER = RAGPretrainedModel.from_pretrained(config.CROSS_ENCODER_MODEL)
rag_generator = rag.RAGAnswerGenerator(
docs=docs_processed,
bm25=bm25,
knowledge_index=KNOWLEDGE_VECTOR_DATABASE,
reranker=RERANKER
)
run_app(rag_generator)