gregmialz commited on
Commit
d3b1c8a
1 Parent(s): a86b728

Update scorer.py

Browse files
Files changed (1) hide show
  1. scorer.py +88 -68
scorer.py CHANGED
@@ -1,81 +1,101 @@
1
  import json
2
  import re
3
  import string
 
4
 
5
  import numpy as np
6
 
7
- def normalize_text(text: str) -> str:
8
- "From QuAC"
9
- def remove_articles(text: str) -> str:
10
- return re.sub(r"\b(a|an|the)\b", " ", text)
11
 
12
- def white_space_fix(text: str) -> str:
13
- return " ".join(text.split())
 
 
 
 
 
 
 
 
14
 
15
- def homogeneize_numbers(text: str) -> str:
16
- try:
17
- return str(float(text))
18
- except ValueError:
19
- return text
20
-
21
- def remove_punc(text: str) -> str:
22
- exclude = set(string.punctuation)
23
- return "".join(ch for ch in text if ch not in exclude)
24
-
25
- def remove_punc2(text: str) -> str:
26
- "From Grégoire's code, removes all punctuation, nicer than remove_punc"
27
- translator = str.maketrans('', '', string.punctuation)
28
- return text.translate(translator)
29
-
30
- def lower(text: str) -> str:
31
- return text.lower()
32
-
33
- def _tokenize(text):
34
- return re.split(" ", text)
35
-
36
- tokens = [white_space_fix(remove_articles(homogeneize_numbers(remove_punc2(lower(t))))) for t in _tokenize(text)]
37
- return " ".join([t for t in tokens if t != ""]).strip()
38
-
39
- def extract_answer(input_str: str, prompt_sep: str = 'FINAL ANSWER: ') -> str:
40
- answer = input_str.split(prompt_sep)[-1].strip()
41
- return answer
42
 
43
- def extract_bow(input_str: str) -> list[str]:
44
- return input_str.split(" ")
 
 
 
 
45
 
46
- def numbers_equals_in_bow(gold_list: list, pred_list: list) -> bool:
47
- # Numbers in prediction bag of words
48
- pred_numbers = []
49
- for text in pred_list:
50
- try:
51
- pred_numbers.append(str(float(text)))
52
- except ValueError:
53
- continue
54
 
55
- for text in gold_list:
 
 
 
 
56
  try:
57
- number = str(float(text))
58
- if number not in pred_numbers:
59
- return False
60
  except ValueError:
61
- continue
62
-
63
- return True
64
-
65
- def affix_quasi_exact_match(gold: str, pred: str) -> float:
66
- if not pred:
67
- return 0
68
-
69
- normalized_pred = normalize_text(pred)
70
- normalized_gold = normalize_text(gold)
71
- bow_pred = extract_bow(pred)
72
- bow_gold = extract_bow(gold)
73
-
74
- if normalized_pred.startswith(normalized_gold) or normalized_pred.endswith(normalized_gold):
75
- if numbers_equals_in_bow(bow_gold, bow_pred):
76
- return 1
77
-
78
- return 0
79
-
80
- def question_scorer(gold: str, pred: str) -> float:
81
- return affix_quasi_exact_match(gold, pred)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import json
2
  import re
3
  import string
4
+ import warnings
5
 
6
  import numpy as np
7
 
 
 
 
 
8
 
9
+ def normalize_number_str(number_str: str) -> float:
10
+ # we replace these common units and commas to allow
11
+ # conversion to float
12
+ for char in ["$", "%", ","]:
13
+ number_str = number_str.replace(char, "")
14
+ try:
15
+ return float(number_str)
16
+ except ValueError:
17
+ print(f"String {number_str} cannot be normalized to number str.")
18
+ return float("inf")
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
+ def split_string(
22
+ s: str,
23
+ char_list: list[str] = [",", ";"],
24
+ ) -> list[str]:
25
+ pattern = f"[{''.join(char_list)}]"
26
+ return re.split(pattern, s)
27
 
 
 
 
 
 
 
 
 
28
 
29
+ def question_scorer(
30
+ model_answer: str,
31
+ ground_truth: str,
32
+ ) -> bool:
33
+ def is_float(element: any) -> bool:
34
  try:
35
+ float(element)
36
+ return True
 
37
  except ValueError:
38
+ return False
39
+
40
+ # if gt is a number
41
+ if is_float(ground_truth):
42
+ print(f"Evaluating {model_answer} as a number.")
43
+ normalized_answer = normalize_number_str(model_answer)
44
+ return normalized_answer == float(ground_truth)
45
+
46
+ # if gt is a list
47
+ elif any(char in ground_truth for char in [",", ";"]):
48
+ print(f"Evaluating {model_answer} as a comma separated list.")
49
+ # question with the fish: normalization removes punct
50
+
51
+ gt_elems = split_string(ground_truth)
52
+ ma_elems = split_string(model_answer)
53
+
54
+ # check length is the same
55
+ if len(gt_elems) != len(ma_elems):
56
+ warnings.warn(
57
+ "Answer lists have different lengths, returning False.", UserWarning
58
+ )
59
+ return False
60
+
61
+ # compare each element as float or str
62
+ comparisons = []
63
+ for ma_elem, gt_elem in zip(ma_elems, gt_elems):
64
+ if is_float(gt_elem):
65
+ normalized_ma_elem = normalize_number_str(ma_elem)
66
+ comparisons.append(normalized_ma_elem == float(gt_elem))
67
+ else:
68
+ # we do not remove punct since comparisons can include punct
69
+ comparisons.append(
70
+ normalize_str(ma_elem, remove_punct=False)
71
+ == normalize_str(gt_elem, remove_punct=False)
72
+ )
73
+ return all(comparisons)
74
+
75
+ # if gt is a str
76
+ else:
77
+ print(f"Evaluating {model_answer} as a string.")
78
+ return normalize_str(model_answer) == normalize_str(ground_truth)
79
+
80
+
81
+ def normalize_str(input_str, remove_punct=True) -> str:
82
+ """
83
+ Normalize a string by:
84
+ - Removing all white spaces
85
+ - Optionally removing punctuation (if remove_punct is True)
86
+ - Converting to lowercase
87
+ Parameters:
88
+ - input_str: str, the string to normalize
89
+ - remove_punct: bool, whether to remove punctuation (default: True)
90
+ Returns:
91
+ - str, the normalized string
92
+ """
93
+ # Remove all white spaces. Required e.g for seagull vs. sea gull
94
+ no_spaces = re.sub(r"\s", "", input_str)
95
+
96
+ # Remove punctuation, if specified.
97
+ if remove_punct:
98
+ translator = str.maketrans("", "", string.punctuation)
99
+ return no_spaces.lower().translate(translator)
100
+ else:
101
+ return no_spaces.lower()