ONNX
kimihailv commited on
Commit
0ed065b
1 Parent(s): 18d9bf5

Upload convert_model.py

Browse files
Files changed (1) hide show
  1. convert_model.py +132 -0
convert_model.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import uform
2
+ import torch
3
+ import coremltools as ct
4
+ from os.path import join
5
+ from argparse import ArgumentParser
6
+
7
+ class TextEncoder(torch.nn.Module):
8
+ def __init__(self, model):
9
+ super().__init__()
10
+ self.model = model.eval()
11
+
12
+ def forward(self, input_ids, attention_mask):
13
+ features = self.model.forward_features(
14
+ input_ids, attention_mask
15
+ )
16
+
17
+ embeddings = self.model.forward_embedding(
18
+ features, attention_mask
19
+ )
20
+
21
+ return features, embeddings
22
+
23
+
24
+ class ImageEncoder(torch.nn.Module):
25
+ def __init__(self, model):
26
+ super().__init__()
27
+ self.model = model.eval()
28
+
29
+ def forward(self, image):
30
+ features = self.model.forward_features(
31
+ image
32
+ )
33
+
34
+ embeddings = self.model.forward_embedding(
35
+ features
36
+ )
37
+
38
+ return features, embeddings
39
+
40
+ def convert_model(opts):
41
+ src_model = uform.get_model(opts.model_name)
42
+ input_ids = torch.ones(1, 77, dtype=torch.int32)
43
+ attention_mask = torch.ones(1, 77, dtype=torch.int32)
44
+ image = torch.ones(1, 3, 224, 224, dtype=torch.float32)
45
+
46
+ print('Tracing models…')
47
+ image_encoder = ImageEncoder(src_model.image_encoder).eval()
48
+ image_encoder = torch.jit.trace(image_encoder, image)
49
+ text_encoder = TextEncoder(src_model.text_encoder).eval()
50
+ text_encoder = torch.jit.trace(text_encoder, (input_ids, attention_mask))
51
+
52
+ print('Converting models…')
53
+
54
+ image_encoder = ct.convert(
55
+ image_encoder,
56
+ convert_to='mlprogram',
57
+ inputs=[
58
+ ct.TensorType(
59
+ name='image',
60
+ shape=(ct.RangeDim(lower_bound=opts.batchsize_lb, upper_bound=opts.batchsize_ub, default=1), 3, 224, 224),
61
+ dtype=image.numpy().dtype
62
+ )],
63
+ outputs=[
64
+ ct.TensorType(
65
+ name='features'
66
+ ),
67
+ ct.TensorType(
68
+ name='embeddings'
69
+ )
70
+ ],
71
+ compute_precision=ct.precision.FLOAT16 if opts.use_fp16 else ct.precision.FLOAT32
72
+ )
73
+
74
+ text_encoder = ct.convert(
75
+ text_encoder,
76
+ convert_to='mlprogram',
77
+ inputs=[
78
+ ct.TensorType(
79
+ name='input_ids',
80
+ shape=(ct.RangeDim(lower_bound=opts.batchsize_lb, upper_bound=opts.batchsize_ub, default=1), 77),
81
+ dtype=input_ids.numpy().dtype
82
+ ),
83
+ ct.TensorType(
84
+ name='attention_mask',
85
+ shape=(ct.RangeDim(lower_bound=opts.batchsize_lb, upper_bound=opts.batchsize_ub, default=1), 77),
86
+ dtype=attention_mask.numpy().dtype
87
+ )],
88
+ outputs=[
89
+ ct.TensorType(
90
+ name="features"
91
+ ),
92
+ ct.TensorType(
93
+ name="embeddings"
94
+ )
95
+ ],
96
+ compute_precision=ct.precision.FLOAT16 if opts.use_fp16 else ct.precision.FLOAT32
97
+ )
98
+
99
+ print('Image encoder:', image_encoder, sep='\n')
100
+ print('Text encoder:', text_encoder, sep='\n')
101
+
102
+ image_encoder.save(join(opts.output_dir, f"{opts.model_name.replace('/', '.')}.image-encoder.mlpackage"))
103
+ text_encoder.save(join(opts.output_dir, f"{opts.model_name.replace('/', '.')}.text-encoder.mlpackage"))
104
+
105
+
106
+ if __name__ == '__main__':
107
+ opts = ArgumentParser()
108
+ opts.add_argument('--model_name',
109
+ action='store',
110
+ type=str,
111
+ help='UForm model name')
112
+
113
+ opts.add_argument('--batchsize_lb',
114
+ action='store',
115
+ type=int,
116
+ help='lower bound of batch size')
117
+
118
+ opts.add_argument('--batchsize_ub',
119
+ action='store',
120
+ type=int,
121
+ help='upper bound of batch size')
122
+
123
+ opts.add_argument('-use_fp16',
124
+ action='store_true',
125
+ help='whether to use fp16 for inference or not')
126
+
127
+ opts.add_argument('--output_dir',
128
+ action='store',
129
+ type=str,
130
+ help='ouput directory')
131
+
132
+ convert_model(opts.parse_args())