|
import torch |
|
from constants import DIALECTS, DIALECTS_WITH_LABELS, DIALECT_IN_ARABIC |
|
|
|
|
|
def predict_top_p(model, tokenizer, text, P=0.9): |
|
"""Predict the top dialects with an accumulative confidence of at least P (set by default to 0.9). |
|
The model is expected to generate logits for each dialect of the following dialects in the same order: |
|
Algeria, Bahrain, Egypt, Iraq, Jordan, Kuwait, Lebanon, Libya, Morocco, Oman, Palestine, Qatar, Saudi_Arabia, Sudan, Syria, Tunisia, UAE, Yemen. |
|
""" |
|
assert P <= 1 and P >= 0 |
|
|
|
logits = model(**tokenizer(text, return_tensors="pt")).logits |
|
probabilities = torch.softmax(logits, dim=1).flatten().tolist() |
|
topk_predictions = torch.topk(logits, 18).indices.flatten().tolist() |
|
|
|
predictions = [0 for _ in range(18)] |
|
total_prob = 0 |
|
|
|
|
|
|
|
for i in range(18): |
|
total_prob += probabilities[topk_predictions[i]] |
|
predictions[topk_predictions[i]] = 1 |
|
if total_prob >= P: |
|
break |
|
|
|
if ( |
|
str(model.config.to_dict()["id2label"][0]) == "LABEL_0" |
|
or str(model.config.to_dict()["id2label"][0]) == "Algeria" |
|
): |
|
return [DIALECTS[i] for i, p in enumerate(predictions) if p == 1] |
|
else: |
|
|
|
|
|
DIALECTS_LIST = [ |
|
"Oman", |
|
"Sudan", |
|
"Saudi_Arabia", |
|
"Kuwait", |
|
"Qatar", |
|
"Lebanon", |
|
"Jordan", |
|
"Syria", |
|
"Iraq", |
|
"Morocco", |
|
"Egypt", |
|
"Palestine", |
|
"Yemen", |
|
"Bahrain", |
|
"Algeria", |
|
"UAE", |
|
"Tunisia", |
|
"Libya", |
|
] |
|
|
|
return [DIALECTS_LIST[i] for i, p in enumerate(predictions) if p == 1] |
|
|
|
|
|
def prompt_chat_LLM(model, tokenizer, text): |
|
"""Prompt the model to determine whether the input text is acceptable in each of the 11 dialects.""" |
|
predicted_dialects = [] |
|
for dialect in DIALECTS_WITH_LABELS: |
|
messages = [ |
|
{ |
|
"role": "user", |
|
"content": f"حدد إذا كانت الجملة الأتية مقبولة في أحد اللهجات المستخدمة في {DIALECT_IN_ARABIC[dialect]}. أجب ب 'نعم' أو 'لا' فقط." |
|
+ "\n" |
|
+ f'الجملة: "{text}"', |
|
}, |
|
] |
|
input_ids = tokenizer.apply_chat_template( |
|
messages, tokenize=True, add_generation_prompt=True, return_tensors="pt", |
|
) |
|
gen_tokens = model.generate(input_ids, max_new_tokens=20) |
|
gen_text = tokenizer.decode(gen_tokens[0], skip_special_tokens=True) |
|
|
|
if "نعم" in gen_text: |
|
predicted_dialects.append(dialect) |
|
return predicted_dialects |
|
|
|
|
|
def predict_binary_outcomes(model, tokenizer, texts, threshold=0.3): |
|
"""Predict the validity in each dialect, by indepenently applying a sigmoid activation to each dialect's logit. |
|
Dialects with probabilities (sigmoid activations) above a threshold (set by defauly to 0.3) are predicted as valid. |
|
The model is expected to generate logits for each dialect of the following dialects in the same order: |
|
Algeria, Bahrain, Egypt, Iraq, Jordan, Kuwait, Lebanon, Libya, Morocco, Oman, Palestine, Qatar, Saudi_Arabia, Sudan, Syria, Tunisia, UAE, Yemen. |
|
Credits: method proposed by Ali Mekky, Lara Hassan, and Mohamed ELZeftawy from MBZUAI. |
|
""" |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
encodings = tokenizer( |
|
texts, truncation=True, padding=True, max_length=128, return_tensors="pt" |
|
) |
|
|
|
|
|
input_ids = encodings["input_ids"].to(device) |
|
attention_mask = encodings["attention_mask"].to(device) |
|
|
|
with torch.no_grad(): |
|
outputs = model(input_ids=input_ids, attention_mask=attention_mask) |
|
logits = outputs.logits |
|
|
|
probabilities = torch.sigmoid(logits).cpu().numpy().reshape(-1) |
|
binary_predictions = (probabilities >= threshold).astype(int) |
|
|
|
|
|
predicted_dialects = [ |
|
dialect |
|
for dialect, dialect_prediction in zip(DIALECTS, binary_predictions) |
|
if dialect_prediction == 1 |
|
] |
|
|
|
return predicted_dialects |
|
|