ADKU commited on
Commit
cf496f0
·
verified ·
1 Parent(s): 9699ac9

Update app.py

Browse files

made a gradio app and extended the project with paper qs answering feature in the model using GPT 2 model and few enhancements to enhance the performance of the model

Files changed (1) hide show
  1. app.py +120 -88
app.py CHANGED
@@ -4,130 +4,162 @@ import numpy as np
4
  from rank_bm25 import BM25Okapi
5
  import torch
6
  import pandas as pd
7
- from fastapi import FastAPI
8
- from pydantic import BaseModel
9
- from transformers import DistilBertForSequenceClassification, DistilBertTokenizerFast, AutoTokenizer, AutoModel
10
 
11
- # Set cache directory to /tmp/huggingface (fixes permission error)
12
  os.environ["HF_HOME"] = "/tmp/huggingface"
13
- os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface"
14
- os.environ["HUGGINGFACE_HUB_CACHE"] = "/tmp/huggingface"
15
 
16
- app = FastAPI()
17
-
18
- # Ensure the correct file path
19
  DATASET_PATH = os.path.join(os.getcwd(), "springer_papers_DL.json")
20
-
21
  if not os.path.exists(DATASET_PATH):
22
  raise FileNotFoundError(f"Dataset file not found at {DATASET_PATH}")
23
-
24
- # Load dataset
25
  df = pd.read_json(DATASET_PATH)
26
 
27
- # Clean text function
28
  def clean_text(text):
29
  return text.strip().lower()
30
 
31
  df["cleaned_abstract"] = df["abstract"].apply(clean_text)
32
 
33
- # Precompute BM25 Index
34
  tokenized_corpus = [paper.split() for paper in df["cleaned_abstract"]]
35
  bm25 = BM25Okapi(tokenized_corpus)
36
 
37
- # Load embedding models
38
- embedding_models = {
39
- "BERT": "bert-base-uncased",
40
- "DistilBERT": "distilbert-base-uncased",
41
- "Sentence-BERT": "all-MiniLM-L6-v2",
42
- "MiniLM": "sentence-transformers/all-MiniLM-L12-v2",
43
- "SciBERT": "allenai/scibert_scivocab_uncased",
44
- }
45
-
46
- BATCH_SIZE = 32 # Batch size for processing
47
-
48
- # ✅ Function to clear GPU memory
49
- def clear_gpu_memory():
50
- torch.cuda.empty_cache()
51
-
52
- # ✅ Generate embeddings using SciBERT
53
- def generate_embeddings_sci_bert(texts, batch_size=BATCH_SIZE):
54
- model_name = "allenai/scibert_scivocab_uncased"
55
-
56
- tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir="/tmp/huggingface")
57
- model = AutoModel.from_pretrained(model_name, cache_dir="/tmp/huggingface")
58
-
59
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
60
- model.to(device)
61
-
62
  all_embeddings = []
63
  for i in range(0, len(texts), batch_size):
64
- batch = texts[i : i + batch_size]
65
- inputs = tokenizer(batch, return_tensors="pt", padding=True, truncation=True, max_length=512)
66
  inputs = {key: val.to(device) for key, val in inputs.items()}
67
-
68
  with torch.no_grad():
69
- outputs = model(**inputs)
70
-
71
  embeddings = outputs.last_hidden_state.mean(dim=1)
72
  all_embeddings.append(embeddings.cpu().numpy())
73
-
74
- clear_gpu_memory()
75
-
76
  return np.concatenate(all_embeddings, axis=0)
77
 
78
- # Compute embeddings
79
  abstracts = df["cleaned_abstract"].tolist()
80
- embeddings = generate_embeddings_sci_bert(abstracts, batch_size=BATCH_SIZE)
81
-
82
- # ✅ Initialize FAISS index
83
  dimension = embeddings.shape[1]
84
  faiss_index = faiss.IndexFlatL2(dimension)
85
  faiss_index.add(embeddings.astype(np.float32))
86
 
87
- # API Request Model
88
- class InputText(BaseModel):
89
- query: str
90
- top_k: int = 5
91
-
92
- # ✅ Hybrid Search Function
93
  def get_relevant_papers(query, top_k=5):
94
  if not query.strip():
95
- return {"error": "Query is empty. Please enter a valid search query."}
96
-
97
- # 1️⃣ Generate query embedding
98
- query_embedding = generate_embeddings_sci_bert([query], batch_size=1)
99
-
100
- # 2️⃣ Perform FAISS similarity search
101
  distances, indices = faiss_index.search(query_embedding.astype(np.float32), top_k)
102
-
103
- # 3️⃣ Perform BM25 keyword search
104
  tokenized_query = query.split()
