MagicTime / inference_magictime.py
BestWishYsh's picture
Upload 85 files
8a8dad9 verified
import os
import json
import time
import torch
import random
import inspect
import argparse
import numpy as np
import pandas as pd
from pathlib import Path
from omegaconf import OmegaConf
from transformers import CLIPTextModel, CLIPTokenizer
from diffusers import AutoencoderKL, DDIMScheduler
from diffusers.utils.import_utils import is_xformers_available
from utils.unet import UNet3DConditionModel
from utils.pipeline_magictime import MagicTimePipeline
from utils.util import save_videos_grid
from utils.util import load_weights
@torch.no_grad()
def main(args):
*_, func_args = inspect.getargvalues(inspect.currentframe())
func_args = dict(func_args)
if 'counter' not in globals():
globals()['counter'] = 0
unique_id = globals()['counter']
globals()['counter'] += 1
savedir_base = f"{Path(args.config).stem}"
savedir_prefix = "outputs"
savedir = None
if args.save_path:
savedir = os.path.join(savedir_prefix, args.save_path, f"{savedir_base}-{unique_id}")
else:
savedir = os.path.join(savedir_prefix, f"{savedir_base}-{unique_id}")
while os.path.exists(savedir):
unique_id = globals()['counter']
globals()['counter'] += 1
if args.save_path:
savedir = os.path.join(savedir_prefix, args.save_path, f"{savedir_base}-{unique_id}")
else:
savedir = os.path.join(savedir_prefix, f"{savedir_base}-{unique_id}")
os.makedirs(savedir)
print(f"The results will be save to {savedir}")
model_config = OmegaConf.load(args.config)[0]
inference_config = OmegaConf.load(args.config)[1]
if model_config.magic_adapter_s_path:
print("Use MagicAdapter-S")
if model_config.magic_adapter_t_path:
print("Use MagicAdapter-T")
if model_config.magic_text_encoder_path:
print("Use Magic_Text_Encoder")
samples = []
# create validation pipeline
tokenizer = CLIPTokenizer.from_pretrained(model_config.pretrained_model_path, subfolder="tokenizer")
text_encoder = CLIPTextModel.from_pretrained(model_config.pretrained_model_path, subfolder="text_encoder").cuda()
vae = AutoencoderKL.from_pretrained(model_config.pretrained_model_path, subfolder="vae").cuda()
unet = UNet3DConditionModel.from_pretrained_2d(model_config.pretrained_model_path, subfolder="unet",
unet_additional_kwargs=OmegaConf.to_container(
inference_config.unet_additional_kwargs)).cuda()
# set xformers
if is_xformers_available() and (not args.without_xformers):
unet.enable_xformers_memory_efficient_attention()
pipeline = MagicTimePipeline(
vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet,
scheduler=DDIMScheduler(**OmegaConf.to_container(inference_config.noise_scheduler_kwargs)),
).to("cuda")
pipeline = load_weights(
pipeline,
motion_module_path=model_config.get("motion_module", ""),
dreambooth_model_path=model_config.get("dreambooth_path", ""),
magic_adapter_s_path=model_config.get("magic_adapter_s_path", ""),
magic_adapter_t_path=model_config.get("magic_adapter_t_path", ""),
magic_text_encoder_path=model_config.get("magic_text_encoder_path", ""),
).to("cuda")
sample_idx = 0
if args.human:
sample_idx = 0 # Initialize sample index
while True:
user_prompt = input("Enter your prompt (or type 'exit' to quit): ")
if user_prompt.lower() == "exit":
break
random_seed = torch.randint(0, 2 ** 32 - 1, (1,)).item()
torch.manual_seed(random_seed)
print(f"current seed: {random_seed}")
print(f"sampling {user_prompt} ...")
# Now, you directly use `user_prompt` to generate a video.
# The following is a placeholder call; you need to adapt it to your actual video generation function.
sample = pipeline(
user_prompt,
num_inference_steps=model_config.steps,
guidance_scale=model_config.guidance_scale,
width=model_config.W,
height=model_config.H,
video_length=model_config.L,
).videos
# Adapt the filename to avoid conflicts and properly represent the content
prompt_for_filename = "-".join(user_prompt.replace("/", "").split(" ")[:10])
save_videos_grid(sample, f"{savedir}/sample/{sample_idx}-{random_seed}-{prompt_for_filename}.gif")
print(f"save to {savedir}/sample/{sample_idx}-{random_seed}-{prompt_for_filename}.gif")
sample_idx += 1
elif args.run_csv:
print("run_csv")
file_path = args.run_csv
data = pd.read_csv(file_path)
for index, row in data.iterrows():
user_prompt = row['name'] # Set the user_prompt to the 'name' field of the current row
videoid = row['videoid'] # Extract videoid for filename
random_seed = torch.randint(0, 2 ** 32 - 1, (1,)).item()
torch.manual_seed(random_seed)
print(f"current seed: {random_seed}")
print(f"sampling {user_prompt} ...")
sample = pipeline(
user_prompt,
num_inference_steps=model_config.steps,
guidance_scale=model_config.guidance_scale,
width=model_config.W,
height=model_config.H,
video_length=model_config.L,
).videos
# Adapt the filename to avoid conflicts and properly represent the content
save_videos_grid(sample, f"{savedir}/sample/{videoid}.gif")
print(f"save to {savedir}/sample/{videoid}.gif")
elif args.run_json:
print("run_json")
file_path = args.run_json
with open(file_path, 'r') as file:
data = json.load(file)
prompts = []
videoids = []
senids = []
for item in data:
prompts.append(item['caption'])
videoids.append(item['video_id'])
senids.append(item['sen_id'])
n_prompts = list(model_config.n_prompt) * len(prompts) if len(
model_config.n_prompt) == 1 else model_config.n_prompt
random_seeds = model_config.get("seed", [-1])
random_seeds = [random_seeds] if isinstance(random_seeds, int) else list(random_seeds)
random_seeds = random_seeds * len(prompts) if len(random_seeds) == 1 else random_seeds
model_config.random_seed = []
for prompt_idx, (prompt, n_prompt, random_seed) in enumerate(zip(prompts, n_prompts, random_seeds)):
filename = f"MSRVTT/sample/{videoids[prompt_idx]}-{senids[prompt_idx]}.gif"
if os.path.exists(filename):
print(f"File {filename} already exists, skipping...")
continue
# manually set random seed for reproduction
if random_seed != -1:
torch.manual_seed(random_seed)
else:
torch.seed()
model_config.random_seed.append(torch.initial_seed())
print(f"current seed: {torch.initial_seed()}")
print(f"sampling {prompt} ...")
sample = pipeline(
prompt,
num_inference_steps=model_config.steps,
guidance_scale=model_config.guidance_scale,
width=model_config.W,
height=model_config.H,
video_length=model_config.L,
).videos
# Adapt the filename to avoid conflicts and properly represent the content
save_videos_grid(sample, filename)
print(f"save to {filename}")
else:
prompts = model_config.prompt
n_prompts = list(model_config.n_prompt) * len(prompts) if len(
model_config.n_prompt) == 1 else model_config.n_prompt
random_seeds = model_config.get("seed", [-1])
random_seeds = [random_seeds] if isinstance(random_seeds, int) else list(random_seeds)
random_seeds = random_seeds * len(prompts) if len(random_seeds) == 1 else random_seeds
model_config.random_seed = []
for prompt_idx, (prompt, n_prompt, random_seed) in enumerate(zip(prompts, n_prompts, random_seeds)):
# manually set random seed for reproduction
if random_seed != -1:
torch.manual_seed(random_seed)
np.random.seed(random_seed)
random.seed(random_seed)
else:
torch.seed()
model_config.random_seed.append(torch.initial_seed())
print(f"current seed: {torch.initial_seed()}")
print(f"sampling {prompt} ...")
sample = pipeline(
prompt,
negative_prompt=n_prompt,
num_inference_steps=model_config.steps,
guidance_scale=model_config.guidance_scale,
width=model_config.W,
height=model_config.H,
video_length=model_config.L,
).videos
samples.append(sample)
prompt = "-".join((prompt.replace("/", "").split(" ")[:10]))
save_videos_grid(sample, f"{savedir}/sample/{sample_idx}-{random_seed}-{prompt}.gif")
print(f"save to {savedir}/sample/{random_seed}-{prompt}.gif")
sample_idx += 1
samples = torch.concat(samples)
save_videos_grid(samples, f"{savedir}/merge_all.gif", n_rows=4)
OmegaConf.save(model_config, f"{savedir}/model_config.yaml")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, required=True)
parser.add_argument("--without-xformers", action="store_true")
parser.add_argument("--human", action="store_true", help="Enable human mode for interactive video generation")
parser.add_argument("--run-csv", type=str, default=None)
parser.add_argument("--run-json", type=str, default=None)
parser.add_argument("--save-path", type=str, default=None)
args = parser.parse_args()
main(args)