MLADI / eval_utils.py
AMR-KELEG's picture
Track the constants and eval modules
24cf6c5
raw
history blame
742 Bytes
import torch
from constants import DIALECTS, DIALECTS_WITH_LABELS
def predict_top_p(model, tokenizer, text, P=0.9):
"""Predict the top dialects with an accumulative confidence of at least P."""
assert P <= 1 and P >= 0
logits = model(**tokenizer(text, return_tensors="pt")).logits
probabilities = torch.softmax(logits, dim=1).flatten().tolist()
topk_predictions = torch.topk(logits, 18).indices.flatten().tolist()
predictions = [0 for _ in range(18)]
total_prob = 0
for i in range(18):
total_prob += probabilities[topk_predictions[i]]
predictions[topk_predictions[i]] = 1
if total_prob >= P:
break
return [DIALECTS[i] for i, p in enumerate(predictions) if p == 1]