ADKU commited on
Commit
43c1491
·
verified ·
1 Parent(s): 425d4bf

Update app.py

Browse files

made changes in the paper qs answering algorithm to make it robust and accurate

Files changed (1) hide show
  1. app.py +30 -14
app.py CHANGED
@@ -79,7 +79,7 @@ def generate_embeddings_sci_bert(texts, batch_size=32):
79
  return np.concatenate(all_embeddings, axis=0)
80
  except Exception as e:
81
  logger.error(f"Embedding generation failed: {e}")
82
- return np.zeros((len(texts), 768)) # Fallback to zero embeddings
83
 
84
  # Precompute embeddings and FAISS index
85
  try:
@@ -114,7 +114,7 @@ def get_relevant_papers(query):
114
  logger.error(f"Search failed: {e}")
115
  return [], "Search failed. Please try again."
116
 
117
- # GPT-2 QA function
118
  def answer_question(paper, question, history):
119
  if not paper:
120
  return [(question, "Please select a paper first!")], history
@@ -128,26 +128,42 @@ def answer_question(paper, question, history):
128
  title = paper.split(" - Abstract: ")[0].split(". ", 1)[1]
129
  abstract = paper.split(" - Abstract: ")[1].rstrip("...")
130
 
131
- # Build context with history
132
- context = f"Title: {title}\nAbstract: {abstract}\n\nPrevious conversation:\n"
133
- for user_q, bot_a in history:
134
- context += f"User: {user_q}\nAssistant: {bot_a}\n"
135
- context += f"User: {question}\nAssistant: "
 
 
136
 
137
- # Generate response
138
- inputs = gpt2_tokenizer(context, return_tensors="pt", truncation=True, max_length=512)
 
 
 
 
 
 
 
 
139
  inputs = {key: val.to(device) for key, val in inputs.items()}
140
  with torch.no_grad():
141
  outputs = gpt2_model.generate(
142
  inputs["input_ids"],
143
- max_new_tokens=100,
144
  do_sample=True,
145
- temperature=0.7,
146
- top_k=50,
147
  pad_token_id=gpt2_tokenizer.eos_token_id
148
  )
 
 
149
  response = gpt2_tokenizer.decode(outputs[0], skip_special_tokens=True)
150
- response = response[len(context):].strip()
 
 
 
 
151
 
152
  history.append((question, response))
153
  return history, history
@@ -218,7 +234,7 @@ with gr.Blocks(
218
  ).then(
219
  fn=lambda: "",
220
  inputs=None,
221
- outputs=question_input # Clear input
222
  )
223
 
224
  # Launch the app
 
79
  return np.concatenate(all_embeddings, axis=0)
80
  except Exception as e:
81
  logger.error(f"Embedding generation failed: {e}")
82
+ return np.zeros((len(texts), 768))
83
 
84
  # Precompute embeddings and FAISS index
85
  try:
 
114
  logger.error(f"Search failed: {e}")
115
  return [], "Search failed. Please try again."
116
 
117
+ # GPT-2 QA function with direct prompting
118
  def answer_question(paper, question, history):
119
  if not paper:
120
  return [(question, "Please select a paper first!")], history
 
128
  title = paper.split(" - Abstract: ")[0].split(". ", 1)[1]
129
  abstract = paper.split(" - Abstract: ")[1].rstrip("...")
130
 
131
+ # Build a simple prompt
132
+ prompt = (
133
+ f"You are an expert assistant. Based on the following paper details:\n"
134
+ f"Title: {title}\n"
135
+ f"Abstract: {abstract}\n\n"
136
+ f"Answer this question: {question}"
137
+ )
138
 
139
+ # Include recent history if available
140
+ if history:
141
+ prompt += "\n\nPrevious conversation:\n"
142
+ for user_q, bot_a in history[-2:]: # Last 2 turns for context
143
+ prompt += f"User: {user_q}\nAssistant: {bot_a}\n"
144
+
145
+ logger.info(f"Prompt sent to GPT-2: {prompt[:200]}...")
146
+
147
+ # Generate response directly
148
+ inputs = gpt2_tokenizer(prompt, return_tensors="pt", truncation=True, max_length=400)
149
  inputs = {key: val.to(device) for key, val in inputs.items()}
150
  with torch.no_grad():
151
  outputs = gpt2_model.generate(
152
  inputs["input_ids"],
153
+ max_new_tokens=150, # Longer responses for clarity
154
  do_sample=True,
155
+ temperature=0.8,
156
+ top_p=0.9,
157
  pad_token_id=gpt2_tokenizer.eos_token_id
158
  )
159
+
160
+ # Decode full output and extract response
161
  response = gpt2_tokenizer.decode(outputs[0], skip_special_tokens=True)
162
+ response = response[len(prompt):].strip() # Remove prompt from output
163
+
164
+ # Fallback for bad responses
165
+ if not response or len(response) < 10:
166
+ response = "I couldn’t generate a clear answer. Could you rephrase your question?"
167
 
168
  history.append((question, response))
169
  return history, history
 
234
  ).then(
235
  fn=lambda: "",
236
  inputs=None,
237
+ outputs=question_input
238
  )
239
 
240
  # Launch the app