Spaces:
Runtime error
Runtime error
Create vqa_accuracy.py
Browse files- vqa_accuracy.py +320 -0
vqa_accuracy.py
ADDED
@@ -0,0 +1,320 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import datasets
|
2 |
+
import evaluate
|
3 |
+
import re
|
4 |
+
|
5 |
+
_DESCRIPTION = """
|
6 |
+
VQA accuracy is a evaluation metric which is robust to inter-human variability in phrasing the answers:
|
7 |
+
Acc(`ans`) = min{ # humans that said `ans` / 3, 1 }
|
8 |
+
Where `ans` is answered by machine. In order to be consistent with 'human accuracies', machine accuracies are averaged over all 10 choose 9 sets of human annotators.
|
9 |
+
Note that to obtain results consistent with offical VQA evaluation, all inputs should be processed with `postprocess_generation` from testbed.data.vqav2.
|
10 |
+
"""
|
11 |
+
|
12 |
+
|
13 |
+
_KWARGS_DESCRIPTION = """
|
14 |
+
Args:
|
15 |
+
predictions (`list` of `str`): Predicted answers.
|
16 |
+
references (`list` of `str` lists): Ground truth answers.
|
17 |
+
answer_types (`list` of `str`, *optional*): Answer types corresponding to each questions.
|
18 |
+
questions_type (`list` of `str`, *optional*): Question types corresponding to each questions.
|
19 |
+
precision (`int`, defaults to 2): The precision of results.
|
20 |
+
Returns:
|
21 |
+
visual question answering accuracy (`float` or `int`): Accuracy accuracy. Minimum possible value is 0. Maximum possible value is 1.0, or the number of examples input, if `normalize` is set to `True`.. A higher accuracy means higher accuracy.
|
22 |
+
|
23 |
+
"""
|
24 |
+
|
25 |
+
|
26 |
+
_CITATION = """
|
27 |
+
@InProceedings{{VQA},
|
28 |
+
author = {Stanislaw Antol and Aishwarya Agrawal and Jiasen Lu and Margaret Mitchell and Dhruv Batra and C. Lawrence Zitnick and Devi Parikh},
|
29 |
+
title = {{VQA}: {V}isual {Q}uestion {A}nswering},
|
30 |
+
booktitle = {International Conference on Computer Vision (ICCV)},
|
31 |
+
year = {2015},
|
32 |
+
}
|
33 |
+
"""
|
34 |
+
|
35 |
+
contractions = {
|
36 |
+
"aint": "ain't",
|
37 |
+
"arent": "aren't",
|
38 |
+
"cant": "can't",
|
39 |
+
"couldve": "could've",
|
40 |
+
"couldnt": "couldn't",
|
41 |
+
"couldn'tve": "couldn't've",
|
42 |
+
"couldnt've": "couldn't've",
|
43 |
+
"didnt": "didn't",
|
44 |
+
"doesnt": "doesn't",
|
45 |
+
"dont": "don't",
|
46 |
+
"hadnt": "hadn't",
|
47 |
+
"hadnt've": "hadn't've",
|
48 |
+
"hadn'tve": "hadn't've",
|
49 |
+
"hasnt": "hasn't",
|
50 |
+
"havent": "haven't",
|
51 |
+
"hed": "he'd",
|
52 |
+
"hed've": "he'd've",
|
53 |
+
"he'dve": "he'd've",
|
54 |
+
"hes": "he's",
|
55 |
+
"howd": "how'd",
|
56 |
+
"howll": "how'll",
|
57 |
+
"hows": "how's",
|
58 |
+
"Id've": "I'd've",
|
59 |
+
"I'dve": "I'd've",
|
60 |
+
"Im": "I'm",
|
61 |
+
"Ive": "I've",
|
62 |
+
"isnt": "isn't",
|
63 |
+
"itd": "it'd",
|
64 |
+
"itd've": "it'd've",
|
65 |
+
"it'dve": "it'd've",
|
66 |
+
"itll": "it'll",
|
67 |
+
"let's": "let's",
|
68 |
+
"maam": "ma'am",
|
69 |
+
"mightnt": "mightn't",
|
70 |
+
"mightnt've": "mightn't've",
|
71 |
+
"mightn'tve": "mightn't've",
|
72 |
+
"mightve": "might've",
|
73 |
+
"mustnt": "mustn't",
|
74 |
+
"mustve": "must've",
|
75 |
+
"neednt": "needn't",
|
76 |
+
"notve": "not've",
|
77 |
+
"oclock": "o'clock",
|
78 |
+
"oughtnt": "oughtn't",
|
79 |
+
"ow's'at": "'ow's'at",
|
80 |
+
"'ows'at": "'ow's'at",
|
81 |
+
"'ow'sat": "'ow's'at",
|
82 |
+
"shant": "shan't",
|
83 |
+
"shed've": "she'd've",
|
84 |
+
"she'dve": "she'd've",
|
85 |
+
"she's": "she's",
|
86 |
+
"shouldve": "should've",
|
87 |
+
"shouldnt": "shouldn't",
|
88 |
+
"shouldnt've": "shouldn't've",
|
89 |
+
"shouldn'tve": "shouldn't've",
|
90 |
+
"somebody'd": "somebodyd",
|
91 |
+
"somebodyd've": "somebody'd've",
|
92 |
+
"somebody'dve": "somebody'd've",
|
93 |
+
"somebodyll": "somebody'll",
|
94 |
+
"somebodys": "somebody's",
|
95 |
+
"someoned": "someone'd",
|
96 |
+
"someoned've": "someone'd've",
|
97 |
+
"someone'dve": "someone'd've",
|
98 |
+
"someonell": "someone'll",
|
99 |
+
"someones": "someone's",
|
100 |
+
"somethingd": "something'd",
|
101 |
+
"somethingd've": "something'd've",
|
102 |
+
"something'dve": "something'd've",
|
103 |
+
"somethingll": "something'll",
|
104 |
+
"thats": "that's",
|
105 |
+
"thered": "there'd",
|
106 |
+
"thered've": "there'd've",
|
107 |
+
"there'dve": "there'd've",
|
108 |
+
"therere": "there're",
|
109 |
+
"theres": "there's",
|
110 |
+
"theyd": "they'd",
|
111 |
+
"theyd've": "they'd've",
|
112 |
+
"they'dve": "they'd've",
|
113 |
+
"theyll": "they'll",
|
114 |
+
"theyre": "they're",
|
115 |
+
"theyve": "they've",
|
116 |
+
"twas": "'twas",
|
117 |
+
"wasnt": "wasn't",
|
118 |
+
"wed've": "we'd've",
|
119 |
+
"we'dve": "we'd've",
|
120 |
+
"weve": "we've",
|
121 |
+
"werent": "weren't",
|
122 |
+
"whatll": "what'll",
|
123 |
+
"whatre": "what're",
|
124 |
+
"whats": "what's",
|
125 |
+
"whatve": "what've",
|
126 |
+
"whens": "when's",
|
127 |
+
"whered": "where'd",
|
128 |
+
"wheres": "where's",
|
129 |
+
"whereve": "where've",
|
130 |
+
"whod": "who'd",
|
131 |
+
"whod've": "who'd've",
|
132 |
+
"who'dve": "who'd've",
|
133 |
+
"wholl": "who'll",
|
134 |
+
"whos": "who's",
|
135 |
+
"whove": "who've",
|
136 |
+
"whyll": "why'll",
|
137 |
+
"whyre": "why're",
|
138 |
+
"whys": "why's",
|
139 |
+
"wont": "won't",
|
140 |
+
"wouldve": "would've",
|
141 |
+
"wouldnt": "wouldn't",
|
142 |
+
"wouldnt've": "wouldn't've",
|
143 |
+
"wouldn'tve": "wouldn't've",
|
144 |
+
"yall": "y'all",
|
145 |
+
"yall'll": "y'all'll",
|
146 |
+
"y'allll": "y'all'll",
|
147 |
+
"yall'd've": "y'all'd've",
|
148 |
+
"y'alld've": "y'all'd've",
|
149 |
+
"y'all'dve": "y'all'd've",
|
150 |
+
"youd": "you'd",
|
151 |
+
"youd've": "you'd've",
|
152 |
+
"you'dve": "you'd've",
|
153 |
+
"youll": "you'll",
|
154 |
+
"youre": "you're",
|
155 |
+
"youve": "you've",
|
156 |
+
}
|
157 |
+
manualMap = {
|
158 |
+
"none": "0",
|
159 |
+
"zero": "0",
|
160 |
+
"one": "1",
|
161 |
+
"two": "2",
|
162 |
+
"three": "3",
|
163 |
+
"four": "4",
|
164 |
+
"five": "5",
|
165 |
+
"six": "6",
|
166 |
+
"seven": "7",
|
167 |
+
"eight": "8",
|
168 |
+
"nine": "9",
|
169 |
+
"ten": "10",
|
170 |
+
}
|
171 |
+
articles = ["a", "an", "the"]
|
172 |
+
|
173 |
+
periodStrip = re.compile(r"(?!<=\d)(\.)(?!\d)")
|
174 |
+
commaStrip = re.compile(r"(\d)(\,)(\d)")
|
175 |
+
punct = [
|
176 |
+
";",
|
177 |
+
r"/",
|
178 |
+
"[",
|
179 |
+
"]",
|
180 |
+
'"',
|
181 |
+
"{",
|
182 |
+
"}",
|
183 |
+
"(",
|
184 |
+
")",
|
185 |
+
"=",
|
186 |
+
"+",
|
187 |
+
"\\",
|
188 |
+
"_",
|
189 |
+
"-",
|
190 |
+
">",
|
191 |
+
"<",
|
192 |
+
"@",
|
193 |
+
"`",
|
194 |
+
",",
|
195 |
+
"?",
|
196 |
+
"!",
|
197 |
+
]
|
198 |
+
|
199 |
+
|
200 |
+
def processPunctuation(inText):
|
201 |
+
outText = inText
|
202 |
+
for p in punct:
|
203 |
+
if (p + " " in inText or " " + p in inText) or (
|
204 |
+
re.search(commaStrip, inText) != None
|
205 |
+
):
|
206 |
+
outText = outText.replace(p, "")
|
207 |
+
else:
|
208 |
+
outText = outText.replace(p, " ")
|
209 |
+
outText = periodStrip.sub("", outText, re.UNICODE)
|
210 |
+
return outText
|
211 |
+
|
212 |
+
|
213 |
+
def processDigitArticle(inText):
|
214 |
+
outText = []
|
215 |
+
tempText = inText.lower().split()
|
216 |
+
for word in tempText:
|
217 |
+
word = manualMap.setdefault(word, word)
|
218 |
+
if word not in articles:
|
219 |
+
outText.append(word)
|
220 |
+
else:
|
221 |
+
pass
|
222 |
+
for wordId, word in enumerate(outText):
|
223 |
+
if word in contractions:
|
224 |
+
outText[wordId] = contractions[word]
|
225 |
+
outText = " ".join(outText)
|
226 |
+
return outText
|
227 |
+
|
228 |
+
|
229 |
+
@evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
|
230 |
+
class VQAaccuracy(evaluate.Metric):
|
231 |
+
def _info(self):
|
232 |
+
return evaluate.MetricInfo(
|
233 |
+
description=_DESCRIPTION,
|
234 |
+
citation=_CITATION,
|
235 |
+
inputs_description=_KWARGS_DESCRIPTION,
|
236 |
+
features=datasets.Features(
|
237 |
+
{
|
238 |
+
"predictions": datasets.Value("string", id="sequence"),
|
239 |
+
"references": datasets.Sequence(
|
240 |
+
datasets.Value("string", id="sequence"), id="references"
|
241 |
+
),
|
242 |
+
"answer_types": datasets.Value("string", id="sequence"),
|
243 |
+
"question_types": datasets.Value("string", id="sequence"),
|
244 |
+
}
|
245 |
+
),
|
246 |
+
reference_urls=[
|
247 |
+
"https://visualqa.org/evaluation.html",
|
248 |
+
"https://github.com/GT-Vision-Lab/VQA/blob/master",
|
249 |
+
],
|
250 |
+
)
|
251 |
+
|
252 |
+
def _compute(
|
253 |
+
self,
|
254 |
+
predictions,
|
255 |
+
references,
|
256 |
+
answer_types=None,
|
257 |
+
question_types=None,
|
258 |
+
precision=2,
|
259 |
+
):
|
260 |
+
if answer_types is None:
|
261 |
+
answer_types = [None] * len(predictions)
|
262 |
+
|
263 |
+
if question_types is None:
|
264 |
+
question_types = [None] * len(predictions)
|
265 |
+
|
266 |
+
if not len(predictions) == len(answer_types) == len(question_types):
|
267 |
+
raise ValueError(
|
268 |
+
"The length of predictions, answer_types and question_types doesn't match."
|
269 |
+
)
|
270 |
+
|
271 |
+
total, ans_type_dict, ques_type_dict = [], {}, {}
|
272 |
+
|
273 |
+
for pred, gts, ans_type, ques_type in zip(
|
274 |
+
predictions, references, answer_types, question_types
|
275 |
+
):
|
276 |
+
# to align with offical data postprocess
|
277 |
+
pred = pred.replace("\n", " ").replace("\t", " ").strip()
|
278 |
+
pred = processDigitArticle(processPunctuation(pred))
|
279 |
+
gts = [processDigitArticle(processPunctuation(gt_ans)) for gt_ans in gts]
|
280 |
+
|
281 |
+
# calculate vqa accuracy
|
282 |
+
accuracy = []
|
283 |
+
for i in range(len(gts)):
|
284 |
+
other_gt = gts[:i] + gts[i + 1 :]
|
285 |
+
matching_ans = [item for item in other_gt if item == pred]
|
286 |
+
accuracy.append(min(1, len(matching_ans) / 3))
|
287 |
+
|
288 |
+
vqa_acc = sum(accuracy) / len(accuracy)
|
289 |
+
total.append(vqa_acc)
|
290 |
+
|
291 |
+
if ans_type is not None:
|
292 |
+
if ans_type not in ans_type_dict:
|
293 |
+
ans_type_dict[ans_type] = []
|
294 |
+
ans_type_dict[ans_type].append(vqa_acc)
|
295 |
+
|
296 |
+
if ques_type is not None:
|
297 |
+
if ques_type not in ques_type_dict:
|
298 |
+
ques_type_dict[ques_type] = []
|
299 |
+
ques_type_dict[ques_type].append(vqa_acc)
|
300 |
+
|
301 |
+
# the following key names follow the naming of the official evaluation results
|
302 |
+
result = {"overall": round(100 * sum(total) / len(total), precision)}
|
303 |
+
|
304 |
+
if len(ans_type_dict) > 0:
|
305 |
+
result["perAnswerType"] = {
|
306 |
+
ans_type: round(
|
307 |
+
100 * sum(accuracy_list) / len(accuracy_list), precision
|
308 |
+
)
|
309 |
+
for ans_type, accuracy_list in ans_type_dict.items()
|
310 |
+
}
|
311 |
+
|
312 |
+
if len(ques_type_dict) > 0:
|
313 |
+
result["perQuestionType"] = {
|
314 |
+
ques_type: round(
|
315 |
+
100 * sum(accuracy_list) / len(accuracy_list), precision
|
316 |
+
)
|
317 |
+
for ques_type, accuracy_list in ques_type_dict.items()
|
318 |
+
}
|
319 |
+
|
320 |
+
return result
|