Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import cv2 | |
import torch | |
import numpy as np | |
import os | |
from control_cogvideox.cogvideox_transformer_3d import CogVideoXTransformer3DModel | |
from control_cogvideox.controlnet_cogvideox_transformer_3d import ControlCogVideoXTransformer3DModel | |
from pipeline_cogvideox_controlnet_5b_i2v_instruction2 import ControlCogVideoXPipeline | |
from diffusers.utils import export_to_video | |
from diffusers import AutoencoderKLCogVideoX | |
from transformers import T5EncoderModel, T5Tokenizer | |
from diffusers.schedulers import CogVideoXDDIMScheduler | |
from omegaconf import OmegaConf | |
from transformers import T5EncoderModel | |
from einops import rearrange | |
import decord | |
from typing import List | |
from tqdm import tqdm | |
import PIL | |
import torch.nn.functional as F | |
from torchvision import transforms | |
import spaces | |
from huggingface_hub import snapshot_download | |
import time | |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
def download_model(): | |
snapshot_download(repo_id="THUDM/CogVideoX-5b-I2V", local_dir="./cogvideox-5b-i2v") | |
print("Download completed") | |
def download_model_senorita(): | |
snapshot_download(repo_id="PengWeixuanSZU/Senorita-2M", local_dir="./senorita-2m") | |
print("Download completed") | |
download_model() | |
download_model_senorita() | |
print("Download successfully!") | |
def get_prompt(file:str): | |
with open(file,'r') as f: | |
a=f.readlines() | |
return a #a[0]:positive prompt, a[1] negative prompt | |
def unwarp_model(state_dict): | |
new_state_dict = {} | |
for key in state_dict: | |
new_state_dict[key.split('module.')[1]] = state_dict[key] | |
return new_state_dict | |
def init_pipe(): | |
i2v=True | |
if i2v: | |
key = "i2v" | |
else: | |
key = "t2v" | |
noise_scheduler = CogVideoXDDIMScheduler( | |
**OmegaConf.to_container( | |
OmegaConf.load(f"./cogvideox-5b-{key}/scheduler/scheduler_config.json") | |
) | |
) | |
text_encoder = T5EncoderModel.from_pretrained(f"./cogvideox-5b-{key}/", subfolder="text_encoder", torch_dtype=torch.float16) | |
vae = AutoencoderKLCogVideoX.from_pretrained(f"./cogvideox-5b-{key}/", subfolder="vae", torch_dtype=torch.float16) | |
tokenizer = T5Tokenizer.from_pretrained(f"./cogvideox-5b-{key}/tokenizer", torch_dtype=torch.float16) | |
config = OmegaConf.to_container( | |
OmegaConf.load(f"./cogvideox-5b-{key}/transformer/config.json") | |
) | |
if i2v: | |
config["in_channels"] = 32 | |
else: | |
config["in_channels"] = 16 | |
transformer = CogVideoXTransformer3DModel(**config) | |
control_config = OmegaConf.to_container( | |
OmegaConf.load(f"./cogvideox-5b-{key}/transformer/config.json") | |
) | |
if i2v: | |
control_config["in_channels"] = 32 | |
else: | |
control_config["in_channels"] = 16 | |
control_config['num_layers'] = 6 | |
control_config['control_in_channels'] = 16 | |
controlnet_transformer = ControlCogVideoXTransformer3DModel(**control_config) | |
all_state_dicts = torch.load("./senorita-2m/models_half/ff_controlnet_half.pth", map_location="cpu",weights_only=True) | |
transformer_state_dict = unwarp_model(all_state_dicts["transformer_state_dict"]) | |
controlnet_transformer_state_dict = unwarp_model(all_state_dicts["controlnet_transformer_state_dict"]) | |
transformer.load_state_dict(transformer_state_dict, strict=True) | |
controlnet_transformer.load_state_dict(controlnet_transformer_state_dict, strict=True) | |
transformer = transformer.half() | |
controlnet_transformer = controlnet_transformer.half() | |
vae = vae.eval() | |
text_encoder = text_encoder.eval() | |
transformer = transformer.eval() | |
controlnet_transformer = controlnet_transformer.eval() | |
pipe = ControlCogVideoXPipeline(tokenizer, | |
text_encoder, | |
vae, | |
transformer, | |
noise_scheduler, | |
controlnet_transformer, | |
) | |
pipe.vae.enable_slicing() | |
pipe.vae.enable_tiling() | |
pipe.enable_model_cpu_offload() | |
return pipe | |
def inference(source_images, | |
target_images, | |
text_prompt, negative_prompt, | |
pipe, vae, guidance_scale, | |
h, w, random_seed)->List[PIL.Image.Image]: | |
torch.manual_seed(random_seed) | |
pipe.vae.to(DEVICE) | |
pipe.transformer.to(DEVICE) | |
pipe.controlnet_transformer.to(DEVICE) | |
source_pixel_values = source_images/127.5 - 1.0 | |
source_pixel_values = source_pixel_values.to(torch.float16).to(DEVICE) | |
if target_images is not None: | |
target_pixel_values = target_images/127.5 - 1.0 | |
target_pixel_values = target_pixel_values.to(torch.float16).to(DEVICE) | |
bsz,f,h,w,c = source_pixel_values.shape | |
with torch.no_grad(): | |
source_pixel_values = rearrange(source_pixel_values, "b f w h c -> b c f w h") | |
source_latents = vae.encode(source_pixel_values).latent_dist.sample() | |
source_latents = source_latents.to(torch.float16) | |
source_latents = source_latents * vae.config.scaling_factor | |
source_latents = rearrange(source_latents, "b c f h w -> b f c h w") | |
if target_images is not None: | |
target_pixel_values = rearrange(target_pixel_values, "b f w h c -> b c f w h") | |
images = target_pixel_values[:,:,:1,...] | |
image_latents = vae.encode(images).latent_dist.sample() | |
image_latents = image_latents.to(torch.float16) | |
image_latents = image_latents * vae.config.scaling_factor | |
image_latents = rearrange(image_latents, "b c f h w -> b f c h w") | |
image_latents = torch.cat([image_latents, torch.zeros_like(source_latents)[:,1:]],dim=1) | |
latents = torch.cat([image_latents, source_latents], dim=2) | |
else: | |
image_latents = None | |
latents = source_latents | |
a=time.perf_counter() | |
video = pipe( | |
prompt = text_prompt, | |
negative_prompt = negative_prompt, | |
video_condition = source_latents, # input to controlnet | |
video_condition2 = image_latents, # concat with latents | |
height = h, | |
width = w, | |
num_frames = f, | |
num_inference_steps = 20, | |
interval = 6, | |
guidance_scale = guidance_scale, | |
generator = torch.Generator(device=DEVICE).manual_seed(random_seed) | |
).frames[0] | |
b=time.perf_counter() | |
print(f"Denoise 5 steps in {b-a}s") | |
return video | |
def process_video(video_file, image_file, positive_prompt, negative_prompt, guidance, random_seed, choice, progress=gr.Progress(track_tqdm=True))->str: | |
if choice==33: | |
video_shard=1 | |
elif choice==65: | |
video_shard=2 | |
pipe=PIPE | |
h = 448 | |
w = 768 | |
frames_per_shard=33 | |
#get image | |
image = cv2.imread(image_file) | |
resized_image = cv2.resize(image, (768, 448)) | |
resized_image = cv2.cvtColor(resized_image, cv2.COLOR_BGR2RGB) | |
image=torch.from_numpy(resized_image) | |
#get mp4 | |
vr = decord.VideoReader(video_file) | |
frames = vr.get_batch(list(range(33))).asnumpy() | |
_,src_h,src_w,_=frames.shape | |
resized_frames = [cv2.resize(frame, (768, 448)) for frame in frames] | |
images=torch.from_numpy(np.array(resized_frames)) | |
target_path="outputvideo" | |
source_images = images[None,...] | |
target_images = image[None,None,...] | |
video:List[PIL.Image.Image]=[] | |
for i in progress.tqdm(range(video_shard)): | |
if i>0: #first frame guidence | |
first_frame=transforms.ToTensor()(video[-1]) | |
first_frame = first_frame*255.0 | |
first_frame = rearrange(first_frame,"c w h -> w h c") | |
source_images=source_images | |
target_images=first_frame[None,None,...] | |
video+=inference(source_images, \ | |
target_images, positive_prompt, \ | |
negative_prompt, pipe, pipe.vae, \ | |
guidance, \ | |
h, w, random_seed) | |
i+=1 | |
video=[image.resize((int(src_w/src_h*448),448))for image in video] | |
os.makedirs(f"./{target_path}", exist_ok=True) | |
output_path:str=f"./{target_path}/output_{video_file[-5]}.mp4" | |
export_to_video(video, output_path, fps=8) | |
return output_path | |
PIPE=init_pipe() | |
with gr.Blocks() as demo: | |
gr.Markdown( | |
""" | |
# Señorita-2M: A High-Quality Instruction-based Dataset for General Video Editing by Video Specialists | |
[Paper](https://arxiv.org/abs/2502.06734) | [Code](https://github.com/zibojia/SENORITA) | [Huggingface](https://huggingface.co/datasets/SENORITADATASET/Senorita) | |
<small>This is the official implementation of Señorita. The original model requires 50 denoising steps to generate a video. | |
However, due to GPU usage limitations on Hugging Face Spaces, we have reduced the number of denoising steps to 20, which takes about 240s to generate one video. | |
As a result, the performance may be slightly affected. Thank you for your understanding! This UI is made by [PengWeixuanSZU](https://huggingface.co/PengWeixuanSZU).</small> | |
""" | |
) | |
with gr.Row(): | |
video_input = gr.Video(label="Video input") | |
image_input = gr.Image(type="filepath", label="First frame guidence") | |
with gr.Row(): | |
with gr.Column(): | |
positive_prompt = gr.Textbox(label="Positive prompt",value="") | |
negative_prompt = gr.Textbox(label="Negative prompt",value="") | |
seed = gr.Slider(minimum=0, maximum=2147483647, step=1, value=0, label="Seed") | |
guidance_slider = gr.Slider(minimum=1, maximum=10, value=4, label="Guidance") | |
choice=gr.Radio(choices=[33,65],label="Frame number",value=33) | |
with gr.Column(): | |
video_output = gr.Video(label="Video output") | |
with gr.Row(): | |
submit_button = gr.Button("Generate") | |
submit_button.click(fn=process_video, inputs=[video_input, image_input, positive_prompt, negative_prompt, guidance_slider, seed, choice], outputs=video_output) | |
with gr.Row(): | |
gr.Examples( | |
[ | |
["assets/0.mp4","assets/0_edit.png",get_prompt("assets/0.txt")[0],get_prompt("assets/0.txt")[1],4,0,33], | |
["assets/1.mp4","assets/1_edit.png",get_prompt("assets/1.txt")[0],get_prompt("assets/1.txt")[1],4,0,33], | |
["assets/2.mp4","assets/2_edit.png",get_prompt("assets/2.txt")[0],get_prompt("assets/2.txt")[1],4,0,33], | |
["assets/3.mp4","assets/3_edit.png",get_prompt("assets/3.txt")[0],get_prompt("assets/3.txt")[1],4,0,33], | |
["assets/4.mp4","assets/4_edit.png",get_prompt("assets/4.txt")[0],get_prompt("assets/4.txt")[1],4,0,33], | |
["assets/5.mp4","assets/5_edit.png",get_prompt("assets/5.txt")[0],get_prompt("assets/5.txt")[1],4,0,33], | |
["assets/6.mp4","assets/6_edit.png",get_prompt("assets/6.txt")[0],get_prompt("assets/6.txt")[1],4,0,33], | |
["assets/7.mp4","assets/7_edit.png",get_prompt("assets/7.txt")[0],get_prompt("assets/7.txt")[1],4,0,33], | |
["assets/8.mp4","assets/8_edit.png",get_prompt("assets/8.txt")[0],get_prompt("assets/8.txt")[1],4,0,33] | |
], | |
inputs=[video_input, image_input, positive_prompt, negative_prompt, guidance_slider, seed, choice], | |
outputs=video_output, | |
fn=process_video, | |
cache_examples=False | |
) | |
demo.queue().launch() | |