import os import imageio import numpy as np import torch import random import spaces import gradio as gr import torchvision import torchvision.transforms as T from einops import rearrange from huggingface_hub import hf_hub_download from torchvision.models.optical_flow import raft_large, Raft_Large_Weights from torchvision.utils import flow_to_image from diffusers import AutoencoderKL, MotionAdapter, UNet2DConditionModel from diffusers import DDIMScheduler from transformers import CLIPTextModel, CLIPTokenizer from onlyflow.models.flow_adaptor import FlowEncoder, FlowAdaptor from onlyflow.models.unet import UNetMotionModel from onlyflow.pipelines.pipeline_animation_long import FlowCtrlPipeline from tools.optical_flow import get_optical_flow def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=8): videos = rearrange(videos, "b c t h w -> t b c h w") outputs = [] for x in videos: x = torchvision.utils.make_grid(x, nrow=n_rows) x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) if rescale: x = (x + 1.0) / 2.0 # -1,1 -> 0,1 x = (x * 255).numpy().astype(np.uint8) outputs.append(x) os.makedirs(os.path.dirname(path), exist_ok=True) imageio.mimsave(path, outputs, fps=fps) css = """ .toolbutton { margin-buttom: 0em 0em 0em 0em; max-width: 2.5em; min-width: 2.5em !important; height: 2.5em; } """ class AnimateController: def __init__(self): # config dirs self.basedir = os.getcwd() self.stable_diffusion_dir = os.path.join(self.basedir, "models", "StableDiffusion") self.motion_module_dir = os.path.join(self.basedir, "models", "Motion_Module") self.personalized_model_dir = os.path.join(self.basedir, "models", "DreamBooth_LoRA") self.savedir = os.path.join(self.basedir, "samples") os.makedirs(self.savedir, exist_ok=True) ckpt_path = hf_hub_download('obvious-research/onlyflow', 'weights_fp16.ckpt') ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=True) self.flow_encoder_state_dict = ckpt['flow_encoder_state_dict'] self.attention_processor_state_dict = ckpt['attention_processor_state_dict'] self.tokenizer = None self.text_encoder = None self.vae = None self.unet = None self.motion_adapter = None def update_base_model(self, base_model_id, progress=gr.Progress()): progress(0, desc="Starting...") self.tokenizer = CLIPTokenizer.from_pretrained(base_model_id, subfolder="tokenizer") self.text_encoder = CLIPTextModel.from_pretrained(base_model_id, subfolder="text_encoder") self.vae = AutoencoderKL.from_pretrained(base_model_id, subfolder="vae") self.unet = UNet2DConditionModel.from_pretrained(base_model_id, subfolder="unet") return base_model_id def update_motion_module(self, motion_module_id, progress=gr.Progress()): self.motion_adapter = MotionAdapter.from_pretrained(motion_module_id) def animate( self, id_base_model, id_motion_module, prompt_textbox_positive, prompt_textbox_negative, seed_textbox, input_video, height, width, flow_scale, cfg, diffusion_steps, temporal_ds, ctx_stride ): #if any([x is None for x in [self.tokenizer, self.text_encoder, self.vae, self.unet, self.motion_adapter]]) or isinstance(self.unet, str): self.update_base_model(id_base_model) self.update_motion_module(id_motion_module) self.unet = UNetMotionModel.from_unet2d( self.unet, motion_adapter=self.motion_adapter ) self.raft = raft_large(weights=Raft_Large_Weights.DEFAULT, progress=False).eval() self.flow_encoder = FlowEncoder( downscale_factor=8, channels=[320, 640, 1280, 1280], nums_rb=2, ksize=1, sk=True, use_conv=False, compression_factor=1, temporal_attention_nhead=8, positional_embeddings="sinusoidal", num_positional_embeddings=16, checkpointing=False ).eval() self.vae.requires_grad_(False) self.text_encoder.requires_grad_(False) self.unet.requires_grad_(False) self.raft.requires_grad_(False) self.flow_encoder.requires_grad_(False) self.unet.set_all_attn( flow_channels=[320, 640, 1280, 1280], add_spatial=False, add_temporal=True, encoder_only=False, query_condition=True, key_value_condition=True, flow_scale=1.0, ) self.flow_adaptor = FlowAdaptor(self.unet, self.flow_encoder).eval() # load the flow encoder weights pose_enc_m, pose_enc_u = self.flow_adaptor.flow_encoder.load_state_dict( self.flow_encoder_state_dict, strict=False ) assert len(pose_enc_m) == 0 and len(pose_enc_u) == 0 # load the attention processor weights _, attention_processor_u = self.flow_adaptor.unet.load_state_dict( self.attention_processor_state_dict, strict=False ) assert len(attention_processor_u) == 0 pipeline = FlowCtrlPipeline( vae=self.vae, text_encoder=self.text_encoder, tokenizer=self.tokenizer, unet=self.unet, motion_adapter=self.motion_adapter, flow_encoder=self.flow_encoder, scheduler=DDIMScheduler.from_pretrained(id_base_model, subfolder="scheduler"), ) if int(seed_textbox) > 0: seed = int(seed_textbox) else: seed = random.randint(1, int(1e16)) return animate_diffusion(seed, pipeline, self.raft, input_video, prompt_textbox_positive, prompt_textbox_negative, width, height, flow_scale, cfg, diffusion_steps, temporal_ds, ctx_stride) @spaces.GPU(duration=150) def animate_diffusion(seed, pipeline, raft_model, base_video, prompt_textbox, negative_prompt_textbox, width_slider, height_slider, flow_scale, cfg, diffusion_steps, temporal_ds, context_stride): savedir = './samples' device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" generator = torch.Generator(device="cpu") generator.manual_seed(seed) raft_model = raft_model.to(device) pipeline = pipeline.to(device) pixel_values = torchvision.io.read_video(base_video, output_format="TCHW", pts_unit='sec')[0][::temporal_ds] print("Video loaded, shape:", pixel_values.shape) if width_slider/height_slider > pixel_values.shape[3]/pixel_values.shape[2]: print("Resizing video to fit width cause input video is not wide enough") temp_height = int(width_slider * pixel_values.shape[2]/pixel_values.shape[3]) temp_width = width_slider else: print("Resizing video to fit height cause input video is not tall enough") temp_height = height_slider temp_width = int(height_slider * pixel_values.shape[3]/pixel_values.shape[2]) print("Resizing video to:", temp_height, temp_width) pixel_values = T.Resize((temp_height, temp_width))(pixel_values) pixel_values = T.CenterCrop((height_slider, width_slider))(pixel_values) pixel_values = T.ConvertImageDtype(torch.float32)(pixel_values)[None, ...].contiguous().to(device) save_sample_path_input = os.path.join(savedir, f"input.mp4") pixel_values_save = pixel_values[0] * 255 pixel_values_save = pixel_values_save.cpu() pixel_values_save = torch.permute(pixel_values_save, (0, 2, 3, 1)) torchvision.io.write_video(save_sample_path_input, pixel_values_save, fps=8) del pixel_values_save print("Video loaded, shape:", pixel_values.shape) flow = get_optical_flow( raft_model, (pixel_values * 2) - 1, pixel_values.shape[1] - 1, encode_chunk_size=16, ).to('cpu') sample_flow = (flow_to_image(rearrange(flow[0], "c f h w -> f c h w"))) # N, 3, H, W save_sample_path_flow = os.path.join(savedir, f"flow.mp4") sample_flow = (sample_flow).cpu().to(torch.uint8).permute(0, 2, 3, 1) torchvision.io.write_video(save_sample_path_flow, sample_flow, fps=8) del sample_flow original_flow_shape = flow.shape print("Optical flow computed, shape:", flow.shape) if flow.shape[2] < 16: print("Video is too short, padding to 16 frames") video_length = 16 n = 16 - flow.shape[2] # create a tensor containing the last frame optical flow repeated n times to_add = flow[:, :, -1].unsqueeze(2).expand(-1, -1, n, -1, -1) flow = torch.cat([flow, to_add], dim=2).to(device) elif flow.shape[2] > 16: print("Video is too long, enabling windowing") print("Enabling model CPU offload") pipeline.enable_model_cpu_offload() print("Enabling VAE slicing") pipeline.enable_vae_slicing() print("Enabling VAE tiling") pipeline.enable_vae_tiling() print("Enabling free noise") pipeline.enable_free_noise( context_length=16, context_stride=context_stride, ) import math def find_divisors(n: int): """ Return sorted list of all positive divisors of n. Uses a sqrt(n) approach for efficiency. """ divs = set() limit = int(math.isqrt(n)) for i in range(1, limit + 1): if n % i == 0: divs.add(i) divs.add(n // i) return sorted(divs) def multiples_in_range(k: int, min_val: int, max_val: int): """ Return all multiples of k within [min_val, max_val]. """ if k == 0: return [] # First multiple of k >= min_val start = ((min_val + k - 1) // k) * k # Last multiple of k <= max_val end = (max_val // k) * k return list(range(start, end + 1, k)) if start <= end else [] def adjust_video_length(original_length: int, context_stride: int, chunk_size: int, temporal_split_size: int) -> int: """ Find the minimal video_length >= original_length satisfying: 1) (video_length - 16) is divisible by context_stride. 2) EITHER (2*video_length) is divisible by temporal_split_size OR (2*video_length) is divisible by chunk_size (when 2*video_length is not multiple of temporal_split_size). """ # We start at least at 16 (though in practice original_length likely > 16) candidate = max(original_length, 16) # We want (candidate - 16) % context_stride == 0 # so let n be the multiple to step. # n is how many times we add `context_stride` beyond 16. # This ensures (candidate - 16) is a multiple of context_stride. # Then we check the second condition, else keep stepping. # If candidate < 16, bump it to 16 if candidate < 16: candidate = 16 # Make sure we jump to the correct "starting multiple" of context_stride offset = (candidate - 16) % context_stride if offset != 0: candidate += (context_stride - offset) # jump to the next multiple while True: # Condition: (candidate - 16) is multiple of context_stride (already enforced by stepping) # Check second part: # - if (2*candidate) % temporal_split_size == 0, we are good # - else we require (2*candidate) % chunk_size == 0 twoL = 2 * candidate if (twoL % temporal_split_size == 0) or (twoL % chunk_size == 0): return candidate # Go to next valid candidate candidate += context_stride def find_valid_configs(original_video_length: int, width: int, height: int, context_stride: int): """ Generate all valid tuples (chunk_size, spatial_split_size, temporal_split_size, video_length) subject to the constraints: 1) chunk_size divides temporal_split_size 2) chunk_size divides spatial_split_size 3) chunk_size divides (2 * (width//64) * (height//64)) 4) if (2*video_length) % temporal_split_size != 0, then chunk_size divides (2*video_length) 5) context_stride divides (video_length - 16) 6) 128 <= spatial_split_size <= 512 7) 1 <= temporal_split_size <= 32 8) 1 <= chunk_size <= 16 We allow increasing original_video_length minimally if needed to satisfy constraints #4 and #5. """ factor = 2 * (width // 64) * (height // 64) # 1) find all possible chunk_size as divisors of factor, in [1..16] possible_chunks = [d for d in find_divisors(factor) if 1 <= d <= 32] # For storing results valid_tuples = [] for chunk_size in possible_chunks: # 2) generate all spatial_split_size in [128..512] that are multiples of chunk_size spatial_splits = multiples_in_range(chunk_size, 480, 512) # 3) generate all temporal_split_size in [1..32] that are multiples of chunk_size temporal_splits = multiples_in_range(chunk_size, 1, 32) for ssp in spatial_splits: for tsp in temporal_splits: # 4) & 5) Adjust video_length minimally to satisfy constraints final_length = adjust_video_length(original_video_length, context_stride, chunk_size, tsp) # Now we have a valid (chunk_size, ssp, tsp, final_length) valid_tuples.append((chunk_size, ssp, tsp, final_length)) return valid_tuples def find_pareto_optimal(configs): """ Given a list of tuples (chunk_size, spatial_split_size, temporal_split_size, video_length), return the Pareto-optimal subset under the criteria: - chunk_size: larger is better - spatial_split_size: larger is better - temporal_split_size: larger is better - video_length: smaller is better """ def dominates(A, B): cA, sA, tA, lA = A cB, sB, tB, lB = B # A dominates B if: # cA >= cB, sA >= sB, tA >= tB, and lA <= lB # AND at least one of these is a strict inequality. better_or_equal = (cA >= cB) and (tA >= tB) and (lA <= lB) strictly_better = (cA > cB) or (tA > tB) or (lA < lB) return better_or_equal and strictly_better pareto = [] for i, cfg_i in enumerate(configs): # Check if cfg_i is dominated by any cfg_j is_dominated = False for j, cfg_j in enumerate(configs): if i == j: continue if dominates(cfg_j, cfg_i): is_dominated = True break if not is_dominated: pareto.append(cfg_i) return pareto print("Finding valid configurations...") valid_configs = find_valid_configs( original_video_length=flow.shape[2], width=width_slider, height=height_slider, context_stride=context_stride ) print("Found", len(valid_configs), "valid configurations") print("Finding Pareto-optimal configurations...") pareto_optimal = find_pareto_optimal(valid_configs) print("Found", pareto_optimal) criteria = lambda cs, sss, tss, vl: cs + tss - 3 * int(abs(flow.shape[2] - vl) / 10) pareto_optimal.sort(key=lambda x: criteria(*x), reverse=True) print("Found sorted", pareto_optimal) solution = pareto_optimal[0] chunk_size, spatial_split_size, temporal_split_size, video_length = solution n = video_length - original_flow_shape[2] to_add = flow[:, :, -1].unsqueeze(2).expand(-1, -1, n, -1, -1) flow = torch.cat([flow, to_add], dim=2) pipeline.enable_free_noise_split_inference( temporal_split_size=temporal_split_size, spatial_split_size=spatial_split_size ) pipeline.unet.enable_forward_chunking(chunk_size) print("Chunking enabled with chunk size:", chunk_size) print("Temporal split size:", temporal_split_size) print("Spatial split size:", spatial_split_size) print("Context stride:", context_stride) print("Temporal downscale:", temporal_ds) print("Video length:", video_length) print("Flow shape:", flow.shape) else: print("Video is just right, no padding or windowing needed") flow = flow.to(device) video_length = flow.shape[2] sample_vid = pipeline( prompt_textbox, negative_prompt=negative_prompt_textbox, optical_flow=flow, num_inference_steps=diffusion_steps, guidance_scale=cfg, width=width_slider, height=height_slider, num_frames=video_length, val_scale_factor_temporal=flow_scale, generator=generator, ).frames[0] del flow if device == "cuda": torch.cuda.synchronize() torch.cuda.empty_cache() save_sample_path_video = os.path.join(savedir, f"sample.mp4") sample_vid = sample_vid[:original_flow_shape[2]] * 255. sample_vid = sample_vid.cpu().numpy() sample_vid = np.transpose(sample_vid, axes=(0, 2, 3, 1)) torchvision.io.write_video(save_sample_path_video, sample_vid, fps=8) return gr.Video(value=save_sample_path_flow), gr.Video(value=save_sample_path_video) controller = AnimateController() def find_closest_ratio(target_ratio): width_list = list(reversed(range(256, 1025, 64))) height_list = list(reversed(range(256, 1025, 64))) ratio_list = [(h, w, w/h) for h in height_list for w in width_list] ratio_list.sort(key=lambda x: abs(x[2] - target_ratio)) ratio_list = list(filter(lambda x: x[2] == ratio_list[0][2], ratio_list)) ratio_list.sort(key=lambda x: abs(x[0]*x[1] - 512*512)) return ratio_list[0][:2] def find_dimension(video): import av container = av.open(open(video, 'rb')) height, width = container.streams.video[0].height, container.streams.video[0].width target_ratio = width / height return find_closest_ratio(target_ratio) def ui(): with gr.Blocks(css=css) as demo: gr.Markdown( """ #

