Spaces:
Runtime error
Runtime error
import datasets | |
import evaluate | |
import re | |
_DESCRIPTION = """ | |
VQA accuracy is a evaluation metric which is robust to inter-human variability in phrasing the answers: | |
$$ | |
\\text{Acc}(ans) = \\min \\left( \\frac{\\text{# humans that said }ans}{3}, 1 \\right) | |
$$ | |
Where `ans` is answered by machine. In order to be consistent with 'human accuracies', machine accuracies are averaged over all 10 choose 9 sets of human annotators. | |
""" | |
_KWARGS_DESCRIPTION = """ | |
Args: | |
predictions (`list` of `str`): Predicted answers. | |
references (`list` of `str` lists): Ground truth answers. | |
answer_types (`list` of `str`, *optional*): Answer types corresponding to each questions. | |
questions_type (`list` of `str`, *optional*): Question types corresponding to each questions. | |
Returns: | |
visual question answering accuracy (`float` or `int`): Accuracy accuracy. Minimum possible value is 0. Maximum possible value is 100. | |
""" | |
_CITATION = """ | |
@InProceedings{{VQA}, | |
author = {Stanislaw Antol and Aishwarya Agrawal and Jiasen Lu and Margaret Mitchell and Dhruv Batra and C. Lawrence Zitnick and Devi Parikh}, | |
title = {{VQA}: {V}isual {Q}uestion {A}nswering}, | |
booktitle = {International Conference on Computer Vision (ICCV)}, | |
year = {2015}, | |
} | |
""" | |
contractions = { | |
"aint": "ain't", | |
"arent": "aren't", | |
"cant": "can't", | |
"couldve": "could've", | |
"couldnt": "couldn't", | |
"couldn'tve": "couldn't've", | |
"couldnt've": "couldn't've", | |
"didnt": "didn't", | |
"doesnt": "doesn't", | |
"dont": "don't", | |
"hadnt": "hadn't", | |
"hadnt've": "hadn't've", | |
"hadn'tve": "hadn't've", | |
"hasnt": "hasn't", | |
"havent": "haven't", | |
"hed": "he'd", | |
"hed've": "he'd've", | |
"he'dve": "he'd've", | |
"hes": "he's", | |
"howd": "how'd", | |
"howll": "how'll", | |
"hows": "how's", | |
"Id've": "I'd've", | |
"I'dve": "I'd've", | |
"Im": "I'm", | |
"Ive": "I've", | |
"isnt": "isn't", | |
"itd": "it'd", | |
"itd've": "it'd've", | |
"it'dve": "it'd've", | |
"itll": "it'll", | |
"let's": "let's", | |
"maam": "ma'am", | |
"mightnt": "mightn't", | |
"mightnt've": "mightn't've", | |
"mightn'tve": "mightn't've", | |
"mightve": "might've", | |
"mustnt": "mustn't", | |
"mustve": "must've", | |
"neednt": "needn't", | |
"notve": "not've", | |
"oclock": "o'clock", | |
"oughtnt": "oughtn't", | |
"ow's'at": "'ow's'at", | |
"'ows'at": "'ow's'at", | |
"'ow'sat": "'ow's'at", | |
"shant": "shan't", | |
"shed've": "she'd've", | |
"she'dve": "she'd've", | |
"she's": "she's", | |
"shouldve": "should've", | |
"shouldnt": "shouldn't", | |
"shouldnt've": "shouldn't've", | |
"shouldn'tve": "shouldn't've", | |
"somebody'd": "somebodyd", | |
"somebodyd've": "somebody'd've", | |
"somebody'dve": "somebody'd've", | |
"somebodyll": "somebody'll", | |
"somebodys": "somebody's", | |
"someoned": "someone'd", | |
"someoned've": "someone'd've", | |
"someone'dve": "someone'd've", | |
"someonell": "someone'll", | |
"someones": "someone's", | |
"somethingd": "something'd", | |
"somethingd've": "something'd've", | |
"something'dve": "something'd've", | |
"somethingll": "something'll", | |
"thats": "that's", | |
"thered": "there'd", | |
"thered've": "there'd've", | |
"there'dve": "there'd've", | |
"therere": "there're", | |
"theres": "there's", | |
"theyd": "they'd", | |
"theyd've": "they'd've", | |
"they'dve": "they'd've", | |
"theyll": "they'll", | |
"theyre": "they're", | |
"theyve": "they've", | |
"twas": "'twas", | |
"wasnt": "wasn't", | |
"wed've": "we'd've", | |
"we'dve": "we'd've", | |
"weve": "we've", | |
"werent": "weren't", | |
"whatll": "what'll", | |
"whatre": "what're", | |
"whats": "what's", | |
"whatve": "what've", | |
"whens": "when's", | |
"whered": "where'd", | |
"wheres": "where's", | |
"whereve": "where've", | |
"whod": "who'd", | |
"whod've": "who'd've", | |
"who'dve": "who'd've", | |
"wholl": "who'll", | |
"whos": "who's", | |
"whove": "who've", | |
"whyll": "why'll", | |
"whyre": "why're", | |
"whys": "why's", | |
"wont": "won't", | |
"wouldve": "would've", | |
"wouldnt": "wouldn't", | |
"wouldnt've": "wouldn't've", | |
"wouldn'tve": "wouldn't've", | |
"yall": "y'all", | |
"yall'll": "y'all'll", | |
"y'allll": "y'all'll", | |
"yall'd've": "y'all'd've", | |
"y'alld've": "y'all'd've", | |
"y'all'dve": "y'all'd've", | |
"youd": "you'd", | |
"youd've": "you'd've", | |
"you'dve": "you'd've", | |
"youll": "you'll", | |
"youre": "you're", | |
"youve": "you've", | |
} | |
manualMap = { | |
"none": "0", | |
"zero": "0", | |
"one": "1", | |
"two": "2", | |
"three": "3", | |
"four": "4", | |
"five": "5", | |
"six": "6", | |
"seven": "7", | |
"eight": "8", | |
"nine": "9", | |
"ten": "10", | |
} | |
articles = ["a", "an", "the"] | |
periodStrip = re.compile(r"(?!<=\d)(\.)(?!\d)") | |
commaStrip = re.compile(r"(\d)(\,)(\d)") | |
punct = [ | |
";", | |
r"/", | |
"[", | |
"]", | |
'"', | |
"{", | |
"}", | |
"(", | |
")", | |
"=", | |
"+", | |
"\\", | |
"_", | |
"-", | |
">", | |
"<", | |
"@", | |
"`", | |
",", | |
"?", | |
"!", | |
] | |
def processPunctuation(inText): | |
outText = inText | |
for p in punct: | |
if (p + " " in inText or " " + p in inText) or ( | |
re.search(commaStrip, inText) != None | |
): | |
outText = outText.replace(p, "") | |
else: | |
outText = outText.replace(p, " ") | |
outText = periodStrip.sub("", outText, re.UNICODE) | |
return outText | |
def processDigitArticle(inText): | |
outText = [] | |
tempText = inText.lower().split() | |
for word in tempText: | |
word = manualMap.setdefault(word, word) | |
if word not in articles: | |
outText.append(word) | |
else: | |
pass | |
for wordId, word in enumerate(outText): | |
if word in contractions: | |
outText[wordId] = contractions[word] | |
outText = " ".join(outText) | |
return outText | |
class VQAAccuracy(evaluate.Metric): | |
def _info(self): | |
return evaluate.MetricInfo( | |
description=_DESCRIPTION, | |
citation=_CITATION, | |
inputs_description=_KWARGS_DESCRIPTION, | |
features=datasets.Features( | |
{ | |
"predictions": datasets.Value("string", id="sequence"), | |
"references": datasets.Sequence( | |
datasets.Value("string", id="sequence"), id="references" | |
), | |
"answer_types": datasets.Value("string", id="sequence"), | |
"question_types": datasets.Value("string", id="sequence"), | |
} | |
), | |
reference_urls=[ | |
"https://visualqa.org/evaluation.html", | |
"https://github.com/GT-Vision-Lab/VQA/blob/master", | |
], | |
) | |
def _compute(self, predictions, references, answer_types=None, question_types=None): | |
if answer_types is None: | |
answer_types = [None] * len(predictions) | |
if question_types is None: | |
question_types = [None] * len(predictions) | |
if not len(predictions) == len(answer_types) == len(question_types): | |
raise ValueError( | |
"The length of predictions, answer_types and question_types doesn't match." | |
) | |
total, ans_type_dict, ques_type_dict = [], {}, {} | |
for pred, gts, ans_type, ques_type in zip( | |
predictions, references, answer_types, question_types | |
): | |
# to align with offical data postprocess | |
pred = pred.replace("\n", " ").replace("\t", " ").strip() | |
pred = processDigitArticle(processPunctuation(pred)) | |
gts = [processDigitArticle(processPunctuation(gt_ans)) for gt_ans in gts] | |
# calculate vqa accuracy | |
accuracy = [] | |
for i in range(len(gts)): | |
other_gt = gts[:i] + gts[i + 1 :] | |
matching_ans = [item for item in other_gt if item == pred] | |
accuracy.append(min(1, len(matching_ans) / 3)) | |
vqa_acc = sum(accuracy) / len(accuracy) | |
total.append(vqa_acc) | |
if ans_type is not None: | |
if ans_type not in ans_type_dict: | |
ans_type_dict[ans_type] = [] | |
ans_type_dict[ans_type].append(vqa_acc) | |
if ques_type is not None: | |
if ques_type not in ques_type_dict: | |
ques_type_dict[ques_type] = [] | |
ques_type_dict[ques_type].append(vqa_acc) | |
# the following key names follow the naming of the official evaluation results | |
result = {"overall": 100 * sum(total) / len(total)} | |
if len(ans_type_dict) > 0: | |
result["perAnswerType"] = { | |
ans_type: 100 * sum(accuracy_list) / len(accuracy_list) | |
for ans_type, accuracy_list in ans_type_dict.items() | |
} | |
if len(ques_type_dict) > 0: | |
result["perQuestionType"] = { | |
ques_type: 100 * sum(accuracy_list) / len(accuracy_list) | |
for ques_type, accuracy_list in ques_type_dict.items() | |
} | |
return result | |