kxic commited on
Commit
6d86936
·
verified ·
1 Parent(s): 23aae87

Delete gradio_demo

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. gradio_demo/6DoF/CN_encoder.py +0 -36
  2. gradio_demo/6DoF/dataset.py +0 -176
  3. gradio_demo/6DoF/diffusers/__init__.py +0 -281
  4. gradio_demo/6DoF/diffusers/commands/__init__.py +0 -27
  5. gradio_demo/6DoF/diffusers/commands/diffusers_cli.py +0 -41
  6. gradio_demo/6DoF/diffusers/commands/env.py +0 -84
  7. gradio_demo/6DoF/diffusers/configuration_utils.py +0 -664
  8. gradio_demo/6DoF/diffusers/dependency_versions_check.py +0 -47
  9. gradio_demo/6DoF/diffusers/dependency_versions_table.py +0 -44
  10. gradio_demo/6DoF/diffusers/experimental/__init__.py +0 -1
  11. gradio_demo/6DoF/diffusers/experimental/rl/__init__.py +0 -1
  12. gradio_demo/6DoF/diffusers/experimental/rl/value_guided_sampling.py +0 -152
  13. gradio_demo/6DoF/diffusers/image_processor.py +0 -366
  14. gradio_demo/6DoF/diffusers/loaders.py +0 -1492
  15. gradio_demo/6DoF/diffusers/models/__init__.py +0 -35
  16. gradio_demo/6DoF/diffusers/models/activations.py +0 -12
  17. gradio_demo/6DoF/diffusers/models/attention.py +0 -392
  18. gradio_demo/6DoF/diffusers/models/attention_flax.py +0 -446
  19. gradio_demo/6DoF/diffusers/models/attention_processor.py +0 -1684
  20. gradio_demo/6DoF/diffusers/models/autoencoder_kl.py +0 -411
  21. gradio_demo/6DoF/diffusers/models/controlnet.py +0 -705
  22. gradio_demo/6DoF/diffusers/models/controlnet_flax.py +0 -394
  23. gradio_demo/6DoF/diffusers/models/cross_attention.py +0 -94
  24. gradio_demo/6DoF/diffusers/models/dual_transformer_2d.py +0 -151
  25. gradio_demo/6DoF/diffusers/models/embeddings.py +0 -546
  26. gradio_demo/6DoF/diffusers/models/embeddings_flax.py +0 -95
  27. gradio_demo/6DoF/diffusers/models/modeling_flax_pytorch_utils.py +0 -118
  28. gradio_demo/6DoF/diffusers/models/modeling_flax_utils.py +0 -534
  29. gradio_demo/6DoF/diffusers/models/modeling_pytorch_flax_utils.py +0 -161
  30. gradio_demo/6DoF/diffusers/models/modeling_utils.py +0 -980
  31. gradio_demo/6DoF/diffusers/models/prior_transformer.py +0 -364
  32. gradio_demo/6DoF/diffusers/models/resnet.py +0 -877
  33. gradio_demo/6DoF/diffusers/models/resnet_flax.py +0 -124
  34. gradio_demo/6DoF/diffusers/models/t5_film_transformer.py +0 -321
  35. gradio_demo/6DoF/diffusers/models/transformer_2d.py +0 -343
  36. gradio_demo/6DoF/diffusers/models/transformer_temporal.py +0 -179
  37. gradio_demo/6DoF/diffusers/models/unet_1d.py +0 -255
  38. gradio_demo/6DoF/diffusers/models/unet_1d_blocks.py +0 -656
  39. gradio_demo/6DoF/diffusers/models/unet_2d.py +0 -329
  40. gradio_demo/6DoF/diffusers/models/unet_2d_blocks.py +0 -0
  41. gradio_demo/6DoF/diffusers/models/unet_2d_blocks_flax.py +0 -377
  42. gradio_demo/6DoF/diffusers/models/unet_2d_condition.py +0 -980
  43. gradio_demo/6DoF/diffusers/models/unet_2d_condition_flax.py +0 -357
  44. gradio_demo/6DoF/diffusers/models/unet_3d_blocks.py +0 -679
  45. gradio_demo/6DoF/diffusers/models/unet_3d_condition.py +0 -627
  46. gradio_demo/6DoF/diffusers/models/vae.py +0 -441
  47. gradio_demo/6DoF/diffusers/models/vae_flax.py +0 -869
  48. gradio_demo/6DoF/diffusers/models/vq_model.py +0 -167
  49. gradio_demo/6DoF/diffusers/optimization.py +0 -354
  50. gradio_demo/6DoF/diffusers/pipeline_utils.py +0 -29
gradio_demo/6DoF/CN_encoder.py DELETED
@@ -1,36 +0,0 @@
1
- from transformers import ConvNextV2Model
2
- import torch
3
- from typing import Optional
4
- import einops
5
-
6
- class CN_encoder(ConvNextV2Model):
7
- def __init__(self, config):
8
- super().__init__(config)
9
-
10
- def forward(
11
- self,
12
- pixel_values: torch.FloatTensor = None,
13
- output_hidden_states: Optional[bool] = None,
14
- return_dict: Optional[bool] = None,
15
- ):
16
- output_hidden_states = (
17
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
18
- )
19
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
20
-
21
- if pixel_values is None:
22
- raise ValueError("You have to specify pixel_values")
23
-
24
- embedding_output = self.embeddings(pixel_values)
25
-
26
- encoder_outputs = self.encoder(
27
- embedding_output,
28
- output_hidden_states=output_hidden_states,
29
- return_dict=return_dict,
30
- )
31
-
32
- last_hidden_state = encoder_outputs[0]
33
- image_embeddings = einops.rearrange(last_hidden_state, 'b c h w -> b (h w) c')
34
- image_embeddings = self.layernorm(image_embeddings)
35
-
36
- return image_embeddings
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
gradio_demo/6DoF/dataset.py DELETED
@@ -1,176 +0,0 @@
1
- import os
2
- import math
3
- from pathlib import Path
4
- import torch
5
- import torchvision
6
- from torch.utils.data import Dataset, DataLoader
7
- from torchvision import transforms
8
- from PIL import Image
9
- import numpy as np
10
- import webdataset as wds
11
- from torch.utils.data.distributed import DistributedSampler
12
- import matplotlib.pyplot as plt
13
- import sys
14
-
15
- class ObjaverseDataLoader():
16
- def __init__(self, root_dir, batch_size, total_view=12, num_workers=4):
17
- self.root_dir = root_dir
18
- self.batch_size = batch_size
19
- self.num_workers = num_workers
20
- self.total_view = total_view
21
-
22
- image_transforms = [torchvision.transforms.Resize((256, 256)),
23
- transforms.ToTensor(),
24
- transforms.Normalize([0.5], [0.5])]
25
- self.image_transforms = torchvision.transforms.Compose(image_transforms)
26
-
27
- def train_dataloader(self):
28
- dataset = ObjaverseData(root_dir=self.root_dir, total_view=self.total_view, validation=False,
29
- image_transforms=self.image_transforms)
30
- # sampler = DistributedSampler(dataset)
31
- return wds.WebLoader(dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False)
32
- # sampler=sampler)
33
-
34
- def val_dataloader(self):
35
- dataset = ObjaverseData(root_dir=self.root_dir, total_view=self.total_view, validation=True,
36
- image_transforms=self.image_transforms)
37
- sampler = DistributedSampler(dataset)
38
- return wds.WebLoader(dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False)
39
-
40
- def get_pose(transformation):
41
- # transformation: 4x4
42
- return transformation
43
-
44
- class ObjaverseData(Dataset):
45
- def __init__(self,
46
- root_dir='.objaverse/hf-objaverse-v1/views',
47
- image_transforms=None,
48
- total_view=12,
49
- validation=False,
50
- T_in=1,
51
- T_out=1,
52
- fix_sample=False,
53
- ) -> None:
54
- """Create a dataset from a folder of images.
55
- If you pass in a root directory it will be searched for images
56
- ending in ext (ext can be a list)
57
- """
58
- self.root_dir = Path(root_dir)
59
- self.total_view = total_view
60
- self.T_in = T_in
61
- self.T_out = T_out
62
- self.fix_sample = fix_sample
63
-
64
- self.paths = []
65
- # # include all folders
66
- # for folder in os.listdir(self.root_dir):
67
- # if os.path.isdir(os.path.join(self.root_dir, folder)):
68
- # self.paths.append(folder)
69
- # load ids from .npy so we have exactly the same ids/order
70
- self.paths = np.load("../scripts/obj_ids.npy")
71
- # # only use 100K objects for ablation study
72
- # self.paths = self.paths[:100000]
73
- total_objects = len(self.paths)
74
- assert total_objects == 790152, 'total objects %d' % total_objects
75
- if validation:
76
- self.paths = self.paths[math.floor(total_objects / 100. * 99.):] # used last 1% as validation
77
- else:
78
- self.paths = self.paths[:math.floor(total_objects / 100. * 99.)] # used first 99% as training
79
- print('============= length of dataset %d =============' % len(self.paths))
80
- self.tform = image_transforms
81
-
82
- downscale = 512 / 256.
83
- self.fx = 560. / downscale
84
- self.fy = 560. / downscale
85
- self.intrinsic = torch.tensor([[self.fx, 0, 128., 0, self.fy, 128., 0, 0, 1.]], dtype=torch.float64).view(3, 3)
86
-
87
- def __len__(self):
88
- return len(self.paths)
89
-
90
- def get_pose(self, transformation):
91
- # transformation: 4x4
92
- return transformation
93
-
94
-
95
- def load_im(self, path, color):
96
- '''
97
- replace background pixel with random color in rendering
98
- '''
99
- try:
100
- img = plt.imread(path)
101
- except:
102
- print(path)
103
- sys.exit()
104
- img[img[:, :, -1] == 0.] = color
105
- img = Image.fromarray(np.uint8(img[:, :, :3] * 255.))
106
- return img
107
-
108
- def __getitem__(self, index):
109
- data = {}
110
- total_view = 12
111
-
112
- if self.fix_sample:
113
- if self.T_out > 1:
114
- indexes = range(total_view)
115
- index_targets = list(indexes[:2]) + list(indexes[-(self.T_out-2):])
116
- index_inputs = indexes[1:self.T_in+1] # one overlap identity
117
- else:
118
- indexes = range(total_view)
119
- index_targets = indexes[:self.T_out]
120
- index_inputs = indexes[self.T_out-1:self.T_in+self.T_out-1] # one overlap identity
121
- else:
122
- assert self.T_in + self.T_out <= total_view
123
- # training with replace, including identity
124
- indexes = np.random.choice(range(total_view), self.T_in+self.T_out, replace=True)
125
- index_inputs = indexes[:self.T_in]
126
- index_targets = indexes[self.T_in:]
127
- filename = os.path.join(self.root_dir, self.paths[index])
128
-
129
- color = [1., 1., 1., 1.]
130
-
131
- try:
132
- input_ims = []
133
- target_ims = []
134
- target_Ts = []
135
- cond_Ts = []
136
- for i, index_input in enumerate(index_inputs):
137
- input_im = self.process_im(self.load_im(os.path.join(filename, '%03d.png' % index_input), color))
138
- input_ims.append(input_im)
139
- input_RT = np.load(os.path.join(filename, '%03d.npy' % index_input))
140
- cond_Ts.append(self.get_pose(np.concatenate([input_RT[:3, :], np.array([[0, 0, 0, 1]])], axis=0)))
141
- for i, index_target in enumerate(index_targets):
142
- target_im = self.process_im(self.load_im(os.path.join(filename, '%03d.png' % index_target), color))
143
- target_ims.append(target_im)
144
- target_RT = np.load(os.path.join(filename, '%03d.npy' % index_target))
145
- target_Ts.append(self.get_pose(np.concatenate([target_RT[:3, :], np.array([[0, 0, 0, 1]])], axis=0)))
146
- except:
147
- print('error loading data ', filename)
148
- filename = os.path.join(self.root_dir, '0a01f314e2864711aa7e33bace4bd8c8') # this one we know is valid
149
- input_ims = []
150
- target_ims = []
151
- target_Ts = []
152
- cond_Ts = []
153
- # very hacky solution, sorry about this
154
- for i, index_input in enumerate(index_inputs):
155
- input_im = self.process_im(self.load_im(os.path.join(filename, '%03d.png' % index_input), color))
156
- input_ims.append(input_im)
157
- input_RT = np.load(os.path.join(filename, '%03d.npy' % index_input))
158
- cond_Ts.append(self.get_pose(np.concatenate([input_RT[:3, :], np.array([[0, 0, 0, 1]])], axis=0)))
159
- for i, index_target in enumerate(index_targets):
160
- target_im = self.process_im(self.load_im(os.path.join(filename, '%03d.png' % index_target), color))
161
- target_ims.append(target_im)
162
- target_RT = np.load(os.path.join(filename, '%03d.npy' % index_target))
163
- target_Ts.append(self.get_pose(np.concatenate([target_RT[:3, :], np.array([[0, 0, 0, 1]])], axis=0)))
164
-
165
- # stack to batch
166
- data['image_input'] = torch.stack(input_ims, dim=0)
167
- data['image_target'] = torch.stack(target_ims, dim=0)
168
- data['pose_out'] = np.stack(target_Ts)
169
- data['pose_out_inv'] = np.linalg.inv(np.stack(target_Ts)).transpose([0, 2, 1])
170
- data['pose_in'] = np.stack(cond_Ts)
171
- data['pose_in_inv'] = np.linalg.inv(np.stack(cond_Ts)).transpose([0, 2, 1])
172
- return data
173
-
174
- def process_im(self, im):
175
- im = im.convert("RGB")
176
- return self.tform(im)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
gradio_demo/6DoF/diffusers/__init__.py DELETED
@@ -1,281 +0,0 @@
1
- __version__ = "0.18.2"
2
-
3
- from .configuration_utils import ConfigMixin
4
- from .utils import (
5
- OptionalDependencyNotAvailable,
6
- is_flax_available,
7
- is_inflect_available,
8
- is_invisible_watermark_available,
9
- is_k_diffusion_available,
10
- is_k_diffusion_version,
11
- is_librosa_available,
12
- is_note_seq_available,
13
- is_onnx_available,
14
- is_scipy_available,
15
- is_torch_available,
16
- is_torchsde_available,
17
- is_transformers_available,
18
- is_transformers_version,
19
- is_unidecode_available,
20
- logging,
21
- )
22
-
23
-
24
- try:
25
- if not is_onnx_available():
26
- raise OptionalDependencyNotAvailable()
27
- except OptionalDependencyNotAvailable:
28
- from .utils.dummy_onnx_objects import * # noqa F403
29
- else:
30
- from .pipelines import OnnxRuntimeModel
31
-
32
- try:
33
- if not is_torch_available():
34
- raise OptionalDependencyNotAvailable()
35
- except OptionalDependencyNotAvailable:
36
- from .utils.dummy_pt_objects import * # noqa F403
37
- else:
38
- from .models import (
39
- AutoencoderKL,
40
- ControlNetModel,
41
- ModelMixin,
42
- PriorTransformer,
43
- T5FilmDecoder,
44
- Transformer2DModel,
45
- UNet1DModel,
46
- UNet2DConditionModel,
47
- UNet2DModel,
48
- UNet3DConditionModel,
49
- VQModel,
50
- )
51
- from .optimization import (
52
- get_constant_schedule,
53
- get_constant_schedule_with_warmup,
54
- get_cosine_schedule_with_warmup,
55
- get_cosine_with_hard_restarts_schedule_with_warmup,
56
- get_linear_schedule_with_warmup,
57
- get_polynomial_decay_schedule_with_warmup,
58
- get_scheduler,
59
- )
60
- from .pipelines import (
61
- AudioPipelineOutput,
62
- ConsistencyModelPipeline,
63
- DanceDiffusionPipeline,
64
- DDIMPipeline,
65
- DDPMPipeline,
66
- DiffusionPipeline,
67
- DiTPipeline,
68
- ImagePipelineOutput,
69
- KarrasVePipeline,
70
- LDMPipeline,
71
- LDMSuperResolutionPipeline,
72
- PNDMPipeline,
73
- RePaintPipeline,
74
- ScoreSdeVePipeline,
75
- )
76
- from .schedulers import (
77
- CMStochasticIterativeScheduler,
78
- DDIMInverseScheduler,
79
- DDIMParallelScheduler,
80
- DDIMScheduler,
81
- DDPMParallelScheduler,
82
- DDPMScheduler,
83
- DEISMultistepScheduler,
84
- DPMSolverMultistepInverseScheduler,
85
- DPMSolverMultistepScheduler,
86
- DPMSolverSinglestepScheduler,
87
- EulerAncestralDiscreteScheduler,
88
- EulerDiscreteScheduler,
89
- HeunDiscreteScheduler,
90
- IPNDMScheduler,
91
- KarrasVeScheduler,
92
- KDPM2AncestralDiscreteScheduler,
93
- KDPM2DiscreteScheduler,
94
- PNDMScheduler,
95
- RePaintScheduler,
96
- SchedulerMixin,
97
- ScoreSdeVeScheduler,
98
- UnCLIPScheduler,
99
- UniPCMultistepScheduler,
100
- VQDiffusionScheduler,
101
- )
102
- from .training_utils import EMAModel
103
-
104
- try:
105
- if not (is_torch_available() and is_scipy_available()):
106
- raise OptionalDependencyNotAvailable()
107
- except OptionalDependencyNotAvailable:
108
- from .utils.dummy_torch_and_scipy_objects import * # noqa F403
109
- else:
110
- from .schedulers import LMSDiscreteScheduler
111
-
112
- try:
113
- if not (is_torch_available() and is_torchsde_available()):
114
- raise OptionalDependencyNotAvailable()
115
- except OptionalDependencyNotAvailable:
116
- from .utils.dummy_torch_and_torchsde_objects import * # noqa F403
117
- else:
118
- from .schedulers import DPMSolverSDEScheduler
119
-
120
- try:
121
- if not (is_torch_available() and is_transformers_available()):
122
- raise OptionalDependencyNotAvailable()
123
- except OptionalDependencyNotAvailable:
124
- from .utils.dummy_torch_and_transformers_objects import * # noqa F403
125
- else:
126
- from .pipelines import (
127
- AltDiffusionImg2ImgPipeline,
128
- AltDiffusionPipeline,
129
- AudioLDMPipeline,
130
- CycleDiffusionPipeline,
131
- IFImg2ImgPipeline,
132
- IFImg2ImgSuperResolutionPipeline,
133
- IFInpaintingPipeline,
134
- IFInpaintingSuperResolutionPipeline,
135
- IFPipeline,
136
- IFSuperResolutionPipeline,
137
- ImageTextPipelineOutput,
138
- KandinskyImg2ImgPipeline,
139
- KandinskyInpaintPipeline,
140
- KandinskyPipeline,
141
- KandinskyPriorPipeline,
142
- KandinskyV22ControlnetImg2ImgPipeline,
143
- KandinskyV22ControlnetPipeline,
144
- KandinskyV22Img2ImgPipeline,
145
- KandinskyV22InpaintPipeline,
146
- KandinskyV22Pipeline,
147
- KandinskyV22PriorEmb2EmbPipeline,
148
- KandinskyV22PriorPipeline,
149
- LDMTextToImagePipeline,
150
- PaintByExamplePipeline,
151
- SemanticStableDiffusionPipeline,
152
- ShapEImg2ImgPipeline,
153
- ShapEPipeline,
154
- StableDiffusionAttendAndExcitePipeline,
155
- StableDiffusionControlNetImg2ImgPipeline,
156
- StableDiffusionControlNetInpaintPipeline,
157
- StableDiffusionControlNetPipeline,
158
- StableDiffusionDepth2ImgPipeline,
159
- StableDiffusionDiffEditPipeline,
160
- StableDiffusionImageVariationPipeline,
161
- StableDiffusionImg2ImgPipeline,
162
- StableDiffusionInpaintPipeline,
163
- StableDiffusionInpaintPipelineLegacy,
164
- StableDiffusionInstructPix2PixPipeline,
165
- StableDiffusionLatentUpscalePipeline,
166
- StableDiffusionLDM3DPipeline,
167
- StableDiffusionModelEditingPipeline,
168
- StableDiffusionPanoramaPipeline,
169
- StableDiffusionParadigmsPipeline,
170
- StableDiffusionPipeline,
171
- StableDiffusionPipelineSafe,
172
- StableDiffusionPix2PixZeroPipeline,
173
- StableDiffusionSAGPipeline,
174
- StableDiffusionUpscalePipeline,
175
- StableUnCLIPImg2ImgPipeline,
176
- StableUnCLIPPipeline,
177
- TextToVideoSDPipeline,
178
- TextToVideoZeroPipeline,
179
- UnCLIPImageVariationPipeline,
180
- UnCLIPPipeline,
181
- UniDiffuserModel,
182
- UniDiffuserPipeline,
183
- UniDiffuserTextDecoder,
184
- VersatileDiffusionDualGuidedPipeline,
185
- VersatileDiffusionImageVariationPipeline,
186
- VersatileDiffusionPipeline,
187
- VersatileDiffusionTextToImagePipeline,
188
- VideoToVideoSDPipeline,
189
- VQDiffusionPipeline,
190
- )
191
-
192
- try:
193
- if not (is_torch_available() and is_transformers_available() and is_invisible_watermark_available()):
194
- raise OptionalDependencyNotAvailable()
195
- except OptionalDependencyNotAvailable:
196
- from .utils.dummy_torch_and_transformers_and_invisible_watermark_objects import * # noqa F403
197
- else:
198
- from .pipelines import StableDiffusionXLImg2ImgPipeline, StableDiffusionXLPipeline
199
-
200
- try:
201
- if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()):
202
- raise OptionalDependencyNotAvailable()
203
- except OptionalDependencyNotAvailable:
204
- from .utils.dummy_torch_and_transformers_and_k_diffusion_objects import * # noqa F403
205
- else:
206
- from .pipelines import StableDiffusionKDiffusionPipeline
207
-
208
- try:
209
- if not (is_torch_available() and is_transformers_available() and is_onnx_available()):
210
- raise OptionalDependencyNotAvailable()
211
- except OptionalDependencyNotAvailable:
212
- from .utils.dummy_torch_and_transformers_and_onnx_objects import * # noqa F403
213
- else:
214
- from .pipelines import (
215
- OnnxStableDiffusionImg2ImgPipeline,
216
- OnnxStableDiffusionInpaintPipeline,
217
- OnnxStableDiffusionInpaintPipelineLegacy,
218
- OnnxStableDiffusionPipeline,
219
- OnnxStableDiffusionUpscalePipeline,
220
- StableDiffusionOnnxPipeline,
221
- )
222
-
223
- try:
224
- if not (is_torch_available() and is_librosa_available()):
225
- raise OptionalDependencyNotAvailable()
226
- except OptionalDependencyNotAvailable:
227
- from .utils.dummy_torch_and_librosa_objects import * # noqa F403
228
- else:
229
- from .pipelines import AudioDiffusionPipeline, Mel
230
-
231
- try:
232
- if not (is_transformers_available() and is_torch_available() and is_note_seq_available()):
233
- raise OptionalDependencyNotAvailable()
234
- except OptionalDependencyNotAvailable:
235
- from .utils.dummy_transformers_and_torch_and_note_seq_objects import * # noqa F403
236
- else:
237
- from .pipelines import SpectrogramDiffusionPipeline
238
-
239
- try:
240
- if not is_flax_available():
241
- raise OptionalDependencyNotAvailable()
242
- except OptionalDependencyNotAvailable:
243
- from .utils.dummy_flax_objects import * # noqa F403
244
- else:
245
- from .models.controlnet_flax import FlaxControlNetModel
246
- from .models.modeling_flax_utils import FlaxModelMixin
247
- from .models.unet_2d_condition_flax import FlaxUNet2DConditionModel
248
- from .models.vae_flax import FlaxAutoencoderKL
249
- from .pipelines import FlaxDiffusionPipeline
250
- from .schedulers import (
251
- FlaxDDIMScheduler,
252
- FlaxDDPMScheduler,
253
- FlaxDPMSolverMultistepScheduler,
254
- FlaxKarrasVeScheduler,
255
- FlaxLMSDiscreteScheduler,
256
- FlaxPNDMScheduler,
257
- FlaxSchedulerMixin,
258
- FlaxScoreSdeVeScheduler,
259
- )
260
-
261
-
262
- try:
263
- if not (is_flax_available() and is_transformers_available()):
264
- raise OptionalDependencyNotAvailable()
265
- except OptionalDependencyNotAvailable:
266
- from .utils.dummy_flax_and_transformers_objects import * # noqa F403
267
- else:
268
- from .pipelines import (
269
- FlaxStableDiffusionControlNetPipeline,
270
- FlaxStableDiffusionImg2ImgPipeline,
271
- FlaxStableDiffusionInpaintPipeline,
272
- FlaxStableDiffusionPipeline,
273
- )
274
-
275
- try:
276
- if not (is_note_seq_available()):
277
- raise OptionalDependencyNotAvailable()
278
- except OptionalDependencyNotAvailable:
279
- from .utils.dummy_note_seq_objects import * # noqa F403
280
- else:
281
- from .pipelines import MidiProcessor
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
gradio_demo/6DoF/diffusers/commands/__init__.py DELETED
@@ -1,27 +0,0 @@
1
- # Copyright 2023 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- from abc import ABC, abstractmethod
16
- from argparse import ArgumentParser
17
-
18
-
19
- class BaseDiffusersCLICommand(ABC):
20
- @staticmethod
21
- @abstractmethod
22
- def register_subcommand(parser: ArgumentParser):
23
- raise NotImplementedError()
24
-
25
- @abstractmethod
26
- def run(self):
27
- raise NotImplementedError()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
gradio_demo/6DoF/diffusers/commands/diffusers_cli.py DELETED
@@ -1,41 +0,0 @@
1
- #!/usr/bin/env python
2
- # Copyright 2023 The HuggingFace Team. All rights reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- from argparse import ArgumentParser
17
-
18
- from .env import EnvironmentCommand
19
-
20
-
21
- def main():
22
- parser = ArgumentParser("Diffusers CLI tool", usage="diffusers-cli <command> [<args>]")
23
- commands_parser = parser.add_subparsers(help="diffusers-cli command helpers")
24
-
25
- # Register commands
26
- EnvironmentCommand.register_subcommand(commands_parser)
27
-
28
- # Let's go
29
- args = parser.parse_args()
30
-
31
- if not hasattr(args, "func"):
32
- parser.print_help()
33
- exit(1)
34
-
35
- # Run
36
- service = args.func(args)
37
- service.run()
38
-
39
-
40
- if __name__ == "__main__":
41
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
gradio_demo/6DoF/diffusers/commands/env.py DELETED
@@ -1,84 +0,0 @@
1
- # Copyright 2023 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- import platform
16
- from argparse import ArgumentParser
17
-
18
- import huggingface_hub
19
-
20
- from .. import __version__ as version
21
- from ..utils import is_accelerate_available, is_torch_available, is_transformers_available, is_xformers_available
22
- from . import BaseDiffusersCLICommand
23
-
24
-
25
- def info_command_factory(_):
26
- return EnvironmentCommand()
27
-
28
-
29
- class EnvironmentCommand(BaseDiffusersCLICommand):
30
- @staticmethod
31
- def register_subcommand(parser: ArgumentParser):
32
- download_parser = parser.add_parser("env")
33
- download_parser.set_defaults(func=info_command_factory)
34
-
35
- def run(self):
36
- hub_version = huggingface_hub.__version__
37
-
38
- pt_version = "not installed"
39
- pt_cuda_available = "NA"
40
- if is_torch_available():
41
- import torch
42
-
43
- pt_version = torch.__version__
44
- pt_cuda_available = torch.cuda.is_available()
45
-
46
- transformers_version = "not installed"
47
- if is_transformers_available():
48
- import transformers
49
-
50
- transformers_version = transformers.__version__
51
-
52
- accelerate_version = "not installed"
53
- if is_accelerate_available():
54
- import accelerate
55
-
56
- accelerate_version = accelerate.__version__
57
-
58
- xformers_version = "not installed"
59
- if is_xformers_available():
60
- import xformers
61
-
62
- xformers_version = xformers.__version__
63
-
64
- info = {
65
- "`diffusers` version": version,
66
- "Platform": platform.platform(),
67
- "Python version": platform.python_version(),
68
- "PyTorch version (GPU?)": f"{pt_version} ({pt_cuda_available})",
69
- "Huggingface_hub version": hub_version,
70
- "Transformers version": transformers_version,
71
- "Accelerate version": accelerate_version,
72
- "xFormers version": xformers_version,
73
- "Using GPU in script?": "<fill in>",
74
- "Using distributed or parallel set-up in script?": "<fill in>",
75
- }
76
-
77
- print("\nCopy-and-paste the text below in your GitHub issue and FILL OUT the two last points.\n")
78
- print(self.format_dict(info))
79
-
80
- return info
81
-
82
- @staticmethod
83
- def format_dict(d):
84
- return "\n".join([f"- {prop}: {val}" for prop, val in d.items()]) + "\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
gradio_demo/6DoF/diffusers/configuration_utils.py DELETED
@@ -1,664 +0,0 @@
1
- # coding=utf-8
2
- # Copyright 2023 The HuggingFace Inc. team.
3
- # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
4
- #
5
- # Licensed under the Apache License, Version 2.0 (the "License");
6
- # you may not use this file except in compliance with the License.
7
- # You may obtain a copy of the License at
8
- #
9
- # http://www.apache.org/licenses/LICENSE-2.0
10
- #
11
- # Unless required by applicable law or agreed to in writing, software
12
- # distributed under the License is distributed on an "AS IS" BASIS,
13
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
- # See the License for the specific language governing permissions and
15
- # limitations under the License.
16
- """ ConfigMixin base class and utilities."""
17
- import dataclasses
18
- import functools
19
- import importlib
20
- import inspect
21
- import json
22
- import os
23
- import re
24
- from collections import OrderedDict
25
- from pathlib import PosixPath
26
- from typing import Any, Dict, Tuple, Union
27
-
28
- import numpy as np
29
- from huggingface_hub import hf_hub_download
30
- from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
31
- from requests import HTTPError
32
-
33
- from . import __version__
34
- from .utils import (
35
- DIFFUSERS_CACHE,
36
- HUGGINGFACE_CO_RESOLVE_ENDPOINT,
37
- DummyObject,
38
- deprecate,
39
- extract_commit_hash,
40
- http_user_agent,
41
- logging,
42
- )
43
-
44
-
45
- logger = logging.get_logger(__name__)
46
-
47
- _re_configuration_file = re.compile(r"config\.(.*)\.json")
48
-
49
-
50
- class FrozenDict(OrderedDict):
51
- def __init__(self, *args, **kwargs):
52
- super().__init__(*args, **kwargs)
53
-
54
- for key, value in self.items():
55
- setattr(self, key, value)
56
-
57
- self.__frozen = True
58
-
59
- def __delitem__(self, *args, **kwargs):
60
- raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.")
61
-
62
- def setdefault(self, *args, **kwargs):
63
- raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.")
64
-
65
- def pop(self, *args, **kwargs):
66
- raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.")
67
-
68
- def update(self, *args, **kwargs):
69
- raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.")
70
-
71
- def __setattr__(self, name, value):
72
- if hasattr(self, "__frozen") and self.__frozen:
73
- raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.")
74
- super().__setattr__(name, value)
75
-
76
- def __setitem__(self, name, value):
77
- if hasattr(self, "__frozen") and self.__frozen:
78
- raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.")
79
- super().__setitem__(name, value)
80
-
81
-
82
- class ConfigMixin:
83
- r"""
84
- Base class for all configuration classes. All configuration parameters are stored under `self.config`. Also
85
- provides the [`~ConfigMixin.from_config`] and [`~ConfigMixin.save_config`] methods for loading, downloading, and
86
- saving classes that inherit from [`ConfigMixin`].
87
-
88
- Class attributes:
89
- - **config_name** (`str`) -- A filename under which the config should stored when calling
90
- [`~ConfigMixin.save_config`] (should be overridden by parent class).
91
- - **ignore_for_config** (`List[str]`) -- A list of attributes that should not be saved in the config (should be
92
- overridden by subclass).
93
- - **has_compatibles** (`bool`) -- Whether the class has compatible classes (should be overridden by subclass).
94
- - **_deprecated_kwargs** (`List[str]`) -- Keyword arguments that are deprecated. Note that the `init` function
95
- should only have a `kwargs` argument if at least one argument is deprecated (should be overridden by
96
- subclass).
97
- """
98
- config_name = None
99
- ignore_for_config = []
100
- has_compatibles = False
101
-
102
- _deprecated_kwargs = []
103
-
104
- def register_to_config(self, **kwargs):
105
- if self.config_name is None:
106
- raise NotImplementedError(f"Make sure that {self.__class__} has defined a class name `config_name`")
107
- # Special case for `kwargs` used in deprecation warning added to schedulers
108
- # TODO: remove this when we remove the deprecation warning, and the `kwargs` argument,
109
- # or solve in a more general way.
110
- kwargs.pop("kwargs", None)
111
-
112
- if not hasattr(self, "_internal_dict"):
113
- internal_dict = kwargs
114
- else:
115
- previous_dict = dict(self._internal_dict)
116
- internal_dict = {**self._internal_dict, **kwargs}
117
- logger.debug(f"Updating config from {previous_dict} to {internal_dict}")
118
-
119
- self._internal_dict = FrozenDict(internal_dict)
120
-
121
- def __getattr__(self, name: str) -> Any:
122
- """The only reason we overwrite `getattr` here is to gracefully deprecate accessing
123
- config attributes directly. See https://github.com/huggingface/diffusers/pull/3129
124
-
125
- Tihs funtion is mostly copied from PyTorch's __getattr__ overwrite:
126
- https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module
127
- """
128
-
129
- is_in_config = "_internal_dict" in self.__dict__ and hasattr(self.__dict__["_internal_dict"], name)
130
- is_attribute = name in self.__dict__
131
-
132
- if is_in_config and not is_attribute:
133
- deprecation_message = f"Accessing config attribute `{name}` directly via '{type(self).__name__}' object attribute is deprecated. Please access '{name}' over '{type(self).__name__}'s config object instead, e.g. 'scheduler.config.{name}'."
134
- deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False)
135
- return self._internal_dict[name]
136
-
137
- raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
138
-
139
- def save_config(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
140
- """
141
- Save a configuration object to the directory specified in `save_directory` so that it can be reloaded using the
142
- [`~ConfigMixin.from_config`] class method.
143
-
144
- Args:
145
- save_directory (`str` or `os.PathLike`):
146
- Directory where the configuration JSON file is saved (will be created if it does not exist).
147
- """
148
- if os.path.isfile(save_directory):
149
- raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
150
-
151
- os.makedirs(save_directory, exist_ok=True)
152
-
153
- # If we save using the predefined names, we can load using `from_config`
154
- output_config_file = os.path.join(save_directory, self.config_name)
155
-
156
- self.to_json_file(output_config_file)
157
- logger.info(f"Configuration saved in {output_config_file}")
158
-
159
- @classmethod
160
- def from_config(cls, config: Union[FrozenDict, Dict[str, Any]] = None, return_unused_kwargs=False, **kwargs):
161
- r"""
162
- Instantiate a Python class from a config dictionary.
163
-
164
- Parameters:
165
- config (`Dict[str, Any]`):
166
- A config dictionary from which the Python class is instantiated. Make sure to only load configuration
167
- files of compatible classes.
168
- return_unused_kwargs (`bool`, *optional*, defaults to `False`):
169
- Whether kwargs that are not consumed by the Python class should be returned or not.
170
- kwargs (remaining dictionary of keyword arguments, *optional*):
171
- Can be used to update the configuration object (after it is loaded) and initiate the Python class.
172
- `**kwargs` are passed directly to the underlying scheduler/model's `__init__` method and eventually
173
- overwrite the same named arguments in `config`.
174
-
175
- Returns:
176
- [`ModelMixin`] or [`SchedulerMixin`]:
177
- A model or scheduler object instantiated from a config dictionary.
178
-
179
- Examples:
180
-
181
- ```python
182
- >>> from diffusers import DDPMScheduler, DDIMScheduler, PNDMScheduler
183
-
184
- >>> # Download scheduler from huggingface.co and cache.
185
- >>> scheduler = DDPMScheduler.from_pretrained("google/ddpm-cifar10-32")
186
-
187
- >>> # Instantiate DDIM scheduler class with same config as DDPM
188
- >>> scheduler = DDIMScheduler.from_config(scheduler.config)
189
-
190
- >>> # Instantiate PNDM scheduler class with same config as DDPM
191
- >>> scheduler = PNDMScheduler.from_config(scheduler.config)
192
- ```
193
- """
194
- # <===== TO BE REMOVED WITH DEPRECATION
195
- # TODO(Patrick) - make sure to remove the following lines when config=="model_path" is deprecated
196
- if "pretrained_model_name_or_path" in kwargs:
197
- config = kwargs.pop("pretrained_model_name_or_path")
198
-
199
- if config is None:
200
- raise ValueError("Please make sure to provide a config as the first positional argument.")
201
- # ======>
202
-
203
- if not isinstance(config, dict):
204
- deprecation_message = "It is deprecated to pass a pretrained model name or path to `from_config`."
205
- if "Scheduler" in cls.__name__:
206
- deprecation_message += (
207
- f"If you were trying to load a scheduler, please use {cls}.from_pretrained(...) instead."
208
- " Otherwise, please make sure to pass a configuration dictionary instead. This functionality will"
209
- " be removed in v1.0.0."
210
- )
211
- elif "Model" in cls.__name__:
212
- deprecation_message += (
213
- f"If you were trying to load a model, please use {cls}.load_config(...) followed by"
214
- f" {cls}.from_config(...) instead. Otherwise, please make sure to pass a configuration dictionary"
215
- " instead. This functionality will be removed in v1.0.0."
216
- )
217
- deprecate("config-passed-as-path", "1.0.0", deprecation_message, standard_warn=False)
218
- config, kwargs = cls.load_config(pretrained_model_name_or_path=config, return_unused_kwargs=True, **kwargs)
219
-
220
- init_dict, unused_kwargs, hidden_dict = cls.extract_init_dict(config, **kwargs)
221
-
222
- # Allow dtype to be specified on initialization
223
- if "dtype" in unused_kwargs:
224
- init_dict["dtype"] = unused_kwargs.pop("dtype")
225
-
226
- # add possible deprecated kwargs
227
- for deprecated_kwarg in cls._deprecated_kwargs:
228
- if deprecated_kwarg in unused_kwargs:
229
- init_dict[deprecated_kwarg] = unused_kwargs.pop(deprecated_kwarg)
230
-
231
- # Return model and optionally state and/or unused_kwargs
232
- model = cls(**init_dict)
233
-
234
- # make sure to also save config parameters that might be used for compatible classes
235
- model.register_to_config(**hidden_dict)
236
-
237
- # add hidden kwargs of compatible classes to unused_kwargs
238
- unused_kwargs = {**unused_kwargs, **hidden_dict}
239
-
240
- if return_unused_kwargs:
241
- return (model, unused_kwargs)
242
- else:
243
- return model
244
-
245
- @classmethod
246
- def get_config_dict(cls, *args, **kwargs):
247
- deprecation_message = (
248
- f" The function get_config_dict is deprecated. Please use {cls}.load_config instead. This function will be"
249
- " removed in version v1.0.0"
250
- )
251
- deprecate("get_config_dict", "1.0.0", deprecation_message, standard_warn=False)
252
- return cls.load_config(*args, **kwargs)
253
-
254
- @classmethod
255
- def load_config(
256
- cls,
257
- pretrained_model_name_or_path: Union[str, os.PathLike],
258
- return_unused_kwargs=False,
259
- return_commit_hash=False,
260
- **kwargs,
261
- ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
262
- r"""
263
- Load a model or scheduler configuration.
264
-
265
- Parameters:
266
- pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
267
- Can be either:
268
-
269
- - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
270
- the Hub.
271
- - A path to a *directory* (for example `./my_model_directory`) containing model weights saved with
272
- [`~ConfigMixin.save_config`].
273
-
274
- cache_dir (`Union[str, os.PathLike]`, *optional*):
275
- Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
276
- is not used.
277
- force_download (`bool`, *optional*, defaults to `False`):
278
- Whether or not to force the (re-)download of the model weights and configuration files, overriding the
279
- cached versions if they exist.
280
- resume_download (`bool`, *optional*, defaults to `False`):
281
- Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
282
- incompletely downloaded files are deleted.
283
- proxies (`Dict[str, str]`, *optional*):
284
- A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
285
- 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
286
- output_loading_info(`bool`, *optional*, defaults to `False`):
287
- Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
288
- local_files_only (`bool`, *optional*, defaults to `False`):
289
- Whether to only load local model weights and configuration files or not. If set to `True`, the model
290
- won't be downloaded from the Hub.
291
- use_auth_token (`str` or *bool*, *optional*):
292
- The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
293
- `diffusers-cli login` (stored in `~/.huggingface`) is used.
294
- revision (`str`, *optional*, defaults to `"main"`):
295
- The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
296
- allowed by Git.
297
- subfolder (`str`, *optional*, defaults to `""`):
298
- The subfolder location of a model file within a larger model repository on the Hub or locally.
299
- return_unused_kwargs (`bool`, *optional*, defaults to `False):
300
- Whether unused keyword arguments of the config are returned.
301
- return_commit_hash (`bool`, *optional*, defaults to `False):
302
- Whether the `commit_hash` of the loaded configuration are returned.
303
-
304
- Returns:
305
- `dict`:
306
- A dictionary of all the parameters stored in a JSON configuration file.
307
-
308
- """
309
- cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
310
- force_download = kwargs.pop("force_download", False)
311
- resume_download = kwargs.pop("resume_download", False)
312
- proxies = kwargs.pop("proxies", None)
313
- use_auth_token = kwargs.pop("use_auth_token", None)
314
- local_files_only = kwargs.pop("local_files_only", False)
315
- revision = kwargs.pop("revision", None)
316
- _ = kwargs.pop("mirror", None)
317
- subfolder = kwargs.pop("subfolder", None)
318
- user_agent = kwargs.pop("user_agent", {})
319
-
320
- user_agent = {**user_agent, "file_type": "config"}
321
- user_agent = http_user_agent(user_agent)
322
-
323
- pretrained_model_name_or_path = str(pretrained_model_name_or_path)
324
-
325
- if cls.config_name is None:
326
- raise ValueError(
327
- "`self.config_name` is not defined. Note that one should not load a config from "
328
- "`ConfigMixin`. Please make sure to define `config_name` in a class inheriting from `ConfigMixin`"
329
- )
330
-
331
- if os.path.isfile(pretrained_model_name_or_path):
332
- config_file = pretrained_model_name_or_path
333
- elif os.path.isdir(pretrained_model_name_or_path):
334
- if os.path.isfile(os.path.join(pretrained_model_name_or_path, cls.config_name)):
335
- # Load from a PyTorch checkpoint
336
- config_file = os.path.join(pretrained_model_name_or_path, cls.config_name)
337
- elif subfolder is not None and os.path.isfile(
338
- os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name)
339
- ):
340
- config_file = os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name)
341
- else:
342
- raise EnvironmentError(
343
- f"Error no file named {cls.config_name} found in directory {pretrained_model_name_or_path}."
344
- )
345
- else:
346
- try:
347
- # Load from URL or cache if already cached
348
- config_file = hf_hub_download(
349
- pretrained_model_name_or_path,
350
- filename=cls.config_name,
351
- cache_dir=cache_dir,
352
- force_download=force_download,
353
- proxies=proxies,
354
- resume_download=resume_download,
355
- local_files_only=local_files_only,
356
- use_auth_token=use_auth_token,
357
- user_agent=user_agent,
358
- subfolder=subfolder,
359
- revision=revision,
360
- )
361
- except RepositoryNotFoundError:
362
- raise EnvironmentError(
363
- f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier"
364
- " listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a"
365
- " token having permission to this repo with `use_auth_token` or log in with `huggingface-cli"
366
- " login`."
367
- )
368
- except RevisionNotFoundError:
369
- raise EnvironmentError(
370
- f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for"
371
- " this model name. Check the model page at"
372
- f" 'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
373
- )
374
- except EntryNotFoundError:
375
- raise EnvironmentError(
376
- f"{pretrained_model_name_or_path} does not appear to have a file named {cls.config_name}."
377
- )
378
- except HTTPError as err:
379
- raise EnvironmentError(
380
- "There was a specific connection error when trying to load"
381
- f" {pretrained_model_name_or_path}:\n{err}"
382
- )
383
- except ValueError:
384
- raise EnvironmentError(
385
- f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
386
- f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
387
- f" directory containing a {cls.config_name} file.\nCheckout your internet connection or see how to"
388
- " run the library in offline mode at"
389
- " 'https://huggingface.co/docs/diffusers/installation#offline-mode'."
390
- )
391
- except EnvironmentError:
392
- raise EnvironmentError(
393
- f"Can't load config for '{pretrained_model_name_or_path}'. If you were trying to load it from "
394
- "'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
395
- f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
396
- f"containing a {cls.config_name} file"
397
- )
398
-
399
- try:
400
- # Load config dict
401
- config_dict = cls._dict_from_json_file(config_file)
402
-
403
- commit_hash = extract_commit_hash(config_file)
404
- except (json.JSONDecodeError, UnicodeDecodeError):
405
- raise EnvironmentError(f"It looks like the config file at '{config_file}' is not a valid JSON file.")
406
-
407
- if not (return_unused_kwargs or return_commit_hash):
408
- return config_dict
409
-
410
- outputs = (config_dict,)
411
-
412
- if return_unused_kwargs:
413
- outputs += (kwargs,)
414
-
415
- if return_commit_hash:
416
- outputs += (commit_hash,)
417
-
418
- return outputs
419
-
420
- @staticmethod
421
- def _get_init_keys(cls):
422
- return set(dict(inspect.signature(cls.__init__).parameters).keys())
423
-
424
- @classmethod
425
- def extract_init_dict(cls, config_dict, **kwargs):
426
- # Skip keys that were not present in the original config, so default __init__ values were used
427
- used_defaults = config_dict.get("_use_default_values", [])
428
- config_dict = {k: v for k, v in config_dict.items() if k not in used_defaults and k != "_use_default_values"}
429
-
430
- # 0. Copy origin config dict
431
- original_dict = dict(config_dict.items())
432
-
433
- # 1. Retrieve expected config attributes from __init__ signature
434
- expected_keys = cls._get_init_keys(cls)
435
- expected_keys.remove("self")
436
- # remove general kwargs if present in dict
437
- if "kwargs" in expected_keys:
438
- expected_keys.remove("kwargs")
439
- # remove flax internal keys
440
- if hasattr(cls, "_flax_internal_args"):
441
- for arg in cls._flax_internal_args:
442
- expected_keys.remove(arg)
443
-
444
- # 2. Remove attributes that cannot be expected from expected config attributes
445
- # remove keys to be ignored
446
- if len(cls.ignore_for_config) > 0:
447
- expected_keys = expected_keys - set(cls.ignore_for_config)
448
-
449
- # load diffusers library to import compatible and original scheduler
450
- diffusers_library = importlib.import_module(__name__.split(".")[0])
451
-
452
- if cls.has_compatibles:
453
- compatible_classes = [c for c in cls._get_compatibles() if not isinstance(c, DummyObject)]
454
- else:
455
- compatible_classes = []
456
-
457
- expected_keys_comp_cls = set()
458
- for c in compatible_classes:
459
- expected_keys_c = cls._get_init_keys(c)
460
- expected_keys_comp_cls = expected_keys_comp_cls.union(expected_keys_c)
461
- expected_keys_comp_cls = expected_keys_comp_cls - cls._get_init_keys(cls)
462
- config_dict = {k: v for k, v in config_dict.items() if k not in expected_keys_comp_cls}
463
-
464
- # remove attributes from orig class that cannot be expected
465
- orig_cls_name = config_dict.pop("_class_name", cls.__name__)
466
- if orig_cls_name != cls.__name__ and hasattr(diffusers_library, orig_cls_name):
467
- orig_cls = getattr(diffusers_library, orig_cls_name)
468
- unexpected_keys_from_orig = cls._get_init_keys(orig_cls) - expected_keys
469
- config_dict = {k: v for k, v in config_dict.items() if k not in unexpected_keys_from_orig}
470
-
471
- # remove private attributes
472
- config_dict = {k: v for k, v in config_dict.items() if not k.startswith("_")}
473
-
474
- # 3. Create keyword arguments that will be passed to __init__ from expected keyword arguments
475
- init_dict = {}
476
- for key in expected_keys:
477
- # if config param is passed to kwarg and is present in config dict
478
- # it should overwrite existing config dict key
479
- if key in kwargs and key in config_dict:
480
- config_dict[key] = kwargs.pop(key)
481
-
482
- if key in kwargs:
483
- # overwrite key
484
- init_dict[key] = kwargs.pop(key)
485
- elif key in config_dict:
486
- # use value from config dict
487
- init_dict[key] = config_dict.pop(key)
488
-
489
- # 4. Give nice warning if unexpected values have been passed
490
- if len(config_dict) > 0:
491
- logger.warning(
492
- f"The config attributes {config_dict} were passed to {cls.__name__}, "
493
- "but are not expected and will be ignored. Please verify your "
494
- f"{cls.config_name} configuration file."
495
- )
496
-
497
- # 5. Give nice info if config attributes are initiliazed to default because they have not been passed
498
- passed_keys = set(init_dict.keys())
499
- if len(expected_keys - passed_keys) > 0:
500
- logger.info(
501
- f"{expected_keys - passed_keys} was not found in config. Values will be initialized to default values."
502
- )
503
-
504
- # 6. Define unused keyword arguments
505
- unused_kwargs = {**config_dict, **kwargs}
506
-
507
- # 7. Define "hidden" config parameters that were saved for compatible classes
508
- hidden_config_dict = {k: v for k, v in original_dict.items() if k not in init_dict}
509
-
510
- return init_dict, unused_kwargs, hidden_config_dict
511
-
512
- @classmethod
513
- def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]):
514
- with open(json_file, "r", encoding="utf-8") as reader:
515
- text = reader.read()
516
- return json.loads(text)
517
-
518
- def __repr__(self):
519
- return f"{self.__class__.__name__} {self.to_json_string()}"
520
-
521
- @property
522
- def config(self) -> Dict[str, Any]:
523
- """
524
- Returns the config of the class as a frozen dictionary
525
-
526
- Returns:
527
- `Dict[str, Any]`: Config of the class.
528
- """
529
- return self._internal_dict
530
-
531
- def to_json_string(self) -> str:
532
- """
533
- Serializes the configuration instance to a JSON string.
534
-
535
- Returns:
536
- `str`:
537
- String containing all the attributes that make up the configuration instance in JSON format.
538
- """
539
- config_dict = self._internal_dict if hasattr(self, "_internal_dict") else {}
540
- config_dict["_class_name"] = self.__class__.__name__
541
- config_dict["_diffusers_version"] = __version__
542
-
543
- def to_json_saveable(value):
544
- if isinstance(value, np.ndarray):
545
- value = value.tolist()
546
- elif isinstance(value, PosixPath):
547
- value = str(value)
548
- return value
549
-
550
- config_dict = {k: to_json_saveable(v) for k, v in config_dict.items()}
551
- # Don't save "_ignore_files" or "_use_default_values"
552
- config_dict.pop("_ignore_files", None)
553
- config_dict.pop("_use_default_values", None)
554
-
555
- return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
556
-
557
- def to_json_file(self, json_file_path: Union[str, os.PathLike]):
558
- """
559
- Save the configuration instance's parameters to a JSON file.
560
-
561
- Args:
562
- json_file_path (`str` or `os.PathLike`):
563
- Path to the JSON file to save a configuration instance's parameters.
564
- """
565
- with open(json_file_path, "w", encoding="utf-8") as writer:
566
- writer.write(self.to_json_string())
567
-
568
-
569
- def register_to_config(init):
570
- r"""
571
- Decorator to apply on the init of classes inheriting from [`ConfigMixin`] so that all the arguments are
572
- automatically sent to `self.register_for_config`. To ignore a specific argument accepted by the init but that
573
- shouldn't be registered in the config, use the `ignore_for_config` class variable
574
-
575
- Warning: Once decorated, all private arguments (beginning with an underscore) are trashed and not sent to the init!
576
- """
577
-
578
- @functools.wraps(init)
579
- def inner_init(self, *args, **kwargs):
580
- # Ignore private kwargs in the init.
581
- init_kwargs = {k: v for k, v in kwargs.items() if not k.startswith("_")}
582
- config_init_kwargs = {k: v for k, v in kwargs.items() if k.startswith("_")}
583
- if not isinstance(self, ConfigMixin):
584
- raise RuntimeError(
585
- f"`@register_for_config` was applied to {self.__class__.__name__} init method, but this class does "
586
- "not inherit from `ConfigMixin`."
587
- )
588
-
589
- ignore = getattr(self, "ignore_for_config", [])
590
- # Get positional arguments aligned with kwargs
591
- new_kwargs = {}
592
- signature = inspect.signature(init)
593
- parameters = {
594
- name: p.default for i, (name, p) in enumerate(signature.parameters.items()) if i > 0 and name not in ignore
595
- }
596
- for arg, name in zip(args, parameters.keys()):
597
- new_kwargs[name] = arg
598
-
599
- # Then add all kwargs
600
- new_kwargs.update(
601
- {
602
- k: init_kwargs.get(k, default)
603
- for k, default in parameters.items()
604
- if k not in ignore and k not in new_kwargs
605
- }
606
- )
607
-
608
- # Take note of the parameters that were not present in the loaded config
609
- if len(set(new_kwargs.keys()) - set(init_kwargs)) > 0:
610
- new_kwargs["_use_default_values"] = list(set(new_kwargs.keys()) - set(init_kwargs))
611
-
612
- new_kwargs = {**config_init_kwargs, **new_kwargs}
613
- getattr(self, "register_to_config")(**new_kwargs)
614
- init(self, *args, **init_kwargs)
615
-
616
- return inner_init
617
-
618
-
619
- def flax_register_to_config(cls):
620
- original_init = cls.__init__
621
-
622
- @functools.wraps(original_init)
623
- def init(self, *args, **kwargs):
624
- if not isinstance(self, ConfigMixin):
625
- raise RuntimeError(
626
- f"`@register_for_config` was applied to {self.__class__.__name__} init method, but this class does "
627
- "not inherit from `ConfigMixin`."
628
- )
629
-
630
- # Ignore private kwargs in the init. Retrieve all passed attributes
631
- init_kwargs = dict(kwargs.items())
632
-
633
- # Retrieve default values
634
- fields = dataclasses.fields(self)
635
- default_kwargs = {}
636
- for field in fields:
637
- # ignore flax specific attributes
638
- if field.name in self._flax_internal_args:
639
- continue
640
- if type(field.default) == dataclasses._MISSING_TYPE:
641
- default_kwargs[field.name] = None
642
- else:
643
- default_kwargs[field.name] = getattr(self, field.name)
644
-
645
- # Make sure init_kwargs override default kwargs
646
- new_kwargs = {**default_kwargs, **init_kwargs}
647
- # dtype should be part of `init_kwargs`, but not `new_kwargs`
648
- if "dtype" in new_kwargs:
649
- new_kwargs.pop("dtype")
650
-
651
- # Get positional arguments aligned with kwargs
652
- for i, arg in enumerate(args):
653
- name = fields[i].name
654
- new_kwargs[name] = arg
655
-
656
- # Take note of the parameters that were not present in the loaded config
657
- if len(set(new_kwargs.keys()) - set(init_kwargs)) > 0:
658
- new_kwargs["_use_default_values"] = list(set(new_kwargs.keys()) - set(init_kwargs))
659
-
660
- getattr(self, "register_to_config")(**new_kwargs)
661
- original_init(self, *args, **kwargs)
662
-
663
- cls.__init__ = init
664
- return cls
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
gradio_demo/6DoF/diffusers/dependency_versions_check.py DELETED
@@ -1,47 +0,0 @@
1
- # Copyright 2023 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- import sys
15
-
16
- from .dependency_versions_table import deps
17
- from .utils.versions import require_version, require_version_core
18
-
19
-
20
- # define which module versions we always want to check at run time
21
- # (usually the ones defined in `install_requires` in setup.py)
22
- #
23
- # order specific notes:
24
- # - tqdm must be checked before tokenizers
25
-
26
- pkgs_to_check_at_runtime = "python tqdm regex requests packaging filelock numpy tokenizers".split()
27
- if sys.version_info < (3, 7):
28
- pkgs_to_check_at_runtime.append("dataclasses")
29
- if sys.version_info < (3, 8):
30
- pkgs_to_check_at_runtime.append("importlib_metadata")
31
-
32
- for pkg in pkgs_to_check_at_runtime:
33
- if pkg in deps:
34
- if pkg == "tokenizers":
35
- # must be loaded here, or else tqdm check may fail
36
- from .utils import is_tokenizers_available
37
-
38
- if not is_tokenizers_available():
39
- continue # not required, check version only if installed
40
-
41
- require_version_core(deps[pkg])
42
- else:
43
- raise ValueError(f"can't find {pkg} in {deps.keys()}, check dependency_versions_table.py")
44
-
45
-
46
- def dep_version_check(pkg, hint=None):
47
- require_version(deps[pkg], hint)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
gradio_demo/6DoF/diffusers/dependency_versions_table.py DELETED
@@ -1,44 +0,0 @@
1
- # THIS FILE HAS BEEN AUTOGENERATED. To update:
2
- # 1. modify the `_deps` dict in setup.py
3
- # 2. run `make deps_table_update``
4
- deps = {
5
- "Pillow": "Pillow",
6
- "accelerate": "accelerate>=0.11.0",
7
- "compel": "compel==0.1.8",
8
- "black": "black~=23.1",
9
- "datasets": "datasets",
10
- "filelock": "filelock",
11
- "flax": "flax>=0.4.1",
12
- "hf-doc-builder": "hf-doc-builder>=0.3.0",
13
- "huggingface-hub": "huggingface-hub>=0.13.2",
14
- "requests-mock": "requests-mock==1.10.0",
15
- "importlib_metadata": "importlib_metadata",
16
- "invisible-watermark": "invisible-watermark",
17
- "isort": "isort>=5.5.4",
18
- "jax": "jax>=0.2.8,!=0.3.2",
19
- "jaxlib": "jaxlib>=0.1.65",
20
- "Jinja2": "Jinja2",
21
- "k-diffusion": "k-diffusion>=0.0.12",
22
- "torchsde": "torchsde",
23
- "note_seq": "note_seq",
24
- "librosa": "librosa",
25
- "numpy": "numpy",
26
- "omegaconf": "omegaconf",
27
- "parameterized": "parameterized",
28
- "protobuf": "protobuf>=3.20.3,<4",
29
- "pytest": "pytest",
30
- "pytest-timeout": "pytest-timeout",
31
- "pytest-xdist": "pytest-xdist",
32
- "ruff": "ruff>=0.0.241",
33
- "safetensors": "safetensors",
34
- "sentencepiece": "sentencepiece>=0.1.91,!=0.1.92",
35
- "scipy": "scipy",
36
- "onnx": "onnx",
37
- "regex": "regex!=2019.12.17",
38
- "requests": "requests",
39
- "tensorboard": "tensorboard",
40
- "torch": "torch>=1.4",
41
- "torchvision": "torchvision",
42
- "transformers": "transformers>=4.25.1",
43
- "urllib3": "urllib3<=2.0.0",
44
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
gradio_demo/6DoF/diffusers/experimental/__init__.py DELETED
@@ -1 +0,0 @@
1
- from .rl import ValueGuidedRLPipeline
 
 
gradio_demo/6DoF/diffusers/experimental/rl/__init__.py DELETED
@@ -1 +0,0 @@
1
- from .value_guided_sampling import ValueGuidedRLPipeline
 
 
gradio_demo/6DoF/diffusers/experimental/rl/value_guided_sampling.py DELETED
@@ -1,152 +0,0 @@
1
- # Copyright 2023 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- import numpy as np
16
- import torch
17
- import tqdm
18
-
19
- from ...models.unet_1d import UNet1DModel
20
- from ...pipelines import DiffusionPipeline
21
- from ...utils import randn_tensor
22
- from ...utils.dummy_pt_objects import DDPMScheduler
23
-
24
-
25
- class ValueGuidedRLPipeline(DiffusionPipeline):
26
- r"""
27
- This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
28
- library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
29
- Pipeline for sampling actions from a diffusion model trained to predict sequences of states.
30
-
31
- Original implementation inspired by this repository: https://github.com/jannerm/diffuser.
32
-
33
- Parameters:
34
- value_function ([`UNet1DModel`]): A specialized UNet for fine-tuning trajectories base on reward.
35
- unet ([`UNet1DModel`]): U-Net architecture to denoise the encoded trajectories.
36
- scheduler ([`SchedulerMixin`]):
37
- A scheduler to be used in combination with `unet` to denoise the encoded trajectories. Default for this
38
- application is [`DDPMScheduler`].
39
- env: An environment following the OpenAI gym API to act in. For now only Hopper has pretrained models.
40
- """
41
-
42
- def __init__(
43
- self,
44
- value_function: UNet1DModel,
45
- unet: UNet1DModel,
46
- scheduler: DDPMScheduler,
47
- env,
48
- ):
49
- super().__init__()
50
- self.value_function = value_function
51
- self.unet = unet
52
- self.scheduler = scheduler
53
- self.env = env
54
- self.data = env.get_dataset()
55
- self.means = {}
56
- for key in self.data.keys():
57
- try:
58
- self.means[key] = self.data[key].mean()
59
- except: # noqa: E722
60
- pass
61
- self.stds = {}
62
- for key in self.data.keys():
63
- try:
64
- self.stds[key] = self.data[key].std()
65
- except: # noqa: E722
66
- pass
67
- self.state_dim = env.observation_space.shape[0]
68
- self.action_dim = env.action_space.shape[0]
69
-
70
- def normalize(self, x_in, key):
71
- return (x_in - self.means[key]) / self.stds[key]
72
-
73
- def de_normalize(self, x_in, key):
74
- return x_in * self.stds[key] + self.means[key]
75
-
76
- def to_torch(self, x_in):
77
- if type(x_in) is dict:
78
- return {k: self.to_torch(v) for k, v in x_in.items()}
79
- elif torch.is_tensor(x_in):
80
- return x_in.to(self.unet.device)
81
- return torch.tensor(x_in, device=self.unet.device)
82
-
83
- def reset_x0(self, x_in, cond, act_dim):
84
- for key, val in cond.items():
85
- x_in[:, key, act_dim:] = val.clone()
86
- return x_in
87
-
88
- def run_diffusion(self, x, conditions, n_guide_steps, scale):
89
- batch_size = x.shape[0]
90
- y = None
91
- for i in tqdm.tqdm(self.scheduler.timesteps):
92
- # create batch of timesteps to pass into model
93
- timesteps = torch.full((batch_size,), i, device=self.unet.device, dtype=torch.long)
94
- for _ in range(n_guide_steps):
95
- with torch.enable_grad():
96
- x.requires_grad_()
97
-
98
- # permute to match dimension for pre-trained models
99
- y = self.value_function(x.permute(0, 2, 1), timesteps).sample
100
- grad = torch.autograd.grad([y.sum()], [x])[0]
101
-
102
- posterior_variance = self.scheduler._get_variance(i)
103
- model_std = torch.exp(0.5 * posterior_variance)
104
- grad = model_std * grad
105
-
106
- grad[timesteps < 2] = 0
107
- x = x.detach()
108
- x = x + scale * grad
109
- x = self.reset_x0(x, conditions, self.action_dim)
110
-
111
- prev_x = self.unet(x.permute(0, 2, 1), timesteps).sample.permute(0, 2, 1)
112
-
113
- # TODO: verify deprecation of this kwarg
114
- x = self.scheduler.step(prev_x, i, x, predict_epsilon=False)["prev_sample"]
115
-
116
- # apply conditions to the trajectory (set the initial state)
117
- x = self.reset_x0(x, conditions, self.action_dim)
118
- x = self.to_torch(x)
119
- return x, y
120
-
121
- def __call__(self, obs, batch_size=64, planning_horizon=32, n_guide_steps=2, scale=0.1):
122
- # normalize the observations and create batch dimension
123
- obs = self.normalize(obs, "observations")
124
- obs = obs[None].repeat(batch_size, axis=0)
125
-
126
- conditions = {0: self.to_torch(obs)}
127
- shape = (batch_size, planning_horizon, self.state_dim + self.action_dim)
128
-
129
- # generate initial noise and apply our conditions (to make the trajectories start at current state)
130
- x1 = randn_tensor(shape, device=self.unet.device)
131
- x = self.reset_x0(x1, conditions, self.action_dim)
132
- x = self.to_torch(x)
133
-
134
- # run the diffusion process
135
- x, y = self.run_diffusion(x, conditions, n_guide_steps, scale)
136
-
137
- # sort output trajectories by value
138
- sorted_idx = y.argsort(0, descending=True).squeeze()
139
- sorted_values = x[sorted_idx]
140
- actions = sorted_values[:, :, : self.action_dim]
141
- actions = actions.detach().cpu().numpy()
142
- denorm_actions = self.de_normalize(actions, key="actions")
143
-
144
- # select the action with the highest value
145
- if y is not None:
146
- selected_index = 0
147
- else:
148
- # if we didn't run value guiding, select a random action
149
- selected_index = np.random.randint(0, batch_size)
150
-
151
- denorm_actions = denorm_actions[selected_index, 0]
152
- return denorm_actions
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
gradio_demo/6DoF/diffusers/image_processor.py DELETED
@@ -1,366 +0,0 @@
1
- # Copyright 2023 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- import warnings
16
- from typing import List, Optional, Union
17
-
18
- import numpy as np
19
- import PIL
20
- import torch
21
- from PIL import Image
22
-
23
- from .configuration_utils import ConfigMixin, register_to_config
24
- from .utils import CONFIG_NAME, PIL_INTERPOLATION, deprecate
25
-
26
-
27
- class VaeImageProcessor(ConfigMixin):
28
- """
29
- Image processor for VAE.
30
-
31
- Args:
32
- do_resize (`bool`, *optional*, defaults to `True`):
33
- Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`. Can accept
34
- `height` and `width` arguments from [`image_processor.VaeImageProcessor.preprocess`] method.
35
- vae_scale_factor (`int`, *optional*, defaults to `8`):
36
- VAE scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of this factor.
37
- resample (`str`, *optional*, defaults to `lanczos`):
38
- Resampling filter to use when resizing the image.
39
- do_normalize (`bool`, *optional*, defaults to `True`):
40
- Whether to normalize the image to [-1,1].
41
- do_convert_rgb (`bool`, *optional*, defaults to be `False`):
42
- Whether to convert the images to RGB format.
43
- """
44
-
45
- config_name = CONFIG_NAME
46
-
47
- @register_to_config
48
- def __init__(
49
- self,
50
- do_resize: bool = True,
51
- vae_scale_factor: int = 8,
52
- resample: str = "lanczos",
53
- do_normalize: bool = True,
54
- do_convert_rgb: bool = False,
55
- ):
56
- super().__init__()
57
-
58
- @staticmethod
59
- def numpy_to_pil(images: np.ndarray) -> PIL.Image.Image:
60
- """
61
- Convert a numpy image or a batch of images to a PIL image.
62
- """
63
- if images.ndim == 3:
64
- images = images[None, ...]
65
- images = (images * 255).round().astype("uint8")
66
- if images.shape[-1] == 1:
67
- # special case for grayscale (single channel) images
68
- pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
69
- else:
70
- pil_images = [Image.fromarray(image) for image in images]
71
-
72
- return pil_images
73
-
74
- @staticmethod
75
- def pil_to_numpy(images: Union[List[PIL.Image.Image], PIL.Image.Image]) -> np.ndarray:
76
- """
77
- Convert a PIL image or a list of PIL images to NumPy arrays.
78
- """
79
- if not isinstance(images, list):
80
- images = [images]
81
- images = [np.array(image).astype(np.float32) / 255.0 for image in images]
82
- images = np.stack(images, axis=0)
83
-
84
- return images
85
-
86
- @staticmethod
87
- def numpy_to_pt(images: np.ndarray) -> torch.FloatTensor:
88
- """
89
- Convert a NumPy image to a PyTorch tensor.
90
- """
91
- if images.ndim == 3:
92
- images = images[..., None]
93
-
94
- images = torch.from_numpy(images.transpose(0, 3, 1, 2))
95
- return images
96
-
97
- @staticmethod
98
- def pt_to_numpy(images: torch.FloatTensor) -> np.ndarray:
99
- """
100
- Convert a PyTorch tensor to a NumPy image.
101
- """
102
- images = images.cpu().permute(0, 2, 3, 1).float().numpy()
103
- return images
104
-
105
- @staticmethod
106
- def normalize(images):
107
- """
108
- Normalize an image array to [-1,1].
109
- """
110
- return 2.0 * images - 1.0
111
-
112
- @staticmethod
113
- def denormalize(images):
114
- """
115
- Denormalize an image array to [0,1].
116
- """
117
- return (images / 2 + 0.5).clamp(0, 1)
118
-
119
- @staticmethod
120
- def convert_to_rgb(image: PIL.Image.Image) -> PIL.Image.Image:
121
- """
122
- Converts an image to RGB format.
123
- """
124
- image = image.convert("RGB")
125
- return image
126
-
127
- def resize(
128
- self,
129
- image: PIL.Image.Image,
130
- height: Optional[int] = None,
131
- width: Optional[int] = None,
132
- ) -> PIL.Image.Image:
133
- """
134
- Resize a PIL image. Both height and width are downscaled to the next integer multiple of `vae_scale_factor`.
135
- """
136
- if height is None:
137
- height = image.height
138
- if width is None:
139
- width = image.width
140
-
141
- width, height = (
142
- x - x % self.config.vae_scale_factor for x in (width, height)
143
- ) # resize to integer multiple of vae_scale_factor
144
- image = image.resize((width, height), resample=PIL_INTERPOLATION[self.config.resample])
145
- return image
146
-
147
- def preprocess(
148
- self,
149
- image: Union[torch.FloatTensor, PIL.Image.Image, np.ndarray],
150
- height: Optional[int] = None,
151
- width: Optional[int] = None,
152
- ) -> torch.Tensor:
153
- """
154
- Preprocess the image input. Accepted formats are PIL images, NumPy arrays or PyTorch tensors.
155
- """
156
- supported_formats = (PIL.Image.Image, np.ndarray, torch.Tensor)
157
- if isinstance(image, supported_formats):
158
- image = [image]
159
- elif not (isinstance(image, list) and all(isinstance(i, supported_formats) for i in image)):
160
- raise ValueError(
161
- f"Input is in incorrect format: {[type(i) for i in image]}. Currently, we only support {', '.join(supported_formats)}"
162
- )
163
-
164
- if isinstance(image[0], PIL.Image.Image):
165
- if self.config.do_convert_rgb:
166
- image = [self.convert_to_rgb(i) for i in image]
167
- if self.config.do_resize:
168
- image = [self.resize(i, height, width) for i in image]
169
- image = self.pil_to_numpy(image) # to np
170
- image = self.numpy_to_pt(image) # to pt
171
-
172
- elif isinstance(image[0], np.ndarray):
173
- image = np.concatenate(image, axis=0) if image[0].ndim == 4 else np.stack(image, axis=0)
174
- image = self.numpy_to_pt(image)
175
- _, _, height, width = image.shape
176
- if self.config.do_resize and (
177
- height % self.config.vae_scale_factor != 0 or width % self.config.vae_scale_factor != 0
178
- ):
179
- raise ValueError(
180
- f"Currently we only support resizing for PIL image - please resize your numpy array to be divisible by {self.config.vae_scale_factor}"
181
- f"currently the sizes are {height} and {width}. You can also pass a PIL image instead to use resize option in VAEImageProcessor"
182
- )
183
-
184
- elif isinstance(image[0], torch.Tensor):
185
- image = torch.cat(image, axis=0) if image[0].ndim == 4 else torch.stack(image, axis=0)
186
- _, channel, height, width = image.shape
187
-
188
- # don't need any preprocess if the image is latents
189
- if channel == 4:
190
- return image
191
-
192
- if self.config.do_resize and (
193
- height % self.config.vae_scale_factor != 0 or width % self.config.vae_scale_factor != 0
194
- ):
195
- raise ValueError(
196
- f"Currently we only support resizing for PIL image - please resize your pytorch tensor to be divisible by {self.config.vae_scale_factor}"
197
- f"currently the sizes are {height} and {width}. You can also pass a PIL image instead to use resize option in VAEImageProcessor"
198
- )
199
-
200
- # expected range [0,1], normalize to [-1,1]
201
- do_normalize = self.config.do_normalize
202
- if image.min() < 0:
203
- warnings.warn(
204
- "Passing `image` as torch tensor with value range in [-1,1] is deprecated. The expected value range for image tensor is [0,1] "
205
- f"when passing as pytorch tensor or numpy Array. You passed `image` with value range [{image.min()},{image.max()}]",
206
- FutureWarning,
207
- )
208
- do_normalize = False
209
-
210
- if do_normalize:
211
- image = self.normalize(image)
212
-
213
- return image
214
-
215
- def postprocess(
216
- self,
217
- image: torch.FloatTensor,
218
- output_type: str = "pil",
219
- do_denormalize: Optional[List[bool]] = None,
220
- ):
221
- if not isinstance(image, torch.Tensor):
222
- raise ValueError(
223
- f"Input for postprocessing is in incorrect format: {type(image)}. We only support pytorch tensor"
224
- )
225
- if output_type not in ["latent", "pt", "np", "pil"]:
226
- deprecation_message = (
227
- f"the output_type {output_type} is outdated and has been set to `np`. Please make sure to set it to one of these instead: "
228
- "`pil`, `np`, `pt`, `latent`"
229
- )
230
- deprecate("Unsupported output_type", "1.0.0", deprecation_message, standard_warn=False)
231
- output_type = "np"
232
-
233
- if output_type == "latent":
234
- return image
235
-
236
- if do_denormalize is None:
237
- do_denormalize = [self.config.do_normalize] * image.shape[0]
238
-
239
- image = torch.stack(
240
- [self.denormalize(image[i]) if do_denormalize[i] else image[i] for i in range(image.shape[0])]
241
- )
242
-
243
- if output_type == "pt":
244
- return image
245
-
246
- image = self.pt_to_numpy(image)
247
-
248
- if output_type == "np":
249
- return image
250
-
251
- if output_type == "pil":
252
- return self.numpy_to_pil(image)
253
-
254
-
255
- class VaeImageProcessorLDM3D(VaeImageProcessor):
256
- """
257
- Image processor for VAE LDM3D.
258
-
259
- Args:
260
- do_resize (`bool`, *optional*, defaults to `True`):
261
- Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`.
262
- vae_scale_factor (`int`, *optional*, defaults to `8`):
263
- VAE scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of this factor.
264
- resample (`str`, *optional*, defaults to `lanczos`):
265
- Resampling filter to use when resizing the image.
266
- do_normalize (`bool`, *optional*, defaults to `True`):
267
- Whether to normalize the image to [-1,1].
268
- """
269
-
270
- config_name = CONFIG_NAME
271
-
272
- @register_to_config
273
- def __init__(
274
- self,
275
- do_resize: bool = True,
276
- vae_scale_factor: int = 8,
277
- resample: str = "lanczos",
278
- do_normalize: bool = True,
279
- ):
280
- super().__init__()
281
-
282
- @staticmethod
283
- def numpy_to_pil(images):
284
- """
285
- Convert a NumPy image or a batch of images to a PIL image.
286
- """
287
- if images.ndim == 3:
288
- images = images[None, ...]
289
- images = (images * 255).round().astype("uint8")
290
- if images.shape[-1] == 1:
291
- # special case for grayscale (single channel) images
292
- pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
293
- else:
294
- pil_images = [Image.fromarray(image[:, :, :3]) for image in images]
295
-
296
- return pil_images
297
-
298
- @staticmethod
299
- def rgblike_to_depthmap(image):
300
- """
301
- Args:
302
- image: RGB-like depth image
303
-
304
- Returns: depth map
305
-
306
- """
307
- return image[:, :, 1] * 2**8 + image[:, :, 2]
308
-
309
- def numpy_to_depth(self, images):
310
- """
311
- Convert a NumPy depth image or a batch of images to a PIL image.
312
- """
313
- if images.ndim == 3:
314
- images = images[None, ...]
315
- images_depth = images[:, :, :, 3:]
316
- if images.shape[-1] == 6:
317
- images_depth = (images_depth * 255).round().astype("uint8")
318
- pil_images = [
319
- Image.fromarray(self.rgblike_to_depthmap(image_depth), mode="I;16") for image_depth in images_depth
320
- ]
321
- elif images.shape[-1] == 4:
322
- images_depth = (images_depth * 65535.0).astype(np.uint16)
323
- pil_images = [Image.fromarray(image_depth, mode="I;16") for image_depth in images_depth]
324
- else:
325
- raise Exception("Not supported")
326
-
327
- return pil_images
328
-
329
- def postprocess(
330
- self,
331
- image: torch.FloatTensor,
332
- output_type: str = "pil",
333
- do_denormalize: Optional[List[bool]] = None,
334
- ):
335
- if not isinstance(image, torch.Tensor):
336
- raise ValueError(
337
- f"Input for postprocessing is in incorrect format: {type(image)}. We only support pytorch tensor"
338
- )
339
- if output_type not in ["latent", "pt", "np", "pil"]:
340
- deprecation_message = (
341
- f"the output_type {output_type} is outdated and has been set to `np`. Please make sure to set it to one of these instead: "
342
- "`pil`, `np`, `pt`, `latent`"
343
- )
344
- deprecate("Unsupported output_type", "1.0.0", deprecation_message, standard_warn=False)
345
- output_type = "np"
346
-
347
- if do_denormalize is None:
348
- do_denormalize = [self.config.do_normalize] * image.shape[0]
349
-
350
- image = torch.stack(
351
- [self.denormalize(image[i]) if do_denormalize[i] else image[i] for i in range(image.shape[0])]
352
- )
353
-
354
- image = self.pt_to_numpy(image)
355
-
356
- if output_type == "np":
357
- if image.shape[-1] == 6:
358
- image_depth = np.stack([self.rgblike_to_depthmap(im[:, :, 3:]) for im in image], axis=0)
359
- else:
360
- image_depth = image[:, :, :, 3:]
361
- return image[:, :, :, :3], image_depth
362
-
363
- if output_type == "pil":
364
- return self.numpy_to_pil(image), self.numpy_to_depth(image)
365
- else:
366
- raise Exception(f"This type {output_type} is not supported")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
gradio_demo/6DoF/diffusers/loaders.py DELETED
@@ -1,1492 +0,0 @@
1
- # Copyright 2023 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- import os
15
- import warnings
16
- from collections import defaultdict
17
- from pathlib import Path
18
- from typing import Callable, Dict, List, Optional, Union
19
-
20
- import torch
21
- import torch.nn.functional as F
22
- from huggingface_hub import hf_hub_download
23
-
24
- from .models.attention_processor import (
25
- AttnAddedKVProcessor,
26
- AttnAddedKVProcessor2_0,
27
- CustomDiffusionAttnProcessor,
28
- CustomDiffusionXFormersAttnProcessor,
29
- LoRAAttnAddedKVProcessor,
30
- LoRAAttnProcessor,
31
- LoRAAttnProcessor2_0,
32
- LoRAXFormersAttnProcessor,
33
- SlicedAttnAddedKVProcessor,
34
- XFormersAttnProcessor,
35
- )
36
- from .utils import (
37
- DIFFUSERS_CACHE,
38
- HF_HUB_OFFLINE,
39
- TEXT_ENCODER_ATTN_MODULE,
40
- _get_model_file,
41
- deprecate,
42
- is_safetensors_available,
43
- is_transformers_available,
44
- logging,
45
- )
46
-
47
-
48
- if is_safetensors_available():
49
- import safetensors
50
-
51
- if is_transformers_available():
52
- from transformers import PreTrainedModel, PreTrainedTokenizer
53
-
54
-
55
- logger = logging.get_logger(__name__)
56
-
57
- TEXT_ENCODER_NAME = "text_encoder"
58
- UNET_NAME = "unet"
59
-
60
- LORA_WEIGHT_NAME = "pytorch_lora_weights.bin"
61
- LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors"
62
-
63
- TEXT_INVERSION_NAME = "learned_embeds.bin"
64
- TEXT_INVERSION_NAME_SAFE = "learned_embeds.safetensors"
65
-
66
- CUSTOM_DIFFUSION_WEIGHT_NAME = "pytorch_custom_diffusion_weights.bin"
67
- CUSTOM_DIFFUSION_WEIGHT_NAME_SAFE = "pytorch_custom_diffusion_weights.safetensors"
68
-
69
-
70
- class AttnProcsLayers(torch.nn.Module):
71
- def __init__(self, state_dict: Dict[str, torch.Tensor]):
72
- super().__init__()
73
- self.layers = torch.nn.ModuleList(state_dict.values())
74
- self.mapping = dict(enumerate(state_dict.keys()))
75
- self.rev_mapping = {v: k for k, v in enumerate(state_dict.keys())}
76
-
77
- # .processor for unet, .self_attn for text encoder
78
- self.split_keys = [".processor", ".self_attn"]
79
-
80
- # we add a hook to state_dict() and load_state_dict() so that the
81
- # naming fits with `unet.attn_processors`
82
- def map_to(module, state_dict, *args, **kwargs):
83
- new_state_dict = {}
84
- for key, value in state_dict.items():
85
- num = int(key.split(".")[1]) # 0 is always "layers"
86
- new_key = key.replace(f"layers.{num}", module.mapping[num])
87
- new_state_dict[new_key] = value
88
-
89
- return new_state_dict
90
-
91
- def remap_key(key, state_dict):
92
- for k in self.split_keys:
93
- if k in key:
94
- return key.split(k)[0] + k
95
-
96
- raise ValueError(
97
- f"There seems to be a problem with the state_dict: {set(state_dict.keys())}. {key} has to have one of {self.split_keys}."
98
- )
99
-
100
- def map_from(module, state_dict, *args, **kwargs):
101
- all_keys = list(state_dict.keys())
102
- for key in all_keys:
103
- replace_key = remap_key(key, state_dict)
104
- new_key = key.replace(replace_key, f"layers.{module.rev_mapping[replace_key]}")
105
- state_dict[new_key] = state_dict[key]
106
- del state_dict[key]
107
-
108
- self._register_state_dict_hook(map_to)
109
- self._register_load_state_dict_pre_hook(map_from, with_module=True)
110
-
111
-
112
- class UNet2DConditionLoadersMixin:
113
- text_encoder_name = TEXT_ENCODER_NAME
114
- unet_name = UNET_NAME
115
-
116
- def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
117
- r"""
118
- Load pretrained attention processor layers into [`UNet2DConditionModel`]. Attention processor layers have to be
119
- defined in
120
- [`cross_attention.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py)
121
- and be a `torch.nn.Module` class.
122
-
123
- Parameters:
124
- pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
125
- Can be either:
126
-
127
- - A string, the model id (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
128
- the Hub.
129
- - A path to a directory (for example `./my_model_directory`) containing the model weights saved
130
- with [`ModelMixin.save_pretrained`].
131
- - A [torch state
132
- dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
133
-
134
- cache_dir (`Union[str, os.PathLike]`, *optional*):
135
- Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
136
- is not used.
137
- force_download (`bool`, *optional*, defaults to `False`):
138
- Whether or not to force the (re-)download of the model weights and configuration files, overriding the
139
- cached versions if they exist.
140
- resume_download (`bool`, *optional*, defaults to `False`):
141
- Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
142
- incompletely downloaded files are deleted.
143
- proxies (`Dict[str, str]`, *optional*):
144
- A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
145
- 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
146
- local_files_only (`bool`, *optional*, defaults to `False`):
147
- Whether to only load local model weights and configuration files or not. If set to `True`, the model
148
- won't be downloaded from the Hub.
149
- use_auth_token (`str` or *bool*, *optional*):
150
- The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
151
- `diffusers-cli login` (stored in `~/.huggingface`) is used.
152
- revision (`str`, *optional*, defaults to `"main"`):
153
- The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
154
- allowed by Git.
155
- subfolder (`str`, *optional*, defaults to `""`):
156
- The subfolder location of a model file within a larger model repository on the Hub or locally.
157
- mirror (`str`, *optional*):
158
- Mirror source to resolve accessibility issues if you’re downloading a model in China. We do not
159
- guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
160
- information.
161
-
162
- """
163
-
164
- cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
165
- force_download = kwargs.pop("force_download", False)
166
- resume_download = kwargs.pop("resume_download", False)
167
- proxies = kwargs.pop("proxies", None)
168
- local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
169
- use_auth_token = kwargs.pop("use_auth_token", None)
170
- revision = kwargs.pop("revision", None)
171
- subfolder = kwargs.pop("subfolder", None)
172
- weight_name = kwargs.pop("weight_name", None)
173
- use_safetensors = kwargs.pop("use_safetensors", None)
174
- # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
175
- # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
176
- network_alpha = kwargs.pop("network_alpha", None)
177
-
178
- if use_safetensors and not is_safetensors_available():
179
- raise ValueError(
180
- "`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetensors"
181
- )
182
-
183
- allow_pickle = False
184
- if use_safetensors is None:
185
- use_safetensors = is_safetensors_available()
186
- allow_pickle = True
187
-
188
- user_agent = {
189
- "file_type": "attn_procs_weights",
190
- "framework": "pytorch",
191
- }
192
-
193
- model_file = None
194
- if not isinstance(pretrained_model_name_or_path_or_dict, dict):
195
- # Let's first try to load .safetensors weights
196
- if (use_safetensors and weight_name is None) or (
197
- weight_name is not None and weight_name.endswith(".safetensors")
198
- ):
199
- try:
200
- model_file = _get_model_file(
201
- pretrained_model_name_or_path_or_dict,
202
- weights_name=weight_name or LORA_WEIGHT_NAME_SAFE,
203
- cache_dir=cache_dir,
204
- force_download=force_download,
205
- resume_download=resume_download,
206
- proxies=proxies,
207
- local_files_only=local_files_only,
208
- use_auth_token=use_auth_token,
209
- revision=revision,
210
- subfolder=subfolder,
211
- user_agent=user_agent,
212
- )
213
- state_dict = safetensors.torch.load_file(model_file, device="cpu")
214
- except IOError as e:
215
- if not allow_pickle:
216
- raise e
217
- # try loading non-safetensors weights
218
- pass
219
- if model_file is None:
220
- model_file = _get_model_file(
221
- pretrained_model_name_or_path_or_dict,
222
- weights_name=weight_name or LORA_WEIGHT_NAME,
223
- cache_dir=cache_dir,
224
- force_download=force_download,
225
- resume_download=resume_download,
226
- proxies=proxies,
227
- local_files_only=local_files_only,
228
- use_auth_token=use_auth_token,
229
- revision=revision,
230
- subfolder=subfolder,
231
- user_agent=user_agent,
232
- )
233
- state_dict = torch.load(model_file, map_location="cpu")
234
- else:
235
- state_dict = pretrained_model_name_or_path_or_dict
236
-
237
- # fill attn processors
238
- attn_processors = {}
239
-
240
- is_lora = all("lora" in k for k in state_dict.keys())
241
- is_custom_diffusion = any("custom_diffusion" in k for k in state_dict.keys())
242
-
243
- if is_lora:
244
- is_new_lora_format = all(
245
- key.startswith(self.unet_name) or key.startswith(self.text_encoder_name) for key in state_dict.keys()
246
- )
247
- if is_new_lora_format:
248
- # Strip the `"unet"` prefix.
249
- is_text_encoder_present = any(key.startswith(self.text_encoder_name) for key in state_dict.keys())
250
- if is_text_encoder_present:
251
- warn_message = "The state_dict contains LoRA params corresponding to the text encoder which are not being used here. To use both UNet and text encoder related LoRA params, use [`pipe.load_lora_weights()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraLoaderMixin.load_lora_weights)."
252
- warnings.warn(warn_message)
253
- unet_keys = [k for k in state_dict.keys() if k.startswith(self.unet_name)]
254
- state_dict = {k.replace(f"{self.unet_name}.", ""): v for k, v in state_dict.items() if k in unet_keys}
255
-
256
- lora_grouped_dict = defaultdict(dict)
257
- for key, value in state_dict.items():
258
- attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:])
259
- lora_grouped_dict[attn_processor_key][sub_key] = value
260
-
261
- for key, value_dict in lora_grouped_dict.items():
262
- rank = value_dict["to_k_lora.down.weight"].shape[0]
263
- hidden_size = value_dict["to_k_lora.up.weight"].shape[0]
264
-
265
- attn_processor = self
266
- for sub_key in key.split("."):
267
- attn_processor = getattr(attn_processor, sub_key)
268
-
269
- if isinstance(
270
- attn_processor, (AttnAddedKVProcessor, SlicedAttnAddedKVProcessor, AttnAddedKVProcessor2_0)
271
- ):
272
- cross_attention_dim = value_dict["add_k_proj_lora.down.weight"].shape[1]
273
- attn_processor_class = LoRAAttnAddedKVProcessor
274
- else:
275
- cross_attention_dim = value_dict["to_k_lora.down.weight"].shape[1]
276
- if isinstance(attn_processor, (XFormersAttnProcessor, LoRAXFormersAttnProcessor)):
277
- attn_processor_class = LoRAXFormersAttnProcessor
278
- else:
279
- attn_processor_class = (
280
- LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
281
- )
282
-
283
- attn_processors[key] = attn_processor_class(
284
- hidden_size=hidden_size,
285
- cross_attention_dim=cross_attention_dim,
286
- rank=rank,
287
- network_alpha=network_alpha,
288
- )
289
- attn_processors[key].load_state_dict(value_dict)
290
- elif is_custom_diffusion:
291
- custom_diffusion_grouped_dict = defaultdict(dict)
292
- for key, value in state_dict.items():
293
- if len(value) == 0:
294
- custom_diffusion_grouped_dict[key] = {}
295
- else:
296
- if "to_out" in key:
297
- attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:])
298
- else:
299
- attn_processor_key, sub_key = ".".join(key.split(".")[:-2]), ".".join(key.split(".")[-2:])
300
- custom_diffusion_grouped_dict[attn_processor_key][sub_key] = value
301
-
302
- for key, value_dict in custom_diffusion_grouped_dict.items():
303
- if len(value_dict) == 0:
304
- attn_processors[key] = CustomDiffusionAttnProcessor(
305
- train_kv=False, train_q_out=False, hidden_size=None, cross_attention_dim=None
306
- )
307
- else:
308
- cross_attention_dim = value_dict["to_k_custom_diffusion.weight"].shape[1]
309
- hidden_size = value_dict["to_k_custom_diffusion.weight"].shape[0]
310
- train_q_out = True if "to_q_custom_diffusion.weight" in value_dict else False
311
- attn_processors[key] = CustomDiffusionAttnProcessor(
312
- train_kv=True,
313
- train_q_out=train_q_out,
314
- hidden_size=hidden_size,
315
- cross_attention_dim=cross_attention_dim,
316
- )
317
- attn_processors[key].load_state_dict(value_dict)
318
- else:
319
- raise ValueError(
320
- f"{model_file} does not seem to be in the correct format expected by LoRA or Custom Diffusion training."
321
- )
322
-
323
- # set correct dtype & device
324
- attn_processors = {k: v.to(device=self.device, dtype=self.dtype) for k, v in attn_processors.items()}
325
-
326
- # set layers
327
- self.set_attn_processor(attn_processors)
328
-
329
- def save_attn_procs(
330
- self,
331
- save_directory: Union[str, os.PathLike],
332
- is_main_process: bool = True,
333
- weight_name: str = None,
334
- save_function: Callable = None,
335
- safe_serialization: bool = False,
336
- **kwargs,
337
- ):
338
- r"""
339
- Save an attention processor to a directory so that it can be reloaded using the
340
- [`~loaders.UNet2DConditionLoadersMixin.load_attn_procs`] method.
341
-
342
- Arguments:
343
- save_directory (`str` or `os.PathLike`):
344
- Directory to save an attention processor to. Will be created if it doesn't exist.
345
- is_main_process (`bool`, *optional*, defaults to `True`):
346
- Whether the process calling this is the main process or not. Useful during distributed training and you
347
- need to call this function on all processes. In this case, set `is_main_process=True` only on the main
348
- process to avoid race conditions.
349
- save_function (`Callable`):
350
- The function to use to save the state dictionary. Useful during distributed training when you need to
351
- replace `torch.save` with another method. Can be configured with the environment variable
352
- `DIFFUSERS_SAVE_MODE`.
353
-
354
- """
355
- weight_name = weight_name or deprecate(
356
- "weights_name",
357
- "0.20.0",
358
- "`weights_name` is deprecated, please use `weight_name` instead.",
359
- take_from=kwargs,
360
- )
361
- if os.path.isfile(save_directory):
362
- logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
363
- return
364
-
365
- if save_function is None:
366
- if safe_serialization:
367
-
368
- def save_function(weights, filename):
369
- return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"})
370
-
371
- else:
372
- save_function = torch.save
373
-
374
- os.makedirs(save_directory, exist_ok=True)
375
-
376
- is_custom_diffusion = any(
377
- isinstance(x, (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor))
378
- for (_, x) in self.attn_processors.items()
379
- )
380
- if is_custom_diffusion:
381
- model_to_save = AttnProcsLayers(
382
- {
383
- y: x
384
- for (y, x) in self.attn_processors.items()
385
- if isinstance(x, (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor))
386
- }
387
- )
388
- state_dict = model_to_save.state_dict()
389
- for name, attn in self.attn_processors.items():
390
- if len(attn.state_dict()) == 0:
391
- state_dict[name] = {}
392
- else:
393
- model_to_save = AttnProcsLayers(self.attn_processors)
394
- state_dict = model_to_save.state_dict()
395
-
396
- if weight_name is None:
397
- if safe_serialization:
398
- weight_name = CUSTOM_DIFFUSION_WEIGHT_NAME_SAFE if is_custom_diffusion else LORA_WEIGHT_NAME_SAFE
399
- else:
400
- weight_name = CUSTOM_DIFFUSION_WEIGHT_NAME if is_custom_diffusion else LORA_WEIGHT_NAME
401
-
402
- # Save the model
403
- save_function(state_dict, os.path.join(save_directory, weight_name))
404
- logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}")
405
-
406
-
407
- class TextualInversionLoaderMixin:
408
- r"""
409
- Load textual inversion tokens and embeddings to the tokenizer and text encoder.
410
- """
411
-
412
- def maybe_convert_prompt(self, prompt: Union[str, List[str]], tokenizer: "PreTrainedTokenizer"):
413
- r"""
414
- Processes prompts that include a special token corresponding to a multi-vector textual inversion embedding to
415
- be replaced with multiple special tokens each corresponding to one of the vectors. If the prompt has no textual
416
- inversion token or if the textual inversion token is a single vector, the input prompt is returned.
417
-
418
- Parameters:
419
- prompt (`str` or list of `str`):
420
- The prompt or prompts to guide the image generation.
421
- tokenizer (`PreTrainedTokenizer`):
422
- The tokenizer responsible for encoding the prompt into input tokens.
423
-
424
- Returns:
425
- `str` or list of `str`: The converted prompt
426
- """
427
- if not isinstance(prompt, List):
428
- prompts = [prompt]
429
- else:
430
- prompts = prompt
431
-
432
- prompts = [self._maybe_convert_prompt(p, tokenizer) for p in prompts]
433
-
434
- if not isinstance(prompt, List):
435
- return prompts[0]
436
-
437
- return prompts
438
-
439
- def _maybe_convert_prompt(self, prompt: str, tokenizer: "PreTrainedTokenizer"):
440
- r"""
441
- Maybe convert a prompt into a "multi vector"-compatible prompt. If the prompt includes a token that corresponds
442
- to a multi-vector textual inversion embedding, this function will process the prompt so that the special token
443
- is replaced with multiple special tokens each corresponding to one of the vectors. If the prompt has no textual
444
- inversion token or a textual inversion token that is a single vector, the input prompt is simply returned.
445
-
446
- Parameters:
447
- prompt (`str`):
448
- The prompt to guide the image generation.
449
- tokenizer (`PreTrainedTokenizer`):
450
- The tokenizer responsible for encoding the prompt into input tokens.
451
-
452
- Returns:
453
- `str`: The converted prompt
454
- """
455
- tokens = tokenizer.tokenize(prompt)
456
- unique_tokens = set(tokens)
457
- for token in unique_tokens:
458
- if token in tokenizer.added_tokens_encoder:
459
- replacement = token
460
- i = 1
461
- while f"{token}_{i}" in tokenizer.added_tokens_encoder:
462
- replacement += f" {token}_{i}"
463
- i += 1
464
-
465
- prompt = prompt.replace(token, replacement)
466
-
467
- return prompt
468
-
469
- def load_textual_inversion(
470
- self,
471
- pretrained_model_name_or_path: Union[str, List[str], Dict[str, torch.Tensor], List[Dict[str, torch.Tensor]]],
472
- token: Optional[Union[str, List[str]]] = None,
473
- **kwargs,
474
- ):
475
- r"""
476
- Load textual inversion embeddings into the text encoder of [`StableDiffusionPipeline`] (both 🤗 Diffusers and
477
- Automatic1111 formats are supported).
478
-
479
- Parameters:
480
- pretrained_model_name_or_path (`str` or `os.PathLike` or `List[str or os.PathLike]` or `Dict` or `List[Dict]`):
481
- Can be either one of the following or a list of them:
482
-
483
- - A string, the *model id* (for example `sd-concepts-library/low-poly-hd-logos-icons`) of a
484
- pretrained model hosted on the Hub.
485
- - A path to a *directory* (for example `./my_text_inversion_directory/`) containing the textual
486
- inversion weights.
487
- - A path to a *file* (for example `./my_text_inversions.pt`) containing textual inversion weights.
488
- - A [torch state
489
- dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
490
-
491
- token (`str` or `List[str]`, *optional*):
492
- Override the token to use for the textual inversion weights. If `pretrained_model_name_or_path` is a
493
- list, then `token` must also be a list of equal length.
494
- weight_name (`str`, *optional*):
495
- Name of a custom weight file. This should be used when:
496
-
497
- - The saved textual inversion file is in 🤗 Diffusers format, but was saved under a specific weight
498
- name such as `text_inv.bin`.
499
- - The saved textual inversion file is in the Automatic1111 format.
500
- cache_dir (`Union[str, os.PathLike]`, *optional*):
501
- Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
502
- is not used.
503
- force_download (`bool`, *optional*, defaults to `False`):
504
- Whether or not to force the (re-)download of the model weights and configuration files, overriding the
505
- cached versions if they exist.
506
- resume_download (`bool`, *optional*, defaults to `False`):
507
- Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
508
- incompletely downloaded files are deleted.
509
- proxies (`Dict[str, str]`, *optional*):
510
- A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
511
- 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
512
- local_files_only (`bool`, *optional*, defaults to `False`):
513
- Whether to only load local model weights and configuration files or not. If set to `True`, the model
514
- won't be downloaded from the Hub.
515
- use_auth_token (`str` or *bool*, *optional*):
516
- The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
517
- `diffusers-cli login` (stored in `~/.huggingface`) is used.
518
- revision (`str`, *optional*, defaults to `"main"`):
519
- The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
520
- allowed by Git.
521
- subfolder (`str`, *optional*, defaults to `""`):
522
- The subfolder location of a model file within a larger model repository on the Hub or locally.
523
- mirror (`str`, *optional*):
524
- Mirror source to resolve accessibility issues if you're downloading a model in China. We do not
525
- guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
526
- information.
527
-
528
- Example:
529
-
530
- To load a textual inversion embedding vector in 🤗 Diffusers format:
531
-
532
- ```py
533
- from diffusers import StableDiffusionPipeline
534
- import torch
535
-
536
- model_id = "runwayml/stable-diffusion-v1-5"
537
- pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")
538
-
539
- pipe.load_textual_inversion("sd-concepts-library/cat-toy")
540
-
541
- prompt = "A <cat-toy> backpack"
542
-
543
- image = pipe(prompt, num_inference_steps=50).images[0]
544
- image.save("cat-backpack.png")
545
- ```
546
-
547
- To load a textual inversion embedding vector in Automatic1111 format, make sure to download the vector first
548
- (for example from [civitAI](https://civitai.com/models/3036?modelVersionId=9857)) and then load the vector
549
- locally:
550
-
551
- ```py
552
- from diffusers import StableDiffusionPipeline
553
- import torch
554
-
555
- model_id = "runwayml/stable-diffusion-v1-5"
556
- pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")
557
-
558
- pipe.load_textual_inversion("./charturnerv2.pt", token="charturnerv2")
559
-
560
- prompt = "charturnerv2, multiple views of the same character in the same outfit, a character turnaround of a woman wearing a black jacket and red shirt, best quality, intricate details."
561
-
562
- image = pipe(prompt, num_inference_steps=50).images[0]
563
- image.save("character.png")
564
- ```
565
-
566
- """
567
- if not hasattr(self, "tokenizer") or not isinstance(self.tokenizer, PreTrainedTokenizer):
568
- raise ValueError(
569
- f"{self.__class__.__name__} requires `self.tokenizer` of type `PreTrainedTokenizer` for calling"
570
- f" `{self.load_textual_inversion.__name__}`"
571
- )
572
-
573
- if not hasattr(self, "text_encoder") or not isinstance(self.text_encoder, PreTrainedModel):
574
- raise ValueError(
575
- f"{self.__class__.__name__} requires `self.text_encoder` of type `PreTrainedModel` for calling"
576
- f" `{self.load_textual_inversion.__name__}`"
577
- )
578
-
579
- cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
580
- force_download = kwargs.pop("force_download", False)
581
- resume_download = kwargs.pop("resume_download", False)
582
- proxies = kwargs.pop("proxies", None)
583
- local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
584
- use_auth_token = kwargs.pop("use_auth_token", None)
585
- revision = kwargs.pop("revision", None)
586
- subfolder = kwargs.pop("subfolder", None)
587
- weight_name = kwargs.pop("weight_name", None)
588
- use_safetensors = kwargs.pop("use_safetensors", None)
589
-
590
- if use_safetensors and not is_safetensors_available():
591
- raise ValueError(
592
- "`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetensors"
593
- )
594
-
595
- allow_pickle = False
596
- if use_safetensors is None:
597
- use_safetensors = is_safetensors_available()
598
- allow_pickle = True
599
-
600
- user_agent = {
601
- "file_type": "text_inversion",
602
- "framework": "pytorch",
603
- }
604
-
605
- if not isinstance(pretrained_model_name_or_path, list):
606
- pretrained_model_name_or_paths = [pretrained_model_name_or_path]
607
- else:
608
- pretrained_model_name_or_paths = pretrained_model_name_or_path
609
-
610
- if isinstance(token, str):
611
- tokens = [token]
612
- elif token is None:
613
- tokens = [None] * len(pretrained_model_name_or_paths)
614
- else:
615
- tokens = token
616
-
617
- if len(pretrained_model_name_or_paths) != len(tokens):
618
- raise ValueError(
619
- f"You have passed a list of models of length {len(pretrained_model_name_or_paths)}, and list of tokens of length {len(tokens)}"
620
- f"Make sure both lists have the same length."
621
- )
622
-
623
- valid_tokens = [t for t in tokens if t is not None]
624
- if len(set(valid_tokens)) < len(valid_tokens):
625
- raise ValueError(f"You have passed a list of tokens that contains duplicates: {tokens}")
626
-
627
- token_ids_and_embeddings = []
628
-
629
- for pretrained_model_name_or_path, token in zip(pretrained_model_name_or_paths, tokens):
630
- if not isinstance(pretrained_model_name_or_path, dict):
631
- # 1. Load textual inversion file
632
- model_file = None
633
- # Let's first try to load .safetensors weights
634
- if (use_safetensors and weight_name is None) or (
635
- weight_name is not None and weight_name.endswith(".safetensors")
636
- ):
637
- try:
638
- model_file = _get_model_file(
639
- pretrained_model_name_or_path,
640
- weights_name=weight_name or TEXT_INVERSION_NAME_SAFE,
641
- cache_dir=cache_dir,
642
- force_download=force_download,
643
- resume_download=resume_download,
644
- proxies=proxies,
645
- local_files_only=local_files_only,
646
- use_auth_token=use_auth_token,
647
- revision=revision,
648
- subfolder=subfolder,
649
- user_agent=user_agent,
650
- )
651
- state_dict = safetensors.torch.load_file(model_file, device="cpu")
652
- except Exception as e:
653
- if not allow_pickle:
654
- raise e
655
-
656
- model_file = None
657
-
658
- if model_file is None:
659
- model_file = _get_model_file(
660
- pretrained_model_name_or_path,
661
- weights_name=weight_name or TEXT_INVERSION_NAME,
662
- cache_dir=cache_dir,
663
- force_download=force_download,
664
- resume_download=resume_download,
665
- proxies=proxies,
666
- local_files_only=local_files_only,
667
- use_auth_token=use_auth_token,
668
- revision=revision,
669
- subfolder=subfolder,
670
- user_agent=user_agent,
671
- )
672
- state_dict = torch.load(model_file, map_location="cpu")
673
- else:
674
- state_dict = pretrained_model_name_or_path
675
-
676
- # 2. Load token and embedding correcly from file
677
- loaded_token = None
678
- if isinstance(state_dict, torch.Tensor):
679
- if token is None:
680
- raise ValueError(
681
- "You are trying to load a textual inversion embedding that has been saved as a PyTorch tensor. Make sure to pass the name of the corresponding token in this case: `token=...`."
682
- )
683
- embedding = state_dict
684
- elif len(state_dict) == 1:
685
- # diffusers
686
- loaded_token, embedding = next(iter(state_dict.items()))
687
- elif "string_to_param" in state_dict:
688
- # A1111
689
- loaded_token = state_dict["name"]
690
- embedding = state_dict["string_to_param"]["*"]
691
-
692
- if token is not None and loaded_token != token:
693
- logger.info(f"The loaded token: {loaded_token} is overwritten by the passed token {token}.")
694
- else:
695
- token = loaded_token
696
-
697
- embedding = embedding.to(dtype=self.text_encoder.dtype, device=self.text_encoder.device)
698
-
699
- # 3. Make sure we don't mess up the tokenizer or text encoder
700
- vocab = self.tokenizer.get_vocab()
701
- if token in vocab:
702
- raise ValueError(
703
- f"Token {token} already in tokenizer vocabulary. Please choose a different token name or remove {token} and embedding from the tokenizer and text encoder."
704
- )
705
- elif f"{token}_1" in vocab:
706
- multi_vector_tokens = [token]
707
- i = 1
708
- while f"{token}_{i}" in self.tokenizer.added_tokens_encoder:
709
- multi_vector_tokens.append(f"{token}_{i}")
710
- i += 1
711
-
712
- raise ValueError(
713
- f"Multi-vector Token {multi_vector_tokens} already in tokenizer vocabulary. Please choose a different token name or remove the {multi_vector_tokens} and embedding from the tokenizer and text encoder."
714
- )
715
-
716
- is_multi_vector = len(embedding.shape) > 1 and embedding.shape[0] > 1
717
-
718
- if is_multi_vector:
719
- tokens = [token] + [f"{token}_{i}" for i in range(1, embedding.shape[0])]
720
- embeddings = [e for e in embedding] # noqa: C416
721
- else:
722
- tokens = [token]
723
- embeddings = [embedding[0]] if len(embedding.shape) > 1 else [embedding]
724
-
725
- # add tokens and get ids
726
- self.tokenizer.add_tokens(tokens)
727
- token_ids = self.tokenizer.convert_tokens_to_ids(tokens)
728
- token_ids_and_embeddings += zip(token_ids, embeddings)
729
-
730
- logger.info(f"Loaded textual inversion embedding for {token}.")
731
-
732
- # resize token embeddings and set all new embeddings
733
- self.text_encoder.resize_token_embeddings(len(self.tokenizer))
734
- for token_id, embedding in token_ids_and_embeddings:
735
- self.text_encoder.get_input_embeddings().weight.data[token_id] = embedding
736
-
737
-
738
- class LoraLoaderMixin:
739
- r"""
740
- Load LoRA layers into [`UNet2DConditionModel`] and
741
- [`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel).
742
- """
743
- text_encoder_name = TEXT_ENCODER_NAME
744
- unet_name = UNET_NAME
745
-
746
- def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
747
- r"""
748
- Load pretrained LoRA attention processor layers into [`UNet2DConditionModel`] and
749
- [`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel).
750
-
751
- Parameters:
752
- pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
753
- Can be either:
754
-
755
- - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
756
- the Hub.
757
- - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
758
- with [`ModelMixin.save_pretrained`].
759
- - A [torch state
760
- dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
761
-
762
- cache_dir (`Union[str, os.PathLike]`, *optional*):
763
- Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
764
- is not used.
765
- force_download (`bool`, *optional*, defaults to `False`):
766
- Whether or not to force the (re-)download of the model weights and configuration files, overriding the
767
- cached versions if they exist.
768
- resume_download (`bool`, *optional*, defaults to `False`):
769
- Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
770
- incompletely downloaded files are deleted.
771
- proxies (`Dict[str, str]`, *optional*):
772
- A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
773
- 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
774
- local_files_only (`bool`, *optional*, defaults to `False`):
775
- Whether to only load local model weights and configuration files or not. If set to `True`, the model
776
- won't be downloaded from the Hub.
777
- use_auth_token (`str` or *bool*, *optional*):
778
- The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
779
- `diffusers-cli login` (stored in `~/.huggingface`) is used.
780
- revision (`str`, *optional*, defaults to `"main"`):
781
- The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
782
- allowed by Git.
783
- subfolder (`str`, *optional*, defaults to `""`):
784
- The subfolder location of a model file within a larger model repository on the Hub or locally.
785
- mirror (`str`, *optional*):
786
- Mirror source to resolve accessibility issues if you're downloading a model in China. We do not
787
- guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
788
- information.
789
-
790
- """
791
- # Load the main state dict first which has the LoRA layers for either of
792
- # UNet and text encoder or both.
793
- cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
794
- force_download = kwargs.pop("force_download", False)
795
- resume_download = kwargs.pop("resume_download", False)
796
- proxies = kwargs.pop("proxies", None)
797
- local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
798
- use_auth_token = kwargs.pop("use_auth_token", None)
799
- revision = kwargs.pop("revision", None)
800
- subfolder = kwargs.pop("subfolder", None)
801
- weight_name = kwargs.pop("weight_name", None)
802
- use_safetensors = kwargs.pop("use_safetensors", None)
803
-
804
- # set lora scale to a reasonable default
805
- self._lora_scale = 1.0
806
-
807
- if use_safetensors and not is_safetensors_available():
808
- raise ValueError(
809
- "`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetensors"
810
- )
811
-
812
- allow_pickle = False
813
- if use_safetensors is None:
814
- use_safetensors = is_safetensors_available()
815
- allow_pickle = True
816
-
817
- user_agent = {
818
- "file_type": "attn_procs_weights",
819
- "framework": "pytorch",
820
- }
821
-
822
- model_file = None
823
- if not isinstance(pretrained_model_name_or_path_or_dict, dict):
824
- # Let's first try to load .safetensors weights
825
- if (use_safetensors and weight_name is None) or (
826
- weight_name is not None and weight_name.endswith(".safetensors")
827
- ):
828
- try:
829
- model_file = _get_model_file(
830
- pretrained_model_name_or_path_or_dict,
831
- weights_name=weight_name or LORA_WEIGHT_NAME_SAFE,
832
- cache_dir=cache_dir,
833
- force_download=force_download,
834
- resume_download=resume_download,
835
- proxies=proxies,
836
- local_files_only=local_files_only,
837
- use_auth_token=use_auth_token,
838
- revision=revision,
839
- subfolder=subfolder,
840
- user_agent=user_agent,
841
- )
842
- state_dict = safetensors.torch.load_file(model_file, device="cpu")
843
- except IOError as e:
844
- if not allow_pickle:
845
- raise e
846
- # try loading non-safetensors weights
847
- pass
848
- if model_file is None:
849
- model_file = _get_model_file(
850
- pretrained_model_name_or_path_or_dict,
851
- weights_name=weight_name or LORA_WEIGHT_NAME,
852
- cache_dir=cache_dir,
853
- force_download=force_download,
854
- resume_download=resume_download,
855
- proxies=proxies,
856
- local_files_only=local_files_only,
857
- use_auth_token=use_auth_token,
858
- revision=revision,
859
- subfolder=subfolder,
860
- user_agent=user_agent,
861
- )
862
- state_dict = torch.load(model_file, map_location="cpu")
863
- else:
864
- state_dict = pretrained_model_name_or_path_or_dict
865
-
866
- # Convert kohya-ss Style LoRA attn procs to diffusers attn procs
867
- network_alpha = None
868
- if all((k.startswith("lora_te_") or k.startswith("lora_unet_")) for k in state_dict.keys()):
869
- state_dict, network_alpha = self._convert_kohya_lora_to_diffusers(state_dict)
870
-
871
- # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
872
- # then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as
873
- # their prefixes.
874
- keys = list(state_dict.keys())
875
- if all(key.startswith(self.unet_name) or key.startswith(self.text_encoder_name) for key in keys):
876
- # Load the layers corresponding to UNet.
877
- unet_keys = [k for k in keys if k.startswith(self.unet_name)]
878
- logger.info(f"Loading {self.unet_name}.")
879
- unet_lora_state_dict = {
880
- k.replace(f"{self.unet_name}.", ""): v for k, v in state_dict.items() if k in unet_keys
881
- }
882
- self.unet.load_attn_procs(unet_lora_state_dict, network_alpha=network_alpha)
883
-
884
- # Load the layers corresponding to text encoder and make necessary adjustments.
885
- text_encoder_keys = [k for k in keys if k.startswith(self.text_encoder_name)]
886
- text_encoder_lora_state_dict = {
887
- k.replace(f"{self.text_encoder_name}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys
888
- }
889
- if len(text_encoder_lora_state_dict) > 0:
890
- logger.info(f"Loading {self.text_encoder_name}.")
891
- attn_procs_text_encoder = self._load_text_encoder_attn_procs(
892
- text_encoder_lora_state_dict, network_alpha=network_alpha
893
- )
894
- self._modify_text_encoder(attn_procs_text_encoder)
895
-
896
- # save lora attn procs of text encoder so that it can be easily retrieved
897
- self._text_encoder_lora_attn_procs = attn_procs_text_encoder
898
-
899
- # Otherwise, we're dealing with the old format. This means the `state_dict` should only
900
- # contain the module names of the `unet` as its keys WITHOUT any prefix.
901
- elif not all(
902
- key.startswith(self.unet_name) or key.startswith(self.text_encoder_name) for key in state_dict.keys()
903
- ):
904
- self.unet.load_attn_procs(state_dict)
905
- warn_message = "You have saved the LoRA weights using the old format. To convert the old LoRA weights to the new format, you can first load them in a dictionary and then create a new dictionary like the following: `new_state_dict = {f'unet'.{module_name}: params for module_name, params in old_state_dict.items()}`."
906
- warnings.warn(warn_message)
907
-
908
- @property
909
- def lora_scale(self) -> float:
910
- # property function that returns the lora scale which can be set at run time by the pipeline.
911
- # if _lora_scale has not been set, return 1
912
- return self._lora_scale if hasattr(self, "_lora_scale") else 1.0
913
-
914
- @property
915
- def text_encoder_lora_attn_procs(self):
916
- if hasattr(self, "_text_encoder_lora_attn_procs"):
917
- return self._text_encoder_lora_attn_procs
918
- return
919
-
920
- def _remove_text_encoder_monkey_patch(self):
921
- # Loop over the CLIPAttention module of text_encoder
922
- for name, attn_module in self.text_encoder.named_modules():
923
- if name.endswith(TEXT_ENCODER_ATTN_MODULE):
924
- # Loop over the LoRA layers
925
- for _, text_encoder_attr in self._lora_attn_processor_attr_to_text_encoder_attr.items():
926
- # Retrieve the q/k/v/out projection of CLIPAttention
927
- module = attn_module.get_submodule(text_encoder_attr)
928
- if hasattr(module, "old_forward"):
929
- # restore original `forward` to remove monkey-patch
930
- module.forward = module.old_forward
931
- delattr(module, "old_forward")
932
-
933
- def _modify_text_encoder(self, attn_processors: Dict[str, LoRAAttnProcessor]):
934
- r"""
935
- Monkey-patches the forward passes of attention modules of the text encoder.
936
-
937
- Parameters:
938
- attn_processors: Dict[str, `LoRAAttnProcessor`]:
939
- A dictionary mapping the module names and their corresponding [`~LoRAAttnProcessor`].
940
- """
941
-
942
- # First, remove any monkey-patch that might have been applied before
943
- self._remove_text_encoder_monkey_patch()
944
-
945
- # Loop over the CLIPAttention module of text_encoder
946
- for name, attn_module in self.text_encoder.named_modules():
947
- if name.endswith(TEXT_ENCODER_ATTN_MODULE):
948
- # Loop over the LoRA layers
949
- for attn_proc_attr, text_encoder_attr in self._lora_attn_processor_attr_to_text_encoder_attr.items():
950
- # Retrieve the q/k/v/out projection of CLIPAttention and its corresponding LoRA layer.
951
- module = attn_module.get_submodule(text_encoder_attr)
952
- lora_layer = attn_processors[name].get_submodule(attn_proc_attr)
953
-
954
- # save old_forward to module that can be used to remove monkey-patch
955
- old_forward = module.old_forward = module.forward
956
-
957
- # create a new scope that locks in the old_forward, lora_layer value for each new_forward function
958
- # for more detail, see https://github.com/huggingface/diffusers/pull/3490#issuecomment-1555059060
959
- def make_new_forward(old_forward, lora_layer):
960
- def new_forward(x):
961
- result = old_forward(x) + self.lora_scale * lora_layer(x)
962
- return result
963
-
964
- return new_forward
965
-
966
- # Monkey-patch.
967
- module.forward = make_new_forward(old_forward, lora_layer)
968
-
969
- @property
970
- def _lora_attn_processor_attr_to_text_encoder_attr(self):
971
- return {
972
- "to_q_lora": "q_proj",
973
- "to_k_lora": "k_proj",
974
- "to_v_lora": "v_proj",
975
- "to_out_lora": "out_proj",
976
- }
977
-
978
- def _load_text_encoder_attn_procs(
979
- self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs
980
- ):
981
- r"""
982
- Load pretrained attention processor layers for
983
- [`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel).
984
-
985
- <Tip warning={true}>
986
-
987
- This function is experimental and might change in the future.
988
-
989
- </Tip>
990
-
991
- Parameters:
992
- pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
993
- Can be either:
994
-
995
- - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
996
- Valid model ids should have an organization name, like `google/ddpm-celebahq-256`.
997
- - A path to a *directory* containing model weights saved using [`~ModelMixin.save_config`], e.g.,
998
- `./my_model_directory/`.
999
- - A [torch state
1000
- dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
1001
-
1002
- cache_dir (`Union[str, os.PathLike]`, *optional*):
1003
- Path to a directory in which a downloaded pretrained model configuration should be cached if the
1004
- standard cache should not be used.
1005
- force_download (`bool`, *optional*, defaults to `False`):
1006
- Whether or not to force the (re-)download of the model weights and configuration files, overriding the
1007
- cached versions if they exist.
1008
- resume_download (`bool`, *optional*, defaults to `False`):
1009
- Whether or not to delete incompletely received files. Will attempt to resume the download if such a
1010
- file exists.
1011
- proxies (`Dict[str, str]`, *optional*):
1012
- A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
1013
- 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
1014
- local_files_only (`bool`, *optional*, defaults to `False`):
1015
- Whether or not to only look at local files (i.e., do not try to download the model).
1016
- use_auth_token (`str` or *bool*, *optional*):
1017
- The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
1018
- when running `diffusers-cli login` (stored in `~/.huggingface`).
1019
- revision (`str`, *optional*, defaults to `"main"`):
1020
- The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
1021
- git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
1022
- identifier allowed by git.
1023
- subfolder (`str`, *optional*, defaults to `""`):
1024
- In case the relevant files are located inside a subfolder of the model repo (either remote in
1025
- huggingface.co or downloaded locally), you can specify the folder name here.
1026
- mirror (`str`, *optional*):
1027
- Mirror source to accelerate downloads in China. If you are from China and have an accessibility
1028
- problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
1029
- Please refer to the mirror site for more information.
1030
-
1031
- Returns:
1032
- `Dict[name, LoRAAttnProcessor]`: Mapping between the module names and their corresponding
1033
- [`LoRAAttnProcessor`].
1034
-
1035
- <Tip>
1036
-
1037
- It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated
1038
- models](https://huggingface.co/docs/hub/models-gated#gated-models).
1039
-
1040
- </Tip>
1041
- """
1042
-
1043
- cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
1044
- force_download = kwargs.pop("force_download", False)
1045
- resume_download = kwargs.pop("resume_download", False)
1046
- proxies = kwargs.pop("proxies", None)
1047
- local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
1048
- use_auth_token = kwargs.pop("use_auth_token", None)
1049
- revision = kwargs.pop("revision", None)
1050
- subfolder = kwargs.pop("subfolder", None)
1051
- weight_name = kwargs.pop("weight_name", None)
1052
- use_safetensors = kwargs.pop("use_safetensors", None)
1053
- network_alpha = kwargs.pop("network_alpha", None)
1054
-
1055
- if use_safetensors and not is_safetensors_available():
1056
- raise ValueError(
1057
- "`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetensors"
1058
- )
1059
-
1060
- allow_pickle = False
1061
- if use_safetensors is None:
1062
- use_safetensors = is_safetensors_available()
1063
- allow_pickle = True
1064
-
1065
- user_agent = {
1066
- "file_type": "attn_procs_weights",
1067
- "framework": "pytorch",
1068
- }
1069
-
1070
- model_file = None
1071
- if not isinstance(pretrained_model_name_or_path_or_dict, dict):
1072
- # Let's first try to load .safetensors weights
1073
- if (use_safetensors and weight_name is None) or (
1074
- weight_name is not None and weight_name.endswith(".safetensors")
1075
- ):
1076
- try:
1077
- model_file = _get_model_file(
1078
- pretrained_model_name_or_path_or_dict,
1079
- weights_name=weight_name or LORA_WEIGHT_NAME_SAFE,
1080
- cache_dir=cache_dir,
1081
- force_download=force_download,
1082
- resume_download=resume_download,
1083
- proxies=proxies,
1084
- local_files_only=local_files_only,
1085
- use_auth_token=use_auth_token,
1086
- revision=revision,
1087
- subfolder=subfolder,
1088
- user_agent=user_agent,
1089
- )
1090
- state_dict = safetensors.torch.load_file(model_file, device="cpu")
1091
- except IOError as e:
1092
- if not allow_pickle:
1093
- raise e
1094
- # try loading non-safetensors weights
1095
- pass
1096
- if model_file is None:
1097
- model_file = _get_model_file(
1098
- pretrained_model_name_or_path_or_dict,
1099
- weights_name=weight_name or LORA_WEIGHT_NAME,
1100
- cache_dir=cache_dir,
1101
- force_download=force_download,
1102
- resume_download=resume_download,
1103
- proxies=proxies,
1104
- local_files_only=local_files_only,
1105
- use_auth_token=use_auth_token,
1106
- revision=revision,
1107
- subfolder=subfolder,
1108
- user_agent=user_agent,
1109
- )
1110
- state_dict = torch.load(model_file, map_location="cpu")
1111
- else:
1112
- state_dict = pretrained_model_name_or_path_or_dict
1113
-
1114
- # fill attn processors
1115
- attn_processors = {}
1116
-
1117
- is_lora = all("lora" in k for k in state_dict.keys())
1118
-
1119
- if is_lora:
1120
- lora_grouped_dict = defaultdict(dict)
1121
- for key, value in state_dict.items():
1122
- attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:])
1123
- lora_grouped_dict[attn_processor_key][sub_key] = value
1124
-
1125
- for key, value_dict in lora_grouped_dict.items():
1126
- rank = value_dict["to_k_lora.down.weight"].shape[0]
1127
- cross_attention_dim = value_dict["to_k_lora.down.weight"].shape[1]
1128
- hidden_size = value_dict["to_k_lora.up.weight"].shape[0]
1129
-
1130
- attn_processor_class = (
1131
- LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
1132
- )
1133
- attn_processors[key] = attn_processor_class(
1134
- hidden_size=hidden_size,
1135
- cross_attention_dim=cross_attention_dim,
1136
- rank=rank,
1137
- network_alpha=network_alpha,
1138
- )
1139
- attn_processors[key].load_state_dict(value_dict)
1140
-
1141
- else:
1142
- raise ValueError(f"{model_file} does not seem to be in the correct format expected by LoRA training.")
1143
-
1144
- # set correct dtype & device
1145
- attn_processors = {
1146
- k: v.to(device=self.device, dtype=self.text_encoder.dtype) for k, v in attn_processors.items()
1147
- }
1148
- return attn_processors
1149
-
1150
- @classmethod
1151
- def save_lora_weights(
1152
- self,
1153
- save_directory: Union[str, os.PathLike],
1154
- unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
1155
- text_encoder_lora_layers: Dict[str, torch.nn.Module] = None,
1156
- is_main_process: bool = True,
1157
- weight_name: str = None,
1158
- save_function: Callable = None,
1159
- safe_serialization: bool = False,
1160
- ):
1161
- r"""
1162
- Save the LoRA parameters corresponding to the UNet and text encoder.
1163
-
1164
- Arguments:
1165
- save_directory (`str` or `os.PathLike`):
1166
- Directory to save LoRA parameters to. Will be created if it doesn't exist.
1167
- unet_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
1168
- State dict of the LoRA layers corresponding to the UNet.
1169
- text_encoder_lora_layers (`Dict[str, torch.nn.Module] or `Dict[str, torch.Tensor]`):
1170
- State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text
1171
- encoder LoRA state dict because it comes 🤗 Transformers.
1172
- is_main_process (`bool`, *optional*, defaults to `True`):
1173
- Whether the process calling this is the main process or not. Useful during distributed training and you
1174
- need to call this function on all processes. In this case, set `is_main_process=True` only on the main
1175
- process to avoid race conditions.
1176
- save_function (`Callable`):
1177
- The function to use to save the state dictionary. Useful during distributed training when you need to
1178
- replace `torch.save` with another method. Can be configured with the environment variable
1179
- `DIFFUSERS_SAVE_MODE`.
1180
- """
1181
- if os.path.isfile(save_directory):
1182
- logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
1183
- return
1184
-
1185
- if save_function is None:
1186
- if safe_serialization:
1187
-
1188
- def save_function(weights, filename):
1189
- return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"})
1190
-
1191
- else:
1192
- save_function = torch.save
1193
-
1194
- os.makedirs(save_directory, exist_ok=True)
1195
-
1196
- # Create a flat dictionary.
1197
- state_dict = {}
1198
- if unet_lora_layers is not None:
1199
- weights = (
1200
- unet_lora_layers.state_dict() if isinstance(unet_lora_layers, torch.nn.Module) else unet_lora_layers
1201
- )
1202
-
1203
- unet_lora_state_dict = {f"{self.unet_name}.{module_name}": param for module_name, param in weights.items()}
1204
- state_dict.update(unet_lora_state_dict)
1205
-
1206
- if text_encoder_lora_layers is not None:
1207
- weights = (
1208
- text_encoder_lora_layers.state_dict()
1209
- if isinstance(text_encoder_lora_layers, torch.nn.Module)
1210
- else text_encoder_lora_layers
1211
- )
1212
-
1213
- text_encoder_lora_state_dict = {
1214
- f"{self.text_encoder_name}.{module_name}": param for module_name, param in weights.items()
1215
- }
1216
- state_dict.update(text_encoder_lora_state_dict)
1217
-
1218
- # Save the model
1219
- if weight_name is None:
1220
- if safe_serialization:
1221
- weight_name = LORA_WEIGHT_NAME_SAFE
1222
- else:
1223
- weight_name = LORA_WEIGHT_NAME
1224
-
1225
- save_function(state_dict, os.path.join(save_directory, weight_name))
1226
- logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}")
1227
-
1228
- def _convert_kohya_lora_to_diffusers(self, state_dict):
1229
- unet_state_dict = {}
1230
- te_state_dict = {}
1231
- network_alpha = None
1232
-
1233
- for key, value in state_dict.items():
1234
- if "lora_down" in key:
1235
- lora_name = key.split(".")[0]
1236
- lora_name_up = lora_name + ".lora_up.weight"
1237
- lora_name_alpha = lora_name + ".alpha"
1238
- if lora_name_alpha in state_dict:
1239
- alpha = state_dict[lora_name_alpha].item()
1240
- if network_alpha is None:
1241
- network_alpha = alpha
1242
- elif network_alpha != alpha:
1243
- raise ValueError("Network alpha is not consistent")
1244
-
1245
- if lora_name.startswith("lora_unet_"):
1246
- diffusers_name = key.replace("lora_unet_", "").replace("_", ".")
1247
- diffusers_name = diffusers_name.replace("down.blocks", "down_blocks")
1248
- diffusers_name = diffusers_name.replace("mid.block", "mid_block")
1249
- diffusers_name = diffusers_name.replace("up.blocks", "up_blocks")
1250
- diffusers_name = diffusers_name.replace("transformer.blocks", "transformer_blocks")
1251
- diffusers_name = diffusers_name.replace("to.q.lora", "to_q_lora")
1252
- diffusers_name = diffusers_name.replace("to.k.lora", "to_k_lora")
1253
- diffusers_name = diffusers_name.replace("to.v.lora", "to_v_lora")
1254
- diffusers_name = diffusers_name.replace("to.out.0.lora", "to_out_lora")
1255
- if "transformer_blocks" in diffusers_name:
1256
- if "attn1" in diffusers_name or "attn2" in diffusers_name:
1257
- diffusers_name = diffusers_name.replace("attn1", "attn1.processor")
1258
- diffusers_name = diffusers_name.replace("attn2", "attn2.processor")
1259
- unet_state_dict[diffusers_name] = value
1260
- unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up]
1261
- elif lora_name.startswith("lora_te_"):
1262
- diffusers_name = key.replace("lora_te_", "").replace("_", ".")
1263
- diffusers_name = diffusers_name.replace("text.model", "text_model")
1264
- diffusers_name = diffusers_name.replace("self.attn", "self_attn")
1265
- diffusers_name = diffusers_name.replace("q.proj.lora", "to_q_lora")
1266
- diffusers_name = diffusers_name.replace("k.proj.lora", "to_k_lora")
1267
- diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora")
1268
- diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora")
1269
- if "self_attn" in diffusers_name:
1270
- te_state_dict[diffusers_name] = value
1271
- te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up]
1272
-
1273
- unet_state_dict = {f"{UNET_NAME}.{module_name}": params for module_name, params in unet_state_dict.items()}
1274
- te_state_dict = {f"{TEXT_ENCODER_NAME}.{module_name}": params for module_name, params in te_state_dict.items()}
1275
- new_state_dict = {**unet_state_dict, **te_state_dict}
1276
- return new_state_dict, network_alpha
1277
-
1278
-
1279
- class FromSingleFileMixin:
1280
- """
1281
- Load model weights saved in the `.ckpt` format into a [`DiffusionPipeline`].
1282
- """
1283
-
1284
- @classmethod
1285
- def from_ckpt(cls, *args, **kwargs):
1286
- deprecation_message = "The function `from_ckpt` is deprecated in favor of `from_single_file` and will be removed in diffusers v.0.21. Please make sure to use `StableDiffusionPipeline.from_single_file(...)` instead."
1287
- deprecate("from_ckpt", "0.21.0", deprecation_message, standard_warn=False)
1288
- return cls.from_single_file(*args, **kwargs)
1289
-
1290
- @classmethod
1291
- def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
1292
- r"""
1293
- Instantiate a [`DiffusionPipeline`] from pretrained pipeline weights saved in the `.ckpt` format. The pipeline
1294
- is set in evaluation mode (`model.eval()`) by default.
1295
-
1296
- Parameters:
1297
- pretrained_model_link_or_path (`str` or `os.PathLike`, *optional*):
1298
- Can be either:
1299
- - A link to the `.ckpt` file (for example
1300
- `"https://huggingface.co/<repo_id>/blob/main/<path_to_file>.ckpt"`) on the Hub.
1301
- - A path to a *file* containing all pipeline weights.
1302
- torch_dtype (`str` or `torch.dtype`, *optional*):
1303
- Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
1304
- dtype is automatically derived from the model's weights.
1305
- force_download (`bool`, *optional*, defaults to `False`):
1306
- Whether or not to force the (re-)download of the model weights and configuration files, overriding the
1307
- cached versions if they exist.
1308
- cache_dir (`Union[str, os.PathLike]`, *optional*):
1309
- Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
1310
- is not used.
1311
- resume_download (`bool`, *optional*, defaults to `False`):
1312
- Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
1313
- incompletely downloaded files are deleted.
1314
- proxies (`Dict[str, str]`, *optional*):
1315
- A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
1316
- 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
1317
- local_files_only (`bool`, *optional*, defaults to `False`):
1318
- Whether to only load local model weights and configuration files or not. If set to True, the model
1319
- won't be downloaded from the Hub.
1320
- use_auth_token (`str` or *bool*, *optional*):
1321
- The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
1322
- `diffusers-cli login` (stored in `~/.huggingface`) is used.
1323
- revision (`str`, *optional*, defaults to `"main"`):
1324
- The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
1325
- allowed by Git.
1326
- use_safetensors (`bool`, *optional*, defaults to `None`):
1327
- If set to `None`, the safetensors weights are downloaded if they're available **and** if the
1328
- safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors
1329
- weights. If set to `False`, safetensors weights are not loaded.
1330
- extract_ema (`bool`, *optional*, defaults to `False`):
1331
- Whether to extract the EMA weights or not. Pass `True` to extract the EMA weights which usually yield
1332
- higher quality images for inference. Non-EMA weights are usually better to continue finetuning.
1333
- upcast_attention (`bool`, *optional*, defaults to `None`):
1334
- Whether the attention computation should always be upcasted.
1335
- image_size (`int`, *optional*, defaults to 512):
1336
- The image size the model was trained on. Use 512 for all Stable Diffusion v1 models and the Stable
1337
- Diffusion v2 base model. Use 768 for Stable Diffusion v2.
1338
- prediction_type (`str`, *optional*):
1339
- The prediction type the model was trained on. Use `'epsilon'` for all Stable Diffusion v1 models and
1340
- the Stable Diffusion v2 base model. Use `'v_prediction'` for Stable Diffusion v2.
1341
- num_in_channels (`int`, *optional*, defaults to `None`):
1342
- The number of input channels. If `None`, it will be automatically inferred.
1343
- scheduler_type (`str`, *optional*, defaults to `"pndm"`):
1344
- Type of scheduler to use. Should be one of `["pndm", "lms", "heun", "euler", "euler-ancestral", "dpm",
1345
- "ddim"]`.
1346
- load_safety_checker (`bool`, *optional*, defaults to `True`):
1347
- Whether to load the safety checker or not.
1348
- text_encoder (`CLIPTextModel`, *optional*, defaults to `None`):
1349
- An instance of
1350
- [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel) to use,
1351
- specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)
1352
- variant. If this parameter is `None`, the function will load a new instance of [CLIP] by itself, if
1353
- needed.
1354
- tokenizer (`CLIPTokenizer`, *optional*, defaults to `None`):
1355
- An instance of
1356
- [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer)
1357
- to use. If this parameter is `None`, the function will load a new instance of [CLIPTokenizer] by
1358
- itself, if needed.
1359
- kwargs (remaining dictionary of keyword arguments, *optional*):
1360
- Can be used to overwrite load and saveable variables (for example the pipeline components of the
1361
- specific pipeline class). The overwritten components are directly passed to the pipelines `__init__`
1362
- method. See example below for more information.
1363
-
1364
- Examples:
1365
-
1366
- ```py
1367
- >>> from diffusers import StableDiffusionPipeline
1368
-
1369
- >>> # Download pipeline from huggingface.co and cache.
1370
- >>> pipeline = StableDiffusionPipeline.from_single_file(
1371
- ... "https://huggingface.co/WarriorMama777/OrangeMixs/blob/main/Models/AbyssOrangeMix/AbyssOrangeMix.safetensors"
1372
- ... )
1373
-
1374
- >>> # Download pipeline from local file
1375
- >>> # file is downloaded under ./v1-5-pruned-emaonly.ckpt
1376
- >>> pipeline = StableDiffusionPipeline.from_single_file("./v1-5-pruned-emaonly")
1377
-
1378
- >>> # Enable float16 and move to GPU
1379
- >>> pipeline = StableDiffusionPipeline.from_single_file(
1380
- ... "https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.ckpt",
1381
- ... torch_dtype=torch.float16,
1382
- ... )
1383
- >>> pipeline.to("cuda")
1384
- ```
1385
- """
1386
- # import here to avoid circular dependency
1387
- from .pipelines.stable_diffusion.convert_from_ckpt import download_from_original_stable_diffusion_ckpt
1388
-
1389
- cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
1390
- resume_download = kwargs.pop("resume_download", False)
1391
- force_download = kwargs.pop("force_download", False)
1392
- proxies = kwargs.pop("proxies", None)
1393
- local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
1394
- use_auth_token = kwargs.pop("use_auth_token", None)
1395
- revision = kwargs.pop("revision", None)
1396
- extract_ema = kwargs.pop("extract_ema", False)
1397
- image_size = kwargs.pop("image_size", None)
1398
- scheduler_type = kwargs.pop("scheduler_type", "pndm")
1399
- num_in_channels = kwargs.pop("num_in_channels", None)
1400
- upcast_attention = kwargs.pop("upcast_attention", None)
1401
- load_safety_checker = kwargs.pop("load_safety_checker", True)
1402
- prediction_type = kwargs.pop("prediction_type", None)
1403
- text_encoder = kwargs.pop("text_encoder", None)
1404
- tokenizer = kwargs.pop("tokenizer", None)
1405
-
1406
- torch_dtype = kwargs.pop("torch_dtype", None)
1407
-
1408
- use_safetensors = kwargs.pop("use_safetensors", None if is_safetensors_available() else False)
1409
-
1410
- pipeline_name = cls.__name__
1411
- file_extension = pretrained_model_link_or_path.rsplit(".", 1)[-1]
1412
- from_safetensors = file_extension == "safetensors"
1413
-
1414
- if from_safetensors and use_safetensors is False:
1415
- raise ValueError("Make sure to install `safetensors` with `pip install safetensors`.")
1416
-
1417
- # TODO: For now we only support stable diffusion
1418
- stable_unclip = None
1419
- model_type = None
1420
- controlnet = False
1421
-
1422
- if pipeline_name == "StableDiffusionControlNetPipeline":
1423
- # Model type will be inferred from the checkpoint.
1424
- controlnet = True
1425
- elif "StableDiffusion" in pipeline_name:
1426
- # Model type will be inferred from the checkpoint.
1427
- pass
1428
- elif pipeline_name == "StableUnCLIPPipeline":
1429
- model_type = "FrozenOpenCLIPEmbedder"
1430
- stable_unclip = "txt2img"
1431
- elif pipeline_name == "StableUnCLIPImg2ImgPipeline":
1432
- model_type = "FrozenOpenCLIPEmbedder"
1433
- stable_unclip = "img2img"
1434
- elif pipeline_name == "PaintByExamplePipeline":
1435
- model_type = "PaintByExample"
1436
- elif pipeline_name == "LDMTextToImagePipeline":
1437
- model_type = "LDMTextToImage"
1438
- else:
1439
- raise ValueError(f"Unhandled pipeline class: {pipeline_name}")
1440
-
1441
- # remove huggingface url
1442
- for prefix in ["https://huggingface.co/", "huggingface.co/", "hf.co/", "https://hf.co/"]:
1443
- if pretrained_model_link_or_path.startswith(prefix):
1444
- pretrained_model_link_or_path = pretrained_model_link_or_path[len(prefix) :]
1445
-
1446
- # Code based on diffusers.pipelines.pipeline_utils.DiffusionPipeline.from_pretrained
1447
- ckpt_path = Path(pretrained_model_link_or_path)
1448
- if not ckpt_path.is_file():
1449
- # get repo_id and (potentially nested) file path of ckpt in repo
1450
- repo_id = "/".join(ckpt_path.parts[:2])
1451
- file_path = "/".join(ckpt_path.parts[2:])
1452
-
1453
- if file_path.startswith("blob/"):
1454
- file_path = file_path[len("blob/") :]
1455
-
1456
- if file_path.startswith("main/"):
1457
- file_path = file_path[len("main/") :]
1458
-
1459
- pretrained_model_link_or_path = hf_hub_download(
1460
- repo_id,
1461
- filename=file_path,
1462
- cache_dir=cache_dir,
1463
- resume_download=resume_download,
1464
- proxies=proxies,
1465
- local_files_only=local_files_only,
1466
- use_auth_token=use_auth_token,
1467
- revision=revision,
1468
- force_download=force_download,
1469
- )
1470
-
1471
- pipe = download_from_original_stable_diffusion_ckpt(
1472
- pretrained_model_link_or_path,
1473
- pipeline_class=cls,
1474
- model_type=model_type,
1475
- stable_unclip=stable_unclip,
1476
- controlnet=controlnet,
1477
- from_safetensors=from_safetensors,
1478
- extract_ema=extract_ema,
1479
- image_size=image_size,
1480
- scheduler_type=scheduler_type,
1481
- num_in_channels=num_in_channels,
1482
- upcast_attention=upcast_attention,
1483
- load_safety_checker=load_safety_checker,
1484
- prediction_type=prediction_type,
1485
- text_encoder=text_encoder,
1486
- tokenizer=tokenizer,
1487
- )
1488
-
1489
- if torch_dtype is not None:
1490
- pipe.to(torch_dtype=torch_dtype)
1491
-
1492
- return pipe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
gradio_demo/6DoF/diffusers/models/__init__.py DELETED
@@ -1,35 +0,0 @@
1
- # Copyright 2023 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- from ..utils import is_flax_available, is_torch_available
16
-
17
-
18
- if is_torch_available():
19
- from .autoencoder_kl import AutoencoderKL
20
- from .controlnet import ControlNetModel
21
- from .dual_transformer_2d import DualTransformer2DModel
22
- from .modeling_utils import ModelMixin
23
- from .prior_transformer import PriorTransformer
24
- from .t5_film_transformer import T5FilmDecoder
25
- from .transformer_2d import Transformer2DModel
26
- from .unet_1d import UNet1DModel
27
- from .unet_2d import UNet2DModel
28
- from .unet_2d_condition import UNet2DConditionModel
29
- from .unet_3d_condition import UNet3DConditionModel
30
- from .vq_model import VQModel
31
-
32
- if is_flax_available():
33
- from .controlnet_flax import FlaxControlNetModel
34
- from .unet_2d_condition_flax import FlaxUNet2DConditionModel
35
- from .vae_flax import FlaxAutoencoderKL
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
gradio_demo/6DoF/diffusers/models/activations.py DELETED
@@ -1,12 +0,0 @@
1
- from torch import nn
2
-
3
-
4
- def get_activation(act_fn):
5
- if act_fn in ["swish", "silu"]:
6
- return nn.SiLU()
7
- elif act_fn == "mish":
8
- return nn.Mish()
9
- elif act_fn == "gelu":
10
- return nn.GELU()
11
- else:
12
- raise ValueError(f"Unsupported activation function: {act_fn}")
 
 
 
 
 
 
 
 
 
 
 
 
 
gradio_demo/6DoF/diffusers/models/attention.py DELETED
@@ -1,392 +0,0 @@
1
- # Copyright 2023 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- from typing import Any, Dict, Optional
15
-
16
- import torch
17
- import torch.nn.functional as F
18
- from torch import nn
19
-
20
- from ..utils import maybe_allow_in_graph
21
- from .activations import get_activation
22
- from .attention_processor import Attention
23
- from .embeddings import CombinedTimestepLabelEmbeddings
24
-
25
-
26
- @maybe_allow_in_graph
27
- class BasicTransformerBlock(nn.Module):
28
- r"""
29
- A basic Transformer block.
30
-
31
- Parameters:
32
- dim (`int`): The number of channels in the input and output.
33
- num_attention_heads (`int`): The number of heads to use for multi-head attention.
34
- attention_head_dim (`int`): The number of channels in each head.
35
- dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
36
- cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
37
- only_cross_attention (`bool`, *optional*):
38
- Whether to use only cross-attention layers. In this case two cross attention layers are used.
39
- double_self_attention (`bool`, *optional*):
40
- Whether to use two self-attention layers. In this case no cross attention layers are used.
41
- activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
42
- num_embeds_ada_norm (:
43
- obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
44
- attention_bias (:
45
- obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
46
- """
47
-
48
- def __init__(
49
- self,
50
- dim: int,
51
- num_attention_heads: int,
52
- attention_head_dim: int,
53
- dropout=0.0,
54
- cross_attention_dim: Optional[int] = None,
55
- activation_fn: str = "geglu",
56
- num_embeds_ada_norm: Optional[int] = None,
57
- attention_bias: bool = False,
58
- only_cross_attention: bool = False,
59
- double_self_attention: bool = False,
60
- upcast_attention: bool = False,
61
- norm_elementwise_affine: bool = True,
62
- norm_type: str = "layer_norm",
63
- final_dropout: bool = False,
64
- ):
65
- super().__init__()
66
- self.only_cross_attention = only_cross_attention
67
-
68
- self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
69
- self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
70
-
71
- if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
72
- raise ValueError(
73
- f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
74
- f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
75
- )
76
-
77
- # Define 3 blocks. Each block has its own normalization layer.
78
- # 1. Self-Attn
79
- if self.use_ada_layer_norm:
80
- self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
81
- elif self.use_ada_layer_norm_zero:
82
- self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
83
- else:
84
- self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
85
- self.attn1 = Attention(
86
- query_dim=dim,
87
- heads=num_attention_heads,
88
- dim_head=attention_head_dim,
89
- dropout=dropout,
90
- bias=attention_bias,
91
- cross_attention_dim=cross_attention_dim if only_cross_attention else None,
92
- upcast_attention=upcast_attention,
93
- )
94
-
95
- # 2. Cross-Attn
96
- if cross_attention_dim is not None or double_self_attention:
97
- # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
98
- # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
99
- # the second cross attention block.
100
- self.norm2 = (
101
- AdaLayerNorm(dim, num_embeds_ada_norm)
102
- if self.use_ada_layer_norm
103
- else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
104
- )
105
- self.attn2 = Attention(
106
- query_dim=dim,
107
- cross_attention_dim=cross_attention_dim if not double_self_attention else None,
108
- heads=num_attention_heads,
109
- dim_head=attention_head_dim,
110
- dropout=dropout,
111
- bias=attention_bias,
112
- upcast_attention=upcast_attention,
113
- ) # is self-attn if encoder_hidden_states is none
114
- else:
115
- self.norm2 = None
116
- self.attn2 = None
117
-
118
- # 3. Feed-forward
119
- self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
120
- self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
121
-
122
- # let chunk size default to None
123
- self._chunk_size = None
124
- self._chunk_dim = 0
125
-
126
- def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
127
- # Sets chunk feed-forward
128
- self._chunk_size = chunk_size
129
- self._chunk_dim = dim
130
-
131
- def forward(
132
- self,
133
- hidden_states: torch.FloatTensor,
134
- attention_mask: Optional[torch.FloatTensor] = None,
135
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
136
- encoder_attention_mask: Optional[torch.FloatTensor] = None,
137
- timestep: Optional[torch.LongTensor] = None,
138
- posemb: Optional = None,
139
- cross_attention_kwargs: Dict[str, Any] = None,
140
- class_labels: Optional[torch.LongTensor] = None,
141
- ):
142
- # Notice that normalization is always applied before the real computation in the following blocks.
143
- # 1. Self-Attention
144
- if self.use_ada_layer_norm:
145
- norm_hidden_states = self.norm1(hidden_states, timestep)
146
- elif self.use_ada_layer_norm_zero:
147
- norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
148
- hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
149
- )
150
- else:
151
- norm_hidden_states = self.norm1(hidden_states)
152
-
153
- cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
154
-
155
- attn_output = self.attn1(
156
- norm_hidden_states,
157
- encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
158
- attention_mask=attention_mask,
159
- posemb=posemb, # todo in self attn, posemb shoule be [pose_in, pose_in]?
160
- **cross_attention_kwargs,
161
- )
162
- if self.use_ada_layer_norm_zero:
163
- attn_output = gate_msa.unsqueeze(1) * attn_output
164
- hidden_states = attn_output + hidden_states
165
-
166
- # 2. Cross-Attention
167
- if self.attn2 is not None:
168
- norm_hidden_states = (
169
- self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
170
- )
171
-
172
- attn_output = self.attn2(
173
- norm_hidden_states,
174
- encoder_hidden_states=encoder_hidden_states,
175
- attention_mask=encoder_attention_mask,
176
- posemb=posemb,
177
- **cross_attention_kwargs,
178
- )
179
- hidden_states = attn_output + hidden_states
180
-
181
- # 3. Feed-forward
182
- norm_hidden_states = self.norm3(hidden_states)
183
-
184
- if self.use_ada_layer_norm_zero:
185
- norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
186
-
187
- if self._chunk_size is not None:
188
- # "feed_forward_chunk_size" can be used to save memory
189
- if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
190
- raise ValueError(
191
- f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
192
- )
193
-
194
- num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
195
- ff_output = torch.cat(
196
- [self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)],
197
- dim=self._chunk_dim,
198
- )
199
- else:
200
- ff_output = self.ff(norm_hidden_states)
201
-
202
- if self.use_ada_layer_norm_zero:
203
- ff_output = gate_mlp.unsqueeze(1) * ff_output
204
-
205
- hidden_states = ff_output + hidden_states
206
-
207
- return hidden_states
208
-
209
-
210
- class FeedForward(nn.Module):
211
- r"""
212
- A feed-forward layer.
213
-
214
- Parameters:
215
- dim (`int`): The number of channels in the input.
216
- dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
217
- mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
218
- dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
219
- activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
220
- final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
221
- """
222
-
223
- def __init__(
224
- self,
225
- dim: int,
226
- dim_out: Optional[int] = None,
227
- mult: int = 4,
228
- dropout: float = 0.0,
229
- activation_fn: str = "geglu",
230
- final_dropout: bool = False,
231
- ):
232
- super().__init__()
233
- inner_dim = int(dim * mult)
234
- dim_out = dim_out if dim_out is not None else dim
235
-
236
- if activation_fn == "gelu":
237
- act_fn = GELU(dim, inner_dim)
238
- if activation_fn == "gelu-approximate":
239
- act_fn = GELU(dim, inner_dim, approximate="tanh")
240
- elif activation_fn == "geglu":
241
- act_fn = GEGLU(dim, inner_dim)
242
- elif activation_fn == "geglu-approximate":
243
- act_fn = ApproximateGELU(dim, inner_dim)
244
-
245
- self.net = nn.ModuleList([])
246
- # project in
247
- self.net.append(act_fn)
248
- # project dropout
249
- self.net.append(nn.Dropout(dropout))
250
- # project out
251
- self.net.append(nn.Linear(inner_dim, dim_out))
252
- # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
253
- if final_dropout:
254
- self.net.append(nn.Dropout(dropout))
255
-
256
- def forward(self, hidden_states):
257
- for module in self.net:
258
- hidden_states = module(hidden_states)
259
- return hidden_states
260
-
261
-
262
- class GELU(nn.Module):
263
- r"""
264
- GELU activation function with tanh approximation support with `approximate="tanh"`.
265
- """
266
-
267
- def __init__(self, dim_in: int, dim_out: int, approximate: str = "none"):
268
- super().__init__()
269
- self.proj = nn.Linear(dim_in, dim_out)
270
- self.approximate = approximate
271
-
272
- def gelu(self, gate):
273
- if gate.device.type != "mps":
274
- return F.gelu(gate, approximate=self.approximate)
275
- # mps: gelu is not implemented for float16
276
- return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to(dtype=gate.dtype)
277
-
278
- def forward(self, hidden_states):
279
- hidden_states = self.proj(hidden_states)
280
- hidden_states = self.gelu(hidden_states)
281
- return hidden_states
282
-
283
-
284
- class GEGLU(nn.Module):
285
- r"""
286
- A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.
287
-
288
- Parameters:
289
- dim_in (`int`): The number of channels in the input.
290
- dim_out (`int`): The number of channels in the output.
291
- """
292
-
293
- def __init__(self, dim_in: int, dim_out: int):
294
- super().__init__()
295
- self.proj = nn.Linear(dim_in, dim_out * 2)
296
-
297
- def gelu(self, gate):
298
- if gate.device.type != "mps":
299
- return F.gelu(gate)
300
- # mps: gelu is not implemented for float16
301
- return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
302
-
303
- def forward(self, hidden_states):
304
- hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
305
- return hidden_states * self.gelu(gate)
306
-
307
-
308
- class ApproximateGELU(nn.Module):
309
- """
310
- The approximate form of Gaussian Error Linear Unit (GELU)
311
-
312
- For more details, see section 2: https://arxiv.org/abs/1606.08415
313
- """
314
-
315
- def __init__(self, dim_in: int, dim_out: int):
316
- super().__init__()
317
- self.proj = nn.Linear(dim_in, dim_out)
318
-
319
- def forward(self, x):
320
- x = self.proj(x)
321
- return x * torch.sigmoid(1.702 * x)
322
-
323
-
324
- class AdaLayerNorm(nn.Module):
325
- """
326
- Norm layer modified to incorporate timestep embeddings.
327
- """
328
-
329
- def __init__(self, embedding_dim, num_embeddings):
330
- super().__init__()
331
- self.emb = nn.Embedding(num_embeddings, embedding_dim)
332
- self.silu = nn.SiLU()
333
- self.linear = nn.Linear(embedding_dim, embedding_dim * 2)
334
- self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False)
335
-
336
- def forward(self, x, timestep):
337
- emb = self.linear(self.silu(self.emb(timestep)))
338
- scale, shift = torch.chunk(emb, 2)
339
- x = self.norm(x) * (1 + scale) + shift
340
- return x
341
-
342
-
343
- class AdaLayerNormZero(nn.Module):
344
- """
345
- Norm layer adaptive layer norm zero (adaLN-Zero).
346
- """
347
-
348
- def __init__(self, embedding_dim, num_embeddings):
349
- super().__init__()
350
-
351
- self.emb = CombinedTimestepLabelEmbeddings(num_embeddings, embedding_dim)
352
-
353
- self.silu = nn.SiLU()
354
- self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True)
355
- self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
356
-
357
- def forward(self, x, timestep, class_labels, hidden_dtype=None):
358
- emb = self.linear(self.silu(self.emb(timestep, class_labels, hidden_dtype=hidden_dtype)))
359
- shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=1)
360
- x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
361
- return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
362
-
363
-
364
- class AdaGroupNorm(nn.Module):
365
- """
366
- GroupNorm layer modified to incorporate timestep embeddings.
367
- """
368
-
369
- def __init__(
370
- self, embedding_dim: int, out_dim: int, num_groups: int, act_fn: Optional[str] = None, eps: float = 1e-5
371
- ):
372
- super().__init__()
373
- self.num_groups = num_groups
374
- self.eps = eps
375
-
376
- if act_fn is None:
377
- self.act = None
378
- else:
379
- self.act = get_activation(act_fn)
380
-
381
- self.linear = nn.Linear(embedding_dim, out_dim * 2)
382
-
383
- def forward(self, x, emb):
384
- if self.act:
385
- emb = self.act(emb)
386
- emb = self.linear(emb)
387
- emb = emb[:, :, None, None]
388
- scale, shift = emb.chunk(2, dim=1)
389
-
390
- x = F.group_norm(x, self.num_groups, eps=self.eps)
391
- x = x * (1 + scale) + shift
392
- return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
gradio_demo/6DoF/diffusers/models/attention_flax.py DELETED
@@ -1,446 +0,0 @@
1
- # Copyright 2023 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- import functools
16
- import math
17
-
18
- import flax.linen as nn
19
- import jax
20
- import jax.numpy as jnp
21
-
22
-
23
- def _query_chunk_attention(query, key, value, precision, key_chunk_size: int = 4096):
24
- """Multi-head dot product attention with a limited number of queries."""
25
- num_kv, num_heads, k_features = key.shape[-3:]
26
- v_features = value.shape[-1]
27
- key_chunk_size = min(key_chunk_size, num_kv)
28
- query = query / jnp.sqrt(k_features)
29
-
30
- @functools.partial(jax.checkpoint, prevent_cse=False)
31
- def summarize_chunk(query, key, value):
32
- attn_weights = jnp.einsum("...qhd,...khd->...qhk", query, key, precision=precision)
33
-
34
- max_score = jnp.max(attn_weights, axis=-1, keepdims=True)
35
- max_score = jax.lax.stop_gradient(max_score)
36
- exp_weights = jnp.exp(attn_weights - max_score)
37
-
38
- exp_values = jnp.einsum("...vhf,...qhv->...qhf", value, exp_weights, precision=precision)
39
- max_score = jnp.einsum("...qhk->...qh", max_score)
40
-
41
- return (exp_values, exp_weights.sum(axis=-1), max_score)
42
-
43
- def chunk_scanner(chunk_idx):
44
- # julienne key array
45
- key_chunk = jax.lax.dynamic_slice(
46
- operand=key,
47
- start_indices=[0] * (key.ndim - 3) + [chunk_idx, 0, 0], # [...,k,h,d]
48
- slice_sizes=list(key.shape[:-3]) + [key_chunk_size, num_heads, k_features], # [...,k,h,d]
49
- )
50
-
51
- # julienne value array
52
- value_chunk = jax.lax.dynamic_slice(
53
- operand=value,
54
- start_indices=[0] * (value.ndim - 3) + [chunk_idx, 0, 0], # [...,v,h,d]
55
- slice_sizes=list(value.shape[:-3]) + [key_chunk_size, num_heads, v_features], # [...,v,h,d]
56
- )
57
-
58
- return summarize_chunk(query, key_chunk, value_chunk)
59
-
60
- chunk_values, chunk_weights, chunk_max = jax.lax.map(f=chunk_scanner, xs=jnp.arange(0, num_kv, key_chunk_size))
61
-
62
- global_max = jnp.max(chunk_max, axis=0, keepdims=True)
63
- max_diffs = jnp.exp(chunk_max - global_max)
64
-
65
- chunk_values *= jnp.expand_dims(max_diffs, axis=-1)
66
- chunk_weights *= max_diffs
67
-
68
- all_values = chunk_values.sum(axis=0)
69
- all_weights = jnp.expand_dims(chunk_weights, -1).sum(axis=0)
70
-
71
- return all_values / all_weights
72
-
73
-
74
- def jax_memory_efficient_attention(
75
- query, key, value, precision=jax.lax.Precision.HIGHEST, query_chunk_size: int = 1024, key_chunk_size: int = 4096
76
- ):
77
- r"""
78
- Flax Memory-efficient multi-head dot product attention. https://arxiv.org/abs/2112.05682v2
79
- https://github.com/AminRezaei0x443/memory-efficient-attention
80
-
81
- Args:
82
- query (`jnp.ndarray`): (batch..., query_length, head, query_key_depth_per_head)
83
- key (`jnp.ndarray`): (batch..., key_value_length, head, query_key_depth_per_head)
84
- value (`jnp.ndarray`): (batch..., key_value_length, head, value_depth_per_head)
85
- precision (`jax.lax.Precision`, *optional*, defaults to `jax.lax.Precision.HIGHEST`):
86
- numerical precision for computation
87
- query_chunk_size (`int`, *optional*, defaults to 1024):
88
- chunk size to divide query array value must divide query_length equally without remainder
89
- key_chunk_size (`int`, *optional*, defaults to 4096):
90
- chunk size to divide key and value array value must divide key_value_length equally without remainder
91
-
92
- Returns:
93
- (`jnp.ndarray`) with shape of (batch..., query_length, head, value_depth_per_head)
94
- """
95
- num_q, num_heads, q_features = query.shape[-3:]
96
-
97
- def chunk_scanner(chunk_idx, _):
98
- # julienne query array
99
- query_chunk = jax.lax.dynamic_slice(
100
- operand=query,
101
- start_indices=([0] * (query.ndim - 3)) + [chunk_idx, 0, 0], # [...,q,h,d]
102
- slice_sizes=list(query.shape[:-3]) + [min(query_chunk_size, num_q), num_heads, q_features], # [...,q,h,d]
103
- )
104
-
105
- return (
106
- chunk_idx + query_chunk_size, # unused ignore it
107
- _query_chunk_attention(
108
- query=query_chunk, key=key, value=value, precision=precision, key_chunk_size=key_chunk_size
109
- ),
110
- )
111
-
112
- _, res = jax.lax.scan(
113
- f=chunk_scanner, init=0, xs=None, length=math.ceil(num_q / query_chunk_size) # start counter # stop counter
114
- )
115
-
116
- return jnp.concatenate(res, axis=-3) # fuse the chunked result back
117
-
118
-
119
- class FlaxAttention(nn.Module):
120
- r"""
121
- A Flax multi-head attention module as described in: https://arxiv.org/abs/1706.03762
122
-
123
- Parameters:
124
- query_dim (:obj:`int`):
125
- Input hidden states dimension
126
- heads (:obj:`int`, *optional*, defaults to 8):
127
- Number of heads
128
- dim_head (:obj:`int`, *optional*, defaults to 64):
129
- Hidden states dimension inside each head
130
- dropout (:obj:`float`, *optional*, defaults to 0.0):
131
- Dropout rate
132
- use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
133
- enable memory efficient attention https://arxiv.org/abs/2112.05682
134
- dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
135
- Parameters `dtype`
136
-
137
- """
138
- query_dim: int
139
- heads: int = 8
140
- dim_head: int = 64
141
- dropout: float = 0.0
142
- use_memory_efficient_attention: bool = False
143
- dtype: jnp.dtype = jnp.float32
144
-
145
- def setup(self):
146
- inner_dim = self.dim_head * self.heads
147
- self.scale = self.dim_head**-0.5
148
-
149
- # Weights were exported with old names {to_q, to_k, to_v, to_out}
150
- self.query = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_q")
151
- self.key = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_k")
152
- self.value = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_v")
153
-
154
- self.proj_attn = nn.Dense(self.query_dim, dtype=self.dtype, name="to_out_0")
155
- self.dropout_layer = nn.Dropout(rate=self.dropout)
156
-
157
- def reshape_heads_to_batch_dim(self, tensor):
158
- batch_size, seq_len, dim = tensor.shape
159
- head_size = self.heads
160
- tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
161
- tensor = jnp.transpose(tensor, (0, 2, 1, 3))
162
- tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size)
163
- return tensor
164
-
165
- def reshape_batch_dim_to_heads(self, tensor):
166
- batch_size, seq_len, dim = tensor.shape
167
- head_size = self.heads
168
- tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
169
- tensor = jnp.transpose(tensor, (0, 2, 1, 3))
170
- tensor = tensor.reshape(batch_size // head_size, seq_len, dim * head_size)
171
- return tensor
172
-
173
- def __call__(self, hidden_states, context=None, deterministic=True):
174
- context = hidden_states if context is None else context
175
-
176
- query_proj = self.query(hidden_states)
177
- key_proj = self.key(context)
178
- value_proj = self.value(context)
179
-
180
- query_states = self.reshape_heads_to_batch_dim(query_proj)
181
- key_states = self.reshape_heads_to_batch_dim(key_proj)
182
- value_states = self.reshape_heads_to_batch_dim(value_proj)
183
-
184
- if self.use_memory_efficient_attention:
185
- query_states = query_states.transpose(1, 0, 2)
186
- key_states = key_states.transpose(1, 0, 2)
187
- value_states = value_states.transpose(1, 0, 2)
188
-
189
- # this if statement create a chunk size for each layer of the unet
190
- # the chunk size is equal to the query_length dimension of the deepest layer of the unet
191
-
192
- flatten_latent_dim = query_states.shape[-3]
193
- if flatten_latent_dim % 64 == 0:
194
- query_chunk_size = int(flatten_latent_dim / 64)
195
- elif flatten_latent_dim % 16 == 0:
196
- query_chunk_size = int(flatten_latent_dim / 16)
197
- elif flatten_latent_dim % 4 == 0:
198
- query_chunk_size = int(flatten_latent_dim / 4)
199
- else:
200
- query_chunk_size = int(flatten_latent_dim)
201
-
202
- hidden_states = jax_memory_efficient_attention(
203
- query_states, key_states, value_states, query_chunk_size=query_chunk_size, key_chunk_size=4096 * 4
204
- )
205
-
206
- hidden_states = hidden_states.transpose(1, 0, 2)
207
- else:
208
- # compute attentions
209
- attention_scores = jnp.einsum("b i d, b j d->b i j", query_states, key_states)
210
- attention_scores = attention_scores * self.scale
211
- attention_probs = nn.softmax(attention_scores, axis=2)
212
-
213
- # attend to values
214
- hidden_states = jnp.einsum("b i j, b j d -> b i d", attention_probs, value_states)
215
-
216
- hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
217
- hidden_states = self.proj_attn(hidden_states)
218
- return self.dropout_layer(hidden_states, deterministic=deterministic)
219
-
220
-
221
- class FlaxBasicTransformerBlock(nn.Module):
222
- r"""
223
- A Flax transformer block layer with `GLU` (Gated Linear Unit) activation function as described in:
224
- https://arxiv.org/abs/1706.03762
225
-
226
-
227
- Parameters:
228
- dim (:obj:`int`):
229
- Inner hidden states dimension
230
- n_heads (:obj:`int`):
231
- Number of heads
232
- d_head (:obj:`int`):
233
- Hidden states dimension inside each head
234
- dropout (:obj:`float`, *optional*, defaults to 0.0):
235
- Dropout rate
236
- only_cross_attention (`bool`, defaults to `False`):
237
- Whether to only apply cross attention.
238
- dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
239
- Parameters `dtype`
240
- use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
241
- enable memory efficient attention https://arxiv.org/abs/2112.05682
242
- """
243
- dim: int
244
- n_heads: int
245
- d_head: int
246
- dropout: float = 0.0
247
- only_cross_attention: bool = False
248
- dtype: jnp.dtype = jnp.float32
249
- use_memory_efficient_attention: bool = False
250
-
251
- def setup(self):
252
- # self attention (or cross_attention if only_cross_attention is True)
253
- self.attn1 = FlaxAttention(
254
- self.dim, self.n_heads, self.d_head, self.dropout, self.use_memory_efficient_attention, dtype=self.dtype
255
- )
256
- # cross attention
257
- self.attn2 = FlaxAttention(
258
- self.dim, self.n_heads, self.d_head, self.dropout, self.use_memory_efficient_attention, dtype=self.dtype
259
- )
260
- self.ff = FlaxFeedForward(dim=self.dim, dropout=self.dropout, dtype=self.dtype)
261
- self.norm1 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
262
- self.norm2 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
263
- self.norm3 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
264
- self.dropout_layer = nn.Dropout(rate=self.dropout)
265
-
266
- def __call__(self, hidden_states, context, deterministic=True):
267
- # self attention
268
- residual = hidden_states
269
- if self.only_cross_attention:
270
- hidden_states = self.attn1(self.norm1(hidden_states), context, deterministic=deterministic)
271
- else:
272
- hidden_states = self.attn1(self.norm1(hidden_states), deterministic=deterministic)
273
- hidden_states = hidden_states + residual
274
-
275
- # cross attention
276
- residual = hidden_states
277
- hidden_states = self.attn2(self.norm2(hidden_states), context, deterministic=deterministic)
278
- hidden_states = hidden_states + residual
279
-
280
- # feed forward
281
- residual = hidden_states
282
- hidden_states = self.ff(self.norm3(hidden_states), deterministic=deterministic)
283
- hidden_states = hidden_states + residual
284
-
285
- return self.dropout_layer(hidden_states, deterministic=deterministic)
286
-
287
-
288
- class FlaxTransformer2DModel(nn.Module):
289
- r"""
290
- A Spatial Transformer layer with Gated Linear Unit (GLU) activation function as described in:
291
- https://arxiv.org/pdf/1506.02025.pdf
292
-
293
-
294
- Parameters:
295
- in_channels (:obj:`int`):
296
- Input number of channels
297
- n_heads (:obj:`int`):
298
- Number of heads
299
- d_head (:obj:`int`):
300
- Hidden states dimension inside each head
301
- depth (:obj:`int`, *optional*, defaults to 1):
302
- Number of transformers block
303
- dropout (:obj:`float`, *optional*, defaults to 0.0):
304
- Dropout rate
305
- use_linear_projection (`bool`, defaults to `False`): tbd
306
- only_cross_attention (`bool`, defaults to `False`): tbd
307
- dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
308
- Parameters `dtype`
309
- use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
310
- enable memory efficient attention https://arxiv.org/abs/2112.05682
311
- """
312
- in_channels: int
313
- n_heads: int
314
- d_head: int
315
- depth: int = 1
316
- dropout: float = 0.0
317
- use_linear_projection: bool = False
318
- only_cross_attention: bool = False
319
- dtype: jnp.dtype = jnp.float32
320
- use_memory_efficient_attention: bool = False
321
-
322
- def setup(self):
323
- self.norm = nn.GroupNorm(num_groups=32, epsilon=1e-5)
324
-
325
- inner_dim = self.n_heads * self.d_head
326
- if self.use_linear_projection:
327
- self.proj_in = nn.Dense(inner_dim, dtype=self.dtype)
328
- else:
329
- self.proj_in = nn.Conv(
330
- inner_dim,
331
- kernel_size=(1, 1),
332
- strides=(1, 1),
333
- padding="VALID",
334
- dtype=self.dtype,
335
- )
336
-
337
- self.transformer_blocks = [
338
- FlaxBasicTransformerBlock(
339
- inner_dim,
340
- self.n_heads,
341
- self.d_head,
342
- dropout=self.dropout,
343
- only_cross_attention=self.only_cross_attention,
344
- dtype=self.dtype,
345
- use_memory_efficient_attention=self.use_memory_efficient_attention,
346
- )
347
- for _ in range(self.depth)
348
- ]
349
-
350
- if self.use_linear_projection:
351
- self.proj_out = nn.Dense(inner_dim, dtype=self.dtype)
352
- else:
353
- self.proj_out = nn.Conv(
354
- inner_dim,
355
- kernel_size=(1, 1),
356
- strides=(1, 1),
357
- padding="VALID",
358
- dtype=self.dtype,
359
- )
360
-
361
- self.dropout_layer = nn.Dropout(rate=self.dropout)
362
-
363
- def __call__(self, hidden_states, context, deterministic=True):
364
- batch, height, width, channels = hidden_states.shape
365
- residual = hidden_states
366
- hidden_states = self.norm(hidden_states)
367
- if self.use_linear_projection:
368
- hidden_states = hidden_states.reshape(batch, height * width, channels)
369
- hidden_states = self.proj_in(hidden_states)
370
- else:
371
- hidden_states = self.proj_in(hidden_states)
372
- hidden_states = hidden_states.reshape(batch, height * width, channels)
373
-
374
- for transformer_block in self.transformer_blocks:
375
- hidden_states = transformer_block(hidden_states, context, deterministic=deterministic)
376
-
377
- if self.use_linear_projection:
378
- hidden_states = self.proj_out(hidden_states)
379
- hidden_states = hidden_states.reshape(batch, height, width, channels)
380
- else:
381
- hidden_states = hidden_states.reshape(batch, height, width, channels)
382
- hidden_states = self.proj_out(hidden_states)
383
-
384
- hidden_states = hidden_states + residual
385
- return self.dropout_layer(hidden_states, deterministic=deterministic)
386
-
387
-
388
- class FlaxFeedForward(nn.Module):
389
- r"""
390
- Flax module that encapsulates two Linear layers separated by a non-linearity. It is the counterpart of PyTorch's
391
- [`FeedForward`] class, with the following simplifications:
392
- - The activation function is currently hardcoded to a gated linear unit from:
393
- https://arxiv.org/abs/2002.05202
394
- - `dim_out` is equal to `dim`.
395
- - The number of hidden dimensions is hardcoded to `dim * 4` in [`FlaxGELU`].
396
-
397
- Parameters:
398
- dim (:obj:`int`):
399
- Inner hidden states dimension
400
- dropout (:obj:`float`, *optional*, defaults to 0.0):
401
- Dropout rate
402
- dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
403
- Parameters `dtype`
404
- """
405
- dim: int
406
- dropout: float = 0.0
407
- dtype: jnp.dtype = jnp.float32
408
-
409
- def setup(self):
410
- # The second linear layer needs to be called
411
- # net_2 for now to match the index of the Sequential layer
412
- self.net_0 = FlaxGEGLU(self.dim, self.dropout, self.dtype)
413
- self.net_2 = nn.Dense(self.dim, dtype=self.dtype)
414
-
415
- def __call__(self, hidden_states, deterministic=True):
416
- hidden_states = self.net_0(hidden_states, deterministic=deterministic)
417
- hidden_states = self.net_2(hidden_states)
418
- return hidden_states
419
-
420
-
421
- class FlaxGEGLU(nn.Module):
422
- r"""
423
- Flax implementation of a Linear layer followed by the variant of the gated linear unit activation function from
424
- https://arxiv.org/abs/2002.05202.
425
-
426
- Parameters:
427
- dim (:obj:`int`):
428
- Input hidden states dimension
429
- dropout (:obj:`float`, *optional*, defaults to 0.0):
430
- Dropout rate
431
- dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
432
- Parameters `dtype`
433
- """
434
- dim: int
435
- dropout: float = 0.0
436
- dtype: jnp.dtype = jnp.float32
437
-
438
- def setup(self):
439
- inner_dim = self.dim * 4
440
- self.proj = nn.Dense(inner_dim * 2, dtype=self.dtype)
441
- self.dropout_layer = nn.Dropout(rate=self.dropout)
442
-
443
- def __call__(self, hidden_states, deterministic=True):
444
- hidden_states = self.proj(hidden_states)
445
- hidden_linear, hidden_gelu = jnp.split(hidden_states, 2, axis=2)
446
- return self.dropout_layer(hidden_linear * nn.gelu(hidden_gelu), deterministic=deterministic)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
gradio_demo/6DoF/diffusers/models/attention_processor.py DELETED
@@ -1,1684 +0,0 @@
1
- # Copyright 2023 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- from typing import Callable, Optional, Union
15
-
16
- import torch
17
- import torch.nn.functional as F
18
- from torch import nn
19
-
20
- from ..utils import deprecate, logging, maybe_allow_in_graph
21
- from ..utils.import_utils import is_xformers_available
22
-
23
-
24
- logger = logging.get_logger(__name__) # pylint: disable=invalid-name
25
-
26
-
27
- if is_xformers_available():
28
- import xformers
29
- import xformers.ops
30
- else:
31
- xformers = None
32
-
33
-
34
- # 6DoF CaPE
35
- import einops
36
- def cape_embed(f, P):
37
- # f is feature vector of shape [..., d]
38
- # P is 4x4 transformation matrix
39
- f = einops.rearrange(f, '... (d k) -> ... d k', k=4)
40
- return einops.rearrange(f@P, '... d k -> ... (d k)', k=4)
41
-
42
- @maybe_allow_in_graph
43
- class Attention(nn.Module):
44
- r"""
45
- A cross attention layer.
46
-
47
- Parameters:
48
- query_dim (`int`): The number of channels in the query.
49
- cross_attention_dim (`int`, *optional*):
50
- The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
51
- heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
52
- dim_head (`int`, *optional*, defaults to 64): The number of channels in each head.
53
- dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
54
- bias (`bool`, *optional*, defaults to False):
55
- Set to `True` for the query, key, and value linear layers to contain a bias parameter.
56
- """
57
-
58
- def __init__(
59
- self,
60
- query_dim: int,
61
- cross_attention_dim: Optional[int] = None,
62
- heads: int = 8,
63
- dim_head: int = 64,
64
- dropout: float = 0.0,
65
- bias=False,
66
- upcast_attention: bool = False,
67
- upcast_softmax: bool = False,
68
- cross_attention_norm: Optional[str] = None,
69
- cross_attention_norm_num_groups: int = 32,
70
- added_kv_proj_dim: Optional[int] = None,
71
- norm_num_groups: Optional[int] = None,
72
- spatial_norm_dim: Optional[int] = None,
73
- out_bias: bool = True,
74
- scale_qk: bool = True,
75
- only_cross_attention: bool = False,
76
- eps: float = 1e-5,
77
- rescale_output_factor: float = 1.0,
78
- residual_connection: bool = False,
79
- _from_deprecated_attn_block=False,
80
- processor: Optional["AttnProcessor"] = None,
81
- ):
82
- super().__init__()
83
- inner_dim = dim_head * heads
84
- cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
85
- self.upcast_attention = upcast_attention
86
- self.upcast_softmax = upcast_softmax
87
- self.rescale_output_factor = rescale_output_factor
88
- self.residual_connection = residual_connection
89
- self.dropout = dropout
90
-
91
- # we make use of this private variable to know whether this class is loaded
92
- # with an deprecated state dict so that we can convert it on the fly
93
- self._from_deprecated_attn_block = _from_deprecated_attn_block
94
-
95
- self.scale_qk = scale_qk
96
- self.scale = dim_head**-0.5 if self.scale_qk else 1.0
97
-
98
- self.heads = heads
99
- # for slice_size > 0 the attention score computation
100
- # is split across the batch axis to save memory
101
- # You can set slice_size with `set_attention_slice`
102
- self.sliceable_head_dim = heads
103
-
104
- self.added_kv_proj_dim = added_kv_proj_dim
105
- self.only_cross_attention = only_cross_attention
106
-
107
- if self.added_kv_proj_dim is None and self.only_cross_attention:
108
- raise ValueError(
109
- "`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`."
110
- )
111
-
112
- if norm_num_groups is not None:
113
- self.group_norm = nn.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True)
114
- else:
115
- self.group_norm = None
116
-
117
- if spatial_norm_dim is not None:
118
- self.spatial_norm = SpatialNorm(f_channels=query_dim, zq_channels=spatial_norm_dim)
119
- else:
120
- self.spatial_norm = None
121
-
122
- if cross_attention_norm is None:
123
- self.norm_cross = None
124
- elif cross_attention_norm == "layer_norm":
125
- self.norm_cross = nn.LayerNorm(cross_attention_dim)
126
- elif cross_attention_norm == "group_norm":
127
- if self.added_kv_proj_dim is not None:
128
- # The given `encoder_hidden_states` are initially of shape
129
- # (batch_size, seq_len, added_kv_proj_dim) before being projected
130
- # to (batch_size, seq_len, cross_attention_dim). The norm is applied
131
- # before the projection, so we need to use `added_kv_proj_dim` as
132
- # the number of channels for the group norm.
133
- norm_cross_num_channels = added_kv_proj_dim
134
- else:
135
- norm_cross_num_channels = cross_attention_dim
136
-
137
- self.norm_cross = nn.GroupNorm(
138
- num_channels=norm_cross_num_channels, num_groups=cross_attention_norm_num_groups, eps=1e-5, affine=True
139
- )
140
- else:
141
- raise ValueError(
142
- f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'"
143
- )
144
-
145
- self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
146
-
147
- if not self.only_cross_attention:
148
- # only relevant for the `AddedKVProcessor` classes
149
- self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
150
- self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
151
- else:
152
- self.to_k = None
153
- self.to_v = None
154
-
155
- if self.added_kv_proj_dim is not None:
156
- self.add_k_proj = nn.Linear(added_kv_proj_dim, inner_dim)
157
- self.add_v_proj = nn.Linear(added_kv_proj_dim, inner_dim)
158
-
159
- self.to_out = nn.ModuleList([])
160
- self.to_out.append(nn.Linear(inner_dim, query_dim, bias=out_bias))
161
- self.to_out.append(nn.Dropout(dropout))
162
-
163
- # set attention processor
164
- # We use the AttnProcessor2_0 by default when torch 2.x is used which uses
165
- # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
166
- # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
167
- if processor is None:
168
- processor = (
169
- AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
170
- )
171
- self.set_processor(processor)
172
-
173
- def set_use_memory_efficient_attention_xformers(
174
- self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
175
- ):
176
- is_lora = hasattr(self, "processor") and isinstance(
177
- self.processor,
178
- (LoRAAttnProcessor, LoRAAttnProcessor2_0, LoRAXFormersAttnProcessor, LoRAAttnAddedKVProcessor),
179
- )
180
- is_custom_diffusion = hasattr(self, "processor") and isinstance(
181
- self.processor, (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor)
182
- )
183
- is_added_kv_processor = hasattr(self, "processor") and isinstance(
184
- self.processor,
185
- (
186
- AttnAddedKVProcessor,
187
- AttnAddedKVProcessor2_0,
188
- SlicedAttnAddedKVProcessor,
189
- XFormersAttnAddedKVProcessor,
190
- LoRAAttnAddedKVProcessor,
191
- ),
192
- )
193
-
194
- if use_memory_efficient_attention_xformers:
195
- if is_added_kv_processor and (is_lora or is_custom_diffusion):
196
- raise NotImplementedError(
197
- f"Memory efficient attention is currently not supported for LoRA or custom diffuson for attention processor type {self.processor}"
198
- )
199
- if not is_xformers_available():
200
- raise ModuleNotFoundError(
201
- (
202
- "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
203
- " xformers"
204
- ),
205
- name="xformers",
206
- )
207
- elif not torch.cuda.is_available():
208
- raise ValueError(
209
- "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is"
210
- " only available for GPU "
211
- )
212
- else:
213
- try:
214
- # Make sure we can run the memory efficient attention
215
- _ = xformers.ops.memory_efficient_attention(
216
- torch.randn((1, 2, 40), device="cuda"),
217
- torch.randn((1, 2, 40), device="cuda"),
218
- torch.randn((1, 2, 40), device="cuda"),
219
- )
220
- except Exception as e:
221
- raise e
222
-
223
- if is_lora:
224
- # TODO (sayakpaul): should we throw a warning if someone wants to use the xformers
225
- # variant when using PT 2.0 now that we have LoRAAttnProcessor2_0?
226
- processor = LoRAXFormersAttnProcessor(
227
- hidden_size=self.processor.hidden_size,
228
- cross_attention_dim=self.processor.cross_attention_dim,
229
- rank=self.processor.rank,
230
- attention_op=attention_op,
231
- )
232
- processor.load_state_dict(self.processor.state_dict())
233
- processor.to(self.processor.to_q_lora.up.weight.device)
234
- elif is_custom_diffusion:
235
- processor = CustomDiffusionXFormersAttnProcessor(
236
- train_kv=self.processor.train_kv,
237
- train_q_out=self.processor.train_q_out,
238
- hidden_size=self.processor.hidden_size,
239
- cross_attention_dim=self.processor.cross_attention_dim,
240
- attention_op=attention_op,
241
- )
242
- processor.load_state_dict(self.processor.state_dict())
243
- if hasattr(self.processor, "to_k_custom_diffusion"):
244
- processor.to(self.processor.to_k_custom_diffusion.weight.device)
245
- elif is_added_kv_processor:
246
- # TODO(Patrick, Suraj, William) - currently xformers doesn't work for UnCLIP
247
- # which uses this type of cross attention ONLY because the attention mask of format
248
- # [0, ..., -10.000, ..., 0, ...,] is not supported
249
- # throw warning
250
- logger.info(
251
- "Memory efficient attention with `xformers` might currently not work correctly if an attention mask is required for the attention operation."
252
- )
253
- processor = XFormersAttnAddedKVProcessor(attention_op=attention_op)
254
- else:
255
- processor = XFormersAttnProcessor(attention_op=attention_op)
256
- else:
257
- if is_lora:
258
- attn_processor_class = (
259
- LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
260
- )
261
- processor = attn_processor_class(
262
- hidden_size=self.processor.hidden_size,
263
- cross_attention_dim=self.processor.cross_attention_dim,
264
- rank=self.processor.rank,
265
- )
266
- processor.load_state_dict(self.processor.state_dict())
267
- processor.to(self.processor.to_q_lora.up.weight.device)
268
- elif is_custom_diffusion:
269
- processor = CustomDiffusionAttnProcessor(
270
- train_kv=self.processor.train_kv,
271
- train_q_out=self.processor.train_q_out,
272
- hidden_size=self.processor.hidden_size,
273
- cross_attention_dim=self.processor.cross_attention_dim,
274
- )
275
- processor.load_state_dict(self.processor.state_dict())
276
- if hasattr(self.processor, "to_k_custom_diffusion"):
277
- processor.to(self.processor.to_k_custom_diffusion.weight.device)
278
- else:
279
- # set attention processor
280
- # We use the AttnProcessor2_0 by default when torch 2.x is used which uses
281
- # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
282
- # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
283
- processor = (
284
- AttnProcessor2_0()
285
- if hasattr(F, "scaled_dot_product_attention") and self.scale_qk
286
- else AttnProcessor()
287
- )
288
-
289
- self.set_processor(processor)
290
-
291
- def set_attention_slice(self, slice_size):
292
- if slice_size is not None and slice_size > self.sliceable_head_dim:
293
- raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.")
294
-
295
- if slice_size is not None and self.added_kv_proj_dim is not None:
296
- processor = SlicedAttnAddedKVProcessor(slice_size)
297
- elif slice_size is not None:
298
- processor = SlicedAttnProcessor(slice_size)
299
- elif self.added_kv_proj_dim is not None:
300
- processor = AttnAddedKVProcessor()
301
- else:
302
- # set attention processor
303
- # We use the AttnProcessor2_0 by default when torch 2.x is used which uses
304
- # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
305
- # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
306
- processor = (
307
- AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
308
- )
309
-
310
- self.set_processor(processor)
311
-
312
- def set_processor(self, processor: "AttnProcessor"):
313
- # if current processor is in `self._modules` and if passed `processor` is not, we need to
314
- # pop `processor` from `self._modules`
315
- if (
316
- hasattr(self, "processor")
317
- and isinstance(self.processor, torch.nn.Module)
318
- and not isinstance(processor, torch.nn.Module)
319
- ):
320
- logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}")
321
- self._modules.pop("processor")
322
-
323
- self.processor = processor
324
-
325
- def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, **cross_attention_kwargs):
326
- # The `Attention` class can call different attention processors / attention functions
327
- # here we simply pass along all tensors to the selected processor class
328
- # For standard processors that are defined here, `**cross_attention_kwargs` is empty
329
- return self.processor(
330
- self,
331
- hidden_states,
332
- encoder_hidden_states=encoder_hidden_states,
333
- attention_mask=attention_mask,
334
- **cross_attention_kwargs,
335
- )
336
-
337
- def batch_to_head_dim(self, tensor):
338
- head_size = self.heads
339
- batch_size, seq_len, dim = tensor.shape
340
- tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
341
- tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
342
- return tensor
343
-
344
- def head_to_batch_dim(self, tensor, out_dim=3):
345
- head_size = self.heads
346
- batch_size, seq_len, dim = tensor.shape
347
- tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
348
- tensor = tensor.permute(0, 2, 1, 3)
349
-
350
- if out_dim == 3:
351
- tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size)
352
-
353
- return tensor
354
-
355
- def get_attention_scores(self, query, key, attention_mask=None):
356
- dtype = query.dtype
357
- if self.upcast_attention:
358
- query = query.float()
359
- key = key.float()
360
-
361
- if attention_mask is None:
362
- baddbmm_input = torch.empty(
363
- query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device
364
- )
365
- beta = 0
366
- else:
367
- baddbmm_input = attention_mask
368
- beta = 1
369
-
370
- attention_scores = torch.baddbmm(
371
- baddbmm_input,
372
- query,
373
- key.transpose(-1, -2),
374
- beta=beta,
375
- alpha=self.scale,
376
- )
377
- del baddbmm_input
378
-
379
- if self.upcast_softmax:
380
- attention_scores = attention_scores.float()
381
-
382
- attention_probs = attention_scores.softmax(dim=-1)
383
- del attention_scores
384
-
385
- attention_probs = attention_probs.to(dtype)
386
-
387
- return attention_probs
388
-
389
- def prepare_attention_mask(self, attention_mask, target_length, batch_size=None, out_dim=3):
390
- if batch_size is None:
391
- deprecate(
392
- "batch_size=None",
393
- "0.0.15",
394
- (
395
- "Not passing the `batch_size` parameter to `prepare_attention_mask` can lead to incorrect"
396
- " attention mask preparation and is deprecated behavior. Please make sure to pass `batch_size` to"
397
- " `prepare_attention_mask` when preparing the attention_mask."
398
- ),
399
- )
400
- batch_size = 1
401
-
402
- head_size = self.heads
403
- if attention_mask is None:
404
- return attention_mask
405
-
406
- current_length: int = attention_mask.shape[-1]
407
- if current_length != target_length:
408
- if attention_mask.device.type == "mps":
409
- # HACK: MPS: Does not support padding by greater than dimension of input tensor.
410
- # Instead, we can manually construct the padding tensor.
411
- padding_shape = (attention_mask.shape[0], attention_mask.shape[1], target_length)
412
- padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device)
413
- attention_mask = torch.cat([attention_mask, padding], dim=2)
414
- else:
415
- # TODO: for pipelines such as stable-diffusion, padding cross-attn mask:
416
- # we want to instead pad by (0, remaining_length), where remaining_length is:
417
- # remaining_length: int = target_length - current_length
418
- # TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding
419
- attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
420
-
421
- if out_dim == 3:
422
- if attention_mask.shape[0] < batch_size * head_size:
423
- attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
424
- elif out_dim == 4:
425
- attention_mask = attention_mask.unsqueeze(1)
426
- attention_mask = attention_mask.repeat_interleave(head_size, dim=1)
427
-
428
- return attention_mask
429
-
430
- def norm_encoder_hidden_states(self, encoder_hidden_states):
431
- assert self.norm_cross is not None, "self.norm_cross must be defined to call self.norm_encoder_hidden_states"
432
-
433
- if isinstance(self.norm_cross, nn.LayerNorm):
434
- encoder_hidden_states = self.norm_cross(encoder_hidden_states)
435
- elif isinstance(self.norm_cross, nn.GroupNorm):
436
- # Group norm norms along the channels dimension and expects
437
- # input to be in the shape of (N, C, *). In this case, we want
438
- # to norm along the hidden dimension, so we need to move
439
- # (batch_size, sequence_length, hidden_size) ->
440
- # (batch_size, hidden_size, sequence_length)
441
- encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
442
- encoder_hidden_states = self.norm_cross(encoder_hidden_states)
443
- encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
444
- else:
445
- assert False
446
-
447
- return encoder_hidden_states
448
-
449
-
450
- class AttnProcessor:
451
- r"""
452
- Default processor for performing attention-related computations.
453
- """
454
-
455
- def __call__(
456
- self,
457
- attn: Attention,
458
- hidden_states,
459
- encoder_hidden_states=None,
460
- attention_mask=None,
461
- temb=None,
462
- ):
463
- residual = hidden_states
464
-
465
- if attn.spatial_norm is not None:
466
- hidden_states = attn.spatial_norm(hidden_states, temb)
467
-
468
- input_ndim = hidden_states.ndim
469
-
470
- if input_ndim == 4:
471
- batch_size, channel, height, width = hidden_states.shape
472
- hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
473
-
474
- batch_size, sequence_length, _ = (
475
- hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
476
- )
477
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
478
-
479
- if attn.group_norm is not None:
480
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
481
-
482
- query = attn.to_q(hidden_states)
483
-
484
- if encoder_hidden_states is None:
485
- encoder_hidden_states = hidden_states
486
- elif attn.norm_cross:
487
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
488
-
489
- key = attn.to_k(encoder_hidden_states)
490
- value = attn.to_v(encoder_hidden_states)
491
-
492
- query = attn.head_to_batch_dim(query)
493
- key = attn.head_to_batch_dim(key)
494
- value = attn.head_to_batch_dim(value)
495
-
496
- attention_probs = attn.get_attention_scores(query, key, attention_mask)
497
- hidden_states = torch.bmm(attention_probs, value)
498
- hidden_states = attn.batch_to_head_dim(hidden_states)
499
-
500
- # linear proj
501
- hidden_states = attn.to_out[0](hidden_states)
502
- # dropout
503
- hidden_states = attn.to_out[1](hidden_states)
504
-
505
- if input_ndim == 4:
506
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
507
-
508
- if attn.residual_connection:
509
- hidden_states = hidden_states + residual
510
-
511
- hidden_states = hidden_states / attn.rescale_output_factor
512
-
513
- return hidden_states
514
-
515
-
516
- class LoRALinearLayer(nn.Module):
517
- def __init__(self, in_features, out_features, rank=4, network_alpha=None):
518
- super().__init__()
519
-
520
- if rank > min(in_features, out_features):
521
- raise ValueError(f"LoRA rank {rank} must be less or equal than {min(in_features, out_features)}")
522
-
523
- self.down = nn.Linear(in_features, rank, bias=False)
524
- self.up = nn.Linear(rank, out_features, bias=False)
525
- # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
526
- # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
527
- self.network_alpha = network_alpha
528
- self.rank = rank
529
-
530
- nn.init.normal_(self.down.weight, std=1 / rank)
531
- nn.init.zeros_(self.up.weight)
532
-
533
- def forward(self, hidden_states):
534
- orig_dtype = hidden_states.dtype
535
- dtype = self.down.weight.dtype
536
-
537
- down_hidden_states = self.down(hidden_states.to(dtype))
538
- up_hidden_states = self.up(down_hidden_states)
539
-
540
- if self.network_alpha is not None:
541
- up_hidden_states *= self.network_alpha / self.rank
542
-
543
- return up_hidden_states.to(orig_dtype)
544
-
545
-
546
- class LoRAAttnProcessor(nn.Module):
547
- r"""
548
- Processor for implementing the LoRA attention mechanism.
549
-
550
- Args:
551
- hidden_size (`int`, *optional*):
552
- The hidden size of the attention layer.
553
- cross_attention_dim (`int`, *optional*):
554
- The number of channels in the `encoder_hidden_states`.
555
- rank (`int`, defaults to 4):
556
- The dimension of the LoRA update matrices.
557
- network_alpha (`int`, *optional*):
558
- Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.
559
- """
560
-
561
- def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None):
562
- super().__init__()
563
-
564
- self.hidden_size = hidden_size
565
- self.cross_attention_dim = cross_attention_dim
566
- self.rank = rank
567
-
568
- self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
569
- self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
570
- self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
571
- self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
572
-
573
- def __call__(
574
- self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0, temb=None
575
- ):
576
- residual = hidden_states
577
-
578
- if attn.spatial_norm is not None:
579
- hidden_states = attn.spatial_norm(hidden_states, temb)
580
-
581
- input_ndim = hidden_states.ndim
582
-
583
- if input_ndim == 4:
584
- batch_size, channel, height, width = hidden_states.shape
585
- hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
586
-
587
- batch_size, sequence_length, _ = (
588
- hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
589
- )
590
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
591
-
592
- if attn.group_norm is not None:
593
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
594
-
595
- query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
596
- query = attn.head_to_batch_dim(query)
597
-
598
- if encoder_hidden_states is None:
599
- encoder_hidden_states = hidden_states
600
- elif attn.norm_cross:
601
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
602
-
603
- key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states)
604
- value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states)
605
-
606
- key = attn.head_to_batch_dim(key)
607
- value = attn.head_to_batch_dim(value)
608
-
609
- attention_probs = attn.get_attention_scores(query, key, attention_mask)
610
- hidden_states = torch.bmm(attention_probs, value)
611
- hidden_states = attn.batch_to_head_dim(hidden_states)
612
-
613
- # linear proj
614
- hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states)
615
- # dropout
616
- hidden_states = attn.to_out[1](hidden_states)
617
-
618
- if input_ndim == 4:
619
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
620
-
621
- if attn.residual_connection:
622
- hidden_states = hidden_states + residual
623
-
624
- hidden_states = hidden_states / attn.rescale_output_factor
625
-
626
- return hidden_states
627
-
628
-
629
- class CustomDiffusionAttnProcessor(nn.Module):
630
- r"""
631
- Processor for implementing attention for the Custom Diffusion method.
632
-
633
- Args:
634
- train_kv (`bool`, defaults to `True`):
635
- Whether to newly train the key and value matrices corresponding to the text features.
636
- train_q_out (`bool`, defaults to `True`):
637
- Whether to newly train query matrices corresponding to the latent image features.
638
- hidden_size (`int`, *optional*, defaults to `None`):
639
- The hidden size of the attention layer.
640
- cross_attention_dim (`int`, *optional*, defaults to `None`):
641
- The number of channels in the `encoder_hidden_states`.
642
- out_bias (`bool`, defaults to `True`):
643
- Whether to include the bias parameter in `train_q_out`.
644
- dropout (`float`, *optional*, defaults to 0.0):
645
- The dropout probability to use.
646
- """
647
-
648
- def __init__(
649
- self,
650
- train_kv=True,
651
- train_q_out=True,
652
- hidden_size=None,
653
- cross_attention_dim=None,
654
- out_bias=True,
655
- dropout=0.0,
656
- ):
657
- super().__init__()
658
- self.train_kv = train_kv
659
- self.train_q_out = train_q_out
660
-
661
- self.hidden_size = hidden_size
662
- self.cross_attention_dim = cross_attention_dim
663
-
664
- # `_custom_diffusion` id for easy serialization and loading.
665
- if self.train_kv:
666
- self.to_k_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
667
- self.to_v_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
668
- if self.train_q_out:
669
- self.to_q_custom_diffusion = nn.Linear(hidden_size, hidden_size, bias=False)
670
- self.to_out_custom_diffusion = nn.ModuleList([])
671
- self.to_out_custom_diffusion.append(nn.Linear(hidden_size, hidden_size, bias=out_bias))
672
- self.to_out_custom_diffusion.append(nn.Dropout(dropout))
673
-
674
- def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
675
- batch_size, sequence_length, _ = hidden_states.shape
676
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
677
- if self.train_q_out:
678
- query = self.to_q_custom_diffusion(hidden_states)
679
- else:
680
- query = attn.to_q(hidden_states)
681
-
682
- if encoder_hidden_states is None:
683
- crossattn = False
684
- encoder_hidden_states = hidden_states
685
- else:
686
- crossattn = True
687
- if attn.norm_cross:
688
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
689
-
690
- if self.train_kv:
691
- key = self.to_k_custom_diffusion(encoder_hidden_states)
692
- value = self.to_v_custom_diffusion(encoder_hidden_states)
693
- else:
694
- key = attn.to_k(encoder_hidden_states)
695
- value = attn.to_v(encoder_hidden_states)
696
-
697
- if crossattn:
698
- detach = torch.ones_like(key)
699
- detach[:, :1, :] = detach[:, :1, :] * 0.0
700
- key = detach * key + (1 - detach) * key.detach()
701
- value = detach * value + (1 - detach) * value.detach()
702
-
703
- query = attn.head_to_batch_dim(query)
704
- key = attn.head_to_batch_dim(key)
705
- value = attn.head_to_batch_dim(value)
706
-
707
- attention_probs = attn.get_attention_scores(query, key, attention_mask)
708
- hidden_states = torch.bmm(attention_probs, value)
709
- hidden_states = attn.batch_to_head_dim(hidden_states)
710
-
711
- if self.train_q_out:
712
- # linear proj
713
- hidden_states = self.to_out_custom_diffusion[0](hidden_states)
714
- # dropout
715
- hidden_states = self.to_out_custom_diffusion[1](hidden_states)
716
- else:
717
- # linear proj
718
- hidden_states = attn.to_out[0](hidden_states)
719
- # dropout
720
- hidden_states = attn.to_out[1](hidden_states)
721
-
722
- return hidden_states
723
-
724
-
725
- class AttnAddedKVProcessor:
726
- r"""
727
- Processor for performing attention-related computations with extra learnable key and value matrices for the text
728
- encoder.
729
- """
730
-
731
- def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
732
- residual = hidden_states
733
- hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
734
- batch_size, sequence_length, _ = hidden_states.shape
735
-
736
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
737
-
738
- if encoder_hidden_states is None:
739
- encoder_hidden_states = hidden_states
740
- elif attn.norm_cross:
741
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
742
-
743
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
744
-
745
- query = attn.to_q(hidden_states)
746
- query = attn.head_to_batch_dim(query)
747
-
748
- encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
749
- encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
750
- encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
751
- encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
752
-
753
- if not attn.only_cross_attention:
754
- key = attn.to_k(hidden_states)
755
- value = attn.to_v(hidden_states)
756
- key = attn.head_to_batch_dim(key)
757
- value = attn.head_to_batch_dim(value)
758
- key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
759
- value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
760
- else:
761
- key = encoder_hidden_states_key_proj
762
- value = encoder_hidden_states_value_proj
763
-
764
- attention_probs = attn.get_attention_scores(query, key, attention_mask)
765
- hidden_states = torch.bmm(attention_probs, value)
766
- hidden_states = attn.batch_to_head_dim(hidden_states)
767
-
768
- # linear proj
769
- hidden_states = attn.to_out[0](hidden_states)
770
- # dropout
771
- hidden_states = attn.to_out[1](hidden_states)
772
-
773
- hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
774
- hidden_states = hidden_states + residual
775
-
776
- return hidden_states
777
-
778
-
779
- class AttnAddedKVProcessor2_0:
780
- r"""
781
- Processor for performing scaled dot-product attention (enabled by default if you're using PyTorch 2.0), with extra
782
- learnable key and value matrices for the text encoder.
783
- """
784
-
785
- def __init__(self):
786
- if not hasattr(F, "scaled_dot_product_attention"):
787
- raise ImportError(
788
- "AttnAddedKVProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
789
- )
790
-
791
- def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
792
- residual = hidden_states
793
- hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
794
- batch_size, sequence_length, _ = hidden_states.shape
795
-
796
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size, out_dim=4)
797
-
798
- if encoder_hidden_states is None:
799
- encoder_hidden_states = hidden_states
800
- elif attn.norm_cross:
801
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
802
-
803
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
804
-
805
- query = attn.to_q(hidden_states)
806
- query = attn.head_to_batch_dim(query, out_dim=4)
807
-
808
- encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
809
- encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
810
- encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj, out_dim=4)
811
- encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj, out_dim=4)
812
-
813
- if not attn.only_cross_attention:
814
- key = attn.to_k(hidden_states)
815
- value = attn.to_v(hidden_states)
816
- key = attn.head_to_batch_dim(key, out_dim=4)
817
- value = attn.head_to_batch_dim(value, out_dim=4)
818
- key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
819
- value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
820
- else:
821
- key = encoder_hidden_states_key_proj
822
- value = encoder_hidden_states_value_proj
823
-
824
- # the output of sdp = (batch, num_heads, seq_len, head_dim)
825
- # TODO: add support for attn.scale when we move to Torch 2.1
826
- hidden_states = F.scaled_dot_product_attention(
827
- query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
828
- )
829
- hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, residual.shape[1])
830
-
831
- # linear proj
832
- hidden_states = attn.to_out[0](hidden_states)
833
- # dropout
834
- hidden_states = attn.to_out[1](hidden_states)
835
-
836
- hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
837
- hidden_states = hidden_states + residual
838
-
839
- return hidden_states
840
-
841
-
842
- class LoRAAttnAddedKVProcessor(nn.Module):
843
- r"""
844
- Processor for implementing the LoRA attention mechanism with extra learnable key and value matrices for the text
845
- encoder.
846
-
847
- Args:
848
- hidden_size (`int`, *optional*):
849
- The hidden size of the attention layer.
850
- cross_attention_dim (`int`, *optional*, defaults to `None`):
851
- The number of channels in the `encoder_hidden_states`.
852
- rank (`int`, defaults to 4):
853
- The dimension of the LoRA update matrices.
854
-
855
- """
856
-
857
- def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None):
858
- super().__init__()
859
-
860
- self.hidden_size = hidden_size
861
- self.cross_attention_dim = cross_attention_dim
862
- self.rank = rank
863
-
864
- self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
865
- self.add_k_proj_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
866
- self.add_v_proj_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
867
- self.to_k_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
868
- self.to_v_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
869
- self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
870
-
871
- def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0):
872
- residual = hidden_states
873
- hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
874
- batch_size, sequence_length, _ = hidden_states.shape
875
-
876
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
877
-
878
- if encoder_hidden_states is None:
879
- encoder_hidden_states = hidden_states
880
- elif attn.norm_cross:
881
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
882
-
883
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
884
-
885
- query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
886
- query = attn.head_to_batch_dim(query)
887
-
888
- encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) + scale * self.add_k_proj_lora(
889
- encoder_hidden_states
890
- )
891
- encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) + scale * self.add_v_proj_lora(
892
- encoder_hidden_states
893
- )
894
- encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
895
- encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
896
-
897
- if not attn.only_cross_attention:
898
- key = attn.to_k(hidden_states) + scale * self.to_k_lora(hidden_states)
899
- value = attn.to_v(hidden_states) + scale * self.to_v_lora(hidden_states)
900
- key = attn.head_to_batch_dim(key)
901
- value = attn.head_to_batch_dim(value)
902
- key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
903
- value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
904
- else:
905
- key = encoder_hidden_states_key_proj
906
- value = encoder_hidden_states_value_proj
907
-
908
- attention_probs = attn.get_attention_scores(query, key, attention_mask)
909
- hidden_states = torch.bmm(attention_probs, value)
910
- hidden_states = attn.batch_to_head_dim(hidden_states)
911
-
912
- # linear proj
913
- hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states)
914
- # dropout
915
- hidden_states = attn.to_out[1](hidden_states)
916
-
917
- hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
918
- hidden_states = hidden_states + residual
919
-
920
- return hidden_states
921
-
922
-
923
- class XFormersAttnAddedKVProcessor:
924
- r"""
925
- Processor for implementing memory efficient attention using xFormers.
926
-
927
- Args:
928
- attention_op (`Callable`, *optional*, defaults to `None`):
929
- The base
930
- [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
931
- use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
932
- operator.
933
- """
934
-
935
- def __init__(self, attention_op: Optional[Callable] = None):
936
- self.attention_op = attention_op
937
-
938
- def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
939
- residual = hidden_states
940
- hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
941
- batch_size, sequence_length, _ = hidden_states.shape
942
-
943
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
944
-
945
- if encoder_hidden_states is None:
946
- encoder_hidden_states = hidden_states
947
- elif attn.norm_cross:
948
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
949
-
950
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
951
-
952
- query = attn.to_q(hidden_states)
953
- query = attn.head_to_batch_dim(query)
954
-
955
- encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
956
- encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
957
- encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
958
- encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
959
-
960
- if not attn.only_cross_attention:
961
- key = attn.to_k(hidden_states)
962
- value = attn.to_v(hidden_states)
963
- key = attn.head_to_batch_dim(key)
964
- value = attn.head_to_batch_dim(value)
965
- key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
966
- value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
967
- else:
968
- key = encoder_hidden_states_key_proj
969
- value = encoder_hidden_states_value_proj
970
-
971
- hidden_states = xformers.ops.memory_efficient_attention(
972
- query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
973
- )
974
- hidden_states = hidden_states.to(query.dtype)
975
- hidden_states = attn.batch_to_head_dim(hidden_states)
976
-
977
- # linear proj
978
- hidden_states = attn.to_out[0](hidden_states)
979
- # dropout
980
- hidden_states = attn.to_out[1](hidden_states)
981
-
982
- hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
983
- hidden_states = hidden_states + residual
984
-
985
- return hidden_states
986
-
987
-
988
- class XFormersAttnProcessor:
989
- r"""
990
- Processor for implementing memory efficient attention using xFormers.
991
-
992
- Args:
993
- attention_op (`Callable`, *optional*, defaults to `None`):
994
- The base
995
- [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
996
- use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
997
- operator.
998
- """
999
-
1000
- def __init__(self, attention_op: Optional[Callable] = None):
1001
- self.attention_op = attention_op
1002
-
1003
- def __call__(
1004
- self,
1005
- attn: Attention,
1006
- hidden_states: torch.FloatTensor,
1007
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
1008
- attention_mask: Optional[torch.FloatTensor] = None,
1009
- temb: Optional[torch.FloatTensor] = None,
1010
- posemb: Optional = None,
1011
- ):
1012
- residual = hidden_states
1013
-
1014
- if attn.spatial_norm is not None:
1015
- hidden_states = attn.spatial_norm(hidden_states, temb)
1016
-
1017
- input_ndim = hidden_states.ndim
1018
-
1019
- if input_ndim == 4:
1020
- batch_size, channel, height, width = hidden_states.shape
1021
- hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
1022
-
1023
- if posemb is not None:
1024
- # turn 2d attention into multiview attention
1025
- self_attn = encoder_hidden_states is None # check if self attn or cross attn
1026
- [p_out, p_out_inv], [p_in, p_in_inv] = posemb
1027
- t_out, t_in = p_out.shape[1], p_in.shape[1] # t size
1028
- hidden_states = einops.rearrange(hidden_states, '(b t_out) l d -> b (t_out l) d', t_out=t_out)
1029
-
1030
- batch_size, key_tokens, _ = (
1031
- hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
1032
- )
1033
-
1034
- attention_mask = attn.prepare_attention_mask(attention_mask, key_tokens, batch_size)
1035
- if attention_mask is not None:
1036
- # expand our mask's singleton query_tokens dimension:
1037
- # [batch*heads, 1, key_tokens] ->
1038
- # [batch*heads, query_tokens, key_tokens]
1039
- # so that it can be added as a bias onto the attention scores that xformers computes:
1040
- # [batch*heads, query_tokens, key_tokens]
1041
- # we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
1042
- _, query_tokens, _ = hidden_states.shape
1043
- attention_mask = attention_mask.expand(-1, query_tokens, -1)
1044
-
1045
- if attn.group_norm is not None:
1046
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
1047
-
1048
- query = attn.to_q(hidden_states)
1049
- if encoder_hidden_states is None:
1050
- encoder_hidden_states = hidden_states
1051
- elif attn.norm_cross:
1052
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
1053
-
1054
- key = attn.to_k(encoder_hidden_states)
1055
- value = attn.to_v(encoder_hidden_states)
1056
-
1057
-
1058
- # apply 6DoF, todo now only for xformer processor
1059
- if posemb is not None:
1060
- p_out_inv = einops.repeat(p_out_inv, 'b t_out f g -> b (t_out l) f g', l=query.shape[1] // t_out) # query shape
1061
- if self_attn:
1062
- p_in = einops.repeat(p_out, 'b t_out f g -> b (t_out l) f g', l=query.shape[1] // t_out) # query shape
1063
- else:
1064
- p_in = einops.repeat(p_in, 'b t_in f g -> b (t_in l) f g', l=key.shape[1] // t_in) # key shape
1065
- query = cape_embed(query, p_out_inv) # query f_q @ (p_out)^(-T) .permute(0, 1, 3, 2)
1066
- key = cape_embed(key, p_in) # key f_k @ p_in
1067
-
1068
-
1069
- query = attn.head_to_batch_dim(query).contiguous()
1070
- key = attn.head_to_batch_dim(key).contiguous()
1071
- value = attn.head_to_batch_dim(value).contiguous()
1072
-
1073
- # self-ttn (bm) l c x (bm) l c -> (bm) l c
1074
- # cross-ttn (bm) l c x b (nl) c -> (bm) l c
1075
- # reuse 2d attention for multiview attention
1076
- # self-ttn b (ml) c x b (ml) c -> b (ml) c
1077
- # cross-ttn b (ml) c x b (nl) c -> b (ml) c
1078
- hidden_states = xformers.ops.memory_efficient_attention( # query: (bm) l c -> b (ml) c; key: b (nl) c
1079
- query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
1080
- )
1081
- hidden_states = hidden_states.to(query.dtype)
1082
- hidden_states = attn.batch_to_head_dim(hidden_states)
1083
-
1084
- # linear proj
1085
- hidden_states = attn.to_out[0](hidden_states)
1086
- # dropout
1087
- hidden_states = attn.to_out[1](hidden_states)
1088
-
1089
- if posemb is not None:
1090
- # reshape back
1091
- hidden_states = einops.rearrange(hidden_states, 'b (t_out l) d -> (b t_out) l d', t_out=t_out)
1092
-
1093
- if input_ndim == 4:
1094
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
1095
-
1096
- if attn.residual_connection:
1097
- hidden_states = hidden_states + residual
1098
-
1099
- hidden_states = hidden_states / attn.rescale_output_factor
1100
-
1101
-
1102
- return hidden_states
1103
-
1104
-
1105
- class AttnProcessor2_0:
1106
- r"""
1107
- Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
1108
- """
1109
-
1110
- def __init__(self):
1111
- if not hasattr(F, "scaled_dot_product_attention"):
1112
- raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
1113
-
1114
- def __call__(
1115
- self,
1116
- attn: Attention,
1117
- hidden_states,
1118
- encoder_hidden_states=None,
1119
- attention_mask=None,
1120
- temb=None,
1121
- ):
1122
- residual = hidden_states
1123
-
1124
- if attn.spatial_norm is not None:
1125
- hidden_states = attn.spatial_norm(hidden_states, temb)
1126
-
1127
- input_ndim = hidden_states.ndim
1128
-
1129
- if input_ndim == 4:
1130
- batch_size, channel, height, width = hidden_states.shape
1131
- hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
1132
-
1133
- batch_size, sequence_length, _ = (
1134
- hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
1135
- )
1136
- inner_dim = hidden_states.shape[-1]
1137
-
1138
- if attention_mask is not None:
1139
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
1140
- # scaled_dot_product_attention expects attention_mask shape to be
1141
- # (batch, heads, source_length, target_length)
1142
- attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
1143
-
1144
- if attn.group_norm is not None:
1145
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
1146
-
1147
- query = attn.to_q(hidden_states)
1148
-
1149
- if encoder_hidden_states is None:
1150
- encoder_hidden_states = hidden_states
1151
- elif attn.norm_cross:
1152
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
1153
-
1154
- key = attn.to_k(encoder_hidden_states)
1155
- value = attn.to_v(encoder_hidden_states)
1156
-
1157
- head_dim = inner_dim // attn.heads
1158
-
1159
- query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1160
-
1161
- key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1162
- value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1163
-
1164
- # the output of sdp = (batch, num_heads, seq_len, head_dim)
1165
- # TODO: add support for attn.scale when we move to Torch 2.1
1166
- hidden_states = F.scaled_dot_product_attention(
1167
- query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
1168
- )
1169
-
1170
- hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
1171
- hidden_states = hidden_states.to(query.dtype)
1172
-
1173
- # linear proj
1174
- hidden_states = attn.to_out[0](hidden_states)
1175
- # dropout
1176
- hidden_states = attn.to_out[1](hidden_states)
1177
-
1178
- if input_ndim == 4:
1179
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
1180
-
1181
- if attn.residual_connection:
1182
- hidden_states = hidden_states + residual
1183
-
1184
- hidden_states = hidden_states / attn.rescale_output_factor
1185
-
1186
- return hidden_states
1187
-
1188
-
1189
- class LoRAXFormersAttnProcessor(nn.Module):
1190
- r"""
1191
- Processor for implementing the LoRA attention mechanism with memory efficient attention using xFormers.
1192
-
1193
- Args:
1194
- hidden_size (`int`, *optional*):
1195
- The hidden size of the attention layer.
1196
- cross_attention_dim (`int`, *optional*):
1197
- The number of channels in the `encoder_hidden_states`.
1198
- rank (`int`, defaults to 4):
1199
- The dimension of the LoRA update matrices.
1200
- attention_op (`Callable`, *optional*, defaults to `None`):
1201
- The base
1202
- [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
1203
- use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
1204
- operator.
1205
- network_alpha (`int`, *optional*):
1206
- Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.
1207
-
1208
- """
1209
-
1210
- def __init__(
1211
- self, hidden_size, cross_attention_dim, rank=4, attention_op: Optional[Callable] = None, network_alpha=None
1212
- ):
1213
- super().__init__()
1214
-
1215
- self.hidden_size = hidden_size
1216
- self.cross_attention_dim = cross_attention_dim
1217
- self.rank = rank
1218
- self.attention_op = attention_op
1219
-
1220
- self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
1221
- self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
1222
- self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
1223
- self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
1224
-
1225
- def __call__(
1226
- self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0, temb=None
1227
- ):
1228
- residual = hidden_states
1229
-
1230
- if attn.spatial_norm is not None:
1231
- hidden_states = attn.spatial_norm(hidden_states, temb)
1232
-
1233
- input_ndim = hidden_states.ndim
1234
-
1235
- if input_ndim == 4:
1236
- batch_size, channel, height, width = hidden_states.shape
1237
- hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
1238
-
1239
- batch_size, sequence_length, _ = (
1240
- hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
1241
- )
1242
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
1243
-
1244
- if attn.group_norm is not None:
1245
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
1246
-
1247
- query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
1248
- query = attn.head_to_batch_dim(query).contiguous()
1249
-
1250
- if encoder_hidden_states is None:
1251
- encoder_hidden_states = hidden_states
1252
- elif attn.norm_cross:
1253
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
1254
-
1255
- key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states)
1256
- value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states)
1257
-
1258
- key = attn.head_to_batch_dim(key).contiguous()
1259
- value = attn.head_to_batch_dim(value).contiguous()
1260
-
1261
- hidden_states = xformers.ops.memory_efficient_attention(
1262
- query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
1263
- )
1264
- hidden_states = attn.batch_to_head_dim(hidden_states)
1265
-
1266
- # linear proj
1267
- hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states)
1268
- # dropout
1269
- hidden_states = attn.to_out[1](hidden_states)
1270
-
1271
- if input_ndim == 4:
1272
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
1273
-
1274
- if attn.residual_connection:
1275
- hidden_states = hidden_states + residual
1276
-
1277
- hidden_states = hidden_states / attn.rescale_output_factor
1278
-
1279
- return hidden_states
1280
-
1281
-
1282
- class LoRAAttnProcessor2_0(nn.Module):
1283
- r"""
1284
- Processor for implementing the LoRA attention mechanism using PyTorch 2.0's memory-efficient scaled dot-product
1285
- attention.
1286
-
1287
- Args:
1288
- hidden_size (`int`):
1289
- The hidden size of the attention layer.
1290
- cross_attention_dim (`int`, *optional*):
1291
- The number of channels in the `encoder_hidden_states`.
1292
- rank (`int`, defaults to 4):
1293
- The dimension of the LoRA update matrices.
1294
- network_alpha (`int`, *optional*):
1295
- Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.
1296
- """
1297
-
1298
- def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None):
1299
- super().__init__()
1300
- if not hasattr(F, "scaled_dot_product_attention"):
1301
- raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
1302
-
1303
- self.hidden_size = hidden_size
1304
- self.cross_attention_dim = cross_attention_dim
1305
- self.rank = rank
1306
-
1307
- self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
1308
- self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
1309
- self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
1310
- self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
1311
-
1312
- def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0):
1313
- residual = hidden_states
1314
-
1315
- input_ndim = hidden_states.ndim
1316
-
1317
- if input_ndim == 4:
1318
- batch_size, channel, height, width = hidden_states.shape
1319
- hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
1320
-
1321
- batch_size, sequence_length, _ = (
1322
- hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
1323
- )
1324
- inner_dim = hidden_states.shape[-1]
1325
-
1326
- if attention_mask is not None:
1327
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
1328
- # scaled_dot_product_attention expects attention_mask shape to be
1329
- # (batch, heads, source_length, target_length)
1330
- attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
1331
-
1332
- if attn.group_norm is not None:
1333
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
1334
-
1335
- query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
1336
-
1337
- if encoder_hidden_states is None:
1338
- encoder_hidden_states = hidden_states
1339
- elif attn.norm_cross:
1340
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
1341
-
1342
- key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states)
1343
- value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states)
1344
-
1345
- head_dim = inner_dim // attn.heads
1346
- query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1347
- key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1348
- value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1349
-
1350
- # TODO: add support for attn.scale when we move to Torch 2.1
1351
- hidden_states = F.scaled_dot_product_attention(
1352
- query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
1353
- )
1354
- hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
1355
- hidden_states = hidden_states.to(query.dtype)
1356
-
1357
- # linear proj
1358
- hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states)
1359
- # dropout
1360
- hidden_states = attn.to_out[1](hidden_states)
1361
-
1362
- if input_ndim == 4:
1363
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
1364
-
1365
- if attn.residual_connection:
1366
- hidden_states = hidden_states + residual
1367
-
1368
- hidden_states = hidden_states / attn.rescale_output_factor
1369
-
1370
- return hidden_states
1371
-
1372
-
1373
- class CustomDiffusionXFormersAttnProcessor(nn.Module):
1374
- r"""
1375
- Processor for implementing memory efficient attention using xFormers for the Custom Diffusion method.
1376
-
1377
- Args:
1378
- train_kv (`bool`, defaults to `True`):
1379
- Whether to newly train the key and value matrices corresponding to the text features.
1380
- train_q_out (`bool`, defaults to `True`):
1381
- Whether to newly train query matrices corresponding to the latent image features.
1382
- hidden_size (`int`, *optional*, defaults to `None`):
1383
- The hidden size of the attention layer.
1384
- cross_attention_dim (`int`, *optional*, defaults to `None`):
1385
- The number of channels in the `encoder_hidden_states`.
1386
- out_bias (`bool`, defaults to `True`):
1387
- Whether to include the bias parameter in `train_q_out`.
1388
- dropout (`float`, *optional*, defaults to 0.0):
1389
- The dropout probability to use.
1390
- attention_op (`Callable`, *optional*, defaults to `None`):
1391
- The base
1392
- [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to use
1393
- as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best operator.
1394
- """
1395
-
1396
- def __init__(
1397
- self,
1398
- train_kv=True,
1399
- train_q_out=False,
1400
- hidden_size=None,
1401
- cross_attention_dim=None,
1402
- out_bias=True,
1403
- dropout=0.0,
1404
- attention_op: Optional[Callable] = None,
1405
- ):
1406
- super().__init__()
1407
- self.train_kv = train_kv
1408
- self.train_q_out = train_q_out
1409
-
1410
- self.hidden_size = hidden_size
1411
- self.cross_attention_dim = cross_attention_dim
1412
- self.attention_op = attention_op
1413
-
1414
- # `_custom_diffusion` id for easy serialization and loading.
1415
- if self.train_kv:
1416
- self.to_k_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
1417
- self.to_v_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
1418
- if self.train_q_out:
1419
- self.to_q_custom_diffusion = nn.Linear(hidden_size, hidden_size, bias=False)
1420
- self.to_out_custom_diffusion = nn.ModuleList([])
1421
- self.to_out_custom_diffusion.append(nn.Linear(hidden_size, hidden_size, bias=out_bias))
1422
- self.to_out_custom_diffusion.append(nn.Dropout(dropout))
1423
-
1424
- def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
1425
- batch_size, sequence_length, _ = (
1426
- hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
1427
- )
1428
-
1429
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
1430
-
1431
- if self.train_q_out:
1432
- query = self.to_q_custom_diffusion(hidden_states)
1433
- else:
1434
- query = attn.to_q(hidden_states)
1435
-
1436
- if encoder_hidden_states is None:
1437
- crossattn = False
1438
- encoder_hidden_states = hidden_states
1439
- else:
1440
- crossattn = True
1441
- if attn.norm_cross:
1442
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
1443
-
1444
- if self.train_kv:
1445
- key = self.to_k_custom_diffusion(encoder_hidden_states)
1446
- value = self.to_v_custom_diffusion(encoder_hidden_states)
1447
- else:
1448
- key = attn.to_k(encoder_hidden_states)
1449
- value = attn.to_v(encoder_hidden_states)
1450
-
1451
- if crossattn:
1452
- detach = torch.ones_like(key)
1453
- detach[:, :1, :] = detach[:, :1, :] * 0.0
1454
- key = detach * key + (1 - detach) * key.detach()
1455
- value = detach * value + (1 - detach) * value.detach()
1456
-
1457
- query = attn.head_to_batch_dim(query).contiguous()
1458
- key = attn.head_to_batch_dim(key).contiguous()
1459
- value = attn.head_to_batch_dim(value).contiguous()
1460
-
1461
- hidden_states = xformers.ops.memory_efficient_attention(
1462
- query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
1463
- )
1464
- hidden_states = hidden_states.to(query.dtype)
1465
- hidden_states = attn.batch_to_head_dim(hidden_states)
1466
-
1467
- if self.train_q_out:
1468
- # linear proj
1469
- hidden_states = self.to_out_custom_diffusion[0](hidden_states)
1470
- # dropout
1471
- hidden_states = self.to_out_custom_diffusion[1](hidden_states)
1472
- else:
1473
- # linear proj
1474
- hidden_states = attn.to_out[0](hidden_states)
1475
- # dropout
1476
- hidden_states = attn.to_out[1](hidden_states)
1477
- return hidden_states
1478
-
1479
-
1480
- class SlicedAttnProcessor:
1481
- r"""
1482
- Processor for implementing sliced attention.
1483
-
1484
- Args:
1485
- slice_size (`int`, *optional*):
1486
- The number of steps to compute attention. Uses as many slices as `attention_head_dim // slice_size`, and
1487
- `attention_head_dim` must be a multiple of the `slice_size`.
1488
- """
1489
-
1490
- def __init__(self, slice_size):
1491
- self.slice_size = slice_size
1492
-
1493
- def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
1494
- residual = hidden_states
1495
-
1496
- input_ndim = hidden_states.ndim
1497
-
1498
- if input_ndim == 4:
1499
- batch_size, channel, height, width = hidden_states.shape
1500
- hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
1501
-
1502
- batch_size, sequence_length, _ = (
1503
- hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
1504
- )
1505
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
1506
-
1507
- if attn.group_norm is not None:
1508
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
1509
-
1510
- query = attn.to_q(hidden_states)
1511
- dim = query.shape[-1]
1512
- query = attn.head_to_batch_dim(query)
1513
-
1514
- if encoder_hidden_states is None:
1515
- encoder_hidden_states = hidden_states
1516
- elif attn.norm_cross:
1517
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
1518
-
1519
- key = attn.to_k(encoder_hidden_states)
1520
- value = attn.to_v(encoder_hidden_states)
1521
- key = attn.head_to_batch_dim(key)
1522
- value = attn.head_to_batch_dim(value)
1523
-
1524
- batch_size_attention, query_tokens, _ = query.shape
1525
- hidden_states = torch.zeros(
1526
- (batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype
1527
- )
1528
-
1529
- for i in range(batch_size_attention // self.slice_size):
1530
- start_idx = i * self.slice_size
1531
- end_idx = (i + 1) * self.slice_size
1532
-
1533
- query_slice = query[start_idx:end_idx]
1534
- key_slice = key[start_idx:end_idx]
1535
- attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
1536
-
1537
- attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
1538
-
1539
- attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
1540
-
1541
- hidden_states[start_idx:end_idx] = attn_slice
1542
-
1543
- hidden_states = attn.batch_to_head_dim(hidden_states)
1544
-
1545
- # linear proj
1546
- hidden_states = attn.to_out[0](hidden_states)
1547
- # dropout
1548
- hidden_states = attn.to_out[1](hidden_states)
1549
-
1550
- if input_ndim == 4:
1551
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
1552
-
1553
- if attn.residual_connection:
1554
- hidden_states = hidden_states + residual
1555
-
1556
- hidden_states = hidden_states / attn.rescale_output_factor
1557
-
1558
- return hidden_states
1559
-
1560
-
1561
- class SlicedAttnAddedKVProcessor:
1562
- r"""
1563
- Processor for implementing sliced attention with extra learnable key and value matrices for the text encoder.
1564
-
1565
- Args:
1566
- slice_size (`int`, *optional*):
1567
- The number of steps to compute attention. Uses as many slices as `attention_head_dim // slice_size`, and
1568
- `attention_head_dim` must be a multiple of the `slice_size`.
1569
- """
1570
-
1571
- def __init__(self, slice_size):
1572
- self.slice_size = slice_size
1573
-
1574
- def __call__(self, attn: "Attention", hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None):
1575
- residual = hidden_states
1576
-
1577
- if attn.spatial_norm is not None:
1578
- hidden_states = attn.spatial_norm(hidden_states, temb)
1579
-
1580
- hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
1581
-
1582
- batch_size, sequence_length, _ = hidden_states.shape
1583
-
1584
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
1585
-
1586
- if encoder_hidden_states is None:
1587
- encoder_hidden_states = hidden_states
1588
- elif attn.norm_cross:
1589
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
1590
-
1591
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
1592
-
1593
- query = attn.to_q(hidden_states)
1594
- dim = query.shape[-1]
1595
- query = attn.head_to_batch_dim(query)
1596
-
1597
- encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
1598
- encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
1599
-
1600
- encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
1601
- encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
1602
-
1603
- if not attn.only_cross_attention:
1604
- key = attn.to_k(hidden_states)
1605
- value = attn.to_v(hidden_states)
1606
- key = attn.head_to_batch_dim(key)
1607
- value = attn.head_to_batch_dim(value)
1608
- key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
1609
- value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
1610
- else:
1611
- key = encoder_hidden_states_key_proj
1612
- value = encoder_hidden_states_value_proj
1613
-
1614
- batch_size_attention, query_tokens, _ = query.shape
1615
- hidden_states = torch.zeros(
1616
- (batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype
1617
- )
1618
-
1619
- for i in range(batch_size_attention // self.slice_size):
1620
- start_idx = i * self.slice_size
1621
- end_idx = (i + 1) * self.slice_size
1622
-
1623
- query_slice = query[start_idx:end_idx]
1624
- key_slice = key[start_idx:end_idx]
1625
- attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
1626
-
1627
- attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
1628
-
1629
- attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
1630
-
1631
- hidden_states[start_idx:end_idx] = attn_slice
1632
-
1633
- hidden_states = attn.batch_to_head_dim(hidden_states)
1634
-
1635
- # linear proj
1636
- hidden_states = attn.to_out[0](hidden_states)
1637
- # dropout
1638
- hidden_states = attn.to_out[1](hidden_states)
1639
-
1640
- hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
1641
- hidden_states = hidden_states + residual
1642
-
1643
- return hidden_states
1644
-
1645
-
1646
- AttentionProcessor = Union[
1647
- AttnProcessor,
1648
- AttnProcessor2_0,
1649
- XFormersAttnProcessor,
1650
- SlicedAttnProcessor,
1651
- AttnAddedKVProcessor,
1652
- SlicedAttnAddedKVProcessor,
1653
- AttnAddedKVProcessor2_0,
1654
- XFormersAttnAddedKVProcessor,
1655
- LoRAAttnProcessor,
1656
- LoRAXFormersAttnProcessor,
1657
- LoRAAttnProcessor2_0,
1658
- LoRAAttnAddedKVProcessor,
1659
- CustomDiffusionAttnProcessor,
1660
- CustomDiffusionXFormersAttnProcessor,
1661
- ]
1662
-
1663
-
1664
- class SpatialNorm(nn.Module):
1665
- """
1666
- Spatially conditioned normalization as defined in https://arxiv.org/abs/2209.09002
1667
- """
1668
-
1669
- def __init__(
1670
- self,
1671
- f_channels,
1672
- zq_channels,
1673
- ):
1674
- super().__init__()
1675
- self.norm_layer = nn.GroupNorm(num_channels=f_channels, num_groups=32, eps=1e-6, affine=True)
1676
- self.conv_y = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
1677
- self.conv_b = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
1678
-
1679
- def forward(self, f, zq):
1680
- f_size = f.shape[-2:]
1681
- zq = F.interpolate(zq, size=f_size, mode="nearest")
1682
- norm_f = self.norm_layer(f)
1683
- new_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
1684
- return new_f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
gradio_demo/6DoF/diffusers/models/autoencoder_kl.py DELETED
@@ -1,411 +0,0 @@
1
- # Copyright 2023 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- from dataclasses import dataclass
15
- from typing import Dict, Optional, Tuple, Union
16
-
17
- import torch
18
- import torch.nn as nn
19
-
20
- from ..configuration_utils import ConfigMixin, register_to_config
21
- from ..utils import BaseOutput, apply_forward_hook
22
- from .attention_processor import AttentionProcessor, AttnProcessor
23
- from .modeling_utils import ModelMixin
24
- from .vae import Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder
25
-
26
-
27
- @dataclass
28
- class AutoencoderKLOutput(BaseOutput):
29
- """
30
- Output of AutoencoderKL encoding method.
31
-
32
- Args:
33
- latent_dist (`DiagonalGaussianDistribution`):
34
- Encoded outputs of `Encoder` represented as the mean and logvar of `DiagonalGaussianDistribution`.
35
- `DiagonalGaussianDistribution` allows for sampling latents from the distribution.
36
- """
37
-
38
- latent_dist: "DiagonalGaussianDistribution"
39
-
40
-
41
- class AutoencoderKL(ModelMixin, ConfigMixin):
42
- r"""
43
- A VAE model with KL loss for encoding images into latents and decoding latent representations into images.
44
-
45
- This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
46
- for all models (such as downloading or saving).
47
-
48
- Parameters:
49
- in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
50
- out_channels (int, *optional*, defaults to 3): Number of channels in the output.
51
- down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
52
- Tuple of downsample block types.
53
- up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
54
- Tuple of upsample block types.
55
- block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
56
- Tuple of block output channels.
57
- act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
58
- latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space.
59
- sample_size (`int`, *optional*, defaults to `32`): Sample input size.
60
- scaling_factor (`float`, *optional*, defaults to 0.18215):
61
- The component-wise standard deviation of the trained latent space computed using the first batch of the
62
- training set. This is used to scale the latent space to have unit variance when training the diffusion
63
- model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
64
- diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
65
- / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
66
- Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
67
- """
68
-
69
- _supports_gradient_checkpointing = True
70
-
71
- @register_to_config
72
- def __init__(
73
- self,
74
- in_channels: int = 3,
75
- out_channels: int = 3,
76
- down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
77
- up_block_types: Tuple[str] = ("UpDecoderBlock2D",),
78
- block_out_channels: Tuple[int] = (64,),
79
- layers_per_block: int = 1,
80
- act_fn: str = "silu",
81
- latent_channels: int = 4,
82
- norm_num_groups: int = 32,
83
- sample_size: int = 32,
84
- scaling_factor: float = 0.18215,
85
- ):
86
- super().__init__()
87
-
88
- # pass init params to Encoder
89
- self.encoder = Encoder(
90
- in_channels=in_channels,
91
- out_channels=latent_channels,
92
- down_block_types=down_block_types,
93
- block_out_channels=block_out_channels,
94
- layers_per_block=layers_per_block,
95
- act_fn=act_fn,
96
- norm_num_groups=norm_num_groups,
97
- double_z=True,
98
- )
99
-
100
- # pass init params to Decoder
101
- self.decoder = Decoder(
102
- in_channels=latent_channels,
103
- out_channels=out_channels,
104
- up_block_types=up_block_types,
105
- block_out_channels=block_out_channels,
106
- layers_per_block=layers_per_block,
107
- norm_num_groups=norm_num_groups,
108
- act_fn=act_fn,
109
- )
110
-
111
- self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
112
- self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1)
113
-
114
- self.use_slicing = False
115
- self.use_tiling = False
116
-
117
- # only relevant if vae tiling is enabled
118
- self.tile_sample_min_size = self.config.sample_size
119
- sample_size = (
120
- self.config.sample_size[0]
121
- if isinstance(self.config.sample_size, (list, tuple))
122
- else self.config.sample_size
123
- )
124
- self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1)))
125
- self.tile_overlap_factor = 0.25
126
-
127
- def _set_gradient_checkpointing(self, module, value=False):
128
- if isinstance(module, (Encoder, Decoder)):
129
- module.gradient_checkpointing = value
130
-
131
- def enable_tiling(self, use_tiling: bool = True):
132
- r"""
133
- Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
134
- compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
135
- processing larger images.
136
- """
137
- self.use_tiling = use_tiling
138
-
139
- def disable_tiling(self):
140
- r"""
141
- Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
142
- decoding in one step.
143
- """
144
- self.enable_tiling(False)
145
-
146
- def enable_slicing(self):
147
- r"""
148
- Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
149
- compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
150
- """
151
- self.use_slicing = True
152
-
153
- def disable_slicing(self):
154
- r"""
155
- Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
156
- decoding in one step.
157
- """
158
- self.use_slicing = False
159
-
160
- @property
161
- # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
162
- def attn_processors(self) -> Dict[str, AttentionProcessor]:
163
- r"""
164
- Returns:
165
- `dict` of attention processors: A dictionary containing all attention processors used in the model with
166
- indexed by its weight name.
167
- """
168
- # set recursively
169
- processors = {}
170
-
171
- def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
172
- if hasattr(module, "set_processor"):
173
- processors[f"{name}.processor"] = module.processor
174
-
175
- for sub_name, child in module.named_children():
176
- fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
177
-
178
- return processors
179
-
180
- for name, module in self.named_children():
181
- fn_recursive_add_processors(name, module, processors)
182
-
183
- return processors
184
-
185
- # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
186
- def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
187
- r"""
188
- Sets the attention processor to use to compute attention.
189
-
190
- Parameters:
191
- processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
192
- The instantiated processor class or a dictionary of processor classes that will be set as the processor
193
- for **all** `Attention` layers.
194
-
195
- If `processor` is a dict, the key needs to define the path to the corresponding cross attention
196
- processor. This is strongly recommended when setting trainable attention processors.
197
-
198
- """
199
- count = len(self.attn_processors.keys())
200
-
201
- if isinstance(processor, dict) and len(processor) != count:
202
- raise ValueError(
203
- f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
204
- f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
205
- )
206
-
207
- def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
208
- if hasattr(module, "set_processor"):
209
- if not isinstance(processor, dict):
210
- module.set_processor(processor)
211
- else:
212
- module.set_processor(processor.pop(f"{name}.processor"))
213
-
214
- for sub_name, child in module.named_children():
215
- fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
216
-
217
- for name, module in self.named_children():
218
- fn_recursive_attn_processor(name, module, processor)
219
-
220
- # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
221
- def set_default_attn_processor(self):
222
- """
223
- Disables custom attention processors and sets the default attention implementation.
224
- """
225
- self.set_attn_processor(AttnProcessor())
226
-
227
- @apply_forward_hook
228
- def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
229
- if self.use_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size):
230
- return self.tiled_encode(x, return_dict=return_dict)
231
-
232
- if self.use_slicing and x.shape[0] > 1:
233
- encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)]
234
- h = torch.cat(encoded_slices)
235
- else:
236
- h = self.encoder(x)
237
-
238
- moments = self.quant_conv(h)
239
- posterior = DiagonalGaussianDistribution(moments)
240
-
241
- if not return_dict:
242
- return (posterior,)
243
-
244
- return AutoencoderKLOutput(latent_dist=posterior)
245
-
246
- def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
247
- if self.use_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size):
248
- return self.tiled_decode(z, return_dict=return_dict)
249
-
250
- z = self.post_quant_conv(z)
251
- dec = self.decoder(z)
252
-
253
- if not return_dict:
254
- return (dec,)
255
-
256
- return DecoderOutput(sample=dec)
257
-
258
- @apply_forward_hook
259
- def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
260
- if self.use_slicing and z.shape[0] > 1:
261
- decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
262
- decoded = torch.cat(decoded_slices)
263
- else:
264
- decoded = self._decode(z).sample
265
-
266
- if not return_dict:
267
- return (decoded,)
268
-
269
- return DecoderOutput(sample=decoded)
270
-
271
- def blend_v(self, a, b, blend_extent):
272
- blend_extent = min(a.shape[2], b.shape[2], blend_extent)
273
- for y in range(blend_extent):
274
- b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent)
275
- return b
276
-
277
- def blend_h(self, a, b, blend_extent):
278
- blend_extent = min(a.shape[3], b.shape[3], blend_extent)
279
- for x in range(blend_extent):
280
- b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent)
281
- return b
282
-
283
- def tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
284
- r"""Encode a batch of images using a tiled encoder.
285
-
286
- When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
287
- steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is
288
- different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
289
- tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
290
- output, but they should be much less noticeable.
291
-
292
- Args:
293
- x (`torch.FloatTensor`): Input batch of images.
294
- return_dict (`bool`, *optional*, defaults to `True`):
295
- Whether or not to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
296
-
297
- Returns:
298
- [`~models.autoencoder_kl.AutoencoderKLOutput`] or `tuple`:
299
- If return_dict is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain
300
- `tuple` is returned.
301
- """
302
- overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
303
- blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
304
- row_limit = self.tile_latent_min_size - blend_extent
305
-
306
- # Split the image into 512x512 tiles and encode them separately.
307
- rows = []
308
- for i in range(0, x.shape[2], overlap_size):
309
- row = []
310
- for j in range(0, x.shape[3], overlap_size):
311
- tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
312
- tile = self.encoder(tile)
313
- tile = self.quant_conv(tile)
314
- row.append(tile)
315
- rows.append(row)
316
- result_rows = []
317
- for i, row in enumerate(rows):
318
- result_row = []
319
- for j, tile in enumerate(row):
320
- # blend the above tile and the left tile
321
- # to the current tile and add the current tile to the result row
322
- if i > 0:
323
- tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
324
- if j > 0:
325
- tile = self.blend_h(row[j - 1], tile, blend_extent)
326
- result_row.append(tile[:, :, :row_limit, :row_limit])
327
- result_rows.append(torch.cat(result_row, dim=3))
328
-
329
- moments = torch.cat(result_rows, dim=2)
330
- posterior = DiagonalGaussianDistribution(moments)
331
-
332
- if not return_dict:
333
- return (posterior,)
334
-
335
- return AutoencoderKLOutput(latent_dist=posterior)
336
-
337
- def tiled_decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
338
- r"""
339
- Decode a batch of images using a tiled decoder.
340
-
341
- Args:
342
- z (`torch.FloatTensor`): Input batch of latent vectors.
343
- return_dict (`bool`, *optional*, defaults to `True`):
344
- Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
345
-
346
- Returns:
347
- [`~models.vae.DecoderOutput`] or `tuple`:
348
- If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
349
- returned.
350
- """
351
- overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor))
352
- blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor)
353
- row_limit = self.tile_sample_min_size - blend_extent
354
-
355
- # Split z into overlapping 64x64 tiles and decode them separately.
356
- # The tiles have an overlap to avoid seams between tiles.
357
- rows = []
358
- for i in range(0, z.shape[2], overlap_size):
359
- row = []
360
- for j in range(0, z.shape[3], overlap_size):
361
- tile = z[:, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size]
362
- tile = self.post_quant_conv(tile)
363
- decoded = self.decoder(tile)
364
- row.append(decoded)
365
- rows.append(row)
366
- result_rows = []
367
- for i, row in enumerate(rows):
368
- result_row = []
369
- for j, tile in enumerate(row):
370
- # blend the above tile and the left tile
371
- # to the current tile and add the current tile to the result row
372
- if i > 0:
373
- tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
374
- if j > 0:
375
- tile = self.blend_h(row[j - 1], tile, blend_extent)
376
- result_row.append(tile[:, :, :row_limit, :row_limit])
377
- result_rows.append(torch.cat(result_row, dim=3))
378
-
379
- dec = torch.cat(result_rows, dim=2)
380
- if not return_dict:
381
- return (dec,)
382
-
383
- return DecoderOutput(sample=dec)
384
-
385
- def forward(
386
- self,
387
- sample: torch.FloatTensor,
388
- sample_posterior: bool = False,
389
- return_dict: bool = True,
390
- generator: Optional[torch.Generator] = None,
391
- ) -> Union[DecoderOutput, torch.FloatTensor]:
392
- r"""
393
- Args:
394
- sample (`torch.FloatTensor`): Input sample.
395
- sample_posterior (`bool`, *optional*, defaults to `False`):
396
- Whether to sample from the posterior.
397
- return_dict (`bool`, *optional*, defaults to `True`):
398
- Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
399
- """
400
- x = sample
401
- posterior = self.encode(x).latent_dist
402
- if sample_posterior:
403
- z = posterior.sample(generator=generator)
404
- else:
405
- z = posterior.mode()
406
- dec = self.decode(z).sample
407
-
408
- if not return_dict:
409
- return (dec,)
410
-
411
- return DecoderOutput(sample=dec)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
gradio_demo/6DoF/diffusers/models/controlnet.py DELETED
@@ -1,705 +0,0 @@
1
- # Copyright 2023 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- from dataclasses import dataclass
15
- from typing import Any, Dict, List, Optional, Tuple, Union
16
-
17
- import torch
18
- from torch import nn
19
- from torch.nn import functional as F
20
-
21
- from ..configuration_utils import ConfigMixin, register_to_config
22
- from ..utils import BaseOutput, logging
23
- from .attention_processor import AttentionProcessor, AttnProcessor
24
- from .embeddings import TimestepEmbedding, Timesteps
25
- from .modeling_utils import ModelMixin
26
- from .unet_2d_blocks import (
27
- CrossAttnDownBlock2D,
28
- DownBlock2D,
29
- UNetMidBlock2DCrossAttn,
30
- get_down_block,
31
- )
32
- from .unet_2d_condition import UNet2DConditionModel
33
-
34
-
35
- logger = logging.get_logger(__name__) # pylint: disable=invalid-name
36
-
37
-
38
- @dataclass
39
- class ControlNetOutput(BaseOutput):
40
- """
41
- The output of [`ControlNetModel`].
42
-
43
- Args:
44
- down_block_res_samples (`tuple[torch.Tensor]`):
45
- A tuple of downsample activations at different resolutions for each downsampling block. Each tensor should
46
- be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be
47
- used to condition the original UNet's downsampling activations.
48
- mid_down_block_re_sample (`torch.Tensor`):
49
- The activation of the midde block (the lowest sample resolution). Each tensor should be of shape
50
- `(batch_size, channel * lowest_resolution, height // lowest_resolution, width // lowest_resolution)`.
51
- Output can be used to condition the original UNet's middle block activation.
52
- """
53
-
54
- down_block_res_samples: Tuple[torch.Tensor]
55
- mid_block_res_sample: torch.Tensor
56
-
57
-
58
- class ControlNetConditioningEmbedding(nn.Module):
59
- """
60
- Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN
61
- [11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized
62
- training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the
63
- convolution size. We use a tiny network E(·) of four convolution layers with 4 × 4 kernels and 2 × 2 strides
64
- (activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full
65
- model) to encode image-space conditions ... into feature maps ..."
66
- """
67
-
68
- def __init__(
69
- self,
70
- conditioning_embedding_channels: int,
71
- conditioning_channels: int = 3,
72
- block_out_channels: Tuple[int] = (16, 32, 96, 256),
73
- ):
74
- super().__init__()
75
-
76
- self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1)
77
-
78
- self.blocks = nn.ModuleList([])
79
-
80
- for i in range(len(block_out_channels) - 1):
81
- channel_in = block_out_channels[i]
82
- channel_out = block_out_channels[i + 1]
83
- self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1))
84
- self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2))
85
-
86
- self.conv_out = zero_module(
87
- nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1)
88
- )
89
-
90
- def forward(self, conditioning):
91
- embedding = self.conv_in(conditioning)
92
- embedding = F.silu(embedding)
93
-
94
- for block in self.blocks:
95
- embedding = block(embedding)
96
- embedding = F.silu(embedding)
97
-
98
- embedding = self.conv_out(embedding)
99
-
100
- return embedding
101
-
102
-
103
- class ControlNetModel(ModelMixin, ConfigMixin):
104
- """
105
- A ControlNet model.
106
-
107
- Args:
108
- in_channels (`int`, defaults to 4):
109
- The number of channels in the input sample.
110
- flip_sin_to_cos (`bool`, defaults to `True`):
111
- Whether to flip the sin to cos in the time embedding.
112
- freq_shift (`int`, defaults to 0):
113
- The frequency shift to apply to the time embedding.
114
- down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
115
- The tuple of downsample blocks to use.
116
- only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`):
117
- block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`):
118
- The tuple of output channels for each block.
119
- layers_per_block (`int`, defaults to 2):
120
- The number of layers per block.
121
- downsample_padding (`int`, defaults to 1):
122
- The padding to use for the downsampling convolution.
123
- mid_block_scale_factor (`float`, defaults to 1):
124
- The scale factor to use for the mid block.
125
- act_fn (`str`, defaults to "silu"):
126
- The activation function to use.
127
- norm_num_groups (`int`, *optional*, defaults to 32):
128
- The number of groups to use for the normalization. If None, normalization and activation layers is skipped
129
- in post-processing.
130
- norm_eps (`float`, defaults to 1e-5):
131
- The epsilon to use for the normalization.
132
- cross_attention_dim (`int`, defaults to 1280):
133
- The dimension of the cross attention features.
134
- attention_head_dim (`Union[int, Tuple[int]]`, defaults to 8):
135
- The dimension of the attention heads.
136
- use_linear_projection (`bool`, defaults to `False`):
137
- class_embed_type (`str`, *optional*, defaults to `None`):
138
- The type of class embedding to use which is ultimately summed with the time embeddings. Choose from None,
139
- `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
140
- num_class_embeds (`int`, *optional*, defaults to 0):
141
- Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
142
- class conditioning with `class_embed_type` equal to `None`.
143
- upcast_attention (`bool`, defaults to `False`):
144
- resnet_time_scale_shift (`str`, defaults to `"default"`):
145
- Time scale shift config for ResNet blocks (see `ResnetBlock2D`). Choose from `default` or `scale_shift`.
146
- projection_class_embeddings_input_dim (`int`, *optional*, defaults to `None`):
147
- The dimension of the `class_labels` input when `class_embed_type="projection"`. Required when
148
- `class_embed_type="projection"`.
149
- controlnet_conditioning_channel_order (`str`, defaults to `"rgb"`):
150
- The channel order of conditional image. Will convert to `rgb` if it's `bgr`.
151
- conditioning_embedding_out_channels (`tuple[int]`, *optional*, defaults to `(16, 32, 96, 256)`):
152
- The tuple of output channel for each block in the `conditioning_embedding` layer.
153
- global_pool_conditions (`bool`, defaults to `False`):
154
- """
155
-
156
- _supports_gradient_checkpointing = True
157
-
158
- @register_to_config
159
- def __init__(
160
- self,
161
- in_channels: int = 4,
162
- conditioning_channels: int = 3,
163
- flip_sin_to_cos: bool = True,
164
- freq_shift: int = 0,
165
- down_block_types: Tuple[str] = (
166
- "CrossAttnDownBlock2D",
167
- "CrossAttnDownBlock2D",
168
- "CrossAttnDownBlock2D",
169
- "DownBlock2D",
170
- ),
171
- only_cross_attention: Union[bool, Tuple[bool]] = False,
172
- block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
173
- layers_per_block: int = 2,
174
- downsample_padding: int = 1,
175
- mid_block_scale_factor: float = 1,
176
- act_fn: str = "silu",
177
- norm_num_groups: Optional[int] = 32,
178
- norm_eps: float = 1e-5,
179
- cross_attention_dim: int = 1280,
180
- attention_head_dim: Union[int, Tuple[int]] = 8,
181
- num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
182
- use_linear_projection: bool = False,
183
- class_embed_type: Optional[str] = None,
184
- num_class_embeds: Optional[int] = None,
185
- upcast_attention: bool = False,
186
- resnet_time_scale_shift: str = "default",
187
- projection_class_embeddings_input_dim: Optional[int] = None,
188
- controlnet_conditioning_channel_order: str = "rgb",
189
- conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256),
190
- global_pool_conditions: bool = False,
191
- ):
192
- super().__init__()
193
-
194
- # If `num_attention_heads` is not defined (which is the case for most models)
195
- # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
196
- # The reason for this behavior is to correct for incorrectly named variables that were introduced
197
- # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
198
- # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
199
- # which is why we correct for the naming here.
200
- num_attention_heads = num_attention_heads or attention_head_dim
201
-
202
- # Check inputs
203
- if len(block_out_channels) != len(down_block_types):
204
- raise ValueError(
205
- f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
206
- )
207
-
208
- if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
209
- raise ValueError(
210
- f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
211
- )
212
-
213
- if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
214
- raise ValueError(
215
- f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
216
- )
217
-
218
- # input
219
- conv_in_kernel = 3
220
- conv_in_padding = (conv_in_kernel - 1) // 2
221
- self.conv_in = nn.Conv2d(
222
- in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
223
- )
224
-
225
- # time
226
- time_embed_dim = block_out_channels[0] * 4
227
-
228
- self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
229
- timestep_input_dim = block_out_channels[0]
230
-
231
- self.time_embedding = TimestepEmbedding(
232
- timestep_input_dim,
233
- time_embed_dim,
234
- act_fn=act_fn,
235
- )
236
-
237
- # class embedding
238
- if class_embed_type is None and num_class_embeds is not None:
239
- self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
240
- elif class_embed_type == "timestep":
241
- self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
242
- elif class_embed_type == "identity":
243
- self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
244
- elif class_embed_type == "projection":
245
- if projection_class_embeddings_input_dim is None:
246
- raise ValueError(
247
- "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
248
- )
249
- # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
250
- # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
251
- # 2. it projects from an arbitrary input dimension.
252
- #
253
- # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
254
- # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
255
- # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
256
- self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
257
- else:
258
- self.class_embedding = None
259
-
260
- # control net conditioning embedding
261
- self.controlnet_cond_embedding = ControlNetConditioningEmbedding(
262
- conditioning_embedding_channels=block_out_channels[0],
263
- block_out_channels=conditioning_embedding_out_channels,
264
- conditioning_channels=conditioning_channels,
265
- )
266
-
267
- self.down_blocks = nn.ModuleList([])
268
- self.controlnet_down_blocks = nn.ModuleList([])
269
-
270
- if isinstance(only_cross_attention, bool):
271
- only_cross_attention = [only_cross_attention] * len(down_block_types)
272
-
273
- if isinstance(attention_head_dim, int):
274
- attention_head_dim = (attention_head_dim,) * len(down_block_types)
275
-
276
- if isinstance(num_attention_heads, int):
277
- num_attention_heads = (num_attention_heads,) * len(down_block_types)
278
-
279
- # down
280
- output_channel = block_out_channels[0]
281
-
282
- controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
283
- controlnet_block = zero_module(controlnet_block)
284
- self.controlnet_down_blocks.append(controlnet_block)
285
-
286
- for i, down_block_type in enumerate(down_block_types):
287
- input_channel = output_channel
288
- output_channel = block_out_channels[i]
289
- is_final_block = i == len(block_out_channels) - 1
290
-
291
- down_block = get_down_block(
292
- down_block_type,
293
- num_layers=layers_per_block,
294
- in_channels=input_channel,
295
- out_channels=output_channel,
296
- temb_channels=time_embed_dim,
297
- add_downsample=not is_final_block,
298
- resnet_eps=norm_eps,
299
- resnet_act_fn=act_fn,
300
- resnet_groups=norm_num_groups,
301
- cross_attention_dim=cross_attention_dim,
302
- num_attention_heads=num_attention_heads[i],
303
- attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
304
- downsample_padding=downsample_padding,
305
- use_linear_projection=use_linear_projection,
306
- only_cross_attention=only_cross_attention[i],
307
- upcast_attention=upcast_attention,
308
- resnet_time_scale_shift=resnet_time_scale_shift,
309
- )
310
- self.down_blocks.append(down_block)
311
-
312
- for _ in range(layers_per_block):
313
- controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
314
- controlnet_block = zero_module(controlnet_block)
315
- self.controlnet_down_blocks.append(controlnet_block)
316
-
317
- if not is_final_block:
318
- controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
319
- controlnet_block = zero_module(controlnet_block)
320
- self.controlnet_down_blocks.append(controlnet_block)
321
-
322
- # mid
323
- mid_block_channel = block_out_channels[-1]
324
-
325
- controlnet_block = nn.Conv2d(mid_block_channel, mid_block_channel, kernel_size=1)
326
- controlnet_block = zero_module(controlnet_block)
327
- self.controlnet_mid_block = controlnet_block
328
-
329
- self.mid_block = UNetMidBlock2DCrossAttn(
330
- in_channels=mid_block_channel,
331
- temb_channels=time_embed_dim,
332
- resnet_eps=norm_eps,
333
- resnet_act_fn=act_fn,
334
- output_scale_factor=mid_block_scale_factor,
335
- resnet_time_scale_shift=resnet_time_scale_shift,
336
- cross_attention_dim=cross_attention_dim,
337
- num_attention_heads=num_attention_heads[-1],
338
- resnet_groups=norm_num_groups,
339
- use_linear_projection=use_linear_projection,
340
- upcast_attention=upcast_attention,
341
- )
342
-
343
- @classmethod
344
- def from_unet(
345
- cls,
346
- unet: UNet2DConditionModel,
347
- controlnet_conditioning_channel_order: str = "rgb",
348
- conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256),
349
- load_weights_from_unet: bool = True,
350
- ):
351
- r"""
352
- Instantiate a [`ControlNetModel`] from [`UNet2DConditionModel`].
353
-
354
- Parameters:
355
- unet (`UNet2DConditionModel`):
356
- The UNet model weights to copy to the [`ControlNetModel`]. All configuration options are also copied
357
- where applicable.
358
- """
359
- controlnet = cls(
360
- in_channels=unet.config.in_channels,
361
- flip_sin_to_cos=unet.config.flip_sin_to_cos,
362
- freq_shift=unet.config.freq_shift,
363
- down_block_types=unet.config.down_block_types,
364
- only_cross_attention=unet.config.only_cross_attention,
365
- block_out_channels=unet.config.block_out_channels,
366
- layers_per_block=unet.config.layers_per_block,
367
- downsample_padding=unet.config.downsample_padding,
368
- mid_block_scale_factor=unet.config.mid_block_scale_factor,
369
- act_fn=unet.config.act_fn,
370
- norm_num_groups=unet.config.norm_num_groups,
371
- norm_eps=unet.config.norm_eps,
372
- cross_attention_dim=unet.config.cross_attention_dim,
373
- attention_head_dim=unet.config.attention_head_dim,
374
- num_attention_heads=unet.config.num_attention_heads,
375
- use_linear_projection=unet.config.use_linear_projection,
376
- class_embed_type=unet.config.class_embed_type,
377
- num_class_embeds=unet.config.num_class_embeds,
378
- upcast_attention=unet.config.upcast_attention,
379
- resnet_time_scale_shift=unet.config.resnet_time_scale_shift,
380
- projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim,
381
- controlnet_conditioning_channel_order=controlnet_conditioning_channel_order,
382
- conditioning_embedding_out_channels=conditioning_embedding_out_channels,
383
- )
384
-
385
- if load_weights_from_unet:
386
- controlnet.conv_in.load_state_dict(unet.conv_in.state_dict())
387
- controlnet.time_proj.load_state_dict(unet.time_proj.state_dict())
388
- controlnet.time_embedding.load_state_dict(unet.time_embedding.state_dict())
389
-
390
- if controlnet.class_embedding:
391
- controlnet.class_embedding.load_state_dict(unet.class_embedding.state_dict())
392
-
393
- controlnet.down_blocks.load_state_dict(unet.down_blocks.state_dict())
394
- controlnet.mid_block.load_state_dict(unet.mid_block.state_dict())
395
-
396
- return controlnet
397
-
398
- @property
399
- # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
400
- def attn_processors(self) -> Dict[str, AttentionProcessor]:
401
- r"""
402
- Returns:
403
- `dict` of attention processors: A dictionary containing all attention processors used in the model with
404
- indexed by its weight name.
405
- """
406
- # set recursively
407
- processors = {}
408
-
409
- def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
410
- if hasattr(module, "set_processor"):
411
- processors[f"{name}.processor"] = module.processor
412
-
413
- for sub_name, child in module.named_children():
414
- fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
415
-
416
- return processors
417
-
418
- for name, module in self.named_children():
419
- fn_recursive_add_processors(name, module, processors)
420
-
421
- return processors
422
-
423
- # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
424
- def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
425
- r"""
426
- Sets the attention processor to use to compute attention.
427
-
428
- Parameters:
429
- processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
430
- The instantiated processor class or a dictionary of processor classes that will be set as the processor
431
- for **all** `Attention` layers.
432
-
433
- If `processor` is a dict, the key needs to define the path to the corresponding cross attention
434
- processor. This is strongly recommended when setting trainable attention processors.
435
-
436
- """
437
- count = len(self.attn_processors.keys())
438
-
439
- if isinstance(processor, dict) and len(processor) != count:
440
- raise ValueError(
441
- f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
442
- f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
443
- )
444
-
445
- def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
446
- if hasattr(module, "set_processor"):
447
- if not isinstance(processor, dict):
448
- module.set_processor(processor)
449
- else:
450
- module.set_processor(processor.pop(f"{name}.processor"))
451
-
452
- for sub_name, child in module.named_children():
453
- fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
454
-
455
- for name, module in self.named_children():
456
- fn_recursive_attn_processor(name, module, processor)
457
-
458
- # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
459
- def set_default_attn_processor(self):
460
- """
461
- Disables custom attention processors and sets the default attention implementation.
462
- """
463
- self.set_attn_processor(AttnProcessor())
464
-
465
- # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice
466
- def set_attention_slice(self, slice_size):
467
- r"""
468
- Enable sliced attention computation.
469
-
470
- When this option is enabled, the attention module splits the input tensor in slices to compute attention in
471
- several steps. This is useful for saving some memory in exchange for a small decrease in speed.
472
-
473
- Args:
474
- slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
475
- When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
476
- `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
477
- provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
478
- must be a multiple of `slice_size`.
479
- """
480
- sliceable_head_dims = []
481
-
482
- def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
483
- if hasattr(module, "set_attention_slice"):
484
- sliceable_head_dims.append(module.sliceable_head_dim)
485
-
486
- for child in module.children():
487
- fn_recursive_retrieve_sliceable_dims(child)
488
-
489
- # retrieve number of attention layers
490
- for module in self.children():
491
- fn_recursive_retrieve_sliceable_dims(module)
492
-
493
- num_sliceable_layers = len(sliceable_head_dims)
494
-
495
- if slice_size == "auto":
496
- # half the attention head size is usually a good trade-off between
497
- # speed and memory
498
- slice_size = [dim // 2 for dim in sliceable_head_dims]
499
- elif slice_size == "max":
500
- # make smallest slice possible
501
- slice_size = num_sliceable_layers * [1]
502
-
503
- slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
504
-
505
- if len(slice_size) != len(sliceable_head_dims):
506
- raise ValueError(
507
- f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
508
- f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
509
- )
510
-
511
- for i in range(len(slice_size)):
512
- size = slice_size[i]
513
- dim = sliceable_head_dims[i]
514
- if size is not None and size > dim:
515
- raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
516
-
517
- # Recursively walk through all the children.
518
- # Any children which exposes the set_attention_slice method
519
- # gets the message
520
- def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
521
- if hasattr(module, "set_attention_slice"):
522
- module.set_attention_slice(slice_size.pop())
523
-
524
- for child in module.children():
525
- fn_recursive_set_attention_slice(child, slice_size)
526
-
527
- reversed_slice_size = list(reversed(slice_size))
528
- for module in self.children():
529
- fn_recursive_set_attention_slice(module, reversed_slice_size)
530
-
531
- def _set_gradient_checkpointing(self, module, value=False):
532
- if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)):
533
- module.gradient_checkpointing = value
534
-
535
- def forward(
536
- self,
537
- sample: torch.FloatTensor,
538
- timestep: Union[torch.Tensor, float, int],
539
- encoder_hidden_states: torch.Tensor,
540
- controlnet_cond: torch.FloatTensor,
541
- conditioning_scale: float = 1.0,
542
- class_labels: Optional[torch.Tensor] = None,
543
- timestep_cond: Optional[torch.Tensor] = None,
544
- attention_mask: Optional[torch.Tensor] = None,
545
- cross_attention_kwargs: Optional[Dict[str, Any]] = None,
546
- guess_mode: bool = False,
547
- return_dict: bool = True,
548
- ) -> Union[ControlNetOutput, Tuple]:
549
- """
550
- The [`ControlNetModel`] forward method.
551
-
552
- Args:
553
- sample (`torch.FloatTensor`):
554
- The noisy input tensor.
555
- timestep (`Union[torch.Tensor, float, int]`):
556
- The number of timesteps to denoise an input.
557
- encoder_hidden_states (`torch.Tensor`):
558
- The encoder hidden states.
559
- controlnet_cond (`torch.FloatTensor`):
560
- The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
561
- conditioning_scale (`float`, defaults to `1.0`):
562
- The scale factor for ControlNet outputs.
563
- class_labels (`torch.Tensor`, *optional*, defaults to `None`):
564
- Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
565
- timestep_cond (`torch.Tensor`, *optional*, defaults to `None`):
566
- attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
567
- cross_attention_kwargs(`dict[str]`, *optional*, defaults to `None`):
568
- A kwargs dictionary that if specified is passed along to the `AttnProcessor`.
569
- guess_mode (`bool`, defaults to `False`):
570
- In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if
571
- you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended.
572
- return_dict (`bool`, defaults to `True`):
573
- Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple.
574
-
575
- Returns:
576
- [`~models.controlnet.ControlNetOutput`] **or** `tuple`:
577
- If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is
578
- returned where the first element is the sample tensor.
579
- """
580
- # check channel order
581
- channel_order = self.config.controlnet_conditioning_channel_order
582
-
583
- if channel_order == "rgb":
584
- # in rgb order by default
585
- ...
586
- elif channel_order == "bgr":
587
- controlnet_cond = torch.flip(controlnet_cond, dims=[1])
588
- else:
589
- raise ValueError(f"unknown `controlnet_conditioning_channel_order`: {channel_order}")
590
-
591
- # prepare attention_mask
592
- if attention_mask is not None:
593
- attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
594
- attention_mask = attention_mask.unsqueeze(1)
595
-
596
- # 1. time
597
- timesteps = timestep
598
- if not torch.is_tensor(timesteps):
599
- # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
600
- # This would be a good case for the `match` statement (Python 3.10+)
601
- is_mps = sample.device.type == "mps"
602
- if isinstance(timestep, float):
603
- dtype = torch.float32 if is_mps else torch.float64
604
- else:
605
- dtype = torch.int32 if is_mps else torch.int64
606
- timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
607
- elif len(timesteps.shape) == 0:
608
- timesteps = timesteps[None].to(sample.device)
609
-
610
- # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
611
- timesteps = timesteps.expand(sample.shape[0])
612
-
613
- t_emb = self.time_proj(timesteps)
614
-
615
- # timesteps does not contain any weights and will always return f32 tensors
616
- # but time_embedding might actually be running in fp16. so we need to cast here.
617
- # there might be better ways to encapsulate this.
618
- t_emb = t_emb.to(dtype=sample.dtype)
619
-
620
- emb = self.time_embedding(t_emb, timestep_cond)
621
-
622
- if self.class_embedding is not None:
623
- if class_labels is None:
624
- raise ValueError("class_labels should be provided when num_class_embeds > 0")
625
-
626
- if self.config.class_embed_type == "timestep":
627
- class_labels = self.time_proj(class_labels)
628
-
629
- class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
630
- emb = emb + class_emb
631
-
632
- # 2. pre-process
633
- sample = self.conv_in(sample)
634
-
635
- controlnet_cond = self.controlnet_cond_embedding(controlnet_cond)
636
-
637
- sample = sample + controlnet_cond
638
-
639
- # 3. down
640
- down_block_res_samples = (sample,)
641
- for downsample_block in self.down_blocks:
642
- if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
643
- sample, res_samples = downsample_block(
644
- hidden_states=sample,
645
- temb=emb,
646
- encoder_hidden_states=encoder_hidden_states,
647
- attention_mask=attention_mask,
648
- cross_attention_kwargs=cross_attention_kwargs,
649
- )
650
- else:
651
- sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
652
-
653
- down_block_res_samples += res_samples
654
-
655
- # 4. mid
656
- if self.mid_block is not None:
657
- sample = self.mid_block(
658
- sample,
659
- emb,
660
- encoder_hidden_states=encoder_hidden_states,
661
- attention_mask=attention_mask,
662
- cross_attention_kwargs=cross_attention_kwargs,
663
- )
664
-
665
- # 5. Control net blocks
666
-
667
- controlnet_down_block_res_samples = ()
668
-
669
- for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):
670
- down_block_res_sample = controlnet_block(down_block_res_sample)
671
- controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,)
672
-
673
- down_block_res_samples = controlnet_down_block_res_samples
674
-
675
- mid_block_res_sample = self.controlnet_mid_block(sample)
676
-
677
- # 6. scaling
678
- if guess_mode and not self.config.global_pool_conditions:
679
- scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0
680
-
681
- scales = scales * conditioning_scale
682
- down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)]
683
- mid_block_res_sample = mid_block_res_sample * scales[-1] # last one
684
- else:
685
- down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
686
- mid_block_res_sample = mid_block_res_sample * conditioning_scale
687
-
688
- if self.config.global_pool_conditions:
689
- down_block_res_samples = [
690
- torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples
691
- ]
692
- mid_block_res_sample = torch.mean(mid_block_res_sample, dim=(2, 3), keepdim=True)
693
-
694
- if not return_dict:
695
- return (down_block_res_samples, mid_block_res_sample)
696
-
697
- return ControlNetOutput(
698
- down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample
699
- )
700
-
701
-
702
- def zero_module(module):
703
- for p in module.parameters():
704
- nn.init.zeros_(p)
705
- return module
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
gradio_demo/6DoF/diffusers/models/controlnet_flax.py DELETED
@@ -1,394 +0,0 @@
1
- # Copyright 2023 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- from typing import Optional, Tuple, Union
15
-
16
- import flax
17
- import flax.linen as nn
18
- import jax
19
- import jax.numpy as jnp
20
- from flax.core.frozen_dict import FrozenDict
21
-
22
- from ..configuration_utils import ConfigMixin, flax_register_to_config
23
- from ..utils import BaseOutput
24
- from .embeddings_flax import FlaxTimestepEmbedding, FlaxTimesteps
25
- from .modeling_flax_utils import FlaxModelMixin
26
- from .unet_2d_blocks_flax import (
27
- FlaxCrossAttnDownBlock2D,
28
- FlaxDownBlock2D,
29
- FlaxUNetMidBlock2DCrossAttn,
30
- )
31
-
32
-
33
- @flax.struct.dataclass
34
- class FlaxControlNetOutput(BaseOutput):
35
- """
36
- The output of [`FlaxControlNetModel`].
37
-
38
- Args:
39
- down_block_res_samples (`jnp.ndarray`):
40
- mid_block_res_sample (`jnp.ndarray`):
41
- """
42
-
43
- down_block_res_samples: jnp.ndarray
44
- mid_block_res_sample: jnp.ndarray
45
-
46
-
47
- class FlaxControlNetConditioningEmbedding(nn.Module):
48
- conditioning_embedding_channels: int
49
- block_out_channels: Tuple[int] = (16, 32, 96, 256)
50
- dtype: jnp.dtype = jnp.float32
51
-
52
- def setup(self):
53
- self.conv_in = nn.Conv(
54
- self.block_out_channels[0],
55
- kernel_size=(3, 3),
56
- padding=((1, 1), (1, 1)),
57
- dtype=self.dtype,
58
- )
59
-
60
- blocks = []
61
- for i in range(len(self.block_out_channels) - 1):
62
- channel_in = self.block_out_channels[i]
63
- channel_out = self.block_out_channels[i + 1]
64
- conv1 = nn.Conv(
65
- channel_in,
66
- kernel_size=(3, 3),
67
- padding=((1, 1), (1, 1)),
68
- dtype=self.dtype,
69
- )
70
- blocks.append(conv1)
71
- conv2 = nn.Conv(
72
- channel_out,
73
- kernel_size=(3, 3),
74
- strides=(2, 2),
75
- padding=((1, 1), (1, 1)),
76
- dtype=self.dtype,
77
- )
78
- blocks.append(conv2)
79
- self.blocks = blocks
80
-
81
- self.conv_out = nn.Conv(
82
- self.conditioning_embedding_channels,
83
- kernel_size=(3, 3),
84
- padding=((1, 1), (1, 1)),
85
- kernel_init=nn.initializers.zeros_init(),
86
- bias_init=nn.initializers.zeros_init(),
87
- dtype=self.dtype,
88
- )
89
-
90
- def __call__(self, conditioning):
91
- embedding = self.conv_in(conditioning)
92
- embedding = nn.silu(embedding)
93
-
94
- for block in self.blocks:
95
- embedding = block(embedding)
96
- embedding = nn.silu(embedding)
97
-
98
- embedding = self.conv_out(embedding)
99
-
100
- return embedding
101
-
102
-
103
- @flax_register_to_config
104
- class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin):
105
- r"""
106
- A ControlNet model.
107
-
108
- This model inherits from [`FlaxModelMixin`]. Check the superclass documentation for it’s generic methods
109
- implemented for all models (such as downloading or saving).
110
-
111
- This model is also a Flax Linen [`flax.linen.Module`](https://flax.readthedocs.io/en/latest/flax.linen.html#module)
112
- subclass. Use it as a regular Flax Linen module and refer to the Flax documentation for all matters related to its
113
- general usage and behavior.
114
-
115
- Inherent JAX features such as the following are supported:
116
-
117
- - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
118
- - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
119
- - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
120
- - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
121
-
122
- Parameters:
123
- sample_size (`int`, *optional*):
124
- The size of the input sample.
125
- in_channels (`int`, *optional*, defaults to 4):
126
- The number of channels in the input sample.
127
- down_block_types (`Tuple[str]`, *optional*, defaults to `("FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxDownBlock2D")`):
128
- The tuple of downsample blocks to use.
129
- block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
130
- The tuple of output channels for each block.
131
- layers_per_block (`int`, *optional*, defaults to 2):
132
- The number of layers per block.
133
- attention_head_dim (`int` or `Tuple[int]`, *optional*, defaults to 8):
134
- The dimension of the attention heads.
135
- num_attention_heads (`int` or `Tuple[int]`, *optional*):
136
- The number of attention heads.
137
- cross_attention_dim (`int`, *optional*, defaults to 768):
138
- The dimension of the cross attention features.
139
- dropout (`float`, *optional*, defaults to 0):
140
- Dropout probability for down, up and bottleneck blocks.
141
- flip_sin_to_cos (`bool`, *optional*, defaults to `True`):
142
- Whether to flip the sin to cos in the time embedding.
143
- freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
144
- controlnet_conditioning_channel_order (`str`, *optional*, defaults to `rgb`):
145
- The channel order of conditional image. Will convert to `rgb` if it's `bgr`.
146
- conditioning_embedding_out_channels (`tuple`, *optional*, defaults to `(16, 32, 96, 256)`):
147
- The tuple of output channel for each block in the `conditioning_embedding` layer.
148
- """
149
- sample_size: int = 32
150
- in_channels: int = 4
151
- down_block_types: Tuple[str] = (
152
- "CrossAttnDownBlock2D",
153
- "CrossAttnDownBlock2D",
154
- "CrossAttnDownBlock2D",
155
- "DownBlock2D",
156
- )
157
- only_cross_attention: Union[bool, Tuple[bool]] = False
158
- block_out_channels: Tuple[int] = (320, 640, 1280, 1280)
159
- layers_per_block: int = 2
160
- attention_head_dim: Union[int, Tuple[int]] = 8
161
- num_attention_heads: Optional[Union[int, Tuple[int]]] = None
162
- cross_attention_dim: int = 1280
163
- dropout: float = 0.0
164
- use_linear_projection: bool = False
165
- dtype: jnp.dtype = jnp.float32
166
- flip_sin_to_cos: bool = True
167
- freq_shift: int = 0
168
- controlnet_conditioning_channel_order: str = "rgb"
169
- conditioning_embedding_out_channels: Tuple[int] = (16, 32, 96, 256)
170
-
171
- def init_weights(self, rng: jax.random.KeyArray) -> FrozenDict:
172
- # init input tensors
173
- sample_shape = (1, self.in_channels, self.sample_size, self.sample_size)
174
- sample = jnp.zeros(sample_shape, dtype=jnp.float32)
175
- timesteps = jnp.ones((1,), dtype=jnp.int32)
176
- encoder_hidden_states = jnp.zeros((1, 1, self.cross_attention_dim), dtype=jnp.float32)
177
- controlnet_cond_shape = (1, 3, self.sample_size * 8, self.sample_size * 8)
178
- controlnet_cond = jnp.zeros(controlnet_cond_shape, dtype=jnp.float32)
179
-
180
- params_rng, dropout_rng = jax.random.split(rng)
181
- rngs = {"params": params_rng, "dropout": dropout_rng}
182
-
183
- return self.init(rngs, sample, timesteps, encoder_hidden_states, controlnet_cond)["params"]
184
-
185
- def setup(self):
186
- block_out_channels = self.block_out_channels
187
- time_embed_dim = block_out_channels[0] * 4
188
-
189
- # If `num_attention_heads` is not defined (which is the case for most models)
190
- # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
191
- # The reason for this behavior is to correct for incorrectly named variables that were introduced
192
- # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
193
- # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
194
- # which is why we correct for the naming here.
195
- num_attention_heads = self.num_attention_heads or self.attention_head_dim
196
-
197
- # input
198
- self.conv_in = nn.Conv(
199
- block_out_channels[0],
200
- kernel_size=(3, 3),
201
- strides=(1, 1),
202
- padding=((1, 1), (1, 1)),
203
- dtype=self.dtype,
204
- )
205
-
206
- # time
207
- self.time_proj = FlaxTimesteps(
208
- block_out_channels[0], flip_sin_to_cos=self.flip_sin_to_cos, freq_shift=self.config.freq_shift
209
- )
210
- self.time_embedding = FlaxTimestepEmbedding(time_embed_dim, dtype=self.dtype)
211
-
212
- self.controlnet_cond_embedding = FlaxControlNetConditioningEmbedding(
213
- conditioning_embedding_channels=block_out_channels[0],
214
- block_out_channels=self.conditioning_embedding_out_channels,
215
- )
216
-
217
- only_cross_attention = self.only_cross_attention
218
- if isinstance(only_cross_attention, bool):
219
- only_cross_attention = (only_cross_attention,) * len(self.down_block_types)
220
-
221
- if isinstance(num_attention_heads, int):
222
- num_attention_heads = (num_attention_heads,) * len(self.down_block_types)
223
-
224
- # down
225
- down_blocks = []
226
- controlnet_down_blocks = []
227
-
228
- output_channel = block_out_channels[0]
229
-
230
- controlnet_block = nn.Conv(
231
- output_channel,
232
- kernel_size=(1, 1),
233
- padding="VALID",
234
- kernel_init=nn.initializers.zeros_init(),
235
- bias_init=nn.initializers.zeros_init(),
236
- dtype=self.dtype,
237
- )
238
- controlnet_down_blocks.append(controlnet_block)
239
-
240
- for i, down_block_type in enumerate(self.down_block_types):
241
- input_channel = output_channel
242
- output_channel = block_out_channels[i]
243
- is_final_block = i == len(block_out_channels) - 1
244
-
245
- if down_block_type == "CrossAttnDownBlock2D":
246
- down_block = FlaxCrossAttnDownBlock2D(
247
- in_channels=input_channel,
248
- out_channels=output_channel,
249
- dropout=self.dropout,
250
- num_layers=self.layers_per_block,
251
- num_attention_heads=num_attention_heads[i],
252
- add_downsample=not is_final_block,
253
- use_linear_projection=self.use_linear_projection,
254
- only_cross_attention=only_cross_attention[i],
255
- dtype=self.dtype,
256
- )
257
- else:
258
- down_block = FlaxDownBlock2D(
259
- in_channels=input_channel,
260
- out_channels=output_channel,
261
- dropout=self.dropout,
262
- num_layers=self.layers_per_block,
263
- add_downsample=not is_final_block,
264
- dtype=self.dtype,
265
- )
266
-
267
- down_blocks.append(down_block)
268
-
269
- for _ in range(self.layers_per_block):
270
- controlnet_block = nn.Conv(
271
- output_channel,
272
- kernel_size=(1, 1),
273
- padding="VALID",
274
- kernel_init=nn.initializers.zeros_init(),
275
- bias_init=nn.initializers.zeros_init(),
276
- dtype=self.dtype,
277
- )
278
- controlnet_down_blocks.append(controlnet_block)
279
-
280
- if not is_final_block:
281
- controlnet_block = nn.Conv(
282
- output_channel,
283
- kernel_size=(1, 1),
284
- padding="VALID",
285
- kernel_init=nn.initializers.zeros_init(),
286
- bias_init=nn.initializers.zeros_init(),
287
- dtype=self.dtype,
288
- )
289
- controlnet_down_blocks.append(controlnet_block)
290
-
291
- self.down_blocks = down_blocks
292
- self.controlnet_down_blocks = controlnet_down_blocks
293
-
294
- # mid
295
- mid_block_channel = block_out_channels[-1]
296
- self.mid_block = FlaxUNetMidBlock2DCrossAttn(
297
- in_channels=mid_block_channel,
298
- dropout=self.dropout,
299
- num_attention_heads=num_attention_heads[-1],
300
- use_linear_projection=self.use_linear_projection,
301
- dtype=self.dtype,
302
- )
303
-
304
- self.controlnet_mid_block = nn.Conv(
305
- mid_block_channel,
306
- kernel_size=(1, 1),
307
- padding="VALID",
308
- kernel_init=nn.initializers.zeros_init(),
309
- bias_init=nn.initializers.zeros_init(),
310
- dtype=self.dtype,
311
- )
312
-
313
- def __call__(
314
- self,
315
- sample,
316
- timesteps,
317
- encoder_hidden_states,
318
- controlnet_cond,
319
- conditioning_scale: float = 1.0,
320
- return_dict: bool = True,
321
- train: bool = False,
322
- ) -> Union[FlaxControlNetOutput, Tuple]:
323
- r"""
324
- Args:
325
- sample (`jnp.ndarray`): (batch, channel, height, width) noisy inputs tensor
326
- timestep (`jnp.ndarray` or `float` or `int`): timesteps
327
- encoder_hidden_states (`jnp.ndarray`): (batch_size, sequence_length, hidden_size) encoder hidden states
328
- controlnet_cond (`jnp.ndarray`): (batch, channel, height, width) the conditional input tensor
329
- conditioning_scale: (`float`) the scale factor for controlnet outputs
330
- return_dict (`bool`, *optional*, defaults to `True`):
331
- Whether or not to return a [`models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] instead of a
332
- plain tuple.
333
- train (`bool`, *optional*, defaults to `False`):
334
- Use deterministic functions and disable dropout when not training.
335
-
336
- Returns:
337
- [`~models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] or `tuple`:
338
- [`~models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`.
339
- When returning a tuple, the first element is the sample tensor.
340
- """
341
- channel_order = self.controlnet_conditioning_channel_order
342
- if channel_order == "bgr":
343
- controlnet_cond = jnp.flip(controlnet_cond, axis=1)
344
-
345
- # 1. time
346
- if not isinstance(timesteps, jnp.ndarray):
347
- timesteps = jnp.array([timesteps], dtype=jnp.int32)
348
- elif isinstance(timesteps, jnp.ndarray) and len(timesteps.shape) == 0:
349
- timesteps = timesteps.astype(dtype=jnp.float32)
350
- timesteps = jnp.expand_dims(timesteps, 0)
351
-
352
- t_emb = self.time_proj(timesteps)
353
- t_emb = self.time_embedding(t_emb)
354
-
355
- # 2. pre-process
356
- sample = jnp.transpose(sample, (0, 2, 3, 1))
357
- sample = self.conv_in(sample)
358
-
359
- controlnet_cond = jnp.transpose(controlnet_cond, (0, 2, 3, 1))
360
- controlnet_cond = self.controlnet_cond_embedding(controlnet_cond)
361
- sample += controlnet_cond
362
-
363
- # 3. down
364
- down_block_res_samples = (sample,)
365
- for down_block in self.down_blocks:
366
- if isinstance(down_block, FlaxCrossAttnDownBlock2D):
367
- sample, res_samples = down_block(sample, t_emb, encoder_hidden_states, deterministic=not train)
368
- else:
369
- sample, res_samples = down_block(sample, t_emb, deterministic=not train)
370
- down_block_res_samples += res_samples
371
-
372
- # 4. mid
373
- sample = self.mid_block(sample, t_emb, encoder_hidden_states, deterministic=not train)
374
-
375
- # 5. contronet blocks
376
- controlnet_down_block_res_samples = ()
377
- for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):
378
- down_block_res_sample = controlnet_block(down_block_res_sample)
379
- controlnet_down_block_res_samples += (down_block_res_sample,)
380
-
381
- down_block_res_samples = controlnet_down_block_res_samples
382
-
383
- mid_block_res_sample = self.controlnet_mid_block(sample)
384
-
385
- # 6. scaling
386
- down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
387
- mid_block_res_sample *= conditioning_scale
388
-
389
- if not return_dict:
390
- return (down_block_res_samples, mid_block_res_sample)
391
-
392
- return FlaxControlNetOutput(
393
- down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample
394
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
gradio_demo/6DoF/diffusers/models/cross_attention.py DELETED
@@ -1,94 +0,0 @@
1
- # Copyright 2023 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- from ..utils import deprecate
15
- from .attention_processor import ( # noqa: F401
16
- Attention,
17
- AttentionProcessor,
18
- AttnAddedKVProcessor,
19
- AttnProcessor2_0,
20
- LoRAAttnProcessor,
21
- LoRALinearLayer,
22
- LoRAXFormersAttnProcessor,
23
- SlicedAttnAddedKVProcessor,
24
- SlicedAttnProcessor,
25
- XFormersAttnProcessor,
26
- )
27
- from .attention_processor import AttnProcessor as AttnProcessorRename # noqa: F401
28
-
29
-
30
- deprecate(
31
- "cross_attention",
32
- "0.20.0",
33
- "Importing from cross_attention is deprecated. Please import from diffusers.models.attention_processor instead.",
34
- standard_warn=False,
35
- )
36
-
37
-
38
- AttnProcessor = AttentionProcessor
39
-
40
-
41
- class CrossAttention(Attention):
42
- def __init__(self, *args, **kwargs):
43
- deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.20.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead."
44
- deprecate("cross_attention", "0.20.0", deprecation_message, standard_warn=False)
45
- super().__init__(*args, **kwargs)
46
-
47
-
48
- class CrossAttnProcessor(AttnProcessorRename):
49
- def __init__(self, *args, **kwargs):
50
- deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.20.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead."
51
- deprecate("cross_attention", "0.20.0", deprecation_message, standard_warn=False)
52
- super().__init__(*args, **kwargs)
53
-
54
-
55
- class LoRACrossAttnProcessor(LoRAAttnProcessor):
56
- def __init__(self, *args, **kwargs):
57
- deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.20.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead."
58
- deprecate("cross_attention", "0.20.0", deprecation_message, standard_warn=False)
59
- super().__init__(*args, **kwargs)
60
-
61
-
62
- class CrossAttnAddedKVProcessor(AttnAddedKVProcessor):
63
- def __init__(self, *args, **kwargs):
64
- deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.20.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead."
65
- deprecate("cross_attention", "0.20.0", deprecation_message, standard_warn=False)
66
- super().__init__(*args, **kwargs)
67
-
68
-
69
- class XFormersCrossAttnProcessor(XFormersAttnProcessor):
70
- def __init__(self, *args, **kwargs):
71
- deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.20.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead."
72
- deprecate("cross_attention", "0.20.0", deprecation_message, standard_warn=False)
73
- super().__init__(*args, **kwargs)
74
-
75
-
76
- class LoRAXFormersCrossAttnProcessor(LoRAXFormersAttnProcessor):
77
- def __init__(self, *args, **kwargs):
78
- deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.20.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead."
79
- deprecate("cross_attention", "0.20.0", deprecation_message, standard_warn=False)
80
- super().__init__(*args, **kwargs)
81
-
82
-
83
- class SlicedCrossAttnProcessor(SlicedAttnProcessor):
84
- def __init__(self, *args, **kwargs):
85
- deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.20.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead."
86
- deprecate("cross_attention", "0.20.0", deprecation_message, standard_warn=False)
87
- super().__init__(*args, **kwargs)
88
-
89
-
90
- class SlicedCrossAttnAddedKVProcessor(SlicedAttnAddedKVProcessor):
91
- def __init__(self, *args, **kwargs):
92
- deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.20.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead."
93
- deprecate("cross_attention", "0.20.0", deprecation_message, standard_warn=False)
94
- super().__init__(*args, **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
gradio_demo/6DoF/diffusers/models/dual_transformer_2d.py DELETED
@@ -1,151 +0,0 @@
1
- # Copyright 2023 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- from typing import Optional
15
-
16
- from torch import nn
17
-
18
- from .transformer_2d import Transformer2DModel, Transformer2DModelOutput
19
-
20
-
21
- class DualTransformer2DModel(nn.Module):
22
- """
23
- Dual transformer wrapper that combines two `Transformer2DModel`s for mixed inference.
24
-
25
- Parameters:
26
- num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
27
- attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
28
- in_channels (`int`, *optional*):
29
- Pass if the input is continuous. The number of channels in the input and output.
30
- num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
31
- dropout (`float`, *optional*, defaults to 0.1): The dropout probability to use.
32
- cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use.
33
- sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images.
34
- Note that this is fixed at training time as it is used for learning a number of position embeddings. See
35
- `ImagePositionalEmbeddings`.
36
- num_vector_embeds (`int`, *optional*):
37
- Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels.
38
- Includes the class for the masked latent pixel.
39
- activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
40
- num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`.
41
- The number of diffusion steps used during training. Note that this is fixed at training time as it is used
42
- to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for
43
- up to but not more than steps than `num_embeds_ada_norm`.
44
- attention_bias (`bool`, *optional*):
45
- Configure if the TransformerBlocks' attention should contain a bias parameter.
46
- """
47
-
48
- def __init__(
49
- self,
50
- num_attention_heads: int = 16,
51
- attention_head_dim: int = 88,
52
- in_channels: Optional[int] = None,
53
- num_layers: int = 1,
54
- dropout: float = 0.0,
55
- norm_num_groups: int = 32,
56
- cross_attention_dim: Optional[int] = None,
57
- attention_bias: bool = False,
58
- sample_size: Optional[int] = None,
59
- num_vector_embeds: Optional[int] = None,
60
- activation_fn: str = "geglu",
61
- num_embeds_ada_norm: Optional[int] = None,
62
- ):
63
- super().__init__()
64
- self.transformers = nn.ModuleList(
65
- [
66
- Transformer2DModel(
67
- num_attention_heads=num_attention_heads,
68
- attention_head_dim=attention_head_dim,
69
- in_channels=in_channels,
70
- num_layers=num_layers,
71
- dropout=dropout,
72
- norm_num_groups=norm_num_groups,
73
- cross_attention_dim=cross_attention_dim,
74
- attention_bias=attention_bias,
75
- sample_size=sample_size,
76
- num_vector_embeds=num_vector_embeds,
77
- activation_fn=activation_fn,
78
- num_embeds_ada_norm=num_embeds_ada_norm,
79
- )
80
- for _ in range(2)
81
- ]
82
- )
83
-
84
- # Variables that can be set by a pipeline:
85
-
86
- # The ratio of transformer1 to transformer2's output states to be combined during inference
87
- self.mix_ratio = 0.5
88
-
89
- # The shape of `encoder_hidden_states` is expected to be
90
- # `(batch_size, condition_lengths[0]+condition_lengths[1], num_features)`
91
- self.condition_lengths = [77, 257]
92
-
93
- # Which transformer to use to encode which condition.
94
- # E.g. `(1, 0)` means that we'll use `transformers[1](conditions[0])` and `transformers[0](conditions[1])`
95
- self.transformer_index_for_condition = [1, 0]
96
-
97
- def forward(
98
- self,
99
- hidden_states,
100
- encoder_hidden_states,
101
- timestep=None,
102
- attention_mask=None,
103
- cross_attention_kwargs=None,
104
- return_dict: bool = True,
105
- ):
106
- """
107
- Args:
108
- hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`.
109
- When continuous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input
110
- hidden_states
111
- encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
112
- Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
113
- self-attention.
114
- timestep ( `torch.long`, *optional*):
115
- Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step.
116
- attention_mask (`torch.FloatTensor`, *optional*):
117
- Optional attention mask to be applied in Attention
118
- return_dict (`bool`, *optional*, defaults to `True`):
119
- Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
120
-
121
- Returns:
122
- [`~models.transformer_2d.Transformer2DModelOutput`] or `tuple`:
123
- [`~models.transformer_2d.Transformer2DModelOutput`] if `return_dict` is True, otherwise a `tuple`. When
124
- returning a tuple, the first element is the sample tensor.
125
- """
126
- input_states = hidden_states
127
-
128
- encoded_states = []
129
- tokens_start = 0
130
- # attention_mask is not used yet
131
- for i in range(2):
132
- # for each of the two transformers, pass the corresponding condition tokens
133
- condition_state = encoder_hidden_states[:, tokens_start : tokens_start + self.condition_lengths[i]]
134
- transformer_index = self.transformer_index_for_condition[i]
135
- encoded_state = self.transformers[transformer_index](
136
- input_states,
137
- encoder_hidden_states=condition_state,
138
- timestep=timestep,
139
- cross_attention_kwargs=cross_attention_kwargs,
140
- return_dict=False,
141
- )[0]
142
- encoded_states.append(encoded_state - input_states)
143
- tokens_start += self.condition_lengths[i]
144
-
145
- output_states = encoded_states[0] * self.mix_ratio + encoded_states[1] * (1 - self.mix_ratio)
146
- output_states = output_states + input_states
147
-
148
- if not return_dict:
149
- return (output_states,)
150
-
151
- return Transformer2DModelOutput(sample=output_states)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
gradio_demo/6DoF/diffusers/models/embeddings.py DELETED
@@ -1,546 +0,0 @@
1
- # Copyright 2023 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- import math
15
- from typing import Optional
16
-
17
- import numpy as np
18
- import torch
19
- from torch import nn
20
-
21
- from .activations import get_activation
22
-
23
-
24
- def get_timestep_embedding(
25
- timesteps: torch.Tensor,
26
- embedding_dim: int,
27
- flip_sin_to_cos: bool = False,
28
- downscale_freq_shift: float = 1,
29
- scale: float = 1,
30
- max_period: int = 10000,
31
- ):
32
- """
33
- This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
34
-
35
- :param timesteps: a 1-D Tensor of N indices, one per batch element.
36
- These may be fractional.
37
- :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
38
- embeddings. :return: an [N x dim] Tensor of positional embeddings.
39
- """
40
- assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
41
-
42
- half_dim = embedding_dim // 2
43
- exponent = -math.log(max_period) * torch.arange(
44
- start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
45
- )
46
- exponent = exponent / (half_dim - downscale_freq_shift)
47
-
48
- emb = torch.exp(exponent)
49
- emb = timesteps[:, None].float() * emb[None, :]
50
-
51
- # scale embeddings
52
- emb = scale * emb
53
-
54
- # concat sine and cosine embeddings
55
- emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
56
-
57
- # flip sine and cosine embeddings
58
- if flip_sin_to_cos:
59
- emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
60
-
61
- # zero pad
62
- if embedding_dim % 2 == 1:
63
- emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
64
- return emb
65
-
66
-
67
- def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
68
- """
69
- grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or
70
- [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
71
- """
72
- grid_h = np.arange(grid_size, dtype=np.float32)
73
- grid_w = np.arange(grid_size, dtype=np.float32)
74
- grid = np.meshgrid(grid_w, grid_h) # here w goes first
75
- grid = np.stack(grid, axis=0)
76
-
77
- grid = grid.reshape([2, 1, grid_size, grid_size])
78
- pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
79
- if cls_token and extra_tokens > 0:
80
- pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
81
- return pos_embed
82
-
83
-
84
- def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
85
- if embed_dim % 2 != 0:
86
- raise ValueError("embed_dim must be divisible by 2")
87
-
88
- # use half of dimensions to encode grid_h
89
- emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
90
- emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
91
-
92
- emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
93
- return emb
94
-
95
-
96
- def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
97
- """
98
- embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D)
99
- """
100
- if embed_dim % 2 != 0:
101
- raise ValueError("embed_dim must be divisible by 2")
102
-
103
- omega = np.arange(embed_dim // 2, dtype=np.float64)
104
- omega /= embed_dim / 2.0
105
- omega = 1.0 / 10000**omega # (D/2,)
106
-
107
- pos = pos.reshape(-1) # (M,)
108
- out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
109
-
110
- emb_sin = np.sin(out) # (M, D/2)
111
- emb_cos = np.cos(out) # (M, D/2)
112
-
113
- emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
114
- return emb
115
-
116
-
117
- class PatchEmbed(nn.Module):
118
- """2D Image to Patch Embedding"""
119
-
120
- def __init__(
121
- self,
122
- height=224,
123
- width=224,
124
- patch_size=16,
125
- in_channels=3,
126
- embed_dim=768,
127
- layer_norm=False,
128
- flatten=True,
129
- bias=True,
130
- ):
131
- super().__init__()
132
-
133
- num_patches = (height // patch_size) * (width // patch_size)
134
- self.flatten = flatten
135
- self.layer_norm = layer_norm
136
-
137
- self.proj = nn.Conv2d(
138
- in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
139
- )
140
- if layer_norm:
141
- self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6)
142
- else:
143
- self.norm = None
144
-
145
- pos_embed = get_2d_sincos_pos_embed(embed_dim, int(num_patches**0.5))
146
- self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=False)
147
-
148
- def forward(self, latent):
149
- latent = self.proj(latent)
150
- if self.flatten:
151
- latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC
152
- if self.layer_norm:
153
- latent = self.norm(latent)
154
- return latent + self.pos_embed
155
-
156
-
157
- class TimestepEmbedding(nn.Module):
158
- def __init__(
159
- self,
160
- in_channels: int,
161
- time_embed_dim: int,
162
- act_fn: str = "silu",
163
- out_dim: int = None,
164
- post_act_fn: Optional[str] = None,
165
- cond_proj_dim=None,
166
- ):
167
- super().__init__()
168
-
169
- self.linear_1 = nn.Linear(in_channels, time_embed_dim)
170
-
171
- if cond_proj_dim is not None:
172
- self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
173
- else:
174
- self.cond_proj = None
175
-
176
- self.act = get_activation(act_fn)
177
-
178
- if out_dim is not None:
179
- time_embed_dim_out = out_dim
180
- else:
181
- time_embed_dim_out = time_embed_dim
182
- self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out)
183
-
184
- if post_act_fn is None:
185
- self.post_act = None
186
- else:
187
- self.post_act = get_activation(post_act_fn)
188
-
189
- def forward(self, sample, condition=None):
190
- if condition is not None:
191
- sample = sample + self.cond_proj(condition)
192
- sample = self.linear_1(sample)
193
-
194
- if self.act is not None:
195
- sample = self.act(sample)
196
-
197
- sample = self.linear_2(sample)
198
-
199
- if self.post_act is not None:
200
- sample = self.post_act(sample)
201
- return sample
202
-
203
-
204
- class Timesteps(nn.Module):
205
- def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float):
206
- super().__init__()
207
- self.num_channels = num_channels
208
- self.flip_sin_to_cos = flip_sin_to_cos
209
- self.downscale_freq_shift = downscale_freq_shift
210
-
211
- def forward(self, timesteps):
212
- t_emb = get_timestep_embedding(
213
- timesteps,
214
- self.num_channels,
215
- flip_sin_to_cos=self.flip_sin_to_cos,
216
- downscale_freq_shift=self.downscale_freq_shift,
217
- )
218
- return t_emb
219
-
220
-
221
- class GaussianFourierProjection(nn.Module):
222
- """Gaussian Fourier embeddings for noise levels."""
223
-
224
- def __init__(
225
- self, embedding_size: int = 256, scale: float = 1.0, set_W_to_weight=True, log=True, flip_sin_to_cos=False
226
- ):
227
- super().__init__()
228
- self.weight = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
229
- self.log = log
230
- self.flip_sin_to_cos = flip_sin_to_cos
231
-
232
- if set_W_to_weight:
233
- # to delete later
234
- self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
235
-
236
- self.weight = self.W
237
-
238
- def forward(self, x):
239
- if self.log:
240
- x = torch.log(x)
241
-
242
- x_proj = x[:, None] * self.weight[None, :] * 2 * np.pi
243
-
244
- if self.flip_sin_to_cos:
245
- out = torch.cat([torch.cos(x_proj), torch.sin(x_proj)], dim=-1)
246
- else:
247
- out = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
248
- return out
249
-
250
-
251
- class ImagePositionalEmbeddings(nn.Module):
252
- """
253
- Converts latent image classes into vector embeddings. Sums the vector embeddings with positional embeddings for the
254
- height and width of the latent space.
255
-
256
- For more details, see figure 10 of the dall-e paper: https://arxiv.org/abs/2102.12092
257
-
258
- For VQ-diffusion:
259
-
260
- Output vector embeddings are used as input for the transformer.
261
-
262
- Note that the vector embeddings for the transformer are different than the vector embeddings from the VQVAE.
263
-
264
- Args:
265
- num_embed (`int`):
266
- Number of embeddings for the latent pixels embeddings.
267
- height (`int`):
268
- Height of the latent image i.e. the number of height embeddings.
269
- width (`int`):
270
- Width of the latent image i.e. the number of width embeddings.
271
- embed_dim (`int`):
272
- Dimension of the produced vector embeddings. Used for the latent pixel, height, and width embeddings.
273
- """
274
-
275
- def __init__(
276
- self,
277
- num_embed: int,
278
- height: int,
279
- width: int,
280
- embed_dim: int,
281
- ):
282
- super().__init__()
283
-
284
- self.height = height
285
- self.width = width
286
- self.num_embed = num_embed
287
- self.embed_dim = embed_dim
288
-
289
- self.emb = nn.Embedding(self.num_embed, embed_dim)
290
- self.height_emb = nn.Embedding(self.height, embed_dim)
291
- self.width_emb = nn.Embedding(self.width, embed_dim)
292
-
293
- def forward(self, index):
294
- emb = self.emb(index)
295
-
296
- height_emb = self.height_emb(torch.arange(self.height, device=index.device).view(1, self.height))
297
-
298
- # 1 x H x D -> 1 x H x 1 x D
299
- height_emb = height_emb.unsqueeze(2)
300
-
301
- width_emb = self.width_emb(torch.arange(self.width, device=index.device).view(1, self.width))
302
-
303
- # 1 x W x D -> 1 x 1 x W x D
304
- width_emb = width_emb.unsqueeze(1)
305
-
306
- pos_emb = height_emb + width_emb
307
-
308
- # 1 x H x W x D -> 1 x L xD
309
- pos_emb = pos_emb.view(1, self.height * self.width, -1)
310
-
311
- emb = emb + pos_emb[:, : emb.shape[1], :]
312
-
313
- return emb
314
-
315
-
316
- class LabelEmbedding(nn.Module):
317
- """
318
- Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
319
-
320
- Args:
321
- num_classes (`int`): The number of classes.
322
- hidden_size (`int`): The size of the vector embeddings.
323
- dropout_prob (`float`): The probability of dropping a label.
324
- """
325
-
326
- def __init__(self, num_classes, hidden_size, dropout_prob):
327
- super().__init__()
328
- use_cfg_embedding = dropout_prob > 0
329
- self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
330
- self.num_classes = num_classes
331
- self.dropout_prob = dropout_prob
332
-
333
- def token_drop(self, labels, force_drop_ids=None):
334
- """
335
- Drops labels to enable classifier-free guidance.
336
- """
337
- if force_drop_ids is None:
338
- drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
339
- else:
340
- drop_ids = torch.tensor(force_drop_ids == 1)
341
- labels = torch.where(drop_ids, self.num_classes, labels)
342
- return labels
343
-
344
- def forward(self, labels: torch.LongTensor, force_drop_ids=None):
345
- use_dropout = self.dropout_prob > 0
346
- if (self.training and use_dropout) or (force_drop_ids is not None):
347
- labels = self.token_drop(labels, force_drop_ids)
348
- embeddings = self.embedding_table(labels)
349
- return embeddings
350
-
351
-
352
- class TextImageProjection(nn.Module):
353
- def __init__(
354
- self,
355
- text_embed_dim: int = 1024,
356
- image_embed_dim: int = 768,
357
- cross_attention_dim: int = 768,
358
- num_image_text_embeds: int = 10,
359
- ):
360
- super().__init__()
361
-
362
- self.num_image_text_embeds = num_image_text_embeds
363
- self.image_embeds = nn.Linear(image_embed_dim, self.num_image_text_embeds * cross_attention_dim)
364
- self.text_proj = nn.Linear(text_embed_dim, cross_attention_dim)
365
-
366
- def forward(self, text_embeds: torch.FloatTensor, image_embeds: torch.FloatTensor):
367
- batch_size = text_embeds.shape[0]
368
-
369
- # image
370
- image_text_embeds = self.image_embeds(image_embeds)
371
- image_text_embeds = image_text_embeds.reshape(batch_size, self.num_image_text_embeds, -1)
372
-
373
- # text
374
- text_embeds = self.text_proj(text_embeds)
375
-
376
- return torch.cat([image_text_embeds, text_embeds], dim=1)
377
-
378
-
379
- class ImageProjection(nn.Module):
380
- def __init__(
381
- self,
382
- image_embed_dim: int = 768,
383
- cross_attention_dim: int = 768,
384
- num_image_text_embeds: int = 32,
385
- ):
386
- super().__init__()
387
-
388
- self.num_image_text_embeds = num_image_text_embeds
389
- self.image_embeds = nn.Linear(image_embed_dim, self.num_image_text_embeds * cross_attention_dim)
390
- self.norm = nn.LayerNorm(cross_attention_dim)
391
-
392
- def forward(self, image_embeds: torch.FloatTensor):
393
- batch_size = image_embeds.shape[0]
394
-
395
- # image
396
- image_embeds = self.image_embeds(image_embeds)
397
- image_embeds = image_embeds.reshape(batch_size, self.num_image_text_embeds, -1)
398
- image_embeds = self.norm(image_embeds)
399
- return image_embeds
400
-
401
-
402
- class CombinedTimestepLabelEmbeddings(nn.Module):
403
- def __init__(self, num_classes, embedding_dim, class_dropout_prob=0.1):
404
- super().__init__()
405
-
406
- self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1)
407
- self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
408
- self.class_embedder = LabelEmbedding(num_classes, embedding_dim, class_dropout_prob)
409
-
410
- def forward(self, timestep, class_labels, hidden_dtype=None):
411
- timesteps_proj = self.time_proj(timestep)
412
- timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D)
413
-
414
- class_labels = self.class_embedder(class_labels) # (N, D)
415
-
416
- conditioning = timesteps_emb + class_labels # (N, D)
417
-
418
- return conditioning
419
-
420
-
421
- class TextTimeEmbedding(nn.Module):
422
- def __init__(self, encoder_dim: int, time_embed_dim: int, num_heads: int = 64):
423
- super().__init__()
424
- self.norm1 = nn.LayerNorm(encoder_dim)
425
- self.pool = AttentionPooling(num_heads, encoder_dim)
426
- self.proj = nn.Linear(encoder_dim, time_embed_dim)
427
- self.norm2 = nn.LayerNorm(time_embed_dim)
428
-
429
- def forward(self, hidden_states):
430
- hidden_states = self.norm1(hidden_states)
431
- hidden_states = self.pool(hidden_states)
432
- hidden_states = self.proj(hidden_states)
433
- hidden_states = self.norm2(hidden_states)
434
- return hidden_states
435
-
436
-
437
- class TextImageTimeEmbedding(nn.Module):
438
- def __init__(self, text_embed_dim: int = 768, image_embed_dim: int = 768, time_embed_dim: int = 1536):
439
- super().__init__()
440
- self.text_proj = nn.Linear(text_embed_dim, time_embed_dim)
441
- self.text_norm = nn.LayerNorm(time_embed_dim)
442
- self.image_proj = nn.Linear(image_embed_dim, time_embed_dim)
443
-
444
- def forward(self, text_embeds: torch.FloatTensor, image_embeds: torch.FloatTensor):
445
- # text
446
- time_text_embeds = self.text_proj(text_embeds)
447
- time_text_embeds = self.text_norm(time_text_embeds)
448
-
449
- # image
450
- time_image_embeds = self.image_proj(image_embeds)
451
-
452
- return time_image_embeds + time_text_embeds
453
-
454
-
455
- class ImageTimeEmbedding(nn.Module):
456
- def __init__(self, image_embed_dim: int = 768, time_embed_dim: int = 1536):
457
- super().__init__()
458
- self.image_proj = nn.Linear(image_embed_dim, time_embed_dim)
459
- self.image_norm = nn.LayerNorm(time_embed_dim)
460
-
461
- def forward(self, image_embeds: torch.FloatTensor):
462
- # image
463
- time_image_embeds = self.image_proj(image_embeds)
464
- time_image_embeds = self.image_norm(time_image_embeds)
465
- return time_image_embeds
466
-
467
-
468
- class ImageHintTimeEmbedding(nn.Module):
469
- def __init__(self, image_embed_dim: int = 768, time_embed_dim: int = 1536):
470
- super().__init__()
471
- self.image_proj = nn.Linear(image_embed_dim, time_embed_dim)
472
- self.image_norm = nn.LayerNorm(time_embed_dim)
473
- self.input_hint_block = nn.Sequential(
474
- nn.Conv2d(3, 16, 3, padding=1),
475
- nn.SiLU(),
476
- nn.Conv2d(16, 16, 3, padding=1),
477
- nn.SiLU(),
478
- nn.Conv2d(16, 32, 3, padding=1, stride=2),
479
- nn.SiLU(),
480
- nn.Conv2d(32, 32, 3, padding=1),
481
- nn.SiLU(),
482
- nn.Conv2d(32, 96, 3, padding=1, stride=2),
483
- nn.SiLU(),
484
- nn.Conv2d(96, 96, 3, padding=1),
485
- nn.SiLU(),
486
- nn.Conv2d(96, 256, 3, padding=1, stride=2),
487
- nn.SiLU(),
488
- nn.Conv2d(256, 4, 3, padding=1),
489
- )
490
-
491
- def forward(self, image_embeds: torch.FloatTensor, hint: torch.FloatTensor):
492
- # image
493
- time_image_embeds = self.image_proj(image_embeds)
494
- time_image_embeds = self.image_norm(time_image_embeds)
495
- hint = self.input_hint_block(hint)
496
- return time_image_embeds, hint
497
-
498
-
499
- class AttentionPooling(nn.Module):
500
- # Copied from https://github.com/deep-floyd/IF/blob/2f91391f27dd3c468bf174be5805b4cc92980c0b/deepfloyd_if/model/nn.py#L54
501
-
502
- def __init__(self, num_heads, embed_dim, dtype=None):
503
- super().__init__()
504
- self.dtype = dtype
505
- self.positional_embedding = nn.Parameter(torch.randn(1, embed_dim) / embed_dim**0.5)
506
- self.k_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype)
507
- self.q_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype)
508
- self.v_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype)
509
- self.num_heads = num_heads
510
- self.dim_per_head = embed_dim // self.num_heads
511
-
512
- def forward(self, x):
513
- bs, length, width = x.size()
514
-
515
- def shape(x):
516
- # (bs, length, width) --> (bs, length, n_heads, dim_per_head)
517
- x = x.view(bs, -1, self.num_heads, self.dim_per_head)
518
- # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
519
- x = x.transpose(1, 2)
520
- # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
521
- x = x.reshape(bs * self.num_heads, -1, self.dim_per_head)
522
- # (bs*n_heads, length, dim_per_head) --> (bs*n_heads, dim_per_head, length)
523
- x = x.transpose(1, 2)
524
- return x
525
-
526
- class_token = x.mean(dim=1, keepdim=True) + self.positional_embedding.to(x.dtype)
527
- x = torch.cat([class_token, x], dim=1) # (bs, length+1, width)
528
-
529
- # (bs*n_heads, class_token_length, dim_per_head)
530
- q = shape(self.q_proj(class_token))
531
- # (bs*n_heads, length+class_token_length, dim_per_head)
532
- k = shape(self.k_proj(x))
533
- v = shape(self.v_proj(x))
534
-
535
- # (bs*n_heads, class_token_length, length+class_token_length):
536
- scale = 1 / math.sqrt(math.sqrt(self.dim_per_head))
537
- weight = torch.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards
538
- weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
539
-
540
- # (bs*n_heads, dim_per_head, class_token_length)
541
- a = torch.einsum("bts,bcs->bct", weight, v)
542
-
543
- # (bs, length+1, width)
544
- a = a.reshape(bs, -1, 1).transpose(1, 2)
545
-
546
- return a[:, 0, :] # cls_token
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
gradio_demo/6DoF/diffusers/models/embeddings_flax.py DELETED
@@ -1,95 +0,0 @@
1
- # Copyright 2023 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- import math
15
-
16
- import flax.linen as nn
17
- import jax.numpy as jnp
18
-
19
-
20
- def get_sinusoidal_embeddings(
21
- timesteps: jnp.ndarray,
22
- embedding_dim: int,
23
- freq_shift: float = 1,
24
- min_timescale: float = 1,
25
- max_timescale: float = 1.0e4,
26
- flip_sin_to_cos: bool = False,
27
- scale: float = 1.0,
28
- ) -> jnp.ndarray:
29
- """Returns the positional encoding (same as Tensor2Tensor).
30
-
31
- Args:
32
- timesteps: a 1-D Tensor of N indices, one per batch element.
33
- These may be fractional.
34
- embedding_dim: The number of output channels.
35
- min_timescale: The smallest time unit (should probably be 0.0).
36
- max_timescale: The largest time unit.
37
- Returns:
38
- a Tensor of timing signals [N, num_channels]
39
- """
40
- assert timesteps.ndim == 1, "Timesteps should be a 1d-array"
41
- assert embedding_dim % 2 == 0, f"Embedding dimension {embedding_dim} should be even"
42
- num_timescales = float(embedding_dim // 2)
43
- log_timescale_increment = math.log(max_timescale / min_timescale) / (num_timescales - freq_shift)
44
- inv_timescales = min_timescale * jnp.exp(jnp.arange(num_timescales, dtype=jnp.float32) * -log_timescale_increment)
45
- emb = jnp.expand_dims(timesteps, 1) * jnp.expand_dims(inv_timescales, 0)
46
-
47
- # scale embeddings
48
- scaled_time = scale * emb
49
-
50
- if flip_sin_to_cos:
51
- signal = jnp.concatenate([jnp.cos(scaled_time), jnp.sin(scaled_time)], axis=1)
52
- else:
53
- signal = jnp.concatenate([jnp.sin(scaled_time), jnp.cos(scaled_time)], axis=1)
54
- signal = jnp.reshape(signal, [jnp.shape(timesteps)[0], embedding_dim])
55
- return signal
56
-
57
-
58
- class FlaxTimestepEmbedding(nn.Module):
59
- r"""
60
- Time step Embedding Module. Learns embeddings for input time steps.
61
-
62
- Args:
63
- time_embed_dim (`int`, *optional*, defaults to `32`):
64
- Time step embedding dimension
65
- dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
66
- Parameters `dtype`
67
- """
68
- time_embed_dim: int = 32
69
- dtype: jnp.dtype = jnp.float32
70
-
71
- @nn.compact
72
- def __call__(self, temb):
73
- temb = nn.Dense(self.time_embed_dim, dtype=self.dtype, name="linear_1")(temb)
74
- temb = nn.silu(temb)
75
- temb = nn.Dense(self.time_embed_dim, dtype=self.dtype, name="linear_2")(temb)
76
- return temb
77
-
78
-
79
- class FlaxTimesteps(nn.Module):
80
- r"""
81
- Wrapper Module for sinusoidal Time step Embeddings as described in https://arxiv.org/abs/2006.11239
82
-
83
- Args:
84
- dim (`int`, *optional*, defaults to `32`):
85
- Time step embedding dimension
86
- """
87
- dim: int = 32
88
- flip_sin_to_cos: bool = False
89
- freq_shift: float = 1
90
-
91
- @nn.compact
92
- def __call__(self, timesteps):
93
- return get_sinusoidal_embeddings(
94
- timesteps, embedding_dim=self.dim, flip_sin_to_cos=self.flip_sin_to_cos, freq_shift=self.freq_shift
95
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
gradio_demo/6DoF/diffusers/models/modeling_flax_pytorch_utils.py DELETED
@@ -1,118 +0,0 @@
1
- # coding=utf-8
2
- # Copyright 2023 The HuggingFace Inc. team.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- """ PyTorch - Flax general utilities."""
16
- import re
17
-
18
- import jax.numpy as jnp
19
- from flax.traverse_util import flatten_dict, unflatten_dict
20
- from jax.random import PRNGKey
21
-
22
- from ..utils import logging
23
-
24
-
25
- logger = logging.get_logger(__name__)
26
-
27
-
28
- def rename_key(key):
29
- regex = r"\w+[.]\d+"
30
- pats = re.findall(regex, key)
31
- for pat in pats:
32
- key = key.replace(pat, "_".join(pat.split(".")))
33
- return key
34
-
35
-
36
- #####################
37
- # PyTorch => Flax #
38
- #####################
39
-
40
-
41
- # Adapted from https://github.com/huggingface/transformers/blob/c603c80f46881ae18b2ca50770ef65fa4033eacd/src/transformers/modeling_flax_pytorch_utils.py#L69
42
- # and https://github.com/patil-suraj/stable-diffusion-jax/blob/main/stable_diffusion_jax/convert_diffusers_to_jax.py
43
- def rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dict):
44
- """Rename PT weight names to corresponding Flax weight names and reshape tensor if necessary"""
45
-
46
- # conv norm or layer norm
47
- renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",)
48
- if (
49
- any("norm" in str_ for str_ in pt_tuple_key)
50
- and (pt_tuple_key[-1] == "bias")
51
- and (pt_tuple_key[:-1] + ("bias",) not in random_flax_state_dict)
52
- and (pt_tuple_key[:-1] + ("scale",) in random_flax_state_dict)
53
- ):
54
- renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",)
55
- return renamed_pt_tuple_key, pt_tensor
56
- elif pt_tuple_key[-1] in ["weight", "gamma"] and pt_tuple_key[:-1] + ("scale",) in random_flax_state_dict:
57
- renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",)
58
- return renamed_pt_tuple_key, pt_tensor
59
-
60
- # embedding
61
- if pt_tuple_key[-1] == "weight" and pt_tuple_key[:-1] + ("embedding",) in random_flax_state_dict:
62
- pt_tuple_key = pt_tuple_key[:-1] + ("embedding",)
63
- return renamed_pt_tuple_key, pt_tensor
64
-
65
- # conv layer
66
- renamed_pt_tuple_key = pt_tuple_key[:-1] + ("kernel",)
67
- if pt_tuple_key[-1] == "weight" and pt_tensor.ndim == 4:
68
- pt_tensor = pt_tensor.transpose(2, 3, 1, 0)
69
- return renamed_pt_tuple_key, pt_tensor
70
-
71
- # linear layer
72
- renamed_pt_tuple_key = pt_tuple_key[:-1] + ("kernel",)
73
- if pt_tuple_key[-1] == "weight":
74
- pt_tensor = pt_tensor.T
75
- return renamed_pt_tuple_key, pt_tensor
76
-
77
- # old PyTorch layer norm weight
78
- renamed_pt_tuple_key = pt_tuple_key[:-1] + ("weight",)
79
- if pt_tuple_key[-1] == "gamma":
80
- return renamed_pt_tuple_key, pt_tensor
81
-
82
- # old PyTorch layer norm bias
83
- renamed_pt_tuple_key = pt_tuple_key[:-1] + ("bias",)
84
- if pt_tuple_key[-1] == "beta":
85
- return renamed_pt_tuple_key, pt_tensor
86
-
87
- return pt_tuple_key, pt_tensor
88
-
89
-
90
- def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model, init_key=42):
91
- # Step 1: Convert pytorch tensor to numpy
92
- pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()}
93
-
94
- # Step 2: Since the model is stateless, get random Flax params
95
- random_flax_params = flax_model.init_weights(PRNGKey(init_key))
96
-
97
- random_flax_state_dict = flatten_dict(random_flax_params)
98
- flax_state_dict = {}
99
-
100
- # Need to change some parameters name to match Flax names
101
- for pt_key, pt_tensor in pt_state_dict.items():
102
- renamed_pt_key = rename_key(pt_key)
103
- pt_tuple_key = tuple(renamed_pt_key.split("."))
104
-
105
- # Correctly rename weight parameters
106
- flax_key, flax_tensor = rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dict)
107
-
108
- if flax_key in random_flax_state_dict:
109
- if flax_tensor.shape != random_flax_state_dict[flax_key].shape:
110
- raise ValueError(
111
- f"PyTorch checkpoint seems to be incorrect. Weight {pt_key} was expected to be of shape "
112
- f"{random_flax_state_dict[flax_key].shape}, but is {flax_tensor.shape}."
113
- )
114
-
115
- # also add unexpected weight so that warning is thrown
116
- flax_state_dict[flax_key] = jnp.asarray(flax_tensor)
117
-
118
- return unflatten_dict(flax_state_dict)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
gradio_demo/6DoF/diffusers/models/modeling_flax_utils.py DELETED
@@ -1,534 +0,0 @@
1
- # coding=utf-8
2
- # Copyright 2023 The HuggingFace Inc. team.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- import os
17
- from pickle import UnpicklingError
18
- from typing import Any, Dict, Union
19
-
20
- import jax
21
- import jax.numpy as jnp
22
- import msgpack.exceptions
23
- from flax.core.frozen_dict import FrozenDict, unfreeze
24
- from flax.serialization import from_bytes, to_bytes
25
- from flax.traverse_util import flatten_dict, unflatten_dict
26
- from huggingface_hub import hf_hub_download
27
- from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
28
- from requests import HTTPError
29
-
30
- from .. import __version__, is_torch_available
31
- from ..utils import (
32
- CONFIG_NAME,
33
- DIFFUSERS_CACHE,
34
- FLAX_WEIGHTS_NAME,
35
- HUGGINGFACE_CO_RESOLVE_ENDPOINT,
36
- WEIGHTS_NAME,
37
- logging,
38
- )
39
- from .modeling_flax_pytorch_utils import convert_pytorch_state_dict_to_flax
40
-
41
-
42
- logger = logging.get_logger(__name__)
43
-
44
-
45
- class FlaxModelMixin:
46
- r"""
47
- Base class for all Flax models.
48
-
49
- [`FlaxModelMixin`] takes care of storing the model configuration and provides methods for loading, downloading and
50
- saving models.
51
-
52
- - **config_name** ([`str`]) -- Filename to save a model to when calling [`~FlaxModelMixin.save_pretrained`].
53
- """
54
- config_name = CONFIG_NAME
55
- _automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"]
56
- _flax_internal_args = ["name", "parent", "dtype"]
57
-
58
- @classmethod
59
- def _from_config(cls, config, **kwargs):
60
- """
61
- All context managers that the model should be initialized under go here.
62
- """
63
- return cls(config, **kwargs)
64
-
65
- def _cast_floating_to(self, params: Union[Dict, FrozenDict], dtype: jnp.dtype, mask: Any = None) -> Any:
66
- """
67
- Helper method to cast floating-point values of given parameter `PyTree` to given `dtype`.
68
- """
69
-
70
- # taken from https://github.com/deepmind/jmp/blob/3a8318abc3292be38582794dbf7b094e6583b192/jmp/_src/policy.py#L27
71
- def conditional_cast(param):
72
- if isinstance(param, jnp.ndarray) and jnp.issubdtype(param.dtype, jnp.floating):
73
- param = param.astype(dtype)
74
- return param
75
-
76
- if mask is None:
77
- return jax.tree_map(conditional_cast, params)
78
-
79
- flat_params = flatten_dict(params)
80
- flat_mask, _ = jax.tree_flatten(mask)
81
-
82
- for masked, key in zip(flat_mask, flat_params.keys()):
83
- if masked:
84
- param = flat_params[key]
85
- flat_params[key] = conditional_cast(param)
86
-
87
- return unflatten_dict(flat_params)
88
-
89
- def to_bf16(self, params: Union[Dict, FrozenDict], mask: Any = None):
90
- r"""
91
- Cast the floating-point `params` to `jax.numpy.bfloat16`. This returns a new `params` tree and does not cast
92
- the `params` in place.
93
-
94
- This method can be used on a TPU to explicitly convert the model parameters to bfloat16 precision to do full
95
- half-precision training or to save weights in bfloat16 for inference in order to save memory and improve speed.
96
-
97
- Arguments:
98
- params (`Union[Dict, FrozenDict]`):
99
- A `PyTree` of model parameters.
100
- mask (`Union[Dict, FrozenDict]`):
101
- A `PyTree` with same structure as the `params` tree. The leaves should be booleans. It should be `True`
102
- for params you want to cast, and `False` for those you want to skip.
103
-
104
- Examples:
105
-
106
- ```python
107
- >>> from diffusers import FlaxUNet2DConditionModel
108
-
109
- >>> # load model
110
- >>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5")
111
- >>> # By default, the model parameters will be in fp32 precision, to cast these to bfloat16 precision
112
- >>> params = model.to_bf16(params)
113
- >>> # If you don't want to cast certain parameters (for example layer norm bias and scale)
114
- >>> # then pass the mask as follows
115
- >>> from flax import traverse_util
116
-
117
- >>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5")
118
- >>> flat_params = traverse_util.flatten_dict(params)
119
- >>> mask = {
120
- ... path: (path[-2] != ("LayerNorm", "bias") and path[-2:] != ("LayerNorm", "scale"))
121
- ... for path in flat_params
122
- ... }
123
- >>> mask = traverse_util.unflatten_dict(mask)
124
- >>> params = model.to_bf16(params, mask)
125
- ```"""
126
- return self._cast_floating_to(params, jnp.bfloat16, mask)
127
-
128
- def to_fp32(self, params: Union[Dict, FrozenDict], mask: Any = None):
129
- r"""
130
- Cast the floating-point `params` to `jax.numpy.float32`. This method can be used to explicitly convert the
131
- model parameters to fp32 precision. This returns a new `params` tree and does not cast the `params` in place.
132
-
133
- Arguments:
134
- params (`Union[Dict, FrozenDict]`):
135
- A `PyTree` of model parameters.
136
- mask (`Union[Dict, FrozenDict]`):
137
- A `PyTree` with same structure as the `params` tree. The leaves should be booleans. It should be `True`
138
- for params you want to cast, and `False` for those you want to skip.
139
-
140
- Examples:
141
-
142
- ```python
143
- >>> from diffusers import FlaxUNet2DConditionModel
144
-
145
- >>> # Download model and configuration from huggingface.co
146
- >>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5")
147
- >>> # By default, the model params will be in fp32, to illustrate the use of this method,
148
- >>> # we'll first cast to fp16 and back to fp32
149
- >>> params = model.to_f16(params)
150
- >>> # now cast back to fp32
151
- >>> params = model.to_fp32(params)
152
- ```"""
153
- return self._cast_floating_to(params, jnp.float32, mask)
154
-
155
- def to_fp16(self, params: Union[Dict, FrozenDict], mask: Any = None):
156
- r"""
157
- Cast the floating-point `params` to `jax.numpy.float16`. This returns a new `params` tree and does not cast the
158
- `params` in place.
159
-
160
- This method can be used on a GPU to explicitly convert the model parameters to float16 precision to do full
161
- half-precision training or to save weights in float16 for inference in order to save memory and improve speed.
162
-
163
- Arguments:
164
- params (`Union[Dict, FrozenDict]`):
165
- A `PyTree` of model parameters.
166
- mask (`Union[Dict, FrozenDict]`):
167
- A `PyTree` with same structure as the `params` tree. The leaves should be booleans. It should be `True`
168
- for params you want to cast, and `False` for those you want to skip.
169
-
170
- Examples:
171
-
172
- ```python
173
- >>> from diffusers import FlaxUNet2DConditionModel
174
-
175
- >>> # load model
176
- >>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5")
177
- >>> # By default, the model params will be in fp32, to cast these to float16
178
- >>> params = model.to_fp16(params)
179
- >>> # If you want don't want to cast certain parameters (for example layer norm bias and scale)
180
- >>> # then pass the mask as follows
181
- >>> from flax import traverse_util
182
-
183
- >>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5")
184
- >>> flat_params = traverse_util.flatten_dict(params)
185
- >>> mask = {
186
- ... path: (path[-2] != ("LayerNorm", "bias") and path[-2:] != ("LayerNorm", "scale"))
187
- ... for path in flat_params
188
- ... }
189
- >>> mask = traverse_util.unflatten_dict(mask)
190
- >>> params = model.to_fp16(params, mask)
191
- ```"""
192
- return self._cast_floating_to(params, jnp.float16, mask)
193
-
194
- def init_weights(self, rng: jax.random.KeyArray) -> Dict:
195
- raise NotImplementedError(f"init_weights method has to be implemented for {self}")
196
-
197
- @classmethod
198
- def from_pretrained(
199
- cls,
200
- pretrained_model_name_or_path: Union[str, os.PathLike],
201
- dtype: jnp.dtype = jnp.float32,
202
- *model_args,
203
- **kwargs,
204
- ):
205
- r"""
206
- Instantiate a pretrained Flax model from a pretrained model configuration.
207
-
208
- Parameters:
209
- pretrained_model_name_or_path (`str` or `os.PathLike`):
210
- Can be either:
211
-
212
- - A string, the *model id* (for example `runwayml/stable-diffusion-v1-5`) of a pretrained model
213
- hosted on the Hub.
214
- - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
215
- using [`~FlaxModelMixin.save_pretrained`].
216
- dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
217
- The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
218
- `jax.numpy.bfloat16` (on TPUs).
219
-
220
- This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
221
- specified, all the computation will be performed with the given `dtype`.
222
-
223
- <Tip>
224
-
225
- This only specifies the dtype of the *computation* and does not influence the dtype of model
226
- parameters.
227
-
228
- If you wish to change the dtype of the model parameters, see [`~FlaxModelMixin.to_fp16`] and
229
- [`~FlaxModelMixin.to_bf16`].
230
-
231
- </Tip>
232
-
233
- model_args (sequence of positional arguments, *optional*):
234
- All remaining positional arguments are passed to the underlying model's `__init__` method.
235
- cache_dir (`Union[str, os.PathLike]`, *optional*):
236
- Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
237
- is not used.
238
- force_download (`bool`, *optional*, defaults to `False`):
239
- Whether or not to force the (re-)download of the model weights and configuration files, overriding the
240
- cached versions if they exist.
241
- resume_download (`bool`, *optional*, defaults to `False`):
242
- Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
243
- incompletely downloaded files are deleted.
244
- proxies (`Dict[str, str]`, *optional*):
245
- A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
246
- 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
247
- local_files_only(`bool`, *optional*, defaults to `False`):
248
- Whether to only load local model weights and configuration files or not. If set to `True`, the model
249
- won't be downloaded from the Hub.
250
- revision (`str`, *optional*, defaults to `"main"`):
251
- The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
252
- allowed by Git.
253
- from_pt (`bool`, *optional*, defaults to `False`):
254
- Load the model weights from a PyTorch checkpoint save file.
255
- kwargs (remaining dictionary of keyword arguments, *optional*):
256
- Can be used to update the configuration object (after it is loaded) and initiate the model (for
257
- example, `output_attentions=True`). Behaves differently depending on whether a `config` is provided or
258
- automatically loaded:
259
-
260
- - If a configuration is provided with `config`, `kwargs` are directly passed to the underlying
261
- model's `__init__` method (we assume all relevant updates to the configuration have already been
262
- done).
263
- - If a configuration is not provided, `kwargs` are first passed to the configuration class
264
- initialization function [`~ConfigMixin.from_config`]. Each key of the `kwargs` that corresponds
265
- to a configuration attribute is used to override said attribute with the supplied `kwargs` value.
266
- Remaining keys that do not correspond to any configuration attribute are passed to the underlying
267
- model's `__init__` function.
268
-
269
- Examples:
270
-
271
- ```python
272
- >>> from diffusers import FlaxUNet2DConditionModel
273
-
274
- >>> # Download model and configuration from huggingface.co and cache.
275
- >>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5")
276
- >>> # Model was saved using *save_pretrained('./test/saved_model/')* (for example purposes, not runnable).
277
- >>> model, params = FlaxUNet2DConditionModel.from_pretrained("./test/saved_model/")
278
- ```
279
-
280
- If you get the error message below, you need to finetune the weights for your downstream task:
281
-
282
- ```bash
283
- Some weights of UNet2DConditionModel were not initialized from the model checkpoint at runwayml/stable-diffusion-v1-5 and are newly initialized because the shapes did not match:
284
- - conv_in.weight: found shape torch.Size([320, 4, 3, 3]) in the checkpoint and torch.Size([320, 9, 3, 3]) in the model instantiated
285
- You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
286
- ```
287
- """
288
- config = kwargs.pop("config", None)
289
- cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
290
- force_download = kwargs.pop("force_download", False)
291
- from_pt = kwargs.pop("from_pt", False)
292
- resume_download = kwargs.pop("resume_download", False)
293
- proxies = kwargs.pop("proxies", None)
294
- local_files_only = kwargs.pop("local_files_only", False)
295
- use_auth_token = kwargs.pop("use_auth_token", None)
296
- revision = kwargs.pop("revision", None)
297
- subfolder = kwargs.pop("subfolder", None)
298
-
299
- user_agent = {
300
- "diffusers": __version__,
301
- "file_type": "model",
302
- "framework": "flax",
303
- }
304
-
305
- # Load config if we don't provide a configuration
306
- config_path = config if config is not None else pretrained_model_name_or_path
307
- model, model_kwargs = cls.from_config(
308
- config_path,
309
- cache_dir=cache_dir,
310
- return_unused_kwargs=True,
311
- force_download=force_download,
312
- resume_download=resume_download,
313
- proxies=proxies,
314
- local_files_only=local_files_only,
315
- use_auth_token=use_auth_token,
316
- revision=revision,
317
- subfolder=subfolder,
318
- # model args
319
- dtype=dtype,
320
- **kwargs,
321
- )
322
-
323
- # Load model
324
- pretrained_path_with_subfolder = (
325
- pretrained_model_name_or_path
326
- if subfolder is None
327
- else os.path.join(pretrained_model_name_or_path, subfolder)
328
- )
329
- if os.path.isdir(pretrained_path_with_subfolder):
330
- if from_pt:
331
- if not os.path.isfile(os.path.join(pretrained_path_with_subfolder, WEIGHTS_NAME)):
332
- raise EnvironmentError(
333
- f"Error no file named {WEIGHTS_NAME} found in directory {pretrained_path_with_subfolder} "
334
- )
335
- model_file = os.path.join(pretrained_path_with_subfolder, WEIGHTS_NAME)
336
- elif os.path.isfile(os.path.join(pretrained_path_with_subfolder, FLAX_WEIGHTS_NAME)):
337
- # Load from a Flax checkpoint
338
- model_file = os.path.join(pretrained_path_with_subfolder, FLAX_WEIGHTS_NAME)
339
- # Check if pytorch weights exist instead
340
- elif os.path.isfile(os.path.join(pretrained_path_with_subfolder, WEIGHTS_NAME)):
341
- raise EnvironmentError(
342
- f"{WEIGHTS_NAME} file found in directory {pretrained_path_with_subfolder}. Please load the model"
343
- " using `from_pt=True`."
344
- )
345
- else:
346
- raise EnvironmentError(
347
- f"Error no file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME} found in directory "
348
- f"{pretrained_path_with_subfolder}."
349
- )
350
- else:
351
- try:
352
- model_file = hf_hub_download(
353
- pretrained_model_name_or_path,
354
- filename=FLAX_WEIGHTS_NAME if not from_pt else WEIGHTS_NAME,
355
- cache_dir=cache_dir,
356
- force_download=force_download,
357
- proxies=proxies,
358
- resume_download=resume_download,
359
- local_files_only=local_files_only,
360
- use_auth_token=use_auth_token,
361
- user_agent=user_agent,
362
- subfolder=subfolder,
363
- revision=revision,
364
- )
365
-
366
- except RepositoryNotFoundError:
367
- raise EnvironmentError(
368
- f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
369
- "listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a "
370
- "token having permission to this repo with `use_auth_token` or log in with `huggingface-cli "
371
- "login`."
372
- )
373
- except RevisionNotFoundError:
374
- raise EnvironmentError(
375
- f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for "
376
- "this model name. Check the model page at "
377
- f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
378
- )
379
- except EntryNotFoundError:
380
- raise EnvironmentError(
381
- f"{pretrained_model_name_or_path} does not appear to have a file named {FLAX_WEIGHTS_NAME}."
382
- )
383
- except HTTPError as err:
384
- raise EnvironmentError(
385
- f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n"
386
- f"{err}"
387
- )
388
- except ValueError:
389
- raise EnvironmentError(
390
- f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
391
- f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
392
- f" directory containing a file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME}.\nCheckout your"
393
- " internet connection or see how to run the library in offline mode at"
394
- " 'https://huggingface.co/docs/transformers/installation#offline-mode'."
395
- )
396
- except EnvironmentError:
397
- raise EnvironmentError(
398
- f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from "
399
- "'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
400
- f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
401
- f"containing a file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME}."
402
- )
403
-
404
- if from_pt:
405
- if is_torch_available():
406
- from .modeling_utils import load_state_dict
407
- else:
408
- raise EnvironmentError(
409
- "Can't load the model in PyTorch format because PyTorch is not installed. "
410
- "Please, install PyTorch or use native Flax weights."
411
- )
412
-
413
- # Step 1: Get the pytorch file
414
- pytorch_model_file = load_state_dict(model_file)
415
-
416
- # Step 2: Convert the weights
417
- state = convert_pytorch_state_dict_to_flax(pytorch_model_file, model)
418
- else:
419
- try:
420
- with open(model_file, "rb") as state_f:
421
- state = from_bytes(cls, state_f.read())
422
- except (UnpicklingError, msgpack.exceptions.ExtraData) as e:
423
- try:
424
- with open(model_file) as f:
425
- if f.read().startswith("version"):
426
- raise OSError(
427
- "You seem to have cloned a repository without having git-lfs installed. Please"
428
- " install git-lfs and run `git lfs install` followed by `git lfs pull` in the"
429
- " folder you cloned."
430
- )
431
- else:
432
- raise ValueError from e
433
- except (UnicodeDecodeError, ValueError):
434
- raise EnvironmentError(f"Unable to convert {model_file} to Flax deserializable object. ")
435
- # make sure all arrays are stored as jnp.ndarray
436
- # NOTE: This is to prevent a bug this will be fixed in Flax >= v0.3.4:
437
- # https://github.com/google/flax/issues/1261
438
- state = jax.tree_util.tree_map(lambda x: jax.device_put(x, jax.devices("cpu")[0]), state)
439
-
440
- # flatten dicts
441
- state = flatten_dict(state)
442
-
443
- params_shape_tree = jax.eval_shape(model.init_weights, rng=jax.random.PRNGKey(0))
444
- required_params = set(flatten_dict(unfreeze(params_shape_tree)).keys())
445
-
446
- shape_state = flatten_dict(unfreeze(params_shape_tree))
447
-
448
- missing_keys = required_params - set(state.keys())
449
- unexpected_keys = set(state.keys()) - required_params
450
-
451
- if missing_keys:
452
- logger.warning(
453
- f"The checkpoint {pretrained_model_name_or_path} is missing required keys: {missing_keys}. "
454
- "Make sure to call model.init_weights to initialize the missing weights."
455
- )
456
- cls._missing_keys = missing_keys
457
-
458
- for key in state.keys():
459
- if key in shape_state and state[key].shape != shape_state[key].shape:
460
- raise ValueError(
461
- f"Trying to load the pretrained weight for {key} failed: checkpoint has shape "
462
- f"{state[key].shape} which is incompatible with the model shape {shape_state[key].shape}. "
463
- )
464
-
465
- # remove unexpected keys to not be saved again
466
- for unexpected_key in unexpected_keys:
467
- del state[unexpected_key]
468
-
469
- if len(unexpected_keys) > 0:
470
- logger.warning(
471
- f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
472
- f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
473
- f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task or"
474
- " with another architecture."
475
- )
476
- else:
477
- logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
478
-
479
- if len(missing_keys) > 0:
480
- logger.warning(
481
- f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
482
- f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably"
483
- " TRAIN this model on a down-stream task to be able to use it for predictions and inference."
484
- )
485
- else:
486
- logger.info(
487
- f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at"
488
- f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the checkpoint"
489
- f" was trained on, you can already use {model.__class__.__name__} for predictions without further"
490
- " training."
491
- )
492
-
493
- return model, unflatten_dict(state)
494
-
495
- def save_pretrained(
496
- self,
497
- save_directory: Union[str, os.PathLike],
498
- params: Union[Dict, FrozenDict],
499
- is_main_process: bool = True,
500
- ):
501
- """
502
- Save a model and its configuration file to a directory so that it can be reloaded using the
503
- [`~FlaxModelMixin.from_pretrained`] class method.
504
-
505
- Arguments:
506
- save_directory (`str` or `os.PathLike`):
507
- Directory to save a model and its configuration file to. Will be created if it doesn't exist.
508
- params (`Union[Dict, FrozenDict]`):
509
- A `PyTree` of model parameters.
510
- is_main_process (`bool`, *optional*, defaults to `True`):
511
- Whether the process calling this is the main process or not. Useful during distributed training and you
512
- need to call this function on all processes. In this case, set `is_main_process=True` only on the main
513
- process to avoid race conditions.
514
- """
515
- if os.path.isfile(save_directory):
516
- logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
517
- return
518
-
519
- os.makedirs(save_directory, exist_ok=True)
520
-
521
- model_to_save = self
522
-
523
- # Attach architecture to the config
524
- # Save the config
525
- if is_main_process:
526
- model_to_save.save_config(save_directory)
527
-
528
- # save model
529
- output_model_file = os.path.join(save_directory, FLAX_WEIGHTS_NAME)
530
- with open(output_model_file, "wb") as f:
531
- model_bytes = to_bytes(params)
532
- f.write(model_bytes)
533
-
534
- logger.info(f"Model weights saved in {output_model_file}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
gradio_demo/6DoF/diffusers/models/modeling_pytorch_flax_utils.py DELETED
@@ -1,161 +0,0 @@
1
- # coding=utf-8
2
- # Copyright 2023 The HuggingFace Inc. team.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- """ PyTorch - Flax general utilities."""
16
-
17
- from pickle import UnpicklingError
18
-
19
- import jax
20
- import jax.numpy as jnp
21
- import numpy as np
22
- from flax.serialization import from_bytes
23
- from flax.traverse_util import flatten_dict
24
-
25
- from ..utils import logging
26
-
27
-
28
- logger = logging.get_logger(__name__)
29
-
30
-
31
- #####################
32
- # Flax => PyTorch #
33
- #####################
34
-
35
-
36
- # from https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_flax_pytorch_utils.py#L224-L352
37
- def load_flax_checkpoint_in_pytorch_model(pt_model, model_file):
38
- try:
39
- with open(model_file, "rb") as flax_state_f:
40
- flax_state = from_bytes(None, flax_state_f.read())
41
- except UnpicklingError as e:
42
- try:
43
- with open(model_file) as f:
44
- if f.read().startswith("version"):
45
- raise OSError(
46
- "You seem to have cloned a repository without having git-lfs installed. Please"
47
- " install git-lfs and run `git lfs install` followed by `git lfs pull` in the"
48
- " folder you cloned."
49
- )
50
- else:
51
- raise ValueError from e
52
- except (UnicodeDecodeError, ValueError):
53
- raise EnvironmentError(f"Unable to convert {model_file} to Flax deserializable object. ")
54
-
55
- return load_flax_weights_in_pytorch_model(pt_model, flax_state)
56
-
57
-
58
- def load_flax_weights_in_pytorch_model(pt_model, flax_state):
59
- """Load flax checkpoints in a PyTorch model"""
60
-
61
- try:
62
- import torch # noqa: F401
63
- except ImportError:
64
- logger.error(
65
- "Loading Flax weights in PyTorch requires both PyTorch and Flax to be installed. Please see"
66
- " https://pytorch.org/ and https://flax.readthedocs.io/en/latest/installation.html for installation"
67
- " instructions."
68
- )
69
- raise
70
-
71
- # check if we have bf16 weights
72
- is_type_bf16 = flatten_dict(jax.tree_util.tree_map(lambda x: x.dtype == jnp.bfloat16, flax_state)).values()
73
- if any(is_type_bf16):
74
- # convert all weights to fp32 if they are bf16 since torch.from_numpy can-not handle bf16
75
-
76
- # and bf16 is not fully supported in PT yet.
77
- logger.warning(
78
- "Found ``bfloat16`` weights in Flax model. Casting all ``bfloat16`` weights to ``float32`` "
79
- "before loading those in PyTorch model."
80
- )
81
- flax_state = jax.tree_util.tree_map(
82
- lambda params: params.astype(np.float32) if params.dtype == jnp.bfloat16 else params, flax_state
83
- )
84
-
85
- pt_model.base_model_prefix = ""
86
-
87
- flax_state_dict = flatten_dict(flax_state, sep=".")
88
- pt_model_dict = pt_model.state_dict()
89
-
90
- # keep track of unexpected & missing keys
91
- unexpected_keys = []
92
- missing_keys = set(pt_model_dict.keys())
93
-
94
- for flax_key_tuple, flax_tensor in flax_state_dict.items():
95
- flax_key_tuple_array = flax_key_tuple.split(".")
96
-
97
- if flax_key_tuple_array[-1] == "kernel" and flax_tensor.ndim == 4:
98
- flax_key_tuple_array = flax_key_tuple_array[:-1] + ["weight"]
99
- flax_tensor = jnp.transpose(flax_tensor, (3, 2, 0, 1))
100
- elif flax_key_tuple_array[-1] == "kernel":
101
- flax_key_tuple_array = flax_key_tuple_array[:-1] + ["weight"]
102
- flax_tensor = flax_tensor.T
103
- elif flax_key_tuple_array[-1] == "scale":
104
- flax_key_tuple_array = flax_key_tuple_array[:-1] + ["weight"]
105
-
106
- if "time_embedding" not in flax_key_tuple_array:
107
- for i, flax_key_tuple_string in enumerate(flax_key_tuple_array):
108
- flax_key_tuple_array[i] = (
109
- flax_key_tuple_string.replace("_0", ".0")
110
- .replace("_1", ".1")
111
- .replace("_2", ".2")
112
- .replace("_3", ".3")
113
- .replace("_4", ".4")
114
- .replace("_5", ".5")
115
- .replace("_6", ".6")
116
- .replace("_7", ".7")
117
- .replace("_8", ".8")
118
- .replace("_9", ".9")
119
- )
120
-
121
- flax_key = ".".join(flax_key_tuple_array)
122
-
123
- if flax_key in pt_model_dict:
124
- if flax_tensor.shape != pt_model_dict[flax_key].shape:
125
- raise ValueError(
126
- f"Flax checkpoint seems to be incorrect. Weight {flax_key_tuple} was expected "
127
- f"to be of shape {pt_model_dict[flax_key].shape}, but is {flax_tensor.shape}."
128
- )
129
- else:
130
- # add weight to pytorch dict
131
- flax_tensor = np.asarray(flax_tensor) if not isinstance(flax_tensor, np.ndarray) else flax_tensor
132
- pt_model_dict[flax_key] = torch.from_numpy(flax_tensor)
133
- # remove from missing keys
134
- missing_keys.remove(flax_key)
135
- else:
136
- # weight is not expected by PyTorch model
137
- unexpected_keys.append(flax_key)
138
-
139
- pt_model.load_state_dict(pt_model_dict)
140
-
141
- # re-transform missing_keys to list
142
- missing_keys = list(missing_keys)
143
-
144
- if len(unexpected_keys) > 0:
145
- logger.warning(
146
- "Some weights of the Flax model were not used when initializing the PyTorch model"
147
- f" {pt_model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are initializing"
148
- f" {pt_model.__class__.__name__} from a Flax model trained on another task or with another architecture"
149
- " (e.g. initializing a BertForSequenceClassification model from a FlaxBertForPreTraining model).\n- This"
150
- f" IS NOT expected if you are initializing {pt_model.__class__.__name__} from a Flax model that you expect"
151
- " to be exactly identical (e.g. initializing a BertForSequenceClassification model from a"
152
- " FlaxBertForSequenceClassification model)."
153
- )
154
- if len(missing_keys) > 0:
155
- logger.warning(
156
- f"Some weights of {pt_model.__class__.__name__} were not initialized from the Flax model and are newly"
157
- f" initialized: {missing_keys}\nYou should probably TRAIN this model on a down-stream task to be able to"
158
- " use it for predictions and inference."
159
- )
160
-
161
- return pt_model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
gradio_demo/6DoF/diffusers/models/modeling_utils.py DELETED
@@ -1,980 +0,0 @@
1
- # coding=utf-8
2
- # Copyright 2023 The HuggingFace Inc. team.
3
- # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
4
- #
5
- # Licensed under the Apache License, Version 2.0 (the "License");
6
- # you may not use this file except in compliance with the License.
7
- # You may obtain a copy of the License at
8
- #
9
- # http://www.apache.org/licenses/LICENSE-2.0
10
- #
11
- # Unless required by applicable law or agreed to in writing, software
12
- # distributed under the License is distributed on an "AS IS" BASIS,
13
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
- # See the License for the specific language governing permissions and
15
- # limitations under the License.
16
-
17
- import inspect
18
- import itertools
19
- import os
20
- import re
21
- from functools import partial
22
- from typing import Any, Callable, List, Optional, Tuple, Union
23
-
24
- import torch
25
- from torch import Tensor, device, nn
26
-
27
- from .. import __version__
28
- from ..utils import (
29
- CONFIG_NAME,
30
- DIFFUSERS_CACHE,
31
- FLAX_WEIGHTS_NAME,
32
- HF_HUB_OFFLINE,
33
- SAFETENSORS_WEIGHTS_NAME,
34
- WEIGHTS_NAME,
35
- _add_variant,
36
- _get_model_file,
37
- deprecate,
38
- is_accelerate_available,
39
- is_safetensors_available,
40
- is_torch_version,
41
- logging,
42
- )
43
-
44
-
45
- logger = logging.get_logger(__name__)
46
-
47
-
48
- if is_torch_version(">=", "1.9.0"):
49
- _LOW_CPU_MEM_USAGE_DEFAULT = True
50
- else:
51
- _LOW_CPU_MEM_USAGE_DEFAULT = False
52
-
53
-
54
- if is_accelerate_available():
55
- import accelerate
56
- from accelerate.utils import set_module_tensor_to_device
57
- from accelerate.utils.versions import is_torch_version
58
-
59
- if is_safetensors_available():
60
- import safetensors
61
-
62
-
63
- def get_parameter_device(parameter: torch.nn.Module):
64
- try:
65
- parameters_and_buffers = itertools.chain(parameter.parameters(), parameter.buffers())
66
- return next(parameters_and_buffers).device
67
- except StopIteration:
68
- # For torch.nn.DataParallel compatibility in PyTorch 1.5
69
-
70
- def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
71
- tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
72
- return tuples
73
-
74
- gen = parameter._named_members(get_members_fn=find_tensor_attributes)
75
- first_tuple = next(gen)
76
- return first_tuple[1].device
77
-
78
-
79
- def get_parameter_dtype(parameter: torch.nn.Module):
80
- try:
81
- params = tuple(parameter.parameters())
82
- if len(params) > 0:
83
- return params[0].dtype
84
-
85
- buffers = tuple(parameter.buffers())
86
- if len(buffers) > 0:
87
- return buffers[0].dtype
88
-
89
- except StopIteration:
90
- # For torch.nn.DataParallel compatibility in PyTorch 1.5
91
-
92
- def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
93
- tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
94
- return tuples
95
-
96
- gen = parameter._named_members(get_members_fn=find_tensor_attributes)
97
- first_tuple = next(gen)
98
- return first_tuple[1].dtype
99
-
100
-
101
- def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[str] = None):
102
- """
103
- Reads a checkpoint file, returning properly formatted errors if they arise.
104
- """
105
- try:
106
- if os.path.basename(checkpoint_file) == _add_variant(WEIGHTS_NAME, variant):
107
- return torch.load(checkpoint_file, map_location="cpu")
108
- else:
109
- return safetensors.torch.load_file(checkpoint_file, device="cpu")
110
- except Exception as e:
111
- try:
112
- with open(checkpoint_file) as f:
113
- if f.read().startswith("version"):
114
- raise OSError(
115
- "You seem to have cloned a repository without having git-lfs installed. Please install "
116
- "git-lfs and run `git lfs install` followed by `git lfs pull` in the folder "
117
- "you cloned."
118
- )
119
- else:
120
- raise ValueError(
121
- f"Unable to locate the file {checkpoint_file} which is necessary to load this pretrained "
122
- "model. Make sure you have saved the model properly."
123
- ) from e
124
- except (UnicodeDecodeError, ValueError):
125
- raise OSError(
126
- f"Unable to load weights from checkpoint file for '{checkpoint_file}' "
127
- f"at '{checkpoint_file}'. "
128
- "If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True."
129
- )
130
-
131
-
132
- def _load_state_dict_into_model(model_to_load, state_dict):
133
- # Convert old format to new format if needed from a PyTorch state_dict
134
- # copy state_dict so _load_from_state_dict can modify it
135
- state_dict = state_dict.copy()
136
- error_msgs = []
137
-
138
- # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
139
- # so we need to apply the function recursively.
140
- def load(module: torch.nn.Module, prefix=""):
141
- args = (state_dict, prefix, {}, True, [], [], error_msgs)
142
- module._load_from_state_dict(*args)
143
-
144
- for name, child in module._modules.items():
145
- if child is not None:
146
- load(child, prefix + name + ".")
147
-
148
- load(model_to_load)
149
-
150
- return error_msgs
151
-
152
-
153
- class ModelMixin(torch.nn.Module):
154
- r"""
155
- Base class for all models.
156
-
157
- [`ModelMixin`] takes care of storing the model configuration and provides methods for loading, downloading and
158
- saving models.
159
-
160
- - **config_name** ([`str`]) -- Filename to save a model to when calling [`~models.ModelMixin.save_pretrained`].
161
- """
162
- config_name = CONFIG_NAME
163
- _automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"]
164
- _supports_gradient_checkpointing = False
165
- _keys_to_ignore_on_load_unexpected = None
166
-
167
- def __init__(self):
168
- super().__init__()
169
-
170
- def __getattr__(self, name: str) -> Any:
171
- """The only reason we overwrite `getattr` here is to gracefully deprecate accessing
172
- config attributes directly. See https://github.com/huggingface/diffusers/pull/3129 We need to overwrite
173
- __getattr__ here in addition so that we don't trigger `torch.nn.Module`'s __getattr__':
174
- https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module
175
- """
176
-
177
- is_in_config = "_internal_dict" in self.__dict__ and hasattr(self.__dict__["_internal_dict"], name)
178
- is_attribute = name in self.__dict__
179
-
180
- if is_in_config and not is_attribute:
181
- deprecation_message = f"Accessing config attribute `{name}` directly via '{type(self).__name__}' object attribute is deprecated. Please access '{name}' over '{type(self).__name__}'s config object instead, e.g. 'unet.config.{name}'."
182
- deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False, stacklevel=3)
183
- return self._internal_dict[name]
184
-
185
- # call PyTorch's https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module
186
- return super().__getattr__(name)
187
-
188
- @property
189
- def is_gradient_checkpointing(self) -> bool:
190
- """
191
- Whether gradient checkpointing is activated for this model or not.
192
- """
193
- return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules())
194
-
195
- def enable_gradient_checkpointing(self):
196
- """
197
- Activates gradient checkpointing for the current model (may be referred to as *activation checkpointing* or
198
- *checkpoint activations* in other frameworks).
199
- """
200
- if not self._supports_gradient_checkpointing:
201
- raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
202
- self.apply(partial(self._set_gradient_checkpointing, value=True))
203
-
204
- def disable_gradient_checkpointing(self):
205
- """
206
- Deactivates gradient checkpointing for the current model (may be referred to as *activation checkpointing* or
207
- *checkpoint activations* in other frameworks).
208
- """
209
- if self._supports_gradient_checkpointing:
210
- self.apply(partial(self._set_gradient_checkpointing, value=False))
211
-
212
- def set_use_memory_efficient_attention_xformers(
213
- self, valid: bool, attention_op: Optional[Callable] = None
214
- ) -> None:
215
- # Recursively walk through all the children.
216
- # Any children which exposes the set_use_memory_efficient_attention_xformers method
217
- # gets the message
218
- def fn_recursive_set_mem_eff(module: torch.nn.Module):
219
- if hasattr(module, "set_use_memory_efficient_attention_xformers"):
220
- module.set_use_memory_efficient_attention_xformers(valid, attention_op)
221
-
222
- for child in module.children():
223
- fn_recursive_set_mem_eff(child)
224
-
225
- for module in self.children():
226
- if isinstance(module, torch.nn.Module):
227
- fn_recursive_set_mem_eff(module)
228
-
229
- def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None):
230
- r"""
231
- Enable memory efficient attention from [xFormers](https://facebookresearch.github.io/xformers/).
232
-
233
- When this option is enabled, you should observe lower GPU memory usage and a potential speed up during
234
- inference. Speed up during training is not guaranteed.
235
-
236
- <Tip warning={true}>
237
-
238
- ⚠️ When memory efficient attention and sliced attention are both enabled, memory efficient attention takes
239
- precedent.
240
-
241
- </Tip>
242
-
243
- Parameters:
244
- attention_op (`Callable`, *optional*):
245
- Override the default `None` operator for use as `op` argument to the
246
- [`memory_efficient_attention()`](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.memory_efficient_attention)
247
- function of xFormers.
248
-
249
- Examples:
250
-
251
- ```py
252
- >>> import torch
253
- >>> from diffusers import UNet2DConditionModel
254
- >>> from xformers.ops import MemoryEfficientAttentionFlashAttentionOp
255
-
256
- >>> model = UNet2DConditionModel.from_pretrained(
257
- ... "stabilityai/stable-diffusion-2-1", subfolder="unet", torch_dtype=torch.float16
258
- ... )
259
- >>> model = model.to("cuda")
260
- >>> model.enable_xformers_memory_efficient_attention(attention_op=MemoryEfficientAttentionFlashAttentionOp)
261
- ```
262
- """
263
- self.set_use_memory_efficient_attention_xformers(True, attention_op)
264
-
265
- def disable_xformers_memory_efficient_attention(self):
266
- r"""
267
- Disable memory efficient attention from [xFormers](https://facebookresearch.github.io/xformers/).
268
- """
269
- self.set_use_memory_efficient_attention_xformers(False)
270
-
271
- def save_pretrained(
272
- self,
273
- save_directory: Union[str, os.PathLike],
274
- is_main_process: bool = True,
275
- save_function: Callable = None,
276
- safe_serialization: bool = False,
277
- variant: Optional[str] = None,
278
- ):
279
- """
280
- Save a model and its configuration file to a directory so that it can be reloaded using the
281
- [`~models.ModelMixin.from_pretrained`] class method.
282
-
283
- Arguments:
284
- save_directory (`str` or `os.PathLike`):
285
- Directory to save a model and its configuration file to. Will be created if it doesn't exist.
286
- is_main_process (`bool`, *optional*, defaults to `True`):
287
- Whether the process calling this is the main process or not. Useful during distributed training and you
288
- need to call this function on all processes. In this case, set `is_main_process=True` only on the main
289
- process to avoid race conditions.
290
- save_function (`Callable`):
291
- The function to use to save the state dictionary. Useful during distributed training when you need to
292
- replace `torch.save` with another method. Can be configured with the environment variable
293
- `DIFFUSERS_SAVE_MODE`.
294
- safe_serialization (`bool`, *optional*, defaults to `False`):
295
- Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
296
- variant (`str`, *optional*):
297
- If specified, weights are saved in the format `pytorch_model.<variant>.bin`.
298
- """
299
- if safe_serialization and not is_safetensors_available():
300
- raise ImportError("`safe_serialization` requires the `safetensors library: `pip install safetensors`.")
301
-
302
- if os.path.isfile(save_directory):
303
- logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
304
- return
305
-
306
- os.makedirs(save_directory, exist_ok=True)
307
-
308
- model_to_save = self
309
-
310
- # Attach architecture to the config
311
- # Save the config
312
- if is_main_process:
313
- model_to_save.save_config(save_directory)
314
-
315
- # Save the model
316
- state_dict = model_to_save.state_dict()
317
-
318
- weights_name = SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
319
- weights_name = _add_variant(weights_name, variant)
320
-
321
- # Save the model
322
- if safe_serialization:
323
- safetensors.torch.save_file(
324
- state_dict, os.path.join(save_directory, weights_name), metadata={"format": "pt"}
325
- )
326
- else:
327
- torch.save(state_dict, os.path.join(save_directory, weights_name))
328
-
329
- logger.info(f"Model weights saved in {os.path.join(save_directory, weights_name)}")
330
-
331
- @classmethod
332
- def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
333
- r"""
334
- Instantiate a pretrained PyTorch model from a pretrained model configuration.
335
-
336
- The model is set in evaluation mode - `model.eval()` - by default, and dropout modules are deactivated. To
337
- train the model, set it back in training mode with `model.train()`.
338
-
339
- Parameters:
340
- pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
341
- Can be either:
342
-
343
- - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
344
- the Hub.
345
- - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
346
- with [`~ModelMixin.save_pretrained`].
347
-
348
- cache_dir (`Union[str, os.PathLike]`, *optional*):
349
- Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
350
- is not used.
351
- torch_dtype (`str` or `torch.dtype`, *optional*):
352
- Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
353
- dtype is automatically derived from the model's weights.
354
- force_download (`bool`, *optional*, defaults to `False`):
355
- Whether or not to force the (re-)download of the model weights and configuration files, overriding the
356
- cached versions if they exist.
357
- resume_download (`bool`, *optional*, defaults to `False`):
358
- Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
359
- incompletely downloaded files are deleted.
360
- proxies (`Dict[str, str]`, *optional*):
361
- A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
362
- 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
363
- output_loading_info (`bool`, *optional*, defaults to `False`):
364
- Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
365
- local_files_only(`bool`, *optional*, defaults to `False`):
366
- Whether to only load local model weights and configuration files or not. If set to `True`, the model
367
- won't be downloaded from the Hub.
368
- use_auth_token (`str` or *bool*, *optional*):
369
- The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
370
- `diffusers-cli login` (stored in `~/.huggingface`) is used.
371
- revision (`str`, *optional*, defaults to `"main"`):
372
- The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
373
- allowed by Git.
374
- from_flax (`bool`, *optional*, defaults to `False`):
375
- Load the model weights from a Flax checkpoint save file.
376
- subfolder (`str`, *optional*, defaults to `""`):
377
- The subfolder location of a model file within a larger model repository on the Hub or locally.
378
- mirror (`str`, *optional*):
379
- Mirror source to resolve accessibility issues if you're downloading a model in China. We do not
380
- guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
381
- information.
382
- device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
383
- A map that specifies where each submodule should go. It doesn't need to be defined for each
384
- parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the
385
- same device.
386
-
387
- Set `device_map="auto"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For
388
- more information about each option see [designing a device
389
- map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
390
- max_memory (`Dict`, *optional*):
391
- A dictionary device identifier for the maximum memory. Will default to the maximum memory available for
392
- each GPU and the available CPU RAM if unset.
393
- offload_folder (`str` or `os.PathLike`, *optional*):
394
- The path to offload weights if `device_map` contains the value `"disk"`.
395
- offload_state_dict (`bool`, *optional*):
396
- If `True`, temporarily offloads the CPU state dict to the hard drive to avoid running out of CPU RAM if
397
- the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to `True`
398
- when there is some disk offload.
399
- low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
400
- Speed up model loading only loading the pretrained weights and not initializing the weights. This also
401
- tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
402
- Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
403
- argument to `True` will raise an error.
404
- variant (`str`, *optional*):
405
- Load weights from a specified `variant` filename such as `"fp16"` or `"ema"`. This is ignored when
406
- loading `from_flax`.
407
- use_safetensors (`bool`, *optional*, defaults to `None`):
408
- If set to `None`, the `safetensors` weights are downloaded if they're available **and** if the
409
- `safetensors` library is installed. If set to `True`, the model is forcibly loaded from `safetensors`
410
- weights. If set to `False`, `safetensors` weights are not loaded.
411
-
412
- <Tip>
413
-
414
- To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with
415
- `huggingface-cli login`. You can also activate the special
416
- ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use this method in a
417
- firewalled environment.
418
-
419
- </Tip>
420
-
421
- Example:
422
-
423
- ```py
424
- from diffusers import UNet2DConditionModel
425
-
426
- unet = UNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet")
427
- ```
428
-
429
- If you get the error message below, you need to finetune the weights for your downstream task:
430
-
431
- ```bash
432
- Some weights of UNet2DConditionModel were not initialized from the model checkpoint at runwayml/stable-diffusion-v1-5 and are newly initialized because the shapes did not match:
433
- - conv_in.weight: found shape torch.Size([320, 4, 3, 3]) in the checkpoint and torch.Size([320, 9, 3, 3]) in the model instantiated
434
- You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
435
- ```
436
- """
437
- cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
438
- ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
439
- force_download = kwargs.pop("force_download", False)
440
- from_flax = kwargs.pop("from_flax", False)
441
- resume_download = kwargs.pop("resume_download", False)
442
- proxies = kwargs.pop("proxies", None)
443
- output_loading_info = kwargs.pop("output_loading_info", False)
444
- local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
445
- use_auth_token = kwargs.pop("use_auth_token", None)
446
- revision = kwargs.pop("revision", None)
447
- torch_dtype = kwargs.pop("torch_dtype", None)
448
- subfolder = kwargs.pop("subfolder", None)
449
- device_map = kwargs.pop("device_map", None)
450
- max_memory = kwargs.pop("max_memory", None)
451
- offload_folder = kwargs.pop("offload_folder", None)
452
- offload_state_dict = kwargs.pop("offload_state_dict", False)
453
- low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
454
- variant = kwargs.pop("variant", None)
455
- use_safetensors = kwargs.pop("use_safetensors", None)
456
-
457
- if use_safetensors and not is_safetensors_available():
458
- raise ValueError(
459
- "`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetensors"
460
- )
461
-
462
- allow_pickle = False
463
- if use_safetensors is None:
464
- use_safetensors = is_safetensors_available()
465
- allow_pickle = True
466
-
467
- if low_cpu_mem_usage and not is_accelerate_available():
468
- low_cpu_mem_usage = False
469
- logger.warning(
470
- "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
471
- " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
472
- " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
473
- " install accelerate\n```\n."
474
- )
475
-
476
- if device_map is not None and not is_accelerate_available():
477
- raise NotImplementedError(
478
- "Loading and dispatching requires `accelerate`. Please make sure to install accelerate or set"
479
- " `device_map=None`. You can install accelerate with `pip install accelerate`."
480
- )
481
-
482
- # Check if we can handle device_map and dispatching the weights
483
- if device_map is not None and not is_torch_version(">=", "1.9.0"):
484
- raise NotImplementedError(
485
- "Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set"
486
- " `device_map=None`."
487
- )
488
-
489
- if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
490
- raise NotImplementedError(
491
- "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
492
- " `low_cpu_mem_usage=False`."
493
- )
494
-
495
- if low_cpu_mem_usage is False and device_map is not None:
496
- raise ValueError(
497
- f"You cannot set `low_cpu_mem_usage` to `False` while using device_map={device_map} for loading and"
498
- " dispatching. Please make sure to set `low_cpu_mem_usage=True`."
499
- )
500
-
501
- # Load config if we don't provide a configuration
502
- config_path = pretrained_model_name_or_path
503
-
504
- user_agent = {
505
- "diffusers": __version__,
506
- "file_type": "model",
507
- "framework": "pytorch",
508
- }
509
-
510
- # load config
511
- config, unused_kwargs, commit_hash = cls.load_config(
512
- config_path,
513
- cache_dir=cache_dir,
514
- return_unused_kwargs=True,
515
- return_commit_hash=True,
516
- force_download=force_download,
517
- resume_download=resume_download,
518
- proxies=proxies,
519
- local_files_only=local_files_only,
520
- use_auth_token=use_auth_token,
521
- revision=revision,
522
- subfolder=subfolder,
523
- device_map=device_map,
524
- max_memory=max_memory,
525
- offload_folder=offload_folder,
526
- offload_state_dict=offload_state_dict,
527
- user_agent=user_agent,
528
- **kwargs,
529
- )
530
-
531
- # load model
532
- model_file = None
533
- if from_flax:
534
- model_file = _get_model_file(
535
- pretrained_model_name_or_path,
536
- weights_name=FLAX_WEIGHTS_NAME,
537
- cache_dir=cache_dir,
538
- force_download=force_download,
539
- resume_download=resume_download,
540
- proxies=proxies,
541
- local_files_only=local_files_only,
542
- use_auth_token=use_auth_token,
543
- revision=revision,
544
- subfolder=subfolder,
545
- user_agent=user_agent,
546
- commit_hash=commit_hash,
547
- )
548
- model = cls.from_config(config, **unused_kwargs)
549
-
550
- # Convert the weights
551
- from .modeling_pytorch_flax_utils import load_flax_checkpoint_in_pytorch_model
552
-
553
- model = load_flax_checkpoint_in_pytorch_model(model, model_file)
554
- else:
555
- if use_safetensors:
556
- try:
557
- model_file = _get_model_file(
558
- pretrained_model_name_or_path,
559
- weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant),
560
- cache_dir=cache_dir,
561
- force_download=force_download,
562
- resume_download=resume_download,
563
- proxies=proxies,
564
- local_files_only=local_files_only,
565
- use_auth_token=use_auth_token,
566
- revision=revision,
567
- subfolder=subfolder,
568
- user_agent=user_agent,
569
- commit_hash=commit_hash,
570
- )
571
- except IOError as e:
572
- if not allow_pickle:
573
- raise e
574
- pass
575
- if model_file is None:
576
- model_file = _get_model_file(
577
- pretrained_model_name_or_path,
578
- weights_name=_add_variant(WEIGHTS_NAME, variant),
579
- cache_dir=cache_dir,
580
- force_download=force_download,
581
- resume_download=resume_download,
582
- proxies=proxies,
583
- local_files_only=local_files_only,
584
- use_auth_token=use_auth_token,
585
- revision=revision,
586
- subfolder=subfolder,
587
- user_agent=user_agent,
588
- commit_hash=commit_hash,
589
- )
590
-
591
- if low_cpu_mem_usage:
592
- # Instantiate model with empty weights
593
- with accelerate.init_empty_weights():
594
- model = cls.from_config(config, **unused_kwargs)
595
-
596
- # if device_map is None, load the state dict and move the params from meta device to the cpu
597
- if device_map is None:
598
- param_device = "cpu"
599
- state_dict = load_state_dict(model_file, variant=variant)
600
- model._convert_deprecated_attention_blocks(state_dict)
601
- # move the params from meta device to cpu
602
- missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
603
- if len(missing_keys) > 0:
604
- raise ValueError(
605
- f"Cannot load {cls} from {pretrained_model_name_or_path} because the following keys are"
606
- f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass"
607
- " `low_cpu_mem_usage=False` and `device_map=None` if you want to randomly initialize"
608
- " those weights or else make sure your checkpoint file is correct."
609
- )
610
- unexpected_keys = []
611
-
612
- empty_state_dict = model.state_dict()
613
- for param_name, param in state_dict.items():
614
- accepts_dtype = "dtype" in set(
615
- inspect.signature(set_module_tensor_to_device).parameters.keys()
616
- )
617
-
618
- if param_name not in empty_state_dict:
619
- unexpected_keys.append(param_name)
620
- continue
621
-
622
- if empty_state_dict[param_name].shape != param.shape:
623
- raise ValueError(
624
- f"Cannot load {pretrained_model_name_or_path} because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example."
625
- )
626
-
627
- if accepts_dtype:
628
- set_module_tensor_to_device(
629
- model, param_name, param_device, value=param, dtype=torch_dtype
630
- )
631
- else:
632
- set_module_tensor_to_device(model, param_name, param_device, value=param)
633
-
634
- if cls._keys_to_ignore_on_load_unexpected is not None:
635
- for pat in cls._keys_to_ignore_on_load_unexpected:
636
- unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
637
-
638
- if len(unexpected_keys) > 0:
639
- logger.warn(
640
- f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
641
- )
642
-
643
- else: # else let accelerate handle loading and dispatching.
644
- # Load weights and dispatch according to the device_map
645
- # by default the device_map is None and the weights are loaded on the CPU
646
- try:
647
- accelerate.load_checkpoint_and_dispatch(
648
- model,
649
- model_file,
650
- device_map,
651
- max_memory=max_memory,
652
- offload_folder=offload_folder,
653
- offload_state_dict=offload_state_dict,
654
- dtype=torch_dtype,
655
- )
656
- except AttributeError as e:
657
- # When using accelerate loading, we do not have the ability to load the state
658
- # dict and rename the weight names manually. Additionally, accelerate skips
659
- # torch loading conventions and directly writes into `module.{_buffers, _parameters}`
660
- # (which look like they should be private variables?), so we can't use the standard hooks
661
- # to rename parameters on load. We need to mimic the original weight names so the correct
662
- # attributes are available. After we have loaded the weights, we convert the deprecated
663
- # names to the new non-deprecated names. Then we _greatly encourage_ the user to convert
664
- # the weights so we don't have to do this again.
665
-
666
- if "'Attention' object has no attribute" in str(e):
667
- logger.warn(
668
- f"Taking `{str(e)}` while using `accelerate.load_checkpoint_and_dispatch` to mean {pretrained_model_name_or_path}"
669
- " was saved with deprecated attention block weight names. We will load it with the deprecated attention block"
670
- " names and convert them on the fly to the new attention block format. Please re-save the model after this conversion,"
671
- " so we don't have to do the on the fly renaming in the future. If the model is from a hub checkpoint,"
672
- " please also re-upload it or open a PR on the original repository."
673
- )
674
- model._temp_convert_self_to_deprecated_attention_blocks()
675
- accelerate.load_checkpoint_and_dispatch(
676
- model,
677
- model_file,
678
- device_map,
679
- max_memory=max_memory,
680
- offload_folder=offload_folder,
681
- offload_state_dict=offload_state_dict,
682
- dtype=torch_dtype,
683
- )
684
- model._undo_temp_convert_self_to_deprecated_attention_blocks()
685
- else:
686
- raise e
687
-
688
- loading_info = {
689
- "missing_keys": [],
690
- "unexpected_keys": [],
691
- "mismatched_keys": [],
692
- "error_msgs": [],
693
- }
694
- else:
695
- model = cls.from_config(config, **unused_kwargs)
696
-
697
- state_dict = load_state_dict(model_file, variant=variant)
698
- model._convert_deprecated_attention_blocks(state_dict)
699
-
700
- model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
701
- model,
702
- state_dict,
703
- model_file,
704
- pretrained_model_name_or_path,
705
- ignore_mismatched_sizes=ignore_mismatched_sizes,
706
- )
707
-
708
- loading_info = {
709
- "missing_keys": missing_keys,
710
- "unexpected_keys": unexpected_keys,
711
- "mismatched_keys": mismatched_keys,
712
- "error_msgs": error_msgs,
713
- }
714
-
715
- if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
716
- raise ValueError(
717
- f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}."
718
- )
719
- elif torch_dtype is not None:
720
- model = model.to(torch_dtype)
721
-
722
- model.register_to_config(_name_or_path=pretrained_model_name_or_path)
723
-
724
- # Set model in evaluation mode to deactivate DropOut modules by default
725
- model.eval()
726
- if output_loading_info:
727
- return model, loading_info
728
-
729
- return model
730
-
731
- @classmethod
732
- def _load_pretrained_model(
733
- cls,
734
- model,
735
- state_dict,
736
- resolved_archive_file,
737
- pretrained_model_name_or_path,
738
- ignore_mismatched_sizes=False,
739
- ):
740
- # Retrieve missing & unexpected_keys
741
- model_state_dict = model.state_dict()
742
- loaded_keys = list(state_dict.keys())
743
-
744
- expected_keys = list(model_state_dict.keys())
745
-
746
- original_loaded_keys = loaded_keys
747
-
748
- missing_keys = list(set(expected_keys) - set(loaded_keys))
749
- unexpected_keys = list(set(loaded_keys) - set(expected_keys))
750
-
751
- # Make sure we are able to load base models as well as derived models (with heads)
752
- model_to_load = model
753
-
754
- def _find_mismatched_keys(
755
- state_dict,
756
- model_state_dict,
757
- loaded_keys,
758
- ignore_mismatched_sizes,
759
- ):
760
- mismatched_keys = []
761
- if ignore_mismatched_sizes:
762
- for checkpoint_key in loaded_keys:
763
- model_key = checkpoint_key
764
-
765
- if (
766
- model_key in model_state_dict
767
- and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape
768
- ):
769
- mismatched_keys.append(
770
- (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
771
- )
772
- del state_dict[checkpoint_key]
773
- return mismatched_keys
774
-
775
- if state_dict is not None:
776
- # Whole checkpoint
777
- mismatched_keys = _find_mismatched_keys(
778
- state_dict,
779
- model_state_dict,
780
- original_loaded_keys,
781
- ignore_mismatched_sizes,
782
- )
783
- error_msgs = _load_state_dict_into_model(model_to_load, state_dict)
784
-
785
- if len(error_msgs) > 0:
786
- error_msg = "\n\t".join(error_msgs)
787
- if "size mismatch" in error_msg:
788
- error_msg += (
789
- "\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method."
790
- )
791
- raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")
792
-
793
- if len(unexpected_keys) > 0:
794
- logger.warning(
795
- f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
796
- f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
797
- f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task"
798
- " or with another architecture (e.g. initializing a BertForSequenceClassification model from a"
799
- " BertForPreTraining model).\n- This IS NOT expected if you are initializing"
800
- f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly"
801
- " identical (initializing a BertForSequenceClassification model from a"
802
- " BertForSequenceClassification model)."
803
- )
804
- else:
805
- logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
806
- if len(missing_keys) > 0:
807
- logger.warning(
808
- f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
809
- f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably"
810
- " TRAIN this model on a down-stream task to be able to use it for predictions and inference."
811
- )
812
- elif len(mismatched_keys) == 0:
813
- logger.info(
814
- f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at"
815
- f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the"
816
- f" checkpoint was trained on, you can already use {model.__class__.__name__} for predictions"
817
- " without further training."
818
- )
819
- if len(mismatched_keys) > 0:
820
- mismatched_warning = "\n".join(
821
- [
822
- f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
823
- for key, shape1, shape2 in mismatched_keys
824
- ]
825
- )
826
- logger.warning(
827
- f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
828
- f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not"
829
- f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be"
830
- " able to use it for predictions and inference."
831
- )
832
-
833
- return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs
834
-
835
- @property
836
- def device(self) -> device:
837
- """
838
- `torch.device`: The device on which the module is (assuming that all the module parameters are on the same
839
- device).
840
- """
841
- return get_parameter_device(self)
842
-
843
- @property
844
- def dtype(self) -> torch.dtype:
845
- """
846
- `torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
847
- """
848
- return get_parameter_dtype(self)
849
-
850
- def num_parameters(self, only_trainable: bool = False, exclude_embeddings: bool = False) -> int:
851
- """
852
- Get number of (trainable or non-embedding) parameters in the module.
853
-
854
- Args:
855
- only_trainable (`bool`, *optional*, defaults to `False`):
856
- Whether or not to return only the number of trainable parameters.
857
- exclude_embeddings (`bool`, *optional*, defaults to `False`):
858
- Whether or not to return only the number of non-embedding parameters.
859
-
860
- Returns:
861
- `int`: The number of parameters.
862
-
863
- Example:
864
-
865
- ```py
866
- from diffusers import UNet2DConditionModel
867
-
868
- model_id = "runwayml/stable-diffusion-v1-5"
869
- unet = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet")
870
- unet.num_parameters(only_trainable=True)
871
- 859520964
872
- ```
873
- """
874
-
875
- if exclude_embeddings:
876
- embedding_param_names = [
877
- f"{name}.weight"
878
- for name, module_type in self.named_modules()
879
- if isinstance(module_type, torch.nn.Embedding)
880
- ]
881
- non_embedding_parameters = [
882
- parameter for name, parameter in self.named_parameters() if name not in embedding_param_names
883
- ]
884
- return sum(p.numel() for p in non_embedding_parameters if p.requires_grad or not only_trainable)
885
- else:
886
- return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable)
887
-
888
- def _convert_deprecated_attention_blocks(self, state_dict):
889
- deprecated_attention_block_paths = []
890
-
891
- def recursive_find_attn_block(name, module):
892
- if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block:
893
- deprecated_attention_block_paths.append(name)
894
-
895
- for sub_name, sub_module in module.named_children():
896
- sub_name = sub_name if name == "" else f"{name}.{sub_name}"
897
- recursive_find_attn_block(sub_name, sub_module)
898
-
899
- recursive_find_attn_block("", self)
900
-
901
- # NOTE: we have to check if the deprecated parameters are in the state dict
902
- # because it is possible we are loading from a state dict that was already
903
- # converted
904
-
905
- for path in deprecated_attention_block_paths:
906
- # group_norm path stays the same
907
-
908
- # query -> to_q
909
- if f"{path}.query.weight" in state_dict:
910
- state_dict[f"{path}.to_q.weight"] = state_dict.pop(f"{path}.query.weight")
911
- if f"{path}.query.bias" in state_dict:
912
- state_dict[f"{path}.to_q.bias"] = state_dict.pop(f"{path}.query.bias")
913
-
914
- # key -> to_k
915
- if f"{path}.key.weight" in state_dict:
916
- state_dict[f"{path}.to_k.weight"] = state_dict.pop(f"{path}.key.weight")
917
- if f"{path}.key.bias" in state_dict:
918
- state_dict[f"{path}.to_k.bias"] = state_dict.pop(f"{path}.key.bias")
919
-
920
- # value -> to_v
921
- if f"{path}.value.weight" in state_dict:
922
- state_dict[f"{path}.to_v.weight"] = state_dict.pop(f"{path}.value.weight")
923
- if f"{path}.value.bias" in state_dict:
924
- state_dict[f"{path}.to_v.bias"] = state_dict.pop(f"{path}.value.bias")
925
-
926
- # proj_attn -> to_out.0
927
- if f"{path}.proj_attn.weight" in state_dict:
928
- state_dict[f"{path}.to_out.0.weight"] = state_dict.pop(f"{path}.proj_attn.weight")
929
- if f"{path}.proj_attn.bias" in state_dict:
930
- state_dict[f"{path}.to_out.0.bias"] = state_dict.pop(f"{path}.proj_attn.bias")
931
-
932
- def _temp_convert_self_to_deprecated_attention_blocks(self):
933
- deprecated_attention_block_modules = []
934
-
935
- def recursive_find_attn_block(module):
936
- if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block:
937
- deprecated_attention_block_modules.append(module)
938
-
939
- for sub_module in module.children():
940
- recursive_find_attn_block(sub_module)
941
-
942
- recursive_find_attn_block(self)
943
-
944
- for module in deprecated_attention_block_modules:
945
- module.query = module.to_q
946
- module.key = module.to_k
947
- module.value = module.to_v
948
- module.proj_attn = module.to_out[0]
949
-
950
- # We don't _have_ to delete the old attributes, but it's helpful to ensure
951
- # that _all_ the weights are loaded into the new attributes and we're not
952
- # making an incorrect assumption that this model should be converted when
953
- # it really shouldn't be.
954
- del module.to_q
955
- del module.to_k
956
- del module.to_v
957
- del module.to_out
958
-
959
- def _undo_temp_convert_self_to_deprecated_attention_blocks(self):
960
- deprecated_attention_block_modules = []
961
-
962
- def recursive_find_attn_block(module):
963
- if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block:
964
- deprecated_attention_block_modules.append(module)
965
-
966
- for sub_module in module.children():
967
- recursive_find_attn_block(sub_module)
968
-
969
- recursive_find_attn_block(self)
970
-
971
- for module in deprecated_attention_block_modules:
972
- module.to_q = module.query
973
- module.to_k = module.key
974
- module.to_v = module.value
975
- module.to_out = nn.ModuleList([module.proj_attn, nn.Dropout(module.dropout)])
976
-
977
- del module.query
978
- del module.key
979
- del module.value
980
- del module.proj_attn
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
gradio_demo/6DoF/diffusers/models/prior_transformer.py DELETED
@@ -1,364 +0,0 @@
1
- from dataclasses import dataclass
2
- from typing import Dict, Optional, Union
3
-
4
- import torch
5
- import torch.nn.functional as F
6
- from torch import nn
7
-
8
- from ..configuration_utils import ConfigMixin, register_to_config
9
- from ..utils import BaseOutput
10
- from .attention import BasicTransformerBlock
11
- from .attention_processor import AttentionProcessor, AttnProcessor
12
- from .embeddings import TimestepEmbedding, Timesteps
13
- from .modeling_utils import ModelMixin
14
-
15
-
16
- @dataclass
17
- class PriorTransformerOutput(BaseOutput):
18
- """
19
- The output of [`PriorTransformer`].
20
-
21
- Args:
22
- predicted_image_embedding (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`):
23
- The predicted CLIP image embedding conditioned on the CLIP text embedding input.
24
- """
25
-
26
- predicted_image_embedding: torch.FloatTensor
27
-
28
-
29
- class PriorTransformer(ModelMixin, ConfigMixin):
30
- """
31
- A Prior Transformer model.
32
-
33
- Parameters:
34
- num_attention_heads (`int`, *optional*, defaults to 32): The number of heads to use for multi-head attention.
35
- attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
36
- num_layers (`int`, *optional*, defaults to 20): The number of layers of Transformer blocks to use.
37
- embedding_dim (`int`, *optional*, defaults to 768): The dimension of the model input `hidden_states`
38
- num_embeddings (`int`, *optional*, defaults to 77):
39
- The number of embeddings of the model input `hidden_states`
40
- additional_embeddings (`int`, *optional*, defaults to 4): The number of additional tokens appended to the
41
- projected `hidden_states`. The actual length of the used `hidden_states` is `num_embeddings +
42
- additional_embeddings`.
43
- dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
44
- time_embed_act_fn (`str`, *optional*, defaults to 'silu'):
45
- The activation function to use to create timestep embeddings.
46
- norm_in_type (`str`, *optional*, defaults to None): The normalization layer to apply on hidden states before
47
- passing to Transformer blocks. Set it to `None` if normalization is not needed.
48
- embedding_proj_norm_type (`str`, *optional*, defaults to None):
49
- The normalization layer to apply on the input `proj_embedding`. Set it to `None` if normalization is not
50
- needed.
51
- encoder_hid_proj_type (`str`, *optional*, defaults to `linear`):
52
- The projection layer to apply on the input `encoder_hidden_states`. Set it to `None` if
53
- `encoder_hidden_states` is `None`.
54
- added_emb_type (`str`, *optional*, defaults to `prd`): Additional embeddings to condition the model.
55
- Choose from `prd` or `None`. if choose `prd`, it will prepend a token indicating the (quantized) dot
56
- product between the text embedding and image embedding as proposed in the unclip paper
57
- https://arxiv.org/abs/2204.06125 If it is `None`, no additional embeddings will be prepended.
58
- time_embed_dim (`int, *optional*, defaults to None): The dimension of timestep embeddings.
59
- If None, will be set to `num_attention_heads * attention_head_dim`
60
- embedding_proj_dim (`int`, *optional*, default to None):
61
- The dimension of `proj_embedding`. If None, will be set to `embedding_dim`.
62
- clip_embed_dim (`int`, *optional*, default to None):
63
- The dimension of the output. If None, will be set to `embedding_dim`.
64
- """
65
-
66
- @register_to_config
67
- def __init__(
68
- self,
69
- num_attention_heads: int = 32,
70
- attention_head_dim: int = 64,
71
- num_layers: int = 20,
72
- embedding_dim: int = 768,
73
- num_embeddings=77,
74
- additional_embeddings=4,
75
- dropout: float = 0.0,
76
- time_embed_act_fn: str = "silu",
77
- norm_in_type: Optional[str] = None, # layer
78
- embedding_proj_norm_type: Optional[str] = None, # layer
79
- encoder_hid_proj_type: Optional[str] = "linear", # linear
80
- added_emb_type: Optional[str] = "prd", # prd
81
- time_embed_dim: Optional[int] = None,
82
- embedding_proj_dim: Optional[int] = None,
83
- clip_embed_dim: Optional[int] = None,
84
- ):
85
- super().__init__()
86
- self.num_attention_heads = num_attention_heads
87
- self.attention_head_dim = attention_head_dim
88
- inner_dim = num_attention_heads * attention_head_dim
89
- self.additional_embeddings = additional_embeddings
90
-
91
- time_embed_dim = time_embed_dim or inner_dim
92
- embedding_proj_dim = embedding_proj_dim or embedding_dim
93
- clip_embed_dim = clip_embed_dim or embedding_dim
94
-
95
- self.time_proj = Timesteps(inner_dim, True, 0)
96
- self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, out_dim=inner_dim, act_fn=time_embed_act_fn)
97
-
98
- self.proj_in = nn.Linear(embedding_dim, inner_dim)
99
-
100
- if embedding_proj_norm_type is None:
101
- self.embedding_proj_norm = None
102
- elif embedding_proj_norm_type == "layer":
103
- self.embedding_proj_norm = nn.LayerNorm(embedding_proj_dim)
104
- else:
105
- raise ValueError(f"unsupported embedding_proj_norm_type: {embedding_proj_norm_type}")
106
-
107
- self.embedding_proj = nn.Linear(embedding_proj_dim, inner_dim)
108
-
109
- if encoder_hid_proj_type is None:
110
- self.encoder_hidden_states_proj = None
111
- elif encoder_hid_proj_type == "linear":
112
- self.encoder_hidden_states_proj = nn.Linear(embedding_dim, inner_dim)
113
- else:
114
- raise ValueError(f"unsupported encoder_hid_proj_type: {encoder_hid_proj_type}")
115
-
116
- self.positional_embedding = nn.Parameter(torch.zeros(1, num_embeddings + additional_embeddings, inner_dim))
117
-
118
- if added_emb_type == "prd":
119
- self.prd_embedding = nn.Parameter(torch.zeros(1, 1, inner_dim))
120
- elif added_emb_type is None:
121
- self.prd_embedding = None
122
- else:
123
- raise ValueError(
124
- f"`added_emb_type`: {added_emb_type} is not supported. Make sure to choose one of `'prd'` or `None`."
125
- )
126
-
127
- self.transformer_blocks = nn.ModuleList(
128
- [
129
- BasicTransformerBlock(
130
- inner_dim,
131
- num_attention_heads,
132
- attention_head_dim,
133
- dropout=dropout,
134
- activation_fn="gelu",
135
- attention_bias=True,
136
- )
137
- for d in range(num_layers)
138
- ]
139
- )
140
-
141
- if norm_in_type == "layer":
142
- self.norm_in = nn.LayerNorm(inner_dim)
143
- elif norm_in_type is None:
144
- self.norm_in = None
145
- else:
146
- raise ValueError(f"Unsupported norm_in_type: {norm_in_type}.")
147
-
148
- self.norm_out = nn.LayerNorm(inner_dim)
149
-
150
- self.proj_to_clip_embeddings = nn.Linear(inner_dim, clip_embed_dim)
151
-
152
- causal_attention_mask = torch.full(
153
- [num_embeddings + additional_embeddings, num_embeddings + additional_embeddings], -10000.0
154
- )
155
- causal_attention_mask.triu_(1)
156
- causal_attention_mask = causal_attention_mask[None, ...]
157
- self.register_buffer("causal_attention_mask", causal_attention_mask, persistent=False)
158
-
159
- self.clip_mean = nn.Parameter(torch.zeros(1, clip_embed_dim))
160
- self.clip_std = nn.Parameter(torch.zeros(1, clip_embed_dim))
161
-
162
- @property
163
- # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
164
- def attn_processors(self) -> Dict[str, AttentionProcessor]:
165
- r"""
166
- Returns:
167
- `dict` of attention processors: A dictionary containing all attention processors used in the model with
168
- indexed by its weight name.
169
- """
170
- # set recursively
171
- processors = {}
172
-
173
- def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
174
- if hasattr(module, "set_processor"):
175
- processors[f"{name}.processor"] = module.processor
176
-
177
- for sub_name, child in module.named_children():
178
- fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
179
-
180
- return processors
181
-
182
- for name, module in self.named_children():
183
- fn_recursive_add_processors(name, module, processors)
184
-
185
- return processors
186
-
187
- # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
188
- def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
189
- r"""
190
- Sets the attention processor to use to compute attention.
191
-
192
- Parameters:
193
- processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
194
- The instantiated processor class or a dictionary of processor classes that will be set as the processor
195
- for **all** `Attention` layers.
196
-
197
- If `processor` is a dict, the key needs to define the path to the corresponding cross attention
198
- processor. This is strongly recommended when setting trainable attention processors.
199
-
200
- """
201
- count = len(self.attn_processors.keys())
202
-
203
- if isinstance(processor, dict) and len(processor) != count:
204
- raise ValueError(
205
- f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
206
- f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
207
- )
208
-
209
- def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
210
- if hasattr(module, "set_processor"):
211
- if not isinstance(processor, dict):
212
- module.set_processor(processor)
213
- else:
214
- module.set_processor(processor.pop(f"{name}.processor"))
215
-
216
- for sub_name, child in module.named_children():
217
- fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
218
-
219
- for name, module in self.named_children():
220
- fn_recursive_attn_processor(name, module, processor)
221
-
222
- # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
223
- def set_default_attn_processor(self):
224
- """
225
- Disables custom attention processors and sets the default attention implementation.
226
- """
227
- self.set_attn_processor(AttnProcessor())
228
-
229
- def forward(
230
- self,
231
- hidden_states,
232
- timestep: Union[torch.Tensor, float, int],
233
- proj_embedding: torch.FloatTensor,
234
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
235
- attention_mask: Optional[torch.BoolTensor] = None,
236
- return_dict: bool = True,
237
- ):
238
- """
239
- The [`PriorTransformer`] forward method.
240
-
241
- Args:
242
- hidden_states (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`):
243
- The currently predicted image embeddings.
244
- timestep (`torch.LongTensor`):
245
- Current denoising step.
246
- proj_embedding (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`):
247
- Projected embedding vector the denoising process is conditioned on.
248
- encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, num_embeddings, embedding_dim)`):
249
- Hidden states of the text embeddings the denoising process is conditioned on.
250
- attention_mask (`torch.BoolTensor` of shape `(batch_size, num_embeddings)`):
251
- Text mask for the text embeddings.
252
- return_dict (`bool`, *optional*, defaults to `True`):
253
- Whether or not to return a [`~models.prior_transformer.PriorTransformerOutput`] instead of a plain
254
- tuple.
255
-
256
- Returns:
257
- [`~models.prior_transformer.PriorTransformerOutput`] or `tuple`:
258
- If return_dict is True, a [`~models.prior_transformer.PriorTransformerOutput`] is returned, otherwise a
259
- tuple is returned where the first element is the sample tensor.
260
- """
261
- batch_size = hidden_states.shape[0]
262
-
263
- timesteps = timestep
264
- if not torch.is_tensor(timesteps):
265
- timesteps = torch.tensor([timesteps], dtype=torch.long, device=hidden_states.device)
266
- elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
267
- timesteps = timesteps[None].to(hidden_states.device)
268
-
269
- # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
270
- timesteps = timesteps * torch.ones(batch_size, dtype=timesteps.dtype, device=timesteps.device)
271
-
272
- timesteps_projected = self.time_proj(timesteps)
273
-
274
- # timesteps does not contain any weights and will always return f32 tensors
275
- # but time_embedding might be fp16, so we need to cast here.
276
- timesteps_projected = timesteps_projected.to(dtype=self.dtype)
277
- time_embeddings = self.time_embedding(timesteps_projected)
278
-
279
- if self.embedding_proj_norm is not None:
280
- proj_embedding = self.embedding_proj_norm(proj_embedding)
281
-
282
- proj_embeddings = self.embedding_proj(proj_embedding)
283
- if self.encoder_hidden_states_proj is not None and encoder_hidden_states is not None:
284
- encoder_hidden_states = self.encoder_hidden_states_proj(encoder_hidden_states)
285
- elif self.encoder_hidden_states_proj is not None and encoder_hidden_states is None:
286
- raise ValueError("`encoder_hidden_states_proj` requires `encoder_hidden_states` to be set")
287
-
288
- hidden_states = self.proj_in(hidden_states)
289
-
290
- positional_embeddings = self.positional_embedding.to(hidden_states.dtype)
291
-
292
- additional_embeds = []
293
- additional_embeddings_len = 0
294
-
295
- if encoder_hidden_states is not None:
296
- additional_embeds.append(encoder_hidden_states)
297
- additional_embeddings_len += encoder_hidden_states.shape[1]
298
-
299
- if len(proj_embeddings.shape) == 2:
300
- proj_embeddings = proj_embeddings[:, None, :]
301
-
302
- if len(hidden_states.shape) == 2:
303
- hidden_states = hidden_states[:, None, :]
304
-
305
- additional_embeds = additional_embeds + [
306
- proj_embeddings,
307
- time_embeddings[:, None, :],
308
- hidden_states,
309
- ]
310
-
311
- if self.prd_embedding is not None:
312
- prd_embedding = self.prd_embedding.to(hidden_states.dtype).expand(batch_size, -1, -1)
313
- additional_embeds.append(prd_embedding)
314
-
315
- hidden_states = torch.cat(
316
- additional_embeds,
317
- dim=1,
318
- )
319
-
320
- # Allow positional_embedding to not include the `addtional_embeddings` and instead pad it with zeros for these additional tokens
321
- additional_embeddings_len = additional_embeddings_len + proj_embeddings.shape[1] + 1
322
- if positional_embeddings.shape[1] < hidden_states.shape[1]:
323
- positional_embeddings = F.pad(
324
- positional_embeddings,
325
- (
326
- 0,
327
- 0,
328
- additional_embeddings_len,
329
- self.prd_embedding.shape[1] if self.prd_embedding is not None else 0,
330
- ),
331
- value=0.0,
332
- )
333
-
334
- hidden_states = hidden_states + positional_embeddings
335
-
336
- if attention_mask is not None:
337
- attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
338
- attention_mask = F.pad(attention_mask, (0, self.additional_embeddings), value=0.0)
339
- attention_mask = (attention_mask[:, None, :] + self.causal_attention_mask).to(hidden_states.dtype)
340
- attention_mask = attention_mask.repeat_interleave(self.config.num_attention_heads, dim=0)
341
-
342
- if self.norm_in is not None:
343
- hidden_states = self.norm_in(hidden_states)
344
-
345
- for block in self.transformer_blocks:
346
- hidden_states = block(hidden_states, attention_mask=attention_mask)
347
-
348
- hidden_states = self.norm_out(hidden_states)
349
-
350
- if self.prd_embedding is not None:
351
- hidden_states = hidden_states[:, -1]
352
- else:
353
- hidden_states = hidden_states[:, additional_embeddings_len:]
354
-
355
- predicted_image_embedding = self.proj_to_clip_embeddings(hidden_states)
356
-
357
- if not return_dict:
358
- return (predicted_image_embedding,)
359
-
360
- return PriorTransformerOutput(predicted_image_embedding=predicted_image_embedding)
361
-
362
- def post_process_latents(self, prior_latents):
363
- prior_latents = (prior_latents * self.clip_std) + self.clip_mean
364
- return prior_latents
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
gradio_demo/6DoF/diffusers/models/resnet.py DELETED
@@ -1,877 +0,0 @@
1
- # Copyright 2023 The HuggingFace Team. All rights reserved.
2
- # `TemporalConvLayer` Copyright 2023 Alibaba DAMO-VILAB, The ModelScope Team and The HuggingFace Team. All rights reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- from functools import partial
17
- from typing import Optional
18
-
19
- import torch
20
- import torch.nn as nn
21
- import torch.nn.functional as F
22
-
23
- from .activations import get_activation
24
- from .attention import AdaGroupNorm
25
- from .attention_processor import SpatialNorm
26
-
27
-
28
- class Upsample1D(nn.Module):
29
- """A 1D upsampling layer with an optional convolution.
30
-
31
- Parameters:
32
- channels (`int`):
33
- number of channels in the inputs and outputs.
34
- use_conv (`bool`, default `False`):
35
- option to use a convolution.
36
- use_conv_transpose (`bool`, default `False`):
37
- option to use a convolution transpose.
38
- out_channels (`int`, optional):
39
- number of output channels. Defaults to `channels`.
40
- """
41
-
42
- def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
43
- super().__init__()
44
- self.channels = channels
45
- self.out_channels = out_channels or channels
46
- self.use_conv = use_conv
47
- self.use_conv_transpose = use_conv_transpose
48
- self.name = name
49
-
50
- self.conv = None
51
- if use_conv_transpose:
52
- self.conv = nn.ConvTranspose1d(channels, self.out_channels, 4, 2, 1)
53
- elif use_conv:
54
- self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1)
55
-
56
- def forward(self, inputs):
57
- assert inputs.shape[1] == self.channels
58
- if self.use_conv_transpose:
59
- return self.conv(inputs)
60
-
61
- outputs = F.interpolate(inputs, scale_factor=2.0, mode="nearest")
62
-
63
- if self.use_conv:
64
- outputs = self.conv(outputs)
65
-
66
- return outputs
67
-
68
-
69
- class Downsample1D(nn.Module):
70
- """A 1D downsampling layer with an optional convolution.
71
-
72
- Parameters:
73
- channels (`int`):
74
- number of channels in the inputs and outputs.
75
- use_conv (`bool`, default `False`):
76
- option to use a convolution.
77
- out_channels (`int`, optional):
78
- number of output channels. Defaults to `channels`.
79
- padding (`int`, default `1`):
80
- padding for the convolution.
81
- """
82
-
83
- def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
84
- super().__init__()
85
- self.channels = channels
86
- self.out_channels = out_channels or channels
87
- self.use_conv = use_conv
88
- self.padding = padding
89
- stride = 2
90
- self.name = name
91
-
92
- if use_conv:
93
- self.conv = nn.Conv1d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
94
- else:
95
- assert self.channels == self.out_channels
96
- self.conv = nn.AvgPool1d(kernel_size=stride, stride=stride)
97
-
98
- def forward(self, inputs):
99
- assert inputs.shape[1] == self.channels
100
- return self.conv(inputs)
101
-
102
-
103
- class Upsample2D(nn.Module):
104
- """A 2D upsampling layer with an optional convolution.
105
-
106
- Parameters:
107
- channels (`int`):
108
- number of channels in the inputs and outputs.
109
- use_conv (`bool`, default `False`):
110
- option to use a convolution.
111
- use_conv_transpose (`bool`, default `False`):
112
- option to use a convolution transpose.
113
- out_channels (`int`, optional):
114
- number of output channels. Defaults to `channels`.
115
- """
116
-
117
- def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
118
- super().__init__()
119
- self.channels = channels
120
- self.out_channels = out_channels or channels
121
- self.use_conv = use_conv
122
- self.use_conv_transpose = use_conv_transpose
123
- self.name = name
124
-
125
- conv = None
126
- if use_conv_transpose:
127
- conv = nn.ConvTranspose2d(channels, self.out_channels, 4, 2, 1)
128
- elif use_conv:
129
- conv = nn.Conv2d(self.channels, self.out_channels, 3, padding=1)
130
-
131
- # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
132
- if name == "conv":
133
- self.conv = conv
134
- else:
135
- self.Conv2d_0 = conv
136
-
137
- def forward(self, hidden_states, output_size=None):
138
- assert hidden_states.shape[1] == self.channels
139
-
140
- if self.use_conv_transpose:
141
- return self.conv(hidden_states)
142
-
143
- # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
144
- # TODO(Suraj): Remove this cast once the issue is fixed in PyTorch
145
- # https://github.com/pytorch/pytorch/issues/86679
146
- dtype = hidden_states.dtype
147
- if dtype == torch.bfloat16:
148
- hidden_states = hidden_states.to(torch.float32)
149
-
150
- # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
151
- if hidden_states.shape[0] >= 64:
152
- hidden_states = hidden_states.contiguous()
153
-
154
- # if `output_size` is passed we force the interpolation output
155
- # size and do not make use of `scale_factor=2`
156
- if output_size is None:
157
- hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
158
- else:
159
- hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
160
-
161
- # If the input is bfloat16, we cast back to bfloat16
162
- if dtype == torch.bfloat16:
163
- hidden_states = hidden_states.to(dtype)
164
-
165
- # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
166
- if self.use_conv:
167
- if self.name == "conv":
168
- hidden_states = self.conv(hidden_states)
169
- else:
170
- hidden_states = self.Conv2d_0(hidden_states)
171
-
172
- return hidden_states
173
-
174
-
175
- class Downsample2D(nn.Module):
176
- """A 2D downsampling layer with an optional convolution.
177
-
178
- Parameters:
179
- channels (`int`):
180
- number of channels in the inputs and outputs.
181
- use_conv (`bool`, default `False`):
182
- option to use a convolution.
183
- out_channels (`int`, optional):
184
- number of output channels. Defaults to `channels`.
185
- padding (`int`, default `1`):
186
- padding for the convolution.
187
- """
188
-
189
- def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
190
- super().__init__()
191
- self.channels = channels
192
- self.out_channels = out_channels or channels
193
- self.use_conv = use_conv
194
- self.padding = padding
195
- stride = 2
196
- self.name = name
197
-
198
- if use_conv:
199
- conv = nn.Conv2d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
200
- else:
201
- assert self.channels == self.out_channels
202
- conv = nn.AvgPool2d(kernel_size=stride, stride=stride)
203
-
204
- # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
205
- if name == "conv":
206
- self.Conv2d_0 = conv
207
- self.conv = conv
208
- elif name == "Conv2d_0":
209
- self.conv = conv
210
- else:
211
- self.conv = conv
212
-
213
- def forward(self, hidden_states):
214
- assert hidden_states.shape[1] == self.channels
215
- if self.use_conv and self.padding == 0:
216
- pad = (0, 1, 0, 1)
217
- hidden_states = F.pad(hidden_states, pad, mode="constant", value=0)
218
-
219
- assert hidden_states.shape[1] == self.channels
220
- hidden_states = self.conv(hidden_states)
221
-
222
- return hidden_states
223
-
224
-
225
- class FirUpsample2D(nn.Module):
226
- """A 2D FIR upsampling layer with an optional convolution.
227
-
228
- Parameters:
229
- channels (`int`):
230
- number of channels in the inputs and outputs.
231
- use_conv (`bool`, default `False`):
232
- option to use a convolution.
233
- out_channels (`int`, optional):
234
- number of output channels. Defaults to `channels`.
235
- fir_kernel (`tuple`, default `(1, 3, 3, 1)`):
236
- kernel for the FIR filter.
237
- """
238
-
239
- def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
240
- super().__init__()
241
- out_channels = out_channels if out_channels else channels
242
- if use_conv:
243
- self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
244
- self.use_conv = use_conv
245
- self.fir_kernel = fir_kernel
246
- self.out_channels = out_channels
247
-
248
- def _upsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain=1):
249
- """Fused `upsample_2d()` followed by `Conv2d()`.
250
-
251
- Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
252
- efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of
253
- arbitrary order.
254
-
255
- Args:
256
- hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
257
- weight: Weight tensor of the shape `[filterH, filterW, inChannels,
258
- outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`.
259
- kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
260
- (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
261
- factor: Integer upsampling factor (default: 2).
262
- gain: Scaling factor for signal magnitude (default: 1.0).
263
-
264
- Returns:
265
- output: Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same
266
- datatype as `hidden_states`.
267
- """
268
-
269
- assert isinstance(factor, int) and factor >= 1
270
-
271
- # Setup filter kernel.
272
- if kernel is None:
273
- kernel = [1] * factor
274
-
275
- # setup kernel
276
- kernel = torch.tensor(kernel, dtype=torch.float32)
277
- if kernel.ndim == 1:
278
- kernel = torch.outer(kernel, kernel)
279
- kernel /= torch.sum(kernel)
280
-
281
- kernel = kernel * (gain * (factor**2))
282
-
283
- if self.use_conv:
284
- convH = weight.shape[2]
285
- convW = weight.shape[3]
286
- inC = weight.shape[1]
287
-
288
- pad_value = (kernel.shape[0] - factor) - (convW - 1)
289
-
290
- stride = (factor, factor)
291
- # Determine data dimensions.
292
- output_shape = (
293
- (hidden_states.shape[2] - 1) * factor + convH,
294
- (hidden_states.shape[3] - 1) * factor + convW,
295
- )
296
- output_padding = (
297
- output_shape[0] - (hidden_states.shape[2] - 1) * stride[0] - convH,
298
- output_shape[1] - (hidden_states.shape[3] - 1) * stride[1] - convW,
299
- )
300
- assert output_padding[0] >= 0 and output_padding[1] >= 0
301
- num_groups = hidden_states.shape[1] // inC
302
-
303
- # Transpose weights.
304
- weight = torch.reshape(weight, (num_groups, -1, inC, convH, convW))
305
- weight = torch.flip(weight, dims=[3, 4]).permute(0, 2, 1, 3, 4)
306
- weight = torch.reshape(weight, (num_groups * inC, -1, convH, convW))
307
-
308
- inverse_conv = F.conv_transpose2d(
309
- hidden_states, weight, stride=stride, output_padding=output_padding, padding=0
310
- )
311
-
312
- output = upfirdn2d_native(
313
- inverse_conv,
314
- torch.tensor(kernel, device=inverse_conv.device),
315
- pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2 + 1),
316
- )
317
- else:
318
- pad_value = kernel.shape[0] - factor
319
- output = upfirdn2d_native(
320
- hidden_states,
321
- torch.tensor(kernel, device=hidden_states.device),
322
- up=factor,
323
- pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2),
324
- )
325
-
326
- return output
327
-
328
- def forward(self, hidden_states):
329
- if self.use_conv:
330
- height = self._upsample_2d(hidden_states, self.Conv2d_0.weight, kernel=self.fir_kernel)
331
- height = height + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
332
- else:
333
- height = self._upsample_2d(hidden_states, kernel=self.fir_kernel, factor=2)
334
-
335
- return height
336
-
337
-
338
- class FirDownsample2D(nn.Module):
339
- """A 2D FIR downsampling layer with an optional convolution.
340
-
341
- Parameters:
342
- channels (`int`):
343
- number of channels in the inputs and outputs.
344
- use_conv (`bool`, default `False`):
345
- option to use a convolution.
346
- out_channels (`int`, optional):
347
- number of output channels. Defaults to `channels`.
348
- fir_kernel (`tuple`, default `(1, 3, 3, 1)`):
349
- kernel for the FIR filter.
350
- """
351
-
352
- def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
353
- super().__init__()
354
- out_channels = out_channels if out_channels else channels
355
- if use_conv:
356
- self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
357
- self.fir_kernel = fir_kernel
358
- self.use_conv = use_conv
359
- self.out_channels = out_channels
360
-
361
- def _downsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain=1):
362
- """Fused `Conv2d()` followed by `downsample_2d()`.
363
- Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
364
- efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of
365
- arbitrary order.
366
-
367
- Args:
368
- hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
369
- weight:
370
- Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`. Grouped convolution can be
371
- performed by `inChannels = x.shape[0] // numGroups`.
372
- kernel: FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] *
373
- factor`, which corresponds to average pooling.
374
- factor: Integer downsampling factor (default: 2).
375
- gain: Scaling factor for signal magnitude (default: 1.0).
376
-
377
- Returns:
378
- output: Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and
379
- same datatype as `x`.
380
- """
381
-
382
- assert isinstance(factor, int) and factor >= 1
383
- if kernel is None:
384
- kernel = [1] * factor
385
-
386
- # setup kernel
387
- kernel = torch.tensor(kernel, dtype=torch.float32)
388
- if kernel.ndim == 1:
389
- kernel = torch.outer(kernel, kernel)
390
- kernel /= torch.sum(kernel)
391
-
392
- kernel = kernel * gain
393
-
394
- if self.use_conv:
395
- _, _, convH, convW = weight.shape
396
- pad_value = (kernel.shape[0] - factor) + (convW - 1)
397
- stride_value = [factor, factor]
398
- upfirdn_input = upfirdn2d_native(
399
- hidden_states,
400
- torch.tensor(kernel, device=hidden_states.device),
401
- pad=((pad_value + 1) // 2, pad_value // 2),
402
- )
403
- output = F.conv2d(upfirdn_input, weight, stride=stride_value, padding=0)
404
- else:
405
- pad_value = kernel.shape[0] - factor
406
- output = upfirdn2d_native(
407
- hidden_states,
408
- torch.tensor(kernel, device=hidden_states.device),
409
- down=factor,
410
- pad=((pad_value + 1) // 2, pad_value // 2),
411
- )
412
-
413
- return output
414
-
415
- def forward(self, hidden_states):
416
- if self.use_conv:
417
- downsample_input = self._downsample_2d(hidden_states, weight=self.Conv2d_0.weight, kernel=self.fir_kernel)
418
- hidden_states = downsample_input + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
419
- else:
420
- hidden_states = self._downsample_2d(hidden_states, kernel=self.fir_kernel, factor=2)
421
-
422
- return hidden_states
423
-
424
-
425
- # downsample/upsample layer used in k-upscaler, might be able to use FirDownsample2D/DirUpsample2D instead
426
- class KDownsample2D(nn.Module):
427
- def __init__(self, pad_mode="reflect"):
428
- super().__init__()
429
- self.pad_mode = pad_mode
430
- kernel_1d = torch.tensor([[1 / 8, 3 / 8, 3 / 8, 1 / 8]])
431
- self.pad = kernel_1d.shape[1] // 2 - 1
432
- self.register_buffer("kernel", kernel_1d.T @ kernel_1d, persistent=False)
433
-
434
- def forward(self, inputs):
435
- inputs = F.pad(inputs, (self.pad,) * 4, self.pad_mode)
436
- weight = inputs.new_zeros([inputs.shape[1], inputs.shape[1], self.kernel.shape[0], self.kernel.shape[1]])
437
- indices = torch.arange(inputs.shape[1], device=inputs.device)
438
- kernel = self.kernel.to(weight)[None, :].expand(inputs.shape[1], -1, -1)
439
- weight[indices, indices] = kernel
440
- return F.conv2d(inputs, weight, stride=2)
441
-
442
-
443
- class KUpsample2D(nn.Module):
444
- def __init__(self, pad_mode="reflect"):
445
- super().__init__()
446
- self.pad_mode = pad_mode
447
- kernel_1d = torch.tensor([[1 / 8, 3 / 8, 3 / 8, 1 / 8]]) * 2
448
- self.pad = kernel_1d.shape[1] // 2 - 1
449
- self.register_buffer("kernel", kernel_1d.T @ kernel_1d, persistent=False)
450
-
451
- def forward(self, inputs):
452
- inputs = F.pad(inputs, ((self.pad + 1) // 2,) * 4, self.pad_mode)
453
- weight = inputs.new_zeros([inputs.shape[1], inputs.shape[1], self.kernel.shape[0], self.kernel.shape[1]])
454
- indices = torch.arange(inputs.shape[1], device=inputs.device)
455
- kernel = self.kernel.to(weight)[None, :].expand(inputs.shape[1], -1, -1)
456
- weight[indices, indices] = kernel
457
- return F.conv_transpose2d(inputs, weight, stride=2, padding=self.pad * 2 + 1)
458
-
459
-
460
- class ResnetBlock2D(nn.Module):
461
- r"""
462
- A Resnet block.
463
-
464
- Parameters:
465
- in_channels (`int`): The number of channels in the input.
466
- out_channels (`int`, *optional*, default to be `None`):
467
- The number of output channels for the first conv2d layer. If None, same as `in_channels`.
468
- dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use.
469
- temb_channels (`int`, *optional*, default to `512`): the number of channels in timestep embedding.
470
- groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer.
471
- groups_out (`int`, *optional*, default to None):
472
- The number of groups to use for the second normalization layer. if set to None, same as `groups`.
473
- eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization.
474
- non_linearity (`str`, *optional*, default to `"swish"`): the activation function to use.
475
- time_embedding_norm (`str`, *optional*, default to `"default"` ): Time scale shift config.
476
- By default, apply timestep embedding conditioning with a simple shift mechanism. Choose "scale_shift" or
477
- "ada_group" for a stronger conditioning with scale and shift.
478
- kernel (`torch.FloatTensor`, optional, default to None): FIR filter, see
479
- [`~models.resnet.FirUpsample2D`] and [`~models.resnet.FirDownsample2D`].
480
- output_scale_factor (`float`, *optional*, default to be `1.0`): the scale factor to use for the output.
481
- use_in_shortcut (`bool`, *optional*, default to `True`):
482
- If `True`, add a 1x1 nn.conv2d layer for skip-connection.
483
- up (`bool`, *optional*, default to `False`): If `True`, add an upsample layer.
484
- down (`bool`, *optional*, default to `False`): If `True`, add a downsample layer.
485
- conv_shortcut_bias (`bool`, *optional*, default to `True`): If `True`, adds a learnable bias to the
486
- `conv_shortcut` output.
487
- conv_2d_out_channels (`int`, *optional*, default to `None`): the number of channels in the output.
488
- If None, same as `out_channels`.
489
- """
490
-
491
- def __init__(
492
- self,
493
- *,
494
- in_channels,
495
- out_channels=None,
496
- conv_shortcut=False,
497
- dropout=0.0,
498
- temb_channels=512,
499
- groups=32,
500
- groups_out=None,
501
- pre_norm=True,
502
- eps=1e-6,
503
- non_linearity="swish",
504
- skip_time_act=False,
505
- time_embedding_norm="default", # default, scale_shift, ada_group, spatial
506
- kernel=None,
507
- output_scale_factor=1.0,
508
- use_in_shortcut=None,
509
- up=False,
510
- down=False,
511
- conv_shortcut_bias: bool = True,
512
- conv_2d_out_channels: Optional[int] = None,
513
- ):
514
- super().__init__()
515
- self.pre_norm = pre_norm
516
- self.pre_norm = True
517
- self.in_channels = in_channels
518
- out_channels = in_channels if out_channels is None else out_channels
519
- self.out_channels = out_channels
520
- self.use_conv_shortcut = conv_shortcut
521
- self.up = up
522
- self.down = down
523
- self.output_scale_factor = output_scale_factor
524
- self.time_embedding_norm = time_embedding_norm
525
- self.skip_time_act = skip_time_act
526
-
527
- if groups_out is None:
528
- groups_out = groups
529
-
530
- if self.time_embedding_norm == "ada_group":
531
- self.norm1 = AdaGroupNorm(temb_channels, in_channels, groups, eps=eps)
532
- elif self.time_embedding_norm == "spatial":
533
- self.norm1 = SpatialNorm(in_channels, temb_channels)
534
- else:
535
- self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
536
-
537
- self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
538
-
539
- if temb_channels is not None:
540
- if self.time_embedding_norm == "default":
541
- self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels)
542
- elif self.time_embedding_norm == "scale_shift":
543
- self.time_emb_proj = torch.nn.Linear(temb_channels, 2 * out_channels)
544
- elif self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
545
- self.time_emb_proj = None
546
- else:
547
- raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ")
548
- else:
549
- self.time_emb_proj = None
550
-
551
- if self.time_embedding_norm == "ada_group":
552
- self.norm2 = AdaGroupNorm(temb_channels, out_channels, groups_out, eps=eps)
553
- elif self.time_embedding_norm == "spatial":
554
- self.norm2 = SpatialNorm(out_channels, temb_channels)
555
- else:
556
- self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
557
-
558
- self.dropout = torch.nn.Dropout(dropout)
559
- conv_2d_out_channels = conv_2d_out_channels or out_channels
560
- self.conv2 = torch.nn.Conv2d(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1)
561
-
562
- self.nonlinearity = get_activation(non_linearity)
563
-
564
- self.upsample = self.downsample = None
565
- if self.up:
566
- if kernel == "fir":
567
- fir_kernel = (1, 3, 3, 1)
568
- self.upsample = lambda x: upsample_2d(x, kernel=fir_kernel)
569
- elif kernel == "sde_vp":
570
- self.upsample = partial(F.interpolate, scale_factor=2.0, mode="nearest")
571
- else:
572
- self.upsample = Upsample2D(in_channels, use_conv=False)
573
- elif self.down:
574
- if kernel == "fir":
575
- fir_kernel = (1, 3, 3, 1)
576
- self.downsample = lambda x: downsample_2d(x, kernel=fir_kernel)
577
- elif kernel == "sde_vp":
578
- self.downsample = partial(F.avg_pool2d, kernel_size=2, stride=2)
579
- else:
580
- self.downsample = Downsample2D(in_channels, use_conv=False, padding=1, name="op")
581
-
582
- self.use_in_shortcut = self.in_channels != conv_2d_out_channels if use_in_shortcut is None else use_in_shortcut
583
-
584
- self.conv_shortcut = None
585
- if self.use_in_shortcut:
586
- self.conv_shortcut = torch.nn.Conv2d(
587
- in_channels, conv_2d_out_channels, kernel_size=1, stride=1, padding=0, bias=conv_shortcut_bias
588
- )
589
-
590
- def forward(self, input_tensor, temb):
591
- hidden_states = input_tensor
592
-
593
- if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
594
- hidden_states = self.norm1(hidden_states, temb)
595
- else:
596
- hidden_states = self.norm1(hidden_states)
597
-
598
- hidden_states = self.nonlinearity(hidden_states)
599
-
600
- if self.upsample is not None:
601
- # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
602
- if hidden_states.shape[0] >= 64:
603
- input_tensor = input_tensor.contiguous()
604
- hidden_states = hidden_states.contiguous()
605
- input_tensor = self.upsample(input_tensor)
606
- hidden_states = self.upsample(hidden_states)
607
- elif self.downsample is not None:
608
- input_tensor = self.downsample(input_tensor)
609
- hidden_states = self.downsample(hidden_states)
610
-
611
- hidden_states = self.conv1(hidden_states)
612
-
613
- if self.time_emb_proj is not None:
614
- if not self.skip_time_act:
615
- temb = self.nonlinearity(temb)
616
- temb = self.time_emb_proj(temb)[:, :, None, None]
617
-
618
- if temb is not None and self.time_embedding_norm == "default":
619
- hidden_states = hidden_states + temb
620
-
621
- if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
622
- hidden_states = self.norm2(hidden_states, temb)
623
- else:
624
- hidden_states = self.norm2(hidden_states)
625
-
626
- if temb is not None and self.time_embedding_norm == "scale_shift":
627
- scale, shift = torch.chunk(temb, 2, dim=1)
628
- hidden_states = hidden_states * (1 + scale) + shift
629
-
630
- hidden_states = self.nonlinearity(hidden_states)
631
-
632
- hidden_states = self.dropout(hidden_states)
633
- hidden_states = self.conv2(hidden_states)
634
-
635
- if self.conv_shortcut is not None:
636
- input_tensor = self.conv_shortcut(input_tensor)
637
-
638
- output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
639
-
640
- return output_tensor
641
-
642
-
643
- # unet_rl.py
644
- def rearrange_dims(tensor):
645
- if len(tensor.shape) == 2:
646
- return tensor[:, :, None]
647
- if len(tensor.shape) == 3:
648
- return tensor[:, :, None, :]
649
- elif len(tensor.shape) == 4:
650
- return tensor[:, :, 0, :]
651
- else:
652
- raise ValueError(f"`len(tensor)`: {len(tensor)} has to be 2, 3 or 4.")
653
-
654
-
655
- class Conv1dBlock(nn.Module):
656
- """
657
- Conv1d --> GroupNorm --> Mish
658
- """
659
-
660
- def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
661
- super().__init__()
662
-
663
- self.conv1d = nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2)
664
- self.group_norm = nn.GroupNorm(n_groups, out_channels)
665
- self.mish = nn.Mish()
666
-
667
- def forward(self, inputs):
668
- intermediate_repr = self.conv1d(inputs)
669
- intermediate_repr = rearrange_dims(intermediate_repr)
670
- intermediate_repr = self.group_norm(intermediate_repr)
671
- intermediate_repr = rearrange_dims(intermediate_repr)
672
- output = self.mish(intermediate_repr)
673
- return output
674
-
675
-
676
- # unet_rl.py
677
- class ResidualTemporalBlock1D(nn.Module):
678
- def __init__(self, inp_channels, out_channels, embed_dim, kernel_size=5):
679
- super().__init__()
680
- self.conv_in = Conv1dBlock(inp_channels, out_channels, kernel_size)
681
- self.conv_out = Conv1dBlock(out_channels, out_channels, kernel_size)
682
-
683
- self.time_emb_act = nn.Mish()
684
- self.time_emb = nn.Linear(embed_dim, out_channels)
685
-
686
- self.residual_conv = (
687
- nn.Conv1d(inp_channels, out_channels, 1) if inp_channels != out_channels else nn.Identity()
688
- )
689
-
690
- def forward(self, inputs, t):
691
- """
692
- Args:
693
- inputs : [ batch_size x inp_channels x horizon ]
694
- t : [ batch_size x embed_dim ]
695
-
696
- returns:
697
- out : [ batch_size x out_channels x horizon ]
698
- """
699
- t = self.time_emb_act(t)
700
- t = self.time_emb(t)
701
- out = self.conv_in(inputs) + rearrange_dims(t)
702
- out = self.conv_out(out)
703
- return out + self.residual_conv(inputs)
704
-
705
-
706
- def upsample_2d(hidden_states, kernel=None, factor=2, gain=1):
707
- r"""Upsample2D a batch of 2D images with the given filter.
708
- Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given
709
- filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified
710
- `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is
711
- a: multiple of the upsampling factor.
712
-
713
- Args:
714
- hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
715
- kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
716
- (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
717
- factor: Integer upsampling factor (default: 2).
718
- gain: Scaling factor for signal magnitude (default: 1.0).
719
-
720
- Returns:
721
- output: Tensor of the shape `[N, C, H * factor, W * factor]`
722
- """
723
- assert isinstance(factor, int) and factor >= 1
724
- if kernel is None:
725
- kernel = [1] * factor
726
-
727
- kernel = torch.tensor(kernel, dtype=torch.float32)
728
- if kernel.ndim == 1:
729
- kernel = torch.outer(kernel, kernel)
730
- kernel /= torch.sum(kernel)
731
-
732
- kernel = kernel * (gain * (factor**2))
733
- pad_value = kernel.shape[0] - factor
734
- output = upfirdn2d_native(
735
- hidden_states,
736
- kernel.to(device=hidden_states.device),
737
- up=factor,
738
- pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2),
739
- )
740
- return output
741
-
742
-
743
- def downsample_2d(hidden_states, kernel=None, factor=2, gain=1):
744
- r"""Downsample2D a batch of 2D images with the given filter.
745
- Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the
746
- given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the
747
- specified `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its
748
- shape is a multiple of the downsampling factor.
749
-
750
- Args:
751
- hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
752
- kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
753
- (separable). The default is `[1] * factor`, which corresponds to average pooling.
754
- factor: Integer downsampling factor (default: 2).
755
- gain: Scaling factor for signal magnitude (default: 1.0).
756
-
757
- Returns:
758
- output: Tensor of the shape `[N, C, H // factor, W // factor]`
759
- """
760
-
761
- assert isinstance(factor, int) and factor >= 1
762
- if kernel is None:
763
- kernel = [1] * factor
764
-
765
- kernel = torch.tensor(kernel, dtype=torch.float32)
766
- if kernel.ndim == 1:
767
- kernel = torch.outer(kernel, kernel)
768
- kernel /= torch.sum(kernel)
769
-
770
- kernel = kernel * gain
771
- pad_value = kernel.shape[0] - factor
772
- output = upfirdn2d_native(
773
- hidden_states, kernel.to(device=hidden_states.device), down=factor, pad=((pad_value + 1) // 2, pad_value // 2)
774
- )
775
- return output
776
-
777
-
778
- def upfirdn2d_native(tensor, kernel, up=1, down=1, pad=(0, 0)):
779
- up_x = up_y = up
780
- down_x = down_y = down
781
- pad_x0 = pad_y0 = pad[0]
782
- pad_x1 = pad_y1 = pad[1]
783
-
784
- _, channel, in_h, in_w = tensor.shape
785
- tensor = tensor.reshape(-1, in_h, in_w, 1)
786
-
787
- _, in_h, in_w, minor = tensor.shape
788
- kernel_h, kernel_w = kernel.shape
789
-
790
- out = tensor.view(-1, in_h, 1, in_w, 1, minor)
791
- out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
792
- out = out.view(-1, in_h * up_y, in_w * up_x, minor)
793
-
794
- out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])
795
- out = out.to(tensor.device) # Move back to mps if necessary
796
- out = out[
797
- :,
798
- max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
799
- max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
800
- :,
801
- ]
802
-
803
- out = out.permute(0, 3, 1, 2)
804
- out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])
805
- w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
806
- out = F.conv2d(out, w)
807
- out = out.reshape(
808
- -1,
809
- minor,
810
- in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
811
- in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
812
- )
813
- out = out.permute(0, 2, 3, 1)
814
- out = out[:, ::down_y, ::down_x, :]
815
-
816
- out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
817
- out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
818
-
819
- return out.view(-1, channel, out_h, out_w)
820
-
821
-
822
- class TemporalConvLayer(nn.Module):
823
- """
824
- Temporal convolutional layer that can be used for video (sequence of images) input Code mostly copied from:
825
- https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/models/multi_modal/video_synthesis/unet_sd.py#L1016
826
- """
827
-
828
- def __init__(self, in_dim, out_dim=None, dropout=0.0):
829
- super().__init__()
830
- out_dim = out_dim or in_dim
831
- self.in_dim = in_dim
832
- self.out_dim = out_dim
833
-
834
- # conv layers
835
- self.conv1 = nn.Sequential(
836
- nn.GroupNorm(32, in_dim), nn.SiLU(), nn.Conv3d(in_dim, out_dim, (3, 1, 1), padding=(1, 0, 0))
837
- )
838
- self.conv2 = nn.Sequential(
839
- nn.GroupNorm(32, out_dim),
840
- nn.SiLU(),
841
- nn.Dropout(dropout),
842
- nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)),
843
- )
844
- self.conv3 = nn.Sequential(
845
- nn.GroupNorm(32, out_dim),
846
- nn.SiLU(),
847
- nn.Dropout(dropout),
848
- nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)),
849
- )
850
- self.conv4 = nn.Sequential(
851
- nn.GroupNorm(32, out_dim),
852
- nn.SiLU(),
853
- nn.Dropout(dropout),
854
- nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)),
855
- )
856
-
857
- # zero out the last layer params,so the conv block is identity
858
- nn.init.zeros_(self.conv4[-1].weight)
859
- nn.init.zeros_(self.conv4[-1].bias)
860
-
861
- def forward(self, hidden_states, num_frames=1):
862
- hidden_states = (
863
- hidden_states[None, :].reshape((-1, num_frames) + hidden_states.shape[1:]).permute(0, 2, 1, 3, 4)
864
- )
865
-
866
- identity = hidden_states
867
- hidden_states = self.conv1(hidden_states)
868
- hidden_states = self.conv2(hidden_states)
869
- hidden_states = self.conv3(hidden_states)
870
- hidden_states = self.conv4(hidden_states)
871
-
872
- hidden_states = identity + hidden_states
873
-
874
- hidden_states = hidden_states.permute(0, 2, 1, 3, 4).reshape(
875
- (hidden_states.shape[0] * hidden_states.shape[2], -1) + hidden_states.shape[3:]
876
- )
877
- return hidden_states
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
gradio_demo/6DoF/diffusers/models/resnet_flax.py DELETED
@@ -1,124 +0,0 @@
1
- # Copyright 2023 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- import flax.linen as nn
15
- import jax
16
- import jax.numpy as jnp
17
-
18
-
19
- class FlaxUpsample2D(nn.Module):
20
- out_channels: int
21
- dtype: jnp.dtype = jnp.float32
22
-
23
- def setup(self):
24
- self.conv = nn.Conv(
25
- self.out_channels,
26
- kernel_size=(3, 3),
27
- strides=(1, 1),
28
- padding=((1, 1), (1, 1)),
29
- dtype=self.dtype,
30
- )
31
-
32
- def __call__(self, hidden_states):
33
- batch, height, width, channels = hidden_states.shape
34
- hidden_states = jax.image.resize(
35
- hidden_states,
36
- shape=(batch, height * 2, width * 2, channels),
37
- method="nearest",
38
- )
39
- hidden_states = self.conv(hidden_states)
40
- return hidden_states
41
-
42
-
43
- class FlaxDownsample2D(nn.Module):
44
- out_channels: int
45
- dtype: jnp.dtype = jnp.float32
46
-
47
- def setup(self):
48
- self.conv = nn.Conv(
49
- self.out_channels,
50
- kernel_size=(3, 3),
51
- strides=(2, 2),
52
- padding=((1, 1), (1, 1)), # padding="VALID",
53
- dtype=self.dtype,
54
- )
55
-
56
- def __call__(self, hidden_states):
57
- # pad = ((0, 0), (0, 1), (0, 1), (0, 0)) # pad height and width dim
58
- # hidden_states = jnp.pad(hidden_states, pad_width=pad)
59
- hidden_states = self.conv(hidden_states)
60
- return hidden_states
61
-
62
-
63
- class FlaxResnetBlock2D(nn.Module):
64
- in_channels: int
65
- out_channels: int = None
66
- dropout_prob: float = 0.0
67
- use_nin_shortcut: bool = None
68
- dtype: jnp.dtype = jnp.float32
69
-
70
- def setup(self):
71
- out_channels = self.in_channels if self.out_channels is None else self.out_channels
72
-
73
- self.norm1 = nn.GroupNorm(num_groups=32, epsilon=1e-5)
74
- self.conv1 = nn.Conv(
75
- out_channels,
76
- kernel_size=(3, 3),
77
- strides=(1, 1),
78
- padding=((1, 1), (1, 1)),
79
- dtype=self.dtype,
80
- )
81
-
82
- self.time_emb_proj = nn.Dense(out_channels, dtype=self.dtype)
83
-
84
- self.norm2 = nn.GroupNorm(num_groups=32, epsilon=1e-5)
85
- self.dropout = nn.Dropout(self.dropout_prob)
86
- self.conv2 = nn.Conv(
87
- out_channels,
88
- kernel_size=(3, 3),
89
- strides=(1, 1),
90
- padding=((1, 1), (1, 1)),
91
- dtype=self.dtype,
92
- )
93
-
94
- use_nin_shortcut = self.in_channels != out_channels if self.use_nin_shortcut is None else self.use_nin_shortcut
95
-
96
- self.conv_shortcut = None
97
- if use_nin_shortcut:
98
- self.conv_shortcut = nn.Conv(
99
- out_channels,
100
- kernel_size=(1, 1),
101
- strides=(1, 1),
102
- padding="VALID",
103
- dtype=self.dtype,
104
- )
105
-
106
- def __call__(self, hidden_states, temb, deterministic=True):
107
- residual = hidden_states
108
- hidden_states = self.norm1(hidden_states)
109
- hidden_states = nn.swish(hidden_states)
110
- hidden_states = self.conv1(hidden_states)
111
-
112
- temb = self.time_emb_proj(nn.swish(temb))
113
- temb = jnp.expand_dims(jnp.expand_dims(temb, 1), 1)
114
- hidden_states = hidden_states + temb
115
-
116
- hidden_states = self.norm2(hidden_states)
117
- hidden_states = nn.swish(hidden_states)
118
- hidden_states = self.dropout(hidden_states, deterministic)
119
- hidden_states = self.conv2(hidden_states)
120
-
121
- if self.conv_shortcut is not None:
122
- residual = self.conv_shortcut(residual)
123
-
124
- return hidden_states + residual
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
gradio_demo/6DoF/diffusers/models/t5_film_transformer.py DELETED
@@ -1,321 +0,0 @@
1
- # Copyright 2023 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- import math
15
-
16
- import torch
17
- from torch import nn
18
-
19
- from ..configuration_utils import ConfigMixin, register_to_config
20
- from .attention_processor import Attention
21
- from .embeddings import get_timestep_embedding
22
- from .modeling_utils import ModelMixin
23
-
24
-
25
- class T5FilmDecoder(ModelMixin, ConfigMixin):
26
- @register_to_config
27
- def __init__(
28
- self,
29
- input_dims: int = 128,
30
- targets_length: int = 256,
31
- max_decoder_noise_time: float = 2000.0,
32
- d_model: int = 768,
33
- num_layers: int = 12,
34
- num_heads: int = 12,
35
- d_kv: int = 64,
36
- d_ff: int = 2048,
37
- dropout_rate: float = 0.1,
38
- ):
39
- super().__init__()
40
-
41
- self.conditioning_emb = nn.Sequential(
42
- nn.Linear(d_model, d_model * 4, bias=False),
43
- nn.SiLU(),
44
- nn.Linear(d_model * 4, d_model * 4, bias=False),
45
- nn.SiLU(),
46
- )
47
-
48
- self.position_encoding = nn.Embedding(targets_length, d_model)
49
- self.position_encoding.weight.requires_grad = False
50
-
51
- self.continuous_inputs_projection = nn.Linear(input_dims, d_model, bias=False)
52
-
53
- self.dropout = nn.Dropout(p=dropout_rate)
54
-
55
- self.decoders = nn.ModuleList()
56
- for lyr_num in range(num_layers):
57
- # FiLM conditional T5 decoder
58
- lyr = DecoderLayer(d_model=d_model, d_kv=d_kv, num_heads=num_heads, d_ff=d_ff, dropout_rate=dropout_rate)
59
- self.decoders.append(lyr)
60
-
61
- self.decoder_norm = T5LayerNorm(d_model)
62
-
63
- self.post_dropout = nn.Dropout(p=dropout_rate)
64
- self.spec_out = nn.Linear(d_model, input_dims, bias=False)
65
-
66
- def encoder_decoder_mask(self, query_input, key_input):
67
- mask = torch.mul(query_input.unsqueeze(-1), key_input.unsqueeze(-2))
68
- return mask.unsqueeze(-3)
69
-
70
- def forward(self, encodings_and_masks, decoder_input_tokens, decoder_noise_time):
71
- batch, _, _ = decoder_input_tokens.shape
72
- assert decoder_noise_time.shape == (batch,)
73
-
74
- # decoder_noise_time is in [0, 1), so rescale to expected timing range.
75
- time_steps = get_timestep_embedding(
76
- decoder_noise_time * self.config.max_decoder_noise_time,
77
- embedding_dim=self.config.d_model,
78
- max_period=self.config.max_decoder_noise_time,
79
- ).to(dtype=self.dtype)
80
-
81
- conditioning_emb = self.conditioning_emb(time_steps).unsqueeze(1)
82
-
83
- assert conditioning_emb.shape == (batch, 1, self.config.d_model * 4)
84
-
85
- seq_length = decoder_input_tokens.shape[1]
86
-
87
- # If we want to use relative positions for audio context, we can just offset
88
- # this sequence by the length of encodings_and_masks.
89
- decoder_positions = torch.broadcast_to(
90
- torch.arange(seq_length, device=decoder_input_tokens.device),
91
- (batch, seq_length),
92
- )
93
-
94
- position_encodings = self.position_encoding(decoder_positions)
95
-
96
- inputs = self.continuous_inputs_projection(decoder_input_tokens)
97
- inputs += position_encodings
98
- y = self.dropout(inputs)
99
-
100
- # decoder: No padding present.
101
- decoder_mask = torch.ones(
102
- decoder_input_tokens.shape[:2], device=decoder_input_tokens.device, dtype=inputs.dtype
103
- )
104
-
105
- # Translate encoding masks to encoder-decoder masks.
106
- encodings_and_encdec_masks = [(x, self.encoder_decoder_mask(decoder_mask, y)) for x, y in encodings_and_masks]
107
-
108
- # cross attend style: concat encodings
109
- encoded = torch.cat([x[0] for x in encodings_and_encdec_masks], dim=1)
110
- encoder_decoder_mask = torch.cat([x[1] for x in encodings_and_encdec_masks], dim=-1)
111
-
112
- for lyr in self.decoders:
113
- y = lyr(
114
- y,
115
- conditioning_emb=conditioning_emb,
116
- encoder_hidden_states=encoded,
117
- encoder_attention_mask=encoder_decoder_mask,
118
- )[0]
119
-
120
- y = self.decoder_norm(y)
121
- y = self.post_dropout(y)
122
-
123
- spec_out = self.spec_out(y)
124
- return spec_out
125
-
126
-
127
- class DecoderLayer(nn.Module):
128
- def __init__(self, d_model, d_kv, num_heads, d_ff, dropout_rate, layer_norm_epsilon=1e-6):
129
- super().__init__()
130
- self.layer = nn.ModuleList()
131
-
132
- # cond self attention: layer 0
133
- self.layer.append(
134
- T5LayerSelfAttentionCond(d_model=d_model, d_kv=d_kv, num_heads=num_heads, dropout_rate=dropout_rate)
135
- )
136
-
137
- # cross attention: layer 1
138
- self.layer.append(
139
- T5LayerCrossAttention(
140
- d_model=d_model,
141
- d_kv=d_kv,
142
- num_heads=num_heads,
143
- dropout_rate=dropout_rate,
144
- layer_norm_epsilon=layer_norm_epsilon,
145
- )
146
- )
147
-
148
- # Film Cond MLP + dropout: last layer
149
- self.layer.append(
150
- T5LayerFFCond(d_model=d_model, d_ff=d_ff, dropout_rate=dropout_rate, layer_norm_epsilon=layer_norm_epsilon)
151
- )
152
-
153
- def forward(
154
- self,
155
- hidden_states,
156
- conditioning_emb=None,
157
- attention_mask=None,
158
- encoder_hidden_states=None,
159
- encoder_attention_mask=None,
160
- encoder_decoder_position_bias=None,
161
- ):
162
- hidden_states = self.layer[0](
163
- hidden_states,
164
- conditioning_emb=conditioning_emb,
165
- attention_mask=attention_mask,
166
- )
167
-
168
- if encoder_hidden_states is not None:
169
- encoder_extended_attention_mask = torch.where(encoder_attention_mask > 0, 0, -1e10).to(
170
- encoder_hidden_states.dtype
171
- )
172
-
173
- hidden_states = self.layer[1](
174
- hidden_states,
175
- key_value_states=encoder_hidden_states,
176
- attention_mask=encoder_extended_attention_mask,
177
- )
178
-
179
- # Apply Film Conditional Feed Forward layer
180
- hidden_states = self.layer[-1](hidden_states, conditioning_emb)
181
-
182
- return (hidden_states,)
183
-
184
-
185
- class T5LayerSelfAttentionCond(nn.Module):
186
- def __init__(self, d_model, d_kv, num_heads, dropout_rate):
187
- super().__init__()
188
- self.layer_norm = T5LayerNorm(d_model)
189
- self.FiLMLayer = T5FiLMLayer(in_features=d_model * 4, out_features=d_model)
190
- self.attention = Attention(query_dim=d_model, heads=num_heads, dim_head=d_kv, out_bias=False, scale_qk=False)
191
- self.dropout = nn.Dropout(dropout_rate)
192
-
193
- def forward(
194
- self,
195
- hidden_states,
196
- conditioning_emb=None,
197
- attention_mask=None,
198
- ):
199
- # pre_self_attention_layer_norm
200
- normed_hidden_states = self.layer_norm(hidden_states)
201
-
202
- if conditioning_emb is not None:
203
- normed_hidden_states = self.FiLMLayer(normed_hidden_states, conditioning_emb)
204
-
205
- # Self-attention block
206
- attention_output = self.attention(normed_hidden_states)
207
-
208
- hidden_states = hidden_states + self.dropout(attention_output)
209
-
210
- return hidden_states
211
-
212
-
213
- class T5LayerCrossAttention(nn.Module):
214
- def __init__(self, d_model, d_kv, num_heads, dropout_rate, layer_norm_epsilon):
215
- super().__init__()
216
- self.attention = Attention(query_dim=d_model, heads=num_heads, dim_head=d_kv, out_bias=False, scale_qk=False)
217
- self.layer_norm = T5LayerNorm(d_model, eps=layer_norm_epsilon)
218
- self.dropout = nn.Dropout(dropout_rate)
219
-
220
- def forward(
221
- self,
222
- hidden_states,
223
- key_value_states=None,
224
- attention_mask=None,
225
- ):
226
- normed_hidden_states = self.layer_norm(hidden_states)
227
- attention_output = self.attention(
228
- normed_hidden_states,
229
- encoder_hidden_states=key_value_states,
230
- attention_mask=attention_mask.squeeze(1),
231
- )
232
- layer_output = hidden_states + self.dropout(attention_output)
233
- return layer_output
234
-
235
-
236
- class T5LayerFFCond(nn.Module):
237
- def __init__(self, d_model, d_ff, dropout_rate, layer_norm_epsilon):
238
- super().__init__()
239
- self.DenseReluDense = T5DenseGatedActDense(d_model=d_model, d_ff=d_ff, dropout_rate=dropout_rate)
240
- self.film = T5FiLMLayer(in_features=d_model * 4, out_features=d_model)
241
- self.layer_norm = T5LayerNorm(d_model, eps=layer_norm_epsilon)
242
- self.dropout = nn.Dropout(dropout_rate)
243
-
244
- def forward(self, hidden_states, conditioning_emb=None):
245
- forwarded_states = self.layer_norm(hidden_states)
246
- if conditioning_emb is not None:
247
- forwarded_states = self.film(forwarded_states, conditioning_emb)
248
-
249
- forwarded_states = self.DenseReluDense(forwarded_states)
250
- hidden_states = hidden_states + self.dropout(forwarded_states)
251
- return hidden_states
252
-
253
-
254
- class T5DenseGatedActDense(nn.Module):
255
- def __init__(self, d_model, d_ff, dropout_rate):
256
- super().__init__()
257
- self.wi_0 = nn.Linear(d_model, d_ff, bias=False)
258
- self.wi_1 = nn.Linear(d_model, d_ff, bias=False)
259
- self.wo = nn.Linear(d_ff, d_model, bias=False)
260
- self.dropout = nn.Dropout(dropout_rate)
261
- self.act = NewGELUActivation()
262
-
263
- def forward(self, hidden_states):
264
- hidden_gelu = self.act(self.wi_0(hidden_states))
265
- hidden_linear = self.wi_1(hidden_states)
266
- hidden_states = hidden_gelu * hidden_linear
267
- hidden_states = self.dropout(hidden_states)
268
-
269
- hidden_states = self.wo(hidden_states)
270
- return hidden_states
271
-
272
-
273
- class T5LayerNorm(nn.Module):
274
- def __init__(self, hidden_size, eps=1e-6):
275
- """
276
- Construct a layernorm module in the T5 style. No bias and no subtraction of mean.
277
- """
278
- super().__init__()
279
- self.weight = nn.Parameter(torch.ones(hidden_size))
280
- self.variance_epsilon = eps
281
-
282
- def forward(self, hidden_states):
283
- # T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean
284
- # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus variance is calculated
285
- # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for
286
- # half-precision inputs is done in fp32
287
-
288
- variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
289
- hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
290
-
291
- # convert into half-precision if necessary
292
- if self.weight.dtype in [torch.float16, torch.bfloat16]:
293
- hidden_states = hidden_states.to(self.weight.dtype)
294
-
295
- return self.weight * hidden_states
296
-
297
-
298
- class NewGELUActivation(nn.Module):
299
- """
300
- Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see
301
- the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
302
- """
303
-
304
- def forward(self, input: torch.Tensor) -> torch.Tensor:
305
- return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))
306
-
307
-
308
- class T5FiLMLayer(nn.Module):
309
- """
310
- FiLM Layer
311
- """
312
-
313
- def __init__(self, in_features, out_features):
314
- super().__init__()
315
- self.scale_bias = nn.Linear(in_features, out_features * 2, bias=False)
316
-
317
- def forward(self, x, conditioning_emb):
318
- emb = self.scale_bias(conditioning_emb)
319
- scale, shift = torch.chunk(emb, 2, -1)
320
- x = x * (1 + scale) + shift
321
- return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
gradio_demo/6DoF/diffusers/models/transformer_2d.py DELETED
@@ -1,343 +0,0 @@
1
- # Copyright 2023 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- from dataclasses import dataclass
15
- from typing import Any, Dict, Optional
16
-
17
- import torch
18
- import torch.nn.functional as F
19
- from torch import nn
20
-
21
- from ..configuration_utils import ConfigMixin, register_to_config
22
- from ..models.embeddings import ImagePositionalEmbeddings
23
- from ..utils import BaseOutput, deprecate
24
- from .attention import BasicTransformerBlock
25
- from .embeddings import PatchEmbed
26
- from .modeling_utils import ModelMixin
27
-
28
-
29
- @dataclass
30
- class Transformer2DModelOutput(BaseOutput):
31
- """
32
- The output of [`Transformer2DModel`].
33
-
34
- Args:
35
- sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
36
- The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
37
- distributions for the unnoised latent pixels.
38
- """
39
-
40
- sample: torch.FloatTensor
41
-
42
-
43
- class Transformer2DModel(ModelMixin, ConfigMixin):
44
- """
45
- A 2D Transformer model for image-like data.
46
-
47
- Parameters:
48
- num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
49
- attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
50
- in_channels (`int`, *optional*):
51
- The number of channels in the input and output (specify if the input is **continuous**).
52
- num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
53
- dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
54
- cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
55
- sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
56
- This is fixed during training since it is used to learn a number of position embeddings.
57
- num_vector_embeds (`int`, *optional*):
58
- The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**).
59
- Includes the class for the masked latent pixel.
60
- activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
61
- num_embeds_ada_norm ( `int`, *optional*):
62
- The number of diffusion steps used during training. Pass if at least one of the norm_layers is
63
- `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
64
- added to the hidden states.
65
-
66
- During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`.
67
- attention_bias (`bool`, *optional*):
68
- Configure if the `TransformerBlocks` attention should contain a bias parameter.
69
- """
70
-
71
- @register_to_config
72
- def __init__(
73
- self,
74
- num_attention_heads: int = 16,
75
- attention_head_dim: int = 88,
76
- in_channels: Optional[int] = None,
77
- out_channels: Optional[int] = None,
78
- num_layers: int = 1,
79
- dropout: float = 0.0,
80
- norm_num_groups: int = 32,
81
- cross_attention_dim: Optional[int] = None,
82
- attention_bias: bool = False,
83
- sample_size: Optional[int] = None,
84
- num_vector_embeds: Optional[int] = None,
85
- patch_size: Optional[int] = None,
86
- activation_fn: str = "geglu",
87
- num_embeds_ada_norm: Optional[int] = None,
88
- use_linear_projection: bool = False,
89
- only_cross_attention: bool = False,
90
- upcast_attention: bool = False,
91
- norm_type: str = "layer_norm",
92
- norm_elementwise_affine: bool = True,
93
- ):
94
- super().__init__()
95
- self.use_linear_projection = use_linear_projection
96
- self.num_attention_heads = num_attention_heads
97
- self.attention_head_dim = attention_head_dim
98
- inner_dim = num_attention_heads * attention_head_dim
99
-
100
- # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
101
- # Define whether input is continuous or discrete depending on configuration
102
- self.is_input_continuous = (in_channels is not None) and (patch_size is None)
103
- self.is_input_vectorized = num_vector_embeds is not None
104
- self.is_input_patches = in_channels is not None and patch_size is not None
105
-
106
- if norm_type == "layer_norm" and num_embeds_ada_norm is not None:
107
- deprecation_message = (
108
- f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or"
109
- " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config."
110
- " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect"
111
- " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it"
112
- " would be very nice if you could open a Pull request for the `transformer/config.json` file"
113
- )
114
- deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False)
115
- norm_type = "ada_norm"
116
-
117
- if self.is_input_continuous and self.is_input_vectorized:
118
- raise ValueError(
119
- f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make"
120
- " sure that either `in_channels` or `num_vector_embeds` is None."
121
- )
122
- elif self.is_input_vectorized and self.is_input_patches:
123
- raise ValueError(
124
- f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make"
125
- " sure that either `num_vector_embeds` or `num_patches` is None."
126
- )
127
- elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches:
128
- raise ValueError(
129
- f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:"
130
- f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None."
131
- )
132
-
133
- # 2. Define input layers
134
- if self.is_input_continuous:
135
- self.in_channels = in_channels
136
-
137
- self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
138
- if use_linear_projection:
139
- self.proj_in = nn.Linear(in_channels, inner_dim)
140
- else:
141
- self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
142
- elif self.is_input_vectorized:
143
- assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
144
- assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed"
145
-
146
- self.height = sample_size
147
- self.width = sample_size
148
- self.num_vector_embeds = num_vector_embeds
149
- self.num_latent_pixels = self.height * self.width
150
-
151
- self.latent_image_embedding = ImagePositionalEmbeddings(
152
- num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width
153
- )
154
- elif self.is_input_patches:
155
- assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size"
156
-
157
- self.height = sample_size
158
- self.width = sample_size
159
-
160
- self.patch_size = patch_size
161
- self.pos_embed = PatchEmbed(
162
- height=sample_size,
163
- width=sample_size,
164
- patch_size=patch_size,
165
- in_channels=in_channels,
166
- embed_dim=inner_dim,
167
- )
168
-
169
- # 3. Define transformers blocks
170
- self.transformer_blocks = nn.ModuleList(
171
- [
172
- BasicTransformerBlock(
173
- inner_dim,
174
- num_attention_heads,
175
- attention_head_dim,
176
- dropout=dropout,
177
- cross_attention_dim=cross_attention_dim,
178
- activation_fn=activation_fn,
179
- num_embeds_ada_norm=num_embeds_ada_norm,
180
- attention_bias=attention_bias,
181
- only_cross_attention=only_cross_attention,
182
- upcast_attention=upcast_attention,
183
- norm_type=norm_type,
184
- norm_elementwise_affine=norm_elementwise_affine,
185
- )
186
- for d in range(num_layers)
187
- ]
188
- )
189
-
190
- # 4. Define output layers
191
- self.out_channels = in_channels if out_channels is None else out_channels
192
- if self.is_input_continuous:
193
- # TODO: should use out_channels for continuous projections
194
- if use_linear_projection:
195
- self.proj_out = nn.Linear(inner_dim, in_channels)
196
- else:
197
- self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
198
- elif self.is_input_vectorized:
199
- self.norm_out = nn.LayerNorm(inner_dim)
200
- self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
201
- elif self.is_input_patches:
202
- self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
203
- self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim)
204
- self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
205
-
206
- def forward(
207
- self,
208
- hidden_states: torch.Tensor,
209
- encoder_hidden_states: Optional[torch.Tensor] = None,
210
- timestep: Optional[torch.LongTensor] = None,
211
- class_labels: Optional[torch.LongTensor] = None,
212
- posemb: Optional = None,
213
- cross_attention_kwargs: Dict[str, Any] = None,
214
- attention_mask: Optional[torch.Tensor] = None,
215
- encoder_attention_mask: Optional[torch.Tensor] = None,
216
- return_dict: bool = True,
217
- ):
218
- """
219
- The [`Transformer2DModel`] forward method.
220
-
221
- Args:
222
- hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
223
- Input `hidden_states`.
224
- encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
225
- Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
226
- self-attention.
227
- timestep ( `torch.LongTensor`, *optional*):
228
- Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
229
- class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
230
- Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
231
- `AdaLayerZeroNorm`.
232
- encoder_attention_mask ( `torch.Tensor`, *optional*):
233
- Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
234
-
235
- * Mask `(batch, sequence_length)` True = keep, False = discard.
236
- * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
237
-
238
- If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
239
- above. This bias will be added to the cross-attention scores.
240
- return_dict (`bool`, *optional*, defaults to `True`):
241
- Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
242
- tuple.
243
-
244
- Returns:
245
- If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
246
- `tuple` where the first element is the sample tensor.
247
- """
248
- # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
249
- # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
250
- # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
251
- # expects mask of shape:
252
- # [batch, key_tokens]
253
- # adds singleton query_tokens dimension:
254
- # [batch, 1, key_tokens]
255
- # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
256
- # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
257
- # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
258
- if attention_mask is not None and attention_mask.ndim == 2:
259
- # assume that mask is expressed as:
260
- # (1 = keep, 0 = discard)
261
- # convert mask into a bias that can be added to attention scores:
262
- # (keep = +0, discard = -10000.0)
263
- attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
264
- attention_mask = attention_mask.unsqueeze(1)
265
-
266
- # convert encoder_attention_mask to a bias the same way we do for attention_mask
267
- if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
268
- encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
269
- encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
270
-
271
- # 1. Input
272
- if self.is_input_continuous:
273
- batch, _, height, width = hidden_states.shape
274
- residual = hidden_states
275
-
276
- hidden_states = self.norm(hidden_states)
277
- if not self.use_linear_projection:
278
- hidden_states = self.proj_in(hidden_states)
279
- inner_dim = hidden_states.shape[1]
280
- hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
281
- else:
282
- inner_dim = hidden_states.shape[1]
283
- hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
284
- hidden_states = self.proj_in(hidden_states)
285
- elif self.is_input_vectorized:
286
- hidden_states = self.latent_image_embedding(hidden_states)
287
- elif self.is_input_patches:
288
- hidden_states = self.pos_embed(hidden_states)
289
-
290
- # 2. Blocks
291
- for block in self.transformer_blocks:
292
- hidden_states = block(
293
- hidden_states,
294
- attention_mask=attention_mask,
295
- encoder_hidden_states=encoder_hidden_states,
296
- encoder_attention_mask=encoder_attention_mask,
297
- timestep=timestep,
298
- posemb=posemb,
299
- cross_attention_kwargs=cross_attention_kwargs,
300
- class_labels=class_labels,
301
- )
302
-
303
- # 3. Output
304
- if self.is_input_continuous:
305
- if not self.use_linear_projection:
306
- hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
307
- hidden_states = self.proj_out(hidden_states)
308
- else:
309
- hidden_states = self.proj_out(hidden_states)
310
- hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
311
-
312
- output = hidden_states + residual
313
- elif self.is_input_vectorized:
314
- hidden_states = self.norm_out(hidden_states)
315
- logits = self.out(hidden_states)
316
- # (batch, self.num_vector_embeds - 1, self.num_latent_pixels)
317
- logits = logits.permute(0, 2, 1)
318
-
319
- # log(p(x_0))
320
- output = F.log_softmax(logits.double(), dim=1).float()
321
- elif self.is_input_patches:
322
- # TODO: cleanup!
323
- conditioning = self.transformer_blocks[0].norm1.emb(
324
- timestep, class_labels, hidden_dtype=hidden_states.dtype
325
- )
326
- shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
327
- hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
328
- hidden_states = self.proj_out_2(hidden_states)
329
-
330
- # unpatchify
331
- height = width = int(hidden_states.shape[1] ** 0.5)
332
- hidden_states = hidden_states.reshape(
333
- shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
334
- )
335
- hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
336
- output = hidden_states.reshape(
337
- shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
338
- )
339
-
340
- if not return_dict:
341
- return (output,)
342
-
343
- return Transformer2DModelOutput(sample=output)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
gradio_demo/6DoF/diffusers/models/transformer_temporal.py DELETED
@@ -1,179 +0,0 @@
1
- # Copyright 2023 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- from dataclasses import dataclass
15
- from typing import Optional
16
-
17
- import torch
18
- from torch import nn
19
-
20
- from ..configuration_utils import ConfigMixin, register_to_config
21
- from ..utils import BaseOutput
22
- from .attention import BasicTransformerBlock
23
- from .modeling_utils import ModelMixin
24
-
25
-
26
- @dataclass
27
- class TransformerTemporalModelOutput(BaseOutput):
28
- """
29
- The output of [`TransformerTemporalModel`].
30
-
31
- Args:
32
- sample (`torch.FloatTensor` of shape `(batch_size x num_frames, num_channels, height, width)`):
33
- The hidden states output conditioned on `encoder_hidden_states` input.
34
- """
35
-
36
- sample: torch.FloatTensor
37
-
38
-
39
- class TransformerTemporalModel(ModelMixin, ConfigMixin):
40
- """
41
- A Transformer model for video-like data.
42
-
43
- Parameters:
44
- num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
45
- attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
46
- in_channels (`int`, *optional*):
47
- The number of channels in the input and output (specify if the input is **continuous**).
48
- num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
49
- dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
50
- cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
51
- sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
52
- This is fixed during training since it is used to learn a number of position embeddings.
53
- activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
54
- attention_bias (`bool`, *optional*):
55
- Configure if the `TransformerBlock` attention should contain a bias parameter.
56
- double_self_attention (`bool`, *optional*):
57
- Configure if each `TransformerBlock` should contain two self-attention layers.
58
- """
59
-
60
- @register_to_config
61
- def __init__(
62
- self,
63
- num_attention_heads: int = 16,
64
- attention_head_dim: int = 88,
65
- in_channels: Optional[int] = None,
66
- out_channels: Optional[int] = None,
67
- num_layers: int = 1,
68
- dropout: float = 0.0,
69
- norm_num_groups: int = 32,
70
- cross_attention_dim: Optional[int] = None,
71
- attention_bias: bool = False,
72
- sample_size: Optional[int] = None,
73
- activation_fn: str = "geglu",
74
- norm_elementwise_affine: bool = True,
75
- double_self_attention: bool = True,
76
- ):
77
- super().__init__()
78
- self.num_attention_heads = num_attention_heads
79
- self.attention_head_dim = attention_head_dim
80
- inner_dim = num_attention_heads * attention_head_dim
81
-
82
- self.in_channels = in_channels
83
-
84
- self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
85
- self.proj_in = nn.Linear(in_channels, inner_dim)
86
-
87
- # 3. Define transformers blocks
88
- self.transformer_blocks = nn.ModuleList(
89
- [
90
- BasicTransformerBlock(
91
- inner_dim,
92
- num_attention_heads,
93
- attention_head_dim,
94
- dropout=dropout,
95
- cross_attention_dim=cross_attention_dim,
96
- activation_fn=activation_fn,
97
- attention_bias=attention_bias,
98
- double_self_attention=double_self_attention,
99
- norm_elementwise_affine=norm_elementwise_affine,
100
- )
101
- for d in range(num_layers)
102
- ]
103
- )
104
-
105
- self.proj_out = nn.Linear(inner_dim, in_channels)
106
-
107
- def forward(
108
- self,
109
- hidden_states,
110
- encoder_hidden_states=None,
111
- timestep=None,
112
- class_labels=None,
113
- num_frames=1,
114
- cross_attention_kwargs=None,
115
- return_dict: bool = True,
116
- ):
117
- """
118
- The [`TransformerTemporal`] forward method.
119
-
120
- Args:
121
- hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
122
- Input hidden_states.
123
- encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
124
- Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
125
- self-attention.
126
- timestep ( `torch.long`, *optional*):
127
- Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
128
- class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
129
- Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
130
- `AdaLayerZeroNorm`.
131
- return_dict (`bool`, *optional*, defaults to `True`):
132
- Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
133
- tuple.
134
-
135
- Returns:
136
- [`~models.transformer_temporal.TransformerTemporalModelOutput`] or `tuple`:
137
- If `return_dict` is True, an [`~models.transformer_temporal.TransformerTemporalModelOutput`] is
138
- returned, otherwise a `tuple` where the first element is the sample tensor.
139
- """
140
- # 1. Input
141
- batch_frames, channel, height, width = hidden_states.shape
142
- batch_size = batch_frames // num_frames
143
-
144
- residual = hidden_states
145
-
146
- hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, channel, height, width)
147
- hidden_states = hidden_states.permute(0, 2, 1, 3, 4)
148
-
149
- hidden_states = self.norm(hidden_states)
150
- hidden_states = hidden_states.permute(0, 3, 4, 2, 1).reshape(batch_size * height * width, num_frames, channel)
151
-
152
- hidden_states = self.proj_in(hidden_states)
153
-
154
- # 2. Blocks
155
- for block in self.transformer_blocks:
156
- hidden_states = block(
157
- hidden_states,
158
- encoder_hidden_states=encoder_hidden_states,
159
- timestep=timestep,
160
- cross_attention_kwargs=cross_attention_kwargs,
161
- class_labels=class_labels,
162
- )
163
-
164
- # 3. Output
165
- hidden_states = self.proj_out(hidden_states)
166
- hidden_states = (
167
- hidden_states[None, None, :]
168
- .reshape(batch_size, height, width, channel, num_frames)
169
- .permute(0, 3, 4, 1, 2)
170
- .contiguous()
171
- )
172
- hidden_states = hidden_states.reshape(batch_frames, channel, height, width)
173
-
174
- output = hidden_states + residual
175
-
176
- if not return_dict:
177
- return (output,)
178
-
179
- return TransformerTemporalModelOutput(sample=output)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
gradio_demo/6DoF/diffusers/models/unet_1d.py DELETED
@@ -1,255 +0,0 @@
1
- # Copyright 2023 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- from dataclasses import dataclass
16
- from typing import Optional, Tuple, Union
17
-
18
- import torch
19
- import torch.nn as nn
20
-
21
- from ..configuration_utils import ConfigMixin, register_to_config
22
- from ..utils import BaseOutput
23
- from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
24
- from .modeling_utils import ModelMixin
25
- from .unet_1d_blocks import get_down_block, get_mid_block, get_out_block, get_up_block
26
-
27
-
28
- @dataclass
29
- class UNet1DOutput(BaseOutput):
30
- """
31
- The output of [`UNet1DModel`].
32
-
33
- Args:
34
- sample (`torch.FloatTensor` of shape `(batch_size, num_channels, sample_size)`):
35
- The hidden states output from the last layer of the model.
36
- """
37
-
38
- sample: torch.FloatTensor
39
-
40
-
41
- class UNet1DModel(ModelMixin, ConfigMixin):
42
- r"""
43
- A 1D UNet model that takes a noisy sample and a timestep and returns a sample shaped output.
44
-
45
- This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
46
- for all models (such as downloading or saving).
47
-
48
- Parameters:
49
- sample_size (`int`, *optional*): Default length of sample. Should be adaptable at runtime.
50
- in_channels (`int`, *optional*, defaults to 2): Number of channels in the input sample.
51
- out_channels (`int`, *optional*, defaults to 2): Number of channels in the output.
52
- extra_in_channels (`int`, *optional*, defaults to 0):
53
- Number of additional channels to be added to the input of the first down block. Useful for cases where the
54
- input data has more channels than what the model was initially designed for.
55
- time_embedding_type (`str`, *optional*, defaults to `"fourier"`): Type of time embedding to use.
56
- freq_shift (`float`, *optional*, defaults to 0.0): Frequency shift for Fourier time embedding.
57
- flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
58
- Whether to flip sin to cos for Fourier time embedding.
59
- down_block_types (`Tuple[str]`, *optional*, defaults to `("DownBlock1D", "DownBlock1DNoSkip", "AttnDownBlock1D")`):
60
- Tuple of downsample block types.
61
- up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock1D", "UpBlock1DNoSkip", "AttnUpBlock1D")`):
62
- Tuple of upsample block types.
63
- block_out_channels (`Tuple[int]`, *optional*, defaults to `(32, 32, 64)`):
64
- Tuple of block output channels.
65
- mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock1D"`): Block type for middle of UNet.
66
- out_block_type (`str`, *optional*, defaults to `None`): Optional output processing block of UNet.
67
- act_fn (`str`, *optional*, defaults to `None`): Optional activation function in UNet blocks.
68
- norm_num_groups (`int`, *optional*, defaults to 8): The number of groups for normalization.
69
- layers_per_block (`int`, *optional*, defaults to 1): The number of layers per block.
70
- downsample_each_block (`int`, *optional*, defaults to `False`):
71
- Experimental feature for using a UNet without upsampling.
72
- """
73
-
74
- @register_to_config
75
- def __init__(
76
- self,
77
- sample_size: int = 65536,
78
- sample_rate: Optional[int] = None,
79
- in_channels: int = 2,
80
- out_channels: int = 2,
81
- extra_in_channels: int = 0,
82
- time_embedding_type: str = "fourier",
83
- flip_sin_to_cos: bool = True,
84
- use_timestep_embedding: bool = False,
85
- freq_shift: float = 0.0,
86
- down_block_types: Tuple[str] = ("DownBlock1DNoSkip", "DownBlock1D", "AttnDownBlock1D"),
87
- up_block_types: Tuple[str] = ("AttnUpBlock1D", "UpBlock1D", "UpBlock1DNoSkip"),
88
- mid_block_type: Tuple[str] = "UNetMidBlock1D",
89
- out_block_type: str = None,
90
- block_out_channels: Tuple[int] = (32, 32, 64),
91
- act_fn: str = None,
92
- norm_num_groups: int = 8,
93
- layers_per_block: int = 1,
94
- downsample_each_block: bool = False,
95
- ):
96
- super().__init__()
97
- self.sample_size = sample_size
98
-
99
- # time
100
- if time_embedding_type == "fourier":
101
- self.time_proj = GaussianFourierProjection(
102
- embedding_size=8, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
103
- )
104
- timestep_input_dim = 2 * block_out_channels[0]
105
- elif time_embedding_type == "positional":
106
- self.time_proj = Timesteps(
107
- block_out_channels[0], flip_sin_to_cos=flip_sin_to_cos, downscale_freq_shift=freq_shift
108
- )
109
- timestep_input_dim = block_out_channels[0]
110
-
111
- if use_timestep_embedding:
112
- time_embed_dim = block_out_channels[0] * 4
113
- self.time_mlp = TimestepEmbedding(
114
- in_channels=timestep_input_dim,
115
- time_embed_dim=time_embed_dim,
116
- act_fn=act_fn,
117
- out_dim=block_out_channels[0],
118
- )
119
-
120
- self.down_blocks = nn.ModuleList([])
121
- self.mid_block = None
122
- self.up_blocks = nn.ModuleList([])
123
- self.out_block = None
124
-
125
- # down
126
- output_channel = in_channels
127
- for i, down_block_type in enumerate(down_block_types):
128
- input_channel = output_channel
129
- output_channel = block_out_channels[i]
130
-
131
- if i == 0:
132
- input_channel += extra_in_channels
133
-
134
- is_final_block = i == len(block_out_channels) - 1
135
-
136
- down_block = get_down_block(
137
- down_block_type,
138
- num_layers=layers_per_block,
139
- in_channels=input_channel,
140
- out_channels=output_channel,
141
- temb_channels=block_out_channels[0],
142
- add_downsample=not is_final_block or downsample_each_block,
143
- )
144
- self.down_blocks.append(down_block)
145
-
146
- # mid
147
- self.mid_block = get_mid_block(
148
- mid_block_type,
149
- in_channels=block_out_channels[-1],
150
- mid_channels=block_out_channels[-1],
151
- out_channels=block_out_channels[-1],
152
- embed_dim=block_out_channels[0],
153
- num_layers=layers_per_block,
154
- add_downsample=downsample_each_block,
155
- )
156
-
157
- # up
158
- reversed_block_out_channels = list(reversed(block_out_channels))
159
- output_channel = reversed_block_out_channels[0]
160
- if out_block_type is None:
161
- final_upsample_channels = out_channels
162
- else:
163
- final_upsample_channels = block_out_channels[0]
164
-
165
- for i, up_block_type in enumerate(up_block_types):
166
- prev_output_channel = output_channel
167
- output_channel = (
168
- reversed_block_out_channels[i + 1] if i < len(up_block_types) - 1 else final_upsample_channels
169
- )
170
-
171
- is_final_block = i == len(block_out_channels) - 1
172
-
173
- up_block = get_up_block(
174
- up_block_type,
175
- num_layers=layers_per_block,
176
- in_channels=prev_output_channel,
177
- out_channels=output_channel,
178
- temb_channels=block_out_channels[0],
179
- add_upsample=not is_final_block,
180
- )
181
- self.up_blocks.append(up_block)
182
- prev_output_channel = output_channel
183
-
184
- # out
185
- num_groups_out = norm_num_groups if norm_num_groups is not None else min(block_out_channels[0] // 4, 32)
186
- self.out_block = get_out_block(
187
- out_block_type=out_block_type,
188
- num_groups_out=num_groups_out,
189
- embed_dim=block_out_channels[0],
190
- out_channels=out_channels,
191
- act_fn=act_fn,
192
- fc_dim=block_out_channels[-1] // 4,
193
- )
194
-
195
- def forward(
196
- self,
197
- sample: torch.FloatTensor,
198
- timestep: Union[torch.Tensor, float, int],
199
- return_dict: bool = True,
200
- ) -> Union[UNet1DOutput, Tuple]:
201
- r"""
202
- The [`UNet1DModel`] forward method.
203
-
204
- Args:
205
- sample (`torch.FloatTensor`):
206
- The noisy input tensor with the following shape `(batch_size, num_channels, sample_size)`.
207
- timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
208
- return_dict (`bool`, *optional*, defaults to `True`):
209
- Whether or not to return a [`~models.unet_1d.UNet1DOutput`] instead of a plain tuple.
210
-
211
- Returns:
212
- [`~models.unet_1d.UNet1DOutput`] or `tuple`:
213
- If `return_dict` is True, an [`~models.unet_1d.UNet1DOutput`] is returned, otherwise a `tuple` is
214
- returned where the first element is the sample tensor.
215
- """
216
-
217
- # 1. time
218
- timesteps = timestep
219
- if not torch.is_tensor(timesteps):
220
- timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
221
- elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
222
- timesteps = timesteps[None].to(sample.device)
223
-
224
- timestep_embed = self.time_proj(timesteps)
225
- if self.config.use_timestep_embedding:
226
- timestep_embed = self.time_mlp(timestep_embed)
227
- else:
228
- timestep_embed = timestep_embed[..., None]
229
- timestep_embed = timestep_embed.repeat([1, 1, sample.shape[2]]).to(sample.dtype)
230
- timestep_embed = timestep_embed.broadcast_to((sample.shape[:1] + timestep_embed.shape[1:]))
231
-
232
- # 2. down
233
- down_block_res_samples = ()
234
- for downsample_block in self.down_blocks:
235
- sample, res_samples = downsample_block(hidden_states=sample, temb=timestep_embed)
236
- down_block_res_samples += res_samples
237
-
238
- # 3. mid
239
- if self.mid_block:
240
- sample = self.mid_block(sample, timestep_embed)
241
-
242
- # 4. up
243
- for i, upsample_block in enumerate(self.up_blocks):
244
- res_samples = down_block_res_samples[-1:]
245
- down_block_res_samples = down_block_res_samples[:-1]
246
- sample = upsample_block(sample, res_hidden_states_tuple=res_samples, temb=timestep_embed)
247
-
248
- # 5. post-process
249
- if self.out_block:
250
- sample = self.out_block(sample, timestep_embed)
251
-
252
- if not return_dict:
253
- return (sample,)
254
-
255
- return UNet1DOutput(sample=sample)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
gradio_demo/6DoF/diffusers/models/unet_1d_blocks.py DELETED
@@ -1,656 +0,0 @@
1
- # Copyright 2023 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- import math
15
-
16
- import torch
17
- import torch.nn.functional as F
18
- from torch import nn
19
-
20
- from .activations import get_activation
21
- from .resnet import Downsample1D, ResidualTemporalBlock1D, Upsample1D, rearrange_dims
22
-
23
-
24
- class DownResnetBlock1D(nn.Module):
25
- def __init__(
26
- self,
27
- in_channels,
28
- out_channels=None,
29
- num_layers=1,
30
- conv_shortcut=False,
31
- temb_channels=32,
32
- groups=32,
33
- groups_out=None,
34
- non_linearity=None,
35
- time_embedding_norm="default",
36
- output_scale_factor=1.0,
37
- add_downsample=True,
38
- ):
39
- super().__init__()
40
- self.in_channels = in_channels
41
- out_channels = in_channels if out_channels is None else out_channels
42
- self.out_channels = out_channels
43
- self.use_conv_shortcut = conv_shortcut
44
- self.time_embedding_norm = time_embedding_norm
45
- self.add_downsample = add_downsample
46
- self.output_scale_factor = output_scale_factor
47
-
48
- if groups_out is None:
49
- groups_out = groups
50
-
51
- # there will always be at least one resnet
52
- resnets = [ResidualTemporalBlock1D(in_channels, out_channels, embed_dim=temb_channels)]
53
-
54
- for _ in range(num_layers):
55
- resnets.append(ResidualTemporalBlock1D(out_channels, out_channels, embed_dim=temb_channels))
56
-
57
- self.resnets = nn.ModuleList(resnets)
58
-
59
- if non_linearity is None:
60
- self.nonlinearity = None
61
- else:
62
- self.nonlinearity = get_activation(non_linearity)
63
-
64
- self.downsample = None
65
- if add_downsample:
66
- self.downsample = Downsample1D(out_channels, use_conv=True, padding=1)
67
-
68
- def forward(self, hidden_states, temb=None):
69
- output_states = ()
70
-
71
- hidden_states = self.resnets[0](hidden_states, temb)
72
- for resnet in self.resnets[1:]:
73
- hidden_states = resnet(hidden_states, temb)
74
-
75
- output_states += (hidden_states,)
76
-
77
- if self.nonlinearity is not None:
78
- hidden_states = self.nonlinearity(hidden_states)
79
-
80
- if self.downsample is not None:
81
- hidden_states = self.downsample(hidden_states)
82
-
83
- return hidden_states, output_states
84
-
85
-
86
- class UpResnetBlock1D(nn.Module):
87
- def __init__(
88
- self,
89
- in_channels,
90
- out_channels=None,
91
- num_layers=1,
92
- temb_channels=32,
93
- groups=32,
94
- groups_out=None,
95
- non_linearity=None,
96
- time_embedding_norm="default",
97
- output_scale_factor=1.0,
98
- add_upsample=True,
99
- ):
100
- super().__init__()
101
- self.in_channels = in_channels
102
- out_channels = in_channels if out_channels is None else out_channels
103
- self.out_channels = out_channels
104
- self.time_embedding_norm = time_embedding_norm
105
- self.add_upsample = add_upsample
106
- self.output_scale_factor = output_scale_factor
107
-
108
- if groups_out is None:
109
- groups_out = groups
110
-
111
- # there will always be at least one resnet
112
- resnets = [ResidualTemporalBlock1D(2 * in_channels, out_channels, embed_dim=temb_channels)]
113
-
114
- for _ in range(num_layers):
115
- resnets.append(ResidualTemporalBlock1D(out_channels, out_channels, embed_dim=temb_channels))
116
-
117
- self.resnets = nn.ModuleList(resnets)
118
-
119
- if non_linearity is None:
120
- self.nonlinearity = None
121
- else:
122
- self.nonlinearity = get_activation(non_linearity)
123
-
124
- self.upsample = None
125
- if add_upsample:
126
- self.upsample = Upsample1D(out_channels, use_conv_transpose=True)
127
-
128
- def forward(self, hidden_states, res_hidden_states_tuple=None, temb=None):
129
- if res_hidden_states_tuple is not None:
130
- res_hidden_states = res_hidden_states_tuple[-1]
131
- hidden_states = torch.cat((hidden_states, res_hidden_states), dim=1)
132
-
133
- hidden_states = self.resnets[0](hidden_states, temb)
134
- for resnet in self.resnets[1:]:
135
- hidden_states = resnet(hidden_states, temb)
136
-
137
- if self.nonlinearity is not None:
138
- hidden_states = self.nonlinearity(hidden_states)
139
-
140
- if self.upsample is not None:
141
- hidden_states = self.upsample(hidden_states)
142
-
143
- return hidden_states
144
-
145
-
146
- class ValueFunctionMidBlock1D(nn.Module):
147
- def __init__(self, in_channels, out_channels, embed_dim):
148
- super().__init__()
149
- self.in_channels = in_channels
150
- self.out_channels = out_channels
151
- self.embed_dim = embed_dim
152
-
153
- self.res1 = ResidualTemporalBlock1D(in_channels, in_channels // 2, embed_dim=embed_dim)
154
- self.down1 = Downsample1D(out_channels // 2, use_conv=True)
155
- self.res2 = ResidualTemporalBlock1D(in_channels // 2, in_channels // 4, embed_dim=embed_dim)
156
- self.down2 = Downsample1D(out_channels // 4, use_conv=True)
157
-
158
- def forward(self, x, temb=None):
159
- x = self.res1(x, temb)
160
- x = self.down1(x)
161
- x = self.res2(x, temb)
162
- x = self.down2(x)
163
- return x
164
-
165
-
166
- class MidResTemporalBlock1D(nn.Module):
167
- def __init__(
168
- self,
169
- in_channels,
170
- out_channels,
171
- embed_dim,
172
- num_layers: int = 1,
173
- add_downsample: bool = False,
174
- add_upsample: bool = False,
175
- non_linearity=None,
176
- ):
177
- super().__init__()
178
- self.in_channels = in_channels
179
- self.out_channels = out_channels
180
- self.add_downsample = add_downsample
181
-
182
- # there will always be at least one resnet
183
- resnets = [ResidualTemporalBlock1D(in_channels, out_channels, embed_dim=embed_dim)]
184
-
185
- for _ in range(num_layers):
186
- resnets.append(ResidualTemporalBlock1D(out_channels, out_channels, embed_dim=embed_dim))
187
-
188
- self.resnets = nn.ModuleList(resnets)
189
-
190
- if non_linearity is None:
191
- self.nonlinearity = None
192
- else:
193
- self.nonlinearity = get_activation(non_linearity)
194
-
195
- self.upsample = None
196
- if add_upsample:
197
- self.upsample = Downsample1D(out_channels, use_conv=True)
198
-
199
- self.downsample = None
200
- if add_downsample:
201
- self.downsample = Downsample1D(out_channels, use_conv=True)
202
-
203
- if self.upsample and self.downsample:
204
- raise ValueError("Block cannot downsample and upsample")
205
-
206
- def forward(self, hidden_states, temb):
207
- hidden_states = self.resnets[0](hidden_states, temb)
208
- for resnet in self.resnets[1:]:
209
- hidden_states = resnet(hidden_states, temb)
210
-
211
- if self.upsample:
212
- hidden_states = self.upsample(hidden_states)
213
- if self.downsample:
214
- self.downsample = self.downsample(hidden_states)
215
-
216
- return hidden_states
217
-
218
-
219
- class OutConv1DBlock(nn.Module):
220
- def __init__(self, num_groups_out, out_channels, embed_dim, act_fn):
221
- super().__init__()
222
- self.final_conv1d_1 = nn.Conv1d(embed_dim, embed_dim, 5, padding=2)
223
- self.final_conv1d_gn = nn.GroupNorm(num_groups_out, embed_dim)
224
- self.final_conv1d_act = get_activation(act_fn)
225
- self.final_conv1d_2 = nn.Conv1d(embed_dim, out_channels, 1)
226
-
227
- def forward(self, hidden_states, temb=None):
228
- hidden_states = self.final_conv1d_1(hidden_states)
229
- hidden_states = rearrange_dims(hidden_states)
230
- hidden_states = self.final_conv1d_gn(hidden_states)
231
- hidden_states = rearrange_dims(hidden_states)
232
- hidden_states = self.final_conv1d_act(hidden_states)
233
- hidden_states = self.final_conv1d_2(hidden_states)
234
- return hidden_states
235
-
236
-
237
- class OutValueFunctionBlock(nn.Module):
238
- def __init__(self, fc_dim, embed_dim):
239
- super().__init__()
240
- self.final_block = nn.ModuleList(
241
- [
242
- nn.Linear(fc_dim + embed_dim, fc_dim // 2),
243
- nn.Mish(),
244
- nn.Linear(fc_dim // 2, 1),
245
- ]
246
- )
247
-
248
- def forward(self, hidden_states, temb):
249
- hidden_states = hidden_states.view(hidden_states.shape[0], -1)
250
- hidden_states = torch.cat((hidden_states, temb), dim=-1)
251
- for layer in self.final_block:
252
- hidden_states = layer(hidden_states)
253
-
254
- return hidden_states
255
-
256
-
257
- _kernels = {
258
- "linear": [1 / 8, 3 / 8, 3 / 8, 1 / 8],
259
- "cubic": [-0.01171875, -0.03515625, 0.11328125, 0.43359375, 0.43359375, 0.11328125, -0.03515625, -0.01171875],
260
- "lanczos3": [
261
- 0.003689131001010537,
262
- 0.015056144446134567,
263
- -0.03399861603975296,
264
- -0.066637322306633,
265
- 0.13550527393817902,
266
- 0.44638532400131226,
267
- 0.44638532400131226,
268
- 0.13550527393817902,
269
- -0.066637322306633,
270
- -0.03399861603975296,
271
- 0.015056144446134567,
272
- 0.003689131001010537,
273
- ],
274
- }
275
-
276
-
277
- class Downsample1d(nn.Module):
278
- def __init__(self, kernel="linear", pad_mode="reflect"):
279
- super().__init__()
280
- self.pad_mode = pad_mode
281
- kernel_1d = torch.tensor(_kernels[kernel])
282
- self.pad = kernel_1d.shape[0] // 2 - 1
283
- self.register_buffer("kernel", kernel_1d)
284
-
285
- def forward(self, hidden_states):
286
- hidden_states = F.pad(hidden_states, (self.pad,) * 2, self.pad_mode)
287
- weight = hidden_states.new_zeros([hidden_states.shape[1], hidden_states.shape[1], self.kernel.shape[0]])
288
- indices = torch.arange(hidden_states.shape[1], device=hidden_states.device)
289
- kernel = self.kernel.to(weight)[None, :].expand(hidden_states.shape[1], -1)
290
- weight[indices, indices] = kernel
291
- return F.conv1d(hidden_states, weight, stride=2)
292
-
293
-
294
- class Upsample1d(nn.Module):
295
- def __init__(self, kernel="linear", pad_mode="reflect"):
296
- super().__init__()
297
- self.pad_mode = pad_mode
298
- kernel_1d = torch.tensor(_kernels[kernel]) * 2
299
- self.pad = kernel_1d.shape[0] // 2 - 1
300
- self.register_buffer("kernel", kernel_1d)
301
-
302
- def forward(self, hidden_states, temb=None):
303
- hidden_states = F.pad(hidden_states, ((self.pad + 1) // 2,) * 2, self.pad_mode)
304
- weight = hidden_states.new_zeros([hidden_states.shape[1], hidden_states.shape[1], self.kernel.shape[0]])
305
- indices = torch.arange(hidden_states.shape[1], device=hidden_states.device)
306
- kernel = self.kernel.to(weight)[None, :].expand(hidden_states.shape[1], -1)
307
- weight[indices, indices] = kernel
308
- return F.conv_transpose1d(hidden_states, weight, stride=2, padding=self.pad * 2 + 1)
309
-
310
-
311
- class SelfAttention1d(nn.Module):
312
- def __init__(self, in_channels, n_head=1, dropout_rate=0.0):
313
- super().__init__()
314
- self.channels = in_channels
315
- self.group_norm = nn.GroupNorm(1, num_channels=in_channels)
316
- self.num_heads = n_head
317
-
318
- self.query = nn.Linear(self.channels, self.channels)
319
- self.key = nn.Linear(self.channels, self.channels)
320
- self.value = nn.Linear(self.channels, self.channels)
321
-
322
- self.proj_attn = nn.Linear(self.channels, self.channels, bias=True)
323
-
324
- self.dropout = nn.Dropout(dropout_rate, inplace=True)
325
-
326
- def transpose_for_scores(self, projection: torch.Tensor) -> torch.Tensor:
327
- new_projection_shape = projection.size()[:-1] + (self.num_heads, -1)
328
- # move heads to 2nd position (B, T, H * D) -> (B, T, H, D) -> (B, H, T, D)
329
- new_projection = projection.view(new_projection_shape).permute(0, 2, 1, 3)
330
- return new_projection
331
-
332
- def forward(self, hidden_states):
333
- residual = hidden_states
334
- batch, channel_dim, seq = hidden_states.shape
335
-
336
- hidden_states = self.group_norm(hidden_states)
337
- hidden_states = hidden_states.transpose(1, 2)
338
-
339
- query_proj = self.query(hidden_states)
340
- key_proj = self.key(hidden_states)
341
- value_proj = self.value(hidden_states)
342
-
343
- query_states = self.transpose_for_scores(query_proj)
344
- key_states = self.transpose_for_scores(key_proj)
345
- value_states = self.transpose_for_scores(value_proj)
346
-
347
- scale = 1 / math.sqrt(math.sqrt(key_states.shape[-1]))
348
-
349
- attention_scores = torch.matmul(query_states * scale, key_states.transpose(-1, -2) * scale)
350
- attention_probs = torch.softmax(attention_scores, dim=-1)
351
-
352
- # compute attention output
353
- hidden_states = torch.matmul(attention_probs, value_states)
354
-
355
- hidden_states = hidden_states.permute(0, 2, 1, 3).contiguous()
356
- new_hidden_states_shape = hidden_states.size()[:-2] + (self.channels,)
357
- hidden_states = hidden_states.view(new_hidden_states_shape)
358
-
359
- # compute next hidden_states
360
- hidden_states = self.proj_attn(hidden_states)
361
- hidden_states = hidden_states.transpose(1, 2)
362
- hidden_states = self.dropout(hidden_states)
363
-
364
- output = hidden_states + residual
365
-
366
- return output
367
-
368
-
369
- class ResConvBlock(nn.Module):
370
- def __init__(self, in_channels, mid_channels, out_channels, is_last=False):
371
- super().__init__()
372
- self.is_last = is_last
373
- self.has_conv_skip = in_channels != out_channels
374
-
375
- if self.has_conv_skip:
376
- self.conv_skip = nn.Conv1d(in_channels, out_channels, 1, bias=False)
377
-
378
- self.conv_1 = nn.Conv1d(in_channels, mid_channels, 5, padding=2)
379
- self.group_norm_1 = nn.GroupNorm(1, mid_channels)
380
- self.gelu_1 = nn.GELU()
381
- self.conv_2 = nn.Conv1d(mid_channels, out_channels, 5, padding=2)
382
-
383
- if not self.is_last:
384
- self.group_norm_2 = nn.GroupNorm(1, out_channels)
385
- self.gelu_2 = nn.GELU()
386
-
387
- def forward(self, hidden_states):
388
- residual = self.conv_skip(hidden_states) if self.has_conv_skip else hidden_states
389
-
390
- hidden_states = self.conv_1(hidden_states)
391
- hidden_states = self.group_norm_1(hidden_states)
392
- hidden_states = self.gelu_1(hidden_states)
393
- hidden_states = self.conv_2(hidden_states)
394
-
395
- if not self.is_last:
396
- hidden_states = self.group_norm_2(hidden_states)
397
- hidden_states = self.gelu_2(hidden_states)
398
-
399
- output = hidden_states + residual
400
- return output
401
-
402
-
403
- class UNetMidBlock1D(nn.Module):
404
- def __init__(self, mid_channels, in_channels, out_channels=None):
405
- super().__init__()
406
-
407
- out_channels = in_channels if out_channels is None else out_channels
408
-
409
- # there is always at least one resnet
410
- self.down = Downsample1d("cubic")
411
- resnets = [
412
- ResConvBlock(in_channels, mid_channels, mid_channels),
413
- ResConvBlock(mid_channels, mid_channels, mid_channels),
414
- ResConvBlock(mid_channels, mid_channels, mid_channels),
415
- ResConvBlock(mid_channels, mid_channels, mid_channels),
416
- ResConvBlock(mid_channels, mid_channels, mid_channels),
417
- ResConvBlock(mid_channels, mid_channels, out_channels),
418
- ]
419
- attentions = [
420
- SelfAttention1d(mid_channels, mid_channels // 32),
421
- SelfAttention1d(mid_channels, mid_channels // 32),
422
- SelfAttention1d(mid_channels, mid_channels // 32),
423
- SelfAttention1d(mid_channels, mid_channels // 32),
424
- SelfAttention1d(mid_channels, mid_channels // 32),
425
- SelfAttention1d(out_channels, out_channels // 32),
426
- ]
427
- self.up = Upsample1d(kernel="cubic")
428
-
429
- self.attentions = nn.ModuleList(attentions)
430
- self.resnets = nn.ModuleList(resnets)
431
-
432
- def forward(self, hidden_states, temb=None):
433
- hidden_states = self.down(hidden_states)
434
- for attn, resnet in zip(self.attentions, self.resnets):
435
- hidden_states = resnet(hidden_states)
436
- hidden_states = attn(hidden_states)
437
-
438
- hidden_states = self.up(hidden_states)
439
-
440
- return hidden_states
441
-
442
-
443
- class AttnDownBlock1D(nn.Module):
444
- def __init__(self, out_channels, in_channels, mid_channels=None):
445
- super().__init__()
446
- mid_channels = out_channels if mid_channels is None else mid_channels
447
-
448
- self.down = Downsample1d("cubic")
449
- resnets = [
450
- ResConvBlock(in_channels, mid_channels, mid_channels),
451
- ResConvBlock(mid_channels, mid_channels, mid_channels),
452
- ResConvBlock(mid_channels, mid_channels, out_channels),
453
- ]
454
- attentions = [
455
- SelfAttention1d(mid_channels, mid_channels // 32),
456
- SelfAttention1d(mid_channels, mid_channels // 32),
457
- SelfAttention1d(out_channels, out_channels // 32),
458
- ]
459
-
460
- self.attentions = nn.ModuleList(attentions)
461
- self.resnets = nn.ModuleList(resnets)
462
-
463
- def forward(self, hidden_states, temb=None):
464
- hidden_states = self.down(hidden_states)
465
-
466
- for resnet, attn in zip(self.resnets, self.attentions):
467
- hidden_states = resnet(hidden_states)
468
- hidden_states = attn(hidden_states)
469
-
470
- return hidden_states, (hidden_states,)
471
-
472
-
473
- class DownBlock1D(nn.Module):
474
- def __init__(self, out_channels, in_channels, mid_channels=None):
475
- super().__init__()
476
- mid_channels = out_channels if mid_channels is None else mid_channels
477
-
478
- self.down = Downsample1d("cubic")
479
- resnets = [
480
- ResConvBlock(in_channels, mid_channels, mid_channels),
481
- ResConvBlock(mid_channels, mid_channels, mid_channels),
482
- ResConvBlock(mid_channels, mid_channels, out_channels),
483
- ]
484
-
485
- self.resnets = nn.ModuleList(resnets)
486
-
487
- def forward(self, hidden_states, temb=None):
488
- hidden_states = self.down(hidden_states)
489
-
490
- for resnet in self.resnets:
491
- hidden_states = resnet(hidden_states)
492
-
493
- return hidden_states, (hidden_states,)
494
-
495
-
496
- class DownBlock1DNoSkip(nn.Module):
497
- def __init__(self, out_channels, in_channels, mid_channels=None):
498
- super().__init__()
499
- mid_channels = out_channels if mid_channels is None else mid_channels
500
-
501
- resnets = [
502
- ResConvBlock(in_channels, mid_channels, mid_channels),
503
- ResConvBlock(mid_channels, mid_channels, mid_channels),
504
- ResConvBlock(mid_channels, mid_channels, out_channels),
505
- ]
506
-
507
- self.resnets = nn.ModuleList(resnets)
508
-
509
- def forward(self, hidden_states, temb=None):
510
- hidden_states = torch.cat([hidden_states, temb], dim=1)
511
- for resnet in self.resnets:
512
- hidden_states = resnet(hidden_states)
513
-
514
- return hidden_states, (hidden_states,)
515
-
516
-
517
- class AttnUpBlock1D(nn.Module):
518
- def __init__(self, in_channels, out_channels, mid_channels=None):
519
- super().__init__()
520
- mid_channels = out_channels if mid_channels is None else mid_channels
521
-
522
- resnets = [
523
- ResConvBlock(2 * in_channels, mid_channels, mid_channels),
524
- ResConvBlock(mid_channels, mid_channels, mid_channels),
525
- ResConvBlock(mid_channels, mid_channels, out_channels),
526
- ]
527
- attentions = [
528
- SelfAttention1d(mid_channels, mid_channels // 32),
529
- SelfAttention1d(mid_channels, mid_channels // 32),
530
- SelfAttention1d(out_channels, out_channels // 32),
531
- ]
532
-
533
- self.attentions = nn.ModuleList(attentions)
534
- self.resnets = nn.ModuleList(resnets)
535
- self.up = Upsample1d(kernel="cubic")
536
-
537
- def forward(self, hidden_states, res_hidden_states_tuple, temb=None):
538
- res_hidden_states = res_hidden_states_tuple[-1]
539
- hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
540
-
541
- for resnet, attn in zip(self.resnets, self.attentions):
542
- hidden_states = resnet(hidden_states)
543
- hidden_states = attn(hidden_states)
544
-
545
- hidden_states = self.up(hidden_states)
546
-
547
- return hidden_states
548
-
549
-
550
- class UpBlock1D(nn.Module):
551
- def __init__(self, in_channels, out_channels, mid_channels=None):
552
- super().__init__()
553
- mid_channels = in_channels if mid_channels is None else mid_channels
554
-
555
- resnets = [
556
- ResConvBlock(2 * in_channels, mid_channels, mid_channels),
557
- ResConvBlock(mid_channels, mid_channels, mid_channels),
558
- ResConvBlock(mid_channels, mid_channels, out_channels),
559
- ]
560
-
561
- self.resnets = nn.ModuleList(resnets)
562
- self.up = Upsample1d(kernel="cubic")
563
-
564
- def forward(self, hidden_states, res_hidden_states_tuple, temb=None):
565
- res_hidden_states = res_hidden_states_tuple[-1]
566
- hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
567
-
568
- for resnet in self.resnets:
569
- hidden_states = resnet(hidden_states)
570
-
571
- hidden_states = self.up(hidden_states)
572
-
573
- return hidden_states
574
-
575
-
576
- class UpBlock1DNoSkip(nn.Module):
577
- def __init__(self, in_channels, out_channels, mid_channels=None):
578
- super().__init__()
579
- mid_channels = in_channels if mid_channels is None else mid_channels
580
-
581
- resnets = [
582
- ResConvBlock(2 * in_channels, mid_channels, mid_channels),
583
- ResConvBlock(mid_channels, mid_channels, mid_channels),
584
- ResConvBlock(mid_channels, mid_channels, out_channels, is_last=True),
585
- ]
586
-
587
- self.resnets = nn.ModuleList(resnets)
588
-
589
- def forward(self, hidden_states, res_hidden_states_tuple, temb=None):
590
- res_hidden_states = res_hidden_states_tuple[-1]
591
- hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
592
-
593
- for resnet in self.resnets:
594
- hidden_states = resnet(hidden_states)
595
-
596
- return hidden_states
597
-
598
-
599
- def get_down_block(down_block_type, num_layers, in_channels, out_channels, temb_channels, add_downsample):
600
- if down_block_type == "DownResnetBlock1D":
601
- return DownResnetBlock1D(
602
- in_channels=in_channels,
603
- num_layers=num_layers,
604
- out_channels=out_channels,
605
- temb_channels=temb_channels,
606
- add_downsample=add_downsample,
607
- )
608
- elif down_block_type == "DownBlock1D":
609
- return DownBlock1D(out_channels=out_channels, in_channels=in_channels)
610
- elif down_block_type == "AttnDownBlock1D":
611
- return AttnDownBlock1D(out_channels=out_channels, in_channels=in_channels)
612
- elif down_block_type == "DownBlock1DNoSkip":
613
- return DownBlock1DNoSkip(out_channels=out_channels, in_channels=in_channels)
614
- raise ValueError(f"{down_block_type} does not exist.")
615
-
616
-
617
- def get_up_block(up_block_type, num_layers, in_channels, out_channels, temb_channels, add_upsample):
618
- if up_block_type == "UpResnetBlock1D":
619
- return UpResnetBlock1D(
620
- in_channels=in_channels,
621
- num_layers=num_layers,
622
- out_channels=out_channels,
623
- temb_channels=temb_channels,
624
- add_upsample=add_upsample,
625
- )
626
- elif up_block_type == "UpBlock1D":
627
- return UpBlock1D(in_channels=in_channels, out_channels=out_channels)
628
- elif up_block_type == "AttnUpBlock1D":
629
- return AttnUpBlock1D(in_channels=in_channels, out_channels=out_channels)
630
- elif up_block_type == "UpBlock1DNoSkip":
631
- return UpBlock1DNoSkip(in_channels=in_channels, out_channels=out_channels)
632
- raise ValueError(f"{up_block_type} does not exist.")
633
-
634
-
635
- def get_mid_block(mid_block_type, num_layers, in_channels, mid_channels, out_channels, embed_dim, add_downsample):
636
- if mid_block_type == "MidResTemporalBlock1D":
637
- return MidResTemporalBlock1D(
638
- num_layers=num_layers,
639
- in_channels=in_channels,
640
- out_channels=out_channels,
641
- embed_dim=embed_dim,
642
- add_downsample=add_downsample,
643
- )
644
- elif mid_block_type == "ValueFunctionMidBlock1D":
645
- return ValueFunctionMidBlock1D(in_channels=in_channels, out_channels=out_channels, embed_dim=embed_dim)
646
- elif mid_block_type == "UNetMidBlock1D":
647
- return UNetMidBlock1D(in_channels=in_channels, mid_channels=mid_channels, out_channels=out_channels)
648
- raise ValueError(f"{mid_block_type} does not exist.")
649
-
650
-
651
- def get_out_block(*, out_block_type, num_groups_out, embed_dim, out_channels, act_fn, fc_dim):
652
- if out_block_type == "OutConv1DBlock":
653
- return OutConv1DBlock(num_groups_out, out_channels, embed_dim, act_fn)
654
- elif out_block_type == "ValueFunction":
655
- return OutValueFunctionBlock(fc_dim, embed_dim)
656
- return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
gradio_demo/6DoF/diffusers/models/unet_2d.py DELETED
@@ -1,329 +0,0 @@
1
- # Copyright 2023 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- from dataclasses import dataclass
15
- from typing import Optional, Tuple, Union
16
-
17
- import torch
18
- import torch.nn as nn
19
-
20
- from ..configuration_utils import ConfigMixin, register_to_config
21
- from ..utils import BaseOutput
22
- from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
23
- from .modeling_utils import ModelMixin
24
- from .unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block
25
-
26
-
27
- @dataclass
28
- class UNet2DOutput(BaseOutput):
29
- """
30
- The output of [`UNet2DModel`].
31
-
32
- Args:
33
- sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
34
- The hidden states output from the last layer of the model.
35
- """
36
-
37
- sample: torch.FloatTensor
38
-
39
-
40
- class UNet2DModel(ModelMixin, ConfigMixin):
41
- r"""
42
- A 2D UNet model that takes a noisy sample and a timestep and returns a sample shaped output.
43
-
44
- This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
45
- for all models (such as downloading or saving).
46
-
47
- Parameters:
48
- sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
49
- Height and width of input/output sample. Dimensions must be a multiple of `2 ** (len(block_out_channels) -
50
- 1)`.
51
- in_channels (`int`, *optional*, defaults to 3): Number of channels in the input sample.
52
- out_channels (`int`, *optional*, defaults to 3): Number of channels in the output.
53
- center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
54
- time_embedding_type (`str`, *optional*, defaults to `"positional"`): Type of time embedding to use.
55
- freq_shift (`int`, *optional*, defaults to 0): Frequency shift for Fourier time embedding.
56
- flip_sin_to_cos (`bool`, *optional*, defaults to `True`):
57
- Whether to flip sin to cos for Fourier time embedding.
58
- down_block_types (`Tuple[str]`, *optional*, defaults to `("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D")`):
59
- Tuple of downsample block types.
60
- mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2D"`):
61
- Block type for middle of UNet, it can be either `UNetMidBlock2D` or `UnCLIPUNetMidBlock2D`.
62
- up_block_types (`Tuple[str]`, *optional*, defaults to `("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D")`):
63
- Tuple of upsample block types.
64
- block_out_channels (`Tuple[int]`, *optional*, defaults to `(224, 448, 672, 896)`):
65
- Tuple of block output channels.
66
- layers_per_block (`int`, *optional*, defaults to `2`): The number of layers per block.
67
- mid_block_scale_factor (`float`, *optional*, defaults to `1`): The scale factor for the mid block.
68
- downsample_padding (`int`, *optional*, defaults to `1`): The padding for the downsample convolution.
69
- downsample_type (`str`, *optional*, defaults to `conv`):
70
- The downsample type for downsampling layers. Choose between "conv" and "resnet"
71
- upsample_type (`str`, *optional*, defaults to `conv`):
72
- The upsample type for upsampling layers. Choose between "conv" and "resnet"
73
- act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
74
- attention_head_dim (`int`, *optional*, defaults to `8`): The attention head dimension.
75
- norm_num_groups (`int`, *optional*, defaults to `32`): The number of groups for normalization.
76
- norm_eps (`float`, *optional*, defaults to `1e-5`): The epsilon for normalization.
77
- resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
78
- for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
79
- class_embed_type (`str`, *optional*, defaults to `None`):
80
- The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
81
- `"timestep"`, or `"identity"`.
82
- num_class_embeds (`int`, *optional*, defaults to `None`):
83
- Input dimension of the learnable embedding matrix to be projected to `time_embed_dim` when performing class
84
- conditioning with `class_embed_type` equal to `None`.
85
- """
86
-
87
- @register_to_config
88
- def __init__(
89
- self,
90
- sample_size: Optional[Union[int, Tuple[int, int]]] = None,
91
- in_channels: int = 3,
92
- out_channels: int = 3,
93
- center_input_sample: bool = False,
94
- time_embedding_type: str = "positional",
95
- freq_shift: int = 0,
96
- flip_sin_to_cos: bool = True,
97
- down_block_types: Tuple[str] = ("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D"),
98
- up_block_types: Tuple[str] = ("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D"),
99
- block_out_channels: Tuple[int] = (224, 448, 672, 896),
100
- layers_per_block: int = 2,
101
- mid_block_scale_factor: float = 1,
102
- downsample_padding: int = 1,
103
- downsample_type: str = "conv",
104
- upsample_type: str = "conv",
105
- act_fn: str = "silu",
106
- attention_head_dim: Optional[int] = 8,
107
- norm_num_groups: int = 32,
108
- norm_eps: float = 1e-5,
109
- resnet_time_scale_shift: str = "default",
110
- add_attention: bool = True,
111
- class_embed_type: Optional[str] = None,
112
- num_class_embeds: Optional[int] = None,
113
- ):
114
- super().__init__()
115
-
116
- self.sample_size = sample_size
117
- time_embed_dim = block_out_channels[0] * 4
118
-
119
- # Check inputs
120
- if len(down_block_types) != len(up_block_types):
121
- raise ValueError(
122
- f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
123
- )
124
-
125
- if len(block_out_channels) != len(down_block_types):
126
- raise ValueError(
127
- f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
128
- )
129
-
130
- # input
131
- self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1))
132
-
133
- # time
134
- if time_embedding_type == "fourier":
135
- self.time_proj = GaussianFourierProjection(embedding_size=block_out_channels[0], scale=16)
136
- timestep_input_dim = 2 * block_out_channels[0]
137
- elif time_embedding_type == "positional":
138
- self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
139
- timestep_input_dim = block_out_channels[0]
140
-
141
- self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
142
-
143
- # class embedding
144
- if class_embed_type is None and num_class_embeds is not None:
145
- self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
146
- elif class_embed_type == "timestep":
147
- self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
148
- elif class_embed_type == "identity":
149
- self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
150
- else:
151
- self.class_embedding = None
152
-
153
- self.down_blocks = nn.ModuleList([])
154
- self.mid_block = None
155
- self.up_blocks = nn.ModuleList([])
156
-
157
- # down
158
- output_channel = block_out_channels[0]
159
- for i, down_block_type in enumerate(down_block_types):
160
- input_channel = output_channel
161
- output_channel = block_out_channels[i]
162
- is_final_block = i == len(block_out_channels) - 1
163
-
164
- down_block = get_down_block(
165
- down_block_type,
166
- num_layers=layers_per_block,
167
- in_channels=input_channel,
168
- out_channels=output_channel,
169
- temb_channels=time_embed_dim,
170
- add_downsample=not is_final_block,
171
- resnet_eps=norm_eps,
172
- resnet_act_fn=act_fn,
173
- resnet_groups=norm_num_groups,
174
- attention_head_dim=attention_head_dim if attention_head_dim is not None else output_channel,
175
- downsample_padding=downsample_padding,
176
- resnet_time_scale_shift=resnet_time_scale_shift,
177
- downsample_type=downsample_type,
178
- )
179
- self.down_blocks.append(down_block)
180
-
181
- # mid
182
- self.mid_block = UNetMidBlock2D(
183
- in_channels=block_out_channels[-1],
184
- temb_channels=time_embed_dim,
185
- resnet_eps=norm_eps,
186
- resnet_act_fn=act_fn,
187
- output_scale_factor=mid_block_scale_factor,
188
- resnet_time_scale_shift=resnet_time_scale_shift,
189
- attention_head_dim=attention_head_dim if attention_head_dim is not None else block_out_channels[-1],
190
- resnet_groups=norm_num_groups,
191
- add_attention=add_attention,
192
- )
193
-
194
- # up
195
- reversed_block_out_channels = list(reversed(block_out_channels))
196
- output_channel = reversed_block_out_channels[0]
197
- for i, up_block_type in enumerate(up_block_types):
198
- prev_output_channel = output_channel
199
- output_channel = reversed_block_out_channels[i]
200
- input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
201
-
202
- is_final_block = i == len(block_out_channels) - 1
203
-
204
- up_block = get_up_block(
205
- up_block_type,
206
- num_layers=layers_per_block + 1,
207
- in_channels=input_channel,
208
- out_channels=output_channel,
209
- prev_output_channel=prev_output_channel,
210
- temb_channels=time_embed_dim,
211
- add_upsample=not is_final_block,
212
- resnet_eps=norm_eps,
213
- resnet_act_fn=act_fn,
214
- resnet_groups=norm_num_groups,
215
- attention_head_dim=attention_head_dim if attention_head_dim is not None else output_channel,
216
- resnet_time_scale_shift=resnet_time_scale_shift,
217
- upsample_type=upsample_type,
218
- )
219
- self.up_blocks.append(up_block)
220
- prev_output_channel = output_channel
221
-
222
- # out
223
- num_groups_out = norm_num_groups if norm_num_groups is not None else min(block_out_channels[0] // 4, 32)
224
- self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=num_groups_out, eps=norm_eps)
225
- self.conv_act = nn.SiLU()
226
- self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
227
-
228
- def forward(
229
- self,
230
- sample: torch.FloatTensor,
231
- timestep: Union[torch.Tensor, float, int],
232
- class_labels: Optional[torch.Tensor] = None,
233
- return_dict: bool = True,
234
- ) -> Union[UNet2DOutput, Tuple]:
235
- r"""
236
- The [`UNet2DModel`] forward method.
237
-
238
- Args:
239
- sample (`torch.FloatTensor`):
240
- The noisy input tensor with the following shape `(batch, channel, height, width)`.
241
- timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
242
- class_labels (`torch.FloatTensor`, *optional*, defaults to `None`):
243
- Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
244
- return_dict (`bool`, *optional*, defaults to `True`):
245
- Whether or not to return a [`~models.unet_2d.UNet2DOutput`] instead of a plain tuple.
246
-
247
- Returns:
248
- [`~models.unet_2d.UNet2DOutput`] or `tuple`:
249
- If `return_dict` is True, an [`~models.unet_2d.UNet2DOutput`] is returned, otherwise a `tuple` is
250
- returned where the first element is the sample tensor.
251
- """
252
- # 0. center input if necessary
253
- if self.config.center_input_sample:
254
- sample = 2 * sample - 1.0
255
-
256
- # 1. time
257
- timesteps = timestep
258
- if not torch.is_tensor(timesteps):
259
- timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
260
- elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
261
- timesteps = timesteps[None].to(sample.device)
262
-
263
- # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
264
- timesteps = timesteps * torch.ones(sample.shape[0], dtype=timesteps.dtype, device=timesteps.device)
265
-
266
- t_emb = self.time_proj(timesteps)
267
-
268
- # timesteps does not contain any weights and will always return f32 tensors
269
- # but time_embedding might actually be running in fp16. so we need to cast here.
270
- # there might be better ways to encapsulate this.
271
- t_emb = t_emb.to(dtype=self.dtype)
272
- emb = self.time_embedding(t_emb)
273
-
274
- if self.class_embedding is not None:
275
- if class_labels is None:
276
- raise ValueError("class_labels should be provided when doing class conditioning")
277
-
278
- if self.config.class_embed_type == "timestep":
279
- class_labels = self.time_proj(class_labels)
280
-
281
- class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
282
- emb = emb + class_emb
283
-
284
- # 2. pre-process
285
- skip_sample = sample
286
- sample = self.conv_in(sample)
287
-
288
- # 3. down
289
- down_block_res_samples = (sample,)
290
- for downsample_block in self.down_blocks:
291
- if hasattr(downsample_block, "skip_conv"):
292
- sample, res_samples, skip_sample = downsample_block(
293
- hidden_states=sample, temb=emb, skip_sample=skip_sample
294
- )
295
- else:
296
- sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
297
-
298
- down_block_res_samples += res_samples
299
-
300
- # 4. mid
301
- sample = self.mid_block(sample, emb)
302
-
303
- # 5. up
304
- skip_sample = None
305
- for upsample_block in self.up_blocks:
306
- res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
307
- down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
308
-
309
- if hasattr(upsample_block, "skip_conv"):
310
- sample, skip_sample = upsample_block(sample, res_samples, emb, skip_sample)
311
- else:
312
- sample = upsample_block(sample, res_samples, emb)
313
-
314
- # 6. post-process
315
- sample = self.conv_norm_out(sample)
316
- sample = self.conv_act(sample)
317
- sample = self.conv_out(sample)
318
-
319
- if skip_sample is not None:
320
- sample += skip_sample
321
-
322
- if self.config.time_embedding_type == "fourier":
323
- timesteps = timesteps.reshape((sample.shape[0], *([1] * len(sample.shape[1:]))))
324
- sample = sample / timesteps
325
-
326
- if not return_dict:
327
- return (sample,)
328
-
329
- return UNet2DOutput(sample=sample)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
gradio_demo/6DoF/diffusers/models/unet_2d_blocks.py DELETED
The diff for this file is too large to render. See raw diff
 
gradio_demo/6DoF/diffusers/models/unet_2d_blocks_flax.py DELETED
@@ -1,377 +0,0 @@
1
- # Copyright 2023 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- import flax.linen as nn
16
- import jax.numpy as jnp
17
-
18
- from .attention_flax import FlaxTransformer2DModel
19
- from .resnet_flax import FlaxDownsample2D, FlaxResnetBlock2D, FlaxUpsample2D
20
-
21
-
22
- class FlaxCrossAttnDownBlock2D(nn.Module):
23
- r"""
24
- Cross Attention 2D Downsizing block - original architecture from Unet transformers:
25
- https://arxiv.org/abs/2103.06104
26
-
27
- Parameters:
28
- in_channels (:obj:`int`):
29
- Input channels
30
- out_channels (:obj:`int`):
31
- Output channels
32
- dropout (:obj:`float`, *optional*, defaults to 0.0):
33
- Dropout rate
34
- num_layers (:obj:`int`, *optional*, defaults to 1):
35
- Number of attention blocks layers
36
- num_attention_heads (:obj:`int`, *optional*, defaults to 1):
37
- Number of attention heads of each spatial transformer block
38
- add_downsample (:obj:`bool`, *optional*, defaults to `True`):
39
- Whether to add downsampling layer before each final output
40
- use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
41
- enable memory efficient attention https://arxiv.org/abs/2112.05682
42
- dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
43
- Parameters `dtype`
44
- """
45
- in_channels: int
46
- out_channels: int
47
- dropout: float = 0.0
48
- num_layers: int = 1
49
- num_attention_heads: int = 1
50
- add_downsample: bool = True
51
- use_linear_projection: bool = False
52
- only_cross_attention: bool = False
53
- use_memory_efficient_attention: bool = False
54
- dtype: jnp.dtype = jnp.float32
55
-
56
- def setup(self):
57
- resnets = []
58
- attentions = []
59
-
60
- for i in range(self.num_layers):
61
- in_channels = self.in_channels if i == 0 else self.out_channels
62
-
63
- res_block = FlaxResnetBlock2D(
64
- in_channels=in_channels,
65
- out_channels=self.out_channels,
66
- dropout_prob=self.dropout,
67
- dtype=self.dtype,
68
- )
69
- resnets.append(res_block)
70
-
71
- attn_block = FlaxTransformer2DModel(
72
- in_channels=self.out_channels,
73
- n_heads=self.num_attention_heads,
74
- d_head=self.out_channels // self.num_attention_heads,
75
- depth=1,
76
- use_linear_projection=self.use_linear_projection,
77
- only_cross_attention=self.only_cross_attention,
78
- use_memory_efficient_attention=self.use_memory_efficient_attention,
79
- dtype=self.dtype,
80
- )
81
- attentions.append(attn_block)
82
-
83
- self.resnets = resnets
84
- self.attentions = attentions
85
-
86
- if self.add_downsample:
87
- self.downsamplers_0 = FlaxDownsample2D(self.out_channels, dtype=self.dtype)
88
-
89
- def __call__(self, hidden_states, temb, encoder_hidden_states, deterministic=True):
90
- output_states = ()
91
-
92
- for resnet, attn in zip(self.resnets, self.attentions):
93
- hidden_states = resnet(hidden_states, temb, deterministic=deterministic)
94
- hidden_states = attn(hidden_states, encoder_hidden_states, deterministic=deterministic)
95
- output_states += (hidden_states,)
96
-
97
- if self.add_downsample:
98
- hidden_states = self.downsamplers_0(hidden_states)
99
- output_states += (hidden_states,)
100
-
101
- return hidden_states, output_states
102
-
103
-
104
- class FlaxDownBlock2D(nn.Module):
105
- r"""
106
- Flax 2D downsizing block
107
-
108
- Parameters:
109
- in_channels (:obj:`int`):
110
- Input channels
111
- out_channels (:obj:`int`):
112
- Output channels
113
- dropout (:obj:`float`, *optional*, defaults to 0.0):
114
- Dropout rate
115
- num_layers (:obj:`int`, *optional*, defaults to 1):
116
- Number of attention blocks layers
117
- add_downsample (:obj:`bool`, *optional*, defaults to `True`):
118
- Whether to add downsampling layer before each final output
119
- dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
120
- Parameters `dtype`
121
- """
122
- in_channels: int
123
- out_channels: int
124
- dropout: float = 0.0
125
- num_layers: int = 1
126
- add_downsample: bool = True
127
- dtype: jnp.dtype = jnp.float32
128
-
129
- def setup(self):
130
- resnets = []
131
-
132
- for i in range(self.num_layers):
133
- in_channels = self.in_channels if i == 0 else self.out_channels
134
-
135
- res_block = FlaxResnetBlock2D(
136
- in_channels=in_channels,
137
- out_channels=self.out_channels,
138
- dropout_prob=self.dropout,
139
- dtype=self.dtype,
140
- )
141
- resnets.append(res_block)
142
- self.resnets = resnets
143
-
144
- if self.add_downsample:
145
- self.downsamplers_0 = FlaxDownsample2D(self.out_channels, dtype=self.dtype)
146
-
147
- def __call__(self, hidden_states, temb, deterministic=True):
148
- output_states = ()
149
-
150
- for resnet in self.resnets:
151
- hidden_states = resnet(hidden_states, temb, deterministic=deterministic)
152
- output_states += (hidden_states,)
153
-
154
- if self.add_downsample:
155
- hidden_states = self.downsamplers_0(hidden_states)
156
- output_states += (hidden_states,)
157
-
158
- return hidden_states, output_states
159
-
160
-
161
- class FlaxCrossAttnUpBlock2D(nn.Module):
162
- r"""
163
- Cross Attention 2D Upsampling block - original architecture from Unet transformers:
164
- https://arxiv.org/abs/2103.06104
165
-
166
- Parameters:
167
- in_channels (:obj:`int`):
168
- Input channels
169
- out_channels (:obj:`int`):
170
- Output channels
171
- dropout (:obj:`float`, *optional*, defaults to 0.0):
172
- Dropout rate
173
- num_layers (:obj:`int`, *optional*, defaults to 1):
174
- Number of attention blocks layers
175
- num_attention_heads (:obj:`int`, *optional*, defaults to 1):
176
- Number of attention heads of each spatial transformer block
177
- add_upsample (:obj:`bool`, *optional*, defaults to `True`):
178
- Whether to add upsampling layer before each final output
179
- use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
180
- enable memory efficient attention https://arxiv.org/abs/2112.05682
181
- dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
182
- Parameters `dtype`
183
- """
184
- in_channels: int
185
- out_channels: int
186
- prev_output_channel: int
187
- dropout: float = 0.0
188
- num_layers: int = 1
189
- num_attention_heads: int = 1
190
- add_upsample: bool = True
191
- use_linear_projection: bool = False
192
- only_cross_attention: bool = False
193
- use_memory_efficient_attention: bool = False
194
- dtype: jnp.dtype = jnp.float32
195
-
196
- def setup(self):
197
- resnets = []
198
- attentions = []
199
-
200
- for i in range(self.num_layers):
201
- res_skip_channels = self.in_channels if (i == self.num_layers - 1) else self.out_channels
202
- resnet_in_channels = self.prev_output_channel if i == 0 else self.out_channels
203
-
204
- res_block = FlaxResnetBlock2D(
205
- in_channels=resnet_in_channels + res_skip_channels,
206
- out_channels=self.out_channels,
207
- dropout_prob=self.dropout,
208
- dtype=self.dtype,
209
- )
210
- resnets.append(res_block)
211
-
212
- attn_block = FlaxTransformer2DModel(
213
- in_channels=self.out_channels,
214
- n_heads=self.num_attention_heads,
215
- d_head=self.out_channels // self.num_attention_heads,
216
- depth=1,
217
- use_linear_projection=self.use_linear_projection,
218
- only_cross_attention=self.only_cross_attention,
219
- use_memory_efficient_attention=self.use_memory_efficient_attention,
220
- dtype=self.dtype,
221
- )
222
- attentions.append(attn_block)
223
-
224
- self.resnets = resnets
225
- self.attentions = attentions
226
-
227
- if self.add_upsample:
228
- self.upsamplers_0 = FlaxUpsample2D(self.out_channels, dtype=self.dtype)
229
-
230
- def __call__(self, hidden_states, res_hidden_states_tuple, temb, encoder_hidden_states, deterministic=True):
231
- for resnet, attn in zip(self.resnets, self.attentions):
232
- # pop res hidden states
233
- res_hidden_states = res_hidden_states_tuple[-1]
234
- res_hidden_states_tuple = res_hidden_states_tuple[:-1]
235
- hidden_states = jnp.concatenate((hidden_states, res_hidden_states), axis=-1)
236
-
237
- hidden_states = resnet(hidden_states, temb, deterministic=deterministic)
238
- hidden_states = attn(hidden_states, encoder_hidden_states, deterministic=deterministic)
239
-
240
- if self.add_upsample:
241
- hidden_states = self.upsamplers_0(hidden_states)
242
-
243
- return hidden_states
244
-
245
-
246
- class FlaxUpBlock2D(nn.Module):
247
- r"""
248
- Flax 2D upsampling block
249
-
250
- Parameters:
251
- in_channels (:obj:`int`):
252
- Input channels
253
- out_channels (:obj:`int`):
254
- Output channels
255
- prev_output_channel (:obj:`int`):
256
- Output channels from the previous block
257
- dropout (:obj:`float`, *optional*, defaults to 0.0):
258
- Dropout rate
259
- num_layers (:obj:`int`, *optional*, defaults to 1):
260
- Number of attention blocks layers
261
- add_downsample (:obj:`bool`, *optional*, defaults to `True`):
262
- Whether to add downsampling layer before each final output
263
- dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
264
- Parameters `dtype`
265
- """
266
- in_channels: int
267
- out_channels: int
268
- prev_output_channel: int
269
- dropout: float = 0.0
270
- num_layers: int = 1
271
- add_upsample: bool = True
272
- dtype: jnp.dtype = jnp.float32
273
-
274
- def setup(self):
275
- resnets = []
276
-
277
- for i in range(self.num_layers):
278
- res_skip_channels = self.in_channels if (i == self.num_layers - 1) else self.out_channels
279
- resnet_in_channels = self.prev_output_channel if i == 0 else self.out_channels
280
-
281
- res_block = FlaxResnetBlock2D(
282
- in_channels=resnet_in_channels + res_skip_channels,
283
- out_channels=self.out_channels,
284
- dropout_prob=self.dropout,
285
- dtype=self.dtype,
286
- )
287
- resnets.append(res_block)
288
-
289
- self.resnets = resnets
290
-
291
- if self.add_upsample:
292
- self.upsamplers_0 = FlaxUpsample2D(self.out_channels, dtype=self.dtype)
293
-
294
- def __call__(self, hidden_states, res_hidden_states_tuple, temb, deterministic=True):
295
- for resnet in self.resnets:
296
- # pop res hidden states
297
- res_hidden_states = res_hidden_states_tuple[-1]
298
- res_hidden_states_tuple = res_hidden_states_tuple[:-1]
299
- hidden_states = jnp.concatenate((hidden_states, res_hidden_states), axis=-1)
300
-
301
- hidden_states = resnet(hidden_states, temb, deterministic=deterministic)
302
-
303
- if self.add_upsample:
304
- hidden_states = self.upsamplers_0(hidden_states)
305
-
306
- return hidden_states
307
-
308
-
309
- class FlaxUNetMidBlock2DCrossAttn(nn.Module):
310
- r"""
311
- Cross Attention 2D Mid-level block - original architecture from Unet transformers: https://arxiv.org/abs/2103.06104
312
-
313
- Parameters:
314
- in_channels (:obj:`int`):
315
- Input channels
316
- dropout (:obj:`float`, *optional*, defaults to 0.0):
317
- Dropout rate
318
- num_layers (:obj:`int`, *optional*, defaults to 1):
319
- Number of attention blocks layers
320
- num_attention_heads (:obj:`int`, *optional*, defaults to 1):
321
- Number of attention heads of each spatial transformer block
322
- use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
323
- enable memory efficient attention https://arxiv.org/abs/2112.05682
324
- dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
325
- Parameters `dtype`
326
- """
327
- in_channels: int
328
- dropout: float = 0.0
329
- num_layers: int = 1
330
- num_attention_heads: int = 1
331
- use_linear_projection: bool = False
332
- use_memory_efficient_attention: bool = False
333
- dtype: jnp.dtype = jnp.float32
334
-
335
- def setup(self):
336
- # there is always at least one resnet
337
- resnets = [
338
- FlaxResnetBlock2D(
339
- in_channels=self.in_channels,
340
- out_channels=self.in_channels,
341
- dropout_prob=self.dropout,
342
- dtype=self.dtype,
343
- )
344
- ]
345
-
346
- attentions = []
347
-
348
- for _ in range(self.num_layers):
349
- attn_block = FlaxTransformer2DModel(
350
- in_channels=self.in_channels,
351
- n_heads=self.num_attention_heads,
352
- d_head=self.in_channels // self.num_attention_heads,
353
- depth=1,
354
- use_linear_projection=self.use_linear_projection,
355
- use_memory_efficient_attention=self.use_memory_efficient_attention,
356
- dtype=self.dtype,
357
- )
358
- attentions.append(attn_block)
359
-
360
- res_block = FlaxResnetBlock2D(
361
- in_channels=self.in_channels,
362
- out_channels=self.in_channels,
363
- dropout_prob=self.dropout,
364
- dtype=self.dtype,
365
- )
366
- resnets.append(res_block)
367
-
368
- self.resnets = resnets
369
- self.attentions = attentions
370
-
371
- def __call__(self, hidden_states, temb, encoder_hidden_states, deterministic=True):
372
- hidden_states = self.resnets[0](hidden_states, temb)
373
- for attn, resnet in zip(self.attentions, self.resnets[1:]):
374
- hidden_states = attn(hidden_states, encoder_hidden_states, deterministic=deterministic)
375
- hidden_states = resnet(hidden_states, temb, deterministic=deterministic)
376
-
377
- return hidden_states
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
gradio_demo/6DoF/diffusers/models/unet_2d_condition.py DELETED
@@ -1,980 +0,0 @@
1
- # Copyright 2023 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- from dataclasses import dataclass
15
- from typing import Any, Dict, List, Optional, Tuple, Union
16
-
17
- import torch
18
- import torch.nn as nn
19
- import torch.utils.checkpoint
20
-
21
- from ..configuration_utils import ConfigMixin, register_to_config
22
- from ..loaders import UNet2DConditionLoadersMixin
23
- from ..utils import BaseOutput, logging
24
- from .activations import get_activation
25
- from .attention_processor import AttentionProcessor, AttnProcessor
26
- from .embeddings import (
27
- GaussianFourierProjection,
28
- ImageHintTimeEmbedding,
29
- ImageProjection,
30
- ImageTimeEmbedding,
31
- TextImageProjection,
32
- TextImageTimeEmbedding,
33
- TextTimeEmbedding,
34
- TimestepEmbedding,
35
- Timesteps,
36
- )
37
- from .modeling_utils import ModelMixin
38
- from .unet_2d_blocks import (
39
- CrossAttnDownBlock2D,
40
- CrossAttnUpBlock2D,
41
- DownBlock2D,
42
- UNetMidBlock2DCrossAttn,
43
- UNetMidBlock2DSimpleCrossAttn,
44
- UpBlock2D,
45
- get_down_block,
46
- get_up_block,
47
- )
48
-
49
-
50
- logger = logging.get_logger(__name__) # pylint: disable=invalid-name
51
-
52
-
53
- @dataclass
54
- class UNet2DConditionOutput(BaseOutput):
55
- """
56
- The output of [`UNet2DConditionModel`].
57
-
58
- Args:
59
- sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
60
- The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
61
- """
62
-
63
- sample: torch.FloatTensor = None
64
-
65
-
66
- class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
67
- r"""
68
- A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
69
- shaped output.
70
-
71
- This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
72
- for all models (such as downloading or saving).
73
-
74
- Parameters:
75
- sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
76
- Height and width of input/output sample.
77
- in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample.
78
- out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
79
- center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
80
- flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
81
- Whether to flip the sin to cos in the time embedding.
82
- freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
83
- down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
84
- The tuple of downsample blocks to use.
85
- mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
86
- Block type for middle of UNet, it can be either `UNetMidBlock2DCrossAttn` or
87
- `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.
88
- up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
89
- The tuple of upsample blocks to use.
90
- only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):
91
- Whether to include self-attention in the basic transformer blocks, see
92
- [`~models.attention.BasicTransformerBlock`].
93
- block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
94
- The tuple of output channels for each block.
95
- layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
96
- downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
97
- mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
98
- act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
99
- norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
100
- If `None`, normalization and activation layers is skipped in post-processing.
101
- norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
102
- cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
103
- The dimension of the cross attention features.
104
- transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
105
- The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
106
- [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
107
- [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
108
- encoder_hid_dim (`int`, *optional*, defaults to None):
109
- If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
110
- dimension to `cross_attention_dim`.
111
- encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
112
- If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
113
- embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
114
- attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
115
- num_attention_heads (`int`, *optional*):
116
- The number of attention heads. If not defined, defaults to `attention_head_dim`
117
- resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
118
- for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
119
- class_embed_type (`str`, *optional*, defaults to `None`):
120
- The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
121
- `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
122
- addition_embed_type (`str`, *optional*, defaults to `None`):
123
- Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
124
- "text". "text" will use the `TextTimeEmbedding` layer.
125
- addition_time_embed_dim: (`int`, *optional*, defaults to `None`):
126
- Dimension for the timestep embeddings.
127
- num_class_embeds (`int`, *optional*, defaults to `None`):
128
- Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
129
- class conditioning with `class_embed_type` equal to `None`.
130
- time_embedding_type (`str`, *optional*, defaults to `positional`):
131
- The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
132
- time_embedding_dim (`int`, *optional*, defaults to `None`):
133
- An optional override for the dimension of the projected time embedding.
134
- time_embedding_act_fn (`str`, *optional*, defaults to `None`):
135
- Optional activation function to use only once on the time embeddings before they are passed to the rest of
136
- the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`.
137
- timestep_post_act (`str`, *optional*, defaults to `None`):
138
- The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
139
- time_cond_proj_dim (`int`, *optional*, defaults to `None`):
140
- The dimension of `cond_proj` layer in the timestep embedding.
141
- conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer.
142
- conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer.
143
- projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when
144
- `class_embed_type="projection"`. Required when `class_embed_type="projection"`.
145
- class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
146
- embeddings with the class embeddings.
147
- mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):
148
- Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If
149
- `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the
150
- `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False`
151
- otherwise.
152
- """
153
-
154
- _supports_gradient_checkpointing = True
155
-
156
- @register_to_config
157
- def __init__(
158
- self,
159
- sample_size: Optional[int] = None,
160
- in_channels: int = 4,
161
- out_channels: int = 4,
162
- center_input_sample: bool = False,
163
- flip_sin_to_cos: bool = True,
164
- freq_shift: int = 0,
165
- down_block_types: Tuple[str] = (
166
- "CrossAttnDownBlock2D",
167
- "CrossAttnDownBlock2D",
168
- "CrossAttnDownBlock2D",
169
- "DownBlock2D",
170
- ),
171
- mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
172
- up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
173
- only_cross_attention: Union[bool, Tuple[bool]] = False,
174
- block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
175
- layers_per_block: Union[int, Tuple[int]] = 2,
176
- downsample_padding: int = 1,
177
- mid_block_scale_factor: float = 1,
178
- act_fn: str = "silu",
179
- norm_num_groups: Optional[int] = 32,
180
- norm_eps: float = 1e-5,
181
- cross_attention_dim: Union[int, Tuple[int]] = 1280,
182
- transformer_layers_per_block: Union[int, Tuple[int]] = 1,
183
- encoder_hid_dim: Optional[int] = None,
184
- encoder_hid_dim_type: Optional[str] = None,
185
- attention_head_dim: Union[int, Tuple[int]] = 8,
186
- num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
187
- dual_cross_attention: bool = False,
188
- use_linear_projection: bool = False,
189
- class_embed_type: Optional[str] = None,
190
- addition_embed_type: Optional[str] = None,
191
- addition_time_embed_dim: Optional[int] = None,
192
- num_class_embeds: Optional[int] = None,
193
- upcast_attention: bool = False,
194
- resnet_time_scale_shift: str = "default",
195
- resnet_skip_time_act: bool = False,
196
- resnet_out_scale_factor: int = 1.0,
197
- time_embedding_type: str = "positional",
198
- time_embedding_dim: Optional[int] = None,
199
- time_embedding_act_fn: Optional[str] = None,
200
- timestep_post_act: Optional[str] = None,
201
- time_cond_proj_dim: Optional[int] = None,
202
- conv_in_kernel: int = 3,
203
- conv_out_kernel: int = 3,
204
- projection_class_embeddings_input_dim: Optional[int] = None,
205
- class_embeddings_concat: bool = False,
206
- mid_block_only_cross_attention: Optional[bool] = None,
207
- cross_attention_norm: Optional[str] = None,
208
- addition_embed_type_num_heads=64,
209
- ):
210
- super().__init__()
211
-
212
- self.sample_size = sample_size
213
-
214
- if num_attention_heads is not None:
215
- raise ValueError(
216
- "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
217
- )
218
-
219
- # If `num_attention_heads` is not defined (which is the case for most models)
220
- # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
221
- # The reason for this behavior is to correct for incorrectly named variables that were introduced
222
- # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
223
- # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
224
- # which is why we correct for the naming here.
225
- num_attention_heads = num_attention_heads or attention_head_dim
226
-
227
- # Check inputs
228
- if len(down_block_types) != len(up_block_types):
229
- raise ValueError(
230
- f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
231
- )
232
-
233
- if len(block_out_channels) != len(down_block_types):
234
- raise ValueError(
235
- f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
236
- )
237
-
238
- if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
239
- raise ValueError(
240
- f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
241
- )
242
-
243
- if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
244
- raise ValueError(
245
- f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
246
- )
247
-
248
- if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
249
- raise ValueError(
250
- f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
251
- )
252
-
253
- if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
254
- raise ValueError(
255
- f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
256
- )
257
-
258
- if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
259
- raise ValueError(
260
- f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
261
- )
262
-
263
- # input
264
- conv_in_padding = (conv_in_kernel - 1) // 2
265
- self.conv_in = nn.Conv2d(
266
- in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
267
- )
268
-
269
- # time
270
- if time_embedding_type == "fourier":
271
- time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
272
- if time_embed_dim % 2 != 0:
273
- raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")
274
- self.time_proj = GaussianFourierProjection(
275
- time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
276
- )
277
- timestep_input_dim = time_embed_dim
278
- elif time_embedding_type == "positional":
279
- time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
280
-
281
- self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
282
- timestep_input_dim = block_out_channels[0]
283
- else:
284
- raise ValueError(
285
- f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
286
- )
287
-
288
- self.time_embedding = TimestepEmbedding(
289
- timestep_input_dim,
290
- time_embed_dim,
291
- act_fn=act_fn,
292
- post_act_fn=timestep_post_act,
293
- cond_proj_dim=time_cond_proj_dim,
294
- )
295
-
296
- if encoder_hid_dim_type is None and encoder_hid_dim is not None:
297
- encoder_hid_dim_type = "text_proj"
298
- self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
299
- logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
300
-
301
- if encoder_hid_dim is None and encoder_hid_dim_type is not None:
302
- raise ValueError(
303
- f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
304
- )
305
-
306
- if encoder_hid_dim_type == "text_proj":
307
- self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
308
- elif encoder_hid_dim_type == "text_image_proj":
309
- # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
310
- # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
311
- # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
312
- self.encoder_hid_proj = TextImageProjection(
313
- text_embed_dim=encoder_hid_dim,
314
- image_embed_dim=cross_attention_dim,
315
- cross_attention_dim=cross_attention_dim,
316
- )
317
- elif encoder_hid_dim_type == "image_proj":
318
- # Kandinsky 2.2
319
- self.encoder_hid_proj = ImageProjection(
320
- image_embed_dim=encoder_hid_dim,
321
- cross_attention_dim=cross_attention_dim,
322
- )
323
- elif encoder_hid_dim_type is not None:
324
- raise ValueError(
325
- f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
326
- )
327
- else:
328
- self.encoder_hid_proj = None
329
-
330
- # class embedding
331
- if class_embed_type is None and num_class_embeds is not None:
332
- self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
333
- elif class_embed_type == "timestep":
334
- self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)
335
- elif class_embed_type == "identity":
336
- self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
337
- elif class_embed_type == "projection":
338
- if projection_class_embeddings_input_dim is None:
339
- raise ValueError(
340
- "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
341
- )
342
- # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
343
- # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
344
- # 2. it projects from an arbitrary input dimension.
345
- #
346
- # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
347
- # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
348
- # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
349
- self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
350
- elif class_embed_type == "simple_projection":
351
- if projection_class_embeddings_input_dim is None:
352
- raise ValueError(
353
- "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
354
- )
355
- self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim)
356
- else:
357
- self.class_embedding = None
358
-
359
- if addition_embed_type == "text":
360
- if encoder_hid_dim is not None:
361
- text_time_embedding_from_dim = encoder_hid_dim
362
- else:
363
- text_time_embedding_from_dim = cross_attention_dim
364
-
365
- self.add_embedding = TextTimeEmbedding(
366
- text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
367
- )
368
- elif addition_embed_type == "text_image":
369
- # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
370
- # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
371
- # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
372
- self.add_embedding = TextImageTimeEmbedding(
373
- text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
374
- )
375
- elif addition_embed_type == "text_time":
376
- self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
377
- self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
378
- elif addition_embed_type == "image":
379
- # Kandinsky 2.2
380
- self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
381
- elif addition_embed_type == "image_hint":
382
- # Kandinsky 2.2 ControlNet
383
- self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
384
- elif addition_embed_type is not None:
385
- raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
386
-
387
- if time_embedding_act_fn is None:
388
- self.time_embed_act = None
389
- else:
390
- self.time_embed_act = get_activation(time_embedding_act_fn)
391
-
392
- self.down_blocks = nn.ModuleList([])
393
- self.up_blocks = nn.ModuleList([])
394
-
395
- if isinstance(only_cross_attention, bool):
396
- if mid_block_only_cross_attention is None:
397
- mid_block_only_cross_attention = only_cross_attention
398
-
399
- only_cross_attention = [only_cross_attention] * len(down_block_types)
400
-
401
- if mid_block_only_cross_attention is None:
402
- mid_block_only_cross_attention = False
403
-
404
- if isinstance(num_attention_heads, int):
405
- num_attention_heads = (num_attention_heads,) * len(down_block_types)
406
-
407
- if isinstance(attention_head_dim, int):
408
- attention_head_dim = (attention_head_dim,) * len(down_block_types)
409
-
410
- if isinstance(cross_attention_dim, int):
411
- cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
412
-
413
- if isinstance(layers_per_block, int):
414
- layers_per_block = [layers_per_block] * len(down_block_types)
415
-
416
- if isinstance(transformer_layers_per_block, int):
417
- transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
418
-
419
- if class_embeddings_concat:
420
- # The time embeddings are concatenated with the class embeddings. The dimension of the
421
- # time embeddings passed to the down, middle, and up blocks is twice the dimension of the
422
- # regular time embeddings
423
- blocks_time_embed_dim = time_embed_dim * 2
424
- else:
425
- blocks_time_embed_dim = time_embed_dim
426
-
427
- # down
428
- output_channel = block_out_channels[0]
429
- for i, down_block_type in enumerate(down_block_types):
430
- input_channel = output_channel
431
- output_channel = block_out_channels[i]
432
- is_final_block = i == len(block_out_channels) - 1
433
-
434
- down_block = get_down_block(
435
- down_block_type,
436
- num_layers=layers_per_block[i],
437
- transformer_layers_per_block=transformer_layers_per_block[i],
438
- in_channels=input_channel,
439
- out_channels=output_channel,
440
- temb_channels=blocks_time_embed_dim,
441
- add_downsample=not is_final_block,
442
- resnet_eps=norm_eps,
443
- resnet_act_fn=act_fn,
444
- resnet_groups=norm_num_groups,
445
- cross_attention_dim=cross_attention_dim[i],
446
- num_attention_heads=num_attention_heads[i],
447
- downsample_padding=downsample_padding,
448
- dual_cross_attention=dual_cross_attention,
449
- use_linear_projection=use_linear_projection,
450
- only_cross_attention=only_cross_attention[i],
451
- upcast_attention=upcast_attention,
452
- resnet_time_scale_shift=resnet_time_scale_shift,
453
- resnet_skip_time_act=resnet_skip_time_act,
454
- resnet_out_scale_factor=resnet_out_scale_factor,
455
- cross_attention_norm=cross_attention_norm,
456
- attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
457
- )
458
- self.down_blocks.append(down_block)
459
-
460
- # mid
461
- if mid_block_type == "UNetMidBlock2DCrossAttn":
462
- self.mid_block = UNetMidBlock2DCrossAttn(
463
- transformer_layers_per_block=transformer_layers_per_block[-1],
464
- in_channels=block_out_channels[-1],
465
- temb_channels=blocks_time_embed_dim,
466
- resnet_eps=norm_eps,
467
- resnet_act_fn=act_fn,
468
- output_scale_factor=mid_block_scale_factor,
469
- resnet_time_scale_shift=resnet_time_scale_shift,
470
- cross_attention_dim=cross_attention_dim[-1],
471
- num_attention_heads=num_attention_heads[-1],
472
- resnet_groups=norm_num_groups,
473
- dual_cross_attention=dual_cross_attention,
474
- use_linear_projection=use_linear_projection,
475
- upcast_attention=upcast_attention,
476
- )
477
- elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn":
478
- self.mid_block = UNetMidBlock2DSimpleCrossAttn(
479
- in_channels=block_out_channels[-1],
480
- temb_channels=blocks_time_embed_dim,
481
- resnet_eps=norm_eps,
482
- resnet_act_fn=act_fn,
483
- output_scale_factor=mid_block_scale_factor,
484
- cross_attention_dim=cross_attention_dim[-1],
485
- attention_head_dim=attention_head_dim[-1],
486
- resnet_groups=norm_num_groups,
487
- resnet_time_scale_shift=resnet_time_scale_shift,
488
- skip_time_act=resnet_skip_time_act,
489
- only_cross_attention=mid_block_only_cross_attention,
490
- cross_attention_norm=cross_attention_norm,
491
- )
492
- elif mid_block_type is None:
493
- self.mid_block = None
494
- else:
495
- raise ValueError(f"unknown mid_block_type : {mid_block_type}")
496
-
497
- # count how many layers upsample the images
498
- self.num_upsamplers = 0
499
-
500
- # up
501
- reversed_block_out_channels = list(reversed(block_out_channels))
502
- reversed_num_attention_heads = list(reversed(num_attention_heads))
503
- reversed_layers_per_block = list(reversed(layers_per_block))
504
- reversed_cross_attention_dim = list(reversed(cross_attention_dim))
505
- reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
506
- only_cross_attention = list(reversed(only_cross_attention))
507
-
508
- output_channel = reversed_block_out_channels[0]
509
- for i, up_block_type in enumerate(up_block_types):
510
- is_final_block = i == len(block_out_channels) - 1
511
-
512
- prev_output_channel = output_channel
513
- output_channel = reversed_block_out_channels[i]
514
- input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
515
-
516
- # add upsample block for all BUT final layer
517
- if not is_final_block:
518
- add_upsample = True
519
- self.num_upsamplers += 1
520
- else:
521
- add_upsample = False
522
-
523
- up_block = get_up_block(
524
- up_block_type,
525
- num_layers=reversed_layers_per_block[i] + 1,
526
- transformer_layers_per_block=reversed_transformer_layers_per_block[i],
527
- in_channels=input_channel,
528
- out_channels=output_channel,
529
- prev_output_channel=prev_output_channel,
530
- temb_channels=blocks_time_embed_dim,
531
- add_upsample=add_upsample,
532
- resnet_eps=norm_eps,
533
- resnet_act_fn=act_fn,
534
- resnet_groups=norm_num_groups,
535
- cross_attention_dim=reversed_cross_attention_dim[i],
536
- num_attention_heads=reversed_num_attention_heads[i],
537
- dual_cross_attention=dual_cross_attention,
538
- use_linear_projection=use_linear_projection,
539
- only_cross_attention=only_cross_attention[i],
540
- upcast_attention=upcast_attention,
541
- resnet_time_scale_shift=resnet_time_scale_shift,
542
- resnet_skip_time_act=resnet_skip_time_act,
543
- resnet_out_scale_factor=resnet_out_scale_factor,
544
- cross_attention_norm=cross_attention_norm,
545
- attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
546
- )
547
- self.up_blocks.append(up_block)
548
- prev_output_channel = output_channel
549
-
550
- # out
551
- if norm_num_groups is not None:
552
- self.conv_norm_out = nn.GroupNorm(
553
- num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
554
- )
555
-
556
- self.conv_act = get_activation(act_fn)
557
-
558
- else:
559
- self.conv_norm_out = None
560
- self.conv_act = None
561
-
562
- conv_out_padding = (conv_out_kernel - 1) // 2
563
- self.conv_out = nn.Conv2d(
564
- block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
565
- )
566
-
567
- @property
568
- def attn_processors(self) -> Dict[str, AttentionProcessor]:
569
- r"""
570
- Returns:
571
- `dict` of attention processors: A dictionary containing all attention processors used in the model with
572
- indexed by its weight name.
573
- """
574
- # set recursively
575
- processors = {}
576
-
577
- def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
578
- if hasattr(module, "set_processor"):
579
- processors[f"{name}.processor"] = module.processor
580
-
581
- for sub_name, child in module.named_children():
582
- fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
583
-
584
- return processors
585
-
586
- for name, module in self.named_children():
587
- fn_recursive_add_processors(name, module, processors)
588
-
589
- return processors
590
-
591
- def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
592
- r"""
593
- Sets the attention processor to use to compute attention.
594
-
595
- Parameters:
596
- processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
597
- The instantiated processor class or a dictionary of processor classes that will be set as the processor
598
- for **all** `Attention` layers.
599
-
600
- If `processor` is a dict, the key needs to define the path to the corresponding cross attention
601
- processor. This is strongly recommended when setting trainable attention processors.
602
-
603
- """
604
- count = len(self.attn_processors.keys())
605
-
606
- if isinstance(processor, dict) and len(processor) != count:
607
- raise ValueError(
608
- f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
609
- f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
610
- )
611
-
612
- def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
613
- if hasattr(module, "set_processor"):
614
- if not isinstance(processor, dict):
615
- module.set_processor(processor)
616
- else:
617
- module.set_processor(processor.pop(f"{name}.processor"))
618
-
619
- for sub_name, child in module.named_children():
620
- fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
621
-
622
- for name, module in self.named_children():
623
- fn_recursive_attn_processor(name, module, processor)
624
-
625
- def set_default_attn_processor(self):
626
- """
627
- Disables custom attention processors and sets the default attention implementation.
628
- """
629
- self.set_attn_processor(AttnProcessor())
630
-
631
- def set_attention_slice(self, slice_size):
632
- r"""
633
- Enable sliced attention computation.
634
-
635
- When this option is enabled, the attention module splits the input tensor in slices to compute attention in
636
- several steps. This is useful for saving some memory in exchange for a small decrease in speed.
637
-
638
- Args:
639
- slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
640
- When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
641
- `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
642
- provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
643
- must be a multiple of `slice_size`.
644
- """
645
- sliceable_head_dims = []
646
-
647
- def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
648
- if hasattr(module, "set_attention_slice"):
649
- sliceable_head_dims.append(module.sliceable_head_dim)
650
-
651
- for child in module.children():
652
- fn_recursive_retrieve_sliceable_dims(child)
653
-
654
- # retrieve number of attention layers
655
- for module in self.children():
656
- fn_recursive_retrieve_sliceable_dims(module)
657
-
658
- num_sliceable_layers = len(sliceable_head_dims)
659
-
660
- if slice_size == "auto":
661
- # half the attention head size is usually a good trade-off between
662
- # speed and memory
663
- slice_size = [dim // 2 for dim in sliceable_head_dims]
664
- elif slice_size == "max":
665
- # make smallest slice possible
666
- slice_size = num_sliceable_layers * [1]
667
-
668
- slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
669
-
670
- if len(slice_size) != len(sliceable_head_dims):
671
- raise ValueError(
672
- f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
673
- f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
674
- )
675
-
676
- for i in range(len(slice_size)):
677
- size = slice_size[i]
678
- dim = sliceable_head_dims[i]
679
- if size is not None and size > dim:
680
- raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
681
-
682
- # Recursively walk through all the children.
683
- # Any children which exposes the set_attention_slice method
684
- # gets the message
685
- def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
686
- if hasattr(module, "set_attention_slice"):
687
- module.set_attention_slice(slice_size.pop())
688
-
689
- for child in module.children():
690
- fn_recursive_set_attention_slice(child, slice_size)
691
-
692
- reversed_slice_size = list(reversed(slice_size))
693
- for module in self.children():
694
- fn_recursive_set_attention_slice(module, reversed_slice_size)
695
-
696
- def _set_gradient_checkpointing(self, module, value=False):
697
- if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D, CrossAttnUpBlock2D, UpBlock2D)):
698
- module.gradient_checkpointing = value
699
-
700
- def forward(
701
- self,
702
- sample: torch.FloatTensor,
703
- timestep: Union[torch.Tensor, float, int],
704
- encoder_hidden_states: torch.Tensor,
705
- class_labels: Optional[torch.Tensor] = None,
706
- timestep_cond: Optional[torch.Tensor] = None,
707
- attention_mask: Optional[torch.Tensor] = None,
708
- cross_attention_kwargs: Optional[Dict[str, Any]] = None,
709
- added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
710
- down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
711
- mid_block_additional_residual: Optional[torch.Tensor] = None,
712
- encoder_attention_mask: Optional[torch.Tensor] = None,
713
- return_dict: bool = True,
714
- ) -> Union[UNet2DConditionOutput, Tuple]:
715
- r"""
716
- The [`UNet2DConditionModel`] forward method.
717
-
718
- Args:
719
- sample (`torch.FloatTensor`):
720
- The noisy input tensor with the following shape `(batch, channel, height, width)`.
721
- timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
722
- encoder_hidden_states (`torch.FloatTensor`):
723
- The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
724
- encoder_attention_mask (`torch.Tensor`):
725
- A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
726
- `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
727
- which adds large negative values to the attention scores corresponding to "discard" tokens.
728
- return_dict (`bool`, *optional*, defaults to `True`):
729
- Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
730
- tuple.
731
- cross_attention_kwargs (`dict`, *optional*):
732
- A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
733
- added_cond_kwargs: (`dict`, *optional*):
734
- A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that
735
- are passed along to the UNet blocks.
736
-
737
- Returns:
738
- [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
739
- If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
740
- a `tuple` is returned where the first element is the sample tensor.
741
- """
742
- # By default samples have to be AT least a multiple of the overall upsampling factor.
743
- # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
744
- # However, the upsampling interpolation output size can be forced to fit any upsampling size
745
- # on the fly if necessary.
746
- default_overall_up_factor = 2**self.num_upsamplers
747
-
748
- # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
749
- forward_upsample_size = False
750
- upsample_size = None
751
-
752
- if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
753
- logger.info("Forward upsample size to force interpolation output size.")
754
- forward_upsample_size = True
755
-
756
- # ensure attention_mask is a bias, and give it a singleton query_tokens dimension
757
- # expects mask of shape:
758
- # [batch, key_tokens]
759
- # adds singleton query_tokens dimension:
760
- # [batch, 1, key_tokens]
761
- # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
762
- # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
763
- # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
764
- if attention_mask is not None:
765
- # assume that mask is expressed as:
766
- # (1 = keep, 0 = discard)
767
- # convert mask into a bias that can be added to attention scores:
768
- # (keep = +0, discard = -10000.0)
769
- attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
770
- attention_mask = attention_mask.unsqueeze(1)
771
-
772
- # convert encoder_attention_mask to a bias the same way we do for attention_mask
773
- if encoder_attention_mask is not None:
774
- encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
775
- encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
776
-
777
- # 0. center input if necessary
778
- if self.config.center_input_sample:
779
- sample = 2 * sample - 1.0
780
-
781
- # 1. time
782
- timesteps = timestep
783
- if not torch.is_tensor(timesteps):
784
- # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
785
- # This would be a good case for the `match` statement (Python 3.10+)
786
- is_mps = sample.device.type == "mps"
787
- if isinstance(timestep, float):
788
- dtype = torch.float32 if is_mps else torch.float64
789
- else:
790
- dtype = torch.int32 if is_mps else torch.int64
791
- timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
792
- elif len(timesteps.shape) == 0:
793
- timesteps = timesteps[None].to(sample.device)
794
-
795
- # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
796
- timesteps = timesteps.expand(sample.shape[0])
797
-
798
- t_emb = self.time_proj(timesteps)
799
-
800
- # `Timesteps` does not contain any weights and will always return f32 tensors
801
- # but time_embedding might actually be running in fp16. so we need to cast here.
802
- # there might be better ways to encapsulate this.
803
- t_emb = t_emb.to(dtype=sample.dtype)
804
-
805
- emb = self.time_embedding(t_emb, timestep_cond)
806
- aug_emb = None
807
-
808
- if self.class_embedding is not None:
809
- if class_labels is None:
810
- raise ValueError("class_labels should be provided when num_class_embeds > 0")
811
-
812
- if self.config.class_embed_type == "timestep":
813
- class_labels = self.time_proj(class_labels)
814
-
815
- # `Timesteps` does not contain any weights and will always return f32 tensors
816
- # there might be better ways to encapsulate this.
817
- class_labels = class_labels.to(dtype=sample.dtype)
818
-
819
- class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
820
-
821
- if self.config.class_embeddings_concat:
822
- emb = torch.cat([emb, class_emb], dim=-1)
823
- else:
824
- emb = emb + class_emb
825
-
826
- if self.config.addition_embed_type == "text":
827
- aug_emb = self.add_embedding(encoder_hidden_states)
828
- elif self.config.addition_embed_type == "text_image":
829
- # Kandinsky 2.1 - style
830
- if "image_embeds" not in added_cond_kwargs:
831
- raise ValueError(
832
- f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
833
- )
834
-
835
- image_embs = added_cond_kwargs.get("image_embeds")
836
- text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
837
- aug_emb = self.add_embedding(text_embs, image_embs)
838
- elif self.config.addition_embed_type == "text_time":
839
- if "text_embeds" not in added_cond_kwargs:
840
- raise ValueError(
841
- f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
842
- )
843
- text_embeds = added_cond_kwargs.get("text_embeds")
844
- if "time_ids" not in added_cond_kwargs:
845
- raise ValueError(
846
- f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
847
- )
848
- time_ids = added_cond_kwargs.get("time_ids")
849
- time_embeds = self.add_time_proj(time_ids.flatten())
850
- time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
851
-
852
- add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
853
- add_embeds = add_embeds.to(emb.dtype)
854
- aug_emb = self.add_embedding(add_embeds)
855
- elif self.config.addition_embed_type == "image":
856
- # Kandinsky 2.2 - style
857
- if "image_embeds" not in added_cond_kwargs:
858
- raise ValueError(
859
- f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
860
- )
861
- image_embs = added_cond_kwargs.get("image_embeds")
862
- aug_emb = self.add_embedding(image_embs)
863
- elif self.config.addition_embed_type == "image_hint":
864
- # Kandinsky 2.2 - style
865
- if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs:
866
- raise ValueError(
867
- f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
868
- )
869
- image_embs = added_cond_kwargs.get("image_embeds")
870
- hint = added_cond_kwargs.get("hint")
871
- aug_emb, hint = self.add_embedding(image_embs, hint)
872
- sample = torch.cat([sample, hint], dim=1)
873
-
874
- emb = emb + aug_emb if aug_emb is not None else emb
875
-
876
- if self.time_embed_act is not None:
877
- emb = self.time_embed_act(emb)
878
-
879
- if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
880
- encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
881
- elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
882
- # Kadinsky 2.1 - style
883
- if "image_embeds" not in added_cond_kwargs:
884
- raise ValueError(
885
- f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
886
- )
887
-
888
- image_embeds = added_cond_kwargs.get("image_embeds")
889
- encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
890
- elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj":
891
- # Kandinsky 2.2 - style
892
- if "image_embeds" not in added_cond_kwargs:
893
- raise ValueError(
894
- f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
895
- )
896
- image_embeds = added_cond_kwargs.get("image_embeds")
897
- encoder_hidden_states = self.encoder_hid_proj(image_embeds)
898
- # 2. pre-process
899
- sample = self.conv_in(sample)
900
-
901
- # 3. down
902
- down_block_res_samples = (sample,)
903
- for downsample_block in self.down_blocks:
904
- if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
905
- sample, res_samples = downsample_block(
906
- hidden_states=sample,
907
- temb=emb,
908
- encoder_hidden_states=encoder_hidden_states,
909
- attention_mask=attention_mask,
910
- cross_attention_kwargs=cross_attention_kwargs,
911
- encoder_attention_mask=encoder_attention_mask,
912
- )
913
- else:
914
- sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
915
-
916
- down_block_res_samples += res_samples
917
-
918
- if down_block_additional_residuals is not None:
919
- new_down_block_res_samples = ()
920
-
921
- for down_block_res_sample, down_block_additional_residual in zip(
922
- down_block_res_samples, down_block_additional_residuals
923
- ):
924
- down_block_res_sample = down_block_res_sample + down_block_additional_residual
925
- new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
926
-
927
- down_block_res_samples = new_down_block_res_samples
928
-
929
- # 4. mid
930
- if self.mid_block is not None:
931
- sample = self.mid_block(
932
- sample,
933
- emb,
934
- encoder_hidden_states=encoder_hidden_states,
935
- attention_mask=attention_mask,
936
- cross_attention_kwargs=cross_attention_kwargs,
937
- encoder_attention_mask=encoder_attention_mask,
938
- )
939
-
940
- if mid_block_additional_residual is not None:
941
- sample = sample + mid_block_additional_residual
942
-
943
- # 5. up
944
- for i, upsample_block in enumerate(self.up_blocks):
945
- is_final_block = i == len(self.up_blocks) - 1
946
-
947
- res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
948
- down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
949
-
950
- # if we have not reached the final block and need to forward the
951
- # upsample size, we do it here
952
- if not is_final_block and forward_upsample_size:
953
- upsample_size = down_block_res_samples[-1].shape[2:]
954
-
955
- if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
956
- sample = upsample_block(
957
- hidden_states=sample,
958
- temb=emb,
959
- res_hidden_states_tuple=res_samples,
960
- encoder_hidden_states=encoder_hidden_states,
961
- cross_attention_kwargs=cross_attention_kwargs,
962
- upsample_size=upsample_size,
963
- attention_mask=attention_mask,
964
- encoder_attention_mask=encoder_attention_mask,
965
- )
966
- else:
967
- sample = upsample_block(
968
- hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
969
- )
970
-
971
- # 6. post-process
972
- if self.conv_norm_out:
973
- sample = self.conv_norm_out(sample)
974
- sample = self.conv_act(sample)
975
- sample = self.conv_out(sample)
976
-
977
- if not return_dict:
978
- return (sample,)
979
-
980
- return UNet2DConditionOutput(sample=sample)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
gradio_demo/6DoF/diffusers/models/unet_2d_condition_flax.py DELETED
@@ -1,357 +0,0 @@
1
- # Copyright 2023 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- from typing import Optional, Tuple, Union
15
-
16
- import flax
17
- import flax.linen as nn
18
- import jax
19
- import jax.numpy as jnp
20
- from flax.core.frozen_dict import FrozenDict
21
-
22
- from ..configuration_utils import ConfigMixin, flax_register_to_config
23
- from ..utils import BaseOutput
24
- from .embeddings_flax import FlaxTimestepEmbedding, FlaxTimesteps
25
- from .modeling_flax_utils import FlaxModelMixin
26
- from .unet_2d_blocks_flax import (
27
- FlaxCrossAttnDownBlock2D,
28
- FlaxCrossAttnUpBlock2D,
29
- FlaxDownBlock2D,
30
- FlaxUNetMidBlock2DCrossAttn,
31
- FlaxUpBlock2D,
32
- )
33
-
34
-
35
- @flax.struct.dataclass
36
- class FlaxUNet2DConditionOutput(BaseOutput):
37
- """
38
- The output of [`FlaxUNet2DConditionModel`].
39
-
40
- Args:
41
- sample (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)`):
42
- The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
43
- """
44
-
45
- sample: jnp.ndarray
46
-
47
-
48
- @flax_register_to_config
49
- class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
50
- r"""
51
- A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
52
- shaped output.
53
-
54
- This model inherits from [`FlaxModelMixin`]. Check the superclass documentation for it's generic methods
55
- implemented for all models (such as downloading or saving).
56
-
57
- This model is also a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module)
58
- subclass. Use it as a regular Flax Linen module and refer to the Flax documentation for all matters related to its
59
- general usage and behavior.
60
-
61
- Inherent JAX features such as the following are supported:
62
- - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
63
- - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
64
- - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
65
- - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
66
-
67
- Parameters:
68
- sample_size (`int`, *optional*):
69
- The size of the input sample.
70
- in_channels (`int`, *optional*, defaults to 4):
71
- The number of channels in the input sample.
72
- out_channels (`int`, *optional*, defaults to 4):
73
- The number of channels in the output.
74
- down_block_types (`Tuple[str]`, *optional*, defaults to `("FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxDownBlock2D")`):
75
- The tuple of downsample blocks to use.
76
- up_block_types (`Tuple[str]`, *optional*, defaults to `("FlaxUpBlock2D", "FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D")`):
77
- The tuple of upsample blocks to use.
78
- block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
79
- The tuple of output channels for each block.
80
- layers_per_block (`int`, *optional*, defaults to 2):
81
- The number of layers per block.
82
- attention_head_dim (`int` or `Tuple[int]`, *optional*, defaults to 8):
83
- The dimension of the attention heads.
84
- num_attention_heads (`int` or `Tuple[int]`, *optional*):
85
- The number of attention heads.
86
- cross_attention_dim (`int`, *optional*, defaults to 768):
87
- The dimension of the cross attention features.
88
- dropout (`float`, *optional*, defaults to 0):
89
- Dropout probability for down, up and bottleneck blocks.
90
- flip_sin_to_cos (`bool`, *optional*, defaults to `True`):
91
- Whether to flip the sin to cos in the time embedding.
92
- freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
93
- use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
94
- Enable memory efficient attention as described [here](https://arxiv.org/abs/2112.05682).
95
- """
96
-
97
- sample_size: int = 32
98
- in_channels: int = 4
99
- out_channels: int = 4
100
- down_block_types: Tuple[str] = (
101
- "CrossAttnDownBlock2D",
102
- "CrossAttnDownBlock2D",
103
- "CrossAttnDownBlock2D",
104
- "DownBlock2D",
105
- )
106
- up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")
107
- only_cross_attention: Union[bool, Tuple[bool]] = False
108
- block_out_channels: Tuple[int] = (320, 640, 1280, 1280)
109
- layers_per_block: int = 2
110
- attention_head_dim: Union[int, Tuple[int]] = 8
111
- num_attention_heads: Optional[Union[int, Tuple[int]]] = None
112
- cross_attention_dim: int = 1280
113
- dropout: float = 0.0
114
- use_linear_projection: bool = False
115
- dtype: jnp.dtype = jnp.float32
116
- flip_sin_to_cos: bool = True
117
- freq_shift: int = 0
118
- use_memory_efficient_attention: bool = False
119
-
120
- def init_weights(self, rng: jax.random.KeyArray) -> FrozenDict:
121
- # init input tensors
122
- sample_shape = (1, self.in_channels, self.sample_size, self.sample_size)
123
- sample = jnp.zeros(sample_shape, dtype=jnp.float32)
124
- timesteps = jnp.ones((1,), dtype=jnp.int32)
125
- encoder_hidden_states = jnp.zeros((1, 1, self.cross_attention_dim), dtype=jnp.float32)
126
-
127
- params_rng, dropout_rng = jax.random.split(rng)
128
- rngs = {"params": params_rng, "dropout": dropout_rng}
129
-
130
- return self.init(rngs, sample, timesteps, encoder_hidden_states)["params"]
131
-
132
- def setup(self):
133
- block_out_channels = self.block_out_channels
134
- time_embed_dim = block_out_channels[0] * 4
135
-
136
- if self.num_attention_heads is not None:
137
- raise ValueError(
138
- "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
139
- )
140
-
141
- # If `num_attention_heads` is not defined (which is the case for most models)
142
- # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
143
- # The reason for this behavior is to correct for incorrectly named variables that were introduced
144
- # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
145
- # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
146
- # which is why we correct for the naming here.
147
- num_attention_heads = self.num_attention_heads or self.attention_head_dim
148
-
149
- # input
150
- self.conv_in = nn.Conv(
151
- block_out_channels[0],
152
- kernel_size=(3, 3),
153
- strides=(1, 1),
154
- padding=((1, 1), (1, 1)),
155
- dtype=self.dtype,
156
- )
157
-
158
- # time
159
- self.time_proj = FlaxTimesteps(
160
- block_out_channels[0], flip_sin_to_cos=self.flip_sin_to_cos, freq_shift=self.config.freq_shift
161
- )
162
- self.time_embedding = FlaxTimestepEmbedding(time_embed_dim, dtype=self.dtype)
163
-
164
- only_cross_attention = self.only_cross_attention
165
- if isinstance(only_cross_attention, bool):
166
- only_cross_attention = (only_cross_attention,) * len(self.down_block_types)
167
-
168
- if isinstance(num_attention_heads, int):
169
- num_attention_heads = (num_attention_heads,) * len(self.down_block_types)
170
-
171
- # down
172
- down_blocks = []
173
- output_channel = block_out_channels[0]
174
- for i, down_block_type in enumerate(self.down_block_types):
175
- input_channel = output_channel
176
- output_channel = block_out_channels[i]
177
- is_final_block = i == len(block_out_channels) - 1
178
-
179
- if down_block_type == "CrossAttnDownBlock2D":
180
- down_block = FlaxCrossAttnDownBlock2D(
181
- in_channels=input_channel,
182
- out_channels=output_channel,
183
- dropout=self.dropout,
184
- num_layers=self.layers_per_block,
185
- num_attention_heads=num_attention_heads[i],
186
- add_downsample=not is_final_block,
187
- use_linear_projection=self.use_linear_projection,
188
- only_cross_attention=only_cross_attention[i],
189
- use_memory_efficient_attention=self.use_memory_efficient_attention,
190
- dtype=self.dtype,
191
- )
192
- else:
193
- down_block = FlaxDownBlock2D(
194
- in_channels=input_channel,
195
- out_channels=output_channel,
196
- dropout=self.dropout,
197
- num_layers=self.layers_per_block,
198
- add_downsample=not is_final_block,
199
- dtype=self.dtype,
200
- )
201
-
202
- down_blocks.append(down_block)
203
- self.down_blocks = down_blocks
204
-
205
- # mid
206
- self.mid_block = FlaxUNetMidBlock2DCrossAttn(
207
- in_channels=block_out_channels[-1],
208
- dropout=self.dropout,
209
- num_attention_heads=num_attention_heads[-1],
210
- use_linear_projection=self.use_linear_projection,
211
- use_memory_efficient_attention=self.use_memory_efficient_attention,
212
- dtype=self.dtype,
213
- )
214
-
215
- # up
216
- up_blocks = []
217
- reversed_block_out_channels = list(reversed(block_out_channels))
218
- reversed_num_attention_heads = list(reversed(num_attention_heads))
219
- only_cross_attention = list(reversed(only_cross_attention))
220
- output_channel = reversed_block_out_channels[0]
221
- for i, up_block_type in enumerate(self.up_block_types):
222
- prev_output_channel = output_channel
223
- output_channel = reversed_block_out_channels[i]
224
- input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
225
-
226
- is_final_block = i == len(block_out_channels) - 1
227
-
228
- if up_block_type == "CrossAttnUpBlock2D":
229
- up_block = FlaxCrossAttnUpBlock2D(
230
- in_channels=input_channel,
231
- out_channels=output_channel,
232
- prev_output_channel=prev_output_channel,
233
- num_layers=self.layers_per_block + 1,
234
- num_attention_heads=reversed_num_attention_heads[i],
235
- add_upsample=not is_final_block,
236
- dropout=self.dropout,
237
- use_linear_projection=self.use_linear_projection,
238
- only_cross_attention=only_cross_attention[i],
239
- use_memory_efficient_attention=self.use_memory_efficient_attention,
240
- dtype=self.dtype,
241
- )
242
- else:
243
- up_block = FlaxUpBlock2D(
244
- in_channels=input_channel,
245
- out_channels=output_channel,
246
- prev_output_channel=prev_output_channel,
247
- num_layers=self.layers_per_block + 1,
248
- add_upsample=not is_final_block,
249
- dropout=self.dropout,
250
- dtype=self.dtype,
251
- )
252
-
253
- up_blocks.append(up_block)
254
- prev_output_channel = output_channel
255
- self.up_blocks = up_blocks
256
-
257
- # out
258
- self.conv_norm_out = nn.GroupNorm(num_groups=32, epsilon=1e-5)
259
- self.conv_out = nn.Conv(
260
- self.out_channels,
261
- kernel_size=(3, 3),
262
- strides=(1, 1),
263
- padding=((1, 1), (1, 1)),
264
- dtype=self.dtype,
265
- )
266
-
267
- def __call__(
268
- self,
269
- sample,
270
- timesteps,
271
- encoder_hidden_states,
272
- down_block_additional_residuals=None,
273
- mid_block_additional_residual=None,
274
- return_dict: bool = True,
275
- train: bool = False,
276
- ) -> Union[FlaxUNet2DConditionOutput, Tuple]:
277
- r"""
278
- Args:
279
- sample (`jnp.ndarray`): (batch, channel, height, width) noisy inputs tensor
280
- timestep (`jnp.ndarray` or `float` or `int`): timesteps
281
- encoder_hidden_states (`jnp.ndarray`): (batch_size, sequence_length, hidden_size) encoder hidden states
282
- return_dict (`bool`, *optional*, defaults to `True`):
283
- Whether or not to return a [`models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] instead of a
284
- plain tuple.
285
- train (`bool`, *optional*, defaults to `False`):
286
- Use deterministic functions and disable dropout when not training.
287
-
288
- Returns:
289
- [`~models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] or `tuple`:
290
- [`~models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`.
291
- When returning a tuple, the first element is the sample tensor.
292
- """
293
- # 1. time
294
- if not isinstance(timesteps, jnp.ndarray):
295
- timesteps = jnp.array([timesteps], dtype=jnp.int32)
296
- elif isinstance(timesteps, jnp.ndarray) and len(timesteps.shape) == 0:
297
- timesteps = timesteps.astype(dtype=jnp.float32)
298
- timesteps = jnp.expand_dims(timesteps, 0)
299
-
300
- t_emb = self.time_proj(timesteps)
301
- t_emb = self.time_embedding(t_emb)
302
-
303
- # 2. pre-process
304
- sample = jnp.transpose(sample, (0, 2, 3, 1))
305
- sample = self.conv_in(sample)
306
-
307
- # 3. down
308
- down_block_res_samples = (sample,)
309
- for down_block in self.down_blocks:
310
- if isinstance(down_block, FlaxCrossAttnDownBlock2D):
311
- sample, res_samples = down_block(sample, t_emb, encoder_hidden_states, deterministic=not train)
312
- else:
313
- sample, res_samples = down_block(sample, t_emb, deterministic=not train)
314
- down_block_res_samples += res_samples
315
-
316
- if down_block_additional_residuals is not None:
317
- new_down_block_res_samples = ()
318
-
319
- for down_block_res_sample, down_block_additional_residual in zip(
320
- down_block_res_samples, down_block_additional_residuals
321
- ):
322
- down_block_res_sample += down_block_additional_residual
323
- new_down_block_res_samples += (down_block_res_sample,)
324
-
325
- down_block_res_samples = new_down_block_res_samples
326
-
327
- # 4. mid
328
- sample = self.mid_block(sample, t_emb, encoder_hidden_states, deterministic=not train)
329
-
330
- if mid_block_additional_residual is not None:
331
- sample += mid_block_additional_residual
332
-
333
- # 5. up
334
- for up_block in self.up_blocks:
335
- res_samples = down_block_res_samples[-(self.layers_per_block + 1) :]
336
- down_block_res_samples = down_block_res_samples[: -(self.layers_per_block + 1)]
337
- if isinstance(up_block, FlaxCrossAttnUpBlock2D):
338
- sample = up_block(
339
- sample,
340
- temb=t_emb,
341
- encoder_hidden_states=encoder_hidden_states,
342
- res_hidden_states_tuple=res_samples,
343
- deterministic=not train,
344
- )
345
- else:
346
- sample = up_block(sample, temb=t_emb, res_hidden_states_tuple=res_samples, deterministic=not train)
347
-
348
- # 6. post-process
349
- sample = self.conv_norm_out(sample)
350
- sample = nn.silu(sample)
351
- sample = self.conv_out(sample)
352
- sample = jnp.transpose(sample, (0, 3, 1, 2))
353
-
354
- if not return_dict:
355
- return (sample,)
356
-
357
- return FlaxUNet2DConditionOutput(sample=sample)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
gradio_demo/6DoF/diffusers/models/unet_3d_blocks.py DELETED
@@ -1,679 +0,0 @@
1
- # Copyright 2023 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- import torch
16
- from torch import nn
17
-
18
- from .resnet import Downsample2D, ResnetBlock2D, TemporalConvLayer, Upsample2D
19
- from .transformer_2d import Transformer2DModel
20
- from .transformer_temporal import TransformerTemporalModel
21
-
22
-
23
- def get_down_block(
24
- down_block_type,
25
- num_layers,
26
- in_channels,
27
- out_channels,
28
- temb_channels,
29
- add_downsample,
30
- resnet_eps,
31
- resnet_act_fn,
32
- num_attention_heads,
33
- resnet_groups=None,
34
- cross_attention_dim=None,
35
- downsample_padding=None,
36
- dual_cross_attention=False,
37
- use_linear_projection=True,
38
- only_cross_attention=False,
39
- upcast_attention=False,
40
- resnet_time_scale_shift="default",
41
- ):
42
- if down_block_type == "DownBlock3D":
43
- return DownBlock3D(
44
- num_layers=num_layers,
45
- in_channels=in_channels,
46
- out_channels=out_channels,
47
- temb_channels=temb_channels,
48
- add_downsample=add_downsample,
49
- resnet_eps=resnet_eps,
50
- resnet_act_fn=resnet_act_fn,
51
- resnet_groups=resnet_groups,
52
- downsample_padding=downsample_padding,
53
- resnet_time_scale_shift=resnet_time_scale_shift,
54
- )
55
- elif down_block_type == "CrossAttnDownBlock3D":
56
- if cross_attention_dim is None:
57
- raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock3D")
58
- return CrossAttnDownBlock3D(
59
- num_layers=num_layers,
60
- in_channels=in_channels,
61
- out_channels=out_channels,
62
- temb_channels=temb_channels,
63
- add_downsample=add_downsample,
64
- resnet_eps=resnet_eps,
65
- resnet_act_fn=resnet_act_fn,
66
- resnet_groups=resnet_groups,
67
- downsample_padding=downsample_padding,
68
- cross_attention_dim=cross_attention_dim,
69
- num_attention_heads=num_attention_heads,
70
- dual_cross_attention=dual_cross_attention,
71
- use_linear_projection=use_linear_projection,
72
- only_cross_attention=only_cross_attention,
73
- upcast_attention=upcast_attention,
74
- resnet_time_scale_shift=resnet_time_scale_shift,
75
- )
76
- raise ValueError(f"{down_block_type} does not exist.")
77
-
78
-
79
- def get_up_block(
80
- up_block_type,
81
- num_layers,
82
- in_channels,
83
- out_channels,
84
- prev_output_channel,
85
- temb_channels,
86
- add_upsample,
87
- resnet_eps,
88
- resnet_act_fn,
89
- num_attention_heads,
90
- resnet_groups=None,
91
- cross_attention_dim=None,
92
- dual_cross_attention=False,
93
- use_linear_projection=True,
94
- only_cross_attention=False,
95
- upcast_attention=False,
96
- resnet_time_scale_shift="default",
97
- ):
98
- if up_block_type == "UpBlock3D":
99
- return UpBlock3D(
100
- num_layers=num_layers,
101
- in_channels=in_channels,
102
- out_channels=out_channels,
103
- prev_output_channel=prev_output_channel,
104
- temb_channels=temb_channels,
105
- add_upsample=add_upsample,
106
- resnet_eps=resnet_eps,
107
- resnet_act_fn=resnet_act_fn,
108
- resnet_groups=resnet_groups,
109
- resnet_time_scale_shift=resnet_time_scale_shift,
110
- )
111
- elif up_block_type == "CrossAttnUpBlock3D":
112
- if cross_attention_dim is None:
113
- raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock3D")
114
- return CrossAttnUpBlock3D(
115
- num_layers=num_layers,
116
- in_channels=in_channels,
117
- out_channels=out_channels,
118
- prev_output_channel=prev_output_channel,
119
- temb_channels=temb_channels,
120
- add_upsample=add_upsample,
121
- resnet_eps=resnet_eps,
122
- resnet_act_fn=resnet_act_fn,
123
- resnet_groups=resnet_groups,
124
- cross_attention_dim=cross_attention_dim,
125
- num_attention_heads=num_attention_heads,
126
- dual_cross_attention=dual_cross_attention,
127
- use_linear_projection=use_linear_projection,
128
- only_cross_attention=only_cross_attention,
129
- upcast_attention=upcast_attention,
130
- resnet_time_scale_shift=resnet_time_scale_shift,
131
- )
132
- raise ValueError(f"{up_block_type} does not exist.")
133
-
134
-
135
- class UNetMidBlock3DCrossAttn(nn.Module):
136
- def __init__(
137
- self,
138
- in_channels: int,
139
- temb_channels: int,
140
- dropout: float = 0.0,
141
- num_layers: int = 1,
142
- resnet_eps: float = 1e-6,
143
- resnet_time_scale_shift: str = "default",
144
- resnet_act_fn: str = "swish",
145
- resnet_groups: int = 32,
146
- resnet_pre_norm: bool = True,
147
- num_attention_heads=1,
148
- output_scale_factor=1.0,
149
- cross_attention_dim=1280,
150
- dual_cross_attention=False,
151
- use_linear_projection=True,
152
- upcast_attention=False,
153
- ):
154
- super().__init__()
155
-
156
- self.has_cross_attention = True
157
- self.num_attention_heads = num_attention_heads
158
- resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
159
-
160
- # there is always at least one resnet
161
- resnets = [
162
- ResnetBlock2D(
163
- in_channels=in_channels,
164
- out_channels=in_channels,
165
- temb_channels=temb_channels,
166
- eps=resnet_eps,
167
- groups=resnet_groups,
168
- dropout=dropout,
169
- time_embedding_norm=resnet_time_scale_shift,
170
- non_linearity=resnet_act_fn,
171
- output_scale_factor=output_scale_factor,
172
- pre_norm=resnet_pre_norm,
173
- )
174
- ]
175
- temp_convs = [
176
- TemporalConvLayer(
177
- in_channels,
178
- in_channels,
179
- dropout=0.1,
180
- )
181
- ]
182
- attentions = []
183
- temp_attentions = []
184
-
185
- for _ in range(num_layers):
186
- attentions.append(
187
- Transformer2DModel(
188
- in_channels // num_attention_heads,
189
- num_attention_heads,
190
- in_channels=in_channels,
191
- num_layers=1,
192
- cross_attention_dim=cross_attention_dim,
193
- norm_num_groups=resnet_groups,
194
- use_linear_projection=use_linear_projection,
195
- upcast_attention=upcast_attention,
196
- )
197
- )
198
- temp_attentions.append(
199
- TransformerTemporalModel(
200
- in_channels // num_attention_heads,
201
- num_attention_heads,
202
- in_channels=in_channels,
203
- num_layers=1,
204
- cross_attention_dim=cross_attention_dim,
205
- norm_num_groups=resnet_groups,
206
- )
207
- )
208
- resnets.append(
209
- ResnetBlock2D(
210
- in_channels=in_channels,
211
- out_channels=in_channels,
212
- temb_channels=temb_channels,
213
- eps=resnet_eps,
214
- groups=resnet_groups,
215
- dropout=dropout,
216
- time_embedding_norm=resnet_time_scale_shift,
217
- non_linearity=resnet_act_fn,
218
- output_scale_factor=output_scale_factor,
219
- pre_norm=resnet_pre_norm,
220
- )
221
- )
222
- temp_convs.append(
223
- TemporalConvLayer(
224
- in_channels,
225
- in_channels,
226
- dropout=0.1,
227
- )
228
- )
229
-
230
- self.resnets = nn.ModuleList(resnets)
231
- self.temp_convs = nn.ModuleList(temp_convs)
232
- self.attentions = nn.ModuleList(attentions)
233
- self.temp_attentions = nn.ModuleList(temp_attentions)
234
-
235
- def forward(
236
- self,
237
- hidden_states,
238
- temb=None,
239
- encoder_hidden_states=None,
240
- attention_mask=None,
241
- num_frames=1,
242
- cross_attention_kwargs=None,
243
- ):
244
- hidden_states = self.resnets[0](hidden_states, temb)
245
- hidden_states = self.temp_convs[0](hidden_states, num_frames=num_frames)
246
- for attn, temp_attn, resnet, temp_conv in zip(
247
- self.attentions, self.temp_attentions, self.resnets[1:], self.temp_convs[1:]
248
- ):
249
- hidden_states = attn(
250
- hidden_states,
251
- encoder_hidden_states=encoder_hidden_states,
252
- cross_attention_kwargs=cross_attention_kwargs,
253
- return_dict=False,
254
- )[0]
255
- hidden_states = temp_attn(
256
- hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs, return_dict=False
257
- )[0]
258
- hidden_states = resnet(hidden_states, temb)
259
- hidden_states = temp_conv(hidden_states, num_frames=num_frames)
260
-
261
- return hidden_states
262
-
263
-
264
- class CrossAttnDownBlock3D(nn.Module):
265
- def __init__(
266
- self,
267
- in_channels: int,
268
- out_channels: int,
269
- temb_channels: int,
270
- dropout: float = 0.0,
271
- num_layers: int = 1,
272
- resnet_eps: float = 1e-6,
273
- resnet_time_scale_shift: str = "default",
274
- resnet_act_fn: str = "swish",
275
- resnet_groups: int = 32,
276
- resnet_pre_norm: bool = True,
277
- num_attention_heads=1,
278
- cross_attention_dim=1280,
279
- output_scale_factor=1.0,
280
- downsample_padding=1,
281
- add_downsample=True,
282
- dual_cross_attention=False,
283
- use_linear_projection=False,
284
- only_cross_attention=False,
285
- upcast_attention=False,
286
- ):
287
- super().__init__()
288
- resnets = []
289
- attentions = []
290
- temp_attentions = []
291
- temp_convs = []
292
-
293
- self.has_cross_attention = True
294
- self.num_attention_heads = num_attention_heads
295
-
296
- for i in range(num_layers):
297
- in_channels = in_channels if i == 0 else out_channels
298
- resnets.append(
299
- ResnetBlock2D(
300
- in_channels=in_channels,
301
- out_channels=out_channels,
302
- temb_channels=temb_channels,
303
- eps=resnet_eps,
304
- groups=resnet_groups,
305
- dropout=dropout,
306
- time_embedding_norm=resnet_time_scale_shift,
307
- non_linearity=resnet_act_fn,
308
- output_scale_factor=output_scale_factor,
309
- pre_norm=resnet_pre_norm,
310
- )
311
- )
312
- temp_convs.append(
313
- TemporalConvLayer(
314
- out_channels,
315
- out_channels,
316
- dropout=0.1,
317
- )
318
- )
319
- attentions.append(
320
- Transformer2DModel(
321
- out_channels // num_attention_heads,
322
- num_attention_heads,
323
- in_channels=out_channels,
324
- num_layers=1,
325
- cross_attention_dim=cross_attention_dim,
326
- norm_num_groups=resnet_groups,
327
- use_linear_projection=use_linear_projection,
328
- only_cross_attention=only_cross_attention,
329
- upcast_attention=upcast_attention,
330
- )
331
- )
332
- temp_attentions.append(
333
- TransformerTemporalModel(
334
- out_channels // num_attention_heads,
335
- num_attention_heads,
336
- in_channels=out_channels,
337
- num_layers=1,
338
- cross_attention_dim=cross_attention_dim,
339
- norm_num_groups=resnet_groups,
340
- )
341
- )
342
- self.resnets = nn.ModuleList(resnets)
343
- self.temp_convs = nn.ModuleList(temp_convs)
344
- self.attentions = nn.ModuleList(attentions)
345
- self.temp_attentions = nn.ModuleList(temp_attentions)
346
-
347
- if add_downsample:
348
- self.downsamplers = nn.ModuleList(
349
- [
350
- Downsample2D(
351
- out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
352
- )
353
- ]
354
- )
355
- else:
356
- self.downsamplers = None
357
-
358
- self.gradient_checkpointing = False
359
-
360
- def forward(
361
- self,
362
- hidden_states,
363
- temb=None,
364
- encoder_hidden_states=None,
365
- attention_mask=None,
366
- num_frames=1,
367
- cross_attention_kwargs=None,
368
- ):
369
- # TODO(Patrick, William) - attention mask is not used
370
- output_states = ()
371
-
372
- for resnet, temp_conv, attn, temp_attn in zip(
373
- self.resnets, self.temp_convs, self.attentions, self.temp_attentions
374
- ):
375
- hidden_states = resnet(hidden_states, temb)
376
- hidden_states = temp_conv(hidden_states, num_frames=num_frames)
377
- hidden_states = attn(
378
- hidden_states,
379
- encoder_hidden_states=encoder_hidden_states,
380
- cross_attention_kwargs=cross_attention_kwargs,
381
- return_dict=False,
382
- )[0]
383
- hidden_states = temp_attn(
384
- hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs, return_dict=False
385
- )[0]
386
-
387
- output_states += (hidden_states,)
388
-
389
- if self.downsamplers is not None:
390
- for downsampler in self.downsamplers:
391
- hidden_states = downsampler(hidden_states)
392
-
393
- output_states += (hidden_states,)
394
-
395
- return hidden_states, output_states
396
-
397
-
398
- class DownBlock3D(nn.Module):
399
- def __init__(
400
- self,
401
- in_channels: int,
402
- out_channels: int,
403
- temb_channels: int,
404
- dropout: float = 0.0,
405
- num_layers: int = 1,
406
- resnet_eps: float = 1e-6,
407
- resnet_time_scale_shift: str = "default",
408
- resnet_act_fn: str = "swish",
409
- resnet_groups: int = 32,
410
- resnet_pre_norm: bool = True,
411
- output_scale_factor=1.0,
412
- add_downsample=True,
413
- downsample_padding=1,
414
- ):
415
- super().__init__()
416
- resnets = []
417
- temp_convs = []
418
-
419
- for i in range(num_layers):
420
- in_channels = in_channels if i == 0 else out_channels
421
- resnets.append(
422
- ResnetBlock2D(
423
- in_channels=in_channels,
424
- out_channels=out_channels,
425
- temb_channels=temb_channels,
426
- eps=resnet_eps,
427
- groups=resnet_groups,
428
- dropout=dropout,
429
- time_embedding_norm=resnet_time_scale_shift,
430
- non_linearity=resnet_act_fn,
431
- output_scale_factor=output_scale_factor,
432
- pre_norm=resnet_pre_norm,
433
- )
434
- )
435
- temp_convs.append(
436
- TemporalConvLayer(
437
- out_channels,
438
- out_channels,
439
- dropout=0.1,
440
- )
441
- )
442
-
443
- self.resnets = nn.ModuleList(resnets)
444
- self.temp_convs = nn.ModuleList(temp_convs)
445
-
446
- if add_downsample:
447
- self.downsamplers = nn.ModuleList(
448
- [
449
- Downsample2D(
450
- out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
451
- )
452
- ]
453
- )
454
- else:
455
- self.downsamplers = None
456
-
457
- self.gradient_checkpointing = False
458
-
459
- def forward(self, hidden_states, temb=None, num_frames=1):
460
- output_states = ()
461
-
462
- for resnet, temp_conv in zip(self.resnets, self.temp_convs):
463
- hidden_states = resnet(hidden_states, temb)
464
- hidden_states = temp_conv(hidden_states, num_frames=num_frames)
465
-
466
- output_states += (hidden_states,)
467
-
468
- if self.downsamplers is not None:
469
- for downsampler in self.downsamplers:
470
- hidden_states = downsampler(hidden_states)
471
-
472
- output_states += (hidden_states,)
473
-
474
- return hidden_states, output_states
475
-
476
-
477
- class CrossAttnUpBlock3D(nn.Module):
478
- def __init__(
479
- self,
480
- in_channels: int,
481
- out_channels: int,
482
- prev_output_channel: int,
483
- temb_channels: int,
484
- dropout: float = 0.0,
485
- num_layers: int = 1,
486
- resnet_eps: float = 1e-6,
487
- resnet_time_scale_shift: str = "default",
488
- resnet_act_fn: str = "swish",
489
- resnet_groups: int = 32,
490
- resnet_pre_norm: bool = True,
491
- num_attention_heads=1,
492
- cross_attention_dim=1280,
493
- output_scale_factor=1.0,
494
- add_upsample=True,
495
- dual_cross_attention=False,
496
- use_linear_projection=False,
497
- only_cross_attention=False,
498
- upcast_attention=False,
499
- ):
500
- super().__init__()
501
- resnets = []
502
- temp_convs = []
503
- attentions = []
504
- temp_attentions = []
505
-
506
- self.has_cross_attention = True
507
- self.num_attention_heads = num_attention_heads
508
-
509
- for i in range(num_layers):
510
- res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
511
- resnet_in_channels = prev_output_channel if i == 0 else out_channels
512
-
513
- resnets.append(
514
- ResnetBlock2D(
515
- in_channels=resnet_in_channels + res_skip_channels,
516
- out_channels=out_channels,
517
- temb_channels=temb_channels,
518
- eps=resnet_eps,
519
- groups=resnet_groups,
520
- dropout=dropout,
521
- time_embedding_norm=resnet_time_scale_shift,
522
- non_linearity=resnet_act_fn,
523
- output_scale_factor=output_scale_factor,
524
- pre_norm=resnet_pre_norm,
525
- )
526
- )
527
- temp_convs.append(
528
- TemporalConvLayer(
529
- out_channels,
530
- out_channels,
531
- dropout=0.1,
532
- )
533
- )
534
- attentions.append(
535
- Transformer2DModel(
536
- out_channels // num_attention_heads,
537
- num_attention_heads,
538
- in_channels=out_channels,
539
- num_layers=1,
540
- cross_attention_dim=cross_attention_dim,
541
- norm_num_groups=resnet_groups,
542
- use_linear_projection=use_linear_projection,
543
- only_cross_attention=only_cross_attention,
544
- upcast_attention=upcast_attention,
545
- )
546
- )
547
- temp_attentions.append(
548
- TransformerTemporalModel(
549
- out_channels // num_attention_heads,
550
- num_attention_heads,
551
- in_channels=out_channels,
552
- num_layers=1,
553
- cross_attention_dim=cross_attention_dim,
554
- norm_num_groups=resnet_groups,
555
- )
556
- )
557
- self.resnets = nn.ModuleList(resnets)
558
- self.temp_convs = nn.ModuleList(temp_convs)
559
- self.attentions = nn.ModuleList(attentions)
560
- self.temp_attentions = nn.ModuleList(temp_attentions)
561
-
562
- if add_upsample:
563
- self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
564
- else:
565
- self.upsamplers = None
566
-
567
- self.gradient_checkpointing = False
568
-
569
- def forward(
570
- self,
571
- hidden_states,
572
- res_hidden_states_tuple,
573
- temb=None,
574
- encoder_hidden_states=None,
575
- upsample_size=None,
576
- attention_mask=None,
577
- num_frames=1,
578
- cross_attention_kwargs=None,
579
- ):
580
- # TODO(Patrick, William) - attention mask is not used
581
- for resnet, temp_conv, attn, temp_attn in zip(
582
- self.resnets, self.temp_convs, self.attentions, self.temp_attentions
583
- ):
584
- # pop res hidden states
585
- res_hidden_states = res_hidden_states_tuple[-1]
586
- res_hidden_states_tuple = res_hidden_states_tuple[:-1]
587
- hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
588
-
589
- hidden_states = resnet(hidden_states, temb)
590
- hidden_states = temp_conv(hidden_states, num_frames=num_frames)
591
- hidden_states = attn(
592
- hidden_states,
593
- encoder_hidden_states=encoder_hidden_states,
594
- cross_attention_kwargs=cross_attention_kwargs,
595
- return_dict=False,
596
- )[0]
597
- hidden_states = temp_attn(
598
- hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs, return_dict=False
599
- )[0]
600
-
601
- if self.upsamplers is not None:
602
- for upsampler in self.upsamplers:
603
- hidden_states = upsampler(hidden_states, upsample_size)
604
-
605
- return hidden_states
606
-
607
-
608
- class UpBlock3D(nn.Module):
609
- def __init__(
610
- self,
611
- in_channels: int,
612
- prev_output_channel: int,
613
- out_channels: int,
614
- temb_channels: int,
615
- dropout: float = 0.0,
616
- num_layers: int = 1,
617
- resnet_eps: float = 1e-6,
618
- resnet_time_scale_shift: str = "default",
619
- resnet_act_fn: str = "swish",
620
- resnet_groups: int = 32,
621
- resnet_pre_norm: bool = True,
622
- output_scale_factor=1.0,
623
- add_upsample=True,
624
- ):
625
- super().__init__()
626
- resnets = []
627
- temp_convs = []
628
-
629
- for i in range(num_layers):
630
- res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
631
- resnet_in_channels = prev_output_channel if i == 0 else out_channels
632
-
633
- resnets.append(
634
- ResnetBlock2D(
635
- in_channels=resnet_in_channels + res_skip_channels,
636
- out_channels=out_channels,
637
- temb_channels=temb_channels,
638
- eps=resnet_eps,
639
- groups=resnet_groups,
640
- dropout=dropout,
641
- time_embedding_norm=resnet_time_scale_shift,
642
- non_linearity=resnet_act_fn,
643
- output_scale_factor=output_scale_factor,
644
- pre_norm=resnet_pre_norm,
645
- )
646
- )
647
- temp_convs.append(
648
- TemporalConvLayer(
649
- out_channels,
650
- out_channels,
651
- dropout=0.1,
652
- )
653
- )
654
-
655
- self.resnets = nn.ModuleList(resnets)
656
- self.temp_convs = nn.ModuleList(temp_convs)
657
-
658
- if add_upsample:
659
- self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
660
- else:
661
- self.upsamplers = None
662
-
663
- self.gradient_checkpointing = False
664
-
665
- def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, num_frames=1):
666
- for resnet, temp_conv in zip(self.resnets, self.temp_convs):
667
- # pop res hidden states
668
- res_hidden_states = res_hidden_states_tuple[-1]
669
- res_hidden_states_tuple = res_hidden_states_tuple[:-1]
670
- hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
671
-
672
- hidden_states = resnet(hidden_states, temb)
673
- hidden_states = temp_conv(hidden_states, num_frames=num_frames)
674
-
675
- if self.upsamplers is not None:
676
- for upsampler in self.upsamplers:
677
- hidden_states = upsampler(hidden_states, upsample_size)
678
-
679
- return hidden_states
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
gradio_demo/6DoF/diffusers/models/unet_3d_condition.py DELETED
@@ -1,627 +0,0 @@
1
- # Copyright 2023 Alibaba DAMO-VILAB and The HuggingFace Team. All rights reserved.
2
- # Copyright 2023 The ModelScope Team.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- from dataclasses import dataclass
16
- from typing import Any, Dict, List, Optional, Tuple, Union
17
-
18
- import torch
19
- import torch.nn as nn
20
- import torch.utils.checkpoint
21
-
22
- from ..configuration_utils import ConfigMixin, register_to_config
23
- from ..loaders import UNet2DConditionLoadersMixin
24
- from ..utils import BaseOutput, logging
25
- from .attention_processor import AttentionProcessor, AttnProcessor
26
- from .embeddings import TimestepEmbedding, Timesteps
27
- from .modeling_utils import ModelMixin
28
- from .transformer_temporal import TransformerTemporalModel
29
- from .unet_3d_blocks import (
30
- CrossAttnDownBlock3D,
31
- CrossAttnUpBlock3D,
32
- DownBlock3D,
33
- UNetMidBlock3DCrossAttn,
34
- UpBlock3D,
35
- get_down_block,
36
- get_up_block,
37
- )
38
-
39
-
40
- logger = logging.get_logger(__name__) # pylint: disable=invalid-name
41
-
42
-
43
- @dataclass
44
- class UNet3DConditionOutput(BaseOutput):
45
- """
46
- The output of [`UNet3DConditionModel`].
47
-
48
- Args:
49
- sample (`torch.FloatTensor` of shape `(batch_size, num_frames, num_channels, height, width)`):
50
- The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
51
- """
52
-
53
- sample: torch.FloatTensor
54
-
55
-
56
- class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
57
- r"""
58
- A conditional 3D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
59
- shaped output.
60
-
61
- This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
62
- for all models (such as downloading or saving).
63
-
64
- Parameters:
65
- sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
66
- Height and width of input/output sample.
67
- in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample.
68
- out_channels (`int`, *optional*, defaults to 4): The number of channels in the output.
69
- down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
70
- The tuple of downsample blocks to use.
71
- up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
72
- The tuple of upsample blocks to use.
73
- block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
74
- The tuple of output channels for each block.
75
- layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
76
- downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
77
- mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
78
- act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
79
- norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
80
- If `None`, normalization and activation layers is skipped in post-processing.
81
- norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
82
- cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features.
83
- attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
84
- num_attention_heads (`int`, *optional*): The number of attention heads.
85
- """
86
-
87
- _supports_gradient_checkpointing = False
88
-
89
- @register_to_config
90
- def __init__(
91
- self,
92
- sample_size: Optional[int] = None,
93
- in_channels: int = 4,
94
- out_channels: int = 4,
95
- down_block_types: Tuple[str] = (
96
- "CrossAttnDownBlock3D",
97
- "CrossAttnDownBlock3D",
98
- "CrossAttnDownBlock3D",
99
- "DownBlock3D",
100
- ),
101
- up_block_types: Tuple[str] = ("UpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D"),
102
- block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
103
- layers_per_block: int = 2,
104
- downsample_padding: int = 1,
105
- mid_block_scale_factor: float = 1,
106
- act_fn: str = "silu",
107
- norm_num_groups: Optional[int] = 32,
108
- norm_eps: float = 1e-5,
109
- cross_attention_dim: int = 1024,
110
- attention_head_dim: Union[int, Tuple[int]] = 64,
111
- num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
112
- ):
113
- super().__init__()
114
-
115
- self.sample_size = sample_size
116
-
117
- if num_attention_heads is not None:
118
- raise NotImplementedError(
119
- "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
120
- )
121
-
122
- # If `num_attention_heads` is not defined (which is the case for most models)
123
- # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
124
- # The reason for this behavior is to correct for incorrectly named variables that were introduced
125
- # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
126
- # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
127
- # which is why we correct for the naming here.
128
- num_attention_heads = num_attention_heads or attention_head_dim
129
-
130
- # Check inputs
131
- if len(down_block_types) != len(up_block_types):
132
- raise ValueError(
133
- f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
134
- )
135
-
136
- if len(block_out_channels) != len(down_block_types):
137
- raise ValueError(
138
- f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
139
- )
140
-
141
- if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
142
- raise ValueError(
143
- f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
144
- )
145
-
146
- # input
147
- conv_in_kernel = 3
148
- conv_out_kernel = 3
149
- conv_in_padding = (conv_in_kernel - 1) // 2
150
- self.conv_in = nn.Conv2d(
151
- in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
152
- )
153
-
154
- # time
155
- time_embed_dim = block_out_channels[0] * 4
156
- self.time_proj = Timesteps(block_out_channels[0], True, 0)
157
- timestep_input_dim = block_out_channels[0]
158
-
159
- self.time_embedding = TimestepEmbedding(
160
- timestep_input_dim,
161
- time_embed_dim,
162
- act_fn=act_fn,
163
- )
164
-
165
- self.transformer_in = TransformerTemporalModel(
166
- num_attention_heads=8,
167
- attention_head_dim=attention_head_dim,
168
- in_channels=block_out_channels[0],
169
- num_layers=1,
170
- )
171
-
172
- # class embedding
173
- self.down_blocks = nn.ModuleList([])
174
- self.up_blocks = nn.ModuleList([])
175
-
176
- if isinstance(num_attention_heads, int):
177
- num_attention_heads = (num_attention_heads,) * len(down_block_types)
178
-
179
- # down
180
- output_channel = block_out_channels[0]
181
- for i, down_block_type in enumerate(down_block_types):
182
- input_channel = output_channel
183
- output_channel = block_out_channels[i]
184
- is_final_block = i == len(block_out_channels) - 1
185
-
186
- down_block = get_down_block(
187
- down_block_type,
188
- num_layers=layers_per_block,
189
- in_channels=input_channel,
190
- out_channels=output_channel,
191
- temb_channels=time_embed_dim,
192
- add_downsample=not is_final_block,
193
- resnet_eps=norm_eps,
194
- resnet_act_fn=act_fn,
195
- resnet_groups=norm_num_groups,
196
- cross_attention_dim=cross_attention_dim,
197
- num_attention_heads=num_attention_heads[i],
198
- downsample_padding=downsample_padding,
199
- dual_cross_attention=False,
200
- )
201
- self.down_blocks.append(down_block)
202
-
203
- # mid
204
- self.mid_block = UNetMidBlock3DCrossAttn(
205
- in_channels=block_out_channels[-1],
206
- temb_channels=time_embed_dim,
207
- resnet_eps=norm_eps,
208
- resnet_act_fn=act_fn,
209
- output_scale_factor=mid_block_scale_factor,
210
- cross_attention_dim=cross_attention_dim,
211
- num_attention_heads=num_attention_heads[-1],
212
- resnet_groups=norm_num_groups,
213
- dual_cross_attention=False,
214
- )
215
-
216
- # count how many layers upsample the images
217
- self.num_upsamplers = 0
218
-
219
- # up
220
- reversed_block_out_channels = list(reversed(block_out_channels))
221
- reversed_num_attention_heads = list(reversed(num_attention_heads))
222
-
223
- output_channel = reversed_block_out_channels[0]
224
- for i, up_block_type in enumerate(up_block_types):
225
- is_final_block = i == len(block_out_channels) - 1
226
-
227
- prev_output_channel = output_channel
228
- output_channel = reversed_block_out_channels[i]
229
- input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
230
-
231
- # add upsample block for all BUT final layer
232
- if not is_final_block:
233
- add_upsample = True
234
- self.num_upsamplers += 1
235
- else:
236
- add_upsample = False
237
-
238
- up_block = get_up_block(
239
- up_block_type,
240
- num_layers=layers_per_block + 1,
241
- in_channels=input_channel,
242
- out_channels=output_channel,
243
- prev_output_channel=prev_output_channel,
244
- temb_channels=time_embed_dim,
245
- add_upsample=add_upsample,
246
- resnet_eps=norm_eps,
247
- resnet_act_fn=act_fn,
248
- resnet_groups=norm_num_groups,
249
- cross_attention_dim=cross_attention_dim,
250
- num_attention_heads=reversed_num_attention_heads[i],
251
- dual_cross_attention=False,
252
- )
253
- self.up_blocks.append(up_block)
254
- prev_output_channel = output_channel
255
-
256
- # out
257
- if norm_num_groups is not None:
258
- self.conv_norm_out = nn.GroupNorm(
259
- num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
260
- )
261
- self.conv_act = nn.SiLU()
262
- else:
263
- self.conv_norm_out = None
264
- self.conv_act = None
265
-
266
- conv_out_padding = (conv_out_kernel - 1) // 2
267
- self.conv_out = nn.Conv2d(
268
- block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
269
- )
270
-
271
- @property
272
- # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
273
- def attn_processors(self) -> Dict[str, AttentionProcessor]:
274
- r"""
275
- Returns:
276
- `dict` of attention processors: A dictionary containing all attention processors used in the model with
277
- indexed by its weight name.
278
- """
279
- # set recursively
280
- processors = {}
281
-
282
- def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
283
- if hasattr(module, "set_processor"):
284
- processors[f"{name}.processor"] = module.processor
285
-
286
- for sub_name, child in module.named_children():
287
- fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
288
-
289
- return processors
290
-
291
- for name, module in self.named_children():
292
- fn_recursive_add_processors(name, module, processors)
293
-
294
- return processors
295
-
296
- # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice
297
- def set_attention_slice(self, slice_size):
298
- r"""
299
- Enable sliced attention computation.
300
-
301
- When this option is enabled, the attention module splits the input tensor in slices to compute attention in
302
- several steps. This is useful for saving some memory in exchange for a small decrease in speed.
303
-
304
- Args:
305
- slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
306
- When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
307
- `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
308
- provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
309
- must be a multiple of `slice_size`.
310
- """
311
- sliceable_head_dims = []
312
-
313
- def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
314
- if hasattr(module, "set_attention_slice"):
315
- sliceable_head_dims.append(module.sliceable_head_dim)
316
-
317
- for child in module.children():
318
- fn_recursive_retrieve_sliceable_dims(child)
319
-
320
- # retrieve number of attention layers
321
- for module in self.children():
322
- fn_recursive_retrieve_sliceable_dims(module)
323
-
324
- num_sliceable_layers = len(sliceable_head_dims)
325
-
326
- if slice_size == "auto":
327
- # half the attention head size is usually a good trade-off between
328
- # speed and memory
329
- slice_size = [dim // 2 for dim in sliceable_head_dims]
330
- elif slice_size == "max":
331
- # make smallest slice possible
332
- slice_size = num_sliceable_layers * [1]
333
-
334
- slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
335
-
336
- if len(slice_size) != len(sliceable_head_dims):
337
- raise ValueError(
338
- f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
339
- f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
340
- )
341
-
342
- for i in range(len(slice_size)):
343
- size = slice_size[i]
344
- dim = sliceable_head_dims[i]
345
- if size is not None and size > dim:
346
- raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
347
-
348
- # Recursively walk through all the children.
349
- # Any children which exposes the set_attention_slice method
350
- # gets the message
351
- def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
352
- if hasattr(module, "set_attention_slice"):
353
- module.set_attention_slice(slice_size.pop())
354
-
355
- for child in module.children():
356
- fn_recursive_set_attention_slice(child, slice_size)
357
-
358
- reversed_slice_size = list(reversed(slice_size))
359
- for module in self.children():
360
- fn_recursive_set_attention_slice(module, reversed_slice_size)
361
-
362
- # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
363
- def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
364
- r"""
365
- Sets the attention processor to use to compute attention.
366
-
367
- Parameters:
368
- processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
369
- The instantiated processor class or a dictionary of processor classes that will be set as the processor
370
- for **all** `Attention` layers.
371
-
372
- If `processor` is a dict, the key needs to define the path to the corresponding cross attention
373
- processor. This is strongly recommended when setting trainable attention processors.
374
-
375
- """
376
- count = len(self.attn_processors.keys())
377
-
378
- if isinstance(processor, dict) and len(processor) != count:
379
- raise ValueError(
380
- f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
381
- f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
382
- )
383
-
384
- def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
385
- if hasattr(module, "set_processor"):
386
- if not isinstance(processor, dict):
387
- module.set_processor(processor)
388
- else:
389
- module.set_processor(processor.pop(f"{name}.processor"))
390
-
391
- for sub_name, child in module.named_children():
392
- fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
393
-
394
- for name, module in self.named_children():
395
- fn_recursive_attn_processor(name, module, processor)
396
-
397
- def enable_forward_chunking(self, chunk_size=None, dim=0):
398
- """
399
- Sets the attention processor to use [feed forward
400
- chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers).
401
-
402
- Parameters:
403
- chunk_size (`int`, *optional*):
404
- The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually
405
- over each tensor of dim=`dim`.
406
- dim (`int`, *optional*, defaults to `0`):
407
- The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch)
408
- or dim=1 (sequence length).
409
- """
410
- if dim not in [0, 1]:
411
- raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}")
412
-
413
- # By default chunk size is 1
414
- chunk_size = chunk_size or 1
415
-
416
- def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
417
- if hasattr(module, "set_chunk_feed_forward"):
418
- module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
419
-
420
- for child in module.children():
421
- fn_recursive_feed_forward(child, chunk_size, dim)
422
-
423
- for module in self.children():
424
- fn_recursive_feed_forward(module, chunk_size, dim)
425
-
426
- def disable_forward_chunking(self):
427
- def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
428
- if hasattr(module, "set_chunk_feed_forward"):
429
- module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
430
-
431
- for child in module.children():
432
- fn_recursive_feed_forward(child, chunk_size, dim)
433
-
434
- for module in self.children():
435
- fn_recursive_feed_forward(module, None, 0)
436
-
437
- # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
438
- def set_default_attn_processor(self):
439
- """
440
- Disables custom attention processors and sets the default attention implementation.
441
- """
442
- self.set_attn_processor(AttnProcessor())
443
-
444
- def _set_gradient_checkpointing(self, module, value=False):
445
- if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)):
446
- module.gradient_checkpointing = value
447
-
448
- def forward(
449
- self,
450
- sample: torch.FloatTensor,
451
- timestep: Union[torch.Tensor, float, int],
452
- encoder_hidden_states: torch.Tensor,
453
- class_labels: Optional[torch.Tensor] = None,
454
- timestep_cond: Optional[torch.Tensor] = None,
455
- attention_mask: Optional[torch.Tensor] = None,
456
- cross_attention_kwargs: Optional[Dict[str, Any]] = None,
457
- down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
458
- mid_block_additional_residual: Optional[torch.Tensor] = None,
459
- return_dict: bool = True,
460
- ) -> Union[UNet3DConditionOutput, Tuple]:
461
- r"""
462
- The [`UNet3DConditionModel`] forward method.
463
-
464
- Args:
465
- sample (`torch.FloatTensor`):
466
- The noisy input tensor with the following shape `(batch, num_frames, channel, height, width`.
467
- timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
468
- encoder_hidden_states (`torch.FloatTensor`):
469
- The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
470
- return_dict (`bool`, *optional*, defaults to `True`):
471
- Whether or not to return a [`~models.unet_3d_condition.UNet3DConditionOutput`] instead of a plain
472
- tuple.
473
- cross_attention_kwargs (`dict`, *optional*):
474
- A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
475
-
476
- Returns:
477
- [`~models.unet_3d_condition.UNet3DConditionOutput`] or `tuple`:
478
- If `return_dict` is True, an [`~models.unet_3d_condition.UNet3DConditionOutput`] is returned, otherwise
479
- a `tuple` is returned where the first element is the sample tensor.
480
- """
481
- # By default samples have to be AT least a multiple of the overall upsampling factor.
482
- # The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
483
- # However, the upsampling interpolation output size can be forced to fit any upsampling size
484
- # on the fly if necessary.
485
- default_overall_up_factor = 2**self.num_upsamplers
486
-
487
- # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
488
- forward_upsample_size = False
489
- upsample_size = None
490
-
491
- if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
492
- logger.info("Forward upsample size to force interpolation output size.")
493
- forward_upsample_size = True
494
-
495
- # prepare attention_mask
496
- if attention_mask is not None:
497
- attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
498
- attention_mask = attention_mask.unsqueeze(1)
499
-
500
- # 1. time
501
- timesteps = timestep
502
- if not torch.is_tensor(timesteps):
503
- # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
504
- # This would be a good case for the `match` statement (Python 3.10+)
505
- is_mps = sample.device.type == "mps"
506
- if isinstance(timestep, float):
507
- dtype = torch.float32 if is_mps else torch.float64
508
- else:
509
- dtype = torch.int32 if is_mps else torch.int64
510
- timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
511
- elif len(timesteps.shape) == 0:
512
- timesteps = timesteps[None].to(sample.device)
513
-
514
- # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
515
- num_frames = sample.shape[2]
516
- timesteps = timesteps.expand(sample.shape[0])
517
-
518
- t_emb = self.time_proj(timesteps)
519
-
520
- # timesteps does not contain any weights and will always return f32 tensors
521
- # but time_embedding might actually be running in fp16. so we need to cast here.
522
- # there might be better ways to encapsulate this.
523
- t_emb = t_emb.to(dtype=self.dtype)
524
-
525
- emb = self.time_embedding(t_emb, timestep_cond)
526
- emb = emb.repeat_interleave(repeats=num_frames, dim=0)
527
- encoder_hidden_states = encoder_hidden_states.repeat_interleave(repeats=num_frames, dim=0)
528
-
529
- # 2. pre-process
530
- sample = sample.permute(0, 2, 1, 3, 4).reshape((sample.shape[0] * num_frames, -1) + sample.shape[3:])
531
- sample = self.conv_in(sample)
532
-
533
- sample = self.transformer_in(
534
- sample,
535
- num_frames=num_frames,
536
- cross_attention_kwargs=cross_attention_kwargs,
537
- return_dict=False,
538
- )[0]
539
-
540
- # 3. down
541
- down_block_res_samples = (sample,)
542
- for downsample_block in self.down_blocks:
543
- if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
544
- sample, res_samples = downsample_block(
545
- hidden_states=sample,
546
- temb=emb,
547
- encoder_hidden_states=encoder_hidden_states,
548
- attention_mask=attention_mask,
549
- num_frames=num_frames,
550
- cross_attention_kwargs=cross_attention_kwargs,
551
- )
552
- else:
553
- sample, res_samples = downsample_block(hidden_states=sample, temb=emb, num_frames=num_frames)
554
-
555
- down_block_res_samples += res_samples
556
-
557
- if down_block_additional_residuals is not None:
558
- new_down_block_res_samples = ()
559
-
560
- for down_block_res_sample, down_block_additional_residual in zip(
561
- down_block_res_samples, down_block_additional_residuals
562
- ):
563
- down_block_res_sample = down_block_res_sample + down_block_additional_residual
564
- new_down_block_res_samples += (down_block_res_sample,)
565
-
566
- down_block_res_samples = new_down_block_res_samples
567
-
568
- # 4. mid
569
- if self.mid_block is not None:
570
- sample = self.mid_block(
571
- sample,
572
- emb,
573
- encoder_hidden_states=encoder_hidden_states,
574
- attention_mask=attention_mask,
575
- num_frames=num_frames,
576
- cross_attention_kwargs=cross_attention_kwargs,
577
- )
578
-
579
- if mid_block_additional_residual is not None:
580
- sample = sample + mid_block_additional_residual
581
-
582
- # 5. up
583
- for i, upsample_block in enumerate(self.up_blocks):
584
- is_final_block = i == len(self.up_blocks) - 1
585
-
586
- res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
587
- down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
588
-
589
- # if we have not reached the final block and need to forward the
590
- # upsample size, we do it here
591
- if not is_final_block and forward_upsample_size:
592
- upsample_size = down_block_res_samples[-1].shape[2:]
593
-
594
- if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
595
- sample = upsample_block(
596
- hidden_states=sample,
597
- temb=emb,
598
- res_hidden_states_tuple=res_samples,
599
- encoder_hidden_states=encoder_hidden_states,
600
- upsample_size=upsample_size,
601
- attention_mask=attention_mask,
602
- num_frames=num_frames,
603
- cross_attention_kwargs=cross_attention_kwargs,
604
- )
605
- else:
606
- sample = upsample_block(
607
- hidden_states=sample,
608
- temb=emb,
609
- res_hidden_states_tuple=res_samples,
610
- upsample_size=upsample_size,
611
- num_frames=num_frames,
612
- )
613
-
614
- # 6. post-process
615
- if self.conv_norm_out:
616
- sample = self.conv_norm_out(sample)
617
- sample = self.conv_act(sample)
618
-
619
- sample = self.conv_out(sample)
620
-
621
- # reshape to (batch, channel, framerate, width, height)
622
- sample = sample[None, :].reshape((-1, num_frames) + sample.shape[1:]).permute(0, 2, 1, 3, 4)
623
-
624
- if not return_dict:
625
- return (sample,)
626
-
627
- return UNet3DConditionOutput(sample=sample)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
gradio_demo/6DoF/diffusers/models/vae.py DELETED
@@ -1,441 +0,0 @@
1
- # Copyright 2023 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- from dataclasses import dataclass
15
- from typing import Optional
16
-
17
- import numpy as np
18
- import torch
19
- import torch.nn as nn
20
-
21
- from ..utils import BaseOutput, is_torch_version, randn_tensor
22
- from .attention_processor import SpatialNorm
23
- from .unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block
24
-
25
-
26
- @dataclass
27
- class DecoderOutput(BaseOutput):
28
- """
29
- Output of decoding method.
30
-
31
- Args:
32
- sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
33
- The decoded output sample from the last layer of the model.
34
- """
35
-
36
- sample: torch.FloatTensor
37
-
38
-
39
- class Encoder(nn.Module):
40
- def __init__(
41
- self,
42
- in_channels=3,
43
- out_channels=3,
44
- down_block_types=("DownEncoderBlock2D",),
45
- block_out_channels=(64,),
46
- layers_per_block=2,
47
- norm_num_groups=32,
48
- act_fn="silu",
49
- double_z=True,
50
- ):
51
- super().__init__()
52
- self.layers_per_block = layers_per_block
53
-
54
- self.conv_in = torch.nn.Conv2d(
55
- in_channels,
56
- block_out_channels[0],
57
- kernel_size=3,
58
- stride=1,
59
- padding=1,
60
- )
61
-
62
- self.mid_block = None
63
- self.down_blocks = nn.ModuleList([])
64
-
65
- # down
66
- output_channel = block_out_channels[0]
67
- for i, down_block_type in enumerate(down_block_types):
68
- input_channel = output_channel
69
- output_channel = block_out_channels[i]
70
- is_final_block = i == len(block_out_channels) - 1
71
-
72
- down_block = get_down_block(
73
- down_block_type,
74
- num_layers=self.layers_per_block,
75
- in_channels=input_channel,
76
- out_channels=output_channel,
77
- add_downsample=not is_final_block,
78
- resnet_eps=1e-6,
79
- downsample_padding=0,
80
- resnet_act_fn=act_fn,
81
- resnet_groups=norm_num_groups,
82
- attention_head_dim=output_channel,
83
- temb_channels=None,
84
- )
85
- self.down_blocks.append(down_block)
86
-
87
- # mid
88
- self.mid_block = UNetMidBlock2D(
89
- in_channels=block_out_channels[-1],
90
- resnet_eps=1e-6,
91
- resnet_act_fn=act_fn,
92
- output_scale_factor=1,
93
- resnet_time_scale_shift="default",
94
- attention_head_dim=block_out_channels[-1],
95
- resnet_groups=norm_num_groups,
96
- temb_channels=None,
97
- )
98
-
99
- # out
100
- self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6)
101
- self.conv_act = nn.SiLU()
102
-
103
- conv_out_channels = 2 * out_channels if double_z else out_channels
104
- self.conv_out = nn.Conv2d(block_out_channels[-1], conv_out_channels, 3, padding=1)
105
-
106
- self.gradient_checkpointing = False
107
-
108
- def forward(self, x):
109
- sample = x
110
- sample = self.conv_in(sample)
111
-
112
- if self.training and self.gradient_checkpointing:
113
-
114
- def create_custom_forward(module):
115
- def custom_forward(*inputs):
116
- return module(*inputs)
117
-
118
- return custom_forward
119
-
120
- # down
121
- if is_torch_version(">=", "1.11.0"):
122
- for down_block in self.down_blocks:
123
- sample = torch.utils.checkpoint.checkpoint(
124
- create_custom_forward(down_block), sample, use_reentrant=False
125
- )
126
- # middle
127
- sample = torch.utils.checkpoint.checkpoint(
128
- create_custom_forward(self.mid_block), sample, use_reentrant=False
129
- )
130
- else:
131
- for down_block in self.down_blocks:
132
- sample = torch.utils.checkpoint.checkpoint(create_custom_forward(down_block), sample)
133
- # middle
134
- sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample)
135
-
136
- else:
137
- # down
138
- for down_block in self.down_blocks:
139
- sample = down_block(sample)
140
-
141
- # middle
142
- sample = self.mid_block(sample)
143
-
144
- # post-process
145
- sample = self.conv_norm_out(sample)
146
- sample = self.conv_act(sample)
147
- sample = self.conv_out(sample)
148
-
149
- return sample
150
-
151
-
152
- class Decoder(nn.Module):
153
- def __init__(
154
- self,
155
- in_channels=3,
156
- out_channels=3,
157
- up_block_types=("UpDecoderBlock2D",),
158
- block_out_channels=(64,),
159
- layers_per_block=2,
160
- norm_num_groups=32,
161
- act_fn="silu",
162
- norm_type="group", # group, spatial
163
- ):
164
- super().__init__()
165
- self.layers_per_block = layers_per_block
166
-
167
- self.conv_in = nn.Conv2d(
168
- in_channels,
169
- block_out_channels[-1],
170
- kernel_size=3,
171
- stride=1,
172
- padding=1,
173
- )
174
-
175
- self.mid_block = None
176
- self.up_blocks = nn.ModuleList([])
177
-
178
- temb_channels = in_channels if norm_type == "spatial" else None
179
-
180
- # mid
181
- self.mid_block = UNetMidBlock2D(
182
- in_channels=block_out_channels[-1],
183
- resnet_eps=1e-6,
184
- resnet_act_fn=act_fn,
185
- output_scale_factor=1,
186
- resnet_time_scale_shift="default" if norm_type == "group" else norm_type,
187
- attention_head_dim=block_out_channels[-1],
188
- resnet_groups=norm_num_groups,
189
- temb_channels=temb_channels,
190
- )
191
-
192
- # up
193
- reversed_block_out_channels = list(reversed(block_out_channels))
194
- output_channel = reversed_block_out_channels[0]
195
- for i, up_block_type in enumerate(up_block_types):
196
- prev_output_channel = output_channel
197
- output_channel = reversed_block_out_channels[i]
198
-
199
- is_final_block = i == len(block_out_channels) - 1
200
-
201
- up_block = get_up_block(
202
- up_block_type,
203
- num_layers=self.layers_per_block + 1,
204
- in_channels=prev_output_channel,
205
- out_channels=output_channel,
206
- prev_output_channel=None,
207
- add_upsample=not is_final_block,
208
- resnet_eps=1e-6,
209
- resnet_act_fn=act_fn,
210
- resnet_groups=norm_num_groups,
211
- attention_head_dim=output_channel,
212
- temb_channels=temb_channels,
213
- resnet_time_scale_shift=norm_type,
214
- )
215
- self.up_blocks.append(up_block)
216
- prev_output_channel = output_channel
217
-
218
- # out
219
- if norm_type == "spatial":
220
- self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels)
221
- else:
222
- self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)
223
- self.conv_act = nn.SiLU()
224
- self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
225
-
226
- self.gradient_checkpointing = False
227
-
228
- def forward(self, z, latent_embeds=None):
229
- sample = z
230
- sample = self.conv_in(sample)
231
-
232
- upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
233
- if self.training and self.gradient_checkpointing:
234
-
235
- def create_custom_forward(module):
236
- def custom_forward(*inputs):
237
- return module(*inputs)
238
-
239
- return custom_forward
240
-
241
- if is_torch_version(">=", "1.11.0"):
242
- # middle
243
- sample = torch.utils.checkpoint.checkpoint(
244
- create_custom_forward(self.mid_block), sample, latent_embeds, use_reentrant=False
245
- )
246
- sample = sample.to(upscale_dtype)
247
-
248
- # up
249
- for up_block in self.up_blocks:
250
- sample = torch.utils.checkpoint.checkpoint(
251
- create_custom_forward(up_block), sample, latent_embeds, use_reentrant=False
252
- )
253
- else:
254
- # middle
255
- sample = torch.utils.checkpoint.checkpoint(
256
- create_custom_forward(self.mid_block), sample, latent_embeds
257
- )
258
- sample = sample.to(upscale_dtype)
259
-
260
- # up
261
- for up_block in self.up_blocks:
262
- sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample, latent_embeds)
263
- else:
264
- # middle
265
- sample = self.mid_block(sample, latent_embeds)
266
- sample = sample.to(upscale_dtype)
267
-
268
- # up
269
- for up_block in self.up_blocks:
270
- sample = up_block(sample, latent_embeds)
271
-
272
- # post-process
273
- if latent_embeds is None:
274
- sample = self.conv_norm_out(sample)
275
- else:
276
- sample = self.conv_norm_out(sample, latent_embeds)
277
- sample = self.conv_act(sample)
278
- sample = self.conv_out(sample)
279
-
280
- return sample
281
-
282
-
283
- class VectorQuantizer(nn.Module):
284
- """
285
- Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly avoids costly matrix
286
- multiplications and allows for post-hoc remapping of indices.
287
- """
288
-
289
- # NOTE: due to a bug the beta term was applied to the wrong term. for
290
- # backwards compatibility we use the buggy version by default, but you can
291
- # specify legacy=False to fix it.
292
- def __init__(
293
- self, n_e, vq_embed_dim, beta, remap=None, unknown_index="random", sane_index_shape=False, legacy=True
294
- ):
295
- super().__init__()
296
- self.n_e = n_e
297
- self.vq_embed_dim = vq_embed_dim
298
- self.beta = beta
299
- self.legacy = legacy
300
-
301
- self.embedding = nn.Embedding(self.n_e, self.vq_embed_dim)
302
- self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
303
-
304
- self.remap = remap
305
- if self.remap is not None:
306
- self.register_buffer("used", torch.tensor(np.load(self.remap)))
307
- self.re_embed = self.used.shape[0]
308
- self.unknown_index = unknown_index # "random" or "extra" or integer
309
- if self.unknown_index == "extra":
310
- self.unknown_index = self.re_embed
311
- self.re_embed = self.re_embed + 1
312
- print(
313
- f"Remapping {self.n_e} indices to {self.re_embed} indices. "
314
- f"Using {self.unknown_index} for unknown indices."
315
- )
316
- else:
317
- self.re_embed = n_e
318
-
319
- self.sane_index_shape = sane_index_shape
320
-
321
- def remap_to_used(self, inds):
322
- ishape = inds.shape
323
- assert len(ishape) > 1
324
- inds = inds.reshape(ishape[0], -1)
325
- used = self.used.to(inds)
326
- match = (inds[:, :, None] == used[None, None, ...]).long()
327
- new = match.argmax(-1)
328
- unknown = match.sum(2) < 1
329
- if self.unknown_index == "random":
330
- new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device)
331
- else:
332
- new[unknown] = self.unknown_index
333
- return new.reshape(ishape)
334
-
335
- def unmap_to_all(self, inds):
336
- ishape = inds.shape
337
- assert len(ishape) > 1
338
- inds = inds.reshape(ishape[0], -1)
339
- used = self.used.to(inds)
340
- if self.re_embed > self.used.shape[0]: # extra token
341
- inds[inds >= self.used.shape[0]] = 0 # simply set to zero
342
- back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds)
343
- return back.reshape(ishape)
344
-
345
- def forward(self, z):
346
- # reshape z -> (batch, height, width, channel) and flatten
347
- z = z.permute(0, 2, 3, 1).contiguous()
348
- z_flattened = z.view(-1, self.vq_embed_dim)
349
-
350
- # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
351
- min_encoding_indices = torch.argmin(torch.cdist(z_flattened, self.embedding.weight), dim=1)
352
-
353
- z_q = self.embedding(min_encoding_indices).view(z.shape)
354
- perplexity = None
355
- min_encodings = None
356
-
357
- # compute loss for embedding
358
- if not self.legacy:
359
- loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean((z_q - z.detach()) ** 2)
360
- else:
361
- loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean((z_q - z.detach()) ** 2)
362
-
363
- # preserve gradients
364
- z_q = z + (z_q - z).detach()
365
-
366
- # reshape back to match original input shape
367
- z_q = z_q.permute(0, 3, 1, 2).contiguous()
368
-
369
- if self.remap is not None:
370
- min_encoding_indices = min_encoding_indices.reshape(z.shape[0], -1) # add batch axis
371
- min_encoding_indices = self.remap_to_used(min_encoding_indices)
372
- min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten
373
-
374
- if self.sane_index_shape:
375
- min_encoding_indices = min_encoding_indices.reshape(z_q.shape[0], z_q.shape[2], z_q.shape[3])
376
-
377
- return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
378
-
379
- def get_codebook_entry(self, indices, shape):
380
- # shape specifying (batch, height, width, channel)
381
- if self.remap is not None:
382
- indices = indices.reshape(shape[0], -1) # add batch axis
383
- indices = self.unmap_to_all(indices)
384
- indices = indices.reshape(-1) # flatten again
385
-
386
- # get quantized latent vectors
387
- z_q = self.embedding(indices)
388
-
389
- if shape is not None:
390
- z_q = z_q.view(shape)
391
- # reshape back to match original input shape
392
- z_q = z_q.permute(0, 3, 1, 2).contiguous()
393
-
394
- return z_q
395
-
396
-
397
- class DiagonalGaussianDistribution(object):
398
- def __init__(self, parameters, deterministic=False):
399
- self.parameters = parameters
400
- self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
401
- self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
402
- self.deterministic = deterministic
403
- self.std = torch.exp(0.5 * self.logvar)
404
- self.var = torch.exp(self.logvar)
405
- if self.deterministic:
406
- self.var = self.std = torch.zeros_like(
407
- self.mean, device=self.parameters.device, dtype=self.parameters.dtype
408
- )
409
-
410
- def sample(self, generator: Optional[torch.Generator] = None) -> torch.FloatTensor:
411
- # make sure sample is on the same device as the parameters and has same dtype
412
- sample = randn_tensor(
413
- self.mean.shape, generator=generator, device=self.parameters.device, dtype=self.parameters.dtype
414
- )
415
- x = self.mean + self.std * sample
416
- return x
417
-
418
- def kl(self, other=None):
419
- if self.deterministic:
420
- return torch.Tensor([0.0])
421
- else:
422
- if other is None:
423
- return 0.5 * torch.sum(torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, dim=[1, 2, 3])
424
- else:
425
- return 0.5 * torch.sum(
426
- torch.pow(self.mean - other.mean, 2) / other.var
427
- + self.var / other.var
428
- - 1.0
429
- - self.logvar
430
- + other.logvar,
431
- dim=[1, 2, 3],
432
- )
433
-
434
- def nll(self, sample, dims=[1, 2, 3]):
435
- if self.deterministic:
436
- return torch.Tensor([0.0])
437
- logtwopi = np.log(2.0 * np.pi)
438
- return 0.5 * torch.sum(logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, dim=dims)
439
-
440
- def mode(self):
441
- return self.mean
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
gradio_demo/6DoF/diffusers/models/vae_flax.py DELETED
@@ -1,869 +0,0 @@
1
- # Copyright 2023 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- # JAX implementation of VQGAN from taming-transformers https://github.com/CompVis/taming-transformers
16
-
17
- import math
18
- from functools import partial
19
- from typing import Tuple
20
-
21
- import flax
22
- import flax.linen as nn
23
- import jax
24
- import jax.numpy as jnp
25
- from flax.core.frozen_dict import FrozenDict
26
-
27
- from ..configuration_utils import ConfigMixin, flax_register_to_config
28
- from ..utils import BaseOutput
29
- from .modeling_flax_utils import FlaxModelMixin
30
-
31
-
32
- @flax.struct.dataclass
33
- class FlaxDecoderOutput(BaseOutput):
34
- """
35
- Output of decoding method.
36
-
37
- Args:
38
- sample (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)`):
39
- The decoded output sample from the last layer of the model.
40
- dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`):
41
- The `dtype` of the parameters.
42
- """
43
-
44
- sample: jnp.ndarray
45
-
46
-
47
- @flax.struct.dataclass
48
- class FlaxAutoencoderKLOutput(BaseOutput):
49
- """
50
- Output of AutoencoderKL encoding method.
51
-
52
- Args:
53
- latent_dist (`FlaxDiagonalGaussianDistribution`):
54
- Encoded outputs of `Encoder` represented as the mean and logvar of `FlaxDiagonalGaussianDistribution`.
55
- `FlaxDiagonalGaussianDistribution` allows for sampling latents from the distribution.
56
- """
57
-
58
- latent_dist: "FlaxDiagonalGaussianDistribution"
59
-
60
-
61
- class FlaxUpsample2D(nn.Module):
62
- """
63
- Flax implementation of 2D Upsample layer
64
-
65
- Args:
66
- in_channels (`int`):
67
- Input channels
68
- dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
69
- Parameters `dtype`
70
- """
71
-
72
- in_channels: int
73
- dtype: jnp.dtype = jnp.float32
74
-
75
- def setup(self):
76
- self.conv = nn.Conv(
77
- self.in_channels,
78
- kernel_size=(3, 3),
79
- strides=(1, 1),
80
- padding=((1, 1), (1, 1)),
81
- dtype=self.dtype,
82
- )
83
-
84
- def __call__(self, hidden_states):
85
- batch, height, width, channels = hidden_states.shape
86
- hidden_states = jax.image.resize(
87
- hidden_states,
88
- shape=(batch, height * 2, width * 2, channels),
89
- method="nearest",
90
- )
91
- hidden_states = self.conv(hidden_states)
92
- return hidden_states
93
-
94
-
95
- class FlaxDownsample2D(nn.Module):
96
- """
97
- Flax implementation of 2D Downsample layer
98
-
99
- Args:
100
- in_channels (`int`):
101
- Input channels
102
- dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
103
- Parameters `dtype`
104
- """
105
-
106
- in_channels: int
107
- dtype: jnp.dtype = jnp.float32
108
-
109
- def setup(self):
110
- self.conv = nn.Conv(
111
- self.in_channels,
112
- kernel_size=(3, 3),
113
- strides=(2, 2),
114
- padding="VALID",
115
- dtype=self.dtype,
116
- )
117
-
118
- def __call__(self, hidden_states):
119
- pad = ((0, 0), (0, 1), (0, 1), (0, 0)) # pad height and width dim
120
- hidden_states = jnp.pad(hidden_states, pad_width=pad)
121
- hidden_states = self.conv(hidden_states)
122
- return hidden_states
123
-
124
-
125
- class FlaxResnetBlock2D(nn.Module):
126
- """
127
- Flax implementation of 2D Resnet Block.
128
-
129
- Args:
130
- in_channels (`int`):
131
- Input channels
132
- out_channels (`int`):
133
- Output channels
134
- dropout (:obj:`float`, *optional*, defaults to 0.0):
135
- Dropout rate
136
- groups (:obj:`int`, *optional*, defaults to `32`):
137
- The number of groups to use for group norm.
138
- use_nin_shortcut (:obj:`bool`, *optional*, defaults to `None`):
139
- Whether to use `nin_shortcut`. This activates a new layer inside ResNet block
140
- dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
141
- Parameters `dtype`
142
- """
143
-
144
- in_channels: int
145
- out_channels: int = None
146
- dropout: float = 0.0
147
- groups: int = 32
148
- use_nin_shortcut: bool = None
149
- dtype: jnp.dtype = jnp.float32
150
-
151
- def setup(self):
152
- out_channels = self.in_channels if self.out_channels is None else self.out_channels
153
-
154
- self.norm1 = nn.GroupNorm(num_groups=self.groups, epsilon=1e-6)
155
- self.conv1 = nn.Conv(
156
- out_channels,
157
- kernel_size=(3, 3),
158
- strides=(1, 1),
159
- padding=((1, 1), (1, 1)),
160
- dtype=self.dtype,
161
- )
162
-
163
- self.norm2 = nn.GroupNorm(num_groups=self.groups, epsilon=1e-6)
164
- self.dropout_layer = nn.Dropout(self.dropout)
165
- self.conv2 = nn.Conv(
166
- out_channels,
167
- kernel_size=(3, 3),
168
- strides=(1, 1),
169
- padding=((1, 1), (1, 1)),
170
- dtype=self.dtype,
171
- )
172
-
173
- use_nin_shortcut = self.in_channels != out_channels if self.use_nin_shortcut is None else self.use_nin_shortcut
174
-
175
- self.conv_shortcut = None
176
- if use_nin_shortcut:
177
- self.conv_shortcut = nn.Conv(
178
- out_channels,
179
- kernel_size=(1, 1),
180
- strides=(1, 1),
181
- padding="VALID",
182
- dtype=self.dtype,
183
- )
184
-
185
- def __call__(self, hidden_states, deterministic=True):
186
- residual = hidden_states
187
- hidden_states = self.norm1(hidden_states)
188
- hidden_states = nn.swish(hidden_states)
189
- hidden_states = self.conv1(hidden_states)
190
-
191
- hidden_states = self.norm2(hidden_states)
192
- hidden_states = nn.swish(hidden_states)
193
- hidden_states = self.dropout_layer(hidden_states, deterministic)
194
- hidden_states = self.conv2(hidden_states)
195
-
196
- if self.conv_shortcut is not None:
197
- residual = self.conv_shortcut(residual)
198
-
199
- return hidden_states + residual
200
-
201
-
202
- class FlaxAttentionBlock(nn.Module):
203
- r"""
204
- Flax Convolutional based multi-head attention block for diffusion-based VAE.
205
-
206
- Parameters:
207
- channels (:obj:`int`):
208
- Input channels
209
- num_head_channels (:obj:`int`, *optional*, defaults to `None`):
210
- Number of attention heads
211
- num_groups (:obj:`int`, *optional*, defaults to `32`):
212
- The number of groups to use for group norm
213
- dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
214
- Parameters `dtype`
215
-
216
- """
217
- channels: int
218
- num_head_channels: int = None
219
- num_groups: int = 32
220
- dtype: jnp.dtype = jnp.float32
221
-
222
- def setup(self):
223
- self.num_heads = self.channels // self.num_head_channels if self.num_head_channels is not None else 1
224
-
225
- dense = partial(nn.Dense, self.channels, dtype=self.dtype)
226
-
227
- self.group_norm = nn.GroupNorm(num_groups=self.num_groups, epsilon=1e-6)
228
- self.query, self.key, self.value = dense(), dense(), dense()
229
- self.proj_attn = dense()
230
-
231
- def transpose_for_scores(self, projection):
232
- new_projection_shape = projection.shape[:-1] + (self.num_heads, -1)
233
- # move heads to 2nd position (B, T, H * D) -> (B, T, H, D)
234
- new_projection = projection.reshape(new_projection_shape)
235
- # (B, T, H, D) -> (B, H, T, D)
236
- new_projection = jnp.transpose(new_projection, (0, 2, 1, 3))
237
- return new_projection
238
-
239
- def __call__(self, hidden_states):
240
- residual = hidden_states
241
- batch, height, width, channels = hidden_states.shape
242
-
243
- hidden_states = self.group_norm(hidden_states)
244
-
245
- hidden_states = hidden_states.reshape((batch, height * width, channels))
246
-
247
- query = self.query(hidden_states)
248
- key = self.key(hidden_states)
249
- value = self.value(hidden_states)
250
-
251
- # transpose
252
- query = self.transpose_for_scores(query)
253
- key = self.transpose_for_scores(key)
254
- value = self.transpose_for_scores(value)
255
-
256
- # compute attentions
257
- scale = 1 / math.sqrt(math.sqrt(self.channels / self.num_heads))
258
- attn_weights = jnp.einsum("...qc,...kc->...qk", query * scale, key * scale)
259
- attn_weights = nn.softmax(attn_weights, axis=-1)
260
-
261
- # attend to values
262
- hidden_states = jnp.einsum("...kc,...qk->...qc", value, attn_weights)
263
-
264
- hidden_states = jnp.transpose(hidden_states, (0, 2, 1, 3))
265
- new_hidden_states_shape = hidden_states.shape[:-2] + (self.channels,)
266
- hidden_states = hidden_states.reshape(new_hidden_states_shape)
267
-
268
- hidden_states = self.proj_attn(hidden_states)
269
- hidden_states = hidden_states.reshape((batch, height, width, channels))
270
- hidden_states = hidden_states + residual
271
- return hidden_states
272
-
273
-
274
- class FlaxDownEncoderBlock2D(nn.Module):
275
- r"""
276
- Flax Resnet blocks-based Encoder block for diffusion-based VAE.
277
-
278
- Parameters:
279
- in_channels (:obj:`int`):
280
- Input channels
281
- out_channels (:obj:`int`):
282
- Output channels
283
- dropout (:obj:`float`, *optional*, defaults to 0.0):
284
- Dropout rate
285
- num_layers (:obj:`int`, *optional*, defaults to 1):
286
- Number of Resnet layer block
287
- resnet_groups (:obj:`int`, *optional*, defaults to `32`):
288
- The number of groups to use for the Resnet block group norm
289
- add_downsample (:obj:`bool`, *optional*, defaults to `True`):
290
- Whether to add downsample layer
291
- dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
292
- Parameters `dtype`
293
- """
294
- in_channels: int
295
- out_channels: int
296
- dropout: float = 0.0
297
- num_layers: int = 1
298
- resnet_groups: int = 32
299
- add_downsample: bool = True
300
- dtype: jnp.dtype = jnp.float32
301
-
302
- def setup(self):
303
- resnets = []
304
- for i in range(self.num_layers):
305
- in_channels = self.in_channels if i == 0 else self.out_channels
306
-
307
- res_block = FlaxResnetBlock2D(
308
- in_channels=in_channels,
309
- out_channels=self.out_channels,
310
- dropout=self.dropout,
311
- groups=self.resnet_groups,
312
- dtype=self.dtype,
313
- )
314
- resnets.append(res_block)
315
- self.resnets = resnets
316
-
317
- if self.add_downsample:
318
- self.downsamplers_0 = FlaxDownsample2D(self.out_channels, dtype=self.dtype)
319
-
320
- def __call__(self, hidden_states, deterministic=True):
321
- for resnet in self.resnets:
322
- hidden_states = resnet(hidden_states, deterministic=deterministic)
323
-
324
- if self.add_downsample:
325
- hidden_states = self.downsamplers_0(hidden_states)
326
-
327
- return hidden_states
328
-
329
-
330
- class FlaxUpDecoderBlock2D(nn.Module):
331
- r"""
332
- Flax Resnet blocks-based Decoder block for diffusion-based VAE.
333
-
334
- Parameters:
335
- in_channels (:obj:`int`):
336
- Input channels
337
- out_channels (:obj:`int`):
338
- Output channels
339
- dropout (:obj:`float`, *optional*, defaults to 0.0):
340
- Dropout rate
341
- num_layers (:obj:`int`, *optional*, defaults to 1):
342
- Number of Resnet layer block
343
- resnet_groups (:obj:`int`, *optional*, defaults to `32`):
344
- The number of groups to use for the Resnet block group norm
345
- add_upsample (:obj:`bool`, *optional*, defaults to `True`):
346
- Whether to add upsample layer
347
- dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
348
- Parameters `dtype`
349
- """
350
- in_channels: int
351
- out_channels: int
352
- dropout: float = 0.0
353
- num_layers: int = 1
354
- resnet_groups: int = 32
355
- add_upsample: bool = True
356
- dtype: jnp.dtype = jnp.float32
357
-
358
- def setup(self):
359
- resnets = []
360
- for i in range(self.num_layers):
361
- in_channels = self.in_channels if i == 0 else self.out_channels
362
- res_block = FlaxResnetBlock2D(
363
- in_channels=in_channels,
364
- out_channels=self.out_channels,
365
- dropout=self.dropout,
366
- groups=self.resnet_groups,
367
- dtype=self.dtype,
368
- )
369
- resnets.append(res_block)
370
-
371
- self.resnets = resnets
372
-
373
- if self.add_upsample:
374
- self.upsamplers_0 = FlaxUpsample2D(self.out_channels, dtype=self.dtype)
375
-
376
- def __call__(self, hidden_states, deterministic=True):
377
- for resnet in self.resnets:
378
- hidden_states = resnet(hidden_states, deterministic=deterministic)
379
-
380
- if self.add_upsample:
381
- hidden_states = self.upsamplers_0(hidden_states)
382
-
383
- return hidden_states
384
-
385
-
386
- class FlaxUNetMidBlock2D(nn.Module):
387
- r"""
388
- Flax Unet Mid-Block module.
389
-
390
- Parameters:
391
- in_channels (:obj:`int`):
392
- Input channels
393
- dropout (:obj:`float`, *optional*, defaults to 0.0):
394
- Dropout rate
395
- num_layers (:obj:`int`, *optional*, defaults to 1):
396
- Number of Resnet layer block
397
- resnet_groups (:obj:`int`, *optional*, defaults to `32`):
398
- The number of groups to use for the Resnet and Attention block group norm
399
- num_attention_heads (:obj:`int`, *optional*, defaults to `1`):
400
- Number of attention heads for each attention block
401
- dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
402
- Parameters `dtype`
403
- """
404
- in_channels: int
405
- dropout: float = 0.0
406
- num_layers: int = 1
407
- resnet_groups: int = 32
408
- num_attention_heads: int = 1
409
- dtype: jnp.dtype = jnp.float32
410
-
411
- def setup(self):
412
- resnet_groups = self.resnet_groups if self.resnet_groups is not None else min(self.in_channels // 4, 32)
413
-
414
- # there is always at least one resnet
415
- resnets = [
416
- FlaxResnetBlock2D(
417
- in_channels=self.in_channels,
418
- out_channels=self.in_channels,
419
- dropout=self.dropout,
420
- groups=resnet_groups,
421
- dtype=self.dtype,
422
- )
423
- ]
424
-
425
- attentions = []
426
-
427
- for _ in range(self.num_layers):
428
- attn_block = FlaxAttentionBlock(
429
- channels=self.in_channels,
430
- num_head_channels=self.num_attention_heads,
431
- num_groups=resnet_groups,
432
- dtype=self.dtype,
433
- )
434
- attentions.append(attn_block)
435
-
436
- res_block = FlaxResnetBlock2D(
437
- in_channels=self.in_channels,
438
- out_channels=self.in_channels,
439
- dropout=self.dropout,
440
- groups=resnet_groups,
441
- dtype=self.dtype,
442
- )
443
- resnets.append(res_block)
444
-
445
- self.resnets = resnets
446
- self.attentions = attentions
447
-
448
- def __call__(self, hidden_states, deterministic=True):
449
- hidden_states = self.resnets[0](hidden_states, deterministic=deterministic)
450
- for attn, resnet in zip(self.attentions, self.resnets[1:]):
451
- hidden_states = attn(hidden_states)
452
- hidden_states = resnet(hidden_states, deterministic=deterministic)
453
-
454
- return hidden_states
455
-
456
-
457
- class FlaxEncoder(nn.Module):
458
- r"""
459
- Flax Implementation of VAE Encoder.
460
-
461
- This model is a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module)
462
- subclass. Use it as a regular Flax linen Module and refer to the Flax documentation for all matter related to
463
- general usage and behavior.
464
-
465
- Finally, this model supports inherent JAX features such as:
466
- - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
467
- - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
468
- - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
469
- - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
470
-
471
- Parameters:
472
- in_channels (:obj:`int`, *optional*, defaults to 3):
473
- Input channels
474
- out_channels (:obj:`int`, *optional*, defaults to 3):
475
- Output channels
476
- down_block_types (:obj:`Tuple[str]`, *optional*, defaults to `(DownEncoderBlock2D)`):
477
- DownEncoder block type
478
- block_out_channels (:obj:`Tuple[str]`, *optional*, defaults to `(64,)`):
479
- Tuple containing the number of output channels for each block
480
- layers_per_block (:obj:`int`, *optional*, defaults to `2`):
481
- Number of Resnet layer for each block
482
- norm_num_groups (:obj:`int`, *optional*, defaults to `32`):
483
- norm num group
484
- act_fn (:obj:`str`, *optional*, defaults to `silu`):
485
- Activation function
486
- double_z (:obj:`bool`, *optional*, defaults to `False`):
487
- Whether to double the last output channels
488
- dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
489
- Parameters `dtype`
490
- """
491
- in_channels: int = 3
492
- out_channels: int = 3
493
- down_block_types: Tuple[str] = ("DownEncoderBlock2D",)
494
- block_out_channels: Tuple[int] = (64,)
495
- layers_per_block: int = 2
496
- norm_num_groups: int = 32
497
- act_fn: str = "silu"
498
- double_z: bool = False
499
- dtype: jnp.dtype = jnp.float32
500
-
501
- def setup(self):
502
- block_out_channels = self.block_out_channels
503
- # in
504
- self.conv_in = nn.Conv(
505
- block_out_channels[0],
506
- kernel_size=(3, 3),
507
- strides=(1, 1),
508
- padding=((1, 1), (1, 1)),
509
- dtype=self.dtype,
510
- )
511
-
512
- # downsampling
513
- down_blocks = []
514
- output_channel = block_out_channels[0]
515
- for i, _ in enumerate(self.down_block_types):
516
- input_channel = output_channel
517
- output_channel = block_out_channels[i]
518
- is_final_block = i == len(block_out_channels) - 1
519
-
520
- down_block = FlaxDownEncoderBlock2D(
521
- in_channels=input_channel,
522
- out_channels=output_channel,
523
- num_layers=self.layers_per_block,
524
- resnet_groups=self.norm_num_groups,
525
- add_downsample=not is_final_block,
526
- dtype=self.dtype,
527
- )
528
- down_blocks.append(down_block)
529
- self.down_blocks = down_blocks
530
-
531
- # middle
532
- self.mid_block = FlaxUNetMidBlock2D(
533
- in_channels=block_out_channels[-1],
534
- resnet_groups=self.norm_num_groups,
535
- num_attention_heads=None,
536
- dtype=self.dtype,
537
- )
538
-
539
- # end
540
- conv_out_channels = 2 * self.out_channels if self.double_z else self.out_channels
541
- self.conv_norm_out = nn.GroupNorm(num_groups=self.norm_num_groups, epsilon=1e-6)
542
- self.conv_out = nn.Conv(
543
- conv_out_channels,
544
- kernel_size=(3, 3),
545
- strides=(1, 1),
546
- padding=((1, 1), (1, 1)),
547
- dtype=self.dtype,
548
- )
549
-
550
- def __call__(self, sample, deterministic: bool = True):
551
- # in
552
- sample = self.conv_in(sample)
553
-
554
- # downsampling
555
- for block in self.down_blocks:
556
- sample = block(sample, deterministic=deterministic)
557
-
558
- # middle
559
- sample = self.mid_block(sample, deterministic=deterministic)
560
-
561
- # end
562
- sample = self.conv_norm_out(sample)
563
- sample = nn.swish(sample)
564
- sample = self.conv_out(sample)
565
-
566
- return sample
567
-
568
-
569
- class FlaxDecoder(nn.Module):
570
- r"""
571
- Flax Implementation of VAE Decoder.
572
-
573
- This model is a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module)
574
- subclass. Use it as a regular Flax linen Module and refer to the Flax documentation for all matter related to
575
- general usage and behavior.
576
-
577
- Finally, this model supports inherent JAX features such as:
578
- - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
579
- - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
580
- - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
581
- - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
582
-
583
- Parameters:
584
- in_channels (:obj:`int`, *optional*, defaults to 3):
585
- Input channels
586
- out_channels (:obj:`int`, *optional*, defaults to 3):
587
- Output channels
588
- up_block_types (:obj:`Tuple[str]`, *optional*, defaults to `(UpDecoderBlock2D)`):
589
- UpDecoder block type
590
- block_out_channels (:obj:`Tuple[str]`, *optional*, defaults to `(64,)`):
591
- Tuple containing the number of output channels for each block
592
- layers_per_block (:obj:`int`, *optional*, defaults to `2`):
593
- Number of Resnet layer for each block
594
- norm_num_groups (:obj:`int`, *optional*, defaults to `32`):
595
- norm num group
596
- act_fn (:obj:`str`, *optional*, defaults to `silu`):
597
- Activation function
598
- double_z (:obj:`bool`, *optional*, defaults to `False`):
599
- Whether to double the last output channels
600
- dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
601
- parameters `dtype`
602
- """
603
- in_channels: int = 3
604
- out_channels: int = 3
605
- up_block_types: Tuple[str] = ("UpDecoderBlock2D",)
606
- block_out_channels: int = (64,)
607
- layers_per_block: int = 2
608
- norm_num_groups: int = 32
609
- act_fn: str = "silu"
610
- dtype: jnp.dtype = jnp.float32
611
-
612
- def setup(self):
613
- block_out_channels = self.block_out_channels
614
-
615
- # z to block_in
616
- self.conv_in = nn.Conv(
617
- block_out_channels[-1],
618
- kernel_size=(3, 3),
619
- strides=(1, 1),
620
- padding=((1, 1), (1, 1)),
621
- dtype=self.dtype,
622
- )
623
-
624
- # middle
625
- self.mid_block = FlaxUNetMidBlock2D(
626
- in_channels=block_out_channels[-1],
627
- resnet_groups=self.norm_num_groups,
628
- num_attention_heads=None,
629
- dtype=self.dtype,
630
- )
631
-
632
- # upsampling
633
- reversed_block_out_channels = list(reversed(block_out_channels))
634
- output_channel = reversed_block_out_channels[0]
635
- up_blocks = []
636
- for i, _ in enumerate(self.up_block_types):
637
- prev_output_channel = output_channel
638
- output_channel = reversed_block_out_channels[i]
639
-
640
- is_final_block = i == len(block_out_channels) - 1
641
-
642
- up_block = FlaxUpDecoderBlock2D(
643
- in_channels=prev_output_channel,
644
- out_channels=output_channel,
645
- num_layers=self.layers_per_block + 1,
646
- resnet_groups=self.norm_num_groups,
647
- add_upsample=not is_final_block,
648
- dtype=self.dtype,
649
- )
650
- up_blocks.append(up_block)
651
- prev_output_channel = output_channel
652
-
653
- self.up_blocks = up_blocks
654
-
655
- # end
656
- self.conv_norm_out = nn.GroupNorm(num_groups=self.norm_num_groups, epsilon=1e-6)
657
- self.conv_out = nn.Conv(
658
- self.out_channels,
659
- kernel_size=(3, 3),
660
- strides=(1, 1),
661
- padding=((1, 1), (1, 1)),
662
- dtype=self.dtype,
663
- )
664
-
665
- def __call__(self, sample, deterministic: bool = True):
666
- # z to block_in
667
- sample = self.conv_in(sample)
668
-
669
- # middle
670
- sample = self.mid_block(sample, deterministic=deterministic)
671
-
672
- # upsampling
673
- for block in self.up_blocks:
674
- sample = block(sample, deterministic=deterministic)
675
-
676
- sample = self.conv_norm_out(sample)
677
- sample = nn.swish(sample)
678
- sample = self.conv_out(sample)
679
-
680
- return sample
681
-
682
-
683
- class FlaxDiagonalGaussianDistribution(object):
684
- def __init__(self, parameters, deterministic=False):
685
- # Last axis to account for channels-last
686
- self.mean, self.logvar = jnp.split(parameters, 2, axis=-1)
687
- self.logvar = jnp.clip(self.logvar, -30.0, 20.0)
688
- self.deterministic = deterministic
689
- self.std = jnp.exp(0.5 * self.logvar)
690
- self.var = jnp.exp(self.logvar)
691
- if self.deterministic:
692
- self.var = self.std = jnp.zeros_like(self.mean)
693
-
694
- def sample(self, key):
695
- return self.mean + self.std * jax.random.normal(key, self.mean.shape)
696
-
697
- def kl(self, other=None):
698
- if self.deterministic:
699
- return jnp.array([0.0])
700
-
701
- if other is None:
702
- return 0.5 * jnp.sum(self.mean**2 + self.var - 1.0 - self.logvar, axis=[1, 2, 3])
703
-
704
- return 0.5 * jnp.sum(
705
- jnp.square(self.mean - other.mean) / other.var + self.var / other.var - 1.0 - self.logvar + other.logvar,
706
- axis=[1, 2, 3],
707
- )
708
-
709
- def nll(self, sample, axis=[1, 2, 3]):
710
- if self.deterministic:
711
- return jnp.array([0.0])
712
-
713
- logtwopi = jnp.log(2.0 * jnp.pi)
714
- return 0.5 * jnp.sum(logtwopi + self.logvar + jnp.square(sample - self.mean) / self.var, axis=axis)
715
-
716
- def mode(self):
717
- return self.mean
718
-
719
-
720
- @flax_register_to_config
721
- class FlaxAutoencoderKL(nn.Module, FlaxModelMixin, ConfigMixin):
722
- r"""
723
- Flax implementation of a VAE model with KL loss for decoding latent representations.
724
-
725
- This model inherits from [`FlaxModelMixin`]. Check the superclass documentation for it's generic methods
726
- implemented for all models (such as downloading or saving).
727
-
728
- This model is a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module)
729
- subclass. Use it as a regular Flax Linen module and refer to the Flax documentation for all matter related to its
730
- general usage and behavior.
731
-
732
- Inherent JAX features such as the following are supported:
733
-
734
- - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
735
- - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
736
- - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
737
- - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
738
-
739
- Parameters:
740
- in_channels (`int`, *optional*, defaults to 3):
741
- Number of channels in the input image.
742
- out_channels (`int`, *optional*, defaults to 3):
743
- Number of channels in the output.
744
- down_block_types (`Tuple[str]`, *optional*, defaults to `(DownEncoderBlock2D)`):
745
- Tuple of downsample block types.
746
- up_block_types (`Tuple[str]`, *optional*, defaults to `(UpDecoderBlock2D)`):
747
- Tuple of upsample block types.
748
- block_out_channels (`Tuple[str]`, *optional*, defaults to `(64,)`):
749
- Tuple of block output channels.
750
- layers_per_block (`int`, *optional*, defaults to `2`):
751
- Number of ResNet layer for each block.
752
- act_fn (`str`, *optional*, defaults to `silu`):
753
- The activation function to use.
754
- latent_channels (`int`, *optional*, defaults to `4`):
755
- Number of channels in the latent space.
756
- norm_num_groups (`int`, *optional*, defaults to `32`):
757
- The number of groups for normalization.
758
- sample_size (`int`, *optional*, defaults to 32):
759
- Sample input size.
760
- scaling_factor (`float`, *optional*, defaults to 0.18215):
761
- The component-wise standard deviation of the trained latent space computed using the first batch of the
762
- training set. This is used to scale the latent space to have unit variance when training the diffusion
763
- model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
764
- diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
765
- / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
766
- Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
767
- dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`):
768
- The `dtype` of the parameters.
769
- """
770
- in_channels: int = 3
771
- out_channels: int = 3
772
- down_block_types: Tuple[str] = ("DownEncoderBlock2D",)
773
- up_block_types: Tuple[str] = ("UpDecoderBlock2D",)
774
- block_out_channels: Tuple[int] = (64,)
775
- layers_per_block: int = 1
776
- act_fn: str = "silu"
777
- latent_channels: int = 4
778
- norm_num_groups: int = 32
779
- sample_size: int = 32
780
- scaling_factor: float = 0.18215
781
- dtype: jnp.dtype = jnp.float32
782
-
783
- def setup(self):
784
- self.encoder = FlaxEncoder(
785
- in_channels=self.config.in_channels,
786
- out_channels=self.config.latent_channels,
787
- down_block_types=self.config.down_block_types,
788
- block_out_channels=self.config.block_out_channels,
789
- layers_per_block=self.config.layers_per_block,
790
- act_fn=self.config.act_fn,
791
- norm_num_groups=self.config.norm_num_groups,
792
- double_z=True,
793
- dtype=self.dtype,
794
- )
795
- self.decoder = FlaxDecoder(
796
- in_channels=self.config.latent_channels,
797
- out_channels=self.config.out_channels,
798
- up_block_types=self.config.up_block_types,
799
- block_out_channels=self.config.block_out_channels,
800
- layers_per_block=self.config.layers_per_block,
801
- norm_num_groups=self.config.norm_num_groups,
802
- act_fn=self.config.act_fn,
803
- dtype=self.dtype,
804
- )
805
- self.quant_conv = nn.Conv(
806
- 2 * self.config.latent_channels,
807
- kernel_size=(1, 1),
808
- strides=(1, 1),
809
- padding="VALID",
810
- dtype=self.dtype,
811
- )
812
- self.post_quant_conv = nn.Conv(
813
- self.config.latent_channels,
814
- kernel_size=(1, 1),
815
- strides=(1, 1),
816
- padding="VALID",
817
- dtype=self.dtype,
818
- )
819
-
820
- def init_weights(self, rng: jax.random.KeyArray) -> FrozenDict:
821
- # init input tensors
822
- sample_shape = (1, self.in_channels, self.sample_size, self.sample_size)
823
- sample = jnp.zeros(sample_shape, dtype=jnp.float32)
824
-
825
- params_rng, dropout_rng, gaussian_rng = jax.random.split(rng, 3)
826
- rngs = {"params": params_rng, "dropout": dropout_rng, "gaussian": gaussian_rng}
827
-
828
- return self.init(rngs, sample)["params"]
829
-
830
- def encode(self, sample, deterministic: bool = True, return_dict: bool = True):
831
- sample = jnp.transpose(sample, (0, 2, 3, 1))
832
-
833
- hidden_states = self.encoder(sample, deterministic=deterministic)
834
- moments = self.quant_conv(hidden_states)
835
- posterior = FlaxDiagonalGaussianDistribution(moments)
836
-
837
- if not return_dict:
838
- return (posterior,)
839
-
840
- return FlaxAutoencoderKLOutput(latent_dist=posterior)
841
-
842
- def decode(self, latents, deterministic: bool = True, return_dict: bool = True):
843
- if latents.shape[-1] != self.config.latent_channels:
844
- latents = jnp.transpose(latents, (0, 2, 3, 1))
845
-
846
- hidden_states = self.post_quant_conv(latents)
847
- hidden_states = self.decoder(hidden_states, deterministic=deterministic)
848
-
849
- hidden_states = jnp.transpose(hidden_states, (0, 3, 1, 2))
850
-
851
- if not return_dict:
852
- return (hidden_states,)
853
-
854
- return FlaxDecoderOutput(sample=hidden_states)
855
-
856
- def __call__(self, sample, sample_posterior=False, deterministic: bool = True, return_dict: bool = True):
857
- posterior = self.encode(sample, deterministic=deterministic, return_dict=return_dict)
858
- if sample_posterior:
859
- rng = self.make_rng("gaussian")
860
- hidden_states = posterior.latent_dist.sample(rng)
861
- else:
862
- hidden_states = posterior.latent_dist.mode()
863
-
864
- sample = self.decode(hidden_states, return_dict=return_dict).sample
865
-
866
- if not return_dict:
867
- return (sample,)
868
-
869
- return FlaxDecoderOutput(sample=sample)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
gradio_demo/6DoF/diffusers/models/vq_model.py DELETED
@@ -1,167 +0,0 @@
1
- # Copyright 2023 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- from dataclasses import dataclass
15
- from typing import Optional, Tuple, Union
16
-
17
- import torch
18
- import torch.nn as nn
19
-
20
- from ..configuration_utils import ConfigMixin, register_to_config
21
- from ..utils import BaseOutput, apply_forward_hook
22
- from .modeling_utils import ModelMixin
23
- from .vae import Decoder, DecoderOutput, Encoder, VectorQuantizer
24
-
25
-
26
- @dataclass
27
- class VQEncoderOutput(BaseOutput):
28
- """
29
- Output of VQModel encoding method.
30
-
31
- Args:
32
- latents (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
33
- The encoded output sample from the last layer of the model.
34
- """
35
-
36
- latents: torch.FloatTensor
37
-
38
-
39
- class VQModel(ModelMixin, ConfigMixin):
40
- r"""
41
- A VQ-VAE model for decoding latent representations.
42
-
43
- This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
44
- for all models (such as downloading or saving).
45
-
46
- Parameters:
47
- in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
48
- out_channels (int, *optional*, defaults to 3): Number of channels in the output.
49
- down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
50
- Tuple of downsample block types.
51
- up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
52
- Tuple of upsample block types.
53
- block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
54
- Tuple of block output channels.
55
- act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
56
- latent_channels (`int`, *optional*, defaults to `3`): Number of channels in the latent space.
57
- sample_size (`int`, *optional*, defaults to `32`): Sample input size.
58
- num_vq_embeddings (`int`, *optional*, defaults to `256`): Number of codebook vectors in the VQ-VAE.
59
- vq_embed_dim (`int`, *optional*): Hidden dim of codebook vectors in the VQ-VAE.
60
- scaling_factor (`float`, *optional*, defaults to `0.18215`):
61
- The component-wise standard deviation of the trained latent space computed using the first batch of the
62
- training set. This is used to scale the latent space to have unit variance when training the diffusion
63
- model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
64
- diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
65
- / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
66
- Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
67
- """
68
-
69
- @register_to_config
70
- def __init__(
71
- self,
72
- in_channels: int = 3,
73
- out_channels: int = 3,
74
- down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
75
- up_block_types: Tuple[str] = ("UpDecoderBlock2D",),
76
- block_out_channels: Tuple[int] = (64,),
77
- layers_per_block: int = 1,
78
- act_fn: str = "silu",
79
- latent_channels: int = 3,
80
- sample_size: int = 32,
81
- num_vq_embeddings: int = 256,
82
- norm_num_groups: int = 32,
83
- vq_embed_dim: Optional[int] = None,
84
- scaling_factor: float = 0.18215,
85
- norm_type: str = "group", # group, spatial
86
- ):
87
- super().__init__()
88
-
89
- # pass init params to Encoder
90
- self.encoder = Encoder(
91
- in_channels=in_channels,
92
- out_channels=latent_channels,
93
- down_block_types=down_block_types,
94
- block_out_channels=block_out_channels,
95
- layers_per_block=layers_per_block,
96
- act_fn=act_fn,
97
- norm_num_groups=norm_num_groups,
98
- double_z=False,
99
- )
100
-
101
- vq_embed_dim = vq_embed_dim if vq_embed_dim is not None else latent_channels
102
-
103
- self.quant_conv = nn.Conv2d(latent_channels, vq_embed_dim, 1)
104
- self.quantize = VectorQuantizer(num_vq_embeddings, vq_embed_dim, beta=0.25, remap=None, sane_index_shape=False)
105
- self.post_quant_conv = nn.Conv2d(vq_embed_dim, latent_channels, 1)
106
-
107
- # pass init params to Decoder
108
- self.decoder = Decoder(
109
- in_channels=latent_channels,
110
- out_channels=out_channels,
111
- up_block_types=up_block_types,
112
- block_out_channels=block_out_channels,
113
- layers_per_block=layers_per_block,
114
- act_fn=act_fn,
115
- norm_num_groups=norm_num_groups,
116
- norm_type=norm_type,
117
- )
118
-
119
- @apply_forward_hook
120
- def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> VQEncoderOutput:
121
- h = self.encoder(x)
122
- h = self.quant_conv(h)
123
-
124
- if not return_dict:
125
- return (h,)
126
-
127
- return VQEncoderOutput(latents=h)
128
-
129
- @apply_forward_hook
130
- def decode(
131
- self, h: torch.FloatTensor, force_not_quantize: bool = False, return_dict: bool = True
132
- ) -> Union[DecoderOutput, torch.FloatTensor]:
133
- # also go through quantization layer
134
- if not force_not_quantize:
135
- quant, emb_loss, info = self.quantize(h)
136
- else:
137
- quant = h
138
- quant2 = self.post_quant_conv(quant)
139
- dec = self.decoder(quant2, quant if self.config.norm_type == "spatial" else None)
140
-
141
- if not return_dict:
142
- return (dec,)
143
-
144
- return DecoderOutput(sample=dec)
145
-
146
- def forward(self, sample: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
147
- r"""
148
- The [`VQModel`] forward method.
149
-
150
- Args:
151
- sample (`torch.FloatTensor`): Input sample.
152
- return_dict (`bool`, *optional*, defaults to `True`):
153
- Whether or not to return a [`models.vq_model.VQEncoderOutput`] instead of a plain tuple.
154
-
155
- Returns:
156
- [`~models.vq_model.VQEncoderOutput`] or `tuple`:
157
- If return_dict is True, a [`~models.vq_model.VQEncoderOutput`] is returned, otherwise a plain `tuple`
158
- is returned.
159
- """
160
- x = sample
161
- h = self.encode(x).latents
162
- dec = self.decode(h).sample
163
-
164
- if not return_dict:
165
- return (dec,)
166
-
167
- return DecoderOutput(sample=dec)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
gradio_demo/6DoF/diffusers/optimization.py DELETED
@@ -1,354 +0,0 @@
1
- # coding=utf-8
2
- # Copyright 2023 The HuggingFace Inc. team.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- """PyTorch optimization for diffusion models."""
16
-
17
- import math
18
- from enum import Enum
19
- from typing import Optional, Union
20
-
21
- from torch.optim import Optimizer
22
- from torch.optim.lr_scheduler import LambdaLR
23
-
24
- from .utils import logging
25
-
26
-
27
- logger = logging.get_logger(__name__)
28
-
29
-
30
- class SchedulerType(Enum):
31
- LINEAR = "linear"
32
- COSINE = "cosine"
33
- COSINE_WITH_RESTARTS = "cosine_with_restarts"
34
- POLYNOMIAL = "polynomial"
35
- CONSTANT = "constant"
36
- CONSTANT_WITH_WARMUP = "constant_with_warmup"
37
- PIECEWISE_CONSTANT = "piecewise_constant"
38
-
39
-
40
- def get_constant_schedule(optimizer: Optimizer, last_epoch: int = -1):
41
- """
42
- Create a schedule with a constant learning rate, using the learning rate set in optimizer.
43
-
44
- Args:
45
- optimizer ([`~torch.optim.Optimizer`]):
46
- The optimizer for which to schedule the learning rate.
47
- last_epoch (`int`, *optional*, defaults to -1):
48
- The index of the last epoch when resuming training.
49
-
50
- Return:
51
- `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
52
- """
53
- return LambdaLR(optimizer, lambda _: 1, last_epoch=last_epoch)
54
-
55
-
56
- def get_constant_schedule_with_warmup(optimizer: Optimizer, num_warmup_steps: int, last_epoch: int = -1):
57
- """
58
- Create a schedule with a constant learning rate preceded by a warmup period during which the learning rate
59
- increases linearly between 0 and the initial lr set in the optimizer.
60
-
61
- Args:
62
- optimizer ([`~torch.optim.Optimizer`]):
63
- The optimizer for which to schedule the learning rate.
64
- num_warmup_steps (`int`):
65
- The number of steps for the warmup phase.
66
- last_epoch (`int`, *optional*, defaults to -1):
67
- The index of the last epoch when resuming training.
68
-
69
- Return:
70
- `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
71
- """
72
-
73
- def lr_lambda(current_step: int):
74
- if current_step < num_warmup_steps:
75
- return float(current_step) / float(max(1.0, num_warmup_steps))
76
- return 1.0
77
-
78
- return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch)
79
-
80
-
81
- def get_piecewise_constant_schedule(optimizer: Optimizer, step_rules: str, last_epoch: int = -1):
82
- """
83
- Create a schedule with a constant learning rate, using the learning rate set in optimizer.
84
-
85
- Args:
86
- optimizer ([`~torch.optim.Optimizer`]):
87
- The optimizer for which to schedule the learning rate.
88
- step_rules (`string`):
89
- The rules for the learning rate. ex: rule_steps="1:10,0.1:20,0.01:30,0.005" it means that the learning rate
90
- if multiple 1 for the first 10 steps, mutiple 0.1 for the next 20 steps, multiple 0.01 for the next 30
91
- steps and multiple 0.005 for the other steps.
92
- last_epoch (`int`, *optional*, defaults to -1):
93
- The index of the last epoch when resuming training.
94
-
95
- Return:
96
- `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
97
- """
98
-
99
- rules_dict = {}
100
- rule_list = step_rules.split(",")
101
- for rule_str in rule_list[:-1]:
102
- value_str, steps_str = rule_str.split(":")
103
- steps = int(steps_str)
104
- value = float(value_str)
105
- rules_dict[steps] = value
106
- last_lr_multiple = float(rule_list[-1])
107
-
108
- def create_rules_function(rules_dict, last_lr_multiple):
109
- def rule_func(steps: int) -> float:
110
- sorted_steps = sorted(rules_dict.keys())
111
- for i, sorted_step in enumerate(sorted_steps):
112
- if steps < sorted_step:
113
- return rules_dict[sorted_steps[i]]
114
- return last_lr_multiple
115
-
116
- return rule_func
117
-
118
- rules_func = create_rules_function(rules_dict, last_lr_multiple)
119
-
120
- return LambdaLR(optimizer, rules_func, last_epoch=last_epoch)
121
-
122
-
123
- def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1):
124
- """
125
- Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after
126
- a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer.
127
-
128
- Args:
129
- optimizer ([`~torch.optim.Optimizer`]):
130
- The optimizer for which to schedule the learning rate.
131
- num_warmup_steps (`int`):
132
- The number of steps for the warmup phase.
133
- num_training_steps (`int`):
134
- The total number of training steps.
135
- last_epoch (`int`, *optional*, defaults to -1):
136
- The index of the last epoch when resuming training.
137
-
138
- Return:
139
- `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
140
- """
141
-
142
- def lr_lambda(current_step: int):
143
- if current_step < num_warmup_steps:
144
- return float(current_step) / float(max(1, num_warmup_steps))
145
- return max(
146
- 0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps))
147
- )
148
-
149
- return LambdaLR(optimizer, lr_lambda, last_epoch)
150
-
151
-
152
- def get_cosine_schedule_with_warmup(
153
- optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: float = 0.5, last_epoch: int = -1
154
- ):
155
- """
156
- Create a schedule with a learning rate that decreases following the values of the cosine function between the
157
- initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the
158
- initial lr set in the optimizer.
159
-
160
- Args:
161
- optimizer ([`~torch.optim.Optimizer`]):
162
- The optimizer for which to schedule the learning rate.
163
- num_warmup_steps (`int`):
164
- The number of steps for the warmup phase.
165
- num_training_steps (`int`):
166
- The total number of training steps.
167
- num_periods (`float`, *optional*, defaults to 0.5):
168
- The number of periods of the cosine function in a schedule (the default is to just decrease from the max
169
- value to 0 following a half-cosine).
170
- last_epoch (`int`, *optional*, defaults to -1):
171
- The index of the last epoch when resuming training.
172
-
173
- Return:
174
- `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
175
- """
176
-
177
- def lr_lambda(current_step):
178
- if current_step < num_warmup_steps:
179
- return float(current_step) / float(max(1, num_warmup_steps))
180
- progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
181
- return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
182
-
183
- return LambdaLR(optimizer, lr_lambda, last_epoch)
184
-
185
-
186
- def get_cosine_with_hard_restarts_schedule_with_warmup(
187
- optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: int = 1, last_epoch: int = -1
188
- ):
189
- """
190
- Create a schedule with a learning rate that decreases following the values of the cosine function between the
191
- initial lr set in the optimizer to 0, with several hard restarts, after a warmup period during which it increases
192
- linearly between 0 and the initial lr set in the optimizer.
193
-
194
- Args:
195
- optimizer ([`~torch.optim.Optimizer`]):
196
- The optimizer for which to schedule the learning rate.
197
- num_warmup_steps (`int`):
198
- The number of steps for the warmup phase.
199
- num_training_steps (`int`):
200
- The total number of training steps.
201
- num_cycles (`int`, *optional*, defaults to 1):
202
- The number of hard restarts to use.
203
- last_epoch (`int`, *optional*, defaults to -1):
204
- The index of the last epoch when resuming training.
205
-
206
- Return:
207
- `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
208
- """
209
-
210
- def lr_lambda(current_step):
211
- if current_step < num_warmup_steps:
212
- return float(current_step) / float(max(1, num_warmup_steps))
213
- progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
214
- if progress >= 1.0:
215
- return 0.0
216
- return max(0.0, 0.5 * (1.0 + math.cos(math.pi * ((float(num_cycles) * progress) % 1.0))))
217
-
218
- return LambdaLR(optimizer, lr_lambda, last_epoch)
219
-
220
-
221
- def get_polynomial_decay_schedule_with_warmup(
222
- optimizer, num_warmup_steps, num_training_steps, lr_end=1e-7, power=1.0, last_epoch=-1
223
- ):
224
- """
225
- Create a schedule with a learning rate that decreases as a polynomial decay from the initial lr set in the
226
- optimizer to end lr defined by *lr_end*, after a warmup period during which it increases linearly from 0 to the
227
- initial lr set in the optimizer.
228
-
229
- Args:
230
- optimizer ([`~torch.optim.Optimizer`]):
231
- The optimizer for which to schedule the learning rate.
232
- num_warmup_steps (`int`):
233
- The number of steps for the warmup phase.
234
- num_training_steps (`int`):
235
- The total number of training steps.
236
- lr_end (`float`, *optional*, defaults to 1e-7):
237
- The end LR.
238
- power (`float`, *optional*, defaults to 1.0):
239
- Power factor.
240
- last_epoch (`int`, *optional*, defaults to -1):
241
- The index of the last epoch when resuming training.
242
-
243
- Note: *power* defaults to 1.0 as in the fairseq implementation, which in turn is based on the original BERT
244
- implementation at
245
- https://github.com/google-research/bert/blob/f39e881b169b9d53bea03d2d341b31707a6c052b/optimization.py#L37
246
-
247
- Return:
248
- `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
249
-
250
- """
251
-
252
- lr_init = optimizer.defaults["lr"]
253
- if not (lr_init > lr_end):
254
- raise ValueError(f"lr_end ({lr_end}) must be be smaller than initial lr ({lr_init})")
255
-
256
- def lr_lambda(current_step: int):
257
- if current_step < num_warmup_steps:
258
- return float(current_step) / float(max(1, num_warmup_steps))
259
- elif current_step > num_training_steps:
260
- return lr_end / lr_init # as LambdaLR multiplies by lr_init
261
- else:
262
- lr_range = lr_init - lr_end
263
- decay_steps = num_training_steps - num_warmup_steps
264
- pct_remaining = 1 - (current_step - num_warmup_steps) / decay_steps
265
- decay = lr_range * pct_remaining**power + lr_end
266
- return decay / lr_init # as LambdaLR multiplies by lr_init
267
-
268
- return LambdaLR(optimizer, lr_lambda, last_epoch)
269
-
270
-
271
- TYPE_TO_SCHEDULER_FUNCTION = {
272
- SchedulerType.LINEAR: get_linear_schedule_with_warmup,
273
- SchedulerType.COSINE: get_cosine_schedule_with_warmup,
274
- SchedulerType.COSINE_WITH_RESTARTS: get_cosine_with_hard_restarts_schedule_with_warmup,
275
- SchedulerType.POLYNOMIAL: get_polynomial_decay_schedule_with_warmup,
276
- SchedulerType.CONSTANT: get_constant_schedule,
277
- SchedulerType.CONSTANT_WITH_WARMUP: get_constant_schedule_with_warmup,
278
- SchedulerType.PIECEWISE_CONSTANT: get_piecewise_constant_schedule,
279
- }
280
-
281
-
282
- def get_scheduler(
283
- name: Union[str, SchedulerType],
284
- optimizer: Optimizer,
285
- step_rules: Optional[str] = None,
286
- num_warmup_steps: Optional[int] = None,
287
- num_training_steps: Optional[int] = None,
288
- num_cycles: int = 1,
289
- power: float = 1.0,
290
- last_epoch: int = -1,
291
- ):
292
- """
293
- Unified API to get any scheduler from its name.
294
-
295
- Args:
296
- name (`str` or `SchedulerType`):
297
- The name of the scheduler to use.
298
- optimizer (`torch.optim.Optimizer`):
299
- The optimizer that will be used during training.
300
- step_rules (`str`, *optional*):
301
- A string representing the step rules to use. This is only used by the `PIECEWISE_CONSTANT` scheduler.
302
- num_warmup_steps (`int`, *optional*):
303
- The number of warmup steps to do. This is not required by all schedulers (hence the argument being
304
- optional), the function will raise an error if it's unset and the scheduler type requires it.
305
- num_training_steps (`int``, *optional*):
306
- The number of training steps to do. This is not required by all schedulers (hence the argument being
307
- optional), the function will raise an error if it's unset and the scheduler type requires it.
308
- num_cycles (`int`, *optional*):
309
- The number of hard restarts used in `COSINE_WITH_RESTARTS` scheduler.
310
- power (`float`, *optional*, defaults to 1.0):
311
- Power factor. See `POLYNOMIAL` scheduler
312
- last_epoch (`int`, *optional*, defaults to -1):
313
- The index of the last epoch when resuming training.
314
- """
315
- name = SchedulerType(name)
316
- schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
317
- if name == SchedulerType.CONSTANT:
318
- return schedule_func(optimizer, last_epoch=last_epoch)
319
-
320
- if name == SchedulerType.PIECEWISE_CONSTANT:
321
- return schedule_func(optimizer, step_rules=step_rules, last_epoch=last_epoch)
322
-
323
- # All other schedulers require `num_warmup_steps`
324
- if num_warmup_steps is None:
325
- raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")
326
-
327
- if name == SchedulerType.CONSTANT_WITH_WARMUP:
328
- return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, last_epoch=last_epoch)
329
-
330
- # All other schedulers require `num_training_steps`
331
- if num_training_steps is None:
332
- raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
333
-
334
- if name == SchedulerType.COSINE_WITH_RESTARTS:
335
- return schedule_func(
336
- optimizer,
337
- num_warmup_steps=num_warmup_steps,
338
- num_training_steps=num_training_steps,
339
- num_cycles=num_cycles,
340
- last_epoch=last_epoch,
341
- )
342
-
343
- if name == SchedulerType.POLYNOMIAL:
344
- return schedule_func(
345
- optimizer,
346
- num_warmup_steps=num_warmup_steps,
347
- num_training_steps=num_training_steps,
348
- power=power,
349
- last_epoch=last_epoch,
350
- )
351
-
352
- return schedule_func(
353
- optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, last_epoch=last_epoch
354
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
gradio_demo/6DoF/diffusers/pipeline_utils.py DELETED
@@ -1,29 +0,0 @@
1
- # Copyright 2023 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
-
14
- # limitations under the License.
15
-
16
- # NOTE: This file is deprecated and will be removed in a future version.
17
- # It only exists so that temporarely `from diffusers.pipelines import DiffusionPipeline` works
18
-
19
- from .pipelines import DiffusionPipeline, ImagePipelineOutput # noqa: F401
20
- from .utils import deprecate
21
-
22
-
23
- deprecate(
24
- "pipelines_utils",
25
- "0.22.0",
26
- "Importing `DiffusionPipeline` or `ImagePipelineOutput` from diffusers.pipeline_utils is deprecated. Please import from diffusers.pipelines.pipeline_utils instead.",
27
- standard_warn=False,
28
- stacklevel=3,
29
- )