MLADI / eval_utils.py
Amr Keleg
Tiny black edit
17680bc
import torch
from constants import DIALECTS, DIALECTS_WITH_LABELS, DIALECT_IN_ARABIC
def predict_top_p(model, tokenizer, text, P=0.9):
"""Predict the top dialects with an accumulative confidence of at least P (set by default to 0.9).
The model is expected to generate logits for each dialect of the following dialects in the same order:
Algeria, Bahrain, Egypt, Iraq, Jordan, Kuwait, Lebanon, Libya, Morocco, Oman, Palestine, Qatar, Saudi_Arabia, Sudan, Syria, Tunisia, UAE, Yemen.
"""
assert P <= 1 and P >= 0
logits = model(**tokenizer(text, return_tensors="pt")).logits
probabilities = torch.softmax(logits, dim=1).flatten().tolist()
topk_predictions = torch.topk(logits, 18).indices.flatten().tolist()
predictions = [0 for _ in range(18)]
total_prob = 0
# TODO: Assert that the list length is just 18
for i in range(18):
total_prob += probabilities[topk_predictions[i]]
predictions[topk_predictions[i]] = 1
if total_prob >= P:
break
if (
str(model.config.to_dict()["id2label"][0]) == "LABEL_0"
or str(model.config.to_dict()["id2label"][0]) == "Algeria"
):
return [DIALECTS[i] for i, p in enumerate(predictions) if p == 1]
else:
# Use the custom list of
# https://huggingface.co/Abdelrahman-Rezk/bert-base-arabic-camelbert-msa-finetuned-Arabic_Dialect_Identification_model_1
DIALECTS_LIST = [
"Oman",
"Sudan",
"Saudi_Arabia",
"Kuwait",
"Qatar",
"Lebanon",
"Jordan",
"Syria",
"Iraq",
"Morocco",
"Egypt",
"Palestine",
"Yemen",
"Bahrain",
"Algeria",
"UAE",
"Tunisia",
"Libya",
]
return [DIALECTS_LIST[i] for i, p in enumerate(predictions) if p == 1]
def prompt_chat_LLM(model, tokenizer, text):
"""Prompt the model to determine whether the input text is acceptable in each of the 11 dialects."""
predicted_dialects = []
for dialect in DIALECTS_WITH_LABELS:
messages = [
{
"role": "user",
"content": f"حدد إذا كانت الجملة الأتية مقبولة في أحد اللهجات المستخدمة في {DIALECT_IN_ARABIC[dialect]}. أجب ب 'نعم' أو 'لا' فقط."
+ "\n"
+ f'الجملة: "{text}"',
},
]
input_ids = tokenizer.apply_chat_template(
messages, tokenize=True, add_generation_prompt=True, return_tensors="pt",
)
gen_tokens = model.generate(input_ids, max_new_tokens=20)
gen_text = tokenizer.decode(gen_tokens[0], skip_special_tokens=True)
# TODO: Add a condition for the case of "لا" and other responses (e.g., refuse to answer)
if "نعم" in gen_text:
predicted_dialects.append(dialect)
return predicted_dialects
def predict_binary_outcomes(model, tokenizer, texts, threshold=0.3):
"""Predict the validity in each dialect, by indepenently applying a sigmoid activation to each dialect's logit.
Dialects with probabilities (sigmoid activations) above a threshold (set by defauly to 0.3) are predicted as valid.
The model is expected to generate logits for each dialect of the following dialects in the same order:
Algeria, Bahrain, Egypt, Iraq, Jordan, Kuwait, Lebanon, Libya, Morocco, Oman, Palestine, Qatar, Saudi_Arabia, Sudan, Syria, Tunisia, UAE, Yemen.
Credits: method proposed by Ali Mekky, Lara Hassan, and Mohamed ELZeftawy from MBZUAI.
"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
encodings = tokenizer(
texts, truncation=True, padding=True, max_length=128, return_tensors="pt"
)
## inputs
input_ids = encodings["input_ids"].to(device)
attention_mask = encodings["attention_mask"].to(device)
with torch.no_grad():
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
logits = outputs.logits
probabilities = torch.sigmoid(logits).cpu().numpy().reshape(-1)
binary_predictions = (probabilities >= threshold).astype(int)
# Map indices to actual labels
predicted_dialects = [
dialect
for dialect, dialect_prediction in zip(DIALECTS, binary_predictions)
if dialect_prediction == 1
]
return predicted_dialects