Spaces:
Runtime error
Runtime error
File size: 8,651 Bytes
ce68674 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 |
import argparse
from omegaconf import OmegaConf
import torch
from diffusers import AutoencoderKL, DDIMScheduler
from transformers import CLIPTextModel, CLIPTokenizer
from motionclone.models.unet import UNet3DConditionModel
from motionclone.models.sparse_controlnet import SparseControlNetModel
from motionclone.pipelines.pipeline_animation import AnimationPipeline
from motionclone.utils.util import load_weights, auto_download
from diffusers.utils.import_utils import is_xformers_available
from motionclone.utils.motionclone_functions import *
import json
from motionclone.utils.xformer_attention import *
def main(args):
os.environ["CUDA_VISIBLE_DEVICES"] = args.visible_gpu or str(os.getenv('CUDA_VISIBLE_DEVICES', 0))
config = OmegaConf.load(args.inference_config)
adopted_dtype = torch.float16
device = "cuda"
set_all_seed(42)
tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_path, subfolder="tokenizer")
text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_path, subfolder="text_encoder").to(device).to(dtype=adopted_dtype)
vae = AutoencoderKL.from_pretrained(args.pretrained_model_path, subfolder="vae").to(device).to(dtype=adopted_dtype)
config.width = config.get("W", args.W)
config.height = config.get("H", args.H)
config.video_length = config.get("L", args.L)
if not os.path.exists(args.generated_videos_save_dir):
os.makedirs(args.generated_videos_save_dir)
OmegaConf.save(config, os.path.join(args.generated_videos_save_dir,"inference_config.json"))
model_config = OmegaConf.load(config.get("model_config", ""))
unet = UNet3DConditionModel.from_pretrained_2d(args.pretrained_model_path, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(model_config.unet_additional_kwargs),).to(device).to(dtype=adopted_dtype)
# load controlnet model
controlnet = None
if config.get("controlnet_path", "") != "":
# assert model_config.get("controlnet_images", "") != ""
assert config.get("controlnet_config", "") != ""
unet.config.num_attention_heads = 8
unet.config.projection_class_embeddings_input_dim = None
controlnet_config = OmegaConf.load(config.controlnet_config)
controlnet = SparseControlNetModel.from_unet(unet, controlnet_additional_kwargs=controlnet_config.get("controlnet_additional_kwargs", {})).to(device).to(dtype=adopted_dtype)
auto_download(config.controlnet_path, is_dreambooth_lora=False)
print(f"loading controlnet checkpoint from {config.controlnet_path} ...")
controlnet_state_dict = torch.load(config.controlnet_path, map_location="cpu")
controlnet_state_dict = controlnet_state_dict["controlnet"] if "controlnet" in controlnet_state_dict else controlnet_state_dict
controlnet_state_dict = {name: param for name, param in controlnet_state_dict.items() if "pos_encoder.pe" not in name}
controlnet_state_dict.pop("animatediff_config", "")
controlnet.load_state_dict(controlnet_state_dict)
del controlnet_state_dict
# set xformers
if is_xformers_available() and (not args.without_xformers):
unet.enable_xformers_memory_efficient_attention()
pipeline = AnimationPipeline(
vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet,
controlnet=controlnet,
scheduler=DDIMScheduler(**OmegaConf.to_container(model_config.noise_scheduler_kwargs)),
).to(device)
pipeline = load_weights(
pipeline,
# motion module
motion_module_path = config.get("motion_module", ""),
# domain adapter
adapter_lora_path = config.get("adapter_lora_path", ""),
adapter_lora_scale = config.get("adapter_lora_scale", 1.0),
# image layer
dreambooth_model_path = config.get("dreambooth_path", ""),
).to(device)
pipeline.text_encoder.to(dtype=adopted_dtype)
# customized functions in motionclone_functions
pipeline.scheduler.customized_step = schedule_customized_step.__get__(pipeline.scheduler)
pipeline.scheduler.customized_set_timesteps = schedule_set_timesteps.__get__(pipeline.scheduler)
pipeline.unet.forward = unet_customized_forward.__get__(pipeline.unet)
pipeline.sample_video = sample_video.__get__(pipeline)
pipeline.single_step_video = single_step_video.__get__(pipeline)
pipeline.get_temp_attn_prob = get_temp_attn_prob.__get__(pipeline)
pipeline.add_noise = add_noise.__get__(pipeline)
pipeline.compute_temp_loss = compute_temp_loss.__get__(pipeline)
pipeline.obtain_motion_representation = obtain_motion_representation.__get__(pipeline)
for param in pipeline.unet.parameters():
param.requires_grad = False
for param in pipeline.controlnet.parameters():
param.requires_grad = False
pipeline.input_config, pipeline.unet.input_config = config, config
pipeline.unet = prep_unet_attention(pipeline.unet,pipeline.input_config.motion_guidance_blocks)
pipeline.unet = prep_unet_conv(pipeline.unet)
pipeline.scheduler.customized_set_timesteps(config.inference_steps, config.guidance_steps,config.guidance_scale,device=device,timestep_spacing_type = "uneven")
with open(args.examples, 'r') as files:
for line in files:
# prepare infor of each case
example_infor = json.loads(line)
config.video_path = example_infor["video_path"]
config.condition_image_path_list = example_infor["condition_image_paths"]
config.image_index = example_infor.get("image_index",[0])
assert len(config.image_index) == len(config.condition_image_path_list)
config.new_prompt = example_infor["new_prompt"] + config.get("positive_prompt", "")
config.controlnet_scale = example_infor.get("controlnet_scale", 1.0)
pipeline.input_config, pipeline.unet.input_config = config, config # update config
# perform motion representation extraction
seed_motion = seed_motion = example_infor.get("seed", args.default_seed)
generator = torch.Generator(device=pipeline.device)
generator.manual_seed(seed_motion)
if not os.path.exists(args.motion_representation_save_dir):
os.makedirs(args.motion_representation_save_dir)
motion_representation_path = os.path.join(args.motion_representation_save_dir, os.path.splitext(os.path.basename(config.video_path))[0] + '.pt')
pipeline.obtain_motion_representation(generator= generator, motion_representation_path = motion_representation_path, use_controlnet=True,)
# perform video generation
seed = seed_motion # can assign other seed here
generator = torch.Generator(device=pipeline.device)
generator.manual_seed(seed)
pipeline.input_config.seed = seed
videos = pipeline.sample_video(generator = generator, add_controlnet=True,)
videos = rearrange(videos, "b c f h w -> b f h w c")
save_path = os.path.join(args.generated_videos_save_dir, os.path.splitext(os.path.basename(config.video_path))[0]
+ "_" + config.new_prompt.strip().replace(' ', '_') + str(seed_motion) + "_" +str(seed)+'.mp4')
videos_uint8 = (videos[0] * 255).astype(np.uint8)
imageio.mimwrite(save_path, videos_uint8, fps=8)
print(save_path,"is done")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--pretrained-model-path", type=str, default="models/StableDiffusion",)
parser.add_argument("--inference_config", type=str, default="configs/i2v_sketch.yaml")
parser.add_argument("--examples", type=str, default="configs/i2v_sketch.jsonl")
parser.add_argument("--motion-representation-save-dir", type=str, default="motion_representation/")
parser.add_argument("--generated-videos-save-dir", type=str, default="generated_videos/")
parser.add_argument("--visible_gpu", type=str, default=None)
parser.add_argument("--default-seed", type=int, default=76739)
parser.add_argument("--L", type=int, default=16)
parser.add_argument("--W", type=int, default=512)
parser.add_argument("--H", type=int, default=512)
parser.add_argument("--without-xformers", action="store_true")
args = parser.parse_args()
main(args)
|