lyangas commited on
Commit
cbbf9d1
·
1 Parent(s): 28684eb

show TOP-3 predictions with proba

Browse files
Files changed (1) hide show
  1. app.py +5 -2
app.py CHANGED
@@ -14,8 +14,11 @@ except Exception as e:
14
  print(f"ERROR: loading model failed with: {str(e)}")
15
 
16
  def classify(text):
17
- pred_classes = model.predict([text])
18
- output_text = ' '.join(pred_classes)
 
 
 
19
  return output_text
20
 
21
  print('INFO: starting gradio interface')
 
14
  print(f"ERROR: loading model failed with: {str(e)}")
15
 
16
  def classify(text):
17
+ embed = model._texts2vecs([text])
18
+ probs = model.classifier.predict_proba(embed)
19
+ best_n = np.flip(np.argsort(probs, axis=1,)[0,-3:])
20
+ preds = [f"{model.classifier.classes_[i]} - {probs[0][i]*100:.1f}%" for i in best_n]
21
+ output_text = '\n'.join(preds)
22
  return output_text
23
 
24
  print('INFO: starting gradio interface')