bkhmsi commited on
Commit
5d95780
1 Parent(s): 5314058
Files changed (2) hide show
  1. model_partial.py +2 -3
  2. predict.py +2 -1
model_partial.py CHANGED
@@ -178,7 +178,7 @@ class PartialDD(nn.Module):
178
  # padding_mask: T.BoolTensor,
179
  *,
180
  eval_only: str = None,
181
- subword_lengths: T.Tensor,
182
  return_extra: bool = False
183
  ):
184
  # assert self._built and not self.training
@@ -195,7 +195,6 @@ class PartialDD(nn.Module):
195
  word_ids,
196
  char_ids,
197
  _labels,
198
- subword_lengths=subword_lengths,
199
  )
200
  out_shape = y_ctxt.shape[:-1]
201
  else:
@@ -328,7 +327,7 @@ class PartialDD(nn.Module):
328
  for inputs, _ in tqdm(dataloader, total=len(dataloader)):
329
  inputs[0] = inputs[0].to(self.device)
330
  inputs[1] = inputs[1].to(self.device)
331
- output = self(*inputs)
332
 
333
  # output = np.argmax(T.softmax(output.detach(), dim=-1).cpu().numpy(), axis=-1)
334
  marks = output
 
178
  # padding_mask: T.BoolTensor,
179
  *,
180
  eval_only: str = None,
181
+ subword_lengths: T.Tensor = None,
182
  return_extra: bool = False
183
  ):
184
  # assert self._built and not self.training
 
195
  word_ids,
196
  char_ids,
197
  _labels,
 
198
  )
199
  out_shape = y_ctxt.shape[:-1]
200
  else:
 
327
  for inputs, _ in tqdm(dataloader, total=len(dataloader)):
328
  inputs[0] = inputs[0].to(self.device)
329
  inputs[1] = inputs[1].to(self.device)
330
+ output = self(*inputs, eval_only='ctxt')
331
 
332
  # output = np.argmax(T.softmax(output.detach(), dim=-1).cpu().numpy(), axis=-1)
333
  marks = output
predict.py CHANGED
@@ -14,6 +14,7 @@ from torch.utils.data import DataLoader
14
 
15
  from diac_utils import HARAKAT_MAP, shakkel_char, diac_ids_of_line
16
  from model_partial import PartialDD
 
17
  from data_utils import DatasetUtils
18
  from dataloader import DataRetriever
19
  from segment import segment
@@ -44,7 +45,7 @@ class Predictor:
44
  if T.cuda.is_available() else 'cpu'
45
  )
46
 
47
- self.model = PartialDD(config, d2=True)
48
  self.model.sentence_diac.build(word_embeddings, vocab_size)
49
  state_dict = T.load(config["paths"]["load"], map_location=T.device(self.device))['state_dict']
50
  self.model.load_state_dict(state_dict)
 
14
 
15
  from diac_utils import HARAKAT_MAP, shakkel_char, diac_ids_of_line
16
  from model_partial import PartialDD
17
+ from model_dd import DiacritizerD2
18
  from data_utils import DatasetUtils
19
  from dataloader import DataRetriever
20
  from segment import segment
 
45
  if T.cuda.is_available() else 'cpu'
46
  )
47
 
48
+ self.model = DiacritizerD2(config)
49
  self.model.sentence_diac.build(word_embeddings, vocab_size)
50
  state_dict = T.load(config["paths"]["load"], map_location=T.device(self.device))['state_dict']
51
  self.model.load_state_dict(state_dict)