mehran commited on
Commit
83edfa9
1 Parent(s): 6d91bc8

Upload model.py

Browse files
Files changed (1) hide show
  1. model.py +63 -0
model.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import kenlm
3
+ import sentencepiece as spm
4
+ from tokenizers import normalizers
5
+
6
+
7
+ class KenlmModel:
8
+ def __init__(
9
+ self,
10
+ vocabulary_size: str,
11
+ ngram: str,
12
+ pruning: str,
13
+ normalize_nfd: bool = True,
14
+ normalize_numbers: bool = True,
15
+ normalize_puctuation: bool = True,
16
+ ):
17
+ self.model = kenlm.Model(os.path.join("files", f"jomleh-sp-{vocabulary_size}-o{ngram}-prune{pruning}.probing"))
18
+ self.tokenizer = spm.SentencePieceProcessor(os.path.join("files", f"jomleh-sp-{vocabulary_size}.model"))
19
+
20
+ norm_list = []
21
+ if normalize_numbers:
22
+ norm_list += [normalizers.Replace("۱", "۰"),
23
+ normalizers.Replace("۲", "۰"),
24
+ normalizers.Replace("۳", "۰"),
25
+ normalizers.Replace("۴", "۰"),
26
+ normalizers.Replace("۵", "۰"),
27
+ normalizers.Replace("۶", "۰"),
28
+ normalizers.Replace("۷", "۰"),
29
+ normalizers.Replace("۸", "۰"),
30
+ normalizers.Replace("۹", "۰"),
31
+ normalizers.Replace(".", "")]
32
+ if normalize_puctuation:
33
+ norm_list += [normalizers.Replace(".", ""),
34
+ normalizers.Replace("!", ""),
35
+ normalizers.Replace("؛", ""),
36
+ normalizers.Replace("،", ""),
37
+ normalizers.Replace("؟", "")]
38
+ if normalize_nfd:
39
+ norm_list += [normalizers.NFD()]
40
+ norm_list += [normalizers.Strip()]
41
+
42
+ self.normalizer = normalizers.Sequence(norm_list)
43
+
44
+ @classmethod
45
+ def from_pretrained(
46
+ cls,
47
+ vocabulary_size: str,
48
+ ngram: str,
49
+ pruning: str,
50
+ ):
51
+ return cls(vocabulary_size, ngram, pruning)
52
+
53
+ def score(self, doc: str):
54
+ doc = self.normalizer.normalize_str(doc)
55
+ doc = ' '.join(self.tokenizer.encode(doc, out_type=str))
56
+ return self.model.score(doc)
57
+
58
+ def perplexity(self, doc: str):
59
+ doc = self.normalizer.normalize_str(doc)
60
+ doc = ' '.join(self.tokenizer.encode(doc, out_type=str))
61
+ log_score = self.model.score(doc)
62
+ length = len(doc.split()) + 1
63
+ return round(10.0 ** (-log_score / length), 1)