Spaces:
Sleeping
Sleeping
# Natural Language Toolkit: evaluation of dependency parser | |
# | |
# Author: Long Duong <[email protected]> | |
# | |
# Copyright (C) 2001-2023 NLTK Project | |
# URL: <https://www.nltk.org/> | |
# For license information, see LICENSE.TXT | |
import unicodedata | |
class DependencyEvaluator: | |
""" | |
Class for measuring labelled and unlabelled attachment score for | |
dependency parsing. Note that the evaluation ignores punctuation. | |
>>> from nltk.parse import DependencyGraph, DependencyEvaluator | |
>>> gold_sent = DependencyGraph(\""" | |
... Pierre NNP 2 NMOD | |
... Vinken NNP 8 SUB | |
... , , 2 P | |
... 61 CD 5 NMOD | |
... years NNS 6 AMOD | |
... old JJ 2 NMOD | |
... , , 2 P | |
... will MD 0 ROOT | |
... join VB 8 VC | |
... the DT 11 NMOD | |
... board NN 9 OBJ | |
... as IN 9 VMOD | |
... a DT 15 NMOD | |
... nonexecutive JJ 15 NMOD | |
... director NN 12 PMOD | |
... Nov. NNP 9 VMOD | |
... 29 CD 16 NMOD | |
... . . 9 VMOD | |
... \""") | |
>>> parsed_sent = DependencyGraph(\""" | |
... Pierre NNP 8 NMOD | |
... Vinken NNP 1 SUB | |
... , , 3 P | |
... 61 CD 6 NMOD | |
... years NNS 6 AMOD | |
... old JJ 2 NMOD | |
... , , 3 AMOD | |
... will MD 0 ROOT | |
... join VB 8 VC | |
... the DT 11 AMOD | |
... board NN 9 OBJECT | |
... as IN 9 NMOD | |
... a DT 15 NMOD | |
... nonexecutive JJ 15 NMOD | |
... director NN 12 PMOD | |
... Nov. NNP 9 VMOD | |
... 29 CD 16 NMOD | |
... . . 9 VMOD | |
... \""") | |
>>> de = DependencyEvaluator([parsed_sent],[gold_sent]) | |
>>> las, uas = de.eval() | |
>>> las | |
0.6 | |
>>> uas | |
0.8 | |
>>> abs(uas - 0.8) < 0.00001 | |
True | |
""" | |
def __init__(self, parsed_sents, gold_sents): | |
""" | |
:param parsed_sents: the list of parsed_sents as the output of parser | |
:type parsed_sents: list(DependencyGraph) | |
""" | |
self._parsed_sents = parsed_sents | |
self._gold_sents = gold_sents | |
def _remove_punct(self, inStr): | |
""" | |
Function to remove punctuation from Unicode string. | |
:param input: the input string | |
:return: Unicode string after remove all punctuation | |
""" | |
punc_cat = {"Pc", "Pd", "Ps", "Pe", "Pi", "Pf", "Po"} | |
return "".join(x for x in inStr if unicodedata.category(x) not in punc_cat) | |
def eval(self): | |
""" | |
Return the Labeled Attachment Score (LAS) and Unlabeled Attachment Score (UAS) | |
:return : tuple(float,float) | |
""" | |
if len(self._parsed_sents) != len(self._gold_sents): | |
raise ValueError( | |
" Number of parsed sentence is different with number of gold sentence." | |
) | |
corr = 0 | |
corrL = 0 | |
total = 0 | |
for i in range(len(self._parsed_sents)): | |
parsed_sent_nodes = self._parsed_sents[i].nodes | |
gold_sent_nodes = self._gold_sents[i].nodes | |
if len(parsed_sent_nodes) != len(gold_sent_nodes): | |
raise ValueError("Sentences must have equal length.") | |
for parsed_node_address, parsed_node in parsed_sent_nodes.items(): | |
gold_node = gold_sent_nodes[parsed_node_address] | |
if parsed_node["word"] is None: | |
continue | |
if parsed_node["word"] != gold_node["word"]: | |
raise ValueError("Sentence sequence is not matched.") | |
# Ignore if word is punctuation by default | |
# if (parsed_sent[j]["word"] in string.punctuation): | |
if self._remove_punct(parsed_node["word"]) == "": | |
continue | |
total += 1 | |
if parsed_node["head"] == gold_node["head"]: | |
corr += 1 | |
if parsed_node["rel"] == gold_node["rel"]: | |
corrL += 1 | |
return corrL / total, corr / total | |