tanthinhdt commited on
Commit
e7a4186
·
1 Parent(s): dc6d681

feat: use ONNX model

Browse files
Files changed (5) hide show
  1. app.py +40 -30
  2. config.json +235 -0
  3. preprocessor_config.json +27 -0
  4. utils.py +7 -12
  5. videomae_skeleton_v2.3.onnx +3 -0
app.py CHANGED
@@ -1,10 +1,11 @@
 
1
  import gradio as gr
 
 
2
  from mediapipe.python.solutions import holistic
3
  from torchvision.transforms.v2 import Compose, Lambda, Normalize
4
- from transformers import VideoMAEForVideoClassification, VideoMAEImageProcessor
5
  from utils import get_predictions, preprocess
6
 
7
-
8
  title = '''
9
 
10
  '''
@@ -21,21 +22,17 @@ examples = [
21
  ['000_con_cho.mp4'],
22
  ]
23
 
24
- # Initialize the model and image processor.
25
- device = 'cpu'
26
- model_name = 'VieSignLang/videomae_skeleton_v1.0'
27
- image_processor = VideoMAEImageProcessor.from_pretrained(model_name)
28
- model = VideoMAEForVideoClassification.from_pretrained(model_name)
29
- model = model.eval().to(device)
30
-
31
- # Get the mean, std, and model input size.
32
- mean = image_processor.image_mean
33
- std = image_processor.image_std
34
- if 'shortest_edge' in image_processor.size:
35
- model_input_height = model_input_width = image_processor.size['shortest_edge']
36
  else:
37
- model_input_height = image_processor.size['height']
38
- model_input_width = image_processor.size['width']
39
 
40
  # Define the transform.
41
  transform = Compose(
@@ -73,38 +70,51 @@ def inference(
73
  refine_face_landmarks=True,
74
  )
75
 
 
76
  inputs = preprocess(
77
- model_num_frames=model.config.num_frames,
78
  keypoints_detector=keypoints_detector,
79
  source=video,
80
  model_input_height=model_input_height,
81
  model_input_width=model_input_width,
82
- device=device,
83
  transform=transform,
84
  )
 
 
85
 
86
  progress(1/2, desc='Getting predictions')
87
- predictions = get_predictions(inputs=inputs, model=model)
 
 
 
 
 
 
 
 
88
 
89
  if len(predictions) == 0:
90
  output_message = 'No sign language detected in the video. Please try again.'
91
  else:
92
  output_message = 'The top-3 predictions are:\n'
93
  for i, prediction in enumerate(predictions):
94
- output_message += f'{i+1}. {prediction["label"]} ({prediction["score"]:2f})\n'
95
- output_message = output_message.strip()
 
 
96
 
97
  progress(1/2, desc='Completed')
98
 
99
  return output_message
100
 
101
 
102
- iface = gr.Interface(
103
- fn=inference,
104
- inputs='video',
105
- outputs='text',
106
- examples=examples,
107
- title=title,
108
- description=description,
109
- )
110
- iface.launch()
 
 
1
+ import json
2
  import gradio as gr
3
+ from time import time
4
+ import onnxruntime as ort
5
  from mediapipe.python.solutions import holistic
6
  from torchvision.transforms.v2 import Compose, Lambda, Normalize
 
7
  from utils import get_predictions, preprocess
8
 
 
9
  title = '''
10
 
11
  '''
 
22
  ['000_con_cho.mp4'],
23
  ]
24
 
25
+ ort_session = ort.InferenceSession('videomae_skeleton_v2.3.onnx')
26
+ model_config = json.load(open('config.json'))
27
+ preprocessor_config = json.load(open('preprocessor_config.json'))
28
+
29
+ mean = preprocessor_config['image_mean']
30
+ std = preprocessor_config['image_std']
31
+ if 'shortest_edge' in preprocessor_config['size']:
32
+ model_input_height = model_input_width = preprocessor_config['size']['shortest_edge']
 
 
 
 
33
  else:
34
+ model_input_height = preprocessor_config['size']['height']
35
+ model_input_width = preprocessor_config['size']['width']
36
 
37
  # Define the transform.
38
  transform = Compose(
 
70
  refine_face_landmarks=True,
71
  )
72
 
73
+ start_time = time()
74
  inputs = preprocess(
75
+ model_num_frames=model_config['num_frames'],
76
  keypoints_detector=keypoints_detector,
77
  source=video,
78
  model_input_height=model_input_height,
79
  model_input_width=model_input_width,
 
80
  transform=transform,
81
  )
82
+ end_time = time()
83
+ data_time = end_time - start_time
84
 
