File size: 20,516 Bytes
d916065
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
# Natural Language Toolkit: Stack decoder
#
# Copyright (C) 2001-2023 NLTK Project
# Author: Tah Wei Hoon <[email protected]>
# URL: <https://www.nltk.org/>
# For license information, see LICENSE.TXT

"""

A decoder that uses stacks to implement phrase-based translation.



In phrase-based translation, the source sentence is segmented into

phrases of one or more words, and translations for those phrases are

used to build the target sentence.



Hypothesis data structures are used to keep track of the source words

translated so far and the partial output. A hypothesis can be expanded

by selecting an untranslated phrase, looking up its translation in a

phrase table, and appending that translation to the partial output.

Translation is complete when a hypothesis covers all source words.



The search space is huge because the source sentence can be segmented

in different ways, the source phrases can be selected in any order,

and there could be multiple translations for the same source phrase in

the phrase table. To make decoding tractable, stacks are used to limit

the number of candidate hypotheses by doing histogram and/or threshold

pruning.



Hypotheses with the same number of words translated are placed in the

same stack. In histogram pruning, each stack has a size limit, and

the hypothesis with the lowest score is removed when the stack is full.

In threshold pruning, hypotheses that score below a certain threshold

of the best hypothesis in that stack are removed.



Hypothesis scoring can include various factors such as phrase

translation probability, language model probability, length of

translation, cost of remaining words to be translated, and so on.





References:

Philipp Koehn. 2010. Statistical Machine Translation.

Cambridge University Press, New York.

"""

import warnings
from collections import defaultdict
from math import log


