VSL-VideoMAE / utils.py
tanthinhdt's picture
feat: use ONNX model
e7a4186
import cv2
import numpy as np
import onnxruntime as ort
import torch
from mediapipe.python.solutions import (drawing_styles, drawing_utils,
holistic, pose)
from torchvision.transforms.v2 import Compose, UniformTemporalSubsample
def draw_skeleton_on_image(
image: np.ndarray,
detection_results,
resize_to: tuple[int, int] = None,
) -> np.ndarray:
'''
Draw skeleton on the image.
Parameters
----------
image : np.ndarray
Image to draw skeleton on.
detection_results
Detection results.
resize_to : tuple[int, int], optional
Resize the image to the specified size.
Returns
-------
np.ndarray
Annotated image with skeleton.
'''
annotated_image = np.copy(image)
# Draw pose connections
drawing_utils.draw_landmarks(
annotated_image,
detection_results.pose_landmarks,
holistic.POSE_CONNECTIONS,
landmark_drawing_spec=drawing_styles.get_default_pose_landmarks_style(),
)
# Draw left hand connections
drawing_utils.draw_landmarks(
annotated_image,
detection_results.left_hand_landmarks,
holistic.HAND_CONNECTIONS,
drawing_utils.DrawingSpec(color=(121, 22, 76), thickness=2, circle_radius=4),
drawing_utils.DrawingSpec(color=(121, 44, 250), thickness=2, circle_radius=2),
)
# Draw right hand connections
drawing_utils.draw_landmarks(
annotated_image,
detection_results.right_hand_landmarks,
holistic.HAND_CONNECTIONS,
drawing_utils.DrawingSpec(color=(245, 117, 66), thickness=2, circle_radius=4),
drawing_utils.DrawingSpec(color=(245, 66, 230), thickness=2, circle_radius=2),
)
if resize_to is not None:
annotated_image = cv2.resize(
annotated_image,
resize_to,
interpolation=cv2.INTER_AREA,
)
return annotated_image
def calculate_angle(
shoulder: list,
elbow: list,
wrist: list,
) -> float:
'''
Calculate the angle between the shoulder, elbow, and wrist.
Parameters
----------
shoulder : list
Shoulder coordinates.
elbow : list
Elbow coordinates.
wrist : list
Wrist coordinates.
Returns
-------
float
Angle in degree between the shoulder, elbow, and wrist.
'''
shoulder = np.array(shoulder)
elbow = np.array(elbow)
wrist = np.array(wrist)
radians = np.arctan2(wrist[1] - elbow[1], wrist[0] - elbow[0]) \
- np.arctan2(shoulder[1] - elbow[1], shoulder[0] - elbow[0])
angle = np.abs(radians * 180.0 / np.pi)
if angle > 180.0:
angle = 360 - angle
return angle
def do_hands_relax(
pose_landmarks: list,
angle_threshold: float = 160.0,
) -> bool:
'''
Check if the hand is down.
Parameters
----------
hand_landmarks : list
Hand landmarks.
angle_threshold : float, optional
Angle threshold, by default 160.0.
Returns
-------
bool
True if the hand is down, False otherwise.
'''
if pose_landmarks is None:
return True
landmarks = pose_landmarks.landmark
left_shoulder = [
landmarks[pose.PoseLandmark.LEFT_SHOULDER.value].x,
landmarks[pose.PoseLandmark.LEFT_SHOULDER.value].y,
landmarks[pose.PoseLandmark.LEFT_SHOULDER.value].visibility,
]
left_elbow = [
landmarks[pose.PoseLandmark.LEFT_ELBOW.value].x,
landmarks[pose.PoseLandmark.LEFT_ELBOW.value].y,
landmarks[pose.PoseLandmark.LEFT_SHOULDER.value].visibility,
]
left_wrist = [
landmarks[pose.PoseLandmark.LEFT_WRIST.value].x,
landmarks[pose.PoseLandmark.LEFT_WRIST.value].y,
landmarks[pose.PoseLandmark.LEFT_SHOULDER.value].visibility,
]
left_angle = calculate_angle(left_shoulder, left_elbow, left_wrist)
right_shoulder = [
landmarks[pose.PoseLandmark.RIGHT_SHOULDER.value].x,
landmarks[pose.PoseLandmark.RIGHT_SHOULDER.value].y,
landmarks[pose.PoseLandmark.RIGHT_SHOULDER.value].visibility,
]
right_elbow = [
landmarks[pose.PoseLandmark.RIGHT_ELBOW.value].x,
landmarks[pose.PoseLandmark.RIGHT_ELBOW.value].y,
landmarks[pose.PoseLandmark.RIGHT_SHOULDER.value].visibility,
]
right_wrist = [
landmarks[pose.PoseLandmark.RIGHT_WRIST.value].x,
landmarks[pose.PoseLandmark.RIGHT_WRIST.value].y,
landmarks[pose.PoseLandmark.RIGHT_SHOULDER.value].visibility,
]
right_angle = calculate_angle(right_shoulder, right_elbow, right_wrist)
is_visible = all(
[
left_shoulder[2] > 0,
left_elbow[2] > 0,
left_wrist[2] > 0,
right_shoulder[2] > 0,
right_elbow[2] > 0,
right_wrist[2] > 0,
]
)
return all(
[
is_visible,
left_angle < angle_threshold,
right_angle < angle_threshold,
]
)
def get_predictions(
inputs: dict,
ort_session: ort.InferenceSession,
id2gloss: dict,
k: int = 3,
) -> list:
'''
Get the top-k predictions.
Parameters
----------
inputs : dict
Model inputs.
model : VideoMAEForVideoClassification
Model to get predictions from.
k : int, optional
Number of predictions to return, by default 3.
Returns
-------
list
Top-k predictions.
'''
if inputs is None:
return []
logits = torch.from_numpy(ort_session.run(None, inputs)[0])
# Get top-3 predictions
topk_scores, topk_indices = torch.topk(logits, k, dim=1)
topk_scores = torch.nn.functional.softmax(topk_scores, dim=1).squeeze().detach().numpy()
topk_indices = topk_indices.squeeze().detach().numpy()
return [
{
'label': id2gloss[str(topk_indices[i])],
'score': topk_scores[i],
}
for i in range(k)
]
def preprocess(
model_num_frames: int,
keypoints_detector,
source: str,
model_input_height: int,
model_input_width: int,
transform: Compose,
) -> dict:
'''
Preprocess the video.
Parameters
----------
model_num_frames : int
Number of frames in the model.
keypoints_detector
Keypoints detector.
source : str
Video source.
model_input_height : int
Model input height.
model_input_width : int
Model input width.
transform : Compose
Transform to apply.
Returns
-------
dict
Model inputs.
'''
skeleton_video = []
did_sample_start = False
cap = cv2.VideoCapture(source)
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
# Detect keypoints.
detection_results = keypoints_detector.process(frame)
skeleton_frame = draw_skeleton_on_image(
image=np.zeros((1080, 1080, 3), dtype=np.uint8),
detection_results=detection_results,
resize_to=(model_input_height, model_input_width),
)
# (height, width, channels) -> (channels, height, width)
skeleton_frame = transform(torch.tensor(skeleton_frame).permute(2, 0, 1))
# Extract sign video.
if not do_hands_relax(detection_results.pose_landmarks):
if not did_sample_start:
did_sample_start = True
elif did_sample_start:
break
if did_sample_start:
skeleton_video.append(skeleton_frame)
cap.release()
if len(skeleton_video) < model_num_frames:
return None
skeleton_video = torch.stack(skeleton_video)
skeleton_video = UniformTemporalSubsample(model_num_frames)(skeleton_video)
inputs = {
'pixel_values': skeleton_video.unsqueeze(0).numpy(),
}
return inputs