OnlyFlow / app.py
arlaz's picture
initial commit
9bb001a
import os
import imageio
import numpy as np
import torch
import random
import spaces
import gradio as gr
import torchvision
import torchvision.transforms as T
from einops import rearrange
from huggingface_hub import hf_hub_download
from torchvision.models.optical_flow import raft_large, Raft_Large_Weights
from torchvision.utils import flow_to_image
from diffusers import AutoencoderKL, MotionAdapter, UNet2DConditionModel
from diffusers import DDIMScheduler
from transformers import CLIPTextModel, CLIPTokenizer
from onlyflow.models.flow_adaptor import FlowEncoder, FlowAdaptor
from onlyflow.models.unet import UNetMotionModel
from onlyflow.pipelines.pipeline_animation_long import FlowCtrlPipeline
from tools.optical_flow import get_optical_flow
def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=8):
videos = rearrange(videos, "b c t h w -> t b c h w")
outputs = []
for x in videos:
x = torchvision.utils.make_grid(x, nrow=n_rows)
x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
if rescale:
x = (x + 1.0) / 2.0 # -1,1 -> 0,1
x = (x * 255).numpy().astype(np.uint8)
outputs.append(x)
os.makedirs(os.path.dirname(path), exist_ok=True)
imageio.mimsave(path, outputs, fps=fps)
css = """
.toolbutton {
margin-buttom: 0em 0em 0em 0em;
max-width: 2.5em;
min-width: 2.5em !important;
height: 2.5em;
}
"""
class AnimateController:
def __init__(self):
# config dirs
self.basedir = os.getcwd()
self.stable_diffusion_dir = os.path.join(self.basedir, "models", "StableDiffusion")
self.motion_module_dir = os.path.join(self.basedir, "models", "Motion_Module")
self.personalized_model_dir = os.path.join(self.basedir, "models", "DreamBooth_LoRA")
self.savedir = os.path.join(self.basedir, "samples")
os.makedirs(self.savedir, exist_ok=True)
ckpt_path = hf_hub_download('obvious-research/onlyflow', 'weights_fp16.ckpt')
ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=True)
self.flow_encoder_state_dict = ckpt['flow_encoder_state_dict']
self.attention_processor_state_dict = ckpt['attention_processor_state_dict']
self.tokenizer = None
self.text_encoder = None
self.vae = None
self.unet = None
self.motion_adapter = None
def update_base_model(self, base_model_id, progress=gr.Progress()):
progress(0, desc="Starting...")
self.tokenizer = CLIPTokenizer.from_pretrained(base_model_id, subfolder="tokenizer")
self.text_encoder = CLIPTextModel.from_pretrained(base_model_id, subfolder="text_encoder")
self.vae = AutoencoderKL.from_pretrained(base_model_id, subfolder="vae")
self.unet = UNet2DConditionModel.from_pretrained(base_model_id, subfolder="unet")
return base_model_id
def update_motion_module(self, motion_module_id, progress=gr.Progress()):
self.motion_adapter = MotionAdapter.from_pretrained(motion_module_id)
def animate(
self,
id_base_model,
id_motion_module,
prompt_textbox_positive,
prompt_textbox_negative,
seed_textbox,
input_video,
height,
width,
flow_scale,
cfg,
diffusion_steps,
temporal_ds,
ctx_stride
):
#if any([x is None for x in [self.tokenizer, self.text_encoder, self.vae, self.unet, self.motion_adapter]]) or isinstance(self.unet, str):
self.update_base_model(id_base_model)
self.update_motion_module(id_motion_module)
self.unet = UNetMotionModel.from_unet2d(
self.unet,
motion_adapter=self.motion_adapter
)
self.raft = raft_large(weights=Raft_Large_Weights.DEFAULT, progress=False).eval()
self.flow_encoder = FlowEncoder(
downscale_factor=8,
channels=[320, 640, 1280, 1280],
nums_rb=2,
ksize=1,
sk=True,
use_conv=False,
compression_factor=1,
temporal_attention_nhead=8,
positional_embeddings="sinusoidal",
num_positional_embeddings=16,
checkpointing=False
).eval()
self.vae.requires_grad_(False)
self.text_encoder.requires_grad_(False)
self.unet.requires_grad_(False)
self.raft.requires_grad_(False)
self.flow_encoder.requires_grad_(False)
self.unet.set_all_attn(
flow_channels=[320, 640, 1280, 1280],
add_spatial=False,
add_temporal=True,
encoder_only=False,
query_condition=True,
key_value_condition=True,
flow_scale=1.0,
)
self.flow_adaptor = FlowAdaptor(self.unet, self.flow_encoder).eval()
# load the flow encoder weights
pose_enc_m, pose_enc_u = self.flow_adaptor.flow_encoder.load_state_dict(
self.flow_encoder_state_dict,
strict=False
)
assert len(pose_enc_m) == 0 and len(pose_enc_u) == 0
# load the attention processor weights
_, attention_processor_u = self.flow_adaptor.unet.load_state_dict(
self.attention_processor_state_dict,
strict=False
)
assert len(attention_processor_u) == 0
pipeline = FlowCtrlPipeline(
vae=self.vae,
text_encoder=self.text_encoder,
tokenizer=self.tokenizer,
unet=self.unet,
motion_adapter=self.motion_adapter,
flow_encoder=self.flow_encoder,
scheduler=DDIMScheduler.from_pretrained(id_base_model, subfolder="scheduler"),
)
if int(seed_textbox) > 0:
seed = int(seed_textbox)
else:
seed = random.randint(1, int(1e16))
return animate_diffusion(seed, pipeline, self.raft, input_video, prompt_textbox_positive, prompt_textbox_negative, width, height, flow_scale, cfg, diffusion_steps, temporal_ds, ctx_stride)
@spaces.GPU(duration=150)
def animate_diffusion(seed, pipeline, raft_model, base_video, prompt_textbox, negative_prompt_textbox, width_slider, height_slider, flow_scale, cfg, diffusion_steps, temporal_ds, context_stride):
savedir = './samples'
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
generator = torch.Generator(device="cpu")
generator.manual_seed(seed)
raft_model = raft_model.to(device)
pipeline = pipeline.to(device)
pixel_values = torchvision.io.read_video(base_video, output_format="TCHW", pts_unit='sec')[0][::temporal_ds]
print("Video loaded, shape:", pixel_values.shape)
if width_slider/height_slider > pixel_values.shape[3]/pixel_values.shape[2]:
print("Resizing video to fit width cause input video is not wide enough")
temp_height = int(width_slider * pixel_values.shape[2]/pixel_values.shape[3])
temp_width = width_slider
else:
print("Resizing video to fit height cause input video is not tall enough")
temp_height = height_slider
temp_width = int(height_slider * pixel_values.shape[3]/pixel_values.shape[2])
print("Resizing video to:", temp_height, temp_width)
pixel_values = T.Resize((temp_height, temp_width))(pixel_values)
pixel_values = T.CenterCrop((height_slider, width_slider))(pixel_values)
pixel_values = T.ConvertImageDtype(torch.float32)(pixel_values)[None, ...].contiguous().to(device)
save_sample_path_input = os.path.join(savedir, f"input.mp4")
pixel_values_save = pixel_values[0] * 255
pixel_values_save = pixel_values_save.cpu()
pixel_values_save = torch.permute(pixel_values_save, (0, 2, 3, 1))
torchvision.io.write_video(save_sample_path_input, pixel_values_save, fps=8)
del pixel_values_save
print("Video loaded, shape:", pixel_values.shape)
flow = get_optical_flow(
raft_model,
(pixel_values * 2) - 1,
pixel_values.shape[1] - 1,
encode_chunk_size=16,
).to('cpu')
sample_flow = (flow_to_image(rearrange(flow[0], "c f h w -> f c h w"))) # N, 3, H, W
save_sample_path_flow = os.path.join(savedir, f"flow.mp4")
sample_flow = (sample_flow).cpu().to(torch.uint8).permute(0, 2, 3, 1)
torchvision.io.write_video(save_sample_path_flow, sample_flow, fps=8)
del sample_flow
original_flow_shape = flow.shape
print("Optical flow computed, shape:", flow.shape)
if flow.shape[2] < 16:
print("Video is too short, padding to 16 frames")
video_length = 16
n = 16 - flow.shape[2]
# create a tensor containing the last frame optical flow repeated n times
to_add = flow[:, :, -1].unsqueeze(2).expand(-1, -1, n, -1, -1)
flow = torch.cat([flow, to_add], dim=2).to(device)
elif flow.shape[2] > 16:
print("Video is too long, enabling windowing")
print("Enabling model CPU offload")
pipeline.enable_model_cpu_offload()
print("Enabling VAE slicing")
pipeline.enable_vae_slicing()
print("Enabling VAE tiling")
pipeline.enable_vae_tiling()
print("Enabling free noise")
pipeline.enable_free_noise(
context_length=16,
context_stride=context_stride,
)
import math
def find_divisors(n: int):
"""
Return sorted list of all positive divisors of n.
Uses a sqrt(n) approach for efficiency.
"""
divs = set()
limit = int(math.isqrt(n))
for i in range(1, limit + 1):
if n % i == 0:
divs.add(i)
divs.add(n // i)
return sorted(divs)
def multiples_in_range(k: int, min_val: int, max_val: int):
"""
Return all multiples of k within [min_val, max_val].
"""
if k == 0:
return []
# First multiple of k >= min_val
start = ((min_val + k - 1) // k) * k
# Last multiple of k <= max_val
end = (max_val // k) * k
return list(range(start, end + 1, k)) if start <= end else []
def adjust_video_length(original_length: int,
context_stride: int,
chunk_size: int,
temporal_split_size: int) -> int:
"""
Find the minimal video_length >= original_length satisfying:
1) (video_length - 16) is divisible by context_stride.
2) EITHER (2*video_length) is divisible by temporal_split_size
OR (2*video_length) is divisible by chunk_size
(when 2*video_length is not multiple of temporal_split_size).
"""
# We start at least at 16 (though in practice original_length likely > 16)
candidate = max(original_length, 16)
# We want (candidate - 16) % context_stride == 0
# so let n be the multiple to step.
# n is how many times we add `context_stride` beyond 16.
# This ensures (candidate - 16) is a multiple of context_stride.
# Then we check the second condition, else keep stepping.
# If candidate < 16, bump it to 16
if candidate < 16:
candidate = 16
# Make sure we jump to the correct "starting multiple" of context_stride
offset = (candidate - 16) % context_stride
if offset != 0:
candidate += (context_stride - offset) # jump to the next multiple
while True:
# Condition: (candidate - 16) is multiple of context_stride (already enforced by stepping)
# Check second part:
# - if (2*candidate) % temporal_split_size == 0, we are good
# - else we require (2*candidate) % chunk_size == 0
twoL = 2 * candidate
if (twoL % temporal_split_size == 0) or (twoL % chunk_size == 0):
return candidate
# Go to next valid candidate
candidate += context_stride
def find_valid_configs(original_video_length: int,
width: int,
height: int,
context_stride: int):
"""
Generate all valid tuples (chunk_size, spatial_split_size, temporal_split_size, video_length)
subject to the constraints:
1) chunk_size divides temporal_split_size
2) chunk_size divides spatial_split_size
3) chunk_size divides (2 * (width//64) * (height//64))
4) if (2*video_length) % temporal_split_size != 0, then chunk_size divides (2*video_length)
5) context_stride divides (video_length - 16)
6) 128 <= spatial_split_size <= 512
7) 1 <= temporal_split_size <= 32
8) 1 <= chunk_size <= 16
We allow increasing original_video_length minimally if needed to satisfy constraints #4 and #5.
"""
factor = 2 * (width // 64) * (height // 64)
# 1) find all possible chunk_size as divisors of factor, in [1..16]
possible_chunks = [d for d in find_divisors(factor) if 1 <= d <= 32]
# For storing results
valid_tuples = []
for chunk_size in possible_chunks:
# 2) generate all spatial_split_size in [128..512] that are multiples of chunk_size
spatial_splits = multiples_in_range(chunk_size, 480, 512)
# 3) generate all temporal_split_size in [1..32] that are multiples of chunk_size
temporal_splits = multiples_in_range(chunk_size, 1, 32)
for ssp in spatial_splits:
for tsp in temporal_splits:
# 4) & 5) Adjust video_length minimally to satisfy constraints
final_length = adjust_video_length(original_video_length,
context_stride,
chunk_size,
tsp)
# Now we have a valid (chunk_size, ssp, tsp, final_length)
valid_tuples.append((chunk_size, ssp, tsp, final_length))
return valid_tuples
def find_pareto_optimal(configs):
"""
Given a list of tuples (chunk_size, spatial_split_size, temporal_split_size, video_length),
return the Pareto-optimal subset under the criteria:
- chunk_size: larger is better
- spatial_split_size: larger is better
- temporal_split_size: larger is better
- video_length: smaller is better
"""
def dominates(A, B):
cA, sA, tA, lA = A
cB, sB, tB, lB = B
# A dominates B if:
# cA >= cB, sA >= sB, tA >= tB, and lA <= lB
# AND at least one of these is a strict inequality.
better_or_equal = (cA >= cB) and (tA >= tB) and (lA <= lB)
strictly_better = (cA > cB) or (tA > tB) or (lA < lB)
return better_or_equal and strictly_better
pareto = []
for i, cfg_i in enumerate(configs):
# Check if cfg_i is dominated by any cfg_j
is_dominated = False
for j, cfg_j in enumerate(configs):
if i == j:
continue
if dominates(cfg_j, cfg_i):
is_dominated = True
break
if not is_dominated:
pareto.append(cfg_i)
return pareto
print("Finding valid configurations...")
valid_configs = find_valid_configs(
original_video_length=flow.shape[2],
width=width_slider,
height=height_slider,
context_stride=context_stride
)
print("Found", len(valid_configs), "valid configurations")
print("Finding Pareto-optimal configurations...")
pareto_optimal = find_pareto_optimal(valid_configs)
print("Found", pareto_optimal)
criteria = lambda cs, sss, tss, vl: cs + tss - 3 * int(abs(flow.shape[2] - vl) / 10)
pareto_optimal.sort(key=lambda x: criteria(*x), reverse=True)
print("Found sorted", pareto_optimal)
solution = pareto_optimal[0]
chunk_size, spatial_split_size, temporal_split_size, video_length = solution
n = video_length - original_flow_shape[2]
to_add = flow[:, :, -1].unsqueeze(2).expand(-1, -1, n, -1, -1)
flow = torch.cat([flow, to_add], dim=2)
pipeline.enable_free_noise_split_inference(
temporal_split_size=temporal_split_size,
spatial_split_size=spatial_split_size
)
pipeline.unet.enable_forward_chunking(chunk_size)
print("Chunking enabled with chunk size:", chunk_size)
print("Temporal split size:", temporal_split_size)
print("Spatial split size:", spatial_split_size)
print("Context stride:", context_stride)
print("Temporal downscale:", temporal_ds)
print("Video length:", video_length)
print("Flow shape:", flow.shape)
else:
print("Video is just right, no padding or windowing needed")
flow = flow.to(device)
video_length = flow.shape[2]
sample_vid = pipeline(
prompt_textbox,
negative_prompt=negative_prompt_textbox,
optical_flow=flow,
num_inference_steps=diffusion_steps,
guidance_scale=cfg,
width=width_slider,
height=height_slider,
num_frames=video_length,
val_scale_factor_temporal=flow_scale,
generator=generator,
).frames[0]
del flow
if device == "cuda":
torch.cuda.synchronize()
torch.cuda.empty_cache()
save_sample_path_video = os.path.join(savedir, f"sample.mp4")
sample_vid = sample_vid[:original_flow_shape[2]] * 255.
sample_vid = sample_vid.cpu().numpy()
sample_vid = np.transpose(sample_vid, axes=(0, 2, 3, 1))
torchvision.io.write_video(save_sample_path_video, sample_vid, fps=8)
return gr.Video(value=save_sample_path_flow), gr.Video(value=save_sample_path_video)
controller = AnimateController()
def find_closest_ratio(target_ratio):
width_list = list(reversed(range(256, 1025, 64)))
height_list = list(reversed(range(256, 1025, 64)))
ratio_list = [(h, w, w/h) for h in height_list for w in width_list]
ratio_list.sort(key=lambda x: abs(x[2] - target_ratio))
ratio_list = list(filter(lambda x: x[2] == ratio_list[0][2], ratio_list))
ratio_list.sort(key=lambda x: abs(x[0]*x[1] - 512*512))
return ratio_list[0][:2]
def find_dimension(video):
import av
container = av.open(open(video, 'rb'))
height, width = container.streams.video[0].height, container.streams.video[0].width
target_ratio = width / height
return find_closest_ratio(target_ratio)
def ui():
with gr.Blocks(css=css) as demo:
gr.Markdown(
"""
# <p style="text-align:center;">OnlyFlow: Optical Flow based Motion Conditioning for Video Diffusion Models</p>
Mathis Koroglu, Hugo Caselles-Dupré, Guillaume Jeanneret Sanmiguel, Matthieu Cord<br>
[Arxiv Report](https://arxiv.org/abs/2411.10501) | [Project Page](https://obvious-research.github.io/onlyflow/) | [Github](https://github.com/obvious-research/onlyflow/)
"""
)
gr.Markdown(
"""
### Quick Start:
1. Select desired `Base Model`.
2. Select `Motion Module`. We recommend trying guoyww/animatediff-motion-adapter-v1-5-3 for the best results.
3. Provide `Positive Prompt` and `Negative Prompt`. You are encouraged to refer to each model's webpage on HuggingFace Hub or CivitAI to learn how to write prompts for them.
4. Upload a video to extract optical flow from.
5. Select a 'Flow Scale' to modulate the input video optical flow conditioning.
6. Select a 'CFG' and 'Diffusion Steps' to control the quality of the generated video and prompt adherence.
7. Select a 'Temporal Downsample' to reduce the number of frames in the input video.
8. If you want to use a custom dimension, check the `Custom Dimension` box and adjust the `Width` and `Height` sliders.
9. If the video is too long, you can adjust the generation window offset with the `Context Stride` slider.
10. Click `Generate`, wait for ~1/3 min, and enjoy the result!
If you have any error concerning GPU limits, please try again later when your ZeroGPU quota is reset, or try with a shorter video.
Otherwise, you can also duplicate this space and select a custom GPU plan.
"""
)
with gr.Row():
with gr.Column():
gr.Markdown("# INPUTS")
with gr.Row(equal_height=True, show_progress=True):
base_model = gr.Dropdown(
label="Select or type a base model id",
choices=[
"stable-diffusion-v1-5/stable-diffusion-v1-5",
"digiplay/Photon_v1",
],
interactive=True,
scale=4,
allow_custom_value=True,
show_label=True
)
base_model_btn = gr.Button(value="Update", scale=1, size='lg')
with gr.Row(equal_height=True, show_progress=True):
motion_module = gr.Dropdown(
label="Select or type a motion module id",
choices=[
"guoyww/animatediff-motion-adapter-v1-5-3",
"guoyww/animatediff-motion-adapter-v1-5-2"
],
interactive=True,
scale=4
)
motion_module_btn = gr.Button(value="Update", scale=1, size='lg')
base_model_btn.click(fn=controller.update_base_model, inputs=[base_model])
motion_module_btn.click(fn=controller.update_motion_module, inputs=[motion_module])
prompt_textbox_positive = gr.Textbox(label="Positive Prompt", lines=3)
prompt_textbox_negative = gr.Textbox(label="Negative Prompt", lines=2, value="worst quality, low quality, nsfw, logo")
flow_scale = gr.Slider(label="Flow Scale", value=1.0, minimum=0, maximum=2, step=0.025)
diffusion_steps = gr.Slider(label="Diffusion Steps", value=25, minimum=0, maximum=100, step=1)
cfg = gr.Slider(label="CFG", value=7.5, minimum=0, maximum=30, step=0.1)
temporal_ds = gr.Slider(label="Temporal Downsample", value=1, minimum=1, maximum=30, step=1)
input_video = gr.Video(label="Input Video", interactive=True)
ctx_stride = gr.State(12)
with gr.Accordion("Advanced", open=False):
use_custom_dim = gr.Checkbox(label="Custom Dimension", value=False)
with gr.Row(equal_height=True):
height, width = gr.State(512), gr.State(512)
@gr.render(inputs=[use_custom_dim, input_video])
def render_custom_dim(use_custom_dim, input_video):
if input_video is not None:
loc_height, loc_width = find_dimension(input_video)
else:
loc_height, loc_width = 512, 512
slider_width = gr.Slider(label="Width", value=loc_width, minimum=256, maximum=1024,
step=64, visible=use_custom_dim)
slider_height = gr.Slider(label="Height", value=loc_height, minimum=256, maximum=1024,
step=64, visible=use_custom_dim)
slider_width.change(lambda x: x, inputs=[slider_width], outputs=[width])
slider_height.change(lambda x: x, inputs=[slider_height], outputs=[height])
with gr.Row():
@gr.render(inputs=input_video)
def render_ctx_stride(input_video):
if input_video is not None:
video = open(input_video, 'rb')
import av
container = av.open(video)
num_frames = container.streams.video[0].frames
if num_frames > 17:
stride_slider = gr.Slider(label="Context Stride", value=12, minimum=1, maximum=16, step=1)
stride_slider.input(lambda x: x, inputs=[stride_slider], outputs=[ctx_stride])
if num_frames > 32:
gr.Warning(f"Video is long ({num_frames} frames), consider using a shorter video, increasing the context stride, or selecting a custom GPU plan.")
elif num_frames > 64:
raise gr.Error(f"Video is too long ({num_frames} frames), please use a shorter video, increase the context stride, or select a custom GPU plan. The current parameters won't allow generation on ZeroGPU.")
with gr.Row(equal_height=True):
seed_textbox = gr.Textbox(label="Seed", value='-1')
seed_button = gr.Button(value="\U0001F3B2", elem_classes="toolbutton")
seed_button.click(
fn=lambda: random.randint(1, int(1e16)),
inputs=[],
outputs=[seed_textbox]
)
with gr.Row():
clear_btn = gr.ClearButton(value="Clear & Reset", size='lg', variant='secondary', scale=1)
generate_button = gr.Button(value="Generate", variant='primary', scale=2, size='lg')
clear_btn.add([base_model, motion_module, input_video, prompt_textbox_positive, prompt_textbox_negative, seed_textbox, use_custom_dim, ctx_stride])
with gr.Column():
gr.Markdown("# OUTPUTS")
result_optical_flow = gr.Video(label="Optical Flow", interactive=False)
result_video = gr.Video(label="Generated Animation", interactive=False)
inputs = [base_model, motion_module, prompt_textbox_positive, prompt_textbox_negative, seed_textbox, input_video, height, width, flow_scale, cfg, diffusion_steps, temporal_ds, ctx_stride]
outputs = [result_optical_flow, result_video]
generate_button.click(fn=controller.animate, inputs=inputs, outputs=outputs)
return demo
if __name__ == "__main__":
demo = ui()
demo.queue(max_size=20)
demo.launch()