File size: 6,525 Bytes
e86b928
d4c27ab
 
 
 
 
cf496f0
 
b2c1b30
cf496f0
9699ac9
e86b928
cf496f0
a91f0db
 
 
 
d4c27ab
cf496f0
d4c27ab
 
 
9699ac9
d4c27ab
cf496f0
d4c27ab
 
 
cf496f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d4c27ab
 
cf496f0
 
d4c27ab
 
cf496f0
d4c27ab
 
cf496f0
d4c27ab
 
cf496f0
d4c27ab
cf496f0
d4c27ab
 
 
 
cf496f0
9699ac9
d4c27ab
cf496f0
 
d4c27ab
 
 
 
 
 
cf496f0
d4c27ab
 
cf496f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
import os
import faiss
import numpy as np
from rank_bm25 import BM25Okapi
import torch
import pandas as pd
import gradio as gr
from transformers import AutoTokenizer, AutoModel, GPT2LMHeadModel, GPT2Tokenizer

# Set cache directory for Hugging Face models
os.environ["HF_HOME"] = "/tmp/huggingface"

# Load dataset
DATASET_PATH = os.path.join(os.getcwd(), "springer_papers_DL.json")
if not os.path.exists(DATASET_PATH):
    raise FileNotFoundError(f"Dataset file not found at {DATASET_PATH}")
df = pd.read_json(DATASET_PATH)

# Clean text
def clean_text(text):
    return text.strip().lower()

df["cleaned_abstract"] = df["abstract"].apply(clean_text)

# Precompute BM25 Index
tokenized_corpus = [paper.split() for paper in df["cleaned_abstract"]]
bm25 = BM25Okapi(tokenized_corpus)

# Load SciBERT for embeddings (preloaded globally)
sci_bert_tokenizer = AutoTokenizer.from_pretrained("allenai/scibert_scivocab_uncased", cache_dir="/tmp/huggingface")
sci_bert_model = AutoModel.from_pretrained("allenai/scibert_scivocab_uncased", cache_dir="/tmp/huggingface")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
sci_bert_model.to(device)
sci_bert_model.eval()

# Load GPT-2 for QA (using distilgpt2 for efficiency)
gpt2_tokenizer = GPT2Tokenizer.from_pretrained("distilgpt2", cache_dir="/tmp/huggingface")
gpt2_model = GPT2LMHeadModel.from_pretrained("distilgpt2", cache_dir="/tmp/huggingface")
gpt2_model.to(device)
gpt2_model.eval()

# Generate SciBERT embeddings
def generate_embeddings_sci_bert(texts, batch_size=32):
    all_embeddings = []
    for i in range(0, len(texts), batch_size):
        batch = texts[i:i + batch_size]
        inputs = sci_bert_tokenizer(batch, return_tensors="pt", padding=True, truncation=True, max_length=512)
        inputs = {key: val.to(device) for key, val in inputs.items()}
        with torch.no_grad():
            outputs = sci_bert_model(**inputs)
        embeddings = outputs.last_hidden_state.mean(dim=1)
        all_embeddings.append(embeddings.cpu().numpy())
        torch.cuda.empty_cache()
    return np.concatenate(all_embeddings, axis=0)

# Precompute embeddings and FAISS index
abstracts = df["cleaned_abstract"].tolist()
embeddings = generate_embeddings_sci_bert(abstracts)
dimension = embeddings.shape[1]
faiss_index = faiss.IndexFlatL2(dimension)
faiss_index.add(embeddings.astype(np.float32))

# Hybrid search function
def get_relevant_papers(query, top_k=5):
    if not query.strip():
        return []
    query_embedding = generate_embeddings_sci_bert([query])
    distances, indices = faiss_index.search(query_embedding.astype(np.float32), top_k)
    tokenized_query = query.split()
    bm25_scores = bm25.get_scores(tokenized_query)
    bm25_top_indices = np.argsort(bm25_scores)[::-1][:top_k]
    combined_indices = list(set(indices[0]) | set(bm25_top_indices))
    ranked_results = sorted(combined_indices, key=lambda idx: -bm25_scores[idx])
    papers = []
    for i, index in enumerate(ranked_results[:top_k]):
        paper = df.iloc[index]
        papers.append(f"{i+1}. {paper['title']} - Abstract: {paper['cleaned_abstract'][:200]}...")
    return papers

# GPT-2 QA function
def answer_question(paper, question, history):
    if not question.strip():
        return "Please ask a question!", history
    if question.lower() in ["exit", "done"]:
        return "Conversation ended. Select a new paper or search again!", []
    
    # Extract title and abstract from paper string
    title = paper.split(" - Abstract: ")[0].split(". ", 1)[1]
    abstract = paper.split(" - Abstract: ")[1].rstrip("...")
    
    # Build context with history
    context = f"Title: {title}\nAbstract: {abstract}\n\nPrevious conversation:\n"
    for user_q, bot_a in history:
        context += f"User: {user_q}\nAssistant: {bot_a}\n"
    context += f"User: {question}\nAssistant: "
    
    # Generate response
    inputs = gpt2_tokenizer(context, return_tensors="pt", truncation=True, max_length=512)
    inputs = {key: val.to(device) for key, val in inputs.items()}
    with torch.no_grad():
        outputs = gpt2_model.generate(
            inputs["input_ids"],
            max_new_tokens=100,
            do_sample=True,
            temperature=0.7,
            top_k=50,
            pad_token_id=gpt2_tokenizer.eos_token_id
        )
    response = gpt2_tokenizer.decode(outputs[0], skip_special_tokens=True)
    response = response[len(context):].strip()
    
    history.append((question, response))
    return response, history

# Gradio UI
with gr.Blocks(
    css="""
    .chatbot {height: 600px; overflow-y: auto;}
    .sidebar {width: 300px;}
    #main {display: flex; flex-direction: row;}
    """,
    theme=gr.themes.Default(primary_hue="blue")
) as demo:
    gr.Markdown("# ResearchGPT - Paper Search & Chat")
    with gr.Row(elem_id="main"):
        # Sidebar for search
        with gr.Column(scale=1, min_width=300, elem_classes="sidebar"):
            gr.Markdown("### Search Papers")
            query_input = gr.Textbox(label="Enter your search query", placeholder="e.g., machine learning in healthcare")
            search_btn = gr.Button("Search")
            paper_dropdown = gr.Dropdown(label="Select a Paper", choices=[], interactive=True)
            search_btn.click(
                fn=get_relevant_papers,
                inputs=query_input,
                outputs=paper_dropdown
            )
        
        # Main chat area
        with gr.Column(scale=3):
            gr.Markdown("### Chat with Selected Paper")
            selected_paper = gr.Textbox(label="Selected Paper", interactive=False)
            chatbot = gr.Chatbot(label="Conversation", elem_classes="chatbot")
            question_input = gr.Textbox(label="Ask a question", placeholder="e.g., What methods are used?")
            chat_btn = gr.Button("Send")
            
            # State to store conversation history
            history_state = gr.State([])

            # Update selected paper
            paper_dropdown.change(
                fn=lambda x: x,
                inputs=paper_dropdown,
                outputs=selected_paper
            )
            
            # Handle chat
            chat_btn.click(
                fn=answer_question,
                inputs=[selected_paper, question_input, history_state],
                outputs=[chatbot, history_state],
                _js="() => {document.querySelector('.chatbot').scrollTop = document.querySelector('.chatbot').scrollHeight;}"
            )

# Launch the app
demo.launch()