ADKU commited on
Commit
d5d95b0
·
verified ·
1 Parent(s): 1130652

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -5
app.py CHANGED
@@ -15,6 +15,11 @@ logger = logging.getLogger(__name__)
15
  # Set cache directory for Hugging Face models
16
  os.environ["HF_HOME"] = "/tmp/huggingface"
17
 
 
 
 
 
 
18
  # Load dataset with error handling
19
  DATASET_PATH = os.path.join(os.getcwd(), "springer_papers_DL.json")
20
  try:
@@ -53,9 +58,9 @@ try:
53
  sci_bert_model.eval()
54
  logger.info("SciBERT loaded")
55
 
56
- # Mistral-7B-Instruct for QA
57
- mistral_tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1", cache_dir="/tmp/huggingface")
58
- mistral_model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1", cache_dir="/tmp/huggingface")
59
  mistral_model.to(device)
60
  mistral_model.eval()
61
  logger.info("Mistral-7B-Instruct loaded")
@@ -160,14 +165,14 @@ def answer_question(paper, question, history):
160
  with torch.no_grad():
161
  outputs = mistral_model.generate(
162
  inputs["input_ids"],
163
- max_new_tokens=200, # More tokens for detailed answers
164
  do_sample=True,
165
  temperature=0.7,
166
  top_p=0.9,
167
  pad_token_id=mistral_tokenizer.eos_token_id
168
  )
169
 
170
- # Decode and clean response
171
  response = mistral_tokenizer.decode(outputs[0], skip_special_tokens=True)
172
  response = response[len(prompt):].strip() # Remove prompt, including [INST] tags
173
 
 
15
  # Set cache directory for Hugging Face models
16
  os.environ["HF_HOME"] = "/tmp/huggingface"
17
 
18
+ # Get Hugging Face token from environment variable (set in Spaces secrets)
19
+ HF_TOKEN = os.getenv("HF_TOKEN")
20
+ if not HF_TOKEN:
21
+ logger.warning("HF_TOKEN not set. Mistral model access may fail. Set it in Hugging Face Spaces secrets.")
22
+
23
  # Load dataset with error handling
24
  DATASET_PATH = os.path.join(os.getcwd(), "springer_papers_DL.json")
25
  try:
 
58
  sci_bert_model.eval()
59
  logger.info("SciBERT loaded")
60
 
61
+ # Mistral-7B-Instruct for QA with token
62
+ mistral_tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1", cache_dir="/tmp/huggingface", use_auth_token=HF_TOKEN)
63
+ mistral_model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1", cache_dir="/tmp/huggingface", use_auth_token=HF_TOKEN)
64
  mistral_model.to(device)
65
  mistral_model.eval()
66
  logger.info("Mistral-7B-Instruct loaded")
 
165
  with torch.no_grad():
166
  outputs = mistral_model.generate(
167
  inputs["input_ids"],
168
+ max_new_tokens=200,
169
  do_sample=True,
170
  temperature=0.7,
171
  top_p=0.9,
172
  pad_token_id=mistral_tokenizer.eos_token_id
173
  )
174
 
175
+ # Decode and clean response (preserve token structure)
176
  response = mistral_tokenizer.decode(outputs[0], skip_special_tokens=True)
177
  response = response[len(prompt):].strip() # Remove prompt, including [INST] tags
178