|
import os |
|
import kenlm |
|
import sentencepiece as spm |
|
from tokenizers import normalizers |
|
|
|
|
|
class KenlmModel: |
|
def __init__( |
|
self, |
|
vocabulary_size: str, |
|
ngram: str, |
|
pruning: str, |
|
normalize_nfd: bool = True, |
|
normalize_numbers: bool = True, |
|
normalize_puctuation: bool = True, |
|
): |
|
self.model = kenlm.Model(os.path.join("files", f"jomleh-sp-{vocabulary_size}-o{ngram}-prune{pruning}.probing")) |
|
self.tokenizer = spm.SentencePieceProcessor(os.path.join("files", f"jomleh-sp-{vocabulary_size}.model")) |
|
|
|
norm_list = [] |
|
if normalize_numbers: |
|
norm_list += [normalizers.Replace("۱", "۰"), |
|
normalizers.Replace("۲", "۰"), |
|
normalizers.Replace("۳", "۰"), |
|
normalizers.Replace("۴", "۰"), |
|
normalizers.Replace("۵", "۰"), |
|
normalizers.Replace("۶", "۰"), |
|
normalizers.Replace("۷", "۰"), |
|
normalizers.Replace("۸", "۰"), |
|
normalizers.Replace("۹", "۰"), |
|
normalizers.Replace(".", "")] |
|
if normalize_puctuation: |
|
norm_list += [normalizers.Replace(".", ""), |
|
normalizers.Replace("!", ""), |
|
normalizers.Replace("؛", ""), |
|
normalizers.Replace("،", ""), |
|
normalizers.Replace("؟", "")] |
|
if normalize_nfd: |
|
norm_list += [normalizers.NFD()] |
|
norm_list += [normalizers.Strip()] |
|
|
|
self.normalizer = normalizers.Sequence(norm_list) |
|
|
|
@classmethod |
|
def from_pretrained( |
|
cls, |
|
vocabulary_size: str, |
|
ngram: str, |
|
pruning: str, |
|
): |
|
return cls(vocabulary_size, ngram, pruning) |
|
|
|
def score(self, doc: str): |
|
doc = self.normalizer.normalize_str(doc) |
|
doc = ' '.join(self.tokenizer.encode(doc, out_type=str)) |
|
return self.model.score(doc) |
|
|
|
def perplexity(self, doc: str): |
|
doc = self.normalizer.normalize_str(doc) |
|
doc = ' '.join(self.tokenizer.encode(doc, out_type=str)) |
|
log_score = self.model.score(doc) |
|
length = len(doc.split()) + 1 |
|
return round(10.0 ** (-log_score / length), 1) |
|
|