Spaces:
Sleeping
Sleeping
# Natural Language Toolkit: Language Model Unit Tests | |
# | |
# Copyright (C) 2001-2023 NLTK Project | |
# Author: Ilia Kurenkov <[email protected]> | |
# URL: <https://www.nltk.org/> | |
# For license information, see LICENSE.TXT | |
import unittest | |
import pytest | |
from nltk import FreqDist | |
from nltk.lm import NgramCounter | |
from nltk.util import everygrams | |
class TestNgramCounter: | |
"""Tests for NgramCounter that only involve lookup, no modification.""" | |
def setup_class(self): | |
text = [list("abcd"), list("egdbe")] | |
self.trigram_counter = NgramCounter( | |
everygrams(sent, max_len=3) for sent in text | |
) | |
self.bigram_counter = NgramCounter(everygrams(sent, max_len=2) for sent in text) | |
self.case = unittest.TestCase() | |
def test_N(self): | |
assert self.bigram_counter.N() == 16 | |
assert self.trigram_counter.N() == 21 | |
def test_counter_len_changes_with_lookup(self): | |
assert len(self.bigram_counter) == 2 | |
self.bigram_counter[50] | |
assert len(self.bigram_counter) == 3 | |
def test_ngram_order_access_unigrams(self): | |
assert self.bigram_counter[1] == self.bigram_counter.unigrams | |
def test_ngram_conditional_freqdist(self): | |
case = unittest.TestCase() | |
expected_trigram_contexts = [ | |
("a", "b"), | |
("b", "c"), | |
("e", "g"), | |
("g", "d"), | |
("d", "b"), | |
] | |
expected_bigram_contexts = [("a",), ("b",), ("d",), ("e",), ("c",), ("g",)] | |
bigrams = self.trigram_counter[2] | |
trigrams = self.trigram_counter[3] | |
self.case.assertCountEqual(expected_bigram_contexts, bigrams.conditions()) | |
self.case.assertCountEqual(expected_trigram_contexts, trigrams.conditions()) | |
def test_bigram_counts_seen_ngrams(self): | |
assert self.bigram_counter[["a"]]["b"] == 1 | |
assert self.bigram_counter[["b"]]["c"] == 1 | |
def test_bigram_counts_unseen_ngrams(self): | |
assert self.bigram_counter[["b"]]["z"] == 0 | |
def test_unigram_counts_seen_words(self): | |
assert self.bigram_counter["b"] == 2 | |
def test_unigram_counts_completely_unseen_words(self): | |
assert self.bigram_counter["z"] == 0 | |
class TestNgramCounterTraining: | |
def setup_class(self): | |
self.counter = NgramCounter() | |
self.case = unittest.TestCase() | |
def test_empty_inputs(self, case): | |
test = NgramCounter(case) | |
assert 2 not in test | |
assert test[1] == FreqDist() | |
def test_train_on_unigrams(self): | |
words = list("abcd") | |
counter = NgramCounter([[(w,) for w in words]]) | |
assert not counter[3] | |
assert not counter[2] | |
self.case.assertCountEqual(words, counter[1].keys()) | |
def test_train_on_illegal_sentences(self): | |
str_sent = ["Check", "this", "out", "!"] | |
list_sent = [["Check", "this"], ["this", "out"], ["out", "!"]] | |
with pytest.raises(TypeError): | |
NgramCounter([str_sent]) | |
with pytest.raises(TypeError): | |
NgramCounter([list_sent]) | |
def test_train_on_bigrams(self): | |
bigram_sent = [("a", "b"), ("c", "d")] | |
counter = NgramCounter([bigram_sent]) | |
assert not bool(counter[3]) | |
def test_train_on_mix(self): | |
mixed_sent = [("a", "b"), ("c", "d"), ("e", "f", "g"), ("h",)] | |
counter = NgramCounter([mixed_sent]) | |
unigrams = ["h"] | |
bigram_contexts = [("a",), ("c",)] | |
trigram_contexts = [("e", "f")] | |
self.case.assertCountEqual(unigrams, counter[1].keys()) | |
self.case.assertCountEqual(bigram_contexts, counter[2].keys()) | |
self.case.assertCountEqual(trigram_contexts, counter[3].keys()) | |