Spaces:
Running
on
A10G
Running
on
A10G
import os | |
import json | |
import time | |
import torch | |
import random | |
import inspect | |
import argparse | |
import numpy as np | |
import pandas as pd | |
from pathlib import Path | |
from omegaconf import OmegaConf | |
from transformers import CLIPTextModel, CLIPTokenizer | |
from diffusers import AutoencoderKL, DDIMScheduler | |
from diffusers.utils.import_utils import is_xformers_available | |
from utils.unet import UNet3DConditionModel | |
from utils.pipeline_magictime import MagicTimePipeline | |
from utils.util import save_videos_grid | |
from utils.util import load_weights | |
def main(args): | |
*_, func_args = inspect.getargvalues(inspect.currentframe()) | |
func_args = dict(func_args) | |
if 'counter' not in globals(): | |
globals()['counter'] = 0 | |
unique_id = globals()['counter'] | |
globals()['counter'] += 1 | |
savedir_base = f"{Path(args.config).stem}" | |
savedir_prefix = "outputs" | |
savedir = None | |
if args.save_path: | |
savedir = os.path.join(savedir_prefix, args.save_path, f"{savedir_base}-{unique_id}") | |
else: | |
savedir = os.path.join(savedir_prefix, f"{savedir_base}-{unique_id}") | |
while os.path.exists(savedir): | |
unique_id = globals()['counter'] | |
globals()['counter'] += 1 | |
if args.save_path: | |
savedir = os.path.join(savedir_prefix, args.save_path, f"{savedir_base}-{unique_id}") | |
else: | |
savedir = os.path.join(savedir_prefix, f"{savedir_base}-{unique_id}") | |
os.makedirs(savedir) | |
print(f"The results will be save to {savedir}") | |
model_config = OmegaConf.load(args.config)[0] | |
inference_config = OmegaConf.load(args.config)[1] | |
if model_config.magic_adapter_s_path: | |
print("Use MagicAdapter-S") | |
if model_config.magic_adapter_t_path: | |
print("Use MagicAdapter-T") | |
if model_config.magic_text_encoder_path: | |
print("Use Magic_Text_Encoder") | |
samples = [] | |
# create validation pipeline | |
tokenizer = CLIPTokenizer.from_pretrained(model_config.pretrained_model_path, subfolder="tokenizer") | |
text_encoder = CLIPTextModel.from_pretrained(model_config.pretrained_model_path, subfolder="text_encoder").cuda() | |
vae = AutoencoderKL.from_pretrained(model_config.pretrained_model_path, subfolder="vae").cuda() | |
unet = UNet3DConditionModel.from_pretrained_2d(model_config.pretrained_model_path, subfolder="unet", | |
unet_additional_kwargs=OmegaConf.to_container( | |
inference_config.unet_additional_kwargs)).cuda() | |
# set xformers | |
if is_xformers_available() and (not args.without_xformers): | |
unet.enable_xformers_memory_efficient_attention() | |
pipeline = MagicTimePipeline( | |
vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, | |
scheduler=DDIMScheduler(**OmegaConf.to_container(inference_config.noise_scheduler_kwargs)), | |
).to("cuda") | |
pipeline = load_weights( | |
pipeline, | |
motion_module_path=model_config.get("motion_module", ""), | |
dreambooth_model_path=model_config.get("dreambooth_path", ""), | |
magic_adapter_s_path=model_config.get("magic_adapter_s_path", ""), | |
magic_adapter_t_path=model_config.get("magic_adapter_t_path", ""), | |
magic_text_encoder_path=model_config.get("magic_text_encoder_path", ""), | |
).to("cuda") | |
sample_idx = 0 | |
if args.human: | |
sample_idx = 0 # Initialize sample index | |
while True: | |
user_prompt = input("Enter your prompt (or type 'exit' to quit): ") | |
if user_prompt.lower() == "exit": | |
break | |
random_seed = torch.randint(0, 2 ** 32 - 1, (1,)).item() | |
torch.manual_seed(random_seed) | |
print(f"current seed: {random_seed}") | |
print(f"sampling {user_prompt} ...") | |
# Now, you directly use `user_prompt` to generate a video. | |
# The following is a placeholder call; you need to adapt it to your actual video generation function. | |
sample = pipeline( | |
user_prompt, | |
num_inference_steps=model_config.steps, | |
guidance_scale=model_config.guidance_scale, | |
width=model_config.W, | |
height=model_config.H, | |
video_length=model_config.L, | |
).videos | |
# Adapt the filename to avoid conflicts and properly represent the content | |
prompt_for_filename = "-".join(user_prompt.replace("/", "").split(" ")[:10]) | |
save_videos_grid(sample, f"{savedir}/sample/{sample_idx}-{random_seed}-{prompt_for_filename}.gif") | |
print(f"save to {savedir}/sample/{sample_idx}-{random_seed}-{prompt_for_filename}.gif") | |
sample_idx += 1 | |
elif args.run_csv: | |
print("run_csv") | |
file_path = args.run_csv | |
data = pd.read_csv(file_path) | |
for index, row in data.iterrows(): | |
user_prompt = row['name'] # Set the user_prompt to the 'name' field of the current row | |
videoid = row['videoid'] # Extract videoid for filename | |
random_seed = torch.randint(0, 2 ** 32 - 1, (1,)).item() | |
torch.manual_seed(random_seed) | |
print(f"current seed: {random_seed}") | |
print(f"sampling {user_prompt} ...") | |
sample = pipeline( | |
user_prompt, | |
num_inference_steps=model_config.steps, | |
guidance_scale=model_config.guidance_scale, | |
width=model_config.W, | |
height=model_config.H, | |
video_length=model_config.L, | |
).videos | |
# Adapt the filename to avoid conflicts and properly represent the content | |
save_videos_grid(sample, f"{savedir}/sample/{videoid}.gif") | |
print(f"save to {savedir}/sample/{videoid}.gif") | |
elif args.run_json: | |
print("run_json") | |
file_path = args.run_json | |
with open(file_path, 'r') as file: | |
data = json.load(file) | |
prompts = [] | |
videoids = [] | |
senids = [] | |
for item in data: | |
prompts.append(item['caption']) | |
videoids.append(item['video_id']) | |
senids.append(item['sen_id']) | |
n_prompts = list(model_config.n_prompt) * len(prompts) if len( | |
model_config.n_prompt) == 1 else model_config.n_prompt | |
random_seeds = model_config.get("seed", [-1]) | |
random_seeds = [random_seeds] if isinstance(random_seeds, int) else list(random_seeds) | |
random_seeds = random_seeds * len(prompts) if len(random_seeds) == 1 else random_seeds | |
model_config.random_seed = [] | |
for prompt_idx, (prompt, n_prompt, random_seed) in enumerate(zip(prompts, n_prompts, random_seeds)): | |
filename = f"MSRVTT/sample/{videoids[prompt_idx]}-{senids[prompt_idx]}.gif" | |
if os.path.exists(filename): | |
print(f"File {filename} already exists, skipping...") | |
continue | |
# manually set random seed for reproduction | |
if random_seed != -1: | |
torch.manual_seed(random_seed) | |
else: | |
torch.seed() | |
model_config.random_seed.append(torch.initial_seed()) | |
print(f"current seed: {torch.initial_seed()}") | |
print(f"sampling {prompt} ...") | |
sample = pipeline( | |
prompt, | |
num_inference_steps=model_config.steps, | |
guidance_scale=model_config.guidance_scale, | |
width=model_config.W, | |
height=model_config.H, | |
video_length=model_config.L, | |
).videos | |
# Adapt the filename to avoid conflicts and properly represent the content | |
save_videos_grid(sample, filename) | |
print(f"save to {filename}") | |
else: | |
prompts = model_config.prompt | |
n_prompts = list(model_config.n_prompt) * len(prompts) if len( | |
model_config.n_prompt) == 1 else model_config.n_prompt | |
random_seeds = model_config.get("seed", [-1]) | |
random_seeds = [random_seeds] if isinstance(random_seeds, int) else list(random_seeds) | |
random_seeds = random_seeds * len(prompts) if len(random_seeds) == 1 else random_seeds | |
model_config.random_seed = [] | |
for prompt_idx, (prompt, n_prompt, random_seed) in enumerate(zip(prompts, n_prompts, random_seeds)): | |
# manually set random seed for reproduction | |
if random_seed != -1: | |
torch.manual_seed(random_seed) | |
np.random.seed(random_seed) | |
random.seed(random_seed) | |
else: | |
torch.seed() | |
model_config.random_seed.append(torch.initial_seed()) | |
print(f"current seed: {torch.initial_seed()}") | |
print(f"sampling {prompt} ...") | |
sample = pipeline( | |
prompt, | |
negative_prompt=n_prompt, | |
num_inference_steps=model_config.steps, | |
guidance_scale=model_config.guidance_scale, | |
width=model_config.W, | |
height=model_config.H, | |
video_length=model_config.L, | |
).videos | |
samples.append(sample) | |
prompt = "-".join((prompt.replace("/", "").split(" ")[:10])) | |
save_videos_grid(sample, f"{savedir}/sample/{sample_idx}-{random_seed}-{prompt}.gif") | |
print(f"save to {savedir}/sample/{random_seed}-{prompt}.gif") | |
sample_idx += 1 | |
samples = torch.concat(samples) | |
save_videos_grid(samples, f"{savedir}/merge_all.gif", n_rows=4) | |
OmegaConf.save(model_config, f"{savedir}/model_config.yaml") | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--config", type=str, required=True) | |
parser.add_argument("--without-xformers", action="store_true") | |
parser.add_argument("--human", action="store_true", help="Enable human mode for interactive video generation") | |
parser.add_argument("--run-csv", type=str, default=None) | |
parser.add_argument("--run-json", type=str, default=None) | |
parser.add_argument("--save-path", type=str, default=None) | |
args = parser.parse_args() | |
main(args) | |