Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -15,6 +15,11 @@ logger = logging.getLogger(__name__)
|
|
15 |
# Set cache directory for Hugging Face models
|
16 |
os.environ["HF_HOME"] = "/tmp/huggingface"
|
17 |
|
|
|
|
|
|
|
|
|
|
|
18 |
# Load dataset with error handling
|
19 |
DATASET_PATH = os.path.join(os.getcwd(), "springer_papers_DL.json")
|
20 |
try:
|
@@ -53,9 +58,9 @@ try:
|
|
53 |
sci_bert_model.eval()
|
54 |
logger.info("SciBERT loaded")
|
55 |
|
56 |
-
# Mistral-7B-Instruct for QA
|
57 |
-
mistral_tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1", cache_dir="/tmp/huggingface")
|
58 |
-
mistral_model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1", cache_dir="/tmp/huggingface")
|
59 |
mistral_model.to(device)
|
60 |
mistral_model.eval()
|
61 |
logger.info("Mistral-7B-Instruct loaded")
|
@@ -160,14 +165,14 @@ def answer_question(paper, question, history):
|
|
160 |
with torch.no_grad():
|
161 |
outputs = mistral_model.generate(
|
162 |
inputs["input_ids"],
|
163 |
-
max_new_tokens=200,
|
164 |
do_sample=True,
|
165 |
temperature=0.7,
|
166 |
top_p=0.9,
|
167 |
pad_token_id=mistral_tokenizer.eos_token_id
|
168 |
)
|
169 |
|
170 |
-
# Decode and clean response
|
171 |
response = mistral_tokenizer.decode(outputs[0], skip_special_tokens=True)
|
172 |
response = response[len(prompt):].strip() # Remove prompt, including [INST] tags
|
173 |
|
|
|
15 |
# Set cache directory for Hugging Face models
|
16 |
os.environ["HF_HOME"] = "/tmp/huggingface"
|
17 |
|
18 |
+
# Get Hugging Face token from environment variable (set in Spaces secrets)
|
19 |
+
HF_TOKEN = os.getenv("HF_TOKEN")
|
20 |
+
if not HF_TOKEN:
|
21 |
+
logger.warning("HF_TOKEN not set. Mistral model access may fail. Set it in Hugging Face Spaces secrets.")
|
22 |
+
|
23 |
# Load dataset with error handling
|
24 |
DATASET_PATH = os.path.join(os.getcwd(), "springer_papers_DL.json")
|
25 |
try:
|
|
|
58 |
sci_bert_model.eval()
|
59 |
logger.info("SciBERT loaded")
|
60 |
|
61 |
+
# Mistral-7B-Instruct for QA with token
|
62 |
+
mistral_tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1", cache_dir="/tmp/huggingface", use_auth_token=HF_TOKEN)
|
63 |
+
mistral_model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1", cache_dir="/tmp/huggingface", use_auth_token=HF_TOKEN)
|
64 |
mistral_model.to(device)
|
65 |
mistral_model.eval()
|
66 |
logger.info("Mistral-7B-Instruct loaded")
|
|
|
165 |
with torch.no_grad():
|
166 |
outputs = mistral_model.generate(
|
167 |
inputs["input_ids"],
|
168 |
+
max_new_tokens=200,
|
169 |
do_sample=True,
|
170 |
temperature=0.7,
|
171 |
top_p=0.9,
|
172 |
pad_token_id=mistral_tokenizer.eos_token_id
|
173 |
)
|
174 |
|
175 |
+
# Decode and clean response (preserve token structure)
|
176 |
response = mistral_tokenizer.decode(outputs[0], skip_special_tokens=True)
|
177 |
response = response[len(prompt):].strip() # Remove prompt, including [INST] tags
|
178 |
|