Chenxi Whitehouse commited on
Commit
4cac25e
·
1 Parent(s): 093ba74
README.md CHANGED
@@ -101,14 +101,19 @@ python -m src.reranking.question_generation_top_sentences
101
  ### 4. Rerank the QA pairs
102
  Using a pre-trained BERT model [bert_dual_encoder.ckpt](https://huggingface.co/chenxwh/AVeriTeC/blob/main/pretrained_models/bert_dual_encoder.ckpt), we rerank the QA paris and keep top 3 QA paris as evidence. See [rerank_questions.py](https://huggingface.co/chenxwh/AVeriTeC/blob/main/src/reranking/rerank_questions.py) for more argument options. We provide the output file for this step on the dev set [here](https://huggingface.co/chenxwh/AVeriTeC/blob/main/data_store/dev_top_3_rerank_qa.json).
103
  ```bash
104
- python -m reranking.rerank_questions
105
  ```
106
 
107
 
108
  ### 5. Veracity prediction
109
  Finally, given a claim and its 3 QA pairs as evidence, we use another pre-trained BERT model [bert_veracity.ckpt](https://huggingface.co/chenxwh/AVeriTeC/blob/main/pretrained_models/bert_veracity.ckpt) to predict the veracity label. See [veracity_prediction.py](https://huggingface.co/chenxwh/AVeriTeC/blob/main/src/prediction/veracity_prediction.py) for more argument options. We provide the prediction file for this step on the dev set [here](https://huggingface.co/chenxwh/AVeriTeC/blob/main/data_store/dev_vericity_prediction.json).
110
  ```bash
111
- python -m prediction.veracity_prediction
 
 
 
 
 
112
  ```
113
 
114
  The result for dev and the test set below. We recommend using 0.25 as cut-off score for evaluating the relevance of the evidence.
 
101
  ### 4. Rerank the QA pairs
102
  Using a pre-trained BERT model [bert_dual_encoder.ckpt](https://huggingface.co/chenxwh/AVeriTeC/blob/main/pretrained_models/bert_dual_encoder.ckpt), we rerank the QA paris and keep top 3 QA paris as evidence. See [rerank_questions.py](https://huggingface.co/chenxwh/AVeriTeC/blob/main/src/reranking/rerank_questions.py) for more argument options. We provide the output file for this step on the dev set [here](https://huggingface.co/chenxwh/AVeriTeC/blob/main/data_store/dev_top_3_rerank_qa.json).
103
  ```bash
104
+ python -m src.reranking.rerank_questions
105
  ```
106
 
107
 
108
  ### 5. Veracity prediction
109
  Finally, given a claim and its 3 QA pairs as evidence, we use another pre-trained BERT model [bert_veracity.ckpt](https://huggingface.co/chenxwh/AVeriTeC/blob/main/pretrained_models/bert_veracity.ckpt) to predict the veracity label. See [veracity_prediction.py](https://huggingface.co/chenxwh/AVeriTeC/blob/main/src/prediction/veracity_prediction.py) for more argument options. We provide the prediction file for this step on the dev set [here](https://huggingface.co/chenxwh/AVeriTeC/blob/main/data_store/dev_vericity_prediction.json).
110
  ```bash
111
+ python -m src.prediction.veracity_prediction
112
+ ```
113
+
114
+ Then evaluate the veracity prediction performance with (see [evaluate_veracity.py](https://huggingface.co/chenxwh/AVeriTeC/blob/main/src/prediction/evaluate_veracity.py) for more argument options):
115
+ ```bash
116
+ python -m src.prediction.evaluate_veracity
117
  ```
118
 
119
  The result for dev and the test set below. We recommend using 0.25 as cut-off score for evaluating the relevance of the evidence.
src/prediction/evaluate_veracity.py ADDED
@@ -0,0 +1,316 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import scipy
4
+ import numpy as np
5
+ import sklearn
6
+ import nltk
7
+ from nltk import word_tokenize
8
+
9
+
10
+ def pairwise_meteor(candidate, reference):
11
+ return nltk.translate.meteor_score.single_meteor_score(
12
+ word_tokenize(reference), word_tokenize(candidate)
13
+ )
14
+
15
+
16
+ def compute_all_pairwise_scores(src_data, tgt_data, metric):
17
+ scores = np.empty((len(src_data), len(tgt_data)))
18
+
19
+ for i, src in enumerate(src_data):
20
+ for j, tgt in enumerate(tgt_data):
21
+ scores[i][j] = metric(src, tgt)
22
+
23
+ return scores
24
+
25
+
26
+ def print_with_space(left, right, left_space=40):
27
+ print_spaces = " " * (left_space - len(left))
28
+ print(left + print_spaces + right)
29
+
30
+
31
+ class AVeriTeCEvaluator:
32
+
33
+ verdicts = [
34
+ "Supported",
35
+ "Refuted",
36
+ "Not Enough Evidence",
37
+ "Conflicting Evidence/Cherrypicking",
38
+ ]
39
+ pairwise_metric = None
40
+ max_questions = 10
41
+ metric = None
42
+ averitec_reporting_levels = [0.1, 0.2, 0.25, 0.3, 0.4, 0.5]
43
+
44
+ def __init__(self, metric="meteor"):
45
+ self.metric = metric
46
+ if metric == "meteor":
47
+ self.pairwise_metric = pairwise_meteor
48
+
49
+ def evaluate_averitec_veracity_by_type(self, srcs, tgts, threshold=0.25):
50
+ types = {}
51
+ for src, tgt in zip(srcs, tgts):
52
+ score = self.compute_pairwise_evidence_score(src, tgt)
53
+
54
+ if score <= threshold:
55
+ score = 0
56
+
57
+ for t in tgt["claim_types"]:
58
+ if t not in types:
59
+ types[t] = []
60
+
61
+ types[t].append(score)
62
+
63
+ return {t: np.mean(v) for t, v in types.items()}
64
+
65
+ def evaluate_averitec_score(self, srcs, tgts):
66
+ scores = []
67
+ for src, tgt in zip(srcs, tgts):
68
+ score = self.compute_pairwise_evidence_score(src, tgt)
69
+
70
+ this_example_scores = [0.0 for _ in self.averitec_reporting_levels]
71
+ for i, level in enumerate(self.averitec_reporting_levels):
72
+ if score > level:
73
+ this_example_scores[i] = src["pred_label"] == tgt["label"]
74
+
75
+ scores.append(this_example_scores)
76
+
77
+ return np.mean(np.array(scores), axis=0)
78
+
79
+ def evaluate_veracity(self, src, tgt):
80
+ src_labels = [x["pred_label"] for x in src]
81
+ tgt_labels = [x["label"] for x in tgt]
82
+
83
+ acc = np.mean([s == t for s, t in zip(src_labels, tgt_labels)])
84
+
85
+ f1 = {
86
+ self.verdicts[i]: x
87
+ for i, x in enumerate(
88
+ sklearn.metrics.f1_score(
89
+ tgt_labels, src_labels, labels=self.verdicts, average=None
90
+ )
91
+ )
92
+ }
93
+ f1["macro"] = sklearn.metrics.f1_score(
94
+ tgt_labels, src_labels, labels=self.verdicts, average="macro"
95
+ )
96
+ f1["acc"] = acc
97
+ return f1
98
+
99
+ def evaluate_questions_only(self, srcs, tgts):
100
+ all_utils = []
101
+ for src, tgt in zip(srcs, tgts):
102
+ if "evidence" not in src:
103
+ # If there was no evidence, use the string evidence
104
+ src_questions = self.extract_full_comparison_strings(
105
+ src, is_target=False
106
+ )[: self.max_questions]
107
+ else:
108
+ src_questions = [
109
+ qa["question"] for qa in src["evidence"][: self.max_questions]
110
+ ]
111
+ tgt_questions = [qa["question"] for qa in tgt["questions"]]
112
+
113
+ pairwise_scores = compute_all_pairwise_scores(
114
+ src_questions, tgt_questions, self.pairwise_metric
115
+ )
116
+
117
+ assignment = scipy.optimize.linear_sum_assignment(
118
+ pairwise_scores, maximize=True
119
+ )
120
+
121
+ assignment_utility = pairwise_scores[assignment[0], assignment[1]].sum()
122
+
123
+ # Reweight to account for unmatched target questions
124
+ reweight_term = 1 / float(len(tgt_questions))
125
+ assignment_utility *= reweight_term
126
+
127
+ all_utils.append(assignment_utility)
128
+
129
+ return np.mean(all_utils)
130
+
131
+ def get_n_best_qau(self, srcs, tgts, n=3):
132
+ all_utils = []
133
+ for src, tgt in zip(srcs, tgts):
134
+ assignment_utility = self.compute_pairwise_evidence_score(src, tgt)
135
+
136
+ all_utils.append(assignment_utility)
137
+
138
+ idxs = np.argsort(all_utils)[::-1][:n]
139
+
140
+ examples = [
141
+ (
142
+ (
143
+ srcs[i]["questions"]
144
+ if "questions" in srcs[i]
145
+ else srcs[i]["string_evidence"]
146
+ ),
147
+ tgts[i]["questions"],
148
+ all_utils[i],
149
+ )
150
+ for i in idxs
151
+ ]
152
+
153
+ return examples
154
+
155
+ def compute_pairwise_evidence_score(self, src, tgt):
156
+ """Different key is used for reference_data and prediction.
157
+ For the prediction, the format is
158
+ {"evidence": [
159
+ {
160
+ "question": "What does the increased federal medical assistance percentage mean for you?",
161
+ "answer": "Appendix A: Applicability of the Increased Federal Medical Assistance Percentage ",
162
+ "url": "https://www.medicaid.gov/federal-policy-guidance/downloads/smd21003.pdf"
163
+ }],
164
+ "pred_label": "Supported"}
165
+
166
+ And for the data with fold label:
167
+ {"questions": [
168
+ {
169
+ "question": "Where was the claim first published",
170
+ "answers": [
171
+ {
172
+ "answer": "It was first published on Sccopertino",
173
+ "answer_type": "Abstractive",
174
+ "source_url": "https://web.archive.org/web/20201129141238/https://scoopertino.com/exposed-the-imac-disaster-that-almost-was/",
175
+ "source_medium": "Web text",
176
+ "cached_source_url": "https://web.archive.org/web/20201129141238/https://scoopertino.com/exposed-the-imac-disaster-that-almost-was/"
177
+ }
178
+ ]
179
+ }]
180
+ "label": "Refuted"}
181
+ """
182
+
183
+ src_strings = self.extract_full_comparison_strings(src, is_target=False)[
184
+ : self.max_questions
185
+ ]
186
+ tgt_strings = self.extract_full_comparison_strings(tgt)
187
+ pairwise_scores = compute_all_pairwise_scores(
188
+ src_strings, tgt_strings, self.pairwise_metric
189
+ )
190
+ assignment = scipy.optimize.linear_sum_assignment(
191
+ pairwise_scores, maximize=True
192
+ )
193
+ assignment_utility = pairwise_scores[assignment[0], assignment[1]].sum()
194
+
195
+ # Reweight to account for unmatched target questions
196
+ reweight_term = 1 / float(len(tgt_strings))
197
+ assignment_utility *= reweight_term
198
+ return assignment_utility
199
+
200
+ def evaluate_questions_and_answers(self, srcs, tgts):
201
+ all_utils = []
202
+ for src, tgt in zip(srcs, tgts):
203
+ src_strings = self.extract_full_comparison_strings(src, is_target=False)[
204
+ : self.max_questions
205
+ ]
206
+ tgt_strings = self.extract_full_comparison_strings(tgt)
207
+
208
+ pairwise_scores = compute_all_pairwise_scores(
209
+ src_strings, tgt_strings, self.pairwise_metric
210
+ )
211
+
212
+ assignment = scipy.optimize.linear_sum_assignment(
213
+ pairwise_scores, maximize=True
214
+ )
215
+
216
+ assignment_utility = pairwise_scores[assignment[0], assignment[1]].sum()
217
+
218
+ # Reweight to account for unmatched target questions
219
+ reweight_term = 1 / float(len(tgt_strings))
220
+ assignment_utility *= reweight_term
221
+
222
+ all_utils.append(assignment_utility)
223
+
224
+ return np.mean(all_utils)
225
+
226
+ def extract_full_comparison_strings(self, example, is_target=True):
227
+ example_strings = []
228
+
229
+ if is_target:
230
+ if "questions" in example:
231
+ for evidence in example["questions"]:
232
+ # If the answers is not a list, make them a list:
233
+ if not isinstance(evidence["answers"], list):
234
+ evidence["answers"] = [evidence["answers"]]
235
+
236
+ for answer in evidence["answers"]:
237
+ example_strings.append(
238
+ evidence["question"] + " " + answer["answer"]
239
+ )
240
+ if (
241
+ "answer_type" in answer
242
+ and answer["answer_type"] == "Boolean"
243
+ ):
244
+ example_strings[-1] += ". " + answer["boolean_explanation"]
245
+ if len(evidence["answers"]) == 0:
246
+ example_strings.append(
247
+ evidence["question"] + " No answer could be found."
248
+ )
249
+ else:
250
+ if "evidence" in example:
251
+ for evidence in example["evidence"]:
252
+ example_strings.append(
253
+ evidence["question"] + " " + evidence["answer"]
254
+ )
255
+
256
+ if "string_evidence" in example:
257
+ for full_string_evidence in example["string_evidence"]:
258
+ example_strings.append(full_string_evidence)
259
+ return example_strings
260
+
261
+
262
+ if __name__ == "__main__":
263
+ parser = argparse.ArgumentParser(description="Evaluate the veracity prediction.")
264
+ parser.add_argument(
265
+ "-i",
266
+ "--prediction_file",
267
+ default="data_store/dev_veracity.json",
268
+ help="Json file with claim, evidence, and veracity prediction.",
269
+ )
270
+ parser.add_argument(
271
+ "--label_file",
272
+ default="data/dev.json",
273
+ help="Json file with labels.",
274
+ )
275
+ args = parser.parse_args()
276
+
277
+ with open(args.prediction_file) as f:
278
+ predictions = json.load(f)
279
+
280
+ with open(args.label_file) as f:
281
+ references = json.load(f)
282
+
283
+ scorer = AVeriTeCEvaluator()
284
+ q_score = scorer.evaluate_questions_only(predictions, references)
285
+ print_with_space("Question-only score (HU-" + scorer.metric + "):", str(q_score))
286
+ p_score = scorer.evaluate_questions_and_answers(predictions, references)
287
+ print_with_space("Question-answer score (HU-" + scorer.metric + "):", str(p_score))
288
+ print("====================")
289
+
290
+ v_score = scorer.evaluate_veracity(predictions, references)
291
+ print("Veracity F1 scores:")
292
+ for k, v in v_score.items():
293
+ print_with_space(" * " + k + ":", str(v))
294
+
295
+ print("--------------------")
296
+ print("AVeriTeC scores:")
297
+
298
+ v_score = scorer.evaluate_averitec_score(predictions, references)
299
+
300
+ for i, level in enumerate(scorer.averitec_reporting_levels):
301
+ print_with_space(
302
+ " * Veracity scores (" + scorer.metric + " @ " + str(level) + "):",
303
+ str(v_score[i]),
304
+ )
305
+ print("--------------------")
306
+ type_scores = scorer.evaluate_averitec_veracity_by_type(
307
+ predictions, references, threshold=0.2
308
+ )
309
+ for t, v in type_scores.items():
310
+ print_with_space(" * Veracity scores (" + t + "):", str(v))
311
+ print("--------------------")
312
+ type_scores = scorer.evaluate_averitec_veracity_by_type(
313
+ predictions, references, threshold=0.3
314
+ )
315
+ for t, v in type_scores.items():
316
+ print_with_space(" * Veracity scores (" + t + "):", str(v))
src/prediction/veracity_prediction.py CHANGED
@@ -24,7 +24,7 @@ if __name__ == "__main__":
24
  parser.add_argument(
25
  "-i",
26
  "--claim_with_evidence_file",
27
- default="data/dev_top3_questions.json",
28
  help="Json file with claim and top question-answer pairs as evidence.",
29
  )
30
  parser.add_argument(
@@ -41,8 +41,10 @@ if __name__ == "__main__":
41
  )
42
  args = parser.parse_args()
43
 
 
44
  with open(args.claim_with_evidence_file) as f:
45
- examples = json.load(f)
 
46
 
47
  bert_model_name = "bert-base-uncased"
48
 
@@ -113,7 +115,7 @@ if __name__ == "__main__":
113
  "claim_id": example["claim_id"],
114
  "claim": example["claim"],
115
  "evidence": example["evidence"],
116
- "label": LABEL[answer],
117
  }
118
  predictions.append(json_data)
119
 
 
24
  parser.add_argument(
25
  "-i",
26
  "--claim_with_evidence_file",
27
+ default="data_store/dev_top_3_rerank_qa.json",
28
  help="Json file with claim and top question-answer pairs as evidence.",
29
  )
30
  parser.add_argument(
 
41
  )
42
  args = parser.parse_args()
43
 
44
+ examples = []
45
  with open(args.claim_with_evidence_file) as f:
46
+ for line in f:
47
+ examples.append(json.loads(line))
48
 
49
  bert_model_name = "bert-base-uncased"
50
 
 
115
  "claim_id": example["claim_id"],
116
  "claim": example["claim"],
117
  "evidence": example["evidence"],
118
+ "pred_label": LABEL[answer],
119
  }
120
  predictions.append(json_data)
121
 
src/reranking/rerank_questions.py CHANGED
@@ -23,7 +23,7 @@ if __name__ == "__main__":
23
  parser.add_argument(
24
  "-o",
25
  "--output_file",
26
- default="data/dev_top_3_rerank_qa.json",
27
  help="Json file with the top3 reranked questions.",
28
  )
29
  parser.add_argument(
@@ -40,8 +40,10 @@ if __name__ == "__main__":
40
  )
41
  args = parser.parse_args()
42
 
 
43
  with open(args.top_k_qa_file) as f:
44
- examples = json.load(f)
 
45
 
46
  bert_model_name = "bert-base-uncased"
47
 
 
23
  parser.add_argument(
24
  "-o",
25
  "--output_file",
26
+ default="data_store/dev_top_3_rerank_qa.json",
27
  help="Json file with the top3 reranked questions.",
28
  )
29
  parser.add_argument(
 
40
  )
41
  args = parser.parse_args()
42
 
43
+ examples = []
44
  with open(args.top_k_qa_file) as f:
45
+ for line in f:
46
+ examples.append(json.loads(line))
47
 
48
  bert_model_name = "bert-base-uncased"
49