|
import os |
|
import json |
|
import jsonlines |
|
import torch |
|
import math |
|
import random |
|
import cv2 |
|
|
|
from tqdm import tqdm |
|
from collections import OrderedDict |
|
|
|
from PIL import Image |
|
from PIL import ImageFile |
|
ImageFile.LOAD_TRUNCATED_IMAGES = True |
|
|
|
import numpy as np |
|
import subprocess |
|
from torch.utils.data import Dataset, DataLoader |
|
from torchvision import transforms |
|
from torchvision.transforms.functional import InterpolationMode |
|
from torchvision.transforms import functional as F |
|
|
|
|
|
class ImageTextDataset(Dataset): |
|
""" |
|
Usage: |
|
The dataset class for image-text pairs, used for image generation training |
|
It supports multi-aspect ratio training |
|
params: |
|
anno_file: The annotation file list |
|
add_normalize: whether to normalize the input image pixel to [-1, 1], default: True |
|
ratios: The aspect ratios during training, format: width / height |
|
sizes: The resoultion of training images, format: (width, height) |
|
""" |
|
def __init__( |
|
self, anno_file, add_normalize=True, |
|
ratios=[1/1, 3/5, 5/3], |
|
sizes=[(1024, 1024), (768, 1280), (1280, 768)], |
|
crop_mode='random', p_random_ratio=0.0, |
|
): |
|
|
|
super().__init__() |
|
|
|
self.image_annos = [] |
|
if not isinstance(anno_file, list): |
|
anno_file = [anno_file] |
|
|
|
for anno_file_ in anno_file: |
|
print(f"Load image annotation files from {anno_file_}") |
|
with jsonlines.open(anno_file_, 'r') as reader: |
|
for item in reader: |
|
self.image_annos.append(item) |
|
|
|
print(f"Totally Remained {len(self.image_annos)} images") |
|
|
|
transform_list = [ |
|
transforms.ToTensor(), |
|
] |
|
|
|
if add_normalize: |
|
transform_list.append(transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))) |
|
|
|
self.transform = transforms.Compose(transform_list) |
|
|
|
print(f"Transform List is {transform_list}") |
|
|
|
assert crop_mode in ['center', 'random'] |
|
self.crop_mode = crop_mode |
|
self.ratios = ratios |
|
self.sizes = sizes |
|
self.p_random_ratio = p_random_ratio |
|
|
|
def get_closest_size(self, x): |
|
if self.p_random_ratio > 0 and np.random.rand() < self.p_random_ratio: |
|
best_size_idx = np.random.randint(len(self.ratios)) |
|
else: |
|
w, h = x.width, x.height |
|
best_size_idx = np.argmin([abs(w/h-r) for r in self.ratios]) |
|
return self.sizes[best_size_idx] |
|
|
|
def get_resize_size(self, orig_size, tgt_size): |
|
if (tgt_size[1]/tgt_size[0] - 1) * (orig_size[1]/orig_size[0] - 1) >= 0: |
|
alt_min = int(math.ceil(max(tgt_size)*min(orig_size)/max(orig_size))) |
|
resize_size = max(alt_min, min(tgt_size)) |
|
else: |
|
alt_max = int(math.ceil(min(tgt_size)*max(orig_size)/min(orig_size))) |
|
resize_size = max(alt_max, max(tgt_size)) |
|
return resize_size |
|
|
|
def __len__(self): |
|
return len(self.image_annos) |
|
|
|
def __getitem__(self, index): |
|
image_anno = self.image_annos[index] |
|
|
|
try: |
|
img = Image.open(image_anno['image']).convert("RGB") |
|
text = image_anno['text'] |
|
|
|
assert isinstance(text, str), "Text should be str" |
|
|
|
size = self.get_closest_size(img) |
|
resize_size = self.get_resize_size((img.width, img.height), size) |
|
|
|
img = transforms.functional.resize(img, resize_size, interpolation=transforms.InterpolationMode.BICUBIC, antialias=True) |
|
|
|
if self.crop_mode == 'center': |
|
img = transforms.functional.center_crop(img, (size[1], size[0])) |
|
elif self.crop_mode == 'random': |
|
img = transforms.RandomCrop((size[1], size[0]))(img) |
|
else: |
|
img = transforms.functional.center_crop(img, (size[1], size[0])) |
|
|
|
image_tensor = self.transform(img) |
|
|
|
return { |
|
"video": image_tensor, |
|
"text" : text, |
|
"identifier": 'image', |
|
} |
|
|
|
except Exception as e: |
|
print(f'Load Image Error with {e}') |
|
return self.__getitem__(random.randint(0, self.__len__() - 1)) |
|
|
|
|
|
class LengthGroupedVideoTextDataset(Dataset): |
|
""" |
|
Usage: |
|
The dataset class for video-text pairs, used for video generation training |
|
It groups the video with the same frames together |
|
Now only supporting fixed resolution during training |
|
params: |
|
anno_file: The annotation file list |
|
max_frames: The maximum temporal lengths (This is the vae latent temporal length) 16 => (16 - 1) * 8 + 1 = 121 frames |
|
load_vae_latent: Loading the pre-extracted vae latents during training, we recommend to extract the latents in advance |
|
to reduce the time cost per batch |
|
load_text_fea: Loading the pre-extracted text features during training, we recommend to extract the prompt textual features |
|
in advance, since the T5 encoder will cost many GPU memories |
|
""" |
|
|
|
def __init__(self, anno_file, max_frames=16, resolution='384p', load_vae_latent=True, load_text_fea=True): |
|
super().__init__() |
|
|
|
self.video_annos = [] |
|
self.max_frames = max_frames |
|
self.load_vae_latent = load_vae_latent |
|
self.load_text_fea = load_text_fea |
|
self.resolution = resolution |
|
|
|
assert load_vae_latent, "Now only support loading vae latents, we will support to directly load video frames in the future" |
|
|
|
if not isinstance(anno_file, list): |
|
anno_file = [anno_file] |
|
|
|
for anno_file_ in anno_file: |
|
with jsonlines.open(anno_file_, 'r') as reader: |
|
for item in tqdm(reader): |
|
self.video_annos.append(item) |
|
|
|
print(f"Totally Remained {len(self.video_annos)} videos") |
|
|
|
def __len__(self): |
|
return len(self.video_annos) |
|
|
|
def __getitem__(self, index): |
|
try: |
|
video_anno = self.video_annos[index] |
|
text = video_anno['text'] |
|
latent_path = video_anno['latent'] |
|
latent = torch.load(latent_path, map_location='cpu') |
|
|
|
|
|
if self.resolution == '384p': |
|
assert latent.shape[-1] == 640 // 8 |
|
assert latent.shape[-2] == 384 // 8 |
|
else: |
|
assert self.resolution == '768p' |
|
assert latent.shape[-1] == 1280 // 8 |
|
assert latent.shape[-2] == 768 // 8 |
|
|
|
cur_temp = latent.shape[2] |
|
cur_temp = min(cur_temp, self.max_frames) |
|
|
|
video_latent = latent[:,:,:cur_temp].float() |
|
assert video_latent.shape[1] == 16 |
|
|
|
if self.load_text_fea: |
|
text_fea_path = video_anno['text_fea'] |
|
text_fea = torch.load(text_fea_path, map_location='cpu') |
|
return { |
|
'video': video_latent, |
|
'prompt_embed': text_fea['prompt_embed'], |
|
'prompt_attention_mask': text_fea['prompt_attention_mask'], |
|
'pooled_prompt_embed': text_fea['pooled_prompt_embed'], |
|
"identifier": 'video', |
|
} |
|
|
|
else: |
|
return { |
|
'video': video_latent, |
|
'text': text, |
|
"identifier": 'video', |
|
} |
|
|
|
except Exception as e: |
|
print(f'Load Video Error with {e}') |
|
return self.__getitem__(random.randint(0, self.__len__() - 1)) |
|
|
|
|
|
class VideoFrameProcessor: |
|
|
|
def __init__(self, resolution=256, num_frames=24, add_normalize=True, sample_fps=24): |
|
|
|
image_size = resolution |
|
|
|
transform_list = [ |
|
transforms.Resize(image_size, interpolation=InterpolationMode.BICUBIC, antialias=True), |
|
transforms.CenterCrop(image_size), |
|
] |
|
|
|
if add_normalize: |
|
transform_list.append(transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))) |
|
|
|
print(f"Transform List is {transform_list}") |
|
self.num_frames = num_frames |
|
self.transform = transforms.Compose(transform_list) |
|
self.sample_fps = sample_fps |
|
|
|
def __call__(self, video_path): |
|
try: |
|
video_capture = cv2.VideoCapture(video_path) |
|
fps = video_capture.get(cv2.CAP_PROP_FPS) |
|
frames = [] |
|
|
|
while True: |
|
flag, frame = video_capture.read() |
|
if not flag: |
|
break |
|
|
|
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
|
frame = torch.from_numpy(frame) |
|
frame = frame.permute(2, 0, 1) |
|
frames.append(frame) |
|
|
|
video_capture.release() |
|
sample_fps = self.sample_fps |
|
interval = max(int(fps / sample_fps), 1) |
|
frames = frames[::interval] |
|
|
|
if len(frames) < self.num_frames: |
|
num_frame_to_pack = self.num_frames - len(frames) |
|
recurrent_num = num_frame_to_pack // len(frames) |
|
frames = frames + recurrent_num * frames + frames[:(num_frame_to_pack % len(frames))] |
|
assert len(frames) >= self.num_frames, f'{len(frames)}' |
|
|
|
start_indexs = list(range(0, max(0, len(frames) - self.num_frames + 1))) |
|
start_index = random.choice(start_indexs) |
|
|
|
filtered_frames = frames[start_index : start_index+self.num_frames] |
|
assert len(filtered_frames) == self.num_frames, f"The sampled frames should equals to {self.num_frames}" |
|
|
|
filtered_frames = torch.stack(filtered_frames).float() / 255 |
|
filtered_frames = self.transform(filtered_frames) |
|
filtered_frames = filtered_frames.permute(1, 0, 2, 3) |
|
|
|
return filtered_frames, None |
|
|
|
except Exception as e: |
|
print(f"Load video: {video_path} Error, Exception {e}") |
|
return None, None |
|
|
|
|
|
class VideoDataset(Dataset): |
|
def __init__(self, anno_file, resolution=256, max_frames=6, add_normalize=True): |
|
super().__init__() |
|
|
|
self.video_annos = [] |
|
self.max_frames = max_frames |
|
|
|
if not isinstance(anno_file, list): |
|
anno_file = [anno_file] |
|
|
|
print(f"The training video clip frame number is {max_frames} ") |
|
|
|
for anno_file_ in anno_file: |
|
print(f"Load annotation file from {anno_file_}") |
|
|
|
with jsonlines.open(anno_file_, 'r') as reader: |
|
for item in tqdm(reader): |
|
self.video_annos.append(item) |
|
|
|
print(f"Totally Remained {len(self.video_annos)} videos") |
|
|
|
self.video_processor = VideoFrameProcessor(resolution, max_frames, add_normalize) |
|
|
|
def __len__(self): |
|
return len(self.video_annos) |
|
|
|
def __getitem__(self, index): |
|
video_anno = self.video_annos[index] |
|
video_path = video_anno['video'] |
|
|
|
try: |
|
video_tensors, video_frames = self.video_processor(video_path) |
|
|
|
assert video_tensors.shape[1] == self.max_frames |
|
|
|
return { |
|
"video": video_tensors, |
|
"identifier": 'video', |
|
} |
|
|
|
except Exception as e: |
|
print('Loading Video Error with {e}') |
|
return self.__getitem__(random.randint(0, self.__len__() - 1)) |
|
|
|
|
|
class ImageDataset(Dataset): |
|
def __init__(self, anno_file, resolution=256, max_frames=8, add_normalize=True): |
|
super().__init__() |
|
|
|
self.image_annos = [] |
|
self.max_frames = max_frames |
|
image_paths = [] |
|
|
|
if not isinstance(anno_file, list): |
|
anno_file = [anno_file] |
|
|
|
for anno_file_ in anno_file: |
|
print(f"Load annotation file from {anno_file_}") |
|
with jsonlines.open(anno_file_, 'r') as reader: |
|
for item in tqdm(reader): |
|
image_paths.append(item['image']) |
|
|
|
print(f"Totally Remained {len(image_paths)} images") |
|
|
|
|
|
for idx in range(0, len(image_paths), self.max_frames): |
|
image_path_shard = image_paths[idx : idx + self.max_frames] |
|
if len(image_path_shard) < self.max_frames: |
|
image_path_shard = image_path_shard + image_paths[:self.max_frames - len(image_path_shard)] |
|
assert len(image_path_shard) == self.max_frames |
|
self.image_annos.append(image_path_shard) |
|
|
|
image_size = resolution |
|
transform_list = [ |
|
transforms.Resize(image_size, interpolation=InterpolationMode.BICUBIC, antialias=True), |
|
transforms.CenterCrop(image_size), |
|
transforms.ToTensor(), |
|
] |
|
if add_normalize: |
|
transform_list.append(transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))) |
|
|
|
print(f"Transform List is {transform_list}") |
|
self.transform = transforms.Compose(transform_list) |
|
|
|
def __len__(self): |
|
return len(self.image_annos) |
|
|
|
def __getitem__(self, index): |
|
image_paths = self.image_annos[index] |
|
|
|
try: |
|
packed_pil_frames = [Image.open(image_path).convert("RGB") for image_path in image_paths] |
|
filtered_frames = [self.transform(frame) for frame in packed_pil_frames] |
|
filtered_frames = torch.stack(filtered_frames) |
|
filtered_frames = filtered_frames.permute(1, 0, 2, 3) |
|
|
|
return { |
|
"video": filtered_frames, |
|
"identifier": 'image', |
|
} |
|
|
|
except Exception as e: |
|
print(f'Load Images Error with {e}') |
|
return self.__getitem__(random.randint(0, self.__len__() - 1)) |