import os import cv2 import io import numpy as np import torch import decord from PIL import Image from decord import VideoReader, cpu import random try: from petrel_client.client import Client has_client = True except ImportError: has_client = False class VideoMAE(torch.utils.data.Dataset): """Load your own video classification dataset. Parameters ---------- root : str, required. Path to the root folder storing the dataset. setting : str, required. A text file describing the dataset, each line per video sample. There are three items in each line: (1) video path; (2) video length and (3) video label. prefix : str, required. The prefix for loading data. split : str, required. The split character for metadata. train : bool, default True. Whether to load the training or validation set. test_mode : bool, default False. Whether to perform evaluation on the test set. Usually there is three-crop or ten-crop evaluation strategy involved. name_pattern : str, default None. The naming pattern of the decoded video frames. For example, img_00012.jpg. video_ext : str, default 'mp4'. If video_loader is set to True, please specify the video format accordinly. is_color : bool, default True. Whether the loaded image is color or grayscale. modality : str, default 'rgb'. Input modalities, we support only rgb video frames for now. Will add support for rgb difference image and optical flow image later. num_segments : int, default 1. Number of segments to evenly divide the video into clips. A useful technique to obtain global video-level information. Limin Wang, etal, Temporal Segment Networks: Towards Good Practices for Deep Action Recognition, ECCV 2016. num_crop : int, default 1. Number of crops for each image. default is 1. Common choices are three crops and ten crops during evaluation. new_length : int, default 1. The length of input video clip. Default is a single image, but it can be multiple video frames. For example, new_length=16 means we will extract a video clip of consecutive 16 frames. new_step : int, default 1. Temporal sampling rate. For example, new_step=1 means we will extract a video clip of consecutive frames. new_step=2 means we will extract a video clip of every other frame. temporal_jitter : bool, default False. Whether to temporally jitter if new_step > 1. video_loader : bool, default False. Whether to use video loader to load data. use_decord : bool, default True. Whether to use Decord video loader to load data. Otherwise load image. transform : function, default None. A function that takes data and label and transforms them. data_aug : str, default 'v1'. Different types of data augmentation auto. Supports v1, v2, v3 and v4. lazy_init : bool, default False. If set to True, build a dataset instance without loading any dataset. """ def __init__(self, root, setting, prefix='', split=' ', train=True, test_mode=False, name_pattern='img_%05d.jpg', video_ext='mp4', is_color=True, modality='rgb', num_segments=1, num_crop=1, new_length=1, new_step=1, transform=None, temporal_jitter=False, video_loader=False, use_decord=True, lazy_init=False, num_sample=1, ): super(VideoMAE, self).__init__() self.root = root self.setting = setting self.prefix = prefix self.split = split self.train = train self.test_mode = test_mode self.is_color = is_color self.modality = modality self.num_segments = num_segments self.num_crop = num_crop self.new_length = new_length self.new_step = new_step self.skip_length = self.new_length * self.new_step self.temporal_jitter = temporal_jitter self.name_pattern = name_pattern self.video_loader = video_loader self.video_ext = video_ext self.use_decord = use_decord self.transform = transform self.lazy_init = lazy_init self.num_sample = num_sample # sparse sampling, num_segments != 1 if self.num_segments != 1: print('Use sparse sampling, change frame and stride') self.new_length = self.num_segments self.skip_length = 1 self.client = None if has_client: self.client = Client('~/petreloss.conf') if not self.lazy_init: self.clips = self._make_dataset(root, setting) if len(self.clips) == 0: raise(RuntimeError("Found 0 video clips in subfolders of: " + root + "\n" "Check your data directory (opt.data-dir).")) def __getitem__(self, index): while True: try: images = None if self.use_decord: directory, target = self.clips[index] if self.video_loader: if '.' in directory.split('/')[-1]: # data in the "setting" file already have extension, e.g., demo.mp4 video_name = directory else: # data in the "setting" file do not have extension, e.g., demo # So we need to provide extension (i.e., .mp4) to complete the file name. video_name = '{}.{}'.format(directory, self.video_ext) video_name = os.path.join(self.prefix, video_name) if video_name.startswith('s3'): video_bytes = self.client.get(video_name) decord_vr = VideoReader(io.BytesIO(video_bytes), num_threads=1, ctx=cpu(0)) else: decord_vr = decord.VideoReader(video_name, num_threads=1, ctx=cpu(0)) duration = len(decord_vr) segment_indices, skip_offsets = self._sample_train_indices(duration) images = self._video_TSN_decord_batch_loader(directory, decord_vr, duration, segment_indices, skip_offsets) else: video_name, total_frame, target = self.clips[index] video_name = os.path.join(self.prefix, video_name) segment_indices, skip_offsets = self._sample_train_indices(total_frame) frame_id_list = self._get_frame_id_list(total_frame, segment_indices, skip_offsets) images = [] for idx in frame_id_list: frame_fname = os.path.join(video_name, self.name_pattern.format(idx)) img_bytes = self.client.get(frame_fname) img_np = np.frombuffer(img_bytes, np.uint8) img = cv2.imdecode(img_np, cv2.IMREAD_COLOR) cv2.cvtColor(img, cv2.COLOR_BGR2RGB, img) images.append(Image.fromarray(img)) if images is not None: break except Exception as e: print("Failed to load video from {} with error {}".format( video_name, e)) index = random.randint(0, len(self.clips) - 1) if self.num_sample > 1: process_data_list = [] mask_list = [] for _ in range(self.num_sample): process_data, mask = self.transform((images, None)) process_data = process_data.view((self.new_length, 3) + process_data.size()[-2:]).transpose(0, 1) process_data_list.append(process_data) mask_list.append(mask) return process_data_list, mask_list else: process_data, mask = self.transform((images, None)) # T*C,H,W process_data = process_data.view((self.new_length, 3) + process_data.size()[-2:]).transpose(0, 1) # T*C,H,W -> T,C,H,W -> C,T,H,W return (process_data, mask) def __len__(self): return len(self.clips) def _make_dataset(self, directory, setting): if not os.path.exists(setting): raise(RuntimeError("Setting file %s doesn't exist. Check opt.train-list and opt.val-list. " % (setting))) clips = [] print(f'Load dataset using decord: {self.use_decord}') with open(setting) as split_f: data = split_f.readlines() for line in data: line_info = line.split(self.split) if len(line_info) < 2: raise(RuntimeError('Video input format is not correct, missing one or more element. %s' % line)) if self.use_decord: # line format: video_path, video_label clip_path = os.path.join(line_info[0]) target = int(line_info[1]) item = (clip_path, target) else: # line format: video_path, video_duration, video_label clip_path = os.path.join(line_info[0]) total_frame = int(line_info[1]) target = int(line_info[2]) item = (clip_path, total_frame, target) clips.append(item) return clips def _sample_train_indices(self, num_frames): average_duration = (num_frames - self.skip_length + 1) // self.num_segments if average_duration > 0: offsets = np.multiply(list(range(self.num_segments)), average_duration) offsets = offsets + np.random.randint(average_duration, size=self.num_segments) elif num_frames > max(self.num_segments, self.skip_length): offsets = np.sort(np.random.randint( num_frames - self.skip_length + 1, size=self.num_segments)) else: offsets = np.zeros((self.num_segments,)) if self.temporal_jitter: skip_offsets = np.random.randint( self.new_step, size=self.skip_length // self.new_step) else: skip_offsets = np.zeros( self.skip_length // self.new_step, dtype=int) return offsets + 1, skip_offsets def _get_frame_id_list(self, duration, indices, skip_offsets): frame_id_list = [] for seg_ind in indices: offset = int(seg_ind) for i, _ in enumerate(range(0, self.skip_length, self.new_step)): if offset + skip_offsets[i] <= duration: frame_id = offset + skip_offsets[i] - 1 else: frame_id = offset - 1 frame_id_list.append(frame_id) if offset + self.new_step < duration: offset += self.new_step return frame_id_list def _video_TSN_decord_batch_loader(self, directory, video_reader, duration, indices, skip_offsets): sampled_list = [] frame_id_list = [] for seg_ind in indices: offset = int(seg_ind) for i, _ in enumerate(range(0, self.skip_length, self.new_step)): if offset + skip_offsets[i] <= duration: frame_id = offset + skip_offsets[i] - 1 else: frame_id = offset - 1 frame_id_list.append(frame_id) if offset + self.new_step < duration: offset += self.new_step try: video_data = video_reader.get_batch(frame_id_list).asnumpy() sampled_list = [Image.fromarray(video_data[vid, :, :, :]).convert('RGB') for vid, _ in enumerate(frame_id_list)] except: raise RuntimeError('Error occured in reading frames {} from video {} of duration {}.'.format(frame_id_list, directory, duration)) return sampled_list