Kamichanw commited on
Commit
a9c85a7
·
verified ·
1 Parent(s): aec8e5b

Create vqa_accuracy.py

Browse files
Files changed (1) hide show
  1. vqa_accuracy.py +320 -0
vqa_accuracy.py ADDED
@@ -0,0 +1,320 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datasets
2
+ import evaluate
3
+ import re
4
+
5
+ _DESCRIPTION = """
6
+ VQA accuracy is a evaluation metric which is robust to inter-human variability in phrasing the answers:
7
+ Acc(`ans`) = min{ # humans that said `ans` / 3, 1 }
8
+ Where `ans` is answered by machine. In order to be consistent with 'human accuracies', machine accuracies are averaged over all 10 choose 9 sets of human annotators.
9
+ Note that to obtain results consistent with offical VQA evaluation, all inputs should be processed with `postprocess_generation` from testbed.data.vqav2.
10
+ """
11
+
12
+
13
+ _KWARGS_DESCRIPTION = """
14
+ Args:
15
+ predictions (`list` of `str`): Predicted answers.
16
+ references (`list` of `str` lists): Ground truth answers.
17
+ answer_types (`list` of `str`, *optional*): Answer types corresponding to each questions.
18
+ questions_type (`list` of `str`, *optional*): Question types corresponding to each questions.
19
+ precision (`int`, defaults to 2): The precision of results.
20
+ Returns:
21
+ visual question answering accuracy (`float` or `int`): Accuracy accuracy. Minimum possible value is 0. Maximum possible value is 1.0, or the number of examples input, if `normalize` is set to `True`.. A higher accuracy means higher accuracy.
22
+
23
+ """
24
+
25
+
26
+ _CITATION = """
27
+ @InProceedings{{VQA},
28
+ author = {Stanislaw Antol and Aishwarya Agrawal and Jiasen Lu and Margaret Mitchell and Dhruv Batra and C. Lawrence Zitnick and Devi Parikh},
29
+ title = {{VQA}: {V}isual {Q}uestion {A}nswering},
30
+ booktitle = {International Conference on Computer Vision (ICCV)},
31
+ year = {2015},
32
+ }
33
+ """
34
+
35
+ contractions = {
36
+ "aint": "ain't",
37
+ "arent": "aren't",
38
+ "cant": "can't",
39
+ "couldve": "could've",
40
+ "couldnt": "couldn't",
41
+ "couldn'tve": "couldn't've",
42
+ "couldnt've": "couldn't've",
43
+ "didnt": "didn't",
44
+ "doesnt": "doesn't",
45
+ "dont": "don't",
46
+ "hadnt": "hadn't",
47
+ "hadnt've": "hadn't've",
48
+ "hadn'tve": "hadn't've",
49
+ "hasnt": "hasn't",
50
+ "havent": "haven't",
51
+ "hed": "he'd",
52
+ "hed've": "he'd've",
53
+ "he'dve": "he'd've",
54
+ "hes": "he's",
55
+ "howd": "how'd",
56
+ "howll": "how'll",
57
+ "hows": "how's",
58
+ "Id've": "I'd've",
59
+ "I'dve": "I'd've",
60
+ "Im": "I'm",
61
+ "Ive": "I've",
62
+ "isnt": "isn't",
63
+ "itd": "it'd",
64
+ "itd've": "it'd've",
65
+ "it'dve": "it'd've",
66
+ "itll": "it'll",
67
+ "let's": "let's",
68
+ "maam": "ma'am",
69
+ "mightnt": "mightn't",
70
+ "mightnt've": "mightn't've",
71
+ "mightn'tve": "mightn't've",
72
+ "mightve": "might've",
73
+ "mustnt": "mustn't",
74
+ "mustve": "must've",
75
+ "neednt": "needn't",
76
+ "notve": "not've",
77
+ "oclock": "o'clock",
78
+ "oughtnt": "oughtn't",
79
+ "ow's'at": "'ow's'at",
80
+ "'ows'at": "'ow's'at",
81
+ "'ow'sat": "'ow's'at",
82
+ "shant": "shan't",
83
+ "shed've": "she'd've",
84
+ "she'dve": "she'd've",
85
+ "she's": "she's",
86
+ "shouldve": "should've",
87
+ "shouldnt": "shouldn't",
88
+ "shouldnt've": "shouldn't've",
89
+ "shouldn'tve": "shouldn't've",
90
+ "somebody'd": "somebodyd",
91
+ "somebodyd've": "somebody'd've",
92
+ "somebody'dve": "somebody'd've",
93
+ "somebodyll": "somebody'll",
94
+ "somebodys": "somebody's",
95
+ "someoned": "someone'd",
96
+ "someoned've": "someone'd've",
97
+ "someone'dve": "someone'd've",
98
+ "someonell": "someone'll",
99
+ "someones": "someone's",
100
+ "somethingd": "something'd",
101
+ "somethingd've": "something'd've",
102
+ "something'dve": "something'd've",
103
+ "somethingll": "something'll",
104
+ "thats": "that's",
105
+ "thered": "there'd",
106
+ "thered've": "there'd've",
107
+ "there'dve": "there'd've",
108
+ "therere": "there're",
109
+ "theres": "there's",
110
+ "theyd": "they'd",
111
+ "theyd've": "they'd've",
112
+ "they'dve": "they'd've",
113
+ "theyll": "they'll",
114
+ "theyre": "they're",
115
+ "theyve": "they've",
116
+ "twas": "'twas",
117
+ "wasnt": "wasn't",
118
+ "wed've": "we'd've",
119
+ "we'dve": "we'd've",
120
+ "weve": "we've",
121
+ "werent": "weren't",
122
+ "whatll": "what'll",
123
+ "whatre": "what're",
124
+ "whats": "what's",
125
+ "whatve": "what've",
126
+ "whens": "when's",
127
+ "whered": "where'd",
128
+ "wheres": "where's",
129
+ "whereve": "where've",
130
+ "whod": "who'd",
131
+ "whod've": "who'd've",
132
+ "who'dve": "who'd've",
133
+ "wholl": "who'll",
134
+ "whos": "who's",
135
+ "whove": "who've",
136
+ "whyll": "why'll",
137
+ "whyre": "why're",
138
+ "whys": "why's",
139
+ "wont": "won't",
140
+ "wouldve": "would've",
141
+ "wouldnt": "wouldn't",
142
+ "wouldnt've": "wouldn't've",
143
+ "wouldn'tve": "wouldn't've",
144
+ "yall": "y'all",
145
+ "yall'll": "y'all'll",
146
+ "y'allll": "y'all'll",
147
+ "yall'd've": "y'all'd've",
148
+ "y'alld've": "y'all'd've",
149
+ "y'all'dve": "y'all'd've",
150
+ "youd": "you'd",
151
+ "youd've": "you'd've",
152
+ "you'dve": "you'd've",
153
+ "youll": "you'll",
154
+ "youre": "you're",
155
+ "youve": "you've",
156
+ }
157
+ manualMap = {
158
+ "none": "0",
159
+ "zero": "0",
160
+ "one": "1",
161
+ "two": "2",
162
+ "three": "3",
163
+ "four": "4",
164
+ "five": "5",
165
+ "six": "6",
166
+ "seven": "7",
167
+ "eight": "8",
168
+ "nine": "9",
169
+ "ten": "10",
170
+ }
171
+ articles = ["a", "an", "the"]
172
+
173
+ periodStrip = re.compile(r"(?!<=\d)(\.)(?!\d)")
174
+ commaStrip = re.compile(r"(\d)(\,)(\d)")
175
+ punct = [
176
+ ";",
177
+ r"/",
178
+ "[",
179
+ "]",
180
+ '"',
181
+ "{",
182
+ "}",
183
+ "(",
184
+ ")",
185
+ "=",
186
+ "+",
187
+ "\\",
188
+ "_",
189
+ "-",
190
+ ">",
191
+ "<",
192
+ "@",
193
+ "`",
194
+ ",",
195
+ "?",
196
+ "!",
197
+ ]
198
+
199
+
200
+ def processPunctuation(inText):
201
+ outText = inText
202
+ for p in punct:
203
+ if (p + " " in inText or " " + p in inText) or (
204
+ re.search(commaStrip, inText) != None
205
+ ):
206
+ outText = outText.replace(p, "")
207
+ else:
208
+ outText = outText.replace(p, " ")
209
+ outText = periodStrip.sub("", outText, re.UNICODE)
210
+ return outText
211
+
212
+
213
+ def processDigitArticle(inText):
214
+ outText = []
215
+ tempText = inText.lower().split()
216
+ for word in tempText:
217
+ word = manualMap.setdefault(word, word)
218
+ if word not in articles:
219
+ outText.append(word)
220
+ else:
221
+ pass
222
+ for wordId, word in enumerate(outText):
223
+ if word in contractions:
224
+ outText[wordId] = contractions[word]
225
+ outText = " ".join(outText)
226
+ return outText
227
+
228
+
229
+ @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
230
+ class VQAaccuracy(evaluate.Metric):
231
+ def _info(self):
232
+ return evaluate.MetricInfo(
233
+ description=_DESCRIPTION,
234
+ citation=_CITATION,
235
+ inputs_description=_KWARGS_DESCRIPTION,
236
+ features=datasets.Features(
237
+ {
238
+ "predictions": datasets.Value("string", id="sequence"),
239
+ "references": datasets.Sequence(
240
+ datasets.Value("string", id="sequence"), id="references"
241
+ ),
242
+ "answer_types": datasets.Value("string", id="sequence"),
243
+ "question_types": datasets.Value("string", id="sequence"),
244
+ }
245
+ ),
246
+ reference_urls=[
247
+ "https://visualqa.org/evaluation.html",
248
+ "https://github.com/GT-Vision-Lab/VQA/blob/master",
249
+ ],
250
+ )
251
+
252
+ def _compute(
253
+ self,
254
+ predictions,
255
+ references,
256
+ answer_types=None,
257
+ question_types=None,
258
+ precision=2,
259
+ ):
260
+ if answer_types is None:
261
+ answer_types = [None] * len(predictions)
262
+
263
+ if question_types is None:
264
+ question_types = [None] * len(predictions)
265
+
266
+ if not len(predictions) == len(answer_types) == len(question_types):
267
+ raise ValueError(
268
+ "The length of predictions, answer_types and question_types doesn't match."
269
+ )
270
+
271
+ total, ans_type_dict, ques_type_dict = [], {}, {}
272
+
273
+ for pred, gts, ans_type, ques_type in zip(
274
+ predictions, references, answer_types, question_types
275
+ ):
276
+ # to align with offical data postprocess
277
+ pred = pred.replace("\n", " ").replace("\t", " ").strip()
278
+ pred = processDigitArticle(processPunctuation(pred))
279
+ gts = [processDigitArticle(processPunctuation(gt_ans)) for gt_ans in gts]
280
+
281
+ # calculate vqa accuracy
282
+ accuracy = []
283
+ for i in range(len(gts)):
284
+ other_gt = gts[:i] + gts[i + 1 :]
285
+ matching_ans = [item for item in other_gt if item == pred]
286
+ accuracy.append(min(1, len(matching_ans) / 3))
287
+
288
+ vqa_acc = sum(accuracy) / len(accuracy)
289
+ total.append(vqa_acc)
290
+
291
+ if ans_type is not None:
292
+ if ans_type not in ans_type_dict:
293
+ ans_type_dict[ans_type] = []
294
+ ans_type_dict[ans_type].append(vqa_acc)
295
+
296
+ if ques_type is not None:
297
+ if ques_type not in ques_type_dict:
298
+ ques_type_dict[ques_type] = []
299
+ ques_type_dict[ques_type].append(vqa_acc)
300
+
301
+ # the following key names follow the naming of the official evaluation results
302
+ result = {"overall": round(100 * sum(total) / len(total), precision)}
303
+
304
+ if len(ans_type_dict) > 0:
305
+ result["perAnswerType"] = {
306
+ ans_type: round(
307
+ 100 * sum(accuracy_list) / len(accuracy_list), precision
308
+ )
309
+ for ans_type, accuracy_list in ans_type_dict.items()
310
+ }
311
+
312
+ if len(ques_type_dict) > 0:
313
+ result["perQuestionType"] = {
314
+ ques_type: round(
315
+ 100 * sum(accuracy_list) / len(accuracy_list), precision
316
+ )
317
+ for ques_type, accuracy_list in ques_type_dict.items()
318
+ }
319
+
320
+ return result