|
import math |
|
import pickle |
|
import re |
|
import string |
|
|
|
from nltk.tokenize import word_tokenize |
|
from nltk.tokenize.treebank import TreebankWordDetokenizer |
|
|
|
|
|
class TrueCaser(object): |
|
def __init__(self, dist_file_path): |
|
with open(dist_file_path, "rb") as distributions_file: |
|
pickle_dict = pickle.load(distributions_file) |
|
self.uni_dist = pickle_dict["uni_dist"] |
|
self.backward_bi_dist = pickle_dict["backward_bi_dist"] |
|
self.forward_bi_dist = pickle_dict["forward_bi_dist"] |
|
self.trigram_dist = pickle_dict["trigram_dist"] |
|
self.word_casing_lookup = pickle_dict["word_casing_lookup"] |
|
self.detknzr = TreebankWordDetokenizer() |
|
|
|
def get_score(self, prev_token, possible_token, next_token): |
|
pseudo_count = 5.0 |
|
|
|
|
|
numerator = self.uni_dist[possible_token] + pseudo_count |
|
denominator = 0 |
|
for alternativeToken in self.word_casing_lookup[possible_token.lower()]: |
|
denominator += self.uni_dist[alternativeToken] + pseudo_count |
|
|
|
unigram_score = numerator / denominator |
|
|
|
|
|
bigram_backward_score = 1 |
|
if prev_token is not None: |
|
key = prev_token + "_" + possible_token |
|
numerator = self.backward_bi_dist[key] + pseudo_count |
|
denominator = 0 |
|
for alternativeToken in self.word_casing_lookup[possible_token.lower()]: |
|
key = prev_token + "_" + alternativeToken |
|
denominator += self.backward_bi_dist[key] + pseudo_count |
|
|
|
bigram_backward_score = numerator / denominator |
|
|
|
|
|
bigram_forward_score = 1 |
|
if next_token is not None: |
|
next_token = next_token.lower() |
|
key = possible_token + "_" + next_token |
|
numerator = self.forward_bi_dist[key] + pseudo_count |
|
denominator = 0 |
|
for alternativeToken in self.word_casing_lookup[possible_token.lower()]: |
|
key = alternativeToken + "_" + next_token |
|
denominator += self.forward_bi_dist[key] + pseudo_count |
|
|
|
bigram_forward_score = numerator / denominator |
|
|
|
|
|
trigram_score = 1 |
|
if prev_token is not None and next_token is not None: |
|
next_token = next_token.lower() |
|
trigram_key = prev_token + "_" + possible_token + "_" + next_token |
|
numerator = self.trigram_dist[trigram_key] + pseudo_count |
|
denominator = 0 |
|
for alternativeToken in self.word_casing_lookup[possible_token.lower()]: |
|
trigram_key = prev_token + "_" + alternativeToken + "_" + next_token |
|
denominator += self.trigram_dist[trigram_key] + pseudo_count |
|
|
|
trigram_score = numerator / denominator |
|
|
|
result = ( |
|
math.log(unigram_score) |
|
+ math.log(bigram_backward_score) |
|
+ math.log(bigram_forward_score) |
|
+ math.log(trigram_score) |
|
) |
|
|
|
return result |
|
|
|
@staticmethod |
|
def first_token_case(raw): |
|
return raw.capitalize() |
|
|
|
@staticmethod |
|
def upper_replacement(match): |
|
return '. ' + match.group(0)[-1].upper() |
|
|
|
def get_true_case(self, sentence, out_of_vocabulary_token_option="title"): |
|
tokens = word_tokenize(sentence) |
|
tokens_true_case = self.get_true_case_from_tokens(tokens, out_of_vocabulary_token_option) |
|
text = self.detknzr.detokenize(tokens_true_case) |
|
text = re.sub(r' \. .', self.upper_replacement, text) |
|
return text |
|
|
|
def get_true_case_from_tokens(self, tokens, out_of_vocabulary_token_option="title"): |
|
tokens_true_case = [] |
|
|
|
if not len(tokens): |
|
return tokens_true_case |
|
|
|
for token_idx, token in enumerate(tokens): |
|
if token in string.punctuation or token.isdigit(): |
|
tokens_true_case.append(token) |
|
continue |
|
|
|
token = token.lower() |
|
if token not in self.word_casing_lookup: |
|
if out_of_vocabulary_token_option == "title": |
|
tokens_true_case.append(token.title()) |
|
elif out_of_vocabulary_token_option == "capitalize": |
|
tokens_true_case.append(token.capitalize()) |
|
elif out_of_vocabulary_token_option == "lower": |
|
tokens_true_case.append(token.lower()) |
|
else: |
|
tokens_true_case.append(token) |
|
continue |
|
|
|
if len(self.word_casing_lookup[token]) == 1: |
|
tokens_true_case.append(list(self.word_casing_lookup[token])[0]) |
|
continue |
|
|
|
prev_token = tokens_true_case[token_idx - 1] if token_idx > 0 else None |
|
next_token = tokens[token_idx + 1] if token_idx < len(tokens) - 1 else None |
|
|
|
best_token = None |
|
highest_score = float("-inf") |
|
|
|
for possible_token in self.word_casing_lookup[token]: |
|
score = self.get_score(prev_token, possible_token, next_token) |
|
|
|
if score > highest_score: |
|
best_token = possible_token |
|
highest_score = score |
|
|
|
tokens_true_case.append(best_token) |
|
|
|
tokens_true_case[0] = self.first_token_case(tokens_true_case[0]) |
|
return tokens_true_case |
|
|