kenlm-sp-jomleh / model.py
mehran's picture
Upload model.py
83edfa9
raw
history blame
2.36 kB
import os
import kenlm
import sentencepiece as spm
from tokenizers import normalizers
class KenlmModel:
def __init__(
self,
vocabulary_size: str,
ngram: str,
pruning: str,
normalize_nfd: bool = True,
normalize_numbers: bool = True,
normalize_puctuation: bool = True,
):
self.model = kenlm.Model(os.path.join("files", f"jomleh-sp-{vocabulary_size}-o{ngram}-prune{pruning}.probing"))
self.tokenizer = spm.SentencePieceProcessor(os.path.join("files", f"jomleh-sp-{vocabulary_size}.model"))
norm_list = []
if normalize_numbers:
norm_list += [normalizers.Replace("۱", "۰"),
normalizers.Replace("۲", "۰"),
normalizers.Replace("۳", "۰"),
normalizers.Replace("۴", "۰"),
normalizers.Replace("۵", "۰"),
normalizers.Replace("۶", "۰"),
normalizers.Replace("۷", "۰"),
normalizers.Replace("۸", "۰"),
normalizers.Replace("۹", "۰"),
normalizers.Replace(".", "")]
if normalize_puctuation:
norm_list += [normalizers.Replace(".", ""),
normalizers.Replace("!", ""),
normalizers.Replace("؛", ""),
normalizers.Replace("،", ""),
normalizers.Replace("؟", "")]
if normalize_nfd:
norm_list += [normalizers.NFD()]
norm_list += [normalizers.Strip()]
self.normalizer = normalizers.Sequence(norm_list)
@classmethod
def from_pretrained(
cls,
vocabulary_size: str,
ngram: str,
pruning: str,
):
return cls(vocabulary_size, ngram, pruning)
def score(self, doc: str):
doc = self.normalizer.normalize_str(doc)
doc = ' '.join(self.tokenizer.encode(doc, out_type=str))
return self.model.score(doc)
def perplexity(self, doc: str):
doc = self.normalizer.normalize_str(doc)
doc = ' '.join(self.tokenizer.encode(doc, out_type=str))
log_score = self.model.score(doc)
length = len(doc.split()) + 1
return round(10.0 ** (-log_score / length), 1)