Spaces:
Sleeping
Sleeping
import json | |
import gradio as gr | |
from time import time | |
import onnxruntime as ort | |
from mediapipe.python.solutions import holistic | |
from torchvision.transforms.v2 import Compose, Lambda, Normalize | |
from utils import get_predictions, preprocess | |
title = ''' | |
''' | |
cite_markdown = ''' | |
''' | |
description = ''' | |
''' | |
examples = [ | |
['000_con_cho.mp4'], | |
] | |
ort_session = ort.InferenceSession('videomae_skeleton_v2.3.onnx') | |
model_config = json.load(open('config.json')) | |
preprocessor_config = json.load(open('preprocessor_config.json')) | |
mean = preprocessor_config['image_mean'] | |
std = preprocessor_config['image_std'] | |
if 'shortest_edge' in preprocessor_config['size']: | |
model_input_height = model_input_width = preprocessor_config['size']['shortest_edge'] | |
else: | |
model_input_height = preprocessor_config['size']['height'] | |
model_input_width = preprocessor_config['size']['width'] | |
# Define the transform. | |
transform = Compose( | |
[ | |
Lambda(lambda x: x / 255.0), | |
Normalize(mean=mean, std=std), | |
] | |
) | |
def inference( | |
video: str, | |
progress: gr.Progress = gr.Progress(), | |
) -> str: | |
''' | |
Video-based inference for Vietnamese Sign Language recognition. | |
Parameters | |
---------- | |
video : str | |
The path to the video. | |
progress : gr.Progress, optional | |
The progress bar, by default gr.Progress() | |
Returns | |
------- | |
str | |
The top-3 predictions. | |
''' | |
progress(0, desc='Preprocessing video') | |
keypoints_detector = holistic.Holistic( | |
static_image_mode=False, | |
model_complexity=2, | |
enable_segmentation=True, | |
refine_face_landmarks=True, | |
) | |
start_time = time() | |
inputs = preprocess( | |
model_num_frames=model_config['num_frames'], | |
keypoints_detector=keypoints_detector, | |
source=video, | |
model_input_height=model_input_height, | |
model_input_width=model_input_width, | |
transform=transform, | |
) | |
end_time = time() | |
data_time = end_time - start_time | |
progress(1/2, desc='Getting predictions') | |
start_time = time() | |
predictions = get_predictions( | |
inputs=inputs, | |
ort_session=ort_session, | |
id2gloss=model_config['id2label'], | |
k=3, | |
) | |
end_time = time() | |
model_time = end_time - start_time | |
if len(predictions) == 0: | |
output_message = 'No sign language detected in the video. Please try again.' | |
else: | |
output_message = 'The top-3 predictions are:\n' | |
for i, prediction in enumerate(predictions): | |
output_message += f'\t{i+1}. {prediction["label"]} ({prediction["score"]:2f})\n' | |
output_message += f'Data processing time: {data_time:.2f} seconds\n' | |
output_message += f'Model inference time: {model_time:.2f} seconds\n' | |
output_message += f'Total time: {data_time + model_time:.2f} seconds' | |
progress(1/2, desc='Completed') | |
return output_message | |
iface = gr.Interface( | |
fn=inference, | |
inputs='video', | |
outputs='text', | |
examples=examples, | |
title=title, | |
description=description, | |
) | |
iface.launch() | |
# print(inference('000_con_cho.mp4')) | |