class StackDecoder:
    """

    Phrase-based stack decoder for machine translation



    >>> from nltk.translate import PhraseTable

    >>> phrase_table = PhraseTable()

    >>> phrase_table.add(('niemand',), ('nobody',), log(0.8))

    >>> phrase_table.add(('niemand',), ('no', 'one'), log(0.2))

    >>> phrase_table.add(('erwartet',), ('expects',), log(0.8))

    >>> phrase_table.add(('erwartet',), ('expecting',), log(0.2))

    >>> phrase_table.add(('niemand', 'erwartet'), ('one', 'does', 'not', 'expect'), log(0.1))

    >>> phrase_table.add(('die', 'spanische', 'inquisition'), ('the', 'spanish', 'inquisition'), log(0.8))

    >>> phrase_table.add(('!',), ('!',), log(0.8))



    >>> #  nltk.model should be used here once it is implemented

    >>> from collections import defaultdict

    >>> language_prob = defaultdict(lambda: -999.0)

    >>> language_prob[('nobody',)] = log(0.5)

    >>> language_prob[('expects',)] = log(0.4)

    >>> language_prob[('the', 'spanish', 'inquisition')] = log(0.2)

    >>> language_prob[('!',)] = log(0.1)

    >>> language_model = type('',(object,),{'probability_change': lambda self, context, phrase: language_prob[phrase], 'probability': lambda self, phrase: language_prob[phrase]})()



    >>> stack_decoder = StackDecoder(phrase_table, language_model)



    >>> stack_decoder.translate(['niemand', 'erwartet', 'die', 'spanische', 'inquisition', '!'])

    ['nobody', 'expects', 'the', 'spanish', 'inquisition', '!']



    """

    def __init__(self, phrase_table, language_model):
        """

        :param phrase_table: Table of translations for source language

            phrases and the log probabilities for those translations.

        :type phrase_table: PhraseTable



        :param language_model: Target language model. Must define a

            ``probability_change`` method that calculates the change in

            log probability of a sentence, if a given string is appended

            to it.

            This interface is experimental and will likely be replaced

            with nltk.model once it is implemented.

        :type language_model: object

        """
        self.phrase_table = phrase_table
        self.language_model = language_model

        self.word_penalty = 0.0
        """

        float: Influences the translation length exponentially.

            If positive, shorter translations are preferred.

            If negative, longer translations are preferred.

            If zero, no penalty is applied.

        """

        self.beam_threshold = 0.0
        """

        float: Hypotheses that score below this factor of the best

            hypothesis in a stack are dropped from consideration.

            Value between 0.0 and 1.0.

        """

        self.stack_size = 100
        """

        int: Maximum number of hypotheses to consider in a stack.

            Higher values increase the likelihood of a good translation,

            but increases processing time.

        """

        self.__distortion_factor = 0.5
        self.__compute_log_distortion()

    @property
    def distortion_factor(self):
        """

        float: Amount of reordering of source phrases.

            Lower values favour monotone translation, suitable when

            word order is similar for both source and target languages.

            Value between 0.0 and 1.0. Default 0.5.

        """
        return self.__distortion_factor

    @distortion_factor.setter
    def distortion_factor(self, d):
        self.__distortion_factor = d
        self.__compute_log_distortion()

    def __compute_log_distortion(self):
        # cache log(distortion_factor) so we don't have to recompute it
        # when scoring hypotheses
        if self.__distortion_factor == 0.0:
            self.__log_distortion_factor = log(1e-9)  # 1e-9 is almost zero
        else:
            self.__log_distortion_factor = log(self.__distortion_factor)

    def translate(self, src_sentence):
        """

        :param src_sentence: Sentence to be translated

        :type src_sentence: list(str)



        :return: Translated sentence

        :rtype: list(str)

        """
        sentence = tuple(src_sentence)  # prevent accidental modification
        sentence_length = len(sentence)
        stacks = [
            _Stack(self.stack_size, self.beam_threshold)
            for _ in range(0, sentence_length + 1)
        ]
        empty_hypothesis = _Hypothesis()
        stacks[0].push(empty_hypothesis)

        all_phrases = self.find_all_src_phrases(sentence)
        future_score_table = self.compute_future_scores(sentence)
        for stack in stacks:
            for hypothesis in stack:
                possible_expansions = StackDecoder.valid_phrases(
                    all_phrases, hypothesis
                )
                for src_phrase_span in possible_expansions:
                    src_phrase = sentence[src_phrase_span[0] : src_phrase_span[1]]
                    for translation_option in self.phrase_table.translations_for(
                        src_phrase
                    ):
                        raw_score = self.expansion_score(
                            hypothesis, translation_option, src_phrase_span
                        )
                        new_hypothesis = _Hypothesis(
                            raw_score=raw_score,
                            src_phrase_span=src_phrase_span,
                            trg_phrase=translation_option.trg_phrase,
                            previous=hypothesis,
                        )
                        new_hypothesis.future_score = self.future_score(
                            new_hypothesis, future_score_table, sentence_length
                        )
                        total_words = new_hypothesis.total_translated_words()
                        stacks[total_words].push(new_hypothesis)

        if not stacks[sentence_length]:
            warnings.warn(
                "Unable to translate all words. "
                "The source sentence contains words not in "
                "the phrase table"
            )
            # Instead of returning empty output, perhaps a partial
            # translation could be returned
            return []

        best_hypothesis = stacks[sentence_length].best()
        return best_hypothesis.translation_so_far()

    def find_all_src_phrases(self, src_sentence):
        """

        Finds all subsequences in src_sentence that have a phrase

        translation in the translation table



        :type src_sentence: tuple(str)



        :return: Subsequences that have a phrase translation,

            represented as a table of lists of end positions.

            For example, if result[2] is [5, 6, 9], then there are

            three phrases starting from position 2 in ``src_sentence``,

            ending at positions 5, 6, and 9 exclusive. The list of

            ending positions are in ascending order.

        :rtype: list(list(int))

        """
        sentence_length = len(src_sentence)
        phrase_indices = [[] for _ in src_sentence]
        for start in range(0, sentence_length):
            for end in range(start + 1, sentence_length + 1):
                potential_phrase = src_sentence[start:end]
                if potential_phrase in self.phrase_table:
                    phrase_indices[start].append(end)
        return phrase_indices

    def compute_future_scores(self, src_sentence):
        """

        Determines the approximate scores for translating every

        subsequence in ``src_sentence``



        Future scores can be used a look-ahead to determine the

        difficulty of translating the remaining parts of a src_sentence.



        :type src_sentence: tuple(str)



        :return: Scores of subsequences referenced by their start and

            end positions. For example, result[2][5] is the score of the

            subsequence covering positions 2, 3, and 4.

        :rtype: dict(int: (dict(int): float))

        """
        scores = defaultdict(lambda: defaultdict(lambda: float("-inf")))
        for seq_length in range(1, len(src_sentence) + 1):
            for start in range(0, len(src_sentence) - seq_length + 1):
                end = start + seq_length
                phrase = src_sentence[start:end]
                if phrase in self.phrase_table:
                    score = self.phrase_table.translations_for(phrase)[
                        0
                    ].log_prob  # pick best (first) translation
                    # Warning: API of language_model is subject to change
                    score += self.language_model.probability(phrase)
                    scores[start][end] = score

                # check if a better score can be obtained by combining
                # two child subsequences
                for mid in range(start + 1, end):
                    combined_score = scores[start][mid] + scores[mid][end]
                    if combined_score > scores[start][end]:
                        scores[start][end] = combined_score
        return scores

    def future_score(self, hypothesis, future_score_table, sentence_length):
        """

        Determines the approximate score for translating the

        untranslated words in ``hypothesis``

        """
        score = 0.0
        for span in hypothesis.untranslated_spans(sentence_length):
            score += future_score_table[span[0]][span[1]]
        return score

    def expansion_score(self, hypothesis, translation_option, src_phrase_span):
        """

        Calculate the score of expanding ``hypothesis`` with

        ``translation_option``



        :param hypothesis: Hypothesis being expanded

        :type hypothesis: _Hypothesis



        :param translation_option: Information about the proposed expansion

        :type translation_option: PhraseTableEntry



        :param src_phrase_span: Word position span of the source phrase

        :type src_phrase_span: tuple(int, int)

        """
        score = hypothesis.raw_score
        score += translation_option.log_prob
        # The API of language_model is subject to change; it could accept
        # a string, a list of words, and/or some other type
        score += self.language_model.probability_change(
            hypothesis, translation_option.trg_phrase
        )
        score += self.distortion_score(hypothesis, src_phrase_span)
        score -= self.word_penalty * len(translation_option.trg_phrase)
        return score

    def distortion_score(self, hypothesis, next_src_phrase_span):
        if not hypothesis.src_phrase_span:
            return 0.0
        next_src_phrase_start = next_src_phrase_span[0]
        prev_src_phrase_end = hypothesis.src_phrase_span[1]
        distortion_distance = next_src_phrase_start - prev_src_phrase_end
        return abs(distortion_distance) * self.__log_distortion_factor

    @staticmethod
    def valid_phrases(all_phrases_from, hypothesis):
        """

        Extract phrases from ``all_phrases_from`` that contains words

        that have not been translated by ``hypothesis``



        :param all_phrases_from: Phrases represented by their spans, in

            the same format as the return value of

            ``find_all_src_phrases``

        :type all_phrases_from: list(list(int))



        :type hypothesis: _Hypothesis



        :return: A list of phrases, represented by their spans, that

            cover untranslated positions.

        :rtype: list(tuple(int, int))

        """
        untranslated_spans = hypothesis.untranslated_spans(len(all_phrases_from))
        valid_phrases = []
        for available_span in untranslated_spans:
            start = available_span[0]
            available_end = available_span[1]
            while start < available_end:
                for phrase_end in all_phrases_from[start]:
                    if phrase_end > available_end:
                        # Subsequent elements in all_phrases_from[start]
                        # will also be > available_end, since the
                        # elements are in ascending order
                        break
                    valid_phrases.append((start, phrase_end))
                start += 1
        return valid_phrases


