VSL-VideoMAE / app.py
tanthinhdt's picture
feat: use ONNX model
b764d7e
import json
import gradio as gr
from time import time
import onnxruntime as ort
from mediapipe.python.solutions import holistic
from torchvision.transforms.v2 import Compose, Lambda, Normalize
from utils import get_predictions, preprocess
title = '''
'''
cite_markdown = '''
'''
description = '''
'''
examples = [
['000_con_cho.mp4'],
]
ort_session = ort.InferenceSession('videomae_skeleton_v2.3.onnx')
model_config = json.load(open('config.json'))
preprocessor_config = json.load(open('preprocessor_config.json'))
mean = preprocessor_config['image_mean']
std = preprocessor_config['image_std']
if 'shortest_edge' in preprocessor_config['size']:
model_input_height = model_input_width = preprocessor_config['size']['shortest_edge']
else:
model_input_height = preprocessor_config['size']['height']
model_input_width = preprocessor_config['size']['width']
# Define the transform.
transform = Compose(
[
Lambda(lambda x: x / 255.0),
Normalize(mean=mean, std=std),
]
)
def inference(
video: str,
progress: gr.Progress = gr.Progress(),
) -> str:
'''
Video-based inference for Vietnamese Sign Language recognition.
Parameters
----------
video : str
The path to the video.
progress : gr.Progress, optional
The progress bar, by default gr.Progress()
Returns
-------
str
The top-3 predictions.
'''
progress(0, desc='Preprocessing video')
keypoints_detector = holistic.Holistic(
static_image_mode=False,
model_complexity=2,
enable_segmentation=True,
refine_face_landmarks=True,
)
start_time = time()
inputs = preprocess(
model_num_frames=model_config['num_frames'],
keypoints_detector=keypoints_detector,
source=video,
model_input_height=model_input_height,
model_input_width=model_input_width,
transform=transform,
)
end_time = time()
data_time = end_time - start_time
progress(1/2, desc='Getting predictions')
start_time = time()
predictions = get_predictions(
inputs=inputs,
ort_session=ort_session,
id2gloss=model_config['id2label'],
k=3,
)
end_time = time()
model_time = end_time - start_time
if len(predictions) == 0:
output_message = 'No sign language detected in the video. Please try again.'
else:
output_message = 'The top-3 predictions are:\n'
for i, prediction in enumerate(predictions):
output_message += f'\t{i+1}. {prediction["label"]} ({prediction["score"]:2f})\n'
output_message += f'Data processing time: {data_time:.2f} seconds\n'
output_message += f'Model inference time: {model_time:.2f} seconds\n'
output_message += f'Total time: {data_time + model_time:.2f} seconds'
progress(1/2, desc='Completed')
return output_message
iface = gr.Interface(
fn=inference,
inputs='video',
outputs='text',
examples=examples,
title=title,
description=description,
)
iface.launch()
# print(inference('000_con_cho.mp4'))