|
import os |
|
import cv2 |
|
import io |
|
import numpy as np |
|
import torch |
|
import decord |
|
from PIL import Image |
|
from decord import VideoReader, cpu |
|
import random |
|
|
|
try: |
|
from petrel_client.client import Client |
|
has_client = True |
|
except ImportError: |
|
has_client = False |
|
|
|
|
|
class VideoMAE(torch.utils.data.Dataset): |
|
"""Load your own video classification dataset. |
|
Parameters |
|
---------- |
|
root : str, required. |
|
Path to the root folder storing the dataset. |
|
setting : str, required. |
|
A text file describing the dataset, each line per video sample. |
|
There are three items in each line: (1) video path; (2) video length and (3) video label. |
|
prefix : str, required. |
|
The prefix for loading data. |
|
split : str, required. |
|
The split character for metadata. |
|
train : bool, default True. |
|
Whether to load the training or validation set. |
|
test_mode : bool, default False. |
|
Whether to perform evaluation on the test set. |
|
Usually there is three-crop or ten-crop evaluation strategy involved. |
|
name_pattern : str, default None. |
|
The naming pattern of the decoded video frames. |
|
For example, img_00012.jpg. |
|
video_ext : str, default 'mp4'. |
|
If video_loader is set to True, please specify the video format accordinly. |
|
is_color : bool, default True. |
|
Whether the loaded image is color or grayscale. |
|
modality : str, default 'rgb'. |
|
Input modalities, we support only rgb video frames for now. |
|
Will add support for rgb difference image and optical flow image later. |
|
num_segments : int, default 1. |
|
Number of segments to evenly divide the video into clips. |
|
A useful technique to obtain global video-level information. |
|
Limin Wang, etal, Temporal Segment Networks: Towards Good Practices for Deep Action Recognition, ECCV 2016. |
|
num_crop : int, default 1. |
|
Number of crops for each image. default is 1. |
|
Common choices are three crops and ten crops during evaluation. |
|
new_length : int, default 1. |
|
The length of input video clip. Default is a single image, but it can be multiple video frames. |
|
For example, new_length=16 means we will extract a video clip of consecutive 16 frames. |
|
new_step : int, default 1. |
|
Temporal sampling rate. For example, new_step=1 means we will extract a video clip of consecutive frames. |
|
new_step=2 means we will extract a video clip of every other frame. |
|
temporal_jitter : bool, default False. |
|
Whether to temporally jitter if new_step > 1. |
|
video_loader : bool, default False. |
|
Whether to use video loader to load data. |
|
use_decord : bool, default True. |
|
Whether to use Decord video loader to load data. Otherwise load image. |
|
transform : function, default None. |
|
A function that takes data and label and transforms them. |
|
data_aug : str, default 'v1'. |
|
Different types of data augmentation auto. Supports v1, v2, v3 and v4. |
|
lazy_init : bool, default False. |
|
If set to True, build a dataset instance without loading any dataset. |
|
""" |
|
def __init__(self, |
|
root, |
|
setting, |
|
prefix='', |
|
split=' ', |
|
train=True, |
|
test_mode=False, |
|
name_pattern='img_%05d.jpg', |
|
video_ext='mp4', |
|
is_color=True, |
|
modality='rgb', |
|
num_segments=1, |
|
num_crop=1, |
|
new_length=1, |
|
new_step=1, |
|
transform=None, |
|
temporal_jitter=False, |
|
video_loader=False, |
|
use_decord=True, |
|
lazy_init=False, |
|
num_sample=1, |
|
): |
|
|
|
super(VideoMAE, self).__init__() |
|
self.root = root |
|
self.setting = setting |
|
self.prefix = prefix |
|
self.split = split |
|
self.train = train |
|
self.test_mode = test_mode |
|
self.is_color = is_color |
|
self.modality = modality |
|
self.num_segments = num_segments |
|
self.num_crop = num_crop |
|
self.new_length = new_length |
|
self.new_step = new_step |
|
self.skip_length = self.new_length * self.new_step |
|
self.temporal_jitter = temporal_jitter |
|
self.name_pattern = name_pattern |
|
self.video_loader = video_loader |
|
self.video_ext = video_ext |
|
self.use_decord = use_decord |
|
self.transform = transform |
|
self.lazy_init = lazy_init |
|
self.num_sample = num_sample |
|
|
|
|
|
if self.num_segments != 1: |
|
print('Use sparse sampling, change frame and stride') |
|
self.new_length = self.num_segments |
|
self.skip_length = 1 |
|
|
|
self.client = None |
|
if has_client: |
|
self.client = Client('~/petreloss.conf') |
|
|
|
if not self.lazy_init: |
|
self.clips = self._make_dataset(root, setting) |
|
if len(self.clips) == 0: |
|
raise(RuntimeError("Found 0 video clips in subfolders of: " + root + "\n" |
|
"Check your data directory (opt.data-dir).")) |
|
|
|
def __getitem__(self, index): |
|
while True: |
|
try: |
|
images = None |
|
if self.use_decord: |
|
directory, target = self.clips[index] |
|
if self.video_loader: |
|
if '.' in directory.split('/')[-1]: |
|
|
|
video_name = directory |
|
else: |
|
|
|
|
|
video_name = '{}.{}'.format(directory, self.video_ext) |
|
|
|
video_name = os.path.join(self.prefix, video_name) |
|
if video_name.startswith('s3'): |
|
video_bytes = self.client.get(video_name) |
|
decord_vr = VideoReader(io.BytesIO(video_bytes), |
|
num_threads=1, |
|
ctx=cpu(0)) |
|
else: |
|
decord_vr = decord.VideoReader(video_name, num_threads=1, ctx=cpu(0)) |
|
duration = len(decord_vr) |
|
|
|
segment_indices, skip_offsets = self._sample_train_indices(duration) |
|
images = self._video_TSN_decord_batch_loader(directory, decord_vr, duration, segment_indices, skip_offsets) |
|
|
|
else: |
|
video_name, total_frame, target = self.clips[index] |
|
video_name = os.path.join(self.prefix, video_name) |
|
|
|
segment_indices, skip_offsets = self._sample_train_indices(total_frame) |
|
frame_id_list = self._get_frame_id_list(total_frame, segment_indices, skip_offsets) |
|
images = [] |
|
for idx in frame_id_list: |
|
frame_fname = os.path.join(video_name, self.name_pattern.format(idx)) |
|
img_bytes = self.client.get(frame_fname) |
|
img_np = np.frombuffer(img_bytes, np.uint8) |
|
img = cv2.imdecode(img_np, cv2.IMREAD_COLOR) |
|
cv2.cvtColor(img, cv2.COLOR_BGR2RGB, img) |
|
images.append(Image.fromarray(img)) |
|
if images is not None: |
|
break |
|
except Exception as e: |
|
print("Failed to load video from {} with error {}".format( |
|
video_name, e)) |
|
index = random.randint(0, len(self.clips) - 1) |
|
|
|
if self.num_sample > 1: |
|
process_data_list = [] |
|
mask_list = [] |
|
for _ in range(self.num_sample): |
|
process_data, mask = self.transform((images, None)) |
|
process_data = process_data.view((self.new_length, 3) + process_data.size()[-2:]).transpose(0, 1) |
|
process_data_list.append(process_data) |
|
mask_list.append(mask) |
|
return process_data_list, mask_list |
|
else: |
|
process_data, mask = self.transform((images, None)) |
|
process_data = process_data.view((self.new_length, 3) + process_data.size()[-2:]).transpose(0, 1) |
|
return (process_data, mask) |
|
|
|
def __len__(self): |
|
return len(self.clips) |
|
|
|
def _make_dataset(self, directory, setting): |
|
if not os.path.exists(setting): |
|
raise(RuntimeError("Setting file %s doesn't exist. Check opt.train-list and opt.val-list. " % (setting))) |
|
clips = [] |
|
|
|
print(f'Load dataset using decord: {self.use_decord}') |
|
with open(setting) as split_f: |
|
data = split_f.readlines() |
|
for line in data: |
|
line_info = line.split(self.split) |
|
if len(line_info) < 2: |
|
raise(RuntimeError('Video input format is not correct, missing one or more element. %s' % line)) |
|
if self.use_decord: |
|
|
|
clip_path = os.path.join(line_info[0]) |
|
target = int(line_info[1]) |
|
item = (clip_path, target) |
|
else: |
|
|
|
clip_path = os.path.join(line_info[0]) |
|
total_frame = int(line_info[1]) |
|
target = int(line_info[2]) |
|
item = (clip_path, total_frame, target) |
|
clips.append(item) |
|
return clips |
|
|
|
def _sample_train_indices(self, num_frames): |
|
average_duration = (num_frames - self.skip_length + 1) // self.num_segments |
|
if average_duration > 0: |
|
offsets = np.multiply(list(range(self.num_segments)), |
|
average_duration) |
|
offsets = offsets + np.random.randint(average_duration, |
|
size=self.num_segments) |
|
elif num_frames > max(self.num_segments, self.skip_length): |
|
offsets = np.sort(np.random.randint( |
|
num_frames - self.skip_length + 1, |
|
size=self.num_segments)) |
|
else: |
|
offsets = np.zeros((self.num_segments,)) |
|
|
|
if self.temporal_jitter: |
|
skip_offsets = np.random.randint( |
|
self.new_step, size=self.skip_length // self.new_step) |
|
else: |
|
skip_offsets = np.zeros( |
|
self.skip_length // self.new_step, dtype=int) |
|
return offsets + 1, skip_offsets |
|
|
|
def _get_frame_id_list(self, duration, indices, skip_offsets): |
|
frame_id_list = [] |
|
for seg_ind in indices: |
|
offset = int(seg_ind) |
|
for i, _ in enumerate(range(0, self.skip_length, self.new_step)): |
|
if offset + skip_offsets[i] <= duration: |
|
frame_id = offset + skip_offsets[i] - 1 |
|
else: |
|
frame_id = offset - 1 |
|
frame_id_list.append(frame_id) |
|
if offset + self.new_step < duration: |
|
offset += self.new_step |
|
return frame_id_list |
|
|
|
def _video_TSN_decord_batch_loader(self, directory, video_reader, duration, indices, skip_offsets): |
|
sampled_list = [] |
|
frame_id_list = [] |
|
for seg_ind in indices: |
|
offset = int(seg_ind) |
|
for i, _ in enumerate(range(0, self.skip_length, self.new_step)): |
|
if offset + skip_offsets[i] <= duration: |
|
frame_id = offset + skip_offsets[i] - 1 |
|
else: |
|
frame_id = offset - 1 |
|
frame_id_list.append(frame_id) |
|
if offset + self.new_step < duration: |
|
offset += self.new_step |
|
try: |
|
video_data = video_reader.get_batch(frame_id_list).asnumpy() |
|
sampled_list = [Image.fromarray(video_data[vid, :, :, :]).convert('RGB') for vid, _ in enumerate(frame_id_list)] |
|
except: |
|
raise RuntimeError('Error occured in reading frames {} from video {} of duration {}.'.format(frame_id_list, directory, duration)) |
|
return sampled_list |