SuiGio commited on
Commit
24ea49b
·
verified ·
1 Parent(s): 920e083

update main py

Browse files

now prints out all classes confidence

Files changed (1) hide show
  1. main.py +10 -22
main.py CHANGED
@@ -84,29 +84,15 @@ def get_preds(user_text, model=model, tokenizer=tokenizer, label_names=label_nam
84
 
85
  # Process the outputs as needed
86
  # For example, if you have a classification task:
87
- predictions = torch.nn.functional.softmax(outputs.logits, dim=1)
 
88
 
89
- # Get the predicted label indices for each prediction
90
- predicted_label_indices = predictions.argmax(dim=1).cpu().numpy()
 
 
91
 
92
- # Get the confidence scores (probabilities) for each prediction
93
- confidence_scores = predictions.max(dim=1).values.cpu().numpy()
94
-
95
- # Combine the category names, predicted label indices, and confidence scores into a list of tuples
96
- predictions_with_confidence = list(zip(predicted_label_indices, confidence_scores))
97
-
98
- # Sort the predictions with confidence scores in descending order
99
- predictions_with_confidence_sorted = sorted(predictions_with_confidence, key=lambda x: x[1], reverse=True)
100
-
101
- # Print all predictions with category names and confidence scores in descending order
102
- category_list, conf_sc = [], []
103
- for label_index, confidence_score in predictions_with_confidence_sorted:
104
- category_name = model.config.id2label[str(label_index)]
105
- print(f"Category: {category_name}, Confidence: {confidence_score:.2f}")
106
- category_list.append(category_name)
107
- conf_sc.append(confidence_score)
108
-
109
- return category_list
110
 
111
  @app.get("/", tags=["Home"])
112
  def api_home(user_text: str):
@@ -116,6 +102,8 @@ def api_home(user_text: str):
116
  def inference(user_text: str):
117
  return get_preds(user_text=user_text)
118
 
 
 
119
 
120
  if __name__=='__main__':
121
- uvicorn.run('main:app', reload=True)
 
84
 
85
  # Process the outputs as needed
86
  # For example, if you have a classification task:
87
+ predictions = torch.nn.functional.softmax(outputs.logits, dim=1).tolist()
88
+ predictions = [item for sublist in predictions for item in sublist]
89
 
90
+ print(predictions)
91
+
92
+ dicti = dict(zip(label_names, predictions))
93
+ print(dicti)
94
 
95
+ return dicti
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
97
  @app.get("/", tags=["Home"])
98
  def api_home(user_text: str):
 
102
  def inference(user_text: str):
103
  return get_preds(user_text=user_text)
104
 
105
+ # get_preds(user_text=user_text)
106
+
107
 
108
  if __name__=='__main__':
109
+ uvicorn.run('main:app', reload=True)