File size: 1,566 Bytes
208b0eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import gc
import random
from contextlib import contextmanager
from typing import List, Tuple, Optional

import numpy as np
from decord import VideoReader
from PIL import Image


@contextmanager
def video_reader(*args, **kwargs):
    """A context manager to solve the memory leak of decord.
    """
    vr = VideoReader(*args, **kwargs)
    try:
        yield vr
    finally:
        del vr
        gc.collect()


def extract_frames(
    video_path: str,
    sample_method: str = "mid",
    num_sampled_frames: int = -1,
    sample_stride: int = -1,
    **kwargs
) -> Optional[Tuple[List[int], List[Image.Image]]]:
    with video_reader(video_path, num_threads=2, **kwargs) as vr:
        if sample_method == "mid":
            sampled_frame_idx_list = [len(vr) // 2]
        elif sample_method == "uniform":
            sampled_frame_idx_list = np.linspace(0, len(vr), num_sampled_frames, endpoint=False, dtype=int)
        elif sample_method == "random":
            clip_length = min(len(vr), (num_sampled_frames - 1) * sample_stride + 1)
            start_idx = random.randint(0, len(vr) - clip_length)
            sampled_frame_idx_list = np.linspace(start_idx, start_idx + clip_length - 1, num_sampled_frames, dtype=int)
        else:
            raise ValueError(f"The sample_method {sample_method} must be mid, uniform or random.")
        sampled_frame_list = vr.get_batch(sampled_frame_idx_list).asnumpy()
        sampled_frame_list = [Image.fromarray(frame) for frame in sampled_frame_list]

        return list(sampled_frame_idx_list), sampled_frame_list