File size: 1,566 Bytes
208b0eb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 |
import gc
import random
from contextlib import contextmanager
from typing import List, Tuple, Optional
import numpy as np
from decord import VideoReader
from PIL import Image
@contextmanager
def video_reader(*args, **kwargs):
"""A context manager to solve the memory leak of decord.
"""
vr = VideoReader(*args, **kwargs)
try:
yield vr
finally:
del vr
gc.collect()
def extract_frames(
video_path: str,
sample_method: str = "mid",
num_sampled_frames: int = -1,
sample_stride: int = -1,
**kwargs
) -> Optional[Tuple[List[int], List[Image.Image]]]:
with video_reader(video_path, num_threads=2, **kwargs) as vr:
if sample_method == "mid":
sampled_frame_idx_list = [len(vr) // 2]
elif sample_method == "uniform":
sampled_frame_idx_list = np.linspace(0, len(vr), num_sampled_frames, endpoint=False, dtype=int)
elif sample_method == "random":
clip_length = min(len(vr), (num_sampled_frames - 1) * sample_stride + 1)
start_idx = random.randint(0, len(vr) - clip_length)
sampled_frame_idx_list = np.linspace(start_idx, start_idx + clip_length - 1, num_sampled_frames, dtype=int)
else:
raise ValueError(f"The sample_method {sample_method} must be mid, uniform or random.")
sampled_frame_list = vr.get_batch(sampled_frame_idx_list).asnumpy()
sampled_frame_list = [Image.fromarray(frame) for frame in sampled_frame_list]
return list(sampled_frame_idx_list), sampled_frame_list
|