Upload torch.jit.trace() exported files
#1
by
csukuangfj
- opened
- exp/decoder_jit_trace.pt +3 -0
- exp/encoder_jit_trace.pt +3 -0
- exp/jit_trace_export-zh.py +323 -0
- exp/jit_trace_export-zh.sh +43 -0
- exp/joiner_jit_trace.pt +3 -0
exp/decoder_jit_trace.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4e3a51b423148a03481155e7785dba05d3eda920749fd940a81a2decde30a510
|
3 |
+
size 12830070
|
exp/encoder_jit_trace.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:95344ad3309c566bde4f0582f29a88037ba6b2ca518584f84011f281d7677def
|
3 |
+
size 330440074
|
exp/jit_trace_export-zh.py
ADDED
@@ -0,0 +1,323 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
"""
|
4 |
+
Usage:
|
5 |
+
./pruned_transducer_stateless7_streaming/jit_trace_export-zh.py \
|
6 |
+
--exp-dir $dir/exp \
|
7 |
+
--exp-dir ./pruned_transducer_stateless7_streaming/exp \
|
8 |
+
--lang-dir ./data/lang_char_bpe \
|
9 |
+
--epoch 99 \
|
10 |
+
--avg 1 \
|
11 |
+
--use-averaged-model 0 \
|
12 |
+
\
|
13 |
+
--decode-chunk-len 32 \
|
14 |
+
--num-encoder-layers "2,4,3,2,4" \
|
15 |
+
--feedforward-dims "1024,1024,1536,1536,1024" \
|
16 |
+
--nhead "8,8,8,8,8" \
|
17 |
+
--encoder-dims "384,384,384,384,384" \
|
18 |
+
--attention-dims "192,192,192,192,192" \
|
19 |
+
--encoder-unmasked-dims "256,256,256,256,256" \
|
20 |
+
--zipformer-downsampling-factors "1,2,4,8,2" \
|
21 |
+
--cnn-module-kernels "31,31,31,31,31" \
|
22 |
+
--decoder-dim 512 \
|
23 |
+
--joiner-dim 512
|
24 |
+
"""
|
25 |
+
|
26 |
+
import argparse
|
27 |
+
import logging
|
28 |
+
from pathlib import Path
|
29 |
+
|
30 |
+
import sentencepiece as spm
|
31 |
+
import torch
|
32 |
+
from scaling_converter import convert_scaled_to_non_scaled
|
33 |
+
from train import add_model_arguments, get_params, get_transducer_model
|
34 |
+
from icefall.lexicon import Lexicon
|
35 |
+
|
36 |
+
from icefall.checkpoint import (
|
37 |
+
average_checkpoints,
|
38 |
+
average_checkpoints_with_averaged_model,
|
39 |
+
find_checkpoints,
|
40 |
+
load_checkpoint,
|
41 |
+
)
|
42 |
+
from icefall.utils import AttributeDict, str2bool
|
43 |
+
|
44 |
+
|
45 |
+
def get_parser():
|
46 |
+
parser = argparse.ArgumentParser(
|
47 |
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
48 |
+
)
|
49 |
+
|
50 |
+
parser.add_argument(
|
51 |
+
"--epoch",
|
52 |
+
type=int,
|
53 |
+
default=28,
|
54 |
+
help="""It specifies the checkpoint to use for averaging.
|
55 |
+
Note: Epoch counts from 0.
|
56 |
+
You can specify --avg to use more checkpoints for model averaging.""",
|
57 |
+
)
|
58 |
+
|
59 |
+
parser.add_argument(
|
60 |
+
"--iter",
|
61 |
+
type=int,
|
62 |
+
default=0,
|
63 |
+
help="""If positive, --epoch is ignored and it
|
64 |
+
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
65 |
+
You can specify --avg to use more checkpoints for model averaging.
|
66 |
+
""",
|
67 |
+
)
|
68 |
+
|
69 |
+
parser.add_argument(
|
70 |
+
"--avg",
|
71 |
+
type=int,
|
72 |
+
default=15,
|
73 |
+
help="Number of checkpoints to average. Automatically select "
|
74 |
+
"consecutive checkpoints before the checkpoint specified by "
|
75 |
+
"'--epoch' and '--iter'",
|
76 |
+
)
|
77 |
+
|
78 |
+
parser.add_argument(
|
79 |
+
"--exp-dir",
|
80 |
+
type=str,
|
81 |
+
default="pruned_transducer_stateless2/exp",
|
82 |
+
help="""It specifies the directory where all training related
|
83 |
+
files, e.g., checkpoints, log, etc, are saved
|
84 |
+
""",
|
85 |
+
)
|
86 |
+
|
87 |
+
parser.add_argument(
|
88 |
+
"--lang-dir",
|
89 |
+
type=str,
|
90 |
+
default="data/lang_char",
|
91 |
+
help="The lang dir",
|
92 |
+
)
|
93 |
+
|
94 |
+
parser.add_argument(
|
95 |
+
"--context-size",
|
96 |
+
type=int,
|
97 |
+
default=2,
|
98 |
+
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
99 |
+
)
|
100 |
+
|
101 |
+
parser.add_argument(
|
102 |
+
"--use-averaged-model",
|
103 |
+
type=str2bool,
|
104 |
+
default=True,
|
105 |
+
help="Whether to load averaged model. Currently it only supports "
|
106 |
+
"using --epoch. If True, it would decode with the averaged model "
|
107 |
+
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
108 |
+
"Actually only the models with epoch number of `epoch-avg` and "
|
109 |
+
"`epoch` are loaded for averaging. ",
|
110 |
+
)
|
111 |
+
|
112 |
+
add_model_arguments(parser)
|
113 |
+
|
114 |
+
return parser
|
115 |
+
|
116 |
+
|
117 |
+
def export_encoder_model_jit_trace(
|
118 |
+
encoder_model: torch.nn.Module,
|
119 |
+
encoder_filename: str,
|
120 |
+
params: AttributeDict,
|
121 |
+
) -> None:
|
122 |
+
"""Export the given encoder model with torch.jit.trace()
|
123 |
+
|
124 |
+
Note: The warmup argument is fixed to 1.
|
125 |
+
|
126 |
+
Args:
|
127 |
+
encoder_model:
|
128 |
+
The input encoder model
|
129 |
+
encoder_filename:
|
130 |
+
The filename to save the exported model.
|
131 |
+
"""
|
132 |
+
decode_chunk_len = params.decode_chunk_len # before subsampling
|
133 |
+
pad_length = 7
|
134 |
+
s = f"decode_chunk_len: {decode_chunk_len}"
|
135 |
+
logging.info(s)
|
136 |
+
assert encoder_model.decode_chunk_size == decode_chunk_len // 2, (
|
137 |
+
encoder_model.decode_chunk_size,
|
138 |
+
decode_chunk_len,
|
139 |
+
)
|
140 |
+
|
141 |
+
T = decode_chunk_len + pad_length
|
142 |
+
|
143 |
+
x = torch.zeros(1, T, 80, dtype=torch.float32)
|
144 |
+
x_lens = torch.full((1,), T, dtype=torch.int32)
|
145 |
+
states = encoder_model.get_init_state(device=x.device)
|
146 |
+
|
147 |
+
encoder_model.__class__.forward = encoder_model.__class__.streaming_forward
|
148 |
+
traced_model = torch.jit.trace(encoder_model, (x, x_lens, states))
|
149 |
+
traced_model.save(encoder_filename)
|
150 |
+
logging.info(f"Saved to {encoder_filename}")
|
151 |
+
|
152 |
+
|
153 |
+
def export_decoder_model_jit_trace(
|
154 |
+
decoder_model: torch.nn.Module,
|
155 |
+
decoder_filename: str,
|
156 |
+
) -> None:
|
157 |
+
"""Export the given decoder model with torch.jit.trace()
|
158 |
+
|
159 |
+
Note: The argument need_pad is fixed to False.
|
160 |
+
|
161 |
+
Args:
|
162 |
+
decoder_model:
|
163 |
+
The input decoder model
|
164 |
+
decoder_filename:
|
165 |
+
The filename to save the exported model.
|
166 |
+
"""
|
167 |
+
y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64)
|
168 |
+
need_pad = torch.tensor([False])
|
169 |
+
|
170 |
+
traced_model = torch.jit.trace(decoder_model, (y, need_pad))
|
171 |
+
traced_model.save(decoder_filename)
|
172 |
+
logging.info(f"Saved to {decoder_filename}")
|
173 |
+
|
174 |
+
|
175 |
+
def export_joiner_model_jit_trace(
|
176 |
+
joiner_model: torch.nn.Module,
|
177 |
+
joiner_filename: str,
|
178 |
+
) -> None:
|
179 |
+
"""Export the given joiner model with torch.jit.trace()
|
180 |
+
|
181 |
+
Note: The argument project_input is fixed to True. A user should not
|
182 |
+
project the encoder_out/decoder_out by himself/herself. The exported joiner
|
183 |
+
will do that for the user.
|
184 |
+
|
185 |
+
Args:
|
186 |
+
joiner_model:
|
187 |
+
The input joiner model
|
188 |
+
joiner_filename:
|
189 |
+
The filename to save the exported model.
|
190 |
+
|
191 |
+
"""
|
192 |
+
encoder_out_dim = joiner_model.encoder_proj.weight.shape[1]
|
193 |
+
decoder_out_dim = joiner_model.decoder_proj.weight.shape[1]
|
194 |
+
encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32)
|
195 |
+
decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32)
|
196 |
+
|
197 |
+
traced_model = torch.jit.trace(joiner_model, (encoder_out, decoder_out))
|
198 |
+
traced_model.save(joiner_filename)
|
199 |
+
logging.info(f"Saved to {joiner_filename}")
|
200 |
+
|
201 |
+
|
202 |
+
@torch.no_grad()
|
203 |
+
def main():
|
204 |
+
args = get_parser().parse_args()
|
205 |
+
args.exp_dir = Path(args.exp_dir)
|
206 |
+
|
207 |
+
params = get_params()
|
208 |
+
params.update(vars(args))
|
209 |
+
|
210 |
+
device = torch.device("cpu")
|
211 |
+
|
212 |
+
logging.info(f"device: {device}")
|
213 |
+
|
214 |
+
lexicon = Lexicon(params.lang_dir)
|
215 |
+
params.blank_id = 0
|
216 |
+
params.vocab_size = max(lexicon.tokens) + 1
|
217 |
+
|
218 |
+
logging.info(params)
|
219 |
+
|
220 |
+
logging.info("About to create model")
|
221 |
+
model = get_transducer_model(params)
|
222 |
+
|
223 |
+
if not params.use_averaged_model:
|
224 |
+
if params.iter > 0:
|
225 |
+
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
226 |
+
: params.avg
|
227 |
+
]
|
228 |
+
if len(filenames) == 0:
|
229 |
+
raise ValueError(
|
230 |
+
f"No checkpoints found for"
|
231 |
+
f" --iter {params.iter}, --avg {params.avg}"
|
232 |
+
)
|
233 |
+
elif len(filenames) < params.avg:
|
234 |
+
raise ValueError(
|
235 |
+
f"Not enough checkpoints ({len(filenames)}) found for"
|
236 |
+
f" --iter {params.iter}, --avg {params.avg}"
|
237 |
+
)
|
238 |
+
logging.info(f"averaging {filenames}")
|
239 |
+
model.to(device)
|
240 |
+
model.load_state_dict(average_checkpoints(filenames, device=device))
|
241 |
+
elif params.avg == 1:
|
242 |
+
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
243 |
+
else:
|
244 |
+
start = params.epoch - params.avg + 1
|
245 |
+
filenames = []
|
246 |
+
for i in range(start, params.epoch + 1):
|
247 |
+
if i >= 1:
|
248 |
+
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
249 |
+
logging.info(f"averaging {filenames}")
|
250 |
+
model.to(device)
|
251 |
+
model.load_state_dict(average_checkpoints(filenames, device=device))
|
252 |
+
else:
|
253 |
+
if params.iter > 0:
|
254 |
+
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
255 |
+
: params.avg + 1
|
256 |
+
]
|
257 |
+
if len(filenames) == 0:
|
258 |
+
raise ValueError(
|
259 |
+
f"No checkpoints found for"
|
260 |
+
f" --iter {params.iter}, --avg {params.avg}"
|
261 |
+
)
|
262 |
+
elif len(filenames) < params.avg + 1:
|
263 |
+
raise ValueError(
|
264 |
+
f"Not enough checkpoints ({len(filenames)}) found for"
|
265 |
+
f" --iter {params.iter}, --avg {params.avg}"
|
266 |
+
)
|
267 |
+
filename_start = filenames[-1]
|
268 |
+
filename_end = filenames[0]
|
269 |
+
logging.info(
|
270 |
+
"Calculating the averaged model over iteration checkpoints"
|
271 |
+
f" from {filename_start} (excluded) to {filename_end}"
|
272 |
+
)
|
273 |
+
model.to(device)
|
274 |
+
model.load_state_dict(
|
275 |
+
average_checkpoints_with_averaged_model(
|
276 |
+
filename_start=filename_start,
|
277 |
+
filename_end=filename_end,
|
278 |
+
device=device,
|
279 |
+
)
|
280 |
+
)
|
281 |
+
else:
|
282 |
+
assert params.avg > 0, params.avg
|
283 |
+
start = params.epoch - params.avg
|
284 |
+
assert start >= 1, start
|
285 |
+
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
286 |
+
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
287 |
+
logging.info(
|
288 |
+
f"Calculating the averaged model over epoch range from "
|
289 |
+
f"{start} (excluded) to {params.epoch}"
|
290 |
+
)
|
291 |
+
model.to(device)
|
292 |
+
model.load_state_dict(
|
293 |
+
average_checkpoints_with_averaged_model(
|
294 |
+
filename_start=filename_start,
|
295 |
+
filename_end=filename_end,
|
296 |
+
device=device,
|
297 |
+
)
|
298 |
+
)
|
299 |
+
|
300 |
+
model.to("cpu")
|
301 |
+
model.eval()
|
302 |
+
|
303 |
+
convert_scaled_to_non_scaled(model, inplace=True)
|
304 |
+
logging.info("Using torch.jit.trace()")
|
305 |
+
|
306 |
+
logging.info("Exporting encoder")
|
307 |
+
encoder_filename = params.exp_dir / "encoder_jit_trace.pt"
|
308 |
+
export_encoder_model_jit_trace(model.encoder, encoder_filename, params)
|
309 |
+
|
310 |
+
logging.info("Exporting decoder")
|
311 |
+
decoder_filename = params.exp_dir / "decoder_jit_trace.pt"
|
312 |
+
export_decoder_model_jit_trace(model.decoder, decoder_filename)
|
313 |
+
|
314 |
+
logging.info("Exporting joiner")
|
315 |
+
joiner_filename = params.exp_dir / "joiner_jit_trace.pt"
|
316 |
+
export_joiner_model_jit_trace(model.joiner, joiner_filename)
|
317 |
+
|
318 |
+
|
319 |
+
if __name__ == "__main__":
|
320 |
+
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
321 |
+
|
322 |
+
logging.basicConfig(format=formatter, level=logging.INFO)
|
323 |
+
main()
|
exp/jit_trace_export-zh.sh
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env bash
|
2 |
+
|
3 |
+
# Please go to
|
4 |
+
# https://huggingface.co/Zengwei/icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05
|
5 |
+
# to download the pre-trained models
|
6 |
+
|
7 |
+
#
|
8 |
+
# cd $dir
|
9 |
+
# ln -s icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/pretrained-epoch-30-avg-10-averaged.pt epoch-30.pt
|
10 |
+
|
11 |
+
. path.sh
|
12 |
+
|
13 |
+
export CUDA_VISIBLE_DEVICES=""
|
14 |
+
set -ex
|
15 |
+
|
16 |
+
dir=./k2fsa-zipformer-chinese-english-mixed
|
17 |
+
if [ ! -f $dir/exp/epoch-99.pt ]; then
|
18 |
+
pushd $dir/exp
|
19 |
+
ln -s pretrained.pt epoch-99.pt
|
20 |
+
popd
|
21 |
+
fi
|
22 |
+
|
23 |
+
./pruned_transducer_stateless7_streaming/jit_trace_export-zh.py \
|
24 |
+
--exp-dir $dir/exp \
|
25 |
+
--lang-dir $dir/data/lang_char_bpe \
|
26 |
+
--epoch 99 \
|
27 |
+
--avg 1 \
|
28 |
+
--use-averaged-model 0 \
|
29 |
+
\
|
30 |
+
--decode-chunk-len 32 \
|
31 |
+
--num-encoder-layers "2,4,3,2,4" \
|
32 |
+
--feedforward-dims "1024,1024,1536,1536,1024" \
|
33 |
+
--nhead "8,8,8,8,8" \
|
34 |
+
--encoder-dims "384,384,384,384,384" \
|
35 |
+
--attention-dims "192,192,192,192,192" \
|
36 |
+
--encoder-unmasked-dims "256,256,256,256,256" \
|
37 |
+
--zipformer-downsampling-factors "1,2,4,8,2" \
|
38 |
+
--cnn-module-kernels "31,31,31,31,31" \
|
39 |
+
--decoder-dim 512 \
|
40 |
+
--joiner-dim 512
|
41 |
+
|
42 |
+
exit 0
|
43 |
+
|
exp/joiner_jit_trace.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7628036f64e4c0281d02ccb696df74376297645e0400265620bf3806c31e5621
|
3 |
+
size 14679599
|