Spaces:
Build error
Build error
import gradio as gr | |
import utils | |
from datasets import load_dataset, concatenate_datasets | |
from langchain.docstore.document import Document as LangchainDocument | |
from tqdm import tqdm | |
import pickle | |
from ragatouille import RAGPretrainedModel | |
import chunker | |
import retriver | |
import rag | |
import nltk | |
import config | |
import os | |
import warnings | |
import sys | |
import logging | |
logging.getLogger("langchain").setLevel(logging.ERROR) | |
warnings.filterwarnings("ignore") | |
class AnswerSystem: | |
def __init__(self, rag_system) -> None: | |
self.rag_system = rag_system | |
def answer_generate(self, question, bm_25_flag, semantic_flag, temperature): | |
answer, relevant_docs = self.rag_system.answer( | |
question=question, | |
temperature=temperature, | |
bm_25_flag=bm_25_flag, | |
semantic_flag=semantic_flag, | |
num_retrieved_docs = 10, | |
num_docs_final = 5 | |
) | |
formatted_docs = "\n\n".join([f"Document {i + 1}: {doc}" for i, doc in enumerate(relevant_docs)]) | |
return answer, formatted_docs | |
def run_app(rag_model): | |
with gr.Blocks() as demo: | |
gr.Markdown( | |
""" | |
# RealTimeData Monthly Collection - BBC News Documentation Assistant | |
Welcome! This system is designed to help you explore and find insights from the RealTimeData Monthly Collection - BBC News dataset. | |
For example: | |
- *"What position does Josko Gvardiol play, and how much did Manchester City pay for him?"* | |
""" | |
) | |
# Поля вводу | |
question_input = gr.Textbox(label="Enter your question:", | |
placeholder="E.g., What position does Josko Gvardiol play, and how much did Manchester City pay for him?") | |
bm25_checkbox = gr.Checkbox(label="Enable BM25-based retrieval", value=True) # BM25 flag | |
semantic_checkbox = gr.Checkbox(label="Enable Semantic Search", value=True) # Semantic flag | |
temperature_slider = gr.Slider(label="Response Temperature", minimum=0.1, maximum=1.0, value=0.5, | |
step=0.1) # Temperature | |
# Кнопка пошуку | |
search_button = gr.Button("Search") | |
# Поля виводу | |
answer_output = gr.Textbox(label="Answer", interactive=False, lines=5) | |
docs_output = gr.Textbox(label="Relevant Documents", interactive=False, lines=10) | |
# Логіка пошуку | |
system = AnswerSystem(rag_model) | |
search_button.click( | |
system.answer_generate, | |
inputs=[question_input, bm25_checkbox, semantic_checkbox, temperature_slider], # Всі параметри | |
outputs=[answer_output, docs_output] | |
) | |
# Запуск додатку | |
demo.launch(debug=True, share=True) | |
def get_rag_data(): | |
nltk.download('punkt') | |
nltk.download('punkt_tab') | |
if os.path.exists(config.DOCUMENTS_PATH): | |
print(f"Loading preprocessed documents from {config.DOCUMENTS_PATH}") | |
with open(config.DOCUMENTS_PATH, "rb") as file: | |
docs_processed = pickle.load(file) | |
else: | |
print("Processing documents...") | |
datasets_list = [ | |
utils.align_features(load_dataset("RealTimeData/bbc_news_alltime", config)["train"]) | |
for config in tqdm(config.AVAILABLE_DATASET_CONFIGS) | |
] | |
ds = concatenate_datasets(datasets_list) | |
RAW_KNOWLEDGE_BASE = [ | |
LangchainDocument( | |
page_content=doc["content"], | |
metadata={ | |
"title": doc["title"], | |
"published_date": doc["published_date"], | |
"authors": doc["authors"], | |
"section": doc["section"], | |
"description": doc["description"], | |
"link": doc["link"] | |
} | |
) | |
for doc in tqdm(ds) | |
] | |
docs_processed = chunker.split_documents(512, RAW_KNOWLEDGE_BASE) | |
print(f"Saving preprocessed documents to {config.DOCUMENTS_PATH}") | |
with open(config.DOCUMENTS_PATH, "wb") as file: | |
pickle.dump(docs_processed, file) | |
return docs_processed | |
if __name__ == '__main__': | |
docs_processed = get_rag_data() | |
bm25 = retriver.create_bm25(docs_processed) | |
KNOWLEDGE_VECTOR_DATABASE = retriver.create_vector_db(docs_processed) | |
RERANKER = RAGPretrainedModel.from_pretrained(config.CROSS_ENCODER_MODEL) | |
rag_generator = rag.RAGAnswerGenerator( | |
docs=docs_processed, | |
bm25=bm25, | |
knowledge_index=KNOWLEDGE_VECTOR_DATABASE, | |
reranker=RERANKER | |
) | |
run_app(rag_generator) |