File size: 3,891 Bytes
d916065
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
# Natural Language Toolkit: Language Model Unit Tests
#
# Copyright (C) 2001-2023 NLTK Project
# Author: Ilia Kurenkov <[email protected]>
# URL: <https://www.nltk.org/>
# For license information, see LICENSE.TXT

import unittest

import pytest

from nltk import FreqDist
from nltk.lm import NgramCounter
from nltk.util import everygrams


class TestNgramCounter:
    """Tests for NgramCounter that only involve lookup, no modification."""

    @classmethod
    def setup_class(self):
        text = [list("abcd"), list("egdbe")]
        self.trigram_counter = NgramCounter(
            everygrams(sent, max_len=3) for sent in text
        )
        self.bigram_counter = NgramCounter(everygrams(sent, max_len=2) for sent in text)
        self.case = unittest.TestCase()

    def test_N(self):
        assert self.bigram_counter.N() == 16
        assert self.trigram_counter.N() == 21

    def test_counter_len_changes_with_lookup(self):
        assert len(self.bigram_counter) == 2
        self.bigram_counter[50]
        assert len(self.bigram_counter) == 3

    def test_ngram_order_access_unigrams(self):
        assert self.bigram_counter[1] == self.bigram_counter.unigrams

    def test_ngram_conditional_freqdist(self):
        case = unittest.TestCase()
        expected_trigram_contexts = [
            ("a", "b"),
            ("b", "c"),
            ("e", "g"),
            ("g", "d"),
            ("d", "b"),
        ]
        expected_bigram_contexts = [("a",), ("b",), ("d",), ("e",), ("c",), ("g",)]

        bigrams = self.trigram_counter[2]
        trigrams = self.trigram_counter[3]

        self.case.assertCountEqual(expected_bigram_contexts, bigrams.conditions())
        self.case.assertCountEqual(expected_trigram_contexts, trigrams.conditions())

    def test_bigram_counts_seen_ngrams(self):
        assert self.bigram_counter[["a"]]["b"] == 1
        assert self.bigram_counter[["b"]]["c"] == 1

    def test_bigram_counts_unseen_ngrams(self):
        assert self.bigram_counter[["b"]]["z"] == 0

    def test_unigram_counts_seen_words(self):
        assert self.bigram_counter["b"] == 2

    def test_unigram_counts_completely_unseen_words(self):
        assert self.bigram_counter["z"] == 0


class TestNgramCounterTraining:
    @classmethod
    def setup_class(self):
        self.counter = NgramCounter()
        self.case = unittest.TestCase()

    @pytest.mark.parametrize("case", ["", [], None])
    def test_empty_inputs(self, case):
        test = NgramCounter(case)
        assert 2 not in test
        assert test[1] == FreqDist()

    def test_train_on_unigrams(self):
        words = list("abcd")
        counter = NgramCounter([[(w,) for w in words]])

        assert not counter[3]
        assert not counter[2]
        self.case.assertCountEqual(words, counter[1].keys())

    def test_train_on_illegal_sentences(self):
        str_sent = ["Check", "this", "out", "!"]
        list_sent = [["Check", "this"], ["this", "out"], ["out", "!"]]

        with pytest.raises(TypeError):
            NgramCounter([str_sent])

        with pytest.raises(TypeError):
            NgramCounter([list_sent])

    def test_train_on_bigrams(self):
        bigram_sent = [("a", "b"), ("c", "d")]
        counter = NgramCounter([bigram_sent])
        assert not bool(counter[3])

    def test_train_on_mix(self):
        mixed_sent = [("a", "b"), ("c", "d"), ("e", "f", "g"), ("h",)]
        counter = NgramCounter([mixed_sent])
        unigrams = ["h"]
        bigram_contexts = [("a",), ("c",)]
        trigram_contexts = [("e", "f")]

        self.case.assertCountEqual(unigrams, counter[1].keys())
        self.case.assertCountEqual(bigram_contexts, counter[2].keys())
        self.case.assertCountEqual(trigram_contexts, counter[3].keys())