ADKU's picture
made an in depth prompt to handle any type of query
ad54e4d verified
raw
history blame
11.1 kB
import os
import faiss
import numpy as np
from rank_bm25 import BM25Okapi
import torch
import pandas as pd
import gradio as gr
from transformers import AutoTokenizer, AutoModel, GPT2LMHeadModel, GPT2Tokenizer
import logging
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Set cache directory for Hugging Face models
os.environ["HF_HOME"] = "/tmp/huggingface"
# Load dataset with error handling
DATASET_PATH = os.path.join(os.getcwd(), "springer_papers_DL.json")
try:
if not os.path.exists(DATASET_PATH):
raise FileNotFoundError(f"Dataset file not found at {DATASET_PATH}")
df = pd.read_json(DATASET_PATH)
logger.info("Dataset loaded successfully")
except Exception as e:
logger.error(f"Failed to load dataset: {e}")
raise
# Clean text
def clean_text(text):
return text.strip().lower() if isinstance(text, str) else ""
df["cleaned_abstract"] = df["abstract"].apply(clean_text)
# Precompute BM25 Index
try:
tokenized_corpus = [paper.split() for paper in df["cleaned_abstract"]]
bm25 = BM25Okapi(tokenized_corpus)
logger.info("BM25 index created")
except Exception as e:
logger.error(f"BM25 index creation failed: {e}")
raise
# Load models with error handling
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {device}")
try:
# SciBERT for embeddings
sci_bert_tokenizer = AutoTokenizer.from_pretrained("allenai/scibert_scivocab_uncased", cache_dir="/tmp/huggingface")
sci_bert_model = AutoModel.from_pretrained("allenai/scibert_scivocab_uncased", cache_dir="/tmp/huggingface")
sci_bert_model.to(device)
sci_bert_model.eval()
logger.info("SciBERT loaded")
# DistilGPT-2 for QA
gpt2_tokenizer = GPT2Tokenizer.from_pretrained("distilgpt2", cache_dir="/tmp/huggingface")
gpt2_model = GPT2LMHeadModel.from_pretrained("distilgpt2", cache_dir="/tmp/huggingface")
gpt2_model.to(device)
gpt2_model.eval()
logger.info("DistilGPT-2 loaded")
except Exception as e:
logger.error(f"Model loading failed: {e}")
raise
# Generate SciBERT embeddings
def generate_embeddings_sci_bert(texts, batch_size=32):
try:
all_embeddings = []
for i in range(0, len(texts), batch_size):
batch = texts[i:i + batch_size]
inputs = sci_bert_tokenizer(batch, return_tensors="pt", padding=True, truncation=True, max_length=512)
inputs = {key: val.to(device) for key, val in inputs.items()}
with torch.no_grad():
outputs = sci_bert_model(**inputs)
embeddings = outputs.last_hidden_state.mean(dim=1)
all_embeddings.append(embeddings.cpu().numpy())
torch.cuda.empty_cache()
return np.concatenate(all_embeddings, axis=0)
except Exception as e:
logger.error(f"Embedding generation failed: {e}")
return np.zeros((len(texts), 768))
# Precompute embeddings and FAISS index
try:
abstracts = df["cleaned_abstract"].tolist()
embeddings = generate_embeddings_sci_bert(abstracts)
dimension = embeddings.shape[1]
faiss_index = faiss.IndexFlatL2(dimension)
faiss_index.add(embeddings.astype(np.float32))
logger.info("FAISS index created")
except Exception as e:
logger.error(f"FAISS index creation failed: {e}")
raise
# Hybrid search function
def get_relevant_papers(query):
if not query.strip():
return [], "Please enter a search query."
try:
query_embedding = generate_embeddings_sci_bert([query])
distances, indices = faiss_index.search(query_embedding.astype(np.float32), 5)
tokenized_query = query.split()
bm25_scores = bm25.get_scores(tokenized_query)
bm25_top_indices = np.argsort(bm25_scores)[::-1][:5]
combined_indices = list(set(indices[0]) | set(bm25_top_indices))
ranked_results = sorted(combined_indices, key=lambda idx: -bm25_scores[idx])
papers = []
for i, index in enumerate(ranked_results[:5]):
paper = df.iloc[index]
papers.append(f"{i+1}. {paper['title']} - Abstract: {paper['cleaned_abstract'][:200]}...")
return papers, "Search completed."
except Exception as e:
logger.error(f"Search failed: {e}")
return [], "Search failed. Please try again."
# GPT-2 QA function with the best prompt
def answer_question(paper, question, history):
if not paper:
return [(question, "Please select a paper first!")], history
if not question.strip():
return [(question, "Please ask a question!")], history
if question.lower() in ["exit", "done"]:
return [("Conversation ended.", "Select a new paper or search again!")], []
try:
# Extract title and abstract
title = paper.split(" - Abstract: ")[0].split(". ", 1)[1]
abstract = paper.split(" - Abstract: ")[1].rstrip("...")
# Build the ultimate prompt
prompt = (
"You are Dr. Sage, the world's most brilliant and reliable research assistant, specializing in machine learning, deep learning and any abstract or title you are given as input. "
"Your goal is to provide concise, accurate, and well-structured answers based on the given paper's title and abstract. "
"Donot repeat the same sentence again and again no matter what, use your own intelligence to anser some vague question or question whos data is not with you."
"Be the best RESEARCH ASSISTANT ever existed"
"When asked about tech stacks or methods, use the following guidelines:\n"
"1. If the abstract explicitly mentions technologies (e.g., Python, TensorFlow), list them precisely with brief explanations.\n"
"2. If the abstract is vague (e.g., 'machine learning techniques'), infer the most likely tech stacks based on the context of crop prediction and modern research practices, and explain your reasoning.\n"
"3. Always respond in a clear, concise format—use bullet points for lists (e.g., tech stacks) and short paragraphs for explanations.\n"
"4. If the question requires prior conversation context, refer to it naturally to maintain coherence.\n"
"5. If the abstract lacks enough detail, supplement with plausible, domain-specific suggestions and note they are inferred.\n"
"6. Avoid speculation or fluff—stick to facts or educated guesses grounded in the field.\n\n"
f"Here’s the paper:\n"
f"Title: {title}\n"
f"Abstract: {abstract}\n\n"
)
# Add history if present
if history:
prompt += "Previous conversation (if any, use for context):\n"
for user_q, bot_a in history[-2:]:
prompt += f"User: {user_q}\nAssistant: {bot_a}\n"
prompt += f"Now, answer this question: {question}"
logger.info(f"Prompt sent to GPT-2: {prompt[:200]}...")
# Generate response
inputs = gpt2_tokenizer(prompt, return_tensors="pt", truncation=True, max_length=400)
inputs = {key: val.to(device) for key, val in inputs.items()}
with torch.no_grad():
outputs = gpt2_model.generate(
inputs["input_ids"],
max_new_tokens=150,
do_sample=True,
temperature=0.7,
top_p=0.9,
pad_token_id=gpt2_tokenizer.eos_token_id
)
# Decode and clean response
response = gpt2_tokenizer.decode(outputs[0], skip_special_tokens=True)
response = response[len(prompt):].strip()
# Fallback for poor responses
if not response or len(response) < 15:
response = (
"The abstract doesn’t provide specific technologies, but based on crop prediction with machine learning and deep learning, likely tech stacks include:\n"
"- Python: Core language for ML/DL.\n"
"- TensorFlow or PyTorch: Frameworks for deep learning models.\n"
"- Scikit-learn: For traditional ML algorithms.\n"
"- Pandas/NumPy: For data handling and preprocessing."
)
history.append((question, response))
return history, history
except Exception as e:
logger.error(f"QA failed: {e}")
history.append((question, "Sorry, I couldn’t process that. Try again!"))
return history, history
# Gradio UI
with gr.Blocks(
css="""
.chatbot {height: 600px; overflow-y: auto;}
.sidebar {width: 300px;}
#main {display: flex; flex-direction: row;}
""",
theme=gr.themes.Default(primary_hue="blue")
) as demo:
gr.Markdown("# ResearchGPT - Paper Search & Chat")
with gr.Row(elem_id="main"):
# Sidebar for search
with gr.Column(scale=1, min_width=300, elem_classes="sidebar"):
gr.Markdown("### Search Papers")
query_input = gr.Textbox(label="Enter your search query", placeholder="e.g., machine learning in healthcare")
search_btn = gr.Button("Search")
paper_dropdown = gr.Dropdown(label="Select a Paper", choices=[], interactive=True)
search_status = gr.Textbox(label="Search Status", interactive=False)
# State to store paper choices
paper_choices_state = gr.State([])
search_btn.click(
fn=get_relevant_papers,
inputs=query_input,
outputs=[paper_choices_state, search_status]
).then(
fn=lambda choices: gr.update(choices=choices, value=None),
inputs=paper_choices_state,
outputs=paper_dropdown
)
# Main chat area
with gr.Column(scale=3):
gr.Markdown("### Chat with Selected Paper")
selected_paper = gr.Textbox(label="Selected Paper", interactive=False)
chatbot = gr.Chatbot(label="Conversation", elem_classes="chatbot")
question_input = gr.Textbox(label="Ask a question", placeholder="e.g., What methods are used?")
chat_btn = gr.Button("Send")
# State to store conversation history
history_state = gr.State([])
# Update selected paper and reset history
paper_dropdown.change(
fn=lambda x: (x, []),
inputs=paper_dropdown,
outputs=[selected_paper, history_state]
).then(
fn=lambda: [],
inputs=None,
outputs=chatbot
)
# Handle chat
chat_btn.click(
fn=answer_question,
inputs=[selected_paper, question_input, history_state],
outputs=[chatbot, history_state]
).then(
fn=lambda: "",
inputs=None,
outputs=question_input
)
# Launch the app
demo.launch(server_name="0.0.0.0", server_port=7860)