ADKU commited on
Commit
d6c0c81
·
verified ·
1 Parent(s): d5d95b0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -22
app.py CHANGED
@@ -15,11 +15,6 @@ logger = logging.getLogger(__name__)
15
  # Set cache directory for Hugging Face models
16
  os.environ["HF_HOME"] = "/tmp/huggingface"
17
 
18
- # Get Hugging Face token from environment variable (set in Spaces secrets)
19
- HF_TOKEN = os.getenv("HF_TOKEN")
20
- if not HF_TOKEN:
21
- logger.warning("HF_TOKEN not set. Mistral model access may fail. Set it in Hugging Face Spaces secrets.")
22
-
23
  # Load dataset with error handling
24
  DATASET_PATH = os.path.join(os.getcwd(), "springer_papers_DL.json")
25
  try:
@@ -58,12 +53,12 @@ try:
58
  sci_bert_model.eval()
59
  logger.info("SciBERT loaded")
60
 
61
- # Mistral-7B-Instruct for QA with token
62
- mistral_tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1", cache_dir="/tmp/huggingface", use_auth_token=HF_TOKEN)
63
- mistral_model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1", cache_dir="/tmp/huggingface", use_auth_token=HF_TOKEN)
64
- mistral_model.to(device)
65
- mistral_model.eval()
66
- logger.info("Mistral-7B-Instruct loaded")
67
  except Exception as e:
68
  logger.error(f"Model loading failed: {e}")
69
  raise
@@ -119,7 +114,7 @@ def get_relevant_papers(query):
119
  logger.error(f"Search failed: {e}")
120
  return [], "Search failed. Please try again."
121
 
122
- # Mistral QA function with optimized prompt
123
  def answer_question(paper, question, history):
124
  if not paper:
125
  return [(question, "Please select a paper first!")], history
@@ -133,9 +128,10 @@ def answer_question(paper, question, history):
133
  title = paper.split(" - Abstract: ")[0].split(". ", 1)[1]
134
  abstract = paper.split(" - Abstract: ")[1].rstrip("...")
135
 
136
- # Build the ultimate prompt with Mistral's instruction format
137
  prompt = (
138
- "<s>[INST] You are Dr. Sage, the world's most brilliant and reliable research assistant, specializing in machine learning, deep learning, and agriculture. "
 
139
  "Your goal is to provide concise, accurate, and well-structured answers based on the given paper's title and abstract. "
140
  "When asked about tech stacks or methods, follow these guidelines:\n"
141
  "1. If the abstract explicitly mentions technologies (e.g., Python, TensorFlow), list them precisely with brief explanations.\n"
@@ -155,26 +151,26 @@ def answer_question(paper, question, history):
155
  for user_q, bot_a in history[-2:]:
156
  prompt += f"User: {user_q}\nAssistant: {bot_a}\n"
157
 
158
- prompt += f"Now, answer this question: {question} [/INST]</s>"
159
 
160
- logger.info(f"Prompt sent to Mistral: {prompt[:200]}...")
161
 
162
  # Generate response
163
- inputs = mistral_tokenizer(prompt, return_tensors="pt", truncation=True, max_length=400)
164
  inputs = {key: val.to(device) for key, val in inputs.items()}
165
  with torch.no_grad():
166
- outputs = mistral_model.generate(
167
  inputs["input_ids"],
168
  max_new_tokens=200,
169
  do_sample=True,
170
  temperature=0.7,
171
  top_p=0.9,
172
- pad_token_id=mistral_tokenizer.eos_token_id
173
  )
174
 
175
- # Decode and clean response (preserve token structure)
176
- response = mistral_tokenizer.decode(outputs[0], skip_special_tokens=True)
177
- response = response[len(prompt):].strip() # Remove prompt, including [INST] tags
178
 
179
  # Fallback for poor responses
180
  if not response or len(response) < 15:
 
15
  # Set cache directory for Hugging Face models
16
  os.environ["HF_HOME"] = "/tmp/huggingface"
17
 
 
 
 
 
 
18
  # Load dataset with error handling
19
  DATASET_PATH = os.path.join(os.getcwd(), "springer_papers_DL.json")
20
  try:
 
53
  sci_bert_model.eval()
54
  logger.info("SciBERT loaded")
55
 
56
+ # Qwen1.5-1.8B-Chat for QA (ungated)
57
+ qwen_tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen1.5-1.8B-Chat", cache_dir="/tmp/huggingface")
58
+ qwen_model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen1.5-1.8B-Chat", cache_dir="/tmp/huggingface")
59
+ qwen_model.to(device)
60
+ qwen_model.eval()
61
+ logger.info("Qwen1.5-1.8B-Chat loaded")
62
  except Exception as e:
63
  logger.error(f"Model loading failed: {e}")
64
  raise
 
114
  logger.error(f"Search failed: {e}")
115
  return [], "Search failed. Please try again."
116
 
117
+ # Qwen QA function with optimized prompt
118
  def answer_question(paper, question, history):
119
  if not paper:
120
  return [(question, "Please select a paper first!")], history
 
128
  title = paper.split(" - Abstract: ")[0].split(". ", 1)[1]
129
  abstract = paper.split(" - Abstract: ")[1].rstrip("...")
130
 
131
+ # Build prompt with Qwen's chat format
132
  prompt = (
133
+ "<|im_start|>user\n"
134
+ "You are Dr. Sage, the world's most brilliant and reliable research assistant, specializing in machine learning, deep learning, and agriculture. "
135
  "Your goal is to provide concise, accurate, and well-structured answers based on the given paper's title and abstract. "
136
  "When asked about tech stacks or methods, follow these guidelines:\n"
137
  "1. If the abstract explicitly mentions technologies (e.g., Python, TensorFlow), list them precisely with brief explanations.\n"
 
151
  for user_q, bot_a in history[-2:]:
152
  prompt += f"User: {user_q}\nAssistant: {bot_a}\n"
153
 
154
+ prompt += f"Now, answer this question: {question}<|im_end|>\n<|im_start|>assistant"
155
 
156
+ logger.info(f"Prompt sent to Qwen: {prompt[:200]}...")
157
 
158
  # Generate response
159
+ inputs = qwen_tokenizer(prompt, return_tensors="pt", truncation=True, max_length=400)
160
  inputs = {key: val.to(device) for key, val in inputs.items()}
161
  with torch.no_grad():
162
+ outputs = qwen_model.generate(
163
  inputs["input_ids"],
164
  max_new_tokens=200,
165
  do_sample=True,
166
  temperature=0.7,
167
  top_p=0.9,
168
+ pad_token_id=qwen_tokenizer.eos_token_id
169
  )
170
 
171
+ # Decode and clean response
172
+ response = qwen_tokenizer.decode(outputs[0], skip_special_tokens=True)
173
+ response = response[len(prompt):].strip() # Remove prompt, including <|im_start|> tags
174
 
175
  # Fallback for poor responses
176
  if not response or len(response) < 15: