Spaces:
Sleeping
Sleeping
File size: 6,525 Bytes
e86b928 d4c27ab cf496f0 b2c1b30 cf496f0 9699ac9 e86b928 cf496f0 a91f0db d4c27ab cf496f0 d4c27ab 9699ac9 d4c27ab cf496f0 d4c27ab cf496f0 d4c27ab cf496f0 d4c27ab cf496f0 d4c27ab cf496f0 d4c27ab cf496f0 d4c27ab cf496f0 d4c27ab cf496f0 9699ac9 d4c27ab cf496f0 d4c27ab cf496f0 d4c27ab cf496f0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 |
import os
import faiss
import numpy as np
from rank_bm25 import BM25Okapi
import torch
import pandas as pd
import gradio as gr
from transformers import AutoTokenizer, AutoModel, GPT2LMHeadModel, GPT2Tokenizer
# Set cache directory for Hugging Face models
os.environ["HF_HOME"] = "/tmp/huggingface"
# Load dataset
DATASET_PATH = os.path.join(os.getcwd(), "springer_papers_DL.json")
if not os.path.exists(DATASET_PATH):
raise FileNotFoundError(f"Dataset file not found at {DATASET_PATH}")
df = pd.read_json(DATASET_PATH)
# Clean text
def clean_text(text):
return text.strip().lower()
df["cleaned_abstract"] = df["abstract"].apply(clean_text)
# Precompute BM25 Index
tokenized_corpus = [paper.split() for paper in df["cleaned_abstract"]]
bm25 = BM25Okapi(tokenized_corpus)
# Load SciBERT for embeddings (preloaded globally)
sci_bert_tokenizer = AutoTokenizer.from_pretrained("allenai/scibert_scivocab_uncased", cache_dir="/tmp/huggingface")
sci_bert_model = AutoModel.from_pretrained("allenai/scibert_scivocab_uncased", cache_dir="/tmp/huggingface")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
sci_bert_model.to(device)
sci_bert_model.eval()
# Load GPT-2 for QA (using distilgpt2 for efficiency)
gpt2_tokenizer = GPT2Tokenizer.from_pretrained("distilgpt2", cache_dir="/tmp/huggingface")
gpt2_model = GPT2LMHeadModel.from_pretrained("distilgpt2", cache_dir="/tmp/huggingface")
gpt2_model.to(device)
gpt2_model.eval()
# Generate SciBERT embeddings
def generate_embeddings_sci_bert(texts, batch_size=32):
all_embeddings = []
for i in range(0, len(texts), batch_size):
batch = texts[i:i + batch_size]
inputs = sci_bert_tokenizer(batch, return_tensors="pt", padding=True, truncation=True, max_length=512)
inputs = {key: val.to(device) for key, val in inputs.items()}
with torch.no_grad():
outputs = sci_bert_model(**inputs)
embeddings = outputs.last_hidden_state.mean(dim=1)
all_embeddings.append(embeddings.cpu().numpy())
torch.cuda.empty_cache()
return np.concatenate(all_embeddings, axis=0)
# Precompute embeddings and FAISS index
abstracts = df["cleaned_abstract"].tolist()
embeddings = generate_embeddings_sci_bert(abstracts)
dimension = embeddings.shape[1]
faiss_index = faiss.IndexFlatL2(dimension)
faiss_index.add(embeddings.astype(np.float32))
# Hybrid search function
def get_relevant_papers(query, top_k=5):
if not query.strip():
return []
query_embedding = generate_embeddings_sci_bert([query])
distances, indices = faiss_index.search(query_embedding.astype(np.float32), top_k)
tokenized_query = query.split()
bm25_scores = bm25.get_scores(tokenized_query)
bm25_top_indices = np.argsort(bm25_scores)[::-1][:top_k]
combined_indices = list(set(indices[0]) | set(bm25_top_indices))
ranked_results = sorted(combined_indices, key=lambda idx: -bm25_scores[idx])
papers = []
for i, index in enumerate(ranked_results[:top_k]):
paper = df.iloc[index]
papers.append(f"{i+1}. {paper['title']} - Abstract: {paper['cleaned_abstract'][:200]}...")
return papers
# GPT-2 QA function
def answer_question(paper, question, history):
if not question.strip():
return "Please ask a question!", history
if question.lower() in ["exit", "done"]:
return "Conversation ended. Select a new paper or search again!", []
# Extract title and abstract from paper string
title = paper.split(" - Abstract: ")[0].split(". ", 1)[1]
abstract = paper.split(" - Abstract: ")[1].rstrip("...")
# Build context with history
context = f"Title: {title}\nAbstract: {abstract}\n\nPrevious conversation:\n"
for user_q, bot_a in history:
context += f"User: {user_q}\nAssistant: {bot_a}\n"
context += f"User: {question}\nAssistant: "
# Generate response
inputs = gpt2_tokenizer(context, return_tensors="pt", truncation=True, max_length=512)
inputs = {key: val.to(device) for key, val in inputs.items()}
with torch.no_grad():
outputs = gpt2_model.generate(
inputs["input_ids"],
max_new_tokens=100,
do_sample=True,
temperature=0.7,
top_k=50,
pad_token_id=gpt2_tokenizer.eos_token_id
)
response = gpt2_tokenizer.decode(outputs[0], skip_special_tokens=True)
response = response[len(context):].strip()
history.append((question, response))
return response, history
# Gradio UI
with gr.Blocks(
css="""
.chatbot {height: 600px; overflow-y: auto;}
.sidebar {width: 300px;}
#main {display: flex; flex-direction: row;}
""",
theme=gr.themes.Default(primary_hue="blue")
) as demo:
gr.Markdown("# ResearchGPT - Paper Search & Chat")
with gr.Row(elem_id="main"):
# Sidebar for search
with gr.Column(scale=1, min_width=300, elem_classes="sidebar"):
gr.Markdown("### Search Papers")
query_input = gr.Textbox(label="Enter your search query", placeholder="e.g., machine learning in healthcare")
search_btn = gr.Button("Search")
paper_dropdown = gr.Dropdown(label="Select a Paper", choices=[], interactive=True)
search_btn.click(
fn=get_relevant_papers,
inputs=query_input,
outputs=paper_dropdown
)
# Main chat area
with gr.Column(scale=3):
gr.Markdown("### Chat with Selected Paper")
selected_paper = gr.Textbox(label="Selected Paper", interactive=False)
chatbot = gr.Chatbot(label="Conversation", elem_classes="chatbot")
question_input = gr.Textbox(label="Ask a question", placeholder="e.g., What methods are used?")
chat_btn = gr.Button("Send")
# State to store conversation history
history_state = gr.State([])
# Update selected paper
paper_dropdown.change(
fn=lambda x: x,
inputs=paper_dropdown,
outputs=selected_paper
)
# Handle chat
chat_btn.click(
fn=answer_question,
inputs=[selected_paper, question_input, history_state],
outputs=[chatbot, history_state],
_js="() => {document.querySelector('.chatbot').scrollTop = document.querySelector('.chatbot').scrollHeight;}"
)
# Launch the app
demo.launch() |