Spaces:
Runtime error
Runtime error
First push
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- configs/LRS3_V_WER19.1.ini +18 -0
- espnet/.DS_Store +0 -0
- espnet/asr/asr_utils.py +990 -0
- espnet/nets/.DS_Store +0 -0
- espnet/nets/batch_beam_search.py +349 -0
- espnet/nets/beam_search.py +516 -0
- espnet/nets/ctc_prefix_score.py +359 -0
- espnet/nets/e2e_asr_common.py +249 -0
- espnet/nets/lm_interface.py +86 -0
- espnet/nets/pytorch_backend/backbones/conv1d_extractor.py +25 -0
- espnet/nets/pytorch_backend/backbones/conv3d_extractor.py +47 -0
- espnet/nets/pytorch_backend/backbones/modules/resnet.py +178 -0
- espnet/nets/pytorch_backend/backbones/modules/resnet1d.py +213 -0
- espnet/nets/pytorch_backend/backbones/modules/shufflenetv2.py +165 -0
- espnet/nets/pytorch_backend/ctc.py +283 -0
- espnet/nets/pytorch_backend/e2e_asr_transformer.py +320 -0
- espnet/nets/pytorch_backend/e2e_asr_transformer_av.py +352 -0
- espnet/nets/pytorch_backend/lm/__init__.py +1 -0
- espnet/nets/pytorch_backend/lm/default.py +431 -0
- espnet/nets/pytorch_backend/lm/seq_rnn.py +178 -0
- espnet/nets/pytorch_backend/lm/transformer.py +252 -0
- espnet/nets/pytorch_backend/nets_utils.py +526 -0
- espnet/nets/pytorch_backend/transformer/__init__.py +1 -0
- espnet/nets/pytorch_backend/transformer/add_sos_eos.py +31 -0
- espnet/nets/pytorch_backend/transformer/attention.py +280 -0
- espnet/nets/pytorch_backend/transformer/convolution.py +73 -0
- espnet/nets/pytorch_backend/transformer/decoder.py +229 -0
- espnet/nets/pytorch_backend/transformer/decoder_layer.py +121 -0
- espnet/nets/pytorch_backend/transformer/embedding.py +217 -0
- espnet/nets/pytorch_backend/transformer/encoder.py +283 -0
- espnet/nets/pytorch_backend/transformer/encoder_layer.py +149 -0
- espnet/nets/pytorch_backend/transformer/label_smoothing_loss.py +63 -0
- espnet/nets/pytorch_backend/transformer/layer_norm.py +33 -0
- espnet/nets/pytorch_backend/transformer/mask.py +51 -0
- espnet/nets/pytorch_backend/transformer/multi_layer_conv.py +105 -0
- espnet/nets/pytorch_backend/transformer/optimizer.py +75 -0
- espnet/nets/pytorch_backend/transformer/plot.py +134 -0
- espnet/nets/pytorch_backend/transformer/positionwise_feed_forward.py +30 -0
- espnet/nets/pytorch_backend/transformer/raw_embeddings.py +77 -0
- espnet/nets/pytorch_backend/transformer/repeat.py +30 -0
- espnet/nets/pytorch_backend/transformer/subsampling.py +52 -0
- espnet/nets/scorer_interface.py +188 -0
- espnet/nets/scorers/__init__.py +1 -0
- espnet/nets/scorers/ctc.py +158 -0
- espnet/nets/scorers/length_bonus.py +61 -0
- espnet/utils/cli_utils.py +65 -0
- espnet/utils/dynamic_import.py +23 -0
- espnet/utils/fill_missing_args.py +46 -0
- pipelines/.DS_Store +0 -0
- pipelines/data/.DS_Store +0 -0
configs/LRS3_V_WER19.1.ini
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[input]
|
2 |
+
modality=video
|
3 |
+
v_fps=25
|
4 |
+
|
5 |
+
[model]
|
6 |
+
v_fps=25
|
7 |
+
model_path=benchmarks/LRS3/models/LRS3_V_WER19.1/model.pth
|
8 |
+
model_conf=benchmarks/LRS3/models/LRS3_V_WER19.1/model.json
|
9 |
+
rnnlm=benchmarks/LRS3/language_models/lm_en_subword/model.pth
|
10 |
+
rnnlm_conf=benchmarks/LRS3/language_models/lm_en_subword/model.json
|
11 |
+
|
12 |
+
[decode]
|
13 |
+
beam_size=40
|
14 |
+
penalty=0.0
|
15 |
+
maxlenratio=0.0
|
16 |
+
minlenratio=0.0
|
17 |
+
ctc_weight=0.1
|
18 |
+
lm_weight=0.3
|
espnet/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
espnet/asr/asr_utils.py
ADDED
@@ -0,0 +1,990 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2017 Johns Hopkins University (Shinji Watanabe)
|
2 |
+
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
3 |
+
|
4 |
+
import argparse
|
5 |
+
import copy
|
6 |
+
import json
|
7 |
+
import logging
|
8 |
+
import os
|
9 |
+
import shutil
|
10 |
+
import tempfile
|
11 |
+
|
12 |
+
import numpy as np
|
13 |
+
import torch
|
14 |
+
|
15 |
+
|
16 |
+
# * -------------------- training iterator related -------------------- *
|
17 |
+
|
18 |
+
|
19 |
+
class CompareValueTrigger(object):
|
20 |
+
"""Trigger invoked when key value getting bigger or lower than before.
|
21 |
+
|
22 |
+
Args:
|
23 |
+
key (str) : Key of value.
|
24 |
+
compare_fn ((float, float) -> bool) : Function to compare the values.
|
25 |
+
trigger (tuple(int, str)) : Trigger that decide the comparison interval.
|
26 |
+
|
27 |
+
"""
|
28 |
+
|
29 |
+
def __init__(self, key, compare_fn, trigger=(1, "epoch")):
|
30 |
+
from chainer import training
|
31 |
+
|
32 |
+
self._key = key
|
33 |
+
self._best_value = None
|
34 |
+
self._interval_trigger = training.util.get_trigger(trigger)
|
35 |
+
self._init_summary()
|
36 |
+
self._compare_fn = compare_fn
|
37 |
+
|
38 |
+
def __call__(self, trainer):
|
39 |
+
"""Get value related to the key and compare with current value."""
|
40 |
+
observation = trainer.observation
|
41 |
+
summary = self._summary
|
42 |
+
key = self._key
|
43 |
+
if key in observation:
|
44 |
+
summary.add({key: observation[key]})
|
45 |
+
|
46 |
+
if not self._interval_trigger(trainer):
|
47 |
+
return False
|
48 |
+
|
49 |
+
stats = summary.compute_mean()
|
50 |
+
value = float(stats[key]) # copy to CPU
|
51 |
+
self._init_summary()
|
52 |
+
|
53 |
+
if self._best_value is None:
|
54 |
+
# initialize best value
|
55 |
+
self._best_value = value
|
56 |
+
return False
|
57 |
+
elif self._compare_fn(self._best_value, value):
|
58 |
+
return True
|
59 |
+
else:
|
60 |
+
self._best_value = value
|
61 |
+
return False
|
62 |
+
|
63 |
+
def _init_summary(self):
|
64 |
+
import chainer
|
65 |
+
|
66 |
+
self._summary = chainer.reporter.DictSummary()
|
67 |
+
|
68 |
+
|
69 |
+
try:
|
70 |
+
from chainer.training import extension
|
71 |
+
except ImportError:
|
72 |
+
PlotAttentionReport = None
|
73 |
+
else:
|
74 |
+
|
75 |
+
class PlotAttentionReport(extension.Extension):
|
76 |
+
"""Plot attention reporter.
|
77 |
+
|
78 |
+
Args:
|
79 |
+
att_vis_fn (espnet.nets.*_backend.e2e_asr.E2E.calculate_all_attentions):
|
80 |
+
Function of attention visualization.
|
81 |
+
data (list[tuple(str, dict[str, list[Any]])]): List json utt key items.
|
82 |
+
outdir (str): Directory to save figures.
|
83 |
+
converter (espnet.asr.*_backend.asr.CustomConverter):
|
84 |
+
Function to convert data.
|
85 |
+
device (int | torch.device): Device.
|
86 |
+
reverse (bool): If True, input and output length are reversed.
|
87 |
+
ikey (str): Key to access input
|
88 |
+
(for ASR/ST ikey="input", for MT ikey="output".)
|
89 |
+
iaxis (int): Dimension to access input
|
90 |
+
(for ASR/ST iaxis=0, for MT iaxis=1.)
|
91 |
+
okey (str): Key to access output
|
92 |
+
(for ASR/ST okey="input", MT okay="output".)
|
93 |
+
oaxis (int): Dimension to access output
|
94 |
+
(for ASR/ST oaxis=0, for MT oaxis=0.)
|
95 |
+
subsampling_factor (int): subsampling factor in encoder
|
96 |
+
|
97 |
+
"""
|
98 |
+
|
99 |
+
def __init__(
|
100 |
+
self,
|
101 |
+
att_vis_fn,
|
102 |
+
data,
|
103 |
+
outdir,
|
104 |
+
converter,
|
105 |
+
transform,
|
106 |
+
device,
|
107 |
+
reverse=False,
|
108 |
+
ikey="input",
|
109 |
+
iaxis=0,
|
110 |
+
okey="output",
|
111 |
+
oaxis=0,
|
112 |
+
subsampling_factor=1,
|
113 |
+
):
|
114 |
+
self.att_vis_fn = att_vis_fn
|
115 |
+
self.data = copy.deepcopy(data)
|
116 |
+
self.data_dict = {k: v for k, v in copy.deepcopy(data)}
|
117 |
+
# key is utterance ID
|
118 |
+
self.outdir = outdir
|
119 |
+
self.converter = converter
|
120 |
+
self.transform = transform
|
121 |
+
self.device = device
|
122 |
+
self.reverse = reverse
|
123 |
+
self.ikey = ikey
|
124 |
+
self.iaxis = iaxis
|
125 |
+
self.okey = okey
|
126 |
+
self.oaxis = oaxis
|
127 |
+
self.factor = subsampling_factor
|
128 |
+
if not os.path.exists(self.outdir):
|
129 |
+
os.makedirs(self.outdir)
|
130 |
+
|
131 |
+
def __call__(self, trainer):
|
132 |
+
"""Plot and save image file of att_ws matrix."""
|
133 |
+
att_ws, uttid_list = self.get_attention_weights()
|
134 |
+
if isinstance(att_ws, list): # multi-encoder case
|
135 |
+
num_encs = len(att_ws) - 1
|
136 |
+
# atts
|
137 |
+
for i in range(num_encs):
|
138 |
+
for idx, att_w in enumerate(att_ws[i]):
|
139 |
+
filename = "%s/%s.ep.{.updater.epoch}.att%d.png" % (
|
140 |
+
self.outdir,
|
141 |
+
uttid_list[idx],
|
142 |
+
i + 1,
|
143 |
+
)
|
144 |
+
att_w = self.trim_attention_weight(uttid_list[idx], att_w)
|
145 |
+
np_filename = "%s/%s.ep.{.updater.epoch}.att%d.npy" % (
|
146 |
+
self.outdir,
|
147 |
+
uttid_list[idx],
|
148 |
+
i + 1,
|
149 |
+
)
|
150 |
+
np.save(np_filename.format(trainer), att_w)
|
151 |
+
self._plot_and_save_attention(att_w, filename.format(trainer))
|
152 |
+
# han
|
153 |
+
for idx, att_w in enumerate(att_ws[num_encs]):
|
154 |
+
filename = "%s/%s.ep.{.updater.epoch}.han.png" % (
|
155 |
+
self.outdir,
|
156 |
+
uttid_list[idx],
|
157 |
+
)
|
158 |
+
att_w = self.trim_attention_weight(uttid_list[idx], att_w)
|
159 |
+
np_filename = "%s/%s.ep.{.updater.epoch}.han.npy" % (
|
160 |
+
self.outdir,
|
161 |
+
uttid_list[idx],
|
162 |
+
)
|
163 |
+
np.save(np_filename.format(trainer), att_w)
|
164 |
+
self._plot_and_save_attention(
|
165 |
+
att_w, filename.format(trainer), han_mode=True
|
166 |
+
)
|
167 |
+
else:
|
168 |
+
for idx, att_w in enumerate(att_ws):
|
169 |
+
filename = "%s/%s.ep.{.updater.epoch}.png" % (
|
170 |
+
self.outdir,
|
171 |
+
uttid_list[idx],
|
172 |
+
)
|
173 |
+
att_w = self.trim_attention_weight(uttid_list[idx], att_w)
|
174 |
+
np_filename = "%s/%s.ep.{.updater.epoch}.npy" % (
|
175 |
+
self.outdir,
|
176 |
+
uttid_list[idx],
|
177 |
+
)
|
178 |
+
np.save(np_filename.format(trainer), att_w)
|
179 |
+
self._plot_and_save_attention(att_w, filename.format(trainer))
|
180 |
+
|
181 |
+
def log_attentions(self, logger, step):
|
182 |
+
"""Add image files of att_ws matrix to the tensorboard."""
|
183 |
+
att_ws, uttid_list = self.get_attention_weights()
|
184 |
+
if isinstance(att_ws, list): # multi-encoder case
|
185 |
+
num_encs = len(att_ws) - 1
|
186 |
+
# atts
|
187 |
+
for i in range(num_encs):
|
188 |
+
for idx, att_w in enumerate(att_ws[i]):
|
189 |
+
att_w = self.trim_attention_weight(uttid_list[idx], att_w)
|
190 |
+
plot = self.draw_attention_plot(att_w)
|
191 |
+
logger.add_figure(
|
192 |
+
"%s_att%d" % (uttid_list[idx], i + 1),
|
193 |
+
plot.gcf(),
|
194 |
+
step,
|
195 |
+
)
|
196 |
+
# han
|
197 |
+
for idx, att_w in enumerate(att_ws[num_encs]):
|
198 |
+
att_w = self.trim_attention_weight(uttid_list[idx], att_w)
|
199 |
+
plot = self.draw_han_plot(att_w)
|
200 |
+
logger.add_figure(
|
201 |
+
"%s_han" % (uttid_list[idx]),
|
202 |
+
plot.gcf(),
|
203 |
+
step,
|
204 |
+
)
|
205 |
+
else:
|
206 |
+
for idx, att_w in enumerate(att_ws):
|
207 |
+
att_w = self.trim_attention_weight(uttid_list[idx], att_w)
|
208 |
+
plot = self.draw_attention_plot(att_w)
|
209 |
+
logger.add_figure("%s" % (uttid_list[idx]), plot.gcf(), step)
|
210 |
+
|
211 |
+
def get_attention_weights(self):
|
212 |
+
"""Return attention weights.
|
213 |
+
|
214 |
+
Returns:
|
215 |
+
numpy.ndarray: attention weights. float. Its shape would be
|
216 |
+
differ from backend.
|
217 |
+
* pytorch-> 1) multi-head case => (B, H, Lmax, Tmax), 2)
|
218 |
+
other case => (B, Lmax, Tmax).
|
219 |
+
* chainer-> (B, Lmax, Tmax)
|
220 |
+
|
221 |
+
"""
|
222 |
+
return_batch, uttid_list = self.transform(self.data, return_uttid=True)
|
223 |
+
batch = self.converter([return_batch], self.device)
|
224 |
+
if isinstance(batch, tuple):
|
225 |
+
att_ws = self.att_vis_fn(*batch)
|
226 |
+
else:
|
227 |
+
att_ws = self.att_vis_fn(**batch)
|
228 |
+
return att_ws, uttid_list
|
229 |
+
|
230 |
+
def trim_attention_weight(self, uttid, att_w):
|
231 |
+
"""Transform attention matrix with regard to self.reverse."""
|
232 |
+
if self.reverse:
|
233 |
+
enc_key, enc_axis = self.okey, self.oaxis
|
234 |
+
dec_key, dec_axis = self.ikey, self.iaxis
|
235 |
+
else:
|
236 |
+
enc_key, enc_axis = self.ikey, self.iaxis
|
237 |
+
dec_key, dec_axis = self.okey, self.oaxis
|
238 |
+
dec_len = int(self.data_dict[uttid][dec_key][dec_axis]["shape"][0])
|
239 |
+
enc_len = int(self.data_dict[uttid][enc_key][enc_axis]["shape"][0])
|
240 |
+
if self.factor > 1:
|
241 |
+
enc_len //= self.factor
|
242 |
+
if len(att_w.shape) == 3:
|
243 |
+
att_w = att_w[:, :dec_len, :enc_len]
|
244 |
+
else:
|
245 |
+
att_w = att_w[:dec_len, :enc_len]
|
246 |
+
return att_w
|
247 |
+
|
248 |
+
def draw_attention_plot(self, att_w):
|
249 |
+
"""Plot the att_w matrix.
|
250 |
+
|
251 |
+
Returns:
|
252 |
+
matplotlib.pyplot: pyplot object with attention matrix image.
|
253 |
+
|
254 |
+
"""
|
255 |
+
import matplotlib
|
256 |
+
|
257 |
+
matplotlib.use("Agg")
|
258 |
+
import matplotlib.pyplot as plt
|
259 |
+
|
260 |
+
plt.clf()
|
261 |
+
att_w = att_w.astype(np.float32)
|
262 |
+
if len(att_w.shape) == 3:
|
263 |
+
for h, aw in enumerate(att_w, 1):
|
264 |
+
plt.subplot(1, len(att_w), h)
|
265 |
+
plt.imshow(aw, aspect="auto")
|
266 |
+
plt.xlabel("Encoder Index")
|
267 |
+
plt.ylabel("Decoder Index")
|
268 |
+
else:
|
269 |
+
plt.imshow(att_w, aspect="auto")
|
270 |
+
plt.xlabel("Encoder Index")
|
271 |
+
plt.ylabel("Decoder Index")
|
272 |
+
plt.tight_layout()
|
273 |
+
return plt
|
274 |
+
|
275 |
+
def draw_han_plot(self, att_w):
|
276 |
+
"""Plot the att_w matrix for hierarchical attention.
|
277 |
+
|
278 |
+
Returns:
|
279 |
+
matplotlib.pyplot: pyplot object with attention matrix image.
|
280 |
+
|
281 |
+
"""
|
282 |
+
import matplotlib
|
283 |
+
|
284 |
+
matplotlib.use("Agg")
|
285 |
+
import matplotlib.pyplot as plt
|
286 |
+
|
287 |
+
plt.clf()
|
288 |
+
if len(att_w.shape) == 3:
|
289 |
+
for h, aw in enumerate(att_w, 1):
|
290 |
+
legends = []
|
291 |
+
plt.subplot(1, len(att_w), h)
|
292 |
+
for i in range(aw.shape[1]):
|
293 |
+
plt.plot(aw[:, i])
|
294 |
+
legends.append("Att{}".format(i))
|
295 |
+
plt.ylim([0, 1.0])
|
296 |
+
plt.xlim([0, aw.shape[0]])
|
297 |
+
plt.grid(True)
|
298 |
+
plt.ylabel("Attention Weight")
|
299 |
+
plt.xlabel("Decoder Index")
|
300 |
+
plt.legend(legends)
|
301 |
+
else:
|
302 |
+
legends = []
|
303 |
+
for i in range(att_w.shape[1]):
|
304 |
+
plt.plot(att_w[:, i])
|
305 |
+
legends.append("Att{}".format(i))
|
306 |
+
plt.ylim([0, 1.0])
|
307 |
+
plt.xlim([0, att_w.shape[0]])
|
308 |
+
plt.grid(True)
|
309 |
+
plt.ylabel("Attention Weight")
|
310 |
+
plt.xlabel("Decoder Index")
|
311 |
+
plt.legend(legends)
|
312 |
+
plt.tight_layout()
|
313 |
+
return plt
|
314 |
+
|
315 |
+
def _plot_and_save_attention(self, att_w, filename, han_mode=False):
|
316 |
+
if han_mode:
|
317 |
+
plt = self.draw_han_plot(att_w)
|
318 |
+
else:
|
319 |
+
plt = self.draw_attention_plot(att_w)
|
320 |
+
plt.savefig(filename)
|
321 |
+
plt.close()
|
322 |
+
|
323 |
+
|
324 |
+
try:
|
325 |
+
from chainer.training import extension
|
326 |
+
except ImportError:
|
327 |
+
PlotCTCReport = None
|
328 |
+
else:
|
329 |
+
|
330 |
+
class PlotCTCReport(extension.Extension):
|
331 |
+
"""Plot CTC reporter.
|
332 |
+
|
333 |
+
Args:
|
334 |
+
ctc_vis_fn (espnet.nets.*_backend.e2e_asr.E2E.calculate_all_ctc_probs):
|
335 |
+
Function of CTC visualization.
|
336 |
+
data (list[tuple(str, dict[str, list[Any]])]): List json utt key items.
|
337 |
+
outdir (str): Directory to save figures.
|
338 |
+
converter (espnet.asr.*_backend.asr.CustomConverter):
|
339 |
+
Function to convert data.
|
340 |
+
device (int | torch.device): Device.
|
341 |
+
reverse (bool): If True, input and output length are reversed.
|
342 |
+
ikey (str): Key to access input
|
343 |
+
(for ASR/ST ikey="input", for MT ikey="output".)
|
344 |
+
iaxis (int): Dimension to access input
|
345 |
+
(for ASR/ST iaxis=0, for MT iaxis=1.)
|
346 |
+
okey (str): Key to access output
|
347 |
+
(for ASR/ST okey="input", MT okay="output".)
|
348 |
+
oaxis (int): Dimension to access output
|
349 |
+
(for ASR/ST oaxis=0, for MT oaxis=0.)
|
350 |
+
subsampling_factor (int): subsampling factor in encoder
|
351 |
+
|
352 |
+
"""
|
353 |
+
|
354 |
+
def __init__(
|
355 |
+
self,
|
356 |
+
ctc_vis_fn,
|
357 |
+
data,
|
358 |
+
outdir,
|
359 |
+
converter,
|
360 |
+
transform,
|
361 |
+
device,
|
362 |
+
reverse=False,
|
363 |
+
ikey="input",
|
364 |
+
iaxis=0,
|
365 |
+
okey="output",
|
366 |
+
oaxis=0,
|
367 |
+
subsampling_factor=1,
|
368 |
+
):
|
369 |
+
self.ctc_vis_fn = ctc_vis_fn
|
370 |
+
self.data = copy.deepcopy(data)
|
371 |
+
self.data_dict = {k: v for k, v in copy.deepcopy(data)}
|
372 |
+
# key is utterance ID
|
373 |
+
self.outdir = outdir
|
374 |
+
self.converter = converter
|
375 |
+
self.transform = transform
|
376 |
+
self.device = device
|
377 |
+
self.reverse = reverse
|
378 |
+
self.ikey = ikey
|
379 |
+
self.iaxis = iaxis
|
380 |
+
self.okey = okey
|
381 |
+
self.oaxis = oaxis
|
382 |
+
self.factor = subsampling_factor
|
383 |
+
if not os.path.exists(self.outdir):
|
384 |
+
os.makedirs(self.outdir)
|
385 |
+
|
386 |
+
def __call__(self, trainer):
|
387 |
+
"""Plot and save image file of ctc prob."""
|
388 |
+
ctc_probs, uttid_list = self.get_ctc_probs()
|
389 |
+
if isinstance(ctc_probs, list): # multi-encoder case
|
390 |
+
num_encs = len(ctc_probs) - 1
|
391 |
+
for i in range(num_encs):
|
392 |
+
for idx, ctc_prob in enumerate(ctc_probs[i]):
|
393 |
+
filename = "%s/%s.ep.{.updater.epoch}.ctc%d.png" % (
|
394 |
+
self.outdir,
|
395 |
+
uttid_list[idx],
|
396 |
+
i + 1,
|
397 |
+
)
|
398 |
+
ctc_prob = self.trim_ctc_prob(uttid_list[idx], ctc_prob)
|
399 |
+
np_filename = "%s/%s.ep.{.updater.epoch}.ctc%d.npy" % (
|
400 |
+
self.outdir,
|
401 |
+
uttid_list[idx],
|
402 |
+
i + 1,
|
403 |
+
)
|
404 |
+
np.save(np_filename.format(trainer), ctc_prob)
|
405 |
+
self._plot_and_save_ctc(ctc_prob, filename.format(trainer))
|
406 |
+
else:
|
407 |
+
for idx, ctc_prob in enumerate(ctc_probs):
|
408 |
+
filename = "%s/%s.ep.{.updater.epoch}.png" % (
|
409 |
+
self.outdir,
|
410 |
+
uttid_list[idx],
|
411 |
+
)
|
412 |
+
ctc_prob = self.trim_ctc_prob(uttid_list[idx], ctc_prob)
|
413 |
+
np_filename = "%s/%s.ep.{.updater.epoch}.npy" % (
|
414 |
+
self.outdir,
|
415 |
+
uttid_list[idx],
|
416 |
+
)
|
417 |
+
np.save(np_filename.format(trainer), ctc_prob)
|
418 |
+
self._plot_and_save_ctc(ctc_prob, filename.format(trainer))
|
419 |
+
|
420 |
+
def log_ctc_probs(self, logger, step):
|
421 |
+
"""Add image files of ctc probs to the tensorboard."""
|
422 |
+
ctc_probs, uttid_list = self.get_ctc_probs()
|
423 |
+
if isinstance(ctc_probs, list): # multi-encoder case
|
424 |
+
num_encs = len(ctc_probs) - 1
|
425 |
+
for i in range(num_encs):
|
426 |
+
for idx, ctc_prob in enumerate(ctc_probs[i]):
|
427 |
+
ctc_prob = self.trim_ctc_prob(uttid_list[idx], ctc_prob)
|
428 |
+
plot = self.draw_ctc_plot(ctc_prob)
|
429 |
+
logger.add_figure(
|
430 |
+
"%s_ctc%d" % (uttid_list[idx], i + 1),
|
431 |
+
plot.gcf(),
|
432 |
+
step,
|
433 |
+
)
|
434 |
+
else:
|
435 |
+
for idx, ctc_prob in enumerate(ctc_probs):
|
436 |
+
ctc_prob = self.trim_ctc_prob(uttid_list[idx], ctc_prob)
|
437 |
+
plot = self.draw_ctc_plot(ctc_prob)
|
438 |
+
logger.add_figure("%s" % (uttid_list[idx]), plot.gcf(), step)
|
439 |
+
|
440 |
+
def get_ctc_probs(self):
|
441 |
+
"""Return CTC probs.
|
442 |
+
|
443 |
+
Returns:
|
444 |
+
numpy.ndarray: CTC probs. float. Its shape would be
|
445 |
+
differ from backend. (B, Tmax, vocab).
|
446 |
+
|
447 |
+
"""
|
448 |
+
return_batch, uttid_list = self.transform(self.data, return_uttid=True)
|
449 |
+
batch = self.converter([return_batch], self.device)
|
450 |
+
if isinstance(batch, tuple):
|
451 |
+
probs = self.ctc_vis_fn(*batch)
|
452 |
+
else:
|
453 |
+
probs = self.ctc_vis_fn(**batch)
|
454 |
+
return probs, uttid_list
|
455 |
+
|
456 |
+
def trim_ctc_prob(self, uttid, prob):
|
457 |
+
"""Trim CTC posteriors accoding to input lengths."""
|
458 |
+
enc_len = int(self.data_dict[uttid][self.ikey][self.iaxis]["shape"][0])
|
459 |
+
if self.factor > 1:
|
460 |
+
enc_len //= self.factor
|
461 |
+
prob = prob[:enc_len]
|
462 |
+
return prob
|
463 |
+
|
464 |
+
def draw_ctc_plot(self, ctc_prob):
|
465 |
+
"""Plot the ctc_prob matrix.
|
466 |
+
|
467 |
+
Returns:
|
468 |
+
matplotlib.pyplot: pyplot object with CTC prob matrix image.
|
469 |
+
|
470 |
+
"""
|
471 |
+
import matplotlib
|
472 |
+
|
473 |
+
matplotlib.use("Agg")
|
474 |
+
import matplotlib.pyplot as plt
|
475 |
+
|
476 |
+
ctc_prob = ctc_prob.astype(np.float32)
|
477 |
+
|
478 |
+
plt.clf()
|
479 |
+
topk_ids = np.argsort(ctc_prob, axis=1)
|
480 |
+
n_frames, vocab = ctc_prob.shape
|
481 |
+
times_probs = np.arange(n_frames)
|
482 |
+
|
483 |
+
plt.figure(figsize=(20, 8))
|
484 |
+
|
485 |
+
# NOTE: index 0 is reserved for blank
|
486 |
+
for idx in set(topk_ids.reshape(-1).tolist()):
|
487 |
+
if idx == 0:
|
488 |
+
plt.plot(
|
489 |
+
times_probs, ctc_prob[:, 0], ":", label="<blank>", color="grey"
|
490 |
+
)
|
491 |
+
else:
|
492 |
+
plt.plot(times_probs, ctc_prob[:, idx])
|
493 |
+
plt.xlabel("Input [frame]", fontsize=12)
|
494 |
+
plt.ylabel("Posteriors", fontsize=12)
|
495 |
+
plt.xticks(list(range(0, int(n_frames) + 1, 10)))
|
496 |
+
plt.yticks(list(range(0, 2, 1)))
|
497 |
+
plt.tight_layout()
|
498 |
+
return plt
|
499 |
+
|
500 |
+
def _plot_and_save_ctc(self, ctc_prob, filename):
|
501 |
+
plt = self.draw_ctc_plot(ctc_prob)
|
502 |
+
plt.savefig(filename)
|
503 |
+
plt.close()
|
504 |
+
|
505 |
+
|
506 |
+
def restore_snapshot(model, snapshot, load_fn=None):
|
507 |
+
"""Extension to restore snapshot.
|
508 |
+
|
509 |
+
Returns:
|
510 |
+
An extension function.
|
511 |
+
|
512 |
+
"""
|
513 |
+
import chainer
|
514 |
+
from chainer import training
|
515 |
+
|
516 |
+
if load_fn is None:
|
517 |
+
load_fn = chainer.serializers.load_npz
|
518 |
+
|
519 |
+
@training.make_extension(trigger=(1, "epoch"))
|
520 |
+
def restore_snapshot(trainer):
|
521 |
+
_restore_snapshot(model, snapshot, load_fn)
|
522 |
+
|
523 |
+
return restore_snapshot
|
524 |
+
|
525 |
+
|
526 |
+
def _restore_snapshot(model, snapshot, load_fn=None):
|
527 |
+
if load_fn is None:
|
528 |
+
import chainer
|
529 |
+
|
530 |
+
load_fn = chainer.serializers.load_npz
|
531 |
+
|
532 |
+
load_fn(snapshot, model)
|
533 |
+
logging.info("restored from " + str(snapshot))
|
534 |
+
|
535 |
+
|
536 |
+
def adadelta_eps_decay(eps_decay):
|
537 |
+
"""Extension to perform adadelta eps decay.
|
538 |
+
|
539 |
+
Args:
|
540 |
+
eps_decay (float): Decay rate of eps.
|
541 |
+
|
542 |
+
Returns:
|
543 |
+
An extension function.
|
544 |
+
|
545 |
+
"""
|
546 |
+
from chainer import training
|
547 |
+
|
548 |
+
@training.make_extension(trigger=(1, "epoch"))
|
549 |
+
def adadelta_eps_decay(trainer):
|
550 |
+
_adadelta_eps_decay(trainer, eps_decay)
|
551 |
+
|
552 |
+
return adadelta_eps_decay
|
553 |
+
|
554 |
+
|
555 |
+
def _adadelta_eps_decay(trainer, eps_decay):
|
556 |
+
optimizer = trainer.updater.get_optimizer("main")
|
557 |
+
# for chainer
|
558 |
+
if hasattr(optimizer, "eps"):
|
559 |
+
current_eps = optimizer.eps
|
560 |
+
setattr(optimizer, "eps", current_eps * eps_decay)
|
561 |
+
logging.info("adadelta eps decayed to " + str(optimizer.eps))
|
562 |
+
# pytorch
|
563 |
+
else:
|
564 |
+
for p in optimizer.param_groups:
|
565 |
+
p["eps"] *= eps_decay
|
566 |
+
logging.info("adadelta eps decayed to " + str(p["eps"]))
|
567 |
+
|
568 |
+
|
569 |
+
def adam_lr_decay(eps_decay):
|
570 |
+
"""Extension to perform adam lr decay.
|
571 |
+
|
572 |
+
Args:
|
573 |
+
eps_decay (float): Decay rate of lr.
|
574 |
+
|
575 |
+
Returns:
|
576 |
+
An extension function.
|
577 |
+
|
578 |
+
"""
|
579 |
+
from chainer import training
|
580 |
+
|
581 |
+
@training.make_extension(trigger=(1, "epoch"))
|
582 |
+
def adam_lr_decay(trainer):
|
583 |
+
_adam_lr_decay(trainer, eps_decay)
|
584 |
+
|
585 |
+
return adam_lr_decay
|
586 |
+
|
587 |
+
|
588 |
+
def _adam_lr_decay(trainer, eps_decay):
|
589 |
+
optimizer = trainer.updater.get_optimizer("main")
|
590 |
+
# for chainer
|
591 |
+
if hasattr(optimizer, "lr"):
|
592 |
+
current_lr = optimizer.lr
|
593 |
+
setattr(optimizer, "lr", current_lr * eps_decay)
|
594 |
+
logging.info("adam lr decayed to " + str(optimizer.lr))
|
595 |
+
# pytorch
|
596 |
+
else:
|
597 |
+
for p in optimizer.param_groups:
|
598 |
+
p["lr"] *= eps_decay
|
599 |
+
logging.info("adam lr decayed to " + str(p["lr"]))
|
600 |
+
|
601 |
+
|
602 |
+
def torch_snapshot(savefun=torch.save, filename="snapshot.ep.{.updater.epoch}"):
|
603 |
+
"""Extension to take snapshot of the trainer for pytorch.
|
604 |
+
|
605 |
+
Returns:
|
606 |
+
An extension function.
|
607 |
+
|
608 |
+
"""
|
609 |
+
from chainer.training import extension
|
610 |
+
|
611 |
+
@extension.make_extension(trigger=(1, "epoch"), priority=-100)
|
612 |
+
def torch_snapshot(trainer):
|
613 |
+
_torch_snapshot_object(trainer, trainer, filename.format(trainer), savefun)
|
614 |
+
|
615 |
+
return torch_snapshot
|
616 |
+
|
617 |
+
|
618 |
+
def _torch_snapshot_object(trainer, target, filename, savefun):
|
619 |
+
from chainer.serializers import DictionarySerializer
|
620 |
+
|
621 |
+
# make snapshot_dict dictionary
|
622 |
+
s = DictionarySerializer()
|
623 |
+
s.save(trainer)
|
624 |
+
if hasattr(trainer.updater.model, "model"):
|
625 |
+
# (for TTS)
|
626 |
+
if hasattr(trainer.updater.model.model, "module"):
|
627 |
+
model_state_dict = trainer.updater.model.model.module.state_dict()
|
628 |
+
else:
|
629 |
+
model_state_dict = trainer.updater.model.model.state_dict()
|
630 |
+
else:
|
631 |
+
# (for ASR)
|
632 |
+
if hasattr(trainer.updater.model, "module"):
|
633 |
+
model_state_dict = trainer.updater.model.module.state_dict()
|
634 |
+
else:
|
635 |
+
model_state_dict = trainer.updater.model.state_dict()
|
636 |
+
snapshot_dict = {
|
637 |
+
"trainer": s.target,
|
638 |
+
"model": model_state_dict,
|
639 |
+
"optimizer": trainer.updater.get_optimizer("main").state_dict(),
|
640 |
+
}
|
641 |
+
|
642 |
+
# save snapshot dictionary
|
643 |
+
fn = filename.format(trainer)
|
644 |
+
prefix = "tmp" + fn
|
645 |
+
tmpdir = tempfile.mkdtemp(prefix=prefix, dir=trainer.out)
|
646 |
+
tmppath = os.path.join(tmpdir, fn)
|
647 |
+
try:
|
648 |
+
savefun(snapshot_dict, tmppath)
|
649 |
+
shutil.move(tmppath, os.path.join(trainer.out, fn))
|
650 |
+
finally:
|
651 |
+
shutil.rmtree(tmpdir)
|
652 |
+
|
653 |
+
|
654 |
+
def add_gradient_noise(model, iteration, duration=100, eta=1.0, scale_factor=0.55):
|
655 |
+
"""Adds noise from a standard normal distribution to the gradients.
|
656 |
+
|
657 |
+
The standard deviation (`sigma`) is controlled by the three hyper-parameters below.
|
658 |
+
`sigma` goes to zero (no noise) with more iterations.
|
659 |
+
|
660 |
+
Args:
|
661 |
+
model (torch.nn.model): Model.
|
662 |
+
iteration (int): Number of iterations.
|
663 |
+
duration (int) {100, 1000}:
|
664 |
+
Number of durations to control the interval of the `sigma` change.
|
665 |
+
eta (float) {0.01, 0.3, 1.0}: The magnitude of `sigma`.
|
666 |
+
scale_factor (float) {0.55}: The scale of `sigma`.
|
667 |
+
"""
|
668 |
+
interval = (iteration // duration) + 1
|
669 |
+
sigma = eta / interval**scale_factor
|
670 |
+
for param in model.parameters():
|
671 |
+
if param.grad is not None:
|
672 |
+
_shape = param.grad.size()
|
673 |
+
noise = sigma * torch.randn(_shape).to(param.device)
|
674 |
+
param.grad += noise
|
675 |
+
|
676 |
+
|
677 |
+
# * -------------------- general -------------------- *
|
678 |
+
def get_model_conf(model_path, conf_path=None):
|
679 |
+
"""Get model config information by reading a model config file (model.json).
|
680 |
+
|
681 |
+
Args:
|
682 |
+
model_path (str): Model path.
|
683 |
+
conf_path (str): Optional model config path.
|
684 |
+
|
685 |
+
Returns:
|
686 |
+
list[int, int, dict[str, Any]]: Config information loaded from json file.
|
687 |
+
|
688 |
+
"""
|
689 |
+
if conf_path is None:
|
690 |
+
model_conf = os.path.dirname(model_path) + "/model.json"
|
691 |
+
else:
|
692 |
+
model_conf = conf_path
|
693 |
+
with open(model_conf, "rb") as f:
|
694 |
+
logging.info("reading a config file from " + model_conf)
|
695 |
+
confs = json.load(f)
|
696 |
+
if isinstance(confs, dict):
|
697 |
+
# for lm
|
698 |
+
args = confs
|
699 |
+
return argparse.Namespace(**args)
|
700 |
+
else:
|
701 |
+
# for asr, tts, mt
|
702 |
+
idim, odim, args = confs
|
703 |
+
return idim, odim, argparse.Namespace(**args)
|
704 |
+
|
705 |
+
|
706 |
+
def chainer_load(path, model):
|
707 |
+
"""Load chainer model parameters.
|
708 |
+
|
709 |
+
Args:
|
710 |
+
path (str): Model path or snapshot file path to be loaded.
|
711 |
+
model (chainer.Chain): Chainer model.
|
712 |
+
|
713 |
+
"""
|
714 |
+
import chainer
|
715 |
+
|
716 |
+
if "snapshot" in os.path.basename(path):
|
717 |
+
chainer.serializers.load_npz(path, model, path="updater/model:main/")
|
718 |
+
else:
|
719 |
+
chainer.serializers.load_npz(path, model)
|
720 |
+
|
721 |
+
|
722 |
+
def torch_save(path, model):
|
723 |
+
"""Save torch model states.
|
724 |
+
|
725 |
+
Args:
|
726 |
+
path (str): Model path to be saved.
|
727 |
+
model (torch.nn.Module): Torch model.
|
728 |
+
|
729 |
+
"""
|
730 |
+
if hasattr(model, "module"):
|
731 |
+
torch.save(model.module.state_dict(), path)
|
732 |
+
else:
|
733 |
+
torch.save(model.state_dict(), path)
|
734 |
+
|
735 |
+
|
736 |
+
def snapshot_object(target, filename):
|
737 |
+
"""Returns a trainer extension to take snapshots of a given object.
|
738 |
+
|
739 |
+
Args:
|
740 |
+
target (model): Object to serialize.
|
741 |
+
filename (str): Name of the file into which the object is serialized.It can
|
742 |
+
be a format string, where the trainer object is passed to
|
743 |
+
the :meth: `str.format` method. For example,
|
744 |
+
``'snapshot_{.updater.iteration}'`` is converted to
|
745 |
+
``'snapshot_10000'`` at the 10,000th iteration.
|
746 |
+
|
747 |
+
Returns:
|
748 |
+
An extension function.
|
749 |
+
|
750 |
+
"""
|
751 |
+
from chainer.training import extension
|
752 |
+
|
753 |
+
@extension.make_extension(trigger=(1, "epoch"), priority=-100)
|
754 |
+
def snapshot_object(trainer):
|
755 |
+
torch_save(os.path.join(trainer.out, filename.format(trainer)), target)
|
756 |
+
|
757 |
+
return snapshot_object
|
758 |
+
|
759 |
+
|
760 |
+
def torch_load(path, model):
|
761 |
+
"""Load torch model states.
|
762 |
+
|
763 |
+
Args:
|
764 |
+
path (str): Model path or snapshot file path to be loaded.
|
765 |
+
model (torch.nn.Module): Torch model.
|
766 |
+
|
767 |
+
"""
|
768 |
+
if "snapshot" in os.path.basename(path):
|
769 |
+
model_state_dict = torch.load(path, map_location=lambda storage, loc: storage)[
|
770 |
+
"model"
|
771 |
+
]
|
772 |
+
else:
|
773 |
+
model_state_dict = torch.load(path, map_location=lambda storage, loc: storage)
|
774 |
+
|
775 |
+
if hasattr(model, "module"):
|
776 |
+
model.module.load_state_dict(model_state_dict)
|
777 |
+
else:
|
778 |
+
model.load_state_dict(model_state_dict)
|
779 |
+
|
780 |
+
del model_state_dict
|
781 |
+
|
782 |
+
|
783 |
+
def torch_resume(snapshot_path, trainer):
|
784 |
+
"""Resume from snapshot for pytorch.
|
785 |
+
|
786 |
+
Args:
|
787 |
+
snapshot_path (str): Snapshot file path.
|
788 |
+
trainer (chainer.training.Trainer): Chainer's trainer instance.
|
789 |
+
|
790 |
+
"""
|
791 |
+
from chainer.serializers import NpzDeserializer
|
792 |
+
|
793 |
+
# load snapshot
|
794 |
+
snapshot_dict = torch.load(snapshot_path, map_location=lambda storage, loc: storage)
|
795 |
+
|
796 |
+
# restore trainer states
|
797 |
+
d = NpzDeserializer(snapshot_dict["trainer"])
|
798 |
+
d.load(trainer)
|
799 |
+
|
800 |
+
# restore model states
|
801 |
+
if hasattr(trainer.updater.model, "model"):
|
802 |
+
# (for TTS model)
|
803 |
+
if hasattr(trainer.updater.model.model, "module"):
|
804 |
+
trainer.updater.model.model.module.load_state_dict(snapshot_dict["model"])
|
805 |
+
else:
|
806 |
+
trainer.updater.model.model.load_state_dict(snapshot_dict["model"])
|
807 |
+
else:
|
808 |
+
# (for ASR model)
|
809 |
+
if hasattr(trainer.updater.model, "module"):
|
810 |
+
trainer.updater.model.module.load_state_dict(snapshot_dict["model"])
|
811 |
+
else:
|
812 |
+
trainer.updater.model.load_state_dict(snapshot_dict["model"])
|
813 |
+
|
814 |
+
# retore optimizer states
|
815 |
+
trainer.updater.get_optimizer("main").load_state_dict(snapshot_dict["optimizer"])
|
816 |
+
|
817 |
+
# delete opened snapshot
|
818 |
+
del snapshot_dict
|
819 |
+
|
820 |
+
|
821 |
+
# * ------------------ recognition related ------------------ *
|
822 |
+
def parse_hypothesis(hyp, char_list):
|
823 |
+
"""Parse hypothesis.
|
824 |
+
|
825 |
+
Args:
|
826 |
+
hyp (list[dict[str, Any]]): Recognition hypothesis.
|
827 |
+
char_list (list[str]): List of characters.
|
828 |
+
|
829 |
+
Returns:
|
830 |
+
tuple(str, str, str, float)
|
831 |
+
|
832 |
+
"""
|
833 |
+
# remove sos and get results
|
834 |
+
tokenid_as_list = list(map(int, hyp["yseq"][1:]))
|
835 |
+
token_as_list = [char_list[idx] for idx in tokenid_as_list]
|
836 |
+
score = float(hyp["score"])
|
837 |
+
|
838 |
+
# convert to string
|
839 |
+
tokenid = " ".join([str(idx) for idx in tokenid_as_list])
|
840 |
+
token = " ".join(token_as_list)
|
841 |
+
text = "".join(token_as_list).replace("<space>", " ")
|
842 |
+
|
843 |
+
return text, token, tokenid, score
|
844 |
+
|
845 |
+
|
846 |
+
def add_results_to_json(nbest_hyps, char_list):
|
847 |
+
"""Add N-best results to json.
|
848 |
+
Args:
|
849 |
+
js (dict[str, Any]): Groundtruth utterance dict.
|
850 |
+
nbest_hyps_sd (list[dict[str, Any]]):
|
851 |
+
List of hypothesis for multi_speakers: nutts x nspkrs.
|
852 |
+
char_list (list[str]): List of characters.
|
853 |
+
Returns:
|
854 |
+
str: 1-best result
|
855 |
+
"""
|
856 |
+
assert len(nbest_hyps) == 1, "only 1-best result is supported."
|
857 |
+
# parse hypothesis
|
858 |
+
rec_text, rec_token, rec_tokenid, score = parse_hypothesis(nbest_hyps[0], char_list)
|
859 |
+
return rec_text
|
860 |
+
|
861 |
+
|
862 |
+
def plot_spectrogram(
|
863 |
+
plt,
|
864 |
+
spec,
|
865 |
+
mode="db",
|
866 |
+
fs=None,
|
867 |
+
frame_shift=None,
|
868 |
+
bottom=True,
|
869 |
+
left=True,
|
870 |
+
right=True,
|
871 |
+
top=False,
|
872 |
+
labelbottom=True,
|
873 |
+
labelleft=True,
|
874 |
+
labelright=True,
|
875 |
+
labeltop=False,
|
876 |
+
cmap="inferno",
|
877 |
+
):
|
878 |
+
"""Plot spectrogram using matplotlib.
|
879 |
+
|
880 |
+
Args:
|
881 |
+
plt (matplotlib.pyplot): pyplot object.
|
882 |
+
spec (numpy.ndarray): Input stft (Freq, Time)
|
883 |
+
mode (str): db or linear.
|
884 |
+
fs (int): Sample frequency. To convert y-axis to kHz unit.
|
885 |
+
frame_shift (int): The frame shift of stft. To convert x-axis to second unit.
|
886 |
+
bottom (bool):Whether to draw the respective ticks.
|
887 |
+
left (bool):
|
888 |
+
right (bool):
|
889 |
+
top (bool):
|
890 |
+
labelbottom (bool):Whether to draw the respective tick labels.
|
891 |
+
labelleft (bool):
|
892 |
+
labelright (bool):
|
893 |
+
labeltop (bool):
|
894 |
+
cmap (str): Colormap defined in matplotlib.
|
895 |
+
|
896 |
+
"""
|
897 |
+
spec = np.abs(spec)
|
898 |
+
if mode == "db":
|
899 |
+
x = 20 * np.log10(spec + np.finfo(spec.dtype).eps)
|
900 |
+
elif mode == "linear":
|
901 |
+
x = spec
|
902 |
+
else:
|
903 |
+
raise ValueError(mode)
|
904 |
+
|
905 |
+
if fs is not None:
|
906 |
+
ytop = fs / 2000
|
907 |
+
ylabel = "kHz"
|
908 |
+
else:
|
909 |
+
ytop = x.shape[0]
|
910 |
+
ylabel = "bin"
|
911 |
+
|
912 |
+
if frame_shift is not None and fs is not None:
|
913 |
+
xtop = x.shape[1] * frame_shift / fs
|
914 |
+
xlabel = "s"
|
915 |
+
else:
|
916 |
+
xtop = x.shape[1]
|
917 |
+
xlabel = "frame"
|
918 |
+
|
919 |
+
extent = (0, xtop, 0, ytop)
|
920 |
+
plt.imshow(x[::-1], cmap=cmap, extent=extent)
|
921 |
+
|
922 |
+
if labelbottom:
|
923 |
+
plt.xlabel("time [{}]".format(xlabel))
|
924 |
+
if labelleft:
|
925 |
+
plt.ylabel("freq [{}]".format(ylabel))
|
926 |
+
plt.colorbar().set_label("{}".format(mode))
|
927 |
+
|
928 |
+
plt.tick_params(
|
929 |
+
bottom=bottom,
|
930 |
+
left=left,
|
931 |
+
right=right,
|
932 |
+
top=top,
|
933 |
+
labelbottom=labelbottom,
|
934 |
+
labelleft=labelleft,
|
935 |
+
labelright=labelright,
|
936 |
+
labeltop=labeltop,
|
937 |
+
)
|
938 |
+
plt.axis("auto")
|
939 |
+
|
940 |
+
|
941 |
+
# * ------------------ recognition related ------------------ *
|
942 |
+
def format_mulenc_args(args):
|
943 |
+
"""Format args for multi-encoder setup.
|
944 |
+
|
945 |
+
It deals with following situations: (when args.num_encs=2):
|
946 |
+
1. args.elayers = None -> args.elayers = [4, 4];
|
947 |
+
2. args.elayers = 4 -> args.elayers = [4, 4];
|
948 |
+
3. args.elayers = [4, 4, 4] -> args.elayers = [4, 4].
|
949 |
+
|
950 |
+
"""
|
951 |
+
# default values when None is assigned.
|
952 |
+
default_dict = {
|
953 |
+
"etype": "blstmp",
|
954 |
+
"elayers": 4,
|
955 |
+
"eunits": 300,
|
956 |
+
"subsample": "1",
|
957 |
+
"dropout_rate": 0.0,
|
958 |
+
"atype": "dot",
|
959 |
+
"adim": 320,
|
960 |
+
"awin": 5,
|
961 |
+
"aheads": 4,
|
962 |
+
"aconv_chans": -1,
|
963 |
+
"aconv_filts": 100,
|
964 |
+
}
|
965 |
+
for k in default_dict.keys():
|
966 |
+
if isinstance(vars(args)[k], list):
|
967 |
+
if len(vars(args)[k]) != args.num_encs:
|
968 |
+
logging.warning(
|
969 |
+
"Length mismatch {}: Convert {} to {}.".format(
|
970 |
+
k, vars(args)[k], vars(args)[k][: args.num_encs]
|
971 |
+
)
|
972 |
+
)
|
973 |
+
vars(args)[k] = vars(args)[k][: args.num_encs]
|
974 |
+
else:
|
975 |
+
if not vars(args)[k]:
|
976 |
+
# assign default value if it is None
|
977 |
+
vars(args)[k] = default_dict[k]
|
978 |
+
logging.warning(
|
979 |
+
"{} is not specified, use default value {}.".format(
|
980 |
+
k, default_dict[k]
|
981 |
+
)
|
982 |
+
)
|
983 |
+
# duplicate
|
984 |
+
logging.warning(
|
985 |
+
"Type mismatch {}: Convert {} to {}.".format(
|
986 |
+
k, vars(args)[k], [vars(args)[k] for _ in range(args.num_encs)]
|
987 |
+
)
|
988 |
+
)
|
989 |
+
vars(args)[k] = [vars(args)[k] for _ in range(args.num_encs)]
|
990 |
+
return args
|
espnet/nets/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
espnet/nets/batch_beam_search.py
ADDED
@@ -0,0 +1,349 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Parallel beam search module."""
|
2 |
+
|
3 |
+
import logging
|
4 |
+
from typing import Any
|
5 |
+
from typing import Dict
|
6 |
+
from typing import List
|
7 |
+
from typing import NamedTuple
|
8 |
+
from typing import Tuple
|
9 |
+
|
10 |
+
import torch
|
11 |
+
from torch.nn.utils.rnn import pad_sequence
|
12 |
+
|
13 |
+
from espnet.nets.beam_search import BeamSearch
|
14 |
+
from espnet.nets.beam_search import Hypothesis
|
15 |
+
|
16 |
+
|
17 |
+
class BatchHypothesis(NamedTuple):
|
18 |
+
"""Batchfied/Vectorized hypothesis data type."""
|
19 |
+
|
20 |
+
yseq: torch.Tensor = torch.tensor([]) # (batch, maxlen)
|
21 |
+
score: torch.Tensor = torch.tensor([]) # (batch,)
|
22 |
+
length: torch.Tensor = torch.tensor([]) # (batch,)
|
23 |
+
scores: Dict[str, torch.Tensor] = dict() # values: (batch,)
|
24 |
+
states: Dict[str, Dict] = dict()
|
25 |
+
|
26 |
+
def __len__(self) -> int:
|
27 |
+
"""Return a batch size."""
|
28 |
+
return len(self.length)
|
29 |
+
|
30 |
+
|
31 |
+
class BatchBeamSearch(BeamSearch):
|
32 |
+
"""Batch beam search implementation."""
|
33 |
+
|
34 |
+
def batchfy(self, hyps: List[Hypothesis]) -> BatchHypothesis:
|
35 |
+
"""Convert list to batch."""
|
36 |
+
if len(hyps) == 0:
|
37 |
+
return BatchHypothesis()
|
38 |
+
yseq=pad_sequence(
|
39 |
+
[h.yseq for h in hyps], batch_first=True, padding_value=self.eos
|
40 |
+
)
|
41 |
+
return BatchHypothesis(
|
42 |
+
yseq=yseq,
|
43 |
+
length=torch.tensor([len(h.yseq) for h in hyps], dtype=torch.int64, device=yseq.device),
|
44 |
+
score=torch.tensor([h.score for h in hyps]).to(yseq.device),
|
45 |
+
scores={k: torch.tensor([h.scores[k] for h in hyps], device=yseq.device) for k in self.scorers},
|
46 |
+
states={k: [h.states[k] for h in hyps] for k in self.scorers},
|
47 |
+
)
|
48 |
+
|
49 |
+
def _batch_select(self, hyps: BatchHypothesis, ids: List[int]) -> BatchHypothesis:
|
50 |
+
return BatchHypothesis(
|
51 |
+
yseq=hyps.yseq[ids],
|
52 |
+
score=hyps.score[ids],
|
53 |
+
length=hyps.length[ids],
|
54 |
+
scores={k: v[ids] for k, v in hyps.scores.items()},
|
55 |
+
states={
|
56 |
+
k: [self.scorers[k].select_state(v, i) for i in ids]
|
57 |
+
for k, v in hyps.states.items()
|
58 |
+
},
|
59 |
+
)
|
60 |
+
|
61 |
+
def _select(self, hyps: BatchHypothesis, i: int) -> Hypothesis:
|
62 |
+
return Hypothesis(
|
63 |
+
yseq=hyps.yseq[i, : hyps.length[i]],
|
64 |
+
score=hyps.score[i],
|
65 |
+
scores={k: v[i] for k, v in hyps.scores.items()},
|
66 |
+
states={
|
67 |
+
k: self.scorers[k].select_state(v, i) for k, v in hyps.states.items()
|
68 |
+
},
|
69 |
+
)
|
70 |
+
|
71 |
+
def unbatchfy(self, batch_hyps: BatchHypothesis) -> List[Hypothesis]:
|
72 |
+
"""Revert batch to list."""
|
73 |
+
return [
|
74 |
+
Hypothesis(
|
75 |
+
yseq=batch_hyps.yseq[i][: batch_hyps.length[i]],
|
76 |
+
score=batch_hyps.score[i],
|
77 |
+
scores={k: batch_hyps.scores[k][i] for k in self.scorers},
|
78 |
+
states={
|
79 |
+
k: v.select_state(batch_hyps.states[k], i)
|
80 |
+
for k, v in self.scorers.items()
|
81 |
+
},
|
82 |
+
)
|
83 |
+
for i in range(len(batch_hyps.length))
|
84 |
+
]
|
85 |
+
|
86 |
+
def batch_beam(
|
87 |
+
self, weighted_scores: torch.Tensor, ids: torch.Tensor
|
88 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
89 |
+
"""Batch-compute topk full token ids and partial token ids.
|
90 |
+
|
91 |
+
Args:
|
92 |
+
weighted_scores (torch.Tensor): The weighted sum scores for each tokens.
|
93 |
+
Its shape is `(n_beam, self.vocab_size)`.
|
94 |
+
ids (torch.Tensor): The partial token ids to compute topk.
|
95 |
+
Its shape is `(n_beam, self.pre_beam_size)`.
|
96 |
+
|
97 |
+
Returns:
|
98 |
+
Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
99 |
+
The topk full (prev_hyp, new_token) ids
|
100 |
+
and partial (prev_hyp, new_token) ids.
|
101 |
+
Their shapes are all `(self.beam_size,)`
|
102 |
+
|
103 |
+
"""
|
104 |
+
top_ids = weighted_scores.view(-1).topk(self.beam_size)[1]
|
105 |
+
# Because of the flatten above, `top_ids` is organized as:
|
106 |
+
# [hyp1 * V + token1, hyp2 * V + token2, ..., hypK * V + tokenK],
|
107 |
+
# where V is `self.n_vocab` and K is `self.beam_size`
|
108 |
+
prev_hyp_ids = torch.div(top_ids, self.n_vocab, rounding_mode='trunc')
|
109 |
+
new_token_ids = top_ids % self.n_vocab
|
110 |
+
return prev_hyp_ids, new_token_ids, prev_hyp_ids, new_token_ids
|
111 |
+
|
112 |
+
def init_hyp(self, x: torch.Tensor) -> BatchHypothesis:
|
113 |
+
"""Get an initial hypothesis data.
|
114 |
+
|
115 |
+
Args:
|
116 |
+
x (torch.Tensor): The encoder output feature
|
117 |
+
|
118 |
+
Returns:
|
119 |
+
Hypothesis: The initial hypothesis.
|
120 |
+
|
121 |
+
"""
|
122 |
+
init_states = dict()
|
123 |
+
init_scores = dict()
|
124 |
+
for k, d in self.scorers.items():
|
125 |
+
init_states[k] = d.batch_init_state(x)
|
126 |
+
init_scores[k] = 0.0
|
127 |
+
return self.batchfy(
|
128 |
+
[
|
129 |
+
Hypothesis(
|
130 |
+
score=0.0,
|
131 |
+
scores=init_scores,
|
132 |
+
states=init_states,
|
133 |
+
yseq=torch.tensor([self.sos], device=x.device),
|
134 |
+
)
|
135 |
+
]
|
136 |
+
)
|
137 |
+
|
138 |
+
def score_full(
|
139 |
+
self, hyp: BatchHypothesis, x: torch.Tensor
|
140 |
+
) -> Tuple[Dict[str, torch.Tensor], Dict[str, Any]]:
|
141 |
+
"""Score new hypothesis by `self.full_scorers`.
|
142 |
+
|
143 |
+
Args:
|
144 |
+
hyp (Hypothesis): Hypothesis with prefix tokens to score
|
145 |
+
x (torch.Tensor): Corresponding input feature
|
146 |
+
|
147 |
+
Returns:
|
148 |
+
Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: Tuple of
|
149 |
+
score dict of `hyp` that has string keys of `self.full_scorers`
|
150 |
+
and tensor score values of shape: `(self.n_vocab,)`,
|
151 |
+
and state dict that has string keys
|
152 |
+
and state values of `self.full_scorers`
|
153 |
+
|
154 |
+
"""
|
155 |
+
scores = dict()
|
156 |
+
states = dict()
|
157 |
+
for k, d in self.full_scorers.items():
|
158 |
+
scores[k], states[k] = d.batch_score(hyp.yseq, hyp.states[k], x)
|
159 |
+
return scores, states
|
160 |
+
|
161 |
+
def score_partial(
|
162 |
+
self, hyp: BatchHypothesis, ids: torch.Tensor, x: torch.Tensor
|
163 |
+
) -> Tuple[Dict[str, torch.Tensor], Dict[str, Any]]:
|
164 |
+
"""Score new hypothesis by `self.full_scorers`.
|
165 |
+
|
166 |
+
Args:
|
167 |
+
hyp (Hypothesis): Hypothesis with prefix tokens to score
|
168 |
+
ids (torch.Tensor): 2D tensor of new partial tokens to score
|
169 |
+
x (torch.Tensor): Corresponding input feature
|
170 |
+
|
171 |
+
Returns:
|
172 |
+
Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: Tuple of
|
173 |
+
score dict of `hyp` that has string keys of `self.full_scorers`
|
174 |
+
and tensor score values of shape: `(self.n_vocab,)`,
|
175 |
+
and state dict that has string keys
|
176 |
+
and state values of `self.full_scorers`
|
177 |
+
|
178 |
+
"""
|
179 |
+
scores = dict()
|
180 |
+
states = dict()
|
181 |
+
for k, d in self.part_scorers.items():
|
182 |
+
scores[k], states[k] = d.batch_score_partial(
|
183 |
+
hyp.yseq, ids, hyp.states[k], x
|
184 |
+
)
|
185 |
+
return scores, states
|
186 |
+
|
187 |
+
def merge_states(self, states: Any, part_states: Any, part_idx: int) -> Any:
|
188 |
+
"""Merge states for new hypothesis.
|
189 |
+
|
190 |
+
Args:
|
191 |
+
states: states of `self.full_scorers`
|
192 |
+
part_states: states of `self.part_scorers`
|
193 |
+
part_idx (int): The new token id for `part_scores`
|
194 |
+
|
195 |
+
Returns:
|
196 |
+
Dict[str, torch.Tensor]: The new score dict.
|
197 |
+
Its keys are names of `self.full_scorers` and `self.part_scorers`.
|
198 |
+
Its values are states of the scorers.
|
199 |
+
|
200 |
+
"""
|
201 |
+
new_states = dict()
|
202 |
+
for k, v in states.items():
|
203 |
+
new_states[k] = v
|
204 |
+
for k, v in part_states.items():
|
205 |
+
new_states[k] = v
|
206 |
+
return new_states
|
207 |
+
|
208 |
+
def search(self, running_hyps: BatchHypothesis, x: torch.Tensor) -> BatchHypothesis:
|
209 |
+
"""Search new tokens for running hypotheses and encoded speech x.
|
210 |
+
|
211 |
+
Args:
|
212 |
+
running_hyps (BatchHypothesis): Running hypotheses on beam
|
213 |
+
x (torch.Tensor): Encoded speech feature (T, D)
|
214 |
+
|
215 |
+
Returns:
|
216 |
+
BatchHypothesis: Best sorted hypotheses
|
217 |
+
|
218 |
+
"""
|
219 |
+
n_batch = len(running_hyps)
|
220 |
+
part_ids = None # no pre-beam
|
221 |
+
# batch scoring
|
222 |
+
weighted_scores = torch.zeros(
|
223 |
+
n_batch, self.n_vocab, dtype=x.dtype, device=x.device
|
224 |
+
)
|
225 |
+
scores, states = self.score_full(running_hyps, x.expand(n_batch, *x.shape))
|
226 |
+
for k in self.full_scorers:
|
227 |
+
weighted_scores += self.weights[k] * scores[k]
|
228 |
+
# partial scoring
|
229 |
+
if self.do_pre_beam:
|
230 |
+
pre_beam_scores = (
|
231 |
+
weighted_scores
|
232 |
+
if self.pre_beam_score_key == "full"
|
233 |
+
else scores[self.pre_beam_score_key]
|
234 |
+
)
|
235 |
+
part_ids = torch.topk(pre_beam_scores, self.pre_beam_size, dim=-1)[1]
|
236 |
+
# NOTE(takaaki-hori): Unlike BeamSearch, we assume that score_partial returns
|
237 |
+
# full-size score matrices, which has non-zero scores for part_ids and zeros
|
238 |
+
# for others.
|
239 |
+
part_scores, part_states = self.score_partial(running_hyps, part_ids, x)
|
240 |
+
for k in self.part_scorers:
|
241 |
+
weighted_scores += self.weights[k] * part_scores[k]
|
242 |
+
# add previous hyp scores
|
243 |
+
weighted_scores += running_hyps.score.to(
|
244 |
+
dtype=x.dtype, device=x.device
|
245 |
+
).unsqueeze(1)
|
246 |
+
|
247 |
+
# TODO(karita): do not use list. use batch instead
|
248 |
+
# see also https://github.com/espnet/espnet/pull/1402#discussion_r354561029
|
249 |
+
# update hyps
|
250 |
+
best_hyps = []
|
251 |
+
prev_hyps = self.unbatchfy(running_hyps)
|
252 |
+
for (
|
253 |
+
full_prev_hyp_id,
|
254 |
+
full_new_token_id,
|
255 |
+
part_prev_hyp_id,
|
256 |
+
part_new_token_id,
|
257 |
+
) in zip(*self.batch_beam(weighted_scores, part_ids)):
|
258 |
+
prev_hyp = prev_hyps[full_prev_hyp_id]
|
259 |
+
best_hyps.append(
|
260 |
+
Hypothesis(
|
261 |
+
score=weighted_scores[full_prev_hyp_id, full_new_token_id],
|
262 |
+
yseq=self.append_token(prev_hyp.yseq, full_new_token_id),
|
263 |
+
scores=self.merge_scores(
|
264 |
+
prev_hyp.scores,
|
265 |
+
{k: v[full_prev_hyp_id] for k, v in scores.items()},
|
266 |
+
full_new_token_id,
|
267 |
+
{k: v[part_prev_hyp_id] for k, v in part_scores.items()},
|
268 |
+
part_new_token_id,
|
269 |
+
),
|
270 |
+
states=self.merge_states(
|
271 |
+
{
|
272 |
+
k: self.full_scorers[k].select_state(v, full_prev_hyp_id)
|
273 |
+
for k, v in states.items()
|
274 |
+
},
|
275 |
+
{
|
276 |
+
k: self.part_scorers[k].select_state(
|
277 |
+
v, part_prev_hyp_id, part_new_token_id
|
278 |
+
)
|
279 |
+
for k, v in part_states.items()
|
280 |
+
},
|
281 |
+
part_new_token_id,
|
282 |
+
),
|
283 |
+
)
|
284 |
+
)
|
285 |
+
return self.batchfy(best_hyps)
|
286 |
+
|
287 |
+
def post_process(
|
288 |
+
self,
|
289 |
+
i: int,
|
290 |
+
maxlen: int,
|
291 |
+
maxlenratio: float,
|
292 |
+
running_hyps: BatchHypothesis,
|
293 |
+
ended_hyps: List[Hypothesis],
|
294 |
+
) -> BatchHypothesis:
|
295 |
+
"""Perform post-processing of beam search iterations.
|
296 |
+
|
297 |
+
Args:
|
298 |
+
i (int): The length of hypothesis tokens.
|
299 |
+
maxlen (int): The maximum length of tokens in beam search.
|
300 |
+
maxlenratio (int): The maximum length ratio in beam search.
|
301 |
+
running_hyps (BatchHypothesis): The running hypotheses in beam search.
|
302 |
+
ended_hyps (List[Hypothesis]): The ended hypotheses in beam search.
|
303 |
+
|
304 |
+
Returns:
|
305 |
+
BatchHypothesis: The new running hypotheses.
|
306 |
+
|
307 |
+
"""
|
308 |
+
n_batch = running_hyps.yseq.shape[0]
|
309 |
+
logging.debug(f"the number of running hypothes: {n_batch}")
|
310 |
+
if self.token_list is not None:
|
311 |
+
logging.debug(
|
312 |
+
"best hypo: "
|
313 |
+
+ "".join(
|
314 |
+
[
|
315 |
+
self.token_list[x]
|
316 |
+
for x in running_hyps.yseq[0, 1 : running_hyps.length[0]]
|
317 |
+
]
|
318 |
+
)
|
319 |
+
)
|
320 |
+
# add eos in the final loop to avoid that there are no ended hyps
|
321 |
+
if i == maxlen - 1:
|
322 |
+
logging.info("adding <eos> in the last position in the loop")
|
323 |
+
yseq_eos = torch.cat(
|
324 |
+
(
|
325 |
+
running_hyps.yseq,
|
326 |
+
torch.full(
|
327 |
+
(n_batch, 1),
|
328 |
+
self.eos,
|
329 |
+
device=running_hyps.yseq.device,
|
330 |
+
dtype=torch.int64,
|
331 |
+
),
|
332 |
+
),
|
333 |
+
1,
|
334 |
+
)
|
335 |
+
running_hyps.yseq.resize_as_(yseq_eos)
|
336 |
+
running_hyps.yseq[:] = yseq_eos
|
337 |
+
running_hyps.length[:] = yseq_eos.shape[1]
|
338 |
+
|
339 |
+
# add ended hypotheses to a final list, and removed them from current hypotheses
|
340 |
+
# (this will be a probmlem, number of hyps < beam)
|
341 |
+
is_eos = (
|
342 |
+
running_hyps.yseq[torch.arange(n_batch), running_hyps.length - 1]
|
343 |
+
== self.eos
|
344 |
+
)
|
345 |
+
for b in torch.nonzero(is_eos, as_tuple=False).view(-1):
|
346 |
+
hyp = self._select(running_hyps, b)
|
347 |
+
ended_hyps.append(hyp)
|
348 |
+
remained_ids = torch.nonzero(is_eos == 0, as_tuple=False).view(-1)
|
349 |
+
return self._batch_select(running_hyps, remained_ids)
|
espnet/nets/beam_search.py
ADDED
@@ -0,0 +1,516 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Beam search module."""
|
2 |
+
|
3 |
+
from itertools import chain
|
4 |
+
import logging
|
5 |
+
from typing import Any
|
6 |
+
from typing import Dict
|
7 |
+
from typing import List
|
8 |
+
from typing import NamedTuple
|
9 |
+
from typing import Tuple
|
10 |
+
from typing import Union
|
11 |
+
|
12 |
+
import torch
|
13 |
+
|
14 |
+
from espnet.nets.e2e_asr_common import end_detect
|
15 |
+
from espnet.nets.scorer_interface import PartialScorerInterface
|
16 |
+
from espnet.nets.scorer_interface import ScorerInterface
|
17 |
+
|
18 |
+
|
19 |
+
class Hypothesis(NamedTuple):
|
20 |
+
"""Hypothesis data type."""
|
21 |
+
|
22 |
+
yseq: torch.Tensor
|
23 |
+
score: Union[float, torch.Tensor] = 0
|
24 |
+
scores: Dict[str, Union[float, torch.Tensor]] = dict()
|
25 |
+
states: Dict[str, Any] = dict()
|
26 |
+
|
27 |
+
def asdict(self) -> dict:
|
28 |
+
"""Convert data to JSON-friendly dict."""
|
29 |
+
return self._replace(
|
30 |
+
yseq=self.yseq.tolist(),
|
31 |
+
score=float(self.score),
|
32 |
+
scores={k: float(v) for k, v in self.scores.items()},
|
33 |
+
)._asdict()
|
34 |
+
|
35 |
+
|
36 |
+
class BeamSearch(torch.nn.Module):
|
37 |
+
"""Beam search implementation."""
|
38 |
+
|
39 |
+
def __init__(
|
40 |
+
self,
|
41 |
+
scorers: Dict[str, ScorerInterface],
|
42 |
+
weights: Dict[str, float],
|
43 |
+
beam_size: int,
|
44 |
+
vocab_size: int,
|
45 |
+
sos: int,
|
46 |
+
eos: int,
|
47 |
+
token_list: List[str] = None,
|
48 |
+
pre_beam_ratio: float = 1.5,
|
49 |
+
pre_beam_score_key: str = None,
|
50 |
+
):
|
51 |
+
"""Initialize beam search.
|
52 |
+
|
53 |
+
Args:
|
54 |
+
scorers (dict[str, ScorerInterface]): Dict of decoder modules
|
55 |
+
e.g., Decoder, CTCPrefixScorer, LM
|
56 |
+
The scorer will be ignored if it is `None`
|
57 |
+
weights (dict[str, float]): Dict of weights for each scorers
|
58 |
+
The scorer will be ignored if its weight is 0
|
59 |
+
beam_size (int): The number of hypotheses kept during search
|
60 |
+
vocab_size (int): The number of vocabulary
|
61 |
+
sos (int): Start of sequence id
|
62 |
+
eos (int): End of sequence id
|
63 |
+
token_list (list[str]): List of tokens for debug log
|
64 |
+
pre_beam_score_key (str): key of scores to perform pre-beam search
|
65 |
+
pre_beam_ratio (float): beam size in the pre-beam search
|
66 |
+
will be `int(pre_beam_ratio * beam_size)`
|
67 |
+
|
68 |
+
"""
|
69 |
+
super().__init__()
|
70 |
+
# set scorers
|
71 |
+
self.weights = weights
|
72 |
+
self.scorers = dict()
|
73 |
+
self.full_scorers = dict()
|
74 |
+
self.part_scorers = dict()
|
75 |
+
# this module dict is required for recursive cast
|
76 |
+
# `self.to(device, dtype)` in `recog.py`
|
77 |
+
self.nn_dict = torch.nn.ModuleDict()
|
78 |
+
for k, v in scorers.items():
|
79 |
+
w = weights.get(k, 0)
|
80 |
+
if w == 0 or v is None:
|
81 |
+
continue
|
82 |
+
assert isinstance(
|
83 |
+
v, ScorerInterface
|
84 |
+
), f"{k} ({type(v)}) does not implement ScorerInterface"
|
85 |
+
self.scorers[k] = v
|
86 |
+
if isinstance(v, PartialScorerInterface):
|
87 |
+
self.part_scorers[k] = v
|
88 |
+
else:
|
89 |
+
self.full_scorers[k] = v
|
90 |
+
if isinstance(v, torch.nn.Module):
|
91 |
+
self.nn_dict[k] = v
|
92 |
+
|
93 |
+
# set configurations
|
94 |
+
self.sos = sos
|
95 |
+
self.eos = eos
|
96 |
+
self.token_list = token_list
|
97 |
+
self.pre_beam_size = int(pre_beam_ratio * beam_size)
|
98 |
+
self.beam_size = beam_size
|
99 |
+
self.n_vocab = vocab_size
|
100 |
+
if (
|
101 |
+
pre_beam_score_key is not None
|
102 |
+
and pre_beam_score_key != "full"
|
103 |
+
and pre_beam_score_key not in self.full_scorers
|
104 |
+
):
|
105 |
+
raise KeyError(f"{pre_beam_score_key} is not found in {self.full_scorers}")
|
106 |
+
self.pre_beam_score_key = pre_beam_score_key
|
107 |
+
self.do_pre_beam = (
|
108 |
+
self.pre_beam_score_key is not None
|
109 |
+
and self.pre_beam_size < self.n_vocab
|
110 |
+
and len(self.part_scorers) > 0
|
111 |
+
)
|
112 |
+
|
113 |
+
def init_hyp(self, x: torch.Tensor) -> List[Hypothesis]:
|
114 |
+
"""Get an initial hypothesis data.
|
115 |
+
|
116 |
+
Args:
|
117 |
+
x (torch.Tensor): The encoder output feature
|
118 |
+
|
119 |
+
Returns:
|
120 |
+
Hypothesis: The initial hypothesis.
|
121 |
+
|
122 |
+
"""
|
123 |
+
init_states = dict()
|
124 |
+
init_scores = dict()
|
125 |
+
for k, d in self.scorers.items():
|
126 |
+
init_states[k] = d.init_state(x)
|
127 |
+
init_scores[k] = 0.0
|
128 |
+
return [
|
129 |
+
Hypothesis(
|
130 |
+
score=0.0,
|
131 |
+
scores=init_scores,
|
132 |
+
states=init_states,
|
133 |
+
yseq=torch.tensor([self.sos], device=x.device),
|
134 |
+
)
|
135 |
+
]
|
136 |
+
|
137 |
+
@staticmethod
|
138 |
+
def append_token(xs: torch.Tensor, x: int) -> torch.Tensor:
|
139 |
+
"""Append new token to prefix tokens.
|
140 |
+
|
141 |
+
Args:
|
142 |
+
xs (torch.Tensor): The prefix token
|
143 |
+
x (int): The new token to append
|
144 |
+
|
145 |
+
Returns:
|
146 |
+
torch.Tensor: New tensor contains: xs + [x] with xs.dtype and xs.device
|
147 |
+
|
148 |
+
"""
|
149 |
+
x = torch.tensor([x], dtype=xs.dtype, device=xs.device)
|
150 |
+
return torch.cat((xs, x))
|
151 |
+
|
152 |
+
def score_full(
|
153 |
+
self, hyp: Hypothesis, x: torch.Tensor
|
154 |
+
) -> Tuple[Dict[str, torch.Tensor], Dict[str, Any]]:
|
155 |
+
"""Score new hypothesis by `self.full_scorers`.
|
156 |
+
|
157 |
+
Args:
|
158 |
+
hyp (Hypothesis): Hypothesis with prefix tokens to score
|
159 |
+
x (torch.Tensor): Corresponding input feature
|
160 |
+
|
161 |
+
Returns:
|
162 |
+
Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: Tuple of
|
163 |
+
score dict of `hyp` that has string keys of `self.full_scorers`
|
164 |
+
and tensor score values of shape: `(self.n_vocab,)`,
|
165 |
+
and state dict that has string keys
|
166 |
+
and state values of `self.full_scorers`
|
167 |
+
|
168 |
+
"""
|
169 |
+
scores = dict()
|
170 |
+
states = dict()
|
171 |
+
for k, d in self.full_scorers.items():
|
172 |
+
scores[k], states[k] = d.score(hyp.yseq, hyp.states[k], x)
|
173 |
+
return scores, states
|
174 |
+
|
175 |
+
def score_partial(
|
176 |
+
self, hyp: Hypothesis, ids: torch.Tensor, x: torch.Tensor
|
177 |
+
) -> Tuple[Dict[str, torch.Tensor], Dict[str, Any]]:
|
178 |
+
"""Score new hypothesis by `self.part_scorers`.
|
179 |
+
|
180 |
+
Args:
|
181 |
+
hyp (Hypothesis): Hypothesis with prefix tokens to score
|
182 |
+
ids (torch.Tensor): 1D tensor of new partial tokens to score
|
183 |
+
x (torch.Tensor): Corresponding input feature
|
184 |
+
|
185 |
+
Returns:
|
186 |
+
Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: Tuple of
|
187 |
+
score dict of `hyp` that has string keys of `self.part_scorers`
|
188 |
+
and tensor score values of shape: `(len(ids),)`,
|
189 |
+
and state dict that has string keys
|
190 |
+
and state values of `self.part_scorers`
|
191 |
+
|
192 |
+
"""
|
193 |
+
scores = dict()
|
194 |
+
states = dict()
|
195 |
+
for k, d in self.part_scorers.items():
|
196 |
+
scores[k], states[k] = d.score_partial(hyp.yseq, ids, hyp.states[k], x)
|
197 |
+
return scores, states
|
198 |
+
|
199 |
+
def beam(
|
200 |
+
self, weighted_scores: torch.Tensor, ids: torch.Tensor
|
201 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
202 |
+
"""Compute topk full token ids and partial token ids.
|
203 |
+
|
204 |
+
Args:
|
205 |
+
weighted_scores (torch.Tensor): The weighted sum scores for each tokens.
|
206 |
+
Its shape is `(self.n_vocab,)`.
|
207 |
+
ids (torch.Tensor): The partial token ids to compute topk
|
208 |
+
|
209 |
+
Returns:
|
210 |
+
Tuple[torch.Tensor, torch.Tensor]:
|
211 |
+
The topk full token ids and partial token ids.
|
212 |
+
Their shapes are `(self.beam_size,)`
|
213 |
+
|
214 |
+
"""
|
215 |
+
# no pre beam performed
|
216 |
+
if weighted_scores.size(0) == ids.size(0):
|
217 |
+
top_ids = weighted_scores.topk(self.beam_size)[1]
|
218 |
+
return top_ids, top_ids
|
219 |
+
|
220 |
+
# mask pruned in pre-beam not to select in topk
|
221 |
+
tmp = weighted_scores[ids]
|
222 |
+
weighted_scores[:] = -float("inf")
|
223 |
+
weighted_scores[ids] = tmp
|
224 |
+
top_ids = weighted_scores.topk(self.beam_size)[1]
|
225 |
+
local_ids = weighted_scores[ids].topk(self.beam_size)[1]
|
226 |
+
return top_ids, local_ids
|
227 |
+
|
228 |
+
@staticmethod
|
229 |
+
def merge_scores(
|
230 |
+
prev_scores: Dict[str, float],
|
231 |
+
next_full_scores: Dict[str, torch.Tensor],
|
232 |
+
full_idx: int,
|
233 |
+
next_part_scores: Dict[str, torch.Tensor],
|
234 |
+
part_idx: int,
|
235 |
+
) -> Dict[str, torch.Tensor]:
|
236 |
+
"""Merge scores for new hypothesis.
|
237 |
+
|
238 |
+
Args:
|
239 |
+
prev_scores (Dict[str, float]):
|
240 |
+
The previous hypothesis scores by `self.scorers`
|
241 |
+
next_full_scores (Dict[str, torch.Tensor]): scores by `self.full_scorers`
|
242 |
+
full_idx (int): The next token id for `next_full_scores`
|
243 |
+
next_part_scores (Dict[str, torch.Tensor]):
|
244 |
+
scores of partial tokens by `self.part_scorers`
|
245 |
+
part_idx (int): The new token id for `next_part_scores`
|
246 |
+
|
247 |
+
Returns:
|
248 |
+
Dict[str, torch.Tensor]: The new score dict.
|
249 |
+
Its keys are names of `self.full_scorers` and `self.part_scorers`.
|
250 |
+
Its values are scalar tensors by the scorers.
|
251 |
+
|
252 |
+
"""
|
253 |
+
new_scores = dict()
|
254 |
+
for k, v in next_full_scores.items():
|
255 |
+
new_scores[k] = prev_scores[k] + v[full_idx]
|
256 |
+
for k, v in next_part_scores.items():
|
257 |
+
new_scores[k] = prev_scores[k] + v[part_idx]
|
258 |
+
return new_scores
|
259 |
+
|
260 |
+
def merge_states(self, states: Any, part_states: Any, part_idx: int) -> Any:
|
261 |
+
"""Merge states for new hypothesis.
|
262 |
+
|
263 |
+
Args:
|
264 |
+
states: states of `self.full_scorers`
|
265 |
+
part_states: states of `self.part_scorers`
|
266 |
+
part_idx (int): The new token id for `part_scores`
|
267 |
+
|
268 |
+
Returns:
|
269 |
+
Dict[str, torch.Tensor]: The new score dict.
|
270 |
+
Its keys are names of `self.full_scorers` and `self.part_scorers`.
|
271 |
+
Its values are states of the scorers.
|
272 |
+
|
273 |
+
"""
|
274 |
+
new_states = dict()
|
275 |
+
for k, v in states.items():
|
276 |
+
new_states[k] = v
|
277 |
+
for k, d in self.part_scorers.items():
|
278 |
+
new_states[k] = d.select_state(part_states[k], part_idx)
|
279 |
+
return new_states
|
280 |
+
|
281 |
+
def search(
|
282 |
+
self, running_hyps: List[Hypothesis], x: torch.Tensor
|
283 |
+
) -> List[Hypothesis]:
|
284 |
+
"""Search new tokens for running hypotheses and encoded speech x.
|
285 |
+
|
286 |
+
Args:
|
287 |
+
running_hyps (List[Hypothesis]): Running hypotheses on beam
|
288 |
+
x (torch.Tensor): Encoded speech feature (T, D)
|
289 |
+
|
290 |
+
Returns:
|
291 |
+
List[Hypotheses]: Best sorted hypotheses
|
292 |
+
|
293 |
+
"""
|
294 |
+
best_hyps = []
|
295 |
+
part_ids = torch.arange(self.n_vocab, device=x.device) # no pre-beam
|
296 |
+
for hyp in running_hyps:
|
297 |
+
# scoring
|
298 |
+
weighted_scores = torch.zeros(self.n_vocab, dtype=x.dtype, device=x.device)
|
299 |
+
scores, states = self.score_full(hyp, x)
|
300 |
+
for k in self.full_scorers:
|
301 |
+
weighted_scores += self.weights[k] * scores[k]
|
302 |
+
# partial scoring
|
303 |
+
if self.do_pre_beam:
|
304 |
+
pre_beam_scores = (
|
305 |
+
weighted_scores
|
306 |
+
if self.pre_beam_score_key == "full"
|
307 |
+
else scores[self.pre_beam_score_key]
|
308 |
+
)
|
309 |
+
part_ids = torch.topk(pre_beam_scores, self.pre_beam_size)[1]
|
310 |
+
part_scores, part_states = self.score_partial(hyp, part_ids, x)
|
311 |
+
for k in self.part_scorers:
|
312 |
+
weighted_scores[part_ids] += self.weights[k] * part_scores[k]
|
313 |
+
# add previous hyp score
|
314 |
+
weighted_scores += hyp.score
|
315 |
+
|
316 |
+
# update hyps
|
317 |
+
for j, part_j in zip(*self.beam(weighted_scores, part_ids)):
|
318 |
+
# will be (2 x beam at most)
|
319 |
+
best_hyps.append(
|
320 |
+
Hypothesis(
|
321 |
+
score=weighted_scores[j],
|
322 |
+
yseq=self.append_token(hyp.yseq, j),
|
323 |
+
scores=self.merge_scores(
|
324 |
+
hyp.scores, scores, j, part_scores, part_j
|
325 |
+
),
|
326 |
+
states=self.merge_states(states, part_states, part_j),
|
327 |
+
)
|
328 |
+
)
|
329 |
+
|
330 |
+
# sort and prune 2 x beam -> beam
|
331 |
+
best_hyps = sorted(best_hyps, key=lambda x: x.score, reverse=True)[
|
332 |
+
: min(len(best_hyps), self.beam_size)
|
333 |
+
]
|
334 |
+
return best_hyps
|
335 |
+
|
336 |
+
def forward(
|
337 |
+
self, x: torch.Tensor, maxlenratio: float = 0.0, minlenratio: float = 0.0
|
338 |
+
) -> List[Hypothesis]:
|
339 |
+
"""Perform beam search.
|
340 |
+
|
341 |
+
Args:
|
342 |
+
x (torch.Tensor): Encoded speech feature (T, D)
|
343 |
+
maxlenratio (float): Input length ratio to obtain max output length.
|
344 |
+
If maxlenratio=0.0 (default), it uses a end-detect function
|
345 |
+
to automatically find maximum hypothesis lengths
|
346 |
+
If maxlenratio<0.0, its absolute value is interpreted
|
347 |
+
as a constant max output length.
|
348 |
+
minlenratio (float): Input length ratio to obtain min output length.
|
349 |
+
|
350 |
+
Returns:
|
351 |
+
list[Hypothesis]: N-best decoding results
|
352 |
+
|
353 |
+
"""
|
354 |
+
# set length bounds
|
355 |
+
if maxlenratio == 0:
|
356 |
+
maxlen = x.shape[0]
|
357 |
+
elif maxlenratio < 0:
|
358 |
+
maxlen = -1 * int(maxlenratio)
|
359 |
+
else:
|
360 |
+
maxlen = max(1, int(maxlenratio * x.size(0)))
|
361 |
+
minlen = int(minlenratio * x.size(0))
|
362 |
+
logging.info("decoder input length: " + str(x.shape[0]))
|
363 |
+
logging.info("max output length: " + str(maxlen))
|
364 |
+
logging.info("min output length: " + str(minlen))
|
365 |
+
|
366 |
+
# main loop of prefix search
|
367 |
+
running_hyps = self.init_hyp(x)
|
368 |
+
ended_hyps = []
|
369 |
+
for i in range(maxlen):
|
370 |
+
logging.debug("position " + str(i))
|
371 |
+
best = self.search(running_hyps, x)
|
372 |
+
# post process of one iteration
|
373 |
+
running_hyps = self.post_process(i, maxlen, maxlenratio, best, ended_hyps)
|
374 |
+
# end detection
|
375 |
+
if maxlenratio == 0.0 and end_detect([h.asdict() for h in ended_hyps], i):
|
376 |
+
logging.info(f"end detected at {i}")
|
377 |
+
break
|
378 |
+
if len(running_hyps) == 0:
|
379 |
+
logging.info("no hypothesis. Finish decoding.")
|
380 |
+
break
|
381 |
+
else:
|
382 |
+
logging.debug(f"remained hypotheses: {len(running_hyps)}")
|
383 |
+
|
384 |
+
nbest_hyps = sorted(ended_hyps, key=lambda x: x.score, reverse=True)
|
385 |
+
# check the number of hypotheses reaching to eos
|
386 |
+
if len(nbest_hyps) == 0:
|
387 |
+
logging.warning(
|
388 |
+
"there is no N-best results, perform recognition "
|
389 |
+
"again with smaller minlenratio."
|
390 |
+
)
|
391 |
+
return (
|
392 |
+
[]
|
393 |
+
if minlenratio < 0.1
|
394 |
+
else self.forward(x, maxlenratio, max(0.0, minlenratio - 0.1))
|
395 |
+
)
|
396 |
+
|
397 |
+
# report the best result
|
398 |
+
best = nbest_hyps[0]
|
399 |
+
for k, v in best.scores.items():
|
400 |
+
logging.info(
|
401 |
+
f"{v:6.2f} * {self.weights[k]:3} = {v * self.weights[k]:6.2f} for {k}"
|
402 |
+
)
|
403 |
+
logging.info(f"total log probability: {best.score:.2f}")
|
404 |
+
logging.info(f"normalized log probability: {best.score / len(best.yseq):.2f}")
|
405 |
+
logging.info(f"total number of ended hypotheses: {len(nbest_hyps)}")
|
406 |
+
if self.token_list is not None:
|
407 |
+
logging.info(
|
408 |
+
"best hypo: "
|
409 |
+
+ "".join([self.token_list[x] for x in best.yseq[1:-1]])
|
410 |
+
+ "\n"
|
411 |
+
)
|
412 |
+
return nbest_hyps
|
413 |
+
|
414 |
+
def post_process(
|
415 |
+
self,
|
416 |
+
i: int,
|
417 |
+
maxlen: int,
|
418 |
+
maxlenratio: float,
|
419 |
+
running_hyps: List[Hypothesis],
|
420 |
+
ended_hyps: List[Hypothesis],
|
421 |
+
) -> List[Hypothesis]:
|
422 |
+
"""Perform post-processing of beam search iterations.
|
423 |
+
|
424 |
+
Args:
|
425 |
+
i (int): The length of hypothesis tokens.
|
426 |
+
maxlen (int): The maximum length of tokens in beam search.
|
427 |
+
maxlenratio (int): The maximum length ratio in beam search.
|
428 |
+
running_hyps (List[Hypothesis]): The running hypotheses in beam search.
|
429 |
+
ended_hyps (List[Hypothesis]): The ended hypotheses in beam search.
|
430 |
+
|
431 |
+
Returns:
|
432 |
+
List[Hypothesis]: The new running hypotheses.
|
433 |
+
|
434 |
+
"""
|
435 |
+
logging.debug(f"the number of running hypotheses: {len(running_hyps)}")
|
436 |
+
if self.token_list is not None:
|
437 |
+
logging.debug(
|
438 |
+
"best hypo: "
|
439 |
+
+ "".join([self.token_list[x] for x in running_hyps[0].yseq[1:]])
|
440 |
+
)
|
441 |
+
# add eos in the final loop to avoid that there are no ended hyps
|
442 |
+
if i == maxlen - 1:
|
443 |
+
logging.info("adding <eos> in the last position in the loop")
|
444 |
+
running_hyps = [
|
445 |
+
h._replace(yseq=self.append_token(h.yseq, self.eos))
|
446 |
+
for h in running_hyps
|
447 |
+
]
|
448 |
+
|
449 |
+
# add ended hypotheses to a final list, and removed them from current hypotheses
|
450 |
+
# (this will be a problem, number of hyps < beam)
|
451 |
+
remained_hyps = []
|
452 |
+
for hyp in running_hyps:
|
453 |
+
if hyp.yseq[-1] == self.eos:
|
454 |
+
# e.g., Word LM needs to add final <eos> score
|
455 |
+
for k, d in chain(self.full_scorers.items(), self.part_scorers.items()):
|
456 |
+
s = d.final_score(hyp.states[k])
|
457 |
+
hyp.scores[k] += s
|
458 |
+
hyp = hyp._replace(score=hyp.score + self.weights[k] * s)
|
459 |
+
ended_hyps.append(hyp)
|
460 |
+
else:
|
461 |
+
remained_hyps.append(hyp)
|
462 |
+
return remained_hyps
|
463 |
+
|
464 |
+
|
465 |
+
def beam_search(
|
466 |
+
x: torch.Tensor,
|
467 |
+
sos: int,
|
468 |
+
eos: int,
|
469 |
+
beam_size: int,
|
470 |
+
vocab_size: int,
|
471 |
+
scorers: Dict[str, ScorerInterface],
|
472 |
+
weights: Dict[str, float],
|
473 |
+
token_list: List[str] = None,
|
474 |
+
maxlenratio: float = 0.0,
|
475 |
+
minlenratio: float = 0.0,
|
476 |
+
pre_beam_ratio: float = 1.5,
|
477 |
+
pre_beam_score_key: str = "full",
|
478 |
+
) -> list:
|
479 |
+
"""Perform beam search with scorers.
|
480 |
+
|
481 |
+
Args:
|
482 |
+
x (torch.Tensor): Encoded speech feature (T, D)
|
483 |
+
sos (int): Start of sequence id
|
484 |
+
eos (int): End of sequence id
|
485 |
+
beam_size (int): The number of hypotheses kept during search
|
486 |
+
vocab_size (int): The number of vocabulary
|
487 |
+
scorers (dict[str, ScorerInterface]): Dict of decoder modules
|
488 |
+
e.g., Decoder, CTCPrefixScorer, LM
|
489 |
+
The scorer will be ignored if it is `None`
|
490 |
+
weights (dict[str, float]): Dict of weights for each scorers
|
491 |
+
The scorer will be ignored if its weight is 0
|
492 |
+
token_list (list[str]): List of tokens for debug log
|
493 |
+
maxlenratio (float): Input length ratio to obtain max output length.
|
494 |
+
If maxlenratio=0.0 (default), it uses a end-detect function
|
495 |
+
to automatically find maximum hypothesis lengths
|
496 |
+
minlenratio (float): Input length ratio to obtain min output length.
|
497 |
+
pre_beam_score_key (str): key of scores to perform pre-beam search
|
498 |
+
pre_beam_ratio (float): beam size in the pre-beam search
|
499 |
+
will be `int(pre_beam_ratio * beam_size)`
|
500 |
+
|
501 |
+
Returns:
|
502 |
+
list: N-best decoding results
|
503 |
+
|
504 |
+
"""
|
505 |
+
ret = BeamSearch(
|
506 |
+
scorers,
|
507 |
+
weights,
|
508 |
+
beam_size=beam_size,
|
509 |
+
vocab_size=vocab_size,
|
510 |
+
pre_beam_ratio=pre_beam_ratio,
|
511 |
+
pre_beam_score_key=pre_beam_score_key,
|
512 |
+
sos=sos,
|
513 |
+
eos=eos,
|
514 |
+
token_list=token_list,
|
515 |
+
).forward(x=x, maxlenratio=maxlenratio, minlenratio=minlenratio)
|
516 |
+
return [h.asdict() for h in ret]
|
espnet/nets/ctc_prefix_score.py
ADDED
@@ -0,0 +1,359 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
# Copyright 2018 Mitsubishi Electric Research Labs (Takaaki Hori)
|
4 |
+
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
5 |
+
|
6 |
+
import torch
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import six
|
10 |
+
|
11 |
+
|
12 |
+
class CTCPrefixScoreTH(object):
|
13 |
+
"""Batch processing of CTCPrefixScore
|
14 |
+
|
15 |
+
which is based on Algorithm 2 in WATANABE et al.
|
16 |
+
"HYBRID CTC/ATTENTION ARCHITECTURE FOR END-TO-END SPEECH RECOGNITION,"
|
17 |
+
but extended to efficiently compute the label probablities for multiple
|
18 |
+
hypotheses simultaneously
|
19 |
+
See also Seki et al. "Vectorized Beam Search for CTC-Attention-Based
|
20 |
+
Speech Recognition," In INTERSPEECH (pp. 3825-3829), 2019.
|
21 |
+
"""
|
22 |
+
|
23 |
+
def __init__(self, x, xlens, blank, eos, margin=0):
|
24 |
+
"""Construct CTC prefix scorer
|
25 |
+
|
26 |
+
:param torch.Tensor x: input label posterior sequences (B, T, O)
|
27 |
+
:param torch.Tensor xlens: input lengths (B,)
|
28 |
+
:param int blank: blank label id
|
29 |
+
:param int eos: end-of-sequence id
|
30 |
+
:param int margin: margin parameter for windowing (0 means no windowing)
|
31 |
+
"""
|
32 |
+
# In the comment lines,
|
33 |
+
# we assume T: input_length, B: batch size, W: beam width, O: output dim.
|
34 |
+
self.logzero = -10000000000.0
|
35 |
+
self.blank = blank
|
36 |
+
self.eos = eos
|
37 |
+
self.batch = x.size(0)
|
38 |
+
self.input_length = x.size(1)
|
39 |
+
self.odim = x.size(2)
|
40 |
+
self.dtype = x.dtype
|
41 |
+
self.device = (
|
42 |
+
torch.device("cuda:%d" % x.get_device())
|
43 |
+
if x.is_cuda
|
44 |
+
else torch.device("cpu")
|
45 |
+
)
|
46 |
+
# Pad the rest of posteriors in the batch
|
47 |
+
# TODO(takaaki-hori): need a better way without for-loops
|
48 |
+
for i, l in enumerate(xlens):
|
49 |
+
if l < self.input_length:
|
50 |
+
x[i, l:, :] = self.logzero
|
51 |
+
x[i, l:, blank] = 0
|
52 |
+
# Reshape input x
|
53 |
+
xn = x.transpose(0, 1) # (B, T, O) -> (T, B, O)
|
54 |
+
xb = xn[:, :, self.blank].unsqueeze(2).expand(-1, -1, self.odim)
|
55 |
+
self.x = torch.stack([xn, xb]) # (2, T, B, O)
|
56 |
+
self.end_frames = torch.as_tensor(xlens) - 1
|
57 |
+
|
58 |
+
# Setup CTC windowing
|
59 |
+
self.margin = margin
|
60 |
+
if margin > 0:
|
61 |
+
self.frame_ids = torch.arange(
|
62 |
+
self.input_length, dtype=self.dtype, device=self.device
|
63 |
+
)
|
64 |
+
# Base indices for index conversion
|
65 |
+
self.idx_bh = None
|
66 |
+
self.idx_b = torch.arange(self.batch, device=self.device)
|
67 |
+
self.idx_bo = (self.idx_b * self.odim).unsqueeze(1)
|
68 |
+
|
69 |
+
def __call__(self, y, state, scoring_ids=None, att_w=None):
|
70 |
+
"""Compute CTC prefix scores for next labels
|
71 |
+
|
72 |
+
:param list y: prefix label sequences
|
73 |
+
:param tuple state: previous CTC state
|
74 |
+
:param torch.Tensor pre_scores: scores for pre-selection of hypotheses (BW, O)
|
75 |
+
:param torch.Tensor att_w: attention weights to decide CTC window
|
76 |
+
:return new_state, ctc_local_scores (BW, O)
|
77 |
+
"""
|
78 |
+
output_length = len(y[0]) - 1 # ignore sos
|
79 |
+
last_ids = [yi[-1] for yi in y] # last output label ids
|
80 |
+
n_bh = len(last_ids) # batch * hyps
|
81 |
+
n_hyps = n_bh // self.batch # assuming each utterance has the same # of hyps
|
82 |
+
self.scoring_num = scoring_ids.size(-1) if scoring_ids is not None else 0
|
83 |
+
# prepare state info
|
84 |
+
if state is None:
|
85 |
+
r_prev = torch.full(
|
86 |
+
(self.input_length, 2, self.batch, n_hyps),
|
87 |
+
self.logzero,
|
88 |
+
dtype=self.dtype,
|
89 |
+
device=self.device,
|
90 |
+
)
|
91 |
+
r_prev[:, 1] = torch.cumsum(self.x[0, :, :, self.blank], 0).unsqueeze(2)
|
92 |
+
r_prev = r_prev.view(-1, 2, n_bh)
|
93 |
+
s_prev = 0.0
|
94 |
+
f_min_prev = 0
|
95 |
+
f_max_prev = 1
|
96 |
+
else:
|
97 |
+
r_prev, s_prev, f_min_prev, f_max_prev = state
|
98 |
+
|
99 |
+
# select input dimensions for scoring
|
100 |
+
if self.scoring_num > 0:
|
101 |
+
scoring_idmap = torch.full(
|
102 |
+
(n_bh, self.odim), -1, dtype=torch.long, device=self.device
|
103 |
+
)
|
104 |
+
snum = self.scoring_num
|
105 |
+
if self.idx_bh is None or n_bh > len(self.idx_bh):
|
106 |
+
self.idx_bh = torch.arange(n_bh, device=self.device).view(-1, 1)
|
107 |
+
scoring_idmap[self.idx_bh[:n_bh], scoring_ids] = torch.arange(
|
108 |
+
snum, device=self.device
|
109 |
+
)
|
110 |
+
scoring_idx = (
|
111 |
+
scoring_ids + self.idx_bo.repeat(1, n_hyps).view(-1, 1)
|
112 |
+
).view(-1)
|
113 |
+
x_ = torch.index_select(
|
114 |
+
self.x.view(2, -1, self.batch * self.odim), 2, scoring_idx
|
115 |
+
).view(2, -1, n_bh, snum)
|
116 |
+
else:
|
117 |
+
scoring_ids = None
|
118 |
+
scoring_idmap = None
|
119 |
+
snum = self.odim
|
120 |
+
x_ = self.x.unsqueeze(3).repeat(1, 1, 1, n_hyps, 1).view(2, -1, n_bh, snum)
|
121 |
+
|
122 |
+
# new CTC forward probs are prepared as a (T x 2 x BW x S) tensor
|
123 |
+
# that corresponds to r_t^n(h) and r_t^b(h) in a batch.
|
124 |
+
r = torch.full(
|
125 |
+
(self.input_length, 2, n_bh, snum),
|
126 |
+
self.logzero,
|
127 |
+
dtype=self.dtype,
|
128 |
+
device=self.device,
|
129 |
+
)
|
130 |
+
if output_length == 0:
|
131 |
+
r[0, 0] = x_[0, 0]
|
132 |
+
|
133 |
+
r_sum = torch.logsumexp(r_prev, 1)
|
134 |
+
log_phi = r_sum.unsqueeze(2).repeat(1, 1, snum)
|
135 |
+
if scoring_ids is not None:
|
136 |
+
for idx in range(n_bh):
|
137 |
+
pos = scoring_idmap[idx, last_ids[idx]]
|
138 |
+
if pos >= 0:
|
139 |
+
log_phi[:, idx, pos] = r_prev[:, 1, idx]
|
140 |
+
else:
|
141 |
+
for idx in range(n_bh):
|
142 |
+
log_phi[:, idx, last_ids[idx]] = r_prev[:, 1, idx]
|
143 |
+
|
144 |
+
# decide start and end frames based on attention weights
|
145 |
+
if att_w is not None and self.margin > 0:
|
146 |
+
f_arg = torch.matmul(att_w, self.frame_ids)
|
147 |
+
f_min = max(int(f_arg.min().cpu()), f_min_prev)
|
148 |
+
f_max = max(int(f_arg.max().cpu()), f_max_prev)
|
149 |
+
start = min(f_max_prev, max(f_min - self.margin, output_length, 1))
|
150 |
+
end = min(f_max + self.margin, self.input_length)
|
151 |
+
else:
|
152 |
+
f_min = f_max = 0
|
153 |
+
start = max(output_length, 1)
|
154 |
+
end = self.input_length
|
155 |
+
|
156 |
+
# compute forward probabilities log(r_t^n(h)) and log(r_t^b(h))
|
157 |
+
for t in range(start, end):
|
158 |
+
rp = r[t - 1]
|
159 |
+
rr = torch.stack([rp[0], log_phi[t - 1], rp[0], rp[1]]).view(
|
160 |
+
2, 2, n_bh, snum
|
161 |
+
)
|
162 |
+
r[t] = torch.logsumexp(rr, 1) + x_[:, t]
|
163 |
+
|
164 |
+
# compute log prefix probabilities log(psi)
|
165 |
+
log_phi_x = torch.cat((log_phi[0].unsqueeze(0), log_phi[:-1]), dim=0) + x_[0]
|
166 |
+
if scoring_ids is not None:
|
167 |
+
log_psi = torch.full(
|
168 |
+
(n_bh, self.odim), self.logzero, dtype=self.dtype, device=self.device
|
169 |
+
)
|
170 |
+
log_psi_ = torch.logsumexp(
|
171 |
+
torch.cat((log_phi_x[start:end], r[start - 1, 0].unsqueeze(0)), dim=0),
|
172 |
+
dim=0,
|
173 |
+
)
|
174 |
+
for si in range(n_bh):
|
175 |
+
log_psi[si, scoring_ids[si]] = log_psi_[si]
|
176 |
+
else:
|
177 |
+
log_psi = torch.logsumexp(
|
178 |
+
torch.cat((log_phi_x[start:end], r[start - 1, 0].unsqueeze(0)), dim=0),
|
179 |
+
dim=0,
|
180 |
+
)
|
181 |
+
|
182 |
+
for si in range(n_bh):
|
183 |
+
log_psi[si, self.eos] = r_sum[self.end_frames[si // n_hyps], si]
|
184 |
+
|
185 |
+
# exclude blank probs
|
186 |
+
log_psi[:, self.blank] = self.logzero
|
187 |
+
|
188 |
+
return (log_psi - s_prev), (r, log_psi, f_min, f_max, scoring_idmap)
|
189 |
+
|
190 |
+
def index_select_state(self, state, best_ids):
|
191 |
+
"""Select CTC states according to best ids
|
192 |
+
|
193 |
+
:param state : CTC state
|
194 |
+
:param best_ids : index numbers selected by beam pruning (B, W)
|
195 |
+
:return selected_state
|
196 |
+
"""
|
197 |
+
r, s, f_min, f_max, scoring_idmap = state
|
198 |
+
# convert ids to BHO space
|
199 |
+
n_bh = len(s)
|
200 |
+
n_hyps = n_bh // self.batch
|
201 |
+
vidx = (best_ids + (self.idx_b * (n_hyps * self.odim)).view(-1, 1)).view(-1)
|
202 |
+
# select hypothesis scores
|
203 |
+
s_new = torch.index_select(s.view(-1), 0, vidx)
|
204 |
+
s_new = s_new.view(-1, 1).repeat(1, self.odim).view(n_bh, self.odim)
|
205 |
+
# convert ids to BHS space (S: scoring_num)
|
206 |
+
if scoring_idmap is not None:
|
207 |
+
snum = self.scoring_num
|
208 |
+
hyp_idx = (best_ids // self.odim + (self.idx_b * n_hyps).view(-1, 1)).view(
|
209 |
+
-1
|
210 |
+
)
|
211 |
+
label_ids = torch.fmod(best_ids, self.odim).view(-1)
|
212 |
+
score_idx = scoring_idmap[hyp_idx, label_ids]
|
213 |
+
score_idx[score_idx == -1] = 0
|
214 |
+
vidx = score_idx + hyp_idx * snum
|
215 |
+
else:
|
216 |
+
snum = self.odim
|
217 |
+
# select forward probabilities
|
218 |
+
r_new = torch.index_select(r.view(-1, 2, n_bh * snum), 2, vidx).view(
|
219 |
+
-1, 2, n_bh
|
220 |
+
)
|
221 |
+
return r_new, s_new, f_min, f_max
|
222 |
+
|
223 |
+
def extend_prob(self, x):
|
224 |
+
"""Extend CTC prob.
|
225 |
+
|
226 |
+
:param torch.Tensor x: input label posterior sequences (B, T, O)
|
227 |
+
"""
|
228 |
+
|
229 |
+
if self.x.shape[1] < x.shape[1]: # self.x (2,T,B,O); x (B,T,O)
|
230 |
+
# Pad the rest of posteriors in the batch
|
231 |
+
# TODO(takaaki-hori): need a better way without for-loops
|
232 |
+
xlens = [x.size(1)]
|
233 |
+
for i, l in enumerate(xlens):
|
234 |
+
if l < self.input_length:
|
235 |
+
x[i, l:, :] = self.logzero
|
236 |
+
x[i, l:, self.blank] = 0
|
237 |
+
tmp_x = self.x
|
238 |
+
xn = x.transpose(0, 1) # (B, T, O) -> (T, B, O)
|
239 |
+
xb = xn[:, :, self.blank].unsqueeze(2).expand(-1, -1, self.odim)
|
240 |
+
self.x = torch.stack([xn, xb]) # (2, T, B, O)
|
241 |
+
self.x[:, : tmp_x.shape[1], :, :] = tmp_x
|
242 |
+
self.input_length = x.size(1)
|
243 |
+
self.end_frames = torch.as_tensor(xlens) - 1
|
244 |
+
|
245 |
+
def extend_state(self, state):
|
246 |
+
"""Compute CTC prefix state.
|
247 |
+
|
248 |
+
|
249 |
+
:param state : CTC state
|
250 |
+
:return ctc_state
|
251 |
+
"""
|
252 |
+
|
253 |
+
if state is None:
|
254 |
+
# nothing to do
|
255 |
+
return state
|
256 |
+
else:
|
257 |
+
r_prev, s_prev, f_min_prev, f_max_prev = state
|
258 |
+
|
259 |
+
r_prev_new = torch.full(
|
260 |
+
(self.input_length, 2),
|
261 |
+
self.logzero,
|
262 |
+
dtype=self.dtype,
|
263 |
+
device=self.device,
|
264 |
+
)
|
265 |
+
start = max(r_prev.shape[0], 1)
|
266 |
+
r_prev_new[0:start] = r_prev
|
267 |
+
for t in six.moves.range(start, self.input_length):
|
268 |
+
r_prev_new[t, 1] = r_prev_new[t - 1, 1] + self.x[0, t, :, self.blank]
|
269 |
+
|
270 |
+
return (r_prev_new, s_prev, f_min_prev, f_max_prev)
|
271 |
+
|
272 |
+
|
273 |
+
class CTCPrefixScore(object):
|
274 |
+
"""Compute CTC label sequence scores
|
275 |
+
|
276 |
+
which is based on Algorithm 2 in WATANABE et al.
|
277 |
+
"HYBRID CTC/ATTENTION ARCHITECTURE FOR END-TO-END SPEECH RECOGNITION,"
|
278 |
+
but extended to efficiently compute the probablities of multiple labels
|
279 |
+
simultaneously
|
280 |
+
"""
|
281 |
+
|
282 |
+
def __init__(self, x, blank, eos, xp):
|
283 |
+
self.xp = xp
|
284 |
+
self.logzero = -10000000000.0
|
285 |
+
self.blank = blank
|
286 |
+
self.eos = eos
|
287 |
+
self.input_length = len(x)
|
288 |
+
self.x = x
|
289 |
+
|
290 |
+
def initial_state(self):
|
291 |
+
"""Obtain an initial CTC state
|
292 |
+
|
293 |
+
:return: CTC state
|
294 |
+
"""
|
295 |
+
# initial CTC state is made of a frame x 2 tensor that corresponds to
|
296 |
+
# r_t^n(<sos>) and r_t^b(<sos>), where 0 and 1 of axis=1 represent
|
297 |
+
# superscripts n and b (non-blank and blank), respectively.
|
298 |
+
r = self.xp.full((self.input_length, 2), self.logzero, dtype=np.float32)
|
299 |
+
r[0, 1] = self.x[0, self.blank]
|
300 |
+
for i in six.moves.range(1, self.input_length):
|
301 |
+
r[i, 1] = r[i - 1, 1] + self.x[i, self.blank]
|
302 |
+
return r
|
303 |
+
|
304 |
+
def __call__(self, y, cs, r_prev):
|
305 |
+
"""Compute CTC prefix scores for next labels
|
306 |
+
|
307 |
+
:param y : prefix label sequence
|
308 |
+
:param cs : array of next labels
|
309 |
+
:param r_prev: previous CTC state
|
310 |
+
:return ctc_scores, ctc_states
|
311 |
+
"""
|
312 |
+
# initialize CTC states
|
313 |
+
output_length = len(y) - 1 # ignore sos
|
314 |
+
# new CTC states are prepared as a frame x (n or b) x n_labels tensor
|
315 |
+
# that corresponds to r_t^n(h) and r_t^b(h).
|
316 |
+
r = self.xp.ndarray((self.input_length, 2, len(cs)), dtype=np.float32)
|
317 |
+
xs = self.x[:, cs]
|
318 |
+
if output_length == 0:
|
319 |
+
r[0, 0] = xs[0]
|
320 |
+
r[0, 1] = self.logzero
|
321 |
+
else:
|
322 |
+
r[output_length - 1] = self.logzero
|
323 |
+
|
324 |
+
# prepare forward probabilities for the last label
|
325 |
+
r_sum = self.xp.logaddexp(
|
326 |
+
r_prev[:, 0], r_prev[:, 1]
|
327 |
+
) # log(r_t^n(g) + r_t^b(g))
|
328 |
+
last = y[-1]
|
329 |
+
if output_length > 0 and last in cs:
|
330 |
+
log_phi = self.xp.ndarray((self.input_length, len(cs)), dtype=np.float32)
|
331 |
+
for i in six.moves.range(len(cs)):
|
332 |
+
log_phi[:, i] = r_sum if cs[i] != last else r_prev[:, 1]
|
333 |
+
else:
|
334 |
+
log_phi = r_sum
|
335 |
+
|
336 |
+
# compute forward probabilities log(r_t^n(h)), log(r_t^b(h)),
|
337 |
+
# and log prefix probabilities log(psi)
|
338 |
+
start = max(output_length, 1)
|
339 |
+
log_psi = r[start - 1, 0]
|
340 |
+
for t in six.moves.range(start, self.input_length):
|
341 |
+
r[t, 0] = self.xp.logaddexp(r[t - 1, 0], log_phi[t - 1]) + xs[t]
|
342 |
+
r[t, 1] = (
|
343 |
+
self.xp.logaddexp(r[t - 1, 0], r[t - 1, 1]) + self.x[t, self.blank]
|
344 |
+
)
|
345 |
+
log_psi = self.xp.logaddexp(log_psi, log_phi[t - 1] + xs[t])
|
346 |
+
|
347 |
+
# get P(...eos|X) that ends with the prefix itself
|
348 |
+
eos_pos = self.xp.where(cs == self.eos)[0]
|
349 |
+
if len(eos_pos) > 0:
|
350 |
+
log_psi[eos_pos] = r_sum[-1] # log(r_T^n(g) + r_T^b(g))
|
351 |
+
|
352 |
+
# exclude blank probs
|
353 |
+
blank_pos = self.xp.where(cs == self.blank)[0]
|
354 |
+
if len(blank_pos) > 0:
|
355 |
+
log_psi[blank_pos] = self.logzero
|
356 |
+
|
357 |
+
# return the log prefix probability and CTC states, where the label axis
|
358 |
+
# of the CTC states is moved to the first axis to slice it easily
|
359 |
+
return log_psi, self.xp.rollaxis(r, 2)
|
espnet/nets/e2e_asr_common.py
ADDED
@@ -0,0 +1,249 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# encoding: utf-8
|
3 |
+
|
4 |
+
# Copyright 2017 Johns Hopkins University (Shinji Watanabe)
|
5 |
+
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
6 |
+
|
7 |
+
"""Common functions for ASR."""
|
8 |
+
|
9 |
+
import json
|
10 |
+
import logging
|
11 |
+
import sys
|
12 |
+
|
13 |
+
from itertools import groupby
|
14 |
+
import numpy as np
|
15 |
+
import six
|
16 |
+
|
17 |
+
|
18 |
+
def end_detect(ended_hyps, i, M=3, D_end=np.log(1 * np.exp(-10))):
|
19 |
+
"""End detection.
|
20 |
+
|
21 |
+
described in Eq. (50) of S. Watanabe et al
|
22 |
+
"Hybrid CTC/Attention Architecture for End-to-End Speech Recognition"
|
23 |
+
|
24 |
+
:param ended_hyps:
|
25 |
+
:param i:
|
26 |
+
:param M:
|
27 |
+
:param D_end:
|
28 |
+
:return:
|
29 |
+
"""
|
30 |
+
if len(ended_hyps) == 0:
|
31 |
+
return False
|
32 |
+
count = 0
|
33 |
+
best_hyp = sorted(ended_hyps, key=lambda x: x["score"], reverse=True)[0]
|
34 |
+
for m in six.moves.range(M):
|
35 |
+
# get ended_hyps with their length is i - m
|
36 |
+
hyp_length = i - m
|
37 |
+
hyps_same_length = [x for x in ended_hyps if len(x["yseq"]) == hyp_length]
|
38 |
+
if len(hyps_same_length) > 0:
|
39 |
+
best_hyp_same_length = sorted(
|
40 |
+
hyps_same_length, key=lambda x: x["score"], reverse=True
|
41 |
+
)[0]
|
42 |
+
if best_hyp_same_length["score"] - best_hyp["score"] < D_end:
|
43 |
+
count += 1
|
44 |
+
|
45 |
+
if count == M:
|
46 |
+
return True
|
47 |
+
else:
|
48 |
+
return False
|
49 |
+
|
50 |
+
|
51 |
+
# TODO(takaaki-hori): add different smoothing methods
|
52 |
+
def label_smoothing_dist(odim, lsm_type, transcript=None, blank=0):
|
53 |
+
"""Obtain label distribution for loss smoothing.
|
54 |
+
|
55 |
+
:param odim:
|
56 |
+
:param lsm_type:
|
57 |
+
:param blank:
|
58 |
+
:param transcript:
|
59 |
+
:return:
|
60 |
+
"""
|
61 |
+
if transcript is not None:
|
62 |
+
with open(transcript, "rb") as f:
|
63 |
+
trans_json = json.load(f)["utts"]
|
64 |
+
|
65 |
+
if lsm_type == "unigram":
|
66 |
+
assert transcript is not None, (
|
67 |
+
"transcript is required for %s label smoothing" % lsm_type
|
68 |
+
)
|
69 |
+
labelcount = np.zeros(odim)
|
70 |
+
for k, v in trans_json.items():
|
71 |
+
ids = np.array([int(n) for n in v["output"][0]["tokenid"].split()])
|
72 |
+
# to avoid an error when there is no text in an uttrance
|
73 |
+
if len(ids) > 0:
|
74 |
+
labelcount[ids] += 1
|
75 |
+
labelcount[odim - 1] = len(transcript) # count <eos>
|
76 |
+
labelcount[labelcount == 0] = 1 # flooring
|
77 |
+
labelcount[blank] = 0 # remove counts for blank
|
78 |
+
labeldist = labelcount.astype(np.float32) / np.sum(labelcount)
|
79 |
+
else:
|
80 |
+
logging.error("Error: unexpected label smoothing type: %s" % lsm_type)
|
81 |
+
sys.exit()
|
82 |
+
|
83 |
+
return labeldist
|
84 |
+
|
85 |
+
|
86 |
+
def get_vgg2l_odim(idim, in_channel=3, out_channel=128):
|
87 |
+
"""Return the output size of the VGG frontend.
|
88 |
+
|
89 |
+
:param in_channel: input channel size
|
90 |
+
:param out_channel: output channel size
|
91 |
+
:return: output size
|
92 |
+
:rtype int
|
93 |
+
"""
|
94 |
+
idim = idim / in_channel
|
95 |
+
idim = np.ceil(np.array(idim, dtype=np.float32) / 2) # 1st max pooling
|
96 |
+
idim = np.ceil(np.array(idim, dtype=np.float32) / 2) # 2nd max pooling
|
97 |
+
return int(idim) * out_channel # numer of channels
|
98 |
+
|
99 |
+
|
100 |
+
class ErrorCalculator(object):
|
101 |
+
"""Calculate CER and WER for E2E_ASR and CTC models during training.
|
102 |
+
|
103 |
+
:param y_hats: numpy array with predicted text
|
104 |
+
:param y_pads: numpy array with true (target) text
|
105 |
+
:param char_list:
|
106 |
+
:param sym_space:
|
107 |
+
:param sym_blank:
|
108 |
+
:return:
|
109 |
+
"""
|
110 |
+
|
111 |
+
def __init__(
|
112 |
+
self, char_list, sym_space, sym_blank, report_cer=False, report_wer=False
|
113 |
+
):
|
114 |
+
"""Construct an ErrorCalculator object."""
|
115 |
+
super(ErrorCalculator, self).__init__()
|
116 |
+
|
117 |
+
self.report_cer = report_cer
|
118 |
+
self.report_wer = report_wer
|
119 |
+
|
120 |
+
self.char_list = char_list
|
121 |
+
self.space = sym_space
|
122 |
+
self.blank = sym_blank
|
123 |
+
self.idx_blank = self.char_list.index(self.blank)
|
124 |
+
if self.space in self.char_list:
|
125 |
+
self.idx_space = self.char_list.index(self.space)
|
126 |
+
else:
|
127 |
+
self.idx_space = None
|
128 |
+
|
129 |
+
def __call__(self, ys_hat, ys_pad, is_ctc=False):
|
130 |
+
"""Calculate sentence-level WER/CER score.
|
131 |
+
|
132 |
+
:param torch.Tensor ys_hat: prediction (batch, seqlen)
|
133 |
+
:param torch.Tensor ys_pad: reference (batch, seqlen)
|
134 |
+
:param bool is_ctc: calculate CER score for CTC
|
135 |
+
:return: sentence-level WER score
|
136 |
+
:rtype float
|
137 |
+
:return: sentence-level CER score
|
138 |
+
:rtype float
|
139 |
+
"""
|
140 |
+
cer, wer = None, None
|
141 |
+
if is_ctc:
|
142 |
+
return self.calculate_cer_ctc(ys_hat, ys_pad)
|
143 |
+
elif not self.report_cer and not self.report_wer:
|
144 |
+
return cer, wer
|
145 |
+
|
146 |
+
seqs_hat, seqs_true = self.convert_to_char(ys_hat, ys_pad)
|
147 |
+
if self.report_cer:
|
148 |
+
cer = self.calculate_cer(seqs_hat, seqs_true)
|
149 |
+
|
150 |
+
if self.report_wer:
|
151 |
+
wer = self.calculate_wer(seqs_hat, seqs_true)
|
152 |
+
return cer, wer
|
153 |
+
|
154 |
+
def calculate_cer_ctc(self, ys_hat, ys_pad):
|
155 |
+
"""Calculate sentence-level CER score for CTC.
|
156 |
+
|
157 |
+
:param torch.Tensor ys_hat: prediction (batch, seqlen)
|
158 |
+
:param torch.Tensor ys_pad: reference (batch, seqlen)
|
159 |
+
:return: average sentence-level CER score
|
160 |
+
:rtype float
|
161 |
+
"""
|
162 |
+
import editdistance
|
163 |
+
|
164 |
+
cers, char_ref_lens = [], []
|
165 |
+
for i, y in enumerate(ys_hat):
|
166 |
+
y_hat = [x[0] for x in groupby(y)]
|
167 |
+
y_true = ys_pad[i]
|
168 |
+
seq_hat, seq_true = [], []
|
169 |
+
for idx in y_hat:
|
170 |
+
idx = int(idx)
|
171 |
+
if idx != -1 and idx != self.idx_blank and idx != self.idx_space:
|
172 |
+
seq_hat.append(self.char_list[int(idx)])
|
173 |
+
|
174 |
+
for idx in y_true:
|
175 |
+
idx = int(idx)
|
176 |
+
if idx != -1 and idx != self.idx_blank and idx != self.idx_space:
|
177 |
+
seq_true.append(self.char_list[int(idx)])
|
178 |
+
|
179 |
+
hyp_chars = "".join(seq_hat)
|
180 |
+
ref_chars = "".join(seq_true)
|
181 |
+
if len(ref_chars) > 0:
|
182 |
+
cers.append(editdistance.eval(hyp_chars, ref_chars))
|
183 |
+
char_ref_lens.append(len(ref_chars))
|
184 |
+
|
185 |
+
cer_ctc = float(sum(cers)) / sum(char_ref_lens) if cers else None
|
186 |
+
return cer_ctc
|
187 |
+
|
188 |
+
def convert_to_char(self, ys_hat, ys_pad):
|
189 |
+
"""Convert index to character.
|
190 |
+
|
191 |
+
:param torch.Tensor seqs_hat: prediction (batch, seqlen)
|
192 |
+
:param torch.Tensor seqs_true: reference (batch, seqlen)
|
193 |
+
:return: token list of prediction
|
194 |
+
:rtype list
|
195 |
+
:return: token list of reference
|
196 |
+
:rtype list
|
197 |
+
"""
|
198 |
+
seqs_hat, seqs_true = [], []
|
199 |
+
for i, y_hat in enumerate(ys_hat):
|
200 |
+
y_true = ys_pad[i]
|
201 |
+
eos_true = np.where(y_true == -1)[0]
|
202 |
+
ymax = eos_true[0] if len(eos_true) > 0 else len(y_true)
|
203 |
+
# NOTE: padding index (-1) in y_true is used to pad y_hat
|
204 |
+
seq_hat = [self.char_list[int(idx)] for idx in y_hat[:ymax]]
|
205 |
+
seq_true = [self.char_list[int(idx)] for idx in y_true if int(idx) != -1]
|
206 |
+
seq_hat_text = "".join(seq_hat).replace(self.space, " ")
|
207 |
+
seq_hat_text = seq_hat_text.replace(self.blank, "")
|
208 |
+
seq_true_text = "".join(seq_true).replace(self.space, " ")
|
209 |
+
seqs_hat.append(seq_hat_text)
|
210 |
+
seqs_true.append(seq_true_text)
|
211 |
+
return seqs_hat, seqs_true
|
212 |
+
|
213 |
+
def calculate_cer(self, seqs_hat, seqs_true):
|
214 |
+
"""Calculate sentence-level CER score.
|
215 |
+
|
216 |
+
:param list seqs_hat: prediction
|
217 |
+
:param list seqs_true: reference
|
218 |
+
:return: average sentence-level CER score
|
219 |
+
:rtype float
|
220 |
+
"""
|
221 |
+
import editdistance
|
222 |
+
|
223 |
+
char_eds, char_ref_lens = [], []
|
224 |
+
for i, seq_hat_text in enumerate(seqs_hat):
|
225 |
+
seq_true_text = seqs_true[i]
|
226 |
+
hyp_chars = seq_hat_text.replace(" ", "")
|
227 |
+
ref_chars = seq_true_text.replace(" ", "")
|
228 |
+
char_eds.append(editdistance.eval(hyp_chars, ref_chars))
|
229 |
+
char_ref_lens.append(len(ref_chars))
|
230 |
+
return float(sum(char_eds)) / sum(char_ref_lens)
|
231 |
+
|
232 |
+
def calculate_wer(self, seqs_hat, seqs_true):
|
233 |
+
"""Calculate sentence-level WER score.
|
234 |
+
|
235 |
+
:param list seqs_hat: prediction
|
236 |
+
:param list seqs_true: reference
|
237 |
+
:return: average sentence-level WER score
|
238 |
+
:rtype float
|
239 |
+
"""
|
240 |
+
import editdistance
|
241 |
+
|
242 |
+
word_eds, word_ref_lens = [], []
|
243 |
+
for i, seq_hat_text in enumerate(seqs_hat):
|
244 |
+
seq_true_text = seqs_true[i]
|
245 |
+
hyp_words = seq_hat_text.split()
|
246 |
+
ref_words = seq_true_text.split()
|
247 |
+
word_eds.append(editdistance.eval(hyp_words, ref_words))
|
248 |
+
word_ref_lens.append(len(ref_words))
|
249 |
+
return float(sum(word_eds)) / sum(word_ref_lens)
|
espnet/nets/lm_interface.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Language model interface."""
|
2 |
+
|
3 |
+
import argparse
|
4 |
+
|
5 |
+
from espnet.nets.scorer_interface import ScorerInterface
|
6 |
+
from espnet.utils.dynamic_import import dynamic_import
|
7 |
+
from espnet.utils.fill_missing_args import fill_missing_args
|
8 |
+
|
9 |
+
|
10 |
+
class LMInterface(ScorerInterface):
|
11 |
+
"""LM Interface for ESPnet model implementation."""
|
12 |
+
|
13 |
+
@staticmethod
|
14 |
+
def add_arguments(parser):
|
15 |
+
"""Add arguments to command line argument parser."""
|
16 |
+
return parser
|
17 |
+
|
18 |
+
@classmethod
|
19 |
+
def build(cls, n_vocab: int, **kwargs):
|
20 |
+
"""Initialize this class with python-level args.
|
21 |
+
|
22 |
+
Args:
|
23 |
+
idim (int): The number of vocabulary.
|
24 |
+
|
25 |
+
Returns:
|
26 |
+
LMinterface: A new instance of LMInterface.
|
27 |
+
|
28 |
+
"""
|
29 |
+
# local import to avoid cyclic import in lm_train
|
30 |
+
from espnet.bin.lm_train import get_parser
|
31 |
+
|
32 |
+
def wrap(parser):
|
33 |
+
return get_parser(parser, required=False)
|
34 |
+
|
35 |
+
args = argparse.Namespace(**kwargs)
|
36 |
+
args = fill_missing_args(args, wrap)
|
37 |
+
args = fill_missing_args(args, cls.add_arguments)
|
38 |
+
return cls(n_vocab, args)
|
39 |
+
|
40 |
+
def forward(self, x, t):
|
41 |
+
"""Compute LM loss value from buffer sequences.
|
42 |
+
|
43 |
+
Args:
|
44 |
+
x (torch.Tensor): Input ids. (batch, len)
|
45 |
+
t (torch.Tensor): Target ids. (batch, len)
|
46 |
+
|
47 |
+
Returns:
|
48 |
+
tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Tuple of
|
49 |
+
loss to backward (scalar),
|
50 |
+
negative log-likelihood of t: -log p(t) (scalar) and
|
51 |
+
the number of elements in x (scalar)
|
52 |
+
|
53 |
+
Notes:
|
54 |
+
The last two return values are used
|
55 |
+
in perplexity: p(t)^{-n} = exp(-log p(t) / n)
|
56 |
+
|
57 |
+
"""
|
58 |
+
raise NotImplementedError("forward method is not implemented")
|
59 |
+
|
60 |
+
|
61 |
+
predefined_lms = {
|
62 |
+
"pytorch": {
|
63 |
+
"default": "espnet.nets.pytorch_backend.lm.default:DefaultRNNLM",
|
64 |
+
"seq_rnn": "espnet.nets.pytorch_backend.lm.seq_rnn:SequentialRNNLM",
|
65 |
+
"transformer": "espnet.nets.pytorch_backend.lm.transformer:TransformerLM",
|
66 |
+
},
|
67 |
+
"chainer": {"default": "espnet.lm.chainer_backend.lm:DefaultRNNLM"},
|
68 |
+
}
|
69 |
+
|
70 |
+
|
71 |
+
def dynamic_import_lm(module, backend):
|
72 |
+
"""Import LM class dynamically.
|
73 |
+
|
74 |
+
Args:
|
75 |
+
module (str): module_name:class_name or alias in `predefined_lms`
|
76 |
+
backend (str): NN backend. e.g., pytorch, chainer
|
77 |
+
|
78 |
+
Returns:
|
79 |
+
type: LM class
|
80 |
+
|
81 |
+
"""
|
82 |
+
model_class = dynamic_import(module, predefined_lms.get(backend, dict()))
|
83 |
+
assert issubclass(
|
84 |
+
model_class, LMInterface
|
85 |
+
), f"{module} does not implement LMInterface"
|
86 |
+
return model_class
|
espnet/nets/pytorch_backend/backbones/conv1d_extractor.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#! /usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
# Copyright 2021 Imperial College London (Pingchuan Ma)
|
5 |
+
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
6 |
+
import torch
|
7 |
+
from espnet.nets.pytorch_backend.backbones.modules.resnet1d import ResNet1D, BasicBlock1D
|
8 |
+
|
9 |
+
class Conv1dResNet(torch.nn.Module):
|
10 |
+
def __init__(self, relu_type="swish", a_upsample_ratio=1):
|
11 |
+
super().__init__()
|
12 |
+
self.a_upsample_ratio = a_upsample_ratio
|
13 |
+
self.trunk = ResNet1D(BasicBlock1D, [2, 2, 2, 2], relu_type=relu_type, a_upsample_ratio=a_upsample_ratio)
|
14 |
+
|
15 |
+
|
16 |
+
def forward(self, xs_pad):
|
17 |
+
"""forward.
|
18 |
+
|
19 |
+
:param xs_pad: torch.Tensor, batch of padded input sequences (B, Tmax, idim)
|
20 |
+
"""
|
21 |
+
B, T, C = xs_pad.size()
|
22 |
+
xs_pad = xs_pad[:, :T // 640 * 640, :]
|
23 |
+
xs_pad = xs_pad.transpose(1, 2)
|
24 |
+
xs_pad = self.trunk(xs_pad)
|
25 |
+
return xs_pad.transpose(1, 2)
|
espnet/nets/pytorch_backend/backbones/conv3d_extractor.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#! /usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
# Copyright 2021 Imperial College London (Pingchuan Ma)
|
5 |
+
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
from espnet.nets.pytorch_backend.backbones.modules.resnet import ResNet, BasicBlock
|
10 |
+
from espnet.nets.pytorch_backend.transformer.convolution import Swish
|
11 |
+
|
12 |
+
|
13 |
+
def threeD_to_2D_tensor(x):
|
14 |
+
n_batch, n_channels, s_time, sx, sy = x.shape
|
15 |
+
x = x.transpose(1, 2)
|
16 |
+
return x.reshape(n_batch * s_time, n_channels, sx, sy)
|
17 |
+
|
18 |
+
|
19 |
+
|
20 |
+
class Conv3dResNet(torch.nn.Module):
|
21 |
+
"""Conv3dResNet module
|
22 |
+
"""
|
23 |
+
|
24 |
+
def __init__(self, backbone_type="resnet", relu_type="swish"):
|
25 |
+
"""__init__.
|
26 |
+
|
27 |
+
:param backbone_type: str, the type of a visual front-end.
|
28 |
+
:param relu_type: str, activation function used in an audio front-end.
|
29 |
+
"""
|
30 |
+
super(Conv3dResNet, self).__init__()
|
31 |
+
self.frontend_nout = 64
|
32 |
+
self.trunk = ResNet(BasicBlock, [2, 2, 2, 2], relu_type=relu_type)
|
33 |
+
self.frontend3D = nn.Sequential(
|
34 |
+
nn.Conv3d(1, self.frontend_nout, (5, 7, 7), (1, 2, 2), (2, 3, 3), bias=False),
|
35 |
+
nn.BatchNorm3d(self.frontend_nout),
|
36 |
+
Swish(),
|
37 |
+
nn.MaxPool3d((1, 3, 3), (1, 2, 2), (0, 1, 1))
|
38 |
+
)
|
39 |
+
|
40 |
+
|
41 |
+
def forward(self, xs_pad):
|
42 |
+
B, C, T, H, W = xs_pad.size()
|
43 |
+
xs_pad = self.frontend3D(xs_pad)
|
44 |
+
Tnew = xs_pad.shape[2]
|
45 |
+
xs_pad = threeD_to_2D_tensor(xs_pad)
|
46 |
+
xs_pad = self.trunk(xs_pad)
|
47 |
+
return xs_pad.view(B, Tnew, xs_pad.size(1))
|
espnet/nets/pytorch_backend/backbones/modules/resnet.py
ADDED
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch.nn as nn
|
3 |
+
import pdb
|
4 |
+
|
5 |
+
from espnet.nets.pytorch_backend.transformer.convolution import Swish
|
6 |
+
|
7 |
+
|
8 |
+
def conv3x3(in_planes, out_planes, stride=1):
|
9 |
+
"""conv3x3.
|
10 |
+
|
11 |
+
:param in_planes: int, number of channels in the input sequence.
|
12 |
+
:param out_planes: int, number of channels produced by the convolution.
|
13 |
+
:param stride: int, size of the convolving kernel.
|
14 |
+
"""
|
15 |
+
return nn.Conv2d(
|
16 |
+
in_planes,
|
17 |
+
out_planes,
|
18 |
+
kernel_size=3,
|
19 |
+
stride=stride,
|
20 |
+
padding=1,
|
21 |
+
bias=False,
|
22 |
+
)
|
23 |
+
|
24 |
+
|
25 |
+
def downsample_basic_block(inplanes, outplanes, stride):
|
26 |
+
"""downsample_basic_block.
|
27 |
+
|
28 |
+
:param inplanes: int, number of channels in the input sequence.
|
29 |
+
:param outplanes: int, number of channels produced by the convolution.
|
30 |
+
:param stride: int, size of the convolving kernel.
|
31 |
+
"""
|
32 |
+
return nn.Sequential(
|
33 |
+
nn.Conv2d(
|
34 |
+
inplanes,
|
35 |
+
outplanes,
|
36 |
+
kernel_size=1,
|
37 |
+
stride=stride,
|
38 |
+
bias=False,
|
39 |
+
),
|
40 |
+
nn.BatchNorm2d(outplanes),
|
41 |
+
)
|
42 |
+
|
43 |
+
|
44 |
+
class BasicBlock(nn.Module):
|
45 |
+
expansion = 1
|
46 |
+
|
47 |
+
def __init__(
|
48 |
+
self,
|
49 |
+
inplanes,
|
50 |
+
planes,
|
51 |
+
stride=1,
|
52 |
+
downsample=None,
|
53 |
+
relu_type="swish",
|
54 |
+
):
|
55 |
+
"""__init__.
|
56 |
+
|
57 |
+
:param inplanes: int, number of channels in the input sequence.
|
58 |
+
:param planes: int, number of channels produced by the convolution.
|
59 |
+
:param stride: int, size of the convolving kernel.
|
60 |
+
:param downsample: boolean, if True, the temporal resolution is downsampled.
|
61 |
+
:param relu_type: str, type of activation function.
|
62 |
+
"""
|
63 |
+
super(BasicBlock, self).__init__()
|
64 |
+
|
65 |
+
assert relu_type in ["relu", "prelu", "swish"]
|
66 |
+
|
67 |
+
self.conv1 = conv3x3(inplanes, planes, stride)
|
68 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
69 |
+
|
70 |
+
if relu_type == "relu":
|
71 |
+
self.relu1 = nn.ReLU(inplace=True)
|
72 |
+
self.relu2 = nn.ReLU(inplace=True)
|
73 |
+
elif relu_type == "prelu":
|
74 |
+
self.relu1 = nn.PReLU(num_parameters=planes)
|
75 |
+
self.relu2 = nn.PReLU(num_parameters=planes)
|
76 |
+
elif relu_type == "swish":
|
77 |
+
self.relu1 = Swish()
|
78 |
+
self.relu2 = Swish()
|
79 |
+
else:
|
80 |
+
raise NotImplementedError
|
81 |
+
# --------
|
82 |
+
|
83 |
+
self.conv2 = conv3x3(planes, planes)
|
84 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
85 |
+
|
86 |
+
self.downsample = downsample
|
87 |
+
self.stride = stride
|
88 |
+
|
89 |
+
def forward(self, x):
|
90 |
+
"""forward.
|
91 |
+
|
92 |
+
:param x: torch.Tensor, input tensor with input size (B, C, T, H, W).
|
93 |
+
"""
|
94 |
+
residual = x
|
95 |
+
out = self.conv1(x)
|
96 |
+
out = self.bn1(out)
|
97 |
+
out = self.relu1(out)
|
98 |
+
out = self.conv2(out)
|
99 |
+
out = self.bn2(out)
|
100 |
+
if self.downsample is not None:
|
101 |
+
residual = self.downsample(x)
|
102 |
+
|
103 |
+
out += residual
|
104 |
+
out = self.relu2(out)
|
105 |
+
|
106 |
+
return out
|
107 |
+
|
108 |
+
|
109 |
+
class ResNet(nn.Module):
|
110 |
+
|
111 |
+
def __init__(
|
112 |
+
self,
|
113 |
+
block,
|
114 |
+
layers,
|
115 |
+
relu_type="swish",
|
116 |
+
):
|
117 |
+
super(ResNet, self).__init__()
|
118 |
+
self.inplanes = 64
|
119 |
+
self.relu_type = relu_type
|
120 |
+
self.downsample_block = downsample_basic_block
|
121 |
+
|
122 |
+
self.layer1 = self._make_layer(block, 64, layers[0])
|
123 |
+
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
|
124 |
+
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
|
125 |
+
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
|
126 |
+
self.avgpool = nn.AdaptiveAvgPool2d(1)
|
127 |
+
|
128 |
+
|
129 |
+
def _make_layer(self, block, planes, blocks, stride=1):
|
130 |
+
"""_make_layer.
|
131 |
+
|
132 |
+
:param block: torch.nn.Module, class of blocks.
|
133 |
+
:param planes: int, number of channels produced by the convolution.
|
134 |
+
:param blocks: int, number of layers in a block.
|
135 |
+
:param stride: int, size of the convolving kernel.
|
136 |
+
"""
|
137 |
+
downsample = None
|
138 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
139 |
+
downsample = self.downsample_block(
|
140 |
+
inplanes=self.inplanes,
|
141 |
+
outplanes=planes*block.expansion,
|
142 |
+
stride=stride,
|
143 |
+
)
|
144 |
+
|
145 |
+
layers = []
|
146 |
+
layers.append(
|
147 |
+
block(
|
148 |
+
self.inplanes,
|
149 |
+
planes,
|
150 |
+
stride,
|
151 |
+
downsample,
|
152 |
+
relu_type=self.relu_type,
|
153 |
+
)
|
154 |
+
)
|
155 |
+
self.inplanes = planes * block.expansion
|
156 |
+
for i in range(1, blocks):
|
157 |
+
layers.append(
|
158 |
+
block(
|
159 |
+
self.inplanes,
|
160 |
+
planes,
|
161 |
+
relu_type=self.relu_type,
|
162 |
+
)
|
163 |
+
)
|
164 |
+
|
165 |
+
return nn.Sequential(*layers)
|
166 |
+
|
167 |
+
def forward(self, x):
|
168 |
+
"""forward.
|
169 |
+
|
170 |
+
:param x: torch.Tensor, input tensor with input size (B, C, T, H, W).
|
171 |
+
"""
|
172 |
+
x = self.layer1(x)
|
173 |
+
x = self.layer2(x)
|
174 |
+
x = self.layer3(x)
|
175 |
+
x = self.layer4(x)
|
176 |
+
x = self.avgpool(x)
|
177 |
+
x = x.view(x.size(0), -1)
|
178 |
+
return x
|
espnet/nets/pytorch_backend/backbones/modules/resnet1d.py
ADDED
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch.nn as nn
|
3 |
+
import pdb
|
4 |
+
|
5 |
+
from espnet.nets.pytorch_backend.transformer.convolution import Swish
|
6 |
+
|
7 |
+
|
8 |
+
def conv3x3(in_planes, out_planes, stride=1):
|
9 |
+
"""conv3x3.
|
10 |
+
|
11 |
+
:param in_planes: int, number of channels in the input sequence.
|
12 |
+
:param out_planes: int, number of channels produced by the convolution.
|
13 |
+
:param stride: int, size of the convolving kernel.
|
14 |
+
"""
|
15 |
+
return nn.Conv1d(
|
16 |
+
in_planes,
|
17 |
+
out_planes,
|
18 |
+
kernel_size=3,
|
19 |
+
stride=stride,
|
20 |
+
padding=1,
|
21 |
+
bias=False,
|
22 |
+
)
|
23 |
+
|
24 |
+
|
25 |
+
def downsample_basic_block(inplanes, outplanes, stride):
|
26 |
+
"""downsample_basic_block.
|
27 |
+
|
28 |
+
:param inplanes: int, number of channels in the input sequence.
|
29 |
+
:param outplanes: int, number of channels produced by the convolution.
|
30 |
+
:param stride: int, size of the convolving kernel.
|
31 |
+
"""
|
32 |
+
return nn.Sequential(
|
33 |
+
nn.Conv1d(
|
34 |
+
inplanes,
|
35 |
+
outplanes,
|
36 |
+
kernel_size=1,
|
37 |
+
stride=stride,
|
38 |
+
bias=False,
|
39 |
+
),
|
40 |
+
nn.BatchNorm1d(outplanes),
|
41 |
+
)
|
42 |
+
|
43 |
+
|
44 |
+
class BasicBlock1D(nn.Module):
|
45 |
+
expansion = 1
|
46 |
+
|
47 |
+
def __init__(
|
48 |
+
self,
|
49 |
+
inplanes,
|
50 |
+
planes,
|
51 |
+
stride=1,
|
52 |
+
downsample=None,
|
53 |
+
relu_type="relu",
|
54 |
+
):
|
55 |
+
"""__init__.
|
56 |
+
|
57 |
+
:param inplanes: int, number of channels in the input sequence.
|
58 |
+
:param planes: int, number of channels produced by the convolution.
|
59 |
+
:param stride: int, size of the convolving kernel.
|
60 |
+
:param downsample: boolean, if True, the temporal resolution is downsampled.
|
61 |
+
:param relu_type: str, type of activation function.
|
62 |
+
"""
|
63 |
+
super(BasicBlock1D, self).__init__()
|
64 |
+
|
65 |
+
assert relu_type in ["relu","prelu", "swish"]
|
66 |
+
|
67 |
+
self.conv1 = conv3x3(inplanes, planes, stride)
|
68 |
+
self.bn1 = nn.BatchNorm1d(planes)
|
69 |
+
|
70 |
+
# type of ReLU is an input option
|
71 |
+
if relu_type == "relu":
|
72 |
+
self.relu1 = nn.ReLU(inplace=True)
|
73 |
+
self.relu2 = nn.ReLU(inplace=True)
|
74 |
+
elif relu_type == "prelu":
|
75 |
+
self.relu1 = nn.PReLU(num_parameters=planes)
|
76 |
+
self.relu2 = nn.PReLU(num_parameters=planes)
|
77 |
+
elif relu_type == "swish":
|
78 |
+
self.relu1 = Swish()
|
79 |
+
self.relu2 = Swish()
|
80 |
+
else:
|
81 |
+
raise NotImplementedError
|
82 |
+
# --------
|
83 |
+
|
84 |
+
self.conv2 = conv3x3(planes, planes)
|
85 |
+
self.bn2 = nn.BatchNorm1d(planes)
|
86 |
+
|
87 |
+
self.downsample = downsample
|
88 |
+
self.stride = stride
|
89 |
+
|
90 |
+
def forward(self, x):
|
91 |
+
"""forward.
|
92 |
+
|
93 |
+
:param x: torch.Tensor, input tensor with input size (B, C, T)
|
94 |
+
"""
|
95 |
+
residual = x
|
96 |
+
out = self.conv1(x)
|
97 |
+
out = self.bn1(out)
|
98 |
+
out = self.relu1(out)
|
99 |
+
out = self.conv2(out)
|
100 |
+
out = self.bn2(out)
|
101 |
+
if self.downsample is not None:
|
102 |
+
residual = self.downsample(x)
|
103 |
+
|
104 |
+
out += residual
|
105 |
+
out = self.relu2(out)
|
106 |
+
|
107 |
+
return out
|
108 |
+
|
109 |
+
|
110 |
+
class ResNet1D(nn.Module):
|
111 |
+
|
112 |
+
def __init__(self,
|
113 |
+
block,
|
114 |
+
layers,
|
115 |
+
relu_type="swish",
|
116 |
+
a_upsample_ratio=1,
|
117 |
+
):
|
118 |
+
"""__init__.
|
119 |
+
|
120 |
+
:param block: torch.nn.Module, class of blocks.
|
121 |
+
:param layers: List, customised layers in each block.
|
122 |
+
:param relu_type: str, type of activation function.
|
123 |
+
:param a_upsample_ratio: int, The ratio related to the \
|
124 |
+
temporal resolution of output features of the frontend. \
|
125 |
+
a_upsample_ratio=1 produce features with a fps of 25.
|
126 |
+
"""
|
127 |
+
super(ResNet1D, self).__init__()
|
128 |
+
self.inplanes = 64
|
129 |
+
self.relu_type = relu_type
|
130 |
+
self.downsample_block = downsample_basic_block
|
131 |
+
self.a_upsample_ratio = a_upsample_ratio
|
132 |
+
|
133 |
+
self.conv1 = nn.Conv1d(
|
134 |
+
in_channels=1,
|
135 |
+
out_channels=self.inplanes,
|
136 |
+
kernel_size=80,
|
137 |
+
stride=4,
|
138 |
+
padding=38,
|
139 |
+
bias=False,
|
140 |
+
)
|
141 |
+
self.bn1 = nn.BatchNorm1d(self.inplanes)
|
142 |
+
|
143 |
+
if relu_type == "relu":
|
144 |
+
self.relu = nn.ReLU(inplace=True)
|
145 |
+
elif relu_type == "prelu":
|
146 |
+
self.relu = nn.PReLU(num_parameters=self.inplanes)
|
147 |
+
elif relu_type == "swish":
|
148 |
+
self.relu = Swish()
|
149 |
+
|
150 |
+
self.layer1 = self._make_layer(block, 64, layers[0])
|
151 |
+
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
|
152 |
+
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
|
153 |
+
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
|
154 |
+
self.avgpool = nn.AvgPool1d(
|
155 |
+
kernel_size=20//self.a_upsample_ratio,
|
156 |
+
stride=20//self.a_upsample_ratio,
|
157 |
+
)
|
158 |
+
|
159 |
+
|
160 |
+
def _make_layer(self, block, planes, blocks, stride=1):
|
161 |
+
"""_make_layer.
|
162 |
+
|
163 |
+
:param block: torch.nn.Module, class of blocks.
|
164 |
+
:param planes: int, number of channels produced by the convolution.
|
165 |
+
:param blocks: int, number of layers in a block.
|
166 |
+
:param stride: int, size of the convolving kernel.
|
167 |
+
"""
|
168 |
+
|
169 |
+
downsample = None
|
170 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
171 |
+
downsample = self.downsample_block(
|
172 |
+
inplanes=self.inplanes,
|
173 |
+
outplanes=planes*block.expansion,
|
174 |
+
stride=stride,
|
175 |
+
)
|
176 |
+
|
177 |
+
layers = []
|
178 |
+
layers.append(
|
179 |
+
block(
|
180 |
+
self.inplanes,
|
181 |
+
planes,
|
182 |
+
stride,
|
183 |
+
downsample,
|
184 |
+
relu_type=self.relu_type,
|
185 |
+
)
|
186 |
+
)
|
187 |
+
self.inplanes = planes * block.expansion
|
188 |
+
for i in range(1, blocks):
|
189 |
+
layers.append(
|
190 |
+
block(
|
191 |
+
self.inplanes,
|
192 |
+
planes,
|
193 |
+
relu_type=self.relu_type,
|
194 |
+
)
|
195 |
+
)
|
196 |
+
|
197 |
+
return nn.Sequential(*layers)
|
198 |
+
|
199 |
+
def forward(self, x):
|
200 |
+
"""forward.
|
201 |
+
|
202 |
+
:param x: torch.Tensor, input tensor with input size (B, C, T)
|
203 |
+
"""
|
204 |
+
x = self.conv1(x)
|
205 |
+
x = self.bn1(x)
|
206 |
+
x = self.relu(x)
|
207 |
+
|
208 |
+
x = self.layer1(x)
|
209 |
+
x = self.layer2(x)
|
210 |
+
x = self.layer3(x)
|
211 |
+
x = self.layer4(x)
|
212 |
+
x = self.avgpool(x)
|
213 |
+
return x
|
espnet/nets/pytorch_backend/backbones/modules/shufflenetv2.py
ADDED
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from torch.autograd import Variable
|
5 |
+
from collections import OrderedDict
|
6 |
+
from torch.nn import init
|
7 |
+
import math
|
8 |
+
|
9 |
+
import pdb
|
10 |
+
|
11 |
+
def conv_bn(inp, oup, stride):
|
12 |
+
return nn.Sequential(
|
13 |
+
nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
|
14 |
+
nn.BatchNorm2d(oup),
|
15 |
+
nn.ReLU(inplace=True)
|
16 |
+
)
|
17 |
+
|
18 |
+
|
19 |
+
def conv_1x1_bn(inp, oup):
|
20 |
+
return nn.Sequential(
|
21 |
+
nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
|
22 |
+
nn.BatchNorm2d(oup),
|
23 |
+
nn.ReLU(inplace=True)
|
24 |
+
)
|
25 |
+
|
26 |
+
def channel_shuffle(x, groups):
|
27 |
+
batchsize, num_channels, height, width = x.data.size()
|
28 |
+
|
29 |
+
channels_per_group = num_channels // groups
|
30 |
+
|
31 |
+
# reshape
|
32 |
+
x = x.view(batchsize, groups,
|
33 |
+
channels_per_group, height, width)
|
34 |
+
|
35 |
+
x = torch.transpose(x, 1, 2).contiguous()
|
36 |
+
|
37 |
+
# flatten
|
38 |
+
x = x.view(batchsize, -1, height, width)
|
39 |
+
|
40 |
+
return x
|
41 |
+
|
42 |
+
class InvertedResidual(nn.Module):
|
43 |
+
def __init__(self, inp, oup, stride, benchmodel):
|
44 |
+
super(InvertedResidual, self).__init__()
|
45 |
+
self.benchmodel = benchmodel
|
46 |
+
self.stride = stride
|
47 |
+
assert stride in [1, 2]
|
48 |
+
|
49 |
+
oup_inc = oup//2
|
50 |
+
|
51 |
+
if self.benchmodel == 1:
|
52 |
+
#assert inp == oup_inc
|
53 |
+
self.banch2 = nn.Sequential(
|
54 |
+
# pw
|
55 |
+
nn.Conv2d(oup_inc, oup_inc, 1, 1, 0, bias=False),
|
56 |
+
nn.BatchNorm2d(oup_inc),
|
57 |
+
nn.ReLU(inplace=True),
|
58 |
+
# dw
|
59 |
+
nn.Conv2d(oup_inc, oup_inc, 3, stride, 1, groups=oup_inc, bias=False),
|
60 |
+
nn.BatchNorm2d(oup_inc),
|
61 |
+
# pw-linear
|
62 |
+
nn.Conv2d(oup_inc, oup_inc, 1, 1, 0, bias=False),
|
63 |
+
nn.BatchNorm2d(oup_inc),
|
64 |
+
nn.ReLU(inplace=True),
|
65 |
+
)
|
66 |
+
else:
|
67 |
+
self.banch1 = nn.Sequential(
|
68 |
+
# dw
|
69 |
+
nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False),
|
70 |
+
nn.BatchNorm2d(inp),
|
71 |
+
# pw-linear
|
72 |
+
nn.Conv2d(inp, oup_inc, 1, 1, 0, bias=False),
|
73 |
+
nn.BatchNorm2d(oup_inc),
|
74 |
+
nn.ReLU(inplace=True),
|
75 |
+
)
|
76 |
+
|
77 |
+
self.banch2 = nn.Sequential(
|
78 |
+
# pw
|
79 |
+
nn.Conv2d(inp, oup_inc, 1, 1, 0, bias=False),
|
80 |
+
nn.BatchNorm2d(oup_inc),
|
81 |
+
nn.ReLU(inplace=True),
|
82 |
+
# dw
|
83 |
+
nn.Conv2d(oup_inc, oup_inc, 3, stride, 1, groups=oup_inc, bias=False),
|
84 |
+
nn.BatchNorm2d(oup_inc),
|
85 |
+
# pw-linear
|
86 |
+
nn.Conv2d(oup_inc, oup_inc, 1, 1, 0, bias=False),
|
87 |
+
nn.BatchNorm2d(oup_inc),
|
88 |
+
nn.ReLU(inplace=True),
|
89 |
+
)
|
90 |
+
|
91 |
+
@staticmethod
|
92 |
+
def _concat(x, out):
|
93 |
+
# concatenate along channel axis
|
94 |
+
return torch.cat((x, out), 1)
|
95 |
+
|
96 |
+
def forward(self, x):
|
97 |
+
if 1==self.benchmodel:
|
98 |
+
x1 = x[:, :(x.shape[1]//2), :, :]
|
99 |
+
x2 = x[:, (x.shape[1]//2):, :, :]
|
100 |
+
out = self._concat(x1, self.banch2(x2))
|
101 |
+
elif 2==self.benchmodel:
|
102 |
+
out = self._concat(self.banch1(x), self.banch2(x))
|
103 |
+
|
104 |
+
return channel_shuffle(out, 2)
|
105 |
+
|
106 |
+
|
107 |
+
class ShuffleNetV2(nn.Module):
|
108 |
+
def __init__(self, n_class=1000, input_size=224, width_mult=2.):
|
109 |
+
super(ShuffleNetV2, self).__init__()
|
110 |
+
|
111 |
+
assert input_size % 32 == 0, "Input size needs to be divisible by 32"
|
112 |
+
|
113 |
+
self.stage_repeats = [4, 8, 4]
|
114 |
+
# index 0 is invalid and should never be called.
|
115 |
+
# only used for indexing convenience.
|
116 |
+
if width_mult == 0.5:
|
117 |
+
self.stage_out_channels = [-1, 24, 48, 96, 192, 1024]
|
118 |
+
elif width_mult == 1.0:
|
119 |
+
self.stage_out_channels = [-1, 24, 116, 232, 464, 1024]
|
120 |
+
elif width_mult == 1.5:
|
121 |
+
self.stage_out_channels = [-1, 24, 176, 352, 704, 1024]
|
122 |
+
elif width_mult == 2.0:
|
123 |
+
self.stage_out_channels = [-1, 24, 244, 488, 976, 2048]
|
124 |
+
else:
|
125 |
+
raise ValueError(
|
126 |
+
"""Width multiplier should be in [0.5, 1.0, 1.5, 2.0]. Current value: {}""".format(width_mult))
|
127 |
+
|
128 |
+
# building first layer
|
129 |
+
input_channel = self.stage_out_channels[1]
|
130 |
+
self.conv1 = conv_bn(3, input_channel, 2)
|
131 |
+
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
132 |
+
|
133 |
+
self.features = []
|
134 |
+
# building inverted residual blocks
|
135 |
+
for idxstage in range(len(self.stage_repeats)):
|
136 |
+
numrepeat = self.stage_repeats[idxstage]
|
137 |
+
output_channel = self.stage_out_channels[idxstage+2]
|
138 |
+
for i in range(numrepeat):
|
139 |
+
if i == 0:
|
140 |
+
#inp, oup, stride, benchmodel):
|
141 |
+
self.features.append(InvertedResidual(input_channel, output_channel, 2, 2))
|
142 |
+
else:
|
143 |
+
self.features.append(InvertedResidual(input_channel, output_channel, 1, 1))
|
144 |
+
input_channel = output_channel
|
145 |
+
|
146 |
+
|
147 |
+
# make it nn.Sequential
|
148 |
+
self.features = nn.Sequential(*self.features)
|
149 |
+
|
150 |
+
# building last several layers
|
151 |
+
self.conv_last = conv_1x1_bn(input_channel, self.stage_out_channels[-1])
|
152 |
+
self.globalpool = nn.Sequential(nn.AvgPool2d(int(input_size/32)))
|
153 |
+
|
154 |
+
# building classifier
|
155 |
+
self.classifier = nn.Sequential(nn.Linear(self.stage_out_channels[-1], n_class))
|
156 |
+
|
157 |
+
def forward(self, x):
|
158 |
+
x = self.conv1(x)
|
159 |
+
x = self.maxpool(x)
|
160 |
+
x = self.features(x)
|
161 |
+
x = self.conv_last(x)
|
162 |
+
x = self.globalpool(x)
|
163 |
+
x = x.view(-1, self.stage_out_channels[-1])
|
164 |
+
x = self.classifier(x)
|
165 |
+
return x
|
espnet/nets/pytorch_backend/ctc.py
ADDED
@@ -0,0 +1,283 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from distutils.version import LooseVersion
|
2 |
+
import logging
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import six
|
6 |
+
import torch
|
7 |
+
import torch.nn.functional as F
|
8 |
+
|
9 |
+
from espnet.nets.pytorch_backend.nets_utils import to_device
|
10 |
+
|
11 |
+
|
12 |
+
class CTC(torch.nn.Module):
|
13 |
+
"""CTC module
|
14 |
+
|
15 |
+
:param int odim: dimension of outputs
|
16 |
+
:param int eprojs: number of encoder projection units
|
17 |
+
:param float dropout_rate: dropout rate (0.0 ~ 1.0)
|
18 |
+
:param str ctc_type: builtin or warpctc
|
19 |
+
:param bool reduce: reduce the CTC loss into a scalar
|
20 |
+
"""
|
21 |
+
|
22 |
+
def __init__(self, odim, eprojs, dropout_rate, ctc_type="warpctc", reduce=True):
|
23 |
+
super().__init__()
|
24 |
+
self.dropout_rate = dropout_rate
|
25 |
+
self.loss = None
|
26 |
+
self.ctc_lo = torch.nn.Linear(eprojs, odim)
|
27 |
+
self.dropout = torch.nn.Dropout(dropout_rate)
|
28 |
+
self.probs = None # for visualization
|
29 |
+
|
30 |
+
# In case of Pytorch >= 1.7.0, CTC will be always builtin
|
31 |
+
self.ctc_type = (
|
32 |
+
ctc_type
|
33 |
+
if LooseVersion(torch.__version__) < LooseVersion("1.7.0")
|
34 |
+
else "builtin"
|
35 |
+
)
|
36 |
+
|
37 |
+
if self.ctc_type == "builtin":
|
38 |
+
reduction_type = "sum" if reduce else "none"
|
39 |
+
self.ctc_loss = torch.nn.CTCLoss(
|
40 |
+
reduction=reduction_type, zero_infinity=True
|
41 |
+
)
|
42 |
+
elif self.ctc_type == "cudnnctc":
|
43 |
+
reduction_type = "sum" if reduce else "none"
|
44 |
+
self.ctc_loss = torch.nn.CTCLoss(reduction=reduction_type)
|
45 |
+
elif self.ctc_type == "warpctc":
|
46 |
+
import warpctc_pytorch as warp_ctc
|
47 |
+
|
48 |
+
self.ctc_loss = warp_ctc.CTCLoss(size_average=True, reduce=reduce)
|
49 |
+
elif self.ctc_type == "gtnctc":
|
50 |
+
from espnet.nets.pytorch_backend.gtn_ctc import GTNCTCLossFunction
|
51 |
+
|
52 |
+
self.ctc_loss = GTNCTCLossFunction.apply
|
53 |
+
else:
|
54 |
+
raise ValueError(
|
55 |
+
'ctc_type must be "builtin" or "warpctc": {}'.format(self.ctc_type)
|
56 |
+
)
|
57 |
+
|
58 |
+
self.ignore_id = -1
|
59 |
+
self.reduce = reduce
|
60 |
+
|
61 |
+
def loss_fn(self, th_pred, th_target, th_ilen, th_olen):
|
62 |
+
if self.ctc_type in ["builtin", "cudnnctc"]:
|
63 |
+
th_pred = th_pred.log_softmax(2)
|
64 |
+
# Use the deterministic CuDNN implementation of CTC loss to avoid
|
65 |
+
# [issue#17798](https://github.com/pytorch/pytorch/issues/17798)
|
66 |
+
with torch.backends.cudnn.flags(deterministic=True):
|
67 |
+
loss = self.ctc_loss(th_pred, th_target, th_ilen, th_olen)
|
68 |
+
# Batch-size average
|
69 |
+
loss = loss / th_pred.size(1)
|
70 |
+
return loss
|
71 |
+
elif self.ctc_type == "warpctc":
|
72 |
+
return self.ctc_loss(th_pred, th_target, th_ilen, th_olen)
|
73 |
+
elif self.ctc_type == "gtnctc":
|
74 |
+
targets = [t.tolist() for t in th_target]
|
75 |
+
log_probs = torch.nn.functional.log_softmax(th_pred, dim=2)
|
76 |
+
return self.ctc_loss(log_probs, targets, th_ilen, 0, "none")
|
77 |
+
else:
|
78 |
+
raise NotImplementedError
|
79 |
+
|
80 |
+
def forward(self, hs_pad, hlens, ys_pad):
|
81 |
+
"""CTC forward
|
82 |
+
|
83 |
+
:param torch.Tensor hs_pad: batch of padded hidden state sequences (B, Tmax, D)
|
84 |
+
:param torch.Tensor hlens: batch of lengths of hidden state sequences (B)
|
85 |
+
:param torch.Tensor ys_pad:
|
86 |
+
batch of padded character id sequence tensor (B, Lmax)
|
87 |
+
:return: ctc loss value
|
88 |
+
:rtype: torch.Tensor
|
89 |
+
"""
|
90 |
+
# TODO(kan-bayashi): need to make more smart way
|
91 |
+
ys = [y[y != self.ignore_id] for y in ys_pad] # parse padded ys
|
92 |
+
|
93 |
+
# zero padding for hs
|
94 |
+
ys_hat = self.ctc_lo(self.dropout(hs_pad))
|
95 |
+
if self.ctc_type != "gtnctc":
|
96 |
+
ys_hat = ys_hat.transpose(0, 1)
|
97 |
+
|
98 |
+
if self.ctc_type == "builtin":
|
99 |
+
olens = to_device(ys_hat, torch.LongTensor([len(s) for s in ys]))
|
100 |
+
hlens = hlens.long()
|
101 |
+
ys_pad = torch.cat(ys) # without this the code breaks for asr_mix
|
102 |
+
self.loss = self.loss_fn(ys_hat, ys_pad, hlens, olens)
|
103 |
+
else:
|
104 |
+
self.loss = None
|
105 |
+
hlens = torch.from_numpy(np.fromiter(hlens, dtype=np.int32))
|
106 |
+
olens = torch.from_numpy(
|
107 |
+
np.fromiter((x.size(0) for x in ys), dtype=np.int32)
|
108 |
+
)
|
109 |
+
# zero padding for ys
|
110 |
+
ys_true = torch.cat(ys).cpu().int() # batch x olen
|
111 |
+
# get ctc loss
|
112 |
+
# expected shape of seqLength x batchSize x alphabet_size
|
113 |
+
dtype = ys_hat.dtype
|
114 |
+
if self.ctc_type == "warpctc" or dtype == torch.float16:
|
115 |
+
# warpctc only supports float32
|
116 |
+
# torch.ctc does not support float16 (#1751)
|
117 |
+
ys_hat = ys_hat.to(dtype=torch.float32)
|
118 |
+
if self.ctc_type == "cudnnctc":
|
119 |
+
# use GPU when using the cuDNN implementation
|
120 |
+
ys_true = to_device(hs_pad, ys_true)
|
121 |
+
if self.ctc_type == "gtnctc":
|
122 |
+
# keep as list for gtn
|
123 |
+
ys_true = ys
|
124 |
+
self.loss = to_device(
|
125 |
+
hs_pad, self.loss_fn(ys_hat, ys_true, hlens, olens)
|
126 |
+
).to(dtype=dtype)
|
127 |
+
|
128 |
+
# get length info
|
129 |
+
logging.info(
|
130 |
+
self.__class__.__name__
|
131 |
+
+ " input lengths: "
|
132 |
+
+ "".join(str(hlens).split("\n"))
|
133 |
+
)
|
134 |
+
logging.info(
|
135 |
+
self.__class__.__name__
|
136 |
+
+ " output lengths: "
|
137 |
+
+ "".join(str(olens).split("\n"))
|
138 |
+
)
|
139 |
+
|
140 |
+
if self.reduce:
|
141 |
+
# NOTE: sum() is needed to keep consistency
|
142 |
+
# since warpctc return as tensor w/ shape (1,)
|
143 |
+
# but builtin return as tensor w/o shape (scalar).
|
144 |
+
self.loss = self.loss.sum()
|
145 |
+
logging.info("ctc loss:" + str(float(self.loss)))
|
146 |
+
|
147 |
+
return self.loss
|
148 |
+
|
149 |
+
def softmax(self, hs_pad):
|
150 |
+
"""softmax of frame activations
|
151 |
+
|
152 |
+
:param torch.Tensor hs_pad: 3d tensor (B, Tmax, eprojs)
|
153 |
+
:return: log softmax applied 3d tensor (B, Tmax, odim)
|
154 |
+
:rtype: torch.Tensor
|
155 |
+
"""
|
156 |
+
self.probs = F.softmax(self.ctc_lo(hs_pad), dim=2)
|
157 |
+
return self.probs
|
158 |
+
|
159 |
+
def log_softmax(self, hs_pad):
|
160 |
+
"""log_softmax of frame activations
|
161 |
+
|
162 |
+
:param torch.Tensor hs_pad: 3d tensor (B, Tmax, eprojs)
|
163 |
+
:return: log softmax applied 3d tensor (B, Tmax, odim)
|
164 |
+
:rtype: torch.Tensor
|
165 |
+
"""
|
166 |
+
return F.log_softmax(self.ctc_lo(hs_pad), dim=2)
|
167 |
+
|
168 |
+
def argmax(self, hs_pad):
|
169 |
+
"""argmax of frame activations
|
170 |
+
|
171 |
+
:param torch.Tensor hs_pad: 3d tensor (B, Tmax, eprojs)
|
172 |
+
:return: argmax applied 2d tensor (B, Tmax)
|
173 |
+
:rtype: torch.Tensor
|
174 |
+
"""
|
175 |
+
return torch.argmax(self.ctc_lo(hs_pad), dim=2)
|
176 |
+
|
177 |
+
def forced_align(self, h, y, blank_id=0):
|
178 |
+
"""forced alignment.
|
179 |
+
|
180 |
+
:param torch.Tensor h: hidden state sequence, 2d tensor (T, D)
|
181 |
+
:param torch.Tensor y: id sequence tensor 1d tensor (L)
|
182 |
+
:param int y: blank symbol index
|
183 |
+
:return: best alignment results
|
184 |
+
:rtype: list
|
185 |
+
"""
|
186 |
+
|
187 |
+
def interpolate_blank(label, blank_id=0):
|
188 |
+
"""Insert blank token between every two label token."""
|
189 |
+
label = np.expand_dims(label, 1)
|
190 |
+
blanks = np.zeros((label.shape[0], 1), dtype=np.int64) + blank_id
|
191 |
+
label = np.concatenate([blanks, label], axis=1)
|
192 |
+
label = label.reshape(-1)
|
193 |
+
label = np.append(label, label[0])
|
194 |
+
return label
|
195 |
+
|
196 |
+
lpz = self.log_softmax(h)
|
197 |
+
lpz = lpz.squeeze(0)
|
198 |
+
|
199 |
+
y_int = interpolate_blank(y, blank_id)
|
200 |
+
|
201 |
+
logdelta = np.zeros((lpz.size(0), len(y_int))) - 100000000000.0 # log of zero
|
202 |
+
state_path = (
|
203 |
+
np.zeros((lpz.size(0), len(y_int)), dtype=np.int16) - 1
|
204 |
+
) # state path
|
205 |
+
|
206 |
+
logdelta[0, 0] = lpz[0][y_int[0]]
|
207 |
+
logdelta[0, 1] = lpz[0][y_int[1]]
|
208 |
+
|
209 |
+
for t in six.moves.range(1, lpz.size(0)):
|
210 |
+
for s in six.moves.range(len(y_int)):
|
211 |
+
if y_int[s] == blank_id or s < 2 or y_int[s] == y_int[s - 2]:
|
212 |
+
candidates = np.array([logdelta[t - 1, s], logdelta[t - 1, s - 1]])
|
213 |
+
prev_state = [s, s - 1]
|
214 |
+
else:
|
215 |
+
candidates = np.array(
|
216 |
+
[
|
217 |
+
logdelta[t - 1, s],
|
218 |
+
logdelta[t - 1, s - 1],
|
219 |
+
logdelta[t - 1, s - 2],
|
220 |
+
]
|
221 |
+
)
|
222 |
+
prev_state = [s, s - 1, s - 2]
|
223 |
+
logdelta[t, s] = np.max(candidates) + lpz[t][y_int[s]]
|
224 |
+
state_path[t, s] = prev_state[np.argmax(candidates)]
|
225 |
+
|
226 |
+
state_seq = -1 * np.ones((lpz.size(0), 1), dtype=np.int16)
|
227 |
+
|
228 |
+
candidates = np.array(
|
229 |
+
[logdelta[-1, len(y_int) - 1], logdelta[-1, len(y_int) - 2]]
|
230 |
+
)
|
231 |
+
prev_state = [len(y_int) - 1, len(y_int) - 2]
|
232 |
+
state_seq[-1] = prev_state[np.argmax(candidates)]
|
233 |
+
for t in six.moves.range(lpz.size(0) - 2, -1, -1):
|
234 |
+
state_seq[t] = state_path[t + 1, state_seq[t + 1, 0]]
|
235 |
+
|
236 |
+
output_state_seq = []
|
237 |
+
for t in six.moves.range(0, lpz.size(0)):
|
238 |
+
output_state_seq.append(y_int[state_seq[t, 0]])
|
239 |
+
|
240 |
+
return output_state_seq
|
241 |
+
|
242 |
+
|
243 |
+
def ctc_for(args, odim, reduce=True):
|
244 |
+
"""Returns the CTC module for the given args and output dimension
|
245 |
+
|
246 |
+
:param Namespace args: the program args
|
247 |
+
:param int odim : The output dimension
|
248 |
+
:param bool reduce : return the CTC loss in a scalar
|
249 |
+
:return: the corresponding CTC module
|
250 |
+
"""
|
251 |
+
num_encs = getattr(args, "num_encs", 1) # use getattr to keep compatibility
|
252 |
+
if num_encs == 1:
|
253 |
+
# compatible with single encoder asr mode
|
254 |
+
return CTC(
|
255 |
+
odim, args.eprojs, args.dropout_rate, ctc_type=args.ctc_type, reduce=reduce
|
256 |
+
)
|
257 |
+
elif num_encs >= 1:
|
258 |
+
ctcs_list = torch.nn.ModuleList()
|
259 |
+
if args.share_ctc:
|
260 |
+
# use dropout_rate of the first encoder
|
261 |
+
ctc = CTC(
|
262 |
+
odim,
|
263 |
+
args.eprojs,
|
264 |
+
args.dropout_rate[0],
|
265 |
+
ctc_type=args.ctc_type,
|
266 |
+
reduce=reduce,
|
267 |
+
)
|
268 |
+
ctcs_list.append(ctc)
|
269 |
+
else:
|
270 |
+
for idx in range(num_encs):
|
271 |
+
ctc = CTC(
|
272 |
+
odim,
|
273 |
+
args.eprojs,
|
274 |
+
args.dropout_rate[idx],
|
275 |
+
ctc_type=args.ctc_type,
|
276 |
+
reduce=reduce,
|
277 |
+
)
|
278 |
+
ctcs_list.append(ctc)
|
279 |
+
return ctcs_list
|
280 |
+
else:
|
281 |
+
raise ValueError(
|
282 |
+
"Number of encoders needs to be more than one. {}".format(num_encs)
|
283 |
+
)
|
espnet/nets/pytorch_backend/e2e_asr_transformer.py
ADDED
@@ -0,0 +1,320 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2019 Shigeki Karita
|
2 |
+
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
3 |
+
|
4 |
+
"""Transformer speech recognition model (pytorch)."""
|
5 |
+
|
6 |
+
from argparse import Namespace
|
7 |
+
from distutils.util import strtobool
|
8 |
+
import logging
|
9 |
+
import math
|
10 |
+
|
11 |
+
import numpy
|
12 |
+
import torch
|
13 |
+
|
14 |
+
from espnet.nets.ctc_prefix_score import CTCPrefixScore
|
15 |
+
from espnet.nets.e2e_asr_common import end_detect
|
16 |
+
from espnet.nets.e2e_asr_common import ErrorCalculator
|
17 |
+
from espnet.nets.pytorch_backend.ctc import CTC
|
18 |
+
from espnet.nets.pytorch_backend.nets_utils import get_subsample
|
19 |
+
from espnet.nets.pytorch_backend.nets_utils import make_non_pad_mask
|
20 |
+
from espnet.nets.pytorch_backend.nets_utils import th_accuracy
|
21 |
+
from espnet.nets.pytorch_backend.transformer.add_sos_eos import add_sos_eos
|
22 |
+
from espnet.nets.pytorch_backend.transformer.attention import (
|
23 |
+
MultiHeadedAttention, # noqa: H301
|
24 |
+
RelPositionMultiHeadedAttention, # noqa: H301
|
25 |
+
)
|
26 |
+
from espnet.nets.pytorch_backend.transformer.decoder import Decoder
|
27 |
+
from espnet.nets.pytorch_backend.transformer.encoder import Encoder
|
28 |
+
from espnet.nets.pytorch_backend.transformer.label_smoothing_loss import (
|
29 |
+
LabelSmoothingLoss, # noqa: H301
|
30 |
+
)
|
31 |
+
from espnet.nets.pytorch_backend.transformer.mask import subsequent_mask
|
32 |
+
from espnet.nets.pytorch_backend.transformer.mask import target_mask
|
33 |
+
from espnet.nets.scorers.ctc import CTCPrefixScorer
|
34 |
+
|
35 |
+
|
36 |
+
class E2E(torch.nn.Module):
|
37 |
+
"""E2E module.
|
38 |
+
|
39 |
+
:param int idim: dimension of inputs
|
40 |
+
:param int odim: dimension of outputs
|
41 |
+
:param Namespace args: argument Namespace containing options
|
42 |
+
|
43 |
+
"""
|
44 |
+
|
45 |
+
@staticmethod
|
46 |
+
def add_arguments(parser):
|
47 |
+
"""Add arguments."""
|
48 |
+
group = parser.add_argument_group("transformer model setting")
|
49 |
+
|
50 |
+
group.add_argument(
|
51 |
+
"--transformer-init",
|
52 |
+
type=str,
|
53 |
+
default="pytorch",
|
54 |
+
choices=[
|
55 |
+
"pytorch",
|
56 |
+
"xavier_uniform",
|
57 |
+
"xavier_normal",
|
58 |
+
"kaiming_uniform",
|
59 |
+
"kaiming_normal",
|
60 |
+
],
|
61 |
+
help="how to initialize transformer parameters",
|
62 |
+
)
|
63 |
+
group.add_argument(
|
64 |
+
"--transformer-input-layer",
|
65 |
+
type=str,
|
66 |
+
default="conv2d",
|
67 |
+
choices=["conv3d", "conv2d", "conv1d", "linear", "embed"],
|
68 |
+
help="transformer input layer type",
|
69 |
+
)
|
70 |
+
group.add_argument(
|
71 |
+
"--transformer-encoder-attn-layer-type",
|
72 |
+
type=str,
|
73 |
+
default="mha",
|
74 |
+
choices=["mha", "rel_mha", "legacy_rel_mha"],
|
75 |
+
help="transformer encoder attention layer type",
|
76 |
+
)
|
77 |
+
group.add_argument(
|
78 |
+
"--transformer-attn-dropout-rate",
|
79 |
+
default=None,
|
80 |
+
type=float,
|
81 |
+
help="dropout in transformer attention. use --dropout-rate if None is set",
|
82 |
+
)
|
83 |
+
group.add_argument(
|
84 |
+
"--transformer-lr",
|
85 |
+
default=10.0,
|
86 |
+
type=float,
|
87 |
+
help="Initial value of learning rate",
|
88 |
+
)
|
89 |
+
group.add_argument(
|
90 |
+
"--transformer-warmup-steps",
|
91 |
+
default=25000,
|
92 |
+
type=int,
|
93 |
+
help="optimizer warmup steps",
|
94 |
+
)
|
95 |
+
group.add_argument(
|
96 |
+
"--transformer-length-normalized-loss",
|
97 |
+
default=True,
|
98 |
+
type=strtobool,
|
99 |
+
help="normalize loss by length",
|
100 |
+
)
|
101 |
+
group.add_argument(
|
102 |
+
"--dropout-rate",
|
103 |
+
default=0.0,
|
104 |
+
type=float,
|
105 |
+
help="Dropout rate for the encoder",
|
106 |
+
)
|
107 |
+
group.add_argument(
|
108 |
+
"--macaron-style",
|
109 |
+
default=False,
|
110 |
+
type=strtobool,
|
111 |
+
help="Whether to use macaron style for positionwise layer",
|
112 |
+
)
|
113 |
+
# -- input
|
114 |
+
group.add_argument(
|
115 |
+
"--a-upsample-ratio",
|
116 |
+
default=1,
|
117 |
+
type=int,
|
118 |
+
help="Upsample rate for audio",
|
119 |
+
)
|
120 |
+
group.add_argument(
|
121 |
+
"--relu-type",
|
122 |
+
default="swish",
|
123 |
+
type=str,
|
124 |
+
help="the type of activation layer",
|
125 |
+
)
|
126 |
+
# Encoder
|
127 |
+
group.add_argument(
|
128 |
+
"--elayers",
|
129 |
+
default=4,
|
130 |
+
type=int,
|
131 |
+
help="Number of encoder layers (for shared recognition part "
|
132 |
+
"in multi-speaker asr mode)",
|
133 |
+
)
|
134 |
+
group.add_argument(
|
135 |
+
"--eunits",
|
136 |
+
"-u",
|
137 |
+
default=300,
|
138 |
+
type=int,
|
139 |
+
help="Number of encoder hidden units",
|
140 |
+
)
|
141 |
+
group.add_argument(
|
142 |
+
"--use-cnn-module",
|
143 |
+
default=False,
|
144 |
+
type=strtobool,
|
145 |
+
help="Use convolution module or not",
|
146 |
+
)
|
147 |
+
group.add_argument(
|
148 |
+
"--cnn-module-kernel",
|
149 |
+
default=31,
|
150 |
+
type=int,
|
151 |
+
help="Kernel size of convolution module.",
|
152 |
+
)
|
153 |
+
# Attention
|
154 |
+
group.add_argument(
|
155 |
+
"--adim",
|
156 |
+
default=320,
|
157 |
+
type=int,
|
158 |
+
help="Number of attention transformation dimensions",
|
159 |
+
)
|
160 |
+
group.add_argument(
|
161 |
+
"--aheads",
|
162 |
+
default=4,
|
163 |
+
type=int,
|
164 |
+
help="Number of heads for multi head attention",
|
165 |
+
)
|
166 |
+
group.add_argument(
|
167 |
+
"--zero-triu",
|
168 |
+
default=False,
|
169 |
+
type=strtobool,
|
170 |
+
help="If true, zero the uppper triangular part of attention matrix.",
|
171 |
+
)
|
172 |
+
# Relative positional encoding
|
173 |
+
group.add_argument(
|
174 |
+
"--rel-pos-type",
|
175 |
+
type=str,
|
176 |
+
default="legacy",
|
177 |
+
choices=["legacy", "latest"],
|
178 |
+
help="Whether to use the latest relative positional encoding or the legacy one."
|
179 |
+
"The legacy relative positional encoding will be deprecated in the future."
|
180 |
+
"More Details can be found in https://github.com/espnet/espnet/pull/2816.",
|
181 |
+
)
|
182 |
+
# Decoder
|
183 |
+
group.add_argument(
|
184 |
+
"--dlayers", default=1, type=int, help="Number of decoder layers"
|
185 |
+
)
|
186 |
+
group.add_argument(
|
187 |
+
"--dunits", default=320, type=int, help="Number of decoder hidden units"
|
188 |
+
)
|
189 |
+
# -- pretrain
|
190 |
+
group.add_argument("--pretrain-dataset",
|
191 |
+
default="",
|
192 |
+
type=str,
|
193 |
+
help='pre-trained dataset for encoder'
|
194 |
+
)
|
195 |
+
# -- custom name
|
196 |
+
group.add_argument("--custom-pretrain-name",
|
197 |
+
default="",
|
198 |
+
type=str,
|
199 |
+
help='pre-trained model for encoder'
|
200 |
+
)
|
201 |
+
return parser
|
202 |
+
|
203 |
+
@property
|
204 |
+
def attention_plot_class(self):
|
205 |
+
"""Return PlotAttentionReport."""
|
206 |
+
return PlotAttentionReport
|
207 |
+
|
208 |
+
def __init__(self, odim, args, ignore_id=-1):
|
209 |
+
"""Construct an E2E object.
|
210 |
+
:param int odim: dimension of outputs
|
211 |
+
:param Namespace args: argument Namespace containing options
|
212 |
+
"""
|
213 |
+
torch.nn.Module.__init__(self)
|
214 |
+
if args.transformer_attn_dropout_rate is None:
|
215 |
+
args.transformer_attn_dropout_rate = args.dropout_rate
|
216 |
+
# Check the relative positional encoding type
|
217 |
+
self.rel_pos_type = getattr(args, "rel_pos_type", None)
|
218 |
+
if self.rel_pos_type is None and args.transformer_encoder_attn_layer_type == "rel_mha":
|
219 |
+
args.transformer_encoder_attn_layer_type = "legacy_rel_mha"
|
220 |
+
logging.warning(
|
221 |
+
"Using legacy_rel_pos and it will be deprecated in the future."
|
222 |
+
)
|
223 |
+
|
224 |
+
idim = 80
|
225 |
+
|
226 |
+
self.encoder = Encoder(
|
227 |
+
idim=idim,
|
228 |
+
attention_dim=args.adim,
|
229 |
+
attention_heads=args.aheads,
|
230 |
+
linear_units=args.eunits,
|
231 |
+
num_blocks=args.elayers,
|
232 |
+
input_layer=args.transformer_input_layer,
|
233 |
+
dropout_rate=args.dropout_rate,
|
234 |
+
positional_dropout_rate=args.dropout_rate,
|
235 |
+
attention_dropout_rate=args.transformer_attn_dropout_rate,
|
236 |
+
encoder_attn_layer_type=args.transformer_encoder_attn_layer_type,
|
237 |
+
macaron_style=args.macaron_style,
|
238 |
+
use_cnn_module=args.use_cnn_module,
|
239 |
+
cnn_module_kernel=args.cnn_module_kernel,
|
240 |
+
zero_triu=getattr(args, "zero_triu", False),
|
241 |
+
a_upsample_ratio=args.a_upsample_ratio,
|
242 |
+
relu_type=getattr(args, "relu_type", "swish"),
|
243 |
+
)
|
244 |
+
|
245 |
+
self.transformer_input_layer = args.transformer_input_layer
|
246 |
+
self.a_upsample_ratio = args.a_upsample_ratio
|
247 |
+
|
248 |
+
if args.mtlalpha < 1:
|
249 |
+
self.decoder = Decoder(
|
250 |
+
odim=odim,
|
251 |
+
attention_dim=args.adim,
|
252 |
+
attention_heads=args.aheads,
|
253 |
+
linear_units=args.dunits,
|
254 |
+
num_blocks=args.dlayers,
|
255 |
+
dropout_rate=args.dropout_rate,
|
256 |
+
positional_dropout_rate=args.dropout_rate,
|
257 |
+
self_attention_dropout_rate=args.transformer_attn_dropout_rate,
|
258 |
+
src_attention_dropout_rate=args.transformer_attn_dropout_rate,
|
259 |
+
)
|
260 |
+
else:
|
261 |
+
self.decoder = None
|
262 |
+
self.blank = 0
|
263 |
+
self.sos = odim - 1
|
264 |
+
self.eos = odim - 1
|
265 |
+
self.odim = odim
|
266 |
+
self.ignore_id = ignore_id
|
267 |
+
self.subsample = get_subsample(args, mode="asr", arch="transformer")
|
268 |
+
|
269 |
+
# self.lsm_weight = a
|
270 |
+
self.criterion = LabelSmoothingLoss(
|
271 |
+
self.odim,
|
272 |
+
self.ignore_id,
|
273 |
+
args.lsm_weight,
|
274 |
+
args.transformer_length_normalized_loss,
|
275 |
+
)
|
276 |
+
|
277 |
+
self.adim = args.adim
|
278 |
+
self.mtlalpha = args.mtlalpha
|
279 |
+
if args.mtlalpha > 0.0:
|
280 |
+
self.ctc = CTC(
|
281 |
+
odim, args.adim, args.dropout_rate, ctc_type=args.ctc_type, reduce=True
|
282 |
+
)
|
283 |
+
else:
|
284 |
+
self.ctc = None
|
285 |
+
|
286 |
+
if args.report_cer or args.report_wer:
|
287 |
+
self.error_calculator = ErrorCalculator(
|
288 |
+
args.char_list,
|
289 |
+
args.sym_space,
|
290 |
+
args.sym_blank,
|
291 |
+
args.report_cer,
|
292 |
+
args.report_wer,
|
293 |
+
)
|
294 |
+
else:
|
295 |
+
self.error_calculator = None
|
296 |
+
self.rnnlm = None
|
297 |
+
|
298 |
+
def scorers(self):
|
299 |
+
"""Scorers."""
|
300 |
+
return dict(decoder=self.decoder, ctc=CTCPrefixScorer(self.ctc, self.eos))
|
301 |
+
|
302 |
+
def encode(self, x, extract_resnet_feats=False):
|
303 |
+
"""Encode acoustic features.
|
304 |
+
|
305 |
+
:param ndarray x: source acoustic feature (T, D)
|
306 |
+
:return: encoder outputs
|
307 |
+
:rtype: torch.Tensor
|
308 |
+
"""
|
309 |
+
self.eval()
|
310 |
+
x = torch.as_tensor(x).unsqueeze(0)
|
311 |
+
if extract_resnet_feats:
|
312 |
+
resnet_feats = self.encoder(
|
313 |
+
x,
|
314 |
+
None,
|
315 |
+
extract_resnet_feats=extract_resnet_feats,
|
316 |
+
)
|
317 |
+
return resnet_feats.squeeze(0)
|
318 |
+
else:
|
319 |
+
enc_output, _ = self.encoder(x, None)
|
320 |
+
return enc_output.squeeze(0)
|
espnet/nets/pytorch_backend/e2e_asr_transformer_av.py
ADDED
@@ -0,0 +1,352 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2019 Shigeki Karita
|
2 |
+
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
3 |
+
|
4 |
+
"""Transformer speech recognition model (pytorch)."""
|
5 |
+
|
6 |
+
from argparse import Namespace
|
7 |
+
from distutils.util import strtobool
|
8 |
+
import logging
|
9 |
+
import math
|
10 |
+
|
11 |
+
import numpy
|
12 |
+
import torch
|
13 |
+
|
14 |
+
from espnet.nets.ctc_prefix_score import CTCPrefixScore
|
15 |
+
from espnet.nets.e2e_asr_common import end_detect
|
16 |
+
from espnet.nets.e2e_asr_common import ErrorCalculator
|
17 |
+
from espnet.nets.pytorch_backend.ctc import CTC
|
18 |
+
from espnet.nets.pytorch_backend.nets_utils import get_subsample
|
19 |
+
from espnet.nets.pytorch_backend.nets_utils import make_non_pad_mask
|
20 |
+
from espnet.nets.pytorch_backend.nets_utils import th_accuracy
|
21 |
+
from espnet.nets.pytorch_backend.transformer.add_sos_eos import add_sos_eos
|
22 |
+
from espnet.nets.pytorch_backend.transformer.attention import (
|
23 |
+
MultiHeadedAttention, # noqa: H301
|
24 |
+
RelPositionMultiHeadedAttention, # noqa: H301
|
25 |
+
)
|
26 |
+
from espnet.nets.pytorch_backend.transformer.decoder import Decoder
|
27 |
+
from espnet.nets.pytorch_backend.transformer.encoder import Encoder
|
28 |
+
from espnet.nets.pytorch_backend.transformer.label_smoothing_loss import (
|
29 |
+
LabelSmoothingLoss, # noqa: H301
|
30 |
+
)
|
31 |
+
from espnet.nets.pytorch_backend.transformer.mask import subsequent_mask
|
32 |
+
from espnet.nets.pytorch_backend.transformer.mask import target_mask
|
33 |
+
from espnet.nets.scorers.ctc import CTCPrefixScorer
|
34 |
+
from espnet.nets.pytorch_backend.nets_utils import MLPHead
|
35 |
+
|
36 |
+
|
37 |
+
class E2E(torch.nn.Module):
|
38 |
+
"""E2E module.
|
39 |
+
|
40 |
+
:param int idim: dimension of inputs
|
41 |
+
:param int odim: dimension of outputs
|
42 |
+
:param Namespace args: argument Namespace containing options
|
43 |
+
|
44 |
+
"""
|
45 |
+
|
46 |
+
@staticmethod
|
47 |
+
def add_arguments(parser):
|
48 |
+
"""Add arguments."""
|
49 |
+
group = parser.add_argument_group("transformer model setting")
|
50 |
+
|
51 |
+
group.add_argument(
|
52 |
+
"--transformer-init",
|
53 |
+
type=str,
|
54 |
+
default="pytorch",
|
55 |
+
choices=[
|
56 |
+
"pytorch",
|
57 |
+
"xavier_uniform",
|
58 |
+
"xavier_normal",
|
59 |
+
"kaiming_uniform",
|
60 |
+
"kaiming_normal",
|
61 |
+
],
|
62 |
+
help="how to initialize transformer parameters",
|
63 |
+
)
|
64 |
+
group.add_argument(
|
65 |
+
"--transformer-input-layer",
|
66 |
+
type=str,
|
67 |
+
default="conv2d",
|
68 |
+
choices=["conv3d", "conv2d", "conv1d", "linear", "embed"],
|
69 |
+
help="transformer input layer type",
|
70 |
+
)
|
71 |
+
group.add_argument(
|
72 |
+
"--transformer-encoder-attn-layer-type",
|
73 |
+
type=str,
|
74 |
+
default="mha",
|
75 |
+
choices=["mha", "rel_mha", "legacy_rel_mha"],
|
76 |
+
help="transformer encoder attention layer type",
|
77 |
+
)
|
78 |
+
group.add_argument(
|
79 |
+
"--transformer-attn-dropout-rate",
|
80 |
+
default=None,
|
81 |
+
type=float,
|
82 |
+
help="dropout in transformer attention. use --dropout-rate if None is set",
|
83 |
+
)
|
84 |
+
group.add_argument(
|
85 |
+
"--transformer-lr",
|
86 |
+
default=10.0,
|
87 |
+
type=float,
|
88 |
+
help="Initial value of learning rate",
|
89 |
+
)
|
90 |
+
group.add_argument(
|
91 |
+
"--transformer-warmup-steps",
|
92 |
+
default=25000,
|
93 |
+
type=int,
|
94 |
+
help="optimizer warmup steps",
|
95 |
+
)
|
96 |
+
group.add_argument(
|
97 |
+
"--transformer-length-normalized-loss",
|
98 |
+
default=True,
|
99 |
+
type=strtobool,
|
100 |
+
help="normalize loss by length",
|
101 |
+
)
|
102 |
+
group.add_argument(
|
103 |
+
"--dropout-rate",
|
104 |
+
default=0.0,
|
105 |
+
type=float,
|
106 |
+
help="Dropout rate for the encoder",
|
107 |
+
)
|
108 |
+
group.add_argument(
|
109 |
+
"--macaron-style",
|
110 |
+
default=False,
|
111 |
+
type=strtobool,
|
112 |
+
help="Whether to use macaron style for positionwise layer",
|
113 |
+
)
|
114 |
+
# -- input
|
115 |
+
group.add_argument(
|
116 |
+
"--a-upsample-ratio",
|
117 |
+
default=1,
|
118 |
+
type=int,
|
119 |
+
help="Upsample rate for audio",
|
120 |
+
)
|
121 |
+
group.add_argument(
|
122 |
+
"--relu-type",
|
123 |
+
default="swish",
|
124 |
+
type=str,
|
125 |
+
help="the type of activation layer",
|
126 |
+
)
|
127 |
+
# Encoder
|
128 |
+
group.add_argument(
|
129 |
+
"--elayers",
|
130 |
+
default=4,
|
131 |
+
type=int,
|
132 |
+
help="Number of encoder layers (for shared recognition part "
|
133 |
+
"in multi-speaker asr mode)",
|
134 |
+
)
|
135 |
+
group.add_argument(
|
136 |
+
"--eunits",
|
137 |
+
"-u",
|
138 |
+
default=300,
|
139 |
+
type=int,
|
140 |
+
help="Number of encoder hidden units",
|
141 |
+
)
|
142 |
+
group.add_argument(
|
143 |
+
"--use-cnn-module",
|
144 |
+
default=False,
|
145 |
+
type=strtobool,
|
146 |
+
help="Use convolution module or not",
|
147 |
+
)
|
148 |
+
group.add_argument(
|
149 |
+
"--cnn-module-kernel",
|
150 |
+
default=31,
|
151 |
+
type=int,
|
152 |
+
help="Kernel size of convolution module.",
|
153 |
+
)
|
154 |
+
# Attention
|
155 |
+
group.add_argument(
|
156 |
+
"--adim",
|
157 |
+
default=320,
|
158 |
+
type=int,
|
159 |
+
help="Number of attention transformation dimensions",
|
160 |
+
)
|
161 |
+
group.add_argument(
|
162 |
+
"--aheads",
|
163 |
+
default=4,
|
164 |
+
type=int,
|
165 |
+
help="Number of heads for multi head attention",
|
166 |
+
)
|
167 |
+
group.add_argument(
|
168 |
+
"--zero-triu",
|
169 |
+
default=False,
|
170 |
+
type=strtobool,
|
171 |
+
help="If true, zero the uppper triangular part of attention matrix.",
|
172 |
+
)
|
173 |
+
# Relative positional encoding
|
174 |
+
group.add_argument(
|
175 |
+
"--rel-pos-type",
|
176 |
+
type=str,
|
177 |
+
default="legacy",
|
178 |
+
choices=["legacy", "latest"],
|
179 |
+
help="Whether to use the latest relative positional encoding or the legacy one."
|
180 |
+
"The legacy relative positional encoding will be deprecated in the future."
|
181 |
+
"More Details can be found in https://github.com/espnet/espnet/pull/2816.",
|
182 |
+
)
|
183 |
+
# Decoder
|
184 |
+
group.add_argument(
|
185 |
+
"--dlayers", default=1, type=int, help="Number of decoder layers"
|
186 |
+
)
|
187 |
+
group.add_argument(
|
188 |
+
"--dunits", default=320, type=int, help="Number of decoder hidden units"
|
189 |
+
)
|
190 |
+
# -- pretrain
|
191 |
+
group.add_argument("--pretrain-dataset",
|
192 |
+
default="",
|
193 |
+
type=str,
|
194 |
+
help='pre-trained dataset for encoder'
|
195 |
+
)
|
196 |
+
# -- custom name
|
197 |
+
group.add_argument("--custom-pretrain-name",
|
198 |
+
default="",
|
199 |
+
type=str,
|
200 |
+
help='pre-trained model for encoder'
|
201 |
+
)
|
202 |
+
return parser
|
203 |
+
|
204 |
+
@property
|
205 |
+
def attention_plot_class(self):
|
206 |
+
"""Return PlotAttentionReport."""
|
207 |
+
return PlotAttentionReport
|
208 |
+
|
209 |
+
def __init__(self, odim, args, ignore_id=-1):
|
210 |
+
"""Construct an E2E object.
|
211 |
+
:param int odim: dimension of outputs
|
212 |
+
:param Namespace args: argument Namespace containing options
|
213 |
+
"""
|
214 |
+
torch.nn.Module.__init__(self)
|
215 |
+
if args.transformer_attn_dropout_rate is None:
|
216 |
+
args.transformer_attn_dropout_rate = args.dropout_rate
|
217 |
+
# Check the relative positional encoding type
|
218 |
+
self.rel_pos_type = getattr(args, "rel_pos_type", None)
|
219 |
+
if self.rel_pos_type is None and args.transformer_encoder_attn_layer_type == "rel_mha":
|
220 |
+
args.transformer_encoder_attn_layer_type = "legacy_rel_mha"
|
221 |
+
logging.warning(
|
222 |
+
"Using legacy_rel_pos and it will be deprecated in the future."
|
223 |
+
)
|
224 |
+
|
225 |
+
idim = 80
|
226 |
+
|
227 |
+
self.encoder = Encoder(
|
228 |
+
idim=idim,
|
229 |
+
attention_dim=args.adim,
|
230 |
+
attention_heads=args.aheads,
|
231 |
+
linear_units=args.eunits,
|
232 |
+
num_blocks=args.elayers,
|
233 |
+
input_layer=args.transformer_input_layer,
|
234 |
+
dropout_rate=args.dropout_rate,
|
235 |
+
positional_dropout_rate=args.dropout_rate,
|
236 |
+
attention_dropout_rate=args.transformer_attn_dropout_rate,
|
237 |
+
encoder_attn_layer_type=args.transformer_encoder_attn_layer_type,
|
238 |
+
macaron_style=args.macaron_style,
|
239 |
+
use_cnn_module=args.use_cnn_module,
|
240 |
+
cnn_module_kernel=args.cnn_module_kernel,
|
241 |
+
zero_triu=getattr(args, "zero_triu", False),
|
242 |
+
a_upsample_ratio=args.a_upsample_ratio,
|
243 |
+
relu_type=getattr(args, "relu_type", "swish"),
|
244 |
+
)
|
245 |
+
|
246 |
+
self.transformer_input_layer = args.transformer_input_layer
|
247 |
+
self.a_upsample_ratio = args.a_upsample_ratio
|
248 |
+
|
249 |
+
self.aux_encoder = Encoder(
|
250 |
+
idim=idim,
|
251 |
+
attention_dim=args.aux_adim,
|
252 |
+
attention_heads=args.aux_aheads,
|
253 |
+
linear_units=args.aux_eunits,
|
254 |
+
num_blocks=args.aux_elayers,
|
255 |
+
input_layer=args.aux_transformer_input_layer,
|
256 |
+
dropout_rate=args.aux_dropout_rate,
|
257 |
+
positional_dropout_rate=args.aux_dropout_rate,
|
258 |
+
attention_dropout_rate=args.aux_transformer_attn_dropout_rate,
|
259 |
+
encoder_attn_layer_type=args.aux_transformer_encoder_attn_layer_type,
|
260 |
+
macaron_style=args.aux_macaron_style,
|
261 |
+
use_cnn_module=args.aux_use_cnn_module,
|
262 |
+
cnn_module_kernel=args.aux_cnn_module_kernel,
|
263 |
+
zero_triu=getattr(args, "aux_zero_triu", False),
|
264 |
+
a_upsample_ratio=args.aux_a_upsample_ratio,
|
265 |
+
relu_type=getattr(args, "aux_relu_type", "swish"),
|
266 |
+
)
|
267 |
+
self.aux_transformer_input_layer = args.aux_transformer_input_layer
|
268 |
+
|
269 |
+
self.fusion = MLPHead(
|
270 |
+
idim=args.adim + args.aux_adim,
|
271 |
+
hdim=args.fusion_hdim,
|
272 |
+
odim=args.adim,
|
273 |
+
norm=args.fusion_norm,
|
274 |
+
)
|
275 |
+
|
276 |
+
if args.mtlalpha < 1:
|
277 |
+
self.decoder = Decoder(
|
278 |
+
odim=odim,
|
279 |
+
attention_dim=args.adim,
|
280 |
+
attention_heads=args.aheads,
|
281 |
+
linear_units=args.dunits,
|
282 |
+
num_blocks=args.dlayers,
|
283 |
+
dropout_rate=args.dropout_rate,
|
284 |
+
positional_dropout_rate=args.dropout_rate,
|
285 |
+
self_attention_dropout_rate=args.transformer_attn_dropout_rate,
|
286 |
+
src_attention_dropout_rate=args.transformer_attn_dropout_rate,
|
287 |
+
)
|
288 |
+
else:
|
289 |
+
self.decoder = None
|
290 |
+
self.blank = 0
|
291 |
+
self.sos = odim - 1
|
292 |
+
self.eos = odim - 1
|
293 |
+
self.odim = odim
|
294 |
+
self.ignore_id = ignore_id
|
295 |
+
self.subsample = get_subsample(args, mode="asr", arch="transformer")
|
296 |
+
|
297 |
+
# self.lsm_weight = a
|
298 |
+
self.criterion = LabelSmoothingLoss(
|
299 |
+
self.odim,
|
300 |
+
self.ignore_id,
|
301 |
+
args.lsm_weight,
|
302 |
+
args.transformer_length_normalized_loss,
|
303 |
+
)
|
304 |
+
|
305 |
+
self.adim = args.adim
|
306 |
+
self.mtlalpha = args.mtlalpha
|
307 |
+
if args.mtlalpha > 0.0:
|
308 |
+
self.ctc = CTC(
|
309 |
+
odim, args.adim, args.dropout_rate, ctc_type=args.ctc_type, reduce=True
|
310 |
+
)
|
311 |
+
else:
|
312 |
+
self.ctc = None
|
313 |
+
|
314 |
+
if args.report_cer or args.report_wer:
|
315 |
+
self.error_calculator = ErrorCalculator(
|
316 |
+
args.char_list,
|
317 |
+
args.sym_space,
|
318 |
+
args.sym_blank,
|
319 |
+
args.report_cer,
|
320 |
+
args.report_wer,
|
321 |
+
)
|
322 |
+
else:
|
323 |
+
self.error_calculator = None
|
324 |
+
self.rnnlm = None
|
325 |
+
|
326 |
+
def scorers(self):
|
327 |
+
"""Scorers."""
|
328 |
+
return dict(decoder=self.decoder, ctc=CTCPrefixScorer(self.ctc, self.eos))
|
329 |
+
|
330 |
+
def encode(self, x, aux_x, extract_resnet_feats=False):
|
331 |
+
"""Encode acoustic features.
|
332 |
+
|
333 |
+
:param ndarray x: source acoustic feature (T, D)
|
334 |
+
:return: encoder outputs
|
335 |
+
:rtype: torch.Tensor
|
336 |
+
"""
|
337 |
+
self.eval()
|
338 |
+
if extract_resnet_feats:
|
339 |
+
x = torch.as_tensor(x).unsqueeze(0)
|
340 |
+
resnet_feats = self.encoder(
|
341 |
+
x,
|
342 |
+
None,
|
343 |
+
extract_resnet_feats=extract_resnet_feats,
|
344 |
+
)
|
345 |
+
return resnet_feats.squeeze(0)
|
346 |
+
else:
|
347 |
+
x = torch.as_tensor(x).unsqueeze(0)
|
348 |
+
aux_x = torch.as_tensor(aux_x).unsqueeze(0)
|
349 |
+
feat, _ = self.encoder(x, None)
|
350 |
+
aux_feat, _ = self.aux_encoder(aux_x, None)
|
351 |
+
fus_output = self.fusion(torch.cat((feat, aux_feat), dim=-1))
|
352 |
+
return fus_output.squeeze(0)
|
espnet/nets/pytorch_backend/lm/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
"""Initialize sub package."""
|
espnet/nets/pytorch_backend/lm/default.py
ADDED
@@ -0,0 +1,431 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Default Recurrent Neural Network Languge Model in `lm_train.py`."""
|
2 |
+
|
3 |
+
from typing import Any
|
4 |
+
from typing import List
|
5 |
+
from typing import Tuple
|
6 |
+
|
7 |
+
import logging
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
import torch.nn.functional as F
|
11 |
+
|
12 |
+
from espnet.nets.lm_interface import LMInterface
|
13 |
+
from espnet.nets.pytorch_backend.e2e_asr import to_device
|
14 |
+
from espnet.nets.scorer_interface import BatchScorerInterface
|
15 |
+
from espnet.utils.cli_utils import strtobool
|
16 |
+
|
17 |
+
|
18 |
+
class DefaultRNNLM(BatchScorerInterface, LMInterface, nn.Module):
|
19 |
+
"""Default RNNLM for `LMInterface` Implementation.
|
20 |
+
|
21 |
+
Note:
|
22 |
+
PyTorch seems to have memory leak when one GPU compute this after data parallel.
|
23 |
+
If parallel GPUs compute this, it seems to be fine.
|
24 |
+
See also https://github.com/espnet/espnet/issues/1075
|
25 |
+
|
26 |
+
"""
|
27 |
+
|
28 |
+
@staticmethod
|
29 |
+
def add_arguments(parser):
|
30 |
+
"""Add arguments to command line argument parser."""
|
31 |
+
parser.add_argument(
|
32 |
+
"--type",
|
33 |
+
type=str,
|
34 |
+
default="lstm",
|
35 |
+
nargs="?",
|
36 |
+
choices=["lstm", "gru"],
|
37 |
+
help="Which type of RNN to use",
|
38 |
+
)
|
39 |
+
parser.add_argument(
|
40 |
+
"--layer", "-l", type=int, default=2, help="Number of hidden layers"
|
41 |
+
)
|
42 |
+
parser.add_argument(
|
43 |
+
"--unit", "-u", type=int, default=650, help="Number of hidden units"
|
44 |
+
)
|
45 |
+
parser.add_argument(
|
46 |
+
"--embed-unit",
|
47 |
+
default=None,
|
48 |
+
type=int,
|
49 |
+
help="Number of hidden units in embedding layer, "
|
50 |
+
"if it is not specified, it keeps the same number with hidden units.",
|
51 |
+
)
|
52 |
+
parser.add_argument(
|
53 |
+
"--dropout-rate", type=float, default=0.5, help="dropout probability"
|
54 |
+
)
|
55 |
+
parser.add_argument(
|
56 |
+
"--emb-dropout-rate",
|
57 |
+
type=float,
|
58 |
+
default=0.0,
|
59 |
+
help="emb dropout probability",
|
60 |
+
)
|
61 |
+
parser.add_argument(
|
62 |
+
"--tie-weights",
|
63 |
+
type=strtobool,
|
64 |
+
default=False,
|
65 |
+
help="Tie input and output embeddings",
|
66 |
+
)
|
67 |
+
return parser
|
68 |
+
|
69 |
+
def __init__(self, n_vocab, args):
|
70 |
+
"""Initialize class.
|
71 |
+
|
72 |
+
Args:
|
73 |
+
n_vocab (int): The size of the vocabulary
|
74 |
+
args (argparse.Namespace): configurations. see py:method:`add_arguments`
|
75 |
+
|
76 |
+
"""
|
77 |
+
nn.Module.__init__(self)
|
78 |
+
# NOTE: for a compatibility with less than 0.5.0 version models
|
79 |
+
dropout_rate = getattr(args, "dropout_rate", 0.0)
|
80 |
+
# NOTE: for a compatibility with less than 0.6.1 version models
|
81 |
+
embed_unit = getattr(args, "embed_unit", None)
|
82 |
+
# NOTE: for a compatibility with less than 0.9.7 version models
|
83 |
+
emb_dropout_rate = getattr(args, "emb_dropout_rate", 0.0)
|
84 |
+
# NOTE: for a compatibility with less than 0.9.7 version models
|
85 |
+
tie_weights = getattr(args, "tie_weights", False)
|
86 |
+
|
87 |
+
self.model = ClassifierWithState(
|
88 |
+
RNNLM(
|
89 |
+
n_vocab,
|
90 |
+
args.layer,
|
91 |
+
args.unit,
|
92 |
+
embed_unit,
|
93 |
+
args.type,
|
94 |
+
dropout_rate,
|
95 |
+
emb_dropout_rate,
|
96 |
+
tie_weights,
|
97 |
+
)
|
98 |
+
)
|
99 |
+
|
100 |
+
def state_dict(self):
|
101 |
+
"""Dump state dict."""
|
102 |
+
return self.model.state_dict()
|
103 |
+
|
104 |
+
def load_state_dict(self, d):
|
105 |
+
"""Load state dict."""
|
106 |
+
self.model.load_state_dict(d)
|
107 |
+
|
108 |
+
def forward(self, x, t):
|
109 |
+
"""Compute LM loss value from buffer sequences.
|
110 |
+
|
111 |
+
Args:
|
112 |
+
x (torch.Tensor): Input ids. (batch, len)
|
113 |
+
t (torch.Tensor): Target ids. (batch, len)
|
114 |
+
|
115 |
+
Returns:
|
116 |
+
tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Tuple of
|
117 |
+
loss to backward (scalar),
|
118 |
+
negative log-likelihood of t: -log p(t) (scalar) and
|
119 |
+
the number of elements in x (scalar)
|
120 |
+
|
121 |
+
Notes:
|
122 |
+
The last two return values are used
|
123 |
+
in perplexity: p(t)^{-n} = exp(-log p(t) / n)
|
124 |
+
|
125 |
+
"""
|
126 |
+
loss = 0
|
127 |
+
logp = 0
|
128 |
+
count = torch.tensor(0).long()
|
129 |
+
state = None
|
130 |
+
batch_size, sequence_length = x.shape
|
131 |
+
for i in range(sequence_length):
|
132 |
+
# Compute the loss at this time step and accumulate it
|
133 |
+
state, loss_batch = self.model(state, x[:, i], t[:, i])
|
134 |
+
non_zeros = torch.sum(x[:, i] != 0, dtype=loss_batch.dtype)
|
135 |
+
loss += loss_batch.mean() * non_zeros
|
136 |
+
logp += torch.sum(loss_batch * non_zeros)
|
137 |
+
count += int(non_zeros)
|
138 |
+
return loss / batch_size, loss, count.to(loss.device)
|
139 |
+
|
140 |
+
def score(self, y, state, x):
|
141 |
+
"""Score new token.
|
142 |
+
|
143 |
+
Args:
|
144 |
+
y (torch.Tensor): 1D torch.int64 prefix tokens.
|
145 |
+
state: Scorer state for prefix tokens
|
146 |
+
x (torch.Tensor): 2D encoder feature that generates ys.
|
147 |
+
|
148 |
+
Returns:
|
149 |
+
tuple[torch.Tensor, Any]: Tuple of
|
150 |
+
torch.float32 scores for next token (n_vocab)
|
151 |
+
and next state for ys
|
152 |
+
|
153 |
+
"""
|
154 |
+
new_state, scores = self.model.predict(state, y[-1].unsqueeze(0))
|
155 |
+
return scores.squeeze(0), new_state
|
156 |
+
|
157 |
+
def final_score(self, state):
|
158 |
+
"""Score eos.
|
159 |
+
|
160 |
+
Args:
|
161 |
+
state: Scorer state for prefix tokens
|
162 |
+
|
163 |
+
Returns:
|
164 |
+
float: final score
|
165 |
+
|
166 |
+
"""
|
167 |
+
return self.model.final(state)
|
168 |
+
|
169 |
+
# batch beam search API (see BatchScorerInterface)
|
170 |
+
def batch_score(
|
171 |
+
self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor
|
172 |
+
) -> Tuple[torch.Tensor, List[Any]]:
|
173 |
+
"""Score new token batch.
|
174 |
+
|
175 |
+
Args:
|
176 |
+
ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen).
|
177 |
+
states (List[Any]): Scorer states for prefix tokens.
|
178 |
+
xs (torch.Tensor):
|
179 |
+
The encoder feature that generates ys (n_batch, xlen, n_feat).
|
180 |
+
|
181 |
+
Returns:
|
182 |
+
tuple[torch.Tensor, List[Any]]: Tuple of
|
183 |
+
batchfied scores for next token with shape of `(n_batch, n_vocab)`
|
184 |
+
and next state list for ys.
|
185 |
+
|
186 |
+
"""
|
187 |
+
# merge states
|
188 |
+
n_batch = len(ys)
|
189 |
+
n_layers = self.model.predictor.n_layers
|
190 |
+
if self.model.predictor.typ == "lstm":
|
191 |
+
keys = ("c", "h")
|
192 |
+
else:
|
193 |
+
keys = ("h",)
|
194 |
+
|
195 |
+
if states[0] is None:
|
196 |
+
states = None
|
197 |
+
else:
|
198 |
+
# transpose state of [batch, key, layer] into [key, layer, batch]
|
199 |
+
states = {
|
200 |
+
k: [
|
201 |
+
torch.stack([states[b][k][i] for b in range(n_batch)])
|
202 |
+
for i in range(n_layers)
|
203 |
+
]
|
204 |
+
for k in keys
|
205 |
+
}
|
206 |
+
states, logp = self.model.predict(states, ys[:, -1])
|
207 |
+
|
208 |
+
# transpose state of [key, layer, batch] into [batch, key, layer]
|
209 |
+
return (
|
210 |
+
logp,
|
211 |
+
[
|
212 |
+
{k: [states[k][i][b] for i in range(n_layers)] for k in keys}
|
213 |
+
for b in range(n_batch)
|
214 |
+
],
|
215 |
+
)
|
216 |
+
|
217 |
+
|
218 |
+
class ClassifierWithState(nn.Module):
|
219 |
+
"""A wrapper for pytorch RNNLM."""
|
220 |
+
|
221 |
+
def __init__(
|
222 |
+
self, predictor, lossfun=nn.CrossEntropyLoss(reduction="none"), label_key=-1
|
223 |
+
):
|
224 |
+
"""Initialize class.
|
225 |
+
|
226 |
+
:param torch.nn.Module predictor : The RNNLM
|
227 |
+
:param function lossfun : The loss function to use
|
228 |
+
:param int/str label_key :
|
229 |
+
|
230 |
+
"""
|
231 |
+
if not (isinstance(label_key, (int, str))):
|
232 |
+
raise TypeError("label_key must be int or str, but is %s" % type(label_key))
|
233 |
+
super(ClassifierWithState, self).__init__()
|
234 |
+
self.lossfun = lossfun
|
235 |
+
self.y = None
|
236 |
+
self.loss = None
|
237 |
+
self.label_key = label_key
|
238 |
+
self.predictor = predictor
|
239 |
+
|
240 |
+
def forward(self, state, *args, **kwargs):
|
241 |
+
"""Compute the loss value for an input and label pair.
|
242 |
+
|
243 |
+
Notes:
|
244 |
+
It also computes accuracy and stores it to the attribute.
|
245 |
+
When ``label_key`` is ``int``, the corresponding element in ``args``
|
246 |
+
is treated as ground truth labels. And when it is ``str``, the
|
247 |
+
element in ``kwargs`` is used.
|
248 |
+
The all elements of ``args`` and ``kwargs`` except the groundtruth
|
249 |
+
labels are features.
|
250 |
+
It feeds features to the predictor and compare the result
|
251 |
+
with ground truth labels.
|
252 |
+
|
253 |
+
:param torch.Tensor state : the LM state
|
254 |
+
:param list[torch.Tensor] args : Input minibatch
|
255 |
+
:param dict[torch.Tensor] kwargs : Input minibatch
|
256 |
+
:return loss value
|
257 |
+
:rtype torch.Tensor
|
258 |
+
|
259 |
+
"""
|
260 |
+
if isinstance(self.label_key, int):
|
261 |
+
if not (-len(args) <= self.label_key < len(args)):
|
262 |
+
msg = "Label key %d is out of bounds" % self.label_key
|
263 |
+
raise ValueError(msg)
|
264 |
+
t = args[self.label_key]
|
265 |
+
if self.label_key == -1:
|
266 |
+
args = args[:-1]
|
267 |
+
else:
|
268 |
+
args = args[: self.label_key] + args[self.label_key + 1 :]
|
269 |
+
elif isinstance(self.label_key, str):
|
270 |
+
if self.label_key not in kwargs:
|
271 |
+
msg = 'Label key "%s" is not found' % self.label_key
|
272 |
+
raise ValueError(msg)
|
273 |
+
t = kwargs[self.label_key]
|
274 |
+
del kwargs[self.label_key]
|
275 |
+
|
276 |
+
self.y = None
|
277 |
+
self.loss = None
|
278 |
+
state, self.y = self.predictor(state, *args, **kwargs)
|
279 |
+
self.loss = self.lossfun(self.y, t)
|
280 |
+
return state, self.loss
|
281 |
+
|
282 |
+
def predict(self, state, x):
|
283 |
+
"""Predict log probabilities for given state and input x using the predictor.
|
284 |
+
|
285 |
+
:param torch.Tensor state : The current state
|
286 |
+
:param torch.Tensor x : The input
|
287 |
+
:return a tuple (new state, log prob vector)
|
288 |
+
:rtype (torch.Tensor, torch.Tensor)
|
289 |
+
"""
|
290 |
+
if hasattr(self.predictor, "normalized") and self.predictor.normalized:
|
291 |
+
return self.predictor(state, x)
|
292 |
+
else:
|
293 |
+
state, z = self.predictor(state, x)
|
294 |
+
return state, F.log_softmax(z, dim=1)
|
295 |
+
|
296 |
+
def buff_predict(self, state, x, n):
|
297 |
+
"""Predict new tokens from buffered inputs."""
|
298 |
+
if self.predictor.__class__.__name__ == "RNNLM":
|
299 |
+
return self.predict(state, x)
|
300 |
+
|
301 |
+
new_state = []
|
302 |
+
new_log_y = []
|
303 |
+
for i in range(n):
|
304 |
+
state_i = None if state is None else state[i]
|
305 |
+
state_i, log_y = self.predict(state_i, x[i].unsqueeze(0))
|
306 |
+
new_state.append(state_i)
|
307 |
+
new_log_y.append(log_y)
|
308 |
+
|
309 |
+
return new_state, torch.cat(new_log_y)
|
310 |
+
|
311 |
+
def final(self, state, index=None):
|
312 |
+
"""Predict final log probabilities for given state using the predictor.
|
313 |
+
|
314 |
+
:param state: The state
|
315 |
+
:return The final log probabilities
|
316 |
+
:rtype torch.Tensor
|
317 |
+
"""
|
318 |
+
if hasattr(self.predictor, "final"):
|
319 |
+
if index is not None:
|
320 |
+
return self.predictor.final(state[index])
|
321 |
+
else:
|
322 |
+
return self.predictor.final(state)
|
323 |
+
else:
|
324 |
+
return 0.0
|
325 |
+
|
326 |
+
|
327 |
+
# Definition of a recurrent net for language modeling
|
328 |
+
class RNNLM(nn.Module):
|
329 |
+
"""A pytorch RNNLM."""
|
330 |
+
|
331 |
+
def __init__(
|
332 |
+
self,
|
333 |
+
n_vocab,
|
334 |
+
n_layers,
|
335 |
+
n_units,
|
336 |
+
n_embed=None,
|
337 |
+
typ="lstm",
|
338 |
+
dropout_rate=0.5,
|
339 |
+
emb_dropout_rate=0.0,
|
340 |
+
tie_weights=False,
|
341 |
+
):
|
342 |
+
"""Initialize class.
|
343 |
+
|
344 |
+
:param int n_vocab: The size of the vocabulary
|
345 |
+
:param int n_layers: The number of layers to create
|
346 |
+
:param int n_units: The number of units per layer
|
347 |
+
:param str typ: The RNN type
|
348 |
+
"""
|
349 |
+
super(RNNLM, self).__init__()
|
350 |
+
if n_embed is None:
|
351 |
+
n_embed = n_units
|
352 |
+
|
353 |
+
self.embed = nn.Embedding(n_vocab, n_embed)
|
354 |
+
|
355 |
+
if emb_dropout_rate == 0.0:
|
356 |
+
self.embed_drop = None
|
357 |
+
else:
|
358 |
+
self.embed_drop = nn.Dropout(emb_dropout_rate)
|
359 |
+
|
360 |
+
if typ == "lstm":
|
361 |
+
self.rnn = nn.ModuleList(
|
362 |
+
[nn.LSTMCell(n_embed, n_units)]
|
363 |
+
+ [nn.LSTMCell(n_units, n_units) for _ in range(n_layers - 1)]
|
364 |
+
)
|
365 |
+
else:
|
366 |
+
self.rnn = nn.ModuleList(
|
367 |
+
[nn.GRUCell(n_embed, n_units)]
|
368 |
+
+ [nn.GRUCell(n_units, n_units) for _ in range(n_layers - 1)]
|
369 |
+
)
|
370 |
+
|
371 |
+
self.dropout = nn.ModuleList(
|
372 |
+
[nn.Dropout(dropout_rate) for _ in range(n_layers + 1)]
|
373 |
+
)
|
374 |
+
self.lo = nn.Linear(n_units, n_vocab)
|
375 |
+
self.n_layers = n_layers
|
376 |
+
self.n_units = n_units
|
377 |
+
self.typ = typ
|
378 |
+
|
379 |
+
logging.info("Tie weights set to {}".format(tie_weights))
|
380 |
+
logging.info("Dropout set to {}".format(dropout_rate))
|
381 |
+
logging.info("Emb Dropout set to {}".format(emb_dropout_rate))
|
382 |
+
|
383 |
+
if tie_weights:
|
384 |
+
assert (
|
385 |
+
n_embed == n_units
|
386 |
+
), "Tie Weights: True need embedding and final dimensions to match"
|
387 |
+
self.lo.weight = self.embed.weight
|
388 |
+
|
389 |
+
# initialize parameters from uniform distribution
|
390 |
+
for param in self.parameters():
|
391 |
+
param.data.uniform_(-0.1, 0.1)
|
392 |
+
|
393 |
+
def zero_state(self, batchsize):
|
394 |
+
"""Initialize state."""
|
395 |
+
p = next(self.parameters())
|
396 |
+
return torch.zeros(batchsize, self.n_units).to(device=p.device, dtype=p.dtype)
|
397 |
+
|
398 |
+
def forward(self, state, x):
|
399 |
+
"""Forward neural networks."""
|
400 |
+
if state is None:
|
401 |
+
h = [to_device(x, self.zero_state(x.size(0))) for n in range(self.n_layers)]
|
402 |
+
state = {"h": h}
|
403 |
+
if self.typ == "lstm":
|
404 |
+
c = [
|
405 |
+
to_device(x, self.zero_state(x.size(0)))
|
406 |
+
for n in range(self.n_layers)
|
407 |
+
]
|
408 |
+
state = {"c": c, "h": h}
|
409 |
+
|
410 |
+
h = [None] * self.n_layers
|
411 |
+
if self.embed_drop is not None:
|
412 |
+
emb = self.embed_drop(self.embed(x))
|
413 |
+
else:
|
414 |
+
emb = self.embed(x)
|
415 |
+
if self.typ == "lstm":
|
416 |
+
c = [None] * self.n_layers
|
417 |
+
h[0], c[0] = self.rnn[0](
|
418 |
+
self.dropout[0](emb), (state["h"][0], state["c"][0])
|
419 |
+
)
|
420 |
+
for n in range(1, self.n_layers):
|
421 |
+
h[n], c[n] = self.rnn[n](
|
422 |
+
self.dropout[n](h[n - 1]), (state["h"][n], state["c"][n])
|
423 |
+
)
|
424 |
+
state = {"c": c, "h": h}
|
425 |
+
else:
|
426 |
+
h[0] = self.rnn[0](self.dropout[0](emb), state["h"][0])
|
427 |
+
for n in range(1, self.n_layers):
|
428 |
+
h[n] = self.rnn[n](self.dropout[n](h[n - 1]), state["h"][n])
|
429 |
+
state = {"h": h}
|
430 |
+
y = self.lo(self.dropout[-1](h[-1]))
|
431 |
+
return state, y
|
espnet/nets/pytorch_backend/lm/seq_rnn.py
ADDED
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Sequential implementation of Recurrent Neural Network Language Model."""
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
from espnet.nets.lm_interface import LMInterface
|
8 |
+
|
9 |
+
|
10 |
+
class SequentialRNNLM(LMInterface, torch.nn.Module):
|
11 |
+
"""Sequential RNNLM.
|
12 |
+
|
13 |
+
See also:
|
14 |
+
https://github.com/pytorch/examples/blob/4581968193699de14b56527296262dd76ab43557/word_language_model/model.py
|
15 |
+
|
16 |
+
"""
|
17 |
+
|
18 |
+
@staticmethod
|
19 |
+
def add_arguments(parser):
|
20 |
+
"""Add arguments to command line argument parser."""
|
21 |
+
parser.add_argument(
|
22 |
+
"--type",
|
23 |
+
type=str,
|
24 |
+
default="lstm",
|
25 |
+
nargs="?",
|
26 |
+
choices=["lstm", "gru"],
|
27 |
+
help="Which type of RNN to use",
|
28 |
+
)
|
29 |
+
parser.add_argument(
|
30 |
+
"--layer", "-l", type=int, default=2, help="Number of hidden layers"
|
31 |
+
)
|
32 |
+
parser.add_argument(
|
33 |
+
"--unit", "-u", type=int, default=650, help="Number of hidden units"
|
34 |
+
)
|
35 |
+
parser.add_argument(
|
36 |
+
"--dropout-rate", type=float, default=0.5, help="dropout probability"
|
37 |
+
)
|
38 |
+
return parser
|
39 |
+
|
40 |
+
def __init__(self, n_vocab, args):
|
41 |
+
"""Initialize class.
|
42 |
+
|
43 |
+
Args:
|
44 |
+
n_vocab (int): The size of the vocabulary
|
45 |
+
args (argparse.Namespace): configurations. see py:method:`add_arguments`
|
46 |
+
|
47 |
+
"""
|
48 |
+
torch.nn.Module.__init__(self)
|
49 |
+
self._setup(
|
50 |
+
rnn_type=args.type.upper(),
|
51 |
+
ntoken=n_vocab,
|
52 |
+
ninp=args.unit,
|
53 |
+
nhid=args.unit,
|
54 |
+
nlayers=args.layer,
|
55 |
+
dropout=args.dropout_rate,
|
56 |
+
)
|
57 |
+
|
58 |
+
def _setup(
|
59 |
+
self, rnn_type, ntoken, ninp, nhid, nlayers, dropout=0.5, tie_weights=False
|
60 |
+
):
|
61 |
+
self.drop = nn.Dropout(dropout)
|
62 |
+
self.encoder = nn.Embedding(ntoken, ninp)
|
63 |
+
if rnn_type in ["LSTM", "GRU"]:
|
64 |
+
self.rnn = getattr(nn, rnn_type)(ninp, nhid, nlayers, dropout=dropout)
|
65 |
+
else:
|
66 |
+
try:
|
67 |
+
nonlinearity = {"RNN_TANH": "tanh", "RNN_RELU": "relu"}[rnn_type]
|
68 |
+
except KeyError:
|
69 |
+
raise ValueError(
|
70 |
+
"An invalid option for `--model` was supplied, "
|
71 |
+
"options are ['LSTM', 'GRU', 'RNN_TANH' or 'RNN_RELU']"
|
72 |
+
)
|
73 |
+
self.rnn = nn.RNN(
|
74 |
+
ninp, nhid, nlayers, nonlinearity=nonlinearity, dropout=dropout
|
75 |
+
)
|
76 |
+
self.decoder = nn.Linear(nhid, ntoken)
|
77 |
+
|
78 |
+
# Optionally tie weights as in:
|
79 |
+
# "Using the Output Embedding to Improve Language Models" (Press & Wolf 2016)
|
80 |
+
# https://arxiv.org/abs/1608.05859
|
81 |
+
# and
|
82 |
+
# "Tying Word Vectors and Word Classifiers:
|
83 |
+
# A Loss Framework for Language Modeling" (Inan et al. 2016)
|
84 |
+
# https://arxiv.org/abs/1611.01462
|
85 |
+
if tie_weights:
|
86 |
+
if nhid != ninp:
|
87 |
+
raise ValueError(
|
88 |
+
"When using the tied flag, nhid must be equal to emsize"
|
89 |
+
)
|
90 |
+
self.decoder.weight = self.encoder.weight
|
91 |
+
|
92 |
+
self._init_weights()
|
93 |
+
|
94 |
+
self.rnn_type = rnn_type
|
95 |
+
self.nhid = nhid
|
96 |
+
self.nlayers = nlayers
|
97 |
+
|
98 |
+
def _init_weights(self):
|
99 |
+
# NOTE: original init in pytorch/examples
|
100 |
+
# initrange = 0.1
|
101 |
+
# self.encoder.weight.data.uniform_(-initrange, initrange)
|
102 |
+
# self.decoder.bias.data.zero_()
|
103 |
+
# self.decoder.weight.data.uniform_(-initrange, initrange)
|
104 |
+
# NOTE: our default.py:RNNLM init
|
105 |
+
for param in self.parameters():
|
106 |
+
param.data.uniform_(-0.1, 0.1)
|
107 |
+
|
108 |
+
def forward(self, x, t):
|
109 |
+
"""Compute LM loss value from buffer sequences.
|
110 |
+
|
111 |
+
Args:
|
112 |
+
x (torch.Tensor): Input ids. (batch, len)
|
113 |
+
t (torch.Tensor): Target ids. (batch, len)
|
114 |
+
|
115 |
+
Returns:
|
116 |
+
tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Tuple of
|
117 |
+
loss to backward (scalar),
|
118 |
+
negative log-likelihood of t: -log p(t) (scalar) and
|
119 |
+
the number of elements in x (scalar)
|
120 |
+
|
121 |
+
Notes:
|
122 |
+
The last two return values are used
|
123 |
+
in perplexity: p(t)^{-n} = exp(-log p(t) / n)
|
124 |
+
|
125 |
+
"""
|
126 |
+
y = self._before_loss(x, None)[0]
|
127 |
+
mask = (x != 0).to(y.dtype)
|
128 |
+
loss = F.cross_entropy(y.view(-1, y.shape[-1]), t.view(-1), reduction="none")
|
129 |
+
logp = loss * mask.view(-1)
|
130 |
+
logp = logp.sum()
|
131 |
+
count = mask.sum()
|
132 |
+
return logp / count, logp, count
|
133 |
+
|
134 |
+
def _before_loss(self, input, hidden):
|
135 |
+
emb = self.drop(self.encoder(input))
|
136 |
+
output, hidden = self.rnn(emb, hidden)
|
137 |
+
output = self.drop(output)
|
138 |
+
decoded = self.decoder(
|
139 |
+
output.view(output.size(0) * output.size(1), output.size(2))
|
140 |
+
)
|
141 |
+
return decoded.view(output.size(0), output.size(1), decoded.size(1)), hidden
|
142 |
+
|
143 |
+
def init_state(self, x):
|
144 |
+
"""Get an initial state for decoding.
|
145 |
+
|
146 |
+
Args:
|
147 |
+
x (torch.Tensor): The encoded feature tensor
|
148 |
+
|
149 |
+
Returns: initial state
|
150 |
+
|
151 |
+
"""
|
152 |
+
bsz = 1
|
153 |
+
weight = next(self.parameters())
|
154 |
+
if self.rnn_type == "LSTM":
|
155 |
+
return (
|
156 |
+
weight.new_zeros(self.nlayers, bsz, self.nhid),
|
157 |
+
weight.new_zeros(self.nlayers, bsz, self.nhid),
|
158 |
+
)
|
159 |
+
else:
|
160 |
+
return weight.new_zeros(self.nlayers, bsz, self.nhid)
|
161 |
+
|
162 |
+
def score(self, y, state, x):
|
163 |
+
"""Score new token.
|
164 |
+
|
165 |
+
Args:
|
166 |
+
y (torch.Tensor): 1D torch.int64 prefix tokens.
|
167 |
+
state: Scorer state for prefix tokens
|
168 |
+
x (torch.Tensor): 2D encoder feature that generates ys.
|
169 |
+
|
170 |
+
Returns:
|
171 |
+
tuple[torch.Tensor, Any]: Tuple of
|
172 |
+
torch.float32 scores for next token (n_vocab)
|
173 |
+
and next state for ys
|
174 |
+
|
175 |
+
"""
|
176 |
+
y, new_state = self._before_loss(y[-1].view(1, 1), state)
|
177 |
+
logp = y.log_softmax(dim=-1).view(-1)
|
178 |
+
return logp, new_state
|
espnet/nets/pytorch_backend/lm/transformer.py
ADDED
@@ -0,0 +1,252 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Transformer language model."""
|
2 |
+
|
3 |
+
from typing import Any
|
4 |
+
from typing import List
|
5 |
+
from typing import Tuple
|
6 |
+
|
7 |
+
import logging
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
import torch.nn.functional as F
|
11 |
+
|
12 |
+
from espnet.nets.lm_interface import LMInterface
|
13 |
+
from espnet.nets.pytorch_backend.transformer.embedding import PositionalEncoding
|
14 |
+
from espnet.nets.pytorch_backend.transformer.encoder import Encoder
|
15 |
+
from espnet.nets.pytorch_backend.transformer.mask import subsequent_mask
|
16 |
+
from espnet.nets.scorer_interface import BatchScorerInterface
|
17 |
+
from espnet.utils.cli_utils import strtobool
|
18 |
+
|
19 |
+
|
20 |
+
class TransformerLM(nn.Module, LMInterface, BatchScorerInterface):
|
21 |
+
"""Transformer language model."""
|
22 |
+
|
23 |
+
@staticmethod
|
24 |
+
def add_arguments(parser):
|
25 |
+
"""Add arguments to command line argument parser."""
|
26 |
+
parser.add_argument(
|
27 |
+
"--layer", type=int, default=4, help="Number of hidden layers"
|
28 |
+
)
|
29 |
+
parser.add_argument(
|
30 |
+
"--unit",
|
31 |
+
type=int,
|
32 |
+
default=1024,
|
33 |
+
help="Number of hidden units in feedforward layer",
|
34 |
+
)
|
35 |
+
parser.add_argument(
|
36 |
+
"--att-unit",
|
37 |
+
type=int,
|
38 |
+
default=256,
|
39 |
+
help="Number of hidden units in attention layer",
|
40 |
+
)
|
41 |
+
parser.add_argument(
|
42 |
+
"--embed-unit",
|
43 |
+
type=int,
|
44 |
+
default=128,
|
45 |
+
help="Number of hidden units in embedding layer",
|
46 |
+
)
|
47 |
+
parser.add_argument(
|
48 |
+
"--head", type=int, default=2, help="Number of multi head attention"
|
49 |
+
)
|
50 |
+
parser.add_argument(
|
51 |
+
"--dropout-rate", type=float, default=0.5, help="dropout probability"
|
52 |
+
)
|
53 |
+
parser.add_argument(
|
54 |
+
"--att-dropout-rate",
|
55 |
+
type=float,
|
56 |
+
default=0.0,
|
57 |
+
help="att dropout probability",
|
58 |
+
)
|
59 |
+
parser.add_argument(
|
60 |
+
"--emb-dropout-rate",
|
61 |
+
type=float,
|
62 |
+
default=0.0,
|
63 |
+
help="emb dropout probability",
|
64 |
+
)
|
65 |
+
parser.add_argument(
|
66 |
+
"--tie-weights",
|
67 |
+
type=strtobool,
|
68 |
+
default=False,
|
69 |
+
help="Tie input and output embeddings",
|
70 |
+
)
|
71 |
+
parser.add_argument(
|
72 |
+
"--pos-enc",
|
73 |
+
default="sinusoidal",
|
74 |
+
choices=["sinusoidal", "none"],
|
75 |
+
help="positional encoding",
|
76 |
+
)
|
77 |
+
return parser
|
78 |
+
|
79 |
+
def __init__(self, n_vocab, args):
|
80 |
+
"""Initialize class.
|
81 |
+
|
82 |
+
Args:
|
83 |
+
n_vocab (int): The size of the vocabulary
|
84 |
+
args (argparse.Namespace): configurations. see py:method:`add_arguments`
|
85 |
+
|
86 |
+
"""
|
87 |
+
nn.Module.__init__(self)
|
88 |
+
|
89 |
+
# NOTE: for a compatibility with less than 0.9.7 version models
|
90 |
+
emb_dropout_rate = getattr(args, "emb_dropout_rate", 0.0)
|
91 |
+
# NOTE: for a compatibility with less than 0.9.7 version models
|
92 |
+
tie_weights = getattr(args, "tie_weights", False)
|
93 |
+
# NOTE: for a compatibility with less than 0.9.7 version models
|
94 |
+
att_dropout_rate = getattr(args, "att_dropout_rate", 0.0)
|
95 |
+
|
96 |
+
if args.pos_enc == "sinusoidal":
|
97 |
+
pos_enc_class = PositionalEncoding
|
98 |
+
elif args.pos_enc == "none":
|
99 |
+
|
100 |
+
def pos_enc_class(*args, **kwargs):
|
101 |
+
return nn.Sequential() # indentity
|
102 |
+
|
103 |
+
else:
|
104 |
+
raise ValueError(f"unknown pos-enc option: {args.pos_enc}")
|
105 |
+
|
106 |
+
self.embed = nn.Embedding(n_vocab, args.embed_unit)
|
107 |
+
|
108 |
+
if emb_dropout_rate == 0.0:
|
109 |
+
self.embed_drop = None
|
110 |
+
else:
|
111 |
+
self.embed_drop = nn.Dropout(emb_dropout_rate)
|
112 |
+
|
113 |
+
self.encoder = Encoder(
|
114 |
+
idim=args.embed_unit,
|
115 |
+
attention_dim=args.att_unit,
|
116 |
+
attention_heads=args.head,
|
117 |
+
linear_units=args.unit,
|
118 |
+
num_blocks=args.layer,
|
119 |
+
dropout_rate=args.dropout_rate,
|
120 |
+
attention_dropout_rate=att_dropout_rate,
|
121 |
+
input_layer="linear",
|
122 |
+
pos_enc_class=pos_enc_class,
|
123 |
+
)
|
124 |
+
self.decoder = nn.Linear(args.att_unit, n_vocab)
|
125 |
+
|
126 |
+
logging.info("Tie weights set to {}".format(tie_weights))
|
127 |
+
logging.info("Dropout set to {}".format(args.dropout_rate))
|
128 |
+
logging.info("Emb Dropout set to {}".format(emb_dropout_rate))
|
129 |
+
logging.info("Att Dropout set to {}".format(att_dropout_rate))
|
130 |
+
|
131 |
+
if tie_weights:
|
132 |
+
assert (
|
133 |
+
args.att_unit == args.embed_unit
|
134 |
+
), "Tie Weights: True need embedding and final dimensions to match"
|
135 |
+
self.decoder.weight = self.embed.weight
|
136 |
+
|
137 |
+
def _target_mask(self, ys_in_pad):
|
138 |
+
ys_mask = ys_in_pad != 0
|
139 |
+
m = subsequent_mask(ys_mask.size(-1), device=ys_mask.device).unsqueeze(0)
|
140 |
+
return ys_mask.unsqueeze(-2) & m
|
141 |
+
|
142 |
+
def forward(
|
143 |
+
self, x: torch.Tensor, t: torch.Tensor
|
144 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
145 |
+
"""Compute LM loss value from buffer sequences.
|
146 |
+
|
147 |
+
Args:
|
148 |
+
x (torch.Tensor): Input ids. (batch, len)
|
149 |
+
t (torch.Tensor): Target ids. (batch, len)
|
150 |
+
|
151 |
+
Returns:
|
152 |
+
tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Tuple of
|
153 |
+
loss to backward (scalar),
|
154 |
+
negative log-likelihood of t: -log p(t) (scalar) and
|
155 |
+
the number of elements in x (scalar)
|
156 |
+
|
157 |
+
Notes:
|
158 |
+
The last two return values are used
|
159 |
+
in perplexity: p(t)^{-n} = exp(-log p(t) / n)
|
160 |
+
|
161 |
+
"""
|
162 |
+
xm = x != 0
|
163 |
+
|
164 |
+
if self.embed_drop is not None:
|
165 |
+
emb = self.embed_drop(self.embed(x))
|
166 |
+
else:
|
167 |
+
emb = self.embed(x)
|
168 |
+
|
169 |
+
h, _ = self.encoder(emb, self._target_mask(x))
|
170 |
+
y = self.decoder(h)
|
171 |
+
loss = F.cross_entropy(y.view(-1, y.shape[-1]), t.view(-1), reduction="none")
|
172 |
+
mask = xm.to(dtype=loss.dtype)
|
173 |
+
logp = loss * mask.view(-1)
|
174 |
+
logp = logp.sum()
|
175 |
+
count = mask.sum()
|
176 |
+
return logp / count, logp, count
|
177 |
+
|
178 |
+
def score(
|
179 |
+
self, y: torch.Tensor, state: Any, x: torch.Tensor
|
180 |
+
) -> Tuple[torch.Tensor, Any]:
|
181 |
+
"""Score new token.
|
182 |
+
|
183 |
+
Args:
|
184 |
+
y (torch.Tensor): 1D torch.int64 prefix tokens.
|
185 |
+
state: Scorer state for prefix tokens
|
186 |
+
x (torch.Tensor): encoder feature that generates ys.
|
187 |
+
|
188 |
+
Returns:
|
189 |
+
tuple[torch.Tensor, Any]: Tuple of
|
190 |
+
torch.float32 scores for next token (n_vocab)
|
191 |
+
and next state for ys
|
192 |
+
|
193 |
+
"""
|
194 |
+
y = y.unsqueeze(0)
|
195 |
+
|
196 |
+
if self.embed_drop is not None:
|
197 |
+
emb = self.embed_drop(self.embed(y))
|
198 |
+
else:
|
199 |
+
emb = self.embed(y)
|
200 |
+
|
201 |
+
h, _, cache = self.encoder.forward_one_step(
|
202 |
+
emb, self._target_mask(y), cache=state
|
203 |
+
)
|
204 |
+
h = self.decoder(h[:, -1])
|
205 |
+
logp = h.log_softmax(dim=-1).squeeze(0)
|
206 |
+
return logp, cache
|
207 |
+
|
208 |
+
# batch beam search API (see BatchScorerInterface)
|
209 |
+
def batch_score(
|
210 |
+
self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor
|
211 |
+
) -> Tuple[torch.Tensor, List[Any]]:
|
212 |
+
"""Score new token batch (required).
|
213 |
+
|
214 |
+
Args:
|
215 |
+
ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen).
|
216 |
+
states (List[Any]): Scorer states for prefix tokens.
|
217 |
+
xs (torch.Tensor):
|
218 |
+
The encoder feature that generates ys (n_batch, xlen, n_feat).
|
219 |
+
|
220 |
+
Returns:
|
221 |
+
tuple[torch.Tensor, List[Any]]: Tuple of
|
222 |
+
batchfied scores for next token with shape of `(n_batch, n_vocab)`
|
223 |
+
and next state list for ys.
|
224 |
+
|
225 |
+
"""
|
226 |
+
# merge states
|
227 |
+
n_batch = len(ys)
|
228 |
+
n_layers = len(self.encoder.encoders)
|
229 |
+
if states[0] is None:
|
230 |
+
batch_state = None
|
231 |
+
else:
|
232 |
+
# transpose state of [batch, layer] into [layer, batch]
|
233 |
+
batch_state = [
|
234 |
+
torch.stack([states[b][i] for b in range(n_batch)])
|
235 |
+
for i in range(n_layers)
|
236 |
+
]
|
237 |
+
|
238 |
+
if self.embed_drop is not None:
|
239 |
+
emb = self.embed_drop(self.embed(ys))
|
240 |
+
else:
|
241 |
+
emb = self.embed(ys)
|
242 |
+
|
243 |
+
# batch decoding
|
244 |
+
h, _, states = self.encoder.forward_one_step(
|
245 |
+
emb, self._target_mask(ys), cache=batch_state
|
246 |
+
)
|
247 |
+
h = self.decoder(h[:, -1])
|
248 |
+
logp = h.log_softmax(dim=-1)
|
249 |
+
|
250 |
+
# transpose state of [layer, batch] into [batch, layer]
|
251 |
+
state_list = [[states[i][b] for i in range(n_layers)] for b in range(n_batch)]
|
252 |
+
return logp, state_list
|
espnet/nets/pytorch_backend/nets_utils.py
ADDED
@@ -0,0 +1,526 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
"""Network related utility tools."""
|
4 |
+
|
5 |
+
import logging
|
6 |
+
from typing import Dict
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
|
11 |
+
|
12 |
+
def to_device(m, x):
|
13 |
+
"""Send tensor into the device of the module.
|
14 |
+
|
15 |
+
Args:
|
16 |
+
m (torch.nn.Module): Torch module.
|
17 |
+
x (Tensor): Torch tensor.
|
18 |
+
|
19 |
+
Returns:
|
20 |
+
Tensor: Torch tensor located in the same place as torch module.
|
21 |
+
|
22 |
+
"""
|
23 |
+
if isinstance(m, torch.nn.Module):
|
24 |
+
device = next(m.parameters()).device
|
25 |
+
elif isinstance(m, torch.Tensor):
|
26 |
+
device = m.device
|
27 |
+
else:
|
28 |
+
raise TypeError(
|
29 |
+
"Expected torch.nn.Module or torch.tensor, " f"bot got: {type(m)}"
|
30 |
+
)
|
31 |
+
return x.to(device)
|
32 |
+
|
33 |
+
|
34 |
+
def pad_list(xs, pad_value):
|
35 |
+
"""Perform padding for the list of tensors.
|
36 |
+
|
37 |
+
Args:
|
38 |
+
xs (List): List of Tensors [(T_1, `*`), (T_2, `*`), ..., (T_B, `*`)].
|
39 |
+
pad_value (float): Value for padding.
|
40 |
+
|
41 |
+
Returns:
|
42 |
+
Tensor: Padded tensor (B, Tmax, `*`).
|
43 |
+
|
44 |
+
Examples:
|
45 |
+
>>> x = [torch.ones(4), torch.ones(2), torch.ones(1)]
|
46 |
+
>>> x
|
47 |
+
[tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])]
|
48 |
+
>>> pad_list(x, 0)
|
49 |
+
tensor([[1., 1., 1., 1.],
|
50 |
+
[1., 1., 0., 0.],
|
51 |
+
[1., 0., 0., 0.]])
|
52 |
+
|
53 |
+
"""
|
54 |
+
n_batch = len(xs)
|
55 |
+
max_len = max(x.size(0) for x in xs)
|
56 |
+
pad = xs[0].new(n_batch, max_len, *xs[0].size()[1:]).fill_(pad_value)
|
57 |
+
|
58 |
+
for i in range(n_batch):
|
59 |
+
pad[i, : xs[i].size(0)] = xs[i]
|
60 |
+
|
61 |
+
return pad
|
62 |
+
|
63 |
+
|
64 |
+
def make_pad_mask(lengths, xs=None, length_dim=-1, maxlen=None):
|
65 |
+
"""Make mask tensor containing indices of padded part.
|
66 |
+
|
67 |
+
Args:
|
68 |
+
lengths (LongTensor or List): Batch of lengths (B,).
|
69 |
+
xs (Tensor, optional): The reference tensor.
|
70 |
+
If set, masks will be the same shape as this tensor.
|
71 |
+
length_dim (int, optional): Dimension indicator of the above tensor.
|
72 |
+
See the example.
|
73 |
+
|
74 |
+
Returns:
|
75 |
+
Tensor: Mask tensor containing indices of padded part.
|
76 |
+
dtype=torch.uint8 in PyTorch 1.2-
|
77 |
+
dtype=torch.bool in PyTorch 1.2+ (including 1.2)
|
78 |
+
|
79 |
+
Examples:
|
80 |
+
With only lengths.
|
81 |
+
|
82 |
+
>>> lengths = [5, 3, 2]
|
83 |
+
>>> make_pad_mask(lengths)
|
84 |
+
masks = [[0, 0, 0, 0 ,0],
|
85 |
+
[0, 0, 0, 1, 1],
|
86 |
+
[0, 0, 1, 1, 1]]
|
87 |
+
|
88 |
+
With the reference tensor.
|
89 |
+
|
90 |
+
>>> xs = torch.zeros((3, 2, 4))
|
91 |
+
>>> make_pad_mask(lengths, xs)
|
92 |
+
tensor([[[0, 0, 0, 0],
|
93 |
+
[0, 0, 0, 0]],
|
94 |
+
[[0, 0, 0, 1],
|
95 |
+
[0, 0, 0, 1]],
|
96 |
+
[[0, 0, 1, 1],
|
97 |
+
[0, 0, 1, 1]]], dtype=torch.uint8)
|
98 |
+
>>> xs = torch.zeros((3, 2, 6))
|
99 |
+
>>> make_pad_mask(lengths, xs)
|
100 |
+
tensor([[[0, 0, 0, 0, 0, 1],
|
101 |
+
[0, 0, 0, 0, 0, 1]],
|
102 |
+
[[0, 0, 0, 1, 1, 1],
|
103 |
+
[0, 0, 0, 1, 1, 1]],
|
104 |
+
[[0, 0, 1, 1, 1, 1],
|
105 |
+
[0, 0, 1, 1, 1, 1]]], dtype=torch.uint8)
|
106 |
+
|
107 |
+
With the reference tensor and dimension indicator.
|
108 |
+
|
109 |
+
>>> xs = torch.zeros((3, 6, 6))
|
110 |
+
>>> make_pad_mask(lengths, xs, 1)
|
111 |
+
tensor([[[0, 0, 0, 0, 0, 0],
|
112 |
+
[0, 0, 0, 0, 0, 0],
|
113 |
+
[0, 0, 0, 0, 0, 0],
|
114 |
+
[0, 0, 0, 0, 0, 0],
|
115 |
+
[0, 0, 0, 0, 0, 0],
|
116 |
+
[1, 1, 1, 1, 1, 1]],
|
117 |
+
[[0, 0, 0, 0, 0, 0],
|
118 |
+
[0, 0, 0, 0, 0, 0],
|
119 |
+
[0, 0, 0, 0, 0, 0],
|
120 |
+
[1, 1, 1, 1, 1, 1],
|
121 |
+
[1, 1, 1, 1, 1, 1],
|
122 |
+
[1, 1, 1, 1, 1, 1]],
|
123 |
+
[[0, 0, 0, 0, 0, 0],
|
124 |
+
[0, 0, 0, 0, 0, 0],
|
125 |
+
[1, 1, 1, 1, 1, 1],
|
126 |
+
[1, 1, 1, 1, 1, 1],
|
127 |
+
[1, 1, 1, 1, 1, 1],
|
128 |
+
[1, 1, 1, 1, 1, 1]]], dtype=torch.uint8)
|
129 |
+
>>> make_pad_mask(lengths, xs, 2)
|
130 |
+
tensor([[[0, 0, 0, 0, 0, 1],
|
131 |
+
[0, 0, 0, 0, 0, 1],
|
132 |
+
[0, 0, 0, 0, 0, 1],
|
133 |
+
[0, 0, 0, 0, 0, 1],
|
134 |
+
[0, 0, 0, 0, 0, 1],
|
135 |
+
[0, 0, 0, 0, 0, 1]],
|
136 |
+
[[0, 0, 0, 1, 1, 1],
|
137 |
+
[0, 0, 0, 1, 1, 1],
|
138 |
+
[0, 0, 0, 1, 1, 1],
|
139 |
+
[0, 0, 0, 1, 1, 1],
|
140 |
+
[0, 0, 0, 1, 1, 1],
|
141 |
+
[0, 0, 0, 1, 1, 1]],
|
142 |
+
[[0, 0, 1, 1, 1, 1],
|
143 |
+
[0, 0, 1, 1, 1, 1],
|
144 |
+
[0, 0, 1, 1, 1, 1],
|
145 |
+
[0, 0, 1, 1, 1, 1],
|
146 |
+
[0, 0, 1, 1, 1, 1],
|
147 |
+
[0, 0, 1, 1, 1, 1]]], dtype=torch.uint8)
|
148 |
+
|
149 |
+
"""
|
150 |
+
if length_dim == 0:
|
151 |
+
raise ValueError("length_dim cannot be 0: {}".format(length_dim))
|
152 |
+
|
153 |
+
if not isinstance(lengths, list):
|
154 |
+
lengths = lengths.tolist()
|
155 |
+
bs = int(len(lengths))
|
156 |
+
if maxlen is None:
|
157 |
+
if xs is None:
|
158 |
+
maxlen = int(max(lengths))
|
159 |
+
else:
|
160 |
+
maxlen = xs.size(length_dim)
|
161 |
+
else:
|
162 |
+
assert xs is None
|
163 |
+
assert maxlen >= int(max(lengths))
|
164 |
+
|
165 |
+
seq_range = torch.arange(0, maxlen, dtype=torch.int64)
|
166 |
+
seq_range_expand = seq_range.unsqueeze(0).expand(bs, maxlen)
|
167 |
+
seq_length_expand = seq_range_expand.new(lengths).unsqueeze(-1)
|
168 |
+
mask = seq_range_expand >= seq_length_expand
|
169 |
+
|
170 |
+
if xs is not None:
|
171 |
+
assert xs.size(0) == bs, (xs.size(0), bs)
|
172 |
+
|
173 |
+
if length_dim < 0:
|
174 |
+
length_dim = xs.dim() + length_dim
|
175 |
+
# ind = (:, None, ..., None, :, , None, ..., None)
|
176 |
+
ind = tuple(
|
177 |
+
slice(None) if i in (0, length_dim) else None for i in range(xs.dim())
|
178 |
+
)
|
179 |
+
mask = mask[ind].expand_as(xs).to(xs.device)
|
180 |
+
return mask
|
181 |
+
|
182 |
+
|
183 |
+
def make_non_pad_mask(lengths, xs=None, length_dim=-1):
|
184 |
+
"""Make mask tensor containing indices of non-padded part.
|
185 |
+
|
186 |
+
Args:
|
187 |
+
lengths (LongTensor or List): Batch of lengths (B,).
|
188 |
+
xs (Tensor, optional): The reference tensor.
|
189 |
+
If set, masks will be the same shape as this tensor.
|
190 |
+
length_dim (int, optional): Dimension indicator of the above tensor.
|
191 |
+
See the example.
|
192 |
+
|
193 |
+
Returns:
|
194 |
+
ByteTensor: mask tensor containing indices of padded part.
|
195 |
+
dtype=torch.uint8 in PyTorch 1.2-
|
196 |
+
dtype=torch.bool in PyTorch 1.2+ (including 1.2)
|
197 |
+
|
198 |
+
Examples:
|
199 |
+
With only lengths.
|
200 |
+
|
201 |
+
>>> lengths = [5, 3, 2]
|
202 |
+
>>> make_non_pad_mask(lengths)
|
203 |
+
masks = [[1, 1, 1, 1 ,1],
|
204 |
+
[1, 1, 1, 0, 0],
|
205 |
+
[1, 1, 0, 0, 0]]
|
206 |
+
|
207 |
+
With the reference tensor.
|
208 |
+
|
209 |
+
>>> xs = torch.zeros((3, 2, 4))
|
210 |
+
>>> make_non_pad_mask(lengths, xs)
|
211 |
+
tensor([[[1, 1, 1, 1],
|
212 |
+
[1, 1, 1, 1]],
|
213 |
+
[[1, 1, 1, 0],
|
214 |
+
[1, 1, 1, 0]],
|
215 |
+
[[1, 1, 0, 0],
|
216 |
+
[1, 1, 0, 0]]], dtype=torch.uint8)
|
217 |
+
>>> xs = torch.zeros((3, 2, 6))
|
218 |
+
>>> make_non_pad_mask(lengths, xs)
|
219 |
+
tensor([[[1, 1, 1, 1, 1, 0],
|
220 |
+
[1, 1, 1, 1, 1, 0]],
|
221 |
+
[[1, 1, 1, 0, 0, 0],
|
222 |
+
[1, 1, 1, 0, 0, 0]],
|
223 |
+
[[1, 1, 0, 0, 0, 0],
|
224 |
+
[1, 1, 0, 0, 0, 0]]], dtype=torch.uint8)
|
225 |
+
|
226 |
+
With the reference tensor and dimension indicator.
|
227 |
+
|
228 |
+
>>> xs = torch.zeros((3, 6, 6))
|
229 |
+
>>> make_non_pad_mask(lengths, xs, 1)
|
230 |
+
tensor([[[1, 1, 1, 1, 1, 1],
|
231 |
+
[1, 1, 1, 1, 1, 1],
|
232 |
+
[1, 1, 1, 1, 1, 1],
|
233 |
+
[1, 1, 1, 1, 1, 1],
|
234 |
+
[1, 1, 1, 1, 1, 1],
|
235 |
+
[0, 0, 0, 0, 0, 0]],
|
236 |
+
[[1, 1, 1, 1, 1, 1],
|
237 |
+
[1, 1, 1, 1, 1, 1],
|
238 |
+
[1, 1, 1, 1, 1, 1],
|
239 |
+
[0, 0, 0, 0, 0, 0],
|
240 |
+
[0, 0, 0, 0, 0, 0],
|
241 |
+
[0, 0, 0, 0, 0, 0]],
|
242 |
+
[[1, 1, 1, 1, 1, 1],
|
243 |
+
[1, 1, 1, 1, 1, 1],
|
244 |
+
[0, 0, 0, 0, 0, 0],
|
245 |
+
[0, 0, 0, 0, 0, 0],
|
246 |
+
[0, 0, 0, 0, 0, 0],
|
247 |
+
[0, 0, 0, 0, 0, 0]]], dtype=torch.uint8)
|
248 |
+
>>> make_non_pad_mask(lengths, xs, 2)
|
249 |
+
tensor([[[1, 1, 1, 1, 1, 0],
|
250 |
+
[1, 1, 1, 1, 1, 0],
|
251 |
+
[1, 1, 1, 1, 1, 0],
|
252 |
+
[1, 1, 1, 1, 1, 0],
|
253 |
+
[1, 1, 1, 1, 1, 0],
|
254 |
+
[1, 1, 1, 1, 1, 0]],
|
255 |
+
[[1, 1, 1, 0, 0, 0],
|
256 |
+
[1, 1, 1, 0, 0, 0],
|
257 |
+
[1, 1, 1, 0, 0, 0],
|
258 |
+
[1, 1, 1, 0, 0, 0],
|
259 |
+
[1, 1, 1, 0, 0, 0],
|
260 |
+
[1, 1, 1, 0, 0, 0]],
|
261 |
+
[[1, 1, 0, 0, 0, 0],
|
262 |
+
[1, 1, 0, 0, 0, 0],
|
263 |
+
[1, 1, 0, 0, 0, 0],
|
264 |
+
[1, 1, 0, 0, 0, 0],
|
265 |
+
[1, 1, 0, 0, 0, 0],
|
266 |
+
[1, 1, 0, 0, 0, 0]]], dtype=torch.uint8)
|
267 |
+
|
268 |
+
"""
|
269 |
+
return ~make_pad_mask(lengths, xs, length_dim)
|
270 |
+
|
271 |
+
|
272 |
+
def mask_by_length(xs, lengths, fill=0):
|
273 |
+
"""Mask tensor according to length.
|
274 |
+
|
275 |
+
Args:
|
276 |
+
xs (Tensor): Batch of input tensor (B, `*`).
|
277 |
+
lengths (LongTensor or List): Batch of lengths (B,).
|
278 |
+
fill (int or float): Value to fill masked part.
|
279 |
+
|
280 |
+
Returns:
|
281 |
+
Tensor: Batch of masked input tensor (B, `*`).
|
282 |
+
|
283 |
+
Examples:
|
284 |
+
>>> x = torch.arange(5).repeat(3, 1) + 1
|
285 |
+
>>> x
|
286 |
+
tensor([[1, 2, 3, 4, 5],
|
287 |
+
[1, 2, 3, 4, 5],
|
288 |
+
[1, 2, 3, 4, 5]])
|
289 |
+
>>> lengths = [5, 3, 2]
|
290 |
+
>>> mask_by_length(x, lengths)
|
291 |
+
tensor([[1, 2, 3, 4, 5],
|
292 |
+
[1, 2, 3, 0, 0],
|
293 |
+
[1, 2, 0, 0, 0]])
|
294 |
+
|
295 |
+
"""
|
296 |
+
assert xs.size(0) == len(lengths)
|
297 |
+
ret = xs.data.new(*xs.size()).fill_(fill)
|
298 |
+
for i, l in enumerate(lengths):
|
299 |
+
ret[i, :l] = xs[i, :l]
|
300 |
+
return ret
|
301 |
+
|
302 |
+
|
303 |
+
def th_accuracy(pad_outputs, pad_targets, ignore_label):
|
304 |
+
"""Calculate accuracy.
|
305 |
+
|
306 |
+
Args:
|
307 |
+
pad_outputs (Tensor): Prediction tensors (B * Lmax, D).
|
308 |
+
pad_targets (LongTensor): Target label tensors (B, Lmax, D).
|
309 |
+
ignore_label (int): Ignore label id.
|
310 |
+
|
311 |
+
Returns:
|
312 |
+
float: Accuracy value (0.0 - 1.0).
|
313 |
+
|
314 |
+
"""
|
315 |
+
pad_pred = pad_outputs.view(
|
316 |
+
pad_targets.size(0), pad_targets.size(1), pad_outputs.size(1)
|
317 |
+
).argmax(2)
|
318 |
+
mask = pad_targets != ignore_label
|
319 |
+
numerator = torch.sum(
|
320 |
+
pad_pred.masked_select(mask) == pad_targets.masked_select(mask)
|
321 |
+
)
|
322 |
+
denominator = torch.sum(mask)
|
323 |
+
return float(numerator) / float(denominator)
|
324 |
+
|
325 |
+
|
326 |
+
def to_torch_tensor(x):
|
327 |
+
"""Change to torch.Tensor or ComplexTensor from numpy.ndarray.
|
328 |
+
|
329 |
+
Args:
|
330 |
+
x: Inputs. It should be one of numpy.ndarray, Tensor, ComplexTensor, and dict.
|
331 |
+
|
332 |
+
Returns:
|
333 |
+
Tensor or ComplexTensor: Type converted inputs.
|
334 |
+
|
335 |
+
Examples:
|
336 |
+
>>> xs = np.ones(3, dtype=np.float32)
|
337 |
+
>>> xs = to_torch_tensor(xs)
|
338 |
+
tensor([1., 1., 1.])
|
339 |
+
>>> xs = torch.ones(3, 4, 5)
|
340 |
+
>>> assert to_torch_tensor(xs) is xs
|
341 |
+
>>> xs = {'real': xs, 'imag': xs}
|
342 |
+
>>> to_torch_tensor(xs)
|
343 |
+
ComplexTensor(
|
344 |
+
Real:
|
345 |
+
tensor([1., 1., 1.])
|
346 |
+
Imag;
|
347 |
+
tensor([1., 1., 1.])
|
348 |
+
)
|
349 |
+
|
350 |
+
"""
|
351 |
+
# If numpy, change to torch tensor
|
352 |
+
if isinstance(x, np.ndarray):
|
353 |
+
if x.dtype.kind == "c":
|
354 |
+
# Dynamically importing because torch_complex requires python3
|
355 |
+
from torch_complex.tensor import ComplexTensor
|
356 |
+
|
357 |
+
return ComplexTensor(x)
|
358 |
+
else:
|
359 |
+
return torch.from_numpy(x)
|
360 |
+
|
361 |
+
# If {'real': ..., 'imag': ...}, convert to ComplexTensor
|
362 |
+
elif isinstance(x, dict):
|
363 |
+
# Dynamically importing because torch_complex requires python3
|
364 |
+
from torch_complex.tensor import ComplexTensor
|
365 |
+
|
366 |
+
if "real" not in x or "imag" not in x:
|
367 |
+
raise ValueError("has 'real' and 'imag' keys: {}".format(list(x)))
|
368 |
+
# Relative importing because of using python3 syntax
|
369 |
+
return ComplexTensor(x["real"], x["imag"])
|
370 |
+
|
371 |
+
# If torch.Tensor, as it is
|
372 |
+
elif isinstance(x, torch.Tensor):
|
373 |
+
return x
|
374 |
+
|
375 |
+
else:
|
376 |
+
error = (
|
377 |
+
"x must be numpy.ndarray, torch.Tensor or a dict like "
|
378 |
+
"{{'real': torch.Tensor, 'imag': torch.Tensor}}, "
|
379 |
+
"but got {}".format(type(x))
|
380 |
+
)
|
381 |
+
try:
|
382 |
+
from torch_complex.tensor import ComplexTensor
|
383 |
+
except Exception:
|
384 |
+
# If PY2
|
385 |
+
raise ValueError(error)
|
386 |
+
else:
|
387 |
+
# If PY3
|
388 |
+
if isinstance(x, ComplexTensor):
|
389 |
+
return x
|
390 |
+
else:
|
391 |
+
raise ValueError(error)
|
392 |
+
|
393 |
+
|
394 |
+
def get_subsample(train_args, mode, arch):
|
395 |
+
"""Parse the subsampling factors from the args for the specified `mode` and `arch`.
|
396 |
+
|
397 |
+
Args:
|
398 |
+
train_args: argument Namespace containing options.
|
399 |
+
mode: one of ('asr', 'mt', 'st')
|
400 |
+
arch: one of ('rnn', 'rnn-t', 'rnn_mix', 'rnn_mulenc', 'transformer')
|
401 |
+
|
402 |
+
Returns:
|
403 |
+
np.ndarray / List[np.ndarray]: subsampling factors.
|
404 |
+
"""
|
405 |
+
if arch == "transformer":
|
406 |
+
return np.array([1])
|
407 |
+
|
408 |
+
elif mode == "mt" and arch == "rnn":
|
409 |
+
# +1 means input (+1) and layers outputs (train_args.elayer)
|
410 |
+
subsample = np.ones(train_args.elayers + 1, dtype=np.int)
|
411 |
+
logging.warning("Subsampling is not performed for machine translation.")
|
412 |
+
logging.info("subsample: " + " ".join([str(x) for x in subsample]))
|
413 |
+
return subsample
|
414 |
+
|
415 |
+
elif (
|
416 |
+
(mode == "asr" and arch in ("rnn", "rnn-t"))
|
417 |
+
or (mode == "mt" and arch == "rnn")
|
418 |
+
or (mode == "st" and arch == "rnn")
|
419 |
+
):
|
420 |
+
subsample = np.ones(train_args.elayers + 1, dtype=np.int)
|
421 |
+
if train_args.etype.endswith("p") and not train_args.etype.startswith("vgg"):
|
422 |
+
ss = train_args.subsample.split("_")
|
423 |
+
for j in range(min(train_args.elayers + 1, len(ss))):
|
424 |
+
subsample[j] = int(ss[j])
|
425 |
+
else:
|
426 |
+
logging.warning(
|
427 |
+
"Subsampling is not performed for vgg*. "
|
428 |
+
"It is performed in max pooling layers at CNN."
|
429 |
+
)
|
430 |
+
logging.info("subsample: " + " ".join([str(x) for x in subsample]))
|
431 |
+
return subsample
|
432 |
+
|
433 |
+
elif mode == "asr" and arch == "rnn_mix":
|
434 |
+
subsample = np.ones(
|
435 |
+
train_args.elayers_sd + train_args.elayers + 1, dtype=np.int
|
436 |
+
)
|
437 |
+
if train_args.etype.endswith("p") and not train_args.etype.startswith("vgg"):
|
438 |
+
ss = train_args.subsample.split("_")
|
439 |
+
for j in range(
|
440 |
+
min(train_args.elayers_sd + train_args.elayers + 1, len(ss))
|
441 |
+
):
|
442 |
+
subsample[j] = int(ss[j])
|
443 |
+
else:
|
444 |
+
logging.warning(
|
445 |
+
"Subsampling is not performed for vgg*. "
|
446 |
+
"It is performed in max pooling layers at CNN."
|
447 |
+
)
|
448 |
+
logging.info("subsample: " + " ".join([str(x) for x in subsample]))
|
449 |
+
return subsample
|
450 |
+
|
451 |
+
elif mode == "asr" and arch == "rnn_mulenc":
|
452 |
+
subsample_list = []
|
453 |
+
for idx in range(train_args.num_encs):
|
454 |
+
subsample = np.ones(train_args.elayers[idx] + 1, dtype=np.int)
|
455 |
+
if train_args.etype[idx].endswith("p") and not train_args.etype[
|
456 |
+
idx
|
457 |
+
].startswith("vgg"):
|
458 |
+
ss = train_args.subsample[idx].split("_")
|
459 |
+
for j in range(min(train_args.elayers[idx] + 1, len(ss))):
|
460 |
+
subsample[j] = int(ss[j])
|
461 |
+
else:
|
462 |
+
logging.warning(
|
463 |
+
"Encoder %d: Subsampling is not performed for vgg*. "
|
464 |
+
"It is performed in max pooling layers at CNN.",
|
465 |
+
idx + 1,
|
466 |
+
)
|
467 |
+
logging.info("subsample: " + " ".join([str(x) for x in subsample]))
|
468 |
+
subsample_list.append(subsample)
|
469 |
+
return subsample_list
|
470 |
+
|
471 |
+
else:
|
472 |
+
raise ValueError("Invalid options: mode={}, arch={}".format(mode, arch))
|
473 |
+
|
474 |
+
|
475 |
+
def rename_state_dict(
|
476 |
+
old_prefix: str, new_prefix: str, state_dict: Dict[str, torch.Tensor]
|
477 |
+
):
|
478 |
+
"""Replace keys of old prefix with new prefix in state dict."""
|
479 |
+
# need this list not to break the dict iterator
|
480 |
+
old_keys = [k for k in state_dict if k.startswith(old_prefix)]
|
481 |
+
if len(old_keys) > 0:
|
482 |
+
logging.warning(f"Rename: {old_prefix} -> {new_prefix}")
|
483 |
+
for k in old_keys:
|
484 |
+
v = state_dict.pop(k)
|
485 |
+
new_k = k.replace(old_prefix, new_prefix)
|
486 |
+
state_dict[new_k] = v
|
487 |
+
|
488 |
+
|
489 |
+
def get_activation(act):
|
490 |
+
"""Return activation function."""
|
491 |
+
# Lazy load to avoid unused import
|
492 |
+
from espnet.nets.pytorch_backend.conformer.swish import Swish
|
493 |
+
|
494 |
+
activation_funcs = {
|
495 |
+
"hardtanh": torch.nn.Hardtanh,
|
496 |
+
"tanh": torch.nn.Tanh,
|
497 |
+
"relu": torch.nn.ReLU,
|
498 |
+
"selu": torch.nn.SELU,
|
499 |
+
"swish": Swish,
|
500 |
+
}
|
501 |
+
|
502 |
+
return activation_funcs[act]()
|
503 |
+
|
504 |
+
|
505 |
+
class MLPHead(torch.nn.Module):
|
506 |
+
def __init__(self, idim, hdim, odim, norm="batchnorm"):
|
507 |
+
super(MLPHead, self).__init__()
|
508 |
+
self.norm = norm
|
509 |
+
|
510 |
+
self.fc1 = torch.nn.Linear(idim, hdim)
|
511 |
+
if norm == "batchnorm":
|
512 |
+
self.bn1 = torch.nn.BatchNorm1d(hdim)
|
513 |
+
elif norm == "layernorm":
|
514 |
+
self.norm1 = torch.nn.LayerNorm(hdim)
|
515 |
+
self.nonlin1 = torch.nn.ReLU(inplace=True)
|
516 |
+
self.fc2 = torch.nn.Linear( hdim, odim)
|
517 |
+
|
518 |
+
def forward(self, x):
|
519 |
+
x = self.fc1(x)
|
520 |
+
if self.norm == "batchnorm":
|
521 |
+
x = self.bn1(x.transpose(1,2)).transpose(1,2)
|
522 |
+
elif self.norm == "layernorm":
|
523 |
+
x = self.norm1(x)
|
524 |
+
x = self.nonlin1(x)
|
525 |
+
x = self.fc2(x)
|
526 |
+
return x
|
espnet/nets/pytorch_backend/transformer/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
"""Initialize sub package."""
|
espnet/nets/pytorch_backend/transformer/add_sos_eos.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
# Copyright 2019 Shigeki Karita
|
5 |
+
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
6 |
+
|
7 |
+
"""Unility funcitons for Transformer."""
|
8 |
+
|
9 |
+
import torch
|
10 |
+
|
11 |
+
|
12 |
+
def add_sos_eos(ys_pad, sos, eos, ignore_id):
|
13 |
+
"""Add <sos> and <eos> labels.
|
14 |
+
|
15 |
+
:param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax)
|
16 |
+
:param int sos: index of <sos>
|
17 |
+
:param int eos: index of <eeos>
|
18 |
+
:param int ignore_id: index of padding
|
19 |
+
:return: padded tensor (B, Lmax)
|
20 |
+
:rtype: torch.Tensor
|
21 |
+
:return: padded tensor (B, Lmax)
|
22 |
+
:rtype: torch.Tensor
|
23 |
+
"""
|
24 |
+
from espnet.nets.pytorch_backend.nets_utils import pad_list
|
25 |
+
|
26 |
+
_sos = ys_pad.new([sos])
|
27 |
+
_eos = ys_pad.new([eos])
|
28 |
+
ys = [y[y != ignore_id] for y in ys_pad] # parse padded ys
|
29 |
+
ys_in = [torch.cat([_sos, y], dim=0) for y in ys]
|
30 |
+
ys_out = [torch.cat([y, _eos], dim=0) for y in ys]
|
31 |
+
return pad_list(ys_in, eos), pad_list(ys_out, ignore_id)
|
espnet/nets/pytorch_backend/transformer/attention.py
ADDED
@@ -0,0 +1,280 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
# Copyright 2019 Shigeki Karita
|
5 |
+
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
6 |
+
|
7 |
+
"""Multi-Head Attention layer definition."""
|
8 |
+
|
9 |
+
import math
|
10 |
+
|
11 |
+
import numpy
|
12 |
+
import torch
|
13 |
+
from torch import nn
|
14 |
+
|
15 |
+
|
16 |
+
class MultiHeadedAttention(nn.Module):
|
17 |
+
"""Multi-Head Attention layer.
|
18 |
+
Args:
|
19 |
+
n_head (int): The number of heads.
|
20 |
+
n_feat (int): The number of features.
|
21 |
+
dropout_rate (float): Dropout rate.
|
22 |
+
"""
|
23 |
+
|
24 |
+
def __init__(self, n_head, n_feat, dropout_rate):
|
25 |
+
"""Construct an MultiHeadedAttention object."""
|
26 |
+
super(MultiHeadedAttention, self).__init__()
|
27 |
+
assert n_feat % n_head == 0
|
28 |
+
# We assume d_v always equals d_k
|
29 |
+
self.d_k = n_feat // n_head
|
30 |
+
self.h = n_head
|
31 |
+
self.linear_q = nn.Linear(n_feat, n_feat)
|
32 |
+
self.linear_k = nn.Linear(n_feat, n_feat)
|
33 |
+
self.linear_v = nn.Linear(n_feat, n_feat)
|
34 |
+
self.linear_out = nn.Linear(n_feat, n_feat)
|
35 |
+
self.attn = None
|
36 |
+
self.dropout = nn.Dropout(p=dropout_rate)
|
37 |
+
|
38 |
+
def forward_qkv(self, query, key, value):
|
39 |
+
"""Transform query, key and value.
|
40 |
+
Args:
|
41 |
+
query (torch.Tensor): Query tensor (#batch, time1, size).
|
42 |
+
key (torch.Tensor): Key tensor (#batch, time2, size).
|
43 |
+
value (torch.Tensor): Value tensor (#batch, time2, size).
|
44 |
+
Returns:
|
45 |
+
torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k).
|
46 |
+
torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k).
|
47 |
+
torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k).
|
48 |
+
"""
|
49 |
+
n_batch = query.size(0)
|
50 |
+
q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
|
51 |
+
k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
|
52 |
+
v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
|
53 |
+
q = q.transpose(1, 2) # (batch, head, time1, d_k)
|
54 |
+
k = k.transpose(1, 2) # (batch, head, time2, d_k)
|
55 |
+
v = v.transpose(1, 2) # (batch, head, time2, d_k)
|
56 |
+
|
57 |
+
return q, k, v
|
58 |
+
|
59 |
+
def forward_attention(self, value, scores, mask, rtn_attn=False):
|
60 |
+
"""Compute attention context vector.
|
61 |
+
Args:
|
62 |
+
value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k).
|
63 |
+
scores (torch.Tensor): Attention score (#batch, n_head, time1, time2).
|
64 |
+
mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2).
|
65 |
+
rtn_attn (boolean): Flag of return attention score
|
66 |
+
Returns:
|
67 |
+
torch.Tensor: Transformed value (#batch, time1, d_model)
|
68 |
+
weighted by the attention score (#batch, time1, time2).
|
69 |
+
"""
|
70 |
+
n_batch = value.size(0)
|
71 |
+
if mask is not None:
|
72 |
+
mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
|
73 |
+
min_value = float(
|
74 |
+
numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min
|
75 |
+
)
|
76 |
+
scores = scores.masked_fill(mask, min_value)
|
77 |
+
self.attn = torch.softmax(scores, dim=-1).masked_fill(
|
78 |
+
mask, 0.0
|
79 |
+
) # (batch, head, time1, time2)
|
80 |
+
else:
|
81 |
+
self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
|
82 |
+
|
83 |
+
p_attn = self.dropout(self.attn)
|
84 |
+
x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
|
85 |
+
x = (
|
86 |
+
x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
|
87 |
+
) # (batch, time1, d_model)
|
88 |
+
if rtn_attn:
|
89 |
+
return self.linear_out(x), self.attn
|
90 |
+
return self.linear_out(x) # (batch, time1, d_model)
|
91 |
+
|
92 |
+
def forward(self, query, key, value, mask, rtn_attn=False):
|
93 |
+
"""Compute scaled dot product attention.
|
94 |
+
Args:
|
95 |
+
query (torch.Tensor): Query tensor (#batch, time1, size).
|
96 |
+
key (torch.Tensor): Key tensor (#batch, time2, size).
|
97 |
+
value (torch.Tensor): Value tensor (#batch, time2, size).
|
98 |
+
mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
|
99 |
+
(#batch, time1, time2).
|
100 |
+
rtn_attn (boolean): Flag of return attention score
|
101 |
+
Returns:
|
102 |
+
torch.Tensor: Output tensor (#batch, time1, d_model).
|
103 |
+
"""
|
104 |
+
q, k, v = self.forward_qkv(query, key, value)
|
105 |
+
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
|
106 |
+
return self.forward_attention(v, scores, mask, rtn_attn)
|
107 |
+
|
108 |
+
|
109 |
+
class LegacyRelPositionMultiHeadedAttention(MultiHeadedAttention):
|
110 |
+
"""Multi-Head Attention layer with relative position encoding (old version).
|
111 |
+
Details can be found in https://github.com/espnet/espnet/pull/2816.
|
112 |
+
Paper: https://arxiv.org/abs/1901.02860
|
113 |
+
Args:
|
114 |
+
n_head (int): The number of heads.
|
115 |
+
n_feat (int): The number of features.
|
116 |
+
dropout_rate (float): Dropout rate.
|
117 |
+
zero_triu (bool): Whether to zero the upper triangular part of attention matrix.
|
118 |
+
"""
|
119 |
+
|
120 |
+
def __init__(self, n_head, n_feat, dropout_rate, zero_triu=False):
|
121 |
+
"""Construct an RelPositionMultiHeadedAttention object."""
|
122 |
+
super().__init__(n_head, n_feat, dropout_rate)
|
123 |
+
self.zero_triu = zero_triu
|
124 |
+
# linear transformation for positional encoding
|
125 |
+
self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
|
126 |
+
# these two learnable bias are used in matrix c and matrix d
|
127 |
+
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
|
128 |
+
self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
|
129 |
+
self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k))
|
130 |
+
torch.nn.init.xavier_uniform_(self.pos_bias_u)
|
131 |
+
torch.nn.init.xavier_uniform_(self.pos_bias_v)
|
132 |
+
|
133 |
+
def rel_shift(self, x):
|
134 |
+
"""Compute relative positional encoding.
|
135 |
+
Args:
|
136 |
+
x (torch.Tensor): Input tensor (batch, head, time1, time2).
|
137 |
+
Returns:
|
138 |
+
torch.Tensor: Output tensor.
|
139 |
+
"""
|
140 |
+
zero_pad = torch.zeros((*x.size()[:3], 1), device=x.device, dtype=x.dtype)
|
141 |
+
x_padded = torch.cat([zero_pad, x], dim=-1)
|
142 |
+
|
143 |
+
x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2))
|
144 |
+
x = x_padded[:, :, 1:].view_as(x)
|
145 |
+
|
146 |
+
if self.zero_triu:
|
147 |
+
ones = torch.ones((x.size(2), x.size(3)))
|
148 |
+
x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :]
|
149 |
+
|
150 |
+
return x
|
151 |
+
|
152 |
+
def forward(self, query, key, value, pos_emb, mask):
|
153 |
+
"""Compute 'Scaled Dot Product Attention' with rel. positional encoding.
|
154 |
+
Args:
|
155 |
+
query (torch.Tensor): Query tensor (#batch, time1, size).
|
156 |
+
key (torch.Tensor): Key tensor (#batch, time2, size).
|
157 |
+
value (torch.Tensor): Value tensor (#batch, time2, size).
|
158 |
+
pos_emb (torch.Tensor): Positional embedding tensor (#batch, time1, size).
|
159 |
+
mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
|
160 |
+
(#batch, time1, time2).
|
161 |
+
Returns:
|
162 |
+
torch.Tensor: Output tensor (#batch, time1, d_model).
|
163 |
+
"""
|
164 |
+
q, k, v = self.forward_qkv(query, key, value)
|
165 |
+
q = q.transpose(1, 2) # (batch, time1, head, d_k)
|
166 |
+
|
167 |
+
n_batch_pos = pos_emb.size(0)
|
168 |
+
p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
|
169 |
+
p = p.transpose(1, 2) # (batch, head, time1, d_k)
|
170 |
+
|
171 |
+
# (batch, head, time1, d_k)
|
172 |
+
q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
|
173 |
+
# (batch, head, time1, d_k)
|
174 |
+
q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
|
175 |
+
|
176 |
+
# compute attention score
|
177 |
+
# first compute matrix a and matrix c
|
178 |
+
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
|
179 |
+
# (batch, head, time1, time2)
|
180 |
+
matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
|
181 |
+
|
182 |
+
# compute matrix b and matrix d
|
183 |
+
# (batch, head, time1, time1)
|
184 |
+
matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
|
185 |
+
matrix_bd = self.rel_shift(matrix_bd)
|
186 |
+
|
187 |
+
scores = (matrix_ac + matrix_bd) / math.sqrt(
|
188 |
+
self.d_k
|
189 |
+
) # (batch, head, time1, time2)
|
190 |
+
|
191 |
+
return self.forward_attention(v, scores, mask)
|
192 |
+
|
193 |
+
|
194 |
+
class RelPositionMultiHeadedAttention(MultiHeadedAttention):
|
195 |
+
"""Multi-Head Attention layer with relative position encoding (new implementation).
|
196 |
+
Details can be found in https://github.com/espnet/espnet/pull/2816.
|
197 |
+
Paper: https://arxiv.org/abs/1901.02860
|
198 |
+
Args:
|
199 |
+
n_head (int): The number of heads.
|
200 |
+
n_feat (int): The number of features.
|
201 |
+
dropout_rate (float): Dropout rate.
|
202 |
+
zero_triu (bool): Whether to zero the upper triangular part of attention matrix.
|
203 |
+
"""
|
204 |
+
|
205 |
+
def __init__(self, n_head, n_feat, dropout_rate, zero_triu=False):
|
206 |
+
"""Construct an RelPositionMultiHeadedAttention object."""
|
207 |
+
super().__init__(n_head, n_feat, dropout_rate)
|
208 |
+
self.zero_triu = zero_triu
|
209 |
+
# linear transformation for positional encoding
|
210 |
+
self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
|
211 |
+
# these two learnable bias are used in matrix c and matrix d
|
212 |
+
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
|
213 |
+
self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
|
214 |
+
self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k))
|
215 |
+
torch.nn.init.xavier_uniform_(self.pos_bias_u)
|
216 |
+
torch.nn.init.xavier_uniform_(self.pos_bias_v)
|
217 |
+
|
218 |
+
def rel_shift(self, x):
|
219 |
+
"""Compute relative positional encoding.
|
220 |
+
Args:
|
221 |
+
x (torch.Tensor): Input tensor (batch, head, time1, 2*time1-1).
|
222 |
+
time1 means the length of query vector.
|
223 |
+
Returns:
|
224 |
+
torch.Tensor: Output tensor.
|
225 |
+
"""
|
226 |
+
zero_pad = torch.zeros((*x.size()[:3], 1), device=x.device, dtype=x.dtype)
|
227 |
+
x_padded = torch.cat([zero_pad, x], dim=-1)
|
228 |
+
|
229 |
+
x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2))
|
230 |
+
x = x_padded[:, :, 1:].view_as(x)[
|
231 |
+
:, :, :, : x.size(-1) // 2 + 1
|
232 |
+
] # only keep the positions from 0 to time2
|
233 |
+
|
234 |
+
if self.zero_triu:
|
235 |
+
ones = torch.ones((x.size(2), x.size(3)), device=x.device)
|
236 |
+
x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :]
|
237 |
+
|
238 |
+
return x
|
239 |
+
|
240 |
+
def forward(self, query, key, value, pos_emb, mask):
|
241 |
+
"""Compute 'Scaled Dot Product Attention' with rel. positional encoding.
|
242 |
+
Args:
|
243 |
+
query (torch.Tensor): Query tensor (#batch, time1, size).
|
244 |
+
key (torch.Tensor): Key tensor (#batch, time2, size).
|
245 |
+
value (torch.Tensor): Value tensor (#batch, time2, size).
|
246 |
+
pos_emb (torch.Tensor): Positional embedding tensor
|
247 |
+
(#batch, 2*time1-1, size).
|
248 |
+
mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
|
249 |
+
(#batch, time1, time2).
|
250 |
+
Returns:
|
251 |
+
torch.Tensor: Output tensor (#batch, time1, d_model).
|
252 |
+
"""
|
253 |
+
q, k, v = self.forward_qkv(query, key, value)
|
254 |
+
q = q.transpose(1, 2) # (batch, time1, head, d_k)
|
255 |
+
|
256 |
+
n_batch_pos = pos_emb.size(0)
|
257 |
+
p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
|
258 |
+
p = p.transpose(1, 2) # (batch, head, 2*time1-1, d_k)
|
259 |
+
|
260 |
+
# (batch, head, time1, d_k)
|
261 |
+
q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
|
262 |
+
# (batch, head, time1, d_k)
|
263 |
+
q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
|
264 |
+
|
265 |
+
# compute attention score
|
266 |
+
# first compute matrix a and matrix c
|
267 |
+
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
|
268 |
+
# (batch, head, time1, time2)
|
269 |
+
matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
|
270 |
+
|
271 |
+
# compute matrix b and matrix d
|
272 |
+
# (batch, head, time1, 2*time1-1)
|
273 |
+
matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
|
274 |
+
matrix_bd = self.rel_shift(matrix_bd)
|
275 |
+
|
276 |
+
scores = (matrix_ac + matrix_bd) / math.sqrt(
|
277 |
+
self.d_k
|
278 |
+
) # (batch, head, time1, time2)
|
279 |
+
|
280 |
+
return self.forward_attention(v, scores, mask)
|
espnet/nets/pytorch_backend/transformer/convolution.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
# Copyright 2020 Johns Hopkins University (Shinji Watanabe)
|
5 |
+
# Northwestern Polytechnical University (Pengcheng Guo)
|
6 |
+
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
7 |
+
|
8 |
+
"""ConvolutionModule definition."""
|
9 |
+
|
10 |
+
import torch
|
11 |
+
from torch import nn
|
12 |
+
|
13 |
+
|
14 |
+
class ConvolutionModule(nn.Module):
|
15 |
+
"""ConvolutionModule in Conformer model.
|
16 |
+
|
17 |
+
:param int channels: channels of cnn
|
18 |
+
:param int kernel_size: kernerl size of cnn
|
19 |
+
|
20 |
+
"""
|
21 |
+
|
22 |
+
def __init__(self, channels, kernel_size, bias=True):
|
23 |
+
"""Construct an ConvolutionModule object."""
|
24 |
+
super(ConvolutionModule, self).__init__()
|
25 |
+
# kernerl_size should be a odd number for 'SAME' padding
|
26 |
+
assert (kernel_size - 1) % 2 == 0
|
27 |
+
|
28 |
+
self.pointwise_cov1 = nn.Conv1d(
|
29 |
+
channels, 2 * channels, kernel_size=1, stride=1, padding=0, bias=bias,
|
30 |
+
)
|
31 |
+
self.depthwise_conv = nn.Conv1d(
|
32 |
+
channels,
|
33 |
+
channels,
|
34 |
+
kernel_size,
|
35 |
+
stride=1,
|
36 |
+
padding=(kernel_size - 1) // 2,
|
37 |
+
groups=channels,
|
38 |
+
bias=bias,
|
39 |
+
)
|
40 |
+
self.norm = nn.BatchNorm1d(channels)
|
41 |
+
self.pointwise_cov2 = nn.Conv1d(
|
42 |
+
channels, channels, kernel_size=1, stride=1, padding=0, bias=bias,
|
43 |
+
)
|
44 |
+
self.activation = Swish()
|
45 |
+
|
46 |
+
def forward(self, x):
|
47 |
+
"""Compute covolution module.
|
48 |
+
|
49 |
+
:param torch.Tensor x: (batch, time, size)
|
50 |
+
:return torch.Tensor: convoluted `value` (batch, time, d_model)
|
51 |
+
"""
|
52 |
+
# exchange the temporal dimension and the feature dimension
|
53 |
+
x = x.transpose(1, 2)
|
54 |
+
|
55 |
+
# GLU mechanism
|
56 |
+
x = self.pointwise_cov1(x) # (batch, 2*channel, dim)
|
57 |
+
x = nn.functional.glu(x, dim=1) # (batch, channel, dim)
|
58 |
+
|
59 |
+
# 1D Depthwise Conv
|
60 |
+
x = self.depthwise_conv(x)
|
61 |
+
x = self.activation(self.norm(x))
|
62 |
+
|
63 |
+
x = self.pointwise_cov2(x)
|
64 |
+
|
65 |
+
return x.transpose(1, 2)
|
66 |
+
|
67 |
+
|
68 |
+
class Swish(nn.Module):
|
69 |
+
"""Construct an Swish object."""
|
70 |
+
|
71 |
+
def forward(self, x):
|
72 |
+
"""Return Swich activation function."""
|
73 |
+
return x * torch.sigmoid(x)
|
espnet/nets/pytorch_backend/transformer/decoder.py
ADDED
@@ -0,0 +1,229 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
# Copyright 2019 Shigeki Karita
|
5 |
+
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
6 |
+
|
7 |
+
"""Decoder definition."""
|
8 |
+
|
9 |
+
from typing import Any
|
10 |
+
from typing import List
|
11 |
+
from typing import Tuple
|
12 |
+
|
13 |
+
import torch
|
14 |
+
|
15 |
+
from espnet.nets.pytorch_backend.nets_utils import rename_state_dict
|
16 |
+
from espnet.nets.pytorch_backend.transformer.attention import MultiHeadedAttention
|
17 |
+
from espnet.nets.pytorch_backend.transformer.decoder_layer import DecoderLayer
|
18 |
+
from espnet.nets.pytorch_backend.transformer.embedding import PositionalEncoding
|
19 |
+
from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm
|
20 |
+
from espnet.nets.pytorch_backend.transformer.mask import subsequent_mask
|
21 |
+
from espnet.nets.pytorch_backend.transformer.positionwise_feed_forward import (
|
22 |
+
PositionwiseFeedForward, # noqa: H301
|
23 |
+
)
|
24 |
+
from espnet.nets.pytorch_backend.transformer.repeat import repeat
|
25 |
+
from espnet.nets.scorer_interface import BatchScorerInterface
|
26 |
+
|
27 |
+
|
28 |
+
def _pre_hook(
|
29 |
+
state_dict,
|
30 |
+
prefix,
|
31 |
+
local_metadata,
|
32 |
+
strict,
|
33 |
+
missing_keys,
|
34 |
+
unexpected_keys,
|
35 |
+
error_msgs,
|
36 |
+
):
|
37 |
+
# https://github.com/espnet/espnet/commit/3d422f6de8d4f03673b89e1caef698745ec749ea#diff-bffb1396f038b317b2b64dd96e6d3563
|
38 |
+
rename_state_dict(prefix + "output_norm.", prefix + "after_norm.", state_dict)
|
39 |
+
|
40 |
+
|
41 |
+
class Decoder(BatchScorerInterface, torch.nn.Module):
|
42 |
+
"""Transfomer decoder module.
|
43 |
+
|
44 |
+
:param int odim: output dim
|
45 |
+
:param int attention_dim: dimention of attention
|
46 |
+
:param int attention_heads: the number of heads of multi head attention
|
47 |
+
:param int linear_units: the number of units of position-wise feed forward
|
48 |
+
:param int num_blocks: the number of decoder blocks
|
49 |
+
:param float dropout_rate: dropout rate
|
50 |
+
:param float attention_dropout_rate: dropout rate for attention
|
51 |
+
:param str or torch.nn.Module input_layer: input layer type
|
52 |
+
:param bool use_output_layer: whether to use output layer
|
53 |
+
:param class pos_enc_class: PositionalEncoding or ScaledPositionalEncoding
|
54 |
+
:param bool normalize_before: whether to use layer_norm before the first block
|
55 |
+
:param bool concat_after: whether to concat attention layer's input and output
|
56 |
+
if True, additional linear will be applied.
|
57 |
+
i.e. x -> x + linear(concat(x, att(x)))
|
58 |
+
if False, no additional linear will be applied. i.e. x -> x + att(x)
|
59 |
+
"""
|
60 |
+
|
61 |
+
def __init__(
|
62 |
+
self,
|
63 |
+
odim,
|
64 |
+
attention_dim=256,
|
65 |
+
attention_heads=4,
|
66 |
+
linear_units=2048,
|
67 |
+
num_blocks=6,
|
68 |
+
dropout_rate=0.1,
|
69 |
+
positional_dropout_rate=0.1,
|
70 |
+
self_attention_dropout_rate=0.0,
|
71 |
+
src_attention_dropout_rate=0.0,
|
72 |
+
input_layer="embed",
|
73 |
+
use_output_layer=True,
|
74 |
+
pos_enc_class=PositionalEncoding,
|
75 |
+
normalize_before=True,
|
76 |
+
concat_after=False,
|
77 |
+
):
|
78 |
+
"""Construct an Decoder object."""
|
79 |
+
torch.nn.Module.__init__(self)
|
80 |
+
self._register_load_state_dict_pre_hook(_pre_hook)
|
81 |
+
if input_layer == "embed":
|
82 |
+
self.embed = torch.nn.Sequential(
|
83 |
+
torch.nn.Embedding(odim, attention_dim),
|
84 |
+
pos_enc_class(attention_dim, positional_dropout_rate),
|
85 |
+
)
|
86 |
+
elif input_layer == "linear":
|
87 |
+
self.embed = torch.nn.Sequential(
|
88 |
+
torch.nn.Linear(odim, attention_dim),
|
89 |
+
torch.nn.LayerNorm(attention_dim),
|
90 |
+
torch.nn.Dropout(dropout_rate),
|
91 |
+
torch.nn.ReLU(),
|
92 |
+
pos_enc_class(attention_dim, positional_dropout_rate),
|
93 |
+
)
|
94 |
+
elif isinstance(input_layer, torch.nn.Module):
|
95 |
+
self.embed = torch.nn.Sequential(
|
96 |
+
input_layer, pos_enc_class(attention_dim, positional_dropout_rate)
|
97 |
+
)
|
98 |
+
else:
|
99 |
+
raise NotImplementedError("only `embed` or torch.nn.Module is supported.")
|
100 |
+
self.normalize_before = normalize_before
|
101 |
+
self.decoders = repeat(
|
102 |
+
num_blocks,
|
103 |
+
lambda: DecoderLayer(
|
104 |
+
attention_dim,
|
105 |
+
MultiHeadedAttention(
|
106 |
+
attention_heads, attention_dim, self_attention_dropout_rate
|
107 |
+
),
|
108 |
+
MultiHeadedAttention(
|
109 |
+
attention_heads, attention_dim, src_attention_dropout_rate
|
110 |
+
),
|
111 |
+
PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
|
112 |
+
dropout_rate,
|
113 |
+
normalize_before,
|
114 |
+
concat_after,
|
115 |
+
),
|
116 |
+
)
|
117 |
+
if self.normalize_before:
|
118 |
+
self.after_norm = LayerNorm(attention_dim)
|
119 |
+
if use_output_layer:
|
120 |
+
self.output_layer = torch.nn.Linear(attention_dim, odim)
|
121 |
+
else:
|
122 |
+
self.output_layer = None
|
123 |
+
|
124 |
+
def forward(self, tgt, tgt_mask, memory, memory_mask):
|
125 |
+
"""Forward decoder.
|
126 |
+
:param torch.Tensor tgt: input token ids, int64 (batch, maxlen_out)
|
127 |
+
if input_layer == "embed"
|
128 |
+
input tensor (batch, maxlen_out, #mels)
|
129 |
+
in the other cases
|
130 |
+
:param torch.Tensor tgt_mask: input token mask, (batch, maxlen_out)
|
131 |
+
dtype=torch.uint8 in PyTorch 1.2-
|
132 |
+
dtype=torch.bool in PyTorch 1.2+ (include 1.2)
|
133 |
+
:param torch.Tensor memory: encoded memory, float32 (batch, maxlen_in, feat)
|
134 |
+
:param torch.Tensor memory_mask: encoded memory mask, (batch, maxlen_in)
|
135 |
+
dtype=torch.uint8 in PyTorch 1.2-
|
136 |
+
dtype=torch.bool in PyTorch 1.2+ (include 1.2)
|
137 |
+
:return x: decoded token score before softmax (batch, maxlen_out, token)
|
138 |
+
if use_output_layer is True,
|
139 |
+
final block outputs (batch, maxlen_out, attention_dim)
|
140 |
+
in the other cases
|
141 |
+
:rtype: torch.Tensor
|
142 |
+
:return tgt_mask: score mask before softmax (batch, maxlen_out)
|
143 |
+
:rtype: torch.Tensor
|
144 |
+
"""
|
145 |
+
x = self.embed(tgt)
|
146 |
+
x, tgt_mask, memory, memory_mask = self.decoders(
|
147 |
+
x, tgt_mask, memory, memory_mask
|
148 |
+
)
|
149 |
+
if self.normalize_before:
|
150 |
+
x = self.after_norm(x)
|
151 |
+
if self.output_layer is not None:
|
152 |
+
x = self.output_layer(x)
|
153 |
+
return x, tgt_mask
|
154 |
+
|
155 |
+
def forward_one_step(self, tgt, tgt_mask, memory, memory_mask=None, cache=None):
|
156 |
+
"""Forward one step.
|
157 |
+
:param torch.Tensor tgt: input token ids, int64 (batch, maxlen_out)
|
158 |
+
:param torch.Tensor tgt_mask: input token mask, (batch, maxlen_out)
|
159 |
+
dtype=torch.uint8 in PyTorch 1.2-
|
160 |
+
dtype=torch.bool in PyTorch 1.2+ (include 1.2)
|
161 |
+
:param torch.Tensor memory: encoded memory, float32 (batch, maxlen_in, feat)
|
162 |
+
:param List[torch.Tensor] cache:
|
163 |
+
cached output list of (batch, max_time_out-1, size)
|
164 |
+
:return y, cache: NN output value and cache per `self.decoders`.
|
165 |
+
`y.shape` is (batch, maxlen_out, token)
|
166 |
+
:rtype: Tuple[torch.Tensor, List[torch.Tensor]]
|
167 |
+
"""
|
168 |
+
x = self.embed(tgt)
|
169 |
+
if cache is None:
|
170 |
+
cache = [None] * len(self.decoders)
|
171 |
+
new_cache = []
|
172 |
+
for c, decoder in zip(cache, self.decoders):
|
173 |
+
x, tgt_mask, memory, memory_mask = decoder(
|
174 |
+
x, tgt_mask, memory, memory_mask, cache=c
|
175 |
+
)
|
176 |
+
new_cache.append(x)
|
177 |
+
|
178 |
+
if self.normalize_before:
|
179 |
+
y = self.after_norm(x[:, -1])
|
180 |
+
else:
|
181 |
+
y = x[:, -1]
|
182 |
+
if self.output_layer is not None:
|
183 |
+
y = torch.log_softmax(self.output_layer(y), dim=-1)
|
184 |
+
|
185 |
+
return y, new_cache
|
186 |
+
|
187 |
+
# beam search API (see ScorerInterface)
|
188 |
+
def score(self, ys, state, x):
|
189 |
+
"""Score."""
|
190 |
+
ys_mask = subsequent_mask(len(ys), device=x.device).unsqueeze(0)
|
191 |
+
logp, state = self.forward_one_step(
|
192 |
+
ys.unsqueeze(0), ys_mask, x.unsqueeze(0), cache=state
|
193 |
+
)
|
194 |
+
return logp.squeeze(0), state
|
195 |
+
|
196 |
+
# batch beam search API (see BatchScorerInterface)
|
197 |
+
def batch_score(
|
198 |
+
self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor
|
199 |
+
) -> Tuple[torch.Tensor, List[Any]]:
|
200 |
+
"""Score new token batch (required).
|
201 |
+
Args:
|
202 |
+
ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen).
|
203 |
+
states (List[Any]): Scorer states for prefix tokens.
|
204 |
+
xs (torch.Tensor):
|
205 |
+
The encoder feature that generates ys (n_batch, xlen, n_feat).
|
206 |
+
Returns:
|
207 |
+
tuple[torch.Tensor, List[Any]]: Tuple of
|
208 |
+
batchfied scores for next token with shape of `(n_batch, n_vocab)`
|
209 |
+
and next state list for ys.
|
210 |
+
"""
|
211 |
+
# merge states
|
212 |
+
n_batch = len(ys)
|
213 |
+
n_layers = len(self.decoders)
|
214 |
+
if states[0] is None:
|
215 |
+
batch_state = None
|
216 |
+
else:
|
217 |
+
# transpose state of [batch, layer] into [layer, batch]
|
218 |
+
batch_state = [
|
219 |
+
torch.stack([states[b][l] for b in range(n_batch)])
|
220 |
+
for l in range(n_layers)
|
221 |
+
]
|
222 |
+
|
223 |
+
# batch decoding
|
224 |
+
ys_mask = subsequent_mask(ys.size(-1), device=xs.device).unsqueeze(0)
|
225 |
+
logp, states = self.forward_one_step(ys, ys_mask, xs, cache=batch_state)
|
226 |
+
|
227 |
+
# transpose state of [layer, batch] into [batch, layer]
|
228 |
+
state_list = [[states[l][b] for l in range(n_layers)] for b in range(n_batch)]
|
229 |
+
return logp, state_list
|
espnet/nets/pytorch_backend/transformer/decoder_layer.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
# Copyright 2019 Shigeki Karita
|
5 |
+
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
6 |
+
|
7 |
+
"""Decoder self-attention layer definition."""
|
8 |
+
|
9 |
+
import torch
|
10 |
+
from torch import nn
|
11 |
+
|
12 |
+
from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm
|
13 |
+
|
14 |
+
|
15 |
+
class DecoderLayer(nn.Module):
|
16 |
+
"""Single decoder layer module.
|
17 |
+
:param int size: input dim
|
18 |
+
:param espnet.nets.pytorch_backend.transformer.attention.MultiHeadedAttention
|
19 |
+
self_attn: self attention module
|
20 |
+
:param espnet.nets.pytorch_backend.transformer.attention.MultiHeadedAttention
|
21 |
+
src_attn: source attention module
|
22 |
+
:param espnet.nets.pytorch_backend.transformer.positionwise_feed_forward.
|
23 |
+
PositionwiseFeedForward feed_forward: feed forward layer module
|
24 |
+
:param float dropout_rate: dropout rate
|
25 |
+
:param bool normalize_before: whether to use layer_norm before the first block
|
26 |
+
:param bool concat_after: whether to concat attention layer's input and output
|
27 |
+
if True, additional linear will be applied.
|
28 |
+
i.e. x -> x + linear(concat(x, att(x)))
|
29 |
+
if False, no additional linear will be applied. i.e. x -> x + att(x)
|
30 |
+
"""
|
31 |
+
|
32 |
+
def __init__(
|
33 |
+
self,
|
34 |
+
size,
|
35 |
+
self_attn,
|
36 |
+
src_attn,
|
37 |
+
feed_forward,
|
38 |
+
dropout_rate,
|
39 |
+
normalize_before=True,
|
40 |
+
concat_after=False,
|
41 |
+
):
|
42 |
+
"""Construct an DecoderLayer object."""
|
43 |
+
super(DecoderLayer, self).__init__()
|
44 |
+
self.size = size
|
45 |
+
self.self_attn = self_attn
|
46 |
+
self.src_attn = src_attn
|
47 |
+
self.feed_forward = feed_forward
|
48 |
+
self.norm1 = LayerNorm(size)
|
49 |
+
self.norm2 = LayerNorm(size)
|
50 |
+
self.norm3 = LayerNorm(size)
|
51 |
+
self.dropout = nn.Dropout(dropout_rate)
|
52 |
+
self.normalize_before = normalize_before
|
53 |
+
self.concat_after = concat_after
|
54 |
+
if self.concat_after:
|
55 |
+
self.concat_linear1 = nn.Linear(size + size, size)
|
56 |
+
self.concat_linear2 = nn.Linear(size + size, size)
|
57 |
+
|
58 |
+
def forward(self, tgt, tgt_mask, memory, memory_mask, cache=None):
|
59 |
+
"""Compute decoded features.
|
60 |
+
Args:
|
61 |
+
tgt (torch.Tensor):
|
62 |
+
decoded previous target features (batch, max_time_out, size)
|
63 |
+
tgt_mask (torch.Tensor): mask for x (batch, max_time_out)
|
64 |
+
memory (torch.Tensor): encoded source features (batch, max_time_in, size)
|
65 |
+
memory_mask (torch.Tensor): mask for memory (batch, max_time_in)
|
66 |
+
cache (torch.Tensor): cached output (batch, max_time_out-1, size)
|
67 |
+
"""
|
68 |
+
residual = tgt
|
69 |
+
if self.normalize_before:
|
70 |
+
tgt = self.norm1(tgt)
|
71 |
+
|
72 |
+
if cache is None:
|
73 |
+
tgt_q = tgt
|
74 |
+
tgt_q_mask = tgt_mask
|
75 |
+
else:
|
76 |
+
# compute only the last frame query keeping dim: max_time_out -> 1
|
77 |
+
assert cache.shape == (
|
78 |
+
tgt.shape[0],
|
79 |
+
tgt.shape[1] - 1,
|
80 |
+
self.size,
|
81 |
+
), f"{cache.shape} == {(tgt.shape[0], tgt.shape[1] - 1, self.size)}"
|
82 |
+
tgt_q = tgt[:, -1:, :]
|
83 |
+
residual = residual[:, -1:, :]
|
84 |
+
tgt_q_mask = None
|
85 |
+
if tgt_mask is not None:
|
86 |
+
tgt_q_mask = tgt_mask[:, -1:, :]
|
87 |
+
|
88 |
+
if self.concat_after:
|
89 |
+
tgt_concat = torch.cat(
|
90 |
+
(tgt_q, self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)), dim=-1
|
91 |
+
)
|
92 |
+
x = residual + self.concat_linear1(tgt_concat)
|
93 |
+
else:
|
94 |
+
x = residual + self.dropout(self.self_attn(tgt_q, tgt, tgt, tgt_q_mask))
|
95 |
+
if not self.normalize_before:
|
96 |
+
x = self.norm1(x)
|
97 |
+
|
98 |
+
residual = x
|
99 |
+
if self.normalize_before:
|
100 |
+
x = self.norm2(x)
|
101 |
+
if self.concat_after:
|
102 |
+
x_concat = torch.cat(
|
103 |
+
(x, self.src_attn(x, memory, memory, memory_mask)), dim=-1
|
104 |
+
)
|
105 |
+
x = residual + self.concat_linear2(x_concat)
|
106 |
+
else:
|
107 |
+
x = residual + self.dropout(self.src_attn(x, memory, memory, memory_mask))
|
108 |
+
if not self.normalize_before:
|
109 |
+
x = self.norm2(x)
|
110 |
+
|
111 |
+
residual = x
|
112 |
+
if self.normalize_before:
|
113 |
+
x = self.norm3(x)
|
114 |
+
x = residual + self.dropout(self.feed_forward(x))
|
115 |
+
if not self.normalize_before:
|
116 |
+
x = self.norm3(x)
|
117 |
+
|
118 |
+
if cache is not None:
|
119 |
+
x = torch.cat([cache, x], dim=1)
|
120 |
+
|
121 |
+
return x, tgt_mask, memory, memory_mask
|
espnet/nets/pytorch_backend/transformer/embedding.py
ADDED
@@ -0,0 +1,217 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
# Copyright 2019 Shigeki Karita
|
5 |
+
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
6 |
+
|
7 |
+
"""Positional Encoding Module."""
|
8 |
+
|
9 |
+
import math
|
10 |
+
|
11 |
+
import torch
|
12 |
+
|
13 |
+
|
14 |
+
def _pre_hook(
|
15 |
+
state_dict,
|
16 |
+
prefix,
|
17 |
+
local_metadata,
|
18 |
+
strict,
|
19 |
+
missing_keys,
|
20 |
+
unexpected_keys,
|
21 |
+
error_msgs,
|
22 |
+
):
|
23 |
+
"""Perform pre-hook in load_state_dict for backward compatibility.
|
24 |
+
Note:
|
25 |
+
We saved self.pe until v.0.5.2 but we have omitted it later.
|
26 |
+
Therefore, we remove the item "pe" from `state_dict` for backward compatibility.
|
27 |
+
"""
|
28 |
+
k = prefix + "pe"
|
29 |
+
if k in state_dict:
|
30 |
+
state_dict.pop(k)
|
31 |
+
|
32 |
+
|
33 |
+
class PositionalEncoding(torch.nn.Module):
|
34 |
+
"""Positional encoding.
|
35 |
+
Args:
|
36 |
+
d_model (int): Embedding dimension.
|
37 |
+
dropout_rate (float): Dropout rate.
|
38 |
+
max_len (int): Maximum input length.
|
39 |
+
reverse (bool): Whether to reverse the input position. Only for
|
40 |
+
the class LegacyRelPositionalEncoding. We remove it in the current
|
41 |
+
class RelPositionalEncoding.
|
42 |
+
"""
|
43 |
+
|
44 |
+
def __init__(self, d_model, dropout_rate, max_len=5000, reverse=False):
|
45 |
+
"""Construct an PositionalEncoding object."""
|
46 |
+
super(PositionalEncoding, self).__init__()
|
47 |
+
self.d_model = d_model
|
48 |
+
self.reverse = reverse
|
49 |
+
self.xscale = math.sqrt(self.d_model)
|
50 |
+
self.dropout = torch.nn.Dropout(p=dropout_rate)
|
51 |
+
self.pe = None
|
52 |
+
self.extend_pe(torch.tensor(0.0).expand(1, max_len))
|
53 |
+
self._register_load_state_dict_pre_hook(_pre_hook)
|
54 |
+
|
55 |
+
def extend_pe(self, x):
|
56 |
+
"""Reset the positional encodings."""
|
57 |
+
if self.pe is not None:
|
58 |
+
if self.pe.size(1) >= x.size(1):
|
59 |
+
if self.pe.dtype != x.dtype or self.pe.device != x.device:
|
60 |
+
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
|
61 |
+
return
|
62 |
+
pe = torch.zeros(x.size(1), self.d_model)
|
63 |
+
if self.reverse:
|
64 |
+
position = torch.arange(
|
65 |
+
x.size(1) - 1, -1, -1.0, dtype=torch.float32
|
66 |
+
).unsqueeze(1)
|
67 |
+
else:
|
68 |
+
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
|
69 |
+
div_term = torch.exp(
|
70 |
+
torch.arange(0, self.d_model, 2, dtype=torch.float32)
|
71 |
+
* -(math.log(10000.0) / self.d_model)
|
72 |
+
)
|
73 |
+
pe[:, 0::2] = torch.sin(position * div_term)
|
74 |
+
pe[:, 1::2] = torch.cos(position * div_term)
|
75 |
+
pe = pe.unsqueeze(0)
|
76 |
+
self.pe = pe.to(device=x.device, dtype=x.dtype)
|
77 |
+
|
78 |
+
def forward(self, x: torch.Tensor):
|
79 |
+
"""Add positional encoding.
|
80 |
+
Args:
|
81 |
+
x (torch.Tensor): Input tensor (batch, time, `*`).
|
82 |
+
Returns:
|
83 |
+
torch.Tensor: Encoded tensor (batch, time, `*`).
|
84 |
+
"""
|
85 |
+
self.extend_pe(x)
|
86 |
+
x = x * self.xscale + self.pe[:, : x.size(1)]
|
87 |
+
return self.dropout(x)
|
88 |
+
|
89 |
+
|
90 |
+
class ScaledPositionalEncoding(PositionalEncoding):
|
91 |
+
"""Scaled positional encoding module.
|
92 |
+
See Sec. 3.2 https://arxiv.org/abs/1809.08895
|
93 |
+
Args:
|
94 |
+
d_model (int): Embedding dimension.
|
95 |
+
dropout_rate (float): Dropout rate.
|
96 |
+
max_len (int): Maximum input length.
|
97 |
+
"""
|
98 |
+
|
99 |
+
def __init__(self, d_model, dropout_rate, max_len=5000):
|
100 |
+
"""Initialize class."""
|
101 |
+
super().__init__(d_model=d_model, dropout_rate=dropout_rate, max_len=max_len)
|
102 |
+
self.alpha = torch.nn.Parameter(torch.tensor(1.0))
|
103 |
+
|
104 |
+
def reset_parameters(self):
|
105 |
+
"""Reset parameters."""
|
106 |
+
self.alpha.data = torch.tensor(1.0)
|
107 |
+
|
108 |
+
def forward(self, x):
|
109 |
+
"""Add positional encoding.
|
110 |
+
Args:
|
111 |
+
x (torch.Tensor): Input tensor (batch, time, `*`).
|
112 |
+
Returns:
|
113 |
+
torch.Tensor: Encoded tensor (batch, time, `*`).
|
114 |
+
"""
|
115 |
+
self.extend_pe(x)
|
116 |
+
x = x + self.alpha * self.pe[:, : x.size(1)]
|
117 |
+
return self.dropout(x)
|
118 |
+
|
119 |
+
|
120 |
+
class LegacyRelPositionalEncoding(PositionalEncoding):
|
121 |
+
"""Relative positional encoding module (old version).
|
122 |
+
Details can be found in https://github.com/espnet/espnet/pull/2816.
|
123 |
+
See : Appendix B in https://arxiv.org/abs/1901.02860
|
124 |
+
Args:
|
125 |
+
d_model (int): Embedding dimension.
|
126 |
+
dropout_rate (float): Dropout rate.
|
127 |
+
max_len (int): Maximum input length.
|
128 |
+
"""
|
129 |
+
|
130 |
+
def __init__(self, d_model, dropout_rate, max_len=5000):
|
131 |
+
"""Initialize class."""
|
132 |
+
super().__init__(
|
133 |
+
d_model=d_model,
|
134 |
+
dropout_rate=dropout_rate,
|
135 |
+
max_len=max_len,
|
136 |
+
reverse=True,
|
137 |
+
)
|
138 |
+
|
139 |
+
def forward(self, x):
|
140 |
+
"""Compute positional encoding.
|
141 |
+
Args:
|
142 |
+
x (torch.Tensor): Input tensor (batch, time, `*`).
|
143 |
+
Returns:
|
144 |
+
torch.Tensor: Encoded tensor (batch, time, `*`).
|
145 |
+
torch.Tensor: Positional embedding tensor (1, time, `*`).
|
146 |
+
"""
|
147 |
+
self.extend_pe(x)
|
148 |
+
x = x * self.xscale
|
149 |
+
pos_emb = self.pe[:, : x.size(1)]
|
150 |
+
return self.dropout(x), self.dropout(pos_emb)
|
151 |
+
|
152 |
+
|
153 |
+
class RelPositionalEncoding(torch.nn.Module):
|
154 |
+
"""Relative positional encoding module (new implementation).
|
155 |
+
Details can be found in https://github.com/espnet/espnet/pull/2816.
|
156 |
+
See : Appendix B in https://arxiv.org/abs/1901.02860
|
157 |
+
Args:
|
158 |
+
d_model (int): Embedding dimension.
|
159 |
+
dropout_rate (float): Dropout rate.
|
160 |
+
max_len (int): Maximum input length.
|
161 |
+
"""
|
162 |
+
|
163 |
+
def __init__(self, d_model, dropout_rate, max_len=5000):
|
164 |
+
"""Construct an PositionalEncoding object."""
|
165 |
+
super(RelPositionalEncoding, self).__init__()
|
166 |
+
self.d_model = d_model
|
167 |
+
self.xscale = math.sqrt(self.d_model)
|
168 |
+
self.dropout = torch.nn.Dropout(p=dropout_rate)
|
169 |
+
self.pe = None
|
170 |
+
self.extend_pe(torch.tensor(0.0).expand(1, max_len))
|
171 |
+
|
172 |
+
def extend_pe(self, x):
|
173 |
+
"""Reset the positional encodings."""
|
174 |
+
if self.pe is not None:
|
175 |
+
# self.pe contains both positive and negative parts
|
176 |
+
# the length of self.pe is 2 * input_len - 1
|
177 |
+
if self.pe.size(1) >= x.size(1) * 2 - 1:
|
178 |
+
if self.pe.dtype != x.dtype or self.pe.device != x.device:
|
179 |
+
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
|
180 |
+
return
|
181 |
+
# Suppose `i` means to the position of query vecotr and `j` means the
|
182 |
+
# position of key vector. We use position relative positions when keys
|
183 |
+
# are to the left (i>j) and negative relative positions otherwise (i<j).
|
184 |
+
pe_positive = torch.zeros(x.size(1), self.d_model)
|
185 |
+
pe_negative = torch.zeros(x.size(1), self.d_model)
|
186 |
+
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
|
187 |
+
div_term = torch.exp(
|
188 |
+
torch.arange(0, self.d_model, 2, dtype=torch.float32)
|
189 |
+
* -(math.log(10000.0) / self.d_model)
|
190 |
+
)
|
191 |
+
pe_positive[:, 0::2] = torch.sin(position * div_term)
|
192 |
+
pe_positive[:, 1::2] = torch.cos(position * div_term)
|
193 |
+
pe_negative[:, 0::2] = torch.sin(-1 * position * div_term)
|
194 |
+
pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)
|
195 |
+
|
196 |
+
# Reserve the order of positive indices and concat both positive and
|
197 |
+
# negative indices. This is used to support the shifting trick
|
198 |
+
# as in https://arxiv.org/abs/1901.02860
|
199 |
+
pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
|
200 |
+
pe_negative = pe_negative[1:].unsqueeze(0)
|
201 |
+
pe = torch.cat([pe_positive, pe_negative], dim=1)
|
202 |
+
self.pe = pe.to(device=x.device, dtype=x.dtype)
|
203 |
+
|
204 |
+
def forward(self, x: torch.Tensor):
|
205 |
+
"""Add positional encoding.
|
206 |
+
Args:
|
207 |
+
x (torch.Tensor): Input tensor (batch, time, `*`).
|
208 |
+
Returns:
|
209 |
+
torch.Tensor: Encoded tensor (batch, time, `*`).
|
210 |
+
"""
|
211 |
+
self.extend_pe(x)
|
212 |
+
x = x * self.xscale
|
213 |
+
pos_emb = self.pe[
|
214 |
+
:,
|
215 |
+
self.pe.size(1) // 2 - x.size(1) + 1 : self.pe.size(1) // 2 + x.size(1),
|
216 |
+
]
|
217 |
+
return self.dropout(x), self.dropout(pos_emb)
|
espnet/nets/pytorch_backend/transformer/encoder.py
ADDED
@@ -0,0 +1,283 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
# Copyright 2019 Shigeki Karita
|
5 |
+
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
6 |
+
|
7 |
+
"""Encoder definition."""
|
8 |
+
|
9 |
+
import torch
|
10 |
+
|
11 |
+
from espnet.nets.pytorch_backend.nets_utils import rename_state_dict
|
12 |
+
#from espnet.nets.pytorch_backend.transducer.vgg import VGG2L
|
13 |
+
from espnet.nets.pytorch_backend.transformer.attention import (
|
14 |
+
MultiHeadedAttention, # noqa: H301
|
15 |
+
RelPositionMultiHeadedAttention, # noqa: H301
|
16 |
+
LegacyRelPositionMultiHeadedAttention, # noqa: H301
|
17 |
+
)
|
18 |
+
from espnet.nets.pytorch_backend.transformer.convolution import ConvolutionModule
|
19 |
+
from espnet.nets.pytorch_backend.transformer.embedding import (
|
20 |
+
PositionalEncoding, # noqa: H301
|
21 |
+
RelPositionalEncoding, # noqa: H301
|
22 |
+
LegacyRelPositionalEncoding, # noqa: H301
|
23 |
+
)
|
24 |
+
from espnet.nets.pytorch_backend.transformer.encoder_layer import EncoderLayer
|
25 |
+
from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm
|
26 |
+
from espnet.nets.pytorch_backend.transformer.multi_layer_conv import Conv1dLinear
|
27 |
+
from espnet.nets.pytorch_backend.transformer.multi_layer_conv import MultiLayeredConv1d
|
28 |
+
from espnet.nets.pytorch_backend.transformer.positionwise_feed_forward import (
|
29 |
+
PositionwiseFeedForward, # noqa: H301
|
30 |
+
)
|
31 |
+
from espnet.nets.pytorch_backend.transformer.repeat import repeat
|
32 |
+
from espnet.nets.pytorch_backend.transformer.subsampling import Conv2dSubsampling
|
33 |
+
from espnet.nets.pytorch_backend.transformer.raw_embeddings import VideoEmbedding
|
34 |
+
from espnet.nets.pytorch_backend.transformer.raw_embeddings import AudioEmbedding
|
35 |
+
from espnet.nets.pytorch_backend.backbones.conv3d_extractor import Conv3dResNet
|
36 |
+
from espnet.nets.pytorch_backend.backbones.conv1d_extractor import Conv1dResNet
|
37 |
+
|
38 |
+
|
39 |
+
def _pre_hook(
|
40 |
+
state_dict,
|
41 |
+
prefix,
|
42 |
+
local_metadata,
|
43 |
+
strict,
|
44 |
+
missing_keys,
|
45 |
+
unexpected_keys,
|
46 |
+
error_msgs,
|
47 |
+
):
|
48 |
+
# https://github.com/espnet/espnet/commit/21d70286c354c66c0350e65dc098d2ee236faccc#diff-bffb1396f038b317b2b64dd96e6d3563
|
49 |
+
rename_state_dict(prefix + "input_layer.", prefix + "embed.", state_dict)
|
50 |
+
# https://github.com/espnet/espnet/commit/3d422f6de8d4f03673b89e1caef698745ec749ea#diff-bffb1396f038b317b2b64dd96e6d3563
|
51 |
+
rename_state_dict(prefix + "norm.", prefix + "after_norm.", state_dict)
|
52 |
+
|
53 |
+
|
54 |
+
class Encoder(torch.nn.Module):
|
55 |
+
"""Transformer encoder module.
|
56 |
+
|
57 |
+
:param int idim: input dim
|
58 |
+
:param int attention_dim: dimention of attention
|
59 |
+
:param int attention_heads: the number of heads of multi head attention
|
60 |
+
:param int linear_units: the number of units of position-wise feed forward
|
61 |
+
:param int num_blocks: the number of decoder blocks
|
62 |
+
:param float dropout_rate: dropout rate
|
63 |
+
:param float attention_dropout_rate: dropout rate in attention
|
64 |
+
:param float positional_dropout_rate: dropout rate after adding positional encoding
|
65 |
+
:param str or torch.nn.Module input_layer: input layer type
|
66 |
+
:param class pos_enc_class: PositionalEncoding or ScaledPositionalEncoding
|
67 |
+
:param bool normalize_before: whether to use layer_norm before the first block
|
68 |
+
:param bool concat_after: whether to concat attention layer's input and output
|
69 |
+
if True, additional linear will be applied.
|
70 |
+
i.e. x -> x + linear(concat(x, att(x)))
|
71 |
+
if False, no additional linear will be applied. i.e. x -> x + att(x)
|
72 |
+
:param str positionwise_layer_type: linear of conv1d
|
73 |
+
:param int positionwise_conv_kernel_size: kernel size of positionwise conv1d layer
|
74 |
+
:param str encoder_attn_layer_type: encoder attention layer type
|
75 |
+
:param bool macaron_style: whether to use macaron style for positionwise layer
|
76 |
+
:param bool use_cnn_module: whether to use convolution module
|
77 |
+
:param bool zero_triu: whether to zero the upper triangular part of attention matrix
|
78 |
+
:param int cnn_module_kernel: kernerl size of convolution module
|
79 |
+
:param int padding_idx: padding_idx for input_layer=embed
|
80 |
+
"""
|
81 |
+
|
82 |
+
def __init__(
|
83 |
+
self,
|
84 |
+
idim,
|
85 |
+
attention_dim=256,
|
86 |
+
attention_heads=4,
|
87 |
+
linear_units=2048,
|
88 |
+
num_blocks=6,
|
89 |
+
dropout_rate=0.1,
|
90 |
+
positional_dropout_rate=0.1,
|
91 |
+
attention_dropout_rate=0.0,
|
92 |
+
input_layer="conv2d",
|
93 |
+
pos_enc_class=PositionalEncoding,
|
94 |
+
normalize_before=True,
|
95 |
+
concat_after=False,
|
96 |
+
positionwise_layer_type="linear",
|
97 |
+
positionwise_conv_kernel_size=1,
|
98 |
+
macaron_style=False,
|
99 |
+
encoder_attn_layer_type="mha",
|
100 |
+
use_cnn_module=False,
|
101 |
+
zero_triu=False,
|
102 |
+
cnn_module_kernel=31,
|
103 |
+
padding_idx=-1,
|
104 |
+
relu_type="prelu",
|
105 |
+
a_upsample_ratio=1,
|
106 |
+
):
|
107 |
+
"""Construct an Encoder object."""
|
108 |
+
super(Encoder, self).__init__()
|
109 |
+
self._register_load_state_dict_pre_hook(_pre_hook)
|
110 |
+
|
111 |
+
if encoder_attn_layer_type == "rel_mha":
|
112 |
+
pos_enc_class = RelPositionalEncoding
|
113 |
+
elif encoder_attn_layer_type == "legacy_rel_mha":
|
114 |
+
pos_enc_class = LegacyRelPositionalEncoding
|
115 |
+
# -- frontend module.
|
116 |
+
if input_layer == "conv1d":
|
117 |
+
self.frontend = Conv1dResNet(
|
118 |
+
relu_type=relu_type,
|
119 |
+
a_upsample_ratio=a_upsample_ratio,
|
120 |
+
)
|
121 |
+
elif input_layer == "conv3d":
|
122 |
+
self.frontend = Conv3dResNet(relu_type=relu_type)
|
123 |
+
else:
|
124 |
+
self.frontend = None
|
125 |
+
# -- backend module.
|
126 |
+
if input_layer == "linear":
|
127 |
+
self.embed = torch.nn.Sequential(
|
128 |
+
torch.nn.Linear(idim, attention_dim),
|
129 |
+
torch.nn.LayerNorm(attention_dim),
|
130 |
+
torch.nn.Dropout(dropout_rate),
|
131 |
+
torch.nn.ReLU(),
|
132 |
+
pos_enc_class(attention_dim, positional_dropout_rate),
|
133 |
+
)
|
134 |
+
elif input_layer == "conv2d":
|
135 |
+
self.embed = Conv2dSubsampling(
|
136 |
+
idim,
|
137 |
+
attention_dim,
|
138 |
+
dropout_rate,
|
139 |
+
pos_enc_class(attention_dim, dropout_rate),
|
140 |
+
)
|
141 |
+
elif input_layer == "vgg2l":
|
142 |
+
self.embed = VGG2L(idim, attention_dim)
|
143 |
+
elif input_layer == "embed":
|
144 |
+
self.embed = torch.nn.Sequential(
|
145 |
+
torch.nn.Embedding(idim, attention_dim, padding_idx=padding_idx),
|
146 |
+
pos_enc_class(attention_dim, positional_dropout_rate),
|
147 |
+
)
|
148 |
+
elif isinstance(input_layer, torch.nn.Module):
|
149 |
+
self.embed = torch.nn.Sequential(
|
150 |
+
input_layer, pos_enc_class(attention_dim, positional_dropout_rate),
|
151 |
+
)
|
152 |
+
elif input_layer in ["conv1d", "conv3d"]:
|
153 |
+
self.embed = torch.nn.Sequential(
|
154 |
+
torch.nn.Linear(512, attention_dim),
|
155 |
+
pos_enc_class(attention_dim, positional_dropout_rate)
|
156 |
+
)
|
157 |
+
elif input_layer is None:
|
158 |
+
self.embed = torch.nn.Sequential(
|
159 |
+
pos_enc_class(attention_dim, positional_dropout_rate)
|
160 |
+
)
|
161 |
+
else:
|
162 |
+
raise ValueError("unknown input_layer: " + input_layer)
|
163 |
+
self.normalize_before = normalize_before
|
164 |
+
if positionwise_layer_type == "linear":
|
165 |
+
positionwise_layer = PositionwiseFeedForward
|
166 |
+
positionwise_layer_args = (attention_dim, linear_units, dropout_rate)
|
167 |
+
elif positionwise_layer_type == "conv1d":
|
168 |
+
positionwise_layer = MultiLayeredConv1d
|
169 |
+
positionwise_layer_args = (
|
170 |
+
attention_dim,
|
171 |
+
linear_units,
|
172 |
+
positionwise_conv_kernel_size,
|
173 |
+
dropout_rate,
|
174 |
+
)
|
175 |
+
elif positionwise_layer_type == "conv1d-linear":
|
176 |
+
positionwise_layer = Conv1dLinear
|
177 |
+
positionwise_layer_args = (
|
178 |
+
attention_dim,
|
179 |
+
linear_units,
|
180 |
+
positionwise_conv_kernel_size,
|
181 |
+
dropout_rate,
|
182 |
+
)
|
183 |
+
else:
|
184 |
+
raise NotImplementedError("Support only linear or conv1d.")
|
185 |
+
|
186 |
+
if encoder_attn_layer_type == "mha":
|
187 |
+
encoder_attn_layer = MultiHeadedAttention
|
188 |
+
encoder_attn_layer_args = (
|
189 |
+
attention_heads,
|
190 |
+
attention_dim,
|
191 |
+
attention_dropout_rate,
|
192 |
+
)
|
193 |
+
elif encoder_attn_layer_type == "legacy_rel_mha":
|
194 |
+
encoder_attn_layer = LegacyRelPositionMultiHeadedAttention
|
195 |
+
encoder_attn_layer_args = (
|
196 |
+
attention_heads,
|
197 |
+
attention_dim,
|
198 |
+
attention_dropout_rate,
|
199 |
+
)
|
200 |
+
elif encoder_attn_layer_type == "rel_mha":
|
201 |
+
encoder_attn_layer = RelPositionMultiHeadedAttention
|
202 |
+
encoder_attn_layer_args = (
|
203 |
+
attention_heads,
|
204 |
+
attention_dim,
|
205 |
+
attention_dropout_rate,
|
206 |
+
zero_triu,
|
207 |
+
)
|
208 |
+
else:
|
209 |
+
raise ValueError("unknown encoder_attn_layer: " + encoder_attn_layer)
|
210 |
+
|
211 |
+
convolution_layer = ConvolutionModule
|
212 |
+
convolution_layer_args = (attention_dim, cnn_module_kernel)
|
213 |
+
|
214 |
+
self.encoders = repeat(
|
215 |
+
num_blocks,
|
216 |
+
lambda: EncoderLayer(
|
217 |
+
attention_dim,
|
218 |
+
encoder_attn_layer(*encoder_attn_layer_args),
|
219 |
+
positionwise_layer(*positionwise_layer_args),
|
220 |
+
convolution_layer(*convolution_layer_args) if use_cnn_module else None,
|
221 |
+
dropout_rate,
|
222 |
+
normalize_before,
|
223 |
+
concat_after,
|
224 |
+
macaron_style,
|
225 |
+
),
|
226 |
+
)
|
227 |
+
if self.normalize_before:
|
228 |
+
self.after_norm = LayerNorm(attention_dim)
|
229 |
+
|
230 |
+
def forward(self, xs, masks, extract_resnet_feats=False):
|
231 |
+
"""Encode input sequence.
|
232 |
+
|
233 |
+
:param torch.Tensor xs: input tensor
|
234 |
+
:param torch.Tensor masks: input mask
|
235 |
+
:param str extract_features: the position for feature extraction
|
236 |
+
:return: position embedded tensor and mask
|
237 |
+
:rtype Tuple[torch.Tensor, torch.Tensor]:
|
238 |
+
"""
|
239 |
+
if isinstance(self.frontend, (Conv1dResNet, Conv3dResNet)):
|
240 |
+
xs = self.frontend(xs)
|
241 |
+
if extract_resnet_feats:
|
242 |
+
return xs
|
243 |
+
|
244 |
+
if isinstance(self.embed, Conv2dSubsampling):
|
245 |
+
xs, masks = self.embed(xs, masks)
|
246 |
+
else:
|
247 |
+
xs = self.embed(xs)
|
248 |
+
|
249 |
+
xs, masks = self.encoders(xs, masks)
|
250 |
+
|
251 |
+
if isinstance(xs, tuple):
|
252 |
+
xs = xs[0]
|
253 |
+
|
254 |
+
if self.normalize_before:
|
255 |
+
xs = self.after_norm(xs)
|
256 |
+
|
257 |
+
return xs, masks
|
258 |
+
|
259 |
+
def forward_one_step(self, xs, masks, cache=None):
|
260 |
+
"""Encode input frame.
|
261 |
+
|
262 |
+
:param torch.Tensor xs: input tensor
|
263 |
+
:param torch.Tensor masks: input mask
|
264 |
+
:param List[torch.Tensor] cache: cache tensors
|
265 |
+
:return: position embedded tensor, mask and new cache
|
266 |
+
:rtype Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]:
|
267 |
+
"""
|
268 |
+
if isinstance(self.frontend, (Conv1dResNet, Conv3dResNet)):
|
269 |
+
xs = self.frontend(xs)
|
270 |
+
|
271 |
+
if isinstance(self.embed, Conv2dSubsampling):
|
272 |
+
xs, masks = self.embed(xs, masks)
|
273 |
+
else:
|
274 |
+
xs = self.embed(xs)
|
275 |
+
if cache is None:
|
276 |
+
cache = [None for _ in range(len(self.encoders))]
|
277 |
+
new_cache = []
|
278 |
+
for c, e in zip(cache, self.encoders):
|
279 |
+
xs, masks = e(xs, masks, cache=c)
|
280 |
+
new_cache.append(xs)
|
281 |
+
if self.normalize_before:
|
282 |
+
xs = self.after_norm(xs)
|
283 |
+
return xs, masks, new_cache
|
espnet/nets/pytorch_backend/transformer/encoder_layer.py
ADDED
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
# Copyright 2019 Shigeki Karita
|
5 |
+
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
6 |
+
|
7 |
+
"""Encoder self-attention layer definition."""
|
8 |
+
|
9 |
+
import copy
|
10 |
+
import torch
|
11 |
+
|
12 |
+
from torch import nn
|
13 |
+
|
14 |
+
from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm
|
15 |
+
|
16 |
+
|
17 |
+
class EncoderLayer(nn.Module):
|
18 |
+
"""Encoder layer module.
|
19 |
+
|
20 |
+
:param int size: input dim
|
21 |
+
:param espnet.nets.pytorch_backend.transformer.attention.
|
22 |
+
MultiHeadedAttention self_attn: self attention module
|
23 |
+
RelPositionMultiHeadedAttention self_attn: self attention module
|
24 |
+
:param espnet.nets.pytorch_backend.transformer.positionwise_feed_forward.
|
25 |
+
PositionwiseFeedForward feed_forward:
|
26 |
+
feed forward module
|
27 |
+
:param espnet.nets.pytorch_backend.transformer.convolution.
|
28 |
+
ConvolutionModule feed_foreard:
|
29 |
+
feed forward module
|
30 |
+
:param float dropout_rate: dropout rate
|
31 |
+
:param bool normalize_before: whether to use layer_norm before the first block
|
32 |
+
:param bool concat_after: whether to concat attention layer's input and output
|
33 |
+
if True, additional linear will be applied.
|
34 |
+
i.e. x -> x + linear(concat(x, att(x)))
|
35 |
+
if False, no additional linear will be applied. i.e. x -> x + att(x)
|
36 |
+
:param bool macaron_style: whether to use macaron style for PositionwiseFeedForward
|
37 |
+
|
38 |
+
"""
|
39 |
+
|
40 |
+
def __init__(
|
41 |
+
self,
|
42 |
+
size,
|
43 |
+
self_attn,
|
44 |
+
feed_forward,
|
45 |
+
conv_module,
|
46 |
+
dropout_rate,
|
47 |
+
normalize_before=True,
|
48 |
+
concat_after=False,
|
49 |
+
macaron_style=False,
|
50 |
+
):
|
51 |
+
"""Construct an EncoderLayer object."""
|
52 |
+
super(EncoderLayer, self).__init__()
|
53 |
+
self.self_attn = self_attn
|
54 |
+
self.feed_forward = feed_forward
|
55 |
+
self.ff_scale = 1.0
|
56 |
+
self.conv_module = conv_module
|
57 |
+
self.macaron_style = macaron_style
|
58 |
+
self.norm_ff = LayerNorm(size) # for the FNN module
|
59 |
+
self.norm_mha = LayerNorm(size) # for the MHA module
|
60 |
+
if self.macaron_style:
|
61 |
+
self.feed_forward_macaron = copy.deepcopy(feed_forward)
|
62 |
+
self.ff_scale = 0.5
|
63 |
+
# for another FNN module in macaron style
|
64 |
+
self.norm_ff_macaron = LayerNorm(size)
|
65 |
+
if self.conv_module is not None:
|
66 |
+
self.norm_conv = LayerNorm(size) # for the CNN module
|
67 |
+
self.norm_final = LayerNorm(size) # for the final output of the block
|
68 |
+
self.dropout = nn.Dropout(dropout_rate)
|
69 |
+
self.size = size
|
70 |
+
self.normalize_before = normalize_before
|
71 |
+
self.concat_after = concat_after
|
72 |
+
if self.concat_after:
|
73 |
+
self.concat_linear = nn.Linear(size + size, size)
|
74 |
+
|
75 |
+
def forward(self, x_input, mask, cache=None):
|
76 |
+
"""Compute encoded features.
|
77 |
+
|
78 |
+
:param torch.Tensor x_input: encoded source features (batch, max_time_in, size)
|
79 |
+
:param torch.Tensor mask: mask for x (batch, max_time_in)
|
80 |
+
:param torch.Tensor cache: cache for x (batch, max_time_in - 1, size)
|
81 |
+
:rtype: Tuple[torch.Tensor, torch.Tensor]
|
82 |
+
"""
|
83 |
+
if isinstance(x_input, tuple):
|
84 |
+
x, pos_emb = x_input[0], x_input[1]
|
85 |
+
else:
|
86 |
+
x, pos_emb = x_input, None
|
87 |
+
|
88 |
+
# whether to use macaron style
|
89 |
+
if self.macaron_style:
|
90 |
+
residual = x
|
91 |
+
if self.normalize_before:
|
92 |
+
x = self.norm_ff_macaron(x)
|
93 |
+
x = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(x))
|
94 |
+
if not self.normalize_before:
|
95 |
+
x = self.norm_ff_macaron(x)
|
96 |
+
|
97 |
+
# multi-headed self-attention module
|
98 |
+
residual = x
|
99 |
+
if self.normalize_before:
|
100 |
+
x = self.norm_mha(x)
|
101 |
+
|
102 |
+
if cache is None:
|
103 |
+
x_q = x
|
104 |
+
else:
|
105 |
+
assert cache.shape == (x.shape[0], x.shape[1] - 1, self.size)
|
106 |
+
x_q = x[:, -1:, :]
|
107 |
+
residual = residual[:, -1:, :]
|
108 |
+
mask = None if mask is None else mask[:, -1:, :]
|
109 |
+
|
110 |
+
if pos_emb is not None:
|
111 |
+
x_att = self.self_attn(x_q, x, x, pos_emb, mask)
|
112 |
+
else:
|
113 |
+
x_att = self.self_attn(x_q, x, x, mask)
|
114 |
+
|
115 |
+
if self.concat_after:
|
116 |
+
x_concat = torch.cat((x, x_att), dim=-1)
|
117 |
+
x = residual + self.concat_linear(x_concat)
|
118 |
+
else:
|
119 |
+
x = residual + self.dropout(x_att)
|
120 |
+
if not self.normalize_before:
|
121 |
+
x = self.norm_mha(x)
|
122 |
+
|
123 |
+
# convolution module
|
124 |
+
if self.conv_module is not None:
|
125 |
+
residual = x
|
126 |
+
if self.normalize_before:
|
127 |
+
x = self.norm_conv(x)
|
128 |
+
x = residual + self.dropout(self.conv_module(x))
|
129 |
+
if not self.normalize_before:
|
130 |
+
x = self.norm_conv(x)
|
131 |
+
|
132 |
+
# feed forward module
|
133 |
+
residual = x
|
134 |
+
if self.normalize_before:
|
135 |
+
x = self.norm_ff(x)
|
136 |
+
x = residual + self.ff_scale * self.dropout(self.feed_forward(x))
|
137 |
+
if not self.normalize_before:
|
138 |
+
x = self.norm_ff(x)
|
139 |
+
|
140 |
+
if self.conv_module is not None:
|
141 |
+
x = self.norm_final(x)
|
142 |
+
|
143 |
+
if cache is not None:
|
144 |
+
x = torch.cat([cache, x], dim=1)
|
145 |
+
|
146 |
+
if pos_emb is not None:
|
147 |
+
return (x, pos_emb), mask
|
148 |
+
else:
|
149 |
+
return x, mask
|
espnet/nets/pytorch_backend/transformer/label_smoothing_loss.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
# Copyright 2019 Shigeki Karita
|
5 |
+
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
6 |
+
|
7 |
+
"""Label smoothing module."""
|
8 |
+
|
9 |
+
import torch
|
10 |
+
from torch import nn
|
11 |
+
|
12 |
+
|
13 |
+
class LabelSmoothingLoss(nn.Module):
|
14 |
+
"""Label-smoothing loss.
|
15 |
+
|
16 |
+
:param int size: the number of class
|
17 |
+
:param int padding_idx: ignored class id
|
18 |
+
:param float smoothing: smoothing rate (0.0 means the conventional CE)
|
19 |
+
:param bool normalize_length: normalize loss by sequence length if True
|
20 |
+
:param torch.nn.Module criterion: loss function to be smoothed
|
21 |
+
"""
|
22 |
+
|
23 |
+
def __init__(
|
24 |
+
self,
|
25 |
+
size,
|
26 |
+
padding_idx,
|
27 |
+
smoothing,
|
28 |
+
normalize_length=False,
|
29 |
+
criterion=nn.KLDivLoss(reduction="none"),
|
30 |
+
):
|
31 |
+
"""Construct an LabelSmoothingLoss object."""
|
32 |
+
super(LabelSmoothingLoss, self).__init__()
|
33 |
+
self.criterion = criterion
|
34 |
+
self.padding_idx = padding_idx
|
35 |
+
self.confidence = 1.0 - smoothing
|
36 |
+
self.smoothing = smoothing
|
37 |
+
self.size = size
|
38 |
+
self.true_dist = None
|
39 |
+
self.normalize_length = normalize_length
|
40 |
+
|
41 |
+
def forward(self, x, target):
|
42 |
+
"""Compute loss between x and target.
|
43 |
+
|
44 |
+
:param torch.Tensor x: prediction (batch, seqlen, class)
|
45 |
+
:param torch.Tensor target:
|
46 |
+
target signal masked with self.padding_id (batch, seqlen)
|
47 |
+
:return: scalar float value
|
48 |
+
:rtype torch.Tensor
|
49 |
+
"""
|
50 |
+
assert x.size(2) == self.size
|
51 |
+
batch_size = x.size(0)
|
52 |
+
x = x.view(-1, self.size)
|
53 |
+
target = target.view(-1)
|
54 |
+
with torch.no_grad():
|
55 |
+
true_dist = x.clone()
|
56 |
+
true_dist.fill_(self.smoothing / (self.size - 1))
|
57 |
+
ignore = target == self.padding_idx # (B,)
|
58 |
+
total = len(target) - ignore.sum().item()
|
59 |
+
target = target.masked_fill(ignore, 0) # avoid -1 index
|
60 |
+
true_dist.scatter_(1, target.unsqueeze(1), self.confidence)
|
61 |
+
kl = self.criterion(torch.log_softmax(x, dim=1), true_dist)
|
62 |
+
denom = total if self.normalize_length else batch_size
|
63 |
+
return kl.masked_fill(ignore.unsqueeze(1), 0).sum() / denom
|
espnet/nets/pytorch_backend/transformer/layer_norm.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
# Copyright 2019 Shigeki Karita
|
5 |
+
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
6 |
+
|
7 |
+
"""Layer normalization module."""
|
8 |
+
|
9 |
+
import torch
|
10 |
+
|
11 |
+
|
12 |
+
class LayerNorm(torch.nn.LayerNorm):
|
13 |
+
"""Layer normalization module.
|
14 |
+
|
15 |
+
:param int nout: output dim size
|
16 |
+
:param int dim: dimension to be normalized
|
17 |
+
"""
|
18 |
+
|
19 |
+
def __init__(self, nout, dim=-1):
|
20 |
+
"""Construct an LayerNorm object."""
|
21 |
+
super(LayerNorm, self).__init__(nout, eps=1e-12)
|
22 |
+
self.dim = dim
|
23 |
+
|
24 |
+
def forward(self, x):
|
25 |
+
"""Apply layer normalization.
|
26 |
+
|
27 |
+
:param torch.Tensor x: input tensor
|
28 |
+
:return: layer normalized tensor
|
29 |
+
:rtype torch.Tensor
|
30 |
+
"""
|
31 |
+
if self.dim == -1:
|
32 |
+
return super(LayerNorm, self).forward(x)
|
33 |
+
return super(LayerNorm, self).forward(x.transpose(1, -1)).transpose(1, -1)
|
espnet/nets/pytorch_backend/transformer/mask.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
# Copyright 2019 Shigeki Karita
|
4 |
+
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
5 |
+
|
6 |
+
"""Mask module."""
|
7 |
+
|
8 |
+
from distutils.version import LooseVersion
|
9 |
+
|
10 |
+
import torch
|
11 |
+
|
12 |
+
is_torch_1_2_plus = LooseVersion(torch.__version__) >= LooseVersion("1.2.0")
|
13 |
+
# LooseVersion('1.2.0') == LooseVersion(torch.__version__) can't include e.g. 1.2.0+aaa
|
14 |
+
is_torch_1_2 = (
|
15 |
+
LooseVersion("1.3") > LooseVersion(torch.__version__) >= LooseVersion("1.2")
|
16 |
+
)
|
17 |
+
datatype = torch.bool if is_torch_1_2_plus else torch.uint8
|
18 |
+
|
19 |
+
|
20 |
+
def subsequent_mask(size, device="cpu", dtype=datatype):
|
21 |
+
"""Create mask for subsequent steps (1, size, size).
|
22 |
+
|
23 |
+
:param int size: size of mask
|
24 |
+
:param str device: "cpu" or "cuda" or torch.Tensor.device
|
25 |
+
:param torch.dtype dtype: result dtype
|
26 |
+
:rtype: torch.Tensor
|
27 |
+
>>> subsequent_mask(3)
|
28 |
+
[[1, 0, 0],
|
29 |
+
[1, 1, 0],
|
30 |
+
[1, 1, 1]]
|
31 |
+
"""
|
32 |
+
if is_torch_1_2 and dtype == torch.bool:
|
33 |
+
# torch=1.2 doesn't support tril for bool tensor
|
34 |
+
ret = torch.ones(size, size, device=device, dtype=torch.uint8)
|
35 |
+
return torch.tril(ret, out=ret).type(dtype)
|
36 |
+
else:
|
37 |
+
ret = torch.ones(size, size, device=device, dtype=dtype)
|
38 |
+
return torch.tril(ret, out=ret)
|
39 |
+
|
40 |
+
|
41 |
+
def target_mask(ys_in_pad, ignore_id):
|
42 |
+
"""Create mask for decoder self-attention.
|
43 |
+
|
44 |
+
:param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax)
|
45 |
+
:param int ignore_id: index of padding
|
46 |
+
:param torch.dtype dtype: result dtype
|
47 |
+
:rtype: torch.Tensor
|
48 |
+
"""
|
49 |
+
ys_mask = ys_in_pad != ignore_id
|
50 |
+
m = subsequent_mask(ys_mask.size(-1), device=ys_mask.device).unsqueeze(0)
|
51 |
+
return ys_mask.unsqueeze(-2) & m
|
espnet/nets/pytorch_backend/transformer/multi_layer_conv.py
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
# Copyright 2019 Tomoki Hayashi
|
5 |
+
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
6 |
+
|
7 |
+
"""Layer modules for FFT block in FastSpeech (Feed-forward Transformer)."""
|
8 |
+
|
9 |
+
import torch
|
10 |
+
|
11 |
+
|
12 |
+
class MultiLayeredConv1d(torch.nn.Module):
|
13 |
+
"""Multi-layered conv1d for Transformer block.
|
14 |
+
|
15 |
+
This is a module of multi-leyered conv1d designed
|
16 |
+
to replace positionwise feed-forward network
|
17 |
+
in Transforner block, which is introduced in
|
18 |
+
`FastSpeech: Fast, Robust and Controllable Text to Speech`_.
|
19 |
+
|
20 |
+
.. _`FastSpeech: Fast, Robust and Controllable Text to Speech`:
|
21 |
+
https://arxiv.org/pdf/1905.09263.pdf
|
22 |
+
|
23 |
+
"""
|
24 |
+
|
25 |
+
def __init__(self, in_chans, hidden_chans, kernel_size, dropout_rate):
|
26 |
+
"""Initialize MultiLayeredConv1d module.
|
27 |
+
|
28 |
+
Args:
|
29 |
+
in_chans (int): Number of input channels.
|
30 |
+
hidden_chans (int): Number of hidden channels.
|
31 |
+
kernel_size (int): Kernel size of conv1d.
|
32 |
+
dropout_rate (float): Dropout rate.
|
33 |
+
|
34 |
+
"""
|
35 |
+
super(MultiLayeredConv1d, self).__init__()
|
36 |
+
self.w_1 = torch.nn.Conv1d(
|
37 |
+
in_chans,
|
38 |
+
hidden_chans,
|
39 |
+
kernel_size,
|
40 |
+
stride=1,
|
41 |
+
padding=(kernel_size - 1) // 2,
|
42 |
+
)
|
43 |
+
self.w_2 = torch.nn.Conv1d(
|
44 |
+
hidden_chans,
|
45 |
+
in_chans,
|
46 |
+
kernel_size,
|
47 |
+
stride=1,
|
48 |
+
padding=(kernel_size - 1) // 2,
|
49 |
+
)
|
50 |
+
self.dropout = torch.nn.Dropout(dropout_rate)
|
51 |
+
|
52 |
+
def forward(self, x):
|
53 |
+
"""Calculate forward propagation.
|
54 |
+
|
55 |
+
Args:
|
56 |
+
x (Tensor): Batch of input tensors (B, ..., in_chans).
|
57 |
+
|
58 |
+
Returns:
|
59 |
+
Tensor: Batch of output tensors (B, ..., hidden_chans).
|
60 |
+
|
61 |
+
"""
|
62 |
+
x = torch.relu(self.w_1(x.transpose(-1, 1))).transpose(-1, 1)
|
63 |
+
return self.w_2(self.dropout(x).transpose(-1, 1)).transpose(-1, 1)
|
64 |
+
|
65 |
+
|
66 |
+
class Conv1dLinear(torch.nn.Module):
|
67 |
+
"""Conv1D + Linear for Transformer block.
|
68 |
+
|
69 |
+
A variant of MultiLayeredConv1d, which replaces second conv-layer to linear.
|
70 |
+
|
71 |
+
"""
|
72 |
+
|
73 |
+
def __init__(self, in_chans, hidden_chans, kernel_size, dropout_rate):
|
74 |
+
"""Initialize Conv1dLinear module.
|
75 |
+
|
76 |
+
Args:
|
77 |
+
in_chans (int): Number of input channels.
|
78 |
+
hidden_chans (int): Number of hidden channels.
|
79 |
+
kernel_size (int): Kernel size of conv1d.
|
80 |
+
dropout_rate (float): Dropout rate.
|
81 |
+
|
82 |
+
"""
|
83 |
+
super(Conv1dLinear, self).__init__()
|
84 |
+
self.w_1 = torch.nn.Conv1d(
|
85 |
+
in_chans,
|
86 |
+
hidden_chans,
|
87 |
+
kernel_size,
|
88 |
+
stride=1,
|
89 |
+
padding=(kernel_size - 1) // 2,
|
90 |
+
)
|
91 |
+
self.w_2 = torch.nn.Linear(hidden_chans, in_chans)
|
92 |
+
self.dropout = torch.nn.Dropout(dropout_rate)
|
93 |
+
|
94 |
+
def forward(self, x):
|
95 |
+
"""Calculate forward propagation.
|
96 |
+
|
97 |
+
Args:
|
98 |
+
x (Tensor): Batch of input tensors (B, ..., in_chans).
|
99 |
+
|
100 |
+
Returns:
|
101 |
+
Tensor: Batch of output tensors (B, ..., hidden_chans).
|
102 |
+
|
103 |
+
"""
|
104 |
+
x = torch.relu(self.w_1(x.transpose(-1, 1))).transpose(-1, 1)
|
105 |
+
return self.w_2(self.dropout(x))
|
espnet/nets/pytorch_backend/transformer/optimizer.py
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
# Copyright 2019 Shigeki Karita
|
5 |
+
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
6 |
+
|
7 |
+
"""Optimizer module."""
|
8 |
+
|
9 |
+
import torch
|
10 |
+
|
11 |
+
|
12 |
+
class NoamOpt(object):
|
13 |
+
"""Optim wrapper that implements rate."""
|
14 |
+
|
15 |
+
def __init__(self, model_size, factor, warmup, optimizer):
|
16 |
+
"""Construct an NoamOpt object."""
|
17 |
+
self.optimizer = optimizer
|
18 |
+
self._step = 0
|
19 |
+
self.warmup = warmup
|
20 |
+
self.factor = factor
|
21 |
+
self.model_size = model_size
|
22 |
+
self._rate = 0
|
23 |
+
|
24 |
+
@property
|
25 |
+
def param_groups(self):
|
26 |
+
"""Return param_groups."""
|
27 |
+
return self.optimizer.param_groups
|
28 |
+
|
29 |
+
def step(self):
|
30 |
+
"""Update parameters and rate."""
|
31 |
+
self._step += 1
|
32 |
+
rate = self.rate()
|
33 |
+
for p in self.optimizer.param_groups:
|
34 |
+
p["lr"] = rate
|
35 |
+
self._rate = rate
|
36 |
+
self.optimizer.step()
|
37 |
+
|
38 |
+
def rate(self, step=None):
|
39 |
+
"""Implement `lrate` above."""
|
40 |
+
if step is None:
|
41 |
+
step = self._step
|
42 |
+
return (
|
43 |
+
self.factor
|
44 |
+
* self.model_size ** (-0.5)
|
45 |
+
* min(step ** (-0.5), step * self.warmup ** (-1.5))
|
46 |
+
)
|
47 |
+
|
48 |
+
def zero_grad(self):
|
49 |
+
"""Reset gradient."""
|
50 |
+
self.optimizer.zero_grad()
|
51 |
+
|
52 |
+
def state_dict(self):
|
53 |
+
"""Return state_dict."""
|
54 |
+
return {
|
55 |
+
"_step": self._step,
|
56 |
+
"warmup": self.warmup,
|
57 |
+
"factor": self.factor,
|
58 |
+
"model_size": self.model_size,
|
59 |
+
"_rate": self._rate,
|
60 |
+
"optimizer": self.optimizer.state_dict(),
|
61 |
+
}
|
62 |
+
|
63 |
+
def load_state_dict(self, state_dict):
|
64 |
+
"""Load state_dict."""
|
65 |
+
for key, value in state_dict.items():
|
66 |
+
if key == "optimizer":
|
67 |
+
self.optimizer.load_state_dict(state_dict["optimizer"])
|
68 |
+
else:
|
69 |
+
setattr(self, key, value)
|
70 |
+
|
71 |
+
|
72 |
+
def get_std_opt(model, d_model, warmup, factor):
|
73 |
+
"""Get standard NoamOpt."""
|
74 |
+
base = torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)
|
75 |
+
return NoamOpt(d_model, factor, warmup, base)
|
espnet/nets/pytorch_backend/transformer/plot.py
ADDED
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
# Copyright 2019 Shigeki Karita
|
5 |
+
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
6 |
+
|
7 |
+
import logging
|
8 |
+
|
9 |
+
import matplotlib.pyplot as plt
|
10 |
+
import numpy
|
11 |
+
|
12 |
+
from espnet.asr import asr_utils
|
13 |
+
|
14 |
+
|
15 |
+
def _plot_and_save_attention(att_w, filename, xtokens=None, ytokens=None):
|
16 |
+
# dynamically import matplotlib due to not found error
|
17 |
+
from matplotlib.ticker import MaxNLocator
|
18 |
+
import os
|
19 |
+
|
20 |
+
d = os.path.dirname(filename)
|
21 |
+
if not os.path.exists(d):
|
22 |
+
os.makedirs(d)
|
23 |
+
w, h = plt.figaspect(1.0 / len(att_w))
|
24 |
+
fig = plt.Figure(figsize=(w * 2, h * 2))
|
25 |
+
axes = fig.subplots(1, len(att_w))
|
26 |
+
if len(att_w) == 1:
|
27 |
+
axes = [axes]
|
28 |
+
for ax, aw in zip(axes, att_w):
|
29 |
+
# plt.subplot(1, len(att_w), h)
|
30 |
+
ax.imshow(aw.astype(numpy.float32), aspect="auto")
|
31 |
+
ax.set_xlabel("Input")
|
32 |
+
ax.set_ylabel("Output")
|
33 |
+
ax.xaxis.set_major_locator(MaxNLocator(integer=True))
|
34 |
+
ax.yaxis.set_major_locator(MaxNLocator(integer=True))
|
35 |
+
# Labels for major ticks
|
36 |
+
if xtokens is not None:
|
37 |
+
ax.set_xticks(numpy.linspace(0, len(xtokens) - 1, len(xtokens)))
|
38 |
+
ax.set_xticks(numpy.linspace(0, len(xtokens) - 1, 1), minor=True)
|
39 |
+
ax.set_xticklabels(xtokens + [""], rotation=40)
|
40 |
+
if ytokens is not None:
|
41 |
+
ax.set_yticks(numpy.linspace(0, len(ytokens) - 1, len(ytokens)))
|
42 |
+
ax.set_yticks(numpy.linspace(0, len(ytokens) - 1, 1), minor=True)
|
43 |
+
ax.set_yticklabels(ytokens + [""])
|
44 |
+
fig.tight_layout()
|
45 |
+
return fig
|
46 |
+
|
47 |
+
|
48 |
+
def savefig(plot, filename):
|
49 |
+
plot.savefig(filename)
|
50 |
+
plt.clf()
|
51 |
+
|
52 |
+
|
53 |
+
def plot_multi_head_attention(
|
54 |
+
data,
|
55 |
+
attn_dict,
|
56 |
+
outdir,
|
57 |
+
suffix="png",
|
58 |
+
savefn=savefig,
|
59 |
+
ikey="input",
|
60 |
+
iaxis=0,
|
61 |
+
okey="output",
|
62 |
+
oaxis=0,
|
63 |
+
):
|
64 |
+
"""Plot multi head attentions.
|
65 |
+
|
66 |
+
:param dict data: utts info from json file
|
67 |
+
:param dict[str, torch.Tensor] attn_dict: multi head attention dict.
|
68 |
+
values should be torch.Tensor (head, input_length, output_length)
|
69 |
+
:param str outdir: dir to save fig
|
70 |
+
:param str suffix: filename suffix including image type (e.g., png)
|
71 |
+
:param savefn: function to save
|
72 |
+
|
73 |
+
"""
|
74 |
+
for name, att_ws in attn_dict.items():
|
75 |
+
for idx, att_w in enumerate(att_ws):
|
76 |
+
filename = "%s/%s.%s.%s" % (outdir, data[idx][0], name, suffix)
|
77 |
+
dec_len = int(data[idx][1][okey][oaxis]["shape"][0])
|
78 |
+
enc_len = int(data[idx][1][ikey][iaxis]["shape"][0])
|
79 |
+
xtokens, ytokens = None, None
|
80 |
+
if "encoder" in name:
|
81 |
+
att_w = att_w[:, :enc_len, :enc_len]
|
82 |
+
# for MT
|
83 |
+
if "token" in data[idx][1][ikey][iaxis].keys():
|
84 |
+
xtokens = data[idx][1][ikey][iaxis]["token"].split()
|
85 |
+
ytokens = xtokens[:]
|
86 |
+
elif "decoder" in name:
|
87 |
+
if "self" in name:
|
88 |
+
att_w = att_w[:, : dec_len + 1, : dec_len + 1] # +1 for <sos>
|
89 |
+
else:
|
90 |
+
att_w = att_w[:, : dec_len + 1, :enc_len] # +1 for <sos>
|
91 |
+
# for MT
|
92 |
+
if "token" in data[idx][1][ikey][iaxis].keys():
|
93 |
+
xtokens = data[idx][1][ikey][iaxis]["token"].split()
|
94 |
+
# for ASR/ST/MT
|
95 |
+
if "token" in data[idx][1][okey][oaxis].keys():
|
96 |
+
ytokens = ["<sos>"] + data[idx][1][okey][oaxis]["token"].split()
|
97 |
+
if "self" in name:
|
98 |
+
xtokens = ytokens[:]
|
99 |
+
else:
|
100 |
+
logging.warning("unknown name for shaping attention")
|
101 |
+
fig = _plot_and_save_attention(att_w, filename, xtokens, ytokens)
|
102 |
+
savefn(fig, filename)
|
103 |
+
|
104 |
+
|
105 |
+
class PlotAttentionReport(asr_utils.PlotAttentionReport):
|
106 |
+
def plotfn(self, *args, **kwargs):
|
107 |
+
kwargs["ikey"] = self.ikey
|
108 |
+
kwargs["iaxis"] = self.iaxis
|
109 |
+
kwargs["okey"] = self.okey
|
110 |
+
kwargs["oaxis"] = self.oaxis
|
111 |
+
plot_multi_head_attention(*args, **kwargs)
|
112 |
+
|
113 |
+
def __call__(self, trainer):
|
114 |
+
attn_dict = self.get_attention_weights()
|
115 |
+
suffix = "ep.{.updater.epoch}.png".format(trainer)
|
116 |
+
self.plotfn(self.data, attn_dict, self.outdir, suffix, savefig)
|
117 |
+
|
118 |
+
def get_attention_weights(self):
|
119 |
+
batch = self.converter([self.transform(self.data)], self.device)
|
120 |
+
if isinstance(batch, tuple):
|
121 |
+
att_ws = self.att_vis_fn(*batch)
|
122 |
+
elif isinstance(batch, dict):
|
123 |
+
att_ws = self.att_vis_fn(**batch)
|
124 |
+
return att_ws
|
125 |
+
|
126 |
+
def log_attentions(self, logger, step):
|
127 |
+
def log_fig(plot, filename):
|
128 |
+
from os.path import basename
|
129 |
+
|
130 |
+
logger.add_figure(basename(filename), plot, step)
|
131 |
+
plt.clf()
|
132 |
+
|
133 |
+
attn_dict = self.get_attention_weights()
|
134 |
+
self.plotfn(self.data, attn_dict, self.outdir, "", log_fig)
|
espnet/nets/pytorch_backend/transformer/positionwise_feed_forward.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
# Copyright 2019 Shigeki Karita
|
5 |
+
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
6 |
+
|
7 |
+
"""Positionwise feed forward layer definition."""
|
8 |
+
|
9 |
+
import torch
|
10 |
+
|
11 |
+
|
12 |
+
class PositionwiseFeedForward(torch.nn.Module):
|
13 |
+
"""Positionwise feed forward layer.
|
14 |
+
|
15 |
+
:param int idim: input dimenstion
|
16 |
+
:param int hidden_units: number of hidden units
|
17 |
+
:param float dropout_rate: dropout rate
|
18 |
+
|
19 |
+
"""
|
20 |
+
|
21 |
+
def __init__(self, idim, hidden_units, dropout_rate):
|
22 |
+
"""Construct an PositionwiseFeedForward object."""
|
23 |
+
super(PositionwiseFeedForward, self).__init__()
|
24 |
+
self.w_1 = torch.nn.Linear(idim, hidden_units)
|
25 |
+
self.w_2 = torch.nn.Linear(hidden_units, idim)
|
26 |
+
self.dropout = torch.nn.Dropout(dropout_rate)
|
27 |
+
|
28 |
+
def forward(self, x):
|
29 |
+
"""Forward funciton."""
|
30 |
+
return self.w_2(self.dropout(torch.relu(self.w_1(x))))
|
espnet/nets/pytorch_backend/transformer/raw_embeddings.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import logging
|
3 |
+
|
4 |
+
from espnet.nets.pytorch_backend.backbones.conv3d_extractor import Conv3dResNet
|
5 |
+
from espnet.nets.pytorch_backend.backbones.conv1d_extractor import Conv1dResNet
|
6 |
+
|
7 |
+
|
8 |
+
class VideoEmbedding(torch.nn.Module):
|
9 |
+
"""Video Embedding
|
10 |
+
|
11 |
+
:param int idim: input dim
|
12 |
+
:param int odim: output dim
|
13 |
+
:param flaot dropout_rate: dropout rate
|
14 |
+
"""
|
15 |
+
|
16 |
+
def __init__(self, idim, odim, dropout_rate, pos_enc_class, backbone_type="resnet", relu_type="prelu"):
|
17 |
+
super(VideoEmbedding, self).__init__()
|
18 |
+
self.trunk = Conv3dResNet(
|
19 |
+
backbone_type=backbone_type,
|
20 |
+
relu_type=relu_type
|
21 |
+
)
|
22 |
+
self.out = torch.nn.Sequential(
|
23 |
+
torch.nn.Linear(idim, odim),
|
24 |
+
pos_enc_class,
|
25 |
+
)
|
26 |
+
|
27 |
+
def forward(self, x, x_mask, extract_feats=None):
|
28 |
+
"""video embedding for x
|
29 |
+
|
30 |
+
:param torch.Tensor x: input tensor
|
31 |
+
:param torch.Tensor x_mask: input mask
|
32 |
+
:param str extract_features: the position for feature extraction
|
33 |
+
:return: subsampled x and mask
|
34 |
+
:rtype Tuple[torch.Tensor, torch.Tensor]
|
35 |
+
"""
|
36 |
+
x_resnet, x_mask = self.trunk(x, x_mask)
|
37 |
+
x = self.out(x_resnet)
|
38 |
+
if extract_feats:
|
39 |
+
return x, x_mask, x_resnet
|
40 |
+
else:
|
41 |
+
return x, x_mask
|
42 |
+
|
43 |
+
|
44 |
+
class AudioEmbedding(torch.nn.Module):
|
45 |
+
"""Audio Embedding
|
46 |
+
|
47 |
+
:param int idim: input dim
|
48 |
+
:param int odim: output dim
|
49 |
+
:param flaot dropout_rate: dropout rate
|
50 |
+
"""
|
51 |
+
|
52 |
+
def __init__(self, idim, odim, dropout_rate, pos_enc_class, relu_type="prelu", a_upsample_ratio=1):
|
53 |
+
super(AudioEmbedding, self).__init__()
|
54 |
+
self.trunk = Conv1dResNet(
|
55 |
+
relu_type=relu_type,
|
56 |
+
a_upsample_ratio=a_upsample_ratio,
|
57 |
+
)
|
58 |
+
self.out = torch.nn.Sequential(
|
59 |
+
torch.nn.Linear(idim, odim),
|
60 |
+
pos_enc_class,
|
61 |
+
)
|
62 |
+
|
63 |
+
def forward(self, x, x_mask, extract_feats=None):
|
64 |
+
"""audio embedding for x
|
65 |
+
|
66 |
+
:param torch.Tensor x: input tensor
|
67 |
+
:param torch.Tensor x_mask: input mask
|
68 |
+
:param str extract_features: the position for feature extraction
|
69 |
+
:return: subsampled x and mask
|
70 |
+
:rtype Tuple[torch.Tensor, torch.Tensor]
|
71 |
+
"""
|
72 |
+
x_resnet, x_mask = self.trunk(x, x_mask)
|
73 |
+
x = self.out(x_resnet)
|
74 |
+
if extract_feats:
|
75 |
+
return x, x_mask, x_resnet
|
76 |
+
else:
|
77 |
+
return x, x_mask
|
espnet/nets/pytorch_backend/transformer/repeat.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
# Copyright 2019 Shigeki Karita
|
5 |
+
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
6 |
+
|
7 |
+
"""Repeat the same layer definition."""
|
8 |
+
|
9 |
+
import torch
|
10 |
+
|
11 |
+
|
12 |
+
class MultiSequential(torch.nn.Sequential):
|
13 |
+
"""Multi-input multi-output torch.nn.Sequential."""
|
14 |
+
|
15 |
+
def forward(self, *args):
|
16 |
+
"""Repeat."""
|
17 |
+
for m in self:
|
18 |
+
args = m(*args)
|
19 |
+
return args
|
20 |
+
|
21 |
+
|
22 |
+
def repeat(N, fn):
|
23 |
+
"""Repeat module N times.
|
24 |
+
|
25 |
+
:param int N: repeat time
|
26 |
+
:param function fn: function to generate module
|
27 |
+
:return: repeated modules
|
28 |
+
:rtype: MultiSequential
|
29 |
+
"""
|
30 |
+
return MultiSequential(*[fn() for _ in range(N)])
|
espnet/nets/pytorch_backend/transformer/subsampling.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
# Copyright 2019 Shigeki Karita
|
5 |
+
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
6 |
+
|
7 |
+
"""Subsampling layer definition."""
|
8 |
+
|
9 |
+
import torch
|
10 |
+
|
11 |
+
|
12 |
+
class Conv2dSubsampling(torch.nn.Module):
|
13 |
+
"""Convolutional 2D subsampling (to 1/4 length).
|
14 |
+
|
15 |
+
:param int idim: input dim
|
16 |
+
:param int odim: output dim
|
17 |
+
:param flaot dropout_rate: dropout rate
|
18 |
+
:param nn.Module pos_enc_class: positional encoding layer
|
19 |
+
|
20 |
+
"""
|
21 |
+
|
22 |
+
def __init__(self, idim, odim, dropout_rate, pos_enc_class):
|
23 |
+
"""Construct an Conv2dSubsampling object."""
|
24 |
+
super(Conv2dSubsampling, self).__init__()
|
25 |
+
self.conv = torch.nn.Sequential(
|
26 |
+
torch.nn.Conv2d(1, odim, 3, 2),
|
27 |
+
torch.nn.ReLU(),
|
28 |
+
torch.nn.Conv2d(odim, odim, 3, 2),
|
29 |
+
torch.nn.ReLU(),
|
30 |
+
)
|
31 |
+
self.out = torch.nn.Sequential(
|
32 |
+
torch.nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim), pos_enc_class,
|
33 |
+
)
|
34 |
+
|
35 |
+
def forward(self, x, x_mask):
|
36 |
+
"""Subsample x.
|
37 |
+
|
38 |
+
:param torch.Tensor x: input tensor
|
39 |
+
:param torch.Tensor x_mask: input mask
|
40 |
+
:return: subsampled x and mask
|
41 |
+
:rtype Tuple[torch.Tensor, torch.Tensor]
|
42 |
+
or Tuple[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]
|
43 |
+
"""
|
44 |
+
x = x.unsqueeze(1) # (b, c, t, f)
|
45 |
+
x = self.conv(x)
|
46 |
+
b, c, t, f = x.size()
|
47 |
+
# if RelPositionalEncoding, x: Tuple[torch.Tensor, torch.Tensor]
|
48 |
+
# else x: torch.Tensor
|
49 |
+
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
|
50 |
+
if x_mask is None:
|
51 |
+
return x, None
|
52 |
+
return x, x_mask[:, :, :-2:2][:, :, :-2:2]
|
espnet/nets/scorer_interface.py
ADDED
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Scorer interface module."""
|
2 |
+
|
3 |
+
from typing import Any
|
4 |
+
from typing import List
|
5 |
+
from typing import Tuple
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import warnings
|
9 |
+
|
10 |
+
|
11 |
+
class ScorerInterface:
|
12 |
+
"""Scorer interface for beam search.
|
13 |
+
|
14 |
+
The scorer performs scoring of the all tokens in vocabulary.
|
15 |
+
|
16 |
+
Examples:
|
17 |
+
* Search heuristics
|
18 |
+
* :class:`espnet.nets.scorers.length_bonus.LengthBonus`
|
19 |
+
* Decoder networks of the sequence-to-sequence models
|
20 |
+
* :class:`espnet.nets.pytorch_backend.nets.transformer.decoder.Decoder`
|
21 |
+
* :class:`espnet.nets.pytorch_backend.nets.rnn.decoders.Decoder`
|
22 |
+
* Neural language models
|
23 |
+
* :class:`espnet.nets.pytorch_backend.lm.transformer.TransformerLM`
|
24 |
+
* :class:`espnet.nets.pytorch_backend.lm.default.DefaultRNNLM`
|
25 |
+
* :class:`espnet.nets.pytorch_backend.lm.seq_rnn.SequentialRNNLM`
|
26 |
+
|
27 |
+
"""
|
28 |
+
|
29 |
+
def init_state(self, x: torch.Tensor) -> Any:
|
30 |
+
"""Get an initial state for decoding (optional).
|
31 |
+
|
32 |
+
Args:
|
33 |
+
x (torch.Tensor): The encoded feature tensor
|
34 |
+
|
35 |
+
Returns: initial state
|
36 |
+
|
37 |
+
"""
|
38 |
+
return None
|
39 |
+
|
40 |
+
def select_state(self, state: Any, i: int, new_id: int = None) -> Any:
|
41 |
+
"""Select state with relative ids in the main beam search.
|
42 |
+
|
43 |
+
Args:
|
44 |
+
state: Decoder state for prefix tokens
|
45 |
+
i (int): Index to select a state in the main beam search
|
46 |
+
new_id (int): New label index to select a state if necessary
|
47 |
+
|
48 |
+
Returns:
|
49 |
+
state: pruned state
|
50 |
+
|
51 |
+
"""
|
52 |
+
return None if state is None else state[i]
|
53 |
+
|
54 |
+
def score(
|
55 |
+
self, y: torch.Tensor, state: Any, x: torch.Tensor
|
56 |
+
) -> Tuple[torch.Tensor, Any]:
|
57 |
+
"""Score new token (required).
|
58 |
+
|
59 |
+
Args:
|
60 |
+
y (torch.Tensor): 1D torch.int64 prefix tokens.
|
61 |
+
state: Scorer state for prefix tokens
|
62 |
+
x (torch.Tensor): The encoder feature that generates ys.
|
63 |
+
|
64 |
+
Returns:
|
65 |
+
tuple[torch.Tensor, Any]: Tuple of
|
66 |
+
scores for next token that has a shape of `(n_vocab)`
|
67 |
+
and next state for ys
|
68 |
+
|
69 |
+
"""
|
70 |
+
raise NotImplementedError
|
71 |
+
|
72 |
+
def final_score(self, state: Any) -> float:
|
73 |
+
"""Score eos (optional).
|
74 |
+
|
75 |
+
Args:
|
76 |
+
state: Scorer state for prefix tokens
|
77 |
+
|
78 |
+
Returns:
|
79 |
+
float: final score
|
80 |
+
|
81 |
+
"""
|
82 |
+
return 0.0
|
83 |
+
|
84 |
+
|
85 |
+
class BatchScorerInterface(ScorerInterface):
|
86 |
+
"""Batch scorer interface."""
|
87 |
+
|
88 |
+
def batch_init_state(self, x: torch.Tensor) -> Any:
|
89 |
+
"""Get an initial state for decoding (optional).
|
90 |
+
|
91 |
+
Args:
|
92 |
+
x (torch.Tensor): The encoded feature tensor
|
93 |
+
|
94 |
+
Returns: initial state
|
95 |
+
|
96 |
+
"""
|
97 |
+
return self.init_state(x)
|
98 |
+
|
99 |
+
def batch_score(
|
100 |
+
self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor
|
101 |
+
) -> Tuple[torch.Tensor, List[Any]]:
|
102 |
+
"""Score new token batch (required).
|
103 |
+
|
104 |
+
Args:
|
105 |
+
ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen).
|
106 |
+
states (List[Any]): Scorer states for prefix tokens.
|
107 |
+
xs (torch.Tensor):
|
108 |
+
The encoder feature that generates ys (n_batch, xlen, n_feat).
|
109 |
+
|
110 |
+
Returns:
|
111 |
+
tuple[torch.Tensor, List[Any]]: Tuple of
|
112 |
+
batchfied scores for next token with shape of `(n_batch, n_vocab)`
|
113 |
+
and next state list for ys.
|
114 |
+
|
115 |
+
"""
|
116 |
+
warnings.warn(
|
117 |
+
"{} batch score is implemented through for loop not parallelized".format(
|
118 |
+
self.__class__.__name__
|
119 |
+
)
|
120 |
+
)
|
121 |
+
scores = list()
|
122 |
+
outstates = list()
|
123 |
+
for i, (y, state, x) in enumerate(zip(ys, states, xs)):
|
124 |
+
score, outstate = self.score(y, state, x)
|
125 |
+
outstates.append(outstate)
|
126 |
+
scores.append(score)
|
127 |
+
scores = torch.cat(scores, 0).view(ys.shape[0], -1)
|
128 |
+
return scores, outstates
|
129 |
+
|
130 |
+
|
131 |
+
class PartialScorerInterface(ScorerInterface):
|
132 |
+
"""Partial scorer interface for beam search.
|
133 |
+
|
134 |
+
The partial scorer performs scoring when non-partial scorer finished scoring,
|
135 |
+
and receives pre-pruned next tokens to score because it is too heavy to score
|
136 |
+
all the tokens.
|
137 |
+
|
138 |
+
Examples:
|
139 |
+
* Prefix search for connectionist-temporal-classification models
|
140 |
+
* :class:`espnet.nets.scorers.ctc.CTCPrefixScorer`
|
141 |
+
|
142 |
+
"""
|
143 |
+
|
144 |
+
def score_partial(
|
145 |
+
self, y: torch.Tensor, next_tokens: torch.Tensor, state: Any, x: torch.Tensor
|
146 |
+
) -> Tuple[torch.Tensor, Any]:
|
147 |
+
"""Score new token (required).
|
148 |
+
|
149 |
+
Args:
|
150 |
+
y (torch.Tensor): 1D prefix token
|
151 |
+
next_tokens (torch.Tensor): torch.int64 next token to score
|
152 |
+
state: decoder state for prefix tokens
|
153 |
+
x (torch.Tensor): The encoder feature that generates ys
|
154 |
+
|
155 |
+
Returns:
|
156 |
+
tuple[torch.Tensor, Any]:
|
157 |
+
Tuple of a score tensor for y that has a shape `(len(next_tokens),)`
|
158 |
+
and next state for ys
|
159 |
+
|
160 |
+
"""
|
161 |
+
raise NotImplementedError
|
162 |
+
|
163 |
+
|
164 |
+
class BatchPartialScorerInterface(BatchScorerInterface, PartialScorerInterface):
|
165 |
+
"""Batch partial scorer interface for beam search."""
|
166 |
+
|
167 |
+
def batch_score_partial(
|
168 |
+
self,
|
169 |
+
ys: torch.Tensor,
|
170 |
+
next_tokens: torch.Tensor,
|
171 |
+
states: List[Any],
|
172 |
+
xs: torch.Tensor,
|
173 |
+
) -> Tuple[torch.Tensor, Any]:
|
174 |
+
"""Score new token (required).
|
175 |
+
|
176 |
+
Args:
|
177 |
+
ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen).
|
178 |
+
next_tokens (torch.Tensor): torch.int64 tokens to score (n_batch, n_token).
|
179 |
+
states (List[Any]): Scorer states for prefix tokens.
|
180 |
+
xs (torch.Tensor):
|
181 |
+
The encoder feature that generates ys (n_batch, xlen, n_feat).
|
182 |
+
|
183 |
+
Returns:
|
184 |
+
tuple[torch.Tensor, Any]:
|
185 |
+
Tuple of a score tensor for ys that has a shape `(n_batch, n_vocab)`
|
186 |
+
and next states for ys
|
187 |
+
"""
|
188 |
+
raise NotImplementedError
|
espnet/nets/scorers/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
"""Initialize sub package."""
|
espnet/nets/scorers/ctc.py
ADDED
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""ScorerInterface implementation for CTC."""
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
|
6 |
+
from espnet.nets.ctc_prefix_score import CTCPrefixScore
|
7 |
+
from espnet.nets.ctc_prefix_score import CTCPrefixScoreTH
|
8 |
+
from espnet.nets.scorer_interface import BatchPartialScorerInterface
|
9 |
+
|
10 |
+
|
11 |
+
class CTCPrefixScorer(BatchPartialScorerInterface):
|
12 |
+
"""Decoder interface wrapper for CTCPrefixScore."""
|
13 |
+
|
14 |
+
def __init__(self, ctc: torch.nn.Module, eos: int):
|
15 |
+
"""Initialize class.
|
16 |
+
|
17 |
+
Args:
|
18 |
+
ctc (torch.nn.Module): The CTC implementation.
|
19 |
+
For example, :class:`espnet.nets.pytorch_backend.ctc.CTC`
|
20 |
+
eos (int): The end-of-sequence id.
|
21 |
+
|
22 |
+
"""
|
23 |
+
self.ctc = ctc
|
24 |
+
self.eos = eos
|
25 |
+
self.impl = None
|
26 |
+
|
27 |
+
def init_state(self, x: torch.Tensor):
|
28 |
+
"""Get an initial state for decoding.
|
29 |
+
|
30 |
+
Args:
|
31 |
+
x (torch.Tensor): The encoded feature tensor
|
32 |
+
|
33 |
+
Returns: initial state
|
34 |
+
|
35 |
+
"""
|
36 |
+
logp = self.ctc.log_softmax(x.unsqueeze(0)).detach().squeeze(0).cpu().numpy()
|
37 |
+
# TODO(karita): use CTCPrefixScoreTH
|
38 |
+
self.impl = CTCPrefixScore(logp, 0, self.eos, np)
|
39 |
+
return 0, self.impl.initial_state()
|
40 |
+
|
41 |
+
def select_state(self, state, i, new_id=None):
|
42 |
+
"""Select state with relative ids in the main beam search.
|
43 |
+
|
44 |
+
Args:
|
45 |
+
state: Decoder state for prefix tokens
|
46 |
+
i (int): Index to select a state in the main beam search
|
47 |
+
new_id (int): New label id to select a state if necessary
|
48 |
+
|
49 |
+
Returns:
|
50 |
+
state: pruned state
|
51 |
+
|
52 |
+
"""
|
53 |
+
if type(state) == tuple:
|
54 |
+
if len(state) == 2: # for CTCPrefixScore
|
55 |
+
sc, st = state
|
56 |
+
return sc[i], st[i]
|
57 |
+
else: # for CTCPrefixScoreTH (need new_id > 0)
|
58 |
+
r, log_psi, f_min, f_max, scoring_idmap = state
|
59 |
+
s = log_psi[i, new_id].expand(log_psi.size(1))
|
60 |
+
if scoring_idmap is not None:
|
61 |
+
return r[:, :, i, scoring_idmap[i, new_id]], s, f_min, f_max
|
62 |
+
else:
|
63 |
+
return r[:, :, i, new_id], s, f_min, f_max
|
64 |
+
return None if state is None else state[i]
|
65 |
+
|
66 |
+
def score_partial(self, y, ids, state, x):
|
67 |
+
"""Score new token.
|
68 |
+
|
69 |
+
Args:
|
70 |
+
y (torch.Tensor): 1D prefix token
|
71 |
+
next_tokens (torch.Tensor): torch.int64 next token to score
|
72 |
+
state: decoder state for prefix tokens
|
73 |
+
x (torch.Tensor): 2D encoder feature that generates ys
|
74 |
+
|
75 |
+
Returns:
|
76 |
+
tuple[torch.Tensor, Any]:
|
77 |
+
Tuple of a score tensor for y that has a shape `(len(next_tokens),)`
|
78 |
+
and next state for ys
|
79 |
+
|
80 |
+
"""
|
81 |
+
prev_score, state = state
|
82 |
+
presub_score, new_st = self.impl(y.cpu(), ids.cpu(), state)
|
83 |
+
tscore = torch.as_tensor(
|
84 |
+
presub_score - prev_score, device=x.device, dtype=x.dtype
|
85 |
+
)
|
86 |
+
return tscore, (presub_score, new_st)
|
87 |
+
|
88 |
+
def batch_init_state(self, x: torch.Tensor):
|
89 |
+
"""Get an initial state for decoding.
|
90 |
+
|
91 |
+
Args:
|
92 |
+
x (torch.Tensor): The encoded feature tensor
|
93 |
+
|
94 |
+
Returns: initial state
|
95 |
+
|
96 |
+
"""
|
97 |
+
logp = self.ctc.log_softmax(x.unsqueeze(0)) # assuming batch_size = 1
|
98 |
+
xlen = torch.tensor([logp.size(1)])
|
99 |
+
self.impl = CTCPrefixScoreTH(logp, xlen, 0, self.eos)
|
100 |
+
return None
|
101 |
+
|
102 |
+
def batch_score_partial(self, y, ids, state, x):
|
103 |
+
"""Score new token.
|
104 |
+
|
105 |
+
Args:
|
106 |
+
y (torch.Tensor): 1D prefix token
|
107 |
+
ids (torch.Tensor): torch.int64 next token to score
|
108 |
+
state: decoder state for prefix tokens
|
109 |
+
x (torch.Tensor): 2D encoder feature that generates ys
|
110 |
+
|
111 |
+
Returns:
|
112 |
+
tuple[torch.Tensor, Any]:
|
113 |
+
Tuple of a score tensor for y that has a shape `(len(next_tokens),)`
|
114 |
+
and next state for ys
|
115 |
+
|
116 |
+
"""
|
117 |
+
batch_state = (
|
118 |
+
(
|
119 |
+
torch.stack([s[0] for s in state], dim=2),
|
120 |
+
torch.stack([s[1] for s in state]),
|
121 |
+
state[0][2],
|
122 |
+
state[0][3],
|
123 |
+
)
|
124 |
+
if state[0] is not None
|
125 |
+
else None
|
126 |
+
)
|
127 |
+
return self.impl(y, batch_state, ids)
|
128 |
+
|
129 |
+
def extend_prob(self, x: torch.Tensor):
|
130 |
+
"""Extend probs for decoding.
|
131 |
+
|
132 |
+
This extension is for streaming decoding
|
133 |
+
as in Eq (14) in https://arxiv.org/abs/2006.14941
|
134 |
+
|
135 |
+
Args:
|
136 |
+
x (torch.Tensor): The encoded feature tensor
|
137 |
+
|
138 |
+
"""
|
139 |
+
logp = self.ctc.log_softmax(x.unsqueeze(0))
|
140 |
+
self.impl.extend_prob(logp)
|
141 |
+
|
142 |
+
def extend_state(self, state):
|
143 |
+
"""Extend state for decoding.
|
144 |
+
|
145 |
+
This extension is for streaming decoding
|
146 |
+
as in Eq (14) in https://arxiv.org/abs/2006.14941
|
147 |
+
|
148 |
+
Args:
|
149 |
+
state: The states of hyps
|
150 |
+
|
151 |
+
Returns: exteded state
|
152 |
+
|
153 |
+
"""
|
154 |
+
new_state = []
|
155 |
+
for s in state:
|
156 |
+
new_state.append(self.impl.extend_state(s))
|
157 |
+
|
158 |
+
return new_state
|
espnet/nets/scorers/length_bonus.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Length bonus module."""
|
2 |
+
from typing import Any
|
3 |
+
from typing import List
|
4 |
+
from typing import Tuple
|
5 |
+
|
6 |
+
import torch
|
7 |
+
|
8 |
+
from espnet.nets.scorer_interface import BatchScorerInterface
|
9 |
+
|
10 |
+
|
11 |
+
class LengthBonus(BatchScorerInterface):
|
12 |
+
"""Length bonus in beam search."""
|
13 |
+
|
14 |
+
def __init__(self, n_vocab: int):
|
15 |
+
"""Initialize class.
|
16 |
+
|
17 |
+
Args:
|
18 |
+
n_vocab (int): The number of tokens in vocabulary for beam search
|
19 |
+
|
20 |
+
"""
|
21 |
+
self.n = n_vocab
|
22 |
+
|
23 |
+
def score(self, y, state, x):
|
24 |
+
"""Score new token.
|
25 |
+
|
26 |
+
Args:
|
27 |
+
y (torch.Tensor): 1D torch.int64 prefix tokens.
|
28 |
+
state: Scorer state for prefix tokens
|
29 |
+
x (torch.Tensor): 2D encoder feature that generates ys.
|
30 |
+
|
31 |
+
Returns:
|
32 |
+
tuple[torch.Tensor, Any]: Tuple of
|
33 |
+
torch.float32 scores for next token (n_vocab)
|
34 |
+
and None
|
35 |
+
|
36 |
+
"""
|
37 |
+
return torch.tensor([1.0], device=x.device, dtype=x.dtype).expand(self.n), None
|
38 |
+
|
39 |
+
def batch_score(
|
40 |
+
self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor
|
41 |
+
) -> Tuple[torch.Tensor, List[Any]]:
|
42 |
+
"""Score new token batch.
|
43 |
+
|
44 |
+
Args:
|
45 |
+
ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen).
|
46 |
+
states (List[Any]): Scorer states for prefix tokens.
|
47 |
+
xs (torch.Tensor):
|
48 |
+
The encoder feature that generates ys (n_batch, xlen, n_feat).
|
49 |
+
|
50 |
+
Returns:
|
51 |
+
tuple[torch.Tensor, List[Any]]: Tuple of
|
52 |
+
batchfied scores for next token with shape of `(n_batch, n_vocab)`
|
53 |
+
and next state list for ys.
|
54 |
+
|
55 |
+
"""
|
56 |
+
return (
|
57 |
+
torch.tensor([1.0], device=xs.device, dtype=xs.dtype).expand(
|
58 |
+
ys.shape[0], self.n
|
59 |
+
),
|
60 |
+
None,
|
61 |
+
)
|
espnet/utils/cli_utils.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections.abc import Sequence
|
2 |
+
from distutils.util import strtobool as dist_strtobool
|
3 |
+
import sys
|
4 |
+
|
5 |
+
import numpy
|
6 |
+
|
7 |
+
|
8 |
+
def strtobool(x):
|
9 |
+
# distutils.util.strtobool returns integer, but it's confusing,
|
10 |
+
return bool(dist_strtobool(x))
|
11 |
+
|
12 |
+
|
13 |
+
def get_commandline_args():
|
14 |
+
extra_chars = [
|
15 |
+
" ",
|
16 |
+
";",
|
17 |
+
"&",
|
18 |
+
"(",
|
19 |
+
")",
|
20 |
+
"|",
|
21 |
+
"^",
|
22 |
+
"<",
|
23 |
+
">",
|
24 |
+
"?",
|
25 |
+
"*",
|
26 |
+
"[",
|
27 |
+
"]",
|
28 |
+
"$",
|
29 |
+
"`",
|
30 |
+
'"',
|
31 |
+
"\\",
|
32 |
+
"!",
|
33 |
+
"{",
|
34 |
+
"}",
|
35 |
+
]
|
36 |
+
|
37 |
+
# Escape the extra characters for shell
|
38 |
+
argv = [
|
39 |
+
arg.replace("'", "'\\''")
|
40 |
+
if all(char not in arg for char in extra_chars)
|
41 |
+
else "'" + arg.replace("'", "'\\''") + "'"
|
42 |
+
for arg in sys.argv
|
43 |
+
]
|
44 |
+
|
45 |
+
return sys.executable + " " + " ".join(argv)
|
46 |
+
|
47 |
+
|
48 |
+
def is_scipy_wav_style(value):
|
49 |
+
# If Tuple[int, numpy.ndarray] or not
|
50 |
+
return (
|
51 |
+
isinstance(value, Sequence)
|
52 |
+
and len(value) == 2
|
53 |
+
and isinstance(value[0], int)
|
54 |
+
and isinstance(value[1], numpy.ndarray)
|
55 |
+
)
|
56 |
+
|
57 |
+
|
58 |
+
def assert_scipy_wav_style(value):
|
59 |
+
assert is_scipy_wav_style(
|
60 |
+
value
|
61 |
+
), "Must be Tuple[int, numpy.ndarray], but got {}".format(
|
62 |
+
type(value)
|
63 |
+
if not isinstance(value, Sequence)
|
64 |
+
else "{}[{}]".format(type(value), ", ".join(str(type(v)) for v in value))
|
65 |
+
)
|
espnet/utils/dynamic_import.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import importlib
|
2 |
+
|
3 |
+
|
4 |
+
def dynamic_import(import_path, alias=dict()):
|
5 |
+
"""dynamic import module and class
|
6 |
+
|
7 |
+
:param str import_path: syntax 'module_name:class_name'
|
8 |
+
e.g., 'espnet.transform.add_deltas:AddDeltas'
|
9 |
+
:param dict alias: shortcut for registered class
|
10 |
+
:return: imported class
|
11 |
+
"""
|
12 |
+
if import_path not in alias and ":" not in import_path:
|
13 |
+
raise ValueError(
|
14 |
+
"import_path should be one of {} or "
|
15 |
+
'include ":", e.g. "espnet.transform.add_deltas:AddDeltas" : '
|
16 |
+
"{}".format(set(alias), import_path)
|
17 |
+
)
|
18 |
+
if ":" not in import_path:
|
19 |
+
import_path = alias[import_path]
|
20 |
+
|
21 |
+
module_name, objname = import_path.split(":")
|
22 |
+
m = importlib.import_module(module_name)
|
23 |
+
return getattr(m, objname)
|
espnet/utils/fill_missing_args.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
# Copyright 2018 Nagoya University (Tomoki Hayashi)
|
4 |
+
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
5 |
+
|
6 |
+
import argparse
|
7 |
+
import logging
|
8 |
+
|
9 |
+
|
10 |
+
def fill_missing_args(args, add_arguments):
|
11 |
+
"""Fill missing arguments in args.
|
12 |
+
|
13 |
+
Args:
|
14 |
+
args (Namespace or None): Namesapce containing hyperparameters.
|
15 |
+
add_arguments (function): Function to add arguments.
|
16 |
+
|
17 |
+
Returns:
|
18 |
+
Namespace: Arguments whose missing ones are filled with default value.
|
19 |
+
|
20 |
+
Examples:
|
21 |
+
>>> from argparse import Namespace
|
22 |
+
>>> from espnet.nets.pytorch_backend.e2e_tts_tacotron2 import Tacotron2
|
23 |
+
>>> args = Namespace()
|
24 |
+
>>> fill_missing_args(args, Tacotron2.add_arguments_fn)
|
25 |
+
Namespace(aconv_chans=32, aconv_filts=15, adim=512, atype='location', ...)
|
26 |
+
|
27 |
+
"""
|
28 |
+
# check argument type
|
29 |
+
assert isinstance(args, argparse.Namespace) or args is None
|
30 |
+
assert callable(add_arguments)
|
31 |
+
|
32 |
+
# get default arguments
|
33 |
+
default_args, _ = add_arguments(argparse.ArgumentParser()).parse_known_args()
|
34 |
+
|
35 |
+
# convert to dict
|
36 |
+
args = {} if args is None else vars(args)
|
37 |
+
default_args = vars(default_args)
|
38 |
+
|
39 |
+
for key, value in default_args.items():
|
40 |
+
if key not in args:
|
41 |
+
logging.info(
|
42 |
+
'attribute "%s" does not exist. use default %s.' % (key, str(value))
|
43 |
+
)
|
44 |
+
args[key] = value
|
45 |
+
|
46 |
+
return argparse.Namespace(**args)
|
pipelines/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
pipelines/data/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|