import json import gradio as gr from time import time import onnxruntime as ort from mediapipe.python.solutions import holistic from torchvision.transforms.v2 import Compose, Lambda, Normalize from utils import get_predictions, preprocess title = ''' ''' cite_markdown = ''' ''' description = ''' ''' examples = [ ['000_con_cho.mp4'], ] ort_session = ort.InferenceSession('videomae_skeleton_v2.3.onnx') model_config = json.load(open('config.json')) preprocessor_config = json.load(open('preprocessor_config.json')) mean = preprocessor_config['image_mean'] std = preprocessor_config['image_std'] if 'shortest_edge' in preprocessor_config['size']: model_input_height = model_input_width = preprocessor_config['size']['shortest_edge'] else: model_input_height = preprocessor_config['size']['height'] model_input_width = preprocessor_config['size']['width'] # Define the transform. transform = Compose( [ Lambda(lambda x: x / 255.0), Normalize(mean=mean, std=std), ] ) def inference( video: str, progress: gr.Progress = gr.Progress(), ) -> str: ''' Video-based inference for Vietnamese Sign Language recognition. Parameters ---------- video : str The path to the video. progress : gr.Progress, optional The progress bar, by default gr.Progress() Returns ------- str The top-3 predictions. ''' progress(0, desc='Preprocessing video') keypoints_detector = holistic.Holistic( static_image_mode=False, model_complexity=2, enable_segmentation=True, refine_face_landmarks=True, ) start_time = time() inputs = preprocess( model_num_frames=model_config['num_frames'], keypoints_detector=keypoints_detector, source=video, model_input_height=model_input_height, model_input_width=model_input_width, transform=transform, ) end_time = time() data_time = end_time - start_time progress(1/2, desc='Getting predictions') start_time = time() predictions = get_predictions( inputs=inputs, ort_session=ort_session, id2gloss=model_config['id2label'], k=3, ) end_time = time() model_time = end_time - start_time if len(predictions) == 0: output_message = 'No sign language detected in the video. Please try again.' else: output_message = 'The top-3 predictions are:\n' for i, prediction in enumerate(predictions): output_message += f'\t{i+1}. {prediction["label"]} ({prediction["score"]:2f})\n' output_message += f'Data processing time: {data_time:.2f} seconds\n' output_message += f'Model inference time: {model_time:.2f} seconds\n' output_message += f'Total time: {data_time + model_time:.2f} seconds' progress(1/2, desc='Completed') return output_message iface = gr.Interface( fn=inference, inputs='video', outputs='text', examples=examples, title=title, description=description, ) iface.launch() # print(inference('000_con_cho.mp4'))