Add a prompting-based method
Browse files- background_inference.py +11 -4
- constants.py +22 -0
- eval_utils.py +24 -1
background_inference.py
CHANGED
@@ -4,7 +4,7 @@ import utils
|
|
4 |
import datasets
|
5 |
import eval_utils
|
6 |
from constants import DIALECTS_WITH_LABELS
|
7 |
-
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
8 |
|
9 |
model_name = sys.argv[1]
|
10 |
commit_id = sys.argv[2]
|
@@ -19,10 +19,17 @@ utils.update_model_queue(
|
|
19 |
)
|
20 |
|
21 |
try:
|
22 |
-
tokenizer = AutoTokenizer.from_pretrained(
|
23 |
-
|
24 |
-
model_name, revision=commit_id
|
25 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
|
27 |
# Load the dataset
|
28 |
dataset_name = os.environ["DATASET_NAME"]
|
|
|
4 |
import datasets
|
5 |
import eval_utils
|
6 |
from constants import DIALECTS_WITH_LABELS
|
7 |
+
from transformers import AutoTokenizer, AutoModel, AutoModelForSequenceClassification
|
8 |
|
9 |
model_name = sys.argv[1]
|
10 |
commit_id = sys.argv[2]
|
|
|
19 |
)
|
20 |
|
21 |
try:
|
22 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
23 |
+
model_name, revision=commit_id, access_token=os.environ["HF_TOKEN"]
|
|
|
24 |
)
|
25 |
+
if inference_function == "prompt_chat_LLM":
|
26 |
+
model = AutoModel.from_pretrained(
|
27 |
+
model_name, revision=commit_id, access_token=os.environ["HF_TOKEN"]
|
28 |
+
)
|
29 |
+
else:
|
30 |
+
model = AutoModelForSequenceClassification.from_pretrained(
|
31 |
+
model_name, revision=commit_id, access_token=os.environ["HF_TOKEN"]
|
32 |
+
)
|
33 |
|
34 |
# Load the dataset
|
35 |
dataset_name = os.environ["DATASET_NAME"]
|
constants.py
CHANGED
@@ -18,6 +18,28 @@ DIALECTS = [
|
|
18 |
"UAE",
|
19 |
"Yemen",
|
20 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
assert len(DIALECTS) == 18
|
22 |
|
23 |
DIALECTS_WITH_LABELS = [
|
|
|
18 |
"UAE",
|
19 |
"Yemen",
|
20 |
]
|
21 |
+
|
22 |
+
DIALECT_IN_ARABIC = {
|
23 |
+
"Algeria": "الجزائر",
|
24 |
+
"Bahrain": "البحرين",
|
25 |
+
"Egypt": "مصر",
|
26 |
+
"Iraq": "العراق",
|
27 |
+
"Jordan": "الأردن",
|
28 |
+
"Kuwait": "الكويت",
|
29 |
+
"Lebanon": "لبنان",
|
30 |
+
"Libya": "ليبيا",
|
31 |
+
"Morocco": "المغرب",
|
32 |
+
"Oman": "عمان",
|
33 |
+
"Palestine": "فلسطين",
|
34 |
+
"Qatar": "قطر",
|
35 |
+
"Saudi_Arabia": "المملكة العربية السعودية",
|
36 |
+
"Sudan": "السودان",
|
37 |
+
"Syria": "سوريا",
|
38 |
+
"Tunisia": "تونس",
|
39 |
+
"UAE": "الإمارات",
|
40 |
+
"Yemen": "اليمن",
|
41 |
+
}
|
42 |
+
|
43 |
assert len(DIALECTS) == 18
|
44 |
|
45 |
DIALECTS_WITH_LABELS = [
|
eval_utils.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
import torch
|
2 |
-
from constants import DIALECTS, DIALECTS_WITH_LABELS
|
3 |
|
4 |
|
5 |
def predict_top_p(model, tokenizer, text, P=0.9):
|
@@ -25,3 +25,26 @@ def predict_top_p(model, tokenizer, text, P=0.9):
|
|
25 |
break
|
26 |
|
27 |
return [DIALECTS[i] for i, p in enumerate(predictions) if p == 1]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import torch
|
2 |
+
from constants import DIALECTS, DIALECTS_WITH_LABELS, DIALECT_IN_ARABIC
|
3 |
|
4 |
|
5 |
def predict_top_p(model, tokenizer, text, P=0.9):
|
|
|
25 |
break
|
26 |
|
27 |
return [DIALECTS[i] for i, p in enumerate(predictions) if p == 1]
|
28 |
+
|
29 |
+
|
30 |
+
def prompt_chat_LLM(model, tokenizer, text):
|
31 |
+
"""Prompt the user to determine whether the input text is acceptable in each of the 11 dialects."""
|
32 |
+
predicted_dialects = []
|
33 |
+
for dialect in DIALECTS_WITH_LABELS:
|
34 |
+
messages = [
|
35 |
+
{
|
36 |
+
"role": "user",
|
37 |
+
"content": f"حدد إذا كانت الجملة الأتية مقبولة في أحد اللهجات المستخدمة في {DIALECT_IN_ARABIC[dialect]}. أجب ب 'نعم' أو 'لا' فقط."
|
38 |
+
+ "\n"
|
39 |
+
+ f'الجملة: "{text}"',
|
40 |
+
},
|
41 |
+
]
|
42 |
+
input_ids = tokenizer.apply_chat_template(
|
43 |
+
messages, tokenize=True, add_generation_prompt=True, return_tensors="pt"
|
44 |
+
)
|
45 |
+
gen_tokens = model.generate(input_ids, max_new_tokens=20)
|
46 |
+
gen_text = tokenizer.decode(gen_tokens[0], skip_special_tokens=True)
|
47 |
+
# TODO: Add a condition for the case of "لا" and other responses (e.g., refuse to answer)
|
48 |
+
if "نعم" in gen_text:
|
49 |
+
predicted_dialects.append(dialect)
|
50 |
+
return predicted_dialects
|