|
import argparse, os, sys, glob, yaml, math, random |
|
import datetime, time |
|
import numpy as np |
|
from omegaconf import OmegaConf |
|
from collections import OrderedDict |
|
from tqdm import trange, tqdm |
|
from einops import repeat |
|
from einops import rearrange, repeat |
|
from functools import partial |
|
import torch |
|
from pytorch_lightning import seed_everything |
|
|
|
from funcs import load_model_checkpoint, load_prompts, load_image_batch, get_filelist, save_videos |
|
from funcs import batch_ddim_sampling |
|
from utils.utils import instantiate_from_config |
|
|
|
|
|
def get_parser(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--seed", type=int, default=20230211, help="seed for seed_everything") |
|
parser.add_argument("--mode", default="base", type=str, help="which kind of inference mode: {'base', 'i2v'}") |
|
parser.add_argument("--ckpt_path", type=str, default=None, help="checkpoint path") |
|
parser.add_argument("--config", type=str, help="config (yaml) path") |
|
parser.add_argument("--prompt_file", type=str, default=None, help="a text file containing many prompts") |
|
parser.add_argument("--savedir", type=str, default=None, help="results saving path") |
|
parser.add_argument("--savefps", type=str, default=10, help="video fps to generate") |
|
parser.add_argument("--n_samples", type=int, default=1, help="num of samples per prompt",) |
|
parser.add_argument("--ddim_steps", type=int, default=50, help="steps of ddim if positive, otherwise use DDPM",) |
|
parser.add_argument("--ddim_eta", type=float, default=1.0, help="eta for ddim sampling (0.0 yields deterministic sampling)",) |
|
parser.add_argument("--bs", type=int, default=1, help="batch size for inference") |
|
parser.add_argument("--height", type=int, default=512, help="image height, in pixel space") |
|
parser.add_argument("--width", type=int, default=512, help="image width, in pixel space") |
|
parser.add_argument("--frames", type=int, default=-1, help="frames num to inference") |
|
parser.add_argument("--fps", type=int, default=24) |
|
parser.add_argument("--unconditional_guidance_scale", type=float, default=1.0, help="prompt classifier-free guidance") |
|
parser.add_argument("--unconditional_guidance_scale_temporal", type=float, default=None, help="temporal consistency guidance") |
|
|
|
parser.add_argument("--cond_input", type=str, default=None, help="data dir of conditional input") |
|
return parser |
|
|
|
|
|
def run_inference(args, gpu_num, gpu_no, **kwargs): |
|
|
|
|
|
config = OmegaConf.load(args.config) |
|
|
|
model_config = config.pop("model", OmegaConf.create()) |
|
model = instantiate_from_config(model_config) |
|
model = model.cuda(gpu_no) |
|
assert os.path.exists(args.ckpt_path), f"Error: checkpoint [{args.ckpt_path}] Not Found!" |
|
model = load_model_checkpoint(model, args.ckpt_path) |
|
model.eval() |
|
|
|
|
|
assert (args.height % 16 == 0) and (args.width % 16 == 0), "Error: image size [h,w] should be multiples of 16!" |
|
|
|
h, w = args.height // 8, args.width // 8 |
|
frames = model.temporal_length if args.frames < 0 else args.frames |
|
channels = model.channels |
|
|
|
|
|
os.makedirs(args.savedir, exist_ok=True) |
|
|
|
|
|
|
|
assert os.path.exists(args.prompt_file), "Error: prompt file NOT Found!" |
|
prompt_list = load_prompts(args.prompt_file) |
|
num_samples = len(prompt_list) |
|
filename_list = [f"{id+1:04d}" for id in range(num_samples)] |
|
|
|
samples_split = num_samples // gpu_num |
|
residual_tail = num_samples % gpu_num |
|
print(f'[rank:{gpu_no}] {samples_split}/{num_samples} samples loaded.') |
|
indices = list(range(samples_split*gpu_no, samples_split*(gpu_no+1))) |
|
if gpu_no == 0 and residual_tail != 0: |
|
indices = indices + list(range(num_samples-residual_tail, num_samples)) |
|
prompt_list_rank = [prompt_list[i] for i in indices] |
|
|
|
|
|
if args.mode == "i2v": |
|
|
|
cond_inputs = get_filelist(args.cond_input, ext='[mpj][pn][4gj]') |
|
assert len(cond_inputs) == num_samples, f"Error: conditional input ({len(cond_inputs)}) NOT match prompt ({num_samples})!" |
|
filename_list = [f"{os.path.split(cond_inputs[id])[-1][:-4]}" for id in range(num_samples)] |
|
cond_inputs_rank = [cond_inputs[i] for i in indices] |
|
|
|
filename_list_rank = [filename_list[i] for i in indices] |
|
|
|
|
|
|
|
start = time.time() |
|
n_rounds = len(prompt_list_rank) // args.bs |
|
n_rounds = n_rounds+1 if len(prompt_list_rank) % args.bs != 0 else n_rounds |
|
for idx in range(0, n_rounds): |
|
print(f'[rank:{gpu_no}] batch-{idx+1} ({args.bs})x{args.n_samples} ...') |
|
idx_s = idx*args.bs |
|
idx_e = min(idx_s+args.bs, len(prompt_list_rank)) |
|
batch_size = idx_e - idx_s |
|
filenames = filename_list_rank[idx_s:idx_e] |
|
noise_shape = [batch_size, channels, frames, h, w] |
|
fps = torch.tensor([args.fps]*batch_size).to(model.device).long() |
|
|
|
prompts = prompt_list_rank[idx_s:idx_e] |
|
if isinstance(prompts, str): |
|
prompts = [prompts] |
|
|
|
text_emb = model.get_learned_conditioning(prompts) |
|
|
|
if args.mode == 'base': |
|
cond = {"c_crossattn": [text_emb], "fps": fps} |
|
elif args.mode == 'i2v': |
|
|
|
cond_images = load_image_batch(cond_inputs_rank[idx_s:idx_e], (args.height, args.width)) |
|
cond_images = cond_images.to(model.device) |
|
img_emb = model.get_image_embeds(cond_images) |
|
imtext_cond = torch.cat([text_emb, img_emb], dim=1) |
|
cond = {"c_crossattn": [imtext_cond], "fps": fps} |
|
else: |
|
raise NotImplementedError |
|
|
|
|
|
batch_samples = batch_ddim_sampling(model, cond, noise_shape, args.n_samples, \ |
|
args.ddim_steps, args.ddim_eta, args.unconditional_guidance_scale, **kwargs) |
|
|
|
save_videos(batch_samples, args.savedir, filenames, fps=args.savefps) |
|
|
|
print(f"Saved in {args.savedir}. Time used: {(time.time() - start):.2f} seconds") |
|
|
|
|
|
if __name__ == '__main__': |
|
now = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S") |
|
print("@CoLVDM Inference: %s"%now) |
|
parser = get_parser() |
|
args = parser.parse_args() |
|
seed_everything(args.seed) |
|
rank, gpu_num = 0, 1 |
|
run_inference(args, gpu_num, rank) |