AMR-KELEG commited on
Commit
6df4a25
·
1 Parent(s): 571f9ec

Add a new evaluation function

Browse files
Files changed (1) hide show
  1. eval_utils.py +31 -0
eval_utils.py CHANGED
@@ -77,3 +77,34 @@ def prompt_chat_LLM(model, tokenizer, text):
77
  if "نعم" in gen_text:
78
  predicted_dialects.append(dialect)
79
  return predicted_dialects
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  if "نعم" in gen_text:
78
  predicted_dialects.append(dialect)
79
  return predicted_dialects
80
+
81
+
82
+ def predict_binary_outcomes(model, tokenizer, texts, threshold=0.3):
83
+ """Predict the validity in each dialect, by indepenently applying a sigmoid activation to each dialect's logit.
84
+ Dialects with probabilities (sigmoid activations) above a threshold (set by defauly to 0.3) are considered predicted.
85
+ """
86
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
87
+
88
+ encodings = tokenizer(
89
+ texts, truncation=True, padding=True, max_length=128, return_tensors="pt"
90
+ )
91
+
92
+ ## inputs
93
+ input_ids = encodings["input_ids"].to(device)
94
+ attention_mask = encodings["attention_mask"].to(device)
95
+
96
+ with torch.no_grad():
97
+ outputs = model(input_ids=input_ids, attention_mask=attention_mask)
98
+ logits = outputs.logits
99
+
100
+ probabilities = torch.sigmoid(logits).cpu().numpy().reshape(-1)
101
+ predictions = (probabilities >= threshold).astype(int)
102
+
103
+ # Map indices to actual labels
104
+ predicted_labels = [
105
+ dialect
106
+ for dialect, dialect_probability in zip(DIALECTS, predictions)
107
+ if dialect_probability == 1
108
+ ]
109
+
110
+ return predicted_labels