Spaces:
Sleeping
Sleeping
# Natural Language Toolkit: Language Model Unit Tests | |
# | |
# Copyright (C) 2001-2023 NLTK Project | |
# Author: Ilia Kurenkov <[email protected]> | |
# URL: <https://www.nltk.org/> | |
# For license information, see LICENSE.TXT | |
import math | |
from operator import itemgetter | |
import pytest | |
from nltk.lm import ( | |
MLE, | |
AbsoluteDiscountingInterpolated, | |
KneserNeyInterpolated, | |
Laplace, | |
Lidstone, | |
StupidBackoff, | |
Vocabulary, | |
WittenBellInterpolated, | |
) | |
from nltk.lm.preprocessing import padded_everygrams | |
def vocabulary(): | |
return Vocabulary(["a", "b", "c", "d", "z", "<s>", "</s>"], unk_cutoff=1) | |
def training_data(): | |
return [["a", "b", "c", "d"], ["e", "g", "a", "d", "b", "e"]] | |
def bigram_training_data(training_data): | |
return [list(padded_everygrams(2, sent)) for sent in training_data] | |
def trigram_training_data(training_data): | |
return [list(padded_everygrams(3, sent)) for sent in training_data] | |
def mle_bigram_model(vocabulary, bigram_training_data): | |
model = MLE(2, vocabulary=vocabulary) | |
model.fit(bigram_training_data) | |
return model | |
def test_mle_bigram_scores(mle_bigram_model, word, context, expected_score): | |
assert pytest.approx(mle_bigram_model.score(word, context), 1e-4) == expected_score | |
def test_mle_bigram_logscore_for_zero_score(mle_bigram_model): | |
assert math.isinf(mle_bigram_model.logscore("d", ["e"])) | |
def test_mle_bigram_entropy_perplexity_seen(mle_bigram_model): | |
# ngrams seen during training | |
trained = [ | |
("<s>", "a"), | |
("a", "b"), | |
("b", "<UNK>"), | |
("<UNK>", "a"), | |
("a", "d"), | |
("d", "</s>"), | |
] | |
# Ngram = Log score | |
# <s>, a = -1 | |
# a, b = -1 | |
# b, UNK = -1 | |
# UNK, a = -1.585 | |
# a, d = -1 | |
# d, </s> = -1 | |
# TOTAL logscores = -6.585 | |
# - AVG logscores = 1.0975 | |
H = 1.0975 | |
perplexity = 2.1398 | |
assert pytest.approx(mle_bigram_model.entropy(trained), 1e-4) == H | |
assert pytest.approx(mle_bigram_model.perplexity(trained), 1e-4) == perplexity | |
def test_mle_bigram_entropy_perplexity_unseen(mle_bigram_model): | |
# In MLE, even one unseen ngram should make entropy and perplexity infinite | |
untrained = [("<s>", "a"), ("a", "c"), ("c", "d"), ("d", "</s>")] | |
assert math.isinf(mle_bigram_model.entropy(untrained)) | |
assert math.isinf(mle_bigram_model.perplexity(untrained)) | |
def test_mle_bigram_entropy_perplexity_unigrams(mle_bigram_model): | |
# word = score, log score | |
# <s> = 0.1429, -2.8074 | |
# a = 0.1429, -2.8074 | |
# c = 0.0714, -3.8073 | |
# UNK = 0.2143, -2.2224 | |
# d = 0.1429, -2.8074 | |
# c = 0.0714, -3.8073 | |
# </s> = 0.1429, -2.8074 | |
# TOTAL logscores = -21.6243 | |
# - AVG logscores = 3.0095 | |
H = 3.0095 | |
perplexity = 8.0529 | |
text = [("<s>",), ("a",), ("c",), ("-",), ("d",), ("c",), ("</s>",)] | |
assert pytest.approx(mle_bigram_model.entropy(text), 1e-4) == H | |
assert pytest.approx(mle_bigram_model.perplexity(text), 1e-4) == perplexity | |
def mle_trigram_model(trigram_training_data, vocabulary): | |
model = MLE(order=3, vocabulary=vocabulary) | |
model.fit(trigram_training_data) | |
return model | |
def test_mle_trigram_scores(mle_trigram_model, word, context, expected_score): | |
assert pytest.approx(mle_trigram_model.score(word, context), 1e-4) == expected_score | |
def lidstone_bigram_model(bigram_training_data, vocabulary): | |
model = Lidstone(0.1, order=2, vocabulary=vocabulary) | |
model.fit(bigram_training_data) | |
return model | |
def test_lidstone_bigram_score(lidstone_bigram_model, word, context, expected_score): | |
assert ( | |
pytest.approx(lidstone_bigram_model.score(word, context), 1e-4) | |
== expected_score | |
) | |
def test_lidstone_entropy_perplexity(lidstone_bigram_model): | |
text = [ | |
("<s>", "a"), | |
("a", "c"), | |
("c", "<UNK>"), | |
("<UNK>", "d"), | |
("d", "c"), | |
("c", "</s>"), | |
] | |
# Unlike MLE this should be able to handle completely novel ngrams | |
# Ngram = score, log score | |
# <s>, a = 0.3929, -1.3479 | |
# a, c = 0.0357, -4.8074 | |
# c, UNK = 0.0(5), -4.1699 | |
# UNK, d = 0.0263, -5.2479 | |
# d, c = 0.0357, -4.8074 | |
# c, </s> = 0.0(5), -4.1699 | |
# TOTAL logscore: −24.5504 | |
# - AVG logscore: 4.0917 | |
H = 4.0917 | |
perplexity = 17.0504 | |
assert pytest.approx(lidstone_bigram_model.entropy(text), 1e-4) == H | |
assert pytest.approx(lidstone_bigram_model.perplexity(text), 1e-4) == perplexity | |
def lidstone_trigram_model(trigram_training_data, vocabulary): | |
model = Lidstone(0.1, order=3, vocabulary=vocabulary) | |
model.fit(trigram_training_data) | |
return model | |
def test_lidstone_trigram_score(lidstone_trigram_model, word, context, expected_score): | |
assert ( | |
pytest.approx(lidstone_trigram_model.score(word, context), 1e-4) | |
== expected_score | |
) | |
def laplace_bigram_model(bigram_training_data, vocabulary): | |
model = Laplace(2, vocabulary=vocabulary) | |
model.fit(bigram_training_data) | |
return model | |
def test_laplace_bigram_score(laplace_bigram_model, word, context, expected_score): | |
assert ( | |
pytest.approx(laplace_bigram_model.score(word, context), 1e-4) == expected_score | |
) | |
def test_laplace_bigram_entropy_perplexity(laplace_bigram_model): | |
text = [ | |
("<s>", "a"), | |
("a", "c"), | |
("c", "<UNK>"), | |
("<UNK>", "d"), | |
("d", "c"), | |
("c", "</s>"), | |
] | |
# Unlike MLE this should be able to handle completely novel ngrams | |
# Ngram = score, log score | |
# <s>, a = 0.2, -2.3219 | |
# a, c = 0.1, -3.3219 | |
# c, UNK = 0.(1), -3.1699 | |
# UNK, d = 0.(09), 3.4594 | |
# d, c = 0.1 -3.3219 | |
# c, </s> = 0.(1), -3.1699 | |
# Total logscores: −18.7651 | |
# - AVG logscores: 3.1275 | |
H = 3.1275 | |
perplexity = 8.7393 | |
assert pytest.approx(laplace_bigram_model.entropy(text), 1e-4) == H | |
assert pytest.approx(laplace_bigram_model.perplexity(text), 1e-4) == perplexity | |
def test_laplace_gamma(laplace_bigram_model): | |
assert laplace_bigram_model.gamma == 1 | |
def wittenbell_trigram_model(trigram_training_data, vocabulary): | |
model = WittenBellInterpolated(3, vocabulary=vocabulary) | |
model.fit(trigram_training_data) | |
return model | |
def test_wittenbell_trigram_score( | |
wittenbell_trigram_model, word, context, expected_score | |
): | |
assert ( | |
pytest.approx(wittenbell_trigram_model.score(word, context), 1e-4) | |
== expected_score | |
) | |
############################################################################### | |
# Notation Explained # | |
############################################################################### | |
# For all subsequent calculations we use the following notation: | |
# 1. '*': Placeholder for any word/character. E.g. '*b' stands for | |
# all bigrams that end in 'b'. '*b*' stands for all trigrams that | |
# contain 'b' in the middle. | |
# 1. count(ngram): Count all instances (tokens) of an ngram. | |
# 1. unique(ngram): Count unique instances (types) of an ngram. | |
def kneserney_trigram_model(trigram_training_data, vocabulary): | |
model = KneserNeyInterpolated(order=3, discount=0.75, vocabulary=vocabulary) | |
model.fit(trigram_training_data) | |
return model | |
def test_kneserney_trigram_score( | |
kneserney_trigram_model, word, context, expected_score | |
): | |
assert ( | |
pytest.approx(kneserney_trigram_model.score(word, context), 1e-4) | |
== expected_score | |
) | |
def absolute_discounting_trigram_model(trigram_training_data, vocabulary): | |
model = AbsoluteDiscountingInterpolated(order=3, vocabulary=vocabulary) | |
model.fit(trigram_training_data) | |
return model | |
def test_absolute_discounting_trigram_score( | |
absolute_discounting_trigram_model, word, context, expected_score | |
): | |
assert ( | |
pytest.approx(absolute_discounting_trigram_model.score(word, context), 1e-4) | |
== expected_score | |
) | |
def stupid_backoff_trigram_model(trigram_training_data, vocabulary): | |
model = StupidBackoff(order=3, vocabulary=vocabulary) | |
model.fit(trigram_training_data) | |
return model | |
def test_stupid_backoff_trigram_score( | |
stupid_backoff_trigram_model, word, context, expected_score | |
): | |
assert ( | |
pytest.approx(stupid_backoff_trigram_model.score(word, context), 1e-4) | |
== expected_score | |
) | |
############################################################################### | |
# Probability Distributions Should Sum up to Unity # | |
############################################################################### | |
def kneserney_bigram_model(bigram_training_data, vocabulary): | |
model = KneserNeyInterpolated(order=2, vocabulary=vocabulary) | |
model.fit(bigram_training_data) | |
return model | |
def test_sums_to_1(model_fixture, context, request): | |
model = request.getfixturevalue(model_fixture) | |
scores_for_context = sum(model.score(w, context) for w in model.vocab) | |
assert pytest.approx(scores_for_context, 1e-7) == 1.0 | |
############################################################################### | |
# Generating Text # | |
############################################################################### | |
def test_generate_one_no_context(mle_trigram_model): | |
assert mle_trigram_model.generate(random_seed=3) == "<UNK>" | |
def test_generate_one_from_limiting_context(mle_trigram_model): | |
# We don't need random_seed for contexts with only one continuation | |
assert mle_trigram_model.generate(text_seed=["c"]) == "d" | |
assert mle_trigram_model.generate(text_seed=["b", "c"]) == "d" | |
assert mle_trigram_model.generate(text_seed=["a", "c"]) == "d" | |
def test_generate_one_from_varied_context(mle_trigram_model): | |
# When context doesn't limit our options enough, seed the random choice | |
assert mle_trigram_model.generate(text_seed=("a", "<s>"), random_seed=2) == "a" | |
def test_generate_cycle(mle_trigram_model): | |
# Add a cycle to the model: bd -> b, db -> d | |
more_training_text = [padded_everygrams(mle_trigram_model.order, list("bdbdbd"))] | |
mle_trigram_model.fit(more_training_text) | |
# Test that we can escape the cycle | |
assert mle_trigram_model.generate(7, text_seed=("b", "d"), random_seed=5) == [ | |
"b", | |
"d", | |
"b", | |
"d", | |
"b", | |
"d", | |
"</s>", | |
] | |
def test_generate_with_text_seed(mle_trigram_model): | |
assert mle_trigram_model.generate(5, text_seed=("<s>", "e"), random_seed=3) == [ | |
"<UNK>", | |
"a", | |
"d", | |
"b", | |
"<UNK>", | |
] | |
def test_generate_oov_text_seed(mle_trigram_model): | |
assert mle_trigram_model.generate( | |
text_seed=("aliens",), random_seed=3 | |
) == mle_trigram_model.generate(text_seed=("<UNK>",), random_seed=3) | |
def test_generate_None_text_seed(mle_trigram_model): | |
# should crash with type error when we try to look it up in vocabulary | |
with pytest.raises(TypeError): | |
mle_trigram_model.generate(text_seed=(None,)) | |
# This will work | |
assert mle_trigram_model.generate( | |
text_seed=None, random_seed=3 | |
) == mle_trigram_model.generate(random_seed=3) | |