bug fix
Browse files- model_partial.py +2 -3
- predict.py +2 -1
model_partial.py
CHANGED
@@ -178,7 +178,7 @@ class PartialDD(nn.Module):
|
|
178 |
# padding_mask: T.BoolTensor,
|
179 |
*,
|
180 |
eval_only: str = None,
|
181 |
-
subword_lengths: T.Tensor,
|
182 |
return_extra: bool = False
|
183 |
):
|
184 |
# assert self._built and not self.training
|
@@ -195,7 +195,6 @@ class PartialDD(nn.Module):
|
|
195 |
word_ids,
|
196 |
char_ids,
|
197 |
_labels,
|
198 |
-
subword_lengths=subword_lengths,
|
199 |
)
|
200 |
out_shape = y_ctxt.shape[:-1]
|
201 |
else:
|
@@ -328,7 +327,7 @@ class PartialDD(nn.Module):
|
|
328 |
for inputs, _ in tqdm(dataloader, total=len(dataloader)):
|
329 |
inputs[0] = inputs[0].to(self.device)
|
330 |
inputs[1] = inputs[1].to(self.device)
|
331 |
-
output = self(*inputs)
|
332 |
|
333 |
# output = np.argmax(T.softmax(output.detach(), dim=-1).cpu().numpy(), axis=-1)
|
334 |
marks = output
|
|
|
178 |
# padding_mask: T.BoolTensor,
|
179 |
*,
|
180 |
eval_only: str = None,
|
181 |
+
subword_lengths: T.Tensor = None,
|
182 |
return_extra: bool = False
|
183 |
):
|
184 |
# assert self._built and not self.training
|
|
|
195 |
word_ids,
|
196 |
char_ids,
|
197 |
_labels,
|
|
|
198 |
)
|
199 |
out_shape = y_ctxt.shape[:-1]
|
200 |
else:
|
|
|
327 |
for inputs, _ in tqdm(dataloader, total=len(dataloader)):
|
328 |
inputs[0] = inputs[0].to(self.device)
|
329 |
inputs[1] = inputs[1].to(self.device)
|
330 |
+
output = self(*inputs, eval_only='ctxt')
|
331 |
|
332 |
# output = np.argmax(T.softmax(output.detach(), dim=-1).cpu().numpy(), axis=-1)
|
333 |
marks = output
|
predict.py
CHANGED
@@ -14,6 +14,7 @@ from torch.utils.data import DataLoader
|
|
14 |
|
15 |
from diac_utils import HARAKAT_MAP, shakkel_char, diac_ids_of_line
|
16 |
from model_partial import PartialDD
|
|
|
17 |
from data_utils import DatasetUtils
|
18 |
from dataloader import DataRetriever
|
19 |
from segment import segment
|
@@ -44,7 +45,7 @@ class Predictor:
|
|
44 |
if T.cuda.is_available() else 'cpu'
|
45 |
)
|
46 |
|
47 |
-
self.model =
|
48 |
self.model.sentence_diac.build(word_embeddings, vocab_size)
|
49 |
state_dict = T.load(config["paths"]["load"], map_location=T.device(self.device))['state_dict']
|
50 |
self.model.load_state_dict(state_dict)
|
|
|
14 |
|
15 |
from diac_utils import HARAKAT_MAP, shakkel_char, diac_ids_of_line
|
16 |
from model_partial import PartialDD
|
17 |
+
from model_dd import DiacritizerD2
|
18 |
from data_utils import DatasetUtils
|
19 |
from dataloader import DataRetriever
|
20 |
from segment import segment
|
|
|
45 |
if T.cuda.is_available() else 'cpu'
|
46 |
)
|
47 |
|
48 |
+
self.model = DiacritizerD2(config)
|
49 |
self.model.sentence_diac.build(word_embeddings, vocab_size)
|
50 |
state_dict = T.load(config["paths"]["load"], map_location=T.device(self.device))['state_dict']
|
51 |
self.model.load_state_dict(state_dict)
|