Spaces:
Sleeping
Sleeping
File size: 14,739 Bytes
d916065 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 |
# Natural Language Toolkit: Agreement Metrics
#
# Copyright (C) 2001-2023 NLTK Project
# Author: Lauri Hallila <[email protected]>
# URL: <https://www.nltk.org/>
# For license information, see LICENSE.TXT
#
"""Counts Paice's performance statistics for evaluating stemming algorithms.
What is required:
- A dictionary of words grouped by their real lemmas
- A dictionary of words grouped by stems from a stemming algorithm
When these are given, Understemming Index (UI), Overstemming Index (OI),
Stemming Weight (SW) and Error-rate relative to truncation (ERRT) are counted.
References:
Chris D. Paice (1994). An evaluation method for stemming algorithms.
In Proceedings of SIGIR, 42--50.
"""
from math import sqrt
def get_words_from_dictionary(lemmas):
"""
Get original set of words used for analysis.
:param lemmas: A dictionary where keys are lemmas and values are sets
or lists of words corresponding to that lemma.
:type lemmas: dict(str): list(str)
:return: Set of words that exist as values in the dictionary
:rtype: set(str)
"""
words = set()
for lemma in lemmas:
words.update(set(lemmas[lemma]))
return words
def _truncate(words, cutlength):
"""Group words by stems defined by truncating them at given length.
:param words: Set of words used for analysis
:param cutlength: Words are stemmed by cutting at this length.
:type words: set(str) or list(str)
:type cutlength: int
:return: Dictionary where keys are stems and values are sets of words
corresponding to that stem.
:rtype: dict(str): set(str)
"""
stems = {}
for word in words:
stem = word[:cutlength]
try:
stems[stem].update([word])
except KeyError:
stems[stem] = {word}
return stems
# Reference: https://en.wikipedia.org/wiki/Line-line_intersection
def _count_intersection(l1, l2):
"""Count intersection between two line segments defined by coordinate pairs.
:param l1: Tuple of two coordinate pairs defining the first line segment
:param l2: Tuple of two coordinate pairs defining the second line segment
:type l1: tuple(float, float)
:type l2: tuple(float, float)
:return: Coordinates of the intersection
:rtype: tuple(float, float)
"""
x1, y1 = l1[0]
x2, y2 = l1[1]
x3, y3 = l2[0]
x4, y4 = l2[1]
denominator = (x1 - x2) * (y3 - y4) - (y1 - y2) * (x3 - x4)
if denominator == 0.0: # lines are parallel
if x1 == x2 == x3 == x4 == 0.0:
# When lines are parallel, they must be on the y-axis.
# We can ignore x-axis because we stop counting the
# truncation line when we get there.
# There are no other options as UI (x-axis) grows and
# OI (y-axis) diminishes when we go along the truncation line.
return (0.0, y4)
x = (
(x1 * y2 - y1 * x2) * (x3 - x4) - (x1 - x2) * (x3 * y4 - y3 * x4)
) / denominator
y = (
(x1 * y2 - y1 * x2) * (y3 - y4) - (y1 - y2) * (x3 * y4 - y3 * x4)
) / denominator
return (x, y)
def _get_derivative(coordinates):
"""Get derivative of the line from (0,0) to given coordinates.
:param coordinates: A coordinate pair
:type coordinates: tuple(float, float)
:return: Derivative; inf if x is zero
:rtype: float
"""
try:
return coordinates[1] / coordinates[0]
except ZeroDivisionError:
return float("inf")
def _calculate_cut(lemmawords, stems):
"""Count understemmed and overstemmed pairs for (lemma, stem) pair with common words.
:param lemmawords: Set or list of words corresponding to certain lemma.
:param stems: A dictionary where keys are stems and values are sets
or lists of words corresponding to that stem.
:type lemmawords: set(str) or list(str)
:type stems: dict(str): set(str)
:return: Amount of understemmed and overstemmed pairs contributed by words
existing in both lemmawords and stems.
:rtype: tuple(float, float)
"""
umt, wmt = 0.0, 0.0
for stem in stems:
cut = set(lemmawords) & set(stems[stem])
if cut:
cutcount = len(cut)
stemcount = len(stems[stem])
# Unachieved merge total
umt += cutcount * (len(lemmawords) - cutcount)
# Wrongly merged total
wmt += cutcount * (stemcount - cutcount)
return (umt, wmt)
def _calculate(lemmas, stems):
"""Calculate actual and maximum possible amounts of understemmed and overstemmed word pairs.
:param lemmas: A dictionary where keys are lemmas and values are sets
or lists of words corresponding to that lemma.
:param stems: A dictionary where keys are stems and values are sets
or lists of words corresponding to that stem.
:type lemmas: dict(str): list(str)
:type stems: dict(str): set(str)
:return: Global unachieved merge total (gumt),
global desired merge total (gdmt),
global wrongly merged total (gwmt) and
global desired non-merge total (gdnt).
:rtype: tuple(float, float, float, float)
"""
n = sum(len(lemmas[word]) for word in lemmas)
gdmt, gdnt, gumt, gwmt = (0.0, 0.0, 0.0, 0.0)
for lemma in lemmas:
lemmacount = len(lemmas[lemma])
# Desired merge total
gdmt += lemmacount * (lemmacount - 1)
# Desired non-merge total
gdnt += lemmacount * (n - lemmacount)
# For each (lemma, stem) pair with common words, count how many
# pairs are understemmed and overstemmed.
umt, wmt = _calculate_cut(lemmas[lemma], stems)
# Add to total undesired and wrongly-merged totals
gumt += umt
gwmt += wmt
# Each object is counted twice, so divide by two
return (gumt / 2, gdmt / 2, gwmt / 2, gdnt / 2)
def _indexes(gumt, gdmt, gwmt, gdnt):
"""Count Understemming Index (UI), Overstemming Index (OI) and Stemming Weight (SW).
:param gumt, gdmt, gwmt, gdnt: Global unachieved merge total (gumt),
global desired merge total (gdmt),
global wrongly merged total (gwmt) and
global desired non-merge total (gdnt).
:type gumt, gdmt, gwmt, gdnt: float
:return: Understemming Index (UI),
Overstemming Index (OI) and
Stemming Weight (SW).
:rtype: tuple(float, float, float)
"""
# Calculate Understemming Index (UI),
# Overstemming Index (OI) and Stemming Weight (SW)
try:
ui = gumt / gdmt
except ZeroDivisionError:
# If GDMT (max merge total) is 0, define UI as 0
ui = 0.0
try:
oi = gwmt / gdnt
except ZeroDivisionError:
# IF GDNT (max non-merge total) is 0, define OI as 0
oi = 0.0
try:
sw = oi / ui
except ZeroDivisionError:
if oi == 0.0:
# OI and UI are 0, define SW as 'not a number'
sw = float("nan")
else:
# UI is 0, define SW as infinity
sw = float("inf")
return (ui, oi, sw)
class Paice:
"""Class for storing lemmas, stems and evaluation metrics."""
def __init__(self, lemmas, stems):
"""
:param lemmas: A dictionary where keys are lemmas and values are sets
or lists of words corresponding to that lemma.
:param stems: A dictionary where keys are stems and values are sets
or lists of words corresponding to that stem.
:type lemmas: dict(str): list(str)
:type stems: dict(str): set(str)
"""
self.lemmas = lemmas
self.stems = stems
self.coords = []
self.gumt, self.gdmt, self.gwmt, self.gdnt = (None, None, None, None)
self.ui, self.oi, self.sw = (None, None, None)
self.errt = None
self.update()
def __str__(self):
text = ["Global Unachieved Merge Total (GUMT): %s\n" % self.gumt]
text.append("Global Desired Merge Total (GDMT): %s\n" % self.gdmt)
text.append("Global Wrongly-Merged Total (GWMT): %s\n" % self.gwmt)
text.append("Global Desired Non-merge Total (GDNT): %s\n" % self.gdnt)
text.append("Understemming Index (GUMT / GDMT): %s\n" % self.ui)
text.append("Overstemming Index (GWMT / GDNT): %s\n" % self.oi)
text.append("Stemming Weight (OI / UI): %s\n" % self.sw)
text.append("Error-Rate Relative to Truncation (ERRT): %s\r\n" % self.errt)
coordinates = " ".join(["(%s, %s)" % item for item in self.coords])
text.append("Truncation line: %s" % coordinates)
return "".join(text)
def _get_truncation_indexes(self, words, cutlength):
"""Count (UI, OI) when stemming is done by truncating words at \'cutlength\'.
:param words: Words used for the analysis
:param cutlength: Words are stemmed by cutting them at this length
:type words: set(str) or list(str)
:type cutlength: int
:return: Understemming and overstemming indexes
:rtype: tuple(int, int)
"""
truncated = _truncate(words, cutlength)
gumt, gdmt, gwmt, gdnt = _calculate(self.lemmas, truncated)
ui, oi = _indexes(gumt, gdmt, gwmt, gdnt)[:2]
return (ui, oi)
def _get_truncation_coordinates(self, cutlength=0):
"""Count (UI, OI) pairs for truncation points until we find the segment where (ui, oi) crosses the truncation line.
:param cutlength: Optional parameter to start counting from (ui, oi)
coordinates gotten by stemming at this length. Useful for speeding up
the calculations when you know the approximate location of the
intersection.
:type cutlength: int
:return: List of coordinate pairs that define the truncation line
:rtype: list(tuple(float, float))
"""
words = get_words_from_dictionary(self.lemmas)
maxlength = max(len(word) for word in words)
# Truncate words from different points until (0, 0) - (ui, oi) segment crosses the truncation line
coords = []
while cutlength <= maxlength:
# Get (UI, OI) pair of current truncation point
pair = self._get_truncation_indexes(words, cutlength)
# Store only new coordinates so we'll have an actual
# line segment when counting the intersection point
if pair not in coords:
coords.append(pair)
if pair == (0.0, 0.0):
# Stop counting if truncation line goes through origo;
# length from origo to truncation line is 0
return coords
if len(coords) >= 2 and pair[0] > 0.0:
derivative1 = _get_derivative(coords[-2])
derivative2 = _get_derivative(coords[-1])
# Derivative of the truncation line is a decreasing value;
# when it passes Stemming Weight, we've found the segment
# of truncation line intersecting with (0, 0) - (ui, oi) segment
if derivative1 >= self.sw >= derivative2:
return coords
cutlength += 1
return coords
def _errt(self):
"""Count Error-Rate Relative to Truncation (ERRT).
:return: ERRT, length of the line from origo to (UI, OI) divided by
the length of the line from origo to the point defined by the same
line when extended until the truncation line.
:rtype: float
"""
# Count (UI, OI) pairs for truncation points until we find the segment where (ui, oi) crosses the truncation line
self.coords = self._get_truncation_coordinates()
if (0.0, 0.0) in self.coords:
# Truncation line goes through origo, so ERRT cannot be counted
if (self.ui, self.oi) != (0.0, 0.0):
return float("inf")
else:
return float("nan")
if (self.ui, self.oi) == (0.0, 0.0):
# (ui, oi) is origo; define errt as 0.0
return 0.0
# Count the intersection point
# Note that (self.ui, self.oi) cannot be (0.0, 0.0) and self.coords has different coordinates
# so we have actual line segments instead of a line segment and a point
intersection = _count_intersection(
((0, 0), (self.ui, self.oi)), self.coords[-2:]
)
# Count OP (length of the line from origo to (ui, oi))
op = sqrt(self.ui**2 + self.oi**2)
# Count OT (length of the line from origo to truncation line that goes through (ui, oi))
ot = sqrt(intersection[0] ** 2 + intersection[1] ** 2)
# OP / OT tells how well the stemming algorithm works compared to just truncating words
return op / ot
def update(self):
"""Update statistics after lemmas and stems have been set."""
self.gumt, self.gdmt, self.gwmt, self.gdnt = _calculate(self.lemmas, self.stems)
self.ui, self.oi, self.sw = _indexes(self.gumt, self.gdmt, self.gwmt, self.gdnt)
self.errt = self._errt()
def demo():
"""Demonstration of the module."""
# Some words with their real lemmas
lemmas = {
"kneel": ["kneel", "knelt"],
"range": ["range", "ranged"],
"ring": ["ring", "rang", "rung"],
}
# Same words with stems from a stemming algorithm
stems = {
"kneel": ["kneel"],
"knelt": ["knelt"],
"rang": ["rang", "range", "ranged"],
"ring": ["ring"],
"rung": ["rung"],
}
print("Words grouped by their lemmas:")
for lemma in sorted(lemmas):
print("{} => {}".format(lemma, " ".join(lemmas[lemma])))
print()
print("Same words grouped by a stemming algorithm:")
for stem in sorted(stems):
print("{} => {}".format(stem, " ".join(stems[stem])))
print()
p = Paice(lemmas, stems)
print(p)
print()
# Let's "change" results from a stemming algorithm
stems = {
"kneel": ["kneel"],
"knelt": ["knelt"],
"rang": ["rang"],
"range": ["range", "ranged"],
"ring": ["ring"],
"rung": ["rung"],
}
print("Counting stats after changing stemming results:")
for stem in sorted(stems):
print("{} => {}".format(stem, " ".join(stems[stem])))
print()
p.stems = stems
p.update()
print(p)
if __name__ == "__main__":
demo()
|