kxhit commited on
Commit
5f093a6
·
1 Parent(s): 6d86936
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. 6DoF/CN_encoder.py +36 -0
  2. 6DoF/dataset.py +176 -0
  3. 6DoF/diffusers/__init__.py +281 -0
  4. 6DoF/diffusers/commands/__init__.py +27 -0
  5. 6DoF/diffusers/commands/diffusers_cli.py +41 -0
  6. 6DoF/diffusers/commands/env.py +84 -0
  7. 6DoF/diffusers/configuration_utils.py +664 -0
  8. 6DoF/diffusers/dependency_versions_check.py +47 -0
  9. 6DoF/diffusers/dependency_versions_table.py +44 -0
  10. 6DoF/diffusers/experimental/__init__.py +1 -0
  11. 6DoF/diffusers/experimental/rl/__init__.py +1 -0
  12. 6DoF/diffusers/experimental/rl/value_guided_sampling.py +152 -0
  13. 6DoF/diffusers/image_processor.py +366 -0
  14. 6DoF/diffusers/loaders.py +1492 -0
  15. 6DoF/diffusers/models/__init__.py +35 -0
  16. 6DoF/diffusers/models/activations.py +12 -0
  17. 6DoF/diffusers/models/attention.py +392 -0
  18. 6DoF/diffusers/models/attention_flax.py +446 -0
  19. 6DoF/diffusers/models/attention_processor.py +1684 -0
  20. 6DoF/diffusers/models/autoencoder_kl.py +411 -0
  21. 6DoF/diffusers/models/controlnet.py +705 -0
  22. 6DoF/diffusers/models/controlnet_flax.py +394 -0
  23. 6DoF/diffusers/models/cross_attention.py +94 -0
  24. 6DoF/diffusers/models/dual_transformer_2d.py +151 -0
  25. 6DoF/diffusers/models/embeddings.py +546 -0
  26. 6DoF/diffusers/models/embeddings_flax.py +95 -0
  27. 6DoF/diffusers/models/modeling_flax_pytorch_utils.py +118 -0
  28. 6DoF/diffusers/models/modeling_flax_utils.py +534 -0
  29. 6DoF/diffusers/models/modeling_pytorch_flax_utils.py +161 -0
  30. 6DoF/diffusers/models/modeling_utils.py +980 -0
  31. 6DoF/diffusers/models/prior_transformer.py +364 -0
  32. 6DoF/diffusers/models/resnet.py +877 -0
  33. 6DoF/diffusers/models/resnet_flax.py +124 -0
  34. 6DoF/diffusers/models/t5_film_transformer.py +321 -0
  35. 6DoF/diffusers/models/transformer_2d.py +343 -0
  36. 6DoF/diffusers/models/transformer_temporal.py +179 -0
  37. 6DoF/diffusers/models/unet_1d.py +255 -0
  38. 6DoF/diffusers/models/unet_1d_blocks.py +656 -0
  39. 6DoF/diffusers/models/unet_2d.py +329 -0
  40. 6DoF/diffusers/models/unet_2d_blocks.py +0 -0
  41. 6DoF/diffusers/models/unet_2d_blocks_flax.py +377 -0
  42. 6DoF/diffusers/models/unet_2d_condition.py +980 -0
  43. 6DoF/diffusers/models/unet_2d_condition_flax.py +357 -0
  44. 6DoF/diffusers/models/unet_3d_blocks.py +679 -0
  45. 6DoF/diffusers/models/unet_3d_condition.py +627 -0
  46. 6DoF/diffusers/models/vae.py +441 -0
  47. 6DoF/diffusers/models/vae_flax.py +869 -0
  48. 6DoF/diffusers/models/vq_model.py +167 -0
  49. 6DoF/diffusers/optimization.py +354 -0
  50. 6DoF/diffusers/pipeline_utils.py +29 -0
6DoF/CN_encoder.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import ConvNextV2Model
2
+ import torch
3
+ from typing import Optional
4
+ import einops
5
+
6
+ class CN_encoder(ConvNextV2Model):
7
+ def __init__(self, config):
8
+ super().__init__(config)
9
+
10
+ def forward(
11
+ self,
12
+ pixel_values: torch.FloatTensor = None,
13
+ output_hidden_states: Optional[bool] = None,
14
+ return_dict: Optional[bool] = None,
15
+ ):
16
+ output_hidden_states = (
17
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
18
+ )
19
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
20
+
21
+ if pixel_values is None:
22
+ raise ValueError("You have to specify pixel_values")
23
+
24
+ embedding_output = self.embeddings(pixel_values)
25
+
26
+ encoder_outputs = self.encoder(
27
+ embedding_output,
28
+ output_hidden_states=output_hidden_states,
29
+ return_dict=return_dict,
30
+ )
31
+
32
+ last_hidden_state = encoder_outputs[0]
33
+ image_embeddings = einops.rearrange(last_hidden_state, 'b c h w -> b (h w) c')
34
+ image_embeddings = self.layernorm(image_embeddings)
35
+
36
+ return image_embeddings
6DoF/dataset.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import math
3
+ from pathlib import Path
4
+ import torch
5
+ import torchvision
6
+ from torch.utils.data import Dataset, DataLoader
7
+ from torchvision import transforms
8
+ from PIL import Image
9
+ import numpy as np
10
+ import webdataset as wds
11
+ from torch.utils.data.distributed import DistributedSampler
12
+ import matplotlib.pyplot as plt
13
+ import sys
14
+
15
+ class ObjaverseDataLoader():
16
+ def __init__(self, root_dir, batch_size, total_view=12, num_workers=4):
17
+ self.root_dir = root_dir
18
+ self.batch_size = batch_size
19
+ self.num_workers = num_workers
20
+ self.total_view = total_view
21
+
22
+ image_transforms = [torchvision.transforms.Resize((256, 256)),
23
+ transforms.ToTensor(),
24
+ transforms.Normalize([0.5], [0.5])]
25
+ self.image_transforms = torchvision.transforms.Compose(image_transforms)
26
+
27
+ def train_dataloader(self):
28
+ dataset = ObjaverseData(root_dir=self.root_dir, total_view=self.total_view, validation=False,
29
+ image_transforms=self.image_transforms)
30
+ # sampler = DistributedSampler(dataset)
31
+ return wds.WebLoader(dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False)
32
+ # sampler=sampler)
33
+
34
+ def val_dataloader(self):
35
+ dataset = ObjaverseData(root_dir=self.root_dir, total_view=self.total_view, validation=True,
36
+ image_transforms=self.image_transforms)
37
+ sampler = DistributedSampler(dataset)
38
+ return wds.WebLoader(dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False)
39
+
40
+ def get_pose(transformation):
41
+ # transformation: 4x4
42
+ return transformation
43
+
44
+ class ObjaverseData(Dataset):
45
+ def __init__(self,
46
+ root_dir='.objaverse/hf-objaverse-v1/views',
47
+ image_transforms=None,
48
+ total_view=12,
49
+ validation=False,
50
+ T_in=1,
51
+ T_out=1,
52
+ fix_sample=False,
53
+ ) -> None:
54
+ """Create a dataset from a folder of images.
55
+ If you pass in a root directory it will be searched for images
56
+ ending in ext (ext can be a list)
57
+ """
58
+ self.root_dir = Path(root_dir)
59
+ self.total_view = total_view
60
+ self.T_in = T_in
61
+ self.T_out = T_out
62
+ self.fix_sample = fix_sample
63
+
64
+ self.paths = []
65
+ # # include all folders
66
+ # for folder in os.listdir(self.root_dir):
67
+ # if os.path.isdir(os.path.join(self.root_dir, folder)):
68
+ # self.paths.append(folder)
69
+ # load ids from .npy so we have exactly the same ids/order
70
+ self.paths = np.load("../scripts/obj_ids.npy")
71
+ # # only use 100K objects for ablation study
72
+ # self.paths = self.paths[:100000]
73
+ total_objects = len(self.paths)
74
+ assert total_objects == 790152, 'total objects %d' % total_objects
75
+ if validation:
76
+ self.paths = self.paths[math.floor(total_objects / 100. * 99.):] # used last 1% as validation
77
+ else:
78
+ self.paths = self.paths[:math.floor(total_objects / 100. * 99.)] # used first 99% as training
79
+ print('============= length of dataset %d =============' % len(self.paths))
80
+ self.tform = image_transforms
81
+
82
+ downscale = 512 / 256.
83
+ self.fx = 560. / downscale
84
+ self.fy = 560. / downscale
85
+ self.intrinsic = torch.tensor([[self.fx, 0, 128., 0, self.fy, 128., 0, 0, 1.]], dtype=torch.float64).view(3, 3)
86
+
87
+ def __len__(self):
88
+ return len(self.paths)
89
+
90
+ def get_pose(self, transformation):
91
+ # transformation: 4x4
92
+ return transformation
93
+
94
+
95
+ def load_im(self, path, color):
96
+ '''
97
+ replace background pixel with random color in rendering
98
+ '''
99
+ try:
100
+ img = plt.imread(path)
101
+ except:
102
+ print(path)
103
+ sys.exit()
104
+ img[img[:, :, -1] == 0.] = color
105
+ img = Image.fromarray(np.uint8(img[:, :, :3] * 255.))
106
+ return img
107
+
108
+ def __getitem__(self, index):
109
+ data = {}
110
+ total_view = 12
111
+
112
+ if self.fix_sample:
113
+ if self.T_out > 1:
114
+ indexes = range(total_view)
115
+ index_targets = list(indexes[:2]) + list(indexes[-(self.T_out-2):])
116
+ index_inputs = indexes[1:self.T_in+1] # one overlap identity
117
+ else:
118
+ indexes = range(total_view)
119
+ index_targets = indexes[:self.T_out]
120
+ index_inputs = indexes[self.T_out-1:self.T_in+self.T_out-1] # one overlap identity
121
+ else:
122
+ assert self.T_in + self.T_out <= total_view
123
+ # training with replace, including identity
124
+ indexes = np.random.choice(range(total_view), self.T_in+self.T_out, replace=True)
125
+ index_inputs = indexes[:self.T_in]
126
+ index_targets = indexes[self.T_in:]
127
+ filename = os.path.join(self.root_dir, self.paths[index])
128
+
129
+ color = [1., 1., 1., 1.]
130
+
131
+ try:
132
+ input_ims = []
133
+ target_ims = []
134
+ target_Ts = []
135
+ cond_Ts = []
136
+ for i, index_input in enumerate(index_inputs):
137
+ input_im = self.process_im(self.load_im(os.path.join(filename, '%03d.png' % index_input), color))
138
+ input_ims.append(input_im)
139
+ input_RT = np.load(os.path.join(filename, '%03d.npy' % index_input))
140
+ cond_Ts.append(self.get_pose(np.concatenate([input_RT[:3, :], np.array([[0, 0, 0, 1]])], axis=0)))
141
+ for i, index_target in enumerate(index_targets):
142
+ target_im = self.process_im(self.load_im(os.path.join(filename, '%03d.png' % index_target), color))
143
+ target_ims.append(target_im)
144
+ target_RT = np.load(os.path.join(filename, '%03d.npy' % index_target))
145
+ target_Ts.append(self.get_pose(np.concatenate([target_RT[:3, :], np.array([[0, 0, 0, 1]])], axis=0)))
146
+ except:
147
+ print('error loading data ', filename)
148
+ filename = os.path.join(self.root_dir, '0a01f314e2864711aa7e33bace4bd8c8') # this one we know is valid
149
+ input_ims = []
150
+ target_ims = []
151
+ target_Ts = []
152
+ cond_Ts = []
153
+ # very hacky solution, sorry about this
154
+ for i, index_input in enumerate(index_inputs):
155
+ input_im = self.process_im(self.load_im(os.path.join(filename, '%03d.png' % index_input), color))
156
+ input_ims.append(input_im)
157
+ input_RT = np.load(os.path.join(filename, '%03d.npy' % index_input))
158
+ cond_Ts.append(self.get_pose(np.concatenate([input_RT[:3, :], np.array([[0, 0, 0, 1]])], axis=0)))
159
+ for i, index_target in enumerate(index_targets):
160
+ target_im = self.process_im(self.load_im(os.path.join(filename, '%03d.png' % index_target), color))
161
+ target_ims.append(target_im)
162
+ target_RT = np.load(os.path.join(filename, '%03d.npy' % index_target))
163
+ target_Ts.append(self.get_pose(np.concatenate([target_RT[:3, :], np.array([[0, 0, 0, 1]])], axis=0)))
164
+
165
+ # stack to batch
166
+ data['image_input'] = torch.stack(input_ims, dim=0)
167
+ data['image_target'] = torch.stack(target_ims, dim=0)
168
+ data['pose_out'] = np.stack(target_Ts)
169
+ data['pose_out_inv'] = np.linalg.inv(np.stack(target_Ts)).transpose([0, 2, 1])
170
+ data['pose_in'] = np.stack(cond_Ts)
171
+ data['pose_in_inv'] = np.linalg.inv(np.stack(cond_Ts)).transpose([0, 2, 1])
172
+ return data
173
+
174
+ def process_im(self, im):
175
+ im = im.convert("RGB")
176
+ return self.tform(im)
6DoF/diffusers/__init__.py ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __version__ = "0.18.2"
2
+
3
+ from .configuration_utils import ConfigMixin
4
+ from .utils import (
5
+ OptionalDependencyNotAvailable,
6
+ is_flax_available,
7
+ is_inflect_available,
8
+ is_invisible_watermark_available,
9
+ is_k_diffusion_available,
10
+ is_k_diffusion_version,
11
+ is_librosa_available,
12
+ is_note_seq_available,
13
+ is_onnx_available,
14
+ is_scipy_available,
15
+ is_torch_available,
16
+ is_torchsde_available,
17
+ is_transformers_available,
18
+ is_transformers_version,
19
+ is_unidecode_available,
20
+ logging,
21
+ )
22
+
23
+
24
+ try:
25
+ if not is_onnx_available():
26
+ raise OptionalDependencyNotAvailable()
27
+ except OptionalDependencyNotAvailable:
28
+ from .utils.dummy_onnx_objects import * # noqa F403
29
+ else:
30
+ from .pipelines import OnnxRuntimeModel
31
+
32
+ try:
33
+ if not is_torch_available():
34
+ raise OptionalDependencyNotAvailable()
35
+ except OptionalDependencyNotAvailable:
36
+ from .utils.dummy_pt_objects import * # noqa F403
37
+ else:
38
+ from .models import (
39
+ AutoencoderKL,
40
+ ControlNetModel,
41
+ ModelMixin,
42
+ PriorTransformer,
43
+ T5FilmDecoder,
44
+ Transformer2DModel,
45
+ UNet1DModel,
46
+ UNet2DConditionModel,
47
+ UNet2DModel,
48
+ UNet3DConditionModel,
49
+ VQModel,
50
+ )
51
+ from .optimization import (
52
+ get_constant_schedule,
53
+ get_constant_schedule_with_warmup,
54
+ get_cosine_schedule_with_warmup,
55
+ get_cosine_with_hard_restarts_schedule_with_warmup,
56
+ get_linear_schedule_with_warmup,
57
+ get_polynomial_decay_schedule_with_warmup,
58
+ get_scheduler,
59
+ )
60
+ from .pipelines import (
61
+ AudioPipelineOutput,
62
+ ConsistencyModelPipeline,
63
+ DanceDiffusionPipeline,
64
+ DDIMPipeline,
65
+ DDPMPipeline,
66
+ DiffusionPipeline,
67
+ DiTPipeline,
68
+ ImagePipelineOutput,
69
+ KarrasVePipeline,
70
+ LDMPipeline,
71
+ LDMSuperResolutionPipeline,
72
+ PNDMPipeline,
73
+ RePaintPipeline,
74
+ ScoreSdeVePipeline,
75
+ )
76
+ from .schedulers import (
77
+ CMStochasticIterativeScheduler,
78
+ DDIMInverseScheduler,
79
+ DDIMParallelScheduler,
80
+ DDIMScheduler,
81
+ DDPMParallelScheduler,
82
+ DDPMScheduler,
83
+ DEISMultistepScheduler,
84
+ DPMSolverMultistepInverseScheduler,
85
+ DPMSolverMultistepScheduler,
86
+ DPMSolverSinglestepScheduler,
87
+ EulerAncestralDiscreteScheduler,
88
+ EulerDiscreteScheduler,
89
+ HeunDiscreteScheduler,
90
+ IPNDMScheduler,
91
+ KarrasVeScheduler,
92
+ KDPM2AncestralDiscreteScheduler,
93
+ KDPM2DiscreteScheduler,
94
+ PNDMScheduler,
95
+ RePaintScheduler,
96
+ SchedulerMixin,
97
+ ScoreSdeVeScheduler,
98
+ UnCLIPScheduler,
99
+ UniPCMultistepScheduler,
100
+ VQDiffusionScheduler,
101
+ )
102
+ from .training_utils import EMAModel
103
+
104
+ try:
105
+ if not (is_torch_available() and is_scipy_available()):
106
+ raise OptionalDependencyNotAvailable()
107
+ except OptionalDependencyNotAvailable:
108
+ from .utils.dummy_torch_and_scipy_objects import * # noqa F403
109
+ else:
110
+ from .schedulers import LMSDiscreteScheduler
111
+
112
+ try:
113
+ if not (is_torch_available() and is_torchsde_available()):
114
+ raise OptionalDependencyNotAvailable()
115
+ except OptionalDependencyNotAvailable:
116
+ from .utils.dummy_torch_and_torchsde_objects import * # noqa F403
117
+ else:
118
+ from .schedulers import DPMSolverSDEScheduler
119
+
120
+ try:
121
+ if not (is_torch_available() and is_transformers_available()):
122
+ raise OptionalDependencyNotAvailable()
123
+ except OptionalDependencyNotAvailable:
124
+ from .utils.dummy_torch_and_transformers_objects import * # noqa F403
125
+ else:
126
+ from .pipelines import (
127
+ AltDiffusionImg2ImgPipeline,
128
+ AltDiffusionPipeline,
129
+ AudioLDMPipeline,
130
+ CycleDiffusionPipeline,
131
+ IFImg2ImgPipeline,
132
+ IFImg2ImgSuperResolutionPipeline,
133
+ IFInpaintingPipeline,
134
+ IFInpaintingSuperResolutionPipeline,
135
+ IFPipeline,
136
+ IFSuperResolutionPipeline,
137
+ ImageTextPipelineOutput,
138
+ KandinskyImg2ImgPipeline,
139
+ KandinskyInpaintPipeline,
140
+ KandinskyPipeline,
141
+ KandinskyPriorPipeline,
142
+ KandinskyV22ControlnetImg2ImgPipeline,
143
+ KandinskyV22ControlnetPipeline,
144
+ KandinskyV22Img2ImgPipeline,
145
+ KandinskyV22InpaintPipeline,
146
+ KandinskyV22Pipeline,
147
+ KandinskyV22PriorEmb2EmbPipeline,
148
+ KandinskyV22PriorPipeline,
149
+ LDMTextToImagePipeline,
150
+ PaintByExamplePipeline,
151
+ SemanticStableDiffusionPipeline,
152
+ ShapEImg2ImgPipeline,
153
+ ShapEPipeline,
154
+ StableDiffusionAttendAndExcitePipeline,
155
+ StableDiffusionControlNetImg2ImgPipeline,
156
+ StableDiffusionControlNetInpaintPipeline,
157
+ StableDiffusionControlNetPipeline,
158
+ StableDiffusionDepth2ImgPipeline,
159
+ StableDiffusionDiffEditPipeline,
160
+ StableDiffusionImageVariationPipeline,
161
+ StableDiffusionImg2ImgPipeline,
162
+ StableDiffusionInpaintPipeline,
163
+ StableDiffusionInpaintPipelineLegacy,
164
+ StableDiffusionInstructPix2PixPipeline,
165
+ StableDiffusionLatentUpscalePipeline,
166
+ StableDiffusionLDM3DPipeline,
167
+ StableDiffusionModelEditingPipeline,
168
+ StableDiffusionPanoramaPipeline,
169
+ StableDiffusionParadigmsPipeline,
170
+ StableDiffusionPipeline,
171
+ StableDiffusionPipelineSafe,
172
+ StableDiffusionPix2PixZeroPipeline,
173
+ StableDiffusionSAGPipeline,
174
+ StableDiffusionUpscalePipeline,
175
+ StableUnCLIPImg2ImgPipeline,
176
+ StableUnCLIPPipeline,
177
+ TextToVideoSDPipeline,
178
+ TextToVideoZeroPipeline,
179
+ UnCLIPImageVariationPipeline,
180
+ UnCLIPPipeline,
181
+ UniDiffuserModel,
182
+ UniDiffuserPipeline,
183
+ UniDiffuserTextDecoder,
184
+ VersatileDiffusionDualGuidedPipeline,
185
+ VersatileDiffusionImageVariationPipeline,
186
+ VersatileDiffusionPipeline,
187
+ VersatileDiffusionTextToImagePipeline,
188
+ VideoToVideoSDPipeline,
189
+ VQDiffusionPipeline,
190
+ )
191
+
192
+ try:
193
+ if not (is_torch_available() and is_transformers_available() and is_invisible_watermark_available()):
194
+ raise OptionalDependencyNotAvailable()
195
+ except OptionalDependencyNotAvailable:
196
+ from .utils.dummy_torch_and_transformers_and_invisible_watermark_objects import * # noqa F403
197
+ else:
198
+ from .pipelines import StableDiffusionXLImg2ImgPipeline, StableDiffusionXLPipeline
199
+
200
+ try:
201
+ if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()):
202
+ raise OptionalDependencyNotAvailable()
203
+ except OptionalDependencyNotAvailable:
204
+ from .utils.dummy_torch_and_transformers_and_k_diffusion_objects import * # noqa F403
205
+ else:
206
+ from .pipelines import StableDiffusionKDiffusionPipeline
207
+
208
+ try:
209
+ if not (is_torch_available() and is_transformers_available() and is_onnx_available()):
210
+ raise OptionalDependencyNotAvailable()
211
+ except OptionalDependencyNotAvailable:
212
+ from .utils.dummy_torch_and_transformers_and_onnx_objects import * # noqa F403
213
+ else:
214
+ from .pipelines import (
215
+ OnnxStableDiffusionImg2ImgPipeline,
216
+ OnnxStableDiffusionInpaintPipeline,
217
+ OnnxStableDiffusionInpaintPipelineLegacy,
218
+ OnnxStableDiffusionPipeline,
219
+ OnnxStableDiffusionUpscalePipeline,
220
+ StableDiffusionOnnxPipeline,
221
+ )
222
+
223
+ try:
224
+ if not (is_torch_available() and is_librosa_available()):
225
+ raise OptionalDependencyNotAvailable()
226
+ except OptionalDependencyNotAvailable:
227
+ from .utils.dummy_torch_and_librosa_objects import * # noqa F403
228
+ else:
229
+ from .pipelines import AudioDiffusionPipeline, Mel
230
+
231
+ try:
232
+ if not (is_transformers_available() and is_torch_available() and is_note_seq_available()):
233
+ raise OptionalDependencyNotAvailable()
234
+ except OptionalDependencyNotAvailable:
235
+ from .utils.dummy_transformers_and_torch_and_note_seq_objects import * # noqa F403
236
+ else:
237
+ from .pipelines import SpectrogramDiffusionPipeline
238
+
239
+ try:
240
+ if not is_flax_available():
241
+ raise OptionalDependencyNotAvailable()
242
+ except OptionalDependencyNotAvailable:
243
+ from .utils.dummy_flax_objects import * # noqa F403
244
+ else:
245
+ from .models.controlnet_flax import FlaxControlNetModel
246
+ from .models.modeling_flax_utils import FlaxModelMixin
247
+ from .models.unet_2d_condition_flax import FlaxUNet2DConditionModel
248
+ from .models.vae_flax import FlaxAutoencoderKL
249
+ from .pipelines import FlaxDiffusionPipeline
250
+ from .schedulers import (
251
+ FlaxDDIMScheduler,
252
+ FlaxDDPMScheduler,
253
+ FlaxDPMSolverMultistepScheduler,
254
+ FlaxKarrasVeScheduler,
255
+ FlaxLMSDiscreteScheduler,
256
+ FlaxPNDMScheduler,
257
+ FlaxSchedulerMixin,
258
+ FlaxScoreSdeVeScheduler,
259
+ )
260
+
261
+
262
+ try:
263
+ if not (is_flax_available() and is_transformers_available()):
264
+ raise OptionalDependencyNotAvailable()
265
+ except OptionalDependencyNotAvailable:
266
+ from .utils.dummy_flax_and_transformers_objects import * # noqa F403
267
+ else:
268
+ from .pipelines import (
269
+ FlaxStableDiffusionControlNetPipeline,
270
+ FlaxStableDiffusionImg2ImgPipeline,
271
+ FlaxStableDiffusionInpaintPipeline,
272
+ FlaxStableDiffusionPipeline,
273
+ )
274
+
275
+ try:
276
+ if not (is_note_seq_available()):
277
+ raise OptionalDependencyNotAvailable()
278
+ except OptionalDependencyNotAvailable:
279
+ from .utils.dummy_note_seq_objects import * # noqa F403
280
+ else:
281
+ from .pipelines import MidiProcessor
6DoF/diffusers/commands/__init__.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from abc import ABC, abstractmethod
16
+ from argparse import ArgumentParser
17
+
18
+
19
+ class BaseDiffusersCLICommand(ABC):
20
+ @staticmethod
21
+ @abstractmethod
22
+ def register_subcommand(parser: ArgumentParser):
23
+ raise NotImplementedError()
24
+
25
+ @abstractmethod
26
+ def run(self):
27
+ raise NotImplementedError()
6DoF/diffusers/commands/diffusers_cli.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from argparse import ArgumentParser
17
+
18
+ from .env import EnvironmentCommand
19
+
20
+
21
+ def main():
22
+ parser = ArgumentParser("Diffusers CLI tool", usage="diffusers-cli <command> [<args>]")
23
+ commands_parser = parser.add_subparsers(help="diffusers-cli command helpers")
24
+
25
+ # Register commands
26
+ EnvironmentCommand.register_subcommand(commands_parser)
27
+
28
+ # Let's go
29
+ args = parser.parse_args()
30
+
31
+ if not hasattr(args, "func"):
32
+ parser.print_help()
33
+ exit(1)
34
+
35
+ # Run
36
+ service = args.func(args)
37
+ service.run()
38
+
39
+
40
+ if __name__ == "__main__":
41
+ main()
6DoF/diffusers/commands/env.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import platform
16
+ from argparse import ArgumentParser
17
+
18
+ import huggingface_hub
19
+
20
+ from .. import __version__ as version
21
+ from ..utils import is_accelerate_available, is_torch_available, is_transformers_available, is_xformers_available
22
+ from . import BaseDiffusersCLICommand
23
+
24
+
25
+ def info_command_factory(_):
26
+ return EnvironmentCommand()
27
+
28
+
29
+ class EnvironmentCommand(BaseDiffusersCLICommand):
30
+ @staticmethod
31
+ def register_subcommand(parser: ArgumentParser):
32
+ download_parser = parser.add_parser("env")
33
+ download_parser.set_defaults(func=info_command_factory)
34
+
35
+ def run(self):
36
+ hub_version = huggingface_hub.__version__
37
+
38
+ pt_version = "not installed"
39
+ pt_cuda_available = "NA"
40
+ if is_torch_available():
41
+ import torch
42
+
43
+ pt_version = torch.__version__
44
+ pt_cuda_available = torch.cuda.is_available()
45
+
46
+ transformers_version = "not installed"
47
+ if is_transformers_available():
48
+ import transformers
49
+
50
+ transformers_version = transformers.__version__
51
+
52
+ accelerate_version = "not installed"
53
+ if is_accelerate_available():
54
+ import accelerate
55
+
56
+ accelerate_version = accelerate.__version__
57
+
58
+ xformers_version = "not installed"
59
+ if is_xformers_available():
60
+ import xformers
61
+
62
+ xformers_version = xformers.__version__
63
+
64
+ info = {
65
+ "`diffusers` version": version,
66
+ "Platform": platform.platform(),
67
+ "Python version": platform.python_version(),
68
+ "PyTorch version (GPU?)": f"{pt_version} ({pt_cuda_available})",
69
+ "Huggingface_hub version": hub_version,
70
+ "Transformers version": transformers_version,
71
+ "Accelerate version": accelerate_version,
72
+ "xFormers version": xformers_version,
73
+ "Using GPU in script?": "<fill in>",
74
+ "Using distributed or parallel set-up in script?": "<fill in>",
75
+ }
76
+
77
+ print("\nCopy-and-paste the text below in your GitHub issue and FILL OUT the two last points.\n")
78
+ print(self.format_dict(info))
79
+
80
+ return info
81
+
82
+ @staticmethod
83
+ def format_dict(d):
84
+ return "\n".join([f"- {prop}: {val}" for prop, val in d.items()]) + "\n"
6DoF/diffusers/configuration_utils.py ADDED
@@ -0,0 +1,664 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The HuggingFace Inc. team.
3
+ # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """ ConfigMixin base class and utilities."""
17
+ import dataclasses
18
+ import functools
19
+ import importlib
20
+ import inspect
21
+ import json
22
+ import os
23
+ import re
24
+ from collections import OrderedDict
25
+ from pathlib import PosixPath
26
+ from typing import Any, Dict, Tuple, Union
27
+
28
+ import numpy as np
29
+ from huggingface_hub import hf_hub_download
30
+ from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
31
+ from requests import HTTPError
32
+
33
+ from . import __version__
34
+ from .utils import (
35
+ DIFFUSERS_CACHE,
36
+ HUGGINGFACE_CO_RESOLVE_ENDPOINT,
37
+ DummyObject,
38
+ deprecate,
39
+ extract_commit_hash,
40
+ http_user_agent,
41
+ logging,
42
+ )
43
+
44
+
45
+ logger = logging.get_logger(__name__)
46
+
47
+ _re_configuration_file = re.compile(r"config\.(.*)\.json")
48
+
49
+
50
+ class FrozenDict(OrderedDict):
51
+ def __init__(self, *args, **kwargs):
52
+ super().__init__(*args, **kwargs)
53
+
54
+ for key, value in self.items():
55
+ setattr(self, key, value)
56
+
57
+ self.__frozen = True
58
+
59
+ def __delitem__(self, *args, **kwargs):
60
+ raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.")
61
+
62
+ def setdefault(self, *args, **kwargs):
63
+ raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.")
64
+
65
+ def pop(self, *args, **kwargs):
66
+ raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.")
67
+
68
+ def update(self, *args, **kwargs):
69
+ raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.")
70
+
71
+ def __setattr__(self, name, value):
72
+ if hasattr(self, "__frozen") and self.__frozen:
73
+ raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.")
74
+ super().__setattr__(name, value)
75
+
76
+ def __setitem__(self, name, value):
77
+ if hasattr(self, "__frozen") and self.__frozen:
78
+ raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.")
79
+ super().__setitem__(name, value)
80
+
81
+
82
+ class ConfigMixin:
83
+ r"""
84
+ Base class for all configuration classes. All configuration parameters are stored under `self.config`. Also
85
+ provides the [`~ConfigMixin.from_config`] and [`~ConfigMixin.save_config`] methods for loading, downloading, and
86
+ saving classes that inherit from [`ConfigMixin`].
87
+
88
+ Class attributes:
89
+ - **config_name** (`str`) -- A filename under which the config should stored when calling
90
+ [`~ConfigMixin.save_config`] (should be overridden by parent class).
91
+ - **ignore_for_config** (`List[str]`) -- A list of attributes that should not be saved in the config (should be
92
+ overridden by subclass).
93
+ - **has_compatibles** (`bool`) -- Whether the class has compatible classes (should be overridden by subclass).
94
+ - **_deprecated_kwargs** (`List[str]`) -- Keyword arguments that are deprecated. Note that the `init` function
95
+ should only have a `kwargs` argument if at least one argument is deprecated (should be overridden by
96
+ subclass).
97
+ """
98
+ config_name = None
99
+ ignore_for_config = []
100
+ has_compatibles = False
101
+
102
+ _deprecated_kwargs = []
103
+
104
+ def register_to_config(self, **kwargs):
105
+ if self.config_name is None:
106
+ raise NotImplementedError(f"Make sure that {self.__class__} has defined a class name `config_name`")
107
+ # Special case for `kwargs` used in deprecation warning added to schedulers
108
+ # TODO: remove this when we remove the deprecation warning, and the `kwargs` argument,
109
+ # or solve in a more general way.
110
+ kwargs.pop("kwargs", None)
111
+
112
+ if not hasattr(self, "_internal_dict"):
113
+ internal_dict = kwargs
114
+ else:
115
+ previous_dict = dict(self._internal_dict)
116
+ internal_dict = {**self._internal_dict, **kwargs}
117
+ logger.debug(f"Updating config from {previous_dict} to {internal_dict}")
118
+
119
+ self._internal_dict = FrozenDict(internal_dict)
120
+
121
+ def __getattr__(self, name: str) -> Any:
122
+ """The only reason we overwrite `getattr` here is to gracefully deprecate accessing
123
+ config attributes directly. See https://github.com/huggingface/diffusers/pull/3129
124
+
125
+ Tihs funtion is mostly copied from PyTorch's __getattr__ overwrite:
126
+ https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module
127
+ """
128
+
129
+ is_in_config = "_internal_dict" in self.__dict__ and hasattr(self.__dict__["_internal_dict"], name)
130
+ is_attribute = name in self.__dict__
131
+
132
+ if is_in_config and not is_attribute:
133
+ deprecation_message = f"Accessing config attribute `{name}` directly via '{type(self).__name__}' object attribute is deprecated. Please access '{name}' over '{type(self).__name__}'s config object instead, e.g. 'scheduler.config.{name}'."
134
+ deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False)
135
+ return self._internal_dict[name]
136
+
137
+ raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
138
+
139
+ def save_config(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
140
+ """
141
+ Save a configuration object to the directory specified in `save_directory` so that it can be reloaded using the
142
+ [`~ConfigMixin.from_config`] class method.
143
+
144
+ Args:
145
+ save_directory (`str` or `os.PathLike`):
146
+ Directory where the configuration JSON file is saved (will be created if it does not exist).
147
+ """
148
+ if os.path.isfile(save_directory):
149
+ raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
150
+
151
+ os.makedirs(save_directory, exist_ok=True)
152
+
153
+ # If we save using the predefined names, we can load using `from_config`
154
+ output_config_file = os.path.join(save_directory, self.config_name)
155
+
156
+ self.to_json_file(output_config_file)
157
+ logger.info(f"Configuration saved in {output_config_file}")
158
+
159
+ @classmethod
160
+ def from_config(cls, config: Union[FrozenDict, Dict[str, Any]] = None, return_unused_kwargs=False, **kwargs):
161
+ r"""
162
+ Instantiate a Python class from a config dictionary.
163
+
164
+ Parameters:
165
+ config (`Dict[str, Any]`):
166
+ A config dictionary from which the Python class is instantiated. Make sure to only load configuration
167
+ files of compatible classes.
168
+ return_unused_kwargs (`bool`, *optional*, defaults to `False`):
169
+ Whether kwargs that are not consumed by the Python class should be returned or not.
170
+ kwargs (remaining dictionary of keyword arguments, *optional*):
171
+ Can be used to update the configuration object (after it is loaded) and initiate the Python class.
172
+ `**kwargs` are passed directly to the underlying scheduler/model's `__init__` method and eventually
173
+ overwrite the same named arguments in `config`.
174
+
175
+ Returns:
176
+ [`ModelMixin`] or [`SchedulerMixin`]:
177
+ A model or scheduler object instantiated from a config dictionary.
178
+
179
+ Examples:
180
+
181
+ ```python
182
+ >>> from diffusers import DDPMScheduler, DDIMScheduler, PNDMScheduler
183
+
184
+ >>> # Download scheduler from huggingface.co and cache.
185
+ >>> scheduler = DDPMScheduler.from_pretrained("google/ddpm-cifar10-32")
186
+
187
+ >>> # Instantiate DDIM scheduler class with same config as DDPM
188
+ >>> scheduler = DDIMScheduler.from_config(scheduler.config)
189
+
190
+ >>> # Instantiate PNDM scheduler class with same config as DDPM
191
+ >>> scheduler = PNDMScheduler.from_config(scheduler.config)
192
+ ```
193
+ """
194
+ # <===== TO BE REMOVED WITH DEPRECATION
195
+ # TODO(Patrick) - make sure to remove the following lines when config=="model_path" is deprecated
196
+ if "pretrained_model_name_or_path" in kwargs:
197
+ config = kwargs.pop("pretrained_model_name_or_path")
198
+
199
+ if config is None:
200
+ raise ValueError("Please make sure to provide a config as the first positional argument.")
201
+ # ======>
202
+
203
+ if not isinstance(config, dict):
204
+ deprecation_message = "It is deprecated to pass a pretrained model name or path to `from_config`."
205
+ if "Scheduler" in cls.__name__:
206
+ deprecation_message += (
207
+ f"If you were trying to load a scheduler, please use {cls}.from_pretrained(...) instead."
208
+ " Otherwise, please make sure to pass a configuration dictionary instead. This functionality will"
209
+ " be removed in v1.0.0."
210
+ )
211
+ elif "Model" in cls.__name__:
212
+ deprecation_message += (
213
+ f"If you were trying to load a model, please use {cls}.load_config(...) followed by"
214
+ f" {cls}.from_config(...) instead. Otherwise, please make sure to pass a configuration dictionary"
215
+ " instead. This functionality will be removed in v1.0.0."
216
+ )
217
+ deprecate("config-passed-as-path", "1.0.0", deprecation_message, standard_warn=False)
218
+ config, kwargs = cls.load_config(pretrained_model_name_or_path=config, return_unused_kwargs=True, **kwargs)
219
+
220
+ init_dict, unused_kwargs, hidden_dict = cls.extract_init_dict(config, **kwargs)
221
+
222
+ # Allow dtype to be specified on initialization
223
+ if "dtype" in unused_kwargs:
224
+ init_dict["dtype"] = unused_kwargs.pop("dtype")
225
+
226
+ # add possible deprecated kwargs
227
+ for deprecated_kwarg in cls._deprecated_kwargs:
228
+ if deprecated_kwarg in unused_kwargs:
229
+ init_dict[deprecated_kwarg] = unused_kwargs.pop(deprecated_kwarg)
230
+
231
+ # Return model and optionally state and/or unused_kwargs
232
+ model = cls(**init_dict)
233
+
234
+ # make sure to also save config parameters that might be used for compatible classes
235
+ model.register_to_config(**hidden_dict)
236
+
237
+ # add hidden kwargs of compatible classes to unused_kwargs
238
+ unused_kwargs = {**unused_kwargs, **hidden_dict}
239
+
240
+ if return_unused_kwargs:
241
+ return (model, unused_kwargs)
242
+ else:
243
+ return model
244
+
245
+ @classmethod
246
+ def get_config_dict(cls, *args, **kwargs):
247
+ deprecation_message = (
248
+ f" The function get_config_dict is deprecated. Please use {cls}.load_config instead. This function will be"
249
+ " removed in version v1.0.0"
250
+ )
251
+ deprecate("get_config_dict", "1.0.0", deprecation_message, standard_warn=False)
252
+ return cls.load_config(*args, **kwargs)
253
+
254
+ @classmethod
255
+ def load_config(
256
+ cls,
257
+ pretrained_model_name_or_path: Union[str, os.PathLike],
258
+ return_unused_kwargs=False,
259
+ return_commit_hash=False,
260
+ **kwargs,
261
+ ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
262
+ r"""
263
+ Load a model or scheduler configuration.
264
+
265
+ Parameters:
266
+ pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
267
+ Can be either:
268
+
269
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
270
+ the Hub.
271
+ - A path to a *directory* (for example `./my_model_directory`) containing model weights saved with
272
+ [`~ConfigMixin.save_config`].
273
+
274
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
275
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
276
+ is not used.
277
+ force_download (`bool`, *optional*, defaults to `False`):
278
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
279
+ cached versions if they exist.
280
+ resume_download (`bool`, *optional*, defaults to `False`):
281
+ Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
282
+ incompletely downloaded files are deleted.
283
+ proxies (`Dict[str, str]`, *optional*):
284
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
285
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
286
+ output_loading_info(`bool`, *optional*, defaults to `False`):
287
+ Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
288
+ local_files_only (`bool`, *optional*, defaults to `False`):
289
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
290
+ won't be downloaded from the Hub.
291
+ use_auth_token (`str` or *bool*, *optional*):
292
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
293
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
294
+ revision (`str`, *optional*, defaults to `"main"`):
295
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
296
+ allowed by Git.
297
+ subfolder (`str`, *optional*, defaults to `""`):
298
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
299
+ return_unused_kwargs (`bool`, *optional*, defaults to `False):
300
+ Whether unused keyword arguments of the config are returned.
301
+ return_commit_hash (`bool`, *optional*, defaults to `False):
302
+ Whether the `commit_hash` of the loaded configuration are returned.
303
+
304
+ Returns:
305
+ `dict`:
306
+ A dictionary of all the parameters stored in a JSON configuration file.
307
+
308
+ """
309
+ cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
310
+ force_download = kwargs.pop("force_download", False)
311
+ resume_download = kwargs.pop("resume_download", False)
312
+ proxies = kwargs.pop("proxies", None)
313
+ use_auth_token = kwargs.pop("use_auth_token", None)
314
+ local_files_only = kwargs.pop("local_files_only", False)
315
+ revision = kwargs.pop("revision", None)
316
+ _ = kwargs.pop("mirror", None)
317
+ subfolder = kwargs.pop("subfolder", None)
318
+ user_agent = kwargs.pop("user_agent", {})
319
+
320
+ user_agent = {**user_agent, "file_type": "config"}
321
+ user_agent = http_user_agent(user_agent)
322
+
323
+ pretrained_model_name_or_path = str(pretrained_model_name_or_path)
324
+
325
+ if cls.config_name is None:
326
+ raise ValueError(
327
+ "`self.config_name` is not defined. Note that one should not load a config from "
328
+ "`ConfigMixin`. Please make sure to define `config_name` in a class inheriting from `ConfigMixin`"
329
+ )
330
+
331
+ if os.path.isfile(pretrained_model_name_or_path):
332
+ config_file = pretrained_model_name_or_path
333
+ elif os.path.isdir(pretrained_model_name_or_path):
334
+ if os.path.isfile(os.path.join(pretrained_model_name_or_path, cls.config_name)):
335
+ # Load from a PyTorch checkpoint
336
+ config_file = os.path.join(pretrained_model_name_or_path, cls.config_name)
337
+ elif subfolder is not None and os.path.isfile(
338
+ os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name)
339
+ ):
340
+ config_file = os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name)
341
+ else:
342
+ raise EnvironmentError(
343
+ f"Error no file named {cls.config_name} found in directory {pretrained_model_name_or_path}."
344
+ )
345
+ else:
346
+ try:
347
+ # Load from URL or cache if already cached
348
+ config_file = hf_hub_download(
349
+ pretrained_model_name_or_path,
350
+ filename=cls.config_name,
351
+ cache_dir=cache_dir,
352
+ force_download=force_download,
353
+ proxies=proxies,
354
+ resume_download=resume_download,
355
+ local_files_only=local_files_only,
356
+ use_auth_token=use_auth_token,
357
+ user_agent=user_agent,
358
+ subfolder=subfolder,
359
+ revision=revision,
360
+ )
361
+ except RepositoryNotFoundError:
362
+ raise EnvironmentError(
363
+ f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier"
364
+ " listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a"
365
+ " token having permission to this repo with `use_auth_token` or log in with `huggingface-cli"
366
+ " login`."
367
+ )
368
+ except RevisionNotFoundError:
369
+ raise EnvironmentError(
370
+ f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for"
371
+ " this model name. Check the model page at"
372
+ f" 'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
373
+ )
374
+ except EntryNotFoundError:
375
+ raise EnvironmentError(
376
+ f"{pretrained_model_name_or_path} does not appear to have a file named {cls.config_name}."
377
+ )
378
+ except HTTPError as err:
379
+ raise EnvironmentError(
380
+ "There was a specific connection error when trying to load"
381
+ f" {pretrained_model_name_or_path}:\n{err}"
382
+ )
383
+ except ValueError:
384
+ raise EnvironmentError(
385
+ f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
386
+ f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
387
+ f" directory containing a {cls.config_name} file.\nCheckout your internet connection or see how to"
388
+ " run the library in offline mode at"
389
+ " 'https://huggingface.co/docs/diffusers/installation#offline-mode'."
390
+ )
391
+ except EnvironmentError:
392
+ raise EnvironmentError(
393
+ f"Can't load config for '{pretrained_model_name_or_path}'. If you were trying to load it from "
394
+ "'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
395
+ f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
396
+ f"containing a {cls.config_name} file"
397
+ )
398
+
399
+ try:
400
+ # Load config dict
401
+ config_dict = cls._dict_from_json_file(config_file)
402
+
403
+ commit_hash = extract_commit_hash(config_file)
404
+ except (json.JSONDecodeError, UnicodeDecodeError):
405
+ raise EnvironmentError(f"It looks like the config file at '{config_file}' is not a valid JSON file.")
406
+
407
+ if not (return_unused_kwargs or return_commit_hash):
408
+ return config_dict
409
+
410
+ outputs = (config_dict,)
411
+
412
+ if return_unused_kwargs:
413
+ outputs += (kwargs,)
414
+
415
+ if return_commit_hash:
416
+ outputs += (commit_hash,)
417
+
418
+ return outputs
419
+
420
+ @staticmethod
421
+ def _get_init_keys(cls):
422
+ return set(dict(inspect.signature(cls.__init__).parameters).keys())
423
+
424
+ @classmethod
425
+ def extract_init_dict(cls, config_dict, **kwargs):
426
+ # Skip keys that were not present in the original config, so default __init__ values were used
427
+ used_defaults = config_dict.get("_use_default_values", [])
428
+ config_dict = {k: v for k, v in config_dict.items() if k not in used_defaults and k != "_use_default_values"}
429
+
430
+ # 0. Copy origin config dict
431
+ original_dict = dict(config_dict.items())
432
+
433
+ # 1. Retrieve expected config attributes from __init__ signature
434
+ expected_keys = cls._get_init_keys(cls)
435
+ expected_keys.remove("self")
436
+ # remove general kwargs if present in dict
437
+ if "kwargs" in expected_keys:
438
+ expected_keys.remove("kwargs")
439
+ # remove flax internal keys
440
+ if hasattr(cls, "_flax_internal_args"):
441
+ for arg in cls._flax_internal_args:
442
+ expected_keys.remove(arg)
443
+
444
+ # 2. Remove attributes that cannot be expected from expected config attributes
445
+ # remove keys to be ignored
446
+ if len(cls.ignore_for_config) > 0:
447
+ expected_keys = expected_keys - set(cls.ignore_for_config)
448
+
449
+ # load diffusers library to import compatible and original scheduler
450
+ diffusers_library = importlib.import_module(__name__.split(".")[0])
451
+
452
+ if cls.has_compatibles:
453
+ compatible_classes = [c for c in cls._get_compatibles() if not isinstance(c, DummyObject)]
454
+ else:
455
+ compatible_classes = []
456
+
457
+ expected_keys_comp_cls = set()
458
+ for c in compatible_classes:
459
+ expected_keys_c = cls._get_init_keys(c)
460
+ expected_keys_comp_cls = expected_keys_comp_cls.union(expected_keys_c)
461
+ expected_keys_comp_cls = expected_keys_comp_cls - cls._get_init_keys(cls)
462
+ config_dict = {k: v for k, v in config_dict.items() if k not in expected_keys_comp_cls}
463
+
464
+ # remove attributes from orig class that cannot be expected
465
+ orig_cls_name = config_dict.pop("_class_name", cls.__name__)
466
+ if orig_cls_name != cls.__name__ and hasattr(diffusers_library, orig_cls_name):
467
+ orig_cls = getattr(diffusers_library, orig_cls_name)
468
+ unexpected_keys_from_orig = cls._get_init_keys(orig_cls) - expected_keys
469
+ config_dict = {k: v for k, v in config_dict.items() if k not in unexpected_keys_from_orig}
470
+
471
+ # remove private attributes
472
+ config_dict = {k: v for k, v in config_dict.items() if not k.startswith("_")}
473
+
474
+ # 3. Create keyword arguments that will be passed to __init__ from expected keyword arguments
475
+ init_dict = {}
476
+ for key in expected_keys:
477
+ # if config param is passed to kwarg and is present in config dict
478
+ # it should overwrite existing config dict key
479
+ if key in kwargs and key in config_dict:
480
+ config_dict[key] = kwargs.pop(key)
481
+
482
+ if key in kwargs:
483
+ # overwrite key
484
+ init_dict[key] = kwargs.pop(key)
485
+ elif key in config_dict:
486
+ # use value from config dict
487
+ init_dict[key] = config_dict.pop(key)
488
+
489
+ # 4. Give nice warning if unexpected values have been passed
490
+ if len(config_dict) > 0:
491
+ logger.warning(
492
+ f"The config attributes {config_dict} were passed to {cls.__name__}, "
493
+ "but are not expected and will be ignored. Please verify your "
494
+ f"{cls.config_name} configuration file."
495
+ )
496
+
497
+ # 5. Give nice info if config attributes are initiliazed to default because they have not been passed
498
+ passed_keys = set(init_dict.keys())
499
+ if len(expected_keys - passed_keys) > 0:
500
+ logger.info(
501
+ f"{expected_keys - passed_keys} was not found in config. Values will be initialized to default values."
502
+ )
503
+
504
+ # 6. Define unused keyword arguments
505
+ unused_kwargs = {**config_dict, **kwargs}
506
+
507
+ # 7. Define "hidden" config parameters that were saved for compatible classes
508
+ hidden_config_dict = {k: v for k, v in original_dict.items() if k not in init_dict}
509
+
510
+ return init_dict, unused_kwargs, hidden_config_dict
511
+
512
+ @classmethod
513
+ def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]):
514
+ with open(json_file, "r", encoding="utf-8") as reader:
515
+ text = reader.read()
516
+ return json.loads(text)
517
+
518
+ def __repr__(self):
519
+ return f"{self.__class__.__name__} {self.to_json_string()}"
520
+
521
+ @property
522
+ def config(self) -> Dict[str, Any]:
523
+ """
524
+ Returns the config of the class as a frozen dictionary
525
+
526
+ Returns:
527
+ `Dict[str, Any]`: Config of the class.
528
+ """
529
+ return self._internal_dict
530
+
531
+ def to_json_string(self) -> str:
532
+ """
533
+ Serializes the configuration instance to a JSON string.
534
+
535
+ Returns:
536
+ `str`:
537
+ String containing all the attributes that make up the configuration instance in JSON format.
538
+ """
539
+ config_dict = self._internal_dict if hasattr(self, "_internal_dict") else {}
540
+ config_dict["_class_name"] = self.__class__.__name__
541
+ config_dict["_diffusers_version"] = __version__
542
+
543
+ def to_json_saveable(value):
544
+ if isinstance(value, np.ndarray):
545
+ value = value.tolist()
546
+ elif isinstance(value, PosixPath):
547
+ value = str(value)
548
+ return value
549
+
550
+ config_dict = {k: to_json_saveable(v) for k, v in config_dict.items()}
551
+ # Don't save "_ignore_files" or "_use_default_values"
552
+ config_dict.pop("_ignore_files", None)
553
+ config_dict.pop("_use_default_values", None)
554
+
555
+ return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
556
+
557
+ def to_json_file(self, json_file_path: Union[str, os.PathLike]):
558
+ """
559
+ Save the configuration instance's parameters to a JSON file.
560
+
561
+ Args:
562
+ json_file_path (`str` or `os.PathLike`):
563
+ Path to the JSON file to save a configuration instance's parameters.
564
+ """
565
+ with open(json_file_path, "w", encoding="utf-8") as writer:
566
+ writer.write(self.to_json_string())
567
+
568
+
569
+ def register_to_config(init):
570
+ r"""
571
+ Decorator to apply on the init of classes inheriting from [`ConfigMixin`] so that all the arguments are
572
+ automatically sent to `self.register_for_config`. To ignore a specific argument accepted by the init but that
573
+ shouldn't be registered in the config, use the `ignore_for_config` class variable
574
+
575
+ Warning: Once decorated, all private arguments (beginning with an underscore) are trashed and not sent to the init!
576
+ """
577
+
578
+ @functools.wraps(init)
579
+ def inner_init(self, *args, **kwargs):
580
+ # Ignore private kwargs in the init.
581
+ init_kwargs = {k: v for k, v in kwargs.items() if not k.startswith("_")}
582
+ config_init_kwargs = {k: v for k, v in kwargs.items() if k.startswith("_")}
583
+ if not isinstance(self, ConfigMixin):
584
+ raise RuntimeError(
585
+ f"`@register_for_config` was applied to {self.__class__.__name__} init method, but this class does "
586
+ "not inherit from `ConfigMixin`."
587
+ )
588
+
589
+ ignore = getattr(self, "ignore_for_config", [])
590
+ # Get positional arguments aligned with kwargs
591
+ new_kwargs = {}
592
+ signature = inspect.signature(init)
593
+ parameters = {
594
+ name: p.default for i, (name, p) in enumerate(signature.parameters.items()) if i > 0 and name not in ignore
595
+ }
596
+ for arg, name in zip(args, parameters.keys()):
597
+ new_kwargs[name] = arg
598
+
599
+ # Then add all kwargs
600
+ new_kwargs.update(
601
+ {
602
+ k: init_kwargs.get(k, default)
603
+ for k, default in parameters.items()
604
+ if k not in ignore and k not in new_kwargs
605
+ }
606
+ )
607
+
608
+ # Take note of the parameters that were not present in the loaded config
609
+ if len(set(new_kwargs.keys()) - set(init_kwargs)) > 0:
610
+ new_kwargs["_use_default_values"] = list(set(new_kwargs.keys()) - set(init_kwargs))
611
+
612
+ new_kwargs = {**config_init_kwargs, **new_kwargs}
613
+ getattr(self, "register_to_config")(**new_kwargs)
614
+ init(self, *args, **init_kwargs)
615
+
616
+ return inner_init
617
+
618
+
619
+ def flax_register_to_config(cls):
620
+ original_init = cls.__init__
621
+
622
+ @functools.wraps(original_init)
623
+ def init(self, *args, **kwargs):
624
+ if not isinstance(self, ConfigMixin):
625
+ raise RuntimeError(
626
+ f"`@register_for_config` was applied to {self.__class__.__name__} init method, but this class does "
627
+ "not inherit from `ConfigMixin`."
628
+ )
629
+
630
+ # Ignore private kwargs in the init. Retrieve all passed attributes
631
+ init_kwargs = dict(kwargs.items())
632
+
633
+ # Retrieve default values
634
+ fields = dataclasses.fields(self)
635
+ default_kwargs = {}
636
+ for field in fields:
637
+ # ignore flax specific attributes
638
+ if field.name in self._flax_internal_args:
639
+ continue
640
+ if type(field.default) == dataclasses._MISSING_TYPE:
641
+ default_kwargs[field.name] = None
642
+ else:
643
+ default_kwargs[field.name] = getattr(self, field.name)
644
+
645
+ # Make sure init_kwargs override default kwargs
646
+ new_kwargs = {**default_kwargs, **init_kwargs}
647
+ # dtype should be part of `init_kwargs`, but not `new_kwargs`
648
+ if "dtype" in new_kwargs:
649
+ new_kwargs.pop("dtype")
650
+
651
+ # Get positional arguments aligned with kwargs
652
+ for i, arg in enumerate(args):
653
+ name = fields[i].name
654
+ new_kwargs[name] = arg
655
+
656
+ # Take note of the parameters that were not present in the loaded config
657
+ if len(set(new_kwargs.keys()) - set(init_kwargs)) > 0:
658
+ new_kwargs["_use_default_values"] = list(set(new_kwargs.keys()) - set(init_kwargs))
659
+
660
+ getattr(self, "register_to_config")(**new_kwargs)
661
+ original_init(self, *args, **kwargs)
662
+
663
+ cls.__init__ = init
664
+ return cls
6DoF/diffusers/dependency_versions_check.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import sys
15
+
16
+ from .dependency_versions_table import deps
17
+ from .utils.versions import require_version, require_version_core
18
+
19
+
20
+ # define which module versions we always want to check at run time
21
+ # (usually the ones defined in `install_requires` in setup.py)
22
+ #
23
+ # order specific notes:
24
+ # - tqdm must be checked before tokenizers
25
+
26
+ pkgs_to_check_at_runtime = "python tqdm regex requests packaging filelock numpy tokenizers".split()
27
+ if sys.version_info < (3, 7):
28
+ pkgs_to_check_at_runtime.append("dataclasses")
29
+ if sys.version_info < (3, 8):
30
+ pkgs_to_check_at_runtime.append("importlib_metadata")
31
+
32
+ for pkg in pkgs_to_check_at_runtime:
33
+ if pkg in deps:
34
+ if pkg == "tokenizers":
35
+ # must be loaded here, or else tqdm check may fail
36
+ from .utils import is_tokenizers_available
37
+
38
+ if not is_tokenizers_available():
39
+ continue # not required, check version only if installed
40
+
41
+ require_version_core(deps[pkg])
42
+ else:
43
+ raise ValueError(f"can't find {pkg} in {deps.keys()}, check dependency_versions_table.py")
44
+
45
+
46
+ def dep_version_check(pkg, hint=None):
47
+ require_version(deps[pkg], hint)
6DoF/diffusers/dependency_versions_table.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # THIS FILE HAS BEEN AUTOGENERATED. To update:
2
+ # 1. modify the `_deps` dict in setup.py
3
+ # 2. run `make deps_table_update``
4
+ deps = {
5
+ "Pillow": "Pillow",
6
+ "accelerate": "accelerate>=0.11.0",
7
+ "compel": "compel==0.1.8",
8
+ "black": "black~=23.1",
9
+ "datasets": "datasets",
10
+ "filelock": "filelock",
11
+ "flax": "flax>=0.4.1",
12
+ "hf-doc-builder": "hf-doc-builder>=0.3.0",
13
+ "huggingface-hub": "huggingface-hub>=0.13.2",
14
+ "requests-mock": "requests-mock==1.10.0",
15
+ "importlib_metadata": "importlib_metadata",
16
+ "invisible-watermark": "invisible-watermark",
17
+ "isort": "isort>=5.5.4",
18
+ "jax": "jax>=0.2.8,!=0.3.2",
19
+ "jaxlib": "jaxlib>=0.1.65",
20
+ "Jinja2": "Jinja2",
21
+ "k-diffusion": "k-diffusion>=0.0.12",
22
+ "torchsde": "torchsde",
23
+ "note_seq": "note_seq",
24
+ "librosa": "librosa",
25
+ "numpy": "numpy",
26
+ "omegaconf": "omegaconf",
27
+ "parameterized": "parameterized",
28
+ "protobuf": "protobuf>=3.20.3,<4",
29
+ "pytest": "pytest",
30
+ "pytest-timeout": "pytest-timeout",
31
+ "pytest-xdist": "pytest-xdist",
32
+ "ruff": "ruff>=0.0.241",
33
+ "safetensors": "safetensors",
34
+ "sentencepiece": "sentencepiece>=0.1.91,!=0.1.92",
35
+ "scipy": "scipy",
36
+ "onnx": "onnx",
37
+ "regex": "regex!=2019.12.17",
38
+ "requests": "requests",
39
+ "tensorboard": "tensorboard",
40
+ "torch": "torch>=1.4",
41
+ "torchvision": "torchvision",
42
+ "transformers": "transformers>=4.25.1",
43
+ "urllib3": "urllib3<=2.0.0",
44
+ }
6DoF/diffusers/experimental/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .rl import ValueGuidedRLPipeline
6DoF/diffusers/experimental/rl/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .value_guided_sampling import ValueGuidedRLPipeline
6DoF/diffusers/experimental/rl/value_guided_sampling.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import numpy as np
16
+ import torch
17
+ import tqdm
18
+
19
+ from ...models.unet_1d import UNet1DModel
20
+ from ...pipelines import DiffusionPipeline
21
+ from ...utils import randn_tensor
22
+ from ...utils.dummy_pt_objects import DDPMScheduler
23
+
24
+
25
+ class ValueGuidedRLPipeline(DiffusionPipeline):
26
+ r"""
27
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
28
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
29
+ Pipeline for sampling actions from a diffusion model trained to predict sequences of states.
30
+
31
+ Original implementation inspired by this repository: https://github.com/jannerm/diffuser.
32
+
33
+ Parameters:
34
+ value_function ([`UNet1DModel`]): A specialized UNet for fine-tuning trajectories base on reward.
35
+ unet ([`UNet1DModel`]): U-Net architecture to denoise the encoded trajectories.
36
+ scheduler ([`SchedulerMixin`]):
37
+ A scheduler to be used in combination with `unet` to denoise the encoded trajectories. Default for this
38
+ application is [`DDPMScheduler`].
39
+ env: An environment following the OpenAI gym API to act in. For now only Hopper has pretrained models.
40
+ """
41
+
42
+ def __init__(
43
+ self,
44
+ value_function: UNet1DModel,
45
+ unet: UNet1DModel,
46
+ scheduler: DDPMScheduler,
47
+ env,
48
+ ):
49
+ super().__init__()
50
+ self.value_function = value_function
51
+ self.unet = unet
52
+ self.scheduler = scheduler
53
+ self.env = env
54
+ self.data = env.get_dataset()
55
+ self.means = {}
56
+ for key in self.data.keys():
57
+ try:
58
+ self.means[key] = self.data[key].mean()
59
+ except: # noqa: E722
60
+ pass
61
+ self.stds = {}
62
+ for key in self.data.keys():
63
+ try:
64
+ self.stds[key] = self.data[key].std()
65
+ except: # noqa: E722
66
+ pass
67
+ self.state_dim = env.observation_space.shape[0]
68
+ self.action_dim = env.action_space.shape[0]
69
+
70
+ def normalize(self, x_in, key):
71
+ return (x_in - self.means[key]) / self.stds[key]
72
+
73
+ def de_normalize(self, x_in, key):
74
+ return x_in * self.stds[key] + self.means[key]
75
+
76
+ def to_torch(self, x_in):
77
+ if type(x_in) is dict:
78
+ return {k: self.to_torch(v) for k, v in x_in.items()}
79
+ elif torch.is_tensor(x_in):
80
+ return x_in.to(self.unet.device)
81
+ return torch.tensor(x_in, device=self.unet.device)
82
+
83
+ def reset_x0(self, x_in, cond, act_dim):
84
+ for key, val in cond.items():
85
+ x_in[:, key, act_dim:] = val.clone()
86
+ return x_in
87
+
88
+ def run_diffusion(self, x, conditions, n_guide_steps, scale):
89
+ batch_size = x.shape[0]
90
+ y = None
91
+ for i in tqdm.tqdm(self.scheduler.timesteps):
92
+ # create batch of timesteps to pass into model
93
+ timesteps = torch.full((batch_size,), i, device=self.unet.device, dtype=torch.long)
94
+ for _ in range(n_guide_steps):
95
+ with torch.enable_grad():
96
+ x.requires_grad_()
97
+
98
+ # permute to match dimension for pre-trained models
99
+ y = self.value_function(x.permute(0, 2, 1), timesteps).sample
100
+ grad = torch.autograd.grad([y.sum()], [x])[0]
101
+
102
+ posterior_variance = self.scheduler._get_variance(i)
103
+ model_std = torch.exp(0.5 * posterior_variance)
104
+ grad = model_std * grad
105
+
106
+ grad[timesteps < 2] = 0
107
+ x = x.detach()
108
+ x = x + scale * grad
109
+ x = self.reset_x0(x, conditions, self.action_dim)
110
+
111
+ prev_x = self.unet(x.permute(0, 2, 1), timesteps).sample.permute(0, 2, 1)
112
+
113
+ # TODO: verify deprecation of this kwarg
114
+ x = self.scheduler.step(prev_x, i, x, predict_epsilon=False)["prev_sample"]
115
+
116
+ # apply conditions to the trajectory (set the initial state)
117
+ x = self.reset_x0(x, conditions, self.action_dim)
118
+ x = self.to_torch(x)
119
+ return x, y
120
+
121
+ def __call__(self, obs, batch_size=64, planning_horizon=32, n_guide_steps=2, scale=0.1):
122
+ # normalize the observations and create batch dimension
123
+ obs = self.normalize(obs, "observations")
124
+ obs = obs[None].repeat(batch_size, axis=0)
125
+
126
+ conditions = {0: self.to_torch(obs)}
127
+ shape = (batch_size, planning_horizon, self.state_dim + self.action_dim)
128
+
129
+ # generate initial noise and apply our conditions (to make the trajectories start at current state)
130
+ x1 = randn_tensor(shape, device=self.unet.device)
131
+ x = self.reset_x0(x1, conditions, self.action_dim)
132
+ x = self.to_torch(x)
133
+
134
+ # run the diffusion process
135
+ x, y = self.run_diffusion(x, conditions, n_guide_steps, scale)
136
+
137
+ # sort output trajectories by value
138
+ sorted_idx = y.argsort(0, descending=True).squeeze()
139
+ sorted_values = x[sorted_idx]
140
+ actions = sorted_values[:, :, : self.action_dim]
141
+ actions = actions.detach().cpu().numpy()
142
+ denorm_actions = self.de_normalize(actions, key="actions")
143
+
144
+ # select the action with the highest value
145
+ if y is not None:
146
+ selected_index = 0
147
+ else:
148
+ # if we didn't run value guiding, select a random action
149
+ selected_index = np.random.randint(0, batch_size)
150
+
151
+ denorm_actions = denorm_actions[selected_index, 0]
152
+ return denorm_actions
6DoF/diffusers/image_processor.py ADDED
@@ -0,0 +1,366 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import warnings
16
+ from typing import List, Optional, Union
17
+
18
+ import numpy as np
19
+ import PIL
20
+ import torch
21
+ from PIL import Image
22
+
23
+ from .configuration_utils import ConfigMixin, register_to_config
24
+ from .utils import CONFIG_NAME, PIL_INTERPOLATION, deprecate
25
+
26
+
27
+ class VaeImageProcessor(ConfigMixin):
28
+ """
29
+ Image processor for VAE.
30
+
31
+ Args:
32
+ do_resize (`bool`, *optional*, defaults to `True`):
33
+ Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`. Can accept
34
+ `height` and `width` arguments from [`image_processor.VaeImageProcessor.preprocess`] method.
35
+ vae_scale_factor (`int`, *optional*, defaults to `8`):
36
+ VAE scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of this factor.
37
+ resample (`str`, *optional*, defaults to `lanczos`):
38
+ Resampling filter to use when resizing the image.
39
+ do_normalize (`bool`, *optional*, defaults to `True`):
40
+ Whether to normalize the image to [-1,1].
41
+ do_convert_rgb (`bool`, *optional*, defaults to be `False`):
42
+ Whether to convert the images to RGB format.
43
+ """
44
+
45
+ config_name = CONFIG_NAME
46
+
47
+ @register_to_config
48
+ def __init__(
49
+ self,
50
+ do_resize: bool = True,
51
+ vae_scale_factor: int = 8,
52
+ resample: str = "lanczos",
53
+ do_normalize: bool = True,
54
+ do_convert_rgb: bool = False,
55
+ ):
56
+ super().__init__()
57
+
58
+ @staticmethod
59
+ def numpy_to_pil(images: np.ndarray) -> PIL.Image.Image:
60
+ """
61
+ Convert a numpy image or a batch of images to a PIL image.
62
+ """
63
+ if images.ndim == 3:
64
+ images = images[None, ...]
65
+ images = (images * 255).round().astype("uint8")
66
+ if images.shape[-1] == 1:
67
+ # special case for grayscale (single channel) images
68
+ pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
69
+ else:
70
+ pil_images = [Image.fromarray(image) for image in images]
71
+
72
+ return pil_images
73
+
74
+ @staticmethod
75
+ def pil_to_numpy(images: Union[List[PIL.Image.Image], PIL.Image.Image]) -> np.ndarray:
76
+ """
77
+ Convert a PIL image or a list of PIL images to NumPy arrays.
78
+ """
79
+ if not isinstance(images, list):
80
+ images = [images]
81
+ images = [np.array(image).astype(np.float32) / 255.0 for image in images]
82
+ images = np.stack(images, axis=0)
83
+
84
+ return images
85
+
86
+ @staticmethod
87
+ def numpy_to_pt(images: np.ndarray) -> torch.FloatTensor:
88
+ """
89
+ Convert a NumPy image to a PyTorch tensor.
90
+ """
91
+ if images.ndim == 3:
92
+ images = images[..., None]
93
+
94
+ images = torch.from_numpy(images.transpose(0, 3, 1, 2))
95
+ return images
96
+
97
+ @staticmethod
98
+ def pt_to_numpy(images: torch.FloatTensor) -> np.ndarray:
99
+ """
100
+ Convert a PyTorch tensor to a NumPy image.
101
+ """
102
+ images = images.cpu().permute(0, 2, 3, 1).float().numpy()
103
+ return images
104
+
105
+ @staticmethod
106
+ def normalize(images):
107
+ """
108
+ Normalize an image array to [-1,1].
109
+ """
110
+ return 2.0 * images - 1.0
111
+
112
+ @staticmethod
113
+ def denormalize(images):
114
+ """
115
+ Denormalize an image array to [0,1].
116
+ """
117
+ return (images / 2 + 0.5).clamp(0, 1)
118
+
119
+ @staticmethod
120
+ def convert_to_rgb(image: PIL.Image.Image) -> PIL.Image.Image:
121
+ """
122
+ Converts an image to RGB format.
123
+ """
124
+ image = image.convert("RGB")
125
+ return image
126
+
127
+ def resize(
128
+ self,
129
+ image: PIL.Image.Image,
130
+ height: Optional[int] = None,
131
+ width: Optional[int] = None,
132
+ ) -> PIL.Image.Image:
133
+ """
134
+ Resize a PIL image. Both height and width are downscaled to the next integer multiple of `vae_scale_factor`.
135
+ """
136
+ if height is None:
137
+ height = image.height
138
+ if width is None:
139
+ width = image.width
140
+
141
+ width, height = (
142
+ x - x % self.config.vae_scale_factor for x in (width, height)
143
+ ) # resize to integer multiple of vae_scale_factor
144
+ image = image.resize((width, height), resample=PIL_INTERPOLATION[self.config.resample])
145
+ return image
146
+
147
+ def preprocess(
148
+ self,
149
+ image: Union[torch.FloatTensor, PIL.Image.Image, np.ndarray],
150
+ height: Optional[int] = None,
151
+ width: Optional[int] = None,
152
+ ) -> torch.Tensor:
153
+ """
154
+ Preprocess the image input. Accepted formats are PIL images, NumPy arrays or PyTorch tensors.
155
+ """
156
+ supported_formats = (PIL.Image.Image, np.ndarray, torch.Tensor)
157
+ if isinstance(image, supported_formats):
158
+ image = [image]
159
+ elif not (isinstance(image, list) and all(isinstance(i, supported_formats) for i in image)):
160
+ raise ValueError(
161
+ f"Input is in incorrect format: {[type(i) for i in image]}. Currently, we only support {', '.join(supported_formats)}"
162
+ )
163
+
164
+ if isinstance(image[0], PIL.Image.Image):
165
+ if self.config.do_convert_rgb:
166
+ image = [self.convert_to_rgb(i) for i in image]
167
+ if self.config.do_resize:
168
+ image = [self.resize(i, height, width) for i in image]
169
+ image = self.pil_to_numpy(image) # to np
170
+ image = self.numpy_to_pt(image) # to pt
171
+
172
+ elif isinstance(image[0], np.ndarray):
173
+ image = np.concatenate(image, axis=0) if image[0].ndim == 4 else np.stack(image, axis=0)
174
+ image = self.numpy_to_pt(image)
175
+ _, _, height, width = image.shape
176
+ if self.config.do_resize and (
177
+ height % self.config.vae_scale_factor != 0 or width % self.config.vae_scale_factor != 0
178
+ ):
179
+ raise ValueError(
180
+ f"Currently we only support resizing for PIL image - please resize your numpy array to be divisible by {self.config.vae_scale_factor}"
181
+ f"currently the sizes are {height} and {width}. You can also pass a PIL image instead to use resize option in VAEImageProcessor"
182
+ )
183
+
184
+ elif isinstance(image[0], torch.Tensor):
185
+ image = torch.cat(image, axis=0) if image[0].ndim == 4 else torch.stack(image, axis=0)
186
+ _, channel, height, width = image.shape
187
+
188
+ # don't need any preprocess if the image is latents
189
+ if channel == 4:
190
+ return image
191
+
192
+ if self.config.do_resize and (
193
+ height % self.config.vae_scale_factor != 0 or width % self.config.vae_scale_factor != 0
194
+ ):
195
+ raise ValueError(
196
+ f"Currently we only support resizing for PIL image - please resize your pytorch tensor to be divisible by {self.config.vae_scale_factor}"
197
+ f"currently the sizes are {height} and {width}. You can also pass a PIL image instead to use resize option in VAEImageProcessor"
198
+ )
199
+
200
+ # expected range [0,1], normalize to [-1,1]
201
+ do_normalize = self.config.do_normalize
202
+ if image.min() < 0:
203
+ warnings.warn(
204
+ "Passing `image` as torch tensor with value range in [-1,1] is deprecated. The expected value range for image tensor is [0,1] "
205
+ f"when passing as pytorch tensor or numpy Array. You passed `image` with value range [{image.min()},{image.max()}]",
206
+ FutureWarning,
207
+ )
208
+ do_normalize = False
209
+
210
+ if do_normalize:
211
+ image = self.normalize(image)
212
+
213
+ return image
214
+
215
+ def postprocess(
216
+ self,
217
+ image: torch.FloatTensor,
218
+ output_type: str = "pil",
219
+ do_denormalize: Optional[List[bool]] = None,
220
+ ):
221
+ if not isinstance(image, torch.Tensor):
222
+ raise ValueError(
223
+ f"Input for postprocessing is in incorrect format: {type(image)}. We only support pytorch tensor"
224
+ )
225
+ if output_type not in ["latent", "pt", "np", "pil"]:
226
+ deprecation_message = (
227
+ f"the output_type {output_type} is outdated and has been set to `np`. Please make sure to set it to one of these instead: "
228
+ "`pil`, `np`, `pt`, `latent`"
229
+ )
230
+ deprecate("Unsupported output_type", "1.0.0", deprecation_message, standard_warn=False)
231
+ output_type = "np"
232
+
233
+ if output_type == "latent":
234
+ return image
235
+
236
+ if do_denormalize is None:
237
+ do_denormalize = [self.config.do_normalize] * image.shape[0]
238
+
239
+ image = torch.stack(
240
+ [self.denormalize(image[i]) if do_denormalize[i] else image[i] for i in range(image.shape[0])]
241
+ )
242
+
243
+ if output_type == "pt":
244
+ return image
245
+
246
+ image = self.pt_to_numpy(image)
247
+
248
+ if output_type == "np":
249
+ return image
250
+
251
+ if output_type == "pil":
252
+ return self.numpy_to_pil(image)
253
+
254
+
255
+ class VaeImageProcessorLDM3D(VaeImageProcessor):
256
+ """
257
+ Image processor for VAE LDM3D.
258
+
259
+ Args:
260
+ do_resize (`bool`, *optional*, defaults to `True`):
261
+ Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`.
262
+ vae_scale_factor (`int`, *optional*, defaults to `8`):
263
+ VAE scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of this factor.
264
+ resample (`str`, *optional*, defaults to `lanczos`):
265
+ Resampling filter to use when resizing the image.
266
+ do_normalize (`bool`, *optional*, defaults to `True`):
267
+ Whether to normalize the image to [-1,1].
268
+ """
269
+
270
+ config_name = CONFIG_NAME
271
+
272
+ @register_to_config
273
+ def __init__(
274
+ self,
275
+ do_resize: bool = True,
276
+ vae_scale_factor: int = 8,
277
+ resample: str = "lanczos",
278
+ do_normalize: bool = True,
279
+ ):
280
+ super().__init__()
281
+
282
+ @staticmethod
283
+ def numpy_to_pil(images):
284
+ """
285
+ Convert a NumPy image or a batch of images to a PIL image.
286
+ """
287
+ if images.ndim == 3:
288
+ images = images[None, ...]
289
+ images = (images * 255).round().astype("uint8")
290
+ if images.shape[-1] == 1:
291
+ # special case for grayscale (single channel) images
292
+ pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
293
+ else:
294
+ pil_images = [Image.fromarray(image[:, :, :3]) for image in images]
295
+
296
+ return pil_images
297
+
298
+ @staticmethod
299
+ def rgblike_to_depthmap(image):
300
+ """
301
+ Args:
302
+ image: RGB-like depth image
303
+
304
+ Returns: depth map
305
+
306
+ """
307
+ return image[:, :, 1] * 2**8 + image[:, :, 2]
308
+
309
+ def numpy_to_depth(self, images):
310
+ """
311
+ Convert a NumPy depth image or a batch of images to a PIL image.
312
+ """
313
+ if images.ndim == 3:
314
+ images = images[None, ...]
315
+ images_depth = images[:, :, :, 3:]
316
+ if images.shape[-1] == 6:
317
+ images_depth = (images_depth * 255).round().astype("uint8")
318
+ pil_images = [
319
+ Image.fromarray(self.rgblike_to_depthmap(image_depth), mode="I;16") for image_depth in images_depth
320
+ ]
321
+ elif images.shape[-1] == 4:
322
+ images_depth = (images_depth * 65535.0).astype(np.uint16)
323
+ pil_images = [Image.fromarray(image_depth, mode="I;16") for image_depth in images_depth]
324
+ else:
325
+ raise Exception("Not supported")
326
+
327
+ return pil_images
328
+
329
+ def postprocess(
330
+ self,
331
+ image: torch.FloatTensor,
332
+ output_type: str = "pil",
333
+ do_denormalize: Optional[List[bool]] = None,
334
+ ):
335
+ if not isinstance(image, torch.Tensor):
336
+ raise ValueError(
337
+ f"Input for postprocessing is in incorrect format: {type(image)}. We only support pytorch tensor"
338
+ )
339
+ if output_type not in ["latent", "pt", "np", "pil"]:
340
+ deprecation_message = (
341
+ f"the output_type {output_type} is outdated and has been set to `np`. Please make sure to set it to one of these instead: "
342
+ "`pil`, `np`, `pt`, `latent`"
343
+ )
344
+ deprecate("Unsupported output_type", "1.0.0", deprecation_message, standard_warn=False)
345
+ output_type = "np"
346
+
347
+ if do_denormalize is None:
348
+ do_denormalize = [self.config.do_normalize] * image.shape[0]
349
+
350
+ image = torch.stack(
351
+ [self.denormalize(image[i]) if do_denormalize[i] else image[i] for i in range(image.shape[0])]
352
+ )
353
+
354
+ image = self.pt_to_numpy(image)
355
+
356
+ if output_type == "np":
357
+ if image.shape[-1] == 6:
358
+ image_depth = np.stack([self.rgblike_to_depthmap(im[:, :, 3:]) for im in image], axis=0)
359
+ else:
360
+ image_depth = image[:, :, :, 3:]
361
+ return image[:, :, :, :3], image_depth
362
+
363
+ if output_type == "pil":
364
+ return self.numpy_to_pil(image), self.numpy_to_depth(image)
365
+ else:
366
+ raise Exception(f"This type {output_type} is not supported")
6DoF/diffusers/loaders.py ADDED
@@ -0,0 +1,1492 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import os
15
+ import warnings
16
+ from collections import defaultdict
17
+ from pathlib import Path
18
+ from typing import Callable, Dict, List, Optional, Union
19
+
20
+ import torch
21
+ import torch.nn.functional as F
22
+ from huggingface_hub import hf_hub_download
23
+
24
+ from .models.attention_processor import (
25
+ AttnAddedKVProcessor,
26
+ AttnAddedKVProcessor2_0,
27
+ CustomDiffusionAttnProcessor,
28
+ CustomDiffusionXFormersAttnProcessor,
29
+ LoRAAttnAddedKVProcessor,
30
+ LoRAAttnProcessor,
31
+ LoRAAttnProcessor2_0,
32
+ LoRAXFormersAttnProcessor,
33
+ SlicedAttnAddedKVProcessor,
34
+ XFormersAttnProcessor,
35
+ )
36
+ from .utils import (
37
+ DIFFUSERS_CACHE,
38
+ HF_HUB_OFFLINE,
39
+ TEXT_ENCODER_ATTN_MODULE,
40
+ _get_model_file,
41
+ deprecate,
42
+ is_safetensors_available,
43
+ is_transformers_available,
44
+ logging,
45
+ )
46
+
47
+
48
+ if is_safetensors_available():
49
+ import safetensors
50
+
51
+ if is_transformers_available():
52
+ from transformers import PreTrainedModel, PreTrainedTokenizer
53
+
54
+
55
+ logger = logging.get_logger(__name__)
56
+
57
+ TEXT_ENCODER_NAME = "text_encoder"
58
+ UNET_NAME = "unet"
59
+
60
+ LORA_WEIGHT_NAME = "pytorch_lora_weights.bin"
61
+ LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors"
62
+
63
+ TEXT_INVERSION_NAME = "learned_embeds.bin"
64
+ TEXT_INVERSION_NAME_SAFE = "learned_embeds.safetensors"
65
+
66
+ CUSTOM_DIFFUSION_WEIGHT_NAME = "pytorch_custom_diffusion_weights.bin"
67
+ CUSTOM_DIFFUSION_WEIGHT_NAME_SAFE = "pytorch_custom_diffusion_weights.safetensors"
68
+
69
+
70
+ class AttnProcsLayers(torch.nn.Module):
71
+ def __init__(self, state_dict: Dict[str, torch.Tensor]):
72
+ super().__init__()
73
+ self.layers = torch.nn.ModuleList(state_dict.values())
74
+ self.mapping = dict(enumerate(state_dict.keys()))
75
+ self.rev_mapping = {v: k for k, v in enumerate(state_dict.keys())}
76
+
77
+ # .processor for unet, .self_attn for text encoder
78
+ self.split_keys = [".processor", ".self_attn"]
79
+
80
+ # we add a hook to state_dict() and load_state_dict() so that the
81
+ # naming fits with `unet.attn_processors`
82
+ def map_to(module, state_dict, *args, **kwargs):
83
+ new_state_dict = {}
84
+ for key, value in state_dict.items():
85
+ num = int(key.split(".")[1]) # 0 is always "layers"
86
+ new_key = key.replace(f"layers.{num}", module.mapping[num])
87
+ new_state_dict[new_key] = value
88
+
89
+ return new_state_dict
90
+
91
+ def remap_key(key, state_dict):
92
+ for k in self.split_keys:
93
+ if k in key:
94
+ return key.split(k)[0] + k
95
+
96
+ raise ValueError(
97
+ f"There seems to be a problem with the state_dict: {set(state_dict.keys())}. {key} has to have one of {self.split_keys}."
98
+ )
99
+
100
+ def map_from(module, state_dict, *args, **kwargs):
101
+ all_keys = list(state_dict.keys())
102
+ for key in all_keys:
103
+ replace_key = remap_key(key, state_dict)
104
+ new_key = key.replace(replace_key, f"layers.{module.rev_mapping[replace_key]}")
105
+ state_dict[new_key] = state_dict[key]
106
+ del state_dict[key]
107
+
108
+ self._register_state_dict_hook(map_to)
109
+ self._register_load_state_dict_pre_hook(map_from, with_module=True)
110
+
111
+
112
+ class UNet2DConditionLoadersMixin:
113
+ text_encoder_name = TEXT_ENCODER_NAME
114
+ unet_name = UNET_NAME
115
+
116
+ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
117
+ r"""
118
+ Load pretrained attention processor layers into [`UNet2DConditionModel`]. Attention processor layers have to be
119
+ defined in
120
+ [`cross_attention.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py)
121
+ and be a `torch.nn.Module` class.
122
+
123
+ Parameters:
124
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
125
+ Can be either:
126
+
127
+ - A string, the model id (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
128
+ the Hub.
129
+ - A path to a directory (for example `./my_model_directory`) containing the model weights saved
130
+ with [`ModelMixin.save_pretrained`].
131
+ - A [torch state
132
+ dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
133
+
134
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
135
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
136
+ is not used.
137
+ force_download (`bool`, *optional*, defaults to `False`):
138
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
139
+ cached versions if they exist.
140
+ resume_download (`bool`, *optional*, defaults to `False`):
141
+ Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
142
+ incompletely downloaded files are deleted.
143
+ proxies (`Dict[str, str]`, *optional*):
144
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
145
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
146
+ local_files_only (`bool`, *optional*, defaults to `False`):
147
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
148
+ won't be downloaded from the Hub.
149
+ use_auth_token (`str` or *bool*, *optional*):
150
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
151
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
152
+ revision (`str`, *optional*, defaults to `"main"`):
153
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
154
+ allowed by Git.
155
+ subfolder (`str`, *optional*, defaults to `""`):
156
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
157
+ mirror (`str`, *optional*):
158
+ Mirror source to resolve accessibility issues if you’re downloading a model in China. We do not
159
+ guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
160
+ information.
161
+
162
+ """
163
+
164
+ cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
165
+ force_download = kwargs.pop("force_download", False)
166
+ resume_download = kwargs.pop("resume_download", False)
167
+ proxies = kwargs.pop("proxies", None)
168
+ local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
169
+ use_auth_token = kwargs.pop("use_auth_token", None)
170
+ revision = kwargs.pop("revision", None)
171
+ subfolder = kwargs.pop("subfolder", None)
172
+ weight_name = kwargs.pop("weight_name", None)
173
+ use_safetensors = kwargs.pop("use_safetensors", None)
174
+ # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
175
+ # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
176
+ network_alpha = kwargs.pop("network_alpha", None)
177
+
178
+ if use_safetensors and not is_safetensors_available():
179
+ raise ValueError(
180
+ "`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetensors"
181
+ )
182
+
183
+ allow_pickle = False
184
+ if use_safetensors is None:
185
+ use_safetensors = is_safetensors_available()
186
+ allow_pickle = True
187
+
188
+ user_agent = {
189
+ "file_type": "attn_procs_weights",
190
+ "framework": "pytorch",
191
+ }
192
+
193
+ model_file = None
194
+ if not isinstance(pretrained_model_name_or_path_or_dict, dict):
195
+ # Let's first try to load .safetensors weights
196
+ if (use_safetensors and weight_name is None) or (
197
+ weight_name is not None and weight_name.endswith(".safetensors")
198
+ ):
199
+ try:
200
+ model_file = _get_model_file(
201
+ pretrained_model_name_or_path_or_dict,
202
+ weights_name=weight_name or LORA_WEIGHT_NAME_SAFE,
203
+ cache_dir=cache_dir,
204
+ force_download=force_download,
205
+ resume_download=resume_download,
206
+ proxies=proxies,
207
+ local_files_only=local_files_only,
208
+ use_auth_token=use_auth_token,
209
+ revision=revision,
210
+ subfolder=subfolder,
211
+ user_agent=user_agent,
212
+ )
213
+ state_dict = safetensors.torch.load_file(model_file, device="cpu")
214
+ except IOError as e:
215
+ if not allow_pickle:
216
+ raise e
217
+ # try loading non-safetensors weights
218
+ pass
219
+ if model_file is None:
220
+ model_file = _get_model_file(
221
+ pretrained_model_name_or_path_or_dict,
222
+ weights_name=weight_name or LORA_WEIGHT_NAME,
223
+ cache_dir=cache_dir,
224
+ force_download=force_download,
225
+ resume_download=resume_download,
226
+ proxies=proxies,
227
+ local_files_only=local_files_only,
228
+ use_auth_token=use_auth_token,
229
+ revision=revision,
230
+ subfolder=subfolder,
231
+ user_agent=user_agent,
232
+ )
233
+ state_dict = torch.load(model_file, map_location="cpu")
234
+ else:
235
+ state_dict = pretrained_model_name_or_path_or_dict
236
+
237
+ # fill attn processors
238
+ attn_processors = {}
239
+
240
+ is_lora = all("lora" in k for k in state_dict.keys())
241
+ is_custom_diffusion = any("custom_diffusion" in k for k in state_dict.keys())
242
+
243
+ if is_lora:
244
+ is_new_lora_format = all(
245
+ key.startswith(self.unet_name) or key.startswith(self.text_encoder_name) for key in state_dict.keys()
246
+ )
247
+ if is_new_lora_format:
248
+ # Strip the `"unet"` prefix.
249
+ is_text_encoder_present = any(key.startswith(self.text_encoder_name) for key in state_dict.keys())
250
+ if is_text_encoder_present:
251
+ warn_message = "The state_dict contains LoRA params corresponding to the text encoder which are not being used here. To use both UNet and text encoder related LoRA params, use [`pipe.load_lora_weights()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraLoaderMixin.load_lora_weights)."
252
+ warnings.warn(warn_message)
253
+ unet_keys = [k for k in state_dict.keys() if k.startswith(self.unet_name)]
254
+ state_dict = {k.replace(f"{self.unet_name}.", ""): v for k, v in state_dict.items() if k in unet_keys}
255
+
256
+ lora_grouped_dict = defaultdict(dict)
257
+ for key, value in state_dict.items():
258
+ attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:])
259
+ lora_grouped_dict[attn_processor_key][sub_key] = value
260
+
261
+ for key, value_dict in lora_grouped_dict.items():
262
+ rank = value_dict["to_k_lora.down.weight"].shape[0]
263
+ hidden_size = value_dict["to_k_lora.up.weight"].shape[0]
264
+
265
+ attn_processor = self
266
+ for sub_key in key.split("."):
267
+ attn_processor = getattr(attn_processor, sub_key)
268
+
269
+ if isinstance(
270
+ attn_processor, (AttnAddedKVProcessor, SlicedAttnAddedKVProcessor, AttnAddedKVProcessor2_0)
271
+ ):
272
+ cross_attention_dim = value_dict["add_k_proj_lora.down.weight"].shape[1]
273
+ attn_processor_class = LoRAAttnAddedKVProcessor
274
+ else:
275
+ cross_attention_dim = value_dict["to_k_lora.down.weight"].shape[1]
276
+ if isinstance(attn_processor, (XFormersAttnProcessor, LoRAXFormersAttnProcessor)):
277
+ attn_processor_class = LoRAXFormersAttnProcessor
278
+ else:
279
+ attn_processor_class = (
280
+ LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
281
+ )
282
+
283
+ attn_processors[key] = attn_processor_class(
284
+ hidden_size=hidden_size,
285
+ cross_attention_dim=cross_attention_dim,
286
+ rank=rank,
287
+ network_alpha=network_alpha,
288
+ )
289
+ attn_processors[key].load_state_dict(value_dict)
290
+ elif is_custom_diffusion:
291
+ custom_diffusion_grouped_dict = defaultdict(dict)
292
+ for key, value in state_dict.items():
293
+ if len(value) == 0:
294
+ custom_diffusion_grouped_dict[key] = {}
295
+ else:
296
+ if "to_out" in key:
297
+ attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:])
298
+ else:
299
+ attn_processor_key, sub_key = ".".join(key.split(".")[:-2]), ".".join(key.split(".")[-2:])
300
+ custom_diffusion_grouped_dict[attn_processor_key][sub_key] = value
301
+
302
+ for key, value_dict in custom_diffusion_grouped_dict.items():
303
+ if len(value_dict) == 0:
304
+ attn_processors[key] = CustomDiffusionAttnProcessor(
305
+ train_kv=False, train_q_out=False, hidden_size=None, cross_attention_dim=None
306
+ )
307
+ else:
308
+ cross_attention_dim = value_dict["to_k_custom_diffusion.weight"].shape[1]
309
+ hidden_size = value_dict["to_k_custom_diffusion.weight"].shape[0]
310
+ train_q_out = True if "to_q_custom_diffusion.weight" in value_dict else False
311
+ attn_processors[key] = CustomDiffusionAttnProcessor(
312
+ train_kv=True,
313
+ train_q_out=train_q_out,
314
+ hidden_size=hidden_size,
315
+ cross_attention_dim=cross_attention_dim,
316
+ )
317
+ attn_processors[key].load_state_dict(value_dict)
318
+ else:
319
+ raise ValueError(
320
+ f"{model_file} does not seem to be in the correct format expected by LoRA or Custom Diffusion training."
321
+ )
322
+
323
+ # set correct dtype & device
324
+ attn_processors = {k: v.to(device=self.device, dtype=self.dtype) for k, v in attn_processors.items()}
325
+
326
+ # set layers
327
+ self.set_attn_processor(attn_processors)
328
+
329
+ def save_attn_procs(
330
+ self,
331
+ save_directory: Union[str, os.PathLike],
332
+ is_main_process: bool = True,
333
+ weight_name: str = None,
334
+ save_function: Callable = None,
335
+ safe_serialization: bool = False,
336
+ **kwargs,
337
+ ):
338
+ r"""
339
+ Save an attention processor to a directory so that it can be reloaded using the
340
+ [`~loaders.UNet2DConditionLoadersMixin.load_attn_procs`] method.
341
+
342
+ Arguments:
343
+ save_directory (`str` or `os.PathLike`):
344
+ Directory to save an attention processor to. Will be created if it doesn't exist.
345
+ is_main_process (`bool`, *optional*, defaults to `True`):
346
+ Whether the process calling this is the main process or not. Useful during distributed training and you
347
+ need to call this function on all processes. In this case, set `is_main_process=True` only on the main
348
+ process to avoid race conditions.
349
+ save_function (`Callable`):
350
+ The function to use to save the state dictionary. Useful during distributed training when you need to
351
+ replace `torch.save` with another method. Can be configured with the environment variable
352
+ `DIFFUSERS_SAVE_MODE`.
353
+
354
+ """
355
+ weight_name = weight_name or deprecate(
356
+ "weights_name",
357
+ "0.20.0",
358
+ "`weights_name` is deprecated, please use `weight_name` instead.",
359
+ take_from=kwargs,
360
+ )
361
+ if os.path.isfile(save_directory):
362
+ logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
363
+ return
364
+
365
+ if save_function is None:
366
+ if safe_serialization:
367
+
368
+ def save_function(weights, filename):
369
+ return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"})
370
+
371
+ else:
372
+ save_function = torch.save
373
+
374
+ os.makedirs(save_directory, exist_ok=True)
375
+
376
+ is_custom_diffusion = any(
377
+ isinstance(x, (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor))
378
+ for (_, x) in self.attn_processors.items()
379
+ )
380
+ if is_custom_diffusion:
381
+ model_to_save = AttnProcsLayers(
382
+ {
383
+ y: x
384
+ for (y, x) in self.attn_processors.items()
385
+ if isinstance(x, (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor))
386
+ }
387
+ )
388
+ state_dict = model_to_save.state_dict()
389
+ for name, attn in self.attn_processors.items():
390
+ if len(attn.state_dict()) == 0:
391
+ state_dict[name] = {}
392
+ else:
393
+ model_to_save = AttnProcsLayers(self.attn_processors)
394
+ state_dict = model_to_save.state_dict()
395
+
396
+ if weight_name is None:
397
+ if safe_serialization:
398
+ weight_name = CUSTOM_DIFFUSION_WEIGHT_NAME_SAFE if is_custom_diffusion else LORA_WEIGHT_NAME_SAFE
399
+ else:
400
+ weight_name = CUSTOM_DIFFUSION_WEIGHT_NAME if is_custom_diffusion else LORA_WEIGHT_NAME
401
+
402
+ # Save the model
403
+ save_function(state_dict, os.path.join(save_directory, weight_name))
404
+ logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}")
405
+
406
+
407
+ class TextualInversionLoaderMixin:
408
+ r"""
409
+ Load textual inversion tokens and embeddings to the tokenizer and text encoder.
410
+ """
411
+
412
+ def maybe_convert_prompt(self, prompt: Union[str, List[str]], tokenizer: "PreTrainedTokenizer"):
413
+ r"""
414
+ Processes prompts that include a special token corresponding to a multi-vector textual inversion embedding to
415
+ be replaced with multiple special tokens each corresponding to one of the vectors. If the prompt has no textual
416
+ inversion token or if the textual inversion token is a single vector, the input prompt is returned.
417
+
418
+ Parameters:
419
+ prompt (`str` or list of `str`):
420
+ The prompt or prompts to guide the image generation.
421
+ tokenizer (`PreTrainedTokenizer`):
422
+ The tokenizer responsible for encoding the prompt into input tokens.
423
+
424
+ Returns:
425
+ `str` or list of `str`: The converted prompt
426
+ """
427
+ if not isinstance(prompt, List):
428
+ prompts = [prompt]
429
+ else:
430
+ prompts = prompt
431
+
432
+ prompts = [self._maybe_convert_prompt(p, tokenizer) for p in prompts]
433
+
434
+ if not isinstance(prompt, List):
435
+ return prompts[0]
436
+
437
+ return prompts
438
+
439
+ def _maybe_convert_prompt(self, prompt: str, tokenizer: "PreTrainedTokenizer"):
440
+ r"""
441
+ Maybe convert a prompt into a "multi vector"-compatible prompt. If the prompt includes a token that corresponds
442
+ to a multi-vector textual inversion embedding, this function will process the prompt so that the special token
443
+ is replaced with multiple special tokens each corresponding to one of the vectors. If the prompt has no textual
444
+ inversion token or a textual inversion token that is a single vector, the input prompt is simply returned.
445
+
446
+ Parameters:
447
+ prompt (`str`):
448
+ The prompt to guide the image generation.
449
+ tokenizer (`PreTrainedTokenizer`):
450
+ The tokenizer responsible for encoding the prompt into input tokens.
451
+
452
+ Returns:
453
+ `str`: The converted prompt
454
+ """
455
+ tokens = tokenizer.tokenize(prompt)
456
+ unique_tokens = set(tokens)
457
+ for token in unique_tokens:
458
+ if token in tokenizer.added_tokens_encoder:
459
+ replacement = token
460
+ i = 1
461
+ while f"{token}_{i}" in tokenizer.added_tokens_encoder:
462
+ replacement += f" {token}_{i}"
463
+ i += 1
464
+
465
+ prompt = prompt.replace(token, replacement)
466
+
467
+ return prompt
468
+
469
+ def load_textual_inversion(
470
+ self,
471
+ pretrained_model_name_or_path: Union[str, List[str], Dict[str, torch.Tensor], List[Dict[str, torch.Tensor]]],
472
+ token: Optional[Union[str, List[str]]] = None,
473
+ **kwargs,
474
+ ):
475
+ r"""
476
+ Load textual inversion embeddings into the text encoder of [`StableDiffusionPipeline`] (both 🤗 Diffusers and
477
+ Automatic1111 formats are supported).
478
+
479
+ Parameters:
480
+ pretrained_model_name_or_path (`str` or `os.PathLike` or `List[str or os.PathLike]` or `Dict` or `List[Dict]`):
481
+ Can be either one of the following or a list of them:
482
+
483
+ - A string, the *model id* (for example `sd-concepts-library/low-poly-hd-logos-icons`) of a
484
+ pretrained model hosted on the Hub.
485
+ - A path to a *directory* (for example `./my_text_inversion_directory/`) containing the textual
486
+ inversion weights.
487
+ - A path to a *file* (for example `./my_text_inversions.pt`) containing textual inversion weights.
488
+ - A [torch state
489
+ dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
490
+
491
+ token (`str` or `List[str]`, *optional*):
492
+ Override the token to use for the textual inversion weights. If `pretrained_model_name_or_path` is a
493
+ list, then `token` must also be a list of equal length.
494
+ weight_name (`str`, *optional*):
495
+ Name of a custom weight file. This should be used when:
496
+
497
+ - The saved textual inversion file is in 🤗 Diffusers format, but was saved under a specific weight
498
+ name such as `text_inv.bin`.
499
+ - The saved textual inversion file is in the Automatic1111 format.
500
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
501
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
502
+ is not used.
503
+ force_download (`bool`, *optional*, defaults to `False`):
504
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
505
+ cached versions if they exist.
506
+ resume_download (`bool`, *optional*, defaults to `False`):
507
+ Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
508
+ incompletely downloaded files are deleted.
509
+ proxies (`Dict[str, str]`, *optional*):
510
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
511
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
512
+ local_files_only (`bool`, *optional*, defaults to `False`):
513
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
514
+ won't be downloaded from the Hub.
515
+ use_auth_token (`str` or *bool*, *optional*):
516
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
517
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
518
+ revision (`str`, *optional*, defaults to `"main"`):
519
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
520
+ allowed by Git.
521
+ subfolder (`str`, *optional*, defaults to `""`):
522
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
523
+ mirror (`str`, *optional*):
524
+ Mirror source to resolve accessibility issues if you're downloading a model in China. We do not
525
+ guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
526
+ information.
527
+
528
+ Example:
529
+
530
+ To load a textual inversion embedding vector in 🤗 Diffusers format:
531
+
532
+ ```py
533
+ from diffusers import StableDiffusionPipeline
534
+ import torch
535
+
536
+ model_id = "runwayml/stable-diffusion-v1-5"
537
+ pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")
538
+
539
+ pipe.load_textual_inversion("sd-concepts-library/cat-toy")
540
+
541
+ prompt = "A <cat-toy> backpack"
542
+
543
+ image = pipe(prompt, num_inference_steps=50).images[0]
544
+ image.save("cat-backpack.png")
545
+ ```
546
+
547
+ To load a textual inversion embedding vector in Automatic1111 format, make sure to download the vector first
548
+ (for example from [civitAI](https://civitai.com/models/3036?modelVersionId=9857)) and then load the vector
549
+ locally:
550
+
551
+ ```py
552
+ from diffusers import StableDiffusionPipeline
553
+ import torch
554
+
555
+ model_id = "runwayml/stable-diffusion-v1-5"
556
+ pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")
557
+
558
+ pipe.load_textual_inversion("./charturnerv2.pt", token="charturnerv2")
559
+
560
+ prompt = "charturnerv2, multiple views of the same character in the same outfit, a character turnaround of a woman wearing a black jacket and red shirt, best quality, intricate details."
561
+
562
+ image = pipe(prompt, num_inference_steps=50).images[0]
563
+ image.save("character.png")
564
+ ```
565
+
566
+ """
567
+ if not hasattr(self, "tokenizer") or not isinstance(self.tokenizer, PreTrainedTokenizer):
568
+ raise ValueError(
569
+ f"{self.__class__.__name__} requires `self.tokenizer` of type `PreTrainedTokenizer` for calling"
570
+ f" `{self.load_textual_inversion.__name__}`"
571
+ )
572
+
573
+ if not hasattr(self, "text_encoder") or not isinstance(self.text_encoder, PreTrainedModel):
574
+ raise ValueError(
575
+ f"{self.__class__.__name__} requires `self.text_encoder` of type `PreTrainedModel` for calling"
576
+ f" `{self.load_textual_inversion.__name__}`"
577
+ )
578
+
579
+ cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
580
+ force_download = kwargs.pop("force_download", False)
581
+ resume_download = kwargs.pop("resume_download", False)
582
+ proxies = kwargs.pop("proxies", None)
583
+ local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
584
+ use_auth_token = kwargs.pop("use_auth_token", None)
585
+ revision = kwargs.pop("revision", None)
586
+ subfolder = kwargs.pop("subfolder", None)
587
+ weight_name = kwargs.pop("weight_name", None)
588
+ use_safetensors = kwargs.pop("use_safetensors", None)
589
+
590
+ if use_safetensors and not is_safetensors_available():
591
+ raise ValueError(
592
+ "`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetensors"
593
+ )
594
+
595
+ allow_pickle = False
596
+ if use_safetensors is None:
597
+ use_safetensors = is_safetensors_available()
598
+ allow_pickle = True
599
+
600
+ user_agent = {
601
+ "file_type": "text_inversion",
602
+ "framework": "pytorch",
603
+ }
604
+
605
+ if not isinstance(pretrained_model_name_or_path, list):
606
+ pretrained_model_name_or_paths = [pretrained_model_name_or_path]
607
+ else:
608
+ pretrained_model_name_or_paths = pretrained_model_name_or_path
609
+
610
+ if isinstance(token, str):
611
+ tokens = [token]
612
+ elif token is None:
613
+ tokens = [None] * len(pretrained_model_name_or_paths)
614
+ else:
615
+ tokens = token
616
+
617
+ if len(pretrained_model_name_or_paths) != len(tokens):
618
+ raise ValueError(
619
+ f"You have passed a list of models of length {len(pretrained_model_name_or_paths)}, and list of tokens of length {len(tokens)}"
620
+ f"Make sure both lists have the same length."
621
+ )
622
+
623
+ valid_tokens = [t for t in tokens if t is not None]
624
+ if len(set(valid_tokens)) < len(valid_tokens):
625
+ raise ValueError(f"You have passed a list of tokens that contains duplicates: {tokens}")
626
+
627
+ token_ids_and_embeddings = []
628
+
629
+ for pretrained_model_name_or_path, token in zip(pretrained_model_name_or_paths, tokens):
630
+ if not isinstance(pretrained_model_name_or_path, dict):
631
+ # 1. Load textual inversion file
632
+ model_file = None
633
+ # Let's first try to load .safetensors weights
634
+ if (use_safetensors and weight_name is None) or (
635
+ weight_name is not None and weight_name.endswith(".safetensors")
636
+ ):
637
+ try:
638
+ model_file = _get_model_file(
639
+ pretrained_model_name_or_path,
640
+ weights_name=weight_name or TEXT_INVERSION_NAME_SAFE,
641
+ cache_dir=cache_dir,
642
+ force_download=force_download,
643
+ resume_download=resume_download,
644
+ proxies=proxies,
645
+ local_files_only=local_files_only,
646
+ use_auth_token=use_auth_token,
647
+ revision=revision,
648
+ subfolder=subfolder,
649
+ user_agent=user_agent,
650
+ )
651
+ state_dict = safetensors.torch.load_file(model_file, device="cpu")
652
+ except Exception as e:
653
+ if not allow_pickle:
654
+ raise e
655
+
656
+ model_file = None
657
+
658
+ if model_file is None:
659
+ model_file = _get_model_file(
660
+ pretrained_model_name_or_path,
661
+ weights_name=weight_name or TEXT_INVERSION_NAME,
662
+ cache_dir=cache_dir,
663
+ force_download=force_download,
664
+ resume_download=resume_download,
665
+ proxies=proxies,
666
+ local_files_only=local_files_only,
667
+ use_auth_token=use_auth_token,
668
+ revision=revision,
669
+ subfolder=subfolder,
670
+ user_agent=user_agent,
671
+ )
672
+ state_dict = torch.load(model_file, map_location="cpu")
673
+ else:
674
+ state_dict = pretrained_model_name_or_path
675
+
676
+ # 2. Load token and embedding correcly from file
677
+ loaded_token = None
678
+ if isinstance(state_dict, torch.Tensor):
679
+ if token is None:
680
+ raise ValueError(
681
+ "You are trying to load a textual inversion embedding that has been saved as a PyTorch tensor. Make sure to pass the name of the corresponding token in this case: `token=...`."
682
+ )
683
+ embedding = state_dict
684
+ elif len(state_dict) == 1:
685
+ # diffusers
686
+ loaded_token, embedding = next(iter(state_dict.items()))
687
+ elif "string_to_param" in state_dict:
688
+ # A1111
689
+ loaded_token = state_dict["name"]
690
+ embedding = state_dict["string_to_param"]["*"]
691
+
692
+ if token is not None and loaded_token != token:
693
+ logger.info(f"The loaded token: {loaded_token} is overwritten by the passed token {token}.")
694
+ else:
695
+ token = loaded_token
696
+
697
+ embedding = embedding.to(dtype=self.text_encoder.dtype, device=self.text_encoder.device)
698
+
699
+ # 3. Make sure we don't mess up the tokenizer or text encoder
700
+ vocab = self.tokenizer.get_vocab()
701
+ if token in vocab:
702
+ raise ValueError(
703
+ f"Token {token} already in tokenizer vocabulary. Please choose a different token name or remove {token} and embedding from the tokenizer and text encoder."
704
+ )
705
+ elif f"{token}_1" in vocab:
706
+ multi_vector_tokens = [token]
707
+ i = 1
708
+ while f"{token}_{i}" in self.tokenizer.added_tokens_encoder:
709
+ multi_vector_tokens.append(f"{token}_{i}")
710
+ i += 1
711
+
712
+ raise ValueError(
713
+ f"Multi-vector Token {multi_vector_tokens} already in tokenizer vocabulary. Please choose a different token name or remove the {multi_vector_tokens} and embedding from the tokenizer and text encoder."
714
+ )
715
+
716
+ is_multi_vector = len(embedding.shape) > 1 and embedding.shape[0] > 1
717
+
718
+ if is_multi_vector:
719
+ tokens = [token] + [f"{token}_{i}" for i in range(1, embedding.shape[0])]
720
+ embeddings = [e for e in embedding] # noqa: C416
721
+ else:
722
+ tokens = [token]
723
+ embeddings = [embedding[0]] if len(embedding.shape) > 1 else [embedding]
724
+
725
+ # add tokens and get ids
726
+ self.tokenizer.add_tokens(tokens)
727
+ token_ids = self.tokenizer.convert_tokens_to_ids(tokens)
728
+ token_ids_and_embeddings += zip(token_ids, embeddings)
729
+
730
+ logger.info(f"Loaded textual inversion embedding for {token}.")
731
+
732
+ # resize token embeddings and set all new embeddings
733
+ self.text_encoder.resize_token_embeddings(len(self.tokenizer))
734
+ for token_id, embedding in token_ids_and_embeddings:
735
+ self.text_encoder.get_input_embeddings().weight.data[token_id] = embedding
736
+
737
+
738
+ class LoraLoaderMixin:
739
+ r"""
740
+ Load LoRA layers into [`UNet2DConditionModel`] and
741
+ [`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel).
742
+ """
743
+ text_encoder_name = TEXT_ENCODER_NAME
744
+ unet_name = UNET_NAME
745
+
746
+ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
747
+ r"""
748
+ Load pretrained LoRA attention processor layers into [`UNet2DConditionModel`] and
749
+ [`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel).
750
+
751
+ Parameters:
752
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
753
+ Can be either:
754
+
755
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
756
+ the Hub.
757
+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
758
+ with [`ModelMixin.save_pretrained`].
759
+ - A [torch state
760
+ dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
761
+
762
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
763
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
764
+ is not used.
765
+ force_download (`bool`, *optional*, defaults to `False`):
766
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
767
+ cached versions if they exist.
768
+ resume_download (`bool`, *optional*, defaults to `False`):
769
+ Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
770
+ incompletely downloaded files are deleted.
771
+ proxies (`Dict[str, str]`, *optional*):
772
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
773
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
774
+ local_files_only (`bool`, *optional*, defaults to `False`):
775
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
776
+ won't be downloaded from the Hub.
777
+ use_auth_token (`str` or *bool*, *optional*):
778
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
779
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
780
+ revision (`str`, *optional*, defaults to `"main"`):
781
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
782
+ allowed by Git.
783
+ subfolder (`str`, *optional*, defaults to `""`):
784
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
785
+ mirror (`str`, *optional*):
786
+ Mirror source to resolve accessibility issues if you're downloading a model in China. We do not
787
+ guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
788
+ information.
789
+
790
+ """
791
+ # Load the main state dict first which has the LoRA layers for either of
792
+ # UNet and text encoder or both.
793
+ cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
794
+ force_download = kwargs.pop("force_download", False)
795
+ resume_download = kwargs.pop("resume_download", False)
796
+ proxies = kwargs.pop("proxies", None)
797
+ local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
798
+ use_auth_token = kwargs.pop("use_auth_token", None)
799
+ revision = kwargs.pop("revision", None)
800
+ subfolder = kwargs.pop("subfolder", None)
801
+ weight_name = kwargs.pop("weight_name", None)
802
+ use_safetensors = kwargs.pop("use_safetensors", None)
803
+
804
+ # set lora scale to a reasonable default
805
+ self._lora_scale = 1.0
806
+
807
+ if use_safetensors and not is_safetensors_available():
808
+ raise ValueError(
809
+ "`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetensors"
810
+ )
811
+
812
+ allow_pickle = False
813
+ if use_safetensors is None:
814
+ use_safetensors = is_safetensors_available()
815
+ allow_pickle = True
816
+
817
+ user_agent = {
818
+ "file_type": "attn_procs_weights",
819
+ "framework": "pytorch",
820
+ }
821
+
822
+ model_file = None
823
+ if not isinstance(pretrained_model_name_or_path_or_dict, dict):
824
+ # Let's first try to load .safetensors weights
825
+ if (use_safetensors and weight_name is None) or (
826
+ weight_name is not None and weight_name.endswith(".safetensors")
827
+ ):
828
+ try:
829
+ model_file = _get_model_file(
830
+ pretrained_model_name_or_path_or_dict,
831
+ weights_name=weight_name or LORA_WEIGHT_NAME_SAFE,
832
+ cache_dir=cache_dir,
833
+ force_download=force_download,
834
+ resume_download=resume_download,
835
+ proxies=proxies,
836
+ local_files_only=local_files_only,
837
+ use_auth_token=use_auth_token,
838
+ revision=revision,
839
+ subfolder=subfolder,
840
+ user_agent=user_agent,
841
+ )
842
+ state_dict = safetensors.torch.load_file(model_file, device="cpu")
843
+ except IOError as e:
844
+ if not allow_pickle:
845
+ raise e
846
+ # try loading non-safetensors weights
847
+ pass
848
+ if model_file is None:
849
+ model_file = _get_model_file(
850
+ pretrained_model_name_or_path_or_dict,
851
+ weights_name=weight_name or LORA_WEIGHT_NAME,
852
+ cache_dir=cache_dir,
853
+ force_download=force_download,
854
+ resume_download=resume_download,
855
+ proxies=proxies,
856
+ local_files_only=local_files_only,
857
+ use_auth_token=use_auth_token,
858
+ revision=revision,
859
+ subfolder=subfolder,
860
+ user_agent=user_agent,
861
+ )
862
+ state_dict = torch.load(model_file, map_location="cpu")
863
+ else:
864
+ state_dict = pretrained_model_name_or_path_or_dict
865
+
866
+ # Convert kohya-ss Style LoRA attn procs to diffusers attn procs
867
+ network_alpha = None
868
+ if all((k.startswith("lora_te_") or k.startswith("lora_unet_")) for k in state_dict.keys()):
869
+ state_dict, network_alpha = self._convert_kohya_lora_to_diffusers(state_dict)
870
+
871
+ # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
872
+ # then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as
873
+ # their prefixes.
874
+ keys = list(state_dict.keys())
875
+ if all(key.startswith(self.unet_name) or key.startswith(self.text_encoder_name) for key in keys):
876
+ # Load the layers corresponding to UNet.
877
+ unet_keys = [k for k in keys if k.startswith(self.unet_name)]
878
+ logger.info(f"Loading {self.unet_name}.")
879
+ unet_lora_state_dict = {
880
+ k.replace(f"{self.unet_name}.", ""): v for k, v in state_dict.items() if k in unet_keys
881
+ }
882
+ self.unet.load_attn_procs(unet_lora_state_dict, network_alpha=network_alpha)
883
+
884
+ # Load the layers corresponding to text encoder and make necessary adjustments.
885
+ text_encoder_keys = [k for k in keys if k.startswith(self.text_encoder_name)]
886
+ text_encoder_lora_state_dict = {
887
+ k.replace(f"{self.text_encoder_name}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys
888
+ }
889
+ if len(text_encoder_lora_state_dict) > 0:
890
+ logger.info(f"Loading {self.text_encoder_name}.")
891
+ attn_procs_text_encoder = self._load_text_encoder_attn_procs(
892
+ text_encoder_lora_state_dict, network_alpha=network_alpha
893
+ )
894
+ self._modify_text_encoder(attn_procs_text_encoder)
895
+
896
+ # save lora attn procs of text encoder so that it can be easily retrieved
897
+ self._text_encoder_lora_attn_procs = attn_procs_text_encoder
898
+
899
+ # Otherwise, we're dealing with the old format. This means the `state_dict` should only
900
+ # contain the module names of the `unet` as its keys WITHOUT any prefix.
901
+ elif not all(
902
+ key.startswith(self.unet_name) or key.startswith(self.text_encoder_name) for key in state_dict.keys()
903
+ ):
904
+ self.unet.load_attn_procs(state_dict)
905
+ warn_message = "You have saved the LoRA weights using the old format. To convert the old LoRA weights to the new format, you can first load them in a dictionary and then create a new dictionary like the following: `new_state_dict = {f'unet'.{module_name}: params for module_name, params in old_state_dict.items()}`."
906
+ warnings.warn(warn_message)
907
+
908
+ @property
909
+ def lora_scale(self) -> float:
910
+ # property function that returns the lora scale which can be set at run time by the pipeline.
911
+ # if _lora_scale has not been set, return 1
912
+ return self._lora_scale if hasattr(self, "_lora_scale") else 1.0
913
+
914
+ @property
915
+ def text_encoder_lora_attn_procs(self):
916
+ if hasattr(self, "_text_encoder_lora_attn_procs"):
917
+ return self._text_encoder_lora_attn_procs
918
+ return
919
+
920
+ def _remove_text_encoder_monkey_patch(self):
921
+ # Loop over the CLIPAttention module of text_encoder
922
+ for name, attn_module in self.text_encoder.named_modules():
923
+ if name.endswith(TEXT_ENCODER_ATTN_MODULE):
924
+ # Loop over the LoRA layers
925
+ for _, text_encoder_attr in self._lora_attn_processor_attr_to_text_encoder_attr.items():
926
+ # Retrieve the q/k/v/out projection of CLIPAttention
927
+ module = attn_module.get_submodule(text_encoder_attr)
928
+ if hasattr(module, "old_forward"):
929
+ # restore original `forward` to remove monkey-patch
930
+ module.forward = module.old_forward
931
+ delattr(module, "old_forward")
932
+
933
+ def _modify_text_encoder(self, attn_processors: Dict[str, LoRAAttnProcessor]):
934
+ r"""
935
+ Monkey-patches the forward passes of attention modules of the text encoder.
936
+
937
+ Parameters:
938
+ attn_processors: Dict[str, `LoRAAttnProcessor`]:
939
+ A dictionary mapping the module names and their corresponding [`~LoRAAttnProcessor`].
940
+ """
941
+
942
+ # First, remove any monkey-patch that might have been applied before
943
+ self._remove_text_encoder_monkey_patch()
944
+
945
+ # Loop over the CLIPAttention module of text_encoder
946
+ for name, attn_module in self.text_encoder.named_modules():
947
+ if name.endswith(TEXT_ENCODER_ATTN_MODULE):
948
+ # Loop over the LoRA layers
949
+ for attn_proc_attr, text_encoder_attr in self._lora_attn_processor_attr_to_text_encoder_attr.items():
950
+ # Retrieve the q/k/v/out projection of CLIPAttention and its corresponding LoRA layer.
951
+ module = attn_module.get_submodule(text_encoder_attr)
952
+ lora_layer = attn_processors[name].get_submodule(attn_proc_attr)
953
+
954
+ # save old_forward to module that can be used to remove monkey-patch
955
+ old_forward = module.old_forward = module.forward
956
+
957
+ # create a new scope that locks in the old_forward, lora_layer value for each new_forward function
958
+ # for more detail, see https://github.com/huggingface/diffusers/pull/3490#issuecomment-1555059060
959
+ def make_new_forward(old_forward, lora_layer):
960
+ def new_forward(x):
961
+ result = old_forward(x) + self.lora_scale * lora_layer(x)
962
+ return result
963
+
964
+ return new_forward
965
+
966
+ # Monkey-patch.
967
+ module.forward = make_new_forward(old_forward, lora_layer)
968
+
969
+ @property
970
+ def _lora_attn_processor_attr_to_text_encoder_attr(self):
971
+ return {
972
+ "to_q_lora": "q_proj",
973
+ "to_k_lora": "k_proj",
974
+ "to_v_lora": "v_proj",
975
+ "to_out_lora": "out_proj",
976
+ }
977
+
978
+ def _load_text_encoder_attn_procs(
979
+ self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs
980
+ ):
981
+ r"""
982
+ Load pretrained attention processor layers for
983
+ [`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel).
984
+
985
+ <Tip warning={true}>
986
+
987
+ This function is experimental and might change in the future.
988
+
989
+ </Tip>
990
+
991
+ Parameters:
992
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
993
+ Can be either:
994
+
995
+ - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
996
+ Valid model ids should have an organization name, like `google/ddpm-celebahq-256`.
997
+ - A path to a *directory* containing model weights saved using [`~ModelMixin.save_config`], e.g.,
998
+ `./my_model_directory/`.
999
+ - A [torch state
1000
+ dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
1001
+
1002
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
1003
+ Path to a directory in which a downloaded pretrained model configuration should be cached if the
1004
+ standard cache should not be used.
1005
+ force_download (`bool`, *optional*, defaults to `False`):
1006
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
1007
+ cached versions if they exist.
1008
+ resume_download (`bool`, *optional*, defaults to `False`):
1009
+ Whether or not to delete incompletely received files. Will attempt to resume the download if such a
1010
+ file exists.
1011
+ proxies (`Dict[str, str]`, *optional*):
1012
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
1013
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
1014
+ local_files_only (`bool`, *optional*, defaults to `False`):
1015
+ Whether or not to only look at local files (i.e., do not try to download the model).
1016
+ use_auth_token (`str` or *bool*, *optional*):
1017
+ The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
1018
+ when running `diffusers-cli login` (stored in `~/.huggingface`).
1019
+ revision (`str`, *optional*, defaults to `"main"`):
1020
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
1021
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
1022
+ identifier allowed by git.
1023
+ subfolder (`str`, *optional*, defaults to `""`):
1024
+ In case the relevant files are located inside a subfolder of the model repo (either remote in
1025
+ huggingface.co or downloaded locally), you can specify the folder name here.
1026
+ mirror (`str`, *optional*):
1027
+ Mirror source to accelerate downloads in China. If you are from China and have an accessibility
1028
+ problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
1029
+ Please refer to the mirror site for more information.
1030
+
1031
+ Returns:
1032
+ `Dict[name, LoRAAttnProcessor]`: Mapping between the module names and their corresponding
1033
+ [`LoRAAttnProcessor`].
1034
+
1035
+ <Tip>
1036
+
1037
+ It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated
1038
+ models](https://huggingface.co/docs/hub/models-gated#gated-models).
1039
+
1040
+ </Tip>
1041
+ """
1042
+
1043
+ cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
1044
+ force_download = kwargs.pop("force_download", False)
1045
+ resume_download = kwargs.pop("resume_download", False)
1046
+ proxies = kwargs.pop("proxies", None)
1047
+ local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
1048
+ use_auth_token = kwargs.pop("use_auth_token", None)
1049
+ revision = kwargs.pop("revision", None)
1050
+ subfolder = kwargs.pop("subfolder", None)
1051
+ weight_name = kwargs.pop("weight_name", None)
1052
+ use_safetensors = kwargs.pop("use_safetensors", None)
1053
+ network_alpha = kwargs.pop("network_alpha", None)
1054
+
1055
+ if use_safetensors and not is_safetensors_available():
1056
+ raise ValueError(
1057
+ "`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetensors"
1058
+ )
1059
+
1060
+ allow_pickle = False
1061
+ if use_safetensors is None:
1062
+ use_safetensors = is_safetensors_available()
1063
+ allow_pickle = True
1064
+
1065
+ user_agent = {
1066
+ "file_type": "attn_procs_weights",
1067
+ "framework": "pytorch",
1068
+ }
1069
+
1070
+ model_file = None
1071
+ if not isinstance(pretrained_model_name_or_path_or_dict, dict):
1072
+ # Let's first try to load .safetensors weights
1073
+ if (use_safetensors and weight_name is None) or (
1074
+ weight_name is not None and weight_name.endswith(".safetensors")
1075
+ ):
1076
+ try:
1077
+ model_file = _get_model_file(
1078
+ pretrained_model_name_or_path_or_dict,
1079
+ weights_name=weight_name or LORA_WEIGHT_NAME_SAFE,
1080
+ cache_dir=cache_dir,
1081
+ force_download=force_download,
1082
+ resume_download=resume_download,
1083
+ proxies=proxies,
1084
+ local_files_only=local_files_only,
1085
+ use_auth_token=use_auth_token,
1086
+ revision=revision,
1087
+ subfolder=subfolder,
1088
+ user_agent=user_agent,
1089
+ )
1090
+ state_dict = safetensors.torch.load_file(model_file, device="cpu")
1091
+ except IOError as e:
1092
+ if not allow_pickle:
1093
+ raise e
1094
+ # try loading non-safetensors weights
1095
+ pass
1096
+ if model_file is None:
1097
+ model_file = _get_model_file(
1098
+ pretrained_model_name_or_path_or_dict,
1099
+ weights_name=weight_name or LORA_WEIGHT_NAME,
1100
+ cache_dir=cache_dir,
1101
+ force_download=force_download,
1102
+ resume_download=resume_download,
1103
+ proxies=proxies,
1104
+ local_files_only=local_files_only,
1105
+ use_auth_token=use_auth_token,
1106
+ revision=revision,
1107
+ subfolder=subfolder,
1108
+ user_agent=user_agent,
1109
+ )
1110
+ state_dict = torch.load(model_file, map_location="cpu")
1111
+ else:
1112
+ state_dict = pretrained_model_name_or_path_or_dict
1113
+
1114
+ # fill attn processors
1115
+ attn_processors = {}
1116
+
1117
+ is_lora = all("lora" in k for k in state_dict.keys())
1118
+
1119
+ if is_lora:
1120
+ lora_grouped_dict = defaultdict(dict)
1121
+ for key, value in state_dict.items():
1122
+ attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:])
1123
+ lora_grouped_dict[attn_processor_key][sub_key] = value
1124
+
1125
+ for key, value_dict in lora_grouped_dict.items():
1126
+ rank = value_dict["to_k_lora.down.weight"].shape[0]
1127
+ cross_attention_dim = value_dict["to_k_lora.down.weight"].shape[1]
1128
+ hidden_size = value_dict["to_k_lora.up.weight"].shape[0]
1129
+
1130
+ attn_processor_class = (
1131
+ LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
1132
+ )
1133
+ attn_processors[key] = attn_processor_class(
1134
+ hidden_size=hidden_size,
1135
+ cross_attention_dim=cross_attention_dim,
1136
+ rank=rank,
1137
+ network_alpha=network_alpha,
1138
+ )
1139
+ attn_processors[key].load_state_dict(value_dict)
1140
+
1141
+ else:
1142
+ raise ValueError(f"{model_file} does not seem to be in the correct format expected by LoRA training.")
1143
+
1144
+ # set correct dtype & device
1145
+ attn_processors = {
1146
+ k: v.to(device=self.device, dtype=self.text_encoder.dtype) for k, v in attn_processors.items()
1147
+ }
1148
+ return attn_processors
1149
+
1150
+ @classmethod
1151
+ def save_lora_weights(
1152
+ self,
1153
+ save_directory: Union[str, os.PathLike],
1154
+ unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
1155
+ text_encoder_lora_layers: Dict[str, torch.nn.Module] = None,
1156
+ is_main_process: bool = True,
1157
+ weight_name: str = None,
1158
+ save_function: Callable = None,
1159
+ safe_serialization: bool = False,
1160
+ ):
1161
+ r"""
1162
+ Save the LoRA parameters corresponding to the UNet and text encoder.
1163
+
1164
+ Arguments:
1165
+ save_directory (`str` or `os.PathLike`):
1166
+ Directory to save LoRA parameters to. Will be created if it doesn't exist.
1167
+ unet_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
1168
+ State dict of the LoRA layers corresponding to the UNet.
1169
+ text_encoder_lora_layers (`Dict[str, torch.nn.Module] or `Dict[str, torch.Tensor]`):
1170
+ State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text
1171
+ encoder LoRA state dict because it comes 🤗 Transformers.
1172
+ is_main_process (`bool`, *optional*, defaults to `True`):
1173
+ Whether the process calling this is the main process or not. Useful during distributed training and you
1174
+ need to call this function on all processes. In this case, set `is_main_process=True` only on the main
1175
+ process to avoid race conditions.
1176
+ save_function (`Callable`):
1177
+ The function to use to save the state dictionary. Useful during distributed training when you need to
1178
+ replace `torch.save` with another method. Can be configured with the environment variable
1179
+ `DIFFUSERS_SAVE_MODE`.
1180
+ """
1181
+ if os.path.isfile(save_directory):
1182
+ logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
1183
+ return
1184
+
1185
+ if save_function is None:
1186
+ if safe_serialization:
1187
+
1188
+ def save_function(weights, filename):
1189
+ return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"})
1190
+
1191
+ else:
1192
+ save_function = torch.save
1193
+
1194
+ os.makedirs(save_directory, exist_ok=True)
1195
+
1196
+ # Create a flat dictionary.
1197
+ state_dict = {}
1198
+ if unet_lora_layers is not None:
1199
+ weights = (
1200
+ unet_lora_layers.state_dict() if isinstance(unet_lora_layers, torch.nn.Module) else unet_lora_layers
1201
+ )
1202
+
1203
+ unet_lora_state_dict = {f"{self.unet_name}.{module_name}": param for module_name, param in weights.items()}
1204
+ state_dict.update(unet_lora_state_dict)
1205
+
1206
+ if text_encoder_lora_layers is not None:
1207
+ weights = (
1208
+ text_encoder_lora_layers.state_dict()
1209
+ if isinstance(text_encoder_lora_layers, torch.nn.Module)
1210
+ else text_encoder_lora_layers
1211
+ )
1212
+
1213
+ text_encoder_lora_state_dict = {
1214
+ f"{self.text_encoder_name}.{module_name}": param for module_name, param in weights.items()
1215
+ }
1216
+ state_dict.update(text_encoder_lora_state_dict)
1217
+
1218
+ # Save the model
1219
+ if weight_name is None:
1220
+ if safe_serialization:
1221
+ weight_name = LORA_WEIGHT_NAME_SAFE
1222
+ else:
1223
+ weight_name = LORA_WEIGHT_NAME
1224
+
1225
+ save_function(state_dict, os.path.join(save_directory, weight_name))
1226
+ logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}")
1227
+
1228
+ def _convert_kohya_lora_to_diffusers(self, state_dict):
1229
+ unet_state_dict = {}
1230
+ te_state_dict = {}
1231
+ network_alpha = None
1232
+
1233
+ for key, value in state_dict.items():
1234
+ if "lora_down" in key:
1235
+ lora_name = key.split(".")[0]
1236
+ lora_name_up = lora_name + ".lora_up.weight"
1237
+ lora_name_alpha = lora_name + ".alpha"
1238
+ if lora_name_alpha in state_dict:
1239
+ alpha = state_dict[lora_name_alpha].item()
1240
+ if network_alpha is None:
1241
+ network_alpha = alpha
1242
+ elif network_alpha != alpha:
1243
+ raise ValueError("Network alpha is not consistent")
1244
+
1245
+ if lora_name.startswith("lora_unet_"):
1246
+ diffusers_name = key.replace("lora_unet_", "").replace("_", ".")
1247
+ diffusers_name = diffusers_name.replace("down.blocks", "down_blocks")
1248
+ diffusers_name = diffusers_name.replace("mid.block", "mid_block")
1249
+ diffusers_name = diffusers_name.replace("up.blocks", "up_blocks")
1250
+ diffusers_name = diffusers_name.replace("transformer.blocks", "transformer_blocks")
1251
+ diffusers_name = diffusers_name.replace("to.q.lora", "to_q_lora")
1252
+ diffusers_name = diffusers_name.replace("to.k.lora", "to_k_lora")
1253
+ diffusers_name = diffusers_name.replace("to.v.lora", "to_v_lora")
1254
+ diffusers_name = diffusers_name.replace("to.out.0.lora", "to_out_lora")
1255
+ if "transformer_blocks" in diffusers_name:
1256
+ if "attn1" in diffusers_name or "attn2" in diffusers_name:
1257
+ diffusers_name = diffusers_name.replace("attn1", "attn1.processor")
1258
+ diffusers_name = diffusers_name.replace("attn2", "attn2.processor")
1259
+ unet_state_dict[diffusers_name] = value
1260
+ unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up]
1261
+ elif lora_name.startswith("lora_te_"):
1262
+ diffusers_name = key.replace("lora_te_", "").replace("_", ".")
1263
+ diffusers_name = diffusers_name.replace("text.model", "text_model")
1264
+ diffusers_name = diffusers_name.replace("self.attn", "self_attn")
1265
+ diffusers_name = diffusers_name.replace("q.proj.lora", "to_q_lora")
1266
+ diffusers_name = diffusers_name.replace("k.proj.lora", "to_k_lora")
1267
+ diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora")
1268
+ diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora")
1269
+ if "self_attn" in diffusers_name:
1270
+ te_state_dict[diffusers_name] = value
1271
+ te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up]
1272
+
1273
+ unet_state_dict = {f"{UNET_NAME}.{module_name}": params for module_name, params in unet_state_dict.items()}
1274
+ te_state_dict = {f"{TEXT_ENCODER_NAME}.{module_name}": params for module_name, params in te_state_dict.items()}
1275
+ new_state_dict = {**unet_state_dict, **te_state_dict}
1276
+ return new_state_dict, network_alpha
1277
+
1278
+
1279
+ class FromSingleFileMixin:
1280
+ """
1281
+ Load model weights saved in the `.ckpt` format into a [`DiffusionPipeline`].
1282
+ """
1283
+
1284
+ @classmethod
1285
+ def from_ckpt(cls, *args, **kwargs):
1286
+ deprecation_message = "The function `from_ckpt` is deprecated in favor of `from_single_file` and will be removed in diffusers v.0.21. Please make sure to use `StableDiffusionPipeline.from_single_file(...)` instead."
1287
+ deprecate("from_ckpt", "0.21.0", deprecation_message, standard_warn=False)
1288
+ return cls.from_single_file(*args, **kwargs)
1289
+
1290
+ @classmethod
1291
+ def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
1292
+ r"""
1293
+ Instantiate a [`DiffusionPipeline`] from pretrained pipeline weights saved in the `.ckpt` format. The pipeline
1294
+ is set in evaluation mode (`model.eval()`) by default.
1295
+
1296
+ Parameters:
1297
+ pretrained_model_link_or_path (`str` or `os.PathLike`, *optional*):
1298
+ Can be either:
1299
+ - A link to the `.ckpt` file (for example
1300
+ `"https://huggingface.co/<repo_id>/blob/main/<path_to_file>.ckpt"`) on the Hub.
1301
+ - A path to a *file* containing all pipeline weights.
1302
+ torch_dtype (`str` or `torch.dtype`, *optional*):
1303
+ Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
1304
+ dtype is automatically derived from the model's weights.
1305
+ force_download (`bool`, *optional*, defaults to `False`):
1306
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
1307
+ cached versions if they exist.
1308
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
1309
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
1310
+ is not used.
1311
+ resume_download (`bool`, *optional*, defaults to `False`):
1312
+ Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
1313
+ incompletely downloaded files are deleted.
1314
+ proxies (`Dict[str, str]`, *optional*):
1315
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
1316
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
1317
+ local_files_only (`bool`, *optional*, defaults to `False`):
1318
+ Whether to only load local model weights and configuration files or not. If set to True, the model
1319
+ won't be downloaded from the Hub.
1320
+ use_auth_token (`str` or *bool*, *optional*):
1321
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
1322
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
1323
+ revision (`str`, *optional*, defaults to `"main"`):
1324
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
1325
+ allowed by Git.
1326
+ use_safetensors (`bool`, *optional*, defaults to `None`):
1327
+ If set to `None`, the safetensors weights are downloaded if they're available **and** if the
1328
+ safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors
1329
+ weights. If set to `False`, safetensors weights are not loaded.
1330
+ extract_ema (`bool`, *optional*, defaults to `False`):
1331
+ Whether to extract the EMA weights or not. Pass `True` to extract the EMA weights which usually yield
1332
+ higher quality images for inference. Non-EMA weights are usually better to continue finetuning.
1333
+ upcast_attention (`bool`, *optional*, defaults to `None`):
1334
+ Whether the attention computation should always be upcasted.
1335
+ image_size (`int`, *optional*, defaults to 512):
1336
+ The image size the model was trained on. Use 512 for all Stable Diffusion v1 models and the Stable
1337
+ Diffusion v2 base model. Use 768 for Stable Diffusion v2.
1338
+ prediction_type (`str`, *optional*):
1339
+ The prediction type the model was trained on. Use `'epsilon'` for all Stable Diffusion v1 models and
1340
+ the Stable Diffusion v2 base model. Use `'v_prediction'` for Stable Diffusion v2.
1341
+ num_in_channels (`int`, *optional*, defaults to `None`):
1342
+ The number of input channels. If `None`, it will be automatically inferred.
1343
+ scheduler_type (`str`, *optional*, defaults to `"pndm"`):
1344
+ Type of scheduler to use. Should be one of `["pndm", "lms", "heun", "euler", "euler-ancestral", "dpm",
1345
+ "ddim"]`.
1346
+ load_safety_checker (`bool`, *optional*, defaults to `True`):
1347
+ Whether to load the safety checker or not.
1348
+ text_encoder (`CLIPTextModel`, *optional*, defaults to `None`):
1349
+ An instance of
1350
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel) to use,
1351
+ specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)
1352
+ variant. If this parameter is `None`, the function will load a new instance of [CLIP] by itself, if
1353
+ needed.
1354
+ tokenizer (`CLIPTokenizer`, *optional*, defaults to `None`):
1355
+ An instance of
1356
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer)
1357
+ to use. If this parameter is `None`, the function will load a new instance of [CLIPTokenizer] by
1358
+ itself, if needed.
1359
+ kwargs (remaining dictionary of keyword arguments, *optional*):
1360
+ Can be used to overwrite load and saveable variables (for example the pipeline components of the
1361
+ specific pipeline class). The overwritten components are directly passed to the pipelines `__init__`
1362
+ method. See example below for more information.
1363
+
1364
+ Examples:
1365
+
1366
+ ```py
1367
+ >>> from diffusers import StableDiffusionPipeline
1368
+
1369
+ >>> # Download pipeline from huggingface.co and cache.
1370
+ >>> pipeline = StableDiffusionPipeline.from_single_file(
1371
+ ... "https://huggingface.co/WarriorMama777/OrangeMixs/blob/main/Models/AbyssOrangeMix/AbyssOrangeMix.safetensors"
1372
+ ... )
1373
+
1374
+ >>> # Download pipeline from local file
1375
+ >>> # file is downloaded under ./v1-5-pruned-emaonly.ckpt
1376
+ >>> pipeline = StableDiffusionPipeline.from_single_file("./v1-5-pruned-emaonly")
1377
+
1378
+ >>> # Enable float16 and move to GPU
1379
+ >>> pipeline = StableDiffusionPipeline.from_single_file(
1380
+ ... "https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.ckpt",
1381
+ ... torch_dtype=torch.float16,
1382
+ ... )
1383
+ >>> pipeline.to("cuda")
1384
+ ```
1385
+ """
1386
+ # import here to avoid circular dependency
1387
+ from .pipelines.stable_diffusion.convert_from_ckpt import download_from_original_stable_diffusion_ckpt
1388
+
1389
+ cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
1390
+ resume_download = kwargs.pop("resume_download", False)
1391
+ force_download = kwargs.pop("force_download", False)
1392
+ proxies = kwargs.pop("proxies", None)
1393
+ local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
1394
+ use_auth_token = kwargs.pop("use_auth_token", None)
1395
+ revision = kwargs.pop("revision", None)
1396
+ extract_ema = kwargs.pop("extract_ema", False)
1397
+ image_size = kwargs.pop("image_size", None)
1398
+ scheduler_type = kwargs.pop("scheduler_type", "pndm")
1399
+ num_in_channels = kwargs.pop("num_in_channels", None)
1400
+ upcast_attention = kwargs.pop("upcast_attention", None)
1401
+ load_safety_checker = kwargs.pop("load_safety_checker", True)
1402
+ prediction_type = kwargs.pop("prediction_type", None)
1403
+ text_encoder = kwargs.pop("text_encoder", None)
1404
+ tokenizer = kwargs.pop("tokenizer", None)
1405
+
1406
+ torch_dtype = kwargs.pop("torch_dtype", None)
1407
+
1408
+ use_safetensors = kwargs.pop("use_safetensors", None if is_safetensors_available() else False)
1409
+
1410
+ pipeline_name = cls.__name__
1411
+ file_extension = pretrained_model_link_or_path.rsplit(".", 1)[-1]
1412
+ from_safetensors = file_extension == "safetensors"
1413
+
1414
+ if from_safetensors and use_safetensors is False:
1415
+ raise ValueError("Make sure to install `safetensors` with `pip install safetensors`.")
1416
+
1417
+ # TODO: For now we only support stable diffusion
1418
+ stable_unclip = None
1419
+ model_type = None
1420
+ controlnet = False
1421
+
1422
+ if pipeline_name == "StableDiffusionControlNetPipeline":
1423
+ # Model type will be inferred from the checkpoint.
1424
+ controlnet = True
1425
+ elif "StableDiffusion" in pipeline_name:
1426
+ # Model type will be inferred from the checkpoint.
1427
+ pass
1428
+ elif pipeline_name == "StableUnCLIPPipeline":
1429
+ model_type = "FrozenOpenCLIPEmbedder"
1430
+ stable_unclip = "txt2img"
1431
+ elif pipeline_name == "StableUnCLIPImg2ImgPipeline":
1432
+ model_type = "FrozenOpenCLIPEmbedder"
1433
+ stable_unclip = "img2img"
1434
+ elif pipeline_name == "PaintByExamplePipeline":
1435
+ model_type = "PaintByExample"
1436
+ elif pipeline_name == "LDMTextToImagePipeline":
1437
+ model_type = "LDMTextToImage"
1438
+ else:
1439
+ raise ValueError(f"Unhandled pipeline class: {pipeline_name}")
1440
+
1441
+ # remove huggingface url
1442
+ for prefix in ["https://huggingface.co/", "huggingface.co/", "hf.co/", "https://hf.co/"]:
1443
+ if pretrained_model_link_or_path.startswith(prefix):
1444
+ pretrained_model_link_or_path = pretrained_model_link_or_path[len(prefix) :]
1445
+
1446
+ # Code based on diffusers.pipelines.pipeline_utils.DiffusionPipeline.from_pretrained
1447
+ ckpt_path = Path(pretrained_model_link_or_path)
1448
+ if not ckpt_path.is_file():
1449
+ # get repo_id and (potentially nested) file path of ckpt in repo
1450
+ repo_id = "/".join(ckpt_path.parts[:2])
1451
+ file_path = "/".join(ckpt_path.parts[2:])
1452
+
1453
+ if file_path.startswith("blob/"):
1454
+ file_path = file_path[len("blob/") :]
1455
+
1456
+ if file_path.startswith("main/"):
1457
+ file_path = file_path[len("main/") :]
1458
+
1459
+ pretrained_model_link_or_path = hf_hub_download(
1460
+ repo_id,
1461
+ filename=file_path,
1462
+ cache_dir=cache_dir,
1463
+ resume_download=resume_download,
1464
+ proxies=proxies,
1465
+ local_files_only=local_files_only,
1466
+ use_auth_token=use_auth_token,
1467
+ revision=revision,
1468
+ force_download=force_download,
1469
+ )
1470
+
1471
+ pipe = download_from_original_stable_diffusion_ckpt(
1472
+ pretrained_model_link_or_path,
1473
+ pipeline_class=cls,
1474
+ model_type=model_type,
1475
+ stable_unclip=stable_unclip,
1476
+ controlnet=controlnet,
1477
+ from_safetensors=from_safetensors,
1478
+ extract_ema=extract_ema,
1479
+ image_size=image_size,
1480
+ scheduler_type=scheduler_type,
1481
+ num_in_channels=num_in_channels,
1482
+ upcast_attention=upcast_attention,
1483
+ load_safety_checker=load_safety_checker,
1484
+ prediction_type=prediction_type,
1485
+ text_encoder=text_encoder,
1486
+ tokenizer=tokenizer,
1487
+ )
1488
+
1489
+ if torch_dtype is not None:
1490
+ pipe.to(torch_dtype=torch_dtype)
1491
+
1492
+ return pipe
6DoF/diffusers/models/__init__.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from ..utils import is_flax_available, is_torch_available
16
+
17
+
18
+ if is_torch_available():
19
+ from .autoencoder_kl import AutoencoderKL
20
+ from .controlnet import ControlNetModel
21
+ from .dual_transformer_2d import DualTransformer2DModel
22
+ from .modeling_utils import ModelMixin
23
+ from .prior_transformer import PriorTransformer
24
+ from .t5_film_transformer import T5FilmDecoder
25
+ from .transformer_2d import Transformer2DModel
26
+ from .unet_1d import UNet1DModel
27
+ from .unet_2d import UNet2DModel
28
+ from .unet_2d_condition import UNet2DConditionModel
29
+ from .unet_3d_condition import UNet3DConditionModel
30
+ from .vq_model import VQModel
31
+
32
+ if is_flax_available():
33
+ from .controlnet_flax import FlaxControlNetModel
34
+ from .unet_2d_condition_flax import FlaxUNet2DConditionModel
35
+ from .vae_flax import FlaxAutoencoderKL
6DoF/diffusers/models/activations.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+
3
+
4
+ def get_activation(act_fn):
5
+ if act_fn in ["swish", "silu"]:
6
+ return nn.SiLU()
7
+ elif act_fn == "mish":
8
+ return nn.Mish()
9
+ elif act_fn == "gelu":
10
+ return nn.GELU()
11
+ else:
12
+ raise ValueError(f"Unsupported activation function: {act_fn}")
6DoF/diffusers/models/attention.py ADDED
@@ -0,0 +1,392 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Any, Dict, Optional
15
+
16
+ import torch
17
+ import torch.nn.functional as F
18
+ from torch import nn
19
+
20
+ from ..utils import maybe_allow_in_graph
21
+ from .activations import get_activation
22
+ from .attention_processor import Attention
23
+ from .embeddings import CombinedTimestepLabelEmbeddings
24
+
25
+
26
+ @maybe_allow_in_graph
27
+ class BasicTransformerBlock(nn.Module):
28
+ r"""
29
+ A basic Transformer block.
30
+
31
+ Parameters:
32
+ dim (`int`): The number of channels in the input and output.
33
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
34
+ attention_head_dim (`int`): The number of channels in each head.
35
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
36
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
37
+ only_cross_attention (`bool`, *optional*):
38
+ Whether to use only cross-attention layers. In this case two cross attention layers are used.
39
+ double_self_attention (`bool`, *optional*):
40
+ Whether to use two self-attention layers. In this case no cross attention layers are used.
41
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
42
+ num_embeds_ada_norm (:
43
+ obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
44
+ attention_bias (:
45
+ obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
46
+ """
47
+
48
+ def __init__(
49
+ self,
50
+ dim: int,
51
+ num_attention_heads: int,
52
+ attention_head_dim: int,
53
+ dropout=0.0,
54
+ cross_attention_dim: Optional[int] = None,
55
+ activation_fn: str = "geglu",
56
+ num_embeds_ada_norm: Optional[int] = None,
57
+ attention_bias: bool = False,
58
+ only_cross_attention: bool = False,
59
+ double_self_attention: bool = False,
60
+ upcast_attention: bool = False,
61
+ norm_elementwise_affine: bool = True,
62
+ norm_type: str = "layer_norm",
63
+ final_dropout: bool = False,
64
+ ):
65
+ super().__init__()
66
+ self.only_cross_attention = only_cross_attention
67
+
68
+ self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
69
+ self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
70
+
71
+ if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
72
+ raise ValueError(
73
+ f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
74
+ f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
75
+ )
76
+
77
+ # Define 3 blocks. Each block has its own normalization layer.
78
+ # 1. Self-Attn
79
+ if self.use_ada_layer_norm:
80
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
81
+ elif self.use_ada_layer_norm_zero:
82
+ self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
83
+ else:
84
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
85
+ self.attn1 = Attention(
86
+ query_dim=dim,
87
+ heads=num_attention_heads,
88
+ dim_head=attention_head_dim,
89
+ dropout=dropout,
90
+ bias=attention_bias,
91
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
92
+ upcast_attention=upcast_attention,
93
+ )
94
+
95
+ # 2. Cross-Attn
96
+ if cross_attention_dim is not None or double_self_attention:
97
+ # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
98
+ # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
99
+ # the second cross attention block.
100
+ self.norm2 = (
101
+ AdaLayerNorm(dim, num_embeds_ada_norm)
102
+ if self.use_ada_layer_norm
103
+ else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
104
+ )
105
+ self.attn2 = Attention(
106
+ query_dim=dim,
107
+ cross_attention_dim=cross_attention_dim if not double_self_attention else None,
108
+ heads=num_attention_heads,
109
+ dim_head=attention_head_dim,
110
+ dropout=dropout,
111
+ bias=attention_bias,
112
+ upcast_attention=upcast_attention,
113
+ ) # is self-attn if encoder_hidden_states is none
114
+ else:
115
+ self.norm2 = None
116
+ self.attn2 = None
117
+
118
+ # 3. Feed-forward
119
+ self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
120
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
121
+
122
+ # let chunk size default to None
123
+ self._chunk_size = None
124
+ self._chunk_dim = 0
125
+
126
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
127
+ # Sets chunk feed-forward
128
+ self._chunk_size = chunk_size
129
+ self._chunk_dim = dim
130
+
131
+ def forward(
132
+ self,
133
+ hidden_states: torch.FloatTensor,
134
+ attention_mask: Optional[torch.FloatTensor] = None,
135
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
136
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
137
+ timestep: Optional[torch.LongTensor] = None,
138
+ posemb: Optional = None,
139
+ cross_attention_kwargs: Dict[str, Any] = None,
140
+ class_labels: Optional[torch.LongTensor] = None,
141
+ ):
142
+ # Notice that normalization is always applied before the real computation in the following blocks.
143
+ # 1. Self-Attention
144
+ if self.use_ada_layer_norm:
145
+ norm_hidden_states = self.norm1(hidden_states, timestep)
146
+ elif self.use_ada_layer_norm_zero:
147
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
148
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
149
+ )
150
+ else:
151
+ norm_hidden_states = self.norm1(hidden_states)
152
+
153
+ cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
154
+
155
+ attn_output = self.attn1(
156
+ norm_hidden_states,
157
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
158
+ attention_mask=attention_mask,
159
+ posemb=posemb, # todo in self attn, posemb shoule be [pose_in, pose_in]?
160
+ **cross_attention_kwargs,
161
+ )
162
+ if self.use_ada_layer_norm_zero:
163
+ attn_output = gate_msa.unsqueeze(1) * attn_output
164
+ hidden_states = attn_output + hidden_states
165
+
166
+ # 2. Cross-Attention
167
+ if self.attn2 is not None:
168
+ norm_hidden_states = (
169
+ self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
170
+ )
171
+
172
+ attn_output = self.attn2(
173
+ norm_hidden_states,
174
+ encoder_hidden_states=encoder_hidden_states,
175
+ attention_mask=encoder_attention_mask,
176
+ posemb=posemb,
177
+ **cross_attention_kwargs,
178
+ )
179
+ hidden_states = attn_output + hidden_states
180
+
181
+ # 3. Feed-forward
182
+ norm_hidden_states = self.norm3(hidden_states)
183
+
184
+ if self.use_ada_layer_norm_zero:
185
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
186
+
187
+ if self._chunk_size is not None:
188
+ # "feed_forward_chunk_size" can be used to save memory
189
+ if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
190
+ raise ValueError(
191
+ f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
192
+ )
193
+
194
+ num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
195
+ ff_output = torch.cat(
196
+ [self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)],
197
+ dim=self._chunk_dim,
198
+ )
199
+ else:
200
+ ff_output = self.ff(norm_hidden_states)
201
+
202
+ if self.use_ada_layer_norm_zero:
203
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
204
+
205
+ hidden_states = ff_output + hidden_states
206
+
207
+ return hidden_states
208
+
209
+
210
+ class FeedForward(nn.Module):
211
+ r"""
212
+ A feed-forward layer.
213
+
214
+ Parameters:
215
+ dim (`int`): The number of channels in the input.
216
+ dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
217
+ mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
218
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
219
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
220
+ final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
221
+ """
222
+
223
+ def __init__(
224
+ self,
225
+ dim: int,
226
+ dim_out: Optional[int] = None,
227
+ mult: int = 4,
228
+ dropout: float = 0.0,
229
+ activation_fn: str = "geglu",
230
+ final_dropout: bool = False,
231
+ ):
232
+ super().__init__()
233
+ inner_dim = int(dim * mult)
234
+ dim_out = dim_out if dim_out is not None else dim
235
+
236
+ if activation_fn == "gelu":
237
+ act_fn = GELU(dim, inner_dim)
238
+ if activation_fn == "gelu-approximate":
239
+ act_fn = GELU(dim, inner_dim, approximate="tanh")
240
+ elif activation_fn == "geglu":
241
+ act_fn = GEGLU(dim, inner_dim)
242
+ elif activation_fn == "geglu-approximate":
243
+ act_fn = ApproximateGELU(dim, inner_dim)
244
+
245
+ self.net = nn.ModuleList([])
246
+ # project in
247
+ self.net.append(act_fn)
248
+ # project dropout
249
+ self.net.append(nn.Dropout(dropout))
250
+ # project out
251
+ self.net.append(nn.Linear(inner_dim, dim_out))
252
+ # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
253
+ if final_dropout:
254
+ self.net.append(nn.Dropout(dropout))
255
+
256
+ def forward(self, hidden_states):
257
+ for module in self.net:
258
+ hidden_states = module(hidden_states)
259
+ return hidden_states
260
+
261
+
262
+ class GELU(nn.Module):
263
+ r"""
264
+ GELU activation function with tanh approximation support with `approximate="tanh"`.
265
+ """
266
+
267
+ def __init__(self, dim_in: int, dim_out: int, approximate: str = "none"):
268
+ super().__init__()
269
+ self.proj = nn.Linear(dim_in, dim_out)
270
+ self.approximate = approximate
271
+
272
+ def gelu(self, gate):
273
+ if gate.device.type != "mps":
274
+ return F.gelu(gate, approximate=self.approximate)
275
+ # mps: gelu is not implemented for float16
276
+ return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to(dtype=gate.dtype)
277
+
278
+ def forward(self, hidden_states):
279
+ hidden_states = self.proj(hidden_states)
280
+ hidden_states = self.gelu(hidden_states)
281
+ return hidden_states
282
+
283
+
284
+ class GEGLU(nn.Module):
285
+ r"""
286
+ A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.
287
+
288
+ Parameters:
289
+ dim_in (`int`): The number of channels in the input.
290
+ dim_out (`int`): The number of channels in the output.
291
+ """
292
+
293
+ def __init__(self, dim_in: int, dim_out: int):
294
+ super().__init__()
295
+ self.proj = nn.Linear(dim_in, dim_out * 2)
296
+
297
+ def gelu(self, gate):
298
+ if gate.device.type != "mps":
299
+ return F.gelu(gate)
300
+ # mps: gelu is not implemented for float16
301
+ return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
302
+
303
+ def forward(self, hidden_states):
304
+ hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
305
+ return hidden_states * self.gelu(gate)
306
+
307
+
308
+ class ApproximateGELU(nn.Module):
309
+ """
310
+ The approximate form of Gaussian Error Linear Unit (GELU)
311
+
312
+ For more details, see section 2: https://arxiv.org/abs/1606.08415
313
+ """
314
+
315
+ def __init__(self, dim_in: int, dim_out: int):
316
+ super().__init__()
317
+ self.proj = nn.Linear(dim_in, dim_out)
318
+
319
+ def forward(self, x):
320
+ x = self.proj(x)
321
+ return x * torch.sigmoid(1.702 * x)
322
+
323
+
324
+ class AdaLayerNorm(nn.Module):
325
+ """
326
+ Norm layer modified to incorporate timestep embeddings.
327
+ """
328
+
329
+ def __init__(self, embedding_dim, num_embeddings):
330
+ super().__init__()
331
+ self.emb = nn.Embedding(num_embeddings, embedding_dim)
332
+ self.silu = nn.SiLU()
333
+ self.linear = nn.Linear(embedding_dim, embedding_dim * 2)
334
+ self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False)
335
+
336
+ def forward(self, x, timestep):
337
+ emb = self.linear(self.silu(self.emb(timestep)))
338
+ scale, shift = torch.chunk(emb, 2)
339
+ x = self.norm(x) * (1 + scale) + shift
340
+ return x
341
+
342
+
343
+ class AdaLayerNormZero(nn.Module):
344
+ """
345
+ Norm layer adaptive layer norm zero (adaLN-Zero).
346
+ """
347
+
348
+ def __init__(self, embedding_dim, num_embeddings):
349
+ super().__init__()
350
+
351
+ self.emb = CombinedTimestepLabelEmbeddings(num_embeddings, embedding_dim)
352
+
353
+ self.silu = nn.SiLU()
354
+ self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True)
355
+ self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
356
+
357
+ def forward(self, x, timestep, class_labels, hidden_dtype=None):
358
+ emb = self.linear(self.silu(self.emb(timestep, class_labels, hidden_dtype=hidden_dtype)))
359
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=1)
360
+ x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
361
+ return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
362
+
363
+
364
+ class AdaGroupNorm(nn.Module):
365
+ """
366
+ GroupNorm layer modified to incorporate timestep embeddings.
367
+ """
368
+
369
+ def __init__(
370
+ self, embedding_dim: int, out_dim: int, num_groups: int, act_fn: Optional[str] = None, eps: float = 1e-5
371
+ ):
372
+ super().__init__()
373
+ self.num_groups = num_groups
374
+ self.eps = eps
375
+
376
+ if act_fn is None:
377
+ self.act = None
378
+ else:
379
+ self.act = get_activation(act_fn)
380
+
381
+ self.linear = nn.Linear(embedding_dim, out_dim * 2)
382
+
383
+ def forward(self, x, emb):
384
+ if self.act:
385
+ emb = self.act(emb)
386
+ emb = self.linear(emb)
387
+ emb = emb[:, :, None, None]
388
+ scale, shift = emb.chunk(2, dim=1)
389
+
390
+ x = F.group_norm(x, self.num_groups, eps=self.eps)
391
+ x = x * (1 + scale) + shift
392
+ return x
6DoF/diffusers/models/attention_flax.py ADDED
@@ -0,0 +1,446 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import functools
16
+ import math
17
+
18
+ import flax.linen as nn
19
+ import jax
20
+ import jax.numpy as jnp
21
+
22
+
23
+ def _query_chunk_attention(query, key, value, precision, key_chunk_size: int = 4096):
24
+ """Multi-head dot product attention with a limited number of queries."""
25
+ num_kv, num_heads, k_features = key.shape[-3:]
26
+ v_features = value.shape[-1]
27
+ key_chunk_size = min(key_chunk_size, num_kv)
28
+ query = query / jnp.sqrt(k_features)
29
+
30
+ @functools.partial(jax.checkpoint, prevent_cse=False)
31
+ def summarize_chunk(query, key, value):
32
+ attn_weights = jnp.einsum("...qhd,...khd->...qhk", query, key, precision=precision)
33
+
34
+ max_score = jnp.max(attn_weights, axis=-1, keepdims=True)
35
+ max_score = jax.lax.stop_gradient(max_score)
36
+ exp_weights = jnp.exp(attn_weights - max_score)
37
+
38
+ exp_values = jnp.einsum("...vhf,...qhv->...qhf", value, exp_weights, precision=precision)
39
+ max_score = jnp.einsum("...qhk->...qh", max_score)
40
+
41
+ return (exp_values, exp_weights.sum(axis=-1), max_score)
42
+
43
+ def chunk_scanner(chunk_idx):
44
+ # julienne key array
45
+ key_chunk = jax.lax.dynamic_slice(
46
+ operand=key,
47
+ start_indices=[0] * (key.ndim - 3) + [chunk_idx, 0, 0], # [...,k,h,d]
48
+ slice_sizes=list(key.shape[:-3]) + [key_chunk_size, num_heads, k_features], # [...,k,h,d]
49
+ )
50
+
51
+ # julienne value array
52
+ value_chunk = jax.lax.dynamic_slice(
53
+ operand=value,
54
+ start_indices=[0] * (value.ndim - 3) + [chunk_idx, 0, 0], # [...,v,h,d]
55
+ slice_sizes=list(value.shape[:-3]) + [key_chunk_size, num_heads, v_features], # [...,v,h,d]
56
+ )
57
+
58
+ return summarize_chunk(query, key_chunk, value_chunk)
59
+
60
+ chunk_values, chunk_weights, chunk_max = jax.lax.map(f=chunk_scanner, xs=jnp.arange(0, num_kv, key_chunk_size))
61
+
62
+ global_max = jnp.max(chunk_max, axis=0, keepdims=True)
63
+ max_diffs = jnp.exp(chunk_max - global_max)
64
+
65
+ chunk_values *= jnp.expand_dims(max_diffs, axis=-1)
66
+ chunk_weights *= max_diffs
67
+
68
+ all_values = chunk_values.sum(axis=0)
69
+ all_weights = jnp.expand_dims(chunk_weights, -1).sum(axis=0)
70
+
71
+ return all_values / all_weights
72
+
73
+
74
+ def jax_memory_efficient_attention(
75
+ query, key, value, precision=jax.lax.Precision.HIGHEST, query_chunk_size: int = 1024, key_chunk_size: int = 4096
76
+ ):
77
+ r"""
78
+ Flax Memory-efficient multi-head dot product attention. https://arxiv.org/abs/2112.05682v2
79
+ https://github.com/AminRezaei0x443/memory-efficient-attention
80
+
81
+ Args:
82
+ query (`jnp.ndarray`): (batch..., query_length, head, query_key_depth_per_head)
83
+ key (`jnp.ndarray`): (batch..., key_value_length, head, query_key_depth_per_head)
84
+ value (`jnp.ndarray`): (batch..., key_value_length, head, value_depth_per_head)
85
+ precision (`jax.lax.Precision`, *optional*, defaults to `jax.lax.Precision.HIGHEST`):
86
+ numerical precision for computation
87
+ query_chunk_size (`int`, *optional*, defaults to 1024):
88
+ chunk size to divide query array value must divide query_length equally without remainder
89
+ key_chunk_size (`int`, *optional*, defaults to 4096):
90
+ chunk size to divide key and value array value must divide key_value_length equally without remainder
91
+
92
+ Returns:
93
+ (`jnp.ndarray`) with shape of (batch..., query_length, head, value_depth_per_head)
94
+ """
95
+ num_q, num_heads, q_features = query.shape[-3:]
96
+
97
+ def chunk_scanner(chunk_idx, _):
98
+ # julienne query array
99
+ query_chunk = jax.lax.dynamic_slice(
100
+ operand=query,
101
+ start_indices=([0] * (query.ndim - 3)) + [chunk_idx, 0, 0], # [...,q,h,d]
102
+ slice_sizes=list(query.shape[:-3]) + [min(query_chunk_size, num_q), num_heads, q_features], # [...,q,h,d]
103
+ )
104
+
105
+ return (
106
+ chunk_idx + query_chunk_size, # unused ignore it
107
+ _query_chunk_attention(
108
+ query=query_chunk, key=key, value=value, precision=precision, key_chunk_size=key_chunk_size
109
+ ),
110
+ )
111
+
112
+ _, res = jax.lax.scan(
113
+ f=chunk_scanner, init=0, xs=None, length=math.ceil(num_q / query_chunk_size) # start counter # stop counter
114
+ )
115
+
116
+ return jnp.concatenate(res, axis=-3) # fuse the chunked result back
117
+
118
+
119
+ class FlaxAttention(nn.Module):
120
+ r"""
121
+ A Flax multi-head attention module as described in: https://arxiv.org/abs/1706.03762
122
+
123
+ Parameters:
124
+ query_dim (:obj:`int`):
125
+ Input hidden states dimension
126
+ heads (:obj:`int`, *optional*, defaults to 8):
127
+ Number of heads
128
+ dim_head (:obj:`int`, *optional*, defaults to 64):
129
+ Hidden states dimension inside each head
130
+ dropout (:obj:`float`, *optional*, defaults to 0.0):
131
+ Dropout rate
132
+ use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
133
+ enable memory efficient attention https://arxiv.org/abs/2112.05682
134
+ dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
135
+ Parameters `dtype`
136
+
137
+ """
138
+ query_dim: int
139
+ heads: int = 8
140
+ dim_head: int = 64
141
+ dropout: float = 0.0
142
+ use_memory_efficient_attention: bool = False
143
+ dtype: jnp.dtype = jnp.float32
144
+
145
+ def setup(self):
146
+ inner_dim = self.dim_head * self.heads
147
+ self.scale = self.dim_head**-0.5
148
+
149
+ # Weights were exported with old names {to_q, to_k, to_v, to_out}
150
+ self.query = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_q")
151
+ self.key = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_k")
152
+ self.value = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_v")
153
+
154
+ self.proj_attn = nn.Dense(self.query_dim, dtype=self.dtype, name="to_out_0")
155
+ self.dropout_layer = nn.Dropout(rate=self.dropout)
156
+
157
+ def reshape_heads_to_batch_dim(self, tensor):
158
+ batch_size, seq_len, dim = tensor.shape
159
+ head_size = self.heads
160
+ tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
161
+ tensor = jnp.transpose(tensor, (0, 2, 1, 3))
162
+ tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size)
163
+ return tensor
164
+
165
+ def reshape_batch_dim_to_heads(self, tensor):
166
+ batch_size, seq_len, dim = tensor.shape
167
+ head_size = self.heads
168
+ tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
169
+ tensor = jnp.transpose(tensor, (0, 2, 1, 3))
170
+ tensor = tensor.reshape(batch_size // head_size, seq_len, dim * head_size)
171
+ return tensor
172
+
173
+ def __call__(self, hidden_states, context=None, deterministic=True):
174
+ context = hidden_states if context is None else context
175
+
176
+ query_proj = self.query(hidden_states)
177
+ key_proj = self.key(context)
178
+ value_proj = self.value(context)
179
+
180
+ query_states = self.reshape_heads_to_batch_dim(query_proj)
181
+ key_states = self.reshape_heads_to_batch_dim(key_proj)
182
+ value_states = self.reshape_heads_to_batch_dim(value_proj)
183
+
184
+ if self.use_memory_efficient_attention:
185
+ query_states = query_states.transpose(1, 0, 2)
186
+ key_states = key_states.transpose(1, 0, 2)
187
+ value_states = value_states.transpose(1, 0, 2)
188
+
189
+ # this if statement create a chunk size for each layer of the unet
190
+ # the chunk size is equal to the query_length dimension of the deepest layer of the unet
191
+
192
+ flatten_latent_dim = query_states.shape[-3]
193
+ if flatten_latent_dim % 64 == 0:
194
+ query_chunk_size = int(flatten_latent_dim / 64)
195
+ elif flatten_latent_dim % 16 == 0:
196
+ query_chunk_size = int(flatten_latent_dim / 16)
197
+ elif flatten_latent_dim % 4 == 0:
198
+ query_chunk_size = int(flatten_latent_dim / 4)
199
+ else:
200
+ query_chunk_size = int(flatten_latent_dim)
201
+
202
+ hidden_states = jax_memory_efficient_attention(
203
+ query_states, key_states, value_states, query_chunk_size=query_chunk_size, key_chunk_size=4096 * 4
204
+ )
205
+
206
+ hidden_states = hidden_states.transpose(1, 0, 2)
207
+ else:
208
+ # compute attentions
209
+ attention_scores = jnp.einsum("b i d, b j d->b i j", query_states, key_states)
210
+ attention_scores = attention_scores * self.scale
211
+ attention_probs = nn.softmax(attention_scores, axis=2)
212
+
213
+ # attend to values
214
+ hidden_states = jnp.einsum("b i j, b j d -> b i d", attention_probs, value_states)
215
+
216
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
217
+ hidden_states = self.proj_attn(hidden_states)
218
+ return self.dropout_layer(hidden_states, deterministic=deterministic)
219
+
220
+
221
+ class FlaxBasicTransformerBlock(nn.Module):
222
+ r"""
223
+ A Flax transformer block layer with `GLU` (Gated Linear Unit) activation function as described in:
224
+ https://arxiv.org/abs/1706.03762
225
+
226
+
227
+ Parameters:
228
+ dim (:obj:`int`):
229
+ Inner hidden states dimension
230
+ n_heads (:obj:`int`):
231
+ Number of heads
232
+ d_head (:obj:`int`):
233
+ Hidden states dimension inside each head
234
+ dropout (:obj:`float`, *optional*, defaults to 0.0):
235
+ Dropout rate
236
+ only_cross_attention (`bool`, defaults to `False`):
237
+ Whether to only apply cross attention.
238
+ dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
239
+ Parameters `dtype`
240
+ use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
241
+ enable memory efficient attention https://arxiv.org/abs/2112.05682
242
+ """
243
+ dim: int
244
+ n_heads: int
245
+ d_head: int
246
+ dropout: float = 0.0
247
+ only_cross_attention: bool = False
248
+ dtype: jnp.dtype = jnp.float32
249
+ use_memory_efficient_attention: bool = False
250
+
251
+ def setup(self):
252
+ # self attention (or cross_attention if only_cross_attention is True)
253
+ self.attn1 = FlaxAttention(
254
+ self.dim, self.n_heads, self.d_head, self.dropout, self.use_memory_efficient_attention, dtype=self.dtype
255
+ )
256
+ # cross attention
257
+ self.attn2 = FlaxAttention(
258
+ self.dim, self.n_heads, self.d_head, self.dropout, self.use_memory_efficient_attention, dtype=self.dtype
259
+ )
260
+ self.ff = FlaxFeedForward(dim=self.dim, dropout=self.dropout, dtype=self.dtype)
261
+ self.norm1 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
262
+ self.norm2 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
263
+ self.norm3 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
264
+ self.dropout_layer = nn.Dropout(rate=self.dropout)
265
+
266
+ def __call__(self, hidden_states, context, deterministic=True):
267
+ # self attention
268
+ residual = hidden_states
269
+ if self.only_cross_attention:
270
+ hidden_states = self.attn1(self.norm1(hidden_states), context, deterministic=deterministic)
271
+ else:
272
+ hidden_states = self.attn1(self.norm1(hidden_states), deterministic=deterministic)
273
+ hidden_states = hidden_states + residual
274
+
275
+ # cross attention
276
+ residual = hidden_states
277
+ hidden_states = self.attn2(self.norm2(hidden_states), context, deterministic=deterministic)
278
+ hidden_states = hidden_states + residual
279
+
280
+ # feed forward
281
+ residual = hidden_states
282
+ hidden_states = self.ff(self.norm3(hidden_states), deterministic=deterministic)
283
+ hidden_states = hidden_states + residual
284
+
285
+ return self.dropout_layer(hidden_states, deterministic=deterministic)
286
+
287
+
288
+ class FlaxTransformer2DModel(nn.Module):
289
+ r"""
290
+ A Spatial Transformer layer with Gated Linear Unit (GLU) activation function as described in:
291
+ https://arxiv.org/pdf/1506.02025.pdf
292
+
293
+
294
+ Parameters:
295
+ in_channels (:obj:`int`):
296
+ Input number of channels
297
+ n_heads (:obj:`int`):
298
+ Number of heads
299
+ d_head (:obj:`int`):
300
+ Hidden states dimension inside each head
301
+ depth (:obj:`int`, *optional*, defaults to 1):
302
+ Number of transformers block
303
+ dropout (:obj:`float`, *optional*, defaults to 0.0):
304
+ Dropout rate
305
+ use_linear_projection (`bool`, defaults to `False`): tbd
306
+ only_cross_attention (`bool`, defaults to `False`): tbd
307
+ dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
308
+ Parameters `dtype`
309
+ use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
310
+ enable memory efficient attention https://arxiv.org/abs/2112.05682
311
+ """
312
+ in_channels: int
313
+ n_heads: int
314
+ d_head: int
315
+ depth: int = 1
316
+ dropout: float = 0.0
317
+ use_linear_projection: bool = False
318
+ only_cross_attention: bool = False
319
+ dtype: jnp.dtype = jnp.float32
320
+ use_memory_efficient_attention: bool = False
321
+
322
+ def setup(self):
323
+ self.norm = nn.GroupNorm(num_groups=32, epsilon=1e-5)
324
+
325
+ inner_dim = self.n_heads * self.d_head
326
+ if self.use_linear_projection:
327
+ self.proj_in = nn.Dense(inner_dim, dtype=self.dtype)
328
+ else:
329
+ self.proj_in = nn.Conv(
330
+ inner_dim,
331
+ kernel_size=(1, 1),
332
+ strides=(1, 1),
333
+ padding="VALID",
334
+ dtype=self.dtype,
335
+ )
336
+
337
+ self.transformer_blocks = [
338
+ FlaxBasicTransformerBlock(
339
+ inner_dim,
340
+ self.n_heads,
341
+ self.d_head,
342
+ dropout=self.dropout,
343
+ only_cross_attention=self.only_cross_attention,
344
+ dtype=self.dtype,
345
+ use_memory_efficient_attention=self.use_memory_efficient_attention,
346
+ )
347
+ for _ in range(self.depth)
348
+ ]
349
+
350
+ if self.use_linear_projection:
351
+ self.proj_out = nn.Dense(inner_dim, dtype=self.dtype)
352
+ else:
353
+ self.proj_out = nn.Conv(
354
+ inner_dim,
355
+ kernel_size=(1, 1),
356
+ strides=(1, 1),
357
+ padding="VALID",
358
+ dtype=self.dtype,
359
+ )
360
+
361
+ self.dropout_layer = nn.Dropout(rate=self.dropout)
362
+
363
+ def __call__(self, hidden_states, context, deterministic=True):
364
+ batch, height, width, channels = hidden_states.shape
365
+ residual = hidden_states
366
+ hidden_states = self.norm(hidden_states)
367
+ if self.use_linear_projection:
368
+ hidden_states = hidden_states.reshape(batch, height * width, channels)
369
+ hidden_states = self.proj_in(hidden_states)
370
+ else:
371
+ hidden_states = self.proj_in(hidden_states)
372
+ hidden_states = hidden_states.reshape(batch, height * width, channels)
373
+
374
+ for transformer_block in self.transformer_blocks:
375
+ hidden_states = transformer_block(hidden_states, context, deterministic=deterministic)
376
+
377
+ if self.use_linear_projection:
378
+ hidden_states = self.proj_out(hidden_states)
379
+ hidden_states = hidden_states.reshape(batch, height, width, channels)
380
+ else:
381
+ hidden_states = hidden_states.reshape(batch, height, width, channels)
382
+ hidden_states = self.proj_out(hidden_states)
383
+
384
+ hidden_states = hidden_states + residual
385
+ return self.dropout_layer(hidden_states, deterministic=deterministic)
386
+
387
+
388
+ class FlaxFeedForward(nn.Module):
389
+ r"""
390
+ Flax module that encapsulates two Linear layers separated by a non-linearity. It is the counterpart of PyTorch's
391
+ [`FeedForward`] class, with the following simplifications:
392
+ - The activation function is currently hardcoded to a gated linear unit from:
393
+ https://arxiv.org/abs/2002.05202
394
+ - `dim_out` is equal to `dim`.
395
+ - The number of hidden dimensions is hardcoded to `dim * 4` in [`FlaxGELU`].
396
+
397
+ Parameters:
398
+ dim (:obj:`int`):
399
+ Inner hidden states dimension
400
+ dropout (:obj:`float`, *optional*, defaults to 0.0):
401
+ Dropout rate
402
+ dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
403
+ Parameters `dtype`
404
+ """
405
+ dim: int
406
+ dropout: float = 0.0
407
+ dtype: jnp.dtype = jnp.float32
408
+
409
+ def setup(self):
410
+ # The second linear layer needs to be called
411
+ # net_2 for now to match the index of the Sequential layer
412
+ self.net_0 = FlaxGEGLU(self.dim, self.dropout, self.dtype)
413
+ self.net_2 = nn.Dense(self.dim, dtype=self.dtype)
414
+
415
+ def __call__(self, hidden_states, deterministic=True):
416
+ hidden_states = self.net_0(hidden_states, deterministic=deterministic)
417
+ hidden_states = self.net_2(hidden_states)
418
+ return hidden_states
419
+
420
+
421
+ class FlaxGEGLU(nn.Module):
422
+ r"""
423
+ Flax implementation of a Linear layer followed by the variant of the gated linear unit activation function from
424
+ https://arxiv.org/abs/2002.05202.
425
+
426
+ Parameters:
427
+ dim (:obj:`int`):
428
+ Input hidden states dimension
429
+ dropout (:obj:`float`, *optional*, defaults to 0.0):
430
+ Dropout rate
431
+ dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
432
+ Parameters `dtype`
433
+ """
434
+ dim: int
435
+ dropout: float = 0.0
436
+ dtype: jnp.dtype = jnp.float32
437
+
438
+ def setup(self):
439
+ inner_dim = self.dim * 4
440
+ self.proj = nn.Dense(inner_dim * 2, dtype=self.dtype)
441
+ self.dropout_layer = nn.Dropout(rate=self.dropout)
442
+
443
+ def __call__(self, hidden_states, deterministic=True):
444
+ hidden_states = self.proj(hidden_states)
445
+ hidden_linear, hidden_gelu = jnp.split(hidden_states, 2, axis=2)
446
+ return self.dropout_layer(hidden_linear * nn.gelu(hidden_gelu), deterministic=deterministic)
6DoF/diffusers/models/attention_processor.py ADDED
@@ -0,0 +1,1684 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Callable, Optional, Union
15
+
16
+ import torch
17
+ import torch.nn.functional as F
18
+ from torch import nn
19
+
20
+ from ..utils import deprecate, logging, maybe_allow_in_graph
21
+ from ..utils.import_utils import is_xformers_available
22
+
23
+
24
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
25
+
26
+
27
+ if is_xformers_available():
28
+ import xformers
29
+ import xformers.ops
30
+ else:
31
+ xformers = None
32
+
33
+
34
+ # 6DoF CaPE
35
+ import einops
36
+ def cape_embed(f, P):
37
+ # f is feature vector of shape [..., d]
38
+ # P is 4x4 transformation matrix
39
+ f = einops.rearrange(f, '... (d k) -> ... d k', k=4)
40
+ return einops.rearrange(f@P, '... d k -> ... (d k)', k=4)
41
+
42
+ @maybe_allow_in_graph
43
+ class Attention(nn.Module):
44
+ r"""
45
+ A cross attention layer.
46
+
47
+ Parameters:
48
+ query_dim (`int`): The number of channels in the query.
49
+ cross_attention_dim (`int`, *optional*):
50
+ The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
51
+ heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
52
+ dim_head (`int`, *optional*, defaults to 64): The number of channels in each head.
53
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
54
+ bias (`bool`, *optional*, defaults to False):
55
+ Set to `True` for the query, key, and value linear layers to contain a bias parameter.
56
+ """
57
+
58
+ def __init__(
59
+ self,
60
+ query_dim: int,
61
+ cross_attention_dim: Optional[int] = None,
62
+ heads: int = 8,
63
+ dim_head: int = 64,
64
+ dropout: float = 0.0,
65
+ bias=False,
66
+ upcast_attention: bool = False,
67
+ upcast_softmax: bool = False,
68
+ cross_attention_norm: Optional[str] = None,
69
+ cross_attention_norm_num_groups: int = 32,
70
+ added_kv_proj_dim: Optional[int] = None,
71
+ norm_num_groups: Optional[int] = None,
72
+ spatial_norm_dim: Optional[int] = None,
73
+ out_bias: bool = True,
74
+ scale_qk: bool = True,
75
+ only_cross_attention: bool = False,
76
+ eps: float = 1e-5,
77
+ rescale_output_factor: float = 1.0,
78
+ residual_connection: bool = False,
79
+ _from_deprecated_attn_block=False,
80
+ processor: Optional["AttnProcessor"] = None,
81
+ ):
82
+ super().__init__()
83
+ inner_dim = dim_head * heads
84
+ cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
85
+ self.upcast_attention = upcast_attention
86
+ self.upcast_softmax = upcast_softmax
87
+ self.rescale_output_factor = rescale_output_factor
88
+ self.residual_connection = residual_connection
89
+ self.dropout = dropout
90
+
91
+ # we make use of this private variable to know whether this class is loaded
92
+ # with an deprecated state dict so that we can convert it on the fly
93
+ self._from_deprecated_attn_block = _from_deprecated_attn_block
94
+
95
+ self.scale_qk = scale_qk
96
+ self.scale = dim_head**-0.5 if self.scale_qk else 1.0
97
+
98
+ self.heads = heads
99
+ # for slice_size > 0 the attention score computation
100
+ # is split across the batch axis to save memory
101
+ # You can set slice_size with `set_attention_slice`
102
+ self.sliceable_head_dim = heads
103
+
104
+ self.added_kv_proj_dim = added_kv_proj_dim
105
+ self.only_cross_attention = only_cross_attention
106
+
107
+ if self.added_kv_proj_dim is None and self.only_cross_attention:
108
+ raise ValueError(
109
+ "`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`."
110
+ )
111
+
112
+ if norm_num_groups is not None:
113
+ self.group_norm = nn.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True)
114
+ else:
115
+ self.group_norm = None
116
+
117
+ if spatial_norm_dim is not None:
118
+ self.spatial_norm = SpatialNorm(f_channels=query_dim, zq_channels=spatial_norm_dim)
119
+ else:
120
+ self.spatial_norm = None
121
+
122
+ if cross_attention_norm is None:
123
+ self.norm_cross = None
124
+ elif cross_attention_norm == "layer_norm":
125
+ self.norm_cross = nn.LayerNorm(cross_attention_dim)
126
+ elif cross_attention_norm == "group_norm":
127
+ if self.added_kv_proj_dim is not None:
128
+ # The given `encoder_hidden_states` are initially of shape
129
+ # (batch_size, seq_len, added_kv_proj_dim) before being projected
130
+ # to (batch_size, seq_len, cross_attention_dim). The norm is applied
131
+ # before the projection, so we need to use `added_kv_proj_dim` as
132
+ # the number of channels for the group norm.
133
+ norm_cross_num_channels = added_kv_proj_dim
134
+ else:
135
+ norm_cross_num_channels = cross_attention_dim
136
+
137
+ self.norm_cross = nn.GroupNorm(
138
+ num_channels=norm_cross_num_channels, num_groups=cross_attention_norm_num_groups, eps=1e-5, affine=True
139
+ )
140
+ else:
141
+ raise ValueError(
142
+ f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'"
143
+ )
144
+
145
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
146
+
147
+ if not self.only_cross_attention:
148
+ # only relevant for the `AddedKVProcessor` classes
149
+ self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
150
+ self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
151
+ else:
152
+ self.to_k = None
153
+ self.to_v = None
154
+
155
+ if self.added_kv_proj_dim is not None:
156
+ self.add_k_proj = nn.Linear(added_kv_proj_dim, inner_dim)
157
+ self.add_v_proj = nn.Linear(added_kv_proj_dim, inner_dim)
158
+
159
+ self.to_out = nn.ModuleList([])
160
+ self.to_out.append(nn.Linear(inner_dim, query_dim, bias=out_bias))
161
+ self.to_out.append(nn.Dropout(dropout))
162
+
163
+ # set attention processor
164
+ # We use the AttnProcessor2_0 by default when torch 2.x is used which uses
165
+ # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
166
+ # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
167
+ if processor is None:
168
+ processor = (
169
+ AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
170
+ )
171
+ self.set_processor(processor)
172
+
173
+ def set_use_memory_efficient_attention_xformers(
174
+ self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
175
+ ):
176
+ is_lora = hasattr(self, "processor") and isinstance(
177
+ self.processor,
178
+ (LoRAAttnProcessor, LoRAAttnProcessor2_0, LoRAXFormersAttnProcessor, LoRAAttnAddedKVProcessor),
179
+ )
180
+ is_custom_diffusion = hasattr(self, "processor") and isinstance(
181
+ self.processor, (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor)
182
+ )
183
+ is_added_kv_processor = hasattr(self, "processor") and isinstance(
184
+ self.processor,
185
+ (
186
+ AttnAddedKVProcessor,
187
+ AttnAddedKVProcessor2_0,
188
+ SlicedAttnAddedKVProcessor,
189
+ XFormersAttnAddedKVProcessor,
190
+ LoRAAttnAddedKVProcessor,
191
+ ),
192
+ )
193
+
194
+ if use_memory_efficient_attention_xformers:
195
+ if is_added_kv_processor and (is_lora or is_custom_diffusion):
196
+ raise NotImplementedError(
197
+ f"Memory efficient attention is currently not supported for LoRA or custom diffuson for attention processor type {self.processor}"
198
+ )
199
+ if not is_xformers_available():
200
+ raise ModuleNotFoundError(
201
+ (
202
+ "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
203
+ " xformers"
204
+ ),
205
+ name="xformers",
206
+ )
207
+ elif not torch.cuda.is_available():
208
+ raise ValueError(
209
+ "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is"
210
+ " only available for GPU "
211
+ )
212
+ else:
213
+ try:
214
+ # Make sure we can run the memory efficient attention
215
+ _ = xformers.ops.memory_efficient_attention(
216
+ torch.randn((1, 2, 40), device="cuda"),
217
+ torch.randn((1, 2, 40), device="cuda"),
218
+ torch.randn((1, 2, 40), device="cuda"),
219
+ )
220
+ except Exception as e:
221
+ raise e
222
+
223
+ if is_lora:
224
+ # TODO (sayakpaul): should we throw a warning if someone wants to use the xformers
225
+ # variant when using PT 2.0 now that we have LoRAAttnProcessor2_0?
226
+ processor = LoRAXFormersAttnProcessor(
227
+ hidden_size=self.processor.hidden_size,
228
+ cross_attention_dim=self.processor.cross_attention_dim,
229
+ rank=self.processor.rank,
230
+ attention_op=attention_op,
231
+ )
232
+ processor.load_state_dict(self.processor.state_dict())
233
+ processor.to(self.processor.to_q_lora.up.weight.device)
234
+ elif is_custom_diffusion:
235
+ processor = CustomDiffusionXFormersAttnProcessor(
236
+ train_kv=self.processor.train_kv,
237
+ train_q_out=self.processor.train_q_out,
238
+ hidden_size=self.processor.hidden_size,
239
+ cross_attention_dim=self.processor.cross_attention_dim,
240
+ attention_op=attention_op,
241
+ )
242
+ processor.load_state_dict(self.processor.state_dict())
243
+ if hasattr(self.processor, "to_k_custom_diffusion"):
244
+ processor.to(self.processor.to_k_custom_diffusion.weight.device)
245
+ elif is_added_kv_processor:
246
+ # TODO(Patrick, Suraj, William) - currently xformers doesn't work for UnCLIP
247
+ # which uses this type of cross attention ONLY because the attention mask of format
248
+ # [0, ..., -10.000, ..., 0, ...,] is not supported
249
+ # throw warning
250
+ logger.info(
251
+ "Memory efficient attention with `xformers` might currently not work correctly if an attention mask is required for the attention operation."
252
+ )
253
+ processor = XFormersAttnAddedKVProcessor(attention_op=attention_op)
254
+ else:
255
+ processor = XFormersAttnProcessor(attention_op=attention_op)
256
+ else:
257
+ if is_lora:
258
+ attn_processor_class = (
259
+ LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
260
+ )
261
+ processor = attn_processor_class(
262
+ hidden_size=self.processor.hidden_size,
263
+ cross_attention_dim=self.processor.cross_attention_dim,
264
+ rank=self.processor.rank,
265
+ )
266
+ processor.load_state_dict(self.processor.state_dict())
267
+ processor.to(self.processor.to_q_lora.up.weight.device)
268
+ elif is_custom_diffusion:
269
+ processor = CustomDiffusionAttnProcessor(
270
+ train_kv=self.processor.train_kv,
271
+ train_q_out=self.processor.train_q_out,
272
+ hidden_size=self.processor.hidden_size,
273
+ cross_attention_dim=self.processor.cross_attention_dim,
274
+ )
275
+ processor.load_state_dict(self.processor.state_dict())
276
+ if hasattr(self.processor, "to_k_custom_diffusion"):
277
+ processor.to(self.processor.to_k_custom_diffusion.weight.device)
278
+ else:
279
+ # set attention processor
280
+ # We use the AttnProcessor2_0 by default when torch 2.x is used which uses
281
+ # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
282
+ # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
283
+ processor = (
284
+ AttnProcessor2_0()
285
+ if hasattr(F, "scaled_dot_product_attention") and self.scale_qk
286
+ else AttnProcessor()
287
+ )
288
+
289
+ self.set_processor(processor)
290
+
291
+ def set_attention_slice(self, slice_size):
292
+ if slice_size is not None and slice_size > self.sliceable_head_dim:
293
+ raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.")
294
+
295
+ if slice_size is not None and self.added_kv_proj_dim is not None:
296
+ processor = SlicedAttnAddedKVProcessor(slice_size)
297
+ elif slice_size is not None:
298
+ processor = SlicedAttnProcessor(slice_size)
299
+ elif self.added_kv_proj_dim is not None:
300
+ processor = AttnAddedKVProcessor()
301
+ else:
302
+ # set attention processor
303
+ # We use the AttnProcessor2_0 by default when torch 2.x is used which uses
304
+ # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
305
+ # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
306
+ processor = (
307
+ AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
308
+ )
309
+
310
+ self.set_processor(processor)
311
+
312
+ def set_processor(self, processor: "AttnProcessor"):
313
+ # if current processor is in `self._modules` and if passed `processor` is not, we need to
314
+ # pop `processor` from `self._modules`
315
+ if (
316
+ hasattr(self, "processor")
317
+ and isinstance(self.processor, torch.nn.Module)
318
+ and not isinstance(processor, torch.nn.Module)
319
+ ):
320
+ logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}")
321
+ self._modules.pop("processor")
322
+
323
+ self.processor = processor
324
+
325
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, **cross_attention_kwargs):
326
+ # The `Attention` class can call different attention processors / attention functions
327
+ # here we simply pass along all tensors to the selected processor class
328
+ # For standard processors that are defined here, `**cross_attention_kwargs` is empty
329
+ return self.processor(
330
+ self,
331
+ hidden_states,
332
+ encoder_hidden_states=encoder_hidden_states,
333
+ attention_mask=attention_mask,
334
+ **cross_attention_kwargs,
335
+ )
336
+
337
+ def batch_to_head_dim(self, tensor):
338
+ head_size = self.heads
339
+ batch_size, seq_len, dim = tensor.shape
340
+ tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
341
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
342
+ return tensor
343
+
344
+ def head_to_batch_dim(self, tensor, out_dim=3):
345
+ head_size = self.heads
346
+ batch_size, seq_len, dim = tensor.shape
347
+ tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
348
+ tensor = tensor.permute(0, 2, 1, 3)
349
+
350
+ if out_dim == 3:
351
+ tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size)
352
+
353
+ return tensor
354
+
355
+ def get_attention_scores(self, query, key, attention_mask=None):
356
+ dtype = query.dtype
357
+ if self.upcast_attention:
358
+ query = query.float()
359
+ key = key.float()
360
+
361
+ if attention_mask is None:
362
+ baddbmm_input = torch.empty(
363
+ query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device
364
+ )
365
+ beta = 0
366
+ else:
367
+ baddbmm_input = attention_mask
368
+ beta = 1
369
+
370
+ attention_scores = torch.baddbmm(
371
+ baddbmm_input,
372
+ query,
373
+ key.transpose(-1, -2),
374
+ beta=beta,
375
+ alpha=self.scale,
376
+ )
377
+ del baddbmm_input
378
+
379
+ if self.upcast_softmax:
380
+ attention_scores = attention_scores.float()
381
+
382
+ attention_probs = attention_scores.softmax(dim=-1)
383
+ del attention_scores
384
+
385
+ attention_probs = attention_probs.to(dtype)
386
+
387
+ return attention_probs
388
+
389
+ def prepare_attention_mask(self, attention_mask, target_length, batch_size=None, out_dim=3):
390
+ if batch_size is None:
391
+ deprecate(
392
+ "batch_size=None",
393
+ "0.0.15",
394
+ (
395
+ "Not passing the `batch_size` parameter to `prepare_attention_mask` can lead to incorrect"
396
+ " attention mask preparation and is deprecated behavior. Please make sure to pass `batch_size` to"
397
+ " `prepare_attention_mask` when preparing the attention_mask."
398
+ ),
399
+ )
400
+ batch_size = 1
401
+
402
+ head_size = self.heads
403
+ if attention_mask is None:
404
+ return attention_mask
405
+
406
+ current_length: int = attention_mask.shape[-1]
407
+ if current_length != target_length:
408
+ if attention_mask.device.type == "mps":
409
+ # HACK: MPS: Does not support padding by greater than dimension of input tensor.
410
+ # Instead, we can manually construct the padding tensor.
411
+ padding_shape = (attention_mask.shape[0], attention_mask.shape[1], target_length)
412
+ padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device)
413
+ attention_mask = torch.cat([attention_mask, padding], dim=2)
414
+ else:
415
+ # TODO: for pipelines such as stable-diffusion, padding cross-attn mask:
416
+ # we want to instead pad by (0, remaining_length), where remaining_length is:
417
+ # remaining_length: int = target_length - current_length
418
+ # TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding
419
+ attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
420
+
421
+ if out_dim == 3:
422
+ if attention_mask.shape[0] < batch_size * head_size:
423
+ attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
424
+ elif out_dim == 4:
425
+ attention_mask = attention_mask.unsqueeze(1)
426
+ attention_mask = attention_mask.repeat_interleave(head_size, dim=1)
427
+
428
+ return attention_mask
429
+
430
+ def norm_encoder_hidden_states(self, encoder_hidden_states):
431
+ assert self.norm_cross is not None, "self.norm_cross must be defined to call self.norm_encoder_hidden_states"
432
+
433
+ if isinstance(self.norm_cross, nn.LayerNorm):
434
+ encoder_hidden_states = self.norm_cross(encoder_hidden_states)
435
+ elif isinstance(self.norm_cross, nn.GroupNorm):
436
+ # Group norm norms along the channels dimension and expects
437
+ # input to be in the shape of (N, C, *). In this case, we want
438
+ # to norm along the hidden dimension, so we need to move
439
+ # (batch_size, sequence_length, hidden_size) ->
440
+ # (batch_size, hidden_size, sequence_length)
441
+ encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
442
+ encoder_hidden_states = self.norm_cross(encoder_hidden_states)
443
+ encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
444
+ else:
445
+ assert False
446
+
447
+ return encoder_hidden_states
448
+
449
+
450
+ class AttnProcessor:
451
+ r"""
452
+ Default processor for performing attention-related computations.
453
+ """
454
+
455
+ def __call__(
456
+ self,
457
+ attn: Attention,
458
+ hidden_states,
459
+ encoder_hidden_states=None,
460
+ attention_mask=None,
461
+ temb=None,
462
+ ):
463
+ residual = hidden_states
464
+
465
+ if attn.spatial_norm is not None:
466
+ hidden_states = attn.spatial_norm(hidden_states, temb)
467
+
468
+ input_ndim = hidden_states.ndim
469
+
470
+ if input_ndim == 4:
471
+ batch_size, channel, height, width = hidden_states.shape
472
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
473
+
474
+ batch_size, sequence_length, _ = (
475
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
476
+ )
477
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
478
+
479
+ if attn.group_norm is not None:
480
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
481
+
482
+ query = attn.to_q(hidden_states)
483
+
484
+ if encoder_hidden_states is None:
485
+ encoder_hidden_states = hidden_states
486
+ elif attn.norm_cross:
487
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
488
+
489
+ key = attn.to_k(encoder_hidden_states)
490
+ value = attn.to_v(encoder_hidden_states)
491
+
492
+ query = attn.head_to_batch_dim(query)
493
+ key = attn.head_to_batch_dim(key)
494
+ value = attn.head_to_batch_dim(value)
495
+
496
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
497
+ hidden_states = torch.bmm(attention_probs, value)
498
+ hidden_states = attn.batch_to_head_dim(hidden_states)
499
+
500
+ # linear proj
501
+ hidden_states = attn.to_out[0](hidden_states)
502
+ # dropout
503
+ hidden_states = attn.to_out[1](hidden_states)
504
+
505
+ if input_ndim == 4:
506
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
507
+
508
+ if attn.residual_connection:
509
+ hidden_states = hidden_states + residual
510
+
511
+ hidden_states = hidden_states / attn.rescale_output_factor
512
+
513
+ return hidden_states
514
+
515
+
516
+ class LoRALinearLayer(nn.Module):
517
+ def __init__(self, in_features, out_features, rank=4, network_alpha=None):
518
+ super().__init__()
519
+
520
+ if rank > min(in_features, out_features):
521
+ raise ValueError(f"LoRA rank {rank} must be less or equal than {min(in_features, out_features)}")
522
+
523
+ self.down = nn.Linear(in_features, rank, bias=False)
524
+ self.up = nn.Linear(rank, out_features, bias=False)
525
+ # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
526
+ # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
527
+ self.network_alpha = network_alpha
528
+ self.rank = rank
529
+
530
+ nn.init.normal_(self.down.weight, std=1 / rank)
531
+ nn.init.zeros_(self.up.weight)
532
+
533
+ def forward(self, hidden_states):
534
+ orig_dtype = hidden_states.dtype
535
+ dtype = self.down.weight.dtype
536
+
537
+ down_hidden_states = self.down(hidden_states.to(dtype))
538
+ up_hidden_states = self.up(down_hidden_states)
539
+
540
+ if self.network_alpha is not None:
541
+ up_hidden_states *= self.network_alpha / self.rank
542
+
543
+ return up_hidden_states.to(orig_dtype)
544
+
545
+
546
+ class LoRAAttnProcessor(nn.Module):
547
+ r"""
548
+ Processor for implementing the LoRA attention mechanism.
549
+
550
+ Args:
551
+ hidden_size (`int`, *optional*):
552
+ The hidden size of the attention layer.
553
+ cross_attention_dim (`int`, *optional*):
554
+ The number of channels in the `encoder_hidden_states`.
555
+ rank (`int`, defaults to 4):
556
+ The dimension of the LoRA update matrices.
557
+ network_alpha (`int`, *optional*):
558
+ Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.
559
+ """
560
+
561
+ def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None):
562
+ super().__init__()
563
+
564
+ self.hidden_size = hidden_size
565
+ self.cross_attention_dim = cross_attention_dim
566
+ self.rank = rank
567
+
568
+ self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
569
+ self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
570
+ self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
571
+ self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
572
+
573
+ def __call__(
574
+ self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0, temb=None
575
+ ):
576
+ residual = hidden_states
577
+
578
+ if attn.spatial_norm is not None:
579
+ hidden_states = attn.spatial_norm(hidden_states, temb)
580
+
581
+ input_ndim = hidden_states.ndim
582
+
583
+ if input_ndim == 4:
584
+ batch_size, channel, height, width = hidden_states.shape
585
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
586
+
587
+ batch_size, sequence_length, _ = (
588
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
589
+ )
590
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
591
+
592
+ if attn.group_norm is not None:
593
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
594
+
595
+ query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
596
+ query = attn.head_to_batch_dim(query)
597
+
598
+ if encoder_hidden_states is None:
599
+ encoder_hidden_states = hidden_states
600
+ elif attn.norm_cross:
601
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
602
+
603
+ key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states)
604
+ value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states)
605
+
606
+ key = attn.head_to_batch_dim(key)
607
+ value = attn.head_to_batch_dim(value)
608
+
609
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
610
+ hidden_states = torch.bmm(attention_probs, value)
611
+ hidden_states = attn.batch_to_head_dim(hidden_states)
612
+
613
+ # linear proj
614
+ hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states)
615
+ # dropout
616
+ hidden_states = attn.to_out[1](hidden_states)
617
+
618
+ if input_ndim == 4:
619
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
620
+
621
+ if attn.residual_connection:
622
+ hidden_states = hidden_states + residual
623
+
624
+ hidden_states = hidden_states / attn.rescale_output_factor
625
+
626
+ return hidden_states
627
+
628
+
629
+ class CustomDiffusionAttnProcessor(nn.Module):
630
+ r"""
631
+ Processor for implementing attention for the Custom Diffusion method.
632
+
633
+ Args:
634
+ train_kv (`bool`, defaults to `True`):
635
+ Whether to newly train the key and value matrices corresponding to the text features.
636
+ train_q_out (`bool`, defaults to `True`):
637
+ Whether to newly train query matrices corresponding to the latent image features.
638
+ hidden_size (`int`, *optional*, defaults to `None`):
639
+ The hidden size of the attention layer.
640
+ cross_attention_dim (`int`, *optional*, defaults to `None`):
641
+ The number of channels in the `encoder_hidden_states`.
642
+ out_bias (`bool`, defaults to `True`):
643
+ Whether to include the bias parameter in `train_q_out`.
644
+ dropout (`float`, *optional*, defaults to 0.0):
645
+ The dropout probability to use.
646
+ """
647
+
648
+ def __init__(
649
+ self,
650
+ train_kv=True,
651
+ train_q_out=True,
652
+ hidden_size=None,
653
+ cross_attention_dim=None,
654
+ out_bias=True,
655
+ dropout=0.0,
656
+ ):
657
+ super().__init__()
658
+ self.train_kv = train_kv
659
+ self.train_q_out = train_q_out
660
+
661
+ self.hidden_size = hidden_size
662
+ self.cross_attention_dim = cross_attention_dim
663
+
664
+ # `_custom_diffusion` id for easy serialization and loading.
665
+ if self.train_kv:
666
+ self.to_k_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
667
+ self.to_v_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
668
+ if self.train_q_out:
669
+ self.to_q_custom_diffusion = nn.Linear(hidden_size, hidden_size, bias=False)
670
+ self.to_out_custom_diffusion = nn.ModuleList([])
671
+ self.to_out_custom_diffusion.append(nn.Linear(hidden_size, hidden_size, bias=out_bias))
672
+ self.to_out_custom_diffusion.append(nn.Dropout(dropout))
673
+
674
+ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
675
+ batch_size, sequence_length, _ = hidden_states.shape
676
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
677
+ if self.train_q_out:
678
+ query = self.to_q_custom_diffusion(hidden_states)
679
+ else:
680
+ query = attn.to_q(hidden_states)
681
+
682
+ if encoder_hidden_states is None:
683
+ crossattn = False
684
+ encoder_hidden_states = hidden_states
685
+ else:
686
+ crossattn = True
687
+ if attn.norm_cross:
688
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
689
+
690
+ if self.train_kv:
691
+ key = self.to_k_custom_diffusion(encoder_hidden_states)
692
+ value = self.to_v_custom_diffusion(encoder_hidden_states)
693
+ else:
694
+ key = attn.to_k(encoder_hidden_states)
695
+ value = attn.to_v(encoder_hidden_states)
696
+
697
+ if crossattn:
698
+ detach = torch.ones_like(key)
699
+ detach[:, :1, :] = detach[:, :1, :] * 0.0
700
+ key = detach * key + (1 - detach) * key.detach()
701
+ value = detach * value + (1 - detach) * value.detach()
702
+
703
+ query = attn.head_to_batch_dim(query)
704
+ key = attn.head_to_batch_dim(key)
705
+ value = attn.head_to_batch_dim(value)
706
+
707
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
708
+ hidden_states = torch.bmm(attention_probs, value)
709
+ hidden_states = attn.batch_to_head_dim(hidden_states)
710
+
711
+ if self.train_q_out:
712
+ # linear proj
713
+ hidden_states = self.to_out_custom_diffusion[0](hidden_states)
714
+ # dropout
715
+ hidden_states = self.to_out_custom_diffusion[1](hidden_states)
716
+ else:
717
+ # linear proj
718
+ hidden_states = attn.to_out[0](hidden_states)
719
+ # dropout
720
+ hidden_states = attn.to_out[1](hidden_states)
721
+
722
+ return hidden_states
723
+
724
+
725
+ class AttnAddedKVProcessor:
726
+ r"""
727
+ Processor for performing attention-related computations with extra learnable key and value matrices for the text
728
+ encoder.
729
+ """
730
+
731
+ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
732
+ residual = hidden_states
733
+ hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
734
+ batch_size, sequence_length, _ = hidden_states.shape
735
+
736
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
737
+
738
+ if encoder_hidden_states is None:
739
+ encoder_hidden_states = hidden_states
740
+ elif attn.norm_cross:
741
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
742
+
743
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
744
+
745
+ query = attn.to_q(hidden_states)
746
+ query = attn.head_to_batch_dim(query)
747
+
748
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
749
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
750
+ encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
751
+ encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
752
+
753
+ if not attn.only_cross_attention:
754
+ key = attn.to_k(hidden_states)
755
+ value = attn.to_v(hidden_states)
756
+ key = attn.head_to_batch_dim(key)
757
+ value = attn.head_to_batch_dim(value)
758
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
759
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
760
+ else:
761
+ key = encoder_hidden_states_key_proj
762
+ value = encoder_hidden_states_value_proj
763
+
764
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
765
+ hidden_states = torch.bmm(attention_probs, value)
766
+ hidden_states = attn.batch_to_head_dim(hidden_states)
767
+
768
+ # linear proj
769
+ hidden_states = attn.to_out[0](hidden_states)
770
+ # dropout
771
+ hidden_states = attn.to_out[1](hidden_states)
772
+
773
+ hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
774
+ hidden_states = hidden_states + residual
775
+
776
+ return hidden_states
777
+
778
+
779
+ class AttnAddedKVProcessor2_0:
780
+ r"""
781
+ Processor for performing scaled dot-product attention (enabled by default if you're using PyTorch 2.0), with extra
782
+ learnable key and value matrices for the text encoder.
783
+ """
784
+
785
+ def __init__(self):
786
+ if not hasattr(F, "scaled_dot_product_attention"):
787
+ raise ImportError(
788
+ "AttnAddedKVProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
789
+ )
790
+
791
+ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
792
+ residual = hidden_states
793
+ hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
794
+ batch_size, sequence_length, _ = hidden_states.shape
795
+
796
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size, out_dim=4)
797
+
798
+ if encoder_hidden_states is None:
799
+ encoder_hidden_states = hidden_states
800
+ elif attn.norm_cross:
801
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
802
+
803
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
804
+
805
+ query = attn.to_q(hidden_states)
806
+ query = attn.head_to_batch_dim(query, out_dim=4)
807
+
808
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
809
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
810
+ encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj, out_dim=4)
811
+ encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj, out_dim=4)
812
+
813
+ if not attn.only_cross_attention:
814
+ key = attn.to_k(hidden_states)
815
+ value = attn.to_v(hidden_states)
816
+ key = attn.head_to_batch_dim(key, out_dim=4)
817
+ value = attn.head_to_batch_dim(value, out_dim=4)
818
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
819
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
820
+ else:
821
+ key = encoder_hidden_states_key_proj
822
+ value = encoder_hidden_states_value_proj
823
+
824
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
825
+ # TODO: add support for attn.scale when we move to Torch 2.1
826
+ hidden_states = F.scaled_dot_product_attention(
827
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
828
+ )
829
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, residual.shape[1])
830
+
831
+ # linear proj
832
+ hidden_states = attn.to_out[0](hidden_states)
833
+ # dropout
834
+ hidden_states = attn.to_out[1](hidden_states)
835
+
836
+ hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
837
+ hidden_states = hidden_states + residual
838
+
839
+ return hidden_states
840
+
841
+
842
+ class LoRAAttnAddedKVProcessor(nn.Module):
843
+ r"""
844
+ Processor for implementing the LoRA attention mechanism with extra learnable key and value matrices for the text
845
+ encoder.
846
+
847
+ Args:
848
+ hidden_size (`int`, *optional*):
849
+ The hidden size of the attention layer.
850
+ cross_attention_dim (`int`, *optional*, defaults to `None`):
851
+ The number of channels in the `encoder_hidden_states`.
852
+ rank (`int`, defaults to 4):
853
+ The dimension of the LoRA update matrices.
854
+
855
+ """
856
+
857
+ def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None):
858
+ super().__init__()
859
+
860
+ self.hidden_size = hidden_size
861
+ self.cross_attention_dim = cross_attention_dim
862
+ self.rank = rank
863
+
864
+ self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
865
+ self.add_k_proj_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
866
+ self.add_v_proj_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
867
+ self.to_k_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
868
+ self.to_v_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
869
+ self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
870
+
871
+ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0):
872
+ residual = hidden_states
873
+ hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
874
+ batch_size, sequence_length, _ = hidden_states.shape
875
+
876
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
877
+
878
+ if encoder_hidden_states is None:
879
+ encoder_hidden_states = hidden_states
880
+ elif attn.norm_cross:
881
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
882
+
883
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
884
+
885
+ query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
886
+ query = attn.head_to_batch_dim(query)
887
+
888
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) + scale * self.add_k_proj_lora(
889
+ encoder_hidden_states
890
+ )
891
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) + scale * self.add_v_proj_lora(
892
+ encoder_hidden_states
893
+ )
894
+ encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
895
+ encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
896
+
897
+ if not attn.only_cross_attention:
898
+ key = attn.to_k(hidden_states) + scale * self.to_k_lora(hidden_states)
899
+ value = attn.to_v(hidden_states) + scale * self.to_v_lora(hidden_states)
900
+ key = attn.head_to_batch_dim(key)
901
+ value = attn.head_to_batch_dim(value)
902
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
903
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
904
+ else:
905
+ key = encoder_hidden_states_key_proj
906
+ value = encoder_hidden_states_value_proj
907
+
908
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
909
+ hidden_states = torch.bmm(attention_probs, value)
910
+ hidden_states = attn.batch_to_head_dim(hidden_states)
911
+
912
+ # linear proj
913
+ hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states)
914
+ # dropout
915
+ hidden_states = attn.to_out[1](hidden_states)
916
+
917
+ hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
918
+ hidden_states = hidden_states + residual
919
+
920
+ return hidden_states
921
+
922
+
923
+ class XFormersAttnAddedKVProcessor:
924
+ r"""
925
+ Processor for implementing memory efficient attention using xFormers.
926
+
927
+ Args:
928
+ attention_op (`Callable`, *optional*, defaults to `None`):
929
+ The base
930
+ [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
931
+ use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
932
+ operator.
933
+ """
934
+
935
+ def __init__(self, attention_op: Optional[Callable] = None):
936
+ self.attention_op = attention_op
937
+
938
+ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
939
+ residual = hidden_states
940
+ hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
941
+ batch_size, sequence_length, _ = hidden_states.shape
942
+
943
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
944
+
945
+ if encoder_hidden_states is None:
946
+ encoder_hidden_states = hidden_states
947
+ elif attn.norm_cross:
948
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
949
+
950
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
951
+
952
+ query = attn.to_q(hidden_states)
953
+ query = attn.head_to_batch_dim(query)
954
+
955
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
956
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
957
+ encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
958
+ encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
959
+
960
+ if not attn.only_cross_attention:
961
+ key = attn.to_k(hidden_states)
962
+ value = attn.to_v(hidden_states)
963
+ key = attn.head_to_batch_dim(key)
964
+ value = attn.head_to_batch_dim(value)
965
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
966
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
967
+ else:
968
+ key = encoder_hidden_states_key_proj
969
+ value = encoder_hidden_states_value_proj
970
+
971
+ hidden_states = xformers.ops.memory_efficient_attention(
972
+ query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
973
+ )
974
+ hidden_states = hidden_states.to(query.dtype)
975
+ hidden_states = attn.batch_to_head_dim(hidden_states)
976
+
977
+ # linear proj
978
+ hidden_states = attn.to_out[0](hidden_states)
979
+ # dropout
980
+ hidden_states = attn.to_out[1](hidden_states)
981
+
982
+ hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
983
+ hidden_states = hidden_states + residual
984
+
985
+ return hidden_states
986
+
987
+
988
+ class XFormersAttnProcessor:
989
+ r"""
990
+ Processor for implementing memory efficient attention using xFormers.
991
+
992
+ Args:
993
+ attention_op (`Callable`, *optional*, defaults to `None`):
994
+ The base
995
+ [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
996
+ use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
997
+ operator.
998
+ """
999
+
1000
+ def __init__(self, attention_op: Optional[Callable] = None):
1001
+ self.attention_op = attention_op
1002
+
1003
+ def __call__(
1004
+ self,
1005
+ attn: Attention,
1006
+ hidden_states: torch.FloatTensor,
1007
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
1008
+ attention_mask: Optional[torch.FloatTensor] = None,
1009
+ temb: Optional[torch.FloatTensor] = None,
1010
+ posemb: Optional = None,
1011
+ ):
1012
+ residual = hidden_states
1013
+
1014
+ if attn.spatial_norm is not None:
1015
+ hidden_states = attn.spatial_norm(hidden_states, temb)
1016
+
1017
+ input_ndim = hidden_states.ndim
1018
+
1019
+ if input_ndim == 4:
1020
+ batch_size, channel, height, width = hidden_states.shape
1021
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
1022
+
1023
+ if posemb is not None:
1024
+ # turn 2d attention into multiview attention
1025
+ self_attn = encoder_hidden_states is None # check if self attn or cross attn
1026
+ [p_out, p_out_inv], [p_in, p_in_inv] = posemb
1027
+ t_out, t_in = p_out.shape[1], p_in.shape[1] # t size
1028
+ hidden_states = einops.rearrange(hidden_states, '(b t_out) l d -> b (t_out l) d', t_out=t_out)
1029
+
1030
+ batch_size, key_tokens, _ = (
1031
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
1032
+ )
1033
+
1034
+ attention_mask = attn.prepare_attention_mask(attention_mask, key_tokens, batch_size)
1035
+ if attention_mask is not None:
1036
+ # expand our mask's singleton query_tokens dimension:
1037
+ # [batch*heads, 1, key_tokens] ->
1038
+ # [batch*heads, query_tokens, key_tokens]
1039
+ # so that it can be added as a bias onto the attention scores that xformers computes:
1040
+ # [batch*heads, query_tokens, key_tokens]
1041
+ # we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
1042
+ _, query_tokens, _ = hidden_states.shape
1043
+ attention_mask = attention_mask.expand(-1, query_tokens, -1)
1044
+
1045
+ if attn.group_norm is not None:
1046
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
1047
+
1048
+ query = attn.to_q(hidden_states)
1049
+ if encoder_hidden_states is None:
1050
+ encoder_hidden_states = hidden_states
1051
+ elif attn.norm_cross:
1052
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
1053
+
1054
+ key = attn.to_k(encoder_hidden_states)
1055
+ value = attn.to_v(encoder_hidden_states)
1056
+
1057
+
1058
+ # apply 6DoF, todo now only for xformer processor
1059
+ if posemb is not None:
1060
+ p_out_inv = einops.repeat(p_out_inv, 'b t_out f g -> b (t_out l) f g', l=query.shape[1] // t_out) # query shape
1061
+ if self_attn:
1062
+ p_in = einops.repeat(p_out, 'b t_out f g -> b (t_out l) f g', l=query.shape[1] // t_out) # query shape
1063
+ else:
1064
+ p_in = einops.repeat(p_in, 'b t_in f g -> b (t_in l) f g', l=key.shape[1] // t_in) # key shape
1065
+ query = cape_embed(query, p_out_inv) # query f_q @ (p_out)^(-T) .permute(0, 1, 3, 2)
1066
+ key = cape_embed(key, p_in) # key f_k @ p_in
1067
+
1068
+
1069
+ query = attn.head_to_batch_dim(query).contiguous()
1070
+ key = attn.head_to_batch_dim(key).contiguous()
1071
+ value = attn.head_to_batch_dim(value).contiguous()
1072
+
1073
+ # self-ttn (bm) l c x (bm) l c -> (bm) l c
1074
+ # cross-ttn (bm) l c x b (nl) c -> (bm) l c
1075
+ # reuse 2d attention for multiview attention
1076
+ # self-ttn b (ml) c x b (ml) c -> b (ml) c
1077
+ # cross-ttn b (ml) c x b (nl) c -> b (ml) c
1078
+ hidden_states = xformers.ops.memory_efficient_attention( # query: (bm) l c -> b (ml) c; key: b (nl) c
1079
+ query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
1080
+ )
1081
+ hidden_states = hidden_states.to(query.dtype)
1082
+ hidden_states = attn.batch_to_head_dim(hidden_states)
1083
+
1084
+ # linear proj
1085
+ hidden_states = attn.to_out[0](hidden_states)
1086
+ # dropout
1087
+ hidden_states = attn.to_out[1](hidden_states)
1088
+
1089
+ if posemb is not None:
1090
+ # reshape back
1091
+ hidden_states = einops.rearrange(hidden_states, 'b (t_out l) d -> (b t_out) l d', t_out=t_out)
1092
+
1093
+ if input_ndim == 4:
1094
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
1095
+
1096
+ if attn.residual_connection:
1097
+ hidden_states = hidden_states + residual
1098
+
1099
+ hidden_states = hidden_states / attn.rescale_output_factor
1100
+
1101
+
1102
+ return hidden_states
1103
+
1104
+
1105
+ class AttnProcessor2_0:
1106
+ r"""
1107
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
1108
+ """
1109
+
1110
+ def __init__(self):
1111
+ if not hasattr(F, "scaled_dot_product_attention"):
1112
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
1113
+
1114
+ def __call__(
1115
+ self,
1116
+ attn: Attention,
1117
+ hidden_states,
1118
+ encoder_hidden_states=None,
1119
+ attention_mask=None,
1120
+ temb=None,
1121
+ ):
1122
+ residual = hidden_states
1123
+
1124
+ if attn.spatial_norm is not None:
1125
+ hidden_states = attn.spatial_norm(hidden_states, temb)
1126
+
1127
+ input_ndim = hidden_states.ndim
1128
+
1129
+ if input_ndim == 4:
1130
+ batch_size, channel, height, width = hidden_states.shape
1131
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
1132
+
1133
+ batch_size, sequence_length, _ = (
1134
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
1135
+ )
1136
+ inner_dim = hidden_states.shape[-1]
1137
+
1138
+ if attention_mask is not None:
1139
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
1140
+ # scaled_dot_product_attention expects attention_mask shape to be
1141
+ # (batch, heads, source_length, target_length)
1142
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
1143
+
1144
+ if attn.group_norm is not None:
1145
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
1146
+
1147
+ query = attn.to_q(hidden_states)
1148
+
1149
+ if encoder_hidden_states is None:
1150
+ encoder_hidden_states = hidden_states
1151
+ elif attn.norm_cross:
1152
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
1153
+
1154
+ key = attn.to_k(encoder_hidden_states)
1155
+ value = attn.to_v(encoder_hidden_states)
1156
+
1157
+ head_dim = inner_dim // attn.heads
1158
+
1159
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1160
+
1161
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1162
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1163
+
1164
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
1165
+ # TODO: add support for attn.scale when we move to Torch 2.1
1166
+ hidden_states = F.scaled_dot_product_attention(
1167
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
1168
+ )
1169
+
1170
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
1171
+ hidden_states = hidden_states.to(query.dtype)
1172
+
1173
+ # linear proj
1174
+ hidden_states = attn.to_out[0](hidden_states)
1175
+ # dropout
1176
+ hidden_states = attn.to_out[1](hidden_states)
1177
+
1178
+ if input_ndim == 4:
1179
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
1180
+
1181
+ if attn.residual_connection:
1182
+ hidden_states = hidden_states + residual
1183
+
1184
+ hidden_states = hidden_states / attn.rescale_output_factor
1185
+
1186
+ return hidden_states
1187
+
1188
+
1189
+ class LoRAXFormersAttnProcessor(nn.Module):
1190
+ r"""
1191
+ Processor for implementing the LoRA attention mechanism with memory efficient attention using xFormers.
1192
+
1193
+ Args:
1194
+ hidden_size (`int`, *optional*):
1195
+ The hidden size of the attention layer.
1196
+ cross_attention_dim (`int`, *optional*):
1197
+ The number of channels in the `encoder_hidden_states`.
1198
+ rank (`int`, defaults to 4):
1199
+ The dimension of the LoRA update matrices.
1200
+ attention_op (`Callable`, *optional*, defaults to `None`):
1201
+ The base
1202
+ [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
1203
+ use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
1204
+ operator.
1205
+ network_alpha (`int`, *optional*):
1206
+ Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.
1207
+
1208
+ """
1209
+
1210
+ def __init__(
1211
+ self, hidden_size, cross_attention_dim, rank=4, attention_op: Optional[Callable] = None, network_alpha=None
1212
+ ):
1213
+ super().__init__()
1214
+
1215
+ self.hidden_size = hidden_size
1216
+ self.cross_attention_dim = cross_attention_dim
1217
+ self.rank = rank
1218
+ self.attention_op = attention_op
1219
+
1220
+ self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
1221
+ self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
1222
+ self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
1223
+ self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
1224
+
1225
+ def __call__(
1226
+ self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0, temb=None
1227
+ ):
1228
+ residual = hidden_states
1229
+
1230
+ if attn.spatial_norm is not None:
1231
+ hidden_states = attn.spatial_norm(hidden_states, temb)
1232
+
1233
+ input_ndim = hidden_states.ndim
1234
+
1235
+ if input_ndim == 4:
1236
+ batch_size, channel, height, width = hidden_states.shape
1237
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
1238
+
1239
+ batch_size, sequence_length, _ = (
1240
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
1241
+ )
1242
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
1243
+
1244
+ if attn.group_norm is not None:
1245
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
1246
+
1247
+ query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
1248
+ query = attn.head_to_batch_dim(query).contiguous()
1249
+
1250
+ if encoder_hidden_states is None:
1251
+ encoder_hidden_states = hidden_states
1252
+ elif attn.norm_cross:
1253
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
1254
+
1255
+ key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states)
1256
+ value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states)
1257
+
1258
+ key = attn.head_to_batch_dim(key).contiguous()
1259
+ value = attn.head_to_batch_dim(value).contiguous()
1260
+
1261
+ hidden_states = xformers.ops.memory_efficient_attention(
1262
+ query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
1263
+ )
1264
+ hidden_states = attn.batch_to_head_dim(hidden_states)
1265
+
1266
+ # linear proj
1267
+ hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states)
1268
+ # dropout
1269
+ hidden_states = attn.to_out[1](hidden_states)
1270
+
1271
+ if input_ndim == 4:
1272
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
1273
+
1274
+ if attn.residual_connection:
1275
+ hidden_states = hidden_states + residual
1276
+
1277
+ hidden_states = hidden_states / attn.rescale_output_factor
1278
+
1279
+ return hidden_states
1280
+
1281
+
1282
+ class LoRAAttnProcessor2_0(nn.Module):
1283
+ r"""
1284
+ Processor for implementing the LoRA attention mechanism using PyTorch 2.0's memory-efficient scaled dot-product
1285
+ attention.
1286
+
1287
+ Args:
1288
+ hidden_size (`int`):
1289
+ The hidden size of the attention layer.
1290
+ cross_attention_dim (`int`, *optional*):
1291
+ The number of channels in the `encoder_hidden_states`.
1292
+ rank (`int`, defaults to 4):
1293
+ The dimension of the LoRA update matrices.
1294
+ network_alpha (`int`, *optional*):
1295
+ Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.
1296
+ """
1297
+
1298
+ def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None):
1299
+ super().__init__()
1300
+ if not hasattr(F, "scaled_dot_product_attention"):
1301
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
1302
+
1303
+ self.hidden_size = hidden_size
1304
+ self.cross_attention_dim = cross_attention_dim
1305
+ self.rank = rank
1306
+
1307
+ self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
1308
+ self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
1309
+ self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
1310
+ self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
1311
+
1312
+ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0):
1313
+ residual = hidden_states
1314
+
1315
+ input_ndim = hidden_states.ndim
1316
+
1317
+ if input_ndim == 4:
1318
+ batch_size, channel, height, width = hidden_states.shape
1319
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
1320
+
1321
+ batch_size, sequence_length, _ = (
1322
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
1323
+ )
1324
+ inner_dim = hidden_states.shape[-1]
1325
+
1326
+ if attention_mask is not None:
1327
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
1328
+ # scaled_dot_product_attention expects attention_mask shape to be
1329
+ # (batch, heads, source_length, target_length)
1330
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
1331
+
1332
+ if attn.group_norm is not None:
1333
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
1334
+
1335
+ query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
1336
+
1337
+ if encoder_hidden_states is None:
1338
+ encoder_hidden_states = hidden_states
1339
+ elif attn.norm_cross:
1340
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
1341
+
1342
+ key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states)
1343
+ value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states)
1344
+
1345
+ head_dim = inner_dim // attn.heads
1346
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1347
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1348
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1349
+
1350
+ # TODO: add support for attn.scale when we move to Torch 2.1
1351
+ hidden_states = F.scaled_dot_product_attention(
1352
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
1353
+ )
1354
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
1355
+ hidden_states = hidden_states.to(query.dtype)
1356
+
1357
+ # linear proj
1358
+ hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states)
1359
+ # dropout
1360
+ hidden_states = attn.to_out[1](hidden_states)
1361
+
1362
+ if input_ndim == 4:
1363
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
1364
+
1365
+ if attn.residual_connection:
1366
+ hidden_states = hidden_states + residual
1367
+
1368
+ hidden_states = hidden_states / attn.rescale_output_factor
1369
+
1370
+ return hidden_states
1371
+
1372
+
1373
+ class CustomDiffusionXFormersAttnProcessor(nn.Module):
1374
+ r"""
1375
+ Processor for implementing memory efficient attention using xFormers for the Custom Diffusion method.
1376
+
1377
+ Args:
1378
+ train_kv (`bool`, defaults to `True`):
1379
+ Whether to newly train the key and value matrices corresponding to the text features.
1380
+ train_q_out (`bool`, defaults to `True`):
1381
+ Whether to newly train query matrices corresponding to the latent image features.
1382
+ hidden_size (`int`, *optional*, defaults to `None`):
1383
+ The hidden size of the attention layer.
1384
+ cross_attention_dim (`int`, *optional*, defaults to `None`):
1385
+ The number of channels in the `encoder_hidden_states`.
1386
+ out_bias (`bool`, defaults to `True`):
1387
+ Whether to include the bias parameter in `train_q_out`.
1388
+ dropout (`float`, *optional*, defaults to 0.0):
1389
+ The dropout probability to use.
1390
+ attention_op (`Callable`, *optional*, defaults to `None`):
1391
+ The base
1392
+ [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to use
1393
+ as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best operator.
1394
+ """
1395
+
1396
+ def __init__(
1397
+ self,
1398
+ train_kv=True,
1399
+ train_q_out=False,
1400
+ hidden_size=None,
1401
+ cross_attention_dim=None,
1402
+ out_bias=True,
1403
+ dropout=0.0,
1404
+ attention_op: Optional[Callable] = None,
1405
+ ):
1406
+ super().__init__()
1407
+ self.train_kv = train_kv
1408
+ self.train_q_out = train_q_out
1409
+
1410
+ self.hidden_size = hidden_size
1411
+ self.cross_attention_dim = cross_attention_dim
1412
+ self.attention_op = attention_op
1413
+
1414
+ # `_custom_diffusion` id for easy serialization and loading.
1415
+ if self.train_kv:
1416
+ self.to_k_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
1417
+ self.to_v_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
1418
+ if self.train_q_out:
1419
+ self.to_q_custom_diffusion = nn.Linear(hidden_size, hidden_size, bias=False)
1420
+ self.to_out_custom_diffusion = nn.ModuleList([])
1421
+ self.to_out_custom_diffusion.append(nn.Linear(hidden_size, hidden_size, bias=out_bias))
1422
+ self.to_out_custom_diffusion.append(nn.Dropout(dropout))
1423
+
1424
+ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
1425
+ batch_size, sequence_length, _ = (
1426
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
1427
+ )
1428
+
1429
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
1430
+
1431
+ if self.train_q_out:
1432
+ query = self.to_q_custom_diffusion(hidden_states)
1433
+ else:
1434
+ query = attn.to_q(hidden_states)
1435
+
1436
+ if encoder_hidden_states is None:
1437
+ crossattn = False
1438
+ encoder_hidden_states = hidden_states
1439
+ else:
1440
+ crossattn = True
1441
+ if attn.norm_cross:
1442
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
1443
+
1444
+ if self.train_kv:
1445
+ key = self.to_k_custom_diffusion(encoder_hidden_states)
1446
+ value = self.to_v_custom_diffusion(encoder_hidden_states)
1447
+ else:
1448
+ key = attn.to_k(encoder_hidden_states)
1449
+ value = attn.to_v(encoder_hidden_states)
1450
+
1451
+ if crossattn:
1452
+ detach = torch.ones_like(key)
1453
+ detach[:, :1, :] = detach[:, :1, :] * 0.0
1454
+ key = detach * key + (1 - detach) * key.detach()
1455
+ value = detach * value + (1 - detach) * value.detach()
1456
+
1457
+ query = attn.head_to_batch_dim(query).contiguous()
1458
+ key = attn.head_to_batch_dim(key).contiguous()
1459
+ value = attn.head_to_batch_dim(value).contiguous()
1460
+
1461
+ hidden_states = xformers.ops.memory_efficient_attention(
1462
+ query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
1463
+ )
1464
+ hidden_states = hidden_states.to(query.dtype)
1465
+ hidden_states = attn.batch_to_head_dim(hidden_states)
1466
+
1467
+ if self.train_q_out:
1468
+ # linear proj
1469
+ hidden_states = self.to_out_custom_diffusion[0](hidden_states)
1470
+ # dropout
1471
+ hidden_states = self.to_out_custom_diffusion[1](hidden_states)
1472
+ else:
1473
+ # linear proj
1474
+ hidden_states = attn.to_out[0](hidden_states)
1475
+ # dropout
1476
+ hidden_states = attn.to_out[1](hidden_states)
1477
+ return hidden_states
1478
+
1479
+
1480
+ class SlicedAttnProcessor:
1481
+ r"""
1482
+ Processor for implementing sliced attention.
1483
+
1484
+ Args:
1485
+ slice_size (`int`, *optional*):
1486
+ The number of steps to compute attention. Uses as many slices as `attention_head_dim // slice_size`, and
1487
+ `attention_head_dim` must be a multiple of the `slice_size`.
1488
+ """
1489
+
1490
+ def __init__(self, slice_size):
1491
+ self.slice_size = slice_size
1492
+
1493
+ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
1494
+ residual = hidden_states
1495
+
1496
+ input_ndim = hidden_states.ndim
1497
+
1498
+ if input_ndim == 4:
1499
+ batch_size, channel, height, width = hidden_states.shape
1500
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
1501
+
1502
+ batch_size, sequence_length, _ = (
1503
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
1504
+ )
1505
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
1506
+
1507
+ if attn.group_norm is not None:
1508
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
1509
+
1510
+ query = attn.to_q(hidden_states)
1511
+ dim = query.shape[-1]
1512
+ query = attn.head_to_batch_dim(query)
1513
+
1514
+ if encoder_hidden_states is None:
1515
+ encoder_hidden_states = hidden_states
1516
+ elif attn.norm_cross:
1517
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
1518
+
1519
+ key = attn.to_k(encoder_hidden_states)
1520
+ value = attn.to_v(encoder_hidden_states)
1521
+ key = attn.head_to_batch_dim(key)
1522
+ value = attn.head_to_batch_dim(value)
1523
+
1524
+ batch_size_attention, query_tokens, _ = query.shape
1525
+ hidden_states = torch.zeros(
1526
+ (batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype
1527
+ )
1528
+
1529
+ for i in range(batch_size_attention // self.slice_size):
1530
+ start_idx = i * self.slice_size
1531
+ end_idx = (i + 1) * self.slice_size
1532
+
1533
+ query_slice = query[start_idx:end_idx]
1534
+ key_slice = key[start_idx:end_idx]
1535
+ attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
1536
+
1537
+ attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
1538
+
1539
+ attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
1540
+
1541
+ hidden_states[start_idx:end_idx] = attn_slice
1542
+
1543
+ hidden_states = attn.batch_to_head_dim(hidden_states)
1544
+
1545
+ # linear proj
1546
+ hidden_states = attn.to_out[0](hidden_states)
1547
+ # dropout
1548
+ hidden_states = attn.to_out[1](hidden_states)
1549
+
1550
+ if input_ndim == 4:
1551
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
1552
+
1553
+ if attn.residual_connection:
1554
+ hidden_states = hidden_states + residual
1555
+
1556
+ hidden_states = hidden_states / attn.rescale_output_factor
1557
+
1558
+ return hidden_states
1559
+
1560
+
1561
+ class SlicedAttnAddedKVProcessor:
1562
+ r"""
1563
+ Processor for implementing sliced attention with extra learnable key and value matrices for the text encoder.
1564
+
1565
+ Args:
1566
+ slice_size (`int`, *optional*):
1567
+ The number of steps to compute attention. Uses as many slices as `attention_head_dim // slice_size`, and
1568
+ `attention_head_dim` must be a multiple of the `slice_size`.
1569
+ """
1570
+
1571
+ def __init__(self, slice_size):
1572
+ self.slice_size = slice_size
1573
+
1574
+ def __call__(self, attn: "Attention", hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None):
1575
+ residual = hidden_states
1576
+
1577
+ if attn.spatial_norm is not None:
1578
+ hidden_states = attn.spatial_norm(hidden_states, temb)
1579
+
1580
+ hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
1581
+
1582
+ batch_size, sequence_length, _ = hidden_states.shape
1583
+
1584
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
1585
+
1586
+ if encoder_hidden_states is None:
1587
+ encoder_hidden_states = hidden_states
1588
+ elif attn.norm_cross:
1589
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
1590
+
1591
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
1592
+
1593
+ query = attn.to_q(hidden_states)
1594
+ dim = query.shape[-1]
1595
+ query = attn.head_to_batch_dim(query)
1596
+
1597
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
1598
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
1599
+
1600
+ encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
1601
+ encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
1602
+
1603
+ if not attn.only_cross_attention:
1604
+ key = attn.to_k(hidden_states)
1605
+ value = attn.to_v(hidden_states)
1606
+ key = attn.head_to_batch_dim(key)
1607
+ value = attn.head_to_batch_dim(value)
1608
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
1609
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
1610
+ else:
1611
+ key = encoder_hidden_states_key_proj
1612
+ value = encoder_hidden_states_value_proj
1613
+
1614
+ batch_size_attention, query_tokens, _ = query.shape
1615
+ hidden_states = torch.zeros(
1616
+ (batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype
1617
+ )
1618
+
1619
+ for i in range(batch_size_attention // self.slice_size):
1620
+ start_idx = i * self.slice_size
1621
+ end_idx = (i + 1) * self.slice_size
1622
+
1623
+ query_slice = query[start_idx:end_idx]
1624
+ key_slice = key[start_idx:end_idx]
1625
+ attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
1626
+
1627
+ attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
1628
+
1629
+ attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
1630
+
1631
+ hidden_states[start_idx:end_idx] = attn_slice
1632
+
1633
+ hidden_states = attn.batch_to_head_dim(hidden_states)
1634
+
1635
+ # linear proj
1636
+ hidden_states = attn.to_out[0](hidden_states)
1637
+ # dropout
1638
+ hidden_states = attn.to_out[1](hidden_states)
1639
+
1640
+ hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
1641
+ hidden_states = hidden_states + residual
1642
+
1643
+ return hidden_states
1644
+
1645
+
1646
+ AttentionProcessor = Union[
1647
+ AttnProcessor,
1648
+ AttnProcessor2_0,
1649
+ XFormersAttnProcessor,
1650
+ SlicedAttnProcessor,
1651
+ AttnAddedKVProcessor,
1652
+ SlicedAttnAddedKVProcessor,
1653
+ AttnAddedKVProcessor2_0,
1654
+ XFormersAttnAddedKVProcessor,
1655
+ LoRAAttnProcessor,
1656
+ LoRAXFormersAttnProcessor,
1657
+ LoRAAttnProcessor2_0,
1658
+ LoRAAttnAddedKVProcessor,
1659
+ CustomDiffusionAttnProcessor,
1660
+ CustomDiffusionXFormersAttnProcessor,
1661
+ ]
1662
+
1663
+
1664
+ class SpatialNorm(nn.Module):
1665
+ """
1666
+ Spatially conditioned normalization as defined in https://arxiv.org/abs/2209.09002
1667
+ """
1668
+
1669
+ def __init__(
1670
+ self,
1671
+ f_channels,
1672
+ zq_channels,
1673
+ ):
1674
+ super().__init__()
1675
+ self.norm_layer = nn.GroupNorm(num_channels=f_channels, num_groups=32, eps=1e-6, affine=True)
1676
+ self.conv_y = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
1677
+ self.conv_b = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
1678
+
1679
+ def forward(self, f, zq):
1680
+ f_size = f.shape[-2:]
1681
+ zq = F.interpolate(zq, size=f_size, mode="nearest")
1682
+ norm_f = self.norm_layer(f)
1683
+ new_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
1684
+ return new_f
6DoF/diffusers/models/autoencoder_kl.py ADDED
@@ -0,0 +1,411 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from dataclasses import dataclass
15
+ from typing import Dict, Optional, Tuple, Union
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+
20
+ from ..configuration_utils import ConfigMixin, register_to_config
21
+ from ..utils import BaseOutput, apply_forward_hook
22
+ from .attention_processor import AttentionProcessor, AttnProcessor
23
+ from .modeling_utils import ModelMixin
24
+ from .vae import Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder
25
+
26
+
27
+ @dataclass
28
+ class AutoencoderKLOutput(BaseOutput):
29
+ """
30
+ Output of AutoencoderKL encoding method.
31
+
32
+ Args:
33
+ latent_dist (`DiagonalGaussianDistribution`):
34
+ Encoded outputs of `Encoder` represented as the mean and logvar of `DiagonalGaussianDistribution`.
35
+ `DiagonalGaussianDistribution` allows for sampling latents from the distribution.
36
+ """
37
+
38
+ latent_dist: "DiagonalGaussianDistribution"
39
+
40
+
41
+ class AutoencoderKL(ModelMixin, ConfigMixin):
42
+ r"""
43
+ A VAE model with KL loss for encoding images into latents and decoding latent representations into images.
44
+
45
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
46
+ for all models (such as downloading or saving).
47
+
48
+ Parameters:
49
+ in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
50
+ out_channels (int, *optional*, defaults to 3): Number of channels in the output.
51
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
52
+ Tuple of downsample block types.
53
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
54
+ Tuple of upsample block types.
55
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
56
+ Tuple of block output channels.
57
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
58
+ latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space.
59
+ sample_size (`int`, *optional*, defaults to `32`): Sample input size.
60
+ scaling_factor (`float`, *optional*, defaults to 0.18215):
61
+ The component-wise standard deviation of the trained latent space computed using the first batch of the
62
+ training set. This is used to scale the latent space to have unit variance when training the diffusion
63
+ model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
64
+ diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
65
+ / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
66
+ Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
67
+ """
68
+
69
+ _supports_gradient_checkpointing = True
70
+
71
+ @register_to_config
72
+ def __init__(
73
+ self,
74
+ in_channels: int = 3,
75
+ out_channels: int = 3,
76
+ down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
77
+ up_block_types: Tuple[str] = ("UpDecoderBlock2D",),
78
+ block_out_channels: Tuple[int] = (64,),
79
+ layers_per_block: int = 1,
80
+ act_fn: str = "silu",
81
+ latent_channels: int = 4,
82
+ norm_num_groups: int = 32,
83
+ sample_size: int = 32,
84
+ scaling_factor: float = 0.18215,
85
+ ):
86
+ super().__init__()
87
+
88
+ # pass init params to Encoder
89
+ self.encoder = Encoder(
90
+ in_channels=in_channels,
91
+ out_channels=latent_channels,
92
+ down_block_types=down_block_types,
93
+ block_out_channels=block_out_channels,
94
+ layers_per_block=layers_per_block,
95
+ act_fn=act_fn,
96
+ norm_num_groups=norm_num_groups,
97
+ double_z=True,
98
+ )
99
+
100
+ # pass init params to Decoder
101
+ self.decoder = Decoder(
102
+ in_channels=latent_channels,
103
+ out_channels=out_channels,
104
+ up_block_types=up_block_types,
105
+ block_out_channels=block_out_channels,
106
+ layers_per_block=layers_per_block,
107
+ norm_num_groups=norm_num_groups,
108
+ act_fn=act_fn,
109
+ )
110
+
111
+ self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
112
+ self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1)
113
+
114
+ self.use_slicing = False
115
+ self.use_tiling = False
116
+
117
+ # only relevant if vae tiling is enabled
118
+ self.tile_sample_min_size = self.config.sample_size
119
+ sample_size = (
120
+ self.config.sample_size[0]
121
+ if isinstance(self.config.sample_size, (list, tuple))
122
+ else self.config.sample_size
123
+ )
124
+ self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1)))
125
+ self.tile_overlap_factor = 0.25
126
+
127
+ def _set_gradient_checkpointing(self, module, value=False):
128
+ if isinstance(module, (Encoder, Decoder)):
129
+ module.gradient_checkpointing = value
130
+
131
+ def enable_tiling(self, use_tiling: bool = True):
132
+ r"""
133
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
134
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
135
+ processing larger images.
136
+ """
137
+ self.use_tiling = use_tiling
138
+
139
+ def disable_tiling(self):
140
+ r"""
141
+ Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
142
+ decoding in one step.
143
+ """
144
+ self.enable_tiling(False)
145
+
146
+ def enable_slicing(self):
147
+ r"""
148
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
149
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
150
+ """
151
+ self.use_slicing = True
152
+
153
+ def disable_slicing(self):
154
+ r"""
155
+ Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
156
+ decoding in one step.
157
+ """
158
+ self.use_slicing = False
159
+
160
+ @property
161
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
162
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
163
+ r"""
164
+ Returns:
165
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
166
+ indexed by its weight name.
167
+ """
168
+ # set recursively
169
+ processors = {}
170
+
171
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
172
+ if hasattr(module, "set_processor"):
173
+ processors[f"{name}.processor"] = module.processor
174
+
175
+ for sub_name, child in module.named_children():
176
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
177
+
178
+ return processors
179
+
180
+ for name, module in self.named_children():
181
+ fn_recursive_add_processors(name, module, processors)
182
+
183
+ return processors
184
+
185
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
186
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
187
+ r"""
188
+ Sets the attention processor to use to compute attention.
189
+
190
+ Parameters:
191
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
192
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
193
+ for **all** `Attention` layers.
194
+
195
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
196
+ processor. This is strongly recommended when setting trainable attention processors.
197
+
198
+ """
199
+ count = len(self.attn_processors.keys())
200
+
201
+ if isinstance(processor, dict) and len(processor) != count:
202
+ raise ValueError(
203
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
204
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
205
+ )
206
+
207
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
208
+ if hasattr(module, "set_processor"):
209
+ if not isinstance(processor, dict):
210
+ module.set_processor(processor)
211
+ else:
212
+ module.set_processor(processor.pop(f"{name}.processor"))
213
+
214
+ for sub_name, child in module.named_children():
215
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
216
+
217
+ for name, module in self.named_children():
218
+ fn_recursive_attn_processor(name, module, processor)
219
+
220
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
221
+ def set_default_attn_processor(self):
222
+ """
223
+ Disables custom attention processors and sets the default attention implementation.
224
+ """
225
+ self.set_attn_processor(AttnProcessor())
226
+
227
+ @apply_forward_hook
228
+ def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
229
+ if self.use_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size):
230
+ return self.tiled_encode(x, return_dict=return_dict)
231
+
232
+ if self.use_slicing and x.shape[0] > 1:
233
+ encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)]
234
+ h = torch.cat(encoded_slices)
235
+ else:
236
+ h = self.encoder(x)
237
+
238
+ moments = self.quant_conv(h)
239
+ posterior = DiagonalGaussianDistribution(moments)
240
+
241
+ if not return_dict:
242
+ return (posterior,)
243
+
244
+ return AutoencoderKLOutput(latent_dist=posterior)
245
+
246
+ def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
247
+ if self.use_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size):
248
+ return self.tiled_decode(z, return_dict=return_dict)
249
+
250
+ z = self.post_quant_conv(z)
251
+ dec = self.decoder(z)
252
+
253
+ if not return_dict:
254
+ return (dec,)
255
+
256
+ return DecoderOutput(sample=dec)
257
+
258
+ @apply_forward_hook
259
+ def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
260
+ if self.use_slicing and z.shape[0] > 1:
261
+ decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
262
+ decoded = torch.cat(decoded_slices)
263
+ else:
264
+ decoded = self._decode(z).sample
265
+
266
+ if not return_dict:
267
+ return (decoded,)
268
+
269
+ return DecoderOutput(sample=decoded)
270
+
271
+ def blend_v(self, a, b, blend_extent):
272
+ blend_extent = min(a.shape[2], b.shape[2], blend_extent)
273
+ for y in range(blend_extent):
274
+ b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent)
275
+ return b
276
+
277
+ def blend_h(self, a, b, blend_extent):
278
+ blend_extent = min(a.shape[3], b.shape[3], blend_extent)
279
+ for x in range(blend_extent):
280
+ b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent)
281
+ return b
282
+
283
+ def tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
284
+ r"""Encode a batch of images using a tiled encoder.
285
+
286
+ When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
287
+ steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is
288
+ different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
289
+ tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
290
+ output, but they should be much less noticeable.
291
+
292
+ Args:
293
+ x (`torch.FloatTensor`): Input batch of images.
294
+ return_dict (`bool`, *optional*, defaults to `True`):
295
+ Whether or not to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
296
+
297
+ Returns:
298
+ [`~models.autoencoder_kl.AutoencoderKLOutput`] or `tuple`:
299
+ If return_dict is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain
300
+ `tuple` is returned.
301
+ """
302
+ overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
303
+ blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
304
+ row_limit = self.tile_latent_min_size - blend_extent
305
+
306
+ # Split the image into 512x512 tiles and encode them separately.
307
+ rows = []
308
+ for i in range(0, x.shape[2], overlap_size):
309
+ row = []
310
+ for j in range(0, x.shape[3], overlap_size):
311
+ tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
312
+ tile = self.encoder(tile)
313
+ tile = self.quant_conv(tile)
314
+ row.append(tile)
315
+ rows.append(row)
316
+ result_rows = []
317
+ for i, row in enumerate(rows):
318
+ result_row = []
319
+ for j, tile in enumerate(row):
320
+ # blend the above tile and the left tile
321
+ # to the current tile and add the current tile to the result row
322
+ if i > 0:
323
+ tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
324
+ if j > 0:
325
+ tile = self.blend_h(row[j - 1], tile, blend_extent)
326
+ result_row.append(tile[:, :, :row_limit, :row_limit])
327
+ result_rows.append(torch.cat(result_row, dim=3))
328
+
329
+ moments = torch.cat(result_rows, dim=2)
330
+ posterior = DiagonalGaussianDistribution(moments)
331
+
332
+ if not return_dict:
333
+ return (posterior,)
334
+
335
+ return AutoencoderKLOutput(latent_dist=posterior)
336
+
337
+ def tiled_decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
338
+ r"""
339
+ Decode a batch of images using a tiled decoder.
340
+
341
+ Args:
342
+ z (`torch.FloatTensor`): Input batch of latent vectors.
343
+ return_dict (`bool`, *optional*, defaults to `True`):
344
+ Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
345
+
346
+ Returns:
347
+ [`~models.vae.DecoderOutput`] or `tuple`:
348
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
349
+ returned.
350
+ """
351
+ overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor))
352
+ blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor)
353
+ row_limit = self.tile_sample_min_size - blend_extent
354
+
355
+ # Split z into overlapping 64x64 tiles and decode them separately.
356
+ # The tiles have an overlap to avoid seams between tiles.
357
+ rows = []
358
+ for i in range(0, z.shape[2], overlap_size):
359
+ row = []
360
+ for j in range(0, z.shape[3], overlap_size):
361
+ tile = z[:, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size]
362
+ tile = self.post_quant_conv(tile)
363
+ decoded = self.decoder(tile)
364
+ row.append(decoded)
365
+ rows.append(row)
366
+ result_rows = []
367
+ for i, row in enumerate(rows):
368
+ result_row = []
369
+ for j, tile in enumerate(row):
370
+ # blend the above tile and the left tile
371
+ # to the current tile and add the current tile to the result row
372
+ if i > 0:
373
+ tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
374
+ if j > 0:
375
+ tile = self.blend_h(row[j - 1], tile, blend_extent)
376
+ result_row.append(tile[:, :, :row_limit, :row_limit])
377
+ result_rows.append(torch.cat(result_row, dim=3))
378
+
379
+ dec = torch.cat(result_rows, dim=2)
380
+ if not return_dict:
381
+ return (dec,)
382
+
383
+ return DecoderOutput(sample=dec)
384
+
385
+ def forward(
386
+ self,
387
+ sample: torch.FloatTensor,
388
+ sample_posterior: bool = False,
389
+ return_dict: bool = True,
390
+ generator: Optional[torch.Generator] = None,
391
+ ) -> Union[DecoderOutput, torch.FloatTensor]:
392
+ r"""
393
+ Args:
394
+ sample (`torch.FloatTensor`): Input sample.
395
+ sample_posterior (`bool`, *optional*, defaults to `False`):
396
+ Whether to sample from the posterior.
397
+ return_dict (`bool`, *optional*, defaults to `True`):
398
+ Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
399
+ """
400
+ x = sample
401
+ posterior = self.encode(x).latent_dist
402
+ if sample_posterior:
403
+ z = posterior.sample(generator=generator)
404
+ else:
405
+ z = posterior.mode()
406
+ dec = self.decode(z).sample
407
+
408
+ if not return_dict:
409
+ return (dec,)
410
+
411
+ return DecoderOutput(sample=dec)
6DoF/diffusers/models/controlnet.py ADDED
@@ -0,0 +1,705 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from dataclasses import dataclass
15
+ from typing import Any, Dict, List, Optional, Tuple, Union
16
+
17
+ import torch
18
+ from torch import nn
19
+ from torch.nn import functional as F
20
+
21
+ from ..configuration_utils import ConfigMixin, register_to_config
22
+ from ..utils import BaseOutput, logging
23
+ from .attention_processor import AttentionProcessor, AttnProcessor
24
+ from .embeddings import TimestepEmbedding, Timesteps
25
+ from .modeling_utils import ModelMixin
26
+ from .unet_2d_blocks import (
27
+ CrossAttnDownBlock2D,
28
+ DownBlock2D,
29
+ UNetMidBlock2DCrossAttn,
30
+ get_down_block,
31
+ )
32
+ from .unet_2d_condition import UNet2DConditionModel
33
+
34
+
35
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
36
+
37
+
38
+ @dataclass
39
+ class ControlNetOutput(BaseOutput):
40
+ """
41
+ The output of [`ControlNetModel`].
42
+
43
+ Args:
44
+ down_block_res_samples (`tuple[torch.Tensor]`):
45
+ A tuple of downsample activations at different resolutions for each downsampling block. Each tensor should
46
+ be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be
47
+ used to condition the original UNet's downsampling activations.
48
+ mid_down_block_re_sample (`torch.Tensor`):
49
+ The activation of the midde block (the lowest sample resolution). Each tensor should be of shape
50
+ `(batch_size, channel * lowest_resolution, height // lowest_resolution, width // lowest_resolution)`.
51
+ Output can be used to condition the original UNet's middle block activation.
52
+ """
53
+
54
+ down_block_res_samples: Tuple[torch.Tensor]
55
+ mid_block_res_sample: torch.Tensor
56
+
57
+
58
+ class ControlNetConditioningEmbedding(nn.Module):
59
+ """
60
+ Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN
61
+ [11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized
62
+ training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the
63
+ convolution size. We use a tiny network E(·) of four convolution layers with 4 × 4 kernels and 2 × 2 strides
64
+ (activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full
65
+ model) to encode image-space conditions ... into feature maps ..."
66
+ """
67
+
68
+ def __init__(
69
+ self,
70
+ conditioning_embedding_channels: int,
71
+ conditioning_channels: int = 3,
72
+ block_out_channels: Tuple[int] = (16, 32, 96, 256),
73
+ ):
74
+ super().__init__()
75
+
76
+ self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1)
77
+
78
+ self.blocks = nn.ModuleList([])
79
+
80
+ for i in range(len(block_out_channels) - 1):
81
+ channel_in = block_out_channels[i]
82
+ channel_out = block_out_channels[i + 1]
83
+ self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1))
84
+ self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2))
85
+
86
+ self.conv_out = zero_module(
87
+ nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1)
88
+ )
89
+
90
+ def forward(self, conditioning):
91
+ embedding = self.conv_in(conditioning)
92
+ embedding = F.silu(embedding)
93
+
94
+ for block in self.blocks:
95
+ embedding = block(embedding)
96
+ embedding = F.silu(embedding)
97
+
98
+ embedding = self.conv_out(embedding)
99
+
100
+ return embedding
101
+
102
+
103
+ class ControlNetModel(ModelMixin, ConfigMixin):
104
+ """
105
+ A ControlNet model.
106
+
107
+ Args:
108
+ in_channels (`int`, defaults to 4):
109
+ The number of channels in the input sample.
110
+ flip_sin_to_cos (`bool`, defaults to `True`):
111
+ Whether to flip the sin to cos in the time embedding.
112
+ freq_shift (`int`, defaults to 0):
113
+ The frequency shift to apply to the time embedding.
114
+ down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
115
+ The tuple of downsample blocks to use.
116
+ only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`):
117
+ block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`):
118
+ The tuple of output channels for each block.
119
+ layers_per_block (`int`, defaults to 2):
120
+ The number of layers per block.
121
+ downsample_padding (`int`, defaults to 1):
122
+ The padding to use for the downsampling convolution.
123
+ mid_block_scale_factor (`float`, defaults to 1):
124
+ The scale factor to use for the mid block.
125
+ act_fn (`str`, defaults to "silu"):
126
+ The activation function to use.
127
+ norm_num_groups (`int`, *optional*, defaults to 32):
128
+ The number of groups to use for the normalization. If None, normalization and activation layers is skipped
129
+ in post-processing.
130
+ norm_eps (`float`, defaults to 1e-5):
131
+ The epsilon to use for the normalization.
132
+ cross_attention_dim (`int`, defaults to 1280):
133
+ The dimension of the cross attention features.
134
+ attention_head_dim (`Union[int, Tuple[int]]`, defaults to 8):
135
+ The dimension of the attention heads.
136
+ use_linear_projection (`bool`, defaults to `False`):
137
+ class_embed_type (`str`, *optional*, defaults to `None`):
138
+ The type of class embedding to use which is ultimately summed with the time embeddings. Choose from None,
139
+ `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
140
+ num_class_embeds (`int`, *optional*, defaults to 0):
141
+ Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
142
+ class conditioning with `class_embed_type` equal to `None`.
143
+ upcast_attention (`bool`, defaults to `False`):
144
+ resnet_time_scale_shift (`str`, defaults to `"default"`):
145
+ Time scale shift config for ResNet blocks (see `ResnetBlock2D`). Choose from `default` or `scale_shift`.
146
+ projection_class_embeddings_input_dim (`int`, *optional*, defaults to `None`):
147
+ The dimension of the `class_labels` input when `class_embed_type="projection"`. Required when
148
+ `class_embed_type="projection"`.
149
+ controlnet_conditioning_channel_order (`str`, defaults to `"rgb"`):
150
+ The channel order of conditional image. Will convert to `rgb` if it's `bgr`.
151
+ conditioning_embedding_out_channels (`tuple[int]`, *optional*, defaults to `(16, 32, 96, 256)`):
152
+ The tuple of output channel for each block in the `conditioning_embedding` layer.
153
+ global_pool_conditions (`bool`, defaults to `False`):
154
+ """
155
+
156
+ _supports_gradient_checkpointing = True
157
+
158
+ @register_to_config
159
+ def __init__(
160
+ self,
161
+ in_channels: int = 4,
162
+ conditioning_channels: int = 3,
163
+ flip_sin_to_cos: bool = True,
164
+ freq_shift: int = 0,
165
+ down_block_types: Tuple[str] = (
166
+ "CrossAttnDownBlock2D",
167
+ "CrossAttnDownBlock2D",
168
+ "CrossAttnDownBlock2D",
169
+ "DownBlock2D",
170
+ ),
171
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
172
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
173
+ layers_per_block: int = 2,
174
+ downsample_padding: int = 1,
175
+ mid_block_scale_factor: float = 1,
176
+ act_fn: str = "silu",
177
+ norm_num_groups: Optional[int] = 32,
178
+ norm_eps: float = 1e-5,
179
+ cross_attention_dim: int = 1280,
180
+ attention_head_dim: Union[int, Tuple[int]] = 8,
181
+ num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
182
+ use_linear_projection: bool = False,
183
+ class_embed_type: Optional[str] = None,
184
+ num_class_embeds: Optional[int] = None,
185
+ upcast_attention: bool = False,
186
+ resnet_time_scale_shift: str = "default",
187
+ projection_class_embeddings_input_dim: Optional[int] = None,
188
+ controlnet_conditioning_channel_order: str = "rgb",
189
+ conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256),
190
+ global_pool_conditions: bool = False,
191
+ ):
192
+ super().__init__()
193
+
194
+ # If `num_attention_heads` is not defined (which is the case for most models)
195
+ # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
196
+ # The reason for this behavior is to correct for incorrectly named variables that were introduced
197
+ # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
198
+ # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
199
+ # which is why we correct for the naming here.
200
+ num_attention_heads = num_attention_heads or attention_head_dim
201
+
202
+ # Check inputs
203
+ if len(block_out_channels) != len(down_block_types):
204
+ raise ValueError(
205
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
206
+ )
207
+
208
+ if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
209
+ raise ValueError(
210
+ f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
211
+ )
212
+
213
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
214
+ raise ValueError(
215
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
216
+ )
217
+
218
+ # input
219
+ conv_in_kernel = 3
220
+ conv_in_padding = (conv_in_kernel - 1) // 2
221
+ self.conv_in = nn.Conv2d(
222
+ in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
223
+ )
224
+
225
+ # time
226
+ time_embed_dim = block_out_channels[0] * 4
227
+
228
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
229
+ timestep_input_dim = block_out_channels[0]
230
+
231
+ self.time_embedding = TimestepEmbedding(
232
+ timestep_input_dim,
233
+ time_embed_dim,
234
+ act_fn=act_fn,
235
+ )
236
+
237
+ # class embedding
238
+ if class_embed_type is None and num_class_embeds is not None:
239
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
240
+ elif class_embed_type == "timestep":
241
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
242
+ elif class_embed_type == "identity":
243
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
244
+ elif class_embed_type == "projection":
245
+ if projection_class_embeddings_input_dim is None:
246
+ raise ValueError(
247
+ "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
248
+ )
249
+ # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
250
+ # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
251
+ # 2. it projects from an arbitrary input dimension.
252
+ #
253
+ # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
254
+ # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
255
+ # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
256
+ self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
257
+ else:
258
+ self.class_embedding = None
259
+
260
+ # control net conditioning embedding
261
+ self.controlnet_cond_embedding = ControlNetConditioningEmbedding(
262
+ conditioning_embedding_channels=block_out_channels[0],
263
+ block_out_channels=conditioning_embedding_out_channels,
264
+ conditioning_channels=conditioning_channels,
265
+ )
266
+
267
+ self.down_blocks = nn.ModuleList([])
268
+ self.controlnet_down_blocks = nn.ModuleList([])
269
+
270
+ if isinstance(only_cross_attention, bool):
271
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
272
+
273
+ if isinstance(attention_head_dim, int):
274
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
275
+
276
+ if isinstance(num_attention_heads, int):
277
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
278
+
279
+ # down
280
+ output_channel = block_out_channels[0]
281
+
282
+ controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
283
+ controlnet_block = zero_module(controlnet_block)
284
+ self.controlnet_down_blocks.append(controlnet_block)
285
+
286
+ for i, down_block_type in enumerate(down_block_types):
287
+ input_channel = output_channel
288
+ output_channel = block_out_channels[i]
289
+ is_final_block = i == len(block_out_channels) - 1
290
+
291
+ down_block = get_down_block(
292
+ down_block_type,
293
+ num_layers=layers_per_block,
294
+ in_channels=input_channel,
295
+ out_channels=output_channel,
296
+ temb_channels=time_embed_dim,
297
+ add_downsample=not is_final_block,
298
+ resnet_eps=norm_eps,
299
+ resnet_act_fn=act_fn,
300
+ resnet_groups=norm_num_groups,
301
+ cross_attention_dim=cross_attention_dim,
302
+ num_attention_heads=num_attention_heads[i],
303
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
304
+ downsample_padding=downsample_padding,
305
+ use_linear_projection=use_linear_projection,
306
+ only_cross_attention=only_cross_attention[i],
307
+ upcast_attention=upcast_attention,
308
+ resnet_time_scale_shift=resnet_time_scale_shift,
309
+ )
310
+ self.down_blocks.append(down_block)
311
+
312
+ for _ in range(layers_per_block):
313
+ controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
314
+ controlnet_block = zero_module(controlnet_block)
315
+ self.controlnet_down_blocks.append(controlnet_block)
316
+
317
+ if not is_final_block:
318
+ controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
319
+ controlnet_block = zero_module(controlnet_block)
320
+ self.controlnet_down_blocks.append(controlnet_block)
321
+
322
+ # mid
323
+ mid_block_channel = block_out_channels[-1]
324
+
325
+ controlnet_block = nn.Conv2d(mid_block_channel, mid_block_channel, kernel_size=1)
326
+ controlnet_block = zero_module(controlnet_block)
327
+ self.controlnet_mid_block = controlnet_block
328
+
329
+ self.mid_block = UNetMidBlock2DCrossAttn(
330
+ in_channels=mid_block_channel,
331
+ temb_channels=time_embed_dim,
332
+ resnet_eps=norm_eps,
333
+ resnet_act_fn=act_fn,
334
+ output_scale_factor=mid_block_scale_factor,
335
+ resnet_time_scale_shift=resnet_time_scale_shift,
336
+ cross_attention_dim=cross_attention_dim,
337
+ num_attention_heads=num_attention_heads[-1],
338
+ resnet_groups=norm_num_groups,
339
+ use_linear_projection=use_linear_projection,
340
+ upcast_attention=upcast_attention,
341
+ )
342
+
343
+ @classmethod
344
+ def from_unet(
345
+ cls,
346
+ unet: UNet2DConditionModel,
347
+ controlnet_conditioning_channel_order: str = "rgb",
348
+ conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256),
349
+ load_weights_from_unet: bool = True,
350
+ ):
351
+ r"""
352
+ Instantiate a [`ControlNetModel`] from [`UNet2DConditionModel`].
353
+
354
+ Parameters:
355
+ unet (`UNet2DConditionModel`):
356
+ The UNet model weights to copy to the [`ControlNetModel`]. All configuration options are also copied
357
+ where applicable.
358
+ """
359
+ controlnet = cls(
360
+ in_channels=unet.config.in_channels,
361
+ flip_sin_to_cos=unet.config.flip_sin_to_cos,
362
+ freq_shift=unet.config.freq_shift,
363
+ down_block_types=unet.config.down_block_types,
364
+ only_cross_attention=unet.config.only_cross_attention,
365
+ block_out_channels=unet.config.block_out_channels,
366
+ layers_per_block=unet.config.layers_per_block,
367
+ downsample_padding=unet.config.downsample_padding,
368
+ mid_block_scale_factor=unet.config.mid_block_scale_factor,
369
+ act_fn=unet.config.act_fn,
370
+ norm_num_groups=unet.config.norm_num_groups,
371
+ norm_eps=unet.config.norm_eps,
372
+ cross_attention_dim=unet.config.cross_attention_dim,
373
+ attention_head_dim=unet.config.attention_head_dim,
374
+ num_attention_heads=unet.config.num_attention_heads,
375
+ use_linear_projection=unet.config.use_linear_projection,
376
+ class_embed_type=unet.config.class_embed_type,
377
+ num_class_embeds=unet.config.num_class_embeds,
378
+ upcast_attention=unet.config.upcast_attention,
379
+ resnet_time_scale_shift=unet.config.resnet_time_scale_shift,
380
+ projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim,
381
+ controlnet_conditioning_channel_order=controlnet_conditioning_channel_order,
382
+ conditioning_embedding_out_channels=conditioning_embedding_out_channels,
383
+ )
384
+
385
+ if load_weights_from_unet:
386
+ controlnet.conv_in.load_state_dict(unet.conv_in.state_dict())
387
+ controlnet.time_proj.load_state_dict(unet.time_proj.state_dict())
388
+ controlnet.time_embedding.load_state_dict(unet.time_embedding.state_dict())
389
+
390
+ if controlnet.class_embedding:
391
+ controlnet.class_embedding.load_state_dict(unet.class_embedding.state_dict())
392
+
393
+ controlnet.down_blocks.load_state_dict(unet.down_blocks.state_dict())
394
+ controlnet.mid_block.load_state_dict(unet.mid_block.state_dict())
395
+
396
+ return controlnet
397
+
398
+ @property
399
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
400
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
401
+ r"""
402
+ Returns:
403
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
404
+ indexed by its weight name.
405
+ """
406
+ # set recursively
407
+ processors = {}
408
+
409
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
410
+ if hasattr(module, "set_processor"):
411
+ processors[f"{name}.processor"] = module.processor
412
+
413
+ for sub_name, child in module.named_children():
414
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
415
+
416
+ return processors
417
+
418
+ for name, module in self.named_children():
419
+ fn_recursive_add_processors(name, module, processors)
420
+
421
+ return processors
422
+
423
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
424
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
425
+ r"""
426
+ Sets the attention processor to use to compute attention.
427
+
428
+ Parameters:
429
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
430
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
431
+ for **all** `Attention` layers.
432
+
433
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
434
+ processor. This is strongly recommended when setting trainable attention processors.
435
+
436
+ """
437
+ count = len(self.attn_processors.keys())
438
+
439
+ if isinstance(processor, dict) and len(processor) != count:
440
+ raise ValueError(
441
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
442
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
443
+ )
444
+
445
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
446
+ if hasattr(module, "set_processor"):
447
+ if not isinstance(processor, dict):
448
+ module.set_processor(processor)
449
+ else:
450
+ module.set_processor(processor.pop(f"{name}.processor"))
451
+
452
+ for sub_name, child in module.named_children():
453
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
454
+
455
+ for name, module in self.named_children():
456
+ fn_recursive_attn_processor(name, module, processor)
457
+
458
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
459
+ def set_default_attn_processor(self):
460
+ """
461
+ Disables custom attention processors and sets the default attention implementation.
462
+ """
463
+ self.set_attn_processor(AttnProcessor())
464
+
465
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice
466
+ def set_attention_slice(self, slice_size):
467
+ r"""
468
+ Enable sliced attention computation.
469
+
470
+ When this option is enabled, the attention module splits the input tensor in slices to compute attention in
471
+ several steps. This is useful for saving some memory in exchange for a small decrease in speed.
472
+
473
+ Args:
474
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
475
+ When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
476
+ `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
477
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
478
+ must be a multiple of `slice_size`.
479
+ """
480
+ sliceable_head_dims = []
481
+
482
+ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
483
+ if hasattr(module, "set_attention_slice"):
484
+ sliceable_head_dims.append(module.sliceable_head_dim)
485
+
486
+ for child in module.children():
487
+ fn_recursive_retrieve_sliceable_dims(child)
488
+
489
+ # retrieve number of attention layers
490
+ for module in self.children():
491
+ fn_recursive_retrieve_sliceable_dims(module)
492
+
493
+ num_sliceable_layers = len(sliceable_head_dims)
494
+
495
+ if slice_size == "auto":
496
+ # half the attention head size is usually a good trade-off between
497
+ # speed and memory
498
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
499
+ elif slice_size == "max":
500
+ # make smallest slice possible
501
+ slice_size = num_sliceable_layers * [1]
502
+
503
+ slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
504
+
505
+ if len(slice_size) != len(sliceable_head_dims):
506
+ raise ValueError(
507
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
508
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
509
+ )
510
+
511
+ for i in range(len(slice_size)):
512
+ size = slice_size[i]
513
+ dim = sliceable_head_dims[i]
514
+ if size is not None and size > dim:
515
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
516
+
517
+ # Recursively walk through all the children.
518
+ # Any children which exposes the set_attention_slice method
519
+ # gets the message
520
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
521
+ if hasattr(module, "set_attention_slice"):
522
+ module.set_attention_slice(slice_size.pop())
523
+
524
+ for child in module.children():
525
+ fn_recursive_set_attention_slice(child, slice_size)
526
+
527
+ reversed_slice_size = list(reversed(slice_size))
528
+ for module in self.children():
529
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
530
+
531
+ def _set_gradient_checkpointing(self, module, value=False):
532
+ if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)):
533
+ module.gradient_checkpointing = value
534
+
535
+ def forward(
536
+ self,
537
+ sample: torch.FloatTensor,
538
+ timestep: Union[torch.Tensor, float, int],
539
+ encoder_hidden_states: torch.Tensor,
540
+ controlnet_cond: torch.FloatTensor,
541
+ conditioning_scale: float = 1.0,
542
+ class_labels: Optional[torch.Tensor] = None,
543
+ timestep_cond: Optional[torch.Tensor] = None,
544
+ attention_mask: Optional[torch.Tensor] = None,
545
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
546
+ guess_mode: bool = False,
547
+ return_dict: bool = True,
548
+ ) -> Union[ControlNetOutput, Tuple]:
549
+ """
550
+ The [`ControlNetModel`] forward method.
551
+
552
+ Args:
553
+ sample (`torch.FloatTensor`):
554
+ The noisy input tensor.
555
+ timestep (`Union[torch.Tensor, float, int]`):
556
+ The number of timesteps to denoise an input.
557
+ encoder_hidden_states (`torch.Tensor`):
558
+ The encoder hidden states.
559
+ controlnet_cond (`torch.FloatTensor`):
560
+ The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
561
+ conditioning_scale (`float`, defaults to `1.0`):
562
+ The scale factor for ControlNet outputs.
563
+ class_labels (`torch.Tensor`, *optional*, defaults to `None`):
564
+ Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
565
+ timestep_cond (`torch.Tensor`, *optional*, defaults to `None`):
566
+ attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
567
+ cross_attention_kwargs(`dict[str]`, *optional*, defaults to `None`):
568
+ A kwargs dictionary that if specified is passed along to the `AttnProcessor`.
569
+ guess_mode (`bool`, defaults to `False`):
570
+ In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if
571
+ you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended.
572
+ return_dict (`bool`, defaults to `True`):
573
+ Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple.
574
+
575
+ Returns:
576
+ [`~models.controlnet.ControlNetOutput`] **or** `tuple`:
577
+ If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is
578
+ returned where the first element is the sample tensor.
579
+ """
580
+ # check channel order
581
+ channel_order = self.config.controlnet_conditioning_channel_order
582
+
583
+ if channel_order == "rgb":
584
+ # in rgb order by default
585
+ ...
586
+ elif channel_order == "bgr":
587
+ controlnet_cond = torch.flip(controlnet_cond, dims=[1])
588
+ else:
589
+ raise ValueError(f"unknown `controlnet_conditioning_channel_order`: {channel_order}")
590
+
591
+ # prepare attention_mask
592
+ if attention_mask is not None:
593
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
594
+ attention_mask = attention_mask.unsqueeze(1)
595
+
596
+ # 1. time
597
+ timesteps = timestep
598
+ if not torch.is_tensor(timesteps):
599
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
600
+ # This would be a good case for the `match` statement (Python 3.10+)
601
+ is_mps = sample.device.type == "mps"
602
+ if isinstance(timestep, float):
603
+ dtype = torch.float32 if is_mps else torch.float64
604
+ else:
605
+ dtype = torch.int32 if is_mps else torch.int64
606
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
607
+ elif len(timesteps.shape) == 0:
608
+ timesteps = timesteps[None].to(sample.device)
609
+
610
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
611
+ timesteps = timesteps.expand(sample.shape[0])
612
+
613
+ t_emb = self.time_proj(timesteps)
614
+
615
+ # timesteps does not contain any weights and will always return f32 tensors
616
+ # but time_embedding might actually be running in fp16. so we need to cast here.
617
+ # there might be better ways to encapsulate this.
618
+ t_emb = t_emb.to(dtype=sample.dtype)
619
+
620
+ emb = self.time_embedding(t_emb, timestep_cond)
621
+
622
+ if self.class_embedding is not None:
623
+ if class_labels is None:
624
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
625
+
626
+ if self.config.class_embed_type == "timestep":
627
+ class_labels = self.time_proj(class_labels)
628
+
629
+ class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
630
+ emb = emb + class_emb
631
+
632
+ # 2. pre-process
633
+ sample = self.conv_in(sample)
634
+
635
+ controlnet_cond = self.controlnet_cond_embedding(controlnet_cond)
636
+
637
+ sample = sample + controlnet_cond
638
+
639
+ # 3. down
640
+ down_block_res_samples = (sample,)
641
+ for downsample_block in self.down_blocks:
642
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
643
+ sample, res_samples = downsample_block(
644
+ hidden_states=sample,
645
+ temb=emb,
646
+ encoder_hidden_states=encoder_hidden_states,
647
+ attention_mask=attention_mask,
648
+ cross_attention_kwargs=cross_attention_kwargs,
649
+ )
650
+ else:
651
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
652
+
653
+ down_block_res_samples += res_samples
654
+
655
+ # 4. mid
656
+ if self.mid_block is not None:
657
+ sample = self.mid_block(
658
+ sample,
659
+ emb,
660
+ encoder_hidden_states=encoder_hidden_states,
661
+ attention_mask=attention_mask,
662
+ cross_attention_kwargs=cross_attention_kwargs,
663
+ )
664
+
665
+ # 5. Control net blocks
666
+
667
+ controlnet_down_block_res_samples = ()
668
+
669
+ for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):
670
+ down_block_res_sample = controlnet_block(down_block_res_sample)
671
+ controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,)
672
+
673
+ down_block_res_samples = controlnet_down_block_res_samples
674
+
675
+ mid_block_res_sample = self.controlnet_mid_block(sample)
676
+
677
+ # 6. scaling
678
+ if guess_mode and not self.config.global_pool_conditions:
679
+ scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0
680
+
681
+ scales = scales * conditioning_scale
682
+ down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)]
683
+ mid_block_res_sample = mid_block_res_sample * scales[-1] # last one
684
+ else:
685
+ down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
686
+ mid_block_res_sample = mid_block_res_sample * conditioning_scale
687
+
688
+ if self.config.global_pool_conditions:
689
+ down_block_res_samples = [
690
+ torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples
691
+ ]
692
+ mid_block_res_sample = torch.mean(mid_block_res_sample, dim=(2, 3), keepdim=True)
693
+
694
+ if not return_dict:
695
+ return (down_block_res_samples, mid_block_res_sample)
696
+
697
+ return ControlNetOutput(
698
+ down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample
699
+ )
700
+
701
+
702
+ def zero_module(module):
703
+ for p in module.parameters():
704
+ nn.init.zeros_(p)
705
+ return module
6DoF/diffusers/models/controlnet_flax.py ADDED
@@ -0,0 +1,394 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Optional, Tuple, Union
15
+
16
+ import flax
17
+ import flax.linen as nn
18
+ import jax
19
+ import jax.numpy as jnp
20
+ from flax.core.frozen_dict import FrozenDict
21
+
22
+ from ..configuration_utils import ConfigMixin, flax_register_to_config
23
+ from ..utils import BaseOutput
24
+ from .embeddings_flax import FlaxTimestepEmbedding, FlaxTimesteps
25
+ from .modeling_flax_utils import FlaxModelMixin
26
+ from .unet_2d_blocks_flax import (
27
+ FlaxCrossAttnDownBlock2D,
28
+ FlaxDownBlock2D,
29
+ FlaxUNetMidBlock2DCrossAttn,
30
+ )
31
+
32
+
33
+ @flax.struct.dataclass
34
+ class FlaxControlNetOutput(BaseOutput):
35
+ """
36
+ The output of [`FlaxControlNetModel`].
37
+
38
+ Args:
39
+ down_block_res_samples (`jnp.ndarray`):
40
+ mid_block_res_sample (`jnp.ndarray`):
41
+ """
42
+
43
+ down_block_res_samples: jnp.ndarray
44
+ mid_block_res_sample: jnp.ndarray
45
+
46
+
47
+ class FlaxControlNetConditioningEmbedding(nn.Module):
48
+ conditioning_embedding_channels: int
49
+ block_out_channels: Tuple[int] = (16, 32, 96, 256)
50
+ dtype: jnp.dtype = jnp.float32
51
+
52
+ def setup(self):
53
+ self.conv_in = nn.Conv(
54
+ self.block_out_channels[0],
55
+ kernel_size=(3, 3),
56
+ padding=((1, 1), (1, 1)),
57
+ dtype=self.dtype,
58
+ )
59
+
60
+ blocks = []
61
+ for i in range(len(self.block_out_channels) - 1):
62
+ channel_in = self.block_out_channels[i]
63
+ channel_out = self.block_out_channels[i + 1]
64
+ conv1 = nn.Conv(
65
+ channel_in,
66
+ kernel_size=(3, 3),
67
+ padding=((1, 1), (1, 1)),
68
+ dtype=self.dtype,
69
+ )
70
+ blocks.append(conv1)
71
+ conv2 = nn.Conv(
72
+ channel_out,
73
+ kernel_size=(3, 3),
74
+ strides=(2, 2),
75
+ padding=((1, 1), (1, 1)),
76
+ dtype=self.dtype,
77
+ )
78
+ blocks.append(conv2)
79
+ self.blocks = blocks
80
+
81
+ self.conv_out = nn.Conv(
82
+ self.conditioning_embedding_channels,
83
+ kernel_size=(3, 3),
84
+ padding=((1, 1), (1, 1)),
85
+ kernel_init=nn.initializers.zeros_init(),
86
+ bias_init=nn.initializers.zeros_init(),
87
+ dtype=self.dtype,
88
+ )
89
+
90
+ def __call__(self, conditioning):
91
+ embedding = self.conv_in(conditioning)
92
+ embedding = nn.silu(embedding)
93
+
94
+ for block in self.blocks:
95
+ embedding = block(embedding)
96
+ embedding = nn.silu(embedding)
97
+
98
+ embedding = self.conv_out(embedding)
99
+
100
+ return embedding
101
+
102
+
103
+ @flax_register_to_config
104
+ class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin):
105
+ r"""
106
+ A ControlNet model.
107
+
108
+ This model inherits from [`FlaxModelMixin`]. Check the superclass documentation for it’s generic methods
109
+ implemented for all models (such as downloading or saving).
110
+
111
+ This model is also a Flax Linen [`flax.linen.Module`](https://flax.readthedocs.io/en/latest/flax.linen.html#module)
112
+ subclass. Use it as a regular Flax Linen module and refer to the Flax documentation for all matters related to its
113
+ general usage and behavior.
114
+
115
+ Inherent JAX features such as the following are supported:
116
+
117
+ - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
118
+ - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
119
+ - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
120
+ - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
121
+
122
+ Parameters:
123
+ sample_size (`int`, *optional*):
124
+ The size of the input sample.
125
+ in_channels (`int`, *optional*, defaults to 4):
126
+ The number of channels in the input sample.
127
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxDownBlock2D")`):
128
+ The tuple of downsample blocks to use.
129
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
130
+ The tuple of output channels for each block.
131
+ layers_per_block (`int`, *optional*, defaults to 2):
132
+ The number of layers per block.
133
+ attention_head_dim (`int` or `Tuple[int]`, *optional*, defaults to 8):
134
+ The dimension of the attention heads.
135
+ num_attention_heads (`int` or `Tuple[int]`, *optional*):
136
+ The number of attention heads.
137
+ cross_attention_dim (`int`, *optional*, defaults to 768):
138
+ The dimension of the cross attention features.
139
+ dropout (`float`, *optional*, defaults to 0):
140
+ Dropout probability for down, up and bottleneck blocks.
141
+ flip_sin_to_cos (`bool`, *optional*, defaults to `True`):
142
+ Whether to flip the sin to cos in the time embedding.
143
+ freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
144
+ controlnet_conditioning_channel_order (`str`, *optional*, defaults to `rgb`):
145
+ The channel order of conditional image. Will convert to `rgb` if it's `bgr`.
146
+ conditioning_embedding_out_channels (`tuple`, *optional*, defaults to `(16, 32, 96, 256)`):
147
+ The tuple of output channel for each block in the `conditioning_embedding` layer.
148
+ """
149
+ sample_size: int = 32
150
+ in_channels: int = 4
151
+ down_block_types: Tuple[str] = (
152
+ "CrossAttnDownBlock2D",
153
+ "CrossAttnDownBlock2D",
154
+ "CrossAttnDownBlock2D",
155
+ "DownBlock2D",
156
+ )
157
+ only_cross_attention: Union[bool, Tuple[bool]] = False
158
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280)
159
+ layers_per_block: int = 2
160
+ attention_head_dim: Union[int, Tuple[int]] = 8
161
+ num_attention_heads: Optional[Union[int, Tuple[int]]] = None
162
+ cross_attention_dim: int = 1280
163
+ dropout: float = 0.0
164
+ use_linear_projection: bool = False
165
+ dtype: jnp.dtype = jnp.float32
166
+ flip_sin_to_cos: bool = True
167
+ freq_shift: int = 0
168
+ controlnet_conditioning_channel_order: str = "rgb"
169
+ conditioning_embedding_out_channels: Tuple[int] = (16, 32, 96, 256)
170
+
171
+ def init_weights(self, rng: jax.random.KeyArray) -> FrozenDict:
172
+ # init input tensors
173
+ sample_shape = (1, self.in_channels, self.sample_size, self.sample_size)
174
+ sample = jnp.zeros(sample_shape, dtype=jnp.float32)
175
+ timesteps = jnp.ones((1,), dtype=jnp.int32)
176
+ encoder_hidden_states = jnp.zeros((1, 1, self.cross_attention_dim), dtype=jnp.float32)
177
+ controlnet_cond_shape = (1, 3, self.sample_size * 8, self.sample_size * 8)
178
+ controlnet_cond = jnp.zeros(controlnet_cond_shape, dtype=jnp.float32)
179
+
180
+ params_rng, dropout_rng = jax.random.split(rng)
181
+ rngs = {"params": params_rng, "dropout": dropout_rng}
182
+
183
+ return self.init(rngs, sample, timesteps, encoder_hidden_states, controlnet_cond)["params"]
184
+
185
+ def setup(self):
186
+ block_out_channels = self.block_out_channels
187
+ time_embed_dim = block_out_channels[0] * 4
188
+
189
+ # If `num_attention_heads` is not defined (which is the case for most models)
190
+ # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
191
+ # The reason for this behavior is to correct for incorrectly named variables that were introduced
192
+ # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
193
+ # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
194
+ # which is why we correct for the naming here.
195
+ num_attention_heads = self.num_attention_heads or self.attention_head_dim
196
+
197
+ # input
198
+ self.conv_in = nn.Conv(
199
+ block_out_channels[0],
200
+ kernel_size=(3, 3),
201
+ strides=(1, 1),
202
+ padding=((1, 1), (1, 1)),
203
+ dtype=self.dtype,
204
+ )
205
+
206
+ # time
207
+ self.time_proj = FlaxTimesteps(
208
+ block_out_channels[0], flip_sin_to_cos=self.flip_sin_to_cos, freq_shift=self.config.freq_shift
209
+ )
210
+ self.time_embedding = FlaxTimestepEmbedding(time_embed_dim, dtype=self.dtype)
211
+
212
+ self.controlnet_cond_embedding = FlaxControlNetConditioningEmbedding(
213
+ conditioning_embedding_channels=block_out_channels[0],
214
+ block_out_channels=self.conditioning_embedding_out_channels,
215
+ )
216
+
217
+ only_cross_attention = self.only_cross_attention
218
+ if isinstance(only_cross_attention, bool):
219
+ only_cross_attention = (only_cross_attention,) * len(self.down_block_types)
220
+
221
+ if isinstance(num_attention_heads, int):
222
+ num_attention_heads = (num_attention_heads,) * len(self.down_block_types)
223
+
224
+ # down
225
+ down_blocks = []
226
+ controlnet_down_blocks = []
227
+
228
+ output_channel = block_out_channels[0]
229
+
230
+ controlnet_block = nn.Conv(
231
+ output_channel,
232
+ kernel_size=(1, 1),
233
+ padding="VALID",
234
+ kernel_init=nn.initializers.zeros_init(),
235
+ bias_init=nn.initializers.zeros_init(),
236
+ dtype=self.dtype,
237
+ )
238
+ controlnet_down_blocks.append(controlnet_block)
239
+
240
+ for i, down_block_type in enumerate(self.down_block_types):
241
+ input_channel = output_channel
242
+ output_channel = block_out_channels[i]
243
+ is_final_block = i == len(block_out_channels) - 1
244
+
245
+ if down_block_type == "CrossAttnDownBlock2D":
246
+ down_block = FlaxCrossAttnDownBlock2D(
247
+ in_channels=input_channel,
248
+ out_channels=output_channel,
249
+ dropout=self.dropout,
250
+ num_layers=self.layers_per_block,
251
+ num_attention_heads=num_attention_heads[i],
252
+ add_downsample=not is_final_block,
253
+ use_linear_projection=self.use_linear_projection,
254
+ only_cross_attention=only_cross_attention[i],
255
+ dtype=self.dtype,
256
+ )
257
+ else:
258
+ down_block = FlaxDownBlock2D(
259
+ in_channels=input_channel,
260
+ out_channels=output_channel,
261
+ dropout=self.dropout,
262
+ num_layers=self.layers_per_block,
263
+ add_downsample=not is_final_block,
264
+ dtype=self.dtype,
265
+ )
266
+
267
+ down_blocks.append(down_block)
268
+
269
+ for _ in range(self.layers_per_block):
270
+ controlnet_block = nn.Conv(
271
+ output_channel,
272
+ kernel_size=(1, 1),
273
+ padding="VALID",
274
+ kernel_init=nn.initializers.zeros_init(),
275
+ bias_init=nn.initializers.zeros_init(),
276
+ dtype=self.dtype,
277
+ )
278
+ controlnet_down_blocks.append(controlnet_block)
279
+
280
+ if not is_final_block:
281
+ controlnet_block = nn.Conv(
282
+ output_channel,
283
+ kernel_size=(1, 1),
284
+ padding="VALID",
285
+ kernel_init=nn.initializers.zeros_init(),
286
+ bias_init=nn.initializers.zeros_init(),
287
+ dtype=self.dtype,
288
+ )
289
+ controlnet_down_blocks.append(controlnet_block)
290
+
291
+ self.down_blocks = down_blocks
292
+ self.controlnet_down_blocks = controlnet_down_blocks
293
+
294
+ # mid
295
+ mid_block_channel = block_out_channels[-1]
296
+ self.mid_block = FlaxUNetMidBlock2DCrossAttn(
297
+ in_channels=mid_block_channel,
298
+ dropout=self.dropout,
299
+ num_attention_heads=num_attention_heads[-1],
300
+ use_linear_projection=self.use_linear_projection,
301
+ dtype=self.dtype,
302
+ )
303
+
304
+ self.controlnet_mid_block = nn.Conv(
305
+ mid_block_channel,
306
+ kernel_size=(1, 1),
307
+ padding="VALID",
308
+ kernel_init=nn.initializers.zeros_init(),
309
+ bias_init=nn.initializers.zeros_init(),
310
+ dtype=self.dtype,
311
+ )
312
+
313
+ def __call__(
314
+ self,
315
+ sample,
316
+ timesteps,
317
+ encoder_hidden_states,
318
+ controlnet_cond,
319
+ conditioning_scale: float = 1.0,
320
+ return_dict: bool = True,
321
+ train: bool = False,
322
+ ) -> Union[FlaxControlNetOutput, Tuple]:
323
+ r"""
324
+ Args:
325
+ sample (`jnp.ndarray`): (batch, channel, height, width) noisy inputs tensor
326
+ timestep (`jnp.ndarray` or `float` or `int`): timesteps
327
+ encoder_hidden_states (`jnp.ndarray`): (batch_size, sequence_length, hidden_size) encoder hidden states
328
+ controlnet_cond (`jnp.ndarray`): (batch, channel, height, width) the conditional input tensor
329
+ conditioning_scale: (`float`) the scale factor for controlnet outputs
330
+ return_dict (`bool`, *optional*, defaults to `True`):
331
+ Whether or not to return a [`models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] instead of a
332
+ plain tuple.
333
+ train (`bool`, *optional*, defaults to `False`):
334
+ Use deterministic functions and disable dropout when not training.
335
+
336
+ Returns:
337
+ [`~models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] or `tuple`:
338
+ [`~models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`.
339
+ When returning a tuple, the first element is the sample tensor.
340
+ """
341
+ channel_order = self.controlnet_conditioning_channel_order
342
+ if channel_order == "bgr":
343
+ controlnet_cond = jnp.flip(controlnet_cond, axis=1)
344
+
345
+ # 1. time
346
+ if not isinstance(timesteps, jnp.ndarray):
347
+ timesteps = jnp.array([timesteps], dtype=jnp.int32)
348
+ elif isinstance(timesteps, jnp.ndarray) and len(timesteps.shape) == 0:
349
+ timesteps = timesteps.astype(dtype=jnp.float32)
350
+ timesteps = jnp.expand_dims(timesteps, 0)
351
+
352
+ t_emb = self.time_proj(timesteps)
353
+ t_emb = self.time_embedding(t_emb)
354
+
355
+ # 2. pre-process
356
+ sample = jnp.transpose(sample, (0, 2, 3, 1))
357
+ sample = self.conv_in(sample)
358
+
359
+ controlnet_cond = jnp.transpose(controlnet_cond, (0, 2, 3, 1))
360
+ controlnet_cond = self.controlnet_cond_embedding(controlnet_cond)
361
+ sample += controlnet_cond
362
+
363
+ # 3. down
364
+ down_block_res_samples = (sample,)
365
+ for down_block in self.down_blocks:
366
+ if isinstance(down_block, FlaxCrossAttnDownBlock2D):
367
+ sample, res_samples = down_block(sample, t_emb, encoder_hidden_states, deterministic=not train)
368
+ else:
369
+ sample, res_samples = down_block(sample, t_emb, deterministic=not train)
370
+ down_block_res_samples += res_samples
371
+
372
+ # 4. mid
373
+ sample = self.mid_block(sample, t_emb, encoder_hidden_states, deterministic=not train)
374
+
375
+ # 5. contronet blocks
376
+ controlnet_down_block_res_samples = ()
377
+ for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):
378
+ down_block_res_sample = controlnet_block(down_block_res_sample)
379
+ controlnet_down_block_res_samples += (down_block_res_sample,)
380
+
381
+ down_block_res_samples = controlnet_down_block_res_samples
382
+
383
+ mid_block_res_sample = self.controlnet_mid_block(sample)
384
+
385
+ # 6. scaling
386
+ down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
387
+ mid_block_res_sample *= conditioning_scale
388
+
389
+ if not return_dict:
390
+ return (down_block_res_samples, mid_block_res_sample)
391
+
392
+ return FlaxControlNetOutput(
393
+ down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample
394
+ )
6DoF/diffusers/models/cross_attention.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from ..utils import deprecate
15
+ from .attention_processor import ( # noqa: F401
16
+ Attention,
17
+ AttentionProcessor,
18
+ AttnAddedKVProcessor,
19
+ AttnProcessor2_0,
20
+ LoRAAttnProcessor,
21
+ LoRALinearLayer,
22
+ LoRAXFormersAttnProcessor,
23
+ SlicedAttnAddedKVProcessor,
24
+ SlicedAttnProcessor,
25
+ XFormersAttnProcessor,
26
+ )
27
+ from .attention_processor import AttnProcessor as AttnProcessorRename # noqa: F401
28
+
29
+
30
+ deprecate(
31
+ "cross_attention",
32
+ "0.20.0",
33
+ "Importing from cross_attention is deprecated. Please import from diffusers.models.attention_processor instead.",
34
+ standard_warn=False,
35
+ )
36
+
37
+
38
+ AttnProcessor = AttentionProcessor
39
+
40
+
41
+ class CrossAttention(Attention):
42
+ def __init__(self, *args, **kwargs):
43
+ deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.20.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead."
44
+ deprecate("cross_attention", "0.20.0", deprecation_message, standard_warn=False)
45
+ super().__init__(*args, **kwargs)
46
+
47
+
48
+ class CrossAttnProcessor(AttnProcessorRename):
49
+ def __init__(self, *args, **kwargs):
50
+ deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.20.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead."
51
+ deprecate("cross_attention", "0.20.0", deprecation_message, standard_warn=False)
52
+ super().__init__(*args, **kwargs)
53
+
54
+
55
+ class LoRACrossAttnProcessor(LoRAAttnProcessor):
56
+ def __init__(self, *args, **kwargs):
57
+ deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.20.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead."
58
+ deprecate("cross_attention", "0.20.0", deprecation_message, standard_warn=False)
59
+ super().__init__(*args, **kwargs)
60
+
61
+
62
+ class CrossAttnAddedKVProcessor(AttnAddedKVProcessor):
63
+ def __init__(self, *args, **kwargs):
64
+ deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.20.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead."
65
+ deprecate("cross_attention", "0.20.0", deprecation_message, standard_warn=False)
66
+ super().__init__(*args, **kwargs)
67
+
68
+
69
+ class XFormersCrossAttnProcessor(XFormersAttnProcessor):
70
+ def __init__(self, *args, **kwargs):
71
+ deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.20.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead."
72
+ deprecate("cross_attention", "0.20.0", deprecation_message, standard_warn=False)
73
+ super().__init__(*args, **kwargs)
74
+
75
+
76
+ class LoRAXFormersCrossAttnProcessor(LoRAXFormersAttnProcessor):
77
+ def __init__(self, *args, **kwargs):
78
+ deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.20.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead."
79
+ deprecate("cross_attention", "0.20.0", deprecation_message, standard_warn=False)
80
+ super().__init__(*args, **kwargs)
81
+
82
+
83
+ class SlicedCrossAttnProcessor(SlicedAttnProcessor):
84
+ def __init__(self, *args, **kwargs):
85
+ deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.20.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead."
86
+ deprecate("cross_attention", "0.20.0", deprecation_message, standard_warn=False)
87
+ super().__init__(*args, **kwargs)
88
+
89
+
90
+ class SlicedCrossAttnAddedKVProcessor(SlicedAttnAddedKVProcessor):
91
+ def __init__(self, *args, **kwargs):
92
+ deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.20.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead."
93
+ deprecate("cross_attention", "0.20.0", deprecation_message, standard_warn=False)
94
+ super().__init__(*args, **kwargs)
6DoF/diffusers/models/dual_transformer_2d.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Optional
15
+
16
+ from torch import nn
17
+
18
+ from .transformer_2d import Transformer2DModel, Transformer2DModelOutput
19
+
20
+
21
+ class DualTransformer2DModel(nn.Module):
22
+ """
23
+ Dual transformer wrapper that combines two `Transformer2DModel`s for mixed inference.
24
+
25
+ Parameters:
26
+ num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
27
+ attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
28
+ in_channels (`int`, *optional*):
29
+ Pass if the input is continuous. The number of channels in the input and output.
30
+ num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
31
+ dropout (`float`, *optional*, defaults to 0.1): The dropout probability to use.
32
+ cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use.
33
+ sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images.
34
+ Note that this is fixed at training time as it is used for learning a number of position embeddings. See
35
+ `ImagePositionalEmbeddings`.
36
+ num_vector_embeds (`int`, *optional*):
37
+ Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels.
38
+ Includes the class for the masked latent pixel.
39
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
40
+ num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`.
41
+ The number of diffusion steps used during training. Note that this is fixed at training time as it is used
42
+ to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for
43
+ up to but not more than steps than `num_embeds_ada_norm`.
44
+ attention_bias (`bool`, *optional*):
45
+ Configure if the TransformerBlocks' attention should contain a bias parameter.
46
+ """
47
+
48
+ def __init__(
49
+ self,
50
+ num_attention_heads: int = 16,
51
+ attention_head_dim: int = 88,
52
+ in_channels: Optional[int] = None,
53
+ num_layers: int = 1,
54
+ dropout: float = 0.0,
55
+ norm_num_groups: int = 32,
56
+ cross_attention_dim: Optional[int] = None,
57
+ attention_bias: bool = False,
58
+ sample_size: Optional[int] = None,
59
+ num_vector_embeds: Optional[int] = None,
60
+ activation_fn: str = "geglu",
61
+ num_embeds_ada_norm: Optional[int] = None,
62
+ ):
63
+ super().__init__()
64
+ self.transformers = nn.ModuleList(
65
+ [
66
+ Transformer2DModel(
67
+ num_attention_heads=num_attention_heads,
68
+ attention_head_dim=attention_head_dim,
69
+ in_channels=in_channels,
70
+ num_layers=num_layers,
71
+ dropout=dropout,
72
+ norm_num_groups=norm_num_groups,
73
+ cross_attention_dim=cross_attention_dim,
74
+ attention_bias=attention_bias,
75
+ sample_size=sample_size,
76
+ num_vector_embeds=num_vector_embeds,
77
+ activation_fn=activation_fn,
78
+ num_embeds_ada_norm=num_embeds_ada_norm,
79
+ )
80
+ for _ in range(2)
81
+ ]
82
+ )
83
+
84
+ # Variables that can be set by a pipeline:
85
+
86
+ # The ratio of transformer1 to transformer2's output states to be combined during inference
87
+ self.mix_ratio = 0.5
88
+
89
+ # The shape of `encoder_hidden_states` is expected to be
90
+ # `(batch_size, condition_lengths[0]+condition_lengths[1], num_features)`
91
+ self.condition_lengths = [77, 257]
92
+
93
+ # Which transformer to use to encode which condition.
94
+ # E.g. `(1, 0)` means that we'll use `transformers[1](conditions[0])` and `transformers[0](conditions[1])`
95
+ self.transformer_index_for_condition = [1, 0]
96
+
97
+ def forward(
98
+ self,
99
+ hidden_states,
100
+ encoder_hidden_states,
101
+ timestep=None,
102
+ attention_mask=None,
103
+ cross_attention_kwargs=None,
104
+ return_dict: bool = True,
105
+ ):
106
+ """
107
+ Args:
108
+ hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`.
109
+ When continuous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input
110
+ hidden_states
111
+ encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
112
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
113
+ self-attention.
114
+ timestep ( `torch.long`, *optional*):
115
+ Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step.
116
+ attention_mask (`torch.FloatTensor`, *optional*):
117
+ Optional attention mask to be applied in Attention
118
+ return_dict (`bool`, *optional*, defaults to `True`):
119
+ Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
120
+
121
+ Returns:
122
+ [`~models.transformer_2d.Transformer2DModelOutput`] or `tuple`:
123
+ [`~models.transformer_2d.Transformer2DModelOutput`] if `return_dict` is True, otherwise a `tuple`. When
124
+ returning a tuple, the first element is the sample tensor.
125
+ """
126
+ input_states = hidden_states
127
+
128
+ encoded_states = []
129
+ tokens_start = 0
130
+ # attention_mask is not used yet
131
+ for i in range(2):
132
+ # for each of the two transformers, pass the corresponding condition tokens
133
+ condition_state = encoder_hidden_states[:, tokens_start : tokens_start + self.condition_lengths[i]]
134
+ transformer_index = self.transformer_index_for_condition[i]
135
+ encoded_state = self.transformers[transformer_index](
136
+ input_states,
137
+ encoder_hidden_states=condition_state,
138
+ timestep=timestep,
139
+ cross_attention_kwargs=cross_attention_kwargs,
140
+ return_dict=False,
141
+ )[0]
142
+ encoded_states.append(encoded_state - input_states)
143
+ tokens_start += self.condition_lengths[i]
144
+
145
+ output_states = encoded_states[0] * self.mix_ratio + encoded_states[1] * (1 - self.mix_ratio)
146
+ output_states = output_states + input_states
147
+
148
+ if not return_dict:
149
+ return (output_states,)
150
+
151
+ return Transformer2DModelOutput(sample=output_states)
6DoF/diffusers/models/embeddings.py ADDED
@@ -0,0 +1,546 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import math
15
+ from typing import Optional
16
+
17
+ import numpy as np
18
+ import torch
19
+ from torch import nn
20
+
21
+ from .activations import get_activation
22
+
23
+
24
+ def get_timestep_embedding(
25
+ timesteps: torch.Tensor,
26
+ embedding_dim: int,
27
+ flip_sin_to_cos: bool = False,
28
+ downscale_freq_shift: float = 1,
29
+ scale: float = 1,
30
+ max_period: int = 10000,
31
+ ):
32
+ """
33
+ This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
34
+
35
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
36
+ These may be fractional.
37
+ :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
38
+ embeddings. :return: an [N x dim] Tensor of positional embeddings.
39
+ """
40
+ assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
41
+
42
+ half_dim = embedding_dim // 2
43
+ exponent = -math.log(max_period) * torch.arange(
44
+ start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
45
+ )
46
+ exponent = exponent / (half_dim - downscale_freq_shift)
47
+
48
+ emb = torch.exp(exponent)
49
+ emb = timesteps[:, None].float() * emb[None, :]
50
+
51
+ # scale embeddings
52
+ emb = scale * emb
53
+
54
+ # concat sine and cosine embeddings
55
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
56
+
57
+ # flip sine and cosine embeddings
58
+ if flip_sin_to_cos:
59
+ emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
60
+
61
+ # zero pad
62
+ if embedding_dim % 2 == 1:
63
+ emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
64
+ return emb
65
+
66
+
67
+ def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
68
+ """
69
+ grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or
70
+ [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
71
+ """
72
+ grid_h = np.arange(grid_size, dtype=np.float32)
73
+ grid_w = np.arange(grid_size, dtype=np.float32)
74
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
75
+ grid = np.stack(grid, axis=0)
76
+
77
+ grid = grid.reshape([2, 1, grid_size, grid_size])
78
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
79
+ if cls_token and extra_tokens > 0:
80
+ pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
81
+ return pos_embed
82
+
83
+
84
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
85
+ if embed_dim % 2 != 0:
86
+ raise ValueError("embed_dim must be divisible by 2")
87
+
88
+ # use half of dimensions to encode grid_h
89
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
90
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
91
+
92
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
93
+ return emb
94
+
95
+
96
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
97
+ """
98
+ embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D)
99
+ """
100
+ if embed_dim % 2 != 0:
101
+ raise ValueError("embed_dim must be divisible by 2")
102
+
103
+ omega = np.arange(embed_dim // 2, dtype=np.float64)
104
+ omega /= embed_dim / 2.0
105
+ omega = 1.0 / 10000**omega # (D/2,)
106
+
107
+ pos = pos.reshape(-1) # (M,)
108
+ out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
109
+
110
+ emb_sin = np.sin(out) # (M, D/2)
111
+ emb_cos = np.cos(out) # (M, D/2)
112
+
113
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
114
+ return emb
115
+
116
+
117
+ class PatchEmbed(nn.Module):
118
+ """2D Image to Patch Embedding"""
119
+
120
+ def __init__(
121
+ self,
122
+ height=224,
123
+ width=224,
124
+ patch_size=16,
125
+ in_channels=3,
126
+ embed_dim=768,
127
+ layer_norm=False,
128
+ flatten=True,
129
+ bias=True,
130
+ ):
131
+ super().__init__()
132
+
133
+ num_patches = (height // patch_size) * (width // patch_size)
134
+ self.flatten = flatten
135
+ self.layer_norm = layer_norm
136
+
137
+ self.proj = nn.Conv2d(
138
+ in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
139
+ )
140
+ if layer_norm:
141
+ self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6)
142
+ else:
143
+ self.norm = None
144
+
145
+ pos_embed = get_2d_sincos_pos_embed(embed_dim, int(num_patches**0.5))
146
+ self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=False)
147
+
148
+ def forward(self, latent):
149
+ latent = self.proj(latent)
150
+ if self.flatten:
151
+ latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC
152
+ if self.layer_norm:
153
+ latent = self.norm(latent)
154
+ return latent + self.pos_embed
155
+
156
+
157
+ class TimestepEmbedding(nn.Module):
158
+ def __init__(
159
+ self,
160
+ in_channels: int,
161
+ time_embed_dim: int,
162
+ act_fn: str = "silu",
163
+ out_dim: int = None,
164
+ post_act_fn: Optional[str] = None,
165
+ cond_proj_dim=None,
166
+ ):
167
+ super().__init__()
168
+
169
+ self.linear_1 = nn.Linear(in_channels, time_embed_dim)
170
+
171
+ if cond_proj_dim is not None:
172
+ self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
173
+ else:
174
+ self.cond_proj = None
175
+
176
+ self.act = get_activation(act_fn)
177
+
178
+ if out_dim is not None:
179
+ time_embed_dim_out = out_dim
180
+ else:
181
+ time_embed_dim_out = time_embed_dim
182
+ self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out)
183
+
184
+ if post_act_fn is None:
185
+ self.post_act = None
186
+ else:
187
+ self.post_act = get_activation(post_act_fn)
188
+
189
+ def forward(self, sample, condition=None):
190
+ if condition is not None:
191
+ sample = sample + self.cond_proj(condition)
192
+ sample = self.linear_1(sample)
193
+
194
+ if self.act is not None:
195
+ sample = self.act(sample)
196
+
197
+ sample = self.linear_2(sample)
198
+
199
+ if self.post_act is not None:
200
+ sample = self.post_act(sample)
201
+ return sample
202
+
203
+
204
+ class Timesteps(nn.Module):
205
+ def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float):
206
+ super().__init__()
207
+ self.num_channels = num_channels
208
+ self.flip_sin_to_cos = flip_sin_to_cos
209
+ self.downscale_freq_shift = downscale_freq_shift
210
+
211
+ def forward(self, timesteps):
212
+ t_emb = get_timestep_embedding(
213
+ timesteps,
214
+ self.num_channels,
215
+ flip_sin_to_cos=self.flip_sin_to_cos,
216
+ downscale_freq_shift=self.downscale_freq_shift,
217
+ )
218
+ return t_emb
219
+
220
+
221
+ class GaussianFourierProjection(nn.Module):
222
+ """Gaussian Fourier embeddings for noise levels."""
223
+
224
+ def __init__(
225
+ self, embedding_size: int = 256, scale: float = 1.0, set_W_to_weight=True, log=True, flip_sin_to_cos=False
226
+ ):
227
+ super().__init__()
228
+ self.weight = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
229
+ self.log = log
230
+ self.flip_sin_to_cos = flip_sin_to_cos
231
+
232
+ if set_W_to_weight:
233
+ # to delete later
234
+ self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
235
+
236
+ self.weight = self.W
237
+
238
+ def forward(self, x):
239
+ if self.log:
240
+ x = torch.log(x)
241
+
242
+ x_proj = x[:, None] * self.weight[None, :] * 2 * np.pi
243
+
244
+ if self.flip_sin_to_cos:
245
+ out = torch.cat([torch.cos(x_proj), torch.sin(x_proj)], dim=-1)
246
+ else:
247
+ out = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
248
+ return out
249
+
250
+
251
+ class ImagePositionalEmbeddings(nn.Module):
252
+ """
253
+ Converts latent image classes into vector embeddings. Sums the vector embeddings with positional embeddings for the
254
+ height and width of the latent space.
255
+
256
+ For more details, see figure 10 of the dall-e paper: https://arxiv.org/abs/2102.12092
257
+
258
+ For VQ-diffusion:
259
+
260
+ Output vector embeddings are used as input for the transformer.
261
+
262
+ Note that the vector embeddings for the transformer are different than the vector embeddings from the VQVAE.
263
+
264
+ Args:
265
+ num_embed (`int`):
266
+ Number of embeddings for the latent pixels embeddings.
267
+ height (`int`):
268
+ Height of the latent image i.e. the number of height embeddings.
269
+ width (`int`):
270
+ Width of the latent image i.e. the number of width embeddings.
271
+ embed_dim (`int`):
272
+ Dimension of the produced vector embeddings. Used for the latent pixel, height, and width embeddings.
273
+ """
274
+
275
+ def __init__(
276
+ self,
277
+ num_embed: int,
278
+ height: int,
279
+ width: int,
280
+ embed_dim: int,
281
+ ):
282
+ super().__init__()
283
+
284
+ self.height = height
285
+ self.width = width
286
+ self.num_embed = num_embed
287
+ self.embed_dim = embed_dim
288
+
289
+ self.emb = nn.Embedding(self.num_embed, embed_dim)
290
+ self.height_emb = nn.Embedding(self.height, embed_dim)
291
+ self.width_emb = nn.Embedding(self.width, embed_dim)
292
+
293
+ def forward(self, index):
294
+ emb = self.emb(index)
295
+
296
+ height_emb = self.height_emb(torch.arange(self.height, device=index.device).view(1, self.height))
297
+
298
+ # 1 x H x D -> 1 x H x 1 x D
299
+ height_emb = height_emb.unsqueeze(2)
300
+
301
+ width_emb = self.width_emb(torch.arange(self.width, device=index.device).view(1, self.width))
302
+
303
+ # 1 x W x D -> 1 x 1 x W x D
304
+ width_emb = width_emb.unsqueeze(1)
305
+
306
+ pos_emb = height_emb + width_emb
307
+
308
+ # 1 x H x W x D -> 1 x L xD
309
+ pos_emb = pos_emb.view(1, self.height * self.width, -1)
310
+
311
+ emb = emb + pos_emb[:, : emb.shape[1], :]
312
+
313
+ return emb
314
+
315
+
316
+ class LabelEmbedding(nn.Module):
317
+ """
318
+ Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
319
+
320
+ Args:
321
+ num_classes (`int`): The number of classes.
322
+ hidden_size (`int`): The size of the vector embeddings.
323
+ dropout_prob (`float`): The probability of dropping a label.
324
+ """
325
+
326
+ def __init__(self, num_classes, hidden_size, dropout_prob):
327
+ super().__init__()
328
+ use_cfg_embedding = dropout_prob > 0
329
+ self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
330
+ self.num_classes = num_classes
331
+ self.dropout_prob = dropout_prob
332
+
333
+ def token_drop(self, labels, force_drop_ids=None):
334
+ """
335
+ Drops labels to enable classifier-free guidance.
336
+ """
337
+ if force_drop_ids is None:
338
+ drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
339
+ else:
340
+ drop_ids = torch.tensor(force_drop_ids == 1)
341
+ labels = torch.where(drop_ids, self.num_classes, labels)
342
+ return labels
343
+
344
+ def forward(self, labels: torch.LongTensor, force_drop_ids=None):
345
+ use_dropout = self.dropout_prob > 0
346
+ if (self.training and use_dropout) or (force_drop_ids is not None):
347
+ labels = self.token_drop(labels, force_drop_ids)
348
+ embeddings = self.embedding_table(labels)
349
+ return embeddings
350
+
351
+
352
+ class TextImageProjection(nn.Module):
353
+ def __init__(
354
+ self,
355
+ text_embed_dim: int = 1024,
356
+ image_embed_dim: int = 768,
357
+ cross_attention_dim: int = 768,
358
+ num_image_text_embeds: int = 10,
359
+ ):
360
+ super().__init__()
361
+
362
+ self.num_image_text_embeds = num_image_text_embeds
363
+ self.image_embeds = nn.Linear(image_embed_dim, self.num_image_text_embeds * cross_attention_dim)
364
+ self.text_proj = nn.Linear(text_embed_dim, cross_attention_dim)
365
+
366
+ def forward(self, text_embeds: torch.FloatTensor, image_embeds: torch.FloatTensor):
367
+ batch_size = text_embeds.shape[0]
368
+
369
+ # image
370
+ image_text_embeds = self.image_embeds(image_embeds)
371
+ image_text_embeds = image_text_embeds.reshape(batch_size, self.num_image_text_embeds, -1)
372
+
373
+ # text
374
+ text_embeds = self.text_proj(text_embeds)
375
+
376
+ return torch.cat([image_text_embeds, text_embeds], dim=1)
377
+
378
+
379
+ class ImageProjection(nn.Module):
380
+ def __init__(
381
+ self,
382
+ image_embed_dim: int = 768,
383
+ cross_attention_dim: int = 768,
384
+ num_image_text_embeds: int = 32,
385
+ ):
386
+ super().__init__()
387
+
388
+ self.num_image_text_embeds = num_image_text_embeds
389
+ self.image_embeds = nn.Linear(image_embed_dim, self.num_image_text_embeds * cross_attention_dim)
390
+ self.norm = nn.LayerNorm(cross_attention_dim)
391
+
392
+ def forward(self, image_embeds: torch.FloatTensor):
393
+ batch_size = image_embeds.shape[0]
394
+
395
+ # image
396
+ image_embeds = self.image_embeds(image_embeds)
397
+ image_embeds = image_embeds.reshape(batch_size, self.num_image_text_embeds, -1)
398
+ image_embeds = self.norm(image_embeds)
399
+ return image_embeds
400
+
401
+
402
+ class CombinedTimestepLabelEmbeddings(nn.Module):
403
+ def __init__(self, num_classes, embedding_dim, class_dropout_prob=0.1):
404
+ super().__init__()
405
+
406
+ self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1)
407
+ self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
408
+ self.class_embedder = LabelEmbedding(num_classes, embedding_dim, class_dropout_prob)
409
+
410
+ def forward(self, timestep, class_labels, hidden_dtype=None):
411
+ timesteps_proj = self.time_proj(timestep)
412
+ timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D)
413
+
414
+ class_labels = self.class_embedder(class_labels) # (N, D)
415
+
416
+ conditioning = timesteps_emb + class_labels # (N, D)
417
+
418
+ return conditioning
419
+
420
+
421
+ class TextTimeEmbedding(nn.Module):
422
+ def __init__(self, encoder_dim: int, time_embed_dim: int, num_heads: int = 64):
423
+ super().__init__()
424
+ self.norm1 = nn.LayerNorm(encoder_dim)
425
+ self.pool = AttentionPooling(num_heads, encoder_dim)
426
+ self.proj = nn.Linear(encoder_dim, time_embed_dim)
427
+ self.norm2 = nn.LayerNorm(time_embed_dim)
428
+
429
+ def forward(self, hidden_states):
430
+ hidden_states = self.norm1(hidden_states)
431
+ hidden_states = self.pool(hidden_states)
432
+ hidden_states = self.proj(hidden_states)
433
+ hidden_states = self.norm2(hidden_states)
434
+ return hidden_states
435
+
436
+
437
+ class TextImageTimeEmbedding(nn.Module):
438
+ def __init__(self, text_embed_dim: int = 768, image_embed_dim: int = 768, time_embed_dim: int = 1536):
439
+ super().__init__()
440
+ self.text_proj = nn.Linear(text_embed_dim, time_embed_dim)
441
+ self.text_norm = nn.LayerNorm(time_embed_dim)
442
+ self.image_proj = nn.Linear(image_embed_dim, time_embed_dim)
443
+
444
+ def forward(self, text_embeds: torch.FloatTensor, image_embeds: torch.FloatTensor):
445
+ # text
446
+ time_text_embeds = self.text_proj(text_embeds)
447
+ time_text_embeds = self.text_norm(time_text_embeds)
448
+
449
+ # image
450
+ time_image_embeds = self.image_proj(image_embeds)
451
+
452
+ return time_image_embeds + time_text_embeds
453
+
454
+
455
+ class ImageTimeEmbedding(nn.Module):
456
+ def __init__(self, image_embed_dim: int = 768, time_embed_dim: int = 1536):
457
+ super().__init__()
458
+ self.image_proj = nn.Linear(image_embed_dim, time_embed_dim)
459
+ self.image_norm = nn.LayerNorm(time_embed_dim)
460
+
461
+ def forward(self, image_embeds: torch.FloatTensor):
462
+ # image
463
+ time_image_embeds = self.image_proj(image_embeds)
464
+ time_image_embeds = self.image_norm(time_image_embeds)
465
+ return time_image_embeds
466
+
467
+
468
+ class ImageHintTimeEmbedding(nn.Module):
469
+ def __init__(self, image_embed_dim: int = 768, time_embed_dim: int = 1536):
470
+ super().__init__()
471
+ self.image_proj = nn.Linear(image_embed_dim, time_embed_dim)
472
+ self.image_norm = nn.LayerNorm(time_embed_dim)
473
+ self.input_hint_block = nn.Sequential(
474
+ nn.Conv2d(3, 16, 3, padding=1),
475
+ nn.SiLU(),
476
+ nn.Conv2d(16, 16, 3, padding=1),
477
+ nn.SiLU(),
478
+ nn.Conv2d(16, 32, 3, padding=1, stride=2),
479
+ nn.SiLU(),
480
+ nn.Conv2d(32, 32, 3, padding=1),
481
+ nn.SiLU(),
482
+ nn.Conv2d(32, 96, 3, padding=1, stride=2),
483
+ nn.SiLU(),
484
+ nn.Conv2d(96, 96, 3, padding=1),
485
+ nn.SiLU(),
486
+ nn.Conv2d(96, 256, 3, padding=1, stride=2),
487
+ nn.SiLU(),
488
+ nn.Conv2d(256, 4, 3, padding=1),
489
+ )
490
+
491
+ def forward(self, image_embeds: torch.FloatTensor, hint: torch.FloatTensor):
492
+ # image
493
+ time_image_embeds = self.image_proj(image_embeds)
494
+ time_image_embeds = self.image_norm(time_image_embeds)
495
+ hint = self.input_hint_block(hint)
496
+ return time_image_embeds, hint
497
+
498
+
499
+ class AttentionPooling(nn.Module):
500
+ # Copied from https://github.com/deep-floyd/IF/blob/2f91391f27dd3c468bf174be5805b4cc92980c0b/deepfloyd_if/model/nn.py#L54
501
+
502
+ def __init__(self, num_heads, embed_dim, dtype=None):
503
+ super().__init__()
504
+ self.dtype = dtype
505
+ self.positional_embedding = nn.Parameter(torch.randn(1, embed_dim) / embed_dim**0.5)
506
+ self.k_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype)
507
+ self.q_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype)
508
+ self.v_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype)
509
+ self.num_heads = num_heads
510
+ self.dim_per_head = embed_dim // self.num_heads
511
+
512
+ def forward(self, x):
513
+ bs, length, width = x.size()
514
+
515
+ def shape(x):
516
+ # (bs, length, width) --> (bs, length, n_heads, dim_per_head)
517
+ x = x.view(bs, -1, self.num_heads, self.dim_per_head)
518
+ # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
519
+ x = x.transpose(1, 2)
520
+ # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
521
+ x = x.reshape(bs * self.num_heads, -1, self.dim_per_head)
522
+ # (bs*n_heads, length, dim_per_head) --> (bs*n_heads, dim_per_head, length)
523
+ x = x.transpose(1, 2)
524
+ return x
525
+
526
+ class_token = x.mean(dim=1, keepdim=True) + self.positional_embedding.to(x.dtype)
527
+ x = torch.cat([class_token, x], dim=1) # (bs, length+1, width)
528
+
529
+ # (bs*n_heads, class_token_length, dim_per_head)
530
+ q = shape(self.q_proj(class_token))
531
+ # (bs*n_heads, length+class_token_length, dim_per_head)
532
+ k = shape(self.k_proj(x))
533
+ v = shape(self.v_proj(x))
534
+
535
+ # (bs*n_heads, class_token_length, length+class_token_length):
536
+ scale = 1 / math.sqrt(math.sqrt(self.dim_per_head))
537
+ weight = torch.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards
538
+ weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
539
+
540
+ # (bs*n_heads, dim_per_head, class_token_length)
541
+ a = torch.einsum("bts,bcs->bct", weight, v)
542
+
543
+ # (bs, length+1, width)
544
+ a = a.reshape(bs, -1, 1).transpose(1, 2)
545
+
546
+ return a[:, 0, :] # cls_token
6DoF/diffusers/models/embeddings_flax.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import math
15
+
16
+ import flax.linen as nn
17
+ import jax.numpy as jnp
18
+
19
+
20
+ def get_sinusoidal_embeddings(
21
+ timesteps: jnp.ndarray,
22
+ embedding_dim: int,
23
+ freq_shift: float = 1,
24
+ min_timescale: float = 1,
25
+ max_timescale: float = 1.0e4,
26
+ flip_sin_to_cos: bool = False,
27
+ scale: float = 1.0,
28
+ ) -> jnp.ndarray:
29
+ """Returns the positional encoding (same as Tensor2Tensor).
30
+
31
+ Args:
32
+ timesteps: a 1-D Tensor of N indices, one per batch element.
33
+ These may be fractional.
34
+ embedding_dim: The number of output channels.
35
+ min_timescale: The smallest time unit (should probably be 0.0).
36
+ max_timescale: The largest time unit.
37
+ Returns:
38
+ a Tensor of timing signals [N, num_channels]
39
+ """
40
+ assert timesteps.ndim == 1, "Timesteps should be a 1d-array"
41
+ assert embedding_dim % 2 == 0, f"Embedding dimension {embedding_dim} should be even"
42
+ num_timescales = float(embedding_dim // 2)
43
+ log_timescale_increment = math.log(max_timescale / min_timescale) / (num_timescales - freq_shift)
44
+ inv_timescales = min_timescale * jnp.exp(jnp.arange(num_timescales, dtype=jnp.float32) * -log_timescale_increment)
45
+ emb = jnp.expand_dims(timesteps, 1) * jnp.expand_dims(inv_timescales, 0)
46
+
47
+ # scale embeddings
48
+ scaled_time = scale * emb
49
+
50
+ if flip_sin_to_cos:
51
+ signal = jnp.concatenate([jnp.cos(scaled_time), jnp.sin(scaled_time)], axis=1)
52
+ else:
53
+ signal = jnp.concatenate([jnp.sin(scaled_time), jnp.cos(scaled_time)], axis=1)
54
+ signal = jnp.reshape(signal, [jnp.shape(timesteps)[0], embedding_dim])
55
+ return signal
56
+
57
+
58
+ class FlaxTimestepEmbedding(nn.Module):
59
+ r"""
60
+ Time step Embedding Module. Learns embeddings for input time steps.
61
+
62
+ Args:
63
+ time_embed_dim (`int`, *optional*, defaults to `32`):
64
+ Time step embedding dimension
65
+ dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
66
+ Parameters `dtype`
67
+ """
68
+ time_embed_dim: int = 32
69
+ dtype: jnp.dtype = jnp.float32
70
+
71
+ @nn.compact
72
+ def __call__(self, temb):
73
+ temb = nn.Dense(self.time_embed_dim, dtype=self.dtype, name="linear_1")(temb)
74
+ temb = nn.silu(temb)
75
+ temb = nn.Dense(self.time_embed_dim, dtype=self.dtype, name="linear_2")(temb)
76
+ return temb
77
+
78
+
79
+ class FlaxTimesteps(nn.Module):
80
+ r"""
81
+ Wrapper Module for sinusoidal Time step Embeddings as described in https://arxiv.org/abs/2006.11239
82
+
83
+ Args:
84
+ dim (`int`, *optional*, defaults to `32`):
85
+ Time step embedding dimension
86
+ """
87
+ dim: int = 32
88
+ flip_sin_to_cos: bool = False
89
+ freq_shift: float = 1
90
+
91
+ @nn.compact
92
+ def __call__(self, timesteps):
93
+ return get_sinusoidal_embeddings(
94
+ timesteps, embedding_dim=self.dim, flip_sin_to_cos=self.flip_sin_to_cos, freq_shift=self.freq_shift
95
+ )
6DoF/diffusers/models/modeling_flax_pytorch_utils.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ PyTorch - Flax general utilities."""
16
+ import re
17
+
18
+ import jax.numpy as jnp
19
+ from flax.traverse_util import flatten_dict, unflatten_dict
20
+ from jax.random import PRNGKey
21
+
22
+ from ..utils import logging
23
+
24
+
25
+ logger = logging.get_logger(__name__)
26
+
27
+
28
+ def rename_key(key):
29
+ regex = r"\w+[.]\d+"
30
+ pats = re.findall(regex, key)
31
+ for pat in pats:
32
+ key = key.replace(pat, "_".join(pat.split(".")))
33
+ return key
34
+
35
+
36
+ #####################
37
+ # PyTorch => Flax #
38
+ #####################
39
+
40
+
41
+ # Adapted from https://github.com/huggingface/transformers/blob/c603c80f46881ae18b2ca50770ef65fa4033eacd/src/transformers/modeling_flax_pytorch_utils.py#L69
42
+ # and https://github.com/patil-suraj/stable-diffusion-jax/blob/main/stable_diffusion_jax/convert_diffusers_to_jax.py
43
+ def rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dict):
44
+ """Rename PT weight names to corresponding Flax weight names and reshape tensor if necessary"""
45
+
46
+ # conv norm or layer norm
47
+ renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",)
48
+ if (
49
+ any("norm" in str_ for str_ in pt_tuple_key)
50
+ and (pt_tuple_key[-1] == "bias")
51
+ and (pt_tuple_key[:-1] + ("bias",) not in random_flax_state_dict)
52
+ and (pt_tuple_key[:-1] + ("scale",) in random_flax_state_dict)
53
+ ):
54
+ renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",)
55
+ return renamed_pt_tuple_key, pt_tensor
56
+ elif pt_tuple_key[-1] in ["weight", "gamma"] and pt_tuple_key[:-1] + ("scale",) in random_flax_state_dict:
57
+ renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",)
58
+ return renamed_pt_tuple_key, pt_tensor
59
+
60
+ # embedding
61
+ if pt_tuple_key[-1] == "weight" and pt_tuple_key[:-1] + ("embedding",) in random_flax_state_dict:
62
+ pt_tuple_key = pt_tuple_key[:-1] + ("embedding",)
63
+ return renamed_pt_tuple_key, pt_tensor
64
+
65
+ # conv layer
66
+ renamed_pt_tuple_key = pt_tuple_key[:-1] + ("kernel",)
67
+ if pt_tuple_key[-1] == "weight" and pt_tensor.ndim == 4:
68
+ pt_tensor = pt_tensor.transpose(2, 3, 1, 0)
69
+ return renamed_pt_tuple_key, pt_tensor
70
+
71
+ # linear layer
72
+ renamed_pt_tuple_key = pt_tuple_key[:-1] + ("kernel",)
73
+ if pt_tuple_key[-1] == "weight":
74
+ pt_tensor = pt_tensor.T
75
+ return renamed_pt_tuple_key, pt_tensor
76
+
77
+ # old PyTorch layer norm weight
78
+ renamed_pt_tuple_key = pt_tuple_key[:-1] + ("weight",)
79
+ if pt_tuple_key[-1] == "gamma":
80
+ return renamed_pt_tuple_key, pt_tensor
81
+
82
+ # old PyTorch layer norm bias
83
+ renamed_pt_tuple_key = pt_tuple_key[:-1] + ("bias",)
84
+ if pt_tuple_key[-1] == "beta":
85
+ return renamed_pt_tuple_key, pt_tensor
86
+
87
+ return pt_tuple_key, pt_tensor
88
+
89
+
90
+ def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model, init_key=42):
91
+ # Step 1: Convert pytorch tensor to numpy
92
+ pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()}
93
+
94
+ # Step 2: Since the model is stateless, get random Flax params
95
+ random_flax_params = flax_model.init_weights(PRNGKey(init_key))
96
+
97
+ random_flax_state_dict = flatten_dict(random_flax_params)
98
+ flax_state_dict = {}
99
+
100
+ # Need to change some parameters name to match Flax names
101
+ for pt_key, pt_tensor in pt_state_dict.items():
102
+ renamed_pt_key = rename_key(pt_key)
103
+ pt_tuple_key = tuple(renamed_pt_key.split("."))
104
+
105
+ # Correctly rename weight parameters
106
+ flax_key, flax_tensor = rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dict)
107
+
108
+ if flax_key in random_flax_state_dict:
109
+ if flax_tensor.shape != random_flax_state_dict[flax_key].shape:
110
+ raise ValueError(
111
+ f"PyTorch checkpoint seems to be incorrect. Weight {pt_key} was expected to be of shape "
112
+ f"{random_flax_state_dict[flax_key].shape}, but is {flax_tensor.shape}."
113
+ )
114
+
115
+ # also add unexpected weight so that warning is thrown
116
+ flax_state_dict[flax_key] = jnp.asarray(flax_tensor)
117
+
118
+ return unflatten_dict(flax_state_dict)
6DoF/diffusers/models/modeling_flax_utils.py ADDED
@@ -0,0 +1,534 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import os
17
+ from pickle import UnpicklingError
18
+ from typing import Any, Dict, Union
19
+
20
+ import jax
21
+ import jax.numpy as jnp
22
+ import msgpack.exceptions
23
+ from flax.core.frozen_dict import FrozenDict, unfreeze
24
+ from flax.serialization import from_bytes, to_bytes
25
+ from flax.traverse_util import flatten_dict, unflatten_dict
26
+ from huggingface_hub import hf_hub_download
27
+ from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
28
+ from requests import HTTPError
29
+
30
+ from .. import __version__, is_torch_available
31
+ from ..utils import (
32
+ CONFIG_NAME,
33
+ DIFFUSERS_CACHE,
34
+ FLAX_WEIGHTS_NAME,
35
+ HUGGINGFACE_CO_RESOLVE_ENDPOINT,
36
+ WEIGHTS_NAME,
37
+ logging,
38
+ )
39
+ from .modeling_flax_pytorch_utils import convert_pytorch_state_dict_to_flax
40
+
41
+
42
+ logger = logging.get_logger(__name__)
43
+
44
+
45
+ class FlaxModelMixin:
46
+ r"""
47
+ Base class for all Flax models.
48
+
49
+ [`FlaxModelMixin`] takes care of storing the model configuration and provides methods for loading, downloading and
50
+ saving models.
51
+
52
+ - **config_name** ([`str`]) -- Filename to save a model to when calling [`~FlaxModelMixin.save_pretrained`].
53
+ """
54
+ config_name = CONFIG_NAME
55
+ _automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"]
56
+ _flax_internal_args = ["name", "parent", "dtype"]
57
+
58
+ @classmethod
59
+ def _from_config(cls, config, **kwargs):
60
+ """
61
+ All context managers that the model should be initialized under go here.
62
+ """
63
+ return cls(config, **kwargs)
64
+
65
+ def _cast_floating_to(self, params: Union[Dict, FrozenDict], dtype: jnp.dtype, mask: Any = None) -> Any:
66
+ """
67
+ Helper method to cast floating-point values of given parameter `PyTree` to given `dtype`.
68
+ """
69
+
70
+ # taken from https://github.com/deepmind/jmp/blob/3a8318abc3292be38582794dbf7b094e6583b192/jmp/_src/policy.py#L27
71
+ def conditional_cast(param):
72
+ if isinstance(param, jnp.ndarray) and jnp.issubdtype(param.dtype, jnp.floating):
73
+ param = param.astype(dtype)
74
+ return param
75
+
76
+ if mask is None:
77
+ return jax.tree_map(conditional_cast, params)
78
+
79
+ flat_params = flatten_dict(params)
80
+ flat_mask, _ = jax.tree_flatten(mask)
81
+
82
+ for masked, key in zip(flat_mask, flat_params.keys()):
83
+ if masked:
84
+ param = flat_params[key]
85
+ flat_params[key] = conditional_cast(param)
86
+
87
+ return unflatten_dict(flat_params)
88
+
89
+ def to_bf16(self, params: Union[Dict, FrozenDict], mask: Any = None):
90
+ r"""
91
+ Cast the floating-point `params` to `jax.numpy.bfloat16`. This returns a new `params` tree and does not cast
92
+ the `params` in place.
93
+
94
+ This method can be used on a TPU to explicitly convert the model parameters to bfloat16 precision to do full
95
+ half-precision training or to save weights in bfloat16 for inference in order to save memory and improve speed.
96
+
97
+ Arguments:
98
+ params (`Union[Dict, FrozenDict]`):
99
+ A `PyTree` of model parameters.
100
+ mask (`Union[Dict, FrozenDict]`):
101
+ A `PyTree` with same structure as the `params` tree. The leaves should be booleans. It should be `True`
102
+ for params you want to cast, and `False` for those you want to skip.
103
+
104
+ Examples:
105
+
106
+ ```python
107
+ >>> from diffusers import FlaxUNet2DConditionModel
108
+
109
+ >>> # load model
110
+ >>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5")
111
+ >>> # By default, the model parameters will be in fp32 precision, to cast these to bfloat16 precision
112
+ >>> params = model.to_bf16(params)
113
+ >>> # If you don't want to cast certain parameters (for example layer norm bias and scale)
114
+ >>> # then pass the mask as follows
115
+ >>> from flax import traverse_util
116
+
117
+ >>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5")
118
+ >>> flat_params = traverse_util.flatten_dict(params)
119
+ >>> mask = {
120
+ ... path: (path[-2] != ("LayerNorm", "bias") and path[-2:] != ("LayerNorm", "scale"))
121
+ ... for path in flat_params
122
+ ... }
123
+ >>> mask = traverse_util.unflatten_dict(mask)
124
+ >>> params = model.to_bf16(params, mask)
125
+ ```"""
126
+ return self._cast_floating_to(params, jnp.bfloat16, mask)
127
+
128
+ def to_fp32(self, params: Union[Dict, FrozenDict], mask: Any = None):
129
+ r"""
130
+ Cast the floating-point `params` to `jax.numpy.float32`. This method can be used to explicitly convert the
131
+ model parameters to fp32 precision. This returns a new `params` tree and does not cast the `params` in place.
132
+
133
+ Arguments:
134
+ params (`Union[Dict, FrozenDict]`):
135
+ A `PyTree` of model parameters.
136
+ mask (`Union[Dict, FrozenDict]`):
137
+ A `PyTree` with same structure as the `params` tree. The leaves should be booleans. It should be `True`
138
+ for params you want to cast, and `False` for those you want to skip.
139
+
140
+ Examples:
141
+
142
+ ```python
143
+ >>> from diffusers import FlaxUNet2DConditionModel
144
+
145
+ >>> # Download model and configuration from huggingface.co
146
+ >>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5")
147
+ >>> # By default, the model params will be in fp32, to illustrate the use of this method,
148
+ >>> # we'll first cast to fp16 and back to fp32
149
+ >>> params = model.to_f16(params)
150
+ >>> # now cast back to fp32
151
+ >>> params = model.to_fp32(params)
152
+ ```"""
153
+ return self._cast_floating_to(params, jnp.float32, mask)
154
+
155
+ def to_fp16(self, params: Union[Dict, FrozenDict], mask: Any = None):
156
+ r"""
157
+ Cast the floating-point `params` to `jax.numpy.float16`. This returns a new `params` tree and does not cast the
158
+ `params` in place.
159
+
160
+ This method can be used on a GPU to explicitly convert the model parameters to float16 precision to do full
161
+ half-precision training or to save weights in float16 for inference in order to save memory and improve speed.
162
+
163
+ Arguments:
164
+ params (`Union[Dict, FrozenDict]`):
165
+ A `PyTree` of model parameters.
166
+ mask (`Union[Dict, FrozenDict]`):
167
+ A `PyTree` with same structure as the `params` tree. The leaves should be booleans. It should be `True`
168
+ for params you want to cast, and `False` for those you want to skip.
169
+
170
+ Examples:
171
+
172
+ ```python
173
+ >>> from diffusers import FlaxUNet2DConditionModel
174
+
175
+ >>> # load model
176
+ >>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5")
177
+ >>> # By default, the model params will be in fp32, to cast these to float16
178
+ >>> params = model.to_fp16(params)
179
+ >>> # If you want don't want to cast certain parameters (for example layer norm bias and scale)
180
+ >>> # then pass the mask as follows
181
+ >>> from flax import traverse_util
182
+
183
+ >>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5")
184
+ >>> flat_params = traverse_util.flatten_dict(params)
185
+ >>> mask = {
186
+ ... path: (path[-2] != ("LayerNorm", "bias") and path[-2:] != ("LayerNorm", "scale"))
187
+ ... for path in flat_params
188
+ ... }
189
+ >>> mask = traverse_util.unflatten_dict(mask)
190
+ >>> params = model.to_fp16(params, mask)
191
+ ```"""
192
+ return self._cast_floating_to(params, jnp.float16, mask)
193
+
194
+ def init_weights(self, rng: jax.random.KeyArray) -> Dict:
195
+ raise NotImplementedError(f"init_weights method has to be implemented for {self}")
196
+
197
+ @classmethod
198
+ def from_pretrained(
199
+ cls,
200
+ pretrained_model_name_or_path: Union[str, os.PathLike],
201
+ dtype: jnp.dtype = jnp.float32,
202
+ *model_args,
203
+ **kwargs,
204
+ ):
205
+ r"""
206
+ Instantiate a pretrained Flax model from a pretrained model configuration.
207
+
208
+ Parameters:
209
+ pretrained_model_name_or_path (`str` or `os.PathLike`):
210
+ Can be either:
211
+
212
+ - A string, the *model id* (for example `runwayml/stable-diffusion-v1-5`) of a pretrained model
213
+ hosted on the Hub.
214
+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
215
+ using [`~FlaxModelMixin.save_pretrained`].
216
+ dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
217
+ The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
218
+ `jax.numpy.bfloat16` (on TPUs).
219
+
220
+ This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
221
+ specified, all the computation will be performed with the given `dtype`.
222
+
223
+ <Tip>
224
+
225
+ This only specifies the dtype of the *computation* and does not influence the dtype of model
226
+ parameters.
227
+
228
+ If you wish to change the dtype of the model parameters, see [`~FlaxModelMixin.to_fp16`] and
229
+ [`~FlaxModelMixin.to_bf16`].
230
+
231
+ </Tip>
232
+
233
+ model_args (sequence of positional arguments, *optional*):
234
+ All remaining positional arguments are passed to the underlying model's `__init__` method.
235
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
236
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
237
+ is not used.
238
+ force_download (`bool`, *optional*, defaults to `False`):
239
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
240
+ cached versions if they exist.
241
+ resume_download (`bool`, *optional*, defaults to `False`):
242
+ Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
243
+ incompletely downloaded files are deleted.
244
+ proxies (`Dict[str, str]`, *optional*):
245
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
246
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
247
+ local_files_only(`bool`, *optional*, defaults to `False`):
248
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
249
+ won't be downloaded from the Hub.
250
+ revision (`str`, *optional*, defaults to `"main"`):
251
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
252
+ allowed by Git.
253
+ from_pt (`bool`, *optional*, defaults to `False`):
254
+ Load the model weights from a PyTorch checkpoint save file.
255
+ kwargs (remaining dictionary of keyword arguments, *optional*):
256
+ Can be used to update the configuration object (after it is loaded) and initiate the model (for
257
+ example, `output_attentions=True`). Behaves differently depending on whether a `config` is provided or
258
+ automatically loaded:
259
+
260
+ - If a configuration is provided with `config`, `kwargs` are directly passed to the underlying
261
+ model's `__init__` method (we assume all relevant updates to the configuration have already been
262
+ done).
263
+ - If a configuration is not provided, `kwargs` are first passed to the configuration class
264
+ initialization function [`~ConfigMixin.from_config`]. Each key of the `kwargs` that corresponds
265
+ to a configuration attribute is used to override said attribute with the supplied `kwargs` value.
266
+ Remaining keys that do not correspond to any configuration attribute are passed to the underlying
267
+ model's `__init__` function.
268
+
269
+ Examples:
270
+
271
+ ```python
272
+ >>> from diffusers import FlaxUNet2DConditionModel
273
+
274
+ >>> # Download model and configuration from huggingface.co and cache.
275
+ >>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5")
276
+ >>> # Model was saved using *save_pretrained('./test/saved_model/')* (for example purposes, not runnable).
277
+ >>> model, params = FlaxUNet2DConditionModel.from_pretrained("./test/saved_model/")
278
+ ```
279
+
280
+ If you get the error message below, you need to finetune the weights for your downstream task:
281
+
282
+ ```bash
283
+ Some weights of UNet2DConditionModel were not initialized from the model checkpoint at runwayml/stable-diffusion-v1-5 and are newly initialized because the shapes did not match:
284
+ - conv_in.weight: found shape torch.Size([320, 4, 3, 3]) in the checkpoint and torch.Size([320, 9, 3, 3]) in the model instantiated
285
+ You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
286
+ ```
287
+ """
288
+ config = kwargs.pop("config", None)
289
+ cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
290
+ force_download = kwargs.pop("force_download", False)
291
+ from_pt = kwargs.pop("from_pt", False)
292
+ resume_download = kwargs.pop("resume_download", False)
293
+ proxies = kwargs.pop("proxies", None)
294
+ local_files_only = kwargs.pop("local_files_only", False)
295
+ use_auth_token = kwargs.pop("use_auth_token", None)
296
+ revision = kwargs.pop("revision", None)
297
+ subfolder = kwargs.pop("subfolder", None)
298
+
299
+ user_agent = {
300
+ "diffusers": __version__,
301
+ "file_type": "model",
302
+ "framework": "flax",
303
+ }
304
+
305
+ # Load config if we don't provide a configuration
306
+ config_path = config if config is not None else pretrained_model_name_or_path
307
+ model, model_kwargs = cls.from_config(
308
+ config_path,
309
+ cache_dir=cache_dir,
310
+ return_unused_kwargs=True,
311
+ force_download=force_download,
312
+ resume_download=resume_download,
313
+ proxies=proxies,
314
+ local_files_only=local_files_only,
315
+ use_auth_token=use_auth_token,
316
+ revision=revision,
317
+ subfolder=subfolder,
318
+ # model args
319
+ dtype=dtype,
320
+ **kwargs,
321
+ )
322
+
323
+ # Load model
324
+ pretrained_path_with_subfolder = (
325
+ pretrained_model_name_or_path
326
+ if subfolder is None
327
+ else os.path.join(pretrained_model_name_or_path, subfolder)
328
+ )
329
+ if os.path.isdir(pretrained_path_with_subfolder):
330
+ if from_pt:
331
+ if not os.path.isfile(os.path.join(pretrained_path_with_subfolder, WEIGHTS_NAME)):
332
+ raise EnvironmentError(
333
+ f"Error no file named {WEIGHTS_NAME} found in directory {pretrained_path_with_subfolder} "
334
+ )
335
+ model_file = os.path.join(pretrained_path_with_subfolder, WEIGHTS_NAME)
336
+ elif os.path.isfile(os.path.join(pretrained_path_with_subfolder, FLAX_WEIGHTS_NAME)):
337
+ # Load from a Flax checkpoint
338
+ model_file = os.path.join(pretrained_path_with_subfolder, FLAX_WEIGHTS_NAME)
339
+ # Check if pytorch weights exist instead
340
+ elif os.path.isfile(os.path.join(pretrained_path_with_subfolder, WEIGHTS_NAME)):
341
+ raise EnvironmentError(
342
+ f"{WEIGHTS_NAME} file found in directory {pretrained_path_with_subfolder}. Please load the model"
343
+ " using `from_pt=True`."
344
+ )
345
+ else:
346
+ raise EnvironmentError(
347
+ f"Error no file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME} found in directory "
348
+ f"{pretrained_path_with_subfolder}."
349
+ )
350
+ else:
351
+ try:
352
+ model_file = hf_hub_download(
353
+ pretrained_model_name_or_path,
354
+ filename=FLAX_WEIGHTS_NAME if not from_pt else WEIGHTS_NAME,
355
+ cache_dir=cache_dir,
356
+ force_download=force_download,
357
+ proxies=proxies,
358
+ resume_download=resume_download,
359
+ local_files_only=local_files_only,
360
+ use_auth_token=use_auth_token,
361
+ user_agent=user_agent,
362
+ subfolder=subfolder,
363
+ revision=revision,
364
+ )
365
+
366
+ except RepositoryNotFoundError:
367
+ raise EnvironmentError(
368
+ f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
369
+ "listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a "
370
+ "token having permission to this repo with `use_auth_token` or log in with `huggingface-cli "
371
+ "login`."
372
+ )
373
+ except RevisionNotFoundError:
374
+ raise EnvironmentError(
375
+ f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for "
376
+ "this model name. Check the model page at "
377
+ f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
378
+ )
379
+ except EntryNotFoundError:
380
+ raise EnvironmentError(
381
+ f"{pretrained_model_name_or_path} does not appear to have a file named {FLAX_WEIGHTS_NAME}."
382
+ )
383
+ except HTTPError as err:
384
+ raise EnvironmentError(
385
+ f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n"
386
+ f"{err}"
387
+ )
388
+ except ValueError:
389
+ raise EnvironmentError(
390
+ f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
391
+ f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
392
+ f" directory containing a file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME}.\nCheckout your"
393
+ " internet connection or see how to run the library in offline mode at"
394
+ " 'https://huggingface.co/docs/transformers/installation#offline-mode'."
395
+ )
396
+ except EnvironmentError:
397
+ raise EnvironmentError(
398
+ f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from "
399
+ "'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
400
+ f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
401
+ f"containing a file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME}."
402
+ )
403
+
404
+ if from_pt:
405
+ if is_torch_available():
406
+ from .modeling_utils import load_state_dict
407
+ else:
408
+ raise EnvironmentError(
409
+ "Can't load the model in PyTorch format because PyTorch is not installed. "
410
+ "Please, install PyTorch or use native Flax weights."
411
+ )
412
+
413
+ # Step 1: Get the pytorch file
414
+ pytorch_model_file = load_state_dict(model_file)
415
+
416
+ # Step 2: Convert the weights
417
+ state = convert_pytorch_state_dict_to_flax(pytorch_model_file, model)
418
+ else:
419
+ try:
420
+ with open(model_file, "rb") as state_f:
421
+ state = from_bytes(cls, state_f.read())
422
+ except (UnpicklingError, msgpack.exceptions.ExtraData) as e:
423
+ try:
424
+ with open(model_file) as f:
425
+ if f.read().startswith("version"):
426
+ raise OSError(
427
+ "You seem to have cloned a repository without having git-lfs installed. Please"
428
+ " install git-lfs and run `git lfs install` followed by `git lfs pull` in the"
429
+ " folder you cloned."
430
+ )
431
+ else:
432
+ raise ValueError from e
433
+ except (UnicodeDecodeError, ValueError):
434
+ raise EnvironmentError(f"Unable to convert {model_file} to Flax deserializable object. ")
435
+ # make sure all arrays are stored as jnp.ndarray
436
+ # NOTE: This is to prevent a bug this will be fixed in Flax >= v0.3.4:
437
+ # https://github.com/google/flax/issues/1261
438
+ state = jax.tree_util.tree_map(lambda x: jax.device_put(x, jax.devices("cpu")[0]), state)
439
+
440
+ # flatten dicts
441
+ state = flatten_dict(state)
442
+
443
+ params_shape_tree = jax.eval_shape(model.init_weights, rng=jax.random.PRNGKey(0))
444
+ required_params = set(flatten_dict(unfreeze(params_shape_tree)).keys())
445
+
446
+ shape_state = flatten_dict(unfreeze(params_shape_tree))
447
+
448
+ missing_keys = required_params - set(state.keys())
449
+ unexpected_keys = set(state.keys()) - required_params
450
+
451
+ if missing_keys:
452
+ logger.warning(
453
+ f"The checkpoint {pretrained_model_name_or_path} is missing required keys: {missing_keys}. "
454
+ "Make sure to call model.init_weights to initialize the missing weights."
455
+ )
456
+ cls._missing_keys = missing_keys
457
+
458
+ for key in state.keys():
459
+ if key in shape_state and state[key].shape != shape_state[key].shape:
460
+ raise ValueError(
461
+ f"Trying to load the pretrained weight for {key} failed: checkpoint has shape "
462
+ f"{state[key].shape} which is incompatible with the model shape {shape_state[key].shape}. "
463
+ )
464
+
465
+ # remove unexpected keys to not be saved again
466
+ for unexpected_key in unexpected_keys:
467
+ del state[unexpected_key]
468
+
469
+ if len(unexpected_keys) > 0:
470
+ logger.warning(
471
+ f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
472
+ f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
473
+ f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task or"
474
+ " with another architecture."
475
+ )
476
+ else:
477
+ logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
478
+
479
+ if len(missing_keys) > 0:
480
+ logger.warning(
481
+ f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
482
+ f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably"
483
+ " TRAIN this model on a down-stream task to be able to use it for predictions and inference."
484
+ )
485
+ else:
486
+ logger.info(
487
+ f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at"
488
+ f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the checkpoint"
489
+ f" was trained on, you can already use {model.__class__.__name__} for predictions without further"
490
+ " training."
491
+ )
492
+
493
+ return model, unflatten_dict(state)
494
+
495
+ def save_pretrained(
496
+ self,
497
+ save_directory: Union[str, os.PathLike],
498
+ params: Union[Dict, FrozenDict],
499
+ is_main_process: bool = True,
500
+ ):
501
+ """
502
+ Save a model and its configuration file to a directory so that it can be reloaded using the
503
+ [`~FlaxModelMixin.from_pretrained`] class method.
504
+
505
+ Arguments:
506
+ save_directory (`str` or `os.PathLike`):
507
+ Directory to save a model and its configuration file to. Will be created if it doesn't exist.
508
+ params (`Union[Dict, FrozenDict]`):
509
+ A `PyTree` of model parameters.
510
+ is_main_process (`bool`, *optional*, defaults to `True`):
511
+ Whether the process calling this is the main process or not. Useful during distributed training and you
512
+ need to call this function on all processes. In this case, set `is_main_process=True` only on the main
513
+ process to avoid race conditions.
514
+ """
515
+ if os.path.isfile(save_directory):
516
+ logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
517
+ return
518
+
519
+ os.makedirs(save_directory, exist_ok=True)
520
+
521
+ model_to_save = self
522
+
523
+ # Attach architecture to the config
524
+ # Save the config
525
+ if is_main_process:
526
+ model_to_save.save_config(save_directory)
527
+
528
+ # save model
529
+ output_model_file = os.path.join(save_directory, FLAX_WEIGHTS_NAME)
530
+ with open(output_model_file, "wb") as f:
531
+ model_bytes = to_bytes(params)
532
+ f.write(model_bytes)
533
+
534
+ logger.info(f"Model weights saved in {output_model_file}")
6DoF/diffusers/models/modeling_pytorch_flax_utils.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ PyTorch - Flax general utilities."""
16
+
17
+ from pickle import UnpicklingError
18
+
19
+ import jax
20
+ import jax.numpy as jnp
21
+ import numpy as np
22
+ from flax.serialization import from_bytes
23
+ from flax.traverse_util import flatten_dict
24
+
25
+ from ..utils import logging
26
+
27
+
28
+ logger = logging.get_logger(__name__)
29
+
30
+
31
+ #####################
32
+ # Flax => PyTorch #
33
+ #####################
34
+
35
+
36
+ # from https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_flax_pytorch_utils.py#L224-L352
37
+ def load_flax_checkpoint_in_pytorch_model(pt_model, model_file):
38
+ try:
39
+ with open(model_file, "rb") as flax_state_f:
40
+ flax_state = from_bytes(None, flax_state_f.read())
41
+ except UnpicklingError as e:
42
+ try:
43
+ with open(model_file) as f:
44
+ if f.read().startswith("version"):
45
+ raise OSError(
46
+ "You seem to have cloned a repository without having git-lfs installed. Please"
47
+ " install git-lfs and run `git lfs install` followed by `git lfs pull` in the"
48
+ " folder you cloned."
49
+ )
50
+ else:
51
+ raise ValueError from e
52
+ except (UnicodeDecodeError, ValueError):
53
+ raise EnvironmentError(f"Unable to convert {model_file} to Flax deserializable object. ")
54
+
55
+ return load_flax_weights_in_pytorch_model(pt_model, flax_state)
56
+
57
+
58
+ def load_flax_weights_in_pytorch_model(pt_model, flax_state):
59
+ """Load flax checkpoints in a PyTorch model"""
60
+
61
+ try:
62
+ import torch # noqa: F401
63
+ except ImportError:
64
+ logger.error(
65
+ "Loading Flax weights in PyTorch requires both PyTorch and Flax to be installed. Please see"
66
+ " https://pytorch.org/ and https://flax.readthedocs.io/en/latest/installation.html for installation"
67
+ " instructions."
68
+ )
69
+ raise
70
+
71
+ # check if we have bf16 weights
72
+ is_type_bf16 = flatten_dict(jax.tree_util.tree_map(lambda x: x.dtype == jnp.bfloat16, flax_state)).values()
73
+ if any(is_type_bf16):
74
+ # convert all weights to fp32 if they are bf16 since torch.from_numpy can-not handle bf16
75
+
76
+ # and bf16 is not fully supported in PT yet.
77
+ logger.warning(
78
+ "Found ``bfloat16`` weights in Flax model. Casting all ``bfloat16`` weights to ``float32`` "
79
+ "before loading those in PyTorch model."
80
+ )
81
+ flax_state = jax.tree_util.tree_map(
82
+ lambda params: params.astype(np.float32) if params.dtype == jnp.bfloat16 else params, flax_state
83
+ )
84
+
85
+ pt_model.base_model_prefix = ""
86
+
87
+ flax_state_dict = flatten_dict(flax_state, sep=".")
88
+ pt_model_dict = pt_model.state_dict()
89
+
90
+ # keep track of unexpected & missing keys
91
+ unexpected_keys = []
92
+ missing_keys = set(pt_model_dict.keys())
93
+
94
+ for flax_key_tuple, flax_tensor in flax_state_dict.items():
95
+ flax_key_tuple_array = flax_key_tuple.split(".")
96
+
97
+ if flax_key_tuple_array[-1] == "kernel" and flax_tensor.ndim == 4:
98
+ flax_key_tuple_array = flax_key_tuple_array[:-1] + ["weight"]
99
+ flax_tensor = jnp.transpose(flax_tensor, (3, 2, 0, 1))
100
+ elif flax_key_tuple_array[-1] == "kernel":
101
+ flax_key_tuple_array = flax_key_tuple_array[:-1] + ["weight"]
102
+ flax_tensor = flax_tensor.T
103
+ elif flax_key_tuple_array[-1] == "scale":
104
+ flax_key_tuple_array = flax_key_tuple_array[:-1] + ["weight"]
105
+
106
+ if "time_embedding" not in flax_key_tuple_array:
107
+ for i, flax_key_tuple_string in enumerate(flax_key_tuple_array):
108
+ flax_key_tuple_array[i] = (
109
+ flax_key_tuple_string.replace("_0", ".0")
110
+ .replace("_1", ".1")
111
+ .replace("_2", ".2")
112
+ .replace("_3", ".3")
113
+ .replace("_4", ".4")
114
+ .replace("_5", ".5")
115
+ .replace("_6", ".6")
116
+ .replace("_7", ".7")
117
+ .replace("_8", ".8")
118
+ .replace("_9", ".9")
119
+ )
120
+
121
+ flax_key = ".".join(flax_key_tuple_array)
122
+
123
+ if flax_key in pt_model_dict:
124
+ if flax_tensor.shape != pt_model_dict[flax_key].shape:
125
+ raise ValueError(
126
+ f"Flax checkpoint seems to be incorrect. Weight {flax_key_tuple} was expected "
127
+ f"to be of shape {pt_model_dict[flax_key].shape}, but is {flax_tensor.shape}."
128
+ )
129
+ else:
130
+ # add weight to pytorch dict
131
+ flax_tensor = np.asarray(flax_tensor) if not isinstance(flax_tensor, np.ndarray) else flax_tensor
132
+ pt_model_dict[flax_key] = torch.from_numpy(flax_tensor)
133
+ # remove from missing keys
134
+ missing_keys.remove(flax_key)
135
+ else:
136
+ # weight is not expected by PyTorch model
137
+ unexpected_keys.append(flax_key)
138
+
139
+ pt_model.load_state_dict(pt_model_dict)
140
+
141
+ # re-transform missing_keys to list
142
+ missing_keys = list(missing_keys)
143
+
144
+ if len(unexpected_keys) > 0:
145
+ logger.warning(
146
+ "Some weights of the Flax model were not used when initializing the PyTorch model"
147
+ f" {pt_model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are initializing"
148
+ f" {pt_model.__class__.__name__} from a Flax model trained on another task or with another architecture"
149
+ " (e.g. initializing a BertForSequenceClassification model from a FlaxBertForPreTraining model).\n- This"
150
+ f" IS NOT expected if you are initializing {pt_model.__class__.__name__} from a Flax model that you expect"
151
+ " to be exactly identical (e.g. initializing a BertForSequenceClassification model from a"
152
+ " FlaxBertForSequenceClassification model)."
153
+ )
154
+ if len(missing_keys) > 0:
155
+ logger.warning(
156
+ f"Some weights of {pt_model.__class__.__name__} were not initialized from the Flax model and are newly"
157
+ f" initialized: {missing_keys}\nYou should probably TRAIN this model on a down-stream task to be able to"
158
+ " use it for predictions and inference."
159
+ )
160
+
161
+ return pt_model
6DoF/diffusers/models/modeling_utils.py ADDED
@@ -0,0 +1,980 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The HuggingFace Inc. team.
3
+ # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import inspect
18
+ import itertools
19
+ import os
20
+ import re
21
+ from functools import partial
22
+ from typing import Any, Callable, List, Optional, Tuple, Union
23
+
24
+ import torch
25
+ from torch import Tensor, device, nn
26
+
27
+ from .. import __version__
28
+ from ..utils import (
29
+ CONFIG_NAME,
30
+ DIFFUSERS_CACHE,
31
+ FLAX_WEIGHTS_NAME,
32
+ HF_HUB_OFFLINE,
33
+ SAFETENSORS_WEIGHTS_NAME,
34
+ WEIGHTS_NAME,
35
+ _add_variant,
36
+ _get_model_file,
37
+ deprecate,
38
+ is_accelerate_available,
39
+ is_safetensors_available,
40
+ is_torch_version,
41
+ logging,
42
+ )
43
+
44
+
45
+ logger = logging.get_logger(__name__)
46
+
47
+
48
+ if is_torch_version(">=", "1.9.0"):
49
+ _LOW_CPU_MEM_USAGE_DEFAULT = True
50
+ else:
51
+ _LOW_CPU_MEM_USAGE_DEFAULT = False
52
+
53
+
54
+ if is_accelerate_available():
55
+ import accelerate
56
+ from accelerate.utils import set_module_tensor_to_device
57
+ from accelerate.utils.versions import is_torch_version
58
+
59
+ if is_safetensors_available():
60
+ import safetensors
61
+
62
+
63
+ def get_parameter_device(parameter: torch.nn.Module):
64
+ try:
65
+ parameters_and_buffers = itertools.chain(parameter.parameters(), parameter.buffers())
66
+ return next(parameters_and_buffers).device
67
+ except StopIteration:
68
+ # For torch.nn.DataParallel compatibility in PyTorch 1.5
69
+
70
+ def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
71
+ tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
72
+ return tuples
73
+
74
+ gen = parameter._named_members(get_members_fn=find_tensor_attributes)
75
+ first_tuple = next(gen)
76
+ return first_tuple[1].device
77
+
78
+
79
+ def get_parameter_dtype(parameter: torch.nn.Module):
80
+ try:
81
+ params = tuple(parameter.parameters())
82
+ if len(params) > 0:
83
+ return params[0].dtype
84
+
85
+ buffers = tuple(parameter.buffers())
86
+ if len(buffers) > 0:
87
+ return buffers[0].dtype
88
+
89
+ except StopIteration:
90
+ # For torch.nn.DataParallel compatibility in PyTorch 1.5
91
+
92
+ def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
93
+ tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
94
+ return tuples
95
+
96
+ gen = parameter._named_members(get_members_fn=find_tensor_attributes)
97
+ first_tuple = next(gen)
98
+ return first_tuple[1].dtype
99
+
100
+
101
+ def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[str] = None):
102
+ """
103
+ Reads a checkpoint file, returning properly formatted errors if they arise.
104
+ """
105
+ try:
106
+ if os.path.basename(checkpoint_file) == _add_variant(WEIGHTS_NAME, variant):
107
+ return torch.load(checkpoint_file, map_location="cpu")
108
+ else:
109
+ return safetensors.torch.load_file(checkpoint_file, device="cpu")
110
+ except Exception as e:
111
+ try:
112
+ with open(checkpoint_file) as f:
113
+ if f.read().startswith("version"):
114
+ raise OSError(
115
+ "You seem to have cloned a repository without having git-lfs installed. Please install "
116
+ "git-lfs and run `git lfs install` followed by `git lfs pull` in the folder "
117
+ "you cloned."
118
+ )
119
+ else:
120
+ raise ValueError(
121
+ f"Unable to locate the file {checkpoint_file} which is necessary to load this pretrained "
122
+ "model. Make sure you have saved the model properly."
123
+ ) from e
124
+ except (UnicodeDecodeError, ValueError):
125
+ raise OSError(
126
+ f"Unable to load weights from checkpoint file for '{checkpoint_file}' "
127
+ f"at '{checkpoint_file}'. "
128
+ "If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True."
129
+ )
130
+
131
+
132
+ def _load_state_dict_into_model(model_to_load, state_dict):
133
+ # Convert old format to new format if needed from a PyTorch state_dict
134
+ # copy state_dict so _load_from_state_dict can modify it
135
+ state_dict = state_dict.copy()
136
+ error_msgs = []
137
+
138
+ # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
139
+ # so we need to apply the function recursively.
140
+ def load(module: torch.nn.Module, prefix=""):
141
+ args = (state_dict, prefix, {}, True, [], [], error_msgs)
142
+ module._load_from_state_dict(*args)
143
+
144
+ for name, child in module._modules.items():
145
+ if child is not None:
146
+ load(child, prefix + name + ".")
147
+
148
+ load(model_to_load)
149
+
150
+ return error_msgs
151
+
152
+
153
+ class ModelMixin(torch.nn.Module):
154
+ r"""
155
+ Base class for all models.
156
+
157
+ [`ModelMixin`] takes care of storing the model configuration and provides methods for loading, downloading and
158
+ saving models.
159
+
160
+ - **config_name** ([`str`]) -- Filename to save a model to when calling [`~models.ModelMixin.save_pretrained`].
161
+ """
162
+ config_name = CONFIG_NAME
163
+ _automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"]
164
+ _supports_gradient_checkpointing = False
165
+ _keys_to_ignore_on_load_unexpected = None
166
+
167
+ def __init__(self):
168
+ super().__init__()
169
+
170
+ def __getattr__(self, name: str) -> Any:
171
+ """The only reason we overwrite `getattr` here is to gracefully deprecate accessing
172
+ config attributes directly. See https://github.com/huggingface/diffusers/pull/3129 We need to overwrite
173
+ __getattr__ here in addition so that we don't trigger `torch.nn.Module`'s __getattr__':
174
+ https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module
175
+ """
176
+
177
+ is_in_config = "_internal_dict" in self.__dict__ and hasattr(self.__dict__["_internal_dict"], name)
178
+ is_attribute = name in self.__dict__
179
+
180
+ if is_in_config and not is_attribute:
181
+ deprecation_message = f"Accessing config attribute `{name}` directly via '{type(self).__name__}' object attribute is deprecated. Please access '{name}' over '{type(self).__name__}'s config object instead, e.g. 'unet.config.{name}'."
182
+ deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False, stacklevel=3)
183
+ return self._internal_dict[name]
184
+
185
+ # call PyTorch's https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module
186
+ return super().__getattr__(name)
187
+
188
+ @property
189
+ def is_gradient_checkpointing(self) -> bool:
190
+ """
191
+ Whether gradient checkpointing is activated for this model or not.
192
+ """
193
+ return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules())
194
+
195
+ def enable_gradient_checkpointing(self):
196
+ """
197
+ Activates gradient checkpointing for the current model (may be referred to as *activation checkpointing* or
198
+ *checkpoint activations* in other frameworks).
199
+ """
200
+ if not self._supports_gradient_checkpointing:
201
+ raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
202
+ self.apply(partial(self._set_gradient_checkpointing, value=True))
203
+
204
+ def disable_gradient_checkpointing(self):
205
+ """
206
+ Deactivates gradient checkpointing for the current model (may be referred to as *activation checkpointing* or
207
+ *checkpoint activations* in other frameworks).
208
+ """
209
+ if self._supports_gradient_checkpointing:
210
+ self.apply(partial(self._set_gradient_checkpointing, value=False))
211
+
212
+ def set_use_memory_efficient_attention_xformers(
213
+ self, valid: bool, attention_op: Optional[Callable] = None
214
+ ) -> None:
215
+ # Recursively walk through all the children.
216
+ # Any children which exposes the set_use_memory_efficient_attention_xformers method
217
+ # gets the message
218
+ def fn_recursive_set_mem_eff(module: torch.nn.Module):
219
+ if hasattr(module, "set_use_memory_efficient_attention_xformers"):
220
+ module.set_use_memory_efficient_attention_xformers(valid, attention_op)
221
+
222
+ for child in module.children():
223
+ fn_recursive_set_mem_eff(child)
224
+
225
+ for module in self.children():
226
+ if isinstance(module, torch.nn.Module):
227
+ fn_recursive_set_mem_eff(module)
228
+
229
+ def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None):
230
+ r"""
231
+ Enable memory efficient attention from [xFormers](https://facebookresearch.github.io/xformers/).
232
+
233
+ When this option is enabled, you should observe lower GPU memory usage and a potential speed up during
234
+ inference. Speed up during training is not guaranteed.
235
+
236
+ <Tip warning={true}>
237
+
238
+ ⚠️ When memory efficient attention and sliced attention are both enabled, memory efficient attention takes
239
+ precedent.
240
+
241
+ </Tip>
242
+
243
+ Parameters:
244
+ attention_op (`Callable`, *optional*):
245
+ Override the default `None` operator for use as `op` argument to the
246
+ [`memory_efficient_attention()`](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.memory_efficient_attention)
247
+ function of xFormers.
248
+
249
+ Examples:
250
+
251
+ ```py
252
+ >>> import torch
253
+ >>> from diffusers import UNet2DConditionModel
254
+ >>> from xformers.ops import MemoryEfficientAttentionFlashAttentionOp
255
+
256
+ >>> model = UNet2DConditionModel.from_pretrained(
257
+ ... "stabilityai/stable-diffusion-2-1", subfolder="unet", torch_dtype=torch.float16
258
+ ... )
259
+ >>> model = model.to("cuda")
260
+ >>> model.enable_xformers_memory_efficient_attention(attention_op=MemoryEfficientAttentionFlashAttentionOp)
261
+ ```
262
+ """
263
+ self.set_use_memory_efficient_attention_xformers(True, attention_op)
264
+
265
+ def disable_xformers_memory_efficient_attention(self):
266
+ r"""
267
+ Disable memory efficient attention from [xFormers](https://facebookresearch.github.io/xformers/).
268
+ """
269
+ self.set_use_memory_efficient_attention_xformers(False)
270
+
271
+ def save_pretrained(
272
+ self,
273
+ save_directory: Union[str, os.PathLike],
274
+ is_main_process: bool = True,
275
+ save_function: Callable = None,
276
+ safe_serialization: bool = False,
277
+ variant: Optional[str] = None,
278
+ ):
279
+ """
280
+ Save a model and its configuration file to a directory so that it can be reloaded using the
281
+ [`~models.ModelMixin.from_pretrained`] class method.
282
+
283
+ Arguments:
284
+ save_directory (`str` or `os.PathLike`):
285
+ Directory to save a model and its configuration file to. Will be created if it doesn't exist.
286
+ is_main_process (`bool`, *optional*, defaults to `True`):
287
+ Whether the process calling this is the main process or not. Useful during distributed training and you
288
+ need to call this function on all processes. In this case, set `is_main_process=True` only on the main
289
+ process to avoid race conditions.
290
+ save_function (`Callable`):
291
+ The function to use to save the state dictionary. Useful during distributed training when you need to
292
+ replace `torch.save` with another method. Can be configured with the environment variable
293
+ `DIFFUSERS_SAVE_MODE`.
294
+ safe_serialization (`bool`, *optional*, defaults to `False`):
295
+ Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
296
+ variant (`str`, *optional*):
297
+ If specified, weights are saved in the format `pytorch_model.<variant>.bin`.
298
+ """
299
+ if safe_serialization and not is_safetensors_available():
300
+ raise ImportError("`safe_serialization` requires the `safetensors library: `pip install safetensors`.")
301
+
302
+ if os.path.isfile(save_directory):
303
+ logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
304
+ return
305
+
306
+ os.makedirs(save_directory, exist_ok=True)
307
+
308
+ model_to_save = self
309
+
310
+ # Attach architecture to the config
311
+ # Save the config
312
+ if is_main_process:
313
+ model_to_save.save_config(save_directory)
314
+
315
+ # Save the model
316
+ state_dict = model_to_save.state_dict()
317
+
318
+ weights_name = SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
319
+ weights_name = _add_variant(weights_name, variant)
320
+
321
+ # Save the model
322
+ if safe_serialization:
323
+ safetensors.torch.save_file(
324
+ state_dict, os.path.join(save_directory, weights_name), metadata={"format": "pt"}
325
+ )
326
+ else:
327
+ torch.save(state_dict, os.path.join(save_directory, weights_name))
328
+
329
+ logger.info(f"Model weights saved in {os.path.join(save_directory, weights_name)}")
330
+
331
+ @classmethod
332
+ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
333
+ r"""
334
+ Instantiate a pretrained PyTorch model from a pretrained model configuration.
335
+
336
+ The model is set in evaluation mode - `model.eval()` - by default, and dropout modules are deactivated. To
337
+ train the model, set it back in training mode with `model.train()`.
338
+
339
+ Parameters:
340
+ pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
341
+ Can be either:
342
+
343
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
344
+ the Hub.
345
+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
346
+ with [`~ModelMixin.save_pretrained`].
347
+
348
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
349
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
350
+ is not used.
351
+ torch_dtype (`str` or `torch.dtype`, *optional*):
352
+ Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
353
+ dtype is automatically derived from the model's weights.
354
+ force_download (`bool`, *optional*, defaults to `False`):
355
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
356
+ cached versions if they exist.
357
+ resume_download (`bool`, *optional*, defaults to `False`):
358
+ Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
359
+ incompletely downloaded files are deleted.
360
+ proxies (`Dict[str, str]`, *optional*):
361
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
362
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
363
+ output_loading_info (`bool`, *optional*, defaults to `False`):
364
+ Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
365
+ local_files_only(`bool`, *optional*, defaults to `False`):
366
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
367
+ won't be downloaded from the Hub.
368
+ use_auth_token (`str` or *bool*, *optional*):
369
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
370
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
371
+ revision (`str`, *optional*, defaults to `"main"`):
372
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
373
+ allowed by Git.
374
+ from_flax (`bool`, *optional*, defaults to `False`):
375
+ Load the model weights from a Flax checkpoint save file.
376
+ subfolder (`str`, *optional*, defaults to `""`):
377
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
378
+ mirror (`str`, *optional*):
379
+ Mirror source to resolve accessibility issues if you're downloading a model in China. We do not
380
+ guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
381
+ information.
382
+ device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
383
+ A map that specifies where each submodule should go. It doesn't need to be defined for each
384
+ parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the
385
+ same device.
386
+
387
+ Set `device_map="auto"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For
388
+ more information about each option see [designing a device
389
+ map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
390
+ max_memory (`Dict`, *optional*):
391
+ A dictionary device identifier for the maximum memory. Will default to the maximum memory available for
392
+ each GPU and the available CPU RAM if unset.
393
+ offload_folder (`str` or `os.PathLike`, *optional*):
394
+ The path to offload weights if `device_map` contains the value `"disk"`.
395
+ offload_state_dict (`bool`, *optional*):
396
+ If `True`, temporarily offloads the CPU state dict to the hard drive to avoid running out of CPU RAM if
397
+ the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to `True`
398
+ when there is some disk offload.
399
+ low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
400
+ Speed up model loading only loading the pretrained weights and not initializing the weights. This also
401
+ tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
402
+ Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
403
+ argument to `True` will raise an error.
404
+ variant (`str`, *optional*):
405
+ Load weights from a specified `variant` filename such as `"fp16"` or `"ema"`. This is ignored when
406
+ loading `from_flax`.
407
+ use_safetensors (`bool`, *optional*, defaults to `None`):
408
+ If set to `None`, the `safetensors` weights are downloaded if they're available **and** if the
409
+ `safetensors` library is installed. If set to `True`, the model is forcibly loaded from `safetensors`
410
+ weights. If set to `False`, `safetensors` weights are not loaded.
411
+
412
+ <Tip>
413
+
414
+ To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with
415
+ `huggingface-cli login`. You can also activate the special
416
+ ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use this method in a
417
+ firewalled environment.
418
+
419
+ </Tip>
420
+
421
+ Example:
422
+
423
+ ```py
424
+ from diffusers import UNet2DConditionModel
425
+
426
+ unet = UNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet")
427
+ ```
428
+
429
+ If you get the error message below, you need to finetune the weights for your downstream task:
430
+
431
+ ```bash
432
+ Some weights of UNet2DConditionModel were not initialized from the model checkpoint at runwayml/stable-diffusion-v1-5 and are newly initialized because the shapes did not match:
433
+ - conv_in.weight: found shape torch.Size([320, 4, 3, 3]) in the checkpoint and torch.Size([320, 9, 3, 3]) in the model instantiated
434
+ You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
435
+ ```
436
+ """
437
+ cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
438
+ ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
439
+ force_download = kwargs.pop("force_download", False)
440
+ from_flax = kwargs.pop("from_flax", False)
441
+ resume_download = kwargs.pop("resume_download", False)
442
+ proxies = kwargs.pop("proxies", None)
443
+ output_loading_info = kwargs.pop("output_loading_info", False)
444
+ local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
445
+ use_auth_token = kwargs.pop("use_auth_token", None)
446
+ revision = kwargs.pop("revision", None)
447
+ torch_dtype = kwargs.pop("torch_dtype", None)
448
+ subfolder = kwargs.pop("subfolder", None)
449
+ device_map = kwargs.pop("device_map", None)
450
+ max_memory = kwargs.pop("max_memory", None)
451
+ offload_folder = kwargs.pop("offload_folder", None)
452
+ offload_state_dict = kwargs.pop("offload_state_dict", False)
453
+ low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
454
+ variant = kwargs.pop("variant", None)
455
+ use_safetensors = kwargs.pop("use_safetensors", None)
456
+
457
+ if use_safetensors and not is_safetensors_available():
458
+ raise ValueError(
459
+ "`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetensors"
460
+ )
461
+
462
+ allow_pickle = False
463
+ if use_safetensors is None:
464
+ use_safetensors = is_safetensors_available()
465
+ allow_pickle = True
466
+
467
+ if low_cpu_mem_usage and not is_accelerate_available():
468
+ low_cpu_mem_usage = False
469
+ logger.warning(
470
+ "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
471
+ " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
472
+ " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
473
+ " install accelerate\n```\n."
474
+ )
475
+
476
+ if device_map is not None and not is_accelerate_available():
477
+ raise NotImplementedError(
478
+ "Loading and dispatching requires `accelerate`. Please make sure to install accelerate or set"
479
+ " `device_map=None`. You can install accelerate with `pip install accelerate`."
480
+ )
481
+
482
+ # Check if we can handle device_map and dispatching the weights
483
+ if device_map is not None and not is_torch_version(">=", "1.9.0"):
484
+ raise NotImplementedError(
485
+ "Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set"
486
+ " `device_map=None`."
487
+ )
488
+
489
+ if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
490
+ raise NotImplementedError(
491
+ "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
492
+ " `low_cpu_mem_usage=False`."
493
+ )
494
+
495
+ if low_cpu_mem_usage is False and device_map is not None:
496
+ raise ValueError(
497
+ f"You cannot set `low_cpu_mem_usage` to `False` while using device_map={device_map} for loading and"
498
+ " dispatching. Please make sure to set `low_cpu_mem_usage=True`."
499
+ )
500
+
501
+ # Load config if we don't provide a configuration
502
+ config_path = pretrained_model_name_or_path
503
+
504
+ user_agent = {
505
+ "diffusers": __version__,
506
+ "file_type": "model",
507
+ "framework": "pytorch",
508
+ }
509
+
510
+ # load config
511
+ config, unused_kwargs, commit_hash = cls.load_config(
512
+ config_path,
513
+ cache_dir=cache_dir,
514
+ return_unused_kwargs=True,
515
+ return_commit_hash=True,
516
+ force_download=force_download,
517
+ resume_download=resume_download,
518
+ proxies=proxies,
519
+ local_files_only=local_files_only,
520
+ use_auth_token=use_auth_token,
521
+ revision=revision,
522
+ subfolder=subfolder,
523
+ device_map=device_map,
524
+ max_memory=max_memory,
525
+ offload_folder=offload_folder,
526
+ offload_state_dict=offload_state_dict,
527
+ user_agent=user_agent,
528
+ **kwargs,
529
+ )
530
+
531
+ # load model
532
+ model_file = None
533
+ if from_flax:
534
+ model_file = _get_model_file(
535
+ pretrained_model_name_or_path,
536
+ weights_name=FLAX_WEIGHTS_NAME,
537
+ cache_dir=cache_dir,
538
+ force_download=force_download,
539
+ resume_download=resume_download,
540
+ proxies=proxies,
541
+ local_files_only=local_files_only,
542
+ use_auth_token=use_auth_token,
543
+ revision=revision,
544
+ subfolder=subfolder,
545
+ user_agent=user_agent,
546
+ commit_hash=commit_hash,
547
+ )
548
+ model = cls.from_config(config, **unused_kwargs)
549
+
550
+ # Convert the weights
551
+ from .modeling_pytorch_flax_utils import load_flax_checkpoint_in_pytorch_model
552
+
553
+ model = load_flax_checkpoint_in_pytorch_model(model, model_file)
554
+ else:
555
+ if use_safetensors:
556
+ try:
557
+ model_file = _get_model_file(
558
+ pretrained_model_name_or_path,
559
+ weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant),
560
+ cache_dir=cache_dir,
561
+ force_download=force_download,
562
+ resume_download=resume_download,
563
+ proxies=proxies,
564
+ local_files_only=local_files_only,
565
+ use_auth_token=use_auth_token,
566
+ revision=revision,
567
+ subfolder=subfolder,
568
+ user_agent=user_agent,
569
+ commit_hash=commit_hash,
570
+ )
571
+ except IOError as e:
572
+ if not allow_pickle:
573
+ raise e
574
+ pass
575
+ if model_file is None:
576
+ model_file = _get_model_file(
577
+ pretrained_model_name_or_path,
578
+ weights_name=_add_variant(WEIGHTS_NAME, variant),
579
+ cache_dir=cache_dir,
580
+ force_download=force_download,
581
+ resume_download=resume_download,
582
+ proxies=proxies,
583
+ local_files_only=local_files_only,
584
+ use_auth_token=use_auth_token,
585
+ revision=revision,
586
+ subfolder=subfolder,
587
+ user_agent=user_agent,
588
+ commit_hash=commit_hash,
589
+ )
590
+
591
+ if low_cpu_mem_usage:
592
+ # Instantiate model with empty weights
593
+ with accelerate.init_empty_weights():
594
+ model = cls.from_config(config, **unused_kwargs)
595
+
596
+ # if device_map is None, load the state dict and move the params from meta device to the cpu
597
+ if device_map is None:
598
+ param_device = "cpu"
599
+ state_dict = load_state_dict(model_file, variant=variant)
600
+ model._convert_deprecated_attention_blocks(state_dict)
601
+ # move the params from meta device to cpu
602
+ missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
603
+ if len(missing_keys) > 0:
604
+ raise ValueError(
605
+ f"Cannot load {cls} from {pretrained_model_name_or_path} because the following keys are"
606
+ f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass"
607
+ " `low_cpu_mem_usage=False` and `device_map=None` if you want to randomly initialize"
608
+ " those weights or else make sure your checkpoint file is correct."
609
+ )
610
+ unexpected_keys = []
611
+
612
+ empty_state_dict = model.state_dict()
613
+ for param_name, param in state_dict.items():
614
+ accepts_dtype = "dtype" in set(
615
+ inspect.signature(set_module_tensor_to_device).parameters.keys()
616
+ )
617
+
618
+ if param_name not in empty_state_dict:
619
+ unexpected_keys.append(param_name)
620
+ continue
621
+
622
+ if empty_state_dict[param_name].shape != param.shape:
623
+ raise ValueError(
624
+ f"Cannot load {pretrained_model_name_or_path} because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example."
625
+ )
626
+
627
+ if accepts_dtype:
628
+ set_module_tensor_to_device(
629
+ model, param_name, param_device, value=param, dtype=torch_dtype
630
+ )
631
+ else:
632
+ set_module_tensor_to_device(model, param_name, param_device, value=param)
633
+
634
+ if cls._keys_to_ignore_on_load_unexpected is not None:
635
+ for pat in cls._keys_to_ignore_on_load_unexpected:
636
+ unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
637
+
638
+ if len(unexpected_keys) > 0:
639
+ logger.warn(
640
+ f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
641
+ )
642
+
643
+ else: # else let accelerate handle loading and dispatching.
644
+ # Load weights and dispatch according to the device_map
645
+ # by default the device_map is None and the weights are loaded on the CPU
646
+ try:
647
+ accelerate.load_checkpoint_and_dispatch(
648
+ model,
649
+ model_file,
650
+ device_map,
651
+ max_memory=max_memory,
652
+ offload_folder=offload_folder,
653
+ offload_state_dict=offload_state_dict,
654
+ dtype=torch_dtype,
655
+ )
656
+ except AttributeError as e:
657
+ # When using accelerate loading, we do not have the ability to load the state
658
+ # dict and rename the weight names manually. Additionally, accelerate skips
659
+ # torch loading conventions and directly writes into `module.{_buffers, _parameters}`
660
+ # (which look like they should be private variables?), so we can't use the standard hooks
661
+ # to rename parameters on load. We need to mimic the original weight names so the correct
662
+ # attributes are available. After we have loaded the weights, we convert the deprecated
663
+ # names to the new non-deprecated names. Then we _greatly encourage_ the user to convert
664
+ # the weights so we don't have to do this again.
665
+
666
+ if "'Attention' object has no attribute" in str(e):
667
+ logger.warn(
668
+ f"Taking `{str(e)}` while using `accelerate.load_checkpoint_and_dispatch` to mean {pretrained_model_name_or_path}"
669
+ " was saved with deprecated attention block weight names. We will load it with the deprecated attention block"
670
+ " names and convert them on the fly to the new attention block format. Please re-save the model after this conversion,"
671
+ " so we don't have to do the on the fly renaming in the future. If the model is from a hub checkpoint,"
672
+ " please also re-upload it or open a PR on the original repository."
673
+ )
674
+ model._temp_convert_self_to_deprecated_attention_blocks()
675
+ accelerate.load_checkpoint_and_dispatch(
676
+ model,
677
+ model_file,
678
+ device_map,
679
+ max_memory=max_memory,
680
+ offload_folder=offload_folder,
681
+ offload_state_dict=offload_state_dict,
682
+ dtype=torch_dtype,
683
+ )
684
+ model._undo_temp_convert_self_to_deprecated_attention_blocks()
685
+ else:
686
+ raise e
687
+
688
+ loading_info = {
689
+ "missing_keys": [],
690
+ "unexpected_keys": [],
691
+ "mismatched_keys": [],
692
+ "error_msgs": [],
693
+ }
694
+ else:
695
+ model = cls.from_config(config, **unused_kwargs)
696
+
697
+ state_dict = load_state_dict(model_file, variant=variant)
698
+ model._convert_deprecated_attention_blocks(state_dict)
699
+
700
+ model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
701
+ model,
702
+ state_dict,
703
+ model_file,
704
+ pretrained_model_name_or_path,
705
+ ignore_mismatched_sizes=ignore_mismatched_sizes,
706
+ )
707
+
708
+ loading_info = {
709
+ "missing_keys": missing_keys,
710
+ "unexpected_keys": unexpected_keys,
711
+ "mismatched_keys": mismatched_keys,
712
+ "error_msgs": error_msgs,
713
+ }
714
+
715
+ if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
716
+ raise ValueError(
717
+ f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}."
718
+ )
719
+ elif torch_dtype is not None:
720
+ model = model.to(torch_dtype)
721
+
722
+ model.register_to_config(_name_or_path=pretrained_model_name_or_path)
723
+
724
+ # Set model in evaluation mode to deactivate DropOut modules by default
725
+ model.eval()
726
+ if output_loading_info:
727
+ return model, loading_info
728
+
729
+ return model
730
+
731
+ @classmethod
732
+ def _load_pretrained_model(
733
+ cls,
734
+ model,
735
+ state_dict,
736
+ resolved_archive_file,
737
+ pretrained_model_name_or_path,
738
+ ignore_mismatched_sizes=False,
739
+ ):
740
+ # Retrieve missing & unexpected_keys
741
+ model_state_dict = model.state_dict()
742
+ loaded_keys = list(state_dict.keys())
743
+
744
+ expected_keys = list(model_state_dict.keys())
745
+
746
+ original_loaded_keys = loaded_keys
747
+
748
+ missing_keys = list(set(expected_keys) - set(loaded_keys))
749
+ unexpected_keys = list(set(loaded_keys) - set(expected_keys))
750
+
751
+ # Make sure we are able to load base models as well as derived models (with heads)
752
+ model_to_load = model
753
+
754
+ def _find_mismatched_keys(
755
+ state_dict,
756
+ model_state_dict,
757
+ loaded_keys,
758
+ ignore_mismatched_sizes,
759
+ ):
760
+ mismatched_keys = []
761
+ if ignore_mismatched_sizes:
762
+ for checkpoint_key in loaded_keys:
763
+ model_key = checkpoint_key
764
+
765
+ if (
766
+ model_key in model_state_dict
767
+ and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape
768
+ ):
769
+ mismatched_keys.append(
770
+ (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
771
+ )
772
+ del state_dict[checkpoint_key]
773
+ return mismatched_keys
774
+
775
+ if state_dict is not None:
776
+ # Whole checkpoint
777
+ mismatched_keys = _find_mismatched_keys(
778
+ state_dict,
779
+ model_state_dict,
780
+ original_loaded_keys,
781
+ ignore_mismatched_sizes,
782
+ )
783
+ error_msgs = _load_state_dict_into_model(model_to_load, state_dict)
784
+
785
+ if len(error_msgs) > 0:
786
+ error_msg = "\n\t".join(error_msgs)
787
+ if "size mismatch" in error_msg:
788
+ error_msg += (
789
+ "\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method."
790
+ )
791
+ raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")
792
+
793
+ if len(unexpected_keys) > 0:
794
+ logger.warning(
795
+ f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
796
+ f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
797
+ f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task"
798
+ " or with another architecture (e.g. initializing a BertForSequenceClassification model from a"
799
+ " BertForPreTraining model).\n- This IS NOT expected if you are initializing"
800
+ f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly"
801
+ " identical (initializing a BertForSequenceClassification model from a"
802
+ " BertForSequenceClassification model)."
803
+ )
804
+ else:
805
+ logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
806
+ if len(missing_keys) > 0:
807
+ logger.warning(
808
+ f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
809
+ f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably"
810
+ " TRAIN this model on a down-stream task to be able to use it for predictions and inference."
811
+ )
812
+ elif len(mismatched_keys) == 0:
813
+ logger.info(
814
+ f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at"
815
+ f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the"
816
+ f" checkpoint was trained on, you can already use {model.__class__.__name__} for predictions"
817
+ " without further training."
818
+ )
819
+ if len(mismatched_keys) > 0:
820
+ mismatched_warning = "\n".join(
821
+ [
822
+ f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
823
+ for key, shape1, shape2 in mismatched_keys
824
+ ]
825
+ )
826
+ logger.warning(
827
+ f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
828
+ f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not"
829
+ f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be"
830
+ " able to use it for predictions and inference."
831
+ )
832
+
833
+ return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs
834
+
835
+ @property
836
+ def device(self) -> device:
837
+ """
838
+ `torch.device`: The device on which the module is (assuming that all the module parameters are on the same
839
+ device).
840
+ """
841
+ return get_parameter_device(self)
842
+
843
+ @property
844
+ def dtype(self) -> torch.dtype:
845
+ """
846
+ `torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
847
+ """
848
+ return get_parameter_dtype(self)
849
+
850
+ def num_parameters(self, only_trainable: bool = False, exclude_embeddings: bool = False) -> int:
851
+ """
852
+ Get number of (trainable or non-embedding) parameters in the module.
853
+
854
+ Args:
855
+ only_trainable (`bool`, *optional*, defaults to `False`):
856
+ Whether or not to return only the number of trainable parameters.
857
+ exclude_embeddings (`bool`, *optional*, defaults to `False`):
858
+ Whether or not to return only the number of non-embedding parameters.
859
+
860
+ Returns:
861
+ `int`: The number of parameters.
862
+
863
+ Example:
864
+
865
+ ```py
866
+ from diffusers import UNet2DConditionModel
867
+
868
+ model_id = "runwayml/stable-diffusion-v1-5"
869
+ unet = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet")
870
+ unet.num_parameters(only_trainable=True)
871
+ 859520964
872
+ ```
873
+ """
874
+
875
+ if exclude_embeddings:
876
+ embedding_param_names = [
877
+ f"{name}.weight"
878
+ for name, module_type in self.named_modules()
879
+ if isinstance(module_type, torch.nn.Embedding)
880
+ ]
881
+ non_embedding_parameters = [
882
+ parameter for name, parameter in self.named_parameters() if name not in embedding_param_names
883
+ ]
884
+ return sum(p.numel() for p in non_embedding_parameters if p.requires_grad or not only_trainable)
885
+ else:
886
+ return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable)
887
+
888
+ def _convert_deprecated_attention_blocks(self, state_dict):
889
+ deprecated_attention_block_paths = []
890
+
891
+ def recursive_find_attn_block(name, module):
892
+ if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block:
893
+ deprecated_attention_block_paths.append(name)
894
+
895
+ for sub_name, sub_module in module.named_children():
896
+ sub_name = sub_name if name == "" else f"{name}.{sub_name}"
897
+ recursive_find_attn_block(sub_name, sub_module)
898
+
899
+ recursive_find_attn_block("", self)
900
+
901
+ # NOTE: we have to check if the deprecated parameters are in the state dict
902
+ # because it is possible we are loading from a state dict that was already
903
+ # converted
904
+
905
+ for path in deprecated_attention_block_paths:
906
+ # group_norm path stays the same
907
+
908
+ # query -> to_q
909
+ if f"{path}.query.weight" in state_dict:
910
+ state_dict[f"{path}.to_q.weight"] = state_dict.pop(f"{path}.query.weight")
911
+ if f"{path}.query.bias" in state_dict:
912
+ state_dict[f"{path}.to_q.bias"] = state_dict.pop(f"{path}.query.bias")
913
+
914
+ # key -> to_k
915
+ if f"{path}.key.weight" in state_dict:
916
+ state_dict[f"{path}.to_k.weight"] = state_dict.pop(f"{path}.key.weight")
917
+ if f"{path}.key.bias" in state_dict:
918
+ state_dict[f"{path}.to_k.bias"] = state_dict.pop(f"{path}.key.bias")
919
+
920
+ # value -> to_v
921
+ if f"{path}.value.weight" in state_dict:
922
+ state_dict[f"{path}.to_v.weight"] = state_dict.pop(f"{path}.value.weight")
923
+ if f"{path}.value.bias" in state_dict:
924
+ state_dict[f"{path}.to_v.bias"] = state_dict.pop(f"{path}.value.bias")
925
+
926
+ # proj_attn -> to_out.0
927
+ if f"{path}.proj_attn.weight" in state_dict:
928
+ state_dict[f"{path}.to_out.0.weight"] = state_dict.pop(f"{path}.proj_attn.weight")
929
+ if f"{path}.proj_attn.bias" in state_dict:
930
+ state_dict[f"{path}.to_out.0.bias"] = state_dict.pop(f"{path}.proj_attn.bias")
931
+
932
+ def _temp_convert_self_to_deprecated_attention_blocks(self):
933
+ deprecated_attention_block_modules = []
934
+
935
+ def recursive_find_attn_block(module):
936
+ if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block:
937
+ deprecated_attention_block_modules.append(module)
938
+
939
+ for sub_module in module.children():
940
+ recursive_find_attn_block(sub_module)
941
+
942
+ recursive_find_attn_block(self)
943
+
944
+ for module in deprecated_attention_block_modules:
945
+ module.query = module.to_q
946
+ module.key = module.to_k
947
+ module.value = module.to_v
948
+ module.proj_attn = module.to_out[0]
949
+
950
+ # We don't _have_ to delete the old attributes, but it's helpful to ensure
951
+ # that _all_ the weights are loaded into the new attributes and we're not
952
+ # making an incorrect assumption that this model should be converted when
953
+ # it really shouldn't be.
954
+ del module.to_q
955
+ del module.to_k
956
+ del module.to_v
957
+ del module.to_out
958
+
959
+ def _undo_temp_convert_self_to_deprecated_attention_blocks(self):
960
+ deprecated_attention_block_modules = []
961
+
962
+ def recursive_find_attn_block(module):
963
+ if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block:
964
+ deprecated_attention_block_modules.append(module)
965
+
966
+ for sub_module in module.children():
967
+ recursive_find_attn_block(sub_module)
968
+
969
+ recursive_find_attn_block(self)
970
+
971
+ for module in deprecated_attention_block_modules:
972
+ module.to_q = module.query
973
+ module.to_k = module.key
974
+ module.to_v = module.value
975
+ module.to_out = nn.ModuleList([module.proj_attn, nn.Dropout(module.dropout)])
976
+
977
+ del module.query
978
+ del module.key
979
+ del module.value
980
+ del module.proj_attn
6DoF/diffusers/models/prior_transformer.py ADDED
@@ -0,0 +1,364 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Dict, Optional, Union
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from torch import nn
7
+
8
+ from ..configuration_utils import ConfigMixin, register_to_config
9
+ from ..utils import BaseOutput
10
+ from .attention import BasicTransformerBlock
11
+ from .attention_processor import AttentionProcessor, AttnProcessor
12
+ from .embeddings import TimestepEmbedding, Timesteps
13
+ from .modeling_utils import ModelMixin
14
+
15
+
16
+ @dataclass
17
+ class PriorTransformerOutput(BaseOutput):
18
+ """
19
+ The output of [`PriorTransformer`].
20
+
21
+ Args:
22
+ predicted_image_embedding (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`):
23
+ The predicted CLIP image embedding conditioned on the CLIP text embedding input.
24
+ """
25
+
26
+ predicted_image_embedding: torch.FloatTensor
27
+
28
+
29
+ class PriorTransformer(ModelMixin, ConfigMixin):
30
+ """
31
+ A Prior Transformer model.
32
+
33
+ Parameters:
34
+ num_attention_heads (`int`, *optional*, defaults to 32): The number of heads to use for multi-head attention.
35
+ attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
36
+ num_layers (`int`, *optional*, defaults to 20): The number of layers of Transformer blocks to use.
37
+ embedding_dim (`int`, *optional*, defaults to 768): The dimension of the model input `hidden_states`
38
+ num_embeddings (`int`, *optional*, defaults to 77):
39
+ The number of embeddings of the model input `hidden_states`
40
+ additional_embeddings (`int`, *optional*, defaults to 4): The number of additional tokens appended to the
41
+ projected `hidden_states`. The actual length of the used `hidden_states` is `num_embeddings +
42
+ additional_embeddings`.
43
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
44
+ time_embed_act_fn (`str`, *optional*, defaults to 'silu'):
45
+ The activation function to use to create timestep embeddings.
46
+ norm_in_type (`str`, *optional*, defaults to None): The normalization layer to apply on hidden states before
47
+ passing to Transformer blocks. Set it to `None` if normalization is not needed.
48
+ embedding_proj_norm_type (`str`, *optional*, defaults to None):
49
+ The normalization layer to apply on the input `proj_embedding`. Set it to `None` if normalization is not
50
+ needed.
51
+ encoder_hid_proj_type (`str`, *optional*, defaults to `linear`):
52
+ The projection layer to apply on the input `encoder_hidden_states`. Set it to `None` if
53
+ `encoder_hidden_states` is `None`.
54
+ added_emb_type (`str`, *optional*, defaults to `prd`): Additional embeddings to condition the model.
55
+ Choose from `prd` or `None`. if choose `prd`, it will prepend a token indicating the (quantized) dot
56
+ product between the text embedding and image embedding as proposed in the unclip paper
57
+ https://arxiv.org/abs/2204.06125 If it is `None`, no additional embeddings will be prepended.
58
+ time_embed_dim (`int, *optional*, defaults to None): The dimension of timestep embeddings.
59
+ If None, will be set to `num_attention_heads * attention_head_dim`
60
+ embedding_proj_dim (`int`, *optional*, default to None):
61
+ The dimension of `proj_embedding`. If None, will be set to `embedding_dim`.
62
+ clip_embed_dim (`int`, *optional*, default to None):
63
+ The dimension of the output. If None, will be set to `embedding_dim`.
64
+ """
65
+
66
+ @register_to_config
67
+ def __init__(
68
+ self,
69
+ num_attention_heads: int = 32,
70
+ attention_head_dim: int = 64,
71
+ num_layers: int = 20,
72
+ embedding_dim: int = 768,
73
+ num_embeddings=77,
74
+ additional_embeddings=4,
75
+ dropout: float = 0.0,
76
+ time_embed_act_fn: str = "silu",
77
+ norm_in_type: Optional[str] = None, # layer
78
+ embedding_proj_norm_type: Optional[str] = None, # layer
79
+ encoder_hid_proj_type: Optional[str] = "linear", # linear
80
+ added_emb_type: Optional[str] = "prd", # prd
81
+ time_embed_dim: Optional[int] = None,
82
+ embedding_proj_dim: Optional[int] = None,
83
+ clip_embed_dim: Optional[int] = None,
84
+ ):
85
+ super().__init__()
86
+ self.num_attention_heads = num_attention_heads
87
+ self.attention_head_dim = attention_head_dim
88
+ inner_dim = num_attention_heads * attention_head_dim
89
+ self.additional_embeddings = additional_embeddings
90
+
91
+ time_embed_dim = time_embed_dim or inner_dim
92
+ embedding_proj_dim = embedding_proj_dim or embedding_dim
93
+ clip_embed_dim = clip_embed_dim or embedding_dim
94
+
95
+ self.time_proj = Timesteps(inner_dim, True, 0)
96
+ self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, out_dim=inner_dim, act_fn=time_embed_act_fn)
97
+
98
+ self.proj_in = nn.Linear(embedding_dim, inner_dim)
99
+
100
+ if embedding_proj_norm_type is None:
101
+ self.embedding_proj_norm = None
102
+ elif embedding_proj_norm_type == "layer":
103
+ self.embedding_proj_norm = nn.LayerNorm(embedding_proj_dim)
104
+ else:
105
+ raise ValueError(f"unsupported embedding_proj_norm_type: {embedding_proj_norm_type}")
106
+
107
+ self.embedding_proj = nn.Linear(embedding_proj_dim, inner_dim)
108
+
109
+ if encoder_hid_proj_type is None:
110
+ self.encoder_hidden_states_proj = None
111
+ elif encoder_hid_proj_type == "linear":
112
+ self.encoder_hidden_states_proj = nn.Linear(embedding_dim, inner_dim)
113
+ else:
114
+ raise ValueError(f"unsupported encoder_hid_proj_type: {encoder_hid_proj_type}")
115
+
116
+ self.positional_embedding = nn.Parameter(torch.zeros(1, num_embeddings + additional_embeddings, inner_dim))
117
+
118
+ if added_emb_type == "prd":
119
+ self.prd_embedding = nn.Parameter(torch.zeros(1, 1, inner_dim))
120
+ elif added_emb_type is None:
121
+ self.prd_embedding = None
122
+ else:
123
+ raise ValueError(
124
+ f"`added_emb_type`: {added_emb_type} is not supported. Make sure to choose one of `'prd'` or `None`."
125
+ )
126
+
127
+ self.transformer_blocks = nn.ModuleList(
128
+ [
129
+ BasicTransformerBlock(
130
+ inner_dim,
131
+ num_attention_heads,
132
+ attention_head_dim,
133
+ dropout=dropout,
134
+ activation_fn="gelu",
135
+ attention_bias=True,
136
+ )
137
+ for d in range(num_layers)
138
+ ]
139
+ )
140
+
141
+ if norm_in_type == "layer":
142
+ self.norm_in = nn.LayerNorm(inner_dim)
143
+ elif norm_in_type is None:
144
+ self.norm_in = None
145
+ else:
146
+ raise ValueError(f"Unsupported norm_in_type: {norm_in_type}.")
147
+
148
+ self.norm_out = nn.LayerNorm(inner_dim)
149
+
150
+ self.proj_to_clip_embeddings = nn.Linear(inner_dim, clip_embed_dim)
151
+
152
+ causal_attention_mask = torch.full(
153
+ [num_embeddings + additional_embeddings, num_embeddings + additional_embeddings], -10000.0
154
+ )
155
+ causal_attention_mask.triu_(1)
156
+ causal_attention_mask = causal_attention_mask[None, ...]
157
+ self.register_buffer("causal_attention_mask", causal_attention_mask, persistent=False)
158
+
159
+ self.clip_mean = nn.Parameter(torch.zeros(1, clip_embed_dim))
160
+ self.clip_std = nn.Parameter(torch.zeros(1, clip_embed_dim))
161
+
162
+ @property
163
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
164
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
165
+ r"""
166
+ Returns:
167
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
168
+ indexed by its weight name.
169
+ """
170
+ # set recursively
171
+ processors = {}
172
+
173
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
174
+ if hasattr(module, "set_processor"):
175
+ processors[f"{name}.processor"] = module.processor
176
+
177
+ for sub_name, child in module.named_children():
178
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
179
+
180
+ return processors
181
+
182
+ for name, module in self.named_children():
183
+ fn_recursive_add_processors(name, module, processors)
184
+
185
+ return processors
186
+
187
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
188
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
189
+ r"""
190
+ Sets the attention processor to use to compute attention.
191
+
192
+ Parameters:
193
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
194
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
195
+ for **all** `Attention` layers.
196
+
197
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
198
+ processor. This is strongly recommended when setting trainable attention processors.
199
+
200
+ """
201
+ count = len(self.attn_processors.keys())
202
+
203
+ if isinstance(processor, dict) and len(processor) != count:
204
+ raise ValueError(
205
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
206
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
207
+ )
208
+
209
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
210
+ if hasattr(module, "set_processor"):
211
+ if not isinstance(processor, dict):
212
+ module.set_processor(processor)
213
+ else:
214
+ module.set_processor(processor.pop(f"{name}.processor"))
215
+
216
+ for sub_name, child in module.named_children():
217
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
218
+
219
+ for name, module in self.named_children():
220
+ fn_recursive_attn_processor(name, module, processor)
221
+
222
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
223
+ def set_default_attn_processor(self):
224
+ """
225
+ Disables custom attention processors and sets the default attention implementation.
226
+ """
227
+ self.set_attn_processor(AttnProcessor())
228
+
229
+ def forward(
230
+ self,
231
+ hidden_states,
232
+ timestep: Union[torch.Tensor, float, int],
233
+ proj_embedding: torch.FloatTensor,
234
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
235
+ attention_mask: Optional[torch.BoolTensor] = None,
236
+ return_dict: bool = True,
237
+ ):
238
+ """
239
+ The [`PriorTransformer`] forward method.
240
+
241
+ Args:
242
+ hidden_states (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`):
243
+ The currently predicted image embeddings.
244
+ timestep (`torch.LongTensor`):
245
+ Current denoising step.
246
+ proj_embedding (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`):
247
+ Projected embedding vector the denoising process is conditioned on.
248
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, num_embeddings, embedding_dim)`):
249
+ Hidden states of the text embeddings the denoising process is conditioned on.
250
+ attention_mask (`torch.BoolTensor` of shape `(batch_size, num_embeddings)`):
251
+ Text mask for the text embeddings.
252
+ return_dict (`bool`, *optional*, defaults to `True`):
253
+ Whether or not to return a [`~models.prior_transformer.PriorTransformerOutput`] instead of a plain
254
+ tuple.
255
+
256
+ Returns:
257
+ [`~models.prior_transformer.PriorTransformerOutput`] or `tuple`:
258
+ If return_dict is True, a [`~models.prior_transformer.PriorTransformerOutput`] is returned, otherwise a
259
+ tuple is returned where the first element is the sample tensor.
260
+ """
261
+ batch_size = hidden_states.shape[0]
262
+
263
+ timesteps = timestep
264
+ if not torch.is_tensor(timesteps):
265
+ timesteps = torch.tensor([timesteps], dtype=torch.long, device=hidden_states.device)
266
+ elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
267
+ timesteps = timesteps[None].to(hidden_states.device)
268
+
269
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
270
+ timesteps = timesteps * torch.ones(batch_size, dtype=timesteps.dtype, device=timesteps.device)
271
+
272
+ timesteps_projected = self.time_proj(timesteps)
273
+
274
+ # timesteps does not contain any weights and will always return f32 tensors
275
+ # but time_embedding might be fp16, so we need to cast here.
276
+ timesteps_projected = timesteps_projected.to(dtype=self.dtype)
277
+ time_embeddings = self.time_embedding(timesteps_projected)
278
+
279
+ if self.embedding_proj_norm is not None:
280
+ proj_embedding = self.embedding_proj_norm(proj_embedding)
281
+
282
+ proj_embeddings = self.embedding_proj(proj_embedding)
283
+ if self.encoder_hidden_states_proj is not None and encoder_hidden_states is not None:
284
+ encoder_hidden_states = self.encoder_hidden_states_proj(encoder_hidden_states)
285
+ elif self.encoder_hidden_states_proj is not None and encoder_hidden_states is None:
286
+ raise ValueError("`encoder_hidden_states_proj` requires `encoder_hidden_states` to be set")
287
+
288
+ hidden_states = self.proj_in(hidden_states)
289
+
290
+ positional_embeddings = self.positional_embedding.to(hidden_states.dtype)
291
+
292
+ additional_embeds = []
293
+ additional_embeddings_len = 0
294
+
295
+ if encoder_hidden_states is not None:
296
+ additional_embeds.append(encoder_hidden_states)
297
+ additional_embeddings_len += encoder_hidden_states.shape[1]
298
+
299
+ if len(proj_embeddings.shape) == 2:
300
+ proj_embeddings = proj_embeddings[:, None, :]
301
+
302
+ if len(hidden_states.shape) == 2:
303
+ hidden_states = hidden_states[:, None, :]
304
+
305
+ additional_embeds = additional_embeds + [
306
+ proj_embeddings,
307
+ time_embeddings[:, None, :],
308
+ hidden_states,
309
+ ]
310
+
311
+ if self.prd_embedding is not None:
312
+ prd_embedding = self.prd_embedding.to(hidden_states.dtype).expand(batch_size, -1, -1)
313
+ additional_embeds.append(prd_embedding)
314
+
315
+ hidden_states = torch.cat(
316
+ additional_embeds,
317
+ dim=1,
318
+ )
319
+
320
+ # Allow positional_embedding to not include the `addtional_embeddings` and instead pad it with zeros for these additional tokens
321
+ additional_embeddings_len = additional_embeddings_len + proj_embeddings.shape[1] + 1
322
+ if positional_embeddings.shape[1] < hidden_states.shape[1]:
323
+ positional_embeddings = F.pad(
324
+ positional_embeddings,
325
+ (
326
+ 0,
327
+ 0,
328
+ additional_embeddings_len,
329
+ self.prd_embedding.shape[1] if self.prd_embedding is not None else 0,
330
+ ),
331
+ value=0.0,
332
+ )
333
+
334
+ hidden_states = hidden_states + positional_embeddings
335
+
336
+ if attention_mask is not None:
337
+ attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
338
+ attention_mask = F.pad(attention_mask, (0, self.additional_embeddings), value=0.0)
339
+ attention_mask = (attention_mask[:, None, :] + self.causal_attention_mask).to(hidden_states.dtype)
340
+ attention_mask = attention_mask.repeat_interleave(self.config.num_attention_heads, dim=0)
341
+
342
+ if self.norm_in is not None:
343
+ hidden_states = self.norm_in(hidden_states)
344
+
345
+ for block in self.transformer_blocks:
346
+ hidden_states = block(hidden_states, attention_mask=attention_mask)
347
+
348
+ hidden_states = self.norm_out(hidden_states)
349
+
350
+ if self.prd_embedding is not None:
351
+ hidden_states = hidden_states[:, -1]
352
+ else:
353
+ hidden_states = hidden_states[:, additional_embeddings_len:]
354
+
355
+ predicted_image_embedding = self.proj_to_clip_embeddings(hidden_states)
356
+
357
+ if not return_dict:
358
+ return (predicted_image_embedding,)
359
+
360
+ return PriorTransformerOutput(predicted_image_embedding=predicted_image_embedding)
361
+
362
+ def post_process_latents(self, prior_latents):
363
+ prior_latents = (prior_latents * self.clip_std) + self.clip_mean
364
+ return prior_latents
6DoF/diffusers/models/resnet.py ADDED
@@ -0,0 +1,877 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ # `TemporalConvLayer` Copyright 2023 Alibaba DAMO-VILAB, The ModelScope Team and The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from functools import partial
17
+ from typing import Optional
18
+
19
+ import torch
20
+ import torch.nn as nn
21
+ import torch.nn.functional as F
22
+
23
+ from .activations import get_activation
24
+ from .attention import AdaGroupNorm
25
+ from .attention_processor import SpatialNorm
26
+
27
+
28
+ class Upsample1D(nn.Module):
29
+ """A 1D upsampling layer with an optional convolution.
30
+
31
+ Parameters:
32
+ channels (`int`):
33
+ number of channels in the inputs and outputs.
34
+ use_conv (`bool`, default `False`):
35
+ option to use a convolution.
36
+ use_conv_transpose (`bool`, default `False`):
37
+ option to use a convolution transpose.
38
+ out_channels (`int`, optional):
39
+ number of output channels. Defaults to `channels`.
40
+ """
41
+
42
+ def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
43
+ super().__init__()
44
+ self.channels = channels
45
+ self.out_channels = out_channels or channels
46
+ self.use_conv = use_conv
47
+ self.use_conv_transpose = use_conv_transpose
48
+ self.name = name
49
+
50
+ self.conv = None
51
+ if use_conv_transpose:
52
+ self.conv = nn.ConvTranspose1d(channels, self.out_channels, 4, 2, 1)
53
+ elif use_conv:
54
+ self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1)
55
+
56
+ def forward(self, inputs):
57
+ assert inputs.shape[1] == self.channels
58
+ if self.use_conv_transpose:
59
+ return self.conv(inputs)
60
+
61
+ outputs = F.interpolate(inputs, scale_factor=2.0, mode="nearest")
62
+
63
+ if self.use_conv:
64
+ outputs = self.conv(outputs)
65
+
66
+ return outputs
67
+
68
+
69
+ class Downsample1D(nn.Module):
70
+ """A 1D downsampling layer with an optional convolution.
71
+
72
+ Parameters:
73
+ channels (`int`):
74
+ number of channels in the inputs and outputs.
75
+ use_conv (`bool`, default `False`):
76
+ option to use a convolution.
77
+ out_channels (`int`, optional):
78
+ number of output channels. Defaults to `channels`.
79
+ padding (`int`, default `1`):
80
+ padding for the convolution.
81
+ """
82
+
83
+ def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
84
+ super().__init__()
85
+ self.channels = channels
86
+ self.out_channels = out_channels or channels
87
+ self.use_conv = use_conv
88
+ self.padding = padding
89
+ stride = 2
90
+ self.name = name
91
+
92
+ if use_conv:
93
+ self.conv = nn.Conv1d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
94
+ else:
95
+ assert self.channels == self.out_channels
96
+ self.conv = nn.AvgPool1d(kernel_size=stride, stride=stride)
97
+
98
+ def forward(self, inputs):
99
+ assert inputs.shape[1] == self.channels
100
+ return self.conv(inputs)
101
+
102
+
103
+ class Upsample2D(nn.Module):
104
+ """A 2D upsampling layer with an optional convolution.
105
+
106
+ Parameters:
107
+ channels (`int`):
108
+ number of channels in the inputs and outputs.
109
+ use_conv (`bool`, default `False`):
110
+ option to use a convolution.
111
+ use_conv_transpose (`bool`, default `False`):
112
+ option to use a convolution transpose.
113
+ out_channels (`int`, optional):
114
+ number of output channels. Defaults to `channels`.
115
+ """
116
+
117
+ def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
118
+ super().__init__()
119
+ self.channels = channels
120
+ self.out_channels = out_channels or channels
121
+ self.use_conv = use_conv
122
+ self.use_conv_transpose = use_conv_transpose
123
+ self.name = name
124
+
125
+ conv = None
126
+ if use_conv_transpose:
127
+ conv = nn.ConvTranspose2d(channels, self.out_channels, 4, 2, 1)
128
+ elif use_conv:
129
+ conv = nn.Conv2d(self.channels, self.out_channels, 3, padding=1)
130
+
131
+ # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
132
+ if name == "conv":
133
+ self.conv = conv
134
+ else:
135
+ self.Conv2d_0 = conv
136
+
137
+ def forward(self, hidden_states, output_size=None):
138
+ assert hidden_states.shape[1] == self.channels
139
+
140
+ if self.use_conv_transpose:
141
+ return self.conv(hidden_states)
142
+
143
+ # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
144
+ # TODO(Suraj): Remove this cast once the issue is fixed in PyTorch
145
+ # https://github.com/pytorch/pytorch/issues/86679
146
+ dtype = hidden_states.dtype
147
+ if dtype == torch.bfloat16:
148
+ hidden_states = hidden_states.to(torch.float32)
149
+
150
+ # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
151
+ if hidden_states.shape[0] >= 64:
152
+ hidden_states = hidden_states.contiguous()
153
+
154
+ # if `output_size` is passed we force the interpolation output
155
+ # size and do not make use of `scale_factor=2`
156
+ if output_size is None:
157
+ hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
158
+ else:
159
+ hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
160
+
161
+ # If the input is bfloat16, we cast back to bfloat16
162
+ if dtype == torch.bfloat16:
163
+ hidden_states = hidden_states.to(dtype)
164
+
165
+ # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
166
+ if self.use_conv:
167
+ if self.name == "conv":
168
+ hidden_states = self.conv(hidden_states)
169
+ else:
170
+ hidden_states = self.Conv2d_0(hidden_states)
171
+
172
+ return hidden_states
173
+
174
+
175
+ class Downsample2D(nn.Module):
176
+ """A 2D downsampling layer with an optional convolution.
177
+
178
+ Parameters:
179
+ channels (`int`):
180
+ number of channels in the inputs and outputs.
181
+ use_conv (`bool`, default `False`):
182
+ option to use a convolution.
183
+ out_channels (`int`, optional):
184
+ number of output channels. Defaults to `channels`.
185
+ padding (`int`, default `1`):
186
+ padding for the convolution.
187
+ """
188
+
189
+ def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
190
+ super().__init__()
191
+ self.channels = channels
192
+ self.out_channels = out_channels or channels
193
+ self.use_conv = use_conv
194
+ self.padding = padding
195
+ stride = 2
196
+ self.name = name
197
+
198
+ if use_conv:
199
+ conv = nn.Conv2d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
200
+ else:
201
+ assert self.channels == self.out_channels
202
+ conv = nn.AvgPool2d(kernel_size=stride, stride=stride)
203
+
204
+ # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
205
+ if name == "conv":
206
+ self.Conv2d_0 = conv
207
+ self.conv = conv
208
+ elif name == "Conv2d_0":
209
+ self.conv = conv
210
+ else:
211
+ self.conv = conv
212
+
213
+ def forward(self, hidden_states):
214
+ assert hidden_states.shape[1] == self.channels
215
+ if self.use_conv and self.padding == 0:
216
+ pad = (0, 1, 0, 1)
217
+ hidden_states = F.pad(hidden_states, pad, mode="constant", value=0)
218
+
219
+ assert hidden_states.shape[1] == self.channels
220
+ hidden_states = self.conv(hidden_states)
221
+
222
+ return hidden_states
223
+
224
+
225
+ class FirUpsample2D(nn.Module):
226
+ """A 2D FIR upsampling layer with an optional convolution.
227
+
228
+ Parameters:
229
+ channels (`int`):
230
+ number of channels in the inputs and outputs.
231
+ use_conv (`bool`, default `False`):
232
+ option to use a convolution.
233
+ out_channels (`int`, optional):
234
+ number of output channels. Defaults to `channels`.
235
+ fir_kernel (`tuple`, default `(1, 3, 3, 1)`):
236
+ kernel for the FIR filter.
237
+ """
238
+
239
+ def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
240
+ super().__init__()
241
+ out_channels = out_channels if out_channels else channels
242
+ if use_conv:
243
+ self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
244
+ self.use_conv = use_conv
245
+ self.fir_kernel = fir_kernel
246
+ self.out_channels = out_channels
247
+
248
+ def _upsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain=1):
249
+ """Fused `upsample_2d()` followed by `Conv2d()`.
250
+
251
+ Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
252
+ efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of
253
+ arbitrary order.
254
+
255
+ Args:
256
+ hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
257
+ weight: Weight tensor of the shape `[filterH, filterW, inChannels,
258
+ outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`.
259
+ kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
260
+ (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
261
+ factor: Integer upsampling factor (default: 2).
262
+ gain: Scaling factor for signal magnitude (default: 1.0).
263
+
264
+ Returns:
265
+ output: Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same
266
+ datatype as `hidden_states`.
267
+ """
268
+
269
+ assert isinstance(factor, int) and factor >= 1
270
+
271
+ # Setup filter kernel.
272
+ if kernel is None:
273
+ kernel = [1] * factor
274
+
275
+ # setup kernel
276
+ kernel = torch.tensor(kernel, dtype=torch.float32)
277
+ if kernel.ndim == 1:
278
+ kernel = torch.outer(kernel, kernel)
279
+ kernel /= torch.sum(kernel)
280
+
281
+ kernel = kernel * (gain * (factor**2))
282
+
283
+ if self.use_conv:
284
+ convH = weight.shape[2]
285
+ convW = weight.shape[3]
286
+ inC = weight.shape[1]
287
+
288
+ pad_value = (kernel.shape[0] - factor) - (convW - 1)
289
+
290
+ stride = (factor, factor)
291
+ # Determine data dimensions.
292
+ output_shape = (
293
+ (hidden_states.shape[2] - 1) * factor + convH,
294
+ (hidden_states.shape[3] - 1) * factor + convW,
295
+ )
296
+ output_padding = (
297
+ output_shape[0] - (hidden_states.shape[2] - 1) * stride[0] - convH,
298
+ output_shape[1] - (hidden_states.shape[3] - 1) * stride[1] - convW,
299
+ )
300
+ assert output_padding[0] >= 0 and output_padding[1] >= 0
301
+ num_groups = hidden_states.shape[1] // inC
302
+
303
+ # Transpose weights.
304
+ weight = torch.reshape(weight, (num_groups, -1, inC, convH, convW))
305
+ weight = torch.flip(weight, dims=[3, 4]).permute(0, 2, 1, 3, 4)
306
+ weight = torch.reshape(weight, (num_groups * inC, -1, convH, convW))
307
+
308
+ inverse_conv = F.conv_transpose2d(
309
+ hidden_states, weight, stride=stride, output_padding=output_padding, padding=0
310
+ )
311
+
312
+ output = upfirdn2d_native(
313
+ inverse_conv,
314
+ torch.tensor(kernel, device=inverse_conv.device),
315
+ pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2 + 1),
316
+ )
317
+ else:
318
+ pad_value = kernel.shape[0] - factor
319
+ output = upfirdn2d_native(
320
+ hidden_states,
321
+ torch.tensor(kernel, device=hidden_states.device),
322
+ up=factor,
323
+ pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2),
324
+ )
325
+
326
+ return output
327
+
328
+ def forward(self, hidden_states):
329
+ if self.use_conv:
330
+ height = self._upsample_2d(hidden_states, self.Conv2d_0.weight, kernel=self.fir_kernel)
331
+ height = height + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
332
+ else:
333
+ height = self._upsample_2d(hidden_states, kernel=self.fir_kernel, factor=2)
334
+
335
+ return height
336
+
337
+
338
+ class FirDownsample2D(nn.Module):
339
+ """A 2D FIR downsampling layer with an optional convolution.
340
+
341
+ Parameters:
342
+ channels (`int`):
343
+ number of channels in the inputs and outputs.
344
+ use_conv (`bool`, default `False`):
345
+ option to use a convolution.
346
+ out_channels (`int`, optional):
347
+ number of output channels. Defaults to `channels`.
348
+ fir_kernel (`tuple`, default `(1, 3, 3, 1)`):
349
+ kernel for the FIR filter.
350
+ """
351
+
352
+ def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
353
+ super().__init__()
354
+ out_channels = out_channels if out_channels else channels
355
+ if use_conv:
356
+ self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
357
+ self.fir_kernel = fir_kernel
358
+ self.use_conv = use_conv
359
+ self.out_channels = out_channels
360
+
361
+ def _downsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain=1):
362
+ """Fused `Conv2d()` followed by `downsample_2d()`.
363
+ Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
364
+ efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of
365
+ arbitrary order.
366
+
367
+ Args:
368
+ hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
369
+ weight:
370
+ Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`. Grouped convolution can be
371
+ performed by `inChannels = x.shape[0] // numGroups`.
372
+ kernel: FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] *
373
+ factor`, which corresponds to average pooling.
374
+ factor: Integer downsampling factor (default: 2).
375
+ gain: Scaling factor for signal magnitude (default: 1.0).
376
+
377
+ Returns:
378
+ output: Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and
379
+ same datatype as `x`.
380
+ """
381
+
382
+ assert isinstance(factor, int) and factor >= 1
383
+ if kernel is None:
384
+ kernel = [1] * factor
385
+
386
+ # setup kernel
387
+ kernel = torch.tensor(kernel, dtype=torch.float32)
388
+ if kernel.ndim == 1:
389
+ kernel = torch.outer(kernel, kernel)
390
+ kernel /= torch.sum(kernel)
391
+
392
+ kernel = kernel * gain
393
+
394
+ if self.use_conv:
395
+ _, _, convH, convW = weight.shape
396
+ pad_value = (kernel.shape[0] - factor) + (convW - 1)
397
+ stride_value = [factor, factor]
398
+ upfirdn_input = upfirdn2d_native(
399
+ hidden_states,
400
+ torch.tensor(kernel, device=hidden_states.device),
401
+ pad=((pad_value + 1) // 2, pad_value // 2),
402
+ )
403
+ output = F.conv2d(upfirdn_input, weight, stride=stride_value, padding=0)
404
+ else:
405
+ pad_value = kernel.shape[0] - factor
406
+ output = upfirdn2d_native(
407
+ hidden_states,
408
+ torch.tensor(kernel, device=hidden_states.device),
409
+ down=factor,
410
+ pad=((pad_value + 1) // 2, pad_value // 2),
411
+ )
412
+
413
+ return output
414
+
415
+ def forward(self, hidden_states):
416
+ if self.use_conv:
417
+ downsample_input = self._downsample_2d(hidden_states, weight=self.Conv2d_0.weight, kernel=self.fir_kernel)
418
+ hidden_states = downsample_input + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
419
+ else:
420
+ hidden_states = self._downsample_2d(hidden_states, kernel=self.fir_kernel, factor=2)
421
+
422
+ return hidden_states
423
+
424
+
425
+ # downsample/upsample layer used in k-upscaler, might be able to use FirDownsample2D/DirUpsample2D instead
426
+ class KDownsample2D(nn.Module):
427
+ def __init__(self, pad_mode="reflect"):
428
+ super().__init__()
429
+ self.pad_mode = pad_mode
430
+ kernel_1d = torch.tensor([[1 / 8, 3 / 8, 3 / 8, 1 / 8]])
431
+ self.pad = kernel_1d.shape[1] // 2 - 1
432
+ self.register_buffer("kernel", kernel_1d.T @ kernel_1d, persistent=False)
433
+
434
+ def forward(self, inputs):
435
+ inputs = F.pad(inputs, (self.pad,) * 4, self.pad_mode)
436
+ weight = inputs.new_zeros([inputs.shape[1], inputs.shape[1], self.kernel.shape[0], self.kernel.shape[1]])
437
+ indices = torch.arange(inputs.shape[1], device=inputs.device)
438
+ kernel = self.kernel.to(weight)[None, :].expand(inputs.shape[1], -1, -1)
439
+ weight[indices, indices] = kernel
440
+ return F.conv2d(inputs, weight, stride=2)
441
+
442
+
443
+ class KUpsample2D(nn.Module):
444
+ def __init__(self, pad_mode="reflect"):
445
+ super().__init__()
446
+ self.pad_mode = pad_mode
447
+ kernel_1d = torch.tensor([[1 / 8, 3 / 8, 3 / 8, 1 / 8]]) * 2
448
+ self.pad = kernel_1d.shape[1] // 2 - 1
449
+ self.register_buffer("kernel", kernel_1d.T @ kernel_1d, persistent=False)
450
+
451
+ def forward(self, inputs):
452
+ inputs = F.pad(inputs, ((self.pad + 1) // 2,) * 4, self.pad_mode)
453
+ weight = inputs.new_zeros([inputs.shape[1], inputs.shape[1], self.kernel.shape[0], self.kernel.shape[1]])
454
+ indices = torch.arange(inputs.shape[1], device=inputs.device)
455
+ kernel = self.kernel.to(weight)[None, :].expand(inputs.shape[1], -1, -1)
456
+ weight[indices, indices] = kernel
457
+ return F.conv_transpose2d(inputs, weight, stride=2, padding=self.pad * 2 + 1)
458
+
459
+
460
+ class ResnetBlock2D(nn.Module):
461
+ r"""
462
+ A Resnet block.
463
+
464
+ Parameters:
465
+ in_channels (`int`): The number of channels in the input.
466
+ out_channels (`int`, *optional*, default to be `None`):
467
+ The number of output channels for the first conv2d layer. If None, same as `in_channels`.
468
+ dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use.
469
+ temb_channels (`int`, *optional*, default to `512`): the number of channels in timestep embedding.
470
+ groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer.
471
+ groups_out (`int`, *optional*, default to None):
472
+ The number of groups to use for the second normalization layer. if set to None, same as `groups`.
473
+ eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization.
474
+ non_linearity (`str`, *optional*, default to `"swish"`): the activation function to use.
475
+ time_embedding_norm (`str`, *optional*, default to `"default"` ): Time scale shift config.
476
+ By default, apply timestep embedding conditioning with a simple shift mechanism. Choose "scale_shift" or
477
+ "ada_group" for a stronger conditioning with scale and shift.
478
+ kernel (`torch.FloatTensor`, optional, default to None): FIR filter, see
479
+ [`~models.resnet.FirUpsample2D`] and [`~models.resnet.FirDownsample2D`].
480
+ output_scale_factor (`float`, *optional*, default to be `1.0`): the scale factor to use for the output.
481
+ use_in_shortcut (`bool`, *optional*, default to `True`):
482
+ If `True`, add a 1x1 nn.conv2d layer for skip-connection.
483
+ up (`bool`, *optional*, default to `False`): If `True`, add an upsample layer.
484
+ down (`bool`, *optional*, default to `False`): If `True`, add a downsample layer.
485
+ conv_shortcut_bias (`bool`, *optional*, default to `True`): If `True`, adds a learnable bias to the
486
+ `conv_shortcut` output.
487
+ conv_2d_out_channels (`int`, *optional*, default to `None`): the number of channels in the output.
488
+ If None, same as `out_channels`.
489
+ """
490
+
491
+ def __init__(
492
+ self,
493
+ *,
494
+ in_channels,
495
+ out_channels=None,
496
+ conv_shortcut=False,
497
+ dropout=0.0,
498
+ temb_channels=512,
499
+ groups=32,
500
+ groups_out=None,
501
+ pre_norm=True,
502
+ eps=1e-6,
503
+ non_linearity="swish",
504
+ skip_time_act=False,
505
+ time_embedding_norm="default", # default, scale_shift, ada_group, spatial
506
+ kernel=None,
507
+ output_scale_factor=1.0,
508
+ use_in_shortcut=None,
509
+ up=False,
510
+ down=False,
511
+ conv_shortcut_bias: bool = True,
512
+ conv_2d_out_channels: Optional[int] = None,
513
+ ):
514
+ super().__init__()
515
+ self.pre_norm = pre_norm
516
+ self.pre_norm = True
517
+ self.in_channels = in_channels
518
+ out_channels = in_channels if out_channels is None else out_channels
519
+ self.out_channels = out_channels
520
+ self.use_conv_shortcut = conv_shortcut
521
+ self.up = up
522
+ self.down = down
523
+ self.output_scale_factor = output_scale_factor
524
+ self.time_embedding_norm = time_embedding_norm
525
+ self.skip_time_act = skip_time_act
526
+
527
+ if groups_out is None:
528
+ groups_out = groups
529
+
530
+ if self.time_embedding_norm == "ada_group":
531
+ self.norm1 = AdaGroupNorm(temb_channels, in_channels, groups, eps=eps)
532
+ elif self.time_embedding_norm == "spatial":
533
+ self.norm1 = SpatialNorm(in_channels, temb_channels)
534
+ else:
535
+ self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
536
+
537
+ self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
538
+
539
+ if temb_channels is not None:
540
+ if self.time_embedding_norm == "default":
541
+ self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels)
542
+ elif self.time_embedding_norm == "scale_shift":
543
+ self.time_emb_proj = torch.nn.Linear(temb_channels, 2 * out_channels)
544
+ elif self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
545
+ self.time_emb_proj = None
546
+ else:
547
+ raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ")
548
+ else:
549
+ self.time_emb_proj = None
550
+
551
+ if self.time_embedding_norm == "ada_group":
552
+ self.norm2 = AdaGroupNorm(temb_channels, out_channels, groups_out, eps=eps)
553
+ elif self.time_embedding_norm == "spatial":
554
+ self.norm2 = SpatialNorm(out_channels, temb_channels)
555
+ else:
556
+ self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
557
+
558
+ self.dropout = torch.nn.Dropout(dropout)
559
+ conv_2d_out_channels = conv_2d_out_channels or out_channels
560
+ self.conv2 = torch.nn.Conv2d(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1)
561
+
562
+ self.nonlinearity = get_activation(non_linearity)
563
+
564
+ self.upsample = self.downsample = None
565
+ if self.up:
566
+ if kernel == "fir":
567
+ fir_kernel = (1, 3, 3, 1)
568
+ self.upsample = lambda x: upsample_2d(x, kernel=fir_kernel)
569
+ elif kernel == "sde_vp":
570
+ self.upsample = partial(F.interpolate, scale_factor=2.0, mode="nearest")
571
+ else:
572
+ self.upsample = Upsample2D(in_channels, use_conv=False)
573
+ elif self.down:
574
+ if kernel == "fir":
575
+ fir_kernel = (1, 3, 3, 1)
576
+ self.downsample = lambda x: downsample_2d(x, kernel=fir_kernel)
577
+ elif kernel == "sde_vp":
578
+ self.downsample = partial(F.avg_pool2d, kernel_size=2, stride=2)
579
+ else:
580
+ self.downsample = Downsample2D(in_channels, use_conv=False, padding=1, name="op")
581
+
582
+ self.use_in_shortcut = self.in_channels != conv_2d_out_channels if use_in_shortcut is None else use_in_shortcut
583
+
584
+ self.conv_shortcut = None
585
+ if self.use_in_shortcut:
586
+ self.conv_shortcut = torch.nn.Conv2d(
587
+ in_channels, conv_2d_out_channels, kernel_size=1, stride=1, padding=0, bias=conv_shortcut_bias
588
+ )
589
+
590
+ def forward(self, input_tensor, temb):
591
+ hidden_states = input_tensor
592
+
593
+ if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
594
+ hidden_states = self.norm1(hidden_states, temb)
595
+ else:
596
+ hidden_states = self.norm1(hidden_states)
597
+
598
+ hidden_states = self.nonlinearity(hidden_states)
599
+
600
+ if self.upsample is not None:
601
+ # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
602
+ if hidden_states.shape[0] >= 64:
603
+ input_tensor = input_tensor.contiguous()
604
+ hidden_states = hidden_states.contiguous()
605
+ input_tensor = self.upsample(input_tensor)
606
+ hidden_states = self.upsample(hidden_states)
607
+ elif self.downsample is not None:
608
+ input_tensor = self.downsample(input_tensor)
609
+ hidden_states = self.downsample(hidden_states)
610
+
611
+ hidden_states = self.conv1(hidden_states)
612
+
613
+ if self.time_emb_proj is not None:
614
+ if not self.skip_time_act:
615
+ temb = self.nonlinearity(temb)
616
+ temb = self.time_emb_proj(temb)[:, :, None, None]
617
+
618
+ if temb is not None and self.time_embedding_norm == "default":
619
+ hidden_states = hidden_states + temb
620
+
621
+ if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
622
+ hidden_states = self.norm2(hidden_states, temb)
623
+ else:
624
+ hidden_states = self.norm2(hidden_states)
625
+
626
+ if temb is not None and self.time_embedding_norm == "scale_shift":
627
+ scale, shift = torch.chunk(temb, 2, dim=1)
628
+ hidden_states = hidden_states * (1 + scale) + shift
629
+
630
+ hidden_states = self.nonlinearity(hidden_states)
631
+
632
+ hidden_states = self.dropout(hidden_states)
633
+ hidden_states = self.conv2(hidden_states)
634
+
635
+ if self.conv_shortcut is not None:
636
+ input_tensor = self.conv_shortcut(input_tensor)
637
+
638
+ output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
639
+
640
+ return output_tensor
641
+
642
+
643
+ # unet_rl.py
644
+ def rearrange_dims(tensor):
645
+ if len(tensor.shape) == 2:
646
+ return tensor[:, :, None]
647
+ if len(tensor.shape) == 3:
648
+ return tensor[:, :, None, :]
649
+ elif len(tensor.shape) == 4:
650
+ return tensor[:, :, 0, :]
651
+ else:
652
+ raise ValueError(f"`len(tensor)`: {len(tensor)} has to be 2, 3 or 4.")
653
+
654
+
655
+ class Conv1dBlock(nn.Module):
656
+ """
657
+ Conv1d --> GroupNorm --> Mish
658
+ """
659
+
660
+ def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
661
+ super().__init__()
662
+
663
+ self.conv1d = nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2)
664
+ self.group_norm = nn.GroupNorm(n_groups, out_channels)
665
+ self.mish = nn.Mish()
666
+
667
+ def forward(self, inputs):
668
+ intermediate_repr = self.conv1d(inputs)
669
+ intermediate_repr = rearrange_dims(intermediate_repr)
670
+ intermediate_repr = self.group_norm(intermediate_repr)
671
+ intermediate_repr = rearrange_dims(intermediate_repr)
672
+ output = self.mish(intermediate_repr)
673
+ return output
674
+
675
+
676
+ # unet_rl.py
677
+ class ResidualTemporalBlock1D(nn.Module):
678
+ def __init__(self, inp_channels, out_channels, embed_dim, kernel_size=5):
679
+ super().__init__()
680
+ self.conv_in = Conv1dBlock(inp_channels, out_channels, kernel_size)
681
+ self.conv_out = Conv1dBlock(out_channels, out_channels, kernel_size)
682
+
683
+ self.time_emb_act = nn.Mish()
684
+ self.time_emb = nn.Linear(embed_dim, out_channels)
685
+
686
+ self.residual_conv = (
687
+ nn.Conv1d(inp_channels, out_channels, 1) if inp_channels != out_channels else nn.Identity()
688
+ )
689
+
690
+ def forward(self, inputs, t):
691
+ """
692
+ Args:
693
+ inputs : [ batch_size x inp_channels x horizon ]
694
+ t : [ batch_size x embed_dim ]
695
+
696
+ returns:
697
+ out : [ batch_size x out_channels x horizon ]
698
+ """
699
+ t = self.time_emb_act(t)
700
+ t = self.time_emb(t)
701
+ out = self.conv_in(inputs) + rearrange_dims(t)
702
+ out = self.conv_out(out)
703
+ return out + self.residual_conv(inputs)
704
+
705
+
706
+ def upsample_2d(hidden_states, kernel=None, factor=2, gain=1):
707
+ r"""Upsample2D a batch of 2D images with the given filter.
708
+ Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given
709
+ filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified
710
+ `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is
711
+ a: multiple of the upsampling factor.
712
+
713
+ Args:
714
+ hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
715
+ kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
716
+ (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
717
+ factor: Integer upsampling factor (default: 2).
718
+ gain: Scaling factor for signal magnitude (default: 1.0).
719
+
720
+ Returns:
721
+ output: Tensor of the shape `[N, C, H * factor, W * factor]`
722
+ """
723
+ assert isinstance(factor, int) and factor >= 1
724
+ if kernel is None:
725
+ kernel = [1] * factor
726
+
727
+ kernel = torch.tensor(kernel, dtype=torch.float32)
728
+ if kernel.ndim == 1:
729
+ kernel = torch.outer(kernel, kernel)
730
+ kernel /= torch.sum(kernel)
731
+
732
+ kernel = kernel * (gain * (factor**2))
733
+ pad_value = kernel.shape[0] - factor
734
+ output = upfirdn2d_native(
735
+ hidden_states,
736
+ kernel.to(device=hidden_states.device),
737
+ up=factor,
738
+ pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2),
739
+ )
740
+ return output
741
+
742
+
743
+ def downsample_2d(hidden_states, kernel=None, factor=2, gain=1):
744
+ r"""Downsample2D a batch of 2D images with the given filter.
745
+ Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the
746
+ given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the
747
+ specified `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its
748
+ shape is a multiple of the downsampling factor.
749
+
750
+ Args:
751
+ hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
752
+ kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
753
+ (separable). The default is `[1] * factor`, which corresponds to average pooling.
754
+ factor: Integer downsampling factor (default: 2).
755
+ gain: Scaling factor for signal magnitude (default: 1.0).
756
+
757
+ Returns:
758
+ output: Tensor of the shape `[N, C, H // factor, W // factor]`
759
+ """
760
+
761
+ assert isinstance(factor, int) and factor >= 1
762
+ if kernel is None:
763
+ kernel = [1] * factor
764
+
765
+ kernel = torch.tensor(kernel, dtype=torch.float32)
766
+ if kernel.ndim == 1:
767
+ kernel = torch.outer(kernel, kernel)
768
+ kernel /= torch.sum(kernel)
769
+
770
+ kernel = kernel * gain
771
+ pad_value = kernel.shape[0] - factor
772
+ output = upfirdn2d_native(
773
+ hidden_states, kernel.to(device=hidden_states.device), down=factor, pad=((pad_value + 1) // 2, pad_value // 2)
774
+ )
775
+ return output
776
+
777
+
778
+ def upfirdn2d_native(tensor, kernel, up=1, down=1, pad=(0, 0)):
779
+ up_x = up_y = up
780
+ down_x = down_y = down
781
+ pad_x0 = pad_y0 = pad[0]
782
+ pad_x1 = pad_y1 = pad[1]
783
+
784
+ _, channel, in_h, in_w = tensor.shape
785
+ tensor = tensor.reshape(-1, in_h, in_w, 1)
786
+
787
+ _, in_h, in_w, minor = tensor.shape
788
+ kernel_h, kernel_w = kernel.shape
789
+
790
+ out = tensor.view(-1, in_h, 1, in_w, 1, minor)
791
+ out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
792
+ out = out.view(-1, in_h * up_y, in_w * up_x, minor)
793
+
794
+ out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])
795
+ out = out.to(tensor.device) # Move back to mps if necessary
796
+ out = out[
797
+ :,
798
+ max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
799
+ max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
800
+ :,
801
+ ]
802
+
803
+ out = out.permute(0, 3, 1, 2)
804
+ out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])
805
+ w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
806
+ out = F.conv2d(out, w)
807
+ out = out.reshape(
808
+ -1,
809
+ minor,
810
+ in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
811
+ in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
812
+ )
813
+ out = out.permute(0, 2, 3, 1)
814
+ out = out[:, ::down_y, ::down_x, :]
815
+
816
+ out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
817
+ out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
818
+
819
+ return out.view(-1, channel, out_h, out_w)
820
+
821
+
822
+ class TemporalConvLayer(nn.Module):
823
+ """
824
+ Temporal convolutional layer that can be used for video (sequence of images) input Code mostly copied from:
825
+ https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/models/multi_modal/video_synthesis/unet_sd.py#L1016
826
+ """
827
+
828
+ def __init__(self, in_dim, out_dim=None, dropout=0.0):
829
+ super().__init__()
830
+ out_dim = out_dim or in_dim
831
+ self.in_dim = in_dim
832
+ self.out_dim = out_dim
833
+
834
+ # conv layers
835
+ self.conv1 = nn.Sequential(
836
+ nn.GroupNorm(32, in_dim), nn.SiLU(), nn.Conv3d(in_dim, out_dim, (3, 1, 1), padding=(1, 0, 0))
837
+ )
838
+ self.conv2 = nn.Sequential(
839
+ nn.GroupNorm(32, out_dim),
840
+ nn.SiLU(),
841
+ nn.Dropout(dropout),
842
+ nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)),
843
+ )
844
+ self.conv3 = nn.Sequential(
845
+ nn.GroupNorm(32, out_dim),
846
+ nn.SiLU(),
847
+ nn.Dropout(dropout),
848
+ nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)),
849
+ )
850
+ self.conv4 = nn.Sequential(
851
+ nn.GroupNorm(32, out_dim),
852
+ nn.SiLU(),
853
+ nn.Dropout(dropout),
854
+ nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)),
855
+ )
856
+
857
+ # zero out the last layer params,so the conv block is identity
858
+ nn.init.zeros_(self.conv4[-1].weight)
859
+ nn.init.zeros_(self.conv4[-1].bias)
860
+
861
+ def forward(self, hidden_states, num_frames=1):
862
+ hidden_states = (
863
+ hidden_states[None, :].reshape((-1, num_frames) + hidden_states.shape[1:]).permute(0, 2, 1, 3, 4)
864
+ )
865
+
866
+ identity = hidden_states
867
+ hidden_states = self.conv1(hidden_states)
868
+ hidden_states = self.conv2(hidden_states)
869
+ hidden_states = self.conv3(hidden_states)
870
+ hidden_states = self.conv4(hidden_states)
871
+
872
+ hidden_states = identity + hidden_states
873
+
874
+ hidden_states = hidden_states.permute(0, 2, 1, 3, 4).reshape(
875
+ (hidden_states.shape[0] * hidden_states.shape[2], -1) + hidden_states.shape[3:]
876
+ )
877
+ return hidden_states
6DoF/diffusers/models/resnet_flax.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import flax.linen as nn
15
+ import jax
16
+ import jax.numpy as jnp
17
+
18
+
19
+ class FlaxUpsample2D(nn.Module):
20
+ out_channels: int
21
+ dtype: jnp.dtype = jnp.float32
22
+
23
+ def setup(self):
24
+ self.conv = nn.Conv(
25
+ self.out_channels,
26
+ kernel_size=(3, 3),
27
+ strides=(1, 1),
28
+ padding=((1, 1), (1, 1)),
29
+ dtype=self.dtype,
30
+ )
31
+
32
+ def __call__(self, hidden_states):
33
+ batch, height, width, channels = hidden_states.shape
34
+ hidden_states = jax.image.resize(
35
+ hidden_states,
36
+ shape=(batch, height * 2, width * 2, channels),
37
+ method="nearest",
38
+ )
39
+ hidden_states = self.conv(hidden_states)
40
+ return hidden_states
41
+
42
+
43
+ class FlaxDownsample2D(nn.Module):
44
+ out_channels: int
45
+ dtype: jnp.dtype = jnp.float32
46
+
47
+ def setup(self):
48
+ self.conv = nn.Conv(
49
+ self.out_channels,
50
+ kernel_size=(3, 3),
51
+ strides=(2, 2),
52
+ padding=((1, 1), (1, 1)), # padding="VALID",
53
+ dtype=self.dtype,
54
+ )
55
+
56
+ def __call__(self, hidden_states):
57
+ # pad = ((0, 0), (0, 1), (0, 1), (0, 0)) # pad height and width dim
58
+ # hidden_states = jnp.pad(hidden_states, pad_width=pad)
59
+ hidden_states = self.conv(hidden_states)
60
+ return hidden_states
61
+
62
+
63
+ class FlaxResnetBlock2D(nn.Module):
64
+ in_channels: int
65
+ out_channels: int = None
66
+ dropout_prob: float = 0.0
67
+ use_nin_shortcut: bool = None
68
+ dtype: jnp.dtype = jnp.float32
69
+
70
+ def setup(self):
71
+ out_channels = self.in_channels if self.out_channels is None else self.out_channels
72
+
73
+ self.norm1 = nn.GroupNorm(num_groups=32, epsilon=1e-5)
74
+ self.conv1 = nn.Conv(
75
+ out_channels,
76
+ kernel_size=(3, 3),
77
+ strides=(1, 1),
78
+ padding=((1, 1), (1, 1)),
79
+ dtype=self.dtype,
80
+ )
81
+
82
+ self.time_emb_proj = nn.Dense(out_channels, dtype=self.dtype)
83
+
84
+ self.norm2 = nn.GroupNorm(num_groups=32, epsilon=1e-5)
85
+ self.dropout = nn.Dropout(self.dropout_prob)
86
+ self.conv2 = nn.Conv(
87
+ out_channels,
88
+ kernel_size=(3, 3),
89
+ strides=(1, 1),
90
+ padding=((1, 1), (1, 1)),
91
+ dtype=self.dtype,
92
+ )
93
+
94
+ use_nin_shortcut = self.in_channels != out_channels if self.use_nin_shortcut is None else self.use_nin_shortcut
95
+
96
+ self.conv_shortcut = None
97
+ if use_nin_shortcut:
98
+ self.conv_shortcut = nn.Conv(
99
+ out_channels,
100
+ kernel_size=(1, 1),
101
+ strides=(1, 1),
102
+ padding="VALID",
103
+ dtype=self.dtype,
104
+ )
105
+
106
+ def __call__(self, hidden_states, temb, deterministic=True):
107
+ residual = hidden_states
108
+ hidden_states = self.norm1(hidden_states)
109
+ hidden_states = nn.swish(hidden_states)
110
+ hidden_states = self.conv1(hidden_states)
111
+
112
+ temb = self.time_emb_proj(nn.swish(temb))
113
+ temb = jnp.expand_dims(jnp.expand_dims(temb, 1), 1)
114
+ hidden_states = hidden_states + temb
115
+
116
+ hidden_states = self.norm2(hidden_states)
117
+ hidden_states = nn.swish(hidden_states)
118
+ hidden_states = self.dropout(hidden_states, deterministic)
119
+ hidden_states = self.conv2(hidden_states)
120
+
121
+ if self.conv_shortcut is not None:
122
+ residual = self.conv_shortcut(residual)
123
+
124
+ return hidden_states + residual
6DoF/diffusers/models/t5_film_transformer.py ADDED
@@ -0,0 +1,321 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import math
15
+
16
+ import torch
17
+ from torch import nn
18
+
19
+ from ..configuration_utils import ConfigMixin, register_to_config
20
+ from .attention_processor import Attention
21
+ from .embeddings import get_timestep_embedding
22
+ from .modeling_utils import ModelMixin
23
+
24
+
25
+ class T5FilmDecoder(ModelMixin, ConfigMixin):
26
+ @register_to_config
27
+ def __init__(
28
+ self,
29
+ input_dims: int = 128,
30
+ targets_length: int = 256,
31
+ max_decoder_noise_time: float = 2000.0,
32
+ d_model: int = 768,
33
+ num_layers: int = 12,
34
+ num_heads: int = 12,
35
+ d_kv: int = 64,
36
+ d_ff: int = 2048,
37
+ dropout_rate: float = 0.1,
38
+ ):
39
+ super().__init__()
40
+
41
+ self.conditioning_emb = nn.Sequential(
42
+ nn.Linear(d_model, d_model * 4, bias=False),
43
+ nn.SiLU(),
44
+ nn.Linear(d_model * 4, d_model * 4, bias=False),
45
+ nn.SiLU(),
46
+ )
47
+
48
+ self.position_encoding = nn.Embedding(targets_length, d_model)
49
+ self.position_encoding.weight.requires_grad = False
50
+
51
+ self.continuous_inputs_projection = nn.Linear(input_dims, d_model, bias=False)
52
+
53
+ self.dropout = nn.Dropout(p=dropout_rate)
54
+
55
+ self.decoders = nn.ModuleList()
56
+ for lyr_num in range(num_layers):
57
+ # FiLM conditional T5 decoder
58
+ lyr = DecoderLayer(d_model=d_model, d_kv=d_kv, num_heads=num_heads, d_ff=d_ff, dropout_rate=dropout_rate)
59
+ self.decoders.append(lyr)
60
+
61
+ self.decoder_norm = T5LayerNorm(d_model)
62
+
63
+ self.post_dropout = nn.Dropout(p=dropout_rate)
64
+ self.spec_out = nn.Linear(d_model, input_dims, bias=False)
65
+
66
+ def encoder_decoder_mask(self, query_input, key_input):
67
+ mask = torch.mul(query_input.unsqueeze(-1), key_input.unsqueeze(-2))
68
+ return mask.unsqueeze(-3)
69
+
70
+ def forward(self, encodings_and_masks, decoder_input_tokens, decoder_noise_time):
71
+ batch, _, _ = decoder_input_tokens.shape
72
+ assert decoder_noise_time.shape == (batch,)
73
+
74
+ # decoder_noise_time is in [0, 1), so rescale to expected timing range.
75
+ time_steps = get_timestep_embedding(
76
+ decoder_noise_time * self.config.max_decoder_noise_time,
77
+ embedding_dim=self.config.d_model,
78
+ max_period=self.config.max_decoder_noise_time,
79
+ ).to(dtype=self.dtype)
80
+
81
+ conditioning_emb = self.conditioning_emb(time_steps).unsqueeze(1)
82
+
83
+ assert conditioning_emb.shape == (batch, 1, self.config.d_model * 4)
84
+
85
+ seq_length = decoder_input_tokens.shape[1]
86
+
87
+ # If we want to use relative positions for audio context, we can just offset
88
+ # this sequence by the length of encodings_and_masks.
89
+ decoder_positions = torch.broadcast_to(
90
+ torch.arange(seq_length, device=decoder_input_tokens.device),
91
+ (batch, seq_length),
92
+ )
93
+
94
+ position_encodings = self.position_encoding(decoder_positions)
95
+
96
+ inputs = self.continuous_inputs_projection(decoder_input_tokens)
97
+ inputs += position_encodings
98
+ y = self.dropout(inputs)
99
+
100
+ # decoder: No padding present.
101
+ decoder_mask = torch.ones(
102
+ decoder_input_tokens.shape[:2], device=decoder_input_tokens.device, dtype=inputs.dtype
103
+ )
104
+
105
+ # Translate encoding masks to encoder-decoder masks.
106
+ encodings_and_encdec_masks = [(x, self.encoder_decoder_mask(decoder_mask, y)) for x, y in encodings_and_masks]
107
+
108
+ # cross attend style: concat encodings
109
+ encoded = torch.cat([x[0] for x in encodings_and_encdec_masks], dim=1)
110
+ encoder_decoder_mask = torch.cat([x[1] for x in encodings_and_encdec_masks], dim=-1)
111
+
112
+ for lyr in self.decoders:
113
+ y = lyr(
114
+ y,
115
+ conditioning_emb=conditioning_emb,
116
+ encoder_hidden_states=encoded,
117
+ encoder_attention_mask=encoder_decoder_mask,
118
+ )[0]
119
+
120
+ y = self.decoder_norm(y)
121
+ y = self.post_dropout(y)
122
+
123
+ spec_out = self.spec_out(y)
124
+ return spec_out
125
+
126
+
127
+ class DecoderLayer(nn.Module):
128
+ def __init__(self, d_model, d_kv, num_heads, d_ff, dropout_rate, layer_norm_epsilon=1e-6):
129
+ super().__init__()
130
+ self.layer = nn.ModuleList()
131
+
132
+ # cond self attention: layer 0
133
+ self.layer.append(
134
+ T5LayerSelfAttentionCond(d_model=d_model, d_kv=d_kv, num_heads=num_heads, dropout_rate=dropout_rate)
135
+ )
136
+
137
+ # cross attention: layer 1
138
+ self.layer.append(
139
+ T5LayerCrossAttention(
140
+ d_model=d_model,
141
+ d_kv=d_kv,
142
+ num_heads=num_heads,
143
+ dropout_rate=dropout_rate,
144
+ layer_norm_epsilon=layer_norm_epsilon,
145
+ )
146
+ )
147
+
148
+ # Film Cond MLP + dropout: last layer
149
+ self.layer.append(
150
+ T5LayerFFCond(d_model=d_model, d_ff=d_ff, dropout_rate=dropout_rate, layer_norm_epsilon=layer_norm_epsilon)
151
+ )
152
+
153
+ def forward(
154
+ self,
155
+ hidden_states,
156
+ conditioning_emb=None,
157
+ attention_mask=None,
158
+ encoder_hidden_states=None,
159
+ encoder_attention_mask=None,
160
+ encoder_decoder_position_bias=None,
161
+ ):
162
+ hidden_states = self.layer[0](
163
+ hidden_states,
164
+ conditioning_emb=conditioning_emb,
165
+ attention_mask=attention_mask,
166
+ )
167
+
168
+ if encoder_hidden_states is not None:
169
+ encoder_extended_attention_mask = torch.where(encoder_attention_mask > 0, 0, -1e10).to(
170
+ encoder_hidden_states.dtype
171
+ )
172
+
173
+ hidden_states = self.layer[1](
174
+ hidden_states,
175
+ key_value_states=encoder_hidden_states,
176
+ attention_mask=encoder_extended_attention_mask,
177
+ )
178
+
179
+ # Apply Film Conditional Feed Forward layer
180
+ hidden_states = self.layer[-1](hidden_states, conditioning_emb)
181
+
182
+ return (hidden_states,)
183
+
184
+
185
+ class T5LayerSelfAttentionCond(nn.Module):
186
+ def __init__(self, d_model, d_kv, num_heads, dropout_rate):
187
+ super().__init__()
188
+ self.layer_norm = T5LayerNorm(d_model)
189
+ self.FiLMLayer = T5FiLMLayer(in_features=d_model * 4, out_features=d_model)
190
+ self.attention = Attention(query_dim=d_model, heads=num_heads, dim_head=d_kv, out_bias=False, scale_qk=False)
191
+ self.dropout = nn.Dropout(dropout_rate)
192
+
193
+ def forward(
194
+ self,
195
+ hidden_states,
196
+ conditioning_emb=None,
197
+ attention_mask=None,
198
+ ):
199
+ # pre_self_attention_layer_norm
200
+ normed_hidden_states = self.layer_norm(hidden_states)
201
+
202
+ if conditioning_emb is not None:
203
+ normed_hidden_states = self.FiLMLayer(normed_hidden_states, conditioning_emb)
204
+
205
+ # Self-attention block
206
+ attention_output = self.attention(normed_hidden_states)
207
+
208
+ hidden_states = hidden_states + self.dropout(attention_output)
209
+
210
+ return hidden_states
211
+
212
+
213
+ class T5LayerCrossAttention(nn.Module):
214
+ def __init__(self, d_model, d_kv, num_heads, dropout_rate, layer_norm_epsilon):
215
+ super().__init__()
216
+ self.attention = Attention(query_dim=d_model, heads=num_heads, dim_head=d_kv, out_bias=False, scale_qk=False)
217
+ self.layer_norm = T5LayerNorm(d_model, eps=layer_norm_epsilon)
218
+ self.dropout = nn.Dropout(dropout_rate)
219
+
220
+ def forward(
221
+ self,
222
+ hidden_states,
223
+ key_value_states=None,
224
+ attention_mask=None,
225
+ ):
226
+ normed_hidden_states = self.layer_norm(hidden_states)
227
+ attention_output = self.attention(
228
+ normed_hidden_states,
229
+ encoder_hidden_states=key_value_states,
230
+ attention_mask=attention_mask.squeeze(1),
231
+ )
232
+ layer_output = hidden_states + self.dropout(attention_output)
233
+ return layer_output
234
+
235
+
236
+ class T5LayerFFCond(nn.Module):
237
+ def __init__(self, d_model, d_ff, dropout_rate, layer_norm_epsilon):
238
+ super().__init__()
239
+ self.DenseReluDense = T5DenseGatedActDense(d_model=d_model, d_ff=d_ff, dropout_rate=dropout_rate)
240
+ self.film = T5FiLMLayer(in_features=d_model * 4, out_features=d_model)
241
+ self.layer_norm = T5LayerNorm(d_model, eps=layer_norm_epsilon)
242
+ self.dropout = nn.Dropout(dropout_rate)
243
+
244
+ def forward(self, hidden_states, conditioning_emb=None):
245
+ forwarded_states = self.layer_norm(hidden_states)
246
+ if conditioning_emb is not None:
247
+ forwarded_states = self.film(forwarded_states, conditioning_emb)
248
+
249
+ forwarded_states = self.DenseReluDense(forwarded_states)
250
+ hidden_states = hidden_states + self.dropout(forwarded_states)
251
+ return hidden_states
252
+
253
+
254
+ class T5DenseGatedActDense(nn.Module):
255
+ def __init__(self, d_model, d_ff, dropout_rate):
256
+ super().__init__()
257
+ self.wi_0 = nn.Linear(d_model, d_ff, bias=False)
258
+ self.wi_1 = nn.Linear(d_model, d_ff, bias=False)
259
+ self.wo = nn.Linear(d_ff, d_model, bias=False)
260
+ self.dropout = nn.Dropout(dropout_rate)
261
+ self.act = NewGELUActivation()
262
+
263
+ def forward(self, hidden_states):
264
+ hidden_gelu = self.act(self.wi_0(hidden_states))
265
+ hidden_linear = self.wi_1(hidden_states)
266
+ hidden_states = hidden_gelu * hidden_linear
267
+ hidden_states = self.dropout(hidden_states)
268
+
269
+ hidden_states = self.wo(hidden_states)
270
+ return hidden_states
271
+
272
+
273
+ class T5LayerNorm(nn.Module):
274
+ def __init__(self, hidden_size, eps=1e-6):
275
+ """
276
+ Construct a layernorm module in the T5 style. No bias and no subtraction of mean.
277
+ """
278
+ super().__init__()
279
+ self.weight = nn.Parameter(torch.ones(hidden_size))
280
+ self.variance_epsilon = eps
281
+
282
+ def forward(self, hidden_states):
283
+ # T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean
284
+ # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus variance is calculated
285
+ # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for
286
+ # half-precision inputs is done in fp32
287
+
288
+ variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
289
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
290
+
291
+ # convert into half-precision if necessary
292
+ if self.weight.dtype in [torch.float16, torch.bfloat16]:
293
+ hidden_states = hidden_states.to(self.weight.dtype)
294
+
295
+ return self.weight * hidden_states
296
+
297
+
298
+ class NewGELUActivation(nn.Module):
299
+ """
300
+ Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see
301
+ the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
302
+ """
303
+
304
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
305
+ return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))
306
+
307
+
308
+ class T5FiLMLayer(nn.Module):
309
+ """
310
+ FiLM Layer
311
+ """
312
+
313
+ def __init__(self, in_features, out_features):
314
+ super().__init__()
315
+ self.scale_bias = nn.Linear(in_features, out_features * 2, bias=False)
316
+
317
+ def forward(self, x, conditioning_emb):
318
+ emb = self.scale_bias(conditioning_emb)
319
+ scale, shift = torch.chunk(emb, 2, -1)
320
+ x = x * (1 + scale) + shift
321
+ return x
6DoF/diffusers/models/transformer_2d.py ADDED
@@ -0,0 +1,343 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from dataclasses import dataclass
15
+ from typing import Any, Dict, Optional
16
+
17
+ import torch
18
+ import torch.nn.functional as F
19
+ from torch import nn
20
+
21
+ from ..configuration_utils import ConfigMixin, register_to_config
22
+ from ..models.embeddings import ImagePositionalEmbeddings
23
+ from ..utils import BaseOutput, deprecate
24
+ from .attention import BasicTransformerBlock
25
+ from .embeddings import PatchEmbed
26
+ from .modeling_utils import ModelMixin
27
+
28
+
29
+ @dataclass
30
+ class Transformer2DModelOutput(BaseOutput):
31
+ """
32
+ The output of [`Transformer2DModel`].
33
+
34
+ Args:
35
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
36
+ The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
37
+ distributions for the unnoised latent pixels.
38
+ """
39
+
40
+ sample: torch.FloatTensor
41
+
42
+
43
+ class Transformer2DModel(ModelMixin, ConfigMixin):
44
+ """
45
+ A 2D Transformer model for image-like data.
46
+
47
+ Parameters:
48
+ num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
49
+ attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
50
+ in_channels (`int`, *optional*):
51
+ The number of channels in the input and output (specify if the input is **continuous**).
52
+ num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
53
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
54
+ cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
55
+ sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
56
+ This is fixed during training since it is used to learn a number of position embeddings.
57
+ num_vector_embeds (`int`, *optional*):
58
+ The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**).
59
+ Includes the class for the masked latent pixel.
60
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
61
+ num_embeds_ada_norm ( `int`, *optional*):
62
+ The number of diffusion steps used during training. Pass if at least one of the norm_layers is
63
+ `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
64
+ added to the hidden states.
65
+
66
+ During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`.
67
+ attention_bias (`bool`, *optional*):
68
+ Configure if the `TransformerBlocks` attention should contain a bias parameter.
69
+ """
70
+
71
+ @register_to_config
72
+ def __init__(
73
+ self,
74
+ num_attention_heads: int = 16,
75
+ attention_head_dim: int = 88,
76
+ in_channels: Optional[int] = None,
77
+ out_channels: Optional[int] = None,
78
+ num_layers: int = 1,
79
+ dropout: float = 0.0,
80
+ norm_num_groups: int = 32,
81
+ cross_attention_dim: Optional[int] = None,
82
+ attention_bias: bool = False,
83
+ sample_size: Optional[int] = None,
84
+ num_vector_embeds: Optional[int] = None,
85
+ patch_size: Optional[int] = None,
86
+ activation_fn: str = "geglu",
87
+ num_embeds_ada_norm: Optional[int] = None,
88
+ use_linear_projection: bool = False,
89
+ only_cross_attention: bool = False,
90
+ upcast_attention: bool = False,
91
+ norm_type: str = "layer_norm",
92
+ norm_elementwise_affine: bool = True,
93
+ ):
94
+ super().__init__()
95
+ self.use_linear_projection = use_linear_projection
96
+ self.num_attention_heads = num_attention_heads
97
+ self.attention_head_dim = attention_head_dim
98
+ inner_dim = num_attention_heads * attention_head_dim
99
+
100
+ # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
101
+ # Define whether input is continuous or discrete depending on configuration
102
+ self.is_input_continuous = (in_channels is not None) and (patch_size is None)
103
+ self.is_input_vectorized = num_vector_embeds is not None
104
+ self.is_input_patches = in_channels is not None and patch_size is not None
105
+
106
+ if norm_type == "layer_norm" and num_embeds_ada_norm is not None:
107
+ deprecation_message = (
108
+ f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or"
109
+ " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config."
110
+ " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect"
111
+ " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it"
112
+ " would be very nice if you could open a Pull request for the `transformer/config.json` file"
113
+ )
114
+ deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False)
115
+ norm_type = "ada_norm"
116
+
117
+ if self.is_input_continuous and self.is_input_vectorized:
118
+ raise ValueError(
119
+ f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make"
120
+ " sure that either `in_channels` or `num_vector_embeds` is None."
121
+ )
122
+ elif self.is_input_vectorized and self.is_input_patches:
123
+ raise ValueError(
124
+ f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make"
125
+ " sure that either `num_vector_embeds` or `num_patches` is None."
126
+ )
127
+ elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches:
128
+ raise ValueError(
129
+ f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:"
130
+ f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None."
131
+ )
132
+
133
+ # 2. Define input layers
134
+ if self.is_input_continuous:
135
+ self.in_channels = in_channels
136
+
137
+ self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
138
+ if use_linear_projection:
139
+ self.proj_in = nn.Linear(in_channels, inner_dim)
140
+ else:
141
+ self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
142
+ elif self.is_input_vectorized:
143
+ assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
144
+ assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed"
145
+
146
+ self.height = sample_size
147
+ self.width = sample_size
148
+ self.num_vector_embeds = num_vector_embeds
149
+ self.num_latent_pixels = self.height * self.width
150
+
151
+ self.latent_image_embedding = ImagePositionalEmbeddings(
152
+ num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width
153
+ )
154
+ elif self.is_input_patches:
155
+ assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size"
156
+
157
+ self.height = sample_size
158
+ self.width = sample_size
159
+
160
+ self.patch_size = patch_size
161
+ self.pos_embed = PatchEmbed(
162
+ height=sample_size,
163
+ width=sample_size,
164
+ patch_size=patch_size,
165
+ in_channels=in_channels,
166
+ embed_dim=inner_dim,
167
+ )
168
+
169
+ # 3. Define transformers blocks
170
+ self.transformer_blocks = nn.ModuleList(
171
+ [
172
+ BasicTransformerBlock(
173
+ inner_dim,
174
+ num_attention_heads,
175
+ attention_head_dim,
176
+ dropout=dropout,
177
+ cross_attention_dim=cross_attention_dim,
178
+ activation_fn=activation_fn,
179
+ num_embeds_ada_norm=num_embeds_ada_norm,
180
+ attention_bias=attention_bias,
181
+ only_cross_attention=only_cross_attention,
182
+ upcast_attention=upcast_attention,
183
+ norm_type=norm_type,
184
+ norm_elementwise_affine=norm_elementwise_affine,
185
+ )
186
+ for d in range(num_layers)
187
+ ]
188
+ )
189
+
190
+ # 4. Define output layers
191
+ self.out_channels = in_channels if out_channels is None else out_channels
192
+ if self.is_input_continuous:
193
+ # TODO: should use out_channels for continuous projections
194
+ if use_linear_projection:
195
+ self.proj_out = nn.Linear(inner_dim, in_channels)
196
+ else:
197
+ self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
198
+ elif self.is_input_vectorized:
199
+ self.norm_out = nn.LayerNorm(inner_dim)
200
+ self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
201
+ elif self.is_input_patches:
202
+ self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
203
+ self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim)
204
+ self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
205
+
206
+ def forward(
207
+ self,
208
+ hidden_states: torch.Tensor,
209
+ encoder_hidden_states: Optional[torch.Tensor] = None,
210
+ timestep: Optional[torch.LongTensor] = None,
211
+ class_labels: Optional[torch.LongTensor] = None,
212
+ posemb: Optional = None,
213
+ cross_attention_kwargs: Dict[str, Any] = None,
214
+ attention_mask: Optional[torch.Tensor] = None,
215
+ encoder_attention_mask: Optional[torch.Tensor] = None,
216
+ return_dict: bool = True,
217
+ ):
218
+ """
219
+ The [`Transformer2DModel`] forward method.
220
+
221
+ Args:
222
+ hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
223
+ Input `hidden_states`.
224
+ encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
225
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
226
+ self-attention.
227
+ timestep ( `torch.LongTensor`, *optional*):
228
+ Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
229
+ class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
230
+ Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
231
+ `AdaLayerZeroNorm`.
232
+ encoder_attention_mask ( `torch.Tensor`, *optional*):
233
+ Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
234
+
235
+ * Mask `(batch, sequence_length)` True = keep, False = discard.
236
+ * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
237
+
238
+ If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
239
+ above. This bias will be added to the cross-attention scores.
240
+ return_dict (`bool`, *optional*, defaults to `True`):
241
+ Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
242
+ tuple.
243
+
244
+ Returns:
245
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
246
+ `tuple` where the first element is the sample tensor.
247
+ """
248
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
249
+ # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
250
+ # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
251
+ # expects mask of shape:
252
+ # [batch, key_tokens]
253
+ # adds singleton query_tokens dimension:
254
+ # [batch, 1, key_tokens]
255
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
256
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
257
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
258
+ if attention_mask is not None and attention_mask.ndim == 2:
259
+ # assume that mask is expressed as:
260
+ # (1 = keep, 0 = discard)
261
+ # convert mask into a bias that can be added to attention scores:
262
+ # (keep = +0, discard = -10000.0)
263
+ attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
264
+ attention_mask = attention_mask.unsqueeze(1)
265
+
266
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
267
+ if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
268
+ encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
269
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
270
+
271
+ # 1. Input
272
+ if self.is_input_continuous:
273
+ batch, _, height, width = hidden_states.shape
274
+ residual = hidden_states
275
+
276
+ hidden_states = self.norm(hidden_states)
277
+ if not self.use_linear_projection:
278
+ hidden_states = self.proj_in(hidden_states)
279
+ inner_dim = hidden_states.shape[1]
280
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
281
+ else:
282
+ inner_dim = hidden_states.shape[1]
283
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
284
+ hidden_states = self.proj_in(hidden_states)
285
+ elif self.is_input_vectorized:
286
+ hidden_states = self.latent_image_embedding(hidden_states)
287
+ elif self.is_input_patches:
288
+ hidden_states = self.pos_embed(hidden_states)
289
+
290
+ # 2. Blocks
291
+ for block in self.transformer_blocks:
292
+ hidden_states = block(
293
+ hidden_states,
294
+ attention_mask=attention_mask,
295
+ encoder_hidden_states=encoder_hidden_states,
296
+ encoder_attention_mask=encoder_attention_mask,
297
+ timestep=timestep,
298
+ posemb=posemb,
299
+ cross_attention_kwargs=cross_attention_kwargs,
300
+ class_labels=class_labels,
301
+ )
302
+
303
+ # 3. Output
304
+ if self.is_input_continuous:
305
+ if not self.use_linear_projection:
306
+ hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
307
+ hidden_states = self.proj_out(hidden_states)
308
+ else:
309
+ hidden_states = self.proj_out(hidden_states)
310
+ hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
311
+
312
+ output = hidden_states + residual
313
+ elif self.is_input_vectorized:
314
+ hidden_states = self.norm_out(hidden_states)
315
+ logits = self.out(hidden_states)
316
+ # (batch, self.num_vector_embeds - 1, self.num_latent_pixels)
317
+ logits = logits.permute(0, 2, 1)
318
+
319
+ # log(p(x_0))
320
+ output = F.log_softmax(logits.double(), dim=1).float()
321
+ elif self.is_input_patches:
322
+ # TODO: cleanup!
323
+ conditioning = self.transformer_blocks[0].norm1.emb(
324
+ timestep, class_labels, hidden_dtype=hidden_states.dtype
325
+ )
326
+ shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
327
+ hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
328
+ hidden_states = self.proj_out_2(hidden_states)
329
+
330
+ # unpatchify
331
+ height = width = int(hidden_states.shape[1] ** 0.5)
332
+ hidden_states = hidden_states.reshape(
333
+ shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
334
+ )
335
+ hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
336
+ output = hidden_states.reshape(
337
+ shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
338
+ )
339
+
340
+ if not return_dict:
341
+ return (output,)
342
+
343
+ return Transformer2DModelOutput(sample=output)
6DoF/diffusers/models/transformer_temporal.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from dataclasses import dataclass
15
+ from typing import Optional
16
+
17
+ import torch
18
+ from torch import nn
19
+
20
+ from ..configuration_utils import ConfigMixin, register_to_config
21
+ from ..utils import BaseOutput
22
+ from .attention import BasicTransformerBlock
23
+ from .modeling_utils import ModelMixin
24
+
25
+
26
+ @dataclass
27
+ class TransformerTemporalModelOutput(BaseOutput):
28
+ """
29
+ The output of [`TransformerTemporalModel`].
30
+
31
+ Args:
32
+ sample (`torch.FloatTensor` of shape `(batch_size x num_frames, num_channels, height, width)`):
33
+ The hidden states output conditioned on `encoder_hidden_states` input.
34
+ """
35
+
36
+ sample: torch.FloatTensor
37
+
38
+
39
+ class TransformerTemporalModel(ModelMixin, ConfigMixin):
40
+ """
41
+ A Transformer model for video-like data.
42
+
43
+ Parameters:
44
+ num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
45
+ attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
46
+ in_channels (`int`, *optional*):
47
+ The number of channels in the input and output (specify if the input is **continuous**).
48
+ num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
49
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
50
+ cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
51
+ sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
52
+ This is fixed during training since it is used to learn a number of position embeddings.
53
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
54
+ attention_bias (`bool`, *optional*):
55
+ Configure if the `TransformerBlock` attention should contain a bias parameter.
56
+ double_self_attention (`bool`, *optional*):
57
+ Configure if each `TransformerBlock` should contain two self-attention layers.
58
+ """
59
+
60
+ @register_to_config
61
+ def __init__(
62
+ self,
63
+ num_attention_heads: int = 16,
64
+ attention_head_dim: int = 88,
65
+ in_channels: Optional[int] = None,
66
+ out_channels: Optional[int] = None,
67
+ num_layers: int = 1,
68
+ dropout: float = 0.0,
69
+ norm_num_groups: int = 32,
70
+ cross_attention_dim: Optional[int] = None,
71
+ attention_bias: bool = False,
72
+ sample_size: Optional[int] = None,
73
+ activation_fn: str = "geglu",
74
+ norm_elementwise_affine: bool = True,
75
+ double_self_attention: bool = True,
76
+ ):
77
+ super().__init__()
78
+ self.num_attention_heads = num_attention_heads
79
+ self.attention_head_dim = attention_head_dim
80
+ inner_dim = num_attention_heads * attention_head_dim
81
+
82
+ self.in_channels = in_channels
83
+
84
+ self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
85
+ self.proj_in = nn.Linear(in_channels, inner_dim)
86
+
87
+ # 3. Define transformers blocks
88
+ self.transformer_blocks = nn.ModuleList(
89
+ [
90
+ BasicTransformerBlock(
91
+ inner_dim,
92
+ num_attention_heads,
93
+ attention_head_dim,
94
+ dropout=dropout,
95
+ cross_attention_dim=cross_attention_dim,
96
+ activation_fn=activation_fn,
97
+ attention_bias=attention_bias,
98
+ double_self_attention=double_self_attention,
99
+ norm_elementwise_affine=norm_elementwise_affine,
100
+ )
101
+ for d in range(num_layers)
102
+ ]
103
+ )
104
+
105
+ self.proj_out = nn.Linear(inner_dim, in_channels)
106
+
107
+ def forward(
108
+ self,
109
+ hidden_states,
110
+ encoder_hidden_states=None,
111
+ timestep=None,
112
+ class_labels=None,
113
+ num_frames=1,
114
+ cross_attention_kwargs=None,
115
+ return_dict: bool = True,
116
+ ):
117
+ """
118
+ The [`TransformerTemporal`] forward method.
119
+
120
+ Args:
121
+ hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
122
+ Input hidden_states.
123
+ encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
124
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
125
+ self-attention.
126
+ timestep ( `torch.long`, *optional*):
127
+ Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
128
+ class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
129
+ Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
130
+ `AdaLayerZeroNorm`.
131
+ return_dict (`bool`, *optional*, defaults to `True`):
132
+ Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
133
+ tuple.
134
+
135
+ Returns:
136
+ [`~models.transformer_temporal.TransformerTemporalModelOutput`] or `tuple`:
137
+ If `return_dict` is True, an [`~models.transformer_temporal.TransformerTemporalModelOutput`] is
138
+ returned, otherwise a `tuple` where the first element is the sample tensor.
139
+ """
140
+ # 1. Input
141
+ batch_frames, channel, height, width = hidden_states.shape
142
+ batch_size = batch_frames // num_frames
143
+
144
+ residual = hidden_states
145
+
146
+ hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, channel, height, width)
147
+ hidden_states = hidden_states.permute(0, 2, 1, 3, 4)
148
+
149
+ hidden_states = self.norm(hidden_states)
150
+ hidden_states = hidden_states.permute(0, 3, 4, 2, 1).reshape(batch_size * height * width, num_frames, channel)
151
+
152
+ hidden_states = self.proj_in(hidden_states)
153
+
154
+ # 2. Blocks
155
+ for block in self.transformer_blocks:
156
+ hidden_states = block(
157
+ hidden_states,
158
+ encoder_hidden_states=encoder_hidden_states,
159
+ timestep=timestep,
160
+ cross_attention_kwargs=cross_attention_kwargs,
161
+ class_labels=class_labels,
162
+ )
163
+
164
+ # 3. Output
165
+ hidden_states = self.proj_out(hidden_states)
166
+ hidden_states = (
167
+ hidden_states[None, None, :]
168
+ .reshape(batch_size, height, width, channel, num_frames)
169
+ .permute(0, 3, 4, 1, 2)
170
+ .contiguous()
171
+ )
172
+ hidden_states = hidden_states.reshape(batch_frames, channel, height, width)
173
+
174
+ output = hidden_states + residual
175
+
176
+ if not return_dict:
177
+ return (output,)
178
+
179
+ return TransformerTemporalModelOutput(sample=output)
6DoF/diffusers/models/unet_1d.py ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from dataclasses import dataclass
16
+ from typing import Optional, Tuple, Union
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+
21
+ from ..configuration_utils import ConfigMixin, register_to_config
22
+ from ..utils import BaseOutput
23
+ from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
24
+ from .modeling_utils import ModelMixin
25
+ from .unet_1d_blocks import get_down_block, get_mid_block, get_out_block, get_up_block
26
+
27
+
28
+ @dataclass
29
+ class UNet1DOutput(BaseOutput):
30
+ """
31
+ The output of [`UNet1DModel`].
32
+
33
+ Args:
34
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, sample_size)`):
35
+ The hidden states output from the last layer of the model.
36
+ """
37
+
38
+ sample: torch.FloatTensor
39
+
40
+
41
+ class UNet1DModel(ModelMixin, ConfigMixin):
42
+ r"""
43
+ A 1D UNet model that takes a noisy sample and a timestep and returns a sample shaped output.
44
+
45
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
46
+ for all models (such as downloading or saving).
47
+
48
+ Parameters:
49
+ sample_size (`int`, *optional*): Default length of sample. Should be adaptable at runtime.
50
+ in_channels (`int`, *optional*, defaults to 2): Number of channels in the input sample.
51
+ out_channels (`int`, *optional*, defaults to 2): Number of channels in the output.
52
+ extra_in_channels (`int`, *optional*, defaults to 0):
53
+ Number of additional channels to be added to the input of the first down block. Useful for cases where the
54
+ input data has more channels than what the model was initially designed for.
55
+ time_embedding_type (`str`, *optional*, defaults to `"fourier"`): Type of time embedding to use.
56
+ freq_shift (`float`, *optional*, defaults to 0.0): Frequency shift for Fourier time embedding.
57
+ flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
58
+ Whether to flip sin to cos for Fourier time embedding.
59
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("DownBlock1D", "DownBlock1DNoSkip", "AttnDownBlock1D")`):
60
+ Tuple of downsample block types.
61
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock1D", "UpBlock1DNoSkip", "AttnUpBlock1D")`):
62
+ Tuple of upsample block types.
63
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(32, 32, 64)`):
64
+ Tuple of block output channels.
65
+ mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock1D"`): Block type for middle of UNet.
66
+ out_block_type (`str`, *optional*, defaults to `None`): Optional output processing block of UNet.
67
+ act_fn (`str`, *optional*, defaults to `None`): Optional activation function in UNet blocks.
68
+ norm_num_groups (`int`, *optional*, defaults to 8): The number of groups for normalization.
69
+ layers_per_block (`int`, *optional*, defaults to 1): The number of layers per block.
70
+ downsample_each_block (`int`, *optional*, defaults to `False`):
71
+ Experimental feature for using a UNet without upsampling.
72
+ """
73
+
74
+ @register_to_config
75
+ def __init__(
76
+ self,
77
+ sample_size: int = 65536,
78
+ sample_rate: Optional[int] = None,
79
+ in_channels: int = 2,
80
+ out_channels: int = 2,
81
+ extra_in_channels: int = 0,
82
+ time_embedding_type: str = "fourier",
83
+ flip_sin_to_cos: bool = True,
84
+ use_timestep_embedding: bool = False,
85
+ freq_shift: float = 0.0,
86
+ down_block_types: Tuple[str] = ("DownBlock1DNoSkip", "DownBlock1D", "AttnDownBlock1D"),
87
+ up_block_types: Tuple[str] = ("AttnUpBlock1D", "UpBlock1D", "UpBlock1DNoSkip"),
88
+ mid_block_type: Tuple[str] = "UNetMidBlock1D",
89
+ out_block_type: str = None,
90
+ block_out_channels: Tuple[int] = (32, 32, 64),
91
+ act_fn: str = None,
92
+ norm_num_groups: int = 8,
93
+ layers_per_block: int = 1,
94
+ downsample_each_block: bool = False,
95
+ ):
96
+ super().__init__()
97
+ self.sample_size = sample_size
98
+
99
+ # time
100
+ if time_embedding_type == "fourier":
101
+ self.time_proj = GaussianFourierProjection(
102
+ embedding_size=8, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
103
+ )
104
+ timestep_input_dim = 2 * block_out_channels[0]
105
+ elif time_embedding_type == "positional":
106
+ self.time_proj = Timesteps(
107
+ block_out_channels[0], flip_sin_to_cos=flip_sin_to_cos, downscale_freq_shift=freq_shift
108
+ )
109
+ timestep_input_dim = block_out_channels[0]
110
+
111
+ if use_timestep_embedding:
112
+ time_embed_dim = block_out_channels[0] * 4
113
+ self.time_mlp = TimestepEmbedding(
114
+ in_channels=timestep_input_dim,
115
+ time_embed_dim=time_embed_dim,
116
+ act_fn=act_fn,
117
+ out_dim=block_out_channels[0],
118
+ )
119
+
120
+ self.down_blocks = nn.ModuleList([])
121
+ self.mid_block = None
122
+ self.up_blocks = nn.ModuleList([])
123
+ self.out_block = None
124
+
125
+ # down
126
+ output_channel = in_channels
127
+ for i, down_block_type in enumerate(down_block_types):
128
+ input_channel = output_channel
129
+ output_channel = block_out_channels[i]
130
+
131
+ if i == 0:
132
+ input_channel += extra_in_channels
133
+
134
+ is_final_block = i == len(block_out_channels) - 1
135
+
136
+ down_block = get_down_block(
137
+ down_block_type,
138
+ num_layers=layers_per_block,
139
+ in_channels=input_channel,
140
+ out_channels=output_channel,
141
+ temb_channels=block_out_channels[0],
142
+ add_downsample=not is_final_block or downsample_each_block,
143
+ )
144
+ self.down_blocks.append(down_block)
145
+
146
+ # mid
147
+ self.mid_block = get_mid_block(
148
+ mid_block_type,
149
+ in_channels=block_out_channels[-1],
150
+ mid_channels=block_out_channels[-1],
151
+ out_channels=block_out_channels[-1],
152
+ embed_dim=block_out_channels[0],
153
+ num_layers=layers_per_block,
154
+ add_downsample=downsample_each_block,
155
+ )
156
+
157
+ # up
158
+ reversed_block_out_channels = list(reversed(block_out_channels))
159
+ output_channel = reversed_block_out_channels[0]
160
+ if out_block_type is None:
161
+ final_upsample_channels = out_channels
162
+ else:
163
+ final_upsample_channels = block_out_channels[0]
164
+
165
+ for i, up_block_type in enumerate(up_block_types):
166
+ prev_output_channel = output_channel
167
+ output_channel = (
168
+ reversed_block_out_channels[i + 1] if i < len(up_block_types) - 1 else final_upsample_channels
169
+ )
170
+
171
+ is_final_block = i == len(block_out_channels) - 1
172
+
173
+ up_block = get_up_block(
174
+ up_block_type,
175
+ num_layers=layers_per_block,
176
+ in_channels=prev_output_channel,
177
+ out_channels=output_channel,
178
+ temb_channels=block_out_channels[0],
179
+ add_upsample=not is_final_block,
180
+ )
181
+ self.up_blocks.append(up_block)
182
+ prev_output_channel = output_channel
183
+
184
+ # out
185
+ num_groups_out = norm_num_groups if norm_num_groups is not None else min(block_out_channels[0] // 4, 32)
186
+ self.out_block = get_out_block(
187
+ out_block_type=out_block_type,
188
+ num_groups_out=num_groups_out,
189
+ embed_dim=block_out_channels[0],
190
+ out_channels=out_channels,
191
+ act_fn=act_fn,
192
+ fc_dim=block_out_channels[-1] // 4,
193
+ )
194
+
195
+ def forward(
196
+ self,
197
+ sample: torch.FloatTensor,
198
+ timestep: Union[torch.Tensor, float, int],
199
+ return_dict: bool = True,
200
+ ) -> Union[UNet1DOutput, Tuple]:
201
+ r"""
202
+ The [`UNet1DModel`] forward method.
203
+
204
+ Args:
205
+ sample (`torch.FloatTensor`):
206
+ The noisy input tensor with the following shape `(batch_size, num_channels, sample_size)`.
207
+ timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
208
+ return_dict (`bool`, *optional*, defaults to `True`):
209
+ Whether or not to return a [`~models.unet_1d.UNet1DOutput`] instead of a plain tuple.
210
+
211
+ Returns:
212
+ [`~models.unet_1d.UNet1DOutput`] or `tuple`:
213
+ If `return_dict` is True, an [`~models.unet_1d.UNet1DOutput`] is returned, otherwise a `tuple` is
214
+ returned where the first element is the sample tensor.
215
+ """
216
+
217
+ # 1. time
218
+ timesteps = timestep
219
+ if not torch.is_tensor(timesteps):
220
+ timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
221
+ elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
222
+ timesteps = timesteps[None].to(sample.device)
223
+
224
+ timestep_embed = self.time_proj(timesteps)
225
+ if self.config.use_timestep_embedding:
226
+ timestep_embed = self.time_mlp(timestep_embed)
227
+ else:
228
+ timestep_embed = timestep_embed[..., None]
229
+ timestep_embed = timestep_embed.repeat([1, 1, sample.shape[2]]).to(sample.dtype)
230
+ timestep_embed = timestep_embed.broadcast_to((sample.shape[:1] + timestep_embed.shape[1:]))
231
+
232
+ # 2. down
233
+ down_block_res_samples = ()
234
+ for downsample_block in self.down_blocks:
235
+ sample, res_samples = downsample_block(hidden_states=sample, temb=timestep_embed)
236
+ down_block_res_samples += res_samples
237
+
238
+ # 3. mid
239
+ if self.mid_block:
240
+ sample = self.mid_block(sample, timestep_embed)
241
+
242
+ # 4. up
243
+ for i, upsample_block in enumerate(self.up_blocks):
244
+ res_samples = down_block_res_samples[-1:]
245
+ down_block_res_samples = down_block_res_samples[:-1]
246
+ sample = upsample_block(sample, res_hidden_states_tuple=res_samples, temb=timestep_embed)
247
+
248
+ # 5. post-process
249
+ if self.out_block:
250
+ sample = self.out_block(sample, timestep_embed)
251
+
252
+ if not return_dict:
253
+ return (sample,)
254
+
255
+ return UNet1DOutput(sample=sample)
6DoF/diffusers/models/unet_1d_blocks.py ADDED
@@ -0,0 +1,656 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import math
15
+
16
+ import torch
17
+ import torch.nn.functional as F
18
+ from torch import nn
19
+
20
+ from .activations import get_activation
21
+ from .resnet import Downsample1D, ResidualTemporalBlock1D, Upsample1D, rearrange_dims
22
+
23
+
24
+ class DownResnetBlock1D(nn.Module):
25
+ def __init__(
26
+ self,
27
+ in_channels,
28
+ out_channels=None,
29
+ num_layers=1,
30
+ conv_shortcut=False,
31
+ temb_channels=32,
32
+ groups=32,
33
+ groups_out=None,
34
+ non_linearity=None,
35
+ time_embedding_norm="default",
36
+ output_scale_factor=1.0,
37
+ add_downsample=True,
38
+ ):
39
+ super().__init__()
40
+ self.in_channels = in_channels
41
+ out_channels = in_channels if out_channels is None else out_channels
42
+ self.out_channels = out_channels
43
+ self.use_conv_shortcut = conv_shortcut
44
+ self.time_embedding_norm = time_embedding_norm
45
+ self.add_downsample = add_downsample
46
+ self.output_scale_factor = output_scale_factor
47
+
48
+ if groups_out is None:
49
+ groups_out = groups
50
+
51
+ # there will always be at least one resnet
52
+ resnets = [ResidualTemporalBlock1D(in_channels, out_channels, embed_dim=temb_channels)]
53
+
54
+ for _ in range(num_layers):
55
+ resnets.append(ResidualTemporalBlock1D(out_channels, out_channels, embed_dim=temb_channels))
56
+
57
+ self.resnets = nn.ModuleList(resnets)
58
+
59
+ if non_linearity is None:
60
+ self.nonlinearity = None
61
+ else:
62
+ self.nonlinearity = get_activation(non_linearity)
63
+
64
+ self.downsample = None
65
+ if add_downsample:
66
+ self.downsample = Downsample1D(out_channels, use_conv=True, padding=1)
67
+
68
+ def forward(self, hidden_states, temb=None):
69
+ output_states = ()
70
+
71
+ hidden_states = self.resnets[0](hidden_states, temb)
72
+ for resnet in self.resnets[1:]:
73
+ hidden_states = resnet(hidden_states, temb)
74
+
75
+ output_states += (hidden_states,)
76
+
77
+ if self.nonlinearity is not None:
78
+ hidden_states = self.nonlinearity(hidden_states)
79
+
80
+ if self.downsample is not None:
81
+ hidden_states = self.downsample(hidden_states)
82
+
83
+ return hidden_states, output_states
84
+
85
+
86
+ class UpResnetBlock1D(nn.Module):
87
+ def __init__(
88
+ self,
89
+ in_channels,
90
+ out_channels=None,
91
+ num_layers=1,
92
+ temb_channels=32,
93
+ groups=32,
94
+ groups_out=None,
95
+ non_linearity=None,
96
+ time_embedding_norm="default",
97
+ output_scale_factor=1.0,
98
+ add_upsample=True,
99
+ ):
100
+ super().__init__()
101
+ self.in_channels = in_channels
102
+ out_channels = in_channels if out_channels is None else out_channels
103
+ self.out_channels = out_channels
104
+ self.time_embedding_norm = time_embedding_norm
105
+ self.add_upsample = add_upsample
106
+ self.output_scale_factor = output_scale_factor
107
+
108
+ if groups_out is None:
109
+ groups_out = groups
110
+
111
+ # there will always be at least one resnet
112
+ resnets = [ResidualTemporalBlock1D(2 * in_channels, out_channels, embed_dim=temb_channels)]
113
+
114
+ for _ in range(num_layers):
115
+ resnets.append(ResidualTemporalBlock1D(out_channels, out_channels, embed_dim=temb_channels))
116
+
117
+ self.resnets = nn.ModuleList(resnets)
118
+
119
+ if non_linearity is None:
120
+ self.nonlinearity = None
121
+ else:
122
+ self.nonlinearity = get_activation(non_linearity)
123
+
124
+ self.upsample = None
125
+ if add_upsample:
126
+ self.upsample = Upsample1D(out_channels, use_conv_transpose=True)
127
+
128
+ def forward(self, hidden_states, res_hidden_states_tuple=None, temb=None):
129
+ if res_hidden_states_tuple is not None:
130
+ res_hidden_states = res_hidden_states_tuple[-1]
131
+ hidden_states = torch.cat((hidden_states, res_hidden_states), dim=1)
132
+
133
+ hidden_states = self.resnets[0](hidden_states, temb)
134
+ for resnet in self.resnets[1:]:
135
+ hidden_states = resnet(hidden_states, temb)
136
+
137
+ if self.nonlinearity is not None:
138
+ hidden_states = self.nonlinearity(hidden_states)
139
+
140
+ if self.upsample is not None:
141
+ hidden_states = self.upsample(hidden_states)
142
+
143
+ return hidden_states
144
+
145
+
146
+ class ValueFunctionMidBlock1D(nn.Module):
147
+ def __init__(self, in_channels, out_channels, embed_dim):
148
+ super().__init__()
149
+ self.in_channels = in_channels
150
+ self.out_channels = out_channels
151
+ self.embed_dim = embed_dim
152
+
153
+ self.res1 = ResidualTemporalBlock1D(in_channels, in_channels // 2, embed_dim=embed_dim)
154
+ self.down1 = Downsample1D(out_channels // 2, use_conv=True)
155
+ self.res2 = ResidualTemporalBlock1D(in_channels // 2, in_channels // 4, embed_dim=embed_dim)
156
+ self.down2 = Downsample1D(out_channels // 4, use_conv=True)
157
+
158
+ def forward(self, x, temb=None):
159
+ x = self.res1(x, temb)
160
+ x = self.down1(x)
161
+ x = self.res2(x, temb)
162
+ x = self.down2(x)
163
+ return x
164
+
165
+
166
+ class MidResTemporalBlock1D(nn.Module):
167
+ def __init__(
168
+ self,
169
+ in_channels,
170
+ out_channels,
171
+ embed_dim,
172
+ num_layers: int = 1,
173
+ add_downsample: bool = False,
174
+ add_upsample: bool = False,
175
+ non_linearity=None,
176
+ ):
177
+ super().__init__()
178
+ self.in_channels = in_channels
179
+ self.out_channels = out_channels
180
+ self.add_downsample = add_downsample
181
+
182
+ # there will always be at least one resnet
183
+ resnets = [ResidualTemporalBlock1D(in_channels, out_channels, embed_dim=embed_dim)]
184
+
185
+ for _ in range(num_layers):
186
+ resnets.append(ResidualTemporalBlock1D(out_channels, out_channels, embed_dim=embed_dim))
187
+
188
+ self.resnets = nn.ModuleList(resnets)
189
+
190
+ if non_linearity is None:
191
+ self.nonlinearity = None
192
+ else:
193
+ self.nonlinearity = get_activation(non_linearity)
194
+
195
+ self.upsample = None
196
+ if add_upsample:
197
+ self.upsample = Downsample1D(out_channels, use_conv=True)
198
+
199
+ self.downsample = None
200
+ if add_downsample:
201
+ self.downsample = Downsample1D(out_channels, use_conv=True)
202
+
203
+ if self.upsample and self.downsample:
204
+ raise ValueError("Block cannot downsample and upsample")
205
+
206
+ def forward(self, hidden_states, temb):
207
+ hidden_states = self.resnets[0](hidden_states, temb)
208
+ for resnet in self.resnets[1:]:
209
+ hidden_states = resnet(hidden_states, temb)
210
+
211
+ if self.upsample:
212
+ hidden_states = self.upsample(hidden_states)
213
+ if self.downsample:
214
+ self.downsample = self.downsample(hidden_states)
215
+
216
+ return hidden_states
217
+
218
+
219
+ class OutConv1DBlock(nn.Module):
220
+ def __init__(self, num_groups_out, out_channels, embed_dim, act_fn):
221
+ super().__init__()
222
+ self.final_conv1d_1 = nn.Conv1d(embed_dim, embed_dim, 5, padding=2)
223
+ self.final_conv1d_gn = nn.GroupNorm(num_groups_out, embed_dim)
224
+ self.final_conv1d_act = get_activation(act_fn)
225
+ self.final_conv1d_2 = nn.Conv1d(embed_dim, out_channels, 1)
226
+
227
+ def forward(self, hidden_states, temb=None):
228
+ hidden_states = self.final_conv1d_1(hidden_states)
229
+ hidden_states = rearrange_dims(hidden_states)
230
+ hidden_states = self.final_conv1d_gn(hidden_states)
231
+ hidden_states = rearrange_dims(hidden_states)
232
+ hidden_states = self.final_conv1d_act(hidden_states)
233
+ hidden_states = self.final_conv1d_2(hidden_states)
234
+ return hidden_states
235
+
236
+
237
+ class OutValueFunctionBlock(nn.Module):
238
+ def __init__(self, fc_dim, embed_dim):
239
+ super().__init__()
240
+ self.final_block = nn.ModuleList(
241
+ [
242
+ nn.Linear(fc_dim + embed_dim, fc_dim // 2),
243
+ nn.Mish(),
244
+ nn.Linear(fc_dim // 2, 1),
245
+ ]
246
+ )
247
+
248
+ def forward(self, hidden_states, temb):
249
+ hidden_states = hidden_states.view(hidden_states.shape[0], -1)
250
+ hidden_states = torch.cat((hidden_states, temb), dim=-1)
251
+ for layer in self.final_block:
252
+ hidden_states = layer(hidden_states)
253
+
254
+ return hidden_states
255
+
256
+
257
+ _kernels = {
258
+ "linear": [1 / 8, 3 / 8, 3 / 8, 1 / 8],
259
+ "cubic": [-0.01171875, -0.03515625, 0.11328125, 0.43359375, 0.43359375, 0.11328125, -0.03515625, -0.01171875],
260
+ "lanczos3": [
261
+ 0.003689131001010537,
262
+ 0.015056144446134567,
263
+ -0.03399861603975296,
264
+ -0.066637322306633,
265
+ 0.13550527393817902,
266
+ 0.44638532400131226,
267
+ 0.44638532400131226,
268
+ 0.13550527393817902,
269
+ -0.066637322306633,
270
+ -0.03399861603975296,
271
+ 0.015056144446134567,
272
+ 0.003689131001010537,
273
+ ],
274
+ }
275
+
276
+
277
+ class Downsample1d(nn.Module):
278
+ def __init__(self, kernel="linear", pad_mode="reflect"):
279
+ super().__init__()
280
+ self.pad_mode = pad_mode
281
+ kernel_1d = torch.tensor(_kernels[kernel])
282
+ self.pad = kernel_1d.shape[0] // 2 - 1
283
+ self.register_buffer("kernel", kernel_1d)
284
+
285
+ def forward(self, hidden_states):
286
+ hidden_states = F.pad(hidden_states, (self.pad,) * 2, self.pad_mode)
287
+ weight = hidden_states.new_zeros([hidden_states.shape[1], hidden_states.shape[1], self.kernel.shape[0]])
288
+ indices = torch.arange(hidden_states.shape[1], device=hidden_states.device)
289
+ kernel = self.kernel.to(weight)[None, :].expand(hidden_states.shape[1], -1)
290
+ weight[indices, indices] = kernel
291
+ return F.conv1d(hidden_states, weight, stride=2)
292
+
293
+
294
+ class Upsample1d(nn.Module):
295
+ def __init__(self, kernel="linear", pad_mode="reflect"):
296
+ super().__init__()
297
+ self.pad_mode = pad_mode
298
+ kernel_1d = torch.tensor(_kernels[kernel]) * 2
299
+ self.pad = kernel_1d.shape[0] // 2 - 1
300
+ self.register_buffer("kernel", kernel_1d)
301
+
302
+ def forward(self, hidden_states, temb=None):
303
+ hidden_states = F.pad(hidden_states, ((self.pad + 1) // 2,) * 2, self.pad_mode)
304
+ weight = hidden_states.new_zeros([hidden_states.shape[1], hidden_states.shape[1], self.kernel.shape[0]])
305
+ indices = torch.arange(hidden_states.shape[1], device=hidden_states.device)
306
+ kernel = self.kernel.to(weight)[None, :].expand(hidden_states.shape[1], -1)
307
+ weight[indices, indices] = kernel
308
+ return F.conv_transpose1d(hidden_states, weight, stride=2, padding=self.pad * 2 + 1)
309
+
310
+
311
+ class SelfAttention1d(nn.Module):
312
+ def __init__(self, in_channels, n_head=1, dropout_rate=0.0):
313
+ super().__init__()
314
+ self.channels = in_channels
315
+ self.group_norm = nn.GroupNorm(1, num_channels=in_channels)
316
+ self.num_heads = n_head
317
+
318
+ self.query = nn.Linear(self.channels, self.channels)
319
+ self.key = nn.Linear(self.channels, self.channels)
320
+ self.value = nn.Linear(self.channels, self.channels)
321
+
322
+ self.proj_attn = nn.Linear(self.channels, self.channels, bias=True)
323
+
324
+ self.dropout = nn.Dropout(dropout_rate, inplace=True)
325
+
326
+ def transpose_for_scores(self, projection: torch.Tensor) -> torch.Tensor:
327
+ new_projection_shape = projection.size()[:-1] + (self.num_heads, -1)
328
+ # move heads to 2nd position (B, T, H * D) -> (B, T, H, D) -> (B, H, T, D)
329
+ new_projection = projection.view(new_projection_shape).permute(0, 2, 1, 3)
330
+ return new_projection
331
+
332
+ def forward(self, hidden_states):
333
+ residual = hidden_states
334
+ batch, channel_dim, seq = hidden_states.shape
335
+
336
+ hidden_states = self.group_norm(hidden_states)
337
+ hidden_states = hidden_states.transpose(1, 2)
338
+
339
+ query_proj = self.query(hidden_states)
340
+ key_proj = self.key(hidden_states)
341
+ value_proj = self.value(hidden_states)
342
+
343
+ query_states = self.transpose_for_scores(query_proj)
344
+ key_states = self.transpose_for_scores(key_proj)
345
+ value_states = self.transpose_for_scores(value_proj)
346
+
347
+ scale = 1 / math.sqrt(math.sqrt(key_states.shape[-1]))
348
+
349
+ attention_scores = torch.matmul(query_states * scale, key_states.transpose(-1, -2) * scale)
350
+ attention_probs = torch.softmax(attention_scores, dim=-1)
351
+
352
+ # compute attention output
353
+ hidden_states = torch.matmul(attention_probs, value_states)
354
+
355
+ hidden_states = hidden_states.permute(0, 2, 1, 3).contiguous()
356
+ new_hidden_states_shape = hidden_states.size()[:-2] + (self.channels,)
357
+ hidden_states = hidden_states.view(new_hidden_states_shape)
358
+
359
+ # compute next hidden_states
360
+ hidden_states = self.proj_attn(hidden_states)
361
+ hidden_states = hidden_states.transpose(1, 2)
362
+ hidden_states = self.dropout(hidden_states)
363
+
364
+ output = hidden_states + residual
365
+
366
+ return output
367
+
368
+
369
+ class ResConvBlock(nn.Module):
370
+ def __init__(self, in_channels, mid_channels, out_channels, is_last=False):
371
+ super().__init__()
372
+ self.is_last = is_last
373
+ self.has_conv_skip = in_channels != out_channels
374
+
375
+ if self.has_conv_skip:
376
+ self.conv_skip = nn.Conv1d(in_channels, out_channels, 1, bias=False)
377
+
378
+ self.conv_1 = nn.Conv1d(in_channels, mid_channels, 5, padding=2)
379
+ self.group_norm_1 = nn.GroupNorm(1, mid_channels)
380
+ self.gelu_1 = nn.GELU()
381
+ self.conv_2 = nn.Conv1d(mid_channels, out_channels, 5, padding=2)
382
+
383
+ if not self.is_last:
384
+ self.group_norm_2 = nn.GroupNorm(1, out_channels)
385
+ self.gelu_2 = nn.GELU()
386
+
387
+ def forward(self, hidden_states):
388
+ residual = self.conv_skip(hidden_states) if self.has_conv_skip else hidden_states
389
+
390
+ hidden_states = self.conv_1(hidden_states)
391
+ hidden_states = self.group_norm_1(hidden_states)
392
+ hidden_states = self.gelu_1(hidden_states)
393
+ hidden_states = self.conv_2(hidden_states)
394
+
395
+ if not self.is_last:
396
+ hidden_states = self.group_norm_2(hidden_states)
397
+ hidden_states = self.gelu_2(hidden_states)
398
+
399
+ output = hidden_states + residual
400
+ return output
401
+
402
+
403
+ class UNetMidBlock1D(nn.Module):
404
+ def __init__(self, mid_channels, in_channels, out_channels=None):
405
+ super().__init__()
406
+
407
+ out_channels = in_channels if out_channels is None else out_channels
408
+
409
+ # there is always at least one resnet
410
+ self.down = Downsample1d("cubic")
411
+ resnets = [
412
+ ResConvBlock(in_channels, mid_channels, mid_channels),
413
+ ResConvBlock(mid_channels, mid_channels, mid_channels),
414
+ ResConvBlock(mid_channels, mid_channels, mid_channels),
415
+ ResConvBlock(mid_channels, mid_channels, mid_channels),
416
+ ResConvBlock(mid_channels, mid_channels, mid_channels),
417
+ ResConvBlock(mid_channels, mid_channels, out_channels),
418
+ ]
419
+ attentions = [
420
+ SelfAttention1d(mid_channels, mid_channels // 32),
421
+ SelfAttention1d(mid_channels, mid_channels // 32),
422
+ SelfAttention1d(mid_channels, mid_channels // 32),
423
+ SelfAttention1d(mid_channels, mid_channels // 32),
424
+ SelfAttention1d(mid_channels, mid_channels // 32),
425
+ SelfAttention1d(out_channels, out_channels // 32),
426
+ ]
427
+ self.up = Upsample1d(kernel="cubic")
428
+
429
+ self.attentions = nn.ModuleList(attentions)
430
+ self.resnets = nn.ModuleList(resnets)
431
+
432
+ def forward(self, hidden_states, temb=None):
433
+ hidden_states = self.down(hidden_states)
434
+ for attn, resnet in zip(self.attentions, self.resnets):
435
+ hidden_states = resnet(hidden_states)
436
+ hidden_states = attn(hidden_states)
437
+
438
+ hidden_states = self.up(hidden_states)
439
+
440
+ return hidden_states
441
+
442
+
443
+ class AttnDownBlock1D(nn.Module):
444
+ def __init__(self, out_channels, in_channels, mid_channels=None):
445
+ super().__init__()
446
+ mid_channels = out_channels if mid_channels is None else mid_channels
447
+
448
+ self.down = Downsample1d("cubic")
449
+ resnets = [
450
+ ResConvBlock(in_channels, mid_channels, mid_channels),
451
+ ResConvBlock(mid_channels, mid_channels, mid_channels),
452
+ ResConvBlock(mid_channels, mid_channels, out_channels),
453
+ ]
454
+ attentions = [
455
+ SelfAttention1d(mid_channels, mid_channels // 32),
456
+ SelfAttention1d(mid_channels, mid_channels // 32),
457
+ SelfAttention1d(out_channels, out_channels // 32),
458
+ ]
459
+
460
+ self.attentions = nn.ModuleList(attentions)
461
+ self.resnets = nn.ModuleList(resnets)
462
+
463
+ def forward(self, hidden_states, temb=None):
464
+ hidden_states = self.down(hidden_states)
465
+
466
+ for resnet, attn in zip(self.resnets, self.attentions):
467
+ hidden_states = resnet(hidden_states)
468
+ hidden_states = attn(hidden_states)
469
+
470
+ return hidden_states, (hidden_states,)
471
+
472
+
473
+ class DownBlock1D(nn.Module):
474
+ def __init__(self, out_channels, in_channels, mid_channels=None):
475
+ super().__init__()
476
+ mid_channels = out_channels if mid_channels is None else mid_channels
477
+
478
+ self.down = Downsample1d("cubic")
479
+ resnets = [
480
+ ResConvBlock(in_channels, mid_channels, mid_channels),
481
+ ResConvBlock(mid_channels, mid_channels, mid_channels),
482
+ ResConvBlock(mid_channels, mid_channels, out_channels),
483
+ ]
484
+
485
+ self.resnets = nn.ModuleList(resnets)
486
+
487
+ def forward(self, hidden_states, temb=None):
488
+ hidden_states = self.down(hidden_states)
489
+
490
+ for resnet in self.resnets:
491
+ hidden_states = resnet(hidden_states)
492
+
493
+ return hidden_states, (hidden_states,)
494
+
495
+
496
+ class DownBlock1DNoSkip(nn.Module):
497
+ def __init__(self, out_channels, in_channels, mid_channels=None):
498
+ super().__init__()
499
+ mid_channels = out_channels if mid_channels is None else mid_channels
500
+
501
+ resnets = [
502
+ ResConvBlock(in_channels, mid_channels, mid_channels),
503
+ ResConvBlock(mid_channels, mid_channels, mid_channels),
504
+ ResConvBlock(mid_channels, mid_channels, out_channels),
505
+ ]
506
+
507
+ self.resnets = nn.ModuleList(resnets)
508
+
509
+ def forward(self, hidden_states, temb=None):
510
+ hidden_states = torch.cat([hidden_states, temb], dim=1)
511
+ for resnet in self.resnets:
512
+ hidden_states = resnet(hidden_states)
513
+
514
+ return hidden_states, (hidden_states,)
515
+
516
+
517
+ class AttnUpBlock1D(nn.Module):
518
+ def __init__(self, in_channels, out_channels, mid_channels=None):
519
+ super().__init__()
520
+ mid_channels = out_channels if mid_channels is None else mid_channels
521
+
522
+ resnets = [
523
+ ResConvBlock(2 * in_channels, mid_channels, mid_channels),
524
+ ResConvBlock(mid_channels, mid_channels, mid_channels),
525
+ ResConvBlock(mid_channels, mid_channels, out_channels),
526
+ ]
527
+ attentions = [
528
+ SelfAttention1d(mid_channels, mid_channels // 32),
529
+ SelfAttention1d(mid_channels, mid_channels // 32),
530
+ SelfAttention1d(out_channels, out_channels // 32),
531
+ ]
532
+
533
+ self.attentions = nn.ModuleList(attentions)
534
+ self.resnets = nn.ModuleList(resnets)
535
+ self.up = Upsample1d(kernel="cubic")
536
+
537
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None):
538
+ res_hidden_states = res_hidden_states_tuple[-1]
539
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
540
+
541
+ for resnet, attn in zip(self.resnets, self.attentions):
542
+ hidden_states = resnet(hidden_states)
543
+ hidden_states = attn(hidden_states)
544
+
545
+ hidden_states = self.up(hidden_states)
546
+
547
+ return hidden_states
548
+
549
+
550
+ class UpBlock1D(nn.Module):
551
+ def __init__(self, in_channels, out_channels, mid_channels=None):
552
+ super().__init__()
553
+ mid_channels = in_channels if mid_channels is None else mid_channels
554
+
555
+ resnets = [
556
+ ResConvBlock(2 * in_channels, mid_channels, mid_channels),
557
+ ResConvBlock(mid_channels, mid_channels, mid_channels),
558
+ ResConvBlock(mid_channels, mid_channels, out_channels),
559
+ ]
560
+
561
+ self.resnets = nn.ModuleList(resnets)
562
+ self.up = Upsample1d(kernel="cubic")
563
+
564
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None):
565
+ res_hidden_states = res_hidden_states_tuple[-1]
566
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
567
+
568
+ for resnet in self.resnets:
569
+ hidden_states = resnet(hidden_states)
570
+
571
+ hidden_states = self.up(hidden_states)
572
+
573
+ return hidden_states
574
+
575
+
576
+ class UpBlock1DNoSkip(nn.Module):
577
+ def __init__(self, in_channels, out_channels, mid_channels=None):
578
+ super().__init__()
579
+ mid_channels = in_channels if mid_channels is None else mid_channels
580
+
581
+ resnets = [
582
+ ResConvBlock(2 * in_channels, mid_channels, mid_channels),
583
+ ResConvBlock(mid_channels, mid_channels, mid_channels),
584
+ ResConvBlock(mid_channels, mid_channels, out_channels, is_last=True),
585
+ ]
586
+
587
+ self.resnets = nn.ModuleList(resnets)
588
+
589
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None):
590
+ res_hidden_states = res_hidden_states_tuple[-1]
591
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
592
+
593
+ for resnet in self.resnets:
594
+ hidden_states = resnet(hidden_states)
595
+
596
+ return hidden_states
597
+
598
+
599
+ def get_down_block(down_block_type, num_layers, in_channels, out_channels, temb_channels, add_downsample):
600
+ if down_block_type == "DownResnetBlock1D":
601
+ return DownResnetBlock1D(
602
+ in_channels=in_channels,
603
+ num_layers=num_layers,
604
+ out_channels=out_channels,
605
+ temb_channels=temb_channels,
606
+ add_downsample=add_downsample,
607
+ )
608
+ elif down_block_type == "DownBlock1D":
609
+ return DownBlock1D(out_channels=out_channels, in_channels=in_channels)
610
+ elif down_block_type == "AttnDownBlock1D":
611
+ return AttnDownBlock1D(out_channels=out_channels, in_channels=in_channels)
612
+ elif down_block_type == "DownBlock1DNoSkip":
613
+ return DownBlock1DNoSkip(out_channels=out_channels, in_channels=in_channels)
614
+ raise ValueError(f"{down_block_type} does not exist.")
615
+
616
+
617
+ def get_up_block(up_block_type, num_layers, in_channels, out_channels, temb_channels, add_upsample):
618
+ if up_block_type == "UpResnetBlock1D":
619
+ return UpResnetBlock1D(
620
+ in_channels=in_channels,
621
+ num_layers=num_layers,
622
+ out_channels=out_channels,
623
+ temb_channels=temb_channels,
624
+ add_upsample=add_upsample,
625
+ )
626
+ elif up_block_type == "UpBlock1D":
627
+ return UpBlock1D(in_channels=in_channels, out_channels=out_channels)
628
+ elif up_block_type == "AttnUpBlock1D":
629
+ return AttnUpBlock1D(in_channels=in_channels, out_channels=out_channels)
630
+ elif up_block_type == "UpBlock1DNoSkip":
631
+ return UpBlock1DNoSkip(in_channels=in_channels, out_channels=out_channels)
632
+ raise ValueError(f"{up_block_type} does not exist.")
633
+
634
+
635
+ def get_mid_block(mid_block_type, num_layers, in_channels, mid_channels, out_channels, embed_dim, add_downsample):
636
+ if mid_block_type == "MidResTemporalBlock1D":
637
+ return MidResTemporalBlock1D(
638
+ num_layers=num_layers,
639
+ in_channels=in_channels,
640
+ out_channels=out_channels,
641
+ embed_dim=embed_dim,
642
+ add_downsample=add_downsample,
643
+ )
644
+ elif mid_block_type == "ValueFunctionMidBlock1D":
645
+ return ValueFunctionMidBlock1D(in_channels=in_channels, out_channels=out_channels, embed_dim=embed_dim)
646
+ elif mid_block_type == "UNetMidBlock1D":
647
+ return UNetMidBlock1D(in_channels=in_channels, mid_channels=mid_channels, out_channels=out_channels)
648
+ raise ValueError(f"{mid_block_type} does not exist.")
649
+
650
+
651
+ def get_out_block(*, out_block_type, num_groups_out, embed_dim, out_channels, act_fn, fc_dim):
652
+ if out_block_type == "OutConv1DBlock":
653
+ return OutConv1DBlock(num_groups_out, out_channels, embed_dim, act_fn)
654
+ elif out_block_type == "ValueFunction":
655
+ return OutValueFunctionBlock(fc_dim, embed_dim)
656
+ return None
6DoF/diffusers/models/unet_2d.py ADDED
@@ -0,0 +1,329 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from dataclasses import dataclass
15
+ from typing import Optional, Tuple, Union
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+
20
+ from ..configuration_utils import ConfigMixin, register_to_config
21
+ from ..utils import BaseOutput
22
+ from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
23
+ from .modeling_utils import ModelMixin
24
+ from .unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block
25
+
26
+
27
+ @dataclass
28
+ class UNet2DOutput(BaseOutput):
29
+ """
30
+ The output of [`UNet2DModel`].
31
+
32
+ Args:
33
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
34
+ The hidden states output from the last layer of the model.
35
+ """
36
+
37
+ sample: torch.FloatTensor
38
+
39
+
40
+ class UNet2DModel(ModelMixin, ConfigMixin):
41
+ r"""
42
+ A 2D UNet model that takes a noisy sample and a timestep and returns a sample shaped output.
43
+
44
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
45
+ for all models (such as downloading or saving).
46
+
47
+ Parameters:
48
+ sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
49
+ Height and width of input/output sample. Dimensions must be a multiple of `2 ** (len(block_out_channels) -
50
+ 1)`.
51
+ in_channels (`int`, *optional*, defaults to 3): Number of channels in the input sample.
52
+ out_channels (`int`, *optional*, defaults to 3): Number of channels in the output.
53
+ center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
54
+ time_embedding_type (`str`, *optional*, defaults to `"positional"`): Type of time embedding to use.
55
+ freq_shift (`int`, *optional*, defaults to 0): Frequency shift for Fourier time embedding.
56
+ flip_sin_to_cos (`bool`, *optional*, defaults to `True`):
57
+ Whether to flip sin to cos for Fourier time embedding.
58
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D")`):
59
+ Tuple of downsample block types.
60
+ mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2D"`):
61
+ Block type for middle of UNet, it can be either `UNetMidBlock2D` or `UnCLIPUNetMidBlock2D`.
62
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D")`):
63
+ Tuple of upsample block types.
64
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(224, 448, 672, 896)`):
65
+ Tuple of block output channels.
66
+ layers_per_block (`int`, *optional*, defaults to `2`): The number of layers per block.
67
+ mid_block_scale_factor (`float`, *optional*, defaults to `1`): The scale factor for the mid block.
68
+ downsample_padding (`int`, *optional*, defaults to `1`): The padding for the downsample convolution.
69
+ downsample_type (`str`, *optional*, defaults to `conv`):
70
+ The downsample type for downsampling layers. Choose between "conv" and "resnet"
71
+ upsample_type (`str`, *optional*, defaults to `conv`):
72
+ The upsample type for upsampling layers. Choose between "conv" and "resnet"
73
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
74
+ attention_head_dim (`int`, *optional*, defaults to `8`): The attention head dimension.
75
+ norm_num_groups (`int`, *optional*, defaults to `32`): The number of groups for normalization.
76
+ norm_eps (`float`, *optional*, defaults to `1e-5`): The epsilon for normalization.
77
+ resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
78
+ for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
79
+ class_embed_type (`str`, *optional*, defaults to `None`):
80
+ The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
81
+ `"timestep"`, or `"identity"`.
82
+ num_class_embeds (`int`, *optional*, defaults to `None`):
83
+ Input dimension of the learnable embedding matrix to be projected to `time_embed_dim` when performing class
84
+ conditioning with `class_embed_type` equal to `None`.
85
+ """
86
+
87
+ @register_to_config
88
+ def __init__(
89
+ self,
90
+ sample_size: Optional[Union[int, Tuple[int, int]]] = None,
91
+ in_channels: int = 3,
92
+ out_channels: int = 3,
93
+ center_input_sample: bool = False,
94
+ time_embedding_type: str = "positional",
95
+ freq_shift: int = 0,
96
+ flip_sin_to_cos: bool = True,
97
+ down_block_types: Tuple[str] = ("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D"),
98
+ up_block_types: Tuple[str] = ("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D"),
99
+ block_out_channels: Tuple[int] = (224, 448, 672, 896),
100
+ layers_per_block: int = 2,
101
+ mid_block_scale_factor: float = 1,
102
+ downsample_padding: int = 1,
103
+ downsample_type: str = "conv",
104
+ upsample_type: str = "conv",
105
+ act_fn: str = "silu",
106
+ attention_head_dim: Optional[int] = 8,
107
+ norm_num_groups: int = 32,
108
+ norm_eps: float = 1e-5,
109
+ resnet_time_scale_shift: str = "default",
110
+ add_attention: bool = True,
111
+ class_embed_type: Optional[str] = None,
112
+ num_class_embeds: Optional[int] = None,
113
+ ):
114
+ super().__init__()
115
+
116
+ self.sample_size = sample_size
117
+ time_embed_dim = block_out_channels[0] * 4
118
+
119
+ # Check inputs
120
+ if len(down_block_types) != len(up_block_types):
121
+ raise ValueError(
122
+ f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
123
+ )
124
+
125
+ if len(block_out_channels) != len(down_block_types):
126
+ raise ValueError(
127
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
128
+ )
129
+
130
+ # input
131
+ self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1))
132
+
133
+ # time
134
+ if time_embedding_type == "fourier":
135
+ self.time_proj = GaussianFourierProjection(embedding_size=block_out_channels[0], scale=16)
136
+ timestep_input_dim = 2 * block_out_channels[0]
137
+ elif time_embedding_type == "positional":
138
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
139
+ timestep_input_dim = block_out_channels[0]
140
+
141
+ self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
142
+
143
+ # class embedding
144
+ if class_embed_type is None and num_class_embeds is not None:
145
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
146
+ elif class_embed_type == "timestep":
147
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
148
+ elif class_embed_type == "identity":
149
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
150
+ else:
151
+ self.class_embedding = None
152
+
153
+ self.down_blocks = nn.ModuleList([])
154
+ self.mid_block = None
155
+ self.up_blocks = nn.ModuleList([])
156
+
157
+ # down
158
+ output_channel = block_out_channels[0]
159
+ for i, down_block_type in enumerate(down_block_types):
160
+ input_channel = output_channel
161
+ output_channel = block_out_channels[i]
162
+ is_final_block = i == len(block_out_channels) - 1
163
+
164
+ down_block = get_down_block(
165
+ down_block_type,
166
+ num_layers=layers_per_block,
167
+ in_channels=input_channel,
168
+ out_channels=output_channel,
169
+ temb_channels=time_embed_dim,
170
+ add_downsample=not is_final_block,
171
+ resnet_eps=norm_eps,
172
+ resnet_act_fn=act_fn,
173
+ resnet_groups=norm_num_groups,
174
+ attention_head_dim=attention_head_dim if attention_head_dim is not None else output_channel,
175
+ downsample_padding=downsample_padding,
176
+ resnet_time_scale_shift=resnet_time_scale_shift,
177
+ downsample_type=downsample_type,
178
+ )
179
+ self.down_blocks.append(down_block)
180
+
181
+ # mid
182
+ self.mid_block = UNetMidBlock2D(
183
+ in_channels=block_out_channels[-1],
184
+ temb_channels=time_embed_dim,
185
+ resnet_eps=norm_eps,
186
+ resnet_act_fn=act_fn,
187
+ output_scale_factor=mid_block_scale_factor,
188
+ resnet_time_scale_shift=resnet_time_scale_shift,
189
+ attention_head_dim=attention_head_dim if attention_head_dim is not None else block_out_channels[-1],
190
+ resnet_groups=norm_num_groups,
191
+ add_attention=add_attention,
192
+ )
193
+
194
+ # up
195
+ reversed_block_out_channels = list(reversed(block_out_channels))
196
+ output_channel = reversed_block_out_channels[0]
197
+ for i, up_block_type in enumerate(up_block_types):
198
+ prev_output_channel = output_channel
199
+ output_channel = reversed_block_out_channels[i]
200
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
201
+
202
+ is_final_block = i == len(block_out_channels) - 1
203
+
204
+ up_block = get_up_block(
205
+ up_block_type,
206
+ num_layers=layers_per_block + 1,
207
+ in_channels=input_channel,
208
+ out_channels=output_channel,
209
+ prev_output_channel=prev_output_channel,
210
+ temb_channels=time_embed_dim,
211
+ add_upsample=not is_final_block,
212
+ resnet_eps=norm_eps,
213
+ resnet_act_fn=act_fn,
214
+ resnet_groups=norm_num_groups,
215
+ attention_head_dim=attention_head_dim if attention_head_dim is not None else output_channel,
216
+ resnet_time_scale_shift=resnet_time_scale_shift,
217
+ upsample_type=upsample_type,
218
+ )
219
+ self.up_blocks.append(up_block)
220
+ prev_output_channel = output_channel
221
+
222
+ # out
223
+ num_groups_out = norm_num_groups if norm_num_groups is not None else min(block_out_channels[0] // 4, 32)
224
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=num_groups_out, eps=norm_eps)
225
+ self.conv_act = nn.SiLU()
226
+ self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
227
+
228
+ def forward(
229
+ self,
230
+ sample: torch.FloatTensor,
231
+ timestep: Union[torch.Tensor, float, int],
232
+ class_labels: Optional[torch.Tensor] = None,
233
+ return_dict: bool = True,
234
+ ) -> Union[UNet2DOutput, Tuple]:
235
+ r"""
236
+ The [`UNet2DModel`] forward method.
237
+
238
+ Args:
239
+ sample (`torch.FloatTensor`):
240
+ The noisy input tensor with the following shape `(batch, channel, height, width)`.
241
+ timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
242
+ class_labels (`torch.FloatTensor`, *optional*, defaults to `None`):
243
+ Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
244
+ return_dict (`bool`, *optional*, defaults to `True`):
245
+ Whether or not to return a [`~models.unet_2d.UNet2DOutput`] instead of a plain tuple.
246
+
247
+ Returns:
248
+ [`~models.unet_2d.UNet2DOutput`] or `tuple`:
249
+ If `return_dict` is True, an [`~models.unet_2d.UNet2DOutput`] is returned, otherwise a `tuple` is
250
+ returned where the first element is the sample tensor.
251
+ """
252
+ # 0. center input if necessary
253
+ if self.config.center_input_sample:
254
+ sample = 2 * sample - 1.0
255
+
256
+ # 1. time
257
+ timesteps = timestep
258
+ if not torch.is_tensor(timesteps):
259
+ timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
260
+ elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
261
+ timesteps = timesteps[None].to(sample.device)
262
+
263
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
264
+ timesteps = timesteps * torch.ones(sample.shape[0], dtype=timesteps.dtype, device=timesteps.device)
265
+
266
+ t_emb = self.time_proj(timesteps)
267
+
268
+ # timesteps does not contain any weights and will always return f32 tensors
269
+ # but time_embedding might actually be running in fp16. so we need to cast here.
270
+ # there might be better ways to encapsulate this.
271
+ t_emb = t_emb.to(dtype=self.dtype)
272
+ emb = self.time_embedding(t_emb)
273
+
274
+ if self.class_embedding is not None:
275
+ if class_labels is None:
276
+ raise ValueError("class_labels should be provided when doing class conditioning")
277
+
278
+ if self.config.class_embed_type == "timestep":
279
+ class_labels = self.time_proj(class_labels)
280
+
281
+ class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
282
+ emb = emb + class_emb
283
+
284
+ # 2. pre-process
285
+ skip_sample = sample
286
+ sample = self.conv_in(sample)
287
+
288
+ # 3. down
289
+ down_block_res_samples = (sample,)
290
+ for downsample_block in self.down_blocks:
291
+ if hasattr(downsample_block, "skip_conv"):
292
+ sample, res_samples, skip_sample = downsample_block(
293
+ hidden_states=sample, temb=emb, skip_sample=skip_sample
294
+ )
295
+ else:
296
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
297
+
298
+ down_block_res_samples += res_samples
299
+
300
+ # 4. mid
301
+ sample = self.mid_block(sample, emb)
302
+
303
+ # 5. up
304
+ skip_sample = None
305
+ for upsample_block in self.up_blocks:
306
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
307
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
308
+
309
+ if hasattr(upsample_block, "skip_conv"):
310
+ sample, skip_sample = upsample_block(sample, res_samples, emb, skip_sample)
311
+ else:
312
+ sample = upsample_block(sample, res_samples, emb)
313
+
314
+ # 6. post-process
315
+ sample = self.conv_norm_out(sample)
316
+ sample = self.conv_act(sample)
317
+ sample = self.conv_out(sample)
318
+
319
+ if skip_sample is not None:
320
+ sample += skip_sample
321
+
322
+ if self.config.time_embedding_type == "fourier":
323
+ timesteps = timesteps.reshape((sample.shape[0], *([1] * len(sample.shape[1:]))))
324
+ sample = sample / timesteps
325
+
326
+ if not return_dict:
327
+ return (sample,)
328
+
329
+ return UNet2DOutput(sample=sample)
6DoF/diffusers/models/unet_2d_blocks.py ADDED
The diff for this file is too large to render. See raw diff
 
6DoF/diffusers/models/unet_2d_blocks_flax.py ADDED
@@ -0,0 +1,377 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import flax.linen as nn
16
+ import jax.numpy as jnp
17
+
18
+ from .attention_flax import FlaxTransformer2DModel
19
+ from .resnet_flax import FlaxDownsample2D, FlaxResnetBlock2D, FlaxUpsample2D
20
+
21
+
22
+ class FlaxCrossAttnDownBlock2D(nn.Module):
23
+ r"""
24
+ Cross Attention 2D Downsizing block - original architecture from Unet transformers:
25
+ https://arxiv.org/abs/2103.06104
26
+
27
+ Parameters:
28
+ in_channels (:obj:`int`):
29
+ Input channels
30
+ out_channels (:obj:`int`):
31
+ Output channels
32
+ dropout (:obj:`float`, *optional*, defaults to 0.0):
33
+ Dropout rate
34
+ num_layers (:obj:`int`, *optional*, defaults to 1):
35
+ Number of attention blocks layers
36
+ num_attention_heads (:obj:`int`, *optional*, defaults to 1):
37
+ Number of attention heads of each spatial transformer block
38
+ add_downsample (:obj:`bool`, *optional*, defaults to `True`):
39
+ Whether to add downsampling layer before each final output
40
+ use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
41
+ enable memory efficient attention https://arxiv.org/abs/2112.05682
42
+ dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
43
+ Parameters `dtype`
44
+ """
45
+ in_channels: int
46
+ out_channels: int
47
+ dropout: float = 0.0
48
+ num_layers: int = 1
49
+ num_attention_heads: int = 1
50
+ add_downsample: bool = True
51
+ use_linear_projection: bool = False
52
+ only_cross_attention: bool = False
53
+ use_memory_efficient_attention: bool = False
54
+ dtype: jnp.dtype = jnp.float32
55
+
56
+ def setup(self):
57
+ resnets = []
58
+ attentions = []
59
+
60
+ for i in range(self.num_layers):
61
+ in_channels = self.in_channels if i == 0 else self.out_channels
62
+
63
+ res_block = FlaxResnetBlock2D(
64
+ in_channels=in_channels,
65
+ out_channels=self.out_channels,
66
+ dropout_prob=self.dropout,
67
+ dtype=self.dtype,
68
+ )
69
+ resnets.append(res_block)
70
+
71
+ attn_block = FlaxTransformer2DModel(
72
+ in_channels=self.out_channels,
73
+ n_heads=self.num_attention_heads,
74
+ d_head=self.out_channels // self.num_attention_heads,
75
+ depth=1,
76
+ use_linear_projection=self.use_linear_projection,
77
+ only_cross_attention=self.only_cross_attention,
78
+ use_memory_efficient_attention=self.use_memory_efficient_attention,
79
+ dtype=self.dtype,
80
+ )
81
+ attentions.append(attn_block)
82
+
83
+ self.resnets = resnets
84
+ self.attentions = attentions
85
+
86
+ if self.add_downsample:
87
+ self.downsamplers_0 = FlaxDownsample2D(self.out_channels, dtype=self.dtype)
88
+
89
+ def __call__(self, hidden_states, temb, encoder_hidden_states, deterministic=True):
90
+ output_states = ()
91
+
92
+ for resnet, attn in zip(self.resnets, self.attentions):
93
+ hidden_states = resnet(hidden_states, temb, deterministic=deterministic)
94
+ hidden_states = attn(hidden_states, encoder_hidden_states, deterministic=deterministic)
95
+ output_states += (hidden_states,)
96
+
97
+ if self.add_downsample:
98
+ hidden_states = self.downsamplers_0(hidden_states)
99
+ output_states += (hidden_states,)
100
+
101
+ return hidden_states, output_states
102
+
103
+
104
+ class FlaxDownBlock2D(nn.Module):
105
+ r"""
106
+ Flax 2D downsizing block
107
+
108
+ Parameters:
109
+ in_channels (:obj:`int`):
110
+ Input channels
111
+ out_channels (:obj:`int`):
112
+ Output channels
113
+ dropout (:obj:`float`, *optional*, defaults to 0.0):
114
+ Dropout rate
115
+ num_layers (:obj:`int`, *optional*, defaults to 1):
116
+ Number of attention blocks layers
117
+ add_downsample (:obj:`bool`, *optional*, defaults to `True`):
118
+ Whether to add downsampling layer before each final output
119
+ dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
120
+ Parameters `dtype`
121
+ """
122
+ in_channels: int
123
+ out_channels: int
124
+ dropout: float = 0.0
125
+ num_layers: int = 1
126
+ add_downsample: bool = True
127
+ dtype: jnp.dtype = jnp.float32
128
+
129
+ def setup(self):
130
+ resnets = []
131
+
132
+ for i in range(self.num_layers):
133
+ in_channels = self.in_channels if i == 0 else self.out_channels
134
+
135
+ res_block = FlaxResnetBlock2D(
136
+ in_channels=in_channels,
137
+ out_channels=self.out_channels,
138
+ dropout_prob=self.dropout,
139
+ dtype=self.dtype,
140
+ )
141
+ resnets.append(res_block)
142
+ self.resnets = resnets
143
+
144
+ if self.add_downsample:
145
+ self.downsamplers_0 = FlaxDownsample2D(self.out_channels, dtype=self.dtype)
146
+
147
+ def __call__(self, hidden_states, temb, deterministic=True):
148
+ output_states = ()
149
+
150
+ for resnet in self.resnets:
151
+ hidden_states = resnet(hidden_states, temb, deterministic=deterministic)
152
+ output_states += (hidden_states,)
153
+
154
+ if self.add_downsample:
155
+ hidden_states = self.downsamplers_0(hidden_states)
156
+ output_states += (hidden_states,)
157
+
158
+ return hidden_states, output_states
159
+
160
+
161
+ class FlaxCrossAttnUpBlock2D(nn.Module):
162
+ r"""
163
+ Cross Attention 2D Upsampling block - original architecture from Unet transformers:
164
+ https://arxiv.org/abs/2103.06104
165
+
166
+ Parameters:
167
+ in_channels (:obj:`int`):
168
+ Input channels
169
+ out_channels (:obj:`int`):
170
+ Output channels
171
+ dropout (:obj:`float`, *optional*, defaults to 0.0):
172
+ Dropout rate
173
+ num_layers (:obj:`int`, *optional*, defaults to 1):
174
+ Number of attention blocks layers
175
+ num_attention_heads (:obj:`int`, *optional*, defaults to 1):
176
+ Number of attention heads of each spatial transformer block
177
+ add_upsample (:obj:`bool`, *optional*, defaults to `True`):
178
+ Whether to add upsampling layer before each final output
179
+ use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
180
+ enable memory efficient attention https://arxiv.org/abs/2112.05682
181
+ dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
182
+ Parameters `dtype`
183
+ """
184
+ in_channels: int
185
+ out_channels: int
186
+ prev_output_channel: int
187
+ dropout: float = 0.0
188
+ num_layers: int = 1
189
+ num_attention_heads: int = 1
190
+ add_upsample: bool = True
191
+ use_linear_projection: bool = False
192
+ only_cross_attention: bool = False
193
+ use_memory_efficient_attention: bool = False
194
+ dtype: jnp.dtype = jnp.float32
195
+
196
+ def setup(self):
197
+ resnets = []
198
+ attentions = []
199
+
200
+ for i in range(self.num_layers):
201
+ res_skip_channels = self.in_channels if (i == self.num_layers - 1) else self.out_channels
202
+ resnet_in_channels = self.prev_output_channel if i == 0 else self.out_channels
203
+
204
+ res_block = FlaxResnetBlock2D(
205
+ in_channels=resnet_in_channels + res_skip_channels,
206
+ out_channels=self.out_channels,
207
+ dropout_prob=self.dropout,
208
+ dtype=self.dtype,
209
+ )
210
+ resnets.append(res_block)
211
+
212
+ attn_block = FlaxTransformer2DModel(
213
+ in_channels=self.out_channels,
214
+ n_heads=self.num_attention_heads,
215
+ d_head=self.out_channels // self.num_attention_heads,
216
+ depth=1,
217
+ use_linear_projection=self.use_linear_projection,
218
+ only_cross_attention=self.only_cross_attention,
219
+ use_memory_efficient_attention=self.use_memory_efficient_attention,
220
+ dtype=self.dtype,
221
+ )
222
+ attentions.append(attn_block)
223
+
224
+ self.resnets = resnets
225
+ self.attentions = attentions
226
+
227
+ if self.add_upsample:
228
+ self.upsamplers_0 = FlaxUpsample2D(self.out_channels, dtype=self.dtype)
229
+
230
+ def __call__(self, hidden_states, res_hidden_states_tuple, temb, encoder_hidden_states, deterministic=True):
231
+ for resnet, attn in zip(self.resnets, self.attentions):
232
+ # pop res hidden states
233
+ res_hidden_states = res_hidden_states_tuple[-1]
234
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
235
+ hidden_states = jnp.concatenate((hidden_states, res_hidden_states), axis=-1)
236
+
237
+ hidden_states = resnet(hidden_states, temb, deterministic=deterministic)
238
+ hidden_states = attn(hidden_states, encoder_hidden_states, deterministic=deterministic)
239
+
240
+ if self.add_upsample:
241
+ hidden_states = self.upsamplers_0(hidden_states)
242
+
243
+ return hidden_states
244
+
245
+
246
+ class FlaxUpBlock2D(nn.Module):
247
+ r"""
248
+ Flax 2D upsampling block
249
+
250
+ Parameters:
251
+ in_channels (:obj:`int`):
252
+ Input channels
253
+ out_channels (:obj:`int`):
254
+ Output channels
255
+ prev_output_channel (:obj:`int`):
256
+ Output channels from the previous block
257
+ dropout (:obj:`float`, *optional*, defaults to 0.0):
258
+ Dropout rate
259
+ num_layers (:obj:`int`, *optional*, defaults to 1):
260
+ Number of attention blocks layers
261
+ add_downsample (:obj:`bool`, *optional*, defaults to `True`):
262
+ Whether to add downsampling layer before each final output
263
+ dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
264
+ Parameters `dtype`
265
+ """
266
+ in_channels: int
267
+ out_channels: int
268
+ prev_output_channel: int
269
+ dropout: float = 0.0
270
+ num_layers: int = 1
271
+ add_upsample: bool = True
272
+ dtype: jnp.dtype = jnp.float32
273
+
274
+ def setup(self):
275
+ resnets = []
276
+
277
+ for i in range(self.num_layers):
278
+ res_skip_channels = self.in_channels if (i == self.num_layers - 1) else self.out_channels
279
+ resnet_in_channels = self.prev_output_channel if i == 0 else self.out_channels
280
+
281
+ res_block = FlaxResnetBlock2D(
282
+ in_channels=resnet_in_channels + res_skip_channels,
283
+ out_channels=self.out_channels,
284
+ dropout_prob=self.dropout,
285
+ dtype=self.dtype,
286
+ )
287
+ resnets.append(res_block)
288
+
289
+ self.resnets = resnets
290
+
291
+ if self.add_upsample:
292
+ self.upsamplers_0 = FlaxUpsample2D(self.out_channels, dtype=self.dtype)
293
+
294
+ def __call__(self, hidden_states, res_hidden_states_tuple, temb, deterministic=True):
295
+ for resnet in self.resnets:
296
+ # pop res hidden states
297
+ res_hidden_states = res_hidden_states_tuple[-1]
298
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
299
+ hidden_states = jnp.concatenate((hidden_states, res_hidden_states), axis=-1)
300
+
301
+ hidden_states = resnet(hidden_states, temb, deterministic=deterministic)
302
+
303
+ if self.add_upsample:
304
+ hidden_states = self.upsamplers_0(hidden_states)
305
+
306
+ return hidden_states
307
+
308
+
309
+ class FlaxUNetMidBlock2DCrossAttn(nn.Module):
310
+ r"""
311
+ Cross Attention 2D Mid-level block - original architecture from Unet transformers: https://arxiv.org/abs/2103.06104
312
+
313
+ Parameters:
314
+ in_channels (:obj:`int`):
315
+ Input channels
316
+ dropout (:obj:`float`, *optional*, defaults to 0.0):
317
+ Dropout rate
318
+ num_layers (:obj:`int`, *optional*, defaults to 1):
319
+ Number of attention blocks layers
320
+ num_attention_heads (:obj:`int`, *optional*, defaults to 1):
321
+ Number of attention heads of each spatial transformer block
322
+ use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
323
+ enable memory efficient attention https://arxiv.org/abs/2112.05682
324
+ dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
325
+ Parameters `dtype`
326
+ """
327
+ in_channels: int
328
+ dropout: float = 0.0
329
+ num_layers: int = 1
330
+ num_attention_heads: int = 1
331
+ use_linear_projection: bool = False
332
+ use_memory_efficient_attention: bool = False
333
+ dtype: jnp.dtype = jnp.float32
334
+
335
+ def setup(self):
336
+ # there is always at least one resnet
337
+ resnets = [
338
+ FlaxResnetBlock2D(
339
+ in_channels=self.in_channels,
340
+ out_channels=self.in_channels,
341
+ dropout_prob=self.dropout,
342
+ dtype=self.dtype,
343
+ )
344
+ ]
345
+
346
+ attentions = []
347
+
348
+ for _ in range(self.num_layers):
349
+ attn_block = FlaxTransformer2DModel(
350
+ in_channels=self.in_channels,
351
+ n_heads=self.num_attention_heads,
352
+ d_head=self.in_channels // self.num_attention_heads,
353
+ depth=1,
354
+ use_linear_projection=self.use_linear_projection,
355
+ use_memory_efficient_attention=self.use_memory_efficient_attention,
356
+ dtype=self.dtype,
357
+ )
358
+ attentions.append(attn_block)
359
+
360
+ res_block = FlaxResnetBlock2D(
361
+ in_channels=self.in_channels,
362
+ out_channels=self.in_channels,
363
+ dropout_prob=self.dropout,
364
+ dtype=self.dtype,
365
+ )
366
+ resnets.append(res_block)
367
+
368
+ self.resnets = resnets
369
+ self.attentions = attentions
370
+
371
+ def __call__(self, hidden_states, temb, encoder_hidden_states, deterministic=True):
372
+ hidden_states = self.resnets[0](hidden_states, temb)
373
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
374
+ hidden_states = attn(hidden_states, encoder_hidden_states, deterministic=deterministic)
375
+ hidden_states = resnet(hidden_states, temb, deterministic=deterministic)
376
+
377
+ return hidden_states
6DoF/diffusers/models/unet_2d_condition.py ADDED
@@ -0,0 +1,980 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from dataclasses import dataclass
15
+ from typing import Any, Dict, List, Optional, Tuple, Union
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+ import torch.utils.checkpoint
20
+
21
+ from ..configuration_utils import ConfigMixin, register_to_config
22
+ from ..loaders import UNet2DConditionLoadersMixin
23
+ from ..utils import BaseOutput, logging
24
+ from .activations import get_activation
25
+ from .attention_processor import AttentionProcessor, AttnProcessor
26
+ from .embeddings import (
27
+ GaussianFourierProjection,
28
+ ImageHintTimeEmbedding,
29
+ ImageProjection,
30
+ ImageTimeEmbedding,
31
+ TextImageProjection,
32
+ TextImageTimeEmbedding,
33
+ TextTimeEmbedding,
34
+ TimestepEmbedding,
35
+ Timesteps,
36
+ )
37
+ from .modeling_utils import ModelMixin
38
+ from .unet_2d_blocks import (
39
+ CrossAttnDownBlock2D,
40
+ CrossAttnUpBlock2D,
41
+ DownBlock2D,
42
+ UNetMidBlock2DCrossAttn,
43
+ UNetMidBlock2DSimpleCrossAttn,
44
+ UpBlock2D,
45
+ get_down_block,
46
+ get_up_block,
47
+ )
48
+
49
+
50
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
51
+
52
+
53
+ @dataclass
54
+ class UNet2DConditionOutput(BaseOutput):
55
+ """
56
+ The output of [`UNet2DConditionModel`].
57
+
58
+ Args:
59
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
60
+ The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
61
+ """
62
+
63
+ sample: torch.FloatTensor = None
64
+
65
+
66
+ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
67
+ r"""
68
+ A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
69
+ shaped output.
70
+
71
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
72
+ for all models (such as downloading or saving).
73
+
74
+ Parameters:
75
+ sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
76
+ Height and width of input/output sample.
77
+ in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample.
78
+ out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
79
+ center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
80
+ flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
81
+ Whether to flip the sin to cos in the time embedding.
82
+ freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
83
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
84
+ The tuple of downsample blocks to use.
85
+ mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
86
+ Block type for middle of UNet, it can be either `UNetMidBlock2DCrossAttn` or
87
+ `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.
88
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
89
+ The tuple of upsample blocks to use.
90
+ only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):
91
+ Whether to include self-attention in the basic transformer blocks, see
92
+ [`~models.attention.BasicTransformerBlock`].
93
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
94
+ The tuple of output channels for each block.
95
+ layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
96
+ downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
97
+ mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
98
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
99
+ norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
100
+ If `None`, normalization and activation layers is skipped in post-processing.
101
+ norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
102
+ cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
103
+ The dimension of the cross attention features.
104
+ transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
105
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
106
+ [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
107
+ [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
108
+ encoder_hid_dim (`int`, *optional*, defaults to None):
109
+ If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
110
+ dimension to `cross_attention_dim`.
111
+ encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
112
+ If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
113
+ embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
114
+ attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
115
+ num_attention_heads (`int`, *optional*):
116
+ The number of attention heads. If not defined, defaults to `attention_head_dim`
117
+ resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
118
+ for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
119
+ class_embed_type (`str`, *optional*, defaults to `None`):
120
+ The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
121
+ `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
122
+ addition_embed_type (`str`, *optional*, defaults to `None`):
123
+ Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
124
+ "text". "text" will use the `TextTimeEmbedding` layer.
125
+ addition_time_embed_dim: (`int`, *optional*, defaults to `None`):
126
+ Dimension for the timestep embeddings.
127
+ num_class_embeds (`int`, *optional*, defaults to `None`):
128
+ Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
129
+ class conditioning with `class_embed_type` equal to `None`.
130
+ time_embedding_type (`str`, *optional*, defaults to `positional`):
131
+ The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
132
+ time_embedding_dim (`int`, *optional*, defaults to `None`):
133
+ An optional override for the dimension of the projected time embedding.
134
+ time_embedding_act_fn (`str`, *optional*, defaults to `None`):
135
+ Optional activation function to use only once on the time embeddings before they are passed to the rest of
136
+ the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`.
137
+ timestep_post_act (`str`, *optional*, defaults to `None`):
138
+ The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
139
+ time_cond_proj_dim (`int`, *optional*, defaults to `None`):
140
+ The dimension of `cond_proj` layer in the timestep embedding.
141
+ conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer.
142
+ conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer.
143
+ projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when
144
+ `class_embed_type="projection"`. Required when `class_embed_type="projection"`.
145
+ class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
146
+ embeddings with the class embeddings.
147
+ mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):
148
+ Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If
149
+ `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the
150
+ `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False`
151
+ otherwise.
152
+ """
153
+
154
+ _supports_gradient_checkpointing = True
155
+
156
+ @register_to_config
157
+ def __init__(
158
+ self,
159
+ sample_size: Optional[int] = None,
160
+ in_channels: int = 4,
161
+ out_channels: int = 4,
162
+ center_input_sample: bool = False,
163
+ flip_sin_to_cos: bool = True,
164
+ freq_shift: int = 0,
165
+ down_block_types: Tuple[str] = (
166
+ "CrossAttnDownBlock2D",
167
+ "CrossAttnDownBlock2D",
168
+ "CrossAttnDownBlock2D",
169
+ "DownBlock2D",
170
+ ),
171
+ mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
172
+ up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
173
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
174
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
175
+ layers_per_block: Union[int, Tuple[int]] = 2,
176
+ downsample_padding: int = 1,
177
+ mid_block_scale_factor: float = 1,
178
+ act_fn: str = "silu",
179
+ norm_num_groups: Optional[int] = 32,
180
+ norm_eps: float = 1e-5,
181
+ cross_attention_dim: Union[int, Tuple[int]] = 1280,
182
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
183
+ encoder_hid_dim: Optional[int] = None,
184
+ encoder_hid_dim_type: Optional[str] = None,
185
+ attention_head_dim: Union[int, Tuple[int]] = 8,
186
+ num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
187
+ dual_cross_attention: bool = False,
188
+ use_linear_projection: bool = False,
189
+ class_embed_type: Optional[str] = None,
190
+ addition_embed_type: Optional[str] = None,
191
+ addition_time_embed_dim: Optional[int] = None,
192
+ num_class_embeds: Optional[int] = None,
193
+ upcast_attention: bool = False,
194
+ resnet_time_scale_shift: str = "default",
195
+ resnet_skip_time_act: bool = False,
196
+ resnet_out_scale_factor: int = 1.0,
197
+ time_embedding_type: str = "positional",
198
+ time_embedding_dim: Optional[int] = None,
199
+ time_embedding_act_fn: Optional[str] = None,
200
+ timestep_post_act: Optional[str] = None,
201
+ time_cond_proj_dim: Optional[int] = None,
202
+ conv_in_kernel: int = 3,
203
+ conv_out_kernel: int = 3,
204
+ projection_class_embeddings_input_dim: Optional[int] = None,
205
+ class_embeddings_concat: bool = False,
206
+ mid_block_only_cross_attention: Optional[bool] = None,
207
+ cross_attention_norm: Optional[str] = None,
208
+ addition_embed_type_num_heads=64,
209
+ ):
210
+ super().__init__()
211
+
212
+ self.sample_size = sample_size
213
+
214
+ if num_attention_heads is not None:
215
+ raise ValueError(
216
+ "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
217
+ )
218
+
219
+ # If `num_attention_heads` is not defined (which is the case for most models)
220
+ # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
221
+ # The reason for this behavior is to correct for incorrectly named variables that were introduced
222
+ # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
223
+ # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
224
+ # which is why we correct for the naming here.
225
+ num_attention_heads = num_attention_heads or attention_head_dim
226
+
227
+ # Check inputs
228
+ if len(down_block_types) != len(up_block_types):
229
+ raise ValueError(
230
+ f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
231
+ )
232
+
233
+ if len(block_out_channels) != len(down_block_types):
234
+ raise ValueError(
235
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
236
+ )
237
+
238
+ if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
239
+ raise ValueError(
240
+ f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
241
+ )
242
+
243
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
244
+ raise ValueError(
245
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
246
+ )
247
+
248
+ if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
249
+ raise ValueError(
250
+ f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
251
+ )
252
+
253
+ if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
254
+ raise ValueError(
255
+ f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
256
+ )
257
+
258
+ if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
259
+ raise ValueError(
260
+ f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
261
+ )
262
+
263
+ # input
264
+ conv_in_padding = (conv_in_kernel - 1) // 2
265
+ self.conv_in = nn.Conv2d(
266
+ in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
267
+ )
268
+
269
+ # time
270
+ if time_embedding_type == "fourier":
271
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
272
+ if time_embed_dim % 2 != 0:
273
+ raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")
274
+ self.time_proj = GaussianFourierProjection(
275
+ time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
276
+ )
277
+ timestep_input_dim = time_embed_dim
278
+ elif time_embedding_type == "positional":
279
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
280
+
281
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
282
+ timestep_input_dim = block_out_channels[0]
283
+ else:
284
+ raise ValueError(
285
+ f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
286
+ )
287
+
288
+ self.time_embedding = TimestepEmbedding(
289
+ timestep_input_dim,
290
+ time_embed_dim,
291
+ act_fn=act_fn,
292
+ post_act_fn=timestep_post_act,
293
+ cond_proj_dim=time_cond_proj_dim,
294
+ )
295
+
296
+ if encoder_hid_dim_type is None and encoder_hid_dim is not None:
297
+ encoder_hid_dim_type = "text_proj"
298
+ self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
299
+ logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
300
+
301
+ if encoder_hid_dim is None and encoder_hid_dim_type is not None:
302
+ raise ValueError(
303
+ f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
304
+ )
305
+
306
+ if encoder_hid_dim_type == "text_proj":
307
+ self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
308
+ elif encoder_hid_dim_type == "text_image_proj":
309
+ # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
310
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
311
+ # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
312
+ self.encoder_hid_proj = TextImageProjection(
313
+ text_embed_dim=encoder_hid_dim,
314
+ image_embed_dim=cross_attention_dim,
315
+ cross_attention_dim=cross_attention_dim,
316
+ )
317
+ elif encoder_hid_dim_type == "image_proj":
318
+ # Kandinsky 2.2
319
+ self.encoder_hid_proj = ImageProjection(
320
+ image_embed_dim=encoder_hid_dim,
321
+ cross_attention_dim=cross_attention_dim,
322
+ )
323
+ elif encoder_hid_dim_type is not None:
324
+ raise ValueError(
325
+ f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
326
+ )
327
+ else:
328
+ self.encoder_hid_proj = None
329
+
330
+ # class embedding
331
+ if class_embed_type is None and num_class_embeds is not None:
332
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
333
+ elif class_embed_type == "timestep":
334
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)
335
+ elif class_embed_type == "identity":
336
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
337
+ elif class_embed_type == "projection":
338
+ if projection_class_embeddings_input_dim is None:
339
+ raise ValueError(
340
+ "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
341
+ )
342
+ # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
343
+ # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
344
+ # 2. it projects from an arbitrary input dimension.
345
+ #
346
+ # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
347
+ # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
348
+ # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
349
+ self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
350
+ elif class_embed_type == "simple_projection":
351
+ if projection_class_embeddings_input_dim is None:
352
+ raise ValueError(
353
+ "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
354
+ )
355
+ self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim)
356
+ else:
357
+ self.class_embedding = None
358
+
359
+ if addition_embed_type == "text":
360
+ if encoder_hid_dim is not None:
361
+ text_time_embedding_from_dim = encoder_hid_dim
362
+ else:
363
+ text_time_embedding_from_dim = cross_attention_dim
364
+
365
+ self.add_embedding = TextTimeEmbedding(
366
+ text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
367
+ )
368
+ elif addition_embed_type == "text_image":
369
+ # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
370
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
371
+ # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
372
+ self.add_embedding = TextImageTimeEmbedding(
373
+ text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
374
+ )
375
+ elif addition_embed_type == "text_time":
376
+ self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
377
+ self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
378
+ elif addition_embed_type == "image":
379
+ # Kandinsky 2.2
380
+ self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
381
+ elif addition_embed_type == "image_hint":
382
+ # Kandinsky 2.2 ControlNet
383
+ self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
384
+ elif addition_embed_type is not None:
385
+ raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
386
+
387
+ if time_embedding_act_fn is None:
388
+ self.time_embed_act = None
389
+ else:
390
+ self.time_embed_act = get_activation(time_embedding_act_fn)
391
+
392
+ self.down_blocks = nn.ModuleList([])
393
+ self.up_blocks = nn.ModuleList([])
394
+
395
+ if isinstance(only_cross_attention, bool):
396
+ if mid_block_only_cross_attention is None:
397
+ mid_block_only_cross_attention = only_cross_attention
398
+
399
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
400
+
401
+ if mid_block_only_cross_attention is None:
402
+ mid_block_only_cross_attention = False
403
+
404
+ if isinstance(num_attention_heads, int):
405
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
406
+
407
+ if isinstance(attention_head_dim, int):
408
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
409
+
410
+ if isinstance(cross_attention_dim, int):
411
+ cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
412
+
413
+ if isinstance(layers_per_block, int):
414
+ layers_per_block = [layers_per_block] * len(down_block_types)
415
+
416
+ if isinstance(transformer_layers_per_block, int):
417
+ transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
418
+
419
+ if class_embeddings_concat:
420
+ # The time embeddings are concatenated with the class embeddings. The dimension of the
421
+ # time embeddings passed to the down, middle, and up blocks is twice the dimension of the
422
+ # regular time embeddings
423
+ blocks_time_embed_dim = time_embed_dim * 2
424
+ else:
425
+ blocks_time_embed_dim = time_embed_dim
426
+
427
+ # down
428
+ output_channel = block_out_channels[0]
429
+ for i, down_block_type in enumerate(down_block_types):
430
+ input_channel = output_channel
431
+ output_channel = block_out_channels[i]
432
+ is_final_block = i == len(block_out_channels) - 1
433
+
434
+ down_block = get_down_block(
435
+ down_block_type,
436
+ num_layers=layers_per_block[i],
437
+ transformer_layers_per_block=transformer_layers_per_block[i],
438
+ in_channels=input_channel,
439
+ out_channels=output_channel,
440
+ temb_channels=blocks_time_embed_dim,
441
+ add_downsample=not is_final_block,
442
+ resnet_eps=norm_eps,
443
+ resnet_act_fn=act_fn,
444
+ resnet_groups=norm_num_groups,
445
+ cross_attention_dim=cross_attention_dim[i],
446
+ num_attention_heads=num_attention_heads[i],
447
+ downsample_padding=downsample_padding,
448
+ dual_cross_attention=dual_cross_attention,
449
+ use_linear_projection=use_linear_projection,
450
+ only_cross_attention=only_cross_attention[i],
451
+ upcast_attention=upcast_attention,
452
+ resnet_time_scale_shift=resnet_time_scale_shift,
453
+ resnet_skip_time_act=resnet_skip_time_act,
454
+ resnet_out_scale_factor=resnet_out_scale_factor,
455
+ cross_attention_norm=cross_attention_norm,
456
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
457
+ )
458
+ self.down_blocks.append(down_block)
459
+
460
+ # mid
461
+ if mid_block_type == "UNetMidBlock2DCrossAttn":
462
+ self.mid_block = UNetMidBlock2DCrossAttn(
463
+ transformer_layers_per_block=transformer_layers_per_block[-1],
464
+ in_channels=block_out_channels[-1],
465
+ temb_channels=blocks_time_embed_dim,
466
+ resnet_eps=norm_eps,
467
+ resnet_act_fn=act_fn,
468
+ output_scale_factor=mid_block_scale_factor,
469
+ resnet_time_scale_shift=resnet_time_scale_shift,
470
+ cross_attention_dim=cross_attention_dim[-1],
471
+ num_attention_heads=num_attention_heads[-1],
472
+ resnet_groups=norm_num_groups,
473
+ dual_cross_attention=dual_cross_attention,
474
+ use_linear_projection=use_linear_projection,
475
+ upcast_attention=upcast_attention,
476
+ )
477
+ elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn":
478
+ self.mid_block = UNetMidBlock2DSimpleCrossAttn(
479
+ in_channels=block_out_channels[-1],
480
+ temb_channels=blocks_time_embed_dim,
481
+ resnet_eps=norm_eps,
482
+ resnet_act_fn=act_fn,
483
+ output_scale_factor=mid_block_scale_factor,
484
+ cross_attention_dim=cross_attention_dim[-1],
485
+ attention_head_dim=attention_head_dim[-1],
486
+ resnet_groups=norm_num_groups,
487
+ resnet_time_scale_shift=resnet_time_scale_shift,
488
+ skip_time_act=resnet_skip_time_act,
489
+ only_cross_attention=mid_block_only_cross_attention,
490
+ cross_attention_norm=cross_attention_norm,
491
+ )
492
+ elif mid_block_type is None:
493
+ self.mid_block = None
494
+ else:
495
+ raise ValueError(f"unknown mid_block_type : {mid_block_type}")
496
+
497
+ # count how many layers upsample the images
498
+ self.num_upsamplers = 0
499
+
500
+ # up
501
+ reversed_block_out_channels = list(reversed(block_out_channels))
502
+ reversed_num_attention_heads = list(reversed(num_attention_heads))
503
+ reversed_layers_per_block = list(reversed(layers_per_block))
504
+ reversed_cross_attention_dim = list(reversed(cross_attention_dim))
505
+ reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
506
+ only_cross_attention = list(reversed(only_cross_attention))
507
+
508
+ output_channel = reversed_block_out_channels[0]
509
+ for i, up_block_type in enumerate(up_block_types):
510
+ is_final_block = i == len(block_out_channels) - 1
511
+
512
+ prev_output_channel = output_channel
513
+ output_channel = reversed_block_out_channels[i]
514
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
515
+
516
+ # add upsample block for all BUT final layer
517
+ if not is_final_block:
518
+ add_upsample = True
519
+ self.num_upsamplers += 1
520
+ else:
521
+ add_upsample = False
522
+
523
+ up_block = get_up_block(
524
+ up_block_type,
525
+ num_layers=reversed_layers_per_block[i] + 1,
526
+ transformer_layers_per_block=reversed_transformer_layers_per_block[i],
527
+ in_channels=input_channel,
528
+ out_channels=output_channel,
529
+ prev_output_channel=prev_output_channel,
530
+ temb_channels=blocks_time_embed_dim,
531
+ add_upsample=add_upsample,
532
+ resnet_eps=norm_eps,
533
+ resnet_act_fn=act_fn,
534
+ resnet_groups=norm_num_groups,
535
+ cross_attention_dim=reversed_cross_attention_dim[i],
536
+ num_attention_heads=reversed_num_attention_heads[i],
537
+ dual_cross_attention=dual_cross_attention,
538
+ use_linear_projection=use_linear_projection,
539
+ only_cross_attention=only_cross_attention[i],
540
+ upcast_attention=upcast_attention,
541
+ resnet_time_scale_shift=resnet_time_scale_shift,
542
+ resnet_skip_time_act=resnet_skip_time_act,
543
+ resnet_out_scale_factor=resnet_out_scale_factor,
544
+ cross_attention_norm=cross_attention_norm,
545
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
546
+ )
547
+ self.up_blocks.append(up_block)
548
+ prev_output_channel = output_channel
549
+
550
+ # out
551
+ if norm_num_groups is not None:
552
+ self.conv_norm_out = nn.GroupNorm(
553
+ num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
554
+ )
555
+
556
+ self.conv_act = get_activation(act_fn)
557
+
558
+ else:
559
+ self.conv_norm_out = None
560
+ self.conv_act = None
561
+
562
+ conv_out_padding = (conv_out_kernel - 1) // 2
563
+ self.conv_out = nn.Conv2d(
564
+ block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
565
+ )
566
+
567
+ @property
568
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
569
+ r"""
570
+ Returns:
571
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
572
+ indexed by its weight name.
573
+ """
574
+ # set recursively
575
+ processors = {}
576
+
577
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
578
+ if hasattr(module, "set_processor"):
579
+ processors[f"{name}.processor"] = module.processor
580
+
581
+ for sub_name, child in module.named_children():
582
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
583
+
584
+ return processors
585
+
586
+ for name, module in self.named_children():
587
+ fn_recursive_add_processors(name, module, processors)
588
+
589
+ return processors
590
+
591
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
592
+ r"""
593
+ Sets the attention processor to use to compute attention.
594
+
595
+ Parameters:
596
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
597
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
598
+ for **all** `Attention` layers.
599
+
600
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
601
+ processor. This is strongly recommended when setting trainable attention processors.
602
+
603
+ """
604
+ count = len(self.attn_processors.keys())
605
+
606
+ if isinstance(processor, dict) and len(processor) != count:
607
+ raise ValueError(
608
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
609
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
610
+ )
611
+
612
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
613
+ if hasattr(module, "set_processor"):
614
+ if not isinstance(processor, dict):
615
+ module.set_processor(processor)
616
+ else:
617
+ module.set_processor(processor.pop(f"{name}.processor"))
618
+
619
+ for sub_name, child in module.named_children():
620
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
621
+
622
+ for name, module in self.named_children():
623
+ fn_recursive_attn_processor(name, module, processor)
624
+
625
+ def set_default_attn_processor(self):
626
+ """
627
+ Disables custom attention processors and sets the default attention implementation.
628
+ """
629
+ self.set_attn_processor(AttnProcessor())
630
+
631
+ def set_attention_slice(self, slice_size):
632
+ r"""
633
+ Enable sliced attention computation.
634
+
635
+ When this option is enabled, the attention module splits the input tensor in slices to compute attention in
636
+ several steps. This is useful for saving some memory in exchange for a small decrease in speed.
637
+
638
+ Args:
639
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
640
+ When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
641
+ `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
642
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
643
+ must be a multiple of `slice_size`.
644
+ """
645
+ sliceable_head_dims = []
646
+
647
+ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
648
+ if hasattr(module, "set_attention_slice"):
649
+ sliceable_head_dims.append(module.sliceable_head_dim)
650
+
651
+ for child in module.children():
652
+ fn_recursive_retrieve_sliceable_dims(child)
653
+
654
+ # retrieve number of attention layers
655
+ for module in self.children():
656
+ fn_recursive_retrieve_sliceable_dims(module)
657
+
658
+ num_sliceable_layers = len(sliceable_head_dims)
659
+
660
+ if slice_size == "auto":
661
+ # half the attention head size is usually a good trade-off between
662
+ # speed and memory
663
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
664
+ elif slice_size == "max":
665
+ # make smallest slice possible
666
+ slice_size = num_sliceable_layers * [1]
667
+
668
+ slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
669
+
670
+ if len(slice_size) != len(sliceable_head_dims):
671
+ raise ValueError(
672
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
673
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
674
+ )
675
+
676
+ for i in range(len(slice_size)):
677
+ size = slice_size[i]
678
+ dim = sliceable_head_dims[i]
679
+ if size is not None and size > dim:
680
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
681
+
682
+ # Recursively walk through all the children.
683
+ # Any children which exposes the set_attention_slice method
684
+ # gets the message
685
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
686
+ if hasattr(module, "set_attention_slice"):
687
+ module.set_attention_slice(slice_size.pop())
688
+
689
+ for child in module.children():
690
+ fn_recursive_set_attention_slice(child, slice_size)
691
+
692
+ reversed_slice_size = list(reversed(slice_size))
693
+ for module in self.children():
694
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
695
+
696
+ def _set_gradient_checkpointing(self, module, value=False):
697
+ if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D, CrossAttnUpBlock2D, UpBlock2D)):
698
+ module.gradient_checkpointing = value
699
+
700
+ def forward(
701
+ self,
702
+ sample: torch.FloatTensor,
703
+ timestep: Union[torch.Tensor, float, int],
704
+ encoder_hidden_states: torch.Tensor,
705
+ class_labels: Optional[torch.Tensor] = None,
706
+ timestep_cond: Optional[torch.Tensor] = None,
707
+ attention_mask: Optional[torch.Tensor] = None,
708
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
709
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
710
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
711
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
712
+ encoder_attention_mask: Optional[torch.Tensor] = None,
713
+ return_dict: bool = True,
714
+ ) -> Union[UNet2DConditionOutput, Tuple]:
715
+ r"""
716
+ The [`UNet2DConditionModel`] forward method.
717
+
718
+ Args:
719
+ sample (`torch.FloatTensor`):
720
+ The noisy input tensor with the following shape `(batch, channel, height, width)`.
721
+ timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
722
+ encoder_hidden_states (`torch.FloatTensor`):
723
+ The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
724
+ encoder_attention_mask (`torch.Tensor`):
725
+ A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
726
+ `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
727
+ which adds large negative values to the attention scores corresponding to "discard" tokens.
728
+ return_dict (`bool`, *optional*, defaults to `True`):
729
+ Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
730
+ tuple.
731
+ cross_attention_kwargs (`dict`, *optional*):
732
+ A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
733
+ added_cond_kwargs: (`dict`, *optional*):
734
+ A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that
735
+ are passed along to the UNet blocks.
736
+
737
+ Returns:
738
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
739
+ If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
740
+ a `tuple` is returned where the first element is the sample tensor.
741
+ """
742
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
743
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
744
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
745
+ # on the fly if necessary.
746
+ default_overall_up_factor = 2**self.num_upsamplers
747
+
748
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
749
+ forward_upsample_size = False
750
+ upsample_size = None
751
+
752
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
753
+ logger.info("Forward upsample size to force interpolation output size.")
754
+ forward_upsample_size = True
755
+
756
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension
757
+ # expects mask of shape:
758
+ # [batch, key_tokens]
759
+ # adds singleton query_tokens dimension:
760
+ # [batch, 1, key_tokens]
761
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
762
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
763
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
764
+ if attention_mask is not None:
765
+ # assume that mask is expressed as:
766
+ # (1 = keep, 0 = discard)
767
+ # convert mask into a bias that can be added to attention scores:
768
+ # (keep = +0, discard = -10000.0)
769
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
770
+ attention_mask = attention_mask.unsqueeze(1)
771
+
772
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
773
+ if encoder_attention_mask is not None:
774
+ encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
775
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
776
+
777
+ # 0. center input if necessary
778
+ if self.config.center_input_sample:
779
+ sample = 2 * sample - 1.0
780
+
781
+ # 1. time
782
+ timesteps = timestep
783
+ if not torch.is_tensor(timesteps):
784
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
785
+ # This would be a good case for the `match` statement (Python 3.10+)
786
+ is_mps = sample.device.type == "mps"
787
+ if isinstance(timestep, float):
788
+ dtype = torch.float32 if is_mps else torch.float64
789
+ else:
790
+ dtype = torch.int32 if is_mps else torch.int64
791
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
792
+ elif len(timesteps.shape) == 0:
793
+ timesteps = timesteps[None].to(sample.device)
794
+
795
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
796
+ timesteps = timesteps.expand(sample.shape[0])
797
+
798
+ t_emb = self.time_proj(timesteps)
799
+
800
+ # `Timesteps` does not contain any weights and will always return f32 tensors
801
+ # but time_embedding might actually be running in fp16. so we need to cast here.
802
+ # there might be better ways to encapsulate this.
803
+ t_emb = t_emb.to(dtype=sample.dtype)
804
+
805
+ emb = self.time_embedding(t_emb, timestep_cond)
806
+ aug_emb = None
807
+
808
+ if self.class_embedding is not None:
809
+ if class_labels is None:
810
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
811
+
812
+ if self.config.class_embed_type == "timestep":
813
+ class_labels = self.time_proj(class_labels)
814
+
815
+ # `Timesteps` does not contain any weights and will always return f32 tensors
816
+ # there might be better ways to encapsulate this.
817
+ class_labels = class_labels.to(dtype=sample.dtype)
818
+
819
+ class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
820
+
821
+ if self.config.class_embeddings_concat:
822
+ emb = torch.cat([emb, class_emb], dim=-1)
823
+ else:
824
+ emb = emb + class_emb
825
+
826
+ if self.config.addition_embed_type == "text":
827
+ aug_emb = self.add_embedding(encoder_hidden_states)
828
+ elif self.config.addition_embed_type == "text_image":
829
+ # Kandinsky 2.1 - style
830
+ if "image_embeds" not in added_cond_kwargs:
831
+ raise ValueError(
832
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
833
+ )
834
+
835
+ image_embs = added_cond_kwargs.get("image_embeds")
836
+ text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
837
+ aug_emb = self.add_embedding(text_embs, image_embs)
838
+ elif self.config.addition_embed_type == "text_time":
839
+ if "text_embeds" not in added_cond_kwargs:
840
+ raise ValueError(
841
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
842
+ )
843
+ text_embeds = added_cond_kwargs.get("text_embeds")
844
+ if "time_ids" not in added_cond_kwargs:
845
+ raise ValueError(
846
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
847
+ )
848
+ time_ids = added_cond_kwargs.get("time_ids")
849
+ time_embeds = self.add_time_proj(time_ids.flatten())
850
+ time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
851
+
852
+ add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
853
+ add_embeds = add_embeds.to(emb.dtype)
854
+ aug_emb = self.add_embedding(add_embeds)
855
+ elif self.config.addition_embed_type == "image":
856
+ # Kandinsky 2.2 - style
857
+ if "image_embeds" not in added_cond_kwargs:
858
+ raise ValueError(
859
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
860
+ )
861
+ image_embs = added_cond_kwargs.get("image_embeds")
862
+ aug_emb = self.add_embedding(image_embs)
863
+ elif self.config.addition_embed_type == "image_hint":
864
+ # Kandinsky 2.2 - style
865
+ if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs:
866
+ raise ValueError(
867
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
868
+ )
869
+ image_embs = added_cond_kwargs.get("image_embeds")
870
+ hint = added_cond_kwargs.get("hint")
871
+ aug_emb, hint = self.add_embedding(image_embs, hint)
872
+ sample = torch.cat([sample, hint], dim=1)
873
+
874
+ emb = emb + aug_emb if aug_emb is not None else emb
875
+
876
+ if self.time_embed_act is not None:
877
+ emb = self.time_embed_act(emb)
878
+
879
+ if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
880
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
881
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
882
+ # Kadinsky 2.1 - style
883
+ if "image_embeds" not in added_cond_kwargs:
884
+ raise ValueError(
885
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
886
+ )
887
+
888
+ image_embeds = added_cond_kwargs.get("image_embeds")
889
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
890
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj":
891
+ # Kandinsky 2.2 - style
892
+ if "image_embeds" not in added_cond_kwargs:
893
+ raise ValueError(
894
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
895
+ )
896
+ image_embeds = added_cond_kwargs.get("image_embeds")
897
+ encoder_hidden_states = self.encoder_hid_proj(image_embeds)
898
+ # 2. pre-process
899
+ sample = self.conv_in(sample)
900
+
901
+ # 3. down
902
+ down_block_res_samples = (sample,)
903
+ for downsample_block in self.down_blocks:
904
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
905
+ sample, res_samples = downsample_block(
906
+ hidden_states=sample,
907
+ temb=emb,
908
+ encoder_hidden_states=encoder_hidden_states,
909
+ attention_mask=attention_mask,
910
+ cross_attention_kwargs=cross_attention_kwargs,
911
+ encoder_attention_mask=encoder_attention_mask,
912
+ )
913
+ else:
914
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
915
+
916
+ down_block_res_samples += res_samples
917
+
918
+ if down_block_additional_residuals is not None:
919
+ new_down_block_res_samples = ()
920
+
921
+ for down_block_res_sample, down_block_additional_residual in zip(
922
+ down_block_res_samples, down_block_additional_residuals
923
+ ):
924
+ down_block_res_sample = down_block_res_sample + down_block_additional_residual
925
+ new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
926
+
927
+ down_block_res_samples = new_down_block_res_samples
928
+
929
+ # 4. mid
930
+ if self.mid_block is not None:
931
+ sample = self.mid_block(
932
+ sample,
933
+ emb,
934
+ encoder_hidden_states=encoder_hidden_states,
935
+ attention_mask=attention_mask,
936
+ cross_attention_kwargs=cross_attention_kwargs,
937
+ encoder_attention_mask=encoder_attention_mask,
938
+ )
939
+
940
+ if mid_block_additional_residual is not None:
941
+ sample = sample + mid_block_additional_residual
942
+
943
+ # 5. up
944
+ for i, upsample_block in enumerate(self.up_blocks):
945
+ is_final_block = i == len(self.up_blocks) - 1
946
+
947
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
948
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
949
+
950
+ # if we have not reached the final block and need to forward the
951
+ # upsample size, we do it here
952
+ if not is_final_block and forward_upsample_size:
953
+ upsample_size = down_block_res_samples[-1].shape[2:]
954
+
955
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
956
+ sample = upsample_block(
957
+ hidden_states=sample,
958
+ temb=emb,
959
+ res_hidden_states_tuple=res_samples,
960
+ encoder_hidden_states=encoder_hidden_states,
961
+ cross_attention_kwargs=cross_attention_kwargs,
962
+ upsample_size=upsample_size,
963
+ attention_mask=attention_mask,
964
+ encoder_attention_mask=encoder_attention_mask,
965
+ )
966
+ else:
967
+ sample = upsample_block(
968
+ hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
969
+ )
970
+
971
+ # 6. post-process
972
+ if self.conv_norm_out:
973
+ sample = self.conv_norm_out(sample)
974
+ sample = self.conv_act(sample)
975
+ sample = self.conv_out(sample)
976
+
977
+ if not return_dict:
978
+ return (sample,)
979
+
980
+ return UNet2DConditionOutput(sample=sample)
6DoF/diffusers/models/unet_2d_condition_flax.py ADDED
@@ -0,0 +1,357 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Optional, Tuple, Union
15
+
16
+ import flax
17
+ import flax.linen as nn
18
+ import jax
19
+ import jax.numpy as jnp
20
+ from flax.core.frozen_dict import FrozenDict
21
+
22
+ from ..configuration_utils import ConfigMixin, flax_register_to_config
23
+ from ..utils import BaseOutput
24
+ from .embeddings_flax import FlaxTimestepEmbedding, FlaxTimesteps
25
+ from .modeling_flax_utils import FlaxModelMixin
26
+ from .unet_2d_blocks_flax import (
27
+ FlaxCrossAttnDownBlock2D,
28
+ FlaxCrossAttnUpBlock2D,
29
+ FlaxDownBlock2D,
30
+ FlaxUNetMidBlock2DCrossAttn,
31
+ FlaxUpBlock2D,
32
+ )
33
+
34
+
35
+ @flax.struct.dataclass
36
+ class FlaxUNet2DConditionOutput(BaseOutput):
37
+ """
38
+ The output of [`FlaxUNet2DConditionModel`].
39
+
40
+ Args:
41
+ sample (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)`):
42
+ The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
43
+ """
44
+
45
+ sample: jnp.ndarray
46
+
47
+
48
+ @flax_register_to_config
49
+ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
50
+ r"""
51
+ A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
52
+ shaped output.
53
+
54
+ This model inherits from [`FlaxModelMixin`]. Check the superclass documentation for it's generic methods
55
+ implemented for all models (such as downloading or saving).
56
+
57
+ This model is also a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module)
58
+ subclass. Use it as a regular Flax Linen module and refer to the Flax documentation for all matters related to its
59
+ general usage and behavior.
60
+
61
+ Inherent JAX features such as the following are supported:
62
+ - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
63
+ - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
64
+ - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
65
+ - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
66
+
67
+ Parameters:
68
+ sample_size (`int`, *optional*):
69
+ The size of the input sample.
70
+ in_channels (`int`, *optional*, defaults to 4):
71
+ The number of channels in the input sample.
72
+ out_channels (`int`, *optional*, defaults to 4):
73
+ The number of channels in the output.
74
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxDownBlock2D")`):
75
+ The tuple of downsample blocks to use.
76
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("FlaxUpBlock2D", "FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D")`):
77
+ The tuple of upsample blocks to use.
78
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
79
+ The tuple of output channels for each block.
80
+ layers_per_block (`int`, *optional*, defaults to 2):
81
+ The number of layers per block.
82
+ attention_head_dim (`int` or `Tuple[int]`, *optional*, defaults to 8):
83
+ The dimension of the attention heads.
84
+ num_attention_heads (`int` or `Tuple[int]`, *optional*):
85
+ The number of attention heads.
86
+ cross_attention_dim (`int`, *optional*, defaults to 768):
87
+ The dimension of the cross attention features.
88
+ dropout (`float`, *optional*, defaults to 0):
89
+ Dropout probability for down, up and bottleneck blocks.
90
+ flip_sin_to_cos (`bool`, *optional*, defaults to `True`):
91
+ Whether to flip the sin to cos in the time embedding.
92
+ freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
93
+ use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
94
+ Enable memory efficient attention as described [here](https://arxiv.org/abs/2112.05682).
95
+ """
96
+
97
+ sample_size: int = 32
98
+ in_channels: int = 4
99
+ out_channels: int = 4
100
+ down_block_types: Tuple[str] = (
101
+ "CrossAttnDownBlock2D",
102
+ "CrossAttnDownBlock2D",
103
+ "CrossAttnDownBlock2D",
104
+ "DownBlock2D",
105
+ )
106
+ up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")
107
+ only_cross_attention: Union[bool, Tuple[bool]] = False
108
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280)
109
+ layers_per_block: int = 2
110
+ attention_head_dim: Union[int, Tuple[int]] = 8
111
+ num_attention_heads: Optional[Union[int, Tuple[int]]] = None
112
+ cross_attention_dim: int = 1280
113
+ dropout: float = 0.0
114
+ use_linear_projection: bool = False
115
+ dtype: jnp.dtype = jnp.float32
116
+ flip_sin_to_cos: bool = True
117
+ freq_shift: int = 0
118
+ use_memory_efficient_attention: bool = False
119
+
120
+ def init_weights(self, rng: jax.random.KeyArray) -> FrozenDict:
121
+ # init input tensors
122
+ sample_shape = (1, self.in_channels, self.sample_size, self.sample_size)
123
+ sample = jnp.zeros(sample_shape, dtype=jnp.float32)
124
+ timesteps = jnp.ones((1,), dtype=jnp.int32)
125
+ encoder_hidden_states = jnp.zeros((1, 1, self.cross_attention_dim), dtype=jnp.float32)
126
+
127
+ params_rng, dropout_rng = jax.random.split(rng)
128
+ rngs = {"params": params_rng, "dropout": dropout_rng}
129
+
130
+ return self.init(rngs, sample, timesteps, encoder_hidden_states)["params"]
131
+
132
+ def setup(self):
133
+ block_out_channels = self.block_out_channels
134
+ time_embed_dim = block_out_channels[0] * 4
135
+
136
+ if self.num_attention_heads is not None:
137
+ raise ValueError(
138
+ "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
139
+ )
140
+
141
+ # If `num_attention_heads` is not defined (which is the case for most models)
142
+ # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
143
+ # The reason for this behavior is to correct for incorrectly named variables that were introduced
144
+ # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
145
+ # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
146
+ # which is why we correct for the naming here.
147
+ num_attention_heads = self.num_attention_heads or self.attention_head_dim
148
+
149
+ # input
150
+ self.conv_in = nn.Conv(
151
+ block_out_channels[0],
152
+ kernel_size=(3, 3),
153
+ strides=(1, 1),
154
+ padding=((1, 1), (1, 1)),
155
+ dtype=self.dtype,
156
+ )
157
+
158
+ # time
159
+ self.time_proj = FlaxTimesteps(
160
+ block_out_channels[0], flip_sin_to_cos=self.flip_sin_to_cos, freq_shift=self.config.freq_shift
161
+ )
162
+ self.time_embedding = FlaxTimestepEmbedding(time_embed_dim, dtype=self.dtype)
163
+
164
+ only_cross_attention = self.only_cross_attention
165
+ if isinstance(only_cross_attention, bool):
166
+ only_cross_attention = (only_cross_attention,) * len(self.down_block_types)
167
+
168
+ if isinstance(num_attention_heads, int):
169
+ num_attention_heads = (num_attention_heads,) * len(self.down_block_types)
170
+
171
+ # down
172
+ down_blocks = []
173
+ output_channel = block_out_channels[0]
174
+ for i, down_block_type in enumerate(self.down_block_types):
175
+ input_channel = output_channel
176
+ output_channel = block_out_channels[i]
177
+ is_final_block = i == len(block_out_channels) - 1
178
+
179
+ if down_block_type == "CrossAttnDownBlock2D":
180
+ down_block = FlaxCrossAttnDownBlock2D(
181
+ in_channels=input_channel,
182
+ out_channels=output_channel,
183
+ dropout=self.dropout,
184
+ num_layers=self.layers_per_block,
185
+ num_attention_heads=num_attention_heads[i],
186
+ add_downsample=not is_final_block,
187
+ use_linear_projection=self.use_linear_projection,
188
+ only_cross_attention=only_cross_attention[i],
189
+ use_memory_efficient_attention=self.use_memory_efficient_attention,
190
+ dtype=self.dtype,
191
+ )
192
+ else:
193
+ down_block = FlaxDownBlock2D(
194
+ in_channels=input_channel,
195
+ out_channels=output_channel,
196
+ dropout=self.dropout,
197
+ num_layers=self.layers_per_block,
198
+ add_downsample=not is_final_block,
199
+ dtype=self.dtype,
200
+ )
201
+
202
+ down_blocks.append(down_block)
203
+ self.down_blocks = down_blocks
204
+
205
+ # mid
206
+ self.mid_block = FlaxUNetMidBlock2DCrossAttn(
207
+ in_channels=block_out_channels[-1],
208
+ dropout=self.dropout,
209
+ num_attention_heads=num_attention_heads[-1],
210
+ use_linear_projection=self.use_linear_projection,
211
+ use_memory_efficient_attention=self.use_memory_efficient_attention,
212
+ dtype=self.dtype,
213
+ )
214
+
215
+ # up
216
+ up_blocks = []
217
+ reversed_block_out_channels = list(reversed(block_out_channels))
218
+ reversed_num_attention_heads = list(reversed(num_attention_heads))
219
+ only_cross_attention = list(reversed(only_cross_attention))
220
+ output_channel = reversed_block_out_channels[0]
221
+ for i, up_block_type in enumerate(self.up_block_types):
222
+ prev_output_channel = output_channel
223
+ output_channel = reversed_block_out_channels[i]
224
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
225
+
226
+ is_final_block = i == len(block_out_channels) - 1
227
+
228
+ if up_block_type == "CrossAttnUpBlock2D":
229
+ up_block = FlaxCrossAttnUpBlock2D(
230
+ in_channels=input_channel,
231
+ out_channels=output_channel,
232
+ prev_output_channel=prev_output_channel,
233
+ num_layers=self.layers_per_block + 1,
234
+ num_attention_heads=reversed_num_attention_heads[i],
235
+ add_upsample=not is_final_block,
236
+ dropout=self.dropout,
237
+ use_linear_projection=self.use_linear_projection,
238
+ only_cross_attention=only_cross_attention[i],
239
+ use_memory_efficient_attention=self.use_memory_efficient_attention,
240
+ dtype=self.dtype,
241
+ )
242
+ else:
243
+ up_block = FlaxUpBlock2D(
244
+ in_channels=input_channel,
245
+ out_channels=output_channel,
246
+ prev_output_channel=prev_output_channel,
247
+ num_layers=self.layers_per_block + 1,
248
+ add_upsample=not is_final_block,
249
+ dropout=self.dropout,
250
+ dtype=self.dtype,
251
+ )
252
+
253
+ up_blocks.append(up_block)
254
+ prev_output_channel = output_channel
255
+ self.up_blocks = up_blocks
256
+
257
+ # out
258
+ self.conv_norm_out = nn.GroupNorm(num_groups=32, epsilon=1e-5)
259
+ self.conv_out = nn.Conv(
260
+ self.out_channels,
261
+ kernel_size=(3, 3),
262
+ strides=(1, 1),
263
+ padding=((1, 1), (1, 1)),
264
+ dtype=self.dtype,
265
+ )
266
+
267
+ def __call__(
268
+ self,
269
+ sample,
270
+ timesteps,
271
+ encoder_hidden_states,
272
+ down_block_additional_residuals=None,
273
+ mid_block_additional_residual=None,
274
+ return_dict: bool = True,
275
+ train: bool = False,
276
+ ) -> Union[FlaxUNet2DConditionOutput, Tuple]:
277
+ r"""
278
+ Args:
279
+ sample (`jnp.ndarray`): (batch, channel, height, width) noisy inputs tensor
280
+ timestep (`jnp.ndarray` or `float` or `int`): timesteps
281
+ encoder_hidden_states (`jnp.ndarray`): (batch_size, sequence_length, hidden_size) encoder hidden states
282
+ return_dict (`bool`, *optional*, defaults to `True`):
283
+ Whether or not to return a [`models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] instead of a
284
+ plain tuple.
285
+ train (`bool`, *optional*, defaults to `False`):
286
+ Use deterministic functions and disable dropout when not training.
287
+
288
+ Returns:
289
+ [`~models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] or `tuple`:
290
+ [`~models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`.
291
+ When returning a tuple, the first element is the sample tensor.
292
+ """
293
+ # 1. time
294
+ if not isinstance(timesteps, jnp.ndarray):
295
+ timesteps = jnp.array([timesteps], dtype=jnp.int32)
296
+ elif isinstance(timesteps, jnp.ndarray) and len(timesteps.shape) == 0:
297
+ timesteps = timesteps.astype(dtype=jnp.float32)
298
+ timesteps = jnp.expand_dims(timesteps, 0)
299
+
300
+ t_emb = self.time_proj(timesteps)
301
+ t_emb = self.time_embedding(t_emb)
302
+
303
+ # 2. pre-process
304
+ sample = jnp.transpose(sample, (0, 2, 3, 1))
305
+ sample = self.conv_in(sample)
306
+
307
+ # 3. down
308
+ down_block_res_samples = (sample,)
309
+ for down_block in self.down_blocks:
310
+ if isinstance(down_block, FlaxCrossAttnDownBlock2D):
311
+ sample, res_samples = down_block(sample, t_emb, encoder_hidden_states, deterministic=not train)
312
+ else:
313
+ sample, res_samples = down_block(sample, t_emb, deterministic=not train)
314
+ down_block_res_samples += res_samples
315
+
316
+ if down_block_additional_residuals is not None:
317
+ new_down_block_res_samples = ()
318
+
319
+ for down_block_res_sample, down_block_additional_residual in zip(
320
+ down_block_res_samples, down_block_additional_residuals
321
+ ):
322
+ down_block_res_sample += down_block_additional_residual
323
+ new_down_block_res_samples += (down_block_res_sample,)
324
+
325
+ down_block_res_samples = new_down_block_res_samples
326
+
327
+ # 4. mid
328
+ sample = self.mid_block(sample, t_emb, encoder_hidden_states, deterministic=not train)
329
+
330
+ if mid_block_additional_residual is not None:
331
+ sample += mid_block_additional_residual
332
+
333
+ # 5. up
334
+ for up_block in self.up_blocks:
335
+ res_samples = down_block_res_samples[-(self.layers_per_block + 1) :]
336
+ down_block_res_samples = down_block_res_samples[: -(self.layers_per_block + 1)]
337
+ if isinstance(up_block, FlaxCrossAttnUpBlock2D):
338
+ sample = up_block(
339
+ sample,
340
+ temb=t_emb,
341
+ encoder_hidden_states=encoder_hidden_states,
342
+ res_hidden_states_tuple=res_samples,
343
+ deterministic=not train,
344
+ )
345
+ else:
346
+ sample = up_block(sample, temb=t_emb, res_hidden_states_tuple=res_samples, deterministic=not train)
347
+
348
+ # 6. post-process
349
+ sample = self.conv_norm_out(sample)
350
+ sample = nn.silu(sample)
351
+ sample = self.conv_out(sample)
352
+ sample = jnp.transpose(sample, (0, 3, 1, 2))
353
+
354
+ if not return_dict:
355
+ return (sample,)
356
+
357
+ return FlaxUNet2DConditionOutput(sample=sample)
6DoF/diffusers/models/unet_3d_blocks.py ADDED
@@ -0,0 +1,679 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import torch
16
+ from torch import nn
17
+
18
+ from .resnet import Downsample2D, ResnetBlock2D, TemporalConvLayer, Upsample2D
19
+ from .transformer_2d import Transformer2DModel
20
+ from .transformer_temporal import TransformerTemporalModel
21
+
22
+
23
+ def get_down_block(
24
+ down_block_type,
25
+ num_layers,
26
+ in_channels,
27
+ out_channels,
28
+ temb_channels,
29
+ add_downsample,
30
+ resnet_eps,
31
+ resnet_act_fn,
32
+ num_attention_heads,
33
+ resnet_groups=None,
34
+ cross_attention_dim=None,
35
+ downsample_padding=None,
36
+ dual_cross_attention=False,
37
+ use_linear_projection=True,
38
+ only_cross_attention=False,
39
+ upcast_attention=False,
40
+ resnet_time_scale_shift="default",
41
+ ):
42
+ if down_block_type == "DownBlock3D":
43
+ return DownBlock3D(
44
+ num_layers=num_layers,
45
+ in_channels=in_channels,
46
+ out_channels=out_channels,
47
+ temb_channels=temb_channels,
48
+ add_downsample=add_downsample,
49
+ resnet_eps=resnet_eps,
50
+ resnet_act_fn=resnet_act_fn,
51
+ resnet_groups=resnet_groups,
52
+ downsample_padding=downsample_padding,
53
+ resnet_time_scale_shift=resnet_time_scale_shift,
54
+ )
55
+ elif down_block_type == "CrossAttnDownBlock3D":
56
+ if cross_attention_dim is None:
57
+ raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock3D")
58
+ return CrossAttnDownBlock3D(
59
+ num_layers=num_layers,
60
+ in_channels=in_channels,
61
+ out_channels=out_channels,
62
+ temb_channels=temb_channels,
63
+ add_downsample=add_downsample,
64
+ resnet_eps=resnet_eps,
65
+ resnet_act_fn=resnet_act_fn,
66
+ resnet_groups=resnet_groups,
67
+ downsample_padding=downsample_padding,
68
+ cross_attention_dim=cross_attention_dim,
69
+ num_attention_heads=num_attention_heads,
70
+ dual_cross_attention=dual_cross_attention,
71
+ use_linear_projection=use_linear_projection,
72
+ only_cross_attention=only_cross_attention,
73
+ upcast_attention=upcast_attention,
74
+ resnet_time_scale_shift=resnet_time_scale_shift,
75
+ )
76
+ raise ValueError(f"{down_block_type} does not exist.")
77
+
78
+
79
+ def get_up_block(
80
+ up_block_type,
81
+ num_layers,
82
+ in_channels,
83
+ out_channels,
84
+ prev_output_channel,
85
+ temb_channels,
86
+ add_upsample,
87
+ resnet_eps,
88
+ resnet_act_fn,
89
+ num_attention_heads,
90
+ resnet_groups=None,
91
+ cross_attention_dim=None,
92
+ dual_cross_attention=False,
93
+ use_linear_projection=True,
94
+ only_cross_attention=False,
95
+ upcast_attention=False,
96
+ resnet_time_scale_shift="default",
97
+ ):
98
+ if up_block_type == "UpBlock3D":
99
+ return UpBlock3D(
100
+ num_layers=num_layers,
101
+ in_channels=in_channels,
102
+ out_channels=out_channels,
103
+ prev_output_channel=prev_output_channel,
104
+ temb_channels=temb_channels,
105
+ add_upsample=add_upsample,
106
+ resnet_eps=resnet_eps,
107
+ resnet_act_fn=resnet_act_fn,
108
+ resnet_groups=resnet_groups,
109
+ resnet_time_scale_shift=resnet_time_scale_shift,
110
+ )
111
+ elif up_block_type == "CrossAttnUpBlock3D":
112
+ if cross_attention_dim is None:
113
+ raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock3D")
114
+ return CrossAttnUpBlock3D(
115
+ num_layers=num_layers,
116
+ in_channels=in_channels,
117
+ out_channels=out_channels,
118
+ prev_output_channel=prev_output_channel,
119
+ temb_channels=temb_channels,
120
+ add_upsample=add_upsample,
121
+ resnet_eps=resnet_eps,
122
+ resnet_act_fn=resnet_act_fn,
123
+ resnet_groups=resnet_groups,
124
+ cross_attention_dim=cross_attention_dim,
125
+ num_attention_heads=num_attention_heads,
126
+ dual_cross_attention=dual_cross_attention,
127
+ use_linear_projection=use_linear_projection,
128
+ only_cross_attention=only_cross_attention,
129
+ upcast_attention=upcast_attention,
130
+ resnet_time_scale_shift=resnet_time_scale_shift,
131
+ )
132
+ raise ValueError(f"{up_block_type} does not exist.")
133
+
134
+
135
+ class UNetMidBlock3DCrossAttn(nn.Module):
136
+ def __init__(
137
+ self,
138
+ in_channels: int,
139
+ temb_channels: int,
140
+ dropout: float = 0.0,
141
+ num_layers: int = 1,
142
+ resnet_eps: float = 1e-6,
143
+ resnet_time_scale_shift: str = "default",
144
+ resnet_act_fn: str = "swish",
145
+ resnet_groups: int = 32,
146
+ resnet_pre_norm: bool = True,
147
+ num_attention_heads=1,
148
+ output_scale_factor=1.0,
149
+ cross_attention_dim=1280,
150
+ dual_cross_attention=False,
151
+ use_linear_projection=True,
152
+ upcast_attention=False,
153
+ ):
154
+ super().__init__()
155
+
156
+ self.has_cross_attention = True
157
+ self.num_attention_heads = num_attention_heads
158
+ resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
159
+
160
+ # there is always at least one resnet
161
+ resnets = [
162
+ ResnetBlock2D(
163
+ in_channels=in_channels,
164
+ out_channels=in_channels,
165
+ temb_channels=temb_channels,
166
+ eps=resnet_eps,
167
+ groups=resnet_groups,
168
+ dropout=dropout,
169
+ time_embedding_norm=resnet_time_scale_shift,
170
+ non_linearity=resnet_act_fn,
171
+ output_scale_factor=output_scale_factor,
172
+ pre_norm=resnet_pre_norm,
173
+ )
174
+ ]
175
+ temp_convs = [
176
+ TemporalConvLayer(
177
+ in_channels,
178
+ in_channels,
179
+ dropout=0.1,
180
+ )
181
+ ]
182
+ attentions = []
183
+ temp_attentions = []
184
+
185
+ for _ in range(num_layers):
186
+ attentions.append(
187
+ Transformer2DModel(
188
+ in_channels // num_attention_heads,
189
+ num_attention_heads,
190
+ in_channels=in_channels,
191
+ num_layers=1,
192
+ cross_attention_dim=cross_attention_dim,
193
+ norm_num_groups=resnet_groups,
194
+ use_linear_projection=use_linear_projection,
195
+ upcast_attention=upcast_attention,
196
+ )
197
+ )
198
+ temp_attentions.append(
199
+ TransformerTemporalModel(
200
+ in_channels // num_attention_heads,
201
+ num_attention_heads,
202
+ in_channels=in_channels,
203
+ num_layers=1,
204
+ cross_attention_dim=cross_attention_dim,
205
+ norm_num_groups=resnet_groups,
206
+ )
207
+ )
208
+ resnets.append(
209
+ ResnetBlock2D(
210
+ in_channels=in_channels,
211
+ out_channels=in_channels,
212
+ temb_channels=temb_channels,
213
+ eps=resnet_eps,
214
+ groups=resnet_groups,
215
+ dropout=dropout,
216
+ time_embedding_norm=resnet_time_scale_shift,
217
+ non_linearity=resnet_act_fn,
218
+ output_scale_factor=output_scale_factor,
219
+ pre_norm=resnet_pre_norm,
220
+ )
221
+ )
222
+ temp_convs.append(
223
+ TemporalConvLayer(
224
+ in_channels,
225
+ in_channels,
226
+ dropout=0.1,
227
+ )
228
+ )
229
+
230
+ self.resnets = nn.ModuleList(resnets)
231
+ self.temp_convs = nn.ModuleList(temp_convs)
232
+ self.attentions = nn.ModuleList(attentions)
233
+ self.temp_attentions = nn.ModuleList(temp_attentions)
234
+
235
+ def forward(
236
+ self,
237
+ hidden_states,
238
+ temb=None,
239
+ encoder_hidden_states=None,
240
+ attention_mask=None,
241
+ num_frames=1,
242
+ cross_attention_kwargs=None,
243
+ ):
244
+ hidden_states = self.resnets[0](hidden_states, temb)
245
+ hidden_states = self.temp_convs[0](hidden_states, num_frames=num_frames)
246
+ for attn, temp_attn, resnet, temp_conv in zip(
247
+ self.attentions, self.temp_attentions, self.resnets[1:], self.temp_convs[1:]
248
+ ):
249
+ hidden_states = attn(
250
+ hidden_states,
251
+ encoder_hidden_states=encoder_hidden_states,
252
+ cross_attention_kwargs=cross_attention_kwargs,
253
+ return_dict=False,
254
+ )[0]
255
+ hidden_states = temp_attn(
256
+ hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs, return_dict=False
257
+ )[0]
258
+ hidden_states = resnet(hidden_states, temb)
259
+ hidden_states = temp_conv(hidden_states, num_frames=num_frames)
260
+
261
+ return hidden_states
262
+
263
+
264
+ class CrossAttnDownBlock3D(nn.Module):
265
+ def __init__(
266
+ self,
267
+ in_channels: int,
268
+ out_channels: int,
269
+ temb_channels: int,
270
+ dropout: float = 0.0,
271
+ num_layers: int = 1,
272
+ resnet_eps: float = 1e-6,
273
+ resnet_time_scale_shift: str = "default",
274
+ resnet_act_fn: str = "swish",
275
+ resnet_groups: int = 32,
276
+ resnet_pre_norm: bool = True,
277
+ num_attention_heads=1,
278
+ cross_attention_dim=1280,
279
+ output_scale_factor=1.0,
280
+ downsample_padding=1,
281
+ add_downsample=True,
282
+ dual_cross_attention=False,
283
+ use_linear_projection=False,
284
+ only_cross_attention=False,
285
+ upcast_attention=False,
286
+ ):
287
+ super().__init__()
288
+ resnets = []
289
+ attentions = []
290
+ temp_attentions = []
291
+ temp_convs = []
292
+
293
+ self.has_cross_attention = True
294
+ self.num_attention_heads = num_attention_heads
295
+
296
+ for i in range(num_layers):
297
+ in_channels = in_channels if i == 0 else out_channels
298
+ resnets.append(
299
+ ResnetBlock2D(
300
+ in_channels=in_channels,
301
+ out_channels=out_channels,
302
+ temb_channels=temb_channels,
303
+ eps=resnet_eps,
304
+ groups=resnet_groups,
305
+ dropout=dropout,
306
+ time_embedding_norm=resnet_time_scale_shift,
307
+ non_linearity=resnet_act_fn,
308
+ output_scale_factor=output_scale_factor,
309
+ pre_norm=resnet_pre_norm,
310
+ )
311
+ )
312
+ temp_convs.append(
313
+ TemporalConvLayer(
314
+ out_channels,
315
+ out_channels,
316
+ dropout=0.1,
317
+ )
318
+ )
319
+ attentions.append(
320
+ Transformer2DModel(
321
+ out_channels // num_attention_heads,
322
+ num_attention_heads,
323
+ in_channels=out_channels,
324
+ num_layers=1,
325
+ cross_attention_dim=cross_attention_dim,
326
+ norm_num_groups=resnet_groups,
327
+ use_linear_projection=use_linear_projection,
328
+ only_cross_attention=only_cross_attention,
329
+ upcast_attention=upcast_attention,
330
+ )
331
+ )
332
+ temp_attentions.append(
333
+ TransformerTemporalModel(
334
+ out_channels // num_attention_heads,
335
+ num_attention_heads,
336
+ in_channels=out_channels,
337
+ num_layers=1,
338
+ cross_attention_dim=cross_attention_dim,
339
+ norm_num_groups=resnet_groups,
340
+ )
341
+ )
342
+ self.resnets = nn.ModuleList(resnets)
343
+ self.temp_convs = nn.ModuleList(temp_convs)
344
+ self.attentions = nn.ModuleList(attentions)
345
+ self.temp_attentions = nn.ModuleList(temp_attentions)
346
+
347
+ if add_downsample:
348
+ self.downsamplers = nn.ModuleList(
349
+ [
350
+ Downsample2D(
351
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
352
+ )
353
+ ]
354
+ )
355
+ else:
356
+ self.downsamplers = None
357
+
358
+ self.gradient_checkpointing = False
359
+
360
+ def forward(
361
+ self,
362
+ hidden_states,
363
+ temb=None,
364
+ encoder_hidden_states=None,
365
+ attention_mask=None,
366
+ num_frames=1,
367
+ cross_attention_kwargs=None,
368
+ ):
369
+ # TODO(Patrick, William) - attention mask is not used
370
+ output_states = ()
371
+
372
+ for resnet, temp_conv, attn, temp_attn in zip(
373
+ self.resnets, self.temp_convs, self.attentions, self.temp_attentions
374
+ ):
375
+ hidden_states = resnet(hidden_states, temb)
376
+ hidden_states = temp_conv(hidden_states, num_frames=num_frames)
377
+ hidden_states = attn(
378
+ hidden_states,
379
+ encoder_hidden_states=encoder_hidden_states,
380
+ cross_attention_kwargs=cross_attention_kwargs,
381
+ return_dict=False,
382
+ )[0]
383
+ hidden_states = temp_attn(
384
+ hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs, return_dict=False
385
+ )[0]
386
+
387
+ output_states += (hidden_states,)
388
+
389
+ if self.downsamplers is not None:
390
+ for downsampler in self.downsamplers:
391
+ hidden_states = downsampler(hidden_states)
392
+
393
+ output_states += (hidden_states,)
394
+
395
+ return hidden_states, output_states
396
+
397
+
398
+ class DownBlock3D(nn.Module):
399
+ def __init__(
400
+ self,
401
+ in_channels: int,
402
+ out_channels: int,
403
+ temb_channels: int,
404
+ dropout: float = 0.0,
405
+ num_layers: int = 1,
406
+ resnet_eps: float = 1e-6,
407
+ resnet_time_scale_shift: str = "default",
408
+ resnet_act_fn: str = "swish",
409
+ resnet_groups: int = 32,
410
+ resnet_pre_norm: bool = True,
411
+ output_scale_factor=1.0,
412
+ add_downsample=True,
413
+ downsample_padding=1,
414
+ ):
415
+ super().__init__()
416
+ resnets = []
417
+ temp_convs = []
418
+
419
+ for i in range(num_layers):
420
+ in_channels = in_channels if i == 0 else out_channels
421
+ resnets.append(
422
+ ResnetBlock2D(
423
+ in_channels=in_channels,
424
+ out_channels=out_channels,
425
+ temb_channels=temb_channels,
426
+ eps=resnet_eps,
427
+ groups=resnet_groups,
428
+ dropout=dropout,
429
+ time_embedding_norm=resnet_time_scale_shift,
430
+ non_linearity=resnet_act_fn,
431
+ output_scale_factor=output_scale_factor,
432
+ pre_norm=resnet_pre_norm,
433
+ )
434
+ )
435
+ temp_convs.append(
436
+ TemporalConvLayer(
437
+ out_channels,
438
+ out_channels,
439
+ dropout=0.1,
440
+ )
441
+ )
442
+
443
+ self.resnets = nn.ModuleList(resnets)
444
+ self.temp_convs = nn.ModuleList(temp_convs)
445
+
446
+ if add_downsample:
447
+ self.downsamplers = nn.ModuleList(
448
+ [
449
+ Downsample2D(
450
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
451
+ )
452
+ ]
453
+ )
454
+ else:
455
+ self.downsamplers = None
456
+
457
+ self.gradient_checkpointing = False
458
+
459
+ def forward(self, hidden_states, temb=None, num_frames=1):
460
+ output_states = ()
461
+
462
+ for resnet, temp_conv in zip(self.resnets, self.temp_convs):
463
+ hidden_states = resnet(hidden_states, temb)
464
+ hidden_states = temp_conv(hidden_states, num_frames=num_frames)
465
+
466
+ output_states += (hidden_states,)
467
+
468
+ if self.downsamplers is not None:
469
+ for downsampler in self.downsamplers:
470
+ hidden_states = downsampler(hidden_states)
471
+
472
+ output_states += (hidden_states,)
473
+
474
+ return hidden_states, output_states
475
+
476
+
477
+ class CrossAttnUpBlock3D(nn.Module):
478
+ def __init__(
479
+ self,
480
+ in_channels: int,
481
+ out_channels: int,
482
+ prev_output_channel: int,
483
+ temb_channels: int,
484
+ dropout: float = 0.0,
485
+ num_layers: int = 1,
486
+ resnet_eps: float = 1e-6,
487
+ resnet_time_scale_shift: str = "default",
488
+ resnet_act_fn: str = "swish",
489
+ resnet_groups: int = 32,
490
+ resnet_pre_norm: bool = True,
491
+ num_attention_heads=1,
492
+ cross_attention_dim=1280,
493
+ output_scale_factor=1.0,
494
+ add_upsample=True,
495
+ dual_cross_attention=False,
496
+ use_linear_projection=False,
497
+ only_cross_attention=False,
498
+ upcast_attention=False,
499
+ ):
500
+ super().__init__()
501
+ resnets = []
502
+ temp_convs = []
503
+ attentions = []
504
+ temp_attentions = []
505
+
506
+ self.has_cross_attention = True
507
+ self.num_attention_heads = num_attention_heads
508
+
509
+ for i in range(num_layers):
510
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
511
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
512
+
513
+ resnets.append(
514
+ ResnetBlock2D(
515
+ in_channels=resnet_in_channels + res_skip_channels,
516
+ out_channels=out_channels,
517
+ temb_channels=temb_channels,
518
+ eps=resnet_eps,
519
+ groups=resnet_groups,
520
+ dropout=dropout,
521
+ time_embedding_norm=resnet_time_scale_shift,
522
+ non_linearity=resnet_act_fn,
523
+ output_scale_factor=output_scale_factor,
524
+ pre_norm=resnet_pre_norm,
525
+ )
526
+ )
527
+ temp_convs.append(
528
+ TemporalConvLayer(
529
+ out_channels,
530
+ out_channels,
531
+ dropout=0.1,
532
+ )
533
+ )
534
+ attentions.append(
535
+ Transformer2DModel(
536
+ out_channels // num_attention_heads,
537
+ num_attention_heads,
538
+ in_channels=out_channels,
539
+ num_layers=1,
540
+ cross_attention_dim=cross_attention_dim,
541
+ norm_num_groups=resnet_groups,
542
+ use_linear_projection=use_linear_projection,
543
+ only_cross_attention=only_cross_attention,
544
+ upcast_attention=upcast_attention,
545
+ )
546
+ )
547
+ temp_attentions.append(
548
+ TransformerTemporalModel(
549
+ out_channels // num_attention_heads,
550
+ num_attention_heads,
551
+ in_channels=out_channels,
552
+ num_layers=1,
553
+ cross_attention_dim=cross_attention_dim,
554
+ norm_num_groups=resnet_groups,
555
+ )
556
+ )
557
+ self.resnets = nn.ModuleList(resnets)
558
+ self.temp_convs = nn.ModuleList(temp_convs)
559
+ self.attentions = nn.ModuleList(attentions)
560
+ self.temp_attentions = nn.ModuleList(temp_attentions)
561
+
562
+ if add_upsample:
563
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
564
+ else:
565
+ self.upsamplers = None
566
+
567
+ self.gradient_checkpointing = False
568
+
569
+ def forward(
570
+ self,
571
+ hidden_states,
572
+ res_hidden_states_tuple,
573
+ temb=None,
574
+ encoder_hidden_states=None,
575
+ upsample_size=None,
576
+ attention_mask=None,
577
+ num_frames=1,
578
+ cross_attention_kwargs=None,
579
+ ):
580
+ # TODO(Patrick, William) - attention mask is not used
581
+ for resnet, temp_conv, attn, temp_attn in zip(
582
+ self.resnets, self.temp_convs, self.attentions, self.temp_attentions
583
+ ):
584
+ # pop res hidden states
585
+ res_hidden_states = res_hidden_states_tuple[-1]
586
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
587
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
588
+
589
+ hidden_states = resnet(hidden_states, temb)
590
+ hidden_states = temp_conv(hidden_states, num_frames=num_frames)
591
+ hidden_states = attn(
592
+ hidden_states,
593
+ encoder_hidden_states=encoder_hidden_states,
594
+ cross_attention_kwargs=cross_attention_kwargs,
595
+ return_dict=False,
596
+ )[0]
597
+ hidden_states = temp_attn(
598
+ hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs, return_dict=False
599
+ )[0]
600
+
601
+ if self.upsamplers is not None:
602
+ for upsampler in self.upsamplers:
603
+ hidden_states = upsampler(hidden_states, upsample_size)
604
+
605
+ return hidden_states
606
+
607
+
608
+ class UpBlock3D(nn.Module):
609
+ def __init__(
610
+ self,
611
+ in_channels: int,
612
+ prev_output_channel: int,
613
+ out_channels: int,
614
+ temb_channels: int,
615
+ dropout: float = 0.0,
616
+ num_layers: int = 1,
617
+ resnet_eps: float = 1e-6,
618
+ resnet_time_scale_shift: str = "default",
619
+ resnet_act_fn: str = "swish",
620
+ resnet_groups: int = 32,
621
+ resnet_pre_norm: bool = True,
622
+ output_scale_factor=1.0,
623
+ add_upsample=True,
624
+ ):
625
+ super().__init__()
626
+ resnets = []
627
+ temp_convs = []
628
+
629
+ for i in range(num_layers):
630
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
631
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
632
+
633
+ resnets.append(
634
+ ResnetBlock2D(
635
+ in_channels=resnet_in_channels + res_skip_channels,
636
+ out_channels=out_channels,
637
+ temb_channels=temb_channels,
638
+ eps=resnet_eps,
639
+ groups=resnet_groups,
640
+ dropout=dropout,
641
+ time_embedding_norm=resnet_time_scale_shift,
642
+ non_linearity=resnet_act_fn,
643
+ output_scale_factor=output_scale_factor,
644
+ pre_norm=resnet_pre_norm,
645
+ )
646
+ )
647
+ temp_convs.append(
648
+ TemporalConvLayer(
649
+ out_channels,
650
+ out_channels,
651
+ dropout=0.1,
652
+ )
653
+ )
654
+
655
+ self.resnets = nn.ModuleList(resnets)
656
+ self.temp_convs = nn.ModuleList(temp_convs)
657
+
658
+ if add_upsample:
659
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
660
+ else:
661
+ self.upsamplers = None
662
+
663
+ self.gradient_checkpointing = False
664
+
665
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, num_frames=1):
666
+ for resnet, temp_conv in zip(self.resnets, self.temp_convs):
667
+ # pop res hidden states
668
+ res_hidden_states = res_hidden_states_tuple[-1]
669
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
670
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
671
+
672
+ hidden_states = resnet(hidden_states, temb)
673
+ hidden_states = temp_conv(hidden_states, num_frames=num_frames)
674
+
675
+ if self.upsamplers is not None:
676
+ for upsampler in self.upsamplers:
677
+ hidden_states = upsampler(hidden_states, upsample_size)
678
+
679
+ return hidden_states
6DoF/diffusers/models/unet_3d_condition.py ADDED
@@ -0,0 +1,627 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Alibaba DAMO-VILAB and The HuggingFace Team. All rights reserved.
2
+ # Copyright 2023 The ModelScope Team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ from dataclasses import dataclass
16
+ from typing import Any, Dict, List, Optional, Tuple, Union
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.utils.checkpoint
21
+
22
+ from ..configuration_utils import ConfigMixin, register_to_config
23
+ from ..loaders import UNet2DConditionLoadersMixin
24
+ from ..utils import BaseOutput, logging
25
+ from .attention_processor import AttentionProcessor, AttnProcessor
26
+ from .embeddings import TimestepEmbedding, Timesteps
27
+ from .modeling_utils import ModelMixin
28
+ from .transformer_temporal import TransformerTemporalModel
29
+ from .unet_3d_blocks import (
30
+ CrossAttnDownBlock3D,
31
+ CrossAttnUpBlock3D,
32
+ DownBlock3D,
33
+ UNetMidBlock3DCrossAttn,
34
+ UpBlock3D,
35
+ get_down_block,
36
+ get_up_block,
37
+ )
38
+
39
+
40
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
41
+
42
+
43
+ @dataclass
44
+ class UNet3DConditionOutput(BaseOutput):
45
+ """
46
+ The output of [`UNet3DConditionModel`].
47
+
48
+ Args:
49
+ sample (`torch.FloatTensor` of shape `(batch_size, num_frames, num_channels, height, width)`):
50
+ The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
51
+ """
52
+
53
+ sample: torch.FloatTensor
54
+
55
+
56
+ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
57
+ r"""
58
+ A conditional 3D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
59
+ shaped output.
60
+
61
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
62
+ for all models (such as downloading or saving).
63
+
64
+ Parameters:
65
+ sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
66
+ Height and width of input/output sample.
67
+ in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample.
68
+ out_channels (`int`, *optional*, defaults to 4): The number of channels in the output.
69
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
70
+ The tuple of downsample blocks to use.
71
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
72
+ The tuple of upsample blocks to use.
73
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
74
+ The tuple of output channels for each block.
75
+ layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
76
+ downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
77
+ mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
78
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
79
+ norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
80
+ If `None`, normalization and activation layers is skipped in post-processing.
81
+ norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
82
+ cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features.
83
+ attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
84
+ num_attention_heads (`int`, *optional*): The number of attention heads.
85
+ """
86
+
87
+ _supports_gradient_checkpointing = False
88
+
89
+ @register_to_config
90
+ def __init__(
91
+ self,
92
+ sample_size: Optional[int] = None,
93
+ in_channels: int = 4,
94
+ out_channels: int = 4,
95
+ down_block_types: Tuple[str] = (
96
+ "CrossAttnDownBlock3D",
97
+ "CrossAttnDownBlock3D",
98
+ "CrossAttnDownBlock3D",
99
+ "DownBlock3D",
100
+ ),
101
+ up_block_types: Tuple[str] = ("UpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D"),
102
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
103
+ layers_per_block: int = 2,
104
+ downsample_padding: int = 1,
105
+ mid_block_scale_factor: float = 1,
106
+ act_fn: str = "silu",
107
+ norm_num_groups: Optional[int] = 32,
108
+ norm_eps: float = 1e-5,
109
+ cross_attention_dim: int = 1024,
110
+ attention_head_dim: Union[int, Tuple[int]] = 64,
111
+ num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
112
+ ):
113
+ super().__init__()
114
+
115
+ self.sample_size = sample_size
116
+
117
+ if num_attention_heads is not None:
118
+ raise NotImplementedError(
119
+ "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
120
+ )
121
+
122
+ # If `num_attention_heads` is not defined (which is the case for most models)
123
+ # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
124
+ # The reason for this behavior is to correct for incorrectly named variables that were introduced
125
+ # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
126
+ # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
127
+ # which is why we correct for the naming here.
128
+ num_attention_heads = num_attention_heads or attention_head_dim
129
+
130
+ # Check inputs
131
+ if len(down_block_types) != len(up_block_types):
132
+ raise ValueError(
133
+ f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
134
+ )
135
+
136
+ if len(block_out_channels) != len(down_block_types):
137
+ raise ValueError(
138
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
139
+ )
140
+
141
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
142
+ raise ValueError(
143
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
144
+ )
145
+
146
+ # input
147
+ conv_in_kernel = 3
148
+ conv_out_kernel = 3
149
+ conv_in_padding = (conv_in_kernel - 1) // 2
150
+ self.conv_in = nn.Conv2d(
151
+ in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
152
+ )
153
+
154
+ # time
155
+ time_embed_dim = block_out_channels[0] * 4
156
+ self.time_proj = Timesteps(block_out_channels[0], True, 0)
157
+ timestep_input_dim = block_out_channels[0]
158
+
159
+ self.time_embedding = TimestepEmbedding(
160
+ timestep_input_dim,
161
+ time_embed_dim,
162
+ act_fn=act_fn,
163
+ )
164
+
165
+ self.transformer_in = TransformerTemporalModel(
166
+ num_attention_heads=8,
167
+ attention_head_dim=attention_head_dim,
168
+ in_channels=block_out_channels[0],
169
+ num_layers=1,
170
+ )
171
+
172
+ # class embedding
173
+ self.down_blocks = nn.ModuleList([])
174
+ self.up_blocks = nn.ModuleList([])
175
+
176
+ if isinstance(num_attention_heads, int):
177
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
178
+
179
+ # down
180
+ output_channel = block_out_channels[0]
181
+ for i, down_block_type in enumerate(down_block_types):
182
+ input_channel = output_channel
183
+ output_channel = block_out_channels[i]
184
+ is_final_block = i == len(block_out_channels) - 1
185
+
186
+ down_block = get_down_block(
187
+ down_block_type,
188
+ num_layers=layers_per_block,
189
+ in_channels=input_channel,
190
+ out_channels=output_channel,
191
+ temb_channels=time_embed_dim,
192
+ add_downsample=not is_final_block,
193
+ resnet_eps=norm_eps,
194
+ resnet_act_fn=act_fn,
195
+ resnet_groups=norm_num_groups,
196
+ cross_attention_dim=cross_attention_dim,
197
+ num_attention_heads=num_attention_heads[i],
198
+ downsample_padding=downsample_padding,
199
+ dual_cross_attention=False,
200
+ )
201
+ self.down_blocks.append(down_block)
202
+
203
+ # mid
204
+ self.mid_block = UNetMidBlock3DCrossAttn(
205
+ in_channels=block_out_channels[-1],
206
+ temb_channels=time_embed_dim,
207
+ resnet_eps=norm_eps,
208
+ resnet_act_fn=act_fn,
209
+ output_scale_factor=mid_block_scale_factor,
210
+ cross_attention_dim=cross_attention_dim,
211
+ num_attention_heads=num_attention_heads[-1],
212
+ resnet_groups=norm_num_groups,
213
+ dual_cross_attention=False,
214
+ )
215
+
216
+ # count how many layers upsample the images
217
+ self.num_upsamplers = 0
218
+
219
+ # up
220
+ reversed_block_out_channels = list(reversed(block_out_channels))
221
+ reversed_num_attention_heads = list(reversed(num_attention_heads))
222
+
223
+ output_channel = reversed_block_out_channels[0]
224
+ for i, up_block_type in enumerate(up_block_types):
225
+ is_final_block = i == len(block_out_channels) - 1
226
+
227
+ prev_output_channel = output_channel
228
+ output_channel = reversed_block_out_channels[i]
229
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
230
+
231
+ # add upsample block for all BUT final layer
232
+ if not is_final_block:
233
+ add_upsample = True
234
+ self.num_upsamplers += 1
235
+ else:
236
+ add_upsample = False
237
+
238
+ up_block = get_up_block(
239
+ up_block_type,
240
+ num_layers=layers_per_block + 1,
241
+ in_channels=input_channel,
242
+ out_channels=output_channel,
243
+ prev_output_channel=prev_output_channel,
244
+ temb_channels=time_embed_dim,
245
+ add_upsample=add_upsample,
246
+ resnet_eps=norm_eps,
247
+ resnet_act_fn=act_fn,
248
+ resnet_groups=norm_num_groups,
249
+ cross_attention_dim=cross_attention_dim,
250
+ num_attention_heads=reversed_num_attention_heads[i],
251
+ dual_cross_attention=False,
252
+ )
253
+ self.up_blocks.append(up_block)
254
+ prev_output_channel = output_channel
255
+
256
+ # out
257
+ if norm_num_groups is not None:
258
+ self.conv_norm_out = nn.GroupNorm(
259
+ num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
260
+ )
261
+ self.conv_act = nn.SiLU()
262
+ else:
263
+ self.conv_norm_out = None
264
+ self.conv_act = None
265
+
266
+ conv_out_padding = (conv_out_kernel - 1) // 2
267
+ self.conv_out = nn.Conv2d(
268
+ block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
269
+ )
270
+
271
+ @property
272
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
273
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
274
+ r"""
275
+ Returns:
276
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
277
+ indexed by its weight name.
278
+ """
279
+ # set recursively
280
+ processors = {}
281
+
282
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
283
+ if hasattr(module, "set_processor"):
284
+ processors[f"{name}.processor"] = module.processor
285
+
286
+ for sub_name, child in module.named_children():
287
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
288
+
289
+ return processors
290
+
291
+ for name, module in self.named_children():
292
+ fn_recursive_add_processors(name, module, processors)
293
+
294
+ return processors
295
+
296
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice
297
+ def set_attention_slice(self, slice_size):
298
+ r"""
299
+ Enable sliced attention computation.
300
+
301
+ When this option is enabled, the attention module splits the input tensor in slices to compute attention in
302
+ several steps. This is useful for saving some memory in exchange for a small decrease in speed.
303
+
304
+ Args:
305
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
306
+ When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
307
+ `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
308
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
309
+ must be a multiple of `slice_size`.
310
+ """
311
+ sliceable_head_dims = []
312
+
313
+ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
314
+ if hasattr(module, "set_attention_slice"):
315
+ sliceable_head_dims.append(module.sliceable_head_dim)
316
+
317
+ for child in module.children():
318
+ fn_recursive_retrieve_sliceable_dims(child)
319
+
320
+ # retrieve number of attention layers
321
+ for module in self.children():
322
+ fn_recursive_retrieve_sliceable_dims(module)
323
+
324
+ num_sliceable_layers = len(sliceable_head_dims)
325
+
326
+ if slice_size == "auto":
327
+ # half the attention head size is usually a good trade-off between
328
+ # speed and memory
329
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
330
+ elif slice_size == "max":
331
+ # make smallest slice possible
332
+ slice_size = num_sliceable_layers * [1]
333
+
334
+ slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
335
+
336
+ if len(slice_size) != len(sliceable_head_dims):
337
+ raise ValueError(
338
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
339
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
340
+ )
341
+
342
+ for i in range(len(slice_size)):
343
+ size = slice_size[i]
344
+ dim = sliceable_head_dims[i]
345
+ if size is not None and size > dim:
346
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
347
+
348
+ # Recursively walk through all the children.
349
+ # Any children which exposes the set_attention_slice method
350
+ # gets the message
351
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
352
+ if hasattr(module, "set_attention_slice"):
353
+ module.set_attention_slice(slice_size.pop())
354
+
355
+ for child in module.children():
356
+ fn_recursive_set_attention_slice(child, slice_size)
357
+
358
+ reversed_slice_size = list(reversed(slice_size))
359
+ for module in self.children():
360
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
361
+
362
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
363
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
364
+ r"""
365
+ Sets the attention processor to use to compute attention.
366
+
367
+ Parameters:
368
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
369
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
370
+ for **all** `Attention` layers.
371
+
372
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
373
+ processor. This is strongly recommended when setting trainable attention processors.
374
+
375
+ """
376
+ count = len(self.attn_processors.keys())
377
+
378
+ if isinstance(processor, dict) and len(processor) != count:
379
+ raise ValueError(
380
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
381
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
382
+ )
383
+
384
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
385
+ if hasattr(module, "set_processor"):
386
+ if not isinstance(processor, dict):
387
+ module.set_processor(processor)
388
+ else:
389
+ module.set_processor(processor.pop(f"{name}.processor"))
390
+
391
+ for sub_name, child in module.named_children():
392
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
393
+
394
+ for name, module in self.named_children():
395
+ fn_recursive_attn_processor(name, module, processor)
396
+
397
+ def enable_forward_chunking(self, chunk_size=None, dim=0):
398
+ """
399
+ Sets the attention processor to use [feed forward
400
+ chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers).
401
+
402
+ Parameters:
403
+ chunk_size (`int`, *optional*):
404
+ The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually
405
+ over each tensor of dim=`dim`.
406
+ dim (`int`, *optional*, defaults to `0`):
407
+ The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch)
408
+ or dim=1 (sequence length).
409
+ """
410
+ if dim not in [0, 1]:
411
+ raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}")
412
+
413
+ # By default chunk size is 1
414
+ chunk_size = chunk_size or 1
415
+
416
+ def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
417
+ if hasattr(module, "set_chunk_feed_forward"):
418
+ module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
419
+
420
+ for child in module.children():
421
+ fn_recursive_feed_forward(child, chunk_size, dim)
422
+
423
+ for module in self.children():
424
+ fn_recursive_feed_forward(module, chunk_size, dim)
425
+
426
+ def disable_forward_chunking(self):
427
+ def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
428
+ if hasattr(module, "set_chunk_feed_forward"):
429
+ module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
430
+
431
+ for child in module.children():
432
+ fn_recursive_feed_forward(child, chunk_size, dim)
433
+
434
+ for module in self.children():
435
+ fn_recursive_feed_forward(module, None, 0)
436
+
437
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
438
+ def set_default_attn_processor(self):
439
+ """
440
+ Disables custom attention processors and sets the default attention implementation.
441
+ """
442
+ self.set_attn_processor(AttnProcessor())
443
+
444
+ def _set_gradient_checkpointing(self, module, value=False):
445
+ if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)):
446
+ module.gradient_checkpointing = value
447
+
448
+ def forward(
449
+ self,
450
+ sample: torch.FloatTensor,
451
+ timestep: Union[torch.Tensor, float, int],
452
+ encoder_hidden_states: torch.Tensor,
453
+ class_labels: Optional[torch.Tensor] = None,
454
+ timestep_cond: Optional[torch.Tensor] = None,
455
+ attention_mask: Optional[torch.Tensor] = None,
456
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
457
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
458
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
459
+ return_dict: bool = True,
460
+ ) -> Union[UNet3DConditionOutput, Tuple]:
461
+ r"""
462
+ The [`UNet3DConditionModel`] forward method.
463
+
464
+ Args:
465
+ sample (`torch.FloatTensor`):
466
+ The noisy input tensor with the following shape `(batch, num_frames, channel, height, width`.
467
+ timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
468
+ encoder_hidden_states (`torch.FloatTensor`):
469
+ The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
470
+ return_dict (`bool`, *optional*, defaults to `True`):
471
+ Whether or not to return a [`~models.unet_3d_condition.UNet3DConditionOutput`] instead of a plain
472
+ tuple.
473
+ cross_attention_kwargs (`dict`, *optional*):
474
+ A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
475
+
476
+ Returns:
477
+ [`~models.unet_3d_condition.UNet3DConditionOutput`] or `tuple`:
478
+ If `return_dict` is True, an [`~models.unet_3d_condition.UNet3DConditionOutput`] is returned, otherwise
479
+ a `tuple` is returned where the first element is the sample tensor.
480
+ """
481
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
482
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
483
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
484
+ # on the fly if necessary.
485
+ default_overall_up_factor = 2**self.num_upsamplers
486
+
487
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
488
+ forward_upsample_size = False
489
+ upsample_size = None
490
+
491
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
492
+ logger.info("Forward upsample size to force interpolation output size.")
493
+ forward_upsample_size = True
494
+
495
+ # prepare attention_mask
496
+ if attention_mask is not None:
497
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
498
+ attention_mask = attention_mask.unsqueeze(1)
499
+
500
+ # 1. time
501
+ timesteps = timestep
502
+ if not torch.is_tensor(timesteps):
503
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
504
+ # This would be a good case for the `match` statement (Python 3.10+)
505
+ is_mps = sample.device.type == "mps"
506
+ if isinstance(timestep, float):
507
+ dtype = torch.float32 if is_mps else torch.float64
508
+ else:
509
+ dtype = torch.int32 if is_mps else torch.int64
510
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
511
+ elif len(timesteps.shape) == 0:
512
+ timesteps = timesteps[None].to(sample.device)
513
+
514
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
515
+ num_frames = sample.shape[2]
516
+ timesteps = timesteps.expand(sample.shape[0])
517
+
518
+ t_emb = self.time_proj(timesteps)
519
+
520
+ # timesteps does not contain any weights and will always return f32 tensors
521
+ # but time_embedding might actually be running in fp16. so we need to cast here.
522
+ # there might be better ways to encapsulate this.
523
+ t_emb = t_emb.to(dtype=self.dtype)
524
+
525
+ emb = self.time_embedding(t_emb, timestep_cond)
526
+ emb = emb.repeat_interleave(repeats=num_frames, dim=0)
527
+ encoder_hidden_states = encoder_hidden_states.repeat_interleave(repeats=num_frames, dim=0)
528
+
529
+ # 2. pre-process
530
+ sample = sample.permute(0, 2, 1, 3, 4).reshape((sample.shape[0] * num_frames, -1) + sample.shape[3:])
531
+ sample = self.conv_in(sample)
532
+
533
+ sample = self.transformer_in(
534
+ sample,
535
+ num_frames=num_frames,
536
+ cross_attention_kwargs=cross_attention_kwargs,
537
+ return_dict=False,
538
+ )[0]
539
+
540
+ # 3. down
541
+ down_block_res_samples = (sample,)
542
+ for downsample_block in self.down_blocks:
543
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
544
+ sample, res_samples = downsample_block(
545
+ hidden_states=sample,
546
+ temb=emb,
547
+ encoder_hidden_states=encoder_hidden_states,
548
+ attention_mask=attention_mask,
549
+ num_frames=num_frames,
550
+ cross_attention_kwargs=cross_attention_kwargs,
551
+ )
552
+ else:
553
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb, num_frames=num_frames)
554
+
555
+ down_block_res_samples += res_samples
556
+
557
+ if down_block_additional_residuals is not None:
558
+ new_down_block_res_samples = ()
559
+
560
+ for down_block_res_sample, down_block_additional_residual in zip(
561
+ down_block_res_samples, down_block_additional_residuals
562
+ ):
563
+ down_block_res_sample = down_block_res_sample + down_block_additional_residual
564
+ new_down_block_res_samples += (down_block_res_sample,)
565
+
566
+ down_block_res_samples = new_down_block_res_samples
567
+
568
+ # 4. mid
569
+ if self.mid_block is not None:
570
+ sample = self.mid_block(
571
+ sample,
572
+ emb,
573
+ encoder_hidden_states=encoder_hidden_states,
574
+ attention_mask=attention_mask,
575
+ num_frames=num_frames,
576
+ cross_attention_kwargs=cross_attention_kwargs,
577
+ )
578
+
579
+ if mid_block_additional_residual is not None:
580
+ sample = sample + mid_block_additional_residual
581
+
582
+ # 5. up
583
+ for i, upsample_block in enumerate(self.up_blocks):
584
+ is_final_block = i == len(self.up_blocks) - 1
585
+
586
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
587
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
588
+
589
+ # if we have not reached the final block and need to forward the
590
+ # upsample size, we do it here
591
+ if not is_final_block and forward_upsample_size:
592
+ upsample_size = down_block_res_samples[-1].shape[2:]
593
+
594
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
595
+ sample = upsample_block(
596
+ hidden_states=sample,
597
+ temb=emb,
598
+ res_hidden_states_tuple=res_samples,
599
+ encoder_hidden_states=encoder_hidden_states,
600
+ upsample_size=upsample_size,
601
+ attention_mask=attention_mask,
602
+ num_frames=num_frames,
603
+ cross_attention_kwargs=cross_attention_kwargs,
604
+ )
605
+ else:
606
+ sample = upsample_block(
607
+ hidden_states=sample,
608
+ temb=emb,
609
+ res_hidden_states_tuple=res_samples,
610
+ upsample_size=upsample_size,
611
+ num_frames=num_frames,
612
+ )
613
+
614
+ # 6. post-process
615
+ if self.conv_norm_out:
616
+ sample = self.conv_norm_out(sample)
617
+ sample = self.conv_act(sample)
618
+
619
+ sample = self.conv_out(sample)
620
+
621
+ # reshape to (batch, channel, framerate, width, height)
622
+ sample = sample[None, :].reshape((-1, num_frames) + sample.shape[1:]).permute(0, 2, 1, 3, 4)
623
+
624
+ if not return_dict:
625
+ return (sample,)
626
+
627
+ return UNet3DConditionOutput(sample=sample)
6DoF/diffusers/models/vae.py ADDED
@@ -0,0 +1,441 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from dataclasses import dataclass
15
+ from typing import Optional
16
+
17
+ import numpy as np
18
+ import torch
19
+ import torch.nn as nn
20
+
21
+ from ..utils import BaseOutput, is_torch_version, randn_tensor
22
+ from .attention_processor import SpatialNorm
23
+ from .unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block
24
+
25
+
26
+ @dataclass
27
+ class DecoderOutput(BaseOutput):
28
+ """
29
+ Output of decoding method.
30
+
31
+ Args:
32
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
33
+ The decoded output sample from the last layer of the model.
34
+ """
35
+
36
+ sample: torch.FloatTensor
37
+
38
+
39
+ class Encoder(nn.Module):
40
+ def __init__(
41
+ self,
42
+ in_channels=3,
43
+ out_channels=3,
44
+ down_block_types=("DownEncoderBlock2D",),
45
+ block_out_channels=(64,),
46
+ layers_per_block=2,
47
+ norm_num_groups=32,
48
+ act_fn="silu",
49
+ double_z=True,
50
+ ):
51
+ super().__init__()
52
+ self.layers_per_block = layers_per_block
53
+
54
+ self.conv_in = torch.nn.Conv2d(
55
+ in_channels,
56
+ block_out_channels[0],
57
+ kernel_size=3,
58
+ stride=1,
59
+ padding=1,
60
+ )
61
+
62
+ self.mid_block = None
63
+ self.down_blocks = nn.ModuleList([])
64
+
65
+ # down
66
+ output_channel = block_out_channels[0]
67
+ for i, down_block_type in enumerate(down_block_types):
68
+ input_channel = output_channel
69
+ output_channel = block_out_channels[i]
70
+ is_final_block = i == len(block_out_channels) - 1
71
+
72
+ down_block = get_down_block(
73
+ down_block_type,
74
+ num_layers=self.layers_per_block,
75
+ in_channels=input_channel,
76
+ out_channels=output_channel,
77
+ add_downsample=not is_final_block,
78
+ resnet_eps=1e-6,
79
+ downsample_padding=0,
80
+ resnet_act_fn=act_fn,
81
+ resnet_groups=norm_num_groups,
82
+ attention_head_dim=output_channel,
83
+ temb_channels=None,
84
+ )
85
+ self.down_blocks.append(down_block)
86
+
87
+ # mid
88
+ self.mid_block = UNetMidBlock2D(
89
+ in_channels=block_out_channels[-1],
90
+ resnet_eps=1e-6,
91
+ resnet_act_fn=act_fn,
92
+ output_scale_factor=1,
93
+ resnet_time_scale_shift="default",
94
+ attention_head_dim=block_out_channels[-1],
95
+ resnet_groups=norm_num_groups,
96
+ temb_channels=None,
97
+ )
98
+
99
+ # out
100
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6)
101
+ self.conv_act = nn.SiLU()
102
+
103
+ conv_out_channels = 2 * out_channels if double_z else out_channels
104
+ self.conv_out = nn.Conv2d(block_out_channels[-1], conv_out_channels, 3, padding=1)
105
+
106
+ self.gradient_checkpointing = False
107
+
108
+ def forward(self, x):
109
+ sample = x
110
+ sample = self.conv_in(sample)
111
+
112
+ if self.training and self.gradient_checkpointing:
113
+
114
+ def create_custom_forward(module):
115
+ def custom_forward(*inputs):
116
+ return module(*inputs)
117
+
118
+ return custom_forward
119
+
120
+ # down
121
+ if is_torch_version(">=", "1.11.0"):
122
+ for down_block in self.down_blocks:
123
+ sample = torch.utils.checkpoint.checkpoint(
124
+ create_custom_forward(down_block), sample, use_reentrant=False
125
+ )
126
+ # middle
127
+ sample = torch.utils.checkpoint.checkpoint(
128
+ create_custom_forward(self.mid_block), sample, use_reentrant=False
129
+ )
130
+ else:
131
+ for down_block in self.down_blocks:
132
+ sample = torch.utils.checkpoint.checkpoint(create_custom_forward(down_block), sample)
133
+ # middle
134
+ sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample)
135
+
136
+ else:
137
+ # down
138
+ for down_block in self.down_blocks:
139
+ sample = down_block(sample)
140
+
141
+ # middle
142
+ sample = self.mid_block(sample)
143
+
144
+ # post-process
145
+ sample = self.conv_norm_out(sample)
146
+ sample = self.conv_act(sample)
147
+ sample = self.conv_out(sample)
148
+
149
+ return sample
150
+
151
+
152
+ class Decoder(nn.Module):
153
+ def __init__(
154
+ self,
155
+ in_channels=3,
156
+ out_channels=3,
157
+ up_block_types=("UpDecoderBlock2D",),
158
+ block_out_channels=(64,),
159
+ layers_per_block=2,
160
+ norm_num_groups=32,
161
+ act_fn="silu",
162
+ norm_type="group", # group, spatial
163
+ ):
164
+ super().__init__()
165
+ self.layers_per_block = layers_per_block
166
+
167
+ self.conv_in = nn.Conv2d(
168
+ in_channels,
169
+ block_out_channels[-1],
170
+ kernel_size=3,
171
+ stride=1,
172
+ padding=1,
173
+ )
174
+
175
+ self.mid_block = None
176
+ self.up_blocks = nn.ModuleList([])
177
+
178
+ temb_channels = in_channels if norm_type == "spatial" else None
179
+
180
+ # mid
181
+ self.mid_block = UNetMidBlock2D(
182
+ in_channels=block_out_channels[-1],
183
+ resnet_eps=1e-6,
184
+ resnet_act_fn=act_fn,
185
+ output_scale_factor=1,
186
+ resnet_time_scale_shift="default" if norm_type == "group" else norm_type,
187
+ attention_head_dim=block_out_channels[-1],
188
+ resnet_groups=norm_num_groups,
189
+ temb_channels=temb_channels,
190
+ )
191
+
192
+ # up
193
+ reversed_block_out_channels = list(reversed(block_out_channels))
194
+ output_channel = reversed_block_out_channels[0]
195
+ for i, up_block_type in enumerate(up_block_types):
196
+ prev_output_channel = output_channel
197
+ output_channel = reversed_block_out_channels[i]
198
+
199
+ is_final_block = i == len(block_out_channels) - 1
200
+
201
+ up_block = get_up_block(
202
+ up_block_type,
203
+ num_layers=self.layers_per_block + 1,
204
+ in_channels=prev_output_channel,
205
+ out_channels=output_channel,
206
+ prev_output_channel=None,
207
+ add_upsample=not is_final_block,
208
+ resnet_eps=1e-6,
209
+ resnet_act_fn=act_fn,
210
+ resnet_groups=norm_num_groups,
211
+ attention_head_dim=output_channel,
212
+ temb_channels=temb_channels,
213
+ resnet_time_scale_shift=norm_type,
214
+ )
215
+ self.up_blocks.append(up_block)
216
+ prev_output_channel = output_channel
217
+
218
+ # out
219
+ if norm_type == "spatial":
220
+ self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels)
221
+ else:
222
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)
223
+ self.conv_act = nn.SiLU()
224
+ self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
225
+
226
+ self.gradient_checkpointing = False
227
+
228
+ def forward(self, z, latent_embeds=None):
229
+ sample = z
230
+ sample = self.conv_in(sample)
231
+
232
+ upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
233
+ if self.training and self.gradient_checkpointing:
234
+
235
+ def create_custom_forward(module):
236
+ def custom_forward(*inputs):
237
+ return module(*inputs)
238
+
239
+ return custom_forward
240
+
241
+ if is_torch_version(">=", "1.11.0"):
242
+ # middle
243
+ sample = torch.utils.checkpoint.checkpoint(
244
+ create_custom_forward(self.mid_block), sample, latent_embeds, use_reentrant=False
245
+ )
246
+ sample = sample.to(upscale_dtype)
247
+
248
+ # up
249
+ for up_block in self.up_blocks:
250
+ sample = torch.utils.checkpoint.checkpoint(
251
+ create_custom_forward(up_block), sample, latent_embeds, use_reentrant=False
252
+ )
253
+ else:
254
+ # middle
255
+ sample = torch.utils.checkpoint.checkpoint(
256
+ create_custom_forward(self.mid_block), sample, latent_embeds
257
+ )
258
+ sample = sample.to(upscale_dtype)
259
+
260
+ # up
261
+ for up_block in self.up_blocks:
262
+ sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample, latent_embeds)
263
+ else:
264
+ # middle
265
+ sample = self.mid_block(sample, latent_embeds)
266
+ sample = sample.to(upscale_dtype)
267
+
268
+ # up
269
+ for up_block in self.up_blocks:
270
+ sample = up_block(sample, latent_embeds)
271
+
272
+ # post-process
273
+ if latent_embeds is None:
274
+ sample = self.conv_norm_out(sample)
275
+ else:
276
+ sample = self.conv_norm_out(sample, latent_embeds)
277
+ sample = self.conv_act(sample)
278
+ sample = self.conv_out(sample)
279
+
280
+ return sample
281
+
282
+
283
+ class VectorQuantizer(nn.Module):
284
+ """
285
+ Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly avoids costly matrix
286
+ multiplications and allows for post-hoc remapping of indices.
287
+ """
288
+
289
+ # NOTE: due to a bug the beta term was applied to the wrong term. for
290
+ # backwards compatibility we use the buggy version by default, but you can
291
+ # specify legacy=False to fix it.
292
+ def __init__(
293
+ self, n_e, vq_embed_dim, beta, remap=None, unknown_index="random", sane_index_shape=False, legacy=True
294
+ ):
295
+ super().__init__()
296
+ self.n_e = n_e
297
+ self.vq_embed_dim = vq_embed_dim
298
+ self.beta = beta
299
+ self.legacy = legacy
300
+
301
+ self.embedding = nn.Embedding(self.n_e, self.vq_embed_dim)
302
+ self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
303
+
304
+ self.remap = remap
305
+ if self.remap is not None:
306
+ self.register_buffer("used", torch.tensor(np.load(self.remap)))
307
+ self.re_embed = self.used.shape[0]
308
+ self.unknown_index = unknown_index # "random" or "extra" or integer
309
+ if self.unknown_index == "extra":
310
+ self.unknown_index = self.re_embed
311
+ self.re_embed = self.re_embed + 1
312
+ print(
313
+ f"Remapping {self.n_e} indices to {self.re_embed} indices. "
314
+ f"Using {self.unknown_index} for unknown indices."
315
+ )
316
+ else:
317
+ self.re_embed = n_e
318
+
319
+ self.sane_index_shape = sane_index_shape
320
+
321
+ def remap_to_used(self, inds):
322
+ ishape = inds.shape
323
+ assert len(ishape) > 1
324
+ inds = inds.reshape(ishape[0], -1)
325
+ used = self.used.to(inds)
326
+ match = (inds[:, :, None] == used[None, None, ...]).long()
327
+ new = match.argmax(-1)
328
+ unknown = match.sum(2) < 1
329
+ if self.unknown_index == "random":
330
+ new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device)
331
+ else:
332
+ new[unknown] = self.unknown_index
333
+ return new.reshape(ishape)
334
+
335
+ def unmap_to_all(self, inds):
336
+ ishape = inds.shape
337
+ assert len(ishape) > 1
338
+ inds = inds.reshape(ishape[0], -1)
339
+ used = self.used.to(inds)
340
+ if self.re_embed > self.used.shape[0]: # extra token
341
+ inds[inds >= self.used.shape[0]] = 0 # simply set to zero
342
+ back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds)
343
+ return back.reshape(ishape)
344
+
345
+ def forward(self, z):
346
+ # reshape z -> (batch, height, width, channel) and flatten
347
+ z = z.permute(0, 2, 3, 1).contiguous()
348
+ z_flattened = z.view(-1, self.vq_embed_dim)
349
+
350
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
351
+ min_encoding_indices = torch.argmin(torch.cdist(z_flattened, self.embedding.weight), dim=1)
352
+
353
+ z_q = self.embedding(min_encoding_indices).view(z.shape)
354
+ perplexity = None
355
+ min_encodings = None
356
+
357
+ # compute loss for embedding
358
+ if not self.legacy:
359
+ loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean((z_q - z.detach()) ** 2)
360
+ else:
361
+ loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean((z_q - z.detach()) ** 2)
362
+
363
+ # preserve gradients
364
+ z_q = z + (z_q - z).detach()
365
+
366
+ # reshape back to match original input shape
367
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
368
+
369
+ if self.remap is not None:
370
+ min_encoding_indices = min_encoding_indices.reshape(z.shape[0], -1) # add batch axis
371
+ min_encoding_indices = self.remap_to_used(min_encoding_indices)
372
+ min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten
373
+
374
+ if self.sane_index_shape:
375
+ min_encoding_indices = min_encoding_indices.reshape(z_q.shape[0], z_q.shape[2], z_q.shape[3])
376
+
377
+ return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
378
+
379
+ def get_codebook_entry(self, indices, shape):
380
+ # shape specifying (batch, height, width, channel)
381
+ if self.remap is not None:
382
+ indices = indices.reshape(shape[0], -1) # add batch axis
383
+ indices = self.unmap_to_all(indices)
384
+ indices = indices.reshape(-1) # flatten again
385
+
386
+ # get quantized latent vectors
387
+ z_q = self.embedding(indices)
388
+
389
+ if shape is not None:
390
+ z_q = z_q.view(shape)
391
+ # reshape back to match original input shape
392
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
393
+
394
+ return z_q
395
+
396
+
397
+ class DiagonalGaussianDistribution(object):
398
+ def __init__(self, parameters, deterministic=False):
399
+ self.parameters = parameters
400
+ self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
401
+ self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
402
+ self.deterministic = deterministic
403
+ self.std = torch.exp(0.5 * self.logvar)
404
+ self.var = torch.exp(self.logvar)
405
+ if self.deterministic:
406
+ self.var = self.std = torch.zeros_like(
407
+ self.mean, device=self.parameters.device, dtype=self.parameters.dtype
408
+ )
409
+
410
+ def sample(self, generator: Optional[torch.Generator] = None) -> torch.FloatTensor:
411
+ # make sure sample is on the same device as the parameters and has same dtype
412
+ sample = randn_tensor(
413
+ self.mean.shape, generator=generator, device=self.parameters.device, dtype=self.parameters.dtype
414
+ )
415
+ x = self.mean + self.std * sample
416
+ return x
417
+
418
+ def kl(self, other=None):
419
+ if self.deterministic:
420
+ return torch.Tensor([0.0])
421
+ else:
422
+ if other is None:
423
+ return 0.5 * torch.sum(torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, dim=[1, 2, 3])
424
+ else:
425
+ return 0.5 * torch.sum(
426
+ torch.pow(self.mean - other.mean, 2) / other.var
427
+ + self.var / other.var
428
+ - 1.0
429
+ - self.logvar
430
+ + other.logvar,
431
+ dim=[1, 2, 3],
432
+ )
433
+
434
+ def nll(self, sample, dims=[1, 2, 3]):
435
+ if self.deterministic:
436
+ return torch.Tensor([0.0])
437
+ logtwopi = np.log(2.0 * np.pi)
438
+ return 0.5 * torch.sum(logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, dim=dims)
439
+
440
+ def mode(self):
441
+ return self.mean
6DoF/diffusers/models/vae_flax.py ADDED
@@ -0,0 +1,869 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # JAX implementation of VQGAN from taming-transformers https://github.com/CompVis/taming-transformers
16
+
17
+ import math
18
+ from functools import partial
19
+ from typing import Tuple
20
+
21
+ import flax
22
+ import flax.linen as nn
23
+ import jax
24
+ import jax.numpy as jnp
25
+ from flax.core.frozen_dict import FrozenDict
26
+
27
+ from ..configuration_utils import ConfigMixin, flax_register_to_config
28
+ from ..utils import BaseOutput
29
+ from .modeling_flax_utils import FlaxModelMixin
30
+
31
+
32
+ @flax.struct.dataclass
33
+ class FlaxDecoderOutput(BaseOutput):
34
+ """
35
+ Output of decoding method.
36
+
37
+ Args:
38
+ sample (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)`):
39
+ The decoded output sample from the last layer of the model.
40
+ dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`):
41
+ The `dtype` of the parameters.
42
+ """
43
+
44
+ sample: jnp.ndarray
45
+
46
+
47
+ @flax.struct.dataclass
48
+ class FlaxAutoencoderKLOutput(BaseOutput):
49
+ """
50
+ Output of AutoencoderKL encoding method.
51
+
52
+ Args:
53
+ latent_dist (`FlaxDiagonalGaussianDistribution`):
54
+ Encoded outputs of `Encoder` represented as the mean and logvar of `FlaxDiagonalGaussianDistribution`.
55
+ `FlaxDiagonalGaussianDistribution` allows for sampling latents from the distribution.
56
+ """
57
+
58
+ latent_dist: "FlaxDiagonalGaussianDistribution"
59
+
60
+
61
+ class FlaxUpsample2D(nn.Module):
62
+ """
63
+ Flax implementation of 2D Upsample layer
64
+
65
+ Args:
66
+ in_channels (`int`):
67
+ Input channels
68
+ dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
69
+ Parameters `dtype`
70
+ """
71
+
72
+ in_channels: int
73
+ dtype: jnp.dtype = jnp.float32
74
+
75
+ def setup(self):
76
+ self.conv = nn.Conv(
77
+ self.in_channels,
78
+ kernel_size=(3, 3),
79
+ strides=(1, 1),
80
+ padding=((1, 1), (1, 1)),
81
+ dtype=self.dtype,
82
+ )
83
+
84
+ def __call__(self, hidden_states):
85
+ batch, height, width, channels = hidden_states.shape
86
+ hidden_states = jax.image.resize(
87
+ hidden_states,
88
+ shape=(batch, height * 2, width * 2, channels),
89
+ method="nearest",
90
+ )
91
+ hidden_states = self.conv(hidden_states)
92
+ return hidden_states
93
+
94
+
95
+ class FlaxDownsample2D(nn.Module):
96
+ """
97
+ Flax implementation of 2D Downsample layer
98
+
99
+ Args:
100
+ in_channels (`int`):
101
+ Input channels
102
+ dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
103
+ Parameters `dtype`
104
+ """
105
+
106
+ in_channels: int
107
+ dtype: jnp.dtype = jnp.float32
108
+
109
+ def setup(self):
110
+ self.conv = nn.Conv(
111
+ self.in_channels,
112
+ kernel_size=(3, 3),
113
+ strides=(2, 2),
114
+ padding="VALID",
115
+ dtype=self.dtype,
116
+ )
117
+
118
+ def __call__(self, hidden_states):
119
+ pad = ((0, 0), (0, 1), (0, 1), (0, 0)) # pad height and width dim
120
+ hidden_states = jnp.pad(hidden_states, pad_width=pad)
121
+ hidden_states = self.conv(hidden_states)
122
+ return hidden_states
123
+
124
+
125
+ class FlaxResnetBlock2D(nn.Module):
126
+ """
127
+ Flax implementation of 2D Resnet Block.
128
+
129
+ Args:
130
+ in_channels (`int`):
131
+ Input channels
132
+ out_channels (`int`):
133
+ Output channels
134
+ dropout (:obj:`float`, *optional*, defaults to 0.0):
135
+ Dropout rate
136
+ groups (:obj:`int`, *optional*, defaults to `32`):
137
+ The number of groups to use for group norm.
138
+ use_nin_shortcut (:obj:`bool`, *optional*, defaults to `None`):
139
+ Whether to use `nin_shortcut`. This activates a new layer inside ResNet block
140
+ dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
141
+ Parameters `dtype`
142
+ """
143
+
144
+ in_channels: int
145
+ out_channels: int = None
146
+ dropout: float = 0.0
147
+ groups: int = 32
148
+ use_nin_shortcut: bool = None
149
+ dtype: jnp.dtype = jnp.float32
150
+
151
+ def setup(self):
152
+ out_channels = self.in_channels if self.out_channels is None else self.out_channels
153
+
154
+ self.norm1 = nn.GroupNorm(num_groups=self.groups, epsilon=1e-6)
155
+ self.conv1 = nn.Conv(
156
+ out_channels,
157
+ kernel_size=(3, 3),
158
+ strides=(1, 1),
159
+ padding=((1, 1), (1, 1)),
160
+ dtype=self.dtype,
161
+ )
162
+
163
+ self.norm2 = nn.GroupNorm(num_groups=self.groups, epsilon=1e-6)
164
+ self.dropout_layer = nn.Dropout(self.dropout)
165
+ self.conv2 = nn.Conv(
166
+ out_channels,
167
+ kernel_size=(3, 3),
168
+ strides=(1, 1),
169
+ padding=((1, 1), (1, 1)),
170
+ dtype=self.dtype,
171
+ )
172
+
173
+ use_nin_shortcut = self.in_channels != out_channels if self.use_nin_shortcut is None else self.use_nin_shortcut
174
+
175
+ self.conv_shortcut = None
176
+ if use_nin_shortcut:
177
+ self.conv_shortcut = nn.Conv(
178
+ out_channels,
179
+ kernel_size=(1, 1),
180
+ strides=(1, 1),
181
+ padding="VALID",
182
+ dtype=self.dtype,
183
+ )
184
+
185
+ def __call__(self, hidden_states, deterministic=True):
186
+ residual = hidden_states
187
+ hidden_states = self.norm1(hidden_states)
188
+ hidden_states = nn.swish(hidden_states)
189
+ hidden_states = self.conv1(hidden_states)
190
+
191
+ hidden_states = self.norm2(hidden_states)
192
+ hidden_states = nn.swish(hidden_states)
193
+ hidden_states = self.dropout_layer(hidden_states, deterministic)
194
+ hidden_states = self.conv2(hidden_states)
195
+
196
+ if self.conv_shortcut is not None:
197
+ residual = self.conv_shortcut(residual)
198
+
199
+ return hidden_states + residual
200
+
201
+
202
+ class FlaxAttentionBlock(nn.Module):
203
+ r"""
204
+ Flax Convolutional based multi-head attention block for diffusion-based VAE.
205
+
206
+ Parameters:
207
+ channels (:obj:`int`):
208
+ Input channels
209
+ num_head_channels (:obj:`int`, *optional*, defaults to `None`):
210
+ Number of attention heads
211
+ num_groups (:obj:`int`, *optional*, defaults to `32`):
212
+ The number of groups to use for group norm
213
+ dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
214
+ Parameters `dtype`
215
+
216
+ """
217
+ channels: int
218
+ num_head_channels: int = None
219
+ num_groups: int = 32
220
+ dtype: jnp.dtype = jnp.float32
221
+
222
+ def setup(self):
223
+ self.num_heads = self.channels // self.num_head_channels if self.num_head_channels is not None else 1
224
+
225
+ dense = partial(nn.Dense, self.channels, dtype=self.dtype)
226
+
227
+ self.group_norm = nn.GroupNorm(num_groups=self.num_groups, epsilon=1e-6)
228
+ self.query, self.key, self.value = dense(), dense(), dense()
229
+ self.proj_attn = dense()
230
+
231
+ def transpose_for_scores(self, projection):
232
+ new_projection_shape = projection.shape[:-1] + (self.num_heads, -1)
233
+ # move heads to 2nd position (B, T, H * D) -> (B, T, H, D)
234
+ new_projection = projection.reshape(new_projection_shape)
235
+ # (B, T, H, D) -> (B, H, T, D)
236
+ new_projection = jnp.transpose(new_projection, (0, 2, 1, 3))
237
+ return new_projection
238
+
239
+ def __call__(self, hidden_states):
240
+ residual = hidden_states
241
+ batch, height, width, channels = hidden_states.shape
242
+
243
+ hidden_states = self.group_norm(hidden_states)
244
+
245
+ hidden_states = hidden_states.reshape((batch, height * width, channels))
246
+
247
+ query = self.query(hidden_states)
248
+ key = self.key(hidden_states)
249
+ value = self.value(hidden_states)
250
+
251
+ # transpose
252
+ query = self.transpose_for_scores(query)
253
+ key = self.transpose_for_scores(key)
254
+ value = self.transpose_for_scores(value)
255
+
256
+ # compute attentions
257
+ scale = 1 / math.sqrt(math.sqrt(self.channels / self.num_heads))
258
+ attn_weights = jnp.einsum("...qc,...kc->...qk", query * scale, key * scale)
259
+ attn_weights = nn.softmax(attn_weights, axis=-1)
260
+
261
+ # attend to values
262
+ hidden_states = jnp.einsum("...kc,...qk->...qc", value, attn_weights)
263
+
264
+ hidden_states = jnp.transpose(hidden_states, (0, 2, 1, 3))
265
+ new_hidden_states_shape = hidden_states.shape[:-2] + (self.channels,)
266
+ hidden_states = hidden_states.reshape(new_hidden_states_shape)
267
+
268
+ hidden_states = self.proj_attn(hidden_states)
269
+ hidden_states = hidden_states.reshape((batch, height, width, channels))
270
+ hidden_states = hidden_states + residual
271
+ return hidden_states
272
+
273
+
274
+ class FlaxDownEncoderBlock2D(nn.Module):
275
+ r"""
276
+ Flax Resnet blocks-based Encoder block for diffusion-based VAE.
277
+
278
+ Parameters:
279
+ in_channels (:obj:`int`):
280
+ Input channels
281
+ out_channels (:obj:`int`):
282
+ Output channels
283
+ dropout (:obj:`float`, *optional*, defaults to 0.0):
284
+ Dropout rate
285
+ num_layers (:obj:`int`, *optional*, defaults to 1):
286
+ Number of Resnet layer block
287
+ resnet_groups (:obj:`int`, *optional*, defaults to `32`):
288
+ The number of groups to use for the Resnet block group norm
289
+ add_downsample (:obj:`bool`, *optional*, defaults to `True`):
290
+ Whether to add downsample layer
291
+ dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
292
+ Parameters `dtype`
293
+ """
294
+ in_channels: int
295
+ out_channels: int
296
+ dropout: float = 0.0
297
+ num_layers: int = 1
298
+ resnet_groups: int = 32
299
+ add_downsample: bool = True
300
+ dtype: jnp.dtype = jnp.float32
301
+
302
+ def setup(self):
303
+ resnets = []
304
+ for i in range(self.num_layers):
305
+ in_channels = self.in_channels if i == 0 else self.out_channels
306
+
307
+ res_block = FlaxResnetBlock2D(
308
+ in_channels=in_channels,
309
+ out_channels=self.out_channels,
310
+ dropout=self.dropout,
311
+ groups=self.resnet_groups,
312
+ dtype=self.dtype,
313
+ )
314
+ resnets.append(res_block)
315
+ self.resnets = resnets
316
+
317
+ if self.add_downsample:
318
+ self.downsamplers_0 = FlaxDownsample2D(self.out_channels, dtype=self.dtype)
319
+
320
+ def __call__(self, hidden_states, deterministic=True):
321
+ for resnet in self.resnets:
322
+ hidden_states = resnet(hidden_states, deterministic=deterministic)
323
+
324
+ if self.add_downsample:
325
+ hidden_states = self.downsamplers_0(hidden_states)
326
+
327
+ return hidden_states
328
+
329
+
330
+ class FlaxUpDecoderBlock2D(nn.Module):
331
+ r"""
332
+ Flax Resnet blocks-based Decoder block for diffusion-based VAE.
333
+
334
+ Parameters:
335
+ in_channels (:obj:`int`):
336
+ Input channels
337
+ out_channels (:obj:`int`):
338
+ Output channels
339
+ dropout (:obj:`float`, *optional*, defaults to 0.0):
340
+ Dropout rate
341
+ num_layers (:obj:`int`, *optional*, defaults to 1):
342
+ Number of Resnet layer block
343
+ resnet_groups (:obj:`int`, *optional*, defaults to `32`):
344
+ The number of groups to use for the Resnet block group norm
345
+ add_upsample (:obj:`bool`, *optional*, defaults to `True`):
346
+ Whether to add upsample layer
347
+ dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
348
+ Parameters `dtype`
349
+ """
350
+ in_channels: int
351
+ out_channels: int
352
+ dropout: float = 0.0
353
+ num_layers: int = 1
354
+ resnet_groups: int = 32
355
+ add_upsample: bool = True
356
+ dtype: jnp.dtype = jnp.float32
357
+
358
+ def setup(self):
359
+ resnets = []
360
+ for i in range(self.num_layers):
361
+ in_channels = self.in_channels if i == 0 else self.out_channels
362
+ res_block = FlaxResnetBlock2D(
363
+ in_channels=in_channels,
364
+ out_channels=self.out_channels,
365
+ dropout=self.dropout,
366
+ groups=self.resnet_groups,
367
+ dtype=self.dtype,
368
+ )
369
+ resnets.append(res_block)
370
+
371
+ self.resnets = resnets
372
+
373
+ if self.add_upsample:
374
+ self.upsamplers_0 = FlaxUpsample2D(self.out_channels, dtype=self.dtype)
375
+
376
+ def __call__(self, hidden_states, deterministic=True):
377
+ for resnet in self.resnets:
378
+ hidden_states = resnet(hidden_states, deterministic=deterministic)
379
+
380
+ if self.add_upsample:
381
+ hidden_states = self.upsamplers_0(hidden_states)
382
+
383
+ return hidden_states
384
+
385
+
386
+ class FlaxUNetMidBlock2D(nn.Module):
387
+ r"""
388
+ Flax Unet Mid-Block module.
389
+
390
+ Parameters:
391
+ in_channels (:obj:`int`):
392
+ Input channels
393
+ dropout (:obj:`float`, *optional*, defaults to 0.0):
394
+ Dropout rate
395
+ num_layers (:obj:`int`, *optional*, defaults to 1):
396
+ Number of Resnet layer block
397
+ resnet_groups (:obj:`int`, *optional*, defaults to `32`):
398
+ The number of groups to use for the Resnet and Attention block group norm
399
+ num_attention_heads (:obj:`int`, *optional*, defaults to `1`):
400
+ Number of attention heads for each attention block
401
+ dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
402
+ Parameters `dtype`
403
+ """
404
+ in_channels: int
405
+ dropout: float = 0.0
406
+ num_layers: int = 1
407
+ resnet_groups: int = 32
408
+ num_attention_heads: int = 1
409
+ dtype: jnp.dtype = jnp.float32
410
+
411
+ def setup(self):
412
+ resnet_groups = self.resnet_groups if self.resnet_groups is not None else min(self.in_channels // 4, 32)
413
+
414
+ # there is always at least one resnet
415
+ resnets = [
416
+ FlaxResnetBlock2D(
417
+ in_channels=self.in_channels,
418
+ out_channels=self.in_channels,
419
+ dropout=self.dropout,
420
+ groups=resnet_groups,
421
+ dtype=self.dtype,
422
+ )
423
+ ]
424
+
425
+ attentions = []
426
+
427
+ for _ in range(self.num_layers):
428
+ attn_block = FlaxAttentionBlock(
429
+ channels=self.in_channels,
430
+ num_head_channels=self.num_attention_heads,
431
+ num_groups=resnet_groups,
432
+ dtype=self.dtype,
433
+ )
434
+ attentions.append(attn_block)
435
+
436
+ res_block = FlaxResnetBlock2D(
437
+ in_channels=self.in_channels,
438
+ out_channels=self.in_channels,
439
+ dropout=self.dropout,
440
+ groups=resnet_groups,
441
+ dtype=self.dtype,
442
+ )
443
+ resnets.append(res_block)
444
+
445
+ self.resnets = resnets
446
+ self.attentions = attentions
447
+
448
+ def __call__(self, hidden_states, deterministic=True):
449
+ hidden_states = self.resnets[0](hidden_states, deterministic=deterministic)
450
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
451
+ hidden_states = attn(hidden_states)
452
+ hidden_states = resnet(hidden_states, deterministic=deterministic)
453
+
454
+ return hidden_states
455
+
456
+
457
+ class FlaxEncoder(nn.Module):
458
+ r"""
459
+ Flax Implementation of VAE Encoder.
460
+
461
+ This model is a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module)
462
+ subclass. Use it as a regular Flax linen Module and refer to the Flax documentation for all matter related to
463
+ general usage and behavior.
464
+
465
+ Finally, this model supports inherent JAX features such as:
466
+ - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
467
+ - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
468
+ - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
469
+ - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
470
+
471
+ Parameters:
472
+ in_channels (:obj:`int`, *optional*, defaults to 3):
473
+ Input channels
474
+ out_channels (:obj:`int`, *optional*, defaults to 3):
475
+ Output channels
476
+ down_block_types (:obj:`Tuple[str]`, *optional*, defaults to `(DownEncoderBlock2D)`):
477
+ DownEncoder block type
478
+ block_out_channels (:obj:`Tuple[str]`, *optional*, defaults to `(64,)`):
479
+ Tuple containing the number of output channels for each block
480
+ layers_per_block (:obj:`int`, *optional*, defaults to `2`):
481
+ Number of Resnet layer for each block
482
+ norm_num_groups (:obj:`int`, *optional*, defaults to `32`):
483
+ norm num group
484
+ act_fn (:obj:`str`, *optional*, defaults to `silu`):
485
+ Activation function
486
+ double_z (:obj:`bool`, *optional*, defaults to `False`):
487
+ Whether to double the last output channels
488
+ dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
489
+ Parameters `dtype`
490
+ """
491
+ in_channels: int = 3
492
+ out_channels: int = 3
493
+ down_block_types: Tuple[str] = ("DownEncoderBlock2D",)
494
+ block_out_channels: Tuple[int] = (64,)
495
+ layers_per_block: int = 2
496
+ norm_num_groups: int = 32
497
+ act_fn: str = "silu"
498
+ double_z: bool = False
499
+ dtype: jnp.dtype = jnp.float32
500
+
501
+ def setup(self):
502
+ block_out_channels = self.block_out_channels
503
+ # in
504
+ self.conv_in = nn.Conv(
505
+ block_out_channels[0],
506
+ kernel_size=(3, 3),
507
+ strides=(1, 1),
508
+ padding=((1, 1), (1, 1)),
509
+ dtype=self.dtype,
510
+ )
511
+
512
+ # downsampling
513
+ down_blocks = []
514
+ output_channel = block_out_channels[0]
515
+ for i, _ in enumerate(self.down_block_types):
516
+ input_channel = output_channel
517
+ output_channel = block_out_channels[i]
518
+ is_final_block = i == len(block_out_channels) - 1
519
+
520
+ down_block = FlaxDownEncoderBlock2D(
521
+ in_channels=input_channel,
522
+ out_channels=output_channel,
523
+ num_layers=self.layers_per_block,
524
+ resnet_groups=self.norm_num_groups,
525
+ add_downsample=not is_final_block,
526
+ dtype=self.dtype,
527
+ )
528
+ down_blocks.append(down_block)
529
+ self.down_blocks = down_blocks
530
+
531
+ # middle
532
+ self.mid_block = FlaxUNetMidBlock2D(
533
+ in_channels=block_out_channels[-1],
534
+ resnet_groups=self.norm_num_groups,
535
+ num_attention_heads=None,
536
+ dtype=self.dtype,
537
+ )
538
+
539
+ # end
540
+ conv_out_channels = 2 * self.out_channels if self.double_z else self.out_channels
541
+ self.conv_norm_out = nn.GroupNorm(num_groups=self.norm_num_groups, epsilon=1e-6)
542
+ self.conv_out = nn.Conv(
543
+ conv_out_channels,
544
+ kernel_size=(3, 3),
545
+ strides=(1, 1),
546
+ padding=((1, 1), (1, 1)),
547
+ dtype=self.dtype,
548
+ )
549
+
550
+ def __call__(self, sample, deterministic: bool = True):
551
+ # in
552
+ sample = self.conv_in(sample)
553
+
554
+ # downsampling
555
+ for block in self.down_blocks:
556
+ sample = block(sample, deterministic=deterministic)
557
+
558
+ # middle
559
+ sample = self.mid_block(sample, deterministic=deterministic)
560
+
561
+ # end
562
+ sample = self.conv_norm_out(sample)
563
+ sample = nn.swish(sample)
564
+ sample = self.conv_out(sample)
565
+
566
+ return sample
567
+
568
+
569
+ class FlaxDecoder(nn.Module):
570
+ r"""
571
+ Flax Implementation of VAE Decoder.
572
+
573
+ This model is a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module)
574
+ subclass. Use it as a regular Flax linen Module and refer to the Flax documentation for all matter related to
575
+ general usage and behavior.
576
+
577
+ Finally, this model supports inherent JAX features such as:
578
+ - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
579
+ - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
580
+ - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
581
+ - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
582
+
583
+ Parameters:
584
+ in_channels (:obj:`int`, *optional*, defaults to 3):
585
+ Input channels
586
+ out_channels (:obj:`int`, *optional*, defaults to 3):
587
+ Output channels
588
+ up_block_types (:obj:`Tuple[str]`, *optional*, defaults to `(UpDecoderBlock2D)`):
589
+ UpDecoder block type
590
+ block_out_channels (:obj:`Tuple[str]`, *optional*, defaults to `(64,)`):
591
+ Tuple containing the number of output channels for each block
592
+ layers_per_block (:obj:`int`, *optional*, defaults to `2`):
593
+ Number of Resnet layer for each block
594
+ norm_num_groups (:obj:`int`, *optional*, defaults to `32`):
595
+ norm num group
596
+ act_fn (:obj:`str`, *optional*, defaults to `silu`):
597
+ Activation function
598
+ double_z (:obj:`bool`, *optional*, defaults to `False`):
599
+ Whether to double the last output channels
600
+ dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
601
+ parameters `dtype`
602
+ """
603
+ in_channels: int = 3
604
+ out_channels: int = 3
605
+ up_block_types: Tuple[str] = ("UpDecoderBlock2D",)
606
+ block_out_channels: int = (64,)
607
+ layers_per_block: int = 2
608
+ norm_num_groups: int = 32
609
+ act_fn: str = "silu"
610
+ dtype: jnp.dtype = jnp.float32
611
+
612
+ def setup(self):
613
+ block_out_channels = self.block_out_channels
614
+
615
+ # z to block_in
616
+ self.conv_in = nn.Conv(
617
+ block_out_channels[-1],
618
+ kernel_size=(3, 3),
619
+ strides=(1, 1),
620
+ padding=((1, 1), (1, 1)),
621
+ dtype=self.dtype,
622
+ )
623
+
624
+ # middle
625
+ self.mid_block = FlaxUNetMidBlock2D(
626
+ in_channels=block_out_channels[-1],
627
+ resnet_groups=self.norm_num_groups,
628
+ num_attention_heads=None,
629
+ dtype=self.dtype,
630
+ )
631
+
632
+ # upsampling
633
+ reversed_block_out_channels = list(reversed(block_out_channels))
634
+ output_channel = reversed_block_out_channels[0]
635
+ up_blocks = []
636
+ for i, _ in enumerate(self.up_block_types):
637
+ prev_output_channel = output_channel
638
+ output_channel = reversed_block_out_channels[i]
639
+
640
+ is_final_block = i == len(block_out_channels) - 1
641
+
642
+ up_block = FlaxUpDecoderBlock2D(
643
+ in_channels=prev_output_channel,
644
+ out_channels=output_channel,
645
+ num_layers=self.layers_per_block + 1,
646
+ resnet_groups=self.norm_num_groups,
647
+ add_upsample=not is_final_block,
648
+ dtype=self.dtype,
649
+ )
650
+ up_blocks.append(up_block)
651
+ prev_output_channel = output_channel
652
+
653
+ self.up_blocks = up_blocks
654
+
655
+ # end
656
+ self.conv_norm_out = nn.GroupNorm(num_groups=self.norm_num_groups, epsilon=1e-6)
657
+ self.conv_out = nn.Conv(
658
+ self.out_channels,
659
+ kernel_size=(3, 3),
660
+ strides=(1, 1),
661
+ padding=((1, 1), (1, 1)),
662
+ dtype=self.dtype,
663
+ )
664
+
665
+ def __call__(self, sample, deterministic: bool = True):
666
+ # z to block_in
667
+ sample = self.conv_in(sample)
668
+
669
+ # middle
670
+ sample = self.mid_block(sample, deterministic=deterministic)
671
+
672
+ # upsampling
673
+ for block in self.up_blocks:
674
+ sample = block(sample, deterministic=deterministic)
675
+
676
+ sample = self.conv_norm_out(sample)
677
+ sample = nn.swish(sample)
678
+ sample = self.conv_out(sample)
679
+
680
+ return sample
681
+
682
+
683
+ class FlaxDiagonalGaussianDistribution(object):
684
+ def __init__(self, parameters, deterministic=False):
685
+ # Last axis to account for channels-last
686
+ self.mean, self.logvar = jnp.split(parameters, 2, axis=-1)
687
+ self.logvar = jnp.clip(self.logvar, -30.0, 20.0)
688
+ self.deterministic = deterministic
689
+ self.std = jnp.exp(0.5 * self.logvar)
690
+ self.var = jnp.exp(self.logvar)
691
+ if self.deterministic:
692
+ self.var = self.std = jnp.zeros_like(self.mean)
693
+
694
+ def sample(self, key):
695
+ return self.mean + self.std * jax.random.normal(key, self.mean.shape)
696
+
697
+ def kl(self, other=None):
698
+ if self.deterministic:
699
+ return jnp.array([0.0])
700
+
701
+ if other is None:
702
+ return 0.5 * jnp.sum(self.mean**2 + self.var - 1.0 - self.logvar, axis=[1, 2, 3])
703
+
704
+ return 0.5 * jnp.sum(
705
+ jnp.square(self.mean - other.mean) / other.var + self.var / other.var - 1.0 - self.logvar + other.logvar,
706
+ axis=[1, 2, 3],
707
+ )
708
+
709
+ def nll(self, sample, axis=[1, 2, 3]):
710
+ if self.deterministic:
711
+ return jnp.array([0.0])
712
+
713
+ logtwopi = jnp.log(2.0 * jnp.pi)
714
+ return 0.5 * jnp.sum(logtwopi + self.logvar + jnp.square(sample - self.mean) / self.var, axis=axis)
715
+
716
+ def mode(self):
717
+ return self.mean
718
+
719
+
720
+ @flax_register_to_config
721
+ class FlaxAutoencoderKL(nn.Module, FlaxModelMixin, ConfigMixin):
722
+ r"""
723
+ Flax implementation of a VAE model with KL loss for decoding latent representations.
724
+
725
+ This model inherits from [`FlaxModelMixin`]. Check the superclass documentation for it's generic methods
726
+ implemented for all models (such as downloading or saving).
727
+
728
+ This model is a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module)
729
+ subclass. Use it as a regular Flax Linen module and refer to the Flax documentation for all matter related to its
730
+ general usage and behavior.
731
+
732
+ Inherent JAX features such as the following are supported:
733
+
734
+ - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
735
+ - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
736
+ - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
737
+ - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
738
+
739
+ Parameters:
740
+ in_channels (`int`, *optional*, defaults to 3):
741
+ Number of channels in the input image.
742
+ out_channels (`int`, *optional*, defaults to 3):
743
+ Number of channels in the output.
744
+ down_block_types (`Tuple[str]`, *optional*, defaults to `(DownEncoderBlock2D)`):
745
+ Tuple of downsample block types.
746
+ up_block_types (`Tuple[str]`, *optional*, defaults to `(UpDecoderBlock2D)`):
747
+ Tuple of upsample block types.
748
+ block_out_channels (`Tuple[str]`, *optional*, defaults to `(64,)`):
749
+ Tuple of block output channels.
750
+ layers_per_block (`int`, *optional*, defaults to `2`):
751
+ Number of ResNet layer for each block.
752
+ act_fn (`str`, *optional*, defaults to `silu`):
753
+ The activation function to use.
754
+ latent_channels (`int`, *optional*, defaults to `4`):
755
+ Number of channels in the latent space.
756
+ norm_num_groups (`int`, *optional*, defaults to `32`):
757
+ The number of groups for normalization.
758
+ sample_size (`int`, *optional*, defaults to 32):
759
+ Sample input size.
760
+ scaling_factor (`float`, *optional*, defaults to 0.18215):
761
+ The component-wise standard deviation of the trained latent space computed using the first batch of the
762
+ training set. This is used to scale the latent space to have unit variance when training the diffusion
763
+ model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
764
+ diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
765
+ / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
766
+ Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
767
+ dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`):
768
+ The `dtype` of the parameters.
769
+ """
770
+ in_channels: int = 3
771
+ out_channels: int = 3
772
+ down_block_types: Tuple[str] = ("DownEncoderBlock2D",)
773
+ up_block_types: Tuple[str] = ("UpDecoderBlock2D",)
774
+ block_out_channels: Tuple[int] = (64,)
775
+ layers_per_block: int = 1
776
+ act_fn: str = "silu"
777
+ latent_channels: int = 4
778
+ norm_num_groups: int = 32
779
+ sample_size: int = 32
780
+ scaling_factor: float = 0.18215
781
+ dtype: jnp.dtype = jnp.float32
782
+
783
+ def setup(self):
784
+ self.encoder = FlaxEncoder(
785
+ in_channels=self.config.in_channels,
786
+ out_channels=self.config.latent_channels,
787
+ down_block_types=self.config.down_block_types,
788
+ block_out_channels=self.config.block_out_channels,
789
+ layers_per_block=self.config.layers_per_block,
790
+ act_fn=self.config.act_fn,
791
+ norm_num_groups=self.config.norm_num_groups,
792
+ double_z=True,
793
+ dtype=self.dtype,
794
+ )
795
+ self.decoder = FlaxDecoder(
796
+ in_channels=self.config.latent_channels,
797
+ out_channels=self.config.out_channels,
798
+ up_block_types=self.config.up_block_types,
799
+ block_out_channels=self.config.block_out_channels,
800
+ layers_per_block=self.config.layers_per_block,
801
+ norm_num_groups=self.config.norm_num_groups,
802
+ act_fn=self.config.act_fn,
803
+ dtype=self.dtype,
804
+ )
805
+ self.quant_conv = nn.Conv(
806
+ 2 * self.config.latent_channels,
807
+ kernel_size=(1, 1),
808
+ strides=(1, 1),
809
+ padding="VALID",
810
+ dtype=self.dtype,
811
+ )
812
+ self.post_quant_conv = nn.Conv(
813
+ self.config.latent_channels,
814
+ kernel_size=(1, 1),
815
+ strides=(1, 1),
816
+ padding="VALID",
817
+ dtype=self.dtype,
818
+ )
819
+
820
+ def init_weights(self, rng: jax.random.KeyArray) -> FrozenDict:
821
+ # init input tensors
822
+ sample_shape = (1, self.in_channels, self.sample_size, self.sample_size)
823
+ sample = jnp.zeros(sample_shape, dtype=jnp.float32)
824
+
825
+ params_rng, dropout_rng, gaussian_rng = jax.random.split(rng, 3)
826
+ rngs = {"params": params_rng, "dropout": dropout_rng, "gaussian": gaussian_rng}
827
+
828
+ return self.init(rngs, sample)["params"]
829
+
830
+ def encode(self, sample, deterministic: bool = True, return_dict: bool = True):
831
+ sample = jnp.transpose(sample, (0, 2, 3, 1))
832
+
833
+ hidden_states = self.encoder(sample, deterministic=deterministic)
834
+ moments = self.quant_conv(hidden_states)
835
+ posterior = FlaxDiagonalGaussianDistribution(moments)
836
+
837
+ if not return_dict:
838
+ return (posterior,)
839
+
840
+ return FlaxAutoencoderKLOutput(latent_dist=posterior)
841
+
842
+ def decode(self, latents, deterministic: bool = True, return_dict: bool = True):
843
+ if latents.shape[-1] != self.config.latent_channels:
844
+ latents = jnp.transpose(latents, (0, 2, 3, 1))
845
+
846
+ hidden_states = self.post_quant_conv(latents)
847
+ hidden_states = self.decoder(hidden_states, deterministic=deterministic)
848
+
849
+ hidden_states = jnp.transpose(hidden_states, (0, 3, 1, 2))
850
+
851
+ if not return_dict:
852
+ return (hidden_states,)
853
+
854
+ return FlaxDecoderOutput(sample=hidden_states)
855
+
856
+ def __call__(self, sample, sample_posterior=False, deterministic: bool = True, return_dict: bool = True):
857
+ posterior = self.encode(sample, deterministic=deterministic, return_dict=return_dict)
858
+ if sample_posterior:
859
+ rng = self.make_rng("gaussian")
860
+ hidden_states = posterior.latent_dist.sample(rng)
861
+ else:
862
+ hidden_states = posterior.latent_dist.mode()
863
+
864
+ sample = self.decode(hidden_states, return_dict=return_dict).sample
865
+
866
+ if not return_dict:
867
+ return (sample,)
868
+
869
+ return FlaxDecoderOutput(sample=sample)
6DoF/diffusers/models/vq_model.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from dataclasses import dataclass
15
+ from typing import Optional, Tuple, Union
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+
20
+ from ..configuration_utils import ConfigMixin, register_to_config
21
+ from ..utils import BaseOutput, apply_forward_hook
22
+ from .modeling_utils import ModelMixin
23
+ from .vae import Decoder, DecoderOutput, Encoder, VectorQuantizer
24
+
25
+
26
+ @dataclass
27
+ class VQEncoderOutput(BaseOutput):
28
+ """
29
+ Output of VQModel encoding method.
30
+
31
+ Args:
32
+ latents (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
33
+ The encoded output sample from the last layer of the model.
34
+ """
35
+
36
+ latents: torch.FloatTensor
37
+
38
+
39
+ class VQModel(ModelMixin, ConfigMixin):
40
+ r"""
41
+ A VQ-VAE model for decoding latent representations.
42
+
43
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
44
+ for all models (such as downloading or saving).
45
+
46
+ Parameters:
47
+ in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
48
+ out_channels (int, *optional*, defaults to 3): Number of channels in the output.
49
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
50
+ Tuple of downsample block types.
51
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
52
+ Tuple of upsample block types.
53
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
54
+ Tuple of block output channels.
55
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
56
+ latent_channels (`int`, *optional*, defaults to `3`): Number of channels in the latent space.
57
+ sample_size (`int`, *optional*, defaults to `32`): Sample input size.
58
+ num_vq_embeddings (`int`, *optional*, defaults to `256`): Number of codebook vectors in the VQ-VAE.
59
+ vq_embed_dim (`int`, *optional*): Hidden dim of codebook vectors in the VQ-VAE.
60
+ scaling_factor (`float`, *optional*, defaults to `0.18215`):
61
+ The component-wise standard deviation of the trained latent space computed using the first batch of the
62
+ training set. This is used to scale the latent space to have unit variance when training the diffusion
63
+ model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
64
+ diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
65
+ / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
66
+ Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
67
+ """
68
+
69
+ @register_to_config
70
+ def __init__(
71
+ self,
72
+ in_channels: int = 3,
73
+ out_channels: int = 3,
74
+ down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
75
+ up_block_types: Tuple[str] = ("UpDecoderBlock2D",),
76
+ block_out_channels: Tuple[int] = (64,),
77
+ layers_per_block: int = 1,
78
+ act_fn: str = "silu",
79
+ latent_channels: int = 3,
80
+ sample_size: int = 32,
81
+ num_vq_embeddings: int = 256,
82
+ norm_num_groups: int = 32,
83
+ vq_embed_dim: Optional[int] = None,
84
+ scaling_factor: float = 0.18215,
85
+ norm_type: str = "group", # group, spatial
86
+ ):
87
+ super().__init__()
88
+
89
+ # pass init params to Encoder
90
+ self.encoder = Encoder(
91
+ in_channels=in_channels,
92
+ out_channels=latent_channels,
93
+ down_block_types=down_block_types,
94
+ block_out_channels=block_out_channels,
95
+ layers_per_block=layers_per_block,
96
+ act_fn=act_fn,
97
+ norm_num_groups=norm_num_groups,
98
+ double_z=False,
99
+ )
100
+
101
+ vq_embed_dim = vq_embed_dim if vq_embed_dim is not None else latent_channels
102
+
103
+ self.quant_conv = nn.Conv2d(latent_channels, vq_embed_dim, 1)
104
+ self.quantize = VectorQuantizer(num_vq_embeddings, vq_embed_dim, beta=0.25, remap=None, sane_index_shape=False)
105
+ self.post_quant_conv = nn.Conv2d(vq_embed_dim, latent_channels, 1)
106
+
107
+ # pass init params to Decoder
108
+ self.decoder = Decoder(
109
+ in_channels=latent_channels,
110
+ out_channels=out_channels,
111
+ up_block_types=up_block_types,
112
+ block_out_channels=block_out_channels,
113
+ layers_per_block=layers_per_block,
114
+ act_fn=act_fn,
115
+ norm_num_groups=norm_num_groups,
116
+ norm_type=norm_type,
117
+ )
118
+
119
+ @apply_forward_hook
120
+ def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> VQEncoderOutput:
121
+ h = self.encoder(x)
122
+ h = self.quant_conv(h)
123
+
124
+ if not return_dict:
125
+ return (h,)
126
+
127
+ return VQEncoderOutput(latents=h)
128
+
129
+ @apply_forward_hook
130
+ def decode(
131
+ self, h: torch.FloatTensor, force_not_quantize: bool = False, return_dict: bool = True
132
+ ) -> Union[DecoderOutput, torch.FloatTensor]:
133
+ # also go through quantization layer
134
+ if not force_not_quantize:
135
+ quant, emb_loss, info = self.quantize(h)
136
+ else:
137
+ quant = h
138
+ quant2 = self.post_quant_conv(quant)
139
+ dec = self.decoder(quant2, quant if self.config.norm_type == "spatial" else None)
140
+
141
+ if not return_dict:
142
+ return (dec,)
143
+
144
+ return DecoderOutput(sample=dec)
145
+
146
+ def forward(self, sample: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
147
+ r"""
148
+ The [`VQModel`] forward method.
149
+
150
+ Args:
151
+ sample (`torch.FloatTensor`): Input sample.
152
+ return_dict (`bool`, *optional*, defaults to `True`):
153
+ Whether or not to return a [`models.vq_model.VQEncoderOutput`] instead of a plain tuple.
154
+
155
+ Returns:
156
+ [`~models.vq_model.VQEncoderOutput`] or `tuple`:
157
+ If return_dict is True, a [`~models.vq_model.VQEncoderOutput`] is returned, otherwise a plain `tuple`
158
+ is returned.
159
+ """
160
+ x = sample
161
+ h = self.encode(x).latents
162
+ dec = self.decode(h).sample
163
+
164
+ if not return_dict:
165
+ return (dec,)
166
+
167
+ return DecoderOutput(sample=dec)
6DoF/diffusers/optimization.py ADDED
@@ -0,0 +1,354 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """PyTorch optimization for diffusion models."""
16
+
17
+ import math
18
+ from enum import Enum
19
+ from typing import Optional, Union
20
+
21
+ from torch.optim import Optimizer
22
+ from torch.optim.lr_scheduler import LambdaLR
23
+
24
+ from .utils import logging
25
+
26
+
27
+ logger = logging.get_logger(__name__)
28
+
29
+
30
+ class SchedulerType(Enum):
31
+ LINEAR = "linear"
32
+ COSINE = "cosine"
33
+ COSINE_WITH_RESTARTS = "cosine_with_restarts"
34
+ POLYNOMIAL = "polynomial"
35
+ CONSTANT = "constant"
36
+ CONSTANT_WITH_WARMUP = "constant_with_warmup"
37
+ PIECEWISE_CONSTANT = "piecewise_constant"
38
+
39
+
40
+ def get_constant_schedule(optimizer: Optimizer, last_epoch: int = -1):
41
+ """
42
+ Create a schedule with a constant learning rate, using the learning rate set in optimizer.
43
+
44
+ Args:
45
+ optimizer ([`~torch.optim.Optimizer`]):
46
+ The optimizer for which to schedule the learning rate.
47
+ last_epoch (`int`, *optional*, defaults to -1):
48
+ The index of the last epoch when resuming training.
49
+
50
+ Return:
51
+ `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
52
+ """
53
+ return LambdaLR(optimizer, lambda _: 1, last_epoch=last_epoch)
54
+
55
+
56
+ def get_constant_schedule_with_warmup(optimizer: Optimizer, num_warmup_steps: int, last_epoch: int = -1):
57
+ """
58
+ Create a schedule with a constant learning rate preceded by a warmup period during which the learning rate
59
+ increases linearly between 0 and the initial lr set in the optimizer.
60
+
61
+ Args:
62
+ optimizer ([`~torch.optim.Optimizer`]):
63
+ The optimizer for which to schedule the learning rate.
64
+ num_warmup_steps (`int`):
65
+ The number of steps for the warmup phase.
66
+ last_epoch (`int`, *optional*, defaults to -1):
67
+ The index of the last epoch when resuming training.
68
+
69
+ Return:
70
+ `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
71
+ """
72
+
73
+ def lr_lambda(current_step: int):
74
+ if current_step < num_warmup_steps:
75
+ return float(current_step) / float(max(1.0, num_warmup_steps))
76
+ return 1.0
77
+
78
+ return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch)
79
+
80
+
81
+ def get_piecewise_constant_schedule(optimizer: Optimizer, step_rules: str, last_epoch: int = -1):
82
+ """
83
+ Create a schedule with a constant learning rate, using the learning rate set in optimizer.
84
+
85
+ Args:
86
+ optimizer ([`~torch.optim.Optimizer`]):
87
+ The optimizer for which to schedule the learning rate.
88
+ step_rules (`string`):
89
+ The rules for the learning rate. ex: rule_steps="1:10,0.1:20,0.01:30,0.005" it means that the learning rate
90
+ if multiple 1 for the first 10 steps, mutiple 0.1 for the next 20 steps, multiple 0.01 for the next 30
91
+ steps and multiple 0.005 for the other steps.
92
+ last_epoch (`int`, *optional*, defaults to -1):
93
+ The index of the last epoch when resuming training.
94
+
95
+ Return:
96
+ `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
97
+ """
98
+
99
+ rules_dict = {}
100
+ rule_list = step_rules.split(",")
101
+ for rule_str in rule_list[:-1]:
102
+ value_str, steps_str = rule_str.split(":")
103
+ steps = int(steps_str)
104
+ value = float(value_str)
105
+ rules_dict[steps] = value
106
+ last_lr_multiple = float(rule_list[-1])
107
+
108
+ def create_rules_function(rules_dict, last_lr_multiple):
109
+ def rule_func(steps: int) -> float:
110
+ sorted_steps = sorted(rules_dict.keys())
111
+ for i, sorted_step in enumerate(sorted_steps):
112
+ if steps < sorted_step:
113
+ return rules_dict[sorted_steps[i]]
114
+ return last_lr_multiple
115
+
116
+ return rule_func
117
+
118
+ rules_func = create_rules_function(rules_dict, last_lr_multiple)
119
+
120
+ return LambdaLR(optimizer, rules_func, last_epoch=last_epoch)
121
+
122
+
123
+ def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1):
124
+ """
125
+ Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after
126
+ a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer.
127
+
128
+ Args:
129
+ optimizer ([`~torch.optim.Optimizer`]):
130
+ The optimizer for which to schedule the learning rate.
131
+ num_warmup_steps (`int`):
132
+ The number of steps for the warmup phase.
133
+ num_training_steps (`int`):
134
+ The total number of training steps.
135
+ last_epoch (`int`, *optional*, defaults to -1):
136
+ The index of the last epoch when resuming training.
137
+
138
+ Return:
139
+ `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
140
+ """
141
+
142
+ def lr_lambda(current_step: int):
143
+ if current_step < num_warmup_steps:
144
+ return float(current_step) / float(max(1, num_warmup_steps))
145
+ return max(
146
+ 0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps))
147
+ )
148
+
149
+ return LambdaLR(optimizer, lr_lambda, last_epoch)
150
+
151
+
152
+ def get_cosine_schedule_with_warmup(
153
+ optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: float = 0.5, last_epoch: int = -1
154
+ ):
155
+ """
156
+ Create a schedule with a learning rate that decreases following the values of the cosine function between the
157
+ initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the
158
+ initial lr set in the optimizer.
159
+
160
+ Args:
161
+ optimizer ([`~torch.optim.Optimizer`]):
162
+ The optimizer for which to schedule the learning rate.
163
+ num_warmup_steps (`int`):
164
+ The number of steps for the warmup phase.
165
+ num_training_steps (`int`):
166
+ The total number of training steps.
167
+ num_periods (`float`, *optional*, defaults to 0.5):
168
+ The number of periods of the cosine function in a schedule (the default is to just decrease from the max
169
+ value to 0 following a half-cosine).
170
+ last_epoch (`int`, *optional*, defaults to -1):
171
+ The index of the last epoch when resuming training.
172
+
173
+ Return:
174
+ `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
175
+ """
176
+
177
+ def lr_lambda(current_step):
178
+ if current_step < num_warmup_steps:
179
+ return float(current_step) / float(max(1, num_warmup_steps))
180
+ progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
181
+ return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
182
+
183
+ return LambdaLR(optimizer, lr_lambda, last_epoch)
184
+
185
+
186
+ def get_cosine_with_hard_restarts_schedule_with_warmup(
187
+ optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: int = 1, last_epoch: int = -1
188
+ ):
189
+ """
190
+ Create a schedule with a learning rate that decreases following the values of the cosine function between the
191
+ initial lr set in the optimizer to 0, with several hard restarts, after a warmup period during which it increases
192
+ linearly between 0 and the initial lr set in the optimizer.
193
+
194
+ Args:
195
+ optimizer ([`~torch.optim.Optimizer`]):
196
+ The optimizer for which to schedule the learning rate.
197
+ num_warmup_steps (`int`):
198
+ The number of steps for the warmup phase.
199
+ num_training_steps (`int`):
200
+ The total number of training steps.
201
+ num_cycles (`int`, *optional*, defaults to 1):
202
+ The number of hard restarts to use.
203
+ last_epoch (`int`, *optional*, defaults to -1):
204
+ The index of the last epoch when resuming training.
205
+
206
+ Return:
207
+ `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
208
+ """
209
+
210
+ def lr_lambda(current_step):
211
+ if current_step < num_warmup_steps:
212
+ return float(current_step) / float(max(1, num_warmup_steps))
213
+ progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
214
+ if progress >= 1.0:
215
+ return 0.0
216
+ return max(0.0, 0.5 * (1.0 + math.cos(math.pi * ((float(num_cycles) * progress) % 1.0))))
217
+
218
+ return LambdaLR(optimizer, lr_lambda, last_epoch)
219
+
220
+
221
+ def get_polynomial_decay_schedule_with_warmup(
222
+ optimizer, num_warmup_steps, num_training_steps, lr_end=1e-7, power=1.0, last_epoch=-1
223
+ ):
224
+ """
225
+ Create a schedule with a learning rate that decreases as a polynomial decay from the initial lr set in the
226
+ optimizer to end lr defined by *lr_end*, after a warmup period during which it increases linearly from 0 to the
227
+ initial lr set in the optimizer.
228
+
229
+ Args:
230
+ optimizer ([`~torch.optim.Optimizer`]):
231
+ The optimizer for which to schedule the learning rate.
232
+ num_warmup_steps (`int`):
233
+ The number of steps for the warmup phase.
234
+ num_training_steps (`int`):
235
+ The total number of training steps.
236
+ lr_end (`float`, *optional*, defaults to 1e-7):
237
+ The end LR.
238
+ power (`float`, *optional*, defaults to 1.0):
239
+ Power factor.
240
+ last_epoch (`int`, *optional*, defaults to -1):
241
+ The index of the last epoch when resuming training.
242
+
243
+ Note: *power* defaults to 1.0 as in the fairseq implementation, which in turn is based on the original BERT
244
+ implementation at
245
+ https://github.com/google-research/bert/blob/f39e881b169b9d53bea03d2d341b31707a6c052b/optimization.py#L37
246
+
247
+ Return:
248
+ `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
249
+
250
+ """
251
+
252
+ lr_init = optimizer.defaults["lr"]
253
+ if not (lr_init > lr_end):
254
+ raise ValueError(f"lr_end ({lr_end}) must be be smaller than initial lr ({lr_init})")
255
+
256
+ def lr_lambda(current_step: int):
257
+ if current_step < num_warmup_steps:
258
+ return float(current_step) / float(max(1, num_warmup_steps))
259
+ elif current_step > num_training_steps:
260
+ return lr_end / lr_init # as LambdaLR multiplies by lr_init
261
+ else:
262
+ lr_range = lr_init - lr_end
263
+ decay_steps = num_training_steps - num_warmup_steps
264
+ pct_remaining = 1 - (current_step - num_warmup_steps) / decay_steps
265
+ decay = lr_range * pct_remaining**power + lr_end
266
+ return decay / lr_init # as LambdaLR multiplies by lr_init
267
+
268
+ return LambdaLR(optimizer, lr_lambda, last_epoch)
269
+
270
+
271
+ TYPE_TO_SCHEDULER_FUNCTION = {
272
+ SchedulerType.LINEAR: get_linear_schedule_with_warmup,
273
+ SchedulerType.COSINE: get_cosine_schedule_with_warmup,
274
+ SchedulerType.COSINE_WITH_RESTARTS: get_cosine_with_hard_restarts_schedule_with_warmup,
275
+ SchedulerType.POLYNOMIAL: get_polynomial_decay_schedule_with_warmup,
276
+ SchedulerType.CONSTANT: get_constant_schedule,
277
+ SchedulerType.CONSTANT_WITH_WARMUP: get_constant_schedule_with_warmup,
278
+ SchedulerType.PIECEWISE_CONSTANT: get_piecewise_constant_schedule,
279
+ }
280
+
281
+
282
+ def get_scheduler(
283
+ name: Union[str, SchedulerType],
284
+ optimizer: Optimizer,
285
+ step_rules: Optional[str] = None,
286
+ num_warmup_steps: Optional[int] = None,
287
+ num_training_steps: Optional[int] = None,
288
+ num_cycles: int = 1,
289
+ power: float = 1.0,
290
+ last_epoch: int = -1,
291
+ ):
292
+ """
293
+ Unified API to get any scheduler from its name.
294
+
295
+ Args:
296
+ name (`str` or `SchedulerType`):
297
+ The name of the scheduler to use.
298
+ optimizer (`torch.optim.Optimizer`):
299
+ The optimizer that will be used during training.
300
+ step_rules (`str`, *optional*):
301
+ A string representing the step rules to use. This is only used by the `PIECEWISE_CONSTANT` scheduler.
302
+ num_warmup_steps (`int`, *optional*):
303
+ The number of warmup steps to do. This is not required by all schedulers (hence the argument being
304
+ optional), the function will raise an error if it's unset and the scheduler type requires it.
305
+ num_training_steps (`int``, *optional*):
306
+ The number of training steps to do. This is not required by all schedulers (hence the argument being
307
+ optional), the function will raise an error if it's unset and the scheduler type requires it.
308
+ num_cycles (`int`, *optional*):
309
+ The number of hard restarts used in `COSINE_WITH_RESTARTS` scheduler.
310
+ power (`float`, *optional*, defaults to 1.0):
311
+ Power factor. See `POLYNOMIAL` scheduler
312
+ last_epoch (`int`, *optional*, defaults to -1):
313
+ The index of the last epoch when resuming training.
314
+ """
315
+ name = SchedulerType(name)
316
+ schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
317
+ if name == SchedulerType.CONSTANT:
318
+ return schedule_func(optimizer, last_epoch=last_epoch)
319
+
320
+ if name == SchedulerType.PIECEWISE_CONSTANT:
321
+ return schedule_func(optimizer, step_rules=step_rules, last_epoch=last_epoch)
322
+
323
+ # All other schedulers require `num_warmup_steps`
324
+ if num_warmup_steps is None:
325
+ raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")
326
+
327
+ if name == SchedulerType.CONSTANT_WITH_WARMUP:
328
+ return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, last_epoch=last_epoch)
329
+
330
+ # All other schedulers require `num_training_steps`
331
+ if num_training_steps is None:
332
+ raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
333
+
334
+ if name == SchedulerType.COSINE_WITH_RESTARTS:
335
+ return schedule_func(
336
+ optimizer,
337
+ num_warmup_steps=num_warmup_steps,
338
+ num_training_steps=num_training_steps,
339
+ num_cycles=num_cycles,
340
+ last_epoch=last_epoch,
341
+ )
342
+
343
+ if name == SchedulerType.POLYNOMIAL:
344
+ return schedule_func(
345
+ optimizer,
346
+ num_warmup_steps=num_warmup_steps,
347
+ num_training_steps=num_training_steps,
348
+ power=power,
349
+ last_epoch=last_epoch,
350
+ )
351
+
352
+ return schedule_func(
353
+ optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, last_epoch=last_epoch
354
+ )
6DoF/diffusers/pipeline_utils.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+
14
+ # limitations under the License.
15
+
16
+ # NOTE: This file is deprecated and will be removed in a future version.
17
+ # It only exists so that temporarely `from diffusers.pipelines import DiffusionPipeline` works
18
+
19
+ from .pipelines import DiffusionPipeline, ImagePipelineOutput # noqa: F401
20
+ from .utils import deprecate
21
+
22
+
23
+ deprecate(
24
+ "pipelines_utils",
25
+ "0.22.0",
26
+ "Importing `DiffusionPipeline` or `ImagePipelineOutput` from diffusers.pipeline_utils is deprecated. Please import from diffusers.pipelines.pipeline_utils instead.",
27
+ standard_warn=False,
28
+ stacklevel=3,
29
+ )