mouadenna commited on
Commit
0cb0826
1 Parent(s): 5d31b2a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -19
app.py CHANGED
@@ -35,16 +35,15 @@ tokenizer = AutoTokenizer.from_pretrained("medalpaca/medalpaca-7b")
35
  #load the first interface
36
 
37
  def fn(*args):
38
- global symptoms
39
- all_symptoms = [symptom for symptom_list in args for symptom in symptom_list]
40
-
41
- if len(all_symptoms) > 17:
42
- raise gr.Error("Please select a maximum of 17 symptoms.")
43
- elif len(all_symptoms) < 3:
44
- raise gr.Error("Please select at least 3 symptoms.")
45
-
46
- symptoms = all_symptoms # Update global symptoms list
47
- return predd(loaded_rf,symptoms)
48
 
49
 
50
 
@@ -77,17 +76,16 @@ demo = gr.Interface(
77
 
78
 
79
  def predict(message, history):
80
- prompt = f"""
81
- Answer the following question:
82
- {message}/n
83
- Answer:
84
- """
85
- batch = tokenizer(prompt, return_tensors='pt')
86
- with torch.cuda.amp.autocast():
87
 
88
  output_tokens = model.generate(**batch, max_new_tokens=100)
89
-
90
- return tokenizer.decode(output_tokens[0], skip_special_tokens=True).replace(prompt,"")
91
 
92
 
93
  loaded_rf = joblib.load("model_joblib")
 
35
  #load the first interface
36
 
37
  def fn(*args):
38
+ global symptoms
39
+ all_symptoms = [symptom for symptom_list in args for symptom in symptom_list]
40
+ if len(all_symptoms) > 17:
41
+ raise gr.Error("Please select a maximum of 17 symptoms.")
42
+ elif len(all_symptoms) < 3:
43
+ raise gr.Error("Please select at least 3 symptoms.")
44
+ symptoms = all_symptoms # Update global symptoms list
45
+ loaded_rf = joblib.load("model_joblib")
46
+ return predd(loaded_rf,symptoms)
 
47
 
48
 
49
 
 
76
 
77
 
78
  def predict(message, history):
79
+ prompt = f"""
80
+ Answer the following question:
81
+ {message}/n
82
+ Answer:
83
+ """
84
+ batch = tokenizer(prompt, return_tensors='pt')
85
+ with torch.cuda.amp.autocast():
86
 
87
  output_tokens = model.generate(**batch, max_new_tokens=100)
88
+ return tokenizer.decode(output_tokens[0], skip_special_tokens=True).replace(prompt,"")
 
89
 
90
 
91
  loaded_rf = joblib.load("model_joblib")