arxiv-RAG / app.py
jharrison27's picture
Upload 2 files
3dbe475 verified
raw
history blame
5.82 kB
import torch
import transformers
import gradio as gr
from ragatouille import RAGPretrainedModel
import re
from datetime import datetime
import json
import arxiv
from helper import rag_cleaner, get_prompt_text, get_references, get_rag, SaveResponseAndRead, get_md_text_abstract, search_cleaner, get_arxiv_live_search
# Constants
RETRIEVE_RESULTS = 20
LLM_MODELS = ['mistralai/Mixtral-8x7B-Instruct-v0.1', 'mistralai/Mistral-7B-Instruct-v0.2', 'google/gemma-7b-it', 'None']
DEFAULT_LLM_MODEL = 'mistralai/Mistral-7B-Instruct-v0.2'
GENERATE_KWARGS = {
"temperature": None,
"max_new_tokens": 512,
"top_p": None,
"do_sample": False,
}
# RAG Model setup
RAG = RAGPretrainedModel.from_index("colbert/indexes/arxiv_colbert")
try:
gr.Info("Setting up retriever, please wait...")
rag_initial_output = RAG.search("What is Generative AI in Healthcare?", k=1)
gr.Info("Retriever working successfully!")
except Exception as e:
gr.Warning(f"Retriever not working: {str(e)}")
# Header setup
mark_text = '# 🩺🔍 Search Results\n'
header_text = "## Arxiv Paper Summary With QA Retrieval Augmented Generation \n"
try:
with open("README.md", "r") as f:
mdfile = f.read()
date_pattern = r'Index Last Updated : \d{4}-\d{2}-\d{2}'
match = re.search(date_pattern, mdfile)
date = match.group().split(': ')[1]
formatted_date = datetime.strptime(date, '%Y-%m-%d').strftime('%d %b %Y')
header_text += f'Index Last Updated: {formatted_date}\n'
index_info = f"Semantic Search - up to {formatted_date}"
except FileNotFoundError:
index_info = "Semantic Search"
database_choices = [index_info, 'Arxiv Search - Latest - (EXPERIMENTAL)']
# Arxiv API setup
arx_client = arxiv.Client()
is_arxiv_available = True
check_arxiv_result = get_arxiv_live_search("What is Self Rewarding AI and how can it be used in Multi-Agent Systems?", arx_client, RETRIEVE_RESULTS)
if len(check_arxiv_result) == 0:
is_arxiv_available = False
print("Arxiv search not working, switching to default search ...")
database_choices = [index_info]
# Gradio UI setup
with gr.Blocks(theme=gr.themes.Soft()) as demo:
header = gr.Markdown(header_text)
with gr.Group():
search_query = gr.Textbox(label='Search', placeholder='What is Generative AI in Healthcare?')
with gr.Accordion("Advanced Settings", open=False):
with gr.Row(equal_height=True):
llm_model = gr.Dropdown(choices=LLM_MODELS, value=DEFAULT_LLM_MODEL, label='LLM Model')
llm_results = gr.Slider(minimum=4, maximum=10, value=5, step=1, interactive=True, label="Top n results as context")
database_src = gr.Dropdown(choices=database_choices, value=index_info, label='Search Source')
stream_results = gr.Checkbox(value=True, label="Stream output", visible=False)
output_text = gr.Textbox(show_label=True, container=True, label='LLM Answer', visible=True)
input = gr.Textbox(show_label=False, visible=False)
gr_md = gr.Markdown(mark_text)
def update_with_rag_md(search_query, llm_results_use=5, database_choice=index_info, llm_model_picked=DEFAULT_LLM_MODEL):
prompt_text_from_data = ""
database_to_use = database_choice
if database_choice == index_info:
rag_out = get_rag(search_query, RAG, RETRIEVE_RESULTS)
else:
arxiv_search_success = True
try:
rag_out = get_arxiv_live_search(search_query, arx_client, RETRIEVE_RESULTS)
if len(rag_out) == 0:
arxiv_search_success = False
except Exception as e:
arxiv_search_success = False
gr.Warning(f"Arxiv Search not working: {str(e)}, switching to semantic search ...")
if not arxiv_search_success:
rag_out = get_rag(search_query, RAG, RETRIEVE_RESULTS)
database_to_use = index_info
md_text_updated = mark_text
for i, rag_answer in enumerate(rag_out):
if i < llm_results_use:
md_text_paper, prompt_text = get_md_text_abstract(rag_answer, source=database_to_use, return_prompt_formatting=True)
prompt_text_from_data += f"{i+1}. {prompt_text}"
else:
md_text_paper = get_md_text_abstract(rag_answer, source=database_to_use)
md_text_updated += md_text_paper
prompt = get_prompt_text(search_query, prompt_text_from_data, llm_model_picked=llm_model_picked)
return md_text_updated, prompt
def ask_llm(prompt, llm_model_picked=DEFAULT_LLM_MODEL, stream_outputs=False):
model_disabled_text = "LLM Model is disabled"
output = ""
if llm_model_picked == 'None':
if stream_outputs:
for out in model_disabled_text:
output += out
yield output
else:
return model_disabled_text
client = InferenceClient(llm_model_picked)
try:
response = client.text_generation(prompt, stream=stream_outputs, details=False, return_full_text=False, **GENERATE_KWARGS)
if stream_outputs:
for token in response:
output += token
yield SaveResponseAndRead(output)
else:
output = response
except Exception as e:
gr.Warning(f"LLM Inference failed: {str(e)}")
output = ""
return output
search_query.submit(update_with_rag_md, [search_query, llm_results, database_src, llm_model], [gr_md, input]).success(ask_llm, [input, llm_model, stream_results], output_text)
demo.queue().launch()