Spaces:
Sleeping
Sleeping
import gradio as gr | |
import os | |
import sys | |
import argparse | |
import random | |
import time | |
from omegaconf import OmegaConf | |
import torch | |
import torchvision | |
from pytorch_lightning import seed_everything | |
from huggingface_hub import hf_hub_download | |
from einops import repeat | |
import torchvision.transforms as transforms | |
from torchvision.utils import make_grid | |
from utils.utils import instantiate_from_config | |
from collections import OrderedDict | |
sys.path.insert(0, "scripts/evaluation") | |
from lvdm.models.samplers.ddim import DDIMSampler, DDIMStyleSampler | |
def load_model_checkpoint(model, ckpt): | |
state_dict = torch.load(ckpt, map_location="cpu") | |
if "state_dict" in list(state_dict.keys()): | |
state_dict = state_dict["state_dict"] | |
else: | |
# deepspeed | |
state_dict = OrderedDict() | |
for key in state_dict['module'].keys(): | |
state_dict[key[16:]]=state_dict['module'][key] | |
model.load_state_dict(state_dict, strict=False) | |
print('>>> model checkpoint loaded.') | |
return model | |
def download_model(): | |
REPO_ID = 'VideoCrafter/Text2Video-512' | |
filename_list = ['model.ckpt'] | |
os.makedirs('./checkpoints/videocrafter_t2v_320_512/', exist_ok=True) | |
for filename in filename_list: | |
local_file = os.path.join('./checkpoints/videocrafter_t2v_320_512/', filename) | |
if not os.path.exists(local_file): | |
hf_hub_download(repo_id=REPO_ID, filename=filename, local_dir='./checkpoints/videocrafter_t2v_320_512/', force_download=True) | |
REPO_ID = 'liuhuohuo/StyleCrafter' | |
filename_list = ['adapter_v1.pth', 'temporal_v1.pth'] | |
os.makedirs('./checkpoints/stylecrafter', exist_ok=True) | |
for filename in filename_list: | |
local_file = os.path.join('./checkpoints/stylecrafter', filename) | |
if not os.path.exists(local_file): | |
hf_hub_download(repo_id=REPO_ID, filename=filename, local_dir='./checkpoints/stylecrafter', force_download=True) | |
def infer(image, prompt, infer_type='image', seed=123, style_strength=1.0, steps=50): | |
download_model() | |
ckpt_path = 'checkpoints/videocrafter_t2v_320_512/model.ckpt' | |
adapter_ckpt_path = 'checkpoints/stylecrafter/adapter_v1.pth' | |
temporal_ckpt_path = 'checkpoints/stylecrafter/temporal_v1.pth' | |
if infer_type == 'image': | |
config_file='configs/inference_image_512_512.yaml' | |
h, w = 512 // 8, 512 // 8 | |
unconditional_guidance_scale = 7.5 | |
unconditional_guidance_scale_style = None | |
else: | |
config_file='configs/inference_video_320_512.yaml' | |
h, w = 320 // 8, 512 // 8 | |
unconditional_guidance_scale = 15.0 | |
unconditional_guidance_scale_style = 7.5 | |
config = OmegaConf.load(config_file) | |
model_config = config.pop("model", OmegaConf.create()) | |
model_config['params']['adapter_config']['params']['scale'] = style_strength | |
model = instantiate_from_config(model_config) | |
model = model.cuda() | |
# load ckpt | |
assert os.path.exists(ckpt_path), "Error: checkpoint Not Found!" | |
assert os.path.exists(adapter_ckpt_path), "Error: adapter checkpoint Not Found!" | |
assert os.path.exists(temporal_ckpt_path), "Error: temporal checkpoint Not Found!" | |
model = load_model_checkpoint(model, ckpt_path) | |
model.load_pretrained_adapter(adapter_ckpt_path) | |
if infer_type == 'video': | |
model.load_pretrained_temporal(temporal_ckpt_path) | |
model.eval() | |
seed_everything(seed) | |
batch_size=1 | |
channels = model.channels | |
frames = model.temporal_length if infer_type == 'video' else 1 | |
noise_shape = [batch_size, channels, frames, h, w] | |
# text cond | |
cond = model.get_learned_conditioning([prompt]) | |
neg_prompt = batch_size * [""] | |
uc = model.get_learned_conditioning(neg_prompt) | |
# style cond | |
style_transforms = torchvision.transforms.Compose([ | |
torchvision.transforms.Resize(512), | |
torchvision.transforms.CenterCrop(512), | |
torchvision.transforms.ToTensor(), | |
torchvision.transforms.Lambda(lambda x: x * 2. - 1.), | |
]) | |
style_img = style_transforms(image).unsqueeze(0).cuda() | |
style_cond = model.get_batch_style(style_img) | |
append_to_context = model.adapter(style_cond) | |
scale_scalar = model.adapter.scale_predictor(torch.concat([append_to_context, cond], dim=1)) | |
ddim_sampler = DDIMSampler(model) if infer_type == 'image' else DDIMStyleSampler(model) | |
samples, _ = ddim_sampler.sample(S=steps, | |
conditioning=cond, | |
batch_size=noise_shape[0], | |
shape=noise_shape[1:], | |
verbose=False, | |
unconditional_guidance_scale=unconditional_guidance_scale, | |
unconditional_guidance_scale_style=unconditional_guidance_scale_style, | |
unconditional_conditioning=uc, | |
eta=1.0, | |
temporal_length=noise_shape[2], | |
append_to_context=append_to_context, | |
scale_scalar=scale_scalar | |
) | |
samples = model.decode_first_stage(samples) | |
if infer_type == 'image': | |
samples = samples[:, :, 0, :, :].detach().cpu() | |
out_path = "./output.png" | |
torchvision.utils.save_image(samples, out_path, nrow=1, normalize=True, range=(-1, 1)) | |
elif infer_type == 'video': | |
samples = samples.detach().cpu() | |
out_path = "./output.mp4" | |
video = torch.clamp(samples, -1, 1) | |
video = video.permute(2, 0, 1, 3, 4) # [T, B, C, H, W] | |
frame_grids = [torchvision.utils.make_grid(video[t], nrow=1) for t in range(video.shape[0])] | |
grid = torch.stack(frame_grids, dim=0) | |
grid = (grid + 1.0) / 2.0 | |
grid = (grid * 255).permute(0, 2, 3, 1).numpy().astype('uint8') | |
torchvision.io.write_video(out_path, grid, fps=8, video_codec='h264', options={'crf': '10'}) | |
return out_path | |
def read_content(file_path: str) -> str: | |
"""read the content of target file | |
""" | |
with open(file_path, 'r', encoding='utf-8') as f: | |
content = f.read() | |
return content | |
demo_exaples = [ | |
['eval_data/3d_1.png', 'A bouquet of flowers in a vase.', 'image', 123, 1.0, 50], | |
['eval_data/craft_1.jpg', 'A modern cityscape with towering skyscrapers.', 'image', 124, 1.0, 50], | |
['eval_data/digital_art_2.jpeg', 'A lighthouse standing tall on a rocky coast.', 'image', 123, 1.0, 50], | |
['eval_data/oil_paint_2.jpg', 'A man playing the guitar on a city street.', 'image', 123, 1.0, 50], | |
['eval_data/craft_2.png', 'City street at night with bright lights and busy traffic.', 'video', 123, 1.0, 50], | |
['eval_data/anime_1.jpg', 'A field of sunflowers on a sunny day.', 'video', 123, 1.0, 50], | |
['eval_data/ink_2.jpeg', 'A knight riding a horse through a field.', 'video', 123, 1.0, 50], | |
['eval_data/oil_paint_2.jpg', 'A street performer playing the guitar.', 'video', 121, 1.0, 50], | |
['eval_data/icon_1.png', 'A campfire surrounded by tents.', 'video', 123, 1.0, 50], | |
] | |
css = """ | |
#input_img {max-height: 512px} | |
#output_vid {max-width: 512px;} | |
""" | |
with gr.Blocks(analytics_enabled=False, css=css) as demo_iface: | |
gr.HTML(read_content("header.html")) | |
with gr.Tab(label='Stylized Generation'): | |
with gr.Column(): | |
with gr.Row(): | |
with gr.Column(): | |
with gr.Row(): | |
input_style_ref = gr.Image(label="Style Reference",elem_id="input_img") | |
with gr.Row(): | |
input_prompt = gr.Text(label='Prompts') | |
with gr.Row(): | |
input_seed = gr.Slider(label='Random Seed', minimum=0, maximum=10000, step=1, value=123) | |
input_style_strength = gr.Slider(minimum=0.0, maximum=2.0, step=0.01, label='Style Strength', value=1.0) | |
with gr.Row(): | |
input_step = gr.Slider(minimum=1, maximum=75, step=1, elem_id="i2v_steps", label="Sampling steps", value=50) | |
input_type = gr.Radio(choices=["image", "video"], label="Generation Type", value="image") | |
input_end_btn = gr.Button("Generate") | |
# with gr.Tab(label='Result'): | |
with gr.Row(): | |
output_result = gr.Video(label="Generated Results",elem_id="output_vid",autoplay=True,show_share_button=True) | |
gr.Examples(examples=demo_exaples, | |
inputs=[input_style_ref, input_prompt, input_type, input_seed, input_style_strength, input_step], | |
outputs=[output_result], | |
fn = infer, | |
) | |
input_end_btn.click(inputs=[input_style_ref, input_prompt, input_type, input_seed, input_style_strength, input_step], | |
outputs=[output_result], | |
fn = infer | |
) | |
demo_iface.queue(max_size=12).launch(show_api=True) |