AMR-KELEG commited on
Commit
24cf6c5
·
1 Parent(s): f818d64

Track the constants and eval modules

Browse files
Files changed (2) hide show
  1. constants.py +36 -0
  2. eval_utils.py +22 -0
constants.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ DIALECTS = [
2
+ "Algeria",
3
+ "Bahrain",
4
+ "Egypt",
5
+ "Iraq",
6
+ "Jordan",
7
+ "Kuwait",
8
+ "Lebanon",
9
+ "Libya",
10
+ "Morocco",
11
+ "Oman",
12
+ "Palestine",
13
+ "Qatar",
14
+ "Saudi_Arabia",
15
+ "Sudan",
16
+ "Syria",
17
+ "Tunisia",
18
+ "UAE",
19
+ "Yemen",
20
+ ]
21
+ assert len(DIALECTS) == 18
22
+
23
+ DIALECTS_WITH_LABELS = [
24
+ "Algeria",
25
+ "Egypt",
26
+ "Iraq",
27
+ "Jordan",
28
+ "Morocco",
29
+ "Palestine",
30
+ "Saudi_Arabia",
31
+ "Sudan",
32
+ "Syria",
33
+ "Tunisia",
34
+ "Yemen",
35
+ ]
36
+ assert len(DIALECTS_WITH_LABELS) == 11
eval_utils.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from constants import DIALECTS, DIALECTS_WITH_LABELS
3
+
4
+
5
+ def predict_top_p(model, tokenizer, text, P=0.9):
6
+ """Predict the top dialects with an accumulative confidence of at least P."""
7
+ assert P <= 1 and P >= 0
8
+
9
+ logits = model(**tokenizer(text, return_tensors="pt")).logits
10
+ probabilities = torch.softmax(logits, dim=1).flatten().tolist()
11
+ topk_predictions = torch.topk(logits, 18).indices.flatten().tolist()
12
+
13
+ predictions = [0 for _ in range(18)]
14
+ total_prob = 0
15
+
16
+ for i in range(18):
17
+ total_prob += probabilities[topk_predictions[i]]
18
+ predictions[topk_predictions[i]] = 1
19
+ if total_prob >= P:
20
+ break
21
+
22
+ return [DIALECTS[i] for i, p in enumerate(predictions) if p == 1]