Spaces:
Running
Running
Commit
·
9bb001a
0
Parent(s):
initial commit
Browse files- .gitattributes +35 -0
- .gitignore +1 -0
- README.md +14 -0
- app.py +641 -0
- onlyflow/data/dataset_idx.py +88 -0
- onlyflow/data/dataset_itr.py +81 -0
- onlyflow/models/attention.py +359 -0
- onlyflow/models/attention_processor.py +456 -0
- onlyflow/models/flow_adaptor.py +247 -0
- onlyflow/models/transformer_2d.py +566 -0
- onlyflow/models/unet.py +0 -0
- onlyflow/pipelines/pipeline_animation.py +497 -0
- onlyflow/pipelines/pipeline_animation_long.py +555 -0
- onlyflow/utils/util.py +140 -0
- requirements.txt +11 -0
- tools/optical_flow.py +22 -0
.gitattributes
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
.DS_Store
|
README.md
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: OnlyFlow
|
3 |
+
emoji: 🐢
|
4 |
+
colorFrom: pink
|
5 |
+
colorTo: red
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 5.16.0
|
8 |
+
app_file: app.py
|
9 |
+
pinned: false
|
10 |
+
license: mit
|
11 |
+
short_description: 'Optical flow based motion conditioned video generation'
|
12 |
+
---
|
13 |
+
|
14 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
@@ -0,0 +1,641 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import imageio
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
import random
|
7 |
+
|
8 |
+
import spaces
|
9 |
+
|
10 |
+
import gradio as gr
|
11 |
+
|
12 |
+
import torchvision
|
13 |
+
import torchvision.transforms as T
|
14 |
+
from einops import rearrange
|
15 |
+
from huggingface_hub import hf_hub_download
|
16 |
+
from torchvision.models.optical_flow import raft_large, Raft_Large_Weights
|
17 |
+
from torchvision.utils import flow_to_image
|
18 |
+
|
19 |
+
from diffusers import AutoencoderKL, MotionAdapter, UNet2DConditionModel
|
20 |
+
from diffusers import DDIMScheduler
|
21 |
+
from transformers import CLIPTextModel, CLIPTokenizer
|
22 |
+
|
23 |
+
from onlyflow.models.flow_adaptor import FlowEncoder, FlowAdaptor
|
24 |
+
from onlyflow.models.unet import UNetMotionModel
|
25 |
+
from onlyflow.pipelines.pipeline_animation_long import FlowCtrlPipeline
|
26 |
+
from tools.optical_flow import get_optical_flow
|
27 |
+
|
28 |
+
|
29 |
+
def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=8):
|
30 |
+
videos = rearrange(videos, "b c t h w -> t b c h w")
|
31 |
+
outputs = []
|
32 |
+
for x in videos:
|
33 |
+
x = torchvision.utils.make_grid(x, nrow=n_rows)
|
34 |
+
x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
|
35 |
+
if rescale:
|
36 |
+
x = (x + 1.0) / 2.0 # -1,1 -> 0,1
|
37 |
+
x = (x * 255).numpy().astype(np.uint8)
|
38 |
+
outputs.append(x)
|
39 |
+
|
40 |
+
os.makedirs(os.path.dirname(path), exist_ok=True)
|
41 |
+
imageio.mimsave(path, outputs, fps=fps)
|
42 |
+
|
43 |
+
css = """
|
44 |
+
.toolbutton {
|
45 |
+
margin-buttom: 0em 0em 0em 0em;
|
46 |
+
max-width: 2.5em;
|
47 |
+
min-width: 2.5em !important;
|
48 |
+
height: 2.5em;
|
49 |
+
}
|
50 |
+
"""
|
51 |
+
|
52 |
+
|
53 |
+
class AnimateController:
|
54 |
+
def __init__(self):
|
55 |
+
|
56 |
+
# config dirs
|
57 |
+
self.basedir = os.getcwd()
|
58 |
+
self.stable_diffusion_dir = os.path.join(self.basedir, "models", "StableDiffusion")
|
59 |
+
self.motion_module_dir = os.path.join(self.basedir, "models", "Motion_Module")
|
60 |
+
self.personalized_model_dir = os.path.join(self.basedir, "models", "DreamBooth_LoRA")
|
61 |
+
self.savedir = os.path.join(self.basedir, "samples")
|
62 |
+
os.makedirs(self.savedir, exist_ok=True)
|
63 |
+
|
64 |
+
|
65 |
+
ckpt_path = hf_hub_download('obvious-research/onlyflow', 'weights_fp16.ckpt')
|
66 |
+
ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=True)
|
67 |
+
self.flow_encoder_state_dict = ckpt['flow_encoder_state_dict']
|
68 |
+
self.attention_processor_state_dict = ckpt['attention_processor_state_dict']
|
69 |
+
|
70 |
+
self.tokenizer = None
|
71 |
+
self.text_encoder = None
|
72 |
+
self.vae = None
|
73 |
+
self.unet = None
|
74 |
+
self.motion_adapter = None
|
75 |
+
|
76 |
+
def update_base_model(self, base_model_id, progress=gr.Progress()):
|
77 |
+
|
78 |
+
progress(0, desc="Starting...")
|
79 |
+
|
80 |
+
self.tokenizer = CLIPTokenizer.from_pretrained(base_model_id, subfolder="tokenizer")
|
81 |
+
self.text_encoder = CLIPTextModel.from_pretrained(base_model_id, subfolder="text_encoder")
|
82 |
+
self.vae = AutoencoderKL.from_pretrained(base_model_id, subfolder="vae")
|
83 |
+
self.unet = UNet2DConditionModel.from_pretrained(base_model_id, subfolder="unet")
|
84 |
+
|
85 |
+
return base_model_id
|
86 |
+
|
87 |
+
def update_motion_module(self, motion_module_id, progress=gr.Progress()):
|
88 |
+
self.motion_adapter = MotionAdapter.from_pretrained(motion_module_id)
|
89 |
+
|
90 |
+
def animate(
|
91 |
+
self,
|
92 |
+
id_base_model,
|
93 |
+
id_motion_module,
|
94 |
+
prompt_textbox_positive,
|
95 |
+
prompt_textbox_negative,
|
96 |
+
seed_textbox,
|
97 |
+
input_video,
|
98 |
+
height,
|
99 |
+
width,
|
100 |
+
flow_scale,
|
101 |
+
cfg,
|
102 |
+
diffusion_steps,
|
103 |
+
temporal_ds,
|
104 |
+
ctx_stride
|
105 |
+
):
|
106 |
+
#if any([x is None for x in [self.tokenizer, self.text_encoder, self.vae, self.unet, self.motion_adapter]]) or isinstance(self.unet, str):
|
107 |
+
self.update_base_model(id_base_model)
|
108 |
+
self.update_motion_module(id_motion_module)
|
109 |
+
|
110 |
+
self.unet = UNetMotionModel.from_unet2d(
|
111 |
+
self.unet,
|
112 |
+
motion_adapter=self.motion_adapter
|
113 |
+
)
|
114 |
+
|
115 |
+
self.raft = raft_large(weights=Raft_Large_Weights.DEFAULT, progress=False).eval()
|
116 |
+
|
117 |
+
self.flow_encoder = FlowEncoder(
|
118 |
+
downscale_factor=8,
|
119 |
+
channels=[320, 640, 1280, 1280],
|
120 |
+
nums_rb=2,
|
121 |
+
ksize=1,
|
122 |
+
sk=True,
|
123 |
+
use_conv=False,
|
124 |
+
compression_factor=1,
|
125 |
+
temporal_attention_nhead=8,
|
126 |
+
positional_embeddings="sinusoidal",
|
127 |
+
num_positional_embeddings=16,
|
128 |
+
checkpointing=False
|
129 |
+
).eval()
|
130 |
+
|
131 |
+
self.vae.requires_grad_(False)
|
132 |
+
self.text_encoder.requires_grad_(False)
|
133 |
+
self.unet.requires_grad_(False)
|
134 |
+
self.raft.requires_grad_(False)
|
135 |
+
self.flow_encoder.requires_grad_(False)
|
136 |
+
|
137 |
+
self.unet.set_all_attn(
|
138 |
+
flow_channels=[320, 640, 1280, 1280],
|
139 |
+
add_spatial=False,
|
140 |
+
add_temporal=True,
|
141 |
+
encoder_only=False,
|
142 |
+
query_condition=True,
|
143 |
+
key_value_condition=True,
|
144 |
+
flow_scale=1.0,
|
145 |
+
)
|
146 |
+
|
147 |
+
self.flow_adaptor = FlowAdaptor(self.unet, self.flow_encoder).eval()
|
148 |
+
|
149 |
+
# load the flow encoder weights
|
150 |
+
pose_enc_m, pose_enc_u = self.flow_adaptor.flow_encoder.load_state_dict(
|
151 |
+
self.flow_encoder_state_dict,
|
152 |
+
strict=False
|
153 |
+
)
|
154 |
+
assert len(pose_enc_m) == 0 and len(pose_enc_u) == 0
|
155 |
+
|
156 |
+
# load the attention processor weights
|
157 |
+
_, attention_processor_u = self.flow_adaptor.unet.load_state_dict(
|
158 |
+
self.attention_processor_state_dict,
|
159 |
+
strict=False
|
160 |
+
)
|
161 |
+
assert len(attention_processor_u) == 0
|
162 |
+
|
163 |
+
pipeline = FlowCtrlPipeline(
|
164 |
+
vae=self.vae,
|
165 |
+
text_encoder=self.text_encoder,
|
166 |
+
tokenizer=self.tokenizer,
|
167 |
+
unet=self.unet,
|
168 |
+
motion_adapter=self.motion_adapter,
|
169 |
+
flow_encoder=self.flow_encoder,
|
170 |
+
scheduler=DDIMScheduler.from_pretrained(id_base_model, subfolder="scheduler"),
|
171 |
+
)
|
172 |
+
|
173 |
+
if int(seed_textbox) > 0:
|
174 |
+
seed = int(seed_textbox)
|
175 |
+
else:
|
176 |
+
seed = random.randint(1, int(1e16))
|
177 |
+
|
178 |
+
return animate_diffusion(seed, pipeline, self.raft, input_video, prompt_textbox_positive, prompt_textbox_negative, width, height, flow_scale, cfg, diffusion_steps, temporal_ds, ctx_stride)
|
179 |
+
|
180 |
+
@spaces.GPU(duration=150)
|
181 |
+
def animate_diffusion(seed, pipeline, raft_model, base_video, prompt_textbox, negative_prompt_textbox, width_slider, height_slider, flow_scale, cfg, diffusion_steps, temporal_ds, context_stride):
|
182 |
+
savedir = './samples'
|
183 |
+
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
|
184 |
+
generator = torch.Generator(device="cpu")
|
185 |
+
generator.manual_seed(seed)
|
186 |
+
|
187 |
+
raft_model = raft_model.to(device)
|
188 |
+
pipeline = pipeline.to(device)
|
189 |
+
|
190 |
+
pixel_values = torchvision.io.read_video(base_video, output_format="TCHW", pts_unit='sec')[0][::temporal_ds]
|
191 |
+
print("Video loaded, shape:", pixel_values.shape)
|
192 |
+
if width_slider/height_slider > pixel_values.shape[3]/pixel_values.shape[2]:
|
193 |
+
print("Resizing video to fit width cause input video is not wide enough")
|
194 |
+
temp_height = int(width_slider * pixel_values.shape[2]/pixel_values.shape[3])
|
195 |
+
temp_width = width_slider
|
196 |
+
else:
|
197 |
+
print("Resizing video to fit height cause input video is not tall enough")
|
198 |
+
temp_height = height_slider
|
199 |
+
temp_width = int(height_slider * pixel_values.shape[3]/pixel_values.shape[2])
|
200 |
+
print("Resizing video to:", temp_height, temp_width)
|
201 |
+
pixel_values = T.Resize((temp_height, temp_width))(pixel_values)
|
202 |
+
pixel_values = T.CenterCrop((height_slider, width_slider))(pixel_values)
|
203 |
+
pixel_values = T.ConvertImageDtype(torch.float32)(pixel_values)[None, ...].contiguous().to(device)
|
204 |
+
|
205 |
+
save_sample_path_input = os.path.join(savedir, f"input.mp4")
|
206 |
+
pixel_values_save = pixel_values[0] * 255
|
207 |
+
pixel_values_save = pixel_values_save.cpu()
|
208 |
+
pixel_values_save = torch.permute(pixel_values_save, (0, 2, 3, 1))
|
209 |
+
torchvision.io.write_video(save_sample_path_input, pixel_values_save, fps=8)
|
210 |
+
del pixel_values_save
|
211 |
+
|
212 |
+
print("Video loaded, shape:", pixel_values.shape)
|
213 |
+
flow = get_optical_flow(
|
214 |
+
raft_model,
|
215 |
+
(pixel_values * 2) - 1,
|
216 |
+
pixel_values.shape[1] - 1,
|
217 |
+
encode_chunk_size=16,
|
218 |
+
).to('cpu')
|
219 |
+
|
220 |
+
sample_flow = (flow_to_image(rearrange(flow[0], "c f h w -> f c h w"))) # N, 3, H, W
|
221 |
+
save_sample_path_flow = os.path.join(savedir, f"flow.mp4")
|
222 |
+
sample_flow = (sample_flow).cpu().to(torch.uint8).permute(0, 2, 3, 1)
|
223 |
+
torchvision.io.write_video(save_sample_path_flow, sample_flow, fps=8)
|
224 |
+
del sample_flow
|
225 |
+
|
226 |
+
original_flow_shape = flow.shape
|
227 |
+
print("Optical flow computed, shape:", flow.shape)
|
228 |
+
if flow.shape[2] < 16:
|
229 |
+
print("Video is too short, padding to 16 frames")
|
230 |
+
video_length = 16
|
231 |
+
n = 16 - flow.shape[2]
|
232 |
+
# create a tensor containing the last frame optical flow repeated n times
|
233 |
+
to_add = flow[:, :, -1].unsqueeze(2).expand(-1, -1, n, -1, -1)
|
234 |
+
flow = torch.cat([flow, to_add], dim=2).to(device)
|
235 |
+
elif flow.shape[2] > 16:
|
236 |
+
print("Video is too long, enabling windowing")
|
237 |
+
print("Enabling model CPU offload")
|
238 |
+
pipeline.enable_model_cpu_offload()
|
239 |
+
print("Enabling VAE slicing")
|
240 |
+
pipeline.enable_vae_slicing()
|
241 |
+
print("Enabling VAE tiling")
|
242 |
+
pipeline.enable_vae_tiling()
|
243 |
+
|
244 |
+
print("Enabling free noise")
|
245 |
+
pipeline.enable_free_noise(
|
246 |
+
context_length=16,
|
247 |
+
context_stride=context_stride,
|
248 |
+
)
|
249 |
+
|
250 |
+
import math
|
251 |
+
|
252 |
+
def find_divisors(n: int):
|
253 |
+
"""
|
254 |
+
Return sorted list of all positive divisors of n.
|
255 |
+
Uses a sqrt(n) approach for efficiency.
|
256 |
+
"""
|
257 |
+
divs = set()
|
258 |
+
limit = int(math.isqrt(n))
|
259 |
+
for i in range(1, limit + 1):
|
260 |
+
if n % i == 0:
|
261 |
+
divs.add(i)
|
262 |
+
divs.add(n // i)
|
263 |
+
return sorted(divs)
|
264 |
+
|
265 |
+
def multiples_in_range(k: int, min_val: int, max_val: int):
|
266 |
+
"""
|
267 |
+
Return all multiples of k within [min_val, max_val].
|
268 |
+
"""
|
269 |
+
if k == 0:
|
270 |
+
return []
|
271 |
+
|
272 |
+
# First multiple of k >= min_val
|
273 |
+
start = ((min_val + k - 1) // k) * k
|
274 |
+
# Last multiple of k <= max_val
|
275 |
+
end = (max_val // k) * k
|
276 |
+
|
277 |
+
return list(range(start, end + 1, k)) if start <= end else []
|
278 |
+
|
279 |
+
def adjust_video_length(original_length: int,
|
280 |
+
context_stride: int,
|
281 |
+
chunk_size: int,
|
282 |
+
temporal_split_size: int) -> int:
|
283 |
+
"""
|
284 |
+
Find the minimal video_length >= original_length satisfying:
|
285 |
+
1) (video_length - 16) is divisible by context_stride.
|
286 |
+
2) EITHER (2*video_length) is divisible by temporal_split_size
|
287 |
+
OR (2*video_length) is divisible by chunk_size
|
288 |
+
(when 2*video_length is not multiple of temporal_split_size).
|
289 |
+
"""
|
290 |
+
|
291 |
+
# We start at least at 16 (though in practice original_length likely > 16)
|
292 |
+
candidate = max(original_length, 16)
|
293 |
+
|
294 |
+
# We want (candidate - 16) % context_stride == 0
|
295 |
+
# so let n be the multiple to step.
|
296 |
+
# n is how many times we add `context_stride` beyond 16.
|
297 |
+
# This ensures (candidate - 16) is a multiple of context_stride.
|
298 |
+
# Then we check the second condition, else keep stepping.
|
299 |
+
|
300 |
+
# If candidate < 16, bump it to 16
|
301 |
+
if candidate < 16:
|
302 |
+
candidate = 16
|
303 |
+
|
304 |
+
# Make sure we jump to the correct "starting multiple" of context_stride
|
305 |
+
offset = (candidate - 16) % context_stride
|
306 |
+
if offset != 0:
|
307 |
+
candidate += (context_stride - offset) # jump to the next multiple
|
308 |
+
|
309 |
+
while True:
|
310 |
+
# Condition: (candidate - 16) is multiple of context_stride (already enforced by stepping)
|
311 |
+
# Check second part:
|
312 |
+
# - if (2*candidate) % temporal_split_size == 0, we are good
|
313 |
+
# - else we require (2*candidate) % chunk_size == 0
|
314 |
+
twoL = 2 * candidate
|
315 |
+
if (twoL % temporal_split_size == 0) or (twoL % chunk_size == 0):
|
316 |
+
return candidate
|
317 |
+
|
318 |
+
# Go to next valid candidate
|
319 |
+
candidate += context_stride
|
320 |
+
|
321 |
+
def find_valid_configs(original_video_length: int,
|
322 |
+
width: int,
|
323 |
+
height: int,
|
324 |
+
context_stride: int):
|
325 |
+
"""
|
326 |
+
Generate all valid tuples (chunk_size, spatial_split_size, temporal_split_size, video_length)
|
327 |
+
subject to the constraints:
|
328 |
+
1) chunk_size divides temporal_split_size
|
329 |
+
2) chunk_size divides spatial_split_size
|
330 |
+
3) chunk_size divides (2 * (width//64) * (height//64))
|
331 |
+
4) if (2*video_length) % temporal_split_size != 0, then chunk_size divides (2*video_length)
|
332 |
+
5) context_stride divides (video_length - 16)
|
333 |
+
6) 128 <= spatial_split_size <= 512
|
334 |
+
7) 1 <= temporal_split_size <= 32
|
335 |
+
8) 1 <= chunk_size <= 16
|
336 |
+
|
337 |
+
We allow increasing original_video_length minimally if needed to satisfy constraints #4 and #5.
|
338 |
+
"""
|
339 |
+
|
340 |
+
factor = 2 * (width // 64) * (height // 64)
|
341 |
+
|
342 |
+
# 1) find all possible chunk_size as divisors of factor, in [1..16]
|
343 |
+
possible_chunks = [d for d in find_divisors(factor) if 1 <= d <= 32]
|
344 |
+
|
345 |
+
# For storing results
|
346 |
+
valid_tuples = []
|
347 |
+
|
348 |
+
for chunk_size in possible_chunks:
|
349 |
+
# 2) generate all spatial_split_size in [128..512] that are multiples of chunk_size
|
350 |
+
spatial_splits = multiples_in_range(chunk_size, 480, 512)
|
351 |
+
|
352 |
+
# 3) generate all temporal_split_size in [1..32] that are multiples of chunk_size
|
353 |
+
temporal_splits = multiples_in_range(chunk_size, 1, 32)
|
354 |
+
|
355 |
+
for ssp in spatial_splits:
|
356 |
+
for tsp in temporal_splits:
|
357 |
+
# 4) & 5) Adjust video_length minimally to satisfy constraints
|
358 |
+
final_length = adjust_video_length(original_video_length,
|
359 |
+
context_stride,
|
360 |
+
chunk_size,
|
361 |
+
tsp)
|
362 |
+
# Now we have a valid (chunk_size, ssp, tsp, final_length)
|
363 |
+
valid_tuples.append((chunk_size, ssp, tsp, final_length))
|
364 |
+
|
365 |
+
return valid_tuples
|
366 |
+
|
367 |
+
def find_pareto_optimal(configs):
|
368 |
+
"""
|
369 |
+
Given a list of tuples (chunk_size, spatial_split_size, temporal_split_size, video_length),
|
370 |
+
return the Pareto-optimal subset under the criteria:
|
371 |
+
- chunk_size: larger is better
|
372 |
+
- spatial_split_size: larger is better
|
373 |
+
- temporal_split_size: larger is better
|
374 |
+
- video_length: smaller is better
|
375 |
+
"""
|
376 |
+
|
377 |
+
def dominates(A, B):
|
378 |
+
cA, sA, tA, lA = A
|
379 |
+
cB, sB, tB, lB = B
|
380 |
+
|
381 |
+
# A dominates B if:
|
382 |
+
# cA >= cB, sA >= sB, tA >= tB, and lA <= lB
|
383 |
+
# AND at least one of these is a strict inequality.
|
384 |
+
|
385 |
+
better_or_equal = (cA >= cB) and (tA >= tB) and (lA <= lB)
|
386 |
+
strictly_better = (cA > cB) or (tA > tB) or (lA < lB)
|
387 |
+
|
388 |
+
return better_or_equal and strictly_better
|
389 |
+
|
390 |
+
pareto = []
|
391 |
+
for i, cfg_i in enumerate(configs):
|
392 |
+
# Check if cfg_i is dominated by any cfg_j
|
393 |
+
is_dominated = False
|
394 |
+
for j, cfg_j in enumerate(configs):
|
395 |
+
if i == j:
|
396 |
+
continue
|
397 |
+
if dominates(cfg_j, cfg_i):
|
398 |
+
is_dominated = True
|
399 |
+
break
|
400 |
+
if not is_dominated:
|
401 |
+
pareto.append(cfg_i)
|
402 |
+
|
403 |
+
return pareto
|
404 |
+
|
405 |
+
print("Finding valid configurations...")
|
406 |
+
valid_configs = find_valid_configs(
|
407 |
+
original_video_length=flow.shape[2],
|
408 |
+
width=width_slider,
|
409 |
+
height=height_slider,
|
410 |
+
context_stride=context_stride
|
411 |
+
)
|
412 |
+
|
413 |
+
print("Found", len(valid_configs), "valid configurations")
|
414 |
+
print("Finding Pareto-optimal configurations...")
|
415 |
+
pareto_optimal = find_pareto_optimal(valid_configs)
|
416 |
+
|
417 |
+
print("Found", pareto_optimal)
|
418 |
+
|
419 |
+
criteria = lambda cs, sss, tss, vl: cs + tss - 3 * int(abs(flow.shape[2] - vl) / 10)
|
420 |
+
pareto_optimal.sort(key=lambda x: criteria(*x), reverse=True)
|
421 |
+
|
422 |
+
print("Found sorted", pareto_optimal)
|
423 |
+
|
424 |
+
solution = pareto_optimal[0]
|
425 |
+
chunk_size, spatial_split_size, temporal_split_size, video_length = solution
|
426 |
+
|
427 |
+
n = video_length - original_flow_shape[2]
|
428 |
+
to_add = flow[:, :, -1].unsqueeze(2).expand(-1, -1, n, -1, -1)
|
429 |
+
flow = torch.cat([flow, to_add], dim=2)
|
430 |
+
|
431 |
+
pipeline.enable_free_noise_split_inference(
|
432 |
+
temporal_split_size=temporal_split_size,
|
433 |
+
spatial_split_size=spatial_split_size
|
434 |
+
)
|
435 |
+
pipeline.unet.enable_forward_chunking(chunk_size)
|
436 |
+
|
437 |
+
print("Chunking enabled with chunk size:", chunk_size)
|
438 |
+
print("Temporal split size:", temporal_split_size)
|
439 |
+
print("Spatial split size:", spatial_split_size)
|
440 |
+
print("Context stride:", context_stride)
|
441 |
+
print("Temporal downscale:", temporal_ds)
|
442 |
+
print("Video length:", video_length)
|
443 |
+
print("Flow shape:", flow.shape)
|
444 |
+
else:
|
445 |
+
print("Video is just right, no padding or windowing needed")
|
446 |
+
flow = flow.to(device)
|
447 |
+
video_length = flow.shape[2]
|
448 |
+
|
449 |
+
sample_vid = pipeline(
|
450 |
+
prompt_textbox,
|
451 |
+
negative_prompt=negative_prompt_textbox,
|
452 |
+
optical_flow=flow,
|
453 |
+
num_inference_steps=diffusion_steps,
|
454 |
+
guidance_scale=cfg,
|
455 |
+
width=width_slider,
|
456 |
+
height=height_slider,
|
457 |
+
num_frames=video_length,
|
458 |
+
val_scale_factor_temporal=flow_scale,
|
459 |
+
generator=generator,
|
460 |
+
).frames[0]
|
461 |
+
|
462 |
+
del flow
|
463 |
+
if device == "cuda":
|
464 |
+
torch.cuda.synchronize()
|
465 |
+
torch.cuda.empty_cache()
|
466 |
+
|
467 |
+
save_sample_path_video = os.path.join(savedir, f"sample.mp4")
|
468 |
+
sample_vid = sample_vid[:original_flow_shape[2]] * 255.
|
469 |
+
sample_vid = sample_vid.cpu().numpy()
|
470 |
+
sample_vid = np.transpose(sample_vid, axes=(0, 2, 3, 1))
|
471 |
+
torchvision.io.write_video(save_sample_path_video, sample_vid, fps=8)
|
472 |
+
|
473 |
+
return gr.Video(value=save_sample_path_flow), gr.Video(value=save_sample_path_video)
|
474 |
+
|
475 |
+
controller = AnimateController()
|
476 |
+
|
477 |
+
|
478 |
+
def find_closest_ratio(target_ratio):
|
479 |
+
width_list = list(reversed(range(256, 1025, 64)))
|
480 |
+
height_list = list(reversed(range(256, 1025, 64)))
|
481 |
+
ratio_list = [(h, w, w/h) for h in height_list for w in width_list]
|
482 |
+
ratio_list.sort(key=lambda x: abs(x[2] - target_ratio))
|
483 |
+
ratio_list = list(filter(lambda x: x[2] == ratio_list[0][2], ratio_list))
|
484 |
+
ratio_list.sort(key=lambda x: abs(x[0]*x[1] - 512*512))
|
485 |
+
return ratio_list[0][:2]
|
486 |
+
|
487 |
+
|
488 |
+
def find_dimension(video):
|
489 |
+
import av
|
490 |
+
container = av.open(open(video, 'rb'))
|
491 |
+
height, width = container.streams.video[0].height, container.streams.video[0].width
|
492 |
+
target_ratio = width / height
|
493 |
+
return find_closest_ratio(target_ratio)
|
494 |
+
|
495 |
+
|
496 |
+
def ui():
|
497 |
+
with gr.Blocks(css=css) as demo:
|
498 |
+
gr.Markdown(
|
499 |
+
"""
|
500 |
+
# <p style="text-align:center;">OnlyFlow: Optical Flow based Motion Conditioning for Video Diffusion Models</p>
|
501 |
+
Mathis Koroglu, Hugo Caselles-Dupré, Guillaume Jeanneret Sanmiguel, Matthieu Cord<br>
|
502 |
+
[Arxiv Report](https://arxiv.org/abs/2411.10501) | [Project Page](https://obvious-research.github.io/onlyflow/) | [Github](https://github.com/obvious-research/onlyflow/)
|
503 |
+
"""
|
504 |
+
)
|
505 |
+
gr.Markdown(
|
506 |
+
"""
|
507 |
+
### Quick Start:
|
508 |
+
|
509 |
+
1. Select desired `Base Model`.
|
510 |
+
2. Select `Motion Module`. We recommend trying guoyww/animatediff-motion-adapter-v1-5-3 for the best results.
|
511 |
+
3. Provide `Positive Prompt` and `Negative Prompt`. You are encouraged to refer to each model's webpage on HuggingFace Hub or CivitAI to learn how to write prompts for them.
|
512 |
+
4. Upload a video to extract optical flow from.
|
513 |
+
5. Select a 'Flow Scale' to modulate the input video optical flow conditioning.
|
514 |
+
6. Select a 'CFG' and 'Diffusion Steps' to control the quality of the generated video and prompt adherence.
|
515 |
+
7. Select a 'Temporal Downsample' to reduce the number of frames in the input video.
|
516 |
+
8. If you want to use a custom dimension, check the `Custom Dimension` box and adjust the `Width` and `Height` sliders.
|
517 |
+
9. If the video is too long, you can adjust the generation window offset with the `Context Stride` slider.
|
518 |
+
10. Click `Generate`, wait for ~1/3 min, and enjoy the result!
|
519 |
+
|
520 |
+
If you have any error concerning GPU limits, please try again later when your ZeroGPU quota is reset, or try with a shorter video.
|
521 |
+
Otherwise, you can also duplicate this space and select a custom GPU plan.
|
522 |
+
"""
|
523 |
+
)
|
524 |
+
with gr.Row():
|
525 |
+
with gr.Column():
|
526 |
+
|
527 |
+
gr.Markdown("# INPUTS")
|
528 |
+
|
529 |
+
with gr.Row(equal_height=True, show_progress=True):
|
530 |
+
base_model = gr.Dropdown(
|
531 |
+
label="Select or type a base model id",
|
532 |
+
choices=[
|
533 |
+
"stable-diffusion-v1-5/stable-diffusion-v1-5",
|
534 |
+
"digiplay/Photon_v1",
|
535 |
+
],
|
536 |
+
interactive=True,
|
537 |
+
scale=4,
|
538 |
+
allow_custom_value=True,
|
539 |
+
show_label=True
|
540 |
+
)
|
541 |
+
base_model_btn = gr.Button(value="Update", scale=1, size='lg')
|
542 |
+
with gr.Row(equal_height=True, show_progress=True):
|
543 |
+
motion_module = gr.Dropdown(
|
544 |
+
label="Select or type a motion module id",
|
545 |
+
choices=[
|
546 |
+
"guoyww/animatediff-motion-adapter-v1-5-3",
|
547 |
+
"guoyww/animatediff-motion-adapter-v1-5-2"
|
548 |
+
],
|
549 |
+
interactive=True,
|
550 |
+
scale=4
|
551 |
+
)
|
552 |
+
motion_module_btn = gr.Button(value="Update", scale=1, size='lg')
|
553 |
+
|
554 |
+
base_model_btn.click(fn=controller.update_base_model, inputs=[base_model])
|
555 |
+
motion_module_btn.click(fn=controller.update_motion_module, inputs=[motion_module])
|
556 |
+
|
557 |
+
prompt_textbox_positive = gr.Textbox(label="Positive Prompt", lines=3)
|
558 |
+
prompt_textbox_negative = gr.Textbox(label="Negative Prompt", lines=2, value="worst quality, low quality, nsfw, logo")
|
559 |
+
|
560 |
+
flow_scale = gr.Slider(label="Flow Scale", value=1.0, minimum=0, maximum=2, step=0.025)
|
561 |
+
diffusion_steps = gr.Slider(label="Diffusion Steps", value=25, minimum=0, maximum=100, step=1)
|
562 |
+
cfg = gr.Slider(label="CFG", value=7.5, minimum=0, maximum=30, step=0.1)
|
563 |
+
|
564 |
+
temporal_ds = gr.Slider(label="Temporal Downsample", value=1, minimum=1, maximum=30, step=1)
|
565 |
+
|
566 |
+
input_video = gr.Video(label="Input Video", interactive=True)
|
567 |
+
ctx_stride = gr.State(12)
|
568 |
+
|
569 |
+
with gr.Accordion("Advanced", open=False):
|
570 |
+
use_custom_dim = gr.Checkbox(label="Custom Dimension", value=False)
|
571 |
+
|
572 |
+
with gr.Row(equal_height=True):
|
573 |
+
|
574 |
+
height, width = gr.State(512), gr.State(512)
|
575 |
+
|
576 |
+
@gr.render(inputs=[use_custom_dim, input_video])
|
577 |
+
def render_custom_dim(use_custom_dim, input_video):
|
578 |
+
if input_video is not None:
|
579 |
+
loc_height, loc_width = find_dimension(input_video)
|
580 |
+
else:
|
581 |
+
loc_height, loc_width = 512, 512
|
582 |
+
slider_width = gr.Slider(label="Width", value=loc_width, minimum=256, maximum=1024,
|
583 |
+
step=64, visible=use_custom_dim)
|
584 |
+
slider_height = gr.Slider(label="Height", value=loc_height, minimum=256, maximum=1024,
|
585 |
+
step=64, visible=use_custom_dim)
|
586 |
+
|
587 |
+
slider_width.change(lambda x: x, inputs=[slider_width], outputs=[width])
|
588 |
+
slider_height.change(lambda x: x, inputs=[slider_height], outputs=[height])
|
589 |
+
|
590 |
+
|
591 |
+
with gr.Row():
|
592 |
+
@gr.render(inputs=input_video)
|
593 |
+
def render_ctx_stride(input_video):
|
594 |
+
if input_video is not None:
|
595 |
+
video = open(input_video, 'rb')
|
596 |
+
import av
|
597 |
+
container = av.open(video)
|
598 |
+
num_frames = container.streams.video[0].frames
|
599 |
+
if num_frames > 17:
|
600 |
+
stride_slider = gr.Slider(label="Context Stride", value=12, minimum=1, maximum=16, step=1)
|
601 |
+
stride_slider.input(lambda x: x, inputs=[stride_slider], outputs=[ctx_stride])
|
602 |
+
if num_frames > 32:
|
603 |
+
gr.Warning(f"Video is long ({num_frames} frames), consider using a shorter video, increasing the context stride, or selecting a custom GPU plan.")
|
604 |
+
elif num_frames > 64:
|
605 |
+
raise gr.Error(f"Video is too long ({num_frames} frames), please use a shorter video, increase the context stride, or select a custom GPU plan. The current parameters won't allow generation on ZeroGPU.")
|
606 |
+
|
607 |
+
with gr.Row(equal_height=True):
|
608 |
+
seed_textbox = gr.Textbox(label="Seed", value='-1')
|
609 |
+
|
610 |
+
seed_button = gr.Button(value="\U0001F3B2", elem_classes="toolbutton")
|
611 |
+
seed_button.click(
|
612 |
+
fn=lambda: random.randint(1, int(1e16)),
|
613 |
+
inputs=[],
|
614 |
+
outputs=[seed_textbox]
|
615 |
+
)
|
616 |
+
|
617 |
+
with gr.Row():
|
618 |
+
clear_btn = gr.ClearButton(value="Clear & Reset", size='lg', variant='secondary', scale=1)
|
619 |
+
generate_button = gr.Button(value="Generate", variant='primary', scale=2, size='lg')
|
620 |
+
|
621 |
+
clear_btn.add([base_model, motion_module, input_video, prompt_textbox_positive, prompt_textbox_negative, seed_textbox, use_custom_dim, ctx_stride])
|
622 |
+
|
623 |
+
with gr.Column():
|
624 |
+
|
625 |
+
gr.Markdown("# OUTPUTS")
|
626 |
+
|
627 |
+
result_optical_flow = gr.Video(label="Optical Flow", interactive=False)
|
628 |
+
result_video = gr.Video(label="Generated Animation", interactive=False)
|
629 |
+
|
630 |
+
inputs = [base_model, motion_module, prompt_textbox_positive, prompt_textbox_negative, seed_textbox, input_video, height, width, flow_scale, cfg, diffusion_steps, temporal_ds, ctx_stride]
|
631 |
+
outputs = [result_optical_flow, result_video]
|
632 |
+
|
633 |
+
generate_button.click(fn=controller.animate, inputs=inputs, outputs=outputs)
|
634 |
+
|
635 |
+
return demo
|
636 |
+
|
637 |
+
|
638 |
+
if __name__ == "__main__":
|
639 |
+
demo = ui()
|
640 |
+
demo.queue(max_size=20)
|
641 |
+
demo.launch()
|
onlyflow/data/dataset_idx.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import functools
|
2 |
+
from io import BytesIO
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torchvision
|
6 |
+
import torchvision.transforms.v2 as transforms
|
7 |
+
import wids
|
8 |
+
from torch.utils.data import DataLoader
|
9 |
+
|
10 |
+
|
11 |
+
def _video_shortener(video_tensor, length, generator=None):
|
12 |
+
start = torch.randint(0, video_tensor.shape[0] - length, (1,), generator=generator)
|
13 |
+
return video_tensor[start:start + length]
|
14 |
+
|
15 |
+
|
16 |
+
def select_video_extract(length=16, generator=None):
|
17 |
+
return functools.partial(_video_shortener, length=length, generator=generator)
|
18 |
+
|
19 |
+
|
20 |
+
def my_collate_fn(batch):
|
21 |
+
videos = torch.stack([sample[0] for sample in batch])
|
22 |
+
txts = [sample[1] for sample in batch]
|
23 |
+
|
24 |
+
return videos, txts
|
25 |
+
|
26 |
+
|
27 |
+
class WebVidDataset(wids.ShardListDataset):
|
28 |
+
|
29 |
+
def __init__(self, shards, cache_dir, video_length=16, video_size=256, video_length_offset=1, val=False, seed=42,
|
30 |
+
**kwargs):
|
31 |
+
|
32 |
+
self.val = val
|
33 |
+
self.generator = torch.Generator()
|
34 |
+
self.generator.manual_seed(seed)
|
35 |
+
self.generator_init_state = self.generator.get_state()
|
36 |
+
super().__init__(shards, cache_dir=cache_dir, keep=True, **kwargs)
|
37 |
+
|
38 |
+
if isinstance(video_size, int):
|
39 |
+
video_size = (video_size, video_size)
|
40 |
+
|
41 |
+
self.video_size = video_size
|
42 |
+
|
43 |
+
for size in video_size:
|
44 |
+
if size % 8 != 0:
|
45 |
+
raise ValueError("video_size must be divisible by 8")
|
46 |
+
|
47 |
+
self.transform = transforms.Compose(
|
48 |
+
[
|
49 |
+
select_video_extract(length=video_length + video_length_offset, generator=self.generator),
|
50 |
+
transforms.Resize(size=video_size),
|
51 |
+
transforms.RandomCrop(size=video_size) if not self.val else transforms.CenterCrop(size=video_size),
|
52 |
+
transforms.RandomHorizontalFlip() if not self.val else transforms.Identity(),
|
53 |
+
]
|
54 |
+
)
|
55 |
+
|
56 |
+
self.add_transform(self._make_sample)
|
57 |
+
|
58 |
+
def _make_sample(self, sample):
|
59 |
+
if self.val:
|
60 |
+
self.generator.set_state(self.generator_init_state)
|
61 |
+
video = torchvision.io.read_video(BytesIO(sample[".mp4"].read()), output_format="TCHW", pts_unit='sec')[0]
|
62 |
+
label = sample[".txt"]
|
63 |
+
return self.transform(video), label
|
64 |
+
|
65 |
+
|
66 |
+
if __name__ == "__main__":
|
67 |
+
|
68 |
+
dataset = WebVidDataset(
|
69 |
+
tar_index=0,
|
70 |
+
root_path='/users/Etu9/3711799/onlyflow/data/webvid/desc.json',
|
71 |
+
video_length=16,
|
72 |
+
video_size=256,
|
73 |
+
video_length_offset=0,
|
74 |
+
)
|
75 |
+
|
76 |
+
sampler = wids.DistributedChunkedSampler(dataset, chunksize=1000, shuffle=True)
|
77 |
+
dataloader = DataLoader(
|
78 |
+
dataset,
|
79 |
+
collate_fn=my_collate_fn,
|
80 |
+
batch_size=4,
|
81 |
+
sampler=sampler,
|
82 |
+
num_workers=4
|
83 |
+
)
|
84 |
+
|
85 |
+
for i, (images, labels) in enumerate(dataloader):
|
86 |
+
print(i, images.shape, labels)
|
87 |
+
if i > 10:
|
88 |
+
break
|
onlyflow/data/dataset_itr.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import functools
|
2 |
+
import os
|
3 |
+
from io import BytesIO
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torchvision
|
7 |
+
import torchvision.transforms.v2 as transforms
|
8 |
+
import webdataset as wds
|
9 |
+
|
10 |
+
|
11 |
+
def _video_shortener(video_tensor, length):
|
12 |
+
start = torch.randint(0, video_tensor.shape[0] - length, (1,))
|
13 |
+
return video_tensor[start:start + length]
|
14 |
+
|
15 |
+
|
16 |
+
def select_video_extract(length=16):
|
17 |
+
return functools.partial(_video_shortener, length=length)
|
18 |
+
|
19 |
+
|
20 |
+
def my_collate_fn(batch):
|
21 |
+
output = {}
|
22 |
+
for key in batch[0].keys():
|
23 |
+
if key == 'video':
|
24 |
+
output[key] = torch.stack([sample[key] for sample in batch])
|
25 |
+
else:
|
26 |
+
output[key] = [sample[key] for sample in batch]
|
27 |
+
|
28 |
+
return output
|
29 |
+
|
30 |
+
|
31 |
+
def map_mp4(sample):
|
32 |
+
return torchvision.io.read_video(BytesIO(sample), output_format="TCHW", pts_unit='sec')[0]
|
33 |
+
|
34 |
+
|
35 |
+
def map_txt(sample):
|
36 |
+
return sample.decode("utf-8")
|
37 |
+
|
38 |
+
|
39 |
+
class WebVidDataset(wds.DataPipeline):
|
40 |
+
def __init__(self, batch_size, tar_index, root_path, video_length=16, video_size=256, video_length_offset=0,
|
41 |
+
horizontal_flip=True, seed=None):
|
42 |
+
|
43 |
+
self.dataset_full_path = os.path.join(root_path, f'webvid-uw-{{{tar_index}}}.tar')
|
44 |
+
|
45 |
+
if isinstance(video_size, int):
|
46 |
+
video_size = (video_size, video_size)
|
47 |
+
|
48 |
+
for size in video_size:
|
49 |
+
if size % 8 != 0:
|
50 |
+
raise ValueError("video_size must be divisible by 8")
|
51 |
+
|
52 |
+
self.pipeline = [
|
53 |
+
wds.SimpleShardList('file:' + str(self.dataset_full_path), seed=seed),
|
54 |
+
wds.shuffle(50),
|
55 |
+
wds.split_by_node,
|
56 |
+
wds.tarfile_to_samples(),
|
57 |
+
wds.shuffle(100),
|
58 |
+
wds.split_by_worker,
|
59 |
+
wds.map_dict(
|
60 |
+
mp4=map_mp4,
|
61 |
+
txt=map_txt,
|
62 |
+
),
|
63 |
+
wds.map_dict(
|
64 |
+
mp4=transforms.Compose(
|
65 |
+
[
|
66 |
+
select_video_extract(length=video_length + video_length_offset),
|
67 |
+
transforms.Resize(size=video_size),
|
68 |
+
transforms.RandomCrop(size=video_size),
|
69 |
+
transforms.RandomHorizontalFlip() if horizontal_flip else transforms.Identity,
|
70 |
+
]
|
71 |
+
)
|
72 |
+
),
|
73 |
+
wds.rename_keys(video="mp4", text='txt', keep_unselected=True),
|
74 |
+
wds.batched(batch_size, collation_fn=my_collate_fn, partial=True)
|
75 |
+
]
|
76 |
+
|
77 |
+
super().__init__(self.pipeline)
|
78 |
+
|
79 |
+
self.batch_size = batch_size
|
80 |
+
self.video_length = video_length
|
81 |
+
self.video_size = video_size
|
onlyflow/models/attention.py
ADDED
@@ -0,0 +1,359 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
from typing import Any, Dict, Optional
|
15 |
+
|
16 |
+
import torch
|
17 |
+
from diffusers.models.attention import GatedSelfAttentionDense, FeedForward, _chunked_feed_forward
|
18 |
+
from diffusers.models.embeddings import SinusoidalPositionalEmbedding
|
19 |
+
from diffusers.models.normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero
|
20 |
+
from diffusers.utils import logging
|
21 |
+
from diffusers.utils.torch_utils import maybe_allow_in_graph
|
22 |
+
from torch import nn
|
23 |
+
|
24 |
+
from onlyflow.models.attention_processor import Attention
|
25 |
+
|
26 |
+
logger = logging.get_logger(__name__)
|
27 |
+
|
28 |
+
|
29 |
+
@maybe_allow_in_graph
|
30 |
+
class BasicTransformerBlock(nn.Module):
|
31 |
+
r"""
|
32 |
+
A basic Transformer block.
|
33 |
+
|
34 |
+
Parameters:
|
35 |
+
dim (`int`): The number of channels in the input and output.
|
36 |
+
num_attention_heads (`int`): The number of heads to use for multi-head attention.
|
37 |
+
attention_head_dim (`int`): The number of channels in each head.
|
38 |
+
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
39 |
+
cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
|
40 |
+
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
41 |
+
num_embeds_ada_norm (:
|
42 |
+
obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
|
43 |
+
attention_bias (:
|
44 |
+
obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
|
45 |
+
only_cross_attention (`bool`, *optional*):
|
46 |
+
Whether to use only cross-attention layers. In this case two cross attention layers are used.
|
47 |
+
double_self_attention (`bool`, *optional*):
|
48 |
+
Whether to use two self-attention layers. In this case no cross attention layers are used.
|
49 |
+
upcast_attention (`bool`, *optional*):
|
50 |
+
Whether to upcast the attention computation to float32. This is useful for mixed precision training.
|
51 |
+
norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
|
52 |
+
Whether to use learnable elementwise affine parameters for normalization.
|
53 |
+
norm_type (`str`, *optional*, defaults to `"layer_norm"`):
|
54 |
+
The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
|
55 |
+
final_dropout (`bool` *optional*, defaults to False):
|
56 |
+
Whether to apply a final dropout after the last feed-forward layer.
|
57 |
+
attention_type (`str`, *optional*, defaults to `"default"`):
|
58 |
+
The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
|
59 |
+
positional_embeddings (`str`, *optional*, defaults to `None`):
|
60 |
+
The type of positional embeddings to apply to.
|
61 |
+
num_positional_embeddings (`int`, *optional*, defaults to `None`):
|
62 |
+
The maximum number of positional embeddings to apply.
|
63 |
+
"""
|
64 |
+
|
65 |
+
def __init__(
|
66 |
+
self,
|
67 |
+
dim: int,
|
68 |
+
num_attention_heads: int,
|
69 |
+
attention_head_dim: int,
|
70 |
+
dropout=0.0,
|
71 |
+
cross_attention_dim: Optional[int] = None,
|
72 |
+
activation_fn: str = "geglu",
|
73 |
+
num_embeds_ada_norm: Optional[int] = None,
|
74 |
+
attention_bias: bool = False,
|
75 |
+
only_cross_attention: bool = False,
|
76 |
+
double_self_attention: bool = False,
|
77 |
+
upcast_attention: bool = False,
|
78 |
+
norm_elementwise_affine: bool = True,
|
79 |
+
norm_type: str = "layer_norm",
|
80 |
+
# 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single', 'ada_norm_continuous', 'layer_norm_i2vgen'
|
81 |
+
norm_eps: float = 1e-5,
|
82 |
+
final_dropout: bool = False,
|
83 |
+
attention_type: str = "default",
|
84 |
+
positional_embeddings: Optional[str] = None,
|
85 |
+
num_positional_embeddings: Optional[int] = None,
|
86 |
+
ada_norm_continous_conditioning_embedding_dim: Optional[int] = None,
|
87 |
+
ada_norm_bias: Optional[int] = None,
|
88 |
+
ff_inner_dim: Optional[int] = None,
|
89 |
+
ff_bias: bool = True,
|
90 |
+
attention_out_bias: bool = True,
|
91 |
+
):
|
92 |
+
super().__init__()
|
93 |
+
self.dim = dim
|
94 |
+
self.num_attention_heads = num_attention_heads
|
95 |
+
self.attention_head_dim = attention_head_dim
|
96 |
+
self.dropout = dropout
|
97 |
+
self.cross_attention_dim = cross_attention_dim
|
98 |
+
self.activation_fn = activation_fn
|
99 |
+
self.attention_bias = attention_bias
|
100 |
+
self.double_self_attention = double_self_attention
|
101 |
+
self.norm_elementwise_affine = norm_elementwise_affine
|
102 |
+
self.positional_embeddings = positional_embeddings
|
103 |
+
self.num_positional_embeddings = num_positional_embeddings
|
104 |
+
self.only_cross_attention = only_cross_attention
|
105 |
+
|
106 |
+
# We keep these boolean flags for backward-compatibility.
|
107 |
+
self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
|
108 |
+
self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
|
109 |
+
self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
|
110 |
+
self.use_layer_norm = norm_type == "layer_norm"
|
111 |
+
self.use_ada_layer_norm_continuous = norm_type == "ada_norm_continuous"
|
112 |
+
|
113 |
+
if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
|
114 |
+
raise ValueError(
|
115 |
+
f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
|
116 |
+
f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
|
117 |
+
)
|
118 |
+
|
119 |
+
self.norm_type = norm_type
|
120 |
+
self.num_embeds_ada_norm = num_embeds_ada_norm
|
121 |
+
|
122 |
+
if positional_embeddings and (num_positional_embeddings is None):
|
123 |
+
raise ValueError(
|
124 |
+
"If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined."
|
125 |
+
)
|
126 |
+
|
127 |
+
if positional_embeddings == "sinusoidal":
|
128 |
+
self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings)
|
129 |
+
else:
|
130 |
+
self.pos_embed = None
|
131 |
+
|
132 |
+
# Define 3 blocks. Each block has its own normalization layer.
|
133 |
+
# 1. Self-Attn
|
134 |
+
if norm_type == "ada_norm":
|
135 |
+
self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
|
136 |
+
elif norm_type == "ada_norm_zero":
|
137 |
+
self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
|
138 |
+
elif norm_type == "ada_norm_continuous":
|
139 |
+
self.norm1 = AdaLayerNormContinuous(
|
140 |
+
dim,
|
141 |
+
ada_norm_continous_conditioning_embedding_dim,
|
142 |
+
norm_elementwise_affine,
|
143 |
+
norm_eps,
|
144 |
+
ada_norm_bias,
|
145 |
+
"rms_norm",
|
146 |
+
)
|
147 |
+
else:
|
148 |
+
self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
|
149 |
+
|
150 |
+
self.attn1 = Attention(
|
151 |
+
query_dim=dim,
|
152 |
+
heads=num_attention_heads,
|
153 |
+
dim_head=attention_head_dim,
|
154 |
+
dropout=dropout,
|
155 |
+
bias=attention_bias,
|
156 |
+
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
|
157 |
+
upcast_attention=upcast_attention,
|
158 |
+
out_bias=attention_out_bias,
|
159 |
+
)
|
160 |
+
|
161 |
+
# 2. Cross-Attn
|
162 |
+
if cross_attention_dim is not None or double_self_attention:
|
163 |
+
# We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
|
164 |
+
# I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
|
165 |
+
# the second cross attention block.
|
166 |
+
if norm_type == "ada_norm":
|
167 |
+
self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm)
|
168 |
+
elif norm_type == "ada_norm_continuous":
|
169 |
+
self.norm2 = AdaLayerNormContinuous(
|
170 |
+
dim,
|
171 |
+
ada_norm_continous_conditioning_embedding_dim,
|
172 |
+
norm_elementwise_affine,
|
173 |
+
norm_eps,
|
174 |
+
ada_norm_bias,
|
175 |
+
"rms_norm",
|
176 |
+
)
|
177 |
+
else:
|
178 |
+
self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
|
179 |
+
|
180 |
+
self.attn2 = Attention(
|
181 |
+
query_dim=dim,
|
182 |
+
cross_attention_dim=cross_attention_dim if not double_self_attention else None,
|
183 |
+
heads=num_attention_heads,
|
184 |
+
dim_head=attention_head_dim,
|
185 |
+
dropout=dropout,
|
186 |
+
bias=attention_bias,
|
187 |
+
upcast_attention=upcast_attention,
|
188 |
+
out_bias=attention_out_bias,
|
189 |
+
) # is self-attn if encoder_hidden_states is none
|
190 |
+
else:
|
191 |
+
if norm_type == "ada_norm_single": # For Latte
|
192 |
+
self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
|
193 |
+
else:
|
194 |
+
self.norm2 = None
|
195 |
+
self.attn2 = None
|
196 |
+
|
197 |
+
# 3. Feed-forward
|
198 |
+
if norm_type == "ada_norm_continuous":
|
199 |
+
self.norm3 = AdaLayerNormContinuous(
|
200 |
+
dim,
|
201 |
+
ada_norm_continous_conditioning_embedding_dim,
|
202 |
+
norm_elementwise_affine,
|
203 |
+
norm_eps,
|
204 |
+
ada_norm_bias,
|
205 |
+
"layer_norm",
|
206 |
+
)
|
207 |
+
|
208 |
+
elif norm_type in ["ada_norm_zero", "ada_norm", "layer_norm"]:
|
209 |
+
self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
|
210 |
+
elif norm_type == "layer_norm_i2vgen":
|
211 |
+
self.norm3 = None
|
212 |
+
|
213 |
+
self.ff = FeedForward(
|
214 |
+
dim,
|
215 |
+
dropout=dropout,
|
216 |
+
activation_fn=activation_fn,
|
217 |
+
final_dropout=final_dropout,
|
218 |
+
inner_dim=ff_inner_dim,
|
219 |
+
bias=ff_bias,
|
220 |
+
)
|
221 |
+
|
222 |
+
# 4. Fuser
|
223 |
+
if attention_type == "gated" or attention_type == "gated-text-image":
|
224 |
+
self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim)
|
225 |
+
|
226 |
+
# 5. Scale-shift for PixArt-Alpha.
|
227 |
+
if norm_type == "ada_norm_single":
|
228 |
+
self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim ** 0.5)
|
229 |
+
|
230 |
+
# let chunk size default to None
|
231 |
+
self._chunk_size = None
|
232 |
+
self._chunk_dim = 0
|
233 |
+
|
234 |
+
def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
|
235 |
+
# Sets chunk feed-forward
|
236 |
+
self._chunk_size = chunk_size
|
237 |
+
self._chunk_dim = dim
|
238 |
+
|
239 |
+
def forward(
|
240 |
+
self,
|
241 |
+
hidden_states: torch.Tensor,
|
242 |
+
attention_mask: Optional[torch.Tensor] = None,
|
243 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
244 |
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
245 |
+
timestep: Optional[torch.LongTensor] = None,
|
246 |
+
cross_attention_kwargs: Dict[str, Any] = None,
|
247 |
+
class_labels: Optional[torch.LongTensor] = None,
|
248 |
+
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
249 |
+
) -> torch.Tensor:
|
250 |
+
if cross_attention_kwargs is not None:
|
251 |
+
if cross_attention_kwargs.get("scale", None) is not None:
|
252 |
+
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
|
253 |
+
|
254 |
+
# Notice that normalization is always applied before the real computation in the following blocks.
|
255 |
+
# 0. Self-Attention
|
256 |
+
batch_size = hidden_states.shape[0]
|
257 |
+
|
258 |
+
if self.norm_type == "ada_norm":
|
259 |
+
norm_hidden_states = self.norm1(hidden_states, timestep)
|
260 |
+
elif self.norm_type == "ada_norm_zero":
|
261 |
+
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
|
262 |
+
hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
|
263 |
+
)
|
264 |
+
elif self.norm_type in ["layer_norm", "layer_norm_i2vgen"]:
|
265 |
+
norm_hidden_states = self.norm1(hidden_states)
|
266 |
+
elif self.norm_type == "ada_norm_continuous":
|
267 |
+
norm_hidden_states = self.norm1(hidden_states, added_cond_kwargs["pooled_text_emb"])
|
268 |
+
elif self.norm_type == "ada_norm_single":
|
269 |
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
|
270 |
+
self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
|
271 |
+
).chunk(6, dim=1)
|
272 |
+
norm_hidden_states = self.norm1(hidden_states)
|
273 |
+
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
|
274 |
+
else:
|
275 |
+
raise ValueError("Incorrect norm used")
|
276 |
+
|
277 |
+
if self.pos_embed is not None:
|
278 |
+
norm_hidden_states = self.pos_embed(norm_hidden_states)
|
279 |
+
|
280 |
+
# 1. Prepare GLIGEN inputs
|
281 |
+
cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
|
282 |
+
gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
|
283 |
+
|
284 |
+
attn_output = self.attn1(
|
285 |
+
hidden_states=norm_hidden_states,
|
286 |
+
encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
|
287 |
+
attention_mask=attention_mask,
|
288 |
+
**cross_attention_kwargs,
|
289 |
+
)
|
290 |
+
|
291 |
+
if self.norm_type == "ada_norm_zero":
|
292 |
+
attn_output = gate_msa.unsqueeze(1) * attn_output
|
293 |
+
elif self.norm_type == "ada_norm_single":
|
294 |
+
attn_output = gate_msa * attn_output
|
295 |
+
|
296 |
+
hidden_states = attn_output + hidden_states
|
297 |
+
if hidden_states.ndim == 4:
|
298 |
+
hidden_states = hidden_states.squeeze(1)
|
299 |
+
|
300 |
+
# 1.2 GLIGEN Control
|
301 |
+
if gligen_kwargs is not None:
|
302 |
+
hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
|
303 |
+
|
304 |
+
# 3. Cross-Attention
|
305 |
+
if self.attn2 is not None:
|
306 |
+
if self.norm_type == "ada_norm":
|
307 |
+
norm_hidden_states = self.norm2(hidden_states, timestep)
|
308 |
+
elif self.norm_type in ["ada_norm_zero", "layer_norm", "layer_norm_i2vgen"]:
|
309 |
+
norm_hidden_states = self.norm2(hidden_states)
|
310 |
+
elif self.norm_type == "ada_norm_single":
|
311 |
+
# For PixArt norm2 isn't applied here:
|
312 |
+
# https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103
|
313 |
+
norm_hidden_states = hidden_states
|
314 |
+
elif self.norm_type == "ada_norm_continuous":
|
315 |
+
norm_hidden_states = self.norm2(hidden_states, added_cond_kwargs["pooled_text_emb"])
|
316 |
+
else:
|
317 |
+
raise ValueError("Incorrect norm")
|
318 |
+
|
319 |
+
if self.pos_embed is not None and self.norm_type != "ada_norm_single":
|
320 |
+
norm_hidden_states = self.pos_embed(norm_hidden_states)
|
321 |
+
|
322 |
+
attn_output = self.attn2(
|
323 |
+
hidden_states=norm_hidden_states,
|
324 |
+
encoder_hidden_states=encoder_hidden_states,
|
325 |
+
attention_mask=encoder_attention_mask,
|
326 |
+
**cross_attention_kwargs,
|
327 |
+
)
|
328 |
+
hidden_states = attn_output + hidden_states
|
329 |
+
|
330 |
+
# 4. Feed-forward
|
331 |
+
# i2vgen doesn't have this norm 🤷♂️
|
332 |
+
if self.norm_type == "ada_norm_continuous":
|
333 |
+
norm_hidden_states = self.norm3(hidden_states, added_cond_kwargs["pooled_text_emb"])
|
334 |
+
elif not self.norm_type == "ada_norm_single":
|
335 |
+
norm_hidden_states = self.norm3(hidden_states)
|
336 |
+
|
337 |
+
if self.norm_type == "ada_norm_zero":
|
338 |
+
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
339 |
+
|
340 |
+
if self.norm_type == "ada_norm_single":
|
341 |
+
norm_hidden_states = self.norm2(hidden_states)
|
342 |
+
norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
|
343 |
+
|
344 |
+
if self._chunk_size is not None:
|
345 |
+
# "feed_forward_chunk_size" can be used to save memory
|
346 |
+
ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
|
347 |
+
else:
|
348 |
+
ff_output = self.ff(norm_hidden_states)
|
349 |
+
|
350 |
+
if self.norm_type == "ada_norm_zero":
|
351 |
+
ff_output = gate_mlp.unsqueeze(1) * ff_output
|
352 |
+
elif self.norm_type == "ada_norm_single":
|
353 |
+
ff_output = gate_mlp * ff_output
|
354 |
+
|
355 |
+
hidden_states = ff_output + hidden_states
|
356 |
+
if hidden_states.ndim == 4:
|
357 |
+
hidden_states = hidden_states.squeeze(1)
|
358 |
+
|
359 |
+
return hidden_states
|
onlyflow/models/attention_processor.py
ADDED
@@ -0,0 +1,456 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import inspect
|
2 |
+
import logging
|
3 |
+
from typing import Optional
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
import torch.nn.init as init
|
9 |
+
from diffusers.models.attention_processor import Attention as AttentionBase
|
10 |
+
from diffusers.models.attention_processor import AttnProcessor2_0 as AttnProcessor2_0_Base, SpatialNorm, AttnProcessor
|
11 |
+
from diffusers.models.attention_processor import IPAdapterAttnProcessor2_0 as IPAdapterAttnProcessor2_0_Base
|
12 |
+
from diffusers.utils.torch_utils import maybe_allow_in_graph
|
13 |
+
|
14 |
+
logger = logging.getLogger(__name__)
|
15 |
+
|
16 |
+
|
17 |
+
@maybe_allow_in_graph
|
18 |
+
class Attention(AttentionBase):
|
19 |
+
r"""
|
20 |
+
A cross attention layer.
|
21 |
+
|
22 |
+
Parameters:
|
23 |
+
query_dim (`int`):
|
24 |
+
The number of channels in the query.
|
25 |
+
cross_attention_dim (`int`, *optional*):
|
26 |
+
The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
|
27 |
+
heads (`int`, *optional*, defaults to 8):
|
28 |
+
The number of heads to use for multi-head attention.
|
29 |
+
kv_heads (`int`, *optional*, defaults to `None`):
|
30 |
+
The number of key and value heads to use for multi-head attention. Defaults to `heads`. If
|
31 |
+
`kv_heads=heads`, the model will use Multi Head Attention (MHA), if `kv_heads=1` the model will use Multi
|
32 |
+
Query Attention (MQA) otherwise GQA is used.
|
33 |
+
dim_head (`int`, *optional*, defaults to 64):
|
34 |
+
The number of channels in each head.
|
35 |
+
dropout (`float`, *optional*, defaults to 0.0):
|
36 |
+
The dropout probability to use.
|
37 |
+
bias (`bool`, *optional*, defaults to False):
|
38 |
+
Set to `True` for the query, key, and value linear layers to contain a bias parameter.
|
39 |
+
upcast_attention (`bool`, *optional*, defaults to False):
|
40 |
+
Set to `True` to upcast the attention computation to `float32`.
|
41 |
+
upcast_softmax (`bool`, *optional*, defaults to False):
|
42 |
+
Set to `True` to upcast the softmax computation to `float32`.
|
43 |
+
cross_attention_norm (`str`, *optional*, defaults to `None`):
|
44 |
+
The type of normalization to use for the cross attention. Can be `None`, `layer_norm`, or `group_norm`.
|
45 |
+
cross_attention_norm_num_groups (`int`, *optional*, defaults to 32):
|
46 |
+
The number of groups to use for the group norm in the cross attention.
|
47 |
+
added_kv_proj_dim (`int`, *optional*, defaults to `None`):
|
48 |
+
The number of channels to use for the added key and value projections. If `None`, no projection is used.
|
49 |
+
norm_num_groups (`int`, *optional*, defaults to `None`):
|
50 |
+
The number of groups to use for the group norm in the attention.
|
51 |
+
spatial_norm_dim (`int`, *optional*, defaults to `None`):
|
52 |
+
The number of channels to use for the spatial normalization.
|
53 |
+
out_bias (`bool`, *optional*, defaults to `True`):
|
54 |
+
Set to `True` to use a bias in the output linear layer.
|
55 |
+
scale_qk (`bool`, *optional*, defaults to `True`):
|
56 |
+
Set to `True` to scale the query and key by `1 / sqrt(dim_head)`.
|
57 |
+
only_cross_attention (`bool`, *optional*, defaults to `False`):
|
58 |
+
Set to `True` to only use cross attention and not added_kv_proj_dim. Can only be set to `True` if
|
59 |
+
`added_kv_proj_dim` is not `None`.
|
60 |
+
eps (`float`, *optional*, defaults to 1e-5):
|
61 |
+
An additional value added to the denominator in group normalization that is used for numerical stability.
|
62 |
+
rescale_output_factor (`float`, *optional*, defaults to 1.0):
|
63 |
+
A factor to rescale the output by dividing it with this value.
|
64 |
+
residual_connection (`bool`, *optional*, defaults to `False`):
|
65 |
+
Set to `True` to add the residual connection to the output.
|
66 |
+
_from_deprecated_attn_block (`bool`, *optional*, defaults to `False`):
|
67 |
+
Set to `True` if the attention block is loaded from a deprecated state dict.
|
68 |
+
processor (`AttnProcessor`, *optional*, defaults to `None`):
|
69 |
+
The attention processor to use. If `None`, defaults to `AttnProcessor2_0` if `torch 2.x` is used and
|
70 |
+
`AttnProcessor` otherwise.
|
71 |
+
"""
|
72 |
+
|
73 |
+
def __init__(
|
74 |
+
self,
|
75 |
+
query_dim: int,
|
76 |
+
cross_attention_dim: Optional[int] = None,
|
77 |
+
heads: int = 8,
|
78 |
+
kv_heads: Optional[int] = None,
|
79 |
+
dim_head: int = 64,
|
80 |
+
dropout: float = 0.0,
|
81 |
+
bias: bool = False,
|
82 |
+
upcast_attention: bool = False,
|
83 |
+
upcast_softmax: bool = False,
|
84 |
+
cross_attention_norm: Optional[str] = None,
|
85 |
+
cross_attention_norm_num_groups: int = 32,
|
86 |
+
qk_norm: Optional[str] = None,
|
87 |
+
added_kv_proj_dim: Optional[int] = None,
|
88 |
+
added_proj_bias: Optional[bool] = True,
|
89 |
+
norm_num_groups: Optional[int] = None,
|
90 |
+
spatial_norm_dim: Optional[int] = None,
|
91 |
+
out_bias: bool = True,
|
92 |
+
scale_qk: bool = True,
|
93 |
+
only_cross_attention: bool = False,
|
94 |
+
eps: float = 1e-5,
|
95 |
+
rescale_output_factor: float = 1.0,
|
96 |
+
residual_connection: bool = False,
|
97 |
+
_from_deprecated_attn_block: bool = False,
|
98 |
+
processor: Optional["AttnProcessor"] = None,
|
99 |
+
out_dim: int = None,
|
100 |
+
context_pre_only=None,
|
101 |
+
pre_only=False,
|
102 |
+
):
|
103 |
+
nn.Module.__init__(self)
|
104 |
+
|
105 |
+
# To prevent circular import.
|
106 |
+
from diffusers.models.normalization import FP32LayerNorm, RMSNorm
|
107 |
+
|
108 |
+
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
|
109 |
+
self.inner_kv_dim = self.inner_dim if kv_heads is None else dim_head * kv_heads
|
110 |
+
self.query_dim = query_dim
|
111 |
+
self.use_bias = bias
|
112 |
+
self.is_cross_attention = cross_attention_dim is not None
|
113 |
+
self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
|
114 |
+
self.upcast_attention = upcast_attention
|
115 |
+
self.upcast_softmax = upcast_softmax
|
116 |
+
self.rescale_output_factor = rescale_output_factor
|
117 |
+
self.residual_connection = residual_connection
|
118 |
+
self.dropout = dropout
|
119 |
+
self.fused_projections = False
|
120 |
+
self.out_dim = out_dim if out_dim is not None else query_dim
|
121 |
+
self.context_pre_only = context_pre_only
|
122 |
+
self.pre_only = pre_only
|
123 |
+
|
124 |
+
# we make use of this private variable to know whether this class is loaded
|
125 |
+
# with an deprecated state dict so that we can convert it on the fly
|
126 |
+
self._from_deprecated_attn_block = _from_deprecated_attn_block
|
127 |
+
|
128 |
+
self.scale_qk = scale_qk
|
129 |
+
self.scale = dim_head ** -0.5 if self.scale_qk else 1.0
|
130 |
+
|
131 |
+
self.heads = out_dim // dim_head if out_dim is not None else heads
|
132 |
+
# for slice_size > 0 the attention score computation
|
133 |
+
# is split across the batch axis to save memory
|
134 |
+
# You can set slice_size with `set_attention_slice`
|
135 |
+
self.sliceable_head_dim = heads
|
136 |
+
|
137 |
+
self.added_kv_proj_dim = added_kv_proj_dim
|
138 |
+
self.only_cross_attention = only_cross_attention
|
139 |
+
|
140 |
+
if self.added_kv_proj_dim is None and self.only_cross_attention:
|
141 |
+
raise ValueError(
|
142 |
+
"`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`."
|
143 |
+
)
|
144 |
+
|
145 |
+
if norm_num_groups is not None:
|
146 |
+
self.group_norm = nn.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True)
|
147 |
+
else:
|
148 |
+
self.group_norm = None
|
149 |
+
|
150 |
+
if spatial_norm_dim is not None:
|
151 |
+
self.spatial_norm = SpatialNorm(f_channels=query_dim, zq_channels=spatial_norm_dim)
|
152 |
+
else:
|
153 |
+
self.spatial_norm = None
|
154 |
+
|
155 |
+
if qk_norm is None:
|
156 |
+
self.norm_q = None
|
157 |
+
self.norm_k = None
|
158 |
+
elif qk_norm == "layer_norm":
|
159 |
+
self.norm_q = nn.LayerNorm(dim_head, eps=eps)
|
160 |
+
self.norm_k = nn.LayerNorm(dim_head, eps=eps)
|
161 |
+
elif qk_norm == "fp32_layer_norm":
|
162 |
+
self.norm_q = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
|
163 |
+
self.norm_k = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
|
164 |
+
elif qk_norm == "layer_norm_across_heads":
|
165 |
+
# Lumina applys qk norm across all heads
|
166 |
+
self.norm_q = nn.LayerNorm(dim_head * heads, eps=eps)
|
167 |
+
self.norm_k = nn.LayerNorm(dim_head * kv_heads, eps=eps)
|
168 |
+
elif qk_norm == "rms_norm":
|
169 |
+
self.norm_q = RMSNorm(dim_head, eps=eps)
|
170 |
+
self.norm_k = RMSNorm(dim_head, eps=eps)
|
171 |
+
else:
|
172 |
+
raise ValueError(f"unknown qk_norm: {qk_norm}. Should be None or 'layer_norm'")
|
173 |
+
|
174 |
+
if cross_attention_norm is None:
|
175 |
+
self.norm_cross = None
|
176 |
+
elif cross_attention_norm == "layer_norm":
|
177 |
+
self.norm_cross = nn.LayerNorm(self.cross_attention_dim)
|
178 |
+
elif cross_attention_norm == "group_norm":
|
179 |
+
if self.added_kv_proj_dim is not None:
|
180 |
+
# The given `encoder_hidden_states` are initially of shape
|
181 |
+
# (batch_size, seq_len, added_kv_proj_dim) before being projected
|
182 |
+
# to (batch_size, seq_len, cross_attention_dim). The norm is applied
|
183 |
+
# before the projection, so we need to use `added_kv_proj_dim` as
|
184 |
+
# the number of channels for the group norm.
|
185 |
+
norm_cross_num_channels = added_kv_proj_dim
|
186 |
+
else:
|
187 |
+
norm_cross_num_channels = self.cross_attention_dim
|
188 |
+
|
189 |
+
self.norm_cross = nn.GroupNorm(
|
190 |
+
num_channels=norm_cross_num_channels, num_groups=cross_attention_norm_num_groups, eps=1e-5, affine=True
|
191 |
+
)
|
192 |
+
else:
|
193 |
+
raise ValueError(
|
194 |
+
f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'"
|
195 |
+
)
|
196 |
+
|
197 |
+
self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias)
|
198 |
+
|
199 |
+
if not self.only_cross_attention:
|
200 |
+
# only relevant for the `AddedKVProcessor` classes
|
201 |
+
self.to_k = nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias)
|
202 |
+
self.to_v = nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias)
|
203 |
+
else:
|
204 |
+
self.to_k = None
|
205 |
+
self.to_v = None
|
206 |
+
|
207 |
+
self.added_proj_bias = added_proj_bias
|
208 |
+
if self.added_kv_proj_dim is not None:
|
209 |
+
self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias)
|
210 |
+
self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias)
|
211 |
+
if self.context_pre_only is not None:
|
212 |
+
self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
|
213 |
+
|
214 |
+
if not self.pre_only:
|
215 |
+
self.to_out = nn.ModuleList([])
|
216 |
+
self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
|
217 |
+
self.to_out.append(nn.Dropout(dropout))
|
218 |
+
|
219 |
+
if self.context_pre_only is not None and not self.context_pre_only:
|
220 |
+
self.to_add_out = nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)
|
221 |
+
|
222 |
+
if qk_norm is not None and added_kv_proj_dim is not None:
|
223 |
+
if qk_norm == "fp32_layer_norm":
|
224 |
+
self.norm_added_q = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
|
225 |
+
self.norm_added_k = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
|
226 |
+
elif qk_norm == "rms_norm":
|
227 |
+
self.norm_added_q = RMSNorm(dim_head, eps=eps)
|
228 |
+
self.norm_added_k = RMSNorm(dim_head, eps=eps)
|
229 |
+
else:
|
230 |
+
self.norm_added_q = None
|
231 |
+
self.norm_added_k = None
|
232 |
+
|
233 |
+
# set attention processor
|
234 |
+
# We use the AttnProcessor2_0 by default when torch 2.x is used which uses
|
235 |
+
# torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
|
236 |
+
# but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
|
237 |
+
if processor is None:
|
238 |
+
processor = (
|
239 |
+
AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
|
240 |
+
)
|
241 |
+
self.set_processor(processor)
|
242 |
+
|
243 |
+
def forward(
|
244 |
+
self,
|
245 |
+
hidden_states: torch.Tensor,
|
246 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
247 |
+
attention_mask: Optional[torch.Tensor] = None,
|
248 |
+
**cross_attention_kwargs,
|
249 |
+
) -> torch.Tensor:
|
250 |
+
r"""
|
251 |
+
The forward method of the `Attention` class.
|
252 |
+
|
253 |
+
Args:
|
254 |
+
hidden_states (`torch.Tensor`):
|
255 |
+
The hidden states of the query.
|
256 |
+
encoder_hidden_states (`torch.Tensor`, *optional*):
|
257 |
+
The hidden states of the encoder.
|
258 |
+
attention_mask (`torch.Tensor`, *optional*):
|
259 |
+
The attention mask to use. If `None`, no mask is applied.
|
260 |
+
**cross_attention_kwargs:
|
261 |
+
Additional keyword arguments to pass along to the cross attention.
|
262 |
+
|
263 |
+
Returns:
|
264 |
+
`torch.Tensor`: The output of the attention layer.
|
265 |
+
"""
|
266 |
+
# The `Attention` class can call different attention processors / attention functions
|
267 |
+
# here we simply pass along all tensors to the selected processor class
|
268 |
+
# For standard processors that are defined here, `**cross_attention_kwargs` is empty
|
269 |
+
|
270 |
+
return self.processor(
|
271 |
+
self,
|
272 |
+
hidden_states=hidden_states,
|
273 |
+
encoder_hidden_states=encoder_hidden_states,
|
274 |
+
attention_mask=attention_mask,
|
275 |
+
**cross_attention_kwargs,
|
276 |
+
)
|
277 |
+
|
278 |
+
|
279 |
+
class AttnProcessor2_0(AttnProcessor2_0_Base):
|
280 |
+
def __call__(
|
281 |
+
self,
|
282 |
+
attn: Attention,
|
283 |
+
hidden_states: torch.Tensor,
|
284 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
285 |
+
attention_mask: Optional[torch.Tensor] = None,
|
286 |
+
temb: Optional[torch.Tensor] = None,
|
287 |
+
flow_feature: Optional[torch.Tensor] = None,
|
288 |
+
flow_scale: Optional[float] = None,
|
289 |
+
*args,
|
290 |
+
**kwargs,
|
291 |
+
) -> torch.Tensor:
|
292 |
+
|
293 |
+
old_attn = attn.scale
|
294 |
+
attn.scale *= kwargs.get("attn_scale", 1.0)
|
295 |
+
|
296 |
+
output = super().__call__(
|
297 |
+
attn,
|
298 |
+
hidden_states,
|
299 |
+
encoder_hidden_states=encoder_hidden_states,
|
300 |
+
attention_mask=attention_mask,
|
301 |
+
temb=temb,
|
302 |
+
*args,
|
303 |
+
**kwargs,
|
304 |
+
)
|
305 |
+
|
306 |
+
attn.scale = old_attn
|
307 |
+
return output
|
308 |
+
|
309 |
+
class IPAdapterAttnProcessor2_0(IPAdapterAttnProcessor2_0_Base):
|
310 |
+
def __call__(
|
311 |
+
self,
|
312 |
+
attn: Attention,
|
313 |
+
hidden_states: torch.Tensor,
|
314 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
315 |
+
attention_mask: Optional[torch.Tensor] = None,
|
316 |
+
temb: Optional[torch.Tensor] = None,
|
317 |
+
scale: float = 1.0,
|
318 |
+
ip_adapter_masks: Optional[torch.Tensor] = None,
|
319 |
+
flow_feature: Optional[torch.Tensor] = None,
|
320 |
+
flow_scale: Optional[float] = None,
|
321 |
+
*args,
|
322 |
+
**kwargs,
|
323 |
+
) -> torch.Tensor:
|
324 |
+
return super().__call__(
|
325 |
+
attn=attn,
|
326 |
+
hidden_states=hidden_states,
|
327 |
+
encoder_hidden_states=encoder_hidden_states,
|
328 |
+
attention_mask=attention_mask,
|
329 |
+
temb=temb,
|
330 |
+
scale=scale,
|
331 |
+
ip_adapter_masks=ip_adapter_masks,
|
332 |
+
)
|
333 |
+
|
334 |
+
|
335 |
+
class FlowAdaptorAttnProcessor(nn.Module):
|
336 |
+
def __init__(self,
|
337 |
+
type: str,
|
338 |
+
hidden_size, # dimension of hidden state
|
339 |
+
flow_feature_dim=None, # dimension of the pose feature
|
340 |
+
cross_attention_dim=None, # dimension of the text embedding
|
341 |
+
query_condition=False,
|
342 |
+
key_value_condition=False,
|
343 |
+
flow_scale=1.0
|
344 |
+
):
|
345 |
+
super().__init__()
|
346 |
+
|
347 |
+
self.type = type
|
348 |
+
self.hidden_size = hidden_size
|
349 |
+
self.flow_feature_dim = flow_feature_dim
|
350 |
+
self.cross_attention_dim = cross_attention_dim
|
351 |
+
self.flow_scale = flow_scale
|
352 |
+
self.query_condition = query_condition
|
353 |
+
self.key_value_condition = key_value_condition
|
354 |
+
assert hidden_size == flow_feature_dim
|
355 |
+
if self.query_condition and self.key_value_condition:
|
356 |
+
self.qkv_merge = nn.Linear(hidden_size, hidden_size)
|
357 |
+
init.zeros_(self.qkv_merge.weight)
|
358 |
+
init.zeros_(self.qkv_merge.bias)
|
359 |
+
elif self.query_condition:
|
360 |
+
self.q_merge = nn.Linear(hidden_size, hidden_size)
|
361 |
+
init.zeros_(self.q_merge.weight)
|
362 |
+
init.zeros_(self.q_merge.bias)
|
363 |
+
else:
|
364 |
+
self.kv_merge = nn.Linear(hidden_size, hidden_size)
|
365 |
+
init.zeros_(self.kv_merge.weight)
|
366 |
+
init.zeros_(self.kv_merge.bias)
|
367 |
+
|
368 |
+
def forward(self,
|
369 |
+
attn: Attention,
|
370 |
+
hidden_states,
|
371 |
+
flow_feature,
|
372 |
+
encoder_hidden_states=None,
|
373 |
+
attention_mask=None,
|
374 |
+
temb=None,
|
375 |
+
flow_scale=None,
|
376 |
+
*args,
|
377 |
+
**kwargs,
|
378 |
+
):
|
379 |
+
assert flow_feature is not None
|
380 |
+
flow_embedding_scale = (flow_scale if flow_scale is not None else self.flow_scale)
|
381 |
+
|
382 |
+
residual = hidden_states
|
383 |
+
if attn.spatial_norm is not None:
|
384 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
385 |
+
|
386 |
+
if self.query_condition and self.key_value_condition:
|
387 |
+
assert encoder_hidden_states is None
|
388 |
+
|
389 |
+
if encoder_hidden_states is None:
|
390 |
+
encoder_hidden_states = hidden_states
|
391 |
+
|
392 |
+
batch_size, ehs_sequence_length, _ = encoder_hidden_states.shape
|
393 |
+
|
394 |
+
if attention_mask is not None:
|
395 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, ehs_sequence_length, batch_size)
|
396 |
+
# scaled_dot_product_attention expects attention_mask shape to be
|
397 |
+
# (batch, heads, source_length, target_length)
|
398 |
+
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
399 |
+
|
400 |
+
if attn.group_norm is not None:
|
401 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
402 |
+
|
403 |
+
if attn.norm_cross:
|
404 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
405 |
+
|
406 |
+
if self.query_condition and self.key_value_condition: # only self attention
|
407 |
+
query_hidden_state = self.qkv_merge(hidden_states + flow_feature) * flow_embedding_scale + hidden_states
|
408 |
+
key_value_hidden_state = query_hidden_state
|
409 |
+
elif self.query_condition:
|
410 |
+
query_hidden_state = self.q_merge(hidden_states + flow_feature) * flow_embedding_scale + hidden_states
|
411 |
+
key_value_hidden_state = encoder_hidden_states
|
412 |
+
else:
|
413 |
+
key_value_hidden_state = self.kv_merge(
|
414 |
+
encoder_hidden_states + flow_feature) * flow_embedding_scale + encoder_hidden_states
|
415 |
+
query_hidden_state = hidden_states
|
416 |
+
|
417 |
+
# original attention
|
418 |
+
key = attn.to_k(key_value_hidden_state)
|
419 |
+
value = attn.to_v(key_value_hidden_state)
|
420 |
+
query = attn.to_q(query_hidden_state)
|
421 |
+
|
422 |
+
inner_dim = key.shape[-1]
|
423 |
+
head_dim = inner_dim // attn.heads
|
424 |
+
|
425 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
426 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
427 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
428 |
+
|
429 |
+
if attn.norm_q is not None:
|
430 |
+
query = attn.norm_q(query)
|
431 |
+
if attn.norm_k is not None:
|
432 |
+
key = attn.norm_k(key)
|
433 |
+
|
434 |
+
hidden_states = F.scaled_dot_product_attention(
|
435 |
+
query, key, value,
|
436 |
+
attn_mask=attention_mask,
|
437 |
+
dropout_p=0.0,
|
438 |
+
is_causal=False,
|
439 |
+
scale=attn.scale * kwargs.get("attn_scale_flow", 1.0),
|
440 |
+
)
|
441 |
+
|
442 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
443 |
+
hidden_states = hidden_states.to(query.dtype)
|
444 |
+
|
445 |
+
# linear proj
|
446 |
+
hidden_states = attn.to_out[0](hidden_states)
|
447 |
+
|
448 |
+
# dropout
|
449 |
+
hidden_states = attn.to_out[1](hidden_states)
|
450 |
+
|
451 |
+
if attn.residual_connection:
|
452 |
+
hidden_states = hidden_states + residual
|
453 |
+
|
454 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
455 |
+
|
456 |
+
return hidden_states
|
onlyflow/models/flow_adaptor.py
ADDED
@@ -0,0 +1,247 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
from einops import rearrange
|
6 |
+
from torch.utils import checkpoint
|
7 |
+
|
8 |
+
from onlyflow.models.attention import BasicTransformerBlock
|
9 |
+
|
10 |
+
|
11 |
+
def get_parameter_dtype(parameter: torch.nn.Module):
|
12 |
+
params = tuple(parameter.parameters())
|
13 |
+
if len(params) > 0:
|
14 |
+
return params[0].dtype
|
15 |
+
|
16 |
+
buffers = tuple(parameter.buffers())
|
17 |
+
if len(buffers) > 0:
|
18 |
+
return buffers[0].dtype
|
19 |
+
|
20 |
+
|
21 |
+
def conv_nd(dims, *args, **kwargs):
|
22 |
+
"""
|
23 |
+
Create a 1D, 2D, or 3D convolution module.
|
24 |
+
"""
|
25 |
+
if dims == 1:
|
26 |
+
return nn.Conv1d(*args, **kwargs)
|
27 |
+
elif dims == 2:
|
28 |
+
return nn.Conv2d(*args, **kwargs)
|
29 |
+
elif dims == 3:
|
30 |
+
return nn.Conv3d(*args, **kwargs)
|
31 |
+
raise ValueError(f"unsupported dimensions: {dims}")
|
32 |
+
|
33 |
+
|
34 |
+
def avg_pool_nd(dims, *args, **kwargs):
|
35 |
+
"""
|
36 |
+
Create a 1D, 2D, or 3D average pooling module.
|
37 |
+
"""
|
38 |
+
if dims == 1:
|
39 |
+
return nn.AvgPool1d(*args, **kwargs)
|
40 |
+
elif dims == 2:
|
41 |
+
return nn.AvgPool2d(*args, **kwargs)
|
42 |
+
elif dims == 3:
|
43 |
+
return nn.AvgPool3d(*args, **kwargs)
|
44 |
+
raise ValueError(f"unsupported dimensions: {dims}")
|
45 |
+
|
46 |
+
|
47 |
+
class FlowAdaptor(nn.Module):
|
48 |
+
def __init__(self, unet, flow_encoder, ckpt_act=True):
|
49 |
+
super().__init__()
|
50 |
+
self.unet = unet
|
51 |
+
self.flow_encoder = flow_encoder
|
52 |
+
self.ckpt_act = ckpt_act
|
53 |
+
|
54 |
+
def forward(self, noisy_latents, timesteps, encoder_hidden_states, flow_embedding):
|
55 |
+
assert flow_embedding.ndim == 5
|
56 |
+
bs = flow_embedding.shape[0] # b c f h w
|
57 |
+
flow_embedding_features = self.flow_encoder(flow_embedding) # flow_embedding b f c h w
|
58 |
+
flow_embedding_features = [rearrange(x, '(b f) c h w -> b c f h w', b=bs)
|
59 |
+
for x in flow_embedding_features]
|
60 |
+
|
61 |
+
added_cond_kwargs = {'flow_embedding_features': flow_embedding_features}
|
62 |
+
|
63 |
+
noise_pred = self.unet(noisy_latents,
|
64 |
+
timesteps,
|
65 |
+
encoder_hidden_states,
|
66 |
+
added_cond_kwargs=added_cond_kwargs,
|
67 |
+
)
|
68 |
+
|
69 |
+
return noise_pred.sample
|
70 |
+
|
71 |
+
|
72 |
+
class Downsample(nn.Module):
|
73 |
+
"""
|
74 |
+
A downsampling layer with an optional convolution.
|
75 |
+
:param channels: channels in the inputs and outputs.
|
76 |
+
:param use_conv: a bool determining if a convolution is applied.
|
77 |
+
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
|
78 |
+
downsampling occurs in the inner-two dimensions.
|
79 |
+
"""
|
80 |
+
|
81 |
+
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
|
82 |
+
super().__init__()
|
83 |
+
self.channels = channels
|
84 |
+
self.out_channels = out_channels or channels
|
85 |
+
self.use_conv = use_conv
|
86 |
+
self.dims = dims
|
87 |
+
stride = 2 if dims != 3 else (1, 2, 2)
|
88 |
+
if use_conv:
|
89 |
+
self.op = conv_nd(dims, self.channels, self.out_channels, 3, stride=stride, padding=padding)
|
90 |
+
else:
|
91 |
+
assert self.channels == self.out_channels
|
92 |
+
self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
|
93 |
+
|
94 |
+
def forward(self, x):
|
95 |
+
assert x.shape[1] == self.channels
|
96 |
+
return self.op(x)
|
97 |
+
|
98 |
+
|
99 |
+
class ResnetBlock(nn.Module):
|
100 |
+
|
101 |
+
def __init__(self, in_c, out_c, down, ksize=3, sk=False, use_conv=True):
|
102 |
+
super().__init__()
|
103 |
+
ps = ksize // 2
|
104 |
+
if in_c != out_c or sk == False:
|
105 |
+
self.in_conv = nn.Conv2d(in_c, out_c, ksize, 1, ps)
|
106 |
+
else:
|
107 |
+
self.in_conv = None
|
108 |
+
self.block1 = nn.Conv2d(out_c, out_c, 3, 1, 1)
|
109 |
+
self.act = nn.ReLU()
|
110 |
+
self.block2 = nn.Conv2d(out_c, out_c, ksize, 1, ps)
|
111 |
+
if not sk:
|
112 |
+
self.skep = nn.Conv2d(in_c, out_c, ksize, 1, ps)
|
113 |
+
else:
|
114 |
+
self.skep = None
|
115 |
+
|
116 |
+
self.down = down
|
117 |
+
if self.down:
|
118 |
+
self.down_opt = Downsample(in_c, use_conv=use_conv)
|
119 |
+
|
120 |
+
def forward(self, x):
|
121 |
+
if self.down:
|
122 |
+
x = self.down_opt(x)
|
123 |
+
if self.in_conv is not None: # edit
|
124 |
+
x = self.in_conv(x)
|
125 |
+
|
126 |
+
h = self.block1(x)
|
127 |
+
h = self.act(h)
|
128 |
+
h = self.block2(h)
|
129 |
+
if self.skep is not None:
|
130 |
+
return h + self.skep(x)
|
131 |
+
else:
|
132 |
+
return h + x
|
133 |
+
|
134 |
+
|
135 |
+
class PositionalEncoding(nn.Module):
|
136 |
+
def __init__(
|
137 |
+
self,
|
138 |
+
d_model,
|
139 |
+
dropout=0.,
|
140 |
+
max_len=32,
|
141 |
+
):
|
142 |
+
super().__init__()
|
143 |
+
self.dropout = nn.Dropout(p=dropout)
|
144 |
+
position = torch.arange(max_len).unsqueeze(1)
|
145 |
+
div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
|
146 |
+
pe = torch.zeros(1, max_len, d_model)
|
147 |
+
pe[0, :, 0::2, ...] = torch.sin(position * div_term)
|
148 |
+
pe[0, :, 1::2, ...] = torch.cos(position * div_term)
|
149 |
+
pe.unsqueeze_(-1).unsqueeze_(-1)
|
150 |
+
self.register_buffer('pe', pe)
|
151 |
+
|
152 |
+
def forward(self, x):
|
153 |
+
x = x + self.pe[:, :x.size(1), ...]
|
154 |
+
return self.dropout(x)
|
155 |
+
|
156 |
+
|
157 |
+
class FlowEncoder(nn.Module):
|
158 |
+
|
159 |
+
def __init__(self,
|
160 |
+
downscale_factor,
|
161 |
+
channels=None,
|
162 |
+
nums_rb=3,
|
163 |
+
ksize=3,
|
164 |
+
sk=False,
|
165 |
+
use_conv=True,
|
166 |
+
compression_factor=1,
|
167 |
+
temporal_attention_nhead=8,
|
168 |
+
positional_embeddings=None,
|
169 |
+
num_positional_embeddings=16,
|
170 |
+
rescale_output_factor=1.0,
|
171 |
+
checkpointing=False):
|
172 |
+
super(FlowEncoder, self).__init__()
|
173 |
+
if channels is None:
|
174 |
+
channels = [320, 640, 1280, 1280]
|
175 |
+
|
176 |
+
self.checkpointing = checkpointing
|
177 |
+
self.unshuffle = nn.PixelUnshuffle(downscale_factor)
|
178 |
+
self.channels = channels
|
179 |
+
self.nums_rb = nums_rb
|
180 |
+
self.encoder_down_conv_blocks = nn.ModuleList()
|
181 |
+
self.encoder_down_attention_blocks = nn.ModuleList()
|
182 |
+
for i in range(len(channels)):
|
183 |
+
conv_layers = nn.ModuleList()
|
184 |
+
temporal_attention_layers = nn.ModuleList()
|
185 |
+
for j in range(nums_rb):
|
186 |
+
if j == 0 and i != 0:
|
187 |
+
in_dim = channels[i - 1]
|
188 |
+
out_dim = int(channels[i] / compression_factor)
|
189 |
+
conv_layer = ResnetBlock(in_dim, out_dim, down=True, ksize=ksize, sk=sk, use_conv=use_conv)
|
190 |
+
elif j == 0:
|
191 |
+
in_dim = channels[0]
|
192 |
+
out_dim = int(channels[i] / compression_factor)
|
193 |
+
conv_layer = ResnetBlock(in_dim, out_dim, down=False, ksize=ksize, sk=sk, use_conv=use_conv)
|
194 |
+
elif j == nums_rb - 1:
|
195 |
+
in_dim = channels[i] / compression_factor
|
196 |
+
out_dim = channels[i]
|
197 |
+
conv_layer = ResnetBlock(in_dim, out_dim, down=False, ksize=ksize, sk=sk, use_conv=use_conv)
|
198 |
+
else:
|
199 |
+
in_dim = int(channels[i] / compression_factor)
|
200 |
+
out_dim = int(channels[i] / compression_factor)
|
201 |
+
conv_layer = ResnetBlock(in_dim, out_dim, down=False, ksize=ksize, sk=sk, use_conv=use_conv)
|
202 |
+
temporal_attention_layer = BasicTransformerBlock(
|
203 |
+
dim=out_dim,
|
204 |
+
num_attention_heads=temporal_attention_nhead,
|
205 |
+
attention_head_dim=int(out_dim / temporal_attention_nhead),
|
206 |
+
dropout=0.0,
|
207 |
+
positional_embeddings=positional_embeddings,
|
208 |
+
num_positional_embeddings=num_positional_embeddings
|
209 |
+
)
|
210 |
+
conv_layers.append(conv_layer)
|
211 |
+
temporal_attention_layers.append(temporal_attention_layer)
|
212 |
+
self.encoder_down_conv_blocks.append(conv_layers)
|
213 |
+
self.encoder_down_attention_blocks.append(temporal_attention_layers)
|
214 |
+
|
215 |
+
self.encoder_conv_in = nn.Conv2d(2 * (downscale_factor ** 2), channels[0], 3, 1, 1)
|
216 |
+
|
217 |
+
@property
|
218 |
+
def dtype(self) -> torch.dtype:
|
219 |
+
"""
|
220 |
+
`torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
|
221 |
+
"""
|
222 |
+
return get_parameter_dtype(self)
|
223 |
+
|
224 |
+
def forward(self, x):
|
225 |
+
# unshuffle
|
226 |
+
bs = x.shape[0]
|
227 |
+
x = rearrange(x, "b c f h w -> (b f) c h w")
|
228 |
+
x = self.unshuffle(x)
|
229 |
+
# extract features
|
230 |
+
features = []
|
231 |
+
x = self.encoder_conv_in(x)
|
232 |
+
for i, (res_block, attention_block) in enumerate(
|
233 |
+
zip(self.encoder_down_conv_blocks, self.encoder_down_attention_blocks)):
|
234 |
+
for j, (res_layer, attention_layer) in enumerate(zip(res_block, attention_block)):
|
235 |
+
if self.checkpointing:
|
236 |
+
x = checkpoint.checkpoint(res_layer, x, use_reentrant=False)
|
237 |
+
else:
|
238 |
+
x = res_layer(x)
|
239 |
+
h, w = x.shape[-2:]
|
240 |
+
x = rearrange(x, '(b f) c h w -> (b h w) f c', b=bs)
|
241 |
+
if self.checkpointing:
|
242 |
+
x = checkpoint.checkpoint(attention_layer, x, use_reentrant=False)
|
243 |
+
else:
|
244 |
+
x = attention_layer(x)
|
245 |
+
x = rearrange(x, '(b h w) f c -> (b f) c h w', h=h, w=w)
|
246 |
+
features.append(x)
|
247 |
+
return features
|
onlyflow/models/transformer_2d.py
ADDED
@@ -0,0 +1,566 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
from typing import Any, Dict, Optional
|
15 |
+
|
16 |
+
import torch
|
17 |
+
import torch.nn.functional as F
|
18 |
+
from diffusers.configuration_utils import LegacyConfigMixin, register_to_config
|
19 |
+
from diffusers.models.embeddings import ImagePositionalEmbeddings, PatchEmbed, PixArtAlphaTextProjection
|
20 |
+
from diffusers.models.modeling_outputs import Transformer2DModelOutput
|
21 |
+
from diffusers.models.modeling_utils import LegacyModelMixin
|
22 |
+
from diffusers.models.normalization import AdaLayerNormSingle
|
23 |
+
from diffusers.utils import deprecate, is_torch_version, logging
|
24 |
+
from torch import nn
|
25 |
+
|
26 |
+
from onlyflow.models.attention import BasicTransformerBlock
|
27 |
+
|
28 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
29 |
+
|
30 |
+
|
31 |
+
class Transformer2DModelOutput(Transformer2DModelOutput):
|
32 |
+
def __init__(self, *args, **kwargs):
|
33 |
+
deprecation_message = "Importing `Transformer2DModelOutput` from `diffusers.models.transformer_2d` is deprecated and this will be removed in a future version. Please use `from diffusers.models.modeling_outputs import Transformer2DModelOutput`, instead."
|
34 |
+
deprecate("Transformer2DModelOutput", "1.0.0", deprecation_message)
|
35 |
+
super().__init__(*args, **kwargs)
|
36 |
+
|
37 |
+
|
38 |
+
class Transformer2DModel(LegacyModelMixin, LegacyConfigMixin):
|
39 |
+
"""
|
40 |
+
A 2D Transformer model for image-like data.
|
41 |
+
|
42 |
+
Parameters:
|
43 |
+
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
|
44 |
+
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
|
45 |
+
in_channels (`int`, *optional*):
|
46 |
+
The number of channels in the input and output (specify if the input is **continuous**).
|
47 |
+
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
|
48 |
+
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
49 |
+
cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
|
50 |
+
sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
|
51 |
+
This is fixed during training since it is used to learn a number of position embeddings.
|
52 |
+
num_vector_embeds (`int`, *optional*):
|
53 |
+
The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**).
|
54 |
+
Includes the class for the masked latent pixel.
|
55 |
+
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
|
56 |
+
num_embeds_ada_norm ( `int`, *optional*):
|
57 |
+
The number of diffusion steps used during training. Pass if at least one of the norm_layers is
|
58 |
+
`AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
|
59 |
+
added to the hidden states.
|
60 |
+
|
61 |
+
During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`.
|
62 |
+
attention_bias (`bool`, *optional*):
|
63 |
+
Configure if the `TransformerBlocks` attention should contain a bias parameter.
|
64 |
+
"""
|
65 |
+
|
66 |
+
_supports_gradient_checkpointing = True
|
67 |
+
_no_split_modules = ["BasicTransformerBlock"]
|
68 |
+
|
69 |
+
@register_to_config
|
70 |
+
def __init__(
|
71 |
+
self,
|
72 |
+
num_attention_heads: int = 16,
|
73 |
+
attention_head_dim: int = 88,
|
74 |
+
in_channels: Optional[int] = None,
|
75 |
+
out_channels: Optional[int] = None,
|
76 |
+
num_layers: int = 1,
|
77 |
+
dropout: float = 0.0,
|
78 |
+
norm_num_groups: int = 32,
|
79 |
+
cross_attention_dim: Optional[int] = None,
|
80 |
+
attention_bias: bool = False,
|
81 |
+
sample_size: Optional[int] = None,
|
82 |
+
num_vector_embeds: Optional[int] = None,
|
83 |
+
patch_size: Optional[int] = None,
|
84 |
+
activation_fn: str = "geglu",
|
85 |
+
num_embeds_ada_norm: Optional[int] = None,
|
86 |
+
use_linear_projection: bool = False,
|
87 |
+
only_cross_attention: bool = False,
|
88 |
+
double_self_attention: bool = False,
|
89 |
+
upcast_attention: bool = False,
|
90 |
+
norm_type: str = "layer_norm",
|
91 |
+
# 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single', 'ada_norm_continuous', 'layer_norm_i2vgen'
|
92 |
+
norm_elementwise_affine: bool = True,
|
93 |
+
norm_eps: float = 1e-5,
|
94 |
+
attention_type: str = "default",
|
95 |
+
caption_channels: int = None,
|
96 |
+
interpolation_scale: float = None,
|
97 |
+
use_additional_conditions: Optional[bool] = None,
|
98 |
+
):
|
99 |
+
super().__init__()
|
100 |
+
|
101 |
+
# Validate inputs.
|
102 |
+
if patch_size is not None:
|
103 |
+
if norm_type not in ["ada_norm", "ada_norm_zero", "ada_norm_single"]:
|
104 |
+
raise NotImplementedError(
|
105 |
+
f"Forward pass is not implemented when `patch_size` is not None and `norm_type` is '{norm_type}'."
|
106 |
+
)
|
107 |
+
elif norm_type in ["ada_norm", "ada_norm_zero"] and num_embeds_ada_norm is None:
|
108 |
+
raise ValueError(
|
109 |
+
f"When using a `patch_size` and this `norm_type` ({norm_type}), `num_embeds_ada_norm` cannot be None."
|
110 |
+
)
|
111 |
+
|
112 |
+
# 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
|
113 |
+
# Define whether input is continuous or discrete depending on configuration
|
114 |
+
self.is_input_continuous = (in_channels is not None) and (patch_size is None)
|
115 |
+
self.is_input_vectorized = num_vector_embeds is not None
|
116 |
+
self.is_input_patches = in_channels is not None and patch_size is not None
|
117 |
+
|
118 |
+
if self.is_input_continuous and self.is_input_vectorized:
|
119 |
+
raise ValueError(
|
120 |
+
f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make"
|
121 |
+
" sure that either `in_channels` or `num_vector_embeds` is None."
|
122 |
+
)
|
123 |
+
elif self.is_input_vectorized and self.is_input_patches:
|
124 |
+
raise ValueError(
|
125 |
+
f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make"
|
126 |
+
" sure that either `num_vector_embeds` or `num_patches` is None."
|
127 |
+
)
|
128 |
+
elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches:
|
129 |
+
raise ValueError(
|
130 |
+
f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:"
|
131 |
+
f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None."
|
132 |
+
)
|
133 |
+
|
134 |
+
if norm_type == "layer_norm" and num_embeds_ada_norm is not None:
|
135 |
+
deprecation_message = (
|
136 |
+
f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or"
|
137 |
+
" incorrectly set to `'layer_norm'`. Make sure to set `norm_type` to `'ada_norm'` in the config."
|
138 |
+
" Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect"
|
139 |
+
" results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it"
|
140 |
+
" would be very nice if you could open a Pull request for the `transformer/config.json` file"
|
141 |
+
)
|
142 |
+
deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False)
|
143 |
+
norm_type = "ada_norm"
|
144 |
+
|
145 |
+
# Set some common variables used across the board.
|
146 |
+
self.use_linear_projection = use_linear_projection
|
147 |
+
self.interpolation_scale = interpolation_scale
|
148 |
+
self.caption_channels = caption_channels
|
149 |
+
self.num_attention_heads = num_attention_heads
|
150 |
+
self.attention_head_dim = attention_head_dim
|
151 |
+
self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
|
152 |
+
self.in_channels = in_channels
|
153 |
+
self.out_channels = in_channels if out_channels is None else out_channels
|
154 |
+
self.gradient_checkpointing = False
|
155 |
+
|
156 |
+
if use_additional_conditions is None:
|
157 |
+
if norm_type == "ada_norm_single" and sample_size == 128:
|
158 |
+
use_additional_conditions = True
|
159 |
+
else:
|
160 |
+
use_additional_conditions = False
|
161 |
+
self.use_additional_conditions = use_additional_conditions
|
162 |
+
|
163 |
+
# 2. Initialize the right blocks.
|
164 |
+
# These functions follow a common structure:
|
165 |
+
# a. Initialize the input blocks. b. Initialize the transformer blocks.
|
166 |
+
# c. Initialize the output blocks and other projection blocks when necessary.
|
167 |
+
if self.is_input_continuous:
|
168 |
+
self._init_continuous_input(norm_type=norm_type)
|
169 |
+
elif self.is_input_vectorized:
|
170 |
+
self._init_vectorized_inputs(norm_type=norm_type)
|
171 |
+
elif self.is_input_patches:
|
172 |
+
self._init_patched_inputs(norm_type=norm_type)
|
173 |
+
|
174 |
+
def _init_continuous_input(self, norm_type):
|
175 |
+
self.norm = torch.nn.GroupNorm(
|
176 |
+
num_groups=self.config.norm_num_groups, num_channels=self.in_channels, eps=1e-6, affine=True
|
177 |
+
)
|
178 |
+
if self.use_linear_projection:
|
179 |
+
self.proj_in = torch.nn.Linear(self.in_channels, self.inner_dim)
|
180 |
+
else:
|
181 |
+
self.proj_in = torch.nn.Conv2d(self.in_channels, self.inner_dim, kernel_size=1, stride=1, padding=0)
|
182 |
+
|
183 |
+
self.transformer_blocks = nn.ModuleList(
|
184 |
+
[
|
185 |
+
BasicTransformerBlock(
|
186 |
+
self.inner_dim,
|
187 |
+
self.config.num_attention_heads,
|
188 |
+
self.config.attention_head_dim,
|
189 |
+
dropout=self.config.dropout,
|
190 |
+
cross_attention_dim=self.config.cross_attention_dim,
|
191 |
+
activation_fn=self.config.activation_fn,
|
192 |
+
num_embeds_ada_norm=self.config.num_embeds_ada_norm,
|
193 |
+
attention_bias=self.config.attention_bias,
|
194 |
+
only_cross_attention=self.config.only_cross_attention,
|
195 |
+
double_self_attention=self.config.double_self_attention,
|
196 |
+
upcast_attention=self.config.upcast_attention,
|
197 |
+
norm_type=norm_type,
|
198 |
+
norm_elementwise_affine=self.config.norm_elementwise_affine,
|
199 |
+
norm_eps=self.config.norm_eps,
|
200 |
+
attention_type=self.config.attention_type,
|
201 |
+
)
|
202 |
+
for _ in range(self.config.num_layers)
|
203 |
+
]
|
204 |
+
)
|
205 |
+
|
206 |
+
if self.use_linear_projection:
|
207 |
+
self.proj_out = torch.nn.Linear(self.inner_dim, self.out_channels)
|
208 |
+
else:
|
209 |
+
self.proj_out = torch.nn.Conv2d(self.inner_dim, self.out_channels, kernel_size=1, stride=1, padding=0)
|
210 |
+
|
211 |
+
def _init_vectorized_inputs(self, norm_type):
|
212 |
+
assert self.config.sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
|
213 |
+
assert (
|
214 |
+
self.config.num_vector_embeds is not None
|
215 |
+
), "Transformer2DModel over discrete input must provide num_embed"
|
216 |
+
|
217 |
+
self.height = self.config.sample_size
|
218 |
+
self.width = self.config.sample_size
|
219 |
+
self.num_latent_pixels = self.height * self.width
|
220 |
+
|
221 |
+
self.latent_image_embedding = ImagePositionalEmbeddings(
|
222 |
+
num_embed=self.config.num_vector_embeds, embed_dim=self.inner_dim, height=self.height, width=self.width
|
223 |
+
)
|
224 |
+
|
225 |
+
self.transformer_blocks = nn.ModuleList(
|
226 |
+
[
|
227 |
+
BasicTransformerBlock(
|
228 |
+
self.inner_dim,
|
229 |
+
self.config.num_attention_heads,
|
230 |
+
self.config.attention_head_dim,
|
231 |
+
dropout=self.config.dropout,
|
232 |
+
cross_attention_dim=self.config.cross_attention_dim,
|
233 |
+
activation_fn=self.config.activation_fn,
|
234 |
+
num_embeds_ada_norm=self.config.num_embeds_ada_norm,
|
235 |
+
attention_bias=self.config.attention_bias,
|
236 |
+
only_cross_attention=self.config.only_cross_attention,
|
237 |
+
double_self_attention=self.config.double_self_attention,
|
238 |
+
upcast_attention=self.config.upcast_attention,
|
239 |
+
norm_type=norm_type,
|
240 |
+
norm_elementwise_affine=self.config.norm_elementwise_affine,
|
241 |
+
norm_eps=self.config.norm_eps,
|
242 |
+
attention_type=self.config.attention_type,
|
243 |
+
)
|
244 |
+
for _ in range(self.config.num_layers)
|
245 |
+
]
|
246 |
+
)
|
247 |
+
|
248 |
+
self.norm_out = nn.LayerNorm(self.inner_dim)
|
249 |
+
self.out = nn.Linear(self.inner_dim, self.config.num_vector_embeds - 1)
|
250 |
+
|
251 |
+
def _init_patched_inputs(self, norm_type):
|
252 |
+
assert self.config.sample_size is not None, "Transformer2DModel over patched input must provide sample_size"
|
253 |
+
|
254 |
+
self.height = self.config.sample_size
|
255 |
+
self.width = self.config.sample_size
|
256 |
+
|
257 |
+
self.patch_size = self.config.patch_size
|
258 |
+
interpolation_scale = (
|
259 |
+
self.config.interpolation_scale
|
260 |
+
if self.config.interpolation_scale is not None
|
261 |
+
else max(self.config.sample_size // 64, 1)
|
262 |
+
)
|
263 |
+
self.pos_embed = PatchEmbed(
|
264 |
+
height=self.config.sample_size,
|
265 |
+
width=self.config.sample_size,
|
266 |
+
patch_size=self.config.patch_size,
|
267 |
+
in_channels=self.in_channels,
|
268 |
+
embed_dim=self.inner_dim,
|
269 |
+
interpolation_scale=interpolation_scale,
|
270 |
+
)
|
271 |
+
|
272 |
+
self.transformer_blocks = nn.ModuleList(
|
273 |
+
[
|
274 |
+
BasicTransformerBlock(
|
275 |
+
self.inner_dim,
|
276 |
+
self.config.num_attention_heads,
|
277 |
+
self.config.attention_head_dim,
|
278 |
+
dropout=self.config.dropout,
|
279 |
+
cross_attention_dim=self.config.cross_attention_dim,
|
280 |
+
activation_fn=self.config.activation_fn,
|
281 |
+
num_embeds_ada_norm=self.config.num_embeds_ada_norm,
|
282 |
+
attention_bias=self.config.attention_bias,
|
283 |
+
only_cross_attention=self.config.only_cross_attention,
|
284 |
+
double_self_attention=self.config.double_self_attention,
|
285 |
+
upcast_attention=self.config.upcast_attention,
|
286 |
+
norm_type=norm_type,
|
287 |
+
norm_elementwise_affine=self.config.norm_elementwise_affine,
|
288 |
+
norm_eps=self.config.norm_eps,
|
289 |
+
attention_type=self.config.attention_type,
|
290 |
+
)
|
291 |
+
for _ in range(self.config.num_layers)
|
292 |
+
]
|
293 |
+
)
|
294 |
+
|
295 |
+
if self.config.norm_type != "ada_norm_single":
|
296 |
+
self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6)
|
297 |
+
self.proj_out_1 = nn.Linear(self.inner_dim, 2 * self.inner_dim)
|
298 |
+
self.proj_out_2 = nn.Linear(
|
299 |
+
self.inner_dim, self.config.patch_size * self.config.patch_size * self.out_channels
|
300 |
+
)
|
301 |
+
elif self.config.norm_type == "ada_norm_single":
|
302 |
+
self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6)
|
303 |
+
self.scale_shift_table = nn.Parameter(torch.randn(2, self.inner_dim) / self.inner_dim ** 0.5)
|
304 |
+
self.proj_out = nn.Linear(
|
305 |
+
self.inner_dim, self.config.patch_size * self.config.patch_size * self.out_channels
|
306 |
+
)
|
307 |
+
|
308 |
+
# PixArt-Alpha blocks.
|
309 |
+
self.adaln_single = None
|
310 |
+
if self.config.norm_type == "ada_norm_single":
|
311 |
+
# TODO(Sayak, PVP) clean this, for now we use sample size to determine whether to use
|
312 |
+
# additional conditions until we find better name
|
313 |
+
self.adaln_single = AdaLayerNormSingle(
|
314 |
+
self.inner_dim, use_additional_conditions=self.use_additional_conditions
|
315 |
+
)
|
316 |
+
|
317 |
+
self.caption_projection = None
|
318 |
+
if self.caption_channels is not None:
|
319 |
+
self.caption_projection = PixArtAlphaTextProjection(
|
320 |
+
in_features=self.caption_channels, hidden_size=self.inner_dim
|
321 |
+
)
|
322 |
+
|
323 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
324 |
+
if hasattr(module, "gradient_checkpointing"):
|
325 |
+
module.gradient_checkpointing = value
|
326 |
+
|
327 |
+
def forward(
|
328 |
+
self,
|
329 |
+
hidden_states: torch.Tensor,
|
330 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
331 |
+
timestep: Optional[torch.LongTensor] = None,
|
332 |
+
added_cond_kwargs: Dict[str, torch.Tensor] = None,
|
333 |
+
class_labels: Optional[torch.LongTensor] = None,
|
334 |
+
cross_attention_kwargs: Dict[str, Any] = None,
|
335 |
+
attention_mask: Optional[torch.Tensor] = None,
|
336 |
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
337 |
+
return_dict: bool = True,
|
338 |
+
):
|
339 |
+
"""
|
340 |
+
The [`Transformer2DModel`] forward method.
|
341 |
+
|
342 |
+
Args:
|
343 |
+
hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.Tensor` of shape `(batch size, channel, height, width)` if continuous):
|
344 |
+
Input `hidden_states`.
|
345 |
+
encoder_hidden_states ( `torch.Tensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
|
346 |
+
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
|
347 |
+
self-attention.
|
348 |
+
timestep ( `torch.LongTensor`, *optional*):
|
349 |
+
Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
|
350 |
+
class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
|
351 |
+
Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
|
352 |
+
`AdaLayerZeroNorm`.
|
353 |
+
cross_attention_kwargs ( `Dict[str, Any]`, *optional*):
|
354 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
355 |
+
`self.processor` in
|
356 |
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
357 |
+
attention_mask ( `torch.Tensor`, *optional*):
|
358 |
+
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
|
359 |
+
is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
|
360 |
+
negative values to the attention scores corresponding to "discard" tokens.
|
361 |
+
encoder_attention_mask ( `torch.Tensor`, *optional*):
|
362 |
+
Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
|
363 |
+
|
364 |
+
* Mask `(batch, sequence_length)` True = keep, False = discard.
|
365 |
+
* Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
|
366 |
+
|
367 |
+
If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
|
368 |
+
above. This bias will be added to the cross-attention scores.
|
369 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
370 |
+
Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
|
371 |
+
tuple.
|
372 |
+
|
373 |
+
Returns:
|
374 |
+
If `return_dict` is True, an [`~models.transformers.transformer_2d.Transformer2DModelOutput`] is returned,
|
375 |
+
otherwise a `tuple` where the first element is the sample tensor.
|
376 |
+
"""
|
377 |
+
if cross_attention_kwargs is not None:
|
378 |
+
if cross_attention_kwargs.get("scale", None) is not None:
|
379 |
+
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
|
380 |
+
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
|
381 |
+
# we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
|
382 |
+
# we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
|
383 |
+
# expects mask of shape:
|
384 |
+
# [batch, key_tokens]
|
385 |
+
# adds singleton query_tokens dimension:
|
386 |
+
# [batch, 1, key_tokens]
|
387 |
+
# this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
|
388 |
+
# [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
|
389 |
+
# [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
|
390 |
+
if attention_mask is not None and attention_mask.ndim == 2:
|
391 |
+
# assume that mask is expressed as:
|
392 |
+
# (1 = keep, 0 = discard)
|
393 |
+
# convert mask into a bias that can be added to attention scores:
|
394 |
+
# (keep = +0, discard = -10000.0)
|
395 |
+
attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
|
396 |
+
attention_mask = attention_mask.unsqueeze(1)
|
397 |
+
|
398 |
+
# convert encoder_attention_mask to a bias the same way we do for attention_mask
|
399 |
+
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
|
400 |
+
encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
|
401 |
+
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
|
402 |
+
|
403 |
+
# 1. Input
|
404 |
+
if self.is_input_continuous:
|
405 |
+
batch_size, _, height, width = hidden_states.shape
|
406 |
+
residual = hidden_states
|
407 |
+
hidden_states, inner_dim = self._operate_on_continuous_inputs(hidden_states)
|
408 |
+
elif self.is_input_vectorized:
|
409 |
+
hidden_states = self.latent_image_embedding(hidden_states)
|
410 |
+
elif self.is_input_patches:
|
411 |
+
height, width = hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size
|
412 |
+
hidden_states, encoder_hidden_states, timestep, embedded_timestep = self._operate_on_patched_inputs(
|
413 |
+
hidden_states, encoder_hidden_states, timestep, added_cond_kwargs
|
414 |
+
)
|
415 |
+
|
416 |
+
# 2. Blocks
|
417 |
+
for block in self.transformer_blocks:
|
418 |
+
if self.training and self.gradient_checkpointing:
|
419 |
+
|
420 |
+
def create_custom_forward(module, return_dict=None):
|
421 |
+
def custom_forward(*inputs):
|
422 |
+
if return_dict is not None:
|
423 |
+
return module(*inputs, return_dict=return_dict)
|
424 |
+
else:
|
425 |
+
return module(*inputs)
|
426 |
+
|
427 |
+
return custom_forward
|
428 |
+
|
429 |
+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
430 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
431 |
+
create_custom_forward(block),
|
432 |
+
hidden_states,
|
433 |
+
attention_mask,
|
434 |
+
encoder_hidden_states,
|
435 |
+
encoder_attention_mask,
|
436 |
+
timestep,
|
437 |
+
cross_attention_kwargs,
|
438 |
+
class_labels,
|
439 |
+
**ckpt_kwargs,
|
440 |
+
)
|
441 |
+
else:
|
442 |
+
hidden_states = block(
|
443 |
+
hidden_states=hidden_states,
|
444 |
+
attention_mask=attention_mask,
|
445 |
+
encoder_hidden_states=encoder_hidden_states,
|
446 |
+
encoder_attention_mask=encoder_attention_mask,
|
447 |
+
timestep=timestep,
|
448 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
449 |
+
class_labels=class_labels,
|
450 |
+
)
|
451 |
+
|
452 |
+
# 3. Output
|
453 |
+
if self.is_input_continuous:
|
454 |
+
output = self._get_output_for_continuous_inputs(
|
455 |
+
hidden_states=hidden_states,
|
456 |
+
residual=residual,
|
457 |
+
batch_size=batch_size,
|
458 |
+
height=height,
|
459 |
+
width=width,
|
460 |
+
inner_dim=inner_dim,
|
461 |
+
)
|
462 |
+
elif self.is_input_vectorized:
|
463 |
+
output = self._get_output_for_vectorized_inputs(hidden_states)
|
464 |
+
elif self.is_input_patches:
|
465 |
+
output = self._get_output_for_patched_inputs(
|
466 |
+
hidden_states=hidden_states,
|
467 |
+
timestep=timestep,
|
468 |
+
class_labels=class_labels,
|
469 |
+
embedded_timestep=embedded_timestep,
|
470 |
+
height=height,
|
471 |
+
width=width,
|
472 |
+
)
|
473 |
+
|
474 |
+
if not return_dict:
|
475 |
+
return (output,)
|
476 |
+
|
477 |
+
return Transformer2DModelOutput(sample=output)
|
478 |
+
|
479 |
+
def _operate_on_continuous_inputs(self, hidden_states):
|
480 |
+
batch, _, height, width = hidden_states.shape
|
481 |
+
hidden_states = self.norm(hidden_states)
|
482 |
+
|
483 |
+
if not self.use_linear_projection:
|
484 |
+
hidden_states = self.proj_in(hidden_states)
|
485 |
+
inner_dim = hidden_states.shape[1]
|
486 |
+
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
|
487 |
+
else:
|
488 |
+
inner_dim = hidden_states.shape[1]
|
489 |
+
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
|
490 |
+
hidden_states = self.proj_in(hidden_states)
|
491 |
+
|
492 |
+
return hidden_states, inner_dim
|
493 |
+
|
494 |
+
def _operate_on_patched_inputs(self, hidden_states, encoder_hidden_states, timestep, added_cond_kwargs):
|
495 |
+
batch_size = hidden_states.shape[0]
|
496 |
+
hidden_states = self.pos_embed(hidden_states)
|
497 |
+
embedded_timestep = None
|
498 |
+
|
499 |
+
if self.adaln_single is not None:
|
500 |
+
if self.use_additional_conditions and added_cond_kwargs is None:
|
501 |
+
raise ValueError(
|
502 |
+
"`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`."
|
503 |
+
)
|
504 |
+
timestep, embedded_timestep = self.adaln_single(
|
505 |
+
timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype
|
506 |
+
)
|
507 |
+
|
508 |
+
if self.caption_projection is not None:
|
509 |
+
encoder_hidden_states = self.caption_projection(encoder_hidden_states)
|
510 |
+
encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
|
511 |
+
|
512 |
+
return hidden_states, encoder_hidden_states, timestep, embedded_timestep
|
513 |
+
|
514 |
+
def _get_output_for_continuous_inputs(self, hidden_states, residual, batch_size, height, width, inner_dim):
|
515 |
+
if not self.use_linear_projection:
|
516 |
+
hidden_states = (
|
517 |
+
hidden_states.reshape(batch_size, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
|
518 |
+
)
|
519 |
+
hidden_states = self.proj_out(hidden_states)
|
520 |
+
else:
|
521 |
+
hidden_states = self.proj_out(hidden_states)
|
522 |
+
hidden_states = (
|
523 |
+
hidden_states.reshape(batch_size, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
|
524 |
+
)
|
525 |
+
|
526 |
+
output = hidden_states + residual
|
527 |
+
return output
|
528 |
+
|
529 |
+
def _get_output_for_vectorized_inputs(self, hidden_states):
|
530 |
+
hidden_states = self.norm_out(hidden_states)
|
531 |
+
logits = self.out(hidden_states)
|
532 |
+
# (batch, self.num_vector_embeds - 1, self.num_latent_pixels)
|
533 |
+
logits = logits.permute(0, 2, 1)
|
534 |
+
# log(p(x_0))
|
535 |
+
output = F.log_softmax(logits.double(), dim=1).float()
|
536 |
+
return output
|
537 |
+
|
538 |
+
def _get_output_for_patched_inputs(
|
539 |
+
self, hidden_states, timestep, class_labels, embedded_timestep, height=None, width=None
|
540 |
+
):
|
541 |
+
if self.config.norm_type != "ada_norm_single":
|
542 |
+
conditioning = self.transformer_blocks[0].norm1.emb(
|
543 |
+
timestep, class_labels, hidden_dtype=hidden_states.dtype
|
544 |
+
)
|
545 |
+
shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
|
546 |
+
hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
|
547 |
+
hidden_states = self.proj_out_2(hidden_states)
|
548 |
+
elif self.config.norm_type == "ada_norm_single":
|
549 |
+
shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1)
|
550 |
+
hidden_states = self.norm_out(hidden_states)
|
551 |
+
# Modulation
|
552 |
+
hidden_states = hidden_states * (1 + scale) + shift
|
553 |
+
hidden_states = self.proj_out(hidden_states)
|
554 |
+
hidden_states = hidden_states.squeeze(1)
|
555 |
+
|
556 |
+
# unpatchify
|
557 |
+
if self.adaln_single is None:
|
558 |
+
height = width = int(hidden_states.shape[1] ** 0.5)
|
559 |
+
hidden_states = hidden_states.reshape(
|
560 |
+
shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
|
561 |
+
)
|
562 |
+
hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
|
563 |
+
output = hidden_states.reshape(
|
564 |
+
shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
|
565 |
+
)
|
566 |
+
return output
|
onlyflow/models/unet.py
ADDED
The diff for this file is too large to render.
See raw diff
|
|
onlyflow/pipelines/pipeline_animation.py
ADDED
@@ -0,0 +1,497 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adapted from https://github.com/showlab/Tune-A-Video/blob/main/tuneavideo/pipelines/pipeline_tuneavideo.py
|
2 |
+
|
3 |
+
|
4 |
+
# TODO: rebase on diffusers/pipelines/animatediff/pipeline_animatediff.py
|
5 |
+
|
6 |
+
import copy
|
7 |
+
from dataclasses import dataclass
|
8 |
+
from typing import Callable, Optional, Dict, Any
|
9 |
+
from typing import List, Union
|
10 |
+
|
11 |
+
import PIL.Image
|
12 |
+
import numpy as np
|
13 |
+
import torch
|
14 |
+
from diffusers import AnimateDiffPipeline
|
15 |
+
from diffusers.image_processor import PipelineImageInput
|
16 |
+
from diffusers.models import AutoencoderKL
|
17 |
+
from diffusers.pipelines.animatediff import AnimateDiffPipelineOutput
|
18 |
+
from diffusers.pipelines.animatediff.pipeline_animatediff import EXAMPLE_DOC_STRING
|
19 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
20 |
+
from diffusers.schedulers import (
|
21 |
+
DDIMScheduler,
|
22 |
+
DPMSolverMultistepScheduler,
|
23 |
+
EulerAncestralDiscreteScheduler,
|
24 |
+
EulerDiscreteScheduler,
|
25 |
+
LMSDiscreteScheduler,
|
26 |
+
PNDMScheduler,
|
27 |
+
)
|
28 |
+
from diffusers.utils import BaseOutput
|
29 |
+
from diffusers.utils import deprecate, logging, replace_example_docstring
|
30 |
+
from einops import rearrange
|
31 |
+
from transformers import CLIPTextModel, CLIPTokenizer
|
32 |
+
|
33 |
+
from onlyflow.models.flow_adaptor import FlowEncoder
|
34 |
+
from onlyflow.models.unet import UNetMotionModel
|
35 |
+
|
36 |
+
logger = logging.get_logger(__name__)
|
37 |
+
|
38 |
+
|
39 |
+
@dataclass
|
40 |
+
class AnimateDiffPipelineOutput(BaseOutput):
|
41 |
+
frames_no_flow: Union[torch.Tensor, np.ndarray, List[List[PIL.Image.Image]]]
|
42 |
+
frames_flow: Union[torch.Tensor, np.ndarray, List[List[PIL.Image.Image]]]
|
43 |
+
|
44 |
+
|
45 |
+
class FlowCtrlPipeline(AnimateDiffPipeline, DiffusionPipeline):
|
46 |
+
_optional_components = ["feature_extractor", "image_encoder", "motion_adapter"]
|
47 |
+
|
48 |
+
def __init__(self,
|
49 |
+
vae: AutoencoderKL,
|
50 |
+
text_encoder: CLIPTextModel,
|
51 |
+
tokenizer: CLIPTokenizer,
|
52 |
+
unet: UNetMotionModel,
|
53 |
+
scheduler: Union[
|
54 |
+
DDIMScheduler,
|
55 |
+
PNDMScheduler,
|
56 |
+
LMSDiscreteScheduler,
|
57 |
+
EulerDiscreteScheduler,
|
58 |
+
EulerAncestralDiscreteScheduler,
|
59 |
+
DPMSolverMultistepScheduler],
|
60 |
+
flow_encoder: FlowEncoder,
|
61 |
+
feature_extractor=None,
|
62 |
+
image_encoder=None,
|
63 |
+
motion_adapter=None,
|
64 |
+
):
|
65 |
+
|
66 |
+
super().__init__(
|
67 |
+
vae=vae,
|
68 |
+
text_encoder=text_encoder,
|
69 |
+
tokenizer=tokenizer,
|
70 |
+
unet=unet,
|
71 |
+
motion_adapter=motion_adapter,
|
72 |
+
scheduler=scheduler,
|
73 |
+
feature_extractor=feature_extractor,
|
74 |
+
image_encoder=image_encoder,
|
75 |
+
)
|
76 |
+
|
77 |
+
# deepcopy the scheduler
|
78 |
+
self.scheduler_no_flow = copy.deepcopy(scheduler)
|
79 |
+
|
80 |
+
self.unet = unet
|
81 |
+
|
82 |
+
self.register_modules(
|
83 |
+
flow_encoder=flow_encoder
|
84 |
+
)
|
85 |
+
|
86 |
+
@torch.no_grad()
|
87 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
88 |
+
def __call__(
|
89 |
+
self,
|
90 |
+
prompt: Union[str, List[str]] = None,
|
91 |
+
flow_embedding: torch.FloatTensor = None,
|
92 |
+
|
93 |
+
num_frames: Optional[int] = 16,
|
94 |
+
height: Optional[int] = None,
|
95 |
+
width: Optional[int] = None,
|
96 |
+
|
97 |
+
num_inference_steps: int = 50,
|
98 |
+
guidance_scale: float = 7.5,
|
99 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
100 |
+
eta: float = 0.0,
|
101 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
102 |
+
latents: Optional[torch.Tensor] = None,
|
103 |
+
|
104 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
105 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
106 |
+
ip_adapter_image: Optional[PipelineImageInput] = None,
|
107 |
+
ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
|
108 |
+
|
109 |
+
output_type: Optional[str] = "pt",
|
110 |
+
return_dict: bool = True,
|
111 |
+
|
112 |
+
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
113 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
114 |
+
|
115 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
116 |
+
motion_cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
117 |
+
|
118 |
+
clip_skip: Optional[int] = None,
|
119 |
+
decode_chunk_size: int = 16,
|
120 |
+
|
121 |
+
val_scale_factor_spatial: float = 1.,
|
122 |
+
val_scale_factor_temporal: float = 1.,
|
123 |
+
|
124 |
+
generate_no_flow: bool = False,
|
125 |
+
|
126 |
+
**kwargs,
|
127 |
+
):
|
128 |
+
r"""
|
129 |
+
The call function to the pipeline for generation.
|
130 |
+
|
131 |
+
Args:
|
132 |
+
prompt (`str` or `List[str]`, *optional*):
|
133 |
+
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
|
134 |
+
height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
135 |
+
The height in pixels of the generated video.
|
136 |
+
width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
137 |
+
The width in pixels of the generated video.
|
138 |
+
num_frames (`int`, *optional*, defaults to 16):
|
139 |
+
The number of video frames that are generated. Defaults to 16 frames which at 8 frames per seconds
|
140 |
+
amounts to 2 seconds of video.
|
141 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
142 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality videos at the
|
143 |
+
expense of slower inference.
|
144 |
+
guidance_scale (`float`, *optional*, defaults to 7.5):
|
145 |
+
A higher guidance scale value encourages the model to generate images closely linked to the text
|
146 |
+
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
|
147 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
148 |
+
The prompt or prompts to guide what to not include in image generation. If not defined, you need to
|
149 |
+
pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
|
150 |
+
eta (`float`, *optional*, defaults to 0.0):
|
151 |
+
Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
|
152 |
+
to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
|
153 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
154 |
+
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
155 |
+
generation deterministic.
|
156 |
+
latents (`torch.Tensor`, *optional*):
|
157 |
+
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for video
|
158 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
159 |
+
tensor is generated by sampling using the supplied random `generator`. Latents should be of shape
|
160 |
+
`(batch_size, num_channel, num_frames, height, width)`.
|
161 |
+
prompt_embeds (`torch.Tensor`, *optional*):
|
162 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
|
163 |
+
provided, text embeddings are generated from the `prompt` input argument.
|
164 |
+
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
165 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
|
166 |
+
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
|
167 |
+
ip_adapter_image: (`PipelineImageInput`, *optional*):
|
168 |
+
Optional image input to work with IP Adapters.
|
169 |
+
ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
|
170 |
+
Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
|
171 |
+
IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should
|
172 |
+
contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not
|
173 |
+
provided, embeddings are computed from the `ip_adapter_image` input argument.
|
174 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
175 |
+
The output format of the generated video. Choose between `torch.Tensor`, `PIL.Image` or `np.array`.
|
176 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
177 |
+
Whether or not to return a [`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] instead
|
178 |
+
of a plain tuple.
|
179 |
+
cross_attention_kwargs (`dict`, *optional*):
|
180 |
+
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
|
181 |
+
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
182 |
+
clip_skip (`int`, *optional*):
|
183 |
+
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
184 |
+
the output of the pre-final layer will be used for computing the prompt embeddings.
|
185 |
+
callback_on_step_end (`Callable`, *optional*):
|
186 |
+
A function that calls at the end of each denoising steps during the inference. The function is called
|
187 |
+
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
188 |
+
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
189 |
+
`callback_on_step_end_tensor_inputs`.
|
190 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
191 |
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
192 |
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
193 |
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
194 |
+
decode_chunk_size (`int`, defaults to `16`):
|
195 |
+
The number of frames to decode at a time when calling `decode_latents` method.
|
196 |
+
|
197 |
+
Examples:
|
198 |
+
|
199 |
+
Returns:
|
200 |
+
[`~pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] or `tuple`:
|
201 |
+
If `return_dict` is `True`, [`~pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] is
|
202 |
+
returned, otherwise a `tuple` is returned where the first element is a list with the generated frames.
|
203 |
+
"""
|
204 |
+
|
205 |
+
callback = kwargs.pop("callback", None)
|
206 |
+
callback_steps = kwargs.pop("callback_steps", None)
|
207 |
+
|
208 |
+
if callback is not None:
|
209 |
+
deprecate(
|
210 |
+
"callback",
|
211 |
+
"1.0.0",
|
212 |
+
"Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
|
213 |
+
)
|
214 |
+
if callback_steps is not None:
|
215 |
+
deprecate(
|
216 |
+
"callback_steps",
|
217 |
+
"1.0.0",
|
218 |
+
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
|
219 |
+
)
|
220 |
+
|
221 |
+
# 0. Default height and width to unet
|
222 |
+
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
223 |
+
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
224 |
+
|
225 |
+
num_videos_per_prompt = 1
|
226 |
+
|
227 |
+
# 1. Check inputs. Raise error if not correct
|
228 |
+
self.check_inputs(
|
229 |
+
prompt,
|
230 |
+
height,
|
231 |
+
width,
|
232 |
+
callback_steps,
|
233 |
+
negative_prompt,
|
234 |
+
prompt_embeds,
|
235 |
+
negative_prompt_embeds,
|
236 |
+
ip_adapter_image,
|
237 |
+
ip_adapter_image_embeds,
|
238 |
+
callback_on_step_end_tensor_inputs,
|
239 |
+
)
|
240 |
+
|
241 |
+
self._guidance_scale = guidance_scale
|
242 |
+
self._clip_skip = clip_skip
|
243 |
+
self._cross_attention_kwargs = cross_attention_kwargs
|
244 |
+
|
245 |
+
# 2. Define call parameters
|
246 |
+
if prompt is not None and isinstance(prompt, str):
|
247 |
+
batch_size = 1
|
248 |
+
elif prompt is not None and isinstance(prompt, list):
|
249 |
+
batch_size = len(prompt)
|
250 |
+
else:
|
251 |
+
batch_size = prompt_embeds.shape[0]
|
252 |
+
|
253 |
+
device = self.unet.device
|
254 |
+
|
255 |
+
# 3. Encode input prompt
|
256 |
+
text_encoder_lora_scale = (
|
257 |
+
self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
|
258 |
+
)
|
259 |
+
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
260 |
+
prompt,
|
261 |
+
device,
|
262 |
+
num_videos_per_prompt,
|
263 |
+
self.do_classifier_free_guidance,
|
264 |
+
negative_prompt,
|
265 |
+
prompt_embeds=prompt_embeds,
|
266 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
267 |
+
lora_scale=text_encoder_lora_scale,
|
268 |
+
clip_skip=self.clip_skip,
|
269 |
+
)
|
270 |
+
# For classifier free guidance, we need to do two forward passes.
|
271 |
+
# Here we concatenate the unconditional and text embeddings into a single batch
|
272 |
+
# to avoid doing two forward passes
|
273 |
+
if self.do_classifier_free_guidance:
|
274 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
275 |
+
|
276 |
+
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
|
277 |
+
image_embeds = self.prepare_ip_adapter_image_embeds(
|
278 |
+
ip_adapter_image,
|
279 |
+
ip_adapter_image_embeds,
|
280 |
+
device,
|
281 |
+
batch_size * num_videos_per_prompt,
|
282 |
+
self.do_classifier_free_guidance,
|
283 |
+
)
|
284 |
+
|
285 |
+
# 4. Prepare timesteps
|
286 |
+
single_model_length = num_frames
|
287 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
288 |
+
timesteps = self.scheduler.timesteps
|
289 |
+
|
290 |
+
# 5. Prepare latent variables
|
291 |
+
num_channels_latents = self.unet.config.in_channels
|
292 |
+
latents = self.prepare_latents(
|
293 |
+
batch_size * num_videos_per_prompt,
|
294 |
+
num_channels_latents,
|
295 |
+
num_frames,
|
296 |
+
height,
|
297 |
+
width,
|
298 |
+
prompt_embeds.dtype,
|
299 |
+
device,
|
300 |
+
generator,
|
301 |
+
latents,
|
302 |
+
)
|
303 |
+
|
304 |
+
if generate_no_flow:
|
305 |
+
latents_no_flow = latents.clone()
|
306 |
+
|
307 |
+
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
308 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
309 |
+
if isinstance(flow_embedding, list):
|
310 |
+
assert all([x.ndim == 5 for x in flow_embedding])
|
311 |
+
bs = flow_embedding[0].shape[0]
|
312 |
+
flow_embedding_features = []
|
313 |
+
for pe in flow_embedding:
|
314 |
+
flow_embedding_feature = self.flow_encoder(pe)
|
315 |
+
flow_embedding_feature = [rearrange(x, '(b f) c h w -> b c f h w', b=bs) for x in
|
316 |
+
flow_embedding_feature]
|
317 |
+
flow_embedding_features.append(flow_embedding_feature)
|
318 |
+
else:
|
319 |
+
bs = flow_embedding.shape[0]
|
320 |
+
assert flow_embedding.ndim == 5
|
321 |
+
flow_embedding_features = self.flow_encoder(flow_embedding) # bf, c, h, w
|
322 |
+
flow_embedding_features = [rearrange(x, '(b f) c h w -> b c f h w', b=bs)
|
323 |
+
for x in flow_embedding_features]
|
324 |
+
|
325 |
+
# 7. Add image embeds for IP-Adapter
|
326 |
+
added_cond_kwargs = {
|
327 |
+
"image_embeds": image_embeds} if ip_adapter_image is not None or ip_adapter_image_embeds is not None else None
|
328 |
+
|
329 |
+
num_free_init_iters = self._free_init_num_iters if self.free_init_enabled else 1
|
330 |
+
for free_init_iter in range(num_free_init_iters):
|
331 |
+
if self.free_init_enabled:
|
332 |
+
latents, timesteps = self._apply_free_init(
|
333 |
+
latents, free_init_iter, num_inference_steps, device, latents.dtype, generator
|
334 |
+
)
|
335 |
+
if generate_no_flow:
|
336 |
+
latents_no_flow = latents.clone()
|
337 |
+
|
338 |
+
self._num_timesteps = len(timesteps)
|
339 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
340 |
+
if isinstance(flow_embedding_features[0], list):
|
341 |
+
flow_embedding_features = [[torch.cat([x, x], dim=0) for x in flow_embedding_feature]
|
342 |
+
for flow_embedding_feature in flow_embedding_features] \
|
343 |
+
if self.do_classifier_free_guidance else flow_embedding_features
|
344 |
+
else:
|
345 |
+
flow_embedding_features = [torch.cat([x, x], dim=0) for x in flow_embedding_features] \
|
346 |
+
if self.do_classifier_free_guidance else flow_embedding_features # [2b c f h w]
|
347 |
+
|
348 |
+
# 8. Denoising loop
|
349 |
+
with self.progress_bar(total=self._num_timesteps) as progress_bar:
|
350 |
+
for i, t in enumerate(timesteps):
|
351 |
+
|
352 |
+
# expand the latents if we are doing classifier free guidance
|
353 |
+
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
354 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
355 |
+
|
356 |
+
if added_cond_kwargs is not None:
|
357 |
+
added_cond_kwargs.update({"flow_embedding_features": flow_embedding_features})
|
358 |
+
else:
|
359 |
+
added_cond_kwargs = {"flow_embedding_features": flow_embedding_features}
|
360 |
+
|
361 |
+
if cross_attention_kwargs is not None:
|
362 |
+
cross_attention_kwargs.update({"flow_scale": val_scale_factor_spatial})
|
363 |
+
else:
|
364 |
+
cross_attention_kwargs = {"flow_scale": val_scale_factor_spatial}
|
365 |
+
|
366 |
+
if motion_cross_attention_kwargs is not None:
|
367 |
+
motion_cross_attention_kwargs.update({"flow_scale": val_scale_factor_temporal})
|
368 |
+
else:
|
369 |
+
motion_cross_attention_kwargs = {"flow_scale": val_scale_factor_temporal}
|
370 |
+
|
371 |
+
# predict the noise residual
|
372 |
+
noise_pred = self.unet(
|
373 |
+
latent_model_input,
|
374 |
+
t,
|
375 |
+
encoder_hidden_states=prompt_embeds,
|
376 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
377 |
+
motion_cross_attention_kwargs=motion_cross_attention_kwargs,
|
378 |
+
added_cond_kwargs=added_cond_kwargs,
|
379 |
+
).sample
|
380 |
+
|
381 |
+
del latent_model_input
|
382 |
+
|
383 |
+
# perform guidance
|
384 |
+
if self.do_classifier_free_guidance:
|
385 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
386 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
387 |
+
del noise_pred_uncond, noise_pred_text
|
388 |
+
|
389 |
+
# compute the previous noisy sample x_t -> x_t-1
|
390 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
391 |
+
del noise_pred
|
392 |
+
|
393 |
+
if callback_on_step_end is not None:
|
394 |
+
callback_kwargs = {}
|
395 |
+
for k in callback_on_step_end_tensor_inputs:
|
396 |
+
callback_kwargs[k] = locals()[k]
|
397 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
398 |
+
|
399 |
+
latents = callback_outputs.pop("latents", latents)
|
400 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
401 |
+
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
402 |
+
|
403 |
+
# call the callback, if provided
|
404 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
405 |
+
progress_bar.update()
|
406 |
+
if callback is not None and i % callback_steps == 0:
|
407 |
+
callback(i, t, latents)
|
408 |
+
|
409 |
+
# 8. Denoising loop
|
410 |
+
if generate_no_flow:
|
411 |
+
with self.progress_bar(total=self._num_timesteps) as progress_bar:
|
412 |
+
for i, t in enumerate(timesteps):
|
413 |
+
|
414 |
+
# expand the latents if we are doing classifier free guidance
|
415 |
+
latent_model_input_no_flow = torch.cat(
|
416 |
+
[latents_no_flow] * 2) if self.do_classifier_free_guidance else latents_no_flow
|
417 |
+
latent_model_input_no_flow = self.scheduler.scale_model_input(latent_model_input_no_flow, t)
|
418 |
+
|
419 |
+
if added_cond_kwargs is not None:
|
420 |
+
added_cond_kwargs.update({"flow_embedding_features": flow_embedding_features})
|
421 |
+
else:
|
422 |
+
added_cond_kwargs = {"flow_embedding_features": flow_embedding_features}
|
423 |
+
|
424 |
+
if cross_attention_kwargs is not None:
|
425 |
+
cross_attention_kwargs.update({"flow_scale": 0.})
|
426 |
+
else:
|
427 |
+
cross_attention_kwargs = {"flow_scale": 0.}
|
428 |
+
|
429 |
+
if motion_cross_attention_kwargs is not None:
|
430 |
+
motion_cross_attention_kwargs.update({"flow_scale": 0.})
|
431 |
+
else:
|
432 |
+
motion_cross_attention_kwargs = {"flow_scale": 0.}
|
433 |
+
|
434 |
+
noise_pred_no_flow = self.unet(
|
435 |
+
latent_model_input_no_flow,
|
436 |
+
t,
|
437 |
+
encoder_hidden_states=prompt_embeds,
|
438 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
439 |
+
motion_cross_attention_kwargs=motion_cross_attention_kwargs,
|
440 |
+
added_cond_kwargs=added_cond_kwargs,
|
441 |
+
).sample
|
442 |
+
|
443 |
+
del latent_model_input_no_flow
|
444 |
+
|
445 |
+
# perform guidance
|
446 |
+
if self.do_classifier_free_guidance:
|
447 |
+
noise_pred_no_flow_uncond, noise_pred_no_flow_text = noise_pred_no_flow.chunk(2)
|
448 |
+
noise_pred_no_flow = noise_pred_no_flow_uncond + guidance_scale * (
|
449 |
+
noise_pred_no_flow_text - noise_pred_no_flow_uncond)
|
450 |
+
del noise_pred_no_flow_uncond, noise_pred_no_flow_text
|
451 |
+
|
452 |
+
# compute the previous noisy sample x_t -> x_t-1
|
453 |
+
latents_no_flow = self.scheduler.step(noise_pred_no_flow, t, latents_no_flow,
|
454 |
+
**extra_step_kwargs).prev_sample
|
455 |
+
del noise_pred_no_flow
|
456 |
+
|
457 |
+
if callback_on_step_end is not None:
|
458 |
+
callback_kwargs = {}
|
459 |
+
for k in callback_on_step_end_tensor_inputs:
|
460 |
+
callback_kwargs[k] = locals()[k]
|
461 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
462 |
+
|
463 |
+
latents = callback_outputs.pop("latents", latents)
|
464 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
465 |
+
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds",
|
466 |
+
negative_prompt_embeds)
|
467 |
+
|
468 |
+
# call the callback, if provided
|
469 |
+
if i == len(timesteps) - 1 or (
|
470 |
+
(i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
471 |
+
progress_bar.update()
|
472 |
+
if callback is not None and i % callback_steps == 0:
|
473 |
+
callback(i, t, latents)
|
474 |
+
|
475 |
+
# 9. Post processing
|
476 |
+
if output_type == "latent":
|
477 |
+
video = latents
|
478 |
+
if generate_no_flow:
|
479 |
+
video_no_flow = latents_no_flow
|
480 |
+
else:
|
481 |
+
video_tensor = self.decode_latents(latents, decode_chunk_size)
|
482 |
+
video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type)
|
483 |
+
|
484 |
+
if generate_no_flow:
|
485 |
+
video_tensor_no_flow = self.decode_latents(latents_no_flow, decode_chunk_size)
|
486 |
+
video_no_flow = self.video_processor.postprocess_video(video=video_tensor_no_flow,
|
487 |
+
output_type=output_type)
|
488 |
+
|
489 |
+
# 10. Offload all models
|
490 |
+
self.maybe_free_model_hooks()
|
491 |
+
|
492 |
+
video_no_flow = None if not generate_no_flow else video_no_flow
|
493 |
+
|
494 |
+
if not return_dict:
|
495 |
+
return (video, video_no_flow)
|
496 |
+
|
497 |
+
return AnimateDiffPipelineOutput(frames_flow=video, frames_no_flow=video_no_flow)
|
onlyflow/pipelines/pipeline_animation_long.py
ADDED
@@ -0,0 +1,555 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adapted from https://github.com/showlab/Tune-A-Video/blob/main/tuneavideo/pipelines/pipeline_tuneavideo.py
|
2 |
+
|
3 |
+
|
4 |
+
# TODO: rebase on diffusers/pipelines/animatediff/pipeline_animatediff.py
|
5 |
+
|
6 |
+
import copy
|
7 |
+
import gc
|
8 |
+
from dataclasses import dataclass
|
9 |
+
from typing import Callable, Optional, Dict, Any, Tuple
|
10 |
+
from typing import List, Union
|
11 |
+
|
12 |
+
import PIL.Image
|
13 |
+
import numpy as np
|
14 |
+
import torch
|
15 |
+
from diffusers import AnimateDiffPipeline
|
16 |
+
from diffusers.image_processor import PipelineImageInput
|
17 |
+
from diffusers.models import AutoencoderKL
|
18 |
+
from diffusers.models.attention import FreeNoiseTransformerBlock
|
19 |
+
from diffusers.pipelines.animatediff.pipeline_animatediff import EXAMPLE_DOC_STRING
|
20 |
+
from diffusers.pipelines.free_noise_utils import AnimateDiffFreeNoiseMixin, SplitInferenceModule
|
21 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
22 |
+
from diffusers.schedulers import (
|
23 |
+
DDIMScheduler,
|
24 |
+
DPMSolverMultistepScheduler,
|
25 |
+
EulerAncestralDiscreteScheduler,
|
26 |
+
EulerDiscreteScheduler,
|
27 |
+
LMSDiscreteScheduler,
|
28 |
+
PNDMScheduler,
|
29 |
+
)
|
30 |
+
from diffusers.utils import BaseOutput
|
31 |
+
from diffusers.utils import deprecate, logging, replace_example_docstring
|
32 |
+
from einops import rearrange
|
33 |
+
from transformers import CLIPTextModel, CLIPTokenizer
|
34 |
+
|
35 |
+
from onlyflow.models.flow_adaptor import FlowEncoder
|
36 |
+
from onlyflow.models.unet import UNetMotionModel, AnimateDiffTransformer3D, \
|
37 |
+
CrossAttnDownBlockMotion, DownBlockMotion, UpBlockMotion, CrossAttnUpBlockMotion
|
38 |
+
from ..models.attention import BasicTransformerBlock
|
39 |
+
|
40 |
+
logger = logging.get_logger(__name__)
|
41 |
+
|
42 |
+
@dataclass
|
43 |
+
class FlowCtrlPipelineOutput(BaseOutput):
|
44 |
+
r"""
|
45 |
+
Output class for AnimateDiff pipelines.
|
46 |
+
|
47 |
+
Args:
|
48 |
+
frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
|
49 |
+
List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
|
50 |
+
denoised
|
51 |
+
PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
|
52 |
+
`(batch_size, num_frames, channels, height, width)`
|
53 |
+
"""
|
54 |
+
|
55 |
+
frames: Union[torch.Tensor, np.ndarray, List[List[PIL.Image.Image]]]
|
56 |
+
|
57 |
+
|
58 |
+
class FlowCtrlPipeline(AnimateDiffPipeline):
|
59 |
+
model_cpu_offload_seq = "text_encoder->flow_encoder->image_encoder->unet->vae"
|
60 |
+
_optional_components = ["feature_extractor", "image_encoder", "motion_adapter"]
|
61 |
+
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
|
62 |
+
|
63 |
+
def __init__(self,
|
64 |
+
vae: AutoencoderKL,
|
65 |
+
text_encoder: CLIPTextModel,
|
66 |
+
tokenizer: CLIPTokenizer,
|
67 |
+
unet: UNetMotionModel,
|
68 |
+
scheduler: Union[
|
69 |
+
DDIMScheduler,
|
70 |
+
PNDMScheduler,
|
71 |
+
LMSDiscreteScheduler,
|
72 |
+
EulerDiscreteScheduler,
|
73 |
+
EulerAncestralDiscreteScheduler,
|
74 |
+
DPMSolverMultistepScheduler],
|
75 |
+
flow_encoder: FlowEncoder,
|
76 |
+
feature_extractor=None,
|
77 |
+
image_encoder=None,
|
78 |
+
motion_adapter=None,
|
79 |
+
):
|
80 |
+
|
81 |
+
super().__init__(
|
82 |
+
vae=vae,
|
83 |
+
text_encoder=text_encoder,
|
84 |
+
tokenizer=tokenizer,
|
85 |
+
unet=unet,
|
86 |
+
motion_adapter=motion_adapter,
|
87 |
+
scheduler=scheduler,
|
88 |
+
feature_extractor=feature_extractor,
|
89 |
+
image_encoder=image_encoder,
|
90 |
+
)
|
91 |
+
self.register_modules(
|
92 |
+
flow_encoder=flow_encoder
|
93 |
+
)
|
94 |
+
|
95 |
+
def _enable_split_inference_motion_modules_(
|
96 |
+
self, motion_modules: List[AnimateDiffTransformer3D], spatial_split_size: int
|
97 |
+
) -> None:
|
98 |
+
for motion_module in motion_modules:
|
99 |
+
motion_module.proj_in = SplitInferenceModule(motion_module.proj_in, spatial_split_size, 0, ["input"])
|
100 |
+
|
101 |
+
for i in range(len(motion_module.transformer_blocks)):
|
102 |
+
motion_module.transformer_blocks[i] = SplitInferenceModule(
|
103 |
+
motion_module.transformer_blocks[i],
|
104 |
+
spatial_split_size,
|
105 |
+
0,
|
106 |
+
["hidden_states", "encoder_hidden_states", "cross_attention_kwargs"],
|
107 |
+
)
|
108 |
+
|
109 |
+
motion_module.proj_out = SplitInferenceModule(motion_module.proj_out, spatial_split_size, 0, ["input"])
|
110 |
+
|
111 |
+
|
112 |
+
def _enable_free_noise_in_block(self, block: Union[CrossAttnDownBlockMotion, DownBlockMotion, UpBlockMotion, CrossAttnUpBlockMotion]):
|
113 |
+
r"""Helper function to enable FreeNoise in transformer blocks."""
|
114 |
+
|
115 |
+
for motion_module in block.motion_modules:
|
116 |
+
num_transformer_blocks = len(motion_module.transformer_blocks)
|
117 |
+
|
118 |
+
for i in range(num_transformer_blocks):
|
119 |
+
if isinstance(motion_module.transformer_blocks[i], FreeNoiseTransformerBlock):
|
120 |
+
motion_module.transformer_blocks[i].set_free_noise_properties(
|
121 |
+
self._free_noise_context_length,
|
122 |
+
self._free_noise_context_stride,
|
123 |
+
self._free_noise_weighting_scheme,
|
124 |
+
)
|
125 |
+
else:
|
126 |
+
basic_transfomer_block = motion_module.transformer_blocks[i]
|
127 |
+
|
128 |
+
motion_module.transformer_blocks[i] = FreeNoiseTransformerBlock(
|
129 |
+
dim=basic_transfomer_block.dim,
|
130 |
+
num_attention_heads=basic_transfomer_block.num_attention_heads,
|
131 |
+
attention_head_dim=basic_transfomer_block.attention_head_dim,
|
132 |
+
dropout=basic_transfomer_block.dropout,
|
133 |
+
cross_attention_dim=basic_transfomer_block.cross_attention_dim,
|
134 |
+
activation_fn=basic_transfomer_block.activation_fn,
|
135 |
+
attention_bias=basic_transfomer_block.attention_bias,
|
136 |
+
only_cross_attention=basic_transfomer_block.only_cross_attention,
|
137 |
+
double_self_attention=basic_transfomer_block.double_self_attention,
|
138 |
+
positional_embeddings=basic_transfomer_block.positional_embeddings,
|
139 |
+
num_positional_embeddings=basic_transfomer_block.num_positional_embeddings,
|
140 |
+
context_length=self._free_noise_context_length,
|
141 |
+
context_stride=self._free_noise_context_stride,
|
142 |
+
weighting_scheme=self._free_noise_weighting_scheme,
|
143 |
+
).to(device=self._execution_device, dtype=self.dtype)
|
144 |
+
|
145 |
+
# here i need to copy the attention processor from the basic transformer block to the free noise transformer block
|
146 |
+
motion_module.transformer_blocks[i].attn1 = basic_transfomer_block.attn1
|
147 |
+
motion_module.transformer_blocks[i].attn2 = basic_transfomer_block.attn2
|
148 |
+
|
149 |
+
motion_module.transformer_blocks[i].load_state_dict(
|
150 |
+
basic_transfomer_block.state_dict(), strict=True
|
151 |
+
)
|
152 |
+
motion_module.transformer_blocks[i].set_chunk_feed_forward(
|
153 |
+
basic_transfomer_block._chunk_size, basic_transfomer_block._chunk_dim
|
154 |
+
)
|
155 |
+
|
156 |
+
def _disable_free_noise_in_block(self, block: Union[CrossAttnDownBlockMotion, DownBlockMotion, UpBlockMotion, CrossAttnUpBlockMotion]):
|
157 |
+
r"""Helper function to disable FreeNoise in transformer blocks."""
|
158 |
+
|
159 |
+
for motion_module in block.motion_modules:
|
160 |
+
num_transformer_blocks = len(motion_module.transformer_blocks)
|
161 |
+
|
162 |
+
for i in range(num_transformer_blocks):
|
163 |
+
if isinstance(motion_module.transformer_blocks[i], FreeNoiseTransformerBlock):
|
164 |
+
free_noise_transfomer_block = motion_module.transformer_blocks[i]
|
165 |
+
|
166 |
+
motion_module.transformer_blocks[i] = BasicTransformerBlock(
|
167 |
+
dim=free_noise_transfomer_block.dim,
|
168 |
+
num_attention_heads=free_noise_transfomer_block.num_attention_heads,
|
169 |
+
attention_head_dim=free_noise_transfomer_block.attention_head_dim,
|
170 |
+
dropout=free_noise_transfomer_block.dropout,
|
171 |
+
cross_attention_dim=free_noise_transfomer_block.cross_attention_dim,
|
172 |
+
activation_fn=free_noise_transfomer_block.activation_fn,
|
173 |
+
attention_bias=free_noise_transfomer_block.attention_bias,
|
174 |
+
only_cross_attention=free_noise_transfomer_block.only_cross_attention,
|
175 |
+
double_self_attention=free_noise_transfomer_block.double_self_attention,
|
176 |
+
positional_embeddings=free_noise_transfomer_block.positional_embeddings,
|
177 |
+
num_positional_embeddings=free_noise_transfomer_block.num_positional_embeddings,
|
178 |
+
).to(device=self._execution_device, dtype=self.dtype)
|
179 |
+
|
180 |
+
motion_module.transformer_blocks[i].load_state_dict(
|
181 |
+
free_noise_transfomer_block.state_dict(), strict=True
|
182 |
+
)
|
183 |
+
motion_module.transformer_blocks[i].set_chunk_feed_forward(
|
184 |
+
free_noise_transfomer_block._chunk_size, free_noise_transfomer_block._chunk_dim
|
185 |
+
)
|
186 |
+
|
187 |
+
|
188 |
+
@torch.no_grad()
|
189 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
190 |
+
def __call__(
|
191 |
+
self,
|
192 |
+
prompt: Union[str, List[str]] = None,
|
193 |
+
optical_flow: torch.FloatTensor = None,
|
194 |
+
|
195 |
+
num_frames: Optional[int] = 16,
|
196 |
+
height: Optional[int] = None,
|
197 |
+
width: Optional[int] = None,
|
198 |
+
|
199 |
+
num_inference_steps: int = 50,
|
200 |
+
guidance_scale: float = 7.5,
|
201 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
202 |
+
eta: float = 0.0,
|
203 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
204 |
+
latents: Optional[torch.Tensor] = None,
|
205 |
+
|
206 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
207 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
208 |
+
ip_adapter_image: Optional[PipelineImageInput] = None,
|
209 |
+
ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
|
210 |
+
|
211 |
+
output_type: Optional[str] = "pt",
|
212 |
+
return_dict: bool = True,
|
213 |
+
|
214 |
+
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
215 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
216 |
+
|
217 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
218 |
+
motion_cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
219 |
+
|
220 |
+
clip_skip: Optional[int] = None,
|
221 |
+
decode_chunk_size: int = 16,
|
222 |
+
|
223 |
+
val_scale_factor_spatial: float = 0.,
|
224 |
+
val_scale_factor_temporal: float = 0.,
|
225 |
+
|
226 |
+
**kwargs,
|
227 |
+
):
|
228 |
+
r"""
|
229 |
+
The call function to the pipeline for generation.
|
230 |
+
|
231 |
+
Args:
|
232 |
+
prompt (`str` or `List[str]`, *optional*):
|
233 |
+
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
|
234 |
+
height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
235 |
+
The height in pixels of the generated video.
|
236 |
+
width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
237 |
+
The width in pixels of the generated video.
|
238 |
+
num_frames (`int`, *optional*, defaults to 16):
|
239 |
+
The number of video frames that are generated. Defaults to 16 frames which at 8 frames per seconds
|
240 |
+
amounts to 2 seconds of video.
|
241 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
242 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality videos at the
|
243 |
+
expense of slower inference.
|
244 |
+
guidance_scale (`float`, *optional*, defaults to 7.5):
|
245 |
+
A higher guidance scale value encourages the model to generate images closely linked to the text
|
246 |
+
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
|
247 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
248 |
+
The prompt or prompts to guide what to not include in image generation. If not defined, you need to
|
249 |
+
pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
|
250 |
+
eta (`float`, *optional*, defaults to 0.0):
|
251 |
+
Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
|
252 |
+
to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
|
253 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
254 |
+
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
255 |
+
generation deterministic.
|
256 |
+
latents (`torch.Tensor`, *optional*):
|
257 |
+
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for video
|
258 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
259 |
+
tensor is generated by sampling using the supplied random `generator`. Latents should be of shape
|
260 |
+
`(batch_size, num_channel, num_frames, height, width)`.
|
261 |
+
prompt_embeds (`torch.Tensor`, *optional*):
|
262 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
|
263 |
+
provided, text embeddings are generated from the `prompt` input argument.
|
264 |
+
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
265 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
|
266 |
+
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
|
267 |
+
ip_adapter_image: (`PipelineImageInput`, *optional*):
|
268 |
+
Optional image input to work with IP Adapters.
|
269 |
+
ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
|
270 |
+
Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
|
271 |
+
IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should
|
272 |
+
contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not
|
273 |
+
provided, embeddings are computed from the `ip_adapter_image` input argument.
|
274 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
275 |
+
The output format of the generated video. Choose between `torch.Tensor`, `PIL.Image` or `np.array`.
|
276 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
277 |
+
Whether or not to return a [`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] instead
|
278 |
+
of a plain tuple.
|
279 |
+
cross_attention_kwargs (`dict`, *optional*):
|
280 |
+
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
|
281 |
+
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
282 |
+
clip_skip (`int`, *optional*):
|
283 |
+
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
284 |
+
the output of the pre-final layer will be used for computing the prompt embeddings.
|
285 |
+
callback_on_step_end (`Callable`, *optional*):
|
286 |
+
A function that calls at the end of each denoising steps during the inference. The function is called
|
287 |
+
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
288 |
+
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
289 |
+
`callback_on_step_end_tensor_inputs`.
|
290 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
291 |
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
292 |
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
293 |
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
294 |
+
decode_chunk_size (`int`, defaults to `16`):
|
295 |
+
The number of frames to decode at a time when calling `decode_latents` method.
|
296 |
+
|
297 |
+
Examples:
|
298 |
+
|
299 |
+
Returns:
|
300 |
+
[`~pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] or `tuple`:
|
301 |
+
If `return_dict` is `True`, [`~pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] is
|
302 |
+
returned, otherwise a `tuple` is returned where the first element is a list with the generated frames.
|
303 |
+
"""
|
304 |
+
|
305 |
+
callback = kwargs.pop("callback", None)
|
306 |
+
callback_steps = kwargs.pop("callback_steps", None)
|
307 |
+
|
308 |
+
if callback is not None:
|
309 |
+
deprecate(
|
310 |
+
"callback",
|
311 |
+
"1.0.0",
|
312 |
+
"Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
|
313 |
+
)
|
314 |
+
if callback_steps is not None:
|
315 |
+
deprecate(
|
316 |
+
"callback_steps",
|
317 |
+
"1.0.0",
|
318 |
+
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
|
319 |
+
)
|
320 |
+
|
321 |
+
# 0. Default height and width to unet
|
322 |
+
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
323 |
+
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
324 |
+
|
325 |
+
num_videos_per_prompt = 1
|
326 |
+
|
327 |
+
# 1. Check inputs. Raise error if not correct
|
328 |
+
self.check_inputs(
|
329 |
+
prompt,
|
330 |
+
height,
|
331 |
+
width,
|
332 |
+
callback_steps,
|
333 |
+
negative_prompt,
|
334 |
+
prompt_embeds,
|
335 |
+
negative_prompt_embeds,
|
336 |
+
ip_adapter_image,
|
337 |
+
ip_adapter_image_embeds,
|
338 |
+
callback_on_step_end_tensor_inputs,
|
339 |
+
)
|
340 |
+
|
341 |
+
self._guidance_scale = guidance_scale
|
342 |
+
self._clip_skip = clip_skip
|
343 |
+
self._cross_attention_kwargs = cross_attention_kwargs
|
344 |
+
self._interrupt = False
|
345 |
+
|
346 |
+
# 2. Define call parameters
|
347 |
+
if prompt is not None and isinstance(prompt, (str, dict)):
|
348 |
+
batch_size = 1
|
349 |
+
elif prompt is not None and isinstance(prompt, list):
|
350 |
+
batch_size = len(prompt)
|
351 |
+
else:
|
352 |
+
batch_size = prompt_embeds.shape[0]
|
353 |
+
|
354 |
+
device = self._execution_device
|
355 |
+
|
356 |
+
# 3. Encode input prompt
|
357 |
+
text_encoder_lora_scale = (
|
358 |
+
self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
|
359 |
+
)
|
360 |
+
if self.free_noise_enabled:
|
361 |
+
prompt_embeds, negative_prompt_embeds = self._encode_prompt_free_noise(
|
362 |
+
prompt=prompt,
|
363 |
+
num_frames=num_frames,
|
364 |
+
device=device,
|
365 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
366 |
+
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
367 |
+
negative_prompt=negative_prompt,
|
368 |
+
prompt_embeds=prompt_embeds,
|
369 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
370 |
+
lora_scale=text_encoder_lora_scale,
|
371 |
+
clip_skip=self.clip_skip,
|
372 |
+
)
|
373 |
+
else:
|
374 |
+
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
375 |
+
prompt,
|
376 |
+
device,
|
377 |
+
num_videos_per_prompt,
|
378 |
+
self.do_classifier_free_guidance,
|
379 |
+
negative_prompt,
|
380 |
+
prompt_embeds=prompt_embeds,
|
381 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
382 |
+
lora_scale=text_encoder_lora_scale,
|
383 |
+
clip_skip=self.clip_skip,
|
384 |
+
)
|
385 |
+
|
386 |
+
# For classifier free guidance, we need to do two forward passes.
|
387 |
+
# Here we concatenate the unconditional and text embeddings into a single batch
|
388 |
+
# to avoid doing two forward passes
|
389 |
+
if self.do_classifier_free_guidance:
|
390 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
391 |
+
|
392 |
+
prompt_embeds = prompt_embeds.repeat_interleave(repeats=num_frames, dim=0)
|
393 |
+
|
394 |
+
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
|
395 |
+
image_embeds = self.prepare_ip_adapter_image_embeds(
|
396 |
+
ip_adapter_image,
|
397 |
+
ip_adapter_image_embeds,
|
398 |
+
device,
|
399 |
+
batch_size * num_videos_per_prompt,
|
400 |
+
self.do_classifier_free_guidance,
|
401 |
+
)
|
402 |
+
|
403 |
+
# 4. Prepare timesteps
|
404 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
405 |
+
timesteps = self.scheduler.timesteps
|
406 |
+
|
407 |
+
# 5. Prepare latent variables
|
408 |
+
num_channels_latents = self.unet.config.in_channels
|
409 |
+
latents = self.prepare_latents(
|
410 |
+
batch_size * num_videos_per_prompt,
|
411 |
+
num_channels_latents,
|
412 |
+
num_frames,
|
413 |
+
height,
|
414 |
+
width,
|
415 |
+
prompt_embeds.dtype,
|
416 |
+
device,
|
417 |
+
generator,
|
418 |
+
latents,
|
419 |
+
)
|
420 |
+
|
421 |
+
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
422 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
423 |
+
|
424 |
+
if torch.cuda.is_available():
|
425 |
+
torch.cuda.empty_cache()
|
426 |
+
torch.cuda.reset_peak_memory_stats()
|
427 |
+
torch.cuda.synchronize()
|
428 |
+
assert optical_flow.ndim == 5
|
429 |
+
bs = optical_flow.shape[0]
|
430 |
+
if self.free_noise_enabled:
|
431 |
+
length = optical_flow.shape[2]
|
432 |
+
flow_embedding_features = [
|
433 |
+
torch.zeros((bs, length, *test_size.shape[1:]), device=self._execution_device)
|
434 |
+
for test_size in self.flow_encoder(optical_flow[:,:,:16].to(self._execution_device))
|
435 |
+
]
|
436 |
+
weight_factor = torch.zeros(length, device=self._execution_device)
|
437 |
+
for star_idx in range(0, length, self._free_noise_context_stride):
|
438 |
+
weight_factor[star_idx:star_idx + self._free_noise_context_length] += 1.0
|
439 |
+
infe = self.flow_encoder(optical_flow[:,:,star_idx:star_idx + self._free_noise_context_length].to(self._execution_device))
|
440 |
+
for flow_emb, infe_sub in zip(flow_embedding_features, infe):
|
441 |
+
flow_emb[:,star_idx:star_idx + self._free_noise_context_length] += rearrange(infe_sub, '(b f) c h w -> b f c h w', b=bs).to(self._execution_device)
|
442 |
+
|
443 |
+
flow_embedding_features = [flow_emb / weight_factor[None,:,None,None,None] for flow_emb in flow_embedding_features]
|
444 |
+
flow_embedding_features = [rearrange(x, 'b f c h w -> b c f h w') for x in flow_embedding_features]
|
445 |
+
else:
|
446 |
+
flow_embedding_features = self.flow_encoder(optical_flow.to(self._execution_device)) # input b c f h w into bf, c, h, w
|
447 |
+
flow_embedding_features = [rearrange(x, '(b f) c h w -> b c f h w', b=bs).to(self._execution_device)
|
448 |
+
for x in flow_embedding_features]
|
449 |
+
|
450 |
+
del optical_flow
|
451 |
+
gc.collect()
|
452 |
+
if torch.cuda.is_available():
|
453 |
+
torch.cuda.empty_cache()
|
454 |
+
torch.cuda.reset_peak_memory_stats()
|
455 |
+
torch.cuda.synchronize()
|
456 |
+
|
457 |
+
# 7. Add image embeds for IP-Adapter
|
458 |
+
added_cond_kwargs = (
|
459 |
+
{"image_embeds": image_embeds}
|
460 |
+
if ip_adapter_image is not None or ip_adapter_image_embeds is not None
|
461 |
+
else None
|
462 |
+
)
|
463 |
+
|
464 |
+
num_free_init_iters = self._free_init_num_iters if self.free_init_enabled else 1
|
465 |
+
for free_init_iter in range(num_free_init_iters):
|
466 |
+
if self.free_init_enabled:
|
467 |
+
latents, timesteps = self._apply_free_init(
|
468 |
+
latents, free_init_iter, num_inference_steps, device, latents.dtype, generator
|
469 |
+
)
|
470 |
+
|
471 |
+
self._num_timesteps = len(timesteps)
|
472 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
473 |
+
|
474 |
+
if isinstance(flow_embedding_features[0], list):
|
475 |
+
flow_embedding_features = [[torch.cat([x, x], dim=0) for x in flow_embedding_feature]
|
476 |
+
for flow_embedding_feature in flow_embedding_features] \
|
477 |
+
if self.do_classifier_free_guidance else flow_embedding_features
|
478 |
+
else:
|
479 |
+
flow_embedding_features = [torch.cat([x, x], dim=0) for x in flow_embedding_features] \
|
480 |
+
if self.do_classifier_free_guidance else flow_embedding_features # [2b c f h w]
|
481 |
+
|
482 |
+
# 8. Denoising loop
|
483 |
+
with self.progress_bar(total=self._num_timesteps) as progress_bar:
|
484 |
+
for i, t in enumerate(timesteps):
|
485 |
+
if self.interrupt:
|
486 |
+
continue
|
487 |
+
|
488 |
+
# expand the latents if we are doing classifier free guidance
|
489 |
+
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
490 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
491 |
+
|
492 |
+
if added_cond_kwargs is not None:
|
493 |
+
added_cond_kwargs.update({"flow_embedding_features": flow_embedding_features})
|
494 |
+
else:
|
495 |
+
added_cond_kwargs = {"flow_embedding_features": flow_embedding_features}
|
496 |
+
|
497 |
+
if cross_attention_kwargs is not None:
|
498 |
+
cross_attention_kwargs.update({"flow_scale": val_scale_factor_spatial})
|
499 |
+
else:
|
500 |
+
cross_attention_kwargs = {"flow_scale": val_scale_factor_spatial}
|
501 |
+
|
502 |
+
if motion_cross_attention_kwargs is not None:
|
503 |
+
motion_cross_attention_kwargs.update({"flow_scale": val_scale_factor_temporal})
|
504 |
+
else:
|
505 |
+
motion_cross_attention_kwargs = {"flow_scale": val_scale_factor_temporal}
|
506 |
+
|
507 |
+
# predict the noise residual
|
508 |
+
|
509 |
+
noise_pred = self.unet(
|
510 |
+
latent_model_input,
|
511 |
+
t,
|
512 |
+
encoder_hidden_states=prompt_embeds,
|
513 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
514 |
+
motion_cross_attention_kwargs=motion_cross_attention_kwargs,
|
515 |
+
added_cond_kwargs=added_cond_kwargs,
|
516 |
+
).sample
|
517 |
+
|
518 |
+
# perform guidance
|
519 |
+
if self.do_classifier_free_guidance:
|
520 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
521 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
522 |
+
|
523 |
+
# compute the previous noisy sample x_t -> x_t-1
|
524 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
525 |
+
|
526 |
+
if callback_on_step_end is not None:
|
527 |
+
callback_kwargs = {}
|
528 |
+
for k in callback_on_step_end_tensor_inputs:
|
529 |
+
callback_kwargs[k] = locals()[k]
|
530 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
531 |
+
|
532 |
+
latents = callback_outputs.pop("latents", latents)
|
533 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
534 |
+
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
535 |
+
|
536 |
+
# call the callback, if provided
|
537 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
538 |
+
progress_bar.update()
|
539 |
+
if callback is not None and i % callback_steps == 0:
|
540 |
+
callback(i, t, latents)
|
541 |
+
|
542 |
+
# 9. Post processing
|
543 |
+
if output_type == "latent":
|
544 |
+
video = latents
|
545 |
+
else:
|
546 |
+
video_tensor = self.decode_latents(latents, decode_chunk_size)
|
547 |
+
video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type)
|
548 |
+
|
549 |
+
# 10. Offload all models
|
550 |
+
self.maybe_free_model_hooks()
|
551 |
+
|
552 |
+
if not return_dict:
|
553 |
+
return (video,)
|
554 |
+
|
555 |
+
return FlowCtrlPipelineOutput(frames=video)
|
onlyflow/utils/util.py
ADDED
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import atexit
|
2 |
+
import functools
|
3 |
+
import importlib
|
4 |
+
import io
|
5 |
+
import logging
|
6 |
+
import os
|
7 |
+
import sys
|
8 |
+
|
9 |
+
import imageio
|
10 |
+
import numpy as np
|
11 |
+
import torch
|
12 |
+
from termcolor import colored
|
13 |
+
|
14 |
+
|
15 |
+
def instantiate_from_config(config, **additional_kwargs):
|
16 |
+
if not "target" in config:
|
17 |
+
if config == '__is_first_stage__':
|
18 |
+
return None
|
19 |
+
elif config == "__is_unconditional__":
|
20 |
+
return None
|
21 |
+
raise KeyError("Expected key `target` to instantiate.")
|
22 |
+
|
23 |
+
additional_kwargs.update(config.get("kwargs", dict()))
|
24 |
+
return get_obj_from_str(config["target"])(**additional_kwargs)
|
25 |
+
|
26 |
+
|
27 |
+
def get_obj_from_str(string, reload=False):
|
28 |
+
module, cls = string.rsplit(".", 1)
|
29 |
+
if reload:
|
30 |
+
module_imp = importlib.import_module(module)
|
31 |
+
importlib.reload(module_imp)
|
32 |
+
return getattr(importlib.import_module(module, package=None), cls)
|
33 |
+
|
34 |
+
|
35 |
+
def get_video(videos: torch.Tensor, path: str, rescale=False, fps=8):
|
36 |
+
if rescale:
|
37 |
+
videos = (videos + 1.0) / 2.0 # -1,1 -> 0,1
|
38 |
+
videos = (videos * 255).numpy().astype(np.uint8)
|
39 |
+
videos = np.transpose(videos, axes=(1, 2, 3, 0))
|
40 |
+
|
41 |
+
binary_object = io.BytesIO()
|
42 |
+
|
43 |
+
imageio.mimsave(binary_object, list(videos), fps=fps, format='gif')
|
44 |
+
|
45 |
+
return binary_object
|
46 |
+
|
47 |
+
|
48 |
+
# Logger utils are copied from detectron2
|
49 |
+
class _ColorfulFormatter(logging.Formatter):
|
50 |
+
def __init__(self, *args, **kwargs):
|
51 |
+
self._root_name = kwargs.pop("root_name") + "."
|
52 |
+
self._abbrev_name = kwargs.pop("abbrev_name", "")
|
53 |
+
if len(self._abbrev_name):
|
54 |
+
self._abbrev_name = self._abbrev_name + "."
|
55 |
+
super(_ColorfulFormatter, self).__init__(*args, **kwargs)
|
56 |
+
|
57 |
+
def formatMessage(self, record):
|
58 |
+
record.name = record.name.replace(self._root_name, self._abbrev_name)
|
59 |
+
log = super(_ColorfulFormatter, self).formatMessage(record)
|
60 |
+
if record.levelno == logging.WARNING:
|
61 |
+
prefix = colored("WARNING", "red", attrs=["blink"])
|
62 |
+
elif record.levelno == logging.ERROR or record.levelno == logging.CRITICAL:
|
63 |
+
prefix = colored("ERROR", "red", attrs=["blink", "underline"])
|
64 |
+
else:
|
65 |
+
return log
|
66 |
+
return prefix + " " + log
|
67 |
+
|
68 |
+
|
69 |
+
# cache the opened file object, so that different calls to `setup_logger`
|
70 |
+
# with the same file name can safely write to the same file.
|
71 |
+
@functools.lru_cache(maxsize=None)
|
72 |
+
def _cached_log_stream(filename):
|
73 |
+
# use 1K buffer if writing to cloud storage
|
74 |
+
io = open(filename, "a", buffering=1024 if "://" in filename else -1)
|
75 |
+
atexit.register(io.close)
|
76 |
+
return io
|
77 |
+
|
78 |
+
|
79 |
+
@functools.lru_cache()
|
80 |
+
def setup_logger(output, distributed_rank, color=True, name='AnimateDiff', abbrev_name=None):
|
81 |
+
logger = logging.getLogger(name)
|
82 |
+
logger.setLevel(logging.DEBUG)
|
83 |
+
logger.propagate = False
|
84 |
+
|
85 |
+
if abbrev_name is None:
|
86 |
+
abbrev_name = 'AD'
|
87 |
+
plain_formatter = logging.Formatter(
|
88 |
+
"[%(asctime)s] %(name)s:%(lineno)d %(levelname)s: %(message)s", datefmt="%m/%d %H:%M:%S"
|
89 |
+
)
|
90 |
+
|
91 |
+
# stdout logging: master only
|
92 |
+
if distributed_rank == 0:
|
93 |
+
ch = logging.StreamHandler(stream=sys.stdout)
|
94 |
+
ch.setLevel(logging.DEBUG)
|
95 |
+
if color:
|
96 |
+
formatter = _ColorfulFormatter(
|
97 |
+
colored("[%(asctime)s %(name)s:%(lineno)d]: ", "green") + "%(message)s",
|
98 |
+
datefmt="%m/%d %H:%M:%S",
|
99 |
+
root_name=name,
|
100 |
+
abbrev_name=str(abbrev_name),
|
101 |
+
)
|
102 |
+
else:
|
103 |
+
formatter = plain_formatter
|
104 |
+
ch.setFormatter(formatter)
|
105 |
+
logger.addHandler(ch)
|
106 |
+
|
107 |
+
# file logging: all workers
|
108 |
+
if output is not None:
|
109 |
+
filename = os.path.join(output, "ranks_logs", f"log.{distributed_rank}.txt")
|
110 |
+
os.makedirs(os.path.dirname(filename), exist_ok=True)
|
111 |
+
fh = logging.StreamHandler(_cached_log_stream(filename))
|
112 |
+
fh.setLevel(logging.DEBUG)
|
113 |
+
fh.setFormatter(plain_formatter)
|
114 |
+
logger.addHandler(fh)
|
115 |
+
|
116 |
+
return logger
|
117 |
+
|
118 |
+
|
119 |
+
def format_time(elapsed_time):
|
120 |
+
# Time thresholds
|
121 |
+
minute = 60
|
122 |
+
hour = 60 * minute
|
123 |
+
day = 24 * hour
|
124 |
+
|
125 |
+
days, remainder = divmod(elapsed_time, day)
|
126 |
+
hours, remainder = divmod(remainder, hour)
|
127 |
+
minutes, seconds = divmod(remainder, minute)
|
128 |
+
|
129 |
+
formatted_time = ""
|
130 |
+
|
131 |
+
if days > 0:
|
132 |
+
formatted_time += f"{int(days)} days "
|
133 |
+
if hours > 0:
|
134 |
+
formatted_time += f"{int(hours)} hours "
|
135 |
+
if minutes > 0:
|
136 |
+
formatted_time += f"{int(minutes)} minutes "
|
137 |
+
if seconds > 0:
|
138 |
+
formatted_time += f"{seconds:.2f} seconds"
|
139 |
+
|
140 |
+
return formatted_time.strip()
|
requirements.txt
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
torchvision
|
3 |
+
diffusers
|
4 |
+
transformers
|
5 |
+
accelerate
|
6 |
+
git+https://github.com/obvious-research/diffusers.git
|
7 |
+
numpy
|
8 |
+
einops
|
9 |
+
imageio
|
10 |
+
omegaconf
|
11 |
+
av==12.0.0
|
tools/optical_flow.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from einops import rearrange
|
3 |
+
|
4 |
+
@torch.no_grad()
|
5 |
+
def get_optical_flow(raft_model, pixel_values, video_length, encode_chunk_size=48, num_flow_updates=14):
|
6 |
+
imgs_1 = pixel_values[:, :-1]
|
7 |
+
imgs_2 = pixel_values[:, 1:]
|
8 |
+
imgs_1 = rearrange(imgs_1, "b f c h w -> (b f) c h w")
|
9 |
+
imgs_2 = rearrange(imgs_2, "b f c h w -> (b f) c h w")
|
10 |
+
|
11 |
+
flow_embedding = []
|
12 |
+
|
13 |
+
for i in range(0, imgs_1.shape[0], encode_chunk_size):
|
14 |
+
imgs_1_chunk = imgs_1[i:i + encode_chunk_size]
|
15 |
+
imgs_2_chunk = imgs_2[i:i + encode_chunk_size]
|
16 |
+
flow_embedding_chunk = raft_model(imgs_1_chunk, imgs_2_chunk, num_flow_updates)[-1]
|
17 |
+
flow_embedding.append(flow_embedding_chunk)
|
18 |
+
|
19 |
+
flow_embedding = torch.cat(flow_embedding).contiguous()
|
20 |
+
flow_embedding = rearrange(flow_embedding, "(b f) c h w -> b c f h w", f=video_length)
|
21 |
+
|
22 |
+
return flow_embedding
|