File size: 5,461 Bytes
69558da
 
9ca4613
69558da
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9ca4613
 
 
 
69558da
 
 
9ca4613
 
 
69558da
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import math
import pickle
import re
import string

from nltk.tokenize import word_tokenize
from nltk.tokenize.treebank import TreebankWordDetokenizer


class TrueCaser(object):
    def __init__(self, dist_file_path):
        with open(dist_file_path, "rb") as distributions_file:
            pickle_dict = pickle.load(distributions_file)
            self.uni_dist = pickle_dict["uni_dist"]
            self.backward_bi_dist = pickle_dict["backward_bi_dist"]
            self.forward_bi_dist = pickle_dict["forward_bi_dist"]
            self.trigram_dist = pickle_dict["trigram_dist"]
            self.word_casing_lookup = pickle_dict["word_casing_lookup"]
        self.detknzr = TreebankWordDetokenizer()

    def get_score(self, prev_token, possible_token, next_token):
        pseudo_count = 5.0

        # Get Unigram Score
        numerator = self.uni_dist[possible_token] + pseudo_count
        denominator = 0
        for alternativeToken in self.word_casing_lookup[possible_token.lower()]:
            denominator += self.uni_dist[alternativeToken] + pseudo_count

        unigram_score = numerator / denominator

        # Get Backward Score
        bigram_backward_score = 1
        if prev_token is not None:
            key = prev_token + "_" + possible_token
            numerator = self.backward_bi_dist[key] + pseudo_count
            denominator = 0
            for alternativeToken in self.word_casing_lookup[possible_token.lower()]:
                key = prev_token + "_" + alternativeToken
                denominator += self.backward_bi_dist[key] + pseudo_count

            bigram_backward_score = numerator / denominator

        # Get Forward Score
        bigram_forward_score = 1
        if next_token is not None:
            next_token = next_token.lower()  # Ensure it is lower case
            key = possible_token + "_" + next_token
            numerator = self.forward_bi_dist[key] + pseudo_count
            denominator = 0
            for alternativeToken in self.word_casing_lookup[possible_token.lower()]:
                key = alternativeToken + "_" + next_token
                denominator += self.forward_bi_dist[key] + pseudo_count

            bigram_forward_score = numerator / denominator

        # Get Trigram Score
        trigram_score = 1
        if prev_token is not None and next_token is not None:
            next_token = next_token.lower()  # Ensure it is lower case
            trigram_key = prev_token + "_" + possible_token + "_" + next_token
            numerator = self.trigram_dist[trigram_key] + pseudo_count
            denominator = 0
            for alternativeToken in self.word_casing_lookup[possible_token.lower()]:
                trigram_key = prev_token + "_" + alternativeToken + "_" + next_token
                denominator += self.trigram_dist[trigram_key] + pseudo_count

            trigram_score = numerator / denominator

        result = (
            math.log(unigram_score)
            + math.log(bigram_backward_score)
            + math.log(bigram_forward_score)
            + math.log(trigram_score)
        )

        return result

    @staticmethod
    def first_token_case(raw):
        return raw.capitalize()

    @staticmethod
    def upper_replacement(match):
        return '. ' + match.group(0)[-1].upper()

    def get_true_case(self, sentence, out_of_vocabulary_token_option="title"):
        tokens = word_tokenize(sentence)
        tokens_true_case = self.get_true_case_from_tokens(tokens, out_of_vocabulary_token_option)
        text = self.detknzr.detokenize(tokens_true_case)
        text = re.sub(r' \. .', self.upper_replacement, text)
        return text

    def get_true_case_from_tokens(self, tokens, out_of_vocabulary_token_option="title"):
        tokens_true_case = []

        if not len(tokens):
            return tokens_true_case

        for token_idx, token in enumerate(tokens):
            if token in string.punctuation or token.isdigit():
                tokens_true_case.append(token)
                continue

            token = token.lower()
            if token not in self.word_casing_lookup:  # Token out of vocabulary
                if out_of_vocabulary_token_option == "title":
                    tokens_true_case.append(token.title())
                elif out_of_vocabulary_token_option == "capitalize":
                    tokens_true_case.append(token.capitalize())
                elif out_of_vocabulary_token_option == "lower":
                    tokens_true_case.append(token.lower())
                else:
                    tokens_true_case.append(token)
                continue

            if len(self.word_casing_lookup[token]) == 1:
                tokens_true_case.append(list(self.word_casing_lookup[token])[0])
                continue

            prev_token = tokens_true_case[token_idx - 1] if token_idx > 0 else None
            next_token = tokens[token_idx + 1] if token_idx < len(tokens) - 1 else None

            best_token = None
            highest_score = float("-inf")

            for possible_token in self.word_casing_lookup[token]:
                score = self.get_score(prev_token, possible_token, next_token)

                if score > highest_score:
                    best_token = possible_token
                    highest_score = score

            tokens_true_case.append(best_token)

        tokens_true_case[0] = self.first_token_case(tokens_true_case[0])
        return tokens_true_case