85
  progress(1/2, desc='Getting predictions')
86
+ start_time = time()
87
+ predictions = get_predictions(
88
+ inputs=inputs,
89
+ ort_session=ort_session,
90
+ id2gloss=model_config['id2label'],
91
+ k=3,
92
+ )
93
+ end_time = time()
94
+ model_time = end_time - start_time
95
 
96
  if len(predictions) == 0:
97
  output_message = 'No sign language detected in the video. Please try again.'
98
  else:
99
  output_message = 'The top-3 predictions are:\n'
100
  for i, prediction in enumerate(predictions):
101
+ output_message += f'\t{i+1}. {prediction["label"]} ({prediction["score"]:2f})\n'
102
+ output_message += f'Data processing time: {data_time:.2f} seconds\n'
103
+ output_message += f'Model inference time: {model_time:.2f} seconds\n'
104
+ output_message += f'Total time: {data_time + model_time:.2f} seconds'
105
 
106
  progress(1/2, desc='Completed')
107
 
108
  return output_message
109
 
110
 
111
+ # iface = gr.Interface(
112
+ # fn=inference,
113
+ # inputs='video',
114
+ # outputs='text',
115
+ # examples=examples,
116
+ # title=title,
117
+ # description=description,
118
+ # )
119
+ # iface.launch()
120
+ print(inference('000_con_cho.mp4'))
config.json ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "VieSignLang/videomae_skeleton_v2.3",
3
+ "architectures": [
4
+ "VideoMAEForVideoClassification"
5
+ ],
6
+ "attention_probs_dropout_prob": 0.0,
7
+ "decoder_hidden_size": 192,
8
+ "decoder_intermediate_size": 768,
9
+ "decoder_num_attention_heads": 3,
10
+ "decoder_num_hidden_layers": 12,
11
+ "hidden_act": "gelu",
12
+ "hidden_dropout_prob": 0.0,
13
+ "hidden_size": 384,
14
+ "id2label": {
15
+ "0": "Con ch\u00f3",
16
+ "1": "Con m\u00e8o",
17
+ "2": "Con g\u00e0",
18
+ "3": "Con v\u1ecbt",
19
+ "4": "Con r\u00f9a",
20
+ "5": "Con th\u1ecf",
21
+ "6": "Con tr\u00e2u",
22
+ "7": "Con b\u00f2",
23
+ "8": "Con d\u00ea",
24
+ "9": "Con heo",
25
+ "10": "M\u00e0u \u0111en",
26
+ "11": "M\u00e0u tr\u1eafng",
27
+ "12": "M\u00e0u \u0111\u1ecf",
28
+ "13": "M\u00e0u cam",
29
+ "14": "M\u00e0u v\u00e0ng",
30
+ "15": "M\u00e0u l\u00e1 c\u00e2y",
31
+ "16": "M\u00e0u da tr\u1eddi",
32
+ "17": "M\u00e0u h\u1ed3ng",
33
+ "18": "M\u00e0u t\u00edm",
34
+ "19": "M\u00e0u n\u00e2u",
35
+ "20": "Qu\u1ea3 d\u00e2u",
36
+ "21": "Qu\u1ea3 m\u1eadn",
37
+ "22": "Qu\u1ea3 d\u1ee9a",
38
+ "23": "Qu\u1ea3 \u0111\u00e0o",
39
+ "24": "Qu\u1ea3 \u0111u \u0111\u1ee7",
40
+ "25": "Qu\u1ea3 cam",
41
+ "26": "Qu\u1ea3 b\u01a1",
42
+ "27": "Qu\u1ea3 chu\u1ed1i",
43
+ "28": "Qu\u1ea3 xo\u00e0i",
44
+ "29": "Qu\u1ea3 d\u1eeba",
45
+ "30": "B\u1ed1",
46
+ "31": "M\u1eb9",
47
+ "32": "Con trai",
48
+ "33": "Con g\u00e1i",
49
+ "34": "V\u1ee3",
50
+ "35": "Ch\u1ed3ng",
51
+ "36": "\u00d4ng n\u1ed9i",
52
+ "37": "B\u00e0 n\u1ed9i",
53
+ "38": "\u00d4ng ngo\u1ea1i",
54
+ "39": "B\u00e0 ngo\u1ea1i",
55
+ "40": "\u0102n",
56
+ "41": "U\u1ed1ng",
57
+ "42": "Xem",
58
+ "43": "Th\u00e8m",
59
+ "44": "M\u00e1ch",
60
+ "45": "Kh\u00f3c",
61
+ "46": "C\u01b0\u1eddi",
62
+ "47": "H\u1ecdc",
63
+ "48": "D\u1ed7i",
64
+ "49": "Ch\u1ebft",
65
+ "50": "\u0110i",
66
+ "51": "Ch\u1ea1y",
67
+ "52": "B\u1eadn",
68
+ "53": "H\u00e1t",
69
+ "54": "M\u00faa",
70
+ "55": "N\u1ea5u",
71
+ "56": "N\u01b0\u1edbng",
72
+ "57": "Nh\u1ea7m l\u1eabn",
73
+ "58": "Quan s\u00e1t",
74
+ "59": "C\u1eafm tr\u1ea1i",
75
+ "60": "Cung c\u1ea5p",
76
+ "61": "B\u1eaft ch\u01b0\u1edbc",
77
+ "62": "B\u1eaft bu\u1ed9c",
78
+ "63": "B\u00e1o c\u00e1o",
79
+ "64": "Mua b\u00e1n",
80
+ "65": "Kh\u00f4ng quen",
81
+ "66": "Kh\u00f4ng n\u00ean",
82
+ "67": "Kh\u00f4ng c\u1ea7n",
83
+ "68": "Kh\u00f4ng cho",
84
+ "69": "Kh\u00f4ng nghe l\u1eddi",
85
+ "70": "M\u1eb7n",
86
+ "71": "\u0110\u1eafng",
87
+ "72": "Cay",
88
+ "73": "Ng\u1ecdt",
89
+ "74": "\u0110\u1eadm",
90
+ "75": "Nh\u1ea1t",
91
+ "76": "Ngon mi\u1ec7ng",
92
+ "77": "X\u1ea5u",
93
+ "78": "\u0110\u1eb9p",
94
+ "79": "Ch\u1eadt",
95
+ "80": "H\u1eb9p",
96
+ "81": "R\u1ed9ng",
97
+ "82": "D\u00e0i",
98
+ "83": "Cao",
99
+ "84": "L\u00f9n",
100
+ "85": "\u1ed0m",
101
+ "86": "M\u1eadp",
102
+ "87": "Ngoan",
103
+ "88": "H\u01b0",
104
+ "89": "Kh\u1ecfe",
105
+ "90": "M\u1ec7t",
106
+ "91": "\u0110au",
107
+ "92": "Gi\u1ecfi",
108
+ "93": "Ch\u0103m ch\u1ec9",
109
+ "94": "L\u01b0\u1eddi bi\u1ebfng",
110
+ "95": "T\u1ed1t b\u1ee5ng",
111
+ "96": "Th\u00fa v\u1ecb",
112
+ "97": "H\u00e0i h\u01b0\u1edbc",
113
+ "98": "D\u0169ng c\u1ea3m",
114
+ "99": "S\u00e1ng t\u1ea1o"
115
+ },
116
+ "image_size": 224,
117
+ "initializer_range": 0.02,
118
+ "intermediate_size": 1536,
119
+ "label2id": {
120
+ "B\u00e0 ngo\u1ea1i": 39,
121
+ "B\u00e0 n\u1ed9i": 37,
122
+ "B\u00e1o c\u00e1o": 63,
123
+ "B\u1eadn": 52,
124
+ "B\u1eaft bu\u1ed9c": 62,
125
+ "B\u1eaft ch\u01b0\u1edbc": 61,
126
+ "B\u1ed1": 30,
127
+ "Cao": 83,
128
+ "Cay": 72,
129
+ "Ch\u0103m ch\u1ec9": 93,
130
+ "Ch\u1ea1y": 51,
131
+ "Ch\u1eadt": 79,
132
+ "Ch\u1ebft": 49,
133
+ "Ch\u1ed3ng": 35,
134
+ "Con b\u00f2": 7,
135
+ "Con ch\u00f3": 0,
136
+ "Con d\u00ea": 8,
137
+ "Con g\u00e0": 2,
138
+ "Con g\u00e1i": 33,
139
+ "Con heo": 9,
140
+ "Con m\u00e8o": 1,
141
+ "Con r\u00f9a": 4,
142
+ "Con th\u1ecf": 5,
143
+ "Con trai": 32,
144
+ "Con tr\u00e2u": 6,
145
+ "Con v\u1ecbt": 3,
146
+ "Cung c\u1ea5p": 60,
147
+ "C\u01b0\u1eddi": 46,
148
+ "C\u1eafm tr\u1ea1i": 59,
149
+ "D\u00e0i": 82,
150
+ "D\u0169ng c\u1ea3m": 98,
151
+ "D\u1ed7i": 48,
152
+ "Gi\u1ecfi": 92,
153
+ "H\u00e0i h\u01b0\u1edbc": 97,
154
+ "H\u00e1t": 53,
155
+ "H\u01b0": 88,
156
+ "H\u1eb9p": 80,
157
+ "H\u1ecdc": 47,
158
+ "Kh\u00f3c": 45,
159
+ "Kh\u00f4ng cho": 68,
160
+ "Kh\u00f4ng c\u1ea7n": 67,
161
+ "Kh\u00f4ng nghe l\u1eddi": 69,
162
+ "Kh\u00f4ng n\u00ean": 66,
163
+ "Kh\u00f4ng quen": 65,
164
+ "Kh\u1ecfe": 89,
165
+ "L\u00f9n": 84,
166
+ "L\u01b0\u1eddi bi\u1ebfng": 94,
167
+ "Mua b\u00e1n": 64,
168
+ "M\u00e0u cam": 13,
169
+ "M\u00e0u da tr\u1eddi": 16,
170
+ "M\u00e0u h\u1ed3ng": 17,
171
+ "M\u00e0u l\u00e1 c\u00e2y": 15,
172
+ "M\u00e0u n\u00e2u": 19,
173
+ "M\u00e0u tr\u1eafng": 11,
174
+ "M\u00e0u t\u00edm": 18,
175
+ "M\u00e0u v\u00e0ng": 14,
176
+ "M\u00e0u \u0111en": 10,
177
+ "M\u00e0u \u0111\u1ecf": 12,
178
+ "M\u00e1ch": 44,
179
+ "M\u00faa": 54,
180
+ "M\u1eadp": 86,
181
+ "M\u1eb7n": 70,
182
+ "M\u1eb9": 31,
183
+ "M\u1ec7t": 90,
184
+ "Ngoan": 87,
185
+ "Ngon mi\u1ec7ng": 76,
186
+ "Ng\u1ecdt": 73,
187
+ "Nh\u1ea1t": 75,
188
+ "Nh\u1ea7m l\u1eabn": 57,
189
+ "N\u01b0\u1edbng": 56,
190
+ "N\u1ea5u": 55,
191
+ "Quan s\u00e1t": 58,
192
+ "Qu\u1ea3 b\u01a1": 26,
193
+ "Qu\u1ea3 cam": 25,
194
+ "Qu\u1ea3 chu\u1ed1i": 27,
195
+ "Qu\u1ea3 d\u00e2u": 20,
196
+ "Qu\u1ea3 d\u1ee9a": 22,
197
+ "Qu\u1ea3 d\u1eeba": 29,
198
+ "Qu\u1ea3 m\u1eadn": 21,
199
+ "Qu\u1ea3 xo\u00e0i": 28,
200
+ "Qu\u1ea3 \u0111u \u0111\u1ee7": 24,
201
+ "Qu\u1ea3 \u0111\u00e0o": 23,
202
+ "R\u1ed9ng": 81,
203
+ "S\u00e1ng t\u1ea1o": 99,
204
+ "Th\u00e8m": 43,
205
+ "Th\u00fa v\u1ecb": 96,
206
+ "T\u1ed1t b\u1ee5ng": 95,
207
+ "U\u1ed1ng": 41,
208
+ "V\u1ee3": 34,
209
+ "Xem": 42,
210
+ "X\u1ea5u": 77,
211
+ "\u00d4ng ngo\u1ea1i": 38,
212
+ "\u00d4ng n\u1ed9i": 36,
213
+ "\u0102n": 40,
214
+ "\u0110au": 91,
215
+ "\u0110i": 50,
216
+ "\u0110\u1eadm": 74,
217
+ "\u0110\u1eafng": 71,
218
+ "\u0110\u1eb9p": 78,
219
+ "\u1ed0m": 85
220
+ },
221
+ "layer_norm_eps": 1e-12,
222
+ "model_type": "videomae",
223
+ "norm_pix_loss": true,
224
+ "num_attention_heads": 16,
225
+ "num_channels": 3,
226
+ "num_frames": 16,
227
+ "num_hidden_layers": 12,
228
+ "patch_size": 16,
229
+ "problem_type": "single_label_classification",
230
+ "qkv_bias": true,
231
+ "torch_dtype": "float32",
232
+ "transformers_version": "4.28.1",
233
+ "tubelet_size": 2,
234
+ "use_mean_pooling": true
235
+ }
preprocessor_config.json ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "crop_size": {
3
+ "height": 224,
4
+ "width": 224
5
+ },
6
+ "do_center_crop": true,
7
+ "do_normalize": true,
8
+ "do_rescale": true,
9
+ "do_resize": true,
10
+ "feature_extractor_type": "VideoMAEFeatureExtractor",
11
+ "image_mean": [
12
+ 0.485,
13
+ 0.456,
14
+ 0.406
15
+ ],
16
+ "image_processor_type": "VideoMAEImageProcessor",
17
+ "image_std": [
18
+ 0.229,
19
+ 0.224,
20
+ 0.225
21
+ ],
22
+ "resample": 2,
23
+ "rescale_factor": 0.00392156862745098,
24
+ "size": {
25
+ "shortest_edge": 224
26
+ }
27
+ }
utils.py CHANGED
@@ -1,10 +1,10 @@
1
  import cv2
