yuekai commited on
Commit
a4f8069
1 Parent(s): 95eeef3

Create export_onnx.py

Browse files
Files changed (1) hide show
  1. export_onnx.py +139 -0
export_onnx.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- encoding: utf-8 -*-
3
+ # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
4
+ # MIT License (https://opensource.org/licenses/MIT)
5
+
6
+ import types
7
+ import torch
8
+ import torch.nn as nn
9
+ from funasr.register import tables
10
+
11
+
12
+ def export_rebuild_model(model, **kwargs):
13
+ model.device = kwargs.get("device")
14
+ is_onnx = kwargs.get("type", "onnx") == "onnx"
15
+ # encoder_class = tables.encoder_classes.get(kwargs["encoder"] + "Export")
16
+ # model.encoder = encoder_class(model.encoder, onnx=is_onnx)
17
+
18
+
19
+
20
+ from funasr.utils.torch_function import sequence_mask
21
+
22
+ # model.make_pad_mask = sequence_mask(kwargs["max_seq_len"], flip=False)
23
+
24
+ model.forward = types.MethodType(export_forward, model)
25
+ model.export_dummy_inputs = types.MethodType(export_dummy_inputs, model)
26
+ model.export_input_names = types.MethodType(export_input_names, model)
27
+ model.export_output_names = types.MethodType(export_output_names, model)
28
+ model.export_dynamic_axes = types.MethodType(export_dynamic_axes, model)
29
+ model.export_name = types.MethodType(export_name, model)
30
+
31
+ model.export_name = "model"
32
+ return model
33
+
34
+
35
+ def export_forward(
36
+ self,
37
+ speech: torch.Tensor,
38
+ speech_lengths: torch.Tensor,
39
+ language: torch.Tensor,
40
+ textnorm: torch.Tensor,
41
+ **kwargs,
42
+ ):
43
+ speech = speech.to(device='cuda')
44
+ speech_lengths = speech_lengths.to(device='cuda')
45
+
46
+ language_query = self.embed(language.to(speech.device)).unsqueeze(1)
47
+
48
+ textnorm_query = self.embed(textnorm.to(speech.device)).unsqueeze(1)
49
+ print(textnorm_query.shape, speech.shape)
50
+ speech = torch.cat((textnorm_query, speech), dim=1)
51
+ speech_lengths += 1
52
+
53
+ event_emo_query = self.embed(torch.LongTensor([[1, 2]]).to(speech.device)).repeat(
54
+ speech.size(0), 1, 1
55
+ )
56
+ input_query = torch.cat((language_query, event_emo_query), dim=1)
57
+ speech = torch.cat((input_query, speech), dim=1)
58
+ speech_lengths += 3
59
+
60
+ # Encoder
61
+ encoder_out, encoder_out_lens = self.encoder(speech, speech_lengths)
62
+ if isinstance(encoder_out, tuple):
63
+ encoder_out = encoder_out[0]
64
+
65
+ # c. Passed the encoder result and the beam search
66
+ # ctc_logits = self.ctc.log_softmax(encoder_out)
67
+ ctc_logits = self.ctc.ctc_lo(encoder_out)
68
+
69
+
70
+ return ctc_logits, encoder_out_lens
71
+
72
+
73
+ def export_dummy_inputs(self):
74
+ speech = torch.randn(2, 30, 560)
75
+ speech_lengths = torch.tensor([6, 30], dtype=torch.int32)
76
+ language = torch.tensor([0, 0], dtype=torch.int32)
77
+ textnorm = torch.tensor([15, 15], dtype=torch.int32)
78
+ return (speech, speech_lengths, language, textnorm)
79
+
80
+
81
+ def export_input_names(self):
82
+ return ["speech", "speech_lengths", "language", "textnorm"]
83
+
84
+
85
+ def export_output_names(self):
86
+ return ["ctc_logits", "encoder_out_lens"]
87
+
88
+
89
+ def export_dynamic_axes(self):
90
+ return {
91
+ "speech": {0: "batch_size", 1: "feats_length"},
92
+ "speech_lengths": {
93
+ 0: "batch_size",
94
+ },
95
+ "language": {0: "batch_size"},
96
+ "textnorm": {0: "batch_size"},
97
+ "ctc_logits": {0: "batch_size", 1: "logits_length"},
98
+ }
99
+
100
+
101
+ def export_name(
102
+ self,
103
+ ):
104
+ return "model.onnx"
105
+
106
+
107
+
108
+ if __name__ == "__main__":
109
+ from model import SenseVoiceSmall
110
+
111
+ model_dir = "iic/SenseVoiceSmall"
112
+ #model_dir = "./SenseVoiceSmall"
113
+ model, kwargs = SenseVoiceSmall.from_pretrained(model=model_dir)
114
+ # model = model.to("cpu")
115
+ model = export_rebuild_model(model, max_seq_len=512, device="cuda")
116
+ # model.export()
117
+ print("Export Done.")
118
+
119
+ dummy_inputs = model.export_dummy_inputs()
120
+
121
+ # Export the model
122
+ torch.onnx.export(
123
+ model,
124
+ dummy_inputs,
125
+ "model.onnx",
126
+ input_names=model.export_input_names(),
127
+ output_names=model.export_output_names(),
128
+ dynamic_axes=model.export_dynamic_axes(),
129
+ opset_version=18
130
+ )
131
+ # import os
132
+ # import onnxmltools
133
+ # from onnxmltools.utils.float16_converter import (
134
+ # convert_float_to_float16)
135
+ # decoder_onnx_model = onnxmltools.utils.load_model("model.onnx")
136
+ # decoder_onnx_model = convert_float_to_float16(decoder_onnx_model)
137
+ # decoder_onnx_path = "model_fp16.onnx"
138
+ # onnxmltools.utils.save_model(decoder_onnx_model, decoder_onnx_path)
139
+ # print("Model has been successfully exported to model.onnx")