|
import math |
|
import random |
|
import argparse |
|
from typing import Optional |
|
|
|
import cv2 |
|
import numpy as np |
|
import numpy.typing as npt |
|
import torch |
|
from PIL import Image |
|
from decord import VideoReader, cpu |
|
from torch.nn import functional as F |
|
from pytorchvideo.transforms import ShortSideScale |
|
from torchvision.transforms import Lambda, Compose |
|
|
|
import sys |
|
sys.path.append(".") |
|
from opensora.dataset.transform import CenterCropVideo, resize |
|
from opensora.models.ae.videobase import CausalVAEModel |
|
|
|
|
|
def array_to_video(image_array: npt.NDArray, fps: float = 30.0, output_file: str = 'output_video.mp4') -> None: |
|
height, width, channels = image_array[0].shape |
|
fourcc = cv2.VideoWriter_fourcc(*'mp4v') |
|
video_writer = cv2.VideoWriter(output_file, fourcc, float(fps), (width, height)) |
|
|
|
for image in image_array: |
|
image_rgb = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) |
|
video_writer.write(image_rgb) |
|
|
|
video_writer.release() |
|
|
|
def custom_to_video(x: torch.Tensor, fps: float = 2.0, output_file: str = 'output_video.mp4') -> None: |
|
x = x.detach().cpu() |
|
x = torch.clamp(x, -1, 1) |
|
x = (x + 1) / 2 |
|
x = x.permute(1, 2, 3, 0).numpy() |
|
x = (255*x).astype(np.uint8) |
|
array_to_video(x, fps=fps, output_file=output_file) |
|
return |
|
|
|
def read_video(video_path: str, num_frames: int, sample_rate: int) -> torch.Tensor: |
|
decord_vr = VideoReader(video_path, ctx=cpu(0)) |
|
total_frames = len(decord_vr) |
|
sample_frames_len = sample_rate * num_frames |
|
|
|
if total_frames > sample_frames_len: |
|
s = random.randint(0, total_frames - sample_frames_len - 1) |
|
s = 0 |
|
e = s + sample_frames_len |
|
num_frames = num_frames |
|
else: |
|
s = 0 |
|
e = total_frames |
|
num_frames = int(total_frames / sample_frames_len * num_frames) |
|
print(f'sample_frames_len {sample_frames_len}, only can sample {num_frames * sample_rate}', video_path, |
|
total_frames) |
|
|
|
|
|
frame_id_list = np.linspace(s, e - 1, num_frames, dtype=int) |
|
video_data = decord_vr.get_batch(frame_id_list).asnumpy() |
|
video_data = torch.from_numpy(video_data) |
|
video_data = video_data.permute(3, 0, 1, 2) |
|
return video_data |
|
|
|
|
|
class ResizeVideo: |
|
def __init__( |
|
self, |
|
size, |
|
interpolation_mode="bilinear", |
|
): |
|
self.size = size |
|
|
|
self.interpolation_mode = interpolation_mode |
|
|
|
def __call__(self, clip): |
|
_, _, h, w = clip.shape |
|
if w < h: |
|
new_h = int(math.floor((float(h) / w) * self.size)) |
|
new_w = self.size |
|
else: |
|
new_h = self.size |
|
new_w = int(math.floor((float(w) / h) * self.size)) |
|
return torch.nn.functional.interpolate( |
|
clip, size=(new_h, new_w), mode=self.interpolation_mode, align_corners=False, antialias=True |
|
) |
|
|
|
def __repr__(self) -> str: |
|
return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}" |
|
|
|
def preprocess(video_data: torch.Tensor, short_size: int = 128, crop_size: Optional[int] = None) -> torch.Tensor: |
|
|
|
transform = Compose( |
|
[ |
|
Lambda(lambda x: ((x / 255.0) * 2 - 1)), |
|
ResizeVideo(size=short_size), |
|
CenterCropVideo(crop_size) if crop_size is not None else Lambda(lambda x: x), |
|
] |
|
) |
|
|
|
video_outputs = transform(video_data) |
|
video_outputs = torch.unsqueeze(video_outputs, 0) |
|
|
|
return video_outputs |
|
|
|
|
|
def main(args: argparse.Namespace): |
|
video_path = args.video_path |
|
num_frames = args.num_frames |
|
resolution = args.resolution |
|
crop_size = args.crop_size |
|
sample_fps = args.sample_fps |
|
sample_rate = args.sample_rate |
|
device = args.device |
|
vqvae = CausalVAEModel.from_pretrained(args.ckpt) |
|
if args.enable_tiling: |
|
vqvae.enable_tiling() |
|
vqvae.tile_overlap_factor = args.tile_overlap_factor |
|
vqvae.eval() |
|
vqvae = vqvae.to(device) |
|
vqvae = vqvae.to(torch.float16) |
|
|
|
with torch.no_grad(): |
|
x_vae = preprocess(read_video(video_path, num_frames, sample_rate), resolution, crop_size) |
|
x_vae = x_vae.to(device) |
|
x_vae = x_vae.to(torch.float16) |
|
latents = vqvae.encode(x_vae).sample().to(torch.float16) |
|
video_recon = vqvae.decode(latents) |
|
|
|
if video_recon.shape[2] == 1: |
|
x = video_recon[0, :, 0, :, :] |
|
x = x.squeeze() |
|
x = x.detach().cpu().numpy() |
|
x = np.clip(x, -1, 1) |
|
x = (x + 1) / 2 |
|
x = (255 * x).astype(np.uint8) |
|
x = x.transpose(1, 2, 0) |
|
image = Image.fromarray(x) |
|
image.save(args.rec_path.replace('mp4', 'jpg')) |
|
else: |
|
custom_to_video(video_recon[0], fps=sample_fps/sample_rate, output_file=args.rec_path) |
|
|
|
if __name__ == '__main__': |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('--video-path', type=str, default='') |
|
parser.add_argument('--rec-path', type=str, default='') |
|
parser.add_argument('--ckpt', type=str, default='results/pretrained') |
|
parser.add_argument('--sample-fps', type=int, default=30) |
|
parser.add_argument('--resolution', type=int, default=336) |
|
parser.add_argument('--crop-size', type=int, default=None) |
|
parser.add_argument('--num-frames', type=int, default=100) |
|
parser.add_argument('--sample-rate', type=int, default=1) |
|
parser.add_argument('--device', type=str, default="cuda") |
|
parser.add_argument('--tile_overlap_factor', type=float, default=0.25) |
|
parser.add_argument('--enable_tiling', action='store_true') |
|
|
|
args = parser.parse_args() |
|
main(args) |
|
|