Spaces:
Sleeping
Sleeping
# Natural Language Toolkit | |
# | |
# Copyright (C) 2001-2023 NLTK Project | |
# Author: Ilia Kurenkov <[email protected]> | |
# URL: <https://www.nltk.org/> | |
# For license information, see LICENSE.TXT | |
"""Language Model Vocabulary""" | |
import sys | |
from collections import Counter | |
from collections.abc import Iterable | |
from functools import singledispatch | |
from itertools import chain | |
def _dispatched_lookup(words, vocab): | |
raise TypeError(f"Unsupported type for looking up in vocabulary: {type(words)}") | |
def _(words, vocab): | |
"""Look up a sequence of words in the vocabulary. | |
Returns an iterator over looked up words. | |
""" | |
return tuple(_dispatched_lookup(w, vocab) for w in words) | |
def _string_lookup(word, vocab): | |
"""Looks up one word in the vocabulary.""" | |
return word if word in vocab else vocab.unk_label | |
class Vocabulary: | |
"""Stores language model vocabulary. | |
Satisfies two common language modeling requirements for a vocabulary: | |
- When checking membership and calculating its size, filters items | |
by comparing their counts to a cutoff value. | |
- Adds a special "unknown" token which unseen words are mapped to. | |
>>> words = ['a', 'c', '-', 'd', 'c', 'a', 'b', 'r', 'a', 'c', 'd'] | |
>>> from nltk.lm import Vocabulary | |
>>> vocab = Vocabulary(words, unk_cutoff=2) | |
Tokens with counts greater than or equal to the cutoff value will | |
be considered part of the vocabulary. | |
>>> vocab['c'] | |
3 | |
>>> 'c' in vocab | |
True | |
>>> vocab['d'] | |
2 | |
>>> 'd' in vocab | |
True | |
Tokens with frequency counts less than the cutoff value will be considered not | |
part of the vocabulary even though their entries in the count dictionary are | |
preserved. | |
>>> vocab['b'] | |
1 | |
>>> 'b' in vocab | |
False | |
>>> vocab['aliens'] | |
0 | |
>>> 'aliens' in vocab | |
False | |
Keeping the count entries for seen words allows us to change the cutoff value | |
without having to recalculate the counts. | |
>>> vocab2 = Vocabulary(vocab.counts, unk_cutoff=1) | |
>>> "b" in vocab2 | |
True | |
The cutoff value influences not only membership checking but also the result of | |
getting the size of the vocabulary using the built-in `len`. | |
Note that while the number of keys in the vocabulary's counter stays the same, | |
the items in the vocabulary differ depending on the cutoff. | |
We use `sorted` to demonstrate because it keeps the order consistent. | |
>>> sorted(vocab2.counts) | |
['-', 'a', 'b', 'c', 'd', 'r'] | |
>>> sorted(vocab2) | |
['-', '<UNK>', 'a', 'b', 'c', 'd', 'r'] | |
>>> sorted(vocab.counts) | |
['-', 'a', 'b', 'c', 'd', 'r'] | |
>>> sorted(vocab) | |
['<UNK>', 'a', 'c', 'd'] | |
In addition to items it gets populated with, the vocabulary stores a special | |
token that stands in for so-called "unknown" items. By default it's "<UNK>". | |
>>> "<UNK>" in vocab | |
True | |
We can look up words in a vocabulary using its `lookup` method. | |
"Unseen" words (with counts less than cutoff) are looked up as the unknown label. | |
If given one word (a string) as an input, this method will return a string. | |
>>> vocab.lookup("a") | |
'a' | |
>>> vocab.lookup("aliens") | |
'<UNK>' | |
If given a sequence, it will return an tuple of the looked up words. | |
>>> vocab.lookup(["p", 'a', 'r', 'd', 'b', 'c']) | |
('<UNK>', 'a', '<UNK>', 'd', '<UNK>', 'c') | |
It's possible to update the counts after the vocabulary has been created. | |
In general, the interface is the same as that of `collections.Counter`. | |
>>> vocab['b'] | |
1 | |
>>> vocab.update(["b", "b", "c"]) | |
>>> vocab['b'] | |
3 | |
""" | |
def __init__(self, counts=None, unk_cutoff=1, unk_label="<UNK>"): | |
"""Create a new Vocabulary. | |
:param counts: Optional iterable or `collections.Counter` instance to | |
pre-seed the Vocabulary. In case it is iterable, counts | |
are calculated. | |
:param int unk_cutoff: Words that occur less frequently than this value | |
are not considered part of the vocabulary. | |
:param unk_label: Label for marking words not part of vocabulary. | |
""" | |
self.unk_label = unk_label | |
if unk_cutoff < 1: | |
raise ValueError(f"Cutoff value cannot be less than 1. Got: {unk_cutoff}") | |
self._cutoff = unk_cutoff | |
self.counts = Counter() | |
self.update(counts if counts is not None else "") | |
def cutoff(self): | |
"""Cutoff value. | |
Items with count below this value are not considered part of vocabulary. | |
""" | |
return self._cutoff | |
def update(self, *counter_args, **counter_kwargs): | |
"""Update vocabulary counts. | |
Wraps `collections.Counter.update` method. | |
""" | |
self.counts.update(*counter_args, **counter_kwargs) | |
self._len = sum(1 for _ in self) | |
def lookup(self, words): | |
"""Look up one or more words in the vocabulary. | |
If passed one word as a string will return that word or `self.unk_label`. | |
Otherwise will assume it was passed a sequence of words, will try to look | |
each of them up and return an iterator over the looked up words. | |
:param words: Word(s) to look up. | |
:type words: Iterable(str) or str | |
:rtype: generator(str) or str | |
:raises: TypeError for types other than strings or iterables | |
>>> from nltk.lm import Vocabulary | |
>>> vocab = Vocabulary(["a", "b", "c", "a", "b"], unk_cutoff=2) | |
>>> vocab.lookup("a") | |
'a' | |
>>> vocab.lookup("aliens") | |
'<UNK>' | |
>>> vocab.lookup(["a", "b", "c", ["x", "b"]]) | |
('a', 'b', '<UNK>', ('<UNK>', 'b')) | |
""" | |
return _dispatched_lookup(words, self) | |
def __getitem__(self, item): | |
return self._cutoff if item == self.unk_label else self.counts[item] | |
def __contains__(self, item): | |
"""Only consider items with counts GE to cutoff as being in the | |
vocabulary.""" | |
return self[item] >= self.cutoff | |
def __iter__(self): | |
"""Building on membership check define how to iterate over | |
vocabulary.""" | |
return chain( | |
(item for item in self.counts if item in self), | |
[self.unk_label] if self.counts else [], | |
) | |
def __len__(self): | |
"""Computing size of vocabulary reflects the cutoff.""" | |
return self._len | |
def __eq__(self, other): | |
return ( | |
self.unk_label == other.unk_label | |
and self.cutoff == other.cutoff | |
and self.counts == other.counts | |
) | |
def __str__(self): | |
return "<{} with cutoff={} unk_label='{}' and {} items>".format( | |
self.__class__.__name__, self.cutoff, self.unk_label, len(self) | |
) | |