ADKU commited on
Commit
1130652
·
verified ·
1 Parent(s): ad54e4d

updated model from gpt2 to mistral for enhancement in response

Browse files
Files changed (1) hide show
  1. app.py +20 -22
app.py CHANGED
@@ -5,7 +5,7 @@ from rank_bm25 import BM25Okapi
5
  import torch
6
  import pandas as pd
7
  import gradio as gr
8
- from transformers import AutoTokenizer, AutoModel, GPT2LMHeadModel, GPT2Tokenizer
9
  import logging
10
 
11
  # Set up logging
@@ -53,12 +53,12 @@ try:
53
  sci_bert_model.eval()
54
  logger.info("SciBERT loaded")
55
 
56
- # DistilGPT-2 for QA
57
- gpt2_tokenizer = GPT2Tokenizer.from_pretrained("distilgpt2", cache_dir="/tmp/huggingface")
58
- gpt2_model = GPT2LMHeadModel.from_pretrained("distilgpt2", cache_dir="/tmp/huggingface")
59
- gpt2_model.to(device)
60
- gpt2_model.eval()
61
- logger.info("DistilGPT-2 loaded")
62
  except Exception as e:
63
  logger.error(f"Model loading failed: {e}")
64
  raise
@@ -114,7 +114,7 @@ def get_relevant_papers(query):
114
  logger.error(f"Search failed: {e}")
115
  return [], "Search failed. Please try again."
116
 
117
- # GPT-2 QA function with the best prompt
118
  def answer_question(paper, question, history):
119
  if not paper:
120
  return [(question, "Please select a paper first!")], history
@@ -128,13 +128,11 @@ def answer_question(paper, question, history):
128
  title = paper.split(" - Abstract: ")[0].split(". ", 1)[1]
129
  abstract = paper.split(" - Abstract: ")[1].rstrip("...")
130
 
131
- # Build the ultimate prompt
132
  prompt = (
133
- "You are Dr. Sage, the world's most brilliant and reliable research assistant, specializing in machine learning, deep learning and any abstract or title you are given as input. "
134
  "Your goal is to provide concise, accurate, and well-structured answers based on the given paper's title and abstract. "
135
- "Donot repeat the same sentence again and again no matter what, use your own intelligence to anser some vague question or question whos data is not with you."
136
- "Be the best RESEARCH ASSISTANT ever existed"
137
- "When asked about tech stacks or methods, use the following guidelines:\n"
138
  "1. If the abstract explicitly mentions technologies (e.g., Python, TensorFlow), list them precisely with brief explanations.\n"
139
  "2. If the abstract is vague (e.g., 'machine learning techniques'), infer the most likely tech stacks based on the context of crop prediction and modern research practices, and explain your reasoning.\n"
140
  "3. Always respond in a clear, concise format—use bullet points for lists (e.g., tech stacks) and short paragraphs for explanations.\n"
@@ -148,30 +146,30 @@ def answer_question(paper, question, history):
148
 
149
  # Add history if present
150
  if history:
151
- prompt += "Previous conversation (if any, use for context):\n"
152
  for user_q, bot_a in history[-2:]:
153
  prompt += f"User: {user_q}\nAssistant: {bot_a}\n"
154
 
155
- prompt += f"Now, answer this question: {question}"
156
 
157
- logger.info(f"Prompt sent to GPT-2: {prompt[:200]}...")
158
 
159
  # Generate response
160
- inputs = gpt2_tokenizer(prompt, return_tensors="pt", truncation=True, max_length=400)
161
  inputs = {key: val.to(device) for key, val in inputs.items()}
162
  with torch.no_grad():
163
- outputs = gpt2_model.generate(
164
  inputs["input_ids"],
165
- max_new_tokens=150,
166
  do_sample=True,
167
  temperature=0.7,
168
  top_p=0.9,
169
- pad_token_id=gpt2_tokenizer.eos_token_id
170
  )
171
 
172
  # Decode and clean response
173
- response = gpt2_tokenizer.decode(outputs[0], skip_special_tokens=True)
174
- response = response[len(prompt):].strip()
175
 
176
  # Fallback for poor responses
177
  if not response or len(response) < 15:
 
5
  import torch
6
  import pandas as pd
7
  import gradio as gr
8
+ from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM
9
  import logging
10
 
11
  # Set up logging
 
53
  sci_bert_model.eval()
54
  logger.info("SciBERT loaded")
55
 
56
+ # Mistral-7B-Instruct for QA
57
+ mistral_tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1", cache_dir="/tmp/huggingface")
58
+ mistral_model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1", cache_dir="/tmp/huggingface")
59
+ mistral_model.to(device)
60
+ mistral_model.eval()
61
+ logger.info("Mistral-7B-Instruct loaded")
62
  except Exception as e:
63
  logger.error(f"Model loading failed: {e}")
64
  raise
 
114
  logger.error(f"Search failed: {e}")
115
  return [], "Search failed. Please try again."
116
 
117
+ # Mistral QA function with optimized prompt
118
  def answer_question(paper, question, history):
119
  if not paper:
120
  return [(question, "Please select a paper first!")], history
 
128
  title = paper.split(" - Abstract: ")[0].split(". ", 1)[1]
129
  abstract = paper.split(" - Abstract: ")[1].rstrip("...")
130
 
131
+ # Build the ultimate prompt with Mistral's instruction format
132
  prompt = (
133
+ "<s>[INST] You are Dr. Sage, the world's most brilliant and reliable research assistant, specializing in machine learning, deep learning, and agriculture. "
134
  "Your goal is to provide concise, accurate, and well-structured answers based on the given paper's title and abstract. "
135
+ "When asked about tech stacks or methods, follow these guidelines:\n"
 
 
136
  "1. If the abstract explicitly mentions technologies (e.g., Python, TensorFlow), list them precisely with brief explanations.\n"
137
  "2. If the abstract is vague (e.g., 'machine learning techniques'), infer the most likely tech stacks based on the context of crop prediction and modern research practices, and explain your reasoning.\n"
138
  "3. Always respond in a clear, concise format—use bullet points for lists (e.g., tech stacks) and short paragraphs for explanations.\n"
 
146
 
147
  # Add history if present
148
  if history:
149
+ prompt += "Previous conversation (use for context):\n"
150
  for user_q, bot_a in history[-2:]:
151
  prompt += f"User: {user_q}\nAssistant: {bot_a}\n"
152
 
153
+ prompt += f"Now, answer this question: {question} [/INST]</s>"
154
 
155
+ logger.info(f"Prompt sent to Mistral: {prompt[:200]}...")
156
 
157
  # Generate response
158
+ inputs = mistral_tokenizer(prompt, return_tensors="pt", truncation=True, max_length=400)
159
  inputs = {key: val.to(device) for key, val in inputs.items()}
160
  with torch.no_grad():
161
+ outputs = mistral_model.generate(
162
  inputs["input_ids"],
163
+ max_new_tokens=200, # More tokens for detailed answers
164
  do_sample=True,
165
  temperature=0.7,
166
  top_p=0.9,
167
+ pad_token_id=mistral_tokenizer.eos_token_id
168
  )
169
 
170
  # Decode and clean response
171
+ response = mistral_tokenizer.decode(outputs[0], skip_special_tokens=True)
172
+ response = response[len(prompt):].strip() # Remove prompt, including [INST] tags
173
 
174
  # Fallback for poor responses
175
  if not response or len(response) < 15: