# Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ WIKI_SPLIT metric.""" import re import string from collections import Counter import datasets import sacrebleu import sacremoses from packaging import version import evaluate _CITATION = """ @inproceedings{xu-etal-2016-optimizing, title = {Optimizing Statistical Machine Translation for Text Simplification}, authors={Xu, Wei and Napoles, Courtney and Pavlick, Ellie and Chen, Quanze and Callison-Burch, Chris}, journal = {Transactions of the Association for Computational Linguistics}, volume = {4}, year={2016}, url = {https://www.aclweb.org/anthology/Q16-1029}, pages = {401--415 }, @inproceedings{post-2018-call, title = "A Call for Clarity in Reporting {BLEU} Scores", author = "Post, Matt", booktitle = "Proceedings of the Third Conference on Machine Translation: Research Papers", month = oct, year = "2018", address = "Belgium, Brussels", publisher = "Association for Computational Linguistics", url = "https://www.aclweb.org/anthology/W18-6319", pages = "186--191", } """ _DESCRIPTION = """\ WIKI_SPLIT is the combination of three metrics SARI, EXACT and SACREBLEU It can be used to evaluate the quality of machine-generated texts. """ _KWARGS_DESCRIPTION = """ Calculates sari score (between 0 and 100) given a list of source and predicted sentences, and a list of lists of reference sentences. It also computes the BLEU score as well as the exact match score. Args: sources: list of source sentences where each sentence should be a string. predictions: list of predicted sentences where each sentence should be a string. references: list of lists of reference sentences where each sentence should be a string. Returns: sari: sari score sacrebleu: sacrebleu score exact: exact score Examples: >>> sources=["About 95 species are currently accepted ."] >>> predictions=["About 95 you now get in ."] >>> references=[["About 95 species are currently known ."]] >>> wiki_split = evaluate.load("wiki_split") >>> results = wiki_split.compute(sources=sources, predictions=predictions, references=references) >>> print(results) {'sari': 21.805555555555557, 'sacrebleu': 14.535768424205482, 'exact': 0.0} """ def normalize_answer(s): """Lower text and remove punctuation, articles and extra whitespace.""" def remove_articles(text): regex = re.compile(r"\b(a|an|the)\b", re.UNICODE) return re.sub(regex, " ", text) def white_space_fix(text): return " ".join(text.split()) def remove_punc(text): exclude = set(string.punctuation) return "".join(ch for ch in text if ch not in exclude) def lower(text): return text.lower() return white_space_fix(remove_articles(remove_punc(lower(s)))) def compute_exact(a_gold, a_pred): return int(normalize_answer(a_gold) == normalize_answer(a_pred)) def compute_em(predictions, references): scores = [any([compute_exact(ref, pred) for ref in refs]) for pred, refs in zip(predictions, references)] return (sum(scores) / len(scores)) * 100 def SARIngram(sgrams, cgrams, rgramslist, numref): rgramsall = [rgram for rgrams in rgramslist for rgram in rgrams] rgramcounter = Counter(rgramsall) sgramcounter = Counter(sgrams) sgramcounter_rep = Counter() for sgram, scount in sgramcounter.items(): sgramcounter_rep[sgram] = scount * numref cgramcounter = Counter(cgrams) cgramcounter_rep = Counter() for cgram, ccount in cgramcounter.items(): cgramcounter_rep[cgram] = ccount * numref # KEEP keepgramcounter_rep = sgramcounter_rep & cgramcounter_rep keepgramcountergood_rep = keepgramcounter_rep & rgramcounter keepgramcounterall_rep = sgramcounter_rep & rgramcounter keeptmpscore1 = 0 keeptmpscore2 = 0 for keepgram in keepgramcountergood_rep: keeptmpscore1 += keepgramcountergood_rep[keepgram] / keepgramcounter_rep[keepgram] # Fix an alleged bug [2] in the keep score computation. # keeptmpscore2 += keepgramcountergood_rep[keepgram] / keepgramcounterall_rep[keepgram] keeptmpscore2 += keepgramcountergood_rep[keepgram] # Define 0/0=1 instead of 0 to give higher scores for predictions that match # a target exactly. keepscore_precision = 1 keepscore_recall = 1 if len(keepgramcounter_rep) > 0: keepscore_precision = keeptmpscore1 / len(keepgramcounter_rep) if len(keepgramcounterall_rep) > 0: # Fix an alleged bug [2] in the keep score computation. # keepscore_recall = keeptmpscore2 / len(keepgramcounterall_rep) keepscore_recall = keeptmpscore2 / sum(keepgramcounterall_rep.values()) keepscore = 0 if keepscore_precision > 0 or keepscore_recall > 0: keepscore = 2 * keepscore_precision * keepscore_recall / (keepscore_precision + keepscore_recall) # DELETION delgramcounter_rep = sgramcounter_rep - cgramcounter_rep delgramcountergood_rep = delgramcounter_rep - rgramcounter delgramcounterall_rep = sgramcounter_rep - rgramcounter deltmpscore1 = 0 deltmpscore2 = 0 for delgram in delgramcountergood_rep: deltmpscore1 += delgramcountergood_rep[delgram] / delgramcounter_rep[delgram] deltmpscore2 += delgramcountergood_rep[delgram] / delgramcounterall_rep[delgram] # Define 0/0=1 instead of 0 to give higher scores for predictions that match # a target exactly. delscore_precision = 1 if len(delgramcounter_rep) > 0: delscore_precision = deltmpscore1 / len(delgramcounter_rep) # ADDITION addgramcounter = set(cgramcounter) - set(sgramcounter) addgramcountergood = set(addgramcounter) & set(rgramcounter) addgramcounterall = set(rgramcounter) - set(sgramcounter) addtmpscore = 0 for addgram in addgramcountergood: addtmpscore += 1 # Define 0/0=1 instead of 0 to give higher scores for predictions that match # a target exactly. addscore_precision = 1 addscore_recall = 1 if len(addgramcounter) > 0: addscore_precision = addtmpscore / len(addgramcounter) if len(addgramcounterall) > 0: addscore_recall = addtmpscore / len(addgramcounterall) addscore = 0 if addscore_precision > 0 or addscore_recall > 0: addscore = 2 * addscore_precision * addscore_recall / (addscore_precision + addscore_recall) return (keepscore, delscore_precision, addscore) def SARIsent(ssent, csent, rsents): numref = len(rsents) s1grams = ssent.split(" ") c1grams = csent.split(" ") s2grams = [] c2grams = [] s3grams = [] c3grams = [] s4grams = [] c4grams = [] r1gramslist = [] r2gramslist = [] r3gramslist = [] r4gramslist = [] for rsent in rsents: r1grams = rsent.split(" ") r2grams = [] r3grams = [] r4grams = [] r1gramslist.append(r1grams) for i in range(0, len(r1grams) - 1): if i < len(r1grams) - 1: r2gram = r1grams[i] + " " + r1grams[i + 1] r2grams.append(r2gram) if i < len(r1grams) - 2: r3gram = r1grams[i] + " " + r1grams[i + 1] + " " + r1grams[i + 2] r3grams.append(r3gram) if i < len(r1grams) - 3: r4gram = r1grams[i] + " " + r1grams[i + 1] + " " + r1grams[i + 2] + " " + r1grams[i + 3] r4grams.append(r4gram) r2gramslist.append(r2grams) r3gramslist.append(r3grams) r4gramslist.append(r4grams) for i in range(0, len(s1grams) - 1): if i < len(s1grams) - 1: s2gram = s1grams[i] + " " + s1grams[i + 1] s2grams.append(s2gram) if i < len(s1grams) - 2: s3gram = s1grams[i] + " " + s1grams[i + 1] + " " + s1grams[i + 2] s3grams.append(s3gram) if i < len(s1grams) - 3: s4gram = s1grams[i] + " " + s1grams[i + 1] + " " + s1grams[i + 2] + " " + s1grams[i + 3] s4grams.append(s4gram) for i in range(0, len(c1grams) - 1): if i < len(c1grams) - 1: c2gram = c1grams[i] + " " + c1grams[i + 1] c2grams.append(c2gram) if i < len(c1grams) - 2: c3gram = c1grams[i] + " " + c1grams[i + 1] + " " + c1grams[i + 2] c3grams.append(c3gram) if i < len(c1grams) - 3: c4gram = c1grams[i] + " " + c1grams[i + 1] + " " + c1grams[i + 2] + " " + c1grams[i + 3] c4grams.append(c4gram) (keep1score, del1score, add1score) = SARIngram(s1grams, c1grams, r1gramslist, numref) (keep2score, del2score, add2score) = SARIngram(s2grams, c2grams, r2gramslist, numref) (keep3score, del3score, add3score) = SARIngram(s3grams, c3grams, r3gramslist, numref) (keep4score, del4score, add4score) = SARIngram(s4grams, c4grams, r4gramslist, numref) avgkeepscore = sum([keep1score, keep2score, keep3score, keep4score]) / 4 avgdelscore = sum([del1score, del2score, del3score, del4score]) / 4 avgaddscore = sum([add1score, add2score, add3score, add4score]) / 4 finalscore = (avgkeepscore + avgdelscore + avgaddscore) / 3 return finalscore def normalize(sentence, lowercase: bool = True, tokenizer: str = "13a", return_str: bool = True): # Normalization is requried for the ASSET dataset (one of the primary # datasets in sentence simplification) to allow using space # to split the sentence. Even though Wiki-Auto and TURK datasets, # do not require normalization, we do it for consistency. # Code adapted from the EASSE library [1] written by the authors of the ASSET dataset. # [1] https://github.com/feralvam/easse/blob/580bba7e1378fc8289c663f864e0487188fe8067/easse/utils/preprocessing.py#L7 if lowercase: sentence = sentence.lower() if tokenizer in ["13a", "intl"]: if version.parse(sacrebleu.__version__).major >= 2: normalized_sent = sacrebleu.metrics.bleu._get_tokenizer(tokenizer)()(sentence) else: normalized_sent = sacrebleu.TOKENIZERS[tokenizer]()(sentence) elif tokenizer == "moses": normalized_sent = sacremoses.MosesTokenizer().tokenize(sentence, return_str=True, escape=False) elif tokenizer == "penn": normalized_sent = sacremoses.MosesTokenizer().penn_tokenize(sentence, return_str=True) else: normalized_sent = sentence if not return_str: normalized_sent = normalized_sent.split() return normalized_sent def compute_sari(sources, predictions, references): if not (len(sources) == len(predictions) == len(references)): raise ValueError("Sources length must match predictions and references lengths.") sari_score = 0 for src, pred, refs in zip(sources, predictions, references): sari_score += SARIsent(normalize(src), normalize(pred), [normalize(sent) for sent in refs]) sari_score = sari_score / len(predictions) return 100 * sari_score def compute_sacrebleu( predictions, references, smooth_method="exp", smooth_value=None, force=False, lowercase=False, use_effective_order=False, ): references_per_prediction = len(references[0]) if any(len(refs) != references_per_prediction for refs in references): raise ValueError("Sacrebleu requires the same number of references for each prediction") transformed_references = [[refs[i] for refs in references] for i in range(references_per_prediction)] output = sacrebleu.corpus_bleu( predictions, transformed_references, smooth_method=smooth_method, smooth_value=smooth_value, force=force, lowercase=lowercase, use_effective_order=use_effective_order, ) return output.score @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) class WikiSplit(evaluate.Metric): def _info(self): return evaluate.MetricInfo( description=_DESCRIPTION, citation=_CITATION, inputs_description=_KWARGS_DESCRIPTION, features=[ datasets.Features( { "predictions": datasets.Value("string", id="sequence"), "references": datasets.Sequence(datasets.Value("string", id="sequence"), id="references"), } ), datasets.Features( { "predictions": datasets.Value("string", id="sequence"), "references": datasets.Value("string", id="sequence"), } ), ], codebase_urls=[ "https://github.com/huggingface/transformers/blob/master/src/transformers/data/metrics/squad_metrics.py", "https://github.com/cocoxu/simplification/blob/master/SARI.py", "https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/sari_hook.py", "https://github.com/mjpost/sacreBLEU", ], reference_urls=[ "https://www.aclweb.org/anthology/Q16-1029.pdf", "https://github.com/mjpost/sacreBLEU", "https://en.wikipedia.org/wiki/BLEU", "https://towardsdatascience.com/evaluating-text-output-in-nlp-bleu-at-your-own-risk-e8609665a213", ], ) def _compute(self, sources, predictions, references): # if only one reference is provided make sure we still use list of lists if isinstance(references[0], str): references = [[ref] for ref in references] result = {} result.update({"sari": compute_sari(sources=sources, predictions=predictions, references=references)}) result.update({"sacrebleu": compute_sacrebleu(predictions=predictions, references=references)}) result.update({"exact": compute_em(predictions=predictions, references=references)}) return result