initialized repo
Browse files- .gitignore +4 -0
- app.py +56 -0
- components/attention.py +130 -0
- components/k_lstm.py +218 -0
- components/linear_scheduler.py +24 -0
- components/rnn.py +0 -0
- components/rnn_base.py +199 -0
- config.yaml +117 -0
- data_utils.py +230 -0
- dataloader.py +64 -0
- diac_utils.py +223 -0
- model_dd.py +526 -0
- model_partial.py +348 -0
- predict.py +170 -0
- segment.py +89 -0
.gitignore
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.pyc
|
2 |
+
*.pt
|
3 |
+
*.vec
|
4 |
+
.DS_Store
|
app.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import yaml
|
3 |
+
import gdown
|
4 |
+
import gradio as gr
|
5 |
+
from predict import PredictTri
|
6 |
+
|
7 |
+
output_path = "tashkeela-d2.pt"
|
8 |
+
if not os.path.exists(output_path):
|
9 |
+
model_gdrive_id = "1FGelqImFkESbTyRsx_elkKIOZ9VbhRuo"
|
10 |
+
gdown.download(id=model_gdrive_id, output=output_path, quiet=False)
|
11 |
+
|
12 |
+
output_path = "vocab.vec"
|
13 |
+
if not os.path.exists(output_path):
|
14 |
+
vocab_gdrive_id = "1-0muGvcSYEf8RAVRcwXay4MRex6kmCii"
|
15 |
+
gdown.download(id=vocab_gdrive_id, output=output_path, quiet=False)
|
16 |
+
|
17 |
+
with open("config.yaml", 'r', encoding="utf-8") as file:
|
18 |
+
config = yaml.load(file, Loader=yaml.FullLoader)
|
19 |
+
|
20 |
+
config["train"]["max-sent-len"] = config["predictor"]["window"]
|
21 |
+
config["train"]["max-token-count"] = config["predictor"]["window"] * 3
|
22 |
+
|
23 |
+
def diacritze(text):
|
24 |
+
print(text)
|
25 |
+
predictor = PredictTri(config, text)
|
26 |
+
diacritized_lines = predictor.predict_majority_vote()
|
27 |
+
return '\n'.join(diacritized_lines)
|
28 |
+
|
29 |
+
with gr.Blocks() as demo:
|
30 |
+
gr.Markdown(
|
31 |
+
"""
|
32 |
+
# Partial Diacritization
|
33 |
+
TODO: put paper links here
|
34 |
+
""")
|
35 |
+
input_txt = gr.Textbox(
|
36 |
+
placeholder="اكتب هنا",
|
37 |
+
lines=5,
|
38 |
+
label="Input",
|
39 |
+
type='text',
|
40 |
+
# rtl=True,
|
41 |
+
# text_align='right',
|
42 |
+
)
|
43 |
+
|
44 |
+
output_txt = gr.Textbox(
|
45 |
+
lines=5,
|
46 |
+
label="Output",
|
47 |
+
type='text',
|
48 |
+
# rtl=True,
|
49 |
+
# text_align='right',
|
50 |
+
)
|
51 |
+
|
52 |
+
btn = gr.Button(value="Shakkel")
|
53 |
+
btn.click(diacritze, inputs=input_txt, outputs=output_txt)
|
54 |
+
|
55 |
+
if __name__ == "__main__":
|
56 |
+
demo.launch()
|
components/attention.py
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import (
|
2 |
+
Optional,
|
3 |
+
)
|
4 |
+
import math
|
5 |
+
|
6 |
+
import torch as T
|
7 |
+
from torch import nn
|
8 |
+
from torch.nn import functional as F
|
9 |
+
|
10 |
+
import opt_einsum as oe
|
11 |
+
|
12 |
+
from torch import Tensor
|
13 |
+
|
14 |
+
einsum = oe.contract
|
15 |
+
|
16 |
+
|
17 |
+
def masked_softmax(xs: Tensor, mask: Tensor, dim: int = -1, eps=1e-12):
|
18 |
+
xs = xs.masked_fill(~mask, -1e9)
|
19 |
+
xs = F.softmax(xs, dim=dim)
|
20 |
+
return xs
|
21 |
+
|
22 |
+
class Attention(nn.Module):
|
23 |
+
def __init__(
|
24 |
+
self,
|
25 |
+
kind: str,
|
26 |
+
query_dim: int,
|
27 |
+
input_dim: int,
|
28 |
+
output_dim: int = None,
|
29 |
+
activation: str = 'auto',
|
30 |
+
scaled = True,
|
31 |
+
):
|
32 |
+
super().__init__()
|
33 |
+
assert kind in [
|
34 |
+
'dot',
|
35 |
+
'linear',
|
36 |
+
]
|
37 |
+
|
38 |
+
self.kind = kind
|
39 |
+
self.Dq = query_dim
|
40 |
+
self.Din = input_dim
|
41 |
+
self.Dout = output_dim or self.Din
|
42 |
+
self.activation = 'auto'
|
43 |
+
self.scaled = scaled
|
44 |
+
|
45 |
+
self.Wq_ = nn.Linear(self.Dq, self.Din)
|
46 |
+
self.Wk_ = nn.Linear(self.Din, self.Din)
|
47 |
+
self.Wv_ = nn.Linear(self.Din, self.Dout)
|
48 |
+
self.Wz_ = nn.Linear(self.Din, self.Dout)
|
49 |
+
|
50 |
+
def forward(
|
51 |
+
self,
|
52 |
+
query: Tensor,
|
53 |
+
data: Tensor,
|
54 |
+
content_mask: Optional[Tensor] = None,
|
55 |
+
prejudice_mask: Optional[Tensor] = None,
|
56 |
+
):
|
57 |
+
#^ query: [b, ts, tw, dq]
|
58 |
+
#^ data: [b, ts, di]
|
59 |
+
#^ content_mask: [b, ts, tw]
|
60 |
+
#^ prejudice_mask: [b, ts, ts]
|
61 |
+
#^ => output: [b, ts, tw, dz]
|
62 |
+
|
63 |
+
dimB, dimS, dimW, dimI = query.shape
|
64 |
+
|
65 |
+
# TODO: Optimize out the [ts, ts, *] intermediate
|
66 |
+
qs = self.Wq_(query)
|
67 |
+
ks = self.Wk_(data)
|
68 |
+
vs = self.Wv_(data)
|
69 |
+
|
70 |
+
if content_mask is not None:
|
71 |
+
words_mask = content_mask.any(2)
|
72 |
+
#^ words_mask : [b, ts]
|
73 |
+
else:
|
74 |
+
words_mask = qs.new_ones((dimB, dimS))
|
75 |
+
|
76 |
+
if self.kind == 'linear':
|
77 |
+
# Ref: https://twitter.com/francoisfleuret/status/1267455240007188486
|
78 |
+
assert prejudice_mask is None, "Linear mode does not support prejudice_mask."
|
79 |
+
assert content_mask is not None, "Linear mode requires a content_mask."
|
80 |
+
qs = T.relu(qs) * content_mask.unsqueeze(3)
|
81 |
+
#^ qs: [bswi]
|
82 |
+
ks = T.relu(ks) * words_mask.unsqueeze(2)
|
83 |
+
#^ ks: [bsi]
|
84 |
+
vks = einsum("bsi, bsz -> bzi", ks, vs)
|
85 |
+
#^ vks : [b, dz, di]
|
86 |
+
zs = einsum("bswi, bzi -> bswz", qs, vks)
|
87 |
+
#^ zs : [b, ts, tw, dz]
|
88 |
+
if self.scaled:
|
89 |
+
ks = ks.sum(1)
|
90 |
+
#^ ks: [bi]
|
91 |
+
denom = einsum("bswi, bi -> bsw", qs, ks) + 1e-9
|
92 |
+
zs = zs / denom
|
93 |
+
|
94 |
+
elif self.kind == 'dot':
|
95 |
+
# Ref: https://arxiv.org/abs/1706.03762
|
96 |
+
# s=ts in q
|
97 |
+
# S=ts in ks,vs
|
98 |
+
att_map = einsum("bqwi, bki -> bqkw", qs, ks)
|
99 |
+
#^ [b, ts:q, ts:k, tw]
|
100 |
+
if self.scaled == 'seqlen':
|
101 |
+
att_map_ndim = len(att_map.shape) - 1
|
102 |
+
norm_coeff = words_mask.sum(1).view(-1, *([1] * att_map_ndim))
|
103 |
+
#^ [b, _, _, _]
|
104 |
+
att_map = att_map / T.sqrt(norm_coeff.float())
|
105 |
+
else:
|
106 |
+
att_map = att_map / math.sqrt(self.Din)
|
107 |
+
|
108 |
+
if content_mask is None and prejudice_mask is None:
|
109 |
+
att_map = F.softmax(att_map, dim=2)
|
110 |
+
else:
|
111 |
+
if content_mask is None:
|
112 |
+
assert prejudice_mask is not None # !for mypy
|
113 |
+
qk_mask = prejudice_mask.unsqueeze(3)
|
114 |
+
#^ qk_mask : [b, ts:q, ts:k, tw^]
|
115 |
+
elif prejudice_mask is None:
|
116 |
+
qk_mask = words_mask.unsqueeze(1).unsqueeze(3) * content_mask.unsqueeze(2)
|
117 |
+
#^ qk_mask : [b, ts:q, ts:k^, tw]
|
118 |
+
else:
|
119 |
+
qk_mask = words_mask.unsqueeze(1).unsqueeze(3)
|
120 |
+
# qk_mask = words_mask.unsqueeze(1).unsqueeze(3) * content_mask.unsqueeze(2)
|
121 |
+
qk_mask = qk_mask * prejudice_mask.unsqueeze(3)
|
122 |
+
#^ qk_mask : [b, ts:q^, ts:k, tw]
|
123 |
+
|
124 |
+
att_map = masked_softmax(att_map, qk_mask.bool(), dim=2)
|
125 |
+
|
126 |
+
#^ att_map : [b, ts:q, ts:k, tw]
|
127 |
+
zs = einsum("bqkw, bkz -> bqwz", att_map, vs)
|
128 |
+
|
129 |
+
zs = self.Wz_(zs)
|
130 |
+
return zs, att_map
|
components/k_lstm.py
ADDED
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import (
|
2 |
+
Tuple,
|
3 |
+
List,
|
4 |
+
Optional,
|
5 |
+
Dict,
|
6 |
+
Callable,
|
7 |
+
Union,
|
8 |
+
cast,
|
9 |
+
)
|
10 |
+
from collections import namedtuple
|
11 |
+
from abc import ABC, abstractmethod
|
12 |
+
from dataclasses import dataclass
|
13 |
+
|
14 |
+
import numpy as np
|
15 |
+
|
16 |
+
import torch as T
|
17 |
+
from torch import nn
|
18 |
+
from torch.nn import functional as F
|
19 |
+
|
20 |
+
from torch import Tensor
|
21 |
+
|
22 |
+
from .rnn_base import (
|
23 |
+
IRecurrentCell,
|
24 |
+
IRecurrentCellBuilder,
|
25 |
+
RecurrentLayer,
|
26 |
+
RecurrentLayerStack,
|
27 |
+
)
|
28 |
+
|
29 |
+
__all__ = [
|
30 |
+
'K_LSTM',
|
31 |
+
'K_LSTM_Cell',
|
32 |
+
'K_LSTM_Cell_Builder',
|
33 |
+
]
|
34 |
+
|
35 |
+
ACTIVATIONS = {
|
36 |
+
'sigmoid': nn.Sigmoid(),
|
37 |
+
'tanh': nn.Tanh(),
|
38 |
+
'hard_tanh': nn.Hardtanh(),
|
39 |
+
'relu': nn.ReLU(),
|
40 |
+
}
|
41 |
+
|
42 |
+
GateSpans = namedtuple('GateSpans', ['I', 'F', 'G', 'O'])
|
43 |
+
|
44 |
+
@dataclass
|
45 |
+
class K_LSTM_Cell_Builder(IRecurrentCellBuilder):
|
46 |
+
vertical_dropout : float = 0.0
|
47 |
+
recurrent_dropout : float = 0.0
|
48 |
+
recurrent_dropout_mode : str = 'gal_tied'
|
49 |
+
input_kernel_initialization : str = 'xavier_uniform'
|
50 |
+
recurrent_activation : str = 'sigmoid'
|
51 |
+
tied_forget_gate : bool = False
|
52 |
+
|
53 |
+
def make(self, input_size: int):
|
54 |
+
return K_LSTM_Cell(input_size, self)
|
55 |
+
|
56 |
+
class K_LSTM_Cell(IRecurrentCell):
|
57 |
+
def __repr__(self):
|
58 |
+
return (
|
59 |
+
f'{self.__class__.__name__}('
|
60 |
+
+ ', '.join(
|
61 |
+
[
|
62 |
+
f'in: {self.Dx}',
|
63 |
+
f'hid: {self.Dh}',
|
64 |
+
f'rdo: {self.recurrent_dropout_p} @{self.recurrent_dropout_mode}',
|
65 |
+
f'vdo: {self.vertical_dropout_p}'
|
66 |
+
]
|
67 |
+
)
|
68 |
+
+')'
|
69 |
+
)
|
70 |
+
|
71 |
+
def __init__(
|
72 |
+
self,
|
73 |
+
input_size: int,
|
74 |
+
args: K_LSTM_Cell_Builder,
|
75 |
+
):
|
76 |
+
super().__init__()
|
77 |
+
self._args = args
|
78 |
+
self.Dx = input_size
|
79 |
+
self.Dh = args.hidden_size
|
80 |
+
self.recurrent_kernel = nn.Linear(self.Dh, self.Dh * 4)
|
81 |
+
self.input_kernel = nn.Linear(self.Dx, self.Dh * 4)
|
82 |
+
|
83 |
+
self.recurrent_dropout_p = args.recurrent_dropout or 0.0
|
84 |
+
self.vertical_dropout_p = args.vertical_dropout or 0.0
|
85 |
+
self.recurrent_dropout_mode = args.recurrent_dropout_mode
|
86 |
+
|
87 |
+
self.recurrent_dropout = nn.Dropout(self.recurrent_dropout_p)
|
88 |
+
self.vertical_dropout = nn.Dropout(self.vertical_dropout_p)
|
89 |
+
|
90 |
+
self.tied_forget_gate = args.tied_forget_gate
|
91 |
+
|
92 |
+
if isinstance(args.recurrent_activation, str):
|
93 |
+
self.fun_rec = ACTIVATIONS[args.recurrent_activation]
|
94 |
+
else:
|
95 |
+
self.fun_rec = args.recurrent_activation
|
96 |
+
|
97 |
+
self.reset_parameters_()
|
98 |
+
|
99 |
+
# @T.jit.ignore
|
100 |
+
def get_recurrent_weights(self):
|
101 |
+
# type: () -> Tuple[GateSpans, GateSpans]
|
102 |
+
W = self.recurrent_kernel.weight.chunk(4, 0)
|
103 |
+
b = self.recurrent_kernel.bias.chunk(4, 0)
|
104 |
+
W = GateSpans(W[0], W[1], W[2], W[3])
|
105 |
+
b = GateSpans(b[0], b[1], b[2], b[3])
|
106 |
+
return W, b
|
107 |
+
|
108 |
+
# @T.jit.ignore
|
109 |
+
def get_input_weights(self):
|
110 |
+
# type: () -> Tuple[GateSpans, GateSpans]
|
111 |
+
W = self.input_kernel.weight.chunk(4, 0)
|
112 |
+
b = self.input_kernel.bias.chunk(4, 0)
|
113 |
+
W = GateSpans(W[0], W[1], W[2], W[3])
|
114 |
+
b = GateSpans(b[0], b[1], b[2], b[3])
|
115 |
+
return W, b
|
116 |
+
|
117 |
+
@T.jit.ignore
|
118 |
+
def reset_parameters_(self):
|
119 |
+
rw, rb = self.get_recurrent_weights()
|
120 |
+
iw, ib = self.get_input_weights()
|
121 |
+
|
122 |
+
nn.init.zeros_(self.input_kernel.bias)
|
123 |
+
nn.init.zeros_(self.recurrent_kernel.bias)
|
124 |
+
nn.init.ones_(rb.F)
|
125 |
+
#^ forget bias
|
126 |
+
|
127 |
+
for W in rw:
|
128 |
+
nn.init.orthogonal_(W)
|
129 |
+
for W in iw:
|
130 |
+
nn.init.xavier_uniform_(W)
|
131 |
+
|
132 |
+
@T.jit.export
|
133 |
+
def get_init_state(self, input: Tensor) -> Tuple[Tensor, Tensor]:
|
134 |
+
batch_size = input.shape[1]
|
135 |
+
h0 = T.zeros(batch_size, self.Dh, device=input.device)
|
136 |
+
c0 = T.zeros(batch_size, self.Dh, device=input.device)
|
137 |
+
return (h0, c0)
|
138 |
+
|
139 |
+
def apply_input_kernel(self, xt: Tensor) -> List[Tensor]:
|
140 |
+
xto = self.vertical_dropout(xt)
|
141 |
+
out = self.input_kernel(xto).chunk(4, 1)
|
142 |
+
# return cast(List[Tensor], out)
|
143 |
+
return out
|
144 |
+
|
145 |
+
def apply_recurrent_kernel(self, h_tm1: Tensor):
|
146 |
+
#^ h_tm1 : [b h]
|
147 |
+
mode = self.recurrent_dropout_mode
|
148 |
+
if mode == 'gal_tied':
|
149 |
+
hto = self.recurrent_dropout(h_tm1)
|
150 |
+
out = self.recurrent_kernel(hto)
|
151 |
+
#^ out : [b 4h]
|
152 |
+
outs = out.chunk(4, -1)
|
153 |
+
elif mode == 'gal_gates':
|
154 |
+
outs = []
|
155 |
+
WW, bb = self.get_recurrent_weights()
|
156 |
+
for i in range(4):
|
157 |
+
hto = self.recurrent_dropout(h_tm1)
|
158 |
+
outs.append(F.linear(hto, WW[i], bb[i]))
|
159 |
+
else:
|
160 |
+
outs = self.recurrent_kernel(h_tm1).chunk(4, -1)
|
161 |
+
return outs
|
162 |
+
|
163 |
+
def forward(self, input, state):
|
164 |
+
# type: (Tensor, Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]]
|
165 |
+
#^ input : [b i]
|
166 |
+
#^ state.h : [b h]
|
167 |
+
|
168 |
+
(h_tm1, c_tm1) = state
|
169 |
+
|
170 |
+
Xi, Xf, Xg, Xo = self.apply_input_kernel(input)
|
171 |
+
Hi, Hf, Hg, Ho = self.apply_recurrent_kernel(h_tm1)
|
172 |
+
|
173 |
+
ft = self.fun_rec(Xf + Hf)
|
174 |
+
ot = self.fun_rec(Xo + Ho)
|
175 |
+
if self.tied_forget_gate:
|
176 |
+
it = 1.0 - ft
|
177 |
+
else:
|
178 |
+
it = self.fun_rec(Xi + Hi)
|
179 |
+
|
180 |
+
gt = T.tanh(Xg + Hg) # * np.sqrt(3)
|
181 |
+
if self.recurrent_dropout_mode == 'semeniuta':
|
182 |
+
#* https://arxiv.org/abs/1603.05118
|
183 |
+
gt = self.recurrent_dropout(gt)
|
184 |
+
|
185 |
+
ct = (ft * c_tm1) + (it * gt)
|
186 |
+
|
187 |
+
ht = ot * T.tanh(ct)
|
188 |
+
|
189 |
+
return ht, (ht, ct)
|
190 |
+
|
191 |
+
@T.jit.export
|
192 |
+
def loop(self, inputs, state_t0, mask=None):
|
193 |
+
# type: (List[Tensor], Tuple[Tensor, Tensor], Optional[List[Tensor]]) -> Tuple[List[Tensor], Tuple[Tensor, Tensor]]
|
194 |
+
'''
|
195 |
+
This loops over t (time) steps
|
196 |
+
'''
|
197 |
+
#^ inputs : t * [b i]
|
198 |
+
#^ state_t0[i] : [b s]
|
199 |
+
#^ out : [t b h]
|
200 |
+
state = state_t0
|
201 |
+
outs = []
|
202 |
+
for xt in inputs:
|
203 |
+
ht, state = self(xt, state)
|
204 |
+
outs.append(ht)
|
205 |
+
|
206 |
+
return outs, state
|
207 |
+
|
208 |
+
class K_LSTM(RecurrentLayerStack):
|
209 |
+
def __init__(
|
210 |
+
self,
|
211 |
+
*args,
|
212 |
+
**kargs,
|
213 |
+
):
|
214 |
+
builder = K_LSTM_Cell_Builder
|
215 |
+
super().__init__(
|
216 |
+
builder,
|
217 |
+
*args, **kargs
|
218 |
+
)
|
components/linear_scheduler.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
class LinearSchedule:
|
2 |
+
def __init__(self, schedule_timesteps, final_p, initial_p=1.0):
|
3 |
+
"""Linear interpolation between initial_p and final_p over
|
4 |
+
schedule_timesteps. After this many timesteps pass final_p is
|
5 |
+
returned.
|
6 |
+
Parameters
|
7 |
+
----------
|
8 |
+
schedule_timesteps: int
|
9 |
+
Number of timesteps for which to linearly anneal initial_p
|
10 |
+
to final_p
|
11 |
+
initial_p: float
|
12 |
+
initial output value
|
13 |
+
final_p: float
|
14 |
+
final output value
|
15 |
+
"""
|
16 |
+
self.schedule_timesteps = schedule_timesteps
|
17 |
+
self.final_p = final_p
|
18 |
+
self.initial_p = initial_p
|
19 |
+
|
20 |
+
def value(self, t):
|
21 |
+
"""See Schedule.value"""
|
22 |
+
fraction = min(float(t) / self.schedule_timesteps, 1.0)
|
23 |
+
return self.initial_p + fraction * (self.final_p - self.initial_p)
|
24 |
+
|
components/rnn.py
ADDED
File without changes
|
components/rnn_base.py
ADDED
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import (
|
2 |
+
Tuple,
|
3 |
+
List,
|
4 |
+
Union,
|
5 |
+
Dict,
|
6 |
+
Optional,
|
7 |
+
Callable,
|
8 |
+
)
|
9 |
+
from collections import namedtuple
|
10 |
+
from abc import ABC, abstractmethod
|
11 |
+
|
12 |
+
import torch as T
|
13 |
+
from torch import nn
|
14 |
+
from torch.nn import functional as F
|
15 |
+
|
16 |
+
from torch import Tensor
|
17 |
+
|
18 |
+
import pdb
|
19 |
+
|
20 |
+
from dataclasses import dataclass
|
21 |
+
|
22 |
+
|
23 |
+
class IRecurrentCell(ABC, nn.Module):
|
24 |
+
@abstractmethod
|
25 |
+
def get_init_state(self, input: Tensor):
|
26 |
+
pass
|
27 |
+
|
28 |
+
@abstractmethod
|
29 |
+
def loop(self, inputs, state_t0, mask=None):
|
30 |
+
pass
|
31 |
+
|
32 |
+
# def forward(self, input, state, mask=None):
|
33 |
+
# pass
|
34 |
+
|
35 |
+
@dataclass
|
36 |
+
class IRecurrentCellBuilder(ABC):
|
37 |
+
hidden_size: int
|
38 |
+
|
39 |
+
def make(self, input_size: int) -> IRecurrentCell:
|
40 |
+
pass
|
41 |
+
|
42 |
+
def make_scripted(self, *p, **ks) -> IRecurrentCell:
|
43 |
+
return T.jit.script(self.make(*p, **ks))
|
44 |
+
|
45 |
+
class RecurrentLayer(nn.Module):
|
46 |
+
def reorder_inputs(self, inputs: Union[List[T.Tensor], T.Tensor]):
|
47 |
+
#^ inputs : [t b i]
|
48 |
+
if self.direction == 'backward':
|
49 |
+
return inputs[::-1]
|
50 |
+
return inputs
|
51 |
+
|
52 |
+
def __init__(
|
53 |
+
self,
|
54 |
+
cell: IRecurrentCell,
|
55 |
+
direction='forward',
|
56 |
+
batch_first=False,
|
57 |
+
):
|
58 |
+
super().__init__()
|
59 |
+
if isinstance(batch_first, bool):
|
60 |
+
batch_first = (batch_first, batch_first)
|
61 |
+
self.batch_first = batch_first
|
62 |
+
self.direction = direction
|
63 |
+
self.cell_: IRecurrentCell = cell
|
64 |
+
|
65 |
+
@T.jit.ignore
|
66 |
+
def forward(self, input, state_t0, return_state=None):
|
67 |
+
if self.batch_first[0]:
|
68 |
+
#^ input : [b t i]
|
69 |
+
input = input.transpose(1, 0)
|
70 |
+
#^ input : [t b i]
|
71 |
+
inputs = input.unbind(0)
|
72 |
+
|
73 |
+
if state_t0 is None:
|
74 |
+
state_t0 = self.cell_.get_init_state(input)
|
75 |
+
|
76 |
+
inputs = self.reorder_inputs(inputs)
|
77 |
+
|
78 |
+
if return_state:
|
79 |
+
sequence, state = self.cell_.loop(inputs, state_t0)
|
80 |
+
else:
|
81 |
+
sequence, _ = self.cell_.loop(inputs, state_t0)
|
82 |
+
#^ sequence : t * [b h]
|
83 |
+
sequence = self.reorder_inputs(sequence)
|
84 |
+
sequence = T.stack(sequence)
|
85 |
+
#^ sequence : [t b h]
|
86 |
+
|
87 |
+
if self.batch_first[1]:
|
88 |
+
sequence = sequence.transpose(1, 0)
|
89 |
+
#^ sequence : [b t h]
|
90 |
+
|
91 |
+
if return_state:
|
92 |
+
return sequence, state
|
93 |
+
else:
|
94 |
+
return sequence, None
|
95 |
+
|
96 |
+
class BidirectionalRecurrentLayer(nn.Module):
|
97 |
+
def __init__(
|
98 |
+
self,
|
99 |
+
input_size: int,
|
100 |
+
cell_builder: IRecurrentCellBuilder,
|
101 |
+
batch_first=False,
|
102 |
+
return_states=False
|
103 |
+
):
|
104 |
+
super().__init__()
|
105 |
+
self.batch_first = batch_first
|
106 |
+
self.cell_builder = cell_builder
|
107 |
+
self.batch_first = batch_first
|
108 |
+
self.return_states = return_states
|
109 |
+
self.fwd = RecurrentLayer(
|
110 |
+
cell_builder.make_scripted(input_size),
|
111 |
+
direction='forward',
|
112 |
+
batch_first=batch_first
|
113 |
+
)
|
114 |
+
self.bwd = RecurrentLayer(
|
115 |
+
cell_builder.make_scripted(input_size),
|
116 |
+
direction='backward',
|
117 |
+
batch_first=batch_first
|
118 |
+
)
|
119 |
+
|
120 |
+
@T.jit.ignore
|
121 |
+
def forward(self, input, state_t0, is_last):
|
122 |
+
return_states = is_last and self.return_states
|
123 |
+
if return_states:
|
124 |
+
fwd, state_fwd = self.fwd(input, state_t0, return_states)
|
125 |
+
bwd, state_bwd = self.bwd(input, state_t0, return_states)
|
126 |
+
return T.cat([fwd, bwd], dim=-1), (T.cat([state_fwd[0], state_bwd[0]], dim=-1), T.cat([state_fwd[1], state_bwd[1]], dim=-1))
|
127 |
+
else:
|
128 |
+
fwd, _ = self.fwd(input, state_t0, return_states)
|
129 |
+
bwd, _ = self.bwd(input, state_t0, return_states)
|
130 |
+
return T.cat([fwd, bwd], dim=-1), None
|
131 |
+
|
132 |
+
class RecurrentLayerStack(nn.Module):
|
133 |
+
def __init__(
|
134 |
+
self,
|
135 |
+
cell_builder : Callable[..., IRecurrentCellBuilder],
|
136 |
+
input_size : int,
|
137 |
+
num_layers : int,
|
138 |
+
bidirectional : bool = False,
|
139 |
+
batch_first : bool = False,
|
140 |
+
scripted : bool = True,
|
141 |
+
return_states : bool = False,
|
142 |
+
*args, **kargs,
|
143 |
+
):
|
144 |
+
super().__init__()
|
145 |
+
cell_builder_: IRecurrentCellBuilder = cell_builder(*args, **kargs)
|
146 |
+
self._cell_builder = cell_builder_
|
147 |
+
|
148 |
+
if bidirectional:
|
149 |
+
Dh = cell_builder_.hidden_size * 2
|
150 |
+
def make(isize: int, last=False):
|
151 |
+
return BidirectionalRecurrentLayer(isize, cell_builder_,
|
152 |
+
batch_first=batch_first, return_states=return_states)
|
153 |
+
else:
|
154 |
+
Dh = cell_builder_.hidden_size
|
155 |
+
def make(isize: int, last=False):
|
156 |
+
cell = cell_builder_.make_scripted(isize)
|
157 |
+
return RecurrentLayer(cell, isize,
|
158 |
+
batch_first=batch_first)
|
159 |
+
|
160 |
+
|
161 |
+
if num_layers > 1:
|
162 |
+
rnns = [
|
163 |
+
make(input_size),
|
164 |
+
*[
|
165 |
+
make(Dh)
|
166 |
+
for _ in range(num_layers - 2)
|
167 |
+
],
|
168 |
+
make(Dh, last=True)
|
169 |
+
]
|
170 |
+
else:
|
171 |
+
rnns = [make(input_size, last=True)]
|
172 |
+
|
173 |
+
self.rnn = nn.Sequential(*rnns)
|
174 |
+
|
175 |
+
self.input_size = input_size
|
176 |
+
self.hidden_size = self._cell_builder.hidden_size
|
177 |
+
self.num_layers = num_layers
|
178 |
+
self.bidirectional = bidirectional
|
179 |
+
self.return_states = return_states
|
180 |
+
|
181 |
+
def __repr__(self):
|
182 |
+
return (
|
183 |
+
f'${self.__class__.__name__}'
|
184 |
+
+ '('
|
185 |
+
+ f'in={self.input_size}, '
|
186 |
+
+ f'hid={self.hidden_size}, '
|
187 |
+
+ f'layers={self.num_layers}, '
|
188 |
+
+ f'bi={self.bidirectional}'
|
189 |
+
+ '; '
|
190 |
+
+ str(self._cell_builder)
|
191 |
+
)
|
192 |
+
|
193 |
+
def forward(self, input, state_t0=None):
|
194 |
+
for layer_idx, rnn in enumerate(self.rnn):
|
195 |
+
is_last = (layer_idx == (len(self.rnn) - 1))
|
196 |
+
input, state = rnn(input, state_t0, is_last)
|
197 |
+
if self.return_states:
|
198 |
+
return input, state
|
199 |
+
return input
|
config.yaml
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
run-title: tashkeela-d2
|
2 |
+
debug: false
|
3 |
+
|
4 |
+
paths:
|
5 |
+
base: ./dataset/ashaar
|
6 |
+
save: ./models
|
7 |
+
load: tashkeela-d2.pt
|
8 |
+
resume: ./models/Tashkeela-D2/tashkeela-d2.pt
|
9 |
+
constants: ./dataset/helpers/constants
|
10 |
+
word-embs: vocab.vec
|
11 |
+
test: test
|
12 |
+
|
13 |
+
loader:
|
14 |
+
wembs-limit: -1
|
15 |
+
num-workers: 0
|
16 |
+
|
17 |
+
train:
|
18 |
+
epochs: 1000
|
19 |
+
batch-size: 32
|
20 |
+
char-embed-dim: 32
|
21 |
+
resume: false
|
22 |
+
resume-lr: false
|
23 |
+
|
24 |
+
max-word-len: 13
|
25 |
+
max-sent-len: 10
|
26 |
+
|
27 |
+
rnn-cell: lstm
|
28 |
+
sent-lstm-layers: 2
|
29 |
+
word-lstm-layers: 2
|
30 |
+
|
31 |
+
sent-lstm-units: 256
|
32 |
+
word-lstm-units: 512
|
33 |
+
decoder-units: 256
|
34 |
+
|
35 |
+
sent-dropout: 0.2
|
36 |
+
diac-dropout: 0
|
37 |
+
final-dropout: 0.2
|
38 |
+
|
39 |
+
sent-mask-zero: false
|
40 |
+
|
41 |
+
lr-factor: 0.5
|
42 |
+
lr-patience: 1
|
43 |
+
lr-min: 1.e-7
|
44 |
+
lr-init: 0.002
|
45 |
+
|
46 |
+
weight-decay: 0
|
47 |
+
vertical-dropout: 0.25
|
48 |
+
recurrent-dropout: 0.25
|
49 |
+
|
50 |
+
stopping-delta: 1.e-7
|
51 |
+
stopping-patience: 3
|
52 |
+
|
53 |
+
predictor:
|
54 |
+
batch-size: 75
|
55 |
+
stride: 2
|
56 |
+
window: 20
|
57 |
+
gt-signal-prob: 0
|
58 |
+
seed-idx: 0
|
59 |
+
|
60 |
+
sentence-break:
|
61 |
+
stride: 2
|
62 |
+
window: 10
|
63 |
+
min-window: 1
|
64 |
+
export-map: false
|
65 |
+
files:
|
66 |
+
- train/train.txt
|
67 |
+
- val/val.txt
|
68 |
+
delimeters:
|
69 |
+
- ،
|
70 |
+
- ؛
|
71 |
+
- ','
|
72 |
+
- ;
|
73 |
+
- «
|
74 |
+
- »
|
75 |
+
- '{'
|
76 |
+
- '}'
|
77 |
+
- '('
|
78 |
+
- ')'
|
79 |
+
- '['
|
80 |
+
- ']'
|
81 |
+
- '.'
|
82 |
+
- '*'
|
83 |
+
- '-'
|
84 |
+
- ':'
|
85 |
+
- '?'
|
86 |
+
- '!'
|
87 |
+
- ؟
|
88 |
+
|
89 |
+
|
90 |
+
segment:
|
91 |
+
stride: 2
|
92 |
+
window: 10
|
93 |
+
min-window: 1
|
94 |
+
export-map: false
|
95 |
+
files:
|
96 |
+
- train/train.txt
|
97 |
+
- val/val.txt
|
98 |
+
delimeters:
|
99 |
+
- ،
|
100 |
+
- ؛
|
101 |
+
- ','
|
102 |
+
- ;
|
103 |
+
- «
|
104 |
+
- »
|
105 |
+
- '{'
|
106 |
+
- '}'
|
107 |
+
- '('
|
108 |
+
- ')'
|
109 |
+
- '['
|
110 |
+
- ']'
|
111 |
+
- '.'
|
112 |
+
- '*'
|
113 |
+
- '-'
|
114 |
+
- ':'
|
115 |
+
- '?'
|
116 |
+
- '!'
|
117 |
+
- ؟
|
data_utils.py
ADDED
@@ -0,0 +1,230 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import pickle
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
from tqdm import tqdm
|
6 |
+
from prettytable import PrettyTable
|
7 |
+
from pyarabic.araby import tokenize, strip_tashkeel
|
8 |
+
import diac_utils as du
|
9 |
+
|
10 |
+
class DatasetUtils:
|
11 |
+
def __init__(self, config):
|
12 |
+
self.base_path = config["paths"]["base"]
|
13 |
+
self.special_tokens = ['<pad>', '<unk>', '<num>', '<punc>']
|
14 |
+
self.delimeters = config["sentence-break"]["delimeters"]
|
15 |
+
self.load_constants(config["paths"]["constants"])
|
16 |
+
self.debug = config["debug"]
|
17 |
+
|
18 |
+
self.stride = config["sentence-break"]["stride"]
|
19 |
+
self.window = config["sentence-break"]["window"]
|
20 |
+
self.val_stride = config["sentence-break"].get("val-stride", self.stride)
|
21 |
+
|
22 |
+
self.test_stride = config["predictor"]["stride"]
|
23 |
+
self.test_window = config["predictor"]["window"]
|
24 |
+
|
25 |
+
self.max_word_len = config["train"]["max-word-len"]
|
26 |
+
self.max_sent_len = config["train"]["max-sent-len"]
|
27 |
+
self.max_token_count = config["train"]["max-token-count"]
|
28 |
+
self.pad_target_val = -100
|
29 |
+
self.pad_char_id = du.LETTER_LIST.index('<pad>')
|
30 |
+
|
31 |
+
self.markov_signal = config['train'].get('markov-signal', False)
|
32 |
+
self.batch_first = config['train'].get('batch-first', True)
|
33 |
+
|
34 |
+
self.gt_prob = config["predictor"]["gt-signal-prob"]
|
35 |
+
if self.gt_prob > 0:
|
36 |
+
self.s_idx = config["predictor"]["seed-idx"]
|
37 |
+
subpath = f"test_gt_mask_{self.gt_prob}_{self.s_idx}.txt"
|
38 |
+
mask_path = os.path.join(self.base_path, "test", subpath)
|
39 |
+
with open(mask_path, 'r') as fin:
|
40 |
+
self.gt_mask = fin.readlines()
|
41 |
+
|
42 |
+
if "word-embs" in config["paths"] and config["paths"]["word-embs"].strip() != "":
|
43 |
+
self.pad_val = self.special_tokens.index("<pad>")
|
44 |
+
self.embeddings, self.vocab = self.load_embeddings(config["paths"]["word-embs"], config["loader"]["wembs-limit"])
|
45 |
+
self.embeddings = self.normalize(self.embeddings, ["unit", "centeremb", "unit"])
|
46 |
+
self.w2idx = {word: i for i, word in enumerate(self.vocab)}
|
47 |
+
|
48 |
+
def load_file(self, path):
|
49 |
+
with open(path, 'rb') as f:
|
50 |
+
return list(pickle.load(f))
|
51 |
+
|
52 |
+
def normalize(self, matrix, actions, mean=None):
|
53 |
+
def length_normalize(matrix):
|
54 |
+
norms = np.sqrt(np.sum(matrix**2, axis=1))
|
55 |
+
norms[norms == 0] = 1
|
56 |
+
matrix = matrix / norms[:, np.newaxis]
|
57 |
+
return matrix
|
58 |
+
|
59 |
+
def mean_center(matrix):
|
60 |
+
return matrix - mean
|
61 |
+
|
62 |
+
def length_normalize_dimensionwise(matrix):
|
63 |
+
norms = np.sqrt(np.sum(matrix**2, axis=0))
|
64 |
+
norms[norms == 0] = 1
|
65 |
+
matrix = matrix / norms
|
66 |
+
return matrix
|
67 |
+
|
68 |
+
def mean_center_embeddingwise(matrix):
|
69 |
+
avg = np.mean(matrix, axis=1)
|
70 |
+
matrix = matrix - avg[:, np.newaxis]
|
71 |
+
return matrix
|
72 |
+
|
73 |
+
for action in actions:
|
74 |
+
if action == 'unit':
|
75 |
+
matrix = length_normalize(matrix)
|
76 |
+
elif action == 'center':
|
77 |
+
matrix = mean_center(matrix)
|
78 |
+
elif action == 'unitdim':
|
79 |
+
matrix = length_normalize_dimensionwise(matrix)
|
80 |
+
elif action == 'centeremb':
|
81 |
+
matrix = mean_center_embeddingwise(matrix)
|
82 |
+
|
83 |
+
return matrix
|
84 |
+
|
85 |
+
def load_constants(self, path):
|
86 |
+
# self.numbers = [c for c in "0123456789"]
|
87 |
+
# self.letter_list = self.special_tokens + self.load_file(os.path.join(path, 'ARABIC_LETTERS_LIST.pickle'))
|
88 |
+
# self.diacritic_list = [' '] + self.load_file(os.path.join(path, 'DIACRITICS_LIST.pickle'))
|
89 |
+
self.numbers = du.NUMBERS
|
90 |
+
self.letter_list = du.LETTER_LIST
|
91 |
+
self.diacritic_list = du.DIACRITICS_SHORT
|
92 |
+
|
93 |
+
def split_word_on_characters_with_diacritics(self, word: str):
|
94 |
+
return du.split_word_on_characters_with_diacritics(word)
|
95 |
+
|
96 |
+
def load_mapping_v3(self, dtype, file_ext=None):
|
97 |
+
mapping = {}
|
98 |
+
if file_ext is None:
|
99 |
+
file_ext = f"-{self.test_stride}-{self.test_window}.map"
|
100 |
+
f_name = os.path.join(self.base_path, dtype, dtype + file_ext)
|
101 |
+
with open(f_name, 'r') as fin:
|
102 |
+
for line in fin:
|
103 |
+
sent_idx, seg_idx, t_idx, c_idx = map(int, line.split(','))
|
104 |
+
if sent_idx not in mapping:
|
105 |
+
mapping[sent_idx] = {}
|
106 |
+
if seg_idx not in mapping[sent_idx]:
|
107 |
+
mapping[sent_idx][seg_idx] = {}
|
108 |
+
if t_idx not in mapping[sent_idx][seg_idx]:
|
109 |
+
mapping[sent_idx][seg_idx][t_idx] = []
|
110 |
+
mapping[sent_idx][seg_idx][t_idx] += [c_idx]
|
111 |
+
return mapping
|
112 |
+
|
113 |
+
def load_mapping_v3_from_list(self, mapping_list):
|
114 |
+
mapping = {}
|
115 |
+
for line in mapping_list:
|
116 |
+
sent_idx, seg_idx, t_idx, c_idx = map(int, line.split(','))
|
117 |
+
if sent_idx not in mapping:
|
118 |
+
mapping[sent_idx] = {}
|
119 |
+
if seg_idx not in mapping[sent_idx]:
|
120 |
+
mapping[sent_idx][seg_idx] = {}
|
121 |
+
if t_idx not in mapping[sent_idx][seg_idx]:
|
122 |
+
mapping[sent_idx][seg_idx][t_idx] = []
|
123 |
+
mapping[sent_idx][seg_idx][t_idx] += [c_idx]
|
124 |
+
return mapping
|
125 |
+
|
126 |
+
def load_embeddings(self, embs_path, limit=-1):
|
127 |
+
if self.debug:
|
128 |
+
return np.zeros((200+len(self.special_tokens),300)), self.special_tokens + ["c"] * 200
|
129 |
+
|
130 |
+
words = [self.special_tokens[0]]
|
131 |
+
print(f"[INFO] Reading Embeddings from {embs_path}")
|
132 |
+
with open(embs_path, encoding='utf-8', mode='r') as fin:
|
133 |
+
n, d = map(int, fin.readline().split())
|
134 |
+
limit = n if limit <= 0 else limit
|
135 |
+
embeddings = np.zeros((limit+1, d))
|
136 |
+
for i, line in tqdm(enumerate(fin), total=limit):
|
137 |
+
if i >= limit: break
|
138 |
+
tokens = line.rstrip().split()
|
139 |
+
words += [tokens[0]]
|
140 |
+
embeddings[i+1] = list(map(float, tokens[1:]))
|
141 |
+
return embeddings, words
|
142 |
+
|
143 |
+
def load_file_clean(self, dtype, strip=False):
|
144 |
+
f_name = os.path.join(self.base_path, dtype, dtype + ".txt")
|
145 |
+
with open(f_name, 'r', encoding="utf-8", newline='\n') as fin:
|
146 |
+
if strip:
|
147 |
+
original_lines = [strip_tashkeel(self.preprocess(line)) for line in fin.readlines()]
|
148 |
+
else:
|
149 |
+
original_lines = [self.preprocess(line) for line in fin.readlines()]
|
150 |
+
return original_lines
|
151 |
+
|
152 |
+
def preprocess(self, line):
|
153 |
+
return ' '.join(tokenize(line))
|
154 |
+
|
155 |
+
def pad_and_truncate_sequence(self, tokens, max_len, pad=None):
|
156 |
+
if pad is None:
|
157 |
+
pad = self.special_tokens.index("<pad>")
|
158 |
+
if len(tokens) < max_len:
|
159 |
+
offset = max_len - len(tokens)
|
160 |
+
return tokens + [pad] * offset
|
161 |
+
else:
|
162 |
+
return tokens[:max_len]
|
163 |
+
|
164 |
+
def stats(self, freq, percentile=90, name="stats"):
|
165 |
+
table = PrettyTable(["Dataset", "Mean", "Std", "Min", "Max", f"{percentile}th Percentile"])
|
166 |
+
freq = np.array(sorted(freq))
|
167 |
+
table.add_row([name, freq.mean(), freq.std(), freq.min(), freq.max(), np.percentile(freq, percentile)])
|
168 |
+
print(table)
|
169 |
+
|
170 |
+
def create_gt_mask(self, lines, prob, idx, seed=1111):
|
171 |
+
np.random.seed(seed)
|
172 |
+
|
173 |
+
gt_masks = []
|
174 |
+
for line in lines:
|
175 |
+
tokens = tokenize(line.strip())
|
176 |
+
gt_mask_token = ""
|
177 |
+
for t_idx, token in enumerate(tokens):
|
178 |
+
gt_mask_token += ''.join(map(str, np.random.binomial(1, prob, len(token))))
|
179 |
+
if t_idx+1 < len(tokens):
|
180 |
+
gt_mask_token += " "
|
181 |
+
gt_masks += [gt_mask_token]
|
182 |
+
|
183 |
+
subpath = f"test_gt_mask_{prob}_{idx}.txt"
|
184 |
+
mask_path = os.path.join(self.base_path, "test", subpath)
|
185 |
+
|
186 |
+
with open(mask_path, 'w') as fout:
|
187 |
+
fout.write('\n'.join(gt_masks))
|
188 |
+
|
189 |
+
def create_gt_labels(self, lines):
|
190 |
+
gt_labels = []
|
191 |
+
for line in lines:
|
192 |
+
gt_labels_line = []
|
193 |
+
tokens = tokenize(line.strip())
|
194 |
+
for w_idx, word in enumerate(tokens):
|
195 |
+
split_word = self.split_word_on_characters_with_diacritics(word)
|
196 |
+
_, cy_flat, _ = du.create_label_for_word(split_word)
|
197 |
+
|
198 |
+
gt_labels_line.extend(cy_flat)
|
199 |
+
if w_idx+1 < len(tokens):
|
200 |
+
gt_labels_line += [0]
|
201 |
+
|
202 |
+
gt_labels += [gt_labels_line]
|
203 |
+
return gt_labels
|
204 |
+
|
205 |
+
def get_ce(self, diac_word_y, e_idx=None, return_idx=False):
|
206 |
+
#^ diac_word_y: [Tw 3]
|
207 |
+
if e_idx is None: e_idx = len(diac_word_y)
|
208 |
+
for c_idx in reversed(range(e_idx)):
|
209 |
+
if diac_word_y[c_idx] != [0,0,0]:
|
210 |
+
return diac_word_y[c_idx] if not return_idx else c_idx
|
211 |
+
return diac_word_y[e_idx-1] if not return_idx else e_idx-1
|
212 |
+
|
213 |
+
def create_decoder_input(self, diac_code_y, prob=0):
|
214 |
+
#^ diac_code_y: [Ts Tw 3]
|
215 |
+
diac_code_x = np.zeros((*np.array(diac_code_y).shape[:-1], 8))
|
216 |
+
if not self.markov_signal:
|
217 |
+
return list(diac_code_x)
|
218 |
+
prev_ce = list(np.eye(6)[-1]) + [0,0] # bos tag
|
219 |
+
for w_idx, word in enumerate(diac_code_y):
|
220 |
+
diac_code_x[w_idx, 0, :] = prev_ce
|
221 |
+
for c_idx, char in enumerate(word[:-1]):
|
222 |
+
# if np.random.rand() < prob:
|
223 |
+
# continue
|
224 |
+
if char[0] == self.pad_target_val:
|
225 |
+
break
|
226 |
+
haraka = list(np.eye(6)[char[0]])
|
227 |
+
diac_code_x[w_idx, c_idx+1, :] = haraka + char[1:]
|
228 |
+
ce = self.get_ce(diac_code_y[w_idx], c_idx)
|
229 |
+
prev_ce = list(np.eye(6)[ce[0]]) + ce[1:]
|
230 |
+
return list(diac_code_x)
|
dataloader.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
from pyarabic.araby import tokenize, strip_tashkeel
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import torch as T
|
7 |
+
from torch.utils.data import Dataset
|
8 |
+
|
9 |
+
from data_utils import DatasetUtils
|
10 |
+
import diac_utils as du
|
11 |
+
|
12 |
+
class DataRetriever(Dataset):
|
13 |
+
def __init__(self, data_utils : DatasetUtils, lines: list):
|
14 |
+
super(DataRetriever).__init__()
|
15 |
+
|
16 |
+
self.data_utils = data_utils
|
17 |
+
self.lines = lines
|
18 |
+
|
19 |
+
def preprocess(self, data, dtype=T.long):
|
20 |
+
return [T.tensor(np.array(x), dtype=dtype) for x in data]
|
21 |
+
|
22 |
+
def __len__(self):
|
23 |
+
return len(self.lines)
|
24 |
+
|
25 |
+
def __getitem__(self, idx):
|
26 |
+
word_x, char_x, diac_x, diac_y = self.create_sentence(idx)
|
27 |
+
return self.preprocess((word_x, char_x, diac_x)), T.tensor(diac_y, dtype=T.long), T.tensor(diac_y, dtype=T.long)
|
28 |
+
|
29 |
+
def create_sentence(self, idx):
|
30 |
+
line = self.lines[idx]
|
31 |
+
tokens = tokenize(line.strip())
|
32 |
+
|
33 |
+
word_x = []
|
34 |
+
char_x = []
|
35 |
+
diac_x = []
|
36 |
+
diac_y = []
|
37 |
+
diac_y_tmp = []
|
38 |
+
|
39 |
+
for word in tokens:
|
40 |
+
word = du.strip_unknown_tashkeel(word)
|
41 |
+
word_chars = du.split_word_on_characters_with_diacritics(word)
|
42 |
+
cx, cy, cy_3head = du.create_label_for_word(word_chars)
|
43 |
+
|
44 |
+
word_strip = strip_tashkeel(word)
|
45 |
+
word_x += [self.data_utils.w2idx[word_strip] if word_strip in self.data_utils.w2idx else self.data_utils.w2idx["<pad>"]]
|
46 |
+
|
47 |
+
char_x += [self.data_utils.pad_and_truncate_sequence(cx, self.data_utils.max_word_len)]
|
48 |
+
|
49 |
+
diac_y += [self.data_utils.pad_and_truncate_sequence(cy, self.data_utils.max_word_len, pad=self.data_utils.pad_target_val)]
|
50 |
+
diac_y_tmp += [self.data_utils.pad_and_truncate_sequence(cy_3head, self.data_utils.max_word_len, pad=[self.data_utils.pad_target_val]*3)]
|
51 |
+
|
52 |
+
diac_x = self.data_utils.create_decoder_input(diac_y_tmp)
|
53 |
+
|
54 |
+
max_slen = self.data_utils.max_sent_len
|
55 |
+
max_wlen = self.data_utils.max_word_len
|
56 |
+
p_val = self.data_utils.pad_val
|
57 |
+
pt_val = self.data_utils.pad_target_val
|
58 |
+
|
59 |
+
word_x = self.data_utils.pad_and_truncate_sequence(word_x, max_slen)
|
60 |
+
char_x = self.data_utils.pad_and_truncate_sequence(char_x, max_slen, pad=[p_val]*max_wlen)
|
61 |
+
diac_x = self.data_utils.pad_and_truncate_sequence(diac_x, max_slen, pad=[[p_val]*8]*max_wlen)
|
62 |
+
diac_y = self.data_utils.pad_and_truncate_sequence(diac_y, max_slen, pad=[pt_val]*max_wlen)
|
63 |
+
|
64 |
+
return word_x, char_x, diac_x, diac_y
|
diac_utils.py
ADDED
@@ -0,0 +1,223 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List
|
2 |
+
|
3 |
+
import torch as T
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
from pyarabic.araby import (
|
7 |
+
tokenize,
|
8 |
+
strip_tashkeel,
|
9 |
+
strip_tatweel,
|
10 |
+
DIACRITICS
|
11 |
+
)
|
12 |
+
|
13 |
+
SEPARATE_DIACRITICS = {
|
14 |
+
"FATHA": 1,
|
15 |
+
"KASRA": 2,
|
16 |
+
"DAMMA": 3,
|
17 |
+
"SUKUN": 4
|
18 |
+
}
|
19 |
+
|
20 |
+
HARAKAT_MAP = [
|
21 |
+
#^ (haraka, tanween, shadda)
|
22 |
+
(0,0,0), #< No diacs on char
|
23 |
+
(1,0,0),
|
24 |
+
(1,1,0), #< Tanween on 2nd slot
|
25 |
+
(2,0,0),
|
26 |
+
(2,1,0),
|
27 |
+
(3,0,0),
|
28 |
+
(3,1,0),
|
29 |
+
(4,0,0),
|
30 |
+
(0,0,1), #< shadda on 3rd slot
|
31 |
+
(1,0,1),
|
32 |
+
(1,1,1),
|
33 |
+
(2,0,1),
|
34 |
+
(2,1,1),
|
35 |
+
(3,0,1),
|
36 |
+
(3,1,1),
|
37 |
+
(0,0,0), #< Padding == -1 (also for spaces)
|
38 |
+
]
|
39 |
+
|
40 |
+
SPECIAL_TOKENS = ['<pad>', '<unk>', '<num>', '<punc>']
|
41 |
+
LETTER_LIST = SPECIAL_TOKENS + list("ءآأؤإئابةتثجحخدذرزسشصضطظعغفقكلمنهوىي")
|
42 |
+
CLASSES_LIST = [' ', 'َ', 'ً', 'ُ', 'ٌ', 'ِ', 'ٍ', 'ْ', 'ّ', 'َّ', 'ًّ', 'ُّ', 'ٌّ', 'ِّ', 'ٍّ']
|
43 |
+
DIACRITICS_SHORT = [' ', 'َ', 'ً', 'ِ', 'ٍ', 'ُ', 'ٌ', 'ْ', 'ّ']
|
44 |
+
NUMBERS = list("0123456789")
|
45 |
+
DELIMITERS = ["،","؛",",",";","«","»","{","}","(",")","[","]",".","*","-",":","?","!","؟"]
|
46 |
+
|
47 |
+
UNKNOWN_DIACRITICS = list(set(DIACRITICS).difference(set(DIACRITICS_SHORT)))
|
48 |
+
|
49 |
+
def shakkel_char(diac: int, tanween: bool, shadda: bool) -> str:
|
50 |
+
returned_text = ""
|
51 |
+
if shadda and diac != SEPARATE_DIACRITICS["SUKUN"]:
|
52 |
+
returned_text += "\u0651"
|
53 |
+
|
54 |
+
if diac == SEPARATE_DIACRITICS["FATHA"]:
|
55 |
+
returned_text += "\u064E" if not tanween else "\u064B"
|
56 |
+
elif diac == SEPARATE_DIACRITICS["KASRA"]:
|
57 |
+
returned_text += "\u0650" if not tanween else "\u064D"
|
58 |
+
elif diac == SEPARATE_DIACRITICS["DAMMA"]:
|
59 |
+
returned_text += "\u064F" if not tanween else "\u064C"
|
60 |
+
elif diac == SEPARATE_DIACRITICS["SUKUN"]:
|
61 |
+
returned_text += "\u0652"
|
62 |
+
|
63 |
+
return returned_text
|
64 |
+
|
65 |
+
def diac_ids_of_line(line: str):
|
66 |
+
words = tokenize(line)
|
67 |
+
diacs = []
|
68 |
+
for word in words:
|
69 |
+
word_chars = split_word_on_characters_with_diacritics(word)
|
70 |
+
cx, cy, cy_3head = create_label_for_word(word_chars)
|
71 |
+
diacs.extend(cy)
|
72 |
+
diacs.append(-1)
|
73 |
+
return np.array(diacs[:-1])
|
74 |
+
|
75 |
+
def strip_unknown_tashkeel(word: str):
|
76 |
+
#! FIXME! warnings.warn("Stripping unknown tashkeel is disabled.")
|
77 |
+
return word
|
78 |
+
return ''.join(c for c in word if c not in UNKNOWN_DIACRITICS)
|
79 |
+
|
80 |
+
def split_word_on_characters_with_diacritics(word: str):
|
81 |
+
'''
|
82 |
+
TODO! Make faster without deque and looping
|
83 |
+
Returns: List[List[char: "letter or diacritic"]]
|
84 |
+
'''
|
85 |
+
chars_w_diac = []
|
86 |
+
i_start = 0
|
87 |
+
for i_c, c in enumerate(word):
|
88 |
+
#! FIXME! DIACRITICS_SHORT is missing a lot of less common diacritics ...
|
89 |
+
#! which are then treated as letters during splitting.
|
90 |
+
# if c not in DIACRITICS:
|
91 |
+
if c not in DIACRITICS_SHORT:
|
92 |
+
sub = list(word[i_start:i_c])
|
93 |
+
chars_w_diac.append(sub)
|
94 |
+
i_start = i_c
|
95 |
+
sub = list(word[i_start:])
|
96 |
+
if sub:
|
97 |
+
chars_w_diac.append(sub)
|
98 |
+
if not chars_w_diac[0]:
|
99 |
+
chars_w_diac = chars_w_diac[1:]
|
100 |
+
return chars_w_diac
|
101 |
+
|
102 |
+
|
103 |
+
def char_type(char: str):
|
104 |
+
if char in LETTER_LIST:
|
105 |
+
return LETTER_LIST.index(char)
|
106 |
+
elif char in NUMBERS:
|
107 |
+
return LETTER_LIST.index('<num>')
|
108 |
+
elif char in DELIMITERS:
|
109 |
+
return LETTER_LIST.index('<punc>')
|
110 |
+
else:
|
111 |
+
return LETTER_LIST.index('<unk>')
|
112 |
+
|
113 |
+
def create_labels(char_w_diac: str):
|
114 |
+
remap_dict = {0: 0, 1: 1, 3: 2, 5: 3, 7: 4}
|
115 |
+
char_w_diac = [char_w_diac[0]] + list(set(char_w_diac[1:]))
|
116 |
+
if len(char_w_diac) > 3:
|
117 |
+
char_w_diac = char_w_diac[:2] if DIACRITICS_SHORT[8] not in char_w_diac else char_w_diac[:3]
|
118 |
+
|
119 |
+
char_idx = None
|
120 |
+
diacritic_index = None
|
121 |
+
head_3 = None
|
122 |
+
|
123 |
+
char_idx = char_type(char_w_diac[0])
|
124 |
+
diacs = set(char_w_diac[1:])
|
125 |
+
diac_h3 = [0, 0, 0]
|
126 |
+
for diac in diacs:
|
127 |
+
if diac in DIACRITICS_SHORT:
|
128 |
+
diac_idx = DIACRITICS_SHORT.index(diac)
|
129 |
+
if diac_idx in [2, 4, 6]: #< Tanween
|
130 |
+
diac_h3[0] = remap_dict[diac_idx - 1]
|
131 |
+
diac_h3[1] = 1
|
132 |
+
elif diac_idx == 8: #< shadda
|
133 |
+
diac_h3[2] = 1
|
134 |
+
else: #< Haraka or sukoon
|
135 |
+
diac_h3[0] = remap_dict[diac_idx]
|
136 |
+
assert not (diac_h3[0] == 4 and (diac_h3[1] or diac_h3[2]))
|
137 |
+
diacritic_index = HARAKAT_MAP.index(tuple(diac_h3))
|
138 |
+
return char_idx, diacritic_index, diac_h3
|
139 |
+
if len(char_w_diac) == 1:
|
140 |
+
return char_idx, 0, [remap_dict[0], 0, 0]
|
141 |
+
elif len(char_w_diac) == 2: # If shadda OR diac
|
142 |
+
diacritic_index = DIACRITICS_SHORT.index(char_w_diac[1])
|
143 |
+
if diacritic_index in [2, 4, 6]: # list of tanween
|
144 |
+
head_3 = [remap_dict[diacritic_index - 1], 1, 0]
|
145 |
+
elif diacritic_index == 8:
|
146 |
+
head_3 = [0, 0, 1]
|
147 |
+
else:
|
148 |
+
head_3 = [remap_dict[diacritic_index], 0, 0]
|
149 |
+
elif len(char_w_diac) == 3: # If shadda AND diac
|
150 |
+
if DIACRITICS_SHORT[8] == char_w_diac[1]:
|
151 |
+
diacritic_index = DIACRITICS_SHORT.index(char_w_diac[2])
|
152 |
+
else:
|
153 |
+
diacritic_index = DIACRITICS_SHORT.index(char_w_diac[1])
|
154 |
+
|
155 |
+
if diacritic_index in [2, 4, 6]: # list of tanween
|
156 |
+
head_3 = [remap_dict[diacritic_index - 1], 1, 1]
|
157 |
+
else:
|
158 |
+
head_3 = [remap_dict[diacritic_index], 0, 1]
|
159 |
+
diacritic_index = diacritic_index+8
|
160 |
+
|
161 |
+
return char_idx, diacritic_index, head_3
|
162 |
+
|
163 |
+
def create_label_for_word(split_word: List[List[str]]):
|
164 |
+
word_char_indices = []
|
165 |
+
word_diac_indices = []
|
166 |
+
word_diac_indices_h3 = []
|
167 |
+
for char_w_diac in split_word:
|
168 |
+
char_idx, diac_idx, diac_h3 = create_labels(char_w_diac)
|
169 |
+
if char_idx == None:
|
170 |
+
print(split_word)
|
171 |
+
raise ValueError(char_idx)
|
172 |
+
word_char_indices.append(char_idx)
|
173 |
+
word_diac_indices.append(diac_idx)
|
174 |
+
word_diac_indices_h3.append(diac_h3)
|
175 |
+
return word_char_indices, word_diac_indices, word_diac_indices_h3
|
176 |
+
|
177 |
+
|
178 |
+
def flat_2_3head(output: T.Tensor):
|
179 |
+
'''
|
180 |
+
output: [b tw tc]
|
181 |
+
'''
|
182 |
+
haraka, tanween, shadda = [], [], []
|
183 |
+
|
184 |
+
# 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14
|
185 |
+
# 0, F, FF, K, KK, D, DD, S, Sh, ShF, ShFF, ShK, ShKK, ShD, ShDD
|
186 |
+
|
187 |
+
b, ts, tw = output.shape
|
188 |
+
|
189 |
+
for b_idx in range(b):
|
190 |
+
h_s, t_s, s_s = [], [], []
|
191 |
+
for w_idx in range(ts):
|
192 |
+
h_w, t_w, s_w = [], [], []
|
193 |
+
for c_idx in range(tw):
|
194 |
+
c = HARAKAT_MAP[int(output[b_idx, w_idx, c_idx])]
|
195 |
+
h_w += [c[0]]
|
196 |
+
t_w += [c[1]]
|
197 |
+
s_w += [c[2]]
|
198 |
+
h_s += [h_w]
|
199 |
+
t_s += [t_w]
|
200 |
+
s_s += [s_w]
|
201 |
+
|
202 |
+
haraka += [h_s]
|
203 |
+
tanween += [t_s]
|
204 |
+
shadda += [s_s]
|
205 |
+
|
206 |
+
|
207 |
+
return haraka, tanween, shadda
|
208 |
+
|
209 |
+
def flat2_3head(diac_idx):
|
210 |
+
'''
|
211 |
+
diac_idx: [tw]
|
212 |
+
'''
|
213 |
+
haraka, tanween, shadda = [], [], []
|
214 |
+
# 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14
|
215 |
+
# 0, F, FF, K, KK, D, DD, S, Sh, ShF, ShFF, ShK, ShKK, ShD, ShDD
|
216 |
+
|
217 |
+
for diac in diac_idx:
|
218 |
+
c_out = HARAKAT_MAP[diac]
|
219 |
+
haraka += [c_out[0]]
|
220 |
+
tanween += [c_out[1]]
|
221 |
+
shadda += [c_out[2]]
|
222 |
+
|
223 |
+
return np.array(haraka), np.array(tanween), np.array(shadda)
|
model_dd.py
ADDED
@@ -0,0 +1,526 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch as T
|
3 |
+
|
4 |
+
from tqdm import tqdm
|
5 |
+
from torch import nn
|
6 |
+
from torch.nn import functional as F
|
7 |
+
|
8 |
+
from components.k_lstm import K_LSTM
|
9 |
+
from components.attention import Attention
|
10 |
+
from data_utils import DatasetUtils
|
11 |
+
from diac_utils import flat2_3head, flat_2_3head
|
12 |
+
|
13 |
+
class DiacritizerD2(nn.Module):
|
14 |
+
def __init__(self, config):
|
15 |
+
super(DiacritizerD2, self).__init__()
|
16 |
+
self.max_word_len = config["train"]["max-word-len"]
|
17 |
+
self.max_sent_len = config["train"]["max-sent-len"]
|
18 |
+
self.char_embed_dim = config["train"]["char-embed-dim"]
|
19 |
+
|
20 |
+
self.final_dropout_p = config["train"]["final-dropout"]
|
21 |
+
self.sent_dropout_p = config["train"]["sent-dropout"]
|
22 |
+
self.diac_dropout_p = config["train"]["diac-dropout"]
|
23 |
+
self.vertical_dropout = config['train']['vertical-dropout']
|
24 |
+
self.recurrent_dropout = config['train']['recurrent-dropout']
|
25 |
+
self.recurrent_dropout_mode = config['train'].get('recurrent-dropout-mode', 'gal_tied')
|
26 |
+
self.recurrent_activation = config['train'].get('recurrent-activation', 'sigmoid')
|
27 |
+
|
28 |
+
self.sent_lstm_units = config["train"]["sent-lstm-units"]
|
29 |
+
self.word_lstm_units = config["train"]["word-lstm-units"]
|
30 |
+
self.decoder_units = config["train"]["decoder-units"]
|
31 |
+
|
32 |
+
self.sent_lstm_layers = config["train"]["sent-lstm-layers"]
|
33 |
+
self.word_lstm_layers = config["train"]["word-lstm-layers"]
|
34 |
+
|
35 |
+
self.cell = config['train'].get('rnn-cell', 'lstm')
|
36 |
+
self.num_layers = config["train"].get("num-layers", 2)
|
37 |
+
self.RNN_Layer = K_LSTM
|
38 |
+
|
39 |
+
self.batch_first = config['train'].get('batch-first', True)
|
40 |
+
self.device = 'cuda' if T.cuda.is_available() else 'cpu'
|
41 |
+
self.num_classes = 15
|
42 |
+
|
43 |
+
def build(self, wembs: T.Tensor, abjad_size: int):
|
44 |
+
self.closs = F.cross_entropy
|
45 |
+
self.bloss = F.binary_cross_entropy_with_logits
|
46 |
+
|
47 |
+
rnn_kargs = dict(
|
48 |
+
recurrent_dropout_mode=self.recurrent_dropout_mode,
|
49 |
+
recurrent_activation=self.recurrent_activation,
|
50 |
+
)
|
51 |
+
|
52 |
+
self.sent_lstm = self.RNN_Layer(
|
53 |
+
input_size=300,
|
54 |
+
hidden_size=self.sent_lstm_units,
|
55 |
+
num_layers=self.sent_lstm_layers,
|
56 |
+
bidirectional=True,
|
57 |
+
vertical_dropout=self.vertical_dropout,
|
58 |
+
recurrent_dropout=self.recurrent_dropout,
|
59 |
+
batch_first=self.batch_first,
|
60 |
+
**rnn_kargs,
|
61 |
+
)
|
62 |
+
|
63 |
+
self.word_lstm = self.RNN_Layer(
|
64 |
+
input_size=self.sent_lstm_units * 2 + self.char_embed_dim,
|
65 |
+
hidden_size=self.word_lstm_units,
|
66 |
+
num_layers=self.word_lstm_layers,
|
67 |
+
bidirectional=True,
|
68 |
+
vertical_dropout=self.vertical_dropout,
|
69 |
+
recurrent_dropout=self.recurrent_dropout,
|
70 |
+
batch_first=self.batch_first,
|
71 |
+
return_states=True,
|
72 |
+
**rnn_kargs,
|
73 |
+
)
|
74 |
+
|
75 |
+
self.char_embs = nn.Embedding(
|
76 |
+
abjad_size,
|
77 |
+
self.char_embed_dim,
|
78 |
+
padding_idx=0,
|
79 |
+
)
|
80 |
+
|
81 |
+
self.attention = Attention(
|
82 |
+
kind="dot",
|
83 |
+
query_dim=self.word_lstm_units * 2,
|
84 |
+
input_dim=self.sent_lstm_units * 2,
|
85 |
+
)
|
86 |
+
|
87 |
+
self.word_embs = T.tensor(wembs).clone().to(dtype=T.float32)
|
88 |
+
self.word_embs = self.word_embs.to(self.device)
|
89 |
+
|
90 |
+
self.classifier = nn.Linear(self.attention.Dout + self.word_lstm_units * 2, self.num_classes)
|
91 |
+
self.dropout = nn.Dropout(self.final_dropout_p)
|
92 |
+
|
93 |
+
def forward(self, sents, words, labels=None, subword_lengths=None):
|
94 |
+
#^ sents : [b ts]
|
95 |
+
#^ words : [b ts tw]
|
96 |
+
#^ labels: [b ts tw]
|
97 |
+
max_words = min(self.max_sent_len, sents.shape[1])
|
98 |
+
|
99 |
+
word_mask = words.ne(0.).float()
|
100 |
+
#^ word_mask: [b ts tw]
|
101 |
+
|
102 |
+
if self.training:
|
103 |
+
q = 1.0 - self.sent_dropout_p
|
104 |
+
sdo = T.bernoulli(T.full(sents.shape, q))
|
105 |
+
sents_do = sents * sdo.long()
|
106 |
+
#^ sents_do : [b ts] ; DO(ts)
|
107 |
+
wembs = self.word_embs[sents_do]
|
108 |
+
#^ wembs : [b ts dw] ; DO(ts)
|
109 |
+
else:
|
110 |
+
wembs = self.word_embs[sents]
|
111 |
+
#^ wembs : [b ts dw]
|
112 |
+
|
113 |
+
sent_enc = self.sent_lstm(wembs.to(self.device))
|
114 |
+
#^ sent_enc : [b ts dwe]
|
115 |
+
|
116 |
+
sentword_do = sent_enc.unsqueeze(2)
|
117 |
+
#^ sentword_do : [b ts _ dwe]
|
118 |
+
|
119 |
+
sentword_do = self.dropout(sentword_do * word_mask.unsqueeze(-1))
|
120 |
+
#^ sentword_do : [b ts tw dwe]
|
121 |
+
|
122 |
+
word_index = words.view(-1, self.max_word_len)
|
123 |
+
#^ word_index: [b*ts tw]?
|
124 |
+
|
125 |
+
cembs = self.char_embs(word_index)
|
126 |
+
#^ cembs : [b*ts tw dc]
|
127 |
+
|
128 |
+
sentword_do = sentword_do.view(-1, self.max_word_len, self.sent_lstm_units * 2)
|
129 |
+
#^ sentword_do : [b*ts tw dwe]
|
130 |
+
|
131 |
+
char_embs = T.cat([cembs, sentword_do], dim=-1)
|
132 |
+
#^ char_embs : [b*ts tw dcw] ; dcw = dc + dwe
|
133 |
+
|
134 |
+
char_enc, _ = self.word_lstm(char_embs)
|
135 |
+
#^ char_enc: [b*ts tw dce]
|
136 |
+
|
137 |
+
char_enc_reshaped = char_enc.view(-1, max_words, self.max_word_len, self.word_lstm_units * 2)
|
138 |
+
# #^ char_enc: [b ts tw dce]
|
139 |
+
|
140 |
+
omit_self_mask = (1.0 - T.eye(max_words)).unsqueeze(0).to(self.device)
|
141 |
+
attn_enc, attn_map = self.attention(char_enc_reshaped, sent_enc, word_mask.bool(), prejudice_mask=omit_self_mask)
|
142 |
+
# # #^ attn_enc: [b ts tw dae]
|
143 |
+
|
144 |
+
attn_enc = attn_enc.reshape(-1, self.max_word_len, self.attention.Dout)
|
145 |
+
# #^ attn_enc: [b*ts tw dae]
|
146 |
+
|
147 |
+
final_vec = T.cat([attn_enc, char_enc], dim=-1)
|
148 |
+
|
149 |
+
diac_out = self.classifier(self.dropout(final_vec))
|
150 |
+
#^ diac_out: [b*ts tw 7]
|
151 |
+
|
152 |
+
diac_out = diac_out.view(-1, max_words, self.max_word_len, self.num_classes)
|
153 |
+
#^ diac_out: [b ts tw 7]
|
154 |
+
|
155 |
+
if not self.batch_first:
|
156 |
+
diac_out = diac_out.swapaxes(1, 0)
|
157 |
+
|
158 |
+
return diac_out
|
159 |
+
|
160 |
+
|
161 |
+
def step(self, xt, yt, mask=None):
|
162 |
+
xt[1] = xt[1].to(self.device)
|
163 |
+
xt[2] = xt[2].to(self.device)
|
164 |
+
|
165 |
+
yt = yt.to(self.device)
|
166 |
+
#^ yt: [b ts tw]
|
167 |
+
|
168 |
+
diac, _ = self(*xt)
|
169 |
+
loss = self.closs(diac.view(-1, self.num_classes), yt.view(-1))
|
170 |
+
|
171 |
+
return loss
|
172 |
+
|
173 |
+
def predict(self, dataloader):
|
174 |
+
training = self.training
|
175 |
+
self.eval()
|
176 |
+
|
177 |
+
preds = {'haraka': [], 'shadda': [], 'tanween': []}
|
178 |
+
print("> Predicting...")
|
179 |
+
for inputs, _ in tqdm(dataloader, total=len(dataloader)):
|
180 |
+
inputs[0] = inputs[0].to(self.device)
|
181 |
+
inputs[1] = inputs[1].to(self.device)
|
182 |
+
diac, _ = self(*inputs)
|
183 |
+
|
184 |
+
output = np.argmax(T.softmax(diac.detach(), dim=-1).cpu().numpy(), axis=-1)
|
185 |
+
#^ [b ts tw]
|
186 |
+
|
187 |
+
haraka, tanween, shadda = flat_2_3head(output)
|
188 |
+
|
189 |
+
preds['haraka'].extend(haraka)
|
190 |
+
preds['tanween'].extend(tanween)
|
191 |
+
preds['shadda'].extend(shadda)
|
192 |
+
|
193 |
+
self.train(training)
|
194 |
+
return (
|
195 |
+
np.array(preds['haraka']),
|
196 |
+
np.array(preds["tanween"]),
|
197 |
+
np.array(preds["shadda"]),
|
198 |
+
)
|
199 |
+
|
200 |
+
class DiacritizerD3(nn.Module):
|
201 |
+
def __init__(self, config, device='cuda'):
|
202 |
+
super(DiacritizerD3, self).__init__()
|
203 |
+
self.max_word_len = config["train"]["max-word-len"]
|
204 |
+
self.max_sent_len = config["train"]["max-sent-len"]
|
205 |
+
self.char_embed_dim = config["train"]["char-embed-dim"]
|
206 |
+
|
207 |
+
self.sent_dropout_p = config["train"]["sent-dropout"]
|
208 |
+
self.diac_dropout_p = config["train"]["diac-dropout"]
|
209 |
+
self.vertical_dropout = config['train']['vertical-dropout']
|
210 |
+
self.recurrent_dropout = config['train']['recurrent-dropout']
|
211 |
+
self.recurrent_dropout_mode = config['train'].get('recurrent-dropout-mode', 'gal_tied')
|
212 |
+
self.recurrent_activation = config['train'].get('recurrent-activation', 'sigmoid')
|
213 |
+
|
214 |
+
self.sent_lstm_units = config["train"]["sent-lstm-units"]
|
215 |
+
self.word_lstm_units = config["train"]["word-lstm-units"]
|
216 |
+
self.decoder_units = config["train"]["decoder-units"]
|
217 |
+
|
218 |
+
self.sent_lstm_layers = config["train"]["sent-lstm-layers"]
|
219 |
+
self.word_lstm_layers = config["train"]["word-lstm-layers"]
|
220 |
+
|
221 |
+
self.cell = config['train'].get('rnn-cell', 'lstm')
|
222 |
+
self.num_layers = config["train"].get("num-layers", 2)
|
223 |
+
self.RNN_Layer = K_LSTM
|
224 |
+
|
225 |
+
self.batch_first = config['train'].get('batch-first', True)
|
226 |
+
|
227 |
+
self.baseline = config["train"].get("baseline", False)
|
228 |
+
self.device = device
|
229 |
+
|
230 |
+
def build(self, wembs: T.Tensor, abjad_size: int):
|
231 |
+
self.closs = F.cross_entropy
|
232 |
+
self.bloss = F.binary_cross_entropy_with_logits
|
233 |
+
|
234 |
+
rnn_kargs = dict(
|
235 |
+
recurrent_dropout_mode=self.recurrent_dropout_mode,
|
236 |
+
recurrent_activation=self.recurrent_activation,
|
237 |
+
)
|
238 |
+
|
239 |
+
self.sent_lstm = self.RNN_Layer(
|
240 |
+
input_size=300,
|
241 |
+
hidden_size=self.sent_lstm_units,
|
242 |
+
num_layers=self.sent_lstm_layers,
|
243 |
+
bidirectional=True,
|
244 |
+
vertical_dropout=self.vertical_dropout,
|
245 |
+
recurrent_dropout=self.recurrent_dropout,
|
246 |
+
batch_first=self.batch_first,
|
247 |
+
**rnn_kargs,
|
248 |
+
)
|
249 |
+
|
250 |
+
self.word_lstm = self.RNN_Layer(
|
251 |
+
input_size=self.sent_lstm_units * 2 + self.char_embed_dim,
|
252 |
+
hidden_size=self.word_lstm_units,
|
253 |
+
num_layers=self.word_lstm_layers,
|
254 |
+
bidirectional=True,
|
255 |
+
vertical_dropout=self.vertical_dropout,
|
256 |
+
recurrent_dropout=self.recurrent_dropout,
|
257 |
+
batch_first=self.batch_first,
|
258 |
+
return_states=True,
|
259 |
+
**rnn_kargs,
|
260 |
+
)
|
261 |
+
|
262 |
+
self.char_embs = nn.Embedding(
|
263 |
+
abjad_size,
|
264 |
+
self.char_embed_dim,
|
265 |
+
padding_idx=0,
|
266 |
+
)
|
267 |
+
|
268 |
+
self.attention = Attention(
|
269 |
+
kind="dot",
|
270 |
+
query_dim=self.word_lstm_units * 2,
|
271 |
+
input_dim=self.sent_lstm_units * 2,
|
272 |
+
)
|
273 |
+
|
274 |
+
self.lstm_decoder = self.RNN_Layer(
|
275 |
+
input_size=self.word_lstm_units * 2 + self.attention.Dout + 8,
|
276 |
+
hidden_size=self.word_lstm_units * 2,
|
277 |
+
num_layers=1,
|
278 |
+
bidirectional=False,
|
279 |
+
vertical_dropout=self.vertical_dropout,
|
280 |
+
recurrent_dropout=self.recurrent_dropout,
|
281 |
+
batch_first=self.batch_first,
|
282 |
+
return_states=True,
|
283 |
+
**rnn_kargs,
|
284 |
+
)
|
285 |
+
|
286 |
+
self.word_embs = T.tensor(wembs, dtype=T.float32)
|
287 |
+
|
288 |
+
self.classifier = nn.Linear(self.lstm_decoder.hidden_size, 15)
|
289 |
+
self.dropout = nn.Dropout(0.2)
|
290 |
+
|
291 |
+
def forward(self, sents, words, labels):
|
292 |
+
#^ sents : [b ts]
|
293 |
+
#^ words : [b ts tw]
|
294 |
+
#^ labels: [b ts tw]
|
295 |
+
|
296 |
+
word_mask = words.ne(0.).float()
|
297 |
+
#^ word_mask: [b ts tw]
|
298 |
+
|
299 |
+
if self.training:
|
300 |
+
q = 1.0 - self.sent_dropout_p
|
301 |
+
sdo = T.bernoulli(T.full(sents.shape, q))
|
302 |
+
sents_do = sents * sdo.long()
|
303 |
+
#^ sents_do : [b ts] ; DO(ts)
|
304 |
+
wembs = self.word_embs[sents_do]
|
305 |
+
#^ wembs : [b ts dw] ; DO(ts)
|
306 |
+
else:
|
307 |
+
wembs = self.word_embs[sents]
|
308 |
+
#^ wembs : [b ts dw]
|
309 |
+
|
310 |
+
sent_enc = self.sent_lstm(wembs.to(self.device))
|
311 |
+
#^ sent_enc : [b ts dwe]
|
312 |
+
|
313 |
+
sentword_do = sent_enc.unsqueeze(2)
|
314 |
+
#^ sentword_do : [b ts _ dwe]
|
315 |
+
|
316 |
+
sentword_do = self.dropout(sentword_do * word_mask.unsqueeze(-1))
|
317 |
+
#^ sentword_do : [b ts tw dwe]
|
318 |
+
|
319 |
+
word_index = words.view(-1, self.max_word_len)
|
320 |
+
#^ word_index: [b*ts tw]?
|
321 |
+
|
322 |
+
cembs = self.char_embs(word_index)
|
323 |
+
#^ cembs : [b*ts tw dc]
|
324 |
+
|
325 |
+
sentword_do = sentword_do.view(-1, self.max_word_len, self.sent_lstm_units * 2)
|
326 |
+
#^ sentword_do : [b*ts tw dwe]
|
327 |
+
|
328 |
+
char_embs = T.cat([cembs, sentword_do], dim=-1)
|
329 |
+
#^ char_embs : [b*ts tw dcw] ; dcw = dc + dwe
|
330 |
+
|
331 |
+
char_enc, _ = self.word_lstm(char_embs)
|
332 |
+
#^ char_enc: [b*ts tw dce]
|
333 |
+
|
334 |
+
char_enc_reshaped = char_enc.view(-1, self.max_sent_len, self.max_word_len, self.word_lstm_units * 2)
|
335 |
+
#^ char_enc: [b ts tw dce]
|
336 |
+
|
337 |
+
omit_self_mask = (1.0 - T.eye(self.max_sent_len)).unsqueeze(0).to(self.device)
|
338 |
+
attn_enc, attn_map = self.attention(char_enc_reshaped, sent_enc, word_mask.bool(), prejudice_mask=omit_self_mask)
|
339 |
+
#^ attn_enc: [b ts tw dae]
|
340 |
+
|
341 |
+
attn_enc = attn_enc.view(-1, self.max_sent_len*self.max_word_len, self.attention.Dout)
|
342 |
+
#^ attn_enc: [b*ts tw dae]
|
343 |
+
|
344 |
+
if self.training and self.diac_dropout_p > 0:
|
345 |
+
q = 1.0 - self.diac_dropout_p
|
346 |
+
ddo = T.bernoulli(T.full(labels.shape[:-1], q))
|
347 |
+
labels = labels * ddo.unsqueeze(-1).long().to(self.device)
|
348 |
+
#^ labels : [b ts tw] ; DO(ts)
|
349 |
+
|
350 |
+
labels = labels.view(-1, self.max_sent_len*self.max_word_len, 8).float()
|
351 |
+
#^ labels: [b*ts tw 8]
|
352 |
+
|
353 |
+
char_enc = char_enc.view(-1, self.max_sent_len*self.max_word_len, self.word_lstm_units * 2)
|
354 |
+
|
355 |
+
final_vec = T.cat([attn_enc, char_enc, labels], dim=-1)
|
356 |
+
#^ final_vec: [b ts*tw dae+8]
|
357 |
+
|
358 |
+
dec_out, _ = self.lstm_decoder(final_vec)
|
359 |
+
#^ dec_out: [b*ts tw du]
|
360 |
+
|
361 |
+
dec_out = dec_out.reshape(-1, self.max_word_len, self.lstm_decoder.hidden_size)
|
362 |
+
|
363 |
+
diac_out = self.classifier(self.dropout(dec_out))
|
364 |
+
#^ diac_out: [b*ts tw 7]
|
365 |
+
|
366 |
+
diac_out = diac_out.view(-1, self.max_sent_len, self.max_word_len, 15)
|
367 |
+
#^ diac_out: [b ts tw 7]
|
368 |
+
|
369 |
+
if not self.batch_first:
|
370 |
+
diac_out = diac_out.swapaxes(1, 0)
|
371 |
+
|
372 |
+
return diac_out, attn_map
|
373 |
+
|
374 |
+
def predict_sample(self, sents, words, labels):
|
375 |
+
|
376 |
+
word_mask = words.ne(0.).float()
|
377 |
+
#^ mask: [b ts tw 1]
|
378 |
+
|
379 |
+
if self.training:
|
380 |
+
q = 1.0 - self.sent_dropout_p
|
381 |
+
sdo = T.bernoulli(T.full(sents.shape, q))
|
382 |
+
sents_do = sents * sdo.long()
|
383 |
+
#^ sents_do : [b ts] ; DO(ts)
|
384 |
+
wembs = self.word_embs[sents_do]
|
385 |
+
#^ wembs : [b ts dw] ; DO(ts)
|
386 |
+
else:
|
387 |
+
wembs = self.word_embs[sents]
|
388 |
+
#^ wembs : [b ts dw]
|
389 |
+
|
390 |
+
sent_enc = self.sent_lstm(wembs.to(self.device))
|
391 |
+
#^ sent_enc : [b ts dwe]
|
392 |
+
|
393 |
+
sentword_do = sent_enc.unsqueeze(2)
|
394 |
+
#^ sentword_do : [b ts _ dwe]
|
395 |
+
|
396 |
+
sentword_do = self.dropout(sentword_do * word_mask.unsqueeze(-1))
|
397 |
+
#^ sentword_do : [b ts tw dwe]
|
398 |
+
|
399 |
+
word_index = words.view(-1, self.max_word_len)
|
400 |
+
#^ word_index: [b*ts tw]?
|
401 |
+
|
402 |
+
cembs = self.char_embs(word_index)
|
403 |
+
#^ cembs : [b*ts tw dc]
|
404 |
+
|
405 |
+
sentword_do = sentword_do.view(-1, self.max_word_len, self.sent_lstm_units * 2)
|
406 |
+
#^ sentword_do : [b*ts tw dwe]
|
407 |
+
|
408 |
+
char_embs = T.cat([cembs, sentword_do], dim=-1)
|
409 |
+
#^ char_embs : [b*ts tw dcw] ; dcw = dc + dwe
|
410 |
+
|
411 |
+
char_enc, _ = self.word_lstm(char_embs)
|
412 |
+
#^ char_enc: [b*ts tw dce]
|
413 |
+
#^ word_states: ([b*ts dce], [b*ts dce])
|
414 |
+
|
415 |
+
char_enc = char_enc.view(-1, self.max_sent_len, self.max_word_len, self.word_lstm_units*2)
|
416 |
+
#^ char_enc: [b ts tw dce]
|
417 |
+
|
418 |
+
omit_self_mask = (1.0 - T.eye(self.max_sent_len)).unsqueeze(0).to(self.device)
|
419 |
+
attn_enc, _ = self.attention(char_enc, sent_enc, word_mask.bool(), prejudice_mask=omit_self_mask)
|
420 |
+
#^ attn_enc: [b ts tw dae]
|
421 |
+
|
422 |
+
all_out = T.zeros(*char_enc.size()[:-1], 15).to(self.device)
|
423 |
+
#^ all_out: [b ts tw 7]
|
424 |
+
|
425 |
+
batch_sz = char_enc.size()[0]
|
426 |
+
#^ batch_sz: b
|
427 |
+
|
428 |
+
zeros = T.zeros(1, batch_sz, self.lstm_decoder.hidden_size).to(self.device)
|
429 |
+
#^ zeros: [1 b du]
|
430 |
+
|
431 |
+
bos_tag = T.tensor([0,0,0,0,0,1,0,0]).unsqueeze(0)
|
432 |
+
#^ bos_tag: [1 8]
|
433 |
+
|
434 |
+
prev_label = T.cat([bos_tag]*batch_sz).to(self.device).float()
|
435 |
+
# bos_vec = T.cat([bos_tag]*batch_sz).to(self.device).float()
|
436 |
+
#^ prev_label: [b 8]
|
437 |
+
|
438 |
+
for ts in range(self.max_sent_len):
|
439 |
+
dec_hx = (zeros, zeros)
|
440 |
+
#^ dec_hx: [1 b du]
|
441 |
+
for tw in range(self.max_word_len):
|
442 |
+
final_vec = T.cat([attn_enc[:,ts,tw,:], char_enc[:,ts,tw,:], prev_label], dim=-1).unsqueeze(1)
|
443 |
+
#^ final_vec: [b 1 dce+8]
|
444 |
+
dec_out, dec_hx = self.lstm_decoder(final_vec, dec_hx)
|
445 |
+
#^ dec_out: [b 1 du]
|
446 |
+
dec_out = dec_out.squeeze(0)
|
447 |
+
dec_out = dec_out.transpose(0,1)
|
448 |
+
|
449 |
+
logits_raw = self.classifier(self.dropout(dec_out))
|
450 |
+
#^ logits_raw: [b 1 15]
|
451 |
+
|
452 |
+
out_idx = T.max(T.softmax(logits_raw.squeeze(), dim=-1), dim=-1)[1]
|
453 |
+
|
454 |
+
haraka, tanween, shadda = flat2_3head(out_idx.detach().cpu().numpy())
|
455 |
+
|
456 |
+
haraka_onehot = T.eye(6)[haraka].float().to(self.device)
|
457 |
+
#^ haraka_onehot+bos_tag: [b 6]
|
458 |
+
|
459 |
+
tanween = T.tensor(tanween).float().unsqueeze(-1).to(self.device)
|
460 |
+
shadda = T.tensor(shadda).float().unsqueeze(-1).to(self.device)
|
461 |
+
|
462 |
+
prev_label = T.cat([haraka_onehot, tanween, shadda], dim=-1)
|
463 |
+
|
464 |
+
all_out[:,ts,tw,:] = logits_raw.squeeze()
|
465 |
+
|
466 |
+
if not self.batch_first:
|
467 |
+
all_out = all_out.swapaxes(1, 0)
|
468 |
+
|
469 |
+
return all_out
|
470 |
+
|
471 |
+
def step(self, xt, yt, mask=None):
|
472 |
+
xt[1] = xt[1].to(self.device)
|
473 |
+
xt[2] = xt[2].to(self.device)
|
474 |
+
#^ yt: [b ts tw]
|
475 |
+
yt = yt.to(self.device)
|
476 |
+
|
477 |
+
if self.training:
|
478 |
+
diac, _ = self(*xt)
|
479 |
+
else:
|
480 |
+
diac = self.predict_sample(*xt)
|
481 |
+
#^ diac[0] : [b ts tw 5]
|
482 |
+
|
483 |
+
loss = self.closs(diac.view(-1,15), yt.view(-1))
|
484 |
+
return loss
|
485 |
+
|
486 |
+
def predict(self, dataloader):
|
487 |
+
training = self.training
|
488 |
+
self.eval()
|
489 |
+
|
490 |
+
preds = {'haraka': [], 'shadda': [], 'tanween': []}
|
491 |
+
print("> Predicting...")
|
492 |
+
for inputs, _ in tqdm(dataloader, total=len(dataloader)):
|
493 |
+
inputs[1] = inputs[1].to(self.device)
|
494 |
+
inputs[2] = inputs[2].to(self.device)
|
495 |
+
diac = self.predict_sample(*inputs)
|
496 |
+
output = np.argmax(T.softmax(diac.detach(), dim=-1).cpu().numpy(), axis=-1)
|
497 |
+
#^ [b ts tw]
|
498 |
+
|
499 |
+
haraka, tanween, shadda = flat_2_3head(output)
|
500 |
+
|
501 |
+
preds['haraka'].extend(haraka)
|
502 |
+
preds['tanween'].extend(tanween)
|
503 |
+
preds['shadda'].extend(shadda)
|
504 |
+
|
505 |
+
self.train(training)
|
506 |
+
return (
|
507 |
+
np.array(preds['haraka']),
|
508 |
+
np.array(preds["tanween"]),
|
509 |
+
np.array(preds["shadda"]),
|
510 |
+
)
|
511 |
+
|
512 |
+
if __name__ == "__main__":
|
513 |
+
|
514 |
+
import yaml
|
515 |
+
config_path = "configs/dd/config_d2.yaml"
|
516 |
+
model_path = "models/tashkeela-d2.pt"
|
517 |
+
with open(config_path, 'r', encoding="utf-8") as file:
|
518 |
+
config = yaml.load(file, Loader=yaml.FullLoader)
|
519 |
+
|
520 |
+
data_utils = DatasetUtils(config)
|
521 |
+
vocab_size = len(data_utils.letter_list)
|
522 |
+
word_embeddings = data_utils.embeddings
|
523 |
+
|
524 |
+
model = DiacritizerD2(config, device='cpu')
|
525 |
+
model.build(word_embeddings, vocab_size)
|
526 |
+
model.load_state_dict(T.load(model_path, map_location=T.device('cpu'))["state_dict"])
|
model_partial.py
ADDED
@@ -0,0 +1,348 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import NamedTuple
|
2 |
+
import yaml
|
3 |
+
from tqdm import tqdm
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
import torch as T
|
7 |
+
from torch import nn
|
8 |
+
from torch import functional as F
|
9 |
+
from diac_utils import flat_2_3head
|
10 |
+
|
11 |
+
from model_dd import DiacritizerD2
|
12 |
+
|
13 |
+
class Readout(nn.Module):
|
14 |
+
def __init__(
|
15 |
+
self,
|
16 |
+
in_size: int,
|
17 |
+
out_size: int,
|
18 |
+
):
|
19 |
+
super().__init__()
|
20 |
+
self.W1 = nn.Linear(in_size, in_size)
|
21 |
+
self.W2 = nn.Linear(in_size, out_size)
|
22 |
+
|
23 |
+
def forward(self, x: T.Tensor):
|
24 |
+
z = self.W1(x)
|
25 |
+
z = T.tanh(z)
|
26 |
+
z = self.W2(x)
|
27 |
+
return z
|
28 |
+
|
29 |
+
class WordDD_LSTM(nn.Module):
|
30 |
+
def __init__(
|
31 |
+
self,
|
32 |
+
feature_size: int,
|
33 |
+
num_classes: int = 13,
|
34 |
+
return_logits: bool = True,
|
35 |
+
):
|
36 |
+
super().__init__()
|
37 |
+
self.feature_size = feature_size
|
38 |
+
self.num_classes = num_classes
|
39 |
+
self.return_logits = return_logits
|
40 |
+
self.cell = nn.LSTM(feature_size)
|
41 |
+
self.head = Readout(feature_size, num_classes)
|
42 |
+
|
43 |
+
def forward(self, x: T.Tensor):
|
44 |
+
#^ x: [b tc dc]
|
45 |
+
z = self.cell(x)
|
46 |
+
#^ z: [b tc @dc]
|
47 |
+
y = self.head(z)
|
48 |
+
#^ y: [b tc Classes]
|
49 |
+
yhat = y
|
50 |
+
if not self.return_logits:
|
51 |
+
yhat = F.softmax(yhat, dim=1)
|
52 |
+
#^ yhat: [b tc @Classes]
|
53 |
+
return yhat
|
54 |
+
|
55 |
+
class PartialDiacOutput(NamedTuple):
|
56 |
+
preds_hard: T.Tensor
|
57 |
+
preds_ctxt_logit: T.Tensor
|
58 |
+
preds_base_logit: T.Tensor
|
59 |
+
|
60 |
+
|
61 |
+
class PartialDD(nn.Module):
|
62 |
+
def __init__(
|
63 |
+
self,
|
64 |
+
config: dict,
|
65 |
+
# feature_size: int,
|
66 |
+
# confidence_threshold: float,
|
67 |
+
d2=False
|
68 |
+
):
|
69 |
+
super().__init__()
|
70 |
+
self._built = False
|
71 |
+
self.no_diac_id = 0
|
72 |
+
self._dummy = nn.Parameter(T.ones(1, 1))
|
73 |
+
|
74 |
+
self.config = config
|
75 |
+
self.sentence_diac = DiacritizerD2(self.config)
|
76 |
+
|
77 |
+
self.eval()
|
78 |
+
|
79 |
+
@property
|
80 |
+
def device(self):
|
81 |
+
return self._dummy.device
|
82 |
+
|
83 |
+
@property
|
84 |
+
def tokenizer(self):
|
85 |
+
return self.sentence_diac.tokenizer
|
86 |
+
|
87 |
+
def load_state_dict(
|
88 |
+
self,
|
89 |
+
state_dict: dict
|
90 |
+
):
|
91 |
+
self.sentence_diac.load_state_dict(state_dict)
|
92 |
+
|
93 |
+
def _slim_batch(
|
94 |
+
self,
|
95 |
+
toke_ids: T.Tensor,
|
96 |
+
char_ids: T.Tensor,
|
97 |
+
diac_ids: T.Tensor,
|
98 |
+
subword_lengths: T.Tensor,
|
99 |
+
):
|
100 |
+
#^ toke_ids: [b tt]
|
101 |
+
#^ char_ids: [b tw tc]
|
102 |
+
#^ diac_ids: [b tw tc "13"]
|
103 |
+
#^ subword_lengths: [b tw]
|
104 |
+
token_nonpad_mask = toke_ids.ne(self.tokenizer.pad_token_id)
|
105 |
+
Ttoken = token_nonpad_mask.sum(1).max()
|
106 |
+
toke_ids = toke_ids[:, :Ttoken]
|
107 |
+
|
108 |
+
char_nonpad_mask = char_ids.ne(0)
|
109 |
+
Tword = char_nonpad_mask.any(2).sum(1).max()
|
110 |
+
Tchar = char_nonpad_mask.sum(2).max()
|
111 |
+
char_ids = char_ids[:, :Tword, :Tchar]
|
112 |
+
diac_ids = diac_ids[:, :Tword, :Tchar]
|
113 |
+
subword_lengths = subword_lengths[:, :Tword]
|
114 |
+
|
115 |
+
return toke_ids, char_ids, diac_ids, subword_lengths
|
116 |
+
|
117 |
+
def word_diac(
|
118 |
+
self,
|
119 |
+
toke_ids: T.Tensor,
|
120 |
+
char_ids: T.Tensor,
|
121 |
+
diac_ids: T.Tensor,
|
122 |
+
subword_lengths: T.Tensor,
|
123 |
+
*,
|
124 |
+
shape: tuple = None,
|
125 |
+
):
|
126 |
+
if shape is None:
|
127 |
+
toke_ids, char_ids, diac_ids, subword_lengths = self._slim_batch(
|
128 |
+
toke_ids, char_ids, diac_ids, subword_lengths
|
129 |
+
)
|
130 |
+
else:
|
131 |
+
Nb, Tw, Tc = shape
|
132 |
+
toke_ids = toke_ids[:, :]
|
133 |
+
char_ids = char_ids[:, :Tw, :Tc]
|
134 |
+
diac_ids = diac_ids[:, :Tw, :Tc, :]
|
135 |
+
subword_lengths = subword_lengths[:, :Tw]
|
136 |
+
Nb, Tw, Tc = char_ids.shape
|
137 |
+
# Tw = min(Tw, word_ids.shape[1])
|
138 |
+
#^ word_ids: [b tt]
|
139 |
+
#^ char_ids: [b tw tc]
|
140 |
+
# wids_flat = word_ids[:, Tw].reshape(Nb * Tw, 1)
|
141 |
+
# cids_flat = char_ids[:, Tw].reshape(Nb * Tw, 1, Tc)
|
142 |
+
# z = self.sentence_diac(wids_flat, cids_flat)
|
143 |
+
|
144 |
+
sent_word_strides = subword_lengths.cumsum(1)
|
145 |
+
assert tuple(subword_lengths.shape) == (Nb, Tw), f"{subword_lengths.shape} != {(Nb, Tw)=}"
|
146 |
+
max_tokens_per_word: int = subword_lengths.max().int().item()
|
147 |
+
word_x = T.zeros(Nb, Tw, max_tokens_per_word).to(toke_ids)
|
148 |
+
for i_b in range(toke_ids.shape[0]):
|
149 |
+
sent_i = toke_ids[i_b]
|
150 |
+
start_iw = 0
|
151 |
+
for i_word, end_iw in enumerate(sent_word_strides[i_b]):
|
152 |
+
if end_iw == start_iw: break
|
153 |
+
word = sent_i[start_iw:end_iw]
|
154 |
+
word_x[i_b, i_word, 0 : end_iw - start_iw] = word
|
155 |
+
start_iw = end_iw
|
156 |
+
#^ word_x: [b tw tt]
|
157 |
+
word_x = word_x.reshape(Nb * Tw, max_tokens_per_word)
|
158 |
+
cids_flat = char_ids.reshape(Nb * Tw, 1, Tc)
|
159 |
+
word_lengths = subword_lengths.reshape(Nb * Tw, 1)
|
160 |
+
|
161 |
+
z = self.sentence_diac(
|
162 |
+
word_x,
|
163 |
+
cids_flat,
|
164 |
+
diac_ids.reshape(Nb*Tw, Tc, -1),
|
165 |
+
subword_lengths=word_lengths,
|
166 |
+
)
|
167 |
+
# Nc = z.shape[-1]
|
168 |
+
#^ z: [b*tw, 1, tc, "13"]
|
169 |
+
z = z.reshape(Nb, Tw, Tc, -1)
|
170 |
+
return z
|
171 |
+
|
172 |
+
def forward(
|
173 |
+
self,
|
174 |
+
word_ids: T.Tensor,
|
175 |
+
char_ids: T.Tensor,
|
176 |
+
_labels: T.Tensor,
|
177 |
+
# ground_truth: T.Tensor,
|
178 |
+
# padding_mask: T.BoolTensor,
|
179 |
+
*,
|
180 |
+
eval_only: str = None,
|
181 |
+
subword_lengths: T.Tensor,
|
182 |
+
return_extra: bool = False
|
183 |
+
):
|
184 |
+
# assert self._built and not self.training
|
185 |
+
assert not self.training
|
186 |
+
#^ word_ids: [b tw]
|
187 |
+
#^ char_ids: [b tw tc]
|
188 |
+
#^ ground_truth: [b tw tc]
|
189 |
+
|
190 |
+
padding_mask = char_ids.eq(0)
|
191 |
+
#^ padding_mask: [b tw tc]
|
192 |
+
|
193 |
+
if True or eval_only != 'base':
|
194 |
+
y_ctxt = self.sentence_diac(
|
195 |
+
word_ids,
|
196 |
+
char_ids,
|
197 |
+
_labels,
|
198 |
+
subword_lengths=subword_lengths,
|
199 |
+
)
|
200 |
+
out_shape = y_ctxt.shape[:-1]
|
201 |
+
else:
|
202 |
+
out_shape = self.sentence_diac._slim_batch_size(
|
203 |
+
word_ids,
|
204 |
+
char_ids,
|
205 |
+
_labels,
|
206 |
+
subword_lengths,
|
207 |
+
)[1].shape
|
208 |
+
#^ y_ctxt: [b tw tc "13"]
|
209 |
+
if eval_only == 'ctxt':
|
210 |
+
return y_ctxt.argmax(-1)
|
211 |
+
|
212 |
+
y_base = self.word_diac(
|
213 |
+
word_ids,
|
214 |
+
char_ids,
|
215 |
+
_labels,
|
216 |
+
subword_lengths,
|
217 |
+
shape=out_shape
|
218 |
+
)
|
219 |
+
#^ y_base: [b tw tc "13"]
|
220 |
+
if eval_only == 'base':
|
221 |
+
return y_base.argmax(-1)
|
222 |
+
|
223 |
+
ypred_ctxt = y_ctxt.argmax(-1)
|
224 |
+
ypred_base = y_base.argmax(-1)
|
225 |
+
#^ ypred: [b tw tc _]
|
226 |
+
|
227 |
+
# Maybe for eval
|
228 |
+
# ypred_ctxt[~((ypred_base == ground_truth) & (~padding_mask))] = self.no_diac_id
|
229 |
+
# return ypred_ctxt
|
230 |
+
ypred_ctxt[(padding_mask) | (ypred_base == ypred_ctxt)] = self.no_diac_id
|
231 |
+
if not return_extra:
|
232 |
+
return ypred_ctxt
|
233 |
+
else:
|
234 |
+
return PartialDiacOutput(ypred_ctxt, y_ctxt, y_base)
|
235 |
+
|
236 |
+
def step(self, xt, yt, mask=None):
|
237 |
+
raise NotImplementedError
|
238 |
+
xt[1] = xt[1].to(self.device)
|
239 |
+
xt[2] = xt[2].to(self.device)
|
240 |
+
|
241 |
+
yt = yt.to(self.device)
|
242 |
+
#^ yt: [b ts tw]
|
243 |
+
|
244 |
+
diac, _ = self(*xt) # xt: (word_ids, char_ids, _labels)
|
245 |
+
loss = self.closs(diac.view(-1, self.num_classes), yt.view(-1))
|
246 |
+
|
247 |
+
return loss
|
248 |
+
|
249 |
+
def predict_partial(
|
250 |
+
self,
|
251 |
+
dataloader,
|
252 |
+
return_extra=False,
|
253 |
+
eval_only: str = None,
|
254 |
+
):
|
255 |
+
training = self.training
|
256 |
+
self.eval()
|
257 |
+
|
258 |
+
preds = {
|
259 |
+
'haraka': [],
|
260 |
+
'shadda': [],
|
261 |
+
'tanween': [],
|
262 |
+
'diacs': [],
|
263 |
+
'y_ctxt': [],
|
264 |
+
'y_base': [],
|
265 |
+
}
|
266 |
+
print("> Predicting...")
|
267 |
+
# breakpoint()
|
268 |
+
for i_batch, (inputs, _, subword_lengths) in enumerate(tqdm(dataloader)):
|
269 |
+
# if i_batch > 10:
|
270 |
+
# break
|
271 |
+
#^ inputs: [toke_ids, char_ids, diac_ids]
|
272 |
+
inputs[0] = inputs[0].to(self.device) #< toke_ids
|
273 |
+
inputs[1] = inputs[1].to(self.device) #< char_ids
|
274 |
+
# inputs[2] = inputs[2].to(self.device) #< diac_ids
|
275 |
+
|
276 |
+
if self._use_d2:
|
277 |
+
subword_lengths = T.ones_like(inputs[0])
|
278 |
+
subword_lengths[inputs[0] == 0] = 0
|
279 |
+
|
280 |
+
with T.no_grad():
|
281 |
+
output = self(
|
282 |
+
*inputs,
|
283 |
+
subword_lengths=subword_lengths,
|
284 |
+
return_extra=return_extra,
|
285 |
+
eval_only=eval_only,
|
286 |
+
)
|
287 |
+
|
288 |
+
# output = np.argmax(T.softmax(output.detach(), dim=-1).cpu().numpy(), axis=-1)
|
289 |
+
if return_extra:
|
290 |
+
assert isinstance(output, PartialDiacOutput)
|
291 |
+
marks = output.preds_hard
|
292 |
+
preds['diacs'].extend(list(marks.detach().cpu().numpy()))
|
293 |
+
preds['y_ctxt'].extend(list(output.preds_ctxt_logit.detach().cpu().numpy()))
|
294 |
+
preds['y_base'].extend(list(output.preds_base_logit.detach().cpu().numpy()))
|
295 |
+
else:
|
296 |
+
assert isinstance(output, T.Tensor)
|
297 |
+
marks = output
|
298 |
+
preds['diacs'].extend(list(marks.detach().cpu().numpy()))
|
299 |
+
#^ [b ts tw]
|
300 |
+
|
301 |
+
haraka, tanween, shadda = flat_2_3head(marks)
|
302 |
+
|
303 |
+
preds['haraka'].extend(haraka)
|
304 |
+
preds['tanween'].extend(tanween)
|
305 |
+
preds['shadda'].extend(shadda)
|
306 |
+
|
307 |
+
self.train(training)
|
308 |
+
return {
|
309 |
+
'diacritics': (
|
310 |
+
#! FIXME! Due to batch slimming, output diacritics may need padding.
|
311 |
+
np.array(preds['haraka']),
|
312 |
+
np.array(preds["tanween"]),
|
313 |
+
np.array(preds["shadda"]),
|
314 |
+
),
|
315 |
+
'other': ( # Would be empty when !return_extra
|
316 |
+
preds['y_ctxt'],
|
317 |
+
preds['y_base'],
|
318 |
+
preds['diacs'],
|
319 |
+
)
|
320 |
+
}
|
321 |
+
|
322 |
+
def predict(self, dataloader):
|
323 |
+
training = self.training
|
324 |
+
self.eval()
|
325 |
+
|
326 |
+
preds = {'haraka': [], 'shadda': [], 'tanween': []}
|
327 |
+
print("> Predicting...")
|
328 |
+
for inputs, _ in tqdm(dataloader, total=len(dataloader)):
|
329 |
+
inputs[0] = inputs[0].to(self.device)
|
330 |
+
inputs[1] = inputs[1].to(self.device)
|
331 |
+
output = self(*inputs)
|
332 |
+
|
333 |
+
# output = np.argmax(T.softmax(output.detach(), dim=-1).cpu().numpy(), axis=-1)
|
334 |
+
marks = output
|
335 |
+
#^ [b ts tw]
|
336 |
+
|
337 |
+
haraka, tanween, shadda = flat_2_3head(marks)
|
338 |
+
|
339 |
+
preds['haraka'].extend(haraka)
|
340 |
+
preds['tanween'].extend(tanween)
|
341 |
+
preds['shadda'].extend(shadda)
|
342 |
+
|
343 |
+
self.train(training)
|
344 |
+
return (
|
345 |
+
np.array(preds['haraka']),
|
346 |
+
np.array(preds["tanween"]),
|
347 |
+
np.array(preds["shadda"]),
|
348 |
+
)
|
predict.py
ADDED
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Iterable, Union, Tuple
|
2 |
+
from collections import Counter
|
3 |
+
|
4 |
+
import argparse
|
5 |
+
import os
|
6 |
+
|
7 |
+
import yaml
|
8 |
+
from pyarabic.araby import tokenize, strip_tatweel
|
9 |
+
from tqdm import tqdm
|
10 |
+
|
11 |
+
import numpy as np
|
12 |
+
import torch as T
|
13 |
+
from torch.utils.data import DataLoader
|
14 |
+
|
15 |
+
from diac_utils import HARAKAT_MAP, shakkel_char, diac_ids_of_line
|
16 |
+
from model_partial import PartialDD
|
17 |
+
from data_utils import DatasetUtils
|
18 |
+
from dataloader import DataRetriever
|
19 |
+
from segment import segment
|
20 |
+
|
21 |
+
class Predictor:
|
22 |
+
def __init__(self, config, text):
|
23 |
+
|
24 |
+
self.data_utils = DatasetUtils(config)
|
25 |
+
vocab_size = len(self.data_utils.letter_list)
|
26 |
+
word_embeddings = self.data_utils.embeddings
|
27 |
+
|
28 |
+
stride = config["segment"]["stride"]
|
29 |
+
window = config["segment"]["window"]
|
30 |
+
min_window = config["segment"]["min-window"]
|
31 |
+
|
32 |
+
segments, mapping = segment([text], stride, window, min_window)
|
33 |
+
|
34 |
+
mapping_lines = []
|
35 |
+
for sent_idx, seg_idx, word_idx, char_idx in mapping:
|
36 |
+
mapping_lines += [f"{sent_idx}, {seg_idx}, {word_idx}, {char_idx}"]
|
37 |
+
|
38 |
+
self.mapping = self.data_utils.load_mapping_v3_from_list(mapping_lines)
|
39 |
+
self.original_lines = [text]
|
40 |
+
self.segments = segments
|
41 |
+
|
42 |
+
self.device = T.device(
|
43 |
+
config['predictor'].get('device', 'cuda:0')
|
44 |
+
if T.cuda.is_available() else 'cpu'
|
45 |
+
)
|
46 |
+
|
47 |
+
self.model = PartialDD(config, d2=True)
|
48 |
+
self.model.sentence_diac.build(word_embeddings, vocab_size)
|
49 |
+
state_dict = T.load(config["paths"]["load"], map_location=T.device(self.device))['state_dict']
|
50 |
+
self.model.load_state_dict(state_dict)
|
51 |
+
self.model.to(self.device)
|
52 |
+
self.model.eval()
|
53 |
+
|
54 |
+
self.data_loader = DataLoader(
|
55 |
+
DataRetriever(self.data_utils, segments),
|
56 |
+
batch_size=config["predictor"].get("batch-size", 32),
|
57 |
+
shuffle=False,
|
58 |
+
num_workers=config['loader'].get('num-workers', 0),
|
59 |
+
)
|
60 |
+
|
61 |
+
class PredictTri(Predictor):
|
62 |
+
def __init__(self, config, text):
|
63 |
+
super().__init__(config, text)
|
64 |
+
self.diacritics = {
|
65 |
+
"FATHA": 1,
|
66 |
+
"KASRA": 2,
|
67 |
+
"DAMMA": 3,
|
68 |
+
"SUKUN": 4
|
69 |
+
}
|
70 |
+
self.votes: Union[Counter[int], Counter[bool]] = Counter()
|
71 |
+
|
72 |
+
def count_votes(
|
73 |
+
self,
|
74 |
+
things: Union[Iterable[int], Iterable[bool]]
|
75 |
+
):
|
76 |
+
self.votes.clear()
|
77 |
+
self.votes.update(things)
|
78 |
+
return self.votes.most_common(1)[0][0]
|
79 |
+
|
80 |
+
def predict_majority_vote(self):
|
81 |
+
y_gen_diac, y_gen_tanween, y_gen_shadda = self.model.predict(self.data_loader)
|
82 |
+
diacritized_lines = self.coalesce_votes_by_majority(y_gen_diac, y_gen_tanween, y_gen_shadda)
|
83 |
+
return diacritized_lines
|
84 |
+
|
85 |
+
def predict_majority_vote_context_contrastive(self, overwrite_cache=False):
|
86 |
+
assert isinstance(self.model, PartialDD)
|
87 |
+
if not os.path.exists("dataset/cache/y_gen_diac.npy") or overwrite_cache:
|
88 |
+
if not os.path.exists("dataset/cache"):
|
89 |
+
os.mkdir("dataset/cache")
|
90 |
+
# segment_outputs = self.model.predict_partial(self.data_loader, return_extra=True)
|
91 |
+
segment_outputs = self.model.predict_partial(self.data_loader, return_extra=False, eval_only='ctxt')
|
92 |
+
T.save(segment_outputs, "dataset/cache/cache.pt")
|
93 |
+
else:
|
94 |
+
segment_outputs = T.load("dataset/cache/cache.pt")
|
95 |
+
|
96 |
+
y_gen_diac, y_gen_tanween, y_gen_shadda = segment_outputs['diacritics']
|
97 |
+
diacritized_lines, extra_for_lines = self.coalesce_votes_by_majority(
|
98 |
+
y_gen_diac, y_gen_tanween, y_gen_shadda,
|
99 |
+
)
|
100 |
+
extra_out = {
|
101 |
+
'line_data': {
|
102 |
+
**extra_for_lines,
|
103 |
+
},
|
104 |
+
'segment_data': {
|
105 |
+
**segment_outputs,
|
106 |
+
# 'logits': segment_outputs['logits'],
|
107 |
+
}
|
108 |
+
}
|
109 |
+
return diacritized_lines, extra_out
|
110 |
+
|
111 |
+
def coalesce_votes_by_majority(
|
112 |
+
self,
|
113 |
+
y_gen_diac: np.ndarray,
|
114 |
+
y_gen_tanween: np.ndarray,
|
115 |
+
y_gen_shadda: np.ndarray,
|
116 |
+
):
|
117 |
+
prepped_lines_og = [' '.join(tokenize(strip_tatweel(line))) for line in self.original_lines]
|
118 |
+
max_line_chars = max(len(line) for line in prepped_lines_og)
|
119 |
+
diacritics_pred = np.full((len(self.original_lines), max_line_chars), fill_value=-1, dtype=int)
|
120 |
+
|
121 |
+
count_processed_sents = 0
|
122 |
+
do_break = False
|
123 |
+
diacritized_lines = []
|
124 |
+
for sent_idx, line in enumerate(tqdm(prepped_lines_og)):
|
125 |
+
count_processed_sents = sent_idx + 1
|
126 |
+
line = line.strip()
|
127 |
+
diacritized_line = ""
|
128 |
+
for char_idx, char in enumerate(line):
|
129 |
+
diacritized_line += char
|
130 |
+
char_vote_diacritic = []
|
131 |
+
# ? This is the voting part
|
132 |
+
if sent_idx not in self.mapping:
|
133 |
+
continue
|
134 |
+
|
135 |
+
mapping_s_i = self.mapping[sent_idx]
|
136 |
+
for seg_idx in mapping_s_i:
|
137 |
+
if self.data_utils.debug and seg_idx >= 256:
|
138 |
+
do_break = True
|
139 |
+
break
|
140 |
+
|
141 |
+
mapping_g_i = mapping_s_i[seg_idx]
|
142 |
+
for t_idx in mapping_g_i:
|
143 |
+
|
144 |
+
mapping_t_i = mapping_g_i[t_idx]
|
145 |
+
if char_idx in mapping_t_i:
|
146 |
+
c_idx = mapping_t_i.index(char_idx)
|
147 |
+
output_idx = np.s_[seg_idx, t_idx, c_idx]
|
148 |
+
diac_h3 = (y_gen_diac[output_idx], y_gen_tanween[output_idx], y_gen_shadda[output_idx])
|
149 |
+
diac_char_i = HARAKAT_MAP.index(diac_h3)
|
150 |
+
if c_idx < 13 and diac_char_i != 0:
|
151 |
+
char_vote_diacritic.append(diac_char_i)
|
152 |
+
|
153 |
+
if do_break:
|
154 |
+
break
|
155 |
+
if len(char_vote_diacritic) > 0:
|
156 |
+
char_mv_diac = self.count_votes(char_vote_diacritic)
|
157 |
+
diacritized_line += shakkel_char(*HARAKAT_MAP[char_mv_diac])
|
158 |
+
diacritics_pred[sent_idx, char_idx] = char_mv_diac
|
159 |
+
else:
|
160 |
+
diacritics_pred[sent_idx, char_idx] = 0
|
161 |
+
if do_break:
|
162 |
+
break
|
163 |
+
|
164 |
+
diacritized_lines += [diacritized_line.strip()]
|
165 |
+
|
166 |
+
print(f'[INFO] Cutting stats from {len(diacritics_pred)} to {count_processed_sents}')
|
167 |
+
extra = {
|
168 |
+
'diac_pred': diacritics_pred[:count_processed_sents],
|
169 |
+
}
|
170 |
+
return diacritized_lines, extra
|
segment.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import yaml
|
3 |
+
import os
|
4 |
+
import pickle as pkl
|
5 |
+
|
6 |
+
from tqdm import tqdm
|
7 |
+
from pyarabic.araby import tokenize, strip_tashkeel, strip_tatweel
|
8 |
+
|
9 |
+
def export(path, text):
|
10 |
+
with open(path, 'w', encoding="utf-8") as fout:
|
11 |
+
fout.write('\n'.join(text))
|
12 |
+
|
13 |
+
def segment(lines, stride, window_sz, min_window_sz):
|
14 |
+
segments, mapping = [], []
|
15 |
+
real_seg_idx = 0
|
16 |
+
|
17 |
+
for sent_idx, line in tqdm(enumerate(lines), total=len(lines)):
|
18 |
+
line: str = strip_tatweel(line)
|
19 |
+
line = line.strip()
|
20 |
+
tokens = tokenize(line)
|
21 |
+
if len(tokens) == 0: continue
|
22 |
+
if tokens[-1] == '\n': tokens = tokens[:-1]
|
23 |
+
seg_idx, idx = 0, 0
|
24 |
+
while idx < len(tokens):
|
25 |
+
window = tokens[idx:idx+window_sz]
|
26 |
+
if window_sz == -1: window = tokens
|
27 |
+
if len(window) < min_window_sz and seg_idx != 0: break
|
28 |
+
|
29 |
+
segment = ' '.join(window)
|
30 |
+
segments += [segment]
|
31 |
+
char_offset = len(strip_tashkeel(' '.join(tokens[:idx])))
|
32 |
+
|
33 |
+
if seg_idx > 0:
|
34 |
+
char_offset += 1
|
35 |
+
|
36 |
+
seg_tokens = tokenize(strip_tashkeel(segment))
|
37 |
+
|
38 |
+
j = 0
|
39 |
+
for st_idx, st in enumerate(seg_tokens):
|
40 |
+
for _ in range(len(st)):
|
41 |
+
mapping += [(sent_idx, real_seg_idx, st_idx, j+char_offset)]
|
42 |
+
j += 1
|
43 |
+
j += 1
|
44 |
+
|
45 |
+
real_seg_idx += 1
|
46 |
+
seg_idx += 1
|
47 |
+
|
48 |
+
if stride == -1: break
|
49 |
+
|
50 |
+
idx += (window_sz if stride >= window_sz else stride)
|
51 |
+
|
52 |
+
return segments, mapping
|
53 |
+
|
54 |
+
if __name__ == "__main__":
|
55 |
+
parser = argparse.ArgumentParser(description='Sentence Breaker')
|
56 |
+
parser.add_argument('-c', '--config', type=str,
|
57 |
+
default="config.yaml", help='Run Configs')
|
58 |
+
parser.add_argument('-d', '--data_dir', type=str,
|
59 |
+
default=None, help='Override for data path')
|
60 |
+
args = parser.parse_args()
|
61 |
+
|
62 |
+
with open(args.config, 'r', encoding="utf-8") as file:
|
63 |
+
config = yaml.load(file, Loader=yaml.FullLoader)
|
64 |
+
|
65 |
+
BASE_PATH = args.data_dir or config["paths"].get("base")
|
66 |
+
|
67 |
+
stride = config["segment"]["stride"]
|
68 |
+
window = config["segment"]["window"]
|
69 |
+
min_window = config["segment"]["min-window"]
|
70 |
+
export_map = config["segment"]["export-map"]
|
71 |
+
|
72 |
+
for fpath in tqdm(config["segment"]["files"]):
|
73 |
+
FILE_PATH = os.path.join(BASE_PATH, fpath)
|
74 |
+
SAVE_PATH = os.path.join(BASE_PATH, fpath[:-4] + f"-{stride}-{window}.txt")
|
75 |
+
MAP_PATH = os.path.join(BASE_PATH, fpath[:-4] + f"-{stride}-{window}.map")
|
76 |
+
|
77 |
+
with open(FILE_PATH, 'r', encoding="utf-8") as fin:
|
78 |
+
lines = fin.readlines()
|
79 |
+
|
80 |
+
segments, mapping = segment(lines, stride, window, min_window)
|
81 |
+
|
82 |
+
with open(SAVE_PATH, 'w', encoding="utf-8") as fout:
|
83 |
+
fout.write('\n'.join(segments))
|
84 |
+
|
85 |
+
if not export_map: continue
|
86 |
+
|
87 |
+
with open(MAP_PATH, 'w', encoding="utf-8") as fout:
|
88 |
+
for sent_idx, seg_idx, word_idx, char_idx in mapping:
|
89 |
+
fout.write(f"{sent_idx}, {seg_idx}, {word_idx}, {char_idx}\n")
|