File size: 4,502 Bytes
24cf6c5
e103605
24cf6c5
 
 
1a452c0
 
 
 
24cf6c5
 
 
 
 
 
 
 
 
d956a72
 
24cf6c5
 
 
 
 
 
3b9a198
491aef1
 
3b9a198
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e103605
 
 
571f9ec
e103605
 
 
 
 
 
 
 
 
 
 
17680bc
e103605
 
 
 
 
 
 
6df4a25
 
 
 
9f4840a
 
 
 
6df4a25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9f4840a
6df4a25
 
9f4840a
6df4a25
9f4840a
 
6df4a25
 
9f4840a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import torch
from constants import DIALECTS, DIALECTS_WITH_LABELS, DIALECT_IN_ARABIC


def predict_top_p(model, tokenizer, text, P=0.9):
    """Predict the top dialects with an accumulative confidence of at least P (set by default to 0.9).
    The model is expected to generate logits for each dialect of the following dialects in the same order:
    Algeria, Bahrain, Egypt, Iraq, Jordan, Kuwait, Lebanon, Libya, Morocco, Oman, Palestine, Qatar, Saudi_Arabia, Sudan, Syria, Tunisia, UAE, Yemen.
    """
    assert P <= 1 and P >= 0

    logits = model(**tokenizer(text, return_tensors="pt")).logits
    probabilities = torch.softmax(logits, dim=1).flatten().tolist()
    topk_predictions = torch.topk(logits, 18).indices.flatten().tolist()

    predictions = [0 for _ in range(18)]
    total_prob = 0

    # TODO: Assert that the list length is just 18

    for i in range(18):
        total_prob += probabilities[topk_predictions[i]]
        predictions[topk_predictions[i]] = 1
        if total_prob >= P:
            break

    if (
        str(model.config.to_dict()["id2label"][0]) == "LABEL_0"
        or str(model.config.to_dict()["id2label"][0]) == "Algeria"
    ):
        return [DIALECTS[i] for i, p in enumerate(predictions) if p == 1]
    else:
        # Use the custom list of
        # https://huggingface.co/Abdelrahman-Rezk/bert-base-arabic-camelbert-msa-finetuned-Arabic_Dialect_Identification_model_1
        DIALECTS_LIST = [
            "Oman",
            "Sudan",
            "Saudi_Arabia",
            "Kuwait",
            "Qatar",
            "Lebanon",
            "Jordan",
            "Syria",
            "Iraq",
            "Morocco",
            "Egypt",
            "Palestine",
            "Yemen",
            "Bahrain",
            "Algeria",
            "UAE",
            "Tunisia",
            "Libya",
        ]

        return [DIALECTS_LIST[i] for i, p in enumerate(predictions) if p == 1]


def prompt_chat_LLM(model, tokenizer, text):
    """Prompt the model to determine whether the input text is acceptable in each of the 11 dialects."""
    predicted_dialects = []
    for dialect in DIALECTS_WITH_LABELS:
        messages = [
            {
                "role": "user",
                "content": f"حدد إذا كانت الجملة الأتية مقبولة في أحد اللهجات المستخدمة في {DIALECT_IN_ARABIC[dialect]}. أجب ب 'نعم' أو 'لا' فقط."
                + "\n"
                + f'الجملة: "{text}"',
            },
        ]
        input_ids = tokenizer.apply_chat_template(
            messages, tokenize=True, add_generation_prompt=True, return_tensors="pt",
        )
        gen_tokens = model.generate(input_ids, max_new_tokens=20)
        gen_text = tokenizer.decode(gen_tokens[0], skip_special_tokens=True)
        # TODO: Add a condition for the case of "لا" and other responses (e.g., refuse to answer)
        if "نعم" in gen_text:
            predicted_dialects.append(dialect)
    return predicted_dialects


def predict_binary_outcomes(model, tokenizer, texts, threshold=0.3):
    """Predict the validity in each dialect, by indepenently applying a sigmoid activation to each dialect's logit.
    Dialects with probabilities (sigmoid activations) above a threshold (set by defauly to 0.3) are predicted as valid.
    The model is expected to generate logits for each dialect of the following dialects in the same order:
    Algeria, Bahrain, Egypt, Iraq, Jordan, Kuwait, Lebanon, Libya, Morocco, Oman, Palestine, Qatar, Saudi_Arabia, Sudan, Syria, Tunisia, UAE, Yemen.
    Credits: method proposed by Ali Mekky, Lara Hassan, and Mohamed ELZeftawy from MBZUAI.
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    encodings = tokenizer(
        texts, truncation=True, padding=True, max_length=128, return_tensors="pt"
    )

    ## inputs
    input_ids = encodings["input_ids"].to(device)
    attention_mask = encodings["attention_mask"].to(device)

    with torch.no_grad():
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        logits = outputs.logits

    probabilities = torch.sigmoid(logits).cpu().numpy().reshape(-1)
    binary_predictions = (probabilities >= threshold).astype(int)

    # Map indices to actual labels
    predicted_dialects = [
        dialect
        for dialect, dialect_prediction in zip(DIALECTS, binary_predictions)
        if dialect_prediction == 1
    ]

    return predicted_dialects