Spaces:
Runtime error
Runtime error
File size: 5,568 Bytes
49a842a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
# Prediction interface for Cog ⚙️
# https://github.com/replicate/cog/blob/main/docs/python.md
import os
import sys
import argparse
import random
from omegaconf import OmegaConf
from einops import rearrange, repeat
import torch
import torchvision
from pytorch_lightning import seed_everything
from cog import BasePredictor, Input, Path
sys.path.insert(0, "scripts/evaluation")
from funcs import (
batch_ddim_sampling,
load_model_checkpoint,
load_image_batch,
get_filelist,
)
from utils.utils import instantiate_from_config
class Predictor(BasePredictor):
def setup(self) -> None:
"""Load the model into memory to make running multiple predictions efficient"""
ckpt_path_base = "checkpoints/base_1024_v1/model.ckpt"
config_base = "configs/inference_t2v_1024_v1.0.yaml"
ckpt_path_i2v = "checkpoints/i2v_512_v1/model.ckpt"
config_i2v = "configs/inference_i2v_512_v1.0.yaml"
config_base = OmegaConf.load(config_base)
model_config_base = config_base.pop("model", OmegaConf.create())
self.model_base = instantiate_from_config(model_config_base)
self.model_base = self.model_base.cuda()
self.model_base = load_model_checkpoint(self.model_base, ckpt_path_base)
self.model_base.eval()
config_i2v = OmegaConf.load(config_i2v)
model_config_i2v = config_i2v.pop("model", OmegaConf.create())
self.model_i2v = instantiate_from_config(model_config_i2v)
self.model_i2v = self.model_i2v.cuda()
self.model_i2v = load_model_checkpoint(self.model_i2v, ckpt_path_i2v)
self.model_i2v.eval()
def predict(
self,
task: str = Input(
description="Choose the task.",
choices=["text2video", "image2video"],
default="text2video",
),
prompt: str = Input(
description="Prompt for video generation.",
default="A tiger walks in the forest, photorealistic, 4k, high definition.",
),
image: Path = Input(
description="Input image for image2video task.", default=None
),
ddim_steps: int = Input(description="Number of denoising steps.", default=50),
unconditional_guidance_scale: float = Input(
description="Classifier-free guidance scale.", default=12.0
),
seed: int = Input(
description="Random seed. Leave blank to randomize the seed", default=None
),
save_fps: int = Input(
description="Frame per second for the generated video.", default=10
),
) -> Path:
width = 1024 if task == "text2video" else 512
height = 576 if task == "text2video" else 320
model = self.model_base if task == "text2video" else self.model_i2v
if task == "image2video":
assert image is not None, "Please provide image for image2video generation."
if seed is None:
seed = int.from_bytes(os.urandom(2), "big")
print(f"Using seed: {seed}")
seed_everything(seed)
args = argparse.Namespace(
mode="base" if task == "text2video" else "i2v",
savefps=save_fps,
n_samples=1,
ddim_steps=ddim_steps,
ddim_eta=1.0,
bs=1,
height=height,
width=width,
frames=-1,
fps=28 if task == "text2video" else 8,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_guidance_scale_temporal=None,
)
## latent noise shape
h, w = args.height // 8, args.width // 8
frames = model.temporal_length if args.frames < 0 else args.frames
channels = model.channels
batch_size = 1
noise_shape = [batch_size, channels, frames, h, w]
fps = torch.tensor([args.fps] * batch_size).to(model.device).long()
prompts = [prompt]
text_emb = model.get_learned_conditioning(prompts)
if args.mode == "base":
cond = {"c_crossattn": [text_emb], "fps": fps}
elif args.mode == "i2v":
cond_images = load_image_batch([str(image)], (args.height, args.width))
cond_images = cond_images.to(model.device)
img_emb = model.get_image_embeds(cond_images)
imtext_cond = torch.cat([text_emb, img_emb], dim=1)
cond = {"c_crossattn": [imtext_cond], "fps": fps}
else:
raise NotImplementedError
## inference
batch_samples = batch_ddim_sampling(
model,
cond,
noise_shape,
args.n_samples,
args.ddim_steps,
args.ddim_eta,
args.unconditional_guidance_scale,
)
out_path = "/tmp/output.mp4"
vid_tensor = batch_samples[0]
video = vid_tensor.detach().cpu()
video = torch.clamp(video.float(), -1.0, 1.0)
video = video.permute(2, 0, 1, 3, 4) # t,n,c,h,w
frame_grids = [
torchvision.utils.make_grid(framesheet, nrow=int(args.n_samples))
for framesheet in video
] # [3, 1*h, n*w]
grid = torch.stack(frame_grids, dim=0) # stack in temporal dim [t, 3, n*h, w]
grid = (grid + 1.0) / 2.0
grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1)
torchvision.io.write_video(
out_path,
grid,
fps=args.savefps,
video_codec="h264",
options={"crf": "10"},
)
return Path(out_path)
|