diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..aa2419622a32d5727e47e604727ff53d1b2ae503
--- /dev/null
+++ b/app.py
@@ -0,0 +1,40 @@
+import time
+
+from easyanimate.api.api import infer_forward_api, update_diffusion_transformer_api, update_edition_api
+from easyanimate.ui.ui import ui_modelscope, ui, ui_huggingface
+
+if __name__ == "__main__":
+ # Choose the ui mode
+ ui_mode = "huggingface"
+ # Server ip
+ server_name = "0.0.0.0"
+ server_port = 7860
+
+ # Params below is used when ui_mode = "modelscope"
+ edition = "v2"
+ config_path = "config/easyanimate_video_magvit_motion_module_v2.yaml"
+ model_name = "models/Diffusion_Transformer/EasyAnimateV2-XL-2-512x512"
+ savedir_sample = "samples"
+
+ if ui_mode == "modelscope":
+ demo, controller = ui_modelscope(edition, config_path, model_name, savedir_sample)
+ elif ui_mode == "huggingface":
+ demo, controller = ui_huggingface(edition, config_path, model_name, savedir_sample)
+ else:
+ demo, controller = ui()
+
+ # launch gradio
+ app, _, _ = demo.queue(status_update_rate=1).launch(
+ server_name=server_name,
+ server_port=server_port,
+ prevent_thread_lock=True
+ )
+
+ # launch api
+ infer_forward_api(None, app, controller)
+ update_diffusion_transformer_api(None, app, controller)
+ update_edition_api(None, app, controller)
+
+ # not close the python
+ while True:
+ time.sleep(5)
\ No newline at end of file
diff --git a/config/easyanimate_image_magvit_v2.yaml b/config/easyanimate_image_magvit_v2.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..8a781081885f300a557b28c9eeb30afa78cc8112
--- /dev/null
+++ b/config/easyanimate_image_magvit_v2.yaml
@@ -0,0 +1,8 @@
+noise_scheduler_kwargs:
+ beta_start: 0.0001
+ beta_end: 0.02
+ beta_schedule: "linear"
+ steps_offset: 1
+
+vae_kwargs:
+ enable_magvit: true
\ No newline at end of file
diff --git a/config/easyanimate_image_normal_v1.yaml b/config/easyanimate_image_normal_v1.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..8b926c1c15586e94c5458f80f8468b6651327ccb
--- /dev/null
+++ b/config/easyanimate_image_normal_v1.yaml
@@ -0,0 +1,8 @@
+noise_scheduler_kwargs:
+ beta_start: 0.0001
+ beta_end: 0.02
+ beta_schedule: "linear"
+ steps_offset: 1
+
+vae_kwargs:
+ enable_magvit: false
\ No newline at end of file
diff --git a/config/easyanimate_image_slicevae_v3.yaml b/config/easyanimate_image_slicevae_v3.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..e41b63d64e605a70ef6be20f309e0c383a177495
--- /dev/null
+++ b/config/easyanimate_image_slicevae_v3.yaml
@@ -0,0 +1,9 @@
+noise_scheduler_kwargs:
+ beta_start: 0.0001
+ beta_end: 0.02
+ beta_schedule: "linear"
+ steps_offset: 1
+
+vae_kwargs:
+ enable_magvit: true
+ slice_compression_vae: true
\ No newline at end of file
diff --git a/config/easyanimate_video_casual_motion_module_v1.yaml b/config/easyanimate_video_casual_motion_module_v1.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..4ed53304ec0f78b764dfc671d2efb7e7af91cf1c
--- /dev/null
+++ b/config/easyanimate_video_casual_motion_module_v1.yaml
@@ -0,0 +1,27 @@
+transformer_additional_kwargs:
+ patch_3d: false
+ fake_3d: false
+ casual_3d: true
+ casual_3d_upsampler_index: [16, 20]
+ time_patch_size: 4
+ basic_block_type: "motionmodule"
+ time_position_encoding_before_transformer: false
+ motion_module_type: "VanillaGrid"
+
+ motion_module_kwargs:
+ num_attention_heads: 8
+ num_transformer_block: 1
+ attention_block_types: [ "Temporal_Self", "Temporal_Self" ]
+ temporal_position_encoding: true
+ temporal_position_encoding_max_len: 4096
+ temporal_attention_dim_div: 1
+ block_size: 2
+
+noise_scheduler_kwargs:
+ beta_start: 0.0001
+ beta_end: 0.02
+ beta_schedule: "linear"
+ steps_offset: 1
+
+vae_kwargs:
+ enable_magvit: false
\ No newline at end of file
diff --git a/config/easyanimate_video_long_sequence_v1.yaml b/config/easyanimate_video_long_sequence_v1.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..0538352aaf2a11faa1bd374b78dcf47435c8df37
--- /dev/null
+++ b/config/easyanimate_video_long_sequence_v1.yaml
@@ -0,0 +1,14 @@
+transformer_additional_kwargs:
+ patch_3d: false
+ fake_3d: false
+ basic_block_type: "selfattentiontemporal"
+ time_position_encoding_before_transformer: true
+
+noise_scheduler_kwargs:
+ beta_start: 0.0001
+ beta_end: 0.02
+ beta_schedule: "linear"
+ steps_offset: 1
+
+vae_kwargs:
+ enable_magvit: false
\ No newline at end of file
diff --git a/config/easyanimate_video_magvit_motion_module_v2.yaml b/config/easyanimate_video_magvit_motion_module_v2.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..723ad1e0aee1a8c3c6c77c3ce0549a0ea4babb90
--- /dev/null
+++ b/config/easyanimate_video_magvit_motion_module_v2.yaml
@@ -0,0 +1,26 @@
+transformer_additional_kwargs:
+ patch_3d: false
+ fake_3d: false
+ basic_block_type: "motionmodule"
+ time_position_encoding_before_transformer: false
+ motion_module_type: "Vanilla"
+ enable_uvit: true
+
+ motion_module_kwargs:
+ num_attention_heads: 8
+ num_transformer_block: 1
+ attention_block_types: [ "Temporal_Self", "Temporal_Self" ]
+ temporal_position_encoding: true
+ temporal_position_encoding_max_len: 4096
+ temporal_attention_dim_div: 1
+ block_size: 1
+
+noise_scheduler_kwargs:
+ beta_start: 0.0001
+ beta_end: 0.02
+ beta_schedule: "linear"
+ steps_offset: 1
+
+vae_kwargs:
+ enable_magvit: true
+ mini_batch_encoder: 9
\ No newline at end of file
diff --git a/config/easyanimate_video_motion_module_v1.yaml b/config/easyanimate_video_motion_module_v1.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..add62459d6ada289f2f8c61e571876d5f99ec5ac
--- /dev/null
+++ b/config/easyanimate_video_motion_module_v1.yaml
@@ -0,0 +1,24 @@
+transformer_additional_kwargs:
+ patch_3d: false
+ fake_3d: false
+ basic_block_type: "motionmodule"
+ time_position_encoding_before_transformer: false
+ motion_module_type: "VanillaGrid"
+
+ motion_module_kwargs:
+ num_attention_heads: 8
+ num_transformer_block: 1
+ attention_block_types: [ "Temporal_Self", "Temporal_Self" ]
+ temporal_position_encoding: true
+ temporal_position_encoding_max_len: 4096
+ temporal_attention_dim_div: 1
+ block_size: 2
+
+noise_scheduler_kwargs:
+ beta_start: 0.0001
+ beta_end: 0.02
+ beta_schedule: "linear"
+ steps_offset: 1
+
+vae_kwargs:
+ enable_magvit: false
\ No newline at end of file
diff --git a/config/easyanimate_video_slicevae_motion_module_v3.yaml b/config/easyanimate_video_slicevae_motion_module_v3.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..e0e1bac132accd6bf5905ee59643004e2c401e9f
--- /dev/null
+++ b/config/easyanimate_video_slicevae_motion_module_v3.yaml
@@ -0,0 +1,27 @@
+transformer_additional_kwargs:
+ patch_3d: false
+ fake_3d: false
+ basic_block_type: "motionmodule"
+ time_position_encoding_before_transformer: false
+ motion_module_type: "Vanilla"
+ enable_uvit: true
+
+ motion_module_kwargs:
+ num_attention_heads: 8
+ num_transformer_block: 1
+ attention_block_types: [ "Temporal_Self", "Temporal_Self" ]
+ temporal_position_encoding: true
+ temporal_position_encoding_max_len: 4096
+ temporal_attention_dim_div: 1
+ block_size: 1
+
+noise_scheduler_kwargs:
+ beta_start: 0.0001
+ beta_end: 0.02
+ beta_schedule: "linear"
+ steps_offset: 1
+
+vae_kwargs:
+ enable_magvit: true
+ slice_compression_vae: true
+ mini_batch_encoder: 8
\ No newline at end of file
diff --git a/easyanimate/__init__.py b/easyanimate/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/easyanimate/api/api.py b/easyanimate/api/api.py
new file mode 100644
index 0000000000000000000000000000000000000000..54a69c9511717550cc7d607fdadc58d29a918e16
--- /dev/null
+++ b/easyanimate/api/api.py
@@ -0,0 +1,96 @@
+import io
+import base64
+import torch
+import gradio as gr
+
+from fastapi import FastAPI
+from io import BytesIO
+
+# Function to encode a file to Base64
+def encode_file_to_base64(file_path):
+ with open(file_path, "rb") as file:
+ # Encode the data to Base64
+ file_base64 = base64.b64encode(file.read())
+ return file_base64
+
+def update_edition_api(_: gr.Blocks, app: FastAPI, controller):
+ @app.post("/easyanimate/update_edition")
+ def _update_edition_api(
+ datas: dict,
+ ):
+ edition = datas.get('edition', 'v2')
+
+ try:
+ controller.update_edition(
+ edition
+ )
+ comment = "Success"
+ except Exception as e:
+ torch.cuda.empty_cache()
+ comment = f"Error. error information is {str(e)}"
+
+ return {"message": comment}
+
+def update_diffusion_transformer_api(_: gr.Blocks, app: FastAPI, controller):
+ @app.post("/easyanimate/update_diffusion_transformer")
+ def _update_diffusion_transformer_api(
+ datas: dict,
+ ):
+ diffusion_transformer_path = datas.get('diffusion_transformer_path', 'none')
+
+ try:
+ controller.update_diffusion_transformer(
+ diffusion_transformer_path
+ )
+ comment = "Success"
+ except Exception as e:
+ torch.cuda.empty_cache()
+ comment = f"Error. error information is {str(e)}"
+
+ return {"message": comment}
+
+def infer_forward_api(_: gr.Blocks, app: FastAPI, controller):
+ @app.post("/easyanimate/infer_forward")
+ def _infer_forward_api(
+ datas: dict,
+ ):
+ base_model_path = datas.get('base_model_path', 'none')
+ motion_module_path = datas.get('motion_module_path', 'none')
+ lora_model_path = datas.get('lora_model_path', 'none')
+ lora_alpha_slider = datas.get('lora_alpha_slider', 0.55)
+ prompt_textbox = datas.get('prompt_textbox', None)
+ negative_prompt_textbox = datas.get('negative_prompt_textbox', '')
+ sampler_dropdown = datas.get('sampler_dropdown', 'Euler')
+ sample_step_slider = datas.get('sample_step_slider', 30)
+ width_slider = datas.get('width_slider', 672)
+ height_slider = datas.get('height_slider', 384)
+ is_image = datas.get('is_image', False)
+ length_slider = datas.get('length_slider', 144)
+ cfg_scale_slider = datas.get('cfg_scale_slider', 6)
+ seed_textbox = datas.get("seed_textbox", 43)
+
+ try:
+ save_sample_path, comment = controller.generate(
+ "",
+ base_model_path,
+ motion_module_path,
+ lora_model_path,
+ lora_alpha_slider,
+ prompt_textbox,
+ negative_prompt_textbox,
+ sampler_dropdown,
+ sample_step_slider,
+ width_slider,
+ height_slider,
+ is_image,
+ length_slider,
+ cfg_scale_slider,
+ seed_textbox,
+ is_api = True,
+ )
+ except Exception as e:
+ torch.cuda.empty_cache()
+ save_sample_path = ""
+ comment = f"Error. error information is {str(e)}"
+
+ return {"message": comment, "save_sample_path": save_sample_path, "base64_encoding": encode_file_to_base64(save_sample_path)}
\ No newline at end of file
diff --git a/easyanimate/api/post_infer.py b/easyanimate/api/post_infer.py
new file mode 100644
index 0000000000000000000000000000000000000000..0b4cba71505359c8ccbe8175c6a212a4783c3215
--- /dev/null
+++ b/easyanimate/api/post_infer.py
@@ -0,0 +1,94 @@
+import base64
+import json
+import sys
+import time
+from datetime import datetime
+from io import BytesIO
+
+import cv2
+import requests
+import base64
+
+
+def post_diffusion_transformer(diffusion_transformer_path, url='http://127.0.0.1:7860'):
+ datas = json.dumps({
+ "diffusion_transformer_path": diffusion_transformer_path
+ })
+ r = requests.post(f'{url}/easyanimate/update_diffusion_transformer', data=datas, timeout=1500)
+ data = r.content.decode('utf-8')
+ return data
+
+def post_update_edition(edition, url='http://0.0.0.0:7860'):
+ datas = json.dumps({
+ "edition": edition
+ })
+ r = requests.post(f'{url}/easyanimate/update_edition', data=datas, timeout=1500)
+ data = r.content.decode('utf-8')
+ return data
+
+def post_infer(is_image, length_slider, url='http://127.0.0.1:7860'):
+ datas = json.dumps({
+ "base_model_path": "none",
+ "motion_module_path": "none",
+ "lora_model_path": "none",
+ "lora_alpha_slider": 0.55,
+ "prompt_textbox": "This video shows Mount saint helens, washington - the stunning scenery of a rocky mountains during golden hours - wide shot. A soaring drone footage captures the majestic beauty of a coastal cliff, its red and yellow stratified rock faces rich in color and against the vibrant turquoise of the sea.",
+ "negative_prompt_textbox": "Strange motion trajectory, a poor composition and deformed video, worst quality, normal quality, low quality, low resolution, duplicate and ugly, strange body structure, long and strange neck, bad teeth, bad eyes, bad limbs, bad hands, rotating camera, blurry camera, shaking camera",
+ "sampler_dropdown": "Euler",
+ "sample_step_slider": 30,
+ "width_slider": 672,
+ "height_slider": 384,
+ "is_image": is_image,
+ "length_slider": length_slider,
+ "cfg_scale_slider": 6,
+ "seed_textbox": 43,
+ })
+ r = requests.post(f'{url}/easyanimate/infer_forward', data=datas, timeout=1500)
+ data = r.content.decode('utf-8')
+ return data
+
+if __name__ == '__main__':
+ # initiate time
+ now_date = datetime.now()
+ time_start = time.time()
+
+ # -------------------------- #
+ # Step 1: update edition
+ # -------------------------- #
+ edition = "v2"
+ outputs = post_update_edition(edition)
+ print('Output update edition: ', outputs)
+
+ # -------------------------- #
+ # Step 2: update edition
+ # -------------------------- #
+ diffusion_transformer_path = "/your-path/EasyAnimate/models/Diffusion_Transformer/EasyAnimateV2-XL-2-512x512"
+ outputs = post_diffusion_transformer(diffusion_transformer_path)
+ print('Output update edition: ', outputs)
+
+ # -------------------------- #
+ # Step 3: infer
+ # -------------------------- #
+ is_image = False
+ length_slider = 27
+ outputs = post_infer(is_image, length_slider)
+
+ # Get decoded data
+ outputs = json.loads(outputs)
+ base64_encoding = outputs["base64_encoding"]
+ decoded_data = base64.b64decode(base64_encoding)
+
+ if is_image or length_slider == 1:
+ file_path = "1.png"
+ else:
+ file_path = "1.mp4"
+ with open(file_path, "wb") as file:
+ file.write(decoded_data)
+
+ # End of record time
+ # The calculated time difference is the execution time of the program, expressed in seconds / s
+ time_end = time.time()
+ time_sum = (time_end - time_start) % 60
+ print('# --------------------------------------------------------- #')
+ print(f'# Total expenditure: {time_sum}s')
+ print('# --------------------------------------------------------- #')
\ No newline at end of file
diff --git a/easyanimate/data/bucket_sampler.py b/easyanimate/data/bucket_sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..2c5fded15beeb7f53bbf310a571f776cba932c52
--- /dev/null
+++ b/easyanimate/data/bucket_sampler.py
@@ -0,0 +1,379 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os
+from typing import (Generic, Iterable, Iterator, List, Optional, Sequence,
+ Sized, TypeVar, Union)
+
+import cv2
+import numpy as np
+import torch
+from PIL import Image
+from torch.utils.data import BatchSampler, Dataset, Sampler
+
+ASPECT_RATIO_512 = {
+ '0.25': [256.0, 1024.0], '0.26': [256.0, 992.0], '0.27': [256.0, 960.0], '0.28': [256.0, 928.0],
+ '0.32': [288.0, 896.0], '0.33': [288.0, 864.0], '0.35': [288.0, 832.0], '0.4': [320.0, 800.0],
+ '0.42': [320.0, 768.0], '0.48': [352.0, 736.0], '0.5': [352.0, 704.0], '0.52': [352.0, 672.0],
+ '0.57': [384.0, 672.0], '0.6': [384.0, 640.0], '0.68': [416.0, 608.0], '0.72': [416.0, 576.0],
+ '0.78': [448.0, 576.0], '0.82': [448.0, 544.0], '0.88': [480.0, 544.0], '0.94': [480.0, 512.0],
+ '1.0': [512.0, 512.0], '1.07': [512.0, 480.0], '1.13': [544.0, 480.0], '1.21': [544.0, 448.0],
+ '1.29': [576.0, 448.0], '1.38': [576.0, 416.0], '1.46': [608.0, 416.0], '1.67': [640.0, 384.0],
+ '1.75': [672.0, 384.0], '2.0': [704.0, 352.0], '2.09': [736.0, 352.0], '2.4': [768.0, 320.0],
+ '2.5': [800.0, 320.0], '2.89': [832.0, 288.0], '3.0': [864.0, 288.0], '3.11': [896.0, 288.0],
+ '3.62': [928.0, 256.0], '3.75': [960.0, 256.0], '3.88': [992.0, 256.0], '4.0': [1024.0, 256.0]
+}
+ASPECT_RATIO_RANDOM_CROP_512 = {
+ '0.42': [320.0, 768.0], '0.5': [352.0, 704.0],
+ '0.57': [384.0, 672.0], '0.68': [416.0, 608.0], '0.78': [448.0, 576.0], '0.88': [480.0, 544.0],
+ '0.94': [480.0, 512.0], '1.0': [512.0, 512.0], '1.07': [512.0, 480.0],
+ '1.13': [544.0, 480.0], '1.29': [576.0, 448.0], '1.46': [608.0, 416.0], '1.75': [672.0, 384.0],
+ '2.0': [704.0, 352.0], '2.4': [768.0, 320.0]
+}
+ASPECT_RATIO_RANDOM_CROP_PROB = [
+ 1, 2,
+ 4, 4, 4, 4,
+ 8, 8, 8,
+ 4, 4, 4, 4,
+ 2, 1
+]
+ASPECT_RATIO_RANDOM_CROP_PROB = np.array(ASPECT_RATIO_RANDOM_CROP_PROB) / sum(ASPECT_RATIO_RANDOM_CROP_PROB)
+
+def get_closest_ratio(height: float, width: float, ratios: dict = ASPECT_RATIO_512):
+ aspect_ratio = height / width
+ closest_ratio = min(ratios.keys(), key=lambda ratio: abs(float(ratio) - aspect_ratio))
+ return ratios[closest_ratio], float(closest_ratio)
+
+def get_image_size_without_loading(path):
+ with Image.open(path) as img:
+ return img.size # (width, height)
+
+class RandomSampler(Sampler[int]):
+ r"""Samples elements randomly. If without replacement, then sample from a shuffled dataset.
+
+ If with replacement, then user can specify :attr:`num_samples` to draw.
+
+ Args:
+ data_source (Dataset): dataset to sample from
+ replacement (bool): samples are drawn on-demand with replacement if ``True``, default=``False``
+ num_samples (int): number of samples to draw, default=`len(dataset)`.
+ generator (Generator): Generator used in sampling.
+ """
+
+ data_source: Sized
+ replacement: bool
+
+ def __init__(self, data_source: Sized, replacement: bool = False,
+ num_samples: Optional[int] = None, generator=None) -> None:
+ self.data_source = data_source
+ self.replacement = replacement
+ self._num_samples = num_samples
+ self.generator = generator
+ self._pos_start = 0
+
+ if not isinstance(self.replacement, bool):
+ raise TypeError(f"replacement should be a boolean value, but got replacement={self.replacement}")
+
+ if not isinstance(self.num_samples, int) or self.num_samples <= 0:
+ raise ValueError(f"num_samples should be a positive integer value, but got num_samples={self.num_samples}")
+
+ @property
+ def num_samples(self) -> int:
+ # dataset size might change at runtime
+ if self._num_samples is None:
+ return len(self.data_source)
+ return self._num_samples
+
+ def __iter__(self) -> Iterator[int]:
+ n = len(self.data_source)
+ if self.generator is None:
+ seed = int(torch.empty((), dtype=torch.int64).random_().item())
+ generator = torch.Generator()
+ generator.manual_seed(seed)
+ else:
+ generator = self.generator
+
+ if self.replacement:
+ for _ in range(self.num_samples // 32):
+ yield from torch.randint(high=n, size=(32,), dtype=torch.int64, generator=generator).tolist()
+ yield from torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64, generator=generator).tolist()
+ else:
+ for _ in range(self.num_samples // n):
+ xx = torch.randperm(n, generator=generator).tolist()
+ if self._pos_start >= n:
+ self._pos_start = 0
+ print("xx top 10", xx[:10], self._pos_start)
+ for idx in range(self._pos_start, n):
+ yield xx[idx]
+ self._pos_start = (self._pos_start + 1) % n
+ self._pos_start = 0
+ yield from torch.randperm(n, generator=generator).tolist()[:self.num_samples % n]
+
+ def __len__(self) -> int:
+ return self.num_samples
+
+class AspectRatioBatchImageSampler(BatchSampler):
+ """A sampler wrapper for grouping images with similar aspect ratio into a same batch.
+
+ Args:
+ sampler (Sampler): Base sampler.
+ dataset (Dataset): Dataset providing data information.
+ batch_size (int): Size of mini-batch.
+ drop_last (bool): If ``True``, the sampler will drop the last batch if
+ its size would be less than ``batch_size``.
+ aspect_ratios (dict): The predefined aspect ratios.
+ """
+ def __init__(
+ self,
+ sampler: Sampler,
+ dataset: Dataset,
+ batch_size: int,
+ train_folder: str = None,
+ aspect_ratios: dict = ASPECT_RATIO_512,
+ drop_last: bool = False,
+ config=None,
+ **kwargs
+ ) -> None:
+ if not isinstance(sampler, Sampler):
+ raise TypeError('sampler should be an instance of ``Sampler``, '
+ f'but got {sampler}')
+ if not isinstance(batch_size, int) or batch_size <= 0:
+ raise ValueError('batch_size should be a positive integer value, '
+ f'but got batch_size={batch_size}')
+ self.sampler = sampler
+ self.dataset = dataset
+ self.train_folder = train_folder
+ self.batch_size = batch_size
+ self.aspect_ratios = aspect_ratios
+ self.drop_last = drop_last
+ self.config = config
+ # buckets for each aspect ratio
+ self._aspect_ratio_buckets = {ratio: [] for ratio in aspect_ratios}
+ # [str(k) for k, v in aspect_ratios]
+ self.current_available_bucket_keys = list(aspect_ratios.keys())
+
+ def __iter__(self):
+ for idx in self.sampler:
+ try:
+ image_dict = self.dataset[idx]
+
+ width, height = image_dict.get("width", None), image_dict.get("height", None)
+ if width is None or height is None:
+ image_id, name = image_dict['file_path'], image_dict['text']
+ if self.train_folder is None:
+ image_dir = image_id
+ else:
+ image_dir = os.path.join(self.train_folder, image_id)
+
+ width, height = get_image_size_without_loading(image_dir)
+
+ ratio = height / width # self.dataset[idx]
+ else:
+ height = int(height)
+ width = int(width)
+ ratio = height / width # self.dataset[idx]
+ except Exception as e:
+ print(e)
+ continue
+ # find the closest aspect ratio
+ closest_ratio = min(self.aspect_ratios.keys(), key=lambda r: abs(float(r) - ratio))
+ if closest_ratio not in self.current_available_bucket_keys:
+ continue
+ bucket = self._aspect_ratio_buckets[closest_ratio]
+ bucket.append(idx)
+ # yield a batch of indices in the same aspect ratio group
+ if len(bucket) == self.batch_size:
+ yield bucket[:]
+ del bucket[:]
+
+class AspectRatioBatchSampler(BatchSampler):
+ """A sampler wrapper for grouping images with similar aspect ratio into a same batch.
+
+ Args:
+ sampler (Sampler): Base sampler.
+ dataset (Dataset): Dataset providing data information.
+ batch_size (int): Size of mini-batch.
+ drop_last (bool): If ``True``, the sampler will drop the last batch if
+ its size would be less than ``batch_size``.
+ aspect_ratios (dict): The predefined aspect ratios.
+ """
+ def __init__(
+ self,
+ sampler: Sampler,
+ dataset: Dataset,
+ batch_size: int,
+ video_folder: str = None,
+ train_data_format: str = "webvid",
+ aspect_ratios: dict = ASPECT_RATIO_512,
+ drop_last: bool = False,
+ config=None,
+ **kwargs
+ ) -> None:
+ if not isinstance(sampler, Sampler):
+ raise TypeError('sampler should be an instance of ``Sampler``, '
+ f'but got {sampler}')
+ if not isinstance(batch_size, int) or batch_size <= 0:
+ raise ValueError('batch_size should be a positive integer value, '
+ f'but got batch_size={batch_size}')
+ self.sampler = sampler
+ self.dataset = dataset
+ self.video_folder = video_folder
+ self.train_data_format = train_data_format
+ self.batch_size = batch_size
+ self.aspect_ratios = aspect_ratios
+ self.drop_last = drop_last
+ self.config = config
+ # buckets for each aspect ratio
+ self._aspect_ratio_buckets = {ratio: [] for ratio in aspect_ratios}
+ # [str(k) for k, v in aspect_ratios]
+ self.current_available_bucket_keys = list(aspect_ratios.keys())
+
+ def __iter__(self):
+ for idx in self.sampler:
+ try:
+ video_dict = self.dataset[idx]
+ width, more = video_dict.get("width", None), video_dict.get("height", None)
+
+ if width is None or height is None:
+ if self.train_data_format == "normal":
+ video_id, name = video_dict['file_path'], video_dict['text']
+ if self.video_folder is None:
+ video_dir = video_id
+ else:
+ video_dir = os.path.join(self.video_folder, video_id)
+ else:
+ videoid, name, page_dir = video_dict['videoid'], video_dict['name'], video_dict['page_dir']
+ video_dir = os.path.join(self.video_folder, f"{videoid}.mp4")
+ cap = cv2.VideoCapture(video_dir)
+
+ # 获取视频尺寸
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) # 浮点数转换为整数
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) # 浮点数转换为整数
+
+ ratio = height / width # self.dataset[idx]
+ else:
+ height = int(height)
+ width = int(width)
+ ratio = height / width # self.dataset[idx]
+ except Exception as e:
+ print(e)
+ continue
+ # find the closest aspect ratio
+ closest_ratio = min(self.aspect_ratios.keys(), key=lambda r: abs(float(r) - ratio))
+ if closest_ratio not in self.current_available_bucket_keys:
+ continue
+ bucket = self._aspect_ratio_buckets[closest_ratio]
+ bucket.append(idx)
+ # yield a batch of indices in the same aspect ratio group
+ if len(bucket) == self.batch_size:
+ yield bucket[:]
+ del bucket[:]
+
+class AspectRatioBatchImageVideoSampler(BatchSampler):
+ """A sampler wrapper for grouping images with similar aspect ratio into a same batch.
+
+ Args:
+ sampler (Sampler): Base sampler.
+ dataset (Dataset): Dataset providing data information.
+ batch_size (int): Size of mini-batch.
+ drop_last (bool): If ``True``, the sampler will drop the last batch if
+ its size would be less than ``batch_size``.
+ aspect_ratios (dict): The predefined aspect ratios.
+ """
+
+ def __init__(self,
+ sampler: Sampler,
+ dataset: Dataset,
+ batch_size: int,
+ train_folder: str = None,
+ aspect_ratios: dict = ASPECT_RATIO_512,
+ drop_last: bool = False
+ ) -> None:
+ if not isinstance(sampler, Sampler):
+ raise TypeError('sampler should be an instance of ``Sampler``, '
+ f'but got {sampler}')
+ if not isinstance(batch_size, int) or batch_size <= 0:
+ raise ValueError('batch_size should be a positive integer value, '
+ f'but got batch_size={batch_size}')
+ self.sampler = sampler
+ self.dataset = dataset
+ self.train_folder = train_folder
+ self.batch_size = batch_size
+ self.aspect_ratios = aspect_ratios
+ self.drop_last = drop_last
+
+ # buckets for each aspect ratio
+ self.current_available_bucket_keys = list(aspect_ratios.keys())
+ self.bucket = {
+ 'image':{ratio: [] for ratio in aspect_ratios},
+ 'video':{ratio: [] for ratio in aspect_ratios}
+ }
+
+ def __iter__(self):
+ for idx in self.sampler:
+ content_type = self.dataset[idx].get('type', 'image')
+ if content_type == 'image':
+ try:
+ image_dict = self.dataset[idx]
+
+ width, height = image_dict.get("width", None), image_dict.get("height", None)
+ if width is None or height is None:
+ image_id, name = image_dict['file_path'], image_dict['text']
+ if self.train_folder is None:
+ image_dir = image_id
+ else:
+ image_dir = os.path.join(self.train_folder, image_id)
+
+ width, height = get_image_size_without_loading(image_dir)
+
+ ratio = height / width # self.dataset[idx]
+ else:
+ height = int(height)
+ width = int(width)
+ ratio = height / width # self.dataset[idx]
+ except Exception as e:
+ print(e)
+ continue
+ # find the closest aspect ratio
+ closest_ratio = min(self.aspect_ratios.keys(), key=lambda r: abs(float(r) - ratio))
+ if closest_ratio not in self.current_available_bucket_keys:
+ continue
+ bucket = self.bucket['image'][closest_ratio]
+ bucket.append(idx)
+ # yield a batch of indices in the same aspect ratio group
+ if len(bucket) == self.batch_size:
+ yield bucket[:]
+ del bucket[:]
+ else:
+ try:
+ video_dict = self.dataset[idx]
+ width, height = video_dict.get("width", None), video_dict.get("height", None)
+
+ if width is None or height is None:
+ video_id, name = video_dict['file_path'], video_dict['text']
+ if self.train_folder is None:
+ video_dir = video_id
+ else:
+ video_dir = os.path.join(self.train_folder, video_id)
+ cap = cv2.VideoCapture(video_dir)
+
+ # 获取视频尺寸
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) # 浮点数转换为整数
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) # 浮点数转换为整数
+
+ ratio = height / width # self.dataset[idx]
+ else:
+ height = int(height)
+ width = int(width)
+ ratio = height / width # self.dataset[idx]
+ except Exception as e:
+ print(e)
+ continue
+ # find the closest aspect ratio
+ closest_ratio = min(self.aspect_ratios.keys(), key=lambda r: abs(float(r) - ratio))
+ if closest_ratio not in self.current_available_bucket_keys:
+ continue
+ bucket = self.bucket['video'][closest_ratio]
+ bucket.append(idx)
+ # yield a batch of indices in the same aspect ratio group
+ if len(bucket) == self.batch_size:
+ yield bucket[:]
+ del bucket[:]
\ No newline at end of file
diff --git a/easyanimate/data/dataset_image.py b/easyanimate/data/dataset_image.py
new file mode 100644
index 0000000000000000000000000000000000000000..098d49a4044f8daa351cd01b4cb1ec5415412e80
--- /dev/null
+++ b/easyanimate/data/dataset_image.py
@@ -0,0 +1,76 @@
+import json
+import os
+import random
+
+import numpy as np
+import torch
+import torchvision.transforms as transforms
+from PIL import Image
+from torch.utils.data.dataset import Dataset
+
+
+class CC15M(Dataset):
+ def __init__(
+ self,
+ json_path,
+ video_folder=None,
+ resolution=512,
+ enable_bucket=False,
+ ):
+ print(f"loading annotations from {json_path} ...")
+ self.dataset = json.load(open(json_path, 'r'))
+ self.length = len(self.dataset)
+ print(f"data scale: {self.length}")
+
+ self.enable_bucket = enable_bucket
+ self.video_folder = video_folder
+
+ resolution = tuple(resolution) if not isinstance(resolution, int) else (resolution, resolution)
+ self.pixel_transforms = transforms.Compose([
+ transforms.Resize(resolution[0]),
+ transforms.CenterCrop(resolution),
+ transforms.ToTensor(),
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
+ ])
+
+ def get_batch(self, idx):
+ video_dict = self.dataset[idx]
+ video_id, name = video_dict['file_path'], video_dict['text']
+
+ if self.video_folder is None:
+ video_dir = video_id
+ else:
+ video_dir = os.path.join(self.video_folder, video_id)
+
+ pixel_values = Image.open(video_dir).convert("RGB")
+ return pixel_values, name
+
+ def __len__(self):
+ return self.length
+
+ def __getitem__(self, idx):
+ while True:
+ try:
+ pixel_values, name = self.get_batch(idx)
+ break
+ except Exception as e:
+ print(e)
+ idx = random.randint(0, self.length-1)
+
+ if not self.enable_bucket:
+ pixel_values = self.pixel_transforms(pixel_values)
+ else:
+ pixel_values = np.array(pixel_values)
+
+ sample = dict(pixel_values=pixel_values, text=name)
+ return sample
+
+if __name__ == "__main__":
+ dataset = CC15M(
+ csv_path="/mnt_wg/zhoumo.xjq/CCUtils/cc15m_add_index.json",
+ resolution=512,
+ )
+
+ dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, num_workers=0,)
+ for idx, batch in enumerate(dataloader):
+ print(batch["pixel_values"].shape, len(batch["text"]))
\ No newline at end of file
diff --git a/easyanimate/data/dataset_image_video.py b/easyanimate/data/dataset_image_video.py
new file mode 100644
index 0000000000000000000000000000000000000000..5b50c22fa8b2d0d203f627c97c72535559e247e8
--- /dev/null
+++ b/easyanimate/data/dataset_image_video.py
@@ -0,0 +1,241 @@
+import csv
+import io
+import json
+import math
+import os
+import random
+from threading import Thread
+
+import albumentations
+import cv2
+import gc
+import numpy as np
+import torch
+import torchvision.transforms as transforms
+from func_timeout import func_timeout, FunctionTimedOut
+from decord import VideoReader
+from PIL import Image
+from torch.utils.data import BatchSampler, Sampler
+from torch.utils.data.dataset import Dataset
+from contextlib import contextmanager
+
+VIDEO_READER_TIMEOUT = 20
+
+class ImageVideoSampler(BatchSampler):
+ """A sampler wrapper for grouping images with similar aspect ratio into a same batch.
+
+ Args:
+ sampler (Sampler): Base sampler.
+ dataset (Dataset): Dataset providing data information.
+ batch_size (int): Size of mini-batch.
+ drop_last (bool): If ``True``, the sampler will drop the last batch if
+ its size would be less than ``batch_size``.
+ aspect_ratios (dict): The predefined aspect ratios.
+ """
+
+ def __init__(self,
+ sampler: Sampler,
+ dataset: Dataset,
+ batch_size: int,
+ drop_last: bool = False
+ ) -> None:
+ if not isinstance(sampler, Sampler):
+ raise TypeError('sampler should be an instance of ``Sampler``, '
+ f'but got {sampler}')
+ if not isinstance(batch_size, int) or batch_size <= 0:
+ raise ValueError('batch_size should be a positive integer value, '
+ f'but got batch_size={batch_size}')
+ self.sampler = sampler
+ self.dataset = dataset
+ self.batch_size = batch_size
+ self.drop_last = drop_last
+
+ # buckets for each aspect ratio
+ self.bucket = {'image':[], 'video':[]}
+
+ def __iter__(self):
+ for idx in self.sampler:
+ content_type = self.dataset.dataset[idx].get('type', 'image')
+ self.bucket[content_type].append(idx)
+
+ # yield a batch of indices in the same aspect ratio group
+ if len(self.bucket['video']) == self.batch_size:
+ bucket = self.bucket['video']
+ yield bucket[:]
+ del bucket[:]
+ elif len(self.bucket['image']) == self.batch_size:
+ bucket = self.bucket['image']
+ yield bucket[:]
+ del bucket[:]
+
+@contextmanager
+def VideoReader_contextmanager(*args, **kwargs):
+ vr = VideoReader(*args, **kwargs)
+ try:
+ yield vr
+ finally:
+ del vr
+ gc.collect()
+
+def get_video_reader_batch(video_reader, batch_index):
+ frames = video_reader.get_batch(batch_index).asnumpy()
+ return frames
+
+class ImageVideoDataset(Dataset):
+ def __init__(
+ self,
+ ann_path, data_root=None,
+ video_sample_size=512, video_sample_stride=4, video_sample_n_frames=16,
+ image_sample_size=512,
+ video_repeat=0,
+ text_drop_ratio=0.001,
+ enable_bucket=False,
+ video_length_drop_start=0.1,
+ video_length_drop_end=0.9,
+ ):
+ # Loading annotations from files
+ print(f"loading annotations from {ann_path} ...")
+ if ann_path.endswith('.csv'):
+ with open(ann_path, 'r') as csvfile:
+ dataset = list(csv.DictReader(csvfile))
+ elif ann_path.endswith('.json'):
+ dataset = json.load(open(ann_path))
+
+ self.data_root = data_root
+
+ # It's used to balance num of images and videos.
+ self.dataset = []
+ for data in dataset:
+ if data.get('type', 'image') != 'video':
+ self.dataset.append(data)
+ if video_repeat > 0:
+ for _ in range(video_repeat):
+ for data in dataset:
+ if data.get('type', 'image') == 'video':
+ self.dataset.append(data)
+ del dataset
+
+ self.length = len(self.dataset)
+ print(f"data scale: {self.length}")
+ # TODO: enable bucket training
+ self.enable_bucket = enable_bucket
+ self.text_drop_ratio = text_drop_ratio
+ self.video_length_drop_start = video_length_drop_start
+ self.video_length_drop_end = video_length_drop_end
+
+ # Video params
+ self.video_sample_stride = video_sample_stride
+ self.video_sample_n_frames = video_sample_n_frames
+ video_sample_size = tuple(video_sample_size) if not isinstance(video_sample_size, int) else (video_sample_size, video_sample_size)
+ self.video_transforms = transforms.Compose(
+ [
+ transforms.Resize(video_sample_size[0]),
+ transforms.CenterCrop(video_sample_size),
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
+ ]
+ )
+
+ # Image params
+ self.image_sample_size = tuple(image_sample_size) if not isinstance(image_sample_size, int) else (image_sample_size, image_sample_size)
+ self.image_transforms = transforms.Compose([
+ transforms.Resize(min(self.image_sample_size)),
+ transforms.CenterCrop(self.image_sample_size),
+ transforms.ToTensor(),
+ transforms.Normalize([0.5, 0.5, 0.5],[0.5, 0.5, 0.5])
+ ])
+
+ def get_batch(self, idx):
+ data_info = self.dataset[idx % len(self.dataset)]
+
+ if data_info.get('type', 'image')=='video':
+ video_id, text = data_info['file_path'], data_info['text']
+
+ if self.data_root is None:
+ video_dir = video_id
+ else:
+ video_dir = os.path.join(self.data_root, video_id)
+
+ with VideoReader_contextmanager(video_dir, num_threads=2) as video_reader:
+ min_sample_n_frames = min(
+ self.video_sample_n_frames,
+ int(len(video_reader) * (self.video_length_drop_end - self.video_length_drop_start))
+ )
+ if min_sample_n_frames == 0:
+ raise ValueError(f"No Frames in video.")
+
+ video_length = int(self.video_length_drop_end * len(video_reader))
+ clip_length = min(video_length, (min_sample_n_frames - 1) * self.video_sample_stride + 1)
+ start_idx = random.randint(int(self.video_length_drop_start * video_length), video_length - clip_length)
+ batch_index = np.linspace(start_idx, start_idx + clip_length - 1, min_sample_n_frames, dtype=int)
+
+ try:
+ sample_args = (video_reader, batch_index)
+ pixel_values = func_timeout(
+ VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args
+ )
+ except FunctionTimedOut:
+ raise ValueError(f"Read {idx} timeout.")
+ except Exception as e:
+ raise ValueError(f"Failed to extract frames from video. Error is {e}.")
+
+ if not self.enable_bucket:
+ pixel_values = torch.from_numpy(pixel_values).permute(0, 3, 1, 2).contiguous()
+ pixel_values = pixel_values / 255.
+ del video_reader
+ else:
+ pixel_values = pixel_values
+
+ if not self.enable_bucket:
+ pixel_values = self.video_transforms(pixel_values)
+
+ # Random use no text generation
+ if random.random() < self.text_drop_ratio:
+ text = ''
+ return pixel_values, text, 'video'
+ else:
+ image_path, text = data_info['file_path'], data_info['text']
+ if self.data_root is not None:
+ image_path = os.path.join(self.data_root, image_path)
+ image = Image.open(image_path).convert('RGB')
+ if not self.enable_bucket:
+ image = self.image_transforms(image).unsqueeze(0)
+ else:
+ image = np.expand_dims(np.array(image), 0)
+ if random.random() < self.text_drop_ratio:
+ text = ''
+ return image, text, 'image'
+
+ def __len__(self):
+ return self.length
+
+ def __getitem__(self, idx):
+ data_info = self.dataset[idx % len(self.dataset)]
+ data_type = data_info.get('type', 'image')
+ while True:
+ sample = {}
+ try:
+ data_info_local = self.dataset[idx % len(self.dataset)]
+ data_type_local = data_info_local.get('type', 'image')
+ if data_type_local != data_type:
+ raise ValueError("data_type_local != data_type")
+
+ pixel_values, name, data_type = self.get_batch(idx)
+ sample["pixel_values"] = pixel_values
+ sample["text"] = name
+ sample["data_type"] = data_type
+ sample["idx"] = idx
+
+ if len(sample) > 0:
+ break
+ except Exception as e:
+ print(e, self.dataset[idx % len(self.dataset)])
+ idx = random.randint(0, self.length-1)
+ return sample
+
+if __name__ == "__main__":
+ dataset = ImageVideoDataset(
+ ann_path="test.json"
+ )
+ dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, num_workers=16)
+ for idx, batch in enumerate(dataloader):
+ print(batch["pixel_values"].shape, len(batch["text"]))
\ No newline at end of file
diff --git a/easyanimate/data/dataset_video.py b/easyanimate/data/dataset_video.py
new file mode 100644
index 0000000000000000000000000000000000000000..c78367d0973ceb1abdcd005947612d16e2480831
--- /dev/null
+++ b/easyanimate/data/dataset_video.py
@@ -0,0 +1,262 @@
+import csv
+import gc
+import io
+import json
+import math
+import os
+import random
+from contextlib import contextmanager
+from threading import Thread
+
+import albumentations
+import cv2
+import numpy as np
+import torch
+import torchvision.transforms as transforms
+from decord import VideoReader
+from einops import rearrange
+from func_timeout import FunctionTimedOut, func_timeout
+from PIL import Image
+from torch.utils.data import BatchSampler, Sampler
+from torch.utils.data.dataset import Dataset
+
+VIDEO_READER_TIMEOUT = 20
+
+def get_random_mask(shape):
+ f, c, h, w = shape
+
+ mask_index = np.random.randint(0, 4)
+ mask = torch.zeros((f, 1, h, w), dtype=torch.uint8)
+ if mask_index == 0:
+ mask[1:, :, :, :] = 1
+ elif mask_index == 1:
+ mask_frame_index = 1
+ mask[mask_frame_index:-mask_frame_index, :, :, :] = 1
+ elif mask_index == 2:
+ center_x = torch.randint(0, w, (1,)).item()
+ center_y = torch.randint(0, h, (1,)).item()
+ block_size_x = torch.randint(w // 4, w // 4 * 3, (1,)).item() # 方块的宽度范围
+ block_size_y = torch.randint(h // 4, h // 4 * 3, (1,)).item() # 方块的高度范围
+
+ start_x = max(center_x - block_size_x // 2, 0)
+ end_x = min(center_x + block_size_x // 2, w)
+ start_y = max(center_y - block_size_y // 2, 0)
+ end_y = min(center_y + block_size_y // 2, h)
+ mask[:, :, start_y:end_y, start_x:end_x] = 1
+ elif mask_index == 3:
+ center_x = torch.randint(0, w, (1,)).item()
+ center_y = torch.randint(0, h, (1,)).item()
+ block_size_x = torch.randint(w // 4, w // 4 * 3, (1,)).item() # 方块的宽度范围
+ block_size_y = torch.randint(h // 4, h // 4 * 3, (1,)).item() # 方块的高度范围
+
+ start_x = max(center_x - block_size_x // 2, 0)
+ end_x = min(center_x + block_size_x // 2, w)
+ start_y = max(center_y - block_size_y // 2, 0)
+ end_y = min(center_y + block_size_y // 2, h)
+
+ mask_frame_before = np.random.randint(0, f // 2)
+ mask_frame_after = np.random.randint(f // 2, f)
+ mask[mask_frame_before:mask_frame_after, :, start_y:end_y, start_x:end_x] = 1
+ else:
+ raise ValueError(f"The mask_index {mask_index} is not define")
+ return mask
+
+
+@contextmanager
+def VideoReader_contextmanager(*args, **kwargs):
+ vr = VideoReader(*args, **kwargs)
+ try:
+ yield vr
+ finally:
+ del vr
+ gc.collect()
+
+
+def get_video_reader_batch(video_reader, batch_index):
+ frames = video_reader.get_batch(batch_index).asnumpy()
+ return frames
+
+
+class WebVid10M(Dataset):
+ def __init__(
+ self,
+ csv_path, video_folder,
+ sample_size=256, sample_stride=4, sample_n_frames=16,
+ enable_bucket=False, enable_inpaint=False, is_image=False,
+ ):
+ print(f"loading annotations from {csv_path} ...")
+ with open(csv_path, 'r') as csvfile:
+ self.dataset = list(csv.DictReader(csvfile))
+ self.length = len(self.dataset)
+ print(f"data scale: {self.length}")
+
+ self.video_folder = video_folder
+ self.sample_stride = sample_stride
+ self.sample_n_frames = sample_n_frames
+ self.enable_bucket = enable_bucket
+ self.enable_inpaint = enable_inpaint
+ self.is_image = is_image
+
+ sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size)
+ self.pixel_transforms = transforms.Compose([
+ transforms.Resize(sample_size[0]),
+ transforms.CenterCrop(sample_size),
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
+ ])
+
+ def get_batch(self, idx):
+ video_dict = self.dataset[idx]
+ videoid, name, page_dir = video_dict['videoid'], video_dict['name'], video_dict['page_dir']
+
+ video_dir = os.path.join(self.video_folder, f"{videoid}.mp4")
+ video_reader = VideoReader(video_dir)
+ video_length = len(video_reader)
+
+ if not self.is_image:
+ clip_length = min(video_length, (self.sample_n_frames - 1) * self.sample_stride + 1)
+ start_idx = random.randint(0, video_length - clip_length)
+ batch_index = np.linspace(start_idx, start_idx + clip_length - 1, self.sample_n_frames, dtype=int)
+ else:
+ batch_index = [random.randint(0, video_length - 1)]
+
+ if not self.enable_bucket:
+ pixel_values = torch.from_numpy(video_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous()
+ pixel_values = pixel_values / 255.
+ del video_reader
+ else:
+ pixel_values = video_reader.get_batch(batch_index).asnumpy()
+
+ if self.is_image:
+ pixel_values = pixel_values[0]
+ return pixel_values, name
+
+ def __len__(self):
+ return self.length
+
+ def __getitem__(self, idx):
+ while True:
+ try:
+ pixel_values, name = self.get_batch(idx)
+ break
+
+ except Exception as e:
+ print("Error info:", e)
+ idx = random.randint(0, self.length-1)
+
+ if not self.enable_bucket:
+ pixel_values = self.pixel_transforms(pixel_values)
+ if self.enable_inpaint:
+ mask = get_random_mask(pixel_values.size())
+ mask_pixel_values = pixel_values * (1 - mask) + torch.ones_like(pixel_values) * -1 * mask
+ sample = dict(pixel_values=pixel_values, mask_pixel_values=mask_pixel_values, mask=mask, text=name)
+ else:
+ sample = dict(pixel_values=pixel_values, text=name)
+ return sample
+
+
+class VideoDataset(Dataset):
+ def __init__(
+ self,
+ json_path, video_folder=None,
+ sample_size=256, sample_stride=4, sample_n_frames=16,
+ enable_bucket=False, enable_inpaint=False
+ ):
+ print(f"loading annotations from {json_path} ...")
+ self.dataset = json.load(open(json_path, 'r'))
+ self.length = len(self.dataset)
+ print(f"data scale: {self.length}")
+
+ self.video_folder = video_folder
+ self.sample_stride = sample_stride
+ self.sample_n_frames = sample_n_frames
+ self.enable_bucket = enable_bucket
+ self.enable_inpaint = enable_inpaint
+
+ sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size)
+ self.pixel_transforms = transforms.Compose(
+ [
+ transforms.Resize(sample_size[0]),
+ transforms.CenterCrop(sample_size),
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
+ ]
+ )
+
+ def get_batch(self, idx):
+ video_dict = self.dataset[idx]
+ video_id, name = video_dict['file_path'], video_dict['text']
+
+ if self.video_folder is None:
+ video_dir = video_id
+ else:
+ video_dir = os.path.join(self.video_folder, video_id)
+
+ with VideoReader_contextmanager(video_dir, num_threads=2) as video_reader:
+ video_length = len(video_reader)
+
+ clip_length = min(video_length, (self.sample_n_frames - 1) * self.sample_stride + 1)
+ start_idx = random.randint(0, video_length - clip_length)
+ batch_index = np.linspace(start_idx, start_idx + clip_length - 1, self.sample_n_frames, dtype=int)
+
+ try:
+ sample_args = (video_reader, batch_index)
+ pixel_values = func_timeout(
+ VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args
+ )
+ except FunctionTimedOut:
+ raise ValueError(f"Read {idx} timeout.")
+ except Exception as e:
+ raise ValueError(f"Failed to extract frames from video. Error is {e}.")
+
+ if not self.enable_bucket:
+ pixel_values = torch.from_numpy(pixel_values).permute(0, 3, 1, 2).contiguous()
+ pixel_values = pixel_values / 255.
+ del video_reader
+ else:
+ pixel_values = pixel_values
+
+ return pixel_values, name
+
+ def __len__(self):
+ return self.length
+
+ def __getitem__(self, idx):
+ while True:
+ try:
+ pixel_values, name = self.get_batch(idx)
+ break
+
+ except Exception as e:
+ print("Error info:", e)
+ idx = random.randint(0, self.length-1)
+
+ if not self.enable_bucket:
+ pixel_values = self.pixel_transforms(pixel_values)
+ if self.enable_inpaint:
+ mask = get_random_mask(pixel_values.size())
+ mask_pixel_values = pixel_values * (1 - mask) + torch.ones_like(pixel_values) * -1 * mask
+ sample = dict(pixel_values=pixel_values, mask_pixel_values=mask_pixel_values, mask=mask, text=name)
+ else:
+ sample = dict(pixel_values=pixel_values, text=name)
+ return sample
+
+
+if __name__ == "__main__":
+ if 1:
+ dataset = VideoDataset(
+ json_path="/home/zhoumo.xjq/disk3/datasets/webvidval/results_2M_val.json",
+ sample_size=256,
+ sample_stride=4, sample_n_frames=16,
+ )
+
+ if 0:
+ dataset = WebVid10M(
+ csv_path="/mnt/petrelfs/guoyuwei/projects/datasets/webvid/results_2M_val.csv",
+ video_folder="/mnt/petrelfs/guoyuwei/projects/datasets/webvid/2M_val",
+ sample_size=256,
+ sample_stride=4, sample_n_frames=16,
+ is_image=False,
+ )
+
+ dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, num_workers=0,)
+ for idx, batch in enumerate(dataloader):
+ print(batch["pixel_values"].shape, len(batch["text"]))
\ No newline at end of file
diff --git a/easyanimate/models/attention.py b/easyanimate/models/attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..ec11aba614dc22dbc31efcb85c7c2d4f9207412a
--- /dev/null
+++ b/easyanimate/models/attention.py
@@ -0,0 +1,1299 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import math
+from typing import Any, Dict, Optional
+
+import torch
+import torch.nn.functional as F
+import torch.nn.init as init
+from diffusers.models.activations import GEGLU, GELU, ApproximateGELU
+from diffusers.models.attention import AdaLayerNorm, FeedForward
+from diffusers.models.attention_processor import Attention
+from diffusers.models.embeddings import SinusoidalPositionalEmbedding
+from diffusers.models.lora import LoRACompatibleLinear
+from diffusers.models.normalization import AdaLayerNorm, AdaLayerNormZero
+from diffusers.utils import USE_PEFT_BACKEND
+from diffusers.utils.import_utils import is_xformers_available
+from diffusers.utils.torch_utils import maybe_allow_in_graph
+from einops import rearrange, repeat
+from torch import nn
+
+from .motion_module import get_motion_module
+
+if is_xformers_available():
+ import xformers
+ import xformers.ops
+else:
+ xformers = None
+
+
+@maybe_allow_in_graph
+class GatedSelfAttentionDense(nn.Module):
+ r"""
+ A gated self-attention dense layer that combines visual features and object features.
+
+ Parameters:
+ query_dim (`int`): The number of channels in the query.
+ context_dim (`int`): The number of channels in the context.
+ n_heads (`int`): The number of heads to use for attention.
+ d_head (`int`): The number of channels in each head.
+ """
+
+ def __init__(self, query_dim: int, context_dim: int, n_heads: int, d_head: int):
+ super().__init__()
+
+ # we need a linear projection since we need cat visual feature and obj feature
+ self.linear = nn.Linear(context_dim, query_dim)
+
+ self.attn = Attention(query_dim=query_dim, heads=n_heads, dim_head=d_head)
+ self.ff = FeedForward(query_dim, activation_fn="geglu")
+
+ self.norm1 = nn.LayerNorm(query_dim)
+ self.norm2 = nn.LayerNorm(query_dim)
+
+ self.register_parameter("alpha_attn", nn.Parameter(torch.tensor(0.0)))
+ self.register_parameter("alpha_dense", nn.Parameter(torch.tensor(0.0)))
+
+ self.enabled = True
+
+ def forward(self, x: torch.Tensor, objs: torch.Tensor) -> torch.Tensor:
+ if not self.enabled:
+ return x
+
+ n_visual = x.shape[1]
+ objs = self.linear(objs)
+
+ x = x + self.alpha_attn.tanh() * self.attn(self.norm1(torch.cat([x, objs], dim=1)))[:, :n_visual, :]
+ x = x + self.alpha_dense.tanh() * self.ff(self.norm2(x))
+
+ return x
+
+
+def zero_module(module):
+ # Zero out the parameters of a module and return it.
+ for p in module.parameters():
+ p.detach().zero_()
+ return module
+
+
+
+class KVCompressionCrossAttention(nn.Module):
+ r"""
+ A cross attention layer.
+
+ Parameters:
+ query_dim (`int`): The number of channels in the query.
+ cross_attention_dim (`int`, *optional*):
+ The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
+ heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
+ dim_head (`int`, *optional*, defaults to 64): The number of channels in each head.
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
+ bias (`bool`, *optional*, defaults to False):
+ Set to `True` for the query, key, and value linear layers to contain a bias parameter.
+ """
+
+ def __init__(
+ self,
+ query_dim: int,
+ cross_attention_dim: Optional[int] = None,
+ heads: int = 8,
+ dim_head: int = 64,
+ dropout: float = 0.0,
+ bias=False,
+ upcast_attention: bool = False,
+ upcast_softmax: bool = False,
+ added_kv_proj_dim: Optional[int] = None,
+ norm_num_groups: Optional[int] = None,
+ ):
+ super().__init__()
+ inner_dim = dim_head * heads
+ cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
+ self.upcast_attention = upcast_attention
+ self.upcast_softmax = upcast_softmax
+
+ self.scale = dim_head**-0.5
+
+ self.heads = heads
+ # for slice_size > 0 the attention score computation
+ # is split across the batch axis to save memory
+ # You can set slice_size with `set_attention_slice`
+ self.sliceable_head_dim = heads
+ self._slice_size = None
+ self._use_memory_efficient_attention_xformers = True
+ self.added_kv_proj_dim = added_kv_proj_dim
+
+ if norm_num_groups is not None:
+ self.group_norm = nn.GroupNorm(num_channels=inner_dim, num_groups=norm_num_groups, eps=1e-5, affine=True)
+ else:
+ self.group_norm = None
+
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
+ self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
+ self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
+
+ if self.added_kv_proj_dim is not None:
+ self.add_k_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
+ self.add_v_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
+
+ self.kv_compression = nn.Conv2d(
+ query_dim,
+ query_dim,
+ groups=query_dim,
+ kernel_size=2,
+ stride=2,
+ bias=True
+ )
+ self.kv_compression_norm = nn.LayerNorm(query_dim)
+ init.constant_(self.kv_compression.weight, 1 / 4)
+ if self.kv_compression.bias is not None:
+ init.constant_(self.kv_compression.bias, 0)
+
+ self.to_out = nn.ModuleList([])
+ self.to_out.append(nn.Linear(inner_dim, query_dim))
+ self.to_out.append(nn.Dropout(dropout))
+
+ def reshape_heads_to_batch_dim(self, tensor):
+ batch_size, seq_len, dim = tensor.shape
+ head_size = self.heads
+ tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
+ return tensor
+
+ def reshape_batch_dim_to_heads(self, tensor):
+ batch_size, seq_len, dim = tensor.shape
+ head_size = self.heads
+ tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
+ return tensor
+
+ def set_attention_slice(self, slice_size):
+ if slice_size is not None and slice_size > self.sliceable_head_dim:
+ raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.")
+
+ self._slice_size = slice_size
+
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, num_frames: int = 16, height: int = 32, width: int = 32):
+ batch_size, sequence_length, _ = hidden_states.shape
+
+ encoder_hidden_states = encoder_hidden_states
+
+ if self.group_norm is not None:
+ hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ query = self.to_q(hidden_states)
+ dim = query.shape[-1]
+ query = self.reshape_heads_to_batch_dim(query)
+
+ if self.added_kv_proj_dim is not None:
+ key = self.to_k(hidden_states)
+ value = self.to_v(hidden_states)
+
+ encoder_hidden_states_key_proj = self.add_k_proj(encoder_hidden_states)
+ encoder_hidden_states_value_proj = self.add_v_proj(encoder_hidden_states)
+
+ key = rearrange(key, "b (f h w) c -> (b f) c h w", f=num_frames, h=height, w=width)
+ key = self.kv_compression(key)
+ key = rearrange(key, "(b f) c h w -> b (f h w) c", f=num_frames)
+ key = self.kv_compression_norm(key)
+ key = key.to(query.dtype)
+
+ value = rearrange(value, "b (f h w) c -> (b f) c h w", f=num_frames, h=height, w=width)
+ value = self.kv_compression(value)
+ value = rearrange(value, "(b f) c h w -> b (f h w) c", f=num_frames)
+ value = self.kv_compression_norm(value)
+ value = value.to(query.dtype)
+
+ key = self.reshape_heads_to_batch_dim(key)
+ value = self.reshape_heads_to_batch_dim(value)
+ encoder_hidden_states_key_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_key_proj)
+ encoder_hidden_states_value_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_value_proj)
+
+ key = torch.concat([encoder_hidden_states_key_proj, key], dim=1)
+ value = torch.concat([encoder_hidden_states_value_proj, value], dim=1)
+ else:
+ encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
+ key = self.to_k(encoder_hidden_states)
+ value = self.to_v(encoder_hidden_states)
+
+ key = rearrange(key, "b (f h w) c -> (b f) c h w", f=num_frames, h=height, w=width)
+ key = self.kv_compression(key)
+ key = rearrange(key, "(b f) c h w -> b (f h w) c", f=num_frames)
+ key = self.kv_compression_norm(key)
+ key = key.to(query.dtype)
+
+ value = rearrange(value, "b (f h w) c -> (b f) c h w", f=num_frames, h=height, w=width)
+ value = self.kv_compression(value)
+ value = rearrange(value, "(b f) c h w -> b (f h w) c", f=num_frames)
+ value = self.kv_compression_norm(value)
+ value = value.to(query.dtype)
+
+ key = self.reshape_heads_to_batch_dim(key)
+ value = self.reshape_heads_to_batch_dim(value)
+
+ if attention_mask is not None:
+ if attention_mask.shape[-1] != query.shape[1]:
+ target_length = query.shape[1]
+ attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
+ attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
+
+ # attention, what we cannot get enough of
+ if self._use_memory_efficient_attention_xformers:
+ hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
+ # Some versions of xformers return output in fp32, cast it back to the dtype of the input
+ hidden_states = hidden_states.to(query.dtype)
+ else:
+ if self._slice_size is None or query.shape[0] // self._slice_size == 1:
+ hidden_states = self._attention(query, key, value, attention_mask)
+ else:
+ hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
+
+ # linear proj
+ hidden_states = self.to_out[0](hidden_states)
+
+ # dropout
+ hidden_states = self.to_out[1](hidden_states)
+ return hidden_states
+
+ def _attention(self, query, key, value, attention_mask=None):
+ if self.upcast_attention:
+ query = query.float()
+ key = key.float()
+
+ attention_scores = torch.baddbmm(
+ torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
+ query,
+ key.transpose(-1, -2),
+ beta=0,
+ alpha=self.scale,
+ )
+
+ if attention_mask is not None:
+ attention_scores = attention_scores + attention_mask
+
+ if self.upcast_softmax:
+ attention_scores = attention_scores.float()
+
+ attention_probs = attention_scores.softmax(dim=-1)
+
+ # cast back to the original dtype
+ attention_probs = attention_probs.to(value.dtype)
+
+ # compute attention output
+ hidden_states = torch.bmm(attention_probs, value)
+
+ # reshape hidden_states
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
+ return hidden_states
+
+ def _sliced_attention(self, query, key, value, sequence_length, dim, attention_mask):
+ batch_size_attention = query.shape[0]
+ hidden_states = torch.zeros(
+ (batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype
+ )
+ slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0]
+ for i in range(hidden_states.shape[0] // slice_size):
+ start_idx = i * slice_size
+ end_idx = (i + 1) * slice_size
+
+ query_slice = query[start_idx:end_idx]
+ key_slice = key[start_idx:end_idx]
+
+ if self.upcast_attention:
+ query_slice = query_slice.float()
+ key_slice = key_slice.float()
+
+ attn_slice = torch.baddbmm(
+ torch.empty(slice_size, query.shape[1], key.shape[1], dtype=query_slice.dtype, device=query.device),
+ query_slice,
+ key_slice.transpose(-1, -2),
+ beta=0,
+ alpha=self.scale,
+ )
+
+ if attention_mask is not None:
+ attn_slice = attn_slice + attention_mask[start_idx:end_idx]
+
+ if self.upcast_softmax:
+ attn_slice = attn_slice.float()
+
+ attn_slice = attn_slice.softmax(dim=-1)
+
+ # cast back to the original dtype
+ attn_slice = attn_slice.to(value.dtype)
+ attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
+
+ hidden_states[start_idx:end_idx] = attn_slice
+
+ # reshape hidden_states
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
+ return hidden_states
+
+ def _memory_efficient_attention_xformers(self, query, key, value, attention_mask):
+ # TODO attention_mask
+ query = query.contiguous()
+ key = key.contiguous()
+ value = value.contiguous()
+ hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
+ return hidden_states
+
+
+@maybe_allow_in_graph
+class TemporalTransformerBlock(nn.Module):
+ r"""
+ A Temporal Transformer block.
+
+ Parameters:
+ dim (`int`): The number of channels in the input and output.
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
+ attention_head_dim (`int`): The number of channels in each head.
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
+ num_embeds_ada_norm (:
+ obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
+ attention_bias (:
+ obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
+ only_cross_attention (`bool`, *optional*):
+ Whether to use only cross-attention layers. In this case two cross attention layers are used.
+ double_self_attention (`bool`, *optional*):
+ Whether to use two self-attention layers. In this case no cross attention layers are used.
+ upcast_attention (`bool`, *optional*):
+ Whether to upcast the attention computation to float32. This is useful for mixed precision training.
+ norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
+ Whether to use learnable elementwise affine parameters for normalization.
+ norm_type (`str`, *optional*, defaults to `"layer_norm"`):
+ The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
+ final_dropout (`bool` *optional*, defaults to False):
+ Whether to apply a final dropout after the last feed-forward layer.
+ attention_type (`str`, *optional*, defaults to `"default"`):
+ The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
+ positional_embeddings (`str`, *optional*, defaults to `None`):
+ The type of positional embeddings to apply to.
+ num_positional_embeddings (`int`, *optional*, defaults to `None`):
+ The maximum number of positional embeddings to apply.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ num_attention_heads: int,
+ attention_head_dim: int,
+ dropout=0.0,
+ cross_attention_dim: Optional[int] = None,
+ activation_fn: str = "geglu",
+ num_embeds_ada_norm: Optional[int] = None,
+ attention_bias: bool = False,
+ only_cross_attention: bool = False,
+ double_self_attention: bool = False,
+ upcast_attention: bool = False,
+ norm_elementwise_affine: bool = True,
+ norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single'
+ norm_eps: float = 1e-5,
+ final_dropout: bool = False,
+ attention_type: str = "default",
+ positional_embeddings: Optional[str] = None,
+ num_positional_embeddings: Optional[int] = None,
+ # kv compression
+ kvcompression: Optional[bool] = False,
+ # motion module kwargs
+ motion_module_type = "VanillaGrid",
+ motion_module_kwargs = None,
+ ):
+ super().__init__()
+ self.only_cross_attention = only_cross_attention
+
+ self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
+ self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
+ self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
+ self.use_layer_norm = norm_type == "layer_norm"
+
+ if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
+ raise ValueError(
+ f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
+ f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
+ )
+
+ if positional_embeddings and (num_positional_embeddings is None):
+ raise ValueError(
+ "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined."
+ )
+
+ if positional_embeddings == "sinusoidal":
+ self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings)
+ else:
+ self.pos_embed = None
+
+ # Define 3 blocks. Each block has its own normalization layer.
+ # 1. Self-Attn
+ if self.use_ada_layer_norm:
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
+ elif self.use_ada_layer_norm_zero:
+ self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
+ else:
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
+
+ self.kvcompression = kvcompression
+ if kvcompression:
+ self.attn1 = KVCompressionCrossAttention(
+ query_dim=dim,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ dropout=dropout,
+ bias=attention_bias,
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
+ upcast_attention=upcast_attention,
+ )
+ else:
+ self.attn1 = Attention(
+ query_dim=dim,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ dropout=dropout,
+ bias=attention_bias,
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
+ upcast_attention=upcast_attention,
+ )
+ print(self.attn1)
+
+ self.attn_temporal = get_motion_module(
+ in_channels = dim,
+ motion_module_type = motion_module_type,
+ motion_module_kwargs = motion_module_kwargs,
+ )
+
+ # 2. Cross-Attn
+ if cross_attention_dim is not None or double_self_attention:
+ # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
+ # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
+ # the second cross attention block.
+ self.norm2 = (
+ AdaLayerNorm(dim, num_embeds_ada_norm)
+ if self.use_ada_layer_norm
+ else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
+ )
+ self.attn2 = Attention(
+ query_dim=dim,
+ cross_attention_dim=cross_attention_dim if not double_self_attention else None,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ dropout=dropout,
+ bias=attention_bias,
+ upcast_attention=upcast_attention,
+ ) # is self-attn if encoder_hidden_states is none
+ else:
+ self.norm2 = None
+ self.attn2 = None
+
+ # 3. Feed-forward
+ if not self.use_ada_layer_norm_single:
+ self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
+
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
+
+ # 4. Fuser
+ if attention_type == "gated" or attention_type == "gated-text-image":
+ self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim)
+
+ # 5. Scale-shift for PixArt-Alpha.
+ if self.use_ada_layer_norm_single:
+ self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
+
+ # let chunk size default to None
+ self._chunk_size = None
+ self._chunk_dim = 0
+
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
+ # Sets chunk feed-forward
+ self._chunk_size = chunk_size
+ self._chunk_dim = dim
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ timestep: Optional[torch.LongTensor] = None,
+ cross_attention_kwargs: Dict[str, Any] = None,
+ class_labels: Optional[torch.LongTensor] = None,
+ num_frames: int = 16,
+ height: int = 32,
+ width: int = 32,
+ ) -> torch.FloatTensor:
+ # Notice that normalization is always applied before the real computation in the following blocks.
+ # 0. Self-Attention
+ batch_size = hidden_states.shape[0]
+
+ if self.use_ada_layer_norm:
+ norm_hidden_states = self.norm1(hidden_states, timestep)
+ elif self.use_ada_layer_norm_zero:
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
+ )
+ elif self.use_layer_norm:
+ norm_hidden_states = self.norm1(hidden_states)
+ elif self.use_ada_layer_norm_single:
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
+ self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
+ ).chunk(6, dim=1)
+ norm_hidden_states = self.norm1(hidden_states)
+ norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
+ norm_hidden_states = norm_hidden_states.squeeze(1)
+ else:
+ raise ValueError("Incorrect norm used")
+
+ if self.pos_embed is not None:
+ norm_hidden_states = self.pos_embed(norm_hidden_states)
+
+ # 1. Retrieve lora scale.
+ lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
+
+ # 2. Prepare GLIGEN inputs
+ cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
+ gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
+
+ norm_hidden_states = rearrange(norm_hidden_states, "b (f d) c -> (b f) d c", f=num_frames)
+ if self.kvcompression:
+ attn_output = self.attn1(
+ norm_hidden_states,
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
+ attention_mask=attention_mask,
+ num_frames=1,
+ height=height,
+ width=width,
+ **cross_attention_kwargs,
+ )
+ else:
+ attn_output = self.attn1(
+ norm_hidden_states,
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
+ attention_mask=attention_mask,
+ **cross_attention_kwargs,
+ )
+ attn_output = rearrange(attn_output, "(b f) d c -> b (f d) c", f=num_frames)
+ if self.use_ada_layer_norm_zero:
+ attn_output = gate_msa.unsqueeze(1) * attn_output
+ elif self.use_ada_layer_norm_single:
+ attn_output = gate_msa * attn_output
+
+ hidden_states = attn_output + hidden_states
+ if hidden_states.ndim == 4:
+ hidden_states = hidden_states.squeeze(1)
+
+ # 2.5 GLIGEN Control
+ if gligen_kwargs is not None:
+ hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
+
+ # 2.75. Temp-Attention
+ if self.attn_temporal is not None:
+ attn_output = rearrange(hidden_states, "b (f h w) c -> b c f h w", f=num_frames, h=height, w=width)
+ attn_output = self.attn_temporal(attn_output)
+ hidden_states = rearrange(attn_output, "b c f h w -> b (f h w) c")
+
+ # 3. Cross-Attention
+ if self.attn2 is not None:
+ if self.use_ada_layer_norm:
+ norm_hidden_states = self.norm2(hidden_states, timestep)
+ elif self.use_ada_layer_norm_zero or self.use_layer_norm:
+ norm_hidden_states = self.norm2(hidden_states)
+ elif self.use_ada_layer_norm_single:
+ # For PixArt norm2 isn't applied here:
+ # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103
+ norm_hidden_states = hidden_states
+ else:
+ raise ValueError("Incorrect norm")
+
+ if self.pos_embed is not None and self.use_ada_layer_norm_single is None:
+ norm_hidden_states = self.pos_embed(norm_hidden_states)
+
+ attn_output = self.attn2(
+ norm_hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=encoder_attention_mask,
+ **cross_attention_kwargs,
+ )
+ hidden_states = attn_output + hidden_states
+
+ # 4. Feed-forward
+ if not self.use_ada_layer_norm_single:
+ norm_hidden_states = self.norm3(hidden_states)
+
+ if self.use_ada_layer_norm_zero:
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
+
+ if self.use_ada_layer_norm_single:
+ norm_hidden_states = self.norm2(hidden_states)
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
+
+ if self._chunk_size is not None:
+ # "feed_forward_chunk_size" can be used to save memory
+ if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
+ raise ValueError(
+ f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
+ )
+
+ num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
+ ff_output = torch.cat(
+ [
+ self.ff(hid_slice, scale=lora_scale)
+ for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)
+ ],
+ dim=self._chunk_dim,
+ )
+ else:
+ ff_output = self.ff(norm_hidden_states, scale=lora_scale)
+
+ if self.use_ada_layer_norm_zero:
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
+ elif self.use_ada_layer_norm_single:
+ ff_output = gate_mlp * ff_output
+
+ hidden_states = ff_output + hidden_states
+ if hidden_states.ndim == 4:
+ hidden_states = hidden_states.squeeze(1)
+
+ return hidden_states
+
+
+@maybe_allow_in_graph
+class SelfAttentionTemporalTransformerBlock(nn.Module):
+ r"""
+ A Temporal Transformer block.
+
+ Parameters:
+ dim (`int`): The number of channels in the input and output.
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
+ attention_head_dim (`int`): The number of channels in each head.
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
+ num_embeds_ada_norm (:
+ obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
+ attention_bias (:
+ obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
+ only_cross_attention (`bool`, *optional*):
+ Whether to use only cross-attention layers. In this case two cross attention layers are used.
+ double_self_attention (`bool`, *optional*):
+ Whether to use two self-attention layers. In this case no cross attention layers are used.
+ upcast_attention (`bool`, *optional*):
+ Whether to upcast the attention computation to float32. This is useful for mixed precision training.
+ norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
+ Whether to use learnable elementwise affine parameters for normalization.
+ norm_type (`str`, *optional*, defaults to `"layer_norm"`):
+ The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
+ final_dropout (`bool` *optional*, defaults to False):
+ Whether to apply a final dropout after the last feed-forward layer.
+ attention_type (`str`, *optional*, defaults to `"default"`):
+ The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
+ positional_embeddings (`str`, *optional*, defaults to `None`):
+ The type of positional embeddings to apply to.
+ num_positional_embeddings (`int`, *optional*, defaults to `None`):
+ The maximum number of positional embeddings to apply.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ num_attention_heads: int,
+ attention_head_dim: int,
+ dropout=0.0,
+ cross_attention_dim: Optional[int] = None,
+ activation_fn: str = "geglu",
+ num_embeds_ada_norm: Optional[int] = None,
+ attention_bias: bool = False,
+ only_cross_attention: bool = False,
+ double_self_attention: bool = False,
+ upcast_attention: bool = False,
+ norm_elementwise_affine: bool = True,
+ norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single'
+ norm_eps: float = 1e-5,
+ final_dropout: bool = False,
+ attention_type: str = "default",
+ positional_embeddings: Optional[str] = None,
+ num_positional_embeddings: Optional[int] = None,
+ ):
+ super().__init__()
+ self.only_cross_attention = only_cross_attention
+
+ self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
+ self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
+ self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
+ self.use_layer_norm = norm_type == "layer_norm"
+
+ if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
+ raise ValueError(
+ f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
+ f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
+ )
+
+ if positional_embeddings and (num_positional_embeddings is None):
+ raise ValueError(
+ "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined."
+ )
+
+ if positional_embeddings == "sinusoidal":
+ self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings)
+ else:
+ self.pos_embed = None
+
+ # Define 3 blocks. Each block has its own normalization layer.
+ # 1. Self-Attn
+ if self.use_ada_layer_norm:
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
+ elif self.use_ada_layer_norm_zero:
+ self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
+ else:
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
+
+ self.attn1 = Attention(
+ query_dim=dim,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ dropout=dropout,
+ bias=attention_bias,
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
+ upcast_attention=upcast_attention,
+ )
+
+ # 2. Cross-Attn
+ if cross_attention_dim is not None or double_self_attention:
+ # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
+ # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
+ # the second cross attention block.
+ self.norm2 = (
+ AdaLayerNorm(dim, num_embeds_ada_norm)
+ if self.use_ada_layer_norm
+ else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
+ )
+ self.attn2 = Attention(
+ query_dim=dim,
+ cross_attention_dim=cross_attention_dim if not double_self_attention else None,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ dropout=dropout,
+ bias=attention_bias,
+ upcast_attention=upcast_attention,
+ ) # is self-attn if encoder_hidden_states is none
+ else:
+ self.norm2 = None
+ self.attn2 = None
+
+ # 3. Feed-forward
+ if not self.use_ada_layer_norm_single:
+ self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
+
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
+
+ # 4. Fuser
+ if attention_type == "gated" or attention_type == "gated-text-image":
+ self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim)
+
+ # 5. Scale-shift for PixArt-Alpha.
+ if self.use_ada_layer_norm_single:
+ self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
+
+ # let chunk size default to None
+ self._chunk_size = None
+ self._chunk_dim = 0
+
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
+ # Sets chunk feed-forward
+ self._chunk_size = chunk_size
+ self._chunk_dim = dim
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ timestep: Optional[torch.LongTensor] = None,
+ cross_attention_kwargs: Dict[str, Any] = None,
+ class_labels: Optional[torch.LongTensor] = None,
+ ) -> torch.FloatTensor:
+ # Notice that normalization is always applied before the real computation in the following blocks.
+ # 0. Self-Attention
+ batch_size = hidden_states.shape[0]
+
+ if self.use_ada_layer_norm:
+ norm_hidden_states = self.norm1(hidden_states, timestep)
+ elif self.use_ada_layer_norm_zero:
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
+ )
+ elif self.use_layer_norm:
+ norm_hidden_states = self.norm1(hidden_states)
+ elif self.use_ada_layer_norm_single:
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
+ self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
+ ).chunk(6, dim=1)
+ norm_hidden_states = self.norm1(hidden_states)
+ norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
+ norm_hidden_states = norm_hidden_states.squeeze(1)
+ else:
+ raise ValueError("Incorrect norm used")
+
+ if self.pos_embed is not None:
+ norm_hidden_states = self.pos_embed(norm_hidden_states)
+
+ # 1. Retrieve lora scale.
+ lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
+
+ # 2. Prepare GLIGEN inputs
+ cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
+ gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
+
+ attn_output = self.attn1(
+ norm_hidden_states,
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
+ attention_mask=attention_mask,
+ **cross_attention_kwargs,
+ )
+
+ if self.use_ada_layer_norm_zero:
+ attn_output = gate_msa.unsqueeze(1) * attn_output
+ elif self.use_ada_layer_norm_single:
+ attn_output = gate_msa * attn_output
+
+ hidden_states = attn_output + hidden_states
+ if hidden_states.ndim == 4:
+ hidden_states = hidden_states.squeeze(1)
+
+ # 2.5 GLIGEN Control
+ if gligen_kwargs is not None:
+ hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
+
+ # 3. Cross-Attention
+ if self.attn2 is not None:
+ if self.use_ada_layer_norm:
+ norm_hidden_states = self.norm2(hidden_states, timestep)
+ elif self.use_ada_layer_norm_zero or self.use_layer_norm:
+ norm_hidden_states = self.norm2(hidden_states)
+ elif self.use_ada_layer_norm_single:
+ # For PixArt norm2 isn't applied here:
+ # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103
+ norm_hidden_states = hidden_states
+ else:
+ raise ValueError("Incorrect norm")
+
+ if self.pos_embed is not None and self.use_ada_layer_norm_single is None:
+ norm_hidden_states = self.pos_embed(norm_hidden_states)
+
+ attn_output = self.attn2(
+ norm_hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=encoder_attention_mask,
+ **cross_attention_kwargs,
+ )
+ hidden_states = attn_output + hidden_states
+
+ # 4. Feed-forward
+ if not self.use_ada_layer_norm_single:
+ norm_hidden_states = self.norm3(hidden_states)
+
+ if self.use_ada_layer_norm_zero:
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
+
+ if self.use_ada_layer_norm_single:
+ norm_hidden_states = self.norm2(hidden_states)
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
+
+ if self._chunk_size is not None:
+ # "feed_forward_chunk_size" can be used to save memory
+ if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
+ raise ValueError(
+ f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
+ )
+
+ num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
+ ff_output = torch.cat(
+ [
+ self.ff(hid_slice, scale=lora_scale)
+ for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)
+ ],
+ dim=self._chunk_dim,
+ )
+ else:
+ ff_output = self.ff(norm_hidden_states, scale=lora_scale)
+
+ if self.use_ada_layer_norm_zero:
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
+ elif self.use_ada_layer_norm_single:
+ ff_output = gate_mlp * ff_output
+
+ hidden_states = ff_output + hidden_states
+ if hidden_states.ndim == 4:
+ hidden_states = hidden_states.squeeze(1)
+
+ return hidden_states
+
+
+@maybe_allow_in_graph
+class KVCompressionTransformerBlock(nn.Module):
+ r"""
+ A Temporal Transformer block.
+
+ Parameters:
+ dim (`int`): The number of channels in the input and output.
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
+ attention_head_dim (`int`): The number of channels in each head.
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
+ num_embeds_ada_norm (:
+ obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
+ attention_bias (:
+ obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
+ only_cross_attention (`bool`, *optional*):
+ Whether to use only cross-attention layers. In this case two cross attention layers are used.
+ double_self_attention (`bool`, *optional*):
+ Whether to use two self-attention layers. In this case no cross attention layers are used.
+ upcast_attention (`bool`, *optional*):
+ Whether to upcast the attention computation to float32. This is useful for mixed precision training.
+ norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
+ Whether to use learnable elementwise affine parameters for normalization.
+ norm_type (`str`, *optional*, defaults to `"layer_norm"`):
+ The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
+ final_dropout (`bool` *optional*, defaults to False):
+ Whether to apply a final dropout after the last feed-forward layer.
+ attention_type (`str`, *optional*, defaults to `"default"`):
+ The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
+ positional_embeddings (`str`, *optional*, defaults to `None`):
+ The type of positional embeddings to apply to.
+ num_positional_embeddings (`int`, *optional*, defaults to `None`):
+ The maximum number of positional embeddings to apply.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ num_attention_heads: int,
+ attention_head_dim: int,
+ dropout=0.0,
+ cross_attention_dim: Optional[int] = None,
+ activation_fn: str = "geglu",
+ num_embeds_ada_norm: Optional[int] = None,
+ attention_bias: bool = False,
+ only_cross_attention: bool = False,
+ double_self_attention: bool = False,
+ upcast_attention: bool = False,
+ norm_elementwise_affine: bool = True,
+ norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single'
+ norm_eps: float = 1e-5,
+ final_dropout: bool = False,
+ attention_type: str = "default",
+ positional_embeddings: Optional[str] = None,
+ num_positional_embeddings: Optional[int] = None,
+ kvcompression: Optional[bool] = False,
+ ):
+ super().__init__()
+ self.only_cross_attention = only_cross_attention
+
+ self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
+ self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
+ self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
+ self.use_layer_norm = norm_type == "layer_norm"
+
+ if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
+ raise ValueError(
+ f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
+ f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
+ )
+
+ if positional_embeddings and (num_positional_embeddings is None):
+ raise ValueError(
+ "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined."
+ )
+
+ if positional_embeddings == "sinusoidal":
+ self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings)
+ else:
+ self.pos_embed = None
+
+ # Define 3 blocks. Each block has its own normalization layer.
+ # 1. Self-Attn
+ if self.use_ada_layer_norm:
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
+ elif self.use_ada_layer_norm_zero:
+ self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
+ else:
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
+
+ self.kvcompression = kvcompression
+ if kvcompression:
+ self.attn1 = KVCompressionCrossAttention(
+ query_dim=dim,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ dropout=dropout,
+ bias=attention_bias,
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
+ upcast_attention=upcast_attention,
+ )
+ else:
+ self.attn1 = Attention(
+ query_dim=dim,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ dropout=dropout,
+ bias=attention_bias,
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
+ upcast_attention=upcast_attention,
+ )
+ print(self.attn1)
+
+ # 2. Cross-Attn
+ if cross_attention_dim is not None or double_self_attention:
+ # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
+ # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
+ # the second cross attention block.
+ self.norm2 = (
+ AdaLayerNorm(dim, num_embeds_ada_norm)
+ if self.use_ada_layer_norm
+ else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
+ )
+ self.attn2 = Attention(
+ query_dim=dim,
+ cross_attention_dim=cross_attention_dim if not double_self_attention else None,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ dropout=dropout,
+ bias=attention_bias,
+ upcast_attention=upcast_attention,
+ ) # is self-attn if encoder_hidden_states is none
+ else:
+ self.norm2 = None
+ self.attn2 = None
+
+ # 3. Feed-forward
+ if not self.use_ada_layer_norm_single:
+ self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
+
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
+
+ # 4. Fuser
+ if attention_type == "gated" or attention_type == "gated-text-image":
+ self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim)
+
+ # 5. Scale-shift for PixArt-Alpha.
+ if self.use_ada_layer_norm_single:
+ self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
+
+ # let chunk size default to None
+ self._chunk_size = None
+ self._chunk_dim = 0
+
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
+ # Sets chunk feed-forward
+ self._chunk_size = chunk_size
+ self._chunk_dim = dim
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ timestep: Optional[torch.LongTensor] = None,
+ cross_attention_kwargs: Dict[str, Any] = None,
+ class_labels: Optional[torch.LongTensor] = None,
+ num_frames: int = 16,
+ height: int = 32,
+ width: int = 32,
+ use_reentrant: bool = False,
+ ) -> torch.FloatTensor:
+ # Notice that normalization is always applied before the real computation in the following blocks.
+ # 0. Self-Attention
+ batch_size = hidden_states.shape[0]
+
+ if self.use_ada_layer_norm:
+ norm_hidden_states = self.norm1(hidden_states, timestep)
+ elif self.use_ada_layer_norm_zero:
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
+ )
+ elif self.use_layer_norm:
+ norm_hidden_states = self.norm1(hidden_states)
+ elif self.use_ada_layer_norm_single:
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
+ self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
+ ).chunk(6, dim=1)
+ norm_hidden_states = self.norm1(hidden_states)
+ norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
+ norm_hidden_states = norm_hidden_states.squeeze(1)
+ else:
+ raise ValueError("Incorrect norm used")
+
+ if self.pos_embed is not None:
+ norm_hidden_states = self.pos_embed(norm_hidden_states)
+
+ # 1. Retrieve lora scale.
+ lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
+
+ # 2. Prepare GLIGEN inputs
+ cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
+ gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
+
+ if self.kvcompression:
+ attn_output = self.attn1(
+ norm_hidden_states,
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
+ attention_mask=attention_mask,
+ num_frames=num_frames,
+ height=height,
+ width=width,
+ **cross_attention_kwargs,
+ )
+ else:
+ attn_output = self.attn1(
+ norm_hidden_states,
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
+ attention_mask=attention_mask,
+ **cross_attention_kwargs,
+ )
+
+ if self.use_ada_layer_norm_zero:
+ attn_output = gate_msa.unsqueeze(1) * attn_output
+ elif self.use_ada_layer_norm_single:
+ attn_output = gate_msa * attn_output
+
+ hidden_states = attn_output + hidden_states
+ if hidden_states.ndim == 4:
+ hidden_states = hidden_states.squeeze(1)
+
+ # 2.5 GLIGEN Control
+ if gligen_kwargs is not None:
+ hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
+
+ # 3. Cross-Attention
+ if self.attn2 is not None:
+ if self.use_ada_layer_norm:
+ norm_hidden_states = self.norm2(hidden_states, timestep)
+ elif self.use_ada_layer_norm_zero or self.use_layer_norm:
+ norm_hidden_states = self.norm2(hidden_states)
+ elif self.use_ada_layer_norm_single:
+ # For PixArt norm2 isn't applied here:
+ # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103
+ norm_hidden_states = hidden_states
+ else:
+ raise ValueError("Incorrect norm")
+
+ if self.pos_embed is not None and self.use_ada_layer_norm_single is None:
+ norm_hidden_states = self.pos_embed(norm_hidden_states)
+
+ attn_output = self.attn2(
+ norm_hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=encoder_attention_mask,
+ **cross_attention_kwargs,
+ )
+ hidden_states = attn_output + hidden_states
+
+ # 4. Feed-forward
+ if not self.use_ada_layer_norm_single:
+ norm_hidden_states = self.norm3(hidden_states)
+
+ if self.use_ada_layer_norm_zero:
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
+
+ if self.use_ada_layer_norm_single:
+ norm_hidden_states = self.norm2(hidden_states)
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
+
+ if self._chunk_size is not None:
+ # "feed_forward_chunk_size" can be used to save memory
+ if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
+ raise ValueError(
+ f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
+ )
+
+ num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
+ ff_output = torch.cat(
+ [
+ self.ff(hid_slice, scale=lora_scale)
+ for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)
+ ],
+ dim=self._chunk_dim,
+ )
+ else:
+ ff_output = self.ff(norm_hidden_states, scale=lora_scale)
+
+ if self.use_ada_layer_norm_zero:
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
+ elif self.use_ada_layer_norm_single:
+ ff_output = gate_mlp * ff_output
+
+ hidden_states = ff_output + hidden_states
+ if hidden_states.ndim == 4:
+ hidden_states = hidden_states.squeeze(1)
+
+ return hidden_states
+
+
+class FeedForward(nn.Module):
+ r"""
+ A feed-forward layer.
+
+ Parameters:
+ dim (`int`): The number of channels in the input.
+ dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
+ mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
+ final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ dim_out: Optional[int] = None,
+ mult: int = 4,
+ dropout: float = 0.0,
+ activation_fn: str = "geglu",
+ final_dropout: bool = False,
+ ):
+ super().__init__()
+ inner_dim = int(dim * mult)
+ dim_out = dim_out if dim_out is not None else dim
+ linear_cls = LoRACompatibleLinear if not USE_PEFT_BACKEND else nn.Linear
+
+ if activation_fn == "gelu":
+ act_fn = GELU(dim, inner_dim)
+ if activation_fn == "gelu-approximate":
+ act_fn = GELU(dim, inner_dim, approximate="tanh")
+ elif activation_fn == "geglu":
+ act_fn = GEGLU(dim, inner_dim)
+ elif activation_fn == "geglu-approximate":
+ act_fn = ApproximateGELU(dim, inner_dim)
+
+ self.net = nn.ModuleList([])
+ # project in
+ self.net.append(act_fn)
+ # project dropout
+ self.net.append(nn.Dropout(dropout))
+ # project out
+ self.net.append(linear_cls(inner_dim, dim_out))
+ # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
+ if final_dropout:
+ self.net.append(nn.Dropout(dropout))
+
+ def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor:
+ compatible_cls = (GEGLU,) if USE_PEFT_BACKEND else (GEGLU, LoRACompatibleLinear)
+ for module in self.net:
+ if isinstance(module, compatible_cls):
+ hidden_states = module(hidden_states, scale)
+ else:
+ hidden_states = module(hidden_states)
+ return hidden_states
diff --git a/easyanimate/models/autoencoder_magvit.py b/easyanimate/models/autoencoder_magvit.py
new file mode 100644
index 0000000000000000000000000000000000000000..d693767d214101b0fe11a0ad04c8ce64ec987948
--- /dev/null
+++ b/easyanimate/models/autoencoder_magvit.py
@@ -0,0 +1,503 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Dict, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+from diffusers.loaders import FromOriginalVAEMixin
+from diffusers.models.attention_processor import (
+ ADDED_KV_ATTENTION_PROCESSORS, CROSS_ATTENTION_PROCESSORS, Attention,
+ AttentionProcessor, AttnAddedKVProcessor, AttnProcessor)
+from diffusers.models.autoencoders.vae import (DecoderOutput,
+ DiagonalGaussianDistribution)
+from diffusers.models.modeling_outputs import AutoencoderKLOutput
+from diffusers.models.modeling_utils import ModelMixin
+from diffusers.utils.accelerate_utils import apply_forward_hook
+from torch import nn
+
+from ..vae.ldm.models.omnigen_enc_dec import Decoder as omnigen_Mag_Decoder
+from ..vae.ldm.models.omnigen_enc_dec import Encoder as omnigen_Mag_Encoder
+
+
+def str_eval(item):
+ if type(item) == str:
+ return eval(item)
+ else:
+ return item
+
+class AutoencoderKLMagvit(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
+ r"""
+ A VAE model with KL loss for encoding images into latents and decoding latent representations into images.
+
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
+ for all models (such as downloading or saving).
+
+ Parameters:
+ in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
+ out_channels (int, *optional*, defaults to 3): Number of channels in the output.
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
+ Tuple of downsample block types.
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
+ Tuple of upsample block types.
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
+ Tuple of block output channels.
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
+ latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space.
+ sample_size (`int`, *optional*, defaults to `32`): Sample input size.
+ scaling_factor (`float`, *optional*, defaults to 0.18215):
+ The component-wise standard deviation of the trained latent space computed using the first batch of the
+ training set. This is used to scale the latent space to have unit variance when training the diffusion
+ model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
+ diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
+ / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
+ Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
+ force_upcast (`bool`, *optional*, default to `True`):
+ If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE
+ can be fine-tuned / trained to a lower range without loosing too much precision in which case
+ `force_upcast` can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix
+ """
+
+ _supports_gradient_checkpointing = True
+
+ @register_to_config
+ def __init__(
+ self,
+ in_channels: int = 3,
+ out_channels: int = 3,
+ ch = 128,
+ ch_mult = [ 1,2,4,4 ],
+ use_gc_blocks = None,
+ down_block_types: tuple = None,
+ up_block_types: tuple = None,
+ mid_block_type: str = "MidBlock3D",
+ mid_block_use_attention: bool = True,
+ mid_block_attention_type: str = "3d",
+ mid_block_num_attention_heads: int = 1,
+ layers_per_block: int = 2,
+ act_fn: str = "silu",
+ num_attention_heads: int = 1,
+ latent_channels: int = 4,
+ norm_num_groups: int = 32,
+ scaling_factor: float = 0.1825,
+ slice_compression_vae=False,
+ mini_batch_encoder=9,
+ mini_batch_decoder=3,
+ ):
+ super().__init__()
+ down_block_types = str_eval(down_block_types)
+ up_block_types = str_eval(up_block_types)
+ self.encoder = omnigen_Mag_Encoder(
+ in_channels=in_channels,
+ out_channels=latent_channels,
+ down_block_types=down_block_types,
+ ch = ch,
+ ch_mult = ch_mult,
+ use_gc_blocks=use_gc_blocks,
+ mid_block_type=mid_block_type,
+ mid_block_use_attention=mid_block_use_attention,
+ mid_block_attention_type=mid_block_attention_type,
+ mid_block_num_attention_heads=mid_block_num_attention_heads,
+ layers_per_block=layers_per_block,
+ norm_num_groups=norm_num_groups,
+ act_fn=act_fn,
+ num_attention_heads=num_attention_heads,
+ double_z=True,
+ slice_compression_vae=slice_compression_vae,
+ mini_batch_encoder=mini_batch_encoder,
+ )
+
+ self.decoder = omnigen_Mag_Decoder(
+ in_channels=latent_channels,
+ out_channels=out_channels,
+ up_block_types=up_block_types,
+ ch = ch,
+ ch_mult = ch_mult,
+ use_gc_blocks=use_gc_blocks,
+ mid_block_type=mid_block_type,
+ mid_block_use_attention=mid_block_use_attention,
+ mid_block_attention_type=mid_block_attention_type,
+ mid_block_num_attention_heads=mid_block_num_attention_heads,
+ layers_per_block=layers_per_block,
+ norm_num_groups=norm_num_groups,
+ act_fn=act_fn,
+ num_attention_heads=num_attention_heads,
+ slice_compression_vae=slice_compression_vae,
+ mini_batch_decoder=mini_batch_decoder,
+ )
+
+ self.quant_conv = nn.Conv3d(2 * latent_channels, 2 * latent_channels, kernel_size=1)
+ self.post_quant_conv = nn.Conv3d(latent_channels, latent_channels, kernel_size=1)
+
+ self.slice_compression_vae = slice_compression_vae
+ self.mini_batch_encoder = mini_batch_encoder
+ self.mini_batch_decoder = mini_batch_decoder
+ self.use_slicing = False
+ self.use_tiling = False
+ self.tile_sample_min_size = 256
+ self.tile_overlap_factor = 0.25
+ self.tile_latent_min_size = int(self.tile_sample_min_size / (2 ** (len(ch_mult) - 1)))
+ self.scaling_factor = scaling_factor
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ if isinstance(module, (omnigen_Mag_Encoder, omnigen_Mag_Decoder)):
+ module.gradient_checkpointing = value
+
+ @property
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
+ r"""
+ Returns:
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
+ indexed by its weight name.
+ """
+ # set recursively
+ processors = {}
+
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
+ if hasattr(module, "get_processor"):
+ processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
+
+ for sub_name, child in module.named_children():
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
+
+ return processors
+
+ for name, module in self.named_children():
+ fn_recursive_add_processors(name, module, processors)
+
+ return processors
+
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
+ r"""
+ Sets the attention processor to use to compute attention.
+
+ Parameters:
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
+ for **all** `Attention` layers.
+
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
+ processor. This is strongly recommended when setting trainable attention processors.
+
+ """
+ count = len(self.attn_processors.keys())
+
+ if isinstance(processor, dict) and len(processor) != count:
+ raise ValueError(
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
+ )
+
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
+ if hasattr(module, "set_processor"):
+ if not isinstance(processor, dict):
+ module.set_processor(processor)
+ else:
+ module.set_processor(processor.pop(f"{name}.processor"))
+
+ for sub_name, child in module.named_children():
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
+
+ for name, module in self.named_children():
+ fn_recursive_attn_processor(name, module, processor)
+
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
+ def set_default_attn_processor(self):
+ """
+ Disables custom attention processors and sets the default attention implementation.
+ """
+ if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
+ processor = AttnAddedKVProcessor()
+ elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
+ processor = AttnProcessor()
+ else:
+ raise ValueError(
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
+ )
+
+ self.set_attn_processor(processor)
+
+ @apply_forward_hook
+ def encode(
+ self, x: torch.FloatTensor, return_dict: bool = True
+ ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
+ """
+ Encode a batch of images into latents.
+
+ Args:
+ x (`torch.FloatTensor`): Input batch of images.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
+
+ Returns:
+ The latent representations of the encoded images. If `return_dict` is True, a
+ [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
+ """
+ if self.use_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size):
+ return self.tiled_encode(x, return_dict=return_dict)
+
+ if self.use_slicing and x.shape[0] > 1:
+ encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)]
+ h = torch.cat(encoded_slices)
+ else:
+ h = self.encoder(x)
+
+ moments = self.quant_conv(h)
+ posterior = DiagonalGaussianDistribution(moments)
+
+ if not return_dict:
+ return (posterior,)
+
+ return AutoencoderKLOutput(latent_dist=posterior)
+
+ def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
+ if self.use_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size):
+ return self.tiled_decode(z, return_dict=return_dict)
+ z = self.post_quant_conv(z)
+ dec = self.decoder(z)
+
+ if not return_dict:
+ return (dec,)
+
+ return DecoderOutput(sample=dec)
+
+ @apply_forward_hook
+ def decode(
+ self, z: torch.FloatTensor, return_dict: bool = True, generator=None
+ ) -> Union[DecoderOutput, torch.FloatTensor]:
+ """
+ Decode a batch of images.
+
+ Args:
+ z (`torch.FloatTensor`): Input batch of latent vectors.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~models.vae.DecoderOutput`] or `tuple`:
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
+ returned.
+
+ """
+ if self.use_slicing and z.shape[0] > 1:
+ decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
+ decoded = torch.cat(decoded_slices)
+ else:
+ decoded = self._decode(z).sample
+
+ if not return_dict:
+ return (decoded,)
+
+ return DecoderOutput(sample=decoded)
+
+ def blend_v(
+ self, a: torch.Tensor, b: torch.Tensor, blend_extent: int
+ ) -> torch.Tensor:
+ blend_extent = min(a.shape[3], b.shape[3], blend_extent)
+ for y in range(blend_extent):
+ b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (
+ 1 - y / blend_extent
+ ) + b[:, :, :, y, :] * (y / blend_extent)
+ return b
+
+ def blend_h(
+ self, a: torch.Tensor, b: torch.Tensor, blend_extent: int
+ ) -> torch.Tensor:
+ blend_extent = min(a.shape[4], b.shape[4], blend_extent)
+ for x in range(blend_extent):
+ b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (
+ 1 - x / blend_extent
+ ) + b[:, :, :, :, x] * (x / blend_extent)
+ return b
+
+ def tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
+ overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
+ blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
+ row_limit = self.tile_latent_min_size - blend_extent
+
+ # Split the image into 512x512 tiles and encode them separately.
+ rows = []
+ for i in range(0, x.shape[3], overlap_size):
+ row = []
+ for j in range(0, x.shape[4], overlap_size):
+ tile = x[
+ :,
+ :,
+ :,
+ i : i + self.tile_sample_min_size,
+ j : j + self.tile_sample_min_size,
+ ]
+ tile = self.encoder(tile)
+ tile = self.quant_conv(tile)
+ row.append(tile)
+ rows.append(row)
+ result_rows = []
+ for i, row in enumerate(rows):
+ result_row = []
+ for j, tile in enumerate(row):
+ # blend the above tile and the left tile
+ # to the current tile and add the current tile to the result row
+ if i > 0:
+ tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
+ if j > 0:
+ tile = self.blend_h(row[j - 1], tile, blend_extent)
+ result_row.append(tile[:, :, :, :row_limit, :row_limit])
+ result_rows.append(torch.cat(result_row, dim=4))
+
+ moments = torch.cat(result_rows, dim=3)
+ posterior = DiagonalGaussianDistribution(moments)
+
+ if not return_dict:
+ return (posterior,)
+
+ return AutoencoderKLOutput(latent_dist=posterior)
+
+ def tiled_decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
+ overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor))
+ blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor)
+ row_limit = self.tile_sample_min_size - blend_extent
+
+ # Split z into overlapping 64x64 tiles and decode them separately.
+ # The tiles have an overlap to avoid seams between tiles.
+ rows = []
+ for i in range(0, z.shape[3], overlap_size):
+ row = []
+ for j in range(0, z.shape[4], overlap_size):
+ tile = z[
+ :,
+ :,
+ :,
+ i : i + self.tile_latent_min_size,
+ j : j + self.tile_latent_min_size,
+ ]
+ tile = self.post_quant_conv(tile)
+ decoded = self.decoder(tile)
+ row.append(decoded)
+ rows.append(row)
+ result_rows = []
+ for i, row in enumerate(rows):
+ result_row = []
+ for j, tile in enumerate(row):
+ # blend the above tile and the left tile
+ # to the current tile and add the current tile to the result row
+ if i > 0:
+ tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
+ if j > 0:
+ tile = self.blend_h(row[j - 1], tile, blend_extent)
+ result_row.append(tile[:, :, :, :row_limit, :row_limit])
+ result_rows.append(torch.cat(result_row, dim=4))
+
+ dec = torch.cat(result_rows, dim=3)
+ if not return_dict:
+ return (dec,)
+
+ return DecoderOutput(sample=dec)
+
+ def forward(
+ self,
+ sample: torch.FloatTensor,
+ sample_posterior: bool = False,
+ return_dict: bool = True,
+ generator: Optional[torch.Generator] = None,
+ ) -> Union[DecoderOutput, torch.FloatTensor]:
+ r"""
+ Args:
+ sample (`torch.FloatTensor`): Input sample.
+ sample_posterior (`bool`, *optional*, defaults to `False`):
+ Whether to sample from the posterior.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
+ """
+ x = sample
+ posterior = self.encode(x).latent_dist
+ if sample_posterior:
+ z = posterior.sample(generator=generator)
+ else:
+ z = posterior.mode()
+ dec = self.decode(z).sample
+
+ if not return_dict:
+ return (dec,)
+
+ return DecoderOutput(sample=dec)
+
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
+ def fuse_qkv_projections(self):
+ """
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
+ key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
+
+
+
+ This API is 🧪 experimental.
+
+
+ """
+ self.original_attn_processors = None
+
+ for _, attn_processor in self.attn_processors.items():
+ if "Added" in str(attn_processor.__class__.__name__):
+ raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
+
+ self.original_attn_processors = self.attn_processors
+
+ for module in self.modules():
+ if isinstance(module, Attention):
+ module.fuse_projections(fuse=True)
+
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
+ def unfuse_qkv_projections(self):
+ """Disables the fused QKV projection if enabled.
+
+
+
+ This API is 🧪 experimental.
+
+
+
+ """
+ if self.original_attn_processors is not None:
+ self.set_attn_processor(self.original_attn_processors)
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_path, subfolder=None, **vae_additional_kwargs):
+ import json
+ import os
+ if subfolder is not None:
+ pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
+
+ config_file = os.path.join(pretrained_model_path, 'config.json')
+ if not os.path.isfile(config_file):
+ raise RuntimeError(f"{config_file} does not exist")
+ with open(config_file, "r") as f:
+ config = json.load(f)
+
+ model = cls.from_config(config, **vae_additional_kwargs)
+ from diffusers.utils import WEIGHTS_NAME
+ model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
+ model_file_safetensors = model_file.replace(".bin", ".safetensors")
+ if os.path.exists(model_file_safetensors):
+ from safetensors.torch import load_file, safe_open
+ state_dict = load_file(model_file_safetensors)
+ else:
+ if not os.path.isfile(model_file):
+ raise RuntimeError(f"{model_file} does not exist")
+ state_dict = torch.load(model_file, map_location="cpu")
+ m, u = model.load_state_dict(state_dict, strict=False)
+ print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
+ print(m, u)
+ return model
diff --git a/easyanimate/models/motion_module.py b/easyanimate/models/motion_module.py
new file mode 100644
index 0000000000000000000000000000000000000000..10ced61cb53f833c7d734bcb10e28f53a4c0b1e2
--- /dev/null
+++ b/easyanimate/models/motion_module.py
@@ -0,0 +1,575 @@
+"""Modified from https://github.com/guoyww/AnimateDiff/blob/main/animatediff/models/motion_module.py
+"""
+import math
+from typing import Any, Callable, List, Optional, Tuple, Union
+
+import torch
+import torch.nn.functional as F
+from diffusers.models.attention import FeedForward
+from diffusers.utils.import_utils import is_xformers_available
+from einops import rearrange, repeat
+from torch import nn
+
+if is_xformers_available():
+ import xformers
+ import xformers.ops
+else:
+ xformers = None
+
+class CrossAttention(nn.Module):
+ r"""
+ A cross attention layer.
+
+ Parameters:
+ query_dim (`int`): The number of channels in the query.
+ cross_attention_dim (`int`, *optional*):
+ The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
+ heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
+ dim_head (`int`, *optional*, defaults to 64): The number of channels in each head.
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
+ bias (`bool`, *optional*, defaults to False):
+ Set to `True` for the query, key, and value linear layers to contain a bias parameter.
+ """
+
+ def __init__(
+ self,
+ query_dim: int,
+ cross_attention_dim: Optional[int] = None,
+ heads: int = 8,
+ dim_head: int = 64,
+ dropout: float = 0.0,
+ bias=False,
+ upcast_attention: bool = False,
+ upcast_softmax: bool = False,
+ added_kv_proj_dim: Optional[int] = None,
+ norm_num_groups: Optional[int] = None,
+ ):
+ super().__init__()
+ inner_dim = dim_head * heads
+ cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
+ self.upcast_attention = upcast_attention
+ self.upcast_softmax = upcast_softmax
+
+ self.scale = dim_head**-0.5
+
+ self.heads = heads
+ # for slice_size > 0 the attention score computation
+ # is split across the batch axis to save memory
+ # You can set slice_size with `set_attention_slice`
+ self.sliceable_head_dim = heads
+ self._slice_size = None
+ self._use_memory_efficient_attention_xformers = False
+ self.added_kv_proj_dim = added_kv_proj_dim
+
+ if norm_num_groups is not None:
+ self.group_norm = nn.GroupNorm(num_channels=inner_dim, num_groups=norm_num_groups, eps=1e-5, affine=True)
+ else:
+ self.group_norm = None
+
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
+ self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
+ self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
+
+ if self.added_kv_proj_dim is not None:
+ self.add_k_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
+ self.add_v_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
+
+ self.to_out = nn.ModuleList([])
+ self.to_out.append(nn.Linear(inner_dim, query_dim))
+ self.to_out.append(nn.Dropout(dropout))
+
+ def set_use_memory_efficient_attention_xformers(
+ self, valid: bool, attention_op: Optional[Callable] = None
+ ) -> None:
+ self._use_memory_efficient_attention_xformers = valid
+
+ def reshape_heads_to_batch_dim(self, tensor):
+ batch_size, seq_len, dim = tensor.shape
+ head_size = self.heads
+ tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
+ return tensor
+
+ def reshape_batch_dim_to_heads(self, tensor):
+ batch_size, seq_len, dim = tensor.shape
+ head_size = self.heads
+ tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
+ return tensor
+
+ def set_attention_slice(self, slice_size):
+ if slice_size is not None and slice_size > self.sliceable_head_dim:
+ raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.")
+
+ self._slice_size = slice_size
+
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
+ batch_size, sequence_length, _ = hidden_states.shape
+
+ encoder_hidden_states = encoder_hidden_states
+
+ if self.group_norm is not None:
+ hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ query = self.to_q(hidden_states)
+ dim = query.shape[-1]
+ query = self.reshape_heads_to_batch_dim(query)
+
+ if self.added_kv_proj_dim is not None:
+ key = self.to_k(hidden_states)
+ value = self.to_v(hidden_states)
+ encoder_hidden_states_key_proj = self.add_k_proj(encoder_hidden_states)
+ encoder_hidden_states_value_proj = self.add_v_proj(encoder_hidden_states)
+
+ key = self.reshape_heads_to_batch_dim(key)
+ value = self.reshape_heads_to_batch_dim(value)
+ encoder_hidden_states_key_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_key_proj)
+ encoder_hidden_states_value_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_value_proj)
+
+ key = torch.concat([encoder_hidden_states_key_proj, key], dim=1)
+ value = torch.concat([encoder_hidden_states_value_proj, value], dim=1)
+ else:
+ encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
+ key = self.to_k(encoder_hidden_states)
+ value = self.to_v(encoder_hidden_states)
+
+ key = self.reshape_heads_to_batch_dim(key)
+ value = self.reshape_heads_to_batch_dim(value)
+
+ if attention_mask is not None:
+ if attention_mask.shape[-1] != query.shape[1]:
+ target_length = query.shape[1]
+ attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
+ attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
+
+ # attention, what we cannot get enough of
+ if self._use_memory_efficient_attention_xformers:
+ hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
+ # Some versions of xformers return output in fp32, cast it back to the dtype of the input
+ hidden_states = hidden_states.to(query.dtype)
+ else:
+ if self._slice_size is None or query.shape[0] // self._slice_size == 1:
+ hidden_states = self._attention(query, key, value, attention_mask)
+ else:
+ hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
+
+ # linear proj
+ hidden_states = self.to_out[0](hidden_states)
+
+ # dropout
+ hidden_states = self.to_out[1](hidden_states)
+ return hidden_states
+
+ def _attention(self, query, key, value, attention_mask=None):
+ if self.upcast_attention:
+ query = query.float()
+ key = key.float()
+
+ attention_scores = torch.baddbmm(
+ torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
+ query,
+ key.transpose(-1, -2),
+ beta=0,
+ alpha=self.scale,
+ )
+
+ if attention_mask is not None:
+ attention_scores = attention_scores + attention_mask
+
+ if self.upcast_softmax:
+ attention_scores = attention_scores.float()
+
+ attention_probs = attention_scores.softmax(dim=-1)
+
+ # cast back to the original dtype
+ attention_probs = attention_probs.to(value.dtype)
+
+ # compute attention output
+ hidden_states = torch.bmm(attention_probs, value)
+
+ # reshape hidden_states
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
+ return hidden_states
+
+ def _sliced_attention(self, query, key, value, sequence_length, dim, attention_mask):
+ batch_size_attention = query.shape[0]
+ hidden_states = torch.zeros(
+ (batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype
+ )
+ slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0]
+ for i in range(hidden_states.shape[0] // slice_size):
+ start_idx = i * slice_size
+ end_idx = (i + 1) * slice_size
+
+ query_slice = query[start_idx:end_idx]
+ key_slice = key[start_idx:end_idx]
+
+ if self.upcast_attention:
+ query_slice = query_slice.float()
+ key_slice = key_slice.float()
+
+ attn_slice = torch.baddbmm(
+ torch.empty(slice_size, query.shape[1], key.shape[1], dtype=query_slice.dtype, device=query.device),
+ query_slice,
+ key_slice.transpose(-1, -2),
+ beta=0,
+ alpha=self.scale,
+ )
+
+ if attention_mask is not None:
+ attn_slice = attn_slice + attention_mask[start_idx:end_idx]
+
+ if self.upcast_softmax:
+ attn_slice = attn_slice.float()
+
+ attn_slice = attn_slice.softmax(dim=-1)
+
+ # cast back to the original dtype
+ attn_slice = attn_slice.to(value.dtype)
+ attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
+
+ hidden_states[start_idx:end_idx] = attn_slice
+
+ # reshape hidden_states
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
+ return hidden_states
+
+ def _memory_efficient_attention_xformers(self, query, key, value, attention_mask):
+ # TODO attention_mask
+ query = query.contiguous()
+ key = key.contiguous()
+ value = value.contiguous()
+ hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
+ return hidden_states
+
+def zero_module(module):
+ # Zero out the parameters of a module and return it.
+ for p in module.parameters():
+ p.detach().zero_()
+ return module
+
+def get_motion_module(
+ in_channels,
+ motion_module_type: str,
+ motion_module_kwargs: dict,
+):
+ if motion_module_type == "Vanilla":
+ return VanillaTemporalModule(in_channels=in_channels, **motion_module_kwargs,)
+ elif motion_module_type == "VanillaGrid":
+ return VanillaTemporalModule(in_channels=in_channels, grid=True, **motion_module_kwargs,)
+ else:
+ raise ValueError
+
+class VanillaTemporalModule(nn.Module):
+ def __init__(
+ self,
+ in_channels,
+ num_attention_heads = 8,
+ num_transformer_block = 2,
+ attention_block_types =( "Temporal_Self", "Temporal_Self" ),
+ cross_frame_attention_mode = None,
+ temporal_position_encoding = False,
+ temporal_position_encoding_max_len = 4096,
+ temporal_attention_dim_div = 1,
+ zero_initialize = True,
+ block_size = 1,
+ grid = False,
+ ):
+ super().__init__()
+
+ self.temporal_transformer = TemporalTransformer3DModel(
+ in_channels=in_channels,
+ num_attention_heads=num_attention_heads,
+ attention_head_dim=in_channels // num_attention_heads // temporal_attention_dim_div,
+ num_layers=num_transformer_block,
+ attention_block_types=attention_block_types,
+ cross_frame_attention_mode=cross_frame_attention_mode,
+ temporal_position_encoding=temporal_position_encoding,
+ temporal_position_encoding_max_len=temporal_position_encoding_max_len,
+ grid=grid,
+ block_size=block_size,
+ )
+ if zero_initialize:
+ self.temporal_transformer.proj_out = zero_module(self.temporal_transformer.proj_out)
+
+ def forward(self, input_tensor, encoder_hidden_states=None, attention_mask=None, anchor_frame_idx=None):
+ hidden_states = input_tensor
+ hidden_states = self.temporal_transformer(hidden_states, encoder_hidden_states, attention_mask)
+
+ output = hidden_states
+ return output
+
+class TemporalTransformer3DModel(nn.Module):
+ def __init__(
+ self,
+ in_channels,
+ num_attention_heads,
+ attention_head_dim,
+
+ num_layers,
+ attention_block_types = ( "Temporal_Self", "Temporal_Self", ),
+ dropout = 0.0,
+ norm_num_groups = 32,
+ cross_attention_dim = 768,
+ activation_fn = "geglu",
+ attention_bias = False,
+ upcast_attention = False,
+
+ cross_frame_attention_mode = None,
+ temporal_position_encoding = False,
+ temporal_position_encoding_max_len = 4096,
+ grid = False,
+ block_size = 1,
+ ):
+ super().__init__()
+
+ inner_dim = num_attention_heads * attention_head_dim
+
+ self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
+ self.proj_in = nn.Linear(in_channels, inner_dim)
+
+ self.block_size = block_size
+ self.transformer_blocks = nn.ModuleList(
+ [
+ TemporalTransformerBlock(
+ dim=inner_dim,
+ num_attention_heads=num_attention_heads,
+ attention_head_dim=attention_head_dim,
+ attention_block_types=attention_block_types,
+ dropout=dropout,
+ norm_num_groups=norm_num_groups,
+ cross_attention_dim=cross_attention_dim,
+ activation_fn=activation_fn,
+ attention_bias=attention_bias,
+ upcast_attention=upcast_attention,
+ cross_frame_attention_mode=cross_frame_attention_mode,
+ temporal_position_encoding=temporal_position_encoding,
+ temporal_position_encoding_max_len=temporal_position_encoding_max_len,
+ block_size=block_size,
+ grid=grid,
+ )
+ for d in range(num_layers)
+ ]
+ )
+ self.proj_out = nn.Linear(inner_dim, in_channels)
+
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
+ assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
+ video_length = hidden_states.shape[2]
+ hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
+
+ batch, channel, height, weight = hidden_states.shape
+ residual = hidden_states
+
+ hidden_states = self.norm(hidden_states)
+ inner_dim = hidden_states.shape[1]
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
+ hidden_states = self.proj_in(hidden_states)
+
+ # Transformer Blocks
+ for block in self.transformer_blocks:
+ hidden_states = block(hidden_states, encoder_hidden_states=encoder_hidden_states, video_length=video_length, height=height, weight=weight)
+
+ # output
+ hidden_states = self.proj_out(hidden_states)
+ hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
+
+ output = hidden_states + residual
+ output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
+
+ return output
+
+class TemporalTransformerBlock(nn.Module):
+ def __init__(
+ self,
+ dim,
+ num_attention_heads,
+ attention_head_dim,
+ attention_block_types = ( "Temporal_Self", "Temporal_Self", ),
+ dropout = 0.0,
+ norm_num_groups = 32,
+ cross_attention_dim = 768,
+ activation_fn = "geglu",
+ attention_bias = False,
+ upcast_attention = False,
+ cross_frame_attention_mode = None,
+ temporal_position_encoding = False,
+ temporal_position_encoding_max_len = 4096,
+ block_size = 1,
+ grid = False,
+ ):
+ super().__init__()
+
+ attention_blocks = []
+ norms = []
+
+ for block_name in attention_block_types:
+ attention_blocks.append(
+ VersatileAttention(
+ attention_mode=block_name.split("_")[0],
+ cross_attention_dim=cross_attention_dim if block_name.endswith("_Cross") else None,
+
+ query_dim=dim,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ dropout=dropout,
+ bias=attention_bias,
+ upcast_attention=upcast_attention,
+
+ cross_frame_attention_mode=cross_frame_attention_mode,
+ temporal_position_encoding=temporal_position_encoding,
+ temporal_position_encoding_max_len=temporal_position_encoding_max_len,
+ block_size=block_size,
+ grid=grid,
+ )
+ )
+ norms.append(nn.LayerNorm(dim))
+
+ self.attention_blocks = nn.ModuleList(attention_blocks)
+ self.norms = nn.ModuleList(norms)
+
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
+ self.ff_norm = nn.LayerNorm(dim)
+
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None, height=None, weight=None):
+ for attention_block, norm in zip(self.attention_blocks, self.norms):
+ norm_hidden_states = norm(hidden_states)
+ hidden_states = attention_block(
+ norm_hidden_states,
+ encoder_hidden_states=encoder_hidden_states if attention_block.is_cross_attention else None,
+ video_length=video_length,
+ height=height,
+ weight=weight,
+ ) + hidden_states
+
+ hidden_states = self.ff(self.ff_norm(hidden_states)) + hidden_states
+
+ output = hidden_states
+ return output
+
+class PositionalEncoding(nn.Module):
+ def __init__(
+ self,
+ d_model,
+ dropout = 0.,
+ max_len = 4096
+ ):
+ super().__init__()
+ self.dropout = nn.Dropout(p=dropout)
+ position = torch.arange(max_len).unsqueeze(1)
+ div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
+ pe = torch.zeros(1, max_len, d_model)
+ pe[0, :, 0::2] = torch.sin(position * div_term)
+ pe[0, :, 1::2] = torch.cos(position * div_term)
+ self.register_buffer('pe', pe)
+
+ def forward(self, x):
+ x = x + self.pe[:, :x.size(1)]
+ return self.dropout(x)
+
+class VersatileAttention(CrossAttention):
+ def __init__(
+ self,
+ attention_mode = None,
+ cross_frame_attention_mode = None,
+ temporal_position_encoding = False,
+ temporal_position_encoding_max_len = 4096,
+ grid = False,
+ block_size = 1,
+ *args, **kwargs
+ ):
+ super().__init__(*args, **kwargs)
+ assert attention_mode == "Temporal"
+
+ self.attention_mode = attention_mode
+ self.is_cross_attention = kwargs["cross_attention_dim"] is not None
+
+ self.block_size = block_size
+ self.grid = grid
+ self.pos_encoder = PositionalEncoding(
+ kwargs["query_dim"],
+ dropout=0.,
+ max_len=temporal_position_encoding_max_len
+ ) if (temporal_position_encoding and attention_mode == "Temporal") else None
+
+ def extra_repr(self):
+ return f"(Module Info) Attention_Mode: {self.attention_mode}, Is_Cross_Attention: {self.is_cross_attention}"
+
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None, height=None, weight=None):
+ batch_size, sequence_length, _ = hidden_states.shape
+
+ if self.attention_mode == "Temporal":
+ # for add pos_encoder
+ _, before_d, _c = hidden_states.size()
+ hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length)
+ if self.pos_encoder is not None:
+ hidden_states = self.pos_encoder(hidden_states)
+
+ if self.grid:
+ hidden_states = rearrange(hidden_states, "(b d) f c -> b f d c", f=video_length, d=before_d)
+ hidden_states = rearrange(hidden_states, "b f (h w) c -> b f h w c", h=height, w=weight)
+
+ hidden_states = rearrange(hidden_states, "b f (h n) (w m) c -> (b h w) (f n m) c", n=self.block_size, m=self.block_size)
+ d = before_d // self.block_size // self.block_size
+ else:
+ d = before_d
+ encoder_hidden_states = repeat(encoder_hidden_states, "b n c -> (b d) n c", d=d) if encoder_hidden_states is not None else encoder_hidden_states
+ else:
+ raise NotImplementedError
+
+ encoder_hidden_states = encoder_hidden_states
+
+ if self.group_norm is not None:
+ hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ query = self.to_q(hidden_states)
+ dim = query.shape[-1]
+ query = self.reshape_heads_to_batch_dim(query)
+
+ if self.added_kv_proj_dim is not None:
+ raise NotImplementedError
+
+ encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
+ key = self.to_k(encoder_hidden_states)
+ value = self.to_v(encoder_hidden_states)
+
+ key = self.reshape_heads_to_batch_dim(key)
+ value = self.reshape_heads_to_batch_dim(value)
+
+ if attention_mask is not None:
+ if attention_mask.shape[-1] != query.shape[1]:
+ target_length = query.shape[1]
+ attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
+ attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
+
+ bs = 512
+ new_hidden_states = []
+ for i in range(0, query.shape[0], bs):
+ # attention, what we cannot get enough of
+ if self._use_memory_efficient_attention_xformers:
+ hidden_states = self._memory_efficient_attention_xformers(query[i : i + bs], key[i : i + bs], value[i : i + bs], attention_mask[i : i + bs] if attention_mask is not None else attention_mask)
+ # Some versions of xformers return output in fp32, cast it back to the dtype of the input
+ hidden_states = hidden_states.to(query.dtype)
+ else:
+ if self._slice_size is None or query[i : i + bs].shape[0] // self._slice_size == 1:
+ hidden_states = self._attention(query[i : i + bs], key[i : i + bs], value[i : i + bs], attention_mask[i : i + bs] if attention_mask is not None else attention_mask)
+ else:
+ hidden_states = self._sliced_attention(query[i : i + bs], key[i : i + bs], value[i : i + bs], sequence_length, dim, attention_mask[i : i + bs] if attention_mask is not None else attention_mask)
+ new_hidden_states.append(hidden_states)
+ hidden_states = torch.cat(new_hidden_states, dim = 0)
+
+ # linear proj
+ hidden_states = self.to_out[0](hidden_states)
+
+ # dropout
+ hidden_states = self.to_out[1](hidden_states)
+
+ if self.attention_mode == "Temporal":
+ hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
+ if self.grid:
+ hidden_states = rearrange(hidden_states, "(b f n m) (h w) c -> (b f) h n w m c", f=video_length, n=self.block_size, m=self.block_size, h=height // self.block_size, w=weight // self.block_size)
+ hidden_states = rearrange(hidden_states, "b h n w m c -> b (h n) (w m) c")
+ hidden_states = rearrange(hidden_states, "b h w c -> b (h w) c")
+
+ return hidden_states
\ No newline at end of file
diff --git a/easyanimate/models/patch.py b/easyanimate/models/patch.py
new file mode 100644
index 0000000000000000000000000000000000000000..48eca35f826327802d7afc620b1f6922c17c28e7
--- /dev/null
+++ b/easyanimate/models/patch.py
@@ -0,0 +1,426 @@
+from typing import Optional
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+import torch.nn.init as init
+import math
+from einops import rearrange
+from torch import nn
+
+
+def get_2d_sincos_pos_embed(
+ embed_dim, grid_size, cls_token=False, extra_tokens=0, interpolation_scale=1.0, base_size=16
+):
+ """
+ grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or
+ [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
+ """
+ if isinstance(grid_size, int):
+ grid_size = (grid_size, grid_size)
+
+ grid_h = np.arange(grid_size[0], dtype=np.float32) / (grid_size[0] / base_size) / interpolation_scale
+ grid_w = np.arange(grid_size[1], dtype=np.float32) / (grid_size[1] / base_size) / interpolation_scale
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
+ grid = np.stack(grid, axis=0)
+
+ grid = grid.reshape([2, 1, grid_size[1], grid_size[0]])
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
+ if cls_token and extra_tokens > 0:
+ pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
+ return pos_embed
+
+
+def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
+ if embed_dim % 2 != 0:
+ raise ValueError("embed_dim must be divisible by 2")
+
+ # use half of dimensions to encode grid_h
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
+
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
+ return emb
+
+
+def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
+ """
+ embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D)
+ """
+ if embed_dim % 2 != 0:
+ raise ValueError("embed_dim must be divisible by 2")
+
+ omega = np.arange(embed_dim // 2, dtype=np.float64)
+ omega /= embed_dim / 2.0
+ omega = 1.0 / 10000**omega # (D/2,)
+
+ pos = pos.reshape(-1) # (M,)
+ out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
+
+ emb_sin = np.sin(out) # (M, D/2)
+ emb_cos = np.cos(out) # (M, D/2)
+
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
+ return emb
+
+class Patch1D(nn.Module):
+ def __init__(
+ self,
+ channels: int,
+ use_conv: bool = False,
+ out_channels: Optional[int] = None,
+ stride: int = 2,
+ padding: int = 0,
+ name: str = "conv",
+ ):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.padding = padding
+ self.name = name
+
+ if use_conv:
+ self.conv = nn.Conv1d(self.channels, self.out_channels, stride, stride=stride, padding=padding)
+ init.constant_(self.conv.weight, 0.0)
+ with torch.no_grad():
+ for i in range(len(self.conv.weight)): self.conv.weight[i, i] = 1 / stride
+ init.constant_(self.conv.bias, 0.0)
+ else:
+ assert self.channels == self.out_channels
+ self.conv = nn.AvgPool1d(kernel_size=stride, stride=stride)
+
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
+ assert inputs.shape[1] == self.channels
+ return self.conv(inputs)
+
+class UnPatch1D(nn.Module):
+ def __init__(
+ self,
+ channels: int,
+ use_conv: bool = False,
+ use_conv_transpose: bool = False,
+ out_channels: Optional[int] = None,
+ name: str = "conv",
+ ):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.use_conv_transpose = use_conv_transpose
+ self.name = name
+
+ self.conv = None
+ if use_conv_transpose:
+ self.conv = nn.ConvTranspose1d(channels, self.out_channels, 4, 2, 1)
+ elif use_conv:
+ self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1)
+
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
+ assert inputs.shape[1] == self.channels
+ if self.use_conv_transpose:
+ return self.conv(inputs)
+
+ outputs = F.interpolate(inputs, scale_factor=2.0, mode="nearest")
+
+ if self.use_conv:
+ outputs = self.conv(outputs)
+
+ return outputs
+
+class Upsampler(nn.Module):
+ def __init__(
+ self,
+ spatial_upsample_factor: int = 1,
+ temporal_upsample_factor: int = 1,
+ ):
+ super().__init__()
+
+ self.spatial_upsample_factor = spatial_upsample_factor
+ self.temporal_upsample_factor = temporal_upsample_factor
+
+class TemporalUpsampler3D(Upsampler):
+ def __init__(self):
+ super().__init__(
+ spatial_upsample_factor=1,
+ temporal_upsample_factor=2,
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ if x.shape[2] > 1:
+ first_frame, x = x[:, :, :1], x[:, :, 1:]
+ x = F.interpolate(x, scale_factor=(2, 1, 1), mode="trilinear")
+ x = torch.cat([first_frame, x], dim=2)
+ return x
+
+def cast_tuple(t, length = 1):
+ return t if isinstance(t, tuple) else ((t,) * length)
+
+def divisible_by(num, den):
+ return (num % den) == 0
+
+def is_odd(n):
+ return not divisible_by(n, 2)
+
+class CausalConv3d(nn.Conv3d):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size=3, # : int | tuple[int, int, int],
+ stride=1, # : int | tuple[int, int, int] = 1,
+ padding=1, # : int | tuple[int, int, int], # TODO: change it to 0.
+ dilation=1, # : int | tuple[int, int, int] = 1,
+ **kwargs,
+ ):
+ kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size,) * 3
+ assert len(kernel_size) == 3, f"Kernel size must be a 3-tuple, got {kernel_size} instead."
+
+ stride = stride if isinstance(stride, tuple) else (stride,) * 3
+ assert len(stride) == 3, f"Stride must be a 3-tuple, got {stride} instead."
+
+ dilation = dilation if isinstance(dilation, tuple) else (dilation,) * 3
+ assert len(dilation) == 3, f"Dilation must be a 3-tuple, got {dilation} instead."
+
+ t_ks, h_ks, w_ks = kernel_size
+ _, h_stride, w_stride = stride
+ t_dilation, h_dilation, w_dilation = dilation
+
+ t_pad = (t_ks - 1) * t_dilation
+ # TODO: align with SD
+ if padding is None:
+ h_pad = math.ceil(((h_ks - 1) * h_dilation + (1 - h_stride)) / 2)
+ w_pad = math.ceil(((w_ks - 1) * w_dilation + (1 - w_stride)) / 2)
+ elif isinstance(padding, int):
+ h_pad = w_pad = padding
+ else:
+ assert NotImplementedError
+
+ self.temporal_padding = t_pad
+
+ super().__init__(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ dilation=dilation,
+ padding=(0, h_pad, w_pad),
+ **kwargs,
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ # x: (B, C, T, H, W)
+ x = F.pad(
+ x,
+ pad=(0, 0, 0, 0, self.temporal_padding, 0),
+ mode="replicate", # TODO: check if this is necessary
+ )
+ return super().forward(x)
+
+class PatchEmbed3D(nn.Module):
+ """3D Image to Patch Embedding"""
+
+ def __init__(
+ self,
+ height=224,
+ width=224,
+ patch_size=16,
+ time_patch_size=4,
+ in_channels=3,
+ embed_dim=768,
+ layer_norm=False,
+ flatten=True,
+ bias=True,
+ interpolation_scale=1,
+ ):
+ super().__init__()
+
+ num_patches = (height // patch_size) * (width // patch_size)
+ self.flatten = flatten
+ self.layer_norm = layer_norm
+
+ self.proj = nn.Conv3d(
+ in_channels, embed_dim, kernel_size=(time_patch_size, patch_size, patch_size), stride=(time_patch_size, patch_size, patch_size), bias=bias
+ )
+ if layer_norm:
+ self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6)
+ else:
+ self.norm = None
+
+ self.patch_size = patch_size
+ # See:
+ # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L161
+ self.height, self.width = height // patch_size, width // patch_size
+ self.base_size = height // patch_size
+ self.interpolation_scale = interpolation_scale
+ pos_embed = get_2d_sincos_pos_embed(
+ embed_dim, int(num_patches**0.5), base_size=self.base_size, interpolation_scale=self.interpolation_scale
+ )
+ self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=False)
+
+ def forward(self, latent):
+ height, width = latent.shape[-2] // self.patch_size, latent.shape[-1] // self.patch_size
+
+ latent = self.proj(latent)
+ latent = rearrange(latent, "b c f h w -> (b f) c h w")
+ if self.flatten:
+ latent = latent.flatten(2).transpose(1, 2) # BCFHW -> BNC
+ if self.layer_norm:
+ latent = self.norm(latent)
+ # Interpolate positional embeddings if needed.
+ # (For PixArt-Alpha: https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L162C151-L162C160)
+ if self.height != height or self.width != width:
+ pos_embed = get_2d_sincos_pos_embed(
+ embed_dim=self.pos_embed.shape[-1],
+ grid_size=(height, width),
+ base_size=self.base_size,
+ interpolation_scale=self.interpolation_scale,
+ )
+ pos_embed = torch.from_numpy(pos_embed)
+ pos_embed = pos_embed.float().unsqueeze(0).to(latent.device)
+ else:
+ pos_embed = self.pos_embed
+
+ return (latent + pos_embed).to(latent.dtype)
+
+class PatchEmbedF3D(nn.Module):
+ """Fake 3D Image to Patch Embedding"""
+
+ def __init__(
+ self,
+ height=224,
+ width=224,
+ patch_size=16,
+ in_channels=3,
+ embed_dim=768,
+ layer_norm=False,
+ flatten=True,
+ bias=True,
+ interpolation_scale=1,
+ ):
+ super().__init__()
+
+ num_patches = (height // patch_size) * (width // patch_size)
+ self.flatten = flatten
+ self.layer_norm = layer_norm
+
+ self.proj = nn.Conv2d(
+ in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
+ )
+ self.proj_t = Patch1D(
+ embed_dim, True, stride=patch_size
+ )
+ if layer_norm:
+ self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6)
+ else:
+ self.norm = None
+
+ self.patch_size = patch_size
+ # See:
+ # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L161
+ self.height, self.width = height // patch_size, width // patch_size
+ self.base_size = height // patch_size
+ self.interpolation_scale = interpolation_scale
+ pos_embed = get_2d_sincos_pos_embed(
+ embed_dim, int(num_patches**0.5), base_size=self.base_size, interpolation_scale=self.interpolation_scale
+ )
+ self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=False)
+
+ def forward(self, latent):
+ height, width = latent.shape[-2] // self.patch_size, latent.shape[-1] // self.patch_size
+ b, c, f, h, w = latent.size()
+ latent = rearrange(latent, "b c f h w -> (b f) c h w")
+ latent = self.proj(latent)
+ latent = rearrange(latent, "(b f) c h w -> b c f h w", f=f)
+
+ latent = rearrange(latent, "b c f h w -> (b h w) c f")
+ latent = self.proj_t(latent)
+ latent = rearrange(latent, "(b h w) c f -> b c f h w", h=h//2, w=w//2)
+
+ latent = rearrange(latent, "b c f h w -> (b f) c h w")
+ if self.flatten:
+ latent = latent.flatten(2).transpose(1, 2) # BCFHW -> BNC
+ if self.layer_norm:
+ latent = self.norm(latent)
+
+ # Interpolate positional embeddings if needed.
+ # (For PixArt-Alpha: https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L162C151-L162C160)
+ if self.height != height or self.width != width:
+ pos_embed = get_2d_sincos_pos_embed(
+ embed_dim=self.pos_embed.shape[-1],
+ grid_size=(height, width),
+ base_size=self.base_size,
+ interpolation_scale=self.interpolation_scale,
+ )
+ pos_embed = torch.from_numpy(pos_embed)
+ pos_embed = pos_embed.float().unsqueeze(0).to(latent.device)
+ else:
+ pos_embed = self.pos_embed
+
+ return (latent + pos_embed).to(latent.dtype)
+
+class CasualPatchEmbed3D(nn.Module):
+ """3D Image to Patch Embedding"""
+
+ def __init__(
+ self,
+ height=224,
+ width=224,
+ patch_size=16,
+ time_patch_size=4,
+ in_channels=3,
+ embed_dim=768,
+ layer_norm=False,
+ flatten=True,
+ bias=True,
+ interpolation_scale=1,
+ ):
+ super().__init__()
+
+ num_patches = (height // patch_size) * (width // patch_size)
+ self.flatten = flatten
+ self.layer_norm = layer_norm
+
+ self.proj = CausalConv3d(
+ in_channels, embed_dim, kernel_size=(time_patch_size, patch_size, patch_size), stride=(time_patch_size, patch_size, patch_size), bias=bias, padding=None
+ )
+ if layer_norm:
+ self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6)
+ else:
+ self.norm = None
+
+ self.patch_size = patch_size
+ # See:
+ # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L161
+ self.height, self.width = height // patch_size, width // patch_size
+ self.base_size = height // patch_size
+ self.interpolation_scale = interpolation_scale
+ pos_embed = get_2d_sincos_pos_embed(
+ embed_dim, int(num_patches**0.5), base_size=self.base_size, interpolation_scale=self.interpolation_scale
+ )
+ self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=False)
+
+ def forward(self, latent):
+ height, width = latent.shape[-2] // self.patch_size, latent.shape[-1] // self.patch_size
+
+ latent = self.proj(latent)
+ latent = rearrange(latent, "b c f h w -> (b f) c h w")
+ if self.flatten:
+ latent = latent.flatten(2).transpose(1, 2) # BCFHW -> BNC
+ if self.layer_norm:
+ latent = self.norm(latent)
+ # Interpolate positional embeddings if needed.
+ # (For PixArt-Alpha: https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L162C151-L162C160)
+ if self.height != height or self.width != width:
+ pos_embed = get_2d_sincos_pos_embed(
+ embed_dim=self.pos_embed.shape[-1],
+ grid_size=(height, width),
+ base_size=self.base_size,
+ interpolation_scale=self.interpolation_scale,
+ )
+ pos_embed = torch.from_numpy(pos_embed)
+ pos_embed = pos_embed.float().unsqueeze(0).to(latent.device)
+ else:
+ pos_embed = self.pos_embed
+
+ return (latent + pos_embed).to(latent.dtype)
diff --git a/easyanimate/models/transformer2d.py b/easyanimate/models/transformer2d.py
new file mode 100644
index 0000000000000000000000000000000000000000..836d8dad22541bc8b5d4f6fbec6b1dea2648c755
--- /dev/null
+++ b/easyanimate/models/transformer2d.py
@@ -0,0 +1,555 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import json
+import os
+from dataclasses import dataclass
+from typing import Any, Dict, Optional
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+import torch.nn.init as init
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+from diffusers.models.attention import BasicTransformerBlock
+from diffusers.models.embeddings import ImagePositionalEmbeddings, PatchEmbed
+from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear
+from diffusers.models.modeling_utils import ModelMixin
+from diffusers.models.normalization import AdaLayerNormSingle
+from diffusers.utils import (USE_PEFT_BACKEND, BaseOutput, deprecate,
+ is_torch_version)
+from einops import rearrange
+from torch import nn
+
+try:
+ from diffusers.models.embeddings import PixArtAlphaTextProjection
+except:
+ from diffusers.models.embeddings import \
+ CaptionProjection as PixArtAlphaTextProjection
+
+from .attention import (KVCompressionTransformerBlock,
+ SelfAttentionTemporalTransformerBlock,
+ TemporalTransformerBlock)
+
+
+@dataclass
+class Transformer2DModelOutput(BaseOutput):
+ """
+ The output of [`Transformer2DModel`].
+
+ Args:
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
+ The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
+ distributions for the unnoised latent pixels.
+ """
+
+ sample: torch.FloatTensor
+
+
+class Transformer2DModel(ModelMixin, ConfigMixin):
+ """
+ A 2D Transformer model for image-like data.
+
+ Parameters:
+ num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
+ attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
+ in_channels (`int`, *optional*):
+ The number of channels in the input and output (specify if the input is **continuous**).
+ num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
+ cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
+ sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
+ This is fixed during training since it is used to learn a number of position embeddings.
+ num_vector_embeds (`int`, *optional*):
+ The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**).
+ Includes the class for the masked latent pixel.
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
+ num_embeds_ada_norm ( `int`, *optional*):
+ The number of diffusion steps used during training. Pass if at least one of the norm_layers is
+ `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
+ added to the hidden states.
+
+ During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`.
+ attention_bias (`bool`, *optional*):
+ Configure if the `TransformerBlocks` attention should contain a bias parameter.
+ """
+ _supports_gradient_checkpointing = True
+
+ @register_to_config
+ def __init__(
+ self,
+ num_attention_heads: int = 16,
+ attention_head_dim: int = 88,
+ in_channels: Optional[int] = None,
+ out_channels: Optional[int] = None,
+ num_layers: int = 1,
+ dropout: float = 0.0,
+ norm_num_groups: int = 32,
+ cross_attention_dim: Optional[int] = None,
+ attention_bias: bool = False,
+ sample_size: Optional[int] = None,
+ num_vector_embeds: Optional[int] = None,
+ patch_size: Optional[int] = None,
+ activation_fn: str = "geglu",
+ num_embeds_ada_norm: Optional[int] = None,
+ use_linear_projection: bool = False,
+ only_cross_attention: bool = False,
+ double_self_attention: bool = False,
+ upcast_attention: bool = False,
+ norm_type: str = "layer_norm",
+ norm_elementwise_affine: bool = True,
+ norm_eps: float = 1e-5,
+ attention_type: str = "default",
+ caption_channels: int = None,
+ # block type
+ basic_block_type: str = "basic",
+ ):
+ super().__init__()
+ self.use_linear_projection = use_linear_projection
+ self.num_attention_heads = num_attention_heads
+ self.attention_head_dim = attention_head_dim
+ self.basic_block_type = basic_block_type
+ inner_dim = num_attention_heads * attention_head_dim
+
+ conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
+ linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
+
+ # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
+ # Define whether input is continuous or discrete depending on configuration
+ self.is_input_continuous = (in_channels is not None) and (patch_size is None)
+ self.is_input_vectorized = num_vector_embeds is not None
+ self.is_input_patches = in_channels is not None and patch_size is not None
+
+ if norm_type == "layer_norm" and num_embeds_ada_norm is not None:
+ deprecation_message = (
+ f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or"
+ " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config."
+ " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect"
+ " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it"
+ " would be very nice if you could open a Pull request for the `transformer/config.json` file"
+ )
+ deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False)
+ norm_type = "ada_norm"
+
+ if self.is_input_continuous and self.is_input_vectorized:
+ raise ValueError(
+ f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make"
+ " sure that either `in_channels` or `num_vector_embeds` is None."
+ )
+ elif self.is_input_vectorized and self.is_input_patches:
+ raise ValueError(
+ f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make"
+ " sure that either `num_vector_embeds` or `num_patches` is None."
+ )
+ elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches:
+ raise ValueError(
+ f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:"
+ f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None."
+ )
+
+ # 2. Define input layers
+ if self.is_input_continuous:
+ self.in_channels = in_channels
+
+ self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
+ if use_linear_projection:
+ self.proj_in = linear_cls(in_channels, inner_dim)
+ else:
+ self.proj_in = conv_cls(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
+ elif self.is_input_vectorized:
+ assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
+ assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed"
+
+ self.height = sample_size
+ self.width = sample_size
+ self.num_vector_embeds = num_vector_embeds
+ self.num_latent_pixels = self.height * self.width
+
+ self.latent_image_embedding = ImagePositionalEmbeddings(
+ num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width
+ )
+ elif self.is_input_patches:
+ assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size"
+
+ self.height = sample_size
+ self.width = sample_size
+
+ self.patch_size = patch_size
+ interpolation_scale = self.config.sample_size // 64 # => 64 (= 512 pixart) has interpolation scale 1
+ interpolation_scale = max(interpolation_scale, 1)
+ self.pos_embed = PatchEmbed(
+ height=sample_size,
+ width=sample_size,
+ patch_size=patch_size,
+ in_channels=in_channels,
+ embed_dim=inner_dim,
+ interpolation_scale=interpolation_scale,
+ )
+
+ basic_block = {
+ "basic": BasicTransformerBlock,
+ "kvcompression": KVCompressionTransformerBlock,
+ }[self.basic_block_type]
+ if self.basic_block_type == "kvcompression":
+ self.transformer_blocks = nn.ModuleList(
+ [
+ basic_block(
+ inner_dim,
+ num_attention_heads,
+ attention_head_dim,
+ dropout=dropout,
+ cross_attention_dim=cross_attention_dim,
+ activation_fn=activation_fn,
+ num_embeds_ada_norm=num_embeds_ada_norm,
+ attention_bias=attention_bias,
+ only_cross_attention=only_cross_attention,
+ double_self_attention=double_self_attention,
+ upcast_attention=upcast_attention,
+ norm_type=norm_type,
+ norm_elementwise_affine=norm_elementwise_affine,
+ norm_eps=norm_eps,
+ attention_type=attention_type,
+ kvcompression=False if d < 14 else True,
+ )
+ for d in range(num_layers)
+ ]
+ )
+ else:
+ # 3. Define transformers blocks
+ self.transformer_blocks = nn.ModuleList(
+ [
+ BasicTransformerBlock(
+ inner_dim,
+ num_attention_heads,
+ attention_head_dim,
+ dropout=dropout,
+ cross_attention_dim=cross_attention_dim,
+ activation_fn=activation_fn,
+ num_embeds_ada_norm=num_embeds_ada_norm,
+ attention_bias=attention_bias,
+ only_cross_attention=only_cross_attention,
+ double_self_attention=double_self_attention,
+ upcast_attention=upcast_attention,
+ norm_type=norm_type,
+ norm_elementwise_affine=norm_elementwise_affine,
+ norm_eps=norm_eps,
+ attention_type=attention_type,
+ )
+ for d in range(num_layers)
+ ]
+ )
+
+ # 4. Define output layers
+ self.out_channels = in_channels if out_channels is None else out_channels
+ if self.is_input_continuous:
+ # TODO: should use out_channels for continuous projections
+ if use_linear_projection:
+ self.proj_out = linear_cls(inner_dim, in_channels)
+ else:
+ self.proj_out = conv_cls(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
+ elif self.is_input_vectorized:
+ self.norm_out = nn.LayerNorm(inner_dim)
+ self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
+ elif self.is_input_patches and norm_type != "ada_norm_single":
+ self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
+ self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim)
+ self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
+ elif self.is_input_patches and norm_type == "ada_norm_single":
+ self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
+ self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5)
+ self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
+
+ # 5. PixArt-Alpha blocks.
+ self.adaln_single = None
+ self.use_additional_conditions = False
+ if norm_type == "ada_norm_single":
+ self.use_additional_conditions = self.config.sample_size == 128
+ # TODO(Sayak, PVP) clean this, for now we use sample size to determine whether to use
+ # additional conditions until we find better name
+ self.adaln_single = AdaLayerNormSingle(inner_dim, use_additional_conditions=self.use_additional_conditions)
+
+ self.caption_projection = None
+ if caption_channels is not None:
+ self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim)
+
+ self.gradient_checkpointing = False
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ if hasattr(module, "gradient_checkpointing"):
+ module.gradient_checkpointing = value
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ timestep: Optional[torch.LongTensor] = None,
+ added_cond_kwargs: Dict[str, torch.Tensor] = None,
+ class_labels: Optional[torch.LongTensor] = None,
+ cross_attention_kwargs: Dict[str, Any] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ return_dict: bool = True,
+ ):
+ """
+ The [`Transformer2DModel`] forward method.
+
+ Args:
+ hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
+ Input `hidden_states`.
+ encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
+ self-attention.
+ timestep ( `torch.LongTensor`, *optional*):
+ Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
+ class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
+ Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
+ `AdaLayerZeroNorm`.
+ cross_attention_kwargs ( `Dict[str, Any]`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ attention_mask ( `torch.Tensor`, *optional*):
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
+ is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
+ negative values to the attention scores corresponding to "discard" tokens.
+ encoder_attention_mask ( `torch.Tensor`, *optional*):
+ Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
+
+ * Mask `(batch, sequence_length)` True = keep, False = discard.
+ * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
+
+ If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
+ above. This bias will be added to the cross-attention scores.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
+ tuple.
+
+ Returns:
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
+ `tuple` where the first element is the sample tensor.
+ """
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
+ # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
+ # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
+ # expects mask of shape:
+ # [batch, key_tokens]
+ # adds singleton query_tokens dimension:
+ # [batch, 1, key_tokens]
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
+ if attention_mask is not None and attention_mask.ndim == 2:
+ # assume that mask is expressed as:
+ # (1 = keep, 0 = discard)
+ # convert mask into a bias that can be added to attention scores:
+ # (keep = +0, discard = -10000.0)
+ attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
+ attention_mask = attention_mask.unsqueeze(1)
+
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
+ if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
+ encoder_attention_mask = (1 - encoder_attention_mask.to(encoder_hidden_states.dtype)) * -10000.0
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
+
+ # Retrieve lora scale.
+ lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
+
+ # 1. Input
+ if self.is_input_continuous:
+ batch, _, height, width = hidden_states.shape
+ residual = hidden_states
+
+ hidden_states = self.norm(hidden_states)
+ if not self.use_linear_projection:
+ hidden_states = (
+ self.proj_in(hidden_states, scale=lora_scale)
+ if not USE_PEFT_BACKEND
+ else self.proj_in(hidden_states)
+ )
+ inner_dim = hidden_states.shape[1]
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
+ else:
+ inner_dim = hidden_states.shape[1]
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
+ hidden_states = (
+ self.proj_in(hidden_states, scale=lora_scale)
+ if not USE_PEFT_BACKEND
+ else self.proj_in(hidden_states)
+ )
+
+ elif self.is_input_vectorized:
+ hidden_states = self.latent_image_embedding(hidden_states)
+ elif self.is_input_patches:
+ height, width = hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size
+ hidden_states = self.pos_embed(hidden_states)
+
+ if self.adaln_single is not None:
+ if self.use_additional_conditions and added_cond_kwargs is None:
+ raise ValueError(
+ "`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`."
+ )
+ batch_size = hidden_states.shape[0]
+ timestep, embedded_timestep = self.adaln_single(
+ timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype
+ )
+
+ # 2. Blocks
+ if self.caption_projection is not None:
+ batch_size = hidden_states.shape[0]
+ encoder_hidden_states = self.caption_projection(encoder_hidden_states)
+ encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
+
+ for block in self.transformer_blocks:
+ if self.training and self.gradient_checkpointing:
+ args = {
+ "basic": [],
+ "kvcompression": [1, height, width],
+ }[self.basic_block_type]
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ block,
+ hidden_states,
+ attention_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ timestep,
+ cross_attention_kwargs,
+ class_labels,
+ *args,
+ use_reentrant=False,
+ )
+ else:
+ kwargs = {
+ "basic": {},
+ "kvcompression": {"num_frames":1, "height":height, "width":width},
+ }[self.basic_block_type]
+ hidden_states = block(
+ hidden_states,
+ attention_mask=attention_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ timestep=timestep,
+ cross_attention_kwargs=cross_attention_kwargs,
+ class_labels=class_labels,
+ **kwargs
+ )
+
+ # 3. Output
+ if self.is_input_continuous:
+ if not self.use_linear_projection:
+ hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
+ hidden_states = (
+ self.proj_out(hidden_states, scale=lora_scale)
+ if not USE_PEFT_BACKEND
+ else self.proj_out(hidden_states)
+ )
+ else:
+ hidden_states = (
+ self.proj_out(hidden_states, scale=lora_scale)
+ if not USE_PEFT_BACKEND
+ else self.proj_out(hidden_states)
+ )
+ hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
+
+ output = hidden_states + residual
+ elif self.is_input_vectorized:
+ hidden_states = self.norm_out(hidden_states)
+ logits = self.out(hidden_states)
+ # (batch, self.num_vector_embeds - 1, self.num_latent_pixels)
+ logits = logits.permute(0, 2, 1)
+
+ # log(p(x_0))
+ output = F.log_softmax(logits.double(), dim=1).float()
+
+ if self.is_input_patches:
+ if self.config.norm_type != "ada_norm_single":
+ conditioning = self.transformer_blocks[0].norm1.emb(
+ timestep, class_labels, hidden_dtype=hidden_states.dtype
+ )
+ shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
+ hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
+ hidden_states = self.proj_out_2(hidden_states)
+ elif self.config.norm_type == "ada_norm_single":
+ shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1)
+ hidden_states = self.norm_out(hidden_states)
+ # Modulation
+ hidden_states = hidden_states * (1 + scale) + shift
+ hidden_states = self.proj_out(hidden_states)
+ hidden_states = hidden_states.squeeze(1)
+
+ # unpatchify
+ if self.adaln_single is None:
+ height = width = int(hidden_states.shape[1] ** 0.5)
+ hidden_states = hidden_states.reshape(
+ shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
+ )
+ hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
+ output = hidden_states.reshape(
+ shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
+ )
+
+ if not return_dict:
+ return (output,)
+
+ return Transformer2DModelOutput(sample=output)
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_path, subfolder=None, patch_size=2, transformer_additional_kwargs={}):
+ if subfolder is not None:
+ pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
+ print(f"loaded 2D transformer's pretrained weights from {pretrained_model_path} ...")
+
+ config_file = os.path.join(pretrained_model_path, 'config.json')
+ if not os.path.isfile(config_file):
+ raise RuntimeError(f"{config_file} does not exist")
+ with open(config_file, "r") as f:
+ config = json.load(f)
+
+ from diffusers.utils import WEIGHTS_NAME
+ model = cls.from_config(config, **transformer_additional_kwargs)
+ model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
+ model_file_safetensors = model_file.replace(".bin", ".safetensors")
+ if os.path.exists(model_file_safetensors):
+ from safetensors.torch import load_file, safe_open
+ state_dict = load_file(model_file_safetensors)
+ else:
+ if not os.path.isfile(model_file):
+ raise RuntimeError(f"{model_file} does not exist")
+ state_dict = torch.load(model_file, map_location="cpu")
+
+ if model.state_dict()['pos_embed.proj.weight'].size() != state_dict['pos_embed.proj.weight'].size():
+ new_shape = model.state_dict()['pos_embed.proj.weight'].size()
+ state_dict['pos_embed.proj.weight'] = torch.tile(state_dict['proj_out.weight'], [1, 2, 1, 1])
+
+ if model.state_dict()['proj_out.weight'].size() != state_dict['proj_out.weight'].size():
+ new_shape = model.state_dict()['proj_out.weight'].size()
+ state_dict['proj_out.weight'] = torch.tile(state_dict['proj_out.weight'], [patch_size, 1])
+
+ if model.state_dict()['proj_out.bias'].size() != state_dict['proj_out.bias'].size():
+ new_shape = model.state_dict()['proj_out.bias'].size()
+ state_dict['proj_out.bias'] = torch.tile(state_dict['proj_out.bias'], [patch_size])
+
+ tmp_state_dict = {}
+ for key in state_dict:
+ if key in model.state_dict().keys() and model.state_dict()[key].size() == state_dict[key].size():
+ tmp_state_dict[key] = state_dict[key]
+ else:
+ print(key, "Size don't match, skip")
+ state_dict = tmp_state_dict
+
+ m, u = model.load_state_dict(state_dict, strict=False)
+ print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
+
+ params = [p.numel() if "motion_modules." in n else 0 for n, p in model.named_parameters()]
+ print(f"### Postion Parameters: {sum(params) / 1e6} M")
+
+ return model
\ No newline at end of file
diff --git a/easyanimate/models/transformer3d.py b/easyanimate/models/transformer3d.py
new file mode 100644
index 0000000000000000000000000000000000000000..4b566b13f82313d0bfd3bb355e73e56bebb9fc04
--- /dev/null
+++ b/easyanimate/models/transformer3d.py
@@ -0,0 +1,738 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import json
+import math
+import os
+from dataclasses import dataclass
+from typing import Any, Dict, Optional
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+import torch.nn.init as init
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+from diffusers.models.attention import BasicTransformerBlock
+from diffusers.models.embeddings import PatchEmbed, Timesteps, TimestepEmbedding
+from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear
+from diffusers.models.modeling_utils import ModelMixin
+from diffusers.models.normalization import AdaLayerNormSingle
+from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, is_torch_version
+from einops import rearrange
+from torch import nn
+from typing import Dict, Optional, Tuple
+
+from .attention import (SelfAttentionTemporalTransformerBlock,
+ TemporalTransformerBlock)
+from .patch import Patch1D, PatchEmbed3D, PatchEmbedF3D, UnPatch1D, TemporalUpsampler3D, CasualPatchEmbed3D
+
+try:
+ from diffusers.models.embeddings import PixArtAlphaTextProjection
+except:
+ from diffusers.models.embeddings import \
+ CaptionProjection as PixArtAlphaTextProjection
+
+def zero_module(module):
+ # Zero out the parameters of a module and return it.
+ for p in module.parameters():
+ p.detach().zero_()
+ return module
+
+class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module):
+ """
+ For PixArt-Alpha.
+
+ Reference:
+ https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L164C9-L168C29
+ """
+
+ def __init__(self, embedding_dim, size_emb_dim, use_additional_conditions: bool = False):
+ super().__init__()
+
+ self.outdim = size_emb_dim
+ self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
+ self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
+
+ self.use_additional_conditions = use_additional_conditions
+ if use_additional_conditions:
+ self.additional_condition_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
+ self.resolution_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim)
+ self.aspect_ratio_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim)
+
+ self.resolution_embedder.linear_2 = zero_module(self.resolution_embedder.linear_2)
+ self.aspect_ratio_embedder.linear_2 = zero_module(self.aspect_ratio_embedder.linear_2)
+
+ def forward(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype):
+ timesteps_proj = self.time_proj(timestep)
+ timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D)
+
+ if self.use_additional_conditions:
+ resolution_emb = self.additional_condition_proj(resolution.flatten()).to(hidden_dtype)
+ resolution_emb = self.resolution_embedder(resolution_emb).reshape(batch_size, -1)
+ aspect_ratio_emb = self.additional_condition_proj(aspect_ratio.flatten()).to(hidden_dtype)
+ aspect_ratio_emb = self.aspect_ratio_embedder(aspect_ratio_emb).reshape(batch_size, -1)
+ conditioning = timesteps_emb + torch.cat([resolution_emb, aspect_ratio_emb], dim=1)
+ else:
+ conditioning = timesteps_emb
+
+ return conditioning
+
+class AdaLayerNormSingle(nn.Module):
+ r"""
+ Norm layer adaptive layer norm single (adaLN-single).
+
+ As proposed in PixArt-Alpha (see: https://arxiv.org/abs/2310.00426; Section 2.3).
+
+ Parameters:
+ embedding_dim (`int`): The size of each embedding vector.
+ use_additional_conditions (`bool`): To use additional conditions for normalization or not.
+ """
+
+ def __init__(self, embedding_dim: int, use_additional_conditions: bool = False):
+ super().__init__()
+
+ self.emb = PixArtAlphaCombinedTimestepSizeEmbeddings(
+ embedding_dim, size_emb_dim=embedding_dim // 3, use_additional_conditions=use_additional_conditions
+ )
+
+ self.silu = nn.SiLU()
+ self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True)
+
+ def forward(
+ self,
+ timestep: torch.Tensor,
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
+ batch_size: Optional[int] = None,
+ hidden_dtype: Optional[torch.dtype] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ # No modulation happening here.
+ embedded_timestep = self.emb(timestep, **added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_dtype)
+ return self.linear(self.silu(embedded_timestep)), embedded_timestep
+
+
+class TimePositionalEncoding(nn.Module):
+ def __init__(
+ self,
+ d_model,
+ dropout = 0.,
+ max_len = 24
+ ):
+ super().__init__()
+ self.dropout = nn.Dropout(p=dropout)
+ position = torch.arange(max_len).unsqueeze(1)
+ div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
+ pe = torch.zeros(1, max_len, d_model)
+ pe[0, :, 0::2] = torch.sin(position * div_term)
+ pe[0, :, 1::2] = torch.cos(position * div_term)
+ self.register_buffer('pe', pe)
+
+ def forward(self, x):
+ b, c, f, h, w = x.size()
+ x = rearrange(x, "b c f h w -> (b h w) f c")
+ x = x + self.pe[:, :x.size(1)]
+ x = rearrange(x, "(b h w) f c -> b c f h w", b=b, h=h, w=w)
+ return self.dropout(x)
+
+@dataclass
+class Transformer3DModelOutput(BaseOutput):
+ """
+ The output of [`Transformer2DModel`].
+
+ Args:
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
+ The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
+ distributions for the unnoised latent pixels.
+ """
+
+ sample: torch.FloatTensor
+
+
+class Transformer3DModel(ModelMixin, ConfigMixin):
+ """
+ A 3D Transformer model for image-like data.
+
+ Parameters:
+ num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
+ attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
+ in_channels (`int`, *optional*):
+ The number of channels in the input and output (specify if the input is **continuous**).
+ num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
+ cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
+ sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
+ This is fixed during training since it is used to learn a number of position embeddings.
+ num_vector_embeds (`int`, *optional*):
+ The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**).
+ Includes the class for the masked latent pixel.
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
+ num_embeds_ada_norm ( `int`, *optional*):
+ The number of diffusion steps used during training. Pass if at least one of the norm_layers is
+ `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
+ added to the hidden states.
+
+ During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`.
+ attention_bias (`bool`, *optional*):
+ Configure if the `TransformerBlocks` attention should contain a bias parameter.
+ """
+
+ _supports_gradient_checkpointing = True
+
+ @register_to_config
+ def __init__(
+ self,
+ num_attention_heads: int = 16,
+ attention_head_dim: int = 88,
+ in_channels: Optional[int] = None,
+ out_channels: Optional[int] = None,
+ num_layers: int = 1,
+ dropout: float = 0.0,
+ norm_num_groups: int = 32,
+ cross_attention_dim: Optional[int] = None,
+ attention_bias: bool = False,
+ sample_size: Optional[int] = None,
+ num_vector_embeds: Optional[int] = None,
+ patch_size: Optional[int] = None,
+ activation_fn: str = "geglu",
+ num_embeds_ada_norm: Optional[int] = None,
+ use_linear_projection: bool = False,
+ only_cross_attention: bool = False,
+ double_self_attention: bool = False,
+ upcast_attention: bool = False,
+ norm_type: str = "layer_norm",
+ norm_elementwise_affine: bool = True,
+ norm_eps: float = 1e-5,
+ attention_type: str = "default",
+ caption_channels: int = None,
+ # block type
+ basic_block_type: str = "motionmodule",
+ # enable_uvit
+ enable_uvit: bool = False,
+
+ # 3d patch params
+ patch_3d: bool = False,
+ fake_3d: bool = False,
+ time_patch_size: Optional[int] = None,
+
+ casual_3d: bool = False,
+ casual_3d_upsampler_index: Optional[list] = None,
+
+ # motion module kwargs
+ motion_module_type = "VanillaGrid",
+ motion_module_kwargs = None,
+
+ # time position encoding
+ time_position_encoding_before_transformer = False
+ ):
+ super().__init__()
+ self.use_linear_projection = use_linear_projection
+ self.num_attention_heads = num_attention_heads
+ self.attention_head_dim = attention_head_dim
+ self.enable_uvit = enable_uvit
+ inner_dim = num_attention_heads * attention_head_dim
+ self.basic_block_type = basic_block_type
+ self.patch_3d = patch_3d
+ self.fake_3d = fake_3d
+ self.casual_3d = casual_3d
+ self.casual_3d_upsampler_index = casual_3d_upsampler_index
+
+ conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
+ linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
+
+ assert sample_size is not None, "Transformer3DModel over patched input must provide sample_size"
+
+ self.height = sample_size
+ self.width = sample_size
+
+ self.patch_size = patch_size
+ self.time_patch_size = self.patch_size if time_patch_size is None else time_patch_size
+ interpolation_scale = self.config.sample_size // 64 # => 64 (= 512 pixart) has interpolation scale 1
+ interpolation_scale = max(interpolation_scale, 1)
+
+ if self.casual_3d:
+ self.pos_embed = CasualPatchEmbed3D(
+ height=sample_size,
+ width=sample_size,
+ patch_size=patch_size,
+ time_patch_size=self.time_patch_size,
+ in_channels=in_channels,
+ embed_dim=inner_dim,
+ interpolation_scale=interpolation_scale,
+ )
+ elif self.patch_3d:
+ if self.fake_3d:
+ self.pos_embed = PatchEmbedF3D(
+ height=sample_size,
+ width=sample_size,
+ patch_size=patch_size,
+ in_channels=in_channels,
+ embed_dim=inner_dim,
+ interpolation_scale=interpolation_scale,
+ )
+ else:
+ self.pos_embed = PatchEmbed3D(
+ height=sample_size,
+ width=sample_size,
+ patch_size=patch_size,
+ time_patch_size=self.time_patch_size,
+ in_channels=in_channels,
+ embed_dim=inner_dim,
+ interpolation_scale=interpolation_scale,
+ )
+ else:
+ self.pos_embed = PatchEmbed(
+ height=sample_size,
+ width=sample_size,
+ patch_size=patch_size,
+ in_channels=in_channels,
+ embed_dim=inner_dim,
+ interpolation_scale=interpolation_scale,
+ )
+
+ # 3. Define transformers blocks
+ if self.basic_block_type == "motionmodule":
+ self.transformer_blocks = nn.ModuleList(
+ [
+ TemporalTransformerBlock(
+ inner_dim,
+ num_attention_heads,
+ attention_head_dim,
+ dropout=dropout,
+ cross_attention_dim=cross_attention_dim,
+ activation_fn=activation_fn,
+ num_embeds_ada_norm=num_embeds_ada_norm,
+ attention_bias=attention_bias,
+ only_cross_attention=only_cross_attention,
+ double_self_attention=double_self_attention,
+ upcast_attention=upcast_attention,
+ norm_type=norm_type,
+ norm_elementwise_affine=norm_elementwise_affine,
+ norm_eps=norm_eps,
+ attention_type=attention_type,
+ motion_module_type=motion_module_type,
+ motion_module_kwargs=motion_module_kwargs,
+ )
+ for d in range(num_layers)
+ ]
+ )
+ elif self.basic_block_type == "kvcompression_motionmodule":
+ self.transformer_blocks = nn.ModuleList(
+ [
+ TemporalTransformerBlock(
+ inner_dim,
+ num_attention_heads,
+ attention_head_dim,
+ dropout=dropout,
+ cross_attention_dim=cross_attention_dim,
+ activation_fn=activation_fn,
+ num_embeds_ada_norm=num_embeds_ada_norm,
+ attention_bias=attention_bias,
+ only_cross_attention=only_cross_attention,
+ double_self_attention=double_self_attention,
+ upcast_attention=upcast_attention,
+ norm_type=norm_type,
+ norm_elementwise_affine=norm_elementwise_affine,
+ norm_eps=norm_eps,
+ attention_type=attention_type,
+ kvcompression=False if d < 14 else True,
+ motion_module_type=motion_module_type,
+ motion_module_kwargs=motion_module_kwargs,
+ )
+ for d in range(num_layers)
+ ]
+ )
+ elif self.basic_block_type == "selfattentiontemporal":
+ self.transformer_blocks = nn.ModuleList(
+ [
+ SelfAttentionTemporalTransformerBlock(
+ inner_dim,
+ num_attention_heads,
+ attention_head_dim,
+ dropout=dropout,
+ cross_attention_dim=cross_attention_dim,
+ activation_fn=activation_fn,
+ num_embeds_ada_norm=num_embeds_ada_norm,
+ attention_bias=attention_bias,
+ only_cross_attention=only_cross_attention,
+ double_self_attention=double_self_attention,
+ upcast_attention=upcast_attention,
+ norm_type=norm_type,
+ norm_elementwise_affine=norm_elementwise_affine,
+ norm_eps=norm_eps,
+ attention_type=attention_type,
+ )
+ for d in range(num_layers)
+ ]
+ )
+ else:
+ self.transformer_blocks = nn.ModuleList(
+ [
+ BasicTransformerBlock(
+ inner_dim,
+ num_attention_heads,
+ attention_head_dim,
+ dropout=dropout,
+ cross_attention_dim=cross_attention_dim,
+ activation_fn=activation_fn,
+ num_embeds_ada_norm=num_embeds_ada_norm,
+ attention_bias=attention_bias,
+ only_cross_attention=only_cross_attention,
+ double_self_attention=double_self_attention,
+ upcast_attention=upcast_attention,
+ norm_type=norm_type,
+ norm_elementwise_affine=norm_elementwise_affine,
+ norm_eps=norm_eps,
+ attention_type=attention_type,
+ )
+ for d in range(num_layers)
+ ]
+ )
+
+ if self.casual_3d:
+ self.unpatch1d = TemporalUpsampler3D()
+ elif self.patch_3d and self.fake_3d:
+ self.unpatch1d = UnPatch1D(inner_dim, True)
+
+ if self.enable_uvit:
+ self.long_connect_fc = nn.ModuleList(
+ [
+ nn.Linear(inner_dim, inner_dim, True) for d in range(13)
+ ]
+ )
+ for index in range(13):
+ self.long_connect_fc[index] = zero_module(self.long_connect_fc[index])
+
+ # 4. Define output layers
+ self.out_channels = in_channels if out_channels is None else out_channels
+ if norm_type != "ada_norm_single":
+ self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
+ self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim)
+ if self.patch_3d and not self.fake_3d:
+ self.proj_out_2 = nn.Linear(inner_dim, self.time_patch_size * patch_size * patch_size * self.out_channels)
+ else:
+ self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
+ elif norm_type == "ada_norm_single":
+ self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
+ self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5)
+ if self.patch_3d and not self.fake_3d:
+ self.proj_out = nn.Linear(inner_dim, self.time_patch_size * patch_size * patch_size * self.out_channels)
+ else:
+ self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
+
+ # 5. PixArt-Alpha blocks.
+ self.adaln_single = None
+ self.use_additional_conditions = False
+ if norm_type == "ada_norm_single":
+ self.use_additional_conditions = self.config.sample_size == 128
+ # TODO(Sayak, PVP) clean this, for now we use sample size to determine whether to use
+ # additional conditions until we find better name
+ self.adaln_single = AdaLayerNormSingle(inner_dim, use_additional_conditions=self.use_additional_conditions)
+
+ self.caption_projection = None
+ if caption_channels is not None:
+ self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim)
+
+ self.gradient_checkpointing = False
+
+ self.time_position_encoding_before_transformer = time_position_encoding_before_transformer
+ if self.time_position_encoding_before_transformer:
+ self.t_pos = TimePositionalEncoding(max_len = 4096, d_model = inner_dim)
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ if hasattr(module, "gradient_checkpointing"):
+ module.gradient_checkpointing = value
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ inpaint_latents: torch.Tensor = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ timestep: Optional[torch.LongTensor] = None,
+ added_cond_kwargs: Dict[str, torch.Tensor] = None,
+ class_labels: Optional[torch.LongTensor] = None,
+ cross_attention_kwargs: Dict[str, Any] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ return_dict: bool = True,
+ ):
+ """
+ The [`Transformer2DModel`] forward method.
+
+ Args:
+ hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
+ Input `hidden_states`.
+ encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
+ self-attention.
+ timestep ( `torch.LongTensor`, *optional*):
+ Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
+ class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
+ Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
+ `AdaLayerZeroNorm`.
+ cross_attention_kwargs ( `Dict[str, Any]`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ attention_mask ( `torch.Tensor`, *optional*):
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
+ is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
+ negative values to the attention scores corresponding to "discard" tokens.
+ encoder_attention_mask ( `torch.Tensor`, *optional*):
+ Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
+
+ * Mask `(batch, sequence_length)` True = keep, False = discard.
+ * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
+
+ If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
+ above. This bias will be added to the cross-attention scores.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
+ tuple.
+
+ Returns:
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer3DModelOutput`] is returned, otherwise a
+ `tuple` where the first element is the sample tensor.
+ """
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
+ # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
+ # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
+ # expects mask of shape:
+ # [batch, key_tokens]
+ # adds singleton query_tokens dimension:
+ # [batch, 1, key_tokens]
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
+ if attention_mask is not None and attention_mask.ndim == 2:
+ # assume that mask is expressed as:
+ # (1 = keep, 0 = discard)
+ # convert mask into a bias that can be added to attention scores:
+ # (keep = +0, discard = -10000.0)
+ attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
+ attention_mask = attention_mask.unsqueeze(1)
+
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
+ if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
+ encoder_attention_mask = (1 - encoder_attention_mask.to(encoder_hidden_states.dtype)) * -10000.0
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
+
+ if inpaint_latents is not None:
+ hidden_states = torch.concat([hidden_states, inpaint_latents], 1)
+ # 1. Input
+ if self.casual_3d:
+ video_length, height, width = (hidden_states.shape[-3] - 1) // self.time_patch_size + 1, hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size
+ elif self.patch_3d:
+ video_length, height, width = hidden_states.shape[-3] // self.time_patch_size, hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size
+ else:
+ video_length, height, width = hidden_states.shape[-3], hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size
+ hidden_states = rearrange(hidden_states, "b c f h w ->(b f) c h w")
+
+ hidden_states = self.pos_embed(hidden_states)
+ if self.adaln_single is not None:
+ if self.use_additional_conditions and added_cond_kwargs is None:
+ raise ValueError(
+ "`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`."
+ )
+ batch_size = hidden_states.shape[0] // video_length
+ timestep, embedded_timestep = self.adaln_single(
+ timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype
+ )
+ hidden_states = rearrange(hidden_states, "(b f) (h w) c -> b c f h w", f=video_length, h=height, w=width)
+
+ # hidden_states
+ # bs, c, f, h, w => b (f h w ) c
+ if self.time_position_encoding_before_transformer:
+ hidden_states = self.t_pos(hidden_states)
+ hidden_states = hidden_states.flatten(2).transpose(1, 2)
+
+ # 2. Blocks
+ if self.caption_projection is not None:
+ batch_size = hidden_states.shape[0]
+ encoder_hidden_states = self.caption_projection(encoder_hidden_states)
+ encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
+
+ skips = []
+ skip_index = 0
+ for index, block in enumerate(self.transformer_blocks):
+ if self.enable_uvit:
+ if index >= 15:
+ long_connect = self.long_connect_fc[skip_index](skips.pop())
+ hidden_states = hidden_states + long_connect
+ skip_index += 1
+
+ if self.casual_3d_upsampler_index is not None and index in self.casual_3d_upsampler_index:
+ hidden_states = rearrange(hidden_states, "b (f h w) c -> b c f h w", f=video_length, h=height, w=width)
+ hidden_states = self.unpatch1d(hidden_states)
+ video_length = (video_length - 1) * 2 + 1
+ hidden_states = rearrange(hidden_states, "b c f h w -> b (f h w) c", f=video_length, h=height, w=width)
+
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module, return_dict=None):
+ def custom_forward(*inputs):
+ if return_dict is not None:
+ return module(*inputs, return_dict=return_dict)
+ else:
+ return module(*inputs)
+
+ return custom_forward
+
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
+ args = {
+ "basic": [],
+ "motionmodule": [video_length, height, width],
+ "selfattentiontemporal": [video_length, height, width],
+ "kvcompression_motionmodule": [video_length, height, width],
+ }[self.basic_block_type]
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(block),
+ hidden_states,
+ attention_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ timestep,
+ cross_attention_kwargs,
+ class_labels,
+ *args,
+ **ckpt_kwargs,
+ )
+ else:
+ kwargs = {
+ "basic": {},
+ "motionmodule": {"num_frames":video_length, "height":height, "width":width},
+ "selfattentiontemporal": {"num_frames":video_length, "height":height, "width":width},
+ "kvcompression_motionmodule": {"num_frames":video_length, "height":height, "width":width},
+ }[self.basic_block_type]
+ hidden_states = block(
+ hidden_states,
+ attention_mask=attention_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ timestep=timestep,
+ cross_attention_kwargs=cross_attention_kwargs,
+ class_labels=class_labels,
+ **kwargs
+ )
+
+ if self.enable_uvit:
+ if index < 13:
+ skips.append(hidden_states)
+
+ if self.fake_3d and self.patch_3d:
+ hidden_states = rearrange(hidden_states, "b (f h w) c -> (b h w) c f", f=video_length, w=width, h=height)
+ hidden_states = self.unpatch1d(hidden_states)
+ hidden_states = rearrange(hidden_states, "(b h w) c f -> b (f h w) c", w=width, h=height)
+
+ # 3. Output
+ if self.config.norm_type != "ada_norm_single":
+ conditioning = self.transformer_blocks[0].norm1.emb(
+ timestep, class_labels, hidden_dtype=hidden_states.dtype
+ )
+ shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
+ hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
+ hidden_states = self.proj_out_2(hidden_states)
+ elif self.config.norm_type == "ada_norm_single":
+ shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1)
+ hidden_states = self.norm_out(hidden_states)
+ # Modulation
+ hidden_states = hidden_states * (1 + scale) + shift
+ hidden_states = self.proj_out(hidden_states)
+ hidden_states = hidden_states.squeeze(1)
+
+ # unpatchify
+ if self.adaln_single is None:
+ height = width = int(hidden_states.shape[1] ** 0.5)
+ if self.patch_3d:
+ if self.fake_3d:
+ hidden_states = hidden_states.reshape(
+ shape=(-1, video_length * self.patch_size, height, width, self.patch_size, self.patch_size, self.out_channels)
+ )
+ hidden_states = torch.einsum("nfhwpqc->ncfhpwq", hidden_states)
+ else:
+ hidden_states = hidden_states.reshape(
+ shape=(-1, video_length, height, width, self.time_patch_size, self.patch_size, self.patch_size, self.out_channels)
+ )
+ hidden_states = torch.einsum("nfhwopqc->ncfohpwq", hidden_states)
+ output = hidden_states.reshape(
+ shape=(-1, self.out_channels, video_length * self.time_patch_size, height * self.patch_size, width * self.patch_size)
+ )
+ else:
+ hidden_states = hidden_states.reshape(
+ shape=(-1, video_length, height, width, self.patch_size, self.patch_size, self.out_channels)
+ )
+ hidden_states = torch.einsum("nfhwpqc->ncfhpwq", hidden_states)
+ output = hidden_states.reshape(
+ shape=(-1, self.out_channels, video_length, height * self.patch_size, width * self.patch_size)
+ )
+
+ if not return_dict:
+ return (output,)
+
+ return Transformer3DModelOutput(sample=output)
+
+ @classmethod
+ def from_pretrained_2d(cls, pretrained_model_path, subfolder=None, patch_size=2, transformer_additional_kwargs={}):
+ if subfolder is not None:
+ pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
+ print(f"loaded 3D transformer's pretrained weights from {pretrained_model_path} ...")
+
+ config_file = os.path.join(pretrained_model_path, 'config.json')
+ if not os.path.isfile(config_file):
+ raise RuntimeError(f"{config_file} does not exist")
+ with open(config_file, "r") as f:
+ config = json.load(f)
+
+ from diffusers.utils import WEIGHTS_NAME
+ model = cls.from_config(config, **transformer_additional_kwargs)
+ model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
+ model_file_safetensors = model_file.replace(".bin", ".safetensors")
+ if os.path.exists(model_file_safetensors):
+ from safetensors.torch import load_file, safe_open
+ state_dict = load_file(model_file_safetensors)
+ else:
+ if not os.path.isfile(model_file):
+ raise RuntimeError(f"{model_file} does not exist")
+ state_dict = torch.load(model_file, map_location="cpu")
+
+ if model.state_dict()['pos_embed.proj.weight'].size() != state_dict['pos_embed.proj.weight'].size():
+ new_shape = model.state_dict()['pos_embed.proj.weight'].size()
+ if len(new_shape) == 5:
+ state_dict['pos_embed.proj.weight'] = state_dict['pos_embed.proj.weight'].unsqueeze(2).expand(new_shape).clone()
+ state_dict['pos_embed.proj.weight'][:, :, :-1] = 0
+ else:
+ model.state_dict()['pos_embed.proj.weight'][:, :4, :, :] = state_dict['pos_embed.proj.weight']
+ model.state_dict()['pos_embed.proj.weight'][:, 4:, :, :] = 0
+ state_dict['pos_embed.proj.weight'] = model.state_dict()['pos_embed.proj.weight']
+
+ if model.state_dict()['proj_out.weight'].size() != state_dict['proj_out.weight'].size():
+ new_shape = model.state_dict()['proj_out.weight'].size()
+ state_dict['proj_out.weight'] = torch.tile(state_dict['proj_out.weight'], [patch_size, 1])
+
+ if model.state_dict()['proj_out.bias'].size() != state_dict['proj_out.bias'].size():
+ new_shape = model.state_dict()['proj_out.bias'].size()
+ state_dict['proj_out.bias'] = torch.tile(state_dict['proj_out.bias'], [patch_size])
+
+ tmp_state_dict = {}
+ for key in state_dict:
+ if key in model.state_dict().keys() and model.state_dict()[key].size() == state_dict[key].size():
+ tmp_state_dict[key] = state_dict[key]
+ else:
+ print(key, "Size don't match, skip")
+ state_dict = tmp_state_dict
+
+ m, u = model.load_state_dict(state_dict, strict=False)
+ print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
+
+ params = [p.numel() if "attn_temporal." in n else 0 for n, p in model.named_parameters()]
+ print(f"### Attn temporal Parameters: {sum(params) / 1e6} M")
+
+ return model
\ No newline at end of file
diff --git a/easyanimate/pipeline/pipeline_easyanimate.py b/easyanimate/pipeline/pipeline_easyanimate.py
new file mode 100644
index 0000000000000000000000000000000000000000..3285a3d0dc2ec677501432f6388496b492123b3d
--- /dev/null
+++ b/easyanimate/pipeline/pipeline_easyanimate.py
@@ -0,0 +1,847 @@
+# Copyright 2023 PixArt-Alpha Authors and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import html
+import inspect
+import copy
+import re
+import urllib.parse as ul
+from dataclasses import dataclass
+from typing import Callable, List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+from diffusers import DiffusionPipeline, ImagePipelineOutput
+from diffusers.image_processor import VaeImageProcessor
+from diffusers.models import AutoencoderKL
+from diffusers.schedulers import DPMSolverMultistepScheduler
+from diffusers.utils import (BACKENDS_MAPPING, BaseOutput, deprecate,
+ is_bs4_available, is_ftfy_available, logging,
+ replace_example_docstring)
+from diffusers.utils.torch_utils import randn_tensor
+from einops import rearrange
+from tqdm import tqdm
+from transformers import T5EncoderModel, T5Tokenizer
+
+from ..models.transformer3d import Transformer3DModel
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+if is_bs4_available():
+ from bs4 import BeautifulSoup
+
+if is_ftfy_available():
+ import ftfy
+
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import EasyAnimatePipeline
+
+ >>> # You can replace the checkpoint id with "PixArt-alpha/PixArt-XL-2-512x512" too.
+ >>> pipe = EasyAnimatePipeline.from_pretrained("PixArt-alpha/PixArt-XL-2-1024-MS", torch_dtype=torch.float16)
+ >>> # Enable memory optimizations.
+ >>> pipe.enable_model_cpu_offload()
+
+ >>> prompt = "A small cactus with a happy face in the Sahara desert."
+ >>> image = pipe(prompt).images[0]
+ ```
+"""
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ **kwargs,
+):
+ """
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used,
+ `timesteps` must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
+ timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
+ must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+@dataclass
+class EasyAnimatePipelineOutput(BaseOutput):
+ videos: Union[torch.Tensor, np.ndarray]
+
+class EasyAnimatePipeline(DiffusionPipeline):
+ r"""
+ Pipeline for text-to-image generation using PixArt-Alpha.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`T5EncoderModel`]):
+ Frozen text-encoder. PixArt-Alpha uses
+ [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the
+ [t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant.
+ tokenizer (`T5Tokenizer`):
+ Tokenizer of class
+ [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
+ transformer ([`Transformer3DModel`]):
+ A text conditioned `Transformer3DModel` to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
+ """
+ bad_punct_regex = re.compile(
+ r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}"
+ ) # noqa
+
+ _optional_components = ["tokenizer", "text_encoder"]
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
+
+ def __init__(
+ self,
+ tokenizer: T5Tokenizer,
+ text_encoder: T5EncoderModel,
+ vae: AutoencoderKL,
+ transformer: Transformer3DModel,
+ scheduler: DPMSolverMultistepScheduler,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
+ )
+
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+
+ # Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/utils.py
+ def mask_text_embeddings(self, emb, mask):
+ if emb.shape[0] == 1:
+ keep_index = mask.sum().item()
+ return emb[:, :, :keep_index, :], keep_index
+ else:
+ masked_feature = emb * mask[:, None, :, None]
+ return masked_feature, emb.shape[2]
+
+ # Adapted from diffusers.pipelines.deepfloyd_if.pipeline_if.encode_prompt
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ do_classifier_free_guidance: bool = True,
+ negative_prompt: str = "",
+ num_images_per_prompt: int = 1,
+ device: Optional[torch.device] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ prompt_attention_mask: Optional[torch.FloatTensor] = None,
+ negative_prompt_attention_mask: Optional[torch.FloatTensor] = None,
+ clean_caption: bool = False,
+ max_sequence_length: int = 120,
+ **kwargs,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds`
+ instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For
+ PixArt-Alpha, this should be "".
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
+ whether to use classifier free guidance or not
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ number of images that should be generated per prompt
+ device: (`torch.device`, *optional*):
+ torch device to place the resulting embeddings on
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. For PixArt-Alpha, it's should be the embeddings of the ""
+ string.
+ clean_caption (`bool`, defaults to `False`):
+ If `True`, the function will preprocess and clean the provided caption before encoding.
+ max_sequence_length (`int`, defaults to 120): Maximum sequence length to use for the prompt.
+ """
+
+ if "mask_feature" in kwargs:
+ deprecation_message = "The use of `mask_feature` is deprecated. It is no longer used in any computation and that doesn't affect the end results. It will be removed in a future version."
+ deprecate("mask_feature", "1.0.0", deprecation_message, standard_warn=False)
+
+ if device is None:
+ device = self._execution_device
+
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ # See Section 3.1. of the paper.
+ max_length = max_sequence_length
+
+ if prompt_embeds is None:
+ prompt = self._text_preprocessing(prompt, clean_caption=clean_caption)
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ add_special_tokens=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {max_length} tokens: {removed_text}"
+ )
+
+ prompt_attention_mask = text_inputs.attention_mask
+ prompt_attention_mask = prompt_attention_mask.to(device)
+
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask)
+ prompt_embeds = prompt_embeds[0]
+
+ if self.text_encoder is not None:
+ dtype = self.text_encoder.dtype
+ elif self.transformer is not None:
+ dtype = self.transformer.dtype
+ else:
+ dtype = None
+
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+ prompt_attention_mask = prompt_attention_mask.view(bs_embed, -1)
+ prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ uncond_tokens = [negative_prompt] * batch_size
+ uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption)
+ max_length = prompt_embeds.shape[1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_attention_mask=True,
+ add_special_tokens=True,
+ return_tensors="pt",
+ )
+ negative_prompt_attention_mask = uncond_input.attention_mask
+ negative_prompt_attention_mask = negative_prompt_attention_mask.to(device)
+
+ negative_prompt_embeds = self.text_encoder(
+ uncond_input.input_ids.to(device), attention_mask=negative_prompt_attention_mask
+ )
+ negative_prompt_embeds = negative_prompt_embeds[0]
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed, -1)
+ negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1)
+ else:
+ negative_prompt_embeds = None
+ negative_prompt_attention_mask = None
+
+ return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ def check_inputs(
+ self,
+ prompt,
+ height,
+ width,
+ negative_prompt,
+ callback_steps,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ ):
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing
+ def _text_preprocessing(self, text, clean_caption=False):
+ if clean_caption and not is_bs4_available():
+ logger.warn(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`"))
+ logger.warn("Setting `clean_caption` to False...")
+ clean_caption = False
+
+ if clean_caption and not is_ftfy_available():
+ logger.warn(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`"))
+ logger.warn("Setting `clean_caption` to False...")
+ clean_caption = False
+
+ if not isinstance(text, (tuple, list)):
+ text = [text]
+
+ def process(text: str):
+ if clean_caption:
+ text = self._clean_caption(text)
+ text = self._clean_caption(text)
+ else:
+ text = text.lower().strip()
+ return text
+
+ return [process(t) for t in text]
+
+ # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption
+ def _clean_caption(self, caption):
+ caption = str(caption)
+ caption = ul.unquote_plus(caption)
+ caption = caption.strip().lower()
+ caption = re.sub("", "person", caption)
+ # urls:
+ caption = re.sub(
+ r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
+ "",
+ caption,
+ ) # regex for urls
+ caption = re.sub(
+ r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
+ "",
+ caption,
+ ) # regex for urls
+ # html:
+ caption = BeautifulSoup(caption, features="html.parser").text
+
+ # @
+ caption = re.sub(r"@[\w\d]+\b", "", caption)
+
+ # 31C0—31EF CJK Strokes
+ # 31F0—31FF Katakana Phonetic Extensions
+ # 3200—32FF Enclosed CJK Letters and Months
+ # 3300—33FF CJK Compatibility
+ # 3400—4DBF CJK Unified Ideographs Extension A
+ # 4DC0—4DFF Yijing Hexagram Symbols
+ # 4E00—9FFF CJK Unified Ideographs
+ caption = re.sub(r"[\u31c0-\u31ef]+", "", caption)
+ caption = re.sub(r"[\u31f0-\u31ff]+", "", caption)
+ caption = re.sub(r"[\u3200-\u32ff]+", "", caption)
+ caption = re.sub(r"[\u3300-\u33ff]+", "", caption)
+ caption = re.sub(r"[\u3400-\u4dbf]+", "", caption)
+ caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption)
+ caption = re.sub(r"[\u4e00-\u9fff]+", "", caption)
+ #######################################################
+
+ # все виды тире / all types of dash --> "-"
+ caption = re.sub(
+ r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa
+ "-",
+ caption,
+ )
+
+ # кавычки к одному стандарту
+ caption = re.sub(r"[`´«»“”¨]", '"', caption)
+ caption = re.sub(r"[‘’]", "'", caption)
+
+ # "
+ caption = re.sub(r""?", "", caption)
+ # &
+ caption = re.sub(r"&", "", caption)
+
+ # ip adresses:
+ caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
+
+ # article ids:
+ caption = re.sub(r"\d:\d\d\s+$", "", caption)
+
+ # \n
+ caption = re.sub(r"\\n", " ", caption)
+
+ # "#123"
+ caption = re.sub(r"#\d{1,3}\b", "", caption)
+ # "#12345.."
+ caption = re.sub(r"#\d{5,}\b", "", caption)
+ # "123456.."
+ caption = re.sub(r"\b\d{6,}\b", "", caption)
+ # filenames:
+ caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption)
+
+ #
+ caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT"""
+ caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT"""
+
+ caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT
+ caption = re.sub(r"\s+\.\s+", r" ", caption) # " . "
+
+ # this-is-my-cute-cat / this_is_my_cute_cat
+ regex2 = re.compile(r"(?:\-|\_)")
+ if len(re.findall(regex2, caption)) > 3:
+ caption = re.sub(regex2, " ", caption)
+
+ caption = ftfy.fix_text(caption)
+ caption = html.unescape(html.unescape(caption))
+
+ caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640
+ caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc
+ caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231
+
+ caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption)
+ caption = re.sub(r"(free\s)?download(\sfree)?", "", caption)
+ caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption)
+ caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption)
+ caption = re.sub(r"\bpage\s+\d+\b", "", caption)
+
+ caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a...
+
+ caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption)
+
+ caption = re.sub(r"\b\s+\:\s+", r": ", caption)
+ caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption)
+ caption = re.sub(r"\s+", " ", caption)
+
+ caption.strip()
+
+ caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption)
+ caption = re.sub(r"^[\'\_,\-\:;]", r"", caption)
+ caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption)
+ caption = re.sub(r"^\.\S+$", "", caption)
+
+ return caption.strip()
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
+ def prepare_latents(self, batch_size, num_channels_latents, video_length, height, width, dtype, device, generator, latents=None):
+ if self.vae.quant_conv.weight.ndim==5:
+ mini_batch_encoder = self.vae.mini_batch_encoder
+ mini_batch_decoder = self.vae.mini_batch_decoder
+ shape = (batch_size, num_channels_latents, int(video_length // mini_batch_encoder * mini_batch_decoder) if video_length != 1 else 1, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ else:
+ shape = (batch_size, num_channels_latents, video_length, height // self.vae_scale_factor, width // self.vae_scale_factor)
+
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+ return latents
+
+ def smooth_output(self, video, mini_batch_encoder, mini_batch_decoder):
+ if video.size()[2] <= mini_batch_encoder:
+ return video
+ prefix_index_before = mini_batch_encoder // 2
+ prefix_index_after = mini_batch_encoder - prefix_index_before
+ pixel_values = video[:, :, prefix_index_before:-prefix_index_after]
+
+ if self.vae.slice_compression_vae:
+ latents = self.vae.encode(pixel_values)[0]
+ latents = latents.sample()
+ else:
+ new_pixel_values = []
+ for i in range(0, pixel_values.shape[2], mini_batch_encoder):
+ with torch.no_grad():
+ pixel_values_bs = pixel_values[:, :, i: i + mini_batch_encoder, :, :]
+ pixel_values_bs = self.vae.encode(pixel_values_bs)[0]
+ pixel_values_bs = pixel_values_bs.sample()
+ new_pixel_values.append(pixel_values_bs)
+ latents = torch.cat(new_pixel_values, dim = 2)
+
+ if self.vae.slice_compression_vae:
+ middle_video = self.vae.decode(latents)[0]
+ else:
+ middle_video = []
+ for i in range(0, latents.shape[2], mini_batch_decoder):
+ with torch.no_grad():
+ start_index = i
+ end_index = i + mini_batch_decoder
+ latents_bs = self.vae.decode(latents[:, :, start_index:end_index, :, :])[0]
+ middle_video.append(latents_bs)
+ middle_video = torch.cat(middle_video, 2)
+ video[:, :, prefix_index_before:-prefix_index_after] = (video[:, :, prefix_index_before:-prefix_index_after] + middle_video) / 2
+ return video
+
+ def decode_latents(self, latents):
+ video_length = latents.shape[2]
+ latents = 1 / 0.18215 * latents
+ if self.vae.quant_conv.weight.ndim==5:
+ mini_batch_encoder = self.vae.mini_batch_encoder
+ mini_batch_decoder = self.vae.mini_batch_decoder
+ if self.vae.slice_compression_vae:
+ video = self.vae.decode(latents)[0]
+ else:
+ video = []
+ for i in range(0, latents.shape[2], mini_batch_decoder):
+ with torch.no_grad():
+ start_index = i
+ end_index = i + mini_batch_decoder
+ latents_bs = self.vae.decode(latents[:, :, start_index:end_index, :, :])[0]
+ video.append(latents_bs)
+ video = torch.cat(video, 2)
+ video = video.clamp(-1, 1)
+ video = self.smooth_output(video, mini_batch_encoder, mini_batch_decoder).cpu().clamp(-1, 1)
+ else:
+ latents = rearrange(latents, "b c f h w -> (b f) c h w")
+ video = []
+ for frame_idx in tqdm(range(latents.shape[0])):
+ video.append(self.vae.decode(latents[frame_idx:frame_idx+1]).sample)
+ video = torch.cat(video)
+ video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length)
+ video = (video / 2 + 0.5).clamp(0, 1)
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
+ video = video.cpu().float().numpy()
+ return video
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ video_length: Optional[int] = None,
+ negative_prompt: str = "",
+ num_inference_steps: int = 20,
+ timesteps: List[int] = None,
+ guidance_scale: float = 4.5,
+ num_images_per_prompt: Optional[int] = 1,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ prompt_attention_mask: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_attention_mask: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "latent",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
+ clean_caption: bool = True,
+ max_sequence_length: int = 120,
+ **kwargs,
+ ) -> Union[EasyAnimatePipelineOutput, Tuple]:
+ """
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ num_inference_steps (`int`, *optional*, defaults to 100):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps`
+ timesteps are used. Must be in descending order.
+ guidance_scale (`float`, *optional*, defaults to 7.0):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size):
+ The width in pixels of the generated image.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. For PixArt-Alpha this negative prompt should be "". If not
+ provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+ clean_caption (`bool`, *optional*, defaults to `True`):
+ Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to
+ be installed. If the dependencies are not installed, the embeddings will be created from the raw
+ prompt.
+ mask_feature (`bool` defaults to `True`): If set to `True`, the text embeddings will be masked.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.ImagePipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is
+ returned where the first element is a list with the generated images
+ """
+ # 1. Check inputs. Raise error if not correct
+ height = height or self.transformer.config.sample_size * self.vae_scale_factor
+ width = width or self.transformer.config.sample_size * self.vae_scale_factor
+
+ # 2. Default height and width to transformer
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Encode input prompt
+ (
+ prompt_embeds,
+ prompt_attention_mask,
+ negative_prompt_embeds,
+ negative_prompt_attention_mask,
+ ) = self.encode_prompt(
+ prompt,
+ do_classifier_free_guidance,
+ negative_prompt=negative_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ device=device,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ prompt_attention_mask=prompt_attention_mask,
+ negative_prompt_attention_mask=negative_prompt_attention_mask,
+ clean_caption=clean_caption,
+ max_sequence_length=max_sequence_length,
+ )
+ if do_classifier_free_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
+ prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
+
+ # 4. Prepare timesteps
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
+
+ # 5. Prepare latents.
+ latent_channels = self.transformer.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ latent_channels,
+ video_length,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 6.1 Prepare micro-conditions.
+ added_cond_kwargs = {"resolution": None, "aspect_ratio": None}
+ if self.transformer.config.sample_size == 128:
+ resolution = torch.tensor([height, width]).repeat(batch_size * num_images_per_prompt, 1)
+ aspect_ratio = torch.tensor([float(height / width)]).repeat(batch_size * num_images_per_prompt, 1)
+ resolution = resolution.to(dtype=prompt_embeds.dtype, device=device)
+ aspect_ratio = aspect_ratio.to(dtype=prompt_embeds.dtype, device=device)
+
+ if do_classifier_free_guidance:
+ resolution = torch.cat([resolution, resolution], dim=0)
+ aspect_ratio = torch.cat([aspect_ratio, aspect_ratio], dim=0)
+
+ added_cond_kwargs = {"resolution": resolution, "aspect_ratio": aspect_ratio}
+
+ # 7. Denoising loop
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ current_timestep = t
+ if not torch.is_tensor(current_timestep):
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
+ # This would be a good case for the `match` statement (Python 3.10+)
+ is_mps = latent_model_input.device.type == "mps"
+ if isinstance(current_timestep, float):
+ dtype = torch.float32 if is_mps else torch.float64
+ else:
+ dtype = torch.int32 if is_mps else torch.int64
+ current_timestep = torch.tensor([current_timestep], dtype=dtype, device=latent_model_input.device)
+ elif len(current_timestep.shape) == 0:
+ current_timestep = current_timestep[None].to(latent_model_input.device)
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ current_timestep = current_timestep.expand(latent_model_input.shape[0])
+
+ # predict noise model_output
+ noise_pred = self.transformer(
+ latent_model_input,
+ encoder_hidden_states=prompt_embeds,
+ encoder_attention_mask=prompt_attention_mask,
+ timestep=current_timestep,
+ added_cond_kwargs=added_cond_kwargs,
+ return_dict=False,
+ )[0]
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # learned sigma
+ if self.transformer.config.out_channels // 2 == latent_channels:
+ noise_pred = noise_pred.chunk(2, dim=1)[0]
+ else:
+ noise_pred = noise_pred
+
+ # compute previous image: x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
+
+ # Post-processing
+ video = self.decode_latents(latents)
+
+ # Convert to tensor
+ if output_type == "latent":
+ video = torch.from_numpy(video)
+
+ if not return_dict:
+ return video
+
+ return EasyAnimatePipelineOutput(videos=video)
\ No newline at end of file
diff --git a/easyanimate/pipeline/pipeline_easyanimate_inpaint.py b/easyanimate/pipeline/pipeline_easyanimate_inpaint.py
new file mode 100644
index 0000000000000000000000000000000000000000..cb62f4ef40601235948b85a7b4a51d8b0e0242b4
--- /dev/null
+++ b/easyanimate/pipeline/pipeline_easyanimate_inpaint.py
@@ -0,0 +1,984 @@
+# Copyright 2023 PixArt-Alpha Authors and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import html
+import inspect
+import re
+import copy
+import urllib.parse as ul
+from dataclasses import dataclass
+from typing import Callable, List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+from diffusers import DiffusionPipeline, ImagePipelineOutput
+from diffusers.image_processor import VaeImageProcessor
+from diffusers.models import AutoencoderKL
+from diffusers.schedulers import DPMSolverMultistepScheduler
+from diffusers.utils import (BACKENDS_MAPPING, BaseOutput, deprecate,
+ is_bs4_available, is_ftfy_available, logging,
+ replace_example_docstring)
+from diffusers.utils.torch_utils import randn_tensor
+from einops import rearrange
+from tqdm import tqdm
+from transformers import T5EncoderModel, T5Tokenizer
+
+from ..models.transformer3d import Transformer3DModel
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+if is_bs4_available():
+ from bs4 import BeautifulSoup
+
+if is_ftfy_available():
+ import ftfy
+
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import EasyAnimatePipeline
+
+ >>> # You can replace the checkpoint id with "PixArt-alpha/PixArt-XL-2-512x512" too.
+ >>> pipe = EasyAnimatePipeline.from_pretrained("PixArt-alpha/PixArt-XL-2-1024-MS", torch_dtype=torch.float16)
+ >>> # Enable memory optimizations.
+ >>> pipe.enable_model_cpu_offload()
+
+ >>> prompt = "A small cactus with a happy face in the Sahara desert."
+ >>> image = pipe(prompt).images[0]
+ ```
+"""
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
+def retrieve_latents(encoder_output, generator):
+ if hasattr(encoder_output, "latent_dist"):
+ return encoder_output.latent_dist.sample(generator)
+ elif hasattr(encoder_output, "latents"):
+ return encoder_output.latents
+ else:
+ raise AttributeError("Could not access latents of provided encoder_output")
+
+@dataclass
+class EasyAnimatePipelineOutput(BaseOutput):
+ videos: Union[torch.Tensor, np.ndarray]
+
+class EasyAnimateInpaintPipeline(DiffusionPipeline):
+ r"""
+ Pipeline for text-to-image generation using PixArt-Alpha.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`T5EncoderModel`]):
+ Frozen text-encoder. PixArt-Alpha uses
+ [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the
+ [t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant.
+ tokenizer (`T5Tokenizer`):
+ Tokenizer of class
+ [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
+ transformer ([`Transformer3DModel`]):
+ A text conditioned `Transformer3DModel` to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
+ """
+ bad_punct_regex = re.compile(
+ r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}"
+ ) # noqa
+
+ _optional_components = ["tokenizer", "text_encoder"]
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
+
+ def __init__(
+ self,
+ tokenizer: T5Tokenizer,
+ text_encoder: T5EncoderModel,
+ vae: AutoencoderKL,
+ transformer: Transformer3DModel,
+ scheduler: DPMSolverMultistepScheduler,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
+ )
+
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_normalize=True)
+ self.mask_processor = VaeImageProcessor(
+ vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True
+ )
+
+ # Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/utils.py
+ def mask_text_embeddings(self, emb, mask):
+ if emb.shape[0] == 1:
+ keep_index = mask.sum().item()
+ return emb[:, :, :keep_index, :], keep_index
+ else:
+ masked_feature = emb * mask[:, None, :, None]
+ return masked_feature, emb.shape[2]
+
+ # Adapted from diffusers.pipelines.deepfloyd_if.pipeline_if.encode_prompt
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ do_classifier_free_guidance: bool = True,
+ negative_prompt: str = "",
+ num_images_per_prompt: int = 1,
+ device: Optional[torch.device] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ prompt_attention_mask: Optional[torch.FloatTensor] = None,
+ negative_prompt_attention_mask: Optional[torch.FloatTensor] = None,
+ clean_caption: bool = False,
+ max_sequence_length: int = 120,
+ **kwargs,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds`
+ instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For
+ PixArt-Alpha, this should be "".
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
+ whether to use classifier free guidance or not
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ number of images that should be generated per prompt
+ device: (`torch.device`, *optional*):
+ torch device to place the resulting embeddings on
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. For PixArt-Alpha, it's should be the embeddings of the ""
+ string.
+ clean_caption (`bool`, defaults to `False`):
+ If `True`, the function will preprocess and clean the provided caption before encoding.
+ max_sequence_length (`int`, defaults to 120): Maximum sequence length to use for the prompt.
+ """
+
+ if "mask_feature" in kwargs:
+ deprecation_message = "The use of `mask_feature` is deprecated. It is no longer used in any computation and that doesn't affect the end results. It will be removed in a future version."
+ deprecate("mask_feature", "1.0.0", deprecation_message, standard_warn=False)
+
+ if device is None:
+ device = self._execution_device
+
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ # See Section 3.1. of the paper.
+ max_length = max_sequence_length
+
+ if prompt_embeds is None:
+ prompt = self._text_preprocessing(prompt, clean_caption=clean_caption)
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ add_special_tokens=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {max_length} tokens: {removed_text}"
+ )
+
+ prompt_attention_mask = text_inputs.attention_mask
+ prompt_attention_mask = prompt_attention_mask.to(device)
+
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask)
+ prompt_embeds = prompt_embeds[0]
+
+ if self.text_encoder is not None:
+ dtype = self.text_encoder.dtype
+ elif self.transformer is not None:
+ dtype = self.transformer.dtype
+ else:
+ dtype = None
+
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+ prompt_attention_mask = prompt_attention_mask.view(bs_embed, -1)
+ prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ uncond_tokens = [negative_prompt] * batch_size
+ uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption)
+ max_length = prompt_embeds.shape[1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_attention_mask=True,
+ add_special_tokens=True,
+ return_tensors="pt",
+ )
+ negative_prompt_attention_mask = uncond_input.attention_mask
+ negative_prompt_attention_mask = negative_prompt_attention_mask.to(device)
+
+ negative_prompt_embeds = self.text_encoder(
+ uncond_input.input_ids.to(device), attention_mask=negative_prompt_attention_mask
+ )
+ negative_prompt_embeds = negative_prompt_embeds[0]
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed, -1)
+ negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1)
+ else:
+ negative_prompt_embeds = None
+ negative_prompt_attention_mask = None
+
+ return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ def check_inputs(
+ self,
+ prompt,
+ height,
+ width,
+ negative_prompt,
+ callback_steps,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ ):
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing
+ def _text_preprocessing(self, text, clean_caption=False):
+ if clean_caption and not is_bs4_available():
+ logger.warn(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`"))
+ logger.warn("Setting `clean_caption` to False...")
+ clean_caption = False
+
+ if clean_caption and not is_ftfy_available():
+ logger.warn(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`"))
+ logger.warn("Setting `clean_caption` to False...")
+ clean_caption = False
+
+ if not isinstance(text, (tuple, list)):
+ text = [text]
+
+ def process(text: str):
+ if clean_caption:
+ text = self._clean_caption(text)
+ text = self._clean_caption(text)
+ else:
+ text = text.lower().strip()
+ return text
+
+ return [process(t) for t in text]
+
+ # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption
+ def _clean_caption(self, caption):
+ caption = str(caption)
+ caption = ul.unquote_plus(caption)
+ caption = caption.strip().lower()
+ caption = re.sub("", "person", caption)
+ # urls:
+ caption = re.sub(
+ r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
+ "",
+ caption,
+ ) # regex for urls
+ caption = re.sub(
+ r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
+ "",
+ caption,
+ ) # regex for urls
+ # html:
+ caption = BeautifulSoup(caption, features="html.parser").text
+
+ # @
+ caption = re.sub(r"@[\w\d]+\b", "", caption)
+
+ # 31C0—31EF CJK Strokes
+ # 31F0—31FF Katakana Phonetic Extensions
+ # 3200—32FF Enclosed CJK Letters and Months
+ # 3300—33FF CJK Compatibility
+ # 3400—4DBF CJK Unified Ideographs Extension A
+ # 4DC0—4DFF Yijing Hexagram Symbols
+ # 4E00—9FFF CJK Unified Ideographs
+ caption = re.sub(r"[\u31c0-\u31ef]+", "", caption)
+ caption = re.sub(r"[\u31f0-\u31ff]+", "", caption)
+ caption = re.sub(r"[\u3200-\u32ff]+", "", caption)
+ caption = re.sub(r"[\u3300-\u33ff]+", "", caption)
+ caption = re.sub(r"[\u3400-\u4dbf]+", "", caption)
+ caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption)
+ caption = re.sub(r"[\u4e00-\u9fff]+", "", caption)
+ #######################################################
+
+ # все виды тире / all types of dash --> "-"
+ caption = re.sub(
+ r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa
+ "-",
+ caption,
+ )
+
+ # кавычки к одному стандарту
+ caption = re.sub(r"[`´«»“”¨]", '"', caption)
+ caption = re.sub(r"[‘’]", "'", caption)
+
+ # "
+ caption = re.sub(r""?", "", caption)
+ # &
+ caption = re.sub(r"&", "", caption)
+
+ # ip adresses:
+ caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
+
+ # article ids:
+ caption = re.sub(r"\d:\d\d\s+$", "", caption)
+
+ # \n
+ caption = re.sub(r"\\n", " ", caption)
+
+ # "#123"
+ caption = re.sub(r"#\d{1,3}\b", "", caption)
+ # "#12345.."
+ caption = re.sub(r"#\d{5,}\b", "", caption)
+ # "123456.."
+ caption = re.sub(r"\b\d{6,}\b", "", caption)
+ # filenames:
+ caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption)
+
+ #
+ caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT"""
+ caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT"""
+
+ caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT
+ caption = re.sub(r"\s+\.\s+", r" ", caption) # " . "
+
+ # this-is-my-cute-cat / this_is_my_cute_cat
+ regex2 = re.compile(r"(?:\-|\_)")
+ if len(re.findall(regex2, caption)) > 3:
+ caption = re.sub(regex2, " ", caption)
+
+ caption = ftfy.fix_text(caption)
+ caption = html.unescape(html.unescape(caption))
+
+ caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640
+ caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc
+ caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231
+
+ caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption)
+ caption = re.sub(r"(free\s)?download(\sfree)?", "", caption)
+ caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption)
+ caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption)
+ caption = re.sub(r"\bpage\s+\d+\b", "", caption)
+
+ caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a...
+
+ caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption)
+
+ caption = re.sub(r"\b\s+\:\s+", r": ", caption)
+ caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption)
+ caption = re.sub(r"\s+", " ", caption)
+
+ caption.strip()
+
+ caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption)
+ caption = re.sub(r"^[\'\_,\-\:;]", r"", caption)
+ caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption)
+ caption = re.sub(r"^\.\S+$", "", caption)
+
+ return caption.strip()
+
+ def prepare_latents(
+ self,
+ batch_size,
+ num_channels_latents,
+ height,
+ width,
+ video_length,
+ dtype,
+ device,
+ generator,
+ latents=None,
+ video=None,
+ timestep=None,
+ is_strength_max=True,
+ return_noise=False,
+ return_video_latents=False,
+ ):
+ if self.vae.quant_conv.weight.ndim==5:
+ shape = (batch_size, num_channels_latents, int(video_length // 5 * 2) if video_length != 1 else 1, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ else:
+ shape = (batch_size, num_channels_latents, video_length, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if return_video_latents or (latents is None and not is_strength_max):
+ video = video.to(device=device, dtype=dtype)
+
+ if video.shape[1] == 4:
+ video_latents = video
+ else:
+ video_length = video.shape[2]
+ video = rearrange(video, "b c f h w -> (b f) c h w")
+ video_latents = self._encode_vae_image(image=video, generator=generator)
+ video_latents = rearrange(video_latents, "(b f) c h w -> b c f h w", f=video_length)
+ video_latents = video_latents.repeat(batch_size // video_latents.shape[0], 1, 1, 1, 1)
+
+ if latents is None:
+ rand_device = "cpu" if device.type == "mps" else device
+
+ noise = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
+ # if strength is 1. then initialise the latents to noise, else initial to image + noise
+ latents = noise if is_strength_max else self.scheduler.add_noise(video_latents, noise, timestep)
+ else:
+ noise = latents.to(device)
+ if latents.shape != shape:
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
+ latents = latents.to(device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+ outputs = (latents,)
+
+ if return_noise:
+ outputs += (noise,)
+
+ if return_video_latents:
+ outputs += (video_latents,)
+
+ return outputs
+
+ def decode_latents(self, latents):
+ video_length = latents.shape[2]
+ latents = 1 / 0.18215 * latents
+ if self.vae.quant_conv.weight.ndim==5:
+ mini_batch_decoder = 2
+ # Decoder
+ video = []
+ for i in range(0, latents.shape[2], mini_batch_decoder):
+ with torch.no_grad():
+ start_index = i
+ end_index = i + mini_batch_decoder
+ latents_bs = self.vae.decode(latents[:, :, start_index:end_index, :, :])[0]
+ video.append(latents_bs)
+
+ # Smooth
+ mini_batch_encoder = 5
+ video = torch.cat(video, 2).cpu()
+ for i in range(mini_batch_encoder, video.shape[2], mini_batch_encoder):
+ origin_before = copy.deepcopy(video[:, :, i - 1, :, :])
+ origin_after = copy.deepcopy(video[:, :, i, :, :])
+
+ video[:, :, i - 1, :, :] = origin_before * 0.75 + origin_after * 0.25
+ video[:, :, i, :, :] = origin_before * 0.25 + origin_after * 0.75
+ video = video.clamp(-1, 1)
+ else:
+ latents = rearrange(latents, "b c f h w -> (b f) c h w")
+ # video = self.vae.decode(latents).sample
+ video = []
+ for frame_idx in tqdm(range(latents.shape[0])):
+ video.append(self.vae.decode(latents[frame_idx:frame_idx+1]).sample)
+ video = torch.cat(video)
+ video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length)
+ video = (video / 2 + 0.5).clamp(0, 1)
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
+ video = video.cpu().float().numpy()
+ return video
+
+ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
+ if isinstance(generator, list):
+ image_latents = [
+ retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
+ for i in range(image.shape[0])
+ ]
+ image_latents = torch.cat(image_latents, dim=0)
+ else:
+ image_latents = retrieve_latents(self.vae.encode(image), generator=generator)
+
+ image_latents = self.vae.config.scaling_factor * image_latents
+
+ return image_latents
+
+ def prepare_mask_latents(
+ self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance
+ ):
+ # resize the mask to latents shape as we concatenate the mask to the latents
+ # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
+ # and half precision
+ video_length = mask.shape[2]
+
+ mask = mask.to(device=device, dtype=self.vae.dtype)
+ if self.vae.quant_conv.weight.ndim==5:
+ bs = 1
+ new_mask = []
+ for i in range(0, mask.shape[0], bs):
+ mini_batch = 5
+ new_mask_mini_batch = []
+ for j in range(0, mask.shape[2], mini_batch):
+ mask_bs = mask[i : i + bs, :, j: j + mini_batch, :, :]
+ mask_bs = self.vae.encode(mask_bs)[0]
+ mask_bs = mask_bs.sample()
+ new_mask_mini_batch.append(mask_bs)
+ new_mask_mini_batch = torch.cat(new_mask_mini_batch, dim = 2)
+ new_mask.append(new_mask_mini_batch)
+ mask = torch.cat(new_mask, dim = 0)
+ mask = mask * 0.1825
+
+ else:
+ if mask.shape[1] == 4:
+ mask = mask
+ else:
+ video_length = mask.shape[2]
+ mask = rearrange(mask, "b c f h w -> (b f) c h w")
+ mask = self._encode_vae_image(mask, generator=generator)
+ mask = rearrange(mask, "(b f) c h w -> b c f h w", f=video_length)
+
+ masked_image = masked_image.to(device=device, dtype=self.vae.dtype)
+ if self.vae.quant_conv.weight.ndim==5:
+ bs = 1
+ new_mask_pixel_values = []
+ for i in range(0, masked_image.shape[0], bs):
+ mini_batch = 5
+ new_mask_pixel_values_mini_batch = []
+ for j in range(0, masked_image.shape[2], mini_batch):
+ mask_pixel_values_bs = masked_image[i : i + bs, :, j: j + mini_batch, :, :]
+ mask_pixel_values_bs = self.vae.encode(mask_pixel_values_bs)[0]
+ mask_pixel_values_bs = mask_pixel_values_bs.sample()
+ new_mask_pixel_values_mini_batch.append(mask_pixel_values_bs)
+ new_mask_pixel_values_mini_batch = torch.cat(new_mask_pixel_values_mini_batch, dim = 2)
+ new_mask_pixel_values.append(new_mask_pixel_values_mini_batch)
+ masked_image_latents = torch.cat(new_mask_pixel_values, dim = 0)
+ masked_image_latents = masked_image_latents * 0.1825
+
+ else:
+ if masked_image.shape[1] == 4:
+ masked_image_latents = masked_image
+ else:
+ video_length = mask.shape[2]
+ masked_image = rearrange(masked_image, "b c f h w -> (b f) c h w")
+ masked_image_latents = self._encode_vae_image(masked_image, generator=generator)
+ masked_image_latents = rearrange(masked_image_latents, "(b f) c h w -> b c f h w", f=video_length)
+
+ # aligning device to prevent device errors when concating it with the latent model input
+ masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
+ return mask, masked_image_latents
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ video_length: Optional[int] = None,
+ video: Union[torch.FloatTensor] = None,
+ mask_video: Union[torch.FloatTensor] = None,
+ masked_video_latents: Union[torch.FloatTensor] = None,
+ negative_prompt: str = "",
+ num_inference_steps: int = 20,
+ timesteps: List[int] = None,
+ guidance_scale: float = 4.5,
+ num_images_per_prompt: Optional[int] = 1,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ strength: float = 1.0,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ prompt_attention_mask: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_attention_mask: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "latent",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
+ clean_caption: bool = True,
+ mask_feature: bool = True,
+ max_sequence_length: int = 120
+ ) -> Union[EasyAnimatePipelineOutput, Tuple]:
+ """
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ num_inference_steps (`int`, *optional*, defaults to 100):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps`
+ timesteps are used. Must be in descending order.
+ guidance_scale (`float`, *optional*, defaults to 7.0):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size):
+ The width in pixels of the generated image.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. For PixArt-Alpha this negative prompt should be "". If not
+ provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+ clean_caption (`bool`, *optional*, defaults to `True`):
+ Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to
+ be installed. If the dependencies are not installed, the embeddings will be created from the raw
+ prompt.
+ mask_feature (`bool` defaults to `True`): If set to `True`, the text embeddings will be masked.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.ImagePipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is
+ returned where the first element is a list with the generated images
+ """
+ # 1. Check inputs. Raise error if not correct
+ height = height or self.transformer.config.sample_size * self.vae_scale_factor
+ width = width or self.transformer.config.sample_size * self.vae_scale_factor
+
+ # 2. Default height and width to transformer
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Encode input prompt
+ (
+ prompt_embeds,
+ prompt_attention_mask,
+ negative_prompt_embeds,
+ negative_prompt_attention_mask,
+ ) = self.encode_prompt(
+ prompt,
+ do_classifier_free_guidance,
+ negative_prompt=negative_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ device=device,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ prompt_attention_mask=prompt_attention_mask,
+ negative_prompt_attention_mask=negative_prompt_attention_mask,
+ clean_caption=clean_caption,
+ max_sequence_length=max_sequence_length,
+ )
+ if do_classifier_free_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
+ prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
+
+ # 4. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+ # at which timestep to set the initial noise (n.b. 50% if strength is 0.5)
+ latent_timestep = timesteps[:1].repeat(batch_size)
+ # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise
+ is_strength_max = strength == 1.0
+
+ if video is not None:
+ video_length = video.shape[2]
+ init_video = self.image_processor.preprocess(rearrange(video, "b c f h w -> (b f) c h w"), height=height, width=width)
+ init_video = init_video.to(dtype=torch.float32)
+ init_video = rearrange(init_video, "(b f) c h w -> b c f h w", f=video_length)
+ else:
+ init_video = None
+
+ # Prepare latent variables
+ num_channels_latents = self.vae.config.latent_channels
+ num_channels_transformer = self.transformer.config.in_channels
+ return_image_latents = num_channels_transformer == 4
+
+ # 5. Prepare latents.
+ latents_outputs = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ video_length,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ video=init_video,
+ timestep=latent_timestep,
+ is_strength_max=is_strength_max,
+ return_noise=True,
+ return_video_latents=return_image_latents,
+ )
+ if return_image_latents:
+ latents, noise, image_latents = latents_outputs
+ else:
+ latents, noise = latents_outputs
+ latents_dtype = latents.dtype
+
+ if mask_video is not None:
+ # Prepare mask latent variables
+ video_length = video.shape[2]
+ mask_condition = self.mask_processor.preprocess(rearrange(mask_video, "b c f h w -> (b f) c h w"), height=height, width=width)
+ mask_condition = mask_condition.to(dtype=torch.float32)
+ mask_condition = rearrange(mask_condition, "(b f) c h w -> b c f h w", f=video_length)
+
+ if masked_video_latents is None:
+ masked_video = init_video * (mask_condition < 0.5) + torch.ones_like(init_video) * (mask_condition > 0.5) * -1
+ else:
+ masked_video = masked_video_latents
+
+ mask, masked_video_latents = self.prepare_mask_latents(
+ mask_condition,
+ masked_video,
+ batch_size,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ do_classifier_free_guidance,
+ )
+ else:
+ mask = torch.zeros_like(latents).to(latents.device, latents.dtype)
+ masked_video_latents = torch.zeros_like(latents).to(latents.device, latents.dtype)
+
+ # Check that sizes of mask, masked image and latents match
+ if num_channels_transformer == 12:
+ # default case for runwayml/stable-diffusion-inpainting
+ num_channels_mask = mask.shape[1]
+ num_channels_masked_image = masked_video_latents.shape[1]
+ if num_channels_latents + num_channels_mask + num_channels_masked_image != self.transformer.config.in_channels:
+ raise ValueError(
+ f"Incorrect configuration settings! The config of `pipeline.transformer`: {self.transformer.config} expects"
+ f" {self.transformer.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
+ f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
+ f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
+ " `pipeline.transformer` or your `mask_image` or `image` input."
+ )
+ elif num_channels_transformer == 4:
+ raise ValueError(
+ f"The transformer {self.transformer.__class__} should have 9 input channels, not {self.transformer.config.in_channels}."
+ )
+
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 6.1 Prepare micro-conditions.
+ added_cond_kwargs = {"resolution": None, "aspect_ratio": None}
+ if self.transformer.config.sample_size == 128:
+ resolution = torch.tensor([height, width]).repeat(batch_size * num_images_per_prompt, 1)
+ aspect_ratio = torch.tensor([float(height / width)]).repeat(batch_size * num_images_per_prompt, 1)
+ resolution = resolution.to(dtype=prompt_embeds.dtype, device=device)
+ aspect_ratio = aspect_ratio.to(dtype=prompt_embeds.dtype, device=device)
+
+ if do_classifier_free_guidance:
+ resolution = torch.cat([resolution, resolution], dim=0)
+ aspect_ratio = torch.cat([aspect_ratio, aspect_ratio], dim=0)
+
+ added_cond_kwargs = {"resolution": resolution, "aspect_ratio": aspect_ratio}
+
+ # 7. Denoising loop
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ if num_channels_transformer == 12:
+ mask_input = torch.cat([mask] * 2) if do_classifier_free_guidance else mask
+ masked_video_latents_input = (
+ torch.cat([masked_video_latents] * 2) if do_classifier_free_guidance else masked_video_latents
+ )
+ inpaint_latents = torch.cat([mask_input, masked_video_latents_input], dim=1)
+
+ current_timestep = t
+ if not torch.is_tensor(current_timestep):
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
+ # This would be a good case for the `match` statement (Python 3.10+)
+ is_mps = latent_model_input.device.type == "mps"
+ if isinstance(current_timestep, float):
+ dtype = torch.float32 if is_mps else torch.float64
+ else:
+ dtype = torch.int32 if is_mps else torch.int64
+ current_timestep = torch.tensor([current_timestep], dtype=dtype, device=latent_model_input.device)
+ elif len(current_timestep.shape) == 0:
+ current_timestep = current_timestep[None].to(latent_model_input.device)
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ current_timestep = current_timestep.expand(latent_model_input.shape[0])
+
+ # predict noise model_output
+ noise_pred = self.transformer(
+ latent_model_input,
+ encoder_hidden_states=prompt_embeds,
+ encoder_attention_mask=prompt_attention_mask,
+ timestep=current_timestep,
+ added_cond_kwargs=added_cond_kwargs,
+ inpaint_latents=inpaint_latents.to(latent_model_input.dtype),
+ return_dict=False,
+ )[0]
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # learned sigma
+ noise_pred = noise_pred.chunk(2, dim=1)[0]
+
+ # compute previous image: x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
+
+ # Post-processing
+ video = self.decode_latents(latents)
+
+ # Convert to tensor
+ if output_type == "latent":
+ video = torch.from_numpy(video)
+
+ if not return_dict:
+ return video
+
+ return EasyAnimatePipelineOutput(videos=video)
\ No newline at end of file
diff --git a/easyanimate/pipeline/pipeline_pixart_magvit.py b/easyanimate/pipeline/pipeline_pixart_magvit.py
new file mode 100644
index 0000000000000000000000000000000000000000..b92abff0a8a3bee2ce8331dc103709d097c768bc
--- /dev/null
+++ b/easyanimate/pipeline/pipeline_pixart_magvit.py
@@ -0,0 +1,983 @@
+# Copyright 2024 PixArt-Alpha Authors and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import html
+import inspect
+import re
+import urllib.parse as ul
+from typing import Callable, List, Optional, Tuple, Union
+
+import torch
+import torch.nn.functional as F
+from diffusers.image_processor import VaeImageProcessor
+from diffusers.models import Transformer2DModel
+from diffusers.pipelines.pipeline_utils import (DiffusionPipeline,
+ ImagePipelineOutput)
+from diffusers.schedulers import DPMSolverMultistepScheduler
+from diffusers.utils import (BACKENDS_MAPPING, deprecate, is_bs4_available,
+ is_ftfy_available, logging,
+ replace_example_docstring)
+from diffusers.utils.torch_utils import randn_tensor
+from transformers import T5EncoderModel, T5Tokenizer
+
+from easyanimate.models.autoencoder_magvit import AutoencoderKLMagvit
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+if is_bs4_available():
+ from bs4 import BeautifulSoup
+
+if is_ftfy_available():
+ import ftfy
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import PixArtAlphaPipeline
+
+ >>> # You can replace the checkpoint id with "PixArt-alpha/PixArt-XL-2-512x512" too.
+ >>> pipe = PixArtAlphaPipeline.from_pretrained("PixArt-alpha/PixArt-XL-2-1024-MS", torch_dtype=torch.float16)
+ >>> # Enable memory optimizations.
+ >>> pipe.enable_model_cpu_offload()
+
+ >>> prompt = "A small cactus with a happy face in the Sahara desert."
+ >>> image = pipe(prompt).images[0]
+ ```
+"""
+
+ASPECT_RATIO_1024_BIN = {
+ "0.25": [512.0, 2048.0],
+ "0.28": [512.0, 1856.0],
+ "0.32": [576.0, 1792.0],
+ "0.33": [576.0, 1728.0],
+ "0.35": [576.0, 1664.0],
+ "0.4": [640.0, 1600.0],
+ "0.42": [640.0, 1536.0],
+ "0.48": [704.0, 1472.0],
+ "0.5": [704.0, 1408.0],
+ "0.52": [704.0, 1344.0],
+ "0.57": [768.0, 1344.0],
+ "0.6": [768.0, 1280.0],
+ "0.68": [832.0, 1216.0],
+ "0.72": [832.0, 1152.0],
+ "0.78": [896.0, 1152.0],
+ "0.82": [896.0, 1088.0],
+ "0.88": [960.0, 1088.0],
+ "0.94": [960.0, 1024.0],
+ "1.0": [1024.0, 1024.0],
+ "1.07": [1024.0, 960.0],
+ "1.13": [1088.0, 960.0],
+ "1.21": [1088.0, 896.0],
+ "1.29": [1152.0, 896.0],
+ "1.38": [1152.0, 832.0],
+ "1.46": [1216.0, 832.0],
+ "1.67": [1280.0, 768.0],
+ "1.75": [1344.0, 768.0],
+ "2.0": [1408.0, 704.0],
+ "2.09": [1472.0, 704.0],
+ "2.4": [1536.0, 640.0],
+ "2.5": [1600.0, 640.0],
+ "3.0": [1728.0, 576.0],
+ "4.0": [2048.0, 512.0],
+}
+
+ASPECT_RATIO_512_BIN = {
+ "0.25": [256.0, 1024.0],
+ "0.28": [256.0, 928.0],
+ "0.32": [288.0, 896.0],
+ "0.33": [288.0, 864.0],
+ "0.35": [288.0, 832.0],
+ "0.4": [320.0, 800.0],
+ "0.42": [320.0, 768.0],
+ "0.48": [352.0, 736.0],
+ "0.5": [352.0, 704.0],
+ "0.52": [352.0, 672.0],
+ "0.57": [384.0, 672.0],
+ "0.6": [384.0, 640.0],
+ "0.68": [416.0, 608.0],
+ "0.72": [416.0, 576.0],
+ "0.78": [448.0, 576.0],
+ "0.82": [448.0, 544.0],
+ "0.88": [480.0, 544.0],
+ "0.94": [480.0, 512.0],
+ "1.0": [512.0, 512.0],
+ "1.07": [512.0, 480.0],
+ "1.13": [544.0, 480.0],
+ "1.21": [544.0, 448.0],
+ "1.29": [576.0, 448.0],
+ "1.38": [576.0, 416.0],
+ "1.46": [608.0, 416.0],
+ "1.67": [640.0, 384.0],
+ "1.75": [672.0, 384.0],
+ "2.0": [704.0, 352.0],
+ "2.09": [736.0, 352.0],
+ "2.4": [768.0, 320.0],
+ "2.5": [800.0, 320.0],
+ "3.0": [864.0, 288.0],
+ "4.0": [1024.0, 256.0],
+}
+
+ASPECT_RATIO_256_BIN = {
+ "0.25": [128.0, 512.0],
+ "0.28": [128.0, 464.0],
+ "0.32": [144.0, 448.0],
+ "0.33": [144.0, 432.0],
+ "0.35": [144.0, 416.0],
+ "0.4": [160.0, 400.0],
+ "0.42": [160.0, 384.0],
+ "0.48": [176.0, 368.0],
+ "0.5": [176.0, 352.0],
+ "0.52": [176.0, 336.0],
+ "0.57": [192.0, 336.0],
+ "0.6": [192.0, 320.0],
+ "0.68": [208.0, 304.0],
+ "0.72": [208.0, 288.0],
+ "0.78": [224.0, 288.0],
+ "0.82": [224.0, 272.0],
+ "0.88": [240.0, 272.0],
+ "0.94": [240.0, 256.0],
+ "1.0": [256.0, 256.0],
+ "1.07": [256.0, 240.0],
+ "1.13": [272.0, 240.0],
+ "1.21": [272.0, 224.0],
+ "1.29": [288.0, 224.0],
+ "1.38": [288.0, 208.0],
+ "1.46": [304.0, 208.0],
+ "1.67": [320.0, 192.0],
+ "1.75": [336.0, 192.0],
+ "2.0": [352.0, 176.0],
+ "2.09": [368.0, 176.0],
+ "2.4": [384.0, 160.0],
+ "2.5": [400.0, 160.0],
+ "3.0": [432.0, 144.0],
+ "4.0": [512.0, 128.0],
+}
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ **kwargs,
+):
+ """
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used,
+ `timesteps` must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
+ timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
+ must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+class PixArtAlphaMagvitPipeline(DiffusionPipeline):
+ r"""
+ Pipeline for text-to-image generation using PixArt-Alpha.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`T5EncoderModel`]):
+ Frozen text-encoder. PixArt-Alpha uses
+ [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the
+ [t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant.
+ tokenizer (`T5Tokenizer`):
+ Tokenizer of class
+ [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
+ transformer ([`Transformer2DModel`]):
+ A text conditioned `Transformer2DModel` to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
+ """
+
+ bad_punct_regex = re.compile(
+ r"["
+ + "#®•©™&@·º½¾¿¡§~"
+ + r"\)"
+ + r"\("
+ + r"\]"
+ + r"\["
+ + r"\}"
+ + r"\{"
+ + r"\|"
+ + "\\"
+ + r"\/"
+ + r"\*"
+ + r"]{1,}"
+ ) # noqa
+
+ _optional_components = ["tokenizer", "text_encoder"]
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
+
+ def __init__(
+ self,
+ tokenizer: T5Tokenizer,
+ text_encoder: T5EncoderModel,
+ vae: AutoencoderKLMagvit,
+ transformer: Transformer2DModel,
+ scheduler: DPMSolverMultistepScheduler,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
+ )
+
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+
+ # Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/utils.py
+ def mask_text_embeddings(self, emb, mask):
+ if emb.shape[0] == 1:
+ keep_index = mask.sum().item()
+ return emb[:, :, :keep_index, :], keep_index
+ else:
+ masked_feature = emb * mask[:, None, :, None]
+ return masked_feature, emb.shape[2]
+
+ # Adapted from diffusers.pipelines.deepfloyd_if.pipeline_if.encode_prompt
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ do_classifier_free_guidance: bool = True,
+ negative_prompt: str = "",
+ num_images_per_prompt: int = 1,
+ device: Optional[torch.device] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ prompt_attention_mask: Optional[torch.FloatTensor] = None,
+ negative_prompt_attention_mask: Optional[torch.FloatTensor] = None,
+ clean_caption: bool = False,
+ max_sequence_length: int = 120,
+ **kwargs,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds`
+ instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For
+ PixArt-Alpha, this should be "".
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
+ whether to use classifier free guidance or not
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ number of images that should be generated per prompt
+ device: (`torch.device`, *optional*):
+ torch device to place the resulting embeddings on
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. For PixArt-Alpha, it's should be the embeddings of the ""
+ string.
+ clean_caption (`bool`, defaults to `False`):
+ If `True`, the function will preprocess and clean the provided caption before encoding.
+ max_sequence_length (`int`, defaults to 120): Maximum sequence length to use for the prompt.
+ """
+
+ if "mask_feature" in kwargs:
+ deprecation_message = "The use of `mask_feature` is deprecated. It is no longer used in any computation and that doesn't affect the end results. It will be removed in a future version."
+ deprecate("mask_feature", "1.0.0", deprecation_message, standard_warn=False)
+
+ if device is None:
+ device = self._execution_device
+
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ # See Section 3.1. of the paper.
+ max_length = max_sequence_length
+
+ if prompt_embeds is None:
+ prompt = self._text_preprocessing(prompt, clean_caption=clean_caption)
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ add_special_tokens=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {max_length} tokens: {removed_text}"
+ )
+
+ prompt_attention_mask = text_inputs.attention_mask
+ prompt_attention_mask = prompt_attention_mask.to(device)
+
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask)
+ prompt_embeds = prompt_embeds[0]
+
+ if self.text_encoder is not None:
+ dtype = self.text_encoder.dtype
+ elif self.transformer is not None:
+ dtype = self.transformer.dtype
+ else:
+ dtype = None
+
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+ prompt_attention_mask = prompt_attention_mask.view(bs_embed, -1)
+ prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ uncond_tokens = [negative_prompt] * batch_size
+ uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption)
+ max_length = prompt_embeds.shape[1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_attention_mask=True,
+ add_special_tokens=True,
+ return_tensors="pt",
+ )
+ negative_prompt_attention_mask = uncond_input.attention_mask
+ negative_prompt_attention_mask = negative_prompt_attention_mask.to(device)
+
+ negative_prompt_embeds = self.text_encoder(
+ uncond_input.input_ids.to(device), attention_mask=negative_prompt_attention_mask
+ )
+ negative_prompt_embeds = negative_prompt_embeds[0]
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed, -1)
+ negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1)
+ else:
+ negative_prompt_embeds = None
+ negative_prompt_attention_mask = None
+
+ return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ def check_inputs(
+ self,
+ prompt,
+ height,
+ width,
+ negative_prompt,
+ callback_steps,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ prompt_attention_mask=None,
+ negative_prompt_attention_mask=None,
+ ):
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and prompt_attention_mask is None:
+ raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.")
+
+ if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:
+ raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+ if prompt_attention_mask.shape != negative_prompt_attention_mask.shape:
+ raise ValueError(
+ "`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but"
+ f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`"
+ f" {negative_prompt_attention_mask.shape}."
+ )
+
+ # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing
+ def _text_preprocessing(self, text, clean_caption=False):
+ if clean_caption and not is_bs4_available():
+ logger.warn(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`"))
+ logger.warn("Setting `clean_caption` to False...")
+ clean_caption = False
+
+ if clean_caption and not is_ftfy_available():
+ logger.warn(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`"))
+ logger.warn("Setting `clean_caption` to False...")
+ clean_caption = False
+
+ if not isinstance(text, (tuple, list)):
+ text = [text]
+
+ def process(text: str):
+ if clean_caption:
+ text = self._clean_caption(text)
+ text = self._clean_caption(text)
+ else:
+ text = text.lower().strip()
+ return text
+
+ return [process(t) for t in text]
+
+ # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption
+ def _clean_caption(self, caption):
+ caption = str(caption)
+ caption = ul.unquote_plus(caption)
+ caption = caption.strip().lower()
+ caption = re.sub("", "person", caption)
+ # urls:
+ caption = re.sub(
+ r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
+ "",
+ caption,
+ ) # regex for urls
+ caption = re.sub(
+ r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
+ "",
+ caption,
+ ) # regex for urls
+ # html:
+ caption = BeautifulSoup(caption, features="html.parser").text
+
+ # @
+ caption = re.sub(r"@[\w\d]+\b", "", caption)
+
+ # 31C0—31EF CJK Strokes
+ # 31F0—31FF Katakana Phonetic Extensions
+ # 3200—32FF Enclosed CJK Letters and Months
+ # 3300—33FF CJK Compatibility
+ # 3400—4DBF CJK Unified Ideographs Extension A
+ # 4DC0—4DFF Yijing Hexagram Symbols
+ # 4E00—9FFF CJK Unified Ideographs
+ caption = re.sub(r"[\u31c0-\u31ef]+", "", caption)
+ caption = re.sub(r"[\u31f0-\u31ff]+", "", caption)
+ caption = re.sub(r"[\u3200-\u32ff]+", "", caption)
+ caption = re.sub(r"[\u3300-\u33ff]+", "", caption)
+ caption = re.sub(r"[\u3400-\u4dbf]+", "", caption)
+ caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption)
+ caption = re.sub(r"[\u4e00-\u9fff]+", "", caption)
+ #######################################################
+
+ # все виды тире / all types of dash --> "-"
+ caption = re.sub(
+ r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa
+ "-",
+ caption,
+ )
+
+ # кавычки к одному стандарту
+ caption = re.sub(r"[`´«»“”¨]", '"', caption)
+ caption = re.sub(r"[‘’]", "'", caption)
+
+ # "
+ caption = re.sub(r""?", "", caption)
+ # &
+ caption = re.sub(r"&", "", caption)
+
+ # ip adresses:
+ caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
+
+ # article ids:
+ caption = re.sub(r"\d:\d\d\s+$", "", caption)
+
+ # \n
+ caption = re.sub(r"\\n", " ", caption)
+
+ # "#123"
+ caption = re.sub(r"#\d{1,3}\b", "", caption)
+ # "#12345.."
+ caption = re.sub(r"#\d{5,}\b", "", caption)
+ # "123456.."
+ caption = re.sub(r"\b\d{6,}\b", "", caption)
+ # filenames:
+ caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption)
+
+ #
+ caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT"""
+ caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT"""
+
+ caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT
+ caption = re.sub(r"\s+\.\s+", r" ", caption) # " . "
+
+ # this-is-my-cute-cat / this_is_my_cute_cat
+ regex2 = re.compile(r"(?:\-|\_)")
+ if len(re.findall(regex2, caption)) > 3:
+ caption = re.sub(regex2, " ", caption)
+
+ caption = ftfy.fix_text(caption)
+ caption = html.unescape(html.unescape(caption))
+
+ caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640
+ caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc
+ caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231
+
+ caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption)
+ caption = re.sub(r"(free\s)?download(\sfree)?", "", caption)
+ caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption)
+ caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption)
+ caption = re.sub(r"\bpage\s+\d+\b", "", caption)
+
+ caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a...
+
+ caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption)
+
+ caption = re.sub(r"\b\s+\:\s+", r": ", caption)
+ caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption)
+ caption = re.sub(r"\s+", " ", caption)
+
+ caption.strip()
+
+ caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption)
+ caption = re.sub(r"^[\'\_,\-\:;]", r"", caption)
+ caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption)
+ caption = re.sub(r"^\.\S+$", "", caption)
+
+ return caption.strip()
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
+ shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+ return latents
+
+ @staticmethod
+ def classify_height_width_bin(height: int, width: int, ratios: dict) -> Tuple[int, int]:
+ """Returns binned height and width."""
+ ar = float(height / width)
+ closest_ratio = min(ratios.keys(), key=lambda ratio: abs(float(ratio) - ar))
+ default_hw = ratios[closest_ratio]
+ return int(default_hw[0]), int(default_hw[1])
+
+ @staticmethod
+ def resize_and_crop_tensor(samples: torch.Tensor, new_width: int, new_height: int) -> torch.Tensor:
+ orig_height, orig_width = samples.shape[2], samples.shape[3]
+
+ # Check if resizing is needed
+ if orig_height != new_height or orig_width != new_width:
+ ratio = max(new_height / orig_height, new_width / orig_width)
+ resized_width = int(orig_width * ratio)
+ resized_height = int(orig_height * ratio)
+
+ # Resize
+ samples = F.interpolate(
+ samples, size=(resized_height, resized_width), mode="bilinear", align_corners=False
+ )
+
+ # Center Crop
+ start_x = (resized_width - new_width) // 2
+ end_x = start_x + new_width
+ start_y = (resized_height - new_height) // 2
+ end_y = start_y + new_height
+ samples = samples[:, :, start_y:end_y, start_x:end_x]
+
+ return samples
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ negative_prompt: str = "",
+ num_inference_steps: int = 20,
+ timesteps: List[int] = None,
+ guidance_scale: float = 4.5,
+ num_images_per_prompt: Optional[int] = 1,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ prompt_attention_mask: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_attention_mask: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
+ clean_caption: bool = True,
+ use_resolution_binning: bool = False,
+ max_sequence_length: int = 120,
+ **kwargs,
+ ) -> Union[ImagePipelineOutput, Tuple]:
+ """
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ num_inference_steps (`int`, *optional*, defaults to 100):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps`
+ timesteps are used. Must be in descending order.
+ guidance_scale (`float`, *optional*, defaults to 4.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size):
+ The width in pixels of the generated image.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ prompt_attention_mask (`torch.FloatTensor`, *optional*): Pre-generated attention mask for text embeddings.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. For PixArt-Alpha this negative prompt should be "". If not
+ provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
+ negative_prompt_attention_mask (`torch.FloatTensor`, *optional*):
+ Pre-generated attention mask for negative text embeddings.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+ clean_caption (`bool`, *optional*, defaults to `True`):
+ Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to
+ be installed. If the dependencies are not installed, the embeddings will be created from the raw
+ prompt.
+ use_resolution_binning (`bool` defaults to `True`):
+ If set to `True`, the requested height and width are first mapped to the closest resolutions using
+ `ASPECT_RATIO_1024_BIN`. After the produced latents are decoded into images, they are resized back to
+ the requested resolution. Useful for generating non-square images.
+ max_sequence_length (`int` defaults to 120): Maximum sequence length to use with the `prompt`.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.ImagePipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is
+ returned where the first element is a list with the generated images
+ """
+ if "mask_feature" in kwargs:
+ deprecation_message = "The use of `mask_feature` is deprecated. It is no longer used in any computation and that doesn't affect the end results. It will be removed in a future version."
+ deprecate("mask_feature", "1.0.0", deprecation_message, standard_warn=False)
+ # 1. Check inputs. Raise error if not correct
+ height = height or self.transformer.config.sample_size * self.vae_scale_factor
+ width = width or self.transformer.config.sample_size * self.vae_scale_factor
+ if use_resolution_binning:
+ if self.transformer.config.sample_size == 128:
+ aspect_ratio_bin = ASPECT_RATIO_1024_BIN
+ elif self.transformer.config.sample_size == 64:
+ aspect_ratio_bin = ASPECT_RATIO_512_BIN
+ elif self.transformer.config.sample_size == 32:
+ aspect_ratio_bin = ASPECT_RATIO_256_BIN
+ else:
+ raise ValueError("Invalid sample size")
+ orig_height, orig_width = height, width
+ height, width = self.classify_height_width_bin(height, width, ratios=aspect_ratio_bin)
+
+ self.check_inputs(
+ prompt,
+ height,
+ width,
+ negative_prompt,
+ callback_steps,
+ prompt_embeds,
+ negative_prompt_embeds,
+ prompt_attention_mask,
+ negative_prompt_attention_mask,
+ )
+
+ # 2. Default height and width to transformer
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Encode input prompt
+ (
+ prompt_embeds,
+ prompt_attention_mask,
+ negative_prompt_embeds,
+ negative_prompt_attention_mask,
+ ) = self.encode_prompt(
+ prompt,
+ do_classifier_free_guidance,
+ negative_prompt=negative_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ device=device,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ prompt_attention_mask=prompt_attention_mask,
+ negative_prompt_attention_mask=negative_prompt_attention_mask,
+ clean_caption=clean_caption,
+ max_sequence_length=max_sequence_length,
+ )
+ if do_classifier_free_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
+ prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
+
+ # 4. Prepare timesteps
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
+
+ # 5. Prepare latents.
+ latent_channels = self.transformer.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ latent_channels,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 6.1 Prepare micro-conditions.
+ added_cond_kwargs = {"resolution": None, "aspect_ratio": None}
+ if self.transformer.config.sample_size == 128:
+ resolution = torch.tensor([height, width]).repeat(batch_size * num_images_per_prompt, 1)
+ aspect_ratio = torch.tensor([float(height / width)]).repeat(batch_size * num_images_per_prompt, 1)
+ resolution = resolution.to(dtype=prompt_embeds.dtype, device=device)
+ aspect_ratio = aspect_ratio.to(dtype=prompt_embeds.dtype, device=device)
+
+ if do_classifier_free_guidance:
+ resolution = torch.cat([resolution, resolution], dim=0)
+ aspect_ratio = torch.cat([aspect_ratio, aspect_ratio], dim=0)
+
+ added_cond_kwargs = {"resolution": resolution, "aspect_ratio": aspect_ratio}
+
+ # 7. Denoising loop
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ current_timestep = t
+ if not torch.is_tensor(current_timestep):
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
+ # This would be a good case for the `match` statement (Python 3.10+)
+ is_mps = latent_model_input.device.type == "mps"
+ if isinstance(current_timestep, float):
+ dtype = torch.float32 if is_mps else torch.float64
+ else:
+ dtype = torch.int32 if is_mps else torch.int64
+ current_timestep = torch.tensor([current_timestep], dtype=dtype, device=latent_model_input.device)
+ elif len(current_timestep.shape) == 0:
+ current_timestep = current_timestep[None].to(latent_model_input.device)
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ current_timestep = current_timestep.expand(latent_model_input.shape[0])
+
+ # predict noise model_output
+ noise_pred = self.transformer(
+ latent_model_input,
+ encoder_hidden_states=prompt_embeds,
+ encoder_attention_mask=prompt_attention_mask,
+ timestep=current_timestep,
+ added_cond_kwargs=added_cond_kwargs,
+ return_dict=False,
+ )[0]
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # learned sigma
+ if self.transformer.config.out_channels // 2 == latent_channels:
+ noise_pred = noise_pred.chunk(2, dim=1)[0]
+ else:
+ noise_pred = noise_pred
+
+ # compute previous image: x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
+
+ if not output_type == "latent":
+ if self.vae.quant_conv.weight.ndim==5:
+ latents = latents.unsqueeze(2)
+ latents = latents.float()
+ self.vae.post_quant_conv = self.vae.post_quant_conv.float()
+ self.vae.decoder = self.vae.decoder.float()
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
+ if self.vae.quant_conv.weight.ndim==5:
+ image = image.permute(0,2,1,3,4).flatten(0, 1)
+
+ if use_resolution_binning:
+ image = self.resize_and_crop_tensor(image, orig_width, orig_height)
+ else:
+ image = latents
+
+ if not output_type == "latent":
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (image,)
+
+ return ImagePipelineOutput(images=image)
diff --git a/easyanimate/ui/ui.py b/easyanimate/ui/ui.py
new file mode 100644
index 0000000000000000000000000000000000000000..8e799fdf1328f06c4f7e0d3df35861af4d7232f4
--- /dev/null
+++ b/easyanimate/ui/ui.py
@@ -0,0 +1,818 @@
+"""Modified from https://github.com/guoyww/AnimateDiff/blob/main/app.py
+"""
+import gc
+import json
+import os
+import random
+import base64
+import requests
+from datetime import datetime
+from glob import glob
+
+import gradio as gr
+import torch
+import numpy as np
+from diffusers import (AutoencoderKL, DDIMScheduler,
+ DPMSolverMultistepScheduler,
+ EulerAncestralDiscreteScheduler, EulerDiscreteScheduler,
+ PNDMScheduler)
+from easyanimate.models.autoencoder_magvit import AutoencoderKLMagvit
+from diffusers.utils.import_utils import is_xformers_available
+from omegaconf import OmegaConf
+from safetensors import safe_open
+from transformers import T5EncoderModel, T5Tokenizer
+
+from easyanimate.models.transformer3d import Transformer3DModel
+from easyanimate.pipeline.pipeline_easyanimate import EasyAnimatePipeline
+from easyanimate.utils.lora_utils import merge_lora, unmerge_lora
+from easyanimate.utils.utils import save_videos_grid
+from PIL import Image
+
+sample_idx = 0
+scheduler_dict = {
+ "Euler": EulerDiscreteScheduler,
+ "Euler A": EulerAncestralDiscreteScheduler,
+ "DPM++": DPMSolverMultistepScheduler,
+ "PNDM": PNDMScheduler,
+ "DDIM": DDIMScheduler,
+}
+
+css = """
+.toolbutton {
+ margin-buttom: 0em 0em 0em 0em;
+ max-width: 2.5em;
+ min-width: 2.5em !important;
+ height: 2.5em;
+}
+"""
+
+class EasyAnimateController:
+ def __init__(self):
+ # config dirs
+ self.basedir = os.getcwd()
+ self.config_dir = os.path.join(self.basedir, "config")
+ self.diffusion_transformer_dir = os.path.join(self.basedir, "models", "Diffusion_Transformer")
+ self.motion_module_dir = os.path.join(self.basedir, "models", "Motion_Module")
+ self.personalized_model_dir = os.path.join(self.basedir, "models", "Personalized_Model")
+ self.savedir = os.path.join(self.basedir, "samples", datetime.now().strftime("Gradio-%Y-%m-%dT%H-%M-%S"))
+ self.savedir_sample = os.path.join(self.savedir, "sample")
+ self.edition = "v2"
+ self.inference_config = OmegaConf.load(os.path.join(self.config_dir, "easyanimate_video_magvit_motion_module_v2.yaml"))
+ os.makedirs(self.savedir, exist_ok=True)
+
+ self.diffusion_transformer_list = []
+ self.motion_module_list = []
+ self.personalized_model_list = []
+
+ self.refresh_diffusion_transformer()
+ self.refresh_motion_module()
+ self.refresh_personalized_model()
+
+ # config models
+ self.tokenizer = None
+ self.text_encoder = None
+ self.vae = None
+ self.transformer = None
+ self.pipeline = None
+ self.motion_module_path = "none"
+ self.base_model_path = "none"
+ self.lora_model_path = "none"
+
+ self.weight_dtype = torch.bfloat16
+
+ def refresh_diffusion_transformer(self):
+ self.diffusion_transformer_list = glob(os.path.join(self.diffusion_transformer_dir, "*/"))
+
+ def refresh_motion_module(self):
+ motion_module_list = glob(os.path.join(self.motion_module_dir, "*.safetensors"))
+ self.motion_module_list = [os.path.basename(p) for p in motion_module_list]
+
+ def refresh_personalized_model(self):
+ personalized_model_list = glob(os.path.join(self.personalized_model_dir, "*.safetensors"))
+ self.personalized_model_list = [os.path.basename(p) for p in personalized_model_list]
+
+ def update_edition(self, edition):
+ print("Update edition of EasyAnimate")
+ self.edition = edition
+ if edition == "v1":
+ self.inference_config = OmegaConf.load(os.path.join(self.config_dir, "easyanimate_video_motion_module_v1.yaml"))
+ return gr.Dropdown.update(), gr.update(value="none"), gr.update(visible=True), gr.update(visible=True), \
+ gr.update(visible=False), gr.update(value=512, minimum=384, maximum=704, step=32), \
+ gr.update(value=512, minimum=384, maximum=704, step=32), gr.update(value=80, minimum=40, maximum=80, step=1)
+ else:
+ self.inference_config = OmegaConf.load(os.path.join(self.config_dir, "easyanimate_video_magvit_motion_module_v2.yaml"))
+ return gr.Dropdown.update(), gr.update(value="none"), gr.update(visible=False), gr.update(visible=False), \
+ gr.update(visible=True), gr.update(value=672, minimum=128, maximum=1280, step=16), \
+ gr.update(value=384, minimum=128, maximum=1280, step=16), gr.update(value=144, minimum=9, maximum=144, step=9)
+
+ def update_diffusion_transformer(self, diffusion_transformer_dropdown):
+ print("Update diffusion transformer")
+ if diffusion_transformer_dropdown == "none":
+ return gr.Dropdown.update()
+ if OmegaConf.to_container(self.inference_config['vae_kwargs'])['enable_magvit']:
+ Choosen_AutoencoderKL = AutoencoderKLMagvit
+ else:
+ Choosen_AutoencoderKL = AutoencoderKL
+ self.vae = Choosen_AutoencoderKL.from_pretrained(
+ diffusion_transformer_dropdown,
+ subfolder="vae",
+ ).to(self.weight_dtype)
+ self.transformer = Transformer3DModel.from_pretrained_2d(
+ diffusion_transformer_dropdown,
+ subfolder="transformer",
+ transformer_additional_kwargs=OmegaConf.to_container(self.inference_config.transformer_additional_kwargs)
+ ).to(self.weight_dtype)
+ self.tokenizer = T5Tokenizer.from_pretrained(diffusion_transformer_dropdown, subfolder="tokenizer")
+ self.text_encoder = T5EncoderModel.from_pretrained(diffusion_transformer_dropdown, subfolder="text_encoder", torch_dtype=self.weight_dtype)
+
+ # Get pipeline
+ self.pipeline = EasyAnimatePipeline(
+ vae=self.vae,
+ text_encoder=self.text_encoder,
+ tokenizer=self.tokenizer,
+ transformer=self.transformer,
+ scheduler=scheduler_dict["Euler"](**OmegaConf.to_container(self.inference_config.noise_scheduler_kwargs))
+ )
+ self.pipeline.enable_model_cpu_offload()
+ print("Update diffusion transformer done")
+ return gr.Dropdown.update()
+
+ def update_motion_module(self, motion_module_dropdown):
+ self.motion_module_path = motion_module_dropdown
+ print("Update motion module")
+ if motion_module_dropdown == "none":
+ return gr.Dropdown.update()
+ if self.transformer is None:
+ gr.Info(f"Please select a pretrained model path.")
+ return gr.Dropdown.update(value=None)
+ else:
+ motion_module_dropdown = os.path.join(self.motion_module_dir, motion_module_dropdown)
+ if motion_module_dropdown.endswith(".safetensors"):
+ from safetensors.torch import load_file, safe_open
+ motion_module_state_dict = load_file(motion_module_dropdown)
+ else:
+ if not os.path.isfile(motion_module_dropdown):
+ raise RuntimeError(f"{motion_module_dropdown} does not exist")
+ motion_module_state_dict = torch.load(motion_module_dropdown, map_location="cpu")
+ missing, unexpected = self.transformer.load_state_dict(motion_module_state_dict, strict=False)
+ print("Update motion module done.")
+ return gr.Dropdown.update()
+
+ def update_base_model(self, base_model_dropdown):
+ self.base_model_path = base_model_dropdown
+ print("Update base model")
+ if base_model_dropdown == "none":
+ return gr.Dropdown.update()
+ if self.transformer is None:
+ gr.Info(f"Please select a pretrained model path.")
+ return gr.Dropdown.update(value=None)
+ else:
+ base_model_dropdown = os.path.join(self.personalized_model_dir, base_model_dropdown)
+ base_model_state_dict = {}
+ with safe_open(base_model_dropdown, framework="pt", device="cpu") as f:
+ for key in f.keys():
+ base_model_state_dict[key] = f.get_tensor(key)
+ self.transformer.load_state_dict(base_model_state_dict, strict=False)
+ print("Update base done")
+ return gr.Dropdown.update()
+
+ def update_lora_model(self, lora_model_dropdown):
+ print("Update lora model")
+ if lora_model_dropdown == "none":
+ self.lora_model_path = "none"
+ return gr.Dropdown.update()
+ lora_model_dropdown = os.path.join(self.personalized_model_dir, lora_model_dropdown)
+ self.lora_model_path = lora_model_dropdown
+ return gr.Dropdown.update()
+
+ def generate(
+ self,
+ diffusion_transformer_dropdown,
+ motion_module_dropdown,
+ base_model_dropdown,
+ lora_model_dropdown,
+ lora_alpha_slider,
+ prompt_textbox,
+ negative_prompt_textbox,
+ sampler_dropdown,
+ sample_step_slider,
+ width_slider,
+ height_slider,
+ is_image,
+ length_slider,
+ cfg_scale_slider,
+ seed_textbox,
+ is_api = False,
+ ):
+ global sample_idx
+ if self.transformer is None:
+ raise gr.Error(f"Please select a pretrained model path.")
+
+ if self.base_model_path != base_model_dropdown:
+ self.update_base_model(base_model_dropdown)
+
+ if self.motion_module_path != motion_module_dropdown:
+ self.update_motion_module(motion_module_dropdown)
+
+ if self.lora_model_path != lora_model_dropdown:
+ print("Update lora model")
+ self.update_lora_model(lora_model_dropdown)
+
+ if is_xformers_available(): self.transformer.enable_xformers_memory_efficient_attention()
+
+ self.pipeline.scheduler = scheduler_dict[sampler_dropdown](**OmegaConf.to_container(self.inference_config.noise_scheduler_kwargs))
+ if self.lora_model_path != "none":
+ # lora part
+ self.pipeline = merge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider)
+ self.pipeline.to("cuda")
+
+ if int(seed_textbox) != -1 and seed_textbox != "": torch.manual_seed(int(seed_textbox))
+ else: seed_textbox = np.random.randint(0, 1e10)
+ generator = torch.Generator(device="cuda").manual_seed(int(seed_textbox))
+
+ try:
+ sample = self.pipeline(
+ prompt_textbox,
+ negative_prompt = negative_prompt_textbox,
+ num_inference_steps = sample_step_slider,
+ guidance_scale = cfg_scale_slider,
+ width = width_slider,
+ height = height_slider,
+ video_length = length_slider if not is_image else 1,
+ generator = generator
+ ).videos
+ except Exception as e:
+ gc.collect()
+ torch.cuda.empty_cache()
+ torch.cuda.ipc_collect()
+ if self.lora_model_path != "none":
+ self.pipeline = unmerge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider)
+ if is_api:
+ return "", f"Error. error information is {str(e)}"
+ else:
+ return gr.Image.update(), gr.Video.update(), f"Error. error information is {str(e)}"
+
+ # lora part
+ if self.lora_model_path != "none":
+ self.pipeline = unmerge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider)
+
+ sample_config = {
+ "prompt": prompt_textbox,
+ "n_prompt": negative_prompt_textbox,
+ "sampler": sampler_dropdown,
+ "num_inference_steps": sample_step_slider,
+ "guidance_scale": cfg_scale_slider,
+ "width": width_slider,
+ "height": height_slider,
+ "video_length": length_slider,
+ "seed_textbox": seed_textbox
+ }
+ json_str = json.dumps(sample_config, indent=4)
+ with open(os.path.join(self.savedir, "logs.json"), "a") as f:
+ f.write(json_str)
+ f.write("\n\n")
+
+ if not os.path.exists(self.savedir_sample):
+ os.makedirs(self.savedir_sample, exist_ok=True)
+ index = len([path for path in os.listdir(self.savedir_sample)]) + 1
+ prefix = str(index).zfill(3)
+
+ gc.collect()
+ torch.cuda.empty_cache()
+ torch.cuda.ipc_collect()
+ if is_image or length_slider == 1:
+ save_sample_path = os.path.join(self.savedir_sample, prefix + f".png")
+
+ image = sample[0, :, 0]
+ image = image.transpose(0, 1).transpose(1, 2)
+ image = (image * 255).numpy().astype(np.uint8)
+ image = Image.fromarray(image)
+ image.save(save_sample_path)
+
+ if is_api:
+ return save_sample_path, "Success"
+ else:
+ return gr.Image.update(value=save_sample_path, visible=True), gr.Video.update(value=None, visible=False), "Success"
+ else:
+ save_sample_path = os.path.join(self.savedir_sample, prefix + f".mp4")
+ save_videos_grid(sample, save_sample_path, fps=12 if self.edition == "v1" else 24)
+
+ if is_api:
+ return save_sample_path, "Success"
+ else:
+ return gr.Image.update(visible=False, value=None), gr.Video.update(value=save_sample_path, visible=True), "Success"
+
+
+def ui():
+ controller = EasyAnimateController()
+
+ with gr.Blocks(css=css) as demo:
+ gr.Markdown(
+ """
+ # EasyAnimate: Integrated generation of baseline scheme for videos and images.
+ Generate your videos easily
+ [Github](https://github.com/aigc-apps/EasyAnimate/)
+ """
+ )
+ with gr.Column(variant="panel"):
+ gr.Markdown(
+ """
+ ### 1. EasyAnimate Edition (select easyanimate edition first).
+ """
+ )
+ with gr.Row():
+ easyanimate_edition_dropdown = gr.Dropdown(
+ label="The config of EasyAnimate Edition",
+ choices=["v1", "v2"],
+ value="v2",
+ interactive=True,
+ )
+ gr.Markdown(
+ """
+ ### 2. Model checkpoints (select pretrained model path).
+ """
+ )
+ with gr.Row():
+ diffusion_transformer_dropdown = gr.Dropdown(
+ label="Pretrained Model Path",
+ choices=controller.diffusion_transformer_list,
+ value="none",
+ interactive=True,
+ )
+ diffusion_transformer_dropdown.change(
+ fn=controller.update_diffusion_transformer,
+ inputs=[diffusion_transformer_dropdown],
+ outputs=[diffusion_transformer_dropdown]
+ )
+
+ diffusion_transformer_refresh_button = gr.Button(value="\U0001F503", elem_classes="toolbutton")
+ def refresh_diffusion_transformer():
+ controller.refresh_diffusion_transformer()
+ return gr.Dropdown.update(choices=controller.diffusion_transformer_list)
+ diffusion_transformer_refresh_button.click(fn=refresh_diffusion_transformer, inputs=[], outputs=[diffusion_transformer_dropdown])
+
+ with gr.Row():
+ motion_module_dropdown = gr.Dropdown(
+ label="Select motion module",
+ choices=controller.motion_module_list,
+ value="none",
+ interactive=True,
+ visible=False
+ )
+
+ motion_module_refresh_button = gr.Button(value="\U0001F503", elem_classes="toolbutton", visible=False)
+ def update_motion_module():
+ controller.refresh_motion_module()
+ return gr.Dropdown.update(choices=controller.motion_module_list)
+ motion_module_refresh_button.click(fn=update_motion_module, inputs=[], outputs=[motion_module_dropdown])
+
+ base_model_dropdown = gr.Dropdown(
+ label="Select base Dreambooth model (optional)",
+ choices=controller.personalized_model_list,
+ value="none",
+ interactive=True,
+ )
+
+ lora_model_dropdown = gr.Dropdown(
+ label="Select LoRA model (optional)",
+ choices=["none"] + controller.personalized_model_list,
+ value="none",
+ interactive=True,
+ )
+
+ lora_alpha_slider = gr.Slider(label="LoRA alpha", value=0.55, minimum=0, maximum=2, interactive=True)
+
+ personalized_refresh_button = gr.Button(value="\U0001F503", elem_classes="toolbutton")
+ def update_personalized_model():
+ controller.refresh_personalized_model()
+ return [
+ gr.Dropdown.update(choices=controller.personalized_model_list),
+ gr.Dropdown.update(choices=["none"] + controller.personalized_model_list)
+ ]
+ personalized_refresh_button.click(fn=update_personalized_model, inputs=[], outputs=[base_model_dropdown, lora_model_dropdown])
+
+ with gr.Column(variant="panel"):
+ gr.Markdown(
+ """
+ ### 3. Configs for Generation.
+ """
+ )
+
+ prompt_textbox = gr.Textbox(label="Prompt", lines=2, value="This video shows the majestic beauty of a waterfall cascading down a cliff into a serene lake. The waterfall, with its powerful flow, is the central focus of the video. The surrounding landscape is lush and green, with trees and foliage adding to the natural beauty of the scene")
+ negative_prompt_textbox = gr.Textbox(label="Negative prompt", lines=2, value="The video is not of a high quality, it has a low resolution, and the audio quality is not clear. Strange motion trajectory, a poor composition and deformed video, low resolution, duplicate and ugly, strange body structure, long and strange neck, bad teeth, bad eyes, bad limbs, bad hands, rotating camera, blurry camera, shaking camera. Deformation, low-resolution, blurry, ugly, distortion. " )
+
+ with gr.Row():
+ with gr.Column():
+ with gr.Row():
+ sampler_dropdown = gr.Dropdown(label="Sampling method", choices=list(scheduler_dict.keys()), value=list(scheduler_dict.keys())[0])
+ sample_step_slider = gr.Slider(label="Sampling steps", value=50, minimum=10, maximum=100, step=1)
+
+ width_slider = gr.Slider(label="Width", value=672, minimum=128, maximum=1280, step=16)
+ height_slider = gr.Slider(label="Height", value=384, minimum=128, maximum=1280, step=16)
+ with gr.Row():
+ is_image = gr.Checkbox(False, label="Generate Image")
+ length_slider = gr.Slider(label="Animation length", value=144, minimum=9, maximum=144, step=9)
+ cfg_scale_slider = gr.Slider(label="CFG Scale", value=7.0, minimum=0, maximum=20)
+
+ with gr.Row():
+ seed_textbox = gr.Textbox(label="Seed", value=43)
+ seed_button = gr.Button(value="\U0001F3B2", elem_classes="toolbutton")
+ seed_button.click(fn=lambda: gr.Textbox.update(value=random.randint(1, 1e8)), inputs=[], outputs=[seed_textbox])
+
+ generate_button = gr.Button(value="Generate", variant='primary')
+
+ with gr.Column():
+ result_image = gr.Image(label="Generated Image", interactive=False, visible=False)
+ result_video = gr.Video(label="Generated Animation", interactive=False)
+ infer_progress = gr.Textbox(
+ label="Generation Info",
+ value="No task currently",
+ interactive=False
+ )
+
+ is_image.change(
+ lambda x: gr.update(visible=not x),
+ inputs=[is_image],
+ outputs=[length_slider],
+ )
+ easyanimate_edition_dropdown.change(
+ fn=controller.update_edition,
+ inputs=[easyanimate_edition_dropdown],
+ outputs=[
+ easyanimate_edition_dropdown,
+ diffusion_transformer_dropdown,
+ motion_module_dropdown,
+ motion_module_refresh_button,
+ is_image,
+ width_slider,
+ height_slider,
+ length_slider,
+ ]
+ )
+ generate_button.click(
+ fn=controller.generate,
+ inputs=[
+ diffusion_transformer_dropdown,
+ motion_module_dropdown,
+ base_model_dropdown,
+ lora_model_dropdown,
+ lora_alpha_slider,
+ prompt_textbox,
+ negative_prompt_textbox,
+ sampler_dropdown,
+ sample_step_slider,
+ width_slider,
+ height_slider,
+ is_image,
+ length_slider,
+ cfg_scale_slider,
+ seed_textbox,
+ ],
+ outputs=[result_image, result_video, infer_progress]
+ )
+ return demo, controller
+
+
+class EasyAnimateController_Modelscope:
+ def __init__(self, edition, config_path, model_name, savedir_sample):
+ # Config and model path
+ weight_dtype = torch.bfloat16
+ self.savedir_sample = savedir_sample
+ os.makedirs(self.savedir_sample, exist_ok=True)
+
+ self.edition = edition
+ self.inference_config = OmegaConf.load(config_path)
+ # Get Transformer
+ self.transformer = Transformer3DModel.from_pretrained_2d(
+ model_name,
+ subfolder="transformer",
+ transformer_additional_kwargs=OmegaConf.to_container(self.inference_config['transformer_additional_kwargs'])
+ ).to(weight_dtype)
+ if OmegaConf.to_container(self.inference_config['vae_kwargs'])['enable_magvit']:
+ Choosen_AutoencoderKL = AutoencoderKLMagvit
+ else:
+ Choosen_AutoencoderKL = AutoencoderKL
+ self.vae = Choosen_AutoencoderKL.from_pretrained(
+ model_name,
+ subfolder="vae"
+ ).to(weight_dtype)
+ self.tokenizer = T5Tokenizer.from_pretrained(
+ model_name,
+ subfolder="tokenizer"
+ )
+ self.text_encoder = T5EncoderModel.from_pretrained(
+ model_name,
+ subfolder="text_encoder",
+ torch_dtype=weight_dtype
+ )
+ self.pipeline = EasyAnimatePipeline(
+ vae=self.vae,
+ text_encoder=self.text_encoder,
+ tokenizer=self.tokenizer,
+ transformer=self.transformer,
+ scheduler=scheduler_dict["Euler"](**OmegaConf.to_container(self.inference_config.noise_scheduler_kwargs))
+ )
+ self.pipeline.enable_model_cpu_offload()
+ print("Update diffusion transformer done")
+
+ def generate(
+ self,
+ prompt_textbox,
+ negative_prompt_textbox,
+ sampler_dropdown,
+ sample_step_slider,
+ width_slider,
+ height_slider,
+ is_image,
+ length_slider,
+ cfg_scale_slider,
+ seed_textbox
+ ):
+ if is_xformers_available(): self.transformer.enable_xformers_memory_efficient_attention()
+
+ self.pipeline.scheduler = scheduler_dict[sampler_dropdown](**OmegaConf.to_container(self.inference_config.noise_scheduler_kwargs))
+ self.pipeline.to("cuda")
+
+ if int(seed_textbox) != -1 and seed_textbox != "": torch.manual_seed(int(seed_textbox))
+ else: seed_textbox = np.random.randint(0, 1e10)
+ generator = torch.Generator(device="cuda").manual_seed(int(seed_textbox))
+
+ try:
+ sample = self.pipeline(
+ prompt_textbox,
+ negative_prompt = negative_prompt_textbox,
+ num_inference_steps = sample_step_slider,
+ guidance_scale = cfg_scale_slider,
+ width = width_slider,
+ height = height_slider,
+ video_length = length_slider if not is_image else 1,
+ generator = generator
+ ).videos
+ except Exception as e:
+ gc.collect()
+ torch.cuda.empty_cache()
+ torch.cuda.ipc_collect()
+ return gr.Image.update(), gr.Video.update(), f"Error. error information is {str(e)}"
+
+ if not os.path.exists(self.savedir_sample):
+ os.makedirs(self.savedir_sample, exist_ok=True)
+ index = len([path for path in os.listdir(self.savedir_sample)]) + 1
+ prefix = str(index).zfill(3)
+
+ gc.collect()
+ torch.cuda.empty_cache()
+ torch.cuda.ipc_collect()
+ if is_image or length_slider == 1:
+ save_sample_path = os.path.join(self.savedir_sample, prefix + f".png")
+
+ image = sample[0, :, 0]
+ image = image.transpose(0, 1).transpose(1, 2)
+ image = (image * 255).numpy().astype(np.uint8)
+ image = Image.fromarray(image)
+ image.save(save_sample_path)
+ return gr.Image.update(value=save_sample_path, visible=True), gr.Video.update(value=None, visible=False), "Success"
+ else:
+ save_sample_path = os.path.join(self.savedir_sample, prefix + f".mp4")
+ save_videos_grid(sample, save_sample_path, fps=12 if self.edition == "v1" else 24)
+ return gr.Image.update(visible=False, value=None), gr.Video.update(value=save_sample_path, visible=True), "Success"
+
+
+def ui_modelscope(edition, config_path, model_name, savedir_sample):
+ controller = EasyAnimateController_Modelscope(edition, config_path, model_name, savedir_sample)
+
+ with gr.Blocks(css=css) as demo:
+ gr.Markdown(
+ """
+ # EasyAnimate: Integrated generation of baseline scheme for videos and images.
+ Generate your videos easily
+ [Github](https://github.com/aigc-apps/EasyAnimate/)
+ """
+ )
+ with gr.Column(variant="panel"):
+ prompt_textbox = gr.Textbox(label="Prompt", lines=2, value="This video shows the majestic beauty of a waterfall cascading down a cliff into a serene lake. The waterfall, with its powerful flow, is the central focus of the video. The surrounding landscape is lush and green, with trees and foliage adding to the natural beauty of the scene")
+ negative_prompt_textbox = gr.Textbox(label="Negative prompt", lines=2, value="The video is not of a high quality, it has a low resolution, and the audio quality is not clear. Strange motion trajectory, a poor composition and deformed video, low resolution, duplicate and ugly, strange body structure, long and strange neck, bad teeth, bad eyes, bad limbs, bad hands, rotating camera, blurry camera, shaking camera. Deformation, low-resolution, blurry, ugly, distortion. " )
+
+ with gr.Row():
+ with gr.Column():
+ with gr.Row():
+ sampler_dropdown = gr.Dropdown(label="Sampling method", choices=list(scheduler_dict.keys()), value=list(scheduler_dict.keys())[0])
+ sample_step_slider = gr.Slider(label="Sampling steps", value=30, minimum=10, maximum=100, step=1)
+
+ if edition == "v1":
+ width_slider = gr.Slider(label="Width", value=512, minimum=384, maximum=704, step=32)
+ height_slider = gr.Slider(label="Height", value=512, minimum=384, maximum=704, step=32)
+ with gr.Row():
+ is_image = gr.Checkbox(False, label="Generate Image", visible=False)
+ length_slider = gr.Slider(label="Animation length", value=80, minimum=40, maximum=96, step=1)
+ cfg_scale_slider = gr.Slider(label="CFG Scale", value=6.0, minimum=0, maximum=20)
+ else:
+ width_slider = gr.Slider(label="Width", value=672, minimum=256, maximum=704, step=16)
+ height_slider = gr.Slider(label="Height", value=384, minimum=256, maximum=704, step=16)
+ with gr.Column():
+ gr.Markdown(
+ """
+ To ensure the efficiency of the trial, we will limit the frame rate to no more than 81.
+ If you want to experience longer video generation, you can go to our [Github](https://github.com/aigc-apps/EasyAnimate/).
+ """
+ )
+ with gr.Row():
+ is_image = gr.Checkbox(False, label="Generate Image")
+ length_slider = gr.Slider(label="Animation length", value=72, minimum=9, maximum=81, step=9)
+ cfg_scale_slider = gr.Slider(label="CFG Scale", value=7.0, minimum=0, maximum=20)
+
+ with gr.Row():
+ seed_textbox = gr.Textbox(label="Seed", value=43)
+ seed_button = gr.Button(value="\U0001F3B2", elem_classes="toolbutton")
+ seed_button.click(fn=lambda: gr.Textbox.update(value=random.randint(1, 1e8)), inputs=[], outputs=[seed_textbox])
+
+ generate_button = gr.Button(value="Generate", variant='primary')
+
+ with gr.Column():
+ result_image = gr.Image(label="Generated Image", interactive=False, visible=False)
+ result_video = gr.Video(label="Generated Animation", interactive=False)
+ infer_progress = gr.Textbox(
+ label="Generation Info",
+ value="No task currently",
+ interactive=False
+ )
+
+ is_image.change(
+ lambda x: gr.update(visible=not x),
+ inputs=[is_image],
+ outputs=[length_slider],
+ )
+
+ generate_button.click(
+ fn=controller.generate,
+ inputs=[
+ prompt_textbox,
+ negative_prompt_textbox,
+ sampler_dropdown,
+ sample_step_slider,
+ width_slider,
+ height_slider,
+ is_image,
+ length_slider,
+ cfg_scale_slider,
+ seed_textbox,
+ ],
+ outputs=[result_image, result_video, infer_progress]
+ )
+ return demo, controller
+
+
+def post_eas(
+ prompt_textbox, negative_prompt_textbox,
+ sampler_dropdown, sample_step_slider, width_slider, height_slider,
+ is_image, length_slider, cfg_scale_slider, seed_textbox,
+):
+ datas = {
+ "base_model_path": "none",
+ "motion_module_path": "none",
+ "lora_model_path": "none",
+ "lora_alpha_slider": 0.55,
+ "prompt_textbox": prompt_textbox,
+ "negative_prompt_textbox": negative_prompt_textbox,
+ "sampler_dropdown": sampler_dropdown,
+ "sample_step_slider": sample_step_slider,
+ "width_slider": width_slider,
+ "height_slider": height_slider,
+ "is_image": is_image,
+ "length_slider": length_slider,
+ "cfg_scale_slider": cfg_scale_slider,
+ "seed_textbox": seed_textbox,
+ }
+ # Token可以在公网地址调用信息中获取,详情请参见通用公网调用部分。
+ session = requests.session()
+ session.headers.update({"Authorization": os.environ.get("EAS_TOKEN")})
+
+ response = session.post(url=f'{os.environ.get("EAS_URL")}/easyanimate/infer_forward', json=datas)
+ outputs = response.json()
+ return outputs
+
+
+class EasyAnimateController_HuggingFace:
+ def __init__(self, edition, config_path, model_name, savedir_sample):
+ self.savedir_sample = savedir_sample
+ os.makedirs(self.savedir_sample, exist_ok=True)
+
+ def generate(
+ self,
+ prompt_textbox,
+ negative_prompt_textbox,
+ sampler_dropdown,
+ sample_step_slider,
+ width_slider,
+ height_slider,
+ is_image,
+ length_slider,
+ cfg_scale_slider,
+ seed_textbox
+ ):
+ outputs = post_eas(
+ prompt_textbox, negative_prompt_textbox,
+ sampler_dropdown, sample_step_slider, width_slider, height_slider,
+ is_image, length_slider, cfg_scale_slider, seed_textbox
+ )
+ base64_encoding = outputs["base64_encoding"]
+ decoded_data = base64.b64decode(base64_encoding)
+
+ if not os.path.exists(self.savedir_sample):
+ os.makedirs(self.savedir_sample, exist_ok=True)
+ index = len([path for path in os.listdir(self.savedir_sample)]) + 1
+ prefix = str(index).zfill(3)
+
+ if is_image or length_slider == 1:
+ save_sample_path = os.path.join(self.savedir_sample, prefix + f".png")
+ with open(save_sample_path, "wb") as file:
+ file.write(decoded_data)
+ return gr.Image.update(value=save_sample_path, visible=True), gr.Video.update(value=None, visible=False), "Success"
+ else:
+ save_sample_path = os.path.join(self.savedir_sample, prefix + f".mp4")
+ with open(save_sample_path, "wb") as file:
+ file.write(decoded_data)
+ return gr.Image.update(visible=False, value=None), gr.Video.update(value=save_sample_path, visible=True), "Success"
+
+
+def ui_huggingface(edition, config_path, model_name, savedir_sample):
+ controller = EasyAnimateController_HuggingFace(edition, config_path, model_name, savedir_sample)
+
+ with gr.Blocks(css=css) as demo:
+ gr.Markdown(
+ """
+ # EasyAnimate: Integrated generation of baseline scheme for videos and images.
+ Generate your videos easily
+ [Github](https://github.com/aigc-apps/EasyAnimate/)
+ """
+ )
+ with gr.Column(variant="panel"):
+ prompt_textbox = gr.Textbox(label="Prompt", lines=2, value="This video shows the majestic beauty of a waterfall cascading down a cliff into a serene lake. The waterfall, with its powerful flow, is the central focus of the video. The surrounding landscape is lush and green, with trees and foliage adding to the natural beauty of the scene")
+ negative_prompt_textbox = gr.Textbox(label="Negative prompt", lines=2, value="The video is not of a high quality, it has a low resolution, and the audio quality is not clear. Strange motion trajectory, a poor composition and deformed video, low resolution, duplicate and ugly, strange body structure, long and strange neck, bad teeth, bad eyes, bad limbs, bad hands, rotating camera, blurry camera, shaking camera. Deformation, low-resolution, blurry, ugly, distortion. " )
+
+ with gr.Row():
+ with gr.Column():
+ with gr.Row():
+ sampler_dropdown = gr.Dropdown(label="Sampling method", choices=list(scheduler_dict.keys()), value=list(scheduler_dict.keys())[0])
+ sample_step_slider = gr.Slider(label="Sampling steps", value=30, minimum=10, maximum=100, step=1)
+
+ if edition == "v1":
+ width_slider = gr.Slider(label="Width", value=512, minimum=384, maximum=704, step=32)
+ height_slider = gr.Slider(label="Height", value=512, minimum=384, maximum=704, step=32)
+ with gr.Row():
+ is_image = gr.Checkbox(False, label="Generate Image", visible=False)
+ length_slider = gr.Slider(label="Animation length", value=80, minimum=40, maximum=96, step=1)
+ cfg_scale_slider = gr.Slider(label="CFG Scale", value=6.0, minimum=0, maximum=20)
+ else:
+ width_slider = gr.Slider(label="Width", value=672, minimum=256, maximum=704, step=16)
+ height_slider = gr.Slider(label="Height", value=384, minimum=256, maximum=704, step=16)
+ with gr.Column():
+ gr.Markdown(
+ """
+ To ensure the efficiency of the trial, we will limit the frame rate to no more than 81.
+ If you want to experience longer video generation, you can go to our [Github](https://github.com/aigc-apps/EasyAnimate/).
+ """
+ )
+ with gr.Row():
+ is_image = gr.Checkbox(False, label="Generate Image")
+ length_slider = gr.Slider(label="Animation length", value=72, minimum=9, maximum=81, step=9)
+ cfg_scale_slider = gr.Slider(label="CFG Scale", value=7.0, minimum=0, maximum=20)
+
+ with gr.Row():
+ seed_textbox = gr.Textbox(label="Seed", value=43)
+ seed_button = gr.Button(value="\U0001F3B2", elem_classes="toolbutton")
+ seed_button.click(fn=lambda: gr.Textbox.update(value=random.randint(1, 1e8)), inputs=[], outputs=[seed_textbox])
+
+ generate_button = gr.Button(value="Generate", variant='primary')
+
+ with gr.Column():
+ result_image = gr.Image(label="Generated Image", interactive=False, visible=False)
+ result_video = gr.Video(label="Generated Animation", interactive=False)
+ infer_progress = gr.Textbox(
+ label="Generation Info",
+ value="No task currently",
+ interactive=False
+ )
+
+ is_image.change(
+ lambda x: gr.update(visible=not x),
+ inputs=[is_image],
+ outputs=[length_slider],
+ )
+
+ generate_button.click(
+ fn=controller.generate,
+ inputs=[
+ prompt_textbox,
+ negative_prompt_textbox,
+ sampler_dropdown,
+ sample_step_slider,
+ width_slider,
+ height_slider,
+ is_image,
+ length_slider,
+ cfg_scale_slider,
+ seed_textbox,
+ ],
+ outputs=[result_image, result_video, infer_progress]
+ )
+ return demo, controller
\ No newline at end of file
diff --git a/easyanimate/utils/__init__.py b/easyanimate/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/easyanimate/utils/diffusion_utils.py b/easyanimate/utils/diffusion_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..a5e06e31d237c0535143fd3623794f769f24dc2e
--- /dev/null
+++ b/easyanimate/utils/diffusion_utils.py
@@ -0,0 +1,92 @@
+# Modified from OpenAI's diffusion repos
+# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
+# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
+# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
+
+import numpy as np
+import torch as th
+
+
+def normal_kl(mean1, logvar1, mean2, logvar2):
+ """
+ Compute the KL divergence between two gaussians.
+ Shapes are automatically broadcasted, so batches can be compared to
+ scalars, among other use cases.
+ """
+ tensor = next(
+ (
+ obj
+ for obj in (mean1, logvar1, mean2, logvar2)
+ if isinstance(obj, th.Tensor)
+ ),
+ None,
+ )
+ assert tensor is not None, "at least one argument must be a Tensor"
+
+ # Force variances to be Tensors. Broadcasting helps convert scalars to
+ # Tensors, but it does not work for th.exp().
+ logvar1, logvar2 = [
+ x if isinstance(x, th.Tensor) else th.tensor(x, device=tensor.device)
+ for x in (logvar1, logvar2)
+ ]
+
+ return 0.5 * (
+ -1.0
+ + logvar2
+ - logvar1
+ + th.exp(logvar1 - logvar2)
+ + ((mean1 - mean2) ** 2) * th.exp(-logvar2)
+ )
+
+
+def approx_standard_normal_cdf(x):
+ """
+ A fast approximation of the cumulative distribution function of the
+ standard normal.
+ """
+ return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3))))
+
+
+def continuous_gaussian_log_likelihood(x, *, means, log_scales):
+ """
+ Compute the log-likelihood of a continuous Gaussian distribution.
+ :param x: the targets
+ :param means: the Gaussian mean Tensor.
+ :param log_scales: the Gaussian log stddev Tensor.
+ :return: a tensor like x of log probabilities (in nats).
+ """
+ centered_x = x - means
+ inv_stdv = th.exp(-log_scales)
+ normalized_x = centered_x * inv_stdv
+ return th.distributions.Normal(th.zeros_like(x), th.ones_like(x)).log_prob(
+ normalized_x
+ )
+
+
+def discretized_gaussian_log_likelihood(x, *, means, log_scales):
+ """
+ Compute the log-likelihood of a Gaussian distribution discretizing to a
+ given image.
+ :param x: the target images. It is assumed that this was uint8 values,
+ rescaled to the range [-1, 1].
+ :param means: the Gaussian mean Tensor.
+ :param log_scales: the Gaussian log stddev Tensor.
+ :return: a tensor like x of log probabilities (in nats).
+ """
+ assert x.shape == means.shape == log_scales.shape
+ centered_x = x - means
+ inv_stdv = th.exp(-log_scales)
+ plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
+ cdf_plus = approx_standard_normal_cdf(plus_in)
+ min_in = inv_stdv * (centered_x - 1.0 / 255.0)
+ cdf_min = approx_standard_normal_cdf(min_in)
+ log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12))
+ log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12))
+ cdf_delta = cdf_plus - cdf_min
+ log_probs = th.where(
+ x < -0.999,
+ log_cdf_plus,
+ th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))),
+ )
+ assert log_probs.shape == x.shape
+ return log_probs
\ No newline at end of file
diff --git a/easyanimate/utils/gaussian_diffusion.py b/easyanimate/utils/gaussian_diffusion.py
new file mode 100644
index 0000000000000000000000000000000000000000..669b580debfc097b80717a0b1a575414fc5c0c8f
--- /dev/null
+++ b/easyanimate/utils/gaussian_diffusion.py
@@ -0,0 +1,1008 @@
+# Modified from OpenAI's diffusion repos
+# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
+# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
+# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
+
+
+import enum
+import math
+
+import numpy as np
+import torch as th
+import torch.nn.functional as F
+from einops import rearrange
+
+from .diffusion_utils import discretized_gaussian_log_likelihood, normal_kl
+
+
+def mean_flat(tensor):
+ """
+ Take the mean over all non-batch dimensions.
+ """
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
+
+
+class ModelMeanType(enum.Enum):
+ """
+ Which type of output the model predicts.
+ """
+
+ PREVIOUS_X = enum.auto() # the model predicts x_{t-1}
+ START_X = enum.auto() # the model predicts x_0
+ EPSILON = enum.auto() # the model predicts epsilon
+
+
+class ModelVarType(enum.Enum):
+ """
+ What is used as the model's output variance.
+ The LEARNED_RANGE option has been added to allow the model to predict
+ values between FIXED_SMALL and FIXED_LARGE, making its job easier.
+ """
+
+ LEARNED = enum.auto()
+ FIXED_SMALL = enum.auto()
+ FIXED_LARGE = enum.auto()
+ LEARNED_RANGE = enum.auto()
+
+
+class LossType(enum.Enum):
+ MSE = enum.auto() # use raw MSE loss (and KL when learning variances)
+ RESCALED_MSE = (
+ enum.auto()
+ ) # use raw MSE loss (with RESCALED_KL when learning variances)
+ KL = enum.auto() # use the variational lower-bound
+ RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB
+
+ def is_vb(self):
+ return self in [LossType.KL, LossType.RESCALED_KL]
+
+
+def _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, warmup_frac):
+ betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
+ warmup_time = int(num_diffusion_timesteps * warmup_frac)
+ betas[:warmup_time] = np.linspace(beta_start, beta_end, warmup_time, dtype=np.float64)
+ return betas
+
+
+def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps):
+ """
+ This is the deprecated API for creating beta schedules.
+ See get_named_beta_schedule() for the new library of schedules.
+ """
+ if beta_schedule == "quad":
+ betas = (
+ np.linspace(
+ beta_start ** 0.5,
+ beta_end ** 0.5,
+ num_diffusion_timesteps,
+ dtype=np.float64,
+ )
+ ** 2
+ )
+ elif beta_schedule == "linear":
+ betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64)
+ elif beta_schedule == "warmup10":
+ betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.1)
+ elif beta_schedule == "warmup50":
+ betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.5)
+ elif beta_schedule == "const":
+ betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
+ elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1
+ betas = 1.0 / np.linspace(
+ num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64
+ )
+ else:
+ raise NotImplementedError(beta_schedule)
+ assert betas.shape == (num_diffusion_timesteps,)
+ return betas
+
+
+def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
+ """
+ Get a pre-defined beta schedule for the given name.
+ The beta schedule library consists of beta schedules which remain similar
+ in the limit of num_diffusion_timesteps.
+ Beta schedules may be added, but should not be removed or changed once
+ they are committed to maintain backwards compatibility.
+ """
+ if schedule_name == "linear":
+ # Linear schedule from Ho et al, extended to work for any number of
+ # diffusion steps.
+ scale = 1000 / num_diffusion_timesteps
+ return get_beta_schedule(
+ "linear",
+ beta_start=scale * 0.0001,
+ beta_end=scale * 0.02,
+ num_diffusion_timesteps=num_diffusion_timesteps,
+ )
+ elif schedule_name == "squaredcos_cap_v2":
+ return betas_for_alpha_bar(
+ num_diffusion_timesteps,
+ lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
+ )
+ else:
+ raise NotImplementedError(f"unknown beta schedule: {schedule_name}")
+
+
+def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
+ """
+ Create a beta schedule that discretizes the given alpha_t_bar function,
+ which defines the cumulative product of (1-beta) over time from t = [0,1].
+ :param num_diffusion_timesteps: the number of betas to produce.
+ :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
+ produces the cumulative product of (1-beta) up to that
+ part of the diffusion process.
+ :param max_beta: the maximum beta to use; use values lower than 1 to
+ prevent singularities.
+ """
+ betas = []
+ for i in range(num_diffusion_timesteps):
+ t1 = i / num_diffusion_timesteps
+ t2 = (i + 1) / num_diffusion_timesteps
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
+ return np.array(betas)
+
+
+class GaussianDiffusion:
+ """
+ Utilities for training and sampling diffusion models.
+ Original ported from this codebase:
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42
+ :param betas: a 1-D numpy array of betas for each diffusion timestep,
+ starting at T and going to 1.
+ """
+
+ def __init__(
+ self,
+ *,
+ betas,
+ model_mean_type,
+ model_var_type,
+ loss_type,
+ snr=False,
+ return_startx=False,
+ ):
+
+ self.model_mean_type = model_mean_type
+ self.model_var_type = model_var_type
+ self.loss_type = loss_type
+ self.snr = snr
+ self.return_startx = return_startx
+
+ # Use float64 for accuracy.
+ betas = np.array(betas, dtype=np.float64)
+ self.betas = betas
+ assert len(betas.shape) == 1, "betas must be 1-D"
+ assert (betas > 0).all() and (betas <= 1).all()
+
+ self.num_timesteps = int(betas.shape[0])
+
+ alphas = 1.0 - betas
+ self.alphas_cumprod = np.cumprod(alphas, axis=0)
+ self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])
+ self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0)
+ assert self.alphas_cumprod_prev.shape == (self.num_timesteps,)
+
+ # calculations for diffusion q(x_t | x_{t-1}) and others
+ self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
+ self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod)
+ self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod)
+ self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod)
+ self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1)
+
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
+ self.posterior_variance = (
+ betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
+ )
+ # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
+ self.posterior_log_variance_clipped = np.log(
+ np.append(self.posterior_variance[1], self.posterior_variance[1:])
+ ) if len(self.posterior_variance) > 1 else np.array([])
+
+ self.posterior_mean_coef1 = (
+ betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
+ )
+ self.posterior_mean_coef2 = (
+ (1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - self.alphas_cumprod)
+ )
+
+ def q_mean_variance(self, x_start, t):
+ """
+ Get the distribution q(x_t | x_0).
+ :param x_start: the [N x C x ...] tensor of noiseless inputs.
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
+ :return: A tuple (mean, variance, log_variance), all of x_start's shape.
+ """
+ mean = _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
+ variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
+ log_variance = _extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape)
+ return mean, variance, log_variance
+
+ def q_sample(self, x_start, t, noise=None):
+ """
+ Diffuse the data for a given number of diffusion steps.
+ In other words, sample from q(x_t | x_0).
+ :param x_start: the initial data batch.
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
+ :param noise: if specified, the split-out normal noise.
+ :return: A noisy version of x_start.
+ """
+ if noise is None:
+ noise = th.randn_like(x_start)
+ assert noise.shape == x_start.shape
+ return (
+ _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
+ + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
+ )
+
+ def q_posterior_mean_variance(self, x_start, x_t, t):
+ """
+ Compute the mean and variance of the diffusion posterior:
+ q(x_{t-1} | x_t, x_0)
+ """
+ assert x_start.shape == x_t.shape
+ posterior_mean = (
+ _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
+ + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
+ )
+ posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape)
+ posterior_log_variance_clipped = _extract_into_tensor(
+ self.posterior_log_variance_clipped, t, x_t.shape
+ )
+ assert (
+ posterior_mean.shape[0]
+ == posterior_variance.shape[0]
+ == posterior_log_variance_clipped.shape[0]
+ == x_start.shape[0]
+ )
+ return posterior_mean, posterior_variance, posterior_log_variance_clipped
+
+ def p_mean_variance(self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None):
+ """
+ Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
+ the initial x, x_0.
+ :param model: the model, which takes a signal and a batch of timesteps
+ as input.
+ :param x: the [N x C x ...] tensor at time t.
+ :param t: a 1-D Tensor of timesteps.
+ :param clip_denoised: if True, clip the denoised signal into [-1, 1].
+ :param denoised_fn: if not None, a function which applies to the
+ x_start prediction before it is used to sample. Applies before
+ clip_denoised.
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
+ pass to the model. This can be used for conditioning.
+ :return: a dict with the following keys:
+ - 'mean': the model mean output.
+ - 'variance': the model variance output.
+ - 'log_variance': the log of 'variance'.
+ - 'pred_xstart': the prediction for x_0.
+ """
+ if model_kwargs is None:
+ model_kwargs = {}
+
+ B, C = x.shape[:2]
+ assert t.shape == (B,)
+ model_output = model(x, timestep=t, **model_kwargs)
+ if isinstance(model_output, tuple):
+ model_output, extra = model_output
+ else:
+ extra = None
+
+ if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]:
+ assert model_output.shape == (B, C * 2, *x.shape[2:])
+ model_output, model_var_values = th.split(model_output, C, dim=1)
+ min_log = _extract_into_tensor(self.posterior_log_variance_clipped, t, x.shape)
+ max_log = _extract_into_tensor(np.log(self.betas), t, x.shape)
+ # The model_var_values is [-1, 1] for [min_var, max_var].
+ frac = (model_var_values + 1) / 2
+ model_log_variance = frac * max_log + (1 - frac) * min_log
+ model_variance = th.exp(model_log_variance)
+ elif self.model_var_type in [ModelVarType.FIXED_LARGE, ModelVarType.FIXED_SMALL]:
+ model_variance, model_log_variance = {
+ # for fixedlarge, we set the initial (log-)variance like so
+ # to get a better decoder log likelihood.
+ ModelVarType.FIXED_LARGE: (
+ np.append(self.posterior_variance[1], self.betas[1:]),
+ np.log(np.append(self.posterior_variance[1], self.betas[1:])),
+ ),
+ ModelVarType.FIXED_SMALL: (
+ self.posterior_variance,
+ self.posterior_log_variance_clipped,
+ ),
+ }[self.model_var_type]
+ model_variance = _extract_into_tensor(model_variance, t, x.shape)
+ model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape)
+ else:
+ model_variance = th.zeros_like(model_output)
+ model_log_variance = th.zeros_like(model_output)
+
+ def process_xstart(x):
+ if denoised_fn is not None:
+ x = denoised_fn(x)
+ return x.clamp(-1, 1) if clip_denoised else x
+
+ if self.model_mean_type == ModelMeanType.START_X:
+ pred_xstart = process_xstart(model_output)
+ else:
+ pred_xstart = process_xstart(
+ self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output)
+ )
+ model_mean, _, _ = self.q_posterior_mean_variance(x_start=pred_xstart, x_t=x, t=t)
+
+ assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape
+ return {
+ "mean": model_mean,
+ "variance": model_variance,
+ "log_variance": model_log_variance,
+ "pred_xstart": pred_xstart,
+ "extra": extra,
+ }
+
+ def _predict_xstart_from_eps(self, x_t, t, eps):
+ assert x_t.shape == eps.shape
+ return (
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
+ - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
+ )
+
+ def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
+ return (
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart
+ ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
+
+ def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
+ """
+ Compute the mean for the previous step, given a function cond_fn that
+ computes the gradient of a conditional log probability with respect to
+ x. In particular, cond_fn computes grad(log(p(y|x))), and we want to
+ condition on y.
+ This uses the conditioning strategy from Sohl-Dickstein et al. (2015).
+ """
+ gradient = cond_fn(x, t, **model_kwargs)
+ return p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float()
+
+ def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
+ """
+ Compute what the p_mean_variance output would have been, should the
+ model's score function be conditioned by cond_fn.
+ See condition_mean() for details on cond_fn.
+ Unlike condition_mean(), this instead uses the conditioning strategy
+ from Song et al (2020).
+ """
+ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
+
+ eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"])
+ eps = eps - (1 - alpha_bar).sqrt() * cond_fn(x, t, **model_kwargs)
+
+ out = p_mean_var.copy()
+ out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps)
+ out["mean"], _, _ = self.q_posterior_mean_variance(x_start=out["pred_xstart"], x_t=x, t=t)
+ return out
+
+ def p_sample(
+ self,
+ model,
+ x,
+ t,
+ clip_denoised=True,
+ denoised_fn=None,
+ cond_fn=None,
+ model_kwargs=None,
+ ):
+ """
+ Sample x_{t-1} from the model at the given timestep.
+ :param model: the model to sample from.
+ :param x: the current tensor at x_{t-1}.
+ :param t: the value of t, starting at 0 for the first diffusion step.
+ :param clip_denoised: if True, clip the x_start prediction to [-1, 1].
+ :param denoised_fn: if not None, a function which applies to the
+ x_start prediction before it is used to sample.
+ :param cond_fn: if not None, this is a gradient function that acts
+ similarly to the model.
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
+ pass to the model. This can be used for conditioning.
+ :return: a dict containing the following keys:
+ - 'sample': a random sample from the model.
+ - 'pred_xstart': a prediction of x_0.
+ """
+ out = self.p_mean_variance(
+ model,
+ x,
+ t,
+ clip_denoised=clip_denoised,
+ denoised_fn=denoised_fn,
+ model_kwargs=model_kwargs,
+ )
+ noise = th.randn_like(x)
+ nonzero_mask = (
+ (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
+ ) # no noise when t == 0
+ if cond_fn is not None:
+ out["mean"] = self.condition_mean(cond_fn, out, x, t, model_kwargs=model_kwargs)
+ sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise
+ return {"sample": sample, "pred_xstart": out["pred_xstart"]}
+
+ def p_sample_loop(
+ self,
+ model,
+ shape,
+ noise=None,
+ clip_denoised=True,
+ denoised_fn=None,
+ cond_fn=None,
+ model_kwargs=None,
+ device=None,
+ progress=False,
+ ):
+ """
+ Generate samples from the model.
+ :param model: the model module.
+ :param shape: the shape of the samples, (N, C, H, W).
+ :param noise: if specified, the noise from the encoder to sample.
+ Should be of the same shape as `shape`.
+ :param clip_denoised: if True, clip x_start predictions to [-1, 1].
+ :param denoised_fn: if not None, a function which applies to the
+ x_start prediction before it is used to sample.
+ :param cond_fn: if not None, this is a gradient function that acts
+ similarly to the model.
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
+ pass to the model. This can be used for conditioning.
+ :param device: if specified, the device to create the samples on.
+ If not specified, use a model parameter's device.
+ :param progress: if True, show a tqdm progress bar.
+ :return: a non-differentiable batch of samples.
+ """
+ final = None
+ for sample in self.p_sample_loop_progressive(
+ model,
+ shape,
+ noise=noise,
+ clip_denoised=clip_denoised,
+ denoised_fn=denoised_fn,
+ cond_fn=cond_fn,
+ model_kwargs=model_kwargs,
+ device=device,
+ progress=progress,
+ ):
+ final = sample
+ return final["sample"]
+
+ def p_sample_loop_progressive(
+ self,
+ model,
+ shape,
+ noise=None,
+ clip_denoised=True,
+ denoised_fn=None,
+ cond_fn=None,
+ model_kwargs=None,
+ device=None,
+ progress=False,
+ ):
+ """
+ Generate samples from the model and yield intermediate samples from
+ each timestep of diffusion.
+ Arguments are the same as p_sample_loop().
+ Returns a generator over dicts, where each dict is the return value of
+ p_sample().
+ """
+ if device is None:
+ device = next(model.parameters()).device
+ assert isinstance(shape, (tuple, list))
+ img = noise if noise is not None else th.randn(*shape, device=device)
+ indices = list(range(self.num_timesteps))[::-1]
+
+ if progress:
+ # Lazy import so that we don't depend on tqdm.
+ from tqdm.auto import tqdm
+
+ indices = tqdm(indices)
+
+ for i in indices:
+ t = th.tensor([i] * shape[0], device=device)
+ with th.no_grad():
+ out = self.p_sample(
+ model,
+ img,
+ t,
+ clip_denoised=clip_denoised,
+ denoised_fn=denoised_fn,
+ cond_fn=cond_fn,
+ model_kwargs=model_kwargs,
+ )
+ yield out
+ img = out["sample"]
+
+ def ddim_sample(
+ self,
+ model,
+ x,
+ t,
+ clip_denoised=True,
+ denoised_fn=None,
+ cond_fn=None,
+ model_kwargs=None,
+ eta=0.0,
+ ):
+ """
+ Sample x_{t-1} from the model using DDIM.
+ Same usage as p_sample().
+ """
+ out = self.p_mean_variance(
+ model,
+ x,
+ t,
+ clip_denoised=clip_denoised,
+ denoised_fn=denoised_fn,
+ model_kwargs=model_kwargs,
+ )
+ if cond_fn is not None:
+ out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
+
+ # Usually our model outputs epsilon, but we re-derive it
+ # in case we used x_start or x_prev prediction.
+ eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"])
+
+ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
+ alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape)
+ sigma = (
+ eta
+ * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar))
+ * th.sqrt(1 - alpha_bar / alpha_bar_prev)
+ )
+ # Equation 12.
+ noise = th.randn_like(x)
+ mean_pred = (
+ out["pred_xstart"] * th.sqrt(alpha_bar_prev)
+ + th.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps
+ )
+ nonzero_mask = (
+ (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
+ ) # no noise when t == 0
+ sample = mean_pred + nonzero_mask * sigma * noise
+ return {"sample": sample, "pred_xstart": out["pred_xstart"]}
+
+ def ddim_reverse_sample(
+ self,
+ model,
+ x,
+ t,
+ clip_denoised=True,
+ denoised_fn=None,
+ cond_fn=None,
+ model_kwargs=None,
+ eta=0.0,
+ ):
+ """
+ Sample x_{t+1} from the model using DDIM reverse ODE.
+ """
+ assert eta == 0.0, "Reverse ODE only for deterministic path"
+ out = self.p_mean_variance(
+ model,
+ x,
+ t,
+ clip_denoised=clip_denoised,
+ denoised_fn=denoised_fn,
+ model_kwargs=model_kwargs,
+ )
+ if cond_fn is not None:
+ out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
+ # Usually our model outputs epsilon, but we re-derive it
+ # in case we used x_start or x_prev prediction.
+ eps = (
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x
+ - out["pred_xstart"]
+ ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape)
+ alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape)
+
+ # Equation 12. reversed
+ mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_next) + th.sqrt(1 - alpha_bar_next) * eps
+
+ return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]}
+
+ def ddim_sample_loop(
+ self,
+ model,
+ shape,
+ noise=None,
+ clip_denoised=True,
+ denoised_fn=None,
+ cond_fn=None,
+ model_kwargs=None,
+ device=None,
+ progress=False,
+ eta=0.0,
+ ):
+ """
+ Generate samples from the model using DDIM.
+ Same usage as p_sample_loop().
+ """
+ final = None
+ for sample in self.ddim_sample_loop_progressive(
+ model,
+ shape,
+ noise=noise,
+ clip_denoised=clip_denoised,
+ denoised_fn=denoised_fn,
+ cond_fn=cond_fn,
+ model_kwargs=model_kwargs,
+ device=device,
+ progress=progress,
+ eta=eta,
+ ):
+ final = sample
+ return final["sample"]
+
+ def ddim_sample_loop_progressive(
+ self,
+ model,
+ shape,
+ noise=None,
+ clip_denoised=True,
+ denoised_fn=None,
+ cond_fn=None,
+ model_kwargs=None,
+ device=None,
+ progress=False,
+ eta=0.0,
+ ):
+ """
+ Use DDIM to sample from the model and yield intermediate samples from
+ each timestep of DDIM.
+ Same usage as p_sample_loop_progressive().
+ """
+ if device is None:
+ device = next(model.parameters()).device
+ assert isinstance(shape, (tuple, list))
+ img = noise if noise is not None else th.randn(*shape, device=device)
+ indices = list(range(self.num_timesteps))[::-1]
+
+ if progress:
+ # Lazy import so that we don't depend on tqdm.
+ from tqdm.auto import tqdm
+
+ indices = tqdm(indices)
+
+ for i in indices:
+ t = th.tensor([i] * shape[0], device=device)
+ with th.no_grad():
+ out = self.ddim_sample(
+ model,
+ img,
+ t,
+ clip_denoised=clip_denoised,
+ denoised_fn=denoised_fn,
+ cond_fn=cond_fn,
+ model_kwargs=model_kwargs,
+ eta=eta,
+ )
+ yield out
+ img = out["sample"]
+
+ def _vb_terms_bpd(
+ self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None
+ ):
+ """
+ Get a term for the variational lower-bound.
+ The resulting units are bits (rather than nats, as one might expect).
+ This allows for comparison to other papers.
+ :return: a dict with the following keys:
+ - 'output': a shape [N] tensor of NLLs or KLs.
+ - 'pred_xstart': the x_0 predictions.
+ """
+ true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(
+ x_start=x_start, x_t=x_t, t=t
+ )
+ out = self.p_mean_variance(
+ model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs
+ )
+ kl = normal_kl(
+ true_mean, true_log_variance_clipped, out["mean"], out["log_variance"]
+ )
+ kl = mean_flat(kl) / np.log(2.0)
+
+ decoder_nll = -discretized_gaussian_log_likelihood(
+ x_start, means=out["mean"], log_scales=0.5 * out["log_variance"]
+ )
+ assert decoder_nll.shape == x_start.shape
+ decoder_nll = mean_flat(decoder_nll) / np.log(2.0)
+
+ # At the first timestep return the decoder NLL,
+ # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))
+ output = th.where((t == 0), decoder_nll, kl)
+ return {"output": output, "pred_xstart": out["pred_xstart"]}
+
+ def training_losses(self, model, x_start, timestep, model_kwargs=None, noise=None, skip_noise=False):
+ """
+ Compute training losses for a single timestep.
+ :param model: the model to evaluate loss on.
+ :param x_start: the [N x C x ...] tensor of inputs.
+ :param t: a batch of timestep indices.
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
+ pass to the model. This can be used for conditioning.
+ :param noise: if specified, the specific Gaussian noise to try to remove.
+ :return: a dict with the key "loss" containing a tensor of shape [N].
+ Some mean or variance settings may also have other keys.
+ """
+ t = timestep
+ if model_kwargs is None:
+ model_kwargs = {}
+ if skip_noise:
+ x_t = x_start
+ else:
+ if noise is None:
+ noise = th.randn_like(x_start)
+ x_t = self.q_sample(x_start, t, noise=noise)
+
+ terms = {}
+
+ if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL:
+ terms["loss"] = self._vb_terms_bpd(
+ model=model,
+ x_start=x_start,
+ x_t=x_t,
+ t=t,
+ clip_denoised=False,
+ model_kwargs=model_kwargs,
+ )["output"]
+ if self.loss_type == LossType.RESCALED_KL:
+ terms["loss"] *= self.num_timesteps
+ elif self.loss_type in [LossType.MSE, LossType.RESCALED_MSE]:
+ model_output = model(x_t, timestep=t, **model_kwargs)[0]
+
+ if isinstance(model_output, dict) and model_output.get('x', None) is not None:
+ output = model_output['x']
+ else:
+ output = model_output
+
+ if self.return_startx and self.model_mean_type == ModelMeanType.EPSILON:
+ return self._extracted_from_training_losses_diffusers(x_t, output, t)
+ # self.model_var_type = ModelVarType.LEARNED_RANGE:4
+ if self.model_var_type in [
+ ModelVarType.LEARNED,
+ ModelVarType.LEARNED_RANGE,
+ ]:
+ B, C = x_t.shape[:2]
+ assert output.shape == (B, C * 2, *x_t.shape[2:])
+ output, model_var_values = th.split(output, C, dim=1)
+ # Learn the variance using the variational bound, but don't let it affect our mean prediction.
+ frozen_out = th.cat([output.detach(), model_var_values], dim=1)
+ # vb variational bound
+ terms["vb"] = self._vb_terms_bpd(
+ model=lambda *args, r=frozen_out, **kwargs: r,
+ x_start=x_start,
+ x_t=x_t,
+ t=t,
+ clip_denoised=False,
+ )["output"]
+ if self.loss_type == LossType.RESCALED_MSE:
+ # Divide by 1000 for equivalence with initial implementation.
+ # Without a factor of 1/1000, the VB term hurts the MSE term.
+ terms["vb"] *= self.num_timesteps / 1000.0
+
+ target = {
+ ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance(
+ x_start=x_start, x_t=x_t, t=t
+ )[0],
+ ModelMeanType.START_X: x_start,
+ ModelMeanType.EPSILON: noise,
+ }[self.model_mean_type]
+ assert output.shape == target.shape == x_start.shape
+ if self.snr:
+ if self.model_mean_type == ModelMeanType.START_X:
+ pred_noise = self._predict_eps_from_xstart(x_t=x_t, t=t, pred_xstart=output)
+ pred_startx = output
+ elif self.model_mean_type == ModelMeanType.EPSILON:
+ pred_noise = output
+ pred_startx = self._predict_xstart_from_eps(x_t=x_t, t=t, eps=output)
+ # terms["mse_eps"] = mean_flat((noise - pred_noise) ** 2)
+ # terms["mse_x0"] = mean_flat((x_start - pred_startx) ** 2)
+
+ t = t[:, None, None, None, None].expand(pred_startx.shape) # [128, 4, 32, 32]
+ # best
+ target = th.where(t > 249, noise, x_start)
+ output = th.where(t > 249, pred_noise, pred_startx)
+ loss = (target - output) ** 2
+ if model_kwargs.get('mask_ratio', False) and model_kwargs['mask_ratio'] > 0:
+ assert 'mask' in model_output
+ loss = F.avg_pool2d(loss.mean(dim=1), model.model.module.patch_size).flatten(1)
+ mask = model_output['mask']
+ unmask = 1 - mask
+ terms['mse'] = mean_flat(loss * unmask) * unmask.shape[1]/unmask.sum(1)
+ if model_kwargs['mask_loss_coef'] > 0:
+ terms['mae'] = model_kwargs['mask_loss_coef'] * mean_flat(loss * mask) * mask.shape[1]/mask.sum(1)
+ else:
+ terms["mse"] = mean_flat(loss)
+ terms["loss"] = terms["mse"] + terms["vb"] if "vb" in terms else terms["mse"]
+ if "mae" in terms:
+ terms["loss"] = terms["loss"] + terms["mae"]
+ else:
+ raise NotImplementedError(self.loss_type)
+
+ return terms
+
+ def training_losses_diffusers(self, model, x_start, timestep, model_kwargs=None, noise=None, skip_noise=False):
+ """
+ Compute training losses for a single timestep.
+ :param model: the model to evaluate loss on.
+ :param x_start: the [N x C x ...] tensor of inputs.
+ :param t: a batch of timestep indices.
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
+ pass to the model. This can be used for conditioning.
+ :param noise: if specified, the specific Gaussian noise to try to remove.
+ :return: a dict with the key "loss" containing a tensor of shape [N].
+ Some mean or variance settings may also have other keys.
+ """
+ t = timestep
+ if model_kwargs is None:
+ model_kwargs = {}
+ if skip_noise:
+ x_t = x_start
+ else:
+ if noise is None:
+ noise = th.randn_like(x_start)
+ x_t = self.q_sample(x_start, t, noise=noise)
+
+ terms = {}
+
+ if self.loss_type in [LossType.KL, LossType.RESCALED_KL]:
+ terms["loss"] = self._vb_terms_bpd(
+ model=model,
+ x_start=x_start,
+ x_t=x_t,
+ t=t,
+ clip_denoised=False,
+ model_kwargs=model_kwargs,
+ )["output"]
+ if self.loss_type == LossType.RESCALED_KL:
+ terms["loss"] *= self.num_timesteps
+ elif self.loss_type in [LossType.MSE, LossType.RESCALED_MSE]:
+ output = model(x_t, timestep=t, **model_kwargs, return_dict=False)[0]
+
+ if self.return_startx and self.model_mean_type == ModelMeanType.EPSILON:
+ return self._extracted_from_training_losses_diffusers(x_t, output, t)
+
+ if self.model_var_type in [
+ ModelVarType.LEARNED,
+ ModelVarType.LEARNED_RANGE,
+ ]:
+ B, C = x_t.shape[:2]
+ assert output.shape == (B, C * 2, *x_t.shape[2:])
+ output, model_var_values = th.split(output, C, dim=1)
+ # Learn the variance using the variational bound, but don't let it affect our mean prediction.
+ frozen_out = th.cat([output.detach(), model_var_values], dim=1)
+ terms["vb"] = self._vb_terms_bpd(
+ model=lambda *args, r=frozen_out, **kwargs: r,
+ x_start=x_start,
+ x_t=x_t,
+ t=t,
+ clip_denoised=False,
+ )["output"]
+ if self.loss_type == LossType.RESCALED_MSE:
+ # Divide by 1000 for equivalence with initial implementation.
+ # Without a factor of 1/1000, the VB term hurts the MSE term.
+ terms["vb"] *= self.num_timesteps / 1000.0
+
+ target = {
+ ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance(
+ x_start=x_start, x_t=x_t, t=t
+ )[0],
+ ModelMeanType.START_X: x_start,
+ ModelMeanType.EPSILON: noise,
+ }[self.model_mean_type]
+ assert output.shape == target.shape == x_start.shape
+ if self.snr:
+ if self.model_mean_type == ModelMeanType.START_X:
+ pred_noise = self._predict_eps_from_xstart(x_t=x_t, t=t, pred_xstart=output)
+ pred_startx = output
+ elif self.model_mean_type == ModelMeanType.EPSILON:
+ pred_noise = output
+ pred_startx = self._predict_xstart_from_eps(x_t=x_t, t=t, eps=output)
+ # terms["mse_eps"] = mean_flat((noise - pred_noise) ** 2)
+ # terms["mse_x0"] = mean_flat((x_start - pred_startx) ** 2)
+
+ t = t[:, None, None, None].expand(pred_startx.shape) # [128, 4, 32, 32]
+ # best
+ target = th.where(t > 249, noise, x_start)
+ output = th.where(t > 249, pred_noise, pred_startx)
+ loss = (target - output) ** 2
+ terms["mse"] = mean_flat(loss)
+ terms["loss"] = terms["mse"] + terms["vb"] if "vb" in terms else terms["mse"]
+ if "mae" in terms:
+ terms["loss"] = terms["loss"] + terms["mae"]
+ else:
+ raise NotImplementedError(self.loss_type)
+
+ return terms
+
+ def _extracted_from_training_losses_diffusers(self, x_t, output, t):
+ B, C = x_t.shape[:2]
+ assert output.shape == (B, C * 2, *x_t.shape[2:])
+ output = th.split(output, C, dim=1)[0]
+ return output, self._predict_xstart_from_eps(x_t=x_t, t=t, eps=output), x_t
+
+ def _prior_bpd(self, x_start):
+ """
+ Get the prior KL term for the variational lower-bound, measured in
+ bits-per-dim.
+ This term can't be optimized, as it only depends on the encoder.
+ :param x_start: the [N x C x ...] tensor of inputs.
+ :return: a batch of [N] KL values (in bits), one per batch element.
+ """
+ batch_size = x_start.shape[0]
+ t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
+ qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
+ kl_prior = normal_kl(
+ mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0
+ )
+ return mean_flat(kl_prior) / np.log(2.0)
+
+ def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None):
+ """
+ Compute the entire variational lower-bound, measured in bits-per-dim,
+ as well as other related quantities.
+ :param model: the model to evaluate loss on.
+ :param x_start: the [N x C x ...] tensor of inputs.
+ :param clip_denoised: if True, clip denoised samples.
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
+ pass to the model. This can be used for conditioning.
+ :return: a dict containing the following keys:
+ - total_bpd: the total variational lower-bound, per batch element.
+ - prior_bpd: the prior term in the lower-bound.
+ - vb: an [N x T] tensor of terms in the lower-bound.
+ - xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep.
+ - mse: an [N x T] tensor of epsilon MSEs for each timestep.
+ """
+ device = x_start.device
+ batch_size = x_start.shape[0]
+
+ vb = []
+ xstart_mse = []
+ mse = []
+ for t in list(range(self.num_timesteps))[::-1]:
+ t_batch = th.tensor([t] * batch_size, device=device)
+ noise = th.randn_like(x_start)
+ x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise)
+ # Calculate VLB term at the current timestep
+ with th.no_grad():
+ out = self._vb_terms_bpd(
+ model,
+ x_start=x_start,
+ x_t=x_t,
+ t=t_batch,
+ clip_denoised=clip_denoised,
+ model_kwargs=model_kwargs,
+ )
+ vb.append(out["output"])
+ xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2))
+ eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"])
+ mse.append(mean_flat((eps - noise) ** 2))
+
+ vb = th.stack(vb, dim=1)
+ xstart_mse = th.stack(xstart_mse, dim=1)
+ mse = th.stack(mse, dim=1)
+
+ prior_bpd = self._prior_bpd(x_start)
+ total_bpd = vb.sum(dim=1) + prior_bpd
+ return {
+ "total_bpd": total_bpd,
+ "prior_bpd": prior_bpd,
+ "vb": vb,
+ "xstart_mse": xstart_mse,
+ "mse": mse,
+ }
+
+
+def _extract_into_tensor(arr, timesteps, broadcast_shape):
+ """
+ Extract values from a 1-D numpy array for a batch of indices.
+ :param arr: the 1-D numpy array.
+ :param timesteps: a tensor of indices into the array to extract.
+ :param broadcast_shape: a larger shape of K dimensions with the batch
+ dimension equal to the length of timesteps.
+ :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
+ """
+ res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
+ while len(res.shape) < len(broadcast_shape):
+ res = res[..., None]
+ return res + th.zeros(broadcast_shape, device=timesteps.device)
\ No newline at end of file
diff --git a/easyanimate/utils/lora_utils.py b/easyanimate/utils/lora_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..9673a7d51d9949f49a10ea50382d9d0117cb2aa6
--- /dev/null
+++ b/easyanimate/utils/lora_utils.py
@@ -0,0 +1,476 @@
+# LoRA network module
+# reference:
+# https://github.com/microsoft/LoRA/blob/main/loralib/layers.py
+# https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py
+# https://github.com/bmaltais/kohya_ss
+
+import hashlib
+import math
+import os
+from collections import defaultdict
+from io import BytesIO
+from typing import List, Optional, Type, Union
+
+import safetensors.torch
+import torch
+import torch.utils.checkpoint
+from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear
+from safetensors.torch import load_file
+from transformers import T5EncoderModel
+
+
+class LoRAModule(torch.nn.Module):
+ """
+ replaces forward method of the original Linear, instead of replacing the original Linear module.
+ """
+
+ def __init__(
+ self,
+ lora_name,
+ org_module: torch.nn.Module,
+ multiplier=1.0,
+ lora_dim=4,
+ alpha=1,
+ dropout=None,
+ rank_dropout=None,
+ module_dropout=None,
+ ):
+ """if alpha == 0 or None, alpha is rank (no scaling)."""
+ super().__init__()
+ self.lora_name = lora_name
+
+ if org_module.__class__.__name__ == "Conv2d":
+ in_dim = org_module.in_channels
+ out_dim = org_module.out_channels
+ else:
+ in_dim = org_module.in_features
+ out_dim = org_module.out_features
+
+ self.lora_dim = lora_dim
+ if org_module.__class__.__name__ == "Conv2d":
+ kernel_size = org_module.kernel_size
+ stride = org_module.stride
+ padding = org_module.padding
+ self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False)
+ self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False)
+ else:
+ self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False)
+ self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False)
+
+ if type(alpha) == torch.Tensor:
+ alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
+ alpha = self.lora_dim if alpha is None or alpha == 0 else alpha
+ self.scale = alpha / self.lora_dim
+ self.register_buffer("alpha", torch.tensor(alpha))
+
+ # same as microsoft's
+ torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
+ torch.nn.init.zeros_(self.lora_up.weight)
+
+ self.multiplier = multiplier
+ self.org_module = org_module # remove in applying
+ self.dropout = dropout
+ self.rank_dropout = rank_dropout
+ self.module_dropout = module_dropout
+
+ def apply_to(self):
+ self.org_forward = self.org_module.forward
+ self.org_module.forward = self.forward
+ del self.org_module
+
+ def forward(self, x, *args, **kwargs):
+ weight_dtype = x.dtype
+ org_forwarded = self.org_forward(x)
+
+ # module dropout
+ if self.module_dropout is not None and self.training:
+ if torch.rand(1) < self.module_dropout:
+ return org_forwarded
+
+ lx = self.lora_down(x.to(self.lora_down.weight.dtype))
+
+ # normal dropout
+ if self.dropout is not None and self.training:
+ lx = torch.nn.functional.dropout(lx, p=self.dropout)
+
+ # rank dropout
+ if self.rank_dropout is not None and self.training:
+ mask = torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout
+ if len(lx.size()) == 3:
+ mask = mask.unsqueeze(1) # for Text Encoder
+ elif len(lx.size()) == 4:
+ mask = mask.unsqueeze(-1).unsqueeze(-1) # for Conv2d
+ lx = lx * mask
+
+ # scaling for rank dropout: treat as if the rank is changed
+ scale = self.scale * (1.0 / (1.0 - self.rank_dropout)) # redundant for readability
+ else:
+ scale = self.scale
+
+ lx = self.lora_up(lx)
+
+ return org_forwarded.to(weight_dtype) + lx.to(weight_dtype) * self.multiplier * scale
+
+
+def addnet_hash_legacy(b):
+ """Old model hash used by sd-webui-additional-networks for .safetensors format files"""
+ m = hashlib.sha256()
+
+ b.seek(0x100000)
+ m.update(b.read(0x10000))
+ return m.hexdigest()[0:8]
+
+
+def addnet_hash_safetensors(b):
+ """New model hash used by sd-webui-additional-networks for .safetensors format files"""
+ hash_sha256 = hashlib.sha256()
+ blksize = 1024 * 1024
+
+ b.seek(0)
+ header = b.read(8)
+ n = int.from_bytes(header, "little")
+
+ offset = n + 8
+ b.seek(offset)
+ for chunk in iter(lambda: b.read(blksize), b""):
+ hash_sha256.update(chunk)
+
+ return hash_sha256.hexdigest()
+
+
+def precalculate_safetensors_hashes(tensors, metadata):
+ """Precalculate the model hashes needed by sd-webui-additional-networks to
+ save time on indexing the model later."""
+
+ # Because writing user metadata to the file can change the result of
+ # sd_models.model_hash(), only retain the training metadata for purposes of
+ # calculating the hash, as they are meant to be immutable
+ metadata = {k: v for k, v in metadata.items() if k.startswith("ss_")}
+
+ bytes = safetensors.torch.save(tensors, metadata)
+ b = BytesIO(bytes)
+
+ model_hash = addnet_hash_safetensors(b)
+ legacy_hash = addnet_hash_legacy(b)
+ return model_hash, legacy_hash
+
+
+class LoRANetwork(torch.nn.Module):
+ TRANSFORMER_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Transformer3DModel"]
+ TEXT_ENCODER_TARGET_REPLACE_MODULE = ["T5LayerSelfAttention", "T5LayerFF"]
+ LORA_PREFIX_TRANSFORMER = "lora_unet"
+ LORA_PREFIX_TEXT_ENCODER = "lora_te"
+ def __init__(
+ self,
+ text_encoder: Union[List[T5EncoderModel], T5EncoderModel],
+ unet,
+ multiplier: float = 1.0,
+ lora_dim: int = 4,
+ alpha: float = 1,
+ dropout: Optional[float] = None,
+ module_class: Type[object] = LoRAModule,
+ add_lora_in_attn_temporal: bool = False,
+ varbose: Optional[bool] = False,
+ ) -> None:
+ super().__init__()
+ self.multiplier = multiplier
+
+ self.lora_dim = lora_dim
+ self.alpha = alpha
+ self.dropout = dropout
+
+ print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}")
+ print(f"neuron dropout: p={self.dropout}")
+
+ # create module instances
+ def create_modules(
+ is_unet: bool,
+ root_module: torch.nn.Module,
+ target_replace_modules: List[torch.nn.Module],
+ ) -> List[LoRAModule]:
+ prefix = (
+ self.LORA_PREFIX_TRANSFORMER
+ if is_unet
+ else self.LORA_PREFIX_TEXT_ENCODER
+ )
+ loras = []
+ skipped = []
+ for name, module in root_module.named_modules():
+ if module.__class__.__name__ in target_replace_modules:
+ for child_name, child_module in module.named_modules():
+ is_linear = child_module.__class__.__name__ == "Linear" or child_module.__class__.__name__ == "LoRACompatibleLinear"
+ is_conv2d = child_module.__class__.__name__ == "Conv2d" or child_module.__class__.__name__ == "LoRACompatibleConv"
+ is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1)
+
+ if not add_lora_in_attn_temporal:
+ if "attn_temporal" in child_name:
+ continue
+
+ if is_linear or is_conv2d:
+ lora_name = prefix + "." + name + "." + child_name
+ lora_name = lora_name.replace(".", "_")
+
+ dim = None
+ alpha = None
+
+ if is_linear or is_conv2d_1x1:
+ dim = self.lora_dim
+ alpha = self.alpha
+
+ if dim is None or dim == 0:
+ if is_linear or is_conv2d_1x1:
+ skipped.append(lora_name)
+ continue
+
+ lora = module_class(
+ lora_name,
+ child_module,
+ self.multiplier,
+ dim,
+ alpha,
+ dropout=dropout,
+ )
+ loras.append(lora)
+ return loras, skipped
+
+ text_encoders = text_encoder if type(text_encoder) == list else [text_encoder]
+
+ self.text_encoder_loras = []
+ skipped_te = []
+ for i, text_encoder in enumerate(text_encoders):
+ text_encoder_loras, skipped = create_modules(False, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
+ self.text_encoder_loras.extend(text_encoder_loras)
+ skipped_te += skipped
+ print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
+
+ self.unet_loras, skipped_un = create_modules(True, unet, LoRANetwork.TRANSFORMER_TARGET_REPLACE_MODULE)
+ print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
+
+ # assertion
+ names = set()
+ for lora in self.text_encoder_loras + self.unet_loras:
+ assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
+ names.add(lora.lora_name)
+
+ def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True):
+ if apply_text_encoder:
+ print("enable LoRA for text encoder")
+ else:
+ self.text_encoder_loras = []
+
+ if apply_unet:
+ print("enable LoRA for U-Net")
+ else:
+ self.unet_loras = []
+
+ for lora in self.text_encoder_loras + self.unet_loras:
+ lora.apply_to()
+ self.add_module(lora.lora_name, lora)
+
+ def set_multiplier(self, multiplier):
+ self.multiplier = multiplier
+ for lora in self.text_encoder_loras + self.unet_loras:
+ lora.multiplier = self.multiplier
+
+ def load_weights(self, file):
+ if os.path.splitext(file)[1] == ".safetensors":
+ from safetensors.torch import load_file
+
+ weights_sd = load_file(file)
+ else:
+ weights_sd = torch.load(file, map_location="cpu")
+ info = self.load_state_dict(weights_sd, False)
+ return info
+
+ def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
+ self.requires_grad_(True)
+ all_params = []
+
+ def enumerate_params(loras):
+ params = []
+ for lora in loras:
+ params.extend(lora.parameters())
+ return params
+
+ if self.text_encoder_loras:
+ param_data = {"params": enumerate_params(self.text_encoder_loras)}
+ if text_encoder_lr is not None:
+ param_data["lr"] = text_encoder_lr
+ all_params.append(param_data)
+
+ if self.unet_loras:
+ param_data = {"params": enumerate_params(self.unet_loras)}
+ if unet_lr is not None:
+ param_data["lr"] = unet_lr
+ all_params.append(param_data)
+
+ return all_params
+
+ def enable_gradient_checkpointing(self):
+ pass
+
+ def get_trainable_params(self):
+ return self.parameters()
+
+ def save_weights(self, file, dtype, metadata):
+ if metadata is not None and len(metadata) == 0:
+ metadata = None
+
+ state_dict = self.state_dict()
+
+ if dtype is not None:
+ for key in list(state_dict.keys()):
+ v = state_dict[key]
+ v = v.detach().clone().to("cpu").to(dtype)
+ state_dict[key] = v
+
+ if os.path.splitext(file)[1] == ".safetensors":
+ from safetensors.torch import save_file
+
+ # Precalculate model hashes to save time on indexing
+ if metadata is None:
+ metadata = {}
+ model_hash, legacy_hash = precalculate_safetensors_hashes(state_dict, metadata)
+ metadata["sshs_model_hash"] = model_hash
+ metadata["sshs_legacy_hash"] = legacy_hash
+
+ save_file(state_dict, file, metadata)
+ else:
+ torch.save(state_dict, file)
+
+def create_network(
+ multiplier: float,
+ network_dim: Optional[int],
+ network_alpha: Optional[float],
+ text_encoder: Union[T5EncoderModel, List[T5EncoderModel]],
+ transformer,
+ neuron_dropout: Optional[float] = None,
+ add_lora_in_attn_temporal: bool = False,
+ **kwargs,
+):
+ if network_dim is None:
+ network_dim = 4 # default
+ if network_alpha is None:
+ network_alpha = 1.0
+
+ network = LoRANetwork(
+ text_encoder,
+ transformer,
+ multiplier=multiplier,
+ lora_dim=network_dim,
+ alpha=network_alpha,
+ dropout=neuron_dropout,
+ add_lora_in_attn_temporal=add_lora_in_attn_temporal,
+ varbose=True,
+ )
+ return network
+
+def merge_lora(pipeline, lora_path, multiplier, device='cpu', dtype=torch.float32, state_dict=None, transformer_only=False):
+ LORA_PREFIX_TRANSFORMER = "lora_unet"
+ LORA_PREFIX_TEXT_ENCODER = "lora_te"
+ if state_dict is None:
+ state_dict = load_file(lora_path, device=device)
+ else:
+ state_dict = state_dict
+ updates = defaultdict(dict)
+ for key, value in state_dict.items():
+ layer, elem = key.split('.', 1)
+ updates[layer][elem] = value
+
+ for layer, elems in updates.items():
+
+ if "lora_te" in layer:
+ if transformer_only:
+ continue
+ else:
+ layer_infos = layer.split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_")
+ curr_layer = pipeline.text_encoder
+ else:
+ layer_infos = layer.split(LORA_PREFIX_TRANSFORMER + "_")[-1].split("_")
+ curr_layer = pipeline.transformer
+
+ temp_name = layer_infos.pop(0)
+ while len(layer_infos) > -1:
+ try:
+ curr_layer = curr_layer.__getattr__(temp_name)
+ if len(layer_infos) > 0:
+ temp_name = layer_infos.pop(0)
+ elif len(layer_infos) == 0:
+ break
+ except Exception:
+ if len(layer_infos) == 0:
+ print('Error loading layer')
+ if len(temp_name) > 0:
+ temp_name += "_" + layer_infos.pop(0)
+ else:
+ temp_name = layer_infos.pop(0)
+
+ weight_up = elems['lora_up.weight'].to(dtype)
+ weight_down = elems['lora_down.weight'].to(dtype)
+ if 'alpha' in elems.keys():
+ alpha = elems['alpha'].item() / weight_up.shape[1]
+ else:
+ alpha = 1.0
+
+ curr_layer.weight.data = curr_layer.weight.data.to(device)
+ if len(weight_up.shape) == 4:
+ curr_layer.weight.data += multiplier * alpha * torch.mm(weight_up.squeeze(3).squeeze(2),
+ weight_down.squeeze(3).squeeze(2)).unsqueeze(
+ 2).unsqueeze(3)
+ else:
+ curr_layer.weight.data += multiplier * alpha * torch.mm(weight_up, weight_down)
+
+ return pipeline
+
+# TODO: Refactor with merge_lora.
+def unmerge_lora(pipeline, lora_path, multiplier=1, device="cpu", dtype=torch.float32):
+ """Unmerge state_dict in LoRANetwork from the pipeline in diffusers."""
+ LORA_PREFIX_UNET = "lora_unet"
+ LORA_PREFIX_TEXT_ENCODER = "lora_te"
+ state_dict = load_file(lora_path, device=device)
+
+ updates = defaultdict(dict)
+ for key, value in state_dict.items():
+ layer, elem = key.split('.', 1)
+ updates[layer][elem] = value
+
+ for layer, elems in updates.items():
+
+ if "lora_te" in layer:
+ layer_infos = layer.split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_")
+ curr_layer = pipeline.text_encoder
+ else:
+ layer_infos = layer.split(LORA_PREFIX_UNET + "_")[-1].split("_")
+ curr_layer = pipeline.transformer
+
+ temp_name = layer_infos.pop(0)
+ while len(layer_infos) > -1:
+ try:
+ curr_layer = curr_layer.__getattr__(temp_name)
+ if len(layer_infos) > 0:
+ temp_name = layer_infos.pop(0)
+ elif len(layer_infos) == 0:
+ break
+ except Exception:
+ if len(layer_infos) == 0:
+ print('Error loading layer')
+ if len(temp_name) > 0:
+ temp_name += "_" + layer_infos.pop(0)
+ else:
+ temp_name = layer_infos.pop(0)
+
+ weight_up = elems['lora_up.weight'].to(dtype)
+ weight_down = elems['lora_down.weight'].to(dtype)
+ if 'alpha' in elems.keys():
+ alpha = elems['alpha'].item() / weight_up.shape[1]
+ else:
+ alpha = 1.0
+
+ curr_layer.weight.data = curr_layer.weight.data.to(device)
+ if len(weight_up.shape) == 4:
+ curr_layer.weight.data -= multiplier * alpha * torch.mm(weight_up.squeeze(3).squeeze(2),
+ weight_down.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
+ else:
+ curr_layer.weight.data -= multiplier * alpha * torch.mm(weight_up, weight_down)
+
+ return pipeline
\ No newline at end of file
diff --git a/easyanimate/utils/respace.py b/easyanimate/utils/respace.py
new file mode 100644
index 0000000000000000000000000000000000000000..bbdfe77eb05dcd9858b5b23106810642e6a514c2
--- /dev/null
+++ b/easyanimate/utils/respace.py
@@ -0,0 +1,131 @@
+# Modified from OpenAI's diffusion repos
+# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
+# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
+# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
+
+import numpy as np
+import torch as th
+
+from .gaussian_diffusion import GaussianDiffusion
+
+
+def space_timesteps(num_timesteps, section_counts):
+ """
+ Create a list of timesteps to use from an original diffusion process,
+ given the number of timesteps we want to take from equally-sized portions
+ of the original process.
+ For example, if there's 300 timesteps and the section counts are [10,15,20]
+ then the first 100 timesteps are strided to be 10 timesteps, the second 100
+ are strided to be 15 timesteps, and the final 100 are strided to be 20.
+ If the stride is a string starting with "ddim", then the fixed striding
+ from the DDIM paper is used, and only one section is allowed.
+ :param num_timesteps: the number of diffusion steps in the original
+ process to divide up.
+ :param section_counts: either a list of numbers, or a string containing
+ comma-separated numbers, indicating the step count
+ per section. As a special case, use "ddimN" where N
+ is a number of steps to use the striding from the
+ DDIM paper.
+ :return: a set of diffusion steps from the original process to use.
+ """
+ if isinstance(section_counts, str):
+ if section_counts.startswith("ddim"):
+ desired_count = int(section_counts[len("ddim") :])
+ for i in range(1, num_timesteps):
+ if len(range(0, num_timesteps, i)) == desired_count:
+ return set(range(0, num_timesteps, i))
+ raise ValueError(
+ f"cannot create exactly {num_timesteps} steps with an integer stride"
+ )
+ section_counts = [int(x) for x in section_counts.split(",")]
+ size_per = num_timesteps // len(section_counts)
+ extra = num_timesteps % len(section_counts)
+ start_idx = 0
+ all_steps = []
+ for i, section_count in enumerate(section_counts):
+ size = size_per + (1 if i < extra else 0)
+ if size < section_count:
+ raise ValueError(
+ f"cannot divide section of {size} steps into {section_count}"
+ )
+ frac_stride = 1 if section_count <= 1 else (size - 1) / (section_count - 1)
+ cur_idx = 0.0
+ taken_steps = []
+ for _ in range(section_count):
+ taken_steps.append(start_idx + round(cur_idx))
+ cur_idx += frac_stride
+ all_steps += taken_steps
+ start_idx += size
+ return set(all_steps)
+
+
+class SpacedDiffusion(GaussianDiffusion):
+ """
+ A diffusion process which can skip steps in a base diffusion process.
+ :param use_timesteps: a collection (sequence or set) of timesteps from the
+ original diffusion process to retain.
+ :param kwargs: the kwargs to create the base diffusion process.
+ """
+
+ def __init__(self, use_timesteps, **kwargs):
+ self.use_timesteps = set(use_timesteps)
+ self.timestep_map = []
+ self.original_num_steps = len(kwargs["betas"])
+
+ base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa
+ last_alpha_cumprod = 1.0
+ new_betas = []
+ for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
+ if i in self.use_timesteps:
+ new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
+ last_alpha_cumprod = alpha_cumprod
+ self.timestep_map.append(i)
+ kwargs["betas"] = np.array(new_betas)
+ super().__init__(**kwargs)
+
+ def p_mean_variance(
+ self, model, *args, **kwargs
+ ): # pylint: disable=signature-differs
+ return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
+
+ def training_losses(
+ self, model, *args, **kwargs
+ ): # pylint: disable=signature-differs
+ return super().training_losses(self._wrap_model(model), *args, **kwargs)
+
+ def training_losses_diffusers(
+ self, model, *args, **kwargs
+ ): # pylint: disable=signature-differs
+ return super().training_losses_diffusers(self._wrap_model(model), *args, **kwargs)
+
+ def condition_mean(self, cond_fn, *args, **kwargs):
+ return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs)
+
+ def condition_score(self, cond_fn, *args, **kwargs):
+ return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs)
+
+ def _wrap_model(self, model):
+ if isinstance(model, _WrappedModel):
+ return model
+ return _WrappedModel(
+ model, self.timestep_map, self.original_num_steps
+ )
+
+ def _scale_timesteps(self, t):
+ # Scaling is done by the wrapped model.
+ return t
+
+
+class _WrappedModel:
+ def __init__(self, model, timestep_map, original_num_steps):
+ self.model = model
+ self.timestep_map = timestep_map
+ # self.rescale_timesteps = rescale_timesteps
+ self.original_num_steps = original_num_steps
+
+ def __call__(self, x, timestep, **kwargs):
+ map_tensor = th.tensor(self.timestep_map, device=timestep.device, dtype=timestep.dtype)
+ new_ts = map_tensor[timestep]
+ # if self.rescale_timesteps:
+ # new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
+ return self.model(x, timestep=new_ts, **kwargs)
\ No newline at end of file
diff --git a/easyanimate/utils/utils.py b/easyanimate/utils/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..0fea292414ed7d8590189338d80cb434d5e735ea
--- /dev/null
+++ b/easyanimate/utils/utils.py
@@ -0,0 +1,64 @@
+import os
+
+import imageio
+import numpy as np
+import torch
+import torchvision
+import cv2
+from einops import rearrange
+from PIL import Image
+
+
+def color_transfer(sc, dc):
+ """
+ Transfer color distribution from of sc, referred to dc.
+
+ Args:
+ sc (numpy.ndarray): input image to be transfered.
+ dc (numpy.ndarray): reference image
+
+ Returns:
+ numpy.ndarray: Transferred color distribution on the sc.
+ """
+
+ def get_mean_and_std(img):
+ x_mean, x_std = cv2.meanStdDev(img)
+ x_mean = np.hstack(np.around(x_mean, 2))
+ x_std = np.hstack(np.around(x_std, 2))
+ return x_mean, x_std
+
+ sc = cv2.cvtColor(sc, cv2.COLOR_RGB2LAB)
+ s_mean, s_std = get_mean_and_std(sc)
+ dc = cv2.cvtColor(dc, cv2.COLOR_RGB2LAB)
+ t_mean, t_std = get_mean_and_std(dc)
+ img_n = ((sc - s_mean) * (t_std / s_std)) + t_mean
+ np.putmask(img_n, img_n > 255, 255)
+ np.putmask(img_n, img_n < 0, 0)
+ dst = cv2.cvtColor(cv2.convertScaleAbs(img_n), cv2.COLOR_LAB2RGB)
+ return dst
+
+def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=12, imageio_backend=True, color_transfer_post_process=False):
+ videos = rearrange(videos, "b c t h w -> t b c h w")
+ outputs = []
+ for x in videos:
+ x = torchvision.utils.make_grid(x, nrow=n_rows)
+ x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
+ if rescale:
+ x = (x + 1.0) / 2.0 # -1,1 -> 0,1
+ x = (x * 255).numpy().astype(np.uint8)
+ outputs.append(Image.fromarray(x))
+
+ if color_transfer_post_process:
+ for i in range(1, len(outputs)):
+ outputs[i] = Image.fromarray(color_transfer(np.uint8(outputs[i]), np.uint8(outputs[0])))
+
+ os.makedirs(os.path.dirname(path), exist_ok=True)
+ if imageio_backend:
+ if path.endswith("mp4"):
+ imageio.mimsave(path, outputs, fps=fps)
+ else:
+ imageio.mimsave(path, outputs, duration=(1000 * 1/fps))
+ else:
+ if path.endswith("mp4"):
+ path = path.replace('.mp4', '.gif')
+ outputs[0].save(path, format='GIF', append_images=outputs, save_all=True, duration=100, loop=0)
diff --git a/easyanimate/vae/LICENSE b/easyanimate/vae/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..0e609df0d8cd3b5d11a1ea962a56b604b70846a5
--- /dev/null
+++ b/easyanimate/vae/LICENSE
@@ -0,0 +1,82 @@
+Copyright (c) 2022 Robin Rombach and Patrick Esser and contributors
+
+CreativeML Open RAIL-M
+dated August 22, 2022
+
+Section I: PREAMBLE
+
+Multimodal generative models are being widely adopted and used, and have the potential to transform the way artists, among other individuals, conceive and benefit from AI or ML technologies as a tool for content creation.
+
+Notwithstanding the current and potential benefits that these artifacts can bring to society at large, there are also concerns about potential misuses of them, either due to their technical limitations or ethical considerations.
+
+In short, this license strives for both the open and responsible downstream use of the accompanying model. When it comes to the open character, we took inspiration from open source permissive licenses regarding the grant of IP rights. Referring to the downstream responsible use, we added use-based restrictions not permitting the use of the Model in very specific scenarios, in order for the licensor to be able to enforce the license in case potential misuses of the Model may occur. At the same time, we strive to promote open and responsible research on generative models for art and content generation.
+
+Even though downstream derivative versions of the model could be released under different licensing terms, the latter will always have to include - at minimum - the same use-based restrictions as the ones in the original license (this license). We believe in the intersection between open and responsible AI development; thus, this License aims to strike a balance between both in order to enable responsible open-science in the field of AI.
+
+This License governs the use of the model (and its derivatives) and is informed by the model card associated with the model.
+
+NOW THEREFORE, You and Licensor agree as follows:
+
+1. Definitions
+
+- "License" means the terms and conditions for use, reproduction, and Distribution as defined in this document.
+- "Data" means a collection of information and/or content extracted from the dataset used with the Model, including to train, pretrain, or otherwise evaluate the Model. The Data is not licensed under this License.
+- "Output" means the results of operating a Model as embodied in informational content resulting therefrom.
+- "Model" means any accompanying machine-learning based assemblies (including checkpoints), consisting of learnt weights, parameters (including optimizer states), corresponding to the model architecture as embodied in the Complementary Material, that have been trained or tuned, in whole or in part on the Data, using the Complementary Material.
+- "Derivatives of the Model" means all modifications to the Model, works based on the Model, or any other model which is created or initialized by transfer of patterns of the weights, parameters, activations or output of the Model, to the other model, in order to cause the other model to perform similarly to the Model, including - but not limited to - distillation methods entailing the use of intermediate data representations or methods based on the generation of synthetic data by the Model for training the other model.
+- "Complementary Material" means the accompanying source code and scripts used to define, run, load, benchmark or evaluate the Model, and used to prepare data for training or evaluation, if any. This includes any accompanying documentation, tutorials, examples, etc, if any.
+- "Distribution" means any transmission, reproduction, publication or other sharing of the Model or Derivatives of the Model to a third party, including providing the Model as a hosted service made available by electronic or other remote means - e.g. API-based or web access.
+- "Licensor" means the copyright owner or entity authorized by the copyright owner that is granting the License, including the persons or entities that may have rights in the Model and/or distributing the Model.
+- "You" (or "Your") means an individual or Legal Entity exercising permissions granted by this License and/or making use of the Model for whichever purpose and in any field of use, including usage of the Model in an end-use application - e.g. chatbot, translator, image generator.
+- "Third Parties" means individuals or legal entities that are not under common control with Licensor or You.
+- "Contribution" means any work of authorship, including the original version of the Model and any modifications or additions to that Model or Derivatives of the Model thereof, that is intentionally submitted to Licensor for inclusion in the Model by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Model, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution."
+- "Contributor" means Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Model.
+
+Section II: INTELLECTUAL PROPERTY RIGHTS
+
+Both copyright and patent grants apply to the Model, Derivatives of the Model and Complementary Material. The Model and Derivatives of the Model are subject to additional terms as described in Section III.
+
+2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare, publicly display, publicly perform, sublicense, and distribute the Complementary Material, the Model, and Derivatives of the Model.
+3. Grant of Patent License. Subject to the terms and conditions of this License and where and as applicable, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this paragraph) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Model and the Complementary Material, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Model to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Model and/or Complementary Material or a Contribution incorporated within the Model and/or Complementary Material constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for the Model and/or Work shall terminate as of the date such litigation is asserted or filed.
+
+Section III: CONDITIONS OF USAGE, DISTRIBUTION AND REDISTRIBUTION
+
+4. Distribution and Redistribution. You may host for Third Party remote access purposes (e.g. software-as-a-service), reproduce and distribute copies of the Model or Derivatives of the Model thereof in any medium, with or without modifications, provided that You meet the following conditions:
+Use-based restrictions as referenced in paragraph 5 MUST be included as an enforceable provision by You in any type of legal agreement (e.g. a license) governing the use and/or distribution of the Model or Derivatives of the Model, and You shall give notice to subsequent users You Distribute to, that the Model or Derivatives of the Model are subject to paragraph 5. This provision does not apply to the use of Complementary Material.
+You must give any Third Party recipients of the Model or Derivatives of the Model a copy of this License;
+You must cause any modified files to carry prominent notices stating that You changed the files;
+You must retain all copyright, patent, trademark, and attribution notices excluding those notices that do not pertain to any part of the Model, Derivatives of the Model.
+You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions - respecting paragraph 4.a. - for use, reproduction, or Distribution of Your modifications, or for any such Derivatives of the Model as a whole, provided Your use, reproduction, and Distribution of the Model otherwise complies with the conditions stated in this License.
+5. Use-based restrictions. The restrictions set forth in Attachment A are considered Use-based restrictions. Therefore You cannot use the Model and the Derivatives of the Model for the specified restricted uses. You may use the Model subject to this License, including only for lawful purposes and in accordance with the License. Use may include creating any content with, finetuning, updating, running, training, evaluating and/or reparametrizing the Model. You shall require all of Your users who use the Model or a Derivative of the Model to comply with the terms of this paragraph (paragraph 5).
+6. The Output You Generate. Except as set forth herein, Licensor claims no rights in the Output You generate using the Model. You are accountable for the Output you generate and its subsequent uses. No use of the output can contravene any provision as stated in the License.
+
+Section IV: OTHER PROVISIONS
+
+7. Updates and Runtime Restrictions. To the maximum extent permitted by law, Licensor reserves the right to restrict (remotely or otherwise) usage of the Model in violation of this License, update the Model through electronic means, or modify the Output of the Model based on updates. You shall undertake reasonable efforts to use the latest version of the Model.
+8. Trademarks and related. Nothing in this License permits You to make use of Licensors’ trademarks, trade names, logos or to otherwise suggest endorsement or misrepresent the relationship between the parties; and any rights not expressly granted herein are reserved by the Licensors.
+9. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Model and the Complementary Material (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Model, Derivatives of the Model, and the Complementary Material and assume any risks associated with Your exercise of permissions under this License.
+10. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Model and the Complementary Material (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages.
+11. Accepting Warranty or Additional Liability. While redistributing the Model, Derivatives of the Model and the Complementary Material thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability.
+12. If any provision of this License is held to be invalid, illegal or unenforceable, the remaining provisions shall be unaffected thereby and remain valid as if such provision had not been set forth herein.
+
+END OF TERMS AND CONDITIONS
+
+
+
+
+Attachment A
+
+Use Restrictions
+
+You agree not to use the Model or Derivatives of the Model:
+- In any way that violates any applicable national, federal, state, local or international law or regulation;
+- For the purpose of exploiting, harming or attempting to exploit or harm minors in any way;
+- To generate or disseminate verifiably false information and/or content with the purpose of harming others;
+- To generate or disseminate personal identifiable information that can be used to harm an individual;
+- To defame, disparage or otherwise harass others;
+- For fully automated decision making that adversely impacts an individual’s legal rights or otherwise creates or modifies a binding, enforceable obligation;
+- For any use intended to or which has the effect of discriminating against or harming individuals or groups based on online or offline social behavior or known or predicted personal or personality characteristics;
+- To exploit any of the vulnerabilities of a specific group of persons based on their age, social, physical or mental characteristics, in order to materially distort the behavior of a person pertaining to that group in a manner that causes or is likely to cause that person or another person physical or psychological harm;
+- For any use intended to or which has the effect of discriminating against individuals or groups based on legally protected characteristics or categories;
+- To provide medical advice and medical results interpretation;
+- To generate or disseminate information for the purpose to be used for administration of justice, law enforcement, immigration or asylum processes, such as predicting an individual will commit fraud/crime commitment (e.g. by text profiling, drawing causal relationships between assertions made in documents, indiscriminate and arbitrarily-targeted use).
diff --git a/easyanimate/vae/README.md b/easyanimate/vae/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..d6327120907f80b48adfa171638dbbdce29863c0
--- /dev/null
+++ b/easyanimate/vae/README.md
@@ -0,0 +1,63 @@
+## VAE Training
+
+English | [简体中文](./README_zh-CN.md)
+
+After completing data preprocessing, we can obtain the following dataset:
+
+```
+📦 project/
+├── 📂 datasets/
+│ ├── 📂 internal_datasets/
+│ ├── 📂 videos/
+│ │ ├── 📄 00000001.mp4
+│ │ ├── 📄 00000001.jpg
+│ │ └── 📄 .....
+│ └── 📄 json_of_internal_datasets.json
+```
+
+The json_of_internal_datasets.json is a standard JSON file. The file_path in the json can to be set as relative path, as shown in below:
+```json
+[
+ {
+ "file_path": "videos/00000001.mp4",
+ "text": "A group of young men in suits and sunglasses are walking down a city street.",
+ "type": "video"
+ },
+ {
+ "file_path": "train/00000001.jpg",
+ "text": "A group of young men in suits and sunglasses are walking down a city street.",
+ "type": "image"
+ },
+ .....
+]
+```
+
+You can also set the path as absolute path as follow:
+```json
+[
+ {
+ "file_path": "/mnt/data/videos/00000001.mp4",
+ "text": "A group of young men in suits and sunglasses are walking down a city street.",
+ "type": "video"
+ },
+ {
+ "file_path": "/mnt/data/train/00000001.jpg",
+ "text": "A group of young men in suits and sunglasses are walking down a city street.",
+ "type": "image"
+ },
+ .....
+]
+```
+
+## Train Video VAE
+We need to set config in ```easyanimate/vae/configs/autoencoder``` at first. The default config is ```autoencoder_kl_32x32x4_slice.yaml```. We need to set the some params in yaml file.
+
+- ```data_json_path``` corresponds to the JSON file of the dataset.
+- ```data_root``` corresponds to the root path of the dataset. If you want to use absolute path in json file, please delete this line.
+- ```ckpt_path``` corresponds to the pretrained weights of the vae.
+- ```gpus``` and num_nodes need to be set as the actual situation of your machine.
+
+The we run shell file as follow:
+```
+sh scripts/train_vae.sh
+```
\ No newline at end of file
diff --git a/easyanimate/vae/README_zh-CN.md b/easyanimate/vae/README_zh-CN.md
new file mode 100644
index 0000000000000000000000000000000000000000..6bc046712c9aeb94c9e2bb7e1ff4c712f9059b75
--- /dev/null
+++ b/easyanimate/vae/README_zh-CN.md
@@ -0,0 +1,63 @@
+## VAE 训练
+
+[English](./README.md) | 简体中文
+
+在完成数据预处理后,你可以获得这样的数据格式:
+
+```
+📦 project/
+├── 📂 datasets/
+│ ├── 📂 internal_datasets/
+│ ├── 📂 videos/
+│ │ ├── 📄 00000001.mp4
+│ │ ├── 📄 00000001.jpg
+│ │ └── 📄 .....
+│ └── 📄 json_of_internal_datasets.json
+```
+
+json_of_internal_datasets.json是一个标准的json文件。json中的file_path可以被设置为相对路径,如下所示:
+```json
+[
+ {
+ "file_path": "videos/00000001.mp4",
+ "text": "A group of young men in suits and sunglasses are walking down a city street.",
+ "type": "video"
+ },
+ {
+ "file_path": "train/00000001.jpg",
+ "text": "A group of young men in suits and sunglasses are walking down a city street.",
+ "type": "image"
+ },
+ .....
+]
+```
+
+你也可以将路径设置为绝对路径:
+```json
+[
+ {
+ "file_path": "/mnt/data/videos/00000001.mp4",
+ "text": "A group of young men in suits and sunglasses are walking down a city street.",
+ "type": "video"
+ },
+ {
+ "file_path": "/mnt/data/train/00000001.jpg",
+ "text": "A group of young men in suits and sunglasses are walking down a city street.",
+ "type": "image"
+ },
+ .....
+]
+```
+
+## 训练 Video VAE
+我们首先需要修改 ```easyanimate/vae/configs/autoencoder``` 中的配置文件。默认的配置文件是 ```autoencoder_kl_32x32x4_slice.yaml```。你需要修改以下参数:
+
+- ```data_json_path``` json file 所在的目录。
+- ```data_root``` 数据的根目录。如果你在json file中使用了绝对路径,请设置为空。
+- ```ckpt_path``` 预训练的vae模型路径。
+- ```gpus``` 以及 ```num_nodes``` 需要设置为你机器的实际gpu数目。
+
+运行以下的脚本来训练vae:
+```
+sh scripts/train_vae.sh
+```
\ No newline at end of file
diff --git a/easyanimate/vae/configs/autoencoder/autoencoder_kl_32x32x4_mag.yaml b/easyanimate/vae/configs/autoencoder/autoencoder_kl_32x32x4_mag.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..cfd0e0a5eaf70b07094f3038a47a8b4f839bf08a
--- /dev/null
+++ b/easyanimate/vae/configs/autoencoder/autoencoder_kl_32x32x4_mag.yaml
@@ -0,0 +1,62 @@
+model:
+ base_learning_rate: 1.0e-04
+ target: easyanimate.vae.ldm.models.omnigen_casual3dcnn.AutoencoderKLMagvit_fromOmnigen
+ params:
+ monitor: train/rec_loss
+ ckpt_path: models/videoVAE_omnigen_8x8x4_from_vae-ft-mse-840000-ema-pruned.ckpt
+ down_block_types: ("SpatialDownBlock3D", "SpatialTemporalDownBlock3D", "SpatialTemporalDownBlock3D",
+ "SpatialTemporalDownBlock3D",)
+ up_block_types: ("SpatialUpBlock3D", "SpatialTemporalUpBlock3D", "SpatialTemporalUpBlock3D",
+ "SpatialTemporalUpBlock3D",)
+ lossconfig:
+ target: easyanimate.vae.ldm.modules.losses.LPIPSWithDiscriminator
+ params:
+ disc_start: 50001
+ kl_weight: 1.0e-06
+ disc_weight: 0.5
+ l2_loss_weight: 0.1
+ l1_loss_weight: 1.0
+ perceptual_weight: 1.0
+
+data:
+ target: train_vae.DataModuleFromConfig
+
+ params:
+ batch_size: 2
+ wrap: true
+ num_workers: 4
+ train:
+ target: easyanimate.vae.ldm.data.dataset_image_video.CustomSRTrain
+ params:
+ data_json_path: pretrain.json
+ data_root: /your_data_root # This is used in relative path
+ size: 128
+ degradation: pil_nearest
+ video_size: 128
+ video_len: 9
+ slice_interval: 1
+ validation:
+ target: easyanimate.vae.ldm.data.dataset_image_video.CustomSRValidation
+ params:
+ data_json_path: pretrain.json
+ data_root: /your_data_root # This is used in relative path
+ size: 128
+ degradation: pil_nearest
+ video_size: 128
+ video_len: 9
+ slice_interval: 1
+
+lightning:
+ callbacks:
+ image_logger:
+ target: train_vae.ImageLogger
+ params:
+ batch_frequency: 5000
+ max_images: 8
+ increase_log_steps: True
+
+ trainer:
+ benchmark: True
+ accumulate_grad_batches: 1
+ gpus: "0"
+ num_nodes: 1
\ No newline at end of file
diff --git a/easyanimate/vae/configs/autoencoder/autoencoder_kl_32x32x4_slice.yaml b/easyanimate/vae/configs/autoencoder/autoencoder_kl_32x32x4_slice.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..8fe607c582d0da2edca038b74667374feae7ac45
--- /dev/null
+++ b/easyanimate/vae/configs/autoencoder/autoencoder_kl_32x32x4_slice.yaml
@@ -0,0 +1,65 @@
+model:
+ base_learning_rate: 1.0e-04
+ target: easyanimate.vae.ldm.models.omnigen_casual3dcnn.AutoencoderKLMagvit_fromOmnigen
+ params:
+ slice_compression_vae: true
+ mini_batch_encoder: 8
+ mini_batch_decoder: 2
+ monitor: train/rec_loss
+ ckpt_path: models/Diffusion_Transformer/EasyAnimateV2-XL-2-512x512/vae/diffusion_pytorch_model.safetensors
+ down_block_types: ("SpatialDownBlock3D", "SpatialTemporalDownBlock3D", "SpatialTemporalDownBlock3D",
+ "SpatialTemporalDownBlock3D",)
+ up_block_types: ("SpatialUpBlock3D", "SpatialTemporalUpBlock3D", "SpatialTemporalUpBlock3D",
+ "SpatialTemporalUpBlock3D",)
+ lossconfig:
+ target: easyanimate.vae.ldm.modules.losses.LPIPSWithDiscriminator
+ params:
+ disc_start: 50001
+ kl_weight: 1.0e-06
+ disc_weight: 0.5
+ l2_loss_weight: 0.0
+ l1_loss_weight: 1.0
+ perceptual_weight: 1.0
+
+data:
+ target: train_vae.DataModuleFromConfig
+
+ params:
+ batch_size: 1
+ wrap: true
+ num_workers: 8
+ train:
+ target: easyanimate.vae.ldm.data.dataset_image_video.CustomSRTrain
+ params:
+ data_json_path: pretrain.json
+ data_root: /your_data_root # This is used in relative path
+ size: 256
+ degradation: pil_nearest
+ video_size: 256
+ video_len: 25
+ slice_interval: 1
+ validation:
+ target: easyanimate.vae.ldm.data.dataset_image_video.CustomSRValidation
+ params:
+ data_json_path: pretrain.json
+ data_root: /your_data_root # This is used in relative path
+ size: 256
+ degradation: pil_nearest
+ video_size: 256
+ video_len: 25
+ slice_interval: 1
+
+lightning:
+ callbacks:
+ image_logger:
+ target: train_vae.ImageLogger
+ params:
+ batch_frequency: 5000
+ max_images: 8
+ increase_log_steps: True
+
+ trainer:
+ benchmark: True
+ accumulate_grad_batches: 1
+ gpus: "0"
+ num_nodes: 1
\ No newline at end of file
diff --git a/easyanimate/vae/configs/autoencoder/autoencoder_kl_32x32x4_slice_decoder_only.yaml b/easyanimate/vae/configs/autoencoder/autoencoder_kl_32x32x4_slice_decoder_only.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..32fa32b2997b6526216472bbbd0c03c04d3cfd3e
--- /dev/null
+++ b/easyanimate/vae/configs/autoencoder/autoencoder_kl_32x32x4_slice_decoder_only.yaml
@@ -0,0 +1,66 @@
+model:
+ base_learning_rate: 1.0e-04
+ target: easyanimate.vae.ldm.models.omnigen_casual3dcnn.AutoencoderKLMagvit_fromOmnigen
+ params:
+ slice_compression_vae: true
+ train_decoder_only: true
+ mini_batch_encoder: 8
+ mini_batch_decoder: 2
+ monitor: train/rec_loss
+ ckpt_path: models/Diffusion_Transformer/EasyAnimateV2-XL-2-512x512/vae/diffusion_pytorch_model.safetensors
+ down_block_types: ("SpatialDownBlock3D", "SpatialTemporalDownBlock3D", "SpatialTemporalDownBlock3D",
+ "SpatialTemporalDownBlock3D",)
+ up_block_types: ("SpatialUpBlock3D", "SpatialTemporalUpBlock3D", "SpatialTemporalUpBlock3D",
+ "SpatialTemporalUpBlock3D",)
+ lossconfig:
+ target: easyanimate.vae.ldm.modules.losses.LPIPSWithDiscriminator
+ params:
+ disc_start: 50001
+ kl_weight: 1.0e-06
+ disc_weight: 0.5
+ l2_loss_weight: 1.0
+ l1_loss_weight: 0.0
+ perceptual_weight: 1.0
+
+data:
+ target: train_vae.DataModuleFromConfig
+
+ params:
+ batch_size: 1
+ wrap: true
+ num_workers: 8
+ train:
+ target: easyanimate.vae.ldm.data.dataset_image_video.CustomSRTrain
+ params:
+ data_json_path: pretrain.json
+ data_root: /your_data_root # This is used in relative path
+ size: 256
+ degradation: pil_nearest
+ video_size: 256
+ video_len: 25
+ slice_interval: 1
+ validation:
+ target: easyanimate.vae.ldm.data.dataset_image_video.CustomSRValidation
+ params:
+ data_json_path: pretrain.json
+ data_root: /your_data_root # This is used in relative path
+ size: 256
+ degradation: pil_nearest
+ video_size: 256
+ video_len: 25
+ slice_interval: 1
+
+lightning:
+ callbacks:
+ image_logger:
+ target: train_vae.ImageLogger
+ params:
+ batch_frequency: 5000
+ max_images: 8
+ increase_log_steps: True
+
+ trainer:
+ benchmark: True
+ accumulate_grad_batches: 1
+ gpus: "0"
+ num_nodes: 1
\ No newline at end of file
diff --git a/easyanimate/vae/configs/autoencoder/autoencoder_kl_32x32x4_slice_t_downsample_8.yaml b/easyanimate/vae/configs/autoencoder/autoencoder_kl_32x32x4_slice_t_downsample_8.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..25f05f2ba60ae285140db6330970d5523413268b
--- /dev/null
+++ b/easyanimate/vae/configs/autoencoder/autoencoder_kl_32x32x4_slice_t_downsample_8.yaml
@@ -0,0 +1,66 @@
+model:
+ base_learning_rate: 1.0e-04
+ target: easyanimate.vae.ldm.models.omnigen_casual3dcnn.AutoencoderKLMagvit_fromOmnigen
+ params:
+ slice_compression_vae: true
+ mini_batch_encoder: 8
+ mini_batch_decoder: 1
+ monitor: train/rec_loss
+ ckpt_path: models/Diffusion_Transformer/EasyAnimateV2-XL-2-512x512/vae/diffusion_pytorch_model.safetensors
+ down_block_types: ("SpatialTemporalDownBlock3D", "SpatialTemporalDownBlock3D", "SpatialTemporalDownBlock3D",
+ "SpatialTemporalDownBlock3D",)
+ up_block_types: ("SpatialTemporalUpBlock3D", "SpatialTemporalUpBlock3D", "SpatialTemporalUpBlock3D",
+ "SpatialTemporalUpBlock3D",)
+ lossconfig:
+ target: easyanimate.vae.ldm.modules.losses.LPIPSWithDiscriminator
+ params:
+ disc_start: 50001
+ kl_weight: 1.0e-06
+ disc_weight: 0.5
+ l2_loss_weight: 0.0
+ l1_loss_weight: 1.0
+ perceptual_weight: 1.0
+
+
+data:
+ target: train_vae.DataModuleFromConfig
+
+ params:
+ batch_size: 1
+ wrap: true
+ num_workers: 8
+ train:
+ target: easyanimate.vae.ldm.data.dataset_image_video.CustomSRTrain
+ params:
+ data_json_path: pretrain.json
+ data_root: /your_data_root # This is used in relative path
+ size: 256
+ degradation: pil_nearest
+ video_size: 256
+ video_len: 33
+ slice_interval: 1
+ validation:
+ target: easyanimate.vae.ldm.data.dataset_image_video.CustomSRValidation
+ params:
+ data_json_path: pretrain.json
+ data_root: /your_data_root # This is used in relative path
+ size: 256
+ degradation: pil_nearest
+ video_size: 256
+ video_len: 33
+ slice_interval: 1
+
+lightning:
+ callbacks:
+ image_logger:
+ target: train_vae.ImageLogger
+ params:
+ batch_frequency: 5000
+ max_images: 8
+ increase_log_steps: True
+
+ trainer:
+ benchmark: True
+ accumulate_grad_batches: 1
+ gpus: "0"
+ num_nodes: 1
\ No newline at end of file
diff --git a/easyanimate/vae/environment.yaml b/easyanimate/vae/environment.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..6279f233d83cb52ed93c77f99c5de10ae866c9fd
--- /dev/null
+++ b/easyanimate/vae/environment.yaml
@@ -0,0 +1,29 @@
+name: ldm
+channels:
+ - pytorch
+ - defaults
+dependencies:
+ - python=3.8.5
+ - pip=20.3
+ - cudatoolkit=11.3
+ - pytorch=1.11.0
+ - torchvision=0.12.0
+ - numpy=1.19.2
+ - pip:
+ - albumentations==0.4.3
+ - diffusers
+ - opencv-python==4.1.2.30
+ - pudb==2019.2
+ - invisible-watermark
+ - imageio==2.9.0
+ - imageio-ffmpeg==0.4.2
+ - pytorch-lightning==1.4.2
+ - omegaconf==2.1.1
+ - test-tube>=0.7.5
+ - streamlit>=0.73.1
+ - einops==0.3.0
+ - torch-fidelity==0.3.0
+ - transformers==4.19.2
+ - torchmetrics==0.6.0
+ - kornia==0.6
+ - -e .
diff --git a/easyanimate/vae/ldm/data/__init__.py b/easyanimate/vae/ldm/data/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/easyanimate/vae/ldm/data/base.py b/easyanimate/vae/ldm/data/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..7ba123313253e625fe14f34a24e3c056ec4cf056
--- /dev/null
+++ b/easyanimate/vae/ldm/data/base.py
@@ -0,0 +1,25 @@
+from abc import abstractmethod
+
+from torch.utils.data import (ChainDataset, ConcatDataset, Dataset,
+ IterableDataset)
+
+
+class Txt2ImgIterableBaseDataset(IterableDataset):
+ '''
+ Define an interface to make the IterableDatasets for text2img data chainable
+ '''
+ def __init__(self, num_records=0, valid_ids=None, size=256):
+ super().__init__()
+ self.num_records = num_records
+ self.valid_ids = valid_ids
+ self.sample_ids = valid_ids
+ self.size = size
+
+ print(f'{self.__class__.__name__} dataset contains {self.__len__()} examples.')
+
+ def __len__(self):
+ return self.num_records
+
+ @abstractmethod
+ def __iter__(self):
+ pass
\ No newline at end of file
diff --git a/easyanimate/vae/ldm/data/dataset_callback.py b/easyanimate/vae/ldm/data/dataset_callback.py
new file mode 100644
index 0000000000000000000000000000000000000000..d0d59435644faca534f24d7619d7b9cf3071c22c
--- /dev/null
+++ b/easyanimate/vae/ldm/data/dataset_callback.py
@@ -0,0 +1,25 @@
+#-*- encoding:utf-8 -*-
+from pytorch_lightning.callbacks import Callback
+
+class DatasetCallback(Callback):
+ def __init__(self):
+ self.sampler_pos_start = 0
+ self.preload_used_idx_flag = False
+
+ def on_train_start(self, trainer, pl_module):
+ if not self.preload_used_idx_flag:
+ self.preload_used_idx_flag = True
+ trainer.train_dataloader.batch_sampler.sampler_pos_reload = self.sampler_pos_start
+
+ def on_save_checkpoint(self, trainer, pl_module, checkpoint):
+ if trainer.train_dataloader is not None:
+ # Save sampler_pos_start parameters in the checkpoint
+ checkpoint['sampler_pos_start'] = trainer.train_dataloader.batch_sampler.sampler_pos_start
+
+ def on_load_checkpoint(self, trainer, pl_module, checkpoint):
+ # Restore sampler_pos_start parameters from the checkpoint
+ if 'sampler_pos_start' in checkpoint:
+ self.sampler_pos_start = checkpoint.get('sampler_pos_start', 0)
+ print('Load sampler_pos_start from checkpoint, sampler_pos_start = %d' % self.sampler_pos_start)
+ else:
+ print('The sampler_pos_start is not in checkpoint')
\ No newline at end of file
diff --git a/easyanimate/vae/ldm/data/dataset_image_video.py b/easyanimate/vae/ldm/data/dataset_image_video.py
new file mode 100644
index 0000000000000000000000000000000000000000..bd0c4f6572a34659c81ca1b0d44e1acb5a58ee05
--- /dev/null
+++ b/easyanimate/vae/ldm/data/dataset_image_video.py
@@ -0,0 +1,281 @@
+import glob
+import json
+import os
+import pickle
+import random
+import shutil
+import tarfile
+from functools import partial
+
+import albumentations
+import cv2
+import numpy as np
+import PIL
+import torchvision.transforms.functional as TF
+import yaml
+from decord import VideoReader
+from func_timeout import FunctionTimedOut, func_set_timeout
+from omegaconf import OmegaConf
+from PIL import Image
+from torch.utils.data import (BatchSampler, Dataset, Sampler)
+from tqdm import tqdm
+
+from ..modules.image_degradation import (degradation_fn_bsr,
+ degradation_fn_bsr_light)
+
+
+class ImageVideoSampler(BatchSampler):
+ """A sampler wrapper for grouping images with similar aspect ratio into a same batch.
+
+ Args:
+ sampler (Sampler): Base sampler.
+ dataset (Dataset): Dataset providing data information.
+ batch_size (int): Size of mini-batch.
+ drop_last (bool): If ``True``, the sampler will drop the last batch if
+ its size would be less than ``batch_size``.
+ aspect_ratios (dict): The predefined aspect ratios.
+ """
+
+ def __init__(self,
+ sampler: Sampler,
+ dataset: Dataset,
+ batch_size: int,
+ drop_last: bool = False
+ ) -> None:
+ if not isinstance(sampler, Sampler):
+ raise TypeError('sampler should be an instance of ``Sampler``, '
+ f'but got {sampler}')
+ if not isinstance(batch_size, int) or batch_size <= 0:
+ raise ValueError('batch_size should be a positive integer value, '
+ f'but got batch_size={batch_size}')
+ self.sampler = sampler
+ self.dataset = dataset
+ self.batch_size = batch_size
+ self.drop_last = drop_last
+
+ self.sampler_pos_start = 0
+ self.sampler_pos_reload = 0
+
+ self.num_samples_random = len(self.sampler)
+ # buckets for each aspect ratio
+ self.bucket = {'image':[], 'video':[]}
+
+ def set_epoch(self, epoch):
+ if hasattr(self.sampler, "set_epoch"):
+ self.sampler.set_epoch(epoch)
+
+ def __iter__(self):
+ for index_sampler, idx in enumerate(self.sampler):
+ if self.sampler_pos_reload != 0 and self.sampler_pos_reload < self.num_samples_random:
+ if index_sampler < self.sampler_pos_reload:
+ self.sampler_pos_start = (self.sampler_pos_start + 1) % self.num_samples_random
+ continue
+ elif index_sampler == self.sampler_pos_reload:
+ self.sampler_pos_reload = 0
+
+ content_type = self.dataset.data.get_type(idx)
+ bucket = self.bucket[content_type]
+ bucket.append(idx)
+ # yield a batch of indices in the same aspect ratio group
+ if len(self.bucket['video']) == self.batch_size:
+ yield self.bucket['video']
+ self.bucket['video'] = []
+ elif len(self.bucket['image']) == self.batch_size:
+ yield self.bucket['image']
+ self.bucket['image'] = []
+ self.sampler_pos_start = (self.sampler_pos_start + 1) % self.num_samples_random
+
+class ImageVideoDataset(Dataset):
+ # update __getitem__() from ImageNetSR. If timeout for Pandas70M, throw exception.
+ # If caught exception(timeout or others), try another index until successful and return.
+ def __init__(self, size=None, video_size=128, video_len=25,
+ degradation=None, downscale_f=4, random_crop=True, min_crop_f=0.25, max_crop_f=1.,
+ s_t=None, slice_interval=None, data_root=None
+ ):
+ """
+ Imagenet Superresolution Dataloader
+ Performs following ops in order:
+ 1. crops a crop of size s from image either as random or center crop
+ 2. resizes crop to size with cv2.area_interpolation
+ 3. degrades resized crop with degradation_fn
+
+ :param size: resizing to size after cropping
+ :param degradation: degradation_fn, e.g. cv_bicubic or bsrgan_light
+ :param downscale_f: Low Resolution Downsample factor
+ :param min_crop_f: determines crop size s,
+ where s = c * min_img_side_len with c sampled from interval (min_crop_f, max_crop_f)
+ :param max_crop_f: ""
+ :param data_root:
+ :param random_crop:
+ """
+ self.base = self.get_base()
+ assert size
+ assert (size / downscale_f).is_integer()
+ self.size = size
+ self.LR_size = int(size / downscale_f)
+ self.min_crop_f = min_crop_f
+ self.max_crop_f = max_crop_f
+ assert(max_crop_f <= 1.)
+ self.center_crop = not random_crop
+ self.s_t = s_t
+ self.slice_interval = slice_interval
+
+ self.image_rescaler = albumentations.SmallestMaxSize(max_size=size, interpolation=cv2.INTER_AREA)
+ self.video_rescaler = albumentations.SmallestMaxSize(max_size=video_size, interpolation=cv2.INTER_AREA)
+ self.video_len = video_len
+ self.video_size = video_size
+ self.data_root = data_root
+
+ self.pil_interpolation = False # gets reset later if incase interp_op is from pillow
+
+ if degradation == "bsrgan":
+ self.degradation_process = partial(degradation_fn_bsr, sf=downscale_f)
+
+ elif degradation == "bsrgan_light":
+ self.degradation_process = partial(degradation_fn_bsr_light, sf=downscale_f)
+ else:
+ interpolation_fn = {
+ "cv_nearest": cv2.INTER_NEAREST,
+ "cv_bilinear": cv2.INTER_LINEAR,
+ "cv_bicubic": cv2.INTER_CUBIC,
+ "cv_area": cv2.INTER_AREA,
+ "cv_lanczos": cv2.INTER_LANCZOS4,
+ "pil_nearest": PIL.Image.NEAREST,
+ "pil_bilinear": PIL.Image.BILINEAR,
+ "pil_bicubic": PIL.Image.BICUBIC,
+ "pil_box": PIL.Image.BOX,
+ "pil_hamming": PIL.Image.HAMMING,
+ "pil_lanczos": PIL.Image.LANCZOS,
+ }[degradation]
+
+ self.pil_interpolation = degradation.startswith("pil_")
+
+ if self.pil_interpolation:
+ self.degradation_process = partial(TF.resize, size=self.LR_size, interpolation=interpolation_fn)
+
+ else:
+ self.degradation_process = albumentations.SmallestMaxSize(max_size=self.LR_size,
+ interpolation=interpolation_fn)
+
+ def __len__(self):
+ return len(self.base)
+
+ def get_type(self, index):
+ return self.base[index].get('type', 'image')
+
+ def __getitem__(self, i):
+ @func_set_timeout(3) # time wait 3 seconds
+ def get_video_item(example):
+ if self.data_root is not None:
+ video_reader = VideoReader(os.path.join(self.data_root, example['file_path']))
+ else:
+ video_reader = VideoReader(example['file_path'])
+ video_length = len(video_reader)
+
+ clip_length = min(video_length, (self.video_len - 1) * self.slice_interval + 1)
+ start_idx = random.randint(0, video_length - clip_length)
+ batch_index = np.linspace(start_idx, start_idx + clip_length - 1, self.video_len, dtype=int)
+
+ pixel_values = video_reader.get_batch(batch_index).asnumpy()
+
+ del video_reader
+ out_images = []
+ LR_out_images = []
+ min_side_len = min(pixel_values[0].shape[:2])
+
+ crop_side_len = min_side_len * np.random.uniform(self.min_crop_f, self.max_crop_f, size=None)
+ crop_side_len = int(crop_side_len)
+ if self.center_crop:
+ self.cropper = albumentations.CenterCrop(height=crop_side_len, width=crop_side_len)
+ else:
+ self.cropper = albumentations.RandomCrop(height=crop_side_len, width=crop_side_len)
+
+ imgs = np.transpose(pixel_values, (1, 2, 3, 0))
+ imgs = self.cropper(image=imgs)["image"]
+ imgs = np.transpose(imgs, (3, 0, 1, 2))
+ for img in imgs:
+ image = self.video_rescaler(image=img)["image"]
+ out_images.append(image[None, :, :, :])
+ if self.pil_interpolation:
+ image_pil = PIL.Image.fromarray(image)
+ LR_image = self.degradation_process(image_pil)
+ LR_image = np.array(LR_image).astype(np.uint8)
+ else:
+ LR_image = self.degradation_process(image=image)["image"]
+ LR_out_images.append(LR_image[None, :, :, :])
+
+ example = {}
+ example['image'] = (np.concatenate(out_images) / 127.5 - 1.0).astype(np.float32)
+ example['LR_image'] = (np.concatenate(LR_out_images) / 127.5 - 1.0).astype(np.float32)
+ return example
+
+ example = self.base[i]
+ if example.get('type', 'image') == 'video':
+ while True:
+ try:
+ example = self.base[i]
+ return get_video_item(example)
+ except FunctionTimedOut:
+ print("stt catch: Function 'extract failed' timed out.")
+ i = random.randint(0, self.__len__() - 1)
+ except Exception as e:
+ print('stt catch', e)
+ i = random.randint(0, self.__len__() - 1)
+ elif example.get('type', 'image') == 'image':
+ while True:
+ try:
+ example = self.base[i]
+ if self.data_root is not None:
+ image = Image.open(os.path.join(self.data_root, example['file_path']))
+ else:
+ image = Image.open(example['file_path'])
+ image = image.convert("RGB")
+ image = np.array(image).astype(np.uint8)
+
+ min_side_len = min(image.shape[:2])
+ crop_side_len = min_side_len * np.random.uniform(self.min_crop_f, self.max_crop_f, size=None)
+ crop_side_len = int(crop_side_len)
+
+ if self.center_crop:
+ self.cropper = albumentations.CenterCrop(height=crop_side_len, width=crop_side_len)
+
+ else:
+ self.cropper = albumentations.RandomCrop(height=crop_side_len, width=crop_side_len)
+
+ image = self.cropper(image=image)["image"]
+
+ image = self.image_rescaler(image=image)["image"]
+
+ if self.pil_interpolation:
+ image_pil = PIL.Image.fromarray(image)
+ LR_image = self.degradation_process(image_pil)
+ LR_image = np.array(LR_image).astype(np.uint8)
+
+ else:
+ LR_image = self.degradation_process(image=image)["image"]
+
+ example = {}
+ example["image"] = (image/127.5 - 1.0).astype(np.float32)
+ example["LR_image"] = (LR_image/127.5 - 1.0).astype(np.float32)
+ return example
+ except Exception as e:
+ print("catch", e)
+ i = random.randint(0, self.__len__() - 1)
+
+class CustomSRTrain(ImageVideoDataset):
+ def __init__(self, data_json_path, **kwargs):
+ self.data_json_path = data_json_path
+ super().__init__(**kwargs)
+
+ def get_base(self):
+ return [ann for ann in json.load(open(self.data_json_path))]
+
+class CustomSRValidation(ImageVideoDataset):
+ def __init__(self, data_json_path, **kwargs):
+ self.data_json_path = data_json_path
+ super().__init__(**kwargs)
+ self.data_json_path = data_json_path
+
+ def get_base(self):
+ return [ann for ann in json.load(open(self.data_json_path))][:100] + \
+ [ann for ann in json.load(open(self.data_json_path))][-100:]
diff --git a/easyanimate/vae/ldm/lr_scheduler.py b/easyanimate/vae/ldm/lr_scheduler.py
new file mode 100644
index 0000000000000000000000000000000000000000..be39da9ca6dacc22bf3df9c7389bbb403a4a3ade
--- /dev/null
+++ b/easyanimate/vae/ldm/lr_scheduler.py
@@ -0,0 +1,98 @@
+import numpy as np
+
+
+class LambdaWarmUpCosineScheduler:
+ """
+ note: use with a base_lr of 1.0
+ """
+ def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0):
+ self.lr_warm_up_steps = warm_up_steps
+ self.lr_start = lr_start
+ self.lr_min = lr_min
+ self.lr_max = lr_max
+ self.lr_max_decay_steps = max_decay_steps
+ self.last_lr = 0.
+ self.verbosity_interval = verbosity_interval
+
+ def schedule(self, n, **kwargs):
+ if self.verbosity_interval > 0:
+ if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
+ if n < self.lr_warm_up_steps:
+ lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start
+ self.last_lr = lr
+ return lr
+ else:
+ t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps)
+ t = min(t, 1.0)
+ lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
+ 1 + np.cos(t * np.pi))
+ self.last_lr = lr
+ return lr
+
+ def __call__(self, n, **kwargs):
+ return self.schedule(n,**kwargs)
+
+
+class LambdaWarmUpCosineScheduler2:
+ """
+ supports repeated iterations, configurable via lists
+ note: use with a base_lr of 1.0.
+ """
+ def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0):
+ assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths)
+ self.lr_warm_up_steps = warm_up_steps
+ self.f_start = f_start
+ self.f_min = f_min
+ self.f_max = f_max
+ self.cycle_lengths = cycle_lengths
+ self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths))
+ self.last_f = 0.
+ self.verbosity_interval = verbosity_interval
+
+ def find_in_interval(self, n):
+ interval = 0
+ for cl in self.cum_cycles[1:]:
+ if n <= cl:
+ return interval
+ interval += 1
+
+ def schedule(self, n, **kwargs):
+ cycle = self.find_in_interval(n)
+ n = n - self.cum_cycles[cycle]
+ if self.verbosity_interval > 0:
+ if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
+ f"current cycle {cycle}")
+ if n < self.lr_warm_up_steps[cycle]:
+ f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
+ self.last_f = f
+ return f
+ else:
+ t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle])
+ t = min(t, 1.0)
+ f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (
+ 1 + np.cos(t * np.pi))
+ self.last_f = f
+ return f
+
+ def __call__(self, n, **kwargs):
+ return self.schedule(n, **kwargs)
+
+
+class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
+
+ def schedule(self, n, **kwargs):
+ cycle = self.find_in_interval(n)
+ n = n - self.cum_cycles[cycle]
+ if self.verbosity_interval > 0:
+ if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
+ f"current cycle {cycle}")
+
+ if n < self.lr_warm_up_steps[cycle]:
+ f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
+ self.last_f = f
+ return f
+ else:
+ f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle])
+ self.last_f = f
+ return f
+
diff --git a/easyanimate/vae/ldm/modules/diffusionmodules/__init__.py b/easyanimate/vae/ldm/modules/diffusionmodules/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/easyanimate/vae/ldm/modules/diffusionmodules/model.py b/easyanimate/vae/ldm/modules/diffusionmodules/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..bf5e181581a936b35f1c5801f6a669e96955357d
--- /dev/null
+++ b/easyanimate/vae/ldm/modules/diffusionmodules/model.py
@@ -0,0 +1,701 @@
+# pytorch_diffusion + derived encoder decoder
+import math
+
+import numpy as np
+import torch
+import torch.nn as nn
+from einops import rearrange
+
+from ...util import instantiate_from_config
+
+
+def get_timestep_embedding(timesteps, embedding_dim):
+ """
+ This matches the implementation in Denoising Diffusion Probabilistic Models:
+ From Fairseq.
+ Build sinusoidal embeddings.
+ This matches the implementation in tensor2tensor, but differs slightly
+ from the description in Section 3.5 of "Attention Is All You Need".
+ """
+ assert len(timesteps.shape) == 1
+
+ half_dim = embedding_dim // 2
+ emb = math.log(10000) / (half_dim - 1)
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
+ emb = emb.to(device=timesteps.device)
+ emb = timesteps.float()[:, None] * emb[None, :]
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
+ if embedding_dim % 2 == 1: # zero pad
+ emb = torch.nn.functional.pad(emb, (0,1,0,0))
+ return emb
+
+
+def nonlinearity(x):
+ # swish
+ return x*torch.sigmoid(x)
+
+
+def Normalize(in_channels, num_groups=32):
+ return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
+
+
+class Upsample(nn.Module):
+ def __init__(self, in_channels, with_conv):
+ super().__init__()
+ self.with_conv = with_conv
+ if self.with_conv:
+ self.conv = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ def forward(self, x):
+ x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
+ if self.with_conv:
+ x = self.conv(x)
+ return x
+
+
+class Downsample(nn.Module):
+ def __init__(self, in_channels, with_conv):
+ super().__init__()
+ self.with_conv = with_conv
+ if self.with_conv:
+ # no asymmetric padding in torch conv, must do it ourselves
+ self.conv = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=3,
+ stride=2,
+ padding=0)
+
+ def forward(self, x):
+ if self.with_conv:
+ pad = (0,1,0,1)
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
+ x = self.conv(x)
+ else:
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
+ return x
+
+
+class ResnetBlock(nn.Module):
+ def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
+ dropout, temb_channels=512):
+ super().__init__()
+ self.in_channels = in_channels
+ out_channels = in_channels if out_channels is None else out_channels
+ self.out_channels = out_channels
+ self.use_conv_shortcut = conv_shortcut
+
+ self.norm1 = Normalize(in_channels)
+ self.conv1 = torch.nn.Conv2d(in_channels,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+ if temb_channels > 0:
+ self.temb_proj = torch.nn.Linear(temb_channels,
+ out_channels)
+ self.norm2 = Normalize(out_channels)
+ self.dropout = torch.nn.Dropout(dropout)
+ self.conv2 = torch.nn.Conv2d(out_channels,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+ if self.in_channels != self.out_channels:
+ if self.use_conv_shortcut:
+ self.conv_shortcut = torch.nn.Conv2d(in_channels,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+ else:
+ self.nin_shortcut = torch.nn.Conv2d(in_channels,
+ out_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+
+ def forward(self, x, temb):
+ h = x
+ h = self.norm1(h)
+ h = nonlinearity(h)
+ h = self.conv1(h)
+
+ if temb is not None:
+ h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]
+
+ h = self.norm2(h)
+ h = nonlinearity(h)
+ h = self.dropout(h)
+ h = self.conv2(h)
+
+ if self.in_channels != self.out_channels:
+ if self.use_conv_shortcut:
+ x = self.conv_shortcut(x)
+ else:
+ x = self.nin_shortcut(x)
+
+ return x+h
+
+
+
+
+class AttnBlock(nn.Module):
+ def __init__(self, in_channels):
+ super().__init__()
+ self.in_channels = in_channels
+
+ self.norm = Normalize(in_channels)
+ self.q = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.k = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.v = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.proj_out = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+
+
+ def forward(self, x):
+ h_ = x
+ h_ = self.norm(h_)
+ q = self.q(h_)
+ k = self.k(h_)
+ v = self.v(h_)
+
+ # compute attention
+ b,c,h,w = q.shape
+ q = q.reshape(b,c,h*w)
+ q = q.permute(0,2,1) # b,hw,c
+ k = k.reshape(b,c,h*w) # b,c,hw
+ w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
+ w_ = w_ * (int(c)**(-0.5))
+ w_ = torch.nn.functional.softmax(w_, dim=2)
+
+ # attend to values
+ v = v.reshape(b,c,h*w)
+ w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q)
+ h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
+ h_ = h_.reshape(b,c,h,w)
+
+ h_ = self.proj_out(h_)
+
+ return x+h_
+
+class LinearAttention(nn.Module):
+ def __init__(self, dim, heads=4, dim_head=32):
+ super().__init__()
+ self.heads = heads
+ hidden_dim = dim_head * heads
+ self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
+ self.to_out = nn.Conv2d(hidden_dim, dim, 1)
+
+ def forward(self, x):
+ b, c, h, w = x.shape
+ qkv = self.to_qkv(x)
+ q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3)
+ k = k.softmax(dim=-1)
+ context = torch.einsum('bhdn,bhen->bhde', k, v)
+ out = torch.einsum('bhde,bhdn->bhen', context, q)
+ out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)
+ return self.to_out(out)
+
+class LinAttnBlock(LinearAttention):
+ """to match AttnBlock usage"""
+ def __init__(self, in_channels):
+ super().__init__(dim=in_channels, heads=1, dim_head=in_channels)
+
+def make_attn(in_channels, attn_type="vanilla"):
+ assert attn_type in ["vanilla", "linear", "none"], f'attn_type {attn_type} unknown'
+ print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
+ if attn_type == "vanilla":
+ return AttnBlock(in_channels)
+ elif attn_type == "none":
+ return nn.Identity(in_channels)
+ else:
+ return LinAttnBlock(in_channels)
+
+
+class Encoder(nn.Module):
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
+ resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla",
+ **ignore_kwargs):
+ super().__init__()
+ if use_linear_attn: attn_type = "linear"
+ self.ch = ch
+ self.temb_ch = 0
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+ self.in_channels = in_channels
+
+ # downsampling
+ self.conv_in = torch.nn.Conv2d(in_channels,
+ self.ch,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ curr_res = resolution
+ in_ch_mult = (1,)+tuple(ch_mult)
+ self.in_ch_mult = in_ch_mult
+ self.down = nn.ModuleList()
+ for i_level in range(self.num_resolutions):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_in = ch*in_ch_mult[i_level]
+ block_out = ch*ch_mult[i_level]
+ for i_block in range(self.num_res_blocks):
+ block.append(ResnetBlock(in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout))
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(make_attn(block_in, attn_type=attn_type))
+ down = nn.Module()
+ down.block = block
+ down.attn = attn
+ if i_level != self.num_resolutions-1:
+ down.downsample = Downsample(block_in, resamp_with_conv)
+ curr_res = curr_res // 2
+ self.down.append(down)
+
+ # middle
+ self.mid = nn.Module()
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+
+ # end
+ self.norm_out = Normalize(block_in)
+ self.conv_out = torch.nn.Conv2d(block_in,
+ 2*z_channels if double_z else z_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ def forward(self, x):
+ # timestep embedding
+ temb = None
+
+ # downsampling
+ hs = [self.conv_in(x)]
+ for i_level in range(self.num_resolutions):
+ for i_block in range(self.num_res_blocks):
+ h = self.down[i_level].block[i_block](hs[-1], temb)
+ if len(self.down[i_level].attn) > 0:
+ h = self.down[i_level].attn[i_block](h)
+ hs.append(h)
+ if i_level != self.num_resolutions-1:
+ hs.append(self.down[i_level].downsample(hs[-1]))
+
+ # middle
+ h = hs[-1]
+ h = self.mid.block_1(h, temb)
+ h = self.mid.attn_1(h)
+ h = self.mid.block_2(h, temb)
+
+ # end
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ return h
+
+
+class Decoder(nn.Module):
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
+ resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False,
+ attn_type="vanilla", **ignorekwargs):
+ super().__init__()
+ if use_linear_attn: attn_type = "linear"
+ self.ch = ch
+ self.temb_ch = 0
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+ self.in_channels = in_channels
+ self.give_pre_end = give_pre_end
+ self.tanh_out = tanh_out
+
+ # compute in_ch_mult, block_in and curr_res at lowest res
+ in_ch_mult = (1,)+tuple(ch_mult)
+ block_in = ch*ch_mult[self.num_resolutions-1]
+ curr_res = resolution // 2**(self.num_resolutions-1)
+ self.z_shape = (1,z_channels,curr_res,curr_res)
+ print("Working with z of shape {} = {} dimensions.".format(
+ self.z_shape, np.prod(self.z_shape)))
+
+ # z to block_in
+ self.conv_in = torch.nn.Conv2d(z_channels,
+ block_in,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ # middle
+ self.mid = nn.Module()
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+
+ # upsampling
+ self.up = nn.ModuleList()
+ for i_level in reversed(range(self.num_resolutions)):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_out = ch*ch_mult[i_level]
+ for i_block in range(self.num_res_blocks+1):
+ block.append(ResnetBlock(in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout))
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(make_attn(block_in, attn_type=attn_type))
+ up = nn.Module()
+ up.block = block
+ up.attn = attn
+ if i_level != 0:
+ up.upsample = Upsample(block_in, resamp_with_conv)
+ curr_res = curr_res * 2
+ self.up.insert(0, up) # prepend to get consistent order
+
+ # end
+ self.norm_out = Normalize(block_in)
+ self.conv_out = torch.nn.Conv2d(block_in,
+ out_ch,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ def forward(self, z):
+ #assert z.shape[1:] == self.z_shape[1:]
+ self.last_z_shape = z.shape
+
+ # timestep embedding
+ temb = None
+
+ # z to block_in
+ h = self.conv_in(z)
+
+ # middle
+ h = self.mid.block_1(h, temb)
+ h = self.mid.attn_1(h)
+ h = self.mid.block_2(h, temb)
+
+ # upsampling
+ for i_level in reversed(range(self.num_resolutions)):
+ for i_block in range(self.num_res_blocks+1):
+ h = self.up[i_level].block[i_block](h, temb)
+ if len(self.up[i_level].attn) > 0:
+ h = self.up[i_level].attn[i_block](h)
+ if i_level != 0:
+ h = self.up[i_level].upsample(h)
+
+ # end
+ if self.give_pre_end:
+ return h
+
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ if self.tanh_out:
+ h = torch.tanh(h)
+ return h
+
+
+class SimpleDecoder(nn.Module):
+ def __init__(self, in_channels, out_channels, *args, **kwargs):
+ super().__init__()
+ self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1),
+ ResnetBlock(in_channels=in_channels,
+ out_channels=2 * in_channels,
+ temb_channels=0, dropout=0.0),
+ ResnetBlock(in_channels=2 * in_channels,
+ out_channels=4 * in_channels,
+ temb_channels=0, dropout=0.0),
+ ResnetBlock(in_channels=4 * in_channels,
+ out_channels=2 * in_channels,
+ temb_channels=0, dropout=0.0),
+ nn.Conv2d(2*in_channels, in_channels, 1),
+ Upsample(in_channels, with_conv=True)])
+ # end
+ self.norm_out = Normalize(in_channels)
+ self.conv_out = torch.nn.Conv2d(in_channels,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ def forward(self, x):
+ for i, layer in enumerate(self.model):
+ if i in [1,2,3]:
+ x = layer(x, None)
+ else:
+ x = layer(x)
+
+ h = self.norm_out(x)
+ h = nonlinearity(h)
+ x = self.conv_out(h)
+ return x
+
+
+class UpsampleDecoder(nn.Module):
+ def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution,
+ ch_mult=(2,2), dropout=0.0):
+ super().__init__()
+ # upsampling
+ self.temb_ch = 0
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ block_in = in_channels
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
+ self.res_blocks = nn.ModuleList()
+ self.upsample_blocks = nn.ModuleList()
+ for i_level in range(self.num_resolutions):
+ res_block = []
+ block_out = ch * ch_mult[i_level]
+ for i_block in range(self.num_res_blocks + 1):
+ res_block.append(ResnetBlock(in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout))
+ block_in = block_out
+ self.res_blocks.append(nn.ModuleList(res_block))
+ if i_level != self.num_resolutions - 1:
+ self.upsample_blocks.append(Upsample(block_in, True))
+ curr_res = curr_res * 2
+
+ # end
+ self.norm_out = Normalize(block_in)
+ self.conv_out = torch.nn.Conv2d(block_in,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ def forward(self, x):
+ # upsampling
+ h = x
+ for k, i_level in enumerate(range(self.num_resolutions)):
+ for i_block in range(self.num_res_blocks + 1):
+ h = self.res_blocks[i_level][i_block](h, None)
+ if i_level != self.num_resolutions - 1:
+ h = self.upsample_blocks[k](h)
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ return h
+
+
+class LatentRescaler(nn.Module):
+ def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2):
+ super().__init__()
+ # residual block, interpolate, residual block
+ self.factor = factor
+ self.conv_in = nn.Conv2d(in_channels,
+ mid_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+ self.res_block1 = nn.ModuleList([ResnetBlock(in_channels=mid_channels,
+ out_channels=mid_channels,
+ temb_channels=0,
+ dropout=0.0) for _ in range(depth)])
+ self.attn = AttnBlock(mid_channels)
+ self.res_block2 = nn.ModuleList([ResnetBlock(in_channels=mid_channels,
+ out_channels=mid_channels,
+ temb_channels=0,
+ dropout=0.0) for _ in range(depth)])
+
+ self.conv_out = nn.Conv2d(mid_channels,
+ out_channels,
+ kernel_size=1,
+ )
+
+ def forward(self, x):
+ x = self.conv_in(x)
+ for block in self.res_block1:
+ x = block(x, None)
+ x = torch.nn.functional.interpolate(x, size=(int(round(x.shape[2]*self.factor)), int(round(x.shape[3]*self.factor))))
+ x = self.attn(x)
+ for block in self.res_block2:
+ x = block(x, None)
+ x = self.conv_out(x)
+ return x
+
+
+class MergedRescaleEncoder(nn.Module):
+ def __init__(self, in_channels, ch, resolution, out_ch, num_res_blocks,
+ attn_resolutions, dropout=0.0, resamp_with_conv=True,
+ ch_mult=(1,2,4,8), rescale_factor=1.0, rescale_module_depth=1):
+ super().__init__()
+ intermediate_chn = ch * ch_mult[-1]
+ self.encoder = Encoder(in_channels=in_channels, num_res_blocks=num_res_blocks, ch=ch, ch_mult=ch_mult,
+ z_channels=intermediate_chn, double_z=False, resolution=resolution,
+ attn_resolutions=attn_resolutions, dropout=dropout, resamp_with_conv=resamp_with_conv,
+ out_ch=None)
+ self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=intermediate_chn,
+ mid_channels=intermediate_chn, out_channels=out_ch, depth=rescale_module_depth)
+
+ def forward(self, x):
+ x = self.encoder(x)
+ x = self.rescaler(x)
+ return x
+
+
+class MergedRescaleDecoder(nn.Module):
+ def __init__(self, z_channels, out_ch, resolution, num_res_blocks, attn_resolutions, ch, ch_mult=(1,2,4,8),
+ dropout=0.0, resamp_with_conv=True, rescale_factor=1.0, rescale_module_depth=1):
+ super().__init__()
+ tmp_chn = z_channels*ch_mult[-1]
+ self.decoder = Decoder(out_ch=out_ch, z_channels=tmp_chn, attn_resolutions=attn_resolutions, dropout=dropout,
+ resamp_with_conv=resamp_with_conv, in_channels=None, num_res_blocks=num_res_blocks,
+ ch_mult=ch_mult, resolution=resolution, ch=ch)
+ self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=z_channels, mid_channels=tmp_chn,
+ out_channels=tmp_chn, depth=rescale_module_depth)
+
+ def forward(self, x):
+ x = self.rescaler(x)
+ x = self.decoder(x)
+ return x
+
+
+class Upsampler(nn.Module):
+ def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2):
+ super().__init__()
+ assert out_size >= in_size
+ num_blocks = int(np.log2(out_size//in_size))+1
+ factor_up = 1.+ (out_size % in_size)
+ print(f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}")
+ self.rescaler = LatentRescaler(factor=factor_up, in_channels=in_channels, mid_channels=2*in_channels,
+ out_channels=in_channels)
+ self.decoder = Decoder(out_ch=out_channels, resolution=out_size, z_channels=in_channels, num_res_blocks=2,
+ attn_resolutions=[], in_channels=None, ch=in_channels,
+ ch_mult=[ch_mult for _ in range(num_blocks)])
+
+ def forward(self, x):
+ x = self.rescaler(x)
+ x = self.decoder(x)
+ return x
+
+
+class Resize(nn.Module):
+ def __init__(self, in_channels=None, learned=False, mode="bilinear"):
+ super().__init__()
+ self.with_conv = learned
+ self.mode = mode
+ if self.with_conv:
+ print(f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode")
+ raise NotImplementedError()
+ assert in_channels is not None
+ # no asymmetric padding in torch conv, must do it ourselves
+ self.conv = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=4,
+ stride=2,
+ padding=1)
+
+ def forward(self, x, scale_factor=1.0):
+ if scale_factor==1.0:
+ return x
+ else:
+ x = torch.nn.functional.interpolate(x, mode=self.mode, align_corners=False, scale_factor=scale_factor)
+ return x
+
+class FirstStagePostProcessor(nn.Module):
+
+ def __init__(self, ch_mult:list, in_channels,
+ pretrained_model:nn.Module=None,
+ reshape=False,
+ n_channels=None,
+ dropout=0.,
+ pretrained_config=None):
+ super().__init__()
+ if pretrained_config is None:
+ assert pretrained_model is not None, 'Either "pretrained_model" or "pretrained_config" must not be None'
+ self.pretrained_model = pretrained_model
+ else:
+ assert pretrained_config is not None, 'Either "pretrained_model" or "pretrained_config" must not be None'
+ self.instantiate_pretrained(pretrained_config)
+
+ self.do_reshape = reshape
+
+ if n_channels is None:
+ n_channels = self.pretrained_model.encoder.ch
+
+ self.proj_norm = Normalize(in_channels,num_groups=in_channels//2)
+ self.proj = nn.Conv2d(in_channels,n_channels,kernel_size=3,
+ stride=1,padding=1)
+
+ blocks = []
+ downs = []
+ ch_in = n_channels
+ for m in ch_mult:
+ blocks.append(ResnetBlock(in_channels=ch_in,out_channels=m*n_channels,dropout=dropout))
+ ch_in = m * n_channels
+ downs.append(Downsample(ch_in, with_conv=False))
+
+ self.model = nn.ModuleList(blocks)
+ self.downsampler = nn.ModuleList(downs)
+
+
+ def instantiate_pretrained(self, config):
+ model = instantiate_from_config(config)
+ self.pretrained_model = model.eval()
+ # self.pretrained_model.train = False
+ for param in self.pretrained_model.parameters():
+ param.requires_grad = False
+
+
+ @torch.no_grad()
+ def encode_with_pretrained(self,x):
+ c = self.pretrained_model.encode(x)
+ if isinstance(c, DiagonalGaussianDistribution):
+ c = c.mode()
+ return c
+
+ def forward(self,x):
+ z_fs = self.encode_with_pretrained(x)
+ z = self.proj_norm(z_fs)
+ z = self.proj(z)
+ z = nonlinearity(z)
+
+ for submodel, downmodel in zip(self.model,self.downsampler):
+ z = submodel(z,temb=None)
+ z = downmodel(z)
+
+ if self.do_reshape:
+ z = rearrange(z,'b c h w -> b (h w) c')
+ return z
+
diff --git a/easyanimate/vae/ldm/modules/diffusionmodules/util.py b/easyanimate/vae/ldm/modules/diffusionmodules/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..54b198c78f1aff15faead5e0b5e70623cc84e5a6
--- /dev/null
+++ b/easyanimate/vae/ldm/modules/diffusionmodules/util.py
@@ -0,0 +1,268 @@
+# adopted from
+# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
+# and
+# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
+# and
+# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
+#
+# thanks!
+
+
+import math
+import os
+
+import numpy as np
+import torch
+import torch.nn as nn
+from einops import repeat
+
+from ...util import instantiate_from_config
+
+
+def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
+ if schedule == "linear":
+ betas = (
+ torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
+ )
+
+ elif schedule == "cosine":
+ timesteps = (
+ torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
+ )
+ alphas = timesteps / (1 + cosine_s) * np.pi / 2
+ alphas = torch.cos(alphas).pow(2)
+ alphas = alphas / alphas[0]
+ betas = 1 - alphas[1:] / alphas[:-1]
+ betas = np.clip(betas, a_min=0, a_max=0.999)
+
+ elif schedule == "sqrt_linear":
+ betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
+ elif schedule == "sqrt":
+ betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5
+ else:
+ raise ValueError(f"schedule '{schedule}' unknown.")
+ return betas.numpy()
+
+
+def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True):
+ if ddim_discr_method == 'uniform':
+ c = num_ddpm_timesteps // num_ddim_timesteps
+ ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
+ elif ddim_discr_method == 'quad':
+ ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int)
+ else:
+ raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"')
+
+ # assert ddim_timesteps.shape[0] == num_ddim_timesteps
+ # add one to get the final alpha values right (the ones from first scale to data during sampling)
+ steps_out = ddim_timesteps + 1
+ if verbose:
+ print(f'Selected timesteps for ddim sampler: {steps_out}')
+ return steps_out
+
+
+def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
+ # select alphas for computing the variance schedule
+ alphas = alphacums[ddim_timesteps]
+ alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
+
+ # according the the formula provided in https://arxiv.org/abs/2010.02502
+ sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
+ if verbose:
+ print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')
+ print(f'For the chosen value of eta, which is {eta}, '
+ f'this results in the following sigma_t schedule for ddim sampler {sigmas}')
+ return sigmas, alphas, alphas_prev
+
+
+def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
+ """
+ Create a beta schedule that discretizes the given alpha_t_bar function,
+ which defines the cumulative product of (1-beta) over time from t = [0,1].
+ :param num_diffusion_timesteps: the number of betas to produce.
+ :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
+ produces the cumulative product of (1-beta) up to that
+ part of the diffusion process.
+ :param max_beta: the maximum beta to use; use values lower than 1 to
+ prevent singularities.
+ """
+ betas = []
+ for i in range(num_diffusion_timesteps):
+ t1 = i / num_diffusion_timesteps
+ t2 = (i + 1) / num_diffusion_timesteps
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
+ return np.array(betas)
+
+
+def extract_into_tensor(a, t, x_shape):
+ b, *_ = t.shape
+ out = a.gather(-1, t)
+ return out.reshape(b, *((1,) * (len(x_shape) - 1)))
+
+
+def checkpoint(func, inputs, params, flag):
+ """
+ Evaluate a function without caching intermediate activations, allowing for
+ reduced memory at the expense of extra compute in the backward pass.
+ :param func: the function to evaluate.
+ :param inputs: the argument sequence to pass to `func`.
+ :param params: a sequence of parameters `func` depends on but does not
+ explicitly take as arguments.
+ :param flag: if False, disable gradient checkpointing.
+ """
+ if flag:
+ args = tuple(inputs) + tuple(params)
+ return CheckpointFunction.apply(func, len(inputs), *args)
+ else:
+ return func(*inputs)
+
+
+class CheckpointFunction(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, run_function, length, *args):
+ ctx.run_function = run_function
+ ctx.input_tensors = list(args[:length])
+ ctx.input_params = list(args[length:])
+
+ with torch.no_grad():
+ output_tensors = ctx.run_function(*ctx.input_tensors)
+ return output_tensors
+
+ @staticmethod
+ def backward(ctx, *output_grads):
+ ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
+ with torch.enable_grad():
+ # Fixes a bug where the first op in run_function modifies the
+ # Tensor storage in place, which is not allowed for detach()'d
+ # Tensors.
+ shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
+ output_tensors = ctx.run_function(*shallow_copies)
+ input_grads = torch.autograd.grad(
+ output_tensors,
+ ctx.input_tensors + ctx.input_params,
+ output_grads,
+ allow_unused=True,
+ )
+ del ctx.input_tensors
+ del ctx.input_params
+ del output_tensors
+ return (None, None) + input_grads
+
+
+def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
+ """
+ Create sinusoidal timestep embeddings.
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
+ These may be fractional.
+ :param dim: the dimension of the output.
+ :param max_period: controls the minimum frequency of the embeddings.
+ :return: an [N x dim] Tensor of positional embeddings.
+ """
+ if not repeat_only:
+ half = dim // 2
+ freqs = torch.exp(
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
+ ).to(device=timesteps.device)
+ args = timesteps[:, None].float() * freqs[None]
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
+ if dim % 2:
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
+ else:
+ embedding = repeat(timesteps, 'b -> b d', d=dim)
+ return embedding
+
+
+def zero_module(module):
+ """
+ Zero out the parameters of a module and return it.
+ """
+ for p in module.parameters():
+ p.detach().zero_()
+ return module
+
+
+def scale_module(module, scale):
+ """
+ Scale the parameters of a module and return it.
+ """
+ for p in module.parameters():
+ p.detach().mul_(scale)
+ return module
+
+
+def mean_flat(tensor):
+ """
+ Take the mean over all non-batch dimensions.
+ """
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
+
+
+def normalization(channels):
+ """
+ Make a standard normalization layer.
+ :param channels: number of input channels.
+ :return: an nn.Module for normalization.
+ """
+ return GroupNorm32(32, channels)
+
+
+# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
+class SiLU(nn.Module):
+ def forward(self, x):
+ return x * torch.sigmoid(x)
+
+
+class GroupNorm32(nn.GroupNorm):
+ def forward(self, x):
+ return super().forward(x.float()).type(x.dtype)
+
+def conv_nd(dims, *args, **kwargs):
+ """
+ Create a 1D, 2D, or 3D convolution module.
+ """
+ if dims == 1:
+ return nn.Conv1d(*args, **kwargs)
+ elif dims == 2:
+ return nn.Conv2d(*args, **kwargs)
+ elif dims == 3:
+ return nn.Conv3d(*args, **kwargs)
+ raise ValueError(f"unsupported dimensions: {dims}")
+
+
+def linear(*args, **kwargs):
+ """
+ Create a linear module.
+ """
+ return nn.Linear(*args, **kwargs)
+
+
+def avg_pool_nd(dims, *args, **kwargs):
+ """
+ Create a 1D, 2D, or 3D average pooling module.
+ """
+ if dims == 1:
+ return nn.AvgPool1d(*args, **kwargs)
+ elif dims == 2:
+ return nn.AvgPool2d(*args, **kwargs)
+ elif dims == 3:
+ return nn.AvgPool3d(*args, **kwargs)
+ raise ValueError(f"unsupported dimensions: {dims}")
+
+
+class HybridConditioner(nn.Module):
+
+ def __init__(self, c_concat_config, c_crossattn_config):
+ super().__init__()
+ self.concat_conditioner = instantiate_from_config(c_concat_config)
+ self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)
+
+ def forward(self, c_concat, c_crossattn):
+ c_concat = self.concat_conditioner(c_concat)
+ c_crossattn = self.crossattn_conditioner(c_crossattn)
+ return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]}
+
+
+def noise_like(shape, device, repeat=False):
+ repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
+ noise = lambda: torch.randn(shape, device=device)
+ return repeat_noise() if repeat else noise()
\ No newline at end of file
diff --git a/easyanimate/vae/ldm/modules/distributions/__init__.py b/easyanimate/vae/ldm/modules/distributions/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/easyanimate/vae/ldm/modules/distributions/distributions.py b/easyanimate/vae/ldm/modules/distributions/distributions.py
new file mode 100644
index 0000000000000000000000000000000000000000..aec017492bdf70b3fdc16ed1ab83a8a92ff9087c
--- /dev/null
+++ b/easyanimate/vae/ldm/modules/distributions/distributions.py
@@ -0,0 +1,92 @@
+import numpy as np
+import torch
+
+
+class AbstractDistribution:
+ def sample(self):
+ raise NotImplementedError()
+
+ def mode(self):
+ raise NotImplementedError()
+
+
+class DiracDistribution(AbstractDistribution):
+ def __init__(self, value):
+ self.value = value
+
+ def sample(self):
+ return self.value
+
+ def mode(self):
+ return self.value
+
+
+class DiagonalGaussianDistribution(object):
+ def __init__(self, parameters, deterministic=False):
+ self.parameters = parameters
+ self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
+ self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
+ self.deterministic = deterministic
+ self.std = torch.exp(0.5 * self.logvar)
+ self.var = torch.exp(self.logvar)
+ if self.deterministic:
+ self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
+
+ def sample(self):
+ x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
+ return x
+
+ def kl(self, other=None):
+ if self.deterministic:
+ return torch.Tensor([0.])
+ else:
+ if other is None:
+ return 0.5 * torch.sum(torch.pow(self.mean, 2)
+ + self.var - 1.0 - self.logvar,
+ dim=[1, 2, 3])
+ else:
+ return 0.5 * torch.sum(
+ torch.pow(self.mean - other.mean, 2) / other.var
+ + self.var / other.var - 1.0 - self.logvar + other.logvar,
+ dim=[1, 2, 3])
+
+ def nll(self, sample, dims=[1,2,3]):
+ if self.deterministic:
+ return torch.Tensor([0.])
+ logtwopi = np.log(2.0 * np.pi)
+ return 0.5 * torch.sum(
+ logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
+ dim=dims)
+
+ def mode(self):
+ return self.mean
+
+
+def normal_kl(mean1, logvar1, mean2, logvar2):
+ """
+ source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
+ Compute the KL divergence between two gaussians.
+ Shapes are automatically broadcasted, so batches can be compared to
+ scalars, among other use cases.
+ """
+ tensor = None
+ for obj in (mean1, logvar1, mean2, logvar2):
+ if isinstance(obj, torch.Tensor):
+ tensor = obj
+ break
+ assert tensor is not None, "at least one argument must be a Tensor"
+
+ # Force variances to be Tensors. Broadcasting helps convert scalars to
+ # Tensors, but it does not work for torch.exp().
+ logvar1, logvar2 = [
+ x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
+ for x in (logvar1, logvar2)
+ ]
+
+ return 0.5 * (
+ -1.0
+ + logvar2
+ - logvar1
+ + torch.exp(logvar1 - logvar2)
+ + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
+ )
diff --git a/easyanimate/vae/ldm/modules/ema.py b/easyanimate/vae/ldm/modules/ema.py
new file mode 100644
index 0000000000000000000000000000000000000000..3657208473f12355e44a2d3a8b114dae79d215bc
--- /dev/null
+++ b/easyanimate/vae/ldm/modules/ema.py
@@ -0,0 +1,115 @@
+#-*- encoding:utf-8 -*-
+import torch
+from torch import nn
+from pytorch_lightning.callbacks import Callback
+
+class LitEma(nn.Module):
+ def __init__(self, model, decay=0.9999, use_num_upates=True):
+ super().__init__()
+ if decay < 0.0 or decay > 1.0:
+ raise ValueError('Decay must be between 0 and 1')
+
+ self.m_name2s_name = {}
+ self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32))
+ self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates
+ else torch.tensor(-1,dtype=torch.int))
+
+ for name, p in model.named_parameters():
+ if p.requires_grad:
+ #remove as '.'-character is not allowed in buffers
+ s_name = name.replace('.','')
+ self.m_name2s_name.update({name:s_name})
+ self.register_buffer(s_name,p.clone().detach().data)
+
+ self.collected_params = []
+
+ def forward(self,model):
+ decay = self.decay
+
+ if self.num_updates >= 0:
+ self.num_updates += 1
+ decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates))
+
+ one_minus_decay = 1.0 - decay
+
+ with torch.no_grad():
+ m_param = dict(model.named_parameters())
+ shadow_params = dict(self.named_buffers())
+
+ for key in m_param:
+ if m_param[key].requires_grad:
+ sname = self.m_name2s_name[key]
+ shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
+ shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
+ else:
+ assert not key in self.m_name2s_name
+
+ def copy_to(self, model):
+ m_param = dict(model.named_parameters())
+ shadow_params = dict(self.named_buffers())
+ for key in m_param:
+ if m_param[key].requires_grad:
+ m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
+ else:
+ assert not key in self.m_name2s_name
+
+ def store(self, parameters):
+ """
+ Save the current parameters for restoring later.
+ Args:
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
+ temporarily stored.
+ """
+ self.collected_params = [param.clone() for param in parameters]
+
+ def restore(self, parameters):
+ """
+ Restore the parameters stored with the `store` method.
+ Useful to validate the model with EMA parameters without affecting the
+ original optimization process. Store the parameters before the
+ `copy_to` method. After validation (or model saving), use this to
+ restore the former parameters.
+ Args:
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
+ updated with the stored parameters.
+ """
+ for c_param, param in zip(self.collected_params, parameters):
+ param.data.copy_(c_param.data)
+
+class EMACallback(Callback):
+ def __init__(self, decay=0.9999):
+ self.decay = decay
+ self.shadow_params = {}
+
+ def on_train_start(self, trainer, pl_module):
+ # initialize shadow parameters for original models
+ total_ema_cnt = 0
+ for name, param in pl_module.named_parameters():
+ if name not in self.shadow_params:
+ self.shadow_params[name] = param.data.clone()
+ else: # already in dict, maybe load from checkpoint
+ pass
+ print('will calc ema for param: %s' % name)
+ total_ema_cnt += 1
+ print('total_ema_cnt=%d' % total_ema_cnt)
+
+ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
+ # Update the shadow params at the end of each epoch
+ for name, param in pl_module.named_parameters():
+ assert name in self.shadow_params
+ new_average = (1.0 - self.decay) * param.data + self.decay * self.shadow_params[name]
+ self.shadow_params[name] = new_average.clone()
+
+ def on_save_checkpoint(self, trainer, pl_module, checkpoint):
+ # Save EMA parameters in the checkpoint
+ checkpoint['ema_params'] = self.shadow_params
+
+ def on_load_checkpoint(self, trainer, pl_module, checkpoint):
+ # Restore EMA parameters from the checkpoint
+ if 'ema_params' in checkpoint:
+ self.shadow_params = checkpoint.get('ema_params', {})
+ for k in self.shadow_params:
+ self.shadow_params[k] = self.shadow_params[k].cuda()
+ print('load shadow params from checkpoint, cnt=%d' % len(self.shadow_params))
+ else:
+ print('ema_params is not in checkpoint')
\ No newline at end of file
diff --git a/easyanimate/vae/ldm/modules/image_degradation/__init__.py b/easyanimate/vae/ldm/modules/image_degradation/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e95d680ad3c687c7fe3c45033cac0c137a6bb44e
--- /dev/null
+++ b/easyanimate/vae/ldm/modules/image_degradation/__init__.py
@@ -0,0 +1,3 @@
+from .bsrgan import degradation_bsrgan_variant as degradation_fn_bsr
+from .bsrgan_light import \
+ degradation_bsrgan_variant as degradation_fn_bsr_light
diff --git a/easyanimate/vae/ldm/modules/image_degradation/bsrgan.py b/easyanimate/vae/ldm/modules/image_degradation/bsrgan.py
new file mode 100644
index 0000000000000000000000000000000000000000..a0bb065950773bcebf47854a25ed7ff5bba0caf6
--- /dev/null
+++ b/easyanimate/vae/ldm/modules/image_degradation/bsrgan.py
@@ -0,0 +1,730 @@
+# -*- coding: utf-8 -*-
+"""
+# --------------------------------------------
+# Super-Resolution
+# --------------------------------------------
+#
+# Kai Zhang (cskaizhang@gmail.com)
+# https://github.com/cszn
+# From 2019/03--2021/08
+# --------------------------------------------
+"""
+
+import random
+from functools import partial
+
+import albumentations
+import cv2
+import numpy as np
+import scipy
+import scipy.stats as ss
+import torch
+from scipy import ndimage
+from scipy.interpolate import interp2d
+from scipy.linalg import orth
+
+from . import utils_image as util
+
+
+def modcrop_np(img, sf):
+ '''
+ Args:
+ img: numpy image, WxH or WxHxC
+ sf: scale factor
+ Return:
+ cropped image
+ '''
+ w, h = img.shape[:2]
+ im = np.copy(img)
+ return im[:w - w % sf, :h - h % sf, ...]
+
+
+"""
+# --------------------------------------------
+# anisotropic Gaussian kernels
+# --------------------------------------------
+"""
+
+
+def analytic_kernel(k):
+ """Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)"""
+ k_size = k.shape[0]
+ # Calculate the big kernels size
+ big_k = np.zeros((3 * k_size - 2, 3 * k_size - 2))
+ # Loop over the small kernel to fill the big one
+ for r in range(k_size):
+ for c in range(k_size):
+ big_k[2 * r:2 * r + k_size, 2 * c:2 * c + k_size] += k[r, c] * k
+ # Crop the edges of the big kernel to ignore very small values and increase run time of SR
+ crop = k_size // 2
+ cropped_big_k = big_k[crop:-crop, crop:-crop]
+ # Normalize to 1
+ return cropped_big_k / cropped_big_k.sum()
+
+
+def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6):
+ """ generate an anisotropic Gaussian kernel
+ Args:
+ ksize : e.g., 15, kernel size
+ theta : [0, pi], rotation angle range
+ l1 : [0.1,50], scaling of eigenvalues
+ l2 : [0.1,l1], scaling of eigenvalues
+ If l1 = l2, will get an isotropic Gaussian kernel.
+ Returns:
+ k : kernel
+ """
+
+ v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1., 0.]))
+ V = np.array([[v[0], v[1]], [v[1], -v[0]]])
+ D = np.array([[l1, 0], [0, l2]])
+ Sigma = np.dot(np.dot(V, D), np.linalg.inv(V))
+ k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize)
+
+ return k
+
+
+def gm_blur_kernel(mean, cov, size=15):
+ center = size / 2.0 + 0.5
+ k = np.zeros([size, size])
+ for y in range(size):
+ for x in range(size):
+ cy = y - center + 1
+ cx = x - center + 1
+ k[y, x] = ss.multivariate_normal.pdf([cx, cy], mean=mean, cov=cov)
+
+ k = k / np.sum(k)
+ return k
+
+
+def shift_pixel(x, sf, upper_left=True):
+ """shift pixel for super-resolution with different scale factors
+ Args:
+ x: WxHxC or WxH
+ sf: scale factor
+ upper_left: shift direction
+ """
+ h, w = x.shape[:2]
+ shift = (sf - 1) * 0.5
+ xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0)
+ if upper_left:
+ x1 = xv + shift
+ y1 = yv + shift
+ else:
+ x1 = xv - shift
+ y1 = yv - shift
+
+ x1 = np.clip(x1, 0, w - 1)
+ y1 = np.clip(y1, 0, h - 1)
+
+ if x.ndim == 2:
+ x = interp2d(xv, yv, x)(x1, y1)
+ if x.ndim == 3:
+ for i in range(x.shape[-1]):
+ x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1)
+
+ return x
+
+
+def blur(x, k):
+ '''
+ x: image, NxcxHxW
+ k: kernel, Nx1xhxw
+ '''
+ n, c = x.shape[:2]
+ p1, p2 = (k.shape[-2] - 1) // 2, (k.shape[-1] - 1) // 2
+ x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode='replicate')
+ k = k.repeat(1, c, 1, 1)
+ k = k.view(-1, 1, k.shape[2], k.shape[3])
+ x = x.view(1, -1, x.shape[2], x.shape[3])
+ x = torch.nn.functional.conv2d(x, k, bias=None, stride=1, padding=0, groups=n * c)
+ x = x.view(n, c, x.shape[2], x.shape[3])
+
+ return x
+
+
+def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10., noise_level=0):
+ """"
+ # modified version of https://github.com/assafshocher/BlindSR_dataset_generator
+ # Kai Zhang
+ # min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var
+ # max_var = 2.5 * sf
+ """
+ # Set random eigen-vals (lambdas) and angle (theta) for COV matrix
+ lambda_1 = min_var + np.random.rand() * (max_var - min_var)
+ lambda_2 = min_var + np.random.rand() * (max_var - min_var)
+ theta = np.random.rand() * np.pi # random theta
+ noise = -noise_level + np.random.rand(*k_size) * noise_level * 2
+
+ # Set COV matrix using Lambdas and Theta
+ LAMBDA = np.diag([lambda_1, lambda_2])
+ Q = np.array([[np.cos(theta), -np.sin(theta)],
+ [np.sin(theta), np.cos(theta)]])
+ SIGMA = Q @ LAMBDA @ Q.T
+ INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :]
+
+ # Set expectation position (shifting kernel for aligned image)
+ MU = k_size // 2 - 0.5 * (scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2)
+ MU = MU[None, None, :, None]
+
+ # Create meshgrid for Gaussian
+ [X, Y] = np.meshgrid(range(k_size[0]), range(k_size[1]))
+ Z = np.stack([X, Y], 2)[:, :, :, None]
+
+ # Calcualte Gaussian for every pixel of the kernel
+ ZZ = Z - MU
+ ZZ_t = ZZ.transpose(0, 1, 3, 2)
+ raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise)
+
+ # shift the kernel so it will be centered
+ # raw_kernel_centered = kernel_shift(raw_kernel, scale_factor)
+
+ # Normalize the kernel and return
+ # kernel = raw_kernel_centered / np.sum(raw_kernel_centered)
+ kernel = raw_kernel / np.sum(raw_kernel)
+ return kernel
+
+
+def fspecial_gaussian(hsize, sigma):
+ hsize = [hsize, hsize]
+ siz = [(hsize[0] - 1.0) / 2.0, (hsize[1] - 1.0) / 2.0]
+ std = sigma
+ [x, y] = np.meshgrid(np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1))
+ arg = -(x * x + y * y) / (2 * std * std)
+ h = np.exp(arg)
+ h[h < scipy.finfo(float).eps * h.max()] = 0
+ sumh = h.sum()
+ if sumh != 0:
+ h = h / sumh
+ return h
+
+
+def fspecial_laplacian(alpha):
+ alpha = max([0, min([alpha, 1])])
+ h1 = alpha / (alpha + 1)
+ h2 = (1 - alpha) / (alpha + 1)
+ h = [[h1, h2, h1], [h2, -4 / (alpha + 1), h2], [h1, h2, h1]]
+ h = np.array(h)
+ return h
+
+
+def fspecial(filter_type, *args, **kwargs):
+ '''
+ python code from:
+ https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py
+ '''
+ if filter_type == 'gaussian':
+ return fspecial_gaussian(*args, **kwargs)
+ if filter_type == 'laplacian':
+ return fspecial_laplacian(*args, **kwargs)
+
+
+"""
+# --------------------------------------------
+# degradation models
+# --------------------------------------------
+"""
+
+
+def bicubic_degradation(x, sf=3):
+ '''
+ Args:
+ x: HxWxC image, [0, 1]
+ sf: down-scale factor
+ Return:
+ bicubicly downsampled LR image
+ '''
+ x = util.imresize_np(x, scale=1 / sf)
+ return x
+
+
+def srmd_degradation(x, k, sf=3):
+ ''' blur + bicubic downsampling
+ Args:
+ x: HxWxC image, [0, 1]
+ k: hxw, double
+ sf: down-scale factor
+ Return:
+ downsampled LR image
+ Reference:
+ @inproceedings{zhang2018learning,
+ title={Learning a single convolutional super-resolution network for multiple degradations},
+ author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
+ booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
+ pages={3262--3271},
+ year={2018}
+ }
+ '''
+ x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror'
+ x = bicubic_degradation(x, sf=sf)
+ return x
+
+
+def dpsr_degradation(x, k, sf=3):
+ ''' bicubic downsampling + blur
+ Args:
+ x: HxWxC image, [0, 1]
+ k: hxw, double
+ sf: down-scale factor
+ Return:
+ downsampled LR image
+ Reference:
+ @inproceedings{zhang2019deep,
+ title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels},
+ author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
+ booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
+ pages={1671--1681},
+ year={2019}
+ }
+ '''
+ x = bicubic_degradation(x, sf=sf)
+ x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
+ return x
+
+
+def classical_degradation(x, k, sf=3):
+ ''' blur + downsampling
+ Args:
+ x: HxWxC image, [0, 1]/[0, 255]
+ k: hxw, double
+ sf: down-scale factor
+ Return:
+ downsampled LR image
+ '''
+ x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
+ # x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2))
+ st = 0
+ return x[st::sf, st::sf, ...]
+
+
+def add_sharpening(img, weight=0.5, radius=50, threshold=10):
+ """USM sharpening. borrowed from real-ESRGAN
+ Input image: I; Blurry image: B.
+ 1. K = I + weight * (I - B)
+ 2. Mask = 1 if abs(I - B) > threshold, else: 0
+ 3. Blur mask:
+ 4. Out = Mask * K + (1 - Mask) * I
+ Args:
+ img (Numpy array): Input image, HWC, BGR; float32, [0, 1].
+ weight (float): Sharp weight. Default: 1.
+ radius (float): Kernel size of Gaussian blur. Default: 50.
+ threshold (int):
+ """
+ if radius % 2 == 0:
+ radius += 1
+ blur = cv2.GaussianBlur(img, (radius, radius), 0)
+ residual = img - blur
+ mask = np.abs(residual) * 255 > threshold
+ mask = mask.astype('float32')
+ soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0)
+
+ K = img + weight * residual
+ K = np.clip(K, 0, 1)
+ return soft_mask * K + (1 - soft_mask) * img
+
+
+def add_blur(img, sf=4):
+ wd2 = 4.0 + sf
+ wd = 2.0 + 0.2 * sf
+ if random.random() < 0.5:
+ l1 = wd2 * random.random()
+ l2 = wd2 * random.random()
+ k = anisotropic_Gaussian(ksize=2 * random.randint(2, 11) + 3, theta=random.random() * np.pi, l1=l1, l2=l2)
+ else:
+ k = fspecial('gaussian', 2 * random.randint(2, 11) + 3, wd * random.random())
+ img = ndimage.filters.convolve(img, np.expand_dims(k, axis=2), mode='mirror')
+
+ return img
+
+
+def add_resize(img, sf=4):
+ rnum = np.random.rand()
+ if rnum > 0.8: # up
+ sf1 = random.uniform(1, 2)
+ elif rnum < 0.7: # down
+ sf1 = random.uniform(0.5 / sf, 1)
+ else:
+ sf1 = 1.0
+ img = cv2.resize(img, (int(sf1 * img.shape[1]), int(sf1 * img.shape[0])), interpolation=random.choice([1, 2, 3]))
+ img = np.clip(img, 0.0, 1.0)
+
+ return img
+
+
+# def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
+# noise_level = random.randint(noise_level1, noise_level2)
+# rnum = np.random.rand()
+# if rnum > 0.6: # add color Gaussian noise
+# img += np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
+# elif rnum < 0.4: # add grayscale Gaussian noise
+# img += np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
+# else: # add noise
+# L = noise_level2 / 255.
+# D = np.diag(np.random.rand(3))
+# U = orth(np.random.rand(3, 3))
+# conv = np.dot(np.dot(np.transpose(U), D), U)
+# img += np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
+# img = np.clip(img, 0.0, 1.0)
+# return img
+
+def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
+ noise_level = random.randint(noise_level1, noise_level2)
+ rnum = np.random.rand()
+ if rnum > 0.6: # add color Gaussian noise
+ img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
+ elif rnum < 0.4: # add grayscale Gaussian noise
+ img = img + np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
+ else: # add noise
+ L = noise_level2 / 255.
+ D = np.diag(np.random.rand(3))
+ U = orth(np.random.rand(3, 3))
+ conv = np.dot(np.dot(np.transpose(U), D), U)
+ img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
+ img = np.clip(img, 0.0, 1.0)
+ return img
+
+
+def add_speckle_noise(img, noise_level1=2, noise_level2=25):
+ noise_level = random.randint(noise_level1, noise_level2)
+ img = np.clip(img, 0.0, 1.0)
+ rnum = random.random()
+ if rnum > 0.6:
+ img += img * np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
+ elif rnum < 0.4:
+ img += img * np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
+ else:
+ L = noise_level2 / 255.
+ D = np.diag(np.random.rand(3))
+ U = orth(np.random.rand(3, 3))
+ conv = np.dot(np.dot(np.transpose(U), D), U)
+ img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
+ img = np.clip(img, 0.0, 1.0)
+ return img
+
+
+def add_Poisson_noise(img):
+ img = np.clip((img * 255.0).round(), 0, 255) / 255.
+ vals = 10 ** (2 * random.random() + 2.0) # [2, 4]
+ if random.random() < 0.5:
+ img = np.random.poisson(img * vals).astype(np.float32) / vals
+ else:
+ img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114])
+ img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255.
+ noise_gray = np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray
+ img += noise_gray[:, :, np.newaxis]
+ img = np.clip(img, 0.0, 1.0)
+ return img
+
+
+def add_JPEG_noise(img):
+ quality_factor = random.randint(30, 95)
+ img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR)
+ result, encimg = cv2.imencode('.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor])
+ img = cv2.imdecode(encimg, 1)
+ img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB)
+ return img
+
+
+def random_crop(lq, hq, sf=4, lq_patchsize=64):
+ h, w = lq.shape[:2]
+ rnd_h = random.randint(0, h - lq_patchsize)
+ rnd_w = random.randint(0, w - lq_patchsize)
+ lq = lq[rnd_h:rnd_h + lq_patchsize, rnd_w:rnd_w + lq_patchsize, :]
+
+ rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf)
+ hq = hq[rnd_h_H:rnd_h_H + lq_patchsize * sf, rnd_w_H:rnd_w_H + lq_patchsize * sf, :]
+ return lq, hq
+
+
+def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None):
+ """
+ This is the degradation model of BSRGAN from the paper
+ "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
+ ----------
+ img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf)
+ sf: scale factor
+ isp_model: camera ISP model
+ Returns
+ -------
+ img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
+ hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
+ """
+ isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
+ sf_ori = sf
+
+ h1, w1 = img.shape[:2]
+ img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop
+ h, w = img.shape[:2]
+
+ if h < lq_patchsize * sf or w < lq_patchsize * sf:
+ raise ValueError(f'img size ({h1}X{w1}) is too small!')
+
+ hq = img.copy()
+
+ if sf == 4 and random.random() < scale2_prob: # downsample1
+ if np.random.rand() < 0.5:
+ img = cv2.resize(img, (int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])),
+ interpolation=random.choice([1, 2, 3]))
+ else:
+ img = util.imresize_np(img, 1 / 2, True)
+ img = np.clip(img, 0.0, 1.0)
+ sf = 2
+
+ shuffle_order = random.sample(range(7), 7)
+ idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
+ if idx1 > idx2: # keep downsample3 last
+ shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1]
+
+ for i in shuffle_order:
+
+ if i == 0:
+ img = add_blur(img, sf=sf)
+
+ elif i == 1:
+ img = add_blur(img, sf=sf)
+
+ elif i == 2:
+ a, b = img.shape[1], img.shape[0]
+ # downsample2
+ if random.random() < 0.75:
+ sf1 = random.uniform(1, 2 * sf)
+ img = cv2.resize(img, (int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])),
+ interpolation=random.choice([1, 2, 3]))
+ else:
+ k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))
+ k_shifted = shift_pixel(k, sf)
+ k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
+ img = ndimage.filters.convolve(img, np.expand_dims(k_shifted, axis=2), mode='mirror')
+ img = img[0::sf, 0::sf, ...] # nearest downsampling
+ img = np.clip(img, 0.0, 1.0)
+
+ elif i == 3:
+ # downsample3
+ img = cv2.resize(img, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3]))
+ img = np.clip(img, 0.0, 1.0)
+
+ elif i == 4:
+ # add Gaussian noise
+ img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25)
+
+ elif i == 5:
+ # add JPEG noise
+ if random.random() < jpeg_prob:
+ img = add_JPEG_noise(img)
+
+ elif i == 6:
+ # add processed camera sensor noise
+ if random.random() < isp_prob and isp_model is not None:
+ with torch.no_grad():
+ img, hq = isp_model.forward(img.copy(), hq)
+
+ # add final JPEG compression noise
+ img = add_JPEG_noise(img)
+
+ # random crop
+ img, hq = random_crop(img, hq, sf_ori, lq_patchsize)
+
+ return img, hq
+
+
+# todo no isp_model?
+def degradation_bsrgan_variant(image, sf=4, isp_model=None):
+ """
+ This is the degradation model of BSRGAN from the paper
+ "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
+ ----------
+ sf: scale factor
+ isp_model: camera ISP model
+ Returns
+ -------
+ img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
+ hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
+ """
+ image = util.uint2single(image)
+ isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
+ sf_ori = sf
+
+ h1, w1 = image.shape[:2]
+ image = image.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop
+ h, w = image.shape[:2]
+
+ hq = image.copy()
+
+ if sf == 4 and random.random() < scale2_prob: # downsample1
+ if np.random.rand() < 0.5:
+ image = cv2.resize(image, (int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])),
+ interpolation=random.choice([1, 2, 3]))
+ else:
+ image = util.imresize_np(image, 1 / 2, True)
+ image = np.clip(image, 0.0, 1.0)
+ sf = 2
+
+ shuffle_order = random.sample(range(7), 7)
+ idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
+ if idx1 > idx2: # keep downsample3 last
+ shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1]
+
+ for i in shuffle_order:
+
+ if i == 0:
+ image = add_blur(image, sf=sf)
+
+ elif i == 1:
+ image = add_blur(image, sf=sf)
+
+ elif i == 2:
+ a, b = image.shape[1], image.shape[0]
+ # downsample2
+ if random.random() < 0.75:
+ sf1 = random.uniform(1, 2 * sf)
+ image = cv2.resize(image, (int(1 / sf1 * image.shape[1]), int(1 / sf1 * image.shape[0])),
+ interpolation=random.choice([1, 2, 3]))
+ else:
+ k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))
+ k_shifted = shift_pixel(k, sf)
+ k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
+ image = ndimage.filters.convolve(image, np.expand_dims(k_shifted, axis=2), mode='mirror')
+ image = image[0::sf, 0::sf, ...] # nearest downsampling
+ image = np.clip(image, 0.0, 1.0)
+
+ elif i == 3:
+ # downsample3
+ image = cv2.resize(image, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3]))
+ image = np.clip(image, 0.0, 1.0)
+
+ elif i == 4:
+ # add Gaussian noise
+ image = add_Gaussian_noise(image, noise_level1=2, noise_level2=25)
+
+ elif i == 5:
+ # add JPEG noise
+ if random.random() < jpeg_prob:
+ image = add_JPEG_noise(image)
+
+ # elif i == 6:
+ # # add processed camera sensor noise
+ # if random.random() < isp_prob and isp_model is not None:
+ # with torch.no_grad():
+ # img, hq = isp_model.forward(img.copy(), hq)
+
+ # add final JPEG compression noise
+ image = add_JPEG_noise(image)
+ image = util.single2uint(image)
+ example = {"image":image}
+ return example
+
+
+# TODO incase there is a pickle error one needs to replace a += x with a = a + x in add_speckle_noise etc...
+def degradation_bsrgan_plus(img, sf=4, shuffle_prob=0.5, use_sharp=True, lq_patchsize=64, isp_model=None):
+ """
+ This is an extended degradation model by combining
+ the degradation models of BSRGAN and Real-ESRGAN
+ ----------
+ img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf)
+ sf: scale factor
+ use_shuffle: the degradation shuffle
+ use_sharp: sharpening the img
+ Returns
+ -------
+ img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
+ hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
+ """
+
+ h1, w1 = img.shape[:2]
+ img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop
+ h, w = img.shape[:2]
+
+ if h < lq_patchsize * sf or w < lq_patchsize * sf:
+ raise ValueError(f'img size ({h1}X{w1}) is too small!')
+
+ if use_sharp:
+ img = add_sharpening(img)
+ hq = img.copy()
+
+ if random.random() < shuffle_prob:
+ shuffle_order = random.sample(range(13), 13)
+ else:
+ shuffle_order = list(range(13))
+ # local shuffle for noise, JPEG is always the last one
+ shuffle_order[2:6] = random.sample(shuffle_order[2:6], len(range(2, 6)))
+ shuffle_order[9:13] = random.sample(shuffle_order[9:13], len(range(9, 13)))
+
+ poisson_prob, speckle_prob, isp_prob = 0.1, 0.1, 0.1
+
+ for i in shuffle_order:
+ if i == 0:
+ img = add_blur(img, sf=sf)
+ elif i == 1:
+ img = add_resize(img, sf=sf)
+ elif i == 2:
+ img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25)
+ elif i == 3:
+ if random.random() < poisson_prob:
+ img = add_Poisson_noise(img)
+ elif i == 4:
+ if random.random() < speckle_prob:
+ img = add_speckle_noise(img)
+ elif i == 5:
+ if random.random() < isp_prob and isp_model is not None:
+ with torch.no_grad():
+ img, hq = isp_model.forward(img.copy(), hq)
+ elif i == 6:
+ img = add_JPEG_noise(img)
+ elif i == 7:
+ img = add_blur(img, sf=sf)
+ elif i == 8:
+ img = add_resize(img, sf=sf)
+ elif i == 9:
+ img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25)
+ elif i == 10:
+ if random.random() < poisson_prob:
+ img = add_Poisson_noise(img)
+ elif i == 11:
+ if random.random() < speckle_prob:
+ img = add_speckle_noise(img)
+ elif i == 12:
+ if random.random() < isp_prob and isp_model is not None:
+ with torch.no_grad():
+ img, hq = isp_model.forward(img.copy(), hq)
+ else:
+ print('check the shuffle!')
+
+ # resize to desired size
+ img = cv2.resize(img, (int(1 / sf * hq.shape[1]), int(1 / sf * hq.shape[0])),
+ interpolation=random.choice([1, 2, 3]))
+
+ # add final JPEG compression noise
+ img = add_JPEG_noise(img)
+
+ # random crop
+ img, hq = random_crop(img, hq, sf, lq_patchsize)
+
+ return img, hq
+
+
+if __name__ == '__main__':
+ print("hey")
+ img = util.imread_uint('utils/test.png', 3)
+ print(img)
+ img = util.uint2single(img)
+ print(img)
+ img = img[:448, :448]
+ h = img.shape[0] // 4
+ print("resizing to", h)
+ sf = 4
+ deg_fn = partial(degradation_bsrgan_variant, sf=sf)
+ for i in range(20):
+ print(i)
+ img_lq = deg_fn(img)
+ print(img_lq)
+ img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img)["image"]
+ print(img_lq.shape)
+ print("bicubic", img_lq_bicubic.shape)
+ print(img_hq.shape)
+ lq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
+ interpolation=0)
+ lq_bicubic_nearest = cv2.resize(util.single2uint(img_lq_bicubic), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
+ interpolation=0)
+ img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1)
+ util.imsave(img_concat, str(i) + '.png')
+
+
diff --git a/easyanimate/vae/ldm/modules/image_degradation/bsrgan_light.py b/easyanimate/vae/ldm/modules/image_degradation/bsrgan_light.py
new file mode 100644
index 0000000000000000000000000000000000000000..2b5be11372c9fa41a72a228e46edf0e88f53d2f0
--- /dev/null
+++ b/easyanimate/vae/ldm/modules/image_degradation/bsrgan_light.py
@@ -0,0 +1,650 @@
+# -*- coding: utf-8 -*-
+import random
+from functools import partial
+
+import albumentations
+import cv2
+import numpy as np
+import scipy
+import scipy.stats as ss
+import torch
+from scipy import ndimage
+from scipy.interpolate import interp2d
+from scipy.linalg import orth
+
+from . import utils_image as util
+
+"""
+# --------------------------------------------
+# Super-Resolution
+# --------------------------------------------
+#
+# Kai Zhang (cskaizhang@gmail.com)
+# https://github.com/cszn
+# From 2019/03--2021/08
+# --------------------------------------------
+"""
+
+
+def modcrop_np(img, sf):
+ '''
+ Args:
+ img: numpy image, WxH or WxHxC
+ sf: scale factor
+ Return:
+ cropped image
+ '''
+ w, h = img.shape[:2]
+ im = np.copy(img)
+ return im[:w - w % sf, :h - h % sf, ...]
+
+
+"""
+# --------------------------------------------
+# anisotropic Gaussian kernels
+# --------------------------------------------
+"""
+
+
+def analytic_kernel(k):
+ """Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)"""
+ k_size = k.shape[0]
+ # Calculate the big kernels size
+ big_k = np.zeros((3 * k_size - 2, 3 * k_size - 2))
+ # Loop over the small kernel to fill the big one
+ for r in range(k_size):
+ for c in range(k_size):
+ big_k[2 * r:2 * r + k_size, 2 * c:2 * c + k_size] += k[r, c] * k
+ # Crop the edges of the big kernel to ignore very small values and increase run time of SR
+ crop = k_size // 2
+ cropped_big_k = big_k[crop:-crop, crop:-crop]
+ # Normalize to 1
+ return cropped_big_k / cropped_big_k.sum()
+
+
+def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6):
+ """ generate an anisotropic Gaussian kernel
+ Args:
+ ksize : e.g., 15, kernel size
+ theta : [0, pi], rotation angle range
+ l1 : [0.1,50], scaling of eigenvalues
+ l2 : [0.1,l1], scaling of eigenvalues
+ If l1 = l2, will get an isotropic Gaussian kernel.
+ Returns:
+ k : kernel
+ """
+
+ v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1., 0.]))
+ V = np.array([[v[0], v[1]], [v[1], -v[0]]])
+ D = np.array([[l1, 0], [0, l2]])
+ Sigma = np.dot(np.dot(V, D), np.linalg.inv(V))
+ k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize)
+
+ return k
+
+
+def gm_blur_kernel(mean, cov, size=15):
+ center = size / 2.0 + 0.5
+ k = np.zeros([size, size])
+ for y in range(size):
+ for x in range(size):
+ cy = y - center + 1
+ cx = x - center + 1
+ k[y, x] = ss.multivariate_normal.pdf([cx, cy], mean=mean, cov=cov)
+
+ k = k / np.sum(k)
+ return k
+
+
+def shift_pixel(x, sf, upper_left=True):
+ """shift pixel for super-resolution with different scale factors
+ Args:
+ x: WxHxC or WxH
+ sf: scale factor
+ upper_left: shift direction
+ """
+ h, w = x.shape[:2]
+ shift = (sf - 1) * 0.5
+ xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0)
+ if upper_left:
+ x1 = xv + shift
+ y1 = yv + shift
+ else:
+ x1 = xv - shift
+ y1 = yv - shift
+
+ x1 = np.clip(x1, 0, w - 1)
+ y1 = np.clip(y1, 0, h - 1)
+
+ if x.ndim == 2:
+ x = interp2d(xv, yv, x)(x1, y1)
+ if x.ndim == 3:
+ for i in range(x.shape[-1]):
+ x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1)
+
+ return x
+
+
+def blur(x, k):
+ '''
+ x: image, NxcxHxW
+ k: kernel, Nx1xhxw
+ '''
+ n, c = x.shape[:2]
+ p1, p2 = (k.shape[-2] - 1) // 2, (k.shape[-1] - 1) // 2
+ x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode='replicate')
+ k = k.repeat(1, c, 1, 1)
+ k = k.view(-1, 1, k.shape[2], k.shape[3])
+ x = x.view(1, -1, x.shape[2], x.shape[3])
+ x = torch.nn.functional.conv2d(x, k, bias=None, stride=1, padding=0, groups=n * c)
+ x = x.view(n, c, x.shape[2], x.shape[3])
+
+ return x
+
+
+def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10., noise_level=0):
+ """"
+ # modified version of https://github.com/assafshocher/BlindSR_dataset_generator
+ # Kai Zhang
+ # min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var
+ # max_var = 2.5 * sf
+ """
+ # Set random eigen-vals (lambdas) and angle (theta) for COV matrix
+ lambda_1 = min_var + np.random.rand() * (max_var - min_var)
+ lambda_2 = min_var + np.random.rand() * (max_var - min_var)
+ theta = np.random.rand() * np.pi # random theta
+ noise = -noise_level + np.random.rand(*k_size) * noise_level * 2
+
+ # Set COV matrix using Lambdas and Theta
+ LAMBDA = np.diag([lambda_1, lambda_2])
+ Q = np.array([[np.cos(theta), -np.sin(theta)],
+ [np.sin(theta), np.cos(theta)]])
+ SIGMA = Q @ LAMBDA @ Q.T
+ INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :]
+
+ # Set expectation position (shifting kernel for aligned image)
+ MU = k_size // 2 - 0.5 * (scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2)
+ MU = MU[None, None, :, None]
+
+ # Create meshgrid for Gaussian
+ [X, Y] = np.meshgrid(range(k_size[0]), range(k_size[1]))
+ Z = np.stack([X, Y], 2)[:, :, :, None]
+
+ # Calcualte Gaussian for every pixel of the kernel
+ ZZ = Z - MU
+ ZZ_t = ZZ.transpose(0, 1, 3, 2)
+ raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise)
+
+ # shift the kernel so it will be centered
+ # raw_kernel_centered = kernel_shift(raw_kernel, scale_factor)
+
+ # Normalize the kernel and return
+ # kernel = raw_kernel_centered / np.sum(raw_kernel_centered)
+ kernel = raw_kernel / np.sum(raw_kernel)
+ return kernel
+
+
+def fspecial_gaussian(hsize, sigma):
+ hsize = [hsize, hsize]
+ siz = [(hsize[0] - 1.0) / 2.0, (hsize[1] - 1.0) / 2.0]
+ std = sigma
+ [x, y] = np.meshgrid(np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1))
+ arg = -(x * x + y * y) / (2 * std * std)
+ h = np.exp(arg)
+ h[h < scipy.finfo(float).eps * h.max()] = 0
+ sumh = h.sum()
+ if sumh != 0:
+ h = h / sumh
+ return h
+
+
+def fspecial_laplacian(alpha):
+ alpha = max([0, min([alpha, 1])])
+ h1 = alpha / (alpha + 1)
+ h2 = (1 - alpha) / (alpha + 1)
+ h = [[h1, h2, h1], [h2, -4 / (alpha + 1), h2], [h1, h2, h1]]
+ h = np.array(h)
+ return h
+
+
+def fspecial(filter_type, *args, **kwargs):
+ '''
+ python code from:
+ https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py
+ '''
+ if filter_type == 'gaussian':
+ return fspecial_gaussian(*args, **kwargs)
+ if filter_type == 'laplacian':
+ return fspecial_laplacian(*args, **kwargs)
+
+
+"""
+# --------------------------------------------
+# degradation models
+# --------------------------------------------
+"""
+
+
+def bicubic_degradation(x, sf=3):
+ '''
+ Args:
+ x: HxWxC image, [0, 1]
+ sf: down-scale factor
+ Return:
+ bicubicly downsampled LR image
+ '''
+ x = util.imresize_np(x, scale=1 / sf)
+ return x
+
+
+def srmd_degradation(x, k, sf=3):
+ ''' blur + bicubic downsampling
+ Args:
+ x: HxWxC image, [0, 1]
+ k: hxw, double
+ sf: down-scale factor
+ Return:
+ downsampled LR image
+ Reference:
+ @inproceedings{zhang2018learning,
+ title={Learning a single convolutional super-resolution network for multiple degradations},
+ author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
+ booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
+ pages={3262--3271},
+ year={2018}
+ }
+ '''
+ x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror'
+ x = bicubic_degradation(x, sf=sf)
+ return x
+
+
+def dpsr_degradation(x, k, sf=3):
+ ''' bicubic downsampling + blur
+ Args:
+ x: HxWxC image, [0, 1]
+ k: hxw, double
+ sf: down-scale factor
+ Return:
+ downsampled LR image
+ Reference:
+ @inproceedings{zhang2019deep,
+ title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels},
+ author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
+ booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
+ pages={1671--1681},
+ year={2019}
+ }
+ '''
+ x = bicubic_degradation(x, sf=sf)
+ x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
+ return x
+
+
+def classical_degradation(x, k, sf=3):
+ ''' blur + downsampling
+ Args:
+ x: HxWxC image, [0, 1]/[0, 255]
+ k: hxw, double
+ sf: down-scale factor
+ Return:
+ downsampled LR image
+ '''
+ x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
+ # x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2))
+ st = 0
+ return x[st::sf, st::sf, ...]
+
+
+def add_sharpening(img, weight=0.5, radius=50, threshold=10):
+ """USM sharpening. borrowed from real-ESRGAN
+ Input image: I; Blurry image: B.
+ 1. K = I + weight * (I - B)
+ 2. Mask = 1 if abs(I - B) > threshold, else: 0
+ 3. Blur mask:
+ 4. Out = Mask * K + (1 - Mask) * I
+ Args:
+ img (Numpy array): Input image, HWC, BGR; float32, [0, 1].
+ weight (float): Sharp weight. Default: 1.
+ radius (float): Kernel size of Gaussian blur. Default: 50.
+ threshold (int):
+ """
+ if radius % 2 == 0:
+ radius += 1
+ blur = cv2.GaussianBlur(img, (radius, radius), 0)
+ residual = img - blur
+ mask = np.abs(residual) * 255 > threshold
+ mask = mask.astype('float32')
+ soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0)
+
+ K = img + weight * residual
+ K = np.clip(K, 0, 1)
+ return soft_mask * K + (1 - soft_mask) * img
+
+
+def add_blur(img, sf=4):
+ wd2 = 4.0 + sf
+ wd = 2.0 + 0.2 * sf
+
+ wd2 = wd2/4
+ wd = wd/4
+
+ if random.random() < 0.5:
+ l1 = wd2 * random.random()
+ l2 = wd2 * random.random()
+ k = anisotropic_Gaussian(ksize=random.randint(2, 11) + 3, theta=random.random() * np.pi, l1=l1, l2=l2)
+ else:
+ k = fspecial('gaussian', random.randint(2, 4) + 3, wd * random.random())
+ img = ndimage.filters.convolve(img, np.expand_dims(k, axis=2), mode='mirror')
+
+ return img
+
+
+def add_resize(img, sf=4):
+ rnum = np.random.rand()
+ if rnum > 0.8: # up
+ sf1 = random.uniform(1, 2)
+ elif rnum < 0.7: # down
+ sf1 = random.uniform(0.5 / sf, 1)
+ else:
+ sf1 = 1.0
+ img = cv2.resize(img, (int(sf1 * img.shape[1]), int(sf1 * img.shape[0])), interpolation=random.choice([1, 2, 3]))
+ img = np.clip(img, 0.0, 1.0)
+
+ return img
+
+
+# def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
+# noise_level = random.randint(noise_level1, noise_level2)
+# rnum = np.random.rand()
+# if rnum > 0.6: # add color Gaussian noise
+# img += np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
+# elif rnum < 0.4: # add grayscale Gaussian noise
+# img += np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
+# else: # add noise
+# L = noise_level2 / 255.
+# D = np.diag(np.random.rand(3))
+# U = orth(np.random.rand(3, 3))
+# conv = np.dot(np.dot(np.transpose(U), D), U)
+# img += np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
+# img = np.clip(img, 0.0, 1.0)
+# return img
+
+def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
+ noise_level = random.randint(noise_level1, noise_level2)
+ rnum = np.random.rand()
+ if rnum > 0.6: # add color Gaussian noise
+ img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
+ elif rnum < 0.4: # add grayscale Gaussian noise
+ img = img + np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
+ else: # add noise
+ L = noise_level2 / 255.
+ D = np.diag(np.random.rand(3))
+ U = orth(np.random.rand(3, 3))
+ conv = np.dot(np.dot(np.transpose(U), D), U)
+ img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
+ img = np.clip(img, 0.0, 1.0)
+ return img
+
+
+def add_speckle_noise(img, noise_level1=2, noise_level2=25):
+ noise_level = random.randint(noise_level1, noise_level2)
+ img = np.clip(img, 0.0, 1.0)
+ rnum = random.random()
+ if rnum > 0.6:
+ img += img * np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
+ elif rnum < 0.4:
+ img += img * np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
+ else:
+ L = noise_level2 / 255.
+ D = np.diag(np.random.rand(3))
+ U = orth(np.random.rand(3, 3))
+ conv = np.dot(np.dot(np.transpose(U), D), U)
+ img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
+ img = np.clip(img, 0.0, 1.0)
+ return img
+
+
+def add_Poisson_noise(img):
+ img = np.clip((img * 255.0).round(), 0, 255) / 255.
+ vals = 10 ** (2 * random.random() + 2.0) # [2, 4]
+ if random.random() < 0.5:
+ img = np.random.poisson(img * vals).astype(np.float32) / vals
+ else:
+ img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114])
+ img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255.
+ noise_gray = np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray
+ img += noise_gray[:, :, np.newaxis]
+ img = np.clip(img, 0.0, 1.0)
+ return img
+
+
+def add_JPEG_noise(img):
+ quality_factor = random.randint(80, 95)
+ img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR)
+ result, encimg = cv2.imencode('.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor])
+ img = cv2.imdecode(encimg, 1)
+ img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB)
+ return img
+
+
+def random_crop(lq, hq, sf=4, lq_patchsize=64):
+ h, w = lq.shape[:2]
+ rnd_h = random.randint(0, h - lq_patchsize)
+ rnd_w = random.randint(0, w - lq_patchsize)
+ lq = lq[rnd_h:rnd_h + lq_patchsize, rnd_w:rnd_w + lq_patchsize, :]
+
+ rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf)
+ hq = hq[rnd_h_H:rnd_h_H + lq_patchsize * sf, rnd_w_H:rnd_w_H + lq_patchsize * sf, :]
+ return lq, hq
+
+
+def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None):
+ """
+ This is the degradation model of BSRGAN from the paper
+ "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
+ ----------
+ img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf)
+ sf: scale factor
+ isp_model: camera ISP model
+ Returns
+ -------
+ img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
+ hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
+ """
+ isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
+ sf_ori = sf
+
+ h1, w1 = img.shape[:2]
+ img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop
+ h, w = img.shape[:2]
+
+ if h < lq_patchsize * sf or w < lq_patchsize * sf:
+ raise ValueError(f'img size ({h1}X{w1}) is too small!')
+
+ hq = img.copy()
+
+ if sf == 4 and random.random() < scale2_prob: # downsample1
+ if np.random.rand() < 0.5:
+ img = cv2.resize(img, (int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])),
+ interpolation=random.choice([1, 2, 3]))
+ else:
+ img = util.imresize_np(img, 1 / 2, True)
+ img = np.clip(img, 0.0, 1.0)
+ sf = 2
+
+ shuffle_order = random.sample(range(7), 7)
+ idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
+ if idx1 > idx2: # keep downsample3 last
+ shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1]
+
+ for i in shuffle_order:
+
+ if i == 0:
+ img = add_blur(img, sf=sf)
+
+ elif i == 1:
+ img = add_blur(img, sf=sf)
+
+ elif i == 2:
+ a, b = img.shape[1], img.shape[0]
+ # downsample2
+ if random.random() < 0.75:
+ sf1 = random.uniform(1, 2 * sf)
+ img = cv2.resize(img, (int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])),
+ interpolation=random.choice([1, 2, 3]))
+ else:
+ k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))
+ k_shifted = shift_pixel(k, sf)
+ k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
+ img = ndimage.filters.convolve(img, np.expand_dims(k_shifted, axis=2), mode='mirror')
+ img = img[0::sf, 0::sf, ...] # nearest downsampling
+ img = np.clip(img, 0.0, 1.0)
+
+ elif i == 3:
+ # downsample3
+ img = cv2.resize(img, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3]))
+ img = np.clip(img, 0.0, 1.0)
+
+ elif i == 4:
+ # add Gaussian noise
+ img = add_Gaussian_noise(img, noise_level1=2, noise_level2=8)
+
+ elif i == 5:
+ # add JPEG noise
+ if random.random() < jpeg_prob:
+ img = add_JPEG_noise(img)
+
+ elif i == 6:
+ # add processed camera sensor noise
+ if random.random() < isp_prob and isp_model is not None:
+ with torch.no_grad():
+ img, hq = isp_model.forward(img.copy(), hq)
+
+ # add final JPEG compression noise
+ img = add_JPEG_noise(img)
+
+ # random crop
+ img, hq = random_crop(img, hq, sf_ori, lq_patchsize)
+
+ return img, hq
+
+
+# todo no isp_model?
+def degradation_bsrgan_variant(image, sf=4, isp_model=None):
+ """
+ This is the degradation model of BSRGAN from the paper
+ "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
+ ----------
+ sf: scale factor
+ isp_model: camera ISP model
+ Returns
+ -------
+ img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
+ hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
+ """
+ image = util.uint2single(image)
+ isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
+ sf_ori = sf
+
+ h1, w1 = image.shape[:2]
+ image = image.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop
+ h, w = image.shape[:2]
+
+ hq = image.copy()
+
+ if sf == 4 and random.random() < scale2_prob: # downsample1
+ if np.random.rand() < 0.5:
+ image = cv2.resize(image, (int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])),
+ interpolation=random.choice([1, 2, 3]))
+ else:
+ image = util.imresize_np(image, 1 / 2, True)
+ image = np.clip(image, 0.0, 1.0)
+ sf = 2
+
+ shuffle_order = random.sample(range(7), 7)
+ idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
+ if idx1 > idx2: # keep downsample3 last
+ shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1]
+
+ for i in shuffle_order:
+
+ if i == 0:
+ image = add_blur(image, sf=sf)
+
+ # elif i == 1:
+ # image = add_blur(image, sf=sf)
+
+ if i == 0:
+ pass
+
+ elif i == 2:
+ a, b = image.shape[1], image.shape[0]
+ # downsample2
+ if random.random() < 0.8:
+ sf1 = random.uniform(1, 2 * sf)
+ image = cv2.resize(image, (int(1 / sf1 * image.shape[1]), int(1 / sf1 * image.shape[0])),
+ interpolation=random.choice([1, 2, 3]))
+ else:
+ k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))
+ k_shifted = shift_pixel(k, sf)
+ k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
+ image = ndimage.filters.convolve(image, np.expand_dims(k_shifted, axis=2), mode='mirror')
+ image = image[0::sf, 0::sf, ...] # nearest downsampling
+
+ image = np.clip(image, 0.0, 1.0)
+
+ elif i == 3:
+ # downsample3
+ image = cv2.resize(image, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3]))
+ image = np.clip(image, 0.0, 1.0)
+
+ elif i == 4:
+ # add Gaussian noise
+ image = add_Gaussian_noise(image, noise_level1=1, noise_level2=2)
+
+ elif i == 5:
+ # add JPEG noise
+ if random.random() < jpeg_prob:
+ image = add_JPEG_noise(image)
+ #
+ # elif i == 6:
+ # # add processed camera sensor noise
+ # if random.random() < isp_prob and isp_model is not None:
+ # with torch.no_grad():
+ # img, hq = isp_model.forward(img.copy(), hq)
+
+ # add final JPEG compression noise
+ image = add_JPEG_noise(image)
+ image = util.single2uint(image)
+ example = {"image": image}
+ return example
+
+
+
+
+if __name__ == '__main__':
+ print("hey")
+ img = util.imread_uint('utils/test.png', 3)
+ img = img[:448, :448]
+ h = img.shape[0] // 4
+ print("resizing to", h)
+ sf = 4
+ deg_fn = partial(degradation_bsrgan_variant, sf=sf)
+ for i in range(20):
+ print(i)
+ img_hq = img
+ img_lq = deg_fn(img)["image"]
+ img_hq, img_lq = util.uint2single(img_hq), util.uint2single(img_lq)
+ print(img_lq)
+ img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img_hq)["image"]
+ print(img_lq.shape)
+ print("bicubic", img_lq_bicubic.shape)
+ print(img_hq.shape)
+ lq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
+ interpolation=0)
+ lq_bicubic_nearest = cv2.resize(util.single2uint(img_lq_bicubic),
+ (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
+ interpolation=0)
+ img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1)
+ util.imsave(img_concat, str(i) + '.png')
diff --git a/easyanimate/vae/ldm/modules/image_degradation/utils/test.png b/easyanimate/vae/ldm/modules/image_degradation/utils/test.png
new file mode 100644
index 0000000000000000000000000000000000000000..4249b43de0f22707758d13c240268a401642f6e6
Binary files /dev/null and b/easyanimate/vae/ldm/modules/image_degradation/utils/test.png differ
diff --git a/easyanimate/vae/ldm/modules/image_degradation/utils_image.py b/easyanimate/vae/ldm/modules/image_degradation/utils_image.py
new file mode 100644
index 0000000000000000000000000000000000000000..5d72b07394c3bbfa021b6980ee21f754fccd7633
--- /dev/null
+++ b/easyanimate/vae/ldm/modules/image_degradation/utils_image.py
@@ -0,0 +1,918 @@
+import math
+import os
+import random
+from datetime import datetime
+
+import cv2
+import numpy as np
+import torch
+from torchvision.utils import make_grid
+
+#import matplotlib.pyplot as plt # TODO: check with Dominik, also bsrgan.py vs bsrgan_light.py
+
+
+os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
+
+
+'''
+# --------------------------------------------
+# Kai Zhang (github: https://github.com/cszn)
+# 03/Mar/2019
+# --------------------------------------------
+# https://github.com/twhui/SRGAN-pyTorch
+# https://github.com/xinntao/BasicSR
+# --------------------------------------------
+'''
+
+
+IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tif']
+
+
+def is_image_file(filename):
+ return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
+
+
+def get_timestamp():
+ return datetime.now().strftime('%y%m%d-%H%M%S')
+
+
+def imshow(x, title=None, cbar=False, figsize=None):
+ plt.figure(figsize=figsize)
+ plt.imshow(np.squeeze(x), interpolation='nearest', cmap='gray')
+ if title:
+ plt.title(title)
+ if cbar:
+ plt.colorbar()
+ plt.show()
+
+
+def surf(Z, cmap='rainbow', figsize=None):
+ plt.figure(figsize=figsize)
+ ax3 = plt.axes(projection='3d')
+
+ w, h = Z.shape[:2]
+ xx = np.arange(0,w,1)
+ yy = np.arange(0,h,1)
+ X, Y = np.meshgrid(xx, yy)
+ ax3.plot_surface(X,Y,Z,cmap=cmap)
+ #ax3.contour(X,Y,Z, zdim='z',offset=-2,cmap=cmap)
+ plt.show()
+
+
+'''
+# --------------------------------------------
+# get image pathes
+# --------------------------------------------
+'''
+
+
+def get_image_paths(dataroot):
+ paths = None # return None if dataroot is None
+ if dataroot is not None:
+ paths = sorted(_get_paths_from_images(dataroot))
+ return paths
+
+
+def _get_paths_from_images(path):
+ assert os.path.isdir(path), '{:s} is not a valid directory'.format(path)
+ images = []
+ for dirpath, _, fnames in sorted(os.walk(path)):
+ for fname in sorted(fnames):
+ if is_image_file(fname):
+ img_path = os.path.join(dirpath, fname)
+ images.append(img_path)
+ assert images, '{:s} has no valid image file'.format(path)
+ return images
+
+
+'''
+# --------------------------------------------
+# split large images into small images
+# --------------------------------------------
+'''
+
+
+def patches_from_image(img, p_size=512, p_overlap=64, p_max=800):
+ w, h = img.shape[:2]
+ patches = []
+ if w > p_max and h > p_max:
+ w1 = list(np.arange(0, w-p_size, p_size-p_overlap, dtype=np.int))
+ h1 = list(np.arange(0, h-p_size, p_size-p_overlap, dtype=np.int))
+ w1.append(w-p_size)
+ h1.append(h-p_size)
+# print(w1)
+# print(h1)
+ for i in w1:
+ for j in h1:
+ patches.append(img[i:i+p_size, j:j+p_size,:])
+ else:
+ patches.append(img)
+
+ return patches
+
+
+def imssave(imgs, img_path):
+ """
+ imgs: list, N images of size WxHxC
+ """
+ img_name, ext = os.path.splitext(os.path.basename(img_path))
+
+ for i, img in enumerate(imgs):
+ if img.ndim == 3:
+ img = img[:, :, [2, 1, 0]]
+ new_path = os.path.join(os.path.dirname(img_path), img_name+str('_s{:04d}'.format(i))+'.png')
+ cv2.imwrite(new_path, img)
+
+
+def split_imageset(original_dataroot, taget_dataroot, n_channels=3, p_size=800, p_overlap=96, p_max=1000):
+ """
+ split the large images from original_dataroot into small overlapped images with size (p_size)x(p_size),
+ and save them into taget_dataroot; only the images with larger size than (p_max)x(p_max)
+ will be splitted.
+ Args:
+ original_dataroot:
+ taget_dataroot:
+ p_size: size of small images
+ p_overlap: patch size in training is a good choice
+ p_max: images with smaller size than (p_max)x(p_max) keep unchanged.
+ """
+ paths = get_image_paths(original_dataroot)
+ for img_path in paths:
+ # img_name, ext = os.path.splitext(os.path.basename(img_path))
+ img = imread_uint(img_path, n_channels=n_channels)
+ patches = patches_from_image(img, p_size, p_overlap, p_max)
+ imssave(patches, os.path.join(taget_dataroot,os.path.basename(img_path)))
+ #if original_dataroot == taget_dataroot:
+ #del img_path
+
+'''
+# --------------------------------------------
+# makedir
+# --------------------------------------------
+'''
+
+
+def mkdir(path):
+ if not os.path.exists(path):
+ os.makedirs(path)
+
+
+def mkdirs(paths):
+ if isinstance(paths, str):
+ mkdir(paths)
+ else:
+ for path in paths:
+ mkdir(path)
+
+
+def mkdir_and_rename(path):
+ if os.path.exists(path):
+ new_name = path + '_archived_' + get_timestamp()
+ print('Path already exists. Rename it to [{:s}]'.format(new_name))
+ os.rename(path, new_name)
+ os.makedirs(path)
+
+
+'''
+# --------------------------------------------
+# read image from path
+# opencv is fast, but read BGR numpy image
+# --------------------------------------------
+'''
+
+
+# --------------------------------------------
+# get uint8 image of size HxWxn_channles (RGB)
+# --------------------------------------------
+def imread_uint(path, n_channels=3):
+ # input: path
+ # output: HxWx3(RGB or GGG), or HxWx1 (G)
+ if n_channels == 1:
+ img = cv2.imread(path, 0) # cv2.IMREAD_GRAYSCALE
+ img = np.expand_dims(img, axis=2) # HxWx1
+ elif n_channels == 3:
+ img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # BGR or G
+ if img.ndim == 2:
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) # GGG
+ else:
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # RGB
+ return img
+
+
+# --------------------------------------------
+# matlab's imwrite
+# --------------------------------------------
+def imsave(img, img_path):
+ img = np.squeeze(img)
+ if img.ndim == 3:
+ img = img[:, :, [2, 1, 0]]
+ cv2.imwrite(img_path, img)
+
+def imwrite(img, img_path):
+ img = np.squeeze(img)
+ if img.ndim == 3:
+ img = img[:, :, [2, 1, 0]]
+ cv2.imwrite(img_path, img)
+
+
+
+# --------------------------------------------
+# get single image of size HxWxn_channles (BGR)
+# --------------------------------------------
+def read_img(path):
+ # read image by cv2
+ # return: Numpy float32, HWC, BGR, [0,1]
+ img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # cv2.IMREAD_GRAYSCALE
+ img = img.astype(np.float32) / 255.
+ if img.ndim == 2:
+ img = np.expand_dims(img, axis=2)
+ # some images have 4 channels
+ if img.shape[2] > 3:
+ img = img[:, :, :3]
+ return img
+
+
+'''
+# --------------------------------------------
+# image format conversion
+# --------------------------------------------
+# numpy(single) <---> numpy(unit)
+# numpy(single) <---> tensor
+# numpy(unit) <---> tensor
+# --------------------------------------------
+'''
+
+
+# --------------------------------------------
+# numpy(single) [0, 1] <---> numpy(unit)
+# --------------------------------------------
+
+
+def uint2single(img):
+
+ return np.float32(img/255.)
+
+
+def single2uint(img):
+
+ return np.uint8((img.clip(0, 1)*255.).round())
+
+
+def uint162single(img):
+
+ return np.float32(img/65535.)
+
+
+def single2uint16(img):
+
+ return np.uint16((img.clip(0, 1)*65535.).round())
+
+
+# --------------------------------------------
+# numpy(unit) (HxWxC or HxW) <---> tensor
+# --------------------------------------------
+
+
+# convert uint to 4-dimensional torch tensor
+def uint2tensor4(img):
+ if img.ndim == 2:
+ img = np.expand_dims(img, axis=2)
+ return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.).unsqueeze(0)
+
+
+# convert uint to 3-dimensional torch tensor
+def uint2tensor3(img):
+ if img.ndim == 2:
+ img = np.expand_dims(img, axis=2)
+ return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.)
+
+
+# convert 2/3/4-dimensional torch tensor to uint
+def tensor2uint(img):
+ img = img.data.squeeze().float().clamp_(0, 1).cpu().numpy()
+ if img.ndim == 3:
+ img = np.transpose(img, (1, 2, 0))
+ return np.uint8((img*255.0).round())
+
+
+# --------------------------------------------
+# numpy(single) (HxWxC) <---> tensor
+# --------------------------------------------
+
+
+# convert single (HxWxC) to 3-dimensional torch tensor
+def single2tensor3(img):
+ return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float()
+
+
+# convert single (HxWxC) to 4-dimensional torch tensor
+def single2tensor4(img):
+ return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().unsqueeze(0)
+
+
+# convert torch tensor to single
+def tensor2single(img):
+ img = img.data.squeeze().float().cpu().numpy()
+ if img.ndim == 3:
+ img = np.transpose(img, (1, 2, 0))
+
+ return img
+
+# convert torch tensor to single
+def tensor2single3(img):
+ img = img.data.squeeze().float().cpu().numpy()
+ if img.ndim == 3:
+ img = np.transpose(img, (1, 2, 0))
+ elif img.ndim == 2:
+ img = np.expand_dims(img, axis=2)
+ return img
+
+
+def single2tensor5(img):
+ return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float().unsqueeze(0)
+
+
+def single32tensor5(img):
+ return torch.from_numpy(np.ascontiguousarray(img)).float().unsqueeze(0).unsqueeze(0)
+
+
+def single42tensor4(img):
+ return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float()
+
+
+# from skimage.io import imread, imsave
+def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)):
+ '''
+ Converts a torch Tensor into an image Numpy array of BGR channel order
+ Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order
+ Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default)
+ '''
+ tensor = tensor.squeeze().float().cpu().clamp_(*min_max) # squeeze first, then clamp
+ tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0]) # to range [0,1]
+ n_dim = tensor.dim()
+ if n_dim == 4:
+ n_img = len(tensor)
+ img_np = make_grid(tensor, nrow=int(math.sqrt(n_img)), normalize=False).numpy()
+ img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR
+ elif n_dim == 3:
+ img_np = tensor.numpy()
+ img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR
+ elif n_dim == 2:
+ img_np = tensor.numpy()
+ else:
+ raise TypeError(
+ 'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim))
+ if out_type == np.uint8:
+ img_np = (img_np * 255.0).round()
+ # Important. Unlike matlab, numpy.unit8() WILL NOT round by default.
+ return img_np.astype(out_type)
+
+
+'''
+# --------------------------------------------
+# Augmentation, flipe and/or rotate
+# --------------------------------------------
+# The following two are enough.
+# (1) augmet_img: numpy image of WxHxC or WxH
+# (2) augment_img_tensor4: tensor image 1xCxWxH
+# --------------------------------------------
+'''
+
+
+def augment_img(img, mode=0):
+ '''Kai Zhang (github: https://github.com/cszn)
+ '''
+ if mode == 0:
+ return img
+ elif mode == 1:
+ return np.flipud(np.rot90(img))
+ elif mode == 2:
+ return np.flipud(img)
+ elif mode == 3:
+ return np.rot90(img, k=3)
+ elif mode == 4:
+ return np.flipud(np.rot90(img, k=2))
+ elif mode == 5:
+ return np.rot90(img)
+ elif mode == 6:
+ return np.rot90(img, k=2)
+ elif mode == 7:
+ return np.flipud(np.rot90(img, k=3))
+
+
+def augment_img_tensor4(img, mode=0):
+ '''Kai Zhang (github: https://github.com/cszn)
+ '''
+ if mode == 0:
+ return img
+ elif mode == 1:
+ return img.rot90(1, [2, 3]).flip([2])
+ elif mode == 2:
+ return img.flip([2])
+ elif mode == 3:
+ return img.rot90(3, [2, 3])
+ elif mode == 4:
+ return img.rot90(2, [2, 3]).flip([2])
+ elif mode == 5:
+ return img.rot90(1, [2, 3])
+ elif mode == 6:
+ return img.rot90(2, [2, 3])
+ elif mode == 7:
+ return img.rot90(3, [2, 3]).flip([2])
+
+
+def augment_img_tensor(img, mode=0):
+ '''Kai Zhang (github: https://github.com/cszn)
+ '''
+ img_size = img.size()
+ img_np = img.data.cpu().numpy()
+ if len(img_size) == 3:
+ img_np = np.transpose(img_np, (1, 2, 0))
+ elif len(img_size) == 4:
+ img_np = np.transpose(img_np, (2, 3, 1, 0))
+ img_np = augment_img(img_np, mode=mode)
+ img_tensor = torch.from_numpy(np.ascontiguousarray(img_np))
+ if len(img_size) == 3:
+ img_tensor = img_tensor.permute(2, 0, 1)
+ elif len(img_size) == 4:
+ img_tensor = img_tensor.permute(3, 2, 0, 1)
+
+ return img_tensor.type_as(img)
+
+
+def augment_img_np3(img, mode=0):
+ if mode == 0:
+ return img
+ elif mode == 1:
+ return img.transpose(1, 0, 2)
+ elif mode == 2:
+ return img[::-1, :, :]
+ elif mode == 3:
+ img = img[::-1, :, :]
+ img = img.transpose(1, 0, 2)
+ return img
+ elif mode == 4:
+ return img[:, ::-1, :]
+ elif mode == 5:
+ img = img[:, ::-1, :]
+ img = img.transpose(1, 0, 2)
+ return img
+ elif mode == 6:
+ img = img[:, ::-1, :]
+ img = img[::-1, :, :]
+ return img
+ elif mode == 7:
+ img = img[:, ::-1, :]
+ img = img[::-1, :, :]
+ img = img.transpose(1, 0, 2)
+ return img
+
+
+def augment_imgs(img_list, hflip=True, rot=True):
+ # horizontal flip OR rotate
+ hflip = hflip and random.random() < 0.5
+ vflip = rot and random.random() < 0.5
+ rot90 = rot and random.random() < 0.5
+
+ def _augment(img):
+ if hflip:
+ img = img[:, ::-1, :]
+ if vflip:
+ img = img[::-1, :, :]
+ if rot90:
+ img = img.transpose(1, 0, 2)
+ return img
+
+ return [_augment(img) for img in img_list]
+
+
+'''
+# --------------------------------------------
+# modcrop and shave
+# --------------------------------------------
+'''
+
+
+def modcrop(img_in, scale):
+ # img_in: Numpy, HWC or HW
+ img = np.copy(img_in)
+ if img.ndim == 2:
+ H, W = img.shape
+ H_r, W_r = H % scale, W % scale
+ img = img[:H - H_r, :W - W_r]
+ elif img.ndim == 3:
+ H, W, C = img.shape
+ H_r, W_r = H % scale, W % scale
+ img = img[:H - H_r, :W - W_r, :]
+ else:
+ raise ValueError('Wrong img ndim: [{:d}].'.format(img.ndim))
+ return img
+
+
+def shave(img_in, border=0):
+ # img_in: Numpy, HWC or HW
+ img = np.copy(img_in)
+ h, w = img.shape[:2]
+ img = img[border:h-border, border:w-border]
+ return img
+
+
+'''
+# --------------------------------------------
+# image processing process on numpy image
+# channel_convert(in_c, tar_type, img_list):
+# rgb2ycbcr(img, only_y=True):
+# bgr2ycbcr(img, only_y=True):
+# ycbcr2rgb(img):
+# --------------------------------------------
+'''
+
+
+def rgb2ycbcr(img, only_y=True):
+ '''same as matlab rgb2ycbcr
+ only_y: only return Y channel
+ Input:
+ uint8, [0, 255]
+ float, [0, 1]
+ '''
+ in_img_type = img.dtype
+ img.astype(np.float32)
+ if in_img_type != np.uint8:
+ img *= 255.
+ # convert
+ if only_y:
+ rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0
+ else:
+ rlt = np.matmul(img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786],
+ [24.966, 112.0, -18.214]]) / 255.0 + [16, 128, 128]
+ if in_img_type == np.uint8:
+ rlt = rlt.round()
+ else:
+ rlt /= 255.
+ return rlt.astype(in_img_type)
+
+
+def ycbcr2rgb(img):
+ '''same as matlab ycbcr2rgb
+ Input:
+ uint8, [0, 255]
+ float, [0, 1]
+ '''
+ in_img_type = img.dtype
+ img.astype(np.float32)
+ if in_img_type != np.uint8:
+ img *= 255.
+ # convert
+ rlt = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071],
+ [0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836]
+ if in_img_type == np.uint8:
+ rlt = rlt.round()
+ else:
+ rlt /= 255.
+ return rlt.astype(in_img_type)
+
+
+def bgr2ycbcr(img, only_y=True):
+ '''bgr version of rgb2ycbcr
+ only_y: only return Y channel
+ Input:
+ uint8, [0, 255]
+ float, [0, 1]
+ '''
+ in_img_type = img.dtype
+ img.astype(np.float32)
+ if in_img_type != np.uint8:
+ img *= 255.
+ # convert
+ if only_y:
+ rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0
+ else:
+ rlt = np.matmul(img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786],
+ [65.481, -37.797, 112.0]]) / 255.0 + [16, 128, 128]
+ if in_img_type == np.uint8:
+ rlt = rlt.round()
+ else:
+ rlt /= 255.
+ return rlt.astype(in_img_type)
+
+
+def channel_convert(in_c, tar_type, img_list):
+ # conversion among BGR, gray and y
+ if in_c == 3 and tar_type == 'gray': # BGR to gray
+ gray_list = [cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) for img in img_list]
+ return [np.expand_dims(img, axis=2) for img in gray_list]
+ elif in_c == 3 and tar_type == 'y': # BGR to y
+ y_list = [bgr2ycbcr(img, only_y=True) for img in img_list]
+ return [np.expand_dims(img, axis=2) for img in y_list]
+ elif in_c == 1 and tar_type == 'RGB': # gray/y to BGR
+ return [cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) for img in img_list]
+ else:
+ return img_list
+
+
+'''
+# --------------------------------------------
+# metric, PSNR and SSIM
+# --------------------------------------------
+'''
+
+
+# --------------------------------------------
+# PSNR
+# --------------------------------------------
+def calculate_psnr(img1, img2, border=0):
+ # img1 and img2 have range [0, 255]
+ #img1 = img1.squeeze()
+ #img2 = img2.squeeze()
+ if not img1.shape == img2.shape:
+ raise ValueError('Input images must have the same dimensions.')
+ h, w = img1.shape[:2]
+ img1 = img1[border:h-border, border:w-border]
+ img2 = img2[border:h-border, border:w-border]
+
+ img1 = img1.astype(np.float64)
+ img2 = img2.astype(np.float64)
+ mse = np.mean((img1 - img2)**2)
+ if mse == 0:
+ return float('inf')
+ return 20 * math.log10(255.0 / math.sqrt(mse))
+
+
+# --------------------------------------------
+# SSIM
+# --------------------------------------------
+def calculate_ssim(img1, img2, border=0):
+ '''calculate SSIM
+ the same outputs as MATLAB's
+ img1, img2: [0, 255]
+ '''
+ #img1 = img1.squeeze()
+ #img2 = img2.squeeze()
+ if not img1.shape == img2.shape:
+ raise ValueError('Input images must have the same dimensions.')
+ h, w = img1.shape[:2]
+ img1 = img1[border:h-border, border:w-border]
+ img2 = img2[border:h-border, border:w-border]
+
+ if img1.ndim == 2:
+ return ssim(img1, img2)
+ elif img1.ndim == 3:
+ if img1.shape[2] == 3:
+ ssims = []
+ for i in range(3):
+ ssims.append(ssim(img1[:,:,i], img2[:,:,i]))
+ return np.array(ssims).mean()
+ elif img1.shape[2] == 1:
+ return ssim(np.squeeze(img1), np.squeeze(img2))
+ else:
+ raise ValueError('Wrong input image dimensions.')
+
+
+def ssim(img1, img2):
+ C1 = (0.01 * 255)**2
+ C2 = (0.03 * 255)**2
+
+ img1 = img1.astype(np.float64)
+ img2 = img2.astype(np.float64)
+ kernel = cv2.getGaussianKernel(11, 1.5)
+ window = np.outer(kernel, kernel.transpose())
+
+ mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid
+ mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
+ mu1_sq = mu1**2
+ mu2_sq = mu2**2
+ mu1_mu2 = mu1 * mu2
+ sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
+ sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
+ sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
+
+ ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
+ (sigma1_sq + sigma2_sq + C2))
+ return ssim_map.mean()
+
+
+'''
+# --------------------------------------------
+# matlab's bicubic imresize (numpy and torch) [0, 1]
+# --------------------------------------------
+'''
+
+
+# matlab 'imresize' function, now only support 'bicubic'
+def cubic(x):
+ absx = torch.abs(x)
+ absx2 = absx**2
+ absx3 = absx**3
+ return (1.5*absx3 - 2.5*absx2 + 1) * ((absx <= 1).type_as(absx)) + \
+ (-0.5*absx3 + 2.5*absx2 - 4*absx + 2) * (((absx > 1)*(absx <= 2)).type_as(absx))
+
+
+def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing):
+ if (scale < 1) and (antialiasing):
+ # Use a modified kernel to simultaneously interpolate and antialias- larger kernel width
+ kernel_width = kernel_width / scale
+
+ # Output-space coordinates
+ x = torch.linspace(1, out_length, out_length)
+
+ # Input-space coordinates. Calculate the inverse mapping such that 0.5
+ # in output space maps to 0.5 in input space, and 0.5+scale in output
+ # space maps to 1.5 in input space.
+ u = x / scale + 0.5 * (1 - 1 / scale)
+
+ # What is the left-most pixel that can be involved in the computation?
+ left = torch.floor(u - kernel_width / 2)
+
+ # What is the maximum number of pixels that can be involved in the
+ # computation? Note: it's OK to use an extra pixel here; if the
+ # corresponding weights are all zero, it will be eliminated at the end
+ # of this function.
+ P = math.ceil(kernel_width) + 2
+
+ # The indices of the input pixels involved in computing the k-th output
+ # pixel are in row k of the indices matrix.
+ indices = left.view(out_length, 1).expand(out_length, P) + torch.linspace(0, P - 1, P).view(
+ 1, P).expand(out_length, P)
+
+ # The weights used to compute the k-th output pixel are in row k of the
+ # weights matrix.
+ distance_to_center = u.view(out_length, 1).expand(out_length, P) - indices
+ # apply cubic kernel
+ if (scale < 1) and (antialiasing):
+ weights = scale * cubic(distance_to_center * scale)
+ else:
+ weights = cubic(distance_to_center)
+ # Normalize the weights matrix so that each row sums to 1.
+ weights_sum = torch.sum(weights, 1).view(out_length, 1)
+ weights = weights / weights_sum.expand(out_length, P)
+
+ # If a column in weights is all zero, get rid of it. only consider the first and last column.
+ weights_zero_tmp = torch.sum((weights == 0), 0)
+ if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6):
+ indices = indices.narrow(1, 1, P - 2)
+ weights = weights.narrow(1, 1, P - 2)
+ if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6):
+ indices = indices.narrow(1, 0, P - 2)
+ weights = weights.narrow(1, 0, P - 2)
+ weights = weights.contiguous()
+ indices = indices.contiguous()
+ sym_len_s = -indices.min() + 1
+ sym_len_e = indices.max() - in_length
+ indices = indices + sym_len_s - 1
+ return weights, indices, int(sym_len_s), int(sym_len_e)
+
+
+# --------------------------------------------
+# imresize for tensor image [0, 1]
+# --------------------------------------------
+def imresize(img, scale, antialiasing=True):
+ # Now the scale should be the same for H and W
+ # input: img: pytorch tensor, CHW or HW [0,1]
+ # output: CHW or HW [0,1] w/o round
+ need_squeeze = True if img.dim() == 2 else False
+ if need_squeeze:
+ img.unsqueeze_(0)
+ in_C, in_H, in_W = img.size()
+ out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale)
+ kernel_width = 4
+ kernel = 'cubic'
+
+ # Return the desired dimension order for performing the resize. The
+ # strategy is to perform the resize first along the dimension with the
+ # smallest scale factor.
+ # Now we do not support this.
+
+ # get weights and indices
+ weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices(
+ in_H, out_H, scale, kernel, kernel_width, antialiasing)
+ weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices(
+ in_W, out_W, scale, kernel, kernel_width, antialiasing)
+ # process H dimension
+ # symmetric copying
+ img_aug = torch.FloatTensor(in_C, in_H + sym_len_Hs + sym_len_He, in_W)
+ img_aug.narrow(1, sym_len_Hs, in_H).copy_(img)
+
+ sym_patch = img[:, :sym_len_Hs, :]
+ inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
+ sym_patch_inv = sym_patch.index_select(1, inv_idx)
+ img_aug.narrow(1, 0, sym_len_Hs).copy_(sym_patch_inv)
+
+ sym_patch = img[:, -sym_len_He:, :]
+ inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
+ sym_patch_inv = sym_patch.index_select(1, inv_idx)
+ img_aug.narrow(1, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv)
+
+ out_1 = torch.FloatTensor(in_C, out_H, in_W)
+ kernel_width = weights_H.size(1)
+ for i in range(out_H):
+ idx = int(indices_H[i][0])
+ for j in range(out_C):
+ out_1[j, i, :] = img_aug[j, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_H[i])
+
+ # process W dimension
+ # symmetric copying
+ out_1_aug = torch.FloatTensor(in_C, out_H, in_W + sym_len_Ws + sym_len_We)
+ out_1_aug.narrow(2, sym_len_Ws, in_W).copy_(out_1)
+
+ sym_patch = out_1[:, :, :sym_len_Ws]
+ inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
+ sym_patch_inv = sym_patch.index_select(2, inv_idx)
+ out_1_aug.narrow(2, 0, sym_len_Ws).copy_(sym_patch_inv)
+
+ sym_patch = out_1[:, :, -sym_len_We:]
+ inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
+ sym_patch_inv = sym_patch.index_select(2, inv_idx)
+ out_1_aug.narrow(2, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv)
+
+ out_2 = torch.FloatTensor(in_C, out_H, out_W)
+ kernel_width = weights_W.size(1)
+ for i in range(out_W):
+ idx = int(indices_W[i][0])
+ for j in range(out_C):
+ out_2[j, :, i] = out_1_aug[j, :, idx:idx + kernel_width].mv(weights_W[i])
+ if need_squeeze:
+ out_2.squeeze_()
+ return out_2
+
+
+# --------------------------------------------
+# imresize for numpy image [0, 1]
+# --------------------------------------------
+def imresize_np(img, scale, antialiasing=True):
+ # Now the scale should be the same for H and W
+ # input: img: Numpy, HWC or HW [0,1]
+ # output: HWC or HW [0,1] w/o round
+ img = torch.from_numpy(img)
+ need_squeeze = True if img.dim() == 2 else False
+ if need_squeeze:
+ img.unsqueeze_(2)
+
+ in_H, in_W, in_C = img.size()
+ out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale)
+ kernel_width = 4
+ kernel = 'cubic'
+
+ # Return the desired dimension order for performing the resize. The
+ # strategy is to perform the resize first along the dimension with the
+ # smallest scale factor.
+ # Now we do not support this.
+
+ # get weights and indices
+ weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices(
+ in_H, out_H, scale, kernel, kernel_width, antialiasing)
+ weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices(
+ in_W, out_W, scale, kernel, kernel_width, antialiasing)
+ # process H dimension
+ # symmetric copying
+ img_aug = torch.FloatTensor(in_H + sym_len_Hs + sym_len_He, in_W, in_C)
+ img_aug.narrow(0, sym_len_Hs, in_H).copy_(img)
+
+ sym_patch = img[:sym_len_Hs, :, :]
+ inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long()
+ sym_patch_inv = sym_patch.index_select(0, inv_idx)
+ img_aug.narrow(0, 0, sym_len_Hs).copy_(sym_patch_inv)
+
+ sym_patch = img[-sym_len_He:, :, :]
+ inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long()
+ sym_patch_inv = sym_patch.index_select(0, inv_idx)
+ img_aug.narrow(0, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv)
+
+ out_1 = torch.FloatTensor(out_H, in_W, in_C)
+ kernel_width = weights_H.size(1)
+ for i in range(out_H):
+ idx = int(indices_H[i][0])
+ for j in range(out_C):
+ out_1[i, :, j] = img_aug[idx:idx + kernel_width, :, j].transpose(0, 1).mv(weights_H[i])
+
+ # process W dimension
+ # symmetric copying
+ out_1_aug = torch.FloatTensor(out_H, in_W + sym_len_Ws + sym_len_We, in_C)
+ out_1_aug.narrow(1, sym_len_Ws, in_W).copy_(out_1)
+
+ sym_patch = out_1[:, :sym_len_Ws, :]
+ inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
+ sym_patch_inv = sym_patch.index_select(1, inv_idx)
+ out_1_aug.narrow(1, 0, sym_len_Ws).copy_(sym_patch_inv)
+
+ sym_patch = out_1[:, -sym_len_We:, :]
+ inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
+ sym_patch_inv = sym_patch.index_select(1, inv_idx)
+ out_1_aug.narrow(1, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv)
+
+ out_2 = torch.FloatTensor(out_H, out_W, in_C)
+ kernel_width = weights_W.size(1)
+ for i in range(out_W):
+ idx = int(indices_W[i][0])
+ for j in range(out_C):
+ out_2[:, i, j] = out_1_aug[:, idx:idx + kernel_width, j].mv(weights_W[i])
+ if need_squeeze:
+ out_2.squeeze_()
+
+ return out_2.numpy()
+
+
+if __name__ == '__main__':
+ print('---')
+# img = imread_uint('test.bmp', 3)
+# img = uint2single(img)
+# img_bicubic = imresize_np(img, 1/4)
\ No newline at end of file
diff --git a/easyanimate/vae/ldm/modules/losses/__init__.py b/easyanimate/vae/ldm/modules/losses/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9a8b4f66d63552b5b80df394eeb6b2c16564fe5d
--- /dev/null
+++ b/easyanimate/vae/ldm/modules/losses/__init__.py
@@ -0,0 +1 @@
+from .contperceptual import LPIPSWithDiscriminator
\ No newline at end of file
diff --git a/easyanimate/vae/ldm/modules/losses/contperceptual.py b/easyanimate/vae/ldm/modules/losses/contperceptual.py
new file mode 100644
index 0000000000000000000000000000000000000000..f2dfcda26f92ca356cdbc26c74e5333b29eba502
--- /dev/null
+++ b/easyanimate/vae/ldm/modules/losses/contperceptual.py
@@ -0,0 +1,148 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from taming.modules.losses.vqperceptual import * # TODO: taming dependency yes/no?
+from ..vaemodules.discriminator import Discriminator3D
+
+class LPIPSWithDiscriminator(nn.Module):
+ def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0,
+ disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
+ perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,
+ disc_loss="hinge", l2_loss_weight=0.0, l1_loss_weight=1.0):
+
+ super().__init__()
+ assert disc_loss in ["hinge", "vanilla"]
+ self.kl_weight = kl_weight
+ self.pixel_weight = pixelloss_weight
+ self.perceptual_loss = LPIPS().eval()
+ self.perceptual_weight = perceptual_weight
+ # output log variance
+ self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init)
+
+ self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,
+ n_layers=disc_num_layers,
+ use_actnorm=use_actnorm
+ ).apply(weights_init)
+ self.discriminator3d = Discriminator3D(
+ in_channels=disc_in_channels,
+ block_out_channels=(64, 128, 256)
+ ).apply(weights_init)
+ self.discriminator_iter_start = disc_start
+ self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss
+ self.disc_factor = disc_factor
+ self.discriminator_weight = disc_weight
+ self.disc_conditional = disc_conditional
+ self.l1_loss_weight = l1_loss_weight
+ self.l2_loss_weight = l2_loss_weight
+
+ def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
+ if last_layer is not None:
+ nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
+ g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
+ else:
+ nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
+ g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
+
+ d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
+ d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
+ d_weight = d_weight * self.discriminator_weight
+ return d_weight
+
+ def forward(self, inputs, reconstructions, posteriors, optimizer_idx,
+ global_step, last_layer=None, cond=None, split="train",
+ weights=None):
+
+ if inputs.ndim==4:
+ inputs = inputs.unsqueeze(2)
+ if reconstructions.ndim==4:
+ reconstructions = reconstructions.unsqueeze(2)
+
+ inputs_ori = inputs
+ reconstructions_ori = reconstructions
+
+ # get new loss_weight
+ loss_weights = 1
+ # b, _ ,f, _, _ = reconstructions.size()
+ # loss_weights = torch.ones([b, f]).view(b, 1, f, 1, 1)
+ # loss_weights[:, :, 0] = 3
+ # for i in range(1, f, 8):
+ # loss_weights[:, :, i - 1] = 3
+ # loss_weights[:, :, i] = 3
+ # loss_weights[:, :, -1] = 3
+ # loss_weights = loss_weights.permute(0, 2, 1, 3, 4).flatten(0, 1).to(reconstructions.device)
+
+ inputs = inputs.permute(0, 2, 1, 3, 4).flatten(0, 1)
+ reconstructions = reconstructions.permute(0, 2, 1, 3, 4).flatten(0, 1)
+
+ rec_loss = 0
+ if self.l1_loss_weight > 0:
+ rec_loss += torch.abs(inputs.contiguous() - reconstructions.contiguous()) * self.l1_loss_weight
+ if self.l2_loss_weight > 0:
+ rec_loss += F.mse_loss(inputs.contiguous(), reconstructions.contiguous(), reduction="none") * self.l2_loss_weight
+ if self.perceptual_weight > 0:
+ p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())
+ rec_loss = rec_loss + self.perceptual_weight * p_loss
+ rec_loss = rec_loss * loss_weights
+
+ nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar
+ weighted_nll_loss = nll_loss
+ if weights is not None:
+ weighted_nll_loss = weights*nll_loss
+ weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]
+ nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
+ kl_loss = posteriors.kl()
+ kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
+
+ # now the GAN part
+ if optimizer_idx == 0:
+ # generator update
+ if cond is None:
+ assert not self.disc_conditional
+ logits_fake = self.discriminator(reconstructions.contiguous())
+ else:
+ assert self.disc_conditional
+ logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1))
+ logits_fake_3d = self.discriminator3d(reconstructions_ori.contiguous())
+ g_loss = -torch.mean(logits_fake) - torch.mean(logits_fake_3d)
+
+ if self.disc_factor > 0.0:
+ try:
+ d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
+ except RuntimeError:
+ assert not self.training
+ d_weight = torch.tensor(0.0)
+ else:
+ d_weight = torch.tensor(0.0)
+
+ disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
+ loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss
+
+ log = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/logvar".format(split): self.logvar.detach(),
+ "{}/kl_loss".format(split): kl_loss.detach().mean(), "{}/nll_loss".format(split): nll_loss.detach().mean(),
+ "{}/rec_loss".format(split): rec_loss.detach().mean(),
+ "{}/d_weight".format(split): d_weight.detach(),
+ "{}/disc_factor".format(split): torch.tensor(disc_factor),
+ "{}/g_loss".format(split): g_loss.detach().mean(),
+ }
+ return loss, log
+
+ if optimizer_idx == 1:
+ # second pass for discriminator update
+ if cond is None:
+ logits_real = self.discriminator(inputs.contiguous().detach())
+ logits_fake = self.discriminator(reconstructions.contiguous().detach())
+ else:
+ logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))
+ logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))
+ logits_real_3d = self.discriminator3d(inputs_ori.contiguous().detach())
+ logits_fake_3d = self.discriminator3d(reconstructions_ori.contiguous().detach())
+
+ disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
+ d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) + disc_factor * self.disc_loss(logits_real_3d, logits_fake_3d)
+
+ log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
+ "{}/logits_real".format(split): logits_real.detach().mean(),
+ "{}/logits_fake".format(split): logits_fake.detach().mean()
+ }
+ return d_loss, log
+
diff --git a/easyanimate/vae/ldm/modules/losses/vqperceptual.py b/easyanimate/vae/ldm/modules/losses/vqperceptual.py
new file mode 100644
index 0000000000000000000000000000000000000000..3eb8241f9c5eb7a43ed7e31efa70ccfaff768f3c
--- /dev/null
+++ b/easyanimate/vae/ldm/modules/losses/vqperceptual.py
@@ -0,0 +1,167 @@
+import torch
+import torch.nn.functional as F
+from einops import repeat
+from taming.modules.discriminator.model import (NLayerDiscriminator,
+ weights_init)
+from taming.modules.losses.lpips import LPIPS
+from taming.modules.losses.vqperceptual import hinge_d_loss, vanilla_d_loss
+from torch import nn
+
+
+def hinge_d_loss_with_exemplar_weights(logits_real, logits_fake, weights):
+ assert weights.shape[0] == logits_real.shape[0] == logits_fake.shape[0]
+ loss_real = torch.mean(F.relu(1. - logits_real), dim=[1,2,3])
+ loss_fake = torch.mean(F.relu(1. + logits_fake), dim=[1,2,3])
+ loss_real = (weights * loss_real).sum() / weights.sum()
+ loss_fake = (weights * loss_fake).sum() / weights.sum()
+ d_loss = 0.5 * (loss_real + loss_fake)
+ return d_loss
+
+def adopt_weight(weight, global_step, threshold=0, value=0.):
+ if global_step < threshold:
+ weight = value
+ return weight
+
+
+def measure_perplexity(predicted_indices, n_embed):
+ # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py
+ # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally
+ encodings = F.one_hot(predicted_indices, n_embed).float().reshape(-1, n_embed)
+ avg_probs = encodings.mean(0)
+ perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp()
+ cluster_use = torch.sum(avg_probs > 0)
+ return perplexity, cluster_use
+
+def l1(x, y):
+ return torch.abs(x-y)
+
+
+def l2(x, y):
+ return torch.pow((x-y), 2)
+
+
+class VQLPIPSWithDiscriminator(nn.Module):
+ def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0,
+ disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
+ perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,
+ disc_ndf=64, disc_loss="hinge", n_classes=None, perceptual_loss="lpips",
+ pixel_loss="l1"):
+ super().__init__()
+ assert disc_loss in ["hinge", "vanilla"]
+ assert perceptual_loss in ["lpips", "clips", "dists"]
+ assert pixel_loss in ["l1", "l2"]
+ self.codebook_weight = codebook_weight
+ self.pixel_weight = pixelloss_weight
+ if perceptual_loss == "lpips":
+ print(f"{self.__class__.__name__}: Running with LPIPS.")
+ self.perceptual_loss = LPIPS().eval()
+ else:
+ raise ValueError(f"Unknown perceptual loss: >> {perceptual_loss} <<")
+ self.perceptual_weight = perceptual_weight
+
+ if pixel_loss == "l1":
+ self.pixel_loss = l1
+ else:
+ self.pixel_loss = l2
+
+ self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,
+ n_layers=disc_num_layers,
+ use_actnorm=use_actnorm,
+ ndf=disc_ndf
+ ).apply(weights_init)
+ self.discriminator_iter_start = disc_start
+ if disc_loss == "hinge":
+ self.disc_loss = hinge_d_loss
+ elif disc_loss == "vanilla":
+ self.disc_loss = vanilla_d_loss
+ else:
+ raise ValueError(f"Unknown GAN loss '{disc_loss}'.")
+ print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.")
+ self.disc_factor = disc_factor
+ self.discriminator_weight = disc_weight
+ self.disc_conditional = disc_conditional
+ self.n_classes = n_classes
+
+ def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
+ if last_layer is not None:
+ nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
+ g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
+ else:
+ nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
+ g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
+
+ d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
+ d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
+ d_weight = d_weight * self.discriminator_weight
+ return d_weight
+
+ def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx,
+ global_step, last_layer=None, cond=None, split="train", predicted_indices=None):
+ if not exists(codebook_loss):
+ codebook_loss = torch.tensor([0.]).to(inputs.device)
+ #rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
+ rec_loss = self.pixel_loss(inputs.contiguous(), reconstructions.contiguous())
+ if self.perceptual_weight > 0:
+ p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())
+ rec_loss = rec_loss + self.perceptual_weight * p_loss
+ else:
+ p_loss = torch.tensor([0.0])
+
+ nll_loss = rec_loss
+ #nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
+ nll_loss = torch.mean(nll_loss)
+
+ # now the GAN part
+ if optimizer_idx == 0:
+ # generator update
+ if cond is None:
+ assert not self.disc_conditional
+ logits_fake = self.discriminator(reconstructions.contiguous())
+ else:
+ assert self.disc_conditional
+ logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1))
+ g_loss = -torch.mean(logits_fake)
+
+ try:
+ d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
+ except RuntimeError:
+ assert not self.training
+ d_weight = torch.tensor(0.0)
+
+ disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
+ loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean()
+
+ log = {"{}/total_loss".format(split): loss.clone().detach().mean(),
+ "{}/quant_loss".format(split): codebook_loss.detach().mean(),
+ "{}/nll_loss".format(split): nll_loss.detach().mean(),
+ "{}/rec_loss".format(split): rec_loss.detach().mean(),
+ "{}/p_loss".format(split): p_loss.detach().mean(),
+ "{}/d_weight".format(split): d_weight.detach(),
+ "{}/disc_factor".format(split): torch.tensor(disc_factor),
+ "{}/g_loss".format(split): g_loss.detach().mean(),
+ }
+ if predicted_indices is not None:
+ assert self.n_classes is not None
+ with torch.no_grad():
+ perplexity, cluster_usage = measure_perplexity(predicted_indices, self.n_classes)
+ log[f"{split}/perplexity"] = perplexity
+ log[f"{split}/cluster_usage"] = cluster_usage
+ return loss, log
+
+ if optimizer_idx == 1:
+ # second pass for discriminator update
+ if cond is None:
+ logits_real = self.discriminator(inputs.contiguous().detach())
+ logits_fake = self.discriminator(reconstructions.contiguous().detach())
+ else:
+ logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))
+ logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))
+
+ disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
+ d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
+
+ log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
+ "{}/logits_real".format(split): logits_real.detach().mean(),
+ "{}/logits_fake".format(split): logits_fake.detach().mean()
+ }
+ return d_loss, log
diff --git a/easyanimate/vae/ldm/modules/vaemodules/__init__.py b/easyanimate/vae/ldm/modules/vaemodules/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/easyanimate/vae/ldm/modules/vaemodules/activations.py b/easyanimate/vae/ldm/modules/vaemodules/activations.py
new file mode 100755
index 0000000000000000000000000000000000000000..ab2d6a919ec262c370cec812f87ab928a96dee8b
--- /dev/null
+++ b/easyanimate/vae/ldm/modules/vaemodules/activations.py
@@ -0,0 +1,27 @@
+import torch.nn as nn
+
+ACTIVATION_FUNCTIONS = {
+ "elu": nn.ELU(),
+ "swish": nn.SiLU(),
+ "silu": nn.SiLU(),
+ "mish": nn.Mish(),
+ "gelu": nn.GELU(),
+ "relu": nn.ReLU(),
+}
+
+
+def get_activation(act_fn: str) -> nn.Module:
+ """Helper function to get activation function from string.
+
+ Args:
+ act_fn (str): Name of activation function.
+
+ Returns:
+ nn.Module: Activation function.
+ """
+
+ act_fn = act_fn.lower()
+ if act_fn in ACTIVATION_FUNCTIONS:
+ return ACTIVATION_FUNCTIONS[act_fn]
+ else:
+ raise ValueError(f"Unsupported activation function: {act_fn}")
diff --git a/easyanimate/vae/ldm/modules/vaemodules/attention.py b/easyanimate/vae/ldm/modules/vaemodules/attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..5b4daccf8722cab65e080cf01af001f8ae041588
--- /dev/null
+++ b/easyanimate/vae/ldm/modules/vaemodules/attention.py
@@ -0,0 +1,479 @@
+import inspect
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from einops import rearrange
+
+from .attention_processors import AttnProcessor, AttnProcessor2_0
+from .common import SpatialNorm3D
+
+
+class Attention(nn.Module):
+ r"""
+ A cross attention layer.
+
+ Parameters:
+ query_dim (`int`):
+ The number of channels in the query.
+ cross_attention_dim (`int`, *optional*):
+ The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
+ nheads (`int`, *optional*, defaults to 8):
+ The number of heads to use for multi-head attention.
+ head_dim (`int`, *optional*, defaults to 64):
+ The number of channels in each head.
+ dropout (`float`, *optional*, defaults to 0.0):
+ The dropout probability to use.
+ bias (`bool`, *optional*, defaults to False):
+ Set to `True` for the query, key, and value linear layers to contain a bias parameter.
+ upcast_attention (`bool`, *optional*, defaults to False):
+ Set to `True` to upcast the attention computation to `float32`.
+ upcast_softmax (`bool`, *optional*, defaults to False):
+ Set to `True` to upcast the softmax computation to `float32`.
+ cross_attention_norm (`str`, *optional*, defaults to `None`):
+ The type of normalization to use for the cross attention. Can be `None`, `layer_norm`, or `group_norm`.
+ cross_attention_norm_num_groups (`int`, *optional*, defaults to 32):
+ The number of groups to use for the group norm in the cross attention.
+ added_kv_proj_dim (`int`, *optional*, defaults to `None`):
+ The number of channels to use for the added key and value projections. If `None`, no projection is used.
+ norm_num_groups (`int`, *optional*, defaults to `None`):
+ The number of groups to use for the group norm in the attention.
+ spatial_norm_dim (`int`, *optional*, defaults to `None`):
+ The number of channels to use for the spatial normalization.
+ out_bias (`bool`, *optional*, defaults to `True`):
+ Set to `True` to use a bias in the output linear layer.
+ scale_qk (`bool`, *optional*, defaults to `True`):
+ Set to `True` to scale the query and key by `1 / sqrt(dim_head)`.
+ only_cross_attention (`bool`, *optional*, defaults to `False`):
+ Set to `True` to only use cross attention and not added_kv_proj_dim. Can only be set to `True` if
+ `added_kv_proj_dim` is not `None`.
+ eps (`float`, *optional*, defaults to 1e-5):
+ An additional value added to the denominator in group normalization that is used for numerical stability.
+ rescale_output_factor (`float`, *optional*, defaults to 1.0):
+ A factor to rescale the output by dividing it with this value.
+ residual_connection (`bool`, *optional*, defaults to `False`):
+ Set to `True` to add the residual connection to the output.
+ _from_deprecated_attn_block (`bool`, *optional*, defaults to `False`):
+ Set to `True` if the attention block is loaded from a deprecated state dict.
+ processor (`AttnProcessor`, *optional*, defaults to `None`):
+ The attention processor to use. If `None`, defaults to `AttnProcessor2_0` if `torch 2.x` is used and
+ `AttnProcessor` otherwise.
+ """
+
+ def __init__(
+ self,
+ query_dim: int,
+ cross_attention_dim: int = None,
+ nheads: int = 8,
+ head_dim: int = 64,
+ dropout: float = 0.0,
+ bias: bool = False,
+ upcast_attention: bool = False,
+ upcast_softmax: bool = False,
+ cross_attention_norm = None,
+ cross_attention_norm_num_groups: int = 32,
+ added_kv_proj_dim = None,
+ norm_num_groups = None,
+ spatial_norm_dim = None,
+ out_bias: bool = True,
+ scale_qk: bool = True,
+ only_cross_attention: bool = False,
+ eps: float = 1e-5,
+ rescale_output_factor: float = 1.0,
+ residual_connection: bool = False,
+ processor = None,
+ out_dim: int = None,
+ ):
+ super().__init__()
+
+ self.query_dim = query_dim
+ self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
+ self.inner_dim = out_dim if out_dim is not None else head_dim * nheads
+ self.nheads = out_dim // head_dim if out_dim is not None else nheads
+ self.out_dim = out_dim if out_dim is not None else query_dim
+
+ self.upcast_attention = upcast_attention
+ self.upcast_softmax = upcast_softmax
+ self.added_kv_proj_dim = added_kv_proj_dim
+ self.only_cross_attention = only_cross_attention
+
+ if self.added_kv_proj_dim is None and self.only_cross_attention:
+ raise ValueError(
+ "`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`."
+ )
+
+ self.scale_qk = scale_qk
+ self.scale = head_dim ** -0.5 if scale_qk else 1.0
+
+ self.rescale_output_factor = rescale_output_factor
+ self.residual_connection = residual_connection
+
+ if norm_num_groups is not None:
+ self.group_norm = nn.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True)
+ else:
+ self.group_norm = None
+
+ if spatial_norm_dim is not None:
+ self.spatial_norm = SpatialNorm3D(f_channels=query_dim, zq_channels=spatial_norm_dim)
+ else:
+ self.spatial_norm = None
+
+ if cross_attention_norm is None:
+ self.norm_cross = None
+ elif cross_attention_norm == "layer_norm":
+ self.norm_cross = nn.LayerNorm(self.cross_attention_dim)
+ elif cross_attention_norm == "group_norm":
+ if self.added_kv_proj_dim is not None:
+ # The given `encoder_hidden_states` are initially of shape
+ # (batch_size, seq_len, added_kv_proj_dim) before being projected
+ # to (batch_size, seq_len, cross_attention_dim). The norm is applied
+ # before the projection, so we need to use `added_kv_proj_dim` as
+ # the number of channels for the group norm.
+ norm_cross_num_channels = added_kv_proj_dim
+ else:
+ norm_cross_num_channels = self.cross_attention_dim
+
+ self.norm_cross = nn.GroupNorm(
+ num_channels=norm_cross_num_channels, num_groups=cross_attention_norm_num_groups, eps=1e-5, affine=True
+ )
+ else:
+ raise ValueError(
+ f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'"
+ )
+
+ self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias)
+
+ if not self.only_cross_attention:
+ self.to_k = nn.Linear(self.cross_attention_dim, self.inner_dim, bias=bias)
+ self.to_v = nn.Linear(self.cross_attention_dim, self.inner_dim, bias=bias)
+ else:
+ self.to_k = None
+ self.to_v = None
+
+ if self.added_kv_proj_dim is not None:
+ self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_dim)
+ self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_dim)
+
+ self.to_out = nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)
+ self.dropout = nn.Dropout(dropout)
+
+ if processor is None:
+ processor = (
+ AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnProcessor()
+ )
+ self.set_processor(processor)
+
+ def set_processor(self, processor: AttnProcessor) -> None:
+ r"""
+ Set the attention processor to use.
+
+ Args:
+ processor (`AttnProcessor`):
+ The attention processor to use.
+ """
+ # if current processor is in `self._modules` and if passed `processor` is not, we need to
+ # pop `processor` from `self._modules`
+ if (
+ hasattr(self, "processor")
+ and isinstance(self.processor, torch.nn.Module)
+ and not isinstance(processor, torch.nn.Module)
+ ):
+ self._modules.pop("processor")
+
+ self.processor = processor
+ self._attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
+
+ def prepare_attention_mask(
+ self, attention_mask: torch.Tensor, target_length: int, batch_size: int, out_dim: int = 3
+ ) -> torch.Tensor:
+ r"""
+ Prepare the attention mask for the attention computation.
+
+ Args:
+ attention_mask (`torch.Tensor`):
+ The attention mask to prepare.
+ target_length (`int`):
+ The target length of the attention mask. This is the length of the attention mask after padding.
+ batch_size (`int`):
+ The batch size, which is used to repeat the attention mask.
+ out_dim (`int`, *optional*, defaults to `3`):
+ The output dimension of the attention mask. Can be either `3` or `4`.
+
+ Returns:
+ `torch.Tensor`: The prepared attention mask.
+ """
+ head_size = self.nheads
+ if attention_mask is None:
+ return attention_mask
+
+ current_length: int = attention_mask.shape[-1]
+ if current_length != target_length:
+ if attention_mask.device.type == "mps":
+ # HACK: MPS: Does not support padding by greater than dimension of input tensor.
+ # Instead, we can manually construct the padding tensor.
+ padding_shape = (attention_mask.shape[0], attention_mask.shape[1], target_length)
+ padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device)
+ attention_mask = torch.cat([attention_mask, padding], dim=2)
+ else:
+ # TODO: for pipelines such as stable-diffusion, padding cross-attn mask:
+ # we want to instead pad by (0, remaining_length), where remaining_length is:
+ # remaining_length: int = target_length - current_length
+ attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
+
+ if out_dim == 3:
+ if attention_mask.shape[0] < batch_size * head_size:
+ attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
+ elif out_dim == 4:
+ attention_mask = attention_mask.unsqueeze(1)
+ attention_mask = attention_mask.repeat_interleave(head_size, dim=1)
+
+ return attention_mask
+
+ def get_attention_scores(
+ self, query: torch.Tensor, key: torch.Tensor, attention_mask: torch.Tensor = None
+ ) -> torch.Tensor:
+ r"""
+ Compute the attention scores.
+
+ Args:
+ query (`torch.Tensor`): The query tensor.
+ key (`torch.Tensor`): The key tensor.
+ attention_mask (`torch.Tensor`, *optional*): The attention mask to use. If `None`, no mask is applied.
+
+ Returns:
+ `torch.Tensor`: The attention probabilities/scores.
+ """
+ dtype = query.dtype
+ if self.upcast_attention:
+ query = query.float()
+ key = key.float()
+
+ if attention_mask is None:
+ baddbmm_input = torch.empty(
+ query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device
+ )
+ beta = 0
+ else:
+ baddbmm_input = attention_mask
+ beta = 1
+
+ attention_scores = torch.baddbmm(
+ baddbmm_input,
+ query,
+ key.transpose(-1, -2),
+ beta=beta,
+ alpha=self.scale,
+ )
+ del baddbmm_input
+
+ if self.upcast_softmax:
+ attention_scores = attention_scores.float()
+
+ attention_probs = attention_scores.softmax(dim=-1)
+ del attention_scores
+
+ attention_probs = attention_probs.to(dtype)
+
+ return attention_probs
+
+ def norm_encoder_hidden_states(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor:
+ r"""
+ Normalize the encoder hidden states. Requires `self.norm_cross` to be specified when constructing the
+ `Attention` class.
+
+ Args:
+ encoder_hidden_states (`torch.Tensor`): Hidden states of the encoder.
+
+ Returns:
+ `torch.Tensor`: The normalized encoder hidden states.
+ """
+ assert self.norm_cross is not None, "self.norm_cross must be defined to call self.norm_encoder_hidden_states"
+
+ if isinstance(self.norm_cross, nn.LayerNorm):
+ encoder_hidden_states = self.norm_cross(encoder_hidden_states)
+ elif isinstance(self.norm_cross, nn.GroupNorm):
+ # Group norm norms along the channels dimension and expects
+ # input to be in the shape of (N, C, *). In this case, we want
+ # to norm along the hidden dimension, so we need to move
+ # (batch_size, sequence_length, hidden_size) ->
+ # (batch_size, hidden_size, sequence_length)
+ encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
+ encoder_hidden_states = self.norm_cross(encoder_hidden_states)
+ encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
+ else:
+ assert False
+
+ return encoder_hidden_states
+
+ def batch_to_head_dim(self, tensor: torch.Tensor) -> torch.Tensor:
+ r"""
+ Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size // nheads, seq_len, dim * nheads]`. `nheads`
+ is the number of heads initialized while constructing the `Attention` class.
+
+ Args:
+ tensor (`torch.Tensor`): The tensor to reshape.
+
+ Returns:
+ `torch.Tensor`: The reshaped tensor.
+ """
+ head_size = self.nheads
+ batch_size, seq_len, dim = tensor.shape
+ tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
+ return tensor
+
+ def head_to_batch_dim(self, tensor: torch.Tensor, out_dim: int = 3) -> torch.Tensor:
+ r"""
+ Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size, seq_len, nheads, dim // nheads]` `nheads` is
+ the number of heads initialized while constructing the `Attention` class.
+
+ Args:
+ tensor (`torch.Tensor`): The tensor to reshape.
+ out_dim (`int`, *optional*, defaults to `3`): The output dimension of the tensor. If `3`, the tensor is
+ reshaped to `[batch_size * nheads, seq_len, dim // nheads]`.
+
+ Returns:
+ `torch.Tensor`: The reshaped tensor.
+ """
+ head_size = self.nheads
+ batch_size, seq_len, dim = tensor.shape
+ tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
+ tensor = tensor.permute(0, 2, 1, 3)
+
+ if out_dim == 3:
+ tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size)
+
+ return tensor
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ encoder_hidden_states: torch.FloatTensor = None,
+ attention_mask: torch.FloatTensor = None,
+ **cross_attention_kwargs,
+ ) -> torch.Tensor:
+ r"""
+ The forward method of the `Attention` class.
+
+ Args:
+ hidden_states (`torch.Tensor`):
+ The hidden states of the query.
+ encoder_hidden_states (`torch.Tensor`, *optional*):
+ The hidden states of the encoder.
+ attention_mask (`torch.Tensor`, *optional*):
+ The attention mask to use. If `None`, no mask is applied.
+ **cross_attention_kwargs:
+ Additional keyword arguments to pass along to the cross attention.
+
+ Returns:
+ `torch.Tensor`: The output of the attention layer.
+ """
+ # The `Attention` class can call different attention processors / attention functions
+ # here we simply pass along all tensors to the selected processor class
+ # For standard processors that are defined here, `**cross_attention_kwargs` is empty
+
+ unused_kwargs = [k for k, _ in cross_attention_kwargs.items() if k not in self._attn_parameters]
+ # if len(unused_kwargs) > 0:
+ # logger.warning(
+ # f"cross_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored."
+ # )
+ cross_attention_kwargs = {k: w for k, w in cross_attention_kwargs.items() if k in self._attn_parameters}
+
+ return self.processor(
+ self,
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ **cross_attention_kwargs,
+ )
+
+
+class SpatialAttention(Attention):
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ encoder_hidden_states: torch.FloatTensor = None,
+ attention_mask: torch.FloatTensor = None,
+ **cross_attention_kwargs,
+ ) -> torch.Tensor:
+ is_image = hidden_states.ndim == 4
+ if is_image:
+ hidden_states = rearrange(hidden_states, "b c h w -> b c 1 h w")
+
+ bsz, h = hidden_states.shape[0], hidden_states.shape[3]
+ hidden_states = rearrange(hidden_states, "b c t h w -> (b t) (h w) c")
+ if encoder_hidden_states is not None:
+ encoder_hidden_states = rearrange(encoder_hidden_states, "b c t h w -> (b t) (h w) c")
+
+ if attention_mask is not None:
+ attention_mask = rearrange(attention_mask, "b t h w -> (b t) (h w)")
+
+ hidden_states = super().forward(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ **cross_attention_kwargs,
+ )
+
+ hidden_states = rearrange(hidden_states, "(b t) (h w) c -> b c t h w", b=bsz, h=h)
+
+ if is_image:
+ hidden_states = rearrange(hidden_states, "b c 1 h w -> b c h w")
+
+ return hidden_states
+
+
+class TemporalAttention(Attention):
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ encoder_hidden_states: torch.FloatTensor = None,
+ attention_mask: torch.FloatTensor = None,
+ **cross_attention_kwargs,
+ ) -> torch.Tensor:
+ bsz, h = hidden_states.shape[0], hidden_states.shape[3]
+ hidden_states = rearrange(hidden_states, "b c t h w -> (b h w) t c")
+ if encoder_hidden_states is not None:
+ encoder_hidden_states = rearrange(encoder_hidden_states, "b c t h w -> (b h w) t c")
+
+ if attention_mask is not None:
+ attention_mask = rearrange(attention_mask, "b t h w -> (b h w) t")
+
+ hidden_states = super().forward(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ **cross_attention_kwargs,
+ )
+
+ hidden_states = rearrange(hidden_states, "(b h w) t c -> b c t h w", b=bsz, h=h)
+
+ return hidden_states
+
+
+class Attention3D(Attention):
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ encoder_hidden_states: torch.FloatTensor = None,
+ attention_mask: torch.FloatTensor = None,
+ **cross_attention_kwargs,
+ ) -> torch.Tensor:
+ t, h = hidden_states.shape[2], hidden_states.shape[3]
+ hidden_states = rearrange(hidden_states, "b c t h w -> b (t h w) c")
+ if encoder_hidden_states is not None:
+ encoder_hidden_states = rearrange(encoder_hidden_states, "b c t h w -> b (t h w) c")
+
+ if attention_mask is not None:
+ attention_mask = rearrange(attention_mask, "b t h w -> b (t h w)")
+
+ hidden_states = super().forward(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ **cross_attention_kwargs,
+ )
+
+ hidden_states = rearrange(hidden_states, "b (t h w) c -> b c t h w", t=t, h=h)
+
+ return hidden_states
diff --git a/easyanimate/vae/ldm/modules/vaemodules/attention_processors.py b/easyanimate/vae/ldm/modules/vaemodules/attention_processors.py
new file mode 100644
index 0000000000000000000000000000000000000000..0d7c05b88e379e8687640e42b46f380a4ded01cf
--- /dev/null
+++ b/easyanimate/vae/ldm/modules/vaemodules/attention_processors.py
@@ -0,0 +1,139 @@
+from typing import TYPE_CHECKING
+
+import torch
+import torch.nn.functional as F
+
+if TYPE_CHECKING:
+ from .attention import Attention
+
+class AttnProcessor:
+ r"""
+ Default processor for performing attention-related computations.
+ """
+
+ def __call__(
+ self,
+ attn: "Attention",
+ hidden_states: torch.FloatTensor,
+ encoder_hidden_states,
+ attention_mask,
+ temb = None,
+ ) -> torch.Tensor:
+ residual = hidden_states
+
+ if attn.spatial_norm is not None:
+ hidden_states = attn.spatial_norm(hidden_states, temb = None)
+
+ # B, L, C
+ assert hidden_states.ndim == 3, f"Hidden states must be 3-dimensional, got {hidden_states.ndim}"
+
+ batch_size, sequence_length, _ = (
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+ )
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+
+ if attn.group_norm is not None:
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2))
+ hidden_states = hidden_states.transpose(1, 2)
+
+ query = attn.to_q(hidden_states)
+
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+ elif attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+
+ query = attn.head_to_batch_dim(query)
+ key = attn.head_to_batch_dim(key)
+ value = attn.head_to_batch_dim(value)
+
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
+ hidden_states = torch.bmm(attention_probs, value)
+ hidden_states = attn.batch_to_head_dim(hidden_states)
+
+ hidden_states = attn.to_out(hidden_states)
+ hidden_states = attn.dropout(hidden_states)
+
+ if attn.residual_connection:
+ hidden_states = hidden_states + residual
+
+ hidden_states = hidden_states / attn.rescale_output_factor
+
+ return hidden_states
+
+class AttnProcessor2_0:
+ r"""
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
+ """
+
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
+
+ def __call__(
+ self,
+ attn: "Attention",
+ hidden_states: torch.FloatTensor,
+ encoder_hidden_states,
+ attention_mask,
+ temb = None,
+ ) -> torch.FloatTensor:
+ residual = hidden_states
+ if attn.spatial_norm is not None:
+ hidden_states = attn.spatial_norm(hidden_states, temb = None)
+
+ # B, L, C
+ assert hidden_states.ndim == 3, f"Hidden states must be 3-dimensional, got {hidden_states.ndim}"
+
+ batch_size, sequence_length, _ = (
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+ )
+
+ if attention_mask is not None:
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+ # scaled_dot_product_attention expects attention_mask shape to be
+ # (batch, heads, source_length, target_length)
+ attention_mask = attention_mask.view(batch_size, attn.nheads, -1, attention_mask.shape[-1])
+
+ if attn.group_norm is not None:
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2))
+ hidden_states = hidden_states.transpose(1, 2)
+
+ query: torch.Tensor = attn.to_q(hidden_states)
+
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+ elif attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+ key: torch.Tensor = attn.to_k(encoder_hidden_states)
+ value: torch.Tensor = attn.to_v(encoder_hidden_states)
+
+ inner_dim = key.shape[-1]
+ head_dim = inner_dim // attn.nheads
+
+ query = query.view(batch_size, -1, attn.nheads, head_dim).transpose(1, 2)
+
+ key = key.view(batch_size, -1, attn.nheads, head_dim).transpose(1, 2)
+ value = value.view(batch_size, -1, attn.nheads, head_dim).transpose(1, 2)
+
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
+ hidden_states = F.scaled_dot_product_attention(
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False, scale=attn.scale
+ )
+
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.nheads * head_dim)
+ hidden_states = hidden_states.to(query.dtype)
+
+ hidden_states = attn.to_out(hidden_states)
+ hidden_states = attn.dropout(hidden_states)
+
+ if attn.residual_connection:
+ hidden_states = hidden_states + residual
+
+ hidden_states = hidden_states / attn.rescale_output_factor
+
+ return hidden_states
diff --git a/easyanimate/vae/ldm/modules/vaemodules/common.py b/easyanimate/vae/ldm/modules/vaemodules/common.py
new file mode 100755
index 0000000000000000000000000000000000000000..a49999dd518e4439cb1b184972f8ac1d66abd3cd
--- /dev/null
+++ b/easyanimate/vae/ldm/modules/vaemodules/common.py
@@ -0,0 +1,260 @@
+import math
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from einops import rearrange, repeat
+
+from .activations import get_activation
+
+
+def cast_tuple(t, length = 1):
+ return t if isinstance(t, tuple) else ((t,) * length)
+
+def divisible_by(num, den):
+ return (num % den) == 0
+
+def is_odd(n):
+ return not divisible_by(n, 2)
+
+class CausalConv3d(nn.Conv3d):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size=3, # : int | tuple[int, int, int],
+ stride=1, # : int | tuple[int, int, int] = 1,
+ padding=1, # : int | tuple[int, int, int], # TODO: change it to 0.
+ dilation=1, # : int | tuple[int, int, int] = 1,
+ **kwargs,
+ ):
+ kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size,) * 3
+ assert len(kernel_size) == 3, f"Kernel size must be a 3-tuple, got {kernel_size} instead."
+
+ stride = stride if isinstance(stride, tuple) else (stride,) * 3
+ assert len(stride) == 3, f"Stride must be a 3-tuple, got {stride} instead."
+
+ dilation = dilation if isinstance(dilation, tuple) else (dilation,) * 3
+ assert len(dilation) == 3, f"Dilation must be a 3-tuple, got {dilation} instead."
+
+ t_ks, h_ks, w_ks = kernel_size
+ _, h_stride, w_stride = stride
+ t_dilation, h_dilation, w_dilation = dilation
+
+ t_pad = (t_ks - 1) * t_dilation
+ # TODO: align with SD
+ if padding is None:
+ h_pad = math.ceil(((h_ks - 1) * h_dilation + (1 - h_stride)) / 2)
+ w_pad = math.ceil(((w_ks - 1) * w_dilation + (1 - w_stride)) / 2)
+ elif isinstance(padding, int):
+ h_pad = w_pad = padding
+ else:
+ assert NotImplementedError
+
+ self.temporal_padding = t_pad
+ self.temporal_padding_origin = math.ceil(((t_ks - 1) * w_dilation + (1 - w_stride)) / 2)
+ self.padding_flag = 0
+
+ super().__init__(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ dilation=dilation,
+ padding=(0, h_pad, w_pad),
+ **kwargs,
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ # x: (B, C, T, H, W)
+ if self.padding_flag == 0:
+ x = F.pad(
+ x,
+ pad=(0, 0, 0, 0, self.temporal_padding, 0),
+ mode="replicate", # TODO: check if this is necessary
+ )
+ else:
+ x = F.pad(
+ x,
+ pad=(0, 0, 0, 0, self.temporal_padding_origin, self.temporal_padding_origin),
+ )
+ return super().forward(x)
+
+ def set_padding_one_frame(self):
+ def _set_padding_one_frame(name, module):
+ if hasattr(module, 'padding_flag'):
+ print('Set pad mode for module[%s] type=%s' % (name, str(type(module))))
+ module.padding_flag = 1
+ for sub_name, sub_mod in module.named_children():
+ _set_padding_one_frame(sub_name, sub_mod)
+ for name, module in self.named_children():
+ _set_padding_one_frame(name, module)
+
+ def set_padding_more_frame(self):
+ def _set_padding_more_frame(name, module):
+ if hasattr(module, 'padding_flag'):
+ print('Set pad mode for module[%s] type=%s' % (name, str(type(module))))
+ module.padding_flag = 2
+ for sub_name, sub_mod in module.named_children():
+ _set_padding_more_frame(sub_name, sub_mod)
+ for name, module in self.named_children():
+ _set_padding_more_frame(name, module)
+
+class ResidualBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ non_linearity: str = "silu",
+ norm_num_groups: int = 32,
+ norm_eps: float = 1e-6,
+ dropout: float = 0.0,
+ output_scale_factor: float = 1.0,
+ ):
+ super().__init__()
+
+ self.output_scale_factor = output_scale_factor
+
+ self.norm1 = nn.GroupNorm(
+ num_groups=norm_num_groups,
+ num_channels=in_channels,
+ eps=norm_eps,
+ affine=True,
+ )
+
+ self.nonlinearity = get_activation(non_linearity)
+
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
+
+ self.norm2 = nn.GroupNorm(
+ num_groups=norm_num_groups,
+ num_channels=out_channels,
+ eps=norm_eps,
+ affine=True,
+ )
+
+ self.dropout = nn.Dropout(dropout)
+
+ self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
+
+ if in_channels != out_channels:
+ self.shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1)
+ else:
+ self.shortcut = nn.Identity()
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ shortcut = self.shortcut(x)
+
+ x = self.norm1(x)
+ x = self.nonlinearity(x)
+
+ x = self.conv1(x)
+
+ x = self.norm2(x)
+ x = self.nonlinearity(x)
+
+ x = self.dropout(x)
+ x = self.conv2(x)
+
+ return (x + shortcut) / self.output_scale_factor
+
+
+class ResidualBlock3D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ non_linearity: str = "silu",
+ norm_num_groups: int = 32,
+ norm_eps: float = 1e-6,
+ dropout: float = 0.0,
+ output_scale_factor: float = 1.0,
+ ):
+ super().__init__()
+
+ self.output_scale_factor = output_scale_factor
+
+ self.norm1 = nn.GroupNorm(
+ num_groups=norm_num_groups,
+ num_channels=in_channels,
+ eps=norm_eps,
+ affine=True,
+ )
+
+ self.nonlinearity = get_activation(non_linearity)
+
+ self.conv1 = CausalConv3d(in_channels, out_channels, kernel_size=3)
+
+ self.norm2 = nn.GroupNorm(
+ num_groups=norm_num_groups,
+ num_channels=out_channels,
+ eps=norm_eps,
+ affine=True,
+ )
+
+ self.dropout = nn.Dropout(dropout)
+
+ self.conv2 = CausalConv3d(out_channels, out_channels, kernel_size=3)
+
+ if in_channels != out_channels:
+ self.shortcut = nn.Conv3d(in_channels, out_channels, kernel_size=1)
+ else:
+ self.shortcut = nn.Identity()
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ shortcut = self.shortcut(x)
+
+ x = self.norm1(x)
+ x = self.nonlinearity(x)
+
+ x = self.conv1(x)
+
+ x = self.norm2(x)
+ x = self.nonlinearity(x)
+
+ x = self.dropout(x)
+ x = self.conv2(x)
+ return (x + shortcut) / self.output_scale_factor
+
+
+class SpatialNorm2D(nn.Module):
+ """
+ Spatially conditioned normalization as defined in https://arxiv.org/abs/2209.09002.
+
+ Args:
+ f_channels (`int`):
+ The number of channels for input to group normalization layer, and output of the spatial norm layer.
+ zq_channels (`int`):
+ The number of channels for the quantized vector as described in the paper.
+ """
+
+ def __init__(
+ self,
+ f_channels: int,
+ zq_channels: int,
+ ):
+ super().__init__()
+
+ self.norm = nn.GroupNorm(num_channels=f_channels, num_groups=32, eps=1e-6, affine=True)
+ self.conv_y = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
+ self.conv_b = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
+
+ def forward(self, f: torch.FloatTensor, zq: torch.FloatTensor) -> torch.FloatTensor:
+ f_size = f.shape[-2:]
+ zq = F.interpolate(zq, size=f_size, mode="nearest")
+ norm_f = self.norm(f)
+ new_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
+ return new_f
+
+
+class SpatialNorm3D(SpatialNorm2D):
+ def forward(self, f: torch.FloatTensor, zq: torch.FloatTensor) -> torch.FloatTensor:
+ batch_size = f.shape[0]
+ f = rearrange(f, "b c t h w -> (b t) c h w")
+ zq = rearrange(zq, "b c t h w -> (b t) c h w")
+
+ x = super().forward(f, zq)
+
+ x = rearrange(x, "(b t) c h w -> b c t h w", b=batch_size)
+
+ return x
diff --git a/easyanimate/vae/ldm/modules/vaemodules/discriminator.py b/easyanimate/vae/ldm/modules/vaemodules/discriminator.py
new file mode 100644
index 0000000000000000000000000000000000000000..a88e07b3bdcd6e89ff417443d709ddf7b0829418
--- /dev/null
+++ b/easyanimate/vae/ldm/modules/vaemodules/discriminator.py
@@ -0,0 +1,214 @@
+import math
+
+import torch
+import torch.nn as nn
+
+from .downsamplers import BlurPooling2D, BlurPooling3D
+
+
+class DiscriminatorBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ dropout: float = 0.0,
+ output_scale_factor: float = 1.0,
+ add_downsample: bool = True,
+ ):
+ super().__init__()
+
+ self.output_scale_factor = output_scale_factor
+
+ self.norm1 = nn.BatchNorm2d(in_channels)
+
+ self.nonlinearity = nn.LeakyReLU(0.2)
+
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
+
+ if add_downsample:
+ self.downsampler = BlurPooling2D(out_channels, out_channels)
+ else:
+ self.downsampler = nn.Identity()
+
+ self.norm2 = nn.BatchNorm2d(out_channels)
+
+ self.dropout = nn.Dropout(dropout)
+
+ self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
+
+ if add_downsample:
+ self.shortcut = nn.Sequential(
+ BlurPooling2D(in_channels, in_channels),
+ nn.Conv2d(in_channels, out_channels, kernel_size=1),
+ )
+ else:
+ self.shortcut = nn.Identity()
+
+ self.spatial_downsample_factor = 2
+ self.temporal_downsample_factor = 1
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ shortcut = self.shortcut(x)
+
+ x = self.norm1(x)
+ x = self.nonlinearity(x)
+
+ x = self.conv1(x)
+
+ x = self.norm2(x)
+ x = self.nonlinearity(x)
+
+ x = self.dropout(x)
+ x = self.downsampler(x)
+ x = self.conv2(x)
+
+ return (x + shortcut) / self.output_scale_factor
+
+
+class Discriminator2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int = 3,
+ block_out_channels = (64,),
+ ):
+ super().__init__()
+
+ self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=1)
+
+ self.blocks = nn.ModuleList([])
+
+ output_channels = block_out_channels[0]
+ for i, out_channels in enumerate(block_out_channels):
+ input_channels = output_channels
+ output_channels = out_channels
+ is_final_block = i == len(block_out_channels) - 1
+
+ self.blocks.append(
+ DiscriminatorBlock2D(
+ in_channels=input_channels,
+ out_channels=output_channels,
+ output_scale_factor=math.sqrt(2),
+ add_downsample=not is_final_block,
+ )
+ )
+
+ self.conv_norm_out = nn.BatchNorm2d(block_out_channels[-1])
+ self.conv_act = nn.LeakyReLU(0.2)
+
+ self.conv_out = nn.Conv2d(block_out_channels[-1], 1, kernel_size=3, padding=1)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ # x: (B, C, H, W)
+ x = self.conv_in(x)
+
+ for block in self.blocks:
+ x = block(x)
+
+ x = self.conv_out(x)
+
+ return x
+
+
+class DiscriminatorBlock3D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ dropout: float = 0.0,
+ output_scale_factor: float = 1.0,
+ add_downsample: bool = True,
+ ):
+ super().__init__()
+
+ self.output_scale_factor = output_scale_factor
+
+ self.norm1 = nn.GroupNorm(32, in_channels)
+
+ self.nonlinearity = nn.LeakyReLU(0.2)
+
+ self.conv1 = nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1)
+
+ if add_downsample:
+ self.downsampler = BlurPooling3D(out_channels, out_channels)
+ else:
+ self.downsampler = nn.Identity()
+
+ self.norm2 = nn.GroupNorm(32, out_channels)
+
+ self.dropout = nn.Dropout(dropout)
+
+ self.conv2 = nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1)
+
+ if add_downsample:
+ self.shortcut = nn.Sequential(
+ BlurPooling3D(in_channels, in_channels),
+ nn.Conv3d(in_channels, out_channels, kernel_size=1),
+ )
+ else:
+ self.shortcut = nn.Sequential(
+ nn.Conv3d(in_channels, out_channels, kernel_size=1),
+ )
+
+ self.spatial_downsample_factor = 2
+ self.temporal_downsample_factor = 2
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ shortcut = self.shortcut(x)
+
+ x = self.norm1(x)
+ x = self.nonlinearity(x)
+
+ x = self.conv1(x)
+
+ x = self.norm2(x)
+ x = self.nonlinearity(x)
+
+ x = self.dropout(x)
+ x = self.downsampler(x)
+ x = self.conv2(x)
+
+ return (x + shortcut) / self.output_scale_factor
+
+
+class Discriminator3D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int = 3,
+ block_out_channels = (64,),
+ ):
+ super().__init__()
+
+ self.conv_in = nn.Conv3d(in_channels, block_out_channels[0], kernel_size=3, padding=1, stride=2)
+
+ self.blocks = nn.ModuleList([])
+
+ output_channels = block_out_channels[0]
+ for i, out_channels in enumerate(block_out_channels):
+ input_channels = output_channels
+ output_channels = out_channels
+ is_final_block = i == len(block_out_channels) - 1
+
+ self.blocks.append(
+ DiscriminatorBlock3D(
+ in_channels=input_channels,
+ out_channels=output_channels,
+ output_scale_factor=math.sqrt(2),
+ add_downsample=not is_final_block,
+ )
+ )
+
+ self.conv_norm_out = nn.GroupNorm(32, block_out_channels[-1])
+ self.conv_act = nn.LeakyReLU(0.2)
+
+ self.conv_out = nn.Conv3d(block_out_channels[-1], 1, kernel_size=3, padding=1)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ # x: (B, C, T, H, W)
+ x = self.conv_in(x)
+
+ for block in self.blocks:
+ x = block(x)
+
+ x = self.conv_out(x)
+
+ return x
diff --git a/easyanimate/vae/ldm/modules/vaemodules/down_blocks.py b/easyanimate/vae/ldm/modules/vaemodules/down_blocks.py
new file mode 100755
index 0000000000000000000000000000000000000000..48b2c1fd057cb70f1b61ca26aaa22c9dd691c0b1
--- /dev/null
+++ b/easyanimate/vae/ldm/modules/vaemodules/down_blocks.py
@@ -0,0 +1,533 @@
+import torch
+import torch.nn as nn
+
+from .attention import SpatialAttention, TemporalAttention
+from .common import ResidualBlock3D
+from .downsamplers import (SpatialDownsampler3D, SpatialTemporalDownsampler3D,
+ TemporalDownsampler3D)
+from .gc_block import GlobalContextBlock
+
+
+def get_down_block(
+ down_block_type: str,
+ in_channels: int,
+ out_channels: int,
+ num_layers: int,
+ act_fn: str,
+ norm_num_groups: int = 32,
+ norm_eps: float = 1e-6,
+ dropout: float = 0.0,
+ num_attention_heads: int = 1,
+ output_scale_factor: float = 1.0,
+ add_gc_block: bool = False,
+ add_downsample: bool = True,
+) -> nn.Module:
+ if down_block_type == "DownBlock3D":
+ return DownBlock3D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ num_layers=num_layers,
+ act_fn=act_fn,
+ norm_num_groups=norm_num_groups,
+ norm_eps=norm_eps,
+ dropout=dropout,
+ output_scale_factor=output_scale_factor,
+ add_gc_block=add_gc_block,
+ )
+ elif down_block_type == "SpatialDownBlock3D":
+ return SpatialDownBlock3D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ num_layers=num_layers,
+ act_fn=act_fn,
+ norm_num_groups=norm_num_groups,
+ norm_eps=norm_eps,
+ dropout=dropout,
+ output_scale_factor=output_scale_factor,
+ add_gc_block=add_gc_block,
+ add_downsample=add_downsample,
+ )
+ elif down_block_type == "SpatialAttnDownBlock3D":
+ return SpatialAttnDownBlock3D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ num_layers=num_layers,
+ act_fn=act_fn,
+ norm_num_groups=norm_num_groups,
+ norm_eps=norm_eps,
+ dropout=dropout,
+ attention_head_dim=out_channels // num_attention_heads,
+ output_scale_factor=output_scale_factor,
+ add_gc_block=add_gc_block,
+ add_downsample=add_downsample,
+ )
+ elif down_block_type == "TemporalDownBlock3D":
+ return TemporalDownBlock3D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ num_layers=num_layers,
+ act_fn=act_fn,
+ norm_num_groups=norm_num_groups,
+ norm_eps=norm_eps,
+ dropout=dropout,
+ output_scale_factor=output_scale_factor,
+ add_gc_block=add_gc_block,
+ add_downsample=add_downsample,
+ )
+ elif down_block_type == "TemporalAttnDownBlock3D":
+ return TemporalAttnDownBlock3D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ num_layers=num_layers,
+ act_fn=act_fn,
+ norm_num_groups=norm_num_groups,
+ norm_eps=norm_eps,
+ dropout=dropout,
+ attention_head_dim=out_channels // num_attention_heads,
+ output_scale_factor=output_scale_factor,
+ add_gc_block=add_gc_block,
+ add_downsample=add_downsample,
+ )
+ elif down_block_type == "SpatialTemporalDownBlock3D":
+ return SpatialTemporalDownBlock3D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ num_layers=num_layers,
+ act_fn=act_fn,
+ norm_num_groups=norm_num_groups,
+ norm_eps=norm_eps,
+ dropout=dropout,
+ output_scale_factor=output_scale_factor,
+ add_gc_block=add_gc_block,
+ add_downsample=add_downsample,
+ )
+ else:
+ raise ValueError(f"Unknown down block type: {down_block_type}")
+
+
+class DownBlock3D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ num_layers: int = 1,
+ act_fn: str = "silu",
+ norm_num_groups: int = 32,
+ norm_eps: float = 1e-6,
+ dropout: float = 0.0,
+ output_scale_factor: float = 1.0,
+ add_gc_block: bool = False,
+ ):
+ super().__init__()
+
+ self.convs = nn.ModuleList([])
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ self.convs.append(
+ ResidualBlock3D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ non_linearity=act_fn,
+ norm_num_groups=norm_num_groups,
+ norm_eps=norm_eps,
+ dropout=dropout,
+ output_scale_factor=output_scale_factor,
+ )
+ )
+
+ if add_gc_block:
+ self.gc_block = GlobalContextBlock(out_channels, out_channels, fusion_type="mul")
+ else:
+ self.gc_block = None
+
+ self.spatial_downsample_factor = 1
+ self.temporal_downsample_factor = 1
+
+ def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
+ for conv in self.convs:
+ x = conv(x)
+
+ if self.gc_block is not None:
+ x = self.gc_block(x)
+
+ return x
+
+
+class SpatialDownBlock3D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ num_layers: int = 1,
+ act_fn: str = "silu",
+ norm_num_groups: int = 32,
+ norm_eps: float = 1e-6,
+ dropout: float = 0.0,
+ output_scale_factor: float = 1.0,
+ add_gc_block: bool = False,
+ add_downsample: bool = True,
+ ):
+ super().__init__()
+
+ self.convs = nn.ModuleList([])
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ self.convs.append(
+ ResidualBlock3D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ non_linearity=act_fn,
+ norm_num_groups=norm_num_groups,
+ norm_eps=norm_eps,
+ dropout=dropout,
+ output_scale_factor=output_scale_factor,
+ )
+ )
+
+ if add_gc_block:
+ self.gc_block = GlobalContextBlock(out_channels, out_channels, fusion_type="mul")
+ else:
+ self.gc_block = None
+
+ if add_downsample:
+ self.downsampler = SpatialDownsampler3D(out_channels, out_channels)
+ self.spatial_downsample_factor = 2
+ else:
+ self.downsampler = None
+ self.spatial_downsample_factor = 1
+
+ self.temporal_downsample_factor = 1
+
+ def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
+ for conv in self.convs:
+ x = conv(x)
+
+ if self.gc_block is not None:
+ x = self.gc_block(x)
+
+ if self.downsampler is not None:
+ x = self.downsampler(x)
+
+ return x
+
+
+class TemporalDownBlock3D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ num_layers: int = 1,
+ act_fn: str = "silu",
+ norm_num_groups: int = 32,
+ norm_eps: float = 1e-6,
+ dropout: float = 0.0,
+ output_scale_factor: float = 1.0,
+ add_gc_block: bool = False,
+ add_downsample: bool = True,
+ ):
+ super().__init__()
+
+ self.convs = nn.ModuleList([])
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ self.convs.append(
+ ResidualBlock3D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ non_linearity=act_fn,
+ norm_num_groups=norm_num_groups,
+ norm_eps=norm_eps,
+ dropout=dropout,
+ output_scale_factor=output_scale_factor,
+ )
+ )
+
+ if add_gc_block:
+ self.gc_block = GlobalContextBlock(out_channels, out_channels, fusion_type="mul")
+ else:
+ self.gc_block = None
+
+ if add_downsample:
+ self.downsampler = TemporalDownsampler3D(out_channels, out_channels)
+ self.temporal_downsample_factor = 2
+ else:
+ self.downsampler = None
+ self.temporal_downsample_factor = 1
+
+ self.spatial_downsample_factor = 1
+
+ def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
+ for conv in self.convs:
+ x = conv(x)
+
+ if self.gc_block is not None:
+ x = self.gc_block(x)
+
+ if self.downsampler is not None:
+ x = self.downsampler(x)
+
+ return x
+
+
+class SpatialTemporalDownBlock3D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ num_layers: int = 1,
+ act_fn: str = "silu",
+ norm_num_groups: int = 32,
+ norm_eps: float = 1e-6,
+ dropout: float = 0.0,
+ output_scale_factor: float = 1.0,
+ add_gc_block: bool = False,
+ add_downsample: bool = True,
+ ):
+ super().__init__()
+
+ self.convs = nn.ModuleList([])
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ self.convs.append(
+ ResidualBlock3D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ non_linearity=act_fn,
+ norm_num_groups=norm_num_groups,
+ norm_eps=norm_eps,
+ dropout=dropout,
+ output_scale_factor=output_scale_factor,
+ )
+ )
+
+ if add_gc_block:
+ self.gc_block = GlobalContextBlock(out_channels, out_channels, fusion_type="mul")
+ else:
+ self.gc_block = None
+
+ if add_downsample:
+ self.downsampler = SpatialTemporalDownsampler3D(out_channels, out_channels)
+ self.spatial_downsample_factor = 2
+ self.temporal_downsample_factor = 2
+ else:
+ self.downsampler = None
+ self.spatial_downsample_factor = 1
+ self.temporal_downsample_factor = 1
+
+ def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
+ for conv in self.convs:
+ x = conv(x)
+
+ if self.gc_block is not None:
+ x = self.gc_block(x)
+
+ if self.downsampler is not None:
+ x = self.downsampler(x)
+
+ return x
+
+
+class SpatialAttnDownBlock3D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ num_layers: int = 1,
+ act_fn: str = "silu",
+ norm_num_groups: int = 32,
+ norm_eps: float = 1e-6,
+ dropout: float = 0.0,
+ attention_head_dim: int = 1,
+ output_scale_factor: float = 1.0,
+ add_gc_block: bool = False,
+ add_downsample: bool = True,
+ ):
+ super().__init__()
+
+ self.convs = nn.ModuleList([])
+ self.attentions = nn.ModuleList([])
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ self.convs.append(
+ ResidualBlock3D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ non_linearity=act_fn,
+ norm_num_groups=norm_num_groups,
+ norm_eps=norm_eps,
+ dropout=dropout,
+ output_scale_factor=output_scale_factor,
+ )
+ )
+ self.attentions.append(
+ SpatialAttention(
+ out_channels,
+ nheads=out_channels // attention_head_dim,
+ head_dim=attention_head_dim,
+ bias=True,
+ upcast_softmax=True,
+ norm_num_groups=norm_num_groups,
+ eps=norm_eps,
+ rescale_output_factor=output_scale_factor,
+ residual_connection=True,
+ )
+ )
+
+ if add_gc_block:
+ self.gc_block = GlobalContextBlock(out_channels, out_channels, fusion_type="mul")
+ else:
+ self.gc_block = None
+
+ if add_downsample:
+ self.downsampler = SpatialDownsampler3D(out_channels, out_channels)
+ self.spatial_downsample_factor = 2
+ else:
+ self.downsampler = None
+ self.spatial_downsample_factor = 1
+
+ self.temporal_downsample_factor = 1
+
+ def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
+ for conv, attn in zip(self.convs, self.attentions):
+ x = conv(x)
+ x = attn(x)
+
+ if self.gc_block is not None:
+ x = self.gc_block(x)
+
+ if self.downsampler is not None:
+ x = self.downsampler(x)
+
+ return x
+
+
+class TemporalDownBlock3D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ num_layers: int = 1,
+ act_fn: str = "silu",
+ norm_num_groups: int = 32,
+ norm_eps: float = 1e-6,
+ dropout: float = 0.0,
+ output_scale_factor: float = 1.0,
+ add_gc_block: bool = False,
+ add_downsample: bool = True,
+ ):
+ super().__init__()
+
+ self.convs = nn.ModuleList([])
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ self.convs.append(
+ ResidualBlock3D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ non_linearity=act_fn,
+ norm_num_groups=norm_num_groups,
+ norm_eps=norm_eps,
+ dropout=dropout,
+ output_scale_factor=output_scale_factor,
+ )
+ )
+
+ if add_gc_block:
+ self.gc_block = GlobalContextBlock(out_channels, out_channels, fusion_type="mul")
+ else:
+ self.gc_block = None
+
+ if add_downsample:
+ self.downsampler = TemporalDownsampler3D(out_channels, out_channels)
+ self.temporal_downsample_factor = 2
+ else:
+ self.downsampler = None
+ self.temporal_downsample_factor = 1
+
+ self.spatial_downsample_factor = 1
+
+ def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
+ for conv in self.convs:
+ x = conv(x)
+
+ if self.gc_block is not None:
+ x = self.gc_block(x)
+
+ if self.downsampler is not None:
+ x = self.downsampler(x)
+
+ return x
+
+
+class TemporalAttnDownBlock3D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ num_layers: int = 1,
+ act_fn: str = "silu",
+ norm_num_groups: int = 32,
+ norm_eps: float = 1e-6,
+ dropout: float = 0.0,
+ attention_head_dim: int = 1,
+ output_scale_factor: float = 1.0,
+ add_gc_block: bool = False,
+ add_downsample: bool = True,
+ ):
+ super().__init__()
+
+ self.convs = nn.ModuleList([])
+ self.attentions = nn.ModuleList([])
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ self.convs.append(
+ ResidualBlock3D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ non_linearity=act_fn,
+ norm_num_groups=norm_num_groups,
+ norm_eps=norm_eps,
+ dropout=dropout,
+ output_scale_factor=output_scale_factor,
+ )
+ )
+ self.attentions.append(
+ TemporalAttention(
+ out_channels,
+ nheads=out_channels // attention_head_dim,
+ head_dim=attention_head_dim,
+ bias=True,
+ upcast_softmax=True,
+ norm_num_groups=norm_num_groups,
+ eps=norm_eps,
+ rescale_output_factor=output_scale_factor,
+ residual_connection=True,
+ )
+ )
+
+ if add_gc_block:
+ self.gc_block = GlobalContextBlock(out_channels, out_channels, fusion_type="mul")
+ else:
+ self.gc_block = None
+
+ if add_downsample:
+ self.downsampler = TemporalDownsampler3D(out_channels, out_channels)
+ self.temporal_downsample_factor = 2
+ else:
+ self.downsampler = None
+ self.temporal_downsample_factor = 1
+
+ self.spatial_downsample_factor = 1
+
+ def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
+ for conv, attn in zip(self.convs, self.attentions):
+ x = conv(x)
+ x = attn(x)
+
+ if self.gc_block is not None:
+ x = self.gc_block(x)
+
+ if self.downsampler is not None:
+ x = self.downsampler(x)
+
+ return x
diff --git a/easyanimate/vae/ldm/modules/vaemodules/downsamplers.py b/easyanimate/vae/ldm/modules/vaemodules/downsamplers.py
new file mode 100644
index 0000000000000000000000000000000000000000..027fe9b189af57e4adcc8a5dc659ad2d7a88ae4d
--- /dev/null
+++ b/easyanimate/vae/ldm/modules/vaemodules/downsamplers.py
@@ -0,0 +1,148 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from .common import CausalConv3d
+
+
+class Downsampler(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ spatial_downsample_factor: int = 1,
+ temporal_downsample_factor: int = 1,
+ ):
+ super().__init__()
+
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.spatial_downsample_factor = spatial_downsample_factor
+ self.temporal_downsample_factor = temporal_downsample_factor
+
+
+class SpatialDownsampler3D(Downsampler):
+ def __init__(self, in_channels: int, out_channels):
+ if out_channels is None:
+ out_channels = in_channels
+
+ super().__init__(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ spatial_downsample_factor=2,
+ temporal_downsample_factor=1,
+ )
+
+ self.conv = CausalConv3d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=3,
+ stride=(1, 2, 2),
+ padding=0,
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = F.pad(x, (0, 1, 0, 1))
+ return self.conv(x)
+
+
+class TemporalDownsampler3D(Downsampler):
+ def __init__(self, in_channels: int, out_channels):
+ if out_channels is None:
+ out_channels = in_channels
+
+ super().__init__(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ spatial_downsample_factor=1,
+ temporal_downsample_factor=2,
+ )
+
+ self.conv = CausalConv3d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=3,
+ stride=(2, 1, 1),
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return self.conv(x)
+
+
+class SpatialTemporalDownsampler3D(Downsampler):
+ def __init__(self, in_channels: int, out_channels):
+ if out_channels is None:
+ out_channels = in_channels
+
+ super().__init__(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ spatial_downsample_factor=2,
+ temporal_downsample_factor=2,
+ )
+
+ self.conv = CausalConv3d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=3,
+ stride=(2, 2, 2),
+ padding=0,
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = F.pad(x, (0, 1, 0, 1))
+ return self.conv(x)
+
+
+class BlurPooling2D(Downsampler):
+ def __init__(self, in_channels: int, out_channels):
+ if out_channels is None:
+ out_channels = in_channels
+
+ assert in_channels == out_channels
+
+ super().__init__(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ spatial_downsample_factor=2,
+ temporal_downsample_factor=1,
+ )
+
+ filt = torch.tensor([1, 2, 1], dtype=torch.float32)
+ filt = torch.einsum("i,j -> ij", filt, filt)
+ filt = filt / filt.sum()
+ filt = filt[None, None].repeat(out_channels, 1, 1, 1)
+
+ self.register_buffer("filt", filt)
+ self.filt: torch.Tensor
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ # x: (B, C, H, W)
+ return F.conv2d(x, self.filt, stride=2, padding=1, groups=self.in_channels)
+
+
+class BlurPooling3D(Downsampler):
+ def __init__(self, in_channels: int, out_channels):
+ if out_channels is None:
+ out_channels = in_channels
+
+ assert in_channels == out_channels
+
+ super().__init__(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ spatial_downsample_factor=2,
+ temporal_downsample_factor=2,
+ )
+
+ filt = torch.tensor([1, 2, 1], dtype=torch.float32)
+ filt = torch.einsum("i,j,k -> ijk", filt, filt, filt)
+ filt = filt / filt.sum()
+ filt = filt[None, None].repeat(out_channels, 1, 1, 1, 1)
+
+ self.register_buffer("filt", filt)
+ self.filt: torch.Tensor
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ # x: (B, C, T, H, W)
+ return F.conv3d(x, self.filt, stride=2, padding=1, groups=self.in_channels)
diff --git a/easyanimate/vae/ldm/modules/vaemodules/gc_block.py b/easyanimate/vae/ldm/modules/vaemodules/gc_block.py
new file mode 100644
index 0000000000000000000000000000000000000000..43d2eeb5c839f32e7ab6e393c7e76c68504e1358
--- /dev/null
+++ b/easyanimate/vae/ldm/modules/vaemodules/gc_block.py
@@ -0,0 +1,79 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from einops import rearrange
+
+
+class GlobalContextBlock(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ min_channels: int = 16,
+ init_bias: float = -10.,
+ fusion_type: str = "mul",
+ ):
+ super().__init__()
+
+ assert fusion_type in ("mul", "add"), f"Unsupported fusion type: {fusion_type}"
+ self.fusion_type = fusion_type
+
+ self.conv_ctx = nn.Conv2d(in_channels, 1, kernel_size=1)
+
+ num_channels = max(min_channels, out_channels // 2)
+
+ if fusion_type == "mul":
+ self.conv_mul = nn.Sequential(
+ nn.Conv2d(in_channels, num_channels, kernel_size=1),
+ nn.LayerNorm([num_channels, 1, 1]), # TODO: LayerNorm or GroupNorm?
+ nn.LeakyReLU(0.1),
+ nn.Conv2d(num_channels, out_channels, kernel_size=1),
+ nn.Sigmoid(),
+ )
+
+ nn.init.zeros_(self.conv_mul[-2].weight)
+ nn.init.constant_(self.conv_mul[-2].bias, init_bias)
+ else:
+ self.conv_add = nn.Sequential(
+ nn.Conv2d(in_channels, num_channels, kernel_size=1),
+ nn.LayerNorm([num_channels, 1, 1]), # TODO: LayerNorm or GroupNorm?
+ nn.LeakyReLU(0.1),
+ nn.Conv2d(num_channels, out_channels, kernel_size=1),
+ )
+
+ nn.init.zeros_(self.conv_add[-1].weight)
+ nn.init.constant_(self.conv_add[-1].bias, init_bias)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ is_image = x.ndim == 4
+ if is_image:
+ x = rearrange(x, "b c h w -> b c 1 h w")
+
+ # x: (B, C, T, H, W)
+ orig_x = x
+ batch_size = x.shape[0]
+
+ x = rearrange(x, "b c t h w -> (b t) c h w")
+
+ ctx = self.conv_ctx(x)
+ ctx = rearrange(ctx, "b c h w -> b c (h w)")
+ ctx = F.softmax(ctx, dim=-1)
+
+ flattened_x = rearrange(x, "b c h w -> b c (h w)")
+
+ x = torch.einsum("b c1 n, b c2 n -> b c2 c1", ctx, flattened_x)
+ x = rearrange(x, "... -> ... 1")
+
+ if self.fusion_type == "mul":
+ mul_term = self.conv_mul(x)
+ mul_term = rearrange(mul_term, "(b t) c h w -> b c t h w", b=batch_size)
+ x = orig_x * mul_term
+ else:
+ add_term = self.conv_add(x)
+ add_term = rearrange(add_term, "(b t) c h w -> b c t h w", b=batch_size)
+ x = orig_x + add_term
+
+ if is_image:
+ x = rearrange(x, "b c 1 h w -> b c h w")
+
+ return x
diff --git a/easyanimate/vae/ldm/modules/vaemodules/mid_blocks.py b/easyanimate/vae/ldm/modules/vaemodules/mid_blocks.py
new file mode 100755
index 0000000000000000000000000000000000000000..12776f43ec6d31651847e93e7a18ef2cae82c9ba
--- /dev/null
+++ b/easyanimate/vae/ldm/modules/vaemodules/mid_blocks.py
@@ -0,0 +1,196 @@
+import torch
+import torch.nn as nn
+
+from .attention import Attention3D, SpatialAttention, TemporalAttention
+from .common import ResidualBlock3D
+
+
+def get_mid_block(
+ mid_block_type: str,
+ in_channels: int,
+ num_layers: int,
+ act_fn: str,
+ norm_num_groups: int = 32,
+ norm_eps: float = 1e-6,
+ dropout: float = 0.0,
+ add_attention: bool = True,
+ attention_type: str = "3d",
+ num_attention_heads: int = 1,
+ output_scale_factor: float = 1.0,
+) -> nn.Module:
+ if mid_block_type == "MidBlock3D":
+ return MidBlock3D(
+ in_channels=in_channels,
+ num_layers=num_layers,
+ act_fn=act_fn,
+ norm_num_groups=norm_num_groups,
+ norm_eps=norm_eps,
+ dropout=dropout,
+ add_attention=add_attention,
+ attention_type=attention_type,
+ attention_head_dim=in_channels // num_attention_heads,
+ output_scale_factor=output_scale_factor,
+ )
+ else:
+ raise ValueError(f"Unknown mid block type: {mid_block_type}")
+
+
+class MidBlock3D(nn.Module):
+ """
+ A 3D UNet mid-block [`MidBlock3D`] with multiple residual blocks and optional attention blocks.
+
+ Args:
+ in_channels (`int`): The number of input channels.
+ num_layers (`int`, *optional*, defaults to 1): The number of residual blocks.
+ act_fn (`str`, *optional*, defaults to `swish`): The activation function for the resnet blocks.
+ norm_num_groups (`int`, *optional*, defaults to 32):
+ The number of groups to use in the group normalization layers of the resnet blocks.
+ norm_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks.
+ dropout (`float`, *optional*, defaults to 0.0): The dropout rate.
+ add_attention (`bool`, *optional*, defaults to `True`): Whether to add attention blocks.
+ attention_type: (`str`, *optional*, defaults to `3d`): The type of attention to use. Defaults to `3d`.
+ attention_head_dim (`int`, *optional*, defaults to 1):
+ Dimension of a single attention head. The number of attention heads is determined based on this value and
+ the number of input channels.
+ output_scale_factor (`float`, *optional*, defaults to 1.0): The output scale factor.
+
+ Returns:
+ `torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size,
+ in_channels, temporal_length, height, width)`.
+
+ """
+
+ def __init__(
+ self,
+ in_channels: int,
+ num_layers: int = 1,
+ act_fn: str = "silu",
+ norm_num_groups: int = 32,
+ norm_eps: float = 1e-6,
+ dropout: float = 0.0,
+ add_attention: bool = True,
+ attention_type: str = "3d",
+ attention_head_dim: int = 1,
+ output_scale_factor: float = 1.0,
+ ):
+ super().__init__()
+
+ self.attention_type = attention_type
+
+ norm_num_groups = norm_num_groups if norm_num_groups is not None else min(in_channels // 4, 32)
+
+ self.convs = nn.ModuleList([
+ ResidualBlock3D(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ non_linearity=act_fn,
+ norm_num_groups=norm_num_groups,
+ norm_eps=norm_eps,
+ dropout=dropout,
+ output_scale_factor=output_scale_factor,
+ )
+ ])
+
+ self.attentions = nn.ModuleList([])
+ for _ in range(num_layers - 1):
+ if add_attention:
+ if attention_type == "3d":
+ self.attentions.append(
+ Attention3D(
+ in_channels,
+ nheads=in_channels // attention_head_dim,
+ head_dim=attention_head_dim,
+ bias=True,
+ upcast_softmax=True,
+ norm_num_groups=norm_num_groups,
+ eps=norm_eps,
+ rescale_output_factor=output_scale_factor,
+ residual_connection=True,
+ )
+ )
+ elif attention_type == "spatial_temporal":
+ self.attentions.append(
+ nn.ModuleList([
+ SpatialAttention(
+ in_channels,
+ nheads=in_channels // attention_head_dim,
+ head_dim=attention_head_dim,
+ bias=True,
+ upcast_softmax=True,
+ norm_num_groups=norm_num_groups,
+ eps=norm_eps,
+ rescale_output_factor=output_scale_factor,
+ residual_connection=True,
+ ),
+ TemporalAttention(
+ in_channels,
+ nheads=in_channels // attention_head_dim,
+ head_dim=attention_head_dim,
+ bias=True,
+ upcast_softmax=True,
+ norm_num_groups=norm_num_groups,
+ eps=norm_eps,
+ rescale_output_factor=output_scale_factor,
+ residual_connection=True,
+ ),
+ ])
+ )
+ elif attention_type == "spatial":
+ self.attentions.append(
+ SpatialAttention(
+ in_channels,
+ nheads=in_channels // attention_head_dim,
+ head_dim=attention_head_dim,
+ bias=True,
+ upcast_softmax=True,
+ norm_num_groups=norm_num_groups,
+ eps=norm_eps,
+ rescale_output_factor=output_scale_factor,
+ residual_connection=True,
+ )
+ )
+ elif attention_type == "temporal":
+ self.attentions.append(
+ TemporalAttention(
+ in_channels,
+ nheads=in_channels // attention_head_dim,
+ head_dim=attention_head_dim,
+ bias=True,
+ upcast_softmax=True,
+ norm_num_groups=norm_num_groups,
+ eps=norm_eps,
+ rescale_output_factor=output_scale_factor,
+ residual_connection=True,
+ )
+ )
+ else:
+ raise ValueError(f"Unknown attention type: {attention_type}")
+ else:
+ self.attentions.append(None)
+
+ self.convs.append(
+ ResidualBlock3D(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ non_linearity=act_fn,
+ norm_num_groups=norm_num_groups,
+ norm_eps=norm_eps,
+ dropout=dropout,
+ output_scale_factor=output_scale_factor,
+ )
+ )
+
+ def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
+ hidden_states = self.convs[0](hidden_states)
+
+ for attn, resnet in zip(self.attentions, self.convs[1:]):
+ if attn is not None:
+ if self.attention_type == "spatial_temporal":
+ spatial_attn, temporal_attn = attn
+ hidden_states = spatial_attn(hidden_states)
+ hidden_states = temporal_attn(hidden_states)
+ else:
+ hidden_states = attn(hidden_states)
+ hidden_states = resnet(hidden_states)
+
+ return hidden_states
diff --git a/easyanimate/vae/ldm/modules/vaemodules/up_blocks.py b/easyanimate/vae/ldm/modules/vaemodules/up_blocks.py
new file mode 100755
index 0000000000000000000000000000000000000000..ee781cbc9dd913ea25b3c676abacc0b38e7903eb
--- /dev/null
+++ b/easyanimate/vae/ldm/modules/vaemodules/up_blocks.py
@@ -0,0 +1,395 @@
+import torch
+import torch.nn as nn
+
+from .attention import SpatialAttention, TemporalAttention
+from .common import ResidualBlock3D
+from .gc_block import GlobalContextBlock
+from .upsamplers import (SpatialTemporalUpsampler3D, SpatialUpsampler3D,
+ TemporalUpsampler3D)
+
+
+def get_up_block(
+ up_block_type: str,
+ in_channels: int,
+ out_channels: int,
+ num_layers: int,
+ act_fn: str,
+ norm_num_groups: int = 32,
+ norm_eps: float = 1e-6,
+ dropout: float = 0.0,
+ num_attention_heads: int = 1,
+ output_scale_factor: float = 1.0,
+ add_gc_block: bool = False,
+ add_upsample: bool = True,
+) -> nn.Module:
+ if up_block_type == "SpatialUpBlock3D":
+ return SpatialUpBlock3D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ num_layers=num_layers,
+ act_fn=act_fn,
+ norm_num_groups=norm_num_groups,
+ norm_eps=norm_eps,
+ dropout=dropout,
+ output_scale_factor=output_scale_factor,
+ add_gc_block=add_gc_block,
+ add_upsample=add_upsample,
+ )
+ elif up_block_type == "SpatialAttnUpBlock3D":
+ return SpatialAttnUpBlock3D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ num_layers=num_layers,
+ act_fn=act_fn,
+ norm_num_groups=norm_num_groups,
+ norm_eps=norm_eps,
+ dropout=dropout,
+ attention_head_dim=out_channels // num_attention_heads,
+ output_scale_factor=output_scale_factor,
+ add_gc_block=add_gc_block,
+ add_upsample=add_upsample,
+ )
+ elif up_block_type == "TemporalUpBlock3D":
+ return TemporalUpBlock3D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ num_layers=num_layers,
+ act_fn=act_fn,
+ norm_num_groups=norm_num_groups,
+ norm_eps=norm_eps,
+ dropout=dropout,
+ output_scale_factor=output_scale_factor,
+ add_gc_block=add_gc_block,
+ add_upsample=add_upsample,
+ )
+ elif up_block_type == "TemporalAttnUpBlock3D":
+ return TemporalAttnUpBlock3D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ num_layers=num_layers,
+ act_fn=act_fn,
+ norm_num_groups=norm_num_groups,
+ norm_eps=norm_eps,
+ dropout=dropout,
+ attention_head_dim=out_channels // num_attention_heads,
+ output_scale_factor=output_scale_factor,
+ add_gc_block=add_gc_block,
+ add_upsample=add_upsample,
+ )
+ elif up_block_type == "SpatialTemporalUpBlock3D":
+ return SpatialTemporalUpBlock3D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ num_layers=num_layers,
+ act_fn=act_fn,
+ norm_num_groups=norm_num_groups,
+ norm_eps=norm_eps,
+ dropout=dropout,
+ output_scale_factor=output_scale_factor,
+ add_gc_block=add_gc_block,
+ add_upsample=add_upsample,
+ )
+ else:
+ raise ValueError(f"Unknown up block type: {up_block_type}")
+
+
+class SpatialUpBlock3D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ num_layers: int = 1,
+ act_fn: str = "silu",
+ norm_num_groups: int = 32,
+ norm_eps: float = 1e-6,
+ dropout: float = 0.0,
+ output_scale_factor: float = 1.0,
+ add_gc_block: bool = False,
+ add_upsample: bool = True,
+ ):
+ super().__init__()
+
+ if add_upsample:
+ self.upsampler = SpatialUpsampler3D(in_channels, in_channels)
+ else:
+ self.upsampler = None
+
+ if add_gc_block:
+ self.gc_block = GlobalContextBlock(in_channels, in_channels, fusion_type="mul")
+ else:
+ self.gc_block = None
+
+ self.convs = nn.ModuleList([])
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ self.convs.append(
+ ResidualBlock3D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ non_linearity=act_fn,
+ norm_num_groups=norm_num_groups,
+ norm_eps=norm_eps,
+ dropout=dropout,
+ output_scale_factor=output_scale_factor,
+ )
+ )
+
+ def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
+ for conv in self.convs:
+ x = conv(x)
+
+ if self.gc_block is not None:
+ x = self.gc_block(x)
+
+ if self.upsampler is not None:
+ x = self.upsampler(x)
+
+ return x
+
+
+class SpatialAttnUpBlock3D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ num_layers: int = 1,
+ act_fn: str = "silu",
+ norm_num_groups: int = 32,
+ norm_eps: float = 1e-6,
+ dropout: float = 0.0,
+ attention_head_dim: int = 1,
+ output_scale_factor: float = 1.0,
+ add_gc_block: bool = False,
+ add_upsample: bool = True,
+ ):
+ super().__init__()
+
+ self.convs = nn.ModuleList([])
+ self.attentions = nn.ModuleList([])
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ self.convs.append(
+ ResidualBlock3D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ non_linearity=act_fn,
+ norm_num_groups=norm_num_groups,
+ norm_eps=norm_eps,
+ dropout=dropout,
+ output_scale_factor=output_scale_factor,
+ )
+ )
+ self.attentions.append(
+ SpatialAttention(
+ out_channels,
+ nheads=out_channels // attention_head_dim,
+ head_dim=attention_head_dim,
+ bias=True,
+ upcast_softmax=True,
+ norm_num_groups=norm_num_groups,
+ eps=norm_eps,
+ rescale_output_factor=output_scale_factor,
+ residual_connection=True,
+ )
+ )
+
+ if add_gc_block:
+ self.gc_block = GlobalContextBlock(out_channels, out_channels, fusion_type="mul")
+ else:
+ self.gc_block = None
+
+ if add_upsample:
+ self.upsampler = SpatialUpsampler3D(out_channels, out_channels)
+ else:
+ self.upsampler = None
+
+ def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
+ for conv, attn in zip(self.convs, self.attentions):
+ x = conv(x)
+ x = attn(x)
+
+ if self.gc_block is not None:
+ x = self.gc_block(x)
+
+ if self.upsampler is not None:
+ x = self.upsampler(x)
+
+ return x
+
+
+class TemporalUpBlock3D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ num_layers: int = 1,
+ act_fn: str = "silu",
+ norm_num_groups: int = 32,
+ norm_eps: float = 1e-6,
+ dropout: float = 0.0,
+ output_scale_factor: float = 1.0,
+ add_gc_block: bool = False,
+ add_upsample: bool = True,
+ ):
+ super().__init__()
+
+ self.convs = nn.ModuleList([])
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ self.convs.append(
+ ResidualBlock3D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ non_linearity=act_fn,
+ norm_num_groups=norm_num_groups,
+ norm_eps=norm_eps,
+ dropout=dropout,
+ output_scale_factor=output_scale_factor,
+ )
+ )
+
+ if add_gc_block:
+ self.gc_block = GlobalContextBlock(out_channels, out_channels, fusion_type="mul")
+ else:
+ self.gc_block = None
+
+ if add_upsample:
+ self.upsampler = TemporalUpsampler3D(out_channels, out_channels)
+ else:
+ self.upsampler = None
+
+ def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
+ for conv in self.convs:
+ x = conv(x)
+
+ if self.gc_block is not None:
+ x = self.gc_block(x)
+
+ if self.upsampler is not None:
+ x = self.upsampler(x)
+
+ return x
+
+
+class TemporalAttnUpBlock3D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ num_layers: int = 1,
+ act_fn: str = "silu",
+ norm_num_groups: int = 32,
+ norm_eps: float = 1e-6,
+ dropout: float = 0.0,
+ attention_head_dim: int = 1,
+ output_scale_factor: float = 1.0,
+ add_gc_block: bool = False,
+ add_upsample: bool = True,
+ ):
+ super().__init__()
+
+ self.convs = nn.ModuleList([])
+ self.attentions = nn.ModuleList([])
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ self.convs.append(
+ ResidualBlock3D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ non_linearity=act_fn,
+ norm_num_groups=norm_num_groups,
+ norm_eps=norm_eps,
+ dropout=dropout,
+ output_scale_factor=output_scale_factor,
+ )
+ )
+ self.attentions.append(
+ TemporalAttention(
+ out_channels,
+ nheads=out_channels // attention_head_dim,
+ head_dim=attention_head_dim,
+ bias=True,
+ upcast_softmax=True,
+ norm_num_groups=norm_num_groups,
+ eps=norm_eps,
+ rescale_output_factor=output_scale_factor,
+ residual_connection=True,
+ )
+ )
+
+ if add_gc_block:
+ self.gc_block = GlobalContextBlock(out_channels, out_channels, fusion_type="mul")
+ else:
+ self.gc_block = None
+
+ if add_upsample:
+ self.upsampler = TemporalUpsampler3D(out_channels, out_channels)
+ else:
+ self.upsampler = None
+
+ def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
+ for conv, attn in zip(self.convs, self.attentions):
+ x = conv(x)
+ x = attn(x)
+
+ if self.gc_block is not None:
+ x = self.gc_block(x)
+
+ if self.upsampler is not None:
+ x = self.upsampler(x)
+
+ return x
+
+
+class SpatialTemporalUpBlock3D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ num_layers: int = 1,
+ act_fn: str = "silu",
+ norm_num_groups: int = 32,
+ norm_eps: float = 1e-6,
+ dropout: float = 0.0,
+ output_scale_factor: float = 1.0,
+ add_gc_block: bool = False,
+ add_upsample: bool = True,
+ ):
+ super().__init__()
+
+ self.convs = nn.ModuleList([])
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ self.convs.append(
+ ResidualBlock3D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ non_linearity=act_fn,
+ norm_num_groups=norm_num_groups,
+ norm_eps=norm_eps,
+ dropout=dropout,
+ output_scale_factor=output_scale_factor,
+ )
+ )
+
+ if add_gc_block:
+ self.gc_block = GlobalContextBlock(out_channels, out_channels, fusion_type="mul")
+ else:
+ self.gc_block = None
+
+ if add_upsample:
+ self.upsampler = SpatialTemporalUpsampler3D(out_channels, out_channels)
+ else:
+ self.upsampler = None
+
+ def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
+ for conv in self.convs:
+ x = conv(x)
+
+ if self.gc_block is not None:
+ x = self.gc_block(x)
+
+ if self.upsampler is not None:
+ x = self.upsampler(x)
+
+ return x
diff --git a/easyanimate/vae/ldm/modules/vaemodules/upsamplers.py b/easyanimate/vae/ldm/modules/vaemodules/upsamplers.py
new file mode 100644
index 0000000000000000000000000000000000000000..16288f13ec1cd70b3ad3e15a27a6fe676884e25e
--- /dev/null
+++ b/easyanimate/vae/ldm/modules/vaemodules/upsamplers.py
@@ -0,0 +1,202 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from einops import rearrange, repeat
+
+from .common import CausalConv3d
+
+
+class Upsampler(nn.Module):
+ def __init__(
+ self,
+ spatial_upsample_factor: int = 1,
+ temporal_upsample_factor: int = 1,
+ ):
+ super().__init__()
+
+ self.spatial_upsample_factor = spatial_upsample_factor
+ self.temporal_upsample_factor = temporal_upsample_factor
+
+
+class SpatialUpsampler3D(Upsampler):
+ def __init__(self, in_channels: int, out_channels: int):
+ super().__init__(spatial_upsample_factor=2)
+
+ if out_channels is None:
+ out_channels = in_channels
+
+ self.conv = CausalConv3d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=3,
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = F.interpolate(x, scale_factor=(1, 2, 2), mode="nearest")
+ x = self.conv(x)
+ return x
+
+
+class SpatialUpsamplerD2S3D(Upsampler):
+ def __init__(self, in_channels: int, out_channels: int):
+ super().__init__(spatial_upsample_factor=2)
+
+ if out_channels is None:
+ out_channels = in_channels
+
+ self.conv = CausalConv3d(
+ in_channels=in_channels,
+ out_channels=out_channels * 4,
+ kernel_size=3,
+ )
+
+ o, i, t, h, w = self.conv.weight.shape
+ conv_weight = torch.empty(o // 4, i, t, h, w)
+ nn.init.kaiming_normal_(conv_weight)
+ conv_weight = repeat(conv_weight, "o ... -> (o 4) ...")
+ self.conv.weight.data.copy_(conv_weight)
+
+ nn.init.zeros_(self.conv.bias)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = self.conv(x)
+ x = rearrange(x, "b (c p1 p2) t h w -> b c t (h p1) (w p2)", p1=2, p2=2)
+ return x
+
+
+class TemporalUpsampler3D(Upsampler):
+ def __init__(self, in_channels: int, out_channels: int):
+ super().__init__(
+ spatial_upsample_factor=1,
+ temporal_upsample_factor=2,
+ )
+
+ if out_channels is None:
+ out_channels = in_channels
+
+ self.conv = CausalConv3d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=3,
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ if x.shape[2] > 1:
+ first_frame, x = x[:, :, :1], x[:, :, 1:]
+ x = F.interpolate(x, scale_factor=(2, 1, 1), mode="trilinear")
+ x = torch.cat([first_frame, x], dim=2)
+ x = self.conv(x)
+ return x
+
+
+class TemporalUpsamplerD2S3D(Upsampler):
+ def __init__(self, in_channels: int, out_channels: int):
+ super().__init__(
+ spatial_upsample_factor=1,
+ temporal_upsample_factor=2,
+ )
+
+ if out_channels is None:
+ out_channels = in_channels
+
+ self.conv = CausalConv3d(
+ in_channels=in_channels,
+ out_channels=out_channels * 2,
+ kernel_size=3,
+ )
+
+ o, i, t, h, w = self.conv.weight.shape
+ conv_weight = torch.empty(o // 2, i, t, h, w)
+ nn.init.kaiming_normal_(conv_weight)
+ conv_weight = repeat(conv_weight, "o ... -> (o 2) ...")
+ self.conv.weight.data.copy_(conv_weight)
+
+ nn.init.zeros_(self.conv.bias)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = self.conv(x)
+ x = rearrange(x, "b (c p1) t h w -> b c (t p1) h w", p1=2)
+ x = x[:, :, 1:]
+ return x
+
+
+class SpatialTemporalUpsampler3D(Upsampler):
+ def __init__(self, in_channels: int, out_channels: int):
+ super().__init__(
+ spatial_upsample_factor=2,
+ temporal_upsample_factor=2,
+ )
+
+ if out_channels is None:
+ out_channels = in_channels
+
+ self.conv = CausalConv3d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=3,
+ )
+
+ self.padding_flag = 0
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = F.interpolate(x, scale_factor=(1, 2, 2), mode="nearest")
+ x = self.conv(x)
+
+ if self.padding_flag == 0:
+ if x.shape[2] > 1:
+ first_frame, x = x[:, :, :1], x[:, :, 1:]
+ x = F.interpolate(x, scale_factor=(2, 1, 1), mode="trilinear")
+ x = torch.cat([first_frame, x], dim=2)
+ elif self.padding_flag == 2:
+ x = F.interpolate(x, scale_factor=(2, 1, 1), mode="trilinear")
+ return x
+
+ def set_padding_one_frame(self):
+ def _set_padding_one_frame(name, module):
+ if hasattr(module, 'padding_flag'):
+ print('Set pad mode for module[%s] type=%s' % (name, str(type(module))))
+ module.padding_flag = 1
+ for sub_name, sub_mod in module.named_children():
+ _set_padding_one_frame(sub_name, sub_mod)
+ for name, module in self.named_children():
+ _set_padding_one_frame(name, module)
+
+ def set_padding_more_frame(self):
+ def _set_padding_more_frame(name, module):
+ if hasattr(module, 'padding_flag'):
+ print('Set pad mode for module[%s] type=%s' % (name, str(type(module))))
+ module.padding_flag = 2
+ for sub_name, sub_mod in module.named_children():
+ _set_padding_more_frame(sub_name, sub_mod)
+ for name, module in self.named_children():
+ _set_padding_more_frame(name, module)
+
+class SpatialTemporalUpsamplerD2S3D(Upsampler):
+ def __init__(self, in_channels: int, out_channels: int):
+ super().__init__(
+ spatial_upsample_factor=2,
+ temporal_upsample_factor=2,
+ )
+
+ if out_channels is None:
+ out_channels = in_channels
+
+ self.conv = CausalConv3d(
+ in_channels=in_channels,
+ out_channels=out_channels * 8,
+ kernel_size=3,
+ )
+
+ o, i, t, h, w = self.conv.weight.shape
+ conv_weight = torch.empty(o // 8, i, t, h, w)
+ nn.init.kaiming_normal_(conv_weight)
+ conv_weight = repeat(conv_weight, "o ... -> (o 8) ...")
+ self.conv.weight.data.copy_(conv_weight)
+
+ nn.init.zeros_(self.conv.bias)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = self.conv(x)
+ x = rearrange(x, "b (c p1 p2 p3) t h w -> b c (t p1) (h p2) (w p3)", p1=2, p2=2, p3=2)
+ x = x[:, :, 1:]
+ return x
diff --git a/easyanimate/vae/ldm/util.py b/easyanimate/vae/ldm/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..475711753da5e19f32e4e80f410335c27d797140
--- /dev/null
+++ b/easyanimate/vae/ldm/util.py
@@ -0,0 +1,201 @@
+import importlib
+import multiprocessing as mp
+from collections import abc
+from functools import partial
+from inspect import isfunction
+from queue import Queue
+from threading import Thread
+
+import numpy as np
+import torch
+from einops import rearrange
+from PIL import Image, ImageDraw, ImageFont
+
+
+def log_txt_as_img(wh, xc, size=10):
+ # wh a tuple of (width, height)
+ # xc a list of captions to plot
+ b = len(xc)
+ txts = list()
+ for bi in range(b):
+ txt = Image.new("RGB", wh, color="white")
+ draw = ImageDraw.Draw(txt)
+ font = ImageFont.truetype('data/DejaVuSans.ttf', size=size)
+ nc = int(40 * (wh[0] / 256))
+ lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc))
+
+ try:
+ draw.text((0, 0), lines, fill="black", font=font)
+ except UnicodeEncodeError:
+ print("Cant encode string for logging. Skipping.")
+
+ txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
+ txts.append(txt)
+ txts = np.stack(txts)
+ txts = torch.tensor(txts)
+ return txts
+
+
+def ismap(x):
+ if not isinstance(x, torch.Tensor):
+ return False
+ return (len(x.shape) == 4) and (x.shape[1] > 3)
+
+
+def isimage(x):
+ if not isinstance(x, torch.Tensor):
+ return False
+ return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
+
+
+def exists(x):
+ return x is not None
+
+
+def default(val, d):
+ if exists(val):
+ return val
+ return d() if isfunction(d) else d
+
+
+def mean_flat(tensor):
+ """
+ https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86
+ Take the mean over all non-batch dimensions.
+ """
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
+
+
+def count_params(model, verbose=False):
+ total_params = sum(p.numel() for p in model.parameters())
+ if verbose:
+ print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.")
+ return total_params
+
+
+def instantiate_from_config(config):
+ if not "target" in config:
+ if config == '__is_first_stage__':
+ return None
+ elif config == "__is_unconditional__":
+ return None
+ raise KeyError("Expected key `target` to instantiate.")
+ return get_obj_from_str(config["target"])(**config.get("params", dict()))
+
+
+def get_obj_from_str(string, reload=False):
+ module, cls = string.rsplit(".", 1)
+ if reload:
+ module_imp = importlib.import_module(module)
+ importlib.reload(module_imp)
+ return getattr(importlib.import_module(module, package=None), cls)
+
+
+def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False):
+ # create dummy dataset instance
+
+ # run prefetching
+ if idx_to_fn:
+ res = func(data, worker_id=idx)
+ else:
+ res = func(data)
+ Q.put([idx, res])
+ Q.put("Done")
+
+
+def parallel_data_prefetch(
+ func: callable, data, n_proc, target_data_type="ndarray", cpu_intensive=True, use_worker_id=False
+):
+ # if target_data_type not in ["ndarray", "list"]:
+ # raise ValueError(
+ # "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray."
+ # )
+ if isinstance(data, np.ndarray) and target_data_type == "list":
+ raise ValueError("list expected but function got ndarray.")
+ elif isinstance(data, abc.Iterable):
+ if isinstance(data, dict):
+ print(
+ f'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.'
+ )
+ data = list(data.values())
+ if target_data_type == "ndarray":
+ data = np.asarray(data)
+ else:
+ data = list(data)
+ else:
+ raise TypeError(
+ f"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}."
+ )
+
+ if cpu_intensive:
+ Q = mp.Queue(1000)
+ proc = mp.Process
+ else:
+ Q = Queue(1000)
+ proc = Thread
+ # spawn processes
+ if target_data_type == "ndarray":
+ arguments = [
+ [func, Q, part, i, use_worker_id]
+ for i, part in enumerate(np.array_split(data, n_proc))
+ ]
+ else:
+ step = (
+ int(len(data) / n_proc + 1)
+ if len(data) % n_proc != 0
+ else int(len(data) / n_proc)
+ )
+ arguments = [
+ [func, Q, part, i, use_worker_id]
+ for i, part in enumerate(
+ [data[i: i + step] for i in range(0, len(data), step)]
+ )
+ ]
+ processes = []
+ for i in range(n_proc):
+ p = proc(target=_do_parallel_data_prefetch, args=arguments[i])
+ processes += [p]
+
+ # start processes
+ print(f"Start prefetching...")
+ import time
+
+ start = time.time()
+ gather_res = [[] for _ in range(n_proc)]
+ try:
+ for p in processes:
+ p.start()
+
+ k = 0
+ while k < n_proc:
+ # get result
+ res = Q.get()
+ if res == "Done":
+ k += 1
+ else:
+ gather_res[res[0]] = res[1]
+
+ except Exception as e:
+ print("Exception: ", e)
+ for p in processes:
+ p.terminate()
+
+ raise e
+ finally:
+ for p in processes:
+ p.join()
+ print(f"Prefetching complete. [{time.time() - start} sec.]")
+
+ if target_data_type == 'ndarray':
+ if not isinstance(gather_res[0], np.ndarray):
+ return np.concatenate([np.asarray(r) for r in gather_res], axis=0)
+
+ # order outputs
+ return np.concatenate(gather_res, axis=0)
+ elif target_data_type == 'list':
+ out = []
+ for r in gather_res:
+ out.extend(r)
+ return out
+ else:
+ return gather_res
diff --git a/easyanimate/vae/setup.py b/easyanimate/vae/setup.py
new file mode 100644
index 0000000000000000000000000000000000000000..baaf21c9fc5238f4d495fb0b5d5f5e10cc481078
--- /dev/null
+++ b/easyanimate/vae/setup.py
@@ -0,0 +1,13 @@
+from setuptools import find_packages, setup
+
+setup(
+ name='latent-diffusion',
+ version='0.0.1',
+ description='',
+ packages=find_packages(),
+ install_requires=[
+ 'torch',
+ 'numpy',
+ 'tqdm',
+ ],
+)
\ No newline at end of file
diff --git a/easyanimate/video_caption/README.md b/easyanimate/video_caption/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..92c6f9abbac2e39c12996a849299457a2c1dbfbe
--- /dev/null
+++ b/easyanimate/video_caption/README.md
@@ -0,0 +1,90 @@
+# Video Caption
+EasyAnimate uses multi-modal LLMs to generate captions for frames extracted from the video firstly, and then employs LLMs to summarize and refine the generated frame captions into the final video caption. By leveraging [sglang](https://github.com/sgl-project/sglang)/[vLLM](https://github.com/vllm-project/vllm) and [accelerate distributed inference](https://huggingface.co/docs/accelerate/en/usage_guides/distributed_inference), the entire processing could be very fast.
+
+English | [简体中文](./README_zh-CN.md)
+
+## Quick Start
+1. Cloud usage: AliyunDSW/Docker
+
+ Check [README.md](../../README.md#quick-start) for details.
+
+2. Local usage
+
+ ```shell
+ # Install EasyAnimate requirements firstly.
+ cd EasyAnimate && pip install -r requirements.txt
+
+ # Install additional requirements for video caption.
+ cd easyanimate/video_caption && pip install -r requirements.txt --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
+
+ # Use DDP instead of DP in EasyOCR detection.
+ site_pkg_path=$(python -c 'import site; print(site.getsitepackages()[0])')
+ cp -v easyocr_detection_patched.py $site_pkg_path/easyocr/detection.py
+
+ # We strongly recommend using Docker unless you can properly handle the dependency between vllm with torch(cuda).
+ ```
+
+## Data preprocessing
+Data preprocessing can be divided into three parts:
+
+- Video cut.
+- Video cleaning.
+- Video caption.
+
+The input for data preprocessing can be a video folder or a metadata file (txt/csv/jsonl) containing the video path column. Please check `get_video_path_list` function in [utils/video_utils.py](utils/video_utils.py) for details.
+
+For easier understanding, we use one data from Panda70m as an example for data preprocessing, [Download here](https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/easyanimate/asset/v2/--C66yU3LjM_2.mp4). Please download the video and push it in "datasets/panda_70m/before_vcut/"
+
+```
+📦 datasets/
+├── 📂 panda_70m/
+│ └── 📂 before_vcut/
+│ └── 📄 --C66yU3LjM_2.mp4
+```
+
+1. Video cut
+
+ For long video cut, EasyAnimate utilizes PySceneDetect to identify scene changes within the video and performs scene cutting based on certain threshold values to ensure consistency in the themes of the video segments. After cutting, we only keep segments with lengths ranging from 3 to 10 seconds for model training.
+
+ We have completed the parameters for ```stage_1_video_cut.sh```, so I can run it directly using the command sh ```stage_1_video_cut.sh```. After executing ```stage_1_video_cut.sh```, we obtained short videos in ```easyanimate/video_caption/datasets/panda_70m/train```.
+
+ ```shell
+ sh stage_1_video_cut.sh
+ ```
+2. Video cleaning
+
+ Following SVD's data preparation process, EasyAnimate provides a simple yet effective data processing pipeline for high-quality data filtering and labeling. It also supports distributed processing to accelerate the speed of data preprocessing. The overall process is as follows:
+
+ - Duration filtering: Analyze the basic information of the video to filter out low-quality videos that are short in duration or low in resolution. This filtering result is corresponding to the video cut (3s ~ 10s videos).
+ - Aesthetic filtering: Filter out videos with poor content (blurry, dim, etc.) by calculating the average aesthetic score of uniformly distributed 4 frames.
+ - Text filtering: Use easyocr to calculate the text proportion of middle frames to filter out videos with a large proportion of text.
+ - Motion filtering: Calculate interframe optical flow differences to filter out videos that move too slowly or too quickly.
+
+ The process file of **Aesthetic filtering** is ```compute_video_frame_quality.py```. After executing ```compute_video_frame_quality.py```, we obtained the file ```datasets/panda_70m/aesthetic_score.jsonl```, where each line corresponds to the aesthetic score of each video.
+
+ The process file of **Text filtering** is ```compute_text_score.py```. After executing ```compute_text_score.py```, we obtained the file ```datasets/panda_70m/text_score.jsonl```, where each line corresponds to the text score of each video.
+
+ The process file of **Motion filtering** is ```compute_motion_score.py```. Motion filtering is based on Aesthetic filtering and Text filtering; only samples that meet certain aesthetic scores and text scores will undergo calculation for the Motion score. After executing ```compute_motion_score.py```, we obtained the file ```datasets/panda_70m/motion_score.jsonl```, where each line corresponds to the motion score of each video.
+
+ Then we need to filter videos by motion scores. After executing ```filter_videos_by_motion_score.py```, we get the file ```datasets/panda_70m/train.jsonl```, which includes the video we need to caption.
+
+ We have completed the parameters for stage_2_filter_data.sh, so I can run it directly using the command sh stage_2_filter_data.sh.
+
+ ```shell
+ sh stage_2_filter_data.sh
+ ```
+3. Video caption
+
+ Video captioning is carried out in two stages. The first stage involves extracting frames from a video and generating descriptions for them. Subsequently, a large language model is used to summarize these descriptions into a caption.
+
+ We have conducted a detailed and manual comparison of open sourced multi-modal LLMs such as [Qwen-VL](https://huggingface.co/Qwen/Qwen-VL), [ShareGPT4V-7B](https://huggingface.co/Lin-Chen/ShareGPT4V-7B), [deepseek-vl-7b-chat](https://huggingface.co/deepseek-ai/deepseek-vl-7b-chat) and etc. And we found that [llava-v1.6-vicuna-7b](https://huggingface.co/liuhaotian/llava-v1.6-vicuna-7b) is capable of generating more detailed captions with fewer hallucinations. Additionally, it is supported by serving engines like [sglang](https://github.com/sgl-project/sglang) and [lmdepoly](https://github.com/InternLM/lmdeploy), enabling faster inference.
+
+ Firstly, we use ```caption_video_frame.py``` to generate frame captions. Then, we use ```caption_summary.py``` to generate summary captions.
+
+ We have completed the parameters for stage_3_video_caption.sh, so I can run it directly using the command sh stage_3_video_caption.sh. After executing ```stage_3_video_cut.sh```, we obtained last json ```train_panda_70m.json``` for easyanimate training.
+
+ ```shell
+ sh stage_3_video_caption.sh
+ ```
+
+ If you cannot access to Huggingface, you can run `export HF_ENDPOINT=https://hf-mirror.com` before the above command to download the summary caption model automatically.
\ No newline at end of file
diff --git a/easyanimate/video_caption/README_zh-CN.md b/easyanimate/video_caption/README_zh-CN.md
new file mode 100644
index 0000000000000000000000000000000000000000..b1bd34f717dd23b77a4d9eb2890f2349b0b6d958
--- /dev/null
+++ b/easyanimate/video_caption/README_zh-CN.md
@@ -0,0 +1,90 @@
+# 数据预处理
+
+
+EasyAnimate 对数据进行了场景切分、视频过滤和视频打标来得到高质量的有标注视频训练使用。使用多模态大型语言模型(LLMs)为从视频中提取的帧生成字幕,然后利用LLMs将生成的帧字幕总结并细化为最终的视频字幕。通过利用sglang/vLLM和加速分布式推理,高效完成视频的打标。
+
+[English](./README.md) | 简体中文
+
+## 快速开始
+1. 云上使用: 阿里云DSW/Docker
+ 参考 [README.md](../../README_zh-CN.md#quick-start) 查看更多细节。
+
+2. 本地安装
+
+ ```shell
+ # Install EasyAnimate requirements firstly.
+ cd EasyAnimate && pip install -r requirements.txt
+
+ # Install additional requirements for video caption.
+ cd easyanimate/video_caption && pip install -r requirements.txt --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
+
+ # Use DDP instead of DP in EasyOCR detection.
+ site_pkg_path=$(python -c 'import site; print(site.getsitepackages()[0])')
+ cp -v easyocr_detection_patched.py $site_pkg_path/easyocr/detection.py
+
+ # We strongly recommend using Docker unless you can properly handle the dependency between vllm with torch(cuda).
+ ```
+
+## 数据预处理
+数据预处理可以分为一下三步:
+
+- 视频切分
+- 视频过滤
+- 视频打标
+
+数据预处理的输入可以是视频文件夹或包含视频路径列的元数据文件(txt/csv/jsonl格式)。详情请查看[utils/video_utils.py](utils/video_utils.py) 文件中的 `get_video_path_list` 函数。
+
+为了便于理解,我们以Panda70m的一个数据为例进行数据预处理,点击[这里](https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/easyanimate/asset/v2/--C66yU3LjM_2.mp4)下载视频。请下载视频并放在下面的路径:"datasets/panda_70m/before_vcut/"
+
+```
+📦 datasets/
+├── 📂 panda_70m/
+│ └── 📂 before_vcut/
+│ └── 📄 --C66yU3LjM_2.mp4
+```
+
+1. 视频切分
+
+ 对于长视频剪辑,EasyAnimate 利用 PySceneDetect 来识别视频中的场景变化,并根据特定的阈值进行场景切割,以确保视频片段主题的一致性。切割后,我们只保留长度在3到10秒之间的片段,用于模型训练。
+
+ 我们整理了完整的方案在 ```stage_1_video_cut.sh``` 文件中, 您可以直接运行```stage_1_video_cut.sh```. 执行完成后可以在 ```easyanimate/video_caption/datasets/panda_70m/train``` 文件夹中查看结果。
+
+ ```shell
+ sh stage_1_video_cut.sh
+ ```
+2. 视频过滤
+
+ 遵循SVD([Stable Video Diffusion](https://github.com/Stability-AI/generative-models))的数据准备流程,EasyAnimate 提供了一个简单而有效的数据处理管道,用于高质量数据的过滤和标记。我们还支持分布式处理来加快数据预处理的速度。整个过程如下::
+
+ - 时长过滤: 分析视频的基本信息,筛选出时长过短或分辨率过低的低质量视频。我们保留3秒至10秒的视频。
+ - 美学过滤: 通过计算均匀分布的4帧的平均审美分数,过滤掉内容质量差的视频(模糊、暗淡等)。
+ - 文本过滤: 使用 [easyocr](https://github.com/JaidedAI/EasyOCR) 来计算中间帧的文本比例,以筛选出含有大量文本的视频。
+ - 运动过滤: 计算帧间光流差异,以筛选出移动过慢或过快的视频。
+
+ **美学过滤** 的代码在 ```compute_video_frame_quality.py```. 执行 ```compute_video_frame_quality.py```,我们可以生成 ```datasets/panda_70m/aesthetic_score.jsonl```文件, 计算每条视频的美学得分。
+
+ **文本过滤** 的代码在 ```compute_text_score.py```. 执行```compute_text_score.py```, 我们可以生成 ```datasets/panda_70m/text_score.jsonl```文件, 计算每个视频的文字占比。
+
+ **运动过滤** 的代码在 ```compute_motion_score.py```. 运动过滤基于审美过滤和文本过滤;只有达到一定审美分数和文本分数的样本才会进行运动分数的计算。 执行 ```compute_motion_score.py```, 我们可以生成 ```datasets/panda_70m/motion_score.jsonl```, 计算每条视频的运动得分。
+
+ 接着执行 ```filter_videos_by_motion_score.py```来得过滤视频。我们最终得到筛选后需要打标的 ```datasets/panda_70m/train.jsonl```文件。
+
+ 我们将视频过滤的流程整理为 ```stage_2_filter_data.sh```,直接执行该脚本来完成视频数据的过滤。
+
+ ```shell
+ sh stage_2_filter_data.sh
+ ```
+3. 视频打标
+
+
+ 视频打标生成分为两个阶段。第一阶段涉及从视频中提取帧并为它们生成描述。随后,使用大型语言模型将这些描述汇总成一条字幕。
+
+ 我们详细对比了现有的多模态大语言模型(诸如[Qwen-VL](https://huggingface.co/Qwen/Qwen-VL), [ShareGPT4V-7B](https://huggingface.co/Lin-Chen/ShareGPT4V-7B), [deepseek-vl-7b-chat](https://huggingface.co/deepseek-ai/deepseek-vl-7b-chat))生成文本描述的效果。 最终选择 [llava-v1.6-vicuna-7b](https://huggingface.co/liuhaotian/llava-v1.6-vicuna-7b) 来进行视频文本描述的生成,它能生成详细的描述并有更少的幻觉。此外,我们引入 [sglang](https://github.com/sgl-project/sglang),[lmdepoly](https://github.com/InternLM/lmdeploy), 来加速推理的过程。
+
+ 首先,我们用 ```caption_video_frame.py``` 来生成文本描述,并用 ```caption_summary.py``` 来总结描述信息。我们将上述过程整理在 ```stage_3_video_caption.sh```, 直接运行它来生成视频的文本描述。我们最终得到 ```train_panda_70m.json``` 用于EasyAnmate 的训练。
+
+ ```shell
+ sh stage_3_video_caption.sh
+ ```
+
+ 请注意,如遇网络问题,您可以设置 `export HF_ENDPOINT=https://hf-mirror.com` 来自动下载视频打标模型。
\ No newline at end of file
diff --git a/easyanimate/video_caption/caption_summary.py b/easyanimate/video_caption/caption_summary.py
new file mode 100644
index 0000000000000000000000000000000000000000..d0c99f44cb44cb286cc86f047471fcb51fc1478c
--- /dev/null
+++ b/easyanimate/video_caption/caption_summary.py
@@ -0,0 +1,134 @@
+import argparse
+import os
+import re
+from tqdm import tqdm
+
+import pandas as pd
+from vllm import LLM, SamplingParams
+
+from utils.logger import logger
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Recaption the video frame.")
+ parser.add_argument(
+ "--video_metadata_path", type=str, required=True, help="The path to the video dataset metadata (csv/jsonl)."
+ )
+ parser.add_argument(
+ "--video_path_column",
+ type=str,
+ default="video_path",
+ help="The column contains the video path (an absolute path or a relative path w.r.t the video_folder).",
+ )
+ parser.add_argument(
+ "--caption_column",
+ type=str,
+ default="sampled_frame_caption",
+ help="The column contains the sampled_frame_caption.",
+ )
+ parser.add_argument(
+ "--remove_quotes",
+ action="store_true",
+ help="Whether to remove quotes from caption.",
+ )
+ parser.add_argument(
+ "--batch_size",
+ type=int,
+ default=10,
+ required=False,
+ help="The batch size for the video caption.",
+ )
+ parser.add_argument(
+ "--summary_model_name",
+ type=str,
+ default="mistralai/Mistral-7B-Instruct-v0.2",
+ )
+ parser.add_argument(
+ "--summary_prompt",
+ type=str,
+ default=(
+ "You are a helpful video description generator. I'll give you a description of the middle frame of the video clip, "
+ "which you need to summarize it into a description of the video clip."
+ "Please provide your video description following these requirements: "
+ "1. Describe the basic and necessary information of the video in the third person, be as concise as possible. "
+ "2. Output the video description directly. Begin with 'In this video'. "
+ "3. Limit the video description within 100 words. "
+ "Here is the mid-frame description: "
+ ),
+ )
+ parser.add_argument("--saved_path", type=str, required=True, help="The save path to the output results (csv/jsonl).")
+ parser.add_argument("--saved_freq", type=int, default=1000, help="The frequency to save the output results.")
+
+ args = parser.parse_args()
+ return args
+
+
+def main():
+ args = parse_args()
+
+ if args.video_metadata_path.endswith(".csv"):
+ video_metadata_df = pd.read_csv(args.video_metadata_path)
+ elif args.video_metadata_path.endswith(".jsonl"):
+ video_metadata_df = pd.read_json(args.video_metadata_path, lines=True)
+ else:
+ raise ValueError("The video_metadata_path must end with .csv or .jsonl.")
+ video_path_list = video_metadata_df[args.video_path_column].tolist()
+ sampled_frame_caption_list = video_metadata_df[args.caption_column].tolist()
+
+ if not (args.saved_path.endswith(".csv") or args.saved_path.endswith(".jsonl")):
+ raise ValueError("The saved_path must end with .csv or .jsonl.")
+
+ if os.path.exists(args.saved_path):
+ if args.saved_path.endswith(".csv"):
+ saved_metadata_df = pd.read_csv(args.saved_path)
+ elif args.saved_path.endswith(".jsonl"):
+ saved_metadata_df = pd.read_json(args.saved_path, lines=True)
+ saved_video_path_list = saved_metadata_df[args.video_path_column].tolist()
+ video_path_list = list(set(video_path_list) - set(saved_video_path_list))
+ video_metadata_df.set_index(args.video_path_column, inplace=True)
+ video_metadata_df = video_metadata_df.loc[video_path_list]
+ sampled_frame_caption_list = video_metadata_df[args.caption_column].tolist()
+ logger.info(f"Resume from {args.saved_path}: {len(saved_video_path_list)} processed and {len(video_path_list)} to be processed.")
+
+ sampling_params = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=256)
+ summary_model = LLM(model=args.summary_model_name, trust_remote_code=True)
+
+ result_dict = {"video_path": [], "summary_model": [], "summary_caption": []}
+
+ for i in tqdm(range(0, len(sampled_frame_caption_list), args.batch_size)):
+ batch_video_path = video_path_list[i: i + args.batch_size]
+ batch_caption = sampled_frame_caption_list[i : i + args.batch_size]
+ batch_prompt = []
+ for caption in batch_caption:
+ if args.remove_quotes:
+ caption = re.sub(r'(["\']).*?\1', "", caption)
+ batch_prompt.append("user:" + args.summary_prompt + str(caption) + "\n assistant:")
+ batch_output = summary_model.generate(batch_prompt, sampling_params)
+
+ result_dict["video_path"].extend(batch_video_path)
+ result_dict["summary_model"].extend([args.summary_model_name] * len(batch_caption))
+ result_dict["summary_caption"].extend([output.outputs[0].text.rstrip() for output in batch_output])
+
+ # Save the metadata every args.saved_freq.
+ if i != 0 and ((i // args.batch_size) % args.saved_freq) == 0:
+ result_df = pd.DataFrame(result_dict)
+ if args.saved_path.endswith(".csv"):
+ header = True if not os.path.exists(args.saved_path) else False
+ result_df.to_csv(args.saved_path, header=header, index=False, mode="a")
+ elif args.saved_path.endswith(".jsonl"):
+ result_df.to_json(args.saved_path, orient="records", lines=True, mode="a")
+ logger.info(f"Save result to {args.saved_path}.")
+
+ result_dict = {"video_path": [], "summary_model": [], "summary_caption": []}
+
+ result_df = pd.DataFrame(result_dict)
+ if args.saved_path.endswith(".csv"):
+ header = True if not os.path.exists(args.saved_path) else False
+ result_df.to_csv(args.saved_path, header=header, index=False, mode="a")
+ elif args.saved_path.endswith(".jsonl"):
+ result_df.to_json(args.saved_path, orient="records", lines=True, mode="a")
+ logger.info(f"Save the final result to {args.saved_path}.")
+
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/easyanimate/video_caption/caption_video_frame.py b/easyanimate/video_caption/caption_video_frame.py
new file mode 100644
index 0000000000000000000000000000000000000000..09ce26831b3ffddf04e05ef929c5b9fa62dd9f49
--- /dev/null
+++ b/easyanimate/video_caption/caption_video_frame.py
@@ -0,0 +1,267 @@
+import argparse
+import copy
+import os
+
+import pandas as pd
+from accelerate import PartialState
+from accelerate.utils import gather_object
+from natsort import natsorted
+from tqdm import tqdm
+from torch.utils.data import DataLoader
+
+from utils.logger import logger
+from utils.video_dataset import VideoDataset, collate_fn
+from utils.video_utils import get_video_path_list, extract_frames
+
+
+ACCELERATE_SUPPORTED_MODELS = ["Qwen-VL-Chat", "internlm-xcomposer2-vl-7b"]
+SGLANG_SUPPORTED_MODELS = ["llava-v1.6-vicuna-7b"]
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Recaption the video frame.")
+ parser.add_argument("--video_folder", type=str, default="", help="The video folder.")
+ parser.add_argument(
+ "--video_metadata_path", type=str, default=None, help="The path to the video dataset metadata (csv/jsonl/txt)."
+ )
+ parser.add_argument(
+ "--video_path_column",
+ type=str,
+ default="video_path",
+ help="The column contains the video path (an absolute path or a relative path w.r.t the video_folder).",
+ )
+ parser.add_argument(
+ "--batch_size",
+ type=int,
+ default=10,
+ required=False,
+ help="The batch size for the video dataset.",
+ )
+ parser.add_argument(
+ "--frame_sample_method",
+ type=str,
+ choices=["mid", "uniform"],
+ default="mid",
+ )
+ parser.add_argument(
+ "--num_sampled_frames",
+ type=int,
+ default=1,
+ help="num_sampled_frames",
+ )
+ parser.add_argument(
+ "--image_caption_model_name",
+ type=str,
+ choices=ACCELERATE_SUPPORTED_MODELS + SGLANG_SUPPORTED_MODELS,
+ default="internlm-xcomposer2-vl-7b",
+ )
+ parser.add_argument(
+ "--image_caption_model_quantized", type=bool, default=True, help="Whether to use the quantized image caption model."
+ )
+ parser.add_argument(
+ "--image_caption_prompt",
+ type=str,
+ default="Describe this image and its style in a very detailed manner.",
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ required=True,
+ help="The directory to create the subfolder (named with the video name) to indicate the video has been processed.",
+ )
+ parser.add_argument("--saved_path", type=str, required=True, help="The save path to the output results (csv/jsonl).")
+ parser.add_argument("--saved_freq", type=int, default=1000, help="The frequency to save the output results.")
+
+ args = parser.parse_args()
+ return args
+
+
+def accelerate_inference(args, video_path_list):
+ from utils.image_captioner_awq import QwenVLChat, InternLMXComposer2
+
+ state = PartialState()
+ device = state.device
+ if state.num_processes == 1:
+ device = "cuda:0"
+ if args.image_caption_model_name == "internlm-xcomposer2-vl-7b":
+ image_caption_model = InternLMXComposer2(device=device, quantized=args.image_caption_model_quantized)
+ elif args.image_caption_model_name == "Qwen-VL-Chat":
+ image_caption_model = QwenVLChat(device=device, quantized=args.image_caption_model_quantized)
+
+ # The workaround can be removed after https://github.com/huggingface/accelerate/pull/2781 is released.
+ index = len(video_path_list) - len(video_path_list) % state.num_processes
+ logger.info(f"Drop {len(video_path_list) % state.num_processes} videos to avoid duplicates in state.split_between_processes.")
+ video_path_list = video_path_list[:index]
+
+ if state.is_main_process:
+ os.makedirs(args.output_dir, exist_ok=True)
+ result_list = []
+ with state.split_between_processes(video_path_list) as splitted_video_path_list:
+ for i, video_path in enumerate(tqdm(splitted_video_path_list, desc=f"{state.device}")):
+ video_id = os.path.splitext(os.path.basename(video_path))[0]
+ try:
+ if not os.path.exists(video_path):
+ print(f"Video {video_id} does not exist. Pass it.")
+ continue
+ sampled_frame_list, sampled_frame_idx_list = extract_frames(video_path, num_sample_frames=args.num_sample_frames)
+ except Exception as e:
+ print(f"Failed to extract frames from video {video_id}. Error is {e}.")
+
+ video_recaption_output_dir = os.path.join(args.output_dir, video_id)
+ if os.path.exists(video_recaption_output_dir):
+ print(f"Video {video_id} has been processed. Pass it.")
+ continue
+ else:
+ os.makedirs(video_recaption_output_dir)
+
+ caption_list = []
+ for frame, frame_idx in zip(sampled_frame_list, sampled_frame_idx_list):
+ frame_path = f"{args.output_dir}/{video_id}_{frame_idx}.png"
+ frame.save(frame_path)
+ try:
+ response, _ = image_caption_model(args.image_caption_prompt, frame_path)
+ except Exception as e:
+ print(f"Failed to caption video {video_id}. Error is {e}.")
+ finally:
+ os.remove(frame_path)
+ caption_list.append(response)
+
+ result_meta = {}
+ if args.video_folder == "":
+ result_meta[args.video_path_column] = video_path
+ else:
+ result_meta[args.video_path_column] = os.path.basename(video_path)
+ result_meta["image_caption_model"] = args.image_caption_model_name
+ result_meta["prompt"] = args.image_caption_prompt
+ result_meta["sampled_frame_idx"] = sampled_frame_idx_list
+ result_meta["sampled_frame_caption"] = caption_list
+ result_list.append(copy.deepcopy(result_meta))
+
+ # Save the metadata in the main process.
+ if i != 0 and i % args.saved_freq == 0:
+ state.wait_for_everyone()
+ gathered_result_list = gather_object(result_list)
+ if state.is_main_process:
+ result_df = pd.DataFrame(gathered_result_list)
+ if args.saved_path.endswith(".csv"):
+ result_df.to_csv(args.saved_path, index=False)
+ elif args.saved_path.endswith(".jsonl"):
+ result_df.to_json(args.saved_path, orient="records", lines=True)
+ print(f"Save result to {args.saved_path}.")
+
+ # Wait for all processes to finish and gather the final result.
+ state.wait_for_everyone()
+ gathered_result_list = gather_object(result_list)
+ # Save the metadata in the main process.
+ if state.is_main_process:
+ result_df = pd.DataFrame(gathered_result_list)
+ if args.saved_path.endswith(".csv"):
+ result_df.to_csv(args.saved_path, index=False)
+ elif args.saved_path.endswith(".jsonl"):
+ result_df.to_json(args.saved_path, orient="records", lines=True)
+ print(f"Save the final result to {args.saved_path}.")
+
+
+def sglang_inference(args, video_path_list):
+ from utils.image_captioner_sglang import LLaVASRT
+
+ if args.image_caption_model_name == "llava-v1.6-vicuna-7b":
+ image_caption_model = LLaVASRT()
+
+ result_dict = {
+ "video_path": [],
+ "image_caption_model": [],
+ "prompt": [],
+ 'sampled_frame_idx': [],
+ "sampled_frame_caption": []
+ }
+
+ video_dataset = VideoDataset(
+ video_path_list=video_path_list,
+ sample_method=args.frame_sample_method,
+ num_sampled_frames=args.num_sampled_frames
+ )
+ video_loader = DataLoader(video_dataset, batch_size=args.batch_size, num_workers=16, collate_fn=collate_fn)
+ for idx, batch in enumerate(tqdm(video_loader)):
+ if len(batch) == 0:
+ continue
+ batch_video_path, batch_frame_idx = batch["video_path"], batch["sampled_frame_idx"]
+ # [batch_size, num_sampled_frames, H, W, C] => [batch_size * num_sampled_frames, H, W, C].
+ batch_frame = []
+ for item_sampled_frame in batch["sampled_frame"]:
+ batch_frame.extend([frame for frame in item_sampled_frame])
+
+ try:
+ response_list, _ = image_caption_model([args.image_caption_prompt] * len(batch_frame), batch_frame)
+ response_list = [response_list[i:i + args.num_sampled_frames] for i in range(0, len(response_list), args.num_sampled_frames)]
+ except Exception as e:
+ logger.error(f"Failed to caption video {batch_video_path}. Error is {e}.")
+
+ result_dict["video_path"].extend(batch_video_path)
+ result_dict["image_caption_model"].extend([args.image_caption_model_name] * len(batch_video_path))
+ result_dict["prompt"].extend([args.image_caption_prompt] * len(batch_video_path))
+ result_dict["sampled_frame_idx"].extend(batch_frame_idx)
+ result_dict["sampled_frame_caption"].extend(response_list)
+
+ # Save the metadata in the main process.
+ if idx != 0 and idx % args.saved_freq == 0:
+ result_df = pd.DataFrame(result_dict)
+ if args.saved_path.endswith(".csv"):
+ header = True if not os.path.exists(args.saved_path) else False
+ result_df.to_csv(args.saved_path, header=header, index=False, mode="a")
+ elif args.saved_path.endswith(".jsonl"):
+ result_df.to_json(args.saved_path, orient="records", lines=True, mode="a")
+ logger.info(f"Save result to {args.saved_path}.")
+
+ result_dict = {
+ "video_path": [],
+ "image_caption_model": [],
+ "prompt": [],
+ 'sampled_frame_idx': [],
+ "sampled_frame_caption": []
+ }
+
+ if len(result_dict["video_path"]) != 0:
+ result_df = pd.DataFrame(result_dict)
+ if args.saved_path.endswith(".csv"):
+ header = True if not os.path.exists(args.saved_path) else False
+ result_df.to_csv(args.saved_path, header=header, index=False, mode="a")
+ elif args.saved_path.endswith(".jsonl"):
+ result_df.to_json(args.saved_path, orient="records", lines=True, mode="a")
+ logger.info(f"Save the final result to {args.saved_path}.")
+
+
+def main():
+ args = parse_args()
+
+ video_path_list = get_video_path_list(
+ video_folder=args.video_folder,
+ video_metadata_path=args.video_metadata_path,
+ video_path_column=args.video_path_column
+ )
+
+ if not (args.saved_path.endswith(".csv") or args.saved_path.endswith(".jsonl")):
+ raise ValueError("The saved_path must end with .csv or .jsonl.")
+
+ if os.path.exists(args.saved_path):
+ if args.saved_path.endswith(".csv"):
+ saved_metadata_df = pd.read_csv(args.saved_path)
+ elif args.saved_path.endswith(".jsonl"):
+ saved_metadata_df = pd.read_json(args.saved_path, lines=True)
+ saved_video_path_list = saved_metadata_df[args.video_path_column].tolist()
+ saved_video_path_list = [os.path.join(args.video_folder, path) for path in saved_video_path_list]
+ video_path_list = list(set(video_path_list) - set(saved_video_path_list))
+ # Sorting to guarantee the same result for each process.
+ video_path_list = natsorted(video_path_list)
+ logger.info(f"Resume from {args.saved_path}: {len(saved_video_path_list)} processed and {len(video_path_list)} to be processed.")
+
+ if args.image_caption_model_name in SGLANG_SUPPORTED_MODELS:
+ sglang_inference(args, video_path_list)
+ elif args.image_caption_model_name in ACCELERATE_SUPPORTED_MODELS:
+ accelerate_inference(args, video_path_list)
+ else:
+ raise ValueError(f"The {args.image_caption_model_name} is not supported.")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/easyanimate/video_caption/compute_motion_score.py b/easyanimate/video_caption/compute_motion_score.py
new file mode 100644
index 0000000000000000000000000000000000000000..4f8afbac2ef979fcaa15fa094ebcf5109e85d25c
--- /dev/null
+++ b/easyanimate/video_caption/compute_motion_score.py
@@ -0,0 +1,196 @@
+import ast
+import argparse
+import gc
+import os
+from contextlib import contextmanager
+from pathlib import Path
+
+import cv2
+import numpy as np
+import pandas as pd
+from joblib import Parallel, delayed
+from natsort import natsorted
+from tqdm import tqdm
+
+from utils.logger import logger
+from utils.video_utils import get_video_path_list
+
+
+@contextmanager
+def VideoCapture(video_path):
+ cap = cv2.VideoCapture(video_path)
+ try:
+ yield cap
+ finally:
+ cap.release()
+ del cap
+ gc.collect()
+
+
+def compute_motion_score(video_path):
+ video_motion_scores = []
+ sampling_fps = 2
+
+ try:
+ with VideoCapture(video_path) as cap:
+ fps = cap.get(cv2.CAP_PROP_FPS)
+ valid_fps = min(max(sampling_fps, 1), fps)
+ frame_interval = int(fps / valid_fps)
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
+
+ # if cannot get the second frame, use the last one
+ frame_interval = min(frame_interval, total_frames - 1)
+
+ prev_frame = None
+ frame_count = -1
+ while cap.isOpened():
+ ret, frame = cap.read()
+ frame_count += 1
+
+ if not ret:
+ break
+
+ # skip middle frames
+ if frame_count % frame_interval != 0:
+ continue
+
+ gray_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
+ if prev_frame is None:
+ prev_frame = gray_frame
+ continue
+
+ flow = cv2.calcOpticalFlowFarneback(
+ prev_frame,
+ gray_frame,
+ None,
+ pyr_scale=0.5,
+ levels=3,
+ winsize=15,
+ iterations=3,
+ poly_n=5,
+ poly_sigma=1.2,
+ flags=0,
+ )
+ mag, _ = cv2.cartToPolar(flow[..., 0], flow[..., 1])
+ frame_motion_score = np.mean(mag)
+ video_motion_scores.append(frame_motion_score)
+ prev_frame = gray_frame
+
+ video_meta_info = {
+ "video_path": Path(video_path).name,
+ "motion_score": round(float(np.mean(video_motion_scores)), 5),
+ }
+ return video_meta_info
+
+ except Exception as e:
+ print(f"Compute motion score for video {video_path} with error: {e}.")
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Compute the motion score of the videos.")
+ parser.add_argument("--video_folder", type=str, default="", help="The video folder.")
+ parser.add_argument(
+ "--video_metadata_path", type=str, default=None, help="The path to the video dataset metadata (csv/jsonl)."
+ )
+ parser.add_argument(
+ "--video_path_column",
+ type=str,
+ default="video_path",
+ help="The column contains the video path (an absolute path or a relative path w.r.t the video_folder).",
+ )
+ parser.add_argument("--saved_path", type=str, required=True, help="The save path to the output results (csv/jsonl).")
+ parser.add_argument("--saved_freq", type=int, default=100, help="The frequency to save the output results.")
+ parser.add_argument("--n_jobs", type=int, default=1, help="The number of concurrent processes.")
+
+ parser.add_argument(
+ "--asethetic_score_metadata_path", type=str, default=None, help="The path to the video quality metadata (csv/jsonl)."
+ )
+ parser.add_argument("--asethetic_score_threshold", type=float, default=4.0, help="The asethetic score threshold.")
+ parser.add_argument(
+ "--text_score_metadata_path", type=str, default=None, help="The path to the video text score metadata (csv/jsonl)."
+ )
+ parser.add_argument("--text_score_threshold", type=float, default=0.02, help="The text threshold.")
+
+ args = parser.parse_args()
+ return args
+
+
+def main():
+ args = parse_args()
+
+ video_path_list = get_video_path_list(
+ video_folder=args.video_folder,
+ video_metadata_path=args.video_metadata_path,
+ video_path_column=args.video_path_column
+ )
+
+ if not (args.saved_path.endswith(".csv") or args.saved_path.endswith(".jsonl")):
+ raise ValueError("The saved_path must end with .csv or .jsonl.")
+
+ if os.path.exists(args.saved_path):
+ if args.saved_path.endswith(".csv"):
+ saved_metadata_df = pd.read_csv(args.saved_path)
+ elif args.saved_path.endswith(".jsonl"):
+ saved_metadata_df = pd.read_json(args.saved_path, lines=True)
+ saved_video_path_list = saved_metadata_df[args.video_path_column].tolist()
+ saved_video_path_list = [os.path.join(args.video_folder, video_path) for video_path in saved_video_path_list]
+
+ video_path_list = list(set(video_path_list).difference(set(saved_video_path_list)))
+ # Sorting to guarantee the same result for each process.
+ video_path_list = natsorted(video_path_list)
+ logger.info(f"Resume from {args.saved_path}: {len(saved_video_path_list)} processed and {len(video_path_list)} to be processed.")
+
+ if args.asethetic_score_metadata_path is not None:
+ if args.asethetic_score_metadata_path.endswith(".csv"):
+ asethetic_score_df = pd.read_csv(args.asethetic_score_metadata_path)
+ elif args.asethetic_score_metadata_path.endswith(".jsonl"):
+ asethetic_score_df = pd.read_json(args.asethetic_score_metadata_path, lines=True)
+
+ # In pandas, csv will save lists as strings, whereas jsonl will not.
+ asethetic_score_df["aesthetic_score"] = asethetic_score_df["aesthetic_score"].apply(
+ lambda x: ast.literal_eval(x) if isinstance(x, str) else x
+ )
+ asethetic_score_df["aesthetic_score_mean"] = asethetic_score_df["aesthetic_score"].apply(lambda x: sum(x) / len(x))
+ filtered_asethetic_score_df = asethetic_score_df[asethetic_score_df["aesthetic_score_mean"] < args.asethetic_score_threshold]
+ filtered_video_path_list = filtered_asethetic_score_df[args.video_path_column].tolist()
+ filtered_video_path_list = [os.path.join(args.video_folder, video_path) for video_path in filtered_video_path_list]
+
+ video_path_list = list(set(video_path_list).difference(set(filtered_video_path_list)))
+ # Sorting to guarantee the same result for each process.
+ video_path_list = natsorted(video_path_list)
+ logger.info(f"Load {args.asethetic_score_metadata_path} and filter {len(filtered_video_path_list)} videos.")
+
+ if args.text_score_metadata_path is not None:
+ if args.text_score_metadata_path.endswith(".csv"):
+ text_score_df = pd.read_csv(args.text_score_metadata_path)
+ elif args.text_score_metadata_path.endswith(".jsonl"):
+ text_score_df = pd.read_json(args.text_score_metadata_path, lines=True)
+
+ filtered_text_score_df = text_score_df[text_score_df["text_score"] > args.text_score_threshold]
+ filtered_video_path_list = filtered_text_score_df[args.video_path_column].tolist()
+ filtered_video_path_list = [os.path.join(args.video_folder, video_path) for video_path in filtered_video_path_list]
+
+ video_path_list = list(set(video_path_list).difference(set(filtered_video_path_list)))
+ # Sorting to guarantee the same result for each process.
+ video_path_list = natsorted(video_path_list)
+ logger.info(f"Load {args.text_score_metadata_path} and filter {len(filtered_video_path_list)} videos.")
+
+ for i in tqdm(range(0, len(video_path_list), args.saved_freq)):
+ result_list = Parallel(n_jobs=args.n_jobs, backend="threading")(
+ delayed(compute_motion_score)(video_path) for video_path in tqdm(video_path_list[i: i + args.saved_freq])
+ )
+ result_list = [result for result in result_list if result is not None]
+ if len(result_list) == 0:
+ continue
+
+ result_df = pd.DataFrame(result_list)
+ if args.saved_path.endswith(".csv"):
+ header = False if os.path.exists(args.saved_path) else True
+ result_df.to_csv(args.saved_path, header=header, index=False, mode="a")
+ elif args.saved_path.endswith(".jsonl"):
+ result_df.to_json(args.saved_path, orient="records", lines=True, mode="a")
+ logger.info(f"Save result to {args.saved_path}.")
+
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/easyanimate/video_caption/compute_text_score.py b/easyanimate/video_caption/compute_text_score.py
new file mode 100644
index 0000000000000000000000000000000000000000..f1e8ec5302747289c050a975ff63fb8d5b242f29
--- /dev/null
+++ b/easyanimate/video_caption/compute_text_score.py
@@ -0,0 +1,198 @@
+import ast
+import argparse
+import os
+from pathlib import Path
+
+import easyocr
+import numpy as np
+import pandas as pd
+from accelerate import PartialState
+from accelerate.utils import gather_object
+from natsort import natsorted
+from tqdm import tqdm
+from torchvision.datasets.utils import download_url
+
+from utils.logger import logger
+from utils.video_utils import extract_frames, get_video_path_list
+
+
+def init_ocr_reader(root: str = "~/.cache/easyocr", device: str = "gpu"):
+ root = os.path.expanduser(root)
+ if not os.path.exists(root):
+ os.makedirs(root)
+ download_url(
+ "https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/easyanimate/video_caption/easyocr/craft_mlt_25k.pth",
+ root,
+ filename="craft_mlt_25k.pth",
+ md5="2f8227d2def4037cdb3b34389dcf9ec1",
+ )
+ ocr_reader = easyocr.Reader(
+ lang_list=["en", "ch_sim"],
+ gpu=device,
+ recognizer=False,
+ verbose=False,
+ model_storage_directory=root,
+ )
+
+ return ocr_reader
+
+
+def triangle_area(p1, p2, p3):
+ """Compute the triangle area according to its coordinates.
+ """
+ x1, y1 = p1
+ x2, y2 = p2
+ x3, y3 = p3
+ tri_area = 0.5 * np.abs(x1 * y2 + x2 * y3 + x3 * y1 - x2 * y1 - x3 * y2 - x1 * y3)
+ return tri_area
+
+
+def compute_text_score(video_path, ocr_reader):
+ _, images = extract_frames(video_path, sample_method="mid")
+ images = [np.array(image) for image in images]
+
+ frame_ocr_area_ratios = []
+ for image in images:
+ # horizontal detected results and free-form detected
+ horizontal_list, free_list = ocr_reader.detect(np.asarray(image))
+ width, height = image.shape[0], image.shape[1]
+
+ total_area = width * height
+ # rectangles
+ rect_area = 0
+ for xmin, xmax, ymin, ymax in horizontal_list[0]:
+ if xmax < xmin or ymax < ymin:
+ continue
+ rect_area += (xmax - xmin) * (ymax - ymin)
+ # free-form
+ quad_area = 0
+ try:
+ for points in free_list[0]:
+ triangle1 = points[:3]
+ quad_area += triangle_area(*triangle1)
+ triangle2 = points[3:] + [points[0]]
+ quad_area += triangle_area(*triangle2)
+ except:
+ quad_area = 0
+ text_area = rect_area + quad_area
+
+ frame_ocr_area_ratios.append(text_area / total_area)
+
+ video_meta_info = {
+ "video_path": Path(video_path).name,
+ "text_score": round(np.mean(frame_ocr_area_ratios), 5),
+ }
+
+ return video_meta_info
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Compute the text score of the middle frame in the videos.")
+ parser.add_argument("--video_folder", type=str, default="", help="The video folder.")
+ parser.add_argument(
+ "--video_metadata_path", type=str, default=None, help="The path to the video dataset metadata (csv/jsonl)."
+ )
+ parser.add_argument(
+ "--video_path_column",
+ type=str,
+ default="video_path",
+ help="The column contains the video path (an absolute path or a relative path w.r.t the video_folder).",
+ )
+ parser.add_argument("--saved_path", type=str, required=True, help="The save path to the output results (csv/jsonl).")
+ parser.add_argument("--saved_freq", type=int, default=100, help="The frequency to save the output results.")
+ parser.add_argument(
+ "--asethetic_score_metadata_path", type=str, default=None, help="The path to the video quality metadata (csv/jsonl)."
+ )
+ parser.add_argument("--asethetic_score_threshold", type=float, default=4.0, help="The asethetic score threshold.")
+
+ args = parser.parse_args()
+ return args
+
+
+def main():
+ args = parse_args()
+
+ video_path_list = get_video_path_list(
+ video_folder=args.video_folder,
+ video_metadata_path=args.video_metadata_path,
+ video_path_column=args.video_path_column
+ )
+
+ if not (args.saved_path.endswith(".csv") or args.saved_path.endswith(".jsonl")):
+ raise ValueError("The saved_path must end with .csv or .jsonl.")
+
+ if os.path.exists(args.saved_path):
+ if args.saved_path.endswith(".csv"):
+ saved_metadata_df = pd.read_csv(args.saved_path)
+ elif args.saved_path.endswith(".jsonl"):
+ saved_metadata_df = pd.read_json(args.saved_path, lines=True)
+ saved_video_path_list = saved_metadata_df[args.video_path_column].tolist()
+ saved_video_path_list = [os.path.join(args.video_folder, video_path) for video_path in saved_video_path_list]
+
+ video_path_list = list(set(video_path_list).difference(set(saved_video_path_list)))
+ # Sorting to guarantee the same result for each process.
+ video_path_list = natsorted(video_path_list)
+ logger.info(f"Resume from {args.saved_path}: {len(saved_video_path_list)} processed and {len(video_path_list)} to be processed.")
+
+ if args.asethetic_score_metadata_path is not None:
+ if args.asethetic_score_metadata_path.endswith(".csv"):
+ asethetic_score_df = pd.read_csv(args.asethetic_score_metadata_path)
+ elif args.asethetic_score_metadata_path.endswith(".jsonl"):
+ asethetic_score_df = pd.read_json(args.asethetic_score_metadata_path, lines=True)
+
+ # In pandas, csv will save lists as strings, whereas jsonl will not.
+ asethetic_score_df["aesthetic_score"] = asethetic_score_df["aesthetic_score"].apply(
+ lambda x: ast.literal_eval(x) if isinstance(x, str) else x
+ )
+ asethetic_score_df["aesthetic_score_mean"] = asethetic_score_df["aesthetic_score"].apply(lambda x: sum(x) / len(x))
+ filtered_asethetic_score_df = asethetic_score_df[asethetic_score_df["aesthetic_score_mean"] < args.asethetic_score_threshold]
+ filtered_video_path_list = filtered_asethetic_score_df[args.video_path_column].tolist()
+ filtered_video_path_list = [os.path.join(args.video_folder, video_path) for video_path in filtered_video_path_list]
+
+ video_path_list = list(set(video_path_list).difference(set(filtered_video_path_list)))
+ # Sorting to guarantee the same result for each process.
+ video_path_list = natsorted(video_path_list)
+ logger.info(f"Load {args.asethetic_score_metadata_path} and filter {len(filtered_video_path_list)} videos.")
+
+ state = PartialState()
+ ocr_reader = init_ocr_reader(device=state.device)
+
+ # The workaround can be removed after https://github.com/huggingface/accelerate/pull/2781 is released.
+ index = len(video_path_list) - len(video_path_list) % state.num_processes
+ logger.info(f"Drop {len(video_path_list) % state.num_processes} videos to avoid duplicates in state.split_between_processes.")
+ video_path_list = video_path_list[:index]
+
+ result_list = []
+ with state.split_between_processes(video_path_list) as splitted_video_path_list:
+ for i, video_path in enumerate(tqdm(splitted_video_path_list)):
+ video_meta_info = compute_text_score(video_path, ocr_reader)
+ result_list.append(video_meta_info)
+ if i != 0 and i % args.saved_freq == 0:
+ state.wait_for_everyone()
+ gathered_result_list = gather_object(result_list)
+ if state.is_main_process:
+ result_df = pd.DataFrame(gathered_result_list)
+ if args.saved_path.endswith(".csv"):
+ header = False if os.path.exists(args.saved_path) else True
+ result_df.to_csv(args.saved_path, header=header, index=False, mode="a")
+ elif args.saved_path.endswith(".jsonl"):
+ result_df.to_json(args.saved_path, orient="records", lines=True, mode="a")
+ logger.info(f"Save result to {args.saved_path}.")
+ result_list = []
+
+ state.wait_for_everyone()
+ gathered_result_list = gather_object(result_list)
+ if state.is_main_process:
+ logger.info(len(gathered_result_list))
+ if len(gathered_result_list) != 0:
+ result_df = pd.DataFrame(gathered_result_list)
+ if args.saved_path.endswith(".csv"):
+ header = False if os.path.exists(args.saved_path) else True
+ result_df.to_csv(args.saved_path, header=header, index=False, mode="a")
+ elif args.saved_path.endswith(".jsonl"):
+ result_df.to_json(args.saved_path, orient="records", lines=True, mode="a")
+ logger.info(f"Save the final result to {args.saved_path}.")
+
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/easyanimate/video_caption/compute_video_frame_quality.py b/easyanimate/video_caption/compute_video_frame_quality.py
new file mode 100644
index 0000000000000000000000000000000000000000..8ebc9997c465c955d05eb3ec5c32af3862c2015b
--- /dev/null
+++ b/easyanimate/video_caption/compute_video_frame_quality.py
@@ -0,0 +1,176 @@
+import argparse
+import re
+import os
+
+import pandas as pd
+from accelerate import PartialState
+from accelerate.utils import gather_object
+from natsort import natsorted
+from tqdm import tqdm
+from torch.utils.data import DataLoader
+
+import utils.image_evaluator as image_evaluator
+from utils.logger import logger
+from utils.video_dataset import VideoDataset, collate_fn
+from utils.video_utils import get_video_path_list
+
+
+def camel2snake(s: str) -> str:
+ """Convert camel case to snake case."""
+ if not re.match("^[A-Z]+$", s):
+ pattern = re.compile(r"(? 1
+
+ video_path_list = get_video_path_list(
+ video_folder=args.video_folder,
+ video_metadata_path=args.video_metadata_path,
+ video_path_column=args.video_path_column
+ )
+
+ if not (args.saved_path.endswith(".csv") or args.saved_path.endswith(".jsonl")):
+ raise ValueError("The saved_path must end with .csv or .jsonl.")
+
+ caption_list = None
+ if args.video_metadata_path is not None and args.caption_column is not None:
+ if args.video_metadata_path.endswith(".csv"):
+ video_metadata_df = pd.read_csv(args.video_metadata_path)
+ elif args.video_metadata_path.endswith(".jsonl"):
+ video_metadata_df = pd.read_json(args.video_metadata_path, lines=True)
+ else:
+ raise ValueError("The video_metadata_path must end with .csv or .jsonl.")
+ caption_list = video_metadata_df[args.caption_column].tolist()
+
+ if os.path.exists(args.saved_path):
+ if args.saved_path.endswith(".csv"):
+ saved_metadata_df = pd.read_csv(args.saved_path)
+ elif args.saved_path.endswith(".jsonl"):
+ saved_metadata_df = pd.read_json(args.saved_path, lines=True)
+ saved_video_path_list = saved_metadata_df[args.video_path_column].tolist()
+ saved_video_path_list = [os.path.join(args.video_folder, video_path) for video_path in saved_video_path_list]
+
+ video_path_list = list(set(video_path_list).difference(set(saved_video_path_list)))
+ # Sorting to guarantee the same result for each process.
+ video_path_list = natsorted(video_path_list)
+ logger.info(f"Resume from {args.saved_path}: {len(saved_video_path_list)} processed and {len(video_path_list)} to be processed.")
+
+ logger.info("Initializing evaluator metrics...")
+ state = PartialState()
+ metric_fns = [getattr(image_evaluator, metric)(device=state.device) for metric in args.metrics]
+
+ # The workaround can be removed after https://github.com/huggingface/accelerate/pull/2781 is released.
+ index = len(video_path_list) - len(video_path_list) % state.num_processes
+ logger.info(f"Drop {len(video_path_list) % state.num_processes} videos to avoid duplicates in state.split_between_processes.")
+ video_path_list = video_path_list[:index]
+
+ result_dict = {args.video_path_column: [], "sample_frame_idx": []}
+ for metric in args.metrics:
+ result_dict[camel2snake(metric)] = []
+
+ with state.split_between_processes(video_path_list) as splitted_video_path_list:
+ video_dataset = VideoDataset(
+ video_path_list=splitted_video_path_list,
+ sample_method="uniform",
+ num_sampled_frames=args.num_sampled_frames
+ )
+ video_loader = DataLoader(video_dataset, batch_size=args.batch_size, num_workers=4, collate_fn=collate_fn)
+ for idx, batch in enumerate(tqdm(video_loader)):
+ if len(batch) == 0:
+ continue
+ batch_video_path = batch[args.video_path_column]
+ result_dict["sample_frame_idx"].extend(batch["sampled_frame_idx"])
+ # [batch_size, num_sampled_frames, H, W, C] => [batch_size * num_sampled_frames, H, W, C].
+ batch_frame = []
+ for item_sampled_frame in batch["sampled_frame"]:
+ batch_frame.extend([frame for frame in item_sampled_frame])
+ batch_caption = None
+ if caption_list is not None:
+ batch_caption = caption_list[i : i + args.batch_size]
+ # Compute the frame quality.
+ for i, metric in enumerate(args.metrics):
+ # [batch_size * num_sampled_frames] => [batch_size, num_sampled_frames]
+ quality_scores = metric_fns[i](batch_frame, batch_caption)
+ quality_scores = [round(score, 5) for score in quality_scores]
+ quality_scores = [quality_scores[j:j + args.num_sampled_frames] for j in range(0, len(quality_scores), args.num_sampled_frames)]
+ result_dict[camel2snake(metric)].extend(quality_scores)
+
+ saved_video_path_list = [os.path.basename(video_path) for video_path in batch_video_path]
+ result_dict[args.video_path_column].extend(saved_video_path_list)
+
+ # Save the metadata in the main process every saved_freq.
+ if (idx != 0) and (idx % args.saved_freq == 0):
+ state.wait_for_everyone()
+ gathered_result_dict = {k: gather_object(v) for k, v in result_dict.items()}
+ if state.is_main_process:
+ result_df = pd.DataFrame(gathered_result_dict)
+ if args.saved_path.endswith(".csv"):
+ header = False if os.path.exists(args.saved_path) else True
+ result_df.to_csv(args.saved_path, header=header, index=False, mode="a")
+ elif args.saved_path.endswith(".jsonl"):
+ result_df.to_json(args.saved_path, orient="records", lines=True, mode="a")
+ logger.info(f"Save result to {args.saved_path}.")
+ for k in result_dict.keys():
+ result_dict[k] = []
+
+ # Wait for all processes to finish and gather the final result.
+ state.wait_for_everyone()
+ gathered_result_dict = {k: gather_object(v) for k, v in result_dict.items()}
+ # Save the metadata in the main process.
+ if state.is_main_process:
+ result_df = pd.DataFrame(gathered_result_dict)
+ if len(gathered_result_dict[args.video_path_column]) != 0:
+ result_df = pd.DataFrame(gathered_result_dict)
+ if args.saved_path.endswith(".csv"):
+ header = False if os.path.exists(args.saved_path) else True
+ result_df.to_csv(args.saved_path, header=header, index=False, mode="a")
+ elif args.saved_path.endswith(".jsonl"):
+ result_df.to_json(args.saved_path, orient="records", lines=True, mode="a")
+ logger.info(f"Save the final result to {args.saved_path}.")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/easyanimate/video_caption/convert_jsonl_to_json.py b/easyanimate/video_caption/convert_jsonl_to_json.py
new file mode 100644
index 0000000000000000000000000000000000000000..78b7ad99ce168fb0902237d5b44b0e89da1060c3
--- /dev/null
+++ b/easyanimate/video_caption/convert_jsonl_to_json.py
@@ -0,0 +1,40 @@
+import argparse
+import json
+import os
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Convert jsonl to json.")
+ parser.add_argument("--video_folder", type=str, default="", help="The video folder.")
+ parser.add_argument(
+ "--jsonl_load_path", type=str, default=None, help="The path to the video dataset metadata (csv/jsonl)."
+ )
+ parser.add_argument("--save_path", type=str, default=None, help="The save path to the output results.")
+ args = parser.parse_args()
+ return args
+
+def main():
+ args = parse_args()
+
+ with open(args.jsonl_load_path, "r") as read:
+ _lines = read.readlines()
+
+ output = []
+ for line in _lines:
+ try:
+ line = json.loads(line.strip())
+ videoid, name = line['video_path'], line['summary_caption']
+ output.append(
+ {
+ "file_path": os.path.join(args.video_folder, videoid),
+ "text": name,
+ "type": "video",
+ }
+ )
+ except:
+ pass
+
+ with open(args.save_path, mode="w", encoding="utf-8") as f:
+ json.dump(output, f, indent=2)
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/easyanimate/video_caption/easyocr_detection_patched.py b/easyanimate/video_caption/easyocr_detection_patched.py
new file mode 100644
index 0000000000000000000000000000000000000000..e2cffa2b00c7c90aafcde307ce27307ed6e71dbf
--- /dev/null
+++ b/easyanimate/video_caption/easyocr_detection_patched.py
@@ -0,0 +1,114 @@
+"""Modified from https://github.com/JaidedAI/EasyOCR/blob/803b907/easyocr/detection.py.
+1. Disable DataParallel.
+"""
+import torch
+import torch.backends.cudnn as cudnn
+from torch.autograd import Variable
+from PIL import Image
+from collections import OrderedDict
+
+import cv2
+import numpy as np
+from .craft_utils import getDetBoxes, adjustResultCoordinates
+from .imgproc import resize_aspect_ratio, normalizeMeanVariance
+from .craft import CRAFT
+
+def copyStateDict(state_dict):
+ if list(state_dict.keys())[0].startswith("module"):
+ start_idx = 1
+ else:
+ start_idx = 0
+ new_state_dict = OrderedDict()
+ for k, v in state_dict.items():
+ name = ".".join(k.split(".")[start_idx:])
+ new_state_dict[name] = v
+ return new_state_dict
+
+def test_net(canvas_size, mag_ratio, net, image, text_threshold, link_threshold, low_text, poly, device, estimate_num_chars=False):
+ if isinstance(image, np.ndarray) and len(image.shape) == 4: # image is batch of np arrays
+ image_arrs = image
+ else: # image is single numpy array
+ image_arrs = [image]
+
+ img_resized_list = []
+ # resize
+ for img in image_arrs:
+ img_resized, target_ratio, size_heatmap = resize_aspect_ratio(img, canvas_size,
+ interpolation=cv2.INTER_LINEAR,
+ mag_ratio=mag_ratio)
+ img_resized_list.append(img_resized)
+ ratio_h = ratio_w = 1 / target_ratio
+ # preprocessing
+ x = [np.transpose(normalizeMeanVariance(n_img), (2, 0, 1))
+ for n_img in img_resized_list]
+ x = torch.from_numpy(np.array(x))
+ x = x.to(device)
+
+ # forward pass
+ with torch.no_grad():
+ y, feature = net(x)
+
+ boxes_list, polys_list = [], []
+ for out in y:
+ # make score and link map
+ score_text = out[:, :, 0].cpu().data.numpy()
+ score_link = out[:, :, 1].cpu().data.numpy()
+
+ # Post-processing
+ boxes, polys, mapper = getDetBoxes(
+ score_text, score_link, text_threshold, link_threshold, low_text, poly, estimate_num_chars)
+
+ # coordinate adjustment
+ boxes = adjustResultCoordinates(boxes, ratio_w, ratio_h)
+ polys = adjustResultCoordinates(polys, ratio_w, ratio_h)
+ if estimate_num_chars:
+ boxes = list(boxes)
+ polys = list(polys)
+ for k in range(len(polys)):
+ if estimate_num_chars:
+ boxes[k] = (boxes[k], mapper[k])
+ if polys[k] is None:
+ polys[k] = boxes[k]
+ boxes_list.append(boxes)
+ polys_list.append(polys)
+
+ return boxes_list, polys_list
+
+def get_detector(trained_model, device='cpu', quantize=True, cudnn_benchmark=False):
+ net = CRAFT()
+
+ if device == 'cpu':
+ net.load_state_dict(copyStateDict(torch.load(trained_model, map_location=device)))
+ if quantize:
+ try:
+ torch.quantization.quantize_dynamic(net, dtype=torch.qint8, inplace=True)
+ except:
+ pass
+ else:
+ net.load_state_dict(copyStateDict(torch.load(trained_model, map_location=device)))
+ # net = torch.nn.DataParallel(net).to(device)
+ net = net.to(device)
+ cudnn.benchmark = cudnn_benchmark
+
+ net.eval()
+ return net
+
+def get_textbox(detector, image, canvas_size, mag_ratio, text_threshold, link_threshold, low_text, poly, device, optimal_num_chars=None, **kwargs):
+ result = []
+ estimate_num_chars = optimal_num_chars is not None
+ bboxes_list, polys_list = test_net(canvas_size, mag_ratio, detector,
+ image, text_threshold,
+ link_threshold, low_text, poly,
+ device, estimate_num_chars)
+ if estimate_num_chars:
+ polys_list = [[p for p, _ in sorted(polys, key=lambda x: abs(optimal_num_chars - x[1]))]
+ for polys in polys_list]
+
+ for polys in polys_list:
+ single_img_result = []
+ for i, box in enumerate(polys):
+ poly = np.array(box).astype(np.int32).reshape((-1))
+ single_img_result.append(poly)
+ result.append(single_img_result)
+
+ return result
diff --git a/easyanimate/video_caption/filter_videos_by_motion_score.py b/easyanimate/video_caption/filter_videos_by_motion_score.py
new file mode 100644
index 0000000000000000000000000000000000000000..e622aaa405958e10bafbe30bf0f35d6b7b3063a4
--- /dev/null
+++ b/easyanimate/video_caption/filter_videos_by_motion_score.py
@@ -0,0 +1,55 @@
+import ast
+import argparse
+import gc
+import os
+from contextlib import contextmanager
+from pathlib import Path
+
+import cv2
+import numpy as np
+import pandas as pd
+from joblib import Parallel, delayed
+from natsort import natsorted
+from tqdm import tqdm
+
+from utils.logger import logger
+from utils.video_utils import get_video_path_list
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Filter the motion score of the videos.")
+ parser.add_argument(
+ "--motion_score_metadata_path", type=str, default=None, help="The path to the video dataset metadata (csv/jsonl)."
+ )
+ parser.add_argument("--low_motion_score_threshold", type=float, default=3.0, help="The low motion score threshold.")
+ parser.add_argument("--high_motion_score_threshold", type=float, default=8.0, help="The high motion score threshold.")
+ parser.add_argument("--saved_path", type=str, required=True, help="The save path to the output results (csv/jsonl).")
+
+ args = parser.parse_args()
+ return args
+
+
+def main():
+ args = parse_args()
+
+ if not (args.saved_path.endswith(".csv") or args.saved_path.endswith(".jsonl")):
+ raise ValueError("The saved_path must end with .csv or .jsonl.")
+
+ if args.motion_score_metadata_path is not None:
+ if args.motion_score_metadata_path.endswith(".csv"):
+ motion_score_df = pd.read_csv(args.motion_score_metadata_path)
+ elif args.motion_score_metadata_path.endswith(".jsonl"):
+ motion_score_df = pd.read_json(args.motion_score_metadata_path, lines=True)
+
+ filtered_motion_score_df = motion_score_df[motion_score_df["motion_score"] > args.low_motion_score_threshold]
+ filtered_motion_score_df = filtered_motion_score_df[motion_score_df["motion_score"] < args.high_motion_score_threshold]
+
+ if args.saved_path.endswith(".csv"):
+ header = False if os.path.exists(args.saved_path) else True
+ filtered_motion_score_df.to_csv(args.saved_path, header=header, index=False, mode="a")
+ elif args.saved_path.endswith(".jsonl"):
+ filtered_motion_score_df.to_json(args.saved_path, orient="records", lines=True, mode="a")
+ logger.info(f"Save result to {args.saved_path}.")
+
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/easyanimate/video_caption/requirements.txt b/easyanimate/video_caption/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..b64e18452c0f6de98276e45b47559454de63f100
--- /dev/null
+++ b/easyanimate/video_caption/requirements.txt
@@ -0,0 +1,11 @@
+auto_gptq==0.6.0
+pandas>=2.0.0
+vllm==0.3.3
+sglang[srt]==0.1.13
+func_timeout
+easyocr==1.7.1
+git+https://github.com/openai/CLIP.git
+natsort
+joblib
+scenedetect
+av
diff --git a/easyanimate/video_caption/scenedetect_vcut.py b/easyanimate/video_caption/scenedetect_vcut.py
new file mode 100644
index 0000000000000000000000000000000000000000..b49c80b514c09038db3fa2a98c8ebbe4aa6c13a1
--- /dev/null
+++ b/easyanimate/video_caption/scenedetect_vcut.py
@@ -0,0 +1,235 @@
+import argparse
+import copy
+import json
+import os
+import shutil
+from multiprocessing import Pool
+
+from scenedetect import SceneManager, open_video
+from scenedetect.detectors import ContentDetector
+from scenedetect.video_splitter import split_video_ffmpeg
+from tqdm import tqdm
+
+from utils.video_utils import download_video, get_video_path_list
+
+tmp_file_dir = "./tmp"
+DEFAULT_FFMPEG_ARGS = '-c:v libx264 -preset veryfast -crf 22 -c:a aac'
+
+def parse_args():
+ parser = argparse.ArgumentParser(
+ description = '''Cut video by PySceneDetect''')
+ parser.add_argument(
+ 'video',
+ type = str,
+ help = '''Input format:
+ 1. Local video file path.
+ 2. Video URL.
+ 3. Local root dir path of videos.
+ 4. Local txt file of video urls/local file path, line by line.
+ ''')
+ parser.add_argument(
+ '--threshold',
+ type = float,
+ nargs='+',
+ default = [10, 20, 30],
+ help = 'Threshold list the average change in pixel intensity must exceed to trigger a cut, one-to-one with frame_skip.')
+ parser.add_argument(
+ '--frame_skip',
+ type = int,
+ nargs='+',
+ default = [0, 1, 2],
+ help = 'Number list of frames to skip, coordinate with threshold \
+ (i.e. process every 1 in N+1 frames, where N is frame_skip, \
+ processing only 1/N+1 percent of the video, \
+ speeding up the detection time at the expense of accuracy). One-to-one with threshold.')
+ parser.add_argument(
+ '--min_seconds',
+ type = int,
+ default = 3,
+ help = 'Video cut must be longer then min_seconds.')
+ parser.add_argument(
+ '--max_seconds',
+ type = int,
+ default = 12,
+ help = 'Video cut must be longer then min_seconds.')
+ parser.add_argument(
+ '--save_dir',
+ type = str,
+ default = "",
+ help = 'Video scene cuts save dir, default value means reusing input video dir.')
+ parser.add_argument(
+ '--name_template',
+ type = str,
+ default = "$VIDEO_NAME-Scene-$SCENE_NUMBER.mp4",
+ help = 'Video scene cuts save name template.')
+ parser.add_argument(
+ '--num_processes',
+ type = int,
+ default = os.cpu_count() // 2,
+ help = 'Number of CPU cores to process the video scene cut.')
+ parser.add_argument(
+ "--save_json", action="store_true", help="Whether save json in datasets."
+ )
+ args = parser.parse_args()
+ return args
+
+
+def split_video_into_scenes(
+ video_path: str,
+ threshold: list[float] = [27.0],
+ frame_skip: list[int] = [0],
+ min_seconds: int = 3,
+ max_seconds: int = 8,
+ save_dir: str = "",
+ name_template: str = "$VIDEO_NAME-Scene-$SCENE_NUMBER.mp4",
+ save_json: bool = False ):
+ # SceneDetect video through casceded (threshold, FPS)
+ frame_points = []
+ frame_timecode = {}
+ fps = 25.0
+ for thre, f_skip in zip(threshold, frame_skip):
+ # Open our video, create a scene manager, and add a detector.
+ video = open_video(video_path, backend='pyav')
+ scene_manager = SceneManager()
+ scene_manager.add_detector(
+ # [ContentDetector, ThresholdDetector, AdaptiveDetector]
+ ContentDetector(threshold=thre, min_scene_len=10)
+ )
+ scene_manager.detect_scenes(video, frame_skip=f_skip, show_progress=False)
+ scene_list = scene_manager.get_scene_list()
+ for scene in scene_list:
+ for frame_time_code in scene:
+ frame_index = frame_time_code.get_frames()
+ if frame_index not in frame_points:
+ frame_points.append(frame_index)
+ frame_timecode[frame_index] = frame_time_code
+ fps = frame_time_code.get_framerate()
+ del video, scene_manager
+ frame_points = sorted(frame_points)
+ output_scene_list = []
+
+ # Detect No Scene Change
+ if len(frame_points) == 0:
+ video = open_video(video_path, backend='pyav')
+ frame_points = [0, video.duration.get_frames() - 1]
+ frame_timecode = {
+ frame_points[0]: video.base_timecode,
+ frame_points[-1]: video.base_timecode + video.base_timecode + video.duration
+ }
+ del video
+
+ for idx in range(len(frame_points) - 1):
+ # Limit save out min seconds
+ if frame_points[idx+1] - frame_points[idx] < fps * min_seconds:
+ continue
+ # Limit save out max seconds
+ elif frame_points[idx+1] - frame_points[idx] > fps * max_seconds:
+ tmp_start_timecode = copy.deepcopy(frame_timecode[frame_points[idx]])
+ tmp_end_timecode = copy.deepcopy(frame_timecode[frame_points[idx]]) + int(max_seconds * fps)
+ # Average cut by max seconds
+ while tmp_end_timecode.get_frames() <= frame_points[idx+1]:
+ output_scene_list.append((
+ copy.deepcopy(tmp_start_timecode),
+ copy.deepcopy(tmp_end_timecode)))
+ tmp_start_timecode += int(max_seconds * fps)
+ tmp_end_timecode += int(max_seconds * fps)
+ if tmp_end_timecode.get_frames() > frame_points[idx+1] and frame_points[idx+1] - tmp_start_timecode.get_frames() > fps * min_seconds:
+ output_scene_list.append((
+ copy.deepcopy(tmp_start_timecode),
+ frame_timecode[frame_points[idx+1]]))
+ del tmp_start_timecode, tmp_end_timecode
+ continue
+ output_scene_list.append((
+ frame_timecode[frame_points[idx]],
+ frame_timecode[frame_points[idx+1]]))
+
+ # Reuse video dir
+ if save_dir == "":
+ save_dir = os.path.dirname(video_path)
+ # Ensure save dir exists
+ elif not os.path.isdir(save_dir):
+ os.makedirs(save_dir)
+
+ clip_info_path = os.path.join(save_dir, os.path.splitext(os.path.basename(video_path))[0] + '.json')
+
+ output_file_template = os.path.join(save_dir, name_template)
+ split_video_ffmpeg(
+ video_path,
+ output_scene_list,
+ arg_override=DEFAULT_FFMPEG_ARGS,
+ output_file_template=output_file_template,
+ show_progress=False,
+ show_output=False) # ffmpeg print
+
+ if save_json:
+ # Save clip info
+ json.dump(
+ [(frame_timecode_tuple[0].get_timecode(), frame_timecode_tuple[1].get_timecode()) for frame_timecode_tuple in output_scene_list],
+ open(clip_info_path, 'w'),
+ indent=2
+ )
+
+ return clip_info_path
+
+
+def process_single_video(args):
+ video, threshold, frame_skip, min_seconds, max_seconds, save_dir, name_template, save_json = args
+ basename = os.path.splitext(os.path.basename(video))[0]
+ # Video URL
+ if video.startswith("http"):
+ save_path = os.path.join(tmp_file_dir, f"{basename}.mp4")
+ download_success = download_video(video, save_path)
+ if not download_success:
+ return
+ video = save_path
+ # Local video path
+ else:
+ if not os.path.isfile(video):
+ print(f"Video not exists: {video}")
+ return
+ # SceneDetect video cut
+ try:
+ split_video_into_scenes(
+ video_path=video,
+ threshold=threshold,
+ frame_skip=frame_skip,
+ min_seconds=min_seconds,
+ max_seconds=max_seconds,
+ save_dir=save_dir,
+ name_template=name_template,
+ save_json=save_json
+ )
+ except Exception as e:
+ print(e, video)
+
+
+def main():
+ # Args
+ args = parse_args()
+ video_input = args.video
+ threshold = args.threshold
+ frame_skip = args.frame_skip
+ min_seconds = args.min_seconds
+ max_seconds = args.max_seconds
+ save_dir = args.save_dir
+ name_template = args.name_template
+ num_processes = args.num_processes
+ save_json = args.save_json
+
+ assert len(threshold) == len(frame_skip), \
+ "Threshold must one-to-one match frame_skip."
+
+ video_list = get_video_path_list(video_input)
+ args_list = [
+ (video, threshold, frame_skip, min_seconds, max_seconds, save_dir, name_template, save_json)
+ for video in video_list
+ ]
+
+ with Pool(processes=num_processes) as pool:
+ with tqdm(total=len(video_list)) as progress_bar:
+ for _ in pool.imap_unordered(process_single_video, args_list):
+ progress_bar.update(1)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/easyanimate/video_caption/stage_1_video_cut.sh b/easyanimate/video_caption/stage_1_video_cut.sh
new file mode 100644
index 0000000000000000000000000000000000000000..817f3142ac61df896b04d2d0e04ef57cd020f115
--- /dev/null
+++ b/easyanimate/video_caption/stage_1_video_cut.sh
@@ -0,0 +1,11 @@
+export VIDEO_FOLDER="datasets/panda_70m/before_vcut/"
+export OUTPUT_FOLDER="datasets/panda_70m/train/"
+
+# Cut raw videos
+python scenedetect_vcut.py \
+ $VIDEO_FOLDER \
+ --threshold 10 20 30 \
+ --frame_skip 0 1 2 \
+ --min_seconds 3 \
+ --max_seconds 10 \
+ --save_dir $OUTPUT_FOLDER
\ No newline at end of file
diff --git a/easyanimate/video_caption/stage_2_filter_data.sh b/easyanimate/video_caption/stage_2_filter_data.sh
new file mode 100644
index 0000000000000000000000000000000000000000..b3d9dc41d6b6d3ded0053c45d2cce10bc949ea90
--- /dev/null
+++ b/easyanimate/video_caption/stage_2_filter_data.sh
@@ -0,0 +1,39 @@
+export VIDEO_FOLDER="datasets/panda_70m/train"
+export FRAME_QUALITY_SAVE_PATH="datasets/panda_70m/aesthetic_score.jsonl"
+export TEXT_SCORE_SAVE_PATH="datasets/panda_70m/text_score.jsonl"
+export MOTION_SCORE_SAVE_PATH="datasets/panda_70m/motion_score.jsonl"
+export FILTER_BY_MOTION_SCORE_SAVE_PATH="datasets/panda_70m/train.jsonl"
+
+# Get asethetic score of all videos
+CUDA_VISIBLE_DEVICES="0" accelerate launch compute_video_frame_quality.py \
+ --video_folder=$VIDEO_FOLDER \
+ --video_path_column="video_path" \
+ --metrics="AestheticScore" \
+ --saved_freq=10 \
+ --saved_path=$FRAME_QUALITY_SAVE_PATH \
+ --batch_size=8
+
+# Get text score of all videos
+CUDA_VISIBLE_DEVICES="0" accelerate launch compute_text_score.py \
+ --video_folder=$VIDEO_FOLDER \
+ --video_path_column="video_path" \
+ --saved_freq=10 \
+ --saved_path=$TEXT_SCORE_SAVE_PATH \
+ --asethetic_score_metadata_path $FRAME_QUALITY_SAVE_PATH
+
+# Get motion score after filter videos by asethetic score and text score
+python compute_motion_score.py \
+ --video_folder=$VIDEO_FOLDER \
+ --video_path_column="video_path" \
+ --saved_freq=10 \
+ --saved_path=$MOTION_SCORE_SAVE_PATH \
+ --n_jobs=8 \
+ --asethetic_score_metadata_path $FRAME_QUALITY_SAVE_PATH \
+ --text_score_metadata_path $TEXT_SCORE_SAVE_PATH
+
+# Filter videos by motion score
+python filter_videos_by_motion_score.py \
+ --motion_score_metadata_path $MOTION_SCORE_SAVE_PATH \
+ --low_motion_score_threshold=3 \
+ --high_motion_score_threshold=8 \
+ --saved_path $FILTER_BY_MOTION_SCORE_SAVE_PATH
diff --git a/easyanimate/video_caption/stage_3_video_caption.sh b/easyanimate/video_caption/stage_3_video_caption.sh
new file mode 100644
index 0000000000000000000000000000000000000000..68bb0a870cf3934cd544b92eb14ffc2205a44d63
--- /dev/null
+++ b/easyanimate/video_caption/stage_3_video_caption.sh
@@ -0,0 +1,35 @@
+export CUDA_VISIBLE_DEVICES=0
+export VIDEO_FOLDER="datasets/panda_70m/train/"
+export MOTION_SCORE_META_PATH="datasets/panda_70m/train.jsonl"
+export VIDEO_FRAME_CAPTION_PATH="datasets/panda_70m/frame_caption.jsonl"
+export VIDEO_CAPTION_PATH="datasets/panda_70m/summary_caption.jsonl"
+export LAST_JSON_PATH="datasets/panda_70m/train_panda_70m.json"
+
+CUDA_VISIBLE_DEVICES="0" python caption_video_frame.py \
+ --video_metadata_path=$MOTION_SCORE_META_PATH \
+ --video_folder=$VIDEO_FOLDER \
+ --frame_sample_method="mid" \
+ --num_sampled_frames=1 \
+ --image_caption_model_name="llava-v1.6-vicuna-7b" \
+ --image_caption_prompt="Please describe this image in detail." \
+ --saved_path=$VIDEO_FRAME_CAPTION_PATH \
+ --output_dir="tmp"
+
+CUDA_VISIBLE_DEVICES="0" python caption_summary.py \
+ --video_metadata_path=$VIDEO_FRAME_CAPTION_PATH \
+ --video_path_column="video_path" \
+ --caption_column="sampled_frame_caption" \
+ --summary_model_name="Qwen/Qwen1.5-7B-Chat" \
+ --summary_prompt="You are a helpful video description generator. I'll give you a description of the middle frame of the video clip, \
+ which you need to summarize it into a description of the video clip. \
+ Please provide your video description following these requirements: \
+ 1. Describe the basic and necessary information of the video in the third person, be as concise as possible. \
+ 2. Output the video description directly. Begin with 'In this video'. \
+ 3. Limit the video description within 100 words. \
+ Here is the mid-frame description: " \
+ --saved_path=$VIDEO_CAPTION_PATH
+
+python convert_jsonl_to_json.py \
+ --video_folder=$VIDEO_FOLDER \
+ --jsonl_load_path=$VIDEO_CAPTION_PATH \
+ --save_path=$LAST_JSON_PATH
\ No newline at end of file
diff --git a/easyanimate/video_caption/utils/image_captioner_awq.py b/easyanimate/video_caption/utils/image_captioner_awq.py
new file mode 100644
index 0000000000000000000000000000000000000000..b65b5f45984b099e1f35c34ee5bc18606bb8f4e2
--- /dev/null
+++ b/easyanimate/video_caption/utils/image_captioner_awq.py
@@ -0,0 +1,90 @@
+from pathlib import Path
+from typing import Tuple
+
+import auto_gptq
+import torch
+from auto_gptq.modeling import BaseGPTQForCausalLM
+from transformers import AutoModelForCausalLM, AutoTokenizer
+
+
+class QwenVLChat:
+ def __init__(self, device: str = "cuda:0", quantized: bool = False) -> None:
+ if quantized:
+ self.model = AutoModelForCausalLM.from_pretrained(
+ "Qwen/Qwen-VL-Chat-Int4", device_map=device, trust_remote_code=True
+ ).eval()
+ self.tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen-VL-Chat-Int4", trust_remote_code=True)
+ else:
+ self.model = AutoModelForCausalLM.from_pretrained(
+ "Qwen/Qwen-VL-Chat", device_map=device, trust_remote_code=True, fp16=True
+ ).eval()
+ self.tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen-VL-Chat", trust_remote_code=True)
+
+ def __call__(self, prompt: str, image: str) -> Tuple[str, str]:
+ query = self.tokenizer.from_list_format([{"image": image}, {"text": prompt}])
+ response, history = self.model.chat(self.tokenizer, query=query, history=[])
+
+ return response, history
+
+
+class InternLMXComposer2QForCausalLM(BaseGPTQForCausalLM):
+ layers_block_name = "model.layers"
+ outside_layer_modules = [
+ "vit",
+ "vision_proj",
+ "model.tok_embeddings",
+ "model.norm",
+ "output",
+ ]
+ inside_layer_modules = [
+ ["attention.wqkv.linear"],
+ ["attention.wo.linear"],
+ ["feed_forward.w1.linear", "feed_forward.w3.linear"],
+ ["feed_forward.w2.linear"],
+ ]
+
+
+class InternLMXComposer2:
+ def __init__(self, device: str = "cuda:0", quantized: bool = True):
+ if quantized:
+ auto_gptq.modeling._base.SUPPORTED_MODELS = ["internlm"]
+ self.model = InternLMXComposer2QForCausalLM.from_quantized(
+ "internlm/internlm-xcomposer2-vl-7b-4bit", trust_remote_code=True, device=device
+ ).eval()
+ self.tokenizer = AutoTokenizer.from_pretrained("internlm/internlm-xcomposer2-vl-7b-4bit", trust_remote_code=True)
+ else:
+ # Setting fp16=True does not work. See https://huggingface.co/internlm/internlm-xcomposer2-vl-7b/discussions/1.
+ self.model = (
+ AutoModelForCausalLM.from_pretrained(
+ "internlm/internlm-xcomposer2-vl-7b", device_map=device, trust_remote_code=True
+ )
+ .eval()
+ .to(torch.float16)
+ )
+ self.tokenizer = AutoTokenizer.from_pretrained("internlm/internlm-xcomposer2-vl-7b", trust_remote_code=True)
+
+ def __call__(self, prompt: str, image: str):
+ if not prompt.startswith(""):
+ prompt = "" + prompt
+ with torch.cuda.amp.autocast(), torch.no_grad():
+ response, history = self.model.chat(self.tokenizer, query=prompt, image=image, history=[], do_sample=False)
+ return response, history
+
+
+if __name__ == "__main__":
+ image_folder = "demo/"
+ wildcard_list = ["*.jpg", "*.png"]
+ image_list = []
+ for wildcard in wildcard_list:
+ image_list.extend([str(image_path) for image_path in Path(image_folder).glob(wildcard)])
+ qwen_vl_chat = QwenVLChat(device="cuda:0", quantized=True)
+ qwen_vl_prompt = "Please describe this image in detail."
+ for image in image_list:
+ response, _ = qwen_vl_chat(qwen_vl_prompt, image)
+ print(image, response)
+
+ internlm2_vl = InternLMXComposer2(device="cuda:0", quantized=False)
+ internlm2_vl_prompt = "Please describe this image in detail."
+ for image in image_list:
+ response, _ = internlm2_vl(internlm2_vl_prompt, image)
+ print(image, response)
diff --git a/easyanimate/video_caption/utils/image_captioner_sglang.py b/easyanimate/video_caption/utils/image_captioner_sglang.py
new file mode 100644
index 0000000000000000000000000000000000000000..050d787e2c13ecae246e4614485f9e6cd5c82646
--- /dev/null
+++ b/easyanimate/video_caption/utils/image_captioner_sglang.py
@@ -0,0 +1,90 @@
+import os
+import time
+from datetime import datetime
+from typing import List, Union
+from pathlib import Path
+
+import sglang as sgl
+from PIL import Image
+
+from utils.logger import logger
+
+TMP_DIR = "./tmp"
+
+
+def get_timestamp():
+ timestamp_ns = int(time.time_ns())
+ milliseconds = timestamp_ns // 1000000
+ formatted_time = datetime.fromtimestamp(milliseconds / 1000).strftime("%Y-%m-%d_%H-%M-%S-%f")[:-3]
+
+ return formatted_time
+
+
+class LLaVASRT:
+ def __init__(self, device: str = "cuda:0", quantized: bool = True):
+ self.runtime = sgl.Runtime(model_path="liuhaotian/llava-v1.6-vicuna-7b", tokenizer_path="llava-hf/llava-1.5-7b-hf")
+ sgl.set_default_backend(self.runtime)
+ logger.info(
+ f"Start the SGLang runtime for llava-v1.6-vicuna-7b with chat template: {self.runtime.endpoint.chat_template.name}. "
+ "Input parameter device and quantized do not take effect."
+ )
+ if not os.path.exists(TMP_DIR):
+ os.makedirs(TMP_DIR, exist_ok=True)
+
+ @sgl.function
+ def image_qa(s, prompt: str, image: str):
+ s += sgl.user(sgl.image(image) + prompt)
+ s += sgl.assistant(sgl.gen("answer"))
+
+ def __call__(self, prompt: Union[str, List[str]], image: Union[str, Image.Image, List[str]]):
+ pil_input_flag = False
+ if isinstance(prompt, str) and (isinstance(image, str) or isinstance(image, Image.Image)):
+ if isinstance(image, Image.Image):
+ pil_input_flag = True
+ image_path = os.path.join(TMP_DIR, get_timestamp() + ".jpg")
+ image.save(image_path)
+ state = self.image_qa.run(prompt=prompt, image=image, max_new_tokens=256)
+ # Post-process.
+ if pil_input_flag:
+ os.remove(image)
+
+ return state["answer"], state
+ elif isinstance(prompt, list) and isinstance(image, list):
+ assert len(prompt) == len(image)
+ if isinstance(image[0], Image.Image):
+ pil_input_flag = True
+ image_path = [os.path.join(TMP_DIR, get_timestamp() + f"-{i}" + ".jpg") for i in range(len(image))]
+ for i in range(len(image)):
+ image[i].save(image_path[i])
+ image = image_path
+ batch_query = [{"prompt": p, "image": img} for p, img in zip(prompt, image)]
+ state = self.image_qa.run_batch(batch_query, max_new_tokens=256)
+ # Post-process.
+ if pil_input_flag:
+ for i in range(len(image)):
+ os.remove(image[i])
+
+ return [s["answer"] for s in state], state
+ else:
+ raise ValueError("Input prompt and image must be both strings or list of strings with the same length.")
+
+ def __del__(self):
+ self.runtime.shutdown()
+
+
+if __name__ == "__main__":
+ image_folder = "demo/"
+ wildcard_list = ["*.jpg", "*.png"]
+ image_list = []
+ for wildcard in wildcard_list:
+ image_list.extend([str(image_path) for image_path in Path(image_folder).glob(wildcard)])
+ # SGLang need the exclusive GPU and cannot re-initialize CUDA in forked subprocess.
+ llava_srt = LLaVASRT()
+ # Batch inference.
+ llava_srt_prompt = ["Please describe this image in detail."] * len(image_list)
+ response, _ = llava_srt(llava_srt_prompt, image_list)
+ print(response)
+ llava_srt_prompt = "Please describe this image in detail."
+ for image in image_list:
+ response, _ = llava_srt(llava_srt_prompt, image)
+ print(image, response)
\ No newline at end of file
diff --git a/easyanimate/video_caption/utils/image_evaluator.py b/easyanimate/video_caption/utils/image_evaluator.py
new file mode 100644
index 0000000000000000000000000000000000000000..3db62f7a12d36b0945bef63ef5b9a09cc1dec8e6
--- /dev/null
+++ b/easyanimate/video_caption/utils/image_evaluator.py
@@ -0,0 +1,130 @@
+import os
+from typing import List
+
+import clip
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from PIL import Image
+from torchvision.datasets.utils import download_url
+from transformers import AutoModel, AutoProcessor
+
+# All metrics.
+__all__ = ["AestheticScore", "CLIPScore"]
+
+_MODELS = {
+ "CLIP_ViT-L/14": "https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/easyanimate/video_caption/clip/ViT-L-14.pt",
+ "Aesthetics_V2": "https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/easyanimate/video_caption/clip/sac%2Blogos%2Bava1-l14-linearMSE.pth",
+}
+_MD5 = {
+ "CLIP_ViT-L/14": "096db1af569b284eb76b3881534822d9",
+ "Aesthetics_V2": "b1047fd767a00134b8fd6529bf19521a",
+}
+
+
+# if you changed the MLP architecture during training, change it also here:
+class _MLP(nn.Module):
+ def __init__(self, input_size):
+ super().__init__()
+ self.input_size = input_size
+ self.layers = nn.Sequential(
+ nn.Linear(self.input_size, 1024),
+ # nn.ReLU(),
+ nn.Dropout(0.2),
+ nn.Linear(1024, 128),
+ # nn.ReLU(),
+ nn.Dropout(0.2),
+ nn.Linear(128, 64),
+ # nn.ReLU(),
+ nn.Dropout(0.1),
+ nn.Linear(64, 16),
+ # nn.ReLU(),
+ nn.Linear(16, 1),
+ )
+
+ def forward(self, x):
+ return self.layers(x)
+
+
+class AestheticScore:
+ """Compute LAION Aesthetics Score V2 based on openai/clip. Note that the default
+ inference dtype with GPUs is fp16 in openai/clip.
+
+ Ref:
+ 1. https://github.com/christophschuhmann/improved-aesthetic-predictor/blob/main/simple_inference.py.
+ 2. https://github.com/openai/CLIP/issues/30.
+ """
+
+ def __init__(self, root: str = "~/.cache/clip", device: str = "cpu"):
+ # The CLIP model is loaded in the evaluation mode.
+ self.root = os.path.expanduser(root)
+ if not os.path.exists(self.root):
+ os.makedirs(self.root)
+ filename = "ViT-L-14.pt"
+ download_url(_MODELS["CLIP_ViT-L/14"], self.root, filename=filename, md5=_MD5["CLIP_ViT-L/14"])
+ self.clip_model, self.preprocess = clip.load(os.path.join(self.root, filename), device=device)
+ self.device = device
+ self._load_mlp()
+
+ def _load_mlp(self):
+ filename = "sac+logos+ava1-l14-linearMSE.pth"
+ download_url(_MODELS["Aesthetics_V2"], self.root, filename=filename, md5=_MD5["Aesthetics_V2"])
+ state_dict = torch.load(os.path.join(self.root, filename))
+ self.mlp = _MLP(768)
+ self.mlp.load_state_dict(state_dict)
+ self.mlp.to(self.device)
+ self.mlp.eval()
+
+ def __call__(self, images: List[Image.Image], texts=None) -> List[float]:
+ with torch.no_grad():
+ images = torch.stack([self.preprocess(image) for image in images]).to(self.device)
+ image_embs = F.normalize(self.clip_model.encode_image(images))
+ scores = self.mlp(image_embs.float()) # torch.float16 -> torch.float32, [N, 1]
+ return scores.squeeze().tolist()
+
+ def __repr__(self) -> str:
+ return "aesthetic_score"
+
+
+class CLIPScore:
+ """Compute CLIP scores for image-text pairs based on huggingface/transformers."""
+
+ def __init__(
+ self,
+ model_name_or_path: str = "openai/clip-vit-large-patch14",
+ torch_dtype=torch.float16,
+ device: str = "cpu",
+ ):
+ self.model = AutoModel.from_pretrained(model_name_or_path, torch_dtype=torch_dtype).eval().to(device)
+ self.processor = AutoProcessor.from_pretrained(model_name_or_path)
+ self.torch_dtype = torch_dtype
+ self.device = device
+
+ def __call__(self, images: List[Image.Image], texts: List[str]) -> List[float]:
+ assert len(images) == len(texts)
+ image_inputs = self.processor(images=images, return_tensors="pt") # {"pixel_values": }
+ if self.torch_dtype == torch.float16:
+ image_inputs["pixel_values"] = image_inputs["pixel_values"].half()
+ text_inputs = self.processor(text=texts, return_tensors="pt", padding=True, truncation=True) # {"inputs_id": }
+ image_inputs, text_inputs = image_inputs.to(self.device), text_inputs.to(self.device)
+ with torch.no_grad():
+ image_embs = F.normalize(self.model.get_image_features(**image_inputs))
+ text_embs = F.normalize(self.model.get_text_features(**text_inputs))
+ scores = text_embs @ image_embs.T # [N, N]
+
+ return scores.diagonal().tolist()
+
+ def __repr__(self) -> str:
+ return "clip_score"
+
+
+if __name__ == "__main__":
+ aesthetic_score = AestheticScore(device="cuda")
+ clip_score = CLIPScore(device="cuda")
+
+ paths = ["demo/splash_cl2_midframe.jpg"] * 3
+ texts = ["a joker", "a woman", "a man"]
+ images = [Image.open(p).convert("RGB") for p in paths]
+
+ print(aesthetic_score(images))
+ print(clip_score(images, texts))
\ No newline at end of file
diff --git a/easyanimate/video_caption/utils/logger.py b/easyanimate/video_caption/utils/logger.py
new file mode 100644
index 0000000000000000000000000000000000000000..754eaf6b379aa39e8b9469c95e17c8ec8128e30d
--- /dev/null
+++ b/easyanimate/video_caption/utils/logger.py
@@ -0,0 +1,36 @@
+# Borrowed from sd-webui-controlnet/scripts/logging.py
+import copy
+import logging
+import sys
+
+
+class ColoredFormatter(logging.Formatter):
+ COLORS = {
+ "DEBUG": "\033[0;36m", # CYAN
+ "INFO": "\033[0;32m", # GREEN
+ "WARNING": "\033[0;33m", # YELLOW
+ "ERROR": "\033[0;31m", # RED
+ "CRITICAL": "\033[0;37;41m", # WHITE ON RED
+ "RESET": "\033[0m", # RESET COLOR
+ }
+
+ def format(self, record):
+ colored_record = copy.copy(record)
+ levelname = colored_record.levelname
+ seq = self.COLORS.get(levelname, self.COLORS["RESET"])
+ colored_record.levelname = f"{seq}{levelname}{self.COLORS['RESET']}"
+ return super().format(colored_record)
+
+
+# Create a new logger
+logger = logging.getLogger("VideoCaption")
+logger.propagate = False
+
+# Add handler if we don't have one.
+if not logger.handlers:
+ handler = logging.StreamHandler(sys.stdout)
+ handler.setFormatter(ColoredFormatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s"))
+ logger.addHandler(handler)
+
+# Configure logger
+logger.setLevel("INFO")
diff --git a/easyanimate/video_caption/utils/video_dataset.py b/easyanimate/video_caption/utils/video_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..537c4110627d9edb267f96df591c90351b7db0fb
--- /dev/null
+++ b/easyanimate/video_caption/utils/video_dataset.py
@@ -0,0 +1,83 @@
+from pathlib import Path
+
+import pandas as pd
+from func_timeout import FunctionTimedOut, func_timeout
+from torch.utils.data import DataLoader, Dataset
+
+from utils.logger import logger
+from utils.video_utils import get_video_path_list, extract_frames
+
+ALL_VIDEO_EXT = set(["mp4", "webm", "mkv", "avi", "flv", "mov"])
+VIDEO_READER_TIMEOUT = 10
+
+
+def collate_fn(batch):
+ batch = list(filter(lambda x: x is not None, batch))
+ if len(batch) != 0:
+ return {k: [item[k] for item in batch] for k in batch[0].keys()}
+ return {}
+
+
+class VideoDataset(Dataset):
+ def __init__(
+ self,
+ video_path_list=None,
+ video_folder=None,
+ video_metadata_path=None,
+ video_path_column=None,
+ sample_method="mid",
+ num_sampled_frames=1,
+ num_sample_stride=None,
+ ):
+ self.video_path_column = video_path_column
+ self.video_folder = video_folder
+ self.sample_method = sample_method
+ self.num_sampled_frames = num_sampled_frames
+ self.num_sample_stride = num_sample_stride
+
+ if video_path_list is not None:
+ self.video_path_list = video_path_list
+ self.metadata_df = pd.DataFrame({video_path_column: self.video_path_list})
+ else:
+ self.video_path_list = get_video_path_list(
+ video_folder=video_folder,
+ video_metadata_path=video_metadata_path,
+ video_path_column=video_path_column
+ )
+
+ def __getitem__(self, index):
+ # video_path = os.path.join(self.video_folder, str(self.video_path_list[index]))
+ video_path = self.video_path_list[index]
+ try:
+ sample_args = (video_path, self.sample_method, self.num_sampled_frames, self.num_sample_stride)
+ sampled_frame_idx_list, sampled_frame_list = func_timeout(
+ VIDEO_READER_TIMEOUT, extract_frames, args=sample_args
+ )
+ except FunctionTimedOut:
+ logger.warning(f"Read {video_path} timeout.")
+ return None
+ except Exception as e:
+ logger.warning(f"Failed to extract frames from video {video_path}. Error is {e}.")
+ return None
+ item = {
+ "video_path": Path(video_path).name,
+ "sampled_frame_idx": sampled_frame_idx_list,
+ "sampled_frame": sampled_frame_list,
+ }
+
+ return item
+
+ def __len__(self):
+ return len(self.video_path_list)
+
+
+if __name__ == "__main__":
+ video_folder = "your_video_folder"
+ video_dataset = VideoDataset(video_folder=video_folder)
+
+ video_dataloader = DataLoader(
+ video_dataset, batch_size=16, num_workers=16, collate_fn=collate_fn
+ )
+ for idx, batch in enumerate(video_dataloader):
+ if len(batch) != 0:
+ print(batch["video_path"], batch["sampled_frame_idx"], len(batch["video_path"]))
\ No newline at end of file
diff --git a/easyanimate/video_caption/utils/video_utils.py b/easyanimate/video_caption/utils/video_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..4219a208fd2bf9dfedba4daa77d9dc9dae373bdb
--- /dev/null
+++ b/easyanimate/video_caption/utils/video_utils.py
@@ -0,0 +1,108 @@
+import gc
+import os
+import random
+import urllib.request as request
+from contextlib import contextmanager
+from pathlib import Path
+from typing import List, Tuple, Optional
+
+import numpy as np
+import pandas as pd
+from decord import VideoReader
+from PIL import Image
+
+ALL_VIDEO_EXT = set([".mp4", ".webm", ".mkv", ".avi", ".flv", ".mov"])
+
+
+def get_video_path_list(
+ video_folder: Optional[str]=None,
+ video_metadata_path: Optional[str]=None,
+ video_path_column: Optional[str]=None
+) -> List[str]:
+ """Get all video (absolute) path list from the video folder or the video metadata file.
+
+ Args:
+ video_folder (str): The absolute path of the folder (including sub-folders) containing all the required video files.
+ video_metadata_path (str): The absolute path of the video metadata file containing video path list.
+ video_path_column (str): The column/key for the corresponding video path in the video metadata file (csv/jsonl).
+ """
+ if video_folder is None and video_metadata_path is None:
+ raise ValueError("Either the video_input or the video_metadata_path should be specified.")
+ if video_metadata_path is not None:
+ if video_metadata_path.endswith(".csv"):
+ if video_path_column is None:
+ raise ValueError("The video_path_column can not be None if provided a csv file.")
+ metadata_df = pd.read_csv(video_metadata_path)
+ video_path_list = metadata_df[video_path_column].tolist()
+ elif video_metadata_path.endswith(".jsonl"):
+ if video_path_column is None:
+ raise ValueError("The video_path_column can not be None if provided a jsonl file.")
+ metadata_df = pd.read_json(video_metadata_path, lines=True)
+ video_path_list = metadata_df[video_path_column].tolist()
+ elif video_metadata_path.endswith(".txt"):
+ with open(video_metadata_path, "r", encoding="utf-8") as f:
+ video_path_list = [line.strip() for line in f]
+ else:
+ raise ValueError("The video_metadata_path must end with `.csv`, `.jsonl` or `.txt`.")
+ if video_folder is not None:
+ video_path_list = [os.path.join(video_folder, video_path) for video_path in video_path_list]
+ return video_path_list
+
+ if os.path.isfile(video_folder):
+ video_path_list = []
+ if video_folder.endswith("mp4"):
+ video_path_list.append(video_folder)
+ elif video_folder.endswith("txt"):
+ with open(video_folder, "r") as file:
+ video_path_list += [line.strip() for line in file.readlines()]
+ return video_path_list
+
+ elif video_folder is not None:
+ video_path_list = []
+ for ext in ALL_VIDEO_EXT:
+ video_path_list.extend(Path(video_folder).rglob(f"*{ext}"))
+ video_path_list = [str(video_path) for video_path in video_path_list]
+ return video_path_list
+
+
+@contextmanager
+def video_reader(*args, **kwargs):
+ """A context manager to solve the memory leak of decord.
+ """
+ vr = VideoReader(*args, **kwargs)
+ try:
+ yield vr
+ finally:
+ del vr
+ gc.collect()
+
+
+def extract_frames(
+ video_path: str, sample_method: str = "mid", num_sampled_frames: int = -1, sample_stride: int = -1
+) -> Optional[Tuple[List[int], List[Image.Image]]]:
+ with video_reader(video_path, num_threads=2) as vr:
+ if sample_method == "mid":
+ sampled_frame_idx_list = [len(vr) // 2]
+ elif sample_method == "uniform":
+ sampled_frame_idx_list = np.linspace(0, len(vr), num_sampled_frames, endpoint=False, dtype=int)
+ elif sample_method == "random":
+ clip_length = min(len(vr), (num_sampled_frames - 1) * sample_stride + 1)
+ start_idx = random.randint(0, len(vr) - clip_length)
+ sampled_frame_idx_list = np.linspace(start_idx, start_idx + clip_length - 1, num_sampled_frames, dtype=int)
+ else:
+ raise ValueError("The sample_method must be mid, uniform or random.")
+ sampled_frame_list = vr.get_batch(sampled_frame_idx_list).asnumpy()
+ sampled_frame_list = [Image.fromarray(frame) for frame in sampled_frame_list]
+
+ return list(sampled_frame_idx_list), sampled_frame_list
+
+
+def download_video(
+ video_url: str,
+ save_path: str) -> bool:
+ try:
+ request.urlretrieve(video_url, save_path)
+ return os.path.isfile(save_path)
+ except Exception as e:
+ print(e, video_url)
+ return False
\ No newline at end of file