105
  bm25_scores = bm25.get_scores(tokenized_query)
106
  bm25_top_indices = np.argsort(bm25_scores)[::-1][:top_k]
107
-
108
- # 4️⃣ Combine FAISS and BM25 results
109
  combined_indices = list(set(indices[0]) | set(bm25_top_indices))
110
  ranked_results = sorted(combined_indices, key=lambda idx: -bm25_scores[idx])
111
-
112
- # 5️⃣ Retrieve relevant papers
113
- relevant_papers = []
114
  for i, index in enumerate(ranked_results[:top_k]):
115
  paper = df.iloc[index]
116
- relevant_papers.append({
117
- "rank": i + 1,
118
- "title": paper["title"],
119
- "authors": paper["authors"],
120
- "abstract": paper["cleaned_abstract"],
121
- })
122
-
123
- return {"results": relevant_papers}
124
-
125
- # ✅ FastAPI Endpoint
126
- @app.post("/predict/")
127
- async def predict(data: InputText):
128
- return get_relevant_papers(data.query, data.top_k)
129
-
130
- # Run FastAPI
131
- if __name__ == "__main__":
132
- import uvicorn
133
- uvicorn.run(app, host="0.0.0.0")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  from rank_bm25 import BM25Okapi
5
  import torch
6
  import pandas as pd
7
+ import gradio as gr
8
+ from transformers import AutoTokenizer, AutoModel, GPT2LMHeadModel, GPT2Tokenizer
 
9
 
10
+ # Set cache directory for Hugging Face models
11
  os.environ["HF_HOME"] = "/tmp/huggingface"
 
 
12
 
13
+ # Load dataset
 
 
14
  DATASET_PATH = os.path.join(os.getcwd(), "springer_papers_DL.json")
 
15
  if not os.path.exists(DATASET_PATH):
16
  raise FileNotFoundError(f"Dataset file not found at {DATASET_PATH}")
 
 
17
  df = pd.read_json(DATASET_PATH)
18
 
19
+ # Clean text
20
  def clean_text(text):
21
  return text.strip().lower()
22
 
23
  df["cleaned_abstract"] = df["abstract"].apply(clean_text)
24
 
25
+ # Precompute BM25 Index
26
  tokenized_corpus = [paper.split() for paper in df["cleaned_abstract"]]
27
  bm25 = BM25Okapi(tokenized_corpus)
28
 
29
+ # Load SciBERT for embeddings (preloaded globally)
30
+ sci_bert_tokenizer = AutoTokenizer.from_pretrained("allenai/scibert_scivocab_uncased", cache_dir="/tmp/huggingface")
31
+ sci_bert_model = AutoModel.from_pretrained("allenai/scibert_scivocab_uncased", cache_dir="/tmp/huggingface")
32
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
33
+ sci_bert_model.to(device)
34
+ sci_bert_model.eval()
35
+
36
+ # Load GPT-2 for QA (using distilgpt2 for efficiency)
37
+ gpt2_tokenizer = GPT2Tokenizer.from_pretrained("distilgpt2", cache_dir="/tmp/huggingface")
38
+ gpt2_model = GPT2LMHeadModel.from_pretrained("distilgpt2", cache_dir="/tmp/huggingface")
39
+ gpt2_model.to(device)
40
+ gpt2_model.eval()
41
+
42
+ # Generate SciBERT embeddings
43
+ def generate_embeddings_sci_bert(texts, batch_size=32):
 
 
 
 
 
 
 
 
 
 
44
  all_embeddings = []
45
  for i in range(0, len(texts), batch_size):
46
+ batch = texts[i:i + batch_size]
47
+ inputs = sci_bert_tokenizer(batch, return_tensors="pt", padding=True, truncation=True, max_length=512)
48
  inputs = {key: val.to(device) for key, val in inputs.items()}
 
49
  with torch.no_grad():
50
+ outputs = sci_bert_model(**inputs)
 
51
  embeddings = outputs.last_hidden_state.mean(dim=1)
52
  all_embeddings.append(embeddings.cpu().numpy())
53
+ torch.cuda.empty_cache()
 
 
54
  return np.concatenate(all_embeddings, axis=0)
55
 
56
+ # Precompute embeddings and FAISS index
57
  abstracts = df["cleaned_abstract"].tolist()
58
+ embeddings = generate_embeddings_sci_bert(abstracts)
 
 
59
  dimension = embeddings.shape[1]
60
  faiss_index = faiss.IndexFlatL2(dimension)
61
  faiss_index.add(embeddings.astype(np.float32))
62
 
63
+ # Hybrid search function
 
 
 
 
 
64
  def get_relevant_papers(query, top_k=5):
65
  if not query.strip():