class _Hypothesis:
    """

    Partial solution to a translation.



    Records the word positions of the phrase being translated, its

    translation, raw score, and the cost of the untranslated parts of

    the sentence. When the next phrase is selected to build upon the

    partial solution, a new _Hypothesis object is created, with a back

    pointer to the previous hypothesis.



    To find out which words have been translated so far, look at the

    ``src_phrase_span`` in the hypothesis chain. Similarly, the

    translation output can be found by traversing up the chain.

    """

    def __init__(

        self,

        raw_score=0.0,

        src_phrase_span=(),

        trg_phrase=(),

        previous=None,

        future_score=0.0,

    ):
        """

        :param raw_score: Likelihood of hypothesis so far.

            Higher is better. Does not account for untranslated words.

        :type raw_score: float



        :param src_phrase_span: Span of word positions covered by the

            source phrase in this hypothesis expansion. For example,

            (2, 5) means that the phrase is from the second word up to,

            but not including the fifth word in the source sentence.

        :type src_phrase_span: tuple(int)



        :param trg_phrase: Translation of the source phrase in this

            hypothesis expansion

        :type trg_phrase: tuple(str)



        :param previous: Previous hypothesis before expansion to this one

        :type previous: _Hypothesis



        :param future_score: Approximate score for translating the

            remaining words not covered by this hypothesis. Higher means

            that the remaining words are easier to translate.

        :type future_score: float

        """
        self.raw_score = raw_score
        self.src_phrase_span = src_phrase_span
        self.trg_phrase = trg_phrase
        self.previous = previous
        self.future_score = future_score

    def score(self):
        """

        Overall score of hypothesis after accounting for local and

        global features

        """
        return self.raw_score + self.future_score

    def untranslated_spans(self, sentence_length):
        """

        Starting from each untranslated word, find the longest

        continuous span of untranslated positions



        :param sentence_length: Length of source sentence being

            translated by the hypothesis

        :type sentence_length: int



        :rtype: list(tuple(int, int))

        """
        translated_positions = self.translated_positions()
        translated_positions.sort()
        translated_positions.append(sentence_length)  # add sentinel position

        untranslated_spans = []
        start = 0
        # each untranslated span must end in one of the translated_positions
        for end in translated_positions:
            if start < end:
                untranslated_spans.append((start, end))
            start = end + 1

        return untranslated_spans

    def translated_positions(self):
        """

        List of positions in the source sentence of words already

        translated. The list is not sorted.



        :rtype: list(int)

        """
        translated_positions = []
        current_hypothesis = self
        while current_hypothesis.previous is not None:
            translated_span = current_hypothesis.src_phrase_span
            translated_positions.extend(range(translated_span[0], translated_span[1]))
            current_hypothesis = current_hypothesis.previous
        return translated_positions

    def total_translated_words(self):
        return len(self.translated_positions())

    def translation_so_far(self):
        translation = []
        self.__build_translation(self, translation)
        return translation

    def __build_translation(self, hypothesis, output):
        if hypothesis.previous is None:
            return
        self.__build_translation(hypothesis.previous, output)
        output.extend(hypothesis.trg_phrase)


