AMR-KELEG commited on
Commit
e103605
·
1 Parent(s): 8f6e384

Add a prompting-based method

Browse files
Files changed (3) hide show
  1. background_inference.py +11 -4
  2. constants.py +22 -0
  3. eval_utils.py +24 -1
background_inference.py CHANGED
@@ -4,7 +4,7 @@ import utils
4
  import datasets
5
  import eval_utils
6
  from constants import DIALECTS_WITH_LABELS
7
- from transformers import AutoTokenizer, AutoModelForSequenceClassification
8
 
9
  model_name = sys.argv[1]
10
  commit_id = sys.argv[2]
@@ -19,10 +19,17 @@ utils.update_model_queue(
19
  )
20
 
21
  try:
22
- tokenizer = AutoTokenizer.from_pretrained(model_name, revision=commit_id)
23
- model = AutoModelForSequenceClassification.from_pretrained(
24
- model_name, revision=commit_id
25
  )
 
 
 
 
 
 
 
 
26
 
27
  # Load the dataset
28
  dataset_name = os.environ["DATASET_NAME"]
 
4
  import datasets
5
  import eval_utils
6
  from constants import DIALECTS_WITH_LABELS
7
+ from transformers import AutoTokenizer, AutoModel, AutoModelForSequenceClassification
8
 
9
  model_name = sys.argv[1]
10
  commit_id = sys.argv[2]
 
19
  )
20
 
21
  try:
22
+ tokenizer = AutoTokenizer.from_pretrained(
23
+ model_name, revision=commit_id, access_token=os.environ["HF_TOKEN"]
 
24
  )
25
+ if inference_function == "prompt_chat_LLM":
26
+ model = AutoModel.from_pretrained(
27
+ model_name, revision=commit_id, access_token=os.environ["HF_TOKEN"]
28
+ )
29
+ else:
30
+ model = AutoModelForSequenceClassification.from_pretrained(
31
+ model_name, revision=commit_id, access_token=os.environ["HF_TOKEN"]
32
+ )
33
 
34
  # Load the dataset
35
  dataset_name = os.environ["DATASET_NAME"]
constants.py CHANGED
@@ -18,6 +18,28 @@ DIALECTS = [
18
  "UAE",
19
  "Yemen",
20
  ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  assert len(DIALECTS) == 18
22
 
23
  DIALECTS_WITH_LABELS = [
 
18
  "UAE",
19
  "Yemen",
20
  ]
21
+
22
+ DIALECT_IN_ARABIC = {
23
+ "Algeria": "الجزائر",
24
+ "Bahrain": "البحرين",
25
+ "Egypt": "مصر",
26
+ "Iraq": "العراق",
27
+ "Jordan": "الأردن",
28
+ "Kuwait": "الكويت",
29
+ "Lebanon": "لبنان",
30
+ "Libya": "ليبيا",
31
+ "Morocco": "المغرب",
32
+ "Oman": "عمان",
33
+ "Palestine": "فلسطين",
34
+ "Qatar": "قطر",
35
+ "Saudi_Arabia": "المملكة العربية السعودية",
36
+ "Sudan": "السودان",
37
+ "Syria": "سوريا",
38
+ "Tunisia": "تونس",
39
+ "UAE": "الإمارات",
40
+ "Yemen": "اليمن",
41
+ }
42
+
43
  assert len(DIALECTS) == 18
44
 
45
  DIALECTS_WITH_LABELS = [
eval_utils.py CHANGED
@@ -1,5 +1,5 @@
1
  import torch
2
- from constants import DIALECTS, DIALECTS_WITH_LABELS
3
 
4
 
5
  def predict_top_p(model, tokenizer, text, P=0.9):
@@ -25,3 +25,26 @@ def predict_top_p(model, tokenizer, text, P=0.9):
25
  break
26
 
27
  return [DIALECTS[i] for i, p in enumerate(predictions) if p == 1]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
+ from constants import DIALECTS, DIALECTS_WITH_LABELS, DIALECT_IN_ARABIC
3
 
4
 
5
  def predict_top_p(model, tokenizer, text, P=0.9):
 
25
  break
26
 
27
  return [DIALECTS[i] for i, p in enumerate(predictions) if p == 1]
28
+
29
+
30
+ def prompt_chat_LLM(model, tokenizer, text):
31
+ """Prompt the user to determine whether the input text is acceptable in each of the 11 dialects."""
32
+ predicted_dialects = []
33
+ for dialect in DIALECTS_WITH_LABELS:
34
+ messages = [
35
+ {
36
+ "role": "user",
37
+ "content": f"حدد إذا كانت الجملة الأتية مقبولة في أحد اللهجات المستخدمة في {DIALECT_IN_ARABIC[dialect]}. أجب ب 'نعم' أو 'لا' فقط."
38
+ + "\n"
39
+ + f'الجملة: "{text}"',
40
+ },
41
+ ]
42
+ input_ids = tokenizer.apply_chat_template(
43
+ messages, tokenize=True, add_generation_prompt=True, return_tensors="pt"
44
+ )
45
+ gen_tokens = model.generate(input_ids, max_new_tokens=20)
46
+ gen_text = tokenizer.decode(gen_tokens[0], skip_special_tokens=True)
47
+ # TODO: Add a condition for the case of "لا" and other responses (e.g., refuse to answer)
48
+ if "نعم" in gen_text:
49
+ predicted_dialects.append(dialect)
50
+ return predicted_dialects