Spaces:
Sleeping
Sleeping
update main py
Browse filesnow prints out all classes confidence
main.py
CHANGED
@@ -84,29 +84,15 @@ def get_preds(user_text, model=model, tokenizer=tokenizer, label_names=label_nam
|
|
84 |
|
85 |
# Process the outputs as needed
|
86 |
# For example, if you have a classification task:
|
87 |
-
predictions = torch.nn.functional.softmax(outputs.logits, dim=1)
|
|
|
88 |
|
89 |
-
|
90 |
-
|
|
|
|
|
91 |
|
92 |
-
|
93 |
-
confidence_scores = predictions.max(dim=1).values.cpu().numpy()
|
94 |
-
|
95 |
-
# Combine the category names, predicted label indices, and confidence scores into a list of tuples
|
96 |
-
predictions_with_confidence = list(zip(predicted_label_indices, confidence_scores))
|
97 |
-
|
98 |
-
# Sort the predictions with confidence scores in descending order
|
99 |
-
predictions_with_confidence_sorted = sorted(predictions_with_confidence, key=lambda x: x[1], reverse=True)
|
100 |
-
|
101 |
-
# Print all predictions with category names and confidence scores in descending order
|
102 |
-
category_list, conf_sc = [], []
|
103 |
-
for label_index, confidence_score in predictions_with_confidence_sorted:
|
104 |
-
category_name = model.config.id2label[str(label_index)]
|
105 |
-
print(f"Category: {category_name}, Confidence: {confidence_score:.2f}")
|
106 |
-
category_list.append(category_name)
|
107 |
-
conf_sc.append(confidence_score)
|
108 |
-
|
109 |
-
return category_list
|
110 |
|
111 |
@app.get("/", tags=["Home"])
|
112 |
def api_home(user_text: str):
|
@@ -116,6 +102,8 @@ def api_home(user_text: str):
|
|
116 |
def inference(user_text: str):
|
117 |
return get_preds(user_text=user_text)
|
118 |
|
|
|
|
|
119 |
|
120 |
if __name__=='__main__':
|
121 |
-
uvicorn.run('main:app', reload=True)
|
|
|
84 |
|
85 |
# Process the outputs as needed
|
86 |
# For example, if you have a classification task:
|
87 |
+
predictions = torch.nn.functional.softmax(outputs.logits, dim=1).tolist()
|
88 |
+
predictions = [item for sublist in predictions for item in sublist]
|
89 |
|
90 |
+
print(predictions)
|
91 |
+
|
92 |
+
dicti = dict(zip(label_names, predictions))
|
93 |
+
print(dicti)
|
94 |
|
95 |
+
return dicti
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
96 |
|
97 |
@app.get("/", tags=["Home"])
|
98 |
def api_home(user_text: str):
|
|
|
102 |
def inference(user_text: str):
|
103 |
return get_preds(user_text=user_text)
|
104 |
|
105 |
+
# get_preds(user_text=user_text)
|
106 |
+
|
107 |
|
108 |
if __name__=='__main__':
|
109 |
+
uvicorn.run('main:app', reload=True)
|