MotionClone-Text-to-Video / i2v_video_sample.py
svjack's picture
Upload folder using huggingface_hub
ce68674 verified
import argparse
from omegaconf import OmegaConf
import torch
from diffusers import AutoencoderKL, DDIMScheduler
from transformers import CLIPTextModel, CLIPTokenizer
from motionclone.models.unet import UNet3DConditionModel
from motionclone.models.sparse_controlnet import SparseControlNetModel
from motionclone.pipelines.pipeline_animation import AnimationPipeline
from motionclone.utils.util import load_weights, auto_download
from diffusers.utils.import_utils import is_xformers_available
from motionclone.utils.motionclone_functions import *
import json
from motionclone.utils.xformer_attention import *
def main(args):
os.environ["CUDA_VISIBLE_DEVICES"] = args.visible_gpu or str(os.getenv('CUDA_VISIBLE_DEVICES', 0))
config = OmegaConf.load(args.inference_config)
adopted_dtype = torch.float16
device = "cuda"
set_all_seed(42)
tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_path, subfolder="tokenizer")
text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_path, subfolder="text_encoder").to(device).to(dtype=adopted_dtype)
vae = AutoencoderKL.from_pretrained(args.pretrained_model_path, subfolder="vae").to(device).to(dtype=adopted_dtype)
config.width = config.get("W", args.W)
config.height = config.get("H", args.H)
config.video_length = config.get("L", args.L)
if not os.path.exists(args.generated_videos_save_dir):
os.makedirs(args.generated_videos_save_dir)
OmegaConf.save(config, os.path.join(args.generated_videos_save_dir,"inference_config.json"))
model_config = OmegaConf.load(config.get("model_config", ""))
unet = UNet3DConditionModel.from_pretrained_2d(args.pretrained_model_path, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(model_config.unet_additional_kwargs),).to(device).to(dtype=adopted_dtype)
# load controlnet model
controlnet = None
if config.get("controlnet_path", "") != "":
# assert model_config.get("controlnet_images", "") != ""
assert config.get("controlnet_config", "") != ""
unet.config.num_attention_heads = 8
unet.config.projection_class_embeddings_input_dim = None
controlnet_config = OmegaConf.load(config.controlnet_config)
controlnet = SparseControlNetModel.from_unet(unet, controlnet_additional_kwargs=controlnet_config.get("controlnet_additional_kwargs", {})).to(device).to(dtype=adopted_dtype)
auto_download(config.controlnet_path, is_dreambooth_lora=False)
print(f"loading controlnet checkpoint from {config.controlnet_path} ...")
controlnet_state_dict = torch.load(config.controlnet_path, map_location="cpu")
controlnet_state_dict = controlnet_state_dict["controlnet"] if "controlnet" in controlnet_state_dict else controlnet_state_dict
controlnet_state_dict = {name: param for name, param in controlnet_state_dict.items() if "pos_encoder.pe" not in name}
controlnet_state_dict.pop("animatediff_config", "")
controlnet.load_state_dict(controlnet_state_dict)
del controlnet_state_dict
# set xformers
if is_xformers_available() and (not args.without_xformers):
unet.enable_xformers_memory_efficient_attention()
pipeline = AnimationPipeline(
vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet,
controlnet=controlnet,
scheduler=DDIMScheduler(**OmegaConf.to_container(model_config.noise_scheduler_kwargs)),
).to(device)
pipeline = load_weights(
pipeline,
# motion module
motion_module_path = config.get("motion_module", ""),
# domain adapter
adapter_lora_path = config.get("adapter_lora_path", ""),
adapter_lora_scale = config.get("adapter_lora_scale", 1.0),
# image layer
dreambooth_model_path = config.get("dreambooth_path", ""),
).to(device)
pipeline.text_encoder.to(dtype=adopted_dtype)
# customized functions in motionclone_functions
pipeline.scheduler.customized_step = schedule_customized_step.__get__(pipeline.scheduler)
pipeline.scheduler.customized_set_timesteps = schedule_set_timesteps.__get__(pipeline.scheduler)
pipeline.unet.forward = unet_customized_forward.__get__(pipeline.unet)
pipeline.sample_video = sample_video.__get__(pipeline)
pipeline.single_step_video = single_step_video.__get__(pipeline)
pipeline.get_temp_attn_prob = get_temp_attn_prob.__get__(pipeline)
pipeline.add_noise = add_noise.__get__(pipeline)
pipeline.compute_temp_loss = compute_temp_loss.__get__(pipeline)
pipeline.obtain_motion_representation = obtain_motion_representation.__get__(pipeline)
for param in pipeline.unet.parameters():
param.requires_grad = False
for param in pipeline.controlnet.parameters():
param.requires_grad = False
pipeline.input_config, pipeline.unet.input_config = config, config
pipeline.unet = prep_unet_attention(pipeline.unet,pipeline.input_config.motion_guidance_blocks)
pipeline.unet = prep_unet_conv(pipeline.unet)
pipeline.scheduler.customized_set_timesteps(config.inference_steps, config.guidance_steps,config.guidance_scale,device=device,timestep_spacing_type = "uneven")
with open(args.examples, 'r') as files:
for line in files:
# prepare infor of each case
example_infor = json.loads(line)
config.video_path = example_infor["video_path"]
config.condition_image_path_list = example_infor["condition_image_paths"]
config.image_index = example_infor.get("image_index",[0])
assert len(config.image_index) == len(config.condition_image_path_list)
config.new_prompt = example_infor["new_prompt"] + config.get("positive_prompt", "")
config.controlnet_scale = example_infor.get("controlnet_scale", 1.0)
pipeline.input_config, pipeline.unet.input_config = config, config # update config
# perform motion representation extraction
seed_motion = seed_motion = example_infor.get("seed", args.default_seed)
generator = torch.Generator(device=pipeline.device)
generator.manual_seed(seed_motion)
if not os.path.exists(args.motion_representation_save_dir):
os.makedirs(args.motion_representation_save_dir)
motion_representation_path = os.path.join(args.motion_representation_save_dir, os.path.splitext(os.path.basename(config.video_path))[0] + '.pt')
pipeline.obtain_motion_representation(generator= generator, motion_representation_path = motion_representation_path, use_controlnet=True,)
# perform video generation
seed = seed_motion # can assign other seed here
generator = torch.Generator(device=pipeline.device)
generator.manual_seed(seed)
pipeline.input_config.seed = seed
videos = pipeline.sample_video(generator = generator, add_controlnet=True,)
videos = rearrange(videos, "b c f h w -> b f h w c")
save_path = os.path.join(args.generated_videos_save_dir, os.path.splitext(os.path.basename(config.video_path))[0]
+ "_" + config.new_prompt.strip().replace(' ', '_') + str(seed_motion) + "_" +str(seed)+'.mp4')
videos_uint8 = (videos[0] * 255).astype(np.uint8)
imageio.mimwrite(save_path, videos_uint8, fps=8)
print(save_path,"is done")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--pretrained-model-path", type=str, default="models/StableDiffusion",)
parser.add_argument("--inference_config", type=str, default="configs/i2v_sketch.yaml")
parser.add_argument("--examples", type=str, default="configs/i2v_sketch.jsonl")
parser.add_argument("--motion-representation-save-dir", type=str, default="motion_representation/")
parser.add_argument("--generated-videos-save-dir", type=str, default="generated_videos/")
parser.add_argument("--visible_gpu", type=str, default=None)
parser.add_argument("--default-seed", type=int, default=76739)
parser.add_argument("--L", type=int, default=16)
parser.add_argument("--W", type=int, default=512)
parser.add_argument("--H", type=int, default=512)
parser.add_argument("--without-xformers", action="store_true")
args = parser.parse_args()
main(args)