2
- import torch
3
  import numpy as np
 
 
4
  from mediapipe.python.solutions import (drawing_styles, drawing_utils,
5
  holistic, pose)
6
  from torchvision.transforms.v2 import Compose, UniformTemporalSubsample
7
- from transformers import VideoMAEForVideoClassification
8
 
9
 
10
  def draw_skeleton_on_image(
@@ -178,7 +178,8 @@ def do_hands_relax(
178
 
179
  def get_predictions(
180
  inputs: dict,
181
- model: VideoMAEForVideoClassification,
 
182
  k: int = 3,
183
  ) -> list:
184
  '''
@@ -201,9 +202,7 @@ def get_predictions(
201
  if inputs is None:
202
  return []
203
 
204
- with torch.no_grad():
205
- outputs = model(**inputs)
206
- logits = outputs.logits
207
 
208
  # Get top-3 predictions
209
  topk_scores, topk_indices = torch.topk(logits, k, dim=1)
@@ -212,7 +211,7 @@ def get_predictions(
212
 
213
  return [
214
  {
215
- 'label': model.config.id2label[topk_indices[i]],
216
  'score': topk_scores[i],
217
  }
218
  for i in range(k)
@@ -225,7 +224,6 @@ def preprocess(
225
  source: str,
226
  model_input_height: int,
227
  model_input_width: int,
228
- device: str,
229
  transform: Compose,
230
  ) -> dict:
231
  '''
@@ -243,8 +241,6 @@ def preprocess(
243
  Model input height.
244
  model_input_width : int
245
  Model input width.
246
- device : str
247
- Device to use.
248
  transform : Compose
249
  Transform to apply.
250
 
@@ -292,8 +288,7 @@ def preprocess(
292
  skeleton_video = torch.stack(skeleton_video)
293
  skeleton_video = UniformTemporalSubsample(model_num_frames)(skeleton_video)
294
  inputs = {
295
- 'pixel_values': skeleton_video.unsqueeze(0),
296
  }
297
- inputs = {k: v.to(device) for k, v in inputs.items()}
298
 
299
  return inputs
 
1
  import cv2
 
2
  import numpy as np
3
+ import onnxruntime as ort
4
+ import torch
5
  from mediapipe.python.solutions import (drawing_styles, drawing_utils,
6
  holistic, pose)
7
  from torchvision.transforms.v2 import Compose, UniformTemporalSubsample
 
8
 
9
 
10
  def draw_skeleton_on_image(
 
178
 
179
  def get_predictions(
180
  inputs: dict,
181
+ ort_session: ort.InferenceSession,
182
+ id2gloss: dict,
183
  k: int = 3,
184
  ) -> list:
185
  '''
 
202
  if inputs is None:
203
  return []
204
 
205
+ logits = torch.from_numpy(ort_session.run(None, inputs)[0])
 
 
206
 
207
  # Get top-3 predictions
208
  topk_scores, topk_indices = torch.topk(logits, k, dim=1)
 
211
 
212
  return [
213
  {
214
+ 'label': id2gloss[str(topk_indices[i])],
215
  'score': topk_scores[i],
216
  }
217
  for i in range(k)
 
224
  source: str,
225
  model_input_height: int,
226
  model_input_width: int,
 
227
  transform: Compose,
228
  ) -> dict:
229
  '''
 
241
  Model input height.
242
  model_input_width : int
243
  Model input width.
 
 
244
  transform : Compose
245
  Transform to apply.
246
 
 
288
  skeleton_video = torch.stack(skeleton_video)
289
  skeleton_video = UniformTemporalSubsample(model_num_frames)(skeleton_video)
290
  inputs = {
291
+ 'pixel_values': skeleton_video.unsqueeze(0).numpy(),
292
  }
 
293
 
294
  return inputs
videomae_skeleton_v2.3.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:271d0e3d932fffc036b6cef4f8c90721e223e32816ef16bb853c890b0f3b90c7
3
+ size 90390035