VSL-VideoMAE / utils.py
tanthinhdt's picture
fix(utils): adjust get_prediction
291adba verified
raw
history blame
5.53 kB
import cv2
import numpy as np
import torch
from mediapipe.python.solutions import (drawing_styles, drawing_utils,
holistic, pose)
from torchvision.transforms.v2 import Compose, UniformTemporalSubsample
def draw_skeleton_on_image(
image: np.ndarray,
detection_results,
resize_to: tuple[int, int] = None,
) -> np.ndarray:
"""
Draw skeleton on the image.
Parameters
----------
image : np.ndarray
Image to draw skeleton on.
detection_results
Detection results.
resize_to : tuple[int, int], optional
Resize the image to the specified size.
Returns
-------
np.ndarray
Annotated image with skeleton.
"""
annotated_image = np.copy(image)
# Draw pose connections
drawing_utils.draw_landmarks(
annotated_image,
detection_results.pose_landmarks,
holistic.POSE_CONNECTIONS,
landmark_drawing_spec=drawing_styles.get_default_pose_landmarks_style(),
)
# Draw left hand connections
drawing_utils.draw_landmarks(
annotated_image,
detection_results.left_hand_landmarks,
holistic.HAND_CONNECTIONS,
drawing_utils.DrawingSpec(color=(121, 22, 76), thickness=2, circle_radius=4),
drawing_utils.DrawingSpec(color=(121, 44, 250), thickness=2, circle_radius=2),
)
# Draw right hand connections
drawing_utils.draw_landmarks(
annotated_image,
detection_results.right_hand_landmarks,
holistic.HAND_CONNECTIONS,
drawing_utils.DrawingSpec(color=(245, 117, 66), thickness=2, circle_radius=4),
drawing_utils.DrawingSpec(color=(245, 66, 230), thickness=2, circle_radius=2),
)
if resize_to is not None:
annotated_image = cv2.resize(
annotated_image,
resize_to,
interpolation=cv2.INTER_AREA,
)
return annotated_image
def are_hands_down(pose_landmarks: list) -> bool:
"""
Check if the hand is down.
Parameters
----------
hand_landmarks : list
Hand landmarks.
Returns
-------
bool
True if the hand is down, False otherwise.
"""
if pose_landmarks is None:
return True
landmarks = pose_landmarks.landmark
left_elbow = [
landmarks[pose.PoseLandmark.LEFT_ELBOW.value].x,
landmarks[pose.PoseLandmark.LEFT_ELBOW.value].y,
landmarks[pose.PoseLandmark.LEFT_SHOULDER.value].visibility,
]
left_wrist = [
landmarks[pose.PoseLandmark.LEFT_WRIST.value].x,
landmarks[pose.PoseLandmark.LEFT_WRIST.value].y,
landmarks[pose.PoseLandmark.LEFT_SHOULDER.value].visibility,
]
right_elbow = [
landmarks[pose.PoseLandmark.RIGHT_ELBOW.value].x,
landmarks[pose.PoseLandmark.RIGHT_ELBOW.value].y,
landmarks[pose.PoseLandmark.RIGHT_SHOULDER.value].visibility,
]
right_wrist = [
landmarks[pose.PoseLandmark.RIGHT_WRIST.value].x,
landmarks[pose.PoseLandmark.RIGHT_WRIST.value].y,
landmarks[pose.PoseLandmark.RIGHT_SHOULDER.value].visibility,
]
is_visible = all(
[left_elbow[2] > 0, left_wrist[2] > 0, right_elbow[2] > 0, right_wrist[2] > 0]
)
return is_visible and left_wrist[1] > left_elbow[1] and right_wrist[1] > right_elbow[1]
def get_predictions(
inputs: dict,
model,
k: int = 3,
) -> list:
if inputs is None:
return []
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
# Get top-3 predictions
topk_scores, topk_indices = torch.topk(logits, k, dim=1)
topk_scores = torch.nn.functional.softmax(topk_scores, dim=1).squeeze().detach().numpy()
topk_indices = topk_indices.squeeze().detach().numpy()
return [
{
'label': model.config.id2label[topk_indices[i]],
'score': topk_scores[i],
}
for i in range(k)
]
def preprocess(
model_num_frames: int,
keypoints_detector,
source: str,
model_input_height: int,
model_input_width: int,
device: str,
transform: Compose,
) -> dict:
skeleton_video = []
did_sample_start = False
cap = cv2.VideoCapture(source)
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
# Detect keypoints.
detection_results = keypoints_detector.process(frame)
skeleton_frame = draw_skeleton_on_image(
image=np.zeros((1080, 1080, 3), dtype=np.uint8),
detection_results=detection_results,
resize_to=(model_input_height, model_input_width),
)
# (height, width, channels) -> (channels, height, width)
skeleton_frame = transform(torch.tensor(skeleton_frame).permute(2, 0, 1))
# Extract sign video.
if not are_hands_down(detection_results.pose_landmarks):
if not did_sample_start:
did_sample_start = True
elif did_sample_start:
break
if did_sample_start:
skeleton_video.append(skeleton_frame)
cap.release()
if len(skeleton_video) < model_num_frames:
return None
skeleton_video = torch.stack(skeleton_video)
skeleton_video = UniformTemporalSubsample(model_num_frames)(skeleton_video)
inputs = {
'pixel_values': skeleton_video.unsqueeze(0),
}
inputs = {k: v.to(device) for k, v in inputs.items()}
return inputs