OnlyFlow: Optical Flow based Motion Conditioning for Video Diffusion Models

Mathis Koroglu, Hugo Caselles-Dupré, Guillaume Jeanneret Sanmiguel, Matthieu Cord
[Arxiv Report](https://arxiv.org/abs/2411.10501) | [Project Page](https://obvious-research.github.io/onlyflow/) | [Github](https://github.com/obvious-research/onlyflow/) """ ) gr.Markdown( """ ### Quick Start: 1. Select desired `Base Model`. 2. Select `Motion Module`. We recommend trying guoyww/animatediff-motion-adapter-v1-5-3 for the best results. 3. Provide `Positive Prompt` and `Negative Prompt`. You are encouraged to refer to each model's webpage on HuggingFace Hub or CivitAI to learn how to write prompts for them. 4. Upload a video to extract optical flow from. 5. Select a 'Flow Scale' to modulate the input video optical flow conditioning. 6. Select a 'CFG' and 'Diffusion Steps' to control the quality of the generated video and prompt adherence. 7. Select a 'Temporal Downsample' to reduce the number of frames in the input video. 8. If you want to use a custom dimension, check the `Custom Dimension` box and adjust the `Width` and `Height` sliders. 9. If the video is too long, you can adjust the generation window offset with the `Context Stride` slider. 10. Click `Generate`, wait for ~1/3 min, and enjoy the result! If you have any error concerning GPU limits, please try again later when your ZeroGPU quota is reset, or try with a shorter video. Otherwise, you can also duplicate this space and select a custom GPU plan. """ ) with gr.Row(): with gr.Column(): gr.Markdown("# INPUTS") with gr.Row(equal_height=True, show_progress=True): base_model = gr.Dropdown( label="Select or type a base model id", choices=[ "stable-diffusion-v1-5/stable-diffusion-v1-5", "digiplay/Photon_v1", ], interactive=True, scale=4, allow_custom_value=True, show_label=True ) base_model_btn = gr.Button(value="Update", scale=1, size='lg') with gr.Row(equal_height=True, show_progress=True): motion_module = gr.Dropdown( label="Select or type a motion module id", choices=[ "guoyww/animatediff-motion-adapter-v1-5-3", "guoyww/animatediff-motion-adapter-v1-5-2" ], interactive=True, scale=4 ) motion_module_btn = gr.Button(value="Update", scale=1, size='lg') base_model_btn.click(fn=controller.update_base_model, inputs=[base_model]) motion_module_btn.click(fn=controller.update_motion_module, inputs=[motion_module]) prompt_textbox_positive = gr.Textbox(label="Positive Prompt", lines=3) prompt_textbox_negative = gr.Textbox(label="Negative Prompt", lines=2, value="worst quality, low quality, nsfw, logo") flow_scale = gr.Slider(label="Flow Scale", value=1.0, minimum=0, maximum=2, step=0.025) diffusion_steps = gr.Slider(label="Diffusion Steps", value=25, minimum=0, maximum=100, step=1) cfg = gr.Slider(label="CFG", value=7.5, minimum=0, maximum=30, step=0.1) temporal_ds = gr.Slider(label="Temporal Downsample", value=1, minimum=1, maximum=30, step=1) input_video = gr.Video(label="Input Video", interactive=True) ctx_stride = gr.State(12) with gr.Accordion("Advanced", open=False): use_custom_dim = gr.Checkbox(label="Custom Dimension", value=False) with gr.Row(equal_height=True): height, width = gr.State(512), gr.State(512) @gr.render(inputs=[use_custom_dim, input_video]) def render_custom_dim(use_custom_dim, input_video): if input_video is not None: loc_height, loc_width = find_dimension(input_video) else: loc_height, loc_width = 512, 512 slider_width = gr.Slider(label="Width", value=loc_width, minimum=256, maximum=1024, step=64, visible=use_custom_dim) slider_height = gr.Slider(label="Height", value=loc_height, minimum=256, maximum=1024, step=64, visible=use_custom_dim) slider_width.change(lambda x: x, inputs=[slider_width], outputs=[width]) slider_height.change(lambda x: x, inputs=[slider_height], outputs=[height]) with gr.Row(): @gr.render(inputs=input_video) def render_ctx_stride(input_video): if input_video is not None: video = open(input_video, 'rb') import av container = av.open(video) num_frames = container.streams.video[0].frames if num_frames > 17: stride_slider = gr.Slider(label="Context Stride", value=12, minimum=1, maximum=16, step=1) stride_slider.input(lambda x: x, inputs=[stride_slider], outputs=[ctx_stride]) if num_frames > 32: gr.Warning(f"Video is long ({num_frames} frames), consider using a shorter video, increasing the context stride, or selecting a custom GPU plan.") elif num_frames > 64: raise gr.Error(f"Video is too long ({num_frames} frames), please use a shorter video, increase the context stride, or select a custom GPU plan. The current parameters won't allow generation on ZeroGPU.") with gr.Row(equal_height=True): seed_textbox = gr.Textbox(label="Seed", value='-1') seed_button = gr.Button(value="\U0001F3B2", elem_classes="toolbutton") seed_button.click( fn=lambda: random.randint(1, int(1e16)), inputs=[], outputs=[seed_textbox] ) with gr.Row(): clear_btn = gr.ClearButton(value="Clear & Reset", size='lg', variant='secondary', scale=1) generate_button = gr.Button(value="Generate", variant='primary', scale=2, size='lg') clear_btn.add([base_model, motion_module, input_video, prompt_textbox_positive, prompt_textbox_negative, seed_textbox, use_custom_dim, ctx_stride]) with gr.Column(): gr.Markdown("# OUTPUTS") result_optical_flow = gr.Video(label="Optical Flow", interactive=False) result_video = gr.Video(label="Generated Animation", interactive=False) inputs = [base_model, motion_module, prompt_textbox_positive, prompt_textbox_negative, seed_textbox, input_video, height, width, flow_scale, cfg, diffusion_steps, temporal_ds, ctx_stride] outputs = [result_optical_flow, result_video] generate_button.click(fn=controller.animate, inputs=inputs, outputs=outputs) return demo if __name__ == "__main__": demo = ui() demo.queue(max_size=20) demo.launch()