|
import torch |
|
from constants import DIALECTS, DIALECTS_WITH_LABELS |
|
|
|
|
|
def predict_top_p(model, tokenizer, text, P=0.9): |
|
"""Predict the top dialects with an accumulative confidence of at least P (set by default to 0.9). |
|
The model is expected to generate logits for each dialect of the following dialects in the same order: |
|
Algeria, Bahrain, Egypt, Iraq, Jordan, Kuwait, Lebanon, Libya, Morocco, Oman, Palestine, Qatar, Saudi_Arabia, Sudan, Syria, Tunisia, UAE, Yemen. |
|
""" |
|
assert P <= 1 and P >= 0 |
|
|
|
logits = model(**tokenizer(text, return_tensors="pt")).logits |
|
probabilities = torch.softmax(logits, dim=1).flatten().tolist() |
|
topk_predictions = torch.topk(logits, 18).indices.flatten().tolist() |
|
|
|
predictions = [0 for _ in range(18)] |
|
total_prob = 0 |
|
|
|
for i in range(18): |
|
total_prob += probabilities[topk_predictions[i]] |
|
predictions[topk_predictions[i]] = 1 |
|
if total_prob >= P: |
|
break |
|
|
|
return [DIALECTS[i] for i, p in enumerate(predictions) if p == 1] |
|
|