import torch from constants import DIALECTS, DIALECTS_WITH_LABELS, DIALECT_IN_ARABIC def predict_top_p(model, tokenizer, text, P=0.9): """Predict the top dialects with an accumulative confidence of at least P (set by default to 0.9). The model is expected to generate logits for each dialect of the following dialects in the same order: Algeria, Bahrain, Egypt, Iraq, Jordan, Kuwait, Lebanon, Libya, Morocco, Oman, Palestine, Qatar, Saudi_Arabia, Sudan, Syria, Tunisia, UAE, Yemen. """ assert P <= 1 and P >= 0 logits = model(**tokenizer(text, return_tensors="pt")).logits probabilities = torch.softmax(logits, dim=1).flatten().tolist() topk_predictions = torch.topk(logits, 18).indices.flatten().tolist() predictions = [0 for _ in range(18)] total_prob = 0 # TODO: Assert that the list length is just 18 for i in range(18): total_prob += probabilities[topk_predictions[i]] predictions[topk_predictions[i]] = 1 if total_prob >= P: break if ( str(model.config.to_dict()["id2label"][0]) == "LABEL_0" or str(model.config.to_dict()["id2label"][0]) == "Algeria" ): return [DIALECTS[i] for i, p in enumerate(predictions) if p == 1] else: # Use the custom list of # https://huggingface.co/Abdelrahman-Rezk/bert-base-arabic-camelbert-msa-finetuned-Arabic_Dialect_Identification_model_1 DIALECTS_LIST = [ "Oman", "Sudan", "Saudi_Arabia", "Kuwait", "Qatar", "Lebanon", "Jordan", "Syria", "Iraq", "Morocco", "Egypt", "Palestine", "Yemen", "Bahrain", "Algeria", "UAE", "Tunisia", "Libya", ] return [DIALECTS_LIST[i] for i, p in enumerate(predictions) if p == 1] def prompt_chat_LLM(model, tokenizer, text): """Prompt the model to determine whether the input text is acceptable in each of the 11 dialects.""" predicted_dialects = [] for dialect in DIALECTS_WITH_LABELS: messages = [ { "role": "user", "content": f"حدد إذا كانت الجملة الأتية مقبولة في أحد اللهجات المستخدمة في {DIALECT_IN_ARABIC[dialect]}. أجب ب 'نعم' أو 'لا' فقط." + "\n" + f'الجملة: "{text}"', }, ] input_ids = tokenizer.apply_chat_template( messages, tokenize=True, add_generation_prompt=True, return_tensors="pt", ) gen_tokens = model.generate(input_ids, max_new_tokens=20) gen_text = tokenizer.decode(gen_tokens[0], skip_special_tokens=True) # TODO: Add a condition for the case of "لا" and other responses (e.g., refuse to answer) if "نعم" in gen_text: predicted_dialects.append(dialect) return predicted_dialects def predict_binary_outcomes(model, tokenizer, texts, threshold=0.3): """Predict the validity in each dialect, by indepenently applying a sigmoid activation to each dialect's logit. Dialects with probabilities (sigmoid activations) above a threshold (set by defauly to 0.3) are predicted as valid. The model is expected to generate logits for each dialect of the following dialects in the same order: Algeria, Bahrain, Egypt, Iraq, Jordan, Kuwait, Lebanon, Libya, Morocco, Oman, Palestine, Qatar, Saudi_Arabia, Sudan, Syria, Tunisia, UAE, Yemen. Credits: method proposed by Ali Mekky, Lara Hassan, and Mohamed ELZeftawy from MBZUAI. """ device = torch.device("cuda" if torch.cuda.is_available() else "cpu") encodings = tokenizer( texts, truncation=True, padding=True, max_length=128, return_tensors="pt" ) ## inputs input_ids = encodings["input_ids"].to(device) attention_mask = encodings["attention_mask"].to(device) with torch.no_grad(): outputs = model(input_ids=input_ids, attention_mask=attention_mask) logits = outputs.logits probabilities = torch.sigmoid(logits).cpu().numpy().reshape(-1) binary_predictions = (probabilities >= threshold).astype(int) # Map indices to actual labels predicted_dialects = [ dialect for dialect, dialect_prediction in zip(DIALECTS, binary_predictions) if dialect_prediction == 1 ] return predicted_dialects