import torch | |
from constants import DIALECTS, DIALECTS_WITH_LABELS | |
def predict_top_p(model, tokenizer, text, P=0.9): | |
"""Predict the top dialects with an accumulative confidence of at least P.""" | |
assert P <= 1 and P >= 0 | |
logits = model(**tokenizer(text, return_tensors="pt")).logits | |
probabilities = torch.softmax(logits, dim=1).flatten().tolist() | |
topk_predictions = torch.topk(logits, 18).indices.flatten().tolist() | |
predictions = [0 for _ in range(18)] | |
total_prob = 0 | |
for i in range(18): | |
total_prob += probabilities[topk_predictions[i]] | |
predictions[topk_predictions[i]] = 1 | |
if total_prob >= P: | |
break | |
return [DIALECTS[i] for i, p in enumerate(predictions) if p == 1] | |