class _Stack:
    """

    Collection of _Hypothesis objects

    """

    def __init__(self, max_size=100, beam_threshold=0.0):
        """

        :param beam_threshold: Hypotheses that score less than this

            factor of the best hypothesis are discarded from the stack.

            Value must be between 0.0 and 1.0.

        :type beam_threshold: float

        """
        self.max_size = max_size
        self.items = []

        if beam_threshold == 0.0:
            self.__log_beam_threshold = float("-inf")
        else:
            self.__log_beam_threshold = log(beam_threshold)

    def push(self, hypothesis):
        """

        Add ``hypothesis`` to the stack.

        Removes lowest scoring hypothesis if the stack is full.

        After insertion, hypotheses that score less than

        ``beam_threshold`` times the score of the best hypothesis

        are removed.

        """
        self.items.append(hypothesis)
        self.items.sort(key=lambda h: h.score(), reverse=True)
        while len(self.items) > self.max_size:
            self.items.pop()
        self.threshold_prune()

    def threshold_prune(self):
        if not self.items:
            return
        #  log(score * beam_threshold) = log(score) + log(beam_threshold)
        threshold = self.items[0].score() + self.__log_beam_threshold
        for hypothesis in reversed(self.items):
            if hypothesis.score() < threshold:
                self.items.pop()
            else:
                break

    def best(self):
        """

        :return: Hypothesis with the highest score in the stack

        :rtype: _Hypothesis

        """
        if self.items:
            return self.items[0]
        return None

    def __iter__(self):
        return iter(self.items)

    def __contains__(self, hypothesis):
        return hypothesis in self.items

    def __bool__(self):
        return len(self.items) != 0

    __nonzero__ = __bool__