restructured space
Browse files- README.md +3 -3
- app.py +86 -28
- data_utils.py +1 -1
- diac_utils.py +35 -4
- gradio_cached_examples/16/log.csv +2 -0
- gradio_cached_examples/6/log.csv +2 -0
- predict.py +53 -28
README.md
CHANGED
@@ -1,12 +1,12 @@
|
|
1 |
---
|
2 |
-
title: Partial
|
3 |
-
emoji:
|
4 |
colorFrom: blue
|
5 |
colorTo: gray
|
6 |
sdk: gradio
|
7 |
sdk_version: 4.1.2
|
8 |
app_file: app.py
|
9 |
-
pinned:
|
10 |
license: cc-by-sa-3.0
|
11 |
---
|
12 |
|
|
|
1 |
---
|
2 |
+
title: Partial Arabic Diacritization
|
3 |
+
emoji: 🖋️
|
4 |
colorFrom: blue
|
5 |
colorTo: gray
|
6 |
sdk: gradio
|
7 |
sdk_version: 4.1.2
|
8 |
app_file: app.py
|
9 |
+
pinned: true
|
10 |
license: cc-by-sa-3.0
|
11 |
---
|
12 |
|
app.py
CHANGED
@@ -9,12 +9,12 @@ output_path = "tashkeela-d2.pt"
|
|
9 |
gdrive_templ = "https://drive.google.com/file/d/{}/view?usp=sharing"
|
10 |
if not os.path.exists(output_path):
|
11 |
model_gdrive_id = "1FGelqImFkESbTyRsx_elkKIOZ9VbhRuo"
|
12 |
-
gdown.download(gdrive_templ.format(model_gdrive_id), output=output_path, quiet=False)
|
13 |
|
14 |
output_path = "vocab.vec"
|
15 |
if not os.path.exists(output_path):
|
16 |
vocab_gdrive_id = "1-0muGvcSYEf8RAVRcwXay4MRex6kmCii"
|
17 |
-
gdown.download(gdrive_templ.format(vocab_gdrive_id), output=output_path, quiet=False)
|
18 |
|
19 |
with open("config.yaml", 'r', encoding="utf-8") as file:
|
20 |
config = yaml.load(file, Loader=yaml.FullLoader)
|
@@ -22,41 +22,99 @@ with open("config.yaml", 'r', encoding="utf-8") as file:
|
|
22 |
config["train"]["max-sent-len"] = config["predictor"]["window"]
|
23 |
config["train"]["max-token-count"] = config["predictor"]["window"] * 3
|
24 |
|
25 |
-
|
26 |
-
|
27 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
return diacritized_lines
|
29 |
|
30 |
-
with gr.Blocks() as demo:
|
31 |
gr.Markdown(
|
32 |
"""
|
33 |
# Partial Diacritization: A Context-Contrastive Inference Approach
|
34 |
-
|
|
|
35 |
""")
|
36 |
|
37 |
-
with gr.
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
|
50 |
-
output_txt = gr.Textbox(
|
51 |
-
lines=5,
|
52 |
-
label="Output",
|
53 |
-
type='text',
|
54 |
-
rtl=True,
|
55 |
-
text_align='right',
|
56 |
-
)
|
57 |
|
58 |
-
btn = gr.Button(value="Shakkel")
|
59 |
-
btn.click(diacritze, inputs=[input_txt, check_box], outputs=[output_txt])
|
60 |
|
61 |
if __name__ == "__main__":
|
62 |
demo.queue().launch(
|
|
|
9 |
gdrive_templ = "https://drive.google.com/file/d/{}/view?usp=sharing"
|
10 |
if not os.path.exists(output_path):
|
11 |
model_gdrive_id = "1FGelqImFkESbTyRsx_elkKIOZ9VbhRuo"
|
12 |
+
gdown.download(gdrive_templ.format(model_gdrive_id), output=output_path, quiet=False, fuzzy=True)
|
13 |
|
14 |
output_path = "vocab.vec"
|
15 |
if not os.path.exists(output_path):
|
16 |
vocab_gdrive_id = "1-0muGvcSYEf8RAVRcwXay4MRex6kmCii"
|
17 |
+
gdown.download(gdrive_templ.format(vocab_gdrive_id), output=output_path, quiet=False, fuzzy=True)
|
18 |
|
19 |
with open("config.yaml", 'r', encoding="utf-8") as file:
|
20 |
config = yaml.load(file, Loader=yaml.FullLoader)
|
|
|
22 |
config["train"]["max-sent-len"] = config["predictor"]["window"]
|
23 |
config["train"]["max-token-count"] = config["predictor"]["window"] * 3
|
24 |
|
25 |
+
predictor = PredictTri(config)
|
26 |
+
|
27 |
+
def diacritze_full(text):
|
28 |
+
do_hard_mask = None
|
29 |
+
threshold = None
|
30 |
+
predictor.create_dataloader(text, False, do_hard_mask, threshold)
|
31 |
+
diacritized_lines = predictor.predict_partial(do_partial=False, lines=text.split('\n'))
|
32 |
+
return diacritized_lines
|
33 |
+
|
34 |
+
def diacritze_partial(text, mask_mode, threshold):
|
35 |
+
do_partial = True
|
36 |
+
predictor.create_dataloader(text, do_partial, mask_mode=="Hard", threshold)
|
37 |
+
diacritized_lines = predictor.predict_partial(do_partial=do_partial, lines=text.split('\n'))
|
38 |
return diacritized_lines
|
39 |
|
40 |
+
with gr.Blocks(theme=gr.themes.Default(text_size="lg")) as demo:
|
41 |
gr.Markdown(
|
42 |
"""
|
43 |
# Partial Diacritization: A Context-Contrastive Inference Approach
|
44 |
+
### Authors: Muhammad ElNokrashy, Badr AlKhamissi
|
45 |
+
### Paper Link: TBD
|
46 |
""")
|
47 |
|
48 |
+
with gr.Tab(label="Full Diacritization"):
|
49 |
+
|
50 |
+
full_input_txt = gr.Textbox(
|
51 |
+
placeholder="اكتب هنا",
|
52 |
+
lines=5,
|
53 |
+
label="Input",
|
54 |
+
type='text',
|
55 |
+
rtl=True,
|
56 |
+
text_align='right',
|
57 |
+
)
|
58 |
+
|
59 |
+
full_output_txt = gr.Textbox(
|
60 |
+
lines=5,
|
61 |
+
label="Output",
|
62 |
+
type='text',
|
63 |
+
rtl=True,
|
64 |
+
text_align='right',
|
65 |
+
show_copy_button=True,
|
66 |
+
)
|
67 |
+
|
68 |
+
full_btn = gr.Button(value="Shakkel")
|
69 |
+
full_btn.click(diacritze_full, inputs=[full_input_txt], outputs=[full_output_txt])
|
70 |
+
|
71 |
+
gr.Examples(
|
72 |
+
examples=[
|
73 |
+
"ولو حمل من مجلس الخيار ، ولم يمنع من الكلام"
|
74 |
+
],
|
75 |
+
inputs=full_input_txt,
|
76 |
+
outputs=full_output_txt,
|
77 |
+
fn=diacritze_full,
|
78 |
+
cache_examples=True,
|
79 |
+
)
|
80 |
+
|
81 |
+
with gr.Tab(label="Partial Diacritization") as partial_settings:
|
82 |
+
with gr.Row():
|
83 |
+
masking_mode = gr.Radio(choices=["Hard", "Soft"], value="Hard", label="Masking Mode")
|
84 |
+
threshold_slider = gr.Slider(label="Soft Masking Threshold", minimum=0, maximum=1, value=0.1)
|
85 |
+
|
86 |
+
partial_input_txt = gr.Textbox(
|
87 |
+
placeholder="اكتب هنا",
|
88 |
+
lines=5,
|
89 |
+
label="Input",
|
90 |
+
type='text',
|
91 |
+
rtl=True,
|
92 |
+
text_align='right',
|
93 |
+
)
|
94 |
+
|
95 |
+
partial_output_txt = gr.Textbox(
|
96 |
+
lines=5,
|
97 |
+
label="Output",
|
98 |
+
type='text',
|
99 |
+
rtl=True,
|
100 |
+
text_align='right',
|
101 |
+
show_copy_button=True,
|
102 |
+
)
|
103 |
+
|
104 |
+
partial_btn = gr.Button(value="Shakkel")
|
105 |
+
partial_btn.click(diacritze_partial, inputs=[partial_input_txt, masking_mode, threshold_slider], outputs=[partial_output_txt])
|
106 |
+
|
107 |
+
gr.Examples(
|
108 |
+
examples=[
|
109 |
+
["ولو حمل من مجلس الخيار ، ولم يمنع من الكلام", "Hard", 0],
|
110 |
+
],
|
111 |
+
inputs=[partial_input_txt, masking_mode, threshold_slider],
|
112 |
+
outputs=partial_output_txt,
|
113 |
+
fn=diacritze_partial,
|
114 |
+
cache_examples=True,
|
115 |
+
)
|
116 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
117 |
|
|
|
|
|
118 |
|
119 |
if __name__ == "__main__":
|
120 |
demo.queue().launch(
|
data_utils.py
CHANGED
@@ -26,7 +26,7 @@ class DatasetUtils:
|
|
26 |
self.max_sent_len = config["train"]["max-sent-len"]
|
27 |
self.max_token_count = config["train"]["max-token-count"]
|
28 |
self.pad_target_val = -100
|
29 |
-
self.pad_char_id = du.LETTER_LIST.index('<pad>')
|
30 |
|
31 |
self.markov_signal = config['train'].get('markov-signal', False)
|
32 |
self.batch_first = config['train'].get('batch-first', True)
|
|
|
26 |
self.max_sent_len = config["train"]["max-sent-len"]
|
27 |
self.max_token_count = config["train"]["max-token-count"]
|
28 |
self.pad_target_val = -100
|
29 |
+
self.pad_char_id = du.DIAC_PAD_IDX #LETTER_LIST.index('<pad>')
|
30 |
|
31 |
self.markov_signal = config['train'].get('markov-signal', False)
|
32 |
self.batch_first = config['train'].get('batch-first', True)
|
diac_utils.py
CHANGED
@@ -37,6 +37,8 @@ HARAKAT_MAP = [
|
|
37 |
(0,0,0), #< Padding == -1 (also for spaces)
|
38 |
]
|
39 |
|
|
|
|
|
40 |
SPECIAL_TOKENS = ['<pad>', '<unk>', '<num>', '<punc>']
|
41 |
LETTER_LIST = SPECIAL_TOKENS + list("ءآأؤإئابةتثجحخدذرزسشصضطظعغفقكلمنهوىي")
|
42 |
CLASSES_LIST = [' ', 'َ', 'ً', 'ُ', 'ٌ', 'ِ', 'ٍ', 'ْ', 'ّ', 'َّ', 'ًّ', 'ُّ', 'ٌّ', 'ِّ', 'ٍّ']
|
@@ -63,13 +65,13 @@ def shakkel_char(diac: int, tanween: bool, shadda: bool) -> str:
|
|
63 |
return returned_text
|
64 |
|
65 |
def diac_ids_of_line(line: str):
|
66 |
-
words = tokenize(line)
|
67 |
diacs = []
|
|
|
68 |
for word in words:
|
69 |
word_chars = split_word_on_characters_with_diacritics(word)
|
70 |
-
|
71 |
diacs.extend(cy)
|
72 |
-
diacs.append(
|
73 |
return np.array(diacs[:-1])
|
74 |
|
75 |
def strip_unknown_tashkeel(word: str):
|
@@ -77,6 +79,23 @@ def strip_unknown_tashkeel(word: str):
|
|
77 |
return word
|
78 |
return ''.join(c for c in word if c not in UNKNOWN_DIACRITICS)
|
79 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
80 |
def split_word_on_characters_with_diacritics(word: str):
|
81 |
'''
|
82 |
TODO! Make faster without deque and looping
|
@@ -100,6 +119,18 @@ def split_word_on_characters_with_diacritics(word: str):
|
|
100 |
return chars_w_diac
|
101 |
|
102 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
103 |
def char_type(char: str):
|
104 |
if char in LETTER_LIST:
|
105 |
return LETTER_LIST.index(char)
|
@@ -220,4 +251,4 @@ def flat2_3head(diac_idx):
|
|
220 |
tanween += [c_out[1]]
|
221 |
shadda += [c_out[2]]
|
222 |
|
223 |
-
return np.array(haraka), np.array(tanween), np.array(shadda)
|
|
|
37 |
(0,0,0), #< Padding == -1 (also for spaces)
|
38 |
]
|
39 |
|
40 |
+
DIAC_PAD_IDX = -1
|
41 |
+
|
42 |
SPECIAL_TOKENS = ['<pad>', '<unk>', '<num>', '<punc>']
|
43 |
LETTER_LIST = SPECIAL_TOKENS + list("ءآأؤإئابةتثجحخدذرزسشصضطظعغفقكلمنهوىي")
|
44 |
CLASSES_LIST = [' ', 'َ', 'ً', 'ُ', 'ٌ', 'ِ', 'ٍ', 'ْ', 'ّ', 'َّ', 'ًّ', 'ُّ', 'ٌّ', 'ِّ', 'ٍّ']
|
|
|
65 |
return returned_text
|
66 |
|
67 |
def diac_ids_of_line(line: str):
|
|
|
68 |
diacs = []
|
69 |
+
words = tokenize(line)
|
70 |
for word in words:
|
71 |
word_chars = split_word_on_characters_with_diacritics(word)
|
72 |
+
_cx, cy, _cy_3head = create_label_for_word(word_chars)
|
73 |
diacs.extend(cy)
|
74 |
+
diacs.append(DIAC_PAD_IDX)
|
75 |
return np.array(diacs[:-1])
|
76 |
|
77 |
def strip_unknown_tashkeel(word: str):
|
|
|
79 |
return word
|
80 |
return ''.join(c for c in word if c not in UNKNOWN_DIACRITICS)
|
81 |
|
82 |
+
def create_gt_labels(lines):
|
83 |
+
gt_labels = []
|
84 |
+
for line in lines:
|
85 |
+
# gt_labels_line = []
|
86 |
+
# tokens = tokenize(line.strip())
|
87 |
+
# for w_idx, word in enumerate(tokens):
|
88 |
+
# split_word = self.split_word_on_characters_with_diacritics(word)
|
89 |
+
# _, cy_flat, _ = du.create_label_for_word(split_word)
|
90 |
+
|
91 |
+
# gt_labels_line.extend(cy_flat)
|
92 |
+
# if w_idx+1 < len(tokens):
|
93 |
+
# gt_labels_line += [0]
|
94 |
+
|
95 |
+
gt_labels_line = diac_ids_of_line(line)
|
96 |
+
gt_labels.append(gt_labels_line)
|
97 |
+
return gt_labels
|
98 |
+
|
99 |
def split_word_on_characters_with_diacritics(word: str):
|
100 |
'''
|
101 |
TODO! Make faster without deque and looping
|
|
|
119 |
return chars_w_diac
|
120 |
|
121 |
|
122 |
+
def load_lines(path: str, *, strip: bool):
|
123 |
+
with open(path, 'r', encoding="utf-8", newline='\n') as fin:
|
124 |
+
if strip:
|
125 |
+
original_lines = [strip_tashkeel(normalize_spaces(line)) for line in fin.readlines()]
|
126 |
+
else:
|
127 |
+
original_lines = [normalize_spaces(line) for line in fin.readlines()]
|
128 |
+
return original_lines
|
129 |
+
|
130 |
+
def normalize_spaces(line: str):
|
131 |
+
return ' '.join(tokenize(line.strip()))
|
132 |
+
|
133 |
+
|
134 |
def char_type(char: str):
|
135 |
if char in LETTER_LIST:
|
136 |
return LETTER_LIST.index(char)
|
|
|
251 |
tanween += [c_out[1]]
|
252 |
shadda += [c_out[2]]
|
253 |
|
254 |
+
return np.array(haraka), np.array(tanween), np.array(shadda)
|
gradio_cached_examples/16/log.csv
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
Output,flag,username,timestamp
|
2 |
+
ولو حمَل من مجلسِ الخيارِ ، ولم يُمنعْ من الكلام,,,2024-01-11 01:33:39.114395
|
gradio_cached_examples/6/log.csv
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
Output,flag,username,timestamp
|
2 |
+
وَلَوْ حَمَلَ مِنْ مَجْلِسِ الْخِيَارِ ، وَلَمْ يُمْنَعْ مِنْ الْكَلَامِ,,,2024-01-11 01:30:56.446393
|
predict.py
CHANGED
@@ -12,7 +12,7 @@ import numpy as np
|
|
12 |
import torch as T
|
13 |
from torch.utils.data import DataLoader
|
14 |
|
15 |
-
from diac_utils import HARAKAT_MAP, shakkel_char,
|
16 |
from model_partial import PartialDD
|
17 |
from model_dd import DiacritizerD2
|
18 |
from data_utils import DatasetUtils
|
@@ -31,10 +31,21 @@ def apply_tashkeel(
|
|
31 |
diacs: Union[np.ndarray, T.Tensor]
|
32 |
):
|
33 |
line_w_diacs = ""
|
34 |
-
|
35 |
-
|
|
|
|
|
|
|
|
|
|
|
36 |
line_w_diacs += ch
|
37 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
return line_w_diacs
|
39 |
|
40 |
def diac_text(data, model_output_base, model_output_ctxt, selection_mode='contrastive-hard', threshold=0.1):
|
@@ -80,29 +91,16 @@ def diac_text(data, model_output_base, model_output_ctxt, selection_mode='contra
|
|
80 |
line = apply_tashkeel(line, line_diacs)
|
81 |
output.append(line)
|
82 |
|
83 |
-
return
|
84 |
|
85 |
class Predictor:
|
86 |
-
def __init__(self, config
|
87 |
|
88 |
self.data_utils = DatasetUtils(config)
|
89 |
vocab_size = len(self.data_utils.letter_list)
|
90 |
word_embeddings = self.data_utils.embeddings
|
|
|
91 |
|
92 |
-
stride = config["segment"]["stride"]
|
93 |
-
window = config["segment"]["window"]
|
94 |
-
min_window = config["segment"]["min-window"]
|
95 |
-
|
96 |
-
segments, mapping = segment([text], stride, window, min_window)
|
97 |
-
|
98 |
-
mapping_lines = []
|
99 |
-
for sent_idx, seg_idx, word_idx, char_idx in mapping:
|
100 |
-
mapping_lines += [f"{sent_idx}, {seg_idx}, {word_idx}, {char_idx}"]
|
101 |
-
|
102 |
-
self.mapping = self.data_utils.load_mapping_v3_from_list(mapping_lines)
|
103 |
-
self.original_lines = [text]
|
104 |
-
self.segments = segments
|
105 |
-
|
106 |
self.device = T.device(
|
107 |
config['predictor'].get('device', 'cuda:0')
|
108 |
if T.cuda.is_available() else 'cpu'
|
@@ -115,16 +113,39 @@ class Predictor:
|
|
115 |
self.model.to(self.device)
|
116 |
self.model.eval()
|
117 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
118 |
self.data_loader = DataLoader(
|
119 |
DataRetriever(self.data_utils, segments),
|
120 |
-
batch_size=config["predictor"].get("batch-size", 32),
|
121 |
shuffle=False,
|
122 |
-
num_workers=config['loader'].get('num-workers', 0),
|
123 |
)
|
124 |
-
|
125 |
class PredictTri(Predictor):
|
126 |
-
def __init__(self, config
|
127 |
-
super().__init__(config
|
128 |
self.diacritics = {
|
129 |
"FATHA": 1,
|
130 |
"KASRA": 2,
|
@@ -146,11 +167,15 @@ class PredictTri(Predictor):
|
|
146 |
diacritized_lines, _ = self.coalesce_votes_by_majority(y_gen_diac, y_gen_tanween, y_gen_shadda)
|
147 |
return diacritized_lines
|
148 |
|
149 |
-
def predict_partial(self, do_partial):
|
150 |
outputs = self.model.predict_partial(self.data_loader, return_extra=True, eval_only='both', do_partial=do_partial)
|
151 |
-
y_gen_diac, y_gen_tanween, y_gen_shadda = outputs['diacritics']
|
152 |
|
153 |
-
|
|
|
|
|
|
|
|
|
|
|
154 |
return '\n'.join(diac_lines)
|
155 |
|
156 |
def predict_majority_vote_context_contrastive(self, overwrite_cache=False):
|
|
|
12 |
import torch as T
|
13 |
from torch.utils.data import DataLoader
|
14 |
|
15 |
+
from diac_utils import HARAKAT_MAP, shakkel_char, flat2_3head
|
16 |
from model_partial import PartialDD
|
17 |
from model_dd import DiacritizerD2
|
18 |
from data_utils import DatasetUtils
|
|
|
31 |
diacs: Union[np.ndarray, T.Tensor]
|
32 |
):
|
33 |
line_w_diacs = ""
|
34 |
+
ts, tw = diacs.shape
|
35 |
+
diacs = diacs.flatten()
|
36 |
+
diacs_h3 = flat2_3head(diacs)
|
37 |
+
diacs_h3 = tuple(x.reshape(ts, tw) for x in diacs_h3)
|
38 |
+
diac_char_idx = 0
|
39 |
+
diac_word_idx = 0
|
40 |
+
for ch in line:
|
41 |
line_w_diacs += ch
|
42 |
+
if ch == " ":
|
43 |
+
diac_char_idx = 0
|
44 |
+
diac_word_idx += 1
|
45 |
+
else:
|
46 |
+
tashkeel = (diacs_h3[0][diac_word_idx][diac_char_idx], diacs_h3[1][diac_word_idx][diac_char_idx], diacs_h3[2][diac_word_idx][diac_char_idx])
|
47 |
+
diac_char_idx += 1
|
48 |
+
line_w_diacs += shakkel_char(*tashkeel)
|
49 |
return line_w_diacs
|
50 |
|
51 |
def diac_text(data, model_output_base, model_output_ctxt, selection_mode='contrastive-hard', threshold=0.1):
|
|
|
91 |
line = apply_tashkeel(line, line_diacs)
|
92 |
output.append(line)
|
93 |
|
94 |
+
return output
|
95 |
|
96 |
class Predictor:
|
97 |
+
def __init__(self, config):
|
98 |
|
99 |
self.data_utils = DatasetUtils(config)
|
100 |
vocab_size = len(self.data_utils.letter_list)
|
101 |
word_embeddings = self.data_utils.embeddings
|
102 |
+
self.config = config
|
103 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
104 |
self.device = T.device(
|
105 |
config['predictor'].get('device', 'cuda:0')
|
106 |
if T.cuda.is_available() else 'cpu'
|
|
|
113 |
self.model.to(self.device)
|
114 |
self.model.eval()
|
115 |
|
116 |
+
def create_dataloader(self, text, do_partial, do_hard_mask, threshold):
|
117 |
+
self.threshold = threshold
|
118 |
+
self.do_hard_mask = do_hard_mask
|
119 |
+
|
120 |
+
stride = self.config["segment"]["stride"]
|
121 |
+
window = self.config["segment"]["window"]
|
122 |
+
min_window = self.config["segment"]["min-window"]
|
123 |
+
if self.do_hard_mask or not do_partial:
|
124 |
+
segments, mapping = segment([text], stride, window, min_window)
|
125 |
+
|
126 |
+
mapping_lines = []
|
127 |
+
for sent_idx, seg_idx, word_idx, char_idx in mapping:
|
128 |
+
mapping_lines += [f"{sent_idx}, {seg_idx}, {word_idx}, {char_idx}"]
|
129 |
+
|
130 |
+
self.mapping = self.data_utils.load_mapping_v3_from_list(mapping_lines)
|
131 |
+
self.original_lines = [text]
|
132 |
+
self.segments = segments
|
133 |
+
else:
|
134 |
+
segments = text.split('\n')
|
135 |
+
|
136 |
+
self.segments = segments
|
137 |
+
self.original_lines = text.split('\n')
|
138 |
+
|
139 |
self.data_loader = DataLoader(
|
140 |
DataRetriever(self.data_utils, segments),
|
141 |
+
batch_size=self.config["predictor"].get("batch-size", 32),
|
142 |
shuffle=False,
|
143 |
+
num_workers=self.config['loader'].get('num-workers', 0),
|
144 |
)
|
145 |
+
|
146 |
class PredictTri(Predictor):
|
147 |
+
def __init__(self, config):
|
148 |
+
super().__init__(config)
|
149 |
self.diacritics = {
|
150 |
"FATHA": 1,
|
151 |
"KASRA": 2,
|
|
|
167 |
diacritized_lines, _ = self.coalesce_votes_by_majority(y_gen_diac, y_gen_tanween, y_gen_shadda)
|
168 |
return diacritized_lines
|
169 |
|
170 |
+
def predict_partial(self, do_partial, lines):
|
171 |
outputs = self.model.predict_partial(self.data_loader, return_extra=True, eval_only='both', do_partial=do_partial)
|
|
|
172 |
|
173 |
+
if self.do_hard_mask or not do_partial:
|
174 |
+
y_gen_diac, y_gen_tanween, y_gen_shadda = outputs['diacritics']
|
175 |
+
diac_lines, _ = self.coalesce_votes_by_majority(y_gen_diac, y_gen_tanween, y_gen_shadda)
|
176 |
+
else:
|
177 |
+
diac_lines = diac_text(lines, outputs["other"][1], outputs["other"][0], selection_mode='1', threshold=self.threshold)
|
178 |
+
|
179 |
return '\n'.join(diac_lines)
|
180 |
|
181 |
def predict_majority_vote_context_contrastive(self, overwrite_cache=False):
|