Spaces:
Sleeping
Sleeping
import cv2 | |
import numpy as np | |
import torch | |
from mediapipe.python.solutions import (drawing_styles, drawing_utils, | |
holistic, pose) | |
from torchvision.transforms.v2 import Compose, UniformTemporalSubsample | |
def draw_skeleton_on_image( | |
image: np.ndarray, | |
detection_results, | |
resize_to: tuple[int, int] = None, | |
) -> np.ndarray: | |
""" | |
Draw skeleton on the image. | |
Parameters | |
---------- | |
image : np.ndarray | |
Image to draw skeleton on. | |
detection_results | |
Detection results. | |
resize_to : tuple[int, int], optional | |
Resize the image to the specified size. | |
Returns | |
------- | |
np.ndarray | |
Annotated image with skeleton. | |
""" | |
annotated_image = np.copy(image) | |
# Draw pose connections | |
drawing_utils.draw_landmarks( | |
annotated_image, | |
detection_results.pose_landmarks, | |
holistic.POSE_CONNECTIONS, | |
landmark_drawing_spec=drawing_styles.get_default_pose_landmarks_style(), | |
) | |
# Draw left hand connections | |
drawing_utils.draw_landmarks( | |
annotated_image, | |
detection_results.left_hand_landmarks, | |
holistic.HAND_CONNECTIONS, | |
drawing_utils.DrawingSpec(color=(121, 22, 76), thickness=2, circle_radius=4), | |
drawing_utils.DrawingSpec(color=(121, 44, 250), thickness=2, circle_radius=2), | |
) | |
# Draw right hand connections | |
drawing_utils.draw_landmarks( | |
annotated_image, | |
detection_results.right_hand_landmarks, | |
holistic.HAND_CONNECTIONS, | |
drawing_utils.DrawingSpec(color=(245, 117, 66), thickness=2, circle_radius=4), | |
drawing_utils.DrawingSpec(color=(245, 66, 230), thickness=2, circle_radius=2), | |
) | |
if resize_to is not None: | |
annotated_image = cv2.resize( | |
annotated_image, | |
resize_to, | |
interpolation=cv2.INTER_AREA, | |
) | |
return annotated_image | |
def are_hands_down(pose_landmarks: list) -> bool: | |
""" | |
Check if the hand is down. | |
Parameters | |
---------- | |
hand_landmarks : list | |
Hand landmarks. | |
Returns | |
------- | |
bool | |
True if the hand is down, False otherwise. | |
""" | |
if pose_landmarks is None: | |
return True | |
landmarks = pose_landmarks.landmark | |
left_elbow = [ | |
landmarks[pose.PoseLandmark.LEFT_ELBOW.value].x, | |
landmarks[pose.PoseLandmark.LEFT_ELBOW.value].y, | |
landmarks[pose.PoseLandmark.LEFT_SHOULDER.value].visibility, | |
] | |
left_wrist = [ | |
landmarks[pose.PoseLandmark.LEFT_WRIST.value].x, | |
landmarks[pose.PoseLandmark.LEFT_WRIST.value].y, | |
landmarks[pose.PoseLandmark.LEFT_SHOULDER.value].visibility, | |
] | |
right_elbow = [ | |
landmarks[pose.PoseLandmark.RIGHT_ELBOW.value].x, | |
landmarks[pose.PoseLandmark.RIGHT_ELBOW.value].y, | |
landmarks[pose.PoseLandmark.RIGHT_SHOULDER.value].visibility, | |
] | |
right_wrist = [ | |
landmarks[pose.PoseLandmark.RIGHT_WRIST.value].x, | |
landmarks[pose.PoseLandmark.RIGHT_WRIST.value].y, | |
landmarks[pose.PoseLandmark.RIGHT_SHOULDER.value].visibility, | |
] | |
is_visible = all( | |
[left_elbow[2] > 0, left_wrist[2] > 0, right_elbow[2] > 0, right_wrist[2] > 0] | |
) | |
return is_visible and left_wrist[1] > left_elbow[1] and right_wrist[1] > right_elbow[1] | |
def get_predictions( | |
inputs: dict, | |
model, | |
k: int = 3, | |
) -> list: | |
if inputs is None: | |
return [] | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
logits = outputs.logits | |
# Get top-3 predictions | |
topk_scores, topk_indices = torch.topk(logits, k, dim=1) | |
topk_scores = torch.nn.functional.softmax(topk_scores, dim=1).squeeze().detach().numpy() | |
topk_indices = topk_indices.squeeze().detach().numpy() | |
return [ | |
{ | |
'label': model.config.id2label[topk_indices[i]], | |
'score': topk_scores[i], | |
} | |
for i in range(k) | |
] | |
def preprocess( | |
model_num_frames: int, | |
keypoints_detector, | |
source: str, | |
model_input_height: int, | |
model_input_width: int, | |
device: str, | |
transform: Compose, | |
) -> dict: | |
skeleton_video = [] | |
did_sample_start = False | |
cap = cv2.VideoCapture(source) | |
while cap.isOpened(): | |
ret, frame = cap.read() | |
if not ret: | |
break | |
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
# Detect keypoints. | |
detection_results = keypoints_detector.process(frame) | |
skeleton_frame = draw_skeleton_on_image( | |
image=np.zeros((1080, 1080, 3), dtype=np.uint8), | |
detection_results=detection_results, | |
resize_to=(model_input_height, model_input_width), | |
) | |
# (height, width, channels) -> (channels, height, width) | |
skeleton_frame = transform(torch.tensor(skeleton_frame).permute(2, 0, 1)) | |
# Extract sign video. | |
if not are_hands_down(detection_results.pose_landmarks): | |
if not did_sample_start: | |
did_sample_start = True | |
elif did_sample_start: | |
break | |
if did_sample_start: | |
skeleton_video.append(skeleton_frame) | |
cap.release() | |
if len(skeleton_video) < model_num_frames: | |
return None | |
skeleton_video = torch.stack(skeleton_video) | |
skeleton_video = UniformTemporalSubsample(model_num_frames)(skeleton_video) | |
inputs = { | |
'pixel_values': skeleton_video.unsqueeze(0), | |
} | |
inputs = {k: v.to(device) for k, v in inputs.items()} | |
return inputs | |