66
+ return []
67
+ query_embedding = generate_embeddings_sci_bert([query])
 
 
 
 
68
  distances, indices = faiss_index.search(query_embedding.astype(np.float32), top_k)
 
 
69
  tokenized_query = query.split()
70
  bm25_scores = bm25.get_scores(tokenized_query)
71
  bm25_top_indices = np.argsort(bm25_scores)[::-1][:top_k]
 
 
72
  combined_indices = list(set(indices[0]) | set(bm25_top_indices))
73
  ranked_results = sorted(combined_indices, key=lambda idx: -bm25_scores[idx])
74
+ papers = []
 
 
75
  for i, index in enumerate(ranked_results[:top_k]):
76
  paper = df.iloc[index]
77
+ papers.append(f"{i+1}. {paper['title']} - Abstract: {paper['cleaned_abstract'][:200]}...")
78
+ return papers
79
+
80
+ # GPT-2 QA function
81
+ def answer_question(paper, question, history):
82
+ if not question.strip():
83
+ return "Please ask a question!", history
84
+ if question.lower() in ["exit", "done"]:
85
+ return "Conversation ended. Select a new paper or search again!", []
86
+
87
+ # Extract title and abstract from paper string
88
+ title = paper.split(" - Abstract: ")[0].split(". ", 1)[1]
89
+ abstract = paper.split(" - Abstract: ")[1].rstrip("...")
90
+
91
+ # Build context with history
92
+ context = f"Title: {title}\nAbstract: {abstract}\n\nPrevious conversation:\n"
93
+ for user_q, bot_a in history:
94
+ context += f"User: {user_q}\nAssistant: {bot_a}\n"
95
+ context += f"User: {question}\nAssistant: "
96
+
97
+ # Generate response
98
+ inputs = gpt2_tokenizer(context, return_tensors="pt", truncation=True, max_length=512)
99
+ inputs = {key: val.to(device) for key, val in inputs.items()}
100
+ with torch.no_grad():
101
+ outputs = gpt2_model.generate(
102
+ inputs["input_ids"],
103
+ max_new_tokens=100,
104
+ do_sample=True,
105
+ temperature=0.7,
106
+ top_k=50,
107
+ pad_token_id=gpt2_tokenizer.eos_token_id
108
+ )
109
+ response = gpt2_tokenizer.decode(outputs[0], skip_special_tokens=True)
110
+ response = response[len(context):].strip()
111
+
112
+ history.append((question, response))
113
+ return response, history
114
+
115
+ # Gradio UI
116
+ with gr.Blocks(
117
+ css="""
118
+ .chatbot {height: 600px; overflow-y: auto;}
119
+ .sidebar {width: 300px;}
120
+ #main {display: flex; flex-direction: row;}
121
+ """,
122
+ theme=gr.themes.Default(primary_hue="blue")
123
+ ) as demo:
124
+ gr.Markdown("# ResearchGPT - Paper Search & Chat")
125
+ with gr.Row(elem_id="main"):
126
+ # Sidebar for search
127
+ with gr.Column(scale=1, min_width=300, elem_classes="sidebar"):
128
+ gr.Markdown("### Search Papers")
129
+ query_input = gr.Textbox(label="Enter your search query", placeholder="e.g., machine learning in healthcare")
130
+ search_btn = gr.Button("Search")
131
+ paper_dropdown = gr.Dropdown(label="Select a Paper", choices=[], interactive=True)
132
+ search_btn.click(
133
+ fn=get_relevant_papers,
134
+ inputs=query_input,
135
+ outputs=paper_dropdown
136
+ )
137
+
138
+ # Main chat area
139
+ with gr.Column(scale=3):
140
+ gr.Markdown("### Chat with Selected Paper")
141
+ selected_paper = gr.Textbox(label="Selected Paper", interactive=False)
142
+ chatbot = gr.Chatbot(label="Conversation", elem_classes="chatbot")
143
+ question_input = gr.Textbox(label="Ask a question", placeholder="e.g., What methods are used?")
144
+ chat_btn = gr.Button("Send")
145
+
146
+ # State to store conversation history
147
+ history_state = gr.State([])
148
+
149
+ # Update selected paper
150
+ paper_dropdown.change(
151
+ fn=lambda x: x,
152
+ inputs=paper_dropdown,
153
+ outputs=selected_paper
154
+ )
155
+
156
+ # Handle chat
157
+ chat_btn.click(
158
+ fn=answer_question,
159
+ inputs=[selected_paper, question_input, history_state],
160
+ outputs=[chatbot, history_state],
161
+ _js="() => {document.querySelector('.chatbot').scrollTop = document.querySelector('.chatbot').scrollHeight;}"
162
+ )
163
+
164
+ # Launch the app
165
+ demo.launch()