Spaces:
Running
on
Zero
Running
on
Zero
#!/usr/bin/env python | |
"""A demo of the VitPose model. | |
This code is based on the implementation from the Colab notebook: | |
https://colab.research.google.com/drive/1e8fcby5rhKZWcr9LSN8mNbQ0TU4Dxxpo | |
""" | |
import pathlib | |
import tempfile | |
import cv2 | |
import gradio as gr | |
import numpy as np | |
import PIL.Image | |
import spaces | |
import supervision as sv | |
import torch | |
import tqdm | |
from transformers import AutoProcessor, RTDetrForObjectDetection, VitPoseForPoseEstimation | |
DESCRIPTION = "# ViTPose" | |
MAX_NUM_FRAMES = 300 | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
person_detector_name = "PekingU/rtdetr_r50vd_coco_o365" | |
person_image_processor = AutoProcessor.from_pretrained(person_detector_name) | |
person_model = RTDetrForObjectDetection.from_pretrained(person_detector_name, device_map=device) | |
pose_model_name = "usyd-community/vitpose-base-simple" | |
pose_image_processor = AutoProcessor.from_pretrained(pose_model_name) | |
pose_model = VitPoseForPoseEstimation.from_pretrained(pose_model_name, device_map=device) | |
def process_image(image: PIL.Image.Image) -> tuple[PIL.Image.Image, list[dict]]: | |
inputs = person_image_processor(images=image, return_tensors="pt").to(device) | |
outputs = person_model(**inputs) | |
results = person_image_processor.post_process_object_detection( | |
outputs, target_sizes=torch.tensor([(image.height, image.width)]), threshold=0.3 | |
) | |
result = results[0] # take first image results | |
# Human label refers 0 index in COCO dataset | |
person_boxes_xyxy = result["boxes"][result["labels"] == 0] | |
person_boxes_xyxy = person_boxes_xyxy.cpu().numpy() | |
# Convert boxes from VOC (x1, y1, x2, y2) to COCO (x1, y1, w, h) format | |
person_boxes = person_boxes_xyxy.copy() | |
person_boxes[:, 2] = person_boxes[:, 2] - person_boxes[:, 0] | |
person_boxes[:, 3] = person_boxes[:, 3] - person_boxes[:, 1] | |
inputs = pose_image_processor(image, boxes=[person_boxes], return_tensors="pt").to(device) | |
# for vitpose-plus-base checkpoint we should additionaly provide dataset_index | |
# to sepcify which MOE experts to use for inference | |
if pose_model.config.backbone_config.num_experts > 1: | |
dataset_index = torch.tensor([0] * len(inputs["pixel_values"])) | |
dataset_index = dataset_index.to(inputs["pixel_values"].device) | |
inputs["dataset_index"] = dataset_index | |
outputs = pose_model(**inputs) | |
pose_results = pose_image_processor.post_process_pose_estimation(outputs, boxes=[person_boxes]) | |
image_pose_result = pose_results[0] # results for first image | |
# make results more human-readable | |
human_readable_results = [] | |
for i, person_pose in enumerate(image_pose_result): | |
data = { | |
"person_id": i, | |
"bbox": person_pose["bbox"].numpy().tolist(), | |
"keypoints": [], | |
} | |
for keypoint, label, score in zip( | |
person_pose["keypoints"], person_pose["labels"], person_pose["scores"], strict=True | |
): | |
keypoint_name = pose_model.config.id2label[label.item()] | |
x, y = keypoint | |
data["keypoints"].append({"name": keypoint_name, "x": x.item(), "y": y.item(), "score": score.item()}) | |
human_readable_results.append(data) | |
# preprocess to torch tensor of shape (n_objects, n_keypoints, 2) | |
xy = [pose_result["keypoints"] for pose_result in image_pose_result] | |
xy = torch.stack(xy).cpu().numpy() | |
scores = [pose_result["scores"] for pose_result in image_pose_result] | |
scores = torch.stack(scores).cpu().numpy() | |
keypoints = sv.KeyPoints(xy=xy, confidence=scores) | |
detections = sv.Detections(xyxy=person_boxes_xyxy) | |
edge_annotator = sv.EdgeAnnotator(color=sv.Color.GREEN, thickness=1) | |
vertex_annotator = sv.VertexAnnotator(color=sv.Color.RED, radius=2) | |
bounding_box_annotator = sv.BoxAnnotator(color=sv.Color.WHITE, color_lookup=sv.ColorLookup.INDEX, thickness=1) | |
annotated_frame = image.copy() | |
# annotate boundg boxes | |
annotated_frame = bounding_box_annotator.annotate(scene=image.copy(), detections=detections) | |
# annotate edges and verticies | |
annotated_frame = edge_annotator.annotate(scene=annotated_frame, key_points=keypoints) | |
return vertex_annotator.annotate(scene=annotated_frame, key_points=keypoints), human_readable_results | |
def process_video( | |
video_path: str, | |
progress: gr.Progress = gr.Progress(track_tqdm=True), # noqa: ARG001, B008 | |
) -> str: | |
cap = cv2.VideoCapture(video_path) | |
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
fps = cap.get(cv2.CAP_PROP_FPS) | |
num_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
fourcc = cv2.VideoWriter_fourcc(*"mp4v") | |
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as out_file: | |
writer = cv2.VideoWriter(out_file.name, fourcc, fps, (width, height)) | |
for _ in tqdm.auto.tqdm(range(min(MAX_NUM_FRAMES, num_frames))): | |
ok, frame = cap.read() | |
if not ok: | |
break | |
rgb_frame = frame[:, :, ::-1] | |
annotated_frame, _ = process_image(PIL.Image.fromarray(rgb_frame)) | |
writer.write(np.asarray(annotated_frame)[:, :, ::-1]) | |
writer.release() | |
cap.release() | |
return out_file.name | |
with gr.Blocks(css_paths="style.css") as demo: | |
gr.Markdown(DESCRIPTION) | |
with gr.Tabs(): | |
with gr.Tab("Image"): | |
with gr.Row(): | |
with gr.Column(): | |
input_image = gr.Image(label="Input Image", type="pil") | |
run_button_image = gr.Button() | |
with gr.Column(): | |
output_image = gr.Image(label="Output Image") | |
output_json = gr.JSON(label="Output JSON") | |
gr.Examples( | |
examples=sorted(pathlib.Path("images").glob("*.jpg")), | |
inputs=input_image, | |
outputs=[output_image, output_json], | |
fn=process_image, | |
) | |
run_button_image.click( | |
fn=process_image, | |
inputs=input_image, | |
outputs=[output_image, output_json], | |
) | |
with gr.Tab("Video"): | |
gr.Markdown(f"The input video will be truncated to {MAX_NUM_FRAMES} frames.") | |
with gr.Row(): | |
with gr.Column(): | |
input_video = gr.Video(label="Input Video") | |
run_button_video = gr.Button() | |
with gr.Column(): | |
output_video = gr.Video(label="Output Video") | |
gr.Examples( | |
examples=sorted(pathlib.Path("videos").glob("*.mp4")), | |
inputs=input_video, | |
outputs=output_video, | |
fn=process_video, | |
cache_examples=False, | |
) | |
run_button_video.click( | |
fn=process_video, | |
inputs=input_video, | |
outputs=output_video, | |
) | |
if __name__ == "__main__": | |
demo.queue(max_size=20).launch() | |