Spaces:
Runtime error
Runtime error
import cv2 | |
import gradio as gr | |
import imutils | |
import numpy as np | |
import torch | |
from pytorchvideo.transforms import ( | |
ApplyTransformToKey, | |
Normalize, | |
RandomShortSideScale, | |
RemoveKey, | |
ShortSideScale, | |
UniformTemporalSubsample, | |
) | |
from torchvision.transforms import ( | |
Compose, | |
Lambda, | |
RandomCrop, | |
RandomHorizontalFlip, | |
Resize, | |
) | |
from transformers import VideoMAEFeatureExtractor, VideoMAEForVideoClassification | |
MODEL_CKPT = "sayakpaul/videomae-base-finetuned-kinetics-finetuned-ucf101-subset" | |
MODEL = VideoMAEForVideoClassification.from_pretrained(MODEL_CKPT) | |
PROCESSOR = VideoMAEFeatureExtractor.from_pretrained(MODEL_CKPT) | |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
RESIZE_TO = PROCESSOR.size | |
NUM_FRAMES_TO_SAMPLE = MODEL.config.num_frames | |
IMAGE_STATS = {"image_mean": [0.485, 0.456, 0.406], "image_std": [0.229, 0.224, 0.225]} | |
VAL_TRANSFORMS = Compose( | |
[ | |
UniformTemporalSubsample(NUM_FRAMES_TO_SAMPLE), | |
Lambda(lambda x: x / 255.0), | |
Normalize(IMAGE_STATS["image_mean"], IMAGE_STATS["image_std"]), | |
Resize((RESIZE_TO, RESIZE_TO)), | |
] | |
) | |
LABELS = list(MODEL.config.label2id.keys()) | |
def parse_video(video_file): | |
"""A utility to parse the input videos. | |
Reference: https://pyimagesearch.com/2018/11/12/yolo-object-detection-with-opencv/ | |
""" | |
vs = cv2.VideoCapture(video_file) | |
# try to determine the total number of frames in the video file | |
try: | |
prop = ( | |
cv2.cv.CV_CAP_PROP_FRAME_COUNT | |
if imutils.is_cv2() | |
else cv2.CAP_PROP_FRAME_COUNT | |
) | |
total = int(vs.get(prop)) | |
print("[INFO] {} total frames in video".format(total)) | |
# an error occurred while trying to determine the total | |
# number of frames in the video file | |
except: | |
print("[INFO] could not determine # of frames in video") | |
print("[INFO] no approx. completion time can be provided") | |
total = -1 | |
frames = [] | |
# loop over frames from the video file stream | |
while True: | |
# read the next frame from the file | |
(grabbed, frame) = vs.read() | |
if frame is not None: | |
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
frames.append(frame) | |
# if the frame was not grabbed, then we have reached the end | |
# of the stream | |
if not grabbed: | |
break | |
return frames | |
def preprocess_video(frames: list): | |
"""Utility to apply preprocessing transformations to a video tensor.""" | |
# Each frame in the `frames` list has the shape: (height, width, num_channels). | |
# Collated together the `frames` has the the shape: (num_frames, height, width, num_channels). | |
# So, after converting the `frames` list to a torch tensor, we permute the shape | |
# such that it becomes (num_channels, num_frames, height, width) to make | |
# the shape compatible with the preprocessing transformations. After applying the | |
# preprocessing chain, we permute the shape to (num_frames, num_channels, height, width) | |
# to make it compatible with the model. Finally, we add a batch dimension so that our video | |
# classification model can operate on it. | |
video_tensor = torch.tensor(np.array(frames).astype(frames[0].dtype)) | |
video_tensor = video_tensor.permute( | |
3, 0, 1, 2 | |
) # (num_channels, num_frames, height, width) | |
video_tensor_pp = VAL_TRANSFORMS(video_tensor) | |
video_tensor_pp = video_tensor_pp.permute( | |
1, 0, 2, 3 | |
) # (num_frames, num_channels, height, width) | |
video_tensor_pp = video_tensor_pp.unsqueeze(0) | |
return video_tensor_pp.to(DEVICE) | |
def infer(video_file): | |
frames = parse_video(video_file) | |
video_tensor = preprocess_video(frames) | |
inputs = {"pixel_values": video_tensor} | |
# forward pass | |
with torch.no_grad(): | |
outputs = MODEL(**inputs) | |
logits = outputs.logits | |
softmax_scores = torch.nn.functional.softmax(logits, dim=0) | |
confidences = {LABELS[i]: float(softmax_scores[i]) for i in range(len(LABELS))} | |
return confidences | |
gr.Interface( | |
fn=infer, | |
inputs=gr.Video(type="file"), | |
outputs=gr.Label(num_top_classes=3), | |
examples=[ | |
["examples/babycrawling.mp4"], | |
["examples/baseball.mp4"], | |
["examples/balancebeam.mp4"], | |
], | |
title="VideoMAE fine-tuned on a subset of UCF-101", | |
description=( | |
"Gradio demo for VideoMAE for video classification. To use it, simply upload your video or click one of the" | |
" examples to load them. Read more at the links below." | |
), | |
article=( | |
"<div style='text-align: center;'><a href='https://huggingface.co/docs/transformers/model_doc/videomae' target='_blank'>VideoMAE</a>" | |
" <center><a href='https://huggingface.co/sayakpaul/videomae-base-finetuned-kinetics-finetuned-ucf101-subset' target='_blank'>Fine-tuned Model</a></center></div>" | |
), | |
allow_flagging=False, | |
allow_screenshot=False, | |
).launch() | |