File size: 5,461 Bytes
69558da 9ca4613 69558da 9ca4613 69558da 9ca4613 69558da |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
import math
import pickle
import re
import string
from nltk.tokenize import word_tokenize
from nltk.tokenize.treebank import TreebankWordDetokenizer
class TrueCaser(object):
def __init__(self, dist_file_path):
with open(dist_file_path, "rb") as distributions_file:
pickle_dict = pickle.load(distributions_file)
self.uni_dist = pickle_dict["uni_dist"]
self.backward_bi_dist = pickle_dict["backward_bi_dist"]
self.forward_bi_dist = pickle_dict["forward_bi_dist"]
self.trigram_dist = pickle_dict["trigram_dist"]
self.word_casing_lookup = pickle_dict["word_casing_lookup"]
self.detknzr = TreebankWordDetokenizer()
def get_score(self, prev_token, possible_token, next_token):
pseudo_count = 5.0
# Get Unigram Score
numerator = self.uni_dist[possible_token] + pseudo_count
denominator = 0
for alternativeToken in self.word_casing_lookup[possible_token.lower()]:
denominator += self.uni_dist[alternativeToken] + pseudo_count
unigram_score = numerator / denominator
# Get Backward Score
bigram_backward_score = 1
if prev_token is not None:
key = prev_token + "_" + possible_token
numerator = self.backward_bi_dist[key] + pseudo_count
denominator = 0
for alternativeToken in self.word_casing_lookup[possible_token.lower()]:
key = prev_token + "_" + alternativeToken
denominator += self.backward_bi_dist[key] + pseudo_count
bigram_backward_score = numerator / denominator
# Get Forward Score
bigram_forward_score = 1
if next_token is not None:
next_token = next_token.lower() # Ensure it is lower case
key = possible_token + "_" + next_token
numerator = self.forward_bi_dist[key] + pseudo_count
denominator = 0
for alternativeToken in self.word_casing_lookup[possible_token.lower()]:
key = alternativeToken + "_" + next_token
denominator += self.forward_bi_dist[key] + pseudo_count
bigram_forward_score = numerator / denominator
# Get Trigram Score
trigram_score = 1
if prev_token is not None and next_token is not None:
next_token = next_token.lower() # Ensure it is lower case
trigram_key = prev_token + "_" + possible_token + "_" + next_token
numerator = self.trigram_dist[trigram_key] + pseudo_count
denominator = 0
for alternativeToken in self.word_casing_lookup[possible_token.lower()]:
trigram_key = prev_token + "_" + alternativeToken + "_" + next_token
denominator += self.trigram_dist[trigram_key] + pseudo_count
trigram_score = numerator / denominator
result = (
math.log(unigram_score)
+ math.log(bigram_backward_score)
+ math.log(bigram_forward_score)
+ math.log(trigram_score)
)
return result
@staticmethod
def first_token_case(raw):
return raw.capitalize()
@staticmethod
def upper_replacement(match):
return '. ' + match.group(0)[-1].upper()
def get_true_case(self, sentence, out_of_vocabulary_token_option="title"):
tokens = word_tokenize(sentence)
tokens_true_case = self.get_true_case_from_tokens(tokens, out_of_vocabulary_token_option)
text = self.detknzr.detokenize(tokens_true_case)
text = re.sub(r' \. .', self.upper_replacement, text)
return text
def get_true_case_from_tokens(self, tokens, out_of_vocabulary_token_option="title"):
tokens_true_case = []
if not len(tokens):
return tokens_true_case
for token_idx, token in enumerate(tokens):
if token in string.punctuation or token.isdigit():
tokens_true_case.append(token)
continue
token = token.lower()
if token not in self.word_casing_lookup: # Token out of vocabulary
if out_of_vocabulary_token_option == "title":
tokens_true_case.append(token.title())
elif out_of_vocabulary_token_option == "capitalize":
tokens_true_case.append(token.capitalize())
elif out_of_vocabulary_token_option == "lower":
tokens_true_case.append(token.lower())
else:
tokens_true_case.append(token)
continue
if len(self.word_casing_lookup[token]) == 1:
tokens_true_case.append(list(self.word_casing_lookup[token])[0])
continue
prev_token = tokens_true_case[token_idx - 1] if token_idx > 0 else None
next_token = tokens[token_idx + 1] if token_idx < len(tokens) - 1 else None
best_token = None
highest_score = float("-inf")
for possible_token in self.word_casing_lookup[token]:
score = self.get_score(prev_token, possible_token, next_token)
if score > highest_score:
best_token = possible_token
highest_score = score
tokens_true_case.append(best_token)
tokens_true_case[0] = self.first_token_case(tokens_true_case[0])
return tokens_true_case
|