Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -15,11 +15,6 @@ logger = logging.getLogger(__name__)
|
|
15 |
# Set cache directory for Hugging Face models
|
16 |
os.environ["HF_HOME"] = "/tmp/huggingface"
|
17 |
|
18 |
-
# Get Hugging Face token from environment variable (set in Spaces secrets)
|
19 |
-
HF_TOKEN = os.getenv("HF_TOKEN")
|
20 |
-
if not HF_TOKEN:
|
21 |
-
logger.warning("HF_TOKEN not set. Mistral model access may fail. Set it in Hugging Face Spaces secrets.")
|
22 |
-
|
23 |
# Load dataset with error handling
|
24 |
DATASET_PATH = os.path.join(os.getcwd(), "springer_papers_DL.json")
|
25 |
try:
|
@@ -58,12 +53,12 @@ try:
|
|
58 |
sci_bert_model.eval()
|
59 |
logger.info("SciBERT loaded")
|
60 |
|
61 |
-
#
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
logger.info("
|
67 |
except Exception as e:
|
68 |
logger.error(f"Model loading failed: {e}")
|
69 |
raise
|
@@ -119,7 +114,7 @@ def get_relevant_papers(query):
|
|
119 |
logger.error(f"Search failed: {e}")
|
120 |
return [], "Search failed. Please try again."
|
121 |
|
122 |
-
#
|
123 |
def answer_question(paper, question, history):
|
124 |
if not paper:
|
125 |
return [(question, "Please select a paper first!")], history
|
@@ -133,9 +128,10 @@ def answer_question(paper, question, history):
|
|
133 |
title = paper.split(" - Abstract: ")[0].split(". ", 1)[1]
|
134 |
abstract = paper.split(" - Abstract: ")[1].rstrip("...")
|
135 |
|
136 |
-
# Build
|
137 |
prompt = (
|
138 |
-
"
|
|
|
139 |
"Your goal is to provide concise, accurate, and well-structured answers based on the given paper's title and abstract. "
|
140 |
"When asked about tech stacks or methods, follow these guidelines:\n"
|
141 |
"1. If the abstract explicitly mentions technologies (e.g., Python, TensorFlow), list them precisely with brief explanations.\n"
|
@@ -155,26 +151,26 @@ def answer_question(paper, question, history):
|
|
155 |
for user_q, bot_a in history[-2:]:
|
156 |
prompt += f"User: {user_q}\nAssistant: {bot_a}\n"
|
157 |
|
158 |
-
prompt += f"Now, answer this question: {question}
|
159 |
|
160 |
-
logger.info(f"Prompt sent to
|
161 |
|
162 |
# Generate response
|
163 |
-
inputs =
|
164 |
inputs = {key: val.to(device) for key, val in inputs.items()}
|
165 |
with torch.no_grad():
|
166 |
-
outputs =
|
167 |
inputs["input_ids"],
|
168 |
max_new_tokens=200,
|
169 |
do_sample=True,
|
170 |
temperature=0.7,
|
171 |
top_p=0.9,
|
172 |
-
pad_token_id=
|
173 |
)
|
174 |
|
175 |
-
# Decode and clean response
|
176 |
-
response =
|
177 |
-
response = response[len(prompt):].strip() # Remove prompt, including
|
178 |
|
179 |
# Fallback for poor responses
|
180 |
if not response or len(response) < 15:
|
|
|
15 |
# Set cache directory for Hugging Face models
|
16 |
os.environ["HF_HOME"] = "/tmp/huggingface"
|
17 |
|
|
|
|
|
|
|
|
|
|
|
18 |
# Load dataset with error handling
|
19 |
DATASET_PATH = os.path.join(os.getcwd(), "springer_papers_DL.json")
|
20 |
try:
|
|
|
53 |
sci_bert_model.eval()
|
54 |
logger.info("SciBERT loaded")
|
55 |
|
56 |
+
# Qwen1.5-1.8B-Chat for QA (ungated)
|
57 |
+
qwen_tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen1.5-1.8B-Chat", cache_dir="/tmp/huggingface")
|
58 |
+
qwen_model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen1.5-1.8B-Chat", cache_dir="/tmp/huggingface")
|
59 |
+
qwen_model.to(device)
|
60 |
+
qwen_model.eval()
|
61 |
+
logger.info("Qwen1.5-1.8B-Chat loaded")
|
62 |
except Exception as e:
|
63 |
logger.error(f"Model loading failed: {e}")
|
64 |
raise
|
|
|
114 |
logger.error(f"Search failed: {e}")
|
115 |
return [], "Search failed. Please try again."
|
116 |
|
117 |
+
# Qwen QA function with optimized prompt
|
118 |
def answer_question(paper, question, history):
|
119 |
if not paper:
|
120 |
return [(question, "Please select a paper first!")], history
|
|
|
128 |
title = paper.split(" - Abstract: ")[0].split(". ", 1)[1]
|
129 |
abstract = paper.split(" - Abstract: ")[1].rstrip("...")
|
130 |
|
131 |
+
# Build prompt with Qwen's chat format
|
132 |
prompt = (
|
133 |
+
"<|im_start|>user\n"
|
134 |
+
"You are Dr. Sage, the world's most brilliant and reliable research assistant, specializing in machine learning, deep learning, and agriculture. "
|
135 |
"Your goal is to provide concise, accurate, and well-structured answers based on the given paper's title and abstract. "
|
136 |
"When asked about tech stacks or methods, follow these guidelines:\n"
|
137 |
"1. If the abstract explicitly mentions technologies (e.g., Python, TensorFlow), list them precisely with brief explanations.\n"
|
|
|
151 |
for user_q, bot_a in history[-2:]:
|
152 |
prompt += f"User: {user_q}\nAssistant: {bot_a}\n"
|
153 |
|
154 |
+
prompt += f"Now, answer this question: {question}<|im_end|>\n<|im_start|>assistant"
|
155 |
|
156 |
+
logger.info(f"Prompt sent to Qwen: {prompt[:200]}...")
|
157 |
|
158 |
# Generate response
|
159 |
+
inputs = qwen_tokenizer(prompt, return_tensors="pt", truncation=True, max_length=400)
|
160 |
inputs = {key: val.to(device) for key, val in inputs.items()}
|
161 |
with torch.no_grad():
|
162 |
+
outputs = qwen_model.generate(
|
163 |
inputs["input_ids"],
|
164 |
max_new_tokens=200,
|
165 |
do_sample=True,
|
166 |
temperature=0.7,
|
167 |
top_p=0.9,
|
168 |
+
pad_token_id=qwen_tokenizer.eos_token_id
|
169 |
)
|
170 |
|
171 |
+
# Decode and clean response
|
172 |
+
response = qwen_tokenizer.decode(outputs[0], skip_special_tokens=True)
|
173 |
+
response = response[len(prompt):].strip() # Remove prompt, including <|im_start|> tags
|
174 |
|
175 |
# Fallback for poor responses
|
176 |
if not response or len(response) < 15:
|