Spaces:
Running
Running
# Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
""" WIKI_SPLIT metric.""" | |
import re | |
import string | |
from collections import Counter | |
import datasets | |
import sacrebleu | |
import sacremoses | |
from packaging import version | |
import evaluate | |
_CITATION = """ | |
@inproceedings{xu-etal-2016-optimizing, | |
title = {Optimizing Statistical Machine Translation for Text Simplification}, | |
authors={Xu, Wei and Napoles, Courtney and Pavlick, Ellie and Chen, Quanze and Callison-Burch, Chris}, | |
journal = {Transactions of the Association for Computational Linguistics}, | |
volume = {4}, | |
year={2016}, | |
url = {https://www.aclweb.org/anthology/Q16-1029}, | |
pages = {401--415 | |
}, | |
@inproceedings{post-2018-call, | |
title = "A Call for Clarity in Reporting {BLEU} Scores", | |
author = "Post, Matt", | |
booktitle = "Proceedings of the Third Conference on Machine Translation: Research Papers", | |
month = oct, | |
year = "2018", | |
address = "Belgium, Brussels", | |
publisher = "Association for Computational Linguistics", | |
url = "https://www.aclweb.org/anthology/W18-6319", | |
pages = "186--191", | |
} | |
""" | |
_DESCRIPTION = """\ | |
WIKI_SPLIT is the combination of three metrics SARI, EXACT and SACREBLEU | |
It can be used to evaluate the quality of machine-generated texts. | |
""" | |
_KWARGS_DESCRIPTION = """ | |
Calculates sari score (between 0 and 100) given a list of source and predicted | |
sentences, and a list of lists of reference sentences. It also computes the BLEU score as well as the exact match score. | |
Args: | |
sources: list of source sentences where each sentence should be a string. | |
predictions: list of predicted sentences where each sentence should be a string. | |
references: list of lists of reference sentences where each sentence should be a string. | |
Returns: | |
sari: sari score | |
sacrebleu: sacrebleu score | |
exact: exact score | |
Examples: | |
>>> sources=["About 95 species are currently accepted ."] | |
>>> predictions=["About 95 you now get in ."] | |
>>> references=[["About 95 species are currently known ."]] | |
>>> wiki_split = evaluate.load("wiki_split") | |
>>> results = wiki_split.compute(sources=sources, predictions=predictions, references=references) | |
>>> print(results) | |
{'sari': 21.805555555555557, 'sacrebleu': 14.535768424205482, 'exact': 0.0} | |
""" | |
def normalize_answer(s): | |
"""Lower text and remove punctuation, articles and extra whitespace.""" | |
def remove_articles(text): | |
regex = re.compile(r"\b(a|an|the)\b", re.UNICODE) | |
return re.sub(regex, " ", text) | |
def white_space_fix(text): | |
return " ".join(text.split()) | |
def remove_punc(text): | |
exclude = set(string.punctuation) | |
return "".join(ch for ch in text if ch not in exclude) | |
def lower(text): | |
return text.lower() | |
return white_space_fix(remove_articles(remove_punc(lower(s)))) | |
def compute_exact(a_gold, a_pred): | |
return int(normalize_answer(a_gold) == normalize_answer(a_pred)) | |
def compute_em(predictions, references): | |
scores = [any([compute_exact(ref, pred) for ref in refs]) for pred, refs in zip(predictions, references)] | |
return (sum(scores) / len(scores)) * 100 | |
def SARIngram(sgrams, cgrams, rgramslist, numref): | |
rgramsall = [rgram for rgrams in rgramslist for rgram in rgrams] | |
rgramcounter = Counter(rgramsall) | |
sgramcounter = Counter(sgrams) | |
sgramcounter_rep = Counter() | |
for sgram, scount in sgramcounter.items(): | |
sgramcounter_rep[sgram] = scount * numref | |
cgramcounter = Counter(cgrams) | |
cgramcounter_rep = Counter() | |
for cgram, ccount in cgramcounter.items(): | |
cgramcounter_rep[cgram] = ccount * numref | |
# KEEP | |
keepgramcounter_rep = sgramcounter_rep & cgramcounter_rep | |
keepgramcountergood_rep = keepgramcounter_rep & rgramcounter | |
keepgramcounterall_rep = sgramcounter_rep & rgramcounter | |
keeptmpscore1 = 0 | |
keeptmpscore2 = 0 | |
for keepgram in keepgramcountergood_rep: | |
keeptmpscore1 += keepgramcountergood_rep[keepgram] / keepgramcounter_rep[keepgram] | |
# Fix an alleged bug [2] in the keep score computation. | |
# keeptmpscore2 += keepgramcountergood_rep[keepgram] / keepgramcounterall_rep[keepgram] | |
keeptmpscore2 += keepgramcountergood_rep[keepgram] | |
# Define 0/0=1 instead of 0 to give higher scores for predictions that match | |
# a target exactly. | |
keepscore_precision = 1 | |
keepscore_recall = 1 | |
if len(keepgramcounter_rep) > 0: | |
keepscore_precision = keeptmpscore1 / len(keepgramcounter_rep) | |
if len(keepgramcounterall_rep) > 0: | |
# Fix an alleged bug [2] in the keep score computation. | |
# keepscore_recall = keeptmpscore2 / len(keepgramcounterall_rep) | |
keepscore_recall = keeptmpscore2 / sum(keepgramcounterall_rep.values()) | |
keepscore = 0 | |
if keepscore_precision > 0 or keepscore_recall > 0: | |
keepscore = 2 * keepscore_precision * keepscore_recall / (keepscore_precision + keepscore_recall) | |
# DELETION | |
delgramcounter_rep = sgramcounter_rep - cgramcounter_rep | |
delgramcountergood_rep = delgramcounter_rep - rgramcounter | |
delgramcounterall_rep = sgramcounter_rep - rgramcounter | |
deltmpscore1 = 0 | |
deltmpscore2 = 0 | |
for delgram in delgramcountergood_rep: | |
deltmpscore1 += delgramcountergood_rep[delgram] / delgramcounter_rep[delgram] | |
deltmpscore2 += delgramcountergood_rep[delgram] / delgramcounterall_rep[delgram] | |
# Define 0/0=1 instead of 0 to give higher scores for predictions that match | |
# a target exactly. | |
delscore_precision = 1 | |
if len(delgramcounter_rep) > 0: | |
delscore_precision = deltmpscore1 / len(delgramcounter_rep) | |
# ADDITION | |
addgramcounter = set(cgramcounter) - set(sgramcounter) | |
addgramcountergood = set(addgramcounter) & set(rgramcounter) | |
addgramcounterall = set(rgramcounter) - set(sgramcounter) | |
addtmpscore = 0 | |
for addgram in addgramcountergood: | |
addtmpscore += 1 | |
# Define 0/0=1 instead of 0 to give higher scores for predictions that match | |
# a target exactly. | |
addscore_precision = 1 | |
addscore_recall = 1 | |
if len(addgramcounter) > 0: | |
addscore_precision = addtmpscore / len(addgramcounter) | |
if len(addgramcounterall) > 0: | |
addscore_recall = addtmpscore / len(addgramcounterall) | |
addscore = 0 | |
if addscore_precision > 0 or addscore_recall > 0: | |
addscore = 2 * addscore_precision * addscore_recall / (addscore_precision + addscore_recall) | |
return (keepscore, delscore_precision, addscore) | |
def SARIsent(ssent, csent, rsents): | |
numref = len(rsents) | |
s1grams = ssent.split(" ") | |
c1grams = csent.split(" ") | |
s2grams = [] | |
c2grams = [] | |
s3grams = [] | |
c3grams = [] | |
s4grams = [] | |
c4grams = [] | |
r1gramslist = [] | |
r2gramslist = [] | |
r3gramslist = [] | |
r4gramslist = [] | |
for rsent in rsents: | |
r1grams = rsent.split(" ") | |
r2grams = [] | |
r3grams = [] | |
r4grams = [] | |
r1gramslist.append(r1grams) | |
for i in range(0, len(r1grams) - 1): | |
if i < len(r1grams) - 1: | |
r2gram = r1grams[i] + " " + r1grams[i + 1] | |
r2grams.append(r2gram) | |
if i < len(r1grams) - 2: | |
r3gram = r1grams[i] + " " + r1grams[i + 1] + " " + r1grams[i + 2] | |
r3grams.append(r3gram) | |
if i < len(r1grams) - 3: | |
r4gram = r1grams[i] + " " + r1grams[i + 1] + " " + r1grams[i + 2] + " " + r1grams[i + 3] | |
r4grams.append(r4gram) | |
r2gramslist.append(r2grams) | |
r3gramslist.append(r3grams) | |
r4gramslist.append(r4grams) | |
for i in range(0, len(s1grams) - 1): | |
if i < len(s1grams) - 1: | |
s2gram = s1grams[i] + " " + s1grams[i + 1] | |
s2grams.append(s2gram) | |
if i < len(s1grams) - 2: | |
s3gram = s1grams[i] + " " + s1grams[i + 1] + " " + s1grams[i + 2] | |
s3grams.append(s3gram) | |
if i < len(s1grams) - 3: | |
s4gram = s1grams[i] + " " + s1grams[i + 1] + " " + s1grams[i + 2] + " " + s1grams[i + 3] | |
s4grams.append(s4gram) | |
for i in range(0, len(c1grams) - 1): | |
if i < len(c1grams) - 1: | |
c2gram = c1grams[i] + " " + c1grams[i + 1] | |
c2grams.append(c2gram) | |
if i < len(c1grams) - 2: | |
c3gram = c1grams[i] + " " + c1grams[i + 1] + " " + c1grams[i + 2] | |
c3grams.append(c3gram) | |
if i < len(c1grams) - 3: | |
c4gram = c1grams[i] + " " + c1grams[i + 1] + " " + c1grams[i + 2] + " " + c1grams[i + 3] | |
c4grams.append(c4gram) | |
(keep1score, del1score, add1score) = SARIngram(s1grams, c1grams, r1gramslist, numref) | |
(keep2score, del2score, add2score) = SARIngram(s2grams, c2grams, r2gramslist, numref) | |
(keep3score, del3score, add3score) = SARIngram(s3grams, c3grams, r3gramslist, numref) | |
(keep4score, del4score, add4score) = SARIngram(s4grams, c4grams, r4gramslist, numref) | |
avgkeepscore = sum([keep1score, keep2score, keep3score, keep4score]) / 4 | |
avgdelscore = sum([del1score, del2score, del3score, del4score]) / 4 | |
avgaddscore = sum([add1score, add2score, add3score, add4score]) / 4 | |
finalscore = (avgkeepscore + avgdelscore + avgaddscore) / 3 | |
return finalscore | |
def normalize(sentence, lowercase: bool = True, tokenizer: str = "13a", return_str: bool = True): | |
# Normalization is requried for the ASSET dataset (one of the primary | |
# datasets in sentence simplification) to allow using space | |
# to split the sentence. Even though Wiki-Auto and TURK datasets, | |
# do not require normalization, we do it for consistency. | |
# Code adapted from the EASSE library [1] written by the authors of the ASSET dataset. | |
# [1] https://github.com/feralvam/easse/blob/580bba7e1378fc8289c663f864e0487188fe8067/easse/utils/preprocessing.py#L7 | |
if lowercase: | |
sentence = sentence.lower() | |
if tokenizer in ["13a", "intl"]: | |
if version.parse(sacrebleu.__version__).major >= 2: | |
normalized_sent = sacrebleu.metrics.bleu._get_tokenizer(tokenizer)()(sentence) | |
else: | |
normalized_sent = sacrebleu.TOKENIZERS[tokenizer]()(sentence) | |
elif tokenizer == "moses": | |
normalized_sent = sacremoses.MosesTokenizer().tokenize(sentence, return_str=True, escape=False) | |
elif tokenizer == "penn": | |
normalized_sent = sacremoses.MosesTokenizer().penn_tokenize(sentence, return_str=True) | |
else: | |
normalized_sent = sentence | |
if not return_str: | |
normalized_sent = normalized_sent.split() | |
return normalized_sent | |
def compute_sari(sources, predictions, references): | |
if not (len(sources) == len(predictions) == len(references)): | |
raise ValueError("Sources length must match predictions and references lengths.") | |
sari_score = 0 | |
for src, pred, refs in zip(sources, predictions, references): | |
sari_score += SARIsent(normalize(src), normalize(pred), [normalize(sent) for sent in refs]) | |
sari_score = sari_score / len(predictions) | |
return 100 * sari_score | |
def compute_sacrebleu( | |
predictions, | |
references, | |
smooth_method="exp", | |
smooth_value=None, | |
force=False, | |
lowercase=False, | |
use_effective_order=False, | |
): | |
references_per_prediction = len(references[0]) | |
if any(len(refs) != references_per_prediction for refs in references): | |
raise ValueError("Sacrebleu requires the same number of references for each prediction") | |
transformed_references = [[refs[i] for refs in references] for i in range(references_per_prediction)] | |
output = sacrebleu.corpus_bleu( | |
predictions, | |
transformed_references, | |
smooth_method=smooth_method, | |
smooth_value=smooth_value, | |
force=force, | |
lowercase=lowercase, | |
use_effective_order=use_effective_order, | |
) | |
return output.score | |
class WikiSplit(evaluate.EvaluationModule): | |
def _info(self): | |
return evaluate.EvaluationModuleInfo( | |
description=_DESCRIPTION, | |
citation=_CITATION, | |
inputs_description=_KWARGS_DESCRIPTION, | |
features=datasets.Features( | |
{ | |
"predictions": datasets.Value("string", id="sequence"), | |
"references": datasets.Sequence(datasets.Value("string", id="sequence"), id="references"), | |
} | |
), | |
codebase_urls=[ | |
"https://github.com/huggingface/transformers/blob/master/src/transformers/data/metrics/squad_metrics.py", | |
"https://github.com/cocoxu/simplification/blob/master/SARI.py", | |
"https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/sari_hook.py", | |
"https://github.com/mjpost/sacreBLEU", | |
], | |
reference_urls=[ | |
"https://www.aclweb.org/anthology/Q16-1029.pdf", | |
"https://github.com/mjpost/sacreBLEU", | |
"https://en.wikipedia.org/wiki/BLEU", | |
"https://towardsdatascience.com/evaluating-text-output-in-nlp-bleu-at-your-own-risk-e8609665a213", | |
], | |
) | |
def _compute(self, sources, predictions, references): | |
result = {} | |
result.update({"sari": compute_sari(sources=sources, predictions=predictions, references=references)}) | |
result.update({"sacrebleu": compute_sacrebleu(predictions=predictions, references=references)}) | |
result.update({"exact": compute_em(predictions=predictions, references=references)}) | |
return result | |