willwade commited on
Commit
e2c1e0f
·
1 Parent(s): b817428

First push

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. configs/LRS3_V_WER19.1.ini +18 -0
  2. espnet/.DS_Store +0 -0
  3. espnet/asr/asr_utils.py +990 -0
  4. espnet/nets/.DS_Store +0 -0
  5. espnet/nets/batch_beam_search.py +349 -0
  6. espnet/nets/beam_search.py +516 -0
  7. espnet/nets/ctc_prefix_score.py +359 -0
  8. espnet/nets/e2e_asr_common.py +249 -0
  9. espnet/nets/lm_interface.py +86 -0
  10. espnet/nets/pytorch_backend/backbones/conv1d_extractor.py +25 -0
  11. espnet/nets/pytorch_backend/backbones/conv3d_extractor.py +47 -0
  12. espnet/nets/pytorch_backend/backbones/modules/resnet.py +178 -0
  13. espnet/nets/pytorch_backend/backbones/modules/resnet1d.py +213 -0
  14. espnet/nets/pytorch_backend/backbones/modules/shufflenetv2.py +165 -0
  15. espnet/nets/pytorch_backend/ctc.py +283 -0
  16. espnet/nets/pytorch_backend/e2e_asr_transformer.py +320 -0
  17. espnet/nets/pytorch_backend/e2e_asr_transformer_av.py +352 -0
  18. espnet/nets/pytorch_backend/lm/__init__.py +1 -0
  19. espnet/nets/pytorch_backend/lm/default.py +431 -0
  20. espnet/nets/pytorch_backend/lm/seq_rnn.py +178 -0
  21. espnet/nets/pytorch_backend/lm/transformer.py +252 -0
  22. espnet/nets/pytorch_backend/nets_utils.py +526 -0
  23. espnet/nets/pytorch_backend/transformer/__init__.py +1 -0
  24. espnet/nets/pytorch_backend/transformer/add_sos_eos.py +31 -0
  25. espnet/nets/pytorch_backend/transformer/attention.py +280 -0
  26. espnet/nets/pytorch_backend/transformer/convolution.py +73 -0
  27. espnet/nets/pytorch_backend/transformer/decoder.py +229 -0
  28. espnet/nets/pytorch_backend/transformer/decoder_layer.py +121 -0
  29. espnet/nets/pytorch_backend/transformer/embedding.py +217 -0
  30. espnet/nets/pytorch_backend/transformer/encoder.py +283 -0
  31. espnet/nets/pytorch_backend/transformer/encoder_layer.py +149 -0
  32. espnet/nets/pytorch_backend/transformer/label_smoothing_loss.py +63 -0
  33. espnet/nets/pytorch_backend/transformer/layer_norm.py +33 -0
  34. espnet/nets/pytorch_backend/transformer/mask.py +51 -0
  35. espnet/nets/pytorch_backend/transformer/multi_layer_conv.py +105 -0
  36. espnet/nets/pytorch_backend/transformer/optimizer.py +75 -0
  37. espnet/nets/pytorch_backend/transformer/plot.py +134 -0
  38. espnet/nets/pytorch_backend/transformer/positionwise_feed_forward.py +30 -0
  39. espnet/nets/pytorch_backend/transformer/raw_embeddings.py +77 -0
  40. espnet/nets/pytorch_backend/transformer/repeat.py +30 -0
  41. espnet/nets/pytorch_backend/transformer/subsampling.py +52 -0
  42. espnet/nets/scorer_interface.py +188 -0
  43. espnet/nets/scorers/__init__.py +1 -0
  44. espnet/nets/scorers/ctc.py +158 -0
  45. espnet/nets/scorers/length_bonus.py +61 -0
  46. espnet/utils/cli_utils.py +65 -0
  47. espnet/utils/dynamic_import.py +23 -0
  48. espnet/utils/fill_missing_args.py +46 -0
  49. pipelines/.DS_Store +0 -0
  50. pipelines/data/.DS_Store +0 -0
configs/LRS3_V_WER19.1.ini ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [input]
2
+ modality=video
3
+ v_fps=25
4
+
5
+ [model]
6
+ v_fps=25
7
+ model_path=benchmarks/LRS3/models/LRS3_V_WER19.1/model.pth
8
+ model_conf=benchmarks/LRS3/models/LRS3_V_WER19.1/model.json
9
+ rnnlm=benchmarks/LRS3/language_models/lm_en_subword/model.pth
10
+ rnnlm_conf=benchmarks/LRS3/language_models/lm_en_subword/model.json
11
+
12
+ [decode]
13
+ beam_size=40
14
+ penalty=0.0
15
+ maxlenratio=0.0
16
+ minlenratio=0.0
17
+ ctc_weight=0.1
18
+ lm_weight=0.3
espnet/.DS_Store ADDED
Binary file (6.15 kB). View file
 
espnet/asr/asr_utils.py ADDED
@@ -0,0 +1,990 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2017 Johns Hopkins University (Shinji Watanabe)
2
+ # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
3
+
4
+ import argparse
5
+ import copy
6
+ import json
7
+ import logging
8
+ import os
9
+ import shutil
10
+ import tempfile
11
+
12
+ import numpy as np
13
+ import torch
14
+
15
+
16
+ # * -------------------- training iterator related -------------------- *
17
+
18
+
19
+ class CompareValueTrigger(object):
20
+ """Trigger invoked when key value getting bigger or lower than before.
21
+
22
+ Args:
23
+ key (str) : Key of value.
24
+ compare_fn ((float, float) -> bool) : Function to compare the values.
25
+ trigger (tuple(int, str)) : Trigger that decide the comparison interval.
26
+
27
+ """
28
+
29
+ def __init__(self, key, compare_fn, trigger=(1, "epoch")):
30
+ from chainer import training
31
+
32
+ self._key = key
33
+ self._best_value = None
34
+ self._interval_trigger = training.util.get_trigger(trigger)
35
+ self._init_summary()
36
+ self._compare_fn = compare_fn
37
+
38
+ def __call__(self, trainer):
39
+ """Get value related to the key and compare with current value."""
40
+ observation = trainer.observation
41
+ summary = self._summary
42
+ key = self._key
43
+ if key in observation:
44
+ summary.add({key: observation[key]})
45
+
46
+ if not self._interval_trigger(trainer):
47
+ return False
48
+
49
+ stats = summary.compute_mean()
50
+ value = float(stats[key]) # copy to CPU
51
+ self._init_summary()
52
+
53
+ if self._best_value is None:
54
+ # initialize best value
55
+ self._best_value = value
56
+ return False
57
+ elif self._compare_fn(self._best_value, value):
58
+ return True
59
+ else:
60
+ self._best_value = value
61
+ return False
62
+
63
+ def _init_summary(self):
64
+ import chainer
65
+
66
+ self._summary = chainer.reporter.DictSummary()
67
+
68
+
69
+ try:
70
+ from chainer.training import extension
71
+ except ImportError:
72
+ PlotAttentionReport = None
73
+ else:
74
+
75
+ class PlotAttentionReport(extension.Extension):
76
+ """Plot attention reporter.
77
+
78
+ Args:
79
+ att_vis_fn (espnet.nets.*_backend.e2e_asr.E2E.calculate_all_attentions):
80
+ Function of attention visualization.
81
+ data (list[tuple(str, dict[str, list[Any]])]): List json utt key items.
82
+ outdir (str): Directory to save figures.
83
+ converter (espnet.asr.*_backend.asr.CustomConverter):
84
+ Function to convert data.
85
+ device (int | torch.device): Device.
86
+ reverse (bool): If True, input and output length are reversed.
87
+ ikey (str): Key to access input
88
+ (for ASR/ST ikey="input", for MT ikey="output".)
89
+ iaxis (int): Dimension to access input
90
+ (for ASR/ST iaxis=0, for MT iaxis=1.)
91
+ okey (str): Key to access output
92
+ (for ASR/ST okey="input", MT okay="output".)
93
+ oaxis (int): Dimension to access output
94
+ (for ASR/ST oaxis=0, for MT oaxis=0.)
95
+ subsampling_factor (int): subsampling factor in encoder
96
+
97
+ """
98
+
99
+ def __init__(
100
+ self,
101
+ att_vis_fn,
102
+ data,
103
+ outdir,
104
+ converter,
105
+ transform,
106
+ device,
107
+ reverse=False,
108
+ ikey="input",
109
+ iaxis=0,
110
+ okey="output",
111
+ oaxis=0,
112
+ subsampling_factor=1,
113
+ ):
114
+ self.att_vis_fn = att_vis_fn
115
+ self.data = copy.deepcopy(data)
116
+ self.data_dict = {k: v for k, v in copy.deepcopy(data)}
117
+ # key is utterance ID
118
+ self.outdir = outdir
119
+ self.converter = converter
120
+ self.transform = transform
121
+ self.device = device
122
+ self.reverse = reverse
123
+ self.ikey = ikey
124
+ self.iaxis = iaxis
125
+ self.okey = okey
126
+ self.oaxis = oaxis
127
+ self.factor = subsampling_factor
128
+ if not os.path.exists(self.outdir):
129
+ os.makedirs(self.outdir)
130
+
131
+ def __call__(self, trainer):
132
+ """Plot and save image file of att_ws matrix."""
133
+ att_ws, uttid_list = self.get_attention_weights()
134
+ if isinstance(att_ws, list): # multi-encoder case
135
+ num_encs = len(att_ws) - 1
136
+ # atts
137
+ for i in range(num_encs):
138
+ for idx, att_w in enumerate(att_ws[i]):
139
+ filename = "%s/%s.ep.{.updater.epoch}.att%d.png" % (
140
+ self.outdir,
141
+ uttid_list[idx],
142
+ i + 1,
143
+ )
144
+ att_w = self.trim_attention_weight(uttid_list[idx], att_w)
145
+ np_filename = "%s/%s.ep.{.updater.epoch}.att%d.npy" % (
146
+ self.outdir,
147
+ uttid_list[idx],
148
+ i + 1,
149
+ )
150
+ np.save(np_filename.format(trainer), att_w)
151
+ self._plot_and_save_attention(att_w, filename.format(trainer))
152
+ # han
153
+ for idx, att_w in enumerate(att_ws[num_encs]):
154
+ filename = "%s/%s.ep.{.updater.epoch}.han.png" % (
155
+ self.outdir,
156
+ uttid_list[idx],
157
+ )
158
+ att_w = self.trim_attention_weight(uttid_list[idx], att_w)
159
+ np_filename = "%s/%s.ep.{.updater.epoch}.han.npy" % (
160
+ self.outdir,
161
+ uttid_list[idx],
162
+ )
163
+ np.save(np_filename.format(trainer), att_w)
164
+ self._plot_and_save_attention(
165
+ att_w, filename.format(trainer), han_mode=True
166
+ )
167
+ else:
168
+ for idx, att_w in enumerate(att_ws):
169
+ filename = "%s/%s.ep.{.updater.epoch}.png" % (
170
+ self.outdir,
171
+ uttid_list[idx],
172
+ )
173
+ att_w = self.trim_attention_weight(uttid_list[idx], att_w)
174
+ np_filename = "%s/%s.ep.{.updater.epoch}.npy" % (
175
+ self.outdir,
176
+ uttid_list[idx],
177
+ )
178
+ np.save(np_filename.format(trainer), att_w)
179
+ self._plot_and_save_attention(att_w, filename.format(trainer))
180
+
181
+ def log_attentions(self, logger, step):
182
+ """Add image files of att_ws matrix to the tensorboard."""
183
+ att_ws, uttid_list = self.get_attention_weights()
184
+ if isinstance(att_ws, list): # multi-encoder case
185
+ num_encs = len(att_ws) - 1
186
+ # atts
187
+ for i in range(num_encs):
188
+ for idx, att_w in enumerate(att_ws[i]):
189
+ att_w = self.trim_attention_weight(uttid_list[idx], att_w)
190
+ plot = self.draw_attention_plot(att_w)
191
+ logger.add_figure(
192
+ "%s_att%d" % (uttid_list[idx], i + 1),
193
+ plot.gcf(),
194
+ step,
195
+ )
196
+ # han
197
+ for idx, att_w in enumerate(att_ws[num_encs]):
198
+ att_w = self.trim_attention_weight(uttid_list[idx], att_w)
199
+ plot = self.draw_han_plot(att_w)
200
+ logger.add_figure(
201
+ "%s_han" % (uttid_list[idx]),
202
+ plot.gcf(),
203
+ step,
204
+ )
205
+ else:
206
+ for idx, att_w in enumerate(att_ws):
207
+ att_w = self.trim_attention_weight(uttid_list[idx], att_w)
208
+ plot = self.draw_attention_plot(att_w)
209
+ logger.add_figure("%s" % (uttid_list[idx]), plot.gcf(), step)
210
+
211
+ def get_attention_weights(self):
212
+ """Return attention weights.
213
+
214
+ Returns:
215
+ numpy.ndarray: attention weights. float. Its shape would be
216
+ differ from backend.
217
+ * pytorch-> 1) multi-head case => (B, H, Lmax, Tmax), 2)
218
+ other case => (B, Lmax, Tmax).
219
+ * chainer-> (B, Lmax, Tmax)
220
+
221
+ """
222
+ return_batch, uttid_list = self.transform(self.data, return_uttid=True)
223
+ batch = self.converter([return_batch], self.device)
224
+ if isinstance(batch, tuple):
225
+ att_ws = self.att_vis_fn(*batch)
226
+ else:
227
+ att_ws = self.att_vis_fn(**batch)
228
+ return att_ws, uttid_list
229
+
230
+ def trim_attention_weight(self, uttid, att_w):
231
+ """Transform attention matrix with regard to self.reverse."""
232
+ if self.reverse:
233
+ enc_key, enc_axis = self.okey, self.oaxis
234
+ dec_key, dec_axis = self.ikey, self.iaxis
235
+ else:
236
+ enc_key, enc_axis = self.ikey, self.iaxis
237
+ dec_key, dec_axis = self.okey, self.oaxis
238
+ dec_len = int(self.data_dict[uttid][dec_key][dec_axis]["shape"][0])
239
+ enc_len = int(self.data_dict[uttid][enc_key][enc_axis]["shape"][0])
240
+ if self.factor > 1:
241
+ enc_len //= self.factor
242
+ if len(att_w.shape) == 3:
243
+ att_w = att_w[:, :dec_len, :enc_len]
244
+ else:
245
+ att_w = att_w[:dec_len, :enc_len]
246
+ return att_w
247
+
248
+ def draw_attention_plot(self, att_w):
249
+ """Plot the att_w matrix.
250
+
251
+ Returns:
252
+ matplotlib.pyplot: pyplot object with attention matrix image.
253
+
254
+ """
255
+ import matplotlib
256
+
257
+ matplotlib.use("Agg")
258
+ import matplotlib.pyplot as plt
259
+
260
+ plt.clf()
261
+ att_w = att_w.astype(np.float32)
262
+ if len(att_w.shape) == 3:
263
+ for h, aw in enumerate(att_w, 1):
264
+ plt.subplot(1, len(att_w), h)
265
+ plt.imshow(aw, aspect="auto")
266
+ plt.xlabel("Encoder Index")
267
+ plt.ylabel("Decoder Index")
268
+ else:
269
+ plt.imshow(att_w, aspect="auto")
270
+ plt.xlabel("Encoder Index")
271
+ plt.ylabel("Decoder Index")
272
+ plt.tight_layout()
273
+ return plt
274
+
275
+ def draw_han_plot(self, att_w):
276
+ """Plot the att_w matrix for hierarchical attention.
277
+
278
+ Returns:
279
+ matplotlib.pyplot: pyplot object with attention matrix image.
280
+
281
+ """
282
+ import matplotlib
283
+
284
+ matplotlib.use("Agg")
285
+ import matplotlib.pyplot as plt
286
+
287
+ plt.clf()
288
+ if len(att_w.shape) == 3:
289
+ for h, aw in enumerate(att_w, 1):
290
+ legends = []
291
+ plt.subplot(1, len(att_w), h)
292
+ for i in range(aw.shape[1]):
293
+ plt.plot(aw[:, i])
294
+ legends.append("Att{}".format(i))
295
+ plt.ylim([0, 1.0])
296
+ plt.xlim([0, aw.shape[0]])
297
+ plt.grid(True)
298
+ plt.ylabel("Attention Weight")
299
+ plt.xlabel("Decoder Index")
300
+ plt.legend(legends)
301
+ else:
302
+ legends = []
303
+ for i in range(att_w.shape[1]):
304
+ plt.plot(att_w[:, i])
305
+ legends.append("Att{}".format(i))
306
+ plt.ylim([0, 1.0])
307
+ plt.xlim([0, att_w.shape[0]])
308
+ plt.grid(True)
309
+ plt.ylabel("Attention Weight")
310
+ plt.xlabel("Decoder Index")
311
+ plt.legend(legends)
312
+ plt.tight_layout()
313
+ return plt
314
+
315
+ def _plot_and_save_attention(self, att_w, filename, han_mode=False):
316
+ if han_mode:
317
+ plt = self.draw_han_plot(att_w)
318
+ else:
319
+ plt = self.draw_attention_plot(att_w)
320
+ plt.savefig(filename)
321
+ plt.close()
322
+
323
+
324
+ try:
325
+ from chainer.training import extension
326
+ except ImportError:
327
+ PlotCTCReport = None
328
+ else:
329
+
330
+ class PlotCTCReport(extension.Extension):
331
+ """Plot CTC reporter.
332
+
333
+ Args:
334
+ ctc_vis_fn (espnet.nets.*_backend.e2e_asr.E2E.calculate_all_ctc_probs):
335
+ Function of CTC visualization.
336
+ data (list[tuple(str, dict[str, list[Any]])]): List json utt key items.
337
+ outdir (str): Directory to save figures.
338
+ converter (espnet.asr.*_backend.asr.CustomConverter):
339
+ Function to convert data.
340
+ device (int | torch.device): Device.
341
+ reverse (bool): If True, input and output length are reversed.
342
+ ikey (str): Key to access input
343
+ (for ASR/ST ikey="input", for MT ikey="output".)
344
+ iaxis (int): Dimension to access input
345
+ (for ASR/ST iaxis=0, for MT iaxis=1.)
346
+ okey (str): Key to access output
347
+ (for ASR/ST okey="input", MT okay="output".)
348
+ oaxis (int): Dimension to access output
349
+ (for ASR/ST oaxis=0, for MT oaxis=0.)
350
+ subsampling_factor (int): subsampling factor in encoder
351
+
352
+ """
353
+
354
+ def __init__(
355
+ self,
356
+ ctc_vis_fn,
357
+ data,
358
+ outdir,
359
+ converter,
360
+ transform,
361
+ device,
362
+ reverse=False,
363
+ ikey="input",
364
+ iaxis=0,
365
+ okey="output",
366
+ oaxis=0,
367
+ subsampling_factor=1,
368
+ ):
369
+ self.ctc_vis_fn = ctc_vis_fn
370
+ self.data = copy.deepcopy(data)
371
+ self.data_dict = {k: v for k, v in copy.deepcopy(data)}
372
+ # key is utterance ID
373
+ self.outdir = outdir
374
+ self.converter = converter
375
+ self.transform = transform
376
+ self.device = device
377
+ self.reverse = reverse
378
+ self.ikey = ikey
379
+ self.iaxis = iaxis
380
+ self.okey = okey
381
+ self.oaxis = oaxis
382
+ self.factor = subsampling_factor
383
+ if not os.path.exists(self.outdir):
384
+ os.makedirs(self.outdir)
385
+
386
+ def __call__(self, trainer):
387
+ """Plot and save image file of ctc prob."""
388
+ ctc_probs, uttid_list = self.get_ctc_probs()
389
+ if isinstance(ctc_probs, list): # multi-encoder case
390
+ num_encs = len(ctc_probs) - 1
391
+ for i in range(num_encs):
392
+ for idx, ctc_prob in enumerate(ctc_probs[i]):
393
+ filename = "%s/%s.ep.{.updater.epoch}.ctc%d.png" % (
394
+ self.outdir,
395
+ uttid_list[idx],
396
+ i + 1,
397
+ )
398
+ ctc_prob = self.trim_ctc_prob(uttid_list[idx], ctc_prob)
399
+ np_filename = "%s/%s.ep.{.updater.epoch}.ctc%d.npy" % (
400
+ self.outdir,
401
+ uttid_list[idx],
402
+ i + 1,
403
+ )
404
+ np.save(np_filename.format(trainer), ctc_prob)
405
+ self._plot_and_save_ctc(ctc_prob, filename.format(trainer))
406
+ else:
407
+ for idx, ctc_prob in enumerate(ctc_probs):
408
+ filename = "%s/%s.ep.{.updater.epoch}.png" % (
409
+ self.outdir,
410
+ uttid_list[idx],
411
+ )
412
+ ctc_prob = self.trim_ctc_prob(uttid_list[idx], ctc_prob)
413
+ np_filename = "%s/%s.ep.{.updater.epoch}.npy" % (
414
+ self.outdir,
415
+ uttid_list[idx],
416
+ )
417
+ np.save(np_filename.format(trainer), ctc_prob)
418
+ self._plot_and_save_ctc(ctc_prob, filename.format(trainer))
419
+
420
+ def log_ctc_probs(self, logger, step):
421
+ """Add image files of ctc probs to the tensorboard."""
422
+ ctc_probs, uttid_list = self.get_ctc_probs()
423
+ if isinstance(ctc_probs, list): # multi-encoder case
424
+ num_encs = len(ctc_probs) - 1
425
+ for i in range(num_encs):
426
+ for idx, ctc_prob in enumerate(ctc_probs[i]):
427
+ ctc_prob = self.trim_ctc_prob(uttid_list[idx], ctc_prob)
428
+ plot = self.draw_ctc_plot(ctc_prob)
429
+ logger.add_figure(
430
+ "%s_ctc%d" % (uttid_list[idx], i + 1),
431
+ plot.gcf(),
432
+ step,
433
+ )
434
+ else:
435
+ for idx, ctc_prob in enumerate(ctc_probs):
436
+ ctc_prob = self.trim_ctc_prob(uttid_list[idx], ctc_prob)
437
+ plot = self.draw_ctc_plot(ctc_prob)
438
+ logger.add_figure("%s" % (uttid_list[idx]), plot.gcf(), step)
439
+
440
+ def get_ctc_probs(self):
441
+ """Return CTC probs.
442
+
443
+ Returns:
444
+ numpy.ndarray: CTC probs. float. Its shape would be
445
+ differ from backend. (B, Tmax, vocab).
446
+
447
+ """
448
+ return_batch, uttid_list = self.transform(self.data, return_uttid=True)
449
+ batch = self.converter([return_batch], self.device)
450
+ if isinstance(batch, tuple):
451
+ probs = self.ctc_vis_fn(*batch)
452
+ else:
453
+ probs = self.ctc_vis_fn(**batch)
454
+ return probs, uttid_list
455
+
456
+ def trim_ctc_prob(self, uttid, prob):
457
+ """Trim CTC posteriors accoding to input lengths."""
458
+ enc_len = int(self.data_dict[uttid][self.ikey][self.iaxis]["shape"][0])
459
+ if self.factor > 1:
460
+ enc_len //= self.factor
461
+ prob = prob[:enc_len]
462
+ return prob
463
+
464
+ def draw_ctc_plot(self, ctc_prob):
465
+ """Plot the ctc_prob matrix.
466
+
467
+ Returns:
468
+ matplotlib.pyplot: pyplot object with CTC prob matrix image.
469
+
470
+ """
471
+ import matplotlib
472
+
473
+ matplotlib.use("Agg")
474
+ import matplotlib.pyplot as plt
475
+
476
+ ctc_prob = ctc_prob.astype(np.float32)
477
+
478
+ plt.clf()
479
+ topk_ids = np.argsort(ctc_prob, axis=1)
480
+ n_frames, vocab = ctc_prob.shape
481
+ times_probs = np.arange(n_frames)
482
+
483
+ plt.figure(figsize=(20, 8))
484
+
485
+ # NOTE: index 0 is reserved for blank
486
+ for idx in set(topk_ids.reshape(-1).tolist()):
487
+ if idx == 0:
488
+ plt.plot(
489
+ times_probs, ctc_prob[:, 0], ":", label="<blank>", color="grey"
490
+ )
491
+ else:
492
+ plt.plot(times_probs, ctc_prob[:, idx])
493
+ plt.xlabel("Input [frame]", fontsize=12)
494
+ plt.ylabel("Posteriors", fontsize=12)
495
+ plt.xticks(list(range(0, int(n_frames) + 1, 10)))
496
+ plt.yticks(list(range(0, 2, 1)))
497
+ plt.tight_layout()
498
+ return plt
499
+
500
+ def _plot_and_save_ctc(self, ctc_prob, filename):
501
+ plt = self.draw_ctc_plot(ctc_prob)
502
+ plt.savefig(filename)
503
+ plt.close()
504
+
505
+
506
+ def restore_snapshot(model, snapshot, load_fn=None):
507
+ """Extension to restore snapshot.
508
+
509
+ Returns:
510
+ An extension function.
511
+
512
+ """
513
+ import chainer
514
+ from chainer import training
515
+
516
+ if load_fn is None:
517
+ load_fn = chainer.serializers.load_npz
518
+
519
+ @training.make_extension(trigger=(1, "epoch"))
520
+ def restore_snapshot(trainer):
521
+ _restore_snapshot(model, snapshot, load_fn)
522
+
523
+ return restore_snapshot
524
+
525
+
526
+ def _restore_snapshot(model, snapshot, load_fn=None):
527
+ if load_fn is None:
528
+ import chainer
529
+
530
+ load_fn = chainer.serializers.load_npz
531
+
532
+ load_fn(snapshot, model)
533
+ logging.info("restored from " + str(snapshot))
534
+
535
+
536
+ def adadelta_eps_decay(eps_decay):
537
+ """Extension to perform adadelta eps decay.
538
+
539
+ Args:
540
+ eps_decay (float): Decay rate of eps.
541
+
542
+ Returns:
543
+ An extension function.
544
+
545
+ """
546
+ from chainer import training
547
+
548
+ @training.make_extension(trigger=(1, "epoch"))
549
+ def adadelta_eps_decay(trainer):
550
+ _adadelta_eps_decay(trainer, eps_decay)
551
+
552
+ return adadelta_eps_decay
553
+
554
+
555
+ def _adadelta_eps_decay(trainer, eps_decay):
556
+ optimizer = trainer.updater.get_optimizer("main")
557
+ # for chainer
558
+ if hasattr(optimizer, "eps"):
559
+ current_eps = optimizer.eps
560
+ setattr(optimizer, "eps", current_eps * eps_decay)
561
+ logging.info("adadelta eps decayed to " + str(optimizer.eps))
562
+ # pytorch
563
+ else:
564
+ for p in optimizer.param_groups:
565
+ p["eps"] *= eps_decay
566
+ logging.info("adadelta eps decayed to " + str(p["eps"]))
567
+
568
+
569
+ def adam_lr_decay(eps_decay):
570
+ """Extension to perform adam lr decay.
571
+
572
+ Args:
573
+ eps_decay (float): Decay rate of lr.
574
+
575
+ Returns:
576
+ An extension function.
577
+
578
+ """
579
+ from chainer import training
580
+
581
+ @training.make_extension(trigger=(1, "epoch"))
582
+ def adam_lr_decay(trainer):
583
+ _adam_lr_decay(trainer, eps_decay)
584
+
585
+ return adam_lr_decay
586
+
587
+
588
+ def _adam_lr_decay(trainer, eps_decay):
589
+ optimizer = trainer.updater.get_optimizer("main")
590
+ # for chainer
591
+ if hasattr(optimizer, "lr"):
592
+ current_lr = optimizer.lr
593
+ setattr(optimizer, "lr", current_lr * eps_decay)
594
+ logging.info("adam lr decayed to " + str(optimizer.lr))
595
+ # pytorch
596
+ else:
597
+ for p in optimizer.param_groups:
598
+ p["lr"] *= eps_decay
599
+ logging.info("adam lr decayed to " + str(p["lr"]))
600
+
601
+
602
+ def torch_snapshot(savefun=torch.save, filename="snapshot.ep.{.updater.epoch}"):
603
+ """Extension to take snapshot of the trainer for pytorch.
604
+
605
+ Returns:
606
+ An extension function.
607
+
608
+ """
609
+ from chainer.training import extension
610
+
611
+ @extension.make_extension(trigger=(1, "epoch"), priority=-100)
612
+ def torch_snapshot(trainer):
613
+ _torch_snapshot_object(trainer, trainer, filename.format(trainer), savefun)
614
+
615
+ return torch_snapshot
616
+
617
+
618
+ def _torch_snapshot_object(trainer, target, filename, savefun):
619
+ from chainer.serializers import DictionarySerializer
620
+
621
+ # make snapshot_dict dictionary
622
+ s = DictionarySerializer()
623
+ s.save(trainer)
624
+ if hasattr(trainer.updater.model, "model"):
625
+ # (for TTS)
626
+ if hasattr(trainer.updater.model.model, "module"):
627
+ model_state_dict = trainer.updater.model.model.module.state_dict()
628
+ else:
629
+ model_state_dict = trainer.updater.model.model.state_dict()
630
+ else:
631
+ # (for ASR)
632
+ if hasattr(trainer.updater.model, "module"):
633
+ model_state_dict = trainer.updater.model.module.state_dict()
634
+ else:
635
+ model_state_dict = trainer.updater.model.state_dict()
636
+ snapshot_dict = {
637
+ "trainer": s.target,
638
+ "model": model_state_dict,
639
+ "optimizer": trainer.updater.get_optimizer("main").state_dict(),
640
+ }
641
+
642
+ # save snapshot dictionary
643
+ fn = filename.format(trainer)
644
+ prefix = "tmp" + fn
645
+ tmpdir = tempfile.mkdtemp(prefix=prefix, dir=trainer.out)
646
+ tmppath = os.path.join(tmpdir, fn)
647
+ try:
648
+ savefun(snapshot_dict, tmppath)
649
+ shutil.move(tmppath, os.path.join(trainer.out, fn))
650
+ finally:
651
+ shutil.rmtree(tmpdir)
652
+
653
+
654
+ def add_gradient_noise(model, iteration, duration=100, eta=1.0, scale_factor=0.55):
655
+ """Adds noise from a standard normal distribution to the gradients.
656
+
657
+ The standard deviation (`sigma`) is controlled by the three hyper-parameters below.
658
+ `sigma` goes to zero (no noise) with more iterations.
659
+
660
+ Args:
661
+ model (torch.nn.model): Model.
662
+ iteration (int): Number of iterations.
663
+ duration (int) {100, 1000}:
664
+ Number of durations to control the interval of the `sigma` change.
665
+ eta (float) {0.01, 0.3, 1.0}: The magnitude of `sigma`.
666
+ scale_factor (float) {0.55}: The scale of `sigma`.
667
+ """
668
+ interval = (iteration // duration) + 1
669
+ sigma = eta / interval**scale_factor
670
+ for param in model.parameters():
671
+ if param.grad is not None:
672
+ _shape = param.grad.size()
673
+ noise = sigma * torch.randn(_shape).to(param.device)
674
+ param.grad += noise
675
+
676
+
677
+ # * -------------------- general -------------------- *
678
+ def get_model_conf(model_path, conf_path=None):
679
+ """Get model config information by reading a model config file (model.json).
680
+
681
+ Args:
682
+ model_path (str): Model path.
683
+ conf_path (str): Optional model config path.
684
+
685
+ Returns:
686
+ list[int, int, dict[str, Any]]: Config information loaded from json file.
687
+
688
+ """
689
+ if conf_path is None:
690
+ model_conf = os.path.dirname(model_path) + "/model.json"
691
+ else:
692
+ model_conf = conf_path
693
+ with open(model_conf, "rb") as f:
694
+ logging.info("reading a config file from " + model_conf)
695
+ confs = json.load(f)
696
+ if isinstance(confs, dict):
697
+ # for lm
698
+ args = confs
699
+ return argparse.Namespace(**args)
700
+ else:
701
+ # for asr, tts, mt
702
+ idim, odim, args = confs
703
+ return idim, odim, argparse.Namespace(**args)
704
+
705
+
706
+ def chainer_load(path, model):
707
+ """Load chainer model parameters.
708
+
709
+ Args:
710
+ path (str): Model path or snapshot file path to be loaded.
711
+ model (chainer.Chain): Chainer model.
712
+
713
+ """
714
+ import chainer
715
+
716
+ if "snapshot" in os.path.basename(path):
717
+ chainer.serializers.load_npz(path, model, path="updater/model:main/")
718
+ else:
719
+ chainer.serializers.load_npz(path, model)
720
+
721
+
722
+ def torch_save(path, model):
723
+ """Save torch model states.
724
+
725
+ Args:
726
+ path (str): Model path to be saved.
727
+ model (torch.nn.Module): Torch model.
728
+
729
+ """
730
+ if hasattr(model, "module"):
731
+ torch.save(model.module.state_dict(), path)
732
+ else:
733
+ torch.save(model.state_dict(), path)
734
+
735
+
736
+ def snapshot_object(target, filename):
737
+ """Returns a trainer extension to take snapshots of a given object.
738
+
739
+ Args:
740
+ target (model): Object to serialize.
741
+ filename (str): Name of the file into which the object is serialized.It can
742
+ be a format string, where the trainer object is passed to
743
+ the :meth: `str.format` method. For example,
744
+ ``'snapshot_{.updater.iteration}'`` is converted to
745
+ ``'snapshot_10000'`` at the 10,000th iteration.
746
+
747
+ Returns:
748
+ An extension function.
749
+
750
+ """
751
+ from chainer.training import extension
752
+
753
+ @extension.make_extension(trigger=(1, "epoch"), priority=-100)
754
+ def snapshot_object(trainer):
755
+ torch_save(os.path.join(trainer.out, filename.format(trainer)), target)
756
+
757
+ return snapshot_object
758
+
759
+
760
+ def torch_load(path, model):
761
+ """Load torch model states.
762
+
763
+ Args:
764
+ path (str): Model path or snapshot file path to be loaded.
765
+ model (torch.nn.Module): Torch model.
766
+
767
+ """
768
+ if "snapshot" in os.path.basename(path):
769
+ model_state_dict = torch.load(path, map_location=lambda storage, loc: storage)[
770
+ "model"
771
+ ]
772
+ else:
773
+ model_state_dict = torch.load(path, map_location=lambda storage, loc: storage)
774
+
775
+ if hasattr(model, "module"):
776
+ model.module.load_state_dict(model_state_dict)
777
+ else:
778
+ model.load_state_dict(model_state_dict)
779
+
780
+ del model_state_dict
781
+
782
+
783
+ def torch_resume(snapshot_path, trainer):
784
+ """Resume from snapshot for pytorch.
785
+
786
+ Args:
787
+ snapshot_path (str): Snapshot file path.
788
+ trainer (chainer.training.Trainer): Chainer's trainer instance.
789
+
790
+ """
791
+ from chainer.serializers import NpzDeserializer
792
+
793
+ # load snapshot
794
+ snapshot_dict = torch.load(snapshot_path, map_location=lambda storage, loc: storage)
795
+
796
+ # restore trainer states
797
+ d = NpzDeserializer(snapshot_dict["trainer"])
798
+ d.load(trainer)
799
+
800
+ # restore model states
801
+ if hasattr(trainer.updater.model, "model"):
802
+ # (for TTS model)
803
+ if hasattr(trainer.updater.model.model, "module"):
804
+ trainer.updater.model.model.module.load_state_dict(snapshot_dict["model"])
805
+ else:
806
+ trainer.updater.model.model.load_state_dict(snapshot_dict["model"])
807
+ else:
808
+ # (for ASR model)
809
+ if hasattr(trainer.updater.model, "module"):
810
+ trainer.updater.model.module.load_state_dict(snapshot_dict["model"])
811
+ else:
812
+ trainer.updater.model.load_state_dict(snapshot_dict["model"])
813
+
814
+ # retore optimizer states
815
+ trainer.updater.get_optimizer("main").load_state_dict(snapshot_dict["optimizer"])
816
+
817
+ # delete opened snapshot
818
+ del snapshot_dict
819
+
820
+
821
+ # * ------------------ recognition related ------------------ *
822
+ def parse_hypothesis(hyp, char_list):
823
+ """Parse hypothesis.
824
+
825
+ Args:
826
+ hyp (list[dict[str, Any]]): Recognition hypothesis.
827
+ char_list (list[str]): List of characters.
828
+
829
+ Returns:
830
+ tuple(str, str, str, float)
831
+
832
+ """
833
+ # remove sos and get results
834
+ tokenid_as_list = list(map(int, hyp["yseq"][1:]))
835
+ token_as_list = [char_list[idx] for idx in tokenid_as_list]
836
+ score = float(hyp["score"])
837
+
838
+ # convert to string
839
+ tokenid = " ".join([str(idx) for idx in tokenid_as_list])
840
+ token = " ".join(token_as_list)
841
+ text = "".join(token_as_list).replace("<space>", " ")
842
+
843
+ return text, token, tokenid, score
844
+
845
+
846
+ def add_results_to_json(nbest_hyps, char_list):
847
+ """Add N-best results to json.
848
+ Args:
849
+ js (dict[str, Any]): Groundtruth utterance dict.
850
+ nbest_hyps_sd (list[dict[str, Any]]):
851
+ List of hypothesis for multi_speakers: nutts x nspkrs.
852
+ char_list (list[str]): List of characters.
853
+ Returns:
854
+ str: 1-best result
855
+ """
856
+ assert len(nbest_hyps) == 1, "only 1-best result is supported."
857
+ # parse hypothesis
858
+ rec_text, rec_token, rec_tokenid, score = parse_hypothesis(nbest_hyps[0], char_list)
859
+ return rec_text
860
+
861
+
862
+ def plot_spectrogram(
863
+ plt,
864
+ spec,
865
+ mode="db",
866
+ fs=None,
867
+ frame_shift=None,
868
+ bottom=True,
869
+ left=True,
870
+ right=True,
871
+ top=False,
872
+ labelbottom=True,
873
+ labelleft=True,
874
+ labelright=True,
875
+ labeltop=False,
876
+ cmap="inferno",
877
+ ):
878
+ """Plot spectrogram using matplotlib.
879
+
880
+ Args:
881
+ plt (matplotlib.pyplot): pyplot object.
882
+ spec (numpy.ndarray): Input stft (Freq, Time)
883
+ mode (str): db or linear.
884
+ fs (int): Sample frequency. To convert y-axis to kHz unit.
885
+ frame_shift (int): The frame shift of stft. To convert x-axis to second unit.
886
+ bottom (bool):Whether to draw the respective ticks.
887
+ left (bool):
888
+ right (bool):
889
+ top (bool):
890
+ labelbottom (bool):Whether to draw the respective tick labels.
891
+ labelleft (bool):
892
+ labelright (bool):
893
+ labeltop (bool):
894
+ cmap (str): Colormap defined in matplotlib.
895
+
896
+ """
897
+ spec = np.abs(spec)
898
+ if mode == "db":
899
+ x = 20 * np.log10(spec + np.finfo(spec.dtype).eps)
900
+ elif mode == "linear":
901
+ x = spec
902
+ else:
903
+ raise ValueError(mode)
904
+
905
+ if fs is not None:
906
+ ytop = fs / 2000
907
+ ylabel = "kHz"
908
+ else:
909
+ ytop = x.shape[0]
910
+ ylabel = "bin"
911
+
912
+ if frame_shift is not None and fs is not None:
913
+ xtop = x.shape[1] * frame_shift / fs
914
+ xlabel = "s"
915
+ else:
916
+ xtop = x.shape[1]
917
+ xlabel = "frame"
918
+
919
+ extent = (0, xtop, 0, ytop)
920
+ plt.imshow(x[::-1], cmap=cmap, extent=extent)
921
+
922
+ if labelbottom:
923
+ plt.xlabel("time [{}]".format(xlabel))
924
+ if labelleft:
925
+ plt.ylabel("freq [{}]".format(ylabel))
926
+ plt.colorbar().set_label("{}".format(mode))
927
+
928
+ plt.tick_params(
929
+ bottom=bottom,
930
+ left=left,
931
+ right=right,
932
+ top=top,
933
+ labelbottom=labelbottom,
934
+ labelleft=labelleft,
935
+ labelright=labelright,
936
+ labeltop=labeltop,
937
+ )
938
+ plt.axis("auto")
939
+
940
+
941
+ # * ------------------ recognition related ------------------ *
942
+ def format_mulenc_args(args):
943
+ """Format args for multi-encoder setup.
944
+
945
+ It deals with following situations: (when args.num_encs=2):
946
+ 1. args.elayers = None -> args.elayers = [4, 4];
947
+ 2. args.elayers = 4 -> args.elayers = [4, 4];
948
+ 3. args.elayers = [4, 4, 4] -> args.elayers = [4, 4].
949
+
950
+ """
951
+ # default values when None is assigned.
952
+ default_dict = {
953
+ "etype": "blstmp",
954
+ "elayers": 4,
955
+ "eunits": 300,
956
+ "subsample": "1",
957
+ "dropout_rate": 0.0,
958
+ "atype": "dot",
959
+ "adim": 320,
960
+ "awin": 5,
961
+ "aheads": 4,
962
+ "aconv_chans": -1,
963
+ "aconv_filts": 100,
964
+ }
965
+ for k in default_dict.keys():
966
+ if isinstance(vars(args)[k], list):
967
+ if len(vars(args)[k]) != args.num_encs:
968
+ logging.warning(
969
+ "Length mismatch {}: Convert {} to {}.".format(
970
+ k, vars(args)[k], vars(args)[k][: args.num_encs]
971
+ )
972
+ )
973
+ vars(args)[k] = vars(args)[k][: args.num_encs]
974
+ else:
975
+ if not vars(args)[k]:
976
+ # assign default value if it is None
977
+ vars(args)[k] = default_dict[k]
978
+ logging.warning(
979
+ "{} is not specified, use default value {}.".format(
980
+ k, default_dict[k]
981
+ )
982
+ )
983
+ # duplicate
984
+ logging.warning(
985
+ "Type mismatch {}: Convert {} to {}.".format(
986
+ k, vars(args)[k], [vars(args)[k] for _ in range(args.num_encs)]
987
+ )
988
+ )
989
+ vars(args)[k] = [vars(args)[k] for _ in range(args.num_encs)]
990
+ return args
espnet/nets/.DS_Store ADDED
Binary file (6.15 kB). View file
 
espnet/nets/batch_beam_search.py ADDED
@@ -0,0 +1,349 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Parallel beam search module."""
2
+
3
+ import logging
4
+ from typing import Any
5
+ from typing import Dict
6
+ from typing import List
7
+ from typing import NamedTuple
8
+ from typing import Tuple
9
+
10
+ import torch
11
+ from torch.nn.utils.rnn import pad_sequence
12
+
13
+ from espnet.nets.beam_search import BeamSearch
14
+ from espnet.nets.beam_search import Hypothesis
15
+
16
+
17
+ class BatchHypothesis(NamedTuple):
18
+ """Batchfied/Vectorized hypothesis data type."""
19
+
20
+ yseq: torch.Tensor = torch.tensor([]) # (batch, maxlen)
21
+ score: torch.Tensor = torch.tensor([]) # (batch,)
22
+ length: torch.Tensor = torch.tensor([]) # (batch,)
23
+ scores: Dict[str, torch.Tensor] = dict() # values: (batch,)
24
+ states: Dict[str, Dict] = dict()
25
+
26
+ def __len__(self) -> int:
27
+ """Return a batch size."""
28
+ return len(self.length)
29
+
30
+
31
+ class BatchBeamSearch(BeamSearch):
32
+ """Batch beam search implementation."""
33
+
34
+ def batchfy(self, hyps: List[Hypothesis]) -> BatchHypothesis:
35
+ """Convert list to batch."""
36
+ if len(hyps) == 0:
37
+ return BatchHypothesis()
38
+ yseq=pad_sequence(
39
+ [h.yseq for h in hyps], batch_first=True, padding_value=self.eos
40
+ )
41
+ return BatchHypothesis(
42
+ yseq=yseq,
43
+ length=torch.tensor([len(h.yseq) for h in hyps], dtype=torch.int64, device=yseq.device),
44
+ score=torch.tensor([h.score for h in hyps]).to(yseq.device),
45
+ scores={k: torch.tensor([h.scores[k] for h in hyps], device=yseq.device) for k in self.scorers},
46
+ states={k: [h.states[k] for h in hyps] for k in self.scorers},
47
+ )
48
+
49
+ def _batch_select(self, hyps: BatchHypothesis, ids: List[int]) -> BatchHypothesis:
50
+ return BatchHypothesis(
51
+ yseq=hyps.yseq[ids],
52
+ score=hyps.score[ids],
53
+ length=hyps.length[ids],
54
+ scores={k: v[ids] for k, v in hyps.scores.items()},
55
+ states={
56
+ k: [self.scorers[k].select_state(v, i) for i in ids]
57
+ for k, v in hyps.states.items()
58
+ },
59
+ )
60
+
61
+ def _select(self, hyps: BatchHypothesis, i: int) -> Hypothesis:
62
+ return Hypothesis(
63
+ yseq=hyps.yseq[i, : hyps.length[i]],
64
+ score=hyps.score[i],
65
+ scores={k: v[i] for k, v in hyps.scores.items()},
66
+ states={
67
+ k: self.scorers[k].select_state(v, i) for k, v in hyps.states.items()
68
+ },
69
+ )
70
+
71
+ def unbatchfy(self, batch_hyps: BatchHypothesis) -> List[Hypothesis]:
72
+ """Revert batch to list."""
73
+ return [
74
+ Hypothesis(
75
+ yseq=batch_hyps.yseq[i][: batch_hyps.length[i]],
76
+ score=batch_hyps.score[i],
77
+ scores={k: batch_hyps.scores[k][i] for k in self.scorers},
78
+ states={
79
+ k: v.select_state(batch_hyps.states[k], i)
80
+ for k, v in self.scorers.items()
81
+ },
82
+ )
83
+ for i in range(len(batch_hyps.length))
84
+ ]
85
+
86
+ def batch_beam(
87
+ self, weighted_scores: torch.Tensor, ids: torch.Tensor
88
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
89
+ """Batch-compute topk full token ids and partial token ids.
90
+
91
+ Args:
92
+ weighted_scores (torch.Tensor): The weighted sum scores for each tokens.
93
+ Its shape is `(n_beam, self.vocab_size)`.
94
+ ids (torch.Tensor): The partial token ids to compute topk.
95
+ Its shape is `(n_beam, self.pre_beam_size)`.
96
+
97
+ Returns:
98
+ Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
99
+ The topk full (prev_hyp, new_token) ids
100
+ and partial (prev_hyp, new_token) ids.
101
+ Their shapes are all `(self.beam_size,)`
102
+
103
+ """
104
+ top_ids = weighted_scores.view(-1).topk(self.beam_size)[1]
105
+ # Because of the flatten above, `top_ids` is organized as:
106
+ # [hyp1 * V + token1, hyp2 * V + token2, ..., hypK * V + tokenK],
107
+ # where V is `self.n_vocab` and K is `self.beam_size`
108
+ prev_hyp_ids = torch.div(top_ids, self.n_vocab, rounding_mode='trunc')
109
+ new_token_ids = top_ids % self.n_vocab
110
+ return prev_hyp_ids, new_token_ids, prev_hyp_ids, new_token_ids
111
+
112
+ def init_hyp(self, x: torch.Tensor) -> BatchHypothesis:
113
+ """Get an initial hypothesis data.
114
+
115
+ Args:
116
+ x (torch.Tensor): The encoder output feature
117
+
118
+ Returns:
119
+ Hypothesis: The initial hypothesis.
120
+
121
+ """
122
+ init_states = dict()
123
+ init_scores = dict()
124
+ for k, d in self.scorers.items():
125
+ init_states[k] = d.batch_init_state(x)
126
+ init_scores[k] = 0.0
127
+ return self.batchfy(
128
+ [
129
+ Hypothesis(
130
+ score=0.0,
131
+ scores=init_scores,
132
+ states=init_states,
133
+ yseq=torch.tensor([self.sos], device=x.device),
134
+ )
135
+ ]
136
+ )
137
+
138
+ def score_full(
139
+ self, hyp: BatchHypothesis, x: torch.Tensor
140
+ ) -> Tuple[Dict[str, torch.Tensor], Dict[str, Any]]:
141
+ """Score new hypothesis by `self.full_scorers`.
142
+
143
+ Args:
144
+ hyp (Hypothesis): Hypothesis with prefix tokens to score
145
+ x (torch.Tensor): Corresponding input feature
146
+
147
+ Returns:
148
+ Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: Tuple of
149
+ score dict of `hyp` that has string keys of `self.full_scorers`
150
+ and tensor score values of shape: `(self.n_vocab,)`,
151
+ and state dict that has string keys
152
+ and state values of `self.full_scorers`
153
+
154
+ """
155
+ scores = dict()
156
+ states = dict()
157
+ for k, d in self.full_scorers.items():
158
+ scores[k], states[k] = d.batch_score(hyp.yseq, hyp.states[k], x)
159
+ return scores, states
160
+
161
+ def score_partial(
162
+ self, hyp: BatchHypothesis, ids: torch.Tensor, x: torch.Tensor
163
+ ) -> Tuple[Dict[str, torch.Tensor], Dict[str, Any]]:
164
+ """Score new hypothesis by `self.full_scorers`.
165
+
166
+ Args:
167
+ hyp (Hypothesis): Hypothesis with prefix tokens to score
168
+ ids (torch.Tensor): 2D tensor of new partial tokens to score
169
+ x (torch.Tensor): Corresponding input feature
170
+
171
+ Returns:
172
+ Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: Tuple of
173
+ score dict of `hyp` that has string keys of `self.full_scorers`
174
+ and tensor score values of shape: `(self.n_vocab,)`,
175
+ and state dict that has string keys
176
+ and state values of `self.full_scorers`
177
+
178
+ """
179
+ scores = dict()
180
+ states = dict()
181
+ for k, d in self.part_scorers.items():
182
+ scores[k], states[k] = d.batch_score_partial(
183
+ hyp.yseq, ids, hyp.states[k], x
184
+ )
185
+ return scores, states
186
+
187
+ def merge_states(self, states: Any, part_states: Any, part_idx: int) -> Any:
188
+ """Merge states for new hypothesis.
189
+
190
+ Args:
191
+ states: states of `self.full_scorers`
192
+ part_states: states of `self.part_scorers`
193
+ part_idx (int): The new token id for `part_scores`
194
+
195
+ Returns:
196
+ Dict[str, torch.Tensor]: The new score dict.
197
+ Its keys are names of `self.full_scorers` and `self.part_scorers`.
198
+ Its values are states of the scorers.
199
+
200
+ """
201
+ new_states = dict()
202
+ for k, v in states.items():
203
+ new_states[k] = v
204
+ for k, v in part_states.items():
205
+ new_states[k] = v
206
+ return new_states
207
+
208
+ def search(self, running_hyps: BatchHypothesis, x: torch.Tensor) -> BatchHypothesis:
209
+ """Search new tokens for running hypotheses and encoded speech x.
210
+
211
+ Args:
212
+ running_hyps (BatchHypothesis): Running hypotheses on beam
213
+ x (torch.Tensor): Encoded speech feature (T, D)
214
+
215
+ Returns:
216
+ BatchHypothesis: Best sorted hypotheses
217
+
218
+ """
219
+ n_batch = len(running_hyps)
220
+ part_ids = None # no pre-beam
221
+ # batch scoring
222
+ weighted_scores = torch.zeros(
223
+ n_batch, self.n_vocab, dtype=x.dtype, device=x.device
224
+ )
225
+ scores, states = self.score_full(running_hyps, x.expand(n_batch, *x.shape))
226
+ for k in self.full_scorers:
227
+ weighted_scores += self.weights[k] * scores[k]
228
+ # partial scoring
229
+ if self.do_pre_beam:
230
+ pre_beam_scores = (
231
+ weighted_scores
232
+ if self.pre_beam_score_key == "full"
233
+ else scores[self.pre_beam_score_key]
234
+ )
235
+ part_ids = torch.topk(pre_beam_scores, self.pre_beam_size, dim=-1)[1]
236
+ # NOTE(takaaki-hori): Unlike BeamSearch, we assume that score_partial returns
237
+ # full-size score matrices, which has non-zero scores for part_ids and zeros
238
+ # for others.
239
+ part_scores, part_states = self.score_partial(running_hyps, part_ids, x)
240
+ for k in self.part_scorers:
241
+ weighted_scores += self.weights[k] * part_scores[k]
242
+ # add previous hyp scores
243
+ weighted_scores += running_hyps.score.to(
244
+ dtype=x.dtype, device=x.device
245
+ ).unsqueeze(1)
246
+
247
+ # TODO(karita): do not use list. use batch instead
248
+ # see also https://github.com/espnet/espnet/pull/1402#discussion_r354561029
249
+ # update hyps
250
+ best_hyps = []
251
+ prev_hyps = self.unbatchfy(running_hyps)
252
+ for (
253
+ full_prev_hyp_id,
254
+ full_new_token_id,
255
+ part_prev_hyp_id,
256
+ part_new_token_id,
257
+ ) in zip(*self.batch_beam(weighted_scores, part_ids)):
258
+ prev_hyp = prev_hyps[full_prev_hyp_id]
259
+ best_hyps.append(
260
+ Hypothesis(
261
+ score=weighted_scores[full_prev_hyp_id, full_new_token_id],
262
+ yseq=self.append_token(prev_hyp.yseq, full_new_token_id),
263
+ scores=self.merge_scores(
264
+ prev_hyp.scores,
265
+ {k: v[full_prev_hyp_id] for k, v in scores.items()},
266
+ full_new_token_id,
267
+ {k: v[part_prev_hyp_id] for k, v in part_scores.items()},
268
+ part_new_token_id,
269
+ ),
270
+ states=self.merge_states(
271
+ {
272
+ k: self.full_scorers[k].select_state(v, full_prev_hyp_id)
273
+ for k, v in states.items()
274
+ },
275
+ {
276
+ k: self.part_scorers[k].select_state(
277
+ v, part_prev_hyp_id, part_new_token_id
278
+ )
279
+ for k, v in part_states.items()
280
+ },
281
+ part_new_token_id,
282
+ ),
283
+ )
284
+ )
285
+ return self.batchfy(best_hyps)
286
+
287
+ def post_process(
288
+ self,
289
+ i: int,
290
+ maxlen: int,
291
+ maxlenratio: float,
292
+ running_hyps: BatchHypothesis,
293
+ ended_hyps: List[Hypothesis],
294
+ ) -> BatchHypothesis:
295
+ """Perform post-processing of beam search iterations.
296
+
297
+ Args:
298
+ i (int): The length of hypothesis tokens.
299
+ maxlen (int): The maximum length of tokens in beam search.
300
+ maxlenratio (int): The maximum length ratio in beam search.
301
+ running_hyps (BatchHypothesis): The running hypotheses in beam search.
302
+ ended_hyps (List[Hypothesis]): The ended hypotheses in beam search.
303
+
304
+ Returns:
305
+ BatchHypothesis: The new running hypotheses.
306
+
307
+ """
308
+ n_batch = running_hyps.yseq.shape[0]
309
+ logging.debug(f"the number of running hypothes: {n_batch}")
310
+ if self.token_list is not None:
311
+ logging.debug(
312
+ "best hypo: "
313
+ + "".join(
314
+ [
315
+ self.token_list[x]
316
+ for x in running_hyps.yseq[0, 1 : running_hyps.length[0]]
317
+ ]
318
+ )
319
+ )
320
+ # add eos in the final loop to avoid that there are no ended hyps
321
+ if i == maxlen - 1:
322
+ logging.info("adding <eos> in the last position in the loop")
323
+ yseq_eos = torch.cat(
324
+ (
325
+ running_hyps.yseq,
326
+ torch.full(
327
+ (n_batch, 1),
328
+ self.eos,
329
+ device=running_hyps.yseq.device,
330
+ dtype=torch.int64,
331
+ ),
332
+ ),
333
+ 1,
334
+ )
335
+ running_hyps.yseq.resize_as_(yseq_eos)
336
+ running_hyps.yseq[:] = yseq_eos
337
+ running_hyps.length[:] = yseq_eos.shape[1]
338
+
339
+ # add ended hypotheses to a final list, and removed them from current hypotheses
340
+ # (this will be a probmlem, number of hyps < beam)
341
+ is_eos = (
342
+ running_hyps.yseq[torch.arange(n_batch), running_hyps.length - 1]
343
+ == self.eos
344
+ )
345
+ for b in torch.nonzero(is_eos, as_tuple=False).view(-1):
346
+ hyp = self._select(running_hyps, b)
347
+ ended_hyps.append(hyp)
348
+ remained_ids = torch.nonzero(is_eos == 0, as_tuple=False).view(-1)
349
+ return self._batch_select(running_hyps, remained_ids)
espnet/nets/beam_search.py ADDED
@@ -0,0 +1,516 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Beam search module."""
2
+
3
+ from itertools import chain
4
+ import logging
5
+ from typing import Any
6
+ from typing import Dict
7
+ from typing import List
8
+ from typing import NamedTuple
9
+ from typing import Tuple
10
+ from typing import Union
11
+
12
+ import torch
13
+
14
+ from espnet.nets.e2e_asr_common import end_detect
15
+ from espnet.nets.scorer_interface import PartialScorerInterface
16
+ from espnet.nets.scorer_interface import ScorerInterface
17
+
18
+
19
+ class Hypothesis(NamedTuple):
20
+ """Hypothesis data type."""
21
+
22
+ yseq: torch.Tensor
23
+ score: Union[float, torch.Tensor] = 0
24
+ scores: Dict[str, Union[float, torch.Tensor]] = dict()
25
+ states: Dict[str, Any] = dict()
26
+
27
+ def asdict(self) -> dict:
28
+ """Convert data to JSON-friendly dict."""
29
+ return self._replace(
30
+ yseq=self.yseq.tolist(),
31
+ score=float(self.score),
32
+ scores={k: float(v) for k, v in self.scores.items()},
33
+ )._asdict()
34
+
35
+
36
+ class BeamSearch(torch.nn.Module):
37
+ """Beam search implementation."""
38
+
39
+ def __init__(
40
+ self,
41
+ scorers: Dict[str, ScorerInterface],
42
+ weights: Dict[str, float],
43
+ beam_size: int,
44
+ vocab_size: int,
45
+ sos: int,
46
+ eos: int,
47
+ token_list: List[str] = None,
48
+ pre_beam_ratio: float = 1.5,
49
+ pre_beam_score_key: str = None,
50
+ ):
51
+ """Initialize beam search.
52
+
53
+ Args:
54
+ scorers (dict[str, ScorerInterface]): Dict of decoder modules
55
+ e.g., Decoder, CTCPrefixScorer, LM
56
+ The scorer will be ignored if it is `None`
57
+ weights (dict[str, float]): Dict of weights for each scorers
58
+ The scorer will be ignored if its weight is 0
59
+ beam_size (int): The number of hypotheses kept during search
60
+ vocab_size (int): The number of vocabulary
61
+ sos (int): Start of sequence id
62
+ eos (int): End of sequence id
63
+ token_list (list[str]): List of tokens for debug log
64
+ pre_beam_score_key (str): key of scores to perform pre-beam search
65
+ pre_beam_ratio (float): beam size in the pre-beam search
66
+ will be `int(pre_beam_ratio * beam_size)`
67
+
68
+ """
69
+ super().__init__()
70
+ # set scorers
71
+ self.weights = weights
72
+ self.scorers = dict()
73
+ self.full_scorers = dict()
74
+ self.part_scorers = dict()
75
+ # this module dict is required for recursive cast
76
+ # `self.to(device, dtype)` in `recog.py`
77
+ self.nn_dict = torch.nn.ModuleDict()
78
+ for k, v in scorers.items():
79
+ w = weights.get(k, 0)
80
+ if w == 0 or v is None:
81
+ continue
82
+ assert isinstance(
83
+ v, ScorerInterface
84
+ ), f"{k} ({type(v)}) does not implement ScorerInterface"
85
+ self.scorers[k] = v
86
+ if isinstance(v, PartialScorerInterface):
87
+ self.part_scorers[k] = v
88
+ else:
89
+ self.full_scorers[k] = v
90
+ if isinstance(v, torch.nn.Module):
91
+ self.nn_dict[k] = v
92
+
93
+ # set configurations
94
+ self.sos = sos
95
+ self.eos = eos
96
+ self.token_list = token_list
97
+ self.pre_beam_size = int(pre_beam_ratio * beam_size)
98
+ self.beam_size = beam_size
99
+ self.n_vocab = vocab_size
100
+ if (
101
+ pre_beam_score_key is not None
102
+ and pre_beam_score_key != "full"
103
+ and pre_beam_score_key not in self.full_scorers
104
+ ):
105
+ raise KeyError(f"{pre_beam_score_key} is not found in {self.full_scorers}")
106
+ self.pre_beam_score_key = pre_beam_score_key
107
+ self.do_pre_beam = (
108
+ self.pre_beam_score_key is not None
109
+ and self.pre_beam_size < self.n_vocab
110
+ and len(self.part_scorers) > 0
111
+ )
112
+
113
+ def init_hyp(self, x: torch.Tensor) -> List[Hypothesis]:
114
+ """Get an initial hypothesis data.
115
+
116
+ Args:
117
+ x (torch.Tensor): The encoder output feature
118
+
119
+ Returns:
120
+ Hypothesis: The initial hypothesis.
121
+
122
+ """
123
+ init_states = dict()
124
+ init_scores = dict()
125
+ for k, d in self.scorers.items():
126
+ init_states[k] = d.init_state(x)
127
+ init_scores[k] = 0.0
128
+ return [
129
+ Hypothesis(
130
+ score=0.0,
131
+ scores=init_scores,
132
+ states=init_states,
133
+ yseq=torch.tensor([self.sos], device=x.device),
134
+ )
135
+ ]
136
+
137
+ @staticmethod
138
+ def append_token(xs: torch.Tensor, x: int) -> torch.Tensor:
139
+ """Append new token to prefix tokens.
140
+
141
+ Args:
142
+ xs (torch.Tensor): The prefix token
143
+ x (int): The new token to append
144
+
145
+ Returns:
146
+ torch.Tensor: New tensor contains: xs + [x] with xs.dtype and xs.device
147
+
148
+ """
149
+ x = torch.tensor([x], dtype=xs.dtype, device=xs.device)
150
+ return torch.cat((xs, x))
151
+
152
+ def score_full(
153
+ self, hyp: Hypothesis, x: torch.Tensor
154
+ ) -> Tuple[Dict[str, torch.Tensor], Dict[str, Any]]:
155
+ """Score new hypothesis by `self.full_scorers`.
156
+
157
+ Args:
158
+ hyp (Hypothesis): Hypothesis with prefix tokens to score
159
+ x (torch.Tensor): Corresponding input feature
160
+
161
+ Returns:
162
+ Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: Tuple of
163
+ score dict of `hyp` that has string keys of `self.full_scorers`
164
+ and tensor score values of shape: `(self.n_vocab,)`,
165
+ and state dict that has string keys
166
+ and state values of `self.full_scorers`
167
+
168
+ """
169
+ scores = dict()
170
+ states = dict()
171
+ for k, d in self.full_scorers.items():
172
+ scores[k], states[k] = d.score(hyp.yseq, hyp.states[k], x)
173
+ return scores, states
174
+
175
+ def score_partial(
176
+ self, hyp: Hypothesis, ids: torch.Tensor, x: torch.Tensor
177
+ ) -> Tuple[Dict[str, torch.Tensor], Dict[str, Any]]:
178
+ """Score new hypothesis by `self.part_scorers`.
179
+
180
+ Args:
181
+ hyp (Hypothesis): Hypothesis with prefix tokens to score
182
+ ids (torch.Tensor): 1D tensor of new partial tokens to score
183
+ x (torch.Tensor): Corresponding input feature
184
+
185
+ Returns:
186
+ Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: Tuple of
187
+ score dict of `hyp` that has string keys of `self.part_scorers`
188
+ and tensor score values of shape: `(len(ids),)`,
189
+ and state dict that has string keys
190
+ and state values of `self.part_scorers`
191
+
192
+ """
193
+ scores = dict()
194
+ states = dict()
195
+ for k, d in self.part_scorers.items():
196
+ scores[k], states[k] = d.score_partial(hyp.yseq, ids, hyp.states[k], x)
197
+ return scores, states
198
+
199
+ def beam(
200
+ self, weighted_scores: torch.Tensor, ids: torch.Tensor
201
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
202
+ """Compute topk full token ids and partial token ids.
203
+
204
+ Args:
205
+ weighted_scores (torch.Tensor): The weighted sum scores for each tokens.
206
+ Its shape is `(self.n_vocab,)`.
207
+ ids (torch.Tensor): The partial token ids to compute topk
208
+
209
+ Returns:
210
+ Tuple[torch.Tensor, torch.Tensor]:
211
+ The topk full token ids and partial token ids.
212
+ Their shapes are `(self.beam_size,)`
213
+
214
+ """
215
+ # no pre beam performed
216
+ if weighted_scores.size(0) == ids.size(0):
217
+ top_ids = weighted_scores.topk(self.beam_size)[1]
218
+ return top_ids, top_ids
219
+
220
+ # mask pruned in pre-beam not to select in topk
221
+ tmp = weighted_scores[ids]
222
+ weighted_scores[:] = -float("inf")
223
+ weighted_scores[ids] = tmp
224
+ top_ids = weighted_scores.topk(self.beam_size)[1]
225
+ local_ids = weighted_scores[ids].topk(self.beam_size)[1]
226
+ return top_ids, local_ids
227
+
228
+ @staticmethod
229
+ def merge_scores(
230
+ prev_scores: Dict[str, float],
231
+ next_full_scores: Dict[str, torch.Tensor],
232
+ full_idx: int,
233
+ next_part_scores: Dict[str, torch.Tensor],
234
+ part_idx: int,
235
+ ) -> Dict[str, torch.Tensor]:
236
+ """Merge scores for new hypothesis.
237
+
238
+ Args:
239
+ prev_scores (Dict[str, float]):
240
+ The previous hypothesis scores by `self.scorers`
241
+ next_full_scores (Dict[str, torch.Tensor]): scores by `self.full_scorers`
242
+ full_idx (int): The next token id for `next_full_scores`
243
+ next_part_scores (Dict[str, torch.Tensor]):
244
+ scores of partial tokens by `self.part_scorers`
245
+ part_idx (int): The new token id for `next_part_scores`
246
+
247
+ Returns:
248
+ Dict[str, torch.Tensor]: The new score dict.
249
+ Its keys are names of `self.full_scorers` and `self.part_scorers`.
250
+ Its values are scalar tensors by the scorers.
251
+
252
+ """
253
+ new_scores = dict()
254
+ for k, v in next_full_scores.items():
255
+ new_scores[k] = prev_scores[k] + v[full_idx]
256
+ for k, v in next_part_scores.items():
257
+ new_scores[k] = prev_scores[k] + v[part_idx]
258
+ return new_scores
259
+
260
+ def merge_states(self, states: Any, part_states: Any, part_idx: int) -> Any:
261
+ """Merge states for new hypothesis.
262
+
263
+ Args:
264
+ states: states of `self.full_scorers`
265
+ part_states: states of `self.part_scorers`
266
+ part_idx (int): The new token id for `part_scores`
267
+
268
+ Returns:
269
+ Dict[str, torch.Tensor]: The new score dict.
270
+ Its keys are names of `self.full_scorers` and `self.part_scorers`.
271
+ Its values are states of the scorers.
272
+
273
+ """
274
+ new_states = dict()
275
+ for k, v in states.items():
276
+ new_states[k] = v
277
+ for k, d in self.part_scorers.items():
278
+ new_states[k] = d.select_state(part_states[k], part_idx)
279
+ return new_states
280
+
281
+ def search(
282
+ self, running_hyps: List[Hypothesis], x: torch.Tensor
283
+ ) -> List[Hypothesis]:
284
+ """Search new tokens for running hypotheses and encoded speech x.
285
+
286
+ Args:
287
+ running_hyps (List[Hypothesis]): Running hypotheses on beam
288
+ x (torch.Tensor): Encoded speech feature (T, D)
289
+
290
+ Returns:
291
+ List[Hypotheses]: Best sorted hypotheses
292
+
293
+ """
294
+ best_hyps = []
295
+ part_ids = torch.arange(self.n_vocab, device=x.device) # no pre-beam
296
+ for hyp in running_hyps:
297
+ # scoring
298
+ weighted_scores = torch.zeros(self.n_vocab, dtype=x.dtype, device=x.device)
299
+ scores, states = self.score_full(hyp, x)
300
+ for k in self.full_scorers:
301
+ weighted_scores += self.weights[k] * scores[k]
302
+ # partial scoring
303
+ if self.do_pre_beam:
304
+ pre_beam_scores = (
305
+ weighted_scores
306
+ if self.pre_beam_score_key == "full"
307
+ else scores[self.pre_beam_score_key]
308
+ )
309
+ part_ids = torch.topk(pre_beam_scores, self.pre_beam_size)[1]
310
+ part_scores, part_states = self.score_partial(hyp, part_ids, x)
311
+ for k in self.part_scorers:
312
+ weighted_scores[part_ids] += self.weights[k] * part_scores[k]
313
+ # add previous hyp score
314
+ weighted_scores += hyp.score
315
+
316
+ # update hyps
317
+ for j, part_j in zip(*self.beam(weighted_scores, part_ids)):
318
+ # will be (2 x beam at most)
319
+ best_hyps.append(
320
+ Hypothesis(
321
+ score=weighted_scores[j],
322
+ yseq=self.append_token(hyp.yseq, j),
323
+ scores=self.merge_scores(
324
+ hyp.scores, scores, j, part_scores, part_j
325
+ ),
326
+ states=self.merge_states(states, part_states, part_j),
327
+ )
328
+ )
329
+
330
+ # sort and prune 2 x beam -> beam
331
+ best_hyps = sorted(best_hyps, key=lambda x: x.score, reverse=True)[
332
+ : min(len(best_hyps), self.beam_size)
333
+ ]
334
+ return best_hyps
335
+
336
+ def forward(
337
+ self, x: torch.Tensor, maxlenratio: float = 0.0, minlenratio: float = 0.0
338
+ ) -> List[Hypothesis]:
339
+ """Perform beam search.
340
+
341
+ Args:
342
+ x (torch.Tensor): Encoded speech feature (T, D)
343
+ maxlenratio (float): Input length ratio to obtain max output length.
344
+ If maxlenratio=0.0 (default), it uses a end-detect function
345
+ to automatically find maximum hypothesis lengths
346
+ If maxlenratio<0.0, its absolute value is interpreted
347
+ as a constant max output length.
348
+ minlenratio (float): Input length ratio to obtain min output length.
349
+
350
+ Returns:
351
+ list[Hypothesis]: N-best decoding results
352
+
353
+ """
354
+ # set length bounds
355
+ if maxlenratio == 0:
356
+ maxlen = x.shape[0]
357
+ elif maxlenratio < 0:
358
+ maxlen = -1 * int(maxlenratio)
359
+ else:
360
+ maxlen = max(1, int(maxlenratio * x.size(0)))
361
+ minlen = int(minlenratio * x.size(0))
362
+ logging.info("decoder input length: " + str(x.shape[0]))
363
+ logging.info("max output length: " + str(maxlen))
364
+ logging.info("min output length: " + str(minlen))
365
+
366
+ # main loop of prefix search
367
+ running_hyps = self.init_hyp(x)
368
+ ended_hyps = []
369
+ for i in range(maxlen):
370
+ logging.debug("position " + str(i))
371
+ best = self.search(running_hyps, x)
372
+ # post process of one iteration
373
+ running_hyps = self.post_process(i, maxlen, maxlenratio, best, ended_hyps)
374
+ # end detection
375
+ if maxlenratio == 0.0 and end_detect([h.asdict() for h in ended_hyps], i):
376
+ logging.info(f"end detected at {i}")
377
+ break
378
+ if len(running_hyps) == 0:
379
+ logging.info("no hypothesis. Finish decoding.")
380
+ break
381
+ else:
382
+ logging.debug(f"remained hypotheses: {len(running_hyps)}")
383
+
384
+ nbest_hyps = sorted(ended_hyps, key=lambda x: x.score, reverse=True)
385
+ # check the number of hypotheses reaching to eos
386
+ if len(nbest_hyps) == 0:
387
+ logging.warning(
388
+ "there is no N-best results, perform recognition "
389
+ "again with smaller minlenratio."
390
+ )
391
+ return (
392
+ []
393
+ if minlenratio < 0.1
394
+ else self.forward(x, maxlenratio, max(0.0, minlenratio - 0.1))
395
+ )
396
+
397
+ # report the best result
398
+ best = nbest_hyps[0]
399
+ for k, v in best.scores.items():
400
+ logging.info(
401
+ f"{v:6.2f} * {self.weights[k]:3} = {v * self.weights[k]:6.2f} for {k}"
402
+ )
403
+ logging.info(f"total log probability: {best.score:.2f}")
404
+ logging.info(f"normalized log probability: {best.score / len(best.yseq):.2f}")
405
+ logging.info(f"total number of ended hypotheses: {len(nbest_hyps)}")
406
+ if self.token_list is not None:
407
+ logging.info(
408
+ "best hypo: "
409
+ + "".join([self.token_list[x] for x in best.yseq[1:-1]])
410
+ + "\n"
411
+ )
412
+ return nbest_hyps
413
+
414
+ def post_process(
415
+ self,
416
+ i: int,
417
+ maxlen: int,
418
+ maxlenratio: float,
419
+ running_hyps: List[Hypothesis],
420
+ ended_hyps: List[Hypothesis],
421
+ ) -> List[Hypothesis]:
422
+ """Perform post-processing of beam search iterations.
423
+
424
+ Args:
425
+ i (int): The length of hypothesis tokens.
426
+ maxlen (int): The maximum length of tokens in beam search.
427
+ maxlenratio (int): The maximum length ratio in beam search.
428
+ running_hyps (List[Hypothesis]): The running hypotheses in beam search.
429
+ ended_hyps (List[Hypothesis]): The ended hypotheses in beam search.
430
+
431
+ Returns:
432
+ List[Hypothesis]: The new running hypotheses.
433
+
434
+ """
435
+ logging.debug(f"the number of running hypotheses: {len(running_hyps)}")
436
+ if self.token_list is not None:
437
+ logging.debug(
438
+ "best hypo: "
439
+ + "".join([self.token_list[x] for x in running_hyps[0].yseq[1:]])
440
+ )
441
+ # add eos in the final loop to avoid that there are no ended hyps
442
+ if i == maxlen - 1:
443
+ logging.info("adding <eos> in the last position in the loop")
444
+ running_hyps = [
445
+ h._replace(yseq=self.append_token(h.yseq, self.eos))
446
+ for h in running_hyps
447
+ ]
448
+
449
+ # add ended hypotheses to a final list, and removed them from current hypotheses
450
+ # (this will be a problem, number of hyps < beam)
451
+ remained_hyps = []
452
+ for hyp in running_hyps:
453
+ if hyp.yseq[-1] == self.eos:
454
+ # e.g., Word LM needs to add final <eos> score
455
+ for k, d in chain(self.full_scorers.items(), self.part_scorers.items()):
456
+ s = d.final_score(hyp.states[k])
457
+ hyp.scores[k] += s
458
+ hyp = hyp._replace(score=hyp.score + self.weights[k] * s)
459
+ ended_hyps.append(hyp)
460
+ else:
461
+ remained_hyps.append(hyp)
462
+ return remained_hyps
463
+
464
+
465
+ def beam_search(
466
+ x: torch.Tensor,
467
+ sos: int,
468
+ eos: int,
469
+ beam_size: int,
470
+ vocab_size: int,
471
+ scorers: Dict[str, ScorerInterface],
472
+ weights: Dict[str, float],
473
+ token_list: List[str] = None,
474
+ maxlenratio: float = 0.0,
475
+ minlenratio: float = 0.0,
476
+ pre_beam_ratio: float = 1.5,
477
+ pre_beam_score_key: str = "full",
478
+ ) -> list:
479
+ """Perform beam search with scorers.
480
+
481
+ Args:
482
+ x (torch.Tensor): Encoded speech feature (T, D)
483
+ sos (int): Start of sequence id
484
+ eos (int): End of sequence id
485
+ beam_size (int): The number of hypotheses kept during search
486
+ vocab_size (int): The number of vocabulary
487
+ scorers (dict[str, ScorerInterface]): Dict of decoder modules
488
+ e.g., Decoder, CTCPrefixScorer, LM
489
+ The scorer will be ignored if it is `None`
490
+ weights (dict[str, float]): Dict of weights for each scorers
491
+ The scorer will be ignored if its weight is 0
492
+ token_list (list[str]): List of tokens for debug log
493
+ maxlenratio (float): Input length ratio to obtain max output length.
494
+ If maxlenratio=0.0 (default), it uses a end-detect function
495
+ to automatically find maximum hypothesis lengths
496
+ minlenratio (float): Input length ratio to obtain min output length.
497
+ pre_beam_score_key (str): key of scores to perform pre-beam search
498
+ pre_beam_ratio (float): beam size in the pre-beam search
499
+ will be `int(pre_beam_ratio * beam_size)`
500
+
501
+ Returns:
502
+ list: N-best decoding results
503
+
504
+ """
505
+ ret = BeamSearch(
506
+ scorers,
507
+ weights,
508
+ beam_size=beam_size,
509
+ vocab_size=vocab_size,
510
+ pre_beam_ratio=pre_beam_ratio,
511
+ pre_beam_score_key=pre_beam_score_key,
512
+ sos=sos,
513
+ eos=eos,
514
+ token_list=token_list,
515
+ ).forward(x=x, maxlenratio=maxlenratio, minlenratio=minlenratio)
516
+ return [h.asdict() for h in ret]
espnet/nets/ctc_prefix_score.py ADDED
@@ -0,0 +1,359 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ # Copyright 2018 Mitsubishi Electric Research Labs (Takaaki Hori)
4
+ # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
5
+
6
+ import torch
7
+
8
+ import numpy as np
9
+ import six
10
+
11
+
12
+ class CTCPrefixScoreTH(object):
13
+ """Batch processing of CTCPrefixScore
14
+
15
+ which is based on Algorithm 2 in WATANABE et al.
16
+ "HYBRID CTC/ATTENTION ARCHITECTURE FOR END-TO-END SPEECH RECOGNITION,"
17
+ but extended to efficiently compute the label probablities for multiple
18
+ hypotheses simultaneously
19
+ See also Seki et al. "Vectorized Beam Search for CTC-Attention-Based
20
+ Speech Recognition," In INTERSPEECH (pp. 3825-3829), 2019.
21
+ """
22
+
23
+ def __init__(self, x, xlens, blank, eos, margin=0):
24
+ """Construct CTC prefix scorer
25
+
26
+ :param torch.Tensor x: input label posterior sequences (B, T, O)
27
+ :param torch.Tensor xlens: input lengths (B,)
28
+ :param int blank: blank label id
29
+ :param int eos: end-of-sequence id
30
+ :param int margin: margin parameter for windowing (0 means no windowing)
31
+ """
32
+ # In the comment lines,
33
+ # we assume T: input_length, B: batch size, W: beam width, O: output dim.
34
+ self.logzero = -10000000000.0
35
+ self.blank = blank
36
+ self.eos = eos
37
+ self.batch = x.size(0)
38
+ self.input_length = x.size(1)
39
+ self.odim = x.size(2)
40
+ self.dtype = x.dtype
41
+ self.device = (
42
+ torch.device("cuda:%d" % x.get_device())
43
+ if x.is_cuda
44
+ else torch.device("cpu")
45
+ )
46
+ # Pad the rest of posteriors in the batch
47
+ # TODO(takaaki-hori): need a better way without for-loops
48
+ for i, l in enumerate(xlens):
49
+ if l < self.input_length:
50
+ x[i, l:, :] = self.logzero
51
+ x[i, l:, blank] = 0
52
+ # Reshape input x
53
+ xn = x.transpose(0, 1) # (B, T, O) -> (T, B, O)
54
+ xb = xn[:, :, self.blank].unsqueeze(2).expand(-1, -1, self.odim)
55
+ self.x = torch.stack([xn, xb]) # (2, T, B, O)
56
+ self.end_frames = torch.as_tensor(xlens) - 1
57
+
58
+ # Setup CTC windowing
59
+ self.margin = margin
60
+ if margin > 0:
61
+ self.frame_ids = torch.arange(
62
+ self.input_length, dtype=self.dtype, device=self.device
63
+ )
64
+ # Base indices for index conversion
65
+ self.idx_bh = None
66
+ self.idx_b = torch.arange(self.batch, device=self.device)
67
+ self.idx_bo = (self.idx_b * self.odim).unsqueeze(1)
68
+
69
+ def __call__(self, y, state, scoring_ids=None, att_w=None):
70
+ """Compute CTC prefix scores for next labels
71
+
72
+ :param list y: prefix label sequences
73
+ :param tuple state: previous CTC state
74
+ :param torch.Tensor pre_scores: scores for pre-selection of hypotheses (BW, O)
75
+ :param torch.Tensor att_w: attention weights to decide CTC window
76
+ :return new_state, ctc_local_scores (BW, O)
77
+ """
78
+ output_length = len(y[0]) - 1 # ignore sos
79
+ last_ids = [yi[-1] for yi in y] # last output label ids
80
+ n_bh = len(last_ids) # batch * hyps
81
+ n_hyps = n_bh // self.batch # assuming each utterance has the same # of hyps
82
+ self.scoring_num = scoring_ids.size(-1) if scoring_ids is not None else 0
83
+ # prepare state info
84
+ if state is None:
85
+ r_prev = torch.full(
86
+ (self.input_length, 2, self.batch, n_hyps),
87
+ self.logzero,
88
+ dtype=self.dtype,
89
+ device=self.device,
90
+ )
91
+ r_prev[:, 1] = torch.cumsum(self.x[0, :, :, self.blank], 0).unsqueeze(2)
92
+ r_prev = r_prev.view(-1, 2, n_bh)
93
+ s_prev = 0.0
94
+ f_min_prev = 0
95
+ f_max_prev = 1
96
+ else:
97
+ r_prev, s_prev, f_min_prev, f_max_prev = state
98
+
99
+ # select input dimensions for scoring
100
+ if self.scoring_num > 0:
101
+ scoring_idmap = torch.full(
102
+ (n_bh, self.odim), -1, dtype=torch.long, device=self.device
103
+ )
104
+ snum = self.scoring_num
105
+ if self.idx_bh is None or n_bh > len(self.idx_bh):
106
+ self.idx_bh = torch.arange(n_bh, device=self.device).view(-1, 1)
107
+ scoring_idmap[self.idx_bh[:n_bh], scoring_ids] = torch.arange(
108
+ snum, device=self.device
109
+ )
110
+ scoring_idx = (
111
+ scoring_ids + self.idx_bo.repeat(1, n_hyps).view(-1, 1)
112
+ ).view(-1)
113
+ x_ = torch.index_select(
114
+ self.x.view(2, -1, self.batch * self.odim), 2, scoring_idx
115
+ ).view(2, -1, n_bh, snum)
116
+ else:
117
+ scoring_ids = None
118
+ scoring_idmap = None
119
+ snum = self.odim
120
+ x_ = self.x.unsqueeze(3).repeat(1, 1, 1, n_hyps, 1).view(2, -1, n_bh, snum)
121
+
122
+ # new CTC forward probs are prepared as a (T x 2 x BW x S) tensor
123
+ # that corresponds to r_t^n(h) and r_t^b(h) in a batch.
124
+ r = torch.full(
125
+ (self.input_length, 2, n_bh, snum),
126
+ self.logzero,
127
+ dtype=self.dtype,
128
+ device=self.device,
129
+ )
130
+ if output_length == 0:
131
+ r[0, 0] = x_[0, 0]
132
+
133
+ r_sum = torch.logsumexp(r_prev, 1)
134
+ log_phi = r_sum.unsqueeze(2).repeat(1, 1, snum)
135
+ if scoring_ids is not None:
136
+ for idx in range(n_bh):
137
+ pos = scoring_idmap[idx, last_ids[idx]]
138
+ if pos >= 0:
139
+ log_phi[:, idx, pos] = r_prev[:, 1, idx]
140
+ else:
141
+ for idx in range(n_bh):
142
+ log_phi[:, idx, last_ids[idx]] = r_prev[:, 1, idx]
143
+
144
+ # decide start and end frames based on attention weights
145
+ if att_w is not None and self.margin > 0:
146
+ f_arg = torch.matmul(att_w, self.frame_ids)
147
+ f_min = max(int(f_arg.min().cpu()), f_min_prev)
148
+ f_max = max(int(f_arg.max().cpu()), f_max_prev)
149
+ start = min(f_max_prev, max(f_min - self.margin, output_length, 1))
150
+ end = min(f_max + self.margin, self.input_length)
151
+ else:
152
+ f_min = f_max = 0
153
+ start = max(output_length, 1)
154
+ end = self.input_length
155
+
156
+ # compute forward probabilities log(r_t^n(h)) and log(r_t^b(h))
157
+ for t in range(start, end):
158
+ rp = r[t - 1]
159
+ rr = torch.stack([rp[0], log_phi[t - 1], rp[0], rp[1]]).view(
160
+ 2, 2, n_bh, snum
161
+ )
162
+ r[t] = torch.logsumexp(rr, 1) + x_[:, t]
163
+
164
+ # compute log prefix probabilities log(psi)
165
+ log_phi_x = torch.cat((log_phi[0].unsqueeze(0), log_phi[:-1]), dim=0) + x_[0]
166
+ if scoring_ids is not None:
167
+ log_psi = torch.full(
168
+ (n_bh, self.odim), self.logzero, dtype=self.dtype, device=self.device
169
+ )
170
+ log_psi_ = torch.logsumexp(
171
+ torch.cat((log_phi_x[start:end], r[start - 1, 0].unsqueeze(0)), dim=0),
172
+ dim=0,
173
+ )
174
+ for si in range(n_bh):
175
+ log_psi[si, scoring_ids[si]] = log_psi_[si]
176
+ else:
177
+ log_psi = torch.logsumexp(
178
+ torch.cat((log_phi_x[start:end], r[start - 1, 0].unsqueeze(0)), dim=0),
179
+ dim=0,
180
+ )
181
+
182
+ for si in range(n_bh):
183
+ log_psi[si, self.eos] = r_sum[self.end_frames[si // n_hyps], si]
184
+
185
+ # exclude blank probs
186
+ log_psi[:, self.blank] = self.logzero
187
+
188
+ return (log_psi - s_prev), (r, log_psi, f_min, f_max, scoring_idmap)
189
+
190
+ def index_select_state(self, state, best_ids):
191
+ """Select CTC states according to best ids
192
+
193
+ :param state : CTC state
194
+ :param best_ids : index numbers selected by beam pruning (B, W)
195
+ :return selected_state
196
+ """
197
+ r, s, f_min, f_max, scoring_idmap = state
198
+ # convert ids to BHO space
199
+ n_bh = len(s)
200
+ n_hyps = n_bh // self.batch
201
+ vidx = (best_ids + (self.idx_b * (n_hyps * self.odim)).view(-1, 1)).view(-1)
202
+ # select hypothesis scores
203
+ s_new = torch.index_select(s.view(-1), 0, vidx)
204
+ s_new = s_new.view(-1, 1).repeat(1, self.odim).view(n_bh, self.odim)
205
+ # convert ids to BHS space (S: scoring_num)
206
+ if scoring_idmap is not None:
207
+ snum = self.scoring_num
208
+ hyp_idx = (best_ids // self.odim + (self.idx_b * n_hyps).view(-1, 1)).view(
209
+ -1
210
+ )
211
+ label_ids = torch.fmod(best_ids, self.odim).view(-1)
212
+ score_idx = scoring_idmap[hyp_idx, label_ids]
213
+ score_idx[score_idx == -1] = 0
214
+ vidx = score_idx + hyp_idx * snum
215
+ else:
216
+ snum = self.odim
217
+ # select forward probabilities
218
+ r_new = torch.index_select(r.view(-1, 2, n_bh * snum), 2, vidx).view(
219
+ -1, 2, n_bh
220
+ )
221
+ return r_new, s_new, f_min, f_max
222
+
223
+ def extend_prob(self, x):
224
+ """Extend CTC prob.
225
+
226
+ :param torch.Tensor x: input label posterior sequences (B, T, O)
227
+ """
228
+
229
+ if self.x.shape[1] < x.shape[1]: # self.x (2,T,B,O); x (B,T,O)
230
+ # Pad the rest of posteriors in the batch
231
+ # TODO(takaaki-hori): need a better way without for-loops
232
+ xlens = [x.size(1)]
233
+ for i, l in enumerate(xlens):
234
+ if l < self.input_length:
235
+ x[i, l:, :] = self.logzero
236
+ x[i, l:, self.blank] = 0
237
+ tmp_x = self.x
238
+ xn = x.transpose(0, 1) # (B, T, O) -> (T, B, O)
239
+ xb = xn[:, :, self.blank].unsqueeze(2).expand(-1, -1, self.odim)
240
+ self.x = torch.stack([xn, xb]) # (2, T, B, O)
241
+ self.x[:, : tmp_x.shape[1], :, :] = tmp_x
242
+ self.input_length = x.size(1)
243
+ self.end_frames = torch.as_tensor(xlens) - 1
244
+
245
+ def extend_state(self, state):
246
+ """Compute CTC prefix state.
247
+
248
+
249
+ :param state : CTC state
250
+ :return ctc_state
251
+ """
252
+
253
+ if state is None:
254
+ # nothing to do
255
+ return state
256
+ else:
257
+ r_prev, s_prev, f_min_prev, f_max_prev = state
258
+
259
+ r_prev_new = torch.full(
260
+ (self.input_length, 2),
261
+ self.logzero,
262
+ dtype=self.dtype,
263
+ device=self.device,
264
+ )
265
+ start = max(r_prev.shape[0], 1)
266
+ r_prev_new[0:start] = r_prev
267
+ for t in six.moves.range(start, self.input_length):
268
+ r_prev_new[t, 1] = r_prev_new[t - 1, 1] + self.x[0, t, :, self.blank]
269
+
270
+ return (r_prev_new, s_prev, f_min_prev, f_max_prev)
271
+
272
+
273
+ class CTCPrefixScore(object):
274
+ """Compute CTC label sequence scores
275
+
276
+ which is based on Algorithm 2 in WATANABE et al.
277
+ "HYBRID CTC/ATTENTION ARCHITECTURE FOR END-TO-END SPEECH RECOGNITION,"
278
+ but extended to efficiently compute the probablities of multiple labels
279
+ simultaneously
280
+ """
281
+
282
+ def __init__(self, x, blank, eos, xp):
283
+ self.xp = xp
284
+ self.logzero = -10000000000.0
285
+ self.blank = blank
286
+ self.eos = eos
287
+ self.input_length = len(x)
288
+ self.x = x
289
+
290
+ def initial_state(self):
291
+ """Obtain an initial CTC state
292
+
293
+ :return: CTC state
294
+ """
295
+ # initial CTC state is made of a frame x 2 tensor that corresponds to
296
+ # r_t^n(<sos>) and r_t^b(<sos>), where 0 and 1 of axis=1 represent
297
+ # superscripts n and b (non-blank and blank), respectively.
298
+ r = self.xp.full((self.input_length, 2), self.logzero, dtype=np.float32)
299
+ r[0, 1] = self.x[0, self.blank]
300
+ for i in six.moves.range(1, self.input_length):
301
+ r[i, 1] = r[i - 1, 1] + self.x[i, self.blank]
302
+ return r
303
+
304
+ def __call__(self, y, cs, r_prev):
305
+ """Compute CTC prefix scores for next labels
306
+
307
+ :param y : prefix label sequence
308
+ :param cs : array of next labels
309
+ :param r_prev: previous CTC state
310
+ :return ctc_scores, ctc_states
311
+ """
312
+ # initialize CTC states
313
+ output_length = len(y) - 1 # ignore sos
314
+ # new CTC states are prepared as a frame x (n or b) x n_labels tensor
315
+ # that corresponds to r_t^n(h) and r_t^b(h).
316
+ r = self.xp.ndarray((self.input_length, 2, len(cs)), dtype=np.float32)
317
+ xs = self.x[:, cs]
318
+ if output_length == 0:
319
+ r[0, 0] = xs[0]
320
+ r[0, 1] = self.logzero
321
+ else:
322
+ r[output_length - 1] = self.logzero
323
+
324
+ # prepare forward probabilities for the last label
325
+ r_sum = self.xp.logaddexp(
326
+ r_prev[:, 0], r_prev[:, 1]
327
+ ) # log(r_t^n(g) + r_t^b(g))
328
+ last = y[-1]
329
+ if output_length > 0 and last in cs:
330
+ log_phi = self.xp.ndarray((self.input_length, len(cs)), dtype=np.float32)
331
+ for i in six.moves.range(len(cs)):
332
+ log_phi[:, i] = r_sum if cs[i] != last else r_prev[:, 1]
333
+ else:
334
+ log_phi = r_sum
335
+
336
+ # compute forward probabilities log(r_t^n(h)), log(r_t^b(h)),
337
+ # and log prefix probabilities log(psi)
338
+ start = max(output_length, 1)
339
+ log_psi = r[start - 1, 0]
340
+ for t in six.moves.range(start, self.input_length):
341
+ r[t, 0] = self.xp.logaddexp(r[t - 1, 0], log_phi[t - 1]) + xs[t]
342
+ r[t, 1] = (
343
+ self.xp.logaddexp(r[t - 1, 0], r[t - 1, 1]) + self.x[t, self.blank]
344
+ )
345
+ log_psi = self.xp.logaddexp(log_psi, log_phi[t - 1] + xs[t])
346
+
347
+ # get P(...eos|X) that ends with the prefix itself
348
+ eos_pos = self.xp.where(cs == self.eos)[0]
349
+ if len(eos_pos) > 0:
350
+ log_psi[eos_pos] = r_sum[-1] # log(r_T^n(g) + r_T^b(g))
351
+
352
+ # exclude blank probs
353
+ blank_pos = self.xp.where(cs == self.blank)[0]
354
+ if len(blank_pos) > 0:
355
+ log_psi[blank_pos] = self.logzero
356
+
357
+ # return the log prefix probability and CTC states, where the label axis
358
+ # of the CTC states is moved to the first axis to slice it easily
359
+ return log_psi, self.xp.rollaxis(r, 2)
espnet/nets/e2e_asr_common.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # encoding: utf-8
3
+
4
+ # Copyright 2017 Johns Hopkins University (Shinji Watanabe)
5
+ # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
6
+
7
+ """Common functions for ASR."""
8
+
9
+ import json
10
+ import logging
11
+ import sys
12
+
13
+ from itertools import groupby
14
+ import numpy as np
15
+ import six
16
+
17
+
18
+ def end_detect(ended_hyps, i, M=3, D_end=np.log(1 * np.exp(-10))):
19
+ """End detection.
20
+
21
+ described in Eq. (50) of S. Watanabe et al
22
+ "Hybrid CTC/Attention Architecture for End-to-End Speech Recognition"
23
+
24
+ :param ended_hyps:
25
+ :param i:
26
+ :param M:
27
+ :param D_end:
28
+ :return:
29
+ """
30
+ if len(ended_hyps) == 0:
31
+ return False
32
+ count = 0
33
+ best_hyp = sorted(ended_hyps, key=lambda x: x["score"], reverse=True)[0]
34
+ for m in six.moves.range(M):
35
+ # get ended_hyps with their length is i - m
36
+ hyp_length = i - m
37
+ hyps_same_length = [x for x in ended_hyps if len(x["yseq"]) == hyp_length]
38
+ if len(hyps_same_length) > 0:
39
+ best_hyp_same_length = sorted(
40
+ hyps_same_length, key=lambda x: x["score"], reverse=True
41
+ )[0]
42
+ if best_hyp_same_length["score"] - best_hyp["score"] < D_end:
43
+ count += 1
44
+
45
+ if count == M:
46
+ return True
47
+ else:
48
+ return False
49
+
50
+
51
+ # TODO(takaaki-hori): add different smoothing methods
52
+ def label_smoothing_dist(odim, lsm_type, transcript=None, blank=0):
53
+ """Obtain label distribution for loss smoothing.
54
+
55
+ :param odim:
56
+ :param lsm_type:
57
+ :param blank:
58
+ :param transcript:
59
+ :return:
60
+ """
61
+ if transcript is not None:
62
+ with open(transcript, "rb") as f:
63
+ trans_json = json.load(f)["utts"]
64
+
65
+ if lsm_type == "unigram":
66
+ assert transcript is not None, (
67
+ "transcript is required for %s label smoothing" % lsm_type
68
+ )
69
+ labelcount = np.zeros(odim)
70
+ for k, v in trans_json.items():
71
+ ids = np.array([int(n) for n in v["output"][0]["tokenid"].split()])
72
+ # to avoid an error when there is no text in an uttrance
73
+ if len(ids) > 0:
74
+ labelcount[ids] += 1
75
+ labelcount[odim - 1] = len(transcript) # count <eos>
76
+ labelcount[labelcount == 0] = 1 # flooring
77
+ labelcount[blank] = 0 # remove counts for blank
78
+ labeldist = labelcount.astype(np.float32) / np.sum(labelcount)
79
+ else:
80
+ logging.error("Error: unexpected label smoothing type: %s" % lsm_type)
81
+ sys.exit()
82
+
83
+ return labeldist
84
+
85
+
86
+ def get_vgg2l_odim(idim, in_channel=3, out_channel=128):
87
+ """Return the output size of the VGG frontend.
88
+
89
+ :param in_channel: input channel size
90
+ :param out_channel: output channel size
91
+ :return: output size
92
+ :rtype int
93
+ """
94
+ idim = idim / in_channel
95
+ idim = np.ceil(np.array(idim, dtype=np.float32) / 2) # 1st max pooling
96
+ idim = np.ceil(np.array(idim, dtype=np.float32) / 2) # 2nd max pooling
97
+ return int(idim) * out_channel # numer of channels
98
+
99
+
100
+ class ErrorCalculator(object):
101
+ """Calculate CER and WER for E2E_ASR and CTC models during training.
102
+
103
+ :param y_hats: numpy array with predicted text
104
+ :param y_pads: numpy array with true (target) text
105
+ :param char_list:
106
+ :param sym_space:
107
+ :param sym_blank:
108
+ :return:
109
+ """
110
+
111
+ def __init__(
112
+ self, char_list, sym_space, sym_blank, report_cer=False, report_wer=False
113
+ ):
114
+ """Construct an ErrorCalculator object."""
115
+ super(ErrorCalculator, self).__init__()
116
+
117
+ self.report_cer = report_cer
118
+ self.report_wer = report_wer
119
+
120
+ self.char_list = char_list
121
+ self.space = sym_space
122
+ self.blank = sym_blank
123
+ self.idx_blank = self.char_list.index(self.blank)
124
+ if self.space in self.char_list:
125
+ self.idx_space = self.char_list.index(self.space)
126
+ else:
127
+ self.idx_space = None
128
+
129
+ def __call__(self, ys_hat, ys_pad, is_ctc=False):
130
+ """Calculate sentence-level WER/CER score.
131
+
132
+ :param torch.Tensor ys_hat: prediction (batch, seqlen)
133
+ :param torch.Tensor ys_pad: reference (batch, seqlen)
134
+ :param bool is_ctc: calculate CER score for CTC
135
+ :return: sentence-level WER score
136
+ :rtype float
137
+ :return: sentence-level CER score
138
+ :rtype float
139
+ """
140
+ cer, wer = None, None
141
+ if is_ctc:
142
+ return self.calculate_cer_ctc(ys_hat, ys_pad)
143
+ elif not self.report_cer and not self.report_wer:
144
+ return cer, wer
145
+
146
+ seqs_hat, seqs_true = self.convert_to_char(ys_hat, ys_pad)
147
+ if self.report_cer:
148
+ cer = self.calculate_cer(seqs_hat, seqs_true)
149
+
150
+ if self.report_wer:
151
+ wer = self.calculate_wer(seqs_hat, seqs_true)
152
+ return cer, wer
153
+
154
+ def calculate_cer_ctc(self, ys_hat, ys_pad):
155
+ """Calculate sentence-level CER score for CTC.
156
+
157
+ :param torch.Tensor ys_hat: prediction (batch, seqlen)
158
+ :param torch.Tensor ys_pad: reference (batch, seqlen)
159
+ :return: average sentence-level CER score
160
+ :rtype float
161
+ """
162
+ import editdistance
163
+
164
+ cers, char_ref_lens = [], []
165
+ for i, y in enumerate(ys_hat):
166
+ y_hat = [x[0] for x in groupby(y)]
167
+ y_true = ys_pad[i]
168
+ seq_hat, seq_true = [], []
169
+ for idx in y_hat:
170
+ idx = int(idx)
171
+ if idx != -1 and idx != self.idx_blank and idx != self.idx_space:
172
+ seq_hat.append(self.char_list[int(idx)])
173
+
174
+ for idx in y_true:
175
+ idx = int(idx)
176
+ if idx != -1 and idx != self.idx_blank and idx != self.idx_space:
177
+ seq_true.append(self.char_list[int(idx)])
178
+
179
+ hyp_chars = "".join(seq_hat)
180
+ ref_chars = "".join(seq_true)
181
+ if len(ref_chars) > 0:
182
+ cers.append(editdistance.eval(hyp_chars, ref_chars))
183
+ char_ref_lens.append(len(ref_chars))
184
+
185
+ cer_ctc = float(sum(cers)) / sum(char_ref_lens) if cers else None
186
+ return cer_ctc
187
+
188
+ def convert_to_char(self, ys_hat, ys_pad):
189
+ """Convert index to character.
190
+
191
+ :param torch.Tensor seqs_hat: prediction (batch, seqlen)
192
+ :param torch.Tensor seqs_true: reference (batch, seqlen)
193
+ :return: token list of prediction
194
+ :rtype list
195
+ :return: token list of reference
196
+ :rtype list
197
+ """
198
+ seqs_hat, seqs_true = [], []
199
+ for i, y_hat in enumerate(ys_hat):
200
+ y_true = ys_pad[i]
201
+ eos_true = np.where(y_true == -1)[0]
202
+ ymax = eos_true[0] if len(eos_true) > 0 else len(y_true)
203
+ # NOTE: padding index (-1) in y_true is used to pad y_hat
204
+ seq_hat = [self.char_list[int(idx)] for idx in y_hat[:ymax]]
205
+ seq_true = [self.char_list[int(idx)] for idx in y_true if int(idx) != -1]
206
+ seq_hat_text = "".join(seq_hat).replace(self.space, " ")
207
+ seq_hat_text = seq_hat_text.replace(self.blank, "")
208
+ seq_true_text = "".join(seq_true).replace(self.space, " ")
209
+ seqs_hat.append(seq_hat_text)
210
+ seqs_true.append(seq_true_text)
211
+ return seqs_hat, seqs_true
212
+
213
+ def calculate_cer(self, seqs_hat, seqs_true):
214
+ """Calculate sentence-level CER score.
215
+
216
+ :param list seqs_hat: prediction
217
+ :param list seqs_true: reference
218
+ :return: average sentence-level CER score
219
+ :rtype float
220
+ """
221
+ import editdistance
222
+
223
+ char_eds, char_ref_lens = [], []
224
+ for i, seq_hat_text in enumerate(seqs_hat):
225
+ seq_true_text = seqs_true[i]
226
+ hyp_chars = seq_hat_text.replace(" ", "")
227
+ ref_chars = seq_true_text.replace(" ", "")
228
+ char_eds.append(editdistance.eval(hyp_chars, ref_chars))
229
+ char_ref_lens.append(len(ref_chars))
230
+ return float(sum(char_eds)) / sum(char_ref_lens)
231
+
232
+ def calculate_wer(self, seqs_hat, seqs_true):
233
+ """Calculate sentence-level WER score.
234
+
235
+ :param list seqs_hat: prediction
236
+ :param list seqs_true: reference
237
+ :return: average sentence-level WER score
238
+ :rtype float
239
+ """
240
+ import editdistance
241
+
242
+ word_eds, word_ref_lens = [], []
243
+ for i, seq_hat_text in enumerate(seqs_hat):
244
+ seq_true_text = seqs_true[i]
245
+ hyp_words = seq_hat_text.split()
246
+ ref_words = seq_true_text.split()
247
+ word_eds.append(editdistance.eval(hyp_words, ref_words))
248
+ word_ref_lens.append(len(ref_words))
249
+ return float(sum(word_eds)) / sum(word_ref_lens)
espnet/nets/lm_interface.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Language model interface."""
2
+
3
+ import argparse
4
+
5
+ from espnet.nets.scorer_interface import ScorerInterface
6
+ from espnet.utils.dynamic_import import dynamic_import
7
+ from espnet.utils.fill_missing_args import fill_missing_args
8
+
9
+
10
+ class LMInterface(ScorerInterface):
11
+ """LM Interface for ESPnet model implementation."""
12
+
13
+ @staticmethod
14
+ def add_arguments(parser):
15
+ """Add arguments to command line argument parser."""
16
+ return parser
17
+
18
+ @classmethod
19
+ def build(cls, n_vocab: int, **kwargs):
20
+ """Initialize this class with python-level args.
21
+
22
+ Args:
23
+ idim (int): The number of vocabulary.
24
+
25
+ Returns:
26
+ LMinterface: A new instance of LMInterface.
27
+
28
+ """
29
+ # local import to avoid cyclic import in lm_train
30
+ from espnet.bin.lm_train import get_parser
31
+
32
+ def wrap(parser):
33
+ return get_parser(parser, required=False)
34
+
35
+ args = argparse.Namespace(**kwargs)
36
+ args = fill_missing_args(args, wrap)
37
+ args = fill_missing_args(args, cls.add_arguments)
38
+ return cls(n_vocab, args)
39
+
40
+ def forward(self, x, t):
41
+ """Compute LM loss value from buffer sequences.
42
+
43
+ Args:
44
+ x (torch.Tensor): Input ids. (batch, len)
45
+ t (torch.Tensor): Target ids. (batch, len)
46
+
47
+ Returns:
48
+ tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Tuple of
49
+ loss to backward (scalar),
50
+ negative log-likelihood of t: -log p(t) (scalar) and
51
+ the number of elements in x (scalar)
52
+
53
+ Notes:
54
+ The last two return values are used
55
+ in perplexity: p(t)^{-n} = exp(-log p(t) / n)
56
+
57
+ """
58
+ raise NotImplementedError("forward method is not implemented")
59
+
60
+
61
+ predefined_lms = {
62
+ "pytorch": {
63
+ "default": "espnet.nets.pytorch_backend.lm.default:DefaultRNNLM",
64
+ "seq_rnn": "espnet.nets.pytorch_backend.lm.seq_rnn:SequentialRNNLM",
65
+ "transformer": "espnet.nets.pytorch_backend.lm.transformer:TransformerLM",
66
+ },
67
+ "chainer": {"default": "espnet.lm.chainer_backend.lm:DefaultRNNLM"},
68
+ }
69
+
70
+
71
+ def dynamic_import_lm(module, backend):
72
+ """Import LM class dynamically.
73
+
74
+ Args:
75
+ module (str): module_name:class_name or alias in `predefined_lms`
76
+ backend (str): NN backend. e.g., pytorch, chainer
77
+
78
+ Returns:
79
+ type: LM class
80
+
81
+ """
82
+ model_class = dynamic_import(module, predefined_lms.get(backend, dict()))
83
+ assert issubclass(
84
+ model_class, LMInterface
85
+ ), f"{module} does not implement LMInterface"
86
+ return model_class
espnet/nets/pytorch_backend/backbones/conv1d_extractor.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #! /usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+
4
+ # Copyright 2021 Imperial College London (Pingchuan Ma)
5
+ # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
6
+ import torch
7
+ from espnet.nets.pytorch_backend.backbones.modules.resnet1d import ResNet1D, BasicBlock1D
8
+
9
+ class Conv1dResNet(torch.nn.Module):
10
+ def __init__(self, relu_type="swish", a_upsample_ratio=1):
11
+ super().__init__()
12
+ self.a_upsample_ratio = a_upsample_ratio
13
+ self.trunk = ResNet1D(BasicBlock1D, [2, 2, 2, 2], relu_type=relu_type, a_upsample_ratio=a_upsample_ratio)
14
+
15
+
16
+ def forward(self, xs_pad):
17
+ """forward.
18
+
19
+ :param xs_pad: torch.Tensor, batch of padded input sequences (B, Tmax, idim)
20
+ """
21
+ B, T, C = xs_pad.size()
22
+ xs_pad = xs_pad[:, :T // 640 * 640, :]
23
+ xs_pad = xs_pad.transpose(1, 2)
24
+ xs_pad = self.trunk(xs_pad)
25
+ return xs_pad.transpose(1, 2)
espnet/nets/pytorch_backend/backbones/conv3d_extractor.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #! /usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+
4
+ # Copyright 2021 Imperial College London (Pingchuan Ma)
5
+ # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ from espnet.nets.pytorch_backend.backbones.modules.resnet import ResNet, BasicBlock
10
+ from espnet.nets.pytorch_backend.transformer.convolution import Swish
11
+
12
+
13
+ def threeD_to_2D_tensor(x):
14
+ n_batch, n_channels, s_time, sx, sy = x.shape
15
+ x = x.transpose(1, 2)
16
+ return x.reshape(n_batch * s_time, n_channels, sx, sy)
17
+
18
+
19
+
20
+ class Conv3dResNet(torch.nn.Module):
21
+ """Conv3dResNet module
22
+ """
23
+
24
+ def __init__(self, backbone_type="resnet", relu_type="swish"):
25
+ """__init__.
26
+
27
+ :param backbone_type: str, the type of a visual front-end.
28
+ :param relu_type: str, activation function used in an audio front-end.
29
+ """
30
+ super(Conv3dResNet, self).__init__()
31
+ self.frontend_nout = 64
32
+ self.trunk = ResNet(BasicBlock, [2, 2, 2, 2], relu_type=relu_type)
33
+ self.frontend3D = nn.Sequential(
34
+ nn.Conv3d(1, self.frontend_nout, (5, 7, 7), (1, 2, 2), (2, 3, 3), bias=False),
35
+ nn.BatchNorm3d(self.frontend_nout),
36
+ Swish(),
37
+ nn.MaxPool3d((1, 3, 3), (1, 2, 2), (0, 1, 1))
38
+ )
39
+
40
+
41
+ def forward(self, xs_pad):
42
+ B, C, T, H, W = xs_pad.size()
43
+ xs_pad = self.frontend3D(xs_pad)
44
+ Tnew = xs_pad.shape[2]
45
+ xs_pad = threeD_to_2D_tensor(xs_pad)
46
+ xs_pad = self.trunk(xs_pad)
47
+ return xs_pad.view(B, Tnew, xs_pad.size(1))
espnet/nets/pytorch_backend/backbones/modules/resnet.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch.nn as nn
3
+ import pdb
4
+
5
+ from espnet.nets.pytorch_backend.transformer.convolution import Swish
6
+
7
+
8
+ def conv3x3(in_planes, out_planes, stride=1):
9
+ """conv3x3.
10
+
11
+ :param in_planes: int, number of channels in the input sequence.
12
+ :param out_planes: int, number of channels produced by the convolution.
13
+ :param stride: int, size of the convolving kernel.
14
+ """
15
+ return nn.Conv2d(
16
+ in_planes,
17
+ out_planes,
18
+ kernel_size=3,
19
+ stride=stride,
20
+ padding=1,
21
+ bias=False,
22
+ )
23
+
24
+
25
+ def downsample_basic_block(inplanes, outplanes, stride):
26
+ """downsample_basic_block.
27
+
28
+ :param inplanes: int, number of channels in the input sequence.
29
+ :param outplanes: int, number of channels produced by the convolution.
30
+ :param stride: int, size of the convolving kernel.
31
+ """
32
+ return nn.Sequential(
33
+ nn.Conv2d(
34
+ inplanes,
35
+ outplanes,
36
+ kernel_size=1,
37
+ stride=stride,
38
+ bias=False,
39
+ ),
40
+ nn.BatchNorm2d(outplanes),
41
+ )
42
+
43
+
44
+ class BasicBlock(nn.Module):
45
+ expansion = 1
46
+
47
+ def __init__(
48
+ self,
49
+ inplanes,
50
+ planes,
51
+ stride=1,
52
+ downsample=None,
53
+ relu_type="swish",
54
+ ):
55
+ """__init__.
56
+
57
+ :param inplanes: int, number of channels in the input sequence.
58
+ :param planes: int, number of channels produced by the convolution.
59
+ :param stride: int, size of the convolving kernel.
60
+ :param downsample: boolean, if True, the temporal resolution is downsampled.
61
+ :param relu_type: str, type of activation function.
62
+ """
63
+ super(BasicBlock, self).__init__()
64
+
65
+ assert relu_type in ["relu", "prelu", "swish"]
66
+
67
+ self.conv1 = conv3x3(inplanes, planes, stride)
68
+ self.bn1 = nn.BatchNorm2d(planes)
69
+
70
+ if relu_type == "relu":
71
+ self.relu1 = nn.ReLU(inplace=True)
72
+ self.relu2 = nn.ReLU(inplace=True)
73
+ elif relu_type == "prelu":
74
+ self.relu1 = nn.PReLU(num_parameters=planes)
75
+ self.relu2 = nn.PReLU(num_parameters=planes)
76
+ elif relu_type == "swish":
77
+ self.relu1 = Swish()
78
+ self.relu2 = Swish()
79
+ else:
80
+ raise NotImplementedError
81
+ # --------
82
+
83
+ self.conv2 = conv3x3(planes, planes)
84
+ self.bn2 = nn.BatchNorm2d(planes)
85
+
86
+ self.downsample = downsample
87
+ self.stride = stride
88
+
89
+ def forward(self, x):
90
+ """forward.
91
+
92
+ :param x: torch.Tensor, input tensor with input size (B, C, T, H, W).
93
+ """
94
+ residual = x
95
+ out = self.conv1(x)
96
+ out = self.bn1(out)
97
+ out = self.relu1(out)
98
+ out = self.conv2(out)
99
+ out = self.bn2(out)
100
+ if self.downsample is not None:
101
+ residual = self.downsample(x)
102
+
103
+ out += residual
104
+ out = self.relu2(out)
105
+
106
+ return out
107
+
108
+
109
+ class ResNet(nn.Module):
110
+
111
+ def __init__(
112
+ self,
113
+ block,
114
+ layers,
115
+ relu_type="swish",
116
+ ):
117
+ super(ResNet, self).__init__()
118
+ self.inplanes = 64
119
+ self.relu_type = relu_type
120
+ self.downsample_block = downsample_basic_block
121
+
122
+ self.layer1 = self._make_layer(block, 64, layers[0])
123
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
124
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
125
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
126
+ self.avgpool = nn.AdaptiveAvgPool2d(1)
127
+
128
+
129
+ def _make_layer(self, block, planes, blocks, stride=1):
130
+ """_make_layer.
131
+
132
+ :param block: torch.nn.Module, class of blocks.
133
+ :param planes: int, number of channels produced by the convolution.
134
+ :param blocks: int, number of layers in a block.
135
+ :param stride: int, size of the convolving kernel.
136
+ """
137
+ downsample = None
138
+ if stride != 1 or self.inplanes != planes * block.expansion:
139
+ downsample = self.downsample_block(
140
+ inplanes=self.inplanes,
141
+ outplanes=planes*block.expansion,
142
+ stride=stride,
143
+ )
144
+
145
+ layers = []
146
+ layers.append(
147
+ block(
148
+ self.inplanes,
149
+ planes,
150
+ stride,
151
+ downsample,
152
+ relu_type=self.relu_type,
153
+ )
154
+ )
155
+ self.inplanes = planes * block.expansion
156
+ for i in range(1, blocks):
157
+ layers.append(
158
+ block(
159
+ self.inplanes,
160
+ planes,
161
+ relu_type=self.relu_type,
162
+ )
163
+ )
164
+
165
+ return nn.Sequential(*layers)
166
+
167
+ def forward(self, x):
168
+ """forward.
169
+
170
+ :param x: torch.Tensor, input tensor with input size (B, C, T, H, W).
171
+ """
172
+ x = self.layer1(x)
173
+ x = self.layer2(x)
174
+ x = self.layer3(x)
175
+ x = self.layer4(x)
176
+ x = self.avgpool(x)
177
+ x = x.view(x.size(0), -1)
178
+ return x
espnet/nets/pytorch_backend/backbones/modules/resnet1d.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch.nn as nn
3
+ import pdb
4
+
5
+ from espnet.nets.pytorch_backend.transformer.convolution import Swish
6
+
7
+
8
+ def conv3x3(in_planes, out_planes, stride=1):
9
+ """conv3x3.
10
+
11
+ :param in_planes: int, number of channels in the input sequence.
12
+ :param out_planes: int, number of channels produced by the convolution.
13
+ :param stride: int, size of the convolving kernel.
14
+ """
15
+ return nn.Conv1d(
16
+ in_planes,
17
+ out_planes,
18
+ kernel_size=3,
19
+ stride=stride,
20
+ padding=1,
21
+ bias=False,
22
+ )
23
+
24
+
25
+ def downsample_basic_block(inplanes, outplanes, stride):
26
+ """downsample_basic_block.
27
+
28
+ :param inplanes: int, number of channels in the input sequence.
29
+ :param outplanes: int, number of channels produced by the convolution.
30
+ :param stride: int, size of the convolving kernel.
31
+ """
32
+ return nn.Sequential(
33
+ nn.Conv1d(
34
+ inplanes,
35
+ outplanes,
36
+ kernel_size=1,
37
+ stride=stride,
38
+ bias=False,
39
+ ),
40
+ nn.BatchNorm1d(outplanes),
41
+ )
42
+
43
+
44
+ class BasicBlock1D(nn.Module):
45
+ expansion = 1
46
+
47
+ def __init__(
48
+ self,
49
+ inplanes,
50
+ planes,
51
+ stride=1,
52
+ downsample=None,
53
+ relu_type="relu",
54
+ ):
55
+ """__init__.
56
+
57
+ :param inplanes: int, number of channels in the input sequence.
58
+ :param planes: int, number of channels produced by the convolution.
59
+ :param stride: int, size of the convolving kernel.
60
+ :param downsample: boolean, if True, the temporal resolution is downsampled.
61
+ :param relu_type: str, type of activation function.
62
+ """
63
+ super(BasicBlock1D, self).__init__()
64
+
65
+ assert relu_type in ["relu","prelu", "swish"]
66
+
67
+ self.conv1 = conv3x3(inplanes, planes, stride)
68
+ self.bn1 = nn.BatchNorm1d(planes)
69
+
70
+ # type of ReLU is an input option
71
+ if relu_type == "relu":
72
+ self.relu1 = nn.ReLU(inplace=True)
73
+ self.relu2 = nn.ReLU(inplace=True)
74
+ elif relu_type == "prelu":
75
+ self.relu1 = nn.PReLU(num_parameters=planes)
76
+ self.relu2 = nn.PReLU(num_parameters=planes)
77
+ elif relu_type == "swish":
78
+ self.relu1 = Swish()
79
+ self.relu2 = Swish()
80
+ else:
81
+ raise NotImplementedError
82
+ # --------
83
+
84
+ self.conv2 = conv3x3(planes, planes)
85
+ self.bn2 = nn.BatchNorm1d(planes)
86
+
87
+ self.downsample = downsample
88
+ self.stride = stride
89
+
90
+ def forward(self, x):
91
+ """forward.
92
+
93
+ :param x: torch.Tensor, input tensor with input size (B, C, T)
94
+ """
95
+ residual = x
96
+ out = self.conv1(x)
97
+ out = self.bn1(out)
98
+ out = self.relu1(out)
99
+ out = self.conv2(out)
100
+ out = self.bn2(out)
101
+ if self.downsample is not None:
102
+ residual = self.downsample(x)
103
+
104
+ out += residual
105
+ out = self.relu2(out)
106
+
107
+ return out
108
+
109
+
110
+ class ResNet1D(nn.Module):
111
+
112
+ def __init__(self,
113
+ block,
114
+ layers,
115
+ relu_type="swish",
116
+ a_upsample_ratio=1,
117
+ ):
118
+ """__init__.
119
+
120
+ :param block: torch.nn.Module, class of blocks.
121
+ :param layers: List, customised layers in each block.
122
+ :param relu_type: str, type of activation function.
123
+ :param a_upsample_ratio: int, The ratio related to the \
124
+ temporal resolution of output features of the frontend. \
125
+ a_upsample_ratio=1 produce features with a fps of 25.
126
+ """
127
+ super(ResNet1D, self).__init__()
128
+ self.inplanes = 64
129
+ self.relu_type = relu_type
130
+ self.downsample_block = downsample_basic_block
131
+ self.a_upsample_ratio = a_upsample_ratio
132
+
133
+ self.conv1 = nn.Conv1d(
134
+ in_channels=1,
135
+ out_channels=self.inplanes,
136
+ kernel_size=80,
137
+ stride=4,
138
+ padding=38,
139
+ bias=False,
140
+ )
141
+ self.bn1 = nn.BatchNorm1d(self.inplanes)
142
+
143
+ if relu_type == "relu":
144
+ self.relu = nn.ReLU(inplace=True)
145
+ elif relu_type == "prelu":
146
+ self.relu = nn.PReLU(num_parameters=self.inplanes)
147
+ elif relu_type == "swish":
148
+ self.relu = Swish()
149
+
150
+ self.layer1 = self._make_layer(block, 64, layers[0])
151
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
152
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
153
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
154
+ self.avgpool = nn.AvgPool1d(
155
+ kernel_size=20//self.a_upsample_ratio,
156
+ stride=20//self.a_upsample_ratio,
157
+ )
158
+
159
+
160
+ def _make_layer(self, block, planes, blocks, stride=1):
161
+ """_make_layer.
162
+
163
+ :param block: torch.nn.Module, class of blocks.
164
+ :param planes: int, number of channels produced by the convolution.
165
+ :param blocks: int, number of layers in a block.
166
+ :param stride: int, size of the convolving kernel.
167
+ """
168
+
169
+ downsample = None
170
+ if stride != 1 or self.inplanes != planes * block.expansion:
171
+ downsample = self.downsample_block(
172
+ inplanes=self.inplanes,
173
+ outplanes=planes*block.expansion,
174
+ stride=stride,
175
+ )
176
+
177
+ layers = []
178
+ layers.append(
179
+ block(
180
+ self.inplanes,
181
+ planes,
182
+ stride,
183
+ downsample,
184
+ relu_type=self.relu_type,
185
+ )
186
+ )
187
+ self.inplanes = planes * block.expansion
188
+ for i in range(1, blocks):
189
+ layers.append(
190
+ block(
191
+ self.inplanes,
192
+ planes,
193
+ relu_type=self.relu_type,
194
+ )
195
+ )
196
+
197
+ return nn.Sequential(*layers)
198
+
199
+ def forward(self, x):
200
+ """forward.
201
+
202
+ :param x: torch.Tensor, input tensor with input size (B, C, T)
203
+ """
204
+ x = self.conv1(x)
205
+ x = self.bn1(x)
206
+ x = self.relu(x)
207
+
208
+ x = self.layer1(x)
209
+ x = self.layer2(x)
210
+ x = self.layer3(x)
211
+ x = self.layer4(x)
212
+ x = self.avgpool(x)
213
+ return x
espnet/nets/pytorch_backend/backbones/modules/shufflenetv2.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torch.autograd import Variable
5
+ from collections import OrderedDict
6
+ from torch.nn import init
7
+ import math
8
+
9
+ import pdb
10
+
11
+ def conv_bn(inp, oup, stride):
12
+ return nn.Sequential(
13
+ nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
14
+ nn.BatchNorm2d(oup),
15
+ nn.ReLU(inplace=True)
16
+ )
17
+
18
+
19
+ def conv_1x1_bn(inp, oup):
20
+ return nn.Sequential(
21
+ nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
22
+ nn.BatchNorm2d(oup),
23
+ nn.ReLU(inplace=True)
24
+ )
25
+
26
+ def channel_shuffle(x, groups):
27
+ batchsize, num_channels, height, width = x.data.size()
28
+
29
+ channels_per_group = num_channels // groups
30
+
31
+ # reshape
32
+ x = x.view(batchsize, groups,
33
+ channels_per_group, height, width)
34
+
35
+ x = torch.transpose(x, 1, 2).contiguous()
36
+
37
+ # flatten
38
+ x = x.view(batchsize, -1, height, width)
39
+
40
+ return x
41
+
42
+ class InvertedResidual(nn.Module):
43
+ def __init__(self, inp, oup, stride, benchmodel):
44
+ super(InvertedResidual, self).__init__()
45
+ self.benchmodel = benchmodel
46
+ self.stride = stride
47
+ assert stride in [1, 2]
48
+
49
+ oup_inc = oup//2
50
+
51
+ if self.benchmodel == 1:
52
+ #assert inp == oup_inc
53
+ self.banch2 = nn.Sequential(
54
+ # pw
55
+ nn.Conv2d(oup_inc, oup_inc, 1, 1, 0, bias=False),
56
+ nn.BatchNorm2d(oup_inc),
57
+ nn.ReLU(inplace=True),
58
+ # dw
59
+ nn.Conv2d(oup_inc, oup_inc, 3, stride, 1, groups=oup_inc, bias=False),
60
+ nn.BatchNorm2d(oup_inc),
61
+ # pw-linear
62
+ nn.Conv2d(oup_inc, oup_inc, 1, 1, 0, bias=False),
63
+ nn.BatchNorm2d(oup_inc),
64
+ nn.ReLU(inplace=True),
65
+ )
66
+ else:
67
+ self.banch1 = nn.Sequential(
68
+ # dw
69
+ nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False),
70
+ nn.BatchNorm2d(inp),
71
+ # pw-linear
72
+ nn.Conv2d(inp, oup_inc, 1, 1, 0, bias=False),
73
+ nn.BatchNorm2d(oup_inc),
74
+ nn.ReLU(inplace=True),
75
+ )
76
+
77
+ self.banch2 = nn.Sequential(
78
+ # pw
79
+ nn.Conv2d(inp, oup_inc, 1, 1, 0, bias=False),
80
+ nn.BatchNorm2d(oup_inc),
81
+ nn.ReLU(inplace=True),
82
+ # dw
83
+ nn.Conv2d(oup_inc, oup_inc, 3, stride, 1, groups=oup_inc, bias=False),
84
+ nn.BatchNorm2d(oup_inc),
85
+ # pw-linear
86
+ nn.Conv2d(oup_inc, oup_inc, 1, 1, 0, bias=False),
87
+ nn.BatchNorm2d(oup_inc),
88
+ nn.ReLU(inplace=True),
89
+ )
90
+
91
+ @staticmethod
92
+ def _concat(x, out):
93
+ # concatenate along channel axis
94
+ return torch.cat((x, out), 1)
95
+
96
+ def forward(self, x):
97
+ if 1==self.benchmodel:
98
+ x1 = x[:, :(x.shape[1]//2), :, :]
99
+ x2 = x[:, (x.shape[1]//2):, :, :]
100
+ out = self._concat(x1, self.banch2(x2))
101
+ elif 2==self.benchmodel:
102
+ out = self._concat(self.banch1(x), self.banch2(x))
103
+
104
+ return channel_shuffle(out, 2)
105
+
106
+
107
+ class ShuffleNetV2(nn.Module):
108
+ def __init__(self, n_class=1000, input_size=224, width_mult=2.):
109
+ super(ShuffleNetV2, self).__init__()
110
+
111
+ assert input_size % 32 == 0, "Input size needs to be divisible by 32"
112
+
113
+ self.stage_repeats = [4, 8, 4]
114
+ # index 0 is invalid and should never be called.
115
+ # only used for indexing convenience.
116
+ if width_mult == 0.5:
117
+ self.stage_out_channels = [-1, 24, 48, 96, 192, 1024]
118
+ elif width_mult == 1.0:
119
+ self.stage_out_channels = [-1, 24, 116, 232, 464, 1024]
120
+ elif width_mult == 1.5:
121
+ self.stage_out_channels = [-1, 24, 176, 352, 704, 1024]
122
+ elif width_mult == 2.0:
123
+ self.stage_out_channels = [-1, 24, 244, 488, 976, 2048]
124
+ else:
125
+ raise ValueError(
126
+ """Width multiplier should be in [0.5, 1.0, 1.5, 2.0]. Current value: {}""".format(width_mult))
127
+
128
+ # building first layer
129
+ input_channel = self.stage_out_channels[1]
130
+ self.conv1 = conv_bn(3, input_channel, 2)
131
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
132
+
133
+ self.features = []
134
+ # building inverted residual blocks
135
+ for idxstage in range(len(self.stage_repeats)):
136
+ numrepeat = self.stage_repeats[idxstage]
137
+ output_channel = self.stage_out_channels[idxstage+2]
138
+ for i in range(numrepeat):
139
+ if i == 0:
140
+ #inp, oup, stride, benchmodel):
141
+ self.features.append(InvertedResidual(input_channel, output_channel, 2, 2))
142
+ else:
143
+ self.features.append(InvertedResidual(input_channel, output_channel, 1, 1))
144
+ input_channel = output_channel
145
+
146
+
147
+ # make it nn.Sequential
148
+ self.features = nn.Sequential(*self.features)
149
+
150
+ # building last several layers
151
+ self.conv_last = conv_1x1_bn(input_channel, self.stage_out_channels[-1])
152
+ self.globalpool = nn.Sequential(nn.AvgPool2d(int(input_size/32)))
153
+
154
+ # building classifier
155
+ self.classifier = nn.Sequential(nn.Linear(self.stage_out_channels[-1], n_class))
156
+
157
+ def forward(self, x):
158
+ x = self.conv1(x)
159
+ x = self.maxpool(x)
160
+ x = self.features(x)
161
+ x = self.conv_last(x)
162
+ x = self.globalpool(x)
163
+ x = x.view(-1, self.stage_out_channels[-1])
164
+ x = self.classifier(x)
165
+ return x
espnet/nets/pytorch_backend/ctc.py ADDED
@@ -0,0 +1,283 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from distutils.version import LooseVersion
2
+ import logging
3
+
4
+ import numpy as np
5
+ import six
6
+ import torch
7
+ import torch.nn.functional as F
8
+
9
+ from espnet.nets.pytorch_backend.nets_utils import to_device
10
+
11
+
12
+ class CTC(torch.nn.Module):
13
+ """CTC module
14
+
15
+ :param int odim: dimension of outputs
16
+ :param int eprojs: number of encoder projection units
17
+ :param float dropout_rate: dropout rate (0.0 ~ 1.0)
18
+ :param str ctc_type: builtin or warpctc
19
+ :param bool reduce: reduce the CTC loss into a scalar
20
+ """
21
+
22
+ def __init__(self, odim, eprojs, dropout_rate, ctc_type="warpctc", reduce=True):
23
+ super().__init__()
24
+ self.dropout_rate = dropout_rate
25
+ self.loss = None
26
+ self.ctc_lo = torch.nn.Linear(eprojs, odim)
27
+ self.dropout = torch.nn.Dropout(dropout_rate)
28
+ self.probs = None # for visualization
29
+
30
+ # In case of Pytorch >= 1.7.0, CTC will be always builtin
31
+ self.ctc_type = (
32
+ ctc_type
33
+ if LooseVersion(torch.__version__) < LooseVersion("1.7.0")
34
+ else "builtin"
35
+ )
36
+
37
+ if self.ctc_type == "builtin":
38
+ reduction_type = "sum" if reduce else "none"
39
+ self.ctc_loss = torch.nn.CTCLoss(
40
+ reduction=reduction_type, zero_infinity=True
41
+ )
42
+ elif self.ctc_type == "cudnnctc":
43
+ reduction_type = "sum" if reduce else "none"
44
+ self.ctc_loss = torch.nn.CTCLoss(reduction=reduction_type)
45
+ elif self.ctc_type == "warpctc":
46
+ import warpctc_pytorch as warp_ctc
47
+
48
+ self.ctc_loss = warp_ctc.CTCLoss(size_average=True, reduce=reduce)
49
+ elif self.ctc_type == "gtnctc":
50
+ from espnet.nets.pytorch_backend.gtn_ctc import GTNCTCLossFunction
51
+
52
+ self.ctc_loss = GTNCTCLossFunction.apply
53
+ else:
54
+ raise ValueError(
55
+ 'ctc_type must be "builtin" or "warpctc": {}'.format(self.ctc_type)
56
+ )
57
+
58
+ self.ignore_id = -1
59
+ self.reduce = reduce
60
+
61
+ def loss_fn(self, th_pred, th_target, th_ilen, th_olen):
62
+ if self.ctc_type in ["builtin", "cudnnctc"]:
63
+ th_pred = th_pred.log_softmax(2)
64
+ # Use the deterministic CuDNN implementation of CTC loss to avoid
65
+ # [issue#17798](https://github.com/pytorch/pytorch/issues/17798)
66
+ with torch.backends.cudnn.flags(deterministic=True):
67
+ loss = self.ctc_loss(th_pred, th_target, th_ilen, th_olen)
68
+ # Batch-size average
69
+ loss = loss / th_pred.size(1)
70
+ return loss
71
+ elif self.ctc_type == "warpctc":
72
+ return self.ctc_loss(th_pred, th_target, th_ilen, th_olen)
73
+ elif self.ctc_type == "gtnctc":
74
+ targets = [t.tolist() for t in th_target]
75
+ log_probs = torch.nn.functional.log_softmax(th_pred, dim=2)
76
+ return self.ctc_loss(log_probs, targets, th_ilen, 0, "none")
77
+ else:
78
+ raise NotImplementedError
79
+
80
+ def forward(self, hs_pad, hlens, ys_pad):
81
+ """CTC forward
82
+
83
+ :param torch.Tensor hs_pad: batch of padded hidden state sequences (B, Tmax, D)
84
+ :param torch.Tensor hlens: batch of lengths of hidden state sequences (B)
85
+ :param torch.Tensor ys_pad:
86
+ batch of padded character id sequence tensor (B, Lmax)
87
+ :return: ctc loss value
88
+ :rtype: torch.Tensor
89
+ """
90
+ # TODO(kan-bayashi): need to make more smart way
91
+ ys = [y[y != self.ignore_id] for y in ys_pad] # parse padded ys
92
+
93
+ # zero padding for hs
94
+ ys_hat = self.ctc_lo(self.dropout(hs_pad))
95
+ if self.ctc_type != "gtnctc":
96
+ ys_hat = ys_hat.transpose(0, 1)
97
+
98
+ if self.ctc_type == "builtin":
99
+ olens = to_device(ys_hat, torch.LongTensor([len(s) for s in ys]))
100
+ hlens = hlens.long()
101
+ ys_pad = torch.cat(ys) # without this the code breaks for asr_mix
102
+ self.loss = self.loss_fn(ys_hat, ys_pad, hlens, olens)
103
+ else:
104
+ self.loss = None
105
+ hlens = torch.from_numpy(np.fromiter(hlens, dtype=np.int32))
106
+ olens = torch.from_numpy(
107
+ np.fromiter((x.size(0) for x in ys), dtype=np.int32)
108
+ )
109
+ # zero padding for ys
110
+ ys_true = torch.cat(ys).cpu().int() # batch x olen
111
+ # get ctc loss
112
+ # expected shape of seqLength x batchSize x alphabet_size
113
+ dtype = ys_hat.dtype
114
+ if self.ctc_type == "warpctc" or dtype == torch.float16:
115
+ # warpctc only supports float32
116
+ # torch.ctc does not support float16 (#1751)
117
+ ys_hat = ys_hat.to(dtype=torch.float32)
118
+ if self.ctc_type == "cudnnctc":
119
+ # use GPU when using the cuDNN implementation
120
+ ys_true = to_device(hs_pad, ys_true)
121
+ if self.ctc_type == "gtnctc":
122
+ # keep as list for gtn
123
+ ys_true = ys
124
+ self.loss = to_device(
125
+ hs_pad, self.loss_fn(ys_hat, ys_true, hlens, olens)
126
+ ).to(dtype=dtype)
127
+
128
+ # get length info
129
+ logging.info(
130
+ self.__class__.__name__
131
+ + " input lengths: "
132
+ + "".join(str(hlens).split("\n"))
133
+ )
134
+ logging.info(
135
+ self.__class__.__name__
136
+ + " output lengths: "
137
+ + "".join(str(olens).split("\n"))
138
+ )
139
+
140
+ if self.reduce:
141
+ # NOTE: sum() is needed to keep consistency
142
+ # since warpctc return as tensor w/ shape (1,)
143
+ # but builtin return as tensor w/o shape (scalar).
144
+ self.loss = self.loss.sum()
145
+ logging.info("ctc loss:" + str(float(self.loss)))
146
+
147
+ return self.loss
148
+
149
+ def softmax(self, hs_pad):
150
+ """softmax of frame activations
151
+
152
+ :param torch.Tensor hs_pad: 3d tensor (B, Tmax, eprojs)
153
+ :return: log softmax applied 3d tensor (B, Tmax, odim)
154
+ :rtype: torch.Tensor
155
+ """
156
+ self.probs = F.softmax(self.ctc_lo(hs_pad), dim=2)
157
+ return self.probs
158
+
159
+ def log_softmax(self, hs_pad):
160
+ """log_softmax of frame activations
161
+
162
+ :param torch.Tensor hs_pad: 3d tensor (B, Tmax, eprojs)
163
+ :return: log softmax applied 3d tensor (B, Tmax, odim)
164
+ :rtype: torch.Tensor
165
+ """
166
+ return F.log_softmax(self.ctc_lo(hs_pad), dim=2)
167
+
168
+ def argmax(self, hs_pad):
169
+ """argmax of frame activations
170
+
171
+ :param torch.Tensor hs_pad: 3d tensor (B, Tmax, eprojs)
172
+ :return: argmax applied 2d tensor (B, Tmax)
173
+ :rtype: torch.Tensor
174
+ """
175
+ return torch.argmax(self.ctc_lo(hs_pad), dim=2)
176
+
177
+ def forced_align(self, h, y, blank_id=0):
178
+ """forced alignment.
179
+
180
+ :param torch.Tensor h: hidden state sequence, 2d tensor (T, D)
181
+ :param torch.Tensor y: id sequence tensor 1d tensor (L)
182
+ :param int y: blank symbol index
183
+ :return: best alignment results
184
+ :rtype: list
185
+ """
186
+
187
+ def interpolate_blank(label, blank_id=0):
188
+ """Insert blank token between every two label token."""
189
+ label = np.expand_dims(label, 1)
190
+ blanks = np.zeros((label.shape[0], 1), dtype=np.int64) + blank_id
191
+ label = np.concatenate([blanks, label], axis=1)
192
+ label = label.reshape(-1)
193
+ label = np.append(label, label[0])
194
+ return label
195
+
196
+ lpz = self.log_softmax(h)
197
+ lpz = lpz.squeeze(0)
198
+
199
+ y_int = interpolate_blank(y, blank_id)
200
+
201
+ logdelta = np.zeros((lpz.size(0), len(y_int))) - 100000000000.0 # log of zero
202
+ state_path = (
203
+ np.zeros((lpz.size(0), len(y_int)), dtype=np.int16) - 1
204
+ ) # state path
205
+
206
+ logdelta[0, 0] = lpz[0][y_int[0]]
207
+ logdelta[0, 1] = lpz[0][y_int[1]]
208
+
209
+ for t in six.moves.range(1, lpz.size(0)):
210
+ for s in six.moves.range(len(y_int)):
211
+ if y_int[s] == blank_id or s < 2 or y_int[s] == y_int[s - 2]:
212
+ candidates = np.array([logdelta[t - 1, s], logdelta[t - 1, s - 1]])
213
+ prev_state = [s, s - 1]
214
+ else:
215
+ candidates = np.array(
216
+ [
217
+ logdelta[t - 1, s],
218
+ logdelta[t - 1, s - 1],
219
+ logdelta[t - 1, s - 2],
220
+ ]
221
+ )
222
+ prev_state = [s, s - 1, s - 2]
223
+ logdelta[t, s] = np.max(candidates) + lpz[t][y_int[s]]
224
+ state_path[t, s] = prev_state[np.argmax(candidates)]
225
+
226
+ state_seq = -1 * np.ones((lpz.size(0), 1), dtype=np.int16)
227
+
228
+ candidates = np.array(
229
+ [logdelta[-1, len(y_int) - 1], logdelta[-1, len(y_int) - 2]]
230
+ )
231
+ prev_state = [len(y_int) - 1, len(y_int) - 2]
232
+ state_seq[-1] = prev_state[np.argmax(candidates)]
233
+ for t in six.moves.range(lpz.size(0) - 2, -1, -1):
234
+ state_seq[t] = state_path[t + 1, state_seq[t + 1, 0]]
235
+
236
+ output_state_seq = []
237
+ for t in six.moves.range(0, lpz.size(0)):
238
+ output_state_seq.append(y_int[state_seq[t, 0]])
239
+
240
+ return output_state_seq
241
+
242
+
243
+ def ctc_for(args, odim, reduce=True):
244
+ """Returns the CTC module for the given args and output dimension
245
+
246
+ :param Namespace args: the program args
247
+ :param int odim : The output dimension
248
+ :param bool reduce : return the CTC loss in a scalar
249
+ :return: the corresponding CTC module
250
+ """
251
+ num_encs = getattr(args, "num_encs", 1) # use getattr to keep compatibility
252
+ if num_encs == 1:
253
+ # compatible with single encoder asr mode
254
+ return CTC(
255
+ odim, args.eprojs, args.dropout_rate, ctc_type=args.ctc_type, reduce=reduce
256
+ )
257
+ elif num_encs >= 1:
258
+ ctcs_list = torch.nn.ModuleList()
259
+ if args.share_ctc:
260
+ # use dropout_rate of the first encoder
261
+ ctc = CTC(
262
+ odim,
263
+ args.eprojs,
264
+ args.dropout_rate[0],
265
+ ctc_type=args.ctc_type,
266
+ reduce=reduce,
267
+ )
268
+ ctcs_list.append(ctc)
269
+ else:
270
+ for idx in range(num_encs):
271
+ ctc = CTC(
272
+ odim,
273
+ args.eprojs,
274
+ args.dropout_rate[idx],
275
+ ctc_type=args.ctc_type,
276
+ reduce=reduce,
277
+ )
278
+ ctcs_list.append(ctc)
279
+ return ctcs_list
280
+ else:
281
+ raise ValueError(
282
+ "Number of encoders needs to be more than one. {}".format(num_encs)
283
+ )
espnet/nets/pytorch_backend/e2e_asr_transformer.py ADDED
@@ -0,0 +1,320 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2019 Shigeki Karita
2
+ # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
3
+
4
+ """Transformer speech recognition model (pytorch)."""
5
+
6
+ from argparse import Namespace
7
+ from distutils.util import strtobool
8
+ import logging
9
+ import math
10
+
11
+ import numpy
12
+ import torch
13
+
14
+ from espnet.nets.ctc_prefix_score import CTCPrefixScore
15
+ from espnet.nets.e2e_asr_common import end_detect
16
+ from espnet.nets.e2e_asr_common import ErrorCalculator
17
+ from espnet.nets.pytorch_backend.ctc import CTC
18
+ from espnet.nets.pytorch_backend.nets_utils import get_subsample
19
+ from espnet.nets.pytorch_backend.nets_utils import make_non_pad_mask
20
+ from espnet.nets.pytorch_backend.nets_utils import th_accuracy
21
+ from espnet.nets.pytorch_backend.transformer.add_sos_eos import add_sos_eos
22
+ from espnet.nets.pytorch_backend.transformer.attention import (
23
+ MultiHeadedAttention, # noqa: H301
24
+ RelPositionMultiHeadedAttention, # noqa: H301
25
+ )
26
+ from espnet.nets.pytorch_backend.transformer.decoder import Decoder
27
+ from espnet.nets.pytorch_backend.transformer.encoder import Encoder
28
+ from espnet.nets.pytorch_backend.transformer.label_smoothing_loss import (
29
+ LabelSmoothingLoss, # noqa: H301
30
+ )
31
+ from espnet.nets.pytorch_backend.transformer.mask import subsequent_mask
32
+ from espnet.nets.pytorch_backend.transformer.mask import target_mask
33
+ from espnet.nets.scorers.ctc import CTCPrefixScorer
34
+
35
+
36
+ class E2E(torch.nn.Module):
37
+ """E2E module.
38
+
39
+ :param int idim: dimension of inputs
40
+ :param int odim: dimension of outputs
41
+ :param Namespace args: argument Namespace containing options
42
+
43
+ """
44
+
45
+ @staticmethod
46
+ def add_arguments(parser):
47
+ """Add arguments."""
48
+ group = parser.add_argument_group("transformer model setting")
49
+
50
+ group.add_argument(
51
+ "--transformer-init",
52
+ type=str,
53
+ default="pytorch",
54
+ choices=[
55
+ "pytorch",
56
+ "xavier_uniform",
57
+ "xavier_normal",
58
+ "kaiming_uniform",
59
+ "kaiming_normal",
60
+ ],
61
+ help="how to initialize transformer parameters",
62
+ )
63
+ group.add_argument(
64
+ "--transformer-input-layer",
65
+ type=str,
66
+ default="conv2d",
67
+ choices=["conv3d", "conv2d", "conv1d", "linear", "embed"],
68
+ help="transformer input layer type",
69
+ )
70
+ group.add_argument(
71
+ "--transformer-encoder-attn-layer-type",
72
+ type=str,
73
+ default="mha",
74
+ choices=["mha", "rel_mha", "legacy_rel_mha"],
75
+ help="transformer encoder attention layer type",
76
+ )
77
+ group.add_argument(
78
+ "--transformer-attn-dropout-rate",
79
+ default=None,
80
+ type=float,
81
+ help="dropout in transformer attention. use --dropout-rate if None is set",
82
+ )
83
+ group.add_argument(
84
+ "--transformer-lr",
85
+ default=10.0,
86
+ type=float,
87
+ help="Initial value of learning rate",
88
+ )
89
+ group.add_argument(
90
+ "--transformer-warmup-steps",
91
+ default=25000,
92
+ type=int,
93
+ help="optimizer warmup steps",
94
+ )
95
+ group.add_argument(
96
+ "--transformer-length-normalized-loss",
97
+ default=True,
98
+ type=strtobool,
99
+ help="normalize loss by length",
100
+ )
101
+ group.add_argument(
102
+ "--dropout-rate",
103
+ default=0.0,
104
+ type=float,
105
+ help="Dropout rate for the encoder",
106
+ )
107
+ group.add_argument(
108
+ "--macaron-style",
109
+ default=False,
110
+ type=strtobool,
111
+ help="Whether to use macaron style for positionwise layer",
112
+ )
113
+ # -- input
114
+ group.add_argument(
115
+ "--a-upsample-ratio",
116
+ default=1,
117
+ type=int,
118
+ help="Upsample rate for audio",
119
+ )
120
+ group.add_argument(
121
+ "--relu-type",
122
+ default="swish",
123
+ type=str,
124
+ help="the type of activation layer",
125
+ )
126
+ # Encoder
127
+ group.add_argument(
128
+ "--elayers",
129
+ default=4,
130
+ type=int,
131
+ help="Number of encoder layers (for shared recognition part "
132
+ "in multi-speaker asr mode)",
133
+ )
134
+ group.add_argument(
135
+ "--eunits",
136
+ "-u",
137
+ default=300,
138
+ type=int,
139
+ help="Number of encoder hidden units",
140
+ )
141
+ group.add_argument(
142
+ "--use-cnn-module",
143
+ default=False,
144
+ type=strtobool,
145
+ help="Use convolution module or not",
146
+ )
147
+ group.add_argument(
148
+ "--cnn-module-kernel",
149
+ default=31,
150
+ type=int,
151
+ help="Kernel size of convolution module.",
152
+ )
153
+ # Attention
154
+ group.add_argument(
155
+ "--adim",
156
+ default=320,
157
+ type=int,
158
+ help="Number of attention transformation dimensions",
159
+ )
160
+ group.add_argument(
161
+ "--aheads",
162
+ default=4,
163
+ type=int,
164
+ help="Number of heads for multi head attention",
165
+ )
166
+ group.add_argument(
167
+ "--zero-triu",
168
+ default=False,
169
+ type=strtobool,
170
+ help="If true, zero the uppper triangular part of attention matrix.",
171
+ )
172
+ # Relative positional encoding
173
+ group.add_argument(
174
+ "--rel-pos-type",
175
+ type=str,
176
+ default="legacy",
177
+ choices=["legacy", "latest"],
178
+ help="Whether to use the latest relative positional encoding or the legacy one."
179
+ "The legacy relative positional encoding will be deprecated in the future."
180
+ "More Details can be found in https://github.com/espnet/espnet/pull/2816.",
181
+ )
182
+ # Decoder
183
+ group.add_argument(
184
+ "--dlayers", default=1, type=int, help="Number of decoder layers"
185
+ )
186
+ group.add_argument(
187
+ "--dunits", default=320, type=int, help="Number of decoder hidden units"
188
+ )
189
+ # -- pretrain
190
+ group.add_argument("--pretrain-dataset",
191
+ default="",
192
+ type=str,
193
+ help='pre-trained dataset for encoder'
194
+ )
195
+ # -- custom name
196
+ group.add_argument("--custom-pretrain-name",
197
+ default="",
198
+ type=str,
199
+ help='pre-trained model for encoder'
200
+ )
201
+ return parser
202
+
203
+ @property
204
+ def attention_plot_class(self):
205
+ """Return PlotAttentionReport."""
206
+ return PlotAttentionReport
207
+
208
+ def __init__(self, odim, args, ignore_id=-1):
209
+ """Construct an E2E object.
210
+ :param int odim: dimension of outputs
211
+ :param Namespace args: argument Namespace containing options
212
+ """
213
+ torch.nn.Module.__init__(self)
214
+ if args.transformer_attn_dropout_rate is None:
215
+ args.transformer_attn_dropout_rate = args.dropout_rate
216
+ # Check the relative positional encoding type
217
+ self.rel_pos_type = getattr(args, "rel_pos_type", None)
218
+ if self.rel_pos_type is None and args.transformer_encoder_attn_layer_type == "rel_mha":
219
+ args.transformer_encoder_attn_layer_type = "legacy_rel_mha"
220
+ logging.warning(
221
+ "Using legacy_rel_pos and it will be deprecated in the future."
222
+ )
223
+
224
+ idim = 80
225
+
226
+ self.encoder = Encoder(
227
+ idim=idim,
228
+ attention_dim=args.adim,
229
+ attention_heads=args.aheads,
230
+ linear_units=args.eunits,
231
+ num_blocks=args.elayers,
232
+ input_layer=args.transformer_input_layer,
233
+ dropout_rate=args.dropout_rate,
234
+ positional_dropout_rate=args.dropout_rate,
235
+ attention_dropout_rate=args.transformer_attn_dropout_rate,
236
+ encoder_attn_layer_type=args.transformer_encoder_attn_layer_type,
237
+ macaron_style=args.macaron_style,
238
+ use_cnn_module=args.use_cnn_module,
239
+ cnn_module_kernel=args.cnn_module_kernel,
240
+ zero_triu=getattr(args, "zero_triu", False),
241
+ a_upsample_ratio=args.a_upsample_ratio,
242
+ relu_type=getattr(args, "relu_type", "swish"),
243
+ )
244
+
245
+ self.transformer_input_layer = args.transformer_input_layer
246
+ self.a_upsample_ratio = args.a_upsample_ratio
247
+
248
+ if args.mtlalpha < 1:
249
+ self.decoder = Decoder(
250
+ odim=odim,
251
+ attention_dim=args.adim,
252
+ attention_heads=args.aheads,
253
+ linear_units=args.dunits,
254
+ num_blocks=args.dlayers,
255
+ dropout_rate=args.dropout_rate,
256
+ positional_dropout_rate=args.dropout_rate,
257
+ self_attention_dropout_rate=args.transformer_attn_dropout_rate,
258
+ src_attention_dropout_rate=args.transformer_attn_dropout_rate,
259
+ )
260
+ else:
261
+ self.decoder = None
262
+ self.blank = 0
263
+ self.sos = odim - 1
264
+ self.eos = odim - 1
265
+ self.odim = odim
266
+ self.ignore_id = ignore_id
267
+ self.subsample = get_subsample(args, mode="asr", arch="transformer")
268
+
269
+ # self.lsm_weight = a
270
+ self.criterion = LabelSmoothingLoss(
271
+ self.odim,
272
+ self.ignore_id,
273
+ args.lsm_weight,
274
+ args.transformer_length_normalized_loss,
275
+ )
276
+
277
+ self.adim = args.adim
278
+ self.mtlalpha = args.mtlalpha
279
+ if args.mtlalpha > 0.0:
280
+ self.ctc = CTC(
281
+ odim, args.adim, args.dropout_rate, ctc_type=args.ctc_type, reduce=True
282
+ )
283
+ else:
284
+ self.ctc = None
285
+
286
+ if args.report_cer or args.report_wer:
287
+ self.error_calculator = ErrorCalculator(
288
+ args.char_list,
289
+ args.sym_space,
290
+ args.sym_blank,
291
+ args.report_cer,
292
+ args.report_wer,
293
+ )
294
+ else:
295
+ self.error_calculator = None
296
+ self.rnnlm = None
297
+
298
+ def scorers(self):
299
+ """Scorers."""
300
+ return dict(decoder=self.decoder, ctc=CTCPrefixScorer(self.ctc, self.eos))
301
+
302
+ def encode(self, x, extract_resnet_feats=False):
303
+ """Encode acoustic features.
304
+
305
+ :param ndarray x: source acoustic feature (T, D)
306
+ :return: encoder outputs
307
+ :rtype: torch.Tensor
308
+ """
309
+ self.eval()
310
+ x = torch.as_tensor(x).unsqueeze(0)
311
+ if extract_resnet_feats:
312
+ resnet_feats = self.encoder(
313
+ x,
314
+ None,
315
+ extract_resnet_feats=extract_resnet_feats,
316
+ )
317
+ return resnet_feats.squeeze(0)
318
+ else:
319
+ enc_output, _ = self.encoder(x, None)
320
+ return enc_output.squeeze(0)
espnet/nets/pytorch_backend/e2e_asr_transformer_av.py ADDED
@@ -0,0 +1,352 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2019 Shigeki Karita
2
+ # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
3
+
4
+ """Transformer speech recognition model (pytorch)."""
5
+
6
+ from argparse import Namespace
7
+ from distutils.util import strtobool
8
+ import logging
9
+ import math
10
+
11
+ import numpy
12
+ import torch
13
+
14
+ from espnet.nets.ctc_prefix_score import CTCPrefixScore
15
+ from espnet.nets.e2e_asr_common import end_detect
16
+ from espnet.nets.e2e_asr_common import ErrorCalculator
17
+ from espnet.nets.pytorch_backend.ctc import CTC
18
+ from espnet.nets.pytorch_backend.nets_utils import get_subsample
19
+ from espnet.nets.pytorch_backend.nets_utils import make_non_pad_mask
20
+ from espnet.nets.pytorch_backend.nets_utils import th_accuracy
21
+ from espnet.nets.pytorch_backend.transformer.add_sos_eos import add_sos_eos
22
+ from espnet.nets.pytorch_backend.transformer.attention import (
23
+ MultiHeadedAttention, # noqa: H301
24
+ RelPositionMultiHeadedAttention, # noqa: H301
25
+ )
26
+ from espnet.nets.pytorch_backend.transformer.decoder import Decoder
27
+ from espnet.nets.pytorch_backend.transformer.encoder import Encoder
28
+ from espnet.nets.pytorch_backend.transformer.label_smoothing_loss import (
29
+ LabelSmoothingLoss, # noqa: H301
30
+ )
31
+ from espnet.nets.pytorch_backend.transformer.mask import subsequent_mask
32
+ from espnet.nets.pytorch_backend.transformer.mask import target_mask
33
+ from espnet.nets.scorers.ctc import CTCPrefixScorer
34
+ from espnet.nets.pytorch_backend.nets_utils import MLPHead
35
+
36
+
37
+ class E2E(torch.nn.Module):
38
+ """E2E module.
39
+
40
+ :param int idim: dimension of inputs
41
+ :param int odim: dimension of outputs
42
+ :param Namespace args: argument Namespace containing options
43
+
44
+ """
45
+
46
+ @staticmethod
47
+ def add_arguments(parser):
48
+ """Add arguments."""
49
+ group = parser.add_argument_group("transformer model setting")
50
+
51
+ group.add_argument(
52
+ "--transformer-init",
53
+ type=str,
54
+ default="pytorch",
55
+ choices=[
56
+ "pytorch",
57
+ "xavier_uniform",
58
+ "xavier_normal",
59
+ "kaiming_uniform",
60
+ "kaiming_normal",
61
+ ],
62
+ help="how to initialize transformer parameters",
63
+ )
64
+ group.add_argument(
65
+ "--transformer-input-layer",
66
+ type=str,
67
+ default="conv2d",
68
+ choices=["conv3d", "conv2d", "conv1d", "linear", "embed"],
69
+ help="transformer input layer type",
70
+ )
71
+ group.add_argument(
72
+ "--transformer-encoder-attn-layer-type",
73
+ type=str,
74
+ default="mha",
75
+ choices=["mha", "rel_mha", "legacy_rel_mha"],
76
+ help="transformer encoder attention layer type",
77
+ )
78
+ group.add_argument(
79
+ "--transformer-attn-dropout-rate",
80
+ default=None,
81
+ type=float,
82
+ help="dropout in transformer attention. use --dropout-rate if None is set",
83
+ )
84
+ group.add_argument(
85
+ "--transformer-lr",
86
+ default=10.0,
87
+ type=float,
88
+ help="Initial value of learning rate",
89
+ )
90
+ group.add_argument(
91
+ "--transformer-warmup-steps",
92
+ default=25000,
93
+ type=int,
94
+ help="optimizer warmup steps",
95
+ )
96
+ group.add_argument(
97
+ "--transformer-length-normalized-loss",
98
+ default=True,
99
+ type=strtobool,
100
+ help="normalize loss by length",
101
+ )
102
+ group.add_argument(
103
+ "--dropout-rate",
104
+ default=0.0,
105
+ type=float,
106
+ help="Dropout rate for the encoder",
107
+ )
108
+ group.add_argument(
109
+ "--macaron-style",
110
+ default=False,
111
+ type=strtobool,
112
+ help="Whether to use macaron style for positionwise layer",
113
+ )
114
+ # -- input
115
+ group.add_argument(
116
+ "--a-upsample-ratio",
117
+ default=1,
118
+ type=int,
119
+ help="Upsample rate for audio",
120
+ )
121
+ group.add_argument(
122
+ "--relu-type",
123
+ default="swish",
124
+ type=str,
125
+ help="the type of activation layer",
126
+ )
127
+ # Encoder
128
+ group.add_argument(
129
+ "--elayers",
130
+ default=4,
131
+ type=int,
132
+ help="Number of encoder layers (for shared recognition part "
133
+ "in multi-speaker asr mode)",
134
+ )
135
+ group.add_argument(
136
+ "--eunits",
137
+ "-u",
138
+ default=300,
139
+ type=int,
140
+ help="Number of encoder hidden units",
141
+ )
142
+ group.add_argument(
143
+ "--use-cnn-module",
144
+ default=False,
145
+ type=strtobool,
146
+ help="Use convolution module or not",
147
+ )
148
+ group.add_argument(
149
+ "--cnn-module-kernel",
150
+ default=31,
151
+ type=int,
152
+ help="Kernel size of convolution module.",
153
+ )
154
+ # Attention
155
+ group.add_argument(
156
+ "--adim",
157
+ default=320,
158
+ type=int,
159
+ help="Number of attention transformation dimensions",
160
+ )
161
+ group.add_argument(
162
+ "--aheads",
163
+ default=4,
164
+ type=int,
165
+ help="Number of heads for multi head attention",
166
+ )
167
+ group.add_argument(
168
+ "--zero-triu",
169
+ default=False,
170
+ type=strtobool,
171
+ help="If true, zero the uppper triangular part of attention matrix.",
172
+ )
173
+ # Relative positional encoding
174
+ group.add_argument(
175
+ "--rel-pos-type",
176
+ type=str,
177
+ default="legacy",
178
+ choices=["legacy", "latest"],
179
+ help="Whether to use the latest relative positional encoding or the legacy one."
180
+ "The legacy relative positional encoding will be deprecated in the future."
181
+ "More Details can be found in https://github.com/espnet/espnet/pull/2816.",
182
+ )
183
+ # Decoder
184
+ group.add_argument(
185
+ "--dlayers", default=1, type=int, help="Number of decoder layers"
186
+ )
187
+ group.add_argument(
188
+ "--dunits", default=320, type=int, help="Number of decoder hidden units"
189
+ )
190
+ # -- pretrain
191
+ group.add_argument("--pretrain-dataset",
192
+ default="",
193
+ type=str,
194
+ help='pre-trained dataset for encoder'
195
+ )
196
+ # -- custom name
197
+ group.add_argument("--custom-pretrain-name",
198
+ default="",
199
+ type=str,
200
+ help='pre-trained model for encoder'
201
+ )
202
+ return parser
203
+
204
+ @property
205
+ def attention_plot_class(self):
206
+ """Return PlotAttentionReport."""
207
+ return PlotAttentionReport
208
+
209
+ def __init__(self, odim, args, ignore_id=-1):
210
+ """Construct an E2E object.
211
+ :param int odim: dimension of outputs
212
+ :param Namespace args: argument Namespace containing options
213
+ """
214
+ torch.nn.Module.__init__(self)
215
+ if args.transformer_attn_dropout_rate is None:
216
+ args.transformer_attn_dropout_rate = args.dropout_rate
217
+ # Check the relative positional encoding type
218
+ self.rel_pos_type = getattr(args, "rel_pos_type", None)
219
+ if self.rel_pos_type is None and args.transformer_encoder_attn_layer_type == "rel_mha":
220
+ args.transformer_encoder_attn_layer_type = "legacy_rel_mha"
221
+ logging.warning(
222
+ "Using legacy_rel_pos and it will be deprecated in the future."
223
+ )
224
+
225
+ idim = 80
226
+
227
+ self.encoder = Encoder(
228
+ idim=idim,
229
+ attention_dim=args.adim,
230
+ attention_heads=args.aheads,
231
+ linear_units=args.eunits,
232
+ num_blocks=args.elayers,
233
+ input_layer=args.transformer_input_layer,
234
+ dropout_rate=args.dropout_rate,
235
+ positional_dropout_rate=args.dropout_rate,
236
+ attention_dropout_rate=args.transformer_attn_dropout_rate,
237
+ encoder_attn_layer_type=args.transformer_encoder_attn_layer_type,
238
+ macaron_style=args.macaron_style,
239
+ use_cnn_module=args.use_cnn_module,
240
+ cnn_module_kernel=args.cnn_module_kernel,
241
+ zero_triu=getattr(args, "zero_triu", False),
242
+ a_upsample_ratio=args.a_upsample_ratio,
243
+ relu_type=getattr(args, "relu_type", "swish"),
244
+ )
245
+
246
+ self.transformer_input_layer = args.transformer_input_layer
247
+ self.a_upsample_ratio = args.a_upsample_ratio
248
+
249
+ self.aux_encoder = Encoder(
250
+ idim=idim,
251
+ attention_dim=args.aux_adim,
252
+ attention_heads=args.aux_aheads,
253
+ linear_units=args.aux_eunits,
254
+ num_blocks=args.aux_elayers,
255
+ input_layer=args.aux_transformer_input_layer,
256
+ dropout_rate=args.aux_dropout_rate,
257
+ positional_dropout_rate=args.aux_dropout_rate,
258
+ attention_dropout_rate=args.aux_transformer_attn_dropout_rate,
259
+ encoder_attn_layer_type=args.aux_transformer_encoder_attn_layer_type,
260
+ macaron_style=args.aux_macaron_style,
261
+ use_cnn_module=args.aux_use_cnn_module,
262
+ cnn_module_kernel=args.aux_cnn_module_kernel,
263
+ zero_triu=getattr(args, "aux_zero_triu", False),
264
+ a_upsample_ratio=args.aux_a_upsample_ratio,
265
+ relu_type=getattr(args, "aux_relu_type", "swish"),
266
+ )
267
+ self.aux_transformer_input_layer = args.aux_transformer_input_layer
268
+
269
+ self.fusion = MLPHead(
270
+ idim=args.adim + args.aux_adim,
271
+ hdim=args.fusion_hdim,
272
+ odim=args.adim,
273
+ norm=args.fusion_norm,
274
+ )
275
+
276
+ if args.mtlalpha < 1:
277
+ self.decoder = Decoder(
278
+ odim=odim,
279
+ attention_dim=args.adim,
280
+ attention_heads=args.aheads,
281
+ linear_units=args.dunits,
282
+ num_blocks=args.dlayers,
283
+ dropout_rate=args.dropout_rate,
284
+ positional_dropout_rate=args.dropout_rate,
285
+ self_attention_dropout_rate=args.transformer_attn_dropout_rate,
286
+ src_attention_dropout_rate=args.transformer_attn_dropout_rate,
287
+ )
288
+ else:
289
+ self.decoder = None
290
+ self.blank = 0
291
+ self.sos = odim - 1
292
+ self.eos = odim - 1
293
+ self.odim = odim
294
+ self.ignore_id = ignore_id
295
+ self.subsample = get_subsample(args, mode="asr", arch="transformer")
296
+
297
+ # self.lsm_weight = a
298
+ self.criterion = LabelSmoothingLoss(
299
+ self.odim,
300
+ self.ignore_id,
301
+ args.lsm_weight,
302
+ args.transformer_length_normalized_loss,
303
+ )
304
+
305
+ self.adim = args.adim
306
+ self.mtlalpha = args.mtlalpha
307
+ if args.mtlalpha > 0.0:
308
+ self.ctc = CTC(
309
+ odim, args.adim, args.dropout_rate, ctc_type=args.ctc_type, reduce=True
310
+ )
311
+ else:
312
+ self.ctc = None
313
+
314
+ if args.report_cer or args.report_wer:
315
+ self.error_calculator = ErrorCalculator(
316
+ args.char_list,
317
+ args.sym_space,
318
+ args.sym_blank,
319
+ args.report_cer,
320
+ args.report_wer,
321
+ )
322
+ else:
323
+ self.error_calculator = None
324
+ self.rnnlm = None
325
+
326
+ def scorers(self):
327
+ """Scorers."""
328
+ return dict(decoder=self.decoder, ctc=CTCPrefixScorer(self.ctc, self.eos))
329
+
330
+ def encode(self, x, aux_x, extract_resnet_feats=False):
331
+ """Encode acoustic features.
332
+
333
+ :param ndarray x: source acoustic feature (T, D)
334
+ :return: encoder outputs
335
+ :rtype: torch.Tensor
336
+ """
337
+ self.eval()
338
+ if extract_resnet_feats:
339
+ x = torch.as_tensor(x).unsqueeze(0)
340
+ resnet_feats = self.encoder(
341
+ x,
342
+ None,
343
+ extract_resnet_feats=extract_resnet_feats,
344
+ )
345
+ return resnet_feats.squeeze(0)
346
+ else:
347
+ x = torch.as_tensor(x).unsqueeze(0)
348
+ aux_x = torch.as_tensor(aux_x).unsqueeze(0)
349
+ feat, _ = self.encoder(x, None)
350
+ aux_feat, _ = self.aux_encoder(aux_x, None)
351
+ fus_output = self.fusion(torch.cat((feat, aux_feat), dim=-1))
352
+ return fus_output.squeeze(0)
espnet/nets/pytorch_backend/lm/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Initialize sub package."""
espnet/nets/pytorch_backend/lm/default.py ADDED
@@ -0,0 +1,431 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Default Recurrent Neural Network Languge Model in `lm_train.py`."""
2
+
3
+ from typing import Any
4
+ from typing import List
5
+ from typing import Tuple
6
+
7
+ import logging
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+
12
+ from espnet.nets.lm_interface import LMInterface
13
+ from espnet.nets.pytorch_backend.e2e_asr import to_device
14
+ from espnet.nets.scorer_interface import BatchScorerInterface
15
+ from espnet.utils.cli_utils import strtobool
16
+
17
+
18
+ class DefaultRNNLM(BatchScorerInterface, LMInterface, nn.Module):
19
+ """Default RNNLM for `LMInterface` Implementation.
20
+
21
+ Note:
22
+ PyTorch seems to have memory leak when one GPU compute this after data parallel.
23
+ If parallel GPUs compute this, it seems to be fine.
24
+ See also https://github.com/espnet/espnet/issues/1075
25
+
26
+ """
27
+
28
+ @staticmethod
29
+ def add_arguments(parser):
30
+ """Add arguments to command line argument parser."""
31
+ parser.add_argument(
32
+ "--type",
33
+ type=str,
34
+ default="lstm",
35
+ nargs="?",
36
+ choices=["lstm", "gru"],
37
+ help="Which type of RNN to use",
38
+ )
39
+ parser.add_argument(
40
+ "--layer", "-l", type=int, default=2, help="Number of hidden layers"
41
+ )
42
+ parser.add_argument(
43
+ "--unit", "-u", type=int, default=650, help="Number of hidden units"
44
+ )
45
+ parser.add_argument(
46
+ "--embed-unit",
47
+ default=None,
48
+ type=int,
49
+ help="Number of hidden units in embedding layer, "
50
+ "if it is not specified, it keeps the same number with hidden units.",
51
+ )
52
+ parser.add_argument(
53
+ "--dropout-rate", type=float, default=0.5, help="dropout probability"
54
+ )
55
+ parser.add_argument(
56
+ "--emb-dropout-rate",
57
+ type=float,
58
+ default=0.0,
59
+ help="emb dropout probability",
60
+ )
61
+ parser.add_argument(
62
+ "--tie-weights",
63
+ type=strtobool,
64
+ default=False,
65
+ help="Tie input and output embeddings",
66
+ )
67
+ return parser
68
+
69
+ def __init__(self, n_vocab, args):
70
+ """Initialize class.
71
+
72
+ Args:
73
+ n_vocab (int): The size of the vocabulary
74
+ args (argparse.Namespace): configurations. see py:method:`add_arguments`
75
+
76
+ """
77
+ nn.Module.__init__(self)
78
+ # NOTE: for a compatibility with less than 0.5.0 version models
79
+ dropout_rate = getattr(args, "dropout_rate", 0.0)
80
+ # NOTE: for a compatibility with less than 0.6.1 version models
81
+ embed_unit = getattr(args, "embed_unit", None)
82
+ # NOTE: for a compatibility with less than 0.9.7 version models
83
+ emb_dropout_rate = getattr(args, "emb_dropout_rate", 0.0)
84
+ # NOTE: for a compatibility with less than 0.9.7 version models
85
+ tie_weights = getattr(args, "tie_weights", False)
86
+
87
+ self.model = ClassifierWithState(
88
+ RNNLM(
89
+ n_vocab,
90
+ args.layer,
91
+ args.unit,
92
+ embed_unit,
93
+ args.type,
94
+ dropout_rate,
95
+ emb_dropout_rate,
96
+ tie_weights,
97
+ )
98
+ )
99
+
100
+ def state_dict(self):
101
+ """Dump state dict."""
102
+ return self.model.state_dict()
103
+
104
+ def load_state_dict(self, d):
105
+ """Load state dict."""
106
+ self.model.load_state_dict(d)
107
+
108
+ def forward(self, x, t):
109
+ """Compute LM loss value from buffer sequences.
110
+
111
+ Args:
112
+ x (torch.Tensor): Input ids. (batch, len)
113
+ t (torch.Tensor): Target ids. (batch, len)
114
+
115
+ Returns:
116
+ tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Tuple of
117
+ loss to backward (scalar),
118
+ negative log-likelihood of t: -log p(t) (scalar) and
119
+ the number of elements in x (scalar)
120
+
121
+ Notes:
122
+ The last two return values are used
123
+ in perplexity: p(t)^{-n} = exp(-log p(t) / n)
124
+
125
+ """
126
+ loss = 0
127
+ logp = 0
128
+ count = torch.tensor(0).long()
129
+ state = None
130
+ batch_size, sequence_length = x.shape
131
+ for i in range(sequence_length):
132
+ # Compute the loss at this time step and accumulate it
133
+ state, loss_batch = self.model(state, x[:, i], t[:, i])
134
+ non_zeros = torch.sum(x[:, i] != 0, dtype=loss_batch.dtype)
135
+ loss += loss_batch.mean() * non_zeros
136
+ logp += torch.sum(loss_batch * non_zeros)
137
+ count += int(non_zeros)
138
+ return loss / batch_size, loss, count.to(loss.device)
139
+
140
+ def score(self, y, state, x):
141
+ """Score new token.
142
+
143
+ Args:
144
+ y (torch.Tensor): 1D torch.int64 prefix tokens.
145
+ state: Scorer state for prefix tokens
146
+ x (torch.Tensor): 2D encoder feature that generates ys.
147
+
148
+ Returns:
149
+ tuple[torch.Tensor, Any]: Tuple of
150
+ torch.float32 scores for next token (n_vocab)
151
+ and next state for ys
152
+
153
+ """
154
+ new_state, scores = self.model.predict(state, y[-1].unsqueeze(0))
155
+ return scores.squeeze(0), new_state
156
+
157
+ def final_score(self, state):
158
+ """Score eos.
159
+
160
+ Args:
161
+ state: Scorer state for prefix tokens
162
+
163
+ Returns:
164
+ float: final score
165
+
166
+ """
167
+ return self.model.final(state)
168
+
169
+ # batch beam search API (see BatchScorerInterface)
170
+ def batch_score(
171
+ self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor
172
+ ) -> Tuple[torch.Tensor, List[Any]]:
173
+ """Score new token batch.
174
+
175
+ Args:
176
+ ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen).
177
+ states (List[Any]): Scorer states for prefix tokens.
178
+ xs (torch.Tensor):
179
+ The encoder feature that generates ys (n_batch, xlen, n_feat).
180
+
181
+ Returns:
182
+ tuple[torch.Tensor, List[Any]]: Tuple of
183
+ batchfied scores for next token with shape of `(n_batch, n_vocab)`
184
+ and next state list for ys.
185
+
186
+ """
187
+ # merge states
188
+ n_batch = len(ys)
189
+ n_layers = self.model.predictor.n_layers
190
+ if self.model.predictor.typ == "lstm":
191
+ keys = ("c", "h")
192
+ else:
193
+ keys = ("h",)
194
+
195
+ if states[0] is None:
196
+ states = None
197
+ else:
198
+ # transpose state of [batch, key, layer] into [key, layer, batch]
199
+ states = {
200
+ k: [
201
+ torch.stack([states[b][k][i] for b in range(n_batch)])
202
+ for i in range(n_layers)
203
+ ]
204
+ for k in keys
205
+ }
206
+ states, logp = self.model.predict(states, ys[:, -1])
207
+
208
+ # transpose state of [key, layer, batch] into [batch, key, layer]
209
+ return (
210
+ logp,
211
+ [
212
+ {k: [states[k][i][b] for i in range(n_layers)] for k in keys}
213
+ for b in range(n_batch)
214
+ ],
215
+ )
216
+
217
+
218
+ class ClassifierWithState(nn.Module):
219
+ """A wrapper for pytorch RNNLM."""
220
+
221
+ def __init__(
222
+ self, predictor, lossfun=nn.CrossEntropyLoss(reduction="none"), label_key=-1
223
+ ):
224
+ """Initialize class.
225
+
226
+ :param torch.nn.Module predictor : The RNNLM
227
+ :param function lossfun : The loss function to use
228
+ :param int/str label_key :
229
+
230
+ """
231
+ if not (isinstance(label_key, (int, str))):
232
+ raise TypeError("label_key must be int or str, but is %s" % type(label_key))
233
+ super(ClassifierWithState, self).__init__()
234
+ self.lossfun = lossfun
235
+ self.y = None
236
+ self.loss = None
237
+ self.label_key = label_key
238
+ self.predictor = predictor
239
+
240
+ def forward(self, state, *args, **kwargs):
241
+ """Compute the loss value for an input and label pair.
242
+
243
+ Notes:
244
+ It also computes accuracy and stores it to the attribute.
245
+ When ``label_key`` is ``int``, the corresponding element in ``args``
246
+ is treated as ground truth labels. And when it is ``str``, the
247
+ element in ``kwargs`` is used.
248
+ The all elements of ``args`` and ``kwargs`` except the groundtruth
249
+ labels are features.
250
+ It feeds features to the predictor and compare the result
251
+ with ground truth labels.
252
+
253
+ :param torch.Tensor state : the LM state
254
+ :param list[torch.Tensor] args : Input minibatch
255
+ :param dict[torch.Tensor] kwargs : Input minibatch
256
+ :return loss value
257
+ :rtype torch.Tensor
258
+
259
+ """
260
+ if isinstance(self.label_key, int):
261
+ if not (-len(args) <= self.label_key < len(args)):
262
+ msg = "Label key %d is out of bounds" % self.label_key
263
+ raise ValueError(msg)
264
+ t = args[self.label_key]
265
+ if self.label_key == -1:
266
+ args = args[:-1]
267
+ else:
268
+ args = args[: self.label_key] + args[self.label_key + 1 :]
269
+ elif isinstance(self.label_key, str):
270
+ if self.label_key not in kwargs:
271
+ msg = 'Label key "%s" is not found' % self.label_key
272
+ raise ValueError(msg)
273
+ t = kwargs[self.label_key]
274
+ del kwargs[self.label_key]
275
+
276
+ self.y = None
277
+ self.loss = None
278
+ state, self.y = self.predictor(state, *args, **kwargs)
279
+ self.loss = self.lossfun(self.y, t)
280
+ return state, self.loss
281
+
282
+ def predict(self, state, x):
283
+ """Predict log probabilities for given state and input x using the predictor.
284
+
285
+ :param torch.Tensor state : The current state
286
+ :param torch.Tensor x : The input
287
+ :return a tuple (new state, log prob vector)
288
+ :rtype (torch.Tensor, torch.Tensor)
289
+ """
290
+ if hasattr(self.predictor, "normalized") and self.predictor.normalized:
291
+ return self.predictor(state, x)
292
+ else:
293
+ state, z = self.predictor(state, x)
294
+ return state, F.log_softmax(z, dim=1)
295
+
296
+ def buff_predict(self, state, x, n):
297
+ """Predict new tokens from buffered inputs."""
298
+ if self.predictor.__class__.__name__ == "RNNLM":
299
+ return self.predict(state, x)
300
+
301
+ new_state = []
302
+ new_log_y = []
303
+ for i in range(n):
304
+ state_i = None if state is None else state[i]
305
+ state_i, log_y = self.predict(state_i, x[i].unsqueeze(0))
306
+ new_state.append(state_i)
307
+ new_log_y.append(log_y)
308
+
309
+ return new_state, torch.cat(new_log_y)
310
+
311
+ def final(self, state, index=None):
312
+ """Predict final log probabilities for given state using the predictor.
313
+
314
+ :param state: The state
315
+ :return The final log probabilities
316
+ :rtype torch.Tensor
317
+ """
318
+ if hasattr(self.predictor, "final"):
319
+ if index is not None:
320
+ return self.predictor.final(state[index])
321
+ else:
322
+ return self.predictor.final(state)
323
+ else:
324
+ return 0.0
325
+
326
+
327
+ # Definition of a recurrent net for language modeling
328
+ class RNNLM(nn.Module):
329
+ """A pytorch RNNLM."""
330
+
331
+ def __init__(
332
+ self,
333
+ n_vocab,
334
+ n_layers,
335
+ n_units,
336
+ n_embed=None,
337
+ typ="lstm",
338
+ dropout_rate=0.5,
339
+ emb_dropout_rate=0.0,
340
+ tie_weights=False,
341
+ ):
342
+ """Initialize class.
343
+
344
+ :param int n_vocab: The size of the vocabulary
345
+ :param int n_layers: The number of layers to create
346
+ :param int n_units: The number of units per layer
347
+ :param str typ: The RNN type
348
+ """
349
+ super(RNNLM, self).__init__()
350
+ if n_embed is None:
351
+ n_embed = n_units
352
+
353
+ self.embed = nn.Embedding(n_vocab, n_embed)
354
+
355
+ if emb_dropout_rate == 0.0:
356
+ self.embed_drop = None
357
+ else:
358
+ self.embed_drop = nn.Dropout(emb_dropout_rate)
359
+
360
+ if typ == "lstm":
361
+ self.rnn = nn.ModuleList(
362
+ [nn.LSTMCell(n_embed, n_units)]
363
+ + [nn.LSTMCell(n_units, n_units) for _ in range(n_layers - 1)]
364
+ )
365
+ else:
366
+ self.rnn = nn.ModuleList(
367
+ [nn.GRUCell(n_embed, n_units)]
368
+ + [nn.GRUCell(n_units, n_units) for _ in range(n_layers - 1)]
369
+ )
370
+
371
+ self.dropout = nn.ModuleList(
372
+ [nn.Dropout(dropout_rate) for _ in range(n_layers + 1)]
373
+ )
374
+ self.lo = nn.Linear(n_units, n_vocab)
375
+ self.n_layers = n_layers
376
+ self.n_units = n_units
377
+ self.typ = typ
378
+
379
+ logging.info("Tie weights set to {}".format(tie_weights))
380
+ logging.info("Dropout set to {}".format(dropout_rate))
381
+ logging.info("Emb Dropout set to {}".format(emb_dropout_rate))
382
+
383
+ if tie_weights:
384
+ assert (
385
+ n_embed == n_units
386
+ ), "Tie Weights: True need embedding and final dimensions to match"
387
+ self.lo.weight = self.embed.weight
388
+
389
+ # initialize parameters from uniform distribution
390
+ for param in self.parameters():
391
+ param.data.uniform_(-0.1, 0.1)
392
+
393
+ def zero_state(self, batchsize):
394
+ """Initialize state."""
395
+ p = next(self.parameters())
396
+ return torch.zeros(batchsize, self.n_units).to(device=p.device, dtype=p.dtype)
397
+
398
+ def forward(self, state, x):
399
+ """Forward neural networks."""
400
+ if state is None:
401
+ h = [to_device(x, self.zero_state(x.size(0))) for n in range(self.n_layers)]
402
+ state = {"h": h}
403
+ if self.typ == "lstm":
404
+ c = [
405
+ to_device(x, self.zero_state(x.size(0)))
406
+ for n in range(self.n_layers)
407
+ ]
408
+ state = {"c": c, "h": h}
409
+
410
+ h = [None] * self.n_layers
411
+ if self.embed_drop is not None:
412
+ emb = self.embed_drop(self.embed(x))
413
+ else:
414
+ emb = self.embed(x)
415
+ if self.typ == "lstm":
416
+ c = [None] * self.n_layers
417
+ h[0], c[0] = self.rnn[0](
418
+ self.dropout[0](emb), (state["h"][0], state["c"][0])
419
+ )
420
+ for n in range(1, self.n_layers):
421
+ h[n], c[n] = self.rnn[n](
422
+ self.dropout[n](h[n - 1]), (state["h"][n], state["c"][n])
423
+ )
424
+ state = {"c": c, "h": h}
425
+ else:
426
+ h[0] = self.rnn[0](self.dropout[0](emb), state["h"][0])
427
+ for n in range(1, self.n_layers):
428
+ h[n] = self.rnn[n](self.dropout[n](h[n - 1]), state["h"][n])
429
+ state = {"h": h}
430
+ y = self.lo(self.dropout[-1](h[-1]))
431
+ return state, y
espnet/nets/pytorch_backend/lm/seq_rnn.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Sequential implementation of Recurrent Neural Network Language Model."""
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+ from espnet.nets.lm_interface import LMInterface
8
+
9
+
10
+ class SequentialRNNLM(LMInterface, torch.nn.Module):
11
+ """Sequential RNNLM.
12
+
13
+ See also:
14
+ https://github.com/pytorch/examples/blob/4581968193699de14b56527296262dd76ab43557/word_language_model/model.py
15
+
16
+ """
17
+
18
+ @staticmethod
19
+ def add_arguments(parser):
20
+ """Add arguments to command line argument parser."""
21
+ parser.add_argument(
22
+ "--type",
23
+ type=str,
24
+ default="lstm",
25
+ nargs="?",
26
+ choices=["lstm", "gru"],
27
+ help="Which type of RNN to use",
28
+ )
29
+ parser.add_argument(
30
+ "--layer", "-l", type=int, default=2, help="Number of hidden layers"
31
+ )
32
+ parser.add_argument(
33
+ "--unit", "-u", type=int, default=650, help="Number of hidden units"
34
+ )
35
+ parser.add_argument(
36
+ "--dropout-rate", type=float, default=0.5, help="dropout probability"
37
+ )
38
+ return parser
39
+
40
+ def __init__(self, n_vocab, args):
41
+ """Initialize class.
42
+
43
+ Args:
44
+ n_vocab (int): The size of the vocabulary
45
+ args (argparse.Namespace): configurations. see py:method:`add_arguments`
46
+
47
+ """
48
+ torch.nn.Module.__init__(self)
49
+ self._setup(
50
+ rnn_type=args.type.upper(),
51
+ ntoken=n_vocab,
52
+ ninp=args.unit,
53
+ nhid=args.unit,
54
+ nlayers=args.layer,
55
+ dropout=args.dropout_rate,
56
+ )
57
+
58
+ def _setup(
59
+ self, rnn_type, ntoken, ninp, nhid, nlayers, dropout=0.5, tie_weights=False
60
+ ):
61
+ self.drop = nn.Dropout(dropout)
62
+ self.encoder = nn.Embedding(ntoken, ninp)
63
+ if rnn_type in ["LSTM", "GRU"]:
64
+ self.rnn = getattr(nn, rnn_type)(ninp, nhid, nlayers, dropout=dropout)
65
+ else:
66
+ try:
67
+ nonlinearity = {"RNN_TANH": "tanh", "RNN_RELU": "relu"}[rnn_type]
68
+ except KeyError:
69
+ raise ValueError(
70
+ "An invalid option for `--model` was supplied, "
71
+ "options are ['LSTM', 'GRU', 'RNN_TANH' or 'RNN_RELU']"
72
+ )
73
+ self.rnn = nn.RNN(
74
+ ninp, nhid, nlayers, nonlinearity=nonlinearity, dropout=dropout
75
+ )
76
+ self.decoder = nn.Linear(nhid, ntoken)
77
+
78
+ # Optionally tie weights as in:
79
+ # "Using the Output Embedding to Improve Language Models" (Press & Wolf 2016)
80
+ # https://arxiv.org/abs/1608.05859
81
+ # and
82
+ # "Tying Word Vectors and Word Classifiers:
83
+ # A Loss Framework for Language Modeling" (Inan et al. 2016)
84
+ # https://arxiv.org/abs/1611.01462
85
+ if tie_weights:
86
+ if nhid != ninp:
87
+ raise ValueError(
88
+ "When using the tied flag, nhid must be equal to emsize"
89
+ )
90
+ self.decoder.weight = self.encoder.weight
91
+
92
+ self._init_weights()
93
+
94
+ self.rnn_type = rnn_type
95
+ self.nhid = nhid
96
+ self.nlayers = nlayers
97
+
98
+ def _init_weights(self):
99
+ # NOTE: original init in pytorch/examples
100
+ # initrange = 0.1
101
+ # self.encoder.weight.data.uniform_(-initrange, initrange)
102
+ # self.decoder.bias.data.zero_()
103
+ # self.decoder.weight.data.uniform_(-initrange, initrange)
104
+ # NOTE: our default.py:RNNLM init
105
+ for param in self.parameters():
106
+ param.data.uniform_(-0.1, 0.1)
107
+
108
+ def forward(self, x, t):
109
+ """Compute LM loss value from buffer sequences.
110
+
111
+ Args:
112
+ x (torch.Tensor): Input ids. (batch, len)
113
+ t (torch.Tensor): Target ids. (batch, len)
114
+
115
+ Returns:
116
+ tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Tuple of
117
+ loss to backward (scalar),
118
+ negative log-likelihood of t: -log p(t) (scalar) and
119
+ the number of elements in x (scalar)
120
+
121
+ Notes:
122
+ The last two return values are used
123
+ in perplexity: p(t)^{-n} = exp(-log p(t) / n)
124
+
125
+ """
126
+ y = self._before_loss(x, None)[0]
127
+ mask = (x != 0).to(y.dtype)
128
+ loss = F.cross_entropy(y.view(-1, y.shape[-1]), t.view(-1), reduction="none")
129
+ logp = loss * mask.view(-1)
130
+ logp = logp.sum()
131
+ count = mask.sum()
132
+ return logp / count, logp, count
133
+
134
+ def _before_loss(self, input, hidden):
135
+ emb = self.drop(self.encoder(input))
136
+ output, hidden = self.rnn(emb, hidden)
137
+ output = self.drop(output)
138
+ decoded = self.decoder(
139
+ output.view(output.size(0) * output.size(1), output.size(2))
140
+ )
141
+ return decoded.view(output.size(0), output.size(1), decoded.size(1)), hidden
142
+
143
+ def init_state(self, x):
144
+ """Get an initial state for decoding.
145
+
146
+ Args:
147
+ x (torch.Tensor): The encoded feature tensor
148
+
149
+ Returns: initial state
150
+
151
+ """
152
+ bsz = 1
153
+ weight = next(self.parameters())
154
+ if self.rnn_type == "LSTM":
155
+ return (
156
+ weight.new_zeros(self.nlayers, bsz, self.nhid),
157
+ weight.new_zeros(self.nlayers, bsz, self.nhid),
158
+ )
159
+ else:
160
+ return weight.new_zeros(self.nlayers, bsz, self.nhid)
161
+
162
+ def score(self, y, state, x):
163
+ """Score new token.
164
+
165
+ Args:
166
+ y (torch.Tensor): 1D torch.int64 prefix tokens.
167
+ state: Scorer state for prefix tokens
168
+ x (torch.Tensor): 2D encoder feature that generates ys.
169
+
170
+ Returns:
171
+ tuple[torch.Tensor, Any]: Tuple of
172
+ torch.float32 scores for next token (n_vocab)
173
+ and next state for ys
174
+
175
+ """
176
+ y, new_state = self._before_loss(y[-1].view(1, 1), state)
177
+ logp = y.log_softmax(dim=-1).view(-1)
178
+ return logp, new_state
espnet/nets/pytorch_backend/lm/transformer.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Transformer language model."""
2
+
3
+ from typing import Any
4
+ from typing import List
5
+ from typing import Tuple
6
+
7
+ import logging
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+
12
+ from espnet.nets.lm_interface import LMInterface
13
+ from espnet.nets.pytorch_backend.transformer.embedding import PositionalEncoding
14
+ from espnet.nets.pytorch_backend.transformer.encoder import Encoder
15
+ from espnet.nets.pytorch_backend.transformer.mask import subsequent_mask
16
+ from espnet.nets.scorer_interface import BatchScorerInterface
17
+ from espnet.utils.cli_utils import strtobool
18
+
19
+
20
+ class TransformerLM(nn.Module, LMInterface, BatchScorerInterface):
21
+ """Transformer language model."""
22
+
23
+ @staticmethod
24
+ def add_arguments(parser):
25
+ """Add arguments to command line argument parser."""
26
+ parser.add_argument(
27
+ "--layer", type=int, default=4, help="Number of hidden layers"
28
+ )
29
+ parser.add_argument(
30
+ "--unit",
31
+ type=int,
32
+ default=1024,
33
+ help="Number of hidden units in feedforward layer",
34
+ )
35
+ parser.add_argument(
36
+ "--att-unit",
37
+ type=int,
38
+ default=256,
39
+ help="Number of hidden units in attention layer",
40
+ )
41
+ parser.add_argument(
42
+ "--embed-unit",
43
+ type=int,
44
+ default=128,
45
+ help="Number of hidden units in embedding layer",
46
+ )
47
+ parser.add_argument(
48
+ "--head", type=int, default=2, help="Number of multi head attention"
49
+ )
50
+ parser.add_argument(
51
+ "--dropout-rate", type=float, default=0.5, help="dropout probability"
52
+ )
53
+ parser.add_argument(
54
+ "--att-dropout-rate",
55
+ type=float,
56
+ default=0.0,
57
+ help="att dropout probability",
58
+ )
59
+ parser.add_argument(
60
+ "--emb-dropout-rate",
61
+ type=float,
62
+ default=0.0,
63
+ help="emb dropout probability",
64
+ )
65
+ parser.add_argument(
66
+ "--tie-weights",
67
+ type=strtobool,
68
+ default=False,
69
+ help="Tie input and output embeddings",
70
+ )
71
+ parser.add_argument(
72
+ "--pos-enc",
73
+ default="sinusoidal",
74
+ choices=["sinusoidal", "none"],
75
+ help="positional encoding",
76
+ )
77
+ return parser
78
+
79
+ def __init__(self, n_vocab, args):
80
+ """Initialize class.
81
+
82
+ Args:
83
+ n_vocab (int): The size of the vocabulary
84
+ args (argparse.Namespace): configurations. see py:method:`add_arguments`
85
+
86
+ """
87
+ nn.Module.__init__(self)
88
+
89
+ # NOTE: for a compatibility with less than 0.9.7 version models
90
+ emb_dropout_rate = getattr(args, "emb_dropout_rate", 0.0)
91
+ # NOTE: for a compatibility with less than 0.9.7 version models
92
+ tie_weights = getattr(args, "tie_weights", False)
93
+ # NOTE: for a compatibility with less than 0.9.7 version models
94
+ att_dropout_rate = getattr(args, "att_dropout_rate", 0.0)
95
+
96
+ if args.pos_enc == "sinusoidal":
97
+ pos_enc_class = PositionalEncoding
98
+ elif args.pos_enc == "none":
99
+
100
+ def pos_enc_class(*args, **kwargs):
101
+ return nn.Sequential() # indentity
102
+
103
+ else:
104
+ raise ValueError(f"unknown pos-enc option: {args.pos_enc}")
105
+
106
+ self.embed = nn.Embedding(n_vocab, args.embed_unit)
107
+
108
+ if emb_dropout_rate == 0.0:
109
+ self.embed_drop = None
110
+ else:
111
+ self.embed_drop = nn.Dropout(emb_dropout_rate)
112
+
113
+ self.encoder = Encoder(
114
+ idim=args.embed_unit,
115
+ attention_dim=args.att_unit,
116
+ attention_heads=args.head,
117
+ linear_units=args.unit,
118
+ num_blocks=args.layer,
119
+ dropout_rate=args.dropout_rate,
120
+ attention_dropout_rate=att_dropout_rate,
121
+ input_layer="linear",
122
+ pos_enc_class=pos_enc_class,
123
+ )
124
+ self.decoder = nn.Linear(args.att_unit, n_vocab)
125
+
126
+ logging.info("Tie weights set to {}".format(tie_weights))
127
+ logging.info("Dropout set to {}".format(args.dropout_rate))
128
+ logging.info("Emb Dropout set to {}".format(emb_dropout_rate))
129
+ logging.info("Att Dropout set to {}".format(att_dropout_rate))
130
+
131
+ if tie_weights:
132
+ assert (
133
+ args.att_unit == args.embed_unit
134
+ ), "Tie Weights: True need embedding and final dimensions to match"
135
+ self.decoder.weight = self.embed.weight
136
+
137
+ def _target_mask(self, ys_in_pad):
138
+ ys_mask = ys_in_pad != 0
139
+ m = subsequent_mask(ys_mask.size(-1), device=ys_mask.device).unsqueeze(0)
140
+ return ys_mask.unsqueeze(-2) & m
141
+
142
+ def forward(
143
+ self, x: torch.Tensor, t: torch.Tensor
144
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
145
+ """Compute LM loss value from buffer sequences.
146
+
147
+ Args:
148
+ x (torch.Tensor): Input ids. (batch, len)
149
+ t (torch.Tensor): Target ids. (batch, len)
150
+
151
+ Returns:
152
+ tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Tuple of
153
+ loss to backward (scalar),
154
+ negative log-likelihood of t: -log p(t) (scalar) and
155
+ the number of elements in x (scalar)
156
+
157
+ Notes:
158
+ The last two return values are used
159
+ in perplexity: p(t)^{-n} = exp(-log p(t) / n)
160
+
161
+ """
162
+ xm = x != 0
163
+
164
+ if self.embed_drop is not None:
165
+ emb = self.embed_drop(self.embed(x))
166
+ else:
167
+ emb = self.embed(x)
168
+
169
+ h, _ = self.encoder(emb, self._target_mask(x))
170
+ y = self.decoder(h)
171
+ loss = F.cross_entropy(y.view(-1, y.shape[-1]), t.view(-1), reduction="none")
172
+ mask = xm.to(dtype=loss.dtype)
173
+ logp = loss * mask.view(-1)
174
+ logp = logp.sum()
175
+ count = mask.sum()
176
+ return logp / count, logp, count
177
+
178
+ def score(
179
+ self, y: torch.Tensor, state: Any, x: torch.Tensor
180
+ ) -> Tuple[torch.Tensor, Any]:
181
+ """Score new token.
182
+
183
+ Args:
184
+ y (torch.Tensor): 1D torch.int64 prefix tokens.
185
+ state: Scorer state for prefix tokens
186
+ x (torch.Tensor): encoder feature that generates ys.
187
+
188
+ Returns:
189
+ tuple[torch.Tensor, Any]: Tuple of
190
+ torch.float32 scores for next token (n_vocab)
191
+ and next state for ys
192
+
193
+ """
194
+ y = y.unsqueeze(0)
195
+
196
+ if self.embed_drop is not None:
197
+ emb = self.embed_drop(self.embed(y))
198
+ else:
199
+ emb = self.embed(y)
200
+
201
+ h, _, cache = self.encoder.forward_one_step(
202
+ emb, self._target_mask(y), cache=state
203
+ )
204
+ h = self.decoder(h[:, -1])
205
+ logp = h.log_softmax(dim=-1).squeeze(0)
206
+ return logp, cache
207
+
208
+ # batch beam search API (see BatchScorerInterface)
209
+ def batch_score(
210
+ self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor
211
+ ) -> Tuple[torch.Tensor, List[Any]]:
212
+ """Score new token batch (required).
213
+
214
+ Args:
215
+ ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen).
216
+ states (List[Any]): Scorer states for prefix tokens.
217
+ xs (torch.Tensor):
218
+ The encoder feature that generates ys (n_batch, xlen, n_feat).
219
+
220
+ Returns:
221
+ tuple[torch.Tensor, List[Any]]: Tuple of
222
+ batchfied scores for next token with shape of `(n_batch, n_vocab)`
223
+ and next state list for ys.
224
+
225
+ """
226
+ # merge states
227
+ n_batch = len(ys)
228
+ n_layers = len(self.encoder.encoders)
229
+ if states[0] is None:
230
+ batch_state = None
231
+ else:
232
+ # transpose state of [batch, layer] into [layer, batch]
233
+ batch_state = [
234
+ torch.stack([states[b][i] for b in range(n_batch)])
235
+ for i in range(n_layers)
236
+ ]
237
+
238
+ if self.embed_drop is not None:
239
+ emb = self.embed_drop(self.embed(ys))
240
+ else:
241
+ emb = self.embed(ys)
242
+
243
+ # batch decoding
244
+ h, _, states = self.encoder.forward_one_step(
245
+ emb, self._target_mask(ys), cache=batch_state
246
+ )
247
+ h = self.decoder(h[:, -1])
248
+ logp = h.log_softmax(dim=-1)
249
+
250
+ # transpose state of [layer, batch] into [batch, layer]
251
+ state_list = [[states[i][b] for i in range(n_layers)] for b in range(n_batch)]
252
+ return logp, state_list
espnet/nets/pytorch_backend/nets_utils.py ADDED
@@ -0,0 +1,526 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ """Network related utility tools."""
4
+
5
+ import logging
6
+ from typing import Dict
7
+
8
+ import numpy as np
9
+ import torch
10
+
11
+
12
+ def to_device(m, x):
13
+ """Send tensor into the device of the module.
14
+
15
+ Args:
16
+ m (torch.nn.Module): Torch module.
17
+ x (Tensor): Torch tensor.
18
+
19
+ Returns:
20
+ Tensor: Torch tensor located in the same place as torch module.
21
+
22
+ """
23
+ if isinstance(m, torch.nn.Module):
24
+ device = next(m.parameters()).device
25
+ elif isinstance(m, torch.Tensor):
26
+ device = m.device
27
+ else:
28
+ raise TypeError(
29
+ "Expected torch.nn.Module or torch.tensor, " f"bot got: {type(m)}"
30
+ )
31
+ return x.to(device)
32
+
33
+
34
+ def pad_list(xs, pad_value):
35
+ """Perform padding for the list of tensors.
36
+
37
+ Args:
38
+ xs (List): List of Tensors [(T_1, `*`), (T_2, `*`), ..., (T_B, `*`)].
39
+ pad_value (float): Value for padding.
40
+
41
+ Returns:
42
+ Tensor: Padded tensor (B, Tmax, `*`).
43
+
44
+ Examples:
45
+ >>> x = [torch.ones(4), torch.ones(2), torch.ones(1)]
46
+ >>> x
47
+ [tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])]
48
+ >>> pad_list(x, 0)
49
+ tensor([[1., 1., 1., 1.],
50
+ [1., 1., 0., 0.],
51
+ [1., 0., 0., 0.]])
52
+
53
+ """
54
+ n_batch = len(xs)
55
+ max_len = max(x.size(0) for x in xs)
56
+ pad = xs[0].new(n_batch, max_len, *xs[0].size()[1:]).fill_(pad_value)
57
+
58
+ for i in range(n_batch):
59
+ pad[i, : xs[i].size(0)] = xs[i]
60
+
61
+ return pad
62
+
63
+
64
+ def make_pad_mask(lengths, xs=None, length_dim=-1, maxlen=None):
65
+ """Make mask tensor containing indices of padded part.
66
+
67
+ Args:
68
+ lengths (LongTensor or List): Batch of lengths (B,).
69
+ xs (Tensor, optional): The reference tensor.
70
+ If set, masks will be the same shape as this tensor.
71
+ length_dim (int, optional): Dimension indicator of the above tensor.
72
+ See the example.
73
+
74
+ Returns:
75
+ Tensor: Mask tensor containing indices of padded part.
76
+ dtype=torch.uint8 in PyTorch 1.2-
77
+ dtype=torch.bool in PyTorch 1.2+ (including 1.2)
78
+
79
+ Examples:
80
+ With only lengths.
81
+
82
+ >>> lengths = [5, 3, 2]
83
+ >>> make_pad_mask(lengths)
84
+ masks = [[0, 0, 0, 0 ,0],
85
+ [0, 0, 0, 1, 1],
86
+ [0, 0, 1, 1, 1]]
87
+
88
+ With the reference tensor.
89
+
90
+ >>> xs = torch.zeros((3, 2, 4))
91
+ >>> make_pad_mask(lengths, xs)
92
+ tensor([[[0, 0, 0, 0],
93
+ [0, 0, 0, 0]],
94
+ [[0, 0, 0, 1],
95
+ [0, 0, 0, 1]],
96
+ [[0, 0, 1, 1],
97
+ [0, 0, 1, 1]]], dtype=torch.uint8)
98
+ >>> xs = torch.zeros((3, 2, 6))
99
+ >>> make_pad_mask(lengths, xs)
100
+ tensor([[[0, 0, 0, 0, 0, 1],
101
+ [0, 0, 0, 0, 0, 1]],
102
+ [[0, 0, 0, 1, 1, 1],
103
+ [0, 0, 0, 1, 1, 1]],
104
+ [[0, 0, 1, 1, 1, 1],
105
+ [0, 0, 1, 1, 1, 1]]], dtype=torch.uint8)
106
+
107
+ With the reference tensor and dimension indicator.
108
+
109
+ >>> xs = torch.zeros((3, 6, 6))
110
+ >>> make_pad_mask(lengths, xs, 1)
111
+ tensor([[[0, 0, 0, 0, 0, 0],
112
+ [0, 0, 0, 0, 0, 0],
113
+ [0, 0, 0, 0, 0, 0],
114
+ [0, 0, 0, 0, 0, 0],
115
+ [0, 0, 0, 0, 0, 0],
116
+ [1, 1, 1, 1, 1, 1]],
117
+ [[0, 0, 0, 0, 0, 0],
118
+ [0, 0, 0, 0, 0, 0],
119
+ [0, 0, 0, 0, 0, 0],
120
+ [1, 1, 1, 1, 1, 1],
121
+ [1, 1, 1, 1, 1, 1],
122
+ [1, 1, 1, 1, 1, 1]],
123
+ [[0, 0, 0, 0, 0, 0],
124
+ [0, 0, 0, 0, 0, 0],
125
+ [1, 1, 1, 1, 1, 1],
126
+ [1, 1, 1, 1, 1, 1],
127
+ [1, 1, 1, 1, 1, 1],
128
+ [1, 1, 1, 1, 1, 1]]], dtype=torch.uint8)
129
+ >>> make_pad_mask(lengths, xs, 2)
130
+ tensor([[[0, 0, 0, 0, 0, 1],
131
+ [0, 0, 0, 0, 0, 1],
132
+ [0, 0, 0, 0, 0, 1],
133
+ [0, 0, 0, 0, 0, 1],
134
+ [0, 0, 0, 0, 0, 1],
135
+ [0, 0, 0, 0, 0, 1]],
136
+ [[0, 0, 0, 1, 1, 1],
137
+ [0, 0, 0, 1, 1, 1],
138
+ [0, 0, 0, 1, 1, 1],
139
+ [0, 0, 0, 1, 1, 1],
140
+ [0, 0, 0, 1, 1, 1],
141
+ [0, 0, 0, 1, 1, 1]],
142
+ [[0, 0, 1, 1, 1, 1],
143
+ [0, 0, 1, 1, 1, 1],
144
+ [0, 0, 1, 1, 1, 1],
145
+ [0, 0, 1, 1, 1, 1],
146
+ [0, 0, 1, 1, 1, 1],
147
+ [0, 0, 1, 1, 1, 1]]], dtype=torch.uint8)
148
+
149
+ """
150
+ if length_dim == 0:
151
+ raise ValueError("length_dim cannot be 0: {}".format(length_dim))
152
+
153
+ if not isinstance(lengths, list):
154
+ lengths = lengths.tolist()
155
+ bs = int(len(lengths))
156
+ if maxlen is None:
157
+ if xs is None:
158
+ maxlen = int(max(lengths))
159
+ else:
160
+ maxlen = xs.size(length_dim)
161
+ else:
162
+ assert xs is None
163
+ assert maxlen >= int(max(lengths))
164
+
165
+ seq_range = torch.arange(0, maxlen, dtype=torch.int64)
166
+ seq_range_expand = seq_range.unsqueeze(0).expand(bs, maxlen)
167
+ seq_length_expand = seq_range_expand.new(lengths).unsqueeze(-1)
168
+ mask = seq_range_expand >= seq_length_expand
169
+
170
+ if xs is not None:
171
+ assert xs.size(0) == bs, (xs.size(0), bs)
172
+
173
+ if length_dim < 0:
174
+ length_dim = xs.dim() + length_dim
175
+ # ind = (:, None, ..., None, :, , None, ..., None)
176
+ ind = tuple(
177
+ slice(None) if i in (0, length_dim) else None for i in range(xs.dim())
178
+ )
179
+ mask = mask[ind].expand_as(xs).to(xs.device)
180
+ return mask
181
+
182
+
183
+ def make_non_pad_mask(lengths, xs=None, length_dim=-1):
184
+ """Make mask tensor containing indices of non-padded part.
185
+
186
+ Args:
187
+ lengths (LongTensor or List): Batch of lengths (B,).
188
+ xs (Tensor, optional): The reference tensor.
189
+ If set, masks will be the same shape as this tensor.
190
+ length_dim (int, optional): Dimension indicator of the above tensor.
191
+ See the example.
192
+
193
+ Returns:
194
+ ByteTensor: mask tensor containing indices of padded part.
195
+ dtype=torch.uint8 in PyTorch 1.2-
196
+ dtype=torch.bool in PyTorch 1.2+ (including 1.2)
197
+
198
+ Examples:
199
+ With only lengths.
200
+
201
+ >>> lengths = [5, 3, 2]
202
+ >>> make_non_pad_mask(lengths)
203
+ masks = [[1, 1, 1, 1 ,1],
204
+ [1, 1, 1, 0, 0],
205
+ [1, 1, 0, 0, 0]]
206
+
207
+ With the reference tensor.
208
+
209
+ >>> xs = torch.zeros((3, 2, 4))
210
+ >>> make_non_pad_mask(lengths, xs)
211
+ tensor([[[1, 1, 1, 1],
212
+ [1, 1, 1, 1]],
213
+ [[1, 1, 1, 0],
214
+ [1, 1, 1, 0]],
215
+ [[1, 1, 0, 0],
216
+ [1, 1, 0, 0]]], dtype=torch.uint8)
217
+ >>> xs = torch.zeros((3, 2, 6))
218
+ >>> make_non_pad_mask(lengths, xs)
219
+ tensor([[[1, 1, 1, 1, 1, 0],
220
+ [1, 1, 1, 1, 1, 0]],
221
+ [[1, 1, 1, 0, 0, 0],
222
+ [1, 1, 1, 0, 0, 0]],
223
+ [[1, 1, 0, 0, 0, 0],
224
+ [1, 1, 0, 0, 0, 0]]], dtype=torch.uint8)
225
+
226
+ With the reference tensor and dimension indicator.
227
+
228
+ >>> xs = torch.zeros((3, 6, 6))
229
+ >>> make_non_pad_mask(lengths, xs, 1)
230
+ tensor([[[1, 1, 1, 1, 1, 1],
231
+ [1, 1, 1, 1, 1, 1],
232
+ [1, 1, 1, 1, 1, 1],
233
+ [1, 1, 1, 1, 1, 1],
234
+ [1, 1, 1, 1, 1, 1],
235
+ [0, 0, 0, 0, 0, 0]],
236
+ [[1, 1, 1, 1, 1, 1],
237
+ [1, 1, 1, 1, 1, 1],
238
+ [1, 1, 1, 1, 1, 1],
239
+ [0, 0, 0, 0, 0, 0],
240
+ [0, 0, 0, 0, 0, 0],
241
+ [0, 0, 0, 0, 0, 0]],
242
+ [[1, 1, 1, 1, 1, 1],
243
+ [1, 1, 1, 1, 1, 1],
244
+ [0, 0, 0, 0, 0, 0],
245
+ [0, 0, 0, 0, 0, 0],
246
+ [0, 0, 0, 0, 0, 0],
247
+ [0, 0, 0, 0, 0, 0]]], dtype=torch.uint8)
248
+ >>> make_non_pad_mask(lengths, xs, 2)
249
+ tensor([[[1, 1, 1, 1, 1, 0],
250
+ [1, 1, 1, 1, 1, 0],
251
+ [1, 1, 1, 1, 1, 0],
252
+ [1, 1, 1, 1, 1, 0],
253
+ [1, 1, 1, 1, 1, 0],
254
+ [1, 1, 1, 1, 1, 0]],
255
+ [[1, 1, 1, 0, 0, 0],
256
+ [1, 1, 1, 0, 0, 0],
257
+ [1, 1, 1, 0, 0, 0],
258
+ [1, 1, 1, 0, 0, 0],
259
+ [1, 1, 1, 0, 0, 0],
260
+ [1, 1, 1, 0, 0, 0]],
261
+ [[1, 1, 0, 0, 0, 0],
262
+ [1, 1, 0, 0, 0, 0],
263
+ [1, 1, 0, 0, 0, 0],
264
+ [1, 1, 0, 0, 0, 0],
265
+ [1, 1, 0, 0, 0, 0],
266
+ [1, 1, 0, 0, 0, 0]]], dtype=torch.uint8)
267
+
268
+ """
269
+ return ~make_pad_mask(lengths, xs, length_dim)
270
+
271
+
272
+ def mask_by_length(xs, lengths, fill=0):
273
+ """Mask tensor according to length.
274
+
275
+ Args:
276
+ xs (Tensor): Batch of input tensor (B, `*`).
277
+ lengths (LongTensor or List): Batch of lengths (B,).
278
+ fill (int or float): Value to fill masked part.
279
+
280
+ Returns:
281
+ Tensor: Batch of masked input tensor (B, `*`).
282
+
283
+ Examples:
284
+ >>> x = torch.arange(5).repeat(3, 1) + 1
285
+ >>> x
286
+ tensor([[1, 2, 3, 4, 5],
287
+ [1, 2, 3, 4, 5],
288
+ [1, 2, 3, 4, 5]])
289
+ >>> lengths = [5, 3, 2]
290
+ >>> mask_by_length(x, lengths)
291
+ tensor([[1, 2, 3, 4, 5],
292
+ [1, 2, 3, 0, 0],
293
+ [1, 2, 0, 0, 0]])
294
+
295
+ """
296
+ assert xs.size(0) == len(lengths)
297
+ ret = xs.data.new(*xs.size()).fill_(fill)
298
+ for i, l in enumerate(lengths):
299
+ ret[i, :l] = xs[i, :l]
300
+ return ret
301
+
302
+
303
+ def th_accuracy(pad_outputs, pad_targets, ignore_label):
304
+ """Calculate accuracy.
305
+
306
+ Args:
307
+ pad_outputs (Tensor): Prediction tensors (B * Lmax, D).
308
+ pad_targets (LongTensor): Target label tensors (B, Lmax, D).
309
+ ignore_label (int): Ignore label id.
310
+
311
+ Returns:
312
+ float: Accuracy value (0.0 - 1.0).
313
+
314
+ """
315
+ pad_pred = pad_outputs.view(
316
+ pad_targets.size(0), pad_targets.size(1), pad_outputs.size(1)
317
+ ).argmax(2)
318
+ mask = pad_targets != ignore_label
319
+ numerator = torch.sum(
320
+ pad_pred.masked_select(mask) == pad_targets.masked_select(mask)
321
+ )
322
+ denominator = torch.sum(mask)
323
+ return float(numerator) / float(denominator)
324
+
325
+
326
+ def to_torch_tensor(x):
327
+ """Change to torch.Tensor or ComplexTensor from numpy.ndarray.
328
+
329
+ Args:
330
+ x: Inputs. It should be one of numpy.ndarray, Tensor, ComplexTensor, and dict.
331
+
332
+ Returns:
333
+ Tensor or ComplexTensor: Type converted inputs.
334
+
335
+ Examples:
336
+ >>> xs = np.ones(3, dtype=np.float32)
337
+ >>> xs = to_torch_tensor(xs)
338
+ tensor([1., 1., 1.])
339
+ >>> xs = torch.ones(3, 4, 5)
340
+ >>> assert to_torch_tensor(xs) is xs
341
+ >>> xs = {'real': xs, 'imag': xs}
342
+ >>> to_torch_tensor(xs)
343
+ ComplexTensor(
344
+ Real:
345
+ tensor([1., 1., 1.])
346
+ Imag;
347
+ tensor([1., 1., 1.])
348
+ )
349
+
350
+ """
351
+ # If numpy, change to torch tensor
352
+ if isinstance(x, np.ndarray):
353
+ if x.dtype.kind == "c":
354
+ # Dynamically importing because torch_complex requires python3
355
+ from torch_complex.tensor import ComplexTensor
356
+
357
+ return ComplexTensor(x)
358
+ else:
359
+ return torch.from_numpy(x)
360
+
361
+ # If {'real': ..., 'imag': ...}, convert to ComplexTensor
362
+ elif isinstance(x, dict):
363
+ # Dynamically importing because torch_complex requires python3
364
+ from torch_complex.tensor import ComplexTensor
365
+
366
+ if "real" not in x or "imag" not in x:
367
+ raise ValueError("has 'real' and 'imag' keys: {}".format(list(x)))
368
+ # Relative importing because of using python3 syntax
369
+ return ComplexTensor(x["real"], x["imag"])
370
+
371
+ # If torch.Tensor, as it is
372
+ elif isinstance(x, torch.Tensor):
373
+ return x
374
+
375
+ else:
376
+ error = (
377
+ "x must be numpy.ndarray, torch.Tensor or a dict like "
378
+ "{{'real': torch.Tensor, 'imag': torch.Tensor}}, "
379
+ "but got {}".format(type(x))
380
+ )
381
+ try:
382
+ from torch_complex.tensor import ComplexTensor
383
+ except Exception:
384
+ # If PY2
385
+ raise ValueError(error)
386
+ else:
387
+ # If PY3
388
+ if isinstance(x, ComplexTensor):
389
+ return x
390
+ else:
391
+ raise ValueError(error)
392
+
393
+
394
+ def get_subsample(train_args, mode, arch):
395
+ """Parse the subsampling factors from the args for the specified `mode` and `arch`.
396
+
397
+ Args:
398
+ train_args: argument Namespace containing options.
399
+ mode: one of ('asr', 'mt', 'st')
400
+ arch: one of ('rnn', 'rnn-t', 'rnn_mix', 'rnn_mulenc', 'transformer')
401
+
402
+ Returns:
403
+ np.ndarray / List[np.ndarray]: subsampling factors.
404
+ """
405
+ if arch == "transformer":
406
+ return np.array([1])
407
+
408
+ elif mode == "mt" and arch == "rnn":
409
+ # +1 means input (+1) and layers outputs (train_args.elayer)
410
+ subsample = np.ones(train_args.elayers + 1, dtype=np.int)
411
+ logging.warning("Subsampling is not performed for machine translation.")
412
+ logging.info("subsample: " + " ".join([str(x) for x in subsample]))
413
+ return subsample
414
+
415
+ elif (
416
+ (mode == "asr" and arch in ("rnn", "rnn-t"))
417
+ or (mode == "mt" and arch == "rnn")
418
+ or (mode == "st" and arch == "rnn")
419
+ ):
420
+ subsample = np.ones(train_args.elayers + 1, dtype=np.int)
421
+ if train_args.etype.endswith("p") and not train_args.etype.startswith("vgg"):
422
+ ss = train_args.subsample.split("_")
423
+ for j in range(min(train_args.elayers + 1, len(ss))):
424
+ subsample[j] = int(ss[j])
425
+ else:
426
+ logging.warning(
427
+ "Subsampling is not performed for vgg*. "
428
+ "It is performed in max pooling layers at CNN."
429
+ )
430
+ logging.info("subsample: " + " ".join([str(x) for x in subsample]))
431
+ return subsample
432
+
433
+ elif mode == "asr" and arch == "rnn_mix":
434
+ subsample = np.ones(
435
+ train_args.elayers_sd + train_args.elayers + 1, dtype=np.int
436
+ )
437
+ if train_args.etype.endswith("p") and not train_args.etype.startswith("vgg"):
438
+ ss = train_args.subsample.split("_")
439
+ for j in range(
440
+ min(train_args.elayers_sd + train_args.elayers + 1, len(ss))
441
+ ):
442
+ subsample[j] = int(ss[j])
443
+ else:
444
+ logging.warning(
445
+ "Subsampling is not performed for vgg*. "
446
+ "It is performed in max pooling layers at CNN."
447
+ )
448
+ logging.info("subsample: " + " ".join([str(x) for x in subsample]))
449
+ return subsample
450
+
451
+ elif mode == "asr" and arch == "rnn_mulenc":
452
+ subsample_list = []
453
+ for idx in range(train_args.num_encs):
454
+ subsample = np.ones(train_args.elayers[idx] + 1, dtype=np.int)
455
+ if train_args.etype[idx].endswith("p") and not train_args.etype[
456
+ idx
457
+ ].startswith("vgg"):
458
+ ss = train_args.subsample[idx].split("_")
459
+ for j in range(min(train_args.elayers[idx] + 1, len(ss))):
460
+ subsample[j] = int(ss[j])
461
+ else:
462
+ logging.warning(
463
+ "Encoder %d: Subsampling is not performed for vgg*. "
464
+ "It is performed in max pooling layers at CNN.",
465
+ idx + 1,
466
+ )
467
+ logging.info("subsample: " + " ".join([str(x) for x in subsample]))
468
+ subsample_list.append(subsample)
469
+ return subsample_list
470
+
471
+ else:
472
+ raise ValueError("Invalid options: mode={}, arch={}".format(mode, arch))
473
+
474
+
475
+ def rename_state_dict(
476
+ old_prefix: str, new_prefix: str, state_dict: Dict[str, torch.Tensor]
477
+ ):
478
+ """Replace keys of old prefix with new prefix in state dict."""
479
+ # need this list not to break the dict iterator
480
+ old_keys = [k for k in state_dict if k.startswith(old_prefix)]
481
+ if len(old_keys) > 0:
482
+ logging.warning(f"Rename: {old_prefix} -> {new_prefix}")
483
+ for k in old_keys:
484
+ v = state_dict.pop(k)
485
+ new_k = k.replace(old_prefix, new_prefix)
486
+ state_dict[new_k] = v
487
+
488
+
489
+ def get_activation(act):
490
+ """Return activation function."""
491
+ # Lazy load to avoid unused import
492
+ from espnet.nets.pytorch_backend.conformer.swish import Swish
493
+
494
+ activation_funcs = {
495
+ "hardtanh": torch.nn.Hardtanh,
496
+ "tanh": torch.nn.Tanh,
497
+ "relu": torch.nn.ReLU,
498
+ "selu": torch.nn.SELU,
499
+ "swish": Swish,
500
+ }
501
+
502
+ return activation_funcs[act]()
503
+
504
+
505
+ class MLPHead(torch.nn.Module):
506
+ def __init__(self, idim, hdim, odim, norm="batchnorm"):
507
+ super(MLPHead, self).__init__()
508
+ self.norm = norm
509
+
510
+ self.fc1 = torch.nn.Linear(idim, hdim)
511
+ if norm == "batchnorm":
512
+ self.bn1 = torch.nn.BatchNorm1d(hdim)
513
+ elif norm == "layernorm":
514
+ self.norm1 = torch.nn.LayerNorm(hdim)
515
+ self.nonlin1 = torch.nn.ReLU(inplace=True)
516
+ self.fc2 = torch.nn.Linear( hdim, odim)
517
+
518
+ def forward(self, x):
519
+ x = self.fc1(x)
520
+ if self.norm == "batchnorm":
521
+ x = self.bn1(x.transpose(1,2)).transpose(1,2)
522
+ elif self.norm == "layernorm":
523
+ x = self.norm1(x)
524
+ x = self.nonlin1(x)
525
+ x = self.fc2(x)
526
+ return x
espnet/nets/pytorch_backend/transformer/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Initialize sub package."""
espnet/nets/pytorch_backend/transformer/add_sos_eos.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ # Copyright 2019 Shigeki Karita
5
+ # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
6
+
7
+ """Unility funcitons for Transformer."""
8
+
9
+ import torch
10
+
11
+
12
+ def add_sos_eos(ys_pad, sos, eos, ignore_id):
13
+ """Add <sos> and <eos> labels.
14
+
15
+ :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax)
16
+ :param int sos: index of <sos>
17
+ :param int eos: index of <eeos>
18
+ :param int ignore_id: index of padding
19
+ :return: padded tensor (B, Lmax)
20
+ :rtype: torch.Tensor
21
+ :return: padded tensor (B, Lmax)
22
+ :rtype: torch.Tensor
23
+ """
24
+ from espnet.nets.pytorch_backend.nets_utils import pad_list
25
+
26
+ _sos = ys_pad.new([sos])
27
+ _eos = ys_pad.new([eos])
28
+ ys = [y[y != ignore_id] for y in ys_pad] # parse padded ys
29
+ ys_in = [torch.cat([_sos, y], dim=0) for y in ys]
30
+ ys_out = [torch.cat([y, _eos], dim=0) for y in ys]
31
+ return pad_list(ys_in, eos), pad_list(ys_out, ignore_id)
espnet/nets/pytorch_backend/transformer/attention.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ # Copyright 2019 Shigeki Karita
5
+ # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
6
+
7
+ """Multi-Head Attention layer definition."""
8
+
9
+ import math
10
+
11
+ import numpy
12
+ import torch
13
+ from torch import nn
14
+
15
+
16
+ class MultiHeadedAttention(nn.Module):
17
+ """Multi-Head Attention layer.
18
+ Args:
19
+ n_head (int): The number of heads.
20
+ n_feat (int): The number of features.
21
+ dropout_rate (float): Dropout rate.
22
+ """
23
+
24
+ def __init__(self, n_head, n_feat, dropout_rate):
25
+ """Construct an MultiHeadedAttention object."""
26
+ super(MultiHeadedAttention, self).__init__()
27
+ assert n_feat % n_head == 0
28
+ # We assume d_v always equals d_k
29
+ self.d_k = n_feat // n_head
30
+ self.h = n_head
31
+ self.linear_q = nn.Linear(n_feat, n_feat)
32
+ self.linear_k = nn.Linear(n_feat, n_feat)
33
+ self.linear_v = nn.Linear(n_feat, n_feat)
34
+ self.linear_out = nn.Linear(n_feat, n_feat)
35
+ self.attn = None
36
+ self.dropout = nn.Dropout(p=dropout_rate)
37
+
38
+ def forward_qkv(self, query, key, value):
39
+ """Transform query, key and value.
40
+ Args:
41
+ query (torch.Tensor): Query tensor (#batch, time1, size).
42
+ key (torch.Tensor): Key tensor (#batch, time2, size).
43
+ value (torch.Tensor): Value tensor (#batch, time2, size).
44
+ Returns:
45
+ torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k).
46
+ torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k).
47
+ torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k).
48
+ """
49
+ n_batch = query.size(0)
50
+ q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
51
+ k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
52
+ v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
53
+ q = q.transpose(1, 2) # (batch, head, time1, d_k)
54
+ k = k.transpose(1, 2) # (batch, head, time2, d_k)
55
+ v = v.transpose(1, 2) # (batch, head, time2, d_k)
56
+
57
+ return q, k, v
58
+
59
+ def forward_attention(self, value, scores, mask, rtn_attn=False):
60
+ """Compute attention context vector.
61
+ Args:
62
+ value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k).
63
+ scores (torch.Tensor): Attention score (#batch, n_head, time1, time2).
64
+ mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2).
65
+ rtn_attn (boolean): Flag of return attention score
66
+ Returns:
67
+ torch.Tensor: Transformed value (#batch, time1, d_model)
68
+ weighted by the attention score (#batch, time1, time2).
69
+ """
70
+ n_batch = value.size(0)
71
+ if mask is not None:
72
+ mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
73
+ min_value = float(
74
+ numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min
75
+ )
76
+ scores = scores.masked_fill(mask, min_value)
77
+ self.attn = torch.softmax(scores, dim=-1).masked_fill(
78
+ mask, 0.0
79
+ ) # (batch, head, time1, time2)
80
+ else:
81
+ self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
82
+
83
+ p_attn = self.dropout(self.attn)
84
+ x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
85
+ x = (
86
+ x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
87
+ ) # (batch, time1, d_model)
88
+ if rtn_attn:
89
+ return self.linear_out(x), self.attn
90
+ return self.linear_out(x) # (batch, time1, d_model)
91
+
92
+ def forward(self, query, key, value, mask, rtn_attn=False):
93
+ """Compute scaled dot product attention.
94
+ Args:
95
+ query (torch.Tensor): Query tensor (#batch, time1, size).
96
+ key (torch.Tensor): Key tensor (#batch, time2, size).
97
+ value (torch.Tensor): Value tensor (#batch, time2, size).
98
+ mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
99
+ (#batch, time1, time2).
100
+ rtn_attn (boolean): Flag of return attention score
101
+ Returns:
102
+ torch.Tensor: Output tensor (#batch, time1, d_model).
103
+ """
104
+ q, k, v = self.forward_qkv(query, key, value)
105
+ scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
106
+ return self.forward_attention(v, scores, mask, rtn_attn)
107
+
108
+
109
+ class LegacyRelPositionMultiHeadedAttention(MultiHeadedAttention):
110
+ """Multi-Head Attention layer with relative position encoding (old version).
111
+ Details can be found in https://github.com/espnet/espnet/pull/2816.
112
+ Paper: https://arxiv.org/abs/1901.02860
113
+ Args:
114
+ n_head (int): The number of heads.
115
+ n_feat (int): The number of features.
116
+ dropout_rate (float): Dropout rate.
117
+ zero_triu (bool): Whether to zero the upper triangular part of attention matrix.
118
+ """
119
+
120
+ def __init__(self, n_head, n_feat, dropout_rate, zero_triu=False):
121
+ """Construct an RelPositionMultiHeadedAttention object."""
122
+ super().__init__(n_head, n_feat, dropout_rate)
123
+ self.zero_triu = zero_triu
124
+ # linear transformation for positional encoding
125
+ self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
126
+ # these two learnable bias are used in matrix c and matrix d
127
+ # as described in https://arxiv.org/abs/1901.02860 Section 3.3
128
+ self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
129
+ self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k))
130
+ torch.nn.init.xavier_uniform_(self.pos_bias_u)
131
+ torch.nn.init.xavier_uniform_(self.pos_bias_v)
132
+
133
+ def rel_shift(self, x):
134
+ """Compute relative positional encoding.
135
+ Args:
136
+ x (torch.Tensor): Input tensor (batch, head, time1, time2).
137
+ Returns:
138
+ torch.Tensor: Output tensor.
139
+ """
140
+ zero_pad = torch.zeros((*x.size()[:3], 1), device=x.device, dtype=x.dtype)
141
+ x_padded = torch.cat([zero_pad, x], dim=-1)
142
+
143
+ x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2))
144
+ x = x_padded[:, :, 1:].view_as(x)
145
+
146
+ if self.zero_triu:
147
+ ones = torch.ones((x.size(2), x.size(3)))
148
+ x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :]
149
+
150
+ return x
151
+
152
+ def forward(self, query, key, value, pos_emb, mask):
153
+ """Compute 'Scaled Dot Product Attention' with rel. positional encoding.
154
+ Args:
155
+ query (torch.Tensor): Query tensor (#batch, time1, size).
156
+ key (torch.Tensor): Key tensor (#batch, time2, size).
157
+ value (torch.Tensor): Value tensor (#batch, time2, size).
158
+ pos_emb (torch.Tensor): Positional embedding tensor (#batch, time1, size).
159
+ mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
160
+ (#batch, time1, time2).
161
+ Returns:
162
+ torch.Tensor: Output tensor (#batch, time1, d_model).
163
+ """
164
+ q, k, v = self.forward_qkv(query, key, value)
165
+ q = q.transpose(1, 2) # (batch, time1, head, d_k)
166
+
167
+ n_batch_pos = pos_emb.size(0)
168
+ p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
169
+ p = p.transpose(1, 2) # (batch, head, time1, d_k)
170
+
171
+ # (batch, head, time1, d_k)
172
+ q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
173
+ # (batch, head, time1, d_k)
174
+ q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
175
+
176
+ # compute attention score
177
+ # first compute matrix a and matrix c
178
+ # as described in https://arxiv.org/abs/1901.02860 Section 3.3
179
+ # (batch, head, time1, time2)
180
+ matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
181
+
182
+ # compute matrix b and matrix d
183
+ # (batch, head, time1, time1)
184
+ matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
185
+ matrix_bd = self.rel_shift(matrix_bd)
186
+
187
+ scores = (matrix_ac + matrix_bd) / math.sqrt(
188
+ self.d_k
189
+ ) # (batch, head, time1, time2)
190
+
191
+ return self.forward_attention(v, scores, mask)
192
+
193
+
194
+ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
195
+ """Multi-Head Attention layer with relative position encoding (new implementation).
196
+ Details can be found in https://github.com/espnet/espnet/pull/2816.
197
+ Paper: https://arxiv.org/abs/1901.02860
198
+ Args:
199
+ n_head (int): The number of heads.
200
+ n_feat (int): The number of features.
201
+ dropout_rate (float): Dropout rate.
202
+ zero_triu (bool): Whether to zero the upper triangular part of attention matrix.
203
+ """
204
+
205
+ def __init__(self, n_head, n_feat, dropout_rate, zero_triu=False):
206
+ """Construct an RelPositionMultiHeadedAttention object."""
207
+ super().__init__(n_head, n_feat, dropout_rate)
208
+ self.zero_triu = zero_triu
209
+ # linear transformation for positional encoding
210
+ self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
211
+ # these two learnable bias are used in matrix c and matrix d
212
+ # as described in https://arxiv.org/abs/1901.02860 Section 3.3
213
+ self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
214
+ self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k))
215
+ torch.nn.init.xavier_uniform_(self.pos_bias_u)
216
+ torch.nn.init.xavier_uniform_(self.pos_bias_v)
217
+
218
+ def rel_shift(self, x):
219
+ """Compute relative positional encoding.
220
+ Args:
221
+ x (torch.Tensor): Input tensor (batch, head, time1, 2*time1-1).
222
+ time1 means the length of query vector.
223
+ Returns:
224
+ torch.Tensor: Output tensor.
225
+ """
226
+ zero_pad = torch.zeros((*x.size()[:3], 1), device=x.device, dtype=x.dtype)
227
+ x_padded = torch.cat([zero_pad, x], dim=-1)
228
+
229
+ x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2))
230
+ x = x_padded[:, :, 1:].view_as(x)[
231
+ :, :, :, : x.size(-1) // 2 + 1
232
+ ] # only keep the positions from 0 to time2
233
+
234
+ if self.zero_triu:
235
+ ones = torch.ones((x.size(2), x.size(3)), device=x.device)
236
+ x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :]
237
+
238
+ return x
239
+
240
+ def forward(self, query, key, value, pos_emb, mask):
241
+ """Compute 'Scaled Dot Product Attention' with rel. positional encoding.
242
+ Args:
243
+ query (torch.Tensor): Query tensor (#batch, time1, size).
244
+ key (torch.Tensor): Key tensor (#batch, time2, size).
245
+ value (torch.Tensor): Value tensor (#batch, time2, size).
246
+ pos_emb (torch.Tensor): Positional embedding tensor
247
+ (#batch, 2*time1-1, size).
248
+ mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
249
+ (#batch, time1, time2).
250
+ Returns:
251
+ torch.Tensor: Output tensor (#batch, time1, d_model).
252
+ """
253
+ q, k, v = self.forward_qkv(query, key, value)
254
+ q = q.transpose(1, 2) # (batch, time1, head, d_k)
255
+
256
+ n_batch_pos = pos_emb.size(0)
257
+ p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
258
+ p = p.transpose(1, 2) # (batch, head, 2*time1-1, d_k)
259
+
260
+ # (batch, head, time1, d_k)
261
+ q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
262
+ # (batch, head, time1, d_k)
263
+ q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
264
+
265
+ # compute attention score
266
+ # first compute matrix a and matrix c
267
+ # as described in https://arxiv.org/abs/1901.02860 Section 3.3
268
+ # (batch, head, time1, time2)
269
+ matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
270
+
271
+ # compute matrix b and matrix d
272
+ # (batch, head, time1, 2*time1-1)
273
+ matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
274
+ matrix_bd = self.rel_shift(matrix_bd)
275
+
276
+ scores = (matrix_ac + matrix_bd) / math.sqrt(
277
+ self.d_k
278
+ ) # (batch, head, time1, time2)
279
+
280
+ return self.forward_attention(v, scores, mask)
espnet/nets/pytorch_backend/transformer/convolution.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ # Copyright 2020 Johns Hopkins University (Shinji Watanabe)
5
+ # Northwestern Polytechnical University (Pengcheng Guo)
6
+ # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
7
+
8
+ """ConvolutionModule definition."""
9
+
10
+ import torch
11
+ from torch import nn
12
+
13
+
14
+ class ConvolutionModule(nn.Module):
15
+ """ConvolutionModule in Conformer model.
16
+
17
+ :param int channels: channels of cnn
18
+ :param int kernel_size: kernerl size of cnn
19
+
20
+ """
21
+
22
+ def __init__(self, channels, kernel_size, bias=True):
23
+ """Construct an ConvolutionModule object."""
24
+ super(ConvolutionModule, self).__init__()
25
+ # kernerl_size should be a odd number for 'SAME' padding
26
+ assert (kernel_size - 1) % 2 == 0
27
+
28
+ self.pointwise_cov1 = nn.Conv1d(
29
+ channels, 2 * channels, kernel_size=1, stride=1, padding=0, bias=bias,
30
+ )
31
+ self.depthwise_conv = nn.Conv1d(
32
+ channels,
33
+ channels,
34
+ kernel_size,
35
+ stride=1,
36
+ padding=(kernel_size - 1) // 2,
37
+ groups=channels,
38
+ bias=bias,
39
+ )
40
+ self.norm = nn.BatchNorm1d(channels)
41
+ self.pointwise_cov2 = nn.Conv1d(
42
+ channels, channels, kernel_size=1, stride=1, padding=0, bias=bias,
43
+ )
44
+ self.activation = Swish()
45
+
46
+ def forward(self, x):
47
+ """Compute covolution module.
48
+
49
+ :param torch.Tensor x: (batch, time, size)
50
+ :return torch.Tensor: convoluted `value` (batch, time, d_model)
51
+ """
52
+ # exchange the temporal dimension and the feature dimension
53
+ x = x.transpose(1, 2)
54
+
55
+ # GLU mechanism
56
+ x = self.pointwise_cov1(x) # (batch, 2*channel, dim)
57
+ x = nn.functional.glu(x, dim=1) # (batch, channel, dim)
58
+
59
+ # 1D Depthwise Conv
60
+ x = self.depthwise_conv(x)
61
+ x = self.activation(self.norm(x))
62
+
63
+ x = self.pointwise_cov2(x)
64
+
65
+ return x.transpose(1, 2)
66
+
67
+
68
+ class Swish(nn.Module):
69
+ """Construct an Swish object."""
70
+
71
+ def forward(self, x):
72
+ """Return Swich activation function."""
73
+ return x * torch.sigmoid(x)
espnet/nets/pytorch_backend/transformer/decoder.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ # Copyright 2019 Shigeki Karita
5
+ # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
6
+
7
+ """Decoder definition."""
8
+
9
+ from typing import Any
10
+ from typing import List
11
+ from typing import Tuple
12
+
13
+ import torch
14
+
15
+ from espnet.nets.pytorch_backend.nets_utils import rename_state_dict
16
+ from espnet.nets.pytorch_backend.transformer.attention import MultiHeadedAttention
17
+ from espnet.nets.pytorch_backend.transformer.decoder_layer import DecoderLayer
18
+ from espnet.nets.pytorch_backend.transformer.embedding import PositionalEncoding
19
+ from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm
20
+ from espnet.nets.pytorch_backend.transformer.mask import subsequent_mask
21
+ from espnet.nets.pytorch_backend.transformer.positionwise_feed_forward import (
22
+ PositionwiseFeedForward, # noqa: H301
23
+ )
24
+ from espnet.nets.pytorch_backend.transformer.repeat import repeat
25
+ from espnet.nets.scorer_interface import BatchScorerInterface
26
+
27
+
28
+ def _pre_hook(
29
+ state_dict,
30
+ prefix,
31
+ local_metadata,
32
+ strict,
33
+ missing_keys,
34
+ unexpected_keys,
35
+ error_msgs,
36
+ ):
37
+ # https://github.com/espnet/espnet/commit/3d422f6de8d4f03673b89e1caef698745ec749ea#diff-bffb1396f038b317b2b64dd96e6d3563
38
+ rename_state_dict(prefix + "output_norm.", prefix + "after_norm.", state_dict)
39
+
40
+
41
+ class Decoder(BatchScorerInterface, torch.nn.Module):
42
+ """Transfomer decoder module.
43
+
44
+ :param int odim: output dim
45
+ :param int attention_dim: dimention of attention
46
+ :param int attention_heads: the number of heads of multi head attention
47
+ :param int linear_units: the number of units of position-wise feed forward
48
+ :param int num_blocks: the number of decoder blocks
49
+ :param float dropout_rate: dropout rate
50
+ :param float attention_dropout_rate: dropout rate for attention
51
+ :param str or torch.nn.Module input_layer: input layer type
52
+ :param bool use_output_layer: whether to use output layer
53
+ :param class pos_enc_class: PositionalEncoding or ScaledPositionalEncoding
54
+ :param bool normalize_before: whether to use layer_norm before the first block
55
+ :param bool concat_after: whether to concat attention layer's input and output
56
+ if True, additional linear will be applied.
57
+ i.e. x -> x + linear(concat(x, att(x)))
58
+ if False, no additional linear will be applied. i.e. x -> x + att(x)
59
+ """
60
+
61
+ def __init__(
62
+ self,
63
+ odim,
64
+ attention_dim=256,
65
+ attention_heads=4,
66
+ linear_units=2048,
67
+ num_blocks=6,
68
+ dropout_rate=0.1,
69
+ positional_dropout_rate=0.1,
70
+ self_attention_dropout_rate=0.0,
71
+ src_attention_dropout_rate=0.0,
72
+ input_layer="embed",
73
+ use_output_layer=True,
74
+ pos_enc_class=PositionalEncoding,
75
+ normalize_before=True,
76
+ concat_after=False,
77
+ ):
78
+ """Construct an Decoder object."""
79
+ torch.nn.Module.__init__(self)
80
+ self._register_load_state_dict_pre_hook(_pre_hook)
81
+ if input_layer == "embed":
82
+ self.embed = torch.nn.Sequential(
83
+ torch.nn.Embedding(odim, attention_dim),
84
+ pos_enc_class(attention_dim, positional_dropout_rate),
85
+ )
86
+ elif input_layer == "linear":
87
+ self.embed = torch.nn.Sequential(
88
+ torch.nn.Linear(odim, attention_dim),
89
+ torch.nn.LayerNorm(attention_dim),
90
+ torch.nn.Dropout(dropout_rate),
91
+ torch.nn.ReLU(),
92
+ pos_enc_class(attention_dim, positional_dropout_rate),
93
+ )
94
+ elif isinstance(input_layer, torch.nn.Module):
95
+ self.embed = torch.nn.Sequential(
96
+ input_layer, pos_enc_class(attention_dim, positional_dropout_rate)
97
+ )
98
+ else:
99
+ raise NotImplementedError("only `embed` or torch.nn.Module is supported.")
100
+ self.normalize_before = normalize_before
101
+ self.decoders = repeat(
102
+ num_blocks,
103
+ lambda: DecoderLayer(
104
+ attention_dim,
105
+ MultiHeadedAttention(
106
+ attention_heads, attention_dim, self_attention_dropout_rate
107
+ ),
108
+ MultiHeadedAttention(
109
+ attention_heads, attention_dim, src_attention_dropout_rate
110
+ ),
111
+ PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
112
+ dropout_rate,
113
+ normalize_before,
114
+ concat_after,
115
+ ),
116
+ )
117
+ if self.normalize_before:
118
+ self.after_norm = LayerNorm(attention_dim)
119
+ if use_output_layer:
120
+ self.output_layer = torch.nn.Linear(attention_dim, odim)
121
+ else:
122
+ self.output_layer = None
123
+
124
+ def forward(self, tgt, tgt_mask, memory, memory_mask):
125
+ """Forward decoder.
126
+ :param torch.Tensor tgt: input token ids, int64 (batch, maxlen_out)
127
+ if input_layer == "embed"
128
+ input tensor (batch, maxlen_out, #mels)
129
+ in the other cases
130
+ :param torch.Tensor tgt_mask: input token mask, (batch, maxlen_out)
131
+ dtype=torch.uint8 in PyTorch 1.2-
132
+ dtype=torch.bool in PyTorch 1.2+ (include 1.2)
133
+ :param torch.Tensor memory: encoded memory, float32 (batch, maxlen_in, feat)
134
+ :param torch.Tensor memory_mask: encoded memory mask, (batch, maxlen_in)
135
+ dtype=torch.uint8 in PyTorch 1.2-
136
+ dtype=torch.bool in PyTorch 1.2+ (include 1.2)
137
+ :return x: decoded token score before softmax (batch, maxlen_out, token)
138
+ if use_output_layer is True,
139
+ final block outputs (batch, maxlen_out, attention_dim)
140
+ in the other cases
141
+ :rtype: torch.Tensor
142
+ :return tgt_mask: score mask before softmax (batch, maxlen_out)
143
+ :rtype: torch.Tensor
144
+ """
145
+ x = self.embed(tgt)
146
+ x, tgt_mask, memory, memory_mask = self.decoders(
147
+ x, tgt_mask, memory, memory_mask
148
+ )
149
+ if self.normalize_before:
150
+ x = self.after_norm(x)
151
+ if self.output_layer is not None:
152
+ x = self.output_layer(x)
153
+ return x, tgt_mask
154
+
155
+ def forward_one_step(self, tgt, tgt_mask, memory, memory_mask=None, cache=None):
156
+ """Forward one step.
157
+ :param torch.Tensor tgt: input token ids, int64 (batch, maxlen_out)
158
+ :param torch.Tensor tgt_mask: input token mask, (batch, maxlen_out)
159
+ dtype=torch.uint8 in PyTorch 1.2-
160
+ dtype=torch.bool in PyTorch 1.2+ (include 1.2)
161
+ :param torch.Tensor memory: encoded memory, float32 (batch, maxlen_in, feat)
162
+ :param List[torch.Tensor] cache:
163
+ cached output list of (batch, max_time_out-1, size)
164
+ :return y, cache: NN output value and cache per `self.decoders`.
165
+ `y.shape` is (batch, maxlen_out, token)
166
+ :rtype: Tuple[torch.Tensor, List[torch.Tensor]]
167
+ """
168
+ x = self.embed(tgt)
169
+ if cache is None:
170
+ cache = [None] * len(self.decoders)
171
+ new_cache = []
172
+ for c, decoder in zip(cache, self.decoders):
173
+ x, tgt_mask, memory, memory_mask = decoder(
174
+ x, tgt_mask, memory, memory_mask, cache=c
175
+ )
176
+ new_cache.append(x)
177
+
178
+ if self.normalize_before:
179
+ y = self.after_norm(x[:, -1])
180
+ else:
181
+ y = x[:, -1]
182
+ if self.output_layer is not None:
183
+ y = torch.log_softmax(self.output_layer(y), dim=-1)
184
+
185
+ return y, new_cache
186
+
187
+ # beam search API (see ScorerInterface)
188
+ def score(self, ys, state, x):
189
+ """Score."""
190
+ ys_mask = subsequent_mask(len(ys), device=x.device).unsqueeze(0)
191
+ logp, state = self.forward_one_step(
192
+ ys.unsqueeze(0), ys_mask, x.unsqueeze(0), cache=state
193
+ )
194
+ return logp.squeeze(0), state
195
+
196
+ # batch beam search API (see BatchScorerInterface)
197
+ def batch_score(
198
+ self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor
199
+ ) -> Tuple[torch.Tensor, List[Any]]:
200
+ """Score new token batch (required).
201
+ Args:
202
+ ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen).
203
+ states (List[Any]): Scorer states for prefix tokens.
204
+ xs (torch.Tensor):
205
+ The encoder feature that generates ys (n_batch, xlen, n_feat).
206
+ Returns:
207
+ tuple[torch.Tensor, List[Any]]: Tuple of
208
+ batchfied scores for next token with shape of `(n_batch, n_vocab)`
209
+ and next state list for ys.
210
+ """
211
+ # merge states
212
+ n_batch = len(ys)
213
+ n_layers = len(self.decoders)
214
+ if states[0] is None:
215
+ batch_state = None
216
+ else:
217
+ # transpose state of [batch, layer] into [layer, batch]
218
+ batch_state = [
219
+ torch.stack([states[b][l] for b in range(n_batch)])
220
+ for l in range(n_layers)
221
+ ]
222
+
223
+ # batch decoding
224
+ ys_mask = subsequent_mask(ys.size(-1), device=xs.device).unsqueeze(0)
225
+ logp, states = self.forward_one_step(ys, ys_mask, xs, cache=batch_state)
226
+
227
+ # transpose state of [layer, batch] into [batch, layer]
228
+ state_list = [[states[l][b] for l in range(n_layers)] for b in range(n_batch)]
229
+ return logp, state_list
espnet/nets/pytorch_backend/transformer/decoder_layer.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ # Copyright 2019 Shigeki Karita
5
+ # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
6
+
7
+ """Decoder self-attention layer definition."""
8
+
9
+ import torch
10
+ from torch import nn
11
+
12
+ from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm
13
+
14
+
15
+ class DecoderLayer(nn.Module):
16
+ """Single decoder layer module.
17
+ :param int size: input dim
18
+ :param espnet.nets.pytorch_backend.transformer.attention.MultiHeadedAttention
19
+ self_attn: self attention module
20
+ :param espnet.nets.pytorch_backend.transformer.attention.MultiHeadedAttention
21
+ src_attn: source attention module
22
+ :param espnet.nets.pytorch_backend.transformer.positionwise_feed_forward.
23
+ PositionwiseFeedForward feed_forward: feed forward layer module
24
+ :param float dropout_rate: dropout rate
25
+ :param bool normalize_before: whether to use layer_norm before the first block
26
+ :param bool concat_after: whether to concat attention layer's input and output
27
+ if True, additional linear will be applied.
28
+ i.e. x -> x + linear(concat(x, att(x)))
29
+ if False, no additional linear will be applied. i.e. x -> x + att(x)
30
+ """
31
+
32
+ def __init__(
33
+ self,
34
+ size,
35
+ self_attn,
36
+ src_attn,
37
+ feed_forward,
38
+ dropout_rate,
39
+ normalize_before=True,
40
+ concat_after=False,
41
+ ):
42
+ """Construct an DecoderLayer object."""
43
+ super(DecoderLayer, self).__init__()
44
+ self.size = size
45
+ self.self_attn = self_attn
46
+ self.src_attn = src_attn
47
+ self.feed_forward = feed_forward
48
+ self.norm1 = LayerNorm(size)
49
+ self.norm2 = LayerNorm(size)
50
+ self.norm3 = LayerNorm(size)
51
+ self.dropout = nn.Dropout(dropout_rate)
52
+ self.normalize_before = normalize_before
53
+ self.concat_after = concat_after
54
+ if self.concat_after:
55
+ self.concat_linear1 = nn.Linear(size + size, size)
56
+ self.concat_linear2 = nn.Linear(size + size, size)
57
+
58
+ def forward(self, tgt, tgt_mask, memory, memory_mask, cache=None):
59
+ """Compute decoded features.
60
+ Args:
61
+ tgt (torch.Tensor):
62
+ decoded previous target features (batch, max_time_out, size)
63
+ tgt_mask (torch.Tensor): mask for x (batch, max_time_out)
64
+ memory (torch.Tensor): encoded source features (batch, max_time_in, size)
65
+ memory_mask (torch.Tensor): mask for memory (batch, max_time_in)
66
+ cache (torch.Tensor): cached output (batch, max_time_out-1, size)
67
+ """
68
+ residual = tgt
69
+ if self.normalize_before:
70
+ tgt = self.norm1(tgt)
71
+
72
+ if cache is None:
73
+ tgt_q = tgt
74
+ tgt_q_mask = tgt_mask
75
+ else:
76
+ # compute only the last frame query keeping dim: max_time_out -> 1
77
+ assert cache.shape == (
78
+ tgt.shape[0],
79
+ tgt.shape[1] - 1,
80
+ self.size,
81
+ ), f"{cache.shape} == {(tgt.shape[0], tgt.shape[1] - 1, self.size)}"
82
+ tgt_q = tgt[:, -1:, :]
83
+ residual = residual[:, -1:, :]
84
+ tgt_q_mask = None
85
+ if tgt_mask is not None:
86
+ tgt_q_mask = tgt_mask[:, -1:, :]
87
+
88
+ if self.concat_after:
89
+ tgt_concat = torch.cat(
90
+ (tgt_q, self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)), dim=-1
91
+ )
92
+ x = residual + self.concat_linear1(tgt_concat)
93
+ else:
94
+ x = residual + self.dropout(self.self_attn(tgt_q, tgt, tgt, tgt_q_mask))
95
+ if not self.normalize_before:
96
+ x = self.norm1(x)
97
+
98
+ residual = x
99
+ if self.normalize_before:
100
+ x = self.norm2(x)
101
+ if self.concat_after:
102
+ x_concat = torch.cat(
103
+ (x, self.src_attn(x, memory, memory, memory_mask)), dim=-1
104
+ )
105
+ x = residual + self.concat_linear2(x_concat)
106
+ else:
107
+ x = residual + self.dropout(self.src_attn(x, memory, memory, memory_mask))
108
+ if not self.normalize_before:
109
+ x = self.norm2(x)
110
+
111
+ residual = x
112
+ if self.normalize_before:
113
+ x = self.norm3(x)
114
+ x = residual + self.dropout(self.feed_forward(x))
115
+ if not self.normalize_before:
116
+ x = self.norm3(x)
117
+
118
+ if cache is not None:
119
+ x = torch.cat([cache, x], dim=1)
120
+
121
+ return x, tgt_mask, memory, memory_mask
espnet/nets/pytorch_backend/transformer/embedding.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ # Copyright 2019 Shigeki Karita
5
+ # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
6
+
7
+ """Positional Encoding Module."""
8
+
9
+ import math
10
+
11
+ import torch
12
+
13
+
14
+ def _pre_hook(
15
+ state_dict,
16
+ prefix,
17
+ local_metadata,
18
+ strict,
19
+ missing_keys,
20
+ unexpected_keys,
21
+ error_msgs,
22
+ ):
23
+ """Perform pre-hook in load_state_dict for backward compatibility.
24
+ Note:
25
+ We saved self.pe until v.0.5.2 but we have omitted it later.
26
+ Therefore, we remove the item "pe" from `state_dict` for backward compatibility.
27
+ """
28
+ k = prefix + "pe"
29
+ if k in state_dict:
30
+ state_dict.pop(k)
31
+
32
+
33
+ class PositionalEncoding(torch.nn.Module):
34
+ """Positional encoding.
35
+ Args:
36
+ d_model (int): Embedding dimension.
37
+ dropout_rate (float): Dropout rate.
38
+ max_len (int): Maximum input length.
39
+ reverse (bool): Whether to reverse the input position. Only for
40
+ the class LegacyRelPositionalEncoding. We remove it in the current
41
+ class RelPositionalEncoding.
42
+ """
43
+
44
+ def __init__(self, d_model, dropout_rate, max_len=5000, reverse=False):
45
+ """Construct an PositionalEncoding object."""
46
+ super(PositionalEncoding, self).__init__()
47
+ self.d_model = d_model
48
+ self.reverse = reverse
49
+ self.xscale = math.sqrt(self.d_model)
50
+ self.dropout = torch.nn.Dropout(p=dropout_rate)
51
+ self.pe = None
52
+ self.extend_pe(torch.tensor(0.0).expand(1, max_len))
53
+ self._register_load_state_dict_pre_hook(_pre_hook)
54
+
55
+ def extend_pe(self, x):
56
+ """Reset the positional encodings."""
57
+ if self.pe is not None:
58
+ if self.pe.size(1) >= x.size(1):
59
+ if self.pe.dtype != x.dtype or self.pe.device != x.device:
60
+ self.pe = self.pe.to(dtype=x.dtype, device=x.device)
61
+ return
62
+ pe = torch.zeros(x.size(1), self.d_model)
63
+ if self.reverse:
64
+ position = torch.arange(
65
+ x.size(1) - 1, -1, -1.0, dtype=torch.float32
66
+ ).unsqueeze(1)
67
+ else:
68
+ position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
69
+ div_term = torch.exp(
70
+ torch.arange(0, self.d_model, 2, dtype=torch.float32)
71
+ * -(math.log(10000.0) / self.d_model)
72
+ )
73
+ pe[:, 0::2] = torch.sin(position * div_term)
74
+ pe[:, 1::2] = torch.cos(position * div_term)
75
+ pe = pe.unsqueeze(0)
76
+ self.pe = pe.to(device=x.device, dtype=x.dtype)
77
+
78
+ def forward(self, x: torch.Tensor):
79
+ """Add positional encoding.
80
+ Args:
81
+ x (torch.Tensor): Input tensor (batch, time, `*`).
82
+ Returns:
83
+ torch.Tensor: Encoded tensor (batch, time, `*`).
84
+ """
85
+ self.extend_pe(x)
86
+ x = x * self.xscale + self.pe[:, : x.size(1)]
87
+ return self.dropout(x)
88
+
89
+
90
+ class ScaledPositionalEncoding(PositionalEncoding):
91
+ """Scaled positional encoding module.
92
+ See Sec. 3.2 https://arxiv.org/abs/1809.08895
93
+ Args:
94
+ d_model (int): Embedding dimension.
95
+ dropout_rate (float): Dropout rate.
96
+ max_len (int): Maximum input length.
97
+ """
98
+
99
+ def __init__(self, d_model, dropout_rate, max_len=5000):
100
+ """Initialize class."""
101
+ super().__init__(d_model=d_model, dropout_rate=dropout_rate, max_len=max_len)
102
+ self.alpha = torch.nn.Parameter(torch.tensor(1.0))
103
+
104
+ def reset_parameters(self):
105
+ """Reset parameters."""
106
+ self.alpha.data = torch.tensor(1.0)
107
+
108
+ def forward(self, x):
109
+ """Add positional encoding.
110
+ Args:
111
+ x (torch.Tensor): Input tensor (batch, time, `*`).
112
+ Returns:
113
+ torch.Tensor: Encoded tensor (batch, time, `*`).
114
+ """
115
+ self.extend_pe(x)
116
+ x = x + self.alpha * self.pe[:, : x.size(1)]
117
+ return self.dropout(x)
118
+
119
+
120
+ class LegacyRelPositionalEncoding(PositionalEncoding):
121
+ """Relative positional encoding module (old version).
122
+ Details can be found in https://github.com/espnet/espnet/pull/2816.
123
+ See : Appendix B in https://arxiv.org/abs/1901.02860
124
+ Args:
125
+ d_model (int): Embedding dimension.
126
+ dropout_rate (float): Dropout rate.
127
+ max_len (int): Maximum input length.
128
+ """
129
+
130
+ def __init__(self, d_model, dropout_rate, max_len=5000):
131
+ """Initialize class."""
132
+ super().__init__(
133
+ d_model=d_model,
134
+ dropout_rate=dropout_rate,
135
+ max_len=max_len,
136
+ reverse=True,
137
+ )
138
+
139
+ def forward(self, x):
140
+ """Compute positional encoding.
141
+ Args:
142
+ x (torch.Tensor): Input tensor (batch, time, `*`).
143
+ Returns:
144
+ torch.Tensor: Encoded tensor (batch, time, `*`).
145
+ torch.Tensor: Positional embedding tensor (1, time, `*`).
146
+ """
147
+ self.extend_pe(x)
148
+ x = x * self.xscale
149
+ pos_emb = self.pe[:, : x.size(1)]
150
+ return self.dropout(x), self.dropout(pos_emb)
151
+
152
+
153
+ class RelPositionalEncoding(torch.nn.Module):
154
+ """Relative positional encoding module (new implementation).
155
+ Details can be found in https://github.com/espnet/espnet/pull/2816.
156
+ See : Appendix B in https://arxiv.org/abs/1901.02860
157
+ Args:
158
+ d_model (int): Embedding dimension.
159
+ dropout_rate (float): Dropout rate.
160
+ max_len (int): Maximum input length.
161
+ """
162
+
163
+ def __init__(self, d_model, dropout_rate, max_len=5000):
164
+ """Construct an PositionalEncoding object."""
165
+ super(RelPositionalEncoding, self).__init__()
166
+ self.d_model = d_model
167
+ self.xscale = math.sqrt(self.d_model)
168
+ self.dropout = torch.nn.Dropout(p=dropout_rate)
169
+ self.pe = None
170
+ self.extend_pe(torch.tensor(0.0).expand(1, max_len))
171
+
172
+ def extend_pe(self, x):
173
+ """Reset the positional encodings."""
174
+ if self.pe is not None:
175
+ # self.pe contains both positive and negative parts
176
+ # the length of self.pe is 2 * input_len - 1
177
+ if self.pe.size(1) >= x.size(1) * 2 - 1:
178
+ if self.pe.dtype != x.dtype or self.pe.device != x.device:
179
+ self.pe = self.pe.to(dtype=x.dtype, device=x.device)
180
+ return
181
+ # Suppose `i` means to the position of query vecotr and `j` means the
182
+ # position of key vector. We use position relative positions when keys
183
+ # are to the left (i>j) and negative relative positions otherwise (i<j).
184
+ pe_positive = torch.zeros(x.size(1), self.d_model)
185
+ pe_negative = torch.zeros(x.size(1), self.d_model)
186
+ position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
187
+ div_term = torch.exp(
188
+ torch.arange(0, self.d_model, 2, dtype=torch.float32)
189
+ * -(math.log(10000.0) / self.d_model)
190
+ )
191
+ pe_positive[:, 0::2] = torch.sin(position * div_term)
192
+ pe_positive[:, 1::2] = torch.cos(position * div_term)
193
+ pe_negative[:, 0::2] = torch.sin(-1 * position * div_term)
194
+ pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)
195
+
196
+ # Reserve the order of positive indices and concat both positive and
197
+ # negative indices. This is used to support the shifting trick
198
+ # as in https://arxiv.org/abs/1901.02860
199
+ pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
200
+ pe_negative = pe_negative[1:].unsqueeze(0)
201
+ pe = torch.cat([pe_positive, pe_negative], dim=1)
202
+ self.pe = pe.to(device=x.device, dtype=x.dtype)
203
+
204
+ def forward(self, x: torch.Tensor):
205
+ """Add positional encoding.
206
+ Args:
207
+ x (torch.Tensor): Input tensor (batch, time, `*`).
208
+ Returns:
209
+ torch.Tensor: Encoded tensor (batch, time, `*`).
210
+ """
211
+ self.extend_pe(x)
212
+ x = x * self.xscale
213
+ pos_emb = self.pe[
214
+ :,
215
+ self.pe.size(1) // 2 - x.size(1) + 1 : self.pe.size(1) // 2 + x.size(1),
216
+ ]
217
+ return self.dropout(x), self.dropout(pos_emb)
espnet/nets/pytorch_backend/transformer/encoder.py ADDED
@@ -0,0 +1,283 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ # Copyright 2019 Shigeki Karita
5
+ # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
6
+
7
+ """Encoder definition."""
8
+
9
+ import torch
10
+
11
+ from espnet.nets.pytorch_backend.nets_utils import rename_state_dict
12
+ #from espnet.nets.pytorch_backend.transducer.vgg import VGG2L
13
+ from espnet.nets.pytorch_backend.transformer.attention import (
14
+ MultiHeadedAttention, # noqa: H301
15
+ RelPositionMultiHeadedAttention, # noqa: H301
16
+ LegacyRelPositionMultiHeadedAttention, # noqa: H301
17
+ )
18
+ from espnet.nets.pytorch_backend.transformer.convolution import ConvolutionModule
19
+ from espnet.nets.pytorch_backend.transformer.embedding import (
20
+ PositionalEncoding, # noqa: H301
21
+ RelPositionalEncoding, # noqa: H301
22
+ LegacyRelPositionalEncoding, # noqa: H301
23
+ )
24
+ from espnet.nets.pytorch_backend.transformer.encoder_layer import EncoderLayer
25
+ from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm
26
+ from espnet.nets.pytorch_backend.transformer.multi_layer_conv import Conv1dLinear
27
+ from espnet.nets.pytorch_backend.transformer.multi_layer_conv import MultiLayeredConv1d
28
+ from espnet.nets.pytorch_backend.transformer.positionwise_feed_forward import (
29
+ PositionwiseFeedForward, # noqa: H301
30
+ )
31
+ from espnet.nets.pytorch_backend.transformer.repeat import repeat
32
+ from espnet.nets.pytorch_backend.transformer.subsampling import Conv2dSubsampling
33
+ from espnet.nets.pytorch_backend.transformer.raw_embeddings import VideoEmbedding
34
+ from espnet.nets.pytorch_backend.transformer.raw_embeddings import AudioEmbedding
35
+ from espnet.nets.pytorch_backend.backbones.conv3d_extractor import Conv3dResNet
36
+ from espnet.nets.pytorch_backend.backbones.conv1d_extractor import Conv1dResNet
37
+
38
+
39
+ def _pre_hook(
40
+ state_dict,
41
+ prefix,
42
+ local_metadata,
43
+ strict,
44
+ missing_keys,
45
+ unexpected_keys,
46
+ error_msgs,
47
+ ):
48
+ # https://github.com/espnet/espnet/commit/21d70286c354c66c0350e65dc098d2ee236faccc#diff-bffb1396f038b317b2b64dd96e6d3563
49
+ rename_state_dict(prefix + "input_layer.", prefix + "embed.", state_dict)
50
+ # https://github.com/espnet/espnet/commit/3d422f6de8d4f03673b89e1caef698745ec749ea#diff-bffb1396f038b317b2b64dd96e6d3563
51
+ rename_state_dict(prefix + "norm.", prefix + "after_norm.", state_dict)
52
+
53
+
54
+ class Encoder(torch.nn.Module):
55
+ """Transformer encoder module.
56
+
57
+ :param int idim: input dim
58
+ :param int attention_dim: dimention of attention
59
+ :param int attention_heads: the number of heads of multi head attention
60
+ :param int linear_units: the number of units of position-wise feed forward
61
+ :param int num_blocks: the number of decoder blocks
62
+ :param float dropout_rate: dropout rate
63
+ :param float attention_dropout_rate: dropout rate in attention
64
+ :param float positional_dropout_rate: dropout rate after adding positional encoding
65
+ :param str or torch.nn.Module input_layer: input layer type
66
+ :param class pos_enc_class: PositionalEncoding or ScaledPositionalEncoding
67
+ :param bool normalize_before: whether to use layer_norm before the first block
68
+ :param bool concat_after: whether to concat attention layer's input and output
69
+ if True, additional linear will be applied.
70
+ i.e. x -> x + linear(concat(x, att(x)))
71
+ if False, no additional linear will be applied. i.e. x -> x + att(x)
72
+ :param str positionwise_layer_type: linear of conv1d
73
+ :param int positionwise_conv_kernel_size: kernel size of positionwise conv1d layer
74
+ :param str encoder_attn_layer_type: encoder attention layer type
75
+ :param bool macaron_style: whether to use macaron style for positionwise layer
76
+ :param bool use_cnn_module: whether to use convolution module
77
+ :param bool zero_triu: whether to zero the upper triangular part of attention matrix
78
+ :param int cnn_module_kernel: kernerl size of convolution module
79
+ :param int padding_idx: padding_idx for input_layer=embed
80
+ """
81
+
82
+ def __init__(
83
+ self,
84
+ idim,
85
+ attention_dim=256,
86
+ attention_heads=4,
87
+ linear_units=2048,
88
+ num_blocks=6,
89
+ dropout_rate=0.1,
90
+ positional_dropout_rate=0.1,
91
+ attention_dropout_rate=0.0,
92
+ input_layer="conv2d",
93
+ pos_enc_class=PositionalEncoding,
94
+ normalize_before=True,
95
+ concat_after=False,
96
+ positionwise_layer_type="linear",
97
+ positionwise_conv_kernel_size=1,
98
+ macaron_style=False,
99
+ encoder_attn_layer_type="mha",
100
+ use_cnn_module=False,
101
+ zero_triu=False,
102
+ cnn_module_kernel=31,
103
+ padding_idx=-1,
104
+ relu_type="prelu",
105
+ a_upsample_ratio=1,
106
+ ):
107
+ """Construct an Encoder object."""
108
+ super(Encoder, self).__init__()
109
+ self._register_load_state_dict_pre_hook(_pre_hook)
110
+
111
+ if encoder_attn_layer_type == "rel_mha":
112
+ pos_enc_class = RelPositionalEncoding
113
+ elif encoder_attn_layer_type == "legacy_rel_mha":
114
+ pos_enc_class = LegacyRelPositionalEncoding
115
+ # -- frontend module.
116
+ if input_layer == "conv1d":
117
+ self.frontend = Conv1dResNet(
118
+ relu_type=relu_type,
119
+ a_upsample_ratio=a_upsample_ratio,
120
+ )
121
+ elif input_layer == "conv3d":
122
+ self.frontend = Conv3dResNet(relu_type=relu_type)
123
+ else:
124
+ self.frontend = None
125
+ # -- backend module.
126
+ if input_layer == "linear":
127
+ self.embed = torch.nn.Sequential(
128
+ torch.nn.Linear(idim, attention_dim),
129
+ torch.nn.LayerNorm(attention_dim),
130
+ torch.nn.Dropout(dropout_rate),
131
+ torch.nn.ReLU(),
132
+ pos_enc_class(attention_dim, positional_dropout_rate),
133
+ )
134
+ elif input_layer == "conv2d":
135
+ self.embed = Conv2dSubsampling(
136
+ idim,
137
+ attention_dim,
138
+ dropout_rate,
139
+ pos_enc_class(attention_dim, dropout_rate),
140
+ )
141
+ elif input_layer == "vgg2l":
142
+ self.embed = VGG2L(idim, attention_dim)
143
+ elif input_layer == "embed":
144
+ self.embed = torch.nn.Sequential(
145
+ torch.nn.Embedding(idim, attention_dim, padding_idx=padding_idx),
146
+ pos_enc_class(attention_dim, positional_dropout_rate),
147
+ )
148
+ elif isinstance(input_layer, torch.nn.Module):
149
+ self.embed = torch.nn.Sequential(
150
+ input_layer, pos_enc_class(attention_dim, positional_dropout_rate),
151
+ )
152
+ elif input_layer in ["conv1d", "conv3d"]:
153
+ self.embed = torch.nn.Sequential(
154
+ torch.nn.Linear(512, attention_dim),
155
+ pos_enc_class(attention_dim, positional_dropout_rate)
156
+ )
157
+ elif input_layer is None:
158
+ self.embed = torch.nn.Sequential(
159
+ pos_enc_class(attention_dim, positional_dropout_rate)
160
+ )
161
+ else:
162
+ raise ValueError("unknown input_layer: " + input_layer)
163
+ self.normalize_before = normalize_before
164
+ if positionwise_layer_type == "linear":
165
+ positionwise_layer = PositionwiseFeedForward
166
+ positionwise_layer_args = (attention_dim, linear_units, dropout_rate)
167
+ elif positionwise_layer_type == "conv1d":
168
+ positionwise_layer = MultiLayeredConv1d
169
+ positionwise_layer_args = (
170
+ attention_dim,
171
+ linear_units,
172
+ positionwise_conv_kernel_size,
173
+ dropout_rate,
174
+ )
175
+ elif positionwise_layer_type == "conv1d-linear":
176
+ positionwise_layer = Conv1dLinear
177
+ positionwise_layer_args = (
178
+ attention_dim,
179
+ linear_units,
180
+ positionwise_conv_kernel_size,
181
+ dropout_rate,
182
+ )
183
+ else:
184
+ raise NotImplementedError("Support only linear or conv1d.")
185
+
186
+ if encoder_attn_layer_type == "mha":
187
+ encoder_attn_layer = MultiHeadedAttention
188
+ encoder_attn_layer_args = (
189
+ attention_heads,
190
+ attention_dim,
191
+ attention_dropout_rate,
192
+ )
193
+ elif encoder_attn_layer_type == "legacy_rel_mha":
194
+ encoder_attn_layer = LegacyRelPositionMultiHeadedAttention
195
+ encoder_attn_layer_args = (
196
+ attention_heads,
197
+ attention_dim,
198
+ attention_dropout_rate,
199
+ )
200
+ elif encoder_attn_layer_type == "rel_mha":
201
+ encoder_attn_layer = RelPositionMultiHeadedAttention
202
+ encoder_attn_layer_args = (
203
+ attention_heads,
204
+ attention_dim,
205
+ attention_dropout_rate,
206
+ zero_triu,
207
+ )
208
+ else:
209
+ raise ValueError("unknown encoder_attn_layer: " + encoder_attn_layer)
210
+
211
+ convolution_layer = ConvolutionModule
212
+ convolution_layer_args = (attention_dim, cnn_module_kernel)
213
+
214
+ self.encoders = repeat(
215
+ num_blocks,
216
+ lambda: EncoderLayer(
217
+ attention_dim,
218
+ encoder_attn_layer(*encoder_attn_layer_args),
219
+ positionwise_layer(*positionwise_layer_args),
220
+ convolution_layer(*convolution_layer_args) if use_cnn_module else None,
221
+ dropout_rate,
222
+ normalize_before,
223
+ concat_after,
224
+ macaron_style,
225
+ ),
226
+ )
227
+ if self.normalize_before:
228
+ self.after_norm = LayerNorm(attention_dim)
229
+
230
+ def forward(self, xs, masks, extract_resnet_feats=False):
231
+ """Encode input sequence.
232
+
233
+ :param torch.Tensor xs: input tensor
234
+ :param torch.Tensor masks: input mask
235
+ :param str extract_features: the position for feature extraction
236
+ :return: position embedded tensor and mask
237
+ :rtype Tuple[torch.Tensor, torch.Tensor]:
238
+ """
239
+ if isinstance(self.frontend, (Conv1dResNet, Conv3dResNet)):
240
+ xs = self.frontend(xs)
241
+ if extract_resnet_feats:
242
+ return xs
243
+
244
+ if isinstance(self.embed, Conv2dSubsampling):
245
+ xs, masks = self.embed(xs, masks)
246
+ else:
247
+ xs = self.embed(xs)
248
+
249
+ xs, masks = self.encoders(xs, masks)
250
+
251
+ if isinstance(xs, tuple):
252
+ xs = xs[0]
253
+
254
+ if self.normalize_before:
255
+ xs = self.after_norm(xs)
256
+
257
+ return xs, masks
258
+
259
+ def forward_one_step(self, xs, masks, cache=None):
260
+ """Encode input frame.
261
+
262
+ :param torch.Tensor xs: input tensor
263
+ :param torch.Tensor masks: input mask
264
+ :param List[torch.Tensor] cache: cache tensors
265
+ :return: position embedded tensor, mask and new cache
266
+ :rtype Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]:
267
+ """
268
+ if isinstance(self.frontend, (Conv1dResNet, Conv3dResNet)):
269
+ xs = self.frontend(xs)
270
+
271
+ if isinstance(self.embed, Conv2dSubsampling):
272
+ xs, masks = self.embed(xs, masks)
273
+ else:
274
+ xs = self.embed(xs)
275
+ if cache is None:
276
+ cache = [None for _ in range(len(self.encoders))]
277
+ new_cache = []
278
+ for c, e in zip(cache, self.encoders):
279
+ xs, masks = e(xs, masks, cache=c)
280
+ new_cache.append(xs)
281
+ if self.normalize_before:
282
+ xs = self.after_norm(xs)
283
+ return xs, masks, new_cache
espnet/nets/pytorch_backend/transformer/encoder_layer.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ # Copyright 2019 Shigeki Karita
5
+ # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
6
+
7
+ """Encoder self-attention layer definition."""
8
+
9
+ import copy
10
+ import torch
11
+
12
+ from torch import nn
13
+
14
+ from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm
15
+
16
+
17
+ class EncoderLayer(nn.Module):
18
+ """Encoder layer module.
19
+
20
+ :param int size: input dim
21
+ :param espnet.nets.pytorch_backend.transformer.attention.
22
+ MultiHeadedAttention self_attn: self attention module
23
+ RelPositionMultiHeadedAttention self_attn: self attention module
24
+ :param espnet.nets.pytorch_backend.transformer.positionwise_feed_forward.
25
+ PositionwiseFeedForward feed_forward:
26
+ feed forward module
27
+ :param espnet.nets.pytorch_backend.transformer.convolution.
28
+ ConvolutionModule feed_foreard:
29
+ feed forward module
30
+ :param float dropout_rate: dropout rate
31
+ :param bool normalize_before: whether to use layer_norm before the first block
32
+ :param bool concat_after: whether to concat attention layer's input and output
33
+ if True, additional linear will be applied.
34
+ i.e. x -> x + linear(concat(x, att(x)))
35
+ if False, no additional linear will be applied. i.e. x -> x + att(x)
36
+ :param bool macaron_style: whether to use macaron style for PositionwiseFeedForward
37
+
38
+ """
39
+
40
+ def __init__(
41
+ self,
42
+ size,
43
+ self_attn,
44
+ feed_forward,
45
+ conv_module,
46
+ dropout_rate,
47
+ normalize_before=True,
48
+ concat_after=False,
49
+ macaron_style=False,
50
+ ):
51
+ """Construct an EncoderLayer object."""
52
+ super(EncoderLayer, self).__init__()
53
+ self.self_attn = self_attn
54
+ self.feed_forward = feed_forward
55
+ self.ff_scale = 1.0
56
+ self.conv_module = conv_module
57
+ self.macaron_style = macaron_style
58
+ self.norm_ff = LayerNorm(size) # for the FNN module
59
+ self.norm_mha = LayerNorm(size) # for the MHA module
60
+ if self.macaron_style:
61
+ self.feed_forward_macaron = copy.deepcopy(feed_forward)
62
+ self.ff_scale = 0.5
63
+ # for another FNN module in macaron style
64
+ self.norm_ff_macaron = LayerNorm(size)
65
+ if self.conv_module is not None:
66
+ self.norm_conv = LayerNorm(size) # for the CNN module
67
+ self.norm_final = LayerNorm(size) # for the final output of the block
68
+ self.dropout = nn.Dropout(dropout_rate)
69
+ self.size = size
70
+ self.normalize_before = normalize_before
71
+ self.concat_after = concat_after
72
+ if self.concat_after:
73
+ self.concat_linear = nn.Linear(size + size, size)
74
+
75
+ def forward(self, x_input, mask, cache=None):
76
+ """Compute encoded features.
77
+
78
+ :param torch.Tensor x_input: encoded source features (batch, max_time_in, size)
79
+ :param torch.Tensor mask: mask for x (batch, max_time_in)
80
+ :param torch.Tensor cache: cache for x (batch, max_time_in - 1, size)
81
+ :rtype: Tuple[torch.Tensor, torch.Tensor]
82
+ """
83
+ if isinstance(x_input, tuple):
84
+ x, pos_emb = x_input[0], x_input[1]
85
+ else:
86
+ x, pos_emb = x_input, None
87
+
88
+ # whether to use macaron style
89
+ if self.macaron_style:
90
+ residual = x
91
+ if self.normalize_before:
92
+ x = self.norm_ff_macaron(x)
93
+ x = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(x))
94
+ if not self.normalize_before:
95
+ x = self.norm_ff_macaron(x)
96
+
97
+ # multi-headed self-attention module
98
+ residual = x
99
+ if self.normalize_before:
100
+ x = self.norm_mha(x)
101
+
102
+ if cache is None:
103
+ x_q = x
104
+ else:
105
+ assert cache.shape == (x.shape[0], x.shape[1] - 1, self.size)
106
+ x_q = x[:, -1:, :]
107
+ residual = residual[:, -1:, :]
108
+ mask = None if mask is None else mask[:, -1:, :]
109
+
110
+ if pos_emb is not None:
111
+ x_att = self.self_attn(x_q, x, x, pos_emb, mask)
112
+ else:
113
+ x_att = self.self_attn(x_q, x, x, mask)
114
+
115
+ if self.concat_after:
116
+ x_concat = torch.cat((x, x_att), dim=-1)
117
+ x = residual + self.concat_linear(x_concat)
118
+ else:
119
+ x = residual + self.dropout(x_att)
120
+ if not self.normalize_before:
121
+ x = self.norm_mha(x)
122
+
123
+ # convolution module
124
+ if self.conv_module is not None:
125
+ residual = x
126
+ if self.normalize_before:
127
+ x = self.norm_conv(x)
128
+ x = residual + self.dropout(self.conv_module(x))
129
+ if not self.normalize_before:
130
+ x = self.norm_conv(x)
131
+
132
+ # feed forward module
133
+ residual = x
134
+ if self.normalize_before:
135
+ x = self.norm_ff(x)
136
+ x = residual + self.ff_scale * self.dropout(self.feed_forward(x))
137
+ if not self.normalize_before:
138
+ x = self.norm_ff(x)
139
+
140
+ if self.conv_module is not None:
141
+ x = self.norm_final(x)
142
+
143
+ if cache is not None:
144
+ x = torch.cat([cache, x], dim=1)
145
+
146
+ if pos_emb is not None:
147
+ return (x, pos_emb), mask
148
+ else:
149
+ return x, mask
espnet/nets/pytorch_backend/transformer/label_smoothing_loss.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ # Copyright 2019 Shigeki Karita
5
+ # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
6
+
7
+ """Label smoothing module."""
8
+
9
+ import torch
10
+ from torch import nn
11
+
12
+
13
+ class LabelSmoothingLoss(nn.Module):
14
+ """Label-smoothing loss.
15
+
16
+ :param int size: the number of class
17
+ :param int padding_idx: ignored class id
18
+ :param float smoothing: smoothing rate (0.0 means the conventional CE)
19
+ :param bool normalize_length: normalize loss by sequence length if True
20
+ :param torch.nn.Module criterion: loss function to be smoothed
21
+ """
22
+
23
+ def __init__(
24
+ self,
25
+ size,
26
+ padding_idx,
27
+ smoothing,
28
+ normalize_length=False,
29
+ criterion=nn.KLDivLoss(reduction="none"),
30
+ ):
31
+ """Construct an LabelSmoothingLoss object."""
32
+ super(LabelSmoothingLoss, self).__init__()
33
+ self.criterion = criterion
34
+ self.padding_idx = padding_idx
35
+ self.confidence = 1.0 - smoothing
36
+ self.smoothing = smoothing
37
+ self.size = size
38
+ self.true_dist = None
39
+ self.normalize_length = normalize_length
40
+
41
+ def forward(self, x, target):
42
+ """Compute loss between x and target.
43
+
44
+ :param torch.Tensor x: prediction (batch, seqlen, class)
45
+ :param torch.Tensor target:
46
+ target signal masked with self.padding_id (batch, seqlen)
47
+ :return: scalar float value
48
+ :rtype torch.Tensor
49
+ """
50
+ assert x.size(2) == self.size
51
+ batch_size = x.size(0)
52
+ x = x.view(-1, self.size)
53
+ target = target.view(-1)
54
+ with torch.no_grad():
55
+ true_dist = x.clone()
56
+ true_dist.fill_(self.smoothing / (self.size - 1))
57
+ ignore = target == self.padding_idx # (B,)
58
+ total = len(target) - ignore.sum().item()
59
+ target = target.masked_fill(ignore, 0) # avoid -1 index
60
+ true_dist.scatter_(1, target.unsqueeze(1), self.confidence)
61
+ kl = self.criterion(torch.log_softmax(x, dim=1), true_dist)
62
+ denom = total if self.normalize_length else batch_size
63
+ return kl.masked_fill(ignore.unsqueeze(1), 0).sum() / denom
espnet/nets/pytorch_backend/transformer/layer_norm.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ # Copyright 2019 Shigeki Karita
5
+ # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
6
+
7
+ """Layer normalization module."""
8
+
9
+ import torch
10
+
11
+
12
+ class LayerNorm(torch.nn.LayerNorm):
13
+ """Layer normalization module.
14
+
15
+ :param int nout: output dim size
16
+ :param int dim: dimension to be normalized
17
+ """
18
+
19
+ def __init__(self, nout, dim=-1):
20
+ """Construct an LayerNorm object."""
21
+ super(LayerNorm, self).__init__(nout, eps=1e-12)
22
+ self.dim = dim
23
+
24
+ def forward(self, x):
25
+ """Apply layer normalization.
26
+
27
+ :param torch.Tensor x: input tensor
28
+ :return: layer normalized tensor
29
+ :rtype torch.Tensor
30
+ """
31
+ if self.dim == -1:
32
+ return super(LayerNorm, self).forward(x)
33
+ return super(LayerNorm, self).forward(x.transpose(1, -1)).transpose(1, -1)
espnet/nets/pytorch_backend/transformer/mask.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ # Copyright 2019 Shigeki Karita
4
+ # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
5
+
6
+ """Mask module."""
7
+
8
+ from distutils.version import LooseVersion
9
+
10
+ import torch
11
+
12
+ is_torch_1_2_plus = LooseVersion(torch.__version__) >= LooseVersion("1.2.0")
13
+ # LooseVersion('1.2.0') == LooseVersion(torch.__version__) can't include e.g. 1.2.0+aaa
14
+ is_torch_1_2 = (
15
+ LooseVersion("1.3") > LooseVersion(torch.__version__) >= LooseVersion("1.2")
16
+ )
17
+ datatype = torch.bool if is_torch_1_2_plus else torch.uint8
18
+
19
+
20
+ def subsequent_mask(size, device="cpu", dtype=datatype):
21
+ """Create mask for subsequent steps (1, size, size).
22
+
23
+ :param int size: size of mask
24
+ :param str device: "cpu" or "cuda" or torch.Tensor.device
25
+ :param torch.dtype dtype: result dtype
26
+ :rtype: torch.Tensor
27
+ >>> subsequent_mask(3)
28
+ [[1, 0, 0],
29
+ [1, 1, 0],
30
+ [1, 1, 1]]
31
+ """
32
+ if is_torch_1_2 and dtype == torch.bool:
33
+ # torch=1.2 doesn't support tril for bool tensor
34
+ ret = torch.ones(size, size, device=device, dtype=torch.uint8)
35
+ return torch.tril(ret, out=ret).type(dtype)
36
+ else:
37
+ ret = torch.ones(size, size, device=device, dtype=dtype)
38
+ return torch.tril(ret, out=ret)
39
+
40
+
41
+ def target_mask(ys_in_pad, ignore_id):
42
+ """Create mask for decoder self-attention.
43
+
44
+ :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax)
45
+ :param int ignore_id: index of padding
46
+ :param torch.dtype dtype: result dtype
47
+ :rtype: torch.Tensor
48
+ """
49
+ ys_mask = ys_in_pad != ignore_id
50
+ m = subsequent_mask(ys_mask.size(-1), device=ys_mask.device).unsqueeze(0)
51
+ return ys_mask.unsqueeze(-2) & m
espnet/nets/pytorch_backend/transformer/multi_layer_conv.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ # Copyright 2019 Tomoki Hayashi
5
+ # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
6
+
7
+ """Layer modules for FFT block in FastSpeech (Feed-forward Transformer)."""
8
+
9
+ import torch
10
+
11
+
12
+ class MultiLayeredConv1d(torch.nn.Module):
13
+ """Multi-layered conv1d for Transformer block.
14
+
15
+ This is a module of multi-leyered conv1d designed
16
+ to replace positionwise feed-forward network
17
+ in Transforner block, which is introduced in
18
+ `FastSpeech: Fast, Robust and Controllable Text to Speech`_.
19
+
20
+ .. _`FastSpeech: Fast, Robust and Controllable Text to Speech`:
21
+ https://arxiv.org/pdf/1905.09263.pdf
22
+
23
+ """
24
+
25
+ def __init__(self, in_chans, hidden_chans, kernel_size, dropout_rate):
26
+ """Initialize MultiLayeredConv1d module.
27
+
28
+ Args:
29
+ in_chans (int): Number of input channels.
30
+ hidden_chans (int): Number of hidden channels.
31
+ kernel_size (int): Kernel size of conv1d.
32
+ dropout_rate (float): Dropout rate.
33
+
34
+ """
35
+ super(MultiLayeredConv1d, self).__init__()
36
+ self.w_1 = torch.nn.Conv1d(
37
+ in_chans,
38
+ hidden_chans,
39
+ kernel_size,
40
+ stride=1,
41
+ padding=(kernel_size - 1) // 2,
42
+ )
43
+ self.w_2 = torch.nn.Conv1d(
44
+ hidden_chans,
45
+ in_chans,
46
+ kernel_size,
47
+ stride=1,
48
+ padding=(kernel_size - 1) // 2,
49
+ )
50
+ self.dropout = torch.nn.Dropout(dropout_rate)
51
+
52
+ def forward(self, x):
53
+ """Calculate forward propagation.
54
+
55
+ Args:
56
+ x (Tensor): Batch of input tensors (B, ..., in_chans).
57
+
58
+ Returns:
59
+ Tensor: Batch of output tensors (B, ..., hidden_chans).
60
+
61
+ """
62
+ x = torch.relu(self.w_1(x.transpose(-1, 1))).transpose(-1, 1)
63
+ return self.w_2(self.dropout(x).transpose(-1, 1)).transpose(-1, 1)
64
+
65
+
66
+ class Conv1dLinear(torch.nn.Module):
67
+ """Conv1D + Linear for Transformer block.
68
+
69
+ A variant of MultiLayeredConv1d, which replaces second conv-layer to linear.
70
+
71
+ """
72
+
73
+ def __init__(self, in_chans, hidden_chans, kernel_size, dropout_rate):
74
+ """Initialize Conv1dLinear module.
75
+
76
+ Args:
77
+ in_chans (int): Number of input channels.
78
+ hidden_chans (int): Number of hidden channels.
79
+ kernel_size (int): Kernel size of conv1d.
80
+ dropout_rate (float): Dropout rate.
81
+
82
+ """
83
+ super(Conv1dLinear, self).__init__()
84
+ self.w_1 = torch.nn.Conv1d(
85
+ in_chans,
86
+ hidden_chans,
87
+ kernel_size,
88
+ stride=1,
89
+ padding=(kernel_size - 1) // 2,
90
+ )
91
+ self.w_2 = torch.nn.Linear(hidden_chans, in_chans)
92
+ self.dropout = torch.nn.Dropout(dropout_rate)
93
+
94
+ def forward(self, x):
95
+ """Calculate forward propagation.
96
+
97
+ Args:
98
+ x (Tensor): Batch of input tensors (B, ..., in_chans).
99
+
100
+ Returns:
101
+ Tensor: Batch of output tensors (B, ..., hidden_chans).
102
+
103
+ """
104
+ x = torch.relu(self.w_1(x.transpose(-1, 1))).transpose(-1, 1)
105
+ return self.w_2(self.dropout(x))
espnet/nets/pytorch_backend/transformer/optimizer.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ # Copyright 2019 Shigeki Karita
5
+ # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
6
+
7
+ """Optimizer module."""
8
+
9
+ import torch
10
+
11
+
12
+ class NoamOpt(object):
13
+ """Optim wrapper that implements rate."""
14
+
15
+ def __init__(self, model_size, factor, warmup, optimizer):
16
+ """Construct an NoamOpt object."""
17
+ self.optimizer = optimizer
18
+ self._step = 0
19
+ self.warmup = warmup
20
+ self.factor = factor
21
+ self.model_size = model_size
22
+ self._rate = 0
23
+
24
+ @property
25
+ def param_groups(self):
26
+ """Return param_groups."""
27
+ return self.optimizer.param_groups
28
+
29
+ def step(self):
30
+ """Update parameters and rate."""
31
+ self._step += 1
32
+ rate = self.rate()
33
+ for p in self.optimizer.param_groups:
34
+ p["lr"] = rate
35
+ self._rate = rate
36
+ self.optimizer.step()
37
+
38
+ def rate(self, step=None):
39
+ """Implement `lrate` above."""
40
+ if step is None:
41
+ step = self._step
42
+ return (
43
+ self.factor
44
+ * self.model_size ** (-0.5)
45
+ * min(step ** (-0.5), step * self.warmup ** (-1.5))
46
+ )
47
+
48
+ def zero_grad(self):
49
+ """Reset gradient."""
50
+ self.optimizer.zero_grad()
51
+
52
+ def state_dict(self):
53
+ """Return state_dict."""
54
+ return {
55
+ "_step": self._step,
56
+ "warmup": self.warmup,
57
+ "factor": self.factor,
58
+ "model_size": self.model_size,
59
+ "_rate": self._rate,
60
+ "optimizer": self.optimizer.state_dict(),
61
+ }
62
+
63
+ def load_state_dict(self, state_dict):
64
+ """Load state_dict."""
65
+ for key, value in state_dict.items():
66
+ if key == "optimizer":
67
+ self.optimizer.load_state_dict(state_dict["optimizer"])
68
+ else:
69
+ setattr(self, key, value)
70
+
71
+
72
+ def get_std_opt(model, d_model, warmup, factor):
73
+ """Get standard NoamOpt."""
74
+ base = torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)
75
+ return NoamOpt(d_model, factor, warmup, base)
espnet/nets/pytorch_backend/transformer/plot.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ # Copyright 2019 Shigeki Karita
5
+ # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
6
+
7
+ import logging
8
+
9
+ import matplotlib.pyplot as plt
10
+ import numpy
11
+
12
+ from espnet.asr import asr_utils
13
+
14
+
15
+ def _plot_and_save_attention(att_w, filename, xtokens=None, ytokens=None):
16
+ # dynamically import matplotlib due to not found error
17
+ from matplotlib.ticker import MaxNLocator
18
+ import os
19
+
20
+ d = os.path.dirname(filename)
21
+ if not os.path.exists(d):
22
+ os.makedirs(d)
23
+ w, h = plt.figaspect(1.0 / len(att_w))
24
+ fig = plt.Figure(figsize=(w * 2, h * 2))
25
+ axes = fig.subplots(1, len(att_w))
26
+ if len(att_w) == 1:
27
+ axes = [axes]
28
+ for ax, aw in zip(axes, att_w):
29
+ # plt.subplot(1, len(att_w), h)
30
+ ax.imshow(aw.astype(numpy.float32), aspect="auto")
31
+ ax.set_xlabel("Input")
32
+ ax.set_ylabel("Output")
33
+ ax.xaxis.set_major_locator(MaxNLocator(integer=True))
34
+ ax.yaxis.set_major_locator(MaxNLocator(integer=True))
35
+ # Labels for major ticks
36
+ if xtokens is not None:
37
+ ax.set_xticks(numpy.linspace(0, len(xtokens) - 1, len(xtokens)))
38
+ ax.set_xticks(numpy.linspace(0, len(xtokens) - 1, 1), minor=True)
39
+ ax.set_xticklabels(xtokens + [""], rotation=40)
40
+ if ytokens is not None:
41
+ ax.set_yticks(numpy.linspace(0, len(ytokens) - 1, len(ytokens)))
42
+ ax.set_yticks(numpy.linspace(0, len(ytokens) - 1, 1), minor=True)
43
+ ax.set_yticklabels(ytokens + [""])
44
+ fig.tight_layout()
45
+ return fig
46
+
47
+
48
+ def savefig(plot, filename):
49
+ plot.savefig(filename)
50
+ plt.clf()
51
+
52
+
53
+ def plot_multi_head_attention(
54
+ data,
55
+ attn_dict,
56
+ outdir,
57
+ suffix="png",
58
+ savefn=savefig,
59
+ ikey="input",
60
+ iaxis=0,
61
+ okey="output",
62
+ oaxis=0,
63
+ ):
64
+ """Plot multi head attentions.
65
+
66
+ :param dict data: utts info from json file
67
+ :param dict[str, torch.Tensor] attn_dict: multi head attention dict.
68
+ values should be torch.Tensor (head, input_length, output_length)
69
+ :param str outdir: dir to save fig
70
+ :param str suffix: filename suffix including image type (e.g., png)
71
+ :param savefn: function to save
72
+
73
+ """
74
+ for name, att_ws in attn_dict.items():
75
+ for idx, att_w in enumerate(att_ws):
76
+ filename = "%s/%s.%s.%s" % (outdir, data[idx][0], name, suffix)
77
+ dec_len = int(data[idx][1][okey][oaxis]["shape"][0])
78
+ enc_len = int(data[idx][1][ikey][iaxis]["shape"][0])
79
+ xtokens, ytokens = None, None
80
+ if "encoder" in name:
81
+ att_w = att_w[:, :enc_len, :enc_len]
82
+ # for MT
83
+ if "token" in data[idx][1][ikey][iaxis].keys():
84
+ xtokens = data[idx][1][ikey][iaxis]["token"].split()
85
+ ytokens = xtokens[:]
86
+ elif "decoder" in name:
87
+ if "self" in name:
88
+ att_w = att_w[:, : dec_len + 1, : dec_len + 1] # +1 for <sos>
89
+ else:
90
+ att_w = att_w[:, : dec_len + 1, :enc_len] # +1 for <sos>
91
+ # for MT
92
+ if "token" in data[idx][1][ikey][iaxis].keys():
93
+ xtokens = data[idx][1][ikey][iaxis]["token"].split()
94
+ # for ASR/ST/MT
95
+ if "token" in data[idx][1][okey][oaxis].keys():
96
+ ytokens = ["<sos>"] + data[idx][1][okey][oaxis]["token"].split()
97
+ if "self" in name:
98
+ xtokens = ytokens[:]
99
+ else:
100
+ logging.warning("unknown name for shaping attention")
101
+ fig = _plot_and_save_attention(att_w, filename, xtokens, ytokens)
102
+ savefn(fig, filename)
103
+
104
+
105
+ class PlotAttentionReport(asr_utils.PlotAttentionReport):
106
+ def plotfn(self, *args, **kwargs):
107
+ kwargs["ikey"] = self.ikey
108
+ kwargs["iaxis"] = self.iaxis
109
+ kwargs["okey"] = self.okey
110
+ kwargs["oaxis"] = self.oaxis
111
+ plot_multi_head_attention(*args, **kwargs)
112
+
113
+ def __call__(self, trainer):
114
+ attn_dict = self.get_attention_weights()
115
+ suffix = "ep.{.updater.epoch}.png".format(trainer)
116
+ self.plotfn(self.data, attn_dict, self.outdir, suffix, savefig)
117
+
118
+ def get_attention_weights(self):
119
+ batch = self.converter([self.transform(self.data)], self.device)
120
+ if isinstance(batch, tuple):
121
+ att_ws = self.att_vis_fn(*batch)
122
+ elif isinstance(batch, dict):
123
+ att_ws = self.att_vis_fn(**batch)
124
+ return att_ws
125
+
126
+ def log_attentions(self, logger, step):
127
+ def log_fig(plot, filename):
128
+ from os.path import basename
129
+
130
+ logger.add_figure(basename(filename), plot, step)
131
+ plt.clf()
132
+
133
+ attn_dict = self.get_attention_weights()
134
+ self.plotfn(self.data, attn_dict, self.outdir, "", log_fig)
espnet/nets/pytorch_backend/transformer/positionwise_feed_forward.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ # Copyright 2019 Shigeki Karita
5
+ # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
6
+
7
+ """Positionwise feed forward layer definition."""
8
+
9
+ import torch
10
+
11
+
12
+ class PositionwiseFeedForward(torch.nn.Module):
13
+ """Positionwise feed forward layer.
14
+
15
+ :param int idim: input dimenstion
16
+ :param int hidden_units: number of hidden units
17
+ :param float dropout_rate: dropout rate
18
+
19
+ """
20
+
21
+ def __init__(self, idim, hidden_units, dropout_rate):
22
+ """Construct an PositionwiseFeedForward object."""
23
+ super(PositionwiseFeedForward, self).__init__()
24
+ self.w_1 = torch.nn.Linear(idim, hidden_units)
25
+ self.w_2 = torch.nn.Linear(hidden_units, idim)
26
+ self.dropout = torch.nn.Dropout(dropout_rate)
27
+
28
+ def forward(self, x):
29
+ """Forward funciton."""
30
+ return self.w_2(self.dropout(torch.relu(self.w_1(x))))
espnet/nets/pytorch_backend/transformer/raw_embeddings.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import logging
3
+
4
+ from espnet.nets.pytorch_backend.backbones.conv3d_extractor import Conv3dResNet
5
+ from espnet.nets.pytorch_backend.backbones.conv1d_extractor import Conv1dResNet
6
+
7
+
8
+ class VideoEmbedding(torch.nn.Module):
9
+ """Video Embedding
10
+
11
+ :param int idim: input dim
12
+ :param int odim: output dim
13
+ :param flaot dropout_rate: dropout rate
14
+ """
15
+
16
+ def __init__(self, idim, odim, dropout_rate, pos_enc_class, backbone_type="resnet", relu_type="prelu"):
17
+ super(VideoEmbedding, self).__init__()
18
+ self.trunk = Conv3dResNet(
19
+ backbone_type=backbone_type,
20
+ relu_type=relu_type
21
+ )
22
+ self.out = torch.nn.Sequential(
23
+ torch.nn.Linear(idim, odim),
24
+ pos_enc_class,
25
+ )
26
+
27
+ def forward(self, x, x_mask, extract_feats=None):
28
+ """video embedding for x
29
+
30
+ :param torch.Tensor x: input tensor
31
+ :param torch.Tensor x_mask: input mask
32
+ :param str extract_features: the position for feature extraction
33
+ :return: subsampled x and mask
34
+ :rtype Tuple[torch.Tensor, torch.Tensor]
35
+ """
36
+ x_resnet, x_mask = self.trunk(x, x_mask)
37
+ x = self.out(x_resnet)
38
+ if extract_feats:
39
+ return x, x_mask, x_resnet
40
+ else:
41
+ return x, x_mask
42
+
43
+
44
+ class AudioEmbedding(torch.nn.Module):
45
+ """Audio Embedding
46
+
47
+ :param int idim: input dim
48
+ :param int odim: output dim
49
+ :param flaot dropout_rate: dropout rate
50
+ """
51
+
52
+ def __init__(self, idim, odim, dropout_rate, pos_enc_class, relu_type="prelu", a_upsample_ratio=1):
53
+ super(AudioEmbedding, self).__init__()
54
+ self.trunk = Conv1dResNet(
55
+ relu_type=relu_type,
56
+ a_upsample_ratio=a_upsample_ratio,
57
+ )
58
+ self.out = torch.nn.Sequential(
59
+ torch.nn.Linear(idim, odim),
60
+ pos_enc_class,
61
+ )
62
+
63
+ def forward(self, x, x_mask, extract_feats=None):
64
+ """audio embedding for x
65
+
66
+ :param torch.Tensor x: input tensor
67
+ :param torch.Tensor x_mask: input mask
68
+ :param str extract_features: the position for feature extraction
69
+ :return: subsampled x and mask
70
+ :rtype Tuple[torch.Tensor, torch.Tensor]
71
+ """
72
+ x_resnet, x_mask = self.trunk(x, x_mask)
73
+ x = self.out(x_resnet)
74
+ if extract_feats:
75
+ return x, x_mask, x_resnet
76
+ else:
77
+ return x, x_mask
espnet/nets/pytorch_backend/transformer/repeat.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ # Copyright 2019 Shigeki Karita
5
+ # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
6
+
7
+ """Repeat the same layer definition."""
8
+
9
+ import torch
10
+
11
+
12
+ class MultiSequential(torch.nn.Sequential):
13
+ """Multi-input multi-output torch.nn.Sequential."""
14
+
15
+ def forward(self, *args):
16
+ """Repeat."""
17
+ for m in self:
18
+ args = m(*args)
19
+ return args
20
+
21
+
22
+ def repeat(N, fn):
23
+ """Repeat module N times.
24
+
25
+ :param int N: repeat time
26
+ :param function fn: function to generate module
27
+ :return: repeated modules
28
+ :rtype: MultiSequential
29
+ """
30
+ return MultiSequential(*[fn() for _ in range(N)])
espnet/nets/pytorch_backend/transformer/subsampling.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ # Copyright 2019 Shigeki Karita
5
+ # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
6
+
7
+ """Subsampling layer definition."""
8
+
9
+ import torch
10
+
11
+
12
+ class Conv2dSubsampling(torch.nn.Module):
13
+ """Convolutional 2D subsampling (to 1/4 length).
14
+
15
+ :param int idim: input dim
16
+ :param int odim: output dim
17
+ :param flaot dropout_rate: dropout rate
18
+ :param nn.Module pos_enc_class: positional encoding layer
19
+
20
+ """
21
+
22
+ def __init__(self, idim, odim, dropout_rate, pos_enc_class):
23
+ """Construct an Conv2dSubsampling object."""
24
+ super(Conv2dSubsampling, self).__init__()
25
+ self.conv = torch.nn.Sequential(
26
+ torch.nn.Conv2d(1, odim, 3, 2),
27
+ torch.nn.ReLU(),
28
+ torch.nn.Conv2d(odim, odim, 3, 2),
29
+ torch.nn.ReLU(),
30
+ )
31
+ self.out = torch.nn.Sequential(
32
+ torch.nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim), pos_enc_class,
33
+ )
34
+
35
+ def forward(self, x, x_mask):
36
+ """Subsample x.
37
+
38
+ :param torch.Tensor x: input tensor
39
+ :param torch.Tensor x_mask: input mask
40
+ :return: subsampled x and mask
41
+ :rtype Tuple[torch.Tensor, torch.Tensor]
42
+ or Tuple[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]
43
+ """
44
+ x = x.unsqueeze(1) # (b, c, t, f)
45
+ x = self.conv(x)
46
+ b, c, t, f = x.size()
47
+ # if RelPositionalEncoding, x: Tuple[torch.Tensor, torch.Tensor]
48
+ # else x: torch.Tensor
49
+ x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
50
+ if x_mask is None:
51
+ return x, None
52
+ return x, x_mask[:, :, :-2:2][:, :, :-2:2]
espnet/nets/scorer_interface.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Scorer interface module."""
2
+
3
+ from typing import Any
4
+ from typing import List
5
+ from typing import Tuple
6
+
7
+ import torch
8
+ import warnings
9
+
10
+
11
+ class ScorerInterface:
12
+ """Scorer interface for beam search.
13
+
14
+ The scorer performs scoring of the all tokens in vocabulary.
15
+
16
+ Examples:
17
+ * Search heuristics
18
+ * :class:`espnet.nets.scorers.length_bonus.LengthBonus`
19
+ * Decoder networks of the sequence-to-sequence models
20
+ * :class:`espnet.nets.pytorch_backend.nets.transformer.decoder.Decoder`
21
+ * :class:`espnet.nets.pytorch_backend.nets.rnn.decoders.Decoder`
22
+ * Neural language models
23
+ * :class:`espnet.nets.pytorch_backend.lm.transformer.TransformerLM`
24
+ * :class:`espnet.nets.pytorch_backend.lm.default.DefaultRNNLM`
25
+ * :class:`espnet.nets.pytorch_backend.lm.seq_rnn.SequentialRNNLM`
26
+
27
+ """
28
+
29
+ def init_state(self, x: torch.Tensor) -> Any:
30
+ """Get an initial state for decoding (optional).
31
+
32
+ Args:
33
+ x (torch.Tensor): The encoded feature tensor
34
+
35
+ Returns: initial state
36
+
37
+ """
38
+ return None
39
+
40
+ def select_state(self, state: Any, i: int, new_id: int = None) -> Any:
41
+ """Select state with relative ids in the main beam search.
42
+
43
+ Args:
44
+ state: Decoder state for prefix tokens
45
+ i (int): Index to select a state in the main beam search
46
+ new_id (int): New label index to select a state if necessary
47
+
48
+ Returns:
49
+ state: pruned state
50
+
51
+ """
52
+ return None if state is None else state[i]
53
+
54
+ def score(
55
+ self, y: torch.Tensor, state: Any, x: torch.Tensor
56
+ ) -> Tuple[torch.Tensor, Any]:
57
+ """Score new token (required).
58
+
59
+ Args:
60
+ y (torch.Tensor): 1D torch.int64 prefix tokens.
61
+ state: Scorer state for prefix tokens
62
+ x (torch.Tensor): The encoder feature that generates ys.
63
+
64
+ Returns:
65
+ tuple[torch.Tensor, Any]: Tuple of
66
+ scores for next token that has a shape of `(n_vocab)`
67
+ and next state for ys
68
+
69
+ """
70
+ raise NotImplementedError
71
+
72
+ def final_score(self, state: Any) -> float:
73
+ """Score eos (optional).
74
+
75
+ Args:
76
+ state: Scorer state for prefix tokens
77
+
78
+ Returns:
79
+ float: final score
80
+
81
+ """
82
+ return 0.0
83
+
84
+
85
+ class BatchScorerInterface(ScorerInterface):
86
+ """Batch scorer interface."""
87
+
88
+ def batch_init_state(self, x: torch.Tensor) -> Any:
89
+ """Get an initial state for decoding (optional).
90
+
91
+ Args:
92
+ x (torch.Tensor): The encoded feature tensor
93
+
94
+ Returns: initial state
95
+
96
+ """
97
+ return self.init_state(x)
98
+
99
+ def batch_score(
100
+ self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor
101
+ ) -> Tuple[torch.Tensor, List[Any]]:
102
+ """Score new token batch (required).
103
+
104
+ Args:
105
+ ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen).
106
+ states (List[Any]): Scorer states for prefix tokens.
107
+ xs (torch.Tensor):
108
+ The encoder feature that generates ys (n_batch, xlen, n_feat).
109
+
110
+ Returns:
111
+ tuple[torch.Tensor, List[Any]]: Tuple of
112
+ batchfied scores for next token with shape of `(n_batch, n_vocab)`
113
+ and next state list for ys.
114
+
115
+ """
116
+ warnings.warn(
117
+ "{} batch score is implemented through for loop not parallelized".format(
118
+ self.__class__.__name__
119
+ )
120
+ )
121
+ scores = list()
122
+ outstates = list()
123
+ for i, (y, state, x) in enumerate(zip(ys, states, xs)):
124
+ score, outstate = self.score(y, state, x)
125
+ outstates.append(outstate)
126
+ scores.append(score)
127
+ scores = torch.cat(scores, 0).view(ys.shape[0], -1)
128
+ return scores, outstates
129
+
130
+
131
+ class PartialScorerInterface(ScorerInterface):
132
+ """Partial scorer interface for beam search.
133
+
134
+ The partial scorer performs scoring when non-partial scorer finished scoring,
135
+ and receives pre-pruned next tokens to score because it is too heavy to score
136
+ all the tokens.
137
+
138
+ Examples:
139
+ * Prefix search for connectionist-temporal-classification models
140
+ * :class:`espnet.nets.scorers.ctc.CTCPrefixScorer`
141
+
142
+ """
143
+
144
+ def score_partial(
145
+ self, y: torch.Tensor, next_tokens: torch.Tensor, state: Any, x: torch.Tensor
146
+ ) -> Tuple[torch.Tensor, Any]:
147
+ """Score new token (required).
148
+
149
+ Args:
150
+ y (torch.Tensor): 1D prefix token
151
+ next_tokens (torch.Tensor): torch.int64 next token to score
152
+ state: decoder state for prefix tokens
153
+ x (torch.Tensor): The encoder feature that generates ys
154
+
155
+ Returns:
156
+ tuple[torch.Tensor, Any]:
157
+ Tuple of a score tensor for y that has a shape `(len(next_tokens),)`
158
+ and next state for ys
159
+
160
+ """
161
+ raise NotImplementedError
162
+
163
+
164
+ class BatchPartialScorerInterface(BatchScorerInterface, PartialScorerInterface):
165
+ """Batch partial scorer interface for beam search."""
166
+
167
+ def batch_score_partial(
168
+ self,
169
+ ys: torch.Tensor,
170
+ next_tokens: torch.Tensor,
171
+ states: List[Any],
172
+ xs: torch.Tensor,
173
+ ) -> Tuple[torch.Tensor, Any]:
174
+ """Score new token (required).
175
+
176
+ Args:
177
+ ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen).
178
+ next_tokens (torch.Tensor): torch.int64 tokens to score (n_batch, n_token).
179
+ states (List[Any]): Scorer states for prefix tokens.
180
+ xs (torch.Tensor):
181
+ The encoder feature that generates ys (n_batch, xlen, n_feat).
182
+
183
+ Returns:
184
+ tuple[torch.Tensor, Any]:
185
+ Tuple of a score tensor for ys that has a shape `(n_batch, n_vocab)`
186
+ and next states for ys
187
+ """
188
+ raise NotImplementedError
espnet/nets/scorers/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Initialize sub package."""
espnet/nets/scorers/ctc.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ScorerInterface implementation for CTC."""
2
+
3
+ import numpy as np
4
+ import torch
5
+
6
+ from espnet.nets.ctc_prefix_score import CTCPrefixScore
7
+ from espnet.nets.ctc_prefix_score import CTCPrefixScoreTH
8
+ from espnet.nets.scorer_interface import BatchPartialScorerInterface
9
+
10
+
11
+ class CTCPrefixScorer(BatchPartialScorerInterface):
12
+ """Decoder interface wrapper for CTCPrefixScore."""
13
+
14
+ def __init__(self, ctc: torch.nn.Module, eos: int):
15
+ """Initialize class.
16
+
17
+ Args:
18
+ ctc (torch.nn.Module): The CTC implementation.
19
+ For example, :class:`espnet.nets.pytorch_backend.ctc.CTC`
20
+ eos (int): The end-of-sequence id.
21
+
22
+ """
23
+ self.ctc = ctc
24
+ self.eos = eos
25
+ self.impl = None
26
+
27
+ def init_state(self, x: torch.Tensor):
28
+ """Get an initial state for decoding.
29
+
30
+ Args:
31
+ x (torch.Tensor): The encoded feature tensor
32
+
33
+ Returns: initial state
34
+
35
+ """
36
+ logp = self.ctc.log_softmax(x.unsqueeze(0)).detach().squeeze(0).cpu().numpy()
37
+ # TODO(karita): use CTCPrefixScoreTH
38
+ self.impl = CTCPrefixScore(logp, 0, self.eos, np)
39
+ return 0, self.impl.initial_state()
40
+
41
+ def select_state(self, state, i, new_id=None):
42
+ """Select state with relative ids in the main beam search.
43
+
44
+ Args:
45
+ state: Decoder state for prefix tokens
46
+ i (int): Index to select a state in the main beam search
47
+ new_id (int): New label id to select a state if necessary
48
+
49
+ Returns:
50
+ state: pruned state
51
+
52
+ """
53
+ if type(state) == tuple:
54
+ if len(state) == 2: # for CTCPrefixScore
55
+ sc, st = state
56
+ return sc[i], st[i]
57
+ else: # for CTCPrefixScoreTH (need new_id > 0)
58
+ r, log_psi, f_min, f_max, scoring_idmap = state
59
+ s = log_psi[i, new_id].expand(log_psi.size(1))
60
+ if scoring_idmap is not None:
61
+ return r[:, :, i, scoring_idmap[i, new_id]], s, f_min, f_max
62
+ else:
63
+ return r[:, :, i, new_id], s, f_min, f_max
64
+ return None if state is None else state[i]
65
+
66
+ def score_partial(self, y, ids, state, x):
67
+ """Score new token.
68
+
69
+ Args:
70
+ y (torch.Tensor): 1D prefix token
71
+ next_tokens (torch.Tensor): torch.int64 next token to score
72
+ state: decoder state for prefix tokens
73
+ x (torch.Tensor): 2D encoder feature that generates ys
74
+
75
+ Returns:
76
+ tuple[torch.Tensor, Any]:
77
+ Tuple of a score tensor for y that has a shape `(len(next_tokens),)`
78
+ and next state for ys
79
+
80
+ """
81
+ prev_score, state = state
82
+ presub_score, new_st = self.impl(y.cpu(), ids.cpu(), state)
83
+ tscore = torch.as_tensor(
84
+ presub_score - prev_score, device=x.device, dtype=x.dtype
85
+ )
86
+ return tscore, (presub_score, new_st)
87
+
88
+ def batch_init_state(self, x: torch.Tensor):
89
+ """Get an initial state for decoding.
90
+
91
+ Args:
92
+ x (torch.Tensor): The encoded feature tensor
93
+
94
+ Returns: initial state
95
+
96
+ """
97
+ logp = self.ctc.log_softmax(x.unsqueeze(0)) # assuming batch_size = 1
98
+ xlen = torch.tensor([logp.size(1)])
99
+ self.impl = CTCPrefixScoreTH(logp, xlen, 0, self.eos)
100
+ return None
101
+
102
+ def batch_score_partial(self, y, ids, state, x):
103
+ """Score new token.
104
+
105
+ Args:
106
+ y (torch.Tensor): 1D prefix token
107
+ ids (torch.Tensor): torch.int64 next token to score
108
+ state: decoder state for prefix tokens
109
+ x (torch.Tensor): 2D encoder feature that generates ys
110
+
111
+ Returns:
112
+ tuple[torch.Tensor, Any]:
113
+ Tuple of a score tensor for y that has a shape `(len(next_tokens),)`
114
+ and next state for ys
115
+
116
+ """
117
+ batch_state = (
118
+ (
119
+ torch.stack([s[0] for s in state], dim=2),
120
+ torch.stack([s[1] for s in state]),
121
+ state[0][2],
122
+ state[0][3],
123
+ )
124
+ if state[0] is not None
125
+ else None
126
+ )
127
+ return self.impl(y, batch_state, ids)
128
+
129
+ def extend_prob(self, x: torch.Tensor):
130
+ """Extend probs for decoding.
131
+
132
+ This extension is for streaming decoding
133
+ as in Eq (14) in https://arxiv.org/abs/2006.14941
134
+
135
+ Args:
136
+ x (torch.Tensor): The encoded feature tensor
137
+
138
+ """
139
+ logp = self.ctc.log_softmax(x.unsqueeze(0))
140
+ self.impl.extend_prob(logp)
141
+
142
+ def extend_state(self, state):
143
+ """Extend state for decoding.
144
+
145
+ This extension is for streaming decoding
146
+ as in Eq (14) in https://arxiv.org/abs/2006.14941
147
+
148
+ Args:
149
+ state: The states of hyps
150
+
151
+ Returns: exteded state
152
+
153
+ """
154
+ new_state = []
155
+ for s in state:
156
+ new_state.append(self.impl.extend_state(s))
157
+
158
+ return new_state
espnet/nets/scorers/length_bonus.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Length bonus module."""
2
+ from typing import Any
3
+ from typing import List
4
+ from typing import Tuple
5
+
6
+ import torch
7
+
8
+ from espnet.nets.scorer_interface import BatchScorerInterface
9
+
10
+
11
+ class LengthBonus(BatchScorerInterface):
12
+ """Length bonus in beam search."""
13
+
14
+ def __init__(self, n_vocab: int):
15
+ """Initialize class.
16
+
17
+ Args:
18
+ n_vocab (int): The number of tokens in vocabulary for beam search
19
+
20
+ """
21
+ self.n = n_vocab
22
+
23
+ def score(self, y, state, x):
24
+ """Score new token.
25
+
26
+ Args:
27
+ y (torch.Tensor): 1D torch.int64 prefix tokens.
28
+ state: Scorer state for prefix tokens
29
+ x (torch.Tensor): 2D encoder feature that generates ys.
30
+
31
+ Returns:
32
+ tuple[torch.Tensor, Any]: Tuple of
33
+ torch.float32 scores for next token (n_vocab)
34
+ and None
35
+
36
+ """
37
+ return torch.tensor([1.0], device=x.device, dtype=x.dtype).expand(self.n), None
38
+
39
+ def batch_score(
40
+ self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor
41
+ ) -> Tuple[torch.Tensor, List[Any]]:
42
+ """Score new token batch.
43
+
44
+ Args:
45
+ ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen).
46
+ states (List[Any]): Scorer states for prefix tokens.
47
+ xs (torch.Tensor):
48
+ The encoder feature that generates ys (n_batch, xlen, n_feat).
49
+
50
+ Returns:
51
+ tuple[torch.Tensor, List[Any]]: Tuple of
52
+ batchfied scores for next token with shape of `(n_batch, n_vocab)`
53
+ and next state list for ys.
54
+
55
+ """
56
+ return (
57
+ torch.tensor([1.0], device=xs.device, dtype=xs.dtype).expand(
58
+ ys.shape[0], self.n
59
+ ),
60
+ None,
61
+ )
espnet/utils/cli_utils.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections.abc import Sequence
2
+ from distutils.util import strtobool as dist_strtobool
3
+ import sys
4
+
5
+ import numpy
6
+
7
+
8
+ def strtobool(x):
9
+ # distutils.util.strtobool returns integer, but it's confusing,
10
+ return bool(dist_strtobool(x))
11
+
12
+
13
+ def get_commandline_args():
14
+ extra_chars = [
15
+ " ",
16
+ ";",
17
+ "&",
18
+ "(",
19
+ ")",
20
+ "|",
21
+ "^",
22
+ "<",
23
+ ">",
24
+ "?",
25
+ "*",
26
+ "[",
27
+ "]",
28
+ "$",
29
+ "`",
30
+ '"',
31
+ "\\",
32
+ "!",
33
+ "{",
34
+ "}",
35
+ ]
36
+
37
+ # Escape the extra characters for shell
38
+ argv = [
39
+ arg.replace("'", "'\\''")
40
+ if all(char not in arg for char in extra_chars)
41
+ else "'" + arg.replace("'", "'\\''") + "'"
42
+ for arg in sys.argv
43
+ ]
44
+
45
+ return sys.executable + " " + " ".join(argv)
46
+
47
+
48
+ def is_scipy_wav_style(value):
49
+ # If Tuple[int, numpy.ndarray] or not
50
+ return (
51
+ isinstance(value, Sequence)
52
+ and len(value) == 2
53
+ and isinstance(value[0], int)
54
+ and isinstance(value[1], numpy.ndarray)
55
+ )
56
+
57
+
58
+ def assert_scipy_wav_style(value):
59
+ assert is_scipy_wav_style(
60
+ value
61
+ ), "Must be Tuple[int, numpy.ndarray], but got {}".format(
62
+ type(value)
63
+ if not isinstance(value, Sequence)
64
+ else "{}[{}]".format(type(value), ", ".join(str(type(v)) for v in value))
65
+ )
espnet/utils/dynamic_import.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+
3
+
4
+ def dynamic_import(import_path, alias=dict()):
5
+ """dynamic import module and class
6
+
7
+ :param str import_path: syntax 'module_name:class_name'
8
+ e.g., 'espnet.transform.add_deltas:AddDeltas'
9
+ :param dict alias: shortcut for registered class
10
+ :return: imported class
11
+ """
12
+ if import_path not in alias and ":" not in import_path:
13
+ raise ValueError(
14
+ "import_path should be one of {} or "
15
+ 'include ":", e.g. "espnet.transform.add_deltas:AddDeltas" : '
16
+ "{}".format(set(alias), import_path)
17
+ )
18
+ if ":" not in import_path:
19
+ import_path = alias[import_path]
20
+
21
+ module_name, objname = import_path.split(":")
22
+ m = importlib.import_module(module_name)
23
+ return getattr(m, objname)
espnet/utils/fill_missing_args.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ # Copyright 2018 Nagoya University (Tomoki Hayashi)
4
+ # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
5
+
6
+ import argparse
7
+ import logging
8
+
9
+
10
+ def fill_missing_args(args, add_arguments):
11
+ """Fill missing arguments in args.
12
+
13
+ Args:
14
+ args (Namespace or None): Namesapce containing hyperparameters.
15
+ add_arguments (function): Function to add arguments.
16
+
17
+ Returns:
18
+ Namespace: Arguments whose missing ones are filled with default value.
19
+
20
+ Examples:
21
+ >>> from argparse import Namespace
22
+ >>> from espnet.nets.pytorch_backend.e2e_tts_tacotron2 import Tacotron2
23
+ >>> args = Namespace()
24
+ >>> fill_missing_args(args, Tacotron2.add_arguments_fn)
25
+ Namespace(aconv_chans=32, aconv_filts=15, adim=512, atype='location', ...)
26
+
27
+ """
28
+ # check argument type
29
+ assert isinstance(args, argparse.Namespace) or args is None
30
+ assert callable(add_arguments)
31
+
32
+ # get default arguments
33
+ default_args, _ = add_arguments(argparse.ArgumentParser()).parse_known_args()
34
+
35
+ # convert to dict
36
+ args = {} if args is None else vars(args)
37
+ default_args = vars(default_args)
38
+
39
+ for key, value in default_args.items():
40
+ if key not in args:
41
+ logging.info(
42
+ 'attribute "%s" does not exist. use default %s.' % (key, str(value))
43
+ )
44
+ args[key] = value
45
+
46
+ return argparse.Namespace(**args)
pipelines/.DS_Store ADDED
Binary file (6.15 kB). View file
 
pipelines/data/.DS_Store ADDED
Binary file (6.15 kB). View file