Spaces:
Running
Running
bubbliiiing
commited on
Commit
·
19fe404
1
Parent(s):
6e9356e
Create Code
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- app.py +40 -0
- config/easyanimate_image_magvit_v2.yaml +8 -0
- config/easyanimate_image_normal_v1.yaml +8 -0
- config/easyanimate_image_slicevae_v3.yaml +9 -0
- config/easyanimate_video_casual_motion_module_v1.yaml +27 -0
- config/easyanimate_video_long_sequence_v1.yaml +14 -0
- config/easyanimate_video_magvit_motion_module_v2.yaml +26 -0
- config/easyanimate_video_motion_module_v1.yaml +24 -0
- config/easyanimate_video_slicevae_motion_module_v3.yaml +27 -0
- easyanimate/__init__.py +0 -0
- easyanimate/api/api.py +96 -0
- easyanimate/api/post_infer.py +94 -0
- easyanimate/data/bucket_sampler.py +379 -0
- easyanimate/data/dataset_image.py +76 -0
- easyanimate/data/dataset_image_video.py +241 -0
- easyanimate/data/dataset_video.py +262 -0
- easyanimate/models/attention.py +1299 -0
- easyanimate/models/autoencoder_magvit.py +503 -0
- easyanimate/models/motion_module.py +575 -0
- easyanimate/models/patch.py +426 -0
- easyanimate/models/transformer2d.py +555 -0
- easyanimate/models/transformer3d.py +738 -0
- easyanimate/pipeline/pipeline_easyanimate.py +847 -0
- easyanimate/pipeline/pipeline_easyanimate_inpaint.py +984 -0
- easyanimate/pipeline/pipeline_pixart_magvit.py +983 -0
- easyanimate/ui/ui.py +818 -0
- easyanimate/utils/__init__.py +0 -0
- easyanimate/utils/diffusion_utils.py +92 -0
- easyanimate/utils/gaussian_diffusion.py +1008 -0
- easyanimate/utils/lora_utils.py +476 -0
- easyanimate/utils/respace.py +131 -0
- easyanimate/utils/utils.py +64 -0
- easyanimate/vae/LICENSE +82 -0
- easyanimate/vae/README.md +63 -0
- easyanimate/vae/README_zh-CN.md +63 -0
- easyanimate/vae/configs/autoencoder/autoencoder_kl_32x32x4_mag.yaml +62 -0
- easyanimate/vae/configs/autoencoder/autoencoder_kl_32x32x4_slice.yaml +65 -0
- easyanimate/vae/configs/autoencoder/autoencoder_kl_32x32x4_slice_decoder_only.yaml +66 -0
- easyanimate/vae/configs/autoencoder/autoencoder_kl_32x32x4_slice_t_downsample_8.yaml +66 -0
- easyanimate/vae/environment.yaml +29 -0
- easyanimate/vae/ldm/data/__init__.py +0 -0
- easyanimate/vae/ldm/data/base.py +25 -0
- easyanimate/vae/ldm/data/dataset_callback.py +25 -0
- easyanimate/vae/ldm/data/dataset_image_video.py +281 -0
- easyanimate/vae/ldm/lr_scheduler.py +98 -0
- easyanimate/vae/ldm/modules/diffusionmodules/__init__.py +0 -0
- easyanimate/vae/ldm/modules/diffusionmodules/model.py +701 -0
- easyanimate/vae/ldm/modules/diffusionmodules/util.py +268 -0
- easyanimate/vae/ldm/modules/distributions/__init__.py +0 -0
- easyanimate/vae/ldm/modules/distributions/distributions.py +92 -0
app.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
|
3 |
+
from easyanimate.api.api import infer_forward_api, update_diffusion_transformer_api, update_edition_api
|
4 |
+
from easyanimate.ui.ui import ui_modelscope, ui, ui_huggingface
|
5 |
+
|
6 |
+
if __name__ == "__main__":
|
7 |
+
# Choose the ui mode
|
8 |
+
ui_mode = "huggingface"
|
9 |
+
# Server ip
|
10 |
+
server_name = "0.0.0.0"
|
11 |
+
server_port = 7860
|
12 |
+
|
13 |
+
# Params below is used when ui_mode = "modelscope"
|
14 |
+
edition = "v2"
|
15 |
+
config_path = "config/easyanimate_video_magvit_motion_module_v2.yaml"
|
16 |
+
model_name = "models/Diffusion_Transformer/EasyAnimateV2-XL-2-512x512"
|
17 |
+
savedir_sample = "samples"
|
18 |
+
|
19 |
+
if ui_mode == "modelscope":
|
20 |
+
demo, controller = ui_modelscope(edition, config_path, model_name, savedir_sample)
|
21 |
+
elif ui_mode == "huggingface":
|
22 |
+
demo, controller = ui_huggingface(edition, config_path, model_name, savedir_sample)
|
23 |
+
else:
|
24 |
+
demo, controller = ui()
|
25 |
+
|
26 |
+
# launch gradio
|
27 |
+
app, _, _ = demo.queue(status_update_rate=1).launch(
|
28 |
+
server_name=server_name,
|
29 |
+
server_port=server_port,
|
30 |
+
prevent_thread_lock=True
|
31 |
+
)
|
32 |
+
|
33 |
+
# launch api
|
34 |
+
infer_forward_api(None, app, controller)
|
35 |
+
update_diffusion_transformer_api(None, app, controller)
|
36 |
+
update_edition_api(None, app, controller)
|
37 |
+
|
38 |
+
# not close the python
|
39 |
+
while True:
|
40 |
+
time.sleep(5)
|
config/easyanimate_image_magvit_v2.yaml
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
noise_scheduler_kwargs:
|
2 |
+
beta_start: 0.0001
|
3 |
+
beta_end: 0.02
|
4 |
+
beta_schedule: "linear"
|
5 |
+
steps_offset: 1
|
6 |
+
|
7 |
+
vae_kwargs:
|
8 |
+
enable_magvit: true
|
config/easyanimate_image_normal_v1.yaml
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
noise_scheduler_kwargs:
|
2 |
+
beta_start: 0.0001
|
3 |
+
beta_end: 0.02
|
4 |
+
beta_schedule: "linear"
|
5 |
+
steps_offset: 1
|
6 |
+
|
7 |
+
vae_kwargs:
|
8 |
+
enable_magvit: false
|
config/easyanimate_image_slicevae_v3.yaml
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
noise_scheduler_kwargs:
|
2 |
+
beta_start: 0.0001
|
3 |
+
beta_end: 0.02
|
4 |
+
beta_schedule: "linear"
|
5 |
+
steps_offset: 1
|
6 |
+
|
7 |
+
vae_kwargs:
|
8 |
+
enable_magvit: true
|
9 |
+
slice_compression_vae: true
|
config/easyanimate_video_casual_motion_module_v1.yaml
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
transformer_additional_kwargs:
|
2 |
+
patch_3d: false
|
3 |
+
fake_3d: false
|
4 |
+
casual_3d: true
|
5 |
+
casual_3d_upsampler_index: [16, 20]
|
6 |
+
time_patch_size: 4
|
7 |
+
basic_block_type: "motionmodule"
|
8 |
+
time_position_encoding_before_transformer: false
|
9 |
+
motion_module_type: "VanillaGrid"
|
10 |
+
|
11 |
+
motion_module_kwargs:
|
12 |
+
num_attention_heads: 8
|
13 |
+
num_transformer_block: 1
|
14 |
+
attention_block_types: [ "Temporal_Self", "Temporal_Self" ]
|
15 |
+
temporal_position_encoding: true
|
16 |
+
temporal_position_encoding_max_len: 4096
|
17 |
+
temporal_attention_dim_div: 1
|
18 |
+
block_size: 2
|
19 |
+
|
20 |
+
noise_scheduler_kwargs:
|
21 |
+
beta_start: 0.0001
|
22 |
+
beta_end: 0.02
|
23 |
+
beta_schedule: "linear"
|
24 |
+
steps_offset: 1
|
25 |
+
|
26 |
+
vae_kwargs:
|
27 |
+
enable_magvit: false
|
config/easyanimate_video_long_sequence_v1.yaml
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
transformer_additional_kwargs:
|
2 |
+
patch_3d: false
|
3 |
+
fake_3d: false
|
4 |
+
basic_block_type: "selfattentiontemporal"
|
5 |
+
time_position_encoding_before_transformer: true
|
6 |
+
|
7 |
+
noise_scheduler_kwargs:
|
8 |
+
beta_start: 0.0001
|
9 |
+
beta_end: 0.02
|
10 |
+
beta_schedule: "linear"
|
11 |
+
steps_offset: 1
|
12 |
+
|
13 |
+
vae_kwargs:
|
14 |
+
enable_magvit: false
|
config/easyanimate_video_magvit_motion_module_v2.yaml
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
transformer_additional_kwargs:
|
2 |
+
patch_3d: false
|
3 |
+
fake_3d: false
|
4 |
+
basic_block_type: "motionmodule"
|
5 |
+
time_position_encoding_before_transformer: false
|
6 |
+
motion_module_type: "Vanilla"
|
7 |
+
enable_uvit: true
|
8 |
+
|
9 |
+
motion_module_kwargs:
|
10 |
+
num_attention_heads: 8
|
11 |
+
num_transformer_block: 1
|
12 |
+
attention_block_types: [ "Temporal_Self", "Temporal_Self" ]
|
13 |
+
temporal_position_encoding: true
|
14 |
+
temporal_position_encoding_max_len: 4096
|
15 |
+
temporal_attention_dim_div: 1
|
16 |
+
block_size: 1
|
17 |
+
|
18 |
+
noise_scheduler_kwargs:
|
19 |
+
beta_start: 0.0001
|
20 |
+
beta_end: 0.02
|
21 |
+
beta_schedule: "linear"
|
22 |
+
steps_offset: 1
|
23 |
+
|
24 |
+
vae_kwargs:
|
25 |
+
enable_magvit: true
|
26 |
+
mini_batch_encoder: 9
|
config/easyanimate_video_motion_module_v1.yaml
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
transformer_additional_kwargs:
|
2 |
+
patch_3d: false
|
3 |
+
fake_3d: false
|
4 |
+
basic_block_type: "motionmodule"
|
5 |
+
time_position_encoding_before_transformer: false
|
6 |
+
motion_module_type: "VanillaGrid"
|
7 |
+
|
8 |
+
motion_module_kwargs:
|
9 |
+
num_attention_heads: 8
|
10 |
+
num_transformer_block: 1
|
11 |
+
attention_block_types: [ "Temporal_Self", "Temporal_Self" ]
|
12 |
+
temporal_position_encoding: true
|
13 |
+
temporal_position_encoding_max_len: 4096
|
14 |
+
temporal_attention_dim_div: 1
|
15 |
+
block_size: 2
|
16 |
+
|
17 |
+
noise_scheduler_kwargs:
|
18 |
+
beta_start: 0.0001
|
19 |
+
beta_end: 0.02
|
20 |
+
beta_schedule: "linear"
|
21 |
+
steps_offset: 1
|
22 |
+
|
23 |
+
vae_kwargs:
|
24 |
+
enable_magvit: false
|
config/easyanimate_video_slicevae_motion_module_v3.yaml
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
transformer_additional_kwargs:
|
2 |
+
patch_3d: false
|
3 |
+
fake_3d: false
|
4 |
+
basic_block_type: "motionmodule"
|
5 |
+
time_position_encoding_before_transformer: false
|
6 |
+
motion_module_type: "Vanilla"
|
7 |
+
enable_uvit: true
|
8 |
+
|
9 |
+
motion_module_kwargs:
|
10 |
+
num_attention_heads: 8
|
11 |
+
num_transformer_block: 1
|
12 |
+
attention_block_types: [ "Temporal_Self", "Temporal_Self" ]
|
13 |
+
temporal_position_encoding: true
|
14 |
+
temporal_position_encoding_max_len: 4096
|
15 |
+
temporal_attention_dim_div: 1
|
16 |
+
block_size: 1
|
17 |
+
|
18 |
+
noise_scheduler_kwargs:
|
19 |
+
beta_start: 0.0001
|
20 |
+
beta_end: 0.02
|
21 |
+
beta_schedule: "linear"
|
22 |
+
steps_offset: 1
|
23 |
+
|
24 |
+
vae_kwargs:
|
25 |
+
enable_magvit: true
|
26 |
+
slice_compression_vae: true
|
27 |
+
mini_batch_encoder: 8
|
easyanimate/__init__.py
ADDED
File without changes
|
easyanimate/api/api.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import io
|
2 |
+
import base64
|
3 |
+
import torch
|
4 |
+
import gradio as gr
|
5 |
+
|
6 |
+
from fastapi import FastAPI
|
7 |
+
from io import BytesIO
|
8 |
+
|
9 |
+
# Function to encode a file to Base64
|
10 |
+
def encode_file_to_base64(file_path):
|
11 |
+
with open(file_path, "rb") as file:
|
12 |
+
# Encode the data to Base64
|
13 |
+
file_base64 = base64.b64encode(file.read())
|
14 |
+
return file_base64
|
15 |
+
|
16 |
+
def update_edition_api(_: gr.Blocks, app: FastAPI, controller):
|
17 |
+
@app.post("/easyanimate/update_edition")
|
18 |
+
def _update_edition_api(
|
19 |
+
datas: dict,
|
20 |
+
):
|
21 |
+
edition = datas.get('edition', 'v2')
|
22 |
+
|
23 |
+
try:
|
24 |
+
controller.update_edition(
|
25 |
+
edition
|
26 |
+
)
|
27 |
+
comment = "Success"
|
28 |
+
except Exception as e:
|
29 |
+
torch.cuda.empty_cache()
|
30 |
+
comment = f"Error. error information is {str(e)}"
|
31 |
+
|
32 |
+
return {"message": comment}
|
33 |
+
|
34 |
+
def update_diffusion_transformer_api(_: gr.Blocks, app: FastAPI, controller):
|
35 |
+
@app.post("/easyanimate/update_diffusion_transformer")
|
36 |
+
def _update_diffusion_transformer_api(
|
37 |
+
datas: dict,
|
38 |
+
):
|
39 |
+
diffusion_transformer_path = datas.get('diffusion_transformer_path', 'none')
|
40 |
+
|
41 |
+
try:
|
42 |
+
controller.update_diffusion_transformer(
|
43 |
+
diffusion_transformer_path
|
44 |
+
)
|
45 |
+
comment = "Success"
|
46 |
+
except Exception as e:
|
47 |
+
torch.cuda.empty_cache()
|
48 |
+
comment = f"Error. error information is {str(e)}"
|
49 |
+
|
50 |
+
return {"message": comment}
|
51 |
+
|
52 |
+
def infer_forward_api(_: gr.Blocks, app: FastAPI, controller):
|
53 |
+
@app.post("/easyanimate/infer_forward")
|
54 |
+
def _infer_forward_api(
|
55 |
+
datas: dict,
|
56 |
+
):
|
57 |
+
base_model_path = datas.get('base_model_path', 'none')
|
58 |
+
motion_module_path = datas.get('motion_module_path', 'none')
|
59 |
+
lora_model_path = datas.get('lora_model_path', 'none')
|
60 |
+
lora_alpha_slider = datas.get('lora_alpha_slider', 0.55)
|
61 |
+
prompt_textbox = datas.get('prompt_textbox', None)
|
62 |
+
negative_prompt_textbox = datas.get('negative_prompt_textbox', '')
|
63 |
+
sampler_dropdown = datas.get('sampler_dropdown', 'Euler')
|
64 |
+
sample_step_slider = datas.get('sample_step_slider', 30)
|
65 |
+
width_slider = datas.get('width_slider', 672)
|
66 |
+
height_slider = datas.get('height_slider', 384)
|
67 |
+
is_image = datas.get('is_image', False)
|
68 |
+
length_slider = datas.get('length_slider', 144)
|
69 |
+
cfg_scale_slider = datas.get('cfg_scale_slider', 6)
|
70 |
+
seed_textbox = datas.get("seed_textbox", 43)
|
71 |
+
|
72 |
+
try:
|
73 |
+
save_sample_path, comment = controller.generate(
|
74 |
+
"",
|
75 |
+
base_model_path,
|
76 |
+
motion_module_path,
|
77 |
+
lora_model_path,
|
78 |
+
lora_alpha_slider,
|
79 |
+
prompt_textbox,
|
80 |
+
negative_prompt_textbox,
|
81 |
+
sampler_dropdown,
|
82 |
+
sample_step_slider,
|
83 |
+
width_slider,
|
84 |
+
height_slider,
|
85 |
+
is_image,
|
86 |
+
length_slider,
|
87 |
+
cfg_scale_slider,
|
88 |
+
seed_textbox,
|
89 |
+
is_api = True,
|
90 |
+
)
|
91 |
+
except Exception as e:
|
92 |
+
torch.cuda.empty_cache()
|
93 |
+
save_sample_path = ""
|
94 |
+
comment = f"Error. error information is {str(e)}"
|
95 |
+
|
96 |
+
return {"message": comment, "save_sample_path": save_sample_path, "base64_encoding": encode_file_to_base64(save_sample_path)}
|
easyanimate/api/post_infer.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import base64
|
2 |
+
import json
|
3 |
+
import sys
|
4 |
+
import time
|
5 |
+
from datetime import datetime
|
6 |
+
from io import BytesIO
|
7 |
+
|
8 |
+
import cv2
|
9 |
+
import requests
|
10 |
+
import base64
|
11 |
+
|
12 |
+
|
13 |
+
def post_diffusion_transformer(diffusion_transformer_path, url='http://127.0.0.1:7860'):
|
14 |
+
datas = json.dumps({
|
15 |
+
"diffusion_transformer_path": diffusion_transformer_path
|
16 |
+
})
|
17 |
+
r = requests.post(f'{url}/easyanimate/update_diffusion_transformer', data=datas, timeout=1500)
|
18 |
+
data = r.content.decode('utf-8')
|
19 |
+
return data
|
20 |
+
|
21 |
+
def post_update_edition(edition, url='http://0.0.0.0:7860'):
|
22 |
+
datas = json.dumps({
|
23 |
+
"edition": edition
|
24 |
+
})
|
25 |
+
r = requests.post(f'{url}/easyanimate/update_edition', data=datas, timeout=1500)
|
26 |
+
data = r.content.decode('utf-8')
|
27 |
+
return data
|
28 |
+
|
29 |
+
def post_infer(is_image, length_slider, url='http://127.0.0.1:7860'):
|
30 |
+
datas = json.dumps({
|
31 |
+
"base_model_path": "none",
|
32 |
+
"motion_module_path": "none",
|
33 |
+
"lora_model_path": "none",
|
34 |
+
"lora_alpha_slider": 0.55,
|
35 |
+
"prompt_textbox": "This video shows Mount saint helens, washington - the stunning scenery of a rocky mountains during golden hours - wide shot. A soaring drone footage captures the majestic beauty of a coastal cliff, its red and yellow stratified rock faces rich in color and against the vibrant turquoise of the sea.",
|
36 |
+
"negative_prompt_textbox": "Strange motion trajectory, a poor composition and deformed video, worst quality, normal quality, low quality, low resolution, duplicate and ugly, strange body structure, long and strange neck, bad teeth, bad eyes, bad limbs, bad hands, rotating camera, blurry camera, shaking camera",
|
37 |
+
"sampler_dropdown": "Euler",
|
38 |
+
"sample_step_slider": 30,
|
39 |
+
"width_slider": 672,
|
40 |
+
"height_slider": 384,
|
41 |
+
"is_image": is_image,
|
42 |
+
"length_slider": length_slider,
|
43 |
+
"cfg_scale_slider": 6,
|
44 |
+
"seed_textbox": 43,
|
45 |
+
})
|
46 |
+
r = requests.post(f'{url}/easyanimate/infer_forward', data=datas, timeout=1500)
|
47 |
+
data = r.content.decode('utf-8')
|
48 |
+
return data
|
49 |
+
|
50 |
+
if __name__ == '__main__':
|
51 |
+
# initiate time
|
52 |
+
now_date = datetime.now()
|
53 |
+
time_start = time.time()
|
54 |
+
|
55 |
+
# -------------------------- #
|
56 |
+
# Step 1: update edition
|
57 |
+
# -------------------------- #
|
58 |
+
edition = "v2"
|
59 |
+
outputs = post_update_edition(edition)
|
60 |
+
print('Output update edition: ', outputs)
|
61 |
+
|
62 |
+
# -------------------------- #
|
63 |
+
# Step 2: update edition
|
64 |
+
# -------------------------- #
|
65 |
+
diffusion_transformer_path = "/your-path/EasyAnimate/models/Diffusion_Transformer/EasyAnimateV2-XL-2-512x512"
|
66 |
+
outputs = post_diffusion_transformer(diffusion_transformer_path)
|
67 |
+
print('Output update edition: ', outputs)
|
68 |
+
|
69 |
+
# -------------------------- #
|
70 |
+
# Step 3: infer
|
71 |
+
# -------------------------- #
|
72 |
+
is_image = False
|
73 |
+
length_slider = 27
|
74 |
+
outputs = post_infer(is_image, length_slider)
|
75 |
+
|
76 |
+
# Get decoded data
|
77 |
+
outputs = json.loads(outputs)
|
78 |
+
base64_encoding = outputs["base64_encoding"]
|
79 |
+
decoded_data = base64.b64decode(base64_encoding)
|
80 |
+
|
81 |
+
if is_image or length_slider == 1:
|
82 |
+
file_path = "1.png"
|
83 |
+
else:
|
84 |
+
file_path = "1.mp4"
|
85 |
+
with open(file_path, "wb") as file:
|
86 |
+
file.write(decoded_data)
|
87 |
+
|
88 |
+
# End of record time
|
89 |
+
# The calculated time difference is the execution time of the program, expressed in seconds / s
|
90 |
+
time_end = time.time()
|
91 |
+
time_sum = (time_end - time_start) % 60
|
92 |
+
print('# --------------------------------------------------------- #')
|
93 |
+
print(f'# Total expenditure: {time_sum}s')
|
94 |
+
print('# --------------------------------------------------------- #')
|
easyanimate/data/bucket_sampler.py
ADDED
@@ -0,0 +1,379 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import os
|
3 |
+
from typing import (Generic, Iterable, Iterator, List, Optional, Sequence,
|
4 |
+
Sized, TypeVar, Union)
|
5 |
+
|
6 |
+
import cv2
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
from PIL import Image
|
10 |
+
from torch.utils.data import BatchSampler, Dataset, Sampler
|
11 |
+
|
12 |
+
ASPECT_RATIO_512 = {
|
13 |
+
'0.25': [256.0, 1024.0], '0.26': [256.0, 992.0], '0.27': [256.0, 960.0], '0.28': [256.0, 928.0],
|
14 |
+
'0.32': [288.0, 896.0], '0.33': [288.0, 864.0], '0.35': [288.0, 832.0], '0.4': [320.0, 800.0],
|
15 |
+
'0.42': [320.0, 768.0], '0.48': [352.0, 736.0], '0.5': [352.0, 704.0], '0.52': [352.0, 672.0],
|
16 |
+
'0.57': [384.0, 672.0], '0.6': [384.0, 640.0], '0.68': [416.0, 608.0], '0.72': [416.0, 576.0],
|
17 |
+
'0.78': [448.0, 576.0], '0.82': [448.0, 544.0], '0.88': [480.0, 544.0], '0.94': [480.0, 512.0],
|
18 |
+
'1.0': [512.0, 512.0], '1.07': [512.0, 480.0], '1.13': [544.0, 480.0], '1.21': [544.0, 448.0],
|
19 |
+
'1.29': [576.0, 448.0], '1.38': [576.0, 416.0], '1.46': [608.0, 416.0], '1.67': [640.0, 384.0],
|
20 |
+
'1.75': [672.0, 384.0], '2.0': [704.0, 352.0], '2.09': [736.0, 352.0], '2.4': [768.0, 320.0],
|
21 |
+
'2.5': [800.0, 320.0], '2.89': [832.0, 288.0], '3.0': [864.0, 288.0], '3.11': [896.0, 288.0],
|
22 |
+
'3.62': [928.0, 256.0], '3.75': [960.0, 256.0], '3.88': [992.0, 256.0], '4.0': [1024.0, 256.0]
|
23 |
+
}
|
24 |
+
ASPECT_RATIO_RANDOM_CROP_512 = {
|
25 |
+
'0.42': [320.0, 768.0], '0.5': [352.0, 704.0],
|
26 |
+
'0.57': [384.0, 672.0], '0.68': [416.0, 608.0], '0.78': [448.0, 576.0], '0.88': [480.0, 544.0],
|
27 |
+
'0.94': [480.0, 512.0], '1.0': [512.0, 512.0], '1.07': [512.0, 480.0],
|
28 |
+
'1.13': [544.0, 480.0], '1.29': [576.0, 448.0], '1.46': [608.0, 416.0], '1.75': [672.0, 384.0],
|
29 |
+
'2.0': [704.0, 352.0], '2.4': [768.0, 320.0]
|
30 |
+
}
|
31 |
+
ASPECT_RATIO_RANDOM_CROP_PROB = [
|
32 |
+
1, 2,
|
33 |
+
4, 4, 4, 4,
|
34 |
+
8, 8, 8,
|
35 |
+
4, 4, 4, 4,
|
36 |
+
2, 1
|
37 |
+
]
|
38 |
+
ASPECT_RATIO_RANDOM_CROP_PROB = np.array(ASPECT_RATIO_RANDOM_CROP_PROB) / sum(ASPECT_RATIO_RANDOM_CROP_PROB)
|
39 |
+
|
40 |
+
def get_closest_ratio(height: float, width: float, ratios: dict = ASPECT_RATIO_512):
|
41 |
+
aspect_ratio = height / width
|
42 |
+
closest_ratio = min(ratios.keys(), key=lambda ratio: abs(float(ratio) - aspect_ratio))
|
43 |
+
return ratios[closest_ratio], float(closest_ratio)
|
44 |
+
|
45 |
+
def get_image_size_without_loading(path):
|
46 |
+
with Image.open(path) as img:
|
47 |
+
return img.size # (width, height)
|
48 |
+
|
49 |
+
class RandomSampler(Sampler[int]):
|
50 |
+
r"""Samples elements randomly. If without replacement, then sample from a shuffled dataset.
|
51 |
+
|
52 |
+
If with replacement, then user can specify :attr:`num_samples` to draw.
|
53 |
+
|
54 |
+
Args:
|
55 |
+
data_source (Dataset): dataset to sample from
|
56 |
+
replacement (bool): samples are drawn on-demand with replacement if ``True``, default=``False``
|
57 |
+
num_samples (int): number of samples to draw, default=`len(dataset)`.
|
58 |
+
generator (Generator): Generator used in sampling.
|
59 |
+
"""
|
60 |
+
|
61 |
+
data_source: Sized
|
62 |
+
replacement: bool
|
63 |
+
|
64 |
+
def __init__(self, data_source: Sized, replacement: bool = False,
|
65 |
+
num_samples: Optional[int] = None, generator=None) -> None:
|
66 |
+
self.data_source = data_source
|
67 |
+
self.replacement = replacement
|
68 |
+
self._num_samples = num_samples
|
69 |
+
self.generator = generator
|
70 |
+
self._pos_start = 0
|
71 |
+
|
72 |
+
if not isinstance(self.replacement, bool):
|
73 |
+
raise TypeError(f"replacement should be a boolean value, but got replacement={self.replacement}")
|
74 |
+
|
75 |
+
if not isinstance(self.num_samples, int) or self.num_samples <= 0:
|
76 |
+
raise ValueError(f"num_samples should be a positive integer value, but got num_samples={self.num_samples}")
|
77 |
+
|
78 |
+
@property
|
79 |
+
def num_samples(self) -> int:
|
80 |
+
# dataset size might change at runtime
|
81 |
+
if self._num_samples is None:
|
82 |
+
return len(self.data_source)
|
83 |
+
return self._num_samples
|
84 |
+
|
85 |
+
def __iter__(self) -> Iterator[int]:
|
86 |
+
n = len(self.data_source)
|
87 |
+
if self.generator is None:
|
88 |
+
seed = int(torch.empty((), dtype=torch.int64).random_().item())
|
89 |
+
generator = torch.Generator()
|
90 |
+
generator.manual_seed(seed)
|
91 |
+
else:
|
92 |
+
generator = self.generator
|
93 |
+
|
94 |
+
if self.replacement:
|
95 |
+
for _ in range(self.num_samples // 32):
|
96 |
+
yield from torch.randint(high=n, size=(32,), dtype=torch.int64, generator=generator).tolist()
|
97 |
+
yield from torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64, generator=generator).tolist()
|
98 |
+
else:
|
99 |
+
for _ in range(self.num_samples // n):
|
100 |
+
xx = torch.randperm(n, generator=generator).tolist()
|
101 |
+
if self._pos_start >= n:
|
102 |
+
self._pos_start = 0
|
103 |
+
print("xx top 10", xx[:10], self._pos_start)
|
104 |
+
for idx in range(self._pos_start, n):
|
105 |
+
yield xx[idx]
|
106 |
+
self._pos_start = (self._pos_start + 1) % n
|
107 |
+
self._pos_start = 0
|
108 |
+
yield from torch.randperm(n, generator=generator).tolist()[:self.num_samples % n]
|
109 |
+
|
110 |
+
def __len__(self) -> int:
|
111 |
+
return self.num_samples
|
112 |
+
|
113 |
+
class AspectRatioBatchImageSampler(BatchSampler):
|
114 |
+
"""A sampler wrapper for grouping images with similar aspect ratio into a same batch.
|
115 |
+
|
116 |
+
Args:
|
117 |
+
sampler (Sampler): Base sampler.
|
118 |
+
dataset (Dataset): Dataset providing data information.
|
119 |
+
batch_size (int): Size of mini-batch.
|
120 |
+
drop_last (bool): If ``True``, the sampler will drop the last batch if
|
121 |
+
its size would be less than ``batch_size``.
|
122 |
+
aspect_ratios (dict): The predefined aspect ratios.
|
123 |
+
"""
|
124 |
+
def __init__(
|
125 |
+
self,
|
126 |
+
sampler: Sampler,
|
127 |
+
dataset: Dataset,
|
128 |
+
batch_size: int,
|
129 |
+
train_folder: str = None,
|
130 |
+
aspect_ratios: dict = ASPECT_RATIO_512,
|
131 |
+
drop_last: bool = False,
|
132 |
+
config=None,
|
133 |
+
**kwargs
|
134 |
+
) -> None:
|
135 |
+
if not isinstance(sampler, Sampler):
|
136 |
+
raise TypeError('sampler should be an instance of ``Sampler``, '
|
137 |
+
f'but got {sampler}')
|
138 |
+
if not isinstance(batch_size, int) or batch_size <= 0:
|
139 |
+
raise ValueError('batch_size should be a positive integer value, '
|
140 |
+
f'but got batch_size={batch_size}')
|
141 |
+
self.sampler = sampler
|
142 |
+
self.dataset = dataset
|
143 |
+
self.train_folder = train_folder
|
144 |
+
self.batch_size = batch_size
|
145 |
+
self.aspect_ratios = aspect_ratios
|
146 |
+
self.drop_last = drop_last
|
147 |
+
self.config = config
|
148 |
+
# buckets for each aspect ratio
|
149 |
+
self._aspect_ratio_buckets = {ratio: [] for ratio in aspect_ratios}
|
150 |
+
# [str(k) for k, v in aspect_ratios]
|
151 |
+
self.current_available_bucket_keys = list(aspect_ratios.keys())
|
152 |
+
|
153 |
+
def __iter__(self):
|
154 |
+
for idx in self.sampler:
|
155 |
+
try:
|
156 |
+
image_dict = self.dataset[idx]
|
157 |
+
|
158 |
+
width, height = image_dict.get("width", None), image_dict.get("height", None)
|
159 |
+
if width is None or height is None:
|
160 |
+
image_id, name = image_dict['file_path'], image_dict['text']
|
161 |
+
if self.train_folder is None:
|
162 |
+
image_dir = image_id
|
163 |
+
else:
|
164 |
+
image_dir = os.path.join(self.train_folder, image_id)
|
165 |
+
|
166 |
+
width, height = get_image_size_without_loading(image_dir)
|
167 |
+
|
168 |
+
ratio = height / width # self.dataset[idx]
|
169 |
+
else:
|
170 |
+
height = int(height)
|
171 |
+
width = int(width)
|
172 |
+
ratio = height / width # self.dataset[idx]
|
173 |
+
except Exception as e:
|
174 |
+
print(e)
|
175 |
+
continue
|
176 |
+
# find the closest aspect ratio
|
177 |
+
closest_ratio = min(self.aspect_ratios.keys(), key=lambda r: abs(float(r) - ratio))
|
178 |
+
if closest_ratio not in self.current_available_bucket_keys:
|
179 |
+
continue
|
180 |
+
bucket = self._aspect_ratio_buckets[closest_ratio]
|
181 |
+
bucket.append(idx)
|
182 |
+
# yield a batch of indices in the same aspect ratio group
|
183 |
+
if len(bucket) == self.batch_size:
|
184 |
+
yield bucket[:]
|
185 |
+
del bucket[:]
|
186 |
+
|
187 |
+
class AspectRatioBatchSampler(BatchSampler):
|
188 |
+
"""A sampler wrapper for grouping images with similar aspect ratio into a same batch.
|
189 |
+
|
190 |
+
Args:
|
191 |
+
sampler (Sampler): Base sampler.
|
192 |
+
dataset (Dataset): Dataset providing data information.
|
193 |
+
batch_size (int): Size of mini-batch.
|
194 |
+
drop_last (bool): If ``True``, the sampler will drop the last batch if
|
195 |
+
its size would be less than ``batch_size``.
|
196 |
+
aspect_ratios (dict): The predefined aspect ratios.
|
197 |
+
"""
|
198 |
+
def __init__(
|
199 |
+
self,
|
200 |
+
sampler: Sampler,
|
201 |
+
dataset: Dataset,
|
202 |
+
batch_size: int,
|
203 |
+
video_folder: str = None,
|
204 |
+
train_data_format: str = "webvid",
|
205 |
+
aspect_ratios: dict = ASPECT_RATIO_512,
|
206 |
+
drop_last: bool = False,
|
207 |
+
config=None,
|
208 |
+
**kwargs
|
209 |
+
) -> None:
|
210 |
+
if not isinstance(sampler, Sampler):
|
211 |
+
raise TypeError('sampler should be an instance of ``Sampler``, '
|
212 |
+
f'but got {sampler}')
|
213 |
+
if not isinstance(batch_size, int) or batch_size <= 0:
|
214 |
+
raise ValueError('batch_size should be a positive integer value, '
|
215 |
+
f'but got batch_size={batch_size}')
|
216 |
+
self.sampler = sampler
|
217 |
+
self.dataset = dataset
|
218 |
+
self.video_folder = video_folder
|
219 |
+
self.train_data_format = train_data_format
|
220 |
+
self.batch_size = batch_size
|
221 |
+
self.aspect_ratios = aspect_ratios
|
222 |
+
self.drop_last = drop_last
|
223 |
+
self.config = config
|
224 |
+
# buckets for each aspect ratio
|
225 |
+
self._aspect_ratio_buckets = {ratio: [] for ratio in aspect_ratios}
|
226 |
+
# [str(k) for k, v in aspect_ratios]
|
227 |
+
self.current_available_bucket_keys = list(aspect_ratios.keys())
|
228 |
+
|
229 |
+
def __iter__(self):
|
230 |
+
for idx in self.sampler:
|
231 |
+
try:
|
232 |
+
video_dict = self.dataset[idx]
|
233 |
+
width, more = video_dict.get("width", None), video_dict.get("height", None)
|
234 |
+
|
235 |
+
if width is None or height is None:
|
236 |
+
if self.train_data_format == "normal":
|
237 |
+
video_id, name = video_dict['file_path'], video_dict['text']
|
238 |
+
if self.video_folder is None:
|
239 |
+
video_dir = video_id
|
240 |
+
else:
|
241 |
+
video_dir = os.path.join(self.video_folder, video_id)
|
242 |
+
else:
|
243 |
+
videoid, name, page_dir = video_dict['videoid'], video_dict['name'], video_dict['page_dir']
|
244 |
+
video_dir = os.path.join(self.video_folder, f"{videoid}.mp4")
|
245 |
+
cap = cv2.VideoCapture(video_dir)
|
246 |
+
|
247 |
+
# 获取视频尺寸
|
248 |
+
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) # 浮点数转换为整数
|
249 |
+
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) # 浮点数转换为整数
|
250 |
+
|
251 |
+
ratio = height / width # self.dataset[idx]
|
252 |
+
else:
|
253 |
+
height = int(height)
|
254 |
+
width = int(width)
|
255 |
+
ratio = height / width # self.dataset[idx]
|
256 |
+
except Exception as e:
|
257 |
+
print(e)
|
258 |
+
continue
|
259 |
+
# find the closest aspect ratio
|
260 |
+
closest_ratio = min(self.aspect_ratios.keys(), key=lambda r: abs(float(r) - ratio))
|
261 |
+
if closest_ratio not in self.current_available_bucket_keys:
|
262 |
+
continue
|
263 |
+
bucket = self._aspect_ratio_buckets[closest_ratio]
|
264 |
+
bucket.append(idx)
|
265 |
+
# yield a batch of indices in the same aspect ratio group
|
266 |
+
if len(bucket) == self.batch_size:
|
267 |
+
yield bucket[:]
|
268 |
+
del bucket[:]
|
269 |
+
|
270 |
+
class AspectRatioBatchImageVideoSampler(BatchSampler):
|
271 |
+
"""A sampler wrapper for grouping images with similar aspect ratio into a same batch.
|
272 |
+
|
273 |
+
Args:
|
274 |
+
sampler (Sampler): Base sampler.
|
275 |
+
dataset (Dataset): Dataset providing data information.
|
276 |
+
batch_size (int): Size of mini-batch.
|
277 |
+
drop_last (bool): If ``True``, the sampler will drop the last batch if
|
278 |
+
its size would be less than ``batch_size``.
|
279 |
+
aspect_ratios (dict): The predefined aspect ratios.
|
280 |
+
"""
|
281 |
+
|
282 |
+
def __init__(self,
|
283 |
+
sampler: Sampler,
|
284 |
+
dataset: Dataset,
|
285 |
+
batch_size: int,
|
286 |
+
train_folder: str = None,
|
287 |
+
aspect_ratios: dict = ASPECT_RATIO_512,
|
288 |
+
drop_last: bool = False
|
289 |
+
) -> None:
|
290 |
+
if not isinstance(sampler, Sampler):
|
291 |
+
raise TypeError('sampler should be an instance of ``Sampler``, '
|
292 |
+
f'but got {sampler}')
|
293 |
+
if not isinstance(batch_size, int) or batch_size <= 0:
|
294 |
+
raise ValueError('batch_size should be a positive integer value, '
|
295 |
+
f'but got batch_size={batch_size}')
|
296 |
+
self.sampler = sampler
|
297 |
+
self.dataset = dataset
|
298 |
+
self.train_folder = train_folder
|
299 |
+
self.batch_size = batch_size
|
300 |
+
self.aspect_ratios = aspect_ratios
|
301 |
+
self.drop_last = drop_last
|
302 |
+
|
303 |
+
# buckets for each aspect ratio
|
304 |
+
self.current_available_bucket_keys = list(aspect_ratios.keys())
|
305 |
+
self.bucket = {
|
306 |
+
'image':{ratio: [] for ratio in aspect_ratios},
|
307 |
+
'video':{ratio: [] for ratio in aspect_ratios}
|
308 |
+
}
|
309 |
+
|
310 |
+
def __iter__(self):
|
311 |
+
for idx in self.sampler:
|
312 |
+
content_type = self.dataset[idx].get('type', 'image')
|
313 |
+
if content_type == 'image':
|
314 |
+
try:
|
315 |
+
image_dict = self.dataset[idx]
|
316 |
+
|
317 |
+
width, height = image_dict.get("width", None), image_dict.get("height", None)
|
318 |
+
if width is None or height is None:
|
319 |
+
image_id, name = image_dict['file_path'], image_dict['text']
|
320 |
+
if self.train_folder is None:
|
321 |
+
image_dir = image_id
|
322 |
+
else:
|
323 |
+
image_dir = os.path.join(self.train_folder, image_id)
|
324 |
+
|
325 |
+
width, height = get_image_size_without_loading(image_dir)
|
326 |
+
|
327 |
+
ratio = height / width # self.dataset[idx]
|
328 |
+
else:
|
329 |
+
height = int(height)
|
330 |
+
width = int(width)
|
331 |
+
ratio = height / width # self.dataset[idx]
|
332 |
+
except Exception as e:
|
333 |
+
print(e)
|
334 |
+
continue
|
335 |
+
# find the closest aspect ratio
|
336 |
+
closest_ratio = min(self.aspect_ratios.keys(), key=lambda r: abs(float(r) - ratio))
|
337 |
+
if closest_ratio not in self.current_available_bucket_keys:
|
338 |
+
continue
|
339 |
+
bucket = self.bucket['image'][closest_ratio]
|
340 |
+
bucket.append(idx)
|
341 |
+
# yield a batch of indices in the same aspect ratio group
|
342 |
+
if len(bucket) == self.batch_size:
|
343 |
+
yield bucket[:]
|
344 |
+
del bucket[:]
|
345 |
+
else:
|
346 |
+
try:
|
347 |
+
video_dict = self.dataset[idx]
|
348 |
+
width, height = video_dict.get("width", None), video_dict.get("height", None)
|
349 |
+
|
350 |
+
if width is None or height is None:
|
351 |
+
video_id, name = video_dict['file_path'], video_dict['text']
|
352 |
+
if self.train_folder is None:
|
353 |
+
video_dir = video_id
|
354 |
+
else:
|
355 |
+
video_dir = os.path.join(self.train_folder, video_id)
|
356 |
+
cap = cv2.VideoCapture(video_dir)
|
357 |
+
|
358 |
+
# 获取视频尺寸
|
359 |
+
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) # 浮点数转换为整数
|
360 |
+
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) # 浮点数转换为整数
|
361 |
+
|
362 |
+
ratio = height / width # self.dataset[idx]
|
363 |
+
else:
|
364 |
+
height = int(height)
|
365 |
+
width = int(width)
|
366 |
+
ratio = height / width # self.dataset[idx]
|
367 |
+
except Exception as e:
|
368 |
+
print(e)
|
369 |
+
continue
|
370 |
+
# find the closest aspect ratio
|
371 |
+
closest_ratio = min(self.aspect_ratios.keys(), key=lambda r: abs(float(r) - ratio))
|
372 |
+
if closest_ratio not in self.current_available_bucket_keys:
|
373 |
+
continue
|
374 |
+
bucket = self.bucket['video'][closest_ratio]
|
375 |
+
bucket.append(idx)
|
376 |
+
# yield a batch of indices in the same aspect ratio group
|
377 |
+
if len(bucket) == self.batch_size:
|
378 |
+
yield bucket[:]
|
379 |
+
del bucket[:]
|
easyanimate/data/dataset_image.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
import random
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
import torchvision.transforms as transforms
|
8 |
+
from PIL import Image
|
9 |
+
from torch.utils.data.dataset import Dataset
|
10 |
+
|
11 |
+
|
12 |
+
class CC15M(Dataset):
|
13 |
+
def __init__(
|
14 |
+
self,
|
15 |
+
json_path,
|
16 |
+
video_folder=None,
|
17 |
+
resolution=512,
|
18 |
+
enable_bucket=False,
|
19 |
+
):
|
20 |
+
print(f"loading annotations from {json_path} ...")
|
21 |
+
self.dataset = json.load(open(json_path, 'r'))
|
22 |
+
self.length = len(self.dataset)
|
23 |
+
print(f"data scale: {self.length}")
|
24 |
+
|
25 |
+
self.enable_bucket = enable_bucket
|
26 |
+
self.video_folder = video_folder
|
27 |
+
|
28 |
+
resolution = tuple(resolution) if not isinstance(resolution, int) else (resolution, resolution)
|
29 |
+
self.pixel_transforms = transforms.Compose([
|
30 |
+
transforms.Resize(resolution[0]),
|
31 |
+
transforms.CenterCrop(resolution),
|
32 |
+
transforms.ToTensor(),
|
33 |
+
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
|
34 |
+
])
|
35 |
+
|
36 |
+
def get_batch(self, idx):
|
37 |
+
video_dict = self.dataset[idx]
|
38 |
+
video_id, name = video_dict['file_path'], video_dict['text']
|
39 |
+
|
40 |
+
if self.video_folder is None:
|
41 |
+
video_dir = video_id
|
42 |
+
else:
|
43 |
+
video_dir = os.path.join(self.video_folder, video_id)
|
44 |
+
|
45 |
+
pixel_values = Image.open(video_dir).convert("RGB")
|
46 |
+
return pixel_values, name
|
47 |
+
|
48 |
+
def __len__(self):
|
49 |
+
return self.length
|
50 |
+
|
51 |
+
def __getitem__(self, idx):
|
52 |
+
while True:
|
53 |
+
try:
|
54 |
+
pixel_values, name = self.get_batch(idx)
|
55 |
+
break
|
56 |
+
except Exception as e:
|
57 |
+
print(e)
|
58 |
+
idx = random.randint(0, self.length-1)
|
59 |
+
|
60 |
+
if not self.enable_bucket:
|
61 |
+
pixel_values = self.pixel_transforms(pixel_values)
|
62 |
+
else:
|
63 |
+
pixel_values = np.array(pixel_values)
|
64 |
+
|
65 |
+
sample = dict(pixel_values=pixel_values, text=name)
|
66 |
+
return sample
|
67 |
+
|
68 |
+
if __name__ == "__main__":
|
69 |
+
dataset = CC15M(
|
70 |
+
csv_path="/mnt_wg/zhoumo.xjq/CCUtils/cc15m_add_index.json",
|
71 |
+
resolution=512,
|
72 |
+
)
|
73 |
+
|
74 |
+
dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, num_workers=0,)
|
75 |
+
for idx, batch in enumerate(dataloader):
|
76 |
+
print(batch["pixel_values"].shape, len(batch["text"]))
|
easyanimate/data/dataset_image_video.py
ADDED
@@ -0,0 +1,241 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import csv
|
2 |
+
import io
|
3 |
+
import json
|
4 |
+
import math
|
5 |
+
import os
|
6 |
+
import random
|
7 |
+
from threading import Thread
|
8 |
+
|
9 |
+
import albumentations
|
10 |
+
import cv2
|
11 |
+
import gc
|
12 |
+
import numpy as np
|
13 |
+
import torch
|
14 |
+
import torchvision.transforms as transforms
|
15 |
+
from func_timeout import func_timeout, FunctionTimedOut
|
16 |
+
from decord import VideoReader
|
17 |
+
from PIL import Image
|
18 |
+
from torch.utils.data import BatchSampler, Sampler
|
19 |
+
from torch.utils.data.dataset import Dataset
|
20 |
+
from contextlib import contextmanager
|
21 |
+
|
22 |
+
VIDEO_READER_TIMEOUT = 20
|
23 |
+
|
24 |
+
class ImageVideoSampler(BatchSampler):
|
25 |
+
"""A sampler wrapper for grouping images with similar aspect ratio into a same batch.
|
26 |
+
|
27 |
+
Args:
|
28 |
+
sampler (Sampler): Base sampler.
|
29 |
+
dataset (Dataset): Dataset providing data information.
|
30 |
+
batch_size (int): Size of mini-batch.
|
31 |
+
drop_last (bool): If ``True``, the sampler will drop the last batch if
|
32 |
+
its size would be less than ``batch_size``.
|
33 |
+
aspect_ratios (dict): The predefined aspect ratios.
|
34 |
+
"""
|
35 |
+
|
36 |
+
def __init__(self,
|
37 |
+
sampler: Sampler,
|
38 |
+
dataset: Dataset,
|
39 |
+
batch_size: int,
|
40 |
+
drop_last: bool = False
|
41 |
+
) -> None:
|
42 |
+
if not isinstance(sampler, Sampler):
|
43 |
+
raise TypeError('sampler should be an instance of ``Sampler``, '
|
44 |
+
f'but got {sampler}')
|
45 |
+
if not isinstance(batch_size, int) or batch_size <= 0:
|
46 |
+
raise ValueError('batch_size should be a positive integer value, '
|
47 |
+
f'but got batch_size={batch_size}')
|
48 |
+
self.sampler = sampler
|
49 |
+
self.dataset = dataset
|
50 |
+
self.batch_size = batch_size
|
51 |
+
self.drop_last = drop_last
|
52 |
+
|
53 |
+
# buckets for each aspect ratio
|
54 |
+
self.bucket = {'image':[], 'video':[]}
|
55 |
+
|
56 |
+
def __iter__(self):
|
57 |
+
for idx in self.sampler:
|
58 |
+
content_type = self.dataset.dataset[idx].get('type', 'image')
|
59 |
+
self.bucket[content_type].append(idx)
|
60 |
+
|
61 |
+
# yield a batch of indices in the same aspect ratio group
|
62 |
+
if len(self.bucket['video']) == self.batch_size:
|
63 |
+
bucket = self.bucket['video']
|
64 |
+
yield bucket[:]
|
65 |
+
del bucket[:]
|
66 |
+
elif len(self.bucket['image']) == self.batch_size:
|
67 |
+
bucket = self.bucket['image']
|
68 |
+
yield bucket[:]
|
69 |
+
del bucket[:]
|
70 |
+
|
71 |
+
@contextmanager
|
72 |
+
def VideoReader_contextmanager(*args, **kwargs):
|
73 |
+
vr = VideoReader(*args, **kwargs)
|
74 |
+
try:
|
75 |
+
yield vr
|
76 |
+
finally:
|
77 |
+
del vr
|
78 |
+
gc.collect()
|
79 |
+
|
80 |
+
def get_video_reader_batch(video_reader, batch_index):
|
81 |
+
frames = video_reader.get_batch(batch_index).asnumpy()
|
82 |
+
return frames
|
83 |
+
|
84 |
+
class ImageVideoDataset(Dataset):
|
85 |
+
def __init__(
|
86 |
+
self,
|
87 |
+
ann_path, data_root=None,
|
88 |
+
video_sample_size=512, video_sample_stride=4, video_sample_n_frames=16,
|
89 |
+
image_sample_size=512,
|
90 |
+
video_repeat=0,
|
91 |
+
text_drop_ratio=0.001,
|
92 |
+
enable_bucket=False,
|
93 |
+
video_length_drop_start=0.1,
|
94 |
+
video_length_drop_end=0.9,
|
95 |
+
):
|
96 |
+
# Loading annotations from files
|
97 |
+
print(f"loading annotations from {ann_path} ...")
|
98 |
+
if ann_path.endswith('.csv'):
|
99 |
+
with open(ann_path, 'r') as csvfile:
|
100 |
+
dataset = list(csv.DictReader(csvfile))
|
101 |
+
elif ann_path.endswith('.json'):
|
102 |
+
dataset = json.load(open(ann_path))
|
103 |
+
|
104 |
+
self.data_root = data_root
|
105 |
+
|
106 |
+
# It's used to balance num of images and videos.
|
107 |
+
self.dataset = []
|
108 |
+
for data in dataset:
|
109 |
+
if data.get('type', 'image') != 'video':
|
110 |
+
self.dataset.append(data)
|
111 |
+
if video_repeat > 0:
|
112 |
+
for _ in range(video_repeat):
|
113 |
+
for data in dataset:
|
114 |
+
if data.get('type', 'image') == 'video':
|
115 |
+
self.dataset.append(data)
|
116 |
+
del dataset
|
117 |
+
|
118 |
+
self.length = len(self.dataset)
|
119 |
+
print(f"data scale: {self.length}")
|
120 |
+
# TODO: enable bucket training
|
121 |
+
self.enable_bucket = enable_bucket
|
122 |
+
self.text_drop_ratio = text_drop_ratio
|
123 |
+
self.video_length_drop_start = video_length_drop_start
|
124 |
+
self.video_length_drop_end = video_length_drop_end
|
125 |
+
|
126 |
+
# Video params
|
127 |
+
self.video_sample_stride = video_sample_stride
|
128 |
+
self.video_sample_n_frames = video_sample_n_frames
|
129 |
+
video_sample_size = tuple(video_sample_size) if not isinstance(video_sample_size, int) else (video_sample_size, video_sample_size)
|
130 |
+
self.video_transforms = transforms.Compose(
|
131 |
+
[
|
132 |
+
transforms.Resize(video_sample_size[0]),
|
133 |
+
transforms.CenterCrop(video_sample_size),
|
134 |
+
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
|
135 |
+
]
|
136 |
+
)
|
137 |
+
|
138 |
+
# Image params
|
139 |
+
self.image_sample_size = tuple(image_sample_size) if not isinstance(image_sample_size, int) else (image_sample_size, image_sample_size)
|
140 |
+
self.image_transforms = transforms.Compose([
|
141 |
+
transforms.Resize(min(self.image_sample_size)),
|
142 |
+
transforms.CenterCrop(self.image_sample_size),
|
143 |
+
transforms.ToTensor(),
|
144 |
+
transforms.Normalize([0.5, 0.5, 0.5],[0.5, 0.5, 0.5])
|
145 |
+
])
|
146 |
+
|
147 |
+
def get_batch(self, idx):
|
148 |
+
data_info = self.dataset[idx % len(self.dataset)]
|
149 |
+
|
150 |
+
if data_info.get('type', 'image')=='video':
|
151 |
+
video_id, text = data_info['file_path'], data_info['text']
|
152 |
+
|
153 |
+
if self.data_root is None:
|
154 |
+
video_dir = video_id
|
155 |
+
else:
|
156 |
+
video_dir = os.path.join(self.data_root, video_id)
|
157 |
+
|
158 |
+
with VideoReader_contextmanager(video_dir, num_threads=2) as video_reader:
|
159 |
+
min_sample_n_frames = min(
|
160 |
+
self.video_sample_n_frames,
|
161 |
+
int(len(video_reader) * (self.video_length_drop_end - self.video_length_drop_start))
|
162 |
+
)
|
163 |
+
if min_sample_n_frames == 0:
|
164 |
+
raise ValueError(f"No Frames in video.")
|
165 |
+
|
166 |
+
video_length = int(self.video_length_drop_end * len(video_reader))
|
167 |
+
clip_length = min(video_length, (min_sample_n_frames - 1) * self.video_sample_stride + 1)
|
168 |
+
start_idx = random.randint(int(self.video_length_drop_start * video_length), video_length - clip_length)
|
169 |
+
batch_index = np.linspace(start_idx, start_idx + clip_length - 1, min_sample_n_frames, dtype=int)
|
170 |
+
|
171 |
+
try:
|
172 |
+
sample_args = (video_reader, batch_index)
|
173 |
+
pixel_values = func_timeout(
|
174 |
+
VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args
|
175 |
+
)
|
176 |
+
except FunctionTimedOut:
|
177 |
+
raise ValueError(f"Read {idx} timeout.")
|
178 |
+
except Exception as e:
|
179 |
+
raise ValueError(f"Failed to extract frames from video. Error is {e}.")
|
180 |
+
|
181 |
+
if not self.enable_bucket:
|
182 |
+
pixel_values = torch.from_numpy(pixel_values).permute(0, 3, 1, 2).contiguous()
|
183 |
+
pixel_values = pixel_values / 255.
|
184 |
+
del video_reader
|
185 |
+
else:
|
186 |
+
pixel_values = pixel_values
|
187 |
+
|
188 |
+
if not self.enable_bucket:
|
189 |
+
pixel_values = self.video_transforms(pixel_values)
|
190 |
+
|
191 |
+
# Random use no text generation
|
192 |
+
if random.random() < self.text_drop_ratio:
|
193 |
+
text = ''
|
194 |
+
return pixel_values, text, 'video'
|
195 |
+
else:
|
196 |
+
image_path, text = data_info['file_path'], data_info['text']
|
197 |
+
if self.data_root is not None:
|
198 |
+
image_path = os.path.join(self.data_root, image_path)
|
199 |
+
image = Image.open(image_path).convert('RGB')
|
200 |
+
if not self.enable_bucket:
|
201 |
+
image = self.image_transforms(image).unsqueeze(0)
|
202 |
+
else:
|
203 |
+
image = np.expand_dims(np.array(image), 0)
|
204 |
+
if random.random() < self.text_drop_ratio:
|
205 |
+
text = ''
|
206 |
+
return image, text, 'image'
|
207 |
+
|
208 |
+
def __len__(self):
|
209 |
+
return self.length
|
210 |
+
|
211 |
+
def __getitem__(self, idx):
|
212 |
+
data_info = self.dataset[idx % len(self.dataset)]
|
213 |
+
data_type = data_info.get('type', 'image')
|
214 |
+
while True:
|
215 |
+
sample = {}
|
216 |
+
try:
|
217 |
+
data_info_local = self.dataset[idx % len(self.dataset)]
|
218 |
+
data_type_local = data_info_local.get('type', 'image')
|
219 |
+
if data_type_local != data_type:
|
220 |
+
raise ValueError("data_type_local != data_type")
|
221 |
+
|
222 |
+
pixel_values, name, data_type = self.get_batch(idx)
|
223 |
+
sample["pixel_values"] = pixel_values
|
224 |
+
sample["text"] = name
|
225 |
+
sample["data_type"] = data_type
|
226 |
+
sample["idx"] = idx
|
227 |
+
|
228 |
+
if len(sample) > 0:
|
229 |
+
break
|
230 |
+
except Exception as e:
|
231 |
+
print(e, self.dataset[idx % len(self.dataset)])
|
232 |
+
idx = random.randint(0, self.length-1)
|
233 |
+
return sample
|
234 |
+
|
235 |
+
if __name__ == "__main__":
|
236 |
+
dataset = ImageVideoDataset(
|
237 |
+
ann_path="test.json"
|
238 |
+
)
|
239 |
+
dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, num_workers=16)
|
240 |
+
for idx, batch in enumerate(dataloader):
|
241 |
+
print(batch["pixel_values"].shape, len(batch["text"]))
|
easyanimate/data/dataset_video.py
ADDED
@@ -0,0 +1,262 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import csv
|
2 |
+
import gc
|
3 |
+
import io
|
4 |
+
import json
|
5 |
+
import math
|
6 |
+
import os
|
7 |
+
import random
|
8 |
+
from contextlib import contextmanager
|
9 |
+
from threading import Thread
|
10 |
+
|
11 |
+
import albumentations
|
12 |
+
import cv2
|
13 |
+
import numpy as np
|
14 |
+
import torch
|
15 |
+
import torchvision.transforms as transforms
|
16 |
+
from decord import VideoReader
|
17 |
+
from einops import rearrange
|
18 |
+
from func_timeout import FunctionTimedOut, func_timeout
|
19 |
+
from PIL import Image
|
20 |
+
from torch.utils.data import BatchSampler, Sampler
|
21 |
+
from torch.utils.data.dataset import Dataset
|
22 |
+
|
23 |
+
VIDEO_READER_TIMEOUT = 20
|
24 |
+
|
25 |
+
def get_random_mask(shape):
|
26 |
+
f, c, h, w = shape
|
27 |
+
|
28 |
+
mask_index = np.random.randint(0, 4)
|
29 |
+
mask = torch.zeros((f, 1, h, w), dtype=torch.uint8)
|
30 |
+
if mask_index == 0:
|
31 |
+
mask[1:, :, :, :] = 1
|
32 |
+
elif mask_index == 1:
|
33 |
+
mask_frame_index = 1
|
34 |
+
mask[mask_frame_index:-mask_frame_index, :, :, :] = 1
|
35 |
+
elif mask_index == 2:
|
36 |
+
center_x = torch.randint(0, w, (1,)).item()
|
37 |
+
center_y = torch.randint(0, h, (1,)).item()
|
38 |
+
block_size_x = torch.randint(w // 4, w // 4 * 3, (1,)).item() # 方块的宽度范围
|
39 |
+
block_size_y = torch.randint(h // 4, h // 4 * 3, (1,)).item() # 方块的高度范围
|
40 |
+
|
41 |
+
start_x = max(center_x - block_size_x // 2, 0)
|
42 |
+
end_x = min(center_x + block_size_x // 2, w)
|
43 |
+
start_y = max(center_y - block_size_y // 2, 0)
|
44 |
+
end_y = min(center_y + block_size_y // 2, h)
|
45 |
+
mask[:, :, start_y:end_y, start_x:end_x] = 1
|
46 |
+
elif mask_index == 3:
|
47 |
+
center_x = torch.randint(0, w, (1,)).item()
|
48 |
+
center_y = torch.randint(0, h, (1,)).item()
|
49 |
+
block_size_x = torch.randint(w // 4, w // 4 * 3, (1,)).item() # 方块的宽度范围
|
50 |
+
block_size_y = torch.randint(h // 4, h // 4 * 3, (1,)).item() # 方块的高度范围
|
51 |
+
|
52 |
+
start_x = max(center_x - block_size_x // 2, 0)
|
53 |
+
end_x = min(center_x + block_size_x // 2, w)
|
54 |
+
start_y = max(center_y - block_size_y // 2, 0)
|
55 |
+
end_y = min(center_y + block_size_y // 2, h)
|
56 |
+
|
57 |
+
mask_frame_before = np.random.randint(0, f // 2)
|
58 |
+
mask_frame_after = np.random.randint(f // 2, f)
|
59 |
+
mask[mask_frame_before:mask_frame_after, :, start_y:end_y, start_x:end_x] = 1
|
60 |
+
else:
|
61 |
+
raise ValueError(f"The mask_index {mask_index} is not define")
|
62 |
+
return mask
|
63 |
+
|
64 |
+
|
65 |
+
@contextmanager
|
66 |
+
def VideoReader_contextmanager(*args, **kwargs):
|
67 |
+
vr = VideoReader(*args, **kwargs)
|
68 |
+
try:
|
69 |
+
yield vr
|
70 |
+
finally:
|
71 |
+
del vr
|
72 |
+
gc.collect()
|
73 |
+
|
74 |
+
|
75 |
+
def get_video_reader_batch(video_reader, batch_index):
|
76 |
+
frames = video_reader.get_batch(batch_index).asnumpy()
|
77 |
+
return frames
|
78 |
+
|
79 |
+
|
80 |
+
class WebVid10M(Dataset):
|
81 |
+
def __init__(
|
82 |
+
self,
|
83 |
+
csv_path, video_folder,
|
84 |
+
sample_size=256, sample_stride=4, sample_n_frames=16,
|
85 |
+
enable_bucket=False, enable_inpaint=False, is_image=False,
|
86 |
+
):
|
87 |
+
print(f"loading annotations from {csv_path} ...")
|
88 |
+
with open(csv_path, 'r') as csvfile:
|
89 |
+
self.dataset = list(csv.DictReader(csvfile))
|
90 |
+
self.length = len(self.dataset)
|
91 |
+
print(f"data scale: {self.length}")
|
92 |
+
|
93 |
+
self.video_folder = video_folder
|
94 |
+
self.sample_stride = sample_stride
|
95 |
+
self.sample_n_frames = sample_n_frames
|
96 |
+
self.enable_bucket = enable_bucket
|
97 |
+
self.enable_inpaint = enable_inpaint
|
98 |
+
self.is_image = is_image
|
99 |
+
|
100 |
+
sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size)
|
101 |
+
self.pixel_transforms = transforms.Compose([
|
102 |
+
transforms.Resize(sample_size[0]),
|
103 |
+
transforms.CenterCrop(sample_size),
|
104 |
+
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
|
105 |
+
])
|
106 |
+
|
107 |
+
def get_batch(self, idx):
|
108 |
+
video_dict = self.dataset[idx]
|
109 |
+
videoid, name, page_dir = video_dict['videoid'], video_dict['name'], video_dict['page_dir']
|
110 |
+
|
111 |
+
video_dir = os.path.join(self.video_folder, f"{videoid}.mp4")
|
112 |
+
video_reader = VideoReader(video_dir)
|
113 |
+
video_length = len(video_reader)
|
114 |
+
|
115 |
+
if not self.is_image:
|
116 |
+
clip_length = min(video_length, (self.sample_n_frames - 1) * self.sample_stride + 1)
|
117 |
+
start_idx = random.randint(0, video_length - clip_length)
|
118 |
+
batch_index = np.linspace(start_idx, start_idx + clip_length - 1, self.sample_n_frames, dtype=int)
|
119 |
+
else:
|
120 |
+
batch_index = [random.randint(0, video_length - 1)]
|
121 |
+
|
122 |
+
if not self.enable_bucket:
|
123 |
+
pixel_values = torch.from_numpy(video_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous()
|
124 |
+
pixel_values = pixel_values / 255.
|
125 |
+
del video_reader
|
126 |
+
else:
|
127 |
+
pixel_values = video_reader.get_batch(batch_index).asnumpy()
|
128 |
+
|
129 |
+
if self.is_image:
|
130 |
+
pixel_values = pixel_values[0]
|
131 |
+
return pixel_values, name
|
132 |
+
|
133 |
+
def __len__(self):
|
134 |
+
return self.length
|
135 |
+
|
136 |
+
def __getitem__(self, idx):
|
137 |
+
while True:
|
138 |
+
try:
|
139 |
+
pixel_values, name = self.get_batch(idx)
|
140 |
+
break
|
141 |
+
|
142 |
+
except Exception as e:
|
143 |
+
print("Error info:", e)
|
144 |
+
idx = random.randint(0, self.length-1)
|
145 |
+
|
146 |
+
if not self.enable_bucket:
|
147 |
+
pixel_values = self.pixel_transforms(pixel_values)
|
148 |
+
if self.enable_inpaint:
|
149 |
+
mask = get_random_mask(pixel_values.size())
|
150 |
+
mask_pixel_values = pixel_values * (1 - mask) + torch.ones_like(pixel_values) * -1 * mask
|
151 |
+
sample = dict(pixel_values=pixel_values, mask_pixel_values=mask_pixel_values, mask=mask, text=name)
|
152 |
+
else:
|
153 |
+
sample = dict(pixel_values=pixel_values, text=name)
|
154 |
+
return sample
|
155 |
+
|
156 |
+
|
157 |
+
class VideoDataset(Dataset):
|
158 |
+
def __init__(
|
159 |
+
self,
|
160 |
+
json_path, video_folder=None,
|
161 |
+
sample_size=256, sample_stride=4, sample_n_frames=16,
|
162 |
+
enable_bucket=False, enable_inpaint=False
|
163 |
+
):
|
164 |
+
print(f"loading annotations from {json_path} ...")
|
165 |
+
self.dataset = json.load(open(json_path, 'r'))
|
166 |
+
self.length = len(self.dataset)
|
167 |
+
print(f"data scale: {self.length}")
|
168 |
+
|
169 |
+
self.video_folder = video_folder
|
170 |
+
self.sample_stride = sample_stride
|
171 |
+
self.sample_n_frames = sample_n_frames
|
172 |
+
self.enable_bucket = enable_bucket
|
173 |
+
self.enable_inpaint = enable_inpaint
|
174 |
+
|
175 |
+
sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size)
|
176 |
+
self.pixel_transforms = transforms.Compose(
|
177 |
+
[
|
178 |
+
transforms.Resize(sample_size[0]),
|
179 |
+
transforms.CenterCrop(sample_size),
|
180 |
+
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
|
181 |
+
]
|
182 |
+
)
|
183 |
+
|
184 |
+
def get_batch(self, idx):
|
185 |
+
video_dict = self.dataset[idx]
|
186 |
+
video_id, name = video_dict['file_path'], video_dict['text']
|
187 |
+
|
188 |
+
if self.video_folder is None:
|
189 |
+
video_dir = video_id
|
190 |
+
else:
|
191 |
+
video_dir = os.path.join(self.video_folder, video_id)
|
192 |
+
|
193 |
+
with VideoReader_contextmanager(video_dir, num_threads=2) as video_reader:
|
194 |
+
video_length = len(video_reader)
|
195 |
+
|
196 |
+
clip_length = min(video_length, (self.sample_n_frames - 1) * self.sample_stride + 1)
|
197 |
+
start_idx = random.randint(0, video_length - clip_length)
|
198 |
+
batch_index = np.linspace(start_idx, start_idx + clip_length - 1, self.sample_n_frames, dtype=int)
|
199 |
+
|
200 |
+
try:
|
201 |
+
sample_args = (video_reader, batch_index)
|
202 |
+
pixel_values = func_timeout(
|
203 |
+
VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args
|
204 |
+
)
|
205 |
+
except FunctionTimedOut:
|
206 |
+
raise ValueError(f"Read {idx} timeout.")
|
207 |
+
except Exception as e:
|
208 |
+
raise ValueError(f"Failed to extract frames from video. Error is {e}.")
|
209 |
+
|
210 |
+
if not self.enable_bucket:
|
211 |
+
pixel_values = torch.from_numpy(pixel_values).permute(0, 3, 1, 2).contiguous()
|
212 |
+
pixel_values = pixel_values / 255.
|
213 |
+
del video_reader
|
214 |
+
else:
|
215 |
+
pixel_values = pixel_values
|
216 |
+
|
217 |
+
return pixel_values, name
|
218 |
+
|
219 |
+
def __len__(self):
|
220 |
+
return self.length
|
221 |
+
|
222 |
+
def __getitem__(self, idx):
|
223 |
+
while True:
|
224 |
+
try:
|
225 |
+
pixel_values, name = self.get_batch(idx)
|
226 |
+
break
|
227 |
+
|
228 |
+
except Exception as e:
|
229 |
+
print("Error info:", e)
|
230 |
+
idx = random.randint(0, self.length-1)
|
231 |
+
|
232 |
+
if not self.enable_bucket:
|
233 |
+
pixel_values = self.pixel_transforms(pixel_values)
|
234 |
+
if self.enable_inpaint:
|
235 |
+
mask = get_random_mask(pixel_values.size())
|
236 |
+
mask_pixel_values = pixel_values * (1 - mask) + torch.ones_like(pixel_values) * -1 * mask
|
237 |
+
sample = dict(pixel_values=pixel_values, mask_pixel_values=mask_pixel_values, mask=mask, text=name)
|
238 |
+
else:
|
239 |
+
sample = dict(pixel_values=pixel_values, text=name)
|
240 |
+
return sample
|
241 |
+
|
242 |
+
|
243 |
+
if __name__ == "__main__":
|
244 |
+
if 1:
|
245 |
+
dataset = VideoDataset(
|
246 |
+
json_path="/home/zhoumo.xjq/disk3/datasets/webvidval/results_2M_val.json",
|
247 |
+
sample_size=256,
|
248 |
+
sample_stride=4, sample_n_frames=16,
|
249 |
+
)
|
250 |
+
|
251 |
+
if 0:
|
252 |
+
dataset = WebVid10M(
|
253 |
+
csv_path="/mnt/petrelfs/guoyuwei/projects/datasets/webvid/results_2M_val.csv",
|
254 |
+
video_folder="/mnt/petrelfs/guoyuwei/projects/datasets/webvid/2M_val",
|
255 |
+
sample_size=256,
|
256 |
+
sample_stride=4, sample_n_frames=16,
|
257 |
+
is_image=False,
|
258 |
+
)
|
259 |
+
|
260 |
+
dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, num_workers=0,)
|
261 |
+
for idx, batch in enumerate(dataloader):
|
262 |
+
print(batch["pixel_values"].shape, len(batch["text"]))
|
easyanimate/models/attention.py
ADDED
@@ -0,0 +1,1299 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import math
|
15 |
+
from typing import Any, Dict, Optional
|
16 |
+
|
17 |
+
import torch
|
18 |
+
import torch.nn.functional as F
|
19 |
+
import torch.nn.init as init
|
20 |
+
from diffusers.models.activations import GEGLU, GELU, ApproximateGELU
|
21 |
+
from diffusers.models.attention import AdaLayerNorm, FeedForward
|
22 |
+
from diffusers.models.attention_processor import Attention
|
23 |
+
from diffusers.models.embeddings import SinusoidalPositionalEmbedding
|
24 |
+
from diffusers.models.lora import LoRACompatibleLinear
|
25 |
+
from diffusers.models.normalization import AdaLayerNorm, AdaLayerNormZero
|
26 |
+
from diffusers.utils import USE_PEFT_BACKEND
|
27 |
+
from diffusers.utils.import_utils import is_xformers_available
|
28 |
+
from diffusers.utils.torch_utils import maybe_allow_in_graph
|
29 |
+
from einops import rearrange, repeat
|
30 |
+
from torch import nn
|
31 |
+
|
32 |
+
from .motion_module import get_motion_module
|
33 |
+
|
34 |
+
if is_xformers_available():
|
35 |
+
import xformers
|
36 |
+
import xformers.ops
|
37 |
+
else:
|
38 |
+
xformers = None
|
39 |
+
|
40 |
+
|
41 |
+
@maybe_allow_in_graph
|
42 |
+
class GatedSelfAttentionDense(nn.Module):
|
43 |
+
r"""
|
44 |
+
A gated self-attention dense layer that combines visual features and object features.
|
45 |
+
|
46 |
+
Parameters:
|
47 |
+
query_dim (`int`): The number of channels in the query.
|
48 |
+
context_dim (`int`): The number of channels in the context.
|
49 |
+
n_heads (`int`): The number of heads to use for attention.
|
50 |
+
d_head (`int`): The number of channels in each head.
|
51 |
+
"""
|
52 |
+
|
53 |
+
def __init__(self, query_dim: int, context_dim: int, n_heads: int, d_head: int):
|
54 |
+
super().__init__()
|
55 |
+
|
56 |
+
# we need a linear projection since we need cat visual feature and obj feature
|
57 |
+
self.linear = nn.Linear(context_dim, query_dim)
|
58 |
+
|
59 |
+
self.attn = Attention(query_dim=query_dim, heads=n_heads, dim_head=d_head)
|
60 |
+
self.ff = FeedForward(query_dim, activation_fn="geglu")
|
61 |
+
|
62 |
+
self.norm1 = nn.LayerNorm(query_dim)
|
63 |
+
self.norm2 = nn.LayerNorm(query_dim)
|
64 |
+
|
65 |
+
self.register_parameter("alpha_attn", nn.Parameter(torch.tensor(0.0)))
|
66 |
+
self.register_parameter("alpha_dense", nn.Parameter(torch.tensor(0.0)))
|
67 |
+
|
68 |
+
self.enabled = True
|
69 |
+
|
70 |
+
def forward(self, x: torch.Tensor, objs: torch.Tensor) -> torch.Tensor:
|
71 |
+
if not self.enabled:
|
72 |
+
return x
|
73 |
+
|
74 |
+
n_visual = x.shape[1]
|
75 |
+
objs = self.linear(objs)
|
76 |
+
|
77 |
+
x = x + self.alpha_attn.tanh() * self.attn(self.norm1(torch.cat([x, objs], dim=1)))[:, :n_visual, :]
|
78 |
+
x = x + self.alpha_dense.tanh() * self.ff(self.norm2(x))
|
79 |
+
|
80 |
+
return x
|
81 |
+
|
82 |
+
|
83 |
+
def zero_module(module):
|
84 |
+
# Zero out the parameters of a module and return it.
|
85 |
+
for p in module.parameters():
|
86 |
+
p.detach().zero_()
|
87 |
+
return module
|
88 |
+
|
89 |
+
|
90 |
+
|
91 |
+
class KVCompressionCrossAttention(nn.Module):
|
92 |
+
r"""
|
93 |
+
A cross attention layer.
|
94 |
+
|
95 |
+
Parameters:
|
96 |
+
query_dim (`int`): The number of channels in the query.
|
97 |
+
cross_attention_dim (`int`, *optional*):
|
98 |
+
The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
|
99 |
+
heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
|
100 |
+
dim_head (`int`, *optional*, defaults to 64): The number of channels in each head.
|
101 |
+
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
102 |
+
bias (`bool`, *optional*, defaults to False):
|
103 |
+
Set to `True` for the query, key, and value linear layers to contain a bias parameter.
|
104 |
+
"""
|
105 |
+
|
106 |
+
def __init__(
|
107 |
+
self,
|
108 |
+
query_dim: int,
|
109 |
+
cross_attention_dim: Optional[int] = None,
|
110 |
+
heads: int = 8,
|
111 |
+
dim_head: int = 64,
|
112 |
+
dropout: float = 0.0,
|
113 |
+
bias=False,
|
114 |
+
upcast_attention: bool = False,
|
115 |
+
upcast_softmax: bool = False,
|
116 |
+
added_kv_proj_dim: Optional[int] = None,
|
117 |
+
norm_num_groups: Optional[int] = None,
|
118 |
+
):
|
119 |
+
super().__init__()
|
120 |
+
inner_dim = dim_head * heads
|
121 |
+
cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
|
122 |
+
self.upcast_attention = upcast_attention
|
123 |
+
self.upcast_softmax = upcast_softmax
|
124 |
+
|
125 |
+
self.scale = dim_head**-0.5
|
126 |
+
|
127 |
+
self.heads = heads
|
128 |
+
# for slice_size > 0 the attention score computation
|
129 |
+
# is split across the batch axis to save memory
|
130 |
+
# You can set slice_size with `set_attention_slice`
|
131 |
+
self.sliceable_head_dim = heads
|
132 |
+
self._slice_size = None
|
133 |
+
self._use_memory_efficient_attention_xformers = True
|
134 |
+
self.added_kv_proj_dim = added_kv_proj_dim
|
135 |
+
|
136 |
+
if norm_num_groups is not None:
|
137 |
+
self.group_norm = nn.GroupNorm(num_channels=inner_dim, num_groups=norm_num_groups, eps=1e-5, affine=True)
|
138 |
+
else:
|
139 |
+
self.group_norm = None
|
140 |
+
|
141 |
+
self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
|
142 |
+
self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
|
143 |
+
self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
|
144 |
+
|
145 |
+
if self.added_kv_proj_dim is not None:
|
146 |
+
self.add_k_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
|
147 |
+
self.add_v_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
|
148 |
+
|
149 |
+
self.kv_compression = nn.Conv2d(
|
150 |
+
query_dim,
|
151 |
+
query_dim,
|
152 |
+
groups=query_dim,
|
153 |
+
kernel_size=2,
|
154 |
+
stride=2,
|
155 |
+
bias=True
|
156 |
+
)
|
157 |
+
self.kv_compression_norm = nn.LayerNorm(query_dim)
|
158 |
+
init.constant_(self.kv_compression.weight, 1 / 4)
|
159 |
+
if self.kv_compression.bias is not None:
|
160 |
+
init.constant_(self.kv_compression.bias, 0)
|
161 |
+
|
162 |
+
self.to_out = nn.ModuleList([])
|
163 |
+
self.to_out.append(nn.Linear(inner_dim, query_dim))
|
164 |
+
self.to_out.append(nn.Dropout(dropout))
|
165 |
+
|
166 |
+
def reshape_heads_to_batch_dim(self, tensor):
|
167 |
+
batch_size, seq_len, dim = tensor.shape
|
168 |
+
head_size = self.heads
|
169 |
+
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
|
170 |
+
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
|
171 |
+
return tensor
|
172 |
+
|
173 |
+
def reshape_batch_dim_to_heads(self, tensor):
|
174 |
+
batch_size, seq_len, dim = tensor.shape
|
175 |
+
head_size = self.heads
|
176 |
+
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
|
177 |
+
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
|
178 |
+
return tensor
|
179 |
+
|
180 |
+
def set_attention_slice(self, slice_size):
|
181 |
+
if slice_size is not None and slice_size > self.sliceable_head_dim:
|
182 |
+
raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.")
|
183 |
+
|
184 |
+
self._slice_size = slice_size
|
185 |
+
|
186 |
+
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, num_frames: int = 16, height: int = 32, width: int = 32):
|
187 |
+
batch_size, sequence_length, _ = hidden_states.shape
|
188 |
+
|
189 |
+
encoder_hidden_states = encoder_hidden_states
|
190 |
+
|
191 |
+
if self.group_norm is not None:
|
192 |
+
hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
193 |
+
|
194 |
+
query = self.to_q(hidden_states)
|
195 |
+
dim = query.shape[-1]
|
196 |
+
query = self.reshape_heads_to_batch_dim(query)
|
197 |
+
|
198 |
+
if self.added_kv_proj_dim is not None:
|
199 |
+
key = self.to_k(hidden_states)
|
200 |
+
value = self.to_v(hidden_states)
|
201 |
+
|
202 |
+
encoder_hidden_states_key_proj = self.add_k_proj(encoder_hidden_states)
|
203 |
+
encoder_hidden_states_value_proj = self.add_v_proj(encoder_hidden_states)
|
204 |
+
|
205 |
+
key = rearrange(key, "b (f h w) c -> (b f) c h w", f=num_frames, h=height, w=width)
|
206 |
+
key = self.kv_compression(key)
|
207 |
+
key = rearrange(key, "(b f) c h w -> b (f h w) c", f=num_frames)
|
208 |
+
key = self.kv_compression_norm(key)
|
209 |
+
key = key.to(query.dtype)
|
210 |
+
|
211 |
+
value = rearrange(value, "b (f h w) c -> (b f) c h w", f=num_frames, h=height, w=width)
|
212 |
+
value = self.kv_compression(value)
|
213 |
+
value = rearrange(value, "(b f) c h w -> b (f h w) c", f=num_frames)
|
214 |
+
value = self.kv_compression_norm(value)
|
215 |
+
value = value.to(query.dtype)
|
216 |
+
|
217 |
+
key = self.reshape_heads_to_batch_dim(key)
|
218 |
+
value = self.reshape_heads_to_batch_dim(value)
|
219 |
+
encoder_hidden_states_key_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_key_proj)
|
220 |
+
encoder_hidden_states_value_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_value_proj)
|
221 |
+
|
222 |
+
key = torch.concat([encoder_hidden_states_key_proj, key], dim=1)
|
223 |
+
value = torch.concat([encoder_hidden_states_value_proj, value], dim=1)
|
224 |
+
else:
|
225 |
+
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
|
226 |
+
key = self.to_k(encoder_hidden_states)
|
227 |
+
value = self.to_v(encoder_hidden_states)
|
228 |
+
|
229 |
+
key = rearrange(key, "b (f h w) c -> (b f) c h w", f=num_frames, h=height, w=width)
|
230 |
+
key = self.kv_compression(key)
|
231 |
+
key = rearrange(key, "(b f) c h w -> b (f h w) c", f=num_frames)
|
232 |
+
key = self.kv_compression_norm(key)
|
233 |
+
key = key.to(query.dtype)
|
234 |
+
|
235 |
+
value = rearrange(value, "b (f h w) c -> (b f) c h w", f=num_frames, h=height, w=width)
|
236 |
+
value = self.kv_compression(value)
|
237 |
+
value = rearrange(value, "(b f) c h w -> b (f h w) c", f=num_frames)
|
238 |
+
value = self.kv_compression_norm(value)
|
239 |
+
value = value.to(query.dtype)
|
240 |
+
|
241 |
+
key = self.reshape_heads_to_batch_dim(key)
|
242 |
+
value = self.reshape_heads_to_batch_dim(value)
|
243 |
+
|
244 |
+
if attention_mask is not None:
|
245 |
+
if attention_mask.shape[-1] != query.shape[1]:
|
246 |
+
target_length = query.shape[1]
|
247 |
+
attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
|
248 |
+
attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
|
249 |
+
|
250 |
+
# attention, what we cannot get enough of
|
251 |
+
if self._use_memory_efficient_attention_xformers:
|
252 |
+
hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
|
253 |
+
# Some versions of xformers return output in fp32, cast it back to the dtype of the input
|
254 |
+
hidden_states = hidden_states.to(query.dtype)
|
255 |
+
else:
|
256 |
+
if self._slice_size is None or query.shape[0] // self._slice_size == 1:
|
257 |
+
hidden_states = self._attention(query, key, value, attention_mask)
|
258 |
+
else:
|
259 |
+
hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
|
260 |
+
|
261 |
+
# linear proj
|
262 |
+
hidden_states = self.to_out[0](hidden_states)
|
263 |
+
|
264 |
+
# dropout
|
265 |
+
hidden_states = self.to_out[1](hidden_states)
|
266 |
+
return hidden_states
|
267 |
+
|
268 |
+
def _attention(self, query, key, value, attention_mask=None):
|
269 |
+
if self.upcast_attention:
|
270 |
+
query = query.float()
|
271 |
+
key = key.float()
|
272 |
+
|
273 |
+
attention_scores = torch.baddbmm(
|
274 |
+
torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
|
275 |
+
query,
|
276 |
+
key.transpose(-1, -2),
|
277 |
+
beta=0,
|
278 |
+
alpha=self.scale,
|
279 |
+
)
|
280 |
+
|
281 |
+
if attention_mask is not None:
|
282 |
+
attention_scores = attention_scores + attention_mask
|
283 |
+
|
284 |
+
if self.upcast_softmax:
|
285 |
+
attention_scores = attention_scores.float()
|
286 |
+
|
287 |
+
attention_probs = attention_scores.softmax(dim=-1)
|
288 |
+
|
289 |
+
# cast back to the original dtype
|
290 |
+
attention_probs = attention_probs.to(value.dtype)
|
291 |
+
|
292 |
+
# compute attention output
|
293 |
+
hidden_states = torch.bmm(attention_probs, value)
|
294 |
+
|
295 |
+
# reshape hidden_states
|
296 |
+
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
|
297 |
+
return hidden_states
|
298 |
+
|
299 |
+
def _sliced_attention(self, query, key, value, sequence_length, dim, attention_mask):
|
300 |
+
batch_size_attention = query.shape[0]
|
301 |
+
hidden_states = torch.zeros(
|
302 |
+
(batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype
|
303 |
+
)
|
304 |
+
slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0]
|
305 |
+
for i in range(hidden_states.shape[0] // slice_size):
|
306 |
+
start_idx = i * slice_size
|
307 |
+
end_idx = (i + 1) * slice_size
|
308 |
+
|
309 |
+
query_slice = query[start_idx:end_idx]
|
310 |
+
key_slice = key[start_idx:end_idx]
|
311 |
+
|
312 |
+
if self.upcast_attention:
|
313 |
+
query_slice = query_slice.float()
|
314 |
+
key_slice = key_slice.float()
|
315 |
+
|
316 |
+
attn_slice = torch.baddbmm(
|
317 |
+
torch.empty(slice_size, query.shape[1], key.shape[1], dtype=query_slice.dtype, device=query.device),
|
318 |
+
query_slice,
|
319 |
+
key_slice.transpose(-1, -2),
|
320 |
+
beta=0,
|
321 |
+
alpha=self.scale,
|
322 |
+
)
|
323 |
+
|
324 |
+
if attention_mask is not None:
|
325 |
+
attn_slice = attn_slice + attention_mask[start_idx:end_idx]
|
326 |
+
|
327 |
+
if self.upcast_softmax:
|
328 |
+
attn_slice = attn_slice.float()
|
329 |
+
|
330 |
+
attn_slice = attn_slice.softmax(dim=-1)
|
331 |
+
|
332 |
+
# cast back to the original dtype
|
333 |
+
attn_slice = attn_slice.to(value.dtype)
|
334 |
+
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
|
335 |
+
|
336 |
+
hidden_states[start_idx:end_idx] = attn_slice
|
337 |
+
|
338 |
+
# reshape hidden_states
|
339 |
+
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
|
340 |
+
return hidden_states
|
341 |
+
|
342 |
+
def _memory_efficient_attention_xformers(self, query, key, value, attention_mask):
|
343 |
+
# TODO attention_mask
|
344 |
+
query = query.contiguous()
|
345 |
+
key = key.contiguous()
|
346 |
+
value = value.contiguous()
|
347 |
+
hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
|
348 |
+
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
|
349 |
+
return hidden_states
|
350 |
+
|
351 |
+
|
352 |
+
@maybe_allow_in_graph
|
353 |
+
class TemporalTransformerBlock(nn.Module):
|
354 |
+
r"""
|
355 |
+
A Temporal Transformer block.
|
356 |
+
|
357 |
+
Parameters:
|
358 |
+
dim (`int`): The number of channels in the input and output.
|
359 |
+
num_attention_heads (`int`): The number of heads to use for multi-head attention.
|
360 |
+
attention_head_dim (`int`): The number of channels in each head.
|
361 |
+
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
362 |
+
cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
|
363 |
+
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
364 |
+
num_embeds_ada_norm (:
|
365 |
+
obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
|
366 |
+
attention_bias (:
|
367 |
+
obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
|
368 |
+
only_cross_attention (`bool`, *optional*):
|
369 |
+
Whether to use only cross-attention layers. In this case two cross attention layers are used.
|
370 |
+
double_self_attention (`bool`, *optional*):
|
371 |
+
Whether to use two self-attention layers. In this case no cross attention layers are used.
|
372 |
+
upcast_attention (`bool`, *optional*):
|
373 |
+
Whether to upcast the attention computation to float32. This is useful for mixed precision training.
|
374 |
+
norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
|
375 |
+
Whether to use learnable elementwise affine parameters for normalization.
|
376 |
+
norm_type (`str`, *optional*, defaults to `"layer_norm"`):
|
377 |
+
The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
|
378 |
+
final_dropout (`bool` *optional*, defaults to False):
|
379 |
+
Whether to apply a final dropout after the last feed-forward layer.
|
380 |
+
attention_type (`str`, *optional*, defaults to `"default"`):
|
381 |
+
The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
|
382 |
+
positional_embeddings (`str`, *optional*, defaults to `None`):
|
383 |
+
The type of positional embeddings to apply to.
|
384 |
+
num_positional_embeddings (`int`, *optional*, defaults to `None`):
|
385 |
+
The maximum number of positional embeddings to apply.
|
386 |
+
"""
|
387 |
+
|
388 |
+
def __init__(
|
389 |
+
self,
|
390 |
+
dim: int,
|
391 |
+
num_attention_heads: int,
|
392 |
+
attention_head_dim: int,
|
393 |
+
dropout=0.0,
|
394 |
+
cross_attention_dim: Optional[int] = None,
|
395 |
+
activation_fn: str = "geglu",
|
396 |
+
num_embeds_ada_norm: Optional[int] = None,
|
397 |
+
attention_bias: bool = False,
|
398 |
+
only_cross_attention: bool = False,
|
399 |
+
double_self_attention: bool = False,
|
400 |
+
upcast_attention: bool = False,
|
401 |
+
norm_elementwise_affine: bool = True,
|
402 |
+
norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single'
|
403 |
+
norm_eps: float = 1e-5,
|
404 |
+
final_dropout: bool = False,
|
405 |
+
attention_type: str = "default",
|
406 |
+
positional_embeddings: Optional[str] = None,
|
407 |
+
num_positional_embeddings: Optional[int] = None,
|
408 |
+
# kv compression
|
409 |
+
kvcompression: Optional[bool] = False,
|
410 |
+
# motion module kwargs
|
411 |
+
motion_module_type = "VanillaGrid",
|
412 |
+
motion_module_kwargs = None,
|
413 |
+
):
|
414 |
+
super().__init__()
|
415 |
+
self.only_cross_attention = only_cross_attention
|
416 |
+
|
417 |
+
self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
|
418 |
+
self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
|
419 |
+
self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
|
420 |
+
self.use_layer_norm = norm_type == "layer_norm"
|
421 |
+
|
422 |
+
if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
|
423 |
+
raise ValueError(
|
424 |
+
f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
|
425 |
+
f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
|
426 |
+
)
|
427 |
+
|
428 |
+
if positional_embeddings and (num_positional_embeddings is None):
|
429 |
+
raise ValueError(
|
430 |
+
"If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined."
|
431 |
+
)
|
432 |
+
|
433 |
+
if positional_embeddings == "sinusoidal":
|
434 |
+
self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings)
|
435 |
+
else:
|
436 |
+
self.pos_embed = None
|
437 |
+
|
438 |
+
# Define 3 blocks. Each block has its own normalization layer.
|
439 |
+
# 1. Self-Attn
|
440 |
+
if self.use_ada_layer_norm:
|
441 |
+
self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
|
442 |
+
elif self.use_ada_layer_norm_zero:
|
443 |
+
self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
|
444 |
+
else:
|
445 |
+
self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
|
446 |
+
|
447 |
+
self.kvcompression = kvcompression
|
448 |
+
if kvcompression:
|
449 |
+
self.attn1 = KVCompressionCrossAttention(
|
450 |
+
query_dim=dim,
|
451 |
+
heads=num_attention_heads,
|
452 |
+
dim_head=attention_head_dim,
|
453 |
+
dropout=dropout,
|
454 |
+
bias=attention_bias,
|
455 |
+
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
|
456 |
+
upcast_attention=upcast_attention,
|
457 |
+
)
|
458 |
+
else:
|
459 |
+
self.attn1 = Attention(
|
460 |
+
query_dim=dim,
|
461 |
+
heads=num_attention_heads,
|
462 |
+
dim_head=attention_head_dim,
|
463 |
+
dropout=dropout,
|
464 |
+
bias=attention_bias,
|
465 |
+
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
|
466 |
+
upcast_attention=upcast_attention,
|
467 |
+
)
|
468 |
+
print(self.attn1)
|
469 |
+
|
470 |
+
self.attn_temporal = get_motion_module(
|
471 |
+
in_channels = dim,
|
472 |
+
motion_module_type = motion_module_type,
|
473 |
+
motion_module_kwargs = motion_module_kwargs,
|
474 |
+
)
|
475 |
+
|
476 |
+
# 2. Cross-Attn
|
477 |
+
if cross_attention_dim is not None or double_self_attention:
|
478 |
+
# We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
|
479 |
+
# I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
|
480 |
+
# the second cross attention block.
|
481 |
+
self.norm2 = (
|
482 |
+
AdaLayerNorm(dim, num_embeds_ada_norm)
|
483 |
+
if self.use_ada_layer_norm
|
484 |
+
else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
|
485 |
+
)
|
486 |
+
self.attn2 = Attention(
|
487 |
+
query_dim=dim,
|
488 |
+
cross_attention_dim=cross_attention_dim if not double_self_attention else None,
|
489 |
+
heads=num_attention_heads,
|
490 |
+
dim_head=attention_head_dim,
|
491 |
+
dropout=dropout,
|
492 |
+
bias=attention_bias,
|
493 |
+
upcast_attention=upcast_attention,
|
494 |
+
) # is self-attn if encoder_hidden_states is none
|
495 |
+
else:
|
496 |
+
self.norm2 = None
|
497 |
+
self.attn2 = None
|
498 |
+
|
499 |
+
# 3. Feed-forward
|
500 |
+
if not self.use_ada_layer_norm_single:
|
501 |
+
self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
|
502 |
+
|
503 |
+
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
|
504 |
+
|
505 |
+
# 4. Fuser
|
506 |
+
if attention_type == "gated" or attention_type == "gated-text-image":
|
507 |
+
self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim)
|
508 |
+
|
509 |
+
# 5. Scale-shift for PixArt-Alpha.
|
510 |
+
if self.use_ada_layer_norm_single:
|
511 |
+
self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
|
512 |
+
|
513 |
+
# let chunk size default to None
|
514 |
+
self._chunk_size = None
|
515 |
+
self._chunk_dim = 0
|
516 |
+
|
517 |
+
def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
|
518 |
+
# Sets chunk feed-forward
|
519 |
+
self._chunk_size = chunk_size
|
520 |
+
self._chunk_dim = dim
|
521 |
+
|
522 |
+
def forward(
|
523 |
+
self,
|
524 |
+
hidden_states: torch.FloatTensor,
|
525 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
526 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
527 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
528 |
+
timestep: Optional[torch.LongTensor] = None,
|
529 |
+
cross_attention_kwargs: Dict[str, Any] = None,
|
530 |
+
class_labels: Optional[torch.LongTensor] = None,
|
531 |
+
num_frames: int = 16,
|
532 |
+
height: int = 32,
|
533 |
+
width: int = 32,
|
534 |
+
) -> torch.FloatTensor:
|
535 |
+
# Notice that normalization is always applied before the real computation in the following blocks.
|
536 |
+
# 0. Self-Attention
|
537 |
+
batch_size = hidden_states.shape[0]
|
538 |
+
|
539 |
+
if self.use_ada_layer_norm:
|
540 |
+
norm_hidden_states = self.norm1(hidden_states, timestep)
|
541 |
+
elif self.use_ada_layer_norm_zero:
|
542 |
+
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
|
543 |
+
hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
|
544 |
+
)
|
545 |
+
elif self.use_layer_norm:
|
546 |
+
norm_hidden_states = self.norm1(hidden_states)
|
547 |
+
elif self.use_ada_layer_norm_single:
|
548 |
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
|
549 |
+
self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
|
550 |
+
).chunk(6, dim=1)
|
551 |
+
norm_hidden_states = self.norm1(hidden_states)
|
552 |
+
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
|
553 |
+
norm_hidden_states = norm_hidden_states.squeeze(1)
|
554 |
+
else:
|
555 |
+
raise ValueError("Incorrect norm used")
|
556 |
+
|
557 |
+
if self.pos_embed is not None:
|
558 |
+
norm_hidden_states = self.pos_embed(norm_hidden_states)
|
559 |
+
|
560 |
+
# 1. Retrieve lora scale.
|
561 |
+
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
|
562 |
+
|
563 |
+
# 2. Prepare GLIGEN inputs
|
564 |
+
cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
|
565 |
+
gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
|
566 |
+
|
567 |
+
norm_hidden_states = rearrange(norm_hidden_states, "b (f d) c -> (b f) d c", f=num_frames)
|
568 |
+
if self.kvcompression:
|
569 |
+
attn_output = self.attn1(
|
570 |
+
norm_hidden_states,
|
571 |
+
encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
|
572 |
+
attention_mask=attention_mask,
|
573 |
+
num_frames=1,
|
574 |
+
height=height,
|
575 |
+
width=width,
|
576 |
+
**cross_attention_kwargs,
|
577 |
+
)
|
578 |
+
else:
|
579 |
+
attn_output = self.attn1(
|
580 |
+
norm_hidden_states,
|
581 |
+
encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
|
582 |
+
attention_mask=attention_mask,
|
583 |
+
**cross_attention_kwargs,
|
584 |
+
)
|
585 |
+
attn_output = rearrange(attn_output, "(b f) d c -> b (f d) c", f=num_frames)
|
586 |
+
if self.use_ada_layer_norm_zero:
|
587 |
+
attn_output = gate_msa.unsqueeze(1) * attn_output
|
588 |
+
elif self.use_ada_layer_norm_single:
|
589 |
+
attn_output = gate_msa * attn_output
|
590 |
+
|
591 |
+
hidden_states = attn_output + hidden_states
|
592 |
+
if hidden_states.ndim == 4:
|
593 |
+
hidden_states = hidden_states.squeeze(1)
|
594 |
+
|
595 |
+
# 2.5 GLIGEN Control
|
596 |
+
if gligen_kwargs is not None:
|
597 |
+
hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
|
598 |
+
|
599 |
+
# 2.75. Temp-Attention
|
600 |
+
if self.attn_temporal is not None:
|
601 |
+
attn_output = rearrange(hidden_states, "b (f h w) c -> b c f h w", f=num_frames, h=height, w=width)
|
602 |
+
attn_output = self.attn_temporal(attn_output)
|
603 |
+
hidden_states = rearrange(attn_output, "b c f h w -> b (f h w) c")
|
604 |
+
|
605 |
+
# 3. Cross-Attention
|
606 |
+
if self.attn2 is not None:
|
607 |
+
if self.use_ada_layer_norm:
|
608 |
+
norm_hidden_states = self.norm2(hidden_states, timestep)
|
609 |
+
elif self.use_ada_layer_norm_zero or self.use_layer_norm:
|
610 |
+
norm_hidden_states = self.norm2(hidden_states)
|
611 |
+
elif self.use_ada_layer_norm_single:
|
612 |
+
# For PixArt norm2 isn't applied here:
|
613 |
+
# https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103
|
614 |
+
norm_hidden_states = hidden_states
|
615 |
+
else:
|
616 |
+
raise ValueError("Incorrect norm")
|
617 |
+
|
618 |
+
if self.pos_embed is not None and self.use_ada_layer_norm_single is None:
|
619 |
+
norm_hidden_states = self.pos_embed(norm_hidden_states)
|
620 |
+
|
621 |
+
attn_output = self.attn2(
|
622 |
+
norm_hidden_states,
|
623 |
+
encoder_hidden_states=encoder_hidden_states,
|
624 |
+
attention_mask=encoder_attention_mask,
|
625 |
+
**cross_attention_kwargs,
|
626 |
+
)
|
627 |
+
hidden_states = attn_output + hidden_states
|
628 |
+
|
629 |
+
# 4. Feed-forward
|
630 |
+
if not self.use_ada_layer_norm_single:
|
631 |
+
norm_hidden_states = self.norm3(hidden_states)
|
632 |
+
|
633 |
+
if self.use_ada_layer_norm_zero:
|
634 |
+
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
635 |
+
|
636 |
+
if self.use_ada_layer_norm_single:
|
637 |
+
norm_hidden_states = self.norm2(hidden_states)
|
638 |
+
norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
|
639 |
+
|
640 |
+
if self._chunk_size is not None:
|
641 |
+
# "feed_forward_chunk_size" can be used to save memory
|
642 |
+
if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
|
643 |
+
raise ValueError(
|
644 |
+
f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
|
645 |
+
)
|
646 |
+
|
647 |
+
num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
|
648 |
+
ff_output = torch.cat(
|
649 |
+
[
|
650 |
+
self.ff(hid_slice, scale=lora_scale)
|
651 |
+
for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)
|
652 |
+
],
|
653 |
+
dim=self._chunk_dim,
|
654 |
+
)
|
655 |
+
else:
|
656 |
+
ff_output = self.ff(norm_hidden_states, scale=lora_scale)
|
657 |
+
|
658 |
+
if self.use_ada_layer_norm_zero:
|
659 |
+
ff_output = gate_mlp.unsqueeze(1) * ff_output
|
660 |
+
elif self.use_ada_layer_norm_single:
|
661 |
+
ff_output = gate_mlp * ff_output
|
662 |
+
|
663 |
+
hidden_states = ff_output + hidden_states
|
664 |
+
if hidden_states.ndim == 4:
|
665 |
+
hidden_states = hidden_states.squeeze(1)
|
666 |
+
|
667 |
+
return hidden_states
|
668 |
+
|
669 |
+
|
670 |
+
@maybe_allow_in_graph
|
671 |
+
class SelfAttentionTemporalTransformerBlock(nn.Module):
|
672 |
+
r"""
|
673 |
+
A Temporal Transformer block.
|
674 |
+
|
675 |
+
Parameters:
|
676 |
+
dim (`int`): The number of channels in the input and output.
|
677 |
+
num_attention_heads (`int`): The number of heads to use for multi-head attention.
|
678 |
+
attention_head_dim (`int`): The number of channels in each head.
|
679 |
+
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
680 |
+
cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
|
681 |
+
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
682 |
+
num_embeds_ada_norm (:
|
683 |
+
obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
|
684 |
+
attention_bias (:
|
685 |
+
obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
|
686 |
+
only_cross_attention (`bool`, *optional*):
|
687 |
+
Whether to use only cross-attention layers. In this case two cross attention layers are used.
|
688 |
+
double_self_attention (`bool`, *optional*):
|
689 |
+
Whether to use two self-attention layers. In this case no cross attention layers are used.
|
690 |
+
upcast_attention (`bool`, *optional*):
|
691 |
+
Whether to upcast the attention computation to float32. This is useful for mixed precision training.
|
692 |
+
norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
|
693 |
+
Whether to use learnable elementwise affine parameters for normalization.
|
694 |
+
norm_type (`str`, *optional*, defaults to `"layer_norm"`):
|
695 |
+
The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
|
696 |
+
final_dropout (`bool` *optional*, defaults to False):
|
697 |
+
Whether to apply a final dropout after the last feed-forward layer.
|
698 |
+
attention_type (`str`, *optional*, defaults to `"default"`):
|
699 |
+
The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
|
700 |
+
positional_embeddings (`str`, *optional*, defaults to `None`):
|
701 |
+
The type of positional embeddings to apply to.
|
702 |
+
num_positional_embeddings (`int`, *optional*, defaults to `None`):
|
703 |
+
The maximum number of positional embeddings to apply.
|
704 |
+
"""
|
705 |
+
|
706 |
+
def __init__(
|
707 |
+
self,
|
708 |
+
dim: int,
|
709 |
+
num_attention_heads: int,
|
710 |
+
attention_head_dim: int,
|
711 |
+
dropout=0.0,
|
712 |
+
cross_attention_dim: Optional[int] = None,
|
713 |
+
activation_fn: str = "geglu",
|
714 |
+
num_embeds_ada_norm: Optional[int] = None,
|
715 |
+
attention_bias: bool = False,
|
716 |
+
only_cross_attention: bool = False,
|
717 |
+
double_self_attention: bool = False,
|
718 |
+
upcast_attention: bool = False,
|
719 |
+
norm_elementwise_affine: bool = True,
|
720 |
+
norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single'
|
721 |
+
norm_eps: float = 1e-5,
|
722 |
+
final_dropout: bool = False,
|
723 |
+
attention_type: str = "default",
|
724 |
+
positional_embeddings: Optional[str] = None,
|
725 |
+
num_positional_embeddings: Optional[int] = None,
|
726 |
+
):
|
727 |
+
super().__init__()
|
728 |
+
self.only_cross_attention = only_cross_attention
|
729 |
+
|
730 |
+
self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
|
731 |
+
self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
|
732 |
+
self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
|
733 |
+
self.use_layer_norm = norm_type == "layer_norm"
|
734 |
+
|
735 |
+
if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
|
736 |
+
raise ValueError(
|
737 |
+
f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
|
738 |
+
f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
|
739 |
+
)
|
740 |
+
|
741 |
+
if positional_embeddings and (num_positional_embeddings is None):
|
742 |
+
raise ValueError(
|
743 |
+
"If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined."
|
744 |
+
)
|
745 |
+
|
746 |
+
if positional_embeddings == "sinusoidal":
|
747 |
+
self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings)
|
748 |
+
else:
|
749 |
+
self.pos_embed = None
|
750 |
+
|
751 |
+
# Define 3 blocks. Each block has its own normalization layer.
|
752 |
+
# 1. Self-Attn
|
753 |
+
if self.use_ada_layer_norm:
|
754 |
+
self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
|
755 |
+
elif self.use_ada_layer_norm_zero:
|
756 |
+
self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
|
757 |
+
else:
|
758 |
+
self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
|
759 |
+
|
760 |
+
self.attn1 = Attention(
|
761 |
+
query_dim=dim,
|
762 |
+
heads=num_attention_heads,
|
763 |
+
dim_head=attention_head_dim,
|
764 |
+
dropout=dropout,
|
765 |
+
bias=attention_bias,
|
766 |
+
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
|
767 |
+
upcast_attention=upcast_attention,
|
768 |
+
)
|
769 |
+
|
770 |
+
# 2. Cross-Attn
|
771 |
+
if cross_attention_dim is not None or double_self_attention:
|
772 |
+
# We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
|
773 |
+
# I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
|
774 |
+
# the second cross attention block.
|
775 |
+
self.norm2 = (
|
776 |
+
AdaLayerNorm(dim, num_embeds_ada_norm)
|
777 |
+
if self.use_ada_layer_norm
|
778 |
+
else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
|
779 |
+
)
|
780 |
+
self.attn2 = Attention(
|
781 |
+
query_dim=dim,
|
782 |
+
cross_attention_dim=cross_attention_dim if not double_self_attention else None,
|
783 |
+
heads=num_attention_heads,
|
784 |
+
dim_head=attention_head_dim,
|
785 |
+
dropout=dropout,
|
786 |
+
bias=attention_bias,
|
787 |
+
upcast_attention=upcast_attention,
|
788 |
+
) # is self-attn if encoder_hidden_states is none
|
789 |
+
else:
|
790 |
+
self.norm2 = None
|
791 |
+
self.attn2 = None
|
792 |
+
|
793 |
+
# 3. Feed-forward
|
794 |
+
if not self.use_ada_layer_norm_single:
|
795 |
+
self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
|
796 |
+
|
797 |
+
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
|
798 |
+
|
799 |
+
# 4. Fuser
|
800 |
+
if attention_type == "gated" or attention_type == "gated-text-image":
|
801 |
+
self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim)
|
802 |
+
|
803 |
+
# 5. Scale-shift for PixArt-Alpha.
|
804 |
+
if self.use_ada_layer_norm_single:
|
805 |
+
self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
|
806 |
+
|
807 |
+
# let chunk size default to None
|
808 |
+
self._chunk_size = None
|
809 |
+
self._chunk_dim = 0
|
810 |
+
|
811 |
+
def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
|
812 |
+
# Sets chunk feed-forward
|
813 |
+
self._chunk_size = chunk_size
|
814 |
+
self._chunk_dim = dim
|
815 |
+
|
816 |
+
def forward(
|
817 |
+
self,
|
818 |
+
hidden_states: torch.FloatTensor,
|
819 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
820 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
821 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
822 |
+
timestep: Optional[torch.LongTensor] = None,
|
823 |
+
cross_attention_kwargs: Dict[str, Any] = None,
|
824 |
+
class_labels: Optional[torch.LongTensor] = None,
|
825 |
+
) -> torch.FloatTensor:
|
826 |
+
# Notice that normalization is always applied before the real computation in the following blocks.
|
827 |
+
# 0. Self-Attention
|
828 |
+
batch_size = hidden_states.shape[0]
|
829 |
+
|
830 |
+
if self.use_ada_layer_norm:
|
831 |
+
norm_hidden_states = self.norm1(hidden_states, timestep)
|
832 |
+
elif self.use_ada_layer_norm_zero:
|
833 |
+
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
|
834 |
+
hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
|
835 |
+
)
|
836 |
+
elif self.use_layer_norm:
|
837 |
+
norm_hidden_states = self.norm1(hidden_states)
|
838 |
+
elif self.use_ada_layer_norm_single:
|
839 |
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
|
840 |
+
self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
|
841 |
+
).chunk(6, dim=1)
|
842 |
+
norm_hidden_states = self.norm1(hidden_states)
|
843 |
+
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
|
844 |
+
norm_hidden_states = norm_hidden_states.squeeze(1)
|
845 |
+
else:
|
846 |
+
raise ValueError("Incorrect norm used")
|
847 |
+
|
848 |
+
if self.pos_embed is not None:
|
849 |
+
norm_hidden_states = self.pos_embed(norm_hidden_states)
|
850 |
+
|
851 |
+
# 1. Retrieve lora scale.
|
852 |
+
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
|
853 |
+
|
854 |
+
# 2. Prepare GLIGEN inputs
|
855 |
+
cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
|
856 |
+
gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
|
857 |
+
|
858 |
+
attn_output = self.attn1(
|
859 |
+
norm_hidden_states,
|
860 |
+
encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
|
861 |
+
attention_mask=attention_mask,
|
862 |
+
**cross_attention_kwargs,
|
863 |
+
)
|
864 |
+
|
865 |
+
if self.use_ada_layer_norm_zero:
|
866 |
+
attn_output = gate_msa.unsqueeze(1) * attn_output
|
867 |
+
elif self.use_ada_layer_norm_single:
|
868 |
+
attn_output = gate_msa * attn_output
|
869 |
+
|
870 |
+
hidden_states = attn_output + hidden_states
|
871 |
+
if hidden_states.ndim == 4:
|
872 |
+
hidden_states = hidden_states.squeeze(1)
|
873 |
+
|
874 |
+
# 2.5 GLIGEN Control
|
875 |
+
if gligen_kwargs is not None:
|
876 |
+
hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
|
877 |
+
|
878 |
+
# 3. Cross-Attention
|
879 |
+
if self.attn2 is not None:
|
880 |
+
if self.use_ada_layer_norm:
|
881 |
+
norm_hidden_states = self.norm2(hidden_states, timestep)
|
882 |
+
elif self.use_ada_layer_norm_zero or self.use_layer_norm:
|
883 |
+
norm_hidden_states = self.norm2(hidden_states)
|
884 |
+
elif self.use_ada_layer_norm_single:
|
885 |
+
# For PixArt norm2 isn't applied here:
|
886 |
+
# https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103
|
887 |
+
norm_hidden_states = hidden_states
|
888 |
+
else:
|
889 |
+
raise ValueError("Incorrect norm")
|
890 |
+
|
891 |
+
if self.pos_embed is not None and self.use_ada_layer_norm_single is None:
|
892 |
+
norm_hidden_states = self.pos_embed(norm_hidden_states)
|
893 |
+
|
894 |
+
attn_output = self.attn2(
|
895 |
+
norm_hidden_states,
|
896 |
+
encoder_hidden_states=encoder_hidden_states,
|
897 |
+
attention_mask=encoder_attention_mask,
|
898 |
+
**cross_attention_kwargs,
|
899 |
+
)
|
900 |
+
hidden_states = attn_output + hidden_states
|
901 |
+
|
902 |
+
# 4. Feed-forward
|
903 |
+
if not self.use_ada_layer_norm_single:
|
904 |
+
norm_hidden_states = self.norm3(hidden_states)
|
905 |
+
|
906 |
+
if self.use_ada_layer_norm_zero:
|
907 |
+
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
908 |
+
|
909 |
+
if self.use_ada_layer_norm_single:
|
910 |
+
norm_hidden_states = self.norm2(hidden_states)
|
911 |
+
norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
|
912 |
+
|
913 |
+
if self._chunk_size is not None:
|
914 |
+
# "feed_forward_chunk_size" can be used to save memory
|
915 |
+
if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
|
916 |
+
raise ValueError(
|
917 |
+
f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
|
918 |
+
)
|
919 |
+
|
920 |
+
num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
|
921 |
+
ff_output = torch.cat(
|
922 |
+
[
|
923 |
+
self.ff(hid_slice, scale=lora_scale)
|
924 |
+
for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)
|
925 |
+
],
|
926 |
+
dim=self._chunk_dim,
|
927 |
+
)
|
928 |
+
else:
|
929 |
+
ff_output = self.ff(norm_hidden_states, scale=lora_scale)
|
930 |
+
|
931 |
+
if self.use_ada_layer_norm_zero:
|
932 |
+
ff_output = gate_mlp.unsqueeze(1) * ff_output
|
933 |
+
elif self.use_ada_layer_norm_single:
|
934 |
+
ff_output = gate_mlp * ff_output
|
935 |
+
|
936 |
+
hidden_states = ff_output + hidden_states
|
937 |
+
if hidden_states.ndim == 4:
|
938 |
+
hidden_states = hidden_states.squeeze(1)
|
939 |
+
|
940 |
+
return hidden_states
|
941 |
+
|
942 |
+
|
943 |
+
@maybe_allow_in_graph
|
944 |
+
class KVCompressionTransformerBlock(nn.Module):
|
945 |
+
r"""
|
946 |
+
A Temporal Transformer block.
|
947 |
+
|
948 |
+
Parameters:
|
949 |
+
dim (`int`): The number of channels in the input and output.
|
950 |
+
num_attention_heads (`int`): The number of heads to use for multi-head attention.
|
951 |
+
attention_head_dim (`int`): The number of channels in each head.
|
952 |
+
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
953 |
+
cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
|
954 |
+
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
955 |
+
num_embeds_ada_norm (:
|
956 |
+
obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
|
957 |
+
attention_bias (:
|
958 |
+
obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
|
959 |
+
only_cross_attention (`bool`, *optional*):
|
960 |
+
Whether to use only cross-attention layers. In this case two cross attention layers are used.
|
961 |
+
double_self_attention (`bool`, *optional*):
|
962 |
+
Whether to use two self-attention layers. In this case no cross attention layers are used.
|
963 |
+
upcast_attention (`bool`, *optional*):
|
964 |
+
Whether to upcast the attention computation to float32. This is useful for mixed precision training.
|
965 |
+
norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
|
966 |
+
Whether to use learnable elementwise affine parameters for normalization.
|
967 |
+
norm_type (`str`, *optional*, defaults to `"layer_norm"`):
|
968 |
+
The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
|
969 |
+
final_dropout (`bool` *optional*, defaults to False):
|
970 |
+
Whether to apply a final dropout after the last feed-forward layer.
|
971 |
+
attention_type (`str`, *optional*, defaults to `"default"`):
|
972 |
+
The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
|
973 |
+
positional_embeddings (`str`, *optional*, defaults to `None`):
|
974 |
+
The type of positional embeddings to apply to.
|
975 |
+
num_positional_embeddings (`int`, *optional*, defaults to `None`):
|
976 |
+
The maximum number of positional embeddings to apply.
|
977 |
+
"""
|
978 |
+
|
979 |
+
def __init__(
|
980 |
+
self,
|
981 |
+
dim: int,
|
982 |
+
num_attention_heads: int,
|
983 |
+
attention_head_dim: int,
|
984 |
+
dropout=0.0,
|
985 |
+
cross_attention_dim: Optional[int] = None,
|
986 |
+
activation_fn: str = "geglu",
|
987 |
+
num_embeds_ada_norm: Optional[int] = None,
|
988 |
+
attention_bias: bool = False,
|
989 |
+
only_cross_attention: bool = False,
|
990 |
+
double_self_attention: bool = False,
|
991 |
+
upcast_attention: bool = False,
|
992 |
+
norm_elementwise_affine: bool = True,
|
993 |
+
norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single'
|
994 |
+
norm_eps: float = 1e-5,
|
995 |
+
final_dropout: bool = False,
|
996 |
+
attention_type: str = "default",
|
997 |
+
positional_embeddings: Optional[str] = None,
|
998 |
+
num_positional_embeddings: Optional[int] = None,
|
999 |
+
kvcompression: Optional[bool] = False,
|
1000 |
+
):
|
1001 |
+
super().__init__()
|
1002 |
+
self.only_cross_attention = only_cross_attention
|
1003 |
+
|
1004 |
+
self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
|
1005 |
+
self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
|
1006 |
+
self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
|
1007 |
+
self.use_layer_norm = norm_type == "layer_norm"
|
1008 |
+
|
1009 |
+
if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
|
1010 |
+
raise ValueError(
|
1011 |
+
f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
|
1012 |
+
f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
|
1013 |
+
)
|
1014 |
+
|
1015 |
+
if positional_embeddings and (num_positional_embeddings is None):
|
1016 |
+
raise ValueError(
|
1017 |
+
"If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined."
|
1018 |
+
)
|
1019 |
+
|
1020 |
+
if positional_embeddings == "sinusoidal":
|
1021 |
+
self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings)
|
1022 |
+
else:
|
1023 |
+
self.pos_embed = None
|
1024 |
+
|
1025 |
+
# Define 3 blocks. Each block has its own normalization layer.
|
1026 |
+
# 1. Self-Attn
|
1027 |
+
if self.use_ada_layer_norm:
|
1028 |
+
self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
|
1029 |
+
elif self.use_ada_layer_norm_zero:
|
1030 |
+
self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
|
1031 |
+
else:
|
1032 |
+
self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
|
1033 |
+
|
1034 |
+
self.kvcompression = kvcompression
|
1035 |
+
if kvcompression:
|
1036 |
+
self.attn1 = KVCompressionCrossAttention(
|
1037 |
+
query_dim=dim,
|
1038 |
+
heads=num_attention_heads,
|
1039 |
+
dim_head=attention_head_dim,
|
1040 |
+
dropout=dropout,
|
1041 |
+
bias=attention_bias,
|
1042 |
+
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
|
1043 |
+
upcast_attention=upcast_attention,
|
1044 |
+
)
|
1045 |
+
else:
|
1046 |
+
self.attn1 = Attention(
|
1047 |
+
query_dim=dim,
|
1048 |
+
heads=num_attention_heads,
|
1049 |
+
dim_head=attention_head_dim,
|
1050 |
+
dropout=dropout,
|
1051 |
+
bias=attention_bias,
|
1052 |
+
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
|
1053 |
+
upcast_attention=upcast_attention,
|
1054 |
+
)
|
1055 |
+
print(self.attn1)
|
1056 |
+
|
1057 |
+
# 2. Cross-Attn
|
1058 |
+
if cross_attention_dim is not None or double_self_attention:
|
1059 |
+
# We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
|
1060 |
+
# I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
|
1061 |
+
# the second cross attention block.
|
1062 |
+
self.norm2 = (
|
1063 |
+
AdaLayerNorm(dim, num_embeds_ada_norm)
|
1064 |
+
if self.use_ada_layer_norm
|
1065 |
+
else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
|
1066 |
+
)
|
1067 |
+
self.attn2 = Attention(
|
1068 |
+
query_dim=dim,
|
1069 |
+
cross_attention_dim=cross_attention_dim if not double_self_attention else None,
|
1070 |
+
heads=num_attention_heads,
|
1071 |
+
dim_head=attention_head_dim,
|
1072 |
+
dropout=dropout,
|
1073 |
+
bias=attention_bias,
|
1074 |
+
upcast_attention=upcast_attention,
|
1075 |
+
) # is self-attn if encoder_hidden_states is none
|
1076 |
+
else:
|
1077 |
+
self.norm2 = None
|
1078 |
+
self.attn2 = None
|
1079 |
+
|
1080 |
+
# 3. Feed-forward
|
1081 |
+
if not self.use_ada_layer_norm_single:
|
1082 |
+
self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
|
1083 |
+
|
1084 |
+
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
|
1085 |
+
|
1086 |
+
# 4. Fuser
|
1087 |
+
if attention_type == "gated" or attention_type == "gated-text-image":
|
1088 |
+
self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim)
|
1089 |
+
|
1090 |
+
# 5. Scale-shift for PixArt-Alpha.
|
1091 |
+
if self.use_ada_layer_norm_single:
|
1092 |
+
self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
|
1093 |
+
|
1094 |
+
# let chunk size default to None
|
1095 |
+
self._chunk_size = None
|
1096 |
+
self._chunk_dim = 0
|
1097 |
+
|
1098 |
+
def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
|
1099 |
+
# Sets chunk feed-forward
|
1100 |
+
self._chunk_size = chunk_size
|
1101 |
+
self._chunk_dim = dim
|
1102 |
+
|
1103 |
+
def forward(
|
1104 |
+
self,
|
1105 |
+
hidden_states: torch.FloatTensor,
|
1106 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
1107 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
1108 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
1109 |
+
timestep: Optional[torch.LongTensor] = None,
|
1110 |
+
cross_attention_kwargs: Dict[str, Any] = None,
|
1111 |
+
class_labels: Optional[torch.LongTensor] = None,
|
1112 |
+
num_frames: int = 16,
|
1113 |
+
height: int = 32,
|
1114 |
+
width: int = 32,
|
1115 |
+
use_reentrant: bool = False,
|
1116 |
+
) -> torch.FloatTensor:
|
1117 |
+
# Notice that normalization is always applied before the real computation in the following blocks.
|
1118 |
+
# 0. Self-Attention
|
1119 |
+
batch_size = hidden_states.shape[0]
|
1120 |
+
|
1121 |
+
if self.use_ada_layer_norm:
|
1122 |
+
norm_hidden_states = self.norm1(hidden_states, timestep)
|
1123 |
+
elif self.use_ada_layer_norm_zero:
|
1124 |
+
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
|
1125 |
+
hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
|
1126 |
+
)
|
1127 |
+
elif self.use_layer_norm:
|
1128 |
+
norm_hidden_states = self.norm1(hidden_states)
|
1129 |
+
elif self.use_ada_layer_norm_single:
|
1130 |
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
|
1131 |
+
self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
|
1132 |
+
).chunk(6, dim=1)
|
1133 |
+
norm_hidden_states = self.norm1(hidden_states)
|
1134 |
+
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
|
1135 |
+
norm_hidden_states = norm_hidden_states.squeeze(1)
|
1136 |
+
else:
|
1137 |
+
raise ValueError("Incorrect norm used")
|
1138 |
+
|
1139 |
+
if self.pos_embed is not None:
|
1140 |
+
norm_hidden_states = self.pos_embed(norm_hidden_states)
|
1141 |
+
|
1142 |
+
# 1. Retrieve lora scale.
|
1143 |
+
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
|
1144 |
+
|
1145 |
+
# 2. Prepare GLIGEN inputs
|
1146 |
+
cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
|
1147 |
+
gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
|
1148 |
+
|
1149 |
+
if self.kvcompression:
|
1150 |
+
attn_output = self.attn1(
|
1151 |
+
norm_hidden_states,
|
1152 |
+
encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
|
1153 |
+
attention_mask=attention_mask,
|
1154 |
+
num_frames=num_frames,
|
1155 |
+
height=height,
|
1156 |
+
width=width,
|
1157 |
+
**cross_attention_kwargs,
|
1158 |
+
)
|
1159 |
+
else:
|
1160 |
+
attn_output = self.attn1(
|
1161 |
+
norm_hidden_states,
|
1162 |
+
encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
|
1163 |
+
attention_mask=attention_mask,
|
1164 |
+
**cross_attention_kwargs,
|
1165 |
+
)
|
1166 |
+
|
1167 |
+
if self.use_ada_layer_norm_zero:
|
1168 |
+
attn_output = gate_msa.unsqueeze(1) * attn_output
|
1169 |
+
elif self.use_ada_layer_norm_single:
|
1170 |
+
attn_output = gate_msa * attn_output
|
1171 |
+
|
1172 |
+
hidden_states = attn_output + hidden_states
|
1173 |
+
if hidden_states.ndim == 4:
|
1174 |
+
hidden_states = hidden_states.squeeze(1)
|
1175 |
+
|
1176 |
+
# 2.5 GLIGEN Control
|
1177 |
+
if gligen_kwargs is not None:
|
1178 |
+
hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
|
1179 |
+
|
1180 |
+
# 3. Cross-Attention
|
1181 |
+
if self.attn2 is not None:
|
1182 |
+
if self.use_ada_layer_norm:
|
1183 |
+
norm_hidden_states = self.norm2(hidden_states, timestep)
|
1184 |
+
elif self.use_ada_layer_norm_zero or self.use_layer_norm:
|
1185 |
+
norm_hidden_states = self.norm2(hidden_states)
|
1186 |
+
elif self.use_ada_layer_norm_single:
|
1187 |
+
# For PixArt norm2 isn't applied here:
|
1188 |
+
# https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103
|
1189 |
+
norm_hidden_states = hidden_states
|
1190 |
+
else:
|
1191 |
+
raise ValueError("Incorrect norm")
|
1192 |
+
|
1193 |
+
if self.pos_embed is not None and self.use_ada_layer_norm_single is None:
|
1194 |
+
norm_hidden_states = self.pos_embed(norm_hidden_states)
|
1195 |
+
|
1196 |
+
attn_output = self.attn2(
|
1197 |
+
norm_hidden_states,
|
1198 |
+
encoder_hidden_states=encoder_hidden_states,
|
1199 |
+
attention_mask=encoder_attention_mask,
|
1200 |
+
**cross_attention_kwargs,
|
1201 |
+
)
|
1202 |
+
hidden_states = attn_output + hidden_states
|
1203 |
+
|
1204 |
+
# 4. Feed-forward
|
1205 |
+
if not self.use_ada_layer_norm_single:
|
1206 |
+
norm_hidden_states = self.norm3(hidden_states)
|
1207 |
+
|
1208 |
+
if self.use_ada_layer_norm_zero:
|
1209 |
+
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
1210 |
+
|
1211 |
+
if self.use_ada_layer_norm_single:
|
1212 |
+
norm_hidden_states = self.norm2(hidden_states)
|
1213 |
+
norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
|
1214 |
+
|
1215 |
+
if self._chunk_size is not None:
|
1216 |
+
# "feed_forward_chunk_size" can be used to save memory
|
1217 |
+
if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
|
1218 |
+
raise ValueError(
|
1219 |
+
f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
|
1220 |
+
)
|
1221 |
+
|
1222 |
+
num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
|
1223 |
+
ff_output = torch.cat(
|
1224 |
+
[
|
1225 |
+
self.ff(hid_slice, scale=lora_scale)
|
1226 |
+
for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)
|
1227 |
+
],
|
1228 |
+
dim=self._chunk_dim,
|
1229 |
+
)
|
1230 |
+
else:
|
1231 |
+
ff_output = self.ff(norm_hidden_states, scale=lora_scale)
|
1232 |
+
|
1233 |
+
if self.use_ada_layer_norm_zero:
|
1234 |
+
ff_output = gate_mlp.unsqueeze(1) * ff_output
|
1235 |
+
elif self.use_ada_layer_norm_single:
|
1236 |
+
ff_output = gate_mlp * ff_output
|
1237 |
+
|
1238 |
+
hidden_states = ff_output + hidden_states
|
1239 |
+
if hidden_states.ndim == 4:
|
1240 |
+
hidden_states = hidden_states.squeeze(1)
|
1241 |
+
|
1242 |
+
return hidden_states
|
1243 |
+
|
1244 |
+
|
1245 |
+
class FeedForward(nn.Module):
|
1246 |
+
r"""
|
1247 |
+
A feed-forward layer.
|
1248 |
+
|
1249 |
+
Parameters:
|
1250 |
+
dim (`int`): The number of channels in the input.
|
1251 |
+
dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
|
1252 |
+
mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
|
1253 |
+
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
1254 |
+
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
1255 |
+
final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
|
1256 |
+
"""
|
1257 |
+
|
1258 |
+
def __init__(
|
1259 |
+
self,
|
1260 |
+
dim: int,
|
1261 |
+
dim_out: Optional[int] = None,
|
1262 |
+
mult: int = 4,
|
1263 |
+
dropout: float = 0.0,
|
1264 |
+
activation_fn: str = "geglu",
|
1265 |
+
final_dropout: bool = False,
|
1266 |
+
):
|
1267 |
+
super().__init__()
|
1268 |
+
inner_dim = int(dim * mult)
|
1269 |
+
dim_out = dim_out if dim_out is not None else dim
|
1270 |
+
linear_cls = LoRACompatibleLinear if not USE_PEFT_BACKEND else nn.Linear
|
1271 |
+
|
1272 |
+
if activation_fn == "gelu":
|
1273 |
+
act_fn = GELU(dim, inner_dim)
|
1274 |
+
if activation_fn == "gelu-approximate":
|
1275 |
+
act_fn = GELU(dim, inner_dim, approximate="tanh")
|
1276 |
+
elif activation_fn == "geglu":
|
1277 |
+
act_fn = GEGLU(dim, inner_dim)
|
1278 |
+
elif activation_fn == "geglu-approximate":
|
1279 |
+
act_fn = ApproximateGELU(dim, inner_dim)
|
1280 |
+
|
1281 |
+
self.net = nn.ModuleList([])
|
1282 |
+
# project in
|
1283 |
+
self.net.append(act_fn)
|
1284 |
+
# project dropout
|
1285 |
+
self.net.append(nn.Dropout(dropout))
|
1286 |
+
# project out
|
1287 |
+
self.net.append(linear_cls(inner_dim, dim_out))
|
1288 |
+
# FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
|
1289 |
+
if final_dropout:
|
1290 |
+
self.net.append(nn.Dropout(dropout))
|
1291 |
+
|
1292 |
+
def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor:
|
1293 |
+
compatible_cls = (GEGLU,) if USE_PEFT_BACKEND else (GEGLU, LoRACompatibleLinear)
|
1294 |
+
for module in self.net:
|
1295 |
+
if isinstance(module, compatible_cls):
|
1296 |
+
hidden_states = module(hidden_states, scale)
|
1297 |
+
else:
|
1298 |
+
hidden_states = module(hidden_states)
|
1299 |
+
return hidden_states
|
easyanimate/models/autoencoder_magvit.py
ADDED
@@ -0,0 +1,503 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
from typing import Dict, Optional, Tuple, Union
|
15 |
+
|
16 |
+
import torch
|
17 |
+
import torch.nn as nn
|
18 |
+
import torch.nn.functional as F
|
19 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
20 |
+
from diffusers.loaders import FromOriginalVAEMixin
|
21 |
+
from diffusers.models.attention_processor import (
|
22 |
+
ADDED_KV_ATTENTION_PROCESSORS, CROSS_ATTENTION_PROCESSORS, Attention,
|
23 |
+
AttentionProcessor, AttnAddedKVProcessor, AttnProcessor)
|
24 |
+
from diffusers.models.autoencoders.vae import (DecoderOutput,
|
25 |
+
DiagonalGaussianDistribution)
|
26 |
+
from diffusers.models.modeling_outputs import AutoencoderKLOutput
|
27 |
+
from diffusers.models.modeling_utils import ModelMixin
|
28 |
+
from diffusers.utils.accelerate_utils import apply_forward_hook
|
29 |
+
from torch import nn
|
30 |
+
|
31 |
+
from ..vae.ldm.models.omnigen_enc_dec import Decoder as omnigen_Mag_Decoder
|
32 |
+
from ..vae.ldm.models.omnigen_enc_dec import Encoder as omnigen_Mag_Encoder
|
33 |
+
|
34 |
+
|
35 |
+
def str_eval(item):
|
36 |
+
if type(item) == str:
|
37 |
+
return eval(item)
|
38 |
+
else:
|
39 |
+
return item
|
40 |
+
|
41 |
+
class AutoencoderKLMagvit(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
|
42 |
+
r"""
|
43 |
+
A VAE model with KL loss for encoding images into latents and decoding latent representations into images.
|
44 |
+
|
45 |
+
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
|
46 |
+
for all models (such as downloading or saving).
|
47 |
+
|
48 |
+
Parameters:
|
49 |
+
in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
|
50 |
+
out_channels (int, *optional*, defaults to 3): Number of channels in the output.
|
51 |
+
down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
|
52 |
+
Tuple of downsample block types.
|
53 |
+
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
|
54 |
+
Tuple of upsample block types.
|
55 |
+
block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
|
56 |
+
Tuple of block output channels.
|
57 |
+
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
|
58 |
+
latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space.
|
59 |
+
sample_size (`int`, *optional*, defaults to `32`): Sample input size.
|
60 |
+
scaling_factor (`float`, *optional*, defaults to 0.18215):
|
61 |
+
The component-wise standard deviation of the trained latent space computed using the first batch of the
|
62 |
+
training set. This is used to scale the latent space to have unit variance when training the diffusion
|
63 |
+
model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
|
64 |
+
diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
|
65 |
+
/ scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
|
66 |
+
Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
|
67 |
+
force_upcast (`bool`, *optional*, default to `True`):
|
68 |
+
If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE
|
69 |
+
can be fine-tuned / trained to a lower range without loosing too much precision in which case
|
70 |
+
`force_upcast` can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix
|
71 |
+
"""
|
72 |
+
|
73 |
+
_supports_gradient_checkpointing = True
|
74 |
+
|
75 |
+
@register_to_config
|
76 |
+
def __init__(
|
77 |
+
self,
|
78 |
+
in_channels: int = 3,
|
79 |
+
out_channels: int = 3,
|
80 |
+
ch = 128,
|
81 |
+
ch_mult = [ 1,2,4,4 ],
|
82 |
+
use_gc_blocks = None,
|
83 |
+
down_block_types: tuple = None,
|
84 |
+
up_block_types: tuple = None,
|
85 |
+
mid_block_type: str = "MidBlock3D",
|
86 |
+
mid_block_use_attention: bool = True,
|
87 |
+
mid_block_attention_type: str = "3d",
|
88 |
+
mid_block_num_attention_heads: int = 1,
|
89 |
+
layers_per_block: int = 2,
|
90 |
+
act_fn: str = "silu",
|
91 |
+
num_attention_heads: int = 1,
|
92 |
+
latent_channels: int = 4,
|
93 |
+
norm_num_groups: int = 32,
|
94 |
+
scaling_factor: float = 0.1825,
|
95 |
+
slice_compression_vae=False,
|
96 |
+
mini_batch_encoder=9,
|
97 |
+
mini_batch_decoder=3,
|
98 |
+
):
|
99 |
+
super().__init__()
|
100 |
+
down_block_types = str_eval(down_block_types)
|
101 |
+
up_block_types = str_eval(up_block_types)
|
102 |
+
self.encoder = omnigen_Mag_Encoder(
|
103 |
+
in_channels=in_channels,
|
104 |
+
out_channels=latent_channels,
|
105 |
+
down_block_types=down_block_types,
|
106 |
+
ch = ch,
|
107 |
+
ch_mult = ch_mult,
|
108 |
+
use_gc_blocks=use_gc_blocks,
|
109 |
+
mid_block_type=mid_block_type,
|
110 |
+
mid_block_use_attention=mid_block_use_attention,
|
111 |
+
mid_block_attention_type=mid_block_attention_type,
|
112 |
+
mid_block_num_attention_heads=mid_block_num_attention_heads,
|
113 |
+
layers_per_block=layers_per_block,
|
114 |
+
norm_num_groups=norm_num_groups,
|
115 |
+
act_fn=act_fn,
|
116 |
+
num_attention_heads=num_attention_heads,
|
117 |
+
double_z=True,
|
118 |
+
slice_compression_vae=slice_compression_vae,
|
119 |
+
mini_batch_encoder=mini_batch_encoder,
|
120 |
+
)
|
121 |
+
|
122 |
+
self.decoder = omnigen_Mag_Decoder(
|
123 |
+
in_channels=latent_channels,
|
124 |
+
out_channels=out_channels,
|
125 |
+
up_block_types=up_block_types,
|
126 |
+
ch = ch,
|
127 |
+
ch_mult = ch_mult,
|
128 |
+
use_gc_blocks=use_gc_blocks,
|
129 |
+
mid_block_type=mid_block_type,
|
130 |
+
mid_block_use_attention=mid_block_use_attention,
|
131 |
+
mid_block_attention_type=mid_block_attention_type,
|
132 |
+
mid_block_num_attention_heads=mid_block_num_attention_heads,
|
133 |
+
layers_per_block=layers_per_block,
|
134 |
+
norm_num_groups=norm_num_groups,
|
135 |
+
act_fn=act_fn,
|
136 |
+
num_attention_heads=num_attention_heads,
|
137 |
+
slice_compression_vae=slice_compression_vae,
|
138 |
+
mini_batch_decoder=mini_batch_decoder,
|
139 |
+
)
|
140 |
+
|
141 |
+
self.quant_conv = nn.Conv3d(2 * latent_channels, 2 * latent_channels, kernel_size=1)
|
142 |
+
self.post_quant_conv = nn.Conv3d(latent_channels, latent_channels, kernel_size=1)
|
143 |
+
|
144 |
+
self.slice_compression_vae = slice_compression_vae
|
145 |
+
self.mini_batch_encoder = mini_batch_encoder
|
146 |
+
self.mini_batch_decoder = mini_batch_decoder
|
147 |
+
self.use_slicing = False
|
148 |
+
self.use_tiling = False
|
149 |
+
self.tile_sample_min_size = 256
|
150 |
+
self.tile_overlap_factor = 0.25
|
151 |
+
self.tile_latent_min_size = int(self.tile_sample_min_size / (2 ** (len(ch_mult) - 1)))
|
152 |
+
self.scaling_factor = scaling_factor
|
153 |
+
|
154 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
155 |
+
if isinstance(module, (omnigen_Mag_Encoder, omnigen_Mag_Decoder)):
|
156 |
+
module.gradient_checkpointing = value
|
157 |
+
|
158 |
+
@property
|
159 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
160 |
+
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
161 |
+
r"""
|
162 |
+
Returns:
|
163 |
+
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
164 |
+
indexed by its weight name.
|
165 |
+
"""
|
166 |
+
# set recursively
|
167 |
+
processors = {}
|
168 |
+
|
169 |
+
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
170 |
+
if hasattr(module, "get_processor"):
|
171 |
+
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
|
172 |
+
|
173 |
+
for sub_name, child in module.named_children():
|
174 |
+
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
175 |
+
|
176 |
+
return processors
|
177 |
+
|
178 |
+
for name, module in self.named_children():
|
179 |
+
fn_recursive_add_processors(name, module, processors)
|
180 |
+
|
181 |
+
return processors
|
182 |
+
|
183 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
184 |
+
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
185 |
+
r"""
|
186 |
+
Sets the attention processor to use to compute attention.
|
187 |
+
|
188 |
+
Parameters:
|
189 |
+
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
190 |
+
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
191 |
+
for **all** `Attention` layers.
|
192 |
+
|
193 |
+
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
194 |
+
processor. This is strongly recommended when setting trainable attention processors.
|
195 |
+
|
196 |
+
"""
|
197 |
+
count = len(self.attn_processors.keys())
|
198 |
+
|
199 |
+
if isinstance(processor, dict) and len(processor) != count:
|
200 |
+
raise ValueError(
|
201 |
+
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
202 |
+
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
203 |
+
)
|
204 |
+
|
205 |
+
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
206 |
+
if hasattr(module, "set_processor"):
|
207 |
+
if not isinstance(processor, dict):
|
208 |
+
module.set_processor(processor)
|
209 |
+
else:
|
210 |
+
module.set_processor(processor.pop(f"{name}.processor"))
|
211 |
+
|
212 |
+
for sub_name, child in module.named_children():
|
213 |
+
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
214 |
+
|
215 |
+
for name, module in self.named_children():
|
216 |
+
fn_recursive_attn_processor(name, module, processor)
|
217 |
+
|
218 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
|
219 |
+
def set_default_attn_processor(self):
|
220 |
+
"""
|
221 |
+
Disables custom attention processors and sets the default attention implementation.
|
222 |
+
"""
|
223 |
+
if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
|
224 |
+
processor = AttnAddedKVProcessor()
|
225 |
+
elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
|
226 |
+
processor = AttnProcessor()
|
227 |
+
else:
|
228 |
+
raise ValueError(
|
229 |
+
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
|
230 |
+
)
|
231 |
+
|
232 |
+
self.set_attn_processor(processor)
|
233 |
+
|
234 |
+
@apply_forward_hook
|
235 |
+
def encode(
|
236 |
+
self, x: torch.FloatTensor, return_dict: bool = True
|
237 |
+
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
|
238 |
+
"""
|
239 |
+
Encode a batch of images into latents.
|
240 |
+
|
241 |
+
Args:
|
242 |
+
x (`torch.FloatTensor`): Input batch of images.
|
243 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
244 |
+
Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
|
245 |
+
|
246 |
+
Returns:
|
247 |
+
The latent representations of the encoded images. If `return_dict` is True, a
|
248 |
+
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
|
249 |
+
"""
|
250 |
+
if self.use_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size):
|
251 |
+
return self.tiled_encode(x, return_dict=return_dict)
|
252 |
+
|
253 |
+
if self.use_slicing and x.shape[0] > 1:
|
254 |
+
encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)]
|
255 |
+
h = torch.cat(encoded_slices)
|
256 |
+
else:
|
257 |
+
h = self.encoder(x)
|
258 |
+
|
259 |
+
moments = self.quant_conv(h)
|
260 |
+
posterior = DiagonalGaussianDistribution(moments)
|
261 |
+
|
262 |
+
if not return_dict:
|
263 |
+
return (posterior,)
|
264 |
+
|
265 |
+
return AutoencoderKLOutput(latent_dist=posterior)
|
266 |
+
|
267 |
+
def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
|
268 |
+
if self.use_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size):
|
269 |
+
return self.tiled_decode(z, return_dict=return_dict)
|
270 |
+
z = self.post_quant_conv(z)
|
271 |
+
dec = self.decoder(z)
|
272 |
+
|
273 |
+
if not return_dict:
|
274 |
+
return (dec,)
|
275 |
+
|
276 |
+
return DecoderOutput(sample=dec)
|
277 |
+
|
278 |
+
@apply_forward_hook
|
279 |
+
def decode(
|
280 |
+
self, z: torch.FloatTensor, return_dict: bool = True, generator=None
|
281 |
+
) -> Union[DecoderOutput, torch.FloatTensor]:
|
282 |
+
"""
|
283 |
+
Decode a batch of images.
|
284 |
+
|
285 |
+
Args:
|
286 |
+
z (`torch.FloatTensor`): Input batch of latent vectors.
|
287 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
288 |
+
Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
|
289 |
+
|
290 |
+
Returns:
|
291 |
+
[`~models.vae.DecoderOutput`] or `tuple`:
|
292 |
+
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
|
293 |
+
returned.
|
294 |
+
|
295 |
+
"""
|
296 |
+
if self.use_slicing and z.shape[0] > 1:
|
297 |
+
decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
|
298 |
+
decoded = torch.cat(decoded_slices)
|
299 |
+
else:
|
300 |
+
decoded = self._decode(z).sample
|
301 |
+
|
302 |
+
if not return_dict:
|
303 |
+
return (decoded,)
|
304 |
+
|
305 |
+
return DecoderOutput(sample=decoded)
|
306 |
+
|
307 |
+
def blend_v(
|
308 |
+
self, a: torch.Tensor, b: torch.Tensor, blend_extent: int
|
309 |
+
) -> torch.Tensor:
|
310 |
+
blend_extent = min(a.shape[3], b.shape[3], blend_extent)
|
311 |
+
for y in range(blend_extent):
|
312 |
+
b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (
|
313 |
+
1 - y / blend_extent
|
314 |
+
) + b[:, :, :, y, :] * (y / blend_extent)
|
315 |
+
return b
|
316 |
+
|
317 |
+
def blend_h(
|
318 |
+
self, a: torch.Tensor, b: torch.Tensor, blend_extent: int
|
319 |
+
) -> torch.Tensor:
|
320 |
+
blend_extent = min(a.shape[4], b.shape[4], blend_extent)
|
321 |
+
for x in range(blend_extent):
|
322 |
+
b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (
|
323 |
+
1 - x / blend_extent
|
324 |
+
) + b[:, :, :, :, x] * (x / blend_extent)
|
325 |
+
return b
|
326 |
+
|
327 |
+
def tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
|
328 |
+
overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
|
329 |
+
blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
|
330 |
+
row_limit = self.tile_latent_min_size - blend_extent
|
331 |
+
|
332 |
+
# Split the image into 512x512 tiles and encode them separately.
|
333 |
+
rows = []
|
334 |
+
for i in range(0, x.shape[3], overlap_size):
|
335 |
+
row = []
|
336 |
+
for j in range(0, x.shape[4], overlap_size):
|
337 |
+
tile = x[
|
338 |
+
:,
|
339 |
+
:,
|
340 |
+
:,
|
341 |
+
i : i + self.tile_sample_min_size,
|
342 |
+
j : j + self.tile_sample_min_size,
|
343 |
+
]
|
344 |
+
tile = self.encoder(tile)
|
345 |
+
tile = self.quant_conv(tile)
|
346 |
+
row.append(tile)
|
347 |
+
rows.append(row)
|
348 |
+
result_rows = []
|
349 |
+
for i, row in enumerate(rows):
|
350 |
+
result_row = []
|
351 |
+
for j, tile in enumerate(row):
|
352 |
+
# blend the above tile and the left tile
|
353 |
+
# to the current tile and add the current tile to the result row
|
354 |
+
if i > 0:
|
355 |
+
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
|
356 |
+
if j > 0:
|
357 |
+
tile = self.blend_h(row[j - 1], tile, blend_extent)
|
358 |
+
result_row.append(tile[:, :, :, :row_limit, :row_limit])
|
359 |
+
result_rows.append(torch.cat(result_row, dim=4))
|
360 |
+
|
361 |
+
moments = torch.cat(result_rows, dim=3)
|
362 |
+
posterior = DiagonalGaussianDistribution(moments)
|
363 |
+
|
364 |
+
if not return_dict:
|
365 |
+
return (posterior,)
|
366 |
+
|
367 |
+
return AutoencoderKLOutput(latent_dist=posterior)
|
368 |
+
|
369 |
+
def tiled_decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
|
370 |
+
overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor))
|
371 |
+
blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor)
|
372 |
+
row_limit = self.tile_sample_min_size - blend_extent
|
373 |
+
|
374 |
+
# Split z into overlapping 64x64 tiles and decode them separately.
|
375 |
+
# The tiles have an overlap to avoid seams between tiles.
|
376 |
+
rows = []
|
377 |
+
for i in range(0, z.shape[3], overlap_size):
|
378 |
+
row = []
|
379 |
+
for j in range(0, z.shape[4], overlap_size):
|
380 |
+
tile = z[
|
381 |
+
:,
|
382 |
+
:,
|
383 |
+
:,
|
384 |
+
i : i + self.tile_latent_min_size,
|
385 |
+
j : j + self.tile_latent_min_size,
|
386 |
+
]
|
387 |
+
tile = self.post_quant_conv(tile)
|
388 |
+
decoded = self.decoder(tile)
|
389 |
+
row.append(decoded)
|
390 |
+
rows.append(row)
|
391 |
+
result_rows = []
|
392 |
+
for i, row in enumerate(rows):
|
393 |
+
result_row = []
|
394 |
+
for j, tile in enumerate(row):
|
395 |
+
# blend the above tile and the left tile
|
396 |
+
# to the current tile and add the current tile to the result row
|
397 |
+
if i > 0:
|
398 |
+
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
|
399 |
+
if j > 0:
|
400 |
+
tile = self.blend_h(row[j - 1], tile, blend_extent)
|
401 |
+
result_row.append(tile[:, :, :, :row_limit, :row_limit])
|
402 |
+
result_rows.append(torch.cat(result_row, dim=4))
|
403 |
+
|
404 |
+
dec = torch.cat(result_rows, dim=3)
|
405 |
+
if not return_dict:
|
406 |
+
return (dec,)
|
407 |
+
|
408 |
+
return DecoderOutput(sample=dec)
|
409 |
+
|
410 |
+
def forward(
|
411 |
+
self,
|
412 |
+
sample: torch.FloatTensor,
|
413 |
+
sample_posterior: bool = False,
|
414 |
+
return_dict: bool = True,
|
415 |
+
generator: Optional[torch.Generator] = None,
|
416 |
+
) -> Union[DecoderOutput, torch.FloatTensor]:
|
417 |
+
r"""
|
418 |
+
Args:
|
419 |
+
sample (`torch.FloatTensor`): Input sample.
|
420 |
+
sample_posterior (`bool`, *optional*, defaults to `False`):
|
421 |
+
Whether to sample from the posterior.
|
422 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
423 |
+
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
|
424 |
+
"""
|
425 |
+
x = sample
|
426 |
+
posterior = self.encode(x).latent_dist
|
427 |
+
if sample_posterior:
|
428 |
+
z = posterior.sample(generator=generator)
|
429 |
+
else:
|
430 |
+
z = posterior.mode()
|
431 |
+
dec = self.decode(z).sample
|
432 |
+
|
433 |
+
if not return_dict:
|
434 |
+
return (dec,)
|
435 |
+
|
436 |
+
return DecoderOutput(sample=dec)
|
437 |
+
|
438 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
|
439 |
+
def fuse_qkv_projections(self):
|
440 |
+
"""
|
441 |
+
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
|
442 |
+
key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
|
443 |
+
|
444 |
+
<Tip warning={true}>
|
445 |
+
|
446 |
+
This API is 🧪 experimental.
|
447 |
+
|
448 |
+
</Tip>
|
449 |
+
"""
|
450 |
+
self.original_attn_processors = None
|
451 |
+
|
452 |
+
for _, attn_processor in self.attn_processors.items():
|
453 |
+
if "Added" in str(attn_processor.__class__.__name__):
|
454 |
+
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
|
455 |
+
|
456 |
+
self.original_attn_processors = self.attn_processors
|
457 |
+
|
458 |
+
for module in self.modules():
|
459 |
+
if isinstance(module, Attention):
|
460 |
+
module.fuse_projections(fuse=True)
|
461 |
+
|
462 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
|
463 |
+
def unfuse_qkv_projections(self):
|
464 |
+
"""Disables the fused QKV projection if enabled.
|
465 |
+
|
466 |
+
<Tip warning={true}>
|
467 |
+
|
468 |
+
This API is 🧪 experimental.
|
469 |
+
|
470 |
+
</Tip>
|
471 |
+
|
472 |
+
"""
|
473 |
+
if self.original_attn_processors is not None:
|
474 |
+
self.set_attn_processor(self.original_attn_processors)
|
475 |
+
|
476 |
+
@classmethod
|
477 |
+
def from_pretrained(cls, pretrained_model_path, subfolder=None, **vae_additional_kwargs):
|
478 |
+
import json
|
479 |
+
import os
|
480 |
+
if subfolder is not None:
|
481 |
+
pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
|
482 |
+
|
483 |
+
config_file = os.path.join(pretrained_model_path, 'config.json')
|
484 |
+
if not os.path.isfile(config_file):
|
485 |
+
raise RuntimeError(f"{config_file} does not exist")
|
486 |
+
with open(config_file, "r") as f:
|
487 |
+
config = json.load(f)
|
488 |
+
|
489 |
+
model = cls.from_config(config, **vae_additional_kwargs)
|
490 |
+
from diffusers.utils import WEIGHTS_NAME
|
491 |
+
model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
|
492 |
+
model_file_safetensors = model_file.replace(".bin", ".safetensors")
|
493 |
+
if os.path.exists(model_file_safetensors):
|
494 |
+
from safetensors.torch import load_file, safe_open
|
495 |
+
state_dict = load_file(model_file_safetensors)
|
496 |
+
else:
|
497 |
+
if not os.path.isfile(model_file):
|
498 |
+
raise RuntimeError(f"{model_file} does not exist")
|
499 |
+
state_dict = torch.load(model_file, map_location="cpu")
|
500 |
+
m, u = model.load_state_dict(state_dict, strict=False)
|
501 |
+
print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
|
502 |
+
print(m, u)
|
503 |
+
return model
|
easyanimate/models/motion_module.py
ADDED
@@ -0,0 +1,575 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Modified from https://github.com/guoyww/AnimateDiff/blob/main/animatediff/models/motion_module.py
|
2 |
+
"""
|
3 |
+
import math
|
4 |
+
from typing import Any, Callable, List, Optional, Tuple, Union
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn.functional as F
|
8 |
+
from diffusers.models.attention import FeedForward
|
9 |
+
from diffusers.utils.import_utils import is_xformers_available
|
10 |
+
from einops import rearrange, repeat
|
11 |
+
from torch import nn
|
12 |
+
|
13 |
+
if is_xformers_available():
|
14 |
+
import xformers
|
15 |
+
import xformers.ops
|
16 |
+
else:
|
17 |
+
xformers = None
|
18 |
+
|
19 |
+
class CrossAttention(nn.Module):
|
20 |
+
r"""
|
21 |
+
A cross attention layer.
|
22 |
+
|
23 |
+
Parameters:
|
24 |
+
query_dim (`int`): The number of channels in the query.
|
25 |
+
cross_attention_dim (`int`, *optional*):
|
26 |
+
The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
|
27 |
+
heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
|
28 |
+
dim_head (`int`, *optional*, defaults to 64): The number of channels in each head.
|
29 |
+
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
30 |
+
bias (`bool`, *optional*, defaults to False):
|
31 |
+
Set to `True` for the query, key, and value linear layers to contain a bias parameter.
|
32 |
+
"""
|
33 |
+
|
34 |
+
def __init__(
|
35 |
+
self,
|
36 |
+
query_dim: int,
|
37 |
+
cross_attention_dim: Optional[int] = None,
|
38 |
+
heads: int = 8,
|
39 |
+
dim_head: int = 64,
|
40 |
+
dropout: float = 0.0,
|
41 |
+
bias=False,
|
42 |
+
upcast_attention: bool = False,
|
43 |
+
upcast_softmax: bool = False,
|
44 |
+
added_kv_proj_dim: Optional[int] = None,
|
45 |
+
norm_num_groups: Optional[int] = None,
|
46 |
+
):
|
47 |
+
super().__init__()
|
48 |
+
inner_dim = dim_head * heads
|
49 |
+
cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
|
50 |
+
self.upcast_attention = upcast_attention
|
51 |
+
self.upcast_softmax = upcast_softmax
|
52 |
+
|
53 |
+
self.scale = dim_head**-0.5
|
54 |
+
|
55 |
+
self.heads = heads
|
56 |
+
# for slice_size > 0 the attention score computation
|
57 |
+
# is split across the batch axis to save memory
|
58 |
+
# You can set slice_size with `set_attention_slice`
|
59 |
+
self.sliceable_head_dim = heads
|
60 |
+
self._slice_size = None
|
61 |
+
self._use_memory_efficient_attention_xformers = False
|
62 |
+
self.added_kv_proj_dim = added_kv_proj_dim
|
63 |
+
|
64 |
+
if norm_num_groups is not None:
|
65 |
+
self.group_norm = nn.GroupNorm(num_channels=inner_dim, num_groups=norm_num_groups, eps=1e-5, affine=True)
|
66 |
+
else:
|
67 |
+
self.group_norm = None
|
68 |
+
|
69 |
+
self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
|
70 |
+
self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
|
71 |
+
self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
|
72 |
+
|
73 |
+
if self.added_kv_proj_dim is not None:
|
74 |
+
self.add_k_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
|
75 |
+
self.add_v_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
|
76 |
+
|
77 |
+
self.to_out = nn.ModuleList([])
|
78 |
+
self.to_out.append(nn.Linear(inner_dim, query_dim))
|
79 |
+
self.to_out.append(nn.Dropout(dropout))
|
80 |
+
|
81 |
+
def set_use_memory_efficient_attention_xformers(
|
82 |
+
self, valid: bool, attention_op: Optional[Callable] = None
|
83 |
+
) -> None:
|
84 |
+
self._use_memory_efficient_attention_xformers = valid
|
85 |
+
|
86 |
+
def reshape_heads_to_batch_dim(self, tensor):
|
87 |
+
batch_size, seq_len, dim = tensor.shape
|
88 |
+
head_size = self.heads
|
89 |
+
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
|
90 |
+
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
|
91 |
+
return tensor
|
92 |
+
|
93 |
+
def reshape_batch_dim_to_heads(self, tensor):
|
94 |
+
batch_size, seq_len, dim = tensor.shape
|
95 |
+
head_size = self.heads
|
96 |
+
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
|
97 |
+
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
|
98 |
+
return tensor
|
99 |
+
|
100 |
+
def set_attention_slice(self, slice_size):
|
101 |
+
if slice_size is not None and slice_size > self.sliceable_head_dim:
|
102 |
+
raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.")
|
103 |
+
|
104 |
+
self._slice_size = slice_size
|
105 |
+
|
106 |
+
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
107 |
+
batch_size, sequence_length, _ = hidden_states.shape
|
108 |
+
|
109 |
+
encoder_hidden_states = encoder_hidden_states
|
110 |
+
|
111 |
+
if self.group_norm is not None:
|
112 |
+
hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
113 |
+
|
114 |
+
query = self.to_q(hidden_states)
|
115 |
+
dim = query.shape[-1]
|
116 |
+
query = self.reshape_heads_to_batch_dim(query)
|
117 |
+
|
118 |
+
if self.added_kv_proj_dim is not None:
|
119 |
+
key = self.to_k(hidden_states)
|
120 |
+
value = self.to_v(hidden_states)
|
121 |
+
encoder_hidden_states_key_proj = self.add_k_proj(encoder_hidden_states)
|
122 |
+
encoder_hidden_states_value_proj = self.add_v_proj(encoder_hidden_states)
|
123 |
+
|
124 |
+
key = self.reshape_heads_to_batch_dim(key)
|
125 |
+
value = self.reshape_heads_to_batch_dim(value)
|
126 |
+
encoder_hidden_states_key_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_key_proj)
|
127 |
+
encoder_hidden_states_value_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_value_proj)
|
128 |
+
|
129 |
+
key = torch.concat([encoder_hidden_states_key_proj, key], dim=1)
|
130 |
+
value = torch.concat([encoder_hidden_states_value_proj, value], dim=1)
|
131 |
+
else:
|
132 |
+
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
|
133 |
+
key = self.to_k(encoder_hidden_states)
|
134 |
+
value = self.to_v(encoder_hidden_states)
|
135 |
+
|
136 |
+
key = self.reshape_heads_to_batch_dim(key)
|
137 |
+
value = self.reshape_heads_to_batch_dim(value)
|
138 |
+
|
139 |
+
if attention_mask is not None:
|
140 |
+
if attention_mask.shape[-1] != query.shape[1]:
|
141 |
+
target_length = query.shape[1]
|
142 |
+
attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
|
143 |
+
attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
|
144 |
+
|
145 |
+
# attention, what we cannot get enough of
|
146 |
+
if self._use_memory_efficient_attention_xformers:
|
147 |
+
hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
|
148 |
+
# Some versions of xformers return output in fp32, cast it back to the dtype of the input
|
149 |
+
hidden_states = hidden_states.to(query.dtype)
|
150 |
+
else:
|
151 |
+
if self._slice_size is None or query.shape[0] // self._slice_size == 1:
|
152 |
+
hidden_states = self._attention(query, key, value, attention_mask)
|
153 |
+
else:
|
154 |
+
hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
|
155 |
+
|
156 |
+
# linear proj
|
157 |
+
hidden_states = self.to_out[0](hidden_states)
|
158 |
+
|
159 |
+
# dropout
|
160 |
+
hidden_states = self.to_out[1](hidden_states)
|
161 |
+
return hidden_states
|
162 |
+
|
163 |
+
def _attention(self, query, key, value, attention_mask=None):
|
164 |
+
if self.upcast_attention:
|
165 |
+
query = query.float()
|
166 |
+
key = key.float()
|
167 |
+
|
168 |
+
attention_scores = torch.baddbmm(
|
169 |
+
torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
|
170 |
+
query,
|
171 |
+
key.transpose(-1, -2),
|
172 |
+
beta=0,
|
173 |
+
alpha=self.scale,
|
174 |
+
)
|
175 |
+
|
176 |
+
if attention_mask is not None:
|
177 |
+
attention_scores = attention_scores + attention_mask
|
178 |
+
|
179 |
+
if self.upcast_softmax:
|
180 |
+
attention_scores = attention_scores.float()
|
181 |
+
|
182 |
+
attention_probs = attention_scores.softmax(dim=-1)
|
183 |
+
|
184 |
+
# cast back to the original dtype
|
185 |
+
attention_probs = attention_probs.to(value.dtype)
|
186 |
+
|
187 |
+
# compute attention output
|
188 |
+
hidden_states = torch.bmm(attention_probs, value)
|
189 |
+
|
190 |
+
# reshape hidden_states
|
191 |
+
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
|
192 |
+
return hidden_states
|
193 |
+
|
194 |
+
def _sliced_attention(self, query, key, value, sequence_length, dim, attention_mask):
|
195 |
+
batch_size_attention = query.shape[0]
|
196 |
+
hidden_states = torch.zeros(
|
197 |
+
(batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype
|
198 |
+
)
|
199 |
+
slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0]
|
200 |
+
for i in range(hidden_states.shape[0] // slice_size):
|
201 |
+
start_idx = i * slice_size
|
202 |
+
end_idx = (i + 1) * slice_size
|
203 |
+
|
204 |
+
query_slice = query[start_idx:end_idx]
|
205 |
+
key_slice = key[start_idx:end_idx]
|
206 |
+
|
207 |
+
if self.upcast_attention:
|
208 |
+
query_slice = query_slice.float()
|
209 |
+
key_slice = key_slice.float()
|
210 |
+
|
211 |
+
attn_slice = torch.baddbmm(
|
212 |
+
torch.empty(slice_size, query.shape[1], key.shape[1], dtype=query_slice.dtype, device=query.device),
|
213 |
+
query_slice,
|
214 |
+
key_slice.transpose(-1, -2),
|
215 |
+
beta=0,
|
216 |
+
alpha=self.scale,
|
217 |
+
)
|
218 |
+
|
219 |
+
if attention_mask is not None:
|
220 |
+
attn_slice = attn_slice + attention_mask[start_idx:end_idx]
|
221 |
+
|
222 |
+
if self.upcast_softmax:
|
223 |
+
attn_slice = attn_slice.float()
|
224 |
+
|
225 |
+
attn_slice = attn_slice.softmax(dim=-1)
|
226 |
+
|
227 |
+
# cast back to the original dtype
|
228 |
+
attn_slice = attn_slice.to(value.dtype)
|
229 |
+
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
|
230 |
+
|
231 |
+
hidden_states[start_idx:end_idx] = attn_slice
|
232 |
+
|
233 |
+
# reshape hidden_states
|
234 |
+
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
|
235 |
+
return hidden_states
|
236 |
+
|
237 |
+
def _memory_efficient_attention_xformers(self, query, key, value, attention_mask):
|
238 |
+
# TODO attention_mask
|
239 |
+
query = query.contiguous()
|
240 |
+
key = key.contiguous()
|
241 |
+
value = value.contiguous()
|
242 |
+
hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
|
243 |
+
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
|
244 |
+
return hidden_states
|
245 |
+
|
246 |
+
def zero_module(module):
|
247 |
+
# Zero out the parameters of a module and return it.
|
248 |
+
for p in module.parameters():
|
249 |
+
p.detach().zero_()
|
250 |
+
return module
|
251 |
+
|
252 |
+
def get_motion_module(
|
253 |
+
in_channels,
|
254 |
+
motion_module_type: str,
|
255 |
+
motion_module_kwargs: dict,
|
256 |
+
):
|
257 |
+
if motion_module_type == "Vanilla":
|
258 |
+
return VanillaTemporalModule(in_channels=in_channels, **motion_module_kwargs,)
|
259 |
+
elif motion_module_type == "VanillaGrid":
|
260 |
+
return VanillaTemporalModule(in_channels=in_channels, grid=True, **motion_module_kwargs,)
|
261 |
+
else:
|
262 |
+
raise ValueError
|
263 |
+
|
264 |
+
class VanillaTemporalModule(nn.Module):
|
265 |
+
def __init__(
|
266 |
+
self,
|
267 |
+
in_channels,
|
268 |
+
num_attention_heads = 8,
|
269 |
+
num_transformer_block = 2,
|
270 |
+
attention_block_types =( "Temporal_Self", "Temporal_Self" ),
|
271 |
+
cross_frame_attention_mode = None,
|
272 |
+
temporal_position_encoding = False,
|
273 |
+
temporal_position_encoding_max_len = 4096,
|
274 |
+
temporal_attention_dim_div = 1,
|
275 |
+
zero_initialize = True,
|
276 |
+
block_size = 1,
|
277 |
+
grid = False,
|
278 |
+
):
|
279 |
+
super().__init__()
|
280 |
+
|
281 |
+
self.temporal_transformer = TemporalTransformer3DModel(
|
282 |
+
in_channels=in_channels,
|
283 |
+
num_attention_heads=num_attention_heads,
|
284 |
+
attention_head_dim=in_channels // num_attention_heads // temporal_attention_dim_div,
|
285 |
+
num_layers=num_transformer_block,
|
286 |
+
attention_block_types=attention_block_types,
|
287 |
+
cross_frame_attention_mode=cross_frame_attention_mode,
|
288 |
+
temporal_position_encoding=temporal_position_encoding,
|
289 |
+
temporal_position_encoding_max_len=temporal_position_encoding_max_len,
|
290 |
+
grid=grid,
|
291 |
+
block_size=block_size,
|
292 |
+
)
|
293 |
+
if zero_initialize:
|
294 |
+
self.temporal_transformer.proj_out = zero_module(self.temporal_transformer.proj_out)
|
295 |
+
|
296 |
+
def forward(self, input_tensor, encoder_hidden_states=None, attention_mask=None, anchor_frame_idx=None):
|
297 |
+
hidden_states = input_tensor
|
298 |
+
hidden_states = self.temporal_transformer(hidden_states, encoder_hidden_states, attention_mask)
|
299 |
+
|
300 |
+
output = hidden_states
|
301 |
+
return output
|
302 |
+
|
303 |
+
class TemporalTransformer3DModel(nn.Module):
|
304 |
+
def __init__(
|
305 |
+
self,
|
306 |
+
in_channels,
|
307 |
+
num_attention_heads,
|
308 |
+
attention_head_dim,
|
309 |
+
|
310 |
+
num_layers,
|
311 |
+
attention_block_types = ( "Temporal_Self", "Temporal_Self", ),
|
312 |
+
dropout = 0.0,
|
313 |
+
norm_num_groups = 32,
|
314 |
+
cross_attention_dim = 768,
|
315 |
+
activation_fn = "geglu",
|
316 |
+
attention_bias = False,
|
317 |
+
upcast_attention = False,
|
318 |
+
|
319 |
+
cross_frame_attention_mode = None,
|
320 |
+
temporal_position_encoding = False,
|
321 |
+
temporal_position_encoding_max_len = 4096,
|
322 |
+
grid = False,
|
323 |
+
block_size = 1,
|
324 |
+
):
|
325 |
+
super().__init__()
|
326 |
+
|
327 |
+
inner_dim = num_attention_heads * attention_head_dim
|
328 |
+
|
329 |
+
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
330 |
+
self.proj_in = nn.Linear(in_channels, inner_dim)
|
331 |
+
|
332 |
+
self.block_size = block_size
|
333 |
+
self.transformer_blocks = nn.ModuleList(
|
334 |
+
[
|
335 |
+
TemporalTransformerBlock(
|
336 |
+
dim=inner_dim,
|
337 |
+
num_attention_heads=num_attention_heads,
|
338 |
+
attention_head_dim=attention_head_dim,
|
339 |
+
attention_block_types=attention_block_types,
|
340 |
+
dropout=dropout,
|
341 |
+
norm_num_groups=norm_num_groups,
|
342 |
+
cross_attention_dim=cross_attention_dim,
|
343 |
+
activation_fn=activation_fn,
|
344 |
+
attention_bias=attention_bias,
|
345 |
+
upcast_attention=upcast_attention,
|
346 |
+
cross_frame_attention_mode=cross_frame_attention_mode,
|
347 |
+
temporal_position_encoding=temporal_position_encoding,
|
348 |
+
temporal_position_encoding_max_len=temporal_position_encoding_max_len,
|
349 |
+
block_size=block_size,
|
350 |
+
grid=grid,
|
351 |
+
)
|
352 |
+
for d in range(num_layers)
|
353 |
+
]
|
354 |
+
)
|
355 |
+
self.proj_out = nn.Linear(inner_dim, in_channels)
|
356 |
+
|
357 |
+
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
358 |
+
assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
|
359 |
+
video_length = hidden_states.shape[2]
|
360 |
+
hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
|
361 |
+
|
362 |
+
batch, channel, height, weight = hidden_states.shape
|
363 |
+
residual = hidden_states
|
364 |
+
|
365 |
+
hidden_states = self.norm(hidden_states)
|
366 |
+
inner_dim = hidden_states.shape[1]
|
367 |
+
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
|
368 |
+
hidden_states = self.proj_in(hidden_states)
|
369 |
+
|
370 |
+
# Transformer Blocks
|
371 |
+
for block in self.transformer_blocks:
|
372 |
+
hidden_states = block(hidden_states, encoder_hidden_states=encoder_hidden_states, video_length=video_length, height=height, weight=weight)
|
373 |
+
|
374 |
+
# output
|
375 |
+
hidden_states = self.proj_out(hidden_states)
|
376 |
+
hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
|
377 |
+
|
378 |
+
output = hidden_states + residual
|
379 |
+
output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
|
380 |
+
|
381 |
+
return output
|
382 |
+
|
383 |
+
class TemporalTransformerBlock(nn.Module):
|
384 |
+
def __init__(
|
385 |
+
self,
|
386 |
+
dim,
|
387 |
+
num_attention_heads,
|
388 |
+
attention_head_dim,
|
389 |
+
attention_block_types = ( "Temporal_Self", "Temporal_Self", ),
|
390 |
+
dropout = 0.0,
|
391 |
+
norm_num_groups = 32,
|
392 |
+
cross_attention_dim = 768,
|
393 |
+
activation_fn = "geglu",
|
394 |
+
attention_bias = False,
|
395 |
+
upcast_attention = False,
|
396 |
+
cross_frame_attention_mode = None,
|
397 |
+
temporal_position_encoding = False,
|
398 |
+
temporal_position_encoding_max_len = 4096,
|
399 |
+
block_size = 1,
|
400 |
+
grid = False,
|
401 |
+
):
|
402 |
+
super().__init__()
|
403 |
+
|
404 |
+
attention_blocks = []
|
405 |
+
norms = []
|
406 |
+
|
407 |
+
for block_name in attention_block_types:
|
408 |
+
attention_blocks.append(
|
409 |
+
VersatileAttention(
|
410 |
+
attention_mode=block_name.split("_")[0],
|
411 |
+
cross_attention_dim=cross_attention_dim if block_name.endswith("_Cross") else None,
|
412 |
+
|
413 |
+
query_dim=dim,
|
414 |
+
heads=num_attention_heads,
|
415 |
+
dim_head=attention_head_dim,
|
416 |
+
dropout=dropout,
|
417 |
+
bias=attention_bias,
|
418 |
+
upcast_attention=upcast_attention,
|
419 |
+
|
420 |
+
cross_frame_attention_mode=cross_frame_attention_mode,
|
421 |
+
temporal_position_encoding=temporal_position_encoding,
|
422 |
+
temporal_position_encoding_max_len=temporal_position_encoding_max_len,
|
423 |
+
block_size=block_size,
|
424 |
+
grid=grid,
|
425 |
+
)
|
426 |
+
)
|
427 |
+
norms.append(nn.LayerNorm(dim))
|
428 |
+
|
429 |
+
self.attention_blocks = nn.ModuleList(attention_blocks)
|
430 |
+
self.norms = nn.ModuleList(norms)
|
431 |
+
|
432 |
+
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
|
433 |
+
self.ff_norm = nn.LayerNorm(dim)
|
434 |
+
|
435 |
+
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None, height=None, weight=None):
|
436 |
+
for attention_block, norm in zip(self.attention_blocks, self.norms):
|
437 |
+
norm_hidden_states = norm(hidden_states)
|
438 |
+
hidden_states = attention_block(
|
439 |
+
norm_hidden_states,
|
440 |
+
encoder_hidden_states=encoder_hidden_states if attention_block.is_cross_attention else None,
|
441 |
+
video_length=video_length,
|
442 |
+
height=height,
|
443 |
+
weight=weight,
|
444 |
+
) + hidden_states
|
445 |
+
|
446 |
+
hidden_states = self.ff(self.ff_norm(hidden_states)) + hidden_states
|
447 |
+
|
448 |
+
output = hidden_states
|
449 |
+
return output
|
450 |
+
|
451 |
+
class PositionalEncoding(nn.Module):
|
452 |
+
def __init__(
|
453 |
+
self,
|
454 |
+
d_model,
|
455 |
+
dropout = 0.,
|
456 |
+
max_len = 4096
|
457 |
+
):
|
458 |
+
super().__init__()
|
459 |
+
self.dropout = nn.Dropout(p=dropout)
|
460 |
+
position = torch.arange(max_len).unsqueeze(1)
|
461 |
+
div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
|
462 |
+
pe = torch.zeros(1, max_len, d_model)
|
463 |
+
pe[0, :, 0::2] = torch.sin(position * div_term)
|
464 |
+
pe[0, :, 1::2] = torch.cos(position * div_term)
|
465 |
+
self.register_buffer('pe', pe)
|
466 |
+
|
467 |
+
def forward(self, x):
|
468 |
+
x = x + self.pe[:, :x.size(1)]
|
469 |
+
return self.dropout(x)
|
470 |
+
|
471 |
+
class VersatileAttention(CrossAttention):
|
472 |
+
def __init__(
|
473 |
+
self,
|
474 |
+
attention_mode = None,
|
475 |
+
cross_frame_attention_mode = None,
|
476 |
+
temporal_position_encoding = False,
|
477 |
+
temporal_position_encoding_max_len = 4096,
|
478 |
+
grid = False,
|
479 |
+
block_size = 1,
|
480 |
+
*args, **kwargs
|
481 |
+
):
|
482 |
+
super().__init__(*args, **kwargs)
|
483 |
+
assert attention_mode == "Temporal"
|
484 |
+
|
485 |
+
self.attention_mode = attention_mode
|
486 |
+
self.is_cross_attention = kwargs["cross_attention_dim"] is not None
|
487 |
+
|
488 |
+
self.block_size = block_size
|
489 |
+
self.grid = grid
|
490 |
+
self.pos_encoder = PositionalEncoding(
|
491 |
+
kwargs["query_dim"],
|
492 |
+
dropout=0.,
|
493 |
+
max_len=temporal_position_encoding_max_len
|
494 |
+
) if (temporal_position_encoding and attention_mode == "Temporal") else None
|
495 |
+
|
496 |
+
def extra_repr(self):
|
497 |
+
return f"(Module Info) Attention_Mode: {self.attention_mode}, Is_Cross_Attention: {self.is_cross_attention}"
|
498 |
+
|
499 |
+
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None, height=None, weight=None):
|
500 |
+
batch_size, sequence_length, _ = hidden_states.shape
|
501 |
+
|
502 |
+
if self.attention_mode == "Temporal":
|
503 |
+
# for add pos_encoder
|
504 |
+
_, before_d, _c = hidden_states.size()
|
505 |
+
hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length)
|
506 |
+
if self.pos_encoder is not None:
|
507 |
+
hidden_states = self.pos_encoder(hidden_states)
|
508 |
+
|
509 |
+
if self.grid:
|
510 |
+
hidden_states = rearrange(hidden_states, "(b d) f c -> b f d c", f=video_length, d=before_d)
|
511 |
+
hidden_states = rearrange(hidden_states, "b f (h w) c -> b f h w c", h=height, w=weight)
|
512 |
+
|
513 |
+
hidden_states = rearrange(hidden_states, "b f (h n) (w m) c -> (b h w) (f n m) c", n=self.block_size, m=self.block_size)
|
514 |
+
d = before_d // self.block_size // self.block_size
|
515 |
+
else:
|
516 |
+
d = before_d
|
517 |
+
encoder_hidden_states = repeat(encoder_hidden_states, "b n c -> (b d) n c", d=d) if encoder_hidden_states is not None else encoder_hidden_states
|
518 |
+
else:
|
519 |
+
raise NotImplementedError
|
520 |
+
|
521 |
+
encoder_hidden_states = encoder_hidden_states
|
522 |
+
|
523 |
+
if self.group_norm is not None:
|
524 |
+
hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
525 |
+
|
526 |
+
query = self.to_q(hidden_states)
|
527 |
+
dim = query.shape[-1]
|
528 |
+
query = self.reshape_heads_to_batch_dim(query)
|
529 |
+
|
530 |
+
if self.added_kv_proj_dim is not None:
|
531 |
+
raise NotImplementedError
|
532 |
+
|
533 |
+
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
|
534 |
+
key = self.to_k(encoder_hidden_states)
|
535 |
+
value = self.to_v(encoder_hidden_states)
|
536 |
+
|
537 |
+
key = self.reshape_heads_to_batch_dim(key)
|
538 |
+
value = self.reshape_heads_to_batch_dim(value)
|
539 |
+
|
540 |
+
if attention_mask is not None:
|
541 |
+
if attention_mask.shape[-1] != query.shape[1]:
|
542 |
+
target_length = query.shape[1]
|
543 |
+
attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
|
544 |
+
attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
|
545 |
+
|
546 |
+
bs = 512
|
547 |
+
new_hidden_states = []
|
548 |
+
for i in range(0, query.shape[0], bs):
|
549 |
+
# attention, what we cannot get enough of
|
550 |
+
if self._use_memory_efficient_attention_xformers:
|
551 |
+
hidden_states = self._memory_efficient_attention_xformers(query[i : i + bs], key[i : i + bs], value[i : i + bs], attention_mask[i : i + bs] if attention_mask is not None else attention_mask)
|
552 |
+
# Some versions of xformers return output in fp32, cast it back to the dtype of the input
|
553 |
+
hidden_states = hidden_states.to(query.dtype)
|
554 |
+
else:
|
555 |
+
if self._slice_size is None or query[i : i + bs].shape[0] // self._slice_size == 1:
|
556 |
+
hidden_states = self._attention(query[i : i + bs], key[i : i + bs], value[i : i + bs], attention_mask[i : i + bs] if attention_mask is not None else attention_mask)
|
557 |
+
else:
|
558 |
+
hidden_states = self._sliced_attention(query[i : i + bs], key[i : i + bs], value[i : i + bs], sequence_length, dim, attention_mask[i : i + bs] if attention_mask is not None else attention_mask)
|
559 |
+
new_hidden_states.append(hidden_states)
|
560 |
+
hidden_states = torch.cat(new_hidden_states, dim = 0)
|
561 |
+
|
562 |
+
# linear proj
|
563 |
+
hidden_states = self.to_out[0](hidden_states)
|
564 |
+
|
565 |
+
# dropout
|
566 |
+
hidden_states = self.to_out[1](hidden_states)
|
567 |
+
|
568 |
+
if self.attention_mode == "Temporal":
|
569 |
+
hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
|
570 |
+
if self.grid:
|
571 |
+
hidden_states = rearrange(hidden_states, "(b f n m) (h w) c -> (b f) h n w m c", f=video_length, n=self.block_size, m=self.block_size, h=height // self.block_size, w=weight // self.block_size)
|
572 |
+
hidden_states = rearrange(hidden_states, "b h n w m c -> b (h n) (w m) c")
|
573 |
+
hidden_states = rearrange(hidden_states, "b h w c -> b (h w) c")
|
574 |
+
|
575 |
+
return hidden_states
|
easyanimate/models/patch.py
ADDED
@@ -0,0 +1,426 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import torch.nn.functional as F
|
6 |
+
import torch.nn.init as init
|
7 |
+
import math
|
8 |
+
from einops import rearrange
|
9 |
+
from torch import nn
|
10 |
+
|
11 |
+
|
12 |
+
def get_2d_sincos_pos_embed(
|
13 |
+
embed_dim, grid_size, cls_token=False, extra_tokens=0, interpolation_scale=1.0, base_size=16
|
14 |
+
):
|
15 |
+
"""
|
16 |
+
grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or
|
17 |
+
[1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
18 |
+
"""
|
19 |
+
if isinstance(grid_size, int):
|
20 |
+
grid_size = (grid_size, grid_size)
|
21 |
+
|
22 |
+
grid_h = np.arange(grid_size[0], dtype=np.float32) / (grid_size[0] / base_size) / interpolation_scale
|
23 |
+
grid_w = np.arange(grid_size[1], dtype=np.float32) / (grid_size[1] / base_size) / interpolation_scale
|
24 |
+
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
25 |
+
grid = np.stack(grid, axis=0)
|
26 |
+
|
27 |
+
grid = grid.reshape([2, 1, grid_size[1], grid_size[0]])
|
28 |
+
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
29 |
+
if cls_token and extra_tokens > 0:
|
30 |
+
pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
|
31 |
+
return pos_embed
|
32 |
+
|
33 |
+
|
34 |
+
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
35 |
+
if embed_dim % 2 != 0:
|
36 |
+
raise ValueError("embed_dim must be divisible by 2")
|
37 |
+
|
38 |
+
# use half of dimensions to encode grid_h
|
39 |
+
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
40 |
+
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
41 |
+
|
42 |
+
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
43 |
+
return emb
|
44 |
+
|
45 |
+
|
46 |
+
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
47 |
+
"""
|
48 |
+
embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D)
|
49 |
+
"""
|
50 |
+
if embed_dim % 2 != 0:
|
51 |
+
raise ValueError("embed_dim must be divisible by 2")
|
52 |
+
|
53 |
+
omega = np.arange(embed_dim // 2, dtype=np.float64)
|
54 |
+
omega /= embed_dim / 2.0
|
55 |
+
omega = 1.0 / 10000**omega # (D/2,)
|
56 |
+
|
57 |
+
pos = pos.reshape(-1) # (M,)
|
58 |
+
out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
|
59 |
+
|
60 |
+
emb_sin = np.sin(out) # (M, D/2)
|
61 |
+
emb_cos = np.cos(out) # (M, D/2)
|
62 |
+
|
63 |
+
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
64 |
+
return emb
|
65 |
+
|
66 |
+
class Patch1D(nn.Module):
|
67 |
+
def __init__(
|
68 |
+
self,
|
69 |
+
channels: int,
|
70 |
+
use_conv: bool = False,
|
71 |
+
out_channels: Optional[int] = None,
|
72 |
+
stride: int = 2,
|
73 |
+
padding: int = 0,
|
74 |
+
name: str = "conv",
|
75 |
+
):
|
76 |
+
super().__init__()
|
77 |
+
self.channels = channels
|
78 |
+
self.out_channels = out_channels or channels
|
79 |
+
self.use_conv = use_conv
|
80 |
+
self.padding = padding
|
81 |
+
self.name = name
|
82 |
+
|
83 |
+
if use_conv:
|
84 |
+
self.conv = nn.Conv1d(self.channels, self.out_channels, stride, stride=stride, padding=padding)
|
85 |
+
init.constant_(self.conv.weight, 0.0)
|
86 |
+
with torch.no_grad():
|
87 |
+
for i in range(len(self.conv.weight)): self.conv.weight[i, i] = 1 / stride
|
88 |
+
init.constant_(self.conv.bias, 0.0)
|
89 |
+
else:
|
90 |
+
assert self.channels == self.out_channels
|
91 |
+
self.conv = nn.AvgPool1d(kernel_size=stride, stride=stride)
|
92 |
+
|
93 |
+
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
94 |
+
assert inputs.shape[1] == self.channels
|
95 |
+
return self.conv(inputs)
|
96 |
+
|
97 |
+
class UnPatch1D(nn.Module):
|
98 |
+
def __init__(
|
99 |
+
self,
|
100 |
+
channels: int,
|
101 |
+
use_conv: bool = False,
|
102 |
+
use_conv_transpose: bool = False,
|
103 |
+
out_channels: Optional[int] = None,
|
104 |
+
name: str = "conv",
|
105 |
+
):
|
106 |
+
super().__init__()
|
107 |
+
self.channels = channels
|
108 |
+
self.out_channels = out_channels or channels
|
109 |
+
self.use_conv = use_conv
|
110 |
+
self.use_conv_transpose = use_conv_transpose
|
111 |
+
self.name = name
|
112 |
+
|
113 |
+
self.conv = None
|
114 |
+
if use_conv_transpose:
|
115 |
+
self.conv = nn.ConvTranspose1d(channels, self.out_channels, 4, 2, 1)
|
116 |
+
elif use_conv:
|
117 |
+
self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1)
|
118 |
+
|
119 |
+
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
120 |
+
assert inputs.shape[1] == self.channels
|
121 |
+
if self.use_conv_transpose:
|
122 |
+
return self.conv(inputs)
|
123 |
+
|
124 |
+
outputs = F.interpolate(inputs, scale_factor=2.0, mode="nearest")
|
125 |
+
|
126 |
+
if self.use_conv:
|
127 |
+
outputs = self.conv(outputs)
|
128 |
+
|
129 |
+
return outputs
|
130 |
+
|
131 |
+
class Upsampler(nn.Module):
|
132 |
+
def __init__(
|
133 |
+
self,
|
134 |
+
spatial_upsample_factor: int = 1,
|
135 |
+
temporal_upsample_factor: int = 1,
|
136 |
+
):
|
137 |
+
super().__init__()
|
138 |
+
|
139 |
+
self.spatial_upsample_factor = spatial_upsample_factor
|
140 |
+
self.temporal_upsample_factor = temporal_upsample_factor
|
141 |
+
|
142 |
+
class TemporalUpsampler3D(Upsampler):
|
143 |
+
def __init__(self):
|
144 |
+
super().__init__(
|
145 |
+
spatial_upsample_factor=1,
|
146 |
+
temporal_upsample_factor=2,
|
147 |
+
)
|
148 |
+
|
149 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
150 |
+
if x.shape[2] > 1:
|
151 |
+
first_frame, x = x[:, :, :1], x[:, :, 1:]
|
152 |
+
x = F.interpolate(x, scale_factor=(2, 1, 1), mode="trilinear")
|
153 |
+
x = torch.cat([first_frame, x], dim=2)
|
154 |
+
return x
|
155 |
+
|
156 |
+
def cast_tuple(t, length = 1):
|
157 |
+
return t if isinstance(t, tuple) else ((t,) * length)
|
158 |
+
|
159 |
+
def divisible_by(num, den):
|
160 |
+
return (num % den) == 0
|
161 |
+
|
162 |
+
def is_odd(n):
|
163 |
+
return not divisible_by(n, 2)
|
164 |
+
|
165 |
+
class CausalConv3d(nn.Conv3d):
|
166 |
+
def __init__(
|
167 |
+
self,
|
168 |
+
in_channels: int,
|
169 |
+
out_channels: int,
|
170 |
+
kernel_size=3, # : int | tuple[int, int, int],
|
171 |
+
stride=1, # : int | tuple[int, int, int] = 1,
|
172 |
+
padding=1, # : int | tuple[int, int, int], # TODO: change it to 0.
|
173 |
+
dilation=1, # : int | tuple[int, int, int] = 1,
|
174 |
+
**kwargs,
|
175 |
+
):
|
176 |
+
kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size,) * 3
|
177 |
+
assert len(kernel_size) == 3, f"Kernel size must be a 3-tuple, got {kernel_size} instead."
|
178 |
+
|
179 |
+
stride = stride if isinstance(stride, tuple) else (stride,) * 3
|
180 |
+
assert len(stride) == 3, f"Stride must be a 3-tuple, got {stride} instead."
|
181 |
+
|
182 |
+
dilation = dilation if isinstance(dilation, tuple) else (dilation,) * 3
|
183 |
+
assert len(dilation) == 3, f"Dilation must be a 3-tuple, got {dilation} instead."
|
184 |
+
|
185 |
+
t_ks, h_ks, w_ks = kernel_size
|
186 |
+
_, h_stride, w_stride = stride
|
187 |
+
t_dilation, h_dilation, w_dilation = dilation
|
188 |
+
|
189 |
+
t_pad = (t_ks - 1) * t_dilation
|
190 |
+
# TODO: align with SD
|
191 |
+
if padding is None:
|
192 |
+
h_pad = math.ceil(((h_ks - 1) * h_dilation + (1 - h_stride)) / 2)
|
193 |
+
w_pad = math.ceil(((w_ks - 1) * w_dilation + (1 - w_stride)) / 2)
|
194 |
+
elif isinstance(padding, int):
|
195 |
+
h_pad = w_pad = padding
|
196 |
+
else:
|
197 |
+
assert NotImplementedError
|
198 |
+
|
199 |
+
self.temporal_padding = t_pad
|
200 |
+
|
201 |
+
super().__init__(
|
202 |
+
in_channels=in_channels,
|
203 |
+
out_channels=out_channels,
|
204 |
+
kernel_size=kernel_size,
|
205 |
+
stride=stride,
|
206 |
+
dilation=dilation,
|
207 |
+
padding=(0, h_pad, w_pad),
|
208 |
+
**kwargs,
|
209 |
+
)
|
210 |
+
|
211 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
212 |
+
# x: (B, C, T, H, W)
|
213 |
+
x = F.pad(
|
214 |
+
x,
|
215 |
+
pad=(0, 0, 0, 0, self.temporal_padding, 0),
|
216 |
+
mode="replicate", # TODO: check if this is necessary
|
217 |
+
)
|
218 |
+
return super().forward(x)
|
219 |
+
|
220 |
+
class PatchEmbed3D(nn.Module):
|
221 |
+
"""3D Image to Patch Embedding"""
|
222 |
+
|
223 |
+
def __init__(
|
224 |
+
self,
|
225 |
+
height=224,
|
226 |
+
width=224,
|
227 |
+
patch_size=16,
|
228 |
+
time_patch_size=4,
|
229 |
+
in_channels=3,
|
230 |
+
embed_dim=768,
|
231 |
+
layer_norm=False,
|
232 |
+
flatten=True,
|
233 |
+
bias=True,
|
234 |
+
interpolation_scale=1,
|
235 |
+
):
|
236 |
+
super().__init__()
|
237 |
+
|
238 |
+
num_patches = (height // patch_size) * (width // patch_size)
|
239 |
+
self.flatten = flatten
|
240 |
+
self.layer_norm = layer_norm
|
241 |
+
|
242 |
+
self.proj = nn.Conv3d(
|
243 |
+
in_channels, embed_dim, kernel_size=(time_patch_size, patch_size, patch_size), stride=(time_patch_size, patch_size, patch_size), bias=bias
|
244 |
+
)
|
245 |
+
if layer_norm:
|
246 |
+
self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6)
|
247 |
+
else:
|
248 |
+
self.norm = None
|
249 |
+
|
250 |
+
self.patch_size = patch_size
|
251 |
+
# See:
|
252 |
+
# https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L161
|
253 |
+
self.height, self.width = height // patch_size, width // patch_size
|
254 |
+
self.base_size = height // patch_size
|
255 |
+
self.interpolation_scale = interpolation_scale
|
256 |
+
pos_embed = get_2d_sincos_pos_embed(
|
257 |
+
embed_dim, int(num_patches**0.5), base_size=self.base_size, interpolation_scale=self.interpolation_scale
|
258 |
+
)
|
259 |
+
self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=False)
|
260 |
+
|
261 |
+
def forward(self, latent):
|
262 |
+
height, width = latent.shape[-2] // self.patch_size, latent.shape[-1] // self.patch_size
|
263 |
+
|
264 |
+
latent = self.proj(latent)
|
265 |
+
latent = rearrange(latent, "b c f h w -> (b f) c h w")
|
266 |
+
if self.flatten:
|
267 |
+
latent = latent.flatten(2).transpose(1, 2) # BCFHW -> BNC
|
268 |
+
if self.layer_norm:
|
269 |
+
latent = self.norm(latent)
|
270 |
+
# Interpolate positional embeddings if needed.
|
271 |
+
# (For PixArt-Alpha: https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L162C151-L162C160)
|
272 |
+
if self.height != height or self.width != width:
|
273 |
+
pos_embed = get_2d_sincos_pos_embed(
|
274 |
+
embed_dim=self.pos_embed.shape[-1],
|
275 |
+
grid_size=(height, width),
|
276 |
+
base_size=self.base_size,
|
277 |
+
interpolation_scale=self.interpolation_scale,
|
278 |
+
)
|
279 |
+
pos_embed = torch.from_numpy(pos_embed)
|
280 |
+
pos_embed = pos_embed.float().unsqueeze(0).to(latent.device)
|
281 |
+
else:
|
282 |
+
pos_embed = self.pos_embed
|
283 |
+
|
284 |
+
return (latent + pos_embed).to(latent.dtype)
|
285 |
+
|
286 |
+
class PatchEmbedF3D(nn.Module):
|
287 |
+
"""Fake 3D Image to Patch Embedding"""
|
288 |
+
|
289 |
+
def __init__(
|
290 |
+
self,
|
291 |
+
height=224,
|
292 |
+
width=224,
|
293 |
+
patch_size=16,
|
294 |
+
in_channels=3,
|
295 |
+
embed_dim=768,
|
296 |
+
layer_norm=False,
|
297 |
+
flatten=True,
|
298 |
+
bias=True,
|
299 |
+
interpolation_scale=1,
|
300 |
+
):
|
301 |
+
super().__init__()
|
302 |
+
|
303 |
+
num_patches = (height // patch_size) * (width // patch_size)
|
304 |
+
self.flatten = flatten
|
305 |
+
self.layer_norm = layer_norm
|
306 |
+
|
307 |
+
self.proj = nn.Conv2d(
|
308 |
+
in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
|
309 |
+
)
|
310 |
+
self.proj_t = Patch1D(
|
311 |
+
embed_dim, True, stride=patch_size
|
312 |
+
)
|
313 |
+
if layer_norm:
|
314 |
+
self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6)
|
315 |
+
else:
|
316 |
+
self.norm = None
|
317 |
+
|
318 |
+
self.patch_size = patch_size
|
319 |
+
# See:
|
320 |
+
# https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L161
|
321 |
+
self.height, self.width = height // patch_size, width // patch_size
|
322 |
+
self.base_size = height // patch_size
|
323 |
+
self.interpolation_scale = interpolation_scale
|
324 |
+
pos_embed = get_2d_sincos_pos_embed(
|
325 |
+
embed_dim, int(num_patches**0.5), base_size=self.base_size, interpolation_scale=self.interpolation_scale
|
326 |
+
)
|
327 |
+
self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=False)
|
328 |
+
|
329 |
+
def forward(self, latent):
|
330 |
+
height, width = latent.shape[-2] // self.patch_size, latent.shape[-1] // self.patch_size
|
331 |
+
b, c, f, h, w = latent.size()
|
332 |
+
latent = rearrange(latent, "b c f h w -> (b f) c h w")
|
333 |
+
latent = self.proj(latent)
|
334 |
+
latent = rearrange(latent, "(b f) c h w -> b c f h w", f=f)
|
335 |
+
|
336 |
+
latent = rearrange(latent, "b c f h w -> (b h w) c f")
|
337 |
+
latent = self.proj_t(latent)
|
338 |
+
latent = rearrange(latent, "(b h w) c f -> b c f h w", h=h//2, w=w//2)
|
339 |
+
|
340 |
+
latent = rearrange(latent, "b c f h w -> (b f) c h w")
|
341 |
+
if self.flatten:
|
342 |
+
latent = latent.flatten(2).transpose(1, 2) # BCFHW -> BNC
|
343 |
+
if self.layer_norm:
|
344 |
+
latent = self.norm(latent)
|
345 |
+
|
346 |
+
# Interpolate positional embeddings if needed.
|
347 |
+
# (For PixArt-Alpha: https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L162C151-L162C160)
|
348 |
+
if self.height != height or self.width != width:
|
349 |
+
pos_embed = get_2d_sincos_pos_embed(
|
350 |
+
embed_dim=self.pos_embed.shape[-1],
|
351 |
+
grid_size=(height, width),
|
352 |
+
base_size=self.base_size,
|
353 |
+
interpolation_scale=self.interpolation_scale,
|
354 |
+
)
|
355 |
+
pos_embed = torch.from_numpy(pos_embed)
|
356 |
+
pos_embed = pos_embed.float().unsqueeze(0).to(latent.device)
|
357 |
+
else:
|
358 |
+
pos_embed = self.pos_embed
|
359 |
+
|
360 |
+
return (latent + pos_embed).to(latent.dtype)
|
361 |
+
|
362 |
+
class CasualPatchEmbed3D(nn.Module):
|
363 |
+
"""3D Image to Patch Embedding"""
|
364 |
+
|
365 |
+
def __init__(
|
366 |
+
self,
|
367 |
+
height=224,
|
368 |
+
width=224,
|
369 |
+
patch_size=16,
|
370 |
+
time_patch_size=4,
|
371 |
+
in_channels=3,
|
372 |
+
embed_dim=768,
|
373 |
+
layer_norm=False,
|
374 |
+
flatten=True,
|
375 |
+
bias=True,
|
376 |
+
interpolation_scale=1,
|
377 |
+
):
|
378 |
+
super().__init__()
|
379 |
+
|
380 |
+
num_patches = (height // patch_size) * (width // patch_size)
|
381 |
+
self.flatten = flatten
|
382 |
+
self.layer_norm = layer_norm
|
383 |
+
|
384 |
+
self.proj = CausalConv3d(
|
385 |
+
in_channels, embed_dim, kernel_size=(time_patch_size, patch_size, patch_size), stride=(time_patch_size, patch_size, patch_size), bias=bias, padding=None
|
386 |
+
)
|
387 |
+
if layer_norm:
|
388 |
+
self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6)
|
389 |
+
else:
|
390 |
+
self.norm = None
|
391 |
+
|
392 |
+
self.patch_size = patch_size
|
393 |
+
# See:
|
394 |
+
# https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L161
|
395 |
+
self.height, self.width = height // patch_size, width // patch_size
|
396 |
+
self.base_size = height // patch_size
|
397 |
+
self.interpolation_scale = interpolation_scale
|
398 |
+
pos_embed = get_2d_sincos_pos_embed(
|
399 |
+
embed_dim, int(num_patches**0.5), base_size=self.base_size, interpolation_scale=self.interpolation_scale
|
400 |
+
)
|
401 |
+
self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=False)
|
402 |
+
|
403 |
+
def forward(self, latent):
|
404 |
+
height, width = latent.shape[-2] // self.patch_size, latent.shape[-1] // self.patch_size
|
405 |
+
|
406 |
+
latent = self.proj(latent)
|
407 |
+
latent = rearrange(latent, "b c f h w -> (b f) c h w")
|
408 |
+
if self.flatten:
|
409 |
+
latent = latent.flatten(2).transpose(1, 2) # BCFHW -> BNC
|
410 |
+
if self.layer_norm:
|
411 |
+
latent = self.norm(latent)
|
412 |
+
# Interpolate positional embeddings if needed.
|
413 |
+
# (For PixArt-Alpha: https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L162C151-L162C160)
|
414 |
+
if self.height != height or self.width != width:
|
415 |
+
pos_embed = get_2d_sincos_pos_embed(
|
416 |
+
embed_dim=self.pos_embed.shape[-1],
|
417 |
+
grid_size=(height, width),
|
418 |
+
base_size=self.base_size,
|
419 |
+
interpolation_scale=self.interpolation_scale,
|
420 |
+
)
|
421 |
+
pos_embed = torch.from_numpy(pos_embed)
|
422 |
+
pos_embed = pos_embed.float().unsqueeze(0).to(latent.device)
|
423 |
+
else:
|
424 |
+
pos_embed = self.pos_embed
|
425 |
+
|
426 |
+
return (latent + pos_embed).to(latent.dtype)
|
easyanimate/models/transformer2d.py
ADDED
@@ -0,0 +1,555 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import json
|
15 |
+
import os
|
16 |
+
from dataclasses import dataclass
|
17 |
+
from typing import Any, Dict, Optional
|
18 |
+
|
19 |
+
import numpy as np
|
20 |
+
import torch
|
21 |
+
import torch.nn.functional as F
|
22 |
+
import torch.nn.init as init
|
23 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
24 |
+
from diffusers.models.attention import BasicTransformerBlock
|
25 |
+
from diffusers.models.embeddings import ImagePositionalEmbeddings, PatchEmbed
|
26 |
+
from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear
|
27 |
+
from diffusers.models.modeling_utils import ModelMixin
|
28 |
+
from diffusers.models.normalization import AdaLayerNormSingle
|
29 |
+
from diffusers.utils import (USE_PEFT_BACKEND, BaseOutput, deprecate,
|
30 |
+
is_torch_version)
|
31 |
+
from einops import rearrange
|
32 |
+
from torch import nn
|
33 |
+
|
34 |
+
try:
|
35 |
+
from diffusers.models.embeddings import PixArtAlphaTextProjection
|
36 |
+
except:
|
37 |
+
from diffusers.models.embeddings import \
|
38 |
+
CaptionProjection as PixArtAlphaTextProjection
|
39 |
+
|
40 |
+
from .attention import (KVCompressionTransformerBlock,
|
41 |
+
SelfAttentionTemporalTransformerBlock,
|
42 |
+
TemporalTransformerBlock)
|
43 |
+
|
44 |
+
|
45 |
+
@dataclass
|
46 |
+
class Transformer2DModelOutput(BaseOutput):
|
47 |
+
"""
|
48 |
+
The output of [`Transformer2DModel`].
|
49 |
+
|
50 |
+
Args:
|
51 |
+
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
|
52 |
+
The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
|
53 |
+
distributions for the unnoised latent pixels.
|
54 |
+
"""
|
55 |
+
|
56 |
+
sample: torch.FloatTensor
|
57 |
+
|
58 |
+
|
59 |
+
class Transformer2DModel(ModelMixin, ConfigMixin):
|
60 |
+
"""
|
61 |
+
A 2D Transformer model for image-like data.
|
62 |
+
|
63 |
+
Parameters:
|
64 |
+
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
|
65 |
+
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
|
66 |
+
in_channels (`int`, *optional*):
|
67 |
+
The number of channels in the input and output (specify if the input is **continuous**).
|
68 |
+
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
|
69 |
+
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
70 |
+
cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
|
71 |
+
sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
|
72 |
+
This is fixed during training since it is used to learn a number of position embeddings.
|
73 |
+
num_vector_embeds (`int`, *optional*):
|
74 |
+
The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**).
|
75 |
+
Includes the class for the masked latent pixel.
|
76 |
+
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
|
77 |
+
num_embeds_ada_norm ( `int`, *optional*):
|
78 |
+
The number of diffusion steps used during training. Pass if at least one of the norm_layers is
|
79 |
+
`AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
|
80 |
+
added to the hidden states.
|
81 |
+
|
82 |
+
During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`.
|
83 |
+
attention_bias (`bool`, *optional*):
|
84 |
+
Configure if the `TransformerBlocks` attention should contain a bias parameter.
|
85 |
+
"""
|
86 |
+
_supports_gradient_checkpointing = True
|
87 |
+
|
88 |
+
@register_to_config
|
89 |
+
def __init__(
|
90 |
+
self,
|
91 |
+
num_attention_heads: int = 16,
|
92 |
+
attention_head_dim: int = 88,
|
93 |
+
in_channels: Optional[int] = None,
|
94 |
+
out_channels: Optional[int] = None,
|
95 |
+
num_layers: int = 1,
|
96 |
+
dropout: float = 0.0,
|
97 |
+
norm_num_groups: int = 32,
|
98 |
+
cross_attention_dim: Optional[int] = None,
|
99 |
+
attention_bias: bool = False,
|
100 |
+
sample_size: Optional[int] = None,
|
101 |
+
num_vector_embeds: Optional[int] = None,
|
102 |
+
patch_size: Optional[int] = None,
|
103 |
+
activation_fn: str = "geglu",
|
104 |
+
num_embeds_ada_norm: Optional[int] = None,
|
105 |
+
use_linear_projection: bool = False,
|
106 |
+
only_cross_attention: bool = False,
|
107 |
+
double_self_attention: bool = False,
|
108 |
+
upcast_attention: bool = False,
|
109 |
+
norm_type: str = "layer_norm",
|
110 |
+
norm_elementwise_affine: bool = True,
|
111 |
+
norm_eps: float = 1e-5,
|
112 |
+
attention_type: str = "default",
|
113 |
+
caption_channels: int = None,
|
114 |
+
# block type
|
115 |
+
basic_block_type: str = "basic",
|
116 |
+
):
|
117 |
+
super().__init__()
|
118 |
+
self.use_linear_projection = use_linear_projection
|
119 |
+
self.num_attention_heads = num_attention_heads
|
120 |
+
self.attention_head_dim = attention_head_dim
|
121 |
+
self.basic_block_type = basic_block_type
|
122 |
+
inner_dim = num_attention_heads * attention_head_dim
|
123 |
+
|
124 |
+
conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
|
125 |
+
linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
|
126 |
+
|
127 |
+
# 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
|
128 |
+
# Define whether input is continuous or discrete depending on configuration
|
129 |
+
self.is_input_continuous = (in_channels is not None) and (patch_size is None)
|
130 |
+
self.is_input_vectorized = num_vector_embeds is not None
|
131 |
+
self.is_input_patches = in_channels is not None and patch_size is not None
|
132 |
+
|
133 |
+
if norm_type == "layer_norm" and num_embeds_ada_norm is not None:
|
134 |
+
deprecation_message = (
|
135 |
+
f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or"
|
136 |
+
" incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config."
|
137 |
+
" Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect"
|
138 |
+
" results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it"
|
139 |
+
" would be very nice if you could open a Pull request for the `transformer/config.json` file"
|
140 |
+
)
|
141 |
+
deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False)
|
142 |
+
norm_type = "ada_norm"
|
143 |
+
|
144 |
+
if self.is_input_continuous and self.is_input_vectorized:
|
145 |
+
raise ValueError(
|
146 |
+
f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make"
|
147 |
+
" sure that either `in_channels` or `num_vector_embeds` is None."
|
148 |
+
)
|
149 |
+
elif self.is_input_vectorized and self.is_input_patches:
|
150 |
+
raise ValueError(
|
151 |
+
f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make"
|
152 |
+
" sure that either `num_vector_embeds` or `num_patches` is None."
|
153 |
+
)
|
154 |
+
elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches:
|
155 |
+
raise ValueError(
|
156 |
+
f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:"
|
157 |
+
f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None."
|
158 |
+
)
|
159 |
+
|
160 |
+
# 2. Define input layers
|
161 |
+
if self.is_input_continuous:
|
162 |
+
self.in_channels = in_channels
|
163 |
+
|
164 |
+
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
165 |
+
if use_linear_projection:
|
166 |
+
self.proj_in = linear_cls(in_channels, inner_dim)
|
167 |
+
else:
|
168 |
+
self.proj_in = conv_cls(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
|
169 |
+
elif self.is_input_vectorized:
|
170 |
+
assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
|
171 |
+
assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed"
|
172 |
+
|
173 |
+
self.height = sample_size
|
174 |
+
self.width = sample_size
|
175 |
+
self.num_vector_embeds = num_vector_embeds
|
176 |
+
self.num_latent_pixels = self.height * self.width
|
177 |
+
|
178 |
+
self.latent_image_embedding = ImagePositionalEmbeddings(
|
179 |
+
num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width
|
180 |
+
)
|
181 |
+
elif self.is_input_patches:
|
182 |
+
assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size"
|
183 |
+
|
184 |
+
self.height = sample_size
|
185 |
+
self.width = sample_size
|
186 |
+
|
187 |
+
self.patch_size = patch_size
|
188 |
+
interpolation_scale = self.config.sample_size // 64 # => 64 (= 512 pixart) has interpolation scale 1
|
189 |
+
interpolation_scale = max(interpolation_scale, 1)
|
190 |
+
self.pos_embed = PatchEmbed(
|
191 |
+
height=sample_size,
|
192 |
+
width=sample_size,
|
193 |
+
patch_size=patch_size,
|
194 |
+
in_channels=in_channels,
|
195 |
+
embed_dim=inner_dim,
|
196 |
+
interpolation_scale=interpolation_scale,
|
197 |
+
)
|
198 |
+
|
199 |
+
basic_block = {
|
200 |
+
"basic": BasicTransformerBlock,
|
201 |
+
"kvcompression": KVCompressionTransformerBlock,
|
202 |
+
}[self.basic_block_type]
|
203 |
+
if self.basic_block_type == "kvcompression":
|
204 |
+
self.transformer_blocks = nn.ModuleList(
|
205 |
+
[
|
206 |
+
basic_block(
|
207 |
+
inner_dim,
|
208 |
+
num_attention_heads,
|
209 |
+
attention_head_dim,
|
210 |
+
dropout=dropout,
|
211 |
+
cross_attention_dim=cross_attention_dim,
|
212 |
+
activation_fn=activation_fn,
|
213 |
+
num_embeds_ada_norm=num_embeds_ada_norm,
|
214 |
+
attention_bias=attention_bias,
|
215 |
+
only_cross_attention=only_cross_attention,
|
216 |
+
double_self_attention=double_self_attention,
|
217 |
+
upcast_attention=upcast_attention,
|
218 |
+
norm_type=norm_type,
|
219 |
+
norm_elementwise_affine=norm_elementwise_affine,
|
220 |
+
norm_eps=norm_eps,
|
221 |
+
attention_type=attention_type,
|
222 |
+
kvcompression=False if d < 14 else True,
|
223 |
+
)
|
224 |
+
for d in range(num_layers)
|
225 |
+
]
|
226 |
+
)
|
227 |
+
else:
|
228 |
+
# 3. Define transformers blocks
|
229 |
+
self.transformer_blocks = nn.ModuleList(
|
230 |
+
[
|
231 |
+
BasicTransformerBlock(
|
232 |
+
inner_dim,
|
233 |
+
num_attention_heads,
|
234 |
+
attention_head_dim,
|
235 |
+
dropout=dropout,
|
236 |
+
cross_attention_dim=cross_attention_dim,
|
237 |
+
activation_fn=activation_fn,
|
238 |
+
num_embeds_ada_norm=num_embeds_ada_norm,
|
239 |
+
attention_bias=attention_bias,
|
240 |
+
only_cross_attention=only_cross_attention,
|
241 |
+
double_self_attention=double_self_attention,
|
242 |
+
upcast_attention=upcast_attention,
|
243 |
+
norm_type=norm_type,
|
244 |
+
norm_elementwise_affine=norm_elementwise_affine,
|
245 |
+
norm_eps=norm_eps,
|
246 |
+
attention_type=attention_type,
|
247 |
+
)
|
248 |
+
for d in range(num_layers)
|
249 |
+
]
|
250 |
+
)
|
251 |
+
|
252 |
+
# 4. Define output layers
|
253 |
+
self.out_channels = in_channels if out_channels is None else out_channels
|
254 |
+
if self.is_input_continuous:
|
255 |
+
# TODO: should use out_channels for continuous projections
|
256 |
+
if use_linear_projection:
|
257 |
+
self.proj_out = linear_cls(inner_dim, in_channels)
|
258 |
+
else:
|
259 |
+
self.proj_out = conv_cls(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
|
260 |
+
elif self.is_input_vectorized:
|
261 |
+
self.norm_out = nn.LayerNorm(inner_dim)
|
262 |
+
self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
|
263 |
+
elif self.is_input_patches and norm_type != "ada_norm_single":
|
264 |
+
self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
|
265 |
+
self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim)
|
266 |
+
self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
|
267 |
+
elif self.is_input_patches and norm_type == "ada_norm_single":
|
268 |
+
self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
|
269 |
+
self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5)
|
270 |
+
self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
|
271 |
+
|
272 |
+
# 5. PixArt-Alpha blocks.
|
273 |
+
self.adaln_single = None
|
274 |
+
self.use_additional_conditions = False
|
275 |
+
if norm_type == "ada_norm_single":
|
276 |
+
self.use_additional_conditions = self.config.sample_size == 128
|
277 |
+
# TODO(Sayak, PVP) clean this, for now we use sample size to determine whether to use
|
278 |
+
# additional conditions until we find better name
|
279 |
+
self.adaln_single = AdaLayerNormSingle(inner_dim, use_additional_conditions=self.use_additional_conditions)
|
280 |
+
|
281 |
+
self.caption_projection = None
|
282 |
+
if caption_channels is not None:
|
283 |
+
self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim)
|
284 |
+
|
285 |
+
self.gradient_checkpointing = False
|
286 |
+
|
287 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
288 |
+
if hasattr(module, "gradient_checkpointing"):
|
289 |
+
module.gradient_checkpointing = value
|
290 |
+
|
291 |
+
def forward(
|
292 |
+
self,
|
293 |
+
hidden_states: torch.Tensor,
|
294 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
295 |
+
timestep: Optional[torch.LongTensor] = None,
|
296 |
+
added_cond_kwargs: Dict[str, torch.Tensor] = None,
|
297 |
+
class_labels: Optional[torch.LongTensor] = None,
|
298 |
+
cross_attention_kwargs: Dict[str, Any] = None,
|
299 |
+
attention_mask: Optional[torch.Tensor] = None,
|
300 |
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
301 |
+
return_dict: bool = True,
|
302 |
+
):
|
303 |
+
"""
|
304 |
+
The [`Transformer2DModel`] forward method.
|
305 |
+
|
306 |
+
Args:
|
307 |
+
hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
|
308 |
+
Input `hidden_states`.
|
309 |
+
encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
|
310 |
+
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
|
311 |
+
self-attention.
|
312 |
+
timestep ( `torch.LongTensor`, *optional*):
|
313 |
+
Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
|
314 |
+
class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
|
315 |
+
Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
|
316 |
+
`AdaLayerZeroNorm`.
|
317 |
+
cross_attention_kwargs ( `Dict[str, Any]`, *optional*):
|
318 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
319 |
+
`self.processor` in
|
320 |
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
321 |
+
attention_mask ( `torch.Tensor`, *optional*):
|
322 |
+
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
|
323 |
+
is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
|
324 |
+
negative values to the attention scores corresponding to "discard" tokens.
|
325 |
+
encoder_attention_mask ( `torch.Tensor`, *optional*):
|
326 |
+
Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
|
327 |
+
|
328 |
+
* Mask `(batch, sequence_length)` True = keep, False = discard.
|
329 |
+
* Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
|
330 |
+
|
331 |
+
If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
|
332 |
+
above. This bias will be added to the cross-attention scores.
|
333 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
334 |
+
Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
|
335 |
+
tuple.
|
336 |
+
|
337 |
+
Returns:
|
338 |
+
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
|
339 |
+
`tuple` where the first element is the sample tensor.
|
340 |
+
"""
|
341 |
+
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
|
342 |
+
# we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
|
343 |
+
# we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
|
344 |
+
# expects mask of shape:
|
345 |
+
# [batch, key_tokens]
|
346 |
+
# adds singleton query_tokens dimension:
|
347 |
+
# [batch, 1, key_tokens]
|
348 |
+
# this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
|
349 |
+
# [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
|
350 |
+
# [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
|
351 |
+
if attention_mask is not None and attention_mask.ndim == 2:
|
352 |
+
# assume that mask is expressed as:
|
353 |
+
# (1 = keep, 0 = discard)
|
354 |
+
# convert mask into a bias that can be added to attention scores:
|
355 |
+
# (keep = +0, discard = -10000.0)
|
356 |
+
attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
|
357 |
+
attention_mask = attention_mask.unsqueeze(1)
|
358 |
+
|
359 |
+
# convert encoder_attention_mask to a bias the same way we do for attention_mask
|
360 |
+
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
|
361 |
+
encoder_attention_mask = (1 - encoder_attention_mask.to(encoder_hidden_states.dtype)) * -10000.0
|
362 |
+
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
|
363 |
+
|
364 |
+
# Retrieve lora scale.
|
365 |
+
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
|
366 |
+
|
367 |
+
# 1. Input
|
368 |
+
if self.is_input_continuous:
|
369 |
+
batch, _, height, width = hidden_states.shape
|
370 |
+
residual = hidden_states
|
371 |
+
|
372 |
+
hidden_states = self.norm(hidden_states)
|
373 |
+
if not self.use_linear_projection:
|
374 |
+
hidden_states = (
|
375 |
+
self.proj_in(hidden_states, scale=lora_scale)
|
376 |
+
if not USE_PEFT_BACKEND
|
377 |
+
else self.proj_in(hidden_states)
|
378 |
+
)
|
379 |
+
inner_dim = hidden_states.shape[1]
|
380 |
+
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
|
381 |
+
else:
|
382 |
+
inner_dim = hidden_states.shape[1]
|
383 |
+
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
|
384 |
+
hidden_states = (
|
385 |
+
self.proj_in(hidden_states, scale=lora_scale)
|
386 |
+
if not USE_PEFT_BACKEND
|
387 |
+
else self.proj_in(hidden_states)
|
388 |
+
)
|
389 |
+
|
390 |
+
elif self.is_input_vectorized:
|
391 |
+
hidden_states = self.latent_image_embedding(hidden_states)
|
392 |
+
elif self.is_input_patches:
|
393 |
+
height, width = hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size
|
394 |
+
hidden_states = self.pos_embed(hidden_states)
|
395 |
+
|
396 |
+
if self.adaln_single is not None:
|
397 |
+
if self.use_additional_conditions and added_cond_kwargs is None:
|
398 |
+
raise ValueError(
|
399 |
+
"`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`."
|
400 |
+
)
|
401 |
+
batch_size = hidden_states.shape[0]
|
402 |
+
timestep, embedded_timestep = self.adaln_single(
|
403 |
+
timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype
|
404 |
+
)
|
405 |
+
|
406 |
+
# 2. Blocks
|
407 |
+
if self.caption_projection is not None:
|
408 |
+
batch_size = hidden_states.shape[0]
|
409 |
+
encoder_hidden_states = self.caption_projection(encoder_hidden_states)
|
410 |
+
encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
|
411 |
+
|
412 |
+
for block in self.transformer_blocks:
|
413 |
+
if self.training and self.gradient_checkpointing:
|
414 |
+
args = {
|
415 |
+
"basic": [],
|
416 |
+
"kvcompression": [1, height, width],
|
417 |
+
}[self.basic_block_type]
|
418 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
419 |
+
block,
|
420 |
+
hidden_states,
|
421 |
+
attention_mask,
|
422 |
+
encoder_hidden_states,
|
423 |
+
encoder_attention_mask,
|
424 |
+
timestep,
|
425 |
+
cross_attention_kwargs,
|
426 |
+
class_labels,
|
427 |
+
*args,
|
428 |
+
use_reentrant=False,
|
429 |
+
)
|
430 |
+
else:
|
431 |
+
kwargs = {
|
432 |
+
"basic": {},
|
433 |
+
"kvcompression": {"num_frames":1, "height":height, "width":width},
|
434 |
+
}[self.basic_block_type]
|
435 |
+
hidden_states = block(
|
436 |
+
hidden_states,
|
437 |
+
attention_mask=attention_mask,
|
438 |
+
encoder_hidden_states=encoder_hidden_states,
|
439 |
+
encoder_attention_mask=encoder_attention_mask,
|
440 |
+
timestep=timestep,
|
441 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
442 |
+
class_labels=class_labels,
|
443 |
+
**kwargs
|
444 |
+
)
|
445 |
+
|
446 |
+
# 3. Output
|
447 |
+
if self.is_input_continuous:
|
448 |
+
if not self.use_linear_projection:
|
449 |
+
hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
|
450 |
+
hidden_states = (
|
451 |
+
self.proj_out(hidden_states, scale=lora_scale)
|
452 |
+
if not USE_PEFT_BACKEND
|
453 |
+
else self.proj_out(hidden_states)
|
454 |
+
)
|
455 |
+
else:
|
456 |
+
hidden_states = (
|
457 |
+
self.proj_out(hidden_states, scale=lora_scale)
|
458 |
+
if not USE_PEFT_BACKEND
|
459 |
+
else self.proj_out(hidden_states)
|
460 |
+
)
|
461 |
+
hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
|
462 |
+
|
463 |
+
output = hidden_states + residual
|
464 |
+
elif self.is_input_vectorized:
|
465 |
+
hidden_states = self.norm_out(hidden_states)
|
466 |
+
logits = self.out(hidden_states)
|
467 |
+
# (batch, self.num_vector_embeds - 1, self.num_latent_pixels)
|
468 |
+
logits = logits.permute(0, 2, 1)
|
469 |
+
|
470 |
+
# log(p(x_0))
|
471 |
+
output = F.log_softmax(logits.double(), dim=1).float()
|
472 |
+
|
473 |
+
if self.is_input_patches:
|
474 |
+
if self.config.norm_type != "ada_norm_single":
|
475 |
+
conditioning = self.transformer_blocks[0].norm1.emb(
|
476 |
+
timestep, class_labels, hidden_dtype=hidden_states.dtype
|
477 |
+
)
|
478 |
+
shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
|
479 |
+
hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
|
480 |
+
hidden_states = self.proj_out_2(hidden_states)
|
481 |
+
elif self.config.norm_type == "ada_norm_single":
|
482 |
+
shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1)
|
483 |
+
hidden_states = self.norm_out(hidden_states)
|
484 |
+
# Modulation
|
485 |
+
hidden_states = hidden_states * (1 + scale) + shift
|
486 |
+
hidden_states = self.proj_out(hidden_states)
|
487 |
+
hidden_states = hidden_states.squeeze(1)
|
488 |
+
|
489 |
+
# unpatchify
|
490 |
+
if self.adaln_single is None:
|
491 |
+
height = width = int(hidden_states.shape[1] ** 0.5)
|
492 |
+
hidden_states = hidden_states.reshape(
|
493 |
+
shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
|
494 |
+
)
|
495 |
+
hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
|
496 |
+
output = hidden_states.reshape(
|
497 |
+
shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
|
498 |
+
)
|
499 |
+
|
500 |
+
if not return_dict:
|
501 |
+
return (output,)
|
502 |
+
|
503 |
+
return Transformer2DModelOutput(sample=output)
|
504 |
+
|
505 |
+
@classmethod
|
506 |
+
def from_pretrained(cls, pretrained_model_path, subfolder=None, patch_size=2, transformer_additional_kwargs={}):
|
507 |
+
if subfolder is not None:
|
508 |
+
pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
|
509 |
+
print(f"loaded 2D transformer's pretrained weights from {pretrained_model_path} ...")
|
510 |
+
|
511 |
+
config_file = os.path.join(pretrained_model_path, 'config.json')
|
512 |
+
if not os.path.isfile(config_file):
|
513 |
+
raise RuntimeError(f"{config_file} does not exist")
|
514 |
+
with open(config_file, "r") as f:
|
515 |
+
config = json.load(f)
|
516 |
+
|
517 |
+
from diffusers.utils import WEIGHTS_NAME
|
518 |
+
model = cls.from_config(config, **transformer_additional_kwargs)
|
519 |
+
model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
|
520 |
+
model_file_safetensors = model_file.replace(".bin", ".safetensors")
|
521 |
+
if os.path.exists(model_file_safetensors):
|
522 |
+
from safetensors.torch import load_file, safe_open
|
523 |
+
state_dict = load_file(model_file_safetensors)
|
524 |
+
else:
|
525 |
+
if not os.path.isfile(model_file):
|
526 |
+
raise RuntimeError(f"{model_file} does not exist")
|
527 |
+
state_dict = torch.load(model_file, map_location="cpu")
|
528 |
+
|
529 |
+
if model.state_dict()['pos_embed.proj.weight'].size() != state_dict['pos_embed.proj.weight'].size():
|
530 |
+
new_shape = model.state_dict()['pos_embed.proj.weight'].size()
|
531 |
+
state_dict['pos_embed.proj.weight'] = torch.tile(state_dict['proj_out.weight'], [1, 2, 1, 1])
|
532 |
+
|
533 |
+
if model.state_dict()['proj_out.weight'].size() != state_dict['proj_out.weight'].size():
|
534 |
+
new_shape = model.state_dict()['proj_out.weight'].size()
|
535 |
+
state_dict['proj_out.weight'] = torch.tile(state_dict['proj_out.weight'], [patch_size, 1])
|
536 |
+
|
537 |
+
if model.state_dict()['proj_out.bias'].size() != state_dict['proj_out.bias'].size():
|
538 |
+
new_shape = model.state_dict()['proj_out.bias'].size()
|
539 |
+
state_dict['proj_out.bias'] = torch.tile(state_dict['proj_out.bias'], [patch_size])
|
540 |
+
|
541 |
+
tmp_state_dict = {}
|
542 |
+
for key in state_dict:
|
543 |
+
if key in model.state_dict().keys() and model.state_dict()[key].size() == state_dict[key].size():
|
544 |
+
tmp_state_dict[key] = state_dict[key]
|
545 |
+
else:
|
546 |
+
print(key, "Size don't match, skip")
|
547 |
+
state_dict = tmp_state_dict
|
548 |
+
|
549 |
+
m, u = model.load_state_dict(state_dict, strict=False)
|
550 |
+
print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
|
551 |
+
|
552 |
+
params = [p.numel() if "motion_modules." in n else 0 for n, p in model.named_parameters()]
|
553 |
+
print(f"### Postion Parameters: {sum(params) / 1e6} M")
|
554 |
+
|
555 |
+
return model
|
easyanimate/models/transformer3d.py
ADDED
@@ -0,0 +1,738 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import json
|
15 |
+
import math
|
16 |
+
import os
|
17 |
+
from dataclasses import dataclass
|
18 |
+
from typing import Any, Dict, Optional
|
19 |
+
|
20 |
+
import numpy as np
|
21 |
+
import torch
|
22 |
+
import torch.nn.functional as F
|
23 |
+
import torch.nn.init as init
|
24 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
25 |
+
from diffusers.models.attention import BasicTransformerBlock
|
26 |
+
from diffusers.models.embeddings import PatchEmbed, Timesteps, TimestepEmbedding
|
27 |
+
from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear
|
28 |
+
from diffusers.models.modeling_utils import ModelMixin
|
29 |
+
from diffusers.models.normalization import AdaLayerNormSingle
|
30 |
+
from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, is_torch_version
|
31 |
+
from einops import rearrange
|
32 |
+
from torch import nn
|
33 |
+
from typing import Dict, Optional, Tuple
|
34 |
+
|
35 |
+
from .attention import (SelfAttentionTemporalTransformerBlock,
|
36 |
+
TemporalTransformerBlock)
|
37 |
+
from .patch import Patch1D, PatchEmbed3D, PatchEmbedF3D, UnPatch1D, TemporalUpsampler3D, CasualPatchEmbed3D
|
38 |
+
|
39 |
+
try:
|
40 |
+
from diffusers.models.embeddings import PixArtAlphaTextProjection
|
41 |
+
except:
|
42 |
+
from diffusers.models.embeddings import \
|
43 |
+
CaptionProjection as PixArtAlphaTextProjection
|
44 |
+
|
45 |
+
def zero_module(module):
|
46 |
+
# Zero out the parameters of a module and return it.
|
47 |
+
for p in module.parameters():
|
48 |
+
p.detach().zero_()
|
49 |
+
return module
|
50 |
+
|
51 |
+
class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module):
|
52 |
+
"""
|
53 |
+
For PixArt-Alpha.
|
54 |
+
|
55 |
+
Reference:
|
56 |
+
https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L164C9-L168C29
|
57 |
+
"""
|
58 |
+
|
59 |
+
def __init__(self, embedding_dim, size_emb_dim, use_additional_conditions: bool = False):
|
60 |
+
super().__init__()
|
61 |
+
|
62 |
+
self.outdim = size_emb_dim
|
63 |
+
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
|
64 |
+
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
|
65 |
+
|
66 |
+
self.use_additional_conditions = use_additional_conditions
|
67 |
+
if use_additional_conditions:
|
68 |
+
self.additional_condition_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
|
69 |
+
self.resolution_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim)
|
70 |
+
self.aspect_ratio_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim)
|
71 |
+
|
72 |
+
self.resolution_embedder.linear_2 = zero_module(self.resolution_embedder.linear_2)
|
73 |
+
self.aspect_ratio_embedder.linear_2 = zero_module(self.aspect_ratio_embedder.linear_2)
|
74 |
+
|
75 |
+
def forward(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype):
|
76 |
+
timesteps_proj = self.time_proj(timestep)
|
77 |
+
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D)
|
78 |
+
|
79 |
+
if self.use_additional_conditions:
|
80 |
+
resolution_emb = self.additional_condition_proj(resolution.flatten()).to(hidden_dtype)
|
81 |
+
resolution_emb = self.resolution_embedder(resolution_emb).reshape(batch_size, -1)
|
82 |
+
aspect_ratio_emb = self.additional_condition_proj(aspect_ratio.flatten()).to(hidden_dtype)
|
83 |
+
aspect_ratio_emb = self.aspect_ratio_embedder(aspect_ratio_emb).reshape(batch_size, -1)
|
84 |
+
conditioning = timesteps_emb + torch.cat([resolution_emb, aspect_ratio_emb], dim=1)
|
85 |
+
else:
|
86 |
+
conditioning = timesteps_emb
|
87 |
+
|
88 |
+
return conditioning
|
89 |
+
|
90 |
+
class AdaLayerNormSingle(nn.Module):
|
91 |
+
r"""
|
92 |
+
Norm layer adaptive layer norm single (adaLN-single).
|
93 |
+
|
94 |
+
As proposed in PixArt-Alpha (see: https://arxiv.org/abs/2310.00426; Section 2.3).
|
95 |
+
|
96 |
+
Parameters:
|
97 |
+
embedding_dim (`int`): The size of each embedding vector.
|
98 |
+
use_additional_conditions (`bool`): To use additional conditions for normalization or not.
|
99 |
+
"""
|
100 |
+
|
101 |
+
def __init__(self, embedding_dim: int, use_additional_conditions: bool = False):
|
102 |
+
super().__init__()
|
103 |
+
|
104 |
+
self.emb = PixArtAlphaCombinedTimestepSizeEmbeddings(
|
105 |
+
embedding_dim, size_emb_dim=embedding_dim // 3, use_additional_conditions=use_additional_conditions
|
106 |
+
)
|
107 |
+
|
108 |
+
self.silu = nn.SiLU()
|
109 |
+
self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True)
|
110 |
+
|
111 |
+
def forward(
|
112 |
+
self,
|
113 |
+
timestep: torch.Tensor,
|
114 |
+
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
115 |
+
batch_size: Optional[int] = None,
|
116 |
+
hidden_dtype: Optional[torch.dtype] = None,
|
117 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
118 |
+
# No modulation happening here.
|
119 |
+
embedded_timestep = self.emb(timestep, **added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_dtype)
|
120 |
+
return self.linear(self.silu(embedded_timestep)), embedded_timestep
|
121 |
+
|
122 |
+
|
123 |
+
class TimePositionalEncoding(nn.Module):
|
124 |
+
def __init__(
|
125 |
+
self,
|
126 |
+
d_model,
|
127 |
+
dropout = 0.,
|
128 |
+
max_len = 24
|
129 |
+
):
|
130 |
+
super().__init__()
|
131 |
+
self.dropout = nn.Dropout(p=dropout)
|
132 |
+
position = torch.arange(max_len).unsqueeze(1)
|
133 |
+
div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
|
134 |
+
pe = torch.zeros(1, max_len, d_model)
|
135 |
+
pe[0, :, 0::2] = torch.sin(position * div_term)
|
136 |
+
pe[0, :, 1::2] = torch.cos(position * div_term)
|
137 |
+
self.register_buffer('pe', pe)
|
138 |
+
|
139 |
+
def forward(self, x):
|
140 |
+
b, c, f, h, w = x.size()
|
141 |
+
x = rearrange(x, "b c f h w -> (b h w) f c")
|
142 |
+
x = x + self.pe[:, :x.size(1)]
|
143 |
+
x = rearrange(x, "(b h w) f c -> b c f h w", b=b, h=h, w=w)
|
144 |
+
return self.dropout(x)
|
145 |
+
|
146 |
+
@dataclass
|
147 |
+
class Transformer3DModelOutput(BaseOutput):
|
148 |
+
"""
|
149 |
+
The output of [`Transformer2DModel`].
|
150 |
+
|
151 |
+
Args:
|
152 |
+
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
|
153 |
+
The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
|
154 |
+
distributions for the unnoised latent pixels.
|
155 |
+
"""
|
156 |
+
|
157 |
+
sample: torch.FloatTensor
|
158 |
+
|
159 |
+
|
160 |
+
class Transformer3DModel(ModelMixin, ConfigMixin):
|
161 |
+
"""
|
162 |
+
A 3D Transformer model for image-like data.
|
163 |
+
|
164 |
+
Parameters:
|
165 |
+
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
|
166 |
+
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
|
167 |
+
in_channels (`int`, *optional*):
|
168 |
+
The number of channels in the input and output (specify if the input is **continuous**).
|
169 |
+
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
|
170 |
+
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
171 |
+
cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
|
172 |
+
sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
|
173 |
+
This is fixed during training since it is used to learn a number of position embeddings.
|
174 |
+
num_vector_embeds (`int`, *optional*):
|
175 |
+
The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**).
|
176 |
+
Includes the class for the masked latent pixel.
|
177 |
+
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
|
178 |
+
num_embeds_ada_norm ( `int`, *optional*):
|
179 |
+
The number of diffusion steps used during training. Pass if at least one of the norm_layers is
|
180 |
+
`AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
|
181 |
+
added to the hidden states.
|
182 |
+
|
183 |
+
During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`.
|
184 |
+
attention_bias (`bool`, *optional*):
|
185 |
+
Configure if the `TransformerBlocks` attention should contain a bias parameter.
|
186 |
+
"""
|
187 |
+
|
188 |
+
_supports_gradient_checkpointing = True
|
189 |
+
|
190 |
+
@register_to_config
|
191 |
+
def __init__(
|
192 |
+
self,
|
193 |
+
num_attention_heads: int = 16,
|
194 |
+
attention_head_dim: int = 88,
|
195 |
+
in_channels: Optional[int] = None,
|
196 |
+
out_channels: Optional[int] = None,
|
197 |
+
num_layers: int = 1,
|
198 |
+
dropout: float = 0.0,
|
199 |
+
norm_num_groups: int = 32,
|
200 |
+
cross_attention_dim: Optional[int] = None,
|
201 |
+
attention_bias: bool = False,
|
202 |
+
sample_size: Optional[int] = None,
|
203 |
+
num_vector_embeds: Optional[int] = None,
|
204 |
+
patch_size: Optional[int] = None,
|
205 |
+
activation_fn: str = "geglu",
|
206 |
+
num_embeds_ada_norm: Optional[int] = None,
|
207 |
+
use_linear_projection: bool = False,
|
208 |
+
only_cross_attention: bool = False,
|
209 |
+
double_self_attention: bool = False,
|
210 |
+
upcast_attention: bool = False,
|
211 |
+
norm_type: str = "layer_norm",
|
212 |
+
norm_elementwise_affine: bool = True,
|
213 |
+
norm_eps: float = 1e-5,
|
214 |
+
attention_type: str = "default",
|
215 |
+
caption_channels: int = None,
|
216 |
+
# block type
|
217 |
+
basic_block_type: str = "motionmodule",
|
218 |
+
# enable_uvit
|
219 |
+
enable_uvit: bool = False,
|
220 |
+
|
221 |
+
# 3d patch params
|
222 |
+
patch_3d: bool = False,
|
223 |
+
fake_3d: bool = False,
|
224 |
+
time_patch_size: Optional[int] = None,
|
225 |
+
|
226 |
+
casual_3d: bool = False,
|
227 |
+
casual_3d_upsampler_index: Optional[list] = None,
|
228 |
+
|
229 |
+
# motion module kwargs
|
230 |
+
motion_module_type = "VanillaGrid",
|
231 |
+
motion_module_kwargs = None,
|
232 |
+
|
233 |
+
# time position encoding
|
234 |
+
time_position_encoding_before_transformer = False
|
235 |
+
):
|
236 |
+
super().__init__()
|
237 |
+
self.use_linear_projection = use_linear_projection
|
238 |
+
self.num_attention_heads = num_attention_heads
|
239 |
+
self.attention_head_dim = attention_head_dim
|
240 |
+
self.enable_uvit = enable_uvit
|
241 |
+
inner_dim = num_attention_heads * attention_head_dim
|
242 |
+
self.basic_block_type = basic_block_type
|
243 |
+
self.patch_3d = patch_3d
|
244 |
+
self.fake_3d = fake_3d
|
245 |
+
self.casual_3d = casual_3d
|
246 |
+
self.casual_3d_upsampler_index = casual_3d_upsampler_index
|
247 |
+
|
248 |
+
conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
|
249 |
+
linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
|
250 |
+
|
251 |
+
assert sample_size is not None, "Transformer3DModel over patched input must provide sample_size"
|
252 |
+
|
253 |
+
self.height = sample_size
|
254 |
+
self.width = sample_size
|
255 |
+
|
256 |
+
self.patch_size = patch_size
|
257 |
+
self.time_patch_size = self.patch_size if time_patch_size is None else time_patch_size
|
258 |
+
interpolation_scale = self.config.sample_size // 64 # => 64 (= 512 pixart) has interpolation scale 1
|
259 |
+
interpolation_scale = max(interpolation_scale, 1)
|
260 |
+
|
261 |
+
if self.casual_3d:
|
262 |
+
self.pos_embed = CasualPatchEmbed3D(
|
263 |
+
height=sample_size,
|
264 |
+
width=sample_size,
|
265 |
+
patch_size=patch_size,
|
266 |
+
time_patch_size=self.time_patch_size,
|
267 |
+
in_channels=in_channels,
|
268 |
+
embed_dim=inner_dim,
|
269 |
+
interpolation_scale=interpolation_scale,
|
270 |
+
)
|
271 |
+
elif self.patch_3d:
|
272 |
+
if self.fake_3d:
|
273 |
+
self.pos_embed = PatchEmbedF3D(
|
274 |
+
height=sample_size,
|
275 |
+
width=sample_size,
|
276 |
+
patch_size=patch_size,
|
277 |
+
in_channels=in_channels,
|
278 |
+
embed_dim=inner_dim,
|
279 |
+
interpolation_scale=interpolation_scale,
|
280 |
+
)
|
281 |
+
else:
|
282 |
+
self.pos_embed = PatchEmbed3D(
|
283 |
+
height=sample_size,
|
284 |
+
width=sample_size,
|
285 |
+
patch_size=patch_size,
|
286 |
+
time_patch_size=self.time_patch_size,
|
287 |
+
in_channels=in_channels,
|
288 |
+
embed_dim=inner_dim,
|
289 |
+
interpolation_scale=interpolation_scale,
|
290 |
+
)
|
291 |
+
else:
|
292 |
+
self.pos_embed = PatchEmbed(
|
293 |
+
height=sample_size,
|
294 |
+
width=sample_size,
|
295 |
+
patch_size=patch_size,
|
296 |
+
in_channels=in_channels,
|
297 |
+
embed_dim=inner_dim,
|
298 |
+
interpolation_scale=interpolation_scale,
|
299 |
+
)
|
300 |
+
|
301 |
+
# 3. Define transformers blocks
|
302 |
+
if self.basic_block_type == "motionmodule":
|
303 |
+
self.transformer_blocks = nn.ModuleList(
|
304 |
+
[
|
305 |
+
TemporalTransformerBlock(
|
306 |
+
inner_dim,
|
307 |
+
num_attention_heads,
|
308 |
+
attention_head_dim,
|
309 |
+
dropout=dropout,
|
310 |
+
cross_attention_dim=cross_attention_dim,
|
311 |
+
activation_fn=activation_fn,
|
312 |
+
num_embeds_ada_norm=num_embeds_ada_norm,
|
313 |
+
attention_bias=attention_bias,
|
314 |
+
only_cross_attention=only_cross_attention,
|
315 |
+
double_self_attention=double_self_attention,
|
316 |
+
upcast_attention=upcast_attention,
|
317 |
+
norm_type=norm_type,
|
318 |
+
norm_elementwise_affine=norm_elementwise_affine,
|
319 |
+
norm_eps=norm_eps,
|
320 |
+
attention_type=attention_type,
|
321 |
+
motion_module_type=motion_module_type,
|
322 |
+
motion_module_kwargs=motion_module_kwargs,
|
323 |
+
)
|
324 |
+
for d in range(num_layers)
|
325 |
+
]
|
326 |
+
)
|
327 |
+
elif self.basic_block_type == "kvcompression_motionmodule":
|
328 |
+
self.transformer_blocks = nn.ModuleList(
|
329 |
+
[
|
330 |
+
TemporalTransformerBlock(
|
331 |
+
inner_dim,
|
332 |
+
num_attention_heads,
|
333 |
+
attention_head_dim,
|
334 |
+
dropout=dropout,
|
335 |
+
cross_attention_dim=cross_attention_dim,
|
336 |
+
activation_fn=activation_fn,
|
337 |
+
num_embeds_ada_norm=num_embeds_ada_norm,
|
338 |
+
attention_bias=attention_bias,
|
339 |
+
only_cross_attention=only_cross_attention,
|
340 |
+
double_self_attention=double_self_attention,
|
341 |
+
upcast_attention=upcast_attention,
|
342 |
+
norm_type=norm_type,
|
343 |
+
norm_elementwise_affine=norm_elementwise_affine,
|
344 |
+
norm_eps=norm_eps,
|
345 |
+
attention_type=attention_type,
|
346 |
+
kvcompression=False if d < 14 else True,
|
347 |
+
motion_module_type=motion_module_type,
|
348 |
+
motion_module_kwargs=motion_module_kwargs,
|
349 |
+
)
|
350 |
+
for d in range(num_layers)
|
351 |
+
]
|
352 |
+
)
|
353 |
+
elif self.basic_block_type == "selfattentiontemporal":
|
354 |
+
self.transformer_blocks = nn.ModuleList(
|
355 |
+
[
|
356 |
+
SelfAttentionTemporalTransformerBlock(
|
357 |
+
inner_dim,
|
358 |
+
num_attention_heads,
|
359 |
+
attention_head_dim,
|
360 |
+
dropout=dropout,
|
361 |
+
cross_attention_dim=cross_attention_dim,
|
362 |
+
activation_fn=activation_fn,
|
363 |
+
num_embeds_ada_norm=num_embeds_ada_norm,
|
364 |
+
attention_bias=attention_bias,
|
365 |
+
only_cross_attention=only_cross_attention,
|
366 |
+
double_self_attention=double_self_attention,
|
367 |
+
upcast_attention=upcast_attention,
|
368 |
+
norm_type=norm_type,
|
369 |
+
norm_elementwise_affine=norm_elementwise_affine,
|
370 |
+
norm_eps=norm_eps,
|
371 |
+
attention_type=attention_type,
|
372 |
+
)
|
373 |
+
for d in range(num_layers)
|
374 |
+
]
|
375 |
+
)
|
376 |
+
else:
|
377 |
+
self.transformer_blocks = nn.ModuleList(
|
378 |
+
[
|
379 |
+
BasicTransformerBlock(
|
380 |
+
inner_dim,
|
381 |
+
num_attention_heads,
|
382 |
+
attention_head_dim,
|
383 |
+
dropout=dropout,
|
384 |
+
cross_attention_dim=cross_attention_dim,
|
385 |
+
activation_fn=activation_fn,
|
386 |
+
num_embeds_ada_norm=num_embeds_ada_norm,
|
387 |
+
attention_bias=attention_bias,
|
388 |
+
only_cross_attention=only_cross_attention,
|
389 |
+
double_self_attention=double_self_attention,
|
390 |
+
upcast_attention=upcast_attention,
|
391 |
+
norm_type=norm_type,
|
392 |
+
norm_elementwise_affine=norm_elementwise_affine,
|
393 |
+
norm_eps=norm_eps,
|
394 |
+
attention_type=attention_type,
|
395 |
+
)
|
396 |
+
for d in range(num_layers)
|
397 |
+
]
|
398 |
+
)
|
399 |
+
|
400 |
+
if self.casual_3d:
|
401 |
+
self.unpatch1d = TemporalUpsampler3D()
|
402 |
+
elif self.patch_3d and self.fake_3d:
|
403 |
+
self.unpatch1d = UnPatch1D(inner_dim, True)
|
404 |
+
|
405 |
+
if self.enable_uvit:
|
406 |
+
self.long_connect_fc = nn.ModuleList(
|
407 |
+
[
|
408 |
+
nn.Linear(inner_dim, inner_dim, True) for d in range(13)
|
409 |
+
]
|
410 |
+
)
|
411 |
+
for index in range(13):
|
412 |
+
self.long_connect_fc[index] = zero_module(self.long_connect_fc[index])
|
413 |
+
|
414 |
+
# 4. Define output layers
|
415 |
+
self.out_channels = in_channels if out_channels is None else out_channels
|
416 |
+
if norm_type != "ada_norm_single":
|
417 |
+
self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
|
418 |
+
self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim)
|
419 |
+
if self.patch_3d and not self.fake_3d:
|
420 |
+
self.proj_out_2 = nn.Linear(inner_dim, self.time_patch_size * patch_size * patch_size * self.out_channels)
|
421 |
+
else:
|
422 |
+
self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
|
423 |
+
elif norm_type == "ada_norm_single":
|
424 |
+
self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
|
425 |
+
self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5)
|
426 |
+
if self.patch_3d and not self.fake_3d:
|
427 |
+
self.proj_out = nn.Linear(inner_dim, self.time_patch_size * patch_size * patch_size * self.out_channels)
|
428 |
+
else:
|
429 |
+
self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
|
430 |
+
|
431 |
+
# 5. PixArt-Alpha blocks.
|
432 |
+
self.adaln_single = None
|
433 |
+
self.use_additional_conditions = False
|
434 |
+
if norm_type == "ada_norm_single":
|
435 |
+
self.use_additional_conditions = self.config.sample_size == 128
|
436 |
+
# TODO(Sayak, PVP) clean this, for now we use sample size to determine whether to use
|
437 |
+
# additional conditions until we find better name
|
438 |
+
self.adaln_single = AdaLayerNormSingle(inner_dim, use_additional_conditions=self.use_additional_conditions)
|
439 |
+
|
440 |
+
self.caption_projection = None
|
441 |
+
if caption_channels is not None:
|
442 |
+
self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim)
|
443 |
+
|
444 |
+
self.gradient_checkpointing = False
|
445 |
+
|
446 |
+
self.time_position_encoding_before_transformer = time_position_encoding_before_transformer
|
447 |
+
if self.time_position_encoding_before_transformer:
|
448 |
+
self.t_pos = TimePositionalEncoding(max_len = 4096, d_model = inner_dim)
|
449 |
+
|
450 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
451 |
+
if hasattr(module, "gradient_checkpointing"):
|
452 |
+
module.gradient_checkpointing = value
|
453 |
+
|
454 |
+
def forward(
|
455 |
+
self,
|
456 |
+
hidden_states: torch.Tensor,
|
457 |
+
inpaint_latents: torch.Tensor = None,
|
458 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
459 |
+
timestep: Optional[torch.LongTensor] = None,
|
460 |
+
added_cond_kwargs: Dict[str, torch.Tensor] = None,
|
461 |
+
class_labels: Optional[torch.LongTensor] = None,
|
462 |
+
cross_attention_kwargs: Dict[str, Any] = None,
|
463 |
+
attention_mask: Optional[torch.Tensor] = None,
|
464 |
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
465 |
+
return_dict: bool = True,
|
466 |
+
):
|
467 |
+
"""
|
468 |
+
The [`Transformer2DModel`] forward method.
|
469 |
+
|
470 |
+
Args:
|
471 |
+
hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
|
472 |
+
Input `hidden_states`.
|
473 |
+
encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
|
474 |
+
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
|
475 |
+
self-attention.
|
476 |
+
timestep ( `torch.LongTensor`, *optional*):
|
477 |
+
Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
|
478 |
+
class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
|
479 |
+
Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
|
480 |
+
`AdaLayerZeroNorm`.
|
481 |
+
cross_attention_kwargs ( `Dict[str, Any]`, *optional*):
|
482 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
483 |
+
`self.processor` in
|
484 |
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
485 |
+
attention_mask ( `torch.Tensor`, *optional*):
|
486 |
+
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
|
487 |
+
is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
|
488 |
+
negative values to the attention scores corresponding to "discard" tokens.
|
489 |
+
encoder_attention_mask ( `torch.Tensor`, *optional*):
|
490 |
+
Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
|
491 |
+
|
492 |
+
* Mask `(batch, sequence_length)` True = keep, False = discard.
|
493 |
+
* Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
|
494 |
+
|
495 |
+
If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
|
496 |
+
above. This bias will be added to the cross-attention scores.
|
497 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
498 |
+
Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
|
499 |
+
tuple.
|
500 |
+
|
501 |
+
Returns:
|
502 |
+
If `return_dict` is True, an [`~models.transformer_2d.Transformer3DModelOutput`] is returned, otherwise a
|
503 |
+
`tuple` where the first element is the sample tensor.
|
504 |
+
"""
|
505 |
+
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
|
506 |
+
# we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
|
507 |
+
# we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
|
508 |
+
# expects mask of shape:
|
509 |
+
# [batch, key_tokens]
|
510 |
+
# adds singleton query_tokens dimension:
|
511 |
+
# [batch, 1, key_tokens]
|
512 |
+
# this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
|
513 |
+
# [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
|
514 |
+
# [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
|
515 |
+
if attention_mask is not None and attention_mask.ndim == 2:
|
516 |
+
# assume that mask is expressed as:
|
517 |
+
# (1 = keep, 0 = discard)
|
518 |
+
# convert mask into a bias that can be added to attention scores:
|
519 |
+
# (keep = +0, discard = -10000.0)
|
520 |
+
attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
|
521 |
+
attention_mask = attention_mask.unsqueeze(1)
|
522 |
+
|
523 |
+
# convert encoder_attention_mask to a bias the same way we do for attention_mask
|
524 |
+
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
|
525 |
+
encoder_attention_mask = (1 - encoder_attention_mask.to(encoder_hidden_states.dtype)) * -10000.0
|
526 |
+
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
|
527 |
+
|
528 |
+
if inpaint_latents is not None:
|
529 |
+
hidden_states = torch.concat([hidden_states, inpaint_latents], 1)
|
530 |
+
# 1. Input
|
531 |
+
if self.casual_3d:
|
532 |
+
video_length, height, width = (hidden_states.shape[-3] - 1) // self.time_patch_size + 1, hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size
|
533 |
+
elif self.patch_3d:
|
534 |
+
video_length, height, width = hidden_states.shape[-3] // self.time_patch_size, hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size
|
535 |
+
else:
|
536 |
+
video_length, height, width = hidden_states.shape[-3], hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size
|
537 |
+
hidden_states = rearrange(hidden_states, "b c f h w ->(b f) c h w")
|
538 |
+
|
539 |
+
hidden_states = self.pos_embed(hidden_states)
|
540 |
+
if self.adaln_single is not None:
|
541 |
+
if self.use_additional_conditions and added_cond_kwargs is None:
|
542 |
+
raise ValueError(
|
543 |
+
"`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`."
|
544 |
+
)
|
545 |
+
batch_size = hidden_states.shape[0] // video_length
|
546 |
+
timestep, embedded_timestep = self.adaln_single(
|
547 |
+
timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype
|
548 |
+
)
|
549 |
+
hidden_states = rearrange(hidden_states, "(b f) (h w) c -> b c f h w", f=video_length, h=height, w=width)
|
550 |
+
|
551 |
+
# hidden_states
|
552 |
+
# bs, c, f, h, w => b (f h w ) c
|
553 |
+
if self.time_position_encoding_before_transformer:
|
554 |
+
hidden_states = self.t_pos(hidden_states)
|
555 |
+
hidden_states = hidden_states.flatten(2).transpose(1, 2)
|
556 |
+
|
557 |
+
# 2. Blocks
|
558 |
+
if self.caption_projection is not None:
|
559 |
+
batch_size = hidden_states.shape[0]
|
560 |
+
encoder_hidden_states = self.caption_projection(encoder_hidden_states)
|
561 |
+
encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
|
562 |
+
|
563 |
+
skips = []
|
564 |
+
skip_index = 0
|
565 |
+
for index, block in enumerate(self.transformer_blocks):
|
566 |
+
if self.enable_uvit:
|
567 |
+
if index >= 15:
|
568 |
+
long_connect = self.long_connect_fc[skip_index](skips.pop())
|
569 |
+
hidden_states = hidden_states + long_connect
|
570 |
+
skip_index += 1
|
571 |
+
|
572 |
+
if self.casual_3d_upsampler_index is not None and index in self.casual_3d_upsampler_index:
|
573 |
+
hidden_states = rearrange(hidden_states, "b (f h w) c -> b c f h w", f=video_length, h=height, w=width)
|
574 |
+
hidden_states = self.unpatch1d(hidden_states)
|
575 |
+
video_length = (video_length - 1) * 2 + 1
|
576 |
+
hidden_states = rearrange(hidden_states, "b c f h w -> b (f h w) c", f=video_length, h=height, w=width)
|
577 |
+
|
578 |
+
if self.training and self.gradient_checkpointing:
|
579 |
+
|
580 |
+
def create_custom_forward(module, return_dict=None):
|
581 |
+
def custom_forward(*inputs):
|
582 |
+
if return_dict is not None:
|
583 |
+
return module(*inputs, return_dict=return_dict)
|
584 |
+
else:
|
585 |
+
return module(*inputs)
|
586 |
+
|
587 |
+
return custom_forward
|
588 |
+
|
589 |
+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
590 |
+
args = {
|
591 |
+
"basic": [],
|
592 |
+
"motionmodule": [video_length, height, width],
|
593 |
+
"selfattentiontemporal": [video_length, height, width],
|
594 |
+
"kvcompression_motionmodule": [video_length, height, width],
|
595 |
+
}[self.basic_block_type]
|
596 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
597 |
+
create_custom_forward(block),
|
598 |
+
hidden_states,
|
599 |
+
attention_mask,
|
600 |
+
encoder_hidden_states,
|
601 |
+
encoder_attention_mask,
|
602 |
+
timestep,
|
603 |
+
cross_attention_kwargs,
|
604 |
+
class_labels,
|
605 |
+
*args,
|
606 |
+
**ckpt_kwargs,
|
607 |
+
)
|
608 |
+
else:
|
609 |
+
kwargs = {
|
610 |
+
"basic": {},
|
611 |
+
"motionmodule": {"num_frames":video_length, "height":height, "width":width},
|
612 |
+
"selfattentiontemporal": {"num_frames":video_length, "height":height, "width":width},
|
613 |
+
"kvcompression_motionmodule": {"num_frames":video_length, "height":height, "width":width},
|
614 |
+
}[self.basic_block_type]
|
615 |
+
hidden_states = block(
|
616 |
+
hidden_states,
|
617 |
+
attention_mask=attention_mask,
|
618 |
+
encoder_hidden_states=encoder_hidden_states,
|
619 |
+
encoder_attention_mask=encoder_attention_mask,
|
620 |
+
timestep=timestep,
|
621 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
622 |
+
class_labels=class_labels,
|
623 |
+
**kwargs
|
624 |
+
)
|
625 |
+
|
626 |
+
if self.enable_uvit:
|
627 |
+
if index < 13:
|
628 |
+
skips.append(hidden_states)
|
629 |
+
|
630 |
+
if self.fake_3d and self.patch_3d:
|
631 |
+
hidden_states = rearrange(hidden_states, "b (f h w) c -> (b h w) c f", f=video_length, w=width, h=height)
|
632 |
+
hidden_states = self.unpatch1d(hidden_states)
|
633 |
+
hidden_states = rearrange(hidden_states, "(b h w) c f -> b (f h w) c", w=width, h=height)
|
634 |
+
|
635 |
+
# 3. Output
|
636 |
+
if self.config.norm_type != "ada_norm_single":
|
637 |
+
conditioning = self.transformer_blocks[0].norm1.emb(
|
638 |
+
timestep, class_labels, hidden_dtype=hidden_states.dtype
|
639 |
+
)
|
640 |
+
shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
|
641 |
+
hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
|
642 |
+
hidden_states = self.proj_out_2(hidden_states)
|
643 |
+
elif self.config.norm_type == "ada_norm_single":
|
644 |
+
shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1)
|
645 |
+
hidden_states = self.norm_out(hidden_states)
|
646 |
+
# Modulation
|
647 |
+
hidden_states = hidden_states * (1 + scale) + shift
|
648 |
+
hidden_states = self.proj_out(hidden_states)
|
649 |
+
hidden_states = hidden_states.squeeze(1)
|
650 |
+
|
651 |
+
# unpatchify
|
652 |
+
if self.adaln_single is None:
|
653 |
+
height = width = int(hidden_states.shape[1] ** 0.5)
|
654 |
+
if self.patch_3d:
|
655 |
+
if self.fake_3d:
|
656 |
+
hidden_states = hidden_states.reshape(
|
657 |
+
shape=(-1, video_length * self.patch_size, height, width, self.patch_size, self.patch_size, self.out_channels)
|
658 |
+
)
|
659 |
+
hidden_states = torch.einsum("nfhwpqc->ncfhpwq", hidden_states)
|
660 |
+
else:
|
661 |
+
hidden_states = hidden_states.reshape(
|
662 |
+
shape=(-1, video_length, height, width, self.time_patch_size, self.patch_size, self.patch_size, self.out_channels)
|
663 |
+
)
|
664 |
+
hidden_states = torch.einsum("nfhwopqc->ncfohpwq", hidden_states)
|
665 |
+
output = hidden_states.reshape(
|
666 |
+
shape=(-1, self.out_channels, video_length * self.time_patch_size, height * self.patch_size, width * self.patch_size)
|
667 |
+
)
|
668 |
+
else:
|
669 |
+
hidden_states = hidden_states.reshape(
|
670 |
+
shape=(-1, video_length, height, width, self.patch_size, self.patch_size, self.out_channels)
|
671 |
+
)
|
672 |
+
hidden_states = torch.einsum("nfhwpqc->ncfhpwq", hidden_states)
|
673 |
+
output = hidden_states.reshape(
|
674 |
+
shape=(-1, self.out_channels, video_length, height * self.patch_size, width * self.patch_size)
|
675 |
+
)
|
676 |
+
|
677 |
+
if not return_dict:
|
678 |
+
return (output,)
|
679 |
+
|
680 |
+
return Transformer3DModelOutput(sample=output)
|
681 |
+
|
682 |
+
@classmethod
|
683 |
+
def from_pretrained_2d(cls, pretrained_model_path, subfolder=None, patch_size=2, transformer_additional_kwargs={}):
|
684 |
+
if subfolder is not None:
|
685 |
+
pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
|
686 |
+
print(f"loaded 3D transformer's pretrained weights from {pretrained_model_path} ...")
|
687 |
+
|
688 |
+
config_file = os.path.join(pretrained_model_path, 'config.json')
|
689 |
+
if not os.path.isfile(config_file):
|
690 |
+
raise RuntimeError(f"{config_file} does not exist")
|
691 |
+
with open(config_file, "r") as f:
|
692 |
+
config = json.load(f)
|
693 |
+
|
694 |
+
from diffusers.utils import WEIGHTS_NAME
|
695 |
+
model = cls.from_config(config, **transformer_additional_kwargs)
|
696 |
+
model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
|
697 |
+
model_file_safetensors = model_file.replace(".bin", ".safetensors")
|
698 |
+
if os.path.exists(model_file_safetensors):
|
699 |
+
from safetensors.torch import load_file, safe_open
|
700 |
+
state_dict = load_file(model_file_safetensors)
|
701 |
+
else:
|
702 |
+
if not os.path.isfile(model_file):
|
703 |
+
raise RuntimeError(f"{model_file} does not exist")
|
704 |
+
state_dict = torch.load(model_file, map_location="cpu")
|
705 |
+
|
706 |
+
if model.state_dict()['pos_embed.proj.weight'].size() != state_dict['pos_embed.proj.weight'].size():
|
707 |
+
new_shape = model.state_dict()['pos_embed.proj.weight'].size()
|
708 |
+
if len(new_shape) == 5:
|
709 |
+
state_dict['pos_embed.proj.weight'] = state_dict['pos_embed.proj.weight'].unsqueeze(2).expand(new_shape).clone()
|
710 |
+
state_dict['pos_embed.proj.weight'][:, :, :-1] = 0
|
711 |
+
else:
|
712 |
+
model.state_dict()['pos_embed.proj.weight'][:, :4, :, :] = state_dict['pos_embed.proj.weight']
|
713 |
+
model.state_dict()['pos_embed.proj.weight'][:, 4:, :, :] = 0
|
714 |
+
state_dict['pos_embed.proj.weight'] = model.state_dict()['pos_embed.proj.weight']
|
715 |
+
|
716 |
+
if model.state_dict()['proj_out.weight'].size() != state_dict['proj_out.weight'].size():
|
717 |
+
new_shape = model.state_dict()['proj_out.weight'].size()
|
718 |
+
state_dict['proj_out.weight'] = torch.tile(state_dict['proj_out.weight'], [patch_size, 1])
|
719 |
+
|
720 |
+
if model.state_dict()['proj_out.bias'].size() != state_dict['proj_out.bias'].size():
|
721 |
+
new_shape = model.state_dict()['proj_out.bias'].size()
|
722 |
+
state_dict['proj_out.bias'] = torch.tile(state_dict['proj_out.bias'], [patch_size])
|
723 |
+
|
724 |
+
tmp_state_dict = {}
|
725 |
+
for key in state_dict:
|
726 |
+
if key in model.state_dict().keys() and model.state_dict()[key].size() == state_dict[key].size():
|
727 |
+
tmp_state_dict[key] = state_dict[key]
|
728 |
+
else:
|
729 |
+
print(key, "Size don't match, skip")
|
730 |
+
state_dict = tmp_state_dict
|
731 |
+
|
732 |
+
m, u = model.load_state_dict(state_dict, strict=False)
|
733 |
+
print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
|
734 |
+
|
735 |
+
params = [p.numel() if "attn_temporal." in n else 0 for n, p in model.named_parameters()]
|
736 |
+
print(f"### Attn temporal Parameters: {sum(params) / 1e6} M")
|
737 |
+
|
738 |
+
return model
|
easyanimate/pipeline/pipeline_easyanimate.py
ADDED
@@ -0,0 +1,847 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 PixArt-Alpha Authors and The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import html
|
16 |
+
import inspect
|
17 |
+
import copy
|
18 |
+
import re
|
19 |
+
import urllib.parse as ul
|
20 |
+
from dataclasses import dataclass
|
21 |
+
from typing import Callable, List, Optional, Tuple, Union
|
22 |
+
|
23 |
+
import numpy as np
|
24 |
+
import torch
|
25 |
+
from diffusers import DiffusionPipeline, ImagePipelineOutput
|
26 |
+
from diffusers.image_processor import VaeImageProcessor
|
27 |
+
from diffusers.models import AutoencoderKL
|
28 |
+
from diffusers.schedulers import DPMSolverMultistepScheduler
|
29 |
+
from diffusers.utils import (BACKENDS_MAPPING, BaseOutput, deprecate,
|
30 |
+
is_bs4_available, is_ftfy_available, logging,
|
31 |
+
replace_example_docstring)
|
32 |
+
from diffusers.utils.torch_utils import randn_tensor
|
33 |
+
from einops import rearrange
|
34 |
+
from tqdm import tqdm
|
35 |
+
from transformers import T5EncoderModel, T5Tokenizer
|
36 |
+
|
37 |
+
from ..models.transformer3d import Transformer3DModel
|
38 |
+
|
39 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
40 |
+
|
41 |
+
if is_bs4_available():
|
42 |
+
from bs4 import BeautifulSoup
|
43 |
+
|
44 |
+
if is_ftfy_available():
|
45 |
+
import ftfy
|
46 |
+
|
47 |
+
|
48 |
+
EXAMPLE_DOC_STRING = """
|
49 |
+
Examples:
|
50 |
+
```py
|
51 |
+
>>> import torch
|
52 |
+
>>> from diffusers import EasyAnimatePipeline
|
53 |
+
|
54 |
+
>>> # You can replace the checkpoint id with "PixArt-alpha/PixArt-XL-2-512x512" too.
|
55 |
+
>>> pipe = EasyAnimatePipeline.from_pretrained("PixArt-alpha/PixArt-XL-2-1024-MS", torch_dtype=torch.float16)
|
56 |
+
>>> # Enable memory optimizations.
|
57 |
+
>>> pipe.enable_model_cpu_offload()
|
58 |
+
|
59 |
+
>>> prompt = "A small cactus with a happy face in the Sahara desert."
|
60 |
+
>>> image = pipe(prompt).images[0]
|
61 |
+
```
|
62 |
+
"""
|
63 |
+
|
64 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
65 |
+
def retrieve_timesteps(
|
66 |
+
scheduler,
|
67 |
+
num_inference_steps: Optional[int] = None,
|
68 |
+
device: Optional[Union[str, torch.device]] = None,
|
69 |
+
timesteps: Optional[List[int]] = None,
|
70 |
+
**kwargs,
|
71 |
+
):
|
72 |
+
"""
|
73 |
+
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
74 |
+
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
75 |
+
|
76 |
+
Args:
|
77 |
+
scheduler (`SchedulerMixin`):
|
78 |
+
The scheduler to get timesteps from.
|
79 |
+
num_inference_steps (`int`):
|
80 |
+
The number of diffusion steps used when generating samples with a pre-trained model. If used,
|
81 |
+
`timesteps` must be `None`.
|
82 |
+
device (`str` or `torch.device`, *optional*):
|
83 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
84 |
+
timesteps (`List[int]`, *optional*):
|
85 |
+
Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
|
86 |
+
timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
|
87 |
+
must be `None`.
|
88 |
+
|
89 |
+
Returns:
|
90 |
+
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
91 |
+
second element is the number of inference steps.
|
92 |
+
"""
|
93 |
+
if timesteps is not None:
|
94 |
+
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
95 |
+
if not accepts_timesteps:
|
96 |
+
raise ValueError(
|
97 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
98 |
+
f" timestep schedules. Please check whether you are using the correct scheduler."
|
99 |
+
)
|
100 |
+
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
101 |
+
timesteps = scheduler.timesteps
|
102 |
+
num_inference_steps = len(timesteps)
|
103 |
+
else:
|
104 |
+
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
105 |
+
timesteps = scheduler.timesteps
|
106 |
+
return timesteps, num_inference_steps
|
107 |
+
|
108 |
+
@dataclass
|
109 |
+
class EasyAnimatePipelineOutput(BaseOutput):
|
110 |
+
videos: Union[torch.Tensor, np.ndarray]
|
111 |
+
|
112 |
+
class EasyAnimatePipeline(DiffusionPipeline):
|
113 |
+
r"""
|
114 |
+
Pipeline for text-to-image generation using PixArt-Alpha.
|
115 |
+
|
116 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
117 |
+
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
118 |
+
|
119 |
+
Args:
|
120 |
+
vae ([`AutoencoderKL`]):
|
121 |
+
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
122 |
+
text_encoder ([`T5EncoderModel`]):
|
123 |
+
Frozen text-encoder. PixArt-Alpha uses
|
124 |
+
[T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the
|
125 |
+
[t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant.
|
126 |
+
tokenizer (`T5Tokenizer`):
|
127 |
+
Tokenizer of class
|
128 |
+
[T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
|
129 |
+
transformer ([`Transformer3DModel`]):
|
130 |
+
A text conditioned `Transformer3DModel` to denoise the encoded image latents.
|
131 |
+
scheduler ([`SchedulerMixin`]):
|
132 |
+
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
133 |
+
"""
|
134 |
+
bad_punct_regex = re.compile(
|
135 |
+
r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}"
|
136 |
+
) # noqa
|
137 |
+
|
138 |
+
_optional_components = ["tokenizer", "text_encoder"]
|
139 |
+
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
140 |
+
|
141 |
+
def __init__(
|
142 |
+
self,
|
143 |
+
tokenizer: T5Tokenizer,
|
144 |
+
text_encoder: T5EncoderModel,
|
145 |
+
vae: AutoencoderKL,
|
146 |
+
transformer: Transformer3DModel,
|
147 |
+
scheduler: DPMSolverMultistepScheduler,
|
148 |
+
):
|
149 |
+
super().__init__()
|
150 |
+
|
151 |
+
self.register_modules(
|
152 |
+
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
|
153 |
+
)
|
154 |
+
|
155 |
+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
156 |
+
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
157 |
+
|
158 |
+
# Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/utils.py
|
159 |
+
def mask_text_embeddings(self, emb, mask):
|
160 |
+
if emb.shape[0] == 1:
|
161 |
+
keep_index = mask.sum().item()
|
162 |
+
return emb[:, :, :keep_index, :], keep_index
|
163 |
+
else:
|
164 |
+
masked_feature = emb * mask[:, None, :, None]
|
165 |
+
return masked_feature, emb.shape[2]
|
166 |
+
|
167 |
+
# Adapted from diffusers.pipelines.deepfloyd_if.pipeline_if.encode_prompt
|
168 |
+
def encode_prompt(
|
169 |
+
self,
|
170 |
+
prompt: Union[str, List[str]],
|
171 |
+
do_classifier_free_guidance: bool = True,
|
172 |
+
negative_prompt: str = "",
|
173 |
+
num_images_per_prompt: int = 1,
|
174 |
+
device: Optional[torch.device] = None,
|
175 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
176 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
177 |
+
prompt_attention_mask: Optional[torch.FloatTensor] = None,
|
178 |
+
negative_prompt_attention_mask: Optional[torch.FloatTensor] = None,
|
179 |
+
clean_caption: bool = False,
|
180 |
+
max_sequence_length: int = 120,
|
181 |
+
**kwargs,
|
182 |
+
):
|
183 |
+
r"""
|
184 |
+
Encodes the prompt into text encoder hidden states.
|
185 |
+
|
186 |
+
Args:
|
187 |
+
prompt (`str` or `List[str]`, *optional*):
|
188 |
+
prompt to be encoded
|
189 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
190 |
+
The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds`
|
191 |
+
instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For
|
192 |
+
PixArt-Alpha, this should be "".
|
193 |
+
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
|
194 |
+
whether to use classifier free guidance or not
|
195 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
196 |
+
number of images that should be generated per prompt
|
197 |
+
device: (`torch.device`, *optional*):
|
198 |
+
torch device to place the resulting embeddings on
|
199 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
200 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
201 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
202 |
+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
203 |
+
Pre-generated negative text embeddings. For PixArt-Alpha, it's should be the embeddings of the ""
|
204 |
+
string.
|
205 |
+
clean_caption (`bool`, defaults to `False`):
|
206 |
+
If `True`, the function will preprocess and clean the provided caption before encoding.
|
207 |
+
max_sequence_length (`int`, defaults to 120): Maximum sequence length to use for the prompt.
|
208 |
+
"""
|
209 |
+
|
210 |
+
if "mask_feature" in kwargs:
|
211 |
+
deprecation_message = "The use of `mask_feature` is deprecated. It is no longer used in any computation and that doesn't affect the end results. It will be removed in a future version."
|
212 |
+
deprecate("mask_feature", "1.0.0", deprecation_message, standard_warn=False)
|
213 |
+
|
214 |
+
if device is None:
|
215 |
+
device = self._execution_device
|
216 |
+
|
217 |
+
if prompt is not None and isinstance(prompt, str):
|
218 |
+
batch_size = 1
|
219 |
+
elif prompt is not None and isinstance(prompt, list):
|
220 |
+
batch_size = len(prompt)
|
221 |
+
else:
|
222 |
+
batch_size = prompt_embeds.shape[0]
|
223 |
+
|
224 |
+
# See Section 3.1. of the paper.
|
225 |
+
max_length = max_sequence_length
|
226 |
+
|
227 |
+
if prompt_embeds is None:
|
228 |
+
prompt = self._text_preprocessing(prompt, clean_caption=clean_caption)
|
229 |
+
text_inputs = self.tokenizer(
|
230 |
+
prompt,
|
231 |
+
padding="max_length",
|
232 |
+
max_length=max_length,
|
233 |
+
truncation=True,
|
234 |
+
add_special_tokens=True,
|
235 |
+
return_tensors="pt",
|
236 |
+
)
|
237 |
+
text_input_ids = text_inputs.input_ids
|
238 |
+
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
239 |
+
|
240 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
241 |
+
text_input_ids, untruncated_ids
|
242 |
+
):
|
243 |
+
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1])
|
244 |
+
logger.warning(
|
245 |
+
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
246 |
+
f" {max_length} tokens: {removed_text}"
|
247 |
+
)
|
248 |
+
|
249 |
+
prompt_attention_mask = text_inputs.attention_mask
|
250 |
+
prompt_attention_mask = prompt_attention_mask.to(device)
|
251 |
+
|
252 |
+
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask)
|
253 |
+
prompt_embeds = prompt_embeds[0]
|
254 |
+
|
255 |
+
if self.text_encoder is not None:
|
256 |
+
dtype = self.text_encoder.dtype
|
257 |
+
elif self.transformer is not None:
|
258 |
+
dtype = self.transformer.dtype
|
259 |
+
else:
|
260 |
+
dtype = None
|
261 |
+
|
262 |
+
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
263 |
+
|
264 |
+
bs_embed, seq_len, _ = prompt_embeds.shape
|
265 |
+
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
|
266 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
267 |
+
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
268 |
+
prompt_attention_mask = prompt_attention_mask.view(bs_embed, -1)
|
269 |
+
prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
|
270 |
+
|
271 |
+
# get unconditional embeddings for classifier free guidance
|
272 |
+
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
273 |
+
uncond_tokens = [negative_prompt] * batch_size
|
274 |
+
uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption)
|
275 |
+
max_length = prompt_embeds.shape[1]
|
276 |
+
uncond_input = self.tokenizer(
|
277 |
+
uncond_tokens,
|
278 |
+
padding="max_length",
|
279 |
+
max_length=max_length,
|
280 |
+
truncation=True,
|
281 |
+
return_attention_mask=True,
|
282 |
+
add_special_tokens=True,
|
283 |
+
return_tensors="pt",
|
284 |
+
)
|
285 |
+
negative_prompt_attention_mask = uncond_input.attention_mask
|
286 |
+
negative_prompt_attention_mask = negative_prompt_attention_mask.to(device)
|
287 |
+
|
288 |
+
negative_prompt_embeds = self.text_encoder(
|
289 |
+
uncond_input.input_ids.to(device), attention_mask=negative_prompt_attention_mask
|
290 |
+
)
|
291 |
+
negative_prompt_embeds = negative_prompt_embeds[0]
|
292 |
+
|
293 |
+
if do_classifier_free_guidance:
|
294 |
+
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
295 |
+
seq_len = negative_prompt_embeds.shape[1]
|
296 |
+
|
297 |
+
negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
|
298 |
+
|
299 |
+
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
300 |
+
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
301 |
+
|
302 |
+
negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed, -1)
|
303 |
+
negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1)
|
304 |
+
else:
|
305 |
+
negative_prompt_embeds = None
|
306 |
+
negative_prompt_attention_mask = None
|
307 |
+
|
308 |
+
return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
|
309 |
+
|
310 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
311 |
+
def prepare_extra_step_kwargs(self, generator, eta):
|
312 |
+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
313 |
+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
314 |
+
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
315 |
+
# and should be between [0, 1]
|
316 |
+
|
317 |
+
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
318 |
+
extra_step_kwargs = {}
|
319 |
+
if accepts_eta:
|
320 |
+
extra_step_kwargs["eta"] = eta
|
321 |
+
|
322 |
+
# check if the scheduler accepts generator
|
323 |
+
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
324 |
+
if accepts_generator:
|
325 |
+
extra_step_kwargs["generator"] = generator
|
326 |
+
return extra_step_kwargs
|
327 |
+
|
328 |
+
def check_inputs(
|
329 |
+
self,
|
330 |
+
prompt,
|
331 |
+
height,
|
332 |
+
width,
|
333 |
+
negative_prompt,
|
334 |
+
callback_steps,
|
335 |
+
prompt_embeds=None,
|
336 |
+
negative_prompt_embeds=None,
|
337 |
+
):
|
338 |
+
if height % 8 != 0 or width % 8 != 0:
|
339 |
+
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
340 |
+
|
341 |
+
if (callback_steps is None) or (
|
342 |
+
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
343 |
+
):
|
344 |
+
raise ValueError(
|
345 |
+
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
346 |
+
f" {type(callback_steps)}."
|
347 |
+
)
|
348 |
+
|
349 |
+
if prompt is not None and prompt_embeds is not None:
|
350 |
+
raise ValueError(
|
351 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
352 |
+
" only forward one of the two."
|
353 |
+
)
|
354 |
+
elif prompt is None and prompt_embeds is None:
|
355 |
+
raise ValueError(
|
356 |
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
357 |
+
)
|
358 |
+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
359 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
360 |
+
|
361 |
+
if prompt is not None and negative_prompt_embeds is not None:
|
362 |
+
raise ValueError(
|
363 |
+
f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
|
364 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
365 |
+
)
|
366 |
+
|
367 |
+
if negative_prompt is not None and negative_prompt_embeds is not None:
|
368 |
+
raise ValueError(
|
369 |
+
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
370 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
371 |
+
)
|
372 |
+
|
373 |
+
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
374 |
+
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
375 |
+
raise ValueError(
|
376 |
+
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
377 |
+
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
378 |
+
f" {negative_prompt_embeds.shape}."
|
379 |
+
)
|
380 |
+
|
381 |
+
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing
|
382 |
+
def _text_preprocessing(self, text, clean_caption=False):
|
383 |
+
if clean_caption and not is_bs4_available():
|
384 |
+
logger.warn(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`"))
|
385 |
+
logger.warn("Setting `clean_caption` to False...")
|
386 |
+
clean_caption = False
|
387 |
+
|
388 |
+
if clean_caption and not is_ftfy_available():
|
389 |
+
logger.warn(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`"))
|
390 |
+
logger.warn("Setting `clean_caption` to False...")
|
391 |
+
clean_caption = False
|
392 |
+
|
393 |
+
if not isinstance(text, (tuple, list)):
|
394 |
+
text = [text]
|
395 |
+
|
396 |
+
def process(text: str):
|
397 |
+
if clean_caption:
|
398 |
+
text = self._clean_caption(text)
|
399 |
+
text = self._clean_caption(text)
|
400 |
+
else:
|
401 |
+
text = text.lower().strip()
|
402 |
+
return text
|
403 |
+
|
404 |
+
return [process(t) for t in text]
|
405 |
+
|
406 |
+
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption
|
407 |
+
def _clean_caption(self, caption):
|
408 |
+
caption = str(caption)
|
409 |
+
caption = ul.unquote_plus(caption)
|
410 |
+
caption = caption.strip().lower()
|
411 |
+
caption = re.sub("<person>", "person", caption)
|
412 |
+
# urls:
|
413 |
+
caption = re.sub(
|
414 |
+
r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
|
415 |
+
"",
|
416 |
+
caption,
|
417 |
+
) # regex for urls
|
418 |
+
caption = re.sub(
|
419 |
+
r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
|
420 |
+
"",
|
421 |
+
caption,
|
422 |
+
) # regex for urls
|
423 |
+
# html:
|
424 |
+
caption = BeautifulSoup(caption, features="html.parser").text
|
425 |
+
|
426 |
+
# @<nickname>
|
427 |
+
caption = re.sub(r"@[\w\d]+\b", "", caption)
|
428 |
+
|
429 |
+
# 31C0—31EF CJK Strokes
|
430 |
+
# 31F0—31FF Katakana Phonetic Extensions
|
431 |
+
# 3200—32FF Enclosed CJK Letters and Months
|
432 |
+
# 3300—33FF CJK Compatibility
|
433 |
+
# 3400—4DBF CJK Unified Ideographs Extension A
|
434 |
+
# 4DC0—4DFF Yijing Hexagram Symbols
|
435 |
+
# 4E00—9FFF CJK Unified Ideographs
|
436 |
+
caption = re.sub(r"[\u31c0-\u31ef]+", "", caption)
|
437 |
+
caption = re.sub(r"[\u31f0-\u31ff]+", "", caption)
|
438 |
+
caption = re.sub(r"[\u3200-\u32ff]+", "", caption)
|
439 |
+
caption = re.sub(r"[\u3300-\u33ff]+", "", caption)
|
440 |
+
caption = re.sub(r"[\u3400-\u4dbf]+", "", caption)
|
441 |
+
caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption)
|
442 |
+
caption = re.sub(r"[\u4e00-\u9fff]+", "", caption)
|
443 |
+
#######################################################
|
444 |
+
|
445 |
+
# все виды тире / all types of dash --> "-"
|
446 |
+
caption = re.sub(
|
447 |
+
r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa
|
448 |
+
"-",
|
449 |
+
caption,
|
450 |
+
)
|
451 |
+
|
452 |
+
# кавычки к одному стандарту
|
453 |
+
caption = re.sub(r"[`´«»“”¨]", '"', caption)
|
454 |
+
caption = re.sub(r"[‘’]", "'", caption)
|
455 |
+
|
456 |
+
# "
|
457 |
+
caption = re.sub(r""?", "", caption)
|
458 |
+
# &
|
459 |
+
caption = re.sub(r"&", "", caption)
|
460 |
+
|
461 |
+
# ip adresses:
|
462 |
+
caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
|
463 |
+
|
464 |
+
# article ids:
|
465 |
+
caption = re.sub(r"\d:\d\d\s+$", "", caption)
|
466 |
+
|
467 |
+
# \n
|
468 |
+
caption = re.sub(r"\\n", " ", caption)
|
469 |
+
|
470 |
+
# "#123"
|
471 |
+
caption = re.sub(r"#\d{1,3}\b", "", caption)
|
472 |
+
# "#12345.."
|
473 |
+
caption = re.sub(r"#\d{5,}\b", "", caption)
|
474 |
+
# "123456.."
|
475 |
+
caption = re.sub(r"\b\d{6,}\b", "", caption)
|
476 |
+
# filenames:
|
477 |
+
caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption)
|
478 |
+
|
479 |
+
#
|
480 |
+
caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT"""
|
481 |
+
caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT"""
|
482 |
+
|
483 |
+
caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT
|
484 |
+
caption = re.sub(r"\s+\.\s+", r" ", caption) # " . "
|
485 |
+
|
486 |
+
# this-is-my-cute-cat / this_is_my_cute_cat
|
487 |
+
regex2 = re.compile(r"(?:\-|\_)")
|
488 |
+
if len(re.findall(regex2, caption)) > 3:
|
489 |
+
caption = re.sub(regex2, " ", caption)
|
490 |
+
|
491 |
+
caption = ftfy.fix_text(caption)
|
492 |
+
caption = html.unescape(html.unescape(caption))
|
493 |
+
|
494 |
+
caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640
|
495 |
+
caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc
|
496 |
+
caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231
|
497 |
+
|
498 |
+
caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption)
|
499 |
+
caption = re.sub(r"(free\s)?download(\sfree)?", "", caption)
|
500 |
+
caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption)
|
501 |
+
caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption)
|
502 |
+
caption = re.sub(r"\bpage\s+\d+\b", "", caption)
|
503 |
+
|
504 |
+
caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a...
|
505 |
+
|
506 |
+
caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption)
|
507 |
+
|
508 |
+
caption = re.sub(r"\b\s+\:\s+", r": ", caption)
|
509 |
+
caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption)
|
510 |
+
caption = re.sub(r"\s+", " ", caption)
|
511 |
+
|
512 |
+
caption.strip()
|
513 |
+
|
514 |
+
caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption)
|
515 |
+
caption = re.sub(r"^[\'\_,\-\:;]", r"", caption)
|
516 |
+
caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption)
|
517 |
+
caption = re.sub(r"^\.\S+$", "", caption)
|
518 |
+
|
519 |
+
return caption.strip()
|
520 |
+
|
521 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
|
522 |
+
def prepare_latents(self, batch_size, num_channels_latents, video_length, height, width, dtype, device, generator, latents=None):
|
523 |
+
if self.vae.quant_conv.weight.ndim==5:
|
524 |
+
mini_batch_encoder = self.vae.mini_batch_encoder
|
525 |
+
mini_batch_decoder = self.vae.mini_batch_decoder
|
526 |
+
shape = (batch_size, num_channels_latents, int(video_length // mini_batch_encoder * mini_batch_decoder) if video_length != 1 else 1, height // self.vae_scale_factor, width // self.vae_scale_factor)
|
527 |
+
else:
|
528 |
+
shape = (batch_size, num_channels_latents, video_length, height // self.vae_scale_factor, width // self.vae_scale_factor)
|
529 |
+
|
530 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
531 |
+
raise ValueError(
|
532 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
533 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
534 |
+
)
|
535 |
+
|
536 |
+
if latents is None:
|
537 |
+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
538 |
+
else:
|
539 |
+
latents = latents.to(device)
|
540 |
+
|
541 |
+
# scale the initial noise by the standard deviation required by the scheduler
|
542 |
+
latents = latents * self.scheduler.init_noise_sigma
|
543 |
+
return latents
|
544 |
+
|
545 |
+
def smooth_output(self, video, mini_batch_encoder, mini_batch_decoder):
|
546 |
+
if video.size()[2] <= mini_batch_encoder:
|
547 |
+
return video
|
548 |
+
prefix_index_before = mini_batch_encoder // 2
|
549 |
+
prefix_index_after = mini_batch_encoder - prefix_index_before
|
550 |
+
pixel_values = video[:, :, prefix_index_before:-prefix_index_after]
|
551 |
+
|
552 |
+
if self.vae.slice_compression_vae:
|
553 |
+
latents = self.vae.encode(pixel_values)[0]
|
554 |
+
latents = latents.sample()
|
555 |
+
else:
|
556 |
+
new_pixel_values = []
|
557 |
+
for i in range(0, pixel_values.shape[2], mini_batch_encoder):
|
558 |
+
with torch.no_grad():
|
559 |
+
pixel_values_bs = pixel_values[:, :, i: i + mini_batch_encoder, :, :]
|
560 |
+
pixel_values_bs = self.vae.encode(pixel_values_bs)[0]
|
561 |
+
pixel_values_bs = pixel_values_bs.sample()
|
562 |
+
new_pixel_values.append(pixel_values_bs)
|
563 |
+
latents = torch.cat(new_pixel_values, dim = 2)
|
564 |
+
|
565 |
+
if self.vae.slice_compression_vae:
|
566 |
+
middle_video = self.vae.decode(latents)[0]
|
567 |
+
else:
|
568 |
+
middle_video = []
|
569 |
+
for i in range(0, latents.shape[2], mini_batch_decoder):
|
570 |
+
with torch.no_grad():
|
571 |
+
start_index = i
|
572 |
+
end_index = i + mini_batch_decoder
|
573 |
+
latents_bs = self.vae.decode(latents[:, :, start_index:end_index, :, :])[0]
|
574 |
+
middle_video.append(latents_bs)
|
575 |
+
middle_video = torch.cat(middle_video, 2)
|
576 |
+
video[:, :, prefix_index_before:-prefix_index_after] = (video[:, :, prefix_index_before:-prefix_index_after] + middle_video) / 2
|
577 |
+
return video
|
578 |
+
|
579 |
+
def decode_latents(self, latents):
|
580 |
+
video_length = latents.shape[2]
|
581 |
+
latents = 1 / 0.18215 * latents
|
582 |
+
if self.vae.quant_conv.weight.ndim==5:
|
583 |
+
mini_batch_encoder = self.vae.mini_batch_encoder
|
584 |
+
mini_batch_decoder = self.vae.mini_batch_decoder
|
585 |
+
if self.vae.slice_compression_vae:
|
586 |
+
video = self.vae.decode(latents)[0]
|
587 |
+
else:
|
588 |
+
video = []
|
589 |
+
for i in range(0, latents.shape[2], mini_batch_decoder):
|
590 |
+
with torch.no_grad():
|
591 |
+
start_index = i
|
592 |
+
end_index = i + mini_batch_decoder
|
593 |
+
latents_bs = self.vae.decode(latents[:, :, start_index:end_index, :, :])[0]
|
594 |
+
video.append(latents_bs)
|
595 |
+
video = torch.cat(video, 2)
|
596 |
+
video = video.clamp(-1, 1)
|
597 |
+
video = self.smooth_output(video, mini_batch_encoder, mini_batch_decoder).cpu().clamp(-1, 1)
|
598 |
+
else:
|
599 |
+
latents = rearrange(latents, "b c f h w -> (b f) c h w")
|
600 |
+
video = []
|
601 |
+
for frame_idx in tqdm(range(latents.shape[0])):
|
602 |
+
video.append(self.vae.decode(latents[frame_idx:frame_idx+1]).sample)
|
603 |
+
video = torch.cat(video)
|
604 |
+
video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length)
|
605 |
+
video = (video / 2 + 0.5).clamp(0, 1)
|
606 |
+
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
|
607 |
+
video = video.cpu().float().numpy()
|
608 |
+
return video
|
609 |
+
|
610 |
+
@torch.no_grad()
|
611 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
612 |
+
def __call__(
|
613 |
+
self,
|
614 |
+
prompt: Union[str, List[str]] = None,
|
615 |
+
video_length: Optional[int] = None,
|
616 |
+
negative_prompt: str = "",
|
617 |
+
num_inference_steps: int = 20,
|
618 |
+
timesteps: List[int] = None,
|
619 |
+
guidance_scale: float = 4.5,
|
620 |
+
num_images_per_prompt: Optional[int] = 1,
|
621 |
+
height: Optional[int] = None,
|
622 |
+
width: Optional[int] = None,
|
623 |
+
eta: float = 0.0,
|
624 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
625 |
+
latents: Optional[torch.FloatTensor] = None,
|
626 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
627 |
+
prompt_attention_mask: Optional[torch.FloatTensor] = None,
|
628 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
629 |
+
negative_prompt_attention_mask: Optional[torch.FloatTensor] = None,
|
630 |
+
output_type: Optional[str] = "latent",
|
631 |
+
return_dict: bool = True,
|
632 |
+
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
633 |
+
callback_steps: int = 1,
|
634 |
+
clean_caption: bool = True,
|
635 |
+
max_sequence_length: int = 120,
|
636 |
+
**kwargs,
|
637 |
+
) -> Union[EasyAnimatePipelineOutput, Tuple]:
|
638 |
+
"""
|
639 |
+
Function invoked when calling the pipeline for generation.
|
640 |
+
|
641 |
+
Args:
|
642 |
+
prompt (`str` or `List[str]`, *optional*):
|
643 |
+
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
644 |
+
instead.
|
645 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
646 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
647 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
648 |
+
less than `1`).
|
649 |
+
num_inference_steps (`int`, *optional*, defaults to 100):
|
650 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
651 |
+
expense of slower inference.
|
652 |
+
timesteps (`List[int]`, *optional*):
|
653 |
+
Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps`
|
654 |
+
timesteps are used. Must be in descending order.
|
655 |
+
guidance_scale (`float`, *optional*, defaults to 7.0):
|
656 |
+
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
657 |
+
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
658 |
+
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
659 |
+
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
660 |
+
usually at the expense of lower image quality.
|
661 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
662 |
+
The number of images to generate per prompt.
|
663 |
+
height (`int`, *optional*, defaults to self.unet.config.sample_size):
|
664 |
+
The height in pixels of the generated image.
|
665 |
+
width (`int`, *optional*, defaults to self.unet.config.sample_size):
|
666 |
+
The width in pixels of the generated image.
|
667 |
+
eta (`float`, *optional*, defaults to 0.0):
|
668 |
+
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
669 |
+
[`schedulers.DDIMScheduler`], will be ignored for others.
|
670 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
671 |
+
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
672 |
+
to make generation deterministic.
|
673 |
+
latents (`torch.FloatTensor`, *optional*):
|
674 |
+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
675 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
676 |
+
tensor will ge generated by sampling using the supplied random `generator`.
|
677 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
678 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
679 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
680 |
+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
681 |
+
Pre-generated negative text embeddings. For PixArt-Alpha this negative prompt should be "". If not
|
682 |
+
provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
|
683 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
684 |
+
The output format of the generate image. Choose between
|
685 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
686 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
687 |
+
Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple.
|
688 |
+
callback (`Callable`, *optional*):
|
689 |
+
A function that will be called every `callback_steps` steps during inference. The function will be
|
690 |
+
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
691 |
+
callback_steps (`int`, *optional*, defaults to 1):
|
692 |
+
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
693 |
+
called at every step.
|
694 |
+
clean_caption (`bool`, *optional*, defaults to `True`):
|
695 |
+
Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to
|
696 |
+
be installed. If the dependencies are not installed, the embeddings will be created from the raw
|
697 |
+
prompt.
|
698 |
+
mask_feature (`bool` defaults to `True`): If set to `True`, the text embeddings will be masked.
|
699 |
+
|
700 |
+
Examples:
|
701 |
+
|
702 |
+
Returns:
|
703 |
+
[`~pipelines.ImagePipelineOutput`] or `tuple`:
|
704 |
+
If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is
|
705 |
+
returned where the first element is a list with the generated images
|
706 |
+
"""
|
707 |
+
# 1. Check inputs. Raise error if not correct
|
708 |
+
height = height or self.transformer.config.sample_size * self.vae_scale_factor
|
709 |
+
width = width or self.transformer.config.sample_size * self.vae_scale_factor
|
710 |
+
|
711 |
+
# 2. Default height and width to transformer
|
712 |
+
if prompt is not None and isinstance(prompt, str):
|
713 |
+
batch_size = 1
|
714 |
+
elif prompt is not None and isinstance(prompt, list):
|
715 |
+
batch_size = len(prompt)
|
716 |
+
else:
|
717 |
+
batch_size = prompt_embeds.shape[0]
|
718 |
+
|
719 |
+
device = self._execution_device
|
720 |
+
|
721 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
722 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
723 |
+
# corresponds to doing no classifier free guidance.
|
724 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
725 |
+
|
726 |
+
# 3. Encode input prompt
|
727 |
+
(
|
728 |
+
prompt_embeds,
|
729 |
+
prompt_attention_mask,
|
730 |
+
negative_prompt_embeds,
|
731 |
+
negative_prompt_attention_mask,
|
732 |
+
) = self.encode_prompt(
|
733 |
+
prompt,
|
734 |
+
do_classifier_free_guidance,
|
735 |
+
negative_prompt=negative_prompt,
|
736 |
+
num_images_per_prompt=num_images_per_prompt,
|
737 |
+
device=device,
|
738 |
+
prompt_embeds=prompt_embeds,
|
739 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
740 |
+
prompt_attention_mask=prompt_attention_mask,
|
741 |
+
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
742 |
+
clean_caption=clean_caption,
|
743 |
+
max_sequence_length=max_sequence_length,
|
744 |
+
)
|
745 |
+
if do_classifier_free_guidance:
|
746 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
747 |
+
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
|
748 |
+
|
749 |
+
# 4. Prepare timesteps
|
750 |
+
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
|
751 |
+
|
752 |
+
# 5. Prepare latents.
|
753 |
+
latent_channels = self.transformer.config.in_channels
|
754 |
+
latents = self.prepare_latents(
|
755 |
+
batch_size * num_images_per_prompt,
|
756 |
+
latent_channels,
|
757 |
+
video_length,
|
758 |
+
height,
|
759 |
+
width,
|
760 |
+
prompt_embeds.dtype,
|
761 |
+
device,
|
762 |
+
generator,
|
763 |
+
latents,
|
764 |
+
)
|
765 |
+
|
766 |
+
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
767 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
768 |
+
|
769 |
+
# 6.1 Prepare micro-conditions.
|
770 |
+
added_cond_kwargs = {"resolution": None, "aspect_ratio": None}
|
771 |
+
if self.transformer.config.sample_size == 128:
|
772 |
+
resolution = torch.tensor([height, width]).repeat(batch_size * num_images_per_prompt, 1)
|
773 |
+
aspect_ratio = torch.tensor([float(height / width)]).repeat(batch_size * num_images_per_prompt, 1)
|
774 |
+
resolution = resolution.to(dtype=prompt_embeds.dtype, device=device)
|
775 |
+
aspect_ratio = aspect_ratio.to(dtype=prompt_embeds.dtype, device=device)
|
776 |
+
|
777 |
+
if do_classifier_free_guidance:
|
778 |
+
resolution = torch.cat([resolution, resolution], dim=0)
|
779 |
+
aspect_ratio = torch.cat([aspect_ratio, aspect_ratio], dim=0)
|
780 |
+
|
781 |
+
added_cond_kwargs = {"resolution": resolution, "aspect_ratio": aspect_ratio}
|
782 |
+
|
783 |
+
# 7. Denoising loop
|
784 |
+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
785 |
+
|
786 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
787 |
+
for i, t in enumerate(timesteps):
|
788 |
+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
789 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
790 |
+
|
791 |
+
current_timestep = t
|
792 |
+
if not torch.is_tensor(current_timestep):
|
793 |
+
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
794 |
+
# This would be a good case for the `match` statement (Python 3.10+)
|
795 |
+
is_mps = latent_model_input.device.type == "mps"
|
796 |
+
if isinstance(current_timestep, float):
|
797 |
+
dtype = torch.float32 if is_mps else torch.float64
|
798 |
+
else:
|
799 |
+
dtype = torch.int32 if is_mps else torch.int64
|
800 |
+
current_timestep = torch.tensor([current_timestep], dtype=dtype, device=latent_model_input.device)
|
801 |
+
elif len(current_timestep.shape) == 0:
|
802 |
+
current_timestep = current_timestep[None].to(latent_model_input.device)
|
803 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
804 |
+
current_timestep = current_timestep.expand(latent_model_input.shape[0])
|
805 |
+
|
806 |
+
# predict noise model_output
|
807 |
+
noise_pred = self.transformer(
|
808 |
+
latent_model_input,
|
809 |
+
encoder_hidden_states=prompt_embeds,
|
810 |
+
encoder_attention_mask=prompt_attention_mask,
|
811 |
+
timestep=current_timestep,
|
812 |
+
added_cond_kwargs=added_cond_kwargs,
|
813 |
+
return_dict=False,
|
814 |
+
)[0]
|
815 |
+
|
816 |
+
# perform guidance
|
817 |
+
if do_classifier_free_guidance:
|
818 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
819 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
820 |
+
|
821 |
+
# learned sigma
|
822 |
+
if self.transformer.config.out_channels // 2 == latent_channels:
|
823 |
+
noise_pred = noise_pred.chunk(2, dim=1)[0]
|
824 |
+
else:
|
825 |
+
noise_pred = noise_pred
|
826 |
+
|
827 |
+
# compute previous image: x_t -> x_t-1
|
828 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
829 |
+
|
830 |
+
# call the callback, if provided
|
831 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
832 |
+
progress_bar.update()
|
833 |
+
if callback is not None and i % callback_steps == 0:
|
834 |
+
step_idx = i // getattr(self.scheduler, "order", 1)
|
835 |
+
callback(step_idx, t, latents)
|
836 |
+
|
837 |
+
# Post-processing
|
838 |
+
video = self.decode_latents(latents)
|
839 |
+
|
840 |
+
# Convert to tensor
|
841 |
+
if output_type == "latent":
|
842 |
+
video = torch.from_numpy(video)
|
843 |
+
|
844 |
+
if not return_dict:
|
845 |
+
return video
|
846 |
+
|
847 |
+
return EasyAnimatePipelineOutput(videos=video)
|
easyanimate/pipeline/pipeline_easyanimate_inpaint.py
ADDED
@@ -0,0 +1,984 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 PixArt-Alpha Authors and The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import html
|
16 |
+
import inspect
|
17 |
+
import re
|
18 |
+
import copy
|
19 |
+
import urllib.parse as ul
|
20 |
+
from dataclasses import dataclass
|
21 |
+
from typing import Callable, List, Optional, Tuple, Union
|
22 |
+
|
23 |
+
import numpy as np
|
24 |
+
import torch
|
25 |
+
from diffusers import DiffusionPipeline, ImagePipelineOutput
|
26 |
+
from diffusers.image_processor import VaeImageProcessor
|
27 |
+
from diffusers.models import AutoencoderKL
|
28 |
+
from diffusers.schedulers import DPMSolverMultistepScheduler
|
29 |
+
from diffusers.utils import (BACKENDS_MAPPING, BaseOutput, deprecate,
|
30 |
+
is_bs4_available, is_ftfy_available, logging,
|
31 |
+
replace_example_docstring)
|
32 |
+
from diffusers.utils.torch_utils import randn_tensor
|
33 |
+
from einops import rearrange
|
34 |
+
from tqdm import tqdm
|
35 |
+
from transformers import T5EncoderModel, T5Tokenizer
|
36 |
+
|
37 |
+
from ..models.transformer3d import Transformer3DModel
|
38 |
+
|
39 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
40 |
+
|
41 |
+
if is_bs4_available():
|
42 |
+
from bs4 import BeautifulSoup
|
43 |
+
|
44 |
+
if is_ftfy_available():
|
45 |
+
import ftfy
|
46 |
+
|
47 |
+
|
48 |
+
EXAMPLE_DOC_STRING = """
|
49 |
+
Examples:
|
50 |
+
```py
|
51 |
+
>>> import torch
|
52 |
+
>>> from diffusers import EasyAnimatePipeline
|
53 |
+
|
54 |
+
>>> # You can replace the checkpoint id with "PixArt-alpha/PixArt-XL-2-512x512" too.
|
55 |
+
>>> pipe = EasyAnimatePipeline.from_pretrained("PixArt-alpha/PixArt-XL-2-1024-MS", torch_dtype=torch.float16)
|
56 |
+
>>> # Enable memory optimizations.
|
57 |
+
>>> pipe.enable_model_cpu_offload()
|
58 |
+
|
59 |
+
>>> prompt = "A small cactus with a happy face in the Sahara desert."
|
60 |
+
>>> image = pipe(prompt).images[0]
|
61 |
+
```
|
62 |
+
"""
|
63 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
64 |
+
def retrieve_latents(encoder_output, generator):
|
65 |
+
if hasattr(encoder_output, "latent_dist"):
|
66 |
+
return encoder_output.latent_dist.sample(generator)
|
67 |
+
elif hasattr(encoder_output, "latents"):
|
68 |
+
return encoder_output.latents
|
69 |
+
else:
|
70 |
+
raise AttributeError("Could not access latents of provided encoder_output")
|
71 |
+
|
72 |
+
@dataclass
|
73 |
+
class EasyAnimatePipelineOutput(BaseOutput):
|
74 |
+
videos: Union[torch.Tensor, np.ndarray]
|
75 |
+
|
76 |
+
class EasyAnimateInpaintPipeline(DiffusionPipeline):
|
77 |
+
r"""
|
78 |
+
Pipeline for text-to-image generation using PixArt-Alpha.
|
79 |
+
|
80 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
81 |
+
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
82 |
+
|
83 |
+
Args:
|
84 |
+
vae ([`AutoencoderKL`]):
|
85 |
+
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
86 |
+
text_encoder ([`T5EncoderModel`]):
|
87 |
+
Frozen text-encoder. PixArt-Alpha uses
|
88 |
+
[T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the
|
89 |
+
[t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant.
|
90 |
+
tokenizer (`T5Tokenizer`):
|
91 |
+
Tokenizer of class
|
92 |
+
[T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
|
93 |
+
transformer ([`Transformer3DModel`]):
|
94 |
+
A text conditioned `Transformer3DModel` to denoise the encoded image latents.
|
95 |
+
scheduler ([`SchedulerMixin`]):
|
96 |
+
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
97 |
+
"""
|
98 |
+
bad_punct_regex = re.compile(
|
99 |
+
r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}"
|
100 |
+
) # noqa
|
101 |
+
|
102 |
+
_optional_components = ["tokenizer", "text_encoder"]
|
103 |
+
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
104 |
+
|
105 |
+
def __init__(
|
106 |
+
self,
|
107 |
+
tokenizer: T5Tokenizer,
|
108 |
+
text_encoder: T5EncoderModel,
|
109 |
+
vae: AutoencoderKL,
|
110 |
+
transformer: Transformer3DModel,
|
111 |
+
scheduler: DPMSolverMultistepScheduler,
|
112 |
+
):
|
113 |
+
super().__init__()
|
114 |
+
|
115 |
+
self.register_modules(
|
116 |
+
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
|
117 |
+
)
|
118 |
+
|
119 |
+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
120 |
+
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_normalize=True)
|
121 |
+
self.mask_processor = VaeImageProcessor(
|
122 |
+
vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True
|
123 |
+
)
|
124 |
+
|
125 |
+
# Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/utils.py
|
126 |
+
def mask_text_embeddings(self, emb, mask):
|
127 |
+
if emb.shape[0] == 1:
|
128 |
+
keep_index = mask.sum().item()
|
129 |
+
return emb[:, :, :keep_index, :], keep_index
|
130 |
+
else:
|
131 |
+
masked_feature = emb * mask[:, None, :, None]
|
132 |
+
return masked_feature, emb.shape[2]
|
133 |
+
|
134 |
+
# Adapted from diffusers.pipelines.deepfloyd_if.pipeline_if.encode_prompt
|
135 |
+
def encode_prompt(
|
136 |
+
self,
|
137 |
+
prompt: Union[str, List[str]],
|
138 |
+
do_classifier_free_guidance: bool = True,
|
139 |
+
negative_prompt: str = "",
|
140 |
+
num_images_per_prompt: int = 1,
|
141 |
+
device: Optional[torch.device] = None,
|
142 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
143 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
144 |
+
prompt_attention_mask: Optional[torch.FloatTensor] = None,
|
145 |
+
negative_prompt_attention_mask: Optional[torch.FloatTensor] = None,
|
146 |
+
clean_caption: bool = False,
|
147 |
+
max_sequence_length: int = 120,
|
148 |
+
**kwargs,
|
149 |
+
):
|
150 |
+
r"""
|
151 |
+
Encodes the prompt into text encoder hidden states.
|
152 |
+
|
153 |
+
Args:
|
154 |
+
prompt (`str` or `List[str]`, *optional*):
|
155 |
+
prompt to be encoded
|
156 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
157 |
+
The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds`
|
158 |
+
instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For
|
159 |
+
PixArt-Alpha, this should be "".
|
160 |
+
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
|
161 |
+
whether to use classifier free guidance or not
|
162 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
163 |
+
number of images that should be generated per prompt
|
164 |
+
device: (`torch.device`, *optional*):
|
165 |
+
torch device to place the resulting embeddings on
|
166 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
167 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
168 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
169 |
+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
170 |
+
Pre-generated negative text embeddings. For PixArt-Alpha, it's should be the embeddings of the ""
|
171 |
+
string.
|
172 |
+
clean_caption (`bool`, defaults to `False`):
|
173 |
+
If `True`, the function will preprocess and clean the provided caption before encoding.
|
174 |
+
max_sequence_length (`int`, defaults to 120): Maximum sequence length to use for the prompt.
|
175 |
+
"""
|
176 |
+
|
177 |
+
if "mask_feature" in kwargs:
|
178 |
+
deprecation_message = "The use of `mask_feature` is deprecated. It is no longer used in any computation and that doesn't affect the end results. It will be removed in a future version."
|
179 |
+
deprecate("mask_feature", "1.0.0", deprecation_message, standard_warn=False)
|
180 |
+
|
181 |
+
if device is None:
|
182 |
+
device = self._execution_device
|
183 |
+
|
184 |
+
if prompt is not None and isinstance(prompt, str):
|
185 |
+
batch_size = 1
|
186 |
+
elif prompt is not None and isinstance(prompt, list):
|
187 |
+
batch_size = len(prompt)
|
188 |
+
else:
|
189 |
+
batch_size = prompt_embeds.shape[0]
|
190 |
+
|
191 |
+
# See Section 3.1. of the paper.
|
192 |
+
max_length = max_sequence_length
|
193 |
+
|
194 |
+
if prompt_embeds is None:
|
195 |
+
prompt = self._text_preprocessing(prompt, clean_caption=clean_caption)
|
196 |
+
text_inputs = self.tokenizer(
|
197 |
+
prompt,
|
198 |
+
padding="max_length",
|
199 |
+
max_length=max_length,
|
200 |
+
truncation=True,
|
201 |
+
add_special_tokens=True,
|
202 |
+
return_tensors="pt",
|
203 |
+
)
|
204 |
+
text_input_ids = text_inputs.input_ids
|
205 |
+
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
206 |
+
|
207 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
208 |
+
text_input_ids, untruncated_ids
|
209 |
+
):
|
210 |
+
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1])
|
211 |
+
logger.warning(
|
212 |
+
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
213 |
+
f" {max_length} tokens: {removed_text}"
|
214 |
+
)
|
215 |
+
|
216 |
+
prompt_attention_mask = text_inputs.attention_mask
|
217 |
+
prompt_attention_mask = prompt_attention_mask.to(device)
|
218 |
+
|
219 |
+
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask)
|
220 |
+
prompt_embeds = prompt_embeds[0]
|
221 |
+
|
222 |
+
if self.text_encoder is not None:
|
223 |
+
dtype = self.text_encoder.dtype
|
224 |
+
elif self.transformer is not None:
|
225 |
+
dtype = self.transformer.dtype
|
226 |
+
else:
|
227 |
+
dtype = None
|
228 |
+
|
229 |
+
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
230 |
+
|
231 |
+
bs_embed, seq_len, _ = prompt_embeds.shape
|
232 |
+
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
|
233 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
234 |
+
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
235 |
+
prompt_attention_mask = prompt_attention_mask.view(bs_embed, -1)
|
236 |
+
prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
|
237 |
+
|
238 |
+
# get unconditional embeddings for classifier free guidance
|
239 |
+
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
240 |
+
uncond_tokens = [negative_prompt] * batch_size
|
241 |
+
uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption)
|
242 |
+
max_length = prompt_embeds.shape[1]
|
243 |
+
uncond_input = self.tokenizer(
|
244 |
+
uncond_tokens,
|
245 |
+
padding="max_length",
|
246 |
+
max_length=max_length,
|
247 |
+
truncation=True,
|
248 |
+
return_attention_mask=True,
|
249 |
+
add_special_tokens=True,
|
250 |
+
return_tensors="pt",
|
251 |
+
)
|
252 |
+
negative_prompt_attention_mask = uncond_input.attention_mask
|
253 |
+
negative_prompt_attention_mask = negative_prompt_attention_mask.to(device)
|
254 |
+
|
255 |
+
negative_prompt_embeds = self.text_encoder(
|
256 |
+
uncond_input.input_ids.to(device), attention_mask=negative_prompt_attention_mask
|
257 |
+
)
|
258 |
+
negative_prompt_embeds = negative_prompt_embeds[0]
|
259 |
+
|
260 |
+
if do_classifier_free_guidance:
|
261 |
+
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
262 |
+
seq_len = negative_prompt_embeds.shape[1]
|
263 |
+
|
264 |
+
negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
|
265 |
+
|
266 |
+
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
267 |
+
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
268 |
+
|
269 |
+
negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed, -1)
|
270 |
+
negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1)
|
271 |
+
else:
|
272 |
+
negative_prompt_embeds = None
|
273 |
+
negative_prompt_attention_mask = None
|
274 |
+
|
275 |
+
return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
|
276 |
+
|
277 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
278 |
+
def prepare_extra_step_kwargs(self, generator, eta):
|
279 |
+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
280 |
+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
281 |
+
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
282 |
+
# and should be between [0, 1]
|
283 |
+
|
284 |
+
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
285 |
+
extra_step_kwargs = {}
|
286 |
+
if accepts_eta:
|
287 |
+
extra_step_kwargs["eta"] = eta
|
288 |
+
|
289 |
+
# check if the scheduler accepts generator
|
290 |
+
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
291 |
+
if accepts_generator:
|
292 |
+
extra_step_kwargs["generator"] = generator
|
293 |
+
return extra_step_kwargs
|
294 |
+
|
295 |
+
def check_inputs(
|
296 |
+
self,
|
297 |
+
prompt,
|
298 |
+
height,
|
299 |
+
width,
|
300 |
+
negative_prompt,
|
301 |
+
callback_steps,
|
302 |
+
prompt_embeds=None,
|
303 |
+
negative_prompt_embeds=None,
|
304 |
+
):
|
305 |
+
if height % 8 != 0 or width % 8 != 0:
|
306 |
+
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
307 |
+
|
308 |
+
if (callback_steps is None) or (
|
309 |
+
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
310 |
+
):
|
311 |
+
raise ValueError(
|
312 |
+
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
313 |
+
f" {type(callback_steps)}."
|
314 |
+
)
|
315 |
+
|
316 |
+
if prompt is not None and prompt_embeds is not None:
|
317 |
+
raise ValueError(
|
318 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
319 |
+
" only forward one of the two."
|
320 |
+
)
|
321 |
+
elif prompt is None and prompt_embeds is None:
|
322 |
+
raise ValueError(
|
323 |
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
324 |
+
)
|
325 |
+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
326 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
327 |
+
|
328 |
+
if prompt is not None and negative_prompt_embeds is not None:
|
329 |
+
raise ValueError(
|
330 |
+
f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
|
331 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
332 |
+
)
|
333 |
+
|
334 |
+
if negative_prompt is not None and negative_prompt_embeds is not None:
|
335 |
+
raise ValueError(
|
336 |
+
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
337 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
338 |
+
)
|
339 |
+
|
340 |
+
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
341 |
+
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
342 |
+
raise ValueError(
|
343 |
+
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
344 |
+
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
345 |
+
f" {negative_prompt_embeds.shape}."
|
346 |
+
)
|
347 |
+
|
348 |
+
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing
|
349 |
+
def _text_preprocessing(self, text, clean_caption=False):
|
350 |
+
if clean_caption and not is_bs4_available():
|
351 |
+
logger.warn(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`"))
|
352 |
+
logger.warn("Setting `clean_caption` to False...")
|
353 |
+
clean_caption = False
|
354 |
+
|
355 |
+
if clean_caption and not is_ftfy_available():
|
356 |
+
logger.warn(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`"))
|
357 |
+
logger.warn("Setting `clean_caption` to False...")
|
358 |
+
clean_caption = False
|
359 |
+
|
360 |
+
if not isinstance(text, (tuple, list)):
|
361 |
+
text = [text]
|
362 |
+
|
363 |
+
def process(text: str):
|
364 |
+
if clean_caption:
|
365 |
+
text = self._clean_caption(text)
|
366 |
+
text = self._clean_caption(text)
|
367 |
+
else:
|
368 |
+
text = text.lower().strip()
|
369 |
+
return text
|
370 |
+
|
371 |
+
return [process(t) for t in text]
|
372 |
+
|
373 |
+
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption
|
374 |
+
def _clean_caption(self, caption):
|
375 |
+
caption = str(caption)
|
376 |
+
caption = ul.unquote_plus(caption)
|
377 |
+
caption = caption.strip().lower()
|
378 |
+
caption = re.sub("<person>", "person", caption)
|
379 |
+
# urls:
|
380 |
+
caption = re.sub(
|
381 |
+
r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
|
382 |
+
"",
|
383 |
+
caption,
|
384 |
+
) # regex for urls
|
385 |
+
caption = re.sub(
|
386 |
+
r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
|
387 |
+
"",
|
388 |
+
caption,
|
389 |
+
) # regex for urls
|
390 |
+
# html:
|
391 |
+
caption = BeautifulSoup(caption, features="html.parser").text
|
392 |
+
|
393 |
+
# @<nickname>
|
394 |
+
caption = re.sub(r"@[\w\d]+\b", "", caption)
|
395 |
+
|
396 |
+
# 31C0—31EF CJK Strokes
|
397 |
+
# 31F0—31FF Katakana Phonetic Extensions
|
398 |
+
# 3200—32FF Enclosed CJK Letters and Months
|
399 |
+
# 3300—33FF CJK Compatibility
|
400 |
+
# 3400—4DBF CJK Unified Ideographs Extension A
|
401 |
+
# 4DC0—4DFF Yijing Hexagram Symbols
|
402 |
+
# 4E00—9FFF CJK Unified Ideographs
|
403 |
+
caption = re.sub(r"[\u31c0-\u31ef]+", "", caption)
|
404 |
+
caption = re.sub(r"[\u31f0-\u31ff]+", "", caption)
|
405 |
+
caption = re.sub(r"[\u3200-\u32ff]+", "", caption)
|
406 |
+
caption = re.sub(r"[\u3300-\u33ff]+", "", caption)
|
407 |
+
caption = re.sub(r"[\u3400-\u4dbf]+", "", caption)
|
408 |
+
caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption)
|
409 |
+
caption = re.sub(r"[\u4e00-\u9fff]+", "", caption)
|
410 |
+
#######################################################
|
411 |
+
|
412 |
+
# все виды тире / all types of dash --> "-"
|
413 |
+
caption = re.sub(
|
414 |
+
r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa
|
415 |
+
"-",
|
416 |
+
caption,
|
417 |
+
)
|
418 |
+
|
419 |
+
# кавычки к одному стандарту
|
420 |
+
caption = re.sub(r"[`´«»“”¨]", '"', caption)
|
421 |
+
caption = re.sub(r"[‘’]", "'", caption)
|
422 |
+
|
423 |
+
# "
|
424 |
+
caption = re.sub(r""?", "", caption)
|
425 |
+
# &
|
426 |
+
caption = re.sub(r"&", "", caption)
|
427 |
+
|
428 |
+
# ip adresses:
|
429 |
+
caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
|
430 |
+
|
431 |
+
# article ids:
|
432 |
+
caption = re.sub(r"\d:\d\d\s+$", "", caption)
|
433 |
+
|
434 |
+
# \n
|
435 |
+
caption = re.sub(r"\\n", " ", caption)
|
436 |
+
|
437 |
+
# "#123"
|
438 |
+
caption = re.sub(r"#\d{1,3}\b", "", caption)
|
439 |
+
# "#12345.."
|
440 |
+
caption = re.sub(r"#\d{5,}\b", "", caption)
|
441 |
+
# "123456.."
|
442 |
+
caption = re.sub(r"\b\d{6,}\b", "", caption)
|
443 |
+
# filenames:
|
444 |
+
caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption)
|
445 |
+
|
446 |
+
#
|
447 |
+
caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT"""
|
448 |
+
caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT"""
|
449 |
+
|
450 |
+
caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT
|
451 |
+
caption = re.sub(r"\s+\.\s+", r" ", caption) # " . "
|
452 |
+
|
453 |
+
# this-is-my-cute-cat / this_is_my_cute_cat
|
454 |
+
regex2 = re.compile(r"(?:\-|\_)")
|
455 |
+
if len(re.findall(regex2, caption)) > 3:
|
456 |
+
caption = re.sub(regex2, " ", caption)
|
457 |
+
|
458 |
+
caption = ftfy.fix_text(caption)
|
459 |
+
caption = html.unescape(html.unescape(caption))
|
460 |
+
|
461 |
+
caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640
|
462 |
+
caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc
|
463 |
+
caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231
|
464 |
+
|
465 |
+
caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption)
|
466 |
+
caption = re.sub(r"(free\s)?download(\sfree)?", "", caption)
|
467 |
+
caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption)
|
468 |
+
caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption)
|
469 |
+
caption = re.sub(r"\bpage\s+\d+\b", "", caption)
|
470 |
+
|
471 |
+
caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a...
|
472 |
+
|
473 |
+
caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption)
|
474 |
+
|
475 |
+
caption = re.sub(r"\b\s+\:\s+", r": ", caption)
|
476 |
+
caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption)
|
477 |
+
caption = re.sub(r"\s+", " ", caption)
|
478 |
+
|
479 |
+
caption.strip()
|
480 |
+
|
481 |
+
caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption)
|
482 |
+
caption = re.sub(r"^[\'\_,\-\:;]", r"", caption)
|
483 |
+
caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption)
|
484 |
+
caption = re.sub(r"^\.\S+$", "", caption)
|
485 |
+
|
486 |
+
return caption.strip()
|
487 |
+
|
488 |
+
def prepare_latents(
|
489 |
+
self,
|
490 |
+
batch_size,
|
491 |
+
num_channels_latents,
|
492 |
+
height,
|
493 |
+
width,
|
494 |
+
video_length,
|
495 |
+
dtype,
|
496 |
+
device,
|
497 |
+
generator,
|
498 |
+
latents=None,
|
499 |
+
video=None,
|
500 |
+
timestep=None,
|
501 |
+
is_strength_max=True,
|
502 |
+
return_noise=False,
|
503 |
+
return_video_latents=False,
|
504 |
+
):
|
505 |
+
if self.vae.quant_conv.weight.ndim==5:
|
506 |
+
shape = (batch_size, num_channels_latents, int(video_length // 5 * 2) if video_length != 1 else 1, height // self.vae_scale_factor, width // self.vae_scale_factor)
|
507 |
+
else:
|
508 |
+
shape = (batch_size, num_channels_latents, video_length, height // self.vae_scale_factor, width // self.vae_scale_factor)
|
509 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
510 |
+
raise ValueError(
|
511 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
512 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
513 |
+
)
|
514 |
+
|
515 |
+
if return_video_latents or (latents is None and not is_strength_max):
|
516 |
+
video = video.to(device=device, dtype=dtype)
|
517 |
+
|
518 |
+
if video.shape[1] == 4:
|
519 |
+
video_latents = video
|
520 |
+
else:
|
521 |
+
video_length = video.shape[2]
|
522 |
+
video = rearrange(video, "b c f h w -> (b f) c h w")
|
523 |
+
video_latents = self._encode_vae_image(image=video, generator=generator)
|
524 |
+
video_latents = rearrange(video_latents, "(b f) c h w -> b c f h w", f=video_length)
|
525 |
+
video_latents = video_latents.repeat(batch_size // video_latents.shape[0], 1, 1, 1, 1)
|
526 |
+
|
527 |
+
if latents is None:
|
528 |
+
rand_device = "cpu" if device.type == "mps" else device
|
529 |
+
|
530 |
+
noise = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
|
531 |
+
# if strength is 1. then initialise the latents to noise, else initial to image + noise
|
532 |
+
latents = noise if is_strength_max else self.scheduler.add_noise(video_latents, noise, timestep)
|
533 |
+
else:
|
534 |
+
noise = latents.to(device)
|
535 |
+
if latents.shape != shape:
|
536 |
+
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
|
537 |
+
latents = latents.to(device)
|
538 |
+
|
539 |
+
# scale the initial noise by the standard deviation required by the scheduler
|
540 |
+
latents = latents * self.scheduler.init_noise_sigma
|
541 |
+
outputs = (latents,)
|
542 |
+
|
543 |
+
if return_noise:
|
544 |
+
outputs += (noise,)
|
545 |
+
|
546 |
+
if return_video_latents:
|
547 |
+
outputs += (video_latents,)
|
548 |
+
|
549 |
+
return outputs
|
550 |
+
|
551 |
+
def decode_latents(self, latents):
|
552 |
+
video_length = latents.shape[2]
|
553 |
+
latents = 1 / 0.18215 * latents
|
554 |
+
if self.vae.quant_conv.weight.ndim==5:
|
555 |
+
mini_batch_decoder = 2
|
556 |
+
# Decoder
|
557 |
+
video = []
|
558 |
+
for i in range(0, latents.shape[2], mini_batch_decoder):
|
559 |
+
with torch.no_grad():
|
560 |
+
start_index = i
|
561 |
+
end_index = i + mini_batch_decoder
|
562 |
+
latents_bs = self.vae.decode(latents[:, :, start_index:end_index, :, :])[0]
|
563 |
+
video.append(latents_bs)
|
564 |
+
|
565 |
+
# Smooth
|
566 |
+
mini_batch_encoder = 5
|
567 |
+
video = torch.cat(video, 2).cpu()
|
568 |
+
for i in range(mini_batch_encoder, video.shape[2], mini_batch_encoder):
|
569 |
+
origin_before = copy.deepcopy(video[:, :, i - 1, :, :])
|
570 |
+
origin_after = copy.deepcopy(video[:, :, i, :, :])
|
571 |
+
|
572 |
+
video[:, :, i - 1, :, :] = origin_before * 0.75 + origin_after * 0.25
|
573 |
+
video[:, :, i, :, :] = origin_before * 0.25 + origin_after * 0.75
|
574 |
+
video = video.clamp(-1, 1)
|
575 |
+
else:
|
576 |
+
latents = rearrange(latents, "b c f h w -> (b f) c h w")
|
577 |
+
# video = self.vae.decode(latents).sample
|
578 |
+
video = []
|
579 |
+
for frame_idx in tqdm(range(latents.shape[0])):
|
580 |
+
video.append(self.vae.decode(latents[frame_idx:frame_idx+1]).sample)
|
581 |
+
video = torch.cat(video)
|
582 |
+
video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length)
|
583 |
+
video = (video / 2 + 0.5).clamp(0, 1)
|
584 |
+
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
|
585 |
+
video = video.cpu().float().numpy()
|
586 |
+
return video
|
587 |
+
|
588 |
+
def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
|
589 |
+
if isinstance(generator, list):
|
590 |
+
image_latents = [
|
591 |
+
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
|
592 |
+
for i in range(image.shape[0])
|
593 |
+
]
|
594 |
+
image_latents = torch.cat(image_latents, dim=0)
|
595 |
+
else:
|
596 |
+
image_latents = retrieve_latents(self.vae.encode(image), generator=generator)
|
597 |
+
|
598 |
+
image_latents = self.vae.config.scaling_factor * image_latents
|
599 |
+
|
600 |
+
return image_latents
|
601 |
+
|
602 |
+
def prepare_mask_latents(
|
603 |
+
self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance
|
604 |
+
):
|
605 |
+
# resize the mask to latents shape as we concatenate the mask to the latents
|
606 |
+
# we do that before converting to dtype to avoid breaking in case we're using cpu_offload
|
607 |
+
# and half precision
|
608 |
+
video_length = mask.shape[2]
|
609 |
+
|
610 |
+
mask = mask.to(device=device, dtype=self.vae.dtype)
|
611 |
+
if self.vae.quant_conv.weight.ndim==5:
|
612 |
+
bs = 1
|
613 |
+
new_mask = []
|
614 |
+
for i in range(0, mask.shape[0], bs):
|
615 |
+
mini_batch = 5
|
616 |
+
new_mask_mini_batch = []
|
617 |
+
for j in range(0, mask.shape[2], mini_batch):
|
618 |
+
mask_bs = mask[i : i + bs, :, j: j + mini_batch, :, :]
|
619 |
+
mask_bs = self.vae.encode(mask_bs)[0]
|
620 |
+
mask_bs = mask_bs.sample()
|
621 |
+
new_mask_mini_batch.append(mask_bs)
|
622 |
+
new_mask_mini_batch = torch.cat(new_mask_mini_batch, dim = 2)
|
623 |
+
new_mask.append(new_mask_mini_batch)
|
624 |
+
mask = torch.cat(new_mask, dim = 0)
|
625 |
+
mask = mask * 0.1825
|
626 |
+
|
627 |
+
else:
|
628 |
+
if mask.shape[1] == 4:
|
629 |
+
mask = mask
|
630 |
+
else:
|
631 |
+
video_length = mask.shape[2]
|
632 |
+
mask = rearrange(mask, "b c f h w -> (b f) c h w")
|
633 |
+
mask = self._encode_vae_image(mask, generator=generator)
|
634 |
+
mask = rearrange(mask, "(b f) c h w -> b c f h w", f=video_length)
|
635 |
+
|
636 |
+
masked_image = masked_image.to(device=device, dtype=self.vae.dtype)
|
637 |
+
if self.vae.quant_conv.weight.ndim==5:
|
638 |
+
bs = 1
|
639 |
+
new_mask_pixel_values = []
|
640 |
+
for i in range(0, masked_image.shape[0], bs):
|
641 |
+
mini_batch = 5
|
642 |
+
new_mask_pixel_values_mini_batch = []
|
643 |
+
for j in range(0, masked_image.shape[2], mini_batch):
|
644 |
+
mask_pixel_values_bs = masked_image[i : i + bs, :, j: j + mini_batch, :, :]
|
645 |
+
mask_pixel_values_bs = self.vae.encode(mask_pixel_values_bs)[0]
|
646 |
+
mask_pixel_values_bs = mask_pixel_values_bs.sample()
|
647 |
+
new_mask_pixel_values_mini_batch.append(mask_pixel_values_bs)
|
648 |
+
new_mask_pixel_values_mini_batch = torch.cat(new_mask_pixel_values_mini_batch, dim = 2)
|
649 |
+
new_mask_pixel_values.append(new_mask_pixel_values_mini_batch)
|
650 |
+
masked_image_latents = torch.cat(new_mask_pixel_values, dim = 0)
|
651 |
+
masked_image_latents = masked_image_latents * 0.1825
|
652 |
+
|
653 |
+
else:
|
654 |
+
if masked_image.shape[1] == 4:
|
655 |
+
masked_image_latents = masked_image
|
656 |
+
else:
|
657 |
+
video_length = mask.shape[2]
|
658 |
+
masked_image = rearrange(masked_image, "b c f h w -> (b f) c h w")
|
659 |
+
masked_image_latents = self._encode_vae_image(masked_image, generator=generator)
|
660 |
+
masked_image_latents = rearrange(masked_image_latents, "(b f) c h w -> b c f h w", f=video_length)
|
661 |
+
|
662 |
+
# aligning device to prevent device errors when concating it with the latent model input
|
663 |
+
masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
|
664 |
+
return mask, masked_image_latents
|
665 |
+
|
666 |
+
@torch.no_grad()
|
667 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
668 |
+
def __call__(
|
669 |
+
self,
|
670 |
+
prompt: Union[str, List[str]] = None,
|
671 |
+
video_length: Optional[int] = None,
|
672 |
+
video: Union[torch.FloatTensor] = None,
|
673 |
+
mask_video: Union[torch.FloatTensor] = None,
|
674 |
+
masked_video_latents: Union[torch.FloatTensor] = None,
|
675 |
+
negative_prompt: str = "",
|
676 |
+
num_inference_steps: int = 20,
|
677 |
+
timesteps: List[int] = None,
|
678 |
+
guidance_scale: float = 4.5,
|
679 |
+
num_images_per_prompt: Optional[int] = 1,
|
680 |
+
height: Optional[int] = None,
|
681 |
+
width: Optional[int] = None,
|
682 |
+
strength: float = 1.0,
|
683 |
+
eta: float = 0.0,
|
684 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
685 |
+
latents: Optional[torch.FloatTensor] = None,
|
686 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
687 |
+
prompt_attention_mask: Optional[torch.FloatTensor] = None,
|
688 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
689 |
+
negative_prompt_attention_mask: Optional[torch.FloatTensor] = None,
|
690 |
+
output_type: Optional[str] = "latent",
|
691 |
+
return_dict: bool = True,
|
692 |
+
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
693 |
+
callback_steps: int = 1,
|
694 |
+
clean_caption: bool = True,
|
695 |
+
mask_feature: bool = True,
|
696 |
+
max_sequence_length: int = 120
|
697 |
+
) -> Union[EasyAnimatePipelineOutput, Tuple]:
|
698 |
+
"""
|
699 |
+
Function invoked when calling the pipeline for generation.
|
700 |
+
|
701 |
+
Args:
|
702 |
+
prompt (`str` or `List[str]`, *optional*):
|
703 |
+
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
704 |
+
instead.
|
705 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
706 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
707 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
708 |
+
less than `1`).
|
709 |
+
num_inference_steps (`int`, *optional*, defaults to 100):
|
710 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
711 |
+
expense of slower inference.
|
712 |
+
timesteps (`List[int]`, *optional*):
|
713 |
+
Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps`
|
714 |
+
timesteps are used. Must be in descending order.
|
715 |
+
guidance_scale (`float`, *optional*, defaults to 7.0):
|
716 |
+
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
717 |
+
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
718 |
+
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
719 |
+
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
720 |
+
usually at the expense of lower image quality.
|
721 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
722 |
+
The number of images to generate per prompt.
|
723 |
+
height (`int`, *optional*, defaults to self.unet.config.sample_size):
|
724 |
+
The height in pixels of the generated image.
|
725 |
+
width (`int`, *optional*, defaults to self.unet.config.sample_size):
|
726 |
+
The width in pixels of the generated image.
|
727 |
+
eta (`float`, *optional*, defaults to 0.0):
|
728 |
+
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
729 |
+
[`schedulers.DDIMScheduler`], will be ignored for others.
|
730 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
731 |
+
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
732 |
+
to make generation deterministic.
|
733 |
+
latents (`torch.FloatTensor`, *optional*):
|
734 |
+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
735 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
736 |
+
tensor will ge generated by sampling using the supplied random `generator`.
|
737 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
738 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
739 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
740 |
+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
741 |
+
Pre-generated negative text embeddings. For PixArt-Alpha this negative prompt should be "". If not
|
742 |
+
provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
|
743 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
744 |
+
The output format of the generate image. Choose between
|
745 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
746 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
747 |
+
Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple.
|
748 |
+
callback (`Callable`, *optional*):
|
749 |
+
A function that will be called every `callback_steps` steps during inference. The function will be
|
750 |
+
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
751 |
+
callback_steps (`int`, *optional*, defaults to 1):
|
752 |
+
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
753 |
+
called at every step.
|
754 |
+
clean_caption (`bool`, *optional*, defaults to `True`):
|
755 |
+
Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to
|
756 |
+
be installed. If the dependencies are not installed, the embeddings will be created from the raw
|
757 |
+
prompt.
|
758 |
+
mask_feature (`bool` defaults to `True`): If set to `True`, the text embeddings will be masked.
|
759 |
+
|
760 |
+
Examples:
|
761 |
+
|
762 |
+
Returns:
|
763 |
+
[`~pipelines.ImagePipelineOutput`] or `tuple`:
|
764 |
+
If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is
|
765 |
+
returned where the first element is a list with the generated images
|
766 |
+
"""
|
767 |
+
# 1. Check inputs. Raise error if not correct
|
768 |
+
height = height or self.transformer.config.sample_size * self.vae_scale_factor
|
769 |
+
width = width or self.transformer.config.sample_size * self.vae_scale_factor
|
770 |
+
|
771 |
+
# 2. Default height and width to transformer
|
772 |
+
if prompt is not None and isinstance(prompt, str):
|
773 |
+
batch_size = 1
|
774 |
+
elif prompt is not None and isinstance(prompt, list):
|
775 |
+
batch_size = len(prompt)
|
776 |
+
else:
|
777 |
+
batch_size = prompt_embeds.shape[0]
|
778 |
+
|
779 |
+
device = self._execution_device
|
780 |
+
|
781 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
782 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
783 |
+
# corresponds to doing no classifier free guidance.
|
784 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
785 |
+
|
786 |
+
# 3. Encode input prompt
|
787 |
+
(
|
788 |
+
prompt_embeds,
|
789 |
+
prompt_attention_mask,
|
790 |
+
negative_prompt_embeds,
|
791 |
+
negative_prompt_attention_mask,
|
792 |
+
) = self.encode_prompt(
|
793 |
+
prompt,
|
794 |
+
do_classifier_free_guidance,
|
795 |
+
negative_prompt=negative_prompt,
|
796 |
+
num_images_per_prompt=num_images_per_prompt,
|
797 |
+
device=device,
|
798 |
+
prompt_embeds=prompt_embeds,
|
799 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
800 |
+
prompt_attention_mask=prompt_attention_mask,
|
801 |
+
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
802 |
+
clean_caption=clean_caption,
|
803 |
+
max_sequence_length=max_sequence_length,
|
804 |
+
)
|
805 |
+
if do_classifier_free_guidance:
|
806 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
807 |
+
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
|
808 |
+
|
809 |
+
# 4. Prepare timesteps
|
810 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
811 |
+
timesteps = self.scheduler.timesteps
|
812 |
+
# at which timestep to set the initial noise (n.b. 50% if strength is 0.5)
|
813 |
+
latent_timestep = timesteps[:1].repeat(batch_size)
|
814 |
+
# create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise
|
815 |
+
is_strength_max = strength == 1.0
|
816 |
+
|
817 |
+
if video is not None:
|
818 |
+
video_length = video.shape[2]
|
819 |
+
init_video = self.image_processor.preprocess(rearrange(video, "b c f h w -> (b f) c h w"), height=height, width=width)
|
820 |
+
init_video = init_video.to(dtype=torch.float32)
|
821 |
+
init_video = rearrange(init_video, "(b f) c h w -> b c f h w", f=video_length)
|
822 |
+
else:
|
823 |
+
init_video = None
|
824 |
+
|
825 |
+
# Prepare latent variables
|
826 |
+
num_channels_latents = self.vae.config.latent_channels
|
827 |
+
num_channels_transformer = self.transformer.config.in_channels
|
828 |
+
return_image_latents = num_channels_transformer == 4
|
829 |
+
|
830 |
+
# 5. Prepare latents.
|
831 |
+
latents_outputs = self.prepare_latents(
|
832 |
+
batch_size * num_images_per_prompt,
|
833 |
+
num_channels_latents,
|
834 |
+
height,
|
835 |
+
width,
|
836 |
+
video_length,
|
837 |
+
prompt_embeds.dtype,
|
838 |
+
device,
|
839 |
+
generator,
|
840 |
+
latents,
|
841 |
+
video=init_video,
|
842 |
+
timestep=latent_timestep,
|
843 |
+
is_strength_max=is_strength_max,
|
844 |
+
return_noise=True,
|
845 |
+
return_video_latents=return_image_latents,
|
846 |
+
)
|
847 |
+
if return_image_latents:
|
848 |
+
latents, noise, image_latents = latents_outputs
|
849 |
+
else:
|
850 |
+
latents, noise = latents_outputs
|
851 |
+
latents_dtype = latents.dtype
|
852 |
+
|
853 |
+
if mask_video is not None:
|
854 |
+
# Prepare mask latent variables
|
855 |
+
video_length = video.shape[2]
|
856 |
+
mask_condition = self.mask_processor.preprocess(rearrange(mask_video, "b c f h w -> (b f) c h w"), height=height, width=width)
|
857 |
+
mask_condition = mask_condition.to(dtype=torch.float32)
|
858 |
+
mask_condition = rearrange(mask_condition, "(b f) c h w -> b c f h w", f=video_length)
|
859 |
+
|
860 |
+
if masked_video_latents is None:
|
861 |
+
masked_video = init_video * (mask_condition < 0.5) + torch.ones_like(init_video) * (mask_condition > 0.5) * -1
|
862 |
+
else:
|
863 |
+
masked_video = masked_video_latents
|
864 |
+
|
865 |
+
mask, masked_video_latents = self.prepare_mask_latents(
|
866 |
+
mask_condition,
|
867 |
+
masked_video,
|
868 |
+
batch_size,
|
869 |
+
height,
|
870 |
+
width,
|
871 |
+
prompt_embeds.dtype,
|
872 |
+
device,
|
873 |
+
generator,
|
874 |
+
do_classifier_free_guidance,
|
875 |
+
)
|
876 |
+
else:
|
877 |
+
mask = torch.zeros_like(latents).to(latents.device, latents.dtype)
|
878 |
+
masked_video_latents = torch.zeros_like(latents).to(latents.device, latents.dtype)
|
879 |
+
|
880 |
+
# Check that sizes of mask, masked image and latents match
|
881 |
+
if num_channels_transformer == 12:
|
882 |
+
# default case for runwayml/stable-diffusion-inpainting
|
883 |
+
num_channels_mask = mask.shape[1]
|
884 |
+
num_channels_masked_image = masked_video_latents.shape[1]
|
885 |
+
if num_channels_latents + num_channels_mask + num_channels_masked_image != self.transformer.config.in_channels:
|
886 |
+
raise ValueError(
|
887 |
+
f"Incorrect configuration settings! The config of `pipeline.transformer`: {self.transformer.config} expects"
|
888 |
+
f" {self.transformer.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
|
889 |
+
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
|
890 |
+
f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
|
891 |
+
" `pipeline.transformer` or your `mask_image` or `image` input."
|
892 |
+
)
|
893 |
+
elif num_channels_transformer == 4:
|
894 |
+
raise ValueError(
|
895 |
+
f"The transformer {self.transformer.__class__} should have 9 input channels, not {self.transformer.config.in_channels}."
|
896 |
+
)
|
897 |
+
|
898 |
+
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
899 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
900 |
+
|
901 |
+
# 6.1 Prepare micro-conditions.
|
902 |
+
added_cond_kwargs = {"resolution": None, "aspect_ratio": None}
|
903 |
+
if self.transformer.config.sample_size == 128:
|
904 |
+
resolution = torch.tensor([height, width]).repeat(batch_size * num_images_per_prompt, 1)
|
905 |
+
aspect_ratio = torch.tensor([float(height / width)]).repeat(batch_size * num_images_per_prompt, 1)
|
906 |
+
resolution = resolution.to(dtype=prompt_embeds.dtype, device=device)
|
907 |
+
aspect_ratio = aspect_ratio.to(dtype=prompt_embeds.dtype, device=device)
|
908 |
+
|
909 |
+
if do_classifier_free_guidance:
|
910 |
+
resolution = torch.cat([resolution, resolution], dim=0)
|
911 |
+
aspect_ratio = torch.cat([aspect_ratio, aspect_ratio], dim=0)
|
912 |
+
|
913 |
+
added_cond_kwargs = {"resolution": resolution, "aspect_ratio": aspect_ratio}
|
914 |
+
|
915 |
+
# 7. Denoising loop
|
916 |
+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
917 |
+
|
918 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
919 |
+
for i, t in enumerate(timesteps):
|
920 |
+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
921 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
922 |
+
|
923 |
+
if num_channels_transformer == 12:
|
924 |
+
mask_input = torch.cat([mask] * 2) if do_classifier_free_guidance else mask
|
925 |
+
masked_video_latents_input = (
|
926 |
+
torch.cat([masked_video_latents] * 2) if do_classifier_free_guidance else masked_video_latents
|
927 |
+
)
|
928 |
+
inpaint_latents = torch.cat([mask_input, masked_video_latents_input], dim=1)
|
929 |
+
|
930 |
+
current_timestep = t
|
931 |
+
if not torch.is_tensor(current_timestep):
|
932 |
+
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
933 |
+
# This would be a good case for the `match` statement (Python 3.10+)
|
934 |
+
is_mps = latent_model_input.device.type == "mps"
|
935 |
+
if isinstance(current_timestep, float):
|
936 |
+
dtype = torch.float32 if is_mps else torch.float64
|
937 |
+
else:
|
938 |
+
dtype = torch.int32 if is_mps else torch.int64
|
939 |
+
current_timestep = torch.tensor([current_timestep], dtype=dtype, device=latent_model_input.device)
|
940 |
+
elif len(current_timestep.shape) == 0:
|
941 |
+
current_timestep = current_timestep[None].to(latent_model_input.device)
|
942 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
943 |
+
current_timestep = current_timestep.expand(latent_model_input.shape[0])
|
944 |
+
|
945 |
+
# predict noise model_output
|
946 |
+
noise_pred = self.transformer(
|
947 |
+
latent_model_input,
|
948 |
+
encoder_hidden_states=prompt_embeds,
|
949 |
+
encoder_attention_mask=prompt_attention_mask,
|
950 |
+
timestep=current_timestep,
|
951 |
+
added_cond_kwargs=added_cond_kwargs,
|
952 |
+
inpaint_latents=inpaint_latents.to(latent_model_input.dtype),
|
953 |
+
return_dict=False,
|
954 |
+
)[0]
|
955 |
+
|
956 |
+
# perform guidance
|
957 |
+
if do_classifier_free_guidance:
|
958 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
959 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
960 |
+
|
961 |
+
# learned sigma
|
962 |
+
noise_pred = noise_pred.chunk(2, dim=1)[0]
|
963 |
+
|
964 |
+
# compute previous image: x_t -> x_t-1
|
965 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
966 |
+
|
967 |
+
# call the callback, if provided
|
968 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
969 |
+
progress_bar.update()
|
970 |
+
if callback is not None and i % callback_steps == 0:
|
971 |
+
step_idx = i // getattr(self.scheduler, "order", 1)
|
972 |
+
callback(step_idx, t, latents)
|
973 |
+
|
974 |
+
# Post-processing
|
975 |
+
video = self.decode_latents(latents)
|
976 |
+
|
977 |
+
# Convert to tensor
|
978 |
+
if output_type == "latent":
|
979 |
+
video = torch.from_numpy(video)
|
980 |
+
|
981 |
+
if not return_dict:
|
982 |
+
return video
|
983 |
+
|
984 |
+
return EasyAnimatePipelineOutput(videos=video)
|
easyanimate/pipeline/pipeline_pixart_magvit.py
ADDED
@@ -0,0 +1,983 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 PixArt-Alpha Authors and The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import html
|
16 |
+
import inspect
|
17 |
+
import re
|
18 |
+
import urllib.parse as ul
|
19 |
+
from typing import Callable, List, Optional, Tuple, Union
|
20 |
+
|
21 |
+
import torch
|
22 |
+
import torch.nn.functional as F
|
23 |
+
from diffusers.image_processor import VaeImageProcessor
|
24 |
+
from diffusers.models import Transformer2DModel
|
25 |
+
from diffusers.pipelines.pipeline_utils import (DiffusionPipeline,
|
26 |
+
ImagePipelineOutput)
|
27 |
+
from diffusers.schedulers import DPMSolverMultistepScheduler
|
28 |
+
from diffusers.utils import (BACKENDS_MAPPING, deprecate, is_bs4_available,
|
29 |
+
is_ftfy_available, logging,
|
30 |
+
replace_example_docstring)
|
31 |
+
from diffusers.utils.torch_utils import randn_tensor
|
32 |
+
from transformers import T5EncoderModel, T5Tokenizer
|
33 |
+
|
34 |
+
from easyanimate.models.autoencoder_magvit import AutoencoderKLMagvit
|
35 |
+
|
36 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
37 |
+
|
38 |
+
if is_bs4_available():
|
39 |
+
from bs4 import BeautifulSoup
|
40 |
+
|
41 |
+
if is_ftfy_available():
|
42 |
+
import ftfy
|
43 |
+
|
44 |
+
EXAMPLE_DOC_STRING = """
|
45 |
+
Examples:
|
46 |
+
```py
|
47 |
+
>>> import torch
|
48 |
+
>>> from diffusers import PixArtAlphaPipeline
|
49 |
+
|
50 |
+
>>> # You can replace the checkpoint id with "PixArt-alpha/PixArt-XL-2-512x512" too.
|
51 |
+
>>> pipe = PixArtAlphaPipeline.from_pretrained("PixArt-alpha/PixArt-XL-2-1024-MS", torch_dtype=torch.float16)
|
52 |
+
>>> # Enable memory optimizations.
|
53 |
+
>>> pipe.enable_model_cpu_offload()
|
54 |
+
|
55 |
+
>>> prompt = "A small cactus with a happy face in the Sahara desert."
|
56 |
+
>>> image = pipe(prompt).images[0]
|
57 |
+
```
|
58 |
+
"""
|
59 |
+
|
60 |
+
ASPECT_RATIO_1024_BIN = {
|
61 |
+
"0.25": [512.0, 2048.0],
|
62 |
+
"0.28": [512.0, 1856.0],
|
63 |
+
"0.32": [576.0, 1792.0],
|
64 |
+
"0.33": [576.0, 1728.0],
|
65 |
+
"0.35": [576.0, 1664.0],
|
66 |
+
"0.4": [640.0, 1600.0],
|
67 |
+
"0.42": [640.0, 1536.0],
|
68 |
+
"0.48": [704.0, 1472.0],
|
69 |
+
"0.5": [704.0, 1408.0],
|
70 |
+
"0.52": [704.0, 1344.0],
|
71 |
+
"0.57": [768.0, 1344.0],
|
72 |
+
"0.6": [768.0, 1280.0],
|
73 |
+
"0.68": [832.0, 1216.0],
|
74 |
+
"0.72": [832.0, 1152.0],
|
75 |
+
"0.78": [896.0, 1152.0],
|
76 |
+
"0.82": [896.0, 1088.0],
|
77 |
+
"0.88": [960.0, 1088.0],
|
78 |
+
"0.94": [960.0, 1024.0],
|
79 |
+
"1.0": [1024.0, 1024.0],
|
80 |
+
"1.07": [1024.0, 960.0],
|
81 |
+
"1.13": [1088.0, 960.0],
|
82 |
+
"1.21": [1088.0, 896.0],
|
83 |
+
"1.29": [1152.0, 896.0],
|
84 |
+
"1.38": [1152.0, 832.0],
|
85 |
+
"1.46": [1216.0, 832.0],
|
86 |
+
"1.67": [1280.0, 768.0],
|
87 |
+
"1.75": [1344.0, 768.0],
|
88 |
+
"2.0": [1408.0, 704.0],
|
89 |
+
"2.09": [1472.0, 704.0],
|
90 |
+
"2.4": [1536.0, 640.0],
|
91 |
+
"2.5": [1600.0, 640.0],
|
92 |
+
"3.0": [1728.0, 576.0],
|
93 |
+
"4.0": [2048.0, 512.0],
|
94 |
+
}
|
95 |
+
|
96 |
+
ASPECT_RATIO_512_BIN = {
|
97 |
+
"0.25": [256.0, 1024.0],
|
98 |
+
"0.28": [256.0, 928.0],
|
99 |
+
"0.32": [288.0, 896.0],
|
100 |
+
"0.33": [288.0, 864.0],
|
101 |
+
"0.35": [288.0, 832.0],
|
102 |
+
"0.4": [320.0, 800.0],
|
103 |
+
"0.42": [320.0, 768.0],
|
104 |
+
"0.48": [352.0, 736.0],
|
105 |
+
"0.5": [352.0, 704.0],
|
106 |
+
"0.52": [352.0, 672.0],
|
107 |
+
"0.57": [384.0, 672.0],
|
108 |
+
"0.6": [384.0, 640.0],
|
109 |
+
"0.68": [416.0, 608.0],
|
110 |
+
"0.72": [416.0, 576.0],
|
111 |
+
"0.78": [448.0, 576.0],
|
112 |
+
"0.82": [448.0, 544.0],
|
113 |
+
"0.88": [480.0, 544.0],
|
114 |
+
"0.94": [480.0, 512.0],
|
115 |
+
"1.0": [512.0, 512.0],
|
116 |
+
"1.07": [512.0, 480.0],
|
117 |
+
"1.13": [544.0, 480.0],
|
118 |
+
"1.21": [544.0, 448.0],
|
119 |
+
"1.29": [576.0, 448.0],
|
120 |
+
"1.38": [576.0, 416.0],
|
121 |
+
"1.46": [608.0, 416.0],
|
122 |
+
"1.67": [640.0, 384.0],
|
123 |
+
"1.75": [672.0, 384.0],
|
124 |
+
"2.0": [704.0, 352.0],
|
125 |
+
"2.09": [736.0, 352.0],
|
126 |
+
"2.4": [768.0, 320.0],
|
127 |
+
"2.5": [800.0, 320.0],
|
128 |
+
"3.0": [864.0, 288.0],
|
129 |
+
"4.0": [1024.0, 256.0],
|
130 |
+
}
|
131 |
+
|
132 |
+
ASPECT_RATIO_256_BIN = {
|
133 |
+
"0.25": [128.0, 512.0],
|
134 |
+
"0.28": [128.0, 464.0],
|
135 |
+
"0.32": [144.0, 448.0],
|
136 |
+
"0.33": [144.0, 432.0],
|
137 |
+
"0.35": [144.0, 416.0],
|
138 |
+
"0.4": [160.0, 400.0],
|
139 |
+
"0.42": [160.0, 384.0],
|
140 |
+
"0.48": [176.0, 368.0],
|
141 |
+
"0.5": [176.0, 352.0],
|
142 |
+
"0.52": [176.0, 336.0],
|
143 |
+
"0.57": [192.0, 336.0],
|
144 |
+
"0.6": [192.0, 320.0],
|
145 |
+
"0.68": [208.0, 304.0],
|
146 |
+
"0.72": [208.0, 288.0],
|
147 |
+
"0.78": [224.0, 288.0],
|
148 |
+
"0.82": [224.0, 272.0],
|
149 |
+
"0.88": [240.0, 272.0],
|
150 |
+
"0.94": [240.0, 256.0],
|
151 |
+
"1.0": [256.0, 256.0],
|
152 |
+
"1.07": [256.0, 240.0],
|
153 |
+
"1.13": [272.0, 240.0],
|
154 |
+
"1.21": [272.0, 224.0],
|
155 |
+
"1.29": [288.0, 224.0],
|
156 |
+
"1.38": [288.0, 208.0],
|
157 |
+
"1.46": [304.0, 208.0],
|
158 |
+
"1.67": [320.0, 192.0],
|
159 |
+
"1.75": [336.0, 192.0],
|
160 |
+
"2.0": [352.0, 176.0],
|
161 |
+
"2.09": [368.0, 176.0],
|
162 |
+
"2.4": [384.0, 160.0],
|
163 |
+
"2.5": [400.0, 160.0],
|
164 |
+
"3.0": [432.0, 144.0],
|
165 |
+
"4.0": [512.0, 128.0],
|
166 |
+
}
|
167 |
+
|
168 |
+
|
169 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
170 |
+
def retrieve_timesteps(
|
171 |
+
scheduler,
|
172 |
+
num_inference_steps: Optional[int] = None,
|
173 |
+
device: Optional[Union[str, torch.device]] = None,
|
174 |
+
timesteps: Optional[List[int]] = None,
|
175 |
+
**kwargs,
|
176 |
+
):
|
177 |
+
"""
|
178 |
+
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
179 |
+
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
180 |
+
|
181 |
+
Args:
|
182 |
+
scheduler (`SchedulerMixin`):
|
183 |
+
The scheduler to get timesteps from.
|
184 |
+
num_inference_steps (`int`):
|
185 |
+
The number of diffusion steps used when generating samples with a pre-trained model. If used,
|
186 |
+
`timesteps` must be `None`.
|
187 |
+
device (`str` or `torch.device`, *optional*):
|
188 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
189 |
+
timesteps (`List[int]`, *optional*):
|
190 |
+
Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
|
191 |
+
timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
|
192 |
+
must be `None`.
|
193 |
+
|
194 |
+
Returns:
|
195 |
+
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
196 |
+
second element is the number of inference steps.
|
197 |
+
"""
|
198 |
+
if timesteps is not None:
|
199 |
+
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
200 |
+
if not accepts_timesteps:
|
201 |
+
raise ValueError(
|
202 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
203 |
+
f" timestep schedules. Please check whether you are using the correct scheduler."
|
204 |
+
)
|
205 |
+
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
206 |
+
timesteps = scheduler.timesteps
|
207 |
+
num_inference_steps = len(timesteps)
|
208 |
+
else:
|
209 |
+
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
210 |
+
timesteps = scheduler.timesteps
|
211 |
+
return timesteps, num_inference_steps
|
212 |
+
|
213 |
+
|
214 |
+
class PixArtAlphaMagvitPipeline(DiffusionPipeline):
|
215 |
+
r"""
|
216 |
+
Pipeline for text-to-image generation using PixArt-Alpha.
|
217 |
+
|
218 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
219 |
+
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
220 |
+
|
221 |
+
Args:
|
222 |
+
vae ([`AutoencoderKL`]):
|
223 |
+
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
224 |
+
text_encoder ([`T5EncoderModel`]):
|
225 |
+
Frozen text-encoder. PixArt-Alpha uses
|
226 |
+
[T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the
|
227 |
+
[t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant.
|
228 |
+
tokenizer (`T5Tokenizer`):
|
229 |
+
Tokenizer of class
|
230 |
+
[T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
|
231 |
+
transformer ([`Transformer2DModel`]):
|
232 |
+
A text conditioned `Transformer2DModel` to denoise the encoded image latents.
|
233 |
+
scheduler ([`SchedulerMixin`]):
|
234 |
+
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
235 |
+
"""
|
236 |
+
|
237 |
+
bad_punct_regex = re.compile(
|
238 |
+
r"["
|
239 |
+
+ "#®•©™&@·º½¾¿¡§~"
|
240 |
+
+ r"\)"
|
241 |
+
+ r"\("
|
242 |
+
+ r"\]"
|
243 |
+
+ r"\["
|
244 |
+
+ r"\}"
|
245 |
+
+ r"\{"
|
246 |
+
+ r"\|"
|
247 |
+
+ "\\"
|
248 |
+
+ r"\/"
|
249 |
+
+ r"\*"
|
250 |
+
+ r"]{1,}"
|
251 |
+
) # noqa
|
252 |
+
|
253 |
+
_optional_components = ["tokenizer", "text_encoder"]
|
254 |
+
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
255 |
+
|
256 |
+
def __init__(
|
257 |
+
self,
|
258 |
+
tokenizer: T5Tokenizer,
|
259 |
+
text_encoder: T5EncoderModel,
|
260 |
+
vae: AutoencoderKLMagvit,
|
261 |
+
transformer: Transformer2DModel,
|
262 |
+
scheduler: DPMSolverMultistepScheduler,
|
263 |
+
):
|
264 |
+
super().__init__()
|
265 |
+
|
266 |
+
self.register_modules(
|
267 |
+
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
|
268 |
+
)
|
269 |
+
|
270 |
+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
271 |
+
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
272 |
+
|
273 |
+
# Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/utils.py
|
274 |
+
def mask_text_embeddings(self, emb, mask):
|
275 |
+
if emb.shape[0] == 1:
|
276 |
+
keep_index = mask.sum().item()
|
277 |
+
return emb[:, :, :keep_index, :], keep_index
|
278 |
+
else:
|
279 |
+
masked_feature = emb * mask[:, None, :, None]
|
280 |
+
return masked_feature, emb.shape[2]
|
281 |
+
|
282 |
+
# Adapted from diffusers.pipelines.deepfloyd_if.pipeline_if.encode_prompt
|
283 |
+
def encode_prompt(
|
284 |
+
self,
|
285 |
+
prompt: Union[str, List[str]],
|
286 |
+
do_classifier_free_guidance: bool = True,
|
287 |
+
negative_prompt: str = "",
|
288 |
+
num_images_per_prompt: int = 1,
|
289 |
+
device: Optional[torch.device] = None,
|
290 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
291 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
292 |
+
prompt_attention_mask: Optional[torch.FloatTensor] = None,
|
293 |
+
negative_prompt_attention_mask: Optional[torch.FloatTensor] = None,
|
294 |
+
clean_caption: bool = False,
|
295 |
+
max_sequence_length: int = 120,
|
296 |
+
**kwargs,
|
297 |
+
):
|
298 |
+
r"""
|
299 |
+
Encodes the prompt into text encoder hidden states.
|
300 |
+
|
301 |
+
Args:
|
302 |
+
prompt (`str` or `List[str]`, *optional*):
|
303 |
+
prompt to be encoded
|
304 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
305 |
+
The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds`
|
306 |
+
instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For
|
307 |
+
PixArt-Alpha, this should be "".
|
308 |
+
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
|
309 |
+
whether to use classifier free guidance or not
|
310 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
311 |
+
number of images that should be generated per prompt
|
312 |
+
device: (`torch.device`, *optional*):
|
313 |
+
torch device to place the resulting embeddings on
|
314 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
315 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
316 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
317 |
+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
318 |
+
Pre-generated negative text embeddings. For PixArt-Alpha, it's should be the embeddings of the ""
|
319 |
+
string.
|
320 |
+
clean_caption (`bool`, defaults to `False`):
|
321 |
+
If `True`, the function will preprocess and clean the provided caption before encoding.
|
322 |
+
max_sequence_length (`int`, defaults to 120): Maximum sequence length to use for the prompt.
|
323 |
+
"""
|
324 |
+
|
325 |
+
if "mask_feature" in kwargs:
|
326 |
+
deprecation_message = "The use of `mask_feature` is deprecated. It is no longer used in any computation and that doesn't affect the end results. It will be removed in a future version."
|
327 |
+
deprecate("mask_feature", "1.0.0", deprecation_message, standard_warn=False)
|
328 |
+
|
329 |
+
if device is None:
|
330 |
+
device = self._execution_device
|
331 |
+
|
332 |
+
if prompt is not None and isinstance(prompt, str):
|
333 |
+
batch_size = 1
|
334 |
+
elif prompt is not None and isinstance(prompt, list):
|
335 |
+
batch_size = len(prompt)
|
336 |
+
else:
|
337 |
+
batch_size = prompt_embeds.shape[0]
|
338 |
+
|
339 |
+
# See Section 3.1. of the paper.
|
340 |
+
max_length = max_sequence_length
|
341 |
+
|
342 |
+
if prompt_embeds is None:
|
343 |
+
prompt = self._text_preprocessing(prompt, clean_caption=clean_caption)
|
344 |
+
text_inputs = self.tokenizer(
|
345 |
+
prompt,
|
346 |
+
padding="max_length",
|
347 |
+
max_length=max_length,
|
348 |
+
truncation=True,
|
349 |
+
add_special_tokens=True,
|
350 |
+
return_tensors="pt",
|
351 |
+
)
|
352 |
+
text_input_ids = text_inputs.input_ids
|
353 |
+
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
354 |
+
|
355 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
356 |
+
text_input_ids, untruncated_ids
|
357 |
+
):
|
358 |
+
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1])
|
359 |
+
logger.warning(
|
360 |
+
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
361 |
+
f" {max_length} tokens: {removed_text}"
|
362 |
+
)
|
363 |
+
|
364 |
+
prompt_attention_mask = text_inputs.attention_mask
|
365 |
+
prompt_attention_mask = prompt_attention_mask.to(device)
|
366 |
+
|
367 |
+
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask)
|
368 |
+
prompt_embeds = prompt_embeds[0]
|
369 |
+
|
370 |
+
if self.text_encoder is not None:
|
371 |
+
dtype = self.text_encoder.dtype
|
372 |
+
elif self.transformer is not None:
|
373 |
+
dtype = self.transformer.dtype
|
374 |
+
else:
|
375 |
+
dtype = None
|
376 |
+
|
377 |
+
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
378 |
+
|
379 |
+
bs_embed, seq_len, _ = prompt_embeds.shape
|
380 |
+
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
|
381 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
382 |
+
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
383 |
+
prompt_attention_mask = prompt_attention_mask.view(bs_embed, -1)
|
384 |
+
prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
|
385 |
+
|
386 |
+
# get unconditional embeddings for classifier free guidance
|
387 |
+
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
388 |
+
uncond_tokens = [negative_prompt] * batch_size
|
389 |
+
uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption)
|
390 |
+
max_length = prompt_embeds.shape[1]
|
391 |
+
uncond_input = self.tokenizer(
|
392 |
+
uncond_tokens,
|
393 |
+
padding="max_length",
|
394 |
+
max_length=max_length,
|
395 |
+
truncation=True,
|
396 |
+
return_attention_mask=True,
|
397 |
+
add_special_tokens=True,
|
398 |
+
return_tensors="pt",
|
399 |
+
)
|
400 |
+
negative_prompt_attention_mask = uncond_input.attention_mask
|
401 |
+
negative_prompt_attention_mask = negative_prompt_attention_mask.to(device)
|
402 |
+
|
403 |
+
negative_prompt_embeds = self.text_encoder(
|
404 |
+
uncond_input.input_ids.to(device), attention_mask=negative_prompt_attention_mask
|
405 |
+
)
|
406 |
+
negative_prompt_embeds = negative_prompt_embeds[0]
|
407 |
+
|
408 |
+
if do_classifier_free_guidance:
|
409 |
+
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
410 |
+
seq_len = negative_prompt_embeds.shape[1]
|
411 |
+
|
412 |
+
negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
|
413 |
+
|
414 |
+
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
415 |
+
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
416 |
+
|
417 |
+
negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed, -1)
|
418 |
+
negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1)
|
419 |
+
else:
|
420 |
+
negative_prompt_embeds = None
|
421 |
+
negative_prompt_attention_mask = None
|
422 |
+
|
423 |
+
return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
|
424 |
+
|
425 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
426 |
+
def prepare_extra_step_kwargs(self, generator, eta):
|
427 |
+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
428 |
+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
429 |
+
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
430 |
+
# and should be between [0, 1]
|
431 |
+
|
432 |
+
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
433 |
+
extra_step_kwargs = {}
|
434 |
+
if accepts_eta:
|
435 |
+
extra_step_kwargs["eta"] = eta
|
436 |
+
|
437 |
+
# check if the scheduler accepts generator
|
438 |
+
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
439 |
+
if accepts_generator:
|
440 |
+
extra_step_kwargs["generator"] = generator
|
441 |
+
return extra_step_kwargs
|
442 |
+
|
443 |
+
def check_inputs(
|
444 |
+
self,
|
445 |
+
prompt,
|
446 |
+
height,
|
447 |
+
width,
|
448 |
+
negative_prompt,
|
449 |
+
callback_steps,
|
450 |
+
prompt_embeds=None,
|
451 |
+
negative_prompt_embeds=None,
|
452 |
+
prompt_attention_mask=None,
|
453 |
+
negative_prompt_attention_mask=None,
|
454 |
+
):
|
455 |
+
if height % 8 != 0 or width % 8 != 0:
|
456 |
+
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
457 |
+
|
458 |
+
if (callback_steps is None) or (
|
459 |
+
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
460 |
+
):
|
461 |
+
raise ValueError(
|
462 |
+
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
463 |
+
f" {type(callback_steps)}."
|
464 |
+
)
|
465 |
+
|
466 |
+
if prompt is not None and prompt_embeds is not None:
|
467 |
+
raise ValueError(
|
468 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
469 |
+
" only forward one of the two."
|
470 |
+
)
|
471 |
+
elif prompt is None and prompt_embeds is None:
|
472 |
+
raise ValueError(
|
473 |
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
474 |
+
)
|
475 |
+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
476 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
477 |
+
|
478 |
+
if prompt is not None and negative_prompt_embeds is not None:
|
479 |
+
raise ValueError(
|
480 |
+
f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
|
481 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
482 |
+
)
|
483 |
+
|
484 |
+
if negative_prompt is not None and negative_prompt_embeds is not None:
|
485 |
+
raise ValueError(
|
486 |
+
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
487 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
488 |
+
)
|
489 |
+
|
490 |
+
if prompt_embeds is not None and prompt_attention_mask is None:
|
491 |
+
raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.")
|
492 |
+
|
493 |
+
if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:
|
494 |
+
raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
|
495 |
+
|
496 |
+
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
497 |
+
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
498 |
+
raise ValueError(
|
499 |
+
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
500 |
+
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
501 |
+
f" {negative_prompt_embeds.shape}."
|
502 |
+
)
|
503 |
+
if prompt_attention_mask.shape != negative_prompt_attention_mask.shape:
|
504 |
+
raise ValueError(
|
505 |
+
"`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but"
|
506 |
+
f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`"
|
507 |
+
f" {negative_prompt_attention_mask.shape}."
|
508 |
+
)
|
509 |
+
|
510 |
+
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing
|
511 |
+
def _text_preprocessing(self, text, clean_caption=False):
|
512 |
+
if clean_caption and not is_bs4_available():
|
513 |
+
logger.warn(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`"))
|
514 |
+
logger.warn("Setting `clean_caption` to False...")
|
515 |
+
clean_caption = False
|
516 |
+
|
517 |
+
if clean_caption and not is_ftfy_available():
|
518 |
+
logger.warn(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`"))
|
519 |
+
logger.warn("Setting `clean_caption` to False...")
|
520 |
+
clean_caption = False
|
521 |
+
|
522 |
+
if not isinstance(text, (tuple, list)):
|
523 |
+
text = [text]
|
524 |
+
|
525 |
+
def process(text: str):
|
526 |
+
if clean_caption:
|
527 |
+
text = self._clean_caption(text)
|
528 |
+
text = self._clean_caption(text)
|
529 |
+
else:
|
530 |
+
text = text.lower().strip()
|
531 |
+
return text
|
532 |
+
|
533 |
+
return [process(t) for t in text]
|
534 |
+
|
535 |
+
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption
|
536 |
+
def _clean_caption(self, caption):
|
537 |
+
caption = str(caption)
|
538 |
+
caption = ul.unquote_plus(caption)
|
539 |
+
caption = caption.strip().lower()
|
540 |
+
caption = re.sub("<person>", "person", caption)
|
541 |
+
# urls:
|
542 |
+
caption = re.sub(
|
543 |
+
r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
|
544 |
+
"",
|
545 |
+
caption,
|
546 |
+
) # regex for urls
|
547 |
+
caption = re.sub(
|
548 |
+
r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
|
549 |
+
"",
|
550 |
+
caption,
|
551 |
+
) # regex for urls
|
552 |
+
# html:
|
553 |
+
caption = BeautifulSoup(caption, features="html.parser").text
|
554 |
+
|
555 |
+
# @<nickname>
|
556 |
+
caption = re.sub(r"@[\w\d]+\b", "", caption)
|
557 |
+
|
558 |
+
# 31C0—31EF CJK Strokes
|
559 |
+
# 31F0—31FF Katakana Phonetic Extensions
|
560 |
+
# 3200—32FF Enclosed CJK Letters and Months
|
561 |
+
# 3300—33FF CJK Compatibility
|
562 |
+
# 3400—4DBF CJK Unified Ideographs Extension A
|
563 |
+
# 4DC0—4DFF Yijing Hexagram Symbols
|
564 |
+
# 4E00—9FFF CJK Unified Ideographs
|
565 |
+
caption = re.sub(r"[\u31c0-\u31ef]+", "", caption)
|
566 |
+
caption = re.sub(r"[\u31f0-\u31ff]+", "", caption)
|
567 |
+
caption = re.sub(r"[\u3200-\u32ff]+", "", caption)
|
568 |
+
caption = re.sub(r"[\u3300-\u33ff]+", "", caption)
|
569 |
+
caption = re.sub(r"[\u3400-\u4dbf]+", "", caption)
|
570 |
+
caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption)
|
571 |
+
caption = re.sub(r"[\u4e00-\u9fff]+", "", caption)
|
572 |
+
#######################################################
|
573 |
+
|
574 |
+
# все виды тире / all types of dash --> "-"
|
575 |
+
caption = re.sub(
|
576 |
+
r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa
|
577 |
+
"-",
|
578 |
+
caption,
|
579 |
+
)
|
580 |
+
|
581 |
+
# кавычки к одному стандарту
|
582 |
+
caption = re.sub(r"[`´«»“”¨]", '"', caption)
|
583 |
+
caption = re.sub(r"[‘’]", "'", caption)
|
584 |
+
|
585 |
+
# "
|
586 |
+
caption = re.sub(r""?", "", caption)
|
587 |
+
# &
|
588 |
+
caption = re.sub(r"&", "", caption)
|
589 |
+
|
590 |
+
# ip adresses:
|
591 |
+
caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
|
592 |
+
|
593 |
+
# article ids:
|
594 |
+
caption = re.sub(r"\d:\d\d\s+$", "", caption)
|
595 |
+
|
596 |
+
# \n
|
597 |
+
caption = re.sub(r"\\n", " ", caption)
|
598 |
+
|
599 |
+
# "#123"
|
600 |
+
caption = re.sub(r"#\d{1,3}\b", "", caption)
|
601 |
+
# "#12345.."
|
602 |
+
caption = re.sub(r"#\d{5,}\b", "", caption)
|
603 |
+
# "123456.."
|
604 |
+
caption = re.sub(r"\b\d{6,}\b", "", caption)
|
605 |
+
# filenames:
|
606 |
+
caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption)
|
607 |
+
|
608 |
+
#
|
609 |
+
caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT"""
|
610 |
+
caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT"""
|
611 |
+
|
612 |
+
caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT
|
613 |
+
caption = re.sub(r"\s+\.\s+", r" ", caption) # " . "
|
614 |
+
|
615 |
+
# this-is-my-cute-cat / this_is_my_cute_cat
|
616 |
+
regex2 = re.compile(r"(?:\-|\_)")
|
617 |
+
if len(re.findall(regex2, caption)) > 3:
|
618 |
+
caption = re.sub(regex2, " ", caption)
|
619 |
+
|
620 |
+
caption = ftfy.fix_text(caption)
|
621 |
+
caption = html.unescape(html.unescape(caption))
|
622 |
+
|
623 |
+
caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640
|
624 |
+
caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc
|
625 |
+
caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231
|
626 |
+
|
627 |
+
caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption)
|
628 |
+
caption = re.sub(r"(free\s)?download(\sfree)?", "", caption)
|
629 |
+
caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption)
|
630 |
+
caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption)
|
631 |
+
caption = re.sub(r"\bpage\s+\d+\b", "", caption)
|
632 |
+
|
633 |
+
caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a...
|
634 |
+
|
635 |
+
caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption)
|
636 |
+
|
637 |
+
caption = re.sub(r"\b\s+\:\s+", r": ", caption)
|
638 |
+
caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption)
|
639 |
+
caption = re.sub(r"\s+", " ", caption)
|
640 |
+
|
641 |
+
caption.strip()
|
642 |
+
|
643 |
+
caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption)
|
644 |
+
caption = re.sub(r"^[\'\_,\-\:;]", r"", caption)
|
645 |
+
caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption)
|
646 |
+
caption = re.sub(r"^\.\S+$", "", caption)
|
647 |
+
|
648 |
+
return caption.strip()
|
649 |
+
|
650 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
|
651 |
+
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
|
652 |
+
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
|
653 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
654 |
+
raise ValueError(
|
655 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
656 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
657 |
+
)
|
658 |
+
|
659 |
+
if latents is None:
|
660 |
+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
661 |
+
else:
|
662 |
+
latents = latents.to(device)
|
663 |
+
|
664 |
+
# scale the initial noise by the standard deviation required by the scheduler
|
665 |
+
latents = latents * self.scheduler.init_noise_sigma
|
666 |
+
return latents
|
667 |
+
|
668 |
+
@staticmethod
|
669 |
+
def classify_height_width_bin(height: int, width: int, ratios: dict) -> Tuple[int, int]:
|
670 |
+
"""Returns binned height and width."""
|
671 |
+
ar = float(height / width)
|
672 |
+
closest_ratio = min(ratios.keys(), key=lambda ratio: abs(float(ratio) - ar))
|
673 |
+
default_hw = ratios[closest_ratio]
|
674 |
+
return int(default_hw[0]), int(default_hw[1])
|
675 |
+
|
676 |
+
@staticmethod
|
677 |
+
def resize_and_crop_tensor(samples: torch.Tensor, new_width: int, new_height: int) -> torch.Tensor:
|
678 |
+
orig_height, orig_width = samples.shape[2], samples.shape[3]
|
679 |
+
|
680 |
+
# Check if resizing is needed
|
681 |
+
if orig_height != new_height or orig_width != new_width:
|
682 |
+
ratio = max(new_height / orig_height, new_width / orig_width)
|
683 |
+
resized_width = int(orig_width * ratio)
|
684 |
+
resized_height = int(orig_height * ratio)
|
685 |
+
|
686 |
+
# Resize
|
687 |
+
samples = F.interpolate(
|
688 |
+
samples, size=(resized_height, resized_width), mode="bilinear", align_corners=False
|
689 |
+
)
|
690 |
+
|
691 |
+
# Center Crop
|
692 |
+
start_x = (resized_width - new_width) // 2
|
693 |
+
end_x = start_x + new_width
|
694 |
+
start_y = (resized_height - new_height) // 2
|
695 |
+
end_y = start_y + new_height
|
696 |
+
samples = samples[:, :, start_y:end_y, start_x:end_x]
|
697 |
+
|
698 |
+
return samples
|
699 |
+
|
700 |
+
@torch.no_grad()
|
701 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
702 |
+
def __call__(
|
703 |
+
self,
|
704 |
+
prompt: Union[str, List[str]] = None,
|
705 |
+
negative_prompt: str = "",
|
706 |
+
num_inference_steps: int = 20,
|
707 |
+
timesteps: List[int] = None,
|
708 |
+
guidance_scale: float = 4.5,
|
709 |
+
num_images_per_prompt: Optional[int] = 1,
|
710 |
+
height: Optional[int] = None,
|
711 |
+
width: Optional[int] = None,
|
712 |
+
eta: float = 0.0,
|
713 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
714 |
+
latents: Optional[torch.FloatTensor] = None,
|
715 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
716 |
+
prompt_attention_mask: Optional[torch.FloatTensor] = None,
|
717 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
718 |
+
negative_prompt_attention_mask: Optional[torch.FloatTensor] = None,
|
719 |
+
output_type: Optional[str] = "pil",
|
720 |
+
return_dict: bool = True,
|
721 |
+
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
722 |
+
callback_steps: int = 1,
|
723 |
+
clean_caption: bool = True,
|
724 |
+
use_resolution_binning: bool = False,
|
725 |
+
max_sequence_length: int = 120,
|
726 |
+
**kwargs,
|
727 |
+
) -> Union[ImagePipelineOutput, Tuple]:
|
728 |
+
"""
|
729 |
+
Function invoked when calling the pipeline for generation.
|
730 |
+
|
731 |
+
Args:
|
732 |
+
prompt (`str` or `List[str]`, *optional*):
|
733 |
+
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
734 |
+
instead.
|
735 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
736 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
737 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
738 |
+
less than `1`).
|
739 |
+
num_inference_steps (`int`, *optional*, defaults to 100):
|
740 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
741 |
+
expense of slower inference.
|
742 |
+
timesteps (`List[int]`, *optional*):
|
743 |
+
Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps`
|
744 |
+
timesteps are used. Must be in descending order.
|
745 |
+
guidance_scale (`float`, *optional*, defaults to 4.5):
|
746 |
+
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
747 |
+
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
748 |
+
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
749 |
+
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
750 |
+
usually at the expense of lower image quality.
|
751 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
752 |
+
The number of images to generate per prompt.
|
753 |
+
height (`int`, *optional*, defaults to self.unet.config.sample_size):
|
754 |
+
The height in pixels of the generated image.
|
755 |
+
width (`int`, *optional*, defaults to self.unet.config.sample_size):
|
756 |
+
The width in pixels of the generated image.
|
757 |
+
eta (`float`, *optional*, defaults to 0.0):
|
758 |
+
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
759 |
+
[`schedulers.DDIMScheduler`], will be ignored for others.
|
760 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
761 |
+
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
762 |
+
to make generation deterministic.
|
763 |
+
latents (`torch.FloatTensor`, *optional*):
|
764 |
+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
765 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
766 |
+
tensor will ge generated by sampling using the supplied random `generator`.
|
767 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
768 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
769 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
770 |
+
prompt_attention_mask (`torch.FloatTensor`, *optional*): Pre-generated attention mask for text embeddings.
|
771 |
+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
772 |
+
Pre-generated negative text embeddings. For PixArt-Alpha this negative prompt should be "". If not
|
773 |
+
provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
|
774 |
+
negative_prompt_attention_mask (`torch.FloatTensor`, *optional*):
|
775 |
+
Pre-generated attention mask for negative text embeddings.
|
776 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
777 |
+
The output format of the generate image. Choose between
|
778 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
779 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
780 |
+
Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple.
|
781 |
+
callback (`Callable`, *optional*):
|
782 |
+
A function that will be called every `callback_steps` steps during inference. The function will be
|
783 |
+
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
784 |
+
callback_steps (`int`, *optional*, defaults to 1):
|
785 |
+
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
786 |
+
called at every step.
|
787 |
+
clean_caption (`bool`, *optional*, defaults to `True`):
|
788 |
+
Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to
|
789 |
+
be installed. If the dependencies are not installed, the embeddings will be created from the raw
|
790 |
+
prompt.
|
791 |
+
use_resolution_binning (`bool` defaults to `True`):
|
792 |
+
If set to `True`, the requested height and width are first mapped to the closest resolutions using
|
793 |
+
`ASPECT_RATIO_1024_BIN`. After the produced latents are decoded into images, they are resized back to
|
794 |
+
the requested resolution. Useful for generating non-square images.
|
795 |
+
max_sequence_length (`int` defaults to 120): Maximum sequence length to use with the `prompt`.
|
796 |
+
|
797 |
+
Examples:
|
798 |
+
|
799 |
+
Returns:
|
800 |
+
[`~pipelines.ImagePipelineOutput`] or `tuple`:
|
801 |
+
If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is
|
802 |
+
returned where the first element is a list with the generated images
|
803 |
+
"""
|
804 |
+
if "mask_feature" in kwargs:
|
805 |
+
deprecation_message = "The use of `mask_feature` is deprecated. It is no longer used in any computation and that doesn't affect the end results. It will be removed in a future version."
|
806 |
+
deprecate("mask_feature", "1.0.0", deprecation_message, standard_warn=False)
|
807 |
+
# 1. Check inputs. Raise error if not correct
|
808 |
+
height = height or self.transformer.config.sample_size * self.vae_scale_factor
|
809 |
+
width = width or self.transformer.config.sample_size * self.vae_scale_factor
|
810 |
+
if use_resolution_binning:
|
811 |
+
if self.transformer.config.sample_size == 128:
|
812 |
+
aspect_ratio_bin = ASPECT_RATIO_1024_BIN
|
813 |
+
elif self.transformer.config.sample_size == 64:
|
814 |
+
aspect_ratio_bin = ASPECT_RATIO_512_BIN
|
815 |
+
elif self.transformer.config.sample_size == 32:
|
816 |
+
aspect_ratio_bin = ASPECT_RATIO_256_BIN
|
817 |
+
else:
|
818 |
+
raise ValueError("Invalid sample size")
|
819 |
+
orig_height, orig_width = height, width
|
820 |
+
height, width = self.classify_height_width_bin(height, width, ratios=aspect_ratio_bin)
|
821 |
+
|
822 |
+
self.check_inputs(
|
823 |
+
prompt,
|
824 |
+
height,
|
825 |
+
width,
|
826 |
+
negative_prompt,
|
827 |
+
callback_steps,
|
828 |
+
prompt_embeds,
|
829 |
+
negative_prompt_embeds,
|
830 |
+
prompt_attention_mask,
|
831 |
+
negative_prompt_attention_mask,
|
832 |
+
)
|
833 |
+
|
834 |
+
# 2. Default height and width to transformer
|
835 |
+
if prompt is not None and isinstance(prompt, str):
|
836 |
+
batch_size = 1
|
837 |
+
elif prompt is not None and isinstance(prompt, list):
|
838 |
+
batch_size = len(prompt)
|
839 |
+
else:
|
840 |
+
batch_size = prompt_embeds.shape[0]
|
841 |
+
|
842 |
+
device = self._execution_device
|
843 |
+
|
844 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
845 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
846 |
+
# corresponds to doing no classifier free guidance.
|
847 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
848 |
+
|
849 |
+
# 3. Encode input prompt
|
850 |
+
(
|
851 |
+
prompt_embeds,
|
852 |
+
prompt_attention_mask,
|
853 |
+
negative_prompt_embeds,
|
854 |
+
negative_prompt_attention_mask,
|
855 |
+
) = self.encode_prompt(
|
856 |
+
prompt,
|
857 |
+
do_classifier_free_guidance,
|
858 |
+
negative_prompt=negative_prompt,
|
859 |
+
num_images_per_prompt=num_images_per_prompt,
|
860 |
+
device=device,
|
861 |
+
prompt_embeds=prompt_embeds,
|
862 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
863 |
+
prompt_attention_mask=prompt_attention_mask,
|
864 |
+
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
865 |
+
clean_caption=clean_caption,
|
866 |
+
max_sequence_length=max_sequence_length,
|
867 |
+
)
|
868 |
+
if do_classifier_free_guidance:
|
869 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
870 |
+
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
|
871 |
+
|
872 |
+
# 4. Prepare timesteps
|
873 |
+
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
|
874 |
+
|
875 |
+
# 5. Prepare latents.
|
876 |
+
latent_channels = self.transformer.config.in_channels
|
877 |
+
latents = self.prepare_latents(
|
878 |
+
batch_size * num_images_per_prompt,
|
879 |
+
latent_channels,
|
880 |
+
height,
|
881 |
+
width,
|
882 |
+
prompt_embeds.dtype,
|
883 |
+
device,
|
884 |
+
generator,
|
885 |
+
latents,
|
886 |
+
)
|
887 |
+
|
888 |
+
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
889 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
890 |
+
|
891 |
+
# 6.1 Prepare micro-conditions.
|
892 |
+
added_cond_kwargs = {"resolution": None, "aspect_ratio": None}
|
893 |
+
if self.transformer.config.sample_size == 128:
|
894 |
+
resolution = torch.tensor([height, width]).repeat(batch_size * num_images_per_prompt, 1)
|
895 |
+
aspect_ratio = torch.tensor([float(height / width)]).repeat(batch_size * num_images_per_prompt, 1)
|
896 |
+
resolution = resolution.to(dtype=prompt_embeds.dtype, device=device)
|
897 |
+
aspect_ratio = aspect_ratio.to(dtype=prompt_embeds.dtype, device=device)
|
898 |
+
|
899 |
+
if do_classifier_free_guidance:
|
900 |
+
resolution = torch.cat([resolution, resolution], dim=0)
|
901 |
+
aspect_ratio = torch.cat([aspect_ratio, aspect_ratio], dim=0)
|
902 |
+
|
903 |
+
added_cond_kwargs = {"resolution": resolution, "aspect_ratio": aspect_ratio}
|
904 |
+
|
905 |
+
# 7. Denoising loop
|
906 |
+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
907 |
+
|
908 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
909 |
+
for i, t in enumerate(timesteps):
|
910 |
+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
911 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
912 |
+
|
913 |
+
current_timestep = t
|
914 |
+
if not torch.is_tensor(current_timestep):
|
915 |
+
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
916 |
+
# This would be a good case for the `match` statement (Python 3.10+)
|
917 |
+
is_mps = latent_model_input.device.type == "mps"
|
918 |
+
if isinstance(current_timestep, float):
|
919 |
+
dtype = torch.float32 if is_mps else torch.float64
|
920 |
+
else:
|
921 |
+
dtype = torch.int32 if is_mps else torch.int64
|
922 |
+
current_timestep = torch.tensor([current_timestep], dtype=dtype, device=latent_model_input.device)
|
923 |
+
elif len(current_timestep.shape) == 0:
|
924 |
+
current_timestep = current_timestep[None].to(latent_model_input.device)
|
925 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
926 |
+
current_timestep = current_timestep.expand(latent_model_input.shape[0])
|
927 |
+
|
928 |
+
# predict noise model_output
|
929 |
+
noise_pred = self.transformer(
|
930 |
+
latent_model_input,
|
931 |
+
encoder_hidden_states=prompt_embeds,
|
932 |
+
encoder_attention_mask=prompt_attention_mask,
|
933 |
+
timestep=current_timestep,
|
934 |
+
added_cond_kwargs=added_cond_kwargs,
|
935 |
+
return_dict=False,
|
936 |
+
)[0]
|
937 |
+
|
938 |
+
# perform guidance
|
939 |
+
if do_classifier_free_guidance:
|
940 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
941 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
942 |
+
|
943 |
+
# learned sigma
|
944 |
+
if self.transformer.config.out_channels // 2 == latent_channels:
|
945 |
+
noise_pred = noise_pred.chunk(2, dim=1)[0]
|
946 |
+
else:
|
947 |
+
noise_pred = noise_pred
|
948 |
+
|
949 |
+
# compute previous image: x_t -> x_t-1
|
950 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
951 |
+
|
952 |
+
# call the callback, if provided
|
953 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
954 |
+
progress_bar.update()
|
955 |
+
if callback is not None and i % callback_steps == 0:
|
956 |
+
step_idx = i // getattr(self.scheduler, "order", 1)
|
957 |
+
callback(step_idx, t, latents)
|
958 |
+
|
959 |
+
if not output_type == "latent":
|
960 |
+
if self.vae.quant_conv.weight.ndim==5:
|
961 |
+
latents = latents.unsqueeze(2)
|
962 |
+
latents = latents.float()
|
963 |
+
self.vae.post_quant_conv = self.vae.post_quant_conv.float()
|
964 |
+
self.vae.decoder = self.vae.decoder.float()
|
965 |
+
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
|
966 |
+
if self.vae.quant_conv.weight.ndim==5:
|
967 |
+
image = image.permute(0,2,1,3,4).flatten(0, 1)
|
968 |
+
|
969 |
+
if use_resolution_binning:
|
970 |
+
image = self.resize_and_crop_tensor(image, orig_width, orig_height)
|
971 |
+
else:
|
972 |
+
image = latents
|
973 |
+
|
974 |
+
if not output_type == "latent":
|
975 |
+
image = self.image_processor.postprocess(image, output_type=output_type)
|
976 |
+
|
977 |
+
# Offload all models
|
978 |
+
self.maybe_free_model_hooks()
|
979 |
+
|
980 |
+
if not return_dict:
|
981 |
+
return (image,)
|
982 |
+
|
983 |
+
return ImagePipelineOutput(images=image)
|
easyanimate/ui/ui.py
ADDED
@@ -0,0 +1,818 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Modified from https://github.com/guoyww/AnimateDiff/blob/main/app.py
|
2 |
+
"""
|
3 |
+
import gc
|
4 |
+
import json
|
5 |
+
import os
|
6 |
+
import random
|
7 |
+
import base64
|
8 |
+
import requests
|
9 |
+
from datetime import datetime
|
10 |
+
from glob import glob
|
11 |
+
|
12 |
+
import gradio as gr
|
13 |
+
import torch
|
14 |
+
import numpy as np
|
15 |
+
from diffusers import (AutoencoderKL, DDIMScheduler,
|
16 |
+
DPMSolverMultistepScheduler,
|
17 |
+
EulerAncestralDiscreteScheduler, EulerDiscreteScheduler,
|
18 |
+
PNDMScheduler)
|
19 |
+
from easyanimate.models.autoencoder_magvit import AutoencoderKLMagvit
|
20 |
+
from diffusers.utils.import_utils import is_xformers_available
|
21 |
+
from omegaconf import OmegaConf
|
22 |
+
from safetensors import safe_open
|
23 |
+
from transformers import T5EncoderModel, T5Tokenizer
|
24 |
+
|
25 |
+
from easyanimate.models.transformer3d import Transformer3DModel
|
26 |
+
from easyanimate.pipeline.pipeline_easyanimate import EasyAnimatePipeline
|
27 |
+
from easyanimate.utils.lora_utils import merge_lora, unmerge_lora
|
28 |
+
from easyanimate.utils.utils import save_videos_grid
|
29 |
+
from PIL import Image
|
30 |
+
|
31 |
+
sample_idx = 0
|
32 |
+
scheduler_dict = {
|
33 |
+
"Euler": EulerDiscreteScheduler,
|
34 |
+
"Euler A": EulerAncestralDiscreteScheduler,
|
35 |
+
"DPM++": DPMSolverMultistepScheduler,
|
36 |
+
"PNDM": PNDMScheduler,
|
37 |
+
"DDIM": DDIMScheduler,
|
38 |
+
}
|
39 |
+
|
40 |
+
css = """
|
41 |
+
.toolbutton {
|
42 |
+
margin-buttom: 0em 0em 0em 0em;
|
43 |
+
max-width: 2.5em;
|
44 |
+
min-width: 2.5em !important;
|
45 |
+
height: 2.5em;
|
46 |
+
}
|
47 |
+
"""
|
48 |
+
|
49 |
+
class EasyAnimateController:
|
50 |
+
def __init__(self):
|
51 |
+
# config dirs
|
52 |
+
self.basedir = os.getcwd()
|
53 |
+
self.config_dir = os.path.join(self.basedir, "config")
|
54 |
+
self.diffusion_transformer_dir = os.path.join(self.basedir, "models", "Diffusion_Transformer")
|
55 |
+
self.motion_module_dir = os.path.join(self.basedir, "models", "Motion_Module")
|
56 |
+
self.personalized_model_dir = os.path.join(self.basedir, "models", "Personalized_Model")
|
57 |
+
self.savedir = os.path.join(self.basedir, "samples", datetime.now().strftime("Gradio-%Y-%m-%dT%H-%M-%S"))
|
58 |
+
self.savedir_sample = os.path.join(self.savedir, "sample")
|
59 |
+
self.edition = "v2"
|
60 |
+
self.inference_config = OmegaConf.load(os.path.join(self.config_dir, "easyanimate_video_magvit_motion_module_v2.yaml"))
|
61 |
+
os.makedirs(self.savedir, exist_ok=True)
|
62 |
+
|
63 |
+
self.diffusion_transformer_list = []
|
64 |
+
self.motion_module_list = []
|
65 |
+
self.personalized_model_list = []
|
66 |
+
|
67 |
+
self.refresh_diffusion_transformer()
|
68 |
+
self.refresh_motion_module()
|
69 |
+
self.refresh_personalized_model()
|
70 |
+
|
71 |
+
# config models
|
72 |
+
self.tokenizer = None
|
73 |
+
self.text_encoder = None
|
74 |
+
self.vae = None
|
75 |
+
self.transformer = None
|
76 |
+
self.pipeline = None
|
77 |
+
self.motion_module_path = "none"
|
78 |
+
self.base_model_path = "none"
|
79 |
+
self.lora_model_path = "none"
|
80 |
+
|
81 |
+
self.weight_dtype = torch.bfloat16
|
82 |
+
|
83 |
+
def refresh_diffusion_transformer(self):
|
84 |
+
self.diffusion_transformer_list = glob(os.path.join(self.diffusion_transformer_dir, "*/"))
|
85 |
+
|
86 |
+
def refresh_motion_module(self):
|
87 |
+
motion_module_list = glob(os.path.join(self.motion_module_dir, "*.safetensors"))
|
88 |
+
self.motion_module_list = [os.path.basename(p) for p in motion_module_list]
|
89 |
+
|
90 |
+
def refresh_personalized_model(self):
|
91 |
+
personalized_model_list = glob(os.path.join(self.personalized_model_dir, "*.safetensors"))
|
92 |
+
self.personalized_model_list = [os.path.basename(p) for p in personalized_model_list]
|
93 |
+
|
94 |
+
def update_edition(self, edition):
|
95 |
+
print("Update edition of EasyAnimate")
|
96 |
+
self.edition = edition
|
97 |
+
if edition == "v1":
|
98 |
+
self.inference_config = OmegaConf.load(os.path.join(self.config_dir, "easyanimate_video_motion_module_v1.yaml"))
|
99 |
+
return gr.Dropdown.update(), gr.update(value="none"), gr.update(visible=True), gr.update(visible=True), \
|
100 |
+
gr.update(visible=False), gr.update(value=512, minimum=384, maximum=704, step=32), \
|
101 |
+
gr.update(value=512, minimum=384, maximum=704, step=32), gr.update(value=80, minimum=40, maximum=80, step=1)
|
102 |
+
else:
|
103 |
+
self.inference_config = OmegaConf.load(os.path.join(self.config_dir, "easyanimate_video_magvit_motion_module_v2.yaml"))
|
104 |
+
return gr.Dropdown.update(), gr.update(value="none"), gr.update(visible=False), gr.update(visible=False), \
|
105 |
+
gr.update(visible=True), gr.update(value=672, minimum=128, maximum=1280, step=16), \
|
106 |
+
gr.update(value=384, minimum=128, maximum=1280, step=16), gr.update(value=144, minimum=9, maximum=144, step=9)
|
107 |
+
|
108 |
+
def update_diffusion_transformer(self, diffusion_transformer_dropdown):
|
109 |
+
print("Update diffusion transformer")
|
110 |
+
if diffusion_transformer_dropdown == "none":
|
111 |
+
return gr.Dropdown.update()
|
112 |
+
if OmegaConf.to_container(self.inference_config['vae_kwargs'])['enable_magvit']:
|
113 |
+
Choosen_AutoencoderKL = AutoencoderKLMagvit
|
114 |
+
else:
|
115 |
+
Choosen_AutoencoderKL = AutoencoderKL
|
116 |
+
self.vae = Choosen_AutoencoderKL.from_pretrained(
|
117 |
+
diffusion_transformer_dropdown,
|
118 |
+
subfolder="vae",
|
119 |
+
).to(self.weight_dtype)
|
120 |
+
self.transformer = Transformer3DModel.from_pretrained_2d(
|
121 |
+
diffusion_transformer_dropdown,
|
122 |
+
subfolder="transformer",
|
123 |
+
transformer_additional_kwargs=OmegaConf.to_container(self.inference_config.transformer_additional_kwargs)
|
124 |
+
).to(self.weight_dtype)
|
125 |
+
self.tokenizer = T5Tokenizer.from_pretrained(diffusion_transformer_dropdown, subfolder="tokenizer")
|
126 |
+
self.text_encoder = T5EncoderModel.from_pretrained(diffusion_transformer_dropdown, subfolder="text_encoder", torch_dtype=self.weight_dtype)
|
127 |
+
|
128 |
+
# Get pipeline
|
129 |
+
self.pipeline = EasyAnimatePipeline(
|
130 |
+
vae=self.vae,
|
131 |
+
text_encoder=self.text_encoder,
|
132 |
+
tokenizer=self.tokenizer,
|
133 |
+
transformer=self.transformer,
|
134 |
+
scheduler=scheduler_dict["Euler"](**OmegaConf.to_container(self.inference_config.noise_scheduler_kwargs))
|
135 |
+
)
|
136 |
+
self.pipeline.enable_model_cpu_offload()
|
137 |
+
print("Update diffusion transformer done")
|
138 |
+
return gr.Dropdown.update()
|
139 |
+
|
140 |
+
def update_motion_module(self, motion_module_dropdown):
|
141 |
+
self.motion_module_path = motion_module_dropdown
|
142 |
+
print("Update motion module")
|
143 |
+
if motion_module_dropdown == "none":
|
144 |
+
return gr.Dropdown.update()
|
145 |
+
if self.transformer is None:
|
146 |
+
gr.Info(f"Please select a pretrained model path.")
|
147 |
+
return gr.Dropdown.update(value=None)
|
148 |
+
else:
|
149 |
+
motion_module_dropdown = os.path.join(self.motion_module_dir, motion_module_dropdown)
|
150 |
+
if motion_module_dropdown.endswith(".safetensors"):
|
151 |
+
from safetensors.torch import load_file, safe_open
|
152 |
+
motion_module_state_dict = load_file(motion_module_dropdown)
|
153 |
+
else:
|
154 |
+
if not os.path.isfile(motion_module_dropdown):
|
155 |
+
raise RuntimeError(f"{motion_module_dropdown} does not exist")
|
156 |
+
motion_module_state_dict = torch.load(motion_module_dropdown, map_location="cpu")
|
157 |
+
missing, unexpected = self.transformer.load_state_dict(motion_module_state_dict, strict=False)
|
158 |
+
print("Update motion module done.")
|
159 |
+
return gr.Dropdown.update()
|
160 |
+
|
161 |
+
def update_base_model(self, base_model_dropdown):
|
162 |
+
self.base_model_path = base_model_dropdown
|
163 |
+
print("Update base model")
|
164 |
+
if base_model_dropdown == "none":
|
165 |
+
return gr.Dropdown.update()
|
166 |
+
if self.transformer is None:
|
167 |
+
gr.Info(f"Please select a pretrained model path.")
|
168 |
+
return gr.Dropdown.update(value=None)
|
169 |
+
else:
|
170 |
+
base_model_dropdown = os.path.join(self.personalized_model_dir, base_model_dropdown)
|
171 |
+
base_model_state_dict = {}
|
172 |
+
with safe_open(base_model_dropdown, framework="pt", device="cpu") as f:
|
173 |
+
for key in f.keys():
|
174 |
+
base_model_state_dict[key] = f.get_tensor(key)
|
175 |
+
self.transformer.load_state_dict(base_model_state_dict, strict=False)
|
176 |
+
print("Update base done")
|
177 |
+
return gr.Dropdown.update()
|
178 |
+
|
179 |
+
def update_lora_model(self, lora_model_dropdown):
|
180 |
+
print("Update lora model")
|
181 |
+
if lora_model_dropdown == "none":
|
182 |
+
self.lora_model_path = "none"
|
183 |
+
return gr.Dropdown.update()
|
184 |
+
lora_model_dropdown = os.path.join(self.personalized_model_dir, lora_model_dropdown)
|
185 |
+
self.lora_model_path = lora_model_dropdown
|
186 |
+
return gr.Dropdown.update()
|
187 |
+
|
188 |
+
def generate(
|
189 |
+
self,
|
190 |
+
diffusion_transformer_dropdown,
|
191 |
+
motion_module_dropdown,
|
192 |
+
base_model_dropdown,
|
193 |
+
lora_model_dropdown,
|
194 |
+
lora_alpha_slider,
|
195 |
+
prompt_textbox,
|
196 |
+
negative_prompt_textbox,
|
197 |
+
sampler_dropdown,
|
198 |
+
sample_step_slider,
|
199 |
+
width_slider,
|
200 |
+
height_slider,
|
201 |
+
is_image,
|
202 |
+
length_slider,
|
203 |
+
cfg_scale_slider,
|
204 |
+
seed_textbox,
|
205 |
+
is_api = False,
|
206 |
+
):
|
207 |
+
global sample_idx
|
208 |
+
if self.transformer is None:
|
209 |
+
raise gr.Error(f"Please select a pretrained model path.")
|
210 |
+
|
211 |
+
if self.base_model_path != base_model_dropdown:
|
212 |
+
self.update_base_model(base_model_dropdown)
|
213 |
+
|
214 |
+
if self.motion_module_path != motion_module_dropdown:
|
215 |
+
self.update_motion_module(motion_module_dropdown)
|
216 |
+
|
217 |
+
if self.lora_model_path != lora_model_dropdown:
|
218 |
+
print("Update lora model")
|
219 |
+
self.update_lora_model(lora_model_dropdown)
|
220 |
+
|
221 |
+
if is_xformers_available(): self.transformer.enable_xformers_memory_efficient_attention()
|
222 |
+
|
223 |
+
self.pipeline.scheduler = scheduler_dict[sampler_dropdown](**OmegaConf.to_container(self.inference_config.noise_scheduler_kwargs))
|
224 |
+
if self.lora_model_path != "none":
|
225 |
+
# lora part
|
226 |
+
self.pipeline = merge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider)
|
227 |
+
self.pipeline.to("cuda")
|
228 |
+
|
229 |
+
if int(seed_textbox) != -1 and seed_textbox != "": torch.manual_seed(int(seed_textbox))
|
230 |
+
else: seed_textbox = np.random.randint(0, 1e10)
|
231 |
+
generator = torch.Generator(device="cuda").manual_seed(int(seed_textbox))
|
232 |
+
|
233 |
+
try:
|
234 |
+
sample = self.pipeline(
|
235 |
+
prompt_textbox,
|
236 |
+
negative_prompt = negative_prompt_textbox,
|
237 |
+
num_inference_steps = sample_step_slider,
|
238 |
+
guidance_scale = cfg_scale_slider,
|
239 |
+
width = width_slider,
|
240 |
+
height = height_slider,
|
241 |
+
video_length = length_slider if not is_image else 1,
|
242 |
+
generator = generator
|
243 |
+
).videos
|
244 |
+
except Exception as e:
|
245 |
+
gc.collect()
|
246 |
+
torch.cuda.empty_cache()
|
247 |
+
torch.cuda.ipc_collect()
|
248 |
+
if self.lora_model_path != "none":
|
249 |
+
self.pipeline = unmerge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider)
|
250 |
+
if is_api:
|
251 |
+
return "", f"Error. error information is {str(e)}"
|
252 |
+
else:
|
253 |
+
return gr.Image.update(), gr.Video.update(), f"Error. error information is {str(e)}"
|
254 |
+
|
255 |
+
# lora part
|
256 |
+
if self.lora_model_path != "none":
|
257 |
+
self.pipeline = unmerge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider)
|
258 |
+
|
259 |
+
sample_config = {
|
260 |
+
"prompt": prompt_textbox,
|
261 |
+
"n_prompt": negative_prompt_textbox,
|
262 |
+
"sampler": sampler_dropdown,
|
263 |
+
"num_inference_steps": sample_step_slider,
|
264 |
+
"guidance_scale": cfg_scale_slider,
|
265 |
+
"width": width_slider,
|
266 |
+
"height": height_slider,
|
267 |
+
"video_length": length_slider,
|
268 |
+
"seed_textbox": seed_textbox
|
269 |
+
}
|
270 |
+
json_str = json.dumps(sample_config, indent=4)
|
271 |
+
with open(os.path.join(self.savedir, "logs.json"), "a") as f:
|
272 |
+
f.write(json_str)
|
273 |
+
f.write("\n\n")
|
274 |
+
|
275 |
+
if not os.path.exists(self.savedir_sample):
|
276 |
+
os.makedirs(self.savedir_sample, exist_ok=True)
|
277 |
+
index = len([path for path in os.listdir(self.savedir_sample)]) + 1
|
278 |
+
prefix = str(index).zfill(3)
|
279 |
+
|
280 |
+
gc.collect()
|
281 |
+
torch.cuda.empty_cache()
|
282 |
+
torch.cuda.ipc_collect()
|
283 |
+
if is_image or length_slider == 1:
|
284 |
+
save_sample_path = os.path.join(self.savedir_sample, prefix + f".png")
|
285 |
+
|
286 |
+
image = sample[0, :, 0]
|
287 |
+
image = image.transpose(0, 1).transpose(1, 2)
|
288 |
+
image = (image * 255).numpy().astype(np.uint8)
|
289 |
+
image = Image.fromarray(image)
|
290 |
+
image.save(save_sample_path)
|
291 |
+
|
292 |
+
if is_api:
|
293 |
+
return save_sample_path, "Success"
|
294 |
+
else:
|
295 |
+
return gr.Image.update(value=save_sample_path, visible=True), gr.Video.update(value=None, visible=False), "Success"
|
296 |
+
else:
|
297 |
+
save_sample_path = os.path.join(self.savedir_sample, prefix + f".mp4")
|
298 |
+
save_videos_grid(sample, save_sample_path, fps=12 if self.edition == "v1" else 24)
|
299 |
+
|
300 |
+
if is_api:
|
301 |
+
return save_sample_path, "Success"
|
302 |
+
else:
|
303 |
+
return gr.Image.update(visible=False, value=None), gr.Video.update(value=save_sample_path, visible=True), "Success"
|
304 |
+
|
305 |
+
|
306 |
+
def ui():
|
307 |
+
controller = EasyAnimateController()
|
308 |
+
|
309 |
+
with gr.Blocks(css=css) as demo:
|
310 |
+
gr.Markdown(
|
311 |
+
"""
|
312 |
+
# EasyAnimate: Integrated generation of baseline scheme for videos and images.
|
313 |
+
Generate your videos easily
|
314 |
+
[Github](https://github.com/aigc-apps/EasyAnimate/)
|
315 |
+
"""
|
316 |
+
)
|
317 |
+
with gr.Column(variant="panel"):
|
318 |
+
gr.Markdown(
|
319 |
+
"""
|
320 |
+
### 1. EasyAnimate Edition (select easyanimate edition first).
|
321 |
+
"""
|
322 |
+
)
|
323 |
+
with gr.Row():
|
324 |
+
easyanimate_edition_dropdown = gr.Dropdown(
|
325 |
+
label="The config of EasyAnimate Edition",
|
326 |
+
choices=["v1", "v2"],
|
327 |
+
value="v2",
|
328 |
+
interactive=True,
|
329 |
+
)
|
330 |
+
gr.Markdown(
|
331 |
+
"""
|
332 |
+
### 2. Model checkpoints (select pretrained model path).
|
333 |
+
"""
|
334 |
+
)
|
335 |
+
with gr.Row():
|
336 |
+
diffusion_transformer_dropdown = gr.Dropdown(
|
337 |
+
label="Pretrained Model Path",
|
338 |
+
choices=controller.diffusion_transformer_list,
|
339 |
+
value="none",
|
340 |
+
interactive=True,
|
341 |
+
)
|
342 |
+
diffusion_transformer_dropdown.change(
|
343 |
+
fn=controller.update_diffusion_transformer,
|
344 |
+
inputs=[diffusion_transformer_dropdown],
|
345 |
+
outputs=[diffusion_transformer_dropdown]
|
346 |
+
)
|
347 |
+
|
348 |
+
diffusion_transformer_refresh_button = gr.Button(value="\U0001F503", elem_classes="toolbutton")
|
349 |
+
def refresh_diffusion_transformer():
|
350 |
+
controller.refresh_diffusion_transformer()
|
351 |
+
return gr.Dropdown.update(choices=controller.diffusion_transformer_list)
|
352 |
+
diffusion_transformer_refresh_button.click(fn=refresh_diffusion_transformer, inputs=[], outputs=[diffusion_transformer_dropdown])
|
353 |
+
|
354 |
+
with gr.Row():
|
355 |
+
motion_module_dropdown = gr.Dropdown(
|
356 |
+
label="Select motion module",
|
357 |
+
choices=controller.motion_module_list,
|
358 |
+
value="none",
|
359 |
+
interactive=True,
|
360 |
+
visible=False
|
361 |
+
)
|
362 |
+
|
363 |
+
motion_module_refresh_button = gr.Button(value="\U0001F503", elem_classes="toolbutton", visible=False)
|
364 |
+
def update_motion_module():
|
365 |
+
controller.refresh_motion_module()
|
366 |
+
return gr.Dropdown.update(choices=controller.motion_module_list)
|
367 |
+
motion_module_refresh_button.click(fn=update_motion_module, inputs=[], outputs=[motion_module_dropdown])
|
368 |
+
|
369 |
+
base_model_dropdown = gr.Dropdown(
|
370 |
+
label="Select base Dreambooth model (optional)",
|
371 |
+
choices=controller.personalized_model_list,
|
372 |
+
value="none",
|
373 |
+
interactive=True,
|
374 |
+
)
|
375 |
+
|
376 |
+
lora_model_dropdown = gr.Dropdown(
|
377 |
+
label="Select LoRA model (optional)",
|
378 |
+
choices=["none"] + controller.personalized_model_list,
|
379 |
+
value="none",
|
380 |
+
interactive=True,
|
381 |
+
)
|
382 |
+
|
383 |
+
lora_alpha_slider = gr.Slider(label="LoRA alpha", value=0.55, minimum=0, maximum=2, interactive=True)
|
384 |
+
|
385 |
+
personalized_refresh_button = gr.Button(value="\U0001F503", elem_classes="toolbutton")
|
386 |
+
def update_personalized_model():
|
387 |
+
controller.refresh_personalized_model()
|
388 |
+
return [
|
389 |
+
gr.Dropdown.update(choices=controller.personalized_model_list),
|
390 |
+
gr.Dropdown.update(choices=["none"] + controller.personalized_model_list)
|
391 |
+
]
|
392 |
+
personalized_refresh_button.click(fn=update_personalized_model, inputs=[], outputs=[base_model_dropdown, lora_model_dropdown])
|
393 |
+
|
394 |
+
with gr.Column(variant="panel"):
|
395 |
+
gr.Markdown(
|
396 |
+
"""
|
397 |
+
### 3. Configs for Generation.
|
398 |
+
"""
|
399 |
+
)
|
400 |
+
|
401 |
+
prompt_textbox = gr.Textbox(label="Prompt", lines=2, value="This video shows the majestic beauty of a waterfall cascading down a cliff into a serene lake. The waterfall, with its powerful flow, is the central focus of the video. The surrounding landscape is lush and green, with trees and foliage adding to the natural beauty of the scene")
|
402 |
+
negative_prompt_textbox = gr.Textbox(label="Negative prompt", lines=2, value="The video is not of a high quality, it has a low resolution, and the audio quality is not clear. Strange motion trajectory, a poor composition and deformed video, low resolution, duplicate and ugly, strange body structure, long and strange neck, bad teeth, bad eyes, bad limbs, bad hands, rotating camera, blurry camera, shaking camera. Deformation, low-resolution, blurry, ugly, distortion. " )
|
403 |
+
|
404 |
+
with gr.Row():
|
405 |
+
with gr.Column():
|
406 |
+
with gr.Row():
|
407 |
+
sampler_dropdown = gr.Dropdown(label="Sampling method", choices=list(scheduler_dict.keys()), value=list(scheduler_dict.keys())[0])
|
408 |
+
sample_step_slider = gr.Slider(label="Sampling steps", value=50, minimum=10, maximum=100, step=1)
|
409 |
+
|
410 |
+
width_slider = gr.Slider(label="Width", value=672, minimum=128, maximum=1280, step=16)
|
411 |
+
height_slider = gr.Slider(label="Height", value=384, minimum=128, maximum=1280, step=16)
|
412 |
+
with gr.Row():
|
413 |
+
is_image = gr.Checkbox(False, label="Generate Image")
|
414 |
+
length_slider = gr.Slider(label="Animation length", value=144, minimum=9, maximum=144, step=9)
|
415 |
+
cfg_scale_slider = gr.Slider(label="CFG Scale", value=7.0, minimum=0, maximum=20)
|
416 |
+
|
417 |
+
with gr.Row():
|
418 |
+
seed_textbox = gr.Textbox(label="Seed", value=43)
|
419 |
+
seed_button = gr.Button(value="\U0001F3B2", elem_classes="toolbutton")
|
420 |
+
seed_button.click(fn=lambda: gr.Textbox.update(value=random.randint(1, 1e8)), inputs=[], outputs=[seed_textbox])
|
421 |
+
|
422 |
+
generate_button = gr.Button(value="Generate", variant='primary')
|
423 |
+
|
424 |
+
with gr.Column():
|
425 |
+
result_image = gr.Image(label="Generated Image", interactive=False, visible=False)
|
426 |
+
result_video = gr.Video(label="Generated Animation", interactive=False)
|
427 |
+
infer_progress = gr.Textbox(
|
428 |
+
label="Generation Info",
|
429 |
+
value="No task currently",
|
430 |
+
interactive=False
|
431 |
+
)
|
432 |
+
|
433 |
+
is_image.change(
|
434 |
+
lambda x: gr.update(visible=not x),
|
435 |
+
inputs=[is_image],
|
436 |
+
outputs=[length_slider],
|
437 |
+
)
|
438 |
+
easyanimate_edition_dropdown.change(
|
439 |
+
fn=controller.update_edition,
|
440 |
+
inputs=[easyanimate_edition_dropdown],
|
441 |
+
outputs=[
|
442 |
+
easyanimate_edition_dropdown,
|
443 |
+
diffusion_transformer_dropdown,
|
444 |
+
motion_module_dropdown,
|
445 |
+
motion_module_refresh_button,
|
446 |
+
is_image,
|
447 |
+
width_slider,
|
448 |
+
height_slider,
|
449 |
+
length_slider,
|
450 |
+
]
|
451 |
+
)
|
452 |
+
generate_button.click(
|
453 |
+
fn=controller.generate,
|
454 |
+
inputs=[
|
455 |
+
diffusion_transformer_dropdown,
|
456 |
+
motion_module_dropdown,
|
457 |
+
base_model_dropdown,
|
458 |
+
lora_model_dropdown,
|
459 |
+
lora_alpha_slider,
|
460 |
+
prompt_textbox,
|
461 |
+
negative_prompt_textbox,
|
462 |
+
sampler_dropdown,
|
463 |
+
sample_step_slider,
|
464 |
+
width_slider,
|
465 |
+
height_slider,
|
466 |
+
is_image,
|
467 |
+
length_slider,
|
468 |
+
cfg_scale_slider,
|
469 |
+
seed_textbox,
|
470 |
+
],
|
471 |
+
outputs=[result_image, result_video, infer_progress]
|
472 |
+
)
|
473 |
+
return demo, controller
|
474 |
+
|
475 |
+
|
476 |
+
class EasyAnimateController_Modelscope:
|
477 |
+
def __init__(self, edition, config_path, model_name, savedir_sample):
|
478 |
+
# Config and model path
|
479 |
+
weight_dtype = torch.bfloat16
|
480 |
+
self.savedir_sample = savedir_sample
|
481 |
+
os.makedirs(self.savedir_sample, exist_ok=True)
|
482 |
+
|
483 |
+
self.edition = edition
|
484 |
+
self.inference_config = OmegaConf.load(config_path)
|
485 |
+
# Get Transformer
|
486 |
+
self.transformer = Transformer3DModel.from_pretrained_2d(
|
487 |
+
model_name,
|
488 |
+
subfolder="transformer",
|
489 |
+
transformer_additional_kwargs=OmegaConf.to_container(self.inference_config['transformer_additional_kwargs'])
|
490 |
+
).to(weight_dtype)
|
491 |
+
if OmegaConf.to_container(self.inference_config['vae_kwargs'])['enable_magvit']:
|
492 |
+
Choosen_AutoencoderKL = AutoencoderKLMagvit
|
493 |
+
else:
|
494 |
+
Choosen_AutoencoderKL = AutoencoderKL
|
495 |
+
self.vae = Choosen_AutoencoderKL.from_pretrained(
|
496 |
+
model_name,
|
497 |
+
subfolder="vae"
|
498 |
+
).to(weight_dtype)
|
499 |
+
self.tokenizer = T5Tokenizer.from_pretrained(
|
500 |
+
model_name,
|
501 |
+
subfolder="tokenizer"
|
502 |
+
)
|
503 |
+
self.text_encoder = T5EncoderModel.from_pretrained(
|
504 |
+
model_name,
|
505 |
+
subfolder="text_encoder",
|
506 |
+
torch_dtype=weight_dtype
|
507 |
+
)
|
508 |
+
self.pipeline = EasyAnimatePipeline(
|
509 |
+
vae=self.vae,
|
510 |
+
text_encoder=self.text_encoder,
|
511 |
+
tokenizer=self.tokenizer,
|
512 |
+
transformer=self.transformer,
|
513 |
+
scheduler=scheduler_dict["Euler"](**OmegaConf.to_container(self.inference_config.noise_scheduler_kwargs))
|
514 |
+
)
|
515 |
+
self.pipeline.enable_model_cpu_offload()
|
516 |
+
print("Update diffusion transformer done")
|
517 |
+
|
518 |
+
def generate(
|
519 |
+
self,
|
520 |
+
prompt_textbox,
|
521 |
+
negative_prompt_textbox,
|
522 |
+
sampler_dropdown,
|
523 |
+
sample_step_slider,
|
524 |
+
width_slider,
|
525 |
+
height_slider,
|
526 |
+
is_image,
|
527 |
+
length_slider,
|
528 |
+
cfg_scale_slider,
|
529 |
+
seed_textbox
|
530 |
+
):
|
531 |
+
if is_xformers_available(): self.transformer.enable_xformers_memory_efficient_attention()
|
532 |
+
|
533 |
+
self.pipeline.scheduler = scheduler_dict[sampler_dropdown](**OmegaConf.to_container(self.inference_config.noise_scheduler_kwargs))
|
534 |
+
self.pipeline.to("cuda")
|
535 |
+
|
536 |
+
if int(seed_textbox) != -1 and seed_textbox != "": torch.manual_seed(int(seed_textbox))
|
537 |
+
else: seed_textbox = np.random.randint(0, 1e10)
|
538 |
+
generator = torch.Generator(device="cuda").manual_seed(int(seed_textbox))
|
539 |
+
|
540 |
+
try:
|
541 |
+
sample = self.pipeline(
|
542 |
+
prompt_textbox,
|
543 |
+
negative_prompt = negative_prompt_textbox,
|
544 |
+
num_inference_steps = sample_step_slider,
|
545 |
+
guidance_scale = cfg_scale_slider,
|
546 |
+
width = width_slider,
|
547 |
+
height = height_slider,
|
548 |
+
video_length = length_slider if not is_image else 1,
|
549 |
+
generator = generator
|
550 |
+
).videos
|
551 |
+
except Exception as e:
|
552 |
+
gc.collect()
|
553 |
+
torch.cuda.empty_cache()
|
554 |
+
torch.cuda.ipc_collect()
|
555 |
+
return gr.Image.update(), gr.Video.update(), f"Error. error information is {str(e)}"
|
556 |
+
|
557 |
+
if not os.path.exists(self.savedir_sample):
|
558 |
+
os.makedirs(self.savedir_sample, exist_ok=True)
|
559 |
+
index = len([path for path in os.listdir(self.savedir_sample)]) + 1
|
560 |
+
prefix = str(index).zfill(3)
|
561 |
+
|
562 |
+
gc.collect()
|
563 |
+
torch.cuda.empty_cache()
|
564 |
+
torch.cuda.ipc_collect()
|
565 |
+
if is_image or length_slider == 1:
|
566 |
+
save_sample_path = os.path.join(self.savedir_sample, prefix + f".png")
|
567 |
+
|
568 |
+
image = sample[0, :, 0]
|
569 |
+
image = image.transpose(0, 1).transpose(1, 2)
|
570 |
+
image = (image * 255).numpy().astype(np.uint8)
|
571 |
+
image = Image.fromarray(image)
|
572 |
+
image.save(save_sample_path)
|
573 |
+
return gr.Image.update(value=save_sample_path, visible=True), gr.Video.update(value=None, visible=False), "Success"
|
574 |
+
else:
|
575 |
+
save_sample_path = os.path.join(self.savedir_sample, prefix + f".mp4")
|
576 |
+
save_videos_grid(sample, save_sample_path, fps=12 if self.edition == "v1" else 24)
|
577 |
+
return gr.Image.update(visible=False, value=None), gr.Video.update(value=save_sample_path, visible=True), "Success"
|
578 |
+
|
579 |
+
|
580 |
+
def ui_modelscope(edition, config_path, model_name, savedir_sample):
|
581 |
+
controller = EasyAnimateController_Modelscope(edition, config_path, model_name, savedir_sample)
|
582 |
+
|
583 |
+
with gr.Blocks(css=css) as demo:
|
584 |
+
gr.Markdown(
|
585 |
+
"""
|
586 |
+
# EasyAnimate: Integrated generation of baseline scheme for videos and images.
|
587 |
+
Generate your videos easily
|
588 |
+
[Github](https://github.com/aigc-apps/EasyAnimate/)
|
589 |
+
"""
|
590 |
+
)
|
591 |
+
with gr.Column(variant="panel"):
|
592 |
+
prompt_textbox = gr.Textbox(label="Prompt", lines=2, value="This video shows the majestic beauty of a waterfall cascading down a cliff into a serene lake. The waterfall, with its powerful flow, is the central focus of the video. The surrounding landscape is lush and green, with trees and foliage adding to the natural beauty of the scene")
|
593 |
+
negative_prompt_textbox = gr.Textbox(label="Negative prompt", lines=2, value="The video is not of a high quality, it has a low resolution, and the audio quality is not clear. Strange motion trajectory, a poor composition and deformed video, low resolution, duplicate and ugly, strange body structure, long and strange neck, bad teeth, bad eyes, bad limbs, bad hands, rotating camera, blurry camera, shaking camera. Deformation, low-resolution, blurry, ugly, distortion. " )
|
594 |
+
|
595 |
+
with gr.Row():
|
596 |
+
with gr.Column():
|
597 |
+
with gr.Row():
|
598 |
+
sampler_dropdown = gr.Dropdown(label="Sampling method", choices=list(scheduler_dict.keys()), value=list(scheduler_dict.keys())[0])
|
599 |
+
sample_step_slider = gr.Slider(label="Sampling steps", value=30, minimum=10, maximum=100, step=1)
|
600 |
+
|
601 |
+
if edition == "v1":
|
602 |
+
width_slider = gr.Slider(label="Width", value=512, minimum=384, maximum=704, step=32)
|
603 |
+
height_slider = gr.Slider(label="Height", value=512, minimum=384, maximum=704, step=32)
|
604 |
+
with gr.Row():
|
605 |
+
is_image = gr.Checkbox(False, label="Generate Image", visible=False)
|
606 |
+
length_slider = gr.Slider(label="Animation length", value=80, minimum=40, maximum=96, step=1)
|
607 |
+
cfg_scale_slider = gr.Slider(label="CFG Scale", value=6.0, minimum=0, maximum=20)
|
608 |
+
else:
|
609 |
+
width_slider = gr.Slider(label="Width", value=672, minimum=256, maximum=704, step=16)
|
610 |
+
height_slider = gr.Slider(label="Height", value=384, minimum=256, maximum=704, step=16)
|
611 |
+
with gr.Column():
|
612 |
+
gr.Markdown(
|
613 |
+
"""
|
614 |
+
To ensure the efficiency of the trial, we will limit the frame rate to no more than 81.
|
615 |
+
If you want to experience longer video generation, you can go to our [Github](https://github.com/aigc-apps/EasyAnimate/).
|
616 |
+
"""
|
617 |
+
)
|
618 |
+
with gr.Row():
|
619 |
+
is_image = gr.Checkbox(False, label="Generate Image")
|
620 |
+
length_slider = gr.Slider(label="Animation length", value=72, minimum=9, maximum=81, step=9)
|
621 |
+
cfg_scale_slider = gr.Slider(label="CFG Scale", value=7.0, minimum=0, maximum=20)
|
622 |
+
|
623 |
+
with gr.Row():
|
624 |
+
seed_textbox = gr.Textbox(label="Seed", value=43)
|
625 |
+
seed_button = gr.Button(value="\U0001F3B2", elem_classes="toolbutton")
|
626 |
+
seed_button.click(fn=lambda: gr.Textbox.update(value=random.randint(1, 1e8)), inputs=[], outputs=[seed_textbox])
|
627 |
+
|
628 |
+
generate_button = gr.Button(value="Generate", variant='primary')
|
629 |
+
|
630 |
+
with gr.Column():
|
631 |
+
result_image = gr.Image(label="Generated Image", interactive=False, visible=False)
|
632 |
+
result_video = gr.Video(label="Generated Animation", interactive=False)
|
633 |
+
infer_progress = gr.Textbox(
|
634 |
+
label="Generation Info",
|
635 |
+
value="No task currently",
|
636 |
+
interactive=False
|
637 |
+
)
|
638 |
+
|
639 |
+
is_image.change(
|
640 |
+
lambda x: gr.update(visible=not x),
|
641 |
+
inputs=[is_image],
|
642 |
+
outputs=[length_slider],
|
643 |
+
)
|
644 |
+
|
645 |
+
generate_button.click(
|
646 |
+
fn=controller.generate,
|
647 |
+
inputs=[
|
648 |
+
prompt_textbox,
|
649 |
+
negative_prompt_textbox,
|
650 |
+
sampler_dropdown,
|
651 |
+
sample_step_slider,
|
652 |
+
width_slider,
|
653 |
+
height_slider,
|
654 |
+
is_image,
|
655 |
+
length_slider,
|
656 |
+
cfg_scale_slider,
|
657 |
+
seed_textbox,
|
658 |
+
],
|
659 |
+
outputs=[result_image, result_video, infer_progress]
|
660 |
+
)
|
661 |
+
return demo, controller
|
662 |
+
|
663 |
+
|
664 |
+
def post_eas(
|
665 |
+
prompt_textbox, negative_prompt_textbox,
|
666 |
+
sampler_dropdown, sample_step_slider, width_slider, height_slider,
|
667 |
+
is_image, length_slider, cfg_scale_slider, seed_textbox,
|
668 |
+
):
|
669 |
+
datas = {
|
670 |
+
"base_model_path": "none",
|
671 |
+
"motion_module_path": "none",
|
672 |
+
"lora_model_path": "none",
|
673 |
+
"lora_alpha_slider": 0.55,
|
674 |
+
"prompt_textbox": prompt_textbox,
|
675 |
+
"negative_prompt_textbox": negative_prompt_textbox,
|
676 |
+
"sampler_dropdown": sampler_dropdown,
|
677 |
+
"sample_step_slider": sample_step_slider,
|
678 |
+
"width_slider": width_slider,
|
679 |
+
"height_slider": height_slider,
|
680 |
+
"is_image": is_image,
|
681 |
+
"length_slider": length_slider,
|
682 |
+
"cfg_scale_slider": cfg_scale_slider,
|
683 |
+
"seed_textbox": seed_textbox,
|
684 |
+
}
|
685 |
+
# Token可以在公网地址调用信息中获取,详情请参见通用公网调用部分。
|
686 |
+
session = requests.session()
|
687 |
+
session.headers.update({"Authorization": os.environ.get("EAS_TOKEN")})
|
688 |
+
|
689 |
+
response = session.post(url=f'{os.environ.get("EAS_URL")}/easyanimate/infer_forward', json=datas)
|
690 |
+
outputs = response.json()
|
691 |
+
return outputs
|
692 |
+
|
693 |
+
|
694 |
+
class EasyAnimateController_HuggingFace:
|
695 |
+
def __init__(self, edition, config_path, model_name, savedir_sample):
|
696 |
+
self.savedir_sample = savedir_sample
|
697 |
+
os.makedirs(self.savedir_sample, exist_ok=True)
|
698 |
+
|
699 |
+
def generate(
|
700 |
+
self,
|
701 |
+
prompt_textbox,
|
702 |
+
negative_prompt_textbox,
|
703 |
+
sampler_dropdown,
|
704 |
+
sample_step_slider,
|
705 |
+
width_slider,
|
706 |
+
height_slider,
|
707 |
+
is_image,
|
708 |
+
length_slider,
|
709 |
+
cfg_scale_slider,
|
710 |
+
seed_textbox
|
711 |
+
):
|
712 |
+
outputs = post_eas(
|
713 |
+
prompt_textbox, negative_prompt_textbox,
|
714 |
+
sampler_dropdown, sample_step_slider, width_slider, height_slider,
|
715 |
+
is_image, length_slider, cfg_scale_slider, seed_textbox
|
716 |
+
)
|
717 |
+
base64_encoding = outputs["base64_encoding"]
|
718 |
+
decoded_data = base64.b64decode(base64_encoding)
|
719 |
+
|
720 |
+
if not os.path.exists(self.savedir_sample):
|
721 |
+
os.makedirs(self.savedir_sample, exist_ok=True)
|
722 |
+
index = len([path for path in os.listdir(self.savedir_sample)]) + 1
|
723 |
+
prefix = str(index).zfill(3)
|
724 |
+
|
725 |
+
if is_image or length_slider == 1:
|
726 |
+
save_sample_path = os.path.join(self.savedir_sample, prefix + f".png")
|
727 |
+
with open(save_sample_path, "wb") as file:
|
728 |
+
file.write(decoded_data)
|
729 |
+
return gr.Image.update(value=save_sample_path, visible=True), gr.Video.update(value=None, visible=False), "Success"
|
730 |
+
else:
|
731 |
+
save_sample_path = os.path.join(self.savedir_sample, prefix + f".mp4")
|
732 |
+
with open(save_sample_path, "wb") as file:
|
733 |
+
file.write(decoded_data)
|
734 |
+
return gr.Image.update(visible=False, value=None), gr.Video.update(value=save_sample_path, visible=True), "Success"
|
735 |
+
|
736 |
+
|
737 |
+
def ui_huggingface(edition, config_path, model_name, savedir_sample):
|
738 |
+
controller = EasyAnimateController_HuggingFace(edition, config_path, model_name, savedir_sample)
|
739 |
+
|
740 |
+
with gr.Blocks(css=css) as demo:
|
741 |
+
gr.Markdown(
|
742 |
+
"""
|
743 |
+
# EasyAnimate: Integrated generation of baseline scheme for videos and images.
|
744 |
+
Generate your videos easily
|
745 |
+
[Github](https://github.com/aigc-apps/EasyAnimate/)
|
746 |
+
"""
|
747 |
+
)
|
748 |
+
with gr.Column(variant="panel"):
|
749 |
+
prompt_textbox = gr.Textbox(label="Prompt", lines=2, value="This video shows the majestic beauty of a waterfall cascading down a cliff into a serene lake. The waterfall, with its powerful flow, is the central focus of the video. The surrounding landscape is lush and green, with trees and foliage adding to the natural beauty of the scene")
|
750 |
+
negative_prompt_textbox = gr.Textbox(label="Negative prompt", lines=2, value="The video is not of a high quality, it has a low resolution, and the audio quality is not clear. Strange motion trajectory, a poor composition and deformed video, low resolution, duplicate and ugly, strange body structure, long and strange neck, bad teeth, bad eyes, bad limbs, bad hands, rotating camera, blurry camera, shaking camera. Deformation, low-resolution, blurry, ugly, distortion. " )
|
751 |
+
|
752 |
+
with gr.Row():
|
753 |
+
with gr.Column():
|
754 |
+
with gr.Row():
|
755 |
+
sampler_dropdown = gr.Dropdown(label="Sampling method", choices=list(scheduler_dict.keys()), value=list(scheduler_dict.keys())[0])
|
756 |
+
sample_step_slider = gr.Slider(label="Sampling steps", value=30, minimum=10, maximum=100, step=1)
|
757 |
+
|
758 |
+
if edition == "v1":
|
759 |
+
width_slider = gr.Slider(label="Width", value=512, minimum=384, maximum=704, step=32)
|
760 |
+
height_slider = gr.Slider(label="Height", value=512, minimum=384, maximum=704, step=32)
|
761 |
+
with gr.Row():
|
762 |
+
is_image = gr.Checkbox(False, label="Generate Image", visible=False)
|
763 |
+
length_slider = gr.Slider(label="Animation length", value=80, minimum=40, maximum=96, step=1)
|
764 |
+
cfg_scale_slider = gr.Slider(label="CFG Scale", value=6.0, minimum=0, maximum=20)
|
765 |
+
else:
|
766 |
+
width_slider = gr.Slider(label="Width", value=672, minimum=256, maximum=704, step=16)
|
767 |
+
height_slider = gr.Slider(label="Height", value=384, minimum=256, maximum=704, step=16)
|
768 |
+
with gr.Column():
|
769 |
+
gr.Markdown(
|
770 |
+
"""
|
771 |
+
To ensure the efficiency of the trial, we will limit the frame rate to no more than 81.
|
772 |
+
If you want to experience longer video generation, you can go to our [Github](https://github.com/aigc-apps/EasyAnimate/).
|
773 |
+
"""
|
774 |
+
)
|
775 |
+
with gr.Row():
|
776 |
+
is_image = gr.Checkbox(False, label="Generate Image")
|
777 |
+
length_slider = gr.Slider(label="Animation length", value=72, minimum=9, maximum=81, step=9)
|
778 |
+
cfg_scale_slider = gr.Slider(label="CFG Scale", value=7.0, minimum=0, maximum=20)
|
779 |
+
|
780 |
+
with gr.Row():
|
781 |
+
seed_textbox = gr.Textbox(label="Seed", value=43)
|
782 |
+
seed_button = gr.Button(value="\U0001F3B2", elem_classes="toolbutton")
|
783 |
+
seed_button.click(fn=lambda: gr.Textbox.update(value=random.randint(1, 1e8)), inputs=[], outputs=[seed_textbox])
|
784 |
+
|
785 |
+
generate_button = gr.Button(value="Generate", variant='primary')
|
786 |
+
|
787 |
+
with gr.Column():
|
788 |
+
result_image = gr.Image(label="Generated Image", interactive=False, visible=False)
|
789 |
+
result_video = gr.Video(label="Generated Animation", interactive=False)
|
790 |
+
infer_progress = gr.Textbox(
|
791 |
+
label="Generation Info",
|
792 |
+
value="No task currently",
|
793 |
+
interactive=False
|
794 |
+
)
|
795 |
+
|
796 |
+
is_image.change(
|
797 |
+
lambda x: gr.update(visible=not x),
|
798 |
+
inputs=[is_image],
|
799 |
+
outputs=[length_slider],
|
800 |
+
)
|
801 |
+
|
802 |
+
generate_button.click(
|
803 |
+
fn=controller.generate,
|
804 |
+
inputs=[
|
805 |
+
prompt_textbox,
|
806 |
+
negative_prompt_textbox,
|
807 |
+
sampler_dropdown,
|
808 |
+
sample_step_slider,
|
809 |
+
width_slider,
|
810 |
+
height_slider,
|
811 |
+
is_image,
|
812 |
+
length_slider,
|
813 |
+
cfg_scale_slider,
|
814 |
+
seed_textbox,
|
815 |
+
],
|
816 |
+
outputs=[result_image, result_video, infer_progress]
|
817 |
+
)
|
818 |
+
return demo, controller
|
easyanimate/utils/__init__.py
ADDED
File without changes
|
easyanimate/utils/diffusion_utils.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Modified from OpenAI's diffusion repos
|
2 |
+
# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
|
3 |
+
# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
|
4 |
+
# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import torch as th
|
8 |
+
|
9 |
+
|
10 |
+
def normal_kl(mean1, logvar1, mean2, logvar2):
|
11 |
+
"""
|
12 |
+
Compute the KL divergence between two gaussians.
|
13 |
+
Shapes are automatically broadcasted, so batches can be compared to
|
14 |
+
scalars, among other use cases.
|
15 |
+
"""
|
16 |
+
tensor = next(
|
17 |
+
(
|
18 |
+
obj
|
19 |
+
for obj in (mean1, logvar1, mean2, logvar2)
|
20 |
+
if isinstance(obj, th.Tensor)
|
21 |
+
),
|
22 |
+
None,
|
23 |
+
)
|
24 |
+
assert tensor is not None, "at least one argument must be a Tensor"
|
25 |
+
|
26 |
+
# Force variances to be Tensors. Broadcasting helps convert scalars to
|
27 |
+
# Tensors, but it does not work for th.exp().
|
28 |
+
logvar1, logvar2 = [
|
29 |
+
x if isinstance(x, th.Tensor) else th.tensor(x, device=tensor.device)
|
30 |
+
for x in (logvar1, logvar2)
|
31 |
+
]
|
32 |
+
|
33 |
+
return 0.5 * (
|
34 |
+
-1.0
|
35 |
+
+ logvar2
|
36 |
+
- logvar1
|
37 |
+
+ th.exp(logvar1 - logvar2)
|
38 |
+
+ ((mean1 - mean2) ** 2) * th.exp(-logvar2)
|
39 |
+
)
|
40 |
+
|
41 |
+
|
42 |
+
def approx_standard_normal_cdf(x):
|
43 |
+
"""
|
44 |
+
A fast approximation of the cumulative distribution function of the
|
45 |
+
standard normal.
|
46 |
+
"""
|
47 |
+
return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3))))
|
48 |
+
|
49 |
+
|
50 |
+
def continuous_gaussian_log_likelihood(x, *, means, log_scales):
|
51 |
+
"""
|
52 |
+
Compute the log-likelihood of a continuous Gaussian distribution.
|
53 |
+
:param x: the targets
|
54 |
+
:param means: the Gaussian mean Tensor.
|
55 |
+
:param log_scales: the Gaussian log stddev Tensor.
|
56 |
+
:return: a tensor like x of log probabilities (in nats).
|
57 |
+
"""
|
58 |
+
centered_x = x - means
|
59 |
+
inv_stdv = th.exp(-log_scales)
|
60 |
+
normalized_x = centered_x * inv_stdv
|
61 |
+
return th.distributions.Normal(th.zeros_like(x), th.ones_like(x)).log_prob(
|
62 |
+
normalized_x
|
63 |
+
)
|
64 |
+
|
65 |
+
|
66 |
+
def discretized_gaussian_log_likelihood(x, *, means, log_scales):
|
67 |
+
"""
|
68 |
+
Compute the log-likelihood of a Gaussian distribution discretizing to a
|
69 |
+
given image.
|
70 |
+
:param x: the target images. It is assumed that this was uint8 values,
|
71 |
+
rescaled to the range [-1, 1].
|
72 |
+
:param means: the Gaussian mean Tensor.
|
73 |
+
:param log_scales: the Gaussian log stddev Tensor.
|
74 |
+
:return: a tensor like x of log probabilities (in nats).
|
75 |
+
"""
|
76 |
+
assert x.shape == means.shape == log_scales.shape
|
77 |
+
centered_x = x - means
|
78 |
+
inv_stdv = th.exp(-log_scales)
|
79 |
+
plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
|
80 |
+
cdf_plus = approx_standard_normal_cdf(plus_in)
|
81 |
+
min_in = inv_stdv * (centered_x - 1.0 / 255.0)
|
82 |
+
cdf_min = approx_standard_normal_cdf(min_in)
|
83 |
+
log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12))
|
84 |
+
log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12))
|
85 |
+
cdf_delta = cdf_plus - cdf_min
|
86 |
+
log_probs = th.where(
|
87 |
+
x < -0.999,
|
88 |
+
log_cdf_plus,
|
89 |
+
th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))),
|
90 |
+
)
|
91 |
+
assert log_probs.shape == x.shape
|
92 |
+
return log_probs
|
easyanimate/utils/gaussian_diffusion.py
ADDED
@@ -0,0 +1,1008 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Modified from OpenAI's diffusion repos
|
2 |
+
# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
|
3 |
+
# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
|
4 |
+
# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
|
5 |
+
|
6 |
+
|
7 |
+
import enum
|
8 |
+
import math
|
9 |
+
|
10 |
+
import numpy as np
|
11 |
+
import torch as th
|
12 |
+
import torch.nn.functional as F
|
13 |
+
from einops import rearrange
|
14 |
+
|
15 |
+
from .diffusion_utils import discretized_gaussian_log_likelihood, normal_kl
|
16 |
+
|
17 |
+
|
18 |
+
def mean_flat(tensor):
|
19 |
+
"""
|
20 |
+
Take the mean over all non-batch dimensions.
|
21 |
+
"""
|
22 |
+
return tensor.mean(dim=list(range(1, len(tensor.shape))))
|
23 |
+
|
24 |
+
|
25 |
+
class ModelMeanType(enum.Enum):
|
26 |
+
"""
|
27 |
+
Which type of output the model predicts.
|
28 |
+
"""
|
29 |
+
|
30 |
+
PREVIOUS_X = enum.auto() # the model predicts x_{t-1}
|
31 |
+
START_X = enum.auto() # the model predicts x_0
|
32 |
+
EPSILON = enum.auto() # the model predicts epsilon
|
33 |
+
|
34 |
+
|
35 |
+
class ModelVarType(enum.Enum):
|
36 |
+
"""
|
37 |
+
What is used as the model's output variance.
|
38 |
+
The LEARNED_RANGE option has been added to allow the model to predict
|
39 |
+
values between FIXED_SMALL and FIXED_LARGE, making its job easier.
|
40 |
+
"""
|
41 |
+
|
42 |
+
LEARNED = enum.auto()
|
43 |
+
FIXED_SMALL = enum.auto()
|
44 |
+
FIXED_LARGE = enum.auto()
|
45 |
+
LEARNED_RANGE = enum.auto()
|
46 |
+
|
47 |
+
|
48 |
+
class LossType(enum.Enum):
|
49 |
+
MSE = enum.auto() # use raw MSE loss (and KL when learning variances)
|
50 |
+
RESCALED_MSE = (
|
51 |
+
enum.auto()
|
52 |
+
) # use raw MSE loss (with RESCALED_KL when learning variances)
|
53 |
+
KL = enum.auto() # use the variational lower-bound
|
54 |
+
RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB
|
55 |
+
|
56 |
+
def is_vb(self):
|
57 |
+
return self in [LossType.KL, LossType.RESCALED_KL]
|
58 |
+
|
59 |
+
|
60 |
+
def _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, warmup_frac):
|
61 |
+
betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
|
62 |
+
warmup_time = int(num_diffusion_timesteps * warmup_frac)
|
63 |
+
betas[:warmup_time] = np.linspace(beta_start, beta_end, warmup_time, dtype=np.float64)
|
64 |
+
return betas
|
65 |
+
|
66 |
+
|
67 |
+
def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps):
|
68 |
+
"""
|
69 |
+
This is the deprecated API for creating beta schedules.
|
70 |
+
See get_named_beta_schedule() for the new library of schedules.
|
71 |
+
"""
|
72 |
+
if beta_schedule == "quad":
|
73 |
+
betas = (
|
74 |
+
np.linspace(
|
75 |
+
beta_start ** 0.5,
|
76 |
+
beta_end ** 0.5,
|
77 |
+
num_diffusion_timesteps,
|
78 |
+
dtype=np.float64,
|
79 |
+
)
|
80 |
+
** 2
|
81 |
+
)
|
82 |
+
elif beta_schedule == "linear":
|
83 |
+
betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64)
|
84 |
+
elif beta_schedule == "warmup10":
|
85 |
+
betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.1)
|
86 |
+
elif beta_schedule == "warmup50":
|
87 |
+
betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.5)
|
88 |
+
elif beta_schedule == "const":
|
89 |
+
betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
|
90 |
+
elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1
|
91 |
+
betas = 1.0 / np.linspace(
|
92 |
+
num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64
|
93 |
+
)
|
94 |
+
else:
|
95 |
+
raise NotImplementedError(beta_schedule)
|
96 |
+
assert betas.shape == (num_diffusion_timesteps,)
|
97 |
+
return betas
|
98 |
+
|
99 |
+
|
100 |
+
def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
|
101 |
+
"""
|
102 |
+
Get a pre-defined beta schedule for the given name.
|
103 |
+
The beta schedule library consists of beta schedules which remain similar
|
104 |
+
in the limit of num_diffusion_timesteps.
|
105 |
+
Beta schedules may be added, but should not be removed or changed once
|
106 |
+
they are committed to maintain backwards compatibility.
|
107 |
+
"""
|
108 |
+
if schedule_name == "linear":
|
109 |
+
# Linear schedule from Ho et al, extended to work for any number of
|
110 |
+
# diffusion steps.
|
111 |
+
scale = 1000 / num_diffusion_timesteps
|
112 |
+
return get_beta_schedule(
|
113 |
+
"linear",
|
114 |
+
beta_start=scale * 0.0001,
|
115 |
+
beta_end=scale * 0.02,
|
116 |
+
num_diffusion_timesteps=num_diffusion_timesteps,
|
117 |
+
)
|
118 |
+
elif schedule_name == "squaredcos_cap_v2":
|
119 |
+
return betas_for_alpha_bar(
|
120 |
+
num_diffusion_timesteps,
|
121 |
+
lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
|
122 |
+
)
|
123 |
+
else:
|
124 |
+
raise NotImplementedError(f"unknown beta schedule: {schedule_name}")
|
125 |
+
|
126 |
+
|
127 |
+
def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
|
128 |
+
"""
|
129 |
+
Create a beta schedule that discretizes the given alpha_t_bar function,
|
130 |
+
which defines the cumulative product of (1-beta) over time from t = [0,1].
|
131 |
+
:param num_diffusion_timesteps: the number of betas to produce.
|
132 |
+
:param alpha_bar: a lambda that takes an argument t from 0 to 1 and
|
133 |
+
produces the cumulative product of (1-beta) up to that
|
134 |
+
part of the diffusion process.
|
135 |
+
:param max_beta: the maximum beta to use; use values lower than 1 to
|
136 |
+
prevent singularities.
|
137 |
+
"""
|
138 |
+
betas = []
|
139 |
+
for i in range(num_diffusion_timesteps):
|
140 |
+
t1 = i / num_diffusion_timesteps
|
141 |
+
t2 = (i + 1) / num_diffusion_timesteps
|
142 |
+
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
|
143 |
+
return np.array(betas)
|
144 |
+
|
145 |
+
|
146 |
+
class GaussianDiffusion:
|
147 |
+
"""
|
148 |
+
Utilities for training and sampling diffusion models.
|
149 |
+
Original ported from this codebase:
|
150 |
+
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42
|
151 |
+
:param betas: a 1-D numpy array of betas for each diffusion timestep,
|
152 |
+
starting at T and going to 1.
|
153 |
+
"""
|
154 |
+
|
155 |
+
def __init__(
|
156 |
+
self,
|
157 |
+
*,
|
158 |
+
betas,
|
159 |
+
model_mean_type,
|
160 |
+
model_var_type,
|
161 |
+
loss_type,
|
162 |
+
snr=False,
|
163 |
+
return_startx=False,
|
164 |
+
):
|
165 |
+
|
166 |
+
self.model_mean_type = model_mean_type
|
167 |
+
self.model_var_type = model_var_type
|
168 |
+
self.loss_type = loss_type
|
169 |
+
self.snr = snr
|
170 |
+
self.return_startx = return_startx
|
171 |
+
|
172 |
+
# Use float64 for accuracy.
|
173 |
+
betas = np.array(betas, dtype=np.float64)
|
174 |
+
self.betas = betas
|
175 |
+
assert len(betas.shape) == 1, "betas must be 1-D"
|
176 |
+
assert (betas > 0).all() and (betas <= 1).all()
|
177 |
+
|
178 |
+
self.num_timesteps = int(betas.shape[0])
|
179 |
+
|
180 |
+
alphas = 1.0 - betas
|
181 |
+
self.alphas_cumprod = np.cumprod(alphas, axis=0)
|
182 |
+
self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])
|
183 |
+
self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0)
|
184 |
+
assert self.alphas_cumprod_prev.shape == (self.num_timesteps,)
|
185 |
+
|
186 |
+
# calculations for diffusion q(x_t | x_{t-1}) and others
|
187 |
+
self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
|
188 |
+
self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod)
|
189 |
+
self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod)
|
190 |
+
self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod)
|
191 |
+
self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1)
|
192 |
+
|
193 |
+
# calculations for posterior q(x_{t-1} | x_t, x_0)
|
194 |
+
self.posterior_variance = (
|
195 |
+
betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
|
196 |
+
)
|
197 |
+
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
|
198 |
+
self.posterior_log_variance_clipped = np.log(
|
199 |
+
np.append(self.posterior_variance[1], self.posterior_variance[1:])
|
200 |
+
) if len(self.posterior_variance) > 1 else np.array([])
|
201 |
+
|
202 |
+
self.posterior_mean_coef1 = (
|
203 |
+
betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
|
204 |
+
)
|
205 |
+
self.posterior_mean_coef2 = (
|
206 |
+
(1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - self.alphas_cumprod)
|
207 |
+
)
|
208 |
+
|
209 |
+
def q_mean_variance(self, x_start, t):
|
210 |
+
"""
|
211 |
+
Get the distribution q(x_t | x_0).
|
212 |
+
:param x_start: the [N x C x ...] tensor of noiseless inputs.
|
213 |
+
:param t: the number of diffusion steps (minus 1). Here, 0 means one step.
|
214 |
+
:return: A tuple (mean, variance, log_variance), all of x_start's shape.
|
215 |
+
"""
|
216 |
+
mean = _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
|
217 |
+
variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
|
218 |
+
log_variance = _extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape)
|
219 |
+
return mean, variance, log_variance
|
220 |
+
|
221 |
+
def q_sample(self, x_start, t, noise=None):
|
222 |
+
"""
|
223 |
+
Diffuse the data for a given number of diffusion steps.
|
224 |
+
In other words, sample from q(x_t | x_0).
|
225 |
+
:param x_start: the initial data batch.
|
226 |
+
:param t: the number of diffusion steps (minus 1). Here, 0 means one step.
|
227 |
+
:param noise: if specified, the split-out normal noise.
|
228 |
+
:return: A noisy version of x_start.
|
229 |
+
"""
|
230 |
+
if noise is None:
|
231 |
+
noise = th.randn_like(x_start)
|
232 |
+
assert noise.shape == x_start.shape
|
233 |
+
return (
|
234 |
+
_extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
|
235 |
+
+ _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
|
236 |
+
)
|
237 |
+
|
238 |
+
def q_posterior_mean_variance(self, x_start, x_t, t):
|
239 |
+
"""
|
240 |
+
Compute the mean and variance of the diffusion posterior:
|
241 |
+
q(x_{t-1} | x_t, x_0)
|
242 |
+
"""
|
243 |
+
assert x_start.shape == x_t.shape
|
244 |
+
posterior_mean = (
|
245 |
+
_extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
|
246 |
+
+ _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
|
247 |
+
)
|
248 |
+
posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape)
|
249 |
+
posterior_log_variance_clipped = _extract_into_tensor(
|
250 |
+
self.posterior_log_variance_clipped, t, x_t.shape
|
251 |
+
)
|
252 |
+
assert (
|
253 |
+
posterior_mean.shape[0]
|
254 |
+
== posterior_variance.shape[0]
|
255 |
+
== posterior_log_variance_clipped.shape[0]
|
256 |
+
== x_start.shape[0]
|
257 |
+
)
|
258 |
+
return posterior_mean, posterior_variance, posterior_log_variance_clipped
|
259 |
+
|
260 |
+
def p_mean_variance(self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None):
|
261 |
+
"""
|
262 |
+
Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
|
263 |
+
the initial x, x_0.
|
264 |
+
:param model: the model, which takes a signal and a batch of timesteps
|
265 |
+
as input.
|
266 |
+
:param x: the [N x C x ...] tensor at time t.
|
267 |
+
:param t: a 1-D Tensor of timesteps.
|
268 |
+
:param clip_denoised: if True, clip the denoised signal into [-1, 1].
|
269 |
+
:param denoised_fn: if not None, a function which applies to the
|
270 |
+
x_start prediction before it is used to sample. Applies before
|
271 |
+
clip_denoised.
|
272 |
+
:param model_kwargs: if not None, a dict of extra keyword arguments to
|
273 |
+
pass to the model. This can be used for conditioning.
|
274 |
+
:return: a dict with the following keys:
|
275 |
+
- 'mean': the model mean output.
|
276 |
+
- 'variance': the model variance output.
|
277 |
+
- 'log_variance': the log of 'variance'.
|
278 |
+
- 'pred_xstart': the prediction for x_0.
|
279 |
+
"""
|
280 |
+
if model_kwargs is None:
|
281 |
+
model_kwargs = {}
|
282 |
+
|
283 |
+
B, C = x.shape[:2]
|
284 |
+
assert t.shape == (B,)
|
285 |
+
model_output = model(x, timestep=t, **model_kwargs)
|
286 |
+
if isinstance(model_output, tuple):
|
287 |
+
model_output, extra = model_output
|
288 |
+
else:
|
289 |
+
extra = None
|
290 |
+
|
291 |
+
if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]:
|
292 |
+
assert model_output.shape == (B, C * 2, *x.shape[2:])
|
293 |
+
model_output, model_var_values = th.split(model_output, C, dim=1)
|
294 |
+
min_log = _extract_into_tensor(self.posterior_log_variance_clipped, t, x.shape)
|
295 |
+
max_log = _extract_into_tensor(np.log(self.betas), t, x.shape)
|
296 |
+
# The model_var_values is [-1, 1] for [min_var, max_var].
|
297 |
+
frac = (model_var_values + 1) / 2
|
298 |
+
model_log_variance = frac * max_log + (1 - frac) * min_log
|
299 |
+
model_variance = th.exp(model_log_variance)
|
300 |
+
elif self.model_var_type in [ModelVarType.FIXED_LARGE, ModelVarType.FIXED_SMALL]:
|
301 |
+
model_variance, model_log_variance = {
|
302 |
+
# for fixedlarge, we set the initial (log-)variance like so
|
303 |
+
# to get a better decoder log likelihood.
|
304 |
+
ModelVarType.FIXED_LARGE: (
|
305 |
+
np.append(self.posterior_variance[1], self.betas[1:]),
|
306 |
+
np.log(np.append(self.posterior_variance[1], self.betas[1:])),
|
307 |
+
),
|
308 |
+
ModelVarType.FIXED_SMALL: (
|
309 |
+
self.posterior_variance,
|
310 |
+
self.posterior_log_variance_clipped,
|
311 |
+
),
|
312 |
+
}[self.model_var_type]
|
313 |
+
model_variance = _extract_into_tensor(model_variance, t, x.shape)
|
314 |
+
model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape)
|
315 |
+
else:
|
316 |
+
model_variance = th.zeros_like(model_output)
|
317 |
+
model_log_variance = th.zeros_like(model_output)
|
318 |
+
|
319 |
+
def process_xstart(x):
|
320 |
+
if denoised_fn is not None:
|
321 |
+
x = denoised_fn(x)
|
322 |
+
return x.clamp(-1, 1) if clip_denoised else x
|
323 |
+
|
324 |
+
if self.model_mean_type == ModelMeanType.START_X:
|
325 |
+
pred_xstart = process_xstart(model_output)
|
326 |
+
else:
|
327 |
+
pred_xstart = process_xstart(
|
328 |
+
self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output)
|
329 |
+
)
|
330 |
+
model_mean, _, _ = self.q_posterior_mean_variance(x_start=pred_xstart, x_t=x, t=t)
|
331 |
+
|
332 |
+
assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape
|
333 |
+
return {
|
334 |
+
"mean": model_mean,
|
335 |
+
"variance": model_variance,
|
336 |
+
"log_variance": model_log_variance,
|
337 |
+
"pred_xstart": pred_xstart,
|
338 |
+
"extra": extra,
|
339 |
+
}
|
340 |
+
|
341 |
+
def _predict_xstart_from_eps(self, x_t, t, eps):
|
342 |
+
assert x_t.shape == eps.shape
|
343 |
+
return (
|
344 |
+
_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
|
345 |
+
- _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
|
346 |
+
)
|
347 |
+
|
348 |
+
def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
|
349 |
+
return (
|
350 |
+
_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart
|
351 |
+
) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
|
352 |
+
|
353 |
+
def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
|
354 |
+
"""
|
355 |
+
Compute the mean for the previous step, given a function cond_fn that
|
356 |
+
computes the gradient of a conditional log probability with respect to
|
357 |
+
x. In particular, cond_fn computes grad(log(p(y|x))), and we want to
|
358 |
+
condition on y.
|
359 |
+
This uses the conditioning strategy from Sohl-Dickstein et al. (2015).
|
360 |
+
"""
|
361 |
+
gradient = cond_fn(x, t, **model_kwargs)
|
362 |
+
return p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float()
|
363 |
+
|
364 |
+
def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
|
365 |
+
"""
|
366 |
+
Compute what the p_mean_variance output would have been, should the
|
367 |
+
model's score function be conditioned by cond_fn.
|
368 |
+
See condition_mean() for details on cond_fn.
|
369 |
+
Unlike condition_mean(), this instead uses the conditioning strategy
|
370 |
+
from Song et al (2020).
|
371 |
+
"""
|
372 |
+
alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
|
373 |
+
|
374 |
+
eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"])
|
375 |
+
eps = eps - (1 - alpha_bar).sqrt() * cond_fn(x, t, **model_kwargs)
|
376 |
+
|
377 |
+
out = p_mean_var.copy()
|
378 |
+
out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps)
|
379 |
+
out["mean"], _, _ = self.q_posterior_mean_variance(x_start=out["pred_xstart"], x_t=x, t=t)
|
380 |
+
return out
|
381 |
+
|
382 |
+
def p_sample(
|
383 |
+
self,
|
384 |
+
model,
|
385 |
+
x,
|
386 |
+
t,
|
387 |
+
clip_denoised=True,
|
388 |
+
denoised_fn=None,
|
389 |
+
cond_fn=None,
|
390 |
+
model_kwargs=None,
|
391 |
+
):
|
392 |
+
"""
|
393 |
+
Sample x_{t-1} from the model at the given timestep.
|
394 |
+
:param model: the model to sample from.
|
395 |
+
:param x: the current tensor at x_{t-1}.
|
396 |
+
:param t: the value of t, starting at 0 for the first diffusion step.
|
397 |
+
:param clip_denoised: if True, clip the x_start prediction to [-1, 1].
|
398 |
+
:param denoised_fn: if not None, a function which applies to the
|
399 |
+
x_start prediction before it is used to sample.
|
400 |
+
:param cond_fn: if not None, this is a gradient function that acts
|
401 |
+
similarly to the model.
|
402 |
+
:param model_kwargs: if not None, a dict of extra keyword arguments to
|
403 |
+
pass to the model. This can be used for conditioning.
|
404 |
+
:return: a dict containing the following keys:
|
405 |
+
- 'sample': a random sample from the model.
|
406 |
+
- 'pred_xstart': a prediction of x_0.
|
407 |
+
"""
|
408 |
+
out = self.p_mean_variance(
|
409 |
+
model,
|
410 |
+
x,
|
411 |
+
t,
|
412 |
+
clip_denoised=clip_denoised,
|
413 |
+
denoised_fn=denoised_fn,
|
414 |
+
model_kwargs=model_kwargs,
|
415 |
+
)
|
416 |
+
noise = th.randn_like(x)
|
417 |
+
nonzero_mask = (
|
418 |
+
(t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
|
419 |
+
) # no noise when t == 0
|
420 |
+
if cond_fn is not None:
|
421 |
+
out["mean"] = self.condition_mean(cond_fn, out, x, t, model_kwargs=model_kwargs)
|
422 |
+
sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise
|
423 |
+
return {"sample": sample, "pred_xstart": out["pred_xstart"]}
|
424 |
+
|
425 |
+
def p_sample_loop(
|
426 |
+
self,
|
427 |
+
model,
|
428 |
+
shape,
|
429 |
+
noise=None,
|
430 |
+
clip_denoised=True,
|
431 |
+
denoised_fn=None,
|
432 |
+
cond_fn=None,
|
433 |
+
model_kwargs=None,
|
434 |
+
device=None,
|
435 |
+
progress=False,
|
436 |
+
):
|
437 |
+
"""
|
438 |
+
Generate samples from the model.
|
439 |
+
:param model: the model module.
|
440 |
+
:param shape: the shape of the samples, (N, C, H, W).
|
441 |
+
:param noise: if specified, the noise from the encoder to sample.
|
442 |
+
Should be of the same shape as `shape`.
|
443 |
+
:param clip_denoised: if True, clip x_start predictions to [-1, 1].
|
444 |
+
:param denoised_fn: if not None, a function which applies to the
|
445 |
+
x_start prediction before it is used to sample.
|
446 |
+
:param cond_fn: if not None, this is a gradient function that acts
|
447 |
+
similarly to the model.
|
448 |
+
:param model_kwargs: if not None, a dict of extra keyword arguments to
|
449 |
+
pass to the model. This can be used for conditioning.
|
450 |
+
:param device: if specified, the device to create the samples on.
|
451 |
+
If not specified, use a model parameter's device.
|
452 |
+
:param progress: if True, show a tqdm progress bar.
|
453 |
+
:return: a non-differentiable batch of samples.
|
454 |
+
"""
|
455 |
+
final = None
|
456 |
+
for sample in self.p_sample_loop_progressive(
|
457 |
+
model,
|
458 |
+
shape,
|
459 |
+
noise=noise,
|
460 |
+
clip_denoised=clip_denoised,
|
461 |
+
denoised_fn=denoised_fn,
|
462 |
+
cond_fn=cond_fn,
|
463 |
+
model_kwargs=model_kwargs,
|
464 |
+
device=device,
|
465 |
+
progress=progress,
|
466 |
+
):
|
467 |
+
final = sample
|
468 |
+
return final["sample"]
|
469 |
+
|
470 |
+
def p_sample_loop_progressive(
|
471 |
+
self,
|
472 |
+
model,
|
473 |
+
shape,
|
474 |
+
noise=None,
|
475 |
+
clip_denoised=True,
|
476 |
+
denoised_fn=None,
|
477 |
+
cond_fn=None,
|
478 |
+
model_kwargs=None,
|
479 |
+
device=None,
|
480 |
+
progress=False,
|
481 |
+
):
|
482 |
+
"""
|
483 |
+
Generate samples from the model and yield intermediate samples from
|
484 |
+
each timestep of diffusion.
|
485 |
+
Arguments are the same as p_sample_loop().
|
486 |
+
Returns a generator over dicts, where each dict is the return value of
|
487 |
+
p_sample().
|
488 |
+
"""
|
489 |
+
if device is None:
|
490 |
+
device = next(model.parameters()).device
|
491 |
+
assert isinstance(shape, (tuple, list))
|
492 |
+
img = noise if noise is not None else th.randn(*shape, device=device)
|
493 |
+
indices = list(range(self.num_timesteps))[::-1]
|
494 |
+
|
495 |
+
if progress:
|
496 |
+
# Lazy import so that we don't depend on tqdm.
|
497 |
+
from tqdm.auto import tqdm
|
498 |
+
|
499 |
+
indices = tqdm(indices)
|
500 |
+
|
501 |
+
for i in indices:
|
502 |
+
t = th.tensor([i] * shape[0], device=device)
|
503 |
+
with th.no_grad():
|
504 |
+
out = self.p_sample(
|
505 |
+
model,
|
506 |
+
img,
|
507 |
+
t,
|
508 |
+
clip_denoised=clip_denoised,
|
509 |
+
denoised_fn=denoised_fn,
|
510 |
+
cond_fn=cond_fn,
|
511 |
+
model_kwargs=model_kwargs,
|
512 |
+
)
|
513 |
+
yield out
|
514 |
+
img = out["sample"]
|
515 |
+
|
516 |
+
def ddim_sample(
|
517 |
+
self,
|
518 |
+
model,
|
519 |
+
x,
|
520 |
+
t,
|
521 |
+
clip_denoised=True,
|
522 |
+
denoised_fn=None,
|
523 |
+
cond_fn=None,
|
524 |
+
model_kwargs=None,
|
525 |
+
eta=0.0,
|
526 |
+
):
|
527 |
+
"""
|
528 |
+
Sample x_{t-1} from the model using DDIM.
|
529 |
+
Same usage as p_sample().
|
530 |
+
"""
|
531 |
+
out = self.p_mean_variance(
|
532 |
+
model,
|
533 |
+
x,
|
534 |
+
t,
|
535 |
+
clip_denoised=clip_denoised,
|
536 |
+
denoised_fn=denoised_fn,
|
537 |
+
model_kwargs=model_kwargs,
|
538 |
+
)
|
539 |
+
if cond_fn is not None:
|
540 |
+
out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
|
541 |
+
|
542 |
+
# Usually our model outputs epsilon, but we re-derive it
|
543 |
+
# in case we used x_start or x_prev prediction.
|
544 |
+
eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"])
|
545 |
+
|
546 |
+
alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
|
547 |
+
alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape)
|
548 |
+
sigma = (
|
549 |
+
eta
|
550 |
+
* th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar))
|
551 |
+
* th.sqrt(1 - alpha_bar / alpha_bar_prev)
|
552 |
+
)
|
553 |
+
# Equation 12.
|
554 |
+
noise = th.randn_like(x)
|
555 |
+
mean_pred = (
|
556 |
+
out["pred_xstart"] * th.sqrt(alpha_bar_prev)
|
557 |
+
+ th.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps
|
558 |
+
)
|
559 |
+
nonzero_mask = (
|
560 |
+
(t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
|
561 |
+
) # no noise when t == 0
|
562 |
+
sample = mean_pred + nonzero_mask * sigma * noise
|
563 |
+
return {"sample": sample, "pred_xstart": out["pred_xstart"]}
|
564 |
+
|
565 |
+
def ddim_reverse_sample(
|
566 |
+
self,
|
567 |
+
model,
|
568 |
+
x,
|
569 |
+
t,
|
570 |
+
clip_denoised=True,
|
571 |
+
denoised_fn=None,
|
572 |
+
cond_fn=None,
|
573 |
+
model_kwargs=None,
|
574 |
+
eta=0.0,
|
575 |
+
):
|
576 |
+
"""
|
577 |
+
Sample x_{t+1} from the model using DDIM reverse ODE.
|
578 |
+
"""
|
579 |
+
assert eta == 0.0, "Reverse ODE only for deterministic path"
|
580 |
+
out = self.p_mean_variance(
|
581 |
+
model,
|
582 |
+
x,
|
583 |
+
t,
|
584 |
+
clip_denoised=clip_denoised,
|
585 |
+
denoised_fn=denoised_fn,
|
586 |
+
model_kwargs=model_kwargs,
|
587 |
+
)
|
588 |
+
if cond_fn is not None:
|
589 |
+
out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
|
590 |
+
# Usually our model outputs epsilon, but we re-derive it
|
591 |
+
# in case we used x_start or x_prev prediction.
|
592 |
+
eps = (
|
593 |
+
_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x
|
594 |
+
- out["pred_xstart"]
|
595 |
+
) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape)
|
596 |
+
alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape)
|
597 |
+
|
598 |
+
# Equation 12. reversed
|
599 |
+
mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_next) + th.sqrt(1 - alpha_bar_next) * eps
|
600 |
+
|
601 |
+
return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]}
|
602 |
+
|
603 |
+
def ddim_sample_loop(
|
604 |
+
self,
|
605 |
+
model,
|
606 |
+
shape,
|
607 |
+
noise=None,
|
608 |
+
clip_denoised=True,
|
609 |
+
denoised_fn=None,
|
610 |
+
cond_fn=None,
|
611 |
+
model_kwargs=None,
|
612 |
+
device=None,
|
613 |
+
progress=False,
|
614 |
+
eta=0.0,
|
615 |
+
):
|
616 |
+
"""
|
617 |
+
Generate samples from the model using DDIM.
|
618 |
+
Same usage as p_sample_loop().
|
619 |
+
"""
|
620 |
+
final = None
|
621 |
+
for sample in self.ddim_sample_loop_progressive(
|
622 |
+
model,
|
623 |
+
shape,
|
624 |
+
noise=noise,
|
625 |
+
clip_denoised=clip_denoised,
|
626 |
+
denoised_fn=denoised_fn,
|
627 |
+
cond_fn=cond_fn,
|
628 |
+
model_kwargs=model_kwargs,
|
629 |
+
device=device,
|
630 |
+
progress=progress,
|
631 |
+
eta=eta,
|
632 |
+
):
|
633 |
+
final = sample
|
634 |
+
return final["sample"]
|
635 |
+
|
636 |
+
def ddim_sample_loop_progressive(
|
637 |
+
self,
|
638 |
+
model,
|
639 |
+
shape,
|
640 |
+
noise=None,
|
641 |
+
clip_denoised=True,
|
642 |
+
denoised_fn=None,
|
643 |
+
cond_fn=None,
|
644 |
+
model_kwargs=None,
|
645 |
+
device=None,
|
646 |
+
progress=False,
|
647 |
+
eta=0.0,
|
648 |
+
):
|
649 |
+
"""
|
650 |
+
Use DDIM to sample from the model and yield intermediate samples from
|
651 |
+
each timestep of DDIM.
|
652 |
+
Same usage as p_sample_loop_progressive().
|
653 |
+
"""
|
654 |
+
if device is None:
|
655 |
+
device = next(model.parameters()).device
|
656 |
+
assert isinstance(shape, (tuple, list))
|
657 |
+
img = noise if noise is not None else th.randn(*shape, device=device)
|
658 |
+
indices = list(range(self.num_timesteps))[::-1]
|
659 |
+
|
660 |
+
if progress:
|
661 |
+
# Lazy import so that we don't depend on tqdm.
|
662 |
+
from tqdm.auto import tqdm
|
663 |
+
|
664 |
+
indices = tqdm(indices)
|
665 |
+
|
666 |
+
for i in indices:
|
667 |
+
t = th.tensor([i] * shape[0], device=device)
|
668 |
+
with th.no_grad():
|
669 |
+
out = self.ddim_sample(
|
670 |
+
model,
|
671 |
+
img,
|
672 |
+
t,
|
673 |
+
clip_denoised=clip_denoised,
|
674 |
+
denoised_fn=denoised_fn,
|
675 |
+
cond_fn=cond_fn,
|
676 |
+
model_kwargs=model_kwargs,
|
677 |
+
eta=eta,
|
678 |
+
)
|
679 |
+
yield out
|
680 |
+
img = out["sample"]
|
681 |
+
|
682 |
+
def _vb_terms_bpd(
|
683 |
+
self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None
|
684 |
+
):
|
685 |
+
"""
|
686 |
+
Get a term for the variational lower-bound.
|
687 |
+
The resulting units are bits (rather than nats, as one might expect).
|
688 |
+
This allows for comparison to other papers.
|
689 |
+
:return: a dict with the following keys:
|
690 |
+
- 'output': a shape [N] tensor of NLLs or KLs.
|
691 |
+
- 'pred_xstart': the x_0 predictions.
|
692 |
+
"""
|
693 |
+
true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(
|
694 |
+
x_start=x_start, x_t=x_t, t=t
|
695 |
+
)
|
696 |
+
out = self.p_mean_variance(
|
697 |
+
model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs
|
698 |
+
)
|
699 |
+
kl = normal_kl(
|
700 |
+
true_mean, true_log_variance_clipped, out["mean"], out["log_variance"]
|
701 |
+
)
|
702 |
+
kl = mean_flat(kl) / np.log(2.0)
|
703 |
+
|
704 |
+
decoder_nll = -discretized_gaussian_log_likelihood(
|
705 |
+
x_start, means=out["mean"], log_scales=0.5 * out["log_variance"]
|
706 |
+
)
|
707 |
+
assert decoder_nll.shape == x_start.shape
|
708 |
+
decoder_nll = mean_flat(decoder_nll) / np.log(2.0)
|
709 |
+
|
710 |
+
# At the first timestep return the decoder NLL,
|
711 |
+
# otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))
|
712 |
+
output = th.where((t == 0), decoder_nll, kl)
|
713 |
+
return {"output": output, "pred_xstart": out["pred_xstart"]}
|
714 |
+
|
715 |
+
def training_losses(self, model, x_start, timestep, model_kwargs=None, noise=None, skip_noise=False):
|
716 |
+
"""
|
717 |
+
Compute training losses for a single timestep.
|
718 |
+
:param model: the model to evaluate loss on.
|
719 |
+
:param x_start: the [N x C x ...] tensor of inputs.
|
720 |
+
:param t: a batch of timestep indices.
|
721 |
+
:param model_kwargs: if not None, a dict of extra keyword arguments to
|
722 |
+
pass to the model. This can be used for conditioning.
|
723 |
+
:param noise: if specified, the specific Gaussian noise to try to remove.
|
724 |
+
:return: a dict with the key "loss" containing a tensor of shape [N].
|
725 |
+
Some mean or variance settings may also have other keys.
|
726 |
+
"""
|
727 |
+
t = timestep
|
728 |
+
if model_kwargs is None:
|
729 |
+
model_kwargs = {}
|
730 |
+
if skip_noise:
|
731 |
+
x_t = x_start
|
732 |
+
else:
|
733 |
+
if noise is None:
|
734 |
+
noise = th.randn_like(x_start)
|
735 |
+
x_t = self.q_sample(x_start, t, noise=noise)
|
736 |
+
|
737 |
+
terms = {}
|
738 |
+
|
739 |
+
if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL:
|
740 |
+
terms["loss"] = self._vb_terms_bpd(
|
741 |
+
model=model,
|
742 |
+
x_start=x_start,
|
743 |
+
x_t=x_t,
|
744 |
+
t=t,
|
745 |
+
clip_denoised=False,
|
746 |
+
model_kwargs=model_kwargs,
|
747 |
+
)["output"]
|
748 |
+
if self.loss_type == LossType.RESCALED_KL:
|
749 |
+
terms["loss"] *= self.num_timesteps
|
750 |
+
elif self.loss_type in [LossType.MSE, LossType.RESCALED_MSE]:
|
751 |
+
model_output = model(x_t, timestep=t, **model_kwargs)[0]
|
752 |
+
|
753 |
+
if isinstance(model_output, dict) and model_output.get('x', None) is not None:
|
754 |
+
output = model_output['x']
|
755 |
+
else:
|
756 |
+
output = model_output
|
757 |
+
|
758 |
+
if self.return_startx and self.model_mean_type == ModelMeanType.EPSILON:
|
759 |
+
return self._extracted_from_training_losses_diffusers(x_t, output, t)
|
760 |
+
# self.model_var_type = ModelVarType.LEARNED_RANGE:4
|
761 |
+
if self.model_var_type in [
|
762 |
+
ModelVarType.LEARNED,
|
763 |
+
ModelVarType.LEARNED_RANGE,
|
764 |
+
]:
|
765 |
+
B, C = x_t.shape[:2]
|
766 |
+
assert output.shape == (B, C * 2, *x_t.shape[2:])
|
767 |
+
output, model_var_values = th.split(output, C, dim=1)
|
768 |
+
# Learn the variance using the variational bound, but don't let it affect our mean prediction.
|
769 |
+
frozen_out = th.cat([output.detach(), model_var_values], dim=1)
|
770 |
+
# vb variational bound
|
771 |
+
terms["vb"] = self._vb_terms_bpd(
|
772 |
+
model=lambda *args, r=frozen_out, **kwargs: r,
|
773 |
+
x_start=x_start,
|
774 |
+
x_t=x_t,
|
775 |
+
t=t,
|
776 |
+
clip_denoised=False,
|
777 |
+
)["output"]
|
778 |
+
if self.loss_type == LossType.RESCALED_MSE:
|
779 |
+
# Divide by 1000 for equivalence with initial implementation.
|
780 |
+
# Without a factor of 1/1000, the VB term hurts the MSE term.
|
781 |
+
terms["vb"] *= self.num_timesteps / 1000.0
|
782 |
+
|
783 |
+
target = {
|
784 |
+
ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance(
|
785 |
+
x_start=x_start, x_t=x_t, t=t
|
786 |
+
)[0],
|
787 |
+
ModelMeanType.START_X: x_start,
|
788 |
+
ModelMeanType.EPSILON: noise,
|
789 |
+
}[self.model_mean_type]
|
790 |
+
assert output.shape == target.shape == x_start.shape
|
791 |
+
if self.snr:
|
792 |
+
if self.model_mean_type == ModelMeanType.START_X:
|
793 |
+
pred_noise = self._predict_eps_from_xstart(x_t=x_t, t=t, pred_xstart=output)
|
794 |
+
pred_startx = output
|
795 |
+
elif self.model_mean_type == ModelMeanType.EPSILON:
|
796 |
+
pred_noise = output
|
797 |
+
pred_startx = self._predict_xstart_from_eps(x_t=x_t, t=t, eps=output)
|
798 |
+
# terms["mse_eps"] = mean_flat((noise - pred_noise) ** 2)
|
799 |
+
# terms["mse_x0"] = mean_flat((x_start - pred_startx) ** 2)
|
800 |
+
|
801 |
+
t = t[:, None, None, None, None].expand(pred_startx.shape) # [128, 4, 32, 32]
|
802 |
+
# best
|
803 |
+
target = th.where(t > 249, noise, x_start)
|
804 |
+
output = th.where(t > 249, pred_noise, pred_startx)
|
805 |
+
loss = (target - output) ** 2
|
806 |
+
if model_kwargs.get('mask_ratio', False) and model_kwargs['mask_ratio'] > 0:
|
807 |
+
assert 'mask' in model_output
|
808 |
+
loss = F.avg_pool2d(loss.mean(dim=1), model.model.module.patch_size).flatten(1)
|
809 |
+
mask = model_output['mask']
|
810 |
+
unmask = 1 - mask
|
811 |
+
terms['mse'] = mean_flat(loss * unmask) * unmask.shape[1]/unmask.sum(1)
|
812 |
+
if model_kwargs['mask_loss_coef'] > 0:
|
813 |
+
terms['mae'] = model_kwargs['mask_loss_coef'] * mean_flat(loss * mask) * mask.shape[1]/mask.sum(1)
|
814 |
+
else:
|
815 |
+
terms["mse"] = mean_flat(loss)
|
816 |
+
terms["loss"] = terms["mse"] + terms["vb"] if "vb" in terms else terms["mse"]
|
817 |
+
if "mae" in terms:
|
818 |
+
terms["loss"] = terms["loss"] + terms["mae"]
|
819 |
+
else:
|
820 |
+
raise NotImplementedError(self.loss_type)
|
821 |
+
|
822 |
+
return terms
|
823 |
+
|
824 |
+
def training_losses_diffusers(self, model, x_start, timestep, model_kwargs=None, noise=None, skip_noise=False):
|
825 |
+
"""
|
826 |
+
Compute training losses for a single timestep.
|
827 |
+
:param model: the model to evaluate loss on.
|
828 |
+
:param x_start: the [N x C x ...] tensor of inputs.
|
829 |
+
:param t: a batch of timestep indices.
|
830 |
+
:param model_kwargs: if not None, a dict of extra keyword arguments to
|
831 |
+
pass to the model. This can be used for conditioning.
|
832 |
+
:param noise: if specified, the specific Gaussian noise to try to remove.
|
833 |
+
:return: a dict with the key "loss" containing a tensor of shape [N].
|
834 |
+
Some mean or variance settings may also have other keys.
|
835 |
+
"""
|
836 |
+
t = timestep
|
837 |
+
if model_kwargs is None:
|
838 |
+
model_kwargs = {}
|
839 |
+
if skip_noise:
|
840 |
+
x_t = x_start
|
841 |
+
else:
|
842 |
+
if noise is None:
|
843 |
+
noise = th.randn_like(x_start)
|
844 |
+
x_t = self.q_sample(x_start, t, noise=noise)
|
845 |
+
|
846 |
+
terms = {}
|
847 |
+
|
848 |
+
if self.loss_type in [LossType.KL, LossType.RESCALED_KL]:
|
849 |
+
terms["loss"] = self._vb_terms_bpd(
|
850 |
+
model=model,
|
851 |
+
x_start=x_start,
|
852 |
+
x_t=x_t,
|
853 |
+
t=t,
|
854 |
+
clip_denoised=False,
|
855 |
+
model_kwargs=model_kwargs,
|
856 |
+
)["output"]
|
857 |
+
if self.loss_type == LossType.RESCALED_KL:
|
858 |
+
terms["loss"] *= self.num_timesteps
|
859 |
+
elif self.loss_type in [LossType.MSE, LossType.RESCALED_MSE]:
|
860 |
+
output = model(x_t, timestep=t, **model_kwargs, return_dict=False)[0]
|
861 |
+
|
862 |
+
if self.return_startx and self.model_mean_type == ModelMeanType.EPSILON:
|
863 |
+
return self._extracted_from_training_losses_diffusers(x_t, output, t)
|
864 |
+
|
865 |
+
if self.model_var_type in [
|
866 |
+
ModelVarType.LEARNED,
|
867 |
+
ModelVarType.LEARNED_RANGE,
|
868 |
+
]:
|
869 |
+
B, C = x_t.shape[:2]
|
870 |
+
assert output.shape == (B, C * 2, *x_t.shape[2:])
|
871 |
+
output, model_var_values = th.split(output, C, dim=1)
|
872 |
+
# Learn the variance using the variational bound, but don't let it affect our mean prediction.
|
873 |
+
frozen_out = th.cat([output.detach(), model_var_values], dim=1)
|
874 |
+
terms["vb"] = self._vb_terms_bpd(
|
875 |
+
model=lambda *args, r=frozen_out, **kwargs: r,
|
876 |
+
x_start=x_start,
|
877 |
+
x_t=x_t,
|
878 |
+
t=t,
|
879 |
+
clip_denoised=False,
|
880 |
+
)["output"]
|
881 |
+
if self.loss_type == LossType.RESCALED_MSE:
|
882 |
+
# Divide by 1000 for equivalence with initial implementation.
|
883 |
+
# Without a factor of 1/1000, the VB term hurts the MSE term.
|
884 |
+
terms["vb"] *= self.num_timesteps / 1000.0
|
885 |
+
|
886 |
+
target = {
|
887 |
+
ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance(
|
888 |
+
x_start=x_start, x_t=x_t, t=t
|
889 |
+
)[0],
|
890 |
+
ModelMeanType.START_X: x_start,
|
891 |
+
ModelMeanType.EPSILON: noise,
|
892 |
+
}[self.model_mean_type]
|
893 |
+
assert output.shape == target.shape == x_start.shape
|
894 |
+
if self.snr:
|
895 |
+
if self.model_mean_type == ModelMeanType.START_X:
|
896 |
+
pred_noise = self._predict_eps_from_xstart(x_t=x_t, t=t, pred_xstart=output)
|
897 |
+
pred_startx = output
|
898 |
+
elif self.model_mean_type == ModelMeanType.EPSILON:
|
899 |
+
pred_noise = output
|
900 |
+
pred_startx = self._predict_xstart_from_eps(x_t=x_t, t=t, eps=output)
|
901 |
+
# terms["mse_eps"] = mean_flat((noise - pred_noise) ** 2)
|
902 |
+
# terms["mse_x0"] = mean_flat((x_start - pred_startx) ** 2)
|
903 |
+
|
904 |
+
t = t[:, None, None, None].expand(pred_startx.shape) # [128, 4, 32, 32]
|
905 |
+
# best
|
906 |
+
target = th.where(t > 249, noise, x_start)
|
907 |
+
output = th.where(t > 249, pred_noise, pred_startx)
|
908 |
+
loss = (target - output) ** 2
|
909 |
+
terms["mse"] = mean_flat(loss)
|
910 |
+
terms["loss"] = terms["mse"] + terms["vb"] if "vb" in terms else terms["mse"]
|
911 |
+
if "mae" in terms:
|
912 |
+
terms["loss"] = terms["loss"] + terms["mae"]
|
913 |
+
else:
|
914 |
+
raise NotImplementedError(self.loss_type)
|
915 |
+
|
916 |
+
return terms
|
917 |
+
|
918 |
+
def _extracted_from_training_losses_diffusers(self, x_t, output, t):
|
919 |
+
B, C = x_t.shape[:2]
|
920 |
+
assert output.shape == (B, C * 2, *x_t.shape[2:])
|
921 |
+
output = th.split(output, C, dim=1)[0]
|
922 |
+
return output, self._predict_xstart_from_eps(x_t=x_t, t=t, eps=output), x_t
|
923 |
+
|
924 |
+
def _prior_bpd(self, x_start):
|
925 |
+
"""
|
926 |
+
Get the prior KL term for the variational lower-bound, measured in
|
927 |
+
bits-per-dim.
|
928 |
+
This term can't be optimized, as it only depends on the encoder.
|
929 |
+
:param x_start: the [N x C x ...] tensor of inputs.
|
930 |
+
:return: a batch of [N] KL values (in bits), one per batch element.
|
931 |
+
"""
|
932 |
+
batch_size = x_start.shape[0]
|
933 |
+
t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
|
934 |
+
qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
|
935 |
+
kl_prior = normal_kl(
|
936 |
+
mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0
|
937 |
+
)
|
938 |
+
return mean_flat(kl_prior) / np.log(2.0)
|
939 |
+
|
940 |
+
def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None):
|
941 |
+
"""
|
942 |
+
Compute the entire variational lower-bound, measured in bits-per-dim,
|
943 |
+
as well as other related quantities.
|
944 |
+
:param model: the model to evaluate loss on.
|
945 |
+
:param x_start: the [N x C x ...] tensor of inputs.
|
946 |
+
:param clip_denoised: if True, clip denoised samples.
|
947 |
+
:param model_kwargs: if not None, a dict of extra keyword arguments to
|
948 |
+
pass to the model. This can be used for conditioning.
|
949 |
+
:return: a dict containing the following keys:
|
950 |
+
- total_bpd: the total variational lower-bound, per batch element.
|
951 |
+
- prior_bpd: the prior term in the lower-bound.
|
952 |
+
- vb: an [N x T] tensor of terms in the lower-bound.
|
953 |
+
- xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep.
|
954 |
+
- mse: an [N x T] tensor of epsilon MSEs for each timestep.
|
955 |
+
"""
|
956 |
+
device = x_start.device
|
957 |
+
batch_size = x_start.shape[0]
|
958 |
+
|
959 |
+
vb = []
|
960 |
+
xstart_mse = []
|
961 |
+
mse = []
|
962 |
+
for t in list(range(self.num_timesteps))[::-1]:
|
963 |
+
t_batch = th.tensor([t] * batch_size, device=device)
|
964 |
+
noise = th.randn_like(x_start)
|
965 |
+
x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise)
|
966 |
+
# Calculate VLB term at the current timestep
|
967 |
+
with th.no_grad():
|
968 |
+
out = self._vb_terms_bpd(
|
969 |
+
model,
|
970 |
+
x_start=x_start,
|
971 |
+
x_t=x_t,
|
972 |
+
t=t_batch,
|
973 |
+
clip_denoised=clip_denoised,
|
974 |
+
model_kwargs=model_kwargs,
|
975 |
+
)
|
976 |
+
vb.append(out["output"])
|
977 |
+
xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2))
|
978 |
+
eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"])
|
979 |
+
mse.append(mean_flat((eps - noise) ** 2))
|
980 |
+
|
981 |
+
vb = th.stack(vb, dim=1)
|
982 |
+
xstart_mse = th.stack(xstart_mse, dim=1)
|
983 |
+
mse = th.stack(mse, dim=1)
|
984 |
+
|
985 |
+
prior_bpd = self._prior_bpd(x_start)
|
986 |
+
total_bpd = vb.sum(dim=1) + prior_bpd
|
987 |
+
return {
|
988 |
+
"total_bpd": total_bpd,
|
989 |
+
"prior_bpd": prior_bpd,
|
990 |
+
"vb": vb,
|
991 |
+
"xstart_mse": xstart_mse,
|
992 |
+
"mse": mse,
|
993 |
+
}
|
994 |
+
|
995 |
+
|
996 |
+
def _extract_into_tensor(arr, timesteps, broadcast_shape):
|
997 |
+
"""
|
998 |
+
Extract values from a 1-D numpy array for a batch of indices.
|
999 |
+
:param arr: the 1-D numpy array.
|
1000 |
+
:param timesteps: a tensor of indices into the array to extract.
|
1001 |
+
:param broadcast_shape: a larger shape of K dimensions with the batch
|
1002 |
+
dimension equal to the length of timesteps.
|
1003 |
+
:return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
|
1004 |
+
"""
|
1005 |
+
res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
|
1006 |
+
while len(res.shape) < len(broadcast_shape):
|
1007 |
+
res = res[..., None]
|
1008 |
+
return res + th.zeros(broadcast_shape, device=timesteps.device)
|
easyanimate/utils/lora_utils.py
ADDED
@@ -0,0 +1,476 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# LoRA network module
|
2 |
+
# reference:
|
3 |
+
# https://github.com/microsoft/LoRA/blob/main/loralib/layers.py
|
4 |
+
# https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py
|
5 |
+
# https://github.com/bmaltais/kohya_ss
|
6 |
+
|
7 |
+
import hashlib
|
8 |
+
import math
|
9 |
+
import os
|
10 |
+
from collections import defaultdict
|
11 |
+
from io import BytesIO
|
12 |
+
from typing import List, Optional, Type, Union
|
13 |
+
|
14 |
+
import safetensors.torch
|
15 |
+
import torch
|
16 |
+
import torch.utils.checkpoint
|
17 |
+
from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear
|
18 |
+
from safetensors.torch import load_file
|
19 |
+
from transformers import T5EncoderModel
|
20 |
+
|
21 |
+
|
22 |
+
class LoRAModule(torch.nn.Module):
|
23 |
+
"""
|
24 |
+
replaces forward method of the original Linear, instead of replacing the original Linear module.
|
25 |
+
"""
|
26 |
+
|
27 |
+
def __init__(
|
28 |
+
self,
|
29 |
+
lora_name,
|
30 |
+
org_module: torch.nn.Module,
|
31 |
+
multiplier=1.0,
|
32 |
+
lora_dim=4,
|
33 |
+
alpha=1,
|
34 |
+
dropout=None,
|
35 |
+
rank_dropout=None,
|
36 |
+
module_dropout=None,
|
37 |
+
):
|
38 |
+
"""if alpha == 0 or None, alpha is rank (no scaling)."""
|
39 |
+
super().__init__()
|
40 |
+
self.lora_name = lora_name
|
41 |
+
|
42 |
+
if org_module.__class__.__name__ == "Conv2d":
|
43 |
+
in_dim = org_module.in_channels
|
44 |
+
out_dim = org_module.out_channels
|
45 |
+
else:
|
46 |
+
in_dim = org_module.in_features
|
47 |
+
out_dim = org_module.out_features
|
48 |
+
|
49 |
+
self.lora_dim = lora_dim
|
50 |
+
if org_module.__class__.__name__ == "Conv2d":
|
51 |
+
kernel_size = org_module.kernel_size
|
52 |
+
stride = org_module.stride
|
53 |
+
padding = org_module.padding
|
54 |
+
self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False)
|
55 |
+
self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False)
|
56 |
+
else:
|
57 |
+
self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False)
|
58 |
+
self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False)
|
59 |
+
|
60 |
+
if type(alpha) == torch.Tensor:
|
61 |
+
alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
|
62 |
+
alpha = self.lora_dim if alpha is None or alpha == 0 else alpha
|
63 |
+
self.scale = alpha / self.lora_dim
|
64 |
+
self.register_buffer("alpha", torch.tensor(alpha))
|
65 |
+
|
66 |
+
# same as microsoft's
|
67 |
+
torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
|
68 |
+
torch.nn.init.zeros_(self.lora_up.weight)
|
69 |
+
|
70 |
+
self.multiplier = multiplier
|
71 |
+
self.org_module = org_module # remove in applying
|
72 |
+
self.dropout = dropout
|
73 |
+
self.rank_dropout = rank_dropout
|
74 |
+
self.module_dropout = module_dropout
|
75 |
+
|
76 |
+
def apply_to(self):
|
77 |
+
self.org_forward = self.org_module.forward
|
78 |
+
self.org_module.forward = self.forward
|
79 |
+
del self.org_module
|
80 |
+
|
81 |
+
def forward(self, x, *args, **kwargs):
|
82 |
+
weight_dtype = x.dtype
|
83 |
+
org_forwarded = self.org_forward(x)
|
84 |
+
|
85 |
+
# module dropout
|
86 |
+
if self.module_dropout is not None and self.training:
|
87 |
+
if torch.rand(1) < self.module_dropout:
|
88 |
+
return org_forwarded
|
89 |
+
|
90 |
+
lx = self.lora_down(x.to(self.lora_down.weight.dtype))
|
91 |
+
|
92 |
+
# normal dropout
|
93 |
+
if self.dropout is not None and self.training:
|
94 |
+
lx = torch.nn.functional.dropout(lx, p=self.dropout)
|
95 |
+
|
96 |
+
# rank dropout
|
97 |
+
if self.rank_dropout is not None and self.training:
|
98 |
+
mask = torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout
|
99 |
+
if len(lx.size()) == 3:
|
100 |
+
mask = mask.unsqueeze(1) # for Text Encoder
|
101 |
+
elif len(lx.size()) == 4:
|
102 |
+
mask = mask.unsqueeze(-1).unsqueeze(-1) # for Conv2d
|
103 |
+
lx = lx * mask
|
104 |
+
|
105 |
+
# scaling for rank dropout: treat as if the rank is changed
|
106 |
+
scale = self.scale * (1.0 / (1.0 - self.rank_dropout)) # redundant for readability
|
107 |
+
else:
|
108 |
+
scale = self.scale
|
109 |
+
|
110 |
+
lx = self.lora_up(lx)
|
111 |
+
|
112 |
+
return org_forwarded.to(weight_dtype) + lx.to(weight_dtype) * self.multiplier * scale
|
113 |
+
|
114 |
+
|
115 |
+
def addnet_hash_legacy(b):
|
116 |
+
"""Old model hash used by sd-webui-additional-networks for .safetensors format files"""
|
117 |
+
m = hashlib.sha256()
|
118 |
+
|
119 |
+
b.seek(0x100000)
|
120 |
+
m.update(b.read(0x10000))
|
121 |
+
return m.hexdigest()[0:8]
|
122 |
+
|
123 |
+
|
124 |
+
def addnet_hash_safetensors(b):
|
125 |
+
"""New model hash used by sd-webui-additional-networks for .safetensors format files"""
|
126 |
+
hash_sha256 = hashlib.sha256()
|
127 |
+
blksize = 1024 * 1024
|
128 |
+
|
129 |
+
b.seek(0)
|
130 |
+
header = b.read(8)
|
131 |
+
n = int.from_bytes(header, "little")
|
132 |
+
|
133 |
+
offset = n + 8
|
134 |
+
b.seek(offset)
|
135 |
+
for chunk in iter(lambda: b.read(blksize), b""):
|
136 |
+
hash_sha256.update(chunk)
|
137 |
+
|
138 |
+
return hash_sha256.hexdigest()
|
139 |
+
|
140 |
+
|
141 |
+
def precalculate_safetensors_hashes(tensors, metadata):
|
142 |
+
"""Precalculate the model hashes needed by sd-webui-additional-networks to
|
143 |
+
save time on indexing the model later."""
|
144 |
+
|
145 |
+
# Because writing user metadata to the file can change the result of
|
146 |
+
# sd_models.model_hash(), only retain the training metadata for purposes of
|
147 |
+
# calculating the hash, as they are meant to be immutable
|
148 |
+
metadata = {k: v for k, v in metadata.items() if k.startswith("ss_")}
|
149 |
+
|
150 |
+
bytes = safetensors.torch.save(tensors, metadata)
|
151 |
+
b = BytesIO(bytes)
|
152 |
+
|
153 |
+
model_hash = addnet_hash_safetensors(b)
|
154 |
+
legacy_hash = addnet_hash_legacy(b)
|
155 |
+
return model_hash, legacy_hash
|
156 |
+
|
157 |
+
|
158 |
+
class LoRANetwork(torch.nn.Module):
|
159 |
+
TRANSFORMER_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Transformer3DModel"]
|
160 |
+
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["T5LayerSelfAttention", "T5LayerFF"]
|
161 |
+
LORA_PREFIX_TRANSFORMER = "lora_unet"
|
162 |
+
LORA_PREFIX_TEXT_ENCODER = "lora_te"
|
163 |
+
def __init__(
|
164 |
+
self,
|
165 |
+
text_encoder: Union[List[T5EncoderModel], T5EncoderModel],
|
166 |
+
unet,
|
167 |
+
multiplier: float = 1.0,
|
168 |
+
lora_dim: int = 4,
|
169 |
+
alpha: float = 1,
|
170 |
+
dropout: Optional[float] = None,
|
171 |
+
module_class: Type[object] = LoRAModule,
|
172 |
+
add_lora_in_attn_temporal: bool = False,
|
173 |
+
varbose: Optional[bool] = False,
|
174 |
+
) -> None:
|
175 |
+
super().__init__()
|
176 |
+
self.multiplier = multiplier
|
177 |
+
|
178 |
+
self.lora_dim = lora_dim
|
179 |
+
self.alpha = alpha
|
180 |
+
self.dropout = dropout
|
181 |
+
|
182 |
+
print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}")
|
183 |
+
print(f"neuron dropout: p={self.dropout}")
|
184 |
+
|
185 |
+
# create module instances
|
186 |
+
def create_modules(
|
187 |
+
is_unet: bool,
|
188 |
+
root_module: torch.nn.Module,
|
189 |
+
target_replace_modules: List[torch.nn.Module],
|
190 |
+
) -> List[LoRAModule]:
|
191 |
+
prefix = (
|
192 |
+
self.LORA_PREFIX_TRANSFORMER
|
193 |
+
if is_unet
|
194 |
+
else self.LORA_PREFIX_TEXT_ENCODER
|
195 |
+
)
|
196 |
+
loras = []
|
197 |
+
skipped = []
|
198 |
+
for name, module in root_module.named_modules():
|
199 |
+
if module.__class__.__name__ in target_replace_modules:
|
200 |
+
for child_name, child_module in module.named_modules():
|
201 |
+
is_linear = child_module.__class__.__name__ == "Linear" or child_module.__class__.__name__ == "LoRACompatibleLinear"
|
202 |
+
is_conv2d = child_module.__class__.__name__ == "Conv2d" or child_module.__class__.__name__ == "LoRACompatibleConv"
|
203 |
+
is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1)
|
204 |
+
|
205 |
+
if not add_lora_in_attn_temporal:
|
206 |
+
if "attn_temporal" in child_name:
|
207 |
+
continue
|
208 |
+
|
209 |
+
if is_linear or is_conv2d:
|
210 |
+
lora_name = prefix + "." + name + "." + child_name
|
211 |
+
lora_name = lora_name.replace(".", "_")
|
212 |
+
|
213 |
+
dim = None
|
214 |
+
alpha = None
|
215 |
+
|
216 |
+
if is_linear or is_conv2d_1x1:
|
217 |
+
dim = self.lora_dim
|
218 |
+
alpha = self.alpha
|
219 |
+
|
220 |
+
if dim is None or dim == 0:
|
221 |
+
if is_linear or is_conv2d_1x1:
|
222 |
+
skipped.append(lora_name)
|
223 |
+
continue
|
224 |
+
|
225 |
+
lora = module_class(
|
226 |
+
lora_name,
|
227 |
+
child_module,
|
228 |
+
self.multiplier,
|
229 |
+
dim,
|
230 |
+
alpha,
|
231 |
+
dropout=dropout,
|
232 |
+
)
|
233 |
+
loras.append(lora)
|
234 |
+
return loras, skipped
|
235 |
+
|
236 |
+
text_encoders = text_encoder if type(text_encoder) == list else [text_encoder]
|
237 |
+
|
238 |
+
self.text_encoder_loras = []
|
239 |
+
skipped_te = []
|
240 |
+
for i, text_encoder in enumerate(text_encoders):
|
241 |
+
text_encoder_loras, skipped = create_modules(False, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
|
242 |
+
self.text_encoder_loras.extend(text_encoder_loras)
|
243 |
+
skipped_te += skipped
|
244 |
+
print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
|
245 |
+
|
246 |
+
self.unet_loras, skipped_un = create_modules(True, unet, LoRANetwork.TRANSFORMER_TARGET_REPLACE_MODULE)
|
247 |
+
print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
|
248 |
+
|
249 |
+
# assertion
|
250 |
+
names = set()
|
251 |
+
for lora in self.text_encoder_loras + self.unet_loras:
|
252 |
+
assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
|
253 |
+
names.add(lora.lora_name)
|
254 |
+
|
255 |
+
def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True):
|
256 |
+
if apply_text_encoder:
|
257 |
+
print("enable LoRA for text encoder")
|
258 |
+
else:
|
259 |
+
self.text_encoder_loras = []
|
260 |
+
|
261 |
+
if apply_unet:
|
262 |
+
print("enable LoRA for U-Net")
|
263 |
+
else:
|
264 |
+
self.unet_loras = []
|
265 |
+
|
266 |
+
for lora in self.text_encoder_loras + self.unet_loras:
|
267 |
+
lora.apply_to()
|
268 |
+
self.add_module(lora.lora_name, lora)
|
269 |
+
|
270 |
+
def set_multiplier(self, multiplier):
|
271 |
+
self.multiplier = multiplier
|
272 |
+
for lora in self.text_encoder_loras + self.unet_loras:
|
273 |
+
lora.multiplier = self.multiplier
|
274 |
+
|
275 |
+
def load_weights(self, file):
|
276 |
+
if os.path.splitext(file)[1] == ".safetensors":
|
277 |
+
from safetensors.torch import load_file
|
278 |
+
|
279 |
+
weights_sd = load_file(file)
|
280 |
+
else:
|
281 |
+
weights_sd = torch.load(file, map_location="cpu")
|
282 |
+
info = self.load_state_dict(weights_sd, False)
|
283 |
+
return info
|
284 |
+
|
285 |
+
def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
|
286 |
+
self.requires_grad_(True)
|
287 |
+
all_params = []
|
288 |
+
|
289 |
+
def enumerate_params(loras):
|
290 |
+
params = []
|
291 |
+
for lora in loras:
|
292 |
+
params.extend(lora.parameters())
|
293 |
+
return params
|
294 |
+
|
295 |
+
if self.text_encoder_loras:
|
296 |
+
param_data = {"params": enumerate_params(self.text_encoder_loras)}
|
297 |
+
if text_encoder_lr is not None:
|
298 |
+
param_data["lr"] = text_encoder_lr
|
299 |
+
all_params.append(param_data)
|
300 |
+
|
301 |
+
if self.unet_loras:
|
302 |
+
param_data = {"params": enumerate_params(self.unet_loras)}
|
303 |
+
if unet_lr is not None:
|
304 |
+
param_data["lr"] = unet_lr
|
305 |
+
all_params.append(param_data)
|
306 |
+
|
307 |
+
return all_params
|
308 |
+
|
309 |
+
def enable_gradient_checkpointing(self):
|
310 |
+
pass
|
311 |
+
|
312 |
+
def get_trainable_params(self):
|
313 |
+
return self.parameters()
|
314 |
+
|
315 |
+
def save_weights(self, file, dtype, metadata):
|
316 |
+
if metadata is not None and len(metadata) == 0:
|
317 |
+
metadata = None
|
318 |
+
|
319 |
+
state_dict = self.state_dict()
|
320 |
+
|
321 |
+
if dtype is not None:
|
322 |
+
for key in list(state_dict.keys()):
|
323 |
+
v = state_dict[key]
|
324 |
+
v = v.detach().clone().to("cpu").to(dtype)
|
325 |
+
state_dict[key] = v
|
326 |
+
|
327 |
+
if os.path.splitext(file)[1] == ".safetensors":
|
328 |
+
from safetensors.torch import save_file
|
329 |
+
|
330 |
+
# Precalculate model hashes to save time on indexing
|
331 |
+
if metadata is None:
|
332 |
+
metadata = {}
|
333 |
+
model_hash, legacy_hash = precalculate_safetensors_hashes(state_dict, metadata)
|
334 |
+
metadata["sshs_model_hash"] = model_hash
|
335 |
+
metadata["sshs_legacy_hash"] = legacy_hash
|
336 |
+
|
337 |
+
save_file(state_dict, file, metadata)
|
338 |
+
else:
|
339 |
+
torch.save(state_dict, file)
|
340 |
+
|
341 |
+
def create_network(
|
342 |
+
multiplier: float,
|
343 |
+
network_dim: Optional[int],
|
344 |
+
network_alpha: Optional[float],
|
345 |
+
text_encoder: Union[T5EncoderModel, List[T5EncoderModel]],
|
346 |
+
transformer,
|
347 |
+
neuron_dropout: Optional[float] = None,
|
348 |
+
add_lora_in_attn_temporal: bool = False,
|
349 |
+
**kwargs,
|
350 |
+
):
|
351 |
+
if network_dim is None:
|
352 |
+
network_dim = 4 # default
|
353 |
+
if network_alpha is None:
|
354 |
+
network_alpha = 1.0
|
355 |
+
|
356 |
+
network = LoRANetwork(
|
357 |
+
text_encoder,
|
358 |
+
transformer,
|
359 |
+
multiplier=multiplier,
|
360 |
+
lora_dim=network_dim,
|
361 |
+
alpha=network_alpha,
|
362 |
+
dropout=neuron_dropout,
|
363 |
+
add_lora_in_attn_temporal=add_lora_in_attn_temporal,
|
364 |
+
varbose=True,
|
365 |
+
)
|
366 |
+
return network
|
367 |
+
|
368 |
+
def merge_lora(pipeline, lora_path, multiplier, device='cpu', dtype=torch.float32, state_dict=None, transformer_only=False):
|
369 |
+
LORA_PREFIX_TRANSFORMER = "lora_unet"
|
370 |
+
LORA_PREFIX_TEXT_ENCODER = "lora_te"
|
371 |
+
if state_dict is None:
|
372 |
+
state_dict = load_file(lora_path, device=device)
|
373 |
+
else:
|
374 |
+
state_dict = state_dict
|
375 |
+
updates = defaultdict(dict)
|
376 |
+
for key, value in state_dict.items():
|
377 |
+
layer, elem = key.split('.', 1)
|
378 |
+
updates[layer][elem] = value
|
379 |
+
|
380 |
+
for layer, elems in updates.items():
|
381 |
+
|
382 |
+
if "lora_te" in layer:
|
383 |
+
if transformer_only:
|
384 |
+
continue
|
385 |
+
else:
|
386 |
+
layer_infos = layer.split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_")
|
387 |
+
curr_layer = pipeline.text_encoder
|
388 |
+
else:
|
389 |
+
layer_infos = layer.split(LORA_PREFIX_TRANSFORMER + "_")[-1].split("_")
|
390 |
+
curr_layer = pipeline.transformer
|
391 |
+
|
392 |
+
temp_name = layer_infos.pop(0)
|
393 |
+
while len(layer_infos) > -1:
|
394 |
+
try:
|
395 |
+
curr_layer = curr_layer.__getattr__(temp_name)
|
396 |
+
if len(layer_infos) > 0:
|
397 |
+
temp_name = layer_infos.pop(0)
|
398 |
+
elif len(layer_infos) == 0:
|
399 |
+
break
|
400 |
+
except Exception:
|
401 |
+
if len(layer_infos) == 0:
|
402 |
+
print('Error loading layer')
|
403 |
+
if len(temp_name) > 0:
|
404 |
+
temp_name += "_" + layer_infos.pop(0)
|
405 |
+
else:
|
406 |
+
temp_name = layer_infos.pop(0)
|
407 |
+
|
408 |
+
weight_up = elems['lora_up.weight'].to(dtype)
|
409 |
+
weight_down = elems['lora_down.weight'].to(dtype)
|
410 |
+
if 'alpha' in elems.keys():
|
411 |
+
alpha = elems['alpha'].item() / weight_up.shape[1]
|
412 |
+
else:
|
413 |
+
alpha = 1.0
|
414 |
+
|
415 |
+
curr_layer.weight.data = curr_layer.weight.data.to(device)
|
416 |
+
if len(weight_up.shape) == 4:
|
417 |
+
curr_layer.weight.data += multiplier * alpha * torch.mm(weight_up.squeeze(3).squeeze(2),
|
418 |
+
weight_down.squeeze(3).squeeze(2)).unsqueeze(
|
419 |
+
2).unsqueeze(3)
|
420 |
+
else:
|
421 |
+
curr_layer.weight.data += multiplier * alpha * torch.mm(weight_up, weight_down)
|
422 |
+
|
423 |
+
return pipeline
|
424 |
+
|
425 |
+
# TODO: Refactor with merge_lora.
|
426 |
+
def unmerge_lora(pipeline, lora_path, multiplier=1, device="cpu", dtype=torch.float32):
|
427 |
+
"""Unmerge state_dict in LoRANetwork from the pipeline in diffusers."""
|
428 |
+
LORA_PREFIX_UNET = "lora_unet"
|
429 |
+
LORA_PREFIX_TEXT_ENCODER = "lora_te"
|
430 |
+
state_dict = load_file(lora_path, device=device)
|
431 |
+
|
432 |
+
updates = defaultdict(dict)
|
433 |
+
for key, value in state_dict.items():
|
434 |
+
layer, elem = key.split('.', 1)
|
435 |
+
updates[layer][elem] = value
|
436 |
+
|
437 |
+
for layer, elems in updates.items():
|
438 |
+
|
439 |
+
if "lora_te" in layer:
|
440 |
+
layer_infos = layer.split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_")
|
441 |
+
curr_layer = pipeline.text_encoder
|
442 |
+
else:
|
443 |
+
layer_infos = layer.split(LORA_PREFIX_UNET + "_")[-1].split("_")
|
444 |
+
curr_layer = pipeline.transformer
|
445 |
+
|
446 |
+
temp_name = layer_infos.pop(0)
|
447 |
+
while len(layer_infos) > -1:
|
448 |
+
try:
|
449 |
+
curr_layer = curr_layer.__getattr__(temp_name)
|
450 |
+
if len(layer_infos) > 0:
|
451 |
+
temp_name = layer_infos.pop(0)
|
452 |
+
elif len(layer_infos) == 0:
|
453 |
+
break
|
454 |
+
except Exception:
|
455 |
+
if len(layer_infos) == 0:
|
456 |
+
print('Error loading layer')
|
457 |
+
if len(temp_name) > 0:
|
458 |
+
temp_name += "_" + layer_infos.pop(0)
|
459 |
+
else:
|
460 |
+
temp_name = layer_infos.pop(0)
|
461 |
+
|
462 |
+
weight_up = elems['lora_up.weight'].to(dtype)
|
463 |
+
weight_down = elems['lora_down.weight'].to(dtype)
|
464 |
+
if 'alpha' in elems.keys():
|
465 |
+
alpha = elems['alpha'].item() / weight_up.shape[1]
|
466 |
+
else:
|
467 |
+
alpha = 1.0
|
468 |
+
|
469 |
+
curr_layer.weight.data = curr_layer.weight.data.to(device)
|
470 |
+
if len(weight_up.shape) == 4:
|
471 |
+
curr_layer.weight.data -= multiplier * alpha * torch.mm(weight_up.squeeze(3).squeeze(2),
|
472 |
+
weight_down.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
|
473 |
+
else:
|
474 |
+
curr_layer.weight.data -= multiplier * alpha * torch.mm(weight_up, weight_down)
|
475 |
+
|
476 |
+
return pipeline
|
easyanimate/utils/respace.py
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Modified from OpenAI's diffusion repos
|
2 |
+
# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
|
3 |
+
# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
|
4 |
+
# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import torch as th
|
8 |
+
|
9 |
+
from .gaussian_diffusion import GaussianDiffusion
|
10 |
+
|
11 |
+
|
12 |
+
def space_timesteps(num_timesteps, section_counts):
|
13 |
+
"""
|
14 |
+
Create a list of timesteps to use from an original diffusion process,
|
15 |
+
given the number of timesteps we want to take from equally-sized portions
|
16 |
+
of the original process.
|
17 |
+
For example, if there's 300 timesteps and the section counts are [10,15,20]
|
18 |
+
then the first 100 timesteps are strided to be 10 timesteps, the second 100
|
19 |
+
are strided to be 15 timesteps, and the final 100 are strided to be 20.
|
20 |
+
If the stride is a string starting with "ddim", then the fixed striding
|
21 |
+
from the DDIM paper is used, and only one section is allowed.
|
22 |
+
:param num_timesteps: the number of diffusion steps in the original
|
23 |
+
process to divide up.
|
24 |
+
:param section_counts: either a list of numbers, or a string containing
|
25 |
+
comma-separated numbers, indicating the step count
|
26 |
+
per section. As a special case, use "ddimN" where N
|
27 |
+
is a number of steps to use the striding from the
|
28 |
+
DDIM paper.
|
29 |
+
:return: a set of diffusion steps from the original process to use.
|
30 |
+
"""
|
31 |
+
if isinstance(section_counts, str):
|
32 |
+
if section_counts.startswith("ddim"):
|
33 |
+
desired_count = int(section_counts[len("ddim") :])
|
34 |
+
for i in range(1, num_timesteps):
|
35 |
+
if len(range(0, num_timesteps, i)) == desired_count:
|
36 |
+
return set(range(0, num_timesteps, i))
|
37 |
+
raise ValueError(
|
38 |
+
f"cannot create exactly {num_timesteps} steps with an integer stride"
|
39 |
+
)
|
40 |
+
section_counts = [int(x) for x in section_counts.split(",")]
|
41 |
+
size_per = num_timesteps // len(section_counts)
|
42 |
+
extra = num_timesteps % len(section_counts)
|
43 |
+
start_idx = 0
|
44 |
+
all_steps = []
|
45 |
+
for i, section_count in enumerate(section_counts):
|
46 |
+
size = size_per + (1 if i < extra else 0)
|
47 |
+
if size < section_count:
|
48 |
+
raise ValueError(
|
49 |
+
f"cannot divide section of {size} steps into {section_count}"
|
50 |
+
)
|
51 |
+
frac_stride = 1 if section_count <= 1 else (size - 1) / (section_count - 1)
|
52 |
+
cur_idx = 0.0
|
53 |
+
taken_steps = []
|
54 |
+
for _ in range(section_count):
|
55 |
+
taken_steps.append(start_idx + round(cur_idx))
|
56 |
+
cur_idx += frac_stride
|
57 |
+
all_steps += taken_steps
|
58 |
+
start_idx += size
|
59 |
+
return set(all_steps)
|
60 |
+
|
61 |
+
|
62 |
+
class SpacedDiffusion(GaussianDiffusion):
|
63 |
+
"""
|
64 |
+
A diffusion process which can skip steps in a base diffusion process.
|
65 |
+
:param use_timesteps: a collection (sequence or set) of timesteps from the
|
66 |
+
original diffusion process to retain.
|
67 |
+
:param kwargs: the kwargs to create the base diffusion process.
|
68 |
+
"""
|
69 |
+
|
70 |
+
def __init__(self, use_timesteps, **kwargs):
|
71 |
+
self.use_timesteps = set(use_timesteps)
|
72 |
+
self.timestep_map = []
|
73 |
+
self.original_num_steps = len(kwargs["betas"])
|
74 |
+
|
75 |
+
base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa
|
76 |
+
last_alpha_cumprod = 1.0
|
77 |
+
new_betas = []
|
78 |
+
for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
|
79 |
+
if i in self.use_timesteps:
|
80 |
+
new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
|
81 |
+
last_alpha_cumprod = alpha_cumprod
|
82 |
+
self.timestep_map.append(i)
|
83 |
+
kwargs["betas"] = np.array(new_betas)
|
84 |
+
super().__init__(**kwargs)
|
85 |
+
|
86 |
+
def p_mean_variance(
|
87 |
+
self, model, *args, **kwargs
|
88 |
+
): # pylint: disable=signature-differs
|
89 |
+
return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
|
90 |
+
|
91 |
+
def training_losses(
|
92 |
+
self, model, *args, **kwargs
|
93 |
+
): # pylint: disable=signature-differs
|
94 |
+
return super().training_losses(self._wrap_model(model), *args, **kwargs)
|
95 |
+
|
96 |
+
def training_losses_diffusers(
|
97 |
+
self, model, *args, **kwargs
|
98 |
+
): # pylint: disable=signature-differs
|
99 |
+
return super().training_losses_diffusers(self._wrap_model(model), *args, **kwargs)
|
100 |
+
|
101 |
+
def condition_mean(self, cond_fn, *args, **kwargs):
|
102 |
+
return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs)
|
103 |
+
|
104 |
+
def condition_score(self, cond_fn, *args, **kwargs):
|
105 |
+
return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs)
|
106 |
+
|
107 |
+
def _wrap_model(self, model):
|
108 |
+
if isinstance(model, _WrappedModel):
|
109 |
+
return model
|
110 |
+
return _WrappedModel(
|
111 |
+
model, self.timestep_map, self.original_num_steps
|
112 |
+
)
|
113 |
+
|
114 |
+
def _scale_timesteps(self, t):
|
115 |
+
# Scaling is done by the wrapped model.
|
116 |
+
return t
|
117 |
+
|
118 |
+
|
119 |
+
class _WrappedModel:
|
120 |
+
def __init__(self, model, timestep_map, original_num_steps):
|
121 |
+
self.model = model
|
122 |
+
self.timestep_map = timestep_map
|
123 |
+
# self.rescale_timesteps = rescale_timesteps
|
124 |
+
self.original_num_steps = original_num_steps
|
125 |
+
|
126 |
+
def __call__(self, x, timestep, **kwargs):
|
127 |
+
map_tensor = th.tensor(self.timestep_map, device=timestep.device, dtype=timestep.dtype)
|
128 |
+
new_ts = map_tensor[timestep]
|
129 |
+
# if self.rescale_timesteps:
|
130 |
+
# new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
|
131 |
+
return self.model(x, timestep=new_ts, **kwargs)
|
easyanimate/utils/utils.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import imageio
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
import torchvision
|
7 |
+
import cv2
|
8 |
+
from einops import rearrange
|
9 |
+
from PIL import Image
|
10 |
+
|
11 |
+
|
12 |
+
def color_transfer(sc, dc):
|
13 |
+
"""
|
14 |
+
Transfer color distribution from of sc, referred to dc.
|
15 |
+
|
16 |
+
Args:
|
17 |
+
sc (numpy.ndarray): input image to be transfered.
|
18 |
+
dc (numpy.ndarray): reference image
|
19 |
+
|
20 |
+
Returns:
|
21 |
+
numpy.ndarray: Transferred color distribution on the sc.
|
22 |
+
"""
|
23 |
+
|
24 |
+
def get_mean_and_std(img):
|
25 |
+
x_mean, x_std = cv2.meanStdDev(img)
|
26 |
+
x_mean = np.hstack(np.around(x_mean, 2))
|
27 |
+
x_std = np.hstack(np.around(x_std, 2))
|
28 |
+
return x_mean, x_std
|
29 |
+
|
30 |
+
sc = cv2.cvtColor(sc, cv2.COLOR_RGB2LAB)
|
31 |
+
s_mean, s_std = get_mean_and_std(sc)
|
32 |
+
dc = cv2.cvtColor(dc, cv2.COLOR_RGB2LAB)
|
33 |
+
t_mean, t_std = get_mean_and_std(dc)
|
34 |
+
img_n = ((sc - s_mean) * (t_std / s_std)) + t_mean
|
35 |
+
np.putmask(img_n, img_n > 255, 255)
|
36 |
+
np.putmask(img_n, img_n < 0, 0)
|
37 |
+
dst = cv2.cvtColor(cv2.convertScaleAbs(img_n), cv2.COLOR_LAB2RGB)
|
38 |
+
return dst
|
39 |
+
|
40 |
+
def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=12, imageio_backend=True, color_transfer_post_process=False):
|
41 |
+
videos = rearrange(videos, "b c t h w -> t b c h w")
|
42 |
+
outputs = []
|
43 |
+
for x in videos:
|
44 |
+
x = torchvision.utils.make_grid(x, nrow=n_rows)
|
45 |
+
x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
|
46 |
+
if rescale:
|
47 |
+
x = (x + 1.0) / 2.0 # -1,1 -> 0,1
|
48 |
+
x = (x * 255).numpy().astype(np.uint8)
|
49 |
+
outputs.append(Image.fromarray(x))
|
50 |
+
|
51 |
+
if color_transfer_post_process:
|
52 |
+
for i in range(1, len(outputs)):
|
53 |
+
outputs[i] = Image.fromarray(color_transfer(np.uint8(outputs[i]), np.uint8(outputs[0])))
|
54 |
+
|
55 |
+
os.makedirs(os.path.dirname(path), exist_ok=True)
|
56 |
+
if imageio_backend:
|
57 |
+
if path.endswith("mp4"):
|
58 |
+
imageio.mimsave(path, outputs, fps=fps)
|
59 |
+
else:
|
60 |
+
imageio.mimsave(path, outputs, duration=(1000 * 1/fps))
|
61 |
+
else:
|
62 |
+
if path.endswith("mp4"):
|
63 |
+
path = path.replace('.mp4', '.gif')
|
64 |
+
outputs[0].save(path, format='GIF', append_images=outputs, save_all=True, duration=100, loop=0)
|
easyanimate/vae/LICENSE
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Copyright (c) 2022 Robin Rombach and Patrick Esser and contributors
|
2 |
+
|
3 |
+
CreativeML Open RAIL-M
|
4 |
+
dated August 22, 2022
|
5 |
+
|
6 |
+
Section I: PREAMBLE
|
7 |
+
|
8 |
+
Multimodal generative models are being widely adopted and used, and have the potential to transform the way artists, among other individuals, conceive and benefit from AI or ML technologies as a tool for content creation.
|
9 |
+
|
10 |
+
Notwithstanding the current and potential benefits that these artifacts can bring to society at large, there are also concerns about potential misuses of them, either due to their technical limitations or ethical considerations.
|
11 |
+
|
12 |
+
In short, this license strives for both the open and responsible downstream use of the accompanying model. When it comes to the open character, we took inspiration from open source permissive licenses regarding the grant of IP rights. Referring to the downstream responsible use, we added use-based restrictions not permitting the use of the Model in very specific scenarios, in order for the licensor to be able to enforce the license in case potential misuses of the Model may occur. At the same time, we strive to promote open and responsible research on generative models for art and content generation.
|
13 |
+
|
14 |
+
Even though downstream derivative versions of the model could be released under different licensing terms, the latter will always have to include - at minimum - the same use-based restrictions as the ones in the original license (this license). We believe in the intersection between open and responsible AI development; thus, this License aims to strike a balance between both in order to enable responsible open-science in the field of AI.
|
15 |
+
|
16 |
+
This License governs the use of the model (and its derivatives) and is informed by the model card associated with the model.
|
17 |
+
|
18 |
+
NOW THEREFORE, You and Licensor agree as follows:
|
19 |
+
|
20 |
+
1. Definitions
|
21 |
+
|
22 |
+
- "License" means the terms and conditions for use, reproduction, and Distribution as defined in this document.
|
23 |
+
- "Data" means a collection of information and/or content extracted from the dataset used with the Model, including to train, pretrain, or otherwise evaluate the Model. The Data is not licensed under this License.
|
24 |
+
- "Output" means the results of operating a Model as embodied in informational content resulting therefrom.
|
25 |
+
- "Model" means any accompanying machine-learning based assemblies (including checkpoints), consisting of learnt weights, parameters (including optimizer states), corresponding to the model architecture as embodied in the Complementary Material, that have been trained or tuned, in whole or in part on the Data, using the Complementary Material.
|
26 |
+
- "Derivatives of the Model" means all modifications to the Model, works based on the Model, or any other model which is created or initialized by transfer of patterns of the weights, parameters, activations or output of the Model, to the other model, in order to cause the other model to perform similarly to the Model, including - but not limited to - distillation methods entailing the use of intermediate data representations or methods based on the generation of synthetic data by the Model for training the other model.
|
27 |
+
- "Complementary Material" means the accompanying source code and scripts used to define, run, load, benchmark or evaluate the Model, and used to prepare data for training or evaluation, if any. This includes any accompanying documentation, tutorials, examples, etc, if any.
|
28 |
+
- "Distribution" means any transmission, reproduction, publication or other sharing of the Model or Derivatives of the Model to a third party, including providing the Model as a hosted service made available by electronic or other remote means - e.g. API-based or web access.
|
29 |
+
- "Licensor" means the copyright owner or entity authorized by the copyright owner that is granting the License, including the persons or entities that may have rights in the Model and/or distributing the Model.
|
30 |
+
- "You" (or "Your") means an individual or Legal Entity exercising permissions granted by this License and/or making use of the Model for whichever purpose and in any field of use, including usage of the Model in an end-use application - e.g. chatbot, translator, image generator.
|
31 |
+
- "Third Parties" means individuals or legal entities that are not under common control with Licensor or You.
|
32 |
+
- "Contribution" means any work of authorship, including the original version of the Model and any modifications or additions to that Model or Derivatives of the Model thereof, that is intentionally submitted to Licensor for inclusion in the Model by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Model, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution."
|
33 |
+
- "Contributor" means Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Model.
|
34 |
+
|
35 |
+
Section II: INTELLECTUAL PROPERTY RIGHTS
|
36 |
+
|
37 |
+
Both copyright and patent grants apply to the Model, Derivatives of the Model and Complementary Material. The Model and Derivatives of the Model are subject to additional terms as described in Section III.
|
38 |
+
|
39 |
+
2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare, publicly display, publicly perform, sublicense, and distribute the Complementary Material, the Model, and Derivatives of the Model.
|
40 |
+
3. Grant of Patent License. Subject to the terms and conditions of this License and where and as applicable, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this paragraph) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Model and the Complementary Material, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Model to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Model and/or Complementary Material or a Contribution incorporated within the Model and/or Complementary Material constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for the Model and/or Work shall terminate as of the date such litigation is asserted or filed.
|
41 |
+
|
42 |
+
Section III: CONDITIONS OF USAGE, DISTRIBUTION AND REDISTRIBUTION
|
43 |
+
|
44 |
+
4. Distribution and Redistribution. You may host for Third Party remote access purposes (e.g. software-as-a-service), reproduce and distribute copies of the Model or Derivatives of the Model thereof in any medium, with or without modifications, provided that You meet the following conditions:
|
45 |
+
Use-based restrictions as referenced in paragraph 5 MUST be included as an enforceable provision by You in any type of legal agreement (e.g. a license) governing the use and/or distribution of the Model or Derivatives of the Model, and You shall give notice to subsequent users You Distribute to, that the Model or Derivatives of the Model are subject to paragraph 5. This provision does not apply to the use of Complementary Material.
|
46 |
+
You must give any Third Party recipients of the Model or Derivatives of the Model a copy of this License;
|
47 |
+
You must cause any modified files to carry prominent notices stating that You changed the files;
|
48 |
+
You must retain all copyright, patent, trademark, and attribution notices excluding those notices that do not pertain to any part of the Model, Derivatives of the Model.
|
49 |
+
You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions - respecting paragraph 4.a. - for use, reproduction, or Distribution of Your modifications, or for any such Derivatives of the Model as a whole, provided Your use, reproduction, and Distribution of the Model otherwise complies with the conditions stated in this License.
|
50 |
+
5. Use-based restrictions. The restrictions set forth in Attachment A are considered Use-based restrictions. Therefore You cannot use the Model and the Derivatives of the Model for the specified restricted uses. You may use the Model subject to this License, including only for lawful purposes and in accordance with the License. Use may include creating any content with, finetuning, updating, running, training, evaluating and/or reparametrizing the Model. You shall require all of Your users who use the Model or a Derivative of the Model to comply with the terms of this paragraph (paragraph 5).
|
51 |
+
6. The Output You Generate. Except as set forth herein, Licensor claims no rights in the Output You generate using the Model. You are accountable for the Output you generate and its subsequent uses. No use of the output can contravene any provision as stated in the License.
|
52 |
+
|
53 |
+
Section IV: OTHER PROVISIONS
|
54 |
+
|
55 |
+
7. Updates and Runtime Restrictions. To the maximum extent permitted by law, Licensor reserves the right to restrict (remotely or otherwise) usage of the Model in violation of this License, update the Model through electronic means, or modify the Output of the Model based on updates. You shall undertake reasonable efforts to use the latest version of the Model.
|
56 |
+
8. Trademarks and related. Nothing in this License permits You to make use of Licensors’ trademarks, trade names, logos or to otherwise suggest endorsement or misrepresent the relationship between the parties; and any rights not expressly granted herein are reserved by the Licensors.
|
57 |
+
9. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Model and the Complementary Material (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Model, Derivatives of the Model, and the Complementary Material and assume any risks associated with Your exercise of permissions under this License.
|
58 |
+
10. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Model and the Complementary Material (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages.
|
59 |
+
11. Accepting Warranty or Additional Liability. While redistributing the Model, Derivatives of the Model and the Complementary Material thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability.
|
60 |
+
12. If any provision of this License is held to be invalid, illegal or unenforceable, the remaining provisions shall be unaffected thereby and remain valid as if such provision had not been set forth herein.
|
61 |
+
|
62 |
+
END OF TERMS AND CONDITIONS
|
63 |
+
|
64 |
+
|
65 |
+
|
66 |
+
|
67 |
+
Attachment A
|
68 |
+
|
69 |
+
Use Restrictions
|
70 |
+
|
71 |
+
You agree not to use the Model or Derivatives of the Model:
|
72 |
+
- In any way that violates any applicable national, federal, state, local or international law or regulation;
|
73 |
+
- For the purpose of exploiting, harming or attempting to exploit or harm minors in any way;
|
74 |
+
- To generate or disseminate verifiably false information and/or content with the purpose of harming others;
|
75 |
+
- To generate or disseminate personal identifiable information that can be used to harm an individual;
|
76 |
+
- To defame, disparage or otherwise harass others;
|
77 |
+
- For fully automated decision making that adversely impacts an individual’s legal rights or otherwise creates or modifies a binding, enforceable obligation;
|
78 |
+
- For any use intended to or which has the effect of discriminating against or harming individuals or groups based on online or offline social behavior or known or predicted personal or personality characteristics;
|
79 |
+
- To exploit any of the vulnerabilities of a specific group of persons based on their age, social, physical or mental characteristics, in order to materially distort the behavior of a person pertaining to that group in a manner that causes or is likely to cause that person or another person physical or psychological harm;
|
80 |
+
- For any use intended to or which has the effect of discriminating against individuals or groups based on legally protected characteristics or categories;
|
81 |
+
- To provide medical advice and medical results interpretation;
|
82 |
+
- To generate or disseminate information for the purpose to be used for administration of justice, law enforcement, immigration or asylum processes, such as predicting an individual will commit fraud/crime commitment (e.g. by text profiling, drawing causal relationships between assertions made in documents, indiscriminate and arbitrarily-targeted use).
|
easyanimate/vae/README.md
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## VAE Training
|
2 |
+
|
3 |
+
English | [简体中文](./README_zh-CN.md)
|
4 |
+
|
5 |
+
After completing data preprocessing, we can obtain the following dataset:
|
6 |
+
|
7 |
+
```
|
8 |
+
📦 project/
|
9 |
+
├── 📂 datasets/
|
10 |
+
│ ├── 📂 internal_datasets/
|
11 |
+
│ ├── 📂 videos/
|
12 |
+
│ │ ├── 📄 00000001.mp4
|
13 |
+
│ │ ├── 📄 00000001.jpg
|
14 |
+
│ │ └── 📄 .....
|
15 |
+
│ └── 📄 json_of_internal_datasets.json
|
16 |
+
```
|
17 |
+
|
18 |
+
The json_of_internal_datasets.json is a standard JSON file. The file_path in the json can to be set as relative path, as shown in below:
|
19 |
+
```json
|
20 |
+
[
|
21 |
+
{
|
22 |
+
"file_path": "videos/00000001.mp4",
|
23 |
+
"text": "A group of young men in suits and sunglasses are walking down a city street.",
|
24 |
+
"type": "video"
|
25 |
+
},
|
26 |
+
{
|
27 |
+
"file_path": "train/00000001.jpg",
|
28 |
+
"text": "A group of young men in suits and sunglasses are walking down a city street.",
|
29 |
+
"type": "image"
|
30 |
+
},
|
31 |
+
.....
|
32 |
+
]
|
33 |
+
```
|
34 |
+
|
35 |
+
You can also set the path as absolute path as follow:
|
36 |
+
```json
|
37 |
+
[
|
38 |
+
{
|
39 |
+
"file_path": "/mnt/data/videos/00000001.mp4",
|
40 |
+
"text": "A group of young men in suits and sunglasses are walking down a city street.",
|
41 |
+
"type": "video"
|
42 |
+
},
|
43 |
+
{
|
44 |
+
"file_path": "/mnt/data/train/00000001.jpg",
|
45 |
+
"text": "A group of young men in suits and sunglasses are walking down a city street.",
|
46 |
+
"type": "image"
|
47 |
+
},
|
48 |
+
.....
|
49 |
+
]
|
50 |
+
```
|
51 |
+
|
52 |
+
## Train Video VAE
|
53 |
+
We need to set config in ```easyanimate/vae/configs/autoencoder``` at first. The default config is ```autoencoder_kl_32x32x4_slice.yaml```. We need to set the some params in yaml file.
|
54 |
+
|
55 |
+
- ```data_json_path``` corresponds to the JSON file of the dataset.
|
56 |
+
- ```data_root``` corresponds to the root path of the dataset. If you want to use absolute path in json file, please delete this line.
|
57 |
+
- ```ckpt_path``` corresponds to the pretrained weights of the vae.
|
58 |
+
- ```gpus``` and num_nodes need to be set as the actual situation of your machine.
|
59 |
+
|
60 |
+
The we run shell file as follow:
|
61 |
+
```
|
62 |
+
sh scripts/train_vae.sh
|
63 |
+
```
|
easyanimate/vae/README_zh-CN.md
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## VAE 训练
|
2 |
+
|
3 |
+
[English](./README.md) | 简体中文
|
4 |
+
|
5 |
+
在完成数据预处理后,你可以获得这样的数据格式:
|
6 |
+
|
7 |
+
```
|
8 |
+
📦 project/
|
9 |
+
├── 📂 datasets/
|
10 |
+
│ ├── 📂 internal_datasets/
|
11 |
+
│ ├── 📂 videos/
|
12 |
+
│ │ ├── 📄 00000001.mp4
|
13 |
+
│ │ ├── 📄 00000001.jpg
|
14 |
+
│ │ └── 📄 .....
|
15 |
+
│ └── 📄 json_of_internal_datasets.json
|
16 |
+
```
|
17 |
+
|
18 |
+
json_of_internal_datasets.json是一个标准的json文件。json中的file_path可以被设置为相对路径,如下所示:
|
19 |
+
```json
|
20 |
+
[
|
21 |
+
{
|
22 |
+
"file_path": "videos/00000001.mp4",
|
23 |
+
"text": "A group of young men in suits and sunglasses are walking down a city street.",
|
24 |
+
"type": "video"
|
25 |
+
},
|
26 |
+
{
|
27 |
+
"file_path": "train/00000001.jpg",
|
28 |
+
"text": "A group of young men in suits and sunglasses are walking down a city street.",
|
29 |
+
"type": "image"
|
30 |
+
},
|
31 |
+
.....
|
32 |
+
]
|
33 |
+
```
|
34 |
+
|
35 |
+
你也可以将路径设置为绝对路径:
|
36 |
+
```json
|
37 |
+
[
|
38 |
+
{
|
39 |
+
"file_path": "/mnt/data/videos/00000001.mp4",
|
40 |
+
"text": "A group of young men in suits and sunglasses are walking down a city street.",
|
41 |
+
"type": "video"
|
42 |
+
},
|
43 |
+
{
|
44 |
+
"file_path": "/mnt/data/train/00000001.jpg",
|
45 |
+
"text": "A group of young men in suits and sunglasses are walking down a city street.",
|
46 |
+
"type": "image"
|
47 |
+
},
|
48 |
+
.....
|
49 |
+
]
|
50 |
+
```
|
51 |
+
|
52 |
+
## 训练 Video VAE
|
53 |
+
我们首先需要修改 ```easyanimate/vae/configs/autoencoder``` 中的配置文件。默认的配置文件是 ```autoencoder_kl_32x32x4_slice.yaml```。你需要修改以下参数:
|
54 |
+
|
55 |
+
- ```data_json_path``` json file 所在的目录。
|
56 |
+
- ```data_root``` 数据的根目录。如果你在json file中使用了绝对路径,请设置为空。
|
57 |
+
- ```ckpt_path``` 预训练的vae模型路径。
|
58 |
+
- ```gpus``` 以及 ```num_nodes``` 需要设置为你机器的实际gpu数目。
|
59 |
+
|
60 |
+
运行以下的脚本来训练vae:
|
61 |
+
```
|
62 |
+
sh scripts/train_vae.sh
|
63 |
+
```
|
easyanimate/vae/configs/autoencoder/autoencoder_kl_32x32x4_mag.yaml
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
base_learning_rate: 1.0e-04
|
3 |
+
target: easyanimate.vae.ldm.models.omnigen_casual3dcnn.AutoencoderKLMagvit_fromOmnigen
|
4 |
+
params:
|
5 |
+
monitor: train/rec_loss
|
6 |
+
ckpt_path: models/videoVAE_omnigen_8x8x4_from_vae-ft-mse-840000-ema-pruned.ckpt
|
7 |
+
down_block_types: ("SpatialDownBlock3D", "SpatialTemporalDownBlock3D", "SpatialTemporalDownBlock3D",
|
8 |
+
"SpatialTemporalDownBlock3D",)
|
9 |
+
up_block_types: ("SpatialUpBlock3D", "SpatialTemporalUpBlock3D", "SpatialTemporalUpBlock3D",
|
10 |
+
"SpatialTemporalUpBlock3D",)
|
11 |
+
lossconfig:
|
12 |
+
target: easyanimate.vae.ldm.modules.losses.LPIPSWithDiscriminator
|
13 |
+
params:
|
14 |
+
disc_start: 50001
|
15 |
+
kl_weight: 1.0e-06
|
16 |
+
disc_weight: 0.5
|
17 |
+
l2_loss_weight: 0.1
|
18 |
+
l1_loss_weight: 1.0
|
19 |
+
perceptual_weight: 1.0
|
20 |
+
|
21 |
+
data:
|
22 |
+
target: train_vae.DataModuleFromConfig
|
23 |
+
|
24 |
+
params:
|
25 |
+
batch_size: 2
|
26 |
+
wrap: true
|
27 |
+
num_workers: 4
|
28 |
+
train:
|
29 |
+
target: easyanimate.vae.ldm.data.dataset_image_video.CustomSRTrain
|
30 |
+
params:
|
31 |
+
data_json_path: pretrain.json
|
32 |
+
data_root: /your_data_root # This is used in relative path
|
33 |
+
size: 128
|
34 |
+
degradation: pil_nearest
|
35 |
+
video_size: 128
|
36 |
+
video_len: 9
|
37 |
+
slice_interval: 1
|
38 |
+
validation:
|
39 |
+
target: easyanimate.vae.ldm.data.dataset_image_video.CustomSRValidation
|
40 |
+
params:
|
41 |
+
data_json_path: pretrain.json
|
42 |
+
data_root: /your_data_root # This is used in relative path
|
43 |
+
size: 128
|
44 |
+
degradation: pil_nearest
|
45 |
+
video_size: 128
|
46 |
+
video_len: 9
|
47 |
+
slice_interval: 1
|
48 |
+
|
49 |
+
lightning:
|
50 |
+
callbacks:
|
51 |
+
image_logger:
|
52 |
+
target: train_vae.ImageLogger
|
53 |
+
params:
|
54 |
+
batch_frequency: 5000
|
55 |
+
max_images: 8
|
56 |
+
increase_log_steps: True
|
57 |
+
|
58 |
+
trainer:
|
59 |
+
benchmark: True
|
60 |
+
accumulate_grad_batches: 1
|
61 |
+
gpus: "0"
|
62 |
+
num_nodes: 1
|
easyanimate/vae/configs/autoencoder/autoencoder_kl_32x32x4_slice.yaml
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
base_learning_rate: 1.0e-04
|
3 |
+
target: easyanimate.vae.ldm.models.omnigen_casual3dcnn.AutoencoderKLMagvit_fromOmnigen
|
4 |
+
params:
|
5 |
+
slice_compression_vae: true
|
6 |
+
mini_batch_encoder: 8
|
7 |
+
mini_batch_decoder: 2
|
8 |
+
monitor: train/rec_loss
|
9 |
+
ckpt_path: models/Diffusion_Transformer/EasyAnimateV2-XL-2-512x512/vae/diffusion_pytorch_model.safetensors
|
10 |
+
down_block_types: ("SpatialDownBlock3D", "SpatialTemporalDownBlock3D", "SpatialTemporalDownBlock3D",
|
11 |
+
"SpatialTemporalDownBlock3D",)
|
12 |
+
up_block_types: ("SpatialUpBlock3D", "SpatialTemporalUpBlock3D", "SpatialTemporalUpBlock3D",
|
13 |
+
"SpatialTemporalUpBlock3D",)
|
14 |
+
lossconfig:
|
15 |
+
target: easyanimate.vae.ldm.modules.losses.LPIPSWithDiscriminator
|
16 |
+
params:
|
17 |
+
disc_start: 50001
|
18 |
+
kl_weight: 1.0e-06
|
19 |
+
disc_weight: 0.5
|
20 |
+
l2_loss_weight: 0.0
|
21 |
+
l1_loss_weight: 1.0
|
22 |
+
perceptual_weight: 1.0
|
23 |
+
|
24 |
+
data:
|
25 |
+
target: train_vae.DataModuleFromConfig
|
26 |
+
|
27 |
+
params:
|
28 |
+
batch_size: 1
|
29 |
+
wrap: true
|
30 |
+
num_workers: 8
|
31 |
+
train:
|
32 |
+
target: easyanimate.vae.ldm.data.dataset_image_video.CustomSRTrain
|
33 |
+
params:
|
34 |
+
data_json_path: pretrain.json
|
35 |
+
data_root: /your_data_root # This is used in relative path
|
36 |
+
size: 256
|
37 |
+
degradation: pil_nearest
|
38 |
+
video_size: 256
|
39 |
+
video_len: 25
|
40 |
+
slice_interval: 1
|
41 |
+
validation:
|
42 |
+
target: easyanimate.vae.ldm.data.dataset_image_video.CustomSRValidation
|
43 |
+
params:
|
44 |
+
data_json_path: pretrain.json
|
45 |
+
data_root: /your_data_root # This is used in relative path
|
46 |
+
size: 256
|
47 |
+
degradation: pil_nearest
|
48 |
+
video_size: 256
|
49 |
+
video_len: 25
|
50 |
+
slice_interval: 1
|
51 |
+
|
52 |
+
lightning:
|
53 |
+
callbacks:
|
54 |
+
image_logger:
|
55 |
+
target: train_vae.ImageLogger
|
56 |
+
params:
|
57 |
+
batch_frequency: 5000
|
58 |
+
max_images: 8
|
59 |
+
increase_log_steps: True
|
60 |
+
|
61 |
+
trainer:
|
62 |
+
benchmark: True
|
63 |
+
accumulate_grad_batches: 1
|
64 |
+
gpus: "0"
|
65 |
+
num_nodes: 1
|
easyanimate/vae/configs/autoencoder/autoencoder_kl_32x32x4_slice_decoder_only.yaml
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
base_learning_rate: 1.0e-04
|
3 |
+
target: easyanimate.vae.ldm.models.omnigen_casual3dcnn.AutoencoderKLMagvit_fromOmnigen
|
4 |
+
params:
|
5 |
+
slice_compression_vae: true
|
6 |
+
train_decoder_only: true
|
7 |
+
mini_batch_encoder: 8
|
8 |
+
mini_batch_decoder: 2
|
9 |
+
monitor: train/rec_loss
|
10 |
+
ckpt_path: models/Diffusion_Transformer/EasyAnimateV2-XL-2-512x512/vae/diffusion_pytorch_model.safetensors
|
11 |
+
down_block_types: ("SpatialDownBlock3D", "SpatialTemporalDownBlock3D", "SpatialTemporalDownBlock3D",
|
12 |
+
"SpatialTemporalDownBlock3D",)
|
13 |
+
up_block_types: ("SpatialUpBlock3D", "SpatialTemporalUpBlock3D", "SpatialTemporalUpBlock3D",
|
14 |
+
"SpatialTemporalUpBlock3D",)
|
15 |
+
lossconfig:
|
16 |
+
target: easyanimate.vae.ldm.modules.losses.LPIPSWithDiscriminator
|
17 |
+
params:
|
18 |
+
disc_start: 50001
|
19 |
+
kl_weight: 1.0e-06
|
20 |
+
disc_weight: 0.5
|
21 |
+
l2_loss_weight: 1.0
|
22 |
+
l1_loss_weight: 0.0
|
23 |
+
perceptual_weight: 1.0
|
24 |
+
|
25 |
+
data:
|
26 |
+
target: train_vae.DataModuleFromConfig
|
27 |
+
|
28 |
+
params:
|
29 |
+
batch_size: 1
|
30 |
+
wrap: true
|
31 |
+
num_workers: 8
|
32 |
+
train:
|
33 |
+
target: easyanimate.vae.ldm.data.dataset_image_video.CustomSRTrain
|
34 |
+
params:
|
35 |
+
data_json_path: pretrain.json
|
36 |
+
data_root: /your_data_root # This is used in relative path
|
37 |
+
size: 256
|
38 |
+
degradation: pil_nearest
|
39 |
+
video_size: 256
|
40 |
+
video_len: 25
|
41 |
+
slice_interval: 1
|
42 |
+
validation:
|
43 |
+
target: easyanimate.vae.ldm.data.dataset_image_video.CustomSRValidation
|
44 |
+
params:
|
45 |
+
data_json_path: pretrain.json
|
46 |
+
data_root: /your_data_root # This is used in relative path
|
47 |
+
size: 256
|
48 |
+
degradation: pil_nearest
|
49 |
+
video_size: 256
|
50 |
+
video_len: 25
|
51 |
+
slice_interval: 1
|
52 |
+
|
53 |
+
lightning:
|
54 |
+
callbacks:
|
55 |
+
image_logger:
|
56 |
+
target: train_vae.ImageLogger
|
57 |
+
params:
|
58 |
+
batch_frequency: 5000
|
59 |
+
max_images: 8
|
60 |
+
increase_log_steps: True
|
61 |
+
|
62 |
+
trainer:
|
63 |
+
benchmark: True
|
64 |
+
accumulate_grad_batches: 1
|
65 |
+
gpus: "0"
|
66 |
+
num_nodes: 1
|
easyanimate/vae/configs/autoencoder/autoencoder_kl_32x32x4_slice_t_downsample_8.yaml
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
base_learning_rate: 1.0e-04
|
3 |
+
target: easyanimate.vae.ldm.models.omnigen_casual3dcnn.AutoencoderKLMagvit_fromOmnigen
|
4 |
+
params:
|
5 |
+
slice_compression_vae: true
|
6 |
+
mini_batch_encoder: 8
|
7 |
+
mini_batch_decoder: 1
|
8 |
+
monitor: train/rec_loss
|
9 |
+
ckpt_path: models/Diffusion_Transformer/EasyAnimateV2-XL-2-512x512/vae/diffusion_pytorch_model.safetensors
|
10 |
+
down_block_types: ("SpatialTemporalDownBlock3D", "SpatialTemporalDownBlock3D", "SpatialTemporalDownBlock3D",
|
11 |
+
"SpatialTemporalDownBlock3D",)
|
12 |
+
up_block_types: ("SpatialTemporalUpBlock3D", "SpatialTemporalUpBlock3D", "SpatialTemporalUpBlock3D",
|
13 |
+
"SpatialTemporalUpBlock3D",)
|
14 |
+
lossconfig:
|
15 |
+
target: easyanimate.vae.ldm.modules.losses.LPIPSWithDiscriminator
|
16 |
+
params:
|
17 |
+
disc_start: 50001
|
18 |
+
kl_weight: 1.0e-06
|
19 |
+
disc_weight: 0.5
|
20 |
+
l2_loss_weight: 0.0
|
21 |
+
l1_loss_weight: 1.0
|
22 |
+
perceptual_weight: 1.0
|
23 |
+
|
24 |
+
|
25 |
+
data:
|
26 |
+
target: train_vae.DataModuleFromConfig
|
27 |
+
|
28 |
+
params:
|
29 |
+
batch_size: 1
|
30 |
+
wrap: true
|
31 |
+
num_workers: 8
|
32 |
+
train:
|
33 |
+
target: easyanimate.vae.ldm.data.dataset_image_video.CustomSRTrain
|
34 |
+
params:
|
35 |
+
data_json_path: pretrain.json
|
36 |
+
data_root: /your_data_root # This is used in relative path
|
37 |
+
size: 256
|
38 |
+
degradation: pil_nearest
|
39 |
+
video_size: 256
|
40 |
+
video_len: 33
|
41 |
+
slice_interval: 1
|
42 |
+
validation:
|
43 |
+
target: easyanimate.vae.ldm.data.dataset_image_video.CustomSRValidation
|
44 |
+
params:
|
45 |
+
data_json_path: pretrain.json
|
46 |
+
data_root: /your_data_root # This is used in relative path
|
47 |
+
size: 256
|
48 |
+
degradation: pil_nearest
|
49 |
+
video_size: 256
|
50 |
+
video_len: 33
|
51 |
+
slice_interval: 1
|
52 |
+
|
53 |
+
lightning:
|
54 |
+
callbacks:
|
55 |
+
image_logger:
|
56 |
+
target: train_vae.ImageLogger
|
57 |
+
params:
|
58 |
+
batch_frequency: 5000
|
59 |
+
max_images: 8
|
60 |
+
increase_log_steps: True
|
61 |
+
|
62 |
+
trainer:
|
63 |
+
benchmark: True
|
64 |
+
accumulate_grad_batches: 1
|
65 |
+
gpus: "0"
|
66 |
+
num_nodes: 1
|
easyanimate/vae/environment.yaml
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: ldm
|
2 |
+
channels:
|
3 |
+
- pytorch
|
4 |
+
- defaults
|
5 |
+
dependencies:
|
6 |
+
- python=3.8.5
|
7 |
+
- pip=20.3
|
8 |
+
- cudatoolkit=11.3
|
9 |
+
- pytorch=1.11.0
|
10 |
+
- torchvision=0.12.0
|
11 |
+
- numpy=1.19.2
|
12 |
+
- pip:
|
13 |
+
- albumentations==0.4.3
|
14 |
+
- diffusers
|
15 |
+
- opencv-python==4.1.2.30
|
16 |
+
- pudb==2019.2
|
17 |
+
- invisible-watermark
|
18 |
+
- imageio==2.9.0
|
19 |
+
- imageio-ffmpeg==0.4.2
|
20 |
+
- pytorch-lightning==1.4.2
|
21 |
+
- omegaconf==2.1.1
|
22 |
+
- test-tube>=0.7.5
|
23 |
+
- streamlit>=0.73.1
|
24 |
+
- einops==0.3.0
|
25 |
+
- torch-fidelity==0.3.0
|
26 |
+
- transformers==4.19.2
|
27 |
+
- torchmetrics==0.6.0
|
28 |
+
- kornia==0.6
|
29 |
+
- -e .
|
easyanimate/vae/ldm/data/__init__.py
ADDED
File without changes
|
easyanimate/vae/ldm/data/base.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import abstractmethod
|
2 |
+
|
3 |
+
from torch.utils.data import (ChainDataset, ConcatDataset, Dataset,
|
4 |
+
IterableDataset)
|
5 |
+
|
6 |
+
|
7 |
+
class Txt2ImgIterableBaseDataset(IterableDataset):
|
8 |
+
'''
|
9 |
+
Define an interface to make the IterableDatasets for text2img data chainable
|
10 |
+
'''
|
11 |
+
def __init__(self, num_records=0, valid_ids=None, size=256):
|
12 |
+
super().__init__()
|
13 |
+
self.num_records = num_records
|
14 |
+
self.valid_ids = valid_ids
|
15 |
+
self.sample_ids = valid_ids
|
16 |
+
self.size = size
|
17 |
+
|
18 |
+
print(f'{self.__class__.__name__} dataset contains {self.__len__()} examples.')
|
19 |
+
|
20 |
+
def __len__(self):
|
21 |
+
return self.num_records
|
22 |
+
|
23 |
+
@abstractmethod
|
24 |
+
def __iter__(self):
|
25 |
+
pass
|
easyanimate/vae/ldm/data/dataset_callback.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#-*- encoding:utf-8 -*-
|
2 |
+
from pytorch_lightning.callbacks import Callback
|
3 |
+
|
4 |
+
class DatasetCallback(Callback):
|
5 |
+
def __init__(self):
|
6 |
+
self.sampler_pos_start = 0
|
7 |
+
self.preload_used_idx_flag = False
|
8 |
+
|
9 |
+
def on_train_start(self, trainer, pl_module):
|
10 |
+
if not self.preload_used_idx_flag:
|
11 |
+
self.preload_used_idx_flag = True
|
12 |
+
trainer.train_dataloader.batch_sampler.sampler_pos_reload = self.sampler_pos_start
|
13 |
+
|
14 |
+
def on_save_checkpoint(self, trainer, pl_module, checkpoint):
|
15 |
+
if trainer.train_dataloader is not None:
|
16 |
+
# Save sampler_pos_start parameters in the checkpoint
|
17 |
+
checkpoint['sampler_pos_start'] = trainer.train_dataloader.batch_sampler.sampler_pos_start
|
18 |
+
|
19 |
+
def on_load_checkpoint(self, trainer, pl_module, checkpoint):
|
20 |
+
# Restore sampler_pos_start parameters from the checkpoint
|
21 |
+
if 'sampler_pos_start' in checkpoint:
|
22 |
+
self.sampler_pos_start = checkpoint.get('sampler_pos_start', 0)
|
23 |
+
print('Load sampler_pos_start from checkpoint, sampler_pos_start = %d' % self.sampler_pos_start)
|
24 |
+
else:
|
25 |
+
print('The sampler_pos_start is not in checkpoint')
|
easyanimate/vae/ldm/data/dataset_image_video.py
ADDED
@@ -0,0 +1,281 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import glob
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
import pickle
|
5 |
+
import random
|
6 |
+
import shutil
|
7 |
+
import tarfile
|
8 |
+
from functools import partial
|
9 |
+
|
10 |
+
import albumentations
|
11 |
+
import cv2
|
12 |
+
import numpy as np
|
13 |
+
import PIL
|
14 |
+
import torchvision.transforms.functional as TF
|
15 |
+
import yaml
|
16 |
+
from decord import VideoReader
|
17 |
+
from func_timeout import FunctionTimedOut, func_set_timeout
|
18 |
+
from omegaconf import OmegaConf
|
19 |
+
from PIL import Image
|
20 |
+
from torch.utils.data import (BatchSampler, Dataset, Sampler)
|
21 |
+
from tqdm import tqdm
|
22 |
+
|
23 |
+
from ..modules.image_degradation import (degradation_fn_bsr,
|
24 |
+
degradation_fn_bsr_light)
|
25 |
+
|
26 |
+
|
27 |
+
class ImageVideoSampler(BatchSampler):
|
28 |
+
"""A sampler wrapper for grouping images with similar aspect ratio into a same batch.
|
29 |
+
|
30 |
+
Args:
|
31 |
+
sampler (Sampler): Base sampler.
|
32 |
+
dataset (Dataset): Dataset providing data information.
|
33 |
+
batch_size (int): Size of mini-batch.
|
34 |
+
drop_last (bool): If ``True``, the sampler will drop the last batch if
|
35 |
+
its size would be less than ``batch_size``.
|
36 |
+
aspect_ratios (dict): The predefined aspect ratios.
|
37 |
+
"""
|
38 |
+
|
39 |
+
def __init__(self,
|
40 |
+
sampler: Sampler,
|
41 |
+
dataset: Dataset,
|
42 |
+
batch_size: int,
|
43 |
+
drop_last: bool = False
|
44 |
+
) -> None:
|
45 |
+
if not isinstance(sampler, Sampler):
|
46 |
+
raise TypeError('sampler should be an instance of ``Sampler``, '
|
47 |
+
f'but got {sampler}')
|
48 |
+
if not isinstance(batch_size, int) or batch_size <= 0:
|
49 |
+
raise ValueError('batch_size should be a positive integer value, '
|
50 |
+
f'but got batch_size={batch_size}')
|
51 |
+
self.sampler = sampler
|
52 |
+
self.dataset = dataset
|
53 |
+
self.batch_size = batch_size
|
54 |
+
self.drop_last = drop_last
|
55 |
+
|
56 |
+
self.sampler_pos_start = 0
|
57 |
+
self.sampler_pos_reload = 0
|
58 |
+
|
59 |
+
self.num_samples_random = len(self.sampler)
|
60 |
+
# buckets for each aspect ratio
|
61 |
+
self.bucket = {'image':[], 'video':[]}
|
62 |
+
|
63 |
+
def set_epoch(self, epoch):
|
64 |
+
if hasattr(self.sampler, "set_epoch"):
|
65 |
+
self.sampler.set_epoch(epoch)
|
66 |
+
|
67 |
+
def __iter__(self):
|
68 |
+
for index_sampler, idx in enumerate(self.sampler):
|
69 |
+
if self.sampler_pos_reload != 0 and self.sampler_pos_reload < self.num_samples_random:
|
70 |
+
if index_sampler < self.sampler_pos_reload:
|
71 |
+
self.sampler_pos_start = (self.sampler_pos_start + 1) % self.num_samples_random
|
72 |
+
continue
|
73 |
+
elif index_sampler == self.sampler_pos_reload:
|
74 |
+
self.sampler_pos_reload = 0
|
75 |
+
|
76 |
+
content_type = self.dataset.data.get_type(idx)
|
77 |
+
bucket = self.bucket[content_type]
|
78 |
+
bucket.append(idx)
|
79 |
+
# yield a batch of indices in the same aspect ratio group
|
80 |
+
if len(self.bucket['video']) == self.batch_size:
|
81 |
+
yield self.bucket['video']
|
82 |
+
self.bucket['video'] = []
|
83 |
+
elif len(self.bucket['image']) == self.batch_size:
|
84 |
+
yield self.bucket['image']
|
85 |
+
self.bucket['image'] = []
|
86 |
+
self.sampler_pos_start = (self.sampler_pos_start + 1) % self.num_samples_random
|
87 |
+
|
88 |
+
class ImageVideoDataset(Dataset):
|
89 |
+
# update __getitem__() from ImageNetSR. If timeout for Pandas70M, throw exception.
|
90 |
+
# If caught exception(timeout or others), try another index until successful and return.
|
91 |
+
def __init__(self, size=None, video_size=128, video_len=25,
|
92 |
+
degradation=None, downscale_f=4, random_crop=True, min_crop_f=0.25, max_crop_f=1.,
|
93 |
+
s_t=None, slice_interval=None, data_root=None
|
94 |
+
):
|
95 |
+
"""
|
96 |
+
Imagenet Superresolution Dataloader
|
97 |
+
Performs following ops in order:
|
98 |
+
1. crops a crop of size s from image either as random or center crop
|
99 |
+
2. resizes crop to size with cv2.area_interpolation
|
100 |
+
3. degrades resized crop with degradation_fn
|
101 |
+
|
102 |
+
:param size: resizing to size after cropping
|
103 |
+
:param degradation: degradation_fn, e.g. cv_bicubic or bsrgan_light
|
104 |
+
:param downscale_f: Low Resolution Downsample factor
|
105 |
+
:param min_crop_f: determines crop size s,
|
106 |
+
where s = c * min_img_side_len with c sampled from interval (min_crop_f, max_crop_f)
|
107 |
+
:param max_crop_f: ""
|
108 |
+
:param data_root:
|
109 |
+
:param random_crop:
|
110 |
+
"""
|
111 |
+
self.base = self.get_base()
|
112 |
+
assert size
|
113 |
+
assert (size / downscale_f).is_integer()
|
114 |
+
self.size = size
|
115 |
+
self.LR_size = int(size / downscale_f)
|
116 |
+
self.min_crop_f = min_crop_f
|
117 |
+
self.max_crop_f = max_crop_f
|
118 |
+
assert(max_crop_f <= 1.)
|
119 |
+
self.center_crop = not random_crop
|
120 |
+
self.s_t = s_t
|
121 |
+
self.slice_interval = slice_interval
|
122 |
+
|
123 |
+
self.image_rescaler = albumentations.SmallestMaxSize(max_size=size, interpolation=cv2.INTER_AREA)
|
124 |
+
self.video_rescaler = albumentations.SmallestMaxSize(max_size=video_size, interpolation=cv2.INTER_AREA)
|
125 |
+
self.video_len = video_len
|
126 |
+
self.video_size = video_size
|
127 |
+
self.data_root = data_root
|
128 |
+
|
129 |
+
self.pil_interpolation = False # gets reset later if incase interp_op is from pillow
|
130 |
+
|
131 |
+
if degradation == "bsrgan":
|
132 |
+
self.degradation_process = partial(degradation_fn_bsr, sf=downscale_f)
|
133 |
+
|
134 |
+
elif degradation == "bsrgan_light":
|
135 |
+
self.degradation_process = partial(degradation_fn_bsr_light, sf=downscale_f)
|
136 |
+
else:
|
137 |
+
interpolation_fn = {
|
138 |
+
"cv_nearest": cv2.INTER_NEAREST,
|
139 |
+
"cv_bilinear": cv2.INTER_LINEAR,
|
140 |
+
"cv_bicubic": cv2.INTER_CUBIC,
|
141 |
+
"cv_area": cv2.INTER_AREA,
|
142 |
+
"cv_lanczos": cv2.INTER_LANCZOS4,
|
143 |
+
"pil_nearest": PIL.Image.NEAREST,
|
144 |
+
"pil_bilinear": PIL.Image.BILINEAR,
|
145 |
+
"pil_bicubic": PIL.Image.BICUBIC,
|
146 |
+
"pil_box": PIL.Image.BOX,
|
147 |
+
"pil_hamming": PIL.Image.HAMMING,
|
148 |
+
"pil_lanczos": PIL.Image.LANCZOS,
|
149 |
+
}[degradation]
|
150 |
+
|
151 |
+
self.pil_interpolation = degradation.startswith("pil_")
|
152 |
+
|
153 |
+
if self.pil_interpolation:
|
154 |
+
self.degradation_process = partial(TF.resize, size=self.LR_size, interpolation=interpolation_fn)
|
155 |
+
|
156 |
+
else:
|
157 |
+
self.degradation_process = albumentations.SmallestMaxSize(max_size=self.LR_size,
|
158 |
+
interpolation=interpolation_fn)
|
159 |
+
|
160 |
+
def __len__(self):
|
161 |
+
return len(self.base)
|
162 |
+
|
163 |
+
def get_type(self, index):
|
164 |
+
return self.base[index].get('type', 'image')
|
165 |
+
|
166 |
+
def __getitem__(self, i):
|
167 |
+
@func_set_timeout(3) # time wait 3 seconds
|
168 |
+
def get_video_item(example):
|
169 |
+
if self.data_root is not None:
|
170 |
+
video_reader = VideoReader(os.path.join(self.data_root, example['file_path']))
|
171 |
+
else:
|
172 |
+
video_reader = VideoReader(example['file_path'])
|
173 |
+
video_length = len(video_reader)
|
174 |
+
|
175 |
+
clip_length = min(video_length, (self.video_len - 1) * self.slice_interval + 1)
|
176 |
+
start_idx = random.randint(0, video_length - clip_length)
|
177 |
+
batch_index = np.linspace(start_idx, start_idx + clip_length - 1, self.video_len, dtype=int)
|
178 |
+
|
179 |
+
pixel_values = video_reader.get_batch(batch_index).asnumpy()
|
180 |
+
|
181 |
+
del video_reader
|
182 |
+
out_images = []
|
183 |
+
LR_out_images = []
|
184 |
+
min_side_len = min(pixel_values[0].shape[:2])
|
185 |
+
|
186 |
+
crop_side_len = min_side_len * np.random.uniform(self.min_crop_f, self.max_crop_f, size=None)
|
187 |
+
crop_side_len = int(crop_side_len)
|
188 |
+
if self.center_crop:
|
189 |
+
self.cropper = albumentations.CenterCrop(height=crop_side_len, width=crop_side_len)
|
190 |
+
else:
|
191 |
+
self.cropper = albumentations.RandomCrop(height=crop_side_len, width=crop_side_len)
|
192 |
+
|
193 |
+
imgs = np.transpose(pixel_values, (1, 2, 3, 0))
|
194 |
+
imgs = self.cropper(image=imgs)["image"]
|
195 |
+
imgs = np.transpose(imgs, (3, 0, 1, 2))
|
196 |
+
for img in imgs:
|
197 |
+
image = self.video_rescaler(image=img)["image"]
|
198 |
+
out_images.append(image[None, :, :, :])
|
199 |
+
if self.pil_interpolation:
|
200 |
+
image_pil = PIL.Image.fromarray(image)
|
201 |
+
LR_image = self.degradation_process(image_pil)
|
202 |
+
LR_image = np.array(LR_image).astype(np.uint8)
|
203 |
+
else:
|
204 |
+
LR_image = self.degradation_process(image=image)["image"]
|
205 |
+
LR_out_images.append(LR_image[None, :, :, :])
|
206 |
+
|
207 |
+
example = {}
|
208 |
+
example['image'] = (np.concatenate(out_images) / 127.5 - 1.0).astype(np.float32)
|
209 |
+
example['LR_image'] = (np.concatenate(LR_out_images) / 127.5 - 1.0).astype(np.float32)
|
210 |
+
return example
|
211 |
+
|
212 |
+
example = self.base[i]
|
213 |
+
if example.get('type', 'image') == 'video':
|
214 |
+
while True:
|
215 |
+
try:
|
216 |
+
example = self.base[i]
|
217 |
+
return get_video_item(example)
|
218 |
+
except FunctionTimedOut:
|
219 |
+
print("stt catch: Function 'extract failed' timed out.")
|
220 |
+
i = random.randint(0, self.__len__() - 1)
|
221 |
+
except Exception as e:
|
222 |
+
print('stt catch', e)
|
223 |
+
i = random.randint(0, self.__len__() - 1)
|
224 |
+
elif example.get('type', 'image') == 'image':
|
225 |
+
while True:
|
226 |
+
try:
|
227 |
+
example = self.base[i]
|
228 |
+
if self.data_root is not None:
|
229 |
+
image = Image.open(os.path.join(self.data_root, example['file_path']))
|
230 |
+
else:
|
231 |
+
image = Image.open(example['file_path'])
|
232 |
+
image = image.convert("RGB")
|
233 |
+
image = np.array(image).astype(np.uint8)
|
234 |
+
|
235 |
+
min_side_len = min(image.shape[:2])
|
236 |
+
crop_side_len = min_side_len * np.random.uniform(self.min_crop_f, self.max_crop_f, size=None)
|
237 |
+
crop_side_len = int(crop_side_len)
|
238 |
+
|
239 |
+
if self.center_crop:
|
240 |
+
self.cropper = albumentations.CenterCrop(height=crop_side_len, width=crop_side_len)
|
241 |
+
|
242 |
+
else:
|
243 |
+
self.cropper = albumentations.RandomCrop(height=crop_side_len, width=crop_side_len)
|
244 |
+
|
245 |
+
image = self.cropper(image=image)["image"]
|
246 |
+
|
247 |
+
image = self.image_rescaler(image=image)["image"]
|
248 |
+
|
249 |
+
if self.pil_interpolation:
|
250 |
+
image_pil = PIL.Image.fromarray(image)
|
251 |
+
LR_image = self.degradation_process(image_pil)
|
252 |
+
LR_image = np.array(LR_image).astype(np.uint8)
|
253 |
+
|
254 |
+
else:
|
255 |
+
LR_image = self.degradation_process(image=image)["image"]
|
256 |
+
|
257 |
+
example = {}
|
258 |
+
example["image"] = (image/127.5 - 1.0).astype(np.float32)
|
259 |
+
example["LR_image"] = (LR_image/127.5 - 1.0).astype(np.float32)
|
260 |
+
return example
|
261 |
+
except Exception as e:
|
262 |
+
print("catch", e)
|
263 |
+
i = random.randint(0, self.__len__() - 1)
|
264 |
+
|
265 |
+
class CustomSRTrain(ImageVideoDataset):
|
266 |
+
def __init__(self, data_json_path, **kwargs):
|
267 |
+
self.data_json_path = data_json_path
|
268 |
+
super().__init__(**kwargs)
|
269 |
+
|
270 |
+
def get_base(self):
|
271 |
+
return [ann for ann in json.load(open(self.data_json_path))]
|
272 |
+
|
273 |
+
class CustomSRValidation(ImageVideoDataset):
|
274 |
+
def __init__(self, data_json_path, **kwargs):
|
275 |
+
self.data_json_path = data_json_path
|
276 |
+
super().__init__(**kwargs)
|
277 |
+
self.data_json_path = data_json_path
|
278 |
+
|
279 |
+
def get_base(self):
|
280 |
+
return [ann for ann in json.load(open(self.data_json_path))][:100] + \
|
281 |
+
[ann for ann in json.load(open(self.data_json_path))][-100:]
|
easyanimate/vae/ldm/lr_scheduler.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
|
4 |
+
class LambdaWarmUpCosineScheduler:
|
5 |
+
"""
|
6 |
+
note: use with a base_lr of 1.0
|
7 |
+
"""
|
8 |
+
def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0):
|
9 |
+
self.lr_warm_up_steps = warm_up_steps
|
10 |
+
self.lr_start = lr_start
|
11 |
+
self.lr_min = lr_min
|
12 |
+
self.lr_max = lr_max
|
13 |
+
self.lr_max_decay_steps = max_decay_steps
|
14 |
+
self.last_lr = 0.
|
15 |
+
self.verbosity_interval = verbosity_interval
|
16 |
+
|
17 |
+
def schedule(self, n, **kwargs):
|
18 |
+
if self.verbosity_interval > 0:
|
19 |
+
if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
|
20 |
+
if n < self.lr_warm_up_steps:
|
21 |
+
lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start
|
22 |
+
self.last_lr = lr
|
23 |
+
return lr
|
24 |
+
else:
|
25 |
+
t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps)
|
26 |
+
t = min(t, 1.0)
|
27 |
+
lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
|
28 |
+
1 + np.cos(t * np.pi))
|
29 |
+
self.last_lr = lr
|
30 |
+
return lr
|
31 |
+
|
32 |
+
def __call__(self, n, **kwargs):
|
33 |
+
return self.schedule(n,**kwargs)
|
34 |
+
|
35 |
+
|
36 |
+
class LambdaWarmUpCosineScheduler2:
|
37 |
+
"""
|
38 |
+
supports repeated iterations, configurable via lists
|
39 |
+
note: use with a base_lr of 1.0.
|
40 |
+
"""
|
41 |
+
def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0):
|
42 |
+
assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths)
|
43 |
+
self.lr_warm_up_steps = warm_up_steps
|
44 |
+
self.f_start = f_start
|
45 |
+
self.f_min = f_min
|
46 |
+
self.f_max = f_max
|
47 |
+
self.cycle_lengths = cycle_lengths
|
48 |
+
self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths))
|
49 |
+
self.last_f = 0.
|
50 |
+
self.verbosity_interval = verbosity_interval
|
51 |
+
|
52 |
+
def find_in_interval(self, n):
|
53 |
+
interval = 0
|
54 |
+
for cl in self.cum_cycles[1:]:
|
55 |
+
if n <= cl:
|
56 |
+
return interval
|
57 |
+
interval += 1
|
58 |
+
|
59 |
+
def schedule(self, n, **kwargs):
|
60 |
+
cycle = self.find_in_interval(n)
|
61 |
+
n = n - self.cum_cycles[cycle]
|
62 |
+
if self.verbosity_interval > 0:
|
63 |
+
if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
|
64 |
+
f"current cycle {cycle}")
|
65 |
+
if n < self.lr_warm_up_steps[cycle]:
|
66 |
+
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
|
67 |
+
self.last_f = f
|
68 |
+
return f
|
69 |
+
else:
|
70 |
+
t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle])
|
71 |
+
t = min(t, 1.0)
|
72 |
+
f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (
|
73 |
+
1 + np.cos(t * np.pi))
|
74 |
+
self.last_f = f
|
75 |
+
return f
|
76 |
+
|
77 |
+
def __call__(self, n, **kwargs):
|
78 |
+
return self.schedule(n, **kwargs)
|
79 |
+
|
80 |
+
|
81 |
+
class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
|
82 |
+
|
83 |
+
def schedule(self, n, **kwargs):
|
84 |
+
cycle = self.find_in_interval(n)
|
85 |
+
n = n - self.cum_cycles[cycle]
|
86 |
+
if self.verbosity_interval > 0:
|
87 |
+
if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
|
88 |
+
f"current cycle {cycle}")
|
89 |
+
|
90 |
+
if n < self.lr_warm_up_steps[cycle]:
|
91 |
+
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
|
92 |
+
self.last_f = f
|
93 |
+
return f
|
94 |
+
else:
|
95 |
+
f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle])
|
96 |
+
self.last_f = f
|
97 |
+
return f
|
98 |
+
|
easyanimate/vae/ldm/modules/diffusionmodules/__init__.py
ADDED
File without changes
|
easyanimate/vae/ldm/modules/diffusionmodules/model.py
ADDED
@@ -0,0 +1,701 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# pytorch_diffusion + derived encoder decoder
|
2 |
+
import math
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
from einops import rearrange
|
8 |
+
|
9 |
+
from ...util import instantiate_from_config
|
10 |
+
|
11 |
+
|
12 |
+
def get_timestep_embedding(timesteps, embedding_dim):
|
13 |
+
"""
|
14 |
+
This matches the implementation in Denoising Diffusion Probabilistic Models:
|
15 |
+
From Fairseq.
|
16 |
+
Build sinusoidal embeddings.
|
17 |
+
This matches the implementation in tensor2tensor, but differs slightly
|
18 |
+
from the description in Section 3.5 of "Attention Is All You Need".
|
19 |
+
"""
|
20 |
+
assert len(timesteps.shape) == 1
|
21 |
+
|
22 |
+
half_dim = embedding_dim // 2
|
23 |
+
emb = math.log(10000) / (half_dim - 1)
|
24 |
+
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
|
25 |
+
emb = emb.to(device=timesteps.device)
|
26 |
+
emb = timesteps.float()[:, None] * emb[None, :]
|
27 |
+
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
28 |
+
if embedding_dim % 2 == 1: # zero pad
|
29 |
+
emb = torch.nn.functional.pad(emb, (0,1,0,0))
|
30 |
+
return emb
|
31 |
+
|
32 |
+
|
33 |
+
def nonlinearity(x):
|
34 |
+
# swish
|
35 |
+
return x*torch.sigmoid(x)
|
36 |
+
|
37 |
+
|
38 |
+
def Normalize(in_channels, num_groups=32):
|
39 |
+
return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
40 |
+
|
41 |
+
|
42 |
+
class Upsample(nn.Module):
|
43 |
+
def __init__(self, in_channels, with_conv):
|
44 |
+
super().__init__()
|
45 |
+
self.with_conv = with_conv
|
46 |
+
if self.with_conv:
|
47 |
+
self.conv = torch.nn.Conv2d(in_channels,
|
48 |
+
in_channels,
|
49 |
+
kernel_size=3,
|
50 |
+
stride=1,
|
51 |
+
padding=1)
|
52 |
+
|
53 |
+
def forward(self, x):
|
54 |
+
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
55 |
+
if self.with_conv:
|
56 |
+
x = self.conv(x)
|
57 |
+
return x
|
58 |
+
|
59 |
+
|
60 |
+
class Downsample(nn.Module):
|
61 |
+
def __init__(self, in_channels, with_conv):
|
62 |
+
super().__init__()
|
63 |
+
self.with_conv = with_conv
|
64 |
+
if self.with_conv:
|
65 |
+
# no asymmetric padding in torch conv, must do it ourselves
|
66 |
+
self.conv = torch.nn.Conv2d(in_channels,
|
67 |
+
in_channels,
|
68 |
+
kernel_size=3,
|
69 |
+
stride=2,
|
70 |
+
padding=0)
|
71 |
+
|
72 |
+
def forward(self, x):
|
73 |
+
if self.with_conv:
|
74 |
+
pad = (0,1,0,1)
|
75 |
+
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
|
76 |
+
x = self.conv(x)
|
77 |
+
else:
|
78 |
+
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
|
79 |
+
return x
|
80 |
+
|
81 |
+
|
82 |
+
class ResnetBlock(nn.Module):
|
83 |
+
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
|
84 |
+
dropout, temb_channels=512):
|
85 |
+
super().__init__()
|
86 |
+
self.in_channels = in_channels
|
87 |
+
out_channels = in_channels if out_channels is None else out_channels
|
88 |
+
self.out_channels = out_channels
|
89 |
+
self.use_conv_shortcut = conv_shortcut
|
90 |
+
|
91 |
+
self.norm1 = Normalize(in_channels)
|
92 |
+
self.conv1 = torch.nn.Conv2d(in_channels,
|
93 |
+
out_channels,
|
94 |
+
kernel_size=3,
|
95 |
+
stride=1,
|
96 |
+
padding=1)
|
97 |
+
if temb_channels > 0:
|
98 |
+
self.temb_proj = torch.nn.Linear(temb_channels,
|
99 |
+
out_channels)
|
100 |
+
self.norm2 = Normalize(out_channels)
|
101 |
+
self.dropout = torch.nn.Dropout(dropout)
|
102 |
+
self.conv2 = torch.nn.Conv2d(out_channels,
|
103 |
+
out_channels,
|
104 |
+
kernel_size=3,
|
105 |
+
stride=1,
|
106 |
+
padding=1)
|
107 |
+
if self.in_channels != self.out_channels:
|
108 |
+
if self.use_conv_shortcut:
|
109 |
+
self.conv_shortcut = torch.nn.Conv2d(in_channels,
|
110 |
+
out_channels,
|
111 |
+
kernel_size=3,
|
112 |
+
stride=1,
|
113 |
+
padding=1)
|
114 |
+
else:
|
115 |
+
self.nin_shortcut = torch.nn.Conv2d(in_channels,
|
116 |
+
out_channels,
|
117 |
+
kernel_size=1,
|
118 |
+
stride=1,
|
119 |
+
padding=0)
|
120 |
+
|
121 |
+
def forward(self, x, temb):
|
122 |
+
h = x
|
123 |
+
h = self.norm1(h)
|
124 |
+
h = nonlinearity(h)
|
125 |
+
h = self.conv1(h)
|
126 |
+
|
127 |
+
if temb is not None:
|
128 |
+
h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]
|
129 |
+
|
130 |
+
h = self.norm2(h)
|
131 |
+
h = nonlinearity(h)
|
132 |
+
h = self.dropout(h)
|
133 |
+
h = self.conv2(h)
|
134 |
+
|
135 |
+
if self.in_channels != self.out_channels:
|
136 |
+
if self.use_conv_shortcut:
|
137 |
+
x = self.conv_shortcut(x)
|
138 |
+
else:
|
139 |
+
x = self.nin_shortcut(x)
|
140 |
+
|
141 |
+
return x+h
|
142 |
+
|
143 |
+
|
144 |
+
|
145 |
+
|
146 |
+
class AttnBlock(nn.Module):
|
147 |
+
def __init__(self, in_channels):
|
148 |
+
super().__init__()
|
149 |
+
self.in_channels = in_channels
|
150 |
+
|
151 |
+
self.norm = Normalize(in_channels)
|
152 |
+
self.q = torch.nn.Conv2d(in_channels,
|
153 |
+
in_channels,
|
154 |
+
kernel_size=1,
|
155 |
+
stride=1,
|
156 |
+
padding=0)
|
157 |
+
self.k = torch.nn.Conv2d(in_channels,
|
158 |
+
in_channels,
|
159 |
+
kernel_size=1,
|
160 |
+
stride=1,
|
161 |
+
padding=0)
|
162 |
+
self.v = torch.nn.Conv2d(in_channels,
|
163 |
+
in_channels,
|
164 |
+
kernel_size=1,
|
165 |
+
stride=1,
|
166 |
+
padding=0)
|
167 |
+
self.proj_out = torch.nn.Conv2d(in_channels,
|
168 |
+
in_channels,
|
169 |
+
kernel_size=1,
|
170 |
+
stride=1,
|
171 |
+
padding=0)
|
172 |
+
|
173 |
+
|
174 |
+
def forward(self, x):
|
175 |
+
h_ = x
|
176 |
+
h_ = self.norm(h_)
|
177 |
+
q = self.q(h_)
|
178 |
+
k = self.k(h_)
|
179 |
+
v = self.v(h_)
|
180 |
+
|
181 |
+
# compute attention
|
182 |
+
b,c,h,w = q.shape
|
183 |
+
q = q.reshape(b,c,h*w)
|
184 |
+
q = q.permute(0,2,1) # b,hw,c
|
185 |
+
k = k.reshape(b,c,h*w) # b,c,hw
|
186 |
+
w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
|
187 |
+
w_ = w_ * (int(c)**(-0.5))
|
188 |
+
w_ = torch.nn.functional.softmax(w_, dim=2)
|
189 |
+
|
190 |
+
# attend to values
|
191 |
+
v = v.reshape(b,c,h*w)
|
192 |
+
w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q)
|
193 |
+
h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
|
194 |
+
h_ = h_.reshape(b,c,h,w)
|
195 |
+
|
196 |
+
h_ = self.proj_out(h_)
|
197 |
+
|
198 |
+
return x+h_
|
199 |
+
|
200 |
+
class LinearAttention(nn.Module):
|
201 |
+
def __init__(self, dim, heads=4, dim_head=32):
|
202 |
+
super().__init__()
|
203 |
+
self.heads = heads
|
204 |
+
hidden_dim = dim_head * heads
|
205 |
+
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
|
206 |
+
self.to_out = nn.Conv2d(hidden_dim, dim, 1)
|
207 |
+
|
208 |
+
def forward(self, x):
|
209 |
+
b, c, h, w = x.shape
|
210 |
+
qkv = self.to_qkv(x)
|
211 |
+
q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3)
|
212 |
+
k = k.softmax(dim=-1)
|
213 |
+
context = torch.einsum('bhdn,bhen->bhde', k, v)
|
214 |
+
out = torch.einsum('bhde,bhdn->bhen', context, q)
|
215 |
+
out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)
|
216 |
+
return self.to_out(out)
|
217 |
+
|
218 |
+
class LinAttnBlock(LinearAttention):
|
219 |
+
"""to match AttnBlock usage"""
|
220 |
+
def __init__(self, in_channels):
|
221 |
+
super().__init__(dim=in_channels, heads=1, dim_head=in_channels)
|
222 |
+
|
223 |
+
def make_attn(in_channels, attn_type="vanilla"):
|
224 |
+
assert attn_type in ["vanilla", "linear", "none"], f'attn_type {attn_type} unknown'
|
225 |
+
print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
|
226 |
+
if attn_type == "vanilla":
|
227 |
+
return AttnBlock(in_channels)
|
228 |
+
elif attn_type == "none":
|
229 |
+
return nn.Identity(in_channels)
|
230 |
+
else:
|
231 |
+
return LinAttnBlock(in_channels)
|
232 |
+
|
233 |
+
|
234 |
+
class Encoder(nn.Module):
|
235 |
+
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
|
236 |
+
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
|
237 |
+
resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla",
|
238 |
+
**ignore_kwargs):
|
239 |
+
super().__init__()
|
240 |
+
if use_linear_attn: attn_type = "linear"
|
241 |
+
self.ch = ch
|
242 |
+
self.temb_ch = 0
|
243 |
+
self.num_resolutions = len(ch_mult)
|
244 |
+
self.num_res_blocks = num_res_blocks
|
245 |
+
self.resolution = resolution
|
246 |
+
self.in_channels = in_channels
|
247 |
+
|
248 |
+
# downsampling
|
249 |
+
self.conv_in = torch.nn.Conv2d(in_channels,
|
250 |
+
self.ch,
|
251 |
+
kernel_size=3,
|
252 |
+
stride=1,
|
253 |
+
padding=1)
|
254 |
+
|
255 |
+
curr_res = resolution
|
256 |
+
in_ch_mult = (1,)+tuple(ch_mult)
|
257 |
+
self.in_ch_mult = in_ch_mult
|
258 |
+
self.down = nn.ModuleList()
|
259 |
+
for i_level in range(self.num_resolutions):
|
260 |
+
block = nn.ModuleList()
|
261 |
+
attn = nn.ModuleList()
|
262 |
+
block_in = ch*in_ch_mult[i_level]
|
263 |
+
block_out = ch*ch_mult[i_level]
|
264 |
+
for i_block in range(self.num_res_blocks):
|
265 |
+
block.append(ResnetBlock(in_channels=block_in,
|
266 |
+
out_channels=block_out,
|
267 |
+
temb_channels=self.temb_ch,
|
268 |
+
dropout=dropout))
|
269 |
+
block_in = block_out
|
270 |
+
if curr_res in attn_resolutions:
|
271 |
+
attn.append(make_attn(block_in, attn_type=attn_type))
|
272 |
+
down = nn.Module()
|
273 |
+
down.block = block
|
274 |
+
down.attn = attn
|
275 |
+
if i_level != self.num_resolutions-1:
|
276 |
+
down.downsample = Downsample(block_in, resamp_with_conv)
|
277 |
+
curr_res = curr_res // 2
|
278 |
+
self.down.append(down)
|
279 |
+
|
280 |
+
# middle
|
281 |
+
self.mid = nn.Module()
|
282 |
+
self.mid.block_1 = ResnetBlock(in_channels=block_in,
|
283 |
+
out_channels=block_in,
|
284 |
+
temb_channels=self.temb_ch,
|
285 |
+
dropout=dropout)
|
286 |
+
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
|
287 |
+
self.mid.block_2 = ResnetBlock(in_channels=block_in,
|
288 |
+
out_channels=block_in,
|
289 |
+
temb_channels=self.temb_ch,
|
290 |
+
dropout=dropout)
|
291 |
+
|
292 |
+
# end
|
293 |
+
self.norm_out = Normalize(block_in)
|
294 |
+
self.conv_out = torch.nn.Conv2d(block_in,
|
295 |
+
2*z_channels if double_z else z_channels,
|
296 |
+
kernel_size=3,
|
297 |
+
stride=1,
|
298 |
+
padding=1)
|
299 |
+
|
300 |
+
def forward(self, x):
|
301 |
+
# timestep embedding
|
302 |
+
temb = None
|
303 |
+
|
304 |
+
# downsampling
|
305 |
+
hs = [self.conv_in(x)]
|
306 |
+
for i_level in range(self.num_resolutions):
|
307 |
+
for i_block in range(self.num_res_blocks):
|
308 |
+
h = self.down[i_level].block[i_block](hs[-1], temb)
|
309 |
+
if len(self.down[i_level].attn) > 0:
|
310 |
+
h = self.down[i_level].attn[i_block](h)
|
311 |
+
hs.append(h)
|
312 |
+
if i_level != self.num_resolutions-1:
|
313 |
+
hs.append(self.down[i_level].downsample(hs[-1]))
|
314 |
+
|
315 |
+
# middle
|
316 |
+
h = hs[-1]
|
317 |
+
h = self.mid.block_1(h, temb)
|
318 |
+
h = self.mid.attn_1(h)
|
319 |
+
h = self.mid.block_2(h, temb)
|
320 |
+
|
321 |
+
# end
|
322 |
+
h = self.norm_out(h)
|
323 |
+
h = nonlinearity(h)
|
324 |
+
h = self.conv_out(h)
|
325 |
+
return h
|
326 |
+
|
327 |
+
|
328 |
+
class Decoder(nn.Module):
|
329 |
+
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
|
330 |
+
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
|
331 |
+
resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False,
|
332 |
+
attn_type="vanilla", **ignorekwargs):
|
333 |
+
super().__init__()
|
334 |
+
if use_linear_attn: attn_type = "linear"
|
335 |
+
self.ch = ch
|
336 |
+
self.temb_ch = 0
|
337 |
+
self.num_resolutions = len(ch_mult)
|
338 |
+
self.num_res_blocks = num_res_blocks
|
339 |
+
self.resolution = resolution
|
340 |
+
self.in_channels = in_channels
|
341 |
+
self.give_pre_end = give_pre_end
|
342 |
+
self.tanh_out = tanh_out
|
343 |
+
|
344 |
+
# compute in_ch_mult, block_in and curr_res at lowest res
|
345 |
+
in_ch_mult = (1,)+tuple(ch_mult)
|
346 |
+
block_in = ch*ch_mult[self.num_resolutions-1]
|
347 |
+
curr_res = resolution // 2**(self.num_resolutions-1)
|
348 |
+
self.z_shape = (1,z_channels,curr_res,curr_res)
|
349 |
+
print("Working with z of shape {} = {} dimensions.".format(
|
350 |
+
self.z_shape, np.prod(self.z_shape)))
|
351 |
+
|
352 |
+
# z to block_in
|
353 |
+
self.conv_in = torch.nn.Conv2d(z_channels,
|
354 |
+
block_in,
|
355 |
+
kernel_size=3,
|
356 |
+
stride=1,
|
357 |
+
padding=1)
|
358 |
+
|
359 |
+
# middle
|
360 |
+
self.mid = nn.Module()
|
361 |
+
self.mid.block_1 = ResnetBlock(in_channels=block_in,
|
362 |
+
out_channels=block_in,
|
363 |
+
temb_channels=self.temb_ch,
|
364 |
+
dropout=dropout)
|
365 |
+
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
|
366 |
+
self.mid.block_2 = ResnetBlock(in_channels=block_in,
|
367 |
+
out_channels=block_in,
|
368 |
+
temb_channels=self.temb_ch,
|
369 |
+
dropout=dropout)
|
370 |
+
|
371 |
+
# upsampling
|
372 |
+
self.up = nn.ModuleList()
|
373 |
+
for i_level in reversed(range(self.num_resolutions)):
|
374 |
+
block = nn.ModuleList()
|
375 |
+
attn = nn.ModuleList()
|
376 |
+
block_out = ch*ch_mult[i_level]
|
377 |
+
for i_block in range(self.num_res_blocks+1):
|
378 |
+
block.append(ResnetBlock(in_channels=block_in,
|
379 |
+
out_channels=block_out,
|
380 |
+
temb_channels=self.temb_ch,
|
381 |
+
dropout=dropout))
|
382 |
+
block_in = block_out
|
383 |
+
if curr_res in attn_resolutions:
|
384 |
+
attn.append(make_attn(block_in, attn_type=attn_type))
|
385 |
+
up = nn.Module()
|
386 |
+
up.block = block
|
387 |
+
up.attn = attn
|
388 |
+
if i_level != 0:
|
389 |
+
up.upsample = Upsample(block_in, resamp_with_conv)
|
390 |
+
curr_res = curr_res * 2
|
391 |
+
self.up.insert(0, up) # prepend to get consistent order
|
392 |
+
|
393 |
+
# end
|
394 |
+
self.norm_out = Normalize(block_in)
|
395 |
+
self.conv_out = torch.nn.Conv2d(block_in,
|
396 |
+
out_ch,
|
397 |
+
kernel_size=3,
|
398 |
+
stride=1,
|
399 |
+
padding=1)
|
400 |
+
|
401 |
+
def forward(self, z):
|
402 |
+
#assert z.shape[1:] == self.z_shape[1:]
|
403 |
+
self.last_z_shape = z.shape
|
404 |
+
|
405 |
+
# timestep embedding
|
406 |
+
temb = None
|
407 |
+
|
408 |
+
# z to block_in
|
409 |
+
h = self.conv_in(z)
|
410 |
+
|
411 |
+
# middle
|
412 |
+
h = self.mid.block_1(h, temb)
|
413 |
+
h = self.mid.attn_1(h)
|
414 |
+
h = self.mid.block_2(h, temb)
|
415 |
+
|
416 |
+
# upsampling
|
417 |
+
for i_level in reversed(range(self.num_resolutions)):
|
418 |
+
for i_block in range(self.num_res_blocks+1):
|
419 |
+
h = self.up[i_level].block[i_block](h, temb)
|
420 |
+
if len(self.up[i_level].attn) > 0:
|
421 |
+
h = self.up[i_level].attn[i_block](h)
|
422 |
+
if i_level != 0:
|
423 |
+
h = self.up[i_level].upsample(h)
|
424 |
+
|
425 |
+
# end
|
426 |
+
if self.give_pre_end:
|
427 |
+
return h
|
428 |
+
|
429 |
+
h = self.norm_out(h)
|
430 |
+
h = nonlinearity(h)
|
431 |
+
h = self.conv_out(h)
|
432 |
+
if self.tanh_out:
|
433 |
+
h = torch.tanh(h)
|
434 |
+
return h
|
435 |
+
|
436 |
+
|
437 |
+
class SimpleDecoder(nn.Module):
|
438 |
+
def __init__(self, in_channels, out_channels, *args, **kwargs):
|
439 |
+
super().__init__()
|
440 |
+
self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1),
|
441 |
+
ResnetBlock(in_channels=in_channels,
|
442 |
+
out_channels=2 * in_channels,
|
443 |
+
temb_channels=0, dropout=0.0),
|
444 |
+
ResnetBlock(in_channels=2 * in_channels,
|
445 |
+
out_channels=4 * in_channels,
|
446 |
+
temb_channels=0, dropout=0.0),
|
447 |
+
ResnetBlock(in_channels=4 * in_channels,
|
448 |
+
out_channels=2 * in_channels,
|
449 |
+
temb_channels=0, dropout=0.0),
|
450 |
+
nn.Conv2d(2*in_channels, in_channels, 1),
|
451 |
+
Upsample(in_channels, with_conv=True)])
|
452 |
+
# end
|
453 |
+
self.norm_out = Normalize(in_channels)
|
454 |
+
self.conv_out = torch.nn.Conv2d(in_channels,
|
455 |
+
out_channels,
|
456 |
+
kernel_size=3,
|
457 |
+
stride=1,
|
458 |
+
padding=1)
|
459 |
+
|
460 |
+
def forward(self, x):
|
461 |
+
for i, layer in enumerate(self.model):
|
462 |
+
if i in [1,2,3]:
|
463 |
+
x = layer(x, None)
|
464 |
+
else:
|
465 |
+
x = layer(x)
|
466 |
+
|
467 |
+
h = self.norm_out(x)
|
468 |
+
h = nonlinearity(h)
|
469 |
+
x = self.conv_out(h)
|
470 |
+
return x
|
471 |
+
|
472 |
+
|
473 |
+
class UpsampleDecoder(nn.Module):
|
474 |
+
def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution,
|
475 |
+
ch_mult=(2,2), dropout=0.0):
|
476 |
+
super().__init__()
|
477 |
+
# upsampling
|
478 |
+
self.temb_ch = 0
|
479 |
+
self.num_resolutions = len(ch_mult)
|
480 |
+
self.num_res_blocks = num_res_blocks
|
481 |
+
block_in = in_channels
|
482 |
+
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
483 |
+
self.res_blocks = nn.ModuleList()
|
484 |
+
self.upsample_blocks = nn.ModuleList()
|
485 |
+
for i_level in range(self.num_resolutions):
|
486 |
+
res_block = []
|
487 |
+
block_out = ch * ch_mult[i_level]
|
488 |
+
for i_block in range(self.num_res_blocks + 1):
|
489 |
+
res_block.append(ResnetBlock(in_channels=block_in,
|
490 |
+
out_channels=block_out,
|
491 |
+
temb_channels=self.temb_ch,
|
492 |
+
dropout=dropout))
|
493 |
+
block_in = block_out
|
494 |
+
self.res_blocks.append(nn.ModuleList(res_block))
|
495 |
+
if i_level != self.num_resolutions - 1:
|
496 |
+
self.upsample_blocks.append(Upsample(block_in, True))
|
497 |
+
curr_res = curr_res * 2
|
498 |
+
|
499 |
+
# end
|
500 |
+
self.norm_out = Normalize(block_in)
|
501 |
+
self.conv_out = torch.nn.Conv2d(block_in,
|
502 |
+
out_channels,
|
503 |
+
kernel_size=3,
|
504 |
+
stride=1,
|
505 |
+
padding=1)
|
506 |
+
|
507 |
+
def forward(self, x):
|
508 |
+
# upsampling
|
509 |
+
h = x
|
510 |
+
for k, i_level in enumerate(range(self.num_resolutions)):
|
511 |
+
for i_block in range(self.num_res_blocks + 1):
|
512 |
+
h = self.res_blocks[i_level][i_block](h, None)
|
513 |
+
if i_level != self.num_resolutions - 1:
|
514 |
+
h = self.upsample_blocks[k](h)
|
515 |
+
h = self.norm_out(h)
|
516 |
+
h = nonlinearity(h)
|
517 |
+
h = self.conv_out(h)
|
518 |
+
return h
|
519 |
+
|
520 |
+
|
521 |
+
class LatentRescaler(nn.Module):
|
522 |
+
def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2):
|
523 |
+
super().__init__()
|
524 |
+
# residual block, interpolate, residual block
|
525 |
+
self.factor = factor
|
526 |
+
self.conv_in = nn.Conv2d(in_channels,
|
527 |
+
mid_channels,
|
528 |
+
kernel_size=3,
|
529 |
+
stride=1,
|
530 |
+
padding=1)
|
531 |
+
self.res_block1 = nn.ModuleList([ResnetBlock(in_channels=mid_channels,
|
532 |
+
out_channels=mid_channels,
|
533 |
+
temb_channels=0,
|
534 |
+
dropout=0.0) for _ in range(depth)])
|
535 |
+
self.attn = AttnBlock(mid_channels)
|
536 |
+
self.res_block2 = nn.ModuleList([ResnetBlock(in_channels=mid_channels,
|
537 |
+
out_channels=mid_channels,
|
538 |
+
temb_channels=0,
|
539 |
+
dropout=0.0) for _ in range(depth)])
|
540 |
+
|
541 |
+
self.conv_out = nn.Conv2d(mid_channels,
|
542 |
+
out_channels,
|
543 |
+
kernel_size=1,
|
544 |
+
)
|
545 |
+
|
546 |
+
def forward(self, x):
|
547 |
+
x = self.conv_in(x)
|
548 |
+
for block in self.res_block1:
|
549 |
+
x = block(x, None)
|
550 |
+
x = torch.nn.functional.interpolate(x, size=(int(round(x.shape[2]*self.factor)), int(round(x.shape[3]*self.factor))))
|
551 |
+
x = self.attn(x)
|
552 |
+
for block in self.res_block2:
|
553 |
+
x = block(x, None)
|
554 |
+
x = self.conv_out(x)
|
555 |
+
return x
|
556 |
+
|
557 |
+
|
558 |
+
class MergedRescaleEncoder(nn.Module):
|
559 |
+
def __init__(self, in_channels, ch, resolution, out_ch, num_res_blocks,
|
560 |
+
attn_resolutions, dropout=0.0, resamp_with_conv=True,
|
561 |
+
ch_mult=(1,2,4,8), rescale_factor=1.0, rescale_module_depth=1):
|
562 |
+
super().__init__()
|
563 |
+
intermediate_chn = ch * ch_mult[-1]
|
564 |
+
self.encoder = Encoder(in_channels=in_channels, num_res_blocks=num_res_blocks, ch=ch, ch_mult=ch_mult,
|
565 |
+
z_channels=intermediate_chn, double_z=False, resolution=resolution,
|
566 |
+
attn_resolutions=attn_resolutions, dropout=dropout, resamp_with_conv=resamp_with_conv,
|
567 |
+
out_ch=None)
|
568 |
+
self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=intermediate_chn,
|
569 |
+
mid_channels=intermediate_chn, out_channels=out_ch, depth=rescale_module_depth)
|
570 |
+
|
571 |
+
def forward(self, x):
|
572 |
+
x = self.encoder(x)
|
573 |
+
x = self.rescaler(x)
|
574 |
+
return x
|
575 |
+
|
576 |
+
|
577 |
+
class MergedRescaleDecoder(nn.Module):
|
578 |
+
def __init__(self, z_channels, out_ch, resolution, num_res_blocks, attn_resolutions, ch, ch_mult=(1,2,4,8),
|
579 |
+
dropout=0.0, resamp_with_conv=True, rescale_factor=1.0, rescale_module_depth=1):
|
580 |
+
super().__init__()
|
581 |
+
tmp_chn = z_channels*ch_mult[-1]
|
582 |
+
self.decoder = Decoder(out_ch=out_ch, z_channels=tmp_chn, attn_resolutions=attn_resolutions, dropout=dropout,
|
583 |
+
resamp_with_conv=resamp_with_conv, in_channels=None, num_res_blocks=num_res_blocks,
|
584 |
+
ch_mult=ch_mult, resolution=resolution, ch=ch)
|
585 |
+
self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=z_channels, mid_channels=tmp_chn,
|
586 |
+
out_channels=tmp_chn, depth=rescale_module_depth)
|
587 |
+
|
588 |
+
def forward(self, x):
|
589 |
+
x = self.rescaler(x)
|
590 |
+
x = self.decoder(x)
|
591 |
+
return x
|
592 |
+
|
593 |
+
|
594 |
+
class Upsampler(nn.Module):
|
595 |
+
def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2):
|
596 |
+
super().__init__()
|
597 |
+
assert out_size >= in_size
|
598 |
+
num_blocks = int(np.log2(out_size//in_size))+1
|
599 |
+
factor_up = 1.+ (out_size % in_size)
|
600 |
+
print(f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}")
|
601 |
+
self.rescaler = LatentRescaler(factor=factor_up, in_channels=in_channels, mid_channels=2*in_channels,
|
602 |
+
out_channels=in_channels)
|
603 |
+
self.decoder = Decoder(out_ch=out_channels, resolution=out_size, z_channels=in_channels, num_res_blocks=2,
|
604 |
+
attn_resolutions=[], in_channels=None, ch=in_channels,
|
605 |
+
ch_mult=[ch_mult for _ in range(num_blocks)])
|
606 |
+
|
607 |
+
def forward(self, x):
|
608 |
+
x = self.rescaler(x)
|
609 |
+
x = self.decoder(x)
|
610 |
+
return x
|
611 |
+
|
612 |
+
|
613 |
+
class Resize(nn.Module):
|
614 |
+
def __init__(self, in_channels=None, learned=False, mode="bilinear"):
|
615 |
+
super().__init__()
|
616 |
+
self.with_conv = learned
|
617 |
+
self.mode = mode
|
618 |
+
if self.with_conv:
|
619 |
+
print(f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode")
|
620 |
+
raise NotImplementedError()
|
621 |
+
assert in_channels is not None
|
622 |
+
# no asymmetric padding in torch conv, must do it ourselves
|
623 |
+
self.conv = torch.nn.Conv2d(in_channels,
|
624 |
+
in_channels,
|
625 |
+
kernel_size=4,
|
626 |
+
stride=2,
|
627 |
+
padding=1)
|
628 |
+
|
629 |
+
def forward(self, x, scale_factor=1.0):
|
630 |
+
if scale_factor==1.0:
|
631 |
+
return x
|
632 |
+
else:
|
633 |
+
x = torch.nn.functional.interpolate(x, mode=self.mode, align_corners=False, scale_factor=scale_factor)
|
634 |
+
return x
|
635 |
+
|
636 |
+
class FirstStagePostProcessor(nn.Module):
|
637 |
+
|
638 |
+
def __init__(self, ch_mult:list, in_channels,
|
639 |
+
pretrained_model:nn.Module=None,
|
640 |
+
reshape=False,
|
641 |
+
n_channels=None,
|
642 |
+
dropout=0.,
|
643 |
+
pretrained_config=None):
|
644 |
+
super().__init__()
|
645 |
+
if pretrained_config is None:
|
646 |
+
assert pretrained_model is not None, 'Either "pretrained_model" or "pretrained_config" must not be None'
|
647 |
+
self.pretrained_model = pretrained_model
|
648 |
+
else:
|
649 |
+
assert pretrained_config is not None, 'Either "pretrained_model" or "pretrained_config" must not be None'
|
650 |
+
self.instantiate_pretrained(pretrained_config)
|
651 |
+
|
652 |
+
self.do_reshape = reshape
|
653 |
+
|
654 |
+
if n_channels is None:
|
655 |
+
n_channels = self.pretrained_model.encoder.ch
|
656 |
+
|
657 |
+
self.proj_norm = Normalize(in_channels,num_groups=in_channels//2)
|
658 |
+
self.proj = nn.Conv2d(in_channels,n_channels,kernel_size=3,
|
659 |
+
stride=1,padding=1)
|
660 |
+
|
661 |
+
blocks = []
|
662 |
+
downs = []
|
663 |
+
ch_in = n_channels
|
664 |
+
for m in ch_mult:
|
665 |
+
blocks.append(ResnetBlock(in_channels=ch_in,out_channels=m*n_channels,dropout=dropout))
|
666 |
+
ch_in = m * n_channels
|
667 |
+
downs.append(Downsample(ch_in, with_conv=False))
|
668 |
+
|
669 |
+
self.model = nn.ModuleList(blocks)
|
670 |
+
self.downsampler = nn.ModuleList(downs)
|
671 |
+
|
672 |
+
|
673 |
+
def instantiate_pretrained(self, config):
|
674 |
+
model = instantiate_from_config(config)
|
675 |
+
self.pretrained_model = model.eval()
|
676 |
+
# self.pretrained_model.train = False
|
677 |
+
for param in self.pretrained_model.parameters():
|
678 |
+
param.requires_grad = False
|
679 |
+
|
680 |
+
|
681 |
+
@torch.no_grad()
|
682 |
+
def encode_with_pretrained(self,x):
|
683 |
+
c = self.pretrained_model.encode(x)
|
684 |
+
if isinstance(c, DiagonalGaussianDistribution):
|
685 |
+
c = c.mode()
|
686 |
+
return c
|
687 |
+
|
688 |
+
def forward(self,x):
|
689 |
+
z_fs = self.encode_with_pretrained(x)
|
690 |
+
z = self.proj_norm(z_fs)
|
691 |
+
z = self.proj(z)
|
692 |
+
z = nonlinearity(z)
|
693 |
+
|
694 |
+
for submodel, downmodel in zip(self.model,self.downsampler):
|
695 |
+
z = submodel(z,temb=None)
|
696 |
+
z = downmodel(z)
|
697 |
+
|
698 |
+
if self.do_reshape:
|
699 |
+
z = rearrange(z,'b c h w -> b (h w) c')
|
700 |
+
return z
|
701 |
+
|
easyanimate/vae/ldm/modules/diffusionmodules/util.py
ADDED
@@ -0,0 +1,268 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# adopted from
|
2 |
+
# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
|
3 |
+
# and
|
4 |
+
# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
|
5 |
+
# and
|
6 |
+
# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
|
7 |
+
#
|
8 |
+
# thanks!
|
9 |
+
|
10 |
+
|
11 |
+
import math
|
12 |
+
import os
|
13 |
+
|
14 |
+
import numpy as np
|
15 |
+
import torch
|
16 |
+
import torch.nn as nn
|
17 |
+
from einops import repeat
|
18 |
+
|
19 |
+
from ...util import instantiate_from_config
|
20 |
+
|
21 |
+
|
22 |
+
def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
|
23 |
+
if schedule == "linear":
|
24 |
+
betas = (
|
25 |
+
torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
|
26 |
+
)
|
27 |
+
|
28 |
+
elif schedule == "cosine":
|
29 |
+
timesteps = (
|
30 |
+
torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
|
31 |
+
)
|
32 |
+
alphas = timesteps / (1 + cosine_s) * np.pi / 2
|
33 |
+
alphas = torch.cos(alphas).pow(2)
|
34 |
+
alphas = alphas / alphas[0]
|
35 |
+
betas = 1 - alphas[1:] / alphas[:-1]
|
36 |
+
betas = np.clip(betas, a_min=0, a_max=0.999)
|
37 |
+
|
38 |
+
elif schedule == "sqrt_linear":
|
39 |
+
betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
|
40 |
+
elif schedule == "sqrt":
|
41 |
+
betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5
|
42 |
+
else:
|
43 |
+
raise ValueError(f"schedule '{schedule}' unknown.")
|
44 |
+
return betas.numpy()
|
45 |
+
|
46 |
+
|
47 |
+
def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True):
|
48 |
+
if ddim_discr_method == 'uniform':
|
49 |
+
c = num_ddpm_timesteps // num_ddim_timesteps
|
50 |
+
ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
|
51 |
+
elif ddim_discr_method == 'quad':
|
52 |
+
ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int)
|
53 |
+
else:
|
54 |
+
raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"')
|
55 |
+
|
56 |
+
# assert ddim_timesteps.shape[0] == num_ddim_timesteps
|
57 |
+
# add one to get the final alpha values right (the ones from first scale to data during sampling)
|
58 |
+
steps_out = ddim_timesteps + 1
|
59 |
+
if verbose:
|
60 |
+
print(f'Selected timesteps for ddim sampler: {steps_out}')
|
61 |
+
return steps_out
|
62 |
+
|
63 |
+
|
64 |
+
def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
|
65 |
+
# select alphas for computing the variance schedule
|
66 |
+
alphas = alphacums[ddim_timesteps]
|
67 |
+
alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
|
68 |
+
|
69 |
+
# according the the formula provided in https://arxiv.org/abs/2010.02502
|
70 |
+
sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
|
71 |
+
if verbose:
|
72 |
+
print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')
|
73 |
+
print(f'For the chosen value of eta, which is {eta}, '
|
74 |
+
f'this results in the following sigma_t schedule for ddim sampler {sigmas}')
|
75 |
+
return sigmas, alphas, alphas_prev
|
76 |
+
|
77 |
+
|
78 |
+
def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
|
79 |
+
"""
|
80 |
+
Create a beta schedule that discretizes the given alpha_t_bar function,
|
81 |
+
which defines the cumulative product of (1-beta) over time from t = [0,1].
|
82 |
+
:param num_diffusion_timesteps: the number of betas to produce.
|
83 |
+
:param alpha_bar: a lambda that takes an argument t from 0 to 1 and
|
84 |
+
produces the cumulative product of (1-beta) up to that
|
85 |
+
part of the diffusion process.
|
86 |
+
:param max_beta: the maximum beta to use; use values lower than 1 to
|
87 |
+
prevent singularities.
|
88 |
+
"""
|
89 |
+
betas = []
|
90 |
+
for i in range(num_diffusion_timesteps):
|
91 |
+
t1 = i / num_diffusion_timesteps
|
92 |
+
t2 = (i + 1) / num_diffusion_timesteps
|
93 |
+
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
|
94 |
+
return np.array(betas)
|
95 |
+
|
96 |
+
|
97 |
+
def extract_into_tensor(a, t, x_shape):
|
98 |
+
b, *_ = t.shape
|
99 |
+
out = a.gather(-1, t)
|
100 |
+
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
|
101 |
+
|
102 |
+
|
103 |
+
def checkpoint(func, inputs, params, flag):
|
104 |
+
"""
|
105 |
+
Evaluate a function without caching intermediate activations, allowing for
|
106 |
+
reduced memory at the expense of extra compute in the backward pass.
|
107 |
+
:param func: the function to evaluate.
|
108 |
+
:param inputs: the argument sequence to pass to `func`.
|
109 |
+
:param params: a sequence of parameters `func` depends on but does not
|
110 |
+
explicitly take as arguments.
|
111 |
+
:param flag: if False, disable gradient checkpointing.
|
112 |
+
"""
|
113 |
+
if flag:
|
114 |
+
args = tuple(inputs) + tuple(params)
|
115 |
+
return CheckpointFunction.apply(func, len(inputs), *args)
|
116 |
+
else:
|
117 |
+
return func(*inputs)
|
118 |
+
|
119 |
+
|
120 |
+
class CheckpointFunction(torch.autograd.Function):
|
121 |
+
@staticmethod
|
122 |
+
def forward(ctx, run_function, length, *args):
|
123 |
+
ctx.run_function = run_function
|
124 |
+
ctx.input_tensors = list(args[:length])
|
125 |
+
ctx.input_params = list(args[length:])
|
126 |
+
|
127 |
+
with torch.no_grad():
|
128 |
+
output_tensors = ctx.run_function(*ctx.input_tensors)
|
129 |
+
return output_tensors
|
130 |
+
|
131 |
+
@staticmethod
|
132 |
+
def backward(ctx, *output_grads):
|
133 |
+
ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
|
134 |
+
with torch.enable_grad():
|
135 |
+
# Fixes a bug where the first op in run_function modifies the
|
136 |
+
# Tensor storage in place, which is not allowed for detach()'d
|
137 |
+
# Tensors.
|
138 |
+
shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
|
139 |
+
output_tensors = ctx.run_function(*shallow_copies)
|
140 |
+
input_grads = torch.autograd.grad(
|
141 |
+
output_tensors,
|
142 |
+
ctx.input_tensors + ctx.input_params,
|
143 |
+
output_grads,
|
144 |
+
allow_unused=True,
|
145 |
+
)
|
146 |
+
del ctx.input_tensors
|
147 |
+
del ctx.input_params
|
148 |
+
del output_tensors
|
149 |
+
return (None, None) + input_grads
|
150 |
+
|
151 |
+
|
152 |
+
def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
|
153 |
+
"""
|
154 |
+
Create sinusoidal timestep embeddings.
|
155 |
+
:param timesteps: a 1-D Tensor of N indices, one per batch element.
|
156 |
+
These may be fractional.
|
157 |
+
:param dim: the dimension of the output.
|
158 |
+
:param max_period: controls the minimum frequency of the embeddings.
|
159 |
+
:return: an [N x dim] Tensor of positional embeddings.
|
160 |
+
"""
|
161 |
+
if not repeat_only:
|
162 |
+
half = dim // 2
|
163 |
+
freqs = torch.exp(
|
164 |
+
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
|
165 |
+
).to(device=timesteps.device)
|
166 |
+
args = timesteps[:, None].float() * freqs[None]
|
167 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
168 |
+
if dim % 2:
|
169 |
+
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
170 |
+
else:
|
171 |
+
embedding = repeat(timesteps, 'b -> b d', d=dim)
|
172 |
+
return embedding
|
173 |
+
|
174 |
+
|
175 |
+
def zero_module(module):
|
176 |
+
"""
|
177 |
+
Zero out the parameters of a module and return it.
|
178 |
+
"""
|
179 |
+
for p in module.parameters():
|
180 |
+
p.detach().zero_()
|
181 |
+
return module
|
182 |
+
|
183 |
+
|
184 |
+
def scale_module(module, scale):
|
185 |
+
"""
|
186 |
+
Scale the parameters of a module and return it.
|
187 |
+
"""
|
188 |
+
for p in module.parameters():
|
189 |
+
p.detach().mul_(scale)
|
190 |
+
return module
|
191 |
+
|
192 |
+
|
193 |
+
def mean_flat(tensor):
|
194 |
+
"""
|
195 |
+
Take the mean over all non-batch dimensions.
|
196 |
+
"""
|
197 |
+
return tensor.mean(dim=list(range(1, len(tensor.shape))))
|
198 |
+
|
199 |
+
|
200 |
+
def normalization(channels):
|
201 |
+
"""
|
202 |
+
Make a standard normalization layer.
|
203 |
+
:param channels: number of input channels.
|
204 |
+
:return: an nn.Module for normalization.
|
205 |
+
"""
|
206 |
+
return GroupNorm32(32, channels)
|
207 |
+
|
208 |
+
|
209 |
+
# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
|
210 |
+
class SiLU(nn.Module):
|
211 |
+
def forward(self, x):
|
212 |
+
return x * torch.sigmoid(x)
|
213 |
+
|
214 |
+
|
215 |
+
class GroupNorm32(nn.GroupNorm):
|
216 |
+
def forward(self, x):
|
217 |
+
return super().forward(x.float()).type(x.dtype)
|
218 |
+
|
219 |
+
def conv_nd(dims, *args, **kwargs):
|
220 |
+
"""
|
221 |
+
Create a 1D, 2D, or 3D convolution module.
|
222 |
+
"""
|
223 |
+
if dims == 1:
|
224 |
+
return nn.Conv1d(*args, **kwargs)
|
225 |
+
elif dims == 2:
|
226 |
+
return nn.Conv2d(*args, **kwargs)
|
227 |
+
elif dims == 3:
|
228 |
+
return nn.Conv3d(*args, **kwargs)
|
229 |
+
raise ValueError(f"unsupported dimensions: {dims}")
|
230 |
+
|
231 |
+
|
232 |
+
def linear(*args, **kwargs):
|
233 |
+
"""
|
234 |
+
Create a linear module.
|
235 |
+
"""
|
236 |
+
return nn.Linear(*args, **kwargs)
|
237 |
+
|
238 |
+
|
239 |
+
def avg_pool_nd(dims, *args, **kwargs):
|
240 |
+
"""
|
241 |
+
Create a 1D, 2D, or 3D average pooling module.
|
242 |
+
"""
|
243 |
+
if dims == 1:
|
244 |
+
return nn.AvgPool1d(*args, **kwargs)
|
245 |
+
elif dims == 2:
|
246 |
+
return nn.AvgPool2d(*args, **kwargs)
|
247 |
+
elif dims == 3:
|
248 |
+
return nn.AvgPool3d(*args, **kwargs)
|
249 |
+
raise ValueError(f"unsupported dimensions: {dims}")
|
250 |
+
|
251 |
+
|
252 |
+
class HybridConditioner(nn.Module):
|
253 |
+
|
254 |
+
def __init__(self, c_concat_config, c_crossattn_config):
|
255 |
+
super().__init__()
|
256 |
+
self.concat_conditioner = instantiate_from_config(c_concat_config)
|
257 |
+
self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)
|
258 |
+
|
259 |
+
def forward(self, c_concat, c_crossattn):
|
260 |
+
c_concat = self.concat_conditioner(c_concat)
|
261 |
+
c_crossattn = self.crossattn_conditioner(c_crossattn)
|
262 |
+
return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]}
|
263 |
+
|
264 |
+
|
265 |
+
def noise_like(shape, device, repeat=False):
|
266 |
+
repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
|
267 |
+
noise = lambda: torch.randn(shape, device=device)
|
268 |
+
return repeat_noise() if repeat else noise()
|
easyanimate/vae/ldm/modules/distributions/__init__.py
ADDED
File without changes
|
easyanimate/vae/ldm/modules/distributions/distributions.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
|
4 |
+
|
5 |
+
class AbstractDistribution:
|
6 |
+
def sample(self):
|
7 |
+
raise NotImplementedError()
|
8 |
+
|
9 |
+
def mode(self):
|
10 |
+
raise NotImplementedError()
|
11 |
+
|
12 |
+
|
13 |
+
class DiracDistribution(AbstractDistribution):
|
14 |
+
def __init__(self, value):
|
15 |
+
self.value = value
|
16 |
+
|
17 |
+
def sample(self):
|
18 |
+
return self.value
|
19 |
+
|
20 |
+
def mode(self):
|
21 |
+
return self.value
|
22 |
+
|
23 |
+
|
24 |
+
class DiagonalGaussianDistribution(object):
|
25 |
+
def __init__(self, parameters, deterministic=False):
|
26 |
+
self.parameters = parameters
|
27 |
+
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
|
28 |
+
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
|
29 |
+
self.deterministic = deterministic
|
30 |
+
self.std = torch.exp(0.5 * self.logvar)
|
31 |
+
self.var = torch.exp(self.logvar)
|
32 |
+
if self.deterministic:
|
33 |
+
self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
|
34 |
+
|
35 |
+
def sample(self):
|
36 |
+
x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
|
37 |
+
return x
|
38 |
+
|
39 |
+
def kl(self, other=None):
|
40 |
+
if self.deterministic:
|
41 |
+
return torch.Tensor([0.])
|
42 |
+
else:
|
43 |
+
if other is None:
|
44 |
+
return 0.5 * torch.sum(torch.pow(self.mean, 2)
|
45 |
+
+ self.var - 1.0 - self.logvar,
|
46 |
+
dim=[1, 2, 3])
|
47 |
+
else:
|
48 |
+
return 0.5 * torch.sum(
|
49 |
+
torch.pow(self.mean - other.mean, 2) / other.var
|
50 |
+
+ self.var / other.var - 1.0 - self.logvar + other.logvar,
|
51 |
+
dim=[1, 2, 3])
|
52 |
+
|
53 |
+
def nll(self, sample, dims=[1,2,3]):
|
54 |
+
if self.deterministic:
|
55 |
+
return torch.Tensor([0.])
|
56 |
+
logtwopi = np.log(2.0 * np.pi)
|
57 |
+
return 0.5 * torch.sum(
|
58 |
+
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
|
59 |
+
dim=dims)
|
60 |
+
|
61 |
+
def mode(self):
|
62 |
+
return self.mean
|
63 |
+
|
64 |
+
|
65 |
+
def normal_kl(mean1, logvar1, mean2, logvar2):
|
66 |
+
"""
|
67 |
+
source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
|
68 |
+
Compute the KL divergence between two gaussians.
|
69 |
+
Shapes are automatically broadcasted, so batches can be compared to
|
70 |
+
scalars, among other use cases.
|
71 |
+
"""
|
72 |
+
tensor = None
|
73 |
+
for obj in (mean1, logvar1, mean2, logvar2):
|
74 |
+
if isinstance(obj, torch.Tensor):
|
75 |
+
tensor = obj
|
76 |
+
break
|
77 |
+
assert tensor is not None, "at least one argument must be a Tensor"
|
78 |
+
|
79 |
+
# Force variances to be Tensors. Broadcasting helps convert scalars to
|
80 |
+
# Tensors, but it does not work for torch.exp().
|
81 |
+
logvar1, logvar2 = [
|
82 |
+
x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
|
83 |
+
for x in (logvar1, logvar2)
|
84 |
+
]
|
85 |
+
|
86 |
+
return 0.5 * (
|
87 |
+
-1.0
|
88 |
+
+ logvar2
|
89 |
+
- logvar1
|
90 |
+
+ torch.exp(logvar1 - logvar2)
|
91 |
+
+ ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
|
92 |
+
)
|