bubbliiiing commited on
Commit
19fe404
·
1 Parent(s): 6e9356e

Create Code

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. app.py +40 -0
  2. config/easyanimate_image_magvit_v2.yaml +8 -0
  3. config/easyanimate_image_normal_v1.yaml +8 -0
  4. config/easyanimate_image_slicevae_v3.yaml +9 -0
  5. config/easyanimate_video_casual_motion_module_v1.yaml +27 -0
  6. config/easyanimate_video_long_sequence_v1.yaml +14 -0
  7. config/easyanimate_video_magvit_motion_module_v2.yaml +26 -0
  8. config/easyanimate_video_motion_module_v1.yaml +24 -0
  9. config/easyanimate_video_slicevae_motion_module_v3.yaml +27 -0
  10. easyanimate/__init__.py +0 -0
  11. easyanimate/api/api.py +96 -0
  12. easyanimate/api/post_infer.py +94 -0
  13. easyanimate/data/bucket_sampler.py +379 -0
  14. easyanimate/data/dataset_image.py +76 -0
  15. easyanimate/data/dataset_image_video.py +241 -0
  16. easyanimate/data/dataset_video.py +262 -0
  17. easyanimate/models/attention.py +1299 -0
  18. easyanimate/models/autoencoder_magvit.py +503 -0
  19. easyanimate/models/motion_module.py +575 -0
  20. easyanimate/models/patch.py +426 -0
  21. easyanimate/models/transformer2d.py +555 -0
  22. easyanimate/models/transformer3d.py +738 -0
  23. easyanimate/pipeline/pipeline_easyanimate.py +847 -0
  24. easyanimate/pipeline/pipeline_easyanimate_inpaint.py +984 -0
  25. easyanimate/pipeline/pipeline_pixart_magvit.py +983 -0
  26. easyanimate/ui/ui.py +818 -0
  27. easyanimate/utils/__init__.py +0 -0
  28. easyanimate/utils/diffusion_utils.py +92 -0
  29. easyanimate/utils/gaussian_diffusion.py +1008 -0
  30. easyanimate/utils/lora_utils.py +476 -0
  31. easyanimate/utils/respace.py +131 -0
  32. easyanimate/utils/utils.py +64 -0
  33. easyanimate/vae/LICENSE +82 -0
  34. easyanimate/vae/README.md +63 -0
  35. easyanimate/vae/README_zh-CN.md +63 -0
  36. easyanimate/vae/configs/autoencoder/autoencoder_kl_32x32x4_mag.yaml +62 -0
  37. easyanimate/vae/configs/autoencoder/autoencoder_kl_32x32x4_slice.yaml +65 -0
  38. easyanimate/vae/configs/autoencoder/autoencoder_kl_32x32x4_slice_decoder_only.yaml +66 -0
  39. easyanimate/vae/configs/autoencoder/autoencoder_kl_32x32x4_slice_t_downsample_8.yaml +66 -0
  40. easyanimate/vae/environment.yaml +29 -0
  41. easyanimate/vae/ldm/data/__init__.py +0 -0
  42. easyanimate/vae/ldm/data/base.py +25 -0
  43. easyanimate/vae/ldm/data/dataset_callback.py +25 -0
  44. easyanimate/vae/ldm/data/dataset_image_video.py +281 -0
  45. easyanimate/vae/ldm/lr_scheduler.py +98 -0
  46. easyanimate/vae/ldm/modules/diffusionmodules/__init__.py +0 -0
  47. easyanimate/vae/ldm/modules/diffusionmodules/model.py +701 -0
  48. easyanimate/vae/ldm/modules/diffusionmodules/util.py +268 -0
  49. easyanimate/vae/ldm/modules/distributions/__init__.py +0 -0
  50. easyanimate/vae/ldm/modules/distributions/distributions.py +92 -0
app.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+
3
+ from easyanimate.api.api import infer_forward_api, update_diffusion_transformer_api, update_edition_api
4
+ from easyanimate.ui.ui import ui_modelscope, ui, ui_huggingface
5
+
6
+ if __name__ == "__main__":
7
+ # Choose the ui mode
8
+ ui_mode = "huggingface"
9
+ # Server ip
10
+ server_name = "0.0.0.0"
11
+ server_port = 7860
12
+
13
+ # Params below is used when ui_mode = "modelscope"
14
+ edition = "v2"
15
+ config_path = "config/easyanimate_video_magvit_motion_module_v2.yaml"
16
+ model_name = "models/Diffusion_Transformer/EasyAnimateV2-XL-2-512x512"
17
+ savedir_sample = "samples"
18
+
19
+ if ui_mode == "modelscope":
20
+ demo, controller = ui_modelscope(edition, config_path, model_name, savedir_sample)
21
+ elif ui_mode == "huggingface":
22
+ demo, controller = ui_huggingface(edition, config_path, model_name, savedir_sample)
23
+ else:
24
+ demo, controller = ui()
25
+
26
+ # launch gradio
27
+ app, _, _ = demo.queue(status_update_rate=1).launch(
28
+ server_name=server_name,
29
+ server_port=server_port,
30
+ prevent_thread_lock=True
31
+ )
32
+
33
+ # launch api
34
+ infer_forward_api(None, app, controller)
35
+ update_diffusion_transformer_api(None, app, controller)
36
+ update_edition_api(None, app, controller)
37
+
38
+ # not close the python
39
+ while True:
40
+ time.sleep(5)
config/easyanimate_image_magvit_v2.yaml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ noise_scheduler_kwargs:
2
+ beta_start: 0.0001
3
+ beta_end: 0.02
4
+ beta_schedule: "linear"
5
+ steps_offset: 1
6
+
7
+ vae_kwargs:
8
+ enable_magvit: true
config/easyanimate_image_normal_v1.yaml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ noise_scheduler_kwargs:
2
+ beta_start: 0.0001
3
+ beta_end: 0.02
4
+ beta_schedule: "linear"
5
+ steps_offset: 1
6
+
7
+ vae_kwargs:
8
+ enable_magvit: false
config/easyanimate_image_slicevae_v3.yaml ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ noise_scheduler_kwargs:
2
+ beta_start: 0.0001
3
+ beta_end: 0.02
4
+ beta_schedule: "linear"
5
+ steps_offset: 1
6
+
7
+ vae_kwargs:
8
+ enable_magvit: true
9
+ slice_compression_vae: true
config/easyanimate_video_casual_motion_module_v1.yaml ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ transformer_additional_kwargs:
2
+ patch_3d: false
3
+ fake_3d: false
4
+ casual_3d: true
5
+ casual_3d_upsampler_index: [16, 20]
6
+ time_patch_size: 4
7
+ basic_block_type: "motionmodule"
8
+ time_position_encoding_before_transformer: false
9
+ motion_module_type: "VanillaGrid"
10
+
11
+ motion_module_kwargs:
12
+ num_attention_heads: 8
13
+ num_transformer_block: 1
14
+ attention_block_types: [ "Temporal_Self", "Temporal_Self" ]
15
+ temporal_position_encoding: true
16
+ temporal_position_encoding_max_len: 4096
17
+ temporal_attention_dim_div: 1
18
+ block_size: 2
19
+
20
+ noise_scheduler_kwargs:
21
+ beta_start: 0.0001
22
+ beta_end: 0.02
23
+ beta_schedule: "linear"
24
+ steps_offset: 1
25
+
26
+ vae_kwargs:
27
+ enable_magvit: false
config/easyanimate_video_long_sequence_v1.yaml ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ transformer_additional_kwargs:
2
+ patch_3d: false
3
+ fake_3d: false
4
+ basic_block_type: "selfattentiontemporal"
5
+ time_position_encoding_before_transformer: true
6
+
7
+ noise_scheduler_kwargs:
8
+ beta_start: 0.0001
9
+ beta_end: 0.02
10
+ beta_schedule: "linear"
11
+ steps_offset: 1
12
+
13
+ vae_kwargs:
14
+ enable_magvit: false
config/easyanimate_video_magvit_motion_module_v2.yaml ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ transformer_additional_kwargs:
2
+ patch_3d: false
3
+ fake_3d: false
4
+ basic_block_type: "motionmodule"
5
+ time_position_encoding_before_transformer: false
6
+ motion_module_type: "Vanilla"
7
+ enable_uvit: true
8
+
9
+ motion_module_kwargs:
10
+ num_attention_heads: 8
11
+ num_transformer_block: 1
12
+ attention_block_types: [ "Temporal_Self", "Temporal_Self" ]
13
+ temporal_position_encoding: true
14
+ temporal_position_encoding_max_len: 4096
15
+ temporal_attention_dim_div: 1
16
+ block_size: 1
17
+
18
+ noise_scheduler_kwargs:
19
+ beta_start: 0.0001
20
+ beta_end: 0.02
21
+ beta_schedule: "linear"
22
+ steps_offset: 1
23
+
24
+ vae_kwargs:
25
+ enable_magvit: true
26
+ mini_batch_encoder: 9
config/easyanimate_video_motion_module_v1.yaml ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ transformer_additional_kwargs:
2
+ patch_3d: false
3
+ fake_3d: false
4
+ basic_block_type: "motionmodule"
5
+ time_position_encoding_before_transformer: false
6
+ motion_module_type: "VanillaGrid"
7
+
8
+ motion_module_kwargs:
9
+ num_attention_heads: 8
10
+ num_transformer_block: 1
11
+ attention_block_types: [ "Temporal_Self", "Temporal_Self" ]
12
+ temporal_position_encoding: true
13
+ temporal_position_encoding_max_len: 4096
14
+ temporal_attention_dim_div: 1
15
+ block_size: 2
16
+
17
+ noise_scheduler_kwargs:
18
+ beta_start: 0.0001
19
+ beta_end: 0.02
20
+ beta_schedule: "linear"
21
+ steps_offset: 1
22
+
23
+ vae_kwargs:
24
+ enable_magvit: false
config/easyanimate_video_slicevae_motion_module_v3.yaml ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ transformer_additional_kwargs:
2
+ patch_3d: false
3
+ fake_3d: false
4
+ basic_block_type: "motionmodule"
5
+ time_position_encoding_before_transformer: false
6
+ motion_module_type: "Vanilla"
7
+ enable_uvit: true
8
+
9
+ motion_module_kwargs:
10
+ num_attention_heads: 8
11
+ num_transformer_block: 1
12
+ attention_block_types: [ "Temporal_Self", "Temporal_Self" ]
13
+ temporal_position_encoding: true
14
+ temporal_position_encoding_max_len: 4096
15
+ temporal_attention_dim_div: 1
16
+ block_size: 1
17
+
18
+ noise_scheduler_kwargs:
19
+ beta_start: 0.0001
20
+ beta_end: 0.02
21
+ beta_schedule: "linear"
22
+ steps_offset: 1
23
+
24
+ vae_kwargs:
25
+ enable_magvit: true
26
+ slice_compression_vae: true
27
+ mini_batch_encoder: 8
easyanimate/__init__.py ADDED
File without changes
easyanimate/api/api.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import base64
3
+ import torch
4
+ import gradio as gr
5
+
6
+ from fastapi import FastAPI
7
+ from io import BytesIO
8
+
9
+ # Function to encode a file to Base64
10
+ def encode_file_to_base64(file_path):
11
+ with open(file_path, "rb") as file:
12
+ # Encode the data to Base64
13
+ file_base64 = base64.b64encode(file.read())
14
+ return file_base64
15
+
16
+ def update_edition_api(_: gr.Blocks, app: FastAPI, controller):
17
+ @app.post("/easyanimate/update_edition")
18
+ def _update_edition_api(
19
+ datas: dict,
20
+ ):
21
+ edition = datas.get('edition', 'v2')
22
+
23
+ try:
24
+ controller.update_edition(
25
+ edition
26
+ )
27
+ comment = "Success"
28
+ except Exception as e:
29
+ torch.cuda.empty_cache()
30
+ comment = f"Error. error information is {str(e)}"
31
+
32
+ return {"message": comment}
33
+
34
+ def update_diffusion_transformer_api(_: gr.Blocks, app: FastAPI, controller):
35
+ @app.post("/easyanimate/update_diffusion_transformer")
36
+ def _update_diffusion_transformer_api(
37
+ datas: dict,
38
+ ):
39
+ diffusion_transformer_path = datas.get('diffusion_transformer_path', 'none')
40
+
41
+ try:
42
+ controller.update_diffusion_transformer(
43
+ diffusion_transformer_path
44
+ )
45
+ comment = "Success"
46
+ except Exception as e:
47
+ torch.cuda.empty_cache()
48
+ comment = f"Error. error information is {str(e)}"
49
+
50
+ return {"message": comment}
51
+
52
+ def infer_forward_api(_: gr.Blocks, app: FastAPI, controller):
53
+ @app.post("/easyanimate/infer_forward")
54
+ def _infer_forward_api(
55
+ datas: dict,
56
+ ):
57
+ base_model_path = datas.get('base_model_path', 'none')
58
+ motion_module_path = datas.get('motion_module_path', 'none')
59
+ lora_model_path = datas.get('lora_model_path', 'none')
60
+ lora_alpha_slider = datas.get('lora_alpha_slider', 0.55)
61
+ prompt_textbox = datas.get('prompt_textbox', None)
62
+ negative_prompt_textbox = datas.get('negative_prompt_textbox', '')
63
+ sampler_dropdown = datas.get('sampler_dropdown', 'Euler')
64
+ sample_step_slider = datas.get('sample_step_slider', 30)
65
+ width_slider = datas.get('width_slider', 672)
66
+ height_slider = datas.get('height_slider', 384)
67
+ is_image = datas.get('is_image', False)
68
+ length_slider = datas.get('length_slider', 144)
69
+ cfg_scale_slider = datas.get('cfg_scale_slider', 6)
70
+ seed_textbox = datas.get("seed_textbox", 43)
71
+
72
+ try:
73
+ save_sample_path, comment = controller.generate(
74
+ "",
75
+ base_model_path,
76
+ motion_module_path,
77
+ lora_model_path,
78
+ lora_alpha_slider,
79
+ prompt_textbox,
80
+ negative_prompt_textbox,
81
+ sampler_dropdown,
82
+ sample_step_slider,
83
+ width_slider,
84
+ height_slider,
85
+ is_image,
86
+ length_slider,
87
+ cfg_scale_slider,
88
+ seed_textbox,
89
+ is_api = True,
90
+ )
91
+ except Exception as e:
92
+ torch.cuda.empty_cache()
93
+ save_sample_path = ""
94
+ comment = f"Error. error information is {str(e)}"
95
+
96
+ return {"message": comment, "save_sample_path": save_sample_path, "base64_encoding": encode_file_to_base64(save_sample_path)}
easyanimate/api/post_infer.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import json
3
+ import sys
4
+ import time
5
+ from datetime import datetime
6
+ from io import BytesIO
7
+
8
+ import cv2
9
+ import requests
10
+ import base64
11
+
12
+
13
+ def post_diffusion_transformer(diffusion_transformer_path, url='http://127.0.0.1:7860'):
14
+ datas = json.dumps({
15
+ "diffusion_transformer_path": diffusion_transformer_path
16
+ })
17
+ r = requests.post(f'{url}/easyanimate/update_diffusion_transformer', data=datas, timeout=1500)
18
+ data = r.content.decode('utf-8')
19
+ return data
20
+
21
+ def post_update_edition(edition, url='http://0.0.0.0:7860'):
22
+ datas = json.dumps({
23
+ "edition": edition
24
+ })
25
+ r = requests.post(f'{url}/easyanimate/update_edition', data=datas, timeout=1500)
26
+ data = r.content.decode('utf-8')
27
+ return data
28
+
29
+ def post_infer(is_image, length_slider, url='http://127.0.0.1:7860'):
30
+ datas = json.dumps({
31
+ "base_model_path": "none",
32
+ "motion_module_path": "none",
33
+ "lora_model_path": "none",
34
+ "lora_alpha_slider": 0.55,
35
+ "prompt_textbox": "This video shows Mount saint helens, washington - the stunning scenery of a rocky mountains during golden hours - wide shot. A soaring drone footage captures the majestic beauty of a coastal cliff, its red and yellow stratified rock faces rich in color and against the vibrant turquoise of the sea.",
36
+ "negative_prompt_textbox": "Strange motion trajectory, a poor composition and deformed video, worst quality, normal quality, low quality, low resolution, duplicate and ugly, strange body structure, long and strange neck, bad teeth, bad eyes, bad limbs, bad hands, rotating camera, blurry camera, shaking camera",
37
+ "sampler_dropdown": "Euler",
38
+ "sample_step_slider": 30,
39
+ "width_slider": 672,
40
+ "height_slider": 384,
41
+ "is_image": is_image,
42
+ "length_slider": length_slider,
43
+ "cfg_scale_slider": 6,
44
+ "seed_textbox": 43,
45
+ })
46
+ r = requests.post(f'{url}/easyanimate/infer_forward', data=datas, timeout=1500)
47
+ data = r.content.decode('utf-8')
48
+ return data
49
+
50
+ if __name__ == '__main__':
51
+ # initiate time
52
+ now_date = datetime.now()
53
+ time_start = time.time()
54
+
55
+ # -------------------------- #
56
+ # Step 1: update edition
57
+ # -------------------------- #
58
+ edition = "v2"
59
+ outputs = post_update_edition(edition)
60
+ print('Output update edition: ', outputs)
61
+
62
+ # -------------------------- #
63
+ # Step 2: update edition
64
+ # -------------------------- #
65
+ diffusion_transformer_path = "/your-path/EasyAnimate/models/Diffusion_Transformer/EasyAnimateV2-XL-2-512x512"
66
+ outputs = post_diffusion_transformer(diffusion_transformer_path)
67
+ print('Output update edition: ', outputs)
68
+
69
+ # -------------------------- #
70
+ # Step 3: infer
71
+ # -------------------------- #
72
+ is_image = False
73
+ length_slider = 27
74
+ outputs = post_infer(is_image, length_slider)
75
+
76
+ # Get decoded data
77
+ outputs = json.loads(outputs)
78
+ base64_encoding = outputs["base64_encoding"]
79
+ decoded_data = base64.b64decode(base64_encoding)
80
+
81
+ if is_image or length_slider == 1:
82
+ file_path = "1.png"
83
+ else:
84
+ file_path = "1.mp4"
85
+ with open(file_path, "wb") as file:
86
+ file.write(decoded_data)
87
+
88
+ # End of record time
89
+ # The calculated time difference is the execution time of the program, expressed in seconds / s
90
+ time_end = time.time()
91
+ time_sum = (time_end - time_start) % 60
92
+ print('# --------------------------------------------------------- #')
93
+ print(f'# Total expenditure: {time_sum}s')
94
+ print('# --------------------------------------------------------- #')
easyanimate/data/bucket_sampler.py ADDED
@@ -0,0 +1,379 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import os
3
+ from typing import (Generic, Iterable, Iterator, List, Optional, Sequence,
4
+ Sized, TypeVar, Union)
5
+
6
+ import cv2
7
+ import numpy as np
8
+ import torch
9
+ from PIL import Image
10
+ from torch.utils.data import BatchSampler, Dataset, Sampler
11
+
12
+ ASPECT_RATIO_512 = {
13
+ '0.25': [256.0, 1024.0], '0.26': [256.0, 992.0], '0.27': [256.0, 960.0], '0.28': [256.0, 928.0],
14
+ '0.32': [288.0, 896.0], '0.33': [288.0, 864.0], '0.35': [288.0, 832.0], '0.4': [320.0, 800.0],
15
+ '0.42': [320.0, 768.0], '0.48': [352.0, 736.0], '0.5': [352.0, 704.0], '0.52': [352.0, 672.0],
16
+ '0.57': [384.0, 672.0], '0.6': [384.0, 640.0], '0.68': [416.0, 608.0], '0.72': [416.0, 576.0],
17
+ '0.78': [448.0, 576.0], '0.82': [448.0, 544.0], '0.88': [480.0, 544.0], '0.94': [480.0, 512.0],
18
+ '1.0': [512.0, 512.0], '1.07': [512.0, 480.0], '1.13': [544.0, 480.0], '1.21': [544.0, 448.0],
19
+ '1.29': [576.0, 448.0], '1.38': [576.0, 416.0], '1.46': [608.0, 416.0], '1.67': [640.0, 384.0],
20
+ '1.75': [672.0, 384.0], '2.0': [704.0, 352.0], '2.09': [736.0, 352.0], '2.4': [768.0, 320.0],
21
+ '2.5': [800.0, 320.0], '2.89': [832.0, 288.0], '3.0': [864.0, 288.0], '3.11': [896.0, 288.0],
22
+ '3.62': [928.0, 256.0], '3.75': [960.0, 256.0], '3.88': [992.0, 256.0], '4.0': [1024.0, 256.0]
23
+ }
24
+ ASPECT_RATIO_RANDOM_CROP_512 = {
25
+ '0.42': [320.0, 768.0], '0.5': [352.0, 704.0],
26
+ '0.57': [384.0, 672.0], '0.68': [416.0, 608.0], '0.78': [448.0, 576.0], '0.88': [480.0, 544.0],
27
+ '0.94': [480.0, 512.0], '1.0': [512.0, 512.0], '1.07': [512.0, 480.0],
28
+ '1.13': [544.0, 480.0], '1.29': [576.0, 448.0], '1.46': [608.0, 416.0], '1.75': [672.0, 384.0],
29
+ '2.0': [704.0, 352.0], '2.4': [768.0, 320.0]
30
+ }
31
+ ASPECT_RATIO_RANDOM_CROP_PROB = [
32
+ 1, 2,
33
+ 4, 4, 4, 4,
34
+ 8, 8, 8,
35
+ 4, 4, 4, 4,
36
+ 2, 1
37
+ ]
38
+ ASPECT_RATIO_RANDOM_CROP_PROB = np.array(ASPECT_RATIO_RANDOM_CROP_PROB) / sum(ASPECT_RATIO_RANDOM_CROP_PROB)
39
+
40
+ def get_closest_ratio(height: float, width: float, ratios: dict = ASPECT_RATIO_512):
41
+ aspect_ratio = height / width
42
+ closest_ratio = min(ratios.keys(), key=lambda ratio: abs(float(ratio) - aspect_ratio))
43
+ return ratios[closest_ratio], float(closest_ratio)
44
+
45
+ def get_image_size_without_loading(path):
46
+ with Image.open(path) as img:
47
+ return img.size # (width, height)
48
+
49
+ class RandomSampler(Sampler[int]):
50
+ r"""Samples elements randomly. If without replacement, then sample from a shuffled dataset.
51
+
52
+ If with replacement, then user can specify :attr:`num_samples` to draw.
53
+
54
+ Args:
55
+ data_source (Dataset): dataset to sample from
56
+ replacement (bool): samples are drawn on-demand with replacement if ``True``, default=``False``
57
+ num_samples (int): number of samples to draw, default=`len(dataset)`.
58
+ generator (Generator): Generator used in sampling.
59
+ """
60
+
61
+ data_source: Sized
62
+ replacement: bool
63
+
64
+ def __init__(self, data_source: Sized, replacement: bool = False,
65
+ num_samples: Optional[int] = None, generator=None) -> None:
66
+ self.data_source = data_source
67
+ self.replacement = replacement
68
+ self._num_samples = num_samples
69
+ self.generator = generator
70
+ self._pos_start = 0
71
+
72
+ if not isinstance(self.replacement, bool):
73
+ raise TypeError(f"replacement should be a boolean value, but got replacement={self.replacement}")
74
+
75
+ if not isinstance(self.num_samples, int) or self.num_samples <= 0:
76
+ raise ValueError(f"num_samples should be a positive integer value, but got num_samples={self.num_samples}")
77
+
78
+ @property
79
+ def num_samples(self) -> int:
80
+ # dataset size might change at runtime
81
+ if self._num_samples is None:
82
+ return len(self.data_source)
83
+ return self._num_samples
84
+
85
+ def __iter__(self) -> Iterator[int]:
86
+ n = len(self.data_source)
87
+ if self.generator is None:
88
+ seed = int(torch.empty((), dtype=torch.int64).random_().item())
89
+ generator = torch.Generator()
90
+ generator.manual_seed(seed)
91
+ else:
92
+ generator = self.generator
93
+
94
+ if self.replacement:
95
+ for _ in range(self.num_samples // 32):
96
+ yield from torch.randint(high=n, size=(32,), dtype=torch.int64, generator=generator).tolist()
97
+ yield from torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64, generator=generator).tolist()
98
+ else:
99
+ for _ in range(self.num_samples // n):
100
+ xx = torch.randperm(n, generator=generator).tolist()
101
+ if self._pos_start >= n:
102
+ self._pos_start = 0
103
+ print("xx top 10", xx[:10], self._pos_start)
104
+ for idx in range(self._pos_start, n):
105
+ yield xx[idx]
106
+ self._pos_start = (self._pos_start + 1) % n
107
+ self._pos_start = 0
108
+ yield from torch.randperm(n, generator=generator).tolist()[:self.num_samples % n]
109
+
110
+ def __len__(self) -> int:
111
+ return self.num_samples
112
+
113
+ class AspectRatioBatchImageSampler(BatchSampler):
114
+ """A sampler wrapper for grouping images with similar aspect ratio into a same batch.
115
+
116
+ Args:
117
+ sampler (Sampler): Base sampler.
118
+ dataset (Dataset): Dataset providing data information.
119
+ batch_size (int): Size of mini-batch.
120
+ drop_last (bool): If ``True``, the sampler will drop the last batch if
121
+ its size would be less than ``batch_size``.
122
+ aspect_ratios (dict): The predefined aspect ratios.
123
+ """
124
+ def __init__(
125
+ self,
126
+ sampler: Sampler,
127
+ dataset: Dataset,
128
+ batch_size: int,
129
+ train_folder: str = None,
130
+ aspect_ratios: dict = ASPECT_RATIO_512,
131
+ drop_last: bool = False,
132
+ config=None,
133
+ **kwargs
134
+ ) -> None:
135
+ if not isinstance(sampler, Sampler):
136
+ raise TypeError('sampler should be an instance of ``Sampler``, '
137
+ f'but got {sampler}')
138
+ if not isinstance(batch_size, int) or batch_size <= 0:
139
+ raise ValueError('batch_size should be a positive integer value, '
140
+ f'but got batch_size={batch_size}')
141
+ self.sampler = sampler
142
+ self.dataset = dataset
143
+ self.train_folder = train_folder
144
+ self.batch_size = batch_size
145
+ self.aspect_ratios = aspect_ratios
146
+ self.drop_last = drop_last
147
+ self.config = config
148
+ # buckets for each aspect ratio
149
+ self._aspect_ratio_buckets = {ratio: [] for ratio in aspect_ratios}
150
+ # [str(k) for k, v in aspect_ratios]
151
+ self.current_available_bucket_keys = list(aspect_ratios.keys())
152
+
153
+ def __iter__(self):
154
+ for idx in self.sampler:
155
+ try:
156
+ image_dict = self.dataset[idx]
157
+
158
+ width, height = image_dict.get("width", None), image_dict.get("height", None)
159
+ if width is None or height is None:
160
+ image_id, name = image_dict['file_path'], image_dict['text']
161
+ if self.train_folder is None:
162
+ image_dir = image_id
163
+ else:
164
+ image_dir = os.path.join(self.train_folder, image_id)
165
+
166
+ width, height = get_image_size_without_loading(image_dir)
167
+
168
+ ratio = height / width # self.dataset[idx]
169
+ else:
170
+ height = int(height)
171
+ width = int(width)
172
+ ratio = height / width # self.dataset[idx]
173
+ except Exception as e:
174
+ print(e)
175
+ continue
176
+ # find the closest aspect ratio
177
+ closest_ratio = min(self.aspect_ratios.keys(), key=lambda r: abs(float(r) - ratio))
178
+ if closest_ratio not in self.current_available_bucket_keys:
179
+ continue
180
+ bucket = self._aspect_ratio_buckets[closest_ratio]
181
+ bucket.append(idx)
182
+ # yield a batch of indices in the same aspect ratio group
183
+ if len(bucket) == self.batch_size:
184
+ yield bucket[:]
185
+ del bucket[:]
186
+
187
+ class AspectRatioBatchSampler(BatchSampler):
188
+ """A sampler wrapper for grouping images with similar aspect ratio into a same batch.
189
+
190
+ Args:
191
+ sampler (Sampler): Base sampler.
192
+ dataset (Dataset): Dataset providing data information.
193
+ batch_size (int): Size of mini-batch.
194
+ drop_last (bool): If ``True``, the sampler will drop the last batch if
195
+ its size would be less than ``batch_size``.
196
+ aspect_ratios (dict): The predefined aspect ratios.
197
+ """
198
+ def __init__(
199
+ self,
200
+ sampler: Sampler,
201
+ dataset: Dataset,
202
+ batch_size: int,
203
+ video_folder: str = None,
204
+ train_data_format: str = "webvid",
205
+ aspect_ratios: dict = ASPECT_RATIO_512,
206
+ drop_last: bool = False,
207
+ config=None,
208
+ **kwargs
209
+ ) -> None:
210
+ if not isinstance(sampler, Sampler):
211
+ raise TypeError('sampler should be an instance of ``Sampler``, '
212
+ f'but got {sampler}')
213
+ if not isinstance(batch_size, int) or batch_size <= 0:
214
+ raise ValueError('batch_size should be a positive integer value, '
215
+ f'but got batch_size={batch_size}')
216
+ self.sampler = sampler
217
+ self.dataset = dataset
218
+ self.video_folder = video_folder
219
+ self.train_data_format = train_data_format
220
+ self.batch_size = batch_size
221
+ self.aspect_ratios = aspect_ratios
222
+ self.drop_last = drop_last
223
+ self.config = config
224
+ # buckets for each aspect ratio
225
+ self._aspect_ratio_buckets = {ratio: [] for ratio in aspect_ratios}
226
+ # [str(k) for k, v in aspect_ratios]
227
+ self.current_available_bucket_keys = list(aspect_ratios.keys())
228
+
229
+ def __iter__(self):
230
+ for idx in self.sampler:
231
+ try:
232
+ video_dict = self.dataset[idx]
233
+ width, more = video_dict.get("width", None), video_dict.get("height", None)
234
+
235
+ if width is None or height is None:
236
+ if self.train_data_format == "normal":
237
+ video_id, name = video_dict['file_path'], video_dict['text']
238
+ if self.video_folder is None:
239
+ video_dir = video_id
240
+ else:
241
+ video_dir = os.path.join(self.video_folder, video_id)
242
+ else:
243
+ videoid, name, page_dir = video_dict['videoid'], video_dict['name'], video_dict['page_dir']
244
+ video_dir = os.path.join(self.video_folder, f"{videoid}.mp4")
245
+ cap = cv2.VideoCapture(video_dir)
246
+
247
+ # 获取视频尺寸
248
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) # 浮点数转换为整数
249
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) # 浮点数转换为整数
250
+
251
+ ratio = height / width # self.dataset[idx]
252
+ else:
253
+ height = int(height)
254
+ width = int(width)
255
+ ratio = height / width # self.dataset[idx]
256
+ except Exception as e:
257
+ print(e)
258
+ continue
259
+ # find the closest aspect ratio
260
+ closest_ratio = min(self.aspect_ratios.keys(), key=lambda r: abs(float(r) - ratio))
261
+ if closest_ratio not in self.current_available_bucket_keys:
262
+ continue
263
+ bucket = self._aspect_ratio_buckets[closest_ratio]
264
+ bucket.append(idx)
265
+ # yield a batch of indices in the same aspect ratio group
266
+ if len(bucket) == self.batch_size:
267
+ yield bucket[:]
268
+ del bucket[:]
269
+
270
+ class AspectRatioBatchImageVideoSampler(BatchSampler):
271
+ """A sampler wrapper for grouping images with similar aspect ratio into a same batch.
272
+
273
+ Args:
274
+ sampler (Sampler): Base sampler.
275
+ dataset (Dataset): Dataset providing data information.
276
+ batch_size (int): Size of mini-batch.
277
+ drop_last (bool): If ``True``, the sampler will drop the last batch if
278
+ its size would be less than ``batch_size``.
279
+ aspect_ratios (dict): The predefined aspect ratios.
280
+ """
281
+
282
+ def __init__(self,
283
+ sampler: Sampler,
284
+ dataset: Dataset,
285
+ batch_size: int,
286
+ train_folder: str = None,
287
+ aspect_ratios: dict = ASPECT_RATIO_512,
288
+ drop_last: bool = False
289
+ ) -> None:
290
+ if not isinstance(sampler, Sampler):
291
+ raise TypeError('sampler should be an instance of ``Sampler``, '
292
+ f'but got {sampler}')
293
+ if not isinstance(batch_size, int) or batch_size <= 0:
294
+ raise ValueError('batch_size should be a positive integer value, '
295
+ f'but got batch_size={batch_size}')
296
+ self.sampler = sampler
297
+ self.dataset = dataset
298
+ self.train_folder = train_folder
299
+ self.batch_size = batch_size
300
+ self.aspect_ratios = aspect_ratios
301
+ self.drop_last = drop_last
302
+
303
+ # buckets for each aspect ratio
304
+ self.current_available_bucket_keys = list(aspect_ratios.keys())
305
+ self.bucket = {
306
+ 'image':{ratio: [] for ratio in aspect_ratios},
307
+ 'video':{ratio: [] for ratio in aspect_ratios}
308
+ }
309
+
310
+ def __iter__(self):
311
+ for idx in self.sampler:
312
+ content_type = self.dataset[idx].get('type', 'image')
313
+ if content_type == 'image':
314
+ try:
315
+ image_dict = self.dataset[idx]
316
+
317
+ width, height = image_dict.get("width", None), image_dict.get("height", None)
318
+ if width is None or height is None:
319
+ image_id, name = image_dict['file_path'], image_dict['text']
320
+ if self.train_folder is None:
321
+ image_dir = image_id
322
+ else:
323
+ image_dir = os.path.join(self.train_folder, image_id)
324
+
325
+ width, height = get_image_size_without_loading(image_dir)
326
+
327
+ ratio = height / width # self.dataset[idx]
328
+ else:
329
+ height = int(height)
330
+ width = int(width)
331
+ ratio = height / width # self.dataset[idx]
332
+ except Exception as e:
333
+ print(e)
334
+ continue
335
+ # find the closest aspect ratio
336
+ closest_ratio = min(self.aspect_ratios.keys(), key=lambda r: abs(float(r) - ratio))
337
+ if closest_ratio not in self.current_available_bucket_keys:
338
+ continue
339
+ bucket = self.bucket['image'][closest_ratio]
340
+ bucket.append(idx)
341
+ # yield a batch of indices in the same aspect ratio group
342
+ if len(bucket) == self.batch_size:
343
+ yield bucket[:]
344
+ del bucket[:]
345
+ else:
346
+ try:
347
+ video_dict = self.dataset[idx]
348
+ width, height = video_dict.get("width", None), video_dict.get("height", None)
349
+
350
+ if width is None or height is None:
351
+ video_id, name = video_dict['file_path'], video_dict['text']
352
+ if self.train_folder is None:
353
+ video_dir = video_id
354
+ else:
355
+ video_dir = os.path.join(self.train_folder, video_id)
356
+ cap = cv2.VideoCapture(video_dir)
357
+
358
+ # 获取视频尺寸
359
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) # 浮点数转换为整数
360
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) # 浮点数转换为整数
361
+
362
+ ratio = height / width # self.dataset[idx]
363
+ else:
364
+ height = int(height)
365
+ width = int(width)
366
+ ratio = height / width # self.dataset[idx]
367
+ except Exception as e:
368
+ print(e)
369
+ continue
370
+ # find the closest aspect ratio
371
+ closest_ratio = min(self.aspect_ratios.keys(), key=lambda r: abs(float(r) - ratio))
372
+ if closest_ratio not in self.current_available_bucket_keys:
373
+ continue
374
+ bucket = self.bucket['video'][closest_ratio]
375
+ bucket.append(idx)
376
+ # yield a batch of indices in the same aspect ratio group
377
+ if len(bucket) == self.batch_size:
378
+ yield bucket[:]
379
+ del bucket[:]
easyanimate/data/dataset_image.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import random
4
+
5
+ import numpy as np
6
+ import torch
7
+ import torchvision.transforms as transforms
8
+ from PIL import Image
9
+ from torch.utils.data.dataset import Dataset
10
+
11
+
12
+ class CC15M(Dataset):
13
+ def __init__(
14
+ self,
15
+ json_path,
16
+ video_folder=None,
17
+ resolution=512,
18
+ enable_bucket=False,
19
+ ):
20
+ print(f"loading annotations from {json_path} ...")
21
+ self.dataset = json.load(open(json_path, 'r'))
22
+ self.length = len(self.dataset)
23
+ print(f"data scale: {self.length}")
24
+
25
+ self.enable_bucket = enable_bucket
26
+ self.video_folder = video_folder
27
+
28
+ resolution = tuple(resolution) if not isinstance(resolution, int) else (resolution, resolution)
29
+ self.pixel_transforms = transforms.Compose([
30
+ transforms.Resize(resolution[0]),
31
+ transforms.CenterCrop(resolution),
32
+ transforms.ToTensor(),
33
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
34
+ ])
35
+
36
+ def get_batch(self, idx):
37
+ video_dict = self.dataset[idx]
38
+ video_id, name = video_dict['file_path'], video_dict['text']
39
+
40
+ if self.video_folder is None:
41
+ video_dir = video_id
42
+ else:
43
+ video_dir = os.path.join(self.video_folder, video_id)
44
+
45
+ pixel_values = Image.open(video_dir).convert("RGB")
46
+ return pixel_values, name
47
+
48
+ def __len__(self):
49
+ return self.length
50
+
51
+ def __getitem__(self, idx):
52
+ while True:
53
+ try:
54
+ pixel_values, name = self.get_batch(idx)
55
+ break
56
+ except Exception as e:
57
+ print(e)
58
+ idx = random.randint(0, self.length-1)
59
+
60
+ if not self.enable_bucket:
61
+ pixel_values = self.pixel_transforms(pixel_values)
62
+ else:
63
+ pixel_values = np.array(pixel_values)
64
+
65
+ sample = dict(pixel_values=pixel_values, text=name)
66
+ return sample
67
+
68
+ if __name__ == "__main__":
69
+ dataset = CC15M(
70
+ csv_path="/mnt_wg/zhoumo.xjq/CCUtils/cc15m_add_index.json",
71
+ resolution=512,
72
+ )
73
+
74
+ dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, num_workers=0,)
75
+ for idx, batch in enumerate(dataloader):
76
+ print(batch["pixel_values"].shape, len(batch["text"]))
easyanimate/data/dataset_image_video.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import csv
2
+ import io
3
+ import json
4
+ import math
5
+ import os
6
+ import random
7
+ from threading import Thread
8
+
9
+ import albumentations
10
+ import cv2
11
+ import gc
12
+ import numpy as np
13
+ import torch
14
+ import torchvision.transforms as transforms
15
+ from func_timeout import func_timeout, FunctionTimedOut
16
+ from decord import VideoReader
17
+ from PIL import Image
18
+ from torch.utils.data import BatchSampler, Sampler
19
+ from torch.utils.data.dataset import Dataset
20
+ from contextlib import contextmanager
21
+
22
+ VIDEO_READER_TIMEOUT = 20
23
+
24
+ class ImageVideoSampler(BatchSampler):
25
+ """A sampler wrapper for grouping images with similar aspect ratio into a same batch.
26
+
27
+ Args:
28
+ sampler (Sampler): Base sampler.
29
+ dataset (Dataset): Dataset providing data information.
30
+ batch_size (int): Size of mini-batch.
31
+ drop_last (bool): If ``True``, the sampler will drop the last batch if
32
+ its size would be less than ``batch_size``.
33
+ aspect_ratios (dict): The predefined aspect ratios.
34
+ """
35
+
36
+ def __init__(self,
37
+ sampler: Sampler,
38
+ dataset: Dataset,
39
+ batch_size: int,
40
+ drop_last: bool = False
41
+ ) -> None:
42
+ if not isinstance(sampler, Sampler):
43
+ raise TypeError('sampler should be an instance of ``Sampler``, '
44
+ f'but got {sampler}')
45
+ if not isinstance(batch_size, int) or batch_size <= 0:
46
+ raise ValueError('batch_size should be a positive integer value, '
47
+ f'but got batch_size={batch_size}')
48
+ self.sampler = sampler
49
+ self.dataset = dataset
50
+ self.batch_size = batch_size
51
+ self.drop_last = drop_last
52
+
53
+ # buckets for each aspect ratio
54
+ self.bucket = {'image':[], 'video':[]}
55
+
56
+ def __iter__(self):
57
+ for idx in self.sampler:
58
+ content_type = self.dataset.dataset[idx].get('type', 'image')
59
+ self.bucket[content_type].append(idx)
60
+
61
+ # yield a batch of indices in the same aspect ratio group
62
+ if len(self.bucket['video']) == self.batch_size:
63
+ bucket = self.bucket['video']
64
+ yield bucket[:]
65
+ del bucket[:]
66
+ elif len(self.bucket['image']) == self.batch_size:
67
+ bucket = self.bucket['image']
68
+ yield bucket[:]
69
+ del bucket[:]
70
+
71
+ @contextmanager
72
+ def VideoReader_contextmanager(*args, **kwargs):
73
+ vr = VideoReader(*args, **kwargs)
74
+ try:
75
+ yield vr
76
+ finally:
77
+ del vr
78
+ gc.collect()
79
+
80
+ def get_video_reader_batch(video_reader, batch_index):
81
+ frames = video_reader.get_batch(batch_index).asnumpy()
82
+ return frames
83
+
84
+ class ImageVideoDataset(Dataset):
85
+ def __init__(
86
+ self,
87
+ ann_path, data_root=None,
88
+ video_sample_size=512, video_sample_stride=4, video_sample_n_frames=16,
89
+ image_sample_size=512,
90
+ video_repeat=0,
91
+ text_drop_ratio=0.001,
92
+ enable_bucket=False,
93
+ video_length_drop_start=0.1,
94
+ video_length_drop_end=0.9,
95
+ ):
96
+ # Loading annotations from files
97
+ print(f"loading annotations from {ann_path} ...")
98
+ if ann_path.endswith('.csv'):
99
+ with open(ann_path, 'r') as csvfile:
100
+ dataset = list(csv.DictReader(csvfile))
101
+ elif ann_path.endswith('.json'):
102
+ dataset = json.load(open(ann_path))
103
+
104
+ self.data_root = data_root
105
+
106
+ # It's used to balance num of images and videos.
107
+ self.dataset = []
108
+ for data in dataset:
109
+ if data.get('type', 'image') != 'video':
110
+ self.dataset.append(data)
111
+ if video_repeat > 0:
112
+ for _ in range(video_repeat):
113
+ for data in dataset:
114
+ if data.get('type', 'image') == 'video':
115
+ self.dataset.append(data)
116
+ del dataset
117
+
118
+ self.length = len(self.dataset)
119
+ print(f"data scale: {self.length}")
120
+ # TODO: enable bucket training
121
+ self.enable_bucket = enable_bucket
122
+ self.text_drop_ratio = text_drop_ratio
123
+ self.video_length_drop_start = video_length_drop_start
124
+ self.video_length_drop_end = video_length_drop_end
125
+
126
+ # Video params
127
+ self.video_sample_stride = video_sample_stride
128
+ self.video_sample_n_frames = video_sample_n_frames
129
+ video_sample_size = tuple(video_sample_size) if not isinstance(video_sample_size, int) else (video_sample_size, video_sample_size)
130
+ self.video_transforms = transforms.Compose(
131
+ [
132
+ transforms.Resize(video_sample_size[0]),
133
+ transforms.CenterCrop(video_sample_size),
134
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
135
+ ]
136
+ )
137
+
138
+ # Image params
139
+ self.image_sample_size = tuple(image_sample_size) if not isinstance(image_sample_size, int) else (image_sample_size, image_sample_size)
140
+ self.image_transforms = transforms.Compose([
141
+ transforms.Resize(min(self.image_sample_size)),
142
+ transforms.CenterCrop(self.image_sample_size),
143
+ transforms.ToTensor(),
144
+ transforms.Normalize([0.5, 0.5, 0.5],[0.5, 0.5, 0.5])
145
+ ])
146
+
147
+ def get_batch(self, idx):
148
+ data_info = self.dataset[idx % len(self.dataset)]
149
+
150
+ if data_info.get('type', 'image')=='video':
151
+ video_id, text = data_info['file_path'], data_info['text']
152
+
153
+ if self.data_root is None:
154
+ video_dir = video_id
155
+ else:
156
+ video_dir = os.path.join(self.data_root, video_id)
157
+
158
+ with VideoReader_contextmanager(video_dir, num_threads=2) as video_reader:
159
+ min_sample_n_frames = min(
160
+ self.video_sample_n_frames,
161
+ int(len(video_reader) * (self.video_length_drop_end - self.video_length_drop_start))
162
+ )
163
+ if min_sample_n_frames == 0:
164
+ raise ValueError(f"No Frames in video.")
165
+
166
+ video_length = int(self.video_length_drop_end * len(video_reader))
167
+ clip_length = min(video_length, (min_sample_n_frames - 1) * self.video_sample_stride + 1)
168
+ start_idx = random.randint(int(self.video_length_drop_start * video_length), video_length - clip_length)
169
+ batch_index = np.linspace(start_idx, start_idx + clip_length - 1, min_sample_n_frames, dtype=int)
170
+
171
+ try:
172
+ sample_args = (video_reader, batch_index)
173
+ pixel_values = func_timeout(
174
+ VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args
175
+ )
176
+ except FunctionTimedOut:
177
+ raise ValueError(f"Read {idx} timeout.")
178
+ except Exception as e:
179
+ raise ValueError(f"Failed to extract frames from video. Error is {e}.")
180
+
181
+ if not self.enable_bucket:
182
+ pixel_values = torch.from_numpy(pixel_values).permute(0, 3, 1, 2).contiguous()
183
+ pixel_values = pixel_values / 255.
184
+ del video_reader
185
+ else:
186
+ pixel_values = pixel_values
187
+
188
+ if not self.enable_bucket:
189
+ pixel_values = self.video_transforms(pixel_values)
190
+
191
+ # Random use no text generation
192
+ if random.random() < self.text_drop_ratio:
193
+ text = ''
194
+ return pixel_values, text, 'video'
195
+ else:
196
+ image_path, text = data_info['file_path'], data_info['text']
197
+ if self.data_root is not None:
198
+ image_path = os.path.join(self.data_root, image_path)
199
+ image = Image.open(image_path).convert('RGB')
200
+ if not self.enable_bucket:
201
+ image = self.image_transforms(image).unsqueeze(0)
202
+ else:
203
+ image = np.expand_dims(np.array(image), 0)
204
+ if random.random() < self.text_drop_ratio:
205
+ text = ''
206
+ return image, text, 'image'
207
+
208
+ def __len__(self):
209
+ return self.length
210
+
211
+ def __getitem__(self, idx):
212
+ data_info = self.dataset[idx % len(self.dataset)]
213
+ data_type = data_info.get('type', 'image')
214
+ while True:
215
+ sample = {}
216
+ try:
217
+ data_info_local = self.dataset[idx % len(self.dataset)]
218
+ data_type_local = data_info_local.get('type', 'image')
219
+ if data_type_local != data_type:
220
+ raise ValueError("data_type_local != data_type")
221
+
222
+ pixel_values, name, data_type = self.get_batch(idx)
223
+ sample["pixel_values"] = pixel_values
224
+ sample["text"] = name
225
+ sample["data_type"] = data_type
226
+ sample["idx"] = idx
227
+
228
+ if len(sample) > 0:
229
+ break
230
+ except Exception as e:
231
+ print(e, self.dataset[idx % len(self.dataset)])
232
+ idx = random.randint(0, self.length-1)
233
+ return sample
234
+
235
+ if __name__ == "__main__":
236
+ dataset = ImageVideoDataset(
237
+ ann_path="test.json"
238
+ )
239
+ dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, num_workers=16)
240
+ for idx, batch in enumerate(dataloader):
241
+ print(batch["pixel_values"].shape, len(batch["text"]))
easyanimate/data/dataset_video.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import csv
2
+ import gc
3
+ import io
4
+ import json
5
+ import math
6
+ import os
7
+ import random
8
+ from contextlib import contextmanager
9
+ from threading import Thread
10
+
11
+ import albumentations
12
+ import cv2
13
+ import numpy as np
14
+ import torch
15
+ import torchvision.transforms as transforms
16
+ from decord import VideoReader
17
+ from einops import rearrange
18
+ from func_timeout import FunctionTimedOut, func_timeout
19
+ from PIL import Image
20
+ from torch.utils.data import BatchSampler, Sampler
21
+ from torch.utils.data.dataset import Dataset
22
+
23
+ VIDEO_READER_TIMEOUT = 20
24
+
25
+ def get_random_mask(shape):
26
+ f, c, h, w = shape
27
+
28
+ mask_index = np.random.randint(0, 4)
29
+ mask = torch.zeros((f, 1, h, w), dtype=torch.uint8)
30
+ if mask_index == 0:
31
+ mask[1:, :, :, :] = 1
32
+ elif mask_index == 1:
33
+ mask_frame_index = 1
34
+ mask[mask_frame_index:-mask_frame_index, :, :, :] = 1
35
+ elif mask_index == 2:
36
+ center_x = torch.randint(0, w, (1,)).item()
37
+ center_y = torch.randint(0, h, (1,)).item()
38
+ block_size_x = torch.randint(w // 4, w // 4 * 3, (1,)).item() # 方块的宽度范围
39
+ block_size_y = torch.randint(h // 4, h // 4 * 3, (1,)).item() # 方块的高度范围
40
+
41
+ start_x = max(center_x - block_size_x // 2, 0)
42
+ end_x = min(center_x + block_size_x // 2, w)
43
+ start_y = max(center_y - block_size_y // 2, 0)
44
+ end_y = min(center_y + block_size_y // 2, h)
45
+ mask[:, :, start_y:end_y, start_x:end_x] = 1
46
+ elif mask_index == 3:
47
+ center_x = torch.randint(0, w, (1,)).item()
48
+ center_y = torch.randint(0, h, (1,)).item()
49
+ block_size_x = torch.randint(w // 4, w // 4 * 3, (1,)).item() # 方块的宽度范围
50
+ block_size_y = torch.randint(h // 4, h // 4 * 3, (1,)).item() # 方块的高度范围
51
+
52
+ start_x = max(center_x - block_size_x // 2, 0)
53
+ end_x = min(center_x + block_size_x // 2, w)
54
+ start_y = max(center_y - block_size_y // 2, 0)
55
+ end_y = min(center_y + block_size_y // 2, h)
56
+
57
+ mask_frame_before = np.random.randint(0, f // 2)
58
+ mask_frame_after = np.random.randint(f // 2, f)
59
+ mask[mask_frame_before:mask_frame_after, :, start_y:end_y, start_x:end_x] = 1
60
+ else:
61
+ raise ValueError(f"The mask_index {mask_index} is not define")
62
+ return mask
63
+
64
+
65
+ @contextmanager
66
+ def VideoReader_contextmanager(*args, **kwargs):
67
+ vr = VideoReader(*args, **kwargs)
68
+ try:
69
+ yield vr
70
+ finally:
71
+ del vr
72
+ gc.collect()
73
+
74
+
75
+ def get_video_reader_batch(video_reader, batch_index):
76
+ frames = video_reader.get_batch(batch_index).asnumpy()
77
+ return frames
78
+
79
+
80
+ class WebVid10M(Dataset):
81
+ def __init__(
82
+ self,
83
+ csv_path, video_folder,
84
+ sample_size=256, sample_stride=4, sample_n_frames=16,
85
+ enable_bucket=False, enable_inpaint=False, is_image=False,
86
+ ):
87
+ print(f"loading annotations from {csv_path} ...")
88
+ with open(csv_path, 'r') as csvfile:
89
+ self.dataset = list(csv.DictReader(csvfile))
90
+ self.length = len(self.dataset)
91
+ print(f"data scale: {self.length}")
92
+
93
+ self.video_folder = video_folder
94
+ self.sample_stride = sample_stride
95
+ self.sample_n_frames = sample_n_frames
96
+ self.enable_bucket = enable_bucket
97
+ self.enable_inpaint = enable_inpaint
98
+ self.is_image = is_image
99
+
100
+ sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size)
101
+ self.pixel_transforms = transforms.Compose([
102
+ transforms.Resize(sample_size[0]),
103
+ transforms.CenterCrop(sample_size),
104
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
105
+ ])
106
+
107
+ def get_batch(self, idx):
108
+ video_dict = self.dataset[idx]
109
+ videoid, name, page_dir = video_dict['videoid'], video_dict['name'], video_dict['page_dir']
110
+
111
+ video_dir = os.path.join(self.video_folder, f"{videoid}.mp4")
112
+ video_reader = VideoReader(video_dir)
113
+ video_length = len(video_reader)
114
+
115
+ if not self.is_image:
116
+ clip_length = min(video_length, (self.sample_n_frames - 1) * self.sample_stride + 1)
117
+ start_idx = random.randint(0, video_length - clip_length)
118
+ batch_index = np.linspace(start_idx, start_idx + clip_length - 1, self.sample_n_frames, dtype=int)
119
+ else:
120
+ batch_index = [random.randint(0, video_length - 1)]
121
+
122
+ if not self.enable_bucket:
123
+ pixel_values = torch.from_numpy(video_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous()
124
+ pixel_values = pixel_values / 255.
125
+ del video_reader
126
+ else:
127
+ pixel_values = video_reader.get_batch(batch_index).asnumpy()
128
+
129
+ if self.is_image:
130
+ pixel_values = pixel_values[0]
131
+ return pixel_values, name
132
+
133
+ def __len__(self):
134
+ return self.length
135
+
136
+ def __getitem__(self, idx):
137
+ while True:
138
+ try:
139
+ pixel_values, name = self.get_batch(idx)
140
+ break
141
+
142
+ except Exception as e:
143
+ print("Error info:", e)
144
+ idx = random.randint(0, self.length-1)
145
+
146
+ if not self.enable_bucket:
147
+ pixel_values = self.pixel_transforms(pixel_values)
148
+ if self.enable_inpaint:
149
+ mask = get_random_mask(pixel_values.size())
150
+ mask_pixel_values = pixel_values * (1 - mask) + torch.ones_like(pixel_values) * -1 * mask
151
+ sample = dict(pixel_values=pixel_values, mask_pixel_values=mask_pixel_values, mask=mask, text=name)
152
+ else:
153
+ sample = dict(pixel_values=pixel_values, text=name)
154
+ return sample
155
+
156
+
157
+ class VideoDataset(Dataset):
158
+ def __init__(
159
+ self,
160
+ json_path, video_folder=None,
161
+ sample_size=256, sample_stride=4, sample_n_frames=16,
162
+ enable_bucket=False, enable_inpaint=False
163
+ ):
164
+ print(f"loading annotations from {json_path} ...")
165
+ self.dataset = json.load(open(json_path, 'r'))
166
+ self.length = len(self.dataset)
167
+ print(f"data scale: {self.length}")
168
+
169
+ self.video_folder = video_folder
170
+ self.sample_stride = sample_stride
171
+ self.sample_n_frames = sample_n_frames
172
+ self.enable_bucket = enable_bucket
173
+ self.enable_inpaint = enable_inpaint
174
+
175
+ sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size)
176
+ self.pixel_transforms = transforms.Compose(
177
+ [
178
+ transforms.Resize(sample_size[0]),
179
+ transforms.CenterCrop(sample_size),
180
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
181
+ ]
182
+ )
183
+
184
+ def get_batch(self, idx):
185
+ video_dict = self.dataset[idx]
186
+ video_id, name = video_dict['file_path'], video_dict['text']
187
+
188
+ if self.video_folder is None:
189
+ video_dir = video_id
190
+ else:
191
+ video_dir = os.path.join(self.video_folder, video_id)
192
+
193
+ with VideoReader_contextmanager(video_dir, num_threads=2) as video_reader:
194
+ video_length = len(video_reader)
195
+
196
+ clip_length = min(video_length, (self.sample_n_frames - 1) * self.sample_stride + 1)
197
+ start_idx = random.randint(0, video_length - clip_length)
198
+ batch_index = np.linspace(start_idx, start_idx + clip_length - 1, self.sample_n_frames, dtype=int)
199
+
200
+ try:
201
+ sample_args = (video_reader, batch_index)
202
+ pixel_values = func_timeout(
203
+ VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args
204
+ )
205
+ except FunctionTimedOut:
206
+ raise ValueError(f"Read {idx} timeout.")
207
+ except Exception as e:
208
+ raise ValueError(f"Failed to extract frames from video. Error is {e}.")
209
+
210
+ if not self.enable_bucket:
211
+ pixel_values = torch.from_numpy(pixel_values).permute(0, 3, 1, 2).contiguous()
212
+ pixel_values = pixel_values / 255.
213
+ del video_reader
214
+ else:
215
+ pixel_values = pixel_values
216
+
217
+ return pixel_values, name
218
+
219
+ def __len__(self):
220
+ return self.length
221
+
222
+ def __getitem__(self, idx):
223
+ while True:
224
+ try:
225
+ pixel_values, name = self.get_batch(idx)
226
+ break
227
+
228
+ except Exception as e:
229
+ print("Error info:", e)
230
+ idx = random.randint(0, self.length-1)
231
+
232
+ if not self.enable_bucket:
233
+ pixel_values = self.pixel_transforms(pixel_values)
234
+ if self.enable_inpaint:
235
+ mask = get_random_mask(pixel_values.size())
236
+ mask_pixel_values = pixel_values * (1 - mask) + torch.ones_like(pixel_values) * -1 * mask
237
+ sample = dict(pixel_values=pixel_values, mask_pixel_values=mask_pixel_values, mask=mask, text=name)
238
+ else:
239
+ sample = dict(pixel_values=pixel_values, text=name)
240
+ return sample
241
+
242
+
243
+ if __name__ == "__main__":
244
+ if 1:
245
+ dataset = VideoDataset(
246
+ json_path="/home/zhoumo.xjq/disk3/datasets/webvidval/results_2M_val.json",
247
+ sample_size=256,
248
+ sample_stride=4, sample_n_frames=16,
249
+ )
250
+
251
+ if 0:
252
+ dataset = WebVid10M(
253
+ csv_path="/mnt/petrelfs/guoyuwei/projects/datasets/webvid/results_2M_val.csv",
254
+ video_folder="/mnt/petrelfs/guoyuwei/projects/datasets/webvid/2M_val",
255
+ sample_size=256,
256
+ sample_stride=4, sample_n_frames=16,
257
+ is_image=False,
258
+ )
259
+
260
+ dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, num_workers=0,)
261
+ for idx, batch in enumerate(dataloader):
262
+ print(batch["pixel_values"].shape, len(batch["text"]))
easyanimate/models/attention.py ADDED
@@ -0,0 +1,1299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import math
15
+ from typing import Any, Dict, Optional
16
+
17
+ import torch
18
+ import torch.nn.functional as F
19
+ import torch.nn.init as init
20
+ from diffusers.models.activations import GEGLU, GELU, ApproximateGELU
21
+ from diffusers.models.attention import AdaLayerNorm, FeedForward
22
+ from diffusers.models.attention_processor import Attention
23
+ from diffusers.models.embeddings import SinusoidalPositionalEmbedding
24
+ from diffusers.models.lora import LoRACompatibleLinear
25
+ from diffusers.models.normalization import AdaLayerNorm, AdaLayerNormZero
26
+ from diffusers.utils import USE_PEFT_BACKEND
27
+ from diffusers.utils.import_utils import is_xformers_available
28
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
29
+ from einops import rearrange, repeat
30
+ from torch import nn
31
+
32
+ from .motion_module import get_motion_module
33
+
34
+ if is_xformers_available():
35
+ import xformers
36
+ import xformers.ops
37
+ else:
38
+ xformers = None
39
+
40
+
41
+ @maybe_allow_in_graph
42
+ class GatedSelfAttentionDense(nn.Module):
43
+ r"""
44
+ A gated self-attention dense layer that combines visual features and object features.
45
+
46
+ Parameters:
47
+ query_dim (`int`): The number of channels in the query.
48
+ context_dim (`int`): The number of channels in the context.
49
+ n_heads (`int`): The number of heads to use for attention.
50
+ d_head (`int`): The number of channels in each head.
51
+ """
52
+
53
+ def __init__(self, query_dim: int, context_dim: int, n_heads: int, d_head: int):
54
+ super().__init__()
55
+
56
+ # we need a linear projection since we need cat visual feature and obj feature
57
+ self.linear = nn.Linear(context_dim, query_dim)
58
+
59
+ self.attn = Attention(query_dim=query_dim, heads=n_heads, dim_head=d_head)
60
+ self.ff = FeedForward(query_dim, activation_fn="geglu")
61
+
62
+ self.norm1 = nn.LayerNorm(query_dim)
63
+ self.norm2 = nn.LayerNorm(query_dim)
64
+
65
+ self.register_parameter("alpha_attn", nn.Parameter(torch.tensor(0.0)))
66
+ self.register_parameter("alpha_dense", nn.Parameter(torch.tensor(0.0)))
67
+
68
+ self.enabled = True
69
+
70
+ def forward(self, x: torch.Tensor, objs: torch.Tensor) -> torch.Tensor:
71
+ if not self.enabled:
72
+ return x
73
+
74
+ n_visual = x.shape[1]
75
+ objs = self.linear(objs)
76
+
77
+ x = x + self.alpha_attn.tanh() * self.attn(self.norm1(torch.cat([x, objs], dim=1)))[:, :n_visual, :]
78
+ x = x + self.alpha_dense.tanh() * self.ff(self.norm2(x))
79
+
80
+ return x
81
+
82
+
83
+ def zero_module(module):
84
+ # Zero out the parameters of a module and return it.
85
+ for p in module.parameters():
86
+ p.detach().zero_()
87
+ return module
88
+
89
+
90
+
91
+ class KVCompressionCrossAttention(nn.Module):
92
+ r"""
93
+ A cross attention layer.
94
+
95
+ Parameters:
96
+ query_dim (`int`): The number of channels in the query.
97
+ cross_attention_dim (`int`, *optional*):
98
+ The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
99
+ heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
100
+ dim_head (`int`, *optional*, defaults to 64): The number of channels in each head.
101
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
102
+ bias (`bool`, *optional*, defaults to False):
103
+ Set to `True` for the query, key, and value linear layers to contain a bias parameter.
104
+ """
105
+
106
+ def __init__(
107
+ self,
108
+ query_dim: int,
109
+ cross_attention_dim: Optional[int] = None,
110
+ heads: int = 8,
111
+ dim_head: int = 64,
112
+ dropout: float = 0.0,
113
+ bias=False,
114
+ upcast_attention: bool = False,
115
+ upcast_softmax: bool = False,
116
+ added_kv_proj_dim: Optional[int] = None,
117
+ norm_num_groups: Optional[int] = None,
118
+ ):
119
+ super().__init__()
120
+ inner_dim = dim_head * heads
121
+ cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
122
+ self.upcast_attention = upcast_attention
123
+ self.upcast_softmax = upcast_softmax
124
+
125
+ self.scale = dim_head**-0.5
126
+
127
+ self.heads = heads
128
+ # for slice_size > 0 the attention score computation
129
+ # is split across the batch axis to save memory
130
+ # You can set slice_size with `set_attention_slice`
131
+ self.sliceable_head_dim = heads
132
+ self._slice_size = None
133
+ self._use_memory_efficient_attention_xformers = True
134
+ self.added_kv_proj_dim = added_kv_proj_dim
135
+
136
+ if norm_num_groups is not None:
137
+ self.group_norm = nn.GroupNorm(num_channels=inner_dim, num_groups=norm_num_groups, eps=1e-5, affine=True)
138
+ else:
139
+ self.group_norm = None
140
+
141
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
142
+ self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
143
+ self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
144
+
145
+ if self.added_kv_proj_dim is not None:
146
+ self.add_k_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
147
+ self.add_v_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
148
+
149
+ self.kv_compression = nn.Conv2d(
150
+ query_dim,
151
+ query_dim,
152
+ groups=query_dim,
153
+ kernel_size=2,
154
+ stride=2,
155
+ bias=True
156
+ )
157
+ self.kv_compression_norm = nn.LayerNorm(query_dim)
158
+ init.constant_(self.kv_compression.weight, 1 / 4)
159
+ if self.kv_compression.bias is not None:
160
+ init.constant_(self.kv_compression.bias, 0)
161
+
162
+ self.to_out = nn.ModuleList([])
163
+ self.to_out.append(nn.Linear(inner_dim, query_dim))
164
+ self.to_out.append(nn.Dropout(dropout))
165
+
166
+ def reshape_heads_to_batch_dim(self, tensor):
167
+ batch_size, seq_len, dim = tensor.shape
168
+ head_size = self.heads
169
+ tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
170
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
171
+ return tensor
172
+
173
+ def reshape_batch_dim_to_heads(self, tensor):
174
+ batch_size, seq_len, dim = tensor.shape
175
+ head_size = self.heads
176
+ tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
177
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
178
+ return tensor
179
+
180
+ def set_attention_slice(self, slice_size):
181
+ if slice_size is not None and slice_size > self.sliceable_head_dim:
182
+ raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.")
183
+
184
+ self._slice_size = slice_size
185
+
186
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, num_frames: int = 16, height: int = 32, width: int = 32):
187
+ batch_size, sequence_length, _ = hidden_states.shape
188
+
189
+ encoder_hidden_states = encoder_hidden_states
190
+
191
+ if self.group_norm is not None:
192
+ hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
193
+
194
+ query = self.to_q(hidden_states)
195
+ dim = query.shape[-1]
196
+ query = self.reshape_heads_to_batch_dim(query)
197
+
198
+ if self.added_kv_proj_dim is not None:
199
+ key = self.to_k(hidden_states)
200
+ value = self.to_v(hidden_states)
201
+
202
+ encoder_hidden_states_key_proj = self.add_k_proj(encoder_hidden_states)
203
+ encoder_hidden_states_value_proj = self.add_v_proj(encoder_hidden_states)
204
+
205
+ key = rearrange(key, "b (f h w) c -> (b f) c h w", f=num_frames, h=height, w=width)
206
+ key = self.kv_compression(key)
207
+ key = rearrange(key, "(b f) c h w -> b (f h w) c", f=num_frames)
208
+ key = self.kv_compression_norm(key)
209
+ key = key.to(query.dtype)
210
+
211
+ value = rearrange(value, "b (f h w) c -> (b f) c h w", f=num_frames, h=height, w=width)
212
+ value = self.kv_compression(value)
213
+ value = rearrange(value, "(b f) c h w -> b (f h w) c", f=num_frames)
214
+ value = self.kv_compression_norm(value)
215
+ value = value.to(query.dtype)
216
+
217
+ key = self.reshape_heads_to_batch_dim(key)
218
+ value = self.reshape_heads_to_batch_dim(value)
219
+ encoder_hidden_states_key_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_key_proj)
220
+ encoder_hidden_states_value_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_value_proj)
221
+
222
+ key = torch.concat([encoder_hidden_states_key_proj, key], dim=1)
223
+ value = torch.concat([encoder_hidden_states_value_proj, value], dim=1)
224
+ else:
225
+ encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
226
+ key = self.to_k(encoder_hidden_states)
227
+ value = self.to_v(encoder_hidden_states)
228
+
229
+ key = rearrange(key, "b (f h w) c -> (b f) c h w", f=num_frames, h=height, w=width)
230
+ key = self.kv_compression(key)
231
+ key = rearrange(key, "(b f) c h w -> b (f h w) c", f=num_frames)
232
+ key = self.kv_compression_norm(key)
233
+ key = key.to(query.dtype)
234
+
235
+ value = rearrange(value, "b (f h w) c -> (b f) c h w", f=num_frames, h=height, w=width)
236
+ value = self.kv_compression(value)
237
+ value = rearrange(value, "(b f) c h w -> b (f h w) c", f=num_frames)
238
+ value = self.kv_compression_norm(value)
239
+ value = value.to(query.dtype)
240
+
241
+ key = self.reshape_heads_to_batch_dim(key)
242
+ value = self.reshape_heads_to_batch_dim(value)
243
+
244
+ if attention_mask is not None:
245
+ if attention_mask.shape[-1] != query.shape[1]:
246
+ target_length = query.shape[1]
247
+ attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
248
+ attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
249
+
250
+ # attention, what we cannot get enough of
251
+ if self._use_memory_efficient_attention_xformers:
252
+ hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
253
+ # Some versions of xformers return output in fp32, cast it back to the dtype of the input
254
+ hidden_states = hidden_states.to(query.dtype)
255
+ else:
256
+ if self._slice_size is None or query.shape[0] // self._slice_size == 1:
257
+ hidden_states = self._attention(query, key, value, attention_mask)
258
+ else:
259
+ hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
260
+
261
+ # linear proj
262
+ hidden_states = self.to_out[0](hidden_states)
263
+
264
+ # dropout
265
+ hidden_states = self.to_out[1](hidden_states)
266
+ return hidden_states
267
+
268
+ def _attention(self, query, key, value, attention_mask=None):
269
+ if self.upcast_attention:
270
+ query = query.float()
271
+ key = key.float()
272
+
273
+ attention_scores = torch.baddbmm(
274
+ torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
275
+ query,
276
+ key.transpose(-1, -2),
277
+ beta=0,
278
+ alpha=self.scale,
279
+ )
280
+
281
+ if attention_mask is not None:
282
+ attention_scores = attention_scores + attention_mask
283
+
284
+ if self.upcast_softmax:
285
+ attention_scores = attention_scores.float()
286
+
287
+ attention_probs = attention_scores.softmax(dim=-1)
288
+
289
+ # cast back to the original dtype
290
+ attention_probs = attention_probs.to(value.dtype)
291
+
292
+ # compute attention output
293
+ hidden_states = torch.bmm(attention_probs, value)
294
+
295
+ # reshape hidden_states
296
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
297
+ return hidden_states
298
+
299
+ def _sliced_attention(self, query, key, value, sequence_length, dim, attention_mask):
300
+ batch_size_attention = query.shape[0]
301
+ hidden_states = torch.zeros(
302
+ (batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype
303
+ )
304
+ slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0]
305
+ for i in range(hidden_states.shape[0] // slice_size):
306
+ start_idx = i * slice_size
307
+ end_idx = (i + 1) * slice_size
308
+
309
+ query_slice = query[start_idx:end_idx]
310
+ key_slice = key[start_idx:end_idx]
311
+
312
+ if self.upcast_attention:
313
+ query_slice = query_slice.float()
314
+ key_slice = key_slice.float()
315
+
316
+ attn_slice = torch.baddbmm(
317
+ torch.empty(slice_size, query.shape[1], key.shape[1], dtype=query_slice.dtype, device=query.device),
318
+ query_slice,
319
+ key_slice.transpose(-1, -2),
320
+ beta=0,
321
+ alpha=self.scale,
322
+ )
323
+
324
+ if attention_mask is not None:
325
+ attn_slice = attn_slice + attention_mask[start_idx:end_idx]
326
+
327
+ if self.upcast_softmax:
328
+ attn_slice = attn_slice.float()
329
+
330
+ attn_slice = attn_slice.softmax(dim=-1)
331
+
332
+ # cast back to the original dtype
333
+ attn_slice = attn_slice.to(value.dtype)
334
+ attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
335
+
336
+ hidden_states[start_idx:end_idx] = attn_slice
337
+
338
+ # reshape hidden_states
339
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
340
+ return hidden_states
341
+
342
+ def _memory_efficient_attention_xformers(self, query, key, value, attention_mask):
343
+ # TODO attention_mask
344
+ query = query.contiguous()
345
+ key = key.contiguous()
346
+ value = value.contiguous()
347
+ hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
348
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
349
+ return hidden_states
350
+
351
+
352
+ @maybe_allow_in_graph
353
+ class TemporalTransformerBlock(nn.Module):
354
+ r"""
355
+ A Temporal Transformer block.
356
+
357
+ Parameters:
358
+ dim (`int`): The number of channels in the input and output.
359
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
360
+ attention_head_dim (`int`): The number of channels in each head.
361
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
362
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
363
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
364
+ num_embeds_ada_norm (:
365
+ obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
366
+ attention_bias (:
367
+ obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
368
+ only_cross_attention (`bool`, *optional*):
369
+ Whether to use only cross-attention layers. In this case two cross attention layers are used.
370
+ double_self_attention (`bool`, *optional*):
371
+ Whether to use two self-attention layers. In this case no cross attention layers are used.
372
+ upcast_attention (`bool`, *optional*):
373
+ Whether to upcast the attention computation to float32. This is useful for mixed precision training.
374
+ norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
375
+ Whether to use learnable elementwise affine parameters for normalization.
376
+ norm_type (`str`, *optional*, defaults to `"layer_norm"`):
377
+ The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
378
+ final_dropout (`bool` *optional*, defaults to False):
379
+ Whether to apply a final dropout after the last feed-forward layer.
380
+ attention_type (`str`, *optional*, defaults to `"default"`):
381
+ The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
382
+ positional_embeddings (`str`, *optional*, defaults to `None`):
383
+ The type of positional embeddings to apply to.
384
+ num_positional_embeddings (`int`, *optional*, defaults to `None`):
385
+ The maximum number of positional embeddings to apply.
386
+ """
387
+
388
+ def __init__(
389
+ self,
390
+ dim: int,
391
+ num_attention_heads: int,
392
+ attention_head_dim: int,
393
+ dropout=0.0,
394
+ cross_attention_dim: Optional[int] = None,
395
+ activation_fn: str = "geglu",
396
+ num_embeds_ada_norm: Optional[int] = None,
397
+ attention_bias: bool = False,
398
+ only_cross_attention: bool = False,
399
+ double_self_attention: bool = False,
400
+ upcast_attention: bool = False,
401
+ norm_elementwise_affine: bool = True,
402
+ norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single'
403
+ norm_eps: float = 1e-5,
404
+ final_dropout: bool = False,
405
+ attention_type: str = "default",
406
+ positional_embeddings: Optional[str] = None,
407
+ num_positional_embeddings: Optional[int] = None,
408
+ # kv compression
409
+ kvcompression: Optional[bool] = False,
410
+ # motion module kwargs
411
+ motion_module_type = "VanillaGrid",
412
+ motion_module_kwargs = None,
413
+ ):
414
+ super().__init__()
415
+ self.only_cross_attention = only_cross_attention
416
+
417
+ self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
418
+ self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
419
+ self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
420
+ self.use_layer_norm = norm_type == "layer_norm"
421
+
422
+ if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
423
+ raise ValueError(
424
+ f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
425
+ f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
426
+ )
427
+
428
+ if positional_embeddings and (num_positional_embeddings is None):
429
+ raise ValueError(
430
+ "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined."
431
+ )
432
+
433
+ if positional_embeddings == "sinusoidal":
434
+ self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings)
435
+ else:
436
+ self.pos_embed = None
437
+
438
+ # Define 3 blocks. Each block has its own normalization layer.
439
+ # 1. Self-Attn
440
+ if self.use_ada_layer_norm:
441
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
442
+ elif self.use_ada_layer_norm_zero:
443
+ self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
444
+ else:
445
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
446
+
447
+ self.kvcompression = kvcompression
448
+ if kvcompression:
449
+ self.attn1 = KVCompressionCrossAttention(
450
+ query_dim=dim,
451
+ heads=num_attention_heads,
452
+ dim_head=attention_head_dim,
453
+ dropout=dropout,
454
+ bias=attention_bias,
455
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
456
+ upcast_attention=upcast_attention,
457
+ )
458
+ else:
459
+ self.attn1 = Attention(
460
+ query_dim=dim,
461
+ heads=num_attention_heads,
462
+ dim_head=attention_head_dim,
463
+ dropout=dropout,
464
+ bias=attention_bias,
465
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
466
+ upcast_attention=upcast_attention,
467
+ )
468
+ print(self.attn1)
469
+
470
+ self.attn_temporal = get_motion_module(
471
+ in_channels = dim,
472
+ motion_module_type = motion_module_type,
473
+ motion_module_kwargs = motion_module_kwargs,
474
+ )
475
+
476
+ # 2. Cross-Attn
477
+ if cross_attention_dim is not None or double_self_attention:
478
+ # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
479
+ # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
480
+ # the second cross attention block.
481
+ self.norm2 = (
482
+ AdaLayerNorm(dim, num_embeds_ada_norm)
483
+ if self.use_ada_layer_norm
484
+ else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
485
+ )
486
+ self.attn2 = Attention(
487
+ query_dim=dim,
488
+ cross_attention_dim=cross_attention_dim if not double_self_attention else None,
489
+ heads=num_attention_heads,
490
+ dim_head=attention_head_dim,
491
+ dropout=dropout,
492
+ bias=attention_bias,
493
+ upcast_attention=upcast_attention,
494
+ ) # is self-attn if encoder_hidden_states is none
495
+ else:
496
+ self.norm2 = None
497
+ self.attn2 = None
498
+
499
+ # 3. Feed-forward
500
+ if not self.use_ada_layer_norm_single:
501
+ self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
502
+
503
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
504
+
505
+ # 4. Fuser
506
+ if attention_type == "gated" or attention_type == "gated-text-image":
507
+ self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim)
508
+
509
+ # 5. Scale-shift for PixArt-Alpha.
510
+ if self.use_ada_layer_norm_single:
511
+ self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
512
+
513
+ # let chunk size default to None
514
+ self._chunk_size = None
515
+ self._chunk_dim = 0
516
+
517
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
518
+ # Sets chunk feed-forward
519
+ self._chunk_size = chunk_size
520
+ self._chunk_dim = dim
521
+
522
+ def forward(
523
+ self,
524
+ hidden_states: torch.FloatTensor,
525
+ attention_mask: Optional[torch.FloatTensor] = None,
526
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
527
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
528
+ timestep: Optional[torch.LongTensor] = None,
529
+ cross_attention_kwargs: Dict[str, Any] = None,
530
+ class_labels: Optional[torch.LongTensor] = None,
531
+ num_frames: int = 16,
532
+ height: int = 32,
533
+ width: int = 32,
534
+ ) -> torch.FloatTensor:
535
+ # Notice that normalization is always applied before the real computation in the following blocks.
536
+ # 0. Self-Attention
537
+ batch_size = hidden_states.shape[0]
538
+
539
+ if self.use_ada_layer_norm:
540
+ norm_hidden_states = self.norm1(hidden_states, timestep)
541
+ elif self.use_ada_layer_norm_zero:
542
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
543
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
544
+ )
545
+ elif self.use_layer_norm:
546
+ norm_hidden_states = self.norm1(hidden_states)
547
+ elif self.use_ada_layer_norm_single:
548
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
549
+ self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
550
+ ).chunk(6, dim=1)
551
+ norm_hidden_states = self.norm1(hidden_states)
552
+ norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
553
+ norm_hidden_states = norm_hidden_states.squeeze(1)
554
+ else:
555
+ raise ValueError("Incorrect norm used")
556
+
557
+ if self.pos_embed is not None:
558
+ norm_hidden_states = self.pos_embed(norm_hidden_states)
559
+
560
+ # 1. Retrieve lora scale.
561
+ lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
562
+
563
+ # 2. Prepare GLIGEN inputs
564
+ cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
565
+ gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
566
+
567
+ norm_hidden_states = rearrange(norm_hidden_states, "b (f d) c -> (b f) d c", f=num_frames)
568
+ if self.kvcompression:
569
+ attn_output = self.attn1(
570
+ norm_hidden_states,
571
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
572
+ attention_mask=attention_mask,
573
+ num_frames=1,
574
+ height=height,
575
+ width=width,
576
+ **cross_attention_kwargs,
577
+ )
578
+ else:
579
+ attn_output = self.attn1(
580
+ norm_hidden_states,
581
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
582
+ attention_mask=attention_mask,
583
+ **cross_attention_kwargs,
584
+ )
585
+ attn_output = rearrange(attn_output, "(b f) d c -> b (f d) c", f=num_frames)
586
+ if self.use_ada_layer_norm_zero:
587
+ attn_output = gate_msa.unsqueeze(1) * attn_output
588
+ elif self.use_ada_layer_norm_single:
589
+ attn_output = gate_msa * attn_output
590
+
591
+ hidden_states = attn_output + hidden_states
592
+ if hidden_states.ndim == 4:
593
+ hidden_states = hidden_states.squeeze(1)
594
+
595
+ # 2.5 GLIGEN Control
596
+ if gligen_kwargs is not None:
597
+ hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
598
+
599
+ # 2.75. Temp-Attention
600
+ if self.attn_temporal is not None:
601
+ attn_output = rearrange(hidden_states, "b (f h w) c -> b c f h w", f=num_frames, h=height, w=width)
602
+ attn_output = self.attn_temporal(attn_output)
603
+ hidden_states = rearrange(attn_output, "b c f h w -> b (f h w) c")
604
+
605
+ # 3. Cross-Attention
606
+ if self.attn2 is not None:
607
+ if self.use_ada_layer_norm:
608
+ norm_hidden_states = self.norm2(hidden_states, timestep)
609
+ elif self.use_ada_layer_norm_zero or self.use_layer_norm:
610
+ norm_hidden_states = self.norm2(hidden_states)
611
+ elif self.use_ada_layer_norm_single:
612
+ # For PixArt norm2 isn't applied here:
613
+ # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103
614
+ norm_hidden_states = hidden_states
615
+ else:
616
+ raise ValueError("Incorrect norm")
617
+
618
+ if self.pos_embed is not None and self.use_ada_layer_norm_single is None:
619
+ norm_hidden_states = self.pos_embed(norm_hidden_states)
620
+
621
+ attn_output = self.attn2(
622
+ norm_hidden_states,
623
+ encoder_hidden_states=encoder_hidden_states,
624
+ attention_mask=encoder_attention_mask,
625
+ **cross_attention_kwargs,
626
+ )
627
+ hidden_states = attn_output + hidden_states
628
+
629
+ # 4. Feed-forward
630
+ if not self.use_ada_layer_norm_single:
631
+ norm_hidden_states = self.norm3(hidden_states)
632
+
633
+ if self.use_ada_layer_norm_zero:
634
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
635
+
636
+ if self.use_ada_layer_norm_single:
637
+ norm_hidden_states = self.norm2(hidden_states)
638
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
639
+
640
+ if self._chunk_size is not None:
641
+ # "feed_forward_chunk_size" can be used to save memory
642
+ if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
643
+ raise ValueError(
644
+ f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
645
+ )
646
+
647
+ num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
648
+ ff_output = torch.cat(
649
+ [
650
+ self.ff(hid_slice, scale=lora_scale)
651
+ for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)
652
+ ],
653
+ dim=self._chunk_dim,
654
+ )
655
+ else:
656
+ ff_output = self.ff(norm_hidden_states, scale=lora_scale)
657
+
658
+ if self.use_ada_layer_norm_zero:
659
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
660
+ elif self.use_ada_layer_norm_single:
661
+ ff_output = gate_mlp * ff_output
662
+
663
+ hidden_states = ff_output + hidden_states
664
+ if hidden_states.ndim == 4:
665
+ hidden_states = hidden_states.squeeze(1)
666
+
667
+ return hidden_states
668
+
669
+
670
+ @maybe_allow_in_graph
671
+ class SelfAttentionTemporalTransformerBlock(nn.Module):
672
+ r"""
673
+ A Temporal Transformer block.
674
+
675
+ Parameters:
676
+ dim (`int`): The number of channels in the input and output.
677
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
678
+ attention_head_dim (`int`): The number of channels in each head.
679
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
680
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
681
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
682
+ num_embeds_ada_norm (:
683
+ obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
684
+ attention_bias (:
685
+ obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
686
+ only_cross_attention (`bool`, *optional*):
687
+ Whether to use only cross-attention layers. In this case two cross attention layers are used.
688
+ double_self_attention (`bool`, *optional*):
689
+ Whether to use two self-attention layers. In this case no cross attention layers are used.
690
+ upcast_attention (`bool`, *optional*):
691
+ Whether to upcast the attention computation to float32. This is useful for mixed precision training.
692
+ norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
693
+ Whether to use learnable elementwise affine parameters for normalization.
694
+ norm_type (`str`, *optional*, defaults to `"layer_norm"`):
695
+ The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
696
+ final_dropout (`bool` *optional*, defaults to False):
697
+ Whether to apply a final dropout after the last feed-forward layer.
698
+ attention_type (`str`, *optional*, defaults to `"default"`):
699
+ The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
700
+ positional_embeddings (`str`, *optional*, defaults to `None`):
701
+ The type of positional embeddings to apply to.
702
+ num_positional_embeddings (`int`, *optional*, defaults to `None`):
703
+ The maximum number of positional embeddings to apply.
704
+ """
705
+
706
+ def __init__(
707
+ self,
708
+ dim: int,
709
+ num_attention_heads: int,
710
+ attention_head_dim: int,
711
+ dropout=0.0,
712
+ cross_attention_dim: Optional[int] = None,
713
+ activation_fn: str = "geglu",
714
+ num_embeds_ada_norm: Optional[int] = None,
715
+ attention_bias: bool = False,
716
+ only_cross_attention: bool = False,
717
+ double_self_attention: bool = False,
718
+ upcast_attention: bool = False,
719
+ norm_elementwise_affine: bool = True,
720
+ norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single'
721
+ norm_eps: float = 1e-5,
722
+ final_dropout: bool = False,
723
+ attention_type: str = "default",
724
+ positional_embeddings: Optional[str] = None,
725
+ num_positional_embeddings: Optional[int] = None,
726
+ ):
727
+ super().__init__()
728
+ self.only_cross_attention = only_cross_attention
729
+
730
+ self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
731
+ self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
732
+ self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
733
+ self.use_layer_norm = norm_type == "layer_norm"
734
+
735
+ if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
736
+ raise ValueError(
737
+ f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
738
+ f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
739
+ )
740
+
741
+ if positional_embeddings and (num_positional_embeddings is None):
742
+ raise ValueError(
743
+ "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined."
744
+ )
745
+
746
+ if positional_embeddings == "sinusoidal":
747
+ self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings)
748
+ else:
749
+ self.pos_embed = None
750
+
751
+ # Define 3 blocks. Each block has its own normalization layer.
752
+ # 1. Self-Attn
753
+ if self.use_ada_layer_norm:
754
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
755
+ elif self.use_ada_layer_norm_zero:
756
+ self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
757
+ else:
758
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
759
+
760
+ self.attn1 = Attention(
761
+ query_dim=dim,
762
+ heads=num_attention_heads,
763
+ dim_head=attention_head_dim,
764
+ dropout=dropout,
765
+ bias=attention_bias,
766
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
767
+ upcast_attention=upcast_attention,
768
+ )
769
+
770
+ # 2. Cross-Attn
771
+ if cross_attention_dim is not None or double_self_attention:
772
+ # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
773
+ # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
774
+ # the second cross attention block.
775
+ self.norm2 = (
776
+ AdaLayerNorm(dim, num_embeds_ada_norm)
777
+ if self.use_ada_layer_norm
778
+ else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
779
+ )
780
+ self.attn2 = Attention(
781
+ query_dim=dim,
782
+ cross_attention_dim=cross_attention_dim if not double_self_attention else None,
783
+ heads=num_attention_heads,
784
+ dim_head=attention_head_dim,
785
+ dropout=dropout,
786
+ bias=attention_bias,
787
+ upcast_attention=upcast_attention,
788
+ ) # is self-attn if encoder_hidden_states is none
789
+ else:
790
+ self.norm2 = None
791
+ self.attn2 = None
792
+
793
+ # 3. Feed-forward
794
+ if not self.use_ada_layer_norm_single:
795
+ self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
796
+
797
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
798
+
799
+ # 4. Fuser
800
+ if attention_type == "gated" or attention_type == "gated-text-image":
801
+ self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim)
802
+
803
+ # 5. Scale-shift for PixArt-Alpha.
804
+ if self.use_ada_layer_norm_single:
805
+ self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
806
+
807
+ # let chunk size default to None
808
+ self._chunk_size = None
809
+ self._chunk_dim = 0
810
+
811
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
812
+ # Sets chunk feed-forward
813
+ self._chunk_size = chunk_size
814
+ self._chunk_dim = dim
815
+
816
+ def forward(
817
+ self,
818
+ hidden_states: torch.FloatTensor,
819
+ attention_mask: Optional[torch.FloatTensor] = None,
820
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
821
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
822
+ timestep: Optional[torch.LongTensor] = None,
823
+ cross_attention_kwargs: Dict[str, Any] = None,
824
+ class_labels: Optional[torch.LongTensor] = None,
825
+ ) -> torch.FloatTensor:
826
+ # Notice that normalization is always applied before the real computation in the following blocks.
827
+ # 0. Self-Attention
828
+ batch_size = hidden_states.shape[0]
829
+
830
+ if self.use_ada_layer_norm:
831
+ norm_hidden_states = self.norm1(hidden_states, timestep)
832
+ elif self.use_ada_layer_norm_zero:
833
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
834
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
835
+ )
836
+ elif self.use_layer_norm:
837
+ norm_hidden_states = self.norm1(hidden_states)
838
+ elif self.use_ada_layer_norm_single:
839
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
840
+ self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
841
+ ).chunk(6, dim=1)
842
+ norm_hidden_states = self.norm1(hidden_states)
843
+ norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
844
+ norm_hidden_states = norm_hidden_states.squeeze(1)
845
+ else:
846
+ raise ValueError("Incorrect norm used")
847
+
848
+ if self.pos_embed is not None:
849
+ norm_hidden_states = self.pos_embed(norm_hidden_states)
850
+
851
+ # 1. Retrieve lora scale.
852
+ lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
853
+
854
+ # 2. Prepare GLIGEN inputs
855
+ cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
856
+ gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
857
+
858
+ attn_output = self.attn1(
859
+ norm_hidden_states,
860
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
861
+ attention_mask=attention_mask,
862
+ **cross_attention_kwargs,
863
+ )
864
+
865
+ if self.use_ada_layer_norm_zero:
866
+ attn_output = gate_msa.unsqueeze(1) * attn_output
867
+ elif self.use_ada_layer_norm_single:
868
+ attn_output = gate_msa * attn_output
869
+
870
+ hidden_states = attn_output + hidden_states
871
+ if hidden_states.ndim == 4:
872
+ hidden_states = hidden_states.squeeze(1)
873
+
874
+ # 2.5 GLIGEN Control
875
+ if gligen_kwargs is not None:
876
+ hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
877
+
878
+ # 3. Cross-Attention
879
+ if self.attn2 is not None:
880
+ if self.use_ada_layer_norm:
881
+ norm_hidden_states = self.norm2(hidden_states, timestep)
882
+ elif self.use_ada_layer_norm_zero or self.use_layer_norm:
883
+ norm_hidden_states = self.norm2(hidden_states)
884
+ elif self.use_ada_layer_norm_single:
885
+ # For PixArt norm2 isn't applied here:
886
+ # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103
887
+ norm_hidden_states = hidden_states
888
+ else:
889
+ raise ValueError("Incorrect norm")
890
+
891
+ if self.pos_embed is not None and self.use_ada_layer_norm_single is None:
892
+ norm_hidden_states = self.pos_embed(norm_hidden_states)
893
+
894
+ attn_output = self.attn2(
895
+ norm_hidden_states,
896
+ encoder_hidden_states=encoder_hidden_states,
897
+ attention_mask=encoder_attention_mask,
898
+ **cross_attention_kwargs,
899
+ )
900
+ hidden_states = attn_output + hidden_states
901
+
902
+ # 4. Feed-forward
903
+ if not self.use_ada_layer_norm_single:
904
+ norm_hidden_states = self.norm3(hidden_states)
905
+
906
+ if self.use_ada_layer_norm_zero:
907
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
908
+
909
+ if self.use_ada_layer_norm_single:
910
+ norm_hidden_states = self.norm2(hidden_states)
911
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
912
+
913
+ if self._chunk_size is not None:
914
+ # "feed_forward_chunk_size" can be used to save memory
915
+ if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
916
+ raise ValueError(
917
+ f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
918
+ )
919
+
920
+ num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
921
+ ff_output = torch.cat(
922
+ [
923
+ self.ff(hid_slice, scale=lora_scale)
924
+ for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)
925
+ ],
926
+ dim=self._chunk_dim,
927
+ )
928
+ else:
929
+ ff_output = self.ff(norm_hidden_states, scale=lora_scale)
930
+
931
+ if self.use_ada_layer_norm_zero:
932
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
933
+ elif self.use_ada_layer_norm_single:
934
+ ff_output = gate_mlp * ff_output
935
+
936
+ hidden_states = ff_output + hidden_states
937
+ if hidden_states.ndim == 4:
938
+ hidden_states = hidden_states.squeeze(1)
939
+
940
+ return hidden_states
941
+
942
+
943
+ @maybe_allow_in_graph
944
+ class KVCompressionTransformerBlock(nn.Module):
945
+ r"""
946
+ A Temporal Transformer block.
947
+
948
+ Parameters:
949
+ dim (`int`): The number of channels in the input and output.
950
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
951
+ attention_head_dim (`int`): The number of channels in each head.
952
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
953
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
954
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
955
+ num_embeds_ada_norm (:
956
+ obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
957
+ attention_bias (:
958
+ obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
959
+ only_cross_attention (`bool`, *optional*):
960
+ Whether to use only cross-attention layers. In this case two cross attention layers are used.
961
+ double_self_attention (`bool`, *optional*):
962
+ Whether to use two self-attention layers. In this case no cross attention layers are used.
963
+ upcast_attention (`bool`, *optional*):
964
+ Whether to upcast the attention computation to float32. This is useful for mixed precision training.
965
+ norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
966
+ Whether to use learnable elementwise affine parameters for normalization.
967
+ norm_type (`str`, *optional*, defaults to `"layer_norm"`):
968
+ The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
969
+ final_dropout (`bool` *optional*, defaults to False):
970
+ Whether to apply a final dropout after the last feed-forward layer.
971
+ attention_type (`str`, *optional*, defaults to `"default"`):
972
+ The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
973
+ positional_embeddings (`str`, *optional*, defaults to `None`):
974
+ The type of positional embeddings to apply to.
975
+ num_positional_embeddings (`int`, *optional*, defaults to `None`):
976
+ The maximum number of positional embeddings to apply.
977
+ """
978
+
979
+ def __init__(
980
+ self,
981
+ dim: int,
982
+ num_attention_heads: int,
983
+ attention_head_dim: int,
984
+ dropout=0.0,
985
+ cross_attention_dim: Optional[int] = None,
986
+ activation_fn: str = "geglu",
987
+ num_embeds_ada_norm: Optional[int] = None,
988
+ attention_bias: bool = False,
989
+ only_cross_attention: bool = False,
990
+ double_self_attention: bool = False,
991
+ upcast_attention: bool = False,
992
+ norm_elementwise_affine: bool = True,
993
+ norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single'
994
+ norm_eps: float = 1e-5,
995
+ final_dropout: bool = False,
996
+ attention_type: str = "default",
997
+ positional_embeddings: Optional[str] = None,
998
+ num_positional_embeddings: Optional[int] = None,
999
+ kvcompression: Optional[bool] = False,
1000
+ ):
1001
+ super().__init__()
1002
+ self.only_cross_attention = only_cross_attention
1003
+
1004
+ self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
1005
+ self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
1006
+ self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
1007
+ self.use_layer_norm = norm_type == "layer_norm"
1008
+
1009
+ if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
1010
+ raise ValueError(
1011
+ f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
1012
+ f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
1013
+ )
1014
+
1015
+ if positional_embeddings and (num_positional_embeddings is None):
1016
+ raise ValueError(
1017
+ "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined."
1018
+ )
1019
+
1020
+ if positional_embeddings == "sinusoidal":
1021
+ self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings)
1022
+ else:
1023
+ self.pos_embed = None
1024
+
1025
+ # Define 3 blocks. Each block has its own normalization layer.
1026
+ # 1. Self-Attn
1027
+ if self.use_ada_layer_norm:
1028
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
1029
+ elif self.use_ada_layer_norm_zero:
1030
+ self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
1031
+ else:
1032
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
1033
+
1034
+ self.kvcompression = kvcompression
1035
+ if kvcompression:
1036
+ self.attn1 = KVCompressionCrossAttention(
1037
+ query_dim=dim,
1038
+ heads=num_attention_heads,
1039
+ dim_head=attention_head_dim,
1040
+ dropout=dropout,
1041
+ bias=attention_bias,
1042
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
1043
+ upcast_attention=upcast_attention,
1044
+ )
1045
+ else:
1046
+ self.attn1 = Attention(
1047
+ query_dim=dim,
1048
+ heads=num_attention_heads,
1049
+ dim_head=attention_head_dim,
1050
+ dropout=dropout,
1051
+ bias=attention_bias,
1052
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
1053
+ upcast_attention=upcast_attention,
1054
+ )
1055
+ print(self.attn1)
1056
+
1057
+ # 2. Cross-Attn
1058
+ if cross_attention_dim is not None or double_self_attention:
1059
+ # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
1060
+ # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
1061
+ # the second cross attention block.
1062
+ self.norm2 = (
1063
+ AdaLayerNorm(dim, num_embeds_ada_norm)
1064
+ if self.use_ada_layer_norm
1065
+ else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
1066
+ )
1067
+ self.attn2 = Attention(
1068
+ query_dim=dim,
1069
+ cross_attention_dim=cross_attention_dim if not double_self_attention else None,
1070
+ heads=num_attention_heads,
1071
+ dim_head=attention_head_dim,
1072
+ dropout=dropout,
1073
+ bias=attention_bias,
1074
+ upcast_attention=upcast_attention,
1075
+ ) # is self-attn if encoder_hidden_states is none
1076
+ else:
1077
+ self.norm2 = None
1078
+ self.attn2 = None
1079
+
1080
+ # 3. Feed-forward
1081
+ if not self.use_ada_layer_norm_single:
1082
+ self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
1083
+
1084
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
1085
+
1086
+ # 4. Fuser
1087
+ if attention_type == "gated" or attention_type == "gated-text-image":
1088
+ self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim)
1089
+
1090
+ # 5. Scale-shift for PixArt-Alpha.
1091
+ if self.use_ada_layer_norm_single:
1092
+ self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
1093
+
1094
+ # let chunk size default to None
1095
+ self._chunk_size = None
1096
+ self._chunk_dim = 0
1097
+
1098
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
1099
+ # Sets chunk feed-forward
1100
+ self._chunk_size = chunk_size
1101
+ self._chunk_dim = dim
1102
+
1103
+ def forward(
1104
+ self,
1105
+ hidden_states: torch.FloatTensor,
1106
+ attention_mask: Optional[torch.FloatTensor] = None,
1107
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
1108
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
1109
+ timestep: Optional[torch.LongTensor] = None,
1110
+ cross_attention_kwargs: Dict[str, Any] = None,
1111
+ class_labels: Optional[torch.LongTensor] = None,
1112
+ num_frames: int = 16,
1113
+ height: int = 32,
1114
+ width: int = 32,
1115
+ use_reentrant: bool = False,
1116
+ ) -> torch.FloatTensor:
1117
+ # Notice that normalization is always applied before the real computation in the following blocks.
1118
+ # 0. Self-Attention
1119
+ batch_size = hidden_states.shape[0]
1120
+
1121
+ if self.use_ada_layer_norm:
1122
+ norm_hidden_states = self.norm1(hidden_states, timestep)
1123
+ elif self.use_ada_layer_norm_zero:
1124
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
1125
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
1126
+ )
1127
+ elif self.use_layer_norm:
1128
+ norm_hidden_states = self.norm1(hidden_states)
1129
+ elif self.use_ada_layer_norm_single:
1130
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
1131
+ self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
1132
+ ).chunk(6, dim=1)
1133
+ norm_hidden_states = self.norm1(hidden_states)
1134
+ norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
1135
+ norm_hidden_states = norm_hidden_states.squeeze(1)
1136
+ else:
1137
+ raise ValueError("Incorrect norm used")
1138
+
1139
+ if self.pos_embed is not None:
1140
+ norm_hidden_states = self.pos_embed(norm_hidden_states)
1141
+
1142
+ # 1. Retrieve lora scale.
1143
+ lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
1144
+
1145
+ # 2. Prepare GLIGEN inputs
1146
+ cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
1147
+ gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
1148
+
1149
+ if self.kvcompression:
1150
+ attn_output = self.attn1(
1151
+ norm_hidden_states,
1152
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
1153
+ attention_mask=attention_mask,
1154
+ num_frames=num_frames,
1155
+ height=height,
1156
+ width=width,
1157
+ **cross_attention_kwargs,
1158
+ )
1159
+ else:
1160
+ attn_output = self.attn1(
1161
+ norm_hidden_states,
1162
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
1163
+ attention_mask=attention_mask,
1164
+ **cross_attention_kwargs,
1165
+ )
1166
+
1167
+ if self.use_ada_layer_norm_zero:
1168
+ attn_output = gate_msa.unsqueeze(1) * attn_output
1169
+ elif self.use_ada_layer_norm_single:
1170
+ attn_output = gate_msa * attn_output
1171
+
1172
+ hidden_states = attn_output + hidden_states
1173
+ if hidden_states.ndim == 4:
1174
+ hidden_states = hidden_states.squeeze(1)
1175
+
1176
+ # 2.5 GLIGEN Control
1177
+ if gligen_kwargs is not None:
1178
+ hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
1179
+
1180
+ # 3. Cross-Attention
1181
+ if self.attn2 is not None:
1182
+ if self.use_ada_layer_norm:
1183
+ norm_hidden_states = self.norm2(hidden_states, timestep)
1184
+ elif self.use_ada_layer_norm_zero or self.use_layer_norm:
1185
+ norm_hidden_states = self.norm2(hidden_states)
1186
+ elif self.use_ada_layer_norm_single:
1187
+ # For PixArt norm2 isn't applied here:
1188
+ # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103
1189
+ norm_hidden_states = hidden_states
1190
+ else:
1191
+ raise ValueError("Incorrect norm")
1192
+
1193
+ if self.pos_embed is not None and self.use_ada_layer_norm_single is None:
1194
+ norm_hidden_states = self.pos_embed(norm_hidden_states)
1195
+
1196
+ attn_output = self.attn2(
1197
+ norm_hidden_states,
1198
+ encoder_hidden_states=encoder_hidden_states,
1199
+ attention_mask=encoder_attention_mask,
1200
+ **cross_attention_kwargs,
1201
+ )
1202
+ hidden_states = attn_output + hidden_states
1203
+
1204
+ # 4. Feed-forward
1205
+ if not self.use_ada_layer_norm_single:
1206
+ norm_hidden_states = self.norm3(hidden_states)
1207
+
1208
+ if self.use_ada_layer_norm_zero:
1209
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
1210
+
1211
+ if self.use_ada_layer_norm_single:
1212
+ norm_hidden_states = self.norm2(hidden_states)
1213
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
1214
+
1215
+ if self._chunk_size is not None:
1216
+ # "feed_forward_chunk_size" can be used to save memory
1217
+ if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
1218
+ raise ValueError(
1219
+ f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
1220
+ )
1221
+
1222
+ num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
1223
+ ff_output = torch.cat(
1224
+ [
1225
+ self.ff(hid_slice, scale=lora_scale)
1226
+ for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)
1227
+ ],
1228
+ dim=self._chunk_dim,
1229
+ )
1230
+ else:
1231
+ ff_output = self.ff(norm_hidden_states, scale=lora_scale)
1232
+
1233
+ if self.use_ada_layer_norm_zero:
1234
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
1235
+ elif self.use_ada_layer_norm_single:
1236
+ ff_output = gate_mlp * ff_output
1237
+
1238
+ hidden_states = ff_output + hidden_states
1239
+ if hidden_states.ndim == 4:
1240
+ hidden_states = hidden_states.squeeze(1)
1241
+
1242
+ return hidden_states
1243
+
1244
+
1245
+ class FeedForward(nn.Module):
1246
+ r"""
1247
+ A feed-forward layer.
1248
+
1249
+ Parameters:
1250
+ dim (`int`): The number of channels in the input.
1251
+ dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
1252
+ mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
1253
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
1254
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
1255
+ final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
1256
+ """
1257
+
1258
+ def __init__(
1259
+ self,
1260
+ dim: int,
1261
+ dim_out: Optional[int] = None,
1262
+ mult: int = 4,
1263
+ dropout: float = 0.0,
1264
+ activation_fn: str = "geglu",
1265
+ final_dropout: bool = False,
1266
+ ):
1267
+ super().__init__()
1268
+ inner_dim = int(dim * mult)
1269
+ dim_out = dim_out if dim_out is not None else dim
1270
+ linear_cls = LoRACompatibleLinear if not USE_PEFT_BACKEND else nn.Linear
1271
+
1272
+ if activation_fn == "gelu":
1273
+ act_fn = GELU(dim, inner_dim)
1274
+ if activation_fn == "gelu-approximate":
1275
+ act_fn = GELU(dim, inner_dim, approximate="tanh")
1276
+ elif activation_fn == "geglu":
1277
+ act_fn = GEGLU(dim, inner_dim)
1278
+ elif activation_fn == "geglu-approximate":
1279
+ act_fn = ApproximateGELU(dim, inner_dim)
1280
+
1281
+ self.net = nn.ModuleList([])
1282
+ # project in
1283
+ self.net.append(act_fn)
1284
+ # project dropout
1285
+ self.net.append(nn.Dropout(dropout))
1286
+ # project out
1287
+ self.net.append(linear_cls(inner_dim, dim_out))
1288
+ # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
1289
+ if final_dropout:
1290
+ self.net.append(nn.Dropout(dropout))
1291
+
1292
+ def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor:
1293
+ compatible_cls = (GEGLU,) if USE_PEFT_BACKEND else (GEGLU, LoRACompatibleLinear)
1294
+ for module in self.net:
1295
+ if isinstance(module, compatible_cls):
1296
+ hidden_states = module(hidden_states, scale)
1297
+ else:
1298
+ hidden_states = module(hidden_states)
1299
+ return hidden_states
easyanimate/models/autoencoder_magvit.py ADDED
@@ -0,0 +1,503 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Dict, Optional, Tuple, Union
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+ import torch.nn.functional as F
19
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
20
+ from diffusers.loaders import FromOriginalVAEMixin
21
+ from diffusers.models.attention_processor import (
22
+ ADDED_KV_ATTENTION_PROCESSORS, CROSS_ATTENTION_PROCESSORS, Attention,
23
+ AttentionProcessor, AttnAddedKVProcessor, AttnProcessor)
24
+ from diffusers.models.autoencoders.vae import (DecoderOutput,
25
+ DiagonalGaussianDistribution)
26
+ from diffusers.models.modeling_outputs import AutoencoderKLOutput
27
+ from diffusers.models.modeling_utils import ModelMixin
28
+ from diffusers.utils.accelerate_utils import apply_forward_hook
29
+ from torch import nn
30
+
31
+ from ..vae.ldm.models.omnigen_enc_dec import Decoder as omnigen_Mag_Decoder
32
+ from ..vae.ldm.models.omnigen_enc_dec import Encoder as omnigen_Mag_Encoder
33
+
34
+
35
+ def str_eval(item):
36
+ if type(item) == str:
37
+ return eval(item)
38
+ else:
39
+ return item
40
+
41
+ class AutoencoderKLMagvit(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
42
+ r"""
43
+ A VAE model with KL loss for encoding images into latents and decoding latent representations into images.
44
+
45
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
46
+ for all models (such as downloading or saving).
47
+
48
+ Parameters:
49
+ in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
50
+ out_channels (int, *optional*, defaults to 3): Number of channels in the output.
51
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
52
+ Tuple of downsample block types.
53
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
54
+ Tuple of upsample block types.
55
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
56
+ Tuple of block output channels.
57
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
58
+ latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space.
59
+ sample_size (`int`, *optional*, defaults to `32`): Sample input size.
60
+ scaling_factor (`float`, *optional*, defaults to 0.18215):
61
+ The component-wise standard deviation of the trained latent space computed using the first batch of the
62
+ training set. This is used to scale the latent space to have unit variance when training the diffusion
63
+ model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
64
+ diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
65
+ / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
66
+ Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
67
+ force_upcast (`bool`, *optional*, default to `True`):
68
+ If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE
69
+ can be fine-tuned / trained to a lower range without loosing too much precision in which case
70
+ `force_upcast` can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix
71
+ """
72
+
73
+ _supports_gradient_checkpointing = True
74
+
75
+ @register_to_config
76
+ def __init__(
77
+ self,
78
+ in_channels: int = 3,
79
+ out_channels: int = 3,
80
+ ch = 128,
81
+ ch_mult = [ 1,2,4,4 ],
82
+ use_gc_blocks = None,
83
+ down_block_types: tuple = None,
84
+ up_block_types: tuple = None,
85
+ mid_block_type: str = "MidBlock3D",
86
+ mid_block_use_attention: bool = True,
87
+ mid_block_attention_type: str = "3d",
88
+ mid_block_num_attention_heads: int = 1,
89
+ layers_per_block: int = 2,
90
+ act_fn: str = "silu",
91
+ num_attention_heads: int = 1,
92
+ latent_channels: int = 4,
93
+ norm_num_groups: int = 32,
94
+ scaling_factor: float = 0.1825,
95
+ slice_compression_vae=False,
96
+ mini_batch_encoder=9,
97
+ mini_batch_decoder=3,
98
+ ):
99
+ super().__init__()
100
+ down_block_types = str_eval(down_block_types)
101
+ up_block_types = str_eval(up_block_types)
102
+ self.encoder = omnigen_Mag_Encoder(
103
+ in_channels=in_channels,
104
+ out_channels=latent_channels,
105
+ down_block_types=down_block_types,
106
+ ch = ch,
107
+ ch_mult = ch_mult,
108
+ use_gc_blocks=use_gc_blocks,
109
+ mid_block_type=mid_block_type,
110
+ mid_block_use_attention=mid_block_use_attention,
111
+ mid_block_attention_type=mid_block_attention_type,
112
+ mid_block_num_attention_heads=mid_block_num_attention_heads,
113
+ layers_per_block=layers_per_block,
114
+ norm_num_groups=norm_num_groups,
115
+ act_fn=act_fn,
116
+ num_attention_heads=num_attention_heads,
117
+ double_z=True,
118
+ slice_compression_vae=slice_compression_vae,
119
+ mini_batch_encoder=mini_batch_encoder,
120
+ )
121
+
122
+ self.decoder = omnigen_Mag_Decoder(
123
+ in_channels=latent_channels,
124
+ out_channels=out_channels,
125
+ up_block_types=up_block_types,
126
+ ch = ch,
127
+ ch_mult = ch_mult,
128
+ use_gc_blocks=use_gc_blocks,
129
+ mid_block_type=mid_block_type,
130
+ mid_block_use_attention=mid_block_use_attention,
131
+ mid_block_attention_type=mid_block_attention_type,
132
+ mid_block_num_attention_heads=mid_block_num_attention_heads,
133
+ layers_per_block=layers_per_block,
134
+ norm_num_groups=norm_num_groups,
135
+ act_fn=act_fn,
136
+ num_attention_heads=num_attention_heads,
137
+ slice_compression_vae=slice_compression_vae,
138
+ mini_batch_decoder=mini_batch_decoder,
139
+ )
140
+
141
+ self.quant_conv = nn.Conv3d(2 * latent_channels, 2 * latent_channels, kernel_size=1)
142
+ self.post_quant_conv = nn.Conv3d(latent_channels, latent_channels, kernel_size=1)
143
+
144
+ self.slice_compression_vae = slice_compression_vae
145
+ self.mini_batch_encoder = mini_batch_encoder
146
+ self.mini_batch_decoder = mini_batch_decoder
147
+ self.use_slicing = False
148
+ self.use_tiling = False
149
+ self.tile_sample_min_size = 256
150
+ self.tile_overlap_factor = 0.25
151
+ self.tile_latent_min_size = int(self.tile_sample_min_size / (2 ** (len(ch_mult) - 1)))
152
+ self.scaling_factor = scaling_factor
153
+
154
+ def _set_gradient_checkpointing(self, module, value=False):
155
+ if isinstance(module, (omnigen_Mag_Encoder, omnigen_Mag_Decoder)):
156
+ module.gradient_checkpointing = value
157
+
158
+ @property
159
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
160
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
161
+ r"""
162
+ Returns:
163
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
164
+ indexed by its weight name.
165
+ """
166
+ # set recursively
167
+ processors = {}
168
+
169
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
170
+ if hasattr(module, "get_processor"):
171
+ processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
172
+
173
+ for sub_name, child in module.named_children():
174
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
175
+
176
+ return processors
177
+
178
+ for name, module in self.named_children():
179
+ fn_recursive_add_processors(name, module, processors)
180
+
181
+ return processors
182
+
183
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
184
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
185
+ r"""
186
+ Sets the attention processor to use to compute attention.
187
+
188
+ Parameters:
189
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
190
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
191
+ for **all** `Attention` layers.
192
+
193
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
194
+ processor. This is strongly recommended when setting trainable attention processors.
195
+
196
+ """
197
+ count = len(self.attn_processors.keys())
198
+
199
+ if isinstance(processor, dict) and len(processor) != count:
200
+ raise ValueError(
201
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
202
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
203
+ )
204
+
205
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
206
+ if hasattr(module, "set_processor"):
207
+ if not isinstance(processor, dict):
208
+ module.set_processor(processor)
209
+ else:
210
+ module.set_processor(processor.pop(f"{name}.processor"))
211
+
212
+ for sub_name, child in module.named_children():
213
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
214
+
215
+ for name, module in self.named_children():
216
+ fn_recursive_attn_processor(name, module, processor)
217
+
218
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
219
+ def set_default_attn_processor(self):
220
+ """
221
+ Disables custom attention processors and sets the default attention implementation.
222
+ """
223
+ if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
224
+ processor = AttnAddedKVProcessor()
225
+ elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
226
+ processor = AttnProcessor()
227
+ else:
228
+ raise ValueError(
229
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
230
+ )
231
+
232
+ self.set_attn_processor(processor)
233
+
234
+ @apply_forward_hook
235
+ def encode(
236
+ self, x: torch.FloatTensor, return_dict: bool = True
237
+ ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
238
+ """
239
+ Encode a batch of images into latents.
240
+
241
+ Args:
242
+ x (`torch.FloatTensor`): Input batch of images.
243
+ return_dict (`bool`, *optional*, defaults to `True`):
244
+ Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
245
+
246
+ Returns:
247
+ The latent representations of the encoded images. If `return_dict` is True, a
248
+ [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
249
+ """
250
+ if self.use_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size):
251
+ return self.tiled_encode(x, return_dict=return_dict)
252
+
253
+ if self.use_slicing and x.shape[0] > 1:
254
+ encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)]
255
+ h = torch.cat(encoded_slices)
256
+ else:
257
+ h = self.encoder(x)
258
+
259
+ moments = self.quant_conv(h)
260
+ posterior = DiagonalGaussianDistribution(moments)
261
+
262
+ if not return_dict:
263
+ return (posterior,)
264
+
265
+ return AutoencoderKLOutput(latent_dist=posterior)
266
+
267
+ def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
268
+ if self.use_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size):
269
+ return self.tiled_decode(z, return_dict=return_dict)
270
+ z = self.post_quant_conv(z)
271
+ dec = self.decoder(z)
272
+
273
+ if not return_dict:
274
+ return (dec,)
275
+
276
+ return DecoderOutput(sample=dec)
277
+
278
+ @apply_forward_hook
279
+ def decode(
280
+ self, z: torch.FloatTensor, return_dict: bool = True, generator=None
281
+ ) -> Union[DecoderOutput, torch.FloatTensor]:
282
+ """
283
+ Decode a batch of images.
284
+
285
+ Args:
286
+ z (`torch.FloatTensor`): Input batch of latent vectors.
287
+ return_dict (`bool`, *optional*, defaults to `True`):
288
+ Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
289
+
290
+ Returns:
291
+ [`~models.vae.DecoderOutput`] or `tuple`:
292
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
293
+ returned.
294
+
295
+ """
296
+ if self.use_slicing and z.shape[0] > 1:
297
+ decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
298
+ decoded = torch.cat(decoded_slices)
299
+ else:
300
+ decoded = self._decode(z).sample
301
+
302
+ if not return_dict:
303
+ return (decoded,)
304
+
305
+ return DecoderOutput(sample=decoded)
306
+
307
+ def blend_v(
308
+ self, a: torch.Tensor, b: torch.Tensor, blend_extent: int
309
+ ) -> torch.Tensor:
310
+ blend_extent = min(a.shape[3], b.shape[3], blend_extent)
311
+ for y in range(blend_extent):
312
+ b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (
313
+ 1 - y / blend_extent
314
+ ) + b[:, :, :, y, :] * (y / blend_extent)
315
+ return b
316
+
317
+ def blend_h(
318
+ self, a: torch.Tensor, b: torch.Tensor, blend_extent: int
319
+ ) -> torch.Tensor:
320
+ blend_extent = min(a.shape[4], b.shape[4], blend_extent)
321
+ for x in range(blend_extent):
322
+ b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (
323
+ 1 - x / blend_extent
324
+ ) + b[:, :, :, :, x] * (x / blend_extent)
325
+ return b
326
+
327
+ def tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
328
+ overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
329
+ blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
330
+ row_limit = self.tile_latent_min_size - blend_extent
331
+
332
+ # Split the image into 512x512 tiles and encode them separately.
333
+ rows = []
334
+ for i in range(0, x.shape[3], overlap_size):
335
+ row = []
336
+ for j in range(0, x.shape[4], overlap_size):
337
+ tile = x[
338
+ :,
339
+ :,
340
+ :,
341
+ i : i + self.tile_sample_min_size,
342
+ j : j + self.tile_sample_min_size,
343
+ ]
344
+ tile = self.encoder(tile)
345
+ tile = self.quant_conv(tile)
346
+ row.append(tile)
347
+ rows.append(row)
348
+ result_rows = []
349
+ for i, row in enumerate(rows):
350
+ result_row = []
351
+ for j, tile in enumerate(row):
352
+ # blend the above tile and the left tile
353
+ # to the current tile and add the current tile to the result row
354
+ if i > 0:
355
+ tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
356
+ if j > 0:
357
+ tile = self.blend_h(row[j - 1], tile, blend_extent)
358
+ result_row.append(tile[:, :, :, :row_limit, :row_limit])
359
+ result_rows.append(torch.cat(result_row, dim=4))
360
+
361
+ moments = torch.cat(result_rows, dim=3)
362
+ posterior = DiagonalGaussianDistribution(moments)
363
+
364
+ if not return_dict:
365
+ return (posterior,)
366
+
367
+ return AutoencoderKLOutput(latent_dist=posterior)
368
+
369
+ def tiled_decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
370
+ overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor))
371
+ blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor)
372
+ row_limit = self.tile_sample_min_size - blend_extent
373
+
374
+ # Split z into overlapping 64x64 tiles and decode them separately.
375
+ # The tiles have an overlap to avoid seams between tiles.
376
+ rows = []
377
+ for i in range(0, z.shape[3], overlap_size):
378
+ row = []
379
+ for j in range(0, z.shape[4], overlap_size):
380
+ tile = z[
381
+ :,
382
+ :,
383
+ :,
384
+ i : i + self.tile_latent_min_size,
385
+ j : j + self.tile_latent_min_size,
386
+ ]
387
+ tile = self.post_quant_conv(tile)
388
+ decoded = self.decoder(tile)
389
+ row.append(decoded)
390
+ rows.append(row)
391
+ result_rows = []
392
+ for i, row in enumerate(rows):
393
+ result_row = []
394
+ for j, tile in enumerate(row):
395
+ # blend the above tile and the left tile
396
+ # to the current tile and add the current tile to the result row
397
+ if i > 0:
398
+ tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
399
+ if j > 0:
400
+ tile = self.blend_h(row[j - 1], tile, blend_extent)
401
+ result_row.append(tile[:, :, :, :row_limit, :row_limit])
402
+ result_rows.append(torch.cat(result_row, dim=4))
403
+
404
+ dec = torch.cat(result_rows, dim=3)
405
+ if not return_dict:
406
+ return (dec,)
407
+
408
+ return DecoderOutput(sample=dec)
409
+
410
+ def forward(
411
+ self,
412
+ sample: torch.FloatTensor,
413
+ sample_posterior: bool = False,
414
+ return_dict: bool = True,
415
+ generator: Optional[torch.Generator] = None,
416
+ ) -> Union[DecoderOutput, torch.FloatTensor]:
417
+ r"""
418
+ Args:
419
+ sample (`torch.FloatTensor`): Input sample.
420
+ sample_posterior (`bool`, *optional*, defaults to `False`):
421
+ Whether to sample from the posterior.
422
+ return_dict (`bool`, *optional*, defaults to `True`):
423
+ Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
424
+ """
425
+ x = sample
426
+ posterior = self.encode(x).latent_dist
427
+ if sample_posterior:
428
+ z = posterior.sample(generator=generator)
429
+ else:
430
+ z = posterior.mode()
431
+ dec = self.decode(z).sample
432
+
433
+ if not return_dict:
434
+ return (dec,)
435
+
436
+ return DecoderOutput(sample=dec)
437
+
438
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
439
+ def fuse_qkv_projections(self):
440
+ """
441
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
442
+ key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
443
+
444
+ <Tip warning={true}>
445
+
446
+ This API is 🧪 experimental.
447
+
448
+ </Tip>
449
+ """
450
+ self.original_attn_processors = None
451
+
452
+ for _, attn_processor in self.attn_processors.items():
453
+ if "Added" in str(attn_processor.__class__.__name__):
454
+ raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
455
+
456
+ self.original_attn_processors = self.attn_processors
457
+
458
+ for module in self.modules():
459
+ if isinstance(module, Attention):
460
+ module.fuse_projections(fuse=True)
461
+
462
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
463
+ def unfuse_qkv_projections(self):
464
+ """Disables the fused QKV projection if enabled.
465
+
466
+ <Tip warning={true}>
467
+
468
+ This API is 🧪 experimental.
469
+
470
+ </Tip>
471
+
472
+ """
473
+ if self.original_attn_processors is not None:
474
+ self.set_attn_processor(self.original_attn_processors)
475
+
476
+ @classmethod
477
+ def from_pretrained(cls, pretrained_model_path, subfolder=None, **vae_additional_kwargs):
478
+ import json
479
+ import os
480
+ if subfolder is not None:
481
+ pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
482
+
483
+ config_file = os.path.join(pretrained_model_path, 'config.json')
484
+ if not os.path.isfile(config_file):
485
+ raise RuntimeError(f"{config_file} does not exist")
486
+ with open(config_file, "r") as f:
487
+ config = json.load(f)
488
+
489
+ model = cls.from_config(config, **vae_additional_kwargs)
490
+ from diffusers.utils import WEIGHTS_NAME
491
+ model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
492
+ model_file_safetensors = model_file.replace(".bin", ".safetensors")
493
+ if os.path.exists(model_file_safetensors):
494
+ from safetensors.torch import load_file, safe_open
495
+ state_dict = load_file(model_file_safetensors)
496
+ else:
497
+ if not os.path.isfile(model_file):
498
+ raise RuntimeError(f"{model_file} does not exist")
499
+ state_dict = torch.load(model_file, map_location="cpu")
500
+ m, u = model.load_state_dict(state_dict, strict=False)
501
+ print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
502
+ print(m, u)
503
+ return model
easyanimate/models/motion_module.py ADDED
@@ -0,0 +1,575 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Modified from https://github.com/guoyww/AnimateDiff/blob/main/animatediff/models/motion_module.py
2
+ """
3
+ import math
4
+ from typing import Any, Callable, List, Optional, Tuple, Union
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from diffusers.models.attention import FeedForward
9
+ from diffusers.utils.import_utils import is_xformers_available
10
+ from einops import rearrange, repeat
11
+ from torch import nn
12
+
13
+ if is_xformers_available():
14
+ import xformers
15
+ import xformers.ops
16
+ else:
17
+ xformers = None
18
+
19
+ class CrossAttention(nn.Module):
20
+ r"""
21
+ A cross attention layer.
22
+
23
+ Parameters:
24
+ query_dim (`int`): The number of channels in the query.
25
+ cross_attention_dim (`int`, *optional*):
26
+ The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
27
+ heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
28
+ dim_head (`int`, *optional*, defaults to 64): The number of channels in each head.
29
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
30
+ bias (`bool`, *optional*, defaults to False):
31
+ Set to `True` for the query, key, and value linear layers to contain a bias parameter.
32
+ """
33
+
34
+ def __init__(
35
+ self,
36
+ query_dim: int,
37
+ cross_attention_dim: Optional[int] = None,
38
+ heads: int = 8,
39
+ dim_head: int = 64,
40
+ dropout: float = 0.0,
41
+ bias=False,
42
+ upcast_attention: bool = False,
43
+ upcast_softmax: bool = False,
44
+ added_kv_proj_dim: Optional[int] = None,
45
+ norm_num_groups: Optional[int] = None,
46
+ ):
47
+ super().__init__()
48
+ inner_dim = dim_head * heads
49
+ cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
50
+ self.upcast_attention = upcast_attention
51
+ self.upcast_softmax = upcast_softmax
52
+
53
+ self.scale = dim_head**-0.5
54
+
55
+ self.heads = heads
56
+ # for slice_size > 0 the attention score computation
57
+ # is split across the batch axis to save memory
58
+ # You can set slice_size with `set_attention_slice`
59
+ self.sliceable_head_dim = heads
60
+ self._slice_size = None
61
+ self._use_memory_efficient_attention_xformers = False
62
+ self.added_kv_proj_dim = added_kv_proj_dim
63
+
64
+ if norm_num_groups is not None:
65
+ self.group_norm = nn.GroupNorm(num_channels=inner_dim, num_groups=norm_num_groups, eps=1e-5, affine=True)
66
+ else:
67
+ self.group_norm = None
68
+
69
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
70
+ self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
71
+ self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
72
+
73
+ if self.added_kv_proj_dim is not None:
74
+ self.add_k_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
75
+ self.add_v_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
76
+
77
+ self.to_out = nn.ModuleList([])
78
+ self.to_out.append(nn.Linear(inner_dim, query_dim))
79
+ self.to_out.append(nn.Dropout(dropout))
80
+
81
+ def set_use_memory_efficient_attention_xformers(
82
+ self, valid: bool, attention_op: Optional[Callable] = None
83
+ ) -> None:
84
+ self._use_memory_efficient_attention_xformers = valid
85
+
86
+ def reshape_heads_to_batch_dim(self, tensor):
87
+ batch_size, seq_len, dim = tensor.shape
88
+ head_size = self.heads
89
+ tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
90
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
91
+ return tensor
92
+
93
+ def reshape_batch_dim_to_heads(self, tensor):
94
+ batch_size, seq_len, dim = tensor.shape
95
+ head_size = self.heads
96
+ tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
97
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
98
+ return tensor
99
+
100
+ def set_attention_slice(self, slice_size):
101
+ if slice_size is not None and slice_size > self.sliceable_head_dim:
102
+ raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.")
103
+
104
+ self._slice_size = slice_size
105
+
106
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
107
+ batch_size, sequence_length, _ = hidden_states.shape
108
+
109
+ encoder_hidden_states = encoder_hidden_states
110
+
111
+ if self.group_norm is not None:
112
+ hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
113
+
114
+ query = self.to_q(hidden_states)
115
+ dim = query.shape[-1]
116
+ query = self.reshape_heads_to_batch_dim(query)
117
+
118
+ if self.added_kv_proj_dim is not None:
119
+ key = self.to_k(hidden_states)
120
+ value = self.to_v(hidden_states)
121
+ encoder_hidden_states_key_proj = self.add_k_proj(encoder_hidden_states)
122
+ encoder_hidden_states_value_proj = self.add_v_proj(encoder_hidden_states)
123
+
124
+ key = self.reshape_heads_to_batch_dim(key)
125
+ value = self.reshape_heads_to_batch_dim(value)
126
+ encoder_hidden_states_key_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_key_proj)
127
+ encoder_hidden_states_value_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_value_proj)
128
+
129
+ key = torch.concat([encoder_hidden_states_key_proj, key], dim=1)
130
+ value = torch.concat([encoder_hidden_states_value_proj, value], dim=1)
131
+ else:
132
+ encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
133
+ key = self.to_k(encoder_hidden_states)
134
+ value = self.to_v(encoder_hidden_states)
135
+
136
+ key = self.reshape_heads_to_batch_dim(key)
137
+ value = self.reshape_heads_to_batch_dim(value)
138
+
139
+ if attention_mask is not None:
140
+ if attention_mask.shape[-1] != query.shape[1]:
141
+ target_length = query.shape[1]
142
+ attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
143
+ attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
144
+
145
+ # attention, what we cannot get enough of
146
+ if self._use_memory_efficient_attention_xformers:
147
+ hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
148
+ # Some versions of xformers return output in fp32, cast it back to the dtype of the input
149
+ hidden_states = hidden_states.to(query.dtype)
150
+ else:
151
+ if self._slice_size is None or query.shape[0] // self._slice_size == 1:
152
+ hidden_states = self._attention(query, key, value, attention_mask)
153
+ else:
154
+ hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
155
+
156
+ # linear proj
157
+ hidden_states = self.to_out[0](hidden_states)
158
+
159
+ # dropout
160
+ hidden_states = self.to_out[1](hidden_states)
161
+ return hidden_states
162
+
163
+ def _attention(self, query, key, value, attention_mask=None):
164
+ if self.upcast_attention:
165
+ query = query.float()
166
+ key = key.float()
167
+
168
+ attention_scores = torch.baddbmm(
169
+ torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
170
+ query,
171
+ key.transpose(-1, -2),
172
+ beta=0,
173
+ alpha=self.scale,
174
+ )
175
+
176
+ if attention_mask is not None:
177
+ attention_scores = attention_scores + attention_mask
178
+
179
+ if self.upcast_softmax:
180
+ attention_scores = attention_scores.float()
181
+
182
+ attention_probs = attention_scores.softmax(dim=-1)
183
+
184
+ # cast back to the original dtype
185
+ attention_probs = attention_probs.to(value.dtype)
186
+
187
+ # compute attention output
188
+ hidden_states = torch.bmm(attention_probs, value)
189
+
190
+ # reshape hidden_states
191
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
192
+ return hidden_states
193
+
194
+ def _sliced_attention(self, query, key, value, sequence_length, dim, attention_mask):
195
+ batch_size_attention = query.shape[0]
196
+ hidden_states = torch.zeros(
197
+ (batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype
198
+ )
199
+ slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0]
200
+ for i in range(hidden_states.shape[0] // slice_size):
201
+ start_idx = i * slice_size
202
+ end_idx = (i + 1) * slice_size
203
+
204
+ query_slice = query[start_idx:end_idx]
205
+ key_slice = key[start_idx:end_idx]
206
+
207
+ if self.upcast_attention:
208
+ query_slice = query_slice.float()
209
+ key_slice = key_slice.float()
210
+
211
+ attn_slice = torch.baddbmm(
212
+ torch.empty(slice_size, query.shape[1], key.shape[1], dtype=query_slice.dtype, device=query.device),
213
+ query_slice,
214
+ key_slice.transpose(-1, -2),
215
+ beta=0,
216
+ alpha=self.scale,
217
+ )
218
+
219
+ if attention_mask is not None:
220
+ attn_slice = attn_slice + attention_mask[start_idx:end_idx]
221
+
222
+ if self.upcast_softmax:
223
+ attn_slice = attn_slice.float()
224
+
225
+ attn_slice = attn_slice.softmax(dim=-1)
226
+
227
+ # cast back to the original dtype
228
+ attn_slice = attn_slice.to(value.dtype)
229
+ attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
230
+
231
+ hidden_states[start_idx:end_idx] = attn_slice
232
+
233
+ # reshape hidden_states
234
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
235
+ return hidden_states
236
+
237
+ def _memory_efficient_attention_xformers(self, query, key, value, attention_mask):
238
+ # TODO attention_mask
239
+ query = query.contiguous()
240
+ key = key.contiguous()
241
+ value = value.contiguous()
242
+ hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
243
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
244
+ return hidden_states
245
+
246
+ def zero_module(module):
247
+ # Zero out the parameters of a module and return it.
248
+ for p in module.parameters():
249
+ p.detach().zero_()
250
+ return module
251
+
252
+ def get_motion_module(
253
+ in_channels,
254
+ motion_module_type: str,
255
+ motion_module_kwargs: dict,
256
+ ):
257
+ if motion_module_type == "Vanilla":
258
+ return VanillaTemporalModule(in_channels=in_channels, **motion_module_kwargs,)
259
+ elif motion_module_type == "VanillaGrid":
260
+ return VanillaTemporalModule(in_channels=in_channels, grid=True, **motion_module_kwargs,)
261
+ else:
262
+ raise ValueError
263
+
264
+ class VanillaTemporalModule(nn.Module):
265
+ def __init__(
266
+ self,
267
+ in_channels,
268
+ num_attention_heads = 8,
269
+ num_transformer_block = 2,
270
+ attention_block_types =( "Temporal_Self", "Temporal_Self" ),
271
+ cross_frame_attention_mode = None,
272
+ temporal_position_encoding = False,
273
+ temporal_position_encoding_max_len = 4096,
274
+ temporal_attention_dim_div = 1,
275
+ zero_initialize = True,
276
+ block_size = 1,
277
+ grid = False,
278
+ ):
279
+ super().__init__()
280
+
281
+ self.temporal_transformer = TemporalTransformer3DModel(
282
+ in_channels=in_channels,
283
+ num_attention_heads=num_attention_heads,
284
+ attention_head_dim=in_channels // num_attention_heads // temporal_attention_dim_div,
285
+ num_layers=num_transformer_block,
286
+ attention_block_types=attention_block_types,
287
+ cross_frame_attention_mode=cross_frame_attention_mode,
288
+ temporal_position_encoding=temporal_position_encoding,
289
+ temporal_position_encoding_max_len=temporal_position_encoding_max_len,
290
+ grid=grid,
291
+ block_size=block_size,
292
+ )
293
+ if zero_initialize:
294
+ self.temporal_transformer.proj_out = zero_module(self.temporal_transformer.proj_out)
295
+
296
+ def forward(self, input_tensor, encoder_hidden_states=None, attention_mask=None, anchor_frame_idx=None):
297
+ hidden_states = input_tensor
298
+ hidden_states = self.temporal_transformer(hidden_states, encoder_hidden_states, attention_mask)
299
+
300
+ output = hidden_states
301
+ return output
302
+
303
+ class TemporalTransformer3DModel(nn.Module):
304
+ def __init__(
305
+ self,
306
+ in_channels,
307
+ num_attention_heads,
308
+ attention_head_dim,
309
+
310
+ num_layers,
311
+ attention_block_types = ( "Temporal_Self", "Temporal_Self", ),
312
+ dropout = 0.0,
313
+ norm_num_groups = 32,
314
+ cross_attention_dim = 768,
315
+ activation_fn = "geglu",
316
+ attention_bias = False,
317
+ upcast_attention = False,
318
+
319
+ cross_frame_attention_mode = None,
320
+ temporal_position_encoding = False,
321
+ temporal_position_encoding_max_len = 4096,
322
+ grid = False,
323
+ block_size = 1,
324
+ ):
325
+ super().__init__()
326
+
327
+ inner_dim = num_attention_heads * attention_head_dim
328
+
329
+ self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
330
+ self.proj_in = nn.Linear(in_channels, inner_dim)
331
+
332
+ self.block_size = block_size
333
+ self.transformer_blocks = nn.ModuleList(
334
+ [
335
+ TemporalTransformerBlock(
336
+ dim=inner_dim,
337
+ num_attention_heads=num_attention_heads,
338
+ attention_head_dim=attention_head_dim,
339
+ attention_block_types=attention_block_types,
340
+ dropout=dropout,
341
+ norm_num_groups=norm_num_groups,
342
+ cross_attention_dim=cross_attention_dim,
343
+ activation_fn=activation_fn,
344
+ attention_bias=attention_bias,
345
+ upcast_attention=upcast_attention,
346
+ cross_frame_attention_mode=cross_frame_attention_mode,
347
+ temporal_position_encoding=temporal_position_encoding,
348
+ temporal_position_encoding_max_len=temporal_position_encoding_max_len,
349
+ block_size=block_size,
350
+ grid=grid,
351
+ )
352
+ for d in range(num_layers)
353
+ ]
354
+ )
355
+ self.proj_out = nn.Linear(inner_dim, in_channels)
356
+
357
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
358
+ assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
359
+ video_length = hidden_states.shape[2]
360
+ hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
361
+
362
+ batch, channel, height, weight = hidden_states.shape
363
+ residual = hidden_states
364
+
365
+ hidden_states = self.norm(hidden_states)
366
+ inner_dim = hidden_states.shape[1]
367
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
368
+ hidden_states = self.proj_in(hidden_states)
369
+
370
+ # Transformer Blocks
371
+ for block in self.transformer_blocks:
372
+ hidden_states = block(hidden_states, encoder_hidden_states=encoder_hidden_states, video_length=video_length, height=height, weight=weight)
373
+
374
+ # output
375
+ hidden_states = self.proj_out(hidden_states)
376
+ hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
377
+
378
+ output = hidden_states + residual
379
+ output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
380
+
381
+ return output
382
+
383
+ class TemporalTransformerBlock(nn.Module):
384
+ def __init__(
385
+ self,
386
+ dim,
387
+ num_attention_heads,
388
+ attention_head_dim,
389
+ attention_block_types = ( "Temporal_Self", "Temporal_Self", ),
390
+ dropout = 0.0,
391
+ norm_num_groups = 32,
392
+ cross_attention_dim = 768,
393
+ activation_fn = "geglu",
394
+ attention_bias = False,
395
+ upcast_attention = False,
396
+ cross_frame_attention_mode = None,
397
+ temporal_position_encoding = False,
398
+ temporal_position_encoding_max_len = 4096,
399
+ block_size = 1,
400
+ grid = False,
401
+ ):
402
+ super().__init__()
403
+
404
+ attention_blocks = []
405
+ norms = []
406
+
407
+ for block_name in attention_block_types:
408
+ attention_blocks.append(
409
+ VersatileAttention(
410
+ attention_mode=block_name.split("_")[0],
411
+ cross_attention_dim=cross_attention_dim if block_name.endswith("_Cross") else None,
412
+
413
+ query_dim=dim,
414
+ heads=num_attention_heads,
415
+ dim_head=attention_head_dim,
416
+ dropout=dropout,
417
+ bias=attention_bias,
418
+ upcast_attention=upcast_attention,
419
+
420
+ cross_frame_attention_mode=cross_frame_attention_mode,
421
+ temporal_position_encoding=temporal_position_encoding,
422
+ temporal_position_encoding_max_len=temporal_position_encoding_max_len,
423
+ block_size=block_size,
424
+ grid=grid,
425
+ )
426
+ )
427
+ norms.append(nn.LayerNorm(dim))
428
+
429
+ self.attention_blocks = nn.ModuleList(attention_blocks)
430
+ self.norms = nn.ModuleList(norms)
431
+
432
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
433
+ self.ff_norm = nn.LayerNorm(dim)
434
+
435
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None, height=None, weight=None):
436
+ for attention_block, norm in zip(self.attention_blocks, self.norms):
437
+ norm_hidden_states = norm(hidden_states)
438
+ hidden_states = attention_block(
439
+ norm_hidden_states,
440
+ encoder_hidden_states=encoder_hidden_states if attention_block.is_cross_attention else None,
441
+ video_length=video_length,
442
+ height=height,
443
+ weight=weight,
444
+ ) + hidden_states
445
+
446
+ hidden_states = self.ff(self.ff_norm(hidden_states)) + hidden_states
447
+
448
+ output = hidden_states
449
+ return output
450
+
451
+ class PositionalEncoding(nn.Module):
452
+ def __init__(
453
+ self,
454
+ d_model,
455
+ dropout = 0.,
456
+ max_len = 4096
457
+ ):
458
+ super().__init__()
459
+ self.dropout = nn.Dropout(p=dropout)
460
+ position = torch.arange(max_len).unsqueeze(1)
461
+ div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
462
+ pe = torch.zeros(1, max_len, d_model)
463
+ pe[0, :, 0::2] = torch.sin(position * div_term)
464
+ pe[0, :, 1::2] = torch.cos(position * div_term)
465
+ self.register_buffer('pe', pe)
466
+
467
+ def forward(self, x):
468
+ x = x + self.pe[:, :x.size(1)]
469
+ return self.dropout(x)
470
+
471
+ class VersatileAttention(CrossAttention):
472
+ def __init__(
473
+ self,
474
+ attention_mode = None,
475
+ cross_frame_attention_mode = None,
476
+ temporal_position_encoding = False,
477
+ temporal_position_encoding_max_len = 4096,
478
+ grid = False,
479
+ block_size = 1,
480
+ *args, **kwargs
481
+ ):
482
+ super().__init__(*args, **kwargs)
483
+ assert attention_mode == "Temporal"
484
+
485
+ self.attention_mode = attention_mode
486
+ self.is_cross_attention = kwargs["cross_attention_dim"] is not None
487
+
488
+ self.block_size = block_size
489
+ self.grid = grid
490
+ self.pos_encoder = PositionalEncoding(
491
+ kwargs["query_dim"],
492
+ dropout=0.,
493
+ max_len=temporal_position_encoding_max_len
494
+ ) if (temporal_position_encoding and attention_mode == "Temporal") else None
495
+
496
+ def extra_repr(self):
497
+ return f"(Module Info) Attention_Mode: {self.attention_mode}, Is_Cross_Attention: {self.is_cross_attention}"
498
+
499
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None, height=None, weight=None):
500
+ batch_size, sequence_length, _ = hidden_states.shape
501
+
502
+ if self.attention_mode == "Temporal":
503
+ # for add pos_encoder
504
+ _, before_d, _c = hidden_states.size()
505
+ hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length)
506
+ if self.pos_encoder is not None:
507
+ hidden_states = self.pos_encoder(hidden_states)
508
+
509
+ if self.grid:
510
+ hidden_states = rearrange(hidden_states, "(b d) f c -> b f d c", f=video_length, d=before_d)
511
+ hidden_states = rearrange(hidden_states, "b f (h w) c -> b f h w c", h=height, w=weight)
512
+
513
+ hidden_states = rearrange(hidden_states, "b f (h n) (w m) c -> (b h w) (f n m) c", n=self.block_size, m=self.block_size)
514
+ d = before_d // self.block_size // self.block_size
515
+ else:
516
+ d = before_d
517
+ encoder_hidden_states = repeat(encoder_hidden_states, "b n c -> (b d) n c", d=d) if encoder_hidden_states is not None else encoder_hidden_states
518
+ else:
519
+ raise NotImplementedError
520
+
521
+ encoder_hidden_states = encoder_hidden_states
522
+
523
+ if self.group_norm is not None:
524
+ hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
525
+
526
+ query = self.to_q(hidden_states)
527
+ dim = query.shape[-1]
528
+ query = self.reshape_heads_to_batch_dim(query)
529
+
530
+ if self.added_kv_proj_dim is not None:
531
+ raise NotImplementedError
532
+
533
+ encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
534
+ key = self.to_k(encoder_hidden_states)
535
+ value = self.to_v(encoder_hidden_states)
536
+
537
+ key = self.reshape_heads_to_batch_dim(key)
538
+ value = self.reshape_heads_to_batch_dim(value)
539
+
540
+ if attention_mask is not None:
541
+ if attention_mask.shape[-1] != query.shape[1]:
542
+ target_length = query.shape[1]
543
+ attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
544
+ attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
545
+
546
+ bs = 512
547
+ new_hidden_states = []
548
+ for i in range(0, query.shape[0], bs):
549
+ # attention, what we cannot get enough of
550
+ if self._use_memory_efficient_attention_xformers:
551
+ hidden_states = self._memory_efficient_attention_xformers(query[i : i + bs], key[i : i + bs], value[i : i + bs], attention_mask[i : i + bs] if attention_mask is not None else attention_mask)
552
+ # Some versions of xformers return output in fp32, cast it back to the dtype of the input
553
+ hidden_states = hidden_states.to(query.dtype)
554
+ else:
555
+ if self._slice_size is None or query[i : i + bs].shape[0] // self._slice_size == 1:
556
+ hidden_states = self._attention(query[i : i + bs], key[i : i + bs], value[i : i + bs], attention_mask[i : i + bs] if attention_mask is not None else attention_mask)
557
+ else:
558
+ hidden_states = self._sliced_attention(query[i : i + bs], key[i : i + bs], value[i : i + bs], sequence_length, dim, attention_mask[i : i + bs] if attention_mask is not None else attention_mask)
559
+ new_hidden_states.append(hidden_states)
560
+ hidden_states = torch.cat(new_hidden_states, dim = 0)
561
+
562
+ # linear proj
563
+ hidden_states = self.to_out[0](hidden_states)
564
+
565
+ # dropout
566
+ hidden_states = self.to_out[1](hidden_states)
567
+
568
+ if self.attention_mode == "Temporal":
569
+ hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
570
+ if self.grid:
571
+ hidden_states = rearrange(hidden_states, "(b f n m) (h w) c -> (b f) h n w m c", f=video_length, n=self.block_size, m=self.block_size, h=height // self.block_size, w=weight // self.block_size)
572
+ hidden_states = rearrange(hidden_states, "b h n w m c -> b (h n) (w m) c")
573
+ hidden_states = rearrange(hidden_states, "b h w c -> b (h w) c")
574
+
575
+ return hidden_states
easyanimate/models/patch.py ADDED
@@ -0,0 +1,426 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn.functional as F
6
+ import torch.nn.init as init
7
+ import math
8
+ from einops import rearrange
9
+ from torch import nn
10
+
11
+
12
+ def get_2d_sincos_pos_embed(
13
+ embed_dim, grid_size, cls_token=False, extra_tokens=0, interpolation_scale=1.0, base_size=16
14
+ ):
15
+ """
16
+ grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or
17
+ [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
18
+ """
19
+ if isinstance(grid_size, int):
20
+ grid_size = (grid_size, grid_size)
21
+
22
+ grid_h = np.arange(grid_size[0], dtype=np.float32) / (grid_size[0] / base_size) / interpolation_scale
23
+ grid_w = np.arange(grid_size[1], dtype=np.float32) / (grid_size[1] / base_size) / interpolation_scale
24
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
25
+ grid = np.stack(grid, axis=0)
26
+
27
+ grid = grid.reshape([2, 1, grid_size[1], grid_size[0]])
28
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
29
+ if cls_token and extra_tokens > 0:
30
+ pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
31
+ return pos_embed
32
+
33
+
34
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
35
+ if embed_dim % 2 != 0:
36
+ raise ValueError("embed_dim must be divisible by 2")
37
+
38
+ # use half of dimensions to encode grid_h
39
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
40
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
41
+
42
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
43
+ return emb
44
+
45
+
46
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
47
+ """
48
+ embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D)
49
+ """
50
+ if embed_dim % 2 != 0:
51
+ raise ValueError("embed_dim must be divisible by 2")
52
+
53
+ omega = np.arange(embed_dim // 2, dtype=np.float64)
54
+ omega /= embed_dim / 2.0
55
+ omega = 1.0 / 10000**omega # (D/2,)
56
+
57
+ pos = pos.reshape(-1) # (M,)
58
+ out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
59
+
60
+ emb_sin = np.sin(out) # (M, D/2)
61
+ emb_cos = np.cos(out) # (M, D/2)
62
+
63
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
64
+ return emb
65
+
66
+ class Patch1D(nn.Module):
67
+ def __init__(
68
+ self,
69
+ channels: int,
70
+ use_conv: bool = False,
71
+ out_channels: Optional[int] = None,
72
+ stride: int = 2,
73
+ padding: int = 0,
74
+ name: str = "conv",
75
+ ):
76
+ super().__init__()
77
+ self.channels = channels
78
+ self.out_channels = out_channels or channels
79
+ self.use_conv = use_conv
80
+ self.padding = padding
81
+ self.name = name
82
+
83
+ if use_conv:
84
+ self.conv = nn.Conv1d(self.channels, self.out_channels, stride, stride=stride, padding=padding)
85
+ init.constant_(self.conv.weight, 0.0)
86
+ with torch.no_grad():
87
+ for i in range(len(self.conv.weight)): self.conv.weight[i, i] = 1 / stride
88
+ init.constant_(self.conv.bias, 0.0)
89
+ else:
90
+ assert self.channels == self.out_channels
91
+ self.conv = nn.AvgPool1d(kernel_size=stride, stride=stride)
92
+
93
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
94
+ assert inputs.shape[1] == self.channels
95
+ return self.conv(inputs)
96
+
97
+ class UnPatch1D(nn.Module):
98
+ def __init__(
99
+ self,
100
+ channels: int,
101
+ use_conv: bool = False,
102
+ use_conv_transpose: bool = False,
103
+ out_channels: Optional[int] = None,
104
+ name: str = "conv",
105
+ ):
106
+ super().__init__()
107
+ self.channels = channels
108
+ self.out_channels = out_channels or channels
109
+ self.use_conv = use_conv
110
+ self.use_conv_transpose = use_conv_transpose
111
+ self.name = name
112
+
113
+ self.conv = None
114
+ if use_conv_transpose:
115
+ self.conv = nn.ConvTranspose1d(channels, self.out_channels, 4, 2, 1)
116
+ elif use_conv:
117
+ self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1)
118
+
119
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
120
+ assert inputs.shape[1] == self.channels
121
+ if self.use_conv_transpose:
122
+ return self.conv(inputs)
123
+
124
+ outputs = F.interpolate(inputs, scale_factor=2.0, mode="nearest")
125
+
126
+ if self.use_conv:
127
+ outputs = self.conv(outputs)
128
+
129
+ return outputs
130
+
131
+ class Upsampler(nn.Module):
132
+ def __init__(
133
+ self,
134
+ spatial_upsample_factor: int = 1,
135
+ temporal_upsample_factor: int = 1,
136
+ ):
137
+ super().__init__()
138
+
139
+ self.spatial_upsample_factor = spatial_upsample_factor
140
+ self.temporal_upsample_factor = temporal_upsample_factor
141
+
142
+ class TemporalUpsampler3D(Upsampler):
143
+ def __init__(self):
144
+ super().__init__(
145
+ spatial_upsample_factor=1,
146
+ temporal_upsample_factor=2,
147
+ )
148
+
149
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
150
+ if x.shape[2] > 1:
151
+ first_frame, x = x[:, :, :1], x[:, :, 1:]
152
+ x = F.interpolate(x, scale_factor=(2, 1, 1), mode="trilinear")
153
+ x = torch.cat([first_frame, x], dim=2)
154
+ return x
155
+
156
+ def cast_tuple(t, length = 1):
157
+ return t if isinstance(t, tuple) else ((t,) * length)
158
+
159
+ def divisible_by(num, den):
160
+ return (num % den) == 0
161
+
162
+ def is_odd(n):
163
+ return not divisible_by(n, 2)
164
+
165
+ class CausalConv3d(nn.Conv3d):
166
+ def __init__(
167
+ self,
168
+ in_channels: int,
169
+ out_channels: int,
170
+ kernel_size=3, # : int | tuple[int, int, int],
171
+ stride=1, # : int | tuple[int, int, int] = 1,
172
+ padding=1, # : int | tuple[int, int, int], # TODO: change it to 0.
173
+ dilation=1, # : int | tuple[int, int, int] = 1,
174
+ **kwargs,
175
+ ):
176
+ kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size,) * 3
177
+ assert len(kernel_size) == 3, f"Kernel size must be a 3-tuple, got {kernel_size} instead."
178
+
179
+ stride = stride if isinstance(stride, tuple) else (stride,) * 3
180
+ assert len(stride) == 3, f"Stride must be a 3-tuple, got {stride} instead."
181
+
182
+ dilation = dilation if isinstance(dilation, tuple) else (dilation,) * 3
183
+ assert len(dilation) == 3, f"Dilation must be a 3-tuple, got {dilation} instead."
184
+
185
+ t_ks, h_ks, w_ks = kernel_size
186
+ _, h_stride, w_stride = stride
187
+ t_dilation, h_dilation, w_dilation = dilation
188
+
189
+ t_pad = (t_ks - 1) * t_dilation
190
+ # TODO: align with SD
191
+ if padding is None:
192
+ h_pad = math.ceil(((h_ks - 1) * h_dilation + (1 - h_stride)) / 2)
193
+ w_pad = math.ceil(((w_ks - 1) * w_dilation + (1 - w_stride)) / 2)
194
+ elif isinstance(padding, int):
195
+ h_pad = w_pad = padding
196
+ else:
197
+ assert NotImplementedError
198
+
199
+ self.temporal_padding = t_pad
200
+
201
+ super().__init__(
202
+ in_channels=in_channels,
203
+ out_channels=out_channels,
204
+ kernel_size=kernel_size,
205
+ stride=stride,
206
+ dilation=dilation,
207
+ padding=(0, h_pad, w_pad),
208
+ **kwargs,
209
+ )
210
+
211
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
212
+ # x: (B, C, T, H, W)
213
+ x = F.pad(
214
+ x,
215
+ pad=(0, 0, 0, 0, self.temporal_padding, 0),
216
+ mode="replicate", # TODO: check if this is necessary
217
+ )
218
+ return super().forward(x)
219
+
220
+ class PatchEmbed3D(nn.Module):
221
+ """3D Image to Patch Embedding"""
222
+
223
+ def __init__(
224
+ self,
225
+ height=224,
226
+ width=224,
227
+ patch_size=16,
228
+ time_patch_size=4,
229
+ in_channels=3,
230
+ embed_dim=768,
231
+ layer_norm=False,
232
+ flatten=True,
233
+ bias=True,
234
+ interpolation_scale=1,
235
+ ):
236
+ super().__init__()
237
+
238
+ num_patches = (height // patch_size) * (width // patch_size)
239
+ self.flatten = flatten
240
+ self.layer_norm = layer_norm
241
+
242
+ self.proj = nn.Conv3d(
243
+ in_channels, embed_dim, kernel_size=(time_patch_size, patch_size, patch_size), stride=(time_patch_size, patch_size, patch_size), bias=bias
244
+ )
245
+ if layer_norm:
246
+ self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6)
247
+ else:
248
+ self.norm = None
249
+
250
+ self.patch_size = patch_size
251
+ # See:
252
+ # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L161
253
+ self.height, self.width = height // patch_size, width // patch_size
254
+ self.base_size = height // patch_size
255
+ self.interpolation_scale = interpolation_scale
256
+ pos_embed = get_2d_sincos_pos_embed(
257
+ embed_dim, int(num_patches**0.5), base_size=self.base_size, interpolation_scale=self.interpolation_scale
258
+ )
259
+ self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=False)
260
+
261
+ def forward(self, latent):
262
+ height, width = latent.shape[-2] // self.patch_size, latent.shape[-1] // self.patch_size
263
+
264
+ latent = self.proj(latent)
265
+ latent = rearrange(latent, "b c f h w -> (b f) c h w")
266
+ if self.flatten:
267
+ latent = latent.flatten(2).transpose(1, 2) # BCFHW -> BNC
268
+ if self.layer_norm:
269
+ latent = self.norm(latent)
270
+ # Interpolate positional embeddings if needed.
271
+ # (For PixArt-Alpha: https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L162C151-L162C160)
272
+ if self.height != height or self.width != width:
273
+ pos_embed = get_2d_sincos_pos_embed(
274
+ embed_dim=self.pos_embed.shape[-1],
275
+ grid_size=(height, width),
276
+ base_size=self.base_size,
277
+ interpolation_scale=self.interpolation_scale,
278
+ )
279
+ pos_embed = torch.from_numpy(pos_embed)
280
+ pos_embed = pos_embed.float().unsqueeze(0).to(latent.device)
281
+ else:
282
+ pos_embed = self.pos_embed
283
+
284
+ return (latent + pos_embed).to(latent.dtype)
285
+
286
+ class PatchEmbedF3D(nn.Module):
287
+ """Fake 3D Image to Patch Embedding"""
288
+
289
+ def __init__(
290
+ self,
291
+ height=224,
292
+ width=224,
293
+ patch_size=16,
294
+ in_channels=3,
295
+ embed_dim=768,
296
+ layer_norm=False,
297
+ flatten=True,
298
+ bias=True,
299
+ interpolation_scale=1,
300
+ ):
301
+ super().__init__()
302
+
303
+ num_patches = (height // patch_size) * (width // patch_size)
304
+ self.flatten = flatten
305
+ self.layer_norm = layer_norm
306
+
307
+ self.proj = nn.Conv2d(
308
+ in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
309
+ )
310
+ self.proj_t = Patch1D(
311
+ embed_dim, True, stride=patch_size
312
+ )
313
+ if layer_norm:
314
+ self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6)
315
+ else:
316
+ self.norm = None
317
+
318
+ self.patch_size = patch_size
319
+ # See:
320
+ # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L161
321
+ self.height, self.width = height // patch_size, width // patch_size
322
+ self.base_size = height // patch_size
323
+ self.interpolation_scale = interpolation_scale
324
+ pos_embed = get_2d_sincos_pos_embed(
325
+ embed_dim, int(num_patches**0.5), base_size=self.base_size, interpolation_scale=self.interpolation_scale
326
+ )
327
+ self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=False)
328
+
329
+ def forward(self, latent):
330
+ height, width = latent.shape[-2] // self.patch_size, latent.shape[-1] // self.patch_size
331
+ b, c, f, h, w = latent.size()
332
+ latent = rearrange(latent, "b c f h w -> (b f) c h w")
333
+ latent = self.proj(latent)
334
+ latent = rearrange(latent, "(b f) c h w -> b c f h w", f=f)
335
+
336
+ latent = rearrange(latent, "b c f h w -> (b h w) c f")
337
+ latent = self.proj_t(latent)
338
+ latent = rearrange(latent, "(b h w) c f -> b c f h w", h=h//2, w=w//2)
339
+
340
+ latent = rearrange(latent, "b c f h w -> (b f) c h w")
341
+ if self.flatten:
342
+ latent = latent.flatten(2).transpose(1, 2) # BCFHW -> BNC
343
+ if self.layer_norm:
344
+ latent = self.norm(latent)
345
+
346
+ # Interpolate positional embeddings if needed.
347
+ # (For PixArt-Alpha: https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L162C151-L162C160)
348
+ if self.height != height or self.width != width:
349
+ pos_embed = get_2d_sincos_pos_embed(
350
+ embed_dim=self.pos_embed.shape[-1],
351
+ grid_size=(height, width),
352
+ base_size=self.base_size,
353
+ interpolation_scale=self.interpolation_scale,
354
+ )
355
+ pos_embed = torch.from_numpy(pos_embed)
356
+ pos_embed = pos_embed.float().unsqueeze(0).to(latent.device)
357
+ else:
358
+ pos_embed = self.pos_embed
359
+
360
+ return (latent + pos_embed).to(latent.dtype)
361
+
362
+ class CasualPatchEmbed3D(nn.Module):
363
+ """3D Image to Patch Embedding"""
364
+
365
+ def __init__(
366
+ self,
367
+ height=224,
368
+ width=224,
369
+ patch_size=16,
370
+ time_patch_size=4,
371
+ in_channels=3,
372
+ embed_dim=768,
373
+ layer_norm=False,
374
+ flatten=True,
375
+ bias=True,
376
+ interpolation_scale=1,
377
+ ):
378
+ super().__init__()
379
+
380
+ num_patches = (height // patch_size) * (width // patch_size)
381
+ self.flatten = flatten
382
+ self.layer_norm = layer_norm
383
+
384
+ self.proj = CausalConv3d(
385
+ in_channels, embed_dim, kernel_size=(time_patch_size, patch_size, patch_size), stride=(time_patch_size, patch_size, patch_size), bias=bias, padding=None
386
+ )
387
+ if layer_norm:
388
+ self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6)
389
+ else:
390
+ self.norm = None
391
+
392
+ self.patch_size = patch_size
393
+ # See:
394
+ # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L161
395
+ self.height, self.width = height // patch_size, width // patch_size
396
+ self.base_size = height // patch_size
397
+ self.interpolation_scale = interpolation_scale
398
+ pos_embed = get_2d_sincos_pos_embed(
399
+ embed_dim, int(num_patches**0.5), base_size=self.base_size, interpolation_scale=self.interpolation_scale
400
+ )
401
+ self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=False)
402
+
403
+ def forward(self, latent):
404
+ height, width = latent.shape[-2] // self.patch_size, latent.shape[-1] // self.patch_size
405
+
406
+ latent = self.proj(latent)
407
+ latent = rearrange(latent, "b c f h w -> (b f) c h w")
408
+ if self.flatten:
409
+ latent = latent.flatten(2).transpose(1, 2) # BCFHW -> BNC
410
+ if self.layer_norm:
411
+ latent = self.norm(latent)
412
+ # Interpolate positional embeddings if needed.
413
+ # (For PixArt-Alpha: https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L162C151-L162C160)
414
+ if self.height != height or self.width != width:
415
+ pos_embed = get_2d_sincos_pos_embed(
416
+ embed_dim=self.pos_embed.shape[-1],
417
+ grid_size=(height, width),
418
+ base_size=self.base_size,
419
+ interpolation_scale=self.interpolation_scale,
420
+ )
421
+ pos_embed = torch.from_numpy(pos_embed)
422
+ pos_embed = pos_embed.float().unsqueeze(0).to(latent.device)
423
+ else:
424
+ pos_embed = self.pos_embed
425
+
426
+ return (latent + pos_embed).to(latent.dtype)
easyanimate/models/transformer2d.py ADDED
@@ -0,0 +1,555 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import json
15
+ import os
16
+ from dataclasses import dataclass
17
+ from typing import Any, Dict, Optional
18
+
19
+ import numpy as np
20
+ import torch
21
+ import torch.nn.functional as F
22
+ import torch.nn.init as init
23
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
24
+ from diffusers.models.attention import BasicTransformerBlock
25
+ from diffusers.models.embeddings import ImagePositionalEmbeddings, PatchEmbed
26
+ from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear
27
+ from diffusers.models.modeling_utils import ModelMixin
28
+ from diffusers.models.normalization import AdaLayerNormSingle
29
+ from diffusers.utils import (USE_PEFT_BACKEND, BaseOutput, deprecate,
30
+ is_torch_version)
31
+ from einops import rearrange
32
+ from torch import nn
33
+
34
+ try:
35
+ from diffusers.models.embeddings import PixArtAlphaTextProjection
36
+ except:
37
+ from diffusers.models.embeddings import \
38
+ CaptionProjection as PixArtAlphaTextProjection
39
+
40
+ from .attention import (KVCompressionTransformerBlock,
41
+ SelfAttentionTemporalTransformerBlock,
42
+ TemporalTransformerBlock)
43
+
44
+
45
+ @dataclass
46
+ class Transformer2DModelOutput(BaseOutput):
47
+ """
48
+ The output of [`Transformer2DModel`].
49
+
50
+ Args:
51
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
52
+ The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
53
+ distributions for the unnoised latent pixels.
54
+ """
55
+
56
+ sample: torch.FloatTensor
57
+
58
+
59
+ class Transformer2DModel(ModelMixin, ConfigMixin):
60
+ """
61
+ A 2D Transformer model for image-like data.
62
+
63
+ Parameters:
64
+ num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
65
+ attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
66
+ in_channels (`int`, *optional*):
67
+ The number of channels in the input and output (specify if the input is **continuous**).
68
+ num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
69
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
70
+ cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
71
+ sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
72
+ This is fixed during training since it is used to learn a number of position embeddings.
73
+ num_vector_embeds (`int`, *optional*):
74
+ The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**).
75
+ Includes the class for the masked latent pixel.
76
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
77
+ num_embeds_ada_norm ( `int`, *optional*):
78
+ The number of diffusion steps used during training. Pass if at least one of the norm_layers is
79
+ `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
80
+ added to the hidden states.
81
+
82
+ During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`.
83
+ attention_bias (`bool`, *optional*):
84
+ Configure if the `TransformerBlocks` attention should contain a bias parameter.
85
+ """
86
+ _supports_gradient_checkpointing = True
87
+
88
+ @register_to_config
89
+ def __init__(
90
+ self,
91
+ num_attention_heads: int = 16,
92
+ attention_head_dim: int = 88,
93
+ in_channels: Optional[int] = None,
94
+ out_channels: Optional[int] = None,
95
+ num_layers: int = 1,
96
+ dropout: float = 0.0,
97
+ norm_num_groups: int = 32,
98
+ cross_attention_dim: Optional[int] = None,
99
+ attention_bias: bool = False,
100
+ sample_size: Optional[int] = None,
101
+ num_vector_embeds: Optional[int] = None,
102
+ patch_size: Optional[int] = None,
103
+ activation_fn: str = "geglu",
104
+ num_embeds_ada_norm: Optional[int] = None,
105
+ use_linear_projection: bool = False,
106
+ only_cross_attention: bool = False,
107
+ double_self_attention: bool = False,
108
+ upcast_attention: bool = False,
109
+ norm_type: str = "layer_norm",
110
+ norm_elementwise_affine: bool = True,
111
+ norm_eps: float = 1e-5,
112
+ attention_type: str = "default",
113
+ caption_channels: int = None,
114
+ # block type
115
+ basic_block_type: str = "basic",
116
+ ):
117
+ super().__init__()
118
+ self.use_linear_projection = use_linear_projection
119
+ self.num_attention_heads = num_attention_heads
120
+ self.attention_head_dim = attention_head_dim
121
+ self.basic_block_type = basic_block_type
122
+ inner_dim = num_attention_heads * attention_head_dim
123
+
124
+ conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
125
+ linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
126
+
127
+ # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
128
+ # Define whether input is continuous or discrete depending on configuration
129
+ self.is_input_continuous = (in_channels is not None) and (patch_size is None)
130
+ self.is_input_vectorized = num_vector_embeds is not None
131
+ self.is_input_patches = in_channels is not None and patch_size is not None
132
+
133
+ if norm_type == "layer_norm" and num_embeds_ada_norm is not None:
134
+ deprecation_message = (
135
+ f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or"
136
+ " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config."
137
+ " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect"
138
+ " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it"
139
+ " would be very nice if you could open a Pull request for the `transformer/config.json` file"
140
+ )
141
+ deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False)
142
+ norm_type = "ada_norm"
143
+
144
+ if self.is_input_continuous and self.is_input_vectorized:
145
+ raise ValueError(
146
+ f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make"
147
+ " sure that either `in_channels` or `num_vector_embeds` is None."
148
+ )
149
+ elif self.is_input_vectorized and self.is_input_patches:
150
+ raise ValueError(
151
+ f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make"
152
+ " sure that either `num_vector_embeds` or `num_patches` is None."
153
+ )
154
+ elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches:
155
+ raise ValueError(
156
+ f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:"
157
+ f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None."
158
+ )
159
+
160
+ # 2. Define input layers
161
+ if self.is_input_continuous:
162
+ self.in_channels = in_channels
163
+
164
+ self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
165
+ if use_linear_projection:
166
+ self.proj_in = linear_cls(in_channels, inner_dim)
167
+ else:
168
+ self.proj_in = conv_cls(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
169
+ elif self.is_input_vectorized:
170
+ assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
171
+ assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed"
172
+
173
+ self.height = sample_size
174
+ self.width = sample_size
175
+ self.num_vector_embeds = num_vector_embeds
176
+ self.num_latent_pixels = self.height * self.width
177
+
178
+ self.latent_image_embedding = ImagePositionalEmbeddings(
179
+ num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width
180
+ )
181
+ elif self.is_input_patches:
182
+ assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size"
183
+
184
+ self.height = sample_size
185
+ self.width = sample_size
186
+
187
+ self.patch_size = patch_size
188
+ interpolation_scale = self.config.sample_size // 64 # => 64 (= 512 pixart) has interpolation scale 1
189
+ interpolation_scale = max(interpolation_scale, 1)
190
+ self.pos_embed = PatchEmbed(
191
+ height=sample_size,
192
+ width=sample_size,
193
+ patch_size=patch_size,
194
+ in_channels=in_channels,
195
+ embed_dim=inner_dim,
196
+ interpolation_scale=interpolation_scale,
197
+ )
198
+
199
+ basic_block = {
200
+ "basic": BasicTransformerBlock,
201
+ "kvcompression": KVCompressionTransformerBlock,
202
+ }[self.basic_block_type]
203
+ if self.basic_block_type == "kvcompression":
204
+ self.transformer_blocks = nn.ModuleList(
205
+ [
206
+ basic_block(
207
+ inner_dim,
208
+ num_attention_heads,
209
+ attention_head_dim,
210
+ dropout=dropout,
211
+ cross_attention_dim=cross_attention_dim,
212
+ activation_fn=activation_fn,
213
+ num_embeds_ada_norm=num_embeds_ada_norm,
214
+ attention_bias=attention_bias,
215
+ only_cross_attention=only_cross_attention,
216
+ double_self_attention=double_self_attention,
217
+ upcast_attention=upcast_attention,
218
+ norm_type=norm_type,
219
+ norm_elementwise_affine=norm_elementwise_affine,
220
+ norm_eps=norm_eps,
221
+ attention_type=attention_type,
222
+ kvcompression=False if d < 14 else True,
223
+ )
224
+ for d in range(num_layers)
225
+ ]
226
+ )
227
+ else:
228
+ # 3. Define transformers blocks
229
+ self.transformer_blocks = nn.ModuleList(
230
+ [
231
+ BasicTransformerBlock(
232
+ inner_dim,
233
+ num_attention_heads,
234
+ attention_head_dim,
235
+ dropout=dropout,
236
+ cross_attention_dim=cross_attention_dim,
237
+ activation_fn=activation_fn,
238
+ num_embeds_ada_norm=num_embeds_ada_norm,
239
+ attention_bias=attention_bias,
240
+ only_cross_attention=only_cross_attention,
241
+ double_self_attention=double_self_attention,
242
+ upcast_attention=upcast_attention,
243
+ norm_type=norm_type,
244
+ norm_elementwise_affine=norm_elementwise_affine,
245
+ norm_eps=norm_eps,
246
+ attention_type=attention_type,
247
+ )
248
+ for d in range(num_layers)
249
+ ]
250
+ )
251
+
252
+ # 4. Define output layers
253
+ self.out_channels = in_channels if out_channels is None else out_channels
254
+ if self.is_input_continuous:
255
+ # TODO: should use out_channels for continuous projections
256
+ if use_linear_projection:
257
+ self.proj_out = linear_cls(inner_dim, in_channels)
258
+ else:
259
+ self.proj_out = conv_cls(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
260
+ elif self.is_input_vectorized:
261
+ self.norm_out = nn.LayerNorm(inner_dim)
262
+ self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
263
+ elif self.is_input_patches and norm_type != "ada_norm_single":
264
+ self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
265
+ self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim)
266
+ self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
267
+ elif self.is_input_patches and norm_type == "ada_norm_single":
268
+ self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
269
+ self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5)
270
+ self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
271
+
272
+ # 5. PixArt-Alpha blocks.
273
+ self.adaln_single = None
274
+ self.use_additional_conditions = False
275
+ if norm_type == "ada_norm_single":
276
+ self.use_additional_conditions = self.config.sample_size == 128
277
+ # TODO(Sayak, PVP) clean this, for now we use sample size to determine whether to use
278
+ # additional conditions until we find better name
279
+ self.adaln_single = AdaLayerNormSingle(inner_dim, use_additional_conditions=self.use_additional_conditions)
280
+
281
+ self.caption_projection = None
282
+ if caption_channels is not None:
283
+ self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim)
284
+
285
+ self.gradient_checkpointing = False
286
+
287
+ def _set_gradient_checkpointing(self, module, value=False):
288
+ if hasattr(module, "gradient_checkpointing"):
289
+ module.gradient_checkpointing = value
290
+
291
+ def forward(
292
+ self,
293
+ hidden_states: torch.Tensor,
294
+ encoder_hidden_states: Optional[torch.Tensor] = None,
295
+ timestep: Optional[torch.LongTensor] = None,
296
+ added_cond_kwargs: Dict[str, torch.Tensor] = None,
297
+ class_labels: Optional[torch.LongTensor] = None,
298
+ cross_attention_kwargs: Dict[str, Any] = None,
299
+ attention_mask: Optional[torch.Tensor] = None,
300
+ encoder_attention_mask: Optional[torch.Tensor] = None,
301
+ return_dict: bool = True,
302
+ ):
303
+ """
304
+ The [`Transformer2DModel`] forward method.
305
+
306
+ Args:
307
+ hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
308
+ Input `hidden_states`.
309
+ encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
310
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
311
+ self-attention.
312
+ timestep ( `torch.LongTensor`, *optional*):
313
+ Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
314
+ class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
315
+ Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
316
+ `AdaLayerZeroNorm`.
317
+ cross_attention_kwargs ( `Dict[str, Any]`, *optional*):
318
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
319
+ `self.processor` in
320
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
321
+ attention_mask ( `torch.Tensor`, *optional*):
322
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
323
+ is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
324
+ negative values to the attention scores corresponding to "discard" tokens.
325
+ encoder_attention_mask ( `torch.Tensor`, *optional*):
326
+ Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
327
+
328
+ * Mask `(batch, sequence_length)` True = keep, False = discard.
329
+ * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
330
+
331
+ If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
332
+ above. This bias will be added to the cross-attention scores.
333
+ return_dict (`bool`, *optional*, defaults to `True`):
334
+ Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
335
+ tuple.
336
+
337
+ Returns:
338
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
339
+ `tuple` where the first element is the sample tensor.
340
+ """
341
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
342
+ # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
343
+ # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
344
+ # expects mask of shape:
345
+ # [batch, key_tokens]
346
+ # adds singleton query_tokens dimension:
347
+ # [batch, 1, key_tokens]
348
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
349
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
350
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
351
+ if attention_mask is not None and attention_mask.ndim == 2:
352
+ # assume that mask is expressed as:
353
+ # (1 = keep, 0 = discard)
354
+ # convert mask into a bias that can be added to attention scores:
355
+ # (keep = +0, discard = -10000.0)
356
+ attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
357
+ attention_mask = attention_mask.unsqueeze(1)
358
+
359
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
360
+ if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
361
+ encoder_attention_mask = (1 - encoder_attention_mask.to(encoder_hidden_states.dtype)) * -10000.0
362
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
363
+
364
+ # Retrieve lora scale.
365
+ lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
366
+
367
+ # 1. Input
368
+ if self.is_input_continuous:
369
+ batch, _, height, width = hidden_states.shape
370
+ residual = hidden_states
371
+
372
+ hidden_states = self.norm(hidden_states)
373
+ if not self.use_linear_projection:
374
+ hidden_states = (
375
+ self.proj_in(hidden_states, scale=lora_scale)
376
+ if not USE_PEFT_BACKEND
377
+ else self.proj_in(hidden_states)
378
+ )
379
+ inner_dim = hidden_states.shape[1]
380
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
381
+ else:
382
+ inner_dim = hidden_states.shape[1]
383
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
384
+ hidden_states = (
385
+ self.proj_in(hidden_states, scale=lora_scale)
386
+ if not USE_PEFT_BACKEND
387
+ else self.proj_in(hidden_states)
388
+ )
389
+
390
+ elif self.is_input_vectorized:
391
+ hidden_states = self.latent_image_embedding(hidden_states)
392
+ elif self.is_input_patches:
393
+ height, width = hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size
394
+ hidden_states = self.pos_embed(hidden_states)
395
+
396
+ if self.adaln_single is not None:
397
+ if self.use_additional_conditions and added_cond_kwargs is None:
398
+ raise ValueError(
399
+ "`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`."
400
+ )
401
+ batch_size = hidden_states.shape[0]
402
+ timestep, embedded_timestep = self.adaln_single(
403
+ timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype
404
+ )
405
+
406
+ # 2. Blocks
407
+ if self.caption_projection is not None:
408
+ batch_size = hidden_states.shape[0]
409
+ encoder_hidden_states = self.caption_projection(encoder_hidden_states)
410
+ encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
411
+
412
+ for block in self.transformer_blocks:
413
+ if self.training and self.gradient_checkpointing:
414
+ args = {
415
+ "basic": [],
416
+ "kvcompression": [1, height, width],
417
+ }[self.basic_block_type]
418
+ hidden_states = torch.utils.checkpoint.checkpoint(
419
+ block,
420
+ hidden_states,
421
+ attention_mask,
422
+ encoder_hidden_states,
423
+ encoder_attention_mask,
424
+ timestep,
425
+ cross_attention_kwargs,
426
+ class_labels,
427
+ *args,
428
+ use_reentrant=False,
429
+ )
430
+ else:
431
+ kwargs = {
432
+ "basic": {},
433
+ "kvcompression": {"num_frames":1, "height":height, "width":width},
434
+ }[self.basic_block_type]
435
+ hidden_states = block(
436
+ hidden_states,
437
+ attention_mask=attention_mask,
438
+ encoder_hidden_states=encoder_hidden_states,
439
+ encoder_attention_mask=encoder_attention_mask,
440
+ timestep=timestep,
441
+ cross_attention_kwargs=cross_attention_kwargs,
442
+ class_labels=class_labels,
443
+ **kwargs
444
+ )
445
+
446
+ # 3. Output
447
+ if self.is_input_continuous:
448
+ if not self.use_linear_projection:
449
+ hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
450
+ hidden_states = (
451
+ self.proj_out(hidden_states, scale=lora_scale)
452
+ if not USE_PEFT_BACKEND
453
+ else self.proj_out(hidden_states)
454
+ )
455
+ else:
456
+ hidden_states = (
457
+ self.proj_out(hidden_states, scale=lora_scale)
458
+ if not USE_PEFT_BACKEND
459
+ else self.proj_out(hidden_states)
460
+ )
461
+ hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
462
+
463
+ output = hidden_states + residual
464
+ elif self.is_input_vectorized:
465
+ hidden_states = self.norm_out(hidden_states)
466
+ logits = self.out(hidden_states)
467
+ # (batch, self.num_vector_embeds - 1, self.num_latent_pixels)
468
+ logits = logits.permute(0, 2, 1)
469
+
470
+ # log(p(x_0))
471
+ output = F.log_softmax(logits.double(), dim=1).float()
472
+
473
+ if self.is_input_patches:
474
+ if self.config.norm_type != "ada_norm_single":
475
+ conditioning = self.transformer_blocks[0].norm1.emb(
476
+ timestep, class_labels, hidden_dtype=hidden_states.dtype
477
+ )
478
+ shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
479
+ hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
480
+ hidden_states = self.proj_out_2(hidden_states)
481
+ elif self.config.norm_type == "ada_norm_single":
482
+ shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1)
483
+ hidden_states = self.norm_out(hidden_states)
484
+ # Modulation
485
+ hidden_states = hidden_states * (1 + scale) + shift
486
+ hidden_states = self.proj_out(hidden_states)
487
+ hidden_states = hidden_states.squeeze(1)
488
+
489
+ # unpatchify
490
+ if self.adaln_single is None:
491
+ height = width = int(hidden_states.shape[1] ** 0.5)
492
+ hidden_states = hidden_states.reshape(
493
+ shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
494
+ )
495
+ hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
496
+ output = hidden_states.reshape(
497
+ shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
498
+ )
499
+
500
+ if not return_dict:
501
+ return (output,)
502
+
503
+ return Transformer2DModelOutput(sample=output)
504
+
505
+ @classmethod
506
+ def from_pretrained(cls, pretrained_model_path, subfolder=None, patch_size=2, transformer_additional_kwargs={}):
507
+ if subfolder is not None:
508
+ pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
509
+ print(f"loaded 2D transformer's pretrained weights from {pretrained_model_path} ...")
510
+
511
+ config_file = os.path.join(pretrained_model_path, 'config.json')
512
+ if not os.path.isfile(config_file):
513
+ raise RuntimeError(f"{config_file} does not exist")
514
+ with open(config_file, "r") as f:
515
+ config = json.load(f)
516
+
517
+ from diffusers.utils import WEIGHTS_NAME
518
+ model = cls.from_config(config, **transformer_additional_kwargs)
519
+ model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
520
+ model_file_safetensors = model_file.replace(".bin", ".safetensors")
521
+ if os.path.exists(model_file_safetensors):
522
+ from safetensors.torch import load_file, safe_open
523
+ state_dict = load_file(model_file_safetensors)
524
+ else:
525
+ if not os.path.isfile(model_file):
526
+ raise RuntimeError(f"{model_file} does not exist")
527
+ state_dict = torch.load(model_file, map_location="cpu")
528
+
529
+ if model.state_dict()['pos_embed.proj.weight'].size() != state_dict['pos_embed.proj.weight'].size():
530
+ new_shape = model.state_dict()['pos_embed.proj.weight'].size()
531
+ state_dict['pos_embed.proj.weight'] = torch.tile(state_dict['proj_out.weight'], [1, 2, 1, 1])
532
+
533
+ if model.state_dict()['proj_out.weight'].size() != state_dict['proj_out.weight'].size():
534
+ new_shape = model.state_dict()['proj_out.weight'].size()
535
+ state_dict['proj_out.weight'] = torch.tile(state_dict['proj_out.weight'], [patch_size, 1])
536
+
537
+ if model.state_dict()['proj_out.bias'].size() != state_dict['proj_out.bias'].size():
538
+ new_shape = model.state_dict()['proj_out.bias'].size()
539
+ state_dict['proj_out.bias'] = torch.tile(state_dict['proj_out.bias'], [patch_size])
540
+
541
+ tmp_state_dict = {}
542
+ for key in state_dict:
543
+ if key in model.state_dict().keys() and model.state_dict()[key].size() == state_dict[key].size():
544
+ tmp_state_dict[key] = state_dict[key]
545
+ else:
546
+ print(key, "Size don't match, skip")
547
+ state_dict = tmp_state_dict
548
+
549
+ m, u = model.load_state_dict(state_dict, strict=False)
550
+ print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
551
+
552
+ params = [p.numel() if "motion_modules." in n else 0 for n, p in model.named_parameters()]
553
+ print(f"### Postion Parameters: {sum(params) / 1e6} M")
554
+
555
+ return model
easyanimate/models/transformer3d.py ADDED
@@ -0,0 +1,738 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import json
15
+ import math
16
+ import os
17
+ from dataclasses import dataclass
18
+ from typing import Any, Dict, Optional
19
+
20
+ import numpy as np
21
+ import torch
22
+ import torch.nn.functional as F
23
+ import torch.nn.init as init
24
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
25
+ from diffusers.models.attention import BasicTransformerBlock
26
+ from diffusers.models.embeddings import PatchEmbed, Timesteps, TimestepEmbedding
27
+ from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear
28
+ from diffusers.models.modeling_utils import ModelMixin
29
+ from diffusers.models.normalization import AdaLayerNormSingle
30
+ from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, is_torch_version
31
+ from einops import rearrange
32
+ from torch import nn
33
+ from typing import Dict, Optional, Tuple
34
+
35
+ from .attention import (SelfAttentionTemporalTransformerBlock,
36
+ TemporalTransformerBlock)
37
+ from .patch import Patch1D, PatchEmbed3D, PatchEmbedF3D, UnPatch1D, TemporalUpsampler3D, CasualPatchEmbed3D
38
+
39
+ try:
40
+ from diffusers.models.embeddings import PixArtAlphaTextProjection
41
+ except:
42
+ from diffusers.models.embeddings import \
43
+ CaptionProjection as PixArtAlphaTextProjection
44
+
45
+ def zero_module(module):
46
+ # Zero out the parameters of a module and return it.
47
+ for p in module.parameters():
48
+ p.detach().zero_()
49
+ return module
50
+
51
+ class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module):
52
+ """
53
+ For PixArt-Alpha.
54
+
55
+ Reference:
56
+ https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L164C9-L168C29
57
+ """
58
+
59
+ def __init__(self, embedding_dim, size_emb_dim, use_additional_conditions: bool = False):
60
+ super().__init__()
61
+
62
+ self.outdim = size_emb_dim
63
+ self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
64
+ self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
65
+
66
+ self.use_additional_conditions = use_additional_conditions
67
+ if use_additional_conditions:
68
+ self.additional_condition_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
69
+ self.resolution_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim)
70
+ self.aspect_ratio_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim)
71
+
72
+ self.resolution_embedder.linear_2 = zero_module(self.resolution_embedder.linear_2)
73
+ self.aspect_ratio_embedder.linear_2 = zero_module(self.aspect_ratio_embedder.linear_2)
74
+
75
+ def forward(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype):
76
+ timesteps_proj = self.time_proj(timestep)
77
+ timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D)
78
+
79
+ if self.use_additional_conditions:
80
+ resolution_emb = self.additional_condition_proj(resolution.flatten()).to(hidden_dtype)
81
+ resolution_emb = self.resolution_embedder(resolution_emb).reshape(batch_size, -1)
82
+ aspect_ratio_emb = self.additional_condition_proj(aspect_ratio.flatten()).to(hidden_dtype)
83
+ aspect_ratio_emb = self.aspect_ratio_embedder(aspect_ratio_emb).reshape(batch_size, -1)
84
+ conditioning = timesteps_emb + torch.cat([resolution_emb, aspect_ratio_emb], dim=1)
85
+ else:
86
+ conditioning = timesteps_emb
87
+
88
+ return conditioning
89
+
90
+ class AdaLayerNormSingle(nn.Module):
91
+ r"""
92
+ Norm layer adaptive layer norm single (adaLN-single).
93
+
94
+ As proposed in PixArt-Alpha (see: https://arxiv.org/abs/2310.00426; Section 2.3).
95
+
96
+ Parameters:
97
+ embedding_dim (`int`): The size of each embedding vector.
98
+ use_additional_conditions (`bool`): To use additional conditions for normalization or not.
99
+ """
100
+
101
+ def __init__(self, embedding_dim: int, use_additional_conditions: bool = False):
102
+ super().__init__()
103
+
104
+ self.emb = PixArtAlphaCombinedTimestepSizeEmbeddings(
105
+ embedding_dim, size_emb_dim=embedding_dim // 3, use_additional_conditions=use_additional_conditions
106
+ )
107
+
108
+ self.silu = nn.SiLU()
109
+ self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True)
110
+
111
+ def forward(
112
+ self,
113
+ timestep: torch.Tensor,
114
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
115
+ batch_size: Optional[int] = None,
116
+ hidden_dtype: Optional[torch.dtype] = None,
117
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
118
+ # No modulation happening here.
119
+ embedded_timestep = self.emb(timestep, **added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_dtype)
120
+ return self.linear(self.silu(embedded_timestep)), embedded_timestep
121
+
122
+
123
+ class TimePositionalEncoding(nn.Module):
124
+ def __init__(
125
+ self,
126
+ d_model,
127
+ dropout = 0.,
128
+ max_len = 24
129
+ ):
130
+ super().__init__()
131
+ self.dropout = nn.Dropout(p=dropout)
132
+ position = torch.arange(max_len).unsqueeze(1)
133
+ div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
134
+ pe = torch.zeros(1, max_len, d_model)
135
+ pe[0, :, 0::2] = torch.sin(position * div_term)
136
+ pe[0, :, 1::2] = torch.cos(position * div_term)
137
+ self.register_buffer('pe', pe)
138
+
139
+ def forward(self, x):
140
+ b, c, f, h, w = x.size()
141
+ x = rearrange(x, "b c f h w -> (b h w) f c")
142
+ x = x + self.pe[:, :x.size(1)]
143
+ x = rearrange(x, "(b h w) f c -> b c f h w", b=b, h=h, w=w)
144
+ return self.dropout(x)
145
+
146
+ @dataclass
147
+ class Transformer3DModelOutput(BaseOutput):
148
+ """
149
+ The output of [`Transformer2DModel`].
150
+
151
+ Args:
152
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
153
+ The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
154
+ distributions for the unnoised latent pixels.
155
+ """
156
+
157
+ sample: torch.FloatTensor
158
+
159
+
160
+ class Transformer3DModel(ModelMixin, ConfigMixin):
161
+ """
162
+ A 3D Transformer model for image-like data.
163
+
164
+ Parameters:
165
+ num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
166
+ attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
167
+ in_channels (`int`, *optional*):
168
+ The number of channels in the input and output (specify if the input is **continuous**).
169
+ num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
170
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
171
+ cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
172
+ sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
173
+ This is fixed during training since it is used to learn a number of position embeddings.
174
+ num_vector_embeds (`int`, *optional*):
175
+ The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**).
176
+ Includes the class for the masked latent pixel.
177
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
178
+ num_embeds_ada_norm ( `int`, *optional*):
179
+ The number of diffusion steps used during training. Pass if at least one of the norm_layers is
180
+ `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
181
+ added to the hidden states.
182
+
183
+ During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`.
184
+ attention_bias (`bool`, *optional*):
185
+ Configure if the `TransformerBlocks` attention should contain a bias parameter.
186
+ """
187
+
188
+ _supports_gradient_checkpointing = True
189
+
190
+ @register_to_config
191
+ def __init__(
192
+ self,
193
+ num_attention_heads: int = 16,
194
+ attention_head_dim: int = 88,
195
+ in_channels: Optional[int] = None,
196
+ out_channels: Optional[int] = None,
197
+ num_layers: int = 1,
198
+ dropout: float = 0.0,
199
+ norm_num_groups: int = 32,
200
+ cross_attention_dim: Optional[int] = None,
201
+ attention_bias: bool = False,
202
+ sample_size: Optional[int] = None,
203
+ num_vector_embeds: Optional[int] = None,
204
+ patch_size: Optional[int] = None,
205
+ activation_fn: str = "geglu",
206
+ num_embeds_ada_norm: Optional[int] = None,
207
+ use_linear_projection: bool = False,
208
+ only_cross_attention: bool = False,
209
+ double_self_attention: bool = False,
210
+ upcast_attention: bool = False,
211
+ norm_type: str = "layer_norm",
212
+ norm_elementwise_affine: bool = True,
213
+ norm_eps: float = 1e-5,
214
+ attention_type: str = "default",
215
+ caption_channels: int = None,
216
+ # block type
217
+ basic_block_type: str = "motionmodule",
218
+ # enable_uvit
219
+ enable_uvit: bool = False,
220
+
221
+ # 3d patch params
222
+ patch_3d: bool = False,
223
+ fake_3d: bool = False,
224
+ time_patch_size: Optional[int] = None,
225
+
226
+ casual_3d: bool = False,
227
+ casual_3d_upsampler_index: Optional[list] = None,
228
+
229
+ # motion module kwargs
230
+ motion_module_type = "VanillaGrid",
231
+ motion_module_kwargs = None,
232
+
233
+ # time position encoding
234
+ time_position_encoding_before_transformer = False
235
+ ):
236
+ super().__init__()
237
+ self.use_linear_projection = use_linear_projection
238
+ self.num_attention_heads = num_attention_heads
239
+ self.attention_head_dim = attention_head_dim
240
+ self.enable_uvit = enable_uvit
241
+ inner_dim = num_attention_heads * attention_head_dim
242
+ self.basic_block_type = basic_block_type
243
+ self.patch_3d = patch_3d
244
+ self.fake_3d = fake_3d
245
+ self.casual_3d = casual_3d
246
+ self.casual_3d_upsampler_index = casual_3d_upsampler_index
247
+
248
+ conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
249
+ linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
250
+
251
+ assert sample_size is not None, "Transformer3DModel over patched input must provide sample_size"
252
+
253
+ self.height = sample_size
254
+ self.width = sample_size
255
+
256
+ self.patch_size = patch_size
257
+ self.time_patch_size = self.patch_size if time_patch_size is None else time_patch_size
258
+ interpolation_scale = self.config.sample_size // 64 # => 64 (= 512 pixart) has interpolation scale 1
259
+ interpolation_scale = max(interpolation_scale, 1)
260
+
261
+ if self.casual_3d:
262
+ self.pos_embed = CasualPatchEmbed3D(
263
+ height=sample_size,
264
+ width=sample_size,
265
+ patch_size=patch_size,
266
+ time_patch_size=self.time_patch_size,
267
+ in_channels=in_channels,
268
+ embed_dim=inner_dim,
269
+ interpolation_scale=interpolation_scale,
270
+ )
271
+ elif self.patch_3d:
272
+ if self.fake_3d:
273
+ self.pos_embed = PatchEmbedF3D(
274
+ height=sample_size,
275
+ width=sample_size,
276
+ patch_size=patch_size,
277
+ in_channels=in_channels,
278
+ embed_dim=inner_dim,
279
+ interpolation_scale=interpolation_scale,
280
+ )
281
+ else:
282
+ self.pos_embed = PatchEmbed3D(
283
+ height=sample_size,
284
+ width=sample_size,
285
+ patch_size=patch_size,
286
+ time_patch_size=self.time_patch_size,
287
+ in_channels=in_channels,
288
+ embed_dim=inner_dim,
289
+ interpolation_scale=interpolation_scale,
290
+ )
291
+ else:
292
+ self.pos_embed = PatchEmbed(
293
+ height=sample_size,
294
+ width=sample_size,
295
+ patch_size=patch_size,
296
+ in_channels=in_channels,
297
+ embed_dim=inner_dim,
298
+ interpolation_scale=interpolation_scale,
299
+ )
300
+
301
+ # 3. Define transformers blocks
302
+ if self.basic_block_type == "motionmodule":
303
+ self.transformer_blocks = nn.ModuleList(
304
+ [
305
+ TemporalTransformerBlock(
306
+ inner_dim,
307
+ num_attention_heads,
308
+ attention_head_dim,
309
+ dropout=dropout,
310
+ cross_attention_dim=cross_attention_dim,
311
+ activation_fn=activation_fn,
312
+ num_embeds_ada_norm=num_embeds_ada_norm,
313
+ attention_bias=attention_bias,
314
+ only_cross_attention=only_cross_attention,
315
+ double_self_attention=double_self_attention,
316
+ upcast_attention=upcast_attention,
317
+ norm_type=norm_type,
318
+ norm_elementwise_affine=norm_elementwise_affine,
319
+ norm_eps=norm_eps,
320
+ attention_type=attention_type,
321
+ motion_module_type=motion_module_type,
322
+ motion_module_kwargs=motion_module_kwargs,
323
+ )
324
+ for d in range(num_layers)
325
+ ]
326
+ )
327
+ elif self.basic_block_type == "kvcompression_motionmodule":
328
+ self.transformer_blocks = nn.ModuleList(
329
+ [
330
+ TemporalTransformerBlock(
331
+ inner_dim,
332
+ num_attention_heads,
333
+ attention_head_dim,
334
+ dropout=dropout,
335
+ cross_attention_dim=cross_attention_dim,
336
+ activation_fn=activation_fn,
337
+ num_embeds_ada_norm=num_embeds_ada_norm,
338
+ attention_bias=attention_bias,
339
+ only_cross_attention=only_cross_attention,
340
+ double_self_attention=double_self_attention,
341
+ upcast_attention=upcast_attention,
342
+ norm_type=norm_type,
343
+ norm_elementwise_affine=norm_elementwise_affine,
344
+ norm_eps=norm_eps,
345
+ attention_type=attention_type,
346
+ kvcompression=False if d < 14 else True,
347
+ motion_module_type=motion_module_type,
348
+ motion_module_kwargs=motion_module_kwargs,
349
+ )
350
+ for d in range(num_layers)
351
+ ]
352
+ )
353
+ elif self.basic_block_type == "selfattentiontemporal":
354
+ self.transformer_blocks = nn.ModuleList(
355
+ [
356
+ SelfAttentionTemporalTransformerBlock(
357
+ inner_dim,
358
+ num_attention_heads,
359
+ attention_head_dim,
360
+ dropout=dropout,
361
+ cross_attention_dim=cross_attention_dim,
362
+ activation_fn=activation_fn,
363
+ num_embeds_ada_norm=num_embeds_ada_norm,
364
+ attention_bias=attention_bias,
365
+ only_cross_attention=only_cross_attention,
366
+ double_self_attention=double_self_attention,
367
+ upcast_attention=upcast_attention,
368
+ norm_type=norm_type,
369
+ norm_elementwise_affine=norm_elementwise_affine,
370
+ norm_eps=norm_eps,
371
+ attention_type=attention_type,
372
+ )
373
+ for d in range(num_layers)
374
+ ]
375
+ )
376
+ else:
377
+ self.transformer_blocks = nn.ModuleList(
378
+ [
379
+ BasicTransformerBlock(
380
+ inner_dim,
381
+ num_attention_heads,
382
+ attention_head_dim,
383
+ dropout=dropout,
384
+ cross_attention_dim=cross_attention_dim,
385
+ activation_fn=activation_fn,
386
+ num_embeds_ada_norm=num_embeds_ada_norm,
387
+ attention_bias=attention_bias,
388
+ only_cross_attention=only_cross_attention,
389
+ double_self_attention=double_self_attention,
390
+ upcast_attention=upcast_attention,
391
+ norm_type=norm_type,
392
+ norm_elementwise_affine=norm_elementwise_affine,
393
+ norm_eps=norm_eps,
394
+ attention_type=attention_type,
395
+ )
396
+ for d in range(num_layers)
397
+ ]
398
+ )
399
+
400
+ if self.casual_3d:
401
+ self.unpatch1d = TemporalUpsampler3D()
402
+ elif self.patch_3d and self.fake_3d:
403
+ self.unpatch1d = UnPatch1D(inner_dim, True)
404
+
405
+ if self.enable_uvit:
406
+ self.long_connect_fc = nn.ModuleList(
407
+ [
408
+ nn.Linear(inner_dim, inner_dim, True) for d in range(13)
409
+ ]
410
+ )
411
+ for index in range(13):
412
+ self.long_connect_fc[index] = zero_module(self.long_connect_fc[index])
413
+
414
+ # 4. Define output layers
415
+ self.out_channels = in_channels if out_channels is None else out_channels
416
+ if norm_type != "ada_norm_single":
417
+ self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
418
+ self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim)
419
+ if self.patch_3d and not self.fake_3d:
420
+ self.proj_out_2 = nn.Linear(inner_dim, self.time_patch_size * patch_size * patch_size * self.out_channels)
421
+ else:
422
+ self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
423
+ elif norm_type == "ada_norm_single":
424
+ self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
425
+ self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5)
426
+ if self.patch_3d and not self.fake_3d:
427
+ self.proj_out = nn.Linear(inner_dim, self.time_patch_size * patch_size * patch_size * self.out_channels)
428
+ else:
429
+ self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
430
+
431
+ # 5. PixArt-Alpha blocks.
432
+ self.adaln_single = None
433
+ self.use_additional_conditions = False
434
+ if norm_type == "ada_norm_single":
435
+ self.use_additional_conditions = self.config.sample_size == 128
436
+ # TODO(Sayak, PVP) clean this, for now we use sample size to determine whether to use
437
+ # additional conditions until we find better name
438
+ self.adaln_single = AdaLayerNormSingle(inner_dim, use_additional_conditions=self.use_additional_conditions)
439
+
440
+ self.caption_projection = None
441
+ if caption_channels is not None:
442
+ self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim)
443
+
444
+ self.gradient_checkpointing = False
445
+
446
+ self.time_position_encoding_before_transformer = time_position_encoding_before_transformer
447
+ if self.time_position_encoding_before_transformer:
448
+ self.t_pos = TimePositionalEncoding(max_len = 4096, d_model = inner_dim)
449
+
450
+ def _set_gradient_checkpointing(self, module, value=False):
451
+ if hasattr(module, "gradient_checkpointing"):
452
+ module.gradient_checkpointing = value
453
+
454
+ def forward(
455
+ self,
456
+ hidden_states: torch.Tensor,
457
+ inpaint_latents: torch.Tensor = None,
458
+ encoder_hidden_states: Optional[torch.Tensor] = None,
459
+ timestep: Optional[torch.LongTensor] = None,
460
+ added_cond_kwargs: Dict[str, torch.Tensor] = None,
461
+ class_labels: Optional[torch.LongTensor] = None,
462
+ cross_attention_kwargs: Dict[str, Any] = None,
463
+ attention_mask: Optional[torch.Tensor] = None,
464
+ encoder_attention_mask: Optional[torch.Tensor] = None,
465
+ return_dict: bool = True,
466
+ ):
467
+ """
468
+ The [`Transformer2DModel`] forward method.
469
+
470
+ Args:
471
+ hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
472
+ Input `hidden_states`.
473
+ encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
474
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
475
+ self-attention.
476
+ timestep ( `torch.LongTensor`, *optional*):
477
+ Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
478
+ class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
479
+ Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
480
+ `AdaLayerZeroNorm`.
481
+ cross_attention_kwargs ( `Dict[str, Any]`, *optional*):
482
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
483
+ `self.processor` in
484
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
485
+ attention_mask ( `torch.Tensor`, *optional*):
486
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
487
+ is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
488
+ negative values to the attention scores corresponding to "discard" tokens.
489
+ encoder_attention_mask ( `torch.Tensor`, *optional*):
490
+ Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
491
+
492
+ * Mask `(batch, sequence_length)` True = keep, False = discard.
493
+ * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
494
+
495
+ If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
496
+ above. This bias will be added to the cross-attention scores.
497
+ return_dict (`bool`, *optional*, defaults to `True`):
498
+ Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
499
+ tuple.
500
+
501
+ Returns:
502
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer3DModelOutput`] is returned, otherwise a
503
+ `tuple` where the first element is the sample tensor.
504
+ """
505
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
506
+ # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
507
+ # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
508
+ # expects mask of shape:
509
+ # [batch, key_tokens]
510
+ # adds singleton query_tokens dimension:
511
+ # [batch, 1, key_tokens]
512
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
513
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
514
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
515
+ if attention_mask is not None and attention_mask.ndim == 2:
516
+ # assume that mask is expressed as:
517
+ # (1 = keep, 0 = discard)
518
+ # convert mask into a bias that can be added to attention scores:
519
+ # (keep = +0, discard = -10000.0)
520
+ attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
521
+ attention_mask = attention_mask.unsqueeze(1)
522
+
523
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
524
+ if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
525
+ encoder_attention_mask = (1 - encoder_attention_mask.to(encoder_hidden_states.dtype)) * -10000.0
526
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
527
+
528
+ if inpaint_latents is not None:
529
+ hidden_states = torch.concat([hidden_states, inpaint_latents], 1)
530
+ # 1. Input
531
+ if self.casual_3d:
532
+ video_length, height, width = (hidden_states.shape[-3] - 1) // self.time_patch_size + 1, hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size
533
+ elif self.patch_3d:
534
+ video_length, height, width = hidden_states.shape[-3] // self.time_patch_size, hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size
535
+ else:
536
+ video_length, height, width = hidden_states.shape[-3], hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size
537
+ hidden_states = rearrange(hidden_states, "b c f h w ->(b f) c h w")
538
+
539
+ hidden_states = self.pos_embed(hidden_states)
540
+ if self.adaln_single is not None:
541
+ if self.use_additional_conditions and added_cond_kwargs is None:
542
+ raise ValueError(
543
+ "`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`."
544
+ )
545
+ batch_size = hidden_states.shape[0] // video_length
546
+ timestep, embedded_timestep = self.adaln_single(
547
+ timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype
548
+ )
549
+ hidden_states = rearrange(hidden_states, "(b f) (h w) c -> b c f h w", f=video_length, h=height, w=width)
550
+
551
+ # hidden_states
552
+ # bs, c, f, h, w => b (f h w ) c
553
+ if self.time_position_encoding_before_transformer:
554
+ hidden_states = self.t_pos(hidden_states)
555
+ hidden_states = hidden_states.flatten(2).transpose(1, 2)
556
+
557
+ # 2. Blocks
558
+ if self.caption_projection is not None:
559
+ batch_size = hidden_states.shape[0]
560
+ encoder_hidden_states = self.caption_projection(encoder_hidden_states)
561
+ encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
562
+
563
+ skips = []
564
+ skip_index = 0
565
+ for index, block in enumerate(self.transformer_blocks):
566
+ if self.enable_uvit:
567
+ if index >= 15:
568
+ long_connect = self.long_connect_fc[skip_index](skips.pop())
569
+ hidden_states = hidden_states + long_connect
570
+ skip_index += 1
571
+
572
+ if self.casual_3d_upsampler_index is not None and index in self.casual_3d_upsampler_index:
573
+ hidden_states = rearrange(hidden_states, "b (f h w) c -> b c f h w", f=video_length, h=height, w=width)
574
+ hidden_states = self.unpatch1d(hidden_states)
575
+ video_length = (video_length - 1) * 2 + 1
576
+ hidden_states = rearrange(hidden_states, "b c f h w -> b (f h w) c", f=video_length, h=height, w=width)
577
+
578
+ if self.training and self.gradient_checkpointing:
579
+
580
+ def create_custom_forward(module, return_dict=None):
581
+ def custom_forward(*inputs):
582
+ if return_dict is not None:
583
+ return module(*inputs, return_dict=return_dict)
584
+ else:
585
+ return module(*inputs)
586
+
587
+ return custom_forward
588
+
589
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
590
+ args = {
591
+ "basic": [],
592
+ "motionmodule": [video_length, height, width],
593
+ "selfattentiontemporal": [video_length, height, width],
594
+ "kvcompression_motionmodule": [video_length, height, width],
595
+ }[self.basic_block_type]
596
+ hidden_states = torch.utils.checkpoint.checkpoint(
597
+ create_custom_forward(block),
598
+ hidden_states,
599
+ attention_mask,
600
+ encoder_hidden_states,
601
+ encoder_attention_mask,
602
+ timestep,
603
+ cross_attention_kwargs,
604
+ class_labels,
605
+ *args,
606
+ **ckpt_kwargs,
607
+ )
608
+ else:
609
+ kwargs = {
610
+ "basic": {},
611
+ "motionmodule": {"num_frames":video_length, "height":height, "width":width},
612
+ "selfattentiontemporal": {"num_frames":video_length, "height":height, "width":width},
613
+ "kvcompression_motionmodule": {"num_frames":video_length, "height":height, "width":width},
614
+ }[self.basic_block_type]
615
+ hidden_states = block(
616
+ hidden_states,
617
+ attention_mask=attention_mask,
618
+ encoder_hidden_states=encoder_hidden_states,
619
+ encoder_attention_mask=encoder_attention_mask,
620
+ timestep=timestep,
621
+ cross_attention_kwargs=cross_attention_kwargs,
622
+ class_labels=class_labels,
623
+ **kwargs
624
+ )
625
+
626
+ if self.enable_uvit:
627
+ if index < 13:
628
+ skips.append(hidden_states)
629
+
630
+ if self.fake_3d and self.patch_3d:
631
+ hidden_states = rearrange(hidden_states, "b (f h w) c -> (b h w) c f", f=video_length, w=width, h=height)
632
+ hidden_states = self.unpatch1d(hidden_states)
633
+ hidden_states = rearrange(hidden_states, "(b h w) c f -> b (f h w) c", w=width, h=height)
634
+
635
+ # 3. Output
636
+ if self.config.norm_type != "ada_norm_single":
637
+ conditioning = self.transformer_blocks[0].norm1.emb(
638
+ timestep, class_labels, hidden_dtype=hidden_states.dtype
639
+ )
640
+ shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
641
+ hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
642
+ hidden_states = self.proj_out_2(hidden_states)
643
+ elif self.config.norm_type == "ada_norm_single":
644
+ shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1)
645
+ hidden_states = self.norm_out(hidden_states)
646
+ # Modulation
647
+ hidden_states = hidden_states * (1 + scale) + shift
648
+ hidden_states = self.proj_out(hidden_states)
649
+ hidden_states = hidden_states.squeeze(1)
650
+
651
+ # unpatchify
652
+ if self.adaln_single is None:
653
+ height = width = int(hidden_states.shape[1] ** 0.5)
654
+ if self.patch_3d:
655
+ if self.fake_3d:
656
+ hidden_states = hidden_states.reshape(
657
+ shape=(-1, video_length * self.patch_size, height, width, self.patch_size, self.patch_size, self.out_channels)
658
+ )
659
+ hidden_states = torch.einsum("nfhwpqc->ncfhpwq", hidden_states)
660
+ else:
661
+ hidden_states = hidden_states.reshape(
662
+ shape=(-1, video_length, height, width, self.time_patch_size, self.patch_size, self.patch_size, self.out_channels)
663
+ )
664
+ hidden_states = torch.einsum("nfhwopqc->ncfohpwq", hidden_states)
665
+ output = hidden_states.reshape(
666
+ shape=(-1, self.out_channels, video_length * self.time_patch_size, height * self.patch_size, width * self.patch_size)
667
+ )
668
+ else:
669
+ hidden_states = hidden_states.reshape(
670
+ shape=(-1, video_length, height, width, self.patch_size, self.patch_size, self.out_channels)
671
+ )
672
+ hidden_states = torch.einsum("nfhwpqc->ncfhpwq", hidden_states)
673
+ output = hidden_states.reshape(
674
+ shape=(-1, self.out_channels, video_length, height * self.patch_size, width * self.patch_size)
675
+ )
676
+
677
+ if not return_dict:
678
+ return (output,)
679
+
680
+ return Transformer3DModelOutput(sample=output)
681
+
682
+ @classmethod
683
+ def from_pretrained_2d(cls, pretrained_model_path, subfolder=None, patch_size=2, transformer_additional_kwargs={}):
684
+ if subfolder is not None:
685
+ pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
686
+ print(f"loaded 3D transformer's pretrained weights from {pretrained_model_path} ...")
687
+
688
+ config_file = os.path.join(pretrained_model_path, 'config.json')
689
+ if not os.path.isfile(config_file):
690
+ raise RuntimeError(f"{config_file} does not exist")
691
+ with open(config_file, "r") as f:
692
+ config = json.load(f)
693
+
694
+ from diffusers.utils import WEIGHTS_NAME
695
+ model = cls.from_config(config, **transformer_additional_kwargs)
696
+ model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
697
+ model_file_safetensors = model_file.replace(".bin", ".safetensors")
698
+ if os.path.exists(model_file_safetensors):
699
+ from safetensors.torch import load_file, safe_open
700
+ state_dict = load_file(model_file_safetensors)
701
+ else:
702
+ if not os.path.isfile(model_file):
703
+ raise RuntimeError(f"{model_file} does not exist")
704
+ state_dict = torch.load(model_file, map_location="cpu")
705
+
706
+ if model.state_dict()['pos_embed.proj.weight'].size() != state_dict['pos_embed.proj.weight'].size():
707
+ new_shape = model.state_dict()['pos_embed.proj.weight'].size()
708
+ if len(new_shape) == 5:
709
+ state_dict['pos_embed.proj.weight'] = state_dict['pos_embed.proj.weight'].unsqueeze(2).expand(new_shape).clone()
710
+ state_dict['pos_embed.proj.weight'][:, :, :-1] = 0
711
+ else:
712
+ model.state_dict()['pos_embed.proj.weight'][:, :4, :, :] = state_dict['pos_embed.proj.weight']
713
+ model.state_dict()['pos_embed.proj.weight'][:, 4:, :, :] = 0
714
+ state_dict['pos_embed.proj.weight'] = model.state_dict()['pos_embed.proj.weight']
715
+
716
+ if model.state_dict()['proj_out.weight'].size() != state_dict['proj_out.weight'].size():
717
+ new_shape = model.state_dict()['proj_out.weight'].size()
718
+ state_dict['proj_out.weight'] = torch.tile(state_dict['proj_out.weight'], [patch_size, 1])
719
+
720
+ if model.state_dict()['proj_out.bias'].size() != state_dict['proj_out.bias'].size():
721
+ new_shape = model.state_dict()['proj_out.bias'].size()
722
+ state_dict['proj_out.bias'] = torch.tile(state_dict['proj_out.bias'], [patch_size])
723
+
724
+ tmp_state_dict = {}
725
+ for key in state_dict:
726
+ if key in model.state_dict().keys() and model.state_dict()[key].size() == state_dict[key].size():
727
+ tmp_state_dict[key] = state_dict[key]
728
+ else:
729
+ print(key, "Size don't match, skip")
730
+ state_dict = tmp_state_dict
731
+
732
+ m, u = model.load_state_dict(state_dict, strict=False)
733
+ print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
734
+
735
+ params = [p.numel() if "attn_temporal." in n else 0 for n, p in model.named_parameters()]
736
+ print(f"### Attn temporal Parameters: {sum(params) / 1e6} M")
737
+
738
+ return model
easyanimate/pipeline/pipeline_easyanimate.py ADDED
@@ -0,0 +1,847 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 PixArt-Alpha Authors and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import html
16
+ import inspect
17
+ import copy
18
+ import re
19
+ import urllib.parse as ul
20
+ from dataclasses import dataclass
21
+ from typing import Callable, List, Optional, Tuple, Union
22
+
23
+ import numpy as np
24
+ import torch
25
+ from diffusers import DiffusionPipeline, ImagePipelineOutput
26
+ from diffusers.image_processor import VaeImageProcessor
27
+ from diffusers.models import AutoencoderKL
28
+ from diffusers.schedulers import DPMSolverMultistepScheduler
29
+ from diffusers.utils import (BACKENDS_MAPPING, BaseOutput, deprecate,
30
+ is_bs4_available, is_ftfy_available, logging,
31
+ replace_example_docstring)
32
+ from diffusers.utils.torch_utils import randn_tensor
33
+ from einops import rearrange
34
+ from tqdm import tqdm
35
+ from transformers import T5EncoderModel, T5Tokenizer
36
+
37
+ from ..models.transformer3d import Transformer3DModel
38
+
39
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
40
+
41
+ if is_bs4_available():
42
+ from bs4 import BeautifulSoup
43
+
44
+ if is_ftfy_available():
45
+ import ftfy
46
+
47
+
48
+ EXAMPLE_DOC_STRING = """
49
+ Examples:
50
+ ```py
51
+ >>> import torch
52
+ >>> from diffusers import EasyAnimatePipeline
53
+
54
+ >>> # You can replace the checkpoint id with "PixArt-alpha/PixArt-XL-2-512x512" too.
55
+ >>> pipe = EasyAnimatePipeline.from_pretrained("PixArt-alpha/PixArt-XL-2-1024-MS", torch_dtype=torch.float16)
56
+ >>> # Enable memory optimizations.
57
+ >>> pipe.enable_model_cpu_offload()
58
+
59
+ >>> prompt = "A small cactus with a happy face in the Sahara desert."
60
+ >>> image = pipe(prompt).images[0]
61
+ ```
62
+ """
63
+
64
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
65
+ def retrieve_timesteps(
66
+ scheduler,
67
+ num_inference_steps: Optional[int] = None,
68
+ device: Optional[Union[str, torch.device]] = None,
69
+ timesteps: Optional[List[int]] = None,
70
+ **kwargs,
71
+ ):
72
+ """
73
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
74
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
75
+
76
+ Args:
77
+ scheduler (`SchedulerMixin`):
78
+ The scheduler to get timesteps from.
79
+ num_inference_steps (`int`):
80
+ The number of diffusion steps used when generating samples with a pre-trained model. If used,
81
+ `timesteps` must be `None`.
82
+ device (`str` or `torch.device`, *optional*):
83
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
84
+ timesteps (`List[int]`, *optional*):
85
+ Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
86
+ timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
87
+ must be `None`.
88
+
89
+ Returns:
90
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
91
+ second element is the number of inference steps.
92
+ """
93
+ if timesteps is not None:
94
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
95
+ if not accepts_timesteps:
96
+ raise ValueError(
97
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
98
+ f" timestep schedules. Please check whether you are using the correct scheduler."
99
+ )
100
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
101
+ timesteps = scheduler.timesteps
102
+ num_inference_steps = len(timesteps)
103
+ else:
104
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
105
+ timesteps = scheduler.timesteps
106
+ return timesteps, num_inference_steps
107
+
108
+ @dataclass
109
+ class EasyAnimatePipelineOutput(BaseOutput):
110
+ videos: Union[torch.Tensor, np.ndarray]
111
+
112
+ class EasyAnimatePipeline(DiffusionPipeline):
113
+ r"""
114
+ Pipeline for text-to-image generation using PixArt-Alpha.
115
+
116
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
117
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
118
+
119
+ Args:
120
+ vae ([`AutoencoderKL`]):
121
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
122
+ text_encoder ([`T5EncoderModel`]):
123
+ Frozen text-encoder. PixArt-Alpha uses
124
+ [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the
125
+ [t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant.
126
+ tokenizer (`T5Tokenizer`):
127
+ Tokenizer of class
128
+ [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
129
+ transformer ([`Transformer3DModel`]):
130
+ A text conditioned `Transformer3DModel` to denoise the encoded image latents.
131
+ scheduler ([`SchedulerMixin`]):
132
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
133
+ """
134
+ bad_punct_regex = re.compile(
135
+ r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}"
136
+ ) # noqa
137
+
138
+ _optional_components = ["tokenizer", "text_encoder"]
139
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
140
+
141
+ def __init__(
142
+ self,
143
+ tokenizer: T5Tokenizer,
144
+ text_encoder: T5EncoderModel,
145
+ vae: AutoencoderKL,
146
+ transformer: Transformer3DModel,
147
+ scheduler: DPMSolverMultistepScheduler,
148
+ ):
149
+ super().__init__()
150
+
151
+ self.register_modules(
152
+ tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
153
+ )
154
+
155
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
156
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
157
+
158
+ # Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/utils.py
159
+ def mask_text_embeddings(self, emb, mask):
160
+ if emb.shape[0] == 1:
161
+ keep_index = mask.sum().item()
162
+ return emb[:, :, :keep_index, :], keep_index
163
+ else:
164
+ masked_feature = emb * mask[:, None, :, None]
165
+ return masked_feature, emb.shape[2]
166
+
167
+ # Adapted from diffusers.pipelines.deepfloyd_if.pipeline_if.encode_prompt
168
+ def encode_prompt(
169
+ self,
170
+ prompt: Union[str, List[str]],
171
+ do_classifier_free_guidance: bool = True,
172
+ negative_prompt: str = "",
173
+ num_images_per_prompt: int = 1,
174
+ device: Optional[torch.device] = None,
175
+ prompt_embeds: Optional[torch.FloatTensor] = None,
176
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
177
+ prompt_attention_mask: Optional[torch.FloatTensor] = None,
178
+ negative_prompt_attention_mask: Optional[torch.FloatTensor] = None,
179
+ clean_caption: bool = False,
180
+ max_sequence_length: int = 120,
181
+ **kwargs,
182
+ ):
183
+ r"""
184
+ Encodes the prompt into text encoder hidden states.
185
+
186
+ Args:
187
+ prompt (`str` or `List[str]`, *optional*):
188
+ prompt to be encoded
189
+ negative_prompt (`str` or `List[str]`, *optional*):
190
+ The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds`
191
+ instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For
192
+ PixArt-Alpha, this should be "".
193
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
194
+ whether to use classifier free guidance or not
195
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
196
+ number of images that should be generated per prompt
197
+ device: (`torch.device`, *optional*):
198
+ torch device to place the resulting embeddings on
199
+ prompt_embeds (`torch.FloatTensor`, *optional*):
200
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
201
+ provided, text embeddings will be generated from `prompt` input argument.
202
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
203
+ Pre-generated negative text embeddings. For PixArt-Alpha, it's should be the embeddings of the ""
204
+ string.
205
+ clean_caption (`bool`, defaults to `False`):
206
+ If `True`, the function will preprocess and clean the provided caption before encoding.
207
+ max_sequence_length (`int`, defaults to 120): Maximum sequence length to use for the prompt.
208
+ """
209
+
210
+ if "mask_feature" in kwargs:
211
+ deprecation_message = "The use of `mask_feature` is deprecated. It is no longer used in any computation and that doesn't affect the end results. It will be removed in a future version."
212
+ deprecate("mask_feature", "1.0.0", deprecation_message, standard_warn=False)
213
+
214
+ if device is None:
215
+ device = self._execution_device
216
+
217
+ if prompt is not None and isinstance(prompt, str):
218
+ batch_size = 1
219
+ elif prompt is not None and isinstance(prompt, list):
220
+ batch_size = len(prompt)
221
+ else:
222
+ batch_size = prompt_embeds.shape[0]
223
+
224
+ # See Section 3.1. of the paper.
225
+ max_length = max_sequence_length
226
+
227
+ if prompt_embeds is None:
228
+ prompt = self._text_preprocessing(prompt, clean_caption=clean_caption)
229
+ text_inputs = self.tokenizer(
230
+ prompt,
231
+ padding="max_length",
232
+ max_length=max_length,
233
+ truncation=True,
234
+ add_special_tokens=True,
235
+ return_tensors="pt",
236
+ )
237
+ text_input_ids = text_inputs.input_ids
238
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
239
+
240
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
241
+ text_input_ids, untruncated_ids
242
+ ):
243
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1])
244
+ logger.warning(
245
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
246
+ f" {max_length} tokens: {removed_text}"
247
+ )
248
+
249
+ prompt_attention_mask = text_inputs.attention_mask
250
+ prompt_attention_mask = prompt_attention_mask.to(device)
251
+
252
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask)
253
+ prompt_embeds = prompt_embeds[0]
254
+
255
+ if self.text_encoder is not None:
256
+ dtype = self.text_encoder.dtype
257
+ elif self.transformer is not None:
258
+ dtype = self.transformer.dtype
259
+ else:
260
+ dtype = None
261
+
262
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
263
+
264
+ bs_embed, seq_len, _ = prompt_embeds.shape
265
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
266
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
267
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
268
+ prompt_attention_mask = prompt_attention_mask.view(bs_embed, -1)
269
+ prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
270
+
271
+ # get unconditional embeddings for classifier free guidance
272
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
273
+ uncond_tokens = [negative_prompt] * batch_size
274
+ uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption)
275
+ max_length = prompt_embeds.shape[1]
276
+ uncond_input = self.tokenizer(
277
+ uncond_tokens,
278
+ padding="max_length",
279
+ max_length=max_length,
280
+ truncation=True,
281
+ return_attention_mask=True,
282
+ add_special_tokens=True,
283
+ return_tensors="pt",
284
+ )
285
+ negative_prompt_attention_mask = uncond_input.attention_mask
286
+ negative_prompt_attention_mask = negative_prompt_attention_mask.to(device)
287
+
288
+ negative_prompt_embeds = self.text_encoder(
289
+ uncond_input.input_ids.to(device), attention_mask=negative_prompt_attention_mask
290
+ )
291
+ negative_prompt_embeds = negative_prompt_embeds[0]
292
+
293
+ if do_classifier_free_guidance:
294
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
295
+ seq_len = negative_prompt_embeds.shape[1]
296
+
297
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
298
+
299
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
300
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
301
+
302
+ negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed, -1)
303
+ negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1)
304
+ else:
305
+ negative_prompt_embeds = None
306
+ negative_prompt_attention_mask = None
307
+
308
+ return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
309
+
310
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
311
+ def prepare_extra_step_kwargs(self, generator, eta):
312
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
313
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
314
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
315
+ # and should be between [0, 1]
316
+
317
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
318
+ extra_step_kwargs = {}
319
+ if accepts_eta:
320
+ extra_step_kwargs["eta"] = eta
321
+
322
+ # check if the scheduler accepts generator
323
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
324
+ if accepts_generator:
325
+ extra_step_kwargs["generator"] = generator
326
+ return extra_step_kwargs
327
+
328
+ def check_inputs(
329
+ self,
330
+ prompt,
331
+ height,
332
+ width,
333
+ negative_prompt,
334
+ callback_steps,
335
+ prompt_embeds=None,
336
+ negative_prompt_embeds=None,
337
+ ):
338
+ if height % 8 != 0 or width % 8 != 0:
339
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
340
+
341
+ if (callback_steps is None) or (
342
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
343
+ ):
344
+ raise ValueError(
345
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
346
+ f" {type(callback_steps)}."
347
+ )
348
+
349
+ if prompt is not None and prompt_embeds is not None:
350
+ raise ValueError(
351
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
352
+ " only forward one of the two."
353
+ )
354
+ elif prompt is None and prompt_embeds is None:
355
+ raise ValueError(
356
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
357
+ )
358
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
359
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
360
+
361
+ if prompt is not None and negative_prompt_embeds is not None:
362
+ raise ValueError(
363
+ f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
364
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
365
+ )
366
+
367
+ if negative_prompt is not None and negative_prompt_embeds is not None:
368
+ raise ValueError(
369
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
370
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
371
+ )
372
+
373
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
374
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
375
+ raise ValueError(
376
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
377
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
378
+ f" {negative_prompt_embeds.shape}."
379
+ )
380
+
381
+ # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing
382
+ def _text_preprocessing(self, text, clean_caption=False):
383
+ if clean_caption and not is_bs4_available():
384
+ logger.warn(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`"))
385
+ logger.warn("Setting `clean_caption` to False...")
386
+ clean_caption = False
387
+
388
+ if clean_caption and not is_ftfy_available():
389
+ logger.warn(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`"))
390
+ logger.warn("Setting `clean_caption` to False...")
391
+ clean_caption = False
392
+
393
+ if not isinstance(text, (tuple, list)):
394
+ text = [text]
395
+
396
+ def process(text: str):
397
+ if clean_caption:
398
+ text = self._clean_caption(text)
399
+ text = self._clean_caption(text)
400
+ else:
401
+ text = text.lower().strip()
402
+ return text
403
+
404
+ return [process(t) for t in text]
405
+
406
+ # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption
407
+ def _clean_caption(self, caption):
408
+ caption = str(caption)
409
+ caption = ul.unquote_plus(caption)
410
+ caption = caption.strip().lower()
411
+ caption = re.sub("<person>", "person", caption)
412
+ # urls:
413
+ caption = re.sub(
414
+ r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
415
+ "",
416
+ caption,
417
+ ) # regex for urls
418
+ caption = re.sub(
419
+ r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
420
+ "",
421
+ caption,
422
+ ) # regex for urls
423
+ # html:
424
+ caption = BeautifulSoup(caption, features="html.parser").text
425
+
426
+ # @<nickname>
427
+ caption = re.sub(r"@[\w\d]+\b", "", caption)
428
+
429
+ # 31C0—31EF CJK Strokes
430
+ # 31F0—31FF Katakana Phonetic Extensions
431
+ # 3200—32FF Enclosed CJK Letters and Months
432
+ # 3300—33FF CJK Compatibility
433
+ # 3400—4DBF CJK Unified Ideographs Extension A
434
+ # 4DC0—4DFF Yijing Hexagram Symbols
435
+ # 4E00—9FFF CJK Unified Ideographs
436
+ caption = re.sub(r"[\u31c0-\u31ef]+", "", caption)
437
+ caption = re.sub(r"[\u31f0-\u31ff]+", "", caption)
438
+ caption = re.sub(r"[\u3200-\u32ff]+", "", caption)
439
+ caption = re.sub(r"[\u3300-\u33ff]+", "", caption)
440
+ caption = re.sub(r"[\u3400-\u4dbf]+", "", caption)
441
+ caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption)
442
+ caption = re.sub(r"[\u4e00-\u9fff]+", "", caption)
443
+ #######################################################
444
+
445
+ # все виды тире / all types of dash --> "-"
446
+ caption = re.sub(
447
+ r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa
448
+ "-",
449
+ caption,
450
+ )
451
+
452
+ # кавычки к одному стандарту
453
+ caption = re.sub(r"[`´«»“”¨]", '"', caption)
454
+ caption = re.sub(r"[‘’]", "'", caption)
455
+
456
+ # &quot;
457
+ caption = re.sub(r"&quot;?", "", caption)
458
+ # &amp
459
+ caption = re.sub(r"&amp", "", caption)
460
+
461
+ # ip adresses:
462
+ caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
463
+
464
+ # article ids:
465
+ caption = re.sub(r"\d:\d\d\s+$", "", caption)
466
+
467
+ # \n
468
+ caption = re.sub(r"\\n", " ", caption)
469
+
470
+ # "#123"
471
+ caption = re.sub(r"#\d{1,3}\b", "", caption)
472
+ # "#12345.."
473
+ caption = re.sub(r"#\d{5,}\b", "", caption)
474
+ # "123456.."
475
+ caption = re.sub(r"\b\d{6,}\b", "", caption)
476
+ # filenames:
477
+ caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption)
478
+
479
+ #
480
+ caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT"""
481
+ caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT"""
482
+
483
+ caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT
484
+ caption = re.sub(r"\s+\.\s+", r" ", caption) # " . "
485
+
486
+ # this-is-my-cute-cat / this_is_my_cute_cat
487
+ regex2 = re.compile(r"(?:\-|\_)")
488
+ if len(re.findall(regex2, caption)) > 3:
489
+ caption = re.sub(regex2, " ", caption)
490
+
491
+ caption = ftfy.fix_text(caption)
492
+ caption = html.unescape(html.unescape(caption))
493
+
494
+ caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640
495
+ caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc
496
+ caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231
497
+
498
+ caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption)
499
+ caption = re.sub(r"(free\s)?download(\sfree)?", "", caption)
500
+ caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption)
501
+ caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption)
502
+ caption = re.sub(r"\bpage\s+\d+\b", "", caption)
503
+
504
+ caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a...
505
+
506
+ caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption)
507
+
508
+ caption = re.sub(r"\b\s+\:\s+", r": ", caption)
509
+ caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption)
510
+ caption = re.sub(r"\s+", " ", caption)
511
+
512
+ caption.strip()
513
+
514
+ caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption)
515
+ caption = re.sub(r"^[\'\_,\-\:;]", r"", caption)
516
+ caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption)
517
+ caption = re.sub(r"^\.\S+$", "", caption)
518
+
519
+ return caption.strip()
520
+
521
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
522
+ def prepare_latents(self, batch_size, num_channels_latents, video_length, height, width, dtype, device, generator, latents=None):
523
+ if self.vae.quant_conv.weight.ndim==5:
524
+ mini_batch_encoder = self.vae.mini_batch_encoder
525
+ mini_batch_decoder = self.vae.mini_batch_decoder
526
+ shape = (batch_size, num_channels_latents, int(video_length // mini_batch_encoder * mini_batch_decoder) if video_length != 1 else 1, height // self.vae_scale_factor, width // self.vae_scale_factor)
527
+ else:
528
+ shape = (batch_size, num_channels_latents, video_length, height // self.vae_scale_factor, width // self.vae_scale_factor)
529
+
530
+ if isinstance(generator, list) and len(generator) != batch_size:
531
+ raise ValueError(
532
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
533
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
534
+ )
535
+
536
+ if latents is None:
537
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
538
+ else:
539
+ latents = latents.to(device)
540
+
541
+ # scale the initial noise by the standard deviation required by the scheduler
542
+ latents = latents * self.scheduler.init_noise_sigma
543
+ return latents
544
+
545
+ def smooth_output(self, video, mini_batch_encoder, mini_batch_decoder):
546
+ if video.size()[2] <= mini_batch_encoder:
547
+ return video
548
+ prefix_index_before = mini_batch_encoder // 2
549
+ prefix_index_after = mini_batch_encoder - prefix_index_before
550
+ pixel_values = video[:, :, prefix_index_before:-prefix_index_after]
551
+
552
+ if self.vae.slice_compression_vae:
553
+ latents = self.vae.encode(pixel_values)[0]
554
+ latents = latents.sample()
555
+ else:
556
+ new_pixel_values = []
557
+ for i in range(0, pixel_values.shape[2], mini_batch_encoder):
558
+ with torch.no_grad():
559
+ pixel_values_bs = pixel_values[:, :, i: i + mini_batch_encoder, :, :]
560
+ pixel_values_bs = self.vae.encode(pixel_values_bs)[0]
561
+ pixel_values_bs = pixel_values_bs.sample()
562
+ new_pixel_values.append(pixel_values_bs)
563
+ latents = torch.cat(new_pixel_values, dim = 2)
564
+
565
+ if self.vae.slice_compression_vae:
566
+ middle_video = self.vae.decode(latents)[0]
567
+ else:
568
+ middle_video = []
569
+ for i in range(0, latents.shape[2], mini_batch_decoder):
570
+ with torch.no_grad():
571
+ start_index = i
572
+ end_index = i + mini_batch_decoder
573
+ latents_bs = self.vae.decode(latents[:, :, start_index:end_index, :, :])[0]
574
+ middle_video.append(latents_bs)
575
+ middle_video = torch.cat(middle_video, 2)
576
+ video[:, :, prefix_index_before:-prefix_index_after] = (video[:, :, prefix_index_before:-prefix_index_after] + middle_video) / 2
577
+ return video
578
+
579
+ def decode_latents(self, latents):
580
+ video_length = latents.shape[2]
581
+ latents = 1 / 0.18215 * latents
582
+ if self.vae.quant_conv.weight.ndim==5:
583
+ mini_batch_encoder = self.vae.mini_batch_encoder
584
+ mini_batch_decoder = self.vae.mini_batch_decoder
585
+ if self.vae.slice_compression_vae:
586
+ video = self.vae.decode(latents)[0]
587
+ else:
588
+ video = []
589
+ for i in range(0, latents.shape[2], mini_batch_decoder):
590
+ with torch.no_grad():
591
+ start_index = i
592
+ end_index = i + mini_batch_decoder
593
+ latents_bs = self.vae.decode(latents[:, :, start_index:end_index, :, :])[0]
594
+ video.append(latents_bs)
595
+ video = torch.cat(video, 2)
596
+ video = video.clamp(-1, 1)
597
+ video = self.smooth_output(video, mini_batch_encoder, mini_batch_decoder).cpu().clamp(-1, 1)
598
+ else:
599
+ latents = rearrange(latents, "b c f h w -> (b f) c h w")
600
+ video = []
601
+ for frame_idx in tqdm(range(latents.shape[0])):
602
+ video.append(self.vae.decode(latents[frame_idx:frame_idx+1]).sample)
603
+ video = torch.cat(video)
604
+ video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length)
605
+ video = (video / 2 + 0.5).clamp(0, 1)
606
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
607
+ video = video.cpu().float().numpy()
608
+ return video
609
+
610
+ @torch.no_grad()
611
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
612
+ def __call__(
613
+ self,
614
+ prompt: Union[str, List[str]] = None,
615
+ video_length: Optional[int] = None,
616
+ negative_prompt: str = "",
617
+ num_inference_steps: int = 20,
618
+ timesteps: List[int] = None,
619
+ guidance_scale: float = 4.5,
620
+ num_images_per_prompt: Optional[int] = 1,
621
+ height: Optional[int] = None,
622
+ width: Optional[int] = None,
623
+ eta: float = 0.0,
624
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
625
+ latents: Optional[torch.FloatTensor] = None,
626
+ prompt_embeds: Optional[torch.FloatTensor] = None,
627
+ prompt_attention_mask: Optional[torch.FloatTensor] = None,
628
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
629
+ negative_prompt_attention_mask: Optional[torch.FloatTensor] = None,
630
+ output_type: Optional[str] = "latent",
631
+ return_dict: bool = True,
632
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
633
+ callback_steps: int = 1,
634
+ clean_caption: bool = True,
635
+ max_sequence_length: int = 120,
636
+ **kwargs,
637
+ ) -> Union[EasyAnimatePipelineOutput, Tuple]:
638
+ """
639
+ Function invoked when calling the pipeline for generation.
640
+
641
+ Args:
642
+ prompt (`str` or `List[str]`, *optional*):
643
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
644
+ instead.
645
+ negative_prompt (`str` or `List[str]`, *optional*):
646
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
647
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
648
+ less than `1`).
649
+ num_inference_steps (`int`, *optional*, defaults to 100):
650
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
651
+ expense of slower inference.
652
+ timesteps (`List[int]`, *optional*):
653
+ Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps`
654
+ timesteps are used. Must be in descending order.
655
+ guidance_scale (`float`, *optional*, defaults to 7.0):
656
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
657
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
658
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
659
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
660
+ usually at the expense of lower image quality.
661
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
662
+ The number of images to generate per prompt.
663
+ height (`int`, *optional*, defaults to self.unet.config.sample_size):
664
+ The height in pixels of the generated image.
665
+ width (`int`, *optional*, defaults to self.unet.config.sample_size):
666
+ The width in pixels of the generated image.
667
+ eta (`float`, *optional*, defaults to 0.0):
668
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
669
+ [`schedulers.DDIMScheduler`], will be ignored for others.
670
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
671
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
672
+ to make generation deterministic.
673
+ latents (`torch.FloatTensor`, *optional*):
674
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
675
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
676
+ tensor will ge generated by sampling using the supplied random `generator`.
677
+ prompt_embeds (`torch.FloatTensor`, *optional*):
678
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
679
+ provided, text embeddings will be generated from `prompt` input argument.
680
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
681
+ Pre-generated negative text embeddings. For PixArt-Alpha this negative prompt should be "". If not
682
+ provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
683
+ output_type (`str`, *optional*, defaults to `"pil"`):
684
+ The output format of the generate image. Choose between
685
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
686
+ return_dict (`bool`, *optional*, defaults to `True`):
687
+ Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple.
688
+ callback (`Callable`, *optional*):
689
+ A function that will be called every `callback_steps` steps during inference. The function will be
690
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
691
+ callback_steps (`int`, *optional*, defaults to 1):
692
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
693
+ called at every step.
694
+ clean_caption (`bool`, *optional*, defaults to `True`):
695
+ Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to
696
+ be installed. If the dependencies are not installed, the embeddings will be created from the raw
697
+ prompt.
698
+ mask_feature (`bool` defaults to `True`): If set to `True`, the text embeddings will be masked.
699
+
700
+ Examples:
701
+
702
+ Returns:
703
+ [`~pipelines.ImagePipelineOutput`] or `tuple`:
704
+ If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is
705
+ returned where the first element is a list with the generated images
706
+ """
707
+ # 1. Check inputs. Raise error if not correct
708
+ height = height or self.transformer.config.sample_size * self.vae_scale_factor
709
+ width = width or self.transformer.config.sample_size * self.vae_scale_factor
710
+
711
+ # 2. Default height and width to transformer
712
+ if prompt is not None and isinstance(prompt, str):
713
+ batch_size = 1
714
+ elif prompt is not None and isinstance(prompt, list):
715
+ batch_size = len(prompt)
716
+ else:
717
+ batch_size = prompt_embeds.shape[0]
718
+
719
+ device = self._execution_device
720
+
721
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
722
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
723
+ # corresponds to doing no classifier free guidance.
724
+ do_classifier_free_guidance = guidance_scale > 1.0
725
+
726
+ # 3. Encode input prompt
727
+ (
728
+ prompt_embeds,
729
+ prompt_attention_mask,
730
+ negative_prompt_embeds,
731
+ negative_prompt_attention_mask,
732
+ ) = self.encode_prompt(
733
+ prompt,
734
+ do_classifier_free_guidance,
735
+ negative_prompt=negative_prompt,
736
+ num_images_per_prompt=num_images_per_prompt,
737
+ device=device,
738
+ prompt_embeds=prompt_embeds,
739
+ negative_prompt_embeds=negative_prompt_embeds,
740
+ prompt_attention_mask=prompt_attention_mask,
741
+ negative_prompt_attention_mask=negative_prompt_attention_mask,
742
+ clean_caption=clean_caption,
743
+ max_sequence_length=max_sequence_length,
744
+ )
745
+ if do_classifier_free_guidance:
746
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
747
+ prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
748
+
749
+ # 4. Prepare timesteps
750
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
751
+
752
+ # 5. Prepare latents.
753
+ latent_channels = self.transformer.config.in_channels
754
+ latents = self.prepare_latents(
755
+ batch_size * num_images_per_prompt,
756
+ latent_channels,
757
+ video_length,
758
+ height,
759
+ width,
760
+ prompt_embeds.dtype,
761
+ device,
762
+ generator,
763
+ latents,
764
+ )
765
+
766
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
767
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
768
+
769
+ # 6.1 Prepare micro-conditions.
770
+ added_cond_kwargs = {"resolution": None, "aspect_ratio": None}
771
+ if self.transformer.config.sample_size == 128:
772
+ resolution = torch.tensor([height, width]).repeat(batch_size * num_images_per_prompt, 1)
773
+ aspect_ratio = torch.tensor([float(height / width)]).repeat(batch_size * num_images_per_prompt, 1)
774
+ resolution = resolution.to(dtype=prompt_embeds.dtype, device=device)
775
+ aspect_ratio = aspect_ratio.to(dtype=prompt_embeds.dtype, device=device)
776
+
777
+ if do_classifier_free_guidance:
778
+ resolution = torch.cat([resolution, resolution], dim=0)
779
+ aspect_ratio = torch.cat([aspect_ratio, aspect_ratio], dim=0)
780
+
781
+ added_cond_kwargs = {"resolution": resolution, "aspect_ratio": aspect_ratio}
782
+
783
+ # 7. Denoising loop
784
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
785
+
786
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
787
+ for i, t in enumerate(timesteps):
788
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
789
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
790
+
791
+ current_timestep = t
792
+ if not torch.is_tensor(current_timestep):
793
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
794
+ # This would be a good case for the `match` statement (Python 3.10+)
795
+ is_mps = latent_model_input.device.type == "mps"
796
+ if isinstance(current_timestep, float):
797
+ dtype = torch.float32 if is_mps else torch.float64
798
+ else:
799
+ dtype = torch.int32 if is_mps else torch.int64
800
+ current_timestep = torch.tensor([current_timestep], dtype=dtype, device=latent_model_input.device)
801
+ elif len(current_timestep.shape) == 0:
802
+ current_timestep = current_timestep[None].to(latent_model_input.device)
803
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
804
+ current_timestep = current_timestep.expand(latent_model_input.shape[0])
805
+
806
+ # predict noise model_output
807
+ noise_pred = self.transformer(
808
+ latent_model_input,
809
+ encoder_hidden_states=prompt_embeds,
810
+ encoder_attention_mask=prompt_attention_mask,
811
+ timestep=current_timestep,
812
+ added_cond_kwargs=added_cond_kwargs,
813
+ return_dict=False,
814
+ )[0]
815
+
816
+ # perform guidance
817
+ if do_classifier_free_guidance:
818
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
819
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
820
+
821
+ # learned sigma
822
+ if self.transformer.config.out_channels // 2 == latent_channels:
823
+ noise_pred = noise_pred.chunk(2, dim=1)[0]
824
+ else:
825
+ noise_pred = noise_pred
826
+
827
+ # compute previous image: x_t -> x_t-1
828
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
829
+
830
+ # call the callback, if provided
831
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
832
+ progress_bar.update()
833
+ if callback is not None and i % callback_steps == 0:
834
+ step_idx = i // getattr(self.scheduler, "order", 1)
835
+ callback(step_idx, t, latents)
836
+
837
+ # Post-processing
838
+ video = self.decode_latents(latents)
839
+
840
+ # Convert to tensor
841
+ if output_type == "latent":
842
+ video = torch.from_numpy(video)
843
+
844
+ if not return_dict:
845
+ return video
846
+
847
+ return EasyAnimatePipelineOutput(videos=video)
easyanimate/pipeline/pipeline_easyanimate_inpaint.py ADDED
@@ -0,0 +1,984 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 PixArt-Alpha Authors and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import html
16
+ import inspect
17
+ import re
18
+ import copy
19
+ import urllib.parse as ul
20
+ from dataclasses import dataclass
21
+ from typing import Callable, List, Optional, Tuple, Union
22
+
23
+ import numpy as np
24
+ import torch
25
+ from diffusers import DiffusionPipeline, ImagePipelineOutput
26
+ from diffusers.image_processor import VaeImageProcessor
27
+ from diffusers.models import AutoencoderKL
28
+ from diffusers.schedulers import DPMSolverMultistepScheduler
29
+ from diffusers.utils import (BACKENDS_MAPPING, BaseOutput, deprecate,
30
+ is_bs4_available, is_ftfy_available, logging,
31
+ replace_example_docstring)
32
+ from diffusers.utils.torch_utils import randn_tensor
33
+ from einops import rearrange
34
+ from tqdm import tqdm
35
+ from transformers import T5EncoderModel, T5Tokenizer
36
+
37
+ from ..models.transformer3d import Transformer3DModel
38
+
39
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
40
+
41
+ if is_bs4_available():
42
+ from bs4 import BeautifulSoup
43
+
44
+ if is_ftfy_available():
45
+ import ftfy
46
+
47
+
48
+ EXAMPLE_DOC_STRING = """
49
+ Examples:
50
+ ```py
51
+ >>> import torch
52
+ >>> from diffusers import EasyAnimatePipeline
53
+
54
+ >>> # You can replace the checkpoint id with "PixArt-alpha/PixArt-XL-2-512x512" too.
55
+ >>> pipe = EasyAnimatePipeline.from_pretrained("PixArt-alpha/PixArt-XL-2-1024-MS", torch_dtype=torch.float16)
56
+ >>> # Enable memory optimizations.
57
+ >>> pipe.enable_model_cpu_offload()
58
+
59
+ >>> prompt = "A small cactus with a happy face in the Sahara desert."
60
+ >>> image = pipe(prompt).images[0]
61
+ ```
62
+ """
63
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
64
+ def retrieve_latents(encoder_output, generator):
65
+ if hasattr(encoder_output, "latent_dist"):
66
+ return encoder_output.latent_dist.sample(generator)
67
+ elif hasattr(encoder_output, "latents"):
68
+ return encoder_output.latents
69
+ else:
70
+ raise AttributeError("Could not access latents of provided encoder_output")
71
+
72
+ @dataclass
73
+ class EasyAnimatePipelineOutput(BaseOutput):
74
+ videos: Union[torch.Tensor, np.ndarray]
75
+
76
+ class EasyAnimateInpaintPipeline(DiffusionPipeline):
77
+ r"""
78
+ Pipeline for text-to-image generation using PixArt-Alpha.
79
+
80
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
81
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
82
+
83
+ Args:
84
+ vae ([`AutoencoderKL`]):
85
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
86
+ text_encoder ([`T5EncoderModel`]):
87
+ Frozen text-encoder. PixArt-Alpha uses
88
+ [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the
89
+ [t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant.
90
+ tokenizer (`T5Tokenizer`):
91
+ Tokenizer of class
92
+ [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
93
+ transformer ([`Transformer3DModel`]):
94
+ A text conditioned `Transformer3DModel` to denoise the encoded image latents.
95
+ scheduler ([`SchedulerMixin`]):
96
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
97
+ """
98
+ bad_punct_regex = re.compile(
99
+ r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}"
100
+ ) # noqa
101
+
102
+ _optional_components = ["tokenizer", "text_encoder"]
103
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
104
+
105
+ def __init__(
106
+ self,
107
+ tokenizer: T5Tokenizer,
108
+ text_encoder: T5EncoderModel,
109
+ vae: AutoencoderKL,
110
+ transformer: Transformer3DModel,
111
+ scheduler: DPMSolverMultistepScheduler,
112
+ ):
113
+ super().__init__()
114
+
115
+ self.register_modules(
116
+ tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
117
+ )
118
+
119
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
120
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_normalize=True)
121
+ self.mask_processor = VaeImageProcessor(
122
+ vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True
123
+ )
124
+
125
+ # Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/utils.py
126
+ def mask_text_embeddings(self, emb, mask):
127
+ if emb.shape[0] == 1:
128
+ keep_index = mask.sum().item()
129
+ return emb[:, :, :keep_index, :], keep_index
130
+ else:
131
+ masked_feature = emb * mask[:, None, :, None]
132
+ return masked_feature, emb.shape[2]
133
+
134
+ # Adapted from diffusers.pipelines.deepfloyd_if.pipeline_if.encode_prompt
135
+ def encode_prompt(
136
+ self,
137
+ prompt: Union[str, List[str]],
138
+ do_classifier_free_guidance: bool = True,
139
+ negative_prompt: str = "",
140
+ num_images_per_prompt: int = 1,
141
+ device: Optional[torch.device] = None,
142
+ prompt_embeds: Optional[torch.FloatTensor] = None,
143
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
144
+ prompt_attention_mask: Optional[torch.FloatTensor] = None,
145
+ negative_prompt_attention_mask: Optional[torch.FloatTensor] = None,
146
+ clean_caption: bool = False,
147
+ max_sequence_length: int = 120,
148
+ **kwargs,
149
+ ):
150
+ r"""
151
+ Encodes the prompt into text encoder hidden states.
152
+
153
+ Args:
154
+ prompt (`str` or `List[str]`, *optional*):
155
+ prompt to be encoded
156
+ negative_prompt (`str` or `List[str]`, *optional*):
157
+ The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds`
158
+ instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For
159
+ PixArt-Alpha, this should be "".
160
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
161
+ whether to use classifier free guidance or not
162
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
163
+ number of images that should be generated per prompt
164
+ device: (`torch.device`, *optional*):
165
+ torch device to place the resulting embeddings on
166
+ prompt_embeds (`torch.FloatTensor`, *optional*):
167
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
168
+ provided, text embeddings will be generated from `prompt` input argument.
169
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
170
+ Pre-generated negative text embeddings. For PixArt-Alpha, it's should be the embeddings of the ""
171
+ string.
172
+ clean_caption (`bool`, defaults to `False`):
173
+ If `True`, the function will preprocess and clean the provided caption before encoding.
174
+ max_sequence_length (`int`, defaults to 120): Maximum sequence length to use for the prompt.
175
+ """
176
+
177
+ if "mask_feature" in kwargs:
178
+ deprecation_message = "The use of `mask_feature` is deprecated. It is no longer used in any computation and that doesn't affect the end results. It will be removed in a future version."
179
+ deprecate("mask_feature", "1.0.0", deprecation_message, standard_warn=False)
180
+
181
+ if device is None:
182
+ device = self._execution_device
183
+
184
+ if prompt is not None and isinstance(prompt, str):
185
+ batch_size = 1
186
+ elif prompt is not None and isinstance(prompt, list):
187
+ batch_size = len(prompt)
188
+ else:
189
+ batch_size = prompt_embeds.shape[0]
190
+
191
+ # See Section 3.1. of the paper.
192
+ max_length = max_sequence_length
193
+
194
+ if prompt_embeds is None:
195
+ prompt = self._text_preprocessing(prompt, clean_caption=clean_caption)
196
+ text_inputs = self.tokenizer(
197
+ prompt,
198
+ padding="max_length",
199
+ max_length=max_length,
200
+ truncation=True,
201
+ add_special_tokens=True,
202
+ return_tensors="pt",
203
+ )
204
+ text_input_ids = text_inputs.input_ids
205
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
206
+
207
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
208
+ text_input_ids, untruncated_ids
209
+ ):
210
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1])
211
+ logger.warning(
212
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
213
+ f" {max_length} tokens: {removed_text}"
214
+ )
215
+
216
+ prompt_attention_mask = text_inputs.attention_mask
217
+ prompt_attention_mask = prompt_attention_mask.to(device)
218
+
219
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask)
220
+ prompt_embeds = prompt_embeds[0]
221
+
222
+ if self.text_encoder is not None:
223
+ dtype = self.text_encoder.dtype
224
+ elif self.transformer is not None:
225
+ dtype = self.transformer.dtype
226
+ else:
227
+ dtype = None
228
+
229
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
230
+
231
+ bs_embed, seq_len, _ = prompt_embeds.shape
232
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
233
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
234
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
235
+ prompt_attention_mask = prompt_attention_mask.view(bs_embed, -1)
236
+ prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
237
+
238
+ # get unconditional embeddings for classifier free guidance
239
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
240
+ uncond_tokens = [negative_prompt] * batch_size
241
+ uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption)
242
+ max_length = prompt_embeds.shape[1]
243
+ uncond_input = self.tokenizer(
244
+ uncond_tokens,
245
+ padding="max_length",
246
+ max_length=max_length,
247
+ truncation=True,
248
+ return_attention_mask=True,
249
+ add_special_tokens=True,
250
+ return_tensors="pt",
251
+ )
252
+ negative_prompt_attention_mask = uncond_input.attention_mask
253
+ negative_prompt_attention_mask = negative_prompt_attention_mask.to(device)
254
+
255
+ negative_prompt_embeds = self.text_encoder(
256
+ uncond_input.input_ids.to(device), attention_mask=negative_prompt_attention_mask
257
+ )
258
+ negative_prompt_embeds = negative_prompt_embeds[0]
259
+
260
+ if do_classifier_free_guidance:
261
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
262
+ seq_len = negative_prompt_embeds.shape[1]
263
+
264
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
265
+
266
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
267
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
268
+
269
+ negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed, -1)
270
+ negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1)
271
+ else:
272
+ negative_prompt_embeds = None
273
+ negative_prompt_attention_mask = None
274
+
275
+ return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
276
+
277
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
278
+ def prepare_extra_step_kwargs(self, generator, eta):
279
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
280
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
281
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
282
+ # and should be between [0, 1]
283
+
284
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
285
+ extra_step_kwargs = {}
286
+ if accepts_eta:
287
+ extra_step_kwargs["eta"] = eta
288
+
289
+ # check if the scheduler accepts generator
290
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
291
+ if accepts_generator:
292
+ extra_step_kwargs["generator"] = generator
293
+ return extra_step_kwargs
294
+
295
+ def check_inputs(
296
+ self,
297
+ prompt,
298
+ height,
299
+ width,
300
+ negative_prompt,
301
+ callback_steps,
302
+ prompt_embeds=None,
303
+ negative_prompt_embeds=None,
304
+ ):
305
+ if height % 8 != 0 or width % 8 != 0:
306
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
307
+
308
+ if (callback_steps is None) or (
309
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
310
+ ):
311
+ raise ValueError(
312
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
313
+ f" {type(callback_steps)}."
314
+ )
315
+
316
+ if prompt is not None and prompt_embeds is not None:
317
+ raise ValueError(
318
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
319
+ " only forward one of the two."
320
+ )
321
+ elif prompt is None and prompt_embeds is None:
322
+ raise ValueError(
323
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
324
+ )
325
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
326
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
327
+
328
+ if prompt is not None and negative_prompt_embeds is not None:
329
+ raise ValueError(
330
+ f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
331
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
332
+ )
333
+
334
+ if negative_prompt is not None and negative_prompt_embeds is not None:
335
+ raise ValueError(
336
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
337
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
338
+ )
339
+
340
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
341
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
342
+ raise ValueError(
343
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
344
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
345
+ f" {negative_prompt_embeds.shape}."
346
+ )
347
+
348
+ # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing
349
+ def _text_preprocessing(self, text, clean_caption=False):
350
+ if clean_caption and not is_bs4_available():
351
+ logger.warn(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`"))
352
+ logger.warn("Setting `clean_caption` to False...")
353
+ clean_caption = False
354
+
355
+ if clean_caption and not is_ftfy_available():
356
+ logger.warn(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`"))
357
+ logger.warn("Setting `clean_caption` to False...")
358
+ clean_caption = False
359
+
360
+ if not isinstance(text, (tuple, list)):
361
+ text = [text]
362
+
363
+ def process(text: str):
364
+ if clean_caption:
365
+ text = self._clean_caption(text)
366
+ text = self._clean_caption(text)
367
+ else:
368
+ text = text.lower().strip()
369
+ return text
370
+
371
+ return [process(t) for t in text]
372
+
373
+ # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption
374
+ def _clean_caption(self, caption):
375
+ caption = str(caption)
376
+ caption = ul.unquote_plus(caption)
377
+ caption = caption.strip().lower()
378
+ caption = re.sub("<person>", "person", caption)
379
+ # urls:
380
+ caption = re.sub(
381
+ r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
382
+ "",
383
+ caption,
384
+ ) # regex for urls
385
+ caption = re.sub(
386
+ r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
387
+ "",
388
+ caption,
389
+ ) # regex for urls
390
+ # html:
391
+ caption = BeautifulSoup(caption, features="html.parser").text
392
+
393
+ # @<nickname>
394
+ caption = re.sub(r"@[\w\d]+\b", "", caption)
395
+
396
+ # 31C0—31EF CJK Strokes
397
+ # 31F0—31FF Katakana Phonetic Extensions
398
+ # 3200—32FF Enclosed CJK Letters and Months
399
+ # 3300—33FF CJK Compatibility
400
+ # 3400—4DBF CJK Unified Ideographs Extension A
401
+ # 4DC0—4DFF Yijing Hexagram Symbols
402
+ # 4E00—9FFF CJK Unified Ideographs
403
+ caption = re.sub(r"[\u31c0-\u31ef]+", "", caption)
404
+ caption = re.sub(r"[\u31f0-\u31ff]+", "", caption)
405
+ caption = re.sub(r"[\u3200-\u32ff]+", "", caption)
406
+ caption = re.sub(r"[\u3300-\u33ff]+", "", caption)
407
+ caption = re.sub(r"[\u3400-\u4dbf]+", "", caption)
408
+ caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption)
409
+ caption = re.sub(r"[\u4e00-\u9fff]+", "", caption)
410
+ #######################################################
411
+
412
+ # все виды тире / all types of dash --> "-"
413
+ caption = re.sub(
414
+ r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa
415
+ "-",
416
+ caption,
417
+ )
418
+
419
+ # кавычки к одному стандарту
420
+ caption = re.sub(r"[`´«»“”¨]", '"', caption)
421
+ caption = re.sub(r"[‘’]", "'", caption)
422
+
423
+ # &quot;
424
+ caption = re.sub(r"&quot;?", "", caption)
425
+ # &amp
426
+ caption = re.sub(r"&amp", "", caption)
427
+
428
+ # ip adresses:
429
+ caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
430
+
431
+ # article ids:
432
+ caption = re.sub(r"\d:\d\d\s+$", "", caption)
433
+
434
+ # \n
435
+ caption = re.sub(r"\\n", " ", caption)
436
+
437
+ # "#123"
438
+ caption = re.sub(r"#\d{1,3}\b", "", caption)
439
+ # "#12345.."
440
+ caption = re.sub(r"#\d{5,}\b", "", caption)
441
+ # "123456.."
442
+ caption = re.sub(r"\b\d{6,}\b", "", caption)
443
+ # filenames:
444
+ caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption)
445
+
446
+ #
447
+ caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT"""
448
+ caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT"""
449
+
450
+ caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT
451
+ caption = re.sub(r"\s+\.\s+", r" ", caption) # " . "
452
+
453
+ # this-is-my-cute-cat / this_is_my_cute_cat
454
+ regex2 = re.compile(r"(?:\-|\_)")
455
+ if len(re.findall(regex2, caption)) > 3:
456
+ caption = re.sub(regex2, " ", caption)
457
+
458
+ caption = ftfy.fix_text(caption)
459
+ caption = html.unescape(html.unescape(caption))
460
+
461
+ caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640
462
+ caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc
463
+ caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231
464
+
465
+ caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption)
466
+ caption = re.sub(r"(free\s)?download(\sfree)?", "", caption)
467
+ caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption)
468
+ caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption)
469
+ caption = re.sub(r"\bpage\s+\d+\b", "", caption)
470
+
471
+ caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a...
472
+
473
+ caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption)
474
+
475
+ caption = re.sub(r"\b\s+\:\s+", r": ", caption)
476
+ caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption)
477
+ caption = re.sub(r"\s+", " ", caption)
478
+
479
+ caption.strip()
480
+
481
+ caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption)
482
+ caption = re.sub(r"^[\'\_,\-\:;]", r"", caption)
483
+ caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption)
484
+ caption = re.sub(r"^\.\S+$", "", caption)
485
+
486
+ return caption.strip()
487
+
488
+ def prepare_latents(
489
+ self,
490
+ batch_size,
491
+ num_channels_latents,
492
+ height,
493
+ width,
494
+ video_length,
495
+ dtype,
496
+ device,
497
+ generator,
498
+ latents=None,
499
+ video=None,
500
+ timestep=None,
501
+ is_strength_max=True,
502
+ return_noise=False,
503
+ return_video_latents=False,
504
+ ):
505
+ if self.vae.quant_conv.weight.ndim==5:
506
+ shape = (batch_size, num_channels_latents, int(video_length // 5 * 2) if video_length != 1 else 1, height // self.vae_scale_factor, width // self.vae_scale_factor)
507
+ else:
508
+ shape = (batch_size, num_channels_latents, video_length, height // self.vae_scale_factor, width // self.vae_scale_factor)
509
+ if isinstance(generator, list) and len(generator) != batch_size:
510
+ raise ValueError(
511
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
512
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
513
+ )
514
+
515
+ if return_video_latents or (latents is None and not is_strength_max):
516
+ video = video.to(device=device, dtype=dtype)
517
+
518
+ if video.shape[1] == 4:
519
+ video_latents = video
520
+ else:
521
+ video_length = video.shape[2]
522
+ video = rearrange(video, "b c f h w -> (b f) c h w")
523
+ video_latents = self._encode_vae_image(image=video, generator=generator)
524
+ video_latents = rearrange(video_latents, "(b f) c h w -> b c f h w", f=video_length)
525
+ video_latents = video_latents.repeat(batch_size // video_latents.shape[0], 1, 1, 1, 1)
526
+
527
+ if latents is None:
528
+ rand_device = "cpu" if device.type == "mps" else device
529
+
530
+ noise = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
531
+ # if strength is 1. then initialise the latents to noise, else initial to image + noise
532
+ latents = noise if is_strength_max else self.scheduler.add_noise(video_latents, noise, timestep)
533
+ else:
534
+ noise = latents.to(device)
535
+ if latents.shape != shape:
536
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
537
+ latents = latents.to(device)
538
+
539
+ # scale the initial noise by the standard deviation required by the scheduler
540
+ latents = latents * self.scheduler.init_noise_sigma
541
+ outputs = (latents,)
542
+
543
+ if return_noise:
544
+ outputs += (noise,)
545
+
546
+ if return_video_latents:
547
+ outputs += (video_latents,)
548
+
549
+ return outputs
550
+
551
+ def decode_latents(self, latents):
552
+ video_length = latents.shape[2]
553
+ latents = 1 / 0.18215 * latents
554
+ if self.vae.quant_conv.weight.ndim==5:
555
+ mini_batch_decoder = 2
556
+ # Decoder
557
+ video = []
558
+ for i in range(0, latents.shape[2], mini_batch_decoder):
559
+ with torch.no_grad():
560
+ start_index = i
561
+ end_index = i + mini_batch_decoder
562
+ latents_bs = self.vae.decode(latents[:, :, start_index:end_index, :, :])[0]
563
+ video.append(latents_bs)
564
+
565
+ # Smooth
566
+ mini_batch_encoder = 5
567
+ video = torch.cat(video, 2).cpu()
568
+ for i in range(mini_batch_encoder, video.shape[2], mini_batch_encoder):
569
+ origin_before = copy.deepcopy(video[:, :, i - 1, :, :])
570
+ origin_after = copy.deepcopy(video[:, :, i, :, :])
571
+
572
+ video[:, :, i - 1, :, :] = origin_before * 0.75 + origin_after * 0.25
573
+ video[:, :, i, :, :] = origin_before * 0.25 + origin_after * 0.75
574
+ video = video.clamp(-1, 1)
575
+ else:
576
+ latents = rearrange(latents, "b c f h w -> (b f) c h w")
577
+ # video = self.vae.decode(latents).sample
578
+ video = []
579
+ for frame_idx in tqdm(range(latents.shape[0])):
580
+ video.append(self.vae.decode(latents[frame_idx:frame_idx+1]).sample)
581
+ video = torch.cat(video)
582
+ video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length)
583
+ video = (video / 2 + 0.5).clamp(0, 1)
584
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
585
+ video = video.cpu().float().numpy()
586
+ return video
587
+
588
+ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
589
+ if isinstance(generator, list):
590
+ image_latents = [
591
+ retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
592
+ for i in range(image.shape[0])
593
+ ]
594
+ image_latents = torch.cat(image_latents, dim=0)
595
+ else:
596
+ image_latents = retrieve_latents(self.vae.encode(image), generator=generator)
597
+
598
+ image_latents = self.vae.config.scaling_factor * image_latents
599
+
600
+ return image_latents
601
+
602
+ def prepare_mask_latents(
603
+ self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance
604
+ ):
605
+ # resize the mask to latents shape as we concatenate the mask to the latents
606
+ # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
607
+ # and half precision
608
+ video_length = mask.shape[2]
609
+
610
+ mask = mask.to(device=device, dtype=self.vae.dtype)
611
+ if self.vae.quant_conv.weight.ndim==5:
612
+ bs = 1
613
+ new_mask = []
614
+ for i in range(0, mask.shape[0], bs):
615
+ mini_batch = 5
616
+ new_mask_mini_batch = []
617
+ for j in range(0, mask.shape[2], mini_batch):
618
+ mask_bs = mask[i : i + bs, :, j: j + mini_batch, :, :]
619
+ mask_bs = self.vae.encode(mask_bs)[0]
620
+ mask_bs = mask_bs.sample()
621
+ new_mask_mini_batch.append(mask_bs)
622
+ new_mask_mini_batch = torch.cat(new_mask_mini_batch, dim = 2)
623
+ new_mask.append(new_mask_mini_batch)
624
+ mask = torch.cat(new_mask, dim = 0)
625
+ mask = mask * 0.1825
626
+
627
+ else:
628
+ if mask.shape[1] == 4:
629
+ mask = mask
630
+ else:
631
+ video_length = mask.shape[2]
632
+ mask = rearrange(mask, "b c f h w -> (b f) c h w")
633
+ mask = self._encode_vae_image(mask, generator=generator)
634
+ mask = rearrange(mask, "(b f) c h w -> b c f h w", f=video_length)
635
+
636
+ masked_image = masked_image.to(device=device, dtype=self.vae.dtype)
637
+ if self.vae.quant_conv.weight.ndim==5:
638
+ bs = 1
639
+ new_mask_pixel_values = []
640
+ for i in range(0, masked_image.shape[0], bs):
641
+ mini_batch = 5
642
+ new_mask_pixel_values_mini_batch = []
643
+ for j in range(0, masked_image.shape[2], mini_batch):
644
+ mask_pixel_values_bs = masked_image[i : i + bs, :, j: j + mini_batch, :, :]
645
+ mask_pixel_values_bs = self.vae.encode(mask_pixel_values_bs)[0]
646
+ mask_pixel_values_bs = mask_pixel_values_bs.sample()
647
+ new_mask_pixel_values_mini_batch.append(mask_pixel_values_bs)
648
+ new_mask_pixel_values_mini_batch = torch.cat(new_mask_pixel_values_mini_batch, dim = 2)
649
+ new_mask_pixel_values.append(new_mask_pixel_values_mini_batch)
650
+ masked_image_latents = torch.cat(new_mask_pixel_values, dim = 0)
651
+ masked_image_latents = masked_image_latents * 0.1825
652
+
653
+ else:
654
+ if masked_image.shape[1] == 4:
655
+ masked_image_latents = masked_image
656
+ else:
657
+ video_length = mask.shape[2]
658
+ masked_image = rearrange(masked_image, "b c f h w -> (b f) c h w")
659
+ masked_image_latents = self._encode_vae_image(masked_image, generator=generator)
660
+ masked_image_latents = rearrange(masked_image_latents, "(b f) c h w -> b c f h w", f=video_length)
661
+
662
+ # aligning device to prevent device errors when concating it with the latent model input
663
+ masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
664
+ return mask, masked_image_latents
665
+
666
+ @torch.no_grad()
667
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
668
+ def __call__(
669
+ self,
670
+ prompt: Union[str, List[str]] = None,
671
+ video_length: Optional[int] = None,
672
+ video: Union[torch.FloatTensor] = None,
673
+ mask_video: Union[torch.FloatTensor] = None,
674
+ masked_video_latents: Union[torch.FloatTensor] = None,
675
+ negative_prompt: str = "",
676
+ num_inference_steps: int = 20,
677
+ timesteps: List[int] = None,
678
+ guidance_scale: float = 4.5,
679
+ num_images_per_prompt: Optional[int] = 1,
680
+ height: Optional[int] = None,
681
+ width: Optional[int] = None,
682
+ strength: float = 1.0,
683
+ eta: float = 0.0,
684
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
685
+ latents: Optional[torch.FloatTensor] = None,
686
+ prompt_embeds: Optional[torch.FloatTensor] = None,
687
+ prompt_attention_mask: Optional[torch.FloatTensor] = None,
688
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
689
+ negative_prompt_attention_mask: Optional[torch.FloatTensor] = None,
690
+ output_type: Optional[str] = "latent",
691
+ return_dict: bool = True,
692
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
693
+ callback_steps: int = 1,
694
+ clean_caption: bool = True,
695
+ mask_feature: bool = True,
696
+ max_sequence_length: int = 120
697
+ ) -> Union[EasyAnimatePipelineOutput, Tuple]:
698
+ """
699
+ Function invoked when calling the pipeline for generation.
700
+
701
+ Args:
702
+ prompt (`str` or `List[str]`, *optional*):
703
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
704
+ instead.
705
+ negative_prompt (`str` or `List[str]`, *optional*):
706
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
707
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
708
+ less than `1`).
709
+ num_inference_steps (`int`, *optional*, defaults to 100):
710
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
711
+ expense of slower inference.
712
+ timesteps (`List[int]`, *optional*):
713
+ Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps`
714
+ timesteps are used. Must be in descending order.
715
+ guidance_scale (`float`, *optional*, defaults to 7.0):
716
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
717
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
718
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
719
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
720
+ usually at the expense of lower image quality.
721
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
722
+ The number of images to generate per prompt.
723
+ height (`int`, *optional*, defaults to self.unet.config.sample_size):
724
+ The height in pixels of the generated image.
725
+ width (`int`, *optional*, defaults to self.unet.config.sample_size):
726
+ The width in pixels of the generated image.
727
+ eta (`float`, *optional*, defaults to 0.0):
728
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
729
+ [`schedulers.DDIMScheduler`], will be ignored for others.
730
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
731
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
732
+ to make generation deterministic.
733
+ latents (`torch.FloatTensor`, *optional*):
734
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
735
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
736
+ tensor will ge generated by sampling using the supplied random `generator`.
737
+ prompt_embeds (`torch.FloatTensor`, *optional*):
738
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
739
+ provided, text embeddings will be generated from `prompt` input argument.
740
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
741
+ Pre-generated negative text embeddings. For PixArt-Alpha this negative prompt should be "". If not
742
+ provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
743
+ output_type (`str`, *optional*, defaults to `"pil"`):
744
+ The output format of the generate image. Choose between
745
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
746
+ return_dict (`bool`, *optional*, defaults to `True`):
747
+ Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple.
748
+ callback (`Callable`, *optional*):
749
+ A function that will be called every `callback_steps` steps during inference. The function will be
750
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
751
+ callback_steps (`int`, *optional*, defaults to 1):
752
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
753
+ called at every step.
754
+ clean_caption (`bool`, *optional*, defaults to `True`):
755
+ Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to
756
+ be installed. If the dependencies are not installed, the embeddings will be created from the raw
757
+ prompt.
758
+ mask_feature (`bool` defaults to `True`): If set to `True`, the text embeddings will be masked.
759
+
760
+ Examples:
761
+
762
+ Returns:
763
+ [`~pipelines.ImagePipelineOutput`] or `tuple`:
764
+ If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is
765
+ returned where the first element is a list with the generated images
766
+ """
767
+ # 1. Check inputs. Raise error if not correct
768
+ height = height or self.transformer.config.sample_size * self.vae_scale_factor
769
+ width = width or self.transformer.config.sample_size * self.vae_scale_factor
770
+
771
+ # 2. Default height and width to transformer
772
+ if prompt is not None and isinstance(prompt, str):
773
+ batch_size = 1
774
+ elif prompt is not None and isinstance(prompt, list):
775
+ batch_size = len(prompt)
776
+ else:
777
+ batch_size = prompt_embeds.shape[0]
778
+
779
+ device = self._execution_device
780
+
781
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
782
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
783
+ # corresponds to doing no classifier free guidance.
784
+ do_classifier_free_guidance = guidance_scale > 1.0
785
+
786
+ # 3. Encode input prompt
787
+ (
788
+ prompt_embeds,
789
+ prompt_attention_mask,
790
+ negative_prompt_embeds,
791
+ negative_prompt_attention_mask,
792
+ ) = self.encode_prompt(
793
+ prompt,
794
+ do_classifier_free_guidance,
795
+ negative_prompt=negative_prompt,
796
+ num_images_per_prompt=num_images_per_prompt,
797
+ device=device,
798
+ prompt_embeds=prompt_embeds,
799
+ negative_prompt_embeds=negative_prompt_embeds,
800
+ prompt_attention_mask=prompt_attention_mask,
801
+ negative_prompt_attention_mask=negative_prompt_attention_mask,
802
+ clean_caption=clean_caption,
803
+ max_sequence_length=max_sequence_length,
804
+ )
805
+ if do_classifier_free_guidance:
806
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
807
+ prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
808
+
809
+ # 4. Prepare timesteps
810
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
811
+ timesteps = self.scheduler.timesteps
812
+ # at which timestep to set the initial noise (n.b. 50% if strength is 0.5)
813
+ latent_timestep = timesteps[:1].repeat(batch_size)
814
+ # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise
815
+ is_strength_max = strength == 1.0
816
+
817
+ if video is not None:
818
+ video_length = video.shape[2]
819
+ init_video = self.image_processor.preprocess(rearrange(video, "b c f h w -> (b f) c h w"), height=height, width=width)
820
+ init_video = init_video.to(dtype=torch.float32)
821
+ init_video = rearrange(init_video, "(b f) c h w -> b c f h w", f=video_length)
822
+ else:
823
+ init_video = None
824
+
825
+ # Prepare latent variables
826
+ num_channels_latents = self.vae.config.latent_channels
827
+ num_channels_transformer = self.transformer.config.in_channels
828
+ return_image_latents = num_channels_transformer == 4
829
+
830
+ # 5. Prepare latents.
831
+ latents_outputs = self.prepare_latents(
832
+ batch_size * num_images_per_prompt,
833
+ num_channels_latents,
834
+ height,
835
+ width,
836
+ video_length,
837
+ prompt_embeds.dtype,
838
+ device,
839
+ generator,
840
+ latents,
841
+ video=init_video,
842
+ timestep=latent_timestep,
843
+ is_strength_max=is_strength_max,
844
+ return_noise=True,
845
+ return_video_latents=return_image_latents,
846
+ )
847
+ if return_image_latents:
848
+ latents, noise, image_latents = latents_outputs
849
+ else:
850
+ latents, noise = latents_outputs
851
+ latents_dtype = latents.dtype
852
+
853
+ if mask_video is not None:
854
+ # Prepare mask latent variables
855
+ video_length = video.shape[2]
856
+ mask_condition = self.mask_processor.preprocess(rearrange(mask_video, "b c f h w -> (b f) c h w"), height=height, width=width)
857
+ mask_condition = mask_condition.to(dtype=torch.float32)
858
+ mask_condition = rearrange(mask_condition, "(b f) c h w -> b c f h w", f=video_length)
859
+
860
+ if masked_video_latents is None:
861
+ masked_video = init_video * (mask_condition < 0.5) + torch.ones_like(init_video) * (mask_condition > 0.5) * -1
862
+ else:
863
+ masked_video = masked_video_latents
864
+
865
+ mask, masked_video_latents = self.prepare_mask_latents(
866
+ mask_condition,
867
+ masked_video,
868
+ batch_size,
869
+ height,
870
+ width,
871
+ prompt_embeds.dtype,
872
+ device,
873
+ generator,
874
+ do_classifier_free_guidance,
875
+ )
876
+ else:
877
+ mask = torch.zeros_like(latents).to(latents.device, latents.dtype)
878
+ masked_video_latents = torch.zeros_like(latents).to(latents.device, latents.dtype)
879
+
880
+ # Check that sizes of mask, masked image and latents match
881
+ if num_channels_transformer == 12:
882
+ # default case for runwayml/stable-diffusion-inpainting
883
+ num_channels_mask = mask.shape[1]
884
+ num_channels_masked_image = masked_video_latents.shape[1]
885
+ if num_channels_latents + num_channels_mask + num_channels_masked_image != self.transformer.config.in_channels:
886
+ raise ValueError(
887
+ f"Incorrect configuration settings! The config of `pipeline.transformer`: {self.transformer.config} expects"
888
+ f" {self.transformer.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
889
+ f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
890
+ f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
891
+ " `pipeline.transformer` or your `mask_image` or `image` input."
892
+ )
893
+ elif num_channels_transformer == 4:
894
+ raise ValueError(
895
+ f"The transformer {self.transformer.__class__} should have 9 input channels, not {self.transformer.config.in_channels}."
896
+ )
897
+
898
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
899
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
900
+
901
+ # 6.1 Prepare micro-conditions.
902
+ added_cond_kwargs = {"resolution": None, "aspect_ratio": None}
903
+ if self.transformer.config.sample_size == 128:
904
+ resolution = torch.tensor([height, width]).repeat(batch_size * num_images_per_prompt, 1)
905
+ aspect_ratio = torch.tensor([float(height / width)]).repeat(batch_size * num_images_per_prompt, 1)
906
+ resolution = resolution.to(dtype=prompt_embeds.dtype, device=device)
907
+ aspect_ratio = aspect_ratio.to(dtype=prompt_embeds.dtype, device=device)
908
+
909
+ if do_classifier_free_guidance:
910
+ resolution = torch.cat([resolution, resolution], dim=0)
911
+ aspect_ratio = torch.cat([aspect_ratio, aspect_ratio], dim=0)
912
+
913
+ added_cond_kwargs = {"resolution": resolution, "aspect_ratio": aspect_ratio}
914
+
915
+ # 7. Denoising loop
916
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
917
+
918
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
919
+ for i, t in enumerate(timesteps):
920
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
921
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
922
+
923
+ if num_channels_transformer == 12:
924
+ mask_input = torch.cat([mask] * 2) if do_classifier_free_guidance else mask
925
+ masked_video_latents_input = (
926
+ torch.cat([masked_video_latents] * 2) if do_classifier_free_guidance else masked_video_latents
927
+ )
928
+ inpaint_latents = torch.cat([mask_input, masked_video_latents_input], dim=1)
929
+
930
+ current_timestep = t
931
+ if not torch.is_tensor(current_timestep):
932
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
933
+ # This would be a good case for the `match` statement (Python 3.10+)
934
+ is_mps = latent_model_input.device.type == "mps"
935
+ if isinstance(current_timestep, float):
936
+ dtype = torch.float32 if is_mps else torch.float64
937
+ else:
938
+ dtype = torch.int32 if is_mps else torch.int64
939
+ current_timestep = torch.tensor([current_timestep], dtype=dtype, device=latent_model_input.device)
940
+ elif len(current_timestep.shape) == 0:
941
+ current_timestep = current_timestep[None].to(latent_model_input.device)
942
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
943
+ current_timestep = current_timestep.expand(latent_model_input.shape[0])
944
+
945
+ # predict noise model_output
946
+ noise_pred = self.transformer(
947
+ latent_model_input,
948
+ encoder_hidden_states=prompt_embeds,
949
+ encoder_attention_mask=prompt_attention_mask,
950
+ timestep=current_timestep,
951
+ added_cond_kwargs=added_cond_kwargs,
952
+ inpaint_latents=inpaint_latents.to(latent_model_input.dtype),
953
+ return_dict=False,
954
+ )[0]
955
+
956
+ # perform guidance
957
+ if do_classifier_free_guidance:
958
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
959
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
960
+
961
+ # learned sigma
962
+ noise_pred = noise_pred.chunk(2, dim=1)[0]
963
+
964
+ # compute previous image: x_t -> x_t-1
965
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
966
+
967
+ # call the callback, if provided
968
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
969
+ progress_bar.update()
970
+ if callback is not None and i % callback_steps == 0:
971
+ step_idx = i // getattr(self.scheduler, "order", 1)
972
+ callback(step_idx, t, latents)
973
+
974
+ # Post-processing
975
+ video = self.decode_latents(latents)
976
+
977
+ # Convert to tensor
978
+ if output_type == "latent":
979
+ video = torch.from_numpy(video)
980
+
981
+ if not return_dict:
982
+ return video
983
+
984
+ return EasyAnimatePipelineOutput(videos=video)
easyanimate/pipeline/pipeline_pixart_magvit.py ADDED
@@ -0,0 +1,983 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 PixArt-Alpha Authors and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import html
16
+ import inspect
17
+ import re
18
+ import urllib.parse as ul
19
+ from typing import Callable, List, Optional, Tuple, Union
20
+
21
+ import torch
22
+ import torch.nn.functional as F
23
+ from diffusers.image_processor import VaeImageProcessor
24
+ from diffusers.models import Transformer2DModel
25
+ from diffusers.pipelines.pipeline_utils import (DiffusionPipeline,
26
+ ImagePipelineOutput)
27
+ from diffusers.schedulers import DPMSolverMultistepScheduler
28
+ from diffusers.utils import (BACKENDS_MAPPING, deprecate, is_bs4_available,
29
+ is_ftfy_available, logging,
30
+ replace_example_docstring)
31
+ from diffusers.utils.torch_utils import randn_tensor
32
+ from transformers import T5EncoderModel, T5Tokenizer
33
+
34
+ from easyanimate.models.autoencoder_magvit import AutoencoderKLMagvit
35
+
36
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
37
+
38
+ if is_bs4_available():
39
+ from bs4 import BeautifulSoup
40
+
41
+ if is_ftfy_available():
42
+ import ftfy
43
+
44
+ EXAMPLE_DOC_STRING = """
45
+ Examples:
46
+ ```py
47
+ >>> import torch
48
+ >>> from diffusers import PixArtAlphaPipeline
49
+
50
+ >>> # You can replace the checkpoint id with "PixArt-alpha/PixArt-XL-2-512x512" too.
51
+ >>> pipe = PixArtAlphaPipeline.from_pretrained("PixArt-alpha/PixArt-XL-2-1024-MS", torch_dtype=torch.float16)
52
+ >>> # Enable memory optimizations.
53
+ >>> pipe.enable_model_cpu_offload()
54
+
55
+ >>> prompt = "A small cactus with a happy face in the Sahara desert."
56
+ >>> image = pipe(prompt).images[0]
57
+ ```
58
+ """
59
+
60
+ ASPECT_RATIO_1024_BIN = {
61
+ "0.25": [512.0, 2048.0],
62
+ "0.28": [512.0, 1856.0],
63
+ "0.32": [576.0, 1792.0],
64
+ "0.33": [576.0, 1728.0],
65
+ "0.35": [576.0, 1664.0],
66
+ "0.4": [640.0, 1600.0],
67
+ "0.42": [640.0, 1536.0],
68
+ "0.48": [704.0, 1472.0],
69
+ "0.5": [704.0, 1408.0],
70
+ "0.52": [704.0, 1344.0],
71
+ "0.57": [768.0, 1344.0],
72
+ "0.6": [768.0, 1280.0],
73
+ "0.68": [832.0, 1216.0],
74
+ "0.72": [832.0, 1152.0],
75
+ "0.78": [896.0, 1152.0],
76
+ "0.82": [896.0, 1088.0],
77
+ "0.88": [960.0, 1088.0],
78
+ "0.94": [960.0, 1024.0],
79
+ "1.0": [1024.0, 1024.0],
80
+ "1.07": [1024.0, 960.0],
81
+ "1.13": [1088.0, 960.0],
82
+ "1.21": [1088.0, 896.0],
83
+ "1.29": [1152.0, 896.0],
84
+ "1.38": [1152.0, 832.0],
85
+ "1.46": [1216.0, 832.0],
86
+ "1.67": [1280.0, 768.0],
87
+ "1.75": [1344.0, 768.0],
88
+ "2.0": [1408.0, 704.0],
89
+ "2.09": [1472.0, 704.0],
90
+ "2.4": [1536.0, 640.0],
91
+ "2.5": [1600.0, 640.0],
92
+ "3.0": [1728.0, 576.0],
93
+ "4.0": [2048.0, 512.0],
94
+ }
95
+
96
+ ASPECT_RATIO_512_BIN = {
97
+ "0.25": [256.0, 1024.0],
98
+ "0.28": [256.0, 928.0],
99
+ "0.32": [288.0, 896.0],
100
+ "0.33": [288.0, 864.0],
101
+ "0.35": [288.0, 832.0],
102
+ "0.4": [320.0, 800.0],
103
+ "0.42": [320.0, 768.0],
104
+ "0.48": [352.0, 736.0],
105
+ "0.5": [352.0, 704.0],
106
+ "0.52": [352.0, 672.0],
107
+ "0.57": [384.0, 672.0],
108
+ "0.6": [384.0, 640.0],
109
+ "0.68": [416.0, 608.0],
110
+ "0.72": [416.0, 576.0],
111
+ "0.78": [448.0, 576.0],
112
+ "0.82": [448.0, 544.0],
113
+ "0.88": [480.0, 544.0],
114
+ "0.94": [480.0, 512.0],
115
+ "1.0": [512.0, 512.0],
116
+ "1.07": [512.0, 480.0],
117
+ "1.13": [544.0, 480.0],
118
+ "1.21": [544.0, 448.0],
119
+ "1.29": [576.0, 448.0],
120
+ "1.38": [576.0, 416.0],
121
+ "1.46": [608.0, 416.0],
122
+ "1.67": [640.0, 384.0],
123
+ "1.75": [672.0, 384.0],
124
+ "2.0": [704.0, 352.0],
125
+ "2.09": [736.0, 352.0],
126
+ "2.4": [768.0, 320.0],
127
+ "2.5": [800.0, 320.0],
128
+ "3.0": [864.0, 288.0],
129
+ "4.0": [1024.0, 256.0],
130
+ }
131
+
132
+ ASPECT_RATIO_256_BIN = {
133
+ "0.25": [128.0, 512.0],
134
+ "0.28": [128.0, 464.0],
135
+ "0.32": [144.0, 448.0],
136
+ "0.33": [144.0, 432.0],
137
+ "0.35": [144.0, 416.0],
138
+ "0.4": [160.0, 400.0],
139
+ "0.42": [160.0, 384.0],
140
+ "0.48": [176.0, 368.0],
141
+ "0.5": [176.0, 352.0],
142
+ "0.52": [176.0, 336.0],
143
+ "0.57": [192.0, 336.0],
144
+ "0.6": [192.0, 320.0],
145
+ "0.68": [208.0, 304.0],
146
+ "0.72": [208.0, 288.0],
147
+ "0.78": [224.0, 288.0],
148
+ "0.82": [224.0, 272.0],
149
+ "0.88": [240.0, 272.0],
150
+ "0.94": [240.0, 256.0],
151
+ "1.0": [256.0, 256.0],
152
+ "1.07": [256.0, 240.0],
153
+ "1.13": [272.0, 240.0],
154
+ "1.21": [272.0, 224.0],
155
+ "1.29": [288.0, 224.0],
156
+ "1.38": [288.0, 208.0],
157
+ "1.46": [304.0, 208.0],
158
+ "1.67": [320.0, 192.0],
159
+ "1.75": [336.0, 192.0],
160
+ "2.0": [352.0, 176.0],
161
+ "2.09": [368.0, 176.0],
162
+ "2.4": [384.0, 160.0],
163
+ "2.5": [400.0, 160.0],
164
+ "3.0": [432.0, 144.0],
165
+ "4.0": [512.0, 128.0],
166
+ }
167
+
168
+
169
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
170
+ def retrieve_timesteps(
171
+ scheduler,
172
+ num_inference_steps: Optional[int] = None,
173
+ device: Optional[Union[str, torch.device]] = None,
174
+ timesteps: Optional[List[int]] = None,
175
+ **kwargs,
176
+ ):
177
+ """
178
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
179
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
180
+
181
+ Args:
182
+ scheduler (`SchedulerMixin`):
183
+ The scheduler to get timesteps from.
184
+ num_inference_steps (`int`):
185
+ The number of diffusion steps used when generating samples with a pre-trained model. If used,
186
+ `timesteps` must be `None`.
187
+ device (`str` or `torch.device`, *optional*):
188
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
189
+ timesteps (`List[int]`, *optional*):
190
+ Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
191
+ timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
192
+ must be `None`.
193
+
194
+ Returns:
195
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
196
+ second element is the number of inference steps.
197
+ """
198
+ if timesteps is not None:
199
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
200
+ if not accepts_timesteps:
201
+ raise ValueError(
202
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
203
+ f" timestep schedules. Please check whether you are using the correct scheduler."
204
+ )
205
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
206
+ timesteps = scheduler.timesteps
207
+ num_inference_steps = len(timesteps)
208
+ else:
209
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
210
+ timesteps = scheduler.timesteps
211
+ return timesteps, num_inference_steps
212
+
213
+
214
+ class PixArtAlphaMagvitPipeline(DiffusionPipeline):
215
+ r"""
216
+ Pipeline for text-to-image generation using PixArt-Alpha.
217
+
218
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
219
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
220
+
221
+ Args:
222
+ vae ([`AutoencoderKL`]):
223
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
224
+ text_encoder ([`T5EncoderModel`]):
225
+ Frozen text-encoder. PixArt-Alpha uses
226
+ [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the
227
+ [t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant.
228
+ tokenizer (`T5Tokenizer`):
229
+ Tokenizer of class
230
+ [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
231
+ transformer ([`Transformer2DModel`]):
232
+ A text conditioned `Transformer2DModel` to denoise the encoded image latents.
233
+ scheduler ([`SchedulerMixin`]):
234
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
235
+ """
236
+
237
+ bad_punct_regex = re.compile(
238
+ r"["
239
+ + "#®•©™&@·º½¾¿¡§~"
240
+ + r"\)"
241
+ + r"\("
242
+ + r"\]"
243
+ + r"\["
244
+ + r"\}"
245
+ + r"\{"
246
+ + r"\|"
247
+ + "\\"
248
+ + r"\/"
249
+ + r"\*"
250
+ + r"]{1,}"
251
+ ) # noqa
252
+
253
+ _optional_components = ["tokenizer", "text_encoder"]
254
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
255
+
256
+ def __init__(
257
+ self,
258
+ tokenizer: T5Tokenizer,
259
+ text_encoder: T5EncoderModel,
260
+ vae: AutoencoderKLMagvit,
261
+ transformer: Transformer2DModel,
262
+ scheduler: DPMSolverMultistepScheduler,
263
+ ):
264
+ super().__init__()
265
+
266
+ self.register_modules(
267
+ tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
268
+ )
269
+
270
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
271
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
272
+
273
+ # Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/utils.py
274
+ def mask_text_embeddings(self, emb, mask):
275
+ if emb.shape[0] == 1:
276
+ keep_index = mask.sum().item()
277
+ return emb[:, :, :keep_index, :], keep_index
278
+ else:
279
+ masked_feature = emb * mask[:, None, :, None]
280
+ return masked_feature, emb.shape[2]
281
+
282
+ # Adapted from diffusers.pipelines.deepfloyd_if.pipeline_if.encode_prompt
283
+ def encode_prompt(
284
+ self,
285
+ prompt: Union[str, List[str]],
286
+ do_classifier_free_guidance: bool = True,
287
+ negative_prompt: str = "",
288
+ num_images_per_prompt: int = 1,
289
+ device: Optional[torch.device] = None,
290
+ prompt_embeds: Optional[torch.FloatTensor] = None,
291
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
292
+ prompt_attention_mask: Optional[torch.FloatTensor] = None,
293
+ negative_prompt_attention_mask: Optional[torch.FloatTensor] = None,
294
+ clean_caption: bool = False,
295
+ max_sequence_length: int = 120,
296
+ **kwargs,
297
+ ):
298
+ r"""
299
+ Encodes the prompt into text encoder hidden states.
300
+
301
+ Args:
302
+ prompt (`str` or `List[str]`, *optional*):
303
+ prompt to be encoded
304
+ negative_prompt (`str` or `List[str]`, *optional*):
305
+ The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds`
306
+ instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For
307
+ PixArt-Alpha, this should be "".
308
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
309
+ whether to use classifier free guidance or not
310
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
311
+ number of images that should be generated per prompt
312
+ device: (`torch.device`, *optional*):
313
+ torch device to place the resulting embeddings on
314
+ prompt_embeds (`torch.FloatTensor`, *optional*):
315
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
316
+ provided, text embeddings will be generated from `prompt` input argument.
317
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
318
+ Pre-generated negative text embeddings. For PixArt-Alpha, it's should be the embeddings of the ""
319
+ string.
320
+ clean_caption (`bool`, defaults to `False`):
321
+ If `True`, the function will preprocess and clean the provided caption before encoding.
322
+ max_sequence_length (`int`, defaults to 120): Maximum sequence length to use for the prompt.
323
+ """
324
+
325
+ if "mask_feature" in kwargs:
326
+ deprecation_message = "The use of `mask_feature` is deprecated. It is no longer used in any computation and that doesn't affect the end results. It will be removed in a future version."
327
+ deprecate("mask_feature", "1.0.0", deprecation_message, standard_warn=False)
328
+
329
+ if device is None:
330
+ device = self._execution_device
331
+
332
+ if prompt is not None and isinstance(prompt, str):
333
+ batch_size = 1
334
+ elif prompt is not None and isinstance(prompt, list):
335
+ batch_size = len(prompt)
336
+ else:
337
+ batch_size = prompt_embeds.shape[0]
338
+
339
+ # See Section 3.1. of the paper.
340
+ max_length = max_sequence_length
341
+
342
+ if prompt_embeds is None:
343
+ prompt = self._text_preprocessing(prompt, clean_caption=clean_caption)
344
+ text_inputs = self.tokenizer(
345
+ prompt,
346
+ padding="max_length",
347
+ max_length=max_length,
348
+ truncation=True,
349
+ add_special_tokens=True,
350
+ return_tensors="pt",
351
+ )
352
+ text_input_ids = text_inputs.input_ids
353
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
354
+
355
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
356
+ text_input_ids, untruncated_ids
357
+ ):
358
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1])
359
+ logger.warning(
360
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
361
+ f" {max_length} tokens: {removed_text}"
362
+ )
363
+
364
+ prompt_attention_mask = text_inputs.attention_mask
365
+ prompt_attention_mask = prompt_attention_mask.to(device)
366
+
367
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask)
368
+ prompt_embeds = prompt_embeds[0]
369
+
370
+ if self.text_encoder is not None:
371
+ dtype = self.text_encoder.dtype
372
+ elif self.transformer is not None:
373
+ dtype = self.transformer.dtype
374
+ else:
375
+ dtype = None
376
+
377
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
378
+
379
+ bs_embed, seq_len, _ = prompt_embeds.shape
380
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
381
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
382
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
383
+ prompt_attention_mask = prompt_attention_mask.view(bs_embed, -1)
384
+ prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
385
+
386
+ # get unconditional embeddings for classifier free guidance
387
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
388
+ uncond_tokens = [negative_prompt] * batch_size
389
+ uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption)
390
+ max_length = prompt_embeds.shape[1]
391
+ uncond_input = self.tokenizer(
392
+ uncond_tokens,
393
+ padding="max_length",
394
+ max_length=max_length,
395
+ truncation=True,
396
+ return_attention_mask=True,
397
+ add_special_tokens=True,
398
+ return_tensors="pt",
399
+ )
400
+ negative_prompt_attention_mask = uncond_input.attention_mask
401
+ negative_prompt_attention_mask = negative_prompt_attention_mask.to(device)
402
+
403
+ negative_prompt_embeds = self.text_encoder(
404
+ uncond_input.input_ids.to(device), attention_mask=negative_prompt_attention_mask
405
+ )
406
+ negative_prompt_embeds = negative_prompt_embeds[0]
407
+
408
+ if do_classifier_free_guidance:
409
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
410
+ seq_len = negative_prompt_embeds.shape[1]
411
+
412
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
413
+
414
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
415
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
416
+
417
+ negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed, -1)
418
+ negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1)
419
+ else:
420
+ negative_prompt_embeds = None
421
+ negative_prompt_attention_mask = None
422
+
423
+ return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
424
+
425
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
426
+ def prepare_extra_step_kwargs(self, generator, eta):
427
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
428
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
429
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
430
+ # and should be between [0, 1]
431
+
432
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
433
+ extra_step_kwargs = {}
434
+ if accepts_eta:
435
+ extra_step_kwargs["eta"] = eta
436
+
437
+ # check if the scheduler accepts generator
438
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
439
+ if accepts_generator:
440
+ extra_step_kwargs["generator"] = generator
441
+ return extra_step_kwargs
442
+
443
+ def check_inputs(
444
+ self,
445
+ prompt,
446
+ height,
447
+ width,
448
+ negative_prompt,
449
+ callback_steps,
450
+ prompt_embeds=None,
451
+ negative_prompt_embeds=None,
452
+ prompt_attention_mask=None,
453
+ negative_prompt_attention_mask=None,
454
+ ):
455
+ if height % 8 != 0 or width % 8 != 0:
456
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
457
+
458
+ if (callback_steps is None) or (
459
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
460
+ ):
461
+ raise ValueError(
462
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
463
+ f" {type(callback_steps)}."
464
+ )
465
+
466
+ if prompt is not None and prompt_embeds is not None:
467
+ raise ValueError(
468
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
469
+ " only forward one of the two."
470
+ )
471
+ elif prompt is None and prompt_embeds is None:
472
+ raise ValueError(
473
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
474
+ )
475
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
476
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
477
+
478
+ if prompt is not None and negative_prompt_embeds is not None:
479
+ raise ValueError(
480
+ f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
481
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
482
+ )
483
+
484
+ if negative_prompt is not None and negative_prompt_embeds is not None:
485
+ raise ValueError(
486
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
487
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
488
+ )
489
+
490
+ if prompt_embeds is not None and prompt_attention_mask is None:
491
+ raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.")
492
+
493
+ if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:
494
+ raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
495
+
496
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
497
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
498
+ raise ValueError(
499
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
500
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
501
+ f" {negative_prompt_embeds.shape}."
502
+ )
503
+ if prompt_attention_mask.shape != negative_prompt_attention_mask.shape:
504
+ raise ValueError(
505
+ "`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but"
506
+ f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`"
507
+ f" {negative_prompt_attention_mask.shape}."
508
+ )
509
+
510
+ # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing
511
+ def _text_preprocessing(self, text, clean_caption=False):
512
+ if clean_caption and not is_bs4_available():
513
+ logger.warn(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`"))
514
+ logger.warn("Setting `clean_caption` to False...")
515
+ clean_caption = False
516
+
517
+ if clean_caption and not is_ftfy_available():
518
+ logger.warn(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`"))
519
+ logger.warn("Setting `clean_caption` to False...")
520
+ clean_caption = False
521
+
522
+ if not isinstance(text, (tuple, list)):
523
+ text = [text]
524
+
525
+ def process(text: str):
526
+ if clean_caption:
527
+ text = self._clean_caption(text)
528
+ text = self._clean_caption(text)
529
+ else:
530
+ text = text.lower().strip()
531
+ return text
532
+
533
+ return [process(t) for t in text]
534
+
535
+ # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption
536
+ def _clean_caption(self, caption):
537
+ caption = str(caption)
538
+ caption = ul.unquote_plus(caption)
539
+ caption = caption.strip().lower()
540
+ caption = re.sub("<person>", "person", caption)
541
+ # urls:
542
+ caption = re.sub(
543
+ r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
544
+ "",
545
+ caption,
546
+ ) # regex for urls
547
+ caption = re.sub(
548
+ r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
549
+ "",
550
+ caption,
551
+ ) # regex for urls
552
+ # html:
553
+ caption = BeautifulSoup(caption, features="html.parser").text
554
+
555
+ # @<nickname>
556
+ caption = re.sub(r"@[\w\d]+\b", "", caption)
557
+
558
+ # 31C0—31EF CJK Strokes
559
+ # 31F0—31FF Katakana Phonetic Extensions
560
+ # 3200—32FF Enclosed CJK Letters and Months
561
+ # 3300—33FF CJK Compatibility
562
+ # 3400—4DBF CJK Unified Ideographs Extension A
563
+ # 4DC0—4DFF Yijing Hexagram Symbols
564
+ # 4E00—9FFF CJK Unified Ideographs
565
+ caption = re.sub(r"[\u31c0-\u31ef]+", "", caption)
566
+ caption = re.sub(r"[\u31f0-\u31ff]+", "", caption)
567
+ caption = re.sub(r"[\u3200-\u32ff]+", "", caption)
568
+ caption = re.sub(r"[\u3300-\u33ff]+", "", caption)
569
+ caption = re.sub(r"[\u3400-\u4dbf]+", "", caption)
570
+ caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption)
571
+ caption = re.sub(r"[\u4e00-\u9fff]+", "", caption)
572
+ #######################################################
573
+
574
+ # все виды тире / all types of dash --> "-"
575
+ caption = re.sub(
576
+ r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa
577
+ "-",
578
+ caption,
579
+ )
580
+
581
+ # кавычки к одному стандарту
582
+ caption = re.sub(r"[`´«»“”¨]", '"', caption)
583
+ caption = re.sub(r"[‘’]", "'", caption)
584
+
585
+ # &quot;
586
+ caption = re.sub(r"&quot;?", "", caption)
587
+ # &amp
588
+ caption = re.sub(r"&amp", "", caption)
589
+
590
+ # ip adresses:
591
+ caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
592
+
593
+ # article ids:
594
+ caption = re.sub(r"\d:\d\d\s+$", "", caption)
595
+
596
+ # \n
597
+ caption = re.sub(r"\\n", " ", caption)
598
+
599
+ # "#123"
600
+ caption = re.sub(r"#\d{1,3}\b", "", caption)
601
+ # "#12345.."
602
+ caption = re.sub(r"#\d{5,}\b", "", caption)
603
+ # "123456.."
604
+ caption = re.sub(r"\b\d{6,}\b", "", caption)
605
+ # filenames:
606
+ caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption)
607
+
608
+ #
609
+ caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT"""
610
+ caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT"""
611
+
612
+ caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT
613
+ caption = re.sub(r"\s+\.\s+", r" ", caption) # " . "
614
+
615
+ # this-is-my-cute-cat / this_is_my_cute_cat
616
+ regex2 = re.compile(r"(?:\-|\_)")
617
+ if len(re.findall(regex2, caption)) > 3:
618
+ caption = re.sub(regex2, " ", caption)
619
+
620
+ caption = ftfy.fix_text(caption)
621
+ caption = html.unescape(html.unescape(caption))
622
+
623
+ caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640
624
+ caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc
625
+ caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231
626
+
627
+ caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption)
628
+ caption = re.sub(r"(free\s)?download(\sfree)?", "", caption)
629
+ caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption)
630
+ caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption)
631
+ caption = re.sub(r"\bpage\s+\d+\b", "", caption)
632
+
633
+ caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a...
634
+
635
+ caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption)
636
+
637
+ caption = re.sub(r"\b\s+\:\s+", r": ", caption)
638
+ caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption)
639
+ caption = re.sub(r"\s+", " ", caption)
640
+
641
+ caption.strip()
642
+
643
+ caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption)
644
+ caption = re.sub(r"^[\'\_,\-\:;]", r"", caption)
645
+ caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption)
646
+ caption = re.sub(r"^\.\S+$", "", caption)
647
+
648
+ return caption.strip()
649
+
650
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
651
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
652
+ shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
653
+ if isinstance(generator, list) and len(generator) != batch_size:
654
+ raise ValueError(
655
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
656
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
657
+ )
658
+
659
+ if latents is None:
660
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
661
+ else:
662
+ latents = latents.to(device)
663
+
664
+ # scale the initial noise by the standard deviation required by the scheduler
665
+ latents = latents * self.scheduler.init_noise_sigma
666
+ return latents
667
+
668
+ @staticmethod
669
+ def classify_height_width_bin(height: int, width: int, ratios: dict) -> Tuple[int, int]:
670
+ """Returns binned height and width."""
671
+ ar = float(height / width)
672
+ closest_ratio = min(ratios.keys(), key=lambda ratio: abs(float(ratio) - ar))
673
+ default_hw = ratios[closest_ratio]
674
+ return int(default_hw[0]), int(default_hw[1])
675
+
676
+ @staticmethod
677
+ def resize_and_crop_tensor(samples: torch.Tensor, new_width: int, new_height: int) -> torch.Tensor:
678
+ orig_height, orig_width = samples.shape[2], samples.shape[3]
679
+
680
+ # Check if resizing is needed
681
+ if orig_height != new_height or orig_width != new_width:
682
+ ratio = max(new_height / orig_height, new_width / orig_width)
683
+ resized_width = int(orig_width * ratio)
684
+ resized_height = int(orig_height * ratio)
685
+
686
+ # Resize
687
+ samples = F.interpolate(
688
+ samples, size=(resized_height, resized_width), mode="bilinear", align_corners=False
689
+ )
690
+
691
+ # Center Crop
692
+ start_x = (resized_width - new_width) // 2
693
+ end_x = start_x + new_width
694
+ start_y = (resized_height - new_height) // 2
695
+ end_y = start_y + new_height
696
+ samples = samples[:, :, start_y:end_y, start_x:end_x]
697
+
698
+ return samples
699
+
700
+ @torch.no_grad()
701
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
702
+ def __call__(
703
+ self,
704
+ prompt: Union[str, List[str]] = None,
705
+ negative_prompt: str = "",
706
+ num_inference_steps: int = 20,
707
+ timesteps: List[int] = None,
708
+ guidance_scale: float = 4.5,
709
+ num_images_per_prompt: Optional[int] = 1,
710
+ height: Optional[int] = None,
711
+ width: Optional[int] = None,
712
+ eta: float = 0.0,
713
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
714
+ latents: Optional[torch.FloatTensor] = None,
715
+ prompt_embeds: Optional[torch.FloatTensor] = None,
716
+ prompt_attention_mask: Optional[torch.FloatTensor] = None,
717
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
718
+ negative_prompt_attention_mask: Optional[torch.FloatTensor] = None,
719
+ output_type: Optional[str] = "pil",
720
+ return_dict: bool = True,
721
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
722
+ callback_steps: int = 1,
723
+ clean_caption: bool = True,
724
+ use_resolution_binning: bool = False,
725
+ max_sequence_length: int = 120,
726
+ **kwargs,
727
+ ) -> Union[ImagePipelineOutput, Tuple]:
728
+ """
729
+ Function invoked when calling the pipeline for generation.
730
+
731
+ Args:
732
+ prompt (`str` or `List[str]`, *optional*):
733
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
734
+ instead.
735
+ negative_prompt (`str` or `List[str]`, *optional*):
736
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
737
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
738
+ less than `1`).
739
+ num_inference_steps (`int`, *optional*, defaults to 100):
740
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
741
+ expense of slower inference.
742
+ timesteps (`List[int]`, *optional*):
743
+ Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps`
744
+ timesteps are used. Must be in descending order.
745
+ guidance_scale (`float`, *optional*, defaults to 4.5):
746
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
747
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
748
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
749
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
750
+ usually at the expense of lower image quality.
751
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
752
+ The number of images to generate per prompt.
753
+ height (`int`, *optional*, defaults to self.unet.config.sample_size):
754
+ The height in pixels of the generated image.
755
+ width (`int`, *optional*, defaults to self.unet.config.sample_size):
756
+ The width in pixels of the generated image.
757
+ eta (`float`, *optional*, defaults to 0.0):
758
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
759
+ [`schedulers.DDIMScheduler`], will be ignored for others.
760
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
761
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
762
+ to make generation deterministic.
763
+ latents (`torch.FloatTensor`, *optional*):
764
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
765
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
766
+ tensor will ge generated by sampling using the supplied random `generator`.
767
+ prompt_embeds (`torch.FloatTensor`, *optional*):
768
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
769
+ provided, text embeddings will be generated from `prompt` input argument.
770
+ prompt_attention_mask (`torch.FloatTensor`, *optional*): Pre-generated attention mask for text embeddings.
771
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
772
+ Pre-generated negative text embeddings. For PixArt-Alpha this negative prompt should be "". If not
773
+ provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
774
+ negative_prompt_attention_mask (`torch.FloatTensor`, *optional*):
775
+ Pre-generated attention mask for negative text embeddings.
776
+ output_type (`str`, *optional*, defaults to `"pil"`):
777
+ The output format of the generate image. Choose between
778
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
779
+ return_dict (`bool`, *optional*, defaults to `True`):
780
+ Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple.
781
+ callback (`Callable`, *optional*):
782
+ A function that will be called every `callback_steps` steps during inference. The function will be
783
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
784
+ callback_steps (`int`, *optional*, defaults to 1):
785
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
786
+ called at every step.
787
+ clean_caption (`bool`, *optional*, defaults to `True`):
788
+ Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to
789
+ be installed. If the dependencies are not installed, the embeddings will be created from the raw
790
+ prompt.
791
+ use_resolution_binning (`bool` defaults to `True`):
792
+ If set to `True`, the requested height and width are first mapped to the closest resolutions using
793
+ `ASPECT_RATIO_1024_BIN`. After the produced latents are decoded into images, they are resized back to
794
+ the requested resolution. Useful for generating non-square images.
795
+ max_sequence_length (`int` defaults to 120): Maximum sequence length to use with the `prompt`.
796
+
797
+ Examples:
798
+
799
+ Returns:
800
+ [`~pipelines.ImagePipelineOutput`] or `tuple`:
801
+ If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is
802
+ returned where the first element is a list with the generated images
803
+ """
804
+ if "mask_feature" in kwargs:
805
+ deprecation_message = "The use of `mask_feature` is deprecated. It is no longer used in any computation and that doesn't affect the end results. It will be removed in a future version."
806
+ deprecate("mask_feature", "1.0.0", deprecation_message, standard_warn=False)
807
+ # 1. Check inputs. Raise error if not correct
808
+ height = height or self.transformer.config.sample_size * self.vae_scale_factor
809
+ width = width or self.transformer.config.sample_size * self.vae_scale_factor
810
+ if use_resolution_binning:
811
+ if self.transformer.config.sample_size == 128:
812
+ aspect_ratio_bin = ASPECT_RATIO_1024_BIN
813
+ elif self.transformer.config.sample_size == 64:
814
+ aspect_ratio_bin = ASPECT_RATIO_512_BIN
815
+ elif self.transformer.config.sample_size == 32:
816
+ aspect_ratio_bin = ASPECT_RATIO_256_BIN
817
+ else:
818
+ raise ValueError("Invalid sample size")
819
+ orig_height, orig_width = height, width
820
+ height, width = self.classify_height_width_bin(height, width, ratios=aspect_ratio_bin)
821
+
822
+ self.check_inputs(
823
+ prompt,
824
+ height,
825
+ width,
826
+ negative_prompt,
827
+ callback_steps,
828
+ prompt_embeds,
829
+ negative_prompt_embeds,
830
+ prompt_attention_mask,
831
+ negative_prompt_attention_mask,
832
+ )
833
+
834
+ # 2. Default height and width to transformer
835
+ if prompt is not None and isinstance(prompt, str):
836
+ batch_size = 1
837
+ elif prompt is not None and isinstance(prompt, list):
838
+ batch_size = len(prompt)
839
+ else:
840
+ batch_size = prompt_embeds.shape[0]
841
+
842
+ device = self._execution_device
843
+
844
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
845
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
846
+ # corresponds to doing no classifier free guidance.
847
+ do_classifier_free_guidance = guidance_scale > 1.0
848
+
849
+ # 3. Encode input prompt
850
+ (
851
+ prompt_embeds,
852
+ prompt_attention_mask,
853
+ negative_prompt_embeds,
854
+ negative_prompt_attention_mask,
855
+ ) = self.encode_prompt(
856
+ prompt,
857
+ do_classifier_free_guidance,
858
+ negative_prompt=negative_prompt,
859
+ num_images_per_prompt=num_images_per_prompt,
860
+ device=device,
861
+ prompt_embeds=prompt_embeds,
862
+ negative_prompt_embeds=negative_prompt_embeds,
863
+ prompt_attention_mask=prompt_attention_mask,
864
+ negative_prompt_attention_mask=negative_prompt_attention_mask,
865
+ clean_caption=clean_caption,
866
+ max_sequence_length=max_sequence_length,
867
+ )
868
+ if do_classifier_free_guidance:
869
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
870
+ prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
871
+
872
+ # 4. Prepare timesteps
873
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
874
+
875
+ # 5. Prepare latents.
876
+ latent_channels = self.transformer.config.in_channels
877
+ latents = self.prepare_latents(
878
+ batch_size * num_images_per_prompt,
879
+ latent_channels,
880
+ height,
881
+ width,
882
+ prompt_embeds.dtype,
883
+ device,
884
+ generator,
885
+ latents,
886
+ )
887
+
888
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
889
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
890
+
891
+ # 6.1 Prepare micro-conditions.
892
+ added_cond_kwargs = {"resolution": None, "aspect_ratio": None}
893
+ if self.transformer.config.sample_size == 128:
894
+ resolution = torch.tensor([height, width]).repeat(batch_size * num_images_per_prompt, 1)
895
+ aspect_ratio = torch.tensor([float(height / width)]).repeat(batch_size * num_images_per_prompt, 1)
896
+ resolution = resolution.to(dtype=prompt_embeds.dtype, device=device)
897
+ aspect_ratio = aspect_ratio.to(dtype=prompt_embeds.dtype, device=device)
898
+
899
+ if do_classifier_free_guidance:
900
+ resolution = torch.cat([resolution, resolution], dim=0)
901
+ aspect_ratio = torch.cat([aspect_ratio, aspect_ratio], dim=0)
902
+
903
+ added_cond_kwargs = {"resolution": resolution, "aspect_ratio": aspect_ratio}
904
+
905
+ # 7. Denoising loop
906
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
907
+
908
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
909
+ for i, t in enumerate(timesteps):
910
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
911
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
912
+
913
+ current_timestep = t
914
+ if not torch.is_tensor(current_timestep):
915
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
916
+ # This would be a good case for the `match` statement (Python 3.10+)
917
+ is_mps = latent_model_input.device.type == "mps"
918
+ if isinstance(current_timestep, float):
919
+ dtype = torch.float32 if is_mps else torch.float64
920
+ else:
921
+ dtype = torch.int32 if is_mps else torch.int64
922
+ current_timestep = torch.tensor([current_timestep], dtype=dtype, device=latent_model_input.device)
923
+ elif len(current_timestep.shape) == 0:
924
+ current_timestep = current_timestep[None].to(latent_model_input.device)
925
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
926
+ current_timestep = current_timestep.expand(latent_model_input.shape[0])
927
+
928
+ # predict noise model_output
929
+ noise_pred = self.transformer(
930
+ latent_model_input,
931
+ encoder_hidden_states=prompt_embeds,
932
+ encoder_attention_mask=prompt_attention_mask,
933
+ timestep=current_timestep,
934
+ added_cond_kwargs=added_cond_kwargs,
935
+ return_dict=False,
936
+ )[0]
937
+
938
+ # perform guidance
939
+ if do_classifier_free_guidance:
940
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
941
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
942
+
943
+ # learned sigma
944
+ if self.transformer.config.out_channels // 2 == latent_channels:
945
+ noise_pred = noise_pred.chunk(2, dim=1)[0]
946
+ else:
947
+ noise_pred = noise_pred
948
+
949
+ # compute previous image: x_t -> x_t-1
950
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
951
+
952
+ # call the callback, if provided
953
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
954
+ progress_bar.update()
955
+ if callback is not None and i % callback_steps == 0:
956
+ step_idx = i // getattr(self.scheduler, "order", 1)
957
+ callback(step_idx, t, latents)
958
+
959
+ if not output_type == "latent":
960
+ if self.vae.quant_conv.weight.ndim==5:
961
+ latents = latents.unsqueeze(2)
962
+ latents = latents.float()
963
+ self.vae.post_quant_conv = self.vae.post_quant_conv.float()
964
+ self.vae.decoder = self.vae.decoder.float()
965
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
966
+ if self.vae.quant_conv.weight.ndim==5:
967
+ image = image.permute(0,2,1,3,4).flatten(0, 1)
968
+
969
+ if use_resolution_binning:
970
+ image = self.resize_and_crop_tensor(image, orig_width, orig_height)
971
+ else:
972
+ image = latents
973
+
974
+ if not output_type == "latent":
975
+ image = self.image_processor.postprocess(image, output_type=output_type)
976
+
977
+ # Offload all models
978
+ self.maybe_free_model_hooks()
979
+
980
+ if not return_dict:
981
+ return (image,)
982
+
983
+ return ImagePipelineOutput(images=image)
easyanimate/ui/ui.py ADDED
@@ -0,0 +1,818 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Modified from https://github.com/guoyww/AnimateDiff/blob/main/app.py
2
+ """
3
+ import gc
4
+ import json
5
+ import os
6
+ import random
7
+ import base64
8
+ import requests
9
+ from datetime import datetime
10
+ from glob import glob
11
+
12
+ import gradio as gr
13
+ import torch
14
+ import numpy as np
15
+ from diffusers import (AutoencoderKL, DDIMScheduler,
16
+ DPMSolverMultistepScheduler,
17
+ EulerAncestralDiscreteScheduler, EulerDiscreteScheduler,
18
+ PNDMScheduler)
19
+ from easyanimate.models.autoencoder_magvit import AutoencoderKLMagvit
20
+ from diffusers.utils.import_utils import is_xformers_available
21
+ from omegaconf import OmegaConf
22
+ from safetensors import safe_open
23
+ from transformers import T5EncoderModel, T5Tokenizer
24
+
25
+ from easyanimate.models.transformer3d import Transformer3DModel
26
+ from easyanimate.pipeline.pipeline_easyanimate import EasyAnimatePipeline
27
+ from easyanimate.utils.lora_utils import merge_lora, unmerge_lora
28
+ from easyanimate.utils.utils import save_videos_grid
29
+ from PIL import Image
30
+
31
+ sample_idx = 0
32
+ scheduler_dict = {
33
+ "Euler": EulerDiscreteScheduler,
34
+ "Euler A": EulerAncestralDiscreteScheduler,
35
+ "DPM++": DPMSolverMultistepScheduler,
36
+ "PNDM": PNDMScheduler,
37
+ "DDIM": DDIMScheduler,
38
+ }
39
+
40
+ css = """
41
+ .toolbutton {
42
+ margin-buttom: 0em 0em 0em 0em;
43
+ max-width: 2.5em;
44
+ min-width: 2.5em !important;
45
+ height: 2.5em;
46
+ }
47
+ """
48
+
49
+ class EasyAnimateController:
50
+ def __init__(self):
51
+ # config dirs
52
+ self.basedir = os.getcwd()
53
+ self.config_dir = os.path.join(self.basedir, "config")
54
+ self.diffusion_transformer_dir = os.path.join(self.basedir, "models", "Diffusion_Transformer")
55
+ self.motion_module_dir = os.path.join(self.basedir, "models", "Motion_Module")
56
+ self.personalized_model_dir = os.path.join(self.basedir, "models", "Personalized_Model")
57
+ self.savedir = os.path.join(self.basedir, "samples", datetime.now().strftime("Gradio-%Y-%m-%dT%H-%M-%S"))
58
+ self.savedir_sample = os.path.join(self.savedir, "sample")
59
+ self.edition = "v2"
60
+ self.inference_config = OmegaConf.load(os.path.join(self.config_dir, "easyanimate_video_magvit_motion_module_v2.yaml"))
61
+ os.makedirs(self.savedir, exist_ok=True)
62
+
63
+ self.diffusion_transformer_list = []
64
+ self.motion_module_list = []
65
+ self.personalized_model_list = []
66
+
67
+ self.refresh_diffusion_transformer()
68
+ self.refresh_motion_module()
69
+ self.refresh_personalized_model()
70
+
71
+ # config models
72
+ self.tokenizer = None
73
+ self.text_encoder = None
74
+ self.vae = None
75
+ self.transformer = None
76
+ self.pipeline = None
77
+ self.motion_module_path = "none"
78
+ self.base_model_path = "none"
79
+ self.lora_model_path = "none"
80
+
81
+ self.weight_dtype = torch.bfloat16
82
+
83
+ def refresh_diffusion_transformer(self):
84
+ self.diffusion_transformer_list = glob(os.path.join(self.diffusion_transformer_dir, "*/"))
85
+
86
+ def refresh_motion_module(self):
87
+ motion_module_list = glob(os.path.join(self.motion_module_dir, "*.safetensors"))
88
+ self.motion_module_list = [os.path.basename(p) for p in motion_module_list]
89
+
90
+ def refresh_personalized_model(self):
91
+ personalized_model_list = glob(os.path.join(self.personalized_model_dir, "*.safetensors"))
92
+ self.personalized_model_list = [os.path.basename(p) for p in personalized_model_list]
93
+
94
+ def update_edition(self, edition):
95
+ print("Update edition of EasyAnimate")
96
+ self.edition = edition
97
+ if edition == "v1":
98
+ self.inference_config = OmegaConf.load(os.path.join(self.config_dir, "easyanimate_video_motion_module_v1.yaml"))
99
+ return gr.Dropdown.update(), gr.update(value="none"), gr.update(visible=True), gr.update(visible=True), \
100
+ gr.update(visible=False), gr.update(value=512, minimum=384, maximum=704, step=32), \
101
+ gr.update(value=512, minimum=384, maximum=704, step=32), gr.update(value=80, minimum=40, maximum=80, step=1)
102
+ else:
103
+ self.inference_config = OmegaConf.load(os.path.join(self.config_dir, "easyanimate_video_magvit_motion_module_v2.yaml"))
104
+ return gr.Dropdown.update(), gr.update(value="none"), gr.update(visible=False), gr.update(visible=False), \
105
+ gr.update(visible=True), gr.update(value=672, minimum=128, maximum=1280, step=16), \
106
+ gr.update(value=384, minimum=128, maximum=1280, step=16), gr.update(value=144, minimum=9, maximum=144, step=9)
107
+
108
+ def update_diffusion_transformer(self, diffusion_transformer_dropdown):
109
+ print("Update diffusion transformer")
110
+ if diffusion_transformer_dropdown == "none":
111
+ return gr.Dropdown.update()
112
+ if OmegaConf.to_container(self.inference_config['vae_kwargs'])['enable_magvit']:
113
+ Choosen_AutoencoderKL = AutoencoderKLMagvit
114
+ else:
115
+ Choosen_AutoencoderKL = AutoencoderKL
116
+ self.vae = Choosen_AutoencoderKL.from_pretrained(
117
+ diffusion_transformer_dropdown,
118
+ subfolder="vae",
119
+ ).to(self.weight_dtype)
120
+ self.transformer = Transformer3DModel.from_pretrained_2d(
121
+ diffusion_transformer_dropdown,
122
+ subfolder="transformer",
123
+ transformer_additional_kwargs=OmegaConf.to_container(self.inference_config.transformer_additional_kwargs)
124
+ ).to(self.weight_dtype)
125
+ self.tokenizer = T5Tokenizer.from_pretrained(diffusion_transformer_dropdown, subfolder="tokenizer")
126
+ self.text_encoder = T5EncoderModel.from_pretrained(diffusion_transformer_dropdown, subfolder="text_encoder", torch_dtype=self.weight_dtype)
127
+
128
+ # Get pipeline
129
+ self.pipeline = EasyAnimatePipeline(
130
+ vae=self.vae,
131
+ text_encoder=self.text_encoder,
132
+ tokenizer=self.tokenizer,
133
+ transformer=self.transformer,
134
+ scheduler=scheduler_dict["Euler"](**OmegaConf.to_container(self.inference_config.noise_scheduler_kwargs))
135
+ )
136
+ self.pipeline.enable_model_cpu_offload()
137
+ print("Update diffusion transformer done")
138
+ return gr.Dropdown.update()
139
+
140
+ def update_motion_module(self, motion_module_dropdown):
141
+ self.motion_module_path = motion_module_dropdown
142
+ print("Update motion module")
143
+ if motion_module_dropdown == "none":
144
+ return gr.Dropdown.update()
145
+ if self.transformer is None:
146
+ gr.Info(f"Please select a pretrained model path.")
147
+ return gr.Dropdown.update(value=None)
148
+ else:
149
+ motion_module_dropdown = os.path.join(self.motion_module_dir, motion_module_dropdown)
150
+ if motion_module_dropdown.endswith(".safetensors"):
151
+ from safetensors.torch import load_file, safe_open
152
+ motion_module_state_dict = load_file(motion_module_dropdown)
153
+ else:
154
+ if not os.path.isfile(motion_module_dropdown):
155
+ raise RuntimeError(f"{motion_module_dropdown} does not exist")
156
+ motion_module_state_dict = torch.load(motion_module_dropdown, map_location="cpu")
157
+ missing, unexpected = self.transformer.load_state_dict(motion_module_state_dict, strict=False)
158
+ print("Update motion module done.")
159
+ return gr.Dropdown.update()
160
+
161
+ def update_base_model(self, base_model_dropdown):
162
+ self.base_model_path = base_model_dropdown
163
+ print("Update base model")
164
+ if base_model_dropdown == "none":
165
+ return gr.Dropdown.update()
166
+ if self.transformer is None:
167
+ gr.Info(f"Please select a pretrained model path.")
168
+ return gr.Dropdown.update(value=None)
169
+ else:
170
+ base_model_dropdown = os.path.join(self.personalized_model_dir, base_model_dropdown)
171
+ base_model_state_dict = {}
172
+ with safe_open(base_model_dropdown, framework="pt", device="cpu") as f:
173
+ for key in f.keys():
174
+ base_model_state_dict[key] = f.get_tensor(key)
175
+ self.transformer.load_state_dict(base_model_state_dict, strict=False)
176
+ print("Update base done")
177
+ return gr.Dropdown.update()
178
+
179
+ def update_lora_model(self, lora_model_dropdown):
180
+ print("Update lora model")
181
+ if lora_model_dropdown == "none":
182
+ self.lora_model_path = "none"
183
+ return gr.Dropdown.update()
184
+ lora_model_dropdown = os.path.join(self.personalized_model_dir, lora_model_dropdown)
185
+ self.lora_model_path = lora_model_dropdown
186
+ return gr.Dropdown.update()
187
+
188
+ def generate(
189
+ self,
190
+ diffusion_transformer_dropdown,
191
+ motion_module_dropdown,
192
+ base_model_dropdown,
193
+ lora_model_dropdown,
194
+ lora_alpha_slider,
195
+ prompt_textbox,
196
+ negative_prompt_textbox,
197
+ sampler_dropdown,
198
+ sample_step_slider,
199
+ width_slider,
200
+ height_slider,
201
+ is_image,
202
+ length_slider,
203
+ cfg_scale_slider,
204
+ seed_textbox,
205
+ is_api = False,
206
+ ):
207
+ global sample_idx
208
+ if self.transformer is None:
209
+ raise gr.Error(f"Please select a pretrained model path.")
210
+
211
+ if self.base_model_path != base_model_dropdown:
212
+ self.update_base_model(base_model_dropdown)
213
+
214
+ if self.motion_module_path != motion_module_dropdown:
215
+ self.update_motion_module(motion_module_dropdown)
216
+
217
+ if self.lora_model_path != lora_model_dropdown:
218
+ print("Update lora model")
219
+ self.update_lora_model(lora_model_dropdown)
220
+
221
+ if is_xformers_available(): self.transformer.enable_xformers_memory_efficient_attention()
222
+
223
+ self.pipeline.scheduler = scheduler_dict[sampler_dropdown](**OmegaConf.to_container(self.inference_config.noise_scheduler_kwargs))
224
+ if self.lora_model_path != "none":
225
+ # lora part
226
+ self.pipeline = merge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider)
227
+ self.pipeline.to("cuda")
228
+
229
+ if int(seed_textbox) != -1 and seed_textbox != "": torch.manual_seed(int(seed_textbox))
230
+ else: seed_textbox = np.random.randint(0, 1e10)
231
+ generator = torch.Generator(device="cuda").manual_seed(int(seed_textbox))
232
+
233
+ try:
234
+ sample = self.pipeline(
235
+ prompt_textbox,
236
+ negative_prompt = negative_prompt_textbox,
237
+ num_inference_steps = sample_step_slider,
238
+ guidance_scale = cfg_scale_slider,
239
+ width = width_slider,
240
+ height = height_slider,
241
+ video_length = length_slider if not is_image else 1,
242
+ generator = generator
243
+ ).videos
244
+ except Exception as e:
245
+ gc.collect()
246
+ torch.cuda.empty_cache()
247
+ torch.cuda.ipc_collect()
248
+ if self.lora_model_path != "none":
249
+ self.pipeline = unmerge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider)
250
+ if is_api:
251
+ return "", f"Error. error information is {str(e)}"
252
+ else:
253
+ return gr.Image.update(), gr.Video.update(), f"Error. error information is {str(e)}"
254
+
255
+ # lora part
256
+ if self.lora_model_path != "none":
257
+ self.pipeline = unmerge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider)
258
+
259
+ sample_config = {
260
+ "prompt": prompt_textbox,
261
+ "n_prompt": negative_prompt_textbox,
262
+ "sampler": sampler_dropdown,
263
+ "num_inference_steps": sample_step_slider,
264
+ "guidance_scale": cfg_scale_slider,
265
+ "width": width_slider,
266
+ "height": height_slider,
267
+ "video_length": length_slider,
268
+ "seed_textbox": seed_textbox
269
+ }
270
+ json_str = json.dumps(sample_config, indent=4)
271
+ with open(os.path.join(self.savedir, "logs.json"), "a") as f:
272
+ f.write(json_str)
273
+ f.write("\n\n")
274
+
275
+ if not os.path.exists(self.savedir_sample):
276
+ os.makedirs(self.savedir_sample, exist_ok=True)
277
+ index = len([path for path in os.listdir(self.savedir_sample)]) + 1
278
+ prefix = str(index).zfill(3)
279
+
280
+ gc.collect()
281
+ torch.cuda.empty_cache()
282
+ torch.cuda.ipc_collect()
283
+ if is_image or length_slider == 1:
284
+ save_sample_path = os.path.join(self.savedir_sample, prefix + f".png")
285
+
286
+ image = sample[0, :, 0]
287
+ image = image.transpose(0, 1).transpose(1, 2)
288
+ image = (image * 255).numpy().astype(np.uint8)
289
+ image = Image.fromarray(image)
290
+ image.save(save_sample_path)
291
+
292
+ if is_api:
293
+ return save_sample_path, "Success"
294
+ else:
295
+ return gr.Image.update(value=save_sample_path, visible=True), gr.Video.update(value=None, visible=False), "Success"
296
+ else:
297
+ save_sample_path = os.path.join(self.savedir_sample, prefix + f".mp4")
298
+ save_videos_grid(sample, save_sample_path, fps=12 if self.edition == "v1" else 24)
299
+
300
+ if is_api:
301
+ return save_sample_path, "Success"
302
+ else:
303
+ return gr.Image.update(visible=False, value=None), gr.Video.update(value=save_sample_path, visible=True), "Success"
304
+
305
+
306
+ def ui():
307
+ controller = EasyAnimateController()
308
+
309
+ with gr.Blocks(css=css) as demo:
310
+ gr.Markdown(
311
+ """
312
+ # EasyAnimate: Integrated generation of baseline scheme for videos and images.
313
+ Generate your videos easily
314
+ [Github](https://github.com/aigc-apps/EasyAnimate/)
315
+ """
316
+ )
317
+ with gr.Column(variant="panel"):
318
+ gr.Markdown(
319
+ """
320
+ ### 1. EasyAnimate Edition (select easyanimate edition first).
321
+ """
322
+ )
323
+ with gr.Row():
324
+ easyanimate_edition_dropdown = gr.Dropdown(
325
+ label="The config of EasyAnimate Edition",
326
+ choices=["v1", "v2"],
327
+ value="v2",
328
+ interactive=True,
329
+ )
330
+ gr.Markdown(
331
+ """
332
+ ### 2. Model checkpoints (select pretrained model path).
333
+ """
334
+ )
335
+ with gr.Row():
336
+ diffusion_transformer_dropdown = gr.Dropdown(
337
+ label="Pretrained Model Path",
338
+ choices=controller.diffusion_transformer_list,
339
+ value="none",
340
+ interactive=True,
341
+ )
342
+ diffusion_transformer_dropdown.change(
343
+ fn=controller.update_diffusion_transformer,
344
+ inputs=[diffusion_transformer_dropdown],
345
+ outputs=[diffusion_transformer_dropdown]
346
+ )
347
+
348
+ diffusion_transformer_refresh_button = gr.Button(value="\U0001F503", elem_classes="toolbutton")
349
+ def refresh_diffusion_transformer():
350
+ controller.refresh_diffusion_transformer()
351
+ return gr.Dropdown.update(choices=controller.diffusion_transformer_list)
352
+ diffusion_transformer_refresh_button.click(fn=refresh_diffusion_transformer, inputs=[], outputs=[diffusion_transformer_dropdown])
353
+
354
+ with gr.Row():
355
+ motion_module_dropdown = gr.Dropdown(
356
+ label="Select motion module",
357
+ choices=controller.motion_module_list,
358
+ value="none",
359
+ interactive=True,
360
+ visible=False
361
+ )
362
+
363
+ motion_module_refresh_button = gr.Button(value="\U0001F503", elem_classes="toolbutton", visible=False)
364
+ def update_motion_module():
365
+ controller.refresh_motion_module()
366
+ return gr.Dropdown.update(choices=controller.motion_module_list)
367
+ motion_module_refresh_button.click(fn=update_motion_module, inputs=[], outputs=[motion_module_dropdown])
368
+
369
+ base_model_dropdown = gr.Dropdown(
370
+ label="Select base Dreambooth model (optional)",
371
+ choices=controller.personalized_model_list,
372
+ value="none",
373
+ interactive=True,
374
+ )
375
+
376
+ lora_model_dropdown = gr.Dropdown(
377
+ label="Select LoRA model (optional)",
378
+ choices=["none"] + controller.personalized_model_list,
379
+ value="none",
380
+ interactive=True,
381
+ )
382
+
383
+ lora_alpha_slider = gr.Slider(label="LoRA alpha", value=0.55, minimum=0, maximum=2, interactive=True)
384
+
385
+ personalized_refresh_button = gr.Button(value="\U0001F503", elem_classes="toolbutton")
386
+ def update_personalized_model():
387
+ controller.refresh_personalized_model()
388
+ return [
389
+ gr.Dropdown.update(choices=controller.personalized_model_list),
390
+ gr.Dropdown.update(choices=["none"] + controller.personalized_model_list)
391
+ ]
392
+ personalized_refresh_button.click(fn=update_personalized_model, inputs=[], outputs=[base_model_dropdown, lora_model_dropdown])
393
+
394
+ with gr.Column(variant="panel"):
395
+ gr.Markdown(
396
+ """
397
+ ### 3. Configs for Generation.
398
+ """
399
+ )
400
+
401
+ prompt_textbox = gr.Textbox(label="Prompt", lines=2, value="This video shows the majestic beauty of a waterfall cascading down a cliff into a serene lake. The waterfall, with its powerful flow, is the central focus of the video. The surrounding landscape is lush and green, with trees and foliage adding to the natural beauty of the scene")
402
+ negative_prompt_textbox = gr.Textbox(label="Negative prompt", lines=2, value="The video is not of a high quality, it has a low resolution, and the audio quality is not clear. Strange motion trajectory, a poor composition and deformed video, low resolution, duplicate and ugly, strange body structure, long and strange neck, bad teeth, bad eyes, bad limbs, bad hands, rotating camera, blurry camera, shaking camera. Deformation, low-resolution, blurry, ugly, distortion. " )
403
+
404
+ with gr.Row():
405
+ with gr.Column():
406
+ with gr.Row():
407
+ sampler_dropdown = gr.Dropdown(label="Sampling method", choices=list(scheduler_dict.keys()), value=list(scheduler_dict.keys())[0])
408
+ sample_step_slider = gr.Slider(label="Sampling steps", value=50, minimum=10, maximum=100, step=1)
409
+
410
+ width_slider = gr.Slider(label="Width", value=672, minimum=128, maximum=1280, step=16)
411
+ height_slider = gr.Slider(label="Height", value=384, minimum=128, maximum=1280, step=16)
412
+ with gr.Row():
413
+ is_image = gr.Checkbox(False, label="Generate Image")
414
+ length_slider = gr.Slider(label="Animation length", value=144, minimum=9, maximum=144, step=9)
415
+ cfg_scale_slider = gr.Slider(label="CFG Scale", value=7.0, minimum=0, maximum=20)
416
+
417
+ with gr.Row():
418
+ seed_textbox = gr.Textbox(label="Seed", value=43)
419
+ seed_button = gr.Button(value="\U0001F3B2", elem_classes="toolbutton")
420
+ seed_button.click(fn=lambda: gr.Textbox.update(value=random.randint(1, 1e8)), inputs=[], outputs=[seed_textbox])
421
+
422
+ generate_button = gr.Button(value="Generate", variant='primary')
423
+
424
+ with gr.Column():
425
+ result_image = gr.Image(label="Generated Image", interactive=False, visible=False)
426
+ result_video = gr.Video(label="Generated Animation", interactive=False)
427
+ infer_progress = gr.Textbox(
428
+ label="Generation Info",
429
+ value="No task currently",
430
+ interactive=False
431
+ )
432
+
433
+ is_image.change(
434
+ lambda x: gr.update(visible=not x),
435
+ inputs=[is_image],
436
+ outputs=[length_slider],
437
+ )
438
+ easyanimate_edition_dropdown.change(
439
+ fn=controller.update_edition,
440
+ inputs=[easyanimate_edition_dropdown],
441
+ outputs=[
442
+ easyanimate_edition_dropdown,
443
+ diffusion_transformer_dropdown,
444
+ motion_module_dropdown,
445
+ motion_module_refresh_button,
446
+ is_image,
447
+ width_slider,
448
+ height_slider,
449
+ length_slider,
450
+ ]
451
+ )
452
+ generate_button.click(
453
+ fn=controller.generate,
454
+ inputs=[
455
+ diffusion_transformer_dropdown,
456
+ motion_module_dropdown,
457
+ base_model_dropdown,
458
+ lora_model_dropdown,
459
+ lora_alpha_slider,
460
+ prompt_textbox,
461
+ negative_prompt_textbox,
462
+ sampler_dropdown,
463
+ sample_step_slider,
464
+ width_slider,
465
+ height_slider,
466
+ is_image,
467
+ length_slider,
468
+ cfg_scale_slider,
469
+ seed_textbox,
470
+ ],
471
+ outputs=[result_image, result_video, infer_progress]
472
+ )
473
+ return demo, controller
474
+
475
+
476
+ class EasyAnimateController_Modelscope:
477
+ def __init__(self, edition, config_path, model_name, savedir_sample):
478
+ # Config and model path
479
+ weight_dtype = torch.bfloat16
480
+ self.savedir_sample = savedir_sample
481
+ os.makedirs(self.savedir_sample, exist_ok=True)
482
+
483
+ self.edition = edition
484
+ self.inference_config = OmegaConf.load(config_path)
485
+ # Get Transformer
486
+ self.transformer = Transformer3DModel.from_pretrained_2d(
487
+ model_name,
488
+ subfolder="transformer",
489
+ transformer_additional_kwargs=OmegaConf.to_container(self.inference_config['transformer_additional_kwargs'])
490
+ ).to(weight_dtype)
491
+ if OmegaConf.to_container(self.inference_config['vae_kwargs'])['enable_magvit']:
492
+ Choosen_AutoencoderKL = AutoencoderKLMagvit
493
+ else:
494
+ Choosen_AutoencoderKL = AutoencoderKL
495
+ self.vae = Choosen_AutoencoderKL.from_pretrained(
496
+ model_name,
497
+ subfolder="vae"
498
+ ).to(weight_dtype)
499
+ self.tokenizer = T5Tokenizer.from_pretrained(
500
+ model_name,
501
+ subfolder="tokenizer"
502
+ )
503
+ self.text_encoder = T5EncoderModel.from_pretrained(
504
+ model_name,
505
+ subfolder="text_encoder",
506
+ torch_dtype=weight_dtype
507
+ )
508
+ self.pipeline = EasyAnimatePipeline(
509
+ vae=self.vae,
510
+ text_encoder=self.text_encoder,
511
+ tokenizer=self.tokenizer,
512
+ transformer=self.transformer,
513
+ scheduler=scheduler_dict["Euler"](**OmegaConf.to_container(self.inference_config.noise_scheduler_kwargs))
514
+ )
515
+ self.pipeline.enable_model_cpu_offload()
516
+ print("Update diffusion transformer done")
517
+
518
+ def generate(
519
+ self,
520
+ prompt_textbox,
521
+ negative_prompt_textbox,
522
+ sampler_dropdown,
523
+ sample_step_slider,
524
+ width_slider,
525
+ height_slider,
526
+ is_image,
527
+ length_slider,
528
+ cfg_scale_slider,
529
+ seed_textbox
530
+ ):
531
+ if is_xformers_available(): self.transformer.enable_xformers_memory_efficient_attention()
532
+
533
+ self.pipeline.scheduler = scheduler_dict[sampler_dropdown](**OmegaConf.to_container(self.inference_config.noise_scheduler_kwargs))
534
+ self.pipeline.to("cuda")
535
+
536
+ if int(seed_textbox) != -1 and seed_textbox != "": torch.manual_seed(int(seed_textbox))
537
+ else: seed_textbox = np.random.randint(0, 1e10)
538
+ generator = torch.Generator(device="cuda").manual_seed(int(seed_textbox))
539
+
540
+ try:
541
+ sample = self.pipeline(
542
+ prompt_textbox,
543
+ negative_prompt = negative_prompt_textbox,
544
+ num_inference_steps = sample_step_slider,
545
+ guidance_scale = cfg_scale_slider,
546
+ width = width_slider,
547
+ height = height_slider,
548
+ video_length = length_slider if not is_image else 1,
549
+ generator = generator
550
+ ).videos
551
+ except Exception as e:
552
+ gc.collect()
553
+ torch.cuda.empty_cache()
554
+ torch.cuda.ipc_collect()
555
+ return gr.Image.update(), gr.Video.update(), f"Error. error information is {str(e)}"
556
+
557
+ if not os.path.exists(self.savedir_sample):
558
+ os.makedirs(self.savedir_sample, exist_ok=True)
559
+ index = len([path for path in os.listdir(self.savedir_sample)]) + 1
560
+ prefix = str(index).zfill(3)
561
+
562
+ gc.collect()
563
+ torch.cuda.empty_cache()
564
+ torch.cuda.ipc_collect()
565
+ if is_image or length_slider == 1:
566
+ save_sample_path = os.path.join(self.savedir_sample, prefix + f".png")
567
+
568
+ image = sample[0, :, 0]
569
+ image = image.transpose(0, 1).transpose(1, 2)
570
+ image = (image * 255).numpy().astype(np.uint8)
571
+ image = Image.fromarray(image)
572
+ image.save(save_sample_path)
573
+ return gr.Image.update(value=save_sample_path, visible=True), gr.Video.update(value=None, visible=False), "Success"
574
+ else:
575
+ save_sample_path = os.path.join(self.savedir_sample, prefix + f".mp4")
576
+ save_videos_grid(sample, save_sample_path, fps=12 if self.edition == "v1" else 24)
577
+ return gr.Image.update(visible=False, value=None), gr.Video.update(value=save_sample_path, visible=True), "Success"
578
+
579
+
580
+ def ui_modelscope(edition, config_path, model_name, savedir_sample):
581
+ controller = EasyAnimateController_Modelscope(edition, config_path, model_name, savedir_sample)
582
+
583
+ with gr.Blocks(css=css) as demo:
584
+ gr.Markdown(
585
+ """
586
+ # EasyAnimate: Integrated generation of baseline scheme for videos and images.
587
+ Generate your videos easily
588
+ [Github](https://github.com/aigc-apps/EasyAnimate/)
589
+ """
590
+ )
591
+ with gr.Column(variant="panel"):
592
+ prompt_textbox = gr.Textbox(label="Prompt", lines=2, value="This video shows the majestic beauty of a waterfall cascading down a cliff into a serene lake. The waterfall, with its powerful flow, is the central focus of the video. The surrounding landscape is lush and green, with trees and foliage adding to the natural beauty of the scene")
593
+ negative_prompt_textbox = gr.Textbox(label="Negative prompt", lines=2, value="The video is not of a high quality, it has a low resolution, and the audio quality is not clear. Strange motion trajectory, a poor composition and deformed video, low resolution, duplicate and ugly, strange body structure, long and strange neck, bad teeth, bad eyes, bad limbs, bad hands, rotating camera, blurry camera, shaking camera. Deformation, low-resolution, blurry, ugly, distortion. " )
594
+
595
+ with gr.Row():
596
+ with gr.Column():
597
+ with gr.Row():
598
+ sampler_dropdown = gr.Dropdown(label="Sampling method", choices=list(scheduler_dict.keys()), value=list(scheduler_dict.keys())[0])
599
+ sample_step_slider = gr.Slider(label="Sampling steps", value=30, minimum=10, maximum=100, step=1)
600
+
601
+ if edition == "v1":
602
+ width_slider = gr.Slider(label="Width", value=512, minimum=384, maximum=704, step=32)
603
+ height_slider = gr.Slider(label="Height", value=512, minimum=384, maximum=704, step=32)
604
+ with gr.Row():
605
+ is_image = gr.Checkbox(False, label="Generate Image", visible=False)
606
+ length_slider = gr.Slider(label="Animation length", value=80, minimum=40, maximum=96, step=1)
607
+ cfg_scale_slider = gr.Slider(label="CFG Scale", value=6.0, minimum=0, maximum=20)
608
+ else:
609
+ width_slider = gr.Slider(label="Width", value=672, minimum=256, maximum=704, step=16)
610
+ height_slider = gr.Slider(label="Height", value=384, minimum=256, maximum=704, step=16)
611
+ with gr.Column():
612
+ gr.Markdown(
613
+ """
614
+ To ensure the efficiency of the trial, we will limit the frame rate to no more than 81.
615
+ If you want to experience longer video generation, you can go to our [Github](https://github.com/aigc-apps/EasyAnimate/).
616
+ """
617
+ )
618
+ with gr.Row():
619
+ is_image = gr.Checkbox(False, label="Generate Image")
620
+ length_slider = gr.Slider(label="Animation length", value=72, minimum=9, maximum=81, step=9)
621
+ cfg_scale_slider = gr.Slider(label="CFG Scale", value=7.0, minimum=0, maximum=20)
622
+
623
+ with gr.Row():
624
+ seed_textbox = gr.Textbox(label="Seed", value=43)
625
+ seed_button = gr.Button(value="\U0001F3B2", elem_classes="toolbutton")
626
+ seed_button.click(fn=lambda: gr.Textbox.update(value=random.randint(1, 1e8)), inputs=[], outputs=[seed_textbox])
627
+
628
+ generate_button = gr.Button(value="Generate", variant='primary')
629
+
630
+ with gr.Column():
631
+ result_image = gr.Image(label="Generated Image", interactive=False, visible=False)
632
+ result_video = gr.Video(label="Generated Animation", interactive=False)
633
+ infer_progress = gr.Textbox(
634
+ label="Generation Info",
635
+ value="No task currently",
636
+ interactive=False
637
+ )
638
+
639
+ is_image.change(
640
+ lambda x: gr.update(visible=not x),
641
+ inputs=[is_image],
642
+ outputs=[length_slider],
643
+ )
644
+
645
+ generate_button.click(
646
+ fn=controller.generate,
647
+ inputs=[
648
+ prompt_textbox,
649
+ negative_prompt_textbox,
650
+ sampler_dropdown,
651
+ sample_step_slider,
652
+ width_slider,
653
+ height_slider,
654
+ is_image,
655
+ length_slider,
656
+ cfg_scale_slider,
657
+ seed_textbox,
658
+ ],
659
+ outputs=[result_image, result_video, infer_progress]
660
+ )
661
+ return demo, controller
662
+
663
+
664
+ def post_eas(
665
+ prompt_textbox, negative_prompt_textbox,
666
+ sampler_dropdown, sample_step_slider, width_slider, height_slider,
667
+ is_image, length_slider, cfg_scale_slider, seed_textbox,
668
+ ):
669
+ datas = {
670
+ "base_model_path": "none",
671
+ "motion_module_path": "none",
672
+ "lora_model_path": "none",
673
+ "lora_alpha_slider": 0.55,
674
+ "prompt_textbox": prompt_textbox,
675
+ "negative_prompt_textbox": negative_prompt_textbox,
676
+ "sampler_dropdown": sampler_dropdown,
677
+ "sample_step_slider": sample_step_slider,
678
+ "width_slider": width_slider,
679
+ "height_slider": height_slider,
680
+ "is_image": is_image,
681
+ "length_slider": length_slider,
682
+ "cfg_scale_slider": cfg_scale_slider,
683
+ "seed_textbox": seed_textbox,
684
+ }
685
+ # Token可以在公网地址调用信息中获取,详情请参见通用公网调用部分。
686
+ session = requests.session()
687
+ session.headers.update({"Authorization": os.environ.get("EAS_TOKEN")})
688
+
689
+ response = session.post(url=f'{os.environ.get("EAS_URL")}/easyanimate/infer_forward', json=datas)
690
+ outputs = response.json()
691
+ return outputs
692
+
693
+
694
+ class EasyAnimateController_HuggingFace:
695
+ def __init__(self, edition, config_path, model_name, savedir_sample):
696
+ self.savedir_sample = savedir_sample
697
+ os.makedirs(self.savedir_sample, exist_ok=True)
698
+
699
+ def generate(
700
+ self,
701
+ prompt_textbox,
702
+ negative_prompt_textbox,
703
+ sampler_dropdown,
704
+ sample_step_slider,
705
+ width_slider,
706
+ height_slider,
707
+ is_image,
708
+ length_slider,
709
+ cfg_scale_slider,
710
+ seed_textbox
711
+ ):
712
+ outputs = post_eas(
713
+ prompt_textbox, negative_prompt_textbox,
714
+ sampler_dropdown, sample_step_slider, width_slider, height_slider,
715
+ is_image, length_slider, cfg_scale_slider, seed_textbox
716
+ )
717
+ base64_encoding = outputs["base64_encoding"]
718
+ decoded_data = base64.b64decode(base64_encoding)
719
+
720
+ if not os.path.exists(self.savedir_sample):
721
+ os.makedirs(self.savedir_sample, exist_ok=True)
722
+ index = len([path for path in os.listdir(self.savedir_sample)]) + 1
723
+ prefix = str(index).zfill(3)
724
+
725
+ if is_image or length_slider == 1:
726
+ save_sample_path = os.path.join(self.savedir_sample, prefix + f".png")
727
+ with open(save_sample_path, "wb") as file:
728
+ file.write(decoded_data)
729
+ return gr.Image.update(value=save_sample_path, visible=True), gr.Video.update(value=None, visible=False), "Success"
730
+ else:
731
+ save_sample_path = os.path.join(self.savedir_sample, prefix + f".mp4")
732
+ with open(save_sample_path, "wb") as file:
733
+ file.write(decoded_data)
734
+ return gr.Image.update(visible=False, value=None), gr.Video.update(value=save_sample_path, visible=True), "Success"
735
+
736
+
737
+ def ui_huggingface(edition, config_path, model_name, savedir_sample):
738
+ controller = EasyAnimateController_HuggingFace(edition, config_path, model_name, savedir_sample)
739
+
740
+ with gr.Blocks(css=css) as demo:
741
+ gr.Markdown(
742
+ """
743
+ # EasyAnimate: Integrated generation of baseline scheme for videos and images.
744
+ Generate your videos easily
745
+ [Github](https://github.com/aigc-apps/EasyAnimate/)
746
+ """
747
+ )
748
+ with gr.Column(variant="panel"):
749
+ prompt_textbox = gr.Textbox(label="Prompt", lines=2, value="This video shows the majestic beauty of a waterfall cascading down a cliff into a serene lake. The waterfall, with its powerful flow, is the central focus of the video. The surrounding landscape is lush and green, with trees and foliage adding to the natural beauty of the scene")
750
+ negative_prompt_textbox = gr.Textbox(label="Negative prompt", lines=2, value="The video is not of a high quality, it has a low resolution, and the audio quality is not clear. Strange motion trajectory, a poor composition and deformed video, low resolution, duplicate and ugly, strange body structure, long and strange neck, bad teeth, bad eyes, bad limbs, bad hands, rotating camera, blurry camera, shaking camera. Deformation, low-resolution, blurry, ugly, distortion. " )
751
+
752
+ with gr.Row():
753
+ with gr.Column():
754
+ with gr.Row():
755
+ sampler_dropdown = gr.Dropdown(label="Sampling method", choices=list(scheduler_dict.keys()), value=list(scheduler_dict.keys())[0])
756
+ sample_step_slider = gr.Slider(label="Sampling steps", value=30, minimum=10, maximum=100, step=1)
757
+
758
+ if edition == "v1":
759
+ width_slider = gr.Slider(label="Width", value=512, minimum=384, maximum=704, step=32)
760
+ height_slider = gr.Slider(label="Height", value=512, minimum=384, maximum=704, step=32)
761
+ with gr.Row():
762
+ is_image = gr.Checkbox(False, label="Generate Image", visible=False)
763
+ length_slider = gr.Slider(label="Animation length", value=80, minimum=40, maximum=96, step=1)
764
+ cfg_scale_slider = gr.Slider(label="CFG Scale", value=6.0, minimum=0, maximum=20)
765
+ else:
766
+ width_slider = gr.Slider(label="Width", value=672, minimum=256, maximum=704, step=16)
767
+ height_slider = gr.Slider(label="Height", value=384, minimum=256, maximum=704, step=16)
768
+ with gr.Column():
769
+ gr.Markdown(
770
+ """
771
+ To ensure the efficiency of the trial, we will limit the frame rate to no more than 81.
772
+ If you want to experience longer video generation, you can go to our [Github](https://github.com/aigc-apps/EasyAnimate/).
773
+ """
774
+ )
775
+ with gr.Row():
776
+ is_image = gr.Checkbox(False, label="Generate Image")
777
+ length_slider = gr.Slider(label="Animation length", value=72, minimum=9, maximum=81, step=9)
778
+ cfg_scale_slider = gr.Slider(label="CFG Scale", value=7.0, minimum=0, maximum=20)
779
+
780
+ with gr.Row():
781
+ seed_textbox = gr.Textbox(label="Seed", value=43)
782
+ seed_button = gr.Button(value="\U0001F3B2", elem_classes="toolbutton")
783
+ seed_button.click(fn=lambda: gr.Textbox.update(value=random.randint(1, 1e8)), inputs=[], outputs=[seed_textbox])
784
+
785
+ generate_button = gr.Button(value="Generate", variant='primary')
786
+
787
+ with gr.Column():
788
+ result_image = gr.Image(label="Generated Image", interactive=False, visible=False)
789
+ result_video = gr.Video(label="Generated Animation", interactive=False)
790
+ infer_progress = gr.Textbox(
791
+ label="Generation Info",
792
+ value="No task currently",
793
+ interactive=False
794
+ )
795
+
796
+ is_image.change(
797
+ lambda x: gr.update(visible=not x),
798
+ inputs=[is_image],
799
+ outputs=[length_slider],
800
+ )
801
+
802
+ generate_button.click(
803
+ fn=controller.generate,
804
+ inputs=[
805
+ prompt_textbox,
806
+ negative_prompt_textbox,
807
+ sampler_dropdown,
808
+ sample_step_slider,
809
+ width_slider,
810
+ height_slider,
811
+ is_image,
812
+ length_slider,
813
+ cfg_scale_slider,
814
+ seed_textbox,
815
+ ],
816
+ outputs=[result_image, result_video, infer_progress]
817
+ )
818
+ return demo, controller
easyanimate/utils/__init__.py ADDED
File without changes
easyanimate/utils/diffusion_utils.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from OpenAI's diffusion repos
2
+ # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3
+ # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4
+ # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5
+
6
+ import numpy as np
7
+ import torch as th
8
+
9
+
10
+ def normal_kl(mean1, logvar1, mean2, logvar2):
11
+ """
12
+ Compute the KL divergence between two gaussians.
13
+ Shapes are automatically broadcasted, so batches can be compared to
14
+ scalars, among other use cases.
15
+ """
16
+ tensor = next(
17
+ (
18
+ obj
19
+ for obj in (mean1, logvar1, mean2, logvar2)
20
+ if isinstance(obj, th.Tensor)
21
+ ),
22
+ None,
23
+ )
24
+ assert tensor is not None, "at least one argument must be a Tensor"
25
+
26
+ # Force variances to be Tensors. Broadcasting helps convert scalars to
27
+ # Tensors, but it does not work for th.exp().
28
+ logvar1, logvar2 = [
29
+ x if isinstance(x, th.Tensor) else th.tensor(x, device=tensor.device)
30
+ for x in (logvar1, logvar2)
31
+ ]
32
+
33
+ return 0.5 * (
34
+ -1.0
35
+ + logvar2
36
+ - logvar1
37
+ + th.exp(logvar1 - logvar2)
38
+ + ((mean1 - mean2) ** 2) * th.exp(-logvar2)
39
+ )
40
+
41
+
42
+ def approx_standard_normal_cdf(x):
43
+ """
44
+ A fast approximation of the cumulative distribution function of the
45
+ standard normal.
46
+ """
47
+ return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3))))
48
+
49
+
50
+ def continuous_gaussian_log_likelihood(x, *, means, log_scales):
51
+ """
52
+ Compute the log-likelihood of a continuous Gaussian distribution.
53
+ :param x: the targets
54
+ :param means: the Gaussian mean Tensor.
55
+ :param log_scales: the Gaussian log stddev Tensor.
56
+ :return: a tensor like x of log probabilities (in nats).
57
+ """
58
+ centered_x = x - means
59
+ inv_stdv = th.exp(-log_scales)
60
+ normalized_x = centered_x * inv_stdv
61
+ return th.distributions.Normal(th.zeros_like(x), th.ones_like(x)).log_prob(
62
+ normalized_x
63
+ )
64
+
65
+
66
+ def discretized_gaussian_log_likelihood(x, *, means, log_scales):
67
+ """
68
+ Compute the log-likelihood of a Gaussian distribution discretizing to a
69
+ given image.
70
+ :param x: the target images. It is assumed that this was uint8 values,
71
+ rescaled to the range [-1, 1].
72
+ :param means: the Gaussian mean Tensor.
73
+ :param log_scales: the Gaussian log stddev Tensor.
74
+ :return: a tensor like x of log probabilities (in nats).
75
+ """
76
+ assert x.shape == means.shape == log_scales.shape
77
+ centered_x = x - means
78
+ inv_stdv = th.exp(-log_scales)
79
+ plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
80
+ cdf_plus = approx_standard_normal_cdf(plus_in)
81
+ min_in = inv_stdv * (centered_x - 1.0 / 255.0)
82
+ cdf_min = approx_standard_normal_cdf(min_in)
83
+ log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12))
84
+ log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12))
85
+ cdf_delta = cdf_plus - cdf_min
86
+ log_probs = th.where(
87
+ x < -0.999,
88
+ log_cdf_plus,
89
+ th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))),
90
+ )
91
+ assert log_probs.shape == x.shape
92
+ return log_probs
easyanimate/utils/gaussian_diffusion.py ADDED
@@ -0,0 +1,1008 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from OpenAI's diffusion repos
2
+ # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3
+ # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4
+ # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5
+
6
+
7
+ import enum
8
+ import math
9
+
10
+ import numpy as np
11
+ import torch as th
12
+ import torch.nn.functional as F
13
+ from einops import rearrange
14
+
15
+ from .diffusion_utils import discretized_gaussian_log_likelihood, normal_kl
16
+
17
+
18
+ def mean_flat(tensor):
19
+ """
20
+ Take the mean over all non-batch dimensions.
21
+ """
22
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
23
+
24
+
25
+ class ModelMeanType(enum.Enum):
26
+ """
27
+ Which type of output the model predicts.
28
+ """
29
+
30
+ PREVIOUS_X = enum.auto() # the model predicts x_{t-1}
31
+ START_X = enum.auto() # the model predicts x_0
32
+ EPSILON = enum.auto() # the model predicts epsilon
33
+
34
+
35
+ class ModelVarType(enum.Enum):
36
+ """
37
+ What is used as the model's output variance.
38
+ The LEARNED_RANGE option has been added to allow the model to predict
39
+ values between FIXED_SMALL and FIXED_LARGE, making its job easier.
40
+ """
41
+
42
+ LEARNED = enum.auto()
43
+ FIXED_SMALL = enum.auto()
44
+ FIXED_LARGE = enum.auto()
45
+ LEARNED_RANGE = enum.auto()
46
+
47
+
48
+ class LossType(enum.Enum):
49
+ MSE = enum.auto() # use raw MSE loss (and KL when learning variances)
50
+ RESCALED_MSE = (
51
+ enum.auto()
52
+ ) # use raw MSE loss (with RESCALED_KL when learning variances)
53
+ KL = enum.auto() # use the variational lower-bound
54
+ RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB
55
+
56
+ def is_vb(self):
57
+ return self in [LossType.KL, LossType.RESCALED_KL]
58
+
59
+
60
+ def _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, warmup_frac):
61
+ betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
62
+ warmup_time = int(num_diffusion_timesteps * warmup_frac)
63
+ betas[:warmup_time] = np.linspace(beta_start, beta_end, warmup_time, dtype=np.float64)
64
+ return betas
65
+
66
+
67
+ def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps):
68
+ """
69
+ This is the deprecated API for creating beta schedules.
70
+ See get_named_beta_schedule() for the new library of schedules.
71
+ """
72
+ if beta_schedule == "quad":
73
+ betas = (
74
+ np.linspace(
75
+ beta_start ** 0.5,
76
+ beta_end ** 0.5,
77
+ num_diffusion_timesteps,
78
+ dtype=np.float64,
79
+ )
80
+ ** 2
81
+ )
82
+ elif beta_schedule == "linear":
83
+ betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64)
84
+ elif beta_schedule == "warmup10":
85
+ betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.1)
86
+ elif beta_schedule == "warmup50":
87
+ betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.5)
88
+ elif beta_schedule == "const":
89
+ betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
90
+ elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1
91
+ betas = 1.0 / np.linspace(
92
+ num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64
93
+ )
94
+ else:
95
+ raise NotImplementedError(beta_schedule)
96
+ assert betas.shape == (num_diffusion_timesteps,)
97
+ return betas
98
+
99
+
100
+ def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
101
+ """
102
+ Get a pre-defined beta schedule for the given name.
103
+ The beta schedule library consists of beta schedules which remain similar
104
+ in the limit of num_diffusion_timesteps.
105
+ Beta schedules may be added, but should not be removed or changed once
106
+ they are committed to maintain backwards compatibility.
107
+ """
108
+ if schedule_name == "linear":
109
+ # Linear schedule from Ho et al, extended to work for any number of
110
+ # diffusion steps.
111
+ scale = 1000 / num_diffusion_timesteps
112
+ return get_beta_schedule(
113
+ "linear",
114
+ beta_start=scale * 0.0001,
115
+ beta_end=scale * 0.02,
116
+ num_diffusion_timesteps=num_diffusion_timesteps,
117
+ )
118
+ elif schedule_name == "squaredcos_cap_v2":
119
+ return betas_for_alpha_bar(
120
+ num_diffusion_timesteps,
121
+ lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
122
+ )
123
+ else:
124
+ raise NotImplementedError(f"unknown beta schedule: {schedule_name}")
125
+
126
+
127
+ def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
128
+ """
129
+ Create a beta schedule that discretizes the given alpha_t_bar function,
130
+ which defines the cumulative product of (1-beta) over time from t = [0,1].
131
+ :param num_diffusion_timesteps: the number of betas to produce.
132
+ :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
133
+ produces the cumulative product of (1-beta) up to that
134
+ part of the diffusion process.
135
+ :param max_beta: the maximum beta to use; use values lower than 1 to
136
+ prevent singularities.
137
+ """
138
+ betas = []
139
+ for i in range(num_diffusion_timesteps):
140
+ t1 = i / num_diffusion_timesteps
141
+ t2 = (i + 1) / num_diffusion_timesteps
142
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
143
+ return np.array(betas)
144
+
145
+
146
+ class GaussianDiffusion:
147
+ """
148
+ Utilities for training and sampling diffusion models.
149
+ Original ported from this codebase:
150
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42
151
+ :param betas: a 1-D numpy array of betas for each diffusion timestep,
152
+ starting at T and going to 1.
153
+ """
154
+
155
+ def __init__(
156
+ self,
157
+ *,
158
+ betas,
159
+ model_mean_type,
160
+ model_var_type,
161
+ loss_type,
162
+ snr=False,
163
+ return_startx=False,
164
+ ):
165
+
166
+ self.model_mean_type = model_mean_type
167
+ self.model_var_type = model_var_type
168
+ self.loss_type = loss_type
169
+ self.snr = snr
170
+ self.return_startx = return_startx
171
+
172
+ # Use float64 for accuracy.
173
+ betas = np.array(betas, dtype=np.float64)
174
+ self.betas = betas
175
+ assert len(betas.shape) == 1, "betas must be 1-D"
176
+ assert (betas > 0).all() and (betas <= 1).all()
177
+
178
+ self.num_timesteps = int(betas.shape[0])
179
+
180
+ alphas = 1.0 - betas
181
+ self.alphas_cumprod = np.cumprod(alphas, axis=0)
182
+ self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])
183
+ self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0)
184
+ assert self.alphas_cumprod_prev.shape == (self.num_timesteps,)
185
+
186
+ # calculations for diffusion q(x_t | x_{t-1}) and others
187
+ self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
188
+ self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod)
189
+ self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod)
190
+ self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod)
191
+ self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1)
192
+
193
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
194
+ self.posterior_variance = (
195
+ betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
196
+ )
197
+ # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
198
+ self.posterior_log_variance_clipped = np.log(
199
+ np.append(self.posterior_variance[1], self.posterior_variance[1:])
200
+ ) if len(self.posterior_variance) > 1 else np.array([])
201
+
202
+ self.posterior_mean_coef1 = (
203
+ betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
204
+ )
205
+ self.posterior_mean_coef2 = (
206
+ (1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - self.alphas_cumprod)
207
+ )
208
+
209
+ def q_mean_variance(self, x_start, t):
210
+ """
211
+ Get the distribution q(x_t | x_0).
212
+ :param x_start: the [N x C x ...] tensor of noiseless inputs.
213
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
214
+ :return: A tuple (mean, variance, log_variance), all of x_start's shape.
215
+ """
216
+ mean = _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
217
+ variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
218
+ log_variance = _extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape)
219
+ return mean, variance, log_variance
220
+
221
+ def q_sample(self, x_start, t, noise=None):
222
+ """
223
+ Diffuse the data for a given number of diffusion steps.
224
+ In other words, sample from q(x_t | x_0).
225
+ :param x_start: the initial data batch.
226
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
227
+ :param noise: if specified, the split-out normal noise.
228
+ :return: A noisy version of x_start.
229
+ """
230
+ if noise is None:
231
+ noise = th.randn_like(x_start)
232
+ assert noise.shape == x_start.shape
233
+ return (
234
+ _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
235
+ + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
236
+ )
237
+
238
+ def q_posterior_mean_variance(self, x_start, x_t, t):
239
+ """
240
+ Compute the mean and variance of the diffusion posterior:
241
+ q(x_{t-1} | x_t, x_0)
242
+ """
243
+ assert x_start.shape == x_t.shape
244
+ posterior_mean = (
245
+ _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
246
+ + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
247
+ )
248
+ posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape)
249
+ posterior_log_variance_clipped = _extract_into_tensor(
250
+ self.posterior_log_variance_clipped, t, x_t.shape
251
+ )
252
+ assert (
253
+ posterior_mean.shape[0]
254
+ == posterior_variance.shape[0]
255
+ == posterior_log_variance_clipped.shape[0]
256
+ == x_start.shape[0]
257
+ )
258
+ return posterior_mean, posterior_variance, posterior_log_variance_clipped
259
+
260
+ def p_mean_variance(self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None):
261
+ """
262
+ Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
263
+ the initial x, x_0.
264
+ :param model: the model, which takes a signal and a batch of timesteps
265
+ as input.
266
+ :param x: the [N x C x ...] tensor at time t.
267
+ :param t: a 1-D Tensor of timesteps.
268
+ :param clip_denoised: if True, clip the denoised signal into [-1, 1].
269
+ :param denoised_fn: if not None, a function which applies to the
270
+ x_start prediction before it is used to sample. Applies before
271
+ clip_denoised.
272
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
273
+ pass to the model. This can be used for conditioning.
274
+ :return: a dict with the following keys:
275
+ - 'mean': the model mean output.
276
+ - 'variance': the model variance output.
277
+ - 'log_variance': the log of 'variance'.
278
+ - 'pred_xstart': the prediction for x_0.
279
+ """
280
+ if model_kwargs is None:
281
+ model_kwargs = {}
282
+
283
+ B, C = x.shape[:2]
284
+ assert t.shape == (B,)
285
+ model_output = model(x, timestep=t, **model_kwargs)
286
+ if isinstance(model_output, tuple):
287
+ model_output, extra = model_output
288
+ else:
289
+ extra = None
290
+
291
+ if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]:
292
+ assert model_output.shape == (B, C * 2, *x.shape[2:])
293
+ model_output, model_var_values = th.split(model_output, C, dim=1)
294
+ min_log = _extract_into_tensor(self.posterior_log_variance_clipped, t, x.shape)
295
+ max_log = _extract_into_tensor(np.log(self.betas), t, x.shape)
296
+ # The model_var_values is [-1, 1] for [min_var, max_var].
297
+ frac = (model_var_values + 1) / 2
298
+ model_log_variance = frac * max_log + (1 - frac) * min_log
299
+ model_variance = th.exp(model_log_variance)
300
+ elif self.model_var_type in [ModelVarType.FIXED_LARGE, ModelVarType.FIXED_SMALL]:
301
+ model_variance, model_log_variance = {
302
+ # for fixedlarge, we set the initial (log-)variance like so
303
+ # to get a better decoder log likelihood.
304
+ ModelVarType.FIXED_LARGE: (
305
+ np.append(self.posterior_variance[1], self.betas[1:]),
306
+ np.log(np.append(self.posterior_variance[1], self.betas[1:])),
307
+ ),
308
+ ModelVarType.FIXED_SMALL: (
309
+ self.posterior_variance,
310
+ self.posterior_log_variance_clipped,
311
+ ),
312
+ }[self.model_var_type]
313
+ model_variance = _extract_into_tensor(model_variance, t, x.shape)
314
+ model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape)
315
+ else:
316
+ model_variance = th.zeros_like(model_output)
317
+ model_log_variance = th.zeros_like(model_output)
318
+
319
+ def process_xstart(x):
320
+ if denoised_fn is not None:
321
+ x = denoised_fn(x)
322
+ return x.clamp(-1, 1) if clip_denoised else x
323
+
324
+ if self.model_mean_type == ModelMeanType.START_X:
325
+ pred_xstart = process_xstart(model_output)
326
+ else:
327
+ pred_xstart = process_xstart(
328
+ self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output)
329
+ )
330
+ model_mean, _, _ = self.q_posterior_mean_variance(x_start=pred_xstart, x_t=x, t=t)
331
+
332
+ assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape
333
+ return {
334
+ "mean": model_mean,
335
+ "variance": model_variance,
336
+ "log_variance": model_log_variance,
337
+ "pred_xstart": pred_xstart,
338
+ "extra": extra,
339
+ }
340
+
341
+ def _predict_xstart_from_eps(self, x_t, t, eps):
342
+ assert x_t.shape == eps.shape
343
+ return (
344
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
345
+ - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
346
+ )
347
+
348
+ def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
349
+ return (
350
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart
351
+ ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
352
+
353
+ def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
354
+ """
355
+ Compute the mean for the previous step, given a function cond_fn that
356
+ computes the gradient of a conditional log probability with respect to
357
+ x. In particular, cond_fn computes grad(log(p(y|x))), and we want to
358
+ condition on y.
359
+ This uses the conditioning strategy from Sohl-Dickstein et al. (2015).
360
+ """
361
+ gradient = cond_fn(x, t, **model_kwargs)
362
+ return p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float()
363
+
364
+ def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
365
+ """
366
+ Compute what the p_mean_variance output would have been, should the
367
+ model's score function be conditioned by cond_fn.
368
+ See condition_mean() for details on cond_fn.
369
+ Unlike condition_mean(), this instead uses the conditioning strategy
370
+ from Song et al (2020).
371
+ """
372
+ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
373
+
374
+ eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"])
375
+ eps = eps - (1 - alpha_bar).sqrt() * cond_fn(x, t, **model_kwargs)
376
+
377
+ out = p_mean_var.copy()
378
+ out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps)
379
+ out["mean"], _, _ = self.q_posterior_mean_variance(x_start=out["pred_xstart"], x_t=x, t=t)
380
+ return out
381
+
382
+ def p_sample(
383
+ self,
384
+ model,
385
+ x,
386
+ t,
387
+ clip_denoised=True,
388
+ denoised_fn=None,
389
+ cond_fn=None,
390
+ model_kwargs=None,
391
+ ):
392
+ """
393
+ Sample x_{t-1} from the model at the given timestep.
394
+ :param model: the model to sample from.
395
+ :param x: the current tensor at x_{t-1}.
396
+ :param t: the value of t, starting at 0 for the first diffusion step.
397
+ :param clip_denoised: if True, clip the x_start prediction to [-1, 1].
398
+ :param denoised_fn: if not None, a function which applies to the
399
+ x_start prediction before it is used to sample.
400
+ :param cond_fn: if not None, this is a gradient function that acts
401
+ similarly to the model.
402
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
403
+ pass to the model. This can be used for conditioning.
404
+ :return: a dict containing the following keys:
405
+ - 'sample': a random sample from the model.
406
+ - 'pred_xstart': a prediction of x_0.
407
+ """
408
+ out = self.p_mean_variance(
409
+ model,
410
+ x,
411
+ t,
412
+ clip_denoised=clip_denoised,
413
+ denoised_fn=denoised_fn,
414
+ model_kwargs=model_kwargs,
415
+ )
416
+ noise = th.randn_like(x)
417
+ nonzero_mask = (
418
+ (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
419
+ ) # no noise when t == 0
420
+ if cond_fn is not None:
421
+ out["mean"] = self.condition_mean(cond_fn, out, x, t, model_kwargs=model_kwargs)
422
+ sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise
423
+ return {"sample": sample, "pred_xstart": out["pred_xstart"]}
424
+
425
+ def p_sample_loop(
426
+ self,
427
+ model,
428
+ shape,
429
+ noise=None,
430
+ clip_denoised=True,
431
+ denoised_fn=None,
432
+ cond_fn=None,
433
+ model_kwargs=None,
434
+ device=None,
435
+ progress=False,
436
+ ):
437
+ """
438
+ Generate samples from the model.
439
+ :param model: the model module.
440
+ :param shape: the shape of the samples, (N, C, H, W).
441
+ :param noise: if specified, the noise from the encoder to sample.
442
+ Should be of the same shape as `shape`.
443
+ :param clip_denoised: if True, clip x_start predictions to [-1, 1].
444
+ :param denoised_fn: if not None, a function which applies to the
445
+ x_start prediction before it is used to sample.
446
+ :param cond_fn: if not None, this is a gradient function that acts
447
+ similarly to the model.
448
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
449
+ pass to the model. This can be used for conditioning.
450
+ :param device: if specified, the device to create the samples on.
451
+ If not specified, use a model parameter's device.
452
+ :param progress: if True, show a tqdm progress bar.
453
+ :return: a non-differentiable batch of samples.
454
+ """
455
+ final = None
456
+ for sample in self.p_sample_loop_progressive(
457
+ model,
458
+ shape,
459
+ noise=noise,
460
+ clip_denoised=clip_denoised,
461
+ denoised_fn=denoised_fn,
462
+ cond_fn=cond_fn,
463
+ model_kwargs=model_kwargs,
464
+ device=device,
465
+ progress=progress,
466
+ ):
467
+ final = sample
468
+ return final["sample"]
469
+
470
+ def p_sample_loop_progressive(
471
+ self,
472
+ model,
473
+ shape,
474
+ noise=None,
475
+ clip_denoised=True,
476
+ denoised_fn=None,
477
+ cond_fn=None,
478
+ model_kwargs=None,
479
+ device=None,
480
+ progress=False,
481
+ ):
482
+ """
483
+ Generate samples from the model and yield intermediate samples from
484
+ each timestep of diffusion.
485
+ Arguments are the same as p_sample_loop().
486
+ Returns a generator over dicts, where each dict is the return value of
487
+ p_sample().
488
+ """
489
+ if device is None:
490
+ device = next(model.parameters()).device
491
+ assert isinstance(shape, (tuple, list))
492
+ img = noise if noise is not None else th.randn(*shape, device=device)
493
+ indices = list(range(self.num_timesteps))[::-1]
494
+
495
+ if progress:
496
+ # Lazy import so that we don't depend on tqdm.
497
+ from tqdm.auto import tqdm
498
+
499
+ indices = tqdm(indices)
500
+
501
+ for i in indices:
502
+ t = th.tensor([i] * shape[0], device=device)
503
+ with th.no_grad():
504
+ out = self.p_sample(
505
+ model,
506
+ img,
507
+ t,
508
+ clip_denoised=clip_denoised,
509
+ denoised_fn=denoised_fn,
510
+ cond_fn=cond_fn,
511
+ model_kwargs=model_kwargs,
512
+ )
513
+ yield out
514
+ img = out["sample"]
515
+
516
+ def ddim_sample(
517
+ self,
518
+ model,
519
+ x,
520
+ t,
521
+ clip_denoised=True,
522
+ denoised_fn=None,
523
+ cond_fn=None,
524
+ model_kwargs=None,
525
+ eta=0.0,
526
+ ):
527
+ """
528
+ Sample x_{t-1} from the model using DDIM.
529
+ Same usage as p_sample().
530
+ """
531
+ out = self.p_mean_variance(
532
+ model,
533
+ x,
534
+ t,
535
+ clip_denoised=clip_denoised,
536
+ denoised_fn=denoised_fn,
537
+ model_kwargs=model_kwargs,
538
+ )
539
+ if cond_fn is not None:
540
+ out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
541
+
542
+ # Usually our model outputs epsilon, but we re-derive it
543
+ # in case we used x_start or x_prev prediction.
544
+ eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"])
545
+
546
+ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
547
+ alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape)
548
+ sigma = (
549
+ eta
550
+ * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar))
551
+ * th.sqrt(1 - alpha_bar / alpha_bar_prev)
552
+ )
553
+ # Equation 12.
554
+ noise = th.randn_like(x)
555
+ mean_pred = (
556
+ out["pred_xstart"] * th.sqrt(alpha_bar_prev)
557
+ + th.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps
558
+ )
559
+ nonzero_mask = (
560
+ (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
561
+ ) # no noise when t == 0
562
+ sample = mean_pred + nonzero_mask * sigma * noise
563
+ return {"sample": sample, "pred_xstart": out["pred_xstart"]}
564
+
565
+ def ddim_reverse_sample(
566
+ self,
567
+ model,
568
+ x,
569
+ t,
570
+ clip_denoised=True,
571
+ denoised_fn=None,
572
+ cond_fn=None,
573
+ model_kwargs=None,
574
+ eta=0.0,
575
+ ):
576
+ """
577
+ Sample x_{t+1} from the model using DDIM reverse ODE.
578
+ """
579
+ assert eta == 0.0, "Reverse ODE only for deterministic path"
580
+ out = self.p_mean_variance(
581
+ model,
582
+ x,
583
+ t,
584
+ clip_denoised=clip_denoised,
585
+ denoised_fn=denoised_fn,
586
+ model_kwargs=model_kwargs,
587
+ )
588
+ if cond_fn is not None:
589
+ out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
590
+ # Usually our model outputs epsilon, but we re-derive it
591
+ # in case we used x_start or x_prev prediction.
592
+ eps = (
593
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x
594
+ - out["pred_xstart"]
595
+ ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape)
596
+ alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape)
597
+
598
+ # Equation 12. reversed
599
+ mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_next) + th.sqrt(1 - alpha_bar_next) * eps
600
+
601
+ return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]}
602
+
603
+ def ddim_sample_loop(
604
+ self,
605
+ model,
606
+ shape,
607
+ noise=None,
608
+ clip_denoised=True,
609
+ denoised_fn=None,
610
+ cond_fn=None,
611
+ model_kwargs=None,
612
+ device=None,
613
+ progress=False,
614
+ eta=0.0,
615
+ ):
616
+ """
617
+ Generate samples from the model using DDIM.
618
+ Same usage as p_sample_loop().
619
+ """
620
+ final = None
621
+ for sample in self.ddim_sample_loop_progressive(
622
+ model,
623
+ shape,
624
+ noise=noise,
625
+ clip_denoised=clip_denoised,
626
+ denoised_fn=denoised_fn,
627
+ cond_fn=cond_fn,
628
+ model_kwargs=model_kwargs,
629
+ device=device,
630
+ progress=progress,
631
+ eta=eta,
632
+ ):
633
+ final = sample
634
+ return final["sample"]
635
+
636
+ def ddim_sample_loop_progressive(
637
+ self,
638
+ model,
639
+ shape,
640
+ noise=None,
641
+ clip_denoised=True,
642
+ denoised_fn=None,
643
+ cond_fn=None,
644
+ model_kwargs=None,
645
+ device=None,
646
+ progress=False,
647
+ eta=0.0,
648
+ ):
649
+ """
650
+ Use DDIM to sample from the model and yield intermediate samples from
651
+ each timestep of DDIM.
652
+ Same usage as p_sample_loop_progressive().
653
+ """
654
+ if device is None:
655
+ device = next(model.parameters()).device
656
+ assert isinstance(shape, (tuple, list))
657
+ img = noise if noise is not None else th.randn(*shape, device=device)
658
+ indices = list(range(self.num_timesteps))[::-1]
659
+
660
+ if progress:
661
+ # Lazy import so that we don't depend on tqdm.
662
+ from tqdm.auto import tqdm
663
+
664
+ indices = tqdm(indices)
665
+
666
+ for i in indices:
667
+ t = th.tensor([i] * shape[0], device=device)
668
+ with th.no_grad():
669
+ out = self.ddim_sample(
670
+ model,
671
+ img,
672
+ t,
673
+ clip_denoised=clip_denoised,
674
+ denoised_fn=denoised_fn,
675
+ cond_fn=cond_fn,
676
+ model_kwargs=model_kwargs,
677
+ eta=eta,
678
+ )
679
+ yield out
680
+ img = out["sample"]
681
+
682
+ def _vb_terms_bpd(
683
+ self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None
684
+ ):
685
+ """
686
+ Get a term for the variational lower-bound.
687
+ The resulting units are bits (rather than nats, as one might expect).
688
+ This allows for comparison to other papers.
689
+ :return: a dict with the following keys:
690
+ - 'output': a shape [N] tensor of NLLs or KLs.
691
+ - 'pred_xstart': the x_0 predictions.
692
+ """
693
+ true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(
694
+ x_start=x_start, x_t=x_t, t=t
695
+ )
696
+ out = self.p_mean_variance(
697
+ model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs
698
+ )
699
+ kl = normal_kl(
700
+ true_mean, true_log_variance_clipped, out["mean"], out["log_variance"]
701
+ )
702
+ kl = mean_flat(kl) / np.log(2.0)
703
+
704
+ decoder_nll = -discretized_gaussian_log_likelihood(
705
+ x_start, means=out["mean"], log_scales=0.5 * out["log_variance"]
706
+ )
707
+ assert decoder_nll.shape == x_start.shape
708
+ decoder_nll = mean_flat(decoder_nll) / np.log(2.0)
709
+
710
+ # At the first timestep return the decoder NLL,
711
+ # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))
712
+ output = th.where((t == 0), decoder_nll, kl)
713
+ return {"output": output, "pred_xstart": out["pred_xstart"]}
714
+
715
+ def training_losses(self, model, x_start, timestep, model_kwargs=None, noise=None, skip_noise=False):
716
+ """
717
+ Compute training losses for a single timestep.
718
+ :param model: the model to evaluate loss on.
719
+ :param x_start: the [N x C x ...] tensor of inputs.
720
+ :param t: a batch of timestep indices.
721
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
722
+ pass to the model. This can be used for conditioning.
723
+ :param noise: if specified, the specific Gaussian noise to try to remove.
724
+ :return: a dict with the key "loss" containing a tensor of shape [N].
725
+ Some mean or variance settings may also have other keys.
726
+ """
727
+ t = timestep
728
+ if model_kwargs is None:
729
+ model_kwargs = {}
730
+ if skip_noise:
731
+ x_t = x_start
732
+ else:
733
+ if noise is None:
734
+ noise = th.randn_like(x_start)
735
+ x_t = self.q_sample(x_start, t, noise=noise)
736
+
737
+ terms = {}
738
+
739
+ if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL:
740
+ terms["loss"] = self._vb_terms_bpd(
741
+ model=model,
742
+ x_start=x_start,
743
+ x_t=x_t,
744
+ t=t,
745
+ clip_denoised=False,
746
+ model_kwargs=model_kwargs,
747
+ )["output"]
748
+ if self.loss_type == LossType.RESCALED_KL:
749
+ terms["loss"] *= self.num_timesteps
750
+ elif self.loss_type in [LossType.MSE, LossType.RESCALED_MSE]:
751
+ model_output = model(x_t, timestep=t, **model_kwargs)[0]
752
+
753
+ if isinstance(model_output, dict) and model_output.get('x', None) is not None:
754
+ output = model_output['x']
755
+ else:
756
+ output = model_output
757
+
758
+ if self.return_startx and self.model_mean_type == ModelMeanType.EPSILON:
759
+ return self._extracted_from_training_losses_diffusers(x_t, output, t)
760
+ # self.model_var_type = ModelVarType.LEARNED_RANGE:4
761
+ if self.model_var_type in [
762
+ ModelVarType.LEARNED,
763
+ ModelVarType.LEARNED_RANGE,
764
+ ]:
765
+ B, C = x_t.shape[:2]
766
+ assert output.shape == (B, C * 2, *x_t.shape[2:])
767
+ output, model_var_values = th.split(output, C, dim=1)
768
+ # Learn the variance using the variational bound, but don't let it affect our mean prediction.
769
+ frozen_out = th.cat([output.detach(), model_var_values], dim=1)
770
+ # vb variational bound
771
+ terms["vb"] = self._vb_terms_bpd(
772
+ model=lambda *args, r=frozen_out, **kwargs: r,
773
+ x_start=x_start,
774
+ x_t=x_t,
775
+ t=t,
776
+ clip_denoised=False,
777
+ )["output"]
778
+ if self.loss_type == LossType.RESCALED_MSE:
779
+ # Divide by 1000 for equivalence with initial implementation.
780
+ # Without a factor of 1/1000, the VB term hurts the MSE term.
781
+ terms["vb"] *= self.num_timesteps / 1000.0
782
+
783
+ target = {
784
+ ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance(
785
+ x_start=x_start, x_t=x_t, t=t
786
+ )[0],
787
+ ModelMeanType.START_X: x_start,
788
+ ModelMeanType.EPSILON: noise,
789
+ }[self.model_mean_type]
790
+ assert output.shape == target.shape == x_start.shape
791
+ if self.snr:
792
+ if self.model_mean_type == ModelMeanType.START_X:
793
+ pred_noise = self._predict_eps_from_xstart(x_t=x_t, t=t, pred_xstart=output)
794
+ pred_startx = output
795
+ elif self.model_mean_type == ModelMeanType.EPSILON:
796
+ pred_noise = output
797
+ pred_startx = self._predict_xstart_from_eps(x_t=x_t, t=t, eps=output)
798
+ # terms["mse_eps"] = mean_flat((noise - pred_noise) ** 2)
799
+ # terms["mse_x0"] = mean_flat((x_start - pred_startx) ** 2)
800
+
801
+ t = t[:, None, None, None, None].expand(pred_startx.shape) # [128, 4, 32, 32]
802
+ # best
803
+ target = th.where(t > 249, noise, x_start)
804
+ output = th.where(t > 249, pred_noise, pred_startx)
805
+ loss = (target - output) ** 2
806
+ if model_kwargs.get('mask_ratio', False) and model_kwargs['mask_ratio'] > 0:
807
+ assert 'mask' in model_output
808
+ loss = F.avg_pool2d(loss.mean(dim=1), model.model.module.patch_size).flatten(1)
809
+ mask = model_output['mask']
810
+ unmask = 1 - mask
811
+ terms['mse'] = mean_flat(loss * unmask) * unmask.shape[1]/unmask.sum(1)
812
+ if model_kwargs['mask_loss_coef'] > 0:
813
+ terms['mae'] = model_kwargs['mask_loss_coef'] * mean_flat(loss * mask) * mask.shape[1]/mask.sum(1)
814
+ else:
815
+ terms["mse"] = mean_flat(loss)
816
+ terms["loss"] = terms["mse"] + terms["vb"] if "vb" in terms else terms["mse"]
817
+ if "mae" in terms:
818
+ terms["loss"] = terms["loss"] + terms["mae"]
819
+ else:
820
+ raise NotImplementedError(self.loss_type)
821
+
822
+ return terms
823
+
824
+ def training_losses_diffusers(self, model, x_start, timestep, model_kwargs=None, noise=None, skip_noise=False):
825
+ """
826
+ Compute training losses for a single timestep.
827
+ :param model: the model to evaluate loss on.
828
+ :param x_start: the [N x C x ...] tensor of inputs.
829
+ :param t: a batch of timestep indices.
830
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
831
+ pass to the model. This can be used for conditioning.
832
+ :param noise: if specified, the specific Gaussian noise to try to remove.
833
+ :return: a dict with the key "loss" containing a tensor of shape [N].
834
+ Some mean or variance settings may also have other keys.
835
+ """
836
+ t = timestep
837
+ if model_kwargs is None:
838
+ model_kwargs = {}
839
+ if skip_noise:
840
+ x_t = x_start
841
+ else:
842
+ if noise is None:
843
+ noise = th.randn_like(x_start)
844
+ x_t = self.q_sample(x_start, t, noise=noise)
845
+
846
+ terms = {}
847
+
848
+ if self.loss_type in [LossType.KL, LossType.RESCALED_KL]:
849
+ terms["loss"] = self._vb_terms_bpd(
850
+ model=model,
851
+ x_start=x_start,
852
+ x_t=x_t,
853
+ t=t,
854
+ clip_denoised=False,
855
+ model_kwargs=model_kwargs,
856
+ )["output"]
857
+ if self.loss_type == LossType.RESCALED_KL:
858
+ terms["loss"] *= self.num_timesteps
859
+ elif self.loss_type in [LossType.MSE, LossType.RESCALED_MSE]:
860
+ output = model(x_t, timestep=t, **model_kwargs, return_dict=False)[0]
861
+
862
+ if self.return_startx and self.model_mean_type == ModelMeanType.EPSILON:
863
+ return self._extracted_from_training_losses_diffusers(x_t, output, t)
864
+
865
+ if self.model_var_type in [
866
+ ModelVarType.LEARNED,
867
+ ModelVarType.LEARNED_RANGE,
868
+ ]:
869
+ B, C = x_t.shape[:2]
870
+ assert output.shape == (B, C * 2, *x_t.shape[2:])
871
+ output, model_var_values = th.split(output, C, dim=1)
872
+ # Learn the variance using the variational bound, but don't let it affect our mean prediction.
873
+ frozen_out = th.cat([output.detach(), model_var_values], dim=1)
874
+ terms["vb"] = self._vb_terms_bpd(
875
+ model=lambda *args, r=frozen_out, **kwargs: r,
876
+ x_start=x_start,
877
+ x_t=x_t,
878
+ t=t,
879
+ clip_denoised=False,
880
+ )["output"]
881
+ if self.loss_type == LossType.RESCALED_MSE:
882
+ # Divide by 1000 for equivalence with initial implementation.
883
+ # Without a factor of 1/1000, the VB term hurts the MSE term.
884
+ terms["vb"] *= self.num_timesteps / 1000.0
885
+
886
+ target = {
887
+ ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance(
888
+ x_start=x_start, x_t=x_t, t=t
889
+ )[0],
890
+ ModelMeanType.START_X: x_start,
891
+ ModelMeanType.EPSILON: noise,
892
+ }[self.model_mean_type]
893
+ assert output.shape == target.shape == x_start.shape
894
+ if self.snr:
895
+ if self.model_mean_type == ModelMeanType.START_X:
896
+ pred_noise = self._predict_eps_from_xstart(x_t=x_t, t=t, pred_xstart=output)
897
+ pred_startx = output
898
+ elif self.model_mean_type == ModelMeanType.EPSILON:
899
+ pred_noise = output
900
+ pred_startx = self._predict_xstart_from_eps(x_t=x_t, t=t, eps=output)
901
+ # terms["mse_eps"] = mean_flat((noise - pred_noise) ** 2)
902
+ # terms["mse_x0"] = mean_flat((x_start - pred_startx) ** 2)
903
+
904
+ t = t[:, None, None, None].expand(pred_startx.shape) # [128, 4, 32, 32]
905
+ # best
906
+ target = th.where(t > 249, noise, x_start)
907
+ output = th.where(t > 249, pred_noise, pred_startx)
908
+ loss = (target - output) ** 2
909
+ terms["mse"] = mean_flat(loss)
910
+ terms["loss"] = terms["mse"] + terms["vb"] if "vb" in terms else terms["mse"]
911
+ if "mae" in terms:
912
+ terms["loss"] = terms["loss"] + terms["mae"]
913
+ else:
914
+ raise NotImplementedError(self.loss_type)
915
+
916
+ return terms
917
+
918
+ def _extracted_from_training_losses_diffusers(self, x_t, output, t):
919
+ B, C = x_t.shape[:2]
920
+ assert output.shape == (B, C * 2, *x_t.shape[2:])
921
+ output = th.split(output, C, dim=1)[0]
922
+ return output, self._predict_xstart_from_eps(x_t=x_t, t=t, eps=output), x_t
923
+
924
+ def _prior_bpd(self, x_start):
925
+ """
926
+ Get the prior KL term for the variational lower-bound, measured in
927
+ bits-per-dim.
928
+ This term can't be optimized, as it only depends on the encoder.
929
+ :param x_start: the [N x C x ...] tensor of inputs.
930
+ :return: a batch of [N] KL values (in bits), one per batch element.
931
+ """
932
+ batch_size = x_start.shape[0]
933
+ t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
934
+ qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
935
+ kl_prior = normal_kl(
936
+ mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0
937
+ )
938
+ return mean_flat(kl_prior) / np.log(2.0)
939
+
940
+ def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None):
941
+ """
942
+ Compute the entire variational lower-bound, measured in bits-per-dim,
943
+ as well as other related quantities.
944
+ :param model: the model to evaluate loss on.
945
+ :param x_start: the [N x C x ...] tensor of inputs.
946
+ :param clip_denoised: if True, clip denoised samples.
947
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
948
+ pass to the model. This can be used for conditioning.
949
+ :return: a dict containing the following keys:
950
+ - total_bpd: the total variational lower-bound, per batch element.
951
+ - prior_bpd: the prior term in the lower-bound.
952
+ - vb: an [N x T] tensor of terms in the lower-bound.
953
+ - xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep.
954
+ - mse: an [N x T] tensor of epsilon MSEs for each timestep.
955
+ """
956
+ device = x_start.device
957
+ batch_size = x_start.shape[0]
958
+
959
+ vb = []
960
+ xstart_mse = []
961
+ mse = []
962
+ for t in list(range(self.num_timesteps))[::-1]:
963
+ t_batch = th.tensor([t] * batch_size, device=device)
964
+ noise = th.randn_like(x_start)
965
+ x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise)
966
+ # Calculate VLB term at the current timestep
967
+ with th.no_grad():
968
+ out = self._vb_terms_bpd(
969
+ model,
970
+ x_start=x_start,
971
+ x_t=x_t,
972
+ t=t_batch,
973
+ clip_denoised=clip_denoised,
974
+ model_kwargs=model_kwargs,
975
+ )
976
+ vb.append(out["output"])
977
+ xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2))
978
+ eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"])
979
+ mse.append(mean_flat((eps - noise) ** 2))
980
+
981
+ vb = th.stack(vb, dim=1)
982
+ xstart_mse = th.stack(xstart_mse, dim=1)
983
+ mse = th.stack(mse, dim=1)
984
+
985
+ prior_bpd = self._prior_bpd(x_start)
986
+ total_bpd = vb.sum(dim=1) + prior_bpd
987
+ return {
988
+ "total_bpd": total_bpd,
989
+ "prior_bpd": prior_bpd,
990
+ "vb": vb,
991
+ "xstart_mse": xstart_mse,
992
+ "mse": mse,
993
+ }
994
+
995
+
996
+ def _extract_into_tensor(arr, timesteps, broadcast_shape):
997
+ """
998
+ Extract values from a 1-D numpy array for a batch of indices.
999
+ :param arr: the 1-D numpy array.
1000
+ :param timesteps: a tensor of indices into the array to extract.
1001
+ :param broadcast_shape: a larger shape of K dimensions with the batch
1002
+ dimension equal to the length of timesteps.
1003
+ :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
1004
+ """
1005
+ res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
1006
+ while len(res.shape) < len(broadcast_shape):
1007
+ res = res[..., None]
1008
+ return res + th.zeros(broadcast_shape, device=timesteps.device)
easyanimate/utils/lora_utils.py ADDED
@@ -0,0 +1,476 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # LoRA network module
2
+ # reference:
3
+ # https://github.com/microsoft/LoRA/blob/main/loralib/layers.py
4
+ # https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py
5
+ # https://github.com/bmaltais/kohya_ss
6
+
7
+ import hashlib
8
+ import math
9
+ import os
10
+ from collections import defaultdict
11
+ from io import BytesIO
12
+ from typing import List, Optional, Type, Union
13
+
14
+ import safetensors.torch
15
+ import torch
16
+ import torch.utils.checkpoint
17
+ from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear
18
+ from safetensors.torch import load_file
19
+ from transformers import T5EncoderModel
20
+
21
+
22
+ class LoRAModule(torch.nn.Module):
23
+ """
24
+ replaces forward method of the original Linear, instead of replacing the original Linear module.
25
+ """
26
+
27
+ def __init__(
28
+ self,
29
+ lora_name,
30
+ org_module: torch.nn.Module,
31
+ multiplier=1.0,
32
+ lora_dim=4,
33
+ alpha=1,
34
+ dropout=None,
35
+ rank_dropout=None,
36
+ module_dropout=None,
37
+ ):
38
+ """if alpha == 0 or None, alpha is rank (no scaling)."""
39
+ super().__init__()
40
+ self.lora_name = lora_name
41
+
42
+ if org_module.__class__.__name__ == "Conv2d":
43
+ in_dim = org_module.in_channels
44
+ out_dim = org_module.out_channels
45
+ else:
46
+ in_dim = org_module.in_features
47
+ out_dim = org_module.out_features
48
+
49
+ self.lora_dim = lora_dim
50
+ if org_module.__class__.__name__ == "Conv2d":
51
+ kernel_size = org_module.kernel_size
52
+ stride = org_module.stride
53
+ padding = org_module.padding
54
+ self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False)
55
+ self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False)
56
+ else:
57
+ self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False)
58
+ self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False)
59
+
60
+ if type(alpha) == torch.Tensor:
61
+ alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
62
+ alpha = self.lora_dim if alpha is None or alpha == 0 else alpha
63
+ self.scale = alpha / self.lora_dim
64
+ self.register_buffer("alpha", torch.tensor(alpha))
65
+
66
+ # same as microsoft's
67
+ torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
68
+ torch.nn.init.zeros_(self.lora_up.weight)
69
+
70
+ self.multiplier = multiplier
71
+ self.org_module = org_module # remove in applying
72
+ self.dropout = dropout
73
+ self.rank_dropout = rank_dropout
74
+ self.module_dropout = module_dropout
75
+
76
+ def apply_to(self):
77
+ self.org_forward = self.org_module.forward
78
+ self.org_module.forward = self.forward
79
+ del self.org_module
80
+
81
+ def forward(self, x, *args, **kwargs):
82
+ weight_dtype = x.dtype
83
+ org_forwarded = self.org_forward(x)
84
+
85
+ # module dropout
86
+ if self.module_dropout is not None and self.training:
87
+ if torch.rand(1) < self.module_dropout:
88
+ return org_forwarded
89
+
90
+ lx = self.lora_down(x.to(self.lora_down.weight.dtype))
91
+
92
+ # normal dropout
93
+ if self.dropout is not None and self.training:
94
+ lx = torch.nn.functional.dropout(lx, p=self.dropout)
95
+
96
+ # rank dropout
97
+ if self.rank_dropout is not None and self.training:
98
+ mask = torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout
99
+ if len(lx.size()) == 3:
100
+ mask = mask.unsqueeze(1) # for Text Encoder
101
+ elif len(lx.size()) == 4:
102
+ mask = mask.unsqueeze(-1).unsqueeze(-1) # for Conv2d
103
+ lx = lx * mask
104
+
105
+ # scaling for rank dropout: treat as if the rank is changed
106
+ scale = self.scale * (1.0 / (1.0 - self.rank_dropout)) # redundant for readability
107
+ else:
108
+ scale = self.scale
109
+
110
+ lx = self.lora_up(lx)
111
+
112
+ return org_forwarded.to(weight_dtype) + lx.to(weight_dtype) * self.multiplier * scale
113
+
114
+
115
+ def addnet_hash_legacy(b):
116
+ """Old model hash used by sd-webui-additional-networks for .safetensors format files"""
117
+ m = hashlib.sha256()
118
+
119
+ b.seek(0x100000)
120
+ m.update(b.read(0x10000))
121
+ return m.hexdigest()[0:8]
122
+
123
+
124
+ def addnet_hash_safetensors(b):
125
+ """New model hash used by sd-webui-additional-networks for .safetensors format files"""
126
+ hash_sha256 = hashlib.sha256()
127
+ blksize = 1024 * 1024
128
+
129
+ b.seek(0)
130
+ header = b.read(8)
131
+ n = int.from_bytes(header, "little")
132
+
133
+ offset = n + 8
134
+ b.seek(offset)
135
+ for chunk in iter(lambda: b.read(blksize), b""):
136
+ hash_sha256.update(chunk)
137
+
138
+ return hash_sha256.hexdigest()
139
+
140
+
141
+ def precalculate_safetensors_hashes(tensors, metadata):
142
+ """Precalculate the model hashes needed by sd-webui-additional-networks to
143
+ save time on indexing the model later."""
144
+
145
+ # Because writing user metadata to the file can change the result of
146
+ # sd_models.model_hash(), only retain the training metadata for purposes of
147
+ # calculating the hash, as they are meant to be immutable
148
+ metadata = {k: v for k, v in metadata.items() if k.startswith("ss_")}
149
+
150
+ bytes = safetensors.torch.save(tensors, metadata)
151
+ b = BytesIO(bytes)
152
+
153
+ model_hash = addnet_hash_safetensors(b)
154
+ legacy_hash = addnet_hash_legacy(b)
155
+ return model_hash, legacy_hash
156
+
157
+
158
+ class LoRANetwork(torch.nn.Module):
159
+ TRANSFORMER_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Transformer3DModel"]
160
+ TEXT_ENCODER_TARGET_REPLACE_MODULE = ["T5LayerSelfAttention", "T5LayerFF"]
161
+ LORA_PREFIX_TRANSFORMER = "lora_unet"
162
+ LORA_PREFIX_TEXT_ENCODER = "lora_te"
163
+ def __init__(
164
+ self,
165
+ text_encoder: Union[List[T5EncoderModel], T5EncoderModel],
166
+ unet,
167
+ multiplier: float = 1.0,
168
+ lora_dim: int = 4,
169
+ alpha: float = 1,
170
+ dropout: Optional[float] = None,
171
+ module_class: Type[object] = LoRAModule,
172
+ add_lora_in_attn_temporal: bool = False,
173
+ varbose: Optional[bool] = False,
174
+ ) -> None:
175
+ super().__init__()
176
+ self.multiplier = multiplier
177
+
178
+ self.lora_dim = lora_dim
179
+ self.alpha = alpha
180
+ self.dropout = dropout
181
+
182
+ print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}")
183
+ print(f"neuron dropout: p={self.dropout}")
184
+
185
+ # create module instances
186
+ def create_modules(
187
+ is_unet: bool,
188
+ root_module: torch.nn.Module,
189
+ target_replace_modules: List[torch.nn.Module],
190
+ ) -> List[LoRAModule]:
191
+ prefix = (
192
+ self.LORA_PREFIX_TRANSFORMER
193
+ if is_unet
194
+ else self.LORA_PREFIX_TEXT_ENCODER
195
+ )
196
+ loras = []
197
+ skipped = []
198
+ for name, module in root_module.named_modules():
199
+ if module.__class__.__name__ in target_replace_modules:
200
+ for child_name, child_module in module.named_modules():
201
+ is_linear = child_module.__class__.__name__ == "Linear" or child_module.__class__.__name__ == "LoRACompatibleLinear"
202
+ is_conv2d = child_module.__class__.__name__ == "Conv2d" or child_module.__class__.__name__ == "LoRACompatibleConv"
203
+ is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1)
204
+
205
+ if not add_lora_in_attn_temporal:
206
+ if "attn_temporal" in child_name:
207
+ continue
208
+
209
+ if is_linear or is_conv2d:
210
+ lora_name = prefix + "." + name + "." + child_name
211
+ lora_name = lora_name.replace(".", "_")
212
+
213
+ dim = None
214
+ alpha = None
215
+
216
+ if is_linear or is_conv2d_1x1:
217
+ dim = self.lora_dim
218
+ alpha = self.alpha
219
+
220
+ if dim is None or dim == 0:
221
+ if is_linear or is_conv2d_1x1:
222
+ skipped.append(lora_name)
223
+ continue
224
+
225
+ lora = module_class(
226
+ lora_name,
227
+ child_module,
228
+ self.multiplier,
229
+ dim,
230
+ alpha,
231
+ dropout=dropout,
232
+ )
233
+ loras.append(lora)
234
+ return loras, skipped
235
+
236
+ text_encoders = text_encoder if type(text_encoder) == list else [text_encoder]
237
+
238
+ self.text_encoder_loras = []
239
+ skipped_te = []
240
+ for i, text_encoder in enumerate(text_encoders):
241
+ text_encoder_loras, skipped = create_modules(False, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
242
+ self.text_encoder_loras.extend(text_encoder_loras)
243
+ skipped_te += skipped
244
+ print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
245
+
246
+ self.unet_loras, skipped_un = create_modules(True, unet, LoRANetwork.TRANSFORMER_TARGET_REPLACE_MODULE)
247
+ print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
248
+
249
+ # assertion
250
+ names = set()
251
+ for lora in self.text_encoder_loras + self.unet_loras:
252
+ assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
253
+ names.add(lora.lora_name)
254
+
255
+ def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True):
256
+ if apply_text_encoder:
257
+ print("enable LoRA for text encoder")
258
+ else:
259
+ self.text_encoder_loras = []
260
+
261
+ if apply_unet:
262
+ print("enable LoRA for U-Net")
263
+ else:
264
+ self.unet_loras = []
265
+
266
+ for lora in self.text_encoder_loras + self.unet_loras:
267
+ lora.apply_to()
268
+ self.add_module(lora.lora_name, lora)
269
+
270
+ def set_multiplier(self, multiplier):
271
+ self.multiplier = multiplier
272
+ for lora in self.text_encoder_loras + self.unet_loras:
273
+ lora.multiplier = self.multiplier
274
+
275
+ def load_weights(self, file):
276
+ if os.path.splitext(file)[1] == ".safetensors":
277
+ from safetensors.torch import load_file
278
+
279
+ weights_sd = load_file(file)
280
+ else:
281
+ weights_sd = torch.load(file, map_location="cpu")
282
+ info = self.load_state_dict(weights_sd, False)
283
+ return info
284
+
285
+ def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
286
+ self.requires_grad_(True)
287
+ all_params = []
288
+
289
+ def enumerate_params(loras):
290
+ params = []
291
+ for lora in loras:
292
+ params.extend(lora.parameters())
293
+ return params
294
+
295
+ if self.text_encoder_loras:
296
+ param_data = {"params": enumerate_params(self.text_encoder_loras)}
297
+ if text_encoder_lr is not None:
298
+ param_data["lr"] = text_encoder_lr
299
+ all_params.append(param_data)
300
+
301
+ if self.unet_loras:
302
+ param_data = {"params": enumerate_params(self.unet_loras)}
303
+ if unet_lr is not None:
304
+ param_data["lr"] = unet_lr
305
+ all_params.append(param_data)
306
+
307
+ return all_params
308
+
309
+ def enable_gradient_checkpointing(self):
310
+ pass
311
+
312
+ def get_trainable_params(self):
313
+ return self.parameters()
314
+
315
+ def save_weights(self, file, dtype, metadata):
316
+ if metadata is not None and len(metadata) == 0:
317
+ metadata = None
318
+
319
+ state_dict = self.state_dict()
320
+
321
+ if dtype is not None:
322
+ for key in list(state_dict.keys()):
323
+ v = state_dict[key]
324
+ v = v.detach().clone().to("cpu").to(dtype)
325
+ state_dict[key] = v
326
+
327
+ if os.path.splitext(file)[1] == ".safetensors":
328
+ from safetensors.torch import save_file
329
+
330
+ # Precalculate model hashes to save time on indexing
331
+ if metadata is None:
332
+ metadata = {}
333
+ model_hash, legacy_hash = precalculate_safetensors_hashes(state_dict, metadata)
334
+ metadata["sshs_model_hash"] = model_hash
335
+ metadata["sshs_legacy_hash"] = legacy_hash
336
+
337
+ save_file(state_dict, file, metadata)
338
+ else:
339
+ torch.save(state_dict, file)
340
+
341
+ def create_network(
342
+ multiplier: float,
343
+ network_dim: Optional[int],
344
+ network_alpha: Optional[float],
345
+ text_encoder: Union[T5EncoderModel, List[T5EncoderModel]],
346
+ transformer,
347
+ neuron_dropout: Optional[float] = None,
348
+ add_lora_in_attn_temporal: bool = False,
349
+ **kwargs,
350
+ ):
351
+ if network_dim is None:
352
+ network_dim = 4 # default
353
+ if network_alpha is None:
354
+ network_alpha = 1.0
355
+
356
+ network = LoRANetwork(
357
+ text_encoder,
358
+ transformer,
359
+ multiplier=multiplier,
360
+ lora_dim=network_dim,
361
+ alpha=network_alpha,
362
+ dropout=neuron_dropout,
363
+ add_lora_in_attn_temporal=add_lora_in_attn_temporal,
364
+ varbose=True,
365
+ )
366
+ return network
367
+
368
+ def merge_lora(pipeline, lora_path, multiplier, device='cpu', dtype=torch.float32, state_dict=None, transformer_only=False):
369
+ LORA_PREFIX_TRANSFORMER = "lora_unet"
370
+ LORA_PREFIX_TEXT_ENCODER = "lora_te"
371
+ if state_dict is None:
372
+ state_dict = load_file(lora_path, device=device)
373
+ else:
374
+ state_dict = state_dict
375
+ updates = defaultdict(dict)
376
+ for key, value in state_dict.items():
377
+ layer, elem = key.split('.', 1)
378
+ updates[layer][elem] = value
379
+
380
+ for layer, elems in updates.items():
381
+
382
+ if "lora_te" in layer:
383
+ if transformer_only:
384
+ continue
385
+ else:
386
+ layer_infos = layer.split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_")
387
+ curr_layer = pipeline.text_encoder
388
+ else:
389
+ layer_infos = layer.split(LORA_PREFIX_TRANSFORMER + "_")[-1].split("_")
390
+ curr_layer = pipeline.transformer
391
+
392
+ temp_name = layer_infos.pop(0)
393
+ while len(layer_infos) > -1:
394
+ try:
395
+ curr_layer = curr_layer.__getattr__(temp_name)
396
+ if len(layer_infos) > 0:
397
+ temp_name = layer_infos.pop(0)
398
+ elif len(layer_infos) == 0:
399
+ break
400
+ except Exception:
401
+ if len(layer_infos) == 0:
402
+ print('Error loading layer')
403
+ if len(temp_name) > 0:
404
+ temp_name += "_" + layer_infos.pop(0)
405
+ else:
406
+ temp_name = layer_infos.pop(0)
407
+
408
+ weight_up = elems['lora_up.weight'].to(dtype)
409
+ weight_down = elems['lora_down.weight'].to(dtype)
410
+ if 'alpha' in elems.keys():
411
+ alpha = elems['alpha'].item() / weight_up.shape[1]
412
+ else:
413
+ alpha = 1.0
414
+
415
+ curr_layer.weight.data = curr_layer.weight.data.to(device)
416
+ if len(weight_up.shape) == 4:
417
+ curr_layer.weight.data += multiplier * alpha * torch.mm(weight_up.squeeze(3).squeeze(2),
418
+ weight_down.squeeze(3).squeeze(2)).unsqueeze(
419
+ 2).unsqueeze(3)
420
+ else:
421
+ curr_layer.weight.data += multiplier * alpha * torch.mm(weight_up, weight_down)
422
+
423
+ return pipeline
424
+
425
+ # TODO: Refactor with merge_lora.
426
+ def unmerge_lora(pipeline, lora_path, multiplier=1, device="cpu", dtype=torch.float32):
427
+ """Unmerge state_dict in LoRANetwork from the pipeline in diffusers."""
428
+ LORA_PREFIX_UNET = "lora_unet"
429
+ LORA_PREFIX_TEXT_ENCODER = "lora_te"
430
+ state_dict = load_file(lora_path, device=device)
431
+
432
+ updates = defaultdict(dict)
433
+ for key, value in state_dict.items():
434
+ layer, elem = key.split('.', 1)
435
+ updates[layer][elem] = value
436
+
437
+ for layer, elems in updates.items():
438
+
439
+ if "lora_te" in layer:
440
+ layer_infos = layer.split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_")
441
+ curr_layer = pipeline.text_encoder
442
+ else:
443
+ layer_infos = layer.split(LORA_PREFIX_UNET + "_")[-1].split("_")
444
+ curr_layer = pipeline.transformer
445
+
446
+ temp_name = layer_infos.pop(0)
447
+ while len(layer_infos) > -1:
448
+ try:
449
+ curr_layer = curr_layer.__getattr__(temp_name)
450
+ if len(layer_infos) > 0:
451
+ temp_name = layer_infos.pop(0)
452
+ elif len(layer_infos) == 0:
453
+ break
454
+ except Exception:
455
+ if len(layer_infos) == 0:
456
+ print('Error loading layer')
457
+ if len(temp_name) > 0:
458
+ temp_name += "_" + layer_infos.pop(0)
459
+ else:
460
+ temp_name = layer_infos.pop(0)
461
+
462
+ weight_up = elems['lora_up.weight'].to(dtype)
463
+ weight_down = elems['lora_down.weight'].to(dtype)
464
+ if 'alpha' in elems.keys():
465
+ alpha = elems['alpha'].item() / weight_up.shape[1]
466
+ else:
467
+ alpha = 1.0
468
+
469
+ curr_layer.weight.data = curr_layer.weight.data.to(device)
470
+ if len(weight_up.shape) == 4:
471
+ curr_layer.weight.data -= multiplier * alpha * torch.mm(weight_up.squeeze(3).squeeze(2),
472
+ weight_down.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
473
+ else:
474
+ curr_layer.weight.data -= multiplier * alpha * torch.mm(weight_up, weight_down)
475
+
476
+ return pipeline
easyanimate/utils/respace.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from OpenAI's diffusion repos
2
+ # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3
+ # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4
+ # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5
+
6
+ import numpy as np
7
+ import torch as th
8
+
9
+ from .gaussian_diffusion import GaussianDiffusion
10
+
11
+
12
+ def space_timesteps(num_timesteps, section_counts):
13
+ """
14
+ Create a list of timesteps to use from an original diffusion process,
15
+ given the number of timesteps we want to take from equally-sized portions
16
+ of the original process.
17
+ For example, if there's 300 timesteps and the section counts are [10,15,20]
18
+ then the first 100 timesteps are strided to be 10 timesteps, the second 100
19
+ are strided to be 15 timesteps, and the final 100 are strided to be 20.
20
+ If the stride is a string starting with "ddim", then the fixed striding
21
+ from the DDIM paper is used, and only one section is allowed.
22
+ :param num_timesteps: the number of diffusion steps in the original
23
+ process to divide up.
24
+ :param section_counts: either a list of numbers, or a string containing
25
+ comma-separated numbers, indicating the step count
26
+ per section. As a special case, use "ddimN" where N
27
+ is a number of steps to use the striding from the
28
+ DDIM paper.
29
+ :return: a set of diffusion steps from the original process to use.
30
+ """
31
+ if isinstance(section_counts, str):
32
+ if section_counts.startswith("ddim"):
33
+ desired_count = int(section_counts[len("ddim") :])
34
+ for i in range(1, num_timesteps):
35
+ if len(range(0, num_timesteps, i)) == desired_count:
36
+ return set(range(0, num_timesteps, i))
37
+ raise ValueError(
38
+ f"cannot create exactly {num_timesteps} steps with an integer stride"
39
+ )
40
+ section_counts = [int(x) for x in section_counts.split(",")]
41
+ size_per = num_timesteps // len(section_counts)
42
+ extra = num_timesteps % len(section_counts)
43
+ start_idx = 0
44
+ all_steps = []
45
+ for i, section_count in enumerate(section_counts):
46
+ size = size_per + (1 if i < extra else 0)
47
+ if size < section_count:
48
+ raise ValueError(
49
+ f"cannot divide section of {size} steps into {section_count}"
50
+ )
51
+ frac_stride = 1 if section_count <= 1 else (size - 1) / (section_count - 1)
52
+ cur_idx = 0.0
53
+ taken_steps = []
54
+ for _ in range(section_count):
55
+ taken_steps.append(start_idx + round(cur_idx))
56
+ cur_idx += frac_stride
57
+ all_steps += taken_steps
58
+ start_idx += size
59
+ return set(all_steps)
60
+
61
+
62
+ class SpacedDiffusion(GaussianDiffusion):
63
+ """
64
+ A diffusion process which can skip steps in a base diffusion process.
65
+ :param use_timesteps: a collection (sequence or set) of timesteps from the
66
+ original diffusion process to retain.
67
+ :param kwargs: the kwargs to create the base diffusion process.
68
+ """
69
+
70
+ def __init__(self, use_timesteps, **kwargs):
71
+ self.use_timesteps = set(use_timesteps)
72
+ self.timestep_map = []
73
+ self.original_num_steps = len(kwargs["betas"])
74
+
75
+ base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa
76
+ last_alpha_cumprod = 1.0
77
+ new_betas = []
78
+ for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
79
+ if i in self.use_timesteps:
80
+ new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
81
+ last_alpha_cumprod = alpha_cumprod
82
+ self.timestep_map.append(i)
83
+ kwargs["betas"] = np.array(new_betas)
84
+ super().__init__(**kwargs)
85
+
86
+ def p_mean_variance(
87
+ self, model, *args, **kwargs
88
+ ): # pylint: disable=signature-differs
89
+ return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
90
+
91
+ def training_losses(
92
+ self, model, *args, **kwargs
93
+ ): # pylint: disable=signature-differs
94
+ return super().training_losses(self._wrap_model(model), *args, **kwargs)
95
+
96
+ def training_losses_diffusers(
97
+ self, model, *args, **kwargs
98
+ ): # pylint: disable=signature-differs
99
+ return super().training_losses_diffusers(self._wrap_model(model), *args, **kwargs)
100
+
101
+ def condition_mean(self, cond_fn, *args, **kwargs):
102
+ return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs)
103
+
104
+ def condition_score(self, cond_fn, *args, **kwargs):
105
+ return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs)
106
+
107
+ def _wrap_model(self, model):
108
+ if isinstance(model, _WrappedModel):
109
+ return model
110
+ return _WrappedModel(
111
+ model, self.timestep_map, self.original_num_steps
112
+ )
113
+
114
+ def _scale_timesteps(self, t):
115
+ # Scaling is done by the wrapped model.
116
+ return t
117
+
118
+
119
+ class _WrappedModel:
120
+ def __init__(self, model, timestep_map, original_num_steps):
121
+ self.model = model
122
+ self.timestep_map = timestep_map
123
+ # self.rescale_timesteps = rescale_timesteps
124
+ self.original_num_steps = original_num_steps
125
+
126
+ def __call__(self, x, timestep, **kwargs):
127
+ map_tensor = th.tensor(self.timestep_map, device=timestep.device, dtype=timestep.dtype)
128
+ new_ts = map_tensor[timestep]
129
+ # if self.rescale_timesteps:
130
+ # new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
131
+ return self.model(x, timestep=new_ts, **kwargs)
easyanimate/utils/utils.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import imageio
4
+ import numpy as np
5
+ import torch
6
+ import torchvision
7
+ import cv2
8
+ from einops import rearrange
9
+ from PIL import Image
10
+
11
+
12
+ def color_transfer(sc, dc):
13
+ """
14
+ Transfer color distribution from of sc, referred to dc.
15
+
16
+ Args:
17
+ sc (numpy.ndarray): input image to be transfered.
18
+ dc (numpy.ndarray): reference image
19
+
20
+ Returns:
21
+ numpy.ndarray: Transferred color distribution on the sc.
22
+ """
23
+
24
+ def get_mean_and_std(img):
25
+ x_mean, x_std = cv2.meanStdDev(img)
26
+ x_mean = np.hstack(np.around(x_mean, 2))
27
+ x_std = np.hstack(np.around(x_std, 2))
28
+ return x_mean, x_std
29
+
30
+ sc = cv2.cvtColor(sc, cv2.COLOR_RGB2LAB)
31
+ s_mean, s_std = get_mean_and_std(sc)
32
+ dc = cv2.cvtColor(dc, cv2.COLOR_RGB2LAB)
33
+ t_mean, t_std = get_mean_and_std(dc)
34
+ img_n = ((sc - s_mean) * (t_std / s_std)) + t_mean
35
+ np.putmask(img_n, img_n > 255, 255)
36
+ np.putmask(img_n, img_n < 0, 0)
37
+ dst = cv2.cvtColor(cv2.convertScaleAbs(img_n), cv2.COLOR_LAB2RGB)
38
+ return dst
39
+
40
+ def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=12, imageio_backend=True, color_transfer_post_process=False):
41
+ videos = rearrange(videos, "b c t h w -> t b c h w")
42
+ outputs = []
43
+ for x in videos:
44
+ x = torchvision.utils.make_grid(x, nrow=n_rows)
45
+ x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
46
+ if rescale:
47
+ x = (x + 1.0) / 2.0 # -1,1 -> 0,1
48
+ x = (x * 255).numpy().astype(np.uint8)
49
+ outputs.append(Image.fromarray(x))
50
+
51
+ if color_transfer_post_process:
52
+ for i in range(1, len(outputs)):
53
+ outputs[i] = Image.fromarray(color_transfer(np.uint8(outputs[i]), np.uint8(outputs[0])))
54
+
55
+ os.makedirs(os.path.dirname(path), exist_ok=True)
56
+ if imageio_backend:
57
+ if path.endswith("mp4"):
58
+ imageio.mimsave(path, outputs, fps=fps)
59
+ else:
60
+ imageio.mimsave(path, outputs, duration=(1000 * 1/fps))
61
+ else:
62
+ if path.endswith("mp4"):
63
+ path = path.replace('.mp4', '.gif')
64
+ outputs[0].save(path, format='GIF', append_images=outputs, save_all=True, duration=100, loop=0)
easyanimate/vae/LICENSE ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Copyright (c) 2022 Robin Rombach and Patrick Esser and contributors
2
+
3
+ CreativeML Open RAIL-M
4
+ dated August 22, 2022
5
+
6
+ Section I: PREAMBLE
7
+
8
+ Multimodal generative models are being widely adopted and used, and have the potential to transform the way artists, among other individuals, conceive and benefit from AI or ML technologies as a tool for content creation.
9
+
10
+ Notwithstanding the current and potential benefits that these artifacts can bring to society at large, there are also concerns about potential misuses of them, either due to their technical limitations or ethical considerations.
11
+
12
+ In short, this license strives for both the open and responsible downstream use of the accompanying model. When it comes to the open character, we took inspiration from open source permissive licenses regarding the grant of IP rights. Referring to the downstream responsible use, we added use-based restrictions not permitting the use of the Model in very specific scenarios, in order for the licensor to be able to enforce the license in case potential misuses of the Model may occur. At the same time, we strive to promote open and responsible research on generative models for art and content generation.
13
+
14
+ Even though downstream derivative versions of the model could be released under different licensing terms, the latter will always have to include - at minimum - the same use-based restrictions as the ones in the original license (this license). We believe in the intersection between open and responsible AI development; thus, this License aims to strike a balance between both in order to enable responsible open-science in the field of AI.
15
+
16
+ This License governs the use of the model (and its derivatives) and is informed by the model card associated with the model.
17
+
18
+ NOW THEREFORE, You and Licensor agree as follows:
19
+
20
+ 1. Definitions
21
+
22
+ - "License" means the terms and conditions for use, reproduction, and Distribution as defined in this document.
23
+ - "Data" means a collection of information and/or content extracted from the dataset used with the Model, including to train, pretrain, or otherwise evaluate the Model. The Data is not licensed under this License.
24
+ - "Output" means the results of operating a Model as embodied in informational content resulting therefrom.
25
+ - "Model" means any accompanying machine-learning based assemblies (including checkpoints), consisting of learnt weights, parameters (including optimizer states), corresponding to the model architecture as embodied in the Complementary Material, that have been trained or tuned, in whole or in part on the Data, using the Complementary Material.
26
+ - "Derivatives of the Model" means all modifications to the Model, works based on the Model, or any other model which is created or initialized by transfer of patterns of the weights, parameters, activations or output of the Model, to the other model, in order to cause the other model to perform similarly to the Model, including - but not limited to - distillation methods entailing the use of intermediate data representations or methods based on the generation of synthetic data by the Model for training the other model.
27
+ - "Complementary Material" means the accompanying source code and scripts used to define, run, load, benchmark or evaluate the Model, and used to prepare data for training or evaluation, if any. This includes any accompanying documentation, tutorials, examples, etc, if any.
28
+ - "Distribution" means any transmission, reproduction, publication or other sharing of the Model or Derivatives of the Model to a third party, including providing the Model as a hosted service made available by electronic or other remote means - e.g. API-based or web access.
29
+ - "Licensor" means the copyright owner or entity authorized by the copyright owner that is granting the License, including the persons or entities that may have rights in the Model and/or distributing the Model.
30
+ - "You" (or "Your") means an individual or Legal Entity exercising permissions granted by this License and/or making use of the Model for whichever purpose and in any field of use, including usage of the Model in an end-use application - e.g. chatbot, translator, image generator.
31
+ - "Third Parties" means individuals or legal entities that are not under common control with Licensor or You.
32
+ - "Contribution" means any work of authorship, including the original version of the Model and any modifications or additions to that Model or Derivatives of the Model thereof, that is intentionally submitted to Licensor for inclusion in the Model by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Model, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution."
33
+ - "Contributor" means Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Model.
34
+
35
+ Section II: INTELLECTUAL PROPERTY RIGHTS
36
+
37
+ Both copyright and patent grants apply to the Model, Derivatives of the Model and Complementary Material. The Model and Derivatives of the Model are subject to additional terms as described in Section III.
38
+
39
+ 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare, publicly display, publicly perform, sublicense, and distribute the Complementary Material, the Model, and Derivatives of the Model.
40
+ 3. Grant of Patent License. Subject to the terms and conditions of this License and where and as applicable, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this paragraph) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Model and the Complementary Material, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Model to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Model and/or Complementary Material or a Contribution incorporated within the Model and/or Complementary Material constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for the Model and/or Work shall terminate as of the date such litigation is asserted or filed.
41
+
42
+ Section III: CONDITIONS OF USAGE, DISTRIBUTION AND REDISTRIBUTION
43
+
44
+ 4. Distribution and Redistribution. You may host for Third Party remote access purposes (e.g. software-as-a-service), reproduce and distribute copies of the Model or Derivatives of the Model thereof in any medium, with or without modifications, provided that You meet the following conditions:
45
+ Use-based restrictions as referenced in paragraph 5 MUST be included as an enforceable provision by You in any type of legal agreement (e.g. a license) governing the use and/or distribution of the Model or Derivatives of the Model, and You shall give notice to subsequent users You Distribute to, that the Model or Derivatives of the Model are subject to paragraph 5. This provision does not apply to the use of Complementary Material.
46
+ You must give any Third Party recipients of the Model or Derivatives of the Model a copy of this License;
47
+ You must cause any modified files to carry prominent notices stating that You changed the files;
48
+ You must retain all copyright, patent, trademark, and attribution notices excluding those notices that do not pertain to any part of the Model, Derivatives of the Model.
49
+ You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions - respecting paragraph 4.a. - for use, reproduction, or Distribution of Your modifications, or for any such Derivatives of the Model as a whole, provided Your use, reproduction, and Distribution of the Model otherwise complies with the conditions stated in this License.
50
+ 5. Use-based restrictions. The restrictions set forth in Attachment A are considered Use-based restrictions. Therefore You cannot use the Model and the Derivatives of the Model for the specified restricted uses. You may use the Model subject to this License, including only for lawful purposes and in accordance with the License. Use may include creating any content with, finetuning, updating, running, training, evaluating and/or reparametrizing the Model. You shall require all of Your users who use the Model or a Derivative of the Model to comply with the terms of this paragraph (paragraph 5).
51
+ 6. The Output You Generate. Except as set forth herein, Licensor claims no rights in the Output You generate using the Model. You are accountable for the Output you generate and its subsequent uses. No use of the output can contravene any provision as stated in the License.
52
+
53
+ Section IV: OTHER PROVISIONS
54
+
55
+ 7. Updates and Runtime Restrictions. To the maximum extent permitted by law, Licensor reserves the right to restrict (remotely or otherwise) usage of the Model in violation of this License, update the Model through electronic means, or modify the Output of the Model based on updates. You shall undertake reasonable efforts to use the latest version of the Model.
56
+ 8. Trademarks and related. Nothing in this License permits You to make use of Licensors’ trademarks, trade names, logos or to otherwise suggest endorsement or misrepresent the relationship between the parties; and any rights not expressly granted herein are reserved by the Licensors.
57
+ 9. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Model and the Complementary Material (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Model, Derivatives of the Model, and the Complementary Material and assume any risks associated with Your exercise of permissions under this License.
58
+ 10. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Model and the Complementary Material (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages.
59
+ 11. Accepting Warranty or Additional Liability. While redistributing the Model, Derivatives of the Model and the Complementary Material thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability.
60
+ 12. If any provision of this License is held to be invalid, illegal or unenforceable, the remaining provisions shall be unaffected thereby and remain valid as if such provision had not been set forth herein.
61
+
62
+ END OF TERMS AND CONDITIONS
63
+
64
+
65
+
66
+
67
+ Attachment A
68
+
69
+ Use Restrictions
70
+
71
+ You agree not to use the Model or Derivatives of the Model:
72
+ - In any way that violates any applicable national, federal, state, local or international law or regulation;
73
+ - For the purpose of exploiting, harming or attempting to exploit or harm minors in any way;
74
+ - To generate or disseminate verifiably false information and/or content with the purpose of harming others;
75
+ - To generate or disseminate personal identifiable information that can be used to harm an individual;
76
+ - To defame, disparage or otherwise harass others;
77
+ - For fully automated decision making that adversely impacts an individual’s legal rights or otherwise creates or modifies a binding, enforceable obligation;
78
+ - For any use intended to or which has the effect of discriminating against or harming individuals or groups based on online or offline social behavior or known or predicted personal or personality characteristics;
79
+ - To exploit any of the vulnerabilities of a specific group of persons based on their age, social, physical or mental characteristics, in order to materially distort the behavior of a person pertaining to that group in a manner that causes or is likely to cause that person or another person physical or psychological harm;
80
+ - For any use intended to or which has the effect of discriminating against individuals or groups based on legally protected characteristics or categories;
81
+ - To provide medical advice and medical results interpretation;
82
+ - To generate or disseminate information for the purpose to be used for administration of justice, law enforcement, immigration or asylum processes, such as predicting an individual will commit fraud/crime commitment (e.g. by text profiling, drawing causal relationships between assertions made in documents, indiscriminate and arbitrarily-targeted use).
easyanimate/vae/README.md ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## VAE Training
2
+
3
+ English | [简体中文](./README_zh-CN.md)
4
+
5
+ After completing data preprocessing, we can obtain the following dataset:
6
+
7
+ ```
8
+ 📦 project/
9
+ ├── 📂 datasets/
10
+ │ ├── 📂 internal_datasets/
11
+ │ ├── 📂 videos/
12
+ │ │ ├── 📄 00000001.mp4
13
+ │ │ ├── 📄 00000001.jpg
14
+ │ │ └── 📄 .....
15
+ │ └── 📄 json_of_internal_datasets.json
16
+ ```
17
+
18
+ The json_of_internal_datasets.json is a standard JSON file. The file_path in the json can to be set as relative path, as shown in below:
19
+ ```json
20
+ [
21
+ {
22
+ "file_path": "videos/00000001.mp4",
23
+ "text": "A group of young men in suits and sunglasses are walking down a city street.",
24
+ "type": "video"
25
+ },
26
+ {
27
+ "file_path": "train/00000001.jpg",
28
+ "text": "A group of young men in suits and sunglasses are walking down a city street.",
29
+ "type": "image"
30
+ },
31
+ .....
32
+ ]
33
+ ```
34
+
35
+ You can also set the path as absolute path as follow:
36
+ ```json
37
+ [
38
+ {
39
+ "file_path": "/mnt/data/videos/00000001.mp4",
40
+ "text": "A group of young men in suits and sunglasses are walking down a city street.",
41
+ "type": "video"
42
+ },
43
+ {
44
+ "file_path": "/mnt/data/train/00000001.jpg",
45
+ "text": "A group of young men in suits and sunglasses are walking down a city street.",
46
+ "type": "image"
47
+ },
48
+ .....
49
+ ]
50
+ ```
51
+
52
+ ## Train Video VAE
53
+ We need to set config in ```easyanimate/vae/configs/autoencoder``` at first. The default config is ```autoencoder_kl_32x32x4_slice.yaml```. We need to set the some params in yaml file.
54
+
55
+ - ```data_json_path``` corresponds to the JSON file of the dataset.
56
+ - ```data_root``` corresponds to the root path of the dataset. If you want to use absolute path in json file, please delete this line.
57
+ - ```ckpt_path``` corresponds to the pretrained weights of the vae.
58
+ - ```gpus``` and num_nodes need to be set as the actual situation of your machine.
59
+
60
+ The we run shell file as follow:
61
+ ```
62
+ sh scripts/train_vae.sh
63
+ ```
easyanimate/vae/README_zh-CN.md ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## VAE 训练
2
+
3
+ [English](./README.md) | 简体中文
4
+
5
+ 在完成数据预处理后,你可以获得这样的数据格式:
6
+
7
+ ```
8
+ 📦 project/
9
+ ├── 📂 datasets/
10
+ │ ├── 📂 internal_datasets/
11
+ │ ├── 📂 videos/
12
+ │ │ ├── 📄 00000001.mp4
13
+ │ │ ├── 📄 00000001.jpg
14
+ │ │ └── 📄 .....
15
+ │ └── 📄 json_of_internal_datasets.json
16
+ ```
17
+
18
+ json_of_internal_datasets.json是一个标准的json文件。json中的file_path可以被设置为相对路径,如下所示:
19
+ ```json
20
+ [
21
+ {
22
+ "file_path": "videos/00000001.mp4",
23
+ "text": "A group of young men in suits and sunglasses are walking down a city street.",
24
+ "type": "video"
25
+ },
26
+ {
27
+ "file_path": "train/00000001.jpg",
28
+ "text": "A group of young men in suits and sunglasses are walking down a city street.",
29
+ "type": "image"
30
+ },
31
+ .....
32
+ ]
33
+ ```
34
+
35
+ 你也可以将路径设置为绝对路径:
36
+ ```json
37
+ [
38
+ {
39
+ "file_path": "/mnt/data/videos/00000001.mp4",
40
+ "text": "A group of young men in suits and sunglasses are walking down a city street.",
41
+ "type": "video"
42
+ },
43
+ {
44
+ "file_path": "/mnt/data/train/00000001.jpg",
45
+ "text": "A group of young men in suits and sunglasses are walking down a city street.",
46
+ "type": "image"
47
+ },
48
+ .....
49
+ ]
50
+ ```
51
+
52
+ ## 训练 Video VAE
53
+ 我们首先需要修改 ```easyanimate/vae/configs/autoencoder``` 中的配置文件。默认的配置文件是 ```autoencoder_kl_32x32x4_slice.yaml```。你需要修改以下参数:
54
+
55
+ - ```data_json_path``` json file 所在的目录。
56
+ - ```data_root``` 数据的根目录。如果你在json file中使用了绝对路径,请设置为空。
57
+ - ```ckpt_path``` 预训练的vae模型路径。
58
+ - ```gpus``` 以及 ```num_nodes``` 需要设置为你机器的实际gpu数目。
59
+
60
+ 运行以下的脚本来训练vae:
61
+ ```
62
+ sh scripts/train_vae.sh
63
+ ```
easyanimate/vae/configs/autoencoder/autoencoder_kl_32x32x4_mag.yaml ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 1.0e-04
3
+ target: easyanimate.vae.ldm.models.omnigen_casual3dcnn.AutoencoderKLMagvit_fromOmnigen
4
+ params:
5
+ monitor: train/rec_loss
6
+ ckpt_path: models/videoVAE_omnigen_8x8x4_from_vae-ft-mse-840000-ema-pruned.ckpt
7
+ down_block_types: ("SpatialDownBlock3D", "SpatialTemporalDownBlock3D", "SpatialTemporalDownBlock3D",
8
+ "SpatialTemporalDownBlock3D",)
9
+ up_block_types: ("SpatialUpBlock3D", "SpatialTemporalUpBlock3D", "SpatialTemporalUpBlock3D",
10
+ "SpatialTemporalUpBlock3D",)
11
+ lossconfig:
12
+ target: easyanimate.vae.ldm.modules.losses.LPIPSWithDiscriminator
13
+ params:
14
+ disc_start: 50001
15
+ kl_weight: 1.0e-06
16
+ disc_weight: 0.5
17
+ l2_loss_weight: 0.1
18
+ l1_loss_weight: 1.0
19
+ perceptual_weight: 1.0
20
+
21
+ data:
22
+ target: train_vae.DataModuleFromConfig
23
+
24
+ params:
25
+ batch_size: 2
26
+ wrap: true
27
+ num_workers: 4
28
+ train:
29
+ target: easyanimate.vae.ldm.data.dataset_image_video.CustomSRTrain
30
+ params:
31
+ data_json_path: pretrain.json
32
+ data_root: /your_data_root # This is used in relative path
33
+ size: 128
34
+ degradation: pil_nearest
35
+ video_size: 128
36
+ video_len: 9
37
+ slice_interval: 1
38
+ validation:
39
+ target: easyanimate.vae.ldm.data.dataset_image_video.CustomSRValidation
40
+ params:
41
+ data_json_path: pretrain.json
42
+ data_root: /your_data_root # This is used in relative path
43
+ size: 128
44
+ degradation: pil_nearest
45
+ video_size: 128
46
+ video_len: 9
47
+ slice_interval: 1
48
+
49
+ lightning:
50
+ callbacks:
51
+ image_logger:
52
+ target: train_vae.ImageLogger
53
+ params:
54
+ batch_frequency: 5000
55
+ max_images: 8
56
+ increase_log_steps: True
57
+
58
+ trainer:
59
+ benchmark: True
60
+ accumulate_grad_batches: 1
61
+ gpus: "0"
62
+ num_nodes: 1
easyanimate/vae/configs/autoencoder/autoencoder_kl_32x32x4_slice.yaml ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 1.0e-04
3
+ target: easyanimate.vae.ldm.models.omnigen_casual3dcnn.AutoencoderKLMagvit_fromOmnigen
4
+ params:
5
+ slice_compression_vae: true
6
+ mini_batch_encoder: 8
7
+ mini_batch_decoder: 2
8
+ monitor: train/rec_loss
9
+ ckpt_path: models/Diffusion_Transformer/EasyAnimateV2-XL-2-512x512/vae/diffusion_pytorch_model.safetensors
10
+ down_block_types: ("SpatialDownBlock3D", "SpatialTemporalDownBlock3D", "SpatialTemporalDownBlock3D",
11
+ "SpatialTemporalDownBlock3D",)
12
+ up_block_types: ("SpatialUpBlock3D", "SpatialTemporalUpBlock3D", "SpatialTemporalUpBlock3D",
13
+ "SpatialTemporalUpBlock3D",)
14
+ lossconfig:
15
+ target: easyanimate.vae.ldm.modules.losses.LPIPSWithDiscriminator
16
+ params:
17
+ disc_start: 50001
18
+ kl_weight: 1.0e-06
19
+ disc_weight: 0.5
20
+ l2_loss_weight: 0.0
21
+ l1_loss_weight: 1.0
22
+ perceptual_weight: 1.0
23
+
24
+ data:
25
+ target: train_vae.DataModuleFromConfig
26
+
27
+ params:
28
+ batch_size: 1
29
+ wrap: true
30
+ num_workers: 8
31
+ train:
32
+ target: easyanimate.vae.ldm.data.dataset_image_video.CustomSRTrain
33
+ params:
34
+ data_json_path: pretrain.json
35
+ data_root: /your_data_root # This is used in relative path
36
+ size: 256
37
+ degradation: pil_nearest
38
+ video_size: 256
39
+ video_len: 25
40
+ slice_interval: 1
41
+ validation:
42
+ target: easyanimate.vae.ldm.data.dataset_image_video.CustomSRValidation
43
+ params:
44
+ data_json_path: pretrain.json
45
+ data_root: /your_data_root # This is used in relative path
46
+ size: 256
47
+ degradation: pil_nearest
48
+ video_size: 256
49
+ video_len: 25
50
+ slice_interval: 1
51
+
52
+ lightning:
53
+ callbacks:
54
+ image_logger:
55
+ target: train_vae.ImageLogger
56
+ params:
57
+ batch_frequency: 5000
58
+ max_images: 8
59
+ increase_log_steps: True
60
+
61
+ trainer:
62
+ benchmark: True
63
+ accumulate_grad_batches: 1
64
+ gpus: "0"
65
+ num_nodes: 1
easyanimate/vae/configs/autoencoder/autoencoder_kl_32x32x4_slice_decoder_only.yaml ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 1.0e-04
3
+ target: easyanimate.vae.ldm.models.omnigen_casual3dcnn.AutoencoderKLMagvit_fromOmnigen
4
+ params:
5
+ slice_compression_vae: true
6
+ train_decoder_only: true
7
+ mini_batch_encoder: 8
8
+ mini_batch_decoder: 2
9
+ monitor: train/rec_loss
10
+ ckpt_path: models/Diffusion_Transformer/EasyAnimateV2-XL-2-512x512/vae/diffusion_pytorch_model.safetensors
11
+ down_block_types: ("SpatialDownBlock3D", "SpatialTemporalDownBlock3D", "SpatialTemporalDownBlock3D",
12
+ "SpatialTemporalDownBlock3D",)
13
+ up_block_types: ("SpatialUpBlock3D", "SpatialTemporalUpBlock3D", "SpatialTemporalUpBlock3D",
14
+ "SpatialTemporalUpBlock3D",)
15
+ lossconfig:
16
+ target: easyanimate.vae.ldm.modules.losses.LPIPSWithDiscriminator
17
+ params:
18
+ disc_start: 50001
19
+ kl_weight: 1.0e-06
20
+ disc_weight: 0.5
21
+ l2_loss_weight: 1.0
22
+ l1_loss_weight: 0.0
23
+ perceptual_weight: 1.0
24
+
25
+ data:
26
+ target: train_vae.DataModuleFromConfig
27
+
28
+ params:
29
+ batch_size: 1
30
+ wrap: true
31
+ num_workers: 8
32
+ train:
33
+ target: easyanimate.vae.ldm.data.dataset_image_video.CustomSRTrain
34
+ params:
35
+ data_json_path: pretrain.json
36
+ data_root: /your_data_root # This is used in relative path
37
+ size: 256
38
+ degradation: pil_nearest
39
+ video_size: 256
40
+ video_len: 25
41
+ slice_interval: 1
42
+ validation:
43
+ target: easyanimate.vae.ldm.data.dataset_image_video.CustomSRValidation
44
+ params:
45
+ data_json_path: pretrain.json
46
+ data_root: /your_data_root # This is used in relative path
47
+ size: 256
48
+ degradation: pil_nearest
49
+ video_size: 256
50
+ video_len: 25
51
+ slice_interval: 1
52
+
53
+ lightning:
54
+ callbacks:
55
+ image_logger:
56
+ target: train_vae.ImageLogger
57
+ params:
58
+ batch_frequency: 5000
59
+ max_images: 8
60
+ increase_log_steps: True
61
+
62
+ trainer:
63
+ benchmark: True
64
+ accumulate_grad_batches: 1
65
+ gpus: "0"
66
+ num_nodes: 1
easyanimate/vae/configs/autoencoder/autoencoder_kl_32x32x4_slice_t_downsample_8.yaml ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 1.0e-04
3
+ target: easyanimate.vae.ldm.models.omnigen_casual3dcnn.AutoencoderKLMagvit_fromOmnigen
4
+ params:
5
+ slice_compression_vae: true
6
+ mini_batch_encoder: 8
7
+ mini_batch_decoder: 1
8
+ monitor: train/rec_loss
9
+ ckpt_path: models/Diffusion_Transformer/EasyAnimateV2-XL-2-512x512/vae/diffusion_pytorch_model.safetensors
10
+ down_block_types: ("SpatialTemporalDownBlock3D", "SpatialTemporalDownBlock3D", "SpatialTemporalDownBlock3D",
11
+ "SpatialTemporalDownBlock3D",)
12
+ up_block_types: ("SpatialTemporalUpBlock3D", "SpatialTemporalUpBlock3D", "SpatialTemporalUpBlock3D",
13
+ "SpatialTemporalUpBlock3D",)
14
+ lossconfig:
15
+ target: easyanimate.vae.ldm.modules.losses.LPIPSWithDiscriminator
16
+ params:
17
+ disc_start: 50001
18
+ kl_weight: 1.0e-06
19
+ disc_weight: 0.5
20
+ l2_loss_weight: 0.0
21
+ l1_loss_weight: 1.0
22
+ perceptual_weight: 1.0
23
+
24
+
25
+ data:
26
+ target: train_vae.DataModuleFromConfig
27
+
28
+ params:
29
+ batch_size: 1
30
+ wrap: true
31
+ num_workers: 8
32
+ train:
33
+ target: easyanimate.vae.ldm.data.dataset_image_video.CustomSRTrain
34
+ params:
35
+ data_json_path: pretrain.json
36
+ data_root: /your_data_root # This is used in relative path
37
+ size: 256
38
+ degradation: pil_nearest
39
+ video_size: 256
40
+ video_len: 33
41
+ slice_interval: 1
42
+ validation:
43
+ target: easyanimate.vae.ldm.data.dataset_image_video.CustomSRValidation
44
+ params:
45
+ data_json_path: pretrain.json
46
+ data_root: /your_data_root # This is used in relative path
47
+ size: 256
48
+ degradation: pil_nearest
49
+ video_size: 256
50
+ video_len: 33
51
+ slice_interval: 1
52
+
53
+ lightning:
54
+ callbacks:
55
+ image_logger:
56
+ target: train_vae.ImageLogger
57
+ params:
58
+ batch_frequency: 5000
59
+ max_images: 8
60
+ increase_log_steps: True
61
+
62
+ trainer:
63
+ benchmark: True
64
+ accumulate_grad_batches: 1
65
+ gpus: "0"
66
+ num_nodes: 1
easyanimate/vae/environment.yaml ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: ldm
2
+ channels:
3
+ - pytorch
4
+ - defaults
5
+ dependencies:
6
+ - python=3.8.5
7
+ - pip=20.3
8
+ - cudatoolkit=11.3
9
+ - pytorch=1.11.0
10
+ - torchvision=0.12.0
11
+ - numpy=1.19.2
12
+ - pip:
13
+ - albumentations==0.4.3
14
+ - diffusers
15
+ - opencv-python==4.1.2.30
16
+ - pudb==2019.2
17
+ - invisible-watermark
18
+ - imageio==2.9.0
19
+ - imageio-ffmpeg==0.4.2
20
+ - pytorch-lightning==1.4.2
21
+ - omegaconf==2.1.1
22
+ - test-tube>=0.7.5
23
+ - streamlit>=0.73.1
24
+ - einops==0.3.0
25
+ - torch-fidelity==0.3.0
26
+ - transformers==4.19.2
27
+ - torchmetrics==0.6.0
28
+ - kornia==0.6
29
+ - -e .
easyanimate/vae/ldm/data/__init__.py ADDED
File without changes
easyanimate/vae/ldm/data/base.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import abstractmethod
2
+
3
+ from torch.utils.data import (ChainDataset, ConcatDataset, Dataset,
4
+ IterableDataset)
5
+
6
+
7
+ class Txt2ImgIterableBaseDataset(IterableDataset):
8
+ '''
9
+ Define an interface to make the IterableDatasets for text2img data chainable
10
+ '''
11
+ def __init__(self, num_records=0, valid_ids=None, size=256):
12
+ super().__init__()
13
+ self.num_records = num_records
14
+ self.valid_ids = valid_ids
15
+ self.sample_ids = valid_ids
16
+ self.size = size
17
+
18
+ print(f'{self.__class__.__name__} dataset contains {self.__len__()} examples.')
19
+
20
+ def __len__(self):
21
+ return self.num_records
22
+
23
+ @abstractmethod
24
+ def __iter__(self):
25
+ pass
easyanimate/vae/ldm/data/dataset_callback.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #-*- encoding:utf-8 -*-
2
+ from pytorch_lightning.callbacks import Callback
3
+
4
+ class DatasetCallback(Callback):
5
+ def __init__(self):
6
+ self.sampler_pos_start = 0
7
+ self.preload_used_idx_flag = False
8
+
9
+ def on_train_start(self, trainer, pl_module):
10
+ if not self.preload_used_idx_flag:
11
+ self.preload_used_idx_flag = True
12
+ trainer.train_dataloader.batch_sampler.sampler_pos_reload = self.sampler_pos_start
13
+
14
+ def on_save_checkpoint(self, trainer, pl_module, checkpoint):
15
+ if trainer.train_dataloader is not None:
16
+ # Save sampler_pos_start parameters in the checkpoint
17
+ checkpoint['sampler_pos_start'] = trainer.train_dataloader.batch_sampler.sampler_pos_start
18
+
19
+ def on_load_checkpoint(self, trainer, pl_module, checkpoint):
20
+ # Restore sampler_pos_start parameters from the checkpoint
21
+ if 'sampler_pos_start' in checkpoint:
22
+ self.sampler_pos_start = checkpoint.get('sampler_pos_start', 0)
23
+ print('Load sampler_pos_start from checkpoint, sampler_pos_start = %d' % self.sampler_pos_start)
24
+ else:
25
+ print('The sampler_pos_start is not in checkpoint')
easyanimate/vae/ldm/data/dataset_image_video.py ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import json
3
+ import os
4
+ import pickle
5
+ import random
6
+ import shutil
7
+ import tarfile
8
+ from functools import partial
9
+
10
+ import albumentations
11
+ import cv2
12
+ import numpy as np
13
+ import PIL
14
+ import torchvision.transforms.functional as TF
15
+ import yaml
16
+ from decord import VideoReader
17
+ from func_timeout import FunctionTimedOut, func_set_timeout
18
+ from omegaconf import OmegaConf
19
+ from PIL import Image
20
+ from torch.utils.data import (BatchSampler, Dataset, Sampler)
21
+ from tqdm import tqdm
22
+
23
+ from ..modules.image_degradation import (degradation_fn_bsr,
24
+ degradation_fn_bsr_light)
25
+
26
+
27
+ class ImageVideoSampler(BatchSampler):
28
+ """A sampler wrapper for grouping images with similar aspect ratio into a same batch.
29
+
30
+ Args:
31
+ sampler (Sampler): Base sampler.
32
+ dataset (Dataset): Dataset providing data information.
33
+ batch_size (int): Size of mini-batch.
34
+ drop_last (bool): If ``True``, the sampler will drop the last batch if
35
+ its size would be less than ``batch_size``.
36
+ aspect_ratios (dict): The predefined aspect ratios.
37
+ """
38
+
39
+ def __init__(self,
40
+ sampler: Sampler,
41
+ dataset: Dataset,
42
+ batch_size: int,
43
+ drop_last: bool = False
44
+ ) -> None:
45
+ if not isinstance(sampler, Sampler):
46
+ raise TypeError('sampler should be an instance of ``Sampler``, '
47
+ f'but got {sampler}')
48
+ if not isinstance(batch_size, int) or batch_size <= 0:
49
+ raise ValueError('batch_size should be a positive integer value, '
50
+ f'but got batch_size={batch_size}')
51
+ self.sampler = sampler
52
+ self.dataset = dataset
53
+ self.batch_size = batch_size
54
+ self.drop_last = drop_last
55
+
56
+ self.sampler_pos_start = 0
57
+ self.sampler_pos_reload = 0
58
+
59
+ self.num_samples_random = len(self.sampler)
60
+ # buckets for each aspect ratio
61
+ self.bucket = {'image':[], 'video':[]}
62
+
63
+ def set_epoch(self, epoch):
64
+ if hasattr(self.sampler, "set_epoch"):
65
+ self.sampler.set_epoch(epoch)
66
+
67
+ def __iter__(self):
68
+ for index_sampler, idx in enumerate(self.sampler):
69
+ if self.sampler_pos_reload != 0 and self.sampler_pos_reload < self.num_samples_random:
70
+ if index_sampler < self.sampler_pos_reload:
71
+ self.sampler_pos_start = (self.sampler_pos_start + 1) % self.num_samples_random
72
+ continue
73
+ elif index_sampler == self.sampler_pos_reload:
74
+ self.sampler_pos_reload = 0
75
+
76
+ content_type = self.dataset.data.get_type(idx)
77
+ bucket = self.bucket[content_type]
78
+ bucket.append(idx)
79
+ # yield a batch of indices in the same aspect ratio group
80
+ if len(self.bucket['video']) == self.batch_size:
81
+ yield self.bucket['video']
82
+ self.bucket['video'] = []
83
+ elif len(self.bucket['image']) == self.batch_size:
84
+ yield self.bucket['image']
85
+ self.bucket['image'] = []
86
+ self.sampler_pos_start = (self.sampler_pos_start + 1) % self.num_samples_random
87
+
88
+ class ImageVideoDataset(Dataset):
89
+ # update __getitem__() from ImageNetSR. If timeout for Pandas70M, throw exception.
90
+ # If caught exception(timeout or others), try another index until successful and return.
91
+ def __init__(self, size=None, video_size=128, video_len=25,
92
+ degradation=None, downscale_f=4, random_crop=True, min_crop_f=0.25, max_crop_f=1.,
93
+ s_t=None, slice_interval=None, data_root=None
94
+ ):
95
+ """
96
+ Imagenet Superresolution Dataloader
97
+ Performs following ops in order:
98
+ 1. crops a crop of size s from image either as random or center crop
99
+ 2. resizes crop to size with cv2.area_interpolation
100
+ 3. degrades resized crop with degradation_fn
101
+
102
+ :param size: resizing to size after cropping
103
+ :param degradation: degradation_fn, e.g. cv_bicubic or bsrgan_light
104
+ :param downscale_f: Low Resolution Downsample factor
105
+ :param min_crop_f: determines crop size s,
106
+ where s = c * min_img_side_len with c sampled from interval (min_crop_f, max_crop_f)
107
+ :param max_crop_f: ""
108
+ :param data_root:
109
+ :param random_crop:
110
+ """
111
+ self.base = self.get_base()
112
+ assert size
113
+ assert (size / downscale_f).is_integer()
114
+ self.size = size
115
+ self.LR_size = int(size / downscale_f)
116
+ self.min_crop_f = min_crop_f
117
+ self.max_crop_f = max_crop_f
118
+ assert(max_crop_f <= 1.)
119
+ self.center_crop = not random_crop
120
+ self.s_t = s_t
121
+ self.slice_interval = slice_interval
122
+
123
+ self.image_rescaler = albumentations.SmallestMaxSize(max_size=size, interpolation=cv2.INTER_AREA)
124
+ self.video_rescaler = albumentations.SmallestMaxSize(max_size=video_size, interpolation=cv2.INTER_AREA)
125
+ self.video_len = video_len
126
+ self.video_size = video_size
127
+ self.data_root = data_root
128
+
129
+ self.pil_interpolation = False # gets reset later if incase interp_op is from pillow
130
+
131
+ if degradation == "bsrgan":
132
+ self.degradation_process = partial(degradation_fn_bsr, sf=downscale_f)
133
+
134
+ elif degradation == "bsrgan_light":
135
+ self.degradation_process = partial(degradation_fn_bsr_light, sf=downscale_f)
136
+ else:
137
+ interpolation_fn = {
138
+ "cv_nearest": cv2.INTER_NEAREST,
139
+ "cv_bilinear": cv2.INTER_LINEAR,
140
+ "cv_bicubic": cv2.INTER_CUBIC,
141
+ "cv_area": cv2.INTER_AREA,
142
+ "cv_lanczos": cv2.INTER_LANCZOS4,
143
+ "pil_nearest": PIL.Image.NEAREST,
144
+ "pil_bilinear": PIL.Image.BILINEAR,
145
+ "pil_bicubic": PIL.Image.BICUBIC,
146
+ "pil_box": PIL.Image.BOX,
147
+ "pil_hamming": PIL.Image.HAMMING,
148
+ "pil_lanczos": PIL.Image.LANCZOS,
149
+ }[degradation]
150
+
151
+ self.pil_interpolation = degradation.startswith("pil_")
152
+
153
+ if self.pil_interpolation:
154
+ self.degradation_process = partial(TF.resize, size=self.LR_size, interpolation=interpolation_fn)
155
+
156
+ else:
157
+ self.degradation_process = albumentations.SmallestMaxSize(max_size=self.LR_size,
158
+ interpolation=interpolation_fn)
159
+
160
+ def __len__(self):
161
+ return len(self.base)
162
+
163
+ def get_type(self, index):
164
+ return self.base[index].get('type', 'image')
165
+
166
+ def __getitem__(self, i):
167
+ @func_set_timeout(3) # time wait 3 seconds
168
+ def get_video_item(example):
169
+ if self.data_root is not None:
170
+ video_reader = VideoReader(os.path.join(self.data_root, example['file_path']))
171
+ else:
172
+ video_reader = VideoReader(example['file_path'])
173
+ video_length = len(video_reader)
174
+
175
+ clip_length = min(video_length, (self.video_len - 1) * self.slice_interval + 1)
176
+ start_idx = random.randint(0, video_length - clip_length)
177
+ batch_index = np.linspace(start_idx, start_idx + clip_length - 1, self.video_len, dtype=int)
178
+
179
+ pixel_values = video_reader.get_batch(batch_index).asnumpy()
180
+
181
+ del video_reader
182
+ out_images = []
183
+ LR_out_images = []
184
+ min_side_len = min(pixel_values[0].shape[:2])
185
+
186
+ crop_side_len = min_side_len * np.random.uniform(self.min_crop_f, self.max_crop_f, size=None)
187
+ crop_side_len = int(crop_side_len)
188
+ if self.center_crop:
189
+ self.cropper = albumentations.CenterCrop(height=crop_side_len, width=crop_side_len)
190
+ else:
191
+ self.cropper = albumentations.RandomCrop(height=crop_side_len, width=crop_side_len)
192
+
193
+ imgs = np.transpose(pixel_values, (1, 2, 3, 0))
194
+ imgs = self.cropper(image=imgs)["image"]
195
+ imgs = np.transpose(imgs, (3, 0, 1, 2))
196
+ for img in imgs:
197
+ image = self.video_rescaler(image=img)["image"]
198
+ out_images.append(image[None, :, :, :])
199
+ if self.pil_interpolation:
200
+ image_pil = PIL.Image.fromarray(image)
201
+ LR_image = self.degradation_process(image_pil)
202
+ LR_image = np.array(LR_image).astype(np.uint8)
203
+ else:
204
+ LR_image = self.degradation_process(image=image)["image"]
205
+ LR_out_images.append(LR_image[None, :, :, :])
206
+
207
+ example = {}
208
+ example['image'] = (np.concatenate(out_images) / 127.5 - 1.0).astype(np.float32)
209
+ example['LR_image'] = (np.concatenate(LR_out_images) / 127.5 - 1.0).astype(np.float32)
210
+ return example
211
+
212
+ example = self.base[i]
213
+ if example.get('type', 'image') == 'video':
214
+ while True:
215
+ try:
216
+ example = self.base[i]
217
+ return get_video_item(example)
218
+ except FunctionTimedOut:
219
+ print("stt catch: Function 'extract failed' timed out.")
220
+ i = random.randint(0, self.__len__() - 1)
221
+ except Exception as e:
222
+ print('stt catch', e)
223
+ i = random.randint(0, self.__len__() - 1)
224
+ elif example.get('type', 'image') == 'image':
225
+ while True:
226
+ try:
227
+ example = self.base[i]
228
+ if self.data_root is not None:
229
+ image = Image.open(os.path.join(self.data_root, example['file_path']))
230
+ else:
231
+ image = Image.open(example['file_path'])
232
+ image = image.convert("RGB")
233
+ image = np.array(image).astype(np.uint8)
234
+
235
+ min_side_len = min(image.shape[:2])
236
+ crop_side_len = min_side_len * np.random.uniform(self.min_crop_f, self.max_crop_f, size=None)
237
+ crop_side_len = int(crop_side_len)
238
+
239
+ if self.center_crop:
240
+ self.cropper = albumentations.CenterCrop(height=crop_side_len, width=crop_side_len)
241
+
242
+ else:
243
+ self.cropper = albumentations.RandomCrop(height=crop_side_len, width=crop_side_len)
244
+
245
+ image = self.cropper(image=image)["image"]
246
+
247
+ image = self.image_rescaler(image=image)["image"]
248
+
249
+ if self.pil_interpolation:
250
+ image_pil = PIL.Image.fromarray(image)
251
+ LR_image = self.degradation_process(image_pil)
252
+ LR_image = np.array(LR_image).astype(np.uint8)
253
+
254
+ else:
255
+ LR_image = self.degradation_process(image=image)["image"]
256
+
257
+ example = {}
258
+ example["image"] = (image/127.5 - 1.0).astype(np.float32)
259
+ example["LR_image"] = (LR_image/127.5 - 1.0).astype(np.float32)
260
+ return example
261
+ except Exception as e:
262
+ print("catch", e)
263
+ i = random.randint(0, self.__len__() - 1)
264
+
265
+ class CustomSRTrain(ImageVideoDataset):
266
+ def __init__(self, data_json_path, **kwargs):
267
+ self.data_json_path = data_json_path
268
+ super().__init__(**kwargs)
269
+
270
+ def get_base(self):
271
+ return [ann for ann in json.load(open(self.data_json_path))]
272
+
273
+ class CustomSRValidation(ImageVideoDataset):
274
+ def __init__(self, data_json_path, **kwargs):
275
+ self.data_json_path = data_json_path
276
+ super().__init__(**kwargs)
277
+ self.data_json_path = data_json_path
278
+
279
+ def get_base(self):
280
+ return [ann for ann in json.load(open(self.data_json_path))][:100] + \
281
+ [ann for ann in json.load(open(self.data_json_path))][-100:]
easyanimate/vae/ldm/lr_scheduler.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ class LambdaWarmUpCosineScheduler:
5
+ """
6
+ note: use with a base_lr of 1.0
7
+ """
8
+ def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0):
9
+ self.lr_warm_up_steps = warm_up_steps
10
+ self.lr_start = lr_start
11
+ self.lr_min = lr_min
12
+ self.lr_max = lr_max
13
+ self.lr_max_decay_steps = max_decay_steps
14
+ self.last_lr = 0.
15
+ self.verbosity_interval = verbosity_interval
16
+
17
+ def schedule(self, n, **kwargs):
18
+ if self.verbosity_interval > 0:
19
+ if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
20
+ if n < self.lr_warm_up_steps:
21
+ lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start
22
+ self.last_lr = lr
23
+ return lr
24
+ else:
25
+ t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps)
26
+ t = min(t, 1.0)
27
+ lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
28
+ 1 + np.cos(t * np.pi))
29
+ self.last_lr = lr
30
+ return lr
31
+
32
+ def __call__(self, n, **kwargs):
33
+ return self.schedule(n,**kwargs)
34
+
35
+
36
+ class LambdaWarmUpCosineScheduler2:
37
+ """
38
+ supports repeated iterations, configurable via lists
39
+ note: use with a base_lr of 1.0.
40
+ """
41
+ def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0):
42
+ assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths)
43
+ self.lr_warm_up_steps = warm_up_steps
44
+ self.f_start = f_start
45
+ self.f_min = f_min
46
+ self.f_max = f_max
47
+ self.cycle_lengths = cycle_lengths
48
+ self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths))
49
+ self.last_f = 0.
50
+ self.verbosity_interval = verbosity_interval
51
+
52
+ def find_in_interval(self, n):
53
+ interval = 0
54
+ for cl in self.cum_cycles[1:]:
55
+ if n <= cl:
56
+ return interval
57
+ interval += 1
58
+
59
+ def schedule(self, n, **kwargs):
60
+ cycle = self.find_in_interval(n)
61
+ n = n - self.cum_cycles[cycle]
62
+ if self.verbosity_interval > 0:
63
+ if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
64
+ f"current cycle {cycle}")
65
+ if n < self.lr_warm_up_steps[cycle]:
66
+ f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
67
+ self.last_f = f
68
+ return f
69
+ else:
70
+ t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle])
71
+ t = min(t, 1.0)
72
+ f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (
73
+ 1 + np.cos(t * np.pi))
74
+ self.last_f = f
75
+ return f
76
+
77
+ def __call__(self, n, **kwargs):
78
+ return self.schedule(n, **kwargs)
79
+
80
+
81
+ class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
82
+
83
+ def schedule(self, n, **kwargs):
84
+ cycle = self.find_in_interval(n)
85
+ n = n - self.cum_cycles[cycle]
86
+ if self.verbosity_interval > 0:
87
+ if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
88
+ f"current cycle {cycle}")
89
+
90
+ if n < self.lr_warm_up_steps[cycle]:
91
+ f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
92
+ self.last_f = f
93
+ return f
94
+ else:
95
+ f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle])
96
+ self.last_f = f
97
+ return f
98
+
easyanimate/vae/ldm/modules/diffusionmodules/__init__.py ADDED
File without changes
easyanimate/vae/ldm/modules/diffusionmodules/model.py ADDED
@@ -0,0 +1,701 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pytorch_diffusion + derived encoder decoder
2
+ import math
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn as nn
7
+ from einops import rearrange
8
+
9
+ from ...util import instantiate_from_config
10
+
11
+
12
+ def get_timestep_embedding(timesteps, embedding_dim):
13
+ """
14
+ This matches the implementation in Denoising Diffusion Probabilistic Models:
15
+ From Fairseq.
16
+ Build sinusoidal embeddings.
17
+ This matches the implementation in tensor2tensor, but differs slightly
18
+ from the description in Section 3.5 of "Attention Is All You Need".
19
+ """
20
+ assert len(timesteps.shape) == 1
21
+
22
+ half_dim = embedding_dim // 2
23
+ emb = math.log(10000) / (half_dim - 1)
24
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
25
+ emb = emb.to(device=timesteps.device)
26
+ emb = timesteps.float()[:, None] * emb[None, :]
27
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
28
+ if embedding_dim % 2 == 1: # zero pad
29
+ emb = torch.nn.functional.pad(emb, (0,1,0,0))
30
+ return emb
31
+
32
+
33
+ def nonlinearity(x):
34
+ # swish
35
+ return x*torch.sigmoid(x)
36
+
37
+
38
+ def Normalize(in_channels, num_groups=32):
39
+ return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
40
+
41
+
42
+ class Upsample(nn.Module):
43
+ def __init__(self, in_channels, with_conv):
44
+ super().__init__()
45
+ self.with_conv = with_conv
46
+ if self.with_conv:
47
+ self.conv = torch.nn.Conv2d(in_channels,
48
+ in_channels,
49
+ kernel_size=3,
50
+ stride=1,
51
+ padding=1)
52
+
53
+ def forward(self, x):
54
+ x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
55
+ if self.with_conv:
56
+ x = self.conv(x)
57
+ return x
58
+
59
+
60
+ class Downsample(nn.Module):
61
+ def __init__(self, in_channels, with_conv):
62
+ super().__init__()
63
+ self.with_conv = with_conv
64
+ if self.with_conv:
65
+ # no asymmetric padding in torch conv, must do it ourselves
66
+ self.conv = torch.nn.Conv2d(in_channels,
67
+ in_channels,
68
+ kernel_size=3,
69
+ stride=2,
70
+ padding=0)
71
+
72
+ def forward(self, x):
73
+ if self.with_conv:
74
+ pad = (0,1,0,1)
75
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
76
+ x = self.conv(x)
77
+ else:
78
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
79
+ return x
80
+
81
+
82
+ class ResnetBlock(nn.Module):
83
+ def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
84
+ dropout, temb_channels=512):
85
+ super().__init__()
86
+ self.in_channels = in_channels
87
+ out_channels = in_channels if out_channels is None else out_channels
88
+ self.out_channels = out_channels
89
+ self.use_conv_shortcut = conv_shortcut
90
+
91
+ self.norm1 = Normalize(in_channels)
92
+ self.conv1 = torch.nn.Conv2d(in_channels,
93
+ out_channels,
94
+ kernel_size=3,
95
+ stride=1,
96
+ padding=1)
97
+ if temb_channels > 0:
98
+ self.temb_proj = torch.nn.Linear(temb_channels,
99
+ out_channels)
100
+ self.norm2 = Normalize(out_channels)
101
+ self.dropout = torch.nn.Dropout(dropout)
102
+ self.conv2 = torch.nn.Conv2d(out_channels,
103
+ out_channels,
104
+ kernel_size=3,
105
+ stride=1,
106
+ padding=1)
107
+ if self.in_channels != self.out_channels:
108
+ if self.use_conv_shortcut:
109
+ self.conv_shortcut = torch.nn.Conv2d(in_channels,
110
+ out_channels,
111
+ kernel_size=3,
112
+ stride=1,
113
+ padding=1)
114
+ else:
115
+ self.nin_shortcut = torch.nn.Conv2d(in_channels,
116
+ out_channels,
117
+ kernel_size=1,
118
+ stride=1,
119
+ padding=0)
120
+
121
+ def forward(self, x, temb):
122
+ h = x
123
+ h = self.norm1(h)
124
+ h = nonlinearity(h)
125
+ h = self.conv1(h)
126
+
127
+ if temb is not None:
128
+ h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]
129
+
130
+ h = self.norm2(h)
131
+ h = nonlinearity(h)
132
+ h = self.dropout(h)
133
+ h = self.conv2(h)
134
+
135
+ if self.in_channels != self.out_channels:
136
+ if self.use_conv_shortcut:
137
+ x = self.conv_shortcut(x)
138
+ else:
139
+ x = self.nin_shortcut(x)
140
+
141
+ return x+h
142
+
143
+
144
+
145
+
146
+ class AttnBlock(nn.Module):
147
+ def __init__(self, in_channels):
148
+ super().__init__()
149
+ self.in_channels = in_channels
150
+
151
+ self.norm = Normalize(in_channels)
152
+ self.q = torch.nn.Conv2d(in_channels,
153
+ in_channels,
154
+ kernel_size=1,
155
+ stride=1,
156
+ padding=0)
157
+ self.k = torch.nn.Conv2d(in_channels,
158
+ in_channels,
159
+ kernel_size=1,
160
+ stride=1,
161
+ padding=0)
162
+ self.v = torch.nn.Conv2d(in_channels,
163
+ in_channels,
164
+ kernel_size=1,
165
+ stride=1,
166
+ padding=0)
167
+ self.proj_out = torch.nn.Conv2d(in_channels,
168
+ in_channels,
169
+ kernel_size=1,
170
+ stride=1,
171
+ padding=0)
172
+
173
+
174
+ def forward(self, x):
175
+ h_ = x
176
+ h_ = self.norm(h_)
177
+ q = self.q(h_)
178
+ k = self.k(h_)
179
+ v = self.v(h_)
180
+
181
+ # compute attention
182
+ b,c,h,w = q.shape
183
+ q = q.reshape(b,c,h*w)
184
+ q = q.permute(0,2,1) # b,hw,c
185
+ k = k.reshape(b,c,h*w) # b,c,hw
186
+ w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
187
+ w_ = w_ * (int(c)**(-0.5))
188
+ w_ = torch.nn.functional.softmax(w_, dim=2)
189
+
190
+ # attend to values
191
+ v = v.reshape(b,c,h*w)
192
+ w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q)
193
+ h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
194
+ h_ = h_.reshape(b,c,h,w)
195
+
196
+ h_ = self.proj_out(h_)
197
+
198
+ return x+h_
199
+
200
+ class LinearAttention(nn.Module):
201
+ def __init__(self, dim, heads=4, dim_head=32):
202
+ super().__init__()
203
+ self.heads = heads
204
+ hidden_dim = dim_head * heads
205
+ self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
206
+ self.to_out = nn.Conv2d(hidden_dim, dim, 1)
207
+
208
+ def forward(self, x):
209
+ b, c, h, w = x.shape
210
+ qkv = self.to_qkv(x)
211
+ q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3)
212
+ k = k.softmax(dim=-1)
213
+ context = torch.einsum('bhdn,bhen->bhde', k, v)
214
+ out = torch.einsum('bhde,bhdn->bhen', context, q)
215
+ out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)
216
+ return self.to_out(out)
217
+
218
+ class LinAttnBlock(LinearAttention):
219
+ """to match AttnBlock usage"""
220
+ def __init__(self, in_channels):
221
+ super().__init__(dim=in_channels, heads=1, dim_head=in_channels)
222
+
223
+ def make_attn(in_channels, attn_type="vanilla"):
224
+ assert attn_type in ["vanilla", "linear", "none"], f'attn_type {attn_type} unknown'
225
+ print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
226
+ if attn_type == "vanilla":
227
+ return AttnBlock(in_channels)
228
+ elif attn_type == "none":
229
+ return nn.Identity(in_channels)
230
+ else:
231
+ return LinAttnBlock(in_channels)
232
+
233
+
234
+ class Encoder(nn.Module):
235
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
236
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
237
+ resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla",
238
+ **ignore_kwargs):
239
+ super().__init__()
240
+ if use_linear_attn: attn_type = "linear"
241
+ self.ch = ch
242
+ self.temb_ch = 0
243
+ self.num_resolutions = len(ch_mult)
244
+ self.num_res_blocks = num_res_blocks
245
+ self.resolution = resolution
246
+ self.in_channels = in_channels
247
+
248
+ # downsampling
249
+ self.conv_in = torch.nn.Conv2d(in_channels,
250
+ self.ch,
251
+ kernel_size=3,
252
+ stride=1,
253
+ padding=1)
254
+
255
+ curr_res = resolution
256
+ in_ch_mult = (1,)+tuple(ch_mult)
257
+ self.in_ch_mult = in_ch_mult
258
+ self.down = nn.ModuleList()
259
+ for i_level in range(self.num_resolutions):
260
+ block = nn.ModuleList()
261
+ attn = nn.ModuleList()
262
+ block_in = ch*in_ch_mult[i_level]
263
+ block_out = ch*ch_mult[i_level]
264
+ for i_block in range(self.num_res_blocks):
265
+ block.append(ResnetBlock(in_channels=block_in,
266
+ out_channels=block_out,
267
+ temb_channels=self.temb_ch,
268
+ dropout=dropout))
269
+ block_in = block_out
270
+ if curr_res in attn_resolutions:
271
+ attn.append(make_attn(block_in, attn_type=attn_type))
272
+ down = nn.Module()
273
+ down.block = block
274
+ down.attn = attn
275
+ if i_level != self.num_resolutions-1:
276
+ down.downsample = Downsample(block_in, resamp_with_conv)
277
+ curr_res = curr_res // 2
278
+ self.down.append(down)
279
+
280
+ # middle
281
+ self.mid = nn.Module()
282
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
283
+ out_channels=block_in,
284
+ temb_channels=self.temb_ch,
285
+ dropout=dropout)
286
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
287
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
288
+ out_channels=block_in,
289
+ temb_channels=self.temb_ch,
290
+ dropout=dropout)
291
+
292
+ # end
293
+ self.norm_out = Normalize(block_in)
294
+ self.conv_out = torch.nn.Conv2d(block_in,
295
+ 2*z_channels if double_z else z_channels,
296
+ kernel_size=3,
297
+ stride=1,
298
+ padding=1)
299
+
300
+ def forward(self, x):
301
+ # timestep embedding
302
+ temb = None
303
+
304
+ # downsampling
305
+ hs = [self.conv_in(x)]
306
+ for i_level in range(self.num_resolutions):
307
+ for i_block in range(self.num_res_blocks):
308
+ h = self.down[i_level].block[i_block](hs[-1], temb)
309
+ if len(self.down[i_level].attn) > 0:
310
+ h = self.down[i_level].attn[i_block](h)
311
+ hs.append(h)
312
+ if i_level != self.num_resolutions-1:
313
+ hs.append(self.down[i_level].downsample(hs[-1]))
314
+
315
+ # middle
316
+ h = hs[-1]
317
+ h = self.mid.block_1(h, temb)
318
+ h = self.mid.attn_1(h)
319
+ h = self.mid.block_2(h, temb)
320
+
321
+ # end
322
+ h = self.norm_out(h)
323
+ h = nonlinearity(h)
324
+ h = self.conv_out(h)
325
+ return h
326
+
327
+
328
+ class Decoder(nn.Module):
329
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
330
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
331
+ resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False,
332
+ attn_type="vanilla", **ignorekwargs):
333
+ super().__init__()
334
+ if use_linear_attn: attn_type = "linear"
335
+ self.ch = ch
336
+ self.temb_ch = 0
337
+ self.num_resolutions = len(ch_mult)
338
+ self.num_res_blocks = num_res_blocks
339
+ self.resolution = resolution
340
+ self.in_channels = in_channels
341
+ self.give_pre_end = give_pre_end
342
+ self.tanh_out = tanh_out
343
+
344
+ # compute in_ch_mult, block_in and curr_res at lowest res
345
+ in_ch_mult = (1,)+tuple(ch_mult)
346
+ block_in = ch*ch_mult[self.num_resolutions-1]
347
+ curr_res = resolution // 2**(self.num_resolutions-1)
348
+ self.z_shape = (1,z_channels,curr_res,curr_res)
349
+ print("Working with z of shape {} = {} dimensions.".format(
350
+ self.z_shape, np.prod(self.z_shape)))
351
+
352
+ # z to block_in
353
+ self.conv_in = torch.nn.Conv2d(z_channels,
354
+ block_in,
355
+ kernel_size=3,
356
+ stride=1,
357
+ padding=1)
358
+
359
+ # middle
360
+ self.mid = nn.Module()
361
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
362
+ out_channels=block_in,
363
+ temb_channels=self.temb_ch,
364
+ dropout=dropout)
365
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
366
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
367
+ out_channels=block_in,
368
+ temb_channels=self.temb_ch,
369
+ dropout=dropout)
370
+
371
+ # upsampling
372
+ self.up = nn.ModuleList()
373
+ for i_level in reversed(range(self.num_resolutions)):
374
+ block = nn.ModuleList()
375
+ attn = nn.ModuleList()
376
+ block_out = ch*ch_mult[i_level]
377
+ for i_block in range(self.num_res_blocks+1):
378
+ block.append(ResnetBlock(in_channels=block_in,
379
+ out_channels=block_out,
380
+ temb_channels=self.temb_ch,
381
+ dropout=dropout))
382
+ block_in = block_out
383
+ if curr_res in attn_resolutions:
384
+ attn.append(make_attn(block_in, attn_type=attn_type))
385
+ up = nn.Module()
386
+ up.block = block
387
+ up.attn = attn
388
+ if i_level != 0:
389
+ up.upsample = Upsample(block_in, resamp_with_conv)
390
+ curr_res = curr_res * 2
391
+ self.up.insert(0, up) # prepend to get consistent order
392
+
393
+ # end
394
+ self.norm_out = Normalize(block_in)
395
+ self.conv_out = torch.nn.Conv2d(block_in,
396
+ out_ch,
397
+ kernel_size=3,
398
+ stride=1,
399
+ padding=1)
400
+
401
+ def forward(self, z):
402
+ #assert z.shape[1:] == self.z_shape[1:]
403
+ self.last_z_shape = z.shape
404
+
405
+ # timestep embedding
406
+ temb = None
407
+
408
+ # z to block_in
409
+ h = self.conv_in(z)
410
+
411
+ # middle
412
+ h = self.mid.block_1(h, temb)
413
+ h = self.mid.attn_1(h)
414
+ h = self.mid.block_2(h, temb)
415
+
416
+ # upsampling
417
+ for i_level in reversed(range(self.num_resolutions)):
418
+ for i_block in range(self.num_res_blocks+1):
419
+ h = self.up[i_level].block[i_block](h, temb)
420
+ if len(self.up[i_level].attn) > 0:
421
+ h = self.up[i_level].attn[i_block](h)
422
+ if i_level != 0:
423
+ h = self.up[i_level].upsample(h)
424
+
425
+ # end
426
+ if self.give_pre_end:
427
+ return h
428
+
429
+ h = self.norm_out(h)
430
+ h = nonlinearity(h)
431
+ h = self.conv_out(h)
432
+ if self.tanh_out:
433
+ h = torch.tanh(h)
434
+ return h
435
+
436
+
437
+ class SimpleDecoder(nn.Module):
438
+ def __init__(self, in_channels, out_channels, *args, **kwargs):
439
+ super().__init__()
440
+ self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1),
441
+ ResnetBlock(in_channels=in_channels,
442
+ out_channels=2 * in_channels,
443
+ temb_channels=0, dropout=0.0),
444
+ ResnetBlock(in_channels=2 * in_channels,
445
+ out_channels=4 * in_channels,
446
+ temb_channels=0, dropout=0.0),
447
+ ResnetBlock(in_channels=4 * in_channels,
448
+ out_channels=2 * in_channels,
449
+ temb_channels=0, dropout=0.0),
450
+ nn.Conv2d(2*in_channels, in_channels, 1),
451
+ Upsample(in_channels, with_conv=True)])
452
+ # end
453
+ self.norm_out = Normalize(in_channels)
454
+ self.conv_out = torch.nn.Conv2d(in_channels,
455
+ out_channels,
456
+ kernel_size=3,
457
+ stride=1,
458
+ padding=1)
459
+
460
+ def forward(self, x):
461
+ for i, layer in enumerate(self.model):
462
+ if i in [1,2,3]:
463
+ x = layer(x, None)
464
+ else:
465
+ x = layer(x)
466
+
467
+ h = self.norm_out(x)
468
+ h = nonlinearity(h)
469
+ x = self.conv_out(h)
470
+ return x
471
+
472
+
473
+ class UpsampleDecoder(nn.Module):
474
+ def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution,
475
+ ch_mult=(2,2), dropout=0.0):
476
+ super().__init__()
477
+ # upsampling
478
+ self.temb_ch = 0
479
+ self.num_resolutions = len(ch_mult)
480
+ self.num_res_blocks = num_res_blocks
481
+ block_in = in_channels
482
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
483
+ self.res_blocks = nn.ModuleList()
484
+ self.upsample_blocks = nn.ModuleList()
485
+ for i_level in range(self.num_resolutions):
486
+ res_block = []
487
+ block_out = ch * ch_mult[i_level]
488
+ for i_block in range(self.num_res_blocks + 1):
489
+ res_block.append(ResnetBlock(in_channels=block_in,
490
+ out_channels=block_out,
491
+ temb_channels=self.temb_ch,
492
+ dropout=dropout))
493
+ block_in = block_out
494
+ self.res_blocks.append(nn.ModuleList(res_block))
495
+ if i_level != self.num_resolutions - 1:
496
+ self.upsample_blocks.append(Upsample(block_in, True))
497
+ curr_res = curr_res * 2
498
+
499
+ # end
500
+ self.norm_out = Normalize(block_in)
501
+ self.conv_out = torch.nn.Conv2d(block_in,
502
+ out_channels,
503
+ kernel_size=3,
504
+ stride=1,
505
+ padding=1)
506
+
507
+ def forward(self, x):
508
+ # upsampling
509
+ h = x
510
+ for k, i_level in enumerate(range(self.num_resolutions)):
511
+ for i_block in range(self.num_res_blocks + 1):
512
+ h = self.res_blocks[i_level][i_block](h, None)
513
+ if i_level != self.num_resolutions - 1:
514
+ h = self.upsample_blocks[k](h)
515
+ h = self.norm_out(h)
516
+ h = nonlinearity(h)
517
+ h = self.conv_out(h)
518
+ return h
519
+
520
+
521
+ class LatentRescaler(nn.Module):
522
+ def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2):
523
+ super().__init__()
524
+ # residual block, interpolate, residual block
525
+ self.factor = factor
526
+ self.conv_in = nn.Conv2d(in_channels,
527
+ mid_channels,
528
+ kernel_size=3,
529
+ stride=1,
530
+ padding=1)
531
+ self.res_block1 = nn.ModuleList([ResnetBlock(in_channels=mid_channels,
532
+ out_channels=mid_channels,
533
+ temb_channels=0,
534
+ dropout=0.0) for _ in range(depth)])
535
+ self.attn = AttnBlock(mid_channels)
536
+ self.res_block2 = nn.ModuleList([ResnetBlock(in_channels=mid_channels,
537
+ out_channels=mid_channels,
538
+ temb_channels=0,
539
+ dropout=0.0) for _ in range(depth)])
540
+
541
+ self.conv_out = nn.Conv2d(mid_channels,
542
+ out_channels,
543
+ kernel_size=1,
544
+ )
545
+
546
+ def forward(self, x):
547
+ x = self.conv_in(x)
548
+ for block in self.res_block1:
549
+ x = block(x, None)
550
+ x = torch.nn.functional.interpolate(x, size=(int(round(x.shape[2]*self.factor)), int(round(x.shape[3]*self.factor))))
551
+ x = self.attn(x)
552
+ for block in self.res_block2:
553
+ x = block(x, None)
554
+ x = self.conv_out(x)
555
+ return x
556
+
557
+
558
+ class MergedRescaleEncoder(nn.Module):
559
+ def __init__(self, in_channels, ch, resolution, out_ch, num_res_blocks,
560
+ attn_resolutions, dropout=0.0, resamp_with_conv=True,
561
+ ch_mult=(1,2,4,8), rescale_factor=1.0, rescale_module_depth=1):
562
+ super().__init__()
563
+ intermediate_chn = ch * ch_mult[-1]
564
+ self.encoder = Encoder(in_channels=in_channels, num_res_blocks=num_res_blocks, ch=ch, ch_mult=ch_mult,
565
+ z_channels=intermediate_chn, double_z=False, resolution=resolution,
566
+ attn_resolutions=attn_resolutions, dropout=dropout, resamp_with_conv=resamp_with_conv,
567
+ out_ch=None)
568
+ self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=intermediate_chn,
569
+ mid_channels=intermediate_chn, out_channels=out_ch, depth=rescale_module_depth)
570
+
571
+ def forward(self, x):
572
+ x = self.encoder(x)
573
+ x = self.rescaler(x)
574
+ return x
575
+
576
+
577
+ class MergedRescaleDecoder(nn.Module):
578
+ def __init__(self, z_channels, out_ch, resolution, num_res_blocks, attn_resolutions, ch, ch_mult=(1,2,4,8),
579
+ dropout=0.0, resamp_with_conv=True, rescale_factor=1.0, rescale_module_depth=1):
580
+ super().__init__()
581
+ tmp_chn = z_channels*ch_mult[-1]
582
+ self.decoder = Decoder(out_ch=out_ch, z_channels=tmp_chn, attn_resolutions=attn_resolutions, dropout=dropout,
583
+ resamp_with_conv=resamp_with_conv, in_channels=None, num_res_blocks=num_res_blocks,
584
+ ch_mult=ch_mult, resolution=resolution, ch=ch)
585
+ self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=z_channels, mid_channels=tmp_chn,
586
+ out_channels=tmp_chn, depth=rescale_module_depth)
587
+
588
+ def forward(self, x):
589
+ x = self.rescaler(x)
590
+ x = self.decoder(x)
591
+ return x
592
+
593
+
594
+ class Upsampler(nn.Module):
595
+ def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2):
596
+ super().__init__()
597
+ assert out_size >= in_size
598
+ num_blocks = int(np.log2(out_size//in_size))+1
599
+ factor_up = 1.+ (out_size % in_size)
600
+ print(f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}")
601
+ self.rescaler = LatentRescaler(factor=factor_up, in_channels=in_channels, mid_channels=2*in_channels,
602
+ out_channels=in_channels)
603
+ self.decoder = Decoder(out_ch=out_channels, resolution=out_size, z_channels=in_channels, num_res_blocks=2,
604
+ attn_resolutions=[], in_channels=None, ch=in_channels,
605
+ ch_mult=[ch_mult for _ in range(num_blocks)])
606
+
607
+ def forward(self, x):
608
+ x = self.rescaler(x)
609
+ x = self.decoder(x)
610
+ return x
611
+
612
+
613
+ class Resize(nn.Module):
614
+ def __init__(self, in_channels=None, learned=False, mode="bilinear"):
615
+ super().__init__()
616
+ self.with_conv = learned
617
+ self.mode = mode
618
+ if self.with_conv:
619
+ print(f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode")
620
+ raise NotImplementedError()
621
+ assert in_channels is not None
622
+ # no asymmetric padding in torch conv, must do it ourselves
623
+ self.conv = torch.nn.Conv2d(in_channels,
624
+ in_channels,
625
+ kernel_size=4,
626
+ stride=2,
627
+ padding=1)
628
+
629
+ def forward(self, x, scale_factor=1.0):
630
+ if scale_factor==1.0:
631
+ return x
632
+ else:
633
+ x = torch.nn.functional.interpolate(x, mode=self.mode, align_corners=False, scale_factor=scale_factor)
634
+ return x
635
+
636
+ class FirstStagePostProcessor(nn.Module):
637
+
638
+ def __init__(self, ch_mult:list, in_channels,
639
+ pretrained_model:nn.Module=None,
640
+ reshape=False,
641
+ n_channels=None,
642
+ dropout=0.,
643
+ pretrained_config=None):
644
+ super().__init__()
645
+ if pretrained_config is None:
646
+ assert pretrained_model is not None, 'Either "pretrained_model" or "pretrained_config" must not be None'
647
+ self.pretrained_model = pretrained_model
648
+ else:
649
+ assert pretrained_config is not None, 'Either "pretrained_model" or "pretrained_config" must not be None'
650
+ self.instantiate_pretrained(pretrained_config)
651
+
652
+ self.do_reshape = reshape
653
+
654
+ if n_channels is None:
655
+ n_channels = self.pretrained_model.encoder.ch
656
+
657
+ self.proj_norm = Normalize(in_channels,num_groups=in_channels//2)
658
+ self.proj = nn.Conv2d(in_channels,n_channels,kernel_size=3,
659
+ stride=1,padding=1)
660
+
661
+ blocks = []
662
+ downs = []
663
+ ch_in = n_channels
664
+ for m in ch_mult:
665
+ blocks.append(ResnetBlock(in_channels=ch_in,out_channels=m*n_channels,dropout=dropout))
666
+ ch_in = m * n_channels
667
+ downs.append(Downsample(ch_in, with_conv=False))
668
+
669
+ self.model = nn.ModuleList(blocks)
670
+ self.downsampler = nn.ModuleList(downs)
671
+
672
+
673
+ def instantiate_pretrained(self, config):
674
+ model = instantiate_from_config(config)
675
+ self.pretrained_model = model.eval()
676
+ # self.pretrained_model.train = False
677
+ for param in self.pretrained_model.parameters():
678
+ param.requires_grad = False
679
+
680
+
681
+ @torch.no_grad()
682
+ def encode_with_pretrained(self,x):
683
+ c = self.pretrained_model.encode(x)
684
+ if isinstance(c, DiagonalGaussianDistribution):
685
+ c = c.mode()
686
+ return c
687
+
688
+ def forward(self,x):
689
+ z_fs = self.encode_with_pretrained(x)
690
+ z = self.proj_norm(z_fs)
691
+ z = self.proj(z)
692
+ z = nonlinearity(z)
693
+
694
+ for submodel, downmodel in zip(self.model,self.downsampler):
695
+ z = submodel(z,temb=None)
696
+ z = downmodel(z)
697
+
698
+ if self.do_reshape:
699
+ z = rearrange(z,'b c h w -> b (h w) c')
700
+ return z
701
+
easyanimate/vae/ldm/modules/diffusionmodules/util.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # adopted from
2
+ # https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
3
+ # and
4
+ # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
5
+ # and
6
+ # https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
7
+ #
8
+ # thanks!
9
+
10
+
11
+ import math
12
+ import os
13
+
14
+ import numpy as np
15
+ import torch
16
+ import torch.nn as nn
17
+ from einops import repeat
18
+
19
+ from ...util import instantiate_from_config
20
+
21
+
22
+ def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
23
+ if schedule == "linear":
24
+ betas = (
25
+ torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
26
+ )
27
+
28
+ elif schedule == "cosine":
29
+ timesteps = (
30
+ torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
31
+ )
32
+ alphas = timesteps / (1 + cosine_s) * np.pi / 2
33
+ alphas = torch.cos(alphas).pow(2)
34
+ alphas = alphas / alphas[0]
35
+ betas = 1 - alphas[1:] / alphas[:-1]
36
+ betas = np.clip(betas, a_min=0, a_max=0.999)
37
+
38
+ elif schedule == "sqrt_linear":
39
+ betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
40
+ elif schedule == "sqrt":
41
+ betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5
42
+ else:
43
+ raise ValueError(f"schedule '{schedule}' unknown.")
44
+ return betas.numpy()
45
+
46
+
47
+ def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True):
48
+ if ddim_discr_method == 'uniform':
49
+ c = num_ddpm_timesteps // num_ddim_timesteps
50
+ ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
51
+ elif ddim_discr_method == 'quad':
52
+ ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int)
53
+ else:
54
+ raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"')
55
+
56
+ # assert ddim_timesteps.shape[0] == num_ddim_timesteps
57
+ # add one to get the final alpha values right (the ones from first scale to data during sampling)
58
+ steps_out = ddim_timesteps + 1
59
+ if verbose:
60
+ print(f'Selected timesteps for ddim sampler: {steps_out}')
61
+ return steps_out
62
+
63
+
64
+ def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
65
+ # select alphas for computing the variance schedule
66
+ alphas = alphacums[ddim_timesteps]
67
+ alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
68
+
69
+ # according the the formula provided in https://arxiv.org/abs/2010.02502
70
+ sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
71
+ if verbose:
72
+ print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')
73
+ print(f'For the chosen value of eta, which is {eta}, '
74
+ f'this results in the following sigma_t schedule for ddim sampler {sigmas}')
75
+ return sigmas, alphas, alphas_prev
76
+
77
+
78
+ def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
79
+ """
80
+ Create a beta schedule that discretizes the given alpha_t_bar function,
81
+ which defines the cumulative product of (1-beta) over time from t = [0,1].
82
+ :param num_diffusion_timesteps: the number of betas to produce.
83
+ :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
84
+ produces the cumulative product of (1-beta) up to that
85
+ part of the diffusion process.
86
+ :param max_beta: the maximum beta to use; use values lower than 1 to
87
+ prevent singularities.
88
+ """
89
+ betas = []
90
+ for i in range(num_diffusion_timesteps):
91
+ t1 = i / num_diffusion_timesteps
92
+ t2 = (i + 1) / num_diffusion_timesteps
93
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
94
+ return np.array(betas)
95
+
96
+
97
+ def extract_into_tensor(a, t, x_shape):
98
+ b, *_ = t.shape
99
+ out = a.gather(-1, t)
100
+ return out.reshape(b, *((1,) * (len(x_shape) - 1)))
101
+
102
+
103
+ def checkpoint(func, inputs, params, flag):
104
+ """
105
+ Evaluate a function without caching intermediate activations, allowing for
106
+ reduced memory at the expense of extra compute in the backward pass.
107
+ :param func: the function to evaluate.
108
+ :param inputs: the argument sequence to pass to `func`.
109
+ :param params: a sequence of parameters `func` depends on but does not
110
+ explicitly take as arguments.
111
+ :param flag: if False, disable gradient checkpointing.
112
+ """
113
+ if flag:
114
+ args = tuple(inputs) + tuple(params)
115
+ return CheckpointFunction.apply(func, len(inputs), *args)
116
+ else:
117
+ return func(*inputs)
118
+
119
+
120
+ class CheckpointFunction(torch.autograd.Function):
121
+ @staticmethod
122
+ def forward(ctx, run_function, length, *args):
123
+ ctx.run_function = run_function
124
+ ctx.input_tensors = list(args[:length])
125
+ ctx.input_params = list(args[length:])
126
+
127
+ with torch.no_grad():
128
+ output_tensors = ctx.run_function(*ctx.input_tensors)
129
+ return output_tensors
130
+
131
+ @staticmethod
132
+ def backward(ctx, *output_grads):
133
+ ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
134
+ with torch.enable_grad():
135
+ # Fixes a bug where the first op in run_function modifies the
136
+ # Tensor storage in place, which is not allowed for detach()'d
137
+ # Tensors.
138
+ shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
139
+ output_tensors = ctx.run_function(*shallow_copies)
140
+ input_grads = torch.autograd.grad(
141
+ output_tensors,
142
+ ctx.input_tensors + ctx.input_params,
143
+ output_grads,
144
+ allow_unused=True,
145
+ )
146
+ del ctx.input_tensors
147
+ del ctx.input_params
148
+ del output_tensors
149
+ return (None, None) + input_grads
150
+
151
+
152
+ def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
153
+ """
154
+ Create sinusoidal timestep embeddings.
155
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
156
+ These may be fractional.
157
+ :param dim: the dimension of the output.
158
+ :param max_period: controls the minimum frequency of the embeddings.
159
+ :return: an [N x dim] Tensor of positional embeddings.
160
+ """
161
+ if not repeat_only:
162
+ half = dim // 2
163
+ freqs = torch.exp(
164
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
165
+ ).to(device=timesteps.device)
166
+ args = timesteps[:, None].float() * freqs[None]
167
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
168
+ if dim % 2:
169
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
170
+ else:
171
+ embedding = repeat(timesteps, 'b -> b d', d=dim)
172
+ return embedding
173
+
174
+
175
+ def zero_module(module):
176
+ """
177
+ Zero out the parameters of a module and return it.
178
+ """
179
+ for p in module.parameters():
180
+ p.detach().zero_()
181
+ return module
182
+
183
+
184
+ def scale_module(module, scale):
185
+ """
186
+ Scale the parameters of a module and return it.
187
+ """
188
+ for p in module.parameters():
189
+ p.detach().mul_(scale)
190
+ return module
191
+
192
+
193
+ def mean_flat(tensor):
194
+ """
195
+ Take the mean over all non-batch dimensions.
196
+ """
197
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
198
+
199
+
200
+ def normalization(channels):
201
+ """
202
+ Make a standard normalization layer.
203
+ :param channels: number of input channels.
204
+ :return: an nn.Module for normalization.
205
+ """
206
+ return GroupNorm32(32, channels)
207
+
208
+
209
+ # PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
210
+ class SiLU(nn.Module):
211
+ def forward(self, x):
212
+ return x * torch.sigmoid(x)
213
+
214
+
215
+ class GroupNorm32(nn.GroupNorm):
216
+ def forward(self, x):
217
+ return super().forward(x.float()).type(x.dtype)
218
+
219
+ def conv_nd(dims, *args, **kwargs):
220
+ """
221
+ Create a 1D, 2D, or 3D convolution module.
222
+ """
223
+ if dims == 1:
224
+ return nn.Conv1d(*args, **kwargs)
225
+ elif dims == 2:
226
+ return nn.Conv2d(*args, **kwargs)
227
+ elif dims == 3:
228
+ return nn.Conv3d(*args, **kwargs)
229
+ raise ValueError(f"unsupported dimensions: {dims}")
230
+
231
+
232
+ def linear(*args, **kwargs):
233
+ """
234
+ Create a linear module.
235
+ """
236
+ return nn.Linear(*args, **kwargs)
237
+
238
+
239
+ def avg_pool_nd(dims, *args, **kwargs):
240
+ """
241
+ Create a 1D, 2D, or 3D average pooling module.
242
+ """
243
+ if dims == 1:
244
+ return nn.AvgPool1d(*args, **kwargs)
245
+ elif dims == 2:
246
+ return nn.AvgPool2d(*args, **kwargs)
247
+ elif dims == 3:
248
+ return nn.AvgPool3d(*args, **kwargs)
249
+ raise ValueError(f"unsupported dimensions: {dims}")
250
+
251
+
252
+ class HybridConditioner(nn.Module):
253
+
254
+ def __init__(self, c_concat_config, c_crossattn_config):
255
+ super().__init__()
256
+ self.concat_conditioner = instantiate_from_config(c_concat_config)
257
+ self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)
258
+
259
+ def forward(self, c_concat, c_crossattn):
260
+ c_concat = self.concat_conditioner(c_concat)
261
+ c_crossattn = self.crossattn_conditioner(c_crossattn)
262
+ return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]}
263
+
264
+
265
+ def noise_like(shape, device, repeat=False):
266
+ repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
267
+ noise = lambda: torch.randn(shape, device=device)
268
+ return repeat_noise() if repeat else noise()
easyanimate/vae/ldm/modules/distributions/__init__.py ADDED
File without changes
easyanimate/vae/ldm/modules/distributions/distributions.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+
4
+
5
+ class AbstractDistribution:
6
+ def sample(self):
7
+ raise NotImplementedError()
8
+
9
+ def mode(self):
10
+ raise NotImplementedError()
11
+
12
+
13
+ class DiracDistribution(AbstractDistribution):
14
+ def __init__(self, value):
15
+ self.value = value
16
+
17
+ def sample(self):
18
+ return self.value
19
+
20
+ def mode(self):
21
+ return self.value
22
+
23
+
24
+ class DiagonalGaussianDistribution(object):
25
+ def __init__(self, parameters, deterministic=False):
26
+ self.parameters = parameters
27
+ self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
28
+ self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
29
+ self.deterministic = deterministic
30
+ self.std = torch.exp(0.5 * self.logvar)
31
+ self.var = torch.exp(self.logvar)
32
+ if self.deterministic:
33
+ self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
34
+
35
+ def sample(self):
36
+ x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
37
+ return x
38
+
39
+ def kl(self, other=None):
40
+ if self.deterministic:
41
+ return torch.Tensor([0.])
42
+ else:
43
+ if other is None:
44
+ return 0.5 * torch.sum(torch.pow(self.mean, 2)
45
+ + self.var - 1.0 - self.logvar,
46
+ dim=[1, 2, 3])
47
+ else:
48
+ return 0.5 * torch.sum(
49
+ torch.pow(self.mean - other.mean, 2) / other.var
50
+ + self.var / other.var - 1.0 - self.logvar + other.logvar,
51
+ dim=[1, 2, 3])
52
+
53
+ def nll(self, sample, dims=[1,2,3]):
54
+ if self.deterministic:
55
+ return torch.Tensor([0.])
56
+ logtwopi = np.log(2.0 * np.pi)
57
+ return 0.5 * torch.sum(
58
+ logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
59
+ dim=dims)
60
+
61
+ def mode(self):
62
+ return self.mean
63
+
64
+
65
+ def normal_kl(mean1, logvar1, mean2, logvar2):
66
+ """
67
+ source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
68
+ Compute the KL divergence between two gaussians.
69
+ Shapes are automatically broadcasted, so batches can be compared to
70
+ scalars, among other use cases.
71
+ """
72
+ tensor = None
73
+ for obj in (mean1, logvar1, mean2, logvar2):
74
+ if isinstance(obj, torch.Tensor):
75
+ tensor = obj
76
+ break
77
+ assert tensor is not None, "at least one argument must be a Tensor"
78
+
79
+ # Force variances to be Tensors. Broadcasting helps convert scalars to
80
+ # Tensors, but it does not work for torch.exp().
81
+ logvar1, logvar2 = [
82
+ x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
83
+ for x in (logvar1, logvar2)
84
+ ]
85
+
86
+ return 0.5 * (
87
+ -1.0
88
+ + logvar2
89
+ - logvar1
90
+ + torch.exp(logvar1 - logvar2)
91
+ + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
92
+ )