Spaces:
Running
on
Zero
Running
on
Zero
kxhit
commited on
Commit
·
5f093a6
1
Parent(s):
6d86936
update
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- 6DoF/CN_encoder.py +36 -0
- 6DoF/dataset.py +176 -0
- 6DoF/diffusers/__init__.py +281 -0
- 6DoF/diffusers/commands/__init__.py +27 -0
- 6DoF/diffusers/commands/diffusers_cli.py +41 -0
- 6DoF/diffusers/commands/env.py +84 -0
- 6DoF/diffusers/configuration_utils.py +664 -0
- 6DoF/diffusers/dependency_versions_check.py +47 -0
- 6DoF/diffusers/dependency_versions_table.py +44 -0
- 6DoF/diffusers/experimental/__init__.py +1 -0
- 6DoF/diffusers/experimental/rl/__init__.py +1 -0
- 6DoF/diffusers/experimental/rl/value_guided_sampling.py +152 -0
- 6DoF/diffusers/image_processor.py +366 -0
- 6DoF/diffusers/loaders.py +1492 -0
- 6DoF/diffusers/models/__init__.py +35 -0
- 6DoF/diffusers/models/activations.py +12 -0
- 6DoF/diffusers/models/attention.py +392 -0
- 6DoF/diffusers/models/attention_flax.py +446 -0
- 6DoF/diffusers/models/attention_processor.py +1684 -0
- 6DoF/diffusers/models/autoencoder_kl.py +411 -0
- 6DoF/diffusers/models/controlnet.py +705 -0
- 6DoF/diffusers/models/controlnet_flax.py +394 -0
- 6DoF/diffusers/models/cross_attention.py +94 -0
- 6DoF/diffusers/models/dual_transformer_2d.py +151 -0
- 6DoF/diffusers/models/embeddings.py +546 -0
- 6DoF/diffusers/models/embeddings_flax.py +95 -0
- 6DoF/diffusers/models/modeling_flax_pytorch_utils.py +118 -0
- 6DoF/diffusers/models/modeling_flax_utils.py +534 -0
- 6DoF/diffusers/models/modeling_pytorch_flax_utils.py +161 -0
- 6DoF/diffusers/models/modeling_utils.py +980 -0
- 6DoF/diffusers/models/prior_transformer.py +364 -0
- 6DoF/diffusers/models/resnet.py +877 -0
- 6DoF/diffusers/models/resnet_flax.py +124 -0
- 6DoF/diffusers/models/t5_film_transformer.py +321 -0
- 6DoF/diffusers/models/transformer_2d.py +343 -0
- 6DoF/diffusers/models/transformer_temporal.py +179 -0
- 6DoF/diffusers/models/unet_1d.py +255 -0
- 6DoF/diffusers/models/unet_1d_blocks.py +656 -0
- 6DoF/diffusers/models/unet_2d.py +329 -0
- 6DoF/diffusers/models/unet_2d_blocks.py +0 -0
- 6DoF/diffusers/models/unet_2d_blocks_flax.py +377 -0
- 6DoF/diffusers/models/unet_2d_condition.py +980 -0
- 6DoF/diffusers/models/unet_2d_condition_flax.py +357 -0
- 6DoF/diffusers/models/unet_3d_blocks.py +679 -0
- 6DoF/diffusers/models/unet_3d_condition.py +627 -0
- 6DoF/diffusers/models/vae.py +441 -0
- 6DoF/diffusers/models/vae_flax.py +869 -0
- 6DoF/diffusers/models/vq_model.py +167 -0
- 6DoF/diffusers/optimization.py +354 -0
- 6DoF/diffusers/pipeline_utils.py +29 -0
6DoF/CN_encoder.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import ConvNextV2Model
|
2 |
+
import torch
|
3 |
+
from typing import Optional
|
4 |
+
import einops
|
5 |
+
|
6 |
+
class CN_encoder(ConvNextV2Model):
|
7 |
+
def __init__(self, config):
|
8 |
+
super().__init__(config)
|
9 |
+
|
10 |
+
def forward(
|
11 |
+
self,
|
12 |
+
pixel_values: torch.FloatTensor = None,
|
13 |
+
output_hidden_states: Optional[bool] = None,
|
14 |
+
return_dict: Optional[bool] = None,
|
15 |
+
):
|
16 |
+
output_hidden_states = (
|
17 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
18 |
+
)
|
19 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
20 |
+
|
21 |
+
if pixel_values is None:
|
22 |
+
raise ValueError("You have to specify pixel_values")
|
23 |
+
|
24 |
+
embedding_output = self.embeddings(pixel_values)
|
25 |
+
|
26 |
+
encoder_outputs = self.encoder(
|
27 |
+
embedding_output,
|
28 |
+
output_hidden_states=output_hidden_states,
|
29 |
+
return_dict=return_dict,
|
30 |
+
)
|
31 |
+
|
32 |
+
last_hidden_state = encoder_outputs[0]
|
33 |
+
image_embeddings = einops.rearrange(last_hidden_state, 'b c h w -> b (h w) c')
|
34 |
+
image_embeddings = self.layernorm(image_embeddings)
|
35 |
+
|
36 |
+
return image_embeddings
|
6DoF/dataset.py
ADDED
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import math
|
3 |
+
from pathlib import Path
|
4 |
+
import torch
|
5 |
+
import torchvision
|
6 |
+
from torch.utils.data import Dataset, DataLoader
|
7 |
+
from torchvision import transforms
|
8 |
+
from PIL import Image
|
9 |
+
import numpy as np
|
10 |
+
import webdataset as wds
|
11 |
+
from torch.utils.data.distributed import DistributedSampler
|
12 |
+
import matplotlib.pyplot as plt
|
13 |
+
import sys
|
14 |
+
|
15 |
+
class ObjaverseDataLoader():
|
16 |
+
def __init__(self, root_dir, batch_size, total_view=12, num_workers=4):
|
17 |
+
self.root_dir = root_dir
|
18 |
+
self.batch_size = batch_size
|
19 |
+
self.num_workers = num_workers
|
20 |
+
self.total_view = total_view
|
21 |
+
|
22 |
+
image_transforms = [torchvision.transforms.Resize((256, 256)),
|
23 |
+
transforms.ToTensor(),
|
24 |
+
transforms.Normalize([0.5], [0.5])]
|
25 |
+
self.image_transforms = torchvision.transforms.Compose(image_transforms)
|
26 |
+
|
27 |
+
def train_dataloader(self):
|
28 |
+
dataset = ObjaverseData(root_dir=self.root_dir, total_view=self.total_view, validation=False,
|
29 |
+
image_transforms=self.image_transforms)
|
30 |
+
# sampler = DistributedSampler(dataset)
|
31 |
+
return wds.WebLoader(dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False)
|
32 |
+
# sampler=sampler)
|
33 |
+
|
34 |
+
def val_dataloader(self):
|
35 |
+
dataset = ObjaverseData(root_dir=self.root_dir, total_view=self.total_view, validation=True,
|
36 |
+
image_transforms=self.image_transforms)
|
37 |
+
sampler = DistributedSampler(dataset)
|
38 |
+
return wds.WebLoader(dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False)
|
39 |
+
|
40 |
+
def get_pose(transformation):
|
41 |
+
# transformation: 4x4
|
42 |
+
return transformation
|
43 |
+
|
44 |
+
class ObjaverseData(Dataset):
|
45 |
+
def __init__(self,
|
46 |
+
root_dir='.objaverse/hf-objaverse-v1/views',
|
47 |
+
image_transforms=None,
|
48 |
+
total_view=12,
|
49 |
+
validation=False,
|
50 |
+
T_in=1,
|
51 |
+
T_out=1,
|
52 |
+
fix_sample=False,
|
53 |
+
) -> None:
|
54 |
+
"""Create a dataset from a folder of images.
|
55 |
+
If you pass in a root directory it will be searched for images
|
56 |
+
ending in ext (ext can be a list)
|
57 |
+
"""
|
58 |
+
self.root_dir = Path(root_dir)
|
59 |
+
self.total_view = total_view
|
60 |
+
self.T_in = T_in
|
61 |
+
self.T_out = T_out
|
62 |
+
self.fix_sample = fix_sample
|
63 |
+
|
64 |
+
self.paths = []
|
65 |
+
# # include all folders
|
66 |
+
# for folder in os.listdir(self.root_dir):
|
67 |
+
# if os.path.isdir(os.path.join(self.root_dir, folder)):
|
68 |
+
# self.paths.append(folder)
|
69 |
+
# load ids from .npy so we have exactly the same ids/order
|
70 |
+
self.paths = np.load("../scripts/obj_ids.npy")
|
71 |
+
# # only use 100K objects for ablation study
|
72 |
+
# self.paths = self.paths[:100000]
|
73 |
+
total_objects = len(self.paths)
|
74 |
+
assert total_objects == 790152, 'total objects %d' % total_objects
|
75 |
+
if validation:
|
76 |
+
self.paths = self.paths[math.floor(total_objects / 100. * 99.):] # used last 1% as validation
|
77 |
+
else:
|
78 |
+
self.paths = self.paths[:math.floor(total_objects / 100. * 99.)] # used first 99% as training
|
79 |
+
print('============= length of dataset %d =============' % len(self.paths))
|
80 |
+
self.tform = image_transforms
|
81 |
+
|
82 |
+
downscale = 512 / 256.
|
83 |
+
self.fx = 560. / downscale
|
84 |
+
self.fy = 560. / downscale
|
85 |
+
self.intrinsic = torch.tensor([[self.fx, 0, 128., 0, self.fy, 128., 0, 0, 1.]], dtype=torch.float64).view(3, 3)
|
86 |
+
|
87 |
+
def __len__(self):
|
88 |
+
return len(self.paths)
|
89 |
+
|
90 |
+
def get_pose(self, transformation):
|
91 |
+
# transformation: 4x4
|
92 |
+
return transformation
|
93 |
+
|
94 |
+
|
95 |
+
def load_im(self, path, color):
|
96 |
+
'''
|
97 |
+
replace background pixel with random color in rendering
|
98 |
+
'''
|
99 |
+
try:
|
100 |
+
img = plt.imread(path)
|
101 |
+
except:
|
102 |
+
print(path)
|
103 |
+
sys.exit()
|
104 |
+
img[img[:, :, -1] == 0.] = color
|
105 |
+
img = Image.fromarray(np.uint8(img[:, :, :3] * 255.))
|
106 |
+
return img
|
107 |
+
|
108 |
+
def __getitem__(self, index):
|
109 |
+
data = {}
|
110 |
+
total_view = 12
|
111 |
+
|
112 |
+
if self.fix_sample:
|
113 |
+
if self.T_out > 1:
|
114 |
+
indexes = range(total_view)
|
115 |
+
index_targets = list(indexes[:2]) + list(indexes[-(self.T_out-2):])
|
116 |
+
index_inputs = indexes[1:self.T_in+1] # one overlap identity
|
117 |
+
else:
|
118 |
+
indexes = range(total_view)
|
119 |
+
index_targets = indexes[:self.T_out]
|
120 |
+
index_inputs = indexes[self.T_out-1:self.T_in+self.T_out-1] # one overlap identity
|
121 |
+
else:
|
122 |
+
assert self.T_in + self.T_out <= total_view
|
123 |
+
# training with replace, including identity
|
124 |
+
indexes = np.random.choice(range(total_view), self.T_in+self.T_out, replace=True)
|
125 |
+
index_inputs = indexes[:self.T_in]
|
126 |
+
index_targets = indexes[self.T_in:]
|
127 |
+
filename = os.path.join(self.root_dir, self.paths[index])
|
128 |
+
|
129 |
+
color = [1., 1., 1., 1.]
|
130 |
+
|
131 |
+
try:
|
132 |
+
input_ims = []
|
133 |
+
target_ims = []
|
134 |
+
target_Ts = []
|
135 |
+
cond_Ts = []
|
136 |
+
for i, index_input in enumerate(index_inputs):
|
137 |
+
input_im = self.process_im(self.load_im(os.path.join(filename, '%03d.png' % index_input), color))
|
138 |
+
input_ims.append(input_im)
|
139 |
+
input_RT = np.load(os.path.join(filename, '%03d.npy' % index_input))
|
140 |
+
cond_Ts.append(self.get_pose(np.concatenate([input_RT[:3, :], np.array([[0, 0, 0, 1]])], axis=0)))
|
141 |
+
for i, index_target in enumerate(index_targets):
|
142 |
+
target_im = self.process_im(self.load_im(os.path.join(filename, '%03d.png' % index_target), color))
|
143 |
+
target_ims.append(target_im)
|
144 |
+
target_RT = np.load(os.path.join(filename, '%03d.npy' % index_target))
|
145 |
+
target_Ts.append(self.get_pose(np.concatenate([target_RT[:3, :], np.array([[0, 0, 0, 1]])], axis=0)))
|
146 |
+
except:
|
147 |
+
print('error loading data ', filename)
|
148 |
+
filename = os.path.join(self.root_dir, '0a01f314e2864711aa7e33bace4bd8c8') # this one we know is valid
|
149 |
+
input_ims = []
|
150 |
+
target_ims = []
|
151 |
+
target_Ts = []
|
152 |
+
cond_Ts = []
|
153 |
+
# very hacky solution, sorry about this
|
154 |
+
for i, index_input in enumerate(index_inputs):
|
155 |
+
input_im = self.process_im(self.load_im(os.path.join(filename, '%03d.png' % index_input), color))
|
156 |
+
input_ims.append(input_im)
|
157 |
+
input_RT = np.load(os.path.join(filename, '%03d.npy' % index_input))
|
158 |
+
cond_Ts.append(self.get_pose(np.concatenate([input_RT[:3, :], np.array([[0, 0, 0, 1]])], axis=0)))
|
159 |
+
for i, index_target in enumerate(index_targets):
|
160 |
+
target_im = self.process_im(self.load_im(os.path.join(filename, '%03d.png' % index_target), color))
|
161 |
+
target_ims.append(target_im)
|
162 |
+
target_RT = np.load(os.path.join(filename, '%03d.npy' % index_target))
|
163 |
+
target_Ts.append(self.get_pose(np.concatenate([target_RT[:3, :], np.array([[0, 0, 0, 1]])], axis=0)))
|
164 |
+
|
165 |
+
# stack to batch
|
166 |
+
data['image_input'] = torch.stack(input_ims, dim=0)
|
167 |
+
data['image_target'] = torch.stack(target_ims, dim=0)
|
168 |
+
data['pose_out'] = np.stack(target_Ts)
|
169 |
+
data['pose_out_inv'] = np.linalg.inv(np.stack(target_Ts)).transpose([0, 2, 1])
|
170 |
+
data['pose_in'] = np.stack(cond_Ts)
|
171 |
+
data['pose_in_inv'] = np.linalg.inv(np.stack(cond_Ts)).transpose([0, 2, 1])
|
172 |
+
return data
|
173 |
+
|
174 |
+
def process_im(self, im):
|
175 |
+
im = im.convert("RGB")
|
176 |
+
return self.tform(im)
|
6DoF/diffusers/__init__.py
ADDED
@@ -0,0 +1,281 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__version__ = "0.18.2"
|
2 |
+
|
3 |
+
from .configuration_utils import ConfigMixin
|
4 |
+
from .utils import (
|
5 |
+
OptionalDependencyNotAvailable,
|
6 |
+
is_flax_available,
|
7 |
+
is_inflect_available,
|
8 |
+
is_invisible_watermark_available,
|
9 |
+
is_k_diffusion_available,
|
10 |
+
is_k_diffusion_version,
|
11 |
+
is_librosa_available,
|
12 |
+
is_note_seq_available,
|
13 |
+
is_onnx_available,
|
14 |
+
is_scipy_available,
|
15 |
+
is_torch_available,
|
16 |
+
is_torchsde_available,
|
17 |
+
is_transformers_available,
|
18 |
+
is_transformers_version,
|
19 |
+
is_unidecode_available,
|
20 |
+
logging,
|
21 |
+
)
|
22 |
+
|
23 |
+
|
24 |
+
try:
|
25 |
+
if not is_onnx_available():
|
26 |
+
raise OptionalDependencyNotAvailable()
|
27 |
+
except OptionalDependencyNotAvailable:
|
28 |
+
from .utils.dummy_onnx_objects import * # noqa F403
|
29 |
+
else:
|
30 |
+
from .pipelines import OnnxRuntimeModel
|
31 |
+
|
32 |
+
try:
|
33 |
+
if not is_torch_available():
|
34 |
+
raise OptionalDependencyNotAvailable()
|
35 |
+
except OptionalDependencyNotAvailable:
|
36 |
+
from .utils.dummy_pt_objects import * # noqa F403
|
37 |
+
else:
|
38 |
+
from .models import (
|
39 |
+
AutoencoderKL,
|
40 |
+
ControlNetModel,
|
41 |
+
ModelMixin,
|
42 |
+
PriorTransformer,
|
43 |
+
T5FilmDecoder,
|
44 |
+
Transformer2DModel,
|
45 |
+
UNet1DModel,
|
46 |
+
UNet2DConditionModel,
|
47 |
+
UNet2DModel,
|
48 |
+
UNet3DConditionModel,
|
49 |
+
VQModel,
|
50 |
+
)
|
51 |
+
from .optimization import (
|
52 |
+
get_constant_schedule,
|
53 |
+
get_constant_schedule_with_warmup,
|
54 |
+
get_cosine_schedule_with_warmup,
|
55 |
+
get_cosine_with_hard_restarts_schedule_with_warmup,
|
56 |
+
get_linear_schedule_with_warmup,
|
57 |
+
get_polynomial_decay_schedule_with_warmup,
|
58 |
+
get_scheduler,
|
59 |
+
)
|
60 |
+
from .pipelines import (
|
61 |
+
AudioPipelineOutput,
|
62 |
+
ConsistencyModelPipeline,
|
63 |
+
DanceDiffusionPipeline,
|
64 |
+
DDIMPipeline,
|
65 |
+
DDPMPipeline,
|
66 |
+
DiffusionPipeline,
|
67 |
+
DiTPipeline,
|
68 |
+
ImagePipelineOutput,
|
69 |
+
KarrasVePipeline,
|
70 |
+
LDMPipeline,
|
71 |
+
LDMSuperResolutionPipeline,
|
72 |
+
PNDMPipeline,
|
73 |
+
RePaintPipeline,
|
74 |
+
ScoreSdeVePipeline,
|
75 |
+
)
|
76 |
+
from .schedulers import (
|
77 |
+
CMStochasticIterativeScheduler,
|
78 |
+
DDIMInverseScheduler,
|
79 |
+
DDIMParallelScheduler,
|
80 |
+
DDIMScheduler,
|
81 |
+
DDPMParallelScheduler,
|
82 |
+
DDPMScheduler,
|
83 |
+
DEISMultistepScheduler,
|
84 |
+
DPMSolverMultistepInverseScheduler,
|
85 |
+
DPMSolverMultistepScheduler,
|
86 |
+
DPMSolverSinglestepScheduler,
|
87 |
+
EulerAncestralDiscreteScheduler,
|
88 |
+
EulerDiscreteScheduler,
|
89 |
+
HeunDiscreteScheduler,
|
90 |
+
IPNDMScheduler,
|
91 |
+
KarrasVeScheduler,
|
92 |
+
KDPM2AncestralDiscreteScheduler,
|
93 |
+
KDPM2DiscreteScheduler,
|
94 |
+
PNDMScheduler,
|
95 |
+
RePaintScheduler,
|
96 |
+
SchedulerMixin,
|
97 |
+
ScoreSdeVeScheduler,
|
98 |
+
UnCLIPScheduler,
|
99 |
+
UniPCMultistepScheduler,
|
100 |
+
VQDiffusionScheduler,
|
101 |
+
)
|
102 |
+
from .training_utils import EMAModel
|
103 |
+
|
104 |
+
try:
|
105 |
+
if not (is_torch_available() and is_scipy_available()):
|
106 |
+
raise OptionalDependencyNotAvailable()
|
107 |
+
except OptionalDependencyNotAvailable:
|
108 |
+
from .utils.dummy_torch_and_scipy_objects import * # noqa F403
|
109 |
+
else:
|
110 |
+
from .schedulers import LMSDiscreteScheduler
|
111 |
+
|
112 |
+
try:
|
113 |
+
if not (is_torch_available() and is_torchsde_available()):
|
114 |
+
raise OptionalDependencyNotAvailable()
|
115 |
+
except OptionalDependencyNotAvailable:
|
116 |
+
from .utils.dummy_torch_and_torchsde_objects import * # noqa F403
|
117 |
+
else:
|
118 |
+
from .schedulers import DPMSolverSDEScheduler
|
119 |
+
|
120 |
+
try:
|
121 |
+
if not (is_torch_available() and is_transformers_available()):
|
122 |
+
raise OptionalDependencyNotAvailable()
|
123 |
+
except OptionalDependencyNotAvailable:
|
124 |
+
from .utils.dummy_torch_and_transformers_objects import * # noqa F403
|
125 |
+
else:
|
126 |
+
from .pipelines import (
|
127 |
+
AltDiffusionImg2ImgPipeline,
|
128 |
+
AltDiffusionPipeline,
|
129 |
+
AudioLDMPipeline,
|
130 |
+
CycleDiffusionPipeline,
|
131 |
+
IFImg2ImgPipeline,
|
132 |
+
IFImg2ImgSuperResolutionPipeline,
|
133 |
+
IFInpaintingPipeline,
|
134 |
+
IFInpaintingSuperResolutionPipeline,
|
135 |
+
IFPipeline,
|
136 |
+
IFSuperResolutionPipeline,
|
137 |
+
ImageTextPipelineOutput,
|
138 |
+
KandinskyImg2ImgPipeline,
|
139 |
+
KandinskyInpaintPipeline,
|
140 |
+
KandinskyPipeline,
|
141 |
+
KandinskyPriorPipeline,
|
142 |
+
KandinskyV22ControlnetImg2ImgPipeline,
|
143 |
+
KandinskyV22ControlnetPipeline,
|
144 |
+
KandinskyV22Img2ImgPipeline,
|
145 |
+
KandinskyV22InpaintPipeline,
|
146 |
+
KandinskyV22Pipeline,
|
147 |
+
KandinskyV22PriorEmb2EmbPipeline,
|
148 |
+
KandinskyV22PriorPipeline,
|
149 |
+
LDMTextToImagePipeline,
|
150 |
+
PaintByExamplePipeline,
|
151 |
+
SemanticStableDiffusionPipeline,
|
152 |
+
ShapEImg2ImgPipeline,
|
153 |
+
ShapEPipeline,
|
154 |
+
StableDiffusionAttendAndExcitePipeline,
|
155 |
+
StableDiffusionControlNetImg2ImgPipeline,
|
156 |
+
StableDiffusionControlNetInpaintPipeline,
|
157 |
+
StableDiffusionControlNetPipeline,
|
158 |
+
StableDiffusionDepth2ImgPipeline,
|
159 |
+
StableDiffusionDiffEditPipeline,
|
160 |
+
StableDiffusionImageVariationPipeline,
|
161 |
+
StableDiffusionImg2ImgPipeline,
|
162 |
+
StableDiffusionInpaintPipeline,
|
163 |
+
StableDiffusionInpaintPipelineLegacy,
|
164 |
+
StableDiffusionInstructPix2PixPipeline,
|
165 |
+
StableDiffusionLatentUpscalePipeline,
|
166 |
+
StableDiffusionLDM3DPipeline,
|
167 |
+
StableDiffusionModelEditingPipeline,
|
168 |
+
StableDiffusionPanoramaPipeline,
|
169 |
+
StableDiffusionParadigmsPipeline,
|
170 |
+
StableDiffusionPipeline,
|
171 |
+
StableDiffusionPipelineSafe,
|
172 |
+
StableDiffusionPix2PixZeroPipeline,
|
173 |
+
StableDiffusionSAGPipeline,
|
174 |
+
StableDiffusionUpscalePipeline,
|
175 |
+
StableUnCLIPImg2ImgPipeline,
|
176 |
+
StableUnCLIPPipeline,
|
177 |
+
TextToVideoSDPipeline,
|
178 |
+
TextToVideoZeroPipeline,
|
179 |
+
UnCLIPImageVariationPipeline,
|
180 |
+
UnCLIPPipeline,
|
181 |
+
UniDiffuserModel,
|
182 |
+
UniDiffuserPipeline,
|
183 |
+
UniDiffuserTextDecoder,
|
184 |
+
VersatileDiffusionDualGuidedPipeline,
|
185 |
+
VersatileDiffusionImageVariationPipeline,
|
186 |
+
VersatileDiffusionPipeline,
|
187 |
+
VersatileDiffusionTextToImagePipeline,
|
188 |
+
VideoToVideoSDPipeline,
|
189 |
+
VQDiffusionPipeline,
|
190 |
+
)
|
191 |
+
|
192 |
+
try:
|
193 |
+
if not (is_torch_available() and is_transformers_available() and is_invisible_watermark_available()):
|
194 |
+
raise OptionalDependencyNotAvailable()
|
195 |
+
except OptionalDependencyNotAvailable:
|
196 |
+
from .utils.dummy_torch_and_transformers_and_invisible_watermark_objects import * # noqa F403
|
197 |
+
else:
|
198 |
+
from .pipelines import StableDiffusionXLImg2ImgPipeline, StableDiffusionXLPipeline
|
199 |
+
|
200 |
+
try:
|
201 |
+
if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()):
|
202 |
+
raise OptionalDependencyNotAvailable()
|
203 |
+
except OptionalDependencyNotAvailable:
|
204 |
+
from .utils.dummy_torch_and_transformers_and_k_diffusion_objects import * # noqa F403
|
205 |
+
else:
|
206 |
+
from .pipelines import StableDiffusionKDiffusionPipeline
|
207 |
+
|
208 |
+
try:
|
209 |
+
if not (is_torch_available() and is_transformers_available() and is_onnx_available()):
|
210 |
+
raise OptionalDependencyNotAvailable()
|
211 |
+
except OptionalDependencyNotAvailable:
|
212 |
+
from .utils.dummy_torch_and_transformers_and_onnx_objects import * # noqa F403
|
213 |
+
else:
|
214 |
+
from .pipelines import (
|
215 |
+
OnnxStableDiffusionImg2ImgPipeline,
|
216 |
+
OnnxStableDiffusionInpaintPipeline,
|
217 |
+
OnnxStableDiffusionInpaintPipelineLegacy,
|
218 |
+
OnnxStableDiffusionPipeline,
|
219 |
+
OnnxStableDiffusionUpscalePipeline,
|
220 |
+
StableDiffusionOnnxPipeline,
|
221 |
+
)
|
222 |
+
|
223 |
+
try:
|
224 |
+
if not (is_torch_available() and is_librosa_available()):
|
225 |
+
raise OptionalDependencyNotAvailable()
|
226 |
+
except OptionalDependencyNotAvailable:
|
227 |
+
from .utils.dummy_torch_and_librosa_objects import * # noqa F403
|
228 |
+
else:
|
229 |
+
from .pipelines import AudioDiffusionPipeline, Mel
|
230 |
+
|
231 |
+
try:
|
232 |
+
if not (is_transformers_available() and is_torch_available() and is_note_seq_available()):
|
233 |
+
raise OptionalDependencyNotAvailable()
|
234 |
+
except OptionalDependencyNotAvailable:
|
235 |
+
from .utils.dummy_transformers_and_torch_and_note_seq_objects import * # noqa F403
|
236 |
+
else:
|
237 |
+
from .pipelines import SpectrogramDiffusionPipeline
|
238 |
+
|
239 |
+
try:
|
240 |
+
if not is_flax_available():
|
241 |
+
raise OptionalDependencyNotAvailable()
|
242 |
+
except OptionalDependencyNotAvailable:
|
243 |
+
from .utils.dummy_flax_objects import * # noqa F403
|
244 |
+
else:
|
245 |
+
from .models.controlnet_flax import FlaxControlNetModel
|
246 |
+
from .models.modeling_flax_utils import FlaxModelMixin
|
247 |
+
from .models.unet_2d_condition_flax import FlaxUNet2DConditionModel
|
248 |
+
from .models.vae_flax import FlaxAutoencoderKL
|
249 |
+
from .pipelines import FlaxDiffusionPipeline
|
250 |
+
from .schedulers import (
|
251 |
+
FlaxDDIMScheduler,
|
252 |
+
FlaxDDPMScheduler,
|
253 |
+
FlaxDPMSolverMultistepScheduler,
|
254 |
+
FlaxKarrasVeScheduler,
|
255 |
+
FlaxLMSDiscreteScheduler,
|
256 |
+
FlaxPNDMScheduler,
|
257 |
+
FlaxSchedulerMixin,
|
258 |
+
FlaxScoreSdeVeScheduler,
|
259 |
+
)
|
260 |
+
|
261 |
+
|
262 |
+
try:
|
263 |
+
if not (is_flax_available() and is_transformers_available()):
|
264 |
+
raise OptionalDependencyNotAvailable()
|
265 |
+
except OptionalDependencyNotAvailable:
|
266 |
+
from .utils.dummy_flax_and_transformers_objects import * # noqa F403
|
267 |
+
else:
|
268 |
+
from .pipelines import (
|
269 |
+
FlaxStableDiffusionControlNetPipeline,
|
270 |
+
FlaxStableDiffusionImg2ImgPipeline,
|
271 |
+
FlaxStableDiffusionInpaintPipeline,
|
272 |
+
FlaxStableDiffusionPipeline,
|
273 |
+
)
|
274 |
+
|
275 |
+
try:
|
276 |
+
if not (is_note_seq_available()):
|
277 |
+
raise OptionalDependencyNotAvailable()
|
278 |
+
except OptionalDependencyNotAvailable:
|
279 |
+
from .utils.dummy_note_seq_objects import * # noqa F403
|
280 |
+
else:
|
281 |
+
from .pipelines import MidiProcessor
|
6DoF/diffusers/commands/__init__.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from abc import ABC, abstractmethod
|
16 |
+
from argparse import ArgumentParser
|
17 |
+
|
18 |
+
|
19 |
+
class BaseDiffusersCLICommand(ABC):
|
20 |
+
@staticmethod
|
21 |
+
@abstractmethod
|
22 |
+
def register_subcommand(parser: ArgumentParser):
|
23 |
+
raise NotImplementedError()
|
24 |
+
|
25 |
+
@abstractmethod
|
26 |
+
def run(self):
|
27 |
+
raise NotImplementedError()
|
6DoF/diffusers/commands/diffusers_cli.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
from argparse import ArgumentParser
|
17 |
+
|
18 |
+
from .env import EnvironmentCommand
|
19 |
+
|
20 |
+
|
21 |
+
def main():
|
22 |
+
parser = ArgumentParser("Diffusers CLI tool", usage="diffusers-cli <command> [<args>]")
|
23 |
+
commands_parser = parser.add_subparsers(help="diffusers-cli command helpers")
|
24 |
+
|
25 |
+
# Register commands
|
26 |
+
EnvironmentCommand.register_subcommand(commands_parser)
|
27 |
+
|
28 |
+
# Let's go
|
29 |
+
args = parser.parse_args()
|
30 |
+
|
31 |
+
if not hasattr(args, "func"):
|
32 |
+
parser.print_help()
|
33 |
+
exit(1)
|
34 |
+
|
35 |
+
# Run
|
36 |
+
service = args.func(args)
|
37 |
+
service.run()
|
38 |
+
|
39 |
+
|
40 |
+
if __name__ == "__main__":
|
41 |
+
main()
|
6DoF/diffusers/commands/env.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import platform
|
16 |
+
from argparse import ArgumentParser
|
17 |
+
|
18 |
+
import huggingface_hub
|
19 |
+
|
20 |
+
from .. import __version__ as version
|
21 |
+
from ..utils import is_accelerate_available, is_torch_available, is_transformers_available, is_xformers_available
|
22 |
+
from . import BaseDiffusersCLICommand
|
23 |
+
|
24 |
+
|
25 |
+
def info_command_factory(_):
|
26 |
+
return EnvironmentCommand()
|
27 |
+
|
28 |
+
|
29 |
+
class EnvironmentCommand(BaseDiffusersCLICommand):
|
30 |
+
@staticmethod
|
31 |
+
def register_subcommand(parser: ArgumentParser):
|
32 |
+
download_parser = parser.add_parser("env")
|
33 |
+
download_parser.set_defaults(func=info_command_factory)
|
34 |
+
|
35 |
+
def run(self):
|
36 |
+
hub_version = huggingface_hub.__version__
|
37 |
+
|
38 |
+
pt_version = "not installed"
|
39 |
+
pt_cuda_available = "NA"
|
40 |
+
if is_torch_available():
|
41 |
+
import torch
|
42 |
+
|
43 |
+
pt_version = torch.__version__
|
44 |
+
pt_cuda_available = torch.cuda.is_available()
|
45 |
+
|
46 |
+
transformers_version = "not installed"
|
47 |
+
if is_transformers_available():
|
48 |
+
import transformers
|
49 |
+
|
50 |
+
transformers_version = transformers.__version__
|
51 |
+
|
52 |
+
accelerate_version = "not installed"
|
53 |
+
if is_accelerate_available():
|
54 |
+
import accelerate
|
55 |
+
|
56 |
+
accelerate_version = accelerate.__version__
|
57 |
+
|
58 |
+
xformers_version = "not installed"
|
59 |
+
if is_xformers_available():
|
60 |
+
import xformers
|
61 |
+
|
62 |
+
xformers_version = xformers.__version__
|
63 |
+
|
64 |
+
info = {
|
65 |
+
"`diffusers` version": version,
|
66 |
+
"Platform": platform.platform(),
|
67 |
+
"Python version": platform.python_version(),
|
68 |
+
"PyTorch version (GPU?)": f"{pt_version} ({pt_cuda_available})",
|
69 |
+
"Huggingface_hub version": hub_version,
|
70 |
+
"Transformers version": transformers_version,
|
71 |
+
"Accelerate version": accelerate_version,
|
72 |
+
"xFormers version": xformers_version,
|
73 |
+
"Using GPU in script?": "<fill in>",
|
74 |
+
"Using distributed or parallel set-up in script?": "<fill in>",
|
75 |
+
}
|
76 |
+
|
77 |
+
print("\nCopy-and-paste the text below in your GitHub issue and FILL OUT the two last points.\n")
|
78 |
+
print(self.format_dict(info))
|
79 |
+
|
80 |
+
return info
|
81 |
+
|
82 |
+
@staticmethod
|
83 |
+
def format_dict(d):
|
84 |
+
return "\n".join([f"- {prop}: {val}" for prop, val in d.items()]) + "\n"
|
6DoF/diffusers/configuration_utils.py
ADDED
@@ -0,0 +1,664 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023 The HuggingFace Inc. team.
|
3 |
+
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
""" ConfigMixin base class and utilities."""
|
17 |
+
import dataclasses
|
18 |
+
import functools
|
19 |
+
import importlib
|
20 |
+
import inspect
|
21 |
+
import json
|
22 |
+
import os
|
23 |
+
import re
|
24 |
+
from collections import OrderedDict
|
25 |
+
from pathlib import PosixPath
|
26 |
+
from typing import Any, Dict, Tuple, Union
|
27 |
+
|
28 |
+
import numpy as np
|
29 |
+
from huggingface_hub import hf_hub_download
|
30 |
+
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
|
31 |
+
from requests import HTTPError
|
32 |
+
|
33 |
+
from . import __version__
|
34 |
+
from .utils import (
|
35 |
+
DIFFUSERS_CACHE,
|
36 |
+
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
|
37 |
+
DummyObject,
|
38 |
+
deprecate,
|
39 |
+
extract_commit_hash,
|
40 |
+
http_user_agent,
|
41 |
+
logging,
|
42 |
+
)
|
43 |
+
|
44 |
+
|
45 |
+
logger = logging.get_logger(__name__)
|
46 |
+
|
47 |
+
_re_configuration_file = re.compile(r"config\.(.*)\.json")
|
48 |
+
|
49 |
+
|
50 |
+
class FrozenDict(OrderedDict):
|
51 |
+
def __init__(self, *args, **kwargs):
|
52 |
+
super().__init__(*args, **kwargs)
|
53 |
+
|
54 |
+
for key, value in self.items():
|
55 |
+
setattr(self, key, value)
|
56 |
+
|
57 |
+
self.__frozen = True
|
58 |
+
|
59 |
+
def __delitem__(self, *args, **kwargs):
|
60 |
+
raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.")
|
61 |
+
|
62 |
+
def setdefault(self, *args, **kwargs):
|
63 |
+
raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.")
|
64 |
+
|
65 |
+
def pop(self, *args, **kwargs):
|
66 |
+
raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.")
|
67 |
+
|
68 |
+
def update(self, *args, **kwargs):
|
69 |
+
raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.")
|
70 |
+
|
71 |
+
def __setattr__(self, name, value):
|
72 |
+
if hasattr(self, "__frozen") and self.__frozen:
|
73 |
+
raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.")
|
74 |
+
super().__setattr__(name, value)
|
75 |
+
|
76 |
+
def __setitem__(self, name, value):
|
77 |
+
if hasattr(self, "__frozen") and self.__frozen:
|
78 |
+
raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.")
|
79 |
+
super().__setitem__(name, value)
|
80 |
+
|
81 |
+
|
82 |
+
class ConfigMixin:
|
83 |
+
r"""
|
84 |
+
Base class for all configuration classes. All configuration parameters are stored under `self.config`. Also
|
85 |
+
provides the [`~ConfigMixin.from_config`] and [`~ConfigMixin.save_config`] methods for loading, downloading, and
|
86 |
+
saving classes that inherit from [`ConfigMixin`].
|
87 |
+
|
88 |
+
Class attributes:
|
89 |
+
- **config_name** (`str`) -- A filename under which the config should stored when calling
|
90 |
+
[`~ConfigMixin.save_config`] (should be overridden by parent class).
|
91 |
+
- **ignore_for_config** (`List[str]`) -- A list of attributes that should not be saved in the config (should be
|
92 |
+
overridden by subclass).
|
93 |
+
- **has_compatibles** (`bool`) -- Whether the class has compatible classes (should be overridden by subclass).
|
94 |
+
- **_deprecated_kwargs** (`List[str]`) -- Keyword arguments that are deprecated. Note that the `init` function
|
95 |
+
should only have a `kwargs` argument if at least one argument is deprecated (should be overridden by
|
96 |
+
subclass).
|
97 |
+
"""
|
98 |
+
config_name = None
|
99 |
+
ignore_for_config = []
|
100 |
+
has_compatibles = False
|
101 |
+
|
102 |
+
_deprecated_kwargs = []
|
103 |
+
|
104 |
+
def register_to_config(self, **kwargs):
|
105 |
+
if self.config_name is None:
|
106 |
+
raise NotImplementedError(f"Make sure that {self.__class__} has defined a class name `config_name`")
|
107 |
+
# Special case for `kwargs` used in deprecation warning added to schedulers
|
108 |
+
# TODO: remove this when we remove the deprecation warning, and the `kwargs` argument,
|
109 |
+
# or solve in a more general way.
|
110 |
+
kwargs.pop("kwargs", None)
|
111 |
+
|
112 |
+
if not hasattr(self, "_internal_dict"):
|
113 |
+
internal_dict = kwargs
|
114 |
+
else:
|
115 |
+
previous_dict = dict(self._internal_dict)
|
116 |
+
internal_dict = {**self._internal_dict, **kwargs}
|
117 |
+
logger.debug(f"Updating config from {previous_dict} to {internal_dict}")
|
118 |
+
|
119 |
+
self._internal_dict = FrozenDict(internal_dict)
|
120 |
+
|
121 |
+
def __getattr__(self, name: str) -> Any:
|
122 |
+
"""The only reason we overwrite `getattr` here is to gracefully deprecate accessing
|
123 |
+
config attributes directly. See https://github.com/huggingface/diffusers/pull/3129
|
124 |
+
|
125 |
+
Tihs funtion is mostly copied from PyTorch's __getattr__ overwrite:
|
126 |
+
https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module
|
127 |
+
"""
|
128 |
+
|
129 |
+
is_in_config = "_internal_dict" in self.__dict__ and hasattr(self.__dict__["_internal_dict"], name)
|
130 |
+
is_attribute = name in self.__dict__
|
131 |
+
|
132 |
+
if is_in_config and not is_attribute:
|
133 |
+
deprecation_message = f"Accessing config attribute `{name}` directly via '{type(self).__name__}' object attribute is deprecated. Please access '{name}' over '{type(self).__name__}'s config object instead, e.g. 'scheduler.config.{name}'."
|
134 |
+
deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False)
|
135 |
+
return self._internal_dict[name]
|
136 |
+
|
137 |
+
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
|
138 |
+
|
139 |
+
def save_config(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
|
140 |
+
"""
|
141 |
+
Save a configuration object to the directory specified in `save_directory` so that it can be reloaded using the
|
142 |
+
[`~ConfigMixin.from_config`] class method.
|
143 |
+
|
144 |
+
Args:
|
145 |
+
save_directory (`str` or `os.PathLike`):
|
146 |
+
Directory where the configuration JSON file is saved (will be created if it does not exist).
|
147 |
+
"""
|
148 |
+
if os.path.isfile(save_directory):
|
149 |
+
raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
|
150 |
+
|
151 |
+
os.makedirs(save_directory, exist_ok=True)
|
152 |
+
|
153 |
+
# If we save using the predefined names, we can load using `from_config`
|
154 |
+
output_config_file = os.path.join(save_directory, self.config_name)
|
155 |
+
|
156 |
+
self.to_json_file(output_config_file)
|
157 |
+
logger.info(f"Configuration saved in {output_config_file}")
|
158 |
+
|
159 |
+
@classmethod
|
160 |
+
def from_config(cls, config: Union[FrozenDict, Dict[str, Any]] = None, return_unused_kwargs=False, **kwargs):
|
161 |
+
r"""
|
162 |
+
Instantiate a Python class from a config dictionary.
|
163 |
+
|
164 |
+
Parameters:
|
165 |
+
config (`Dict[str, Any]`):
|
166 |
+
A config dictionary from which the Python class is instantiated. Make sure to only load configuration
|
167 |
+
files of compatible classes.
|
168 |
+
return_unused_kwargs (`bool`, *optional*, defaults to `False`):
|
169 |
+
Whether kwargs that are not consumed by the Python class should be returned or not.
|
170 |
+
kwargs (remaining dictionary of keyword arguments, *optional*):
|
171 |
+
Can be used to update the configuration object (after it is loaded) and initiate the Python class.
|
172 |
+
`**kwargs` are passed directly to the underlying scheduler/model's `__init__` method and eventually
|
173 |
+
overwrite the same named arguments in `config`.
|
174 |
+
|
175 |
+
Returns:
|
176 |
+
[`ModelMixin`] or [`SchedulerMixin`]:
|
177 |
+
A model or scheduler object instantiated from a config dictionary.
|
178 |
+
|
179 |
+
Examples:
|
180 |
+
|
181 |
+
```python
|
182 |
+
>>> from diffusers import DDPMScheduler, DDIMScheduler, PNDMScheduler
|
183 |
+
|
184 |
+
>>> # Download scheduler from huggingface.co and cache.
|
185 |
+
>>> scheduler = DDPMScheduler.from_pretrained("google/ddpm-cifar10-32")
|
186 |
+
|
187 |
+
>>> # Instantiate DDIM scheduler class with same config as DDPM
|
188 |
+
>>> scheduler = DDIMScheduler.from_config(scheduler.config)
|
189 |
+
|
190 |
+
>>> # Instantiate PNDM scheduler class with same config as DDPM
|
191 |
+
>>> scheduler = PNDMScheduler.from_config(scheduler.config)
|
192 |
+
```
|
193 |
+
"""
|
194 |
+
# <===== TO BE REMOVED WITH DEPRECATION
|
195 |
+
# TODO(Patrick) - make sure to remove the following lines when config=="model_path" is deprecated
|
196 |
+
if "pretrained_model_name_or_path" in kwargs:
|
197 |
+
config = kwargs.pop("pretrained_model_name_or_path")
|
198 |
+
|
199 |
+
if config is None:
|
200 |
+
raise ValueError("Please make sure to provide a config as the first positional argument.")
|
201 |
+
# ======>
|
202 |
+
|
203 |
+
if not isinstance(config, dict):
|
204 |
+
deprecation_message = "It is deprecated to pass a pretrained model name or path to `from_config`."
|
205 |
+
if "Scheduler" in cls.__name__:
|
206 |
+
deprecation_message += (
|
207 |
+
f"If you were trying to load a scheduler, please use {cls}.from_pretrained(...) instead."
|
208 |
+
" Otherwise, please make sure to pass a configuration dictionary instead. This functionality will"
|
209 |
+
" be removed in v1.0.0."
|
210 |
+
)
|
211 |
+
elif "Model" in cls.__name__:
|
212 |
+
deprecation_message += (
|
213 |
+
f"If you were trying to load a model, please use {cls}.load_config(...) followed by"
|
214 |
+
f" {cls}.from_config(...) instead. Otherwise, please make sure to pass a configuration dictionary"
|
215 |
+
" instead. This functionality will be removed in v1.0.0."
|
216 |
+
)
|
217 |
+
deprecate("config-passed-as-path", "1.0.0", deprecation_message, standard_warn=False)
|
218 |
+
config, kwargs = cls.load_config(pretrained_model_name_or_path=config, return_unused_kwargs=True, **kwargs)
|
219 |
+
|
220 |
+
init_dict, unused_kwargs, hidden_dict = cls.extract_init_dict(config, **kwargs)
|
221 |
+
|
222 |
+
# Allow dtype to be specified on initialization
|
223 |
+
if "dtype" in unused_kwargs:
|
224 |
+
init_dict["dtype"] = unused_kwargs.pop("dtype")
|
225 |
+
|
226 |
+
# add possible deprecated kwargs
|
227 |
+
for deprecated_kwarg in cls._deprecated_kwargs:
|
228 |
+
if deprecated_kwarg in unused_kwargs:
|
229 |
+
init_dict[deprecated_kwarg] = unused_kwargs.pop(deprecated_kwarg)
|
230 |
+
|
231 |
+
# Return model and optionally state and/or unused_kwargs
|
232 |
+
model = cls(**init_dict)
|
233 |
+
|
234 |
+
# make sure to also save config parameters that might be used for compatible classes
|
235 |
+
model.register_to_config(**hidden_dict)
|
236 |
+
|
237 |
+
# add hidden kwargs of compatible classes to unused_kwargs
|
238 |
+
unused_kwargs = {**unused_kwargs, **hidden_dict}
|
239 |
+
|
240 |
+
if return_unused_kwargs:
|
241 |
+
return (model, unused_kwargs)
|
242 |
+
else:
|
243 |
+
return model
|
244 |
+
|
245 |
+
@classmethod
|
246 |
+
def get_config_dict(cls, *args, **kwargs):
|
247 |
+
deprecation_message = (
|
248 |
+
f" The function get_config_dict is deprecated. Please use {cls}.load_config instead. This function will be"
|
249 |
+
" removed in version v1.0.0"
|
250 |
+
)
|
251 |
+
deprecate("get_config_dict", "1.0.0", deprecation_message, standard_warn=False)
|
252 |
+
return cls.load_config(*args, **kwargs)
|
253 |
+
|
254 |
+
@classmethod
|
255 |
+
def load_config(
|
256 |
+
cls,
|
257 |
+
pretrained_model_name_or_path: Union[str, os.PathLike],
|
258 |
+
return_unused_kwargs=False,
|
259 |
+
return_commit_hash=False,
|
260 |
+
**kwargs,
|
261 |
+
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
262 |
+
r"""
|
263 |
+
Load a model or scheduler configuration.
|
264 |
+
|
265 |
+
Parameters:
|
266 |
+
pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
|
267 |
+
Can be either:
|
268 |
+
|
269 |
+
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
|
270 |
+
the Hub.
|
271 |
+
- A path to a *directory* (for example `./my_model_directory`) containing model weights saved with
|
272 |
+
[`~ConfigMixin.save_config`].
|
273 |
+
|
274 |
+
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
275 |
+
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
276 |
+
is not used.
|
277 |
+
force_download (`bool`, *optional*, defaults to `False`):
|
278 |
+
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
279 |
+
cached versions if they exist.
|
280 |
+
resume_download (`bool`, *optional*, defaults to `False`):
|
281 |
+
Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
|
282 |
+
incompletely downloaded files are deleted.
|
283 |
+
proxies (`Dict[str, str]`, *optional*):
|
284 |
+
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
285 |
+
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
286 |
+
output_loading_info(`bool`, *optional*, defaults to `False`):
|
287 |
+
Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
|
288 |
+
local_files_only (`bool`, *optional*, defaults to `False`):
|
289 |
+
Whether to only load local model weights and configuration files or not. If set to `True`, the model
|
290 |
+
won't be downloaded from the Hub.
|
291 |
+
use_auth_token (`str` or *bool*, *optional*):
|
292 |
+
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
293 |
+
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
294 |
+
revision (`str`, *optional*, defaults to `"main"`):
|
295 |
+
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
296 |
+
allowed by Git.
|
297 |
+
subfolder (`str`, *optional*, defaults to `""`):
|
298 |
+
The subfolder location of a model file within a larger model repository on the Hub or locally.
|
299 |
+
return_unused_kwargs (`bool`, *optional*, defaults to `False):
|
300 |
+
Whether unused keyword arguments of the config are returned.
|
301 |
+
return_commit_hash (`bool`, *optional*, defaults to `False):
|
302 |
+
Whether the `commit_hash` of the loaded configuration are returned.
|
303 |
+
|
304 |
+
Returns:
|
305 |
+
`dict`:
|
306 |
+
A dictionary of all the parameters stored in a JSON configuration file.
|
307 |
+
|
308 |
+
"""
|
309 |
+
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
|
310 |
+
force_download = kwargs.pop("force_download", False)
|
311 |
+
resume_download = kwargs.pop("resume_download", False)
|
312 |
+
proxies = kwargs.pop("proxies", None)
|
313 |
+
use_auth_token = kwargs.pop("use_auth_token", None)
|
314 |
+
local_files_only = kwargs.pop("local_files_only", False)
|
315 |
+
revision = kwargs.pop("revision", None)
|
316 |
+
_ = kwargs.pop("mirror", None)
|
317 |
+
subfolder = kwargs.pop("subfolder", None)
|
318 |
+
user_agent = kwargs.pop("user_agent", {})
|
319 |
+
|
320 |
+
user_agent = {**user_agent, "file_type": "config"}
|
321 |
+
user_agent = http_user_agent(user_agent)
|
322 |
+
|
323 |
+
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
324 |
+
|
325 |
+
if cls.config_name is None:
|
326 |
+
raise ValueError(
|
327 |
+
"`self.config_name` is not defined. Note that one should not load a config from "
|
328 |
+
"`ConfigMixin`. Please make sure to define `config_name` in a class inheriting from `ConfigMixin`"
|
329 |
+
)
|
330 |
+
|
331 |
+
if os.path.isfile(pretrained_model_name_or_path):
|
332 |
+
config_file = pretrained_model_name_or_path
|
333 |
+
elif os.path.isdir(pretrained_model_name_or_path):
|
334 |
+
if os.path.isfile(os.path.join(pretrained_model_name_or_path, cls.config_name)):
|
335 |
+
# Load from a PyTorch checkpoint
|
336 |
+
config_file = os.path.join(pretrained_model_name_or_path, cls.config_name)
|
337 |
+
elif subfolder is not None and os.path.isfile(
|
338 |
+
os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name)
|
339 |
+
):
|
340 |
+
config_file = os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name)
|
341 |
+
else:
|
342 |
+
raise EnvironmentError(
|
343 |
+
f"Error no file named {cls.config_name} found in directory {pretrained_model_name_or_path}."
|
344 |
+
)
|
345 |
+
else:
|
346 |
+
try:
|
347 |
+
# Load from URL or cache if already cached
|
348 |
+
config_file = hf_hub_download(
|
349 |
+
pretrained_model_name_or_path,
|
350 |
+
filename=cls.config_name,
|
351 |
+
cache_dir=cache_dir,
|
352 |
+
force_download=force_download,
|
353 |
+
proxies=proxies,
|
354 |
+
resume_download=resume_download,
|
355 |
+
local_files_only=local_files_only,
|
356 |
+
use_auth_token=use_auth_token,
|
357 |
+
user_agent=user_agent,
|
358 |
+
subfolder=subfolder,
|
359 |
+
revision=revision,
|
360 |
+
)
|
361 |
+
except RepositoryNotFoundError:
|
362 |
+
raise EnvironmentError(
|
363 |
+
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier"
|
364 |
+
" listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a"
|
365 |
+
" token having permission to this repo with `use_auth_token` or log in with `huggingface-cli"
|
366 |
+
" login`."
|
367 |
+
)
|
368 |
+
except RevisionNotFoundError:
|
369 |
+
raise EnvironmentError(
|
370 |
+
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for"
|
371 |
+
" this model name. Check the model page at"
|
372 |
+
f" 'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
|
373 |
+
)
|
374 |
+
except EntryNotFoundError:
|
375 |
+
raise EnvironmentError(
|
376 |
+
f"{pretrained_model_name_or_path} does not appear to have a file named {cls.config_name}."
|
377 |
+
)
|
378 |
+
except HTTPError as err:
|
379 |
+
raise EnvironmentError(
|
380 |
+
"There was a specific connection error when trying to load"
|
381 |
+
f" {pretrained_model_name_or_path}:\n{err}"
|
382 |
+
)
|
383 |
+
except ValueError:
|
384 |
+
raise EnvironmentError(
|
385 |
+
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
|
386 |
+
f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
|
387 |
+
f" directory containing a {cls.config_name} file.\nCheckout your internet connection or see how to"
|
388 |
+
" run the library in offline mode at"
|
389 |
+
" 'https://huggingface.co/docs/diffusers/installation#offline-mode'."
|
390 |
+
)
|
391 |
+
except EnvironmentError:
|
392 |
+
raise EnvironmentError(
|
393 |
+
f"Can't load config for '{pretrained_model_name_or_path}'. If you were trying to load it from "
|
394 |
+
"'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
|
395 |
+
f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
|
396 |
+
f"containing a {cls.config_name} file"
|
397 |
+
)
|
398 |
+
|
399 |
+
try:
|
400 |
+
# Load config dict
|
401 |
+
config_dict = cls._dict_from_json_file(config_file)
|
402 |
+
|
403 |
+
commit_hash = extract_commit_hash(config_file)
|
404 |
+
except (json.JSONDecodeError, UnicodeDecodeError):
|
405 |
+
raise EnvironmentError(f"It looks like the config file at '{config_file}' is not a valid JSON file.")
|
406 |
+
|
407 |
+
if not (return_unused_kwargs or return_commit_hash):
|
408 |
+
return config_dict
|
409 |
+
|
410 |
+
outputs = (config_dict,)
|
411 |
+
|
412 |
+
if return_unused_kwargs:
|
413 |
+
outputs += (kwargs,)
|
414 |
+
|
415 |
+
if return_commit_hash:
|
416 |
+
outputs += (commit_hash,)
|
417 |
+
|
418 |
+
return outputs
|
419 |
+
|
420 |
+
@staticmethod
|
421 |
+
def _get_init_keys(cls):
|
422 |
+
return set(dict(inspect.signature(cls.__init__).parameters).keys())
|
423 |
+
|
424 |
+
@classmethod
|
425 |
+
def extract_init_dict(cls, config_dict, **kwargs):
|
426 |
+
# Skip keys that were not present in the original config, so default __init__ values were used
|
427 |
+
used_defaults = config_dict.get("_use_default_values", [])
|
428 |
+
config_dict = {k: v for k, v in config_dict.items() if k not in used_defaults and k != "_use_default_values"}
|
429 |
+
|
430 |
+
# 0. Copy origin config dict
|
431 |
+
original_dict = dict(config_dict.items())
|
432 |
+
|
433 |
+
# 1. Retrieve expected config attributes from __init__ signature
|
434 |
+
expected_keys = cls._get_init_keys(cls)
|
435 |
+
expected_keys.remove("self")
|
436 |
+
# remove general kwargs if present in dict
|
437 |
+
if "kwargs" in expected_keys:
|
438 |
+
expected_keys.remove("kwargs")
|
439 |
+
# remove flax internal keys
|
440 |
+
if hasattr(cls, "_flax_internal_args"):
|
441 |
+
for arg in cls._flax_internal_args:
|
442 |
+
expected_keys.remove(arg)
|
443 |
+
|
444 |
+
# 2. Remove attributes that cannot be expected from expected config attributes
|
445 |
+
# remove keys to be ignored
|
446 |
+
if len(cls.ignore_for_config) > 0:
|
447 |
+
expected_keys = expected_keys - set(cls.ignore_for_config)
|
448 |
+
|
449 |
+
# load diffusers library to import compatible and original scheduler
|
450 |
+
diffusers_library = importlib.import_module(__name__.split(".")[0])
|
451 |
+
|
452 |
+
if cls.has_compatibles:
|
453 |
+
compatible_classes = [c for c in cls._get_compatibles() if not isinstance(c, DummyObject)]
|
454 |
+
else:
|
455 |
+
compatible_classes = []
|
456 |
+
|
457 |
+
expected_keys_comp_cls = set()
|
458 |
+
for c in compatible_classes:
|
459 |
+
expected_keys_c = cls._get_init_keys(c)
|
460 |
+
expected_keys_comp_cls = expected_keys_comp_cls.union(expected_keys_c)
|
461 |
+
expected_keys_comp_cls = expected_keys_comp_cls - cls._get_init_keys(cls)
|
462 |
+
config_dict = {k: v for k, v in config_dict.items() if k not in expected_keys_comp_cls}
|
463 |
+
|
464 |
+
# remove attributes from orig class that cannot be expected
|
465 |
+
orig_cls_name = config_dict.pop("_class_name", cls.__name__)
|
466 |
+
if orig_cls_name != cls.__name__ and hasattr(diffusers_library, orig_cls_name):
|
467 |
+
orig_cls = getattr(diffusers_library, orig_cls_name)
|
468 |
+
unexpected_keys_from_orig = cls._get_init_keys(orig_cls) - expected_keys
|
469 |
+
config_dict = {k: v for k, v in config_dict.items() if k not in unexpected_keys_from_orig}
|
470 |
+
|
471 |
+
# remove private attributes
|
472 |
+
config_dict = {k: v for k, v in config_dict.items() if not k.startswith("_")}
|
473 |
+
|
474 |
+
# 3. Create keyword arguments that will be passed to __init__ from expected keyword arguments
|
475 |
+
init_dict = {}
|
476 |
+
for key in expected_keys:
|
477 |
+
# if config param is passed to kwarg and is present in config dict
|
478 |
+
# it should overwrite existing config dict key
|
479 |
+
if key in kwargs and key in config_dict:
|
480 |
+
config_dict[key] = kwargs.pop(key)
|
481 |
+
|
482 |
+
if key in kwargs:
|
483 |
+
# overwrite key
|
484 |
+
init_dict[key] = kwargs.pop(key)
|
485 |
+
elif key in config_dict:
|
486 |
+
# use value from config dict
|
487 |
+
init_dict[key] = config_dict.pop(key)
|
488 |
+
|
489 |
+
# 4. Give nice warning if unexpected values have been passed
|
490 |
+
if len(config_dict) > 0:
|
491 |
+
logger.warning(
|
492 |
+
f"The config attributes {config_dict} were passed to {cls.__name__}, "
|
493 |
+
"but are not expected and will be ignored. Please verify your "
|
494 |
+
f"{cls.config_name} configuration file."
|
495 |
+
)
|
496 |
+
|
497 |
+
# 5. Give nice info if config attributes are initiliazed to default because they have not been passed
|
498 |
+
passed_keys = set(init_dict.keys())
|
499 |
+
if len(expected_keys - passed_keys) > 0:
|
500 |
+
logger.info(
|
501 |
+
f"{expected_keys - passed_keys} was not found in config. Values will be initialized to default values."
|
502 |
+
)
|
503 |
+
|
504 |
+
# 6. Define unused keyword arguments
|
505 |
+
unused_kwargs = {**config_dict, **kwargs}
|
506 |
+
|
507 |
+
# 7. Define "hidden" config parameters that were saved for compatible classes
|
508 |
+
hidden_config_dict = {k: v for k, v in original_dict.items() if k not in init_dict}
|
509 |
+
|
510 |
+
return init_dict, unused_kwargs, hidden_config_dict
|
511 |
+
|
512 |
+
@classmethod
|
513 |
+
def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]):
|
514 |
+
with open(json_file, "r", encoding="utf-8") as reader:
|
515 |
+
text = reader.read()
|
516 |
+
return json.loads(text)
|
517 |
+
|
518 |
+
def __repr__(self):
|
519 |
+
return f"{self.__class__.__name__} {self.to_json_string()}"
|
520 |
+
|
521 |
+
@property
|
522 |
+
def config(self) -> Dict[str, Any]:
|
523 |
+
"""
|
524 |
+
Returns the config of the class as a frozen dictionary
|
525 |
+
|
526 |
+
Returns:
|
527 |
+
`Dict[str, Any]`: Config of the class.
|
528 |
+
"""
|
529 |
+
return self._internal_dict
|
530 |
+
|
531 |
+
def to_json_string(self) -> str:
|
532 |
+
"""
|
533 |
+
Serializes the configuration instance to a JSON string.
|
534 |
+
|
535 |
+
Returns:
|
536 |
+
`str`:
|
537 |
+
String containing all the attributes that make up the configuration instance in JSON format.
|
538 |
+
"""
|
539 |
+
config_dict = self._internal_dict if hasattr(self, "_internal_dict") else {}
|
540 |
+
config_dict["_class_name"] = self.__class__.__name__
|
541 |
+
config_dict["_diffusers_version"] = __version__
|
542 |
+
|
543 |
+
def to_json_saveable(value):
|
544 |
+
if isinstance(value, np.ndarray):
|
545 |
+
value = value.tolist()
|
546 |
+
elif isinstance(value, PosixPath):
|
547 |
+
value = str(value)
|
548 |
+
return value
|
549 |
+
|
550 |
+
config_dict = {k: to_json_saveable(v) for k, v in config_dict.items()}
|
551 |
+
# Don't save "_ignore_files" or "_use_default_values"
|
552 |
+
config_dict.pop("_ignore_files", None)
|
553 |
+
config_dict.pop("_use_default_values", None)
|
554 |
+
|
555 |
+
return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
|
556 |
+
|
557 |
+
def to_json_file(self, json_file_path: Union[str, os.PathLike]):
|
558 |
+
"""
|
559 |
+
Save the configuration instance's parameters to a JSON file.
|
560 |
+
|
561 |
+
Args:
|
562 |
+
json_file_path (`str` or `os.PathLike`):
|
563 |
+
Path to the JSON file to save a configuration instance's parameters.
|
564 |
+
"""
|
565 |
+
with open(json_file_path, "w", encoding="utf-8") as writer:
|
566 |
+
writer.write(self.to_json_string())
|
567 |
+
|
568 |
+
|
569 |
+
def register_to_config(init):
|
570 |
+
r"""
|
571 |
+
Decorator to apply on the init of classes inheriting from [`ConfigMixin`] so that all the arguments are
|
572 |
+
automatically sent to `self.register_for_config`. To ignore a specific argument accepted by the init but that
|
573 |
+
shouldn't be registered in the config, use the `ignore_for_config` class variable
|
574 |
+
|
575 |
+
Warning: Once decorated, all private arguments (beginning with an underscore) are trashed and not sent to the init!
|
576 |
+
"""
|
577 |
+
|
578 |
+
@functools.wraps(init)
|
579 |
+
def inner_init(self, *args, **kwargs):
|
580 |
+
# Ignore private kwargs in the init.
|
581 |
+
init_kwargs = {k: v for k, v in kwargs.items() if not k.startswith("_")}
|
582 |
+
config_init_kwargs = {k: v for k, v in kwargs.items() if k.startswith("_")}
|
583 |
+
if not isinstance(self, ConfigMixin):
|
584 |
+
raise RuntimeError(
|
585 |
+
f"`@register_for_config` was applied to {self.__class__.__name__} init method, but this class does "
|
586 |
+
"not inherit from `ConfigMixin`."
|
587 |
+
)
|
588 |
+
|
589 |
+
ignore = getattr(self, "ignore_for_config", [])
|
590 |
+
# Get positional arguments aligned with kwargs
|
591 |
+
new_kwargs = {}
|
592 |
+
signature = inspect.signature(init)
|
593 |
+
parameters = {
|
594 |
+
name: p.default for i, (name, p) in enumerate(signature.parameters.items()) if i > 0 and name not in ignore
|
595 |
+
}
|
596 |
+
for arg, name in zip(args, parameters.keys()):
|
597 |
+
new_kwargs[name] = arg
|
598 |
+
|
599 |
+
# Then add all kwargs
|
600 |
+
new_kwargs.update(
|
601 |
+
{
|
602 |
+
k: init_kwargs.get(k, default)
|
603 |
+
for k, default in parameters.items()
|
604 |
+
if k not in ignore and k not in new_kwargs
|
605 |
+
}
|
606 |
+
)
|
607 |
+
|
608 |
+
# Take note of the parameters that were not present in the loaded config
|
609 |
+
if len(set(new_kwargs.keys()) - set(init_kwargs)) > 0:
|
610 |
+
new_kwargs["_use_default_values"] = list(set(new_kwargs.keys()) - set(init_kwargs))
|
611 |
+
|
612 |
+
new_kwargs = {**config_init_kwargs, **new_kwargs}
|
613 |
+
getattr(self, "register_to_config")(**new_kwargs)
|
614 |
+
init(self, *args, **init_kwargs)
|
615 |
+
|
616 |
+
return inner_init
|
617 |
+
|
618 |
+
|
619 |
+
def flax_register_to_config(cls):
|
620 |
+
original_init = cls.__init__
|
621 |
+
|
622 |
+
@functools.wraps(original_init)
|
623 |
+
def init(self, *args, **kwargs):
|
624 |
+
if not isinstance(self, ConfigMixin):
|
625 |
+
raise RuntimeError(
|
626 |
+
f"`@register_for_config` was applied to {self.__class__.__name__} init method, but this class does "
|
627 |
+
"not inherit from `ConfigMixin`."
|
628 |
+
)
|
629 |
+
|
630 |
+
# Ignore private kwargs in the init. Retrieve all passed attributes
|
631 |
+
init_kwargs = dict(kwargs.items())
|
632 |
+
|
633 |
+
# Retrieve default values
|
634 |
+
fields = dataclasses.fields(self)
|
635 |
+
default_kwargs = {}
|
636 |
+
for field in fields:
|
637 |
+
# ignore flax specific attributes
|
638 |
+
if field.name in self._flax_internal_args:
|
639 |
+
continue
|
640 |
+
if type(field.default) == dataclasses._MISSING_TYPE:
|
641 |
+
default_kwargs[field.name] = None
|
642 |
+
else:
|
643 |
+
default_kwargs[field.name] = getattr(self, field.name)
|
644 |
+
|
645 |
+
# Make sure init_kwargs override default kwargs
|
646 |
+
new_kwargs = {**default_kwargs, **init_kwargs}
|
647 |
+
# dtype should be part of `init_kwargs`, but not `new_kwargs`
|
648 |
+
if "dtype" in new_kwargs:
|
649 |
+
new_kwargs.pop("dtype")
|
650 |
+
|
651 |
+
# Get positional arguments aligned with kwargs
|
652 |
+
for i, arg in enumerate(args):
|
653 |
+
name = fields[i].name
|
654 |
+
new_kwargs[name] = arg
|
655 |
+
|
656 |
+
# Take note of the parameters that were not present in the loaded config
|
657 |
+
if len(set(new_kwargs.keys()) - set(init_kwargs)) > 0:
|
658 |
+
new_kwargs["_use_default_values"] = list(set(new_kwargs.keys()) - set(init_kwargs))
|
659 |
+
|
660 |
+
getattr(self, "register_to_config")(**new_kwargs)
|
661 |
+
original_init(self, *args, **kwargs)
|
662 |
+
|
663 |
+
cls.__init__ = init
|
664 |
+
return cls
|
6DoF/diffusers/dependency_versions_check.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import sys
|
15 |
+
|
16 |
+
from .dependency_versions_table import deps
|
17 |
+
from .utils.versions import require_version, require_version_core
|
18 |
+
|
19 |
+
|
20 |
+
# define which module versions we always want to check at run time
|
21 |
+
# (usually the ones defined in `install_requires` in setup.py)
|
22 |
+
#
|
23 |
+
# order specific notes:
|
24 |
+
# - tqdm must be checked before tokenizers
|
25 |
+
|
26 |
+
pkgs_to_check_at_runtime = "python tqdm regex requests packaging filelock numpy tokenizers".split()
|
27 |
+
if sys.version_info < (3, 7):
|
28 |
+
pkgs_to_check_at_runtime.append("dataclasses")
|
29 |
+
if sys.version_info < (3, 8):
|
30 |
+
pkgs_to_check_at_runtime.append("importlib_metadata")
|
31 |
+
|
32 |
+
for pkg in pkgs_to_check_at_runtime:
|
33 |
+
if pkg in deps:
|
34 |
+
if pkg == "tokenizers":
|
35 |
+
# must be loaded here, or else tqdm check may fail
|
36 |
+
from .utils import is_tokenizers_available
|
37 |
+
|
38 |
+
if not is_tokenizers_available():
|
39 |
+
continue # not required, check version only if installed
|
40 |
+
|
41 |
+
require_version_core(deps[pkg])
|
42 |
+
else:
|
43 |
+
raise ValueError(f"can't find {pkg} in {deps.keys()}, check dependency_versions_table.py")
|
44 |
+
|
45 |
+
|
46 |
+
def dep_version_check(pkg, hint=None):
|
47 |
+
require_version(deps[pkg], hint)
|
6DoF/diffusers/dependency_versions_table.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# THIS FILE HAS BEEN AUTOGENERATED. To update:
|
2 |
+
# 1. modify the `_deps` dict in setup.py
|
3 |
+
# 2. run `make deps_table_update``
|
4 |
+
deps = {
|
5 |
+
"Pillow": "Pillow",
|
6 |
+
"accelerate": "accelerate>=0.11.0",
|
7 |
+
"compel": "compel==0.1.8",
|
8 |
+
"black": "black~=23.1",
|
9 |
+
"datasets": "datasets",
|
10 |
+
"filelock": "filelock",
|
11 |
+
"flax": "flax>=0.4.1",
|
12 |
+
"hf-doc-builder": "hf-doc-builder>=0.3.0",
|
13 |
+
"huggingface-hub": "huggingface-hub>=0.13.2",
|
14 |
+
"requests-mock": "requests-mock==1.10.0",
|
15 |
+
"importlib_metadata": "importlib_metadata",
|
16 |
+
"invisible-watermark": "invisible-watermark",
|
17 |
+
"isort": "isort>=5.5.4",
|
18 |
+
"jax": "jax>=0.2.8,!=0.3.2",
|
19 |
+
"jaxlib": "jaxlib>=0.1.65",
|
20 |
+
"Jinja2": "Jinja2",
|
21 |
+
"k-diffusion": "k-diffusion>=0.0.12",
|
22 |
+
"torchsde": "torchsde",
|
23 |
+
"note_seq": "note_seq",
|
24 |
+
"librosa": "librosa",
|
25 |
+
"numpy": "numpy",
|
26 |
+
"omegaconf": "omegaconf",
|
27 |
+
"parameterized": "parameterized",
|
28 |
+
"protobuf": "protobuf>=3.20.3,<4",
|
29 |
+
"pytest": "pytest",
|
30 |
+
"pytest-timeout": "pytest-timeout",
|
31 |
+
"pytest-xdist": "pytest-xdist",
|
32 |
+
"ruff": "ruff>=0.0.241",
|
33 |
+
"safetensors": "safetensors",
|
34 |
+
"sentencepiece": "sentencepiece>=0.1.91,!=0.1.92",
|
35 |
+
"scipy": "scipy",
|
36 |
+
"onnx": "onnx",
|
37 |
+
"regex": "regex!=2019.12.17",
|
38 |
+
"requests": "requests",
|
39 |
+
"tensorboard": "tensorboard",
|
40 |
+
"torch": "torch>=1.4",
|
41 |
+
"torchvision": "torchvision",
|
42 |
+
"transformers": "transformers>=4.25.1",
|
43 |
+
"urllib3": "urllib3<=2.0.0",
|
44 |
+
}
|
6DoF/diffusers/experimental/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .rl import ValueGuidedRLPipeline
|
6DoF/diffusers/experimental/rl/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .value_guided_sampling import ValueGuidedRLPipeline
|
6DoF/diffusers/experimental/rl/value_guided_sampling.py
ADDED
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import numpy as np
|
16 |
+
import torch
|
17 |
+
import tqdm
|
18 |
+
|
19 |
+
from ...models.unet_1d import UNet1DModel
|
20 |
+
from ...pipelines import DiffusionPipeline
|
21 |
+
from ...utils import randn_tensor
|
22 |
+
from ...utils.dummy_pt_objects import DDPMScheduler
|
23 |
+
|
24 |
+
|
25 |
+
class ValueGuidedRLPipeline(DiffusionPipeline):
|
26 |
+
r"""
|
27 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
28 |
+
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
29 |
+
Pipeline for sampling actions from a diffusion model trained to predict sequences of states.
|
30 |
+
|
31 |
+
Original implementation inspired by this repository: https://github.com/jannerm/diffuser.
|
32 |
+
|
33 |
+
Parameters:
|
34 |
+
value_function ([`UNet1DModel`]): A specialized UNet for fine-tuning trajectories base on reward.
|
35 |
+
unet ([`UNet1DModel`]): U-Net architecture to denoise the encoded trajectories.
|
36 |
+
scheduler ([`SchedulerMixin`]):
|
37 |
+
A scheduler to be used in combination with `unet` to denoise the encoded trajectories. Default for this
|
38 |
+
application is [`DDPMScheduler`].
|
39 |
+
env: An environment following the OpenAI gym API to act in. For now only Hopper has pretrained models.
|
40 |
+
"""
|
41 |
+
|
42 |
+
def __init__(
|
43 |
+
self,
|
44 |
+
value_function: UNet1DModel,
|
45 |
+
unet: UNet1DModel,
|
46 |
+
scheduler: DDPMScheduler,
|
47 |
+
env,
|
48 |
+
):
|
49 |
+
super().__init__()
|
50 |
+
self.value_function = value_function
|
51 |
+
self.unet = unet
|
52 |
+
self.scheduler = scheduler
|
53 |
+
self.env = env
|
54 |
+
self.data = env.get_dataset()
|
55 |
+
self.means = {}
|
56 |
+
for key in self.data.keys():
|
57 |
+
try:
|
58 |
+
self.means[key] = self.data[key].mean()
|
59 |
+
except: # noqa: E722
|
60 |
+
pass
|
61 |
+
self.stds = {}
|
62 |
+
for key in self.data.keys():
|
63 |
+
try:
|
64 |
+
self.stds[key] = self.data[key].std()
|
65 |
+
except: # noqa: E722
|
66 |
+
pass
|
67 |
+
self.state_dim = env.observation_space.shape[0]
|
68 |
+
self.action_dim = env.action_space.shape[0]
|
69 |
+
|
70 |
+
def normalize(self, x_in, key):
|
71 |
+
return (x_in - self.means[key]) / self.stds[key]
|
72 |
+
|
73 |
+
def de_normalize(self, x_in, key):
|
74 |
+
return x_in * self.stds[key] + self.means[key]
|
75 |
+
|
76 |
+
def to_torch(self, x_in):
|
77 |
+
if type(x_in) is dict:
|
78 |
+
return {k: self.to_torch(v) for k, v in x_in.items()}
|
79 |
+
elif torch.is_tensor(x_in):
|
80 |
+
return x_in.to(self.unet.device)
|
81 |
+
return torch.tensor(x_in, device=self.unet.device)
|
82 |
+
|
83 |
+
def reset_x0(self, x_in, cond, act_dim):
|
84 |
+
for key, val in cond.items():
|
85 |
+
x_in[:, key, act_dim:] = val.clone()
|
86 |
+
return x_in
|
87 |
+
|
88 |
+
def run_diffusion(self, x, conditions, n_guide_steps, scale):
|
89 |
+
batch_size = x.shape[0]
|
90 |
+
y = None
|
91 |
+
for i in tqdm.tqdm(self.scheduler.timesteps):
|
92 |
+
# create batch of timesteps to pass into model
|
93 |
+
timesteps = torch.full((batch_size,), i, device=self.unet.device, dtype=torch.long)
|
94 |
+
for _ in range(n_guide_steps):
|
95 |
+
with torch.enable_grad():
|
96 |
+
x.requires_grad_()
|
97 |
+
|
98 |
+
# permute to match dimension for pre-trained models
|
99 |
+
y = self.value_function(x.permute(0, 2, 1), timesteps).sample
|
100 |
+
grad = torch.autograd.grad([y.sum()], [x])[0]
|
101 |
+
|
102 |
+
posterior_variance = self.scheduler._get_variance(i)
|
103 |
+
model_std = torch.exp(0.5 * posterior_variance)
|
104 |
+
grad = model_std * grad
|
105 |
+
|
106 |
+
grad[timesteps < 2] = 0
|
107 |
+
x = x.detach()
|
108 |
+
x = x + scale * grad
|
109 |
+
x = self.reset_x0(x, conditions, self.action_dim)
|
110 |
+
|
111 |
+
prev_x = self.unet(x.permute(0, 2, 1), timesteps).sample.permute(0, 2, 1)
|
112 |
+
|
113 |
+
# TODO: verify deprecation of this kwarg
|
114 |
+
x = self.scheduler.step(prev_x, i, x, predict_epsilon=False)["prev_sample"]
|
115 |
+
|
116 |
+
# apply conditions to the trajectory (set the initial state)
|
117 |
+
x = self.reset_x0(x, conditions, self.action_dim)
|
118 |
+
x = self.to_torch(x)
|
119 |
+
return x, y
|
120 |
+
|
121 |
+
def __call__(self, obs, batch_size=64, planning_horizon=32, n_guide_steps=2, scale=0.1):
|
122 |
+
# normalize the observations and create batch dimension
|
123 |
+
obs = self.normalize(obs, "observations")
|
124 |
+
obs = obs[None].repeat(batch_size, axis=0)
|
125 |
+
|
126 |
+
conditions = {0: self.to_torch(obs)}
|
127 |
+
shape = (batch_size, planning_horizon, self.state_dim + self.action_dim)
|
128 |
+
|
129 |
+
# generate initial noise and apply our conditions (to make the trajectories start at current state)
|
130 |
+
x1 = randn_tensor(shape, device=self.unet.device)
|
131 |
+
x = self.reset_x0(x1, conditions, self.action_dim)
|
132 |
+
x = self.to_torch(x)
|
133 |
+
|
134 |
+
# run the diffusion process
|
135 |
+
x, y = self.run_diffusion(x, conditions, n_guide_steps, scale)
|
136 |
+
|
137 |
+
# sort output trajectories by value
|
138 |
+
sorted_idx = y.argsort(0, descending=True).squeeze()
|
139 |
+
sorted_values = x[sorted_idx]
|
140 |
+
actions = sorted_values[:, :, : self.action_dim]
|
141 |
+
actions = actions.detach().cpu().numpy()
|
142 |
+
denorm_actions = self.de_normalize(actions, key="actions")
|
143 |
+
|
144 |
+
# select the action with the highest value
|
145 |
+
if y is not None:
|
146 |
+
selected_index = 0
|
147 |
+
else:
|
148 |
+
# if we didn't run value guiding, select a random action
|
149 |
+
selected_index = np.random.randint(0, batch_size)
|
150 |
+
|
151 |
+
denorm_actions = denorm_actions[selected_index, 0]
|
152 |
+
return denorm_actions
|
6DoF/diffusers/image_processor.py
ADDED
@@ -0,0 +1,366 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import warnings
|
16 |
+
from typing import List, Optional, Union
|
17 |
+
|
18 |
+
import numpy as np
|
19 |
+
import PIL
|
20 |
+
import torch
|
21 |
+
from PIL import Image
|
22 |
+
|
23 |
+
from .configuration_utils import ConfigMixin, register_to_config
|
24 |
+
from .utils import CONFIG_NAME, PIL_INTERPOLATION, deprecate
|
25 |
+
|
26 |
+
|
27 |
+
class VaeImageProcessor(ConfigMixin):
|
28 |
+
"""
|
29 |
+
Image processor for VAE.
|
30 |
+
|
31 |
+
Args:
|
32 |
+
do_resize (`bool`, *optional*, defaults to `True`):
|
33 |
+
Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`. Can accept
|
34 |
+
`height` and `width` arguments from [`image_processor.VaeImageProcessor.preprocess`] method.
|
35 |
+
vae_scale_factor (`int`, *optional*, defaults to `8`):
|
36 |
+
VAE scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of this factor.
|
37 |
+
resample (`str`, *optional*, defaults to `lanczos`):
|
38 |
+
Resampling filter to use when resizing the image.
|
39 |
+
do_normalize (`bool`, *optional*, defaults to `True`):
|
40 |
+
Whether to normalize the image to [-1,1].
|
41 |
+
do_convert_rgb (`bool`, *optional*, defaults to be `False`):
|
42 |
+
Whether to convert the images to RGB format.
|
43 |
+
"""
|
44 |
+
|
45 |
+
config_name = CONFIG_NAME
|
46 |
+
|
47 |
+
@register_to_config
|
48 |
+
def __init__(
|
49 |
+
self,
|
50 |
+
do_resize: bool = True,
|
51 |
+
vae_scale_factor: int = 8,
|
52 |
+
resample: str = "lanczos",
|
53 |
+
do_normalize: bool = True,
|
54 |
+
do_convert_rgb: bool = False,
|
55 |
+
):
|
56 |
+
super().__init__()
|
57 |
+
|
58 |
+
@staticmethod
|
59 |
+
def numpy_to_pil(images: np.ndarray) -> PIL.Image.Image:
|
60 |
+
"""
|
61 |
+
Convert a numpy image or a batch of images to a PIL image.
|
62 |
+
"""
|
63 |
+
if images.ndim == 3:
|
64 |
+
images = images[None, ...]
|
65 |
+
images = (images * 255).round().astype("uint8")
|
66 |
+
if images.shape[-1] == 1:
|
67 |
+
# special case for grayscale (single channel) images
|
68 |
+
pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
|
69 |
+
else:
|
70 |
+
pil_images = [Image.fromarray(image) for image in images]
|
71 |
+
|
72 |
+
return pil_images
|
73 |
+
|
74 |
+
@staticmethod
|
75 |
+
def pil_to_numpy(images: Union[List[PIL.Image.Image], PIL.Image.Image]) -> np.ndarray:
|
76 |
+
"""
|
77 |
+
Convert a PIL image or a list of PIL images to NumPy arrays.
|
78 |
+
"""
|
79 |
+
if not isinstance(images, list):
|
80 |
+
images = [images]
|
81 |
+
images = [np.array(image).astype(np.float32) / 255.0 for image in images]
|
82 |
+
images = np.stack(images, axis=0)
|
83 |
+
|
84 |
+
return images
|
85 |
+
|
86 |
+
@staticmethod
|
87 |
+
def numpy_to_pt(images: np.ndarray) -> torch.FloatTensor:
|
88 |
+
"""
|
89 |
+
Convert a NumPy image to a PyTorch tensor.
|
90 |
+
"""
|
91 |
+
if images.ndim == 3:
|
92 |
+
images = images[..., None]
|
93 |
+
|
94 |
+
images = torch.from_numpy(images.transpose(0, 3, 1, 2))
|
95 |
+
return images
|
96 |
+
|
97 |
+
@staticmethod
|
98 |
+
def pt_to_numpy(images: torch.FloatTensor) -> np.ndarray:
|
99 |
+
"""
|
100 |
+
Convert a PyTorch tensor to a NumPy image.
|
101 |
+
"""
|
102 |
+
images = images.cpu().permute(0, 2, 3, 1).float().numpy()
|
103 |
+
return images
|
104 |
+
|
105 |
+
@staticmethod
|
106 |
+
def normalize(images):
|
107 |
+
"""
|
108 |
+
Normalize an image array to [-1,1].
|
109 |
+
"""
|
110 |
+
return 2.0 * images - 1.0
|
111 |
+
|
112 |
+
@staticmethod
|
113 |
+
def denormalize(images):
|
114 |
+
"""
|
115 |
+
Denormalize an image array to [0,1].
|
116 |
+
"""
|
117 |
+
return (images / 2 + 0.5).clamp(0, 1)
|
118 |
+
|
119 |
+
@staticmethod
|
120 |
+
def convert_to_rgb(image: PIL.Image.Image) -> PIL.Image.Image:
|
121 |
+
"""
|
122 |
+
Converts an image to RGB format.
|
123 |
+
"""
|
124 |
+
image = image.convert("RGB")
|
125 |
+
return image
|
126 |
+
|
127 |
+
def resize(
|
128 |
+
self,
|
129 |
+
image: PIL.Image.Image,
|
130 |
+
height: Optional[int] = None,
|
131 |
+
width: Optional[int] = None,
|
132 |
+
) -> PIL.Image.Image:
|
133 |
+
"""
|
134 |
+
Resize a PIL image. Both height and width are downscaled to the next integer multiple of `vae_scale_factor`.
|
135 |
+
"""
|
136 |
+
if height is None:
|
137 |
+
height = image.height
|
138 |
+
if width is None:
|
139 |
+
width = image.width
|
140 |
+
|
141 |
+
width, height = (
|
142 |
+
x - x % self.config.vae_scale_factor for x in (width, height)
|
143 |
+
) # resize to integer multiple of vae_scale_factor
|
144 |
+
image = image.resize((width, height), resample=PIL_INTERPOLATION[self.config.resample])
|
145 |
+
return image
|
146 |
+
|
147 |
+
def preprocess(
|
148 |
+
self,
|
149 |
+
image: Union[torch.FloatTensor, PIL.Image.Image, np.ndarray],
|
150 |
+
height: Optional[int] = None,
|
151 |
+
width: Optional[int] = None,
|
152 |
+
) -> torch.Tensor:
|
153 |
+
"""
|
154 |
+
Preprocess the image input. Accepted formats are PIL images, NumPy arrays or PyTorch tensors.
|
155 |
+
"""
|
156 |
+
supported_formats = (PIL.Image.Image, np.ndarray, torch.Tensor)
|
157 |
+
if isinstance(image, supported_formats):
|
158 |
+
image = [image]
|
159 |
+
elif not (isinstance(image, list) and all(isinstance(i, supported_formats) for i in image)):
|
160 |
+
raise ValueError(
|
161 |
+
f"Input is in incorrect format: {[type(i) for i in image]}. Currently, we only support {', '.join(supported_formats)}"
|
162 |
+
)
|
163 |
+
|
164 |
+
if isinstance(image[0], PIL.Image.Image):
|
165 |
+
if self.config.do_convert_rgb:
|
166 |
+
image = [self.convert_to_rgb(i) for i in image]
|
167 |
+
if self.config.do_resize:
|
168 |
+
image = [self.resize(i, height, width) for i in image]
|
169 |
+
image = self.pil_to_numpy(image) # to np
|
170 |
+
image = self.numpy_to_pt(image) # to pt
|
171 |
+
|
172 |
+
elif isinstance(image[0], np.ndarray):
|
173 |
+
image = np.concatenate(image, axis=0) if image[0].ndim == 4 else np.stack(image, axis=0)
|
174 |
+
image = self.numpy_to_pt(image)
|
175 |
+
_, _, height, width = image.shape
|
176 |
+
if self.config.do_resize and (
|
177 |
+
height % self.config.vae_scale_factor != 0 or width % self.config.vae_scale_factor != 0
|
178 |
+
):
|
179 |
+
raise ValueError(
|
180 |
+
f"Currently we only support resizing for PIL image - please resize your numpy array to be divisible by {self.config.vae_scale_factor}"
|
181 |
+
f"currently the sizes are {height} and {width}. You can also pass a PIL image instead to use resize option in VAEImageProcessor"
|
182 |
+
)
|
183 |
+
|
184 |
+
elif isinstance(image[0], torch.Tensor):
|
185 |
+
image = torch.cat(image, axis=0) if image[0].ndim == 4 else torch.stack(image, axis=0)
|
186 |
+
_, channel, height, width = image.shape
|
187 |
+
|
188 |
+
# don't need any preprocess if the image is latents
|
189 |
+
if channel == 4:
|
190 |
+
return image
|
191 |
+
|
192 |
+
if self.config.do_resize and (
|
193 |
+
height % self.config.vae_scale_factor != 0 or width % self.config.vae_scale_factor != 0
|
194 |
+
):
|
195 |
+
raise ValueError(
|
196 |
+
f"Currently we only support resizing for PIL image - please resize your pytorch tensor to be divisible by {self.config.vae_scale_factor}"
|
197 |
+
f"currently the sizes are {height} and {width}. You can also pass a PIL image instead to use resize option in VAEImageProcessor"
|
198 |
+
)
|
199 |
+
|
200 |
+
# expected range [0,1], normalize to [-1,1]
|
201 |
+
do_normalize = self.config.do_normalize
|
202 |
+
if image.min() < 0:
|
203 |
+
warnings.warn(
|
204 |
+
"Passing `image` as torch tensor with value range in [-1,1] is deprecated. The expected value range for image tensor is [0,1] "
|
205 |
+
f"when passing as pytorch tensor or numpy Array. You passed `image` with value range [{image.min()},{image.max()}]",
|
206 |
+
FutureWarning,
|
207 |
+
)
|
208 |
+
do_normalize = False
|
209 |
+
|
210 |
+
if do_normalize:
|
211 |
+
image = self.normalize(image)
|
212 |
+
|
213 |
+
return image
|
214 |
+
|
215 |
+
def postprocess(
|
216 |
+
self,
|
217 |
+
image: torch.FloatTensor,
|
218 |
+
output_type: str = "pil",
|
219 |
+
do_denormalize: Optional[List[bool]] = None,
|
220 |
+
):
|
221 |
+
if not isinstance(image, torch.Tensor):
|
222 |
+
raise ValueError(
|
223 |
+
f"Input for postprocessing is in incorrect format: {type(image)}. We only support pytorch tensor"
|
224 |
+
)
|
225 |
+
if output_type not in ["latent", "pt", "np", "pil"]:
|
226 |
+
deprecation_message = (
|
227 |
+
f"the output_type {output_type} is outdated and has been set to `np`. Please make sure to set it to one of these instead: "
|
228 |
+
"`pil`, `np`, `pt`, `latent`"
|
229 |
+
)
|
230 |
+
deprecate("Unsupported output_type", "1.0.0", deprecation_message, standard_warn=False)
|
231 |
+
output_type = "np"
|
232 |
+
|
233 |
+
if output_type == "latent":
|
234 |
+
return image
|
235 |
+
|
236 |
+
if do_denormalize is None:
|
237 |
+
do_denormalize = [self.config.do_normalize] * image.shape[0]
|
238 |
+
|
239 |
+
image = torch.stack(
|
240 |
+
[self.denormalize(image[i]) if do_denormalize[i] else image[i] for i in range(image.shape[0])]
|
241 |
+
)
|
242 |
+
|
243 |
+
if output_type == "pt":
|
244 |
+
return image
|
245 |
+
|
246 |
+
image = self.pt_to_numpy(image)
|
247 |
+
|
248 |
+
if output_type == "np":
|
249 |
+
return image
|
250 |
+
|
251 |
+
if output_type == "pil":
|
252 |
+
return self.numpy_to_pil(image)
|
253 |
+
|
254 |
+
|
255 |
+
class VaeImageProcessorLDM3D(VaeImageProcessor):
|
256 |
+
"""
|
257 |
+
Image processor for VAE LDM3D.
|
258 |
+
|
259 |
+
Args:
|
260 |
+
do_resize (`bool`, *optional*, defaults to `True`):
|
261 |
+
Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`.
|
262 |
+
vae_scale_factor (`int`, *optional*, defaults to `8`):
|
263 |
+
VAE scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of this factor.
|
264 |
+
resample (`str`, *optional*, defaults to `lanczos`):
|
265 |
+
Resampling filter to use when resizing the image.
|
266 |
+
do_normalize (`bool`, *optional*, defaults to `True`):
|
267 |
+
Whether to normalize the image to [-1,1].
|
268 |
+
"""
|
269 |
+
|
270 |
+
config_name = CONFIG_NAME
|
271 |
+
|
272 |
+
@register_to_config
|
273 |
+
def __init__(
|
274 |
+
self,
|
275 |
+
do_resize: bool = True,
|
276 |
+
vae_scale_factor: int = 8,
|
277 |
+
resample: str = "lanczos",
|
278 |
+
do_normalize: bool = True,
|
279 |
+
):
|
280 |
+
super().__init__()
|
281 |
+
|
282 |
+
@staticmethod
|
283 |
+
def numpy_to_pil(images):
|
284 |
+
"""
|
285 |
+
Convert a NumPy image or a batch of images to a PIL image.
|
286 |
+
"""
|
287 |
+
if images.ndim == 3:
|
288 |
+
images = images[None, ...]
|
289 |
+
images = (images * 255).round().astype("uint8")
|
290 |
+
if images.shape[-1] == 1:
|
291 |
+
# special case for grayscale (single channel) images
|
292 |
+
pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
|
293 |
+
else:
|
294 |
+
pil_images = [Image.fromarray(image[:, :, :3]) for image in images]
|
295 |
+
|
296 |
+
return pil_images
|
297 |
+
|
298 |
+
@staticmethod
|
299 |
+
def rgblike_to_depthmap(image):
|
300 |
+
"""
|
301 |
+
Args:
|
302 |
+
image: RGB-like depth image
|
303 |
+
|
304 |
+
Returns: depth map
|
305 |
+
|
306 |
+
"""
|
307 |
+
return image[:, :, 1] * 2**8 + image[:, :, 2]
|
308 |
+
|
309 |
+
def numpy_to_depth(self, images):
|
310 |
+
"""
|
311 |
+
Convert a NumPy depth image or a batch of images to a PIL image.
|
312 |
+
"""
|
313 |
+
if images.ndim == 3:
|
314 |
+
images = images[None, ...]
|
315 |
+
images_depth = images[:, :, :, 3:]
|
316 |
+
if images.shape[-1] == 6:
|
317 |
+
images_depth = (images_depth * 255).round().astype("uint8")
|
318 |
+
pil_images = [
|
319 |
+
Image.fromarray(self.rgblike_to_depthmap(image_depth), mode="I;16") for image_depth in images_depth
|
320 |
+
]
|
321 |
+
elif images.shape[-1] == 4:
|
322 |
+
images_depth = (images_depth * 65535.0).astype(np.uint16)
|
323 |
+
pil_images = [Image.fromarray(image_depth, mode="I;16") for image_depth in images_depth]
|
324 |
+
else:
|
325 |
+
raise Exception("Not supported")
|
326 |
+
|
327 |
+
return pil_images
|
328 |
+
|
329 |
+
def postprocess(
|
330 |
+
self,
|
331 |
+
image: torch.FloatTensor,
|
332 |
+
output_type: str = "pil",
|
333 |
+
do_denormalize: Optional[List[bool]] = None,
|
334 |
+
):
|
335 |
+
if not isinstance(image, torch.Tensor):
|
336 |
+
raise ValueError(
|
337 |
+
f"Input for postprocessing is in incorrect format: {type(image)}. We only support pytorch tensor"
|
338 |
+
)
|
339 |
+
if output_type not in ["latent", "pt", "np", "pil"]:
|
340 |
+
deprecation_message = (
|
341 |
+
f"the output_type {output_type} is outdated and has been set to `np`. Please make sure to set it to one of these instead: "
|
342 |
+
"`pil`, `np`, `pt`, `latent`"
|
343 |
+
)
|
344 |
+
deprecate("Unsupported output_type", "1.0.0", deprecation_message, standard_warn=False)
|
345 |
+
output_type = "np"
|
346 |
+
|
347 |
+
if do_denormalize is None:
|
348 |
+
do_denormalize = [self.config.do_normalize] * image.shape[0]
|
349 |
+
|
350 |
+
image = torch.stack(
|
351 |
+
[self.denormalize(image[i]) if do_denormalize[i] else image[i] for i in range(image.shape[0])]
|
352 |
+
)
|
353 |
+
|
354 |
+
image = self.pt_to_numpy(image)
|
355 |
+
|
356 |
+
if output_type == "np":
|
357 |
+
if image.shape[-1] == 6:
|
358 |
+
image_depth = np.stack([self.rgblike_to_depthmap(im[:, :, 3:]) for im in image], axis=0)
|
359 |
+
else:
|
360 |
+
image_depth = image[:, :, :, 3:]
|
361 |
+
return image[:, :, :, :3], image_depth
|
362 |
+
|
363 |
+
if output_type == "pil":
|
364 |
+
return self.numpy_to_pil(image), self.numpy_to_depth(image)
|
365 |
+
else:
|
366 |
+
raise Exception(f"This type {output_type} is not supported")
|
6DoF/diffusers/loaders.py
ADDED
@@ -0,0 +1,1492 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import os
|
15 |
+
import warnings
|
16 |
+
from collections import defaultdict
|
17 |
+
from pathlib import Path
|
18 |
+
from typing import Callable, Dict, List, Optional, Union
|
19 |
+
|
20 |
+
import torch
|
21 |
+
import torch.nn.functional as F
|
22 |
+
from huggingface_hub import hf_hub_download
|
23 |
+
|
24 |
+
from .models.attention_processor import (
|
25 |
+
AttnAddedKVProcessor,
|
26 |
+
AttnAddedKVProcessor2_0,
|
27 |
+
CustomDiffusionAttnProcessor,
|
28 |
+
CustomDiffusionXFormersAttnProcessor,
|
29 |
+
LoRAAttnAddedKVProcessor,
|
30 |
+
LoRAAttnProcessor,
|
31 |
+
LoRAAttnProcessor2_0,
|
32 |
+
LoRAXFormersAttnProcessor,
|
33 |
+
SlicedAttnAddedKVProcessor,
|
34 |
+
XFormersAttnProcessor,
|
35 |
+
)
|
36 |
+
from .utils import (
|
37 |
+
DIFFUSERS_CACHE,
|
38 |
+
HF_HUB_OFFLINE,
|
39 |
+
TEXT_ENCODER_ATTN_MODULE,
|
40 |
+
_get_model_file,
|
41 |
+
deprecate,
|
42 |
+
is_safetensors_available,
|
43 |
+
is_transformers_available,
|
44 |
+
logging,
|
45 |
+
)
|
46 |
+
|
47 |
+
|
48 |
+
if is_safetensors_available():
|
49 |
+
import safetensors
|
50 |
+
|
51 |
+
if is_transformers_available():
|
52 |
+
from transformers import PreTrainedModel, PreTrainedTokenizer
|
53 |
+
|
54 |
+
|
55 |
+
logger = logging.get_logger(__name__)
|
56 |
+
|
57 |
+
TEXT_ENCODER_NAME = "text_encoder"
|
58 |
+
UNET_NAME = "unet"
|
59 |
+
|
60 |
+
LORA_WEIGHT_NAME = "pytorch_lora_weights.bin"
|
61 |
+
LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors"
|
62 |
+
|
63 |
+
TEXT_INVERSION_NAME = "learned_embeds.bin"
|
64 |
+
TEXT_INVERSION_NAME_SAFE = "learned_embeds.safetensors"
|
65 |
+
|
66 |
+
CUSTOM_DIFFUSION_WEIGHT_NAME = "pytorch_custom_diffusion_weights.bin"
|
67 |
+
CUSTOM_DIFFUSION_WEIGHT_NAME_SAFE = "pytorch_custom_diffusion_weights.safetensors"
|
68 |
+
|
69 |
+
|
70 |
+
class AttnProcsLayers(torch.nn.Module):
|
71 |
+
def __init__(self, state_dict: Dict[str, torch.Tensor]):
|
72 |
+
super().__init__()
|
73 |
+
self.layers = torch.nn.ModuleList(state_dict.values())
|
74 |
+
self.mapping = dict(enumerate(state_dict.keys()))
|
75 |
+
self.rev_mapping = {v: k for k, v in enumerate(state_dict.keys())}
|
76 |
+
|
77 |
+
# .processor for unet, .self_attn for text encoder
|
78 |
+
self.split_keys = [".processor", ".self_attn"]
|
79 |
+
|
80 |
+
# we add a hook to state_dict() and load_state_dict() so that the
|
81 |
+
# naming fits with `unet.attn_processors`
|
82 |
+
def map_to(module, state_dict, *args, **kwargs):
|
83 |
+
new_state_dict = {}
|
84 |
+
for key, value in state_dict.items():
|
85 |
+
num = int(key.split(".")[1]) # 0 is always "layers"
|
86 |
+
new_key = key.replace(f"layers.{num}", module.mapping[num])
|
87 |
+
new_state_dict[new_key] = value
|
88 |
+
|
89 |
+
return new_state_dict
|
90 |
+
|
91 |
+
def remap_key(key, state_dict):
|
92 |
+
for k in self.split_keys:
|
93 |
+
if k in key:
|
94 |
+
return key.split(k)[0] + k
|
95 |
+
|
96 |
+
raise ValueError(
|
97 |
+
f"There seems to be a problem with the state_dict: {set(state_dict.keys())}. {key} has to have one of {self.split_keys}."
|
98 |
+
)
|
99 |
+
|
100 |
+
def map_from(module, state_dict, *args, **kwargs):
|
101 |
+
all_keys = list(state_dict.keys())
|
102 |
+
for key in all_keys:
|
103 |
+
replace_key = remap_key(key, state_dict)
|
104 |
+
new_key = key.replace(replace_key, f"layers.{module.rev_mapping[replace_key]}")
|
105 |
+
state_dict[new_key] = state_dict[key]
|
106 |
+
del state_dict[key]
|
107 |
+
|
108 |
+
self._register_state_dict_hook(map_to)
|
109 |
+
self._register_load_state_dict_pre_hook(map_from, with_module=True)
|
110 |
+
|
111 |
+
|
112 |
+
class UNet2DConditionLoadersMixin:
|
113 |
+
text_encoder_name = TEXT_ENCODER_NAME
|
114 |
+
unet_name = UNET_NAME
|
115 |
+
|
116 |
+
def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
|
117 |
+
r"""
|
118 |
+
Load pretrained attention processor layers into [`UNet2DConditionModel`]. Attention processor layers have to be
|
119 |
+
defined in
|
120 |
+
[`cross_attention.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py)
|
121 |
+
and be a `torch.nn.Module` class.
|
122 |
+
|
123 |
+
Parameters:
|
124 |
+
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
125 |
+
Can be either:
|
126 |
+
|
127 |
+
- A string, the model id (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
|
128 |
+
the Hub.
|
129 |
+
- A path to a directory (for example `./my_model_directory`) containing the model weights saved
|
130 |
+
with [`ModelMixin.save_pretrained`].
|
131 |
+
- A [torch state
|
132 |
+
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
|
133 |
+
|
134 |
+
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
135 |
+
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
136 |
+
is not used.
|
137 |
+
force_download (`bool`, *optional*, defaults to `False`):
|
138 |
+
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
139 |
+
cached versions if they exist.
|
140 |
+
resume_download (`bool`, *optional*, defaults to `False`):
|
141 |
+
Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
|
142 |
+
incompletely downloaded files are deleted.
|
143 |
+
proxies (`Dict[str, str]`, *optional*):
|
144 |
+
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
145 |
+
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
146 |
+
local_files_only (`bool`, *optional*, defaults to `False`):
|
147 |
+
Whether to only load local model weights and configuration files or not. If set to `True`, the model
|
148 |
+
won't be downloaded from the Hub.
|
149 |
+
use_auth_token (`str` or *bool*, *optional*):
|
150 |
+
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
151 |
+
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
152 |
+
revision (`str`, *optional*, defaults to `"main"`):
|
153 |
+
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
154 |
+
allowed by Git.
|
155 |
+
subfolder (`str`, *optional*, defaults to `""`):
|
156 |
+
The subfolder location of a model file within a larger model repository on the Hub or locally.
|
157 |
+
mirror (`str`, *optional*):
|
158 |
+
Mirror source to resolve accessibility issues if you’re downloading a model in China. We do not
|
159 |
+
guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
|
160 |
+
information.
|
161 |
+
|
162 |
+
"""
|
163 |
+
|
164 |
+
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
|
165 |
+
force_download = kwargs.pop("force_download", False)
|
166 |
+
resume_download = kwargs.pop("resume_download", False)
|
167 |
+
proxies = kwargs.pop("proxies", None)
|
168 |
+
local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
|
169 |
+
use_auth_token = kwargs.pop("use_auth_token", None)
|
170 |
+
revision = kwargs.pop("revision", None)
|
171 |
+
subfolder = kwargs.pop("subfolder", None)
|
172 |
+
weight_name = kwargs.pop("weight_name", None)
|
173 |
+
use_safetensors = kwargs.pop("use_safetensors", None)
|
174 |
+
# This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
|
175 |
+
# See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
|
176 |
+
network_alpha = kwargs.pop("network_alpha", None)
|
177 |
+
|
178 |
+
if use_safetensors and not is_safetensors_available():
|
179 |
+
raise ValueError(
|
180 |
+
"`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetensors"
|
181 |
+
)
|
182 |
+
|
183 |
+
allow_pickle = False
|
184 |
+
if use_safetensors is None:
|
185 |
+
use_safetensors = is_safetensors_available()
|
186 |
+
allow_pickle = True
|
187 |
+
|
188 |
+
user_agent = {
|
189 |
+
"file_type": "attn_procs_weights",
|
190 |
+
"framework": "pytorch",
|
191 |
+
}
|
192 |
+
|
193 |
+
model_file = None
|
194 |
+
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
|
195 |
+
# Let's first try to load .safetensors weights
|
196 |
+
if (use_safetensors and weight_name is None) or (
|
197 |
+
weight_name is not None and weight_name.endswith(".safetensors")
|
198 |
+
):
|
199 |
+
try:
|
200 |
+
model_file = _get_model_file(
|
201 |
+
pretrained_model_name_or_path_or_dict,
|
202 |
+
weights_name=weight_name or LORA_WEIGHT_NAME_SAFE,
|
203 |
+
cache_dir=cache_dir,
|
204 |
+
force_download=force_download,
|
205 |
+
resume_download=resume_download,
|
206 |
+
proxies=proxies,
|
207 |
+
local_files_only=local_files_only,
|
208 |
+
use_auth_token=use_auth_token,
|
209 |
+
revision=revision,
|
210 |
+
subfolder=subfolder,
|
211 |
+
user_agent=user_agent,
|
212 |
+
)
|
213 |
+
state_dict = safetensors.torch.load_file(model_file, device="cpu")
|
214 |
+
except IOError as e:
|
215 |
+
if not allow_pickle:
|
216 |
+
raise e
|
217 |
+
# try loading non-safetensors weights
|
218 |
+
pass
|
219 |
+
if model_file is None:
|
220 |
+
model_file = _get_model_file(
|
221 |
+
pretrained_model_name_or_path_or_dict,
|
222 |
+
weights_name=weight_name or LORA_WEIGHT_NAME,
|
223 |
+
cache_dir=cache_dir,
|
224 |
+
force_download=force_download,
|
225 |
+
resume_download=resume_download,
|
226 |
+
proxies=proxies,
|
227 |
+
local_files_only=local_files_only,
|
228 |
+
use_auth_token=use_auth_token,
|
229 |
+
revision=revision,
|
230 |
+
subfolder=subfolder,
|
231 |
+
user_agent=user_agent,
|
232 |
+
)
|
233 |
+
state_dict = torch.load(model_file, map_location="cpu")
|
234 |
+
else:
|
235 |
+
state_dict = pretrained_model_name_or_path_or_dict
|
236 |
+
|
237 |
+
# fill attn processors
|
238 |
+
attn_processors = {}
|
239 |
+
|
240 |
+
is_lora = all("lora" in k for k in state_dict.keys())
|
241 |
+
is_custom_diffusion = any("custom_diffusion" in k for k in state_dict.keys())
|
242 |
+
|
243 |
+
if is_lora:
|
244 |
+
is_new_lora_format = all(
|
245 |
+
key.startswith(self.unet_name) or key.startswith(self.text_encoder_name) for key in state_dict.keys()
|
246 |
+
)
|
247 |
+
if is_new_lora_format:
|
248 |
+
# Strip the `"unet"` prefix.
|
249 |
+
is_text_encoder_present = any(key.startswith(self.text_encoder_name) for key in state_dict.keys())
|
250 |
+
if is_text_encoder_present:
|
251 |
+
warn_message = "The state_dict contains LoRA params corresponding to the text encoder which are not being used here. To use both UNet and text encoder related LoRA params, use [`pipe.load_lora_weights()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraLoaderMixin.load_lora_weights)."
|
252 |
+
warnings.warn(warn_message)
|
253 |
+
unet_keys = [k for k in state_dict.keys() if k.startswith(self.unet_name)]
|
254 |
+
state_dict = {k.replace(f"{self.unet_name}.", ""): v for k, v in state_dict.items() if k in unet_keys}
|
255 |
+
|
256 |
+
lora_grouped_dict = defaultdict(dict)
|
257 |
+
for key, value in state_dict.items():
|
258 |
+
attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:])
|
259 |
+
lora_grouped_dict[attn_processor_key][sub_key] = value
|
260 |
+
|
261 |
+
for key, value_dict in lora_grouped_dict.items():
|
262 |
+
rank = value_dict["to_k_lora.down.weight"].shape[0]
|
263 |
+
hidden_size = value_dict["to_k_lora.up.weight"].shape[0]
|
264 |
+
|
265 |
+
attn_processor = self
|
266 |
+
for sub_key in key.split("."):
|
267 |
+
attn_processor = getattr(attn_processor, sub_key)
|
268 |
+
|
269 |
+
if isinstance(
|
270 |
+
attn_processor, (AttnAddedKVProcessor, SlicedAttnAddedKVProcessor, AttnAddedKVProcessor2_0)
|
271 |
+
):
|
272 |
+
cross_attention_dim = value_dict["add_k_proj_lora.down.weight"].shape[1]
|
273 |
+
attn_processor_class = LoRAAttnAddedKVProcessor
|
274 |
+
else:
|
275 |
+
cross_attention_dim = value_dict["to_k_lora.down.weight"].shape[1]
|
276 |
+
if isinstance(attn_processor, (XFormersAttnProcessor, LoRAXFormersAttnProcessor)):
|
277 |
+
attn_processor_class = LoRAXFormersAttnProcessor
|
278 |
+
else:
|
279 |
+
attn_processor_class = (
|
280 |
+
LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
|
281 |
+
)
|
282 |
+
|
283 |
+
attn_processors[key] = attn_processor_class(
|
284 |
+
hidden_size=hidden_size,
|
285 |
+
cross_attention_dim=cross_attention_dim,
|
286 |
+
rank=rank,
|
287 |
+
network_alpha=network_alpha,
|
288 |
+
)
|
289 |
+
attn_processors[key].load_state_dict(value_dict)
|
290 |
+
elif is_custom_diffusion:
|
291 |
+
custom_diffusion_grouped_dict = defaultdict(dict)
|
292 |
+
for key, value in state_dict.items():
|
293 |
+
if len(value) == 0:
|
294 |
+
custom_diffusion_grouped_dict[key] = {}
|
295 |
+
else:
|
296 |
+
if "to_out" in key:
|
297 |
+
attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:])
|
298 |
+
else:
|
299 |
+
attn_processor_key, sub_key = ".".join(key.split(".")[:-2]), ".".join(key.split(".")[-2:])
|
300 |
+
custom_diffusion_grouped_dict[attn_processor_key][sub_key] = value
|
301 |
+
|
302 |
+
for key, value_dict in custom_diffusion_grouped_dict.items():
|
303 |
+
if len(value_dict) == 0:
|
304 |
+
attn_processors[key] = CustomDiffusionAttnProcessor(
|
305 |
+
train_kv=False, train_q_out=False, hidden_size=None, cross_attention_dim=None
|
306 |
+
)
|
307 |
+
else:
|
308 |
+
cross_attention_dim = value_dict["to_k_custom_diffusion.weight"].shape[1]
|
309 |
+
hidden_size = value_dict["to_k_custom_diffusion.weight"].shape[0]
|
310 |
+
train_q_out = True if "to_q_custom_diffusion.weight" in value_dict else False
|
311 |
+
attn_processors[key] = CustomDiffusionAttnProcessor(
|
312 |
+
train_kv=True,
|
313 |
+
train_q_out=train_q_out,
|
314 |
+
hidden_size=hidden_size,
|
315 |
+
cross_attention_dim=cross_attention_dim,
|
316 |
+
)
|
317 |
+
attn_processors[key].load_state_dict(value_dict)
|
318 |
+
else:
|
319 |
+
raise ValueError(
|
320 |
+
f"{model_file} does not seem to be in the correct format expected by LoRA or Custom Diffusion training."
|
321 |
+
)
|
322 |
+
|
323 |
+
# set correct dtype & device
|
324 |
+
attn_processors = {k: v.to(device=self.device, dtype=self.dtype) for k, v in attn_processors.items()}
|
325 |
+
|
326 |
+
# set layers
|
327 |
+
self.set_attn_processor(attn_processors)
|
328 |
+
|
329 |
+
def save_attn_procs(
|
330 |
+
self,
|
331 |
+
save_directory: Union[str, os.PathLike],
|
332 |
+
is_main_process: bool = True,
|
333 |
+
weight_name: str = None,
|
334 |
+
save_function: Callable = None,
|
335 |
+
safe_serialization: bool = False,
|
336 |
+
**kwargs,
|
337 |
+
):
|
338 |
+
r"""
|
339 |
+
Save an attention processor to a directory so that it can be reloaded using the
|
340 |
+
[`~loaders.UNet2DConditionLoadersMixin.load_attn_procs`] method.
|
341 |
+
|
342 |
+
Arguments:
|
343 |
+
save_directory (`str` or `os.PathLike`):
|
344 |
+
Directory to save an attention processor to. Will be created if it doesn't exist.
|
345 |
+
is_main_process (`bool`, *optional*, defaults to `True`):
|
346 |
+
Whether the process calling this is the main process or not. Useful during distributed training and you
|
347 |
+
need to call this function on all processes. In this case, set `is_main_process=True` only on the main
|
348 |
+
process to avoid race conditions.
|
349 |
+
save_function (`Callable`):
|
350 |
+
The function to use to save the state dictionary. Useful during distributed training when you need to
|
351 |
+
replace `torch.save` with another method. Can be configured with the environment variable
|
352 |
+
`DIFFUSERS_SAVE_MODE`.
|
353 |
+
|
354 |
+
"""
|
355 |
+
weight_name = weight_name or deprecate(
|
356 |
+
"weights_name",
|
357 |
+
"0.20.0",
|
358 |
+
"`weights_name` is deprecated, please use `weight_name` instead.",
|
359 |
+
take_from=kwargs,
|
360 |
+
)
|
361 |
+
if os.path.isfile(save_directory):
|
362 |
+
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
|
363 |
+
return
|
364 |
+
|
365 |
+
if save_function is None:
|
366 |
+
if safe_serialization:
|
367 |
+
|
368 |
+
def save_function(weights, filename):
|
369 |
+
return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"})
|
370 |
+
|
371 |
+
else:
|
372 |
+
save_function = torch.save
|
373 |
+
|
374 |
+
os.makedirs(save_directory, exist_ok=True)
|
375 |
+
|
376 |
+
is_custom_diffusion = any(
|
377 |
+
isinstance(x, (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor))
|
378 |
+
for (_, x) in self.attn_processors.items()
|
379 |
+
)
|
380 |
+
if is_custom_diffusion:
|
381 |
+
model_to_save = AttnProcsLayers(
|
382 |
+
{
|
383 |
+
y: x
|
384 |
+
for (y, x) in self.attn_processors.items()
|
385 |
+
if isinstance(x, (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor))
|
386 |
+
}
|
387 |
+
)
|
388 |
+
state_dict = model_to_save.state_dict()
|
389 |
+
for name, attn in self.attn_processors.items():
|
390 |
+
if len(attn.state_dict()) == 0:
|
391 |
+
state_dict[name] = {}
|
392 |
+
else:
|
393 |
+
model_to_save = AttnProcsLayers(self.attn_processors)
|
394 |
+
state_dict = model_to_save.state_dict()
|
395 |
+
|
396 |
+
if weight_name is None:
|
397 |
+
if safe_serialization:
|
398 |
+
weight_name = CUSTOM_DIFFUSION_WEIGHT_NAME_SAFE if is_custom_diffusion else LORA_WEIGHT_NAME_SAFE
|
399 |
+
else:
|
400 |
+
weight_name = CUSTOM_DIFFUSION_WEIGHT_NAME if is_custom_diffusion else LORA_WEIGHT_NAME
|
401 |
+
|
402 |
+
# Save the model
|
403 |
+
save_function(state_dict, os.path.join(save_directory, weight_name))
|
404 |
+
logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}")
|
405 |
+
|
406 |
+
|
407 |
+
class TextualInversionLoaderMixin:
|
408 |
+
r"""
|
409 |
+
Load textual inversion tokens and embeddings to the tokenizer and text encoder.
|
410 |
+
"""
|
411 |
+
|
412 |
+
def maybe_convert_prompt(self, prompt: Union[str, List[str]], tokenizer: "PreTrainedTokenizer"):
|
413 |
+
r"""
|
414 |
+
Processes prompts that include a special token corresponding to a multi-vector textual inversion embedding to
|
415 |
+
be replaced with multiple special tokens each corresponding to one of the vectors. If the prompt has no textual
|
416 |
+
inversion token or if the textual inversion token is a single vector, the input prompt is returned.
|
417 |
+
|
418 |
+
Parameters:
|
419 |
+
prompt (`str` or list of `str`):
|
420 |
+
The prompt or prompts to guide the image generation.
|
421 |
+
tokenizer (`PreTrainedTokenizer`):
|
422 |
+
The tokenizer responsible for encoding the prompt into input tokens.
|
423 |
+
|
424 |
+
Returns:
|
425 |
+
`str` or list of `str`: The converted prompt
|
426 |
+
"""
|
427 |
+
if not isinstance(prompt, List):
|
428 |
+
prompts = [prompt]
|
429 |
+
else:
|
430 |
+
prompts = prompt
|
431 |
+
|
432 |
+
prompts = [self._maybe_convert_prompt(p, tokenizer) for p in prompts]
|
433 |
+
|
434 |
+
if not isinstance(prompt, List):
|
435 |
+
return prompts[0]
|
436 |
+
|
437 |
+
return prompts
|
438 |
+
|
439 |
+
def _maybe_convert_prompt(self, prompt: str, tokenizer: "PreTrainedTokenizer"):
|
440 |
+
r"""
|
441 |
+
Maybe convert a prompt into a "multi vector"-compatible prompt. If the prompt includes a token that corresponds
|
442 |
+
to a multi-vector textual inversion embedding, this function will process the prompt so that the special token
|
443 |
+
is replaced with multiple special tokens each corresponding to one of the vectors. If the prompt has no textual
|
444 |
+
inversion token or a textual inversion token that is a single vector, the input prompt is simply returned.
|
445 |
+
|
446 |
+
Parameters:
|
447 |
+
prompt (`str`):
|
448 |
+
The prompt to guide the image generation.
|
449 |
+
tokenizer (`PreTrainedTokenizer`):
|
450 |
+
The tokenizer responsible for encoding the prompt into input tokens.
|
451 |
+
|
452 |
+
Returns:
|
453 |
+
`str`: The converted prompt
|
454 |
+
"""
|
455 |
+
tokens = tokenizer.tokenize(prompt)
|
456 |
+
unique_tokens = set(tokens)
|
457 |
+
for token in unique_tokens:
|
458 |
+
if token in tokenizer.added_tokens_encoder:
|
459 |
+
replacement = token
|
460 |
+
i = 1
|
461 |
+
while f"{token}_{i}" in tokenizer.added_tokens_encoder:
|
462 |
+
replacement += f" {token}_{i}"
|
463 |
+
i += 1
|
464 |
+
|
465 |
+
prompt = prompt.replace(token, replacement)
|
466 |
+
|
467 |
+
return prompt
|
468 |
+
|
469 |
+
def load_textual_inversion(
|
470 |
+
self,
|
471 |
+
pretrained_model_name_or_path: Union[str, List[str], Dict[str, torch.Tensor], List[Dict[str, torch.Tensor]]],
|
472 |
+
token: Optional[Union[str, List[str]]] = None,
|
473 |
+
**kwargs,
|
474 |
+
):
|
475 |
+
r"""
|
476 |
+
Load textual inversion embeddings into the text encoder of [`StableDiffusionPipeline`] (both 🤗 Diffusers and
|
477 |
+
Automatic1111 formats are supported).
|
478 |
+
|
479 |
+
Parameters:
|
480 |
+
pretrained_model_name_or_path (`str` or `os.PathLike` or `List[str or os.PathLike]` or `Dict` or `List[Dict]`):
|
481 |
+
Can be either one of the following or a list of them:
|
482 |
+
|
483 |
+
- A string, the *model id* (for example `sd-concepts-library/low-poly-hd-logos-icons`) of a
|
484 |
+
pretrained model hosted on the Hub.
|
485 |
+
- A path to a *directory* (for example `./my_text_inversion_directory/`) containing the textual
|
486 |
+
inversion weights.
|
487 |
+
- A path to a *file* (for example `./my_text_inversions.pt`) containing textual inversion weights.
|
488 |
+
- A [torch state
|
489 |
+
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
|
490 |
+
|
491 |
+
token (`str` or `List[str]`, *optional*):
|
492 |
+
Override the token to use for the textual inversion weights. If `pretrained_model_name_or_path` is a
|
493 |
+
list, then `token` must also be a list of equal length.
|
494 |
+
weight_name (`str`, *optional*):
|
495 |
+
Name of a custom weight file. This should be used when:
|
496 |
+
|
497 |
+
- The saved textual inversion file is in 🤗 Diffusers format, but was saved under a specific weight
|
498 |
+
name such as `text_inv.bin`.
|
499 |
+
- The saved textual inversion file is in the Automatic1111 format.
|
500 |
+
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
501 |
+
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
502 |
+
is not used.
|
503 |
+
force_download (`bool`, *optional*, defaults to `False`):
|
504 |
+
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
505 |
+
cached versions if they exist.
|
506 |
+
resume_download (`bool`, *optional*, defaults to `False`):
|
507 |
+
Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
|
508 |
+
incompletely downloaded files are deleted.
|
509 |
+
proxies (`Dict[str, str]`, *optional*):
|
510 |
+
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
511 |
+
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
512 |
+
local_files_only (`bool`, *optional*, defaults to `False`):
|
513 |
+
Whether to only load local model weights and configuration files or not. If set to `True`, the model
|
514 |
+
won't be downloaded from the Hub.
|
515 |
+
use_auth_token (`str` or *bool*, *optional*):
|
516 |
+
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
517 |
+
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
518 |
+
revision (`str`, *optional*, defaults to `"main"`):
|
519 |
+
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
520 |
+
allowed by Git.
|
521 |
+
subfolder (`str`, *optional*, defaults to `""`):
|
522 |
+
The subfolder location of a model file within a larger model repository on the Hub or locally.
|
523 |
+
mirror (`str`, *optional*):
|
524 |
+
Mirror source to resolve accessibility issues if you're downloading a model in China. We do not
|
525 |
+
guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
|
526 |
+
information.
|
527 |
+
|
528 |
+
Example:
|
529 |
+
|
530 |
+
To load a textual inversion embedding vector in 🤗 Diffusers format:
|
531 |
+
|
532 |
+
```py
|
533 |
+
from diffusers import StableDiffusionPipeline
|
534 |
+
import torch
|
535 |
+
|
536 |
+
model_id = "runwayml/stable-diffusion-v1-5"
|
537 |
+
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")
|
538 |
+
|
539 |
+
pipe.load_textual_inversion("sd-concepts-library/cat-toy")
|
540 |
+
|
541 |
+
prompt = "A <cat-toy> backpack"
|
542 |
+
|
543 |
+
image = pipe(prompt, num_inference_steps=50).images[0]
|
544 |
+
image.save("cat-backpack.png")
|
545 |
+
```
|
546 |
+
|
547 |
+
To load a textual inversion embedding vector in Automatic1111 format, make sure to download the vector first
|
548 |
+
(for example from [civitAI](https://civitai.com/models/3036?modelVersionId=9857)) and then load the vector
|
549 |
+
locally:
|
550 |
+
|
551 |
+
```py
|
552 |
+
from diffusers import StableDiffusionPipeline
|
553 |
+
import torch
|
554 |
+
|
555 |
+
model_id = "runwayml/stable-diffusion-v1-5"
|
556 |
+
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")
|
557 |
+
|
558 |
+
pipe.load_textual_inversion("./charturnerv2.pt", token="charturnerv2")
|
559 |
+
|
560 |
+
prompt = "charturnerv2, multiple views of the same character in the same outfit, a character turnaround of a woman wearing a black jacket and red shirt, best quality, intricate details."
|
561 |
+
|
562 |
+
image = pipe(prompt, num_inference_steps=50).images[0]
|
563 |
+
image.save("character.png")
|
564 |
+
```
|
565 |
+
|
566 |
+
"""
|
567 |
+
if not hasattr(self, "tokenizer") or not isinstance(self.tokenizer, PreTrainedTokenizer):
|
568 |
+
raise ValueError(
|
569 |
+
f"{self.__class__.__name__} requires `self.tokenizer` of type `PreTrainedTokenizer` for calling"
|
570 |
+
f" `{self.load_textual_inversion.__name__}`"
|
571 |
+
)
|
572 |
+
|
573 |
+
if not hasattr(self, "text_encoder") or not isinstance(self.text_encoder, PreTrainedModel):
|
574 |
+
raise ValueError(
|
575 |
+
f"{self.__class__.__name__} requires `self.text_encoder` of type `PreTrainedModel` for calling"
|
576 |
+
f" `{self.load_textual_inversion.__name__}`"
|
577 |
+
)
|
578 |
+
|
579 |
+
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
|
580 |
+
force_download = kwargs.pop("force_download", False)
|
581 |
+
resume_download = kwargs.pop("resume_download", False)
|
582 |
+
proxies = kwargs.pop("proxies", None)
|
583 |
+
local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
|
584 |
+
use_auth_token = kwargs.pop("use_auth_token", None)
|
585 |
+
revision = kwargs.pop("revision", None)
|
586 |
+
subfolder = kwargs.pop("subfolder", None)
|
587 |
+
weight_name = kwargs.pop("weight_name", None)
|
588 |
+
use_safetensors = kwargs.pop("use_safetensors", None)
|
589 |
+
|
590 |
+
if use_safetensors and not is_safetensors_available():
|
591 |
+
raise ValueError(
|
592 |
+
"`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetensors"
|
593 |
+
)
|
594 |
+
|
595 |
+
allow_pickle = False
|
596 |
+
if use_safetensors is None:
|
597 |
+
use_safetensors = is_safetensors_available()
|
598 |
+
allow_pickle = True
|
599 |
+
|
600 |
+
user_agent = {
|
601 |
+
"file_type": "text_inversion",
|
602 |
+
"framework": "pytorch",
|
603 |
+
}
|
604 |
+
|
605 |
+
if not isinstance(pretrained_model_name_or_path, list):
|
606 |
+
pretrained_model_name_or_paths = [pretrained_model_name_or_path]
|
607 |
+
else:
|
608 |
+
pretrained_model_name_or_paths = pretrained_model_name_or_path
|
609 |
+
|
610 |
+
if isinstance(token, str):
|
611 |
+
tokens = [token]
|
612 |
+
elif token is None:
|
613 |
+
tokens = [None] * len(pretrained_model_name_or_paths)
|
614 |
+
else:
|
615 |
+
tokens = token
|
616 |
+
|
617 |
+
if len(pretrained_model_name_or_paths) != len(tokens):
|
618 |
+
raise ValueError(
|
619 |
+
f"You have passed a list of models of length {len(pretrained_model_name_or_paths)}, and list of tokens of length {len(tokens)}"
|
620 |
+
f"Make sure both lists have the same length."
|
621 |
+
)
|
622 |
+
|
623 |
+
valid_tokens = [t for t in tokens if t is not None]
|
624 |
+
if len(set(valid_tokens)) < len(valid_tokens):
|
625 |
+
raise ValueError(f"You have passed a list of tokens that contains duplicates: {tokens}")
|
626 |
+
|
627 |
+
token_ids_and_embeddings = []
|
628 |
+
|
629 |
+
for pretrained_model_name_or_path, token in zip(pretrained_model_name_or_paths, tokens):
|
630 |
+
if not isinstance(pretrained_model_name_or_path, dict):
|
631 |
+
# 1. Load textual inversion file
|
632 |
+
model_file = None
|
633 |
+
# Let's first try to load .safetensors weights
|
634 |
+
if (use_safetensors and weight_name is None) or (
|
635 |
+
weight_name is not None and weight_name.endswith(".safetensors")
|
636 |
+
):
|
637 |
+
try:
|
638 |
+
model_file = _get_model_file(
|
639 |
+
pretrained_model_name_or_path,
|
640 |
+
weights_name=weight_name or TEXT_INVERSION_NAME_SAFE,
|
641 |
+
cache_dir=cache_dir,
|
642 |
+
force_download=force_download,
|
643 |
+
resume_download=resume_download,
|
644 |
+
proxies=proxies,
|
645 |
+
local_files_only=local_files_only,
|
646 |
+
use_auth_token=use_auth_token,
|
647 |
+
revision=revision,
|
648 |
+
subfolder=subfolder,
|
649 |
+
user_agent=user_agent,
|
650 |
+
)
|
651 |
+
state_dict = safetensors.torch.load_file(model_file, device="cpu")
|
652 |
+
except Exception as e:
|
653 |
+
if not allow_pickle:
|
654 |
+
raise e
|
655 |
+
|
656 |
+
model_file = None
|
657 |
+
|
658 |
+
if model_file is None:
|
659 |
+
model_file = _get_model_file(
|
660 |
+
pretrained_model_name_or_path,
|
661 |
+
weights_name=weight_name or TEXT_INVERSION_NAME,
|
662 |
+
cache_dir=cache_dir,
|
663 |
+
force_download=force_download,
|
664 |
+
resume_download=resume_download,
|
665 |
+
proxies=proxies,
|
666 |
+
local_files_only=local_files_only,
|
667 |
+
use_auth_token=use_auth_token,
|
668 |
+
revision=revision,
|
669 |
+
subfolder=subfolder,
|
670 |
+
user_agent=user_agent,
|
671 |
+
)
|
672 |
+
state_dict = torch.load(model_file, map_location="cpu")
|
673 |
+
else:
|
674 |
+
state_dict = pretrained_model_name_or_path
|
675 |
+
|
676 |
+
# 2. Load token and embedding correcly from file
|
677 |
+
loaded_token = None
|
678 |
+
if isinstance(state_dict, torch.Tensor):
|
679 |
+
if token is None:
|
680 |
+
raise ValueError(
|
681 |
+
"You are trying to load a textual inversion embedding that has been saved as a PyTorch tensor. Make sure to pass the name of the corresponding token in this case: `token=...`."
|
682 |
+
)
|
683 |
+
embedding = state_dict
|
684 |
+
elif len(state_dict) == 1:
|
685 |
+
# diffusers
|
686 |
+
loaded_token, embedding = next(iter(state_dict.items()))
|
687 |
+
elif "string_to_param" in state_dict:
|
688 |
+
# A1111
|
689 |
+
loaded_token = state_dict["name"]
|
690 |
+
embedding = state_dict["string_to_param"]["*"]
|
691 |
+
|
692 |
+
if token is not None and loaded_token != token:
|
693 |
+
logger.info(f"The loaded token: {loaded_token} is overwritten by the passed token {token}.")
|
694 |
+
else:
|
695 |
+
token = loaded_token
|
696 |
+
|
697 |
+
embedding = embedding.to(dtype=self.text_encoder.dtype, device=self.text_encoder.device)
|
698 |
+
|
699 |
+
# 3. Make sure we don't mess up the tokenizer or text encoder
|
700 |
+
vocab = self.tokenizer.get_vocab()
|
701 |
+
if token in vocab:
|
702 |
+
raise ValueError(
|
703 |
+
f"Token {token} already in tokenizer vocabulary. Please choose a different token name or remove {token} and embedding from the tokenizer and text encoder."
|
704 |
+
)
|
705 |
+
elif f"{token}_1" in vocab:
|
706 |
+
multi_vector_tokens = [token]
|
707 |
+
i = 1
|
708 |
+
while f"{token}_{i}" in self.tokenizer.added_tokens_encoder:
|
709 |
+
multi_vector_tokens.append(f"{token}_{i}")
|
710 |
+
i += 1
|
711 |
+
|
712 |
+
raise ValueError(
|
713 |
+
f"Multi-vector Token {multi_vector_tokens} already in tokenizer vocabulary. Please choose a different token name or remove the {multi_vector_tokens} and embedding from the tokenizer and text encoder."
|
714 |
+
)
|
715 |
+
|
716 |
+
is_multi_vector = len(embedding.shape) > 1 and embedding.shape[0] > 1
|
717 |
+
|
718 |
+
if is_multi_vector:
|
719 |
+
tokens = [token] + [f"{token}_{i}" for i in range(1, embedding.shape[0])]
|
720 |
+
embeddings = [e for e in embedding] # noqa: C416
|
721 |
+
else:
|
722 |
+
tokens = [token]
|
723 |
+
embeddings = [embedding[0]] if len(embedding.shape) > 1 else [embedding]
|
724 |
+
|
725 |
+
# add tokens and get ids
|
726 |
+
self.tokenizer.add_tokens(tokens)
|
727 |
+
token_ids = self.tokenizer.convert_tokens_to_ids(tokens)
|
728 |
+
token_ids_and_embeddings += zip(token_ids, embeddings)
|
729 |
+
|
730 |
+
logger.info(f"Loaded textual inversion embedding for {token}.")
|
731 |
+
|
732 |
+
# resize token embeddings and set all new embeddings
|
733 |
+
self.text_encoder.resize_token_embeddings(len(self.tokenizer))
|
734 |
+
for token_id, embedding in token_ids_and_embeddings:
|
735 |
+
self.text_encoder.get_input_embeddings().weight.data[token_id] = embedding
|
736 |
+
|
737 |
+
|
738 |
+
class LoraLoaderMixin:
|
739 |
+
r"""
|
740 |
+
Load LoRA layers into [`UNet2DConditionModel`] and
|
741 |
+
[`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel).
|
742 |
+
"""
|
743 |
+
text_encoder_name = TEXT_ENCODER_NAME
|
744 |
+
unet_name = UNET_NAME
|
745 |
+
|
746 |
+
def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
|
747 |
+
r"""
|
748 |
+
Load pretrained LoRA attention processor layers into [`UNet2DConditionModel`] and
|
749 |
+
[`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel).
|
750 |
+
|
751 |
+
Parameters:
|
752 |
+
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
753 |
+
Can be either:
|
754 |
+
|
755 |
+
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
|
756 |
+
the Hub.
|
757 |
+
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
|
758 |
+
with [`ModelMixin.save_pretrained`].
|
759 |
+
- A [torch state
|
760 |
+
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
|
761 |
+
|
762 |
+
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
763 |
+
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
764 |
+
is not used.
|
765 |
+
force_download (`bool`, *optional*, defaults to `False`):
|
766 |
+
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
767 |
+
cached versions if they exist.
|
768 |
+
resume_download (`bool`, *optional*, defaults to `False`):
|
769 |
+
Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
|
770 |
+
incompletely downloaded files are deleted.
|
771 |
+
proxies (`Dict[str, str]`, *optional*):
|
772 |
+
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
773 |
+
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
774 |
+
local_files_only (`bool`, *optional*, defaults to `False`):
|
775 |
+
Whether to only load local model weights and configuration files or not. If set to `True`, the model
|
776 |
+
won't be downloaded from the Hub.
|
777 |
+
use_auth_token (`str` or *bool*, *optional*):
|
778 |
+
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
779 |
+
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
780 |
+
revision (`str`, *optional*, defaults to `"main"`):
|
781 |
+
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
782 |
+
allowed by Git.
|
783 |
+
subfolder (`str`, *optional*, defaults to `""`):
|
784 |
+
The subfolder location of a model file within a larger model repository on the Hub or locally.
|
785 |
+
mirror (`str`, *optional*):
|
786 |
+
Mirror source to resolve accessibility issues if you're downloading a model in China. We do not
|
787 |
+
guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
|
788 |
+
information.
|
789 |
+
|
790 |
+
"""
|
791 |
+
# Load the main state dict first which has the LoRA layers for either of
|
792 |
+
# UNet and text encoder or both.
|
793 |
+
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
|
794 |
+
force_download = kwargs.pop("force_download", False)
|
795 |
+
resume_download = kwargs.pop("resume_download", False)
|
796 |
+
proxies = kwargs.pop("proxies", None)
|
797 |
+
local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
|
798 |
+
use_auth_token = kwargs.pop("use_auth_token", None)
|
799 |
+
revision = kwargs.pop("revision", None)
|
800 |
+
subfolder = kwargs.pop("subfolder", None)
|
801 |
+
weight_name = kwargs.pop("weight_name", None)
|
802 |
+
use_safetensors = kwargs.pop("use_safetensors", None)
|
803 |
+
|
804 |
+
# set lora scale to a reasonable default
|
805 |
+
self._lora_scale = 1.0
|
806 |
+
|
807 |
+
if use_safetensors and not is_safetensors_available():
|
808 |
+
raise ValueError(
|
809 |
+
"`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetensors"
|
810 |
+
)
|
811 |
+
|
812 |
+
allow_pickle = False
|
813 |
+
if use_safetensors is None:
|
814 |
+
use_safetensors = is_safetensors_available()
|
815 |
+
allow_pickle = True
|
816 |
+
|
817 |
+
user_agent = {
|
818 |
+
"file_type": "attn_procs_weights",
|
819 |
+
"framework": "pytorch",
|
820 |
+
}
|
821 |
+
|
822 |
+
model_file = None
|
823 |
+
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
|
824 |
+
# Let's first try to load .safetensors weights
|
825 |
+
if (use_safetensors and weight_name is None) or (
|
826 |
+
weight_name is not None and weight_name.endswith(".safetensors")
|
827 |
+
):
|
828 |
+
try:
|
829 |
+
model_file = _get_model_file(
|
830 |
+
pretrained_model_name_or_path_or_dict,
|
831 |
+
weights_name=weight_name or LORA_WEIGHT_NAME_SAFE,
|
832 |
+
cache_dir=cache_dir,
|
833 |
+
force_download=force_download,
|
834 |
+
resume_download=resume_download,
|
835 |
+
proxies=proxies,
|
836 |
+
local_files_only=local_files_only,
|
837 |
+
use_auth_token=use_auth_token,
|
838 |
+
revision=revision,
|
839 |
+
subfolder=subfolder,
|
840 |
+
user_agent=user_agent,
|
841 |
+
)
|
842 |
+
state_dict = safetensors.torch.load_file(model_file, device="cpu")
|
843 |
+
except IOError as e:
|
844 |
+
if not allow_pickle:
|
845 |
+
raise e
|
846 |
+
# try loading non-safetensors weights
|
847 |
+
pass
|
848 |
+
if model_file is None:
|
849 |
+
model_file = _get_model_file(
|
850 |
+
pretrained_model_name_or_path_or_dict,
|
851 |
+
weights_name=weight_name or LORA_WEIGHT_NAME,
|
852 |
+
cache_dir=cache_dir,
|
853 |
+
force_download=force_download,
|
854 |
+
resume_download=resume_download,
|
855 |
+
proxies=proxies,
|
856 |
+
local_files_only=local_files_only,
|
857 |
+
use_auth_token=use_auth_token,
|
858 |
+
revision=revision,
|
859 |
+
subfolder=subfolder,
|
860 |
+
user_agent=user_agent,
|
861 |
+
)
|
862 |
+
state_dict = torch.load(model_file, map_location="cpu")
|
863 |
+
else:
|
864 |
+
state_dict = pretrained_model_name_or_path_or_dict
|
865 |
+
|
866 |
+
# Convert kohya-ss Style LoRA attn procs to diffusers attn procs
|
867 |
+
network_alpha = None
|
868 |
+
if all((k.startswith("lora_te_") or k.startswith("lora_unet_")) for k in state_dict.keys()):
|
869 |
+
state_dict, network_alpha = self._convert_kohya_lora_to_diffusers(state_dict)
|
870 |
+
|
871 |
+
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
|
872 |
+
# then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as
|
873 |
+
# their prefixes.
|
874 |
+
keys = list(state_dict.keys())
|
875 |
+
if all(key.startswith(self.unet_name) or key.startswith(self.text_encoder_name) for key in keys):
|
876 |
+
# Load the layers corresponding to UNet.
|
877 |
+
unet_keys = [k for k in keys if k.startswith(self.unet_name)]
|
878 |
+
logger.info(f"Loading {self.unet_name}.")
|
879 |
+
unet_lora_state_dict = {
|
880 |
+
k.replace(f"{self.unet_name}.", ""): v for k, v in state_dict.items() if k in unet_keys
|
881 |
+
}
|
882 |
+
self.unet.load_attn_procs(unet_lora_state_dict, network_alpha=network_alpha)
|
883 |
+
|
884 |
+
# Load the layers corresponding to text encoder and make necessary adjustments.
|
885 |
+
text_encoder_keys = [k for k in keys if k.startswith(self.text_encoder_name)]
|
886 |
+
text_encoder_lora_state_dict = {
|
887 |
+
k.replace(f"{self.text_encoder_name}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys
|
888 |
+
}
|
889 |
+
if len(text_encoder_lora_state_dict) > 0:
|
890 |
+
logger.info(f"Loading {self.text_encoder_name}.")
|
891 |
+
attn_procs_text_encoder = self._load_text_encoder_attn_procs(
|
892 |
+
text_encoder_lora_state_dict, network_alpha=network_alpha
|
893 |
+
)
|
894 |
+
self._modify_text_encoder(attn_procs_text_encoder)
|
895 |
+
|
896 |
+
# save lora attn procs of text encoder so that it can be easily retrieved
|
897 |
+
self._text_encoder_lora_attn_procs = attn_procs_text_encoder
|
898 |
+
|
899 |
+
# Otherwise, we're dealing with the old format. This means the `state_dict` should only
|
900 |
+
# contain the module names of the `unet` as its keys WITHOUT any prefix.
|
901 |
+
elif not all(
|
902 |
+
key.startswith(self.unet_name) or key.startswith(self.text_encoder_name) for key in state_dict.keys()
|
903 |
+
):
|
904 |
+
self.unet.load_attn_procs(state_dict)
|
905 |
+
warn_message = "You have saved the LoRA weights using the old format. To convert the old LoRA weights to the new format, you can first load them in a dictionary and then create a new dictionary like the following: `new_state_dict = {f'unet'.{module_name}: params for module_name, params in old_state_dict.items()}`."
|
906 |
+
warnings.warn(warn_message)
|
907 |
+
|
908 |
+
@property
|
909 |
+
def lora_scale(self) -> float:
|
910 |
+
# property function that returns the lora scale which can be set at run time by the pipeline.
|
911 |
+
# if _lora_scale has not been set, return 1
|
912 |
+
return self._lora_scale if hasattr(self, "_lora_scale") else 1.0
|
913 |
+
|
914 |
+
@property
|
915 |
+
def text_encoder_lora_attn_procs(self):
|
916 |
+
if hasattr(self, "_text_encoder_lora_attn_procs"):
|
917 |
+
return self._text_encoder_lora_attn_procs
|
918 |
+
return
|
919 |
+
|
920 |
+
def _remove_text_encoder_monkey_patch(self):
|
921 |
+
# Loop over the CLIPAttention module of text_encoder
|
922 |
+
for name, attn_module in self.text_encoder.named_modules():
|
923 |
+
if name.endswith(TEXT_ENCODER_ATTN_MODULE):
|
924 |
+
# Loop over the LoRA layers
|
925 |
+
for _, text_encoder_attr in self._lora_attn_processor_attr_to_text_encoder_attr.items():
|
926 |
+
# Retrieve the q/k/v/out projection of CLIPAttention
|
927 |
+
module = attn_module.get_submodule(text_encoder_attr)
|
928 |
+
if hasattr(module, "old_forward"):
|
929 |
+
# restore original `forward` to remove monkey-patch
|
930 |
+
module.forward = module.old_forward
|
931 |
+
delattr(module, "old_forward")
|
932 |
+
|
933 |
+
def _modify_text_encoder(self, attn_processors: Dict[str, LoRAAttnProcessor]):
|
934 |
+
r"""
|
935 |
+
Monkey-patches the forward passes of attention modules of the text encoder.
|
936 |
+
|
937 |
+
Parameters:
|
938 |
+
attn_processors: Dict[str, `LoRAAttnProcessor`]:
|
939 |
+
A dictionary mapping the module names and their corresponding [`~LoRAAttnProcessor`].
|
940 |
+
"""
|
941 |
+
|
942 |
+
# First, remove any monkey-patch that might have been applied before
|
943 |
+
self._remove_text_encoder_monkey_patch()
|
944 |
+
|
945 |
+
# Loop over the CLIPAttention module of text_encoder
|
946 |
+
for name, attn_module in self.text_encoder.named_modules():
|
947 |
+
if name.endswith(TEXT_ENCODER_ATTN_MODULE):
|
948 |
+
# Loop over the LoRA layers
|
949 |
+
for attn_proc_attr, text_encoder_attr in self._lora_attn_processor_attr_to_text_encoder_attr.items():
|
950 |
+
# Retrieve the q/k/v/out projection of CLIPAttention and its corresponding LoRA layer.
|
951 |
+
module = attn_module.get_submodule(text_encoder_attr)
|
952 |
+
lora_layer = attn_processors[name].get_submodule(attn_proc_attr)
|
953 |
+
|
954 |
+
# save old_forward to module that can be used to remove monkey-patch
|
955 |
+
old_forward = module.old_forward = module.forward
|
956 |
+
|
957 |
+
# create a new scope that locks in the old_forward, lora_layer value for each new_forward function
|
958 |
+
# for more detail, see https://github.com/huggingface/diffusers/pull/3490#issuecomment-1555059060
|
959 |
+
def make_new_forward(old_forward, lora_layer):
|
960 |
+
def new_forward(x):
|
961 |
+
result = old_forward(x) + self.lora_scale * lora_layer(x)
|
962 |
+
return result
|
963 |
+
|
964 |
+
return new_forward
|
965 |
+
|
966 |
+
# Monkey-patch.
|
967 |
+
module.forward = make_new_forward(old_forward, lora_layer)
|
968 |
+
|
969 |
+
@property
|
970 |
+
def _lora_attn_processor_attr_to_text_encoder_attr(self):
|
971 |
+
return {
|
972 |
+
"to_q_lora": "q_proj",
|
973 |
+
"to_k_lora": "k_proj",
|
974 |
+
"to_v_lora": "v_proj",
|
975 |
+
"to_out_lora": "out_proj",
|
976 |
+
}
|
977 |
+
|
978 |
+
def _load_text_encoder_attn_procs(
|
979 |
+
self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs
|
980 |
+
):
|
981 |
+
r"""
|
982 |
+
Load pretrained attention processor layers for
|
983 |
+
[`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel).
|
984 |
+
|
985 |
+
<Tip warning={true}>
|
986 |
+
|
987 |
+
This function is experimental and might change in the future.
|
988 |
+
|
989 |
+
</Tip>
|
990 |
+
|
991 |
+
Parameters:
|
992 |
+
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
993 |
+
Can be either:
|
994 |
+
|
995 |
+
- A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
|
996 |
+
Valid model ids should have an organization name, like `google/ddpm-celebahq-256`.
|
997 |
+
- A path to a *directory* containing model weights saved using [`~ModelMixin.save_config`], e.g.,
|
998 |
+
`./my_model_directory/`.
|
999 |
+
- A [torch state
|
1000 |
+
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
|
1001 |
+
|
1002 |
+
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
1003 |
+
Path to a directory in which a downloaded pretrained model configuration should be cached if the
|
1004 |
+
standard cache should not be used.
|
1005 |
+
force_download (`bool`, *optional*, defaults to `False`):
|
1006 |
+
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
1007 |
+
cached versions if they exist.
|
1008 |
+
resume_download (`bool`, *optional*, defaults to `False`):
|
1009 |
+
Whether or not to delete incompletely received files. Will attempt to resume the download if such a
|
1010 |
+
file exists.
|
1011 |
+
proxies (`Dict[str, str]`, *optional*):
|
1012 |
+
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
|
1013 |
+
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
1014 |
+
local_files_only (`bool`, *optional*, defaults to `False`):
|
1015 |
+
Whether or not to only look at local files (i.e., do not try to download the model).
|
1016 |
+
use_auth_token (`str` or *bool*, *optional*):
|
1017 |
+
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
|
1018 |
+
when running `diffusers-cli login` (stored in `~/.huggingface`).
|
1019 |
+
revision (`str`, *optional*, defaults to `"main"`):
|
1020 |
+
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
1021 |
+
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
|
1022 |
+
identifier allowed by git.
|
1023 |
+
subfolder (`str`, *optional*, defaults to `""`):
|
1024 |
+
In case the relevant files are located inside a subfolder of the model repo (either remote in
|
1025 |
+
huggingface.co or downloaded locally), you can specify the folder name here.
|
1026 |
+
mirror (`str`, *optional*):
|
1027 |
+
Mirror source to accelerate downloads in China. If you are from China and have an accessibility
|
1028 |
+
problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
|
1029 |
+
Please refer to the mirror site for more information.
|
1030 |
+
|
1031 |
+
Returns:
|
1032 |
+
`Dict[name, LoRAAttnProcessor]`: Mapping between the module names and their corresponding
|
1033 |
+
[`LoRAAttnProcessor`].
|
1034 |
+
|
1035 |
+
<Tip>
|
1036 |
+
|
1037 |
+
It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated
|
1038 |
+
models](https://huggingface.co/docs/hub/models-gated#gated-models).
|
1039 |
+
|
1040 |
+
</Tip>
|
1041 |
+
"""
|
1042 |
+
|
1043 |
+
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
|
1044 |
+
force_download = kwargs.pop("force_download", False)
|
1045 |
+
resume_download = kwargs.pop("resume_download", False)
|
1046 |
+
proxies = kwargs.pop("proxies", None)
|
1047 |
+
local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
|
1048 |
+
use_auth_token = kwargs.pop("use_auth_token", None)
|
1049 |
+
revision = kwargs.pop("revision", None)
|
1050 |
+
subfolder = kwargs.pop("subfolder", None)
|
1051 |
+
weight_name = kwargs.pop("weight_name", None)
|
1052 |
+
use_safetensors = kwargs.pop("use_safetensors", None)
|
1053 |
+
network_alpha = kwargs.pop("network_alpha", None)
|
1054 |
+
|
1055 |
+
if use_safetensors and not is_safetensors_available():
|
1056 |
+
raise ValueError(
|
1057 |
+
"`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetensors"
|
1058 |
+
)
|
1059 |
+
|
1060 |
+
allow_pickle = False
|
1061 |
+
if use_safetensors is None:
|
1062 |
+
use_safetensors = is_safetensors_available()
|
1063 |
+
allow_pickle = True
|
1064 |
+
|
1065 |
+
user_agent = {
|
1066 |
+
"file_type": "attn_procs_weights",
|
1067 |
+
"framework": "pytorch",
|
1068 |
+
}
|
1069 |
+
|
1070 |
+
model_file = None
|
1071 |
+
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
|
1072 |
+
# Let's first try to load .safetensors weights
|
1073 |
+
if (use_safetensors and weight_name is None) or (
|
1074 |
+
weight_name is not None and weight_name.endswith(".safetensors")
|
1075 |
+
):
|
1076 |
+
try:
|
1077 |
+
model_file = _get_model_file(
|
1078 |
+
pretrained_model_name_or_path_or_dict,
|
1079 |
+
weights_name=weight_name or LORA_WEIGHT_NAME_SAFE,
|
1080 |
+
cache_dir=cache_dir,
|
1081 |
+
force_download=force_download,
|
1082 |
+
resume_download=resume_download,
|
1083 |
+
proxies=proxies,
|
1084 |
+
local_files_only=local_files_only,
|
1085 |
+
use_auth_token=use_auth_token,
|
1086 |
+
revision=revision,
|
1087 |
+
subfolder=subfolder,
|
1088 |
+
user_agent=user_agent,
|
1089 |
+
)
|
1090 |
+
state_dict = safetensors.torch.load_file(model_file, device="cpu")
|
1091 |
+
except IOError as e:
|
1092 |
+
if not allow_pickle:
|
1093 |
+
raise e
|
1094 |
+
# try loading non-safetensors weights
|
1095 |
+
pass
|
1096 |
+
if model_file is None:
|
1097 |
+
model_file = _get_model_file(
|
1098 |
+
pretrained_model_name_or_path_or_dict,
|
1099 |
+
weights_name=weight_name or LORA_WEIGHT_NAME,
|
1100 |
+
cache_dir=cache_dir,
|
1101 |
+
force_download=force_download,
|
1102 |
+
resume_download=resume_download,
|
1103 |
+
proxies=proxies,
|
1104 |
+
local_files_only=local_files_only,
|
1105 |
+
use_auth_token=use_auth_token,
|
1106 |
+
revision=revision,
|
1107 |
+
subfolder=subfolder,
|
1108 |
+
user_agent=user_agent,
|
1109 |
+
)
|
1110 |
+
state_dict = torch.load(model_file, map_location="cpu")
|
1111 |
+
else:
|
1112 |
+
state_dict = pretrained_model_name_or_path_or_dict
|
1113 |
+
|
1114 |
+
# fill attn processors
|
1115 |
+
attn_processors = {}
|
1116 |
+
|
1117 |
+
is_lora = all("lora" in k for k in state_dict.keys())
|
1118 |
+
|
1119 |
+
if is_lora:
|
1120 |
+
lora_grouped_dict = defaultdict(dict)
|
1121 |
+
for key, value in state_dict.items():
|
1122 |
+
attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:])
|
1123 |
+
lora_grouped_dict[attn_processor_key][sub_key] = value
|
1124 |
+
|
1125 |
+
for key, value_dict in lora_grouped_dict.items():
|
1126 |
+
rank = value_dict["to_k_lora.down.weight"].shape[0]
|
1127 |
+
cross_attention_dim = value_dict["to_k_lora.down.weight"].shape[1]
|
1128 |
+
hidden_size = value_dict["to_k_lora.up.weight"].shape[0]
|
1129 |
+
|
1130 |
+
attn_processor_class = (
|
1131 |
+
LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
|
1132 |
+
)
|
1133 |
+
attn_processors[key] = attn_processor_class(
|
1134 |
+
hidden_size=hidden_size,
|
1135 |
+
cross_attention_dim=cross_attention_dim,
|
1136 |
+
rank=rank,
|
1137 |
+
network_alpha=network_alpha,
|
1138 |
+
)
|
1139 |
+
attn_processors[key].load_state_dict(value_dict)
|
1140 |
+
|
1141 |
+
else:
|
1142 |
+
raise ValueError(f"{model_file} does not seem to be in the correct format expected by LoRA training.")
|
1143 |
+
|
1144 |
+
# set correct dtype & device
|
1145 |
+
attn_processors = {
|
1146 |
+
k: v.to(device=self.device, dtype=self.text_encoder.dtype) for k, v in attn_processors.items()
|
1147 |
+
}
|
1148 |
+
return attn_processors
|
1149 |
+
|
1150 |
+
@classmethod
|
1151 |
+
def save_lora_weights(
|
1152 |
+
self,
|
1153 |
+
save_directory: Union[str, os.PathLike],
|
1154 |
+
unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
1155 |
+
text_encoder_lora_layers: Dict[str, torch.nn.Module] = None,
|
1156 |
+
is_main_process: bool = True,
|
1157 |
+
weight_name: str = None,
|
1158 |
+
save_function: Callable = None,
|
1159 |
+
safe_serialization: bool = False,
|
1160 |
+
):
|
1161 |
+
r"""
|
1162 |
+
Save the LoRA parameters corresponding to the UNet and text encoder.
|
1163 |
+
|
1164 |
+
Arguments:
|
1165 |
+
save_directory (`str` or `os.PathLike`):
|
1166 |
+
Directory to save LoRA parameters to. Will be created if it doesn't exist.
|
1167 |
+
unet_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
|
1168 |
+
State dict of the LoRA layers corresponding to the UNet.
|
1169 |
+
text_encoder_lora_layers (`Dict[str, torch.nn.Module] or `Dict[str, torch.Tensor]`):
|
1170 |
+
State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text
|
1171 |
+
encoder LoRA state dict because it comes 🤗 Transformers.
|
1172 |
+
is_main_process (`bool`, *optional*, defaults to `True`):
|
1173 |
+
Whether the process calling this is the main process or not. Useful during distributed training and you
|
1174 |
+
need to call this function on all processes. In this case, set `is_main_process=True` only on the main
|
1175 |
+
process to avoid race conditions.
|
1176 |
+
save_function (`Callable`):
|
1177 |
+
The function to use to save the state dictionary. Useful during distributed training when you need to
|
1178 |
+
replace `torch.save` with another method. Can be configured with the environment variable
|
1179 |
+
`DIFFUSERS_SAVE_MODE`.
|
1180 |
+
"""
|
1181 |
+
if os.path.isfile(save_directory):
|
1182 |
+
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
|
1183 |
+
return
|
1184 |
+
|
1185 |
+
if save_function is None:
|
1186 |
+
if safe_serialization:
|
1187 |
+
|
1188 |
+
def save_function(weights, filename):
|
1189 |
+
return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"})
|
1190 |
+
|
1191 |
+
else:
|
1192 |
+
save_function = torch.save
|
1193 |
+
|
1194 |
+
os.makedirs(save_directory, exist_ok=True)
|
1195 |
+
|
1196 |
+
# Create a flat dictionary.
|
1197 |
+
state_dict = {}
|
1198 |
+
if unet_lora_layers is not None:
|
1199 |
+
weights = (
|
1200 |
+
unet_lora_layers.state_dict() if isinstance(unet_lora_layers, torch.nn.Module) else unet_lora_layers
|
1201 |
+
)
|
1202 |
+
|
1203 |
+
unet_lora_state_dict = {f"{self.unet_name}.{module_name}": param for module_name, param in weights.items()}
|
1204 |
+
state_dict.update(unet_lora_state_dict)
|
1205 |
+
|
1206 |
+
if text_encoder_lora_layers is not None:
|
1207 |
+
weights = (
|
1208 |
+
text_encoder_lora_layers.state_dict()
|
1209 |
+
if isinstance(text_encoder_lora_layers, torch.nn.Module)
|
1210 |
+
else text_encoder_lora_layers
|
1211 |
+
)
|
1212 |
+
|
1213 |
+
text_encoder_lora_state_dict = {
|
1214 |
+
f"{self.text_encoder_name}.{module_name}": param for module_name, param in weights.items()
|
1215 |
+
}
|
1216 |
+
state_dict.update(text_encoder_lora_state_dict)
|
1217 |
+
|
1218 |
+
# Save the model
|
1219 |
+
if weight_name is None:
|
1220 |
+
if safe_serialization:
|
1221 |
+
weight_name = LORA_WEIGHT_NAME_SAFE
|
1222 |
+
else:
|
1223 |
+
weight_name = LORA_WEIGHT_NAME
|
1224 |
+
|
1225 |
+
save_function(state_dict, os.path.join(save_directory, weight_name))
|
1226 |
+
logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}")
|
1227 |
+
|
1228 |
+
def _convert_kohya_lora_to_diffusers(self, state_dict):
|
1229 |
+
unet_state_dict = {}
|
1230 |
+
te_state_dict = {}
|
1231 |
+
network_alpha = None
|
1232 |
+
|
1233 |
+
for key, value in state_dict.items():
|
1234 |
+
if "lora_down" in key:
|
1235 |
+
lora_name = key.split(".")[0]
|
1236 |
+
lora_name_up = lora_name + ".lora_up.weight"
|
1237 |
+
lora_name_alpha = lora_name + ".alpha"
|
1238 |
+
if lora_name_alpha in state_dict:
|
1239 |
+
alpha = state_dict[lora_name_alpha].item()
|
1240 |
+
if network_alpha is None:
|
1241 |
+
network_alpha = alpha
|
1242 |
+
elif network_alpha != alpha:
|
1243 |
+
raise ValueError("Network alpha is not consistent")
|
1244 |
+
|
1245 |
+
if lora_name.startswith("lora_unet_"):
|
1246 |
+
diffusers_name = key.replace("lora_unet_", "").replace("_", ".")
|
1247 |
+
diffusers_name = diffusers_name.replace("down.blocks", "down_blocks")
|
1248 |
+
diffusers_name = diffusers_name.replace("mid.block", "mid_block")
|
1249 |
+
diffusers_name = diffusers_name.replace("up.blocks", "up_blocks")
|
1250 |
+
diffusers_name = diffusers_name.replace("transformer.blocks", "transformer_blocks")
|
1251 |
+
diffusers_name = diffusers_name.replace("to.q.lora", "to_q_lora")
|
1252 |
+
diffusers_name = diffusers_name.replace("to.k.lora", "to_k_lora")
|
1253 |
+
diffusers_name = diffusers_name.replace("to.v.lora", "to_v_lora")
|
1254 |
+
diffusers_name = diffusers_name.replace("to.out.0.lora", "to_out_lora")
|
1255 |
+
if "transformer_blocks" in diffusers_name:
|
1256 |
+
if "attn1" in diffusers_name or "attn2" in diffusers_name:
|
1257 |
+
diffusers_name = diffusers_name.replace("attn1", "attn1.processor")
|
1258 |
+
diffusers_name = diffusers_name.replace("attn2", "attn2.processor")
|
1259 |
+
unet_state_dict[diffusers_name] = value
|
1260 |
+
unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up]
|
1261 |
+
elif lora_name.startswith("lora_te_"):
|
1262 |
+
diffusers_name = key.replace("lora_te_", "").replace("_", ".")
|
1263 |
+
diffusers_name = diffusers_name.replace("text.model", "text_model")
|
1264 |
+
diffusers_name = diffusers_name.replace("self.attn", "self_attn")
|
1265 |
+
diffusers_name = diffusers_name.replace("q.proj.lora", "to_q_lora")
|
1266 |
+
diffusers_name = diffusers_name.replace("k.proj.lora", "to_k_lora")
|
1267 |
+
diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora")
|
1268 |
+
diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora")
|
1269 |
+
if "self_attn" in diffusers_name:
|
1270 |
+
te_state_dict[diffusers_name] = value
|
1271 |
+
te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up]
|
1272 |
+
|
1273 |
+
unet_state_dict = {f"{UNET_NAME}.{module_name}": params for module_name, params in unet_state_dict.items()}
|
1274 |
+
te_state_dict = {f"{TEXT_ENCODER_NAME}.{module_name}": params for module_name, params in te_state_dict.items()}
|
1275 |
+
new_state_dict = {**unet_state_dict, **te_state_dict}
|
1276 |
+
return new_state_dict, network_alpha
|
1277 |
+
|
1278 |
+
|
1279 |
+
class FromSingleFileMixin:
|
1280 |
+
"""
|
1281 |
+
Load model weights saved in the `.ckpt` format into a [`DiffusionPipeline`].
|
1282 |
+
"""
|
1283 |
+
|
1284 |
+
@classmethod
|
1285 |
+
def from_ckpt(cls, *args, **kwargs):
|
1286 |
+
deprecation_message = "The function `from_ckpt` is deprecated in favor of `from_single_file` and will be removed in diffusers v.0.21. Please make sure to use `StableDiffusionPipeline.from_single_file(...)` instead."
|
1287 |
+
deprecate("from_ckpt", "0.21.0", deprecation_message, standard_warn=False)
|
1288 |
+
return cls.from_single_file(*args, **kwargs)
|
1289 |
+
|
1290 |
+
@classmethod
|
1291 |
+
def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
|
1292 |
+
r"""
|
1293 |
+
Instantiate a [`DiffusionPipeline`] from pretrained pipeline weights saved in the `.ckpt` format. The pipeline
|
1294 |
+
is set in evaluation mode (`model.eval()`) by default.
|
1295 |
+
|
1296 |
+
Parameters:
|
1297 |
+
pretrained_model_link_or_path (`str` or `os.PathLike`, *optional*):
|
1298 |
+
Can be either:
|
1299 |
+
- A link to the `.ckpt` file (for example
|
1300 |
+
`"https://huggingface.co/<repo_id>/blob/main/<path_to_file>.ckpt"`) on the Hub.
|
1301 |
+
- A path to a *file* containing all pipeline weights.
|
1302 |
+
torch_dtype (`str` or `torch.dtype`, *optional*):
|
1303 |
+
Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
|
1304 |
+
dtype is automatically derived from the model's weights.
|
1305 |
+
force_download (`bool`, *optional*, defaults to `False`):
|
1306 |
+
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
1307 |
+
cached versions if they exist.
|
1308 |
+
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
1309 |
+
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
1310 |
+
is not used.
|
1311 |
+
resume_download (`bool`, *optional*, defaults to `False`):
|
1312 |
+
Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
|
1313 |
+
incompletely downloaded files are deleted.
|
1314 |
+
proxies (`Dict[str, str]`, *optional*):
|
1315 |
+
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
1316 |
+
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
1317 |
+
local_files_only (`bool`, *optional*, defaults to `False`):
|
1318 |
+
Whether to only load local model weights and configuration files or not. If set to True, the model
|
1319 |
+
won't be downloaded from the Hub.
|
1320 |
+
use_auth_token (`str` or *bool*, *optional*):
|
1321 |
+
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
1322 |
+
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
1323 |
+
revision (`str`, *optional*, defaults to `"main"`):
|
1324 |
+
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
1325 |
+
allowed by Git.
|
1326 |
+
use_safetensors (`bool`, *optional*, defaults to `None`):
|
1327 |
+
If set to `None`, the safetensors weights are downloaded if they're available **and** if the
|
1328 |
+
safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors
|
1329 |
+
weights. If set to `False`, safetensors weights are not loaded.
|
1330 |
+
extract_ema (`bool`, *optional*, defaults to `False`):
|
1331 |
+
Whether to extract the EMA weights or not. Pass `True` to extract the EMA weights which usually yield
|
1332 |
+
higher quality images for inference. Non-EMA weights are usually better to continue finetuning.
|
1333 |
+
upcast_attention (`bool`, *optional*, defaults to `None`):
|
1334 |
+
Whether the attention computation should always be upcasted.
|
1335 |
+
image_size (`int`, *optional*, defaults to 512):
|
1336 |
+
The image size the model was trained on. Use 512 for all Stable Diffusion v1 models and the Stable
|
1337 |
+
Diffusion v2 base model. Use 768 for Stable Diffusion v2.
|
1338 |
+
prediction_type (`str`, *optional*):
|
1339 |
+
The prediction type the model was trained on. Use `'epsilon'` for all Stable Diffusion v1 models and
|
1340 |
+
the Stable Diffusion v2 base model. Use `'v_prediction'` for Stable Diffusion v2.
|
1341 |
+
num_in_channels (`int`, *optional*, defaults to `None`):
|
1342 |
+
The number of input channels. If `None`, it will be automatically inferred.
|
1343 |
+
scheduler_type (`str`, *optional*, defaults to `"pndm"`):
|
1344 |
+
Type of scheduler to use. Should be one of `["pndm", "lms", "heun", "euler", "euler-ancestral", "dpm",
|
1345 |
+
"ddim"]`.
|
1346 |
+
load_safety_checker (`bool`, *optional*, defaults to `True`):
|
1347 |
+
Whether to load the safety checker or not.
|
1348 |
+
text_encoder (`CLIPTextModel`, *optional*, defaults to `None`):
|
1349 |
+
An instance of
|
1350 |
+
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel) to use,
|
1351 |
+
specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)
|
1352 |
+
variant. If this parameter is `None`, the function will load a new instance of [CLIP] by itself, if
|
1353 |
+
needed.
|
1354 |
+
tokenizer (`CLIPTokenizer`, *optional*, defaults to `None`):
|
1355 |
+
An instance of
|
1356 |
+
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer)
|
1357 |
+
to use. If this parameter is `None`, the function will load a new instance of [CLIPTokenizer] by
|
1358 |
+
itself, if needed.
|
1359 |
+
kwargs (remaining dictionary of keyword arguments, *optional*):
|
1360 |
+
Can be used to overwrite load and saveable variables (for example the pipeline components of the
|
1361 |
+
specific pipeline class). The overwritten components are directly passed to the pipelines `__init__`
|
1362 |
+
method. See example below for more information.
|
1363 |
+
|
1364 |
+
Examples:
|
1365 |
+
|
1366 |
+
```py
|
1367 |
+
>>> from diffusers import StableDiffusionPipeline
|
1368 |
+
|
1369 |
+
>>> # Download pipeline from huggingface.co and cache.
|
1370 |
+
>>> pipeline = StableDiffusionPipeline.from_single_file(
|
1371 |
+
... "https://huggingface.co/WarriorMama777/OrangeMixs/blob/main/Models/AbyssOrangeMix/AbyssOrangeMix.safetensors"
|
1372 |
+
... )
|
1373 |
+
|
1374 |
+
>>> # Download pipeline from local file
|
1375 |
+
>>> # file is downloaded under ./v1-5-pruned-emaonly.ckpt
|
1376 |
+
>>> pipeline = StableDiffusionPipeline.from_single_file("./v1-5-pruned-emaonly")
|
1377 |
+
|
1378 |
+
>>> # Enable float16 and move to GPU
|
1379 |
+
>>> pipeline = StableDiffusionPipeline.from_single_file(
|
1380 |
+
... "https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.ckpt",
|
1381 |
+
... torch_dtype=torch.float16,
|
1382 |
+
... )
|
1383 |
+
>>> pipeline.to("cuda")
|
1384 |
+
```
|
1385 |
+
"""
|
1386 |
+
# import here to avoid circular dependency
|
1387 |
+
from .pipelines.stable_diffusion.convert_from_ckpt import download_from_original_stable_diffusion_ckpt
|
1388 |
+
|
1389 |
+
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
|
1390 |
+
resume_download = kwargs.pop("resume_download", False)
|
1391 |
+
force_download = kwargs.pop("force_download", False)
|
1392 |
+
proxies = kwargs.pop("proxies", None)
|
1393 |
+
local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
|
1394 |
+
use_auth_token = kwargs.pop("use_auth_token", None)
|
1395 |
+
revision = kwargs.pop("revision", None)
|
1396 |
+
extract_ema = kwargs.pop("extract_ema", False)
|
1397 |
+
image_size = kwargs.pop("image_size", None)
|
1398 |
+
scheduler_type = kwargs.pop("scheduler_type", "pndm")
|
1399 |
+
num_in_channels = kwargs.pop("num_in_channels", None)
|
1400 |
+
upcast_attention = kwargs.pop("upcast_attention", None)
|
1401 |
+
load_safety_checker = kwargs.pop("load_safety_checker", True)
|
1402 |
+
prediction_type = kwargs.pop("prediction_type", None)
|
1403 |
+
text_encoder = kwargs.pop("text_encoder", None)
|
1404 |
+
tokenizer = kwargs.pop("tokenizer", None)
|
1405 |
+
|
1406 |
+
torch_dtype = kwargs.pop("torch_dtype", None)
|
1407 |
+
|
1408 |
+
use_safetensors = kwargs.pop("use_safetensors", None if is_safetensors_available() else False)
|
1409 |
+
|
1410 |
+
pipeline_name = cls.__name__
|
1411 |
+
file_extension = pretrained_model_link_or_path.rsplit(".", 1)[-1]
|
1412 |
+
from_safetensors = file_extension == "safetensors"
|
1413 |
+
|
1414 |
+
if from_safetensors and use_safetensors is False:
|
1415 |
+
raise ValueError("Make sure to install `safetensors` with `pip install safetensors`.")
|
1416 |
+
|
1417 |
+
# TODO: For now we only support stable diffusion
|
1418 |
+
stable_unclip = None
|
1419 |
+
model_type = None
|
1420 |
+
controlnet = False
|
1421 |
+
|
1422 |
+
if pipeline_name == "StableDiffusionControlNetPipeline":
|
1423 |
+
# Model type will be inferred from the checkpoint.
|
1424 |
+
controlnet = True
|
1425 |
+
elif "StableDiffusion" in pipeline_name:
|
1426 |
+
# Model type will be inferred from the checkpoint.
|
1427 |
+
pass
|
1428 |
+
elif pipeline_name == "StableUnCLIPPipeline":
|
1429 |
+
model_type = "FrozenOpenCLIPEmbedder"
|
1430 |
+
stable_unclip = "txt2img"
|
1431 |
+
elif pipeline_name == "StableUnCLIPImg2ImgPipeline":
|
1432 |
+
model_type = "FrozenOpenCLIPEmbedder"
|
1433 |
+
stable_unclip = "img2img"
|
1434 |
+
elif pipeline_name == "PaintByExamplePipeline":
|
1435 |
+
model_type = "PaintByExample"
|
1436 |
+
elif pipeline_name == "LDMTextToImagePipeline":
|
1437 |
+
model_type = "LDMTextToImage"
|
1438 |
+
else:
|
1439 |
+
raise ValueError(f"Unhandled pipeline class: {pipeline_name}")
|
1440 |
+
|
1441 |
+
# remove huggingface url
|
1442 |
+
for prefix in ["https://huggingface.co/", "huggingface.co/", "hf.co/", "https://hf.co/"]:
|
1443 |
+
if pretrained_model_link_or_path.startswith(prefix):
|
1444 |
+
pretrained_model_link_or_path = pretrained_model_link_or_path[len(prefix) :]
|
1445 |
+
|
1446 |
+
# Code based on diffusers.pipelines.pipeline_utils.DiffusionPipeline.from_pretrained
|
1447 |
+
ckpt_path = Path(pretrained_model_link_or_path)
|
1448 |
+
if not ckpt_path.is_file():
|
1449 |
+
# get repo_id and (potentially nested) file path of ckpt in repo
|
1450 |
+
repo_id = "/".join(ckpt_path.parts[:2])
|
1451 |
+
file_path = "/".join(ckpt_path.parts[2:])
|
1452 |
+
|
1453 |
+
if file_path.startswith("blob/"):
|
1454 |
+
file_path = file_path[len("blob/") :]
|
1455 |
+
|
1456 |
+
if file_path.startswith("main/"):
|
1457 |
+
file_path = file_path[len("main/") :]
|
1458 |
+
|
1459 |
+
pretrained_model_link_or_path = hf_hub_download(
|
1460 |
+
repo_id,
|
1461 |
+
filename=file_path,
|
1462 |
+
cache_dir=cache_dir,
|
1463 |
+
resume_download=resume_download,
|
1464 |
+
proxies=proxies,
|
1465 |
+
local_files_only=local_files_only,
|
1466 |
+
use_auth_token=use_auth_token,
|
1467 |
+
revision=revision,
|
1468 |
+
force_download=force_download,
|
1469 |
+
)
|
1470 |
+
|
1471 |
+
pipe = download_from_original_stable_diffusion_ckpt(
|
1472 |
+
pretrained_model_link_or_path,
|
1473 |
+
pipeline_class=cls,
|
1474 |
+
model_type=model_type,
|
1475 |
+
stable_unclip=stable_unclip,
|
1476 |
+
controlnet=controlnet,
|
1477 |
+
from_safetensors=from_safetensors,
|
1478 |
+
extract_ema=extract_ema,
|
1479 |
+
image_size=image_size,
|
1480 |
+
scheduler_type=scheduler_type,
|
1481 |
+
num_in_channels=num_in_channels,
|
1482 |
+
upcast_attention=upcast_attention,
|
1483 |
+
load_safety_checker=load_safety_checker,
|
1484 |
+
prediction_type=prediction_type,
|
1485 |
+
text_encoder=text_encoder,
|
1486 |
+
tokenizer=tokenizer,
|
1487 |
+
)
|
1488 |
+
|
1489 |
+
if torch_dtype is not None:
|
1490 |
+
pipe.to(torch_dtype=torch_dtype)
|
1491 |
+
|
1492 |
+
return pipe
|
6DoF/diffusers/models/__init__.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from ..utils import is_flax_available, is_torch_available
|
16 |
+
|
17 |
+
|
18 |
+
if is_torch_available():
|
19 |
+
from .autoencoder_kl import AutoencoderKL
|
20 |
+
from .controlnet import ControlNetModel
|
21 |
+
from .dual_transformer_2d import DualTransformer2DModel
|
22 |
+
from .modeling_utils import ModelMixin
|
23 |
+
from .prior_transformer import PriorTransformer
|
24 |
+
from .t5_film_transformer import T5FilmDecoder
|
25 |
+
from .transformer_2d import Transformer2DModel
|
26 |
+
from .unet_1d import UNet1DModel
|
27 |
+
from .unet_2d import UNet2DModel
|
28 |
+
from .unet_2d_condition import UNet2DConditionModel
|
29 |
+
from .unet_3d_condition import UNet3DConditionModel
|
30 |
+
from .vq_model import VQModel
|
31 |
+
|
32 |
+
if is_flax_available():
|
33 |
+
from .controlnet_flax import FlaxControlNetModel
|
34 |
+
from .unet_2d_condition_flax import FlaxUNet2DConditionModel
|
35 |
+
from .vae_flax import FlaxAutoencoderKL
|
6DoF/diffusers/models/activations.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import nn
|
2 |
+
|
3 |
+
|
4 |
+
def get_activation(act_fn):
|
5 |
+
if act_fn in ["swish", "silu"]:
|
6 |
+
return nn.SiLU()
|
7 |
+
elif act_fn == "mish":
|
8 |
+
return nn.Mish()
|
9 |
+
elif act_fn == "gelu":
|
10 |
+
return nn.GELU()
|
11 |
+
else:
|
12 |
+
raise ValueError(f"Unsupported activation function: {act_fn}")
|
6DoF/diffusers/models/attention.py
ADDED
@@ -0,0 +1,392 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
from typing import Any, Dict, Optional
|
15 |
+
|
16 |
+
import torch
|
17 |
+
import torch.nn.functional as F
|
18 |
+
from torch import nn
|
19 |
+
|
20 |
+
from ..utils import maybe_allow_in_graph
|
21 |
+
from .activations import get_activation
|
22 |
+
from .attention_processor import Attention
|
23 |
+
from .embeddings import CombinedTimestepLabelEmbeddings
|
24 |
+
|
25 |
+
|
26 |
+
@maybe_allow_in_graph
|
27 |
+
class BasicTransformerBlock(nn.Module):
|
28 |
+
r"""
|
29 |
+
A basic Transformer block.
|
30 |
+
|
31 |
+
Parameters:
|
32 |
+
dim (`int`): The number of channels in the input and output.
|
33 |
+
num_attention_heads (`int`): The number of heads to use for multi-head attention.
|
34 |
+
attention_head_dim (`int`): The number of channels in each head.
|
35 |
+
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
36 |
+
cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
|
37 |
+
only_cross_attention (`bool`, *optional*):
|
38 |
+
Whether to use only cross-attention layers. In this case two cross attention layers are used.
|
39 |
+
double_self_attention (`bool`, *optional*):
|
40 |
+
Whether to use two self-attention layers. In this case no cross attention layers are used.
|
41 |
+
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
42 |
+
num_embeds_ada_norm (:
|
43 |
+
obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
|
44 |
+
attention_bias (:
|
45 |
+
obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
|
46 |
+
"""
|
47 |
+
|
48 |
+
def __init__(
|
49 |
+
self,
|
50 |
+
dim: int,
|
51 |
+
num_attention_heads: int,
|
52 |
+
attention_head_dim: int,
|
53 |
+
dropout=0.0,
|
54 |
+
cross_attention_dim: Optional[int] = None,
|
55 |
+
activation_fn: str = "geglu",
|
56 |
+
num_embeds_ada_norm: Optional[int] = None,
|
57 |
+
attention_bias: bool = False,
|
58 |
+
only_cross_attention: bool = False,
|
59 |
+
double_self_attention: bool = False,
|
60 |
+
upcast_attention: bool = False,
|
61 |
+
norm_elementwise_affine: bool = True,
|
62 |
+
norm_type: str = "layer_norm",
|
63 |
+
final_dropout: bool = False,
|
64 |
+
):
|
65 |
+
super().__init__()
|
66 |
+
self.only_cross_attention = only_cross_attention
|
67 |
+
|
68 |
+
self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
|
69 |
+
self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
|
70 |
+
|
71 |
+
if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
|
72 |
+
raise ValueError(
|
73 |
+
f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
|
74 |
+
f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
|
75 |
+
)
|
76 |
+
|
77 |
+
# Define 3 blocks. Each block has its own normalization layer.
|
78 |
+
# 1. Self-Attn
|
79 |
+
if self.use_ada_layer_norm:
|
80 |
+
self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
|
81 |
+
elif self.use_ada_layer_norm_zero:
|
82 |
+
self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
|
83 |
+
else:
|
84 |
+
self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
|
85 |
+
self.attn1 = Attention(
|
86 |
+
query_dim=dim,
|
87 |
+
heads=num_attention_heads,
|
88 |
+
dim_head=attention_head_dim,
|
89 |
+
dropout=dropout,
|
90 |
+
bias=attention_bias,
|
91 |
+
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
|
92 |
+
upcast_attention=upcast_attention,
|
93 |
+
)
|
94 |
+
|
95 |
+
# 2. Cross-Attn
|
96 |
+
if cross_attention_dim is not None or double_self_attention:
|
97 |
+
# We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
|
98 |
+
# I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
|
99 |
+
# the second cross attention block.
|
100 |
+
self.norm2 = (
|
101 |
+
AdaLayerNorm(dim, num_embeds_ada_norm)
|
102 |
+
if self.use_ada_layer_norm
|
103 |
+
else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
|
104 |
+
)
|
105 |
+
self.attn2 = Attention(
|
106 |
+
query_dim=dim,
|
107 |
+
cross_attention_dim=cross_attention_dim if not double_self_attention else None,
|
108 |
+
heads=num_attention_heads,
|
109 |
+
dim_head=attention_head_dim,
|
110 |
+
dropout=dropout,
|
111 |
+
bias=attention_bias,
|
112 |
+
upcast_attention=upcast_attention,
|
113 |
+
) # is self-attn if encoder_hidden_states is none
|
114 |
+
else:
|
115 |
+
self.norm2 = None
|
116 |
+
self.attn2 = None
|
117 |
+
|
118 |
+
# 3. Feed-forward
|
119 |
+
self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
|
120 |
+
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
|
121 |
+
|
122 |
+
# let chunk size default to None
|
123 |
+
self._chunk_size = None
|
124 |
+
self._chunk_dim = 0
|
125 |
+
|
126 |
+
def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
|
127 |
+
# Sets chunk feed-forward
|
128 |
+
self._chunk_size = chunk_size
|
129 |
+
self._chunk_dim = dim
|
130 |
+
|
131 |
+
def forward(
|
132 |
+
self,
|
133 |
+
hidden_states: torch.FloatTensor,
|
134 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
135 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
136 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
137 |
+
timestep: Optional[torch.LongTensor] = None,
|
138 |
+
posemb: Optional = None,
|
139 |
+
cross_attention_kwargs: Dict[str, Any] = None,
|
140 |
+
class_labels: Optional[torch.LongTensor] = None,
|
141 |
+
):
|
142 |
+
# Notice that normalization is always applied before the real computation in the following blocks.
|
143 |
+
# 1. Self-Attention
|
144 |
+
if self.use_ada_layer_norm:
|
145 |
+
norm_hidden_states = self.norm1(hidden_states, timestep)
|
146 |
+
elif self.use_ada_layer_norm_zero:
|
147 |
+
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
|
148 |
+
hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
|
149 |
+
)
|
150 |
+
else:
|
151 |
+
norm_hidden_states = self.norm1(hidden_states)
|
152 |
+
|
153 |
+
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
|
154 |
+
|
155 |
+
attn_output = self.attn1(
|
156 |
+
norm_hidden_states,
|
157 |
+
encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
|
158 |
+
attention_mask=attention_mask,
|
159 |
+
posemb=posemb, # todo in self attn, posemb shoule be [pose_in, pose_in]?
|
160 |
+
**cross_attention_kwargs,
|
161 |
+
)
|
162 |
+
if self.use_ada_layer_norm_zero:
|
163 |
+
attn_output = gate_msa.unsqueeze(1) * attn_output
|
164 |
+
hidden_states = attn_output + hidden_states
|
165 |
+
|
166 |
+
# 2. Cross-Attention
|
167 |
+
if self.attn2 is not None:
|
168 |
+
norm_hidden_states = (
|
169 |
+
self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
|
170 |
+
)
|
171 |
+
|
172 |
+
attn_output = self.attn2(
|
173 |
+
norm_hidden_states,
|
174 |
+
encoder_hidden_states=encoder_hidden_states,
|
175 |
+
attention_mask=encoder_attention_mask,
|
176 |
+
posemb=posemb,
|
177 |
+
**cross_attention_kwargs,
|
178 |
+
)
|
179 |
+
hidden_states = attn_output + hidden_states
|
180 |
+
|
181 |
+
# 3. Feed-forward
|
182 |
+
norm_hidden_states = self.norm3(hidden_states)
|
183 |
+
|
184 |
+
if self.use_ada_layer_norm_zero:
|
185 |
+
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
186 |
+
|
187 |
+
if self._chunk_size is not None:
|
188 |
+
# "feed_forward_chunk_size" can be used to save memory
|
189 |
+
if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
|
190 |
+
raise ValueError(
|
191 |
+
f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
|
192 |
+
)
|
193 |
+
|
194 |
+
num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
|
195 |
+
ff_output = torch.cat(
|
196 |
+
[self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)],
|
197 |
+
dim=self._chunk_dim,
|
198 |
+
)
|
199 |
+
else:
|
200 |
+
ff_output = self.ff(norm_hidden_states)
|
201 |
+
|
202 |
+
if self.use_ada_layer_norm_zero:
|
203 |
+
ff_output = gate_mlp.unsqueeze(1) * ff_output
|
204 |
+
|
205 |
+
hidden_states = ff_output + hidden_states
|
206 |
+
|
207 |
+
return hidden_states
|
208 |
+
|
209 |
+
|
210 |
+
class FeedForward(nn.Module):
|
211 |
+
r"""
|
212 |
+
A feed-forward layer.
|
213 |
+
|
214 |
+
Parameters:
|
215 |
+
dim (`int`): The number of channels in the input.
|
216 |
+
dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
|
217 |
+
mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
|
218 |
+
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
219 |
+
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
220 |
+
final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
|
221 |
+
"""
|
222 |
+
|
223 |
+
def __init__(
|
224 |
+
self,
|
225 |
+
dim: int,
|
226 |
+
dim_out: Optional[int] = None,
|
227 |
+
mult: int = 4,
|
228 |
+
dropout: float = 0.0,
|
229 |
+
activation_fn: str = "geglu",
|
230 |
+
final_dropout: bool = False,
|
231 |
+
):
|
232 |
+
super().__init__()
|
233 |
+
inner_dim = int(dim * mult)
|
234 |
+
dim_out = dim_out if dim_out is not None else dim
|
235 |
+
|
236 |
+
if activation_fn == "gelu":
|
237 |
+
act_fn = GELU(dim, inner_dim)
|
238 |
+
if activation_fn == "gelu-approximate":
|
239 |
+
act_fn = GELU(dim, inner_dim, approximate="tanh")
|
240 |
+
elif activation_fn == "geglu":
|
241 |
+
act_fn = GEGLU(dim, inner_dim)
|
242 |
+
elif activation_fn == "geglu-approximate":
|
243 |
+
act_fn = ApproximateGELU(dim, inner_dim)
|
244 |
+
|
245 |
+
self.net = nn.ModuleList([])
|
246 |
+
# project in
|
247 |
+
self.net.append(act_fn)
|
248 |
+
# project dropout
|
249 |
+
self.net.append(nn.Dropout(dropout))
|
250 |
+
# project out
|
251 |
+
self.net.append(nn.Linear(inner_dim, dim_out))
|
252 |
+
# FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
|
253 |
+
if final_dropout:
|
254 |
+
self.net.append(nn.Dropout(dropout))
|
255 |
+
|
256 |
+
def forward(self, hidden_states):
|
257 |
+
for module in self.net:
|
258 |
+
hidden_states = module(hidden_states)
|
259 |
+
return hidden_states
|
260 |
+
|
261 |
+
|
262 |
+
class GELU(nn.Module):
|
263 |
+
r"""
|
264 |
+
GELU activation function with tanh approximation support with `approximate="tanh"`.
|
265 |
+
"""
|
266 |
+
|
267 |
+
def __init__(self, dim_in: int, dim_out: int, approximate: str = "none"):
|
268 |
+
super().__init__()
|
269 |
+
self.proj = nn.Linear(dim_in, dim_out)
|
270 |
+
self.approximate = approximate
|
271 |
+
|
272 |
+
def gelu(self, gate):
|
273 |
+
if gate.device.type != "mps":
|
274 |
+
return F.gelu(gate, approximate=self.approximate)
|
275 |
+
# mps: gelu is not implemented for float16
|
276 |
+
return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to(dtype=gate.dtype)
|
277 |
+
|
278 |
+
def forward(self, hidden_states):
|
279 |
+
hidden_states = self.proj(hidden_states)
|
280 |
+
hidden_states = self.gelu(hidden_states)
|
281 |
+
return hidden_states
|
282 |
+
|
283 |
+
|
284 |
+
class GEGLU(nn.Module):
|
285 |
+
r"""
|
286 |
+
A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.
|
287 |
+
|
288 |
+
Parameters:
|
289 |
+
dim_in (`int`): The number of channels in the input.
|
290 |
+
dim_out (`int`): The number of channels in the output.
|
291 |
+
"""
|
292 |
+
|
293 |
+
def __init__(self, dim_in: int, dim_out: int):
|
294 |
+
super().__init__()
|
295 |
+
self.proj = nn.Linear(dim_in, dim_out * 2)
|
296 |
+
|
297 |
+
def gelu(self, gate):
|
298 |
+
if gate.device.type != "mps":
|
299 |
+
return F.gelu(gate)
|
300 |
+
# mps: gelu is not implemented for float16
|
301 |
+
return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
|
302 |
+
|
303 |
+
def forward(self, hidden_states):
|
304 |
+
hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
|
305 |
+
return hidden_states * self.gelu(gate)
|
306 |
+
|
307 |
+
|
308 |
+
class ApproximateGELU(nn.Module):
|
309 |
+
"""
|
310 |
+
The approximate form of Gaussian Error Linear Unit (GELU)
|
311 |
+
|
312 |
+
For more details, see section 2: https://arxiv.org/abs/1606.08415
|
313 |
+
"""
|
314 |
+
|
315 |
+
def __init__(self, dim_in: int, dim_out: int):
|
316 |
+
super().__init__()
|
317 |
+
self.proj = nn.Linear(dim_in, dim_out)
|
318 |
+
|
319 |
+
def forward(self, x):
|
320 |
+
x = self.proj(x)
|
321 |
+
return x * torch.sigmoid(1.702 * x)
|
322 |
+
|
323 |
+
|
324 |
+
class AdaLayerNorm(nn.Module):
|
325 |
+
"""
|
326 |
+
Norm layer modified to incorporate timestep embeddings.
|
327 |
+
"""
|
328 |
+
|
329 |
+
def __init__(self, embedding_dim, num_embeddings):
|
330 |
+
super().__init__()
|
331 |
+
self.emb = nn.Embedding(num_embeddings, embedding_dim)
|
332 |
+
self.silu = nn.SiLU()
|
333 |
+
self.linear = nn.Linear(embedding_dim, embedding_dim * 2)
|
334 |
+
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False)
|
335 |
+
|
336 |
+
def forward(self, x, timestep):
|
337 |
+
emb = self.linear(self.silu(self.emb(timestep)))
|
338 |
+
scale, shift = torch.chunk(emb, 2)
|
339 |
+
x = self.norm(x) * (1 + scale) + shift
|
340 |
+
return x
|
341 |
+
|
342 |
+
|
343 |
+
class AdaLayerNormZero(nn.Module):
|
344 |
+
"""
|
345 |
+
Norm layer adaptive layer norm zero (adaLN-Zero).
|
346 |
+
"""
|
347 |
+
|
348 |
+
def __init__(self, embedding_dim, num_embeddings):
|
349 |
+
super().__init__()
|
350 |
+
|
351 |
+
self.emb = CombinedTimestepLabelEmbeddings(num_embeddings, embedding_dim)
|
352 |
+
|
353 |
+
self.silu = nn.SiLU()
|
354 |
+
self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True)
|
355 |
+
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
|
356 |
+
|
357 |
+
def forward(self, x, timestep, class_labels, hidden_dtype=None):
|
358 |
+
emb = self.linear(self.silu(self.emb(timestep, class_labels, hidden_dtype=hidden_dtype)))
|
359 |
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=1)
|
360 |
+
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
|
361 |
+
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
|
362 |
+
|
363 |
+
|
364 |
+
class AdaGroupNorm(nn.Module):
|
365 |
+
"""
|
366 |
+
GroupNorm layer modified to incorporate timestep embeddings.
|
367 |
+
"""
|
368 |
+
|
369 |
+
def __init__(
|
370 |
+
self, embedding_dim: int, out_dim: int, num_groups: int, act_fn: Optional[str] = None, eps: float = 1e-5
|
371 |
+
):
|
372 |
+
super().__init__()
|
373 |
+
self.num_groups = num_groups
|
374 |
+
self.eps = eps
|
375 |
+
|
376 |
+
if act_fn is None:
|
377 |
+
self.act = None
|
378 |
+
else:
|
379 |
+
self.act = get_activation(act_fn)
|
380 |
+
|
381 |
+
self.linear = nn.Linear(embedding_dim, out_dim * 2)
|
382 |
+
|
383 |
+
def forward(self, x, emb):
|
384 |
+
if self.act:
|
385 |
+
emb = self.act(emb)
|
386 |
+
emb = self.linear(emb)
|
387 |
+
emb = emb[:, :, None, None]
|
388 |
+
scale, shift = emb.chunk(2, dim=1)
|
389 |
+
|
390 |
+
x = F.group_norm(x, self.num_groups, eps=self.eps)
|
391 |
+
x = x * (1 + scale) + shift
|
392 |
+
return x
|
6DoF/diffusers/models/attention_flax.py
ADDED
@@ -0,0 +1,446 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import functools
|
16 |
+
import math
|
17 |
+
|
18 |
+
import flax.linen as nn
|
19 |
+
import jax
|
20 |
+
import jax.numpy as jnp
|
21 |
+
|
22 |
+
|
23 |
+
def _query_chunk_attention(query, key, value, precision, key_chunk_size: int = 4096):
|
24 |
+
"""Multi-head dot product attention with a limited number of queries."""
|
25 |
+
num_kv, num_heads, k_features = key.shape[-3:]
|
26 |
+
v_features = value.shape[-1]
|
27 |
+
key_chunk_size = min(key_chunk_size, num_kv)
|
28 |
+
query = query / jnp.sqrt(k_features)
|
29 |
+
|
30 |
+
@functools.partial(jax.checkpoint, prevent_cse=False)
|
31 |
+
def summarize_chunk(query, key, value):
|
32 |
+
attn_weights = jnp.einsum("...qhd,...khd->...qhk", query, key, precision=precision)
|
33 |
+
|
34 |
+
max_score = jnp.max(attn_weights, axis=-1, keepdims=True)
|
35 |
+
max_score = jax.lax.stop_gradient(max_score)
|
36 |
+
exp_weights = jnp.exp(attn_weights - max_score)
|
37 |
+
|
38 |
+
exp_values = jnp.einsum("...vhf,...qhv->...qhf", value, exp_weights, precision=precision)
|
39 |
+
max_score = jnp.einsum("...qhk->...qh", max_score)
|
40 |
+
|
41 |
+
return (exp_values, exp_weights.sum(axis=-1), max_score)
|
42 |
+
|
43 |
+
def chunk_scanner(chunk_idx):
|
44 |
+
# julienne key array
|
45 |
+
key_chunk = jax.lax.dynamic_slice(
|
46 |
+
operand=key,
|
47 |
+
start_indices=[0] * (key.ndim - 3) + [chunk_idx, 0, 0], # [...,k,h,d]
|
48 |
+
slice_sizes=list(key.shape[:-3]) + [key_chunk_size, num_heads, k_features], # [...,k,h,d]
|
49 |
+
)
|
50 |
+
|
51 |
+
# julienne value array
|
52 |
+
value_chunk = jax.lax.dynamic_slice(
|
53 |
+
operand=value,
|
54 |
+
start_indices=[0] * (value.ndim - 3) + [chunk_idx, 0, 0], # [...,v,h,d]
|
55 |
+
slice_sizes=list(value.shape[:-3]) + [key_chunk_size, num_heads, v_features], # [...,v,h,d]
|
56 |
+
)
|
57 |
+
|
58 |
+
return summarize_chunk(query, key_chunk, value_chunk)
|
59 |
+
|
60 |
+
chunk_values, chunk_weights, chunk_max = jax.lax.map(f=chunk_scanner, xs=jnp.arange(0, num_kv, key_chunk_size))
|
61 |
+
|
62 |
+
global_max = jnp.max(chunk_max, axis=0, keepdims=True)
|
63 |
+
max_diffs = jnp.exp(chunk_max - global_max)
|
64 |
+
|
65 |
+
chunk_values *= jnp.expand_dims(max_diffs, axis=-1)
|
66 |
+
chunk_weights *= max_diffs
|
67 |
+
|
68 |
+
all_values = chunk_values.sum(axis=0)
|
69 |
+
all_weights = jnp.expand_dims(chunk_weights, -1).sum(axis=0)
|
70 |
+
|
71 |
+
return all_values / all_weights
|
72 |
+
|
73 |
+
|
74 |
+
def jax_memory_efficient_attention(
|
75 |
+
query, key, value, precision=jax.lax.Precision.HIGHEST, query_chunk_size: int = 1024, key_chunk_size: int = 4096
|
76 |
+
):
|
77 |
+
r"""
|
78 |
+
Flax Memory-efficient multi-head dot product attention. https://arxiv.org/abs/2112.05682v2
|
79 |
+
https://github.com/AminRezaei0x443/memory-efficient-attention
|
80 |
+
|
81 |
+
Args:
|
82 |
+
query (`jnp.ndarray`): (batch..., query_length, head, query_key_depth_per_head)
|
83 |
+
key (`jnp.ndarray`): (batch..., key_value_length, head, query_key_depth_per_head)
|
84 |
+
value (`jnp.ndarray`): (batch..., key_value_length, head, value_depth_per_head)
|
85 |
+
precision (`jax.lax.Precision`, *optional*, defaults to `jax.lax.Precision.HIGHEST`):
|
86 |
+
numerical precision for computation
|
87 |
+
query_chunk_size (`int`, *optional*, defaults to 1024):
|
88 |
+
chunk size to divide query array value must divide query_length equally without remainder
|
89 |
+
key_chunk_size (`int`, *optional*, defaults to 4096):
|
90 |
+
chunk size to divide key and value array value must divide key_value_length equally without remainder
|
91 |
+
|
92 |
+
Returns:
|
93 |
+
(`jnp.ndarray`) with shape of (batch..., query_length, head, value_depth_per_head)
|
94 |
+
"""
|
95 |
+
num_q, num_heads, q_features = query.shape[-3:]
|
96 |
+
|
97 |
+
def chunk_scanner(chunk_idx, _):
|
98 |
+
# julienne query array
|
99 |
+
query_chunk = jax.lax.dynamic_slice(
|
100 |
+
operand=query,
|
101 |
+
start_indices=([0] * (query.ndim - 3)) + [chunk_idx, 0, 0], # [...,q,h,d]
|
102 |
+
slice_sizes=list(query.shape[:-3]) + [min(query_chunk_size, num_q), num_heads, q_features], # [...,q,h,d]
|
103 |
+
)
|
104 |
+
|
105 |
+
return (
|
106 |
+
chunk_idx + query_chunk_size, # unused ignore it
|
107 |
+
_query_chunk_attention(
|
108 |
+
query=query_chunk, key=key, value=value, precision=precision, key_chunk_size=key_chunk_size
|
109 |
+
),
|
110 |
+
)
|
111 |
+
|
112 |
+
_, res = jax.lax.scan(
|
113 |
+
f=chunk_scanner, init=0, xs=None, length=math.ceil(num_q / query_chunk_size) # start counter # stop counter
|
114 |
+
)
|
115 |
+
|
116 |
+
return jnp.concatenate(res, axis=-3) # fuse the chunked result back
|
117 |
+
|
118 |
+
|
119 |
+
class FlaxAttention(nn.Module):
|
120 |
+
r"""
|
121 |
+
A Flax multi-head attention module as described in: https://arxiv.org/abs/1706.03762
|
122 |
+
|
123 |
+
Parameters:
|
124 |
+
query_dim (:obj:`int`):
|
125 |
+
Input hidden states dimension
|
126 |
+
heads (:obj:`int`, *optional*, defaults to 8):
|
127 |
+
Number of heads
|
128 |
+
dim_head (:obj:`int`, *optional*, defaults to 64):
|
129 |
+
Hidden states dimension inside each head
|
130 |
+
dropout (:obj:`float`, *optional*, defaults to 0.0):
|
131 |
+
Dropout rate
|
132 |
+
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
|
133 |
+
enable memory efficient attention https://arxiv.org/abs/2112.05682
|
134 |
+
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
|
135 |
+
Parameters `dtype`
|
136 |
+
|
137 |
+
"""
|
138 |
+
query_dim: int
|
139 |
+
heads: int = 8
|
140 |
+
dim_head: int = 64
|
141 |
+
dropout: float = 0.0
|
142 |
+
use_memory_efficient_attention: bool = False
|
143 |
+
dtype: jnp.dtype = jnp.float32
|
144 |
+
|
145 |
+
def setup(self):
|
146 |
+
inner_dim = self.dim_head * self.heads
|
147 |
+
self.scale = self.dim_head**-0.5
|
148 |
+
|
149 |
+
# Weights were exported with old names {to_q, to_k, to_v, to_out}
|
150 |
+
self.query = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_q")
|
151 |
+
self.key = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_k")
|
152 |
+
self.value = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_v")
|
153 |
+
|
154 |
+
self.proj_attn = nn.Dense(self.query_dim, dtype=self.dtype, name="to_out_0")
|
155 |
+
self.dropout_layer = nn.Dropout(rate=self.dropout)
|
156 |
+
|
157 |
+
def reshape_heads_to_batch_dim(self, tensor):
|
158 |
+
batch_size, seq_len, dim = tensor.shape
|
159 |
+
head_size = self.heads
|
160 |
+
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
|
161 |
+
tensor = jnp.transpose(tensor, (0, 2, 1, 3))
|
162 |
+
tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size)
|
163 |
+
return tensor
|
164 |
+
|
165 |
+
def reshape_batch_dim_to_heads(self, tensor):
|
166 |
+
batch_size, seq_len, dim = tensor.shape
|
167 |
+
head_size = self.heads
|
168 |
+
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
|
169 |
+
tensor = jnp.transpose(tensor, (0, 2, 1, 3))
|
170 |
+
tensor = tensor.reshape(batch_size // head_size, seq_len, dim * head_size)
|
171 |
+
return tensor
|
172 |
+
|
173 |
+
def __call__(self, hidden_states, context=None, deterministic=True):
|
174 |
+
context = hidden_states if context is None else context
|
175 |
+
|
176 |
+
query_proj = self.query(hidden_states)
|
177 |
+
key_proj = self.key(context)
|
178 |
+
value_proj = self.value(context)
|
179 |
+
|
180 |
+
query_states = self.reshape_heads_to_batch_dim(query_proj)
|
181 |
+
key_states = self.reshape_heads_to_batch_dim(key_proj)
|
182 |
+
value_states = self.reshape_heads_to_batch_dim(value_proj)
|
183 |
+
|
184 |
+
if self.use_memory_efficient_attention:
|
185 |
+
query_states = query_states.transpose(1, 0, 2)
|
186 |
+
key_states = key_states.transpose(1, 0, 2)
|
187 |
+
value_states = value_states.transpose(1, 0, 2)
|
188 |
+
|
189 |
+
# this if statement create a chunk size for each layer of the unet
|
190 |
+
# the chunk size is equal to the query_length dimension of the deepest layer of the unet
|
191 |
+
|
192 |
+
flatten_latent_dim = query_states.shape[-3]
|
193 |
+
if flatten_latent_dim % 64 == 0:
|
194 |
+
query_chunk_size = int(flatten_latent_dim / 64)
|
195 |
+
elif flatten_latent_dim % 16 == 0:
|
196 |
+
query_chunk_size = int(flatten_latent_dim / 16)
|
197 |
+
elif flatten_latent_dim % 4 == 0:
|
198 |
+
query_chunk_size = int(flatten_latent_dim / 4)
|
199 |
+
else:
|
200 |
+
query_chunk_size = int(flatten_latent_dim)
|
201 |
+
|
202 |
+
hidden_states = jax_memory_efficient_attention(
|
203 |
+
query_states, key_states, value_states, query_chunk_size=query_chunk_size, key_chunk_size=4096 * 4
|
204 |
+
)
|
205 |
+
|
206 |
+
hidden_states = hidden_states.transpose(1, 0, 2)
|
207 |
+
else:
|
208 |
+
# compute attentions
|
209 |
+
attention_scores = jnp.einsum("b i d, b j d->b i j", query_states, key_states)
|
210 |
+
attention_scores = attention_scores * self.scale
|
211 |
+
attention_probs = nn.softmax(attention_scores, axis=2)
|
212 |
+
|
213 |
+
# attend to values
|
214 |
+
hidden_states = jnp.einsum("b i j, b j d -> b i d", attention_probs, value_states)
|
215 |
+
|
216 |
+
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
|
217 |
+
hidden_states = self.proj_attn(hidden_states)
|
218 |
+
return self.dropout_layer(hidden_states, deterministic=deterministic)
|
219 |
+
|
220 |
+
|
221 |
+
class FlaxBasicTransformerBlock(nn.Module):
|
222 |
+
r"""
|
223 |
+
A Flax transformer block layer with `GLU` (Gated Linear Unit) activation function as described in:
|
224 |
+
https://arxiv.org/abs/1706.03762
|
225 |
+
|
226 |
+
|
227 |
+
Parameters:
|
228 |
+
dim (:obj:`int`):
|
229 |
+
Inner hidden states dimension
|
230 |
+
n_heads (:obj:`int`):
|
231 |
+
Number of heads
|
232 |
+
d_head (:obj:`int`):
|
233 |
+
Hidden states dimension inside each head
|
234 |
+
dropout (:obj:`float`, *optional*, defaults to 0.0):
|
235 |
+
Dropout rate
|
236 |
+
only_cross_attention (`bool`, defaults to `False`):
|
237 |
+
Whether to only apply cross attention.
|
238 |
+
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
|
239 |
+
Parameters `dtype`
|
240 |
+
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
|
241 |
+
enable memory efficient attention https://arxiv.org/abs/2112.05682
|
242 |
+
"""
|
243 |
+
dim: int
|
244 |
+
n_heads: int
|
245 |
+
d_head: int
|
246 |
+
dropout: float = 0.0
|
247 |
+
only_cross_attention: bool = False
|
248 |
+
dtype: jnp.dtype = jnp.float32
|
249 |
+
use_memory_efficient_attention: bool = False
|
250 |
+
|
251 |
+
def setup(self):
|
252 |
+
# self attention (or cross_attention if only_cross_attention is True)
|
253 |
+
self.attn1 = FlaxAttention(
|
254 |
+
self.dim, self.n_heads, self.d_head, self.dropout, self.use_memory_efficient_attention, dtype=self.dtype
|
255 |
+
)
|
256 |
+
# cross attention
|
257 |
+
self.attn2 = FlaxAttention(
|
258 |
+
self.dim, self.n_heads, self.d_head, self.dropout, self.use_memory_efficient_attention, dtype=self.dtype
|
259 |
+
)
|
260 |
+
self.ff = FlaxFeedForward(dim=self.dim, dropout=self.dropout, dtype=self.dtype)
|
261 |
+
self.norm1 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
|
262 |
+
self.norm2 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
|
263 |
+
self.norm3 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
|
264 |
+
self.dropout_layer = nn.Dropout(rate=self.dropout)
|
265 |
+
|
266 |
+
def __call__(self, hidden_states, context, deterministic=True):
|
267 |
+
# self attention
|
268 |
+
residual = hidden_states
|
269 |
+
if self.only_cross_attention:
|
270 |
+
hidden_states = self.attn1(self.norm1(hidden_states), context, deterministic=deterministic)
|
271 |
+
else:
|
272 |
+
hidden_states = self.attn1(self.norm1(hidden_states), deterministic=deterministic)
|
273 |
+
hidden_states = hidden_states + residual
|
274 |
+
|
275 |
+
# cross attention
|
276 |
+
residual = hidden_states
|
277 |
+
hidden_states = self.attn2(self.norm2(hidden_states), context, deterministic=deterministic)
|
278 |
+
hidden_states = hidden_states + residual
|
279 |
+
|
280 |
+
# feed forward
|
281 |
+
residual = hidden_states
|
282 |
+
hidden_states = self.ff(self.norm3(hidden_states), deterministic=deterministic)
|
283 |
+
hidden_states = hidden_states + residual
|
284 |
+
|
285 |
+
return self.dropout_layer(hidden_states, deterministic=deterministic)
|
286 |
+
|
287 |
+
|
288 |
+
class FlaxTransformer2DModel(nn.Module):
|
289 |
+
r"""
|
290 |
+
A Spatial Transformer layer with Gated Linear Unit (GLU) activation function as described in:
|
291 |
+
https://arxiv.org/pdf/1506.02025.pdf
|
292 |
+
|
293 |
+
|
294 |
+
Parameters:
|
295 |
+
in_channels (:obj:`int`):
|
296 |
+
Input number of channels
|
297 |
+
n_heads (:obj:`int`):
|
298 |
+
Number of heads
|
299 |
+
d_head (:obj:`int`):
|
300 |
+
Hidden states dimension inside each head
|
301 |
+
depth (:obj:`int`, *optional*, defaults to 1):
|
302 |
+
Number of transformers block
|
303 |
+
dropout (:obj:`float`, *optional*, defaults to 0.0):
|
304 |
+
Dropout rate
|
305 |
+
use_linear_projection (`bool`, defaults to `False`): tbd
|
306 |
+
only_cross_attention (`bool`, defaults to `False`): tbd
|
307 |
+
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
|
308 |
+
Parameters `dtype`
|
309 |
+
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
|
310 |
+
enable memory efficient attention https://arxiv.org/abs/2112.05682
|
311 |
+
"""
|
312 |
+
in_channels: int
|
313 |
+
n_heads: int
|
314 |
+
d_head: int
|
315 |
+
depth: int = 1
|
316 |
+
dropout: float = 0.0
|
317 |
+
use_linear_projection: bool = False
|
318 |
+
only_cross_attention: bool = False
|
319 |
+
dtype: jnp.dtype = jnp.float32
|
320 |
+
use_memory_efficient_attention: bool = False
|
321 |
+
|
322 |
+
def setup(self):
|
323 |
+
self.norm = nn.GroupNorm(num_groups=32, epsilon=1e-5)
|
324 |
+
|
325 |
+
inner_dim = self.n_heads * self.d_head
|
326 |
+
if self.use_linear_projection:
|
327 |
+
self.proj_in = nn.Dense(inner_dim, dtype=self.dtype)
|
328 |
+
else:
|
329 |
+
self.proj_in = nn.Conv(
|
330 |
+
inner_dim,
|
331 |
+
kernel_size=(1, 1),
|
332 |
+
strides=(1, 1),
|
333 |
+
padding="VALID",
|
334 |
+
dtype=self.dtype,
|
335 |
+
)
|
336 |
+
|
337 |
+
self.transformer_blocks = [
|
338 |
+
FlaxBasicTransformerBlock(
|
339 |
+
inner_dim,
|
340 |
+
self.n_heads,
|
341 |
+
self.d_head,
|
342 |
+
dropout=self.dropout,
|
343 |
+
only_cross_attention=self.only_cross_attention,
|
344 |
+
dtype=self.dtype,
|
345 |
+
use_memory_efficient_attention=self.use_memory_efficient_attention,
|
346 |
+
)
|
347 |
+
for _ in range(self.depth)
|
348 |
+
]
|
349 |
+
|
350 |
+
if self.use_linear_projection:
|
351 |
+
self.proj_out = nn.Dense(inner_dim, dtype=self.dtype)
|
352 |
+
else:
|
353 |
+
self.proj_out = nn.Conv(
|
354 |
+
inner_dim,
|
355 |
+
kernel_size=(1, 1),
|
356 |
+
strides=(1, 1),
|
357 |
+
padding="VALID",
|
358 |
+
dtype=self.dtype,
|
359 |
+
)
|
360 |
+
|
361 |
+
self.dropout_layer = nn.Dropout(rate=self.dropout)
|
362 |
+
|
363 |
+
def __call__(self, hidden_states, context, deterministic=True):
|
364 |
+
batch, height, width, channels = hidden_states.shape
|
365 |
+
residual = hidden_states
|
366 |
+
hidden_states = self.norm(hidden_states)
|
367 |
+
if self.use_linear_projection:
|
368 |
+
hidden_states = hidden_states.reshape(batch, height * width, channels)
|
369 |
+
hidden_states = self.proj_in(hidden_states)
|
370 |
+
else:
|
371 |
+
hidden_states = self.proj_in(hidden_states)
|
372 |
+
hidden_states = hidden_states.reshape(batch, height * width, channels)
|
373 |
+
|
374 |
+
for transformer_block in self.transformer_blocks:
|
375 |
+
hidden_states = transformer_block(hidden_states, context, deterministic=deterministic)
|
376 |
+
|
377 |
+
if self.use_linear_projection:
|
378 |
+
hidden_states = self.proj_out(hidden_states)
|
379 |
+
hidden_states = hidden_states.reshape(batch, height, width, channels)
|
380 |
+
else:
|
381 |
+
hidden_states = hidden_states.reshape(batch, height, width, channels)
|
382 |
+
hidden_states = self.proj_out(hidden_states)
|
383 |
+
|
384 |
+
hidden_states = hidden_states + residual
|
385 |
+
return self.dropout_layer(hidden_states, deterministic=deterministic)
|
386 |
+
|
387 |
+
|
388 |
+
class FlaxFeedForward(nn.Module):
|
389 |
+
r"""
|
390 |
+
Flax module that encapsulates two Linear layers separated by a non-linearity. It is the counterpart of PyTorch's
|
391 |
+
[`FeedForward`] class, with the following simplifications:
|
392 |
+
- The activation function is currently hardcoded to a gated linear unit from:
|
393 |
+
https://arxiv.org/abs/2002.05202
|
394 |
+
- `dim_out` is equal to `dim`.
|
395 |
+
- The number of hidden dimensions is hardcoded to `dim * 4` in [`FlaxGELU`].
|
396 |
+
|
397 |
+
Parameters:
|
398 |
+
dim (:obj:`int`):
|
399 |
+
Inner hidden states dimension
|
400 |
+
dropout (:obj:`float`, *optional*, defaults to 0.0):
|
401 |
+
Dropout rate
|
402 |
+
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
|
403 |
+
Parameters `dtype`
|
404 |
+
"""
|
405 |
+
dim: int
|
406 |
+
dropout: float = 0.0
|
407 |
+
dtype: jnp.dtype = jnp.float32
|
408 |
+
|
409 |
+
def setup(self):
|
410 |
+
# The second linear layer needs to be called
|
411 |
+
# net_2 for now to match the index of the Sequential layer
|
412 |
+
self.net_0 = FlaxGEGLU(self.dim, self.dropout, self.dtype)
|
413 |
+
self.net_2 = nn.Dense(self.dim, dtype=self.dtype)
|
414 |
+
|
415 |
+
def __call__(self, hidden_states, deterministic=True):
|
416 |
+
hidden_states = self.net_0(hidden_states, deterministic=deterministic)
|
417 |
+
hidden_states = self.net_2(hidden_states)
|
418 |
+
return hidden_states
|
419 |
+
|
420 |
+
|
421 |
+
class FlaxGEGLU(nn.Module):
|
422 |
+
r"""
|
423 |
+
Flax implementation of a Linear layer followed by the variant of the gated linear unit activation function from
|
424 |
+
https://arxiv.org/abs/2002.05202.
|
425 |
+
|
426 |
+
Parameters:
|
427 |
+
dim (:obj:`int`):
|
428 |
+
Input hidden states dimension
|
429 |
+
dropout (:obj:`float`, *optional*, defaults to 0.0):
|
430 |
+
Dropout rate
|
431 |
+
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
|
432 |
+
Parameters `dtype`
|
433 |
+
"""
|
434 |
+
dim: int
|
435 |
+
dropout: float = 0.0
|
436 |
+
dtype: jnp.dtype = jnp.float32
|
437 |
+
|
438 |
+
def setup(self):
|
439 |
+
inner_dim = self.dim * 4
|
440 |
+
self.proj = nn.Dense(inner_dim * 2, dtype=self.dtype)
|
441 |
+
self.dropout_layer = nn.Dropout(rate=self.dropout)
|
442 |
+
|
443 |
+
def __call__(self, hidden_states, deterministic=True):
|
444 |
+
hidden_states = self.proj(hidden_states)
|
445 |
+
hidden_linear, hidden_gelu = jnp.split(hidden_states, 2, axis=2)
|
446 |
+
return self.dropout_layer(hidden_linear * nn.gelu(hidden_gelu), deterministic=deterministic)
|
6DoF/diffusers/models/attention_processor.py
ADDED
@@ -0,0 +1,1684 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
from typing import Callable, Optional, Union
|
15 |
+
|
16 |
+
import torch
|
17 |
+
import torch.nn.functional as F
|
18 |
+
from torch import nn
|
19 |
+
|
20 |
+
from ..utils import deprecate, logging, maybe_allow_in_graph
|
21 |
+
from ..utils.import_utils import is_xformers_available
|
22 |
+
|
23 |
+
|
24 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
25 |
+
|
26 |
+
|
27 |
+
if is_xformers_available():
|
28 |
+
import xformers
|
29 |
+
import xformers.ops
|
30 |
+
else:
|
31 |
+
xformers = None
|
32 |
+
|
33 |
+
|
34 |
+
# 6DoF CaPE
|
35 |
+
import einops
|
36 |
+
def cape_embed(f, P):
|
37 |
+
# f is feature vector of shape [..., d]
|
38 |
+
# P is 4x4 transformation matrix
|
39 |
+
f = einops.rearrange(f, '... (d k) -> ... d k', k=4)
|
40 |
+
return einops.rearrange(f@P, '... d k -> ... (d k)', k=4)
|
41 |
+
|
42 |
+
@maybe_allow_in_graph
|
43 |
+
class Attention(nn.Module):
|
44 |
+
r"""
|
45 |
+
A cross attention layer.
|
46 |
+
|
47 |
+
Parameters:
|
48 |
+
query_dim (`int`): The number of channels in the query.
|
49 |
+
cross_attention_dim (`int`, *optional*):
|
50 |
+
The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
|
51 |
+
heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
|
52 |
+
dim_head (`int`, *optional*, defaults to 64): The number of channels in each head.
|
53 |
+
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
54 |
+
bias (`bool`, *optional*, defaults to False):
|
55 |
+
Set to `True` for the query, key, and value linear layers to contain a bias parameter.
|
56 |
+
"""
|
57 |
+
|
58 |
+
def __init__(
|
59 |
+
self,
|
60 |
+
query_dim: int,
|
61 |
+
cross_attention_dim: Optional[int] = None,
|
62 |
+
heads: int = 8,
|
63 |
+
dim_head: int = 64,
|
64 |
+
dropout: float = 0.0,
|
65 |
+
bias=False,
|
66 |
+
upcast_attention: bool = False,
|
67 |
+
upcast_softmax: bool = False,
|
68 |
+
cross_attention_norm: Optional[str] = None,
|
69 |
+
cross_attention_norm_num_groups: int = 32,
|
70 |
+
added_kv_proj_dim: Optional[int] = None,
|
71 |
+
norm_num_groups: Optional[int] = None,
|
72 |
+
spatial_norm_dim: Optional[int] = None,
|
73 |
+
out_bias: bool = True,
|
74 |
+
scale_qk: bool = True,
|
75 |
+
only_cross_attention: bool = False,
|
76 |
+
eps: float = 1e-5,
|
77 |
+
rescale_output_factor: float = 1.0,
|
78 |
+
residual_connection: bool = False,
|
79 |
+
_from_deprecated_attn_block=False,
|
80 |
+
processor: Optional["AttnProcessor"] = None,
|
81 |
+
):
|
82 |
+
super().__init__()
|
83 |
+
inner_dim = dim_head * heads
|
84 |
+
cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
|
85 |
+
self.upcast_attention = upcast_attention
|
86 |
+
self.upcast_softmax = upcast_softmax
|
87 |
+
self.rescale_output_factor = rescale_output_factor
|
88 |
+
self.residual_connection = residual_connection
|
89 |
+
self.dropout = dropout
|
90 |
+
|
91 |
+
# we make use of this private variable to know whether this class is loaded
|
92 |
+
# with an deprecated state dict so that we can convert it on the fly
|
93 |
+
self._from_deprecated_attn_block = _from_deprecated_attn_block
|
94 |
+
|
95 |
+
self.scale_qk = scale_qk
|
96 |
+
self.scale = dim_head**-0.5 if self.scale_qk else 1.0
|
97 |
+
|
98 |
+
self.heads = heads
|
99 |
+
# for slice_size > 0 the attention score computation
|
100 |
+
# is split across the batch axis to save memory
|
101 |
+
# You can set slice_size with `set_attention_slice`
|
102 |
+
self.sliceable_head_dim = heads
|
103 |
+
|
104 |
+
self.added_kv_proj_dim = added_kv_proj_dim
|
105 |
+
self.only_cross_attention = only_cross_attention
|
106 |
+
|
107 |
+
if self.added_kv_proj_dim is None and self.only_cross_attention:
|
108 |
+
raise ValueError(
|
109 |
+
"`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`."
|
110 |
+
)
|
111 |
+
|
112 |
+
if norm_num_groups is not None:
|
113 |
+
self.group_norm = nn.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True)
|
114 |
+
else:
|
115 |
+
self.group_norm = None
|
116 |
+
|
117 |
+
if spatial_norm_dim is not None:
|
118 |
+
self.spatial_norm = SpatialNorm(f_channels=query_dim, zq_channels=spatial_norm_dim)
|
119 |
+
else:
|
120 |
+
self.spatial_norm = None
|
121 |
+
|
122 |
+
if cross_attention_norm is None:
|
123 |
+
self.norm_cross = None
|
124 |
+
elif cross_attention_norm == "layer_norm":
|
125 |
+
self.norm_cross = nn.LayerNorm(cross_attention_dim)
|
126 |
+
elif cross_attention_norm == "group_norm":
|
127 |
+
if self.added_kv_proj_dim is not None:
|
128 |
+
# The given `encoder_hidden_states` are initially of shape
|
129 |
+
# (batch_size, seq_len, added_kv_proj_dim) before being projected
|
130 |
+
# to (batch_size, seq_len, cross_attention_dim). The norm is applied
|
131 |
+
# before the projection, so we need to use `added_kv_proj_dim` as
|
132 |
+
# the number of channels for the group norm.
|
133 |
+
norm_cross_num_channels = added_kv_proj_dim
|
134 |
+
else:
|
135 |
+
norm_cross_num_channels = cross_attention_dim
|
136 |
+
|
137 |
+
self.norm_cross = nn.GroupNorm(
|
138 |
+
num_channels=norm_cross_num_channels, num_groups=cross_attention_norm_num_groups, eps=1e-5, affine=True
|
139 |
+
)
|
140 |
+
else:
|
141 |
+
raise ValueError(
|
142 |
+
f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'"
|
143 |
+
)
|
144 |
+
|
145 |
+
self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
|
146 |
+
|
147 |
+
if not self.only_cross_attention:
|
148 |
+
# only relevant for the `AddedKVProcessor` classes
|
149 |
+
self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
|
150 |
+
self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
|
151 |
+
else:
|
152 |
+
self.to_k = None
|
153 |
+
self.to_v = None
|
154 |
+
|
155 |
+
if self.added_kv_proj_dim is not None:
|
156 |
+
self.add_k_proj = nn.Linear(added_kv_proj_dim, inner_dim)
|
157 |
+
self.add_v_proj = nn.Linear(added_kv_proj_dim, inner_dim)
|
158 |
+
|
159 |
+
self.to_out = nn.ModuleList([])
|
160 |
+
self.to_out.append(nn.Linear(inner_dim, query_dim, bias=out_bias))
|
161 |
+
self.to_out.append(nn.Dropout(dropout))
|
162 |
+
|
163 |
+
# set attention processor
|
164 |
+
# We use the AttnProcessor2_0 by default when torch 2.x is used which uses
|
165 |
+
# torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
|
166 |
+
# but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
|
167 |
+
if processor is None:
|
168 |
+
processor = (
|
169 |
+
AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
|
170 |
+
)
|
171 |
+
self.set_processor(processor)
|
172 |
+
|
173 |
+
def set_use_memory_efficient_attention_xformers(
|
174 |
+
self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
|
175 |
+
):
|
176 |
+
is_lora = hasattr(self, "processor") and isinstance(
|
177 |
+
self.processor,
|
178 |
+
(LoRAAttnProcessor, LoRAAttnProcessor2_0, LoRAXFormersAttnProcessor, LoRAAttnAddedKVProcessor),
|
179 |
+
)
|
180 |
+
is_custom_diffusion = hasattr(self, "processor") and isinstance(
|
181 |
+
self.processor, (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor)
|
182 |
+
)
|
183 |
+
is_added_kv_processor = hasattr(self, "processor") and isinstance(
|
184 |
+
self.processor,
|
185 |
+
(
|
186 |
+
AttnAddedKVProcessor,
|
187 |
+
AttnAddedKVProcessor2_0,
|
188 |
+
SlicedAttnAddedKVProcessor,
|
189 |
+
XFormersAttnAddedKVProcessor,
|
190 |
+
LoRAAttnAddedKVProcessor,
|
191 |
+
),
|
192 |
+
)
|
193 |
+
|
194 |
+
if use_memory_efficient_attention_xformers:
|
195 |
+
if is_added_kv_processor and (is_lora or is_custom_diffusion):
|
196 |
+
raise NotImplementedError(
|
197 |
+
f"Memory efficient attention is currently not supported for LoRA or custom diffuson for attention processor type {self.processor}"
|
198 |
+
)
|
199 |
+
if not is_xformers_available():
|
200 |
+
raise ModuleNotFoundError(
|
201 |
+
(
|
202 |
+
"Refer to https://github.com/facebookresearch/xformers for more information on how to install"
|
203 |
+
" xformers"
|
204 |
+
),
|
205 |
+
name="xformers",
|
206 |
+
)
|
207 |
+
elif not torch.cuda.is_available():
|
208 |
+
raise ValueError(
|
209 |
+
"torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is"
|
210 |
+
" only available for GPU "
|
211 |
+
)
|
212 |
+
else:
|
213 |
+
try:
|
214 |
+
# Make sure we can run the memory efficient attention
|
215 |
+
_ = xformers.ops.memory_efficient_attention(
|
216 |
+
torch.randn((1, 2, 40), device="cuda"),
|
217 |
+
torch.randn((1, 2, 40), device="cuda"),
|
218 |
+
torch.randn((1, 2, 40), device="cuda"),
|
219 |
+
)
|
220 |
+
except Exception as e:
|
221 |
+
raise e
|
222 |
+
|
223 |
+
if is_lora:
|
224 |
+
# TODO (sayakpaul): should we throw a warning if someone wants to use the xformers
|
225 |
+
# variant when using PT 2.0 now that we have LoRAAttnProcessor2_0?
|
226 |
+
processor = LoRAXFormersAttnProcessor(
|
227 |
+
hidden_size=self.processor.hidden_size,
|
228 |
+
cross_attention_dim=self.processor.cross_attention_dim,
|
229 |
+
rank=self.processor.rank,
|
230 |
+
attention_op=attention_op,
|
231 |
+
)
|
232 |
+
processor.load_state_dict(self.processor.state_dict())
|
233 |
+
processor.to(self.processor.to_q_lora.up.weight.device)
|
234 |
+
elif is_custom_diffusion:
|
235 |
+
processor = CustomDiffusionXFormersAttnProcessor(
|
236 |
+
train_kv=self.processor.train_kv,
|
237 |
+
train_q_out=self.processor.train_q_out,
|
238 |
+
hidden_size=self.processor.hidden_size,
|
239 |
+
cross_attention_dim=self.processor.cross_attention_dim,
|
240 |
+
attention_op=attention_op,
|
241 |
+
)
|
242 |
+
processor.load_state_dict(self.processor.state_dict())
|
243 |
+
if hasattr(self.processor, "to_k_custom_diffusion"):
|
244 |
+
processor.to(self.processor.to_k_custom_diffusion.weight.device)
|
245 |
+
elif is_added_kv_processor:
|
246 |
+
# TODO(Patrick, Suraj, William) - currently xformers doesn't work for UnCLIP
|
247 |
+
# which uses this type of cross attention ONLY because the attention mask of format
|
248 |
+
# [0, ..., -10.000, ..., 0, ...,] is not supported
|
249 |
+
# throw warning
|
250 |
+
logger.info(
|
251 |
+
"Memory efficient attention with `xformers` might currently not work correctly if an attention mask is required for the attention operation."
|
252 |
+
)
|
253 |
+
processor = XFormersAttnAddedKVProcessor(attention_op=attention_op)
|
254 |
+
else:
|
255 |
+
processor = XFormersAttnProcessor(attention_op=attention_op)
|
256 |
+
else:
|
257 |
+
if is_lora:
|
258 |
+
attn_processor_class = (
|
259 |
+
LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
|
260 |
+
)
|
261 |
+
processor = attn_processor_class(
|
262 |
+
hidden_size=self.processor.hidden_size,
|
263 |
+
cross_attention_dim=self.processor.cross_attention_dim,
|
264 |
+
rank=self.processor.rank,
|
265 |
+
)
|
266 |
+
processor.load_state_dict(self.processor.state_dict())
|
267 |
+
processor.to(self.processor.to_q_lora.up.weight.device)
|
268 |
+
elif is_custom_diffusion:
|
269 |
+
processor = CustomDiffusionAttnProcessor(
|
270 |
+
train_kv=self.processor.train_kv,
|
271 |
+
train_q_out=self.processor.train_q_out,
|
272 |
+
hidden_size=self.processor.hidden_size,
|
273 |
+
cross_attention_dim=self.processor.cross_attention_dim,
|
274 |
+
)
|
275 |
+
processor.load_state_dict(self.processor.state_dict())
|
276 |
+
if hasattr(self.processor, "to_k_custom_diffusion"):
|
277 |
+
processor.to(self.processor.to_k_custom_diffusion.weight.device)
|
278 |
+
else:
|
279 |
+
# set attention processor
|
280 |
+
# We use the AttnProcessor2_0 by default when torch 2.x is used which uses
|
281 |
+
# torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
|
282 |
+
# but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
|
283 |
+
processor = (
|
284 |
+
AttnProcessor2_0()
|
285 |
+
if hasattr(F, "scaled_dot_product_attention") and self.scale_qk
|
286 |
+
else AttnProcessor()
|
287 |
+
)
|
288 |
+
|
289 |
+
self.set_processor(processor)
|
290 |
+
|
291 |
+
def set_attention_slice(self, slice_size):
|
292 |
+
if slice_size is not None and slice_size > self.sliceable_head_dim:
|
293 |
+
raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.")
|
294 |
+
|
295 |
+
if slice_size is not None and self.added_kv_proj_dim is not None:
|
296 |
+
processor = SlicedAttnAddedKVProcessor(slice_size)
|
297 |
+
elif slice_size is not None:
|
298 |
+
processor = SlicedAttnProcessor(slice_size)
|
299 |
+
elif self.added_kv_proj_dim is not None:
|
300 |
+
processor = AttnAddedKVProcessor()
|
301 |
+
else:
|
302 |
+
# set attention processor
|
303 |
+
# We use the AttnProcessor2_0 by default when torch 2.x is used which uses
|
304 |
+
# torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
|
305 |
+
# but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
|
306 |
+
processor = (
|
307 |
+
AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
|
308 |
+
)
|
309 |
+
|
310 |
+
self.set_processor(processor)
|
311 |
+
|
312 |
+
def set_processor(self, processor: "AttnProcessor"):
|
313 |
+
# if current processor is in `self._modules` and if passed `processor` is not, we need to
|
314 |
+
# pop `processor` from `self._modules`
|
315 |
+
if (
|
316 |
+
hasattr(self, "processor")
|
317 |
+
and isinstance(self.processor, torch.nn.Module)
|
318 |
+
and not isinstance(processor, torch.nn.Module)
|
319 |
+
):
|
320 |
+
logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}")
|
321 |
+
self._modules.pop("processor")
|
322 |
+
|
323 |
+
self.processor = processor
|
324 |
+
|
325 |
+
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, **cross_attention_kwargs):
|
326 |
+
# The `Attention` class can call different attention processors / attention functions
|
327 |
+
# here we simply pass along all tensors to the selected processor class
|
328 |
+
# For standard processors that are defined here, `**cross_attention_kwargs` is empty
|
329 |
+
return self.processor(
|
330 |
+
self,
|
331 |
+
hidden_states,
|
332 |
+
encoder_hidden_states=encoder_hidden_states,
|
333 |
+
attention_mask=attention_mask,
|
334 |
+
**cross_attention_kwargs,
|
335 |
+
)
|
336 |
+
|
337 |
+
def batch_to_head_dim(self, tensor):
|
338 |
+
head_size = self.heads
|
339 |
+
batch_size, seq_len, dim = tensor.shape
|
340 |
+
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
|
341 |
+
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
|
342 |
+
return tensor
|
343 |
+
|
344 |
+
def head_to_batch_dim(self, tensor, out_dim=3):
|
345 |
+
head_size = self.heads
|
346 |
+
batch_size, seq_len, dim = tensor.shape
|
347 |
+
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
|
348 |
+
tensor = tensor.permute(0, 2, 1, 3)
|
349 |
+
|
350 |
+
if out_dim == 3:
|
351 |
+
tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size)
|
352 |
+
|
353 |
+
return tensor
|
354 |
+
|
355 |
+
def get_attention_scores(self, query, key, attention_mask=None):
|
356 |
+
dtype = query.dtype
|
357 |
+
if self.upcast_attention:
|
358 |
+
query = query.float()
|
359 |
+
key = key.float()
|
360 |
+
|
361 |
+
if attention_mask is None:
|
362 |
+
baddbmm_input = torch.empty(
|
363 |
+
query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device
|
364 |
+
)
|
365 |
+
beta = 0
|
366 |
+
else:
|
367 |
+
baddbmm_input = attention_mask
|
368 |
+
beta = 1
|
369 |
+
|
370 |
+
attention_scores = torch.baddbmm(
|
371 |
+
baddbmm_input,
|
372 |
+
query,
|
373 |
+
key.transpose(-1, -2),
|
374 |
+
beta=beta,
|
375 |
+
alpha=self.scale,
|
376 |
+
)
|
377 |
+
del baddbmm_input
|
378 |
+
|
379 |
+
if self.upcast_softmax:
|
380 |
+
attention_scores = attention_scores.float()
|
381 |
+
|
382 |
+
attention_probs = attention_scores.softmax(dim=-1)
|
383 |
+
del attention_scores
|
384 |
+
|
385 |
+
attention_probs = attention_probs.to(dtype)
|
386 |
+
|
387 |
+
return attention_probs
|
388 |
+
|
389 |
+
def prepare_attention_mask(self, attention_mask, target_length, batch_size=None, out_dim=3):
|
390 |
+
if batch_size is None:
|
391 |
+
deprecate(
|
392 |
+
"batch_size=None",
|
393 |
+
"0.0.15",
|
394 |
+
(
|
395 |
+
"Not passing the `batch_size` parameter to `prepare_attention_mask` can lead to incorrect"
|
396 |
+
" attention mask preparation and is deprecated behavior. Please make sure to pass `batch_size` to"
|
397 |
+
" `prepare_attention_mask` when preparing the attention_mask."
|
398 |
+
),
|
399 |
+
)
|
400 |
+
batch_size = 1
|
401 |
+
|
402 |
+
head_size = self.heads
|
403 |
+
if attention_mask is None:
|
404 |
+
return attention_mask
|
405 |
+
|
406 |
+
current_length: int = attention_mask.shape[-1]
|
407 |
+
if current_length != target_length:
|
408 |
+
if attention_mask.device.type == "mps":
|
409 |
+
# HACK: MPS: Does not support padding by greater than dimension of input tensor.
|
410 |
+
# Instead, we can manually construct the padding tensor.
|
411 |
+
padding_shape = (attention_mask.shape[0], attention_mask.shape[1], target_length)
|
412 |
+
padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device)
|
413 |
+
attention_mask = torch.cat([attention_mask, padding], dim=2)
|
414 |
+
else:
|
415 |
+
# TODO: for pipelines such as stable-diffusion, padding cross-attn mask:
|
416 |
+
# we want to instead pad by (0, remaining_length), where remaining_length is:
|
417 |
+
# remaining_length: int = target_length - current_length
|
418 |
+
# TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding
|
419 |
+
attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
|
420 |
+
|
421 |
+
if out_dim == 3:
|
422 |
+
if attention_mask.shape[0] < batch_size * head_size:
|
423 |
+
attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
|
424 |
+
elif out_dim == 4:
|
425 |
+
attention_mask = attention_mask.unsqueeze(1)
|
426 |
+
attention_mask = attention_mask.repeat_interleave(head_size, dim=1)
|
427 |
+
|
428 |
+
return attention_mask
|
429 |
+
|
430 |
+
def norm_encoder_hidden_states(self, encoder_hidden_states):
|
431 |
+
assert self.norm_cross is not None, "self.norm_cross must be defined to call self.norm_encoder_hidden_states"
|
432 |
+
|
433 |
+
if isinstance(self.norm_cross, nn.LayerNorm):
|
434 |
+
encoder_hidden_states = self.norm_cross(encoder_hidden_states)
|
435 |
+
elif isinstance(self.norm_cross, nn.GroupNorm):
|
436 |
+
# Group norm norms along the channels dimension and expects
|
437 |
+
# input to be in the shape of (N, C, *). In this case, we want
|
438 |
+
# to norm along the hidden dimension, so we need to move
|
439 |
+
# (batch_size, sequence_length, hidden_size) ->
|
440 |
+
# (batch_size, hidden_size, sequence_length)
|
441 |
+
encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
|
442 |
+
encoder_hidden_states = self.norm_cross(encoder_hidden_states)
|
443 |
+
encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
|
444 |
+
else:
|
445 |
+
assert False
|
446 |
+
|
447 |
+
return encoder_hidden_states
|
448 |
+
|
449 |
+
|
450 |
+
class AttnProcessor:
|
451 |
+
r"""
|
452 |
+
Default processor for performing attention-related computations.
|
453 |
+
"""
|
454 |
+
|
455 |
+
def __call__(
|
456 |
+
self,
|
457 |
+
attn: Attention,
|
458 |
+
hidden_states,
|
459 |
+
encoder_hidden_states=None,
|
460 |
+
attention_mask=None,
|
461 |
+
temb=None,
|
462 |
+
):
|
463 |
+
residual = hidden_states
|
464 |
+
|
465 |
+
if attn.spatial_norm is not None:
|
466 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
467 |
+
|
468 |
+
input_ndim = hidden_states.ndim
|
469 |
+
|
470 |
+
if input_ndim == 4:
|
471 |
+
batch_size, channel, height, width = hidden_states.shape
|
472 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
473 |
+
|
474 |
+
batch_size, sequence_length, _ = (
|
475 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
476 |
+
)
|
477 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
478 |
+
|
479 |
+
if attn.group_norm is not None:
|
480 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
481 |
+
|
482 |
+
query = attn.to_q(hidden_states)
|
483 |
+
|
484 |
+
if encoder_hidden_states is None:
|
485 |
+
encoder_hidden_states = hidden_states
|
486 |
+
elif attn.norm_cross:
|
487 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
488 |
+
|
489 |
+
key = attn.to_k(encoder_hidden_states)
|
490 |
+
value = attn.to_v(encoder_hidden_states)
|
491 |
+
|
492 |
+
query = attn.head_to_batch_dim(query)
|
493 |
+
key = attn.head_to_batch_dim(key)
|
494 |
+
value = attn.head_to_batch_dim(value)
|
495 |
+
|
496 |
+
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
497 |
+
hidden_states = torch.bmm(attention_probs, value)
|
498 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
499 |
+
|
500 |
+
# linear proj
|
501 |
+
hidden_states = attn.to_out[0](hidden_states)
|
502 |
+
# dropout
|
503 |
+
hidden_states = attn.to_out[1](hidden_states)
|
504 |
+
|
505 |
+
if input_ndim == 4:
|
506 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
507 |
+
|
508 |
+
if attn.residual_connection:
|
509 |
+
hidden_states = hidden_states + residual
|
510 |
+
|
511 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
512 |
+
|
513 |
+
return hidden_states
|
514 |
+
|
515 |
+
|
516 |
+
class LoRALinearLayer(nn.Module):
|
517 |
+
def __init__(self, in_features, out_features, rank=4, network_alpha=None):
|
518 |
+
super().__init__()
|
519 |
+
|
520 |
+
if rank > min(in_features, out_features):
|
521 |
+
raise ValueError(f"LoRA rank {rank} must be less or equal than {min(in_features, out_features)}")
|
522 |
+
|
523 |
+
self.down = nn.Linear(in_features, rank, bias=False)
|
524 |
+
self.up = nn.Linear(rank, out_features, bias=False)
|
525 |
+
# This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
|
526 |
+
# See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
|
527 |
+
self.network_alpha = network_alpha
|
528 |
+
self.rank = rank
|
529 |
+
|
530 |
+
nn.init.normal_(self.down.weight, std=1 / rank)
|
531 |
+
nn.init.zeros_(self.up.weight)
|
532 |
+
|
533 |
+
def forward(self, hidden_states):
|
534 |
+
orig_dtype = hidden_states.dtype
|
535 |
+
dtype = self.down.weight.dtype
|
536 |
+
|
537 |
+
down_hidden_states = self.down(hidden_states.to(dtype))
|
538 |
+
up_hidden_states = self.up(down_hidden_states)
|
539 |
+
|
540 |
+
if self.network_alpha is not None:
|
541 |
+
up_hidden_states *= self.network_alpha / self.rank
|
542 |
+
|
543 |
+
return up_hidden_states.to(orig_dtype)
|
544 |
+
|
545 |
+
|
546 |
+
class LoRAAttnProcessor(nn.Module):
|
547 |
+
r"""
|
548 |
+
Processor for implementing the LoRA attention mechanism.
|
549 |
+
|
550 |
+
Args:
|
551 |
+
hidden_size (`int`, *optional*):
|
552 |
+
The hidden size of the attention layer.
|
553 |
+
cross_attention_dim (`int`, *optional*):
|
554 |
+
The number of channels in the `encoder_hidden_states`.
|
555 |
+
rank (`int`, defaults to 4):
|
556 |
+
The dimension of the LoRA update matrices.
|
557 |
+
network_alpha (`int`, *optional*):
|
558 |
+
Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.
|
559 |
+
"""
|
560 |
+
|
561 |
+
def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None):
|
562 |
+
super().__init__()
|
563 |
+
|
564 |
+
self.hidden_size = hidden_size
|
565 |
+
self.cross_attention_dim = cross_attention_dim
|
566 |
+
self.rank = rank
|
567 |
+
|
568 |
+
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
|
569 |
+
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
|
570 |
+
self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
|
571 |
+
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
|
572 |
+
|
573 |
+
def __call__(
|
574 |
+
self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0, temb=None
|
575 |
+
):
|
576 |
+
residual = hidden_states
|
577 |
+
|
578 |
+
if attn.spatial_norm is not None:
|
579 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
580 |
+
|
581 |
+
input_ndim = hidden_states.ndim
|
582 |
+
|
583 |
+
if input_ndim == 4:
|
584 |
+
batch_size, channel, height, width = hidden_states.shape
|
585 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
586 |
+
|
587 |
+
batch_size, sequence_length, _ = (
|
588 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
589 |
+
)
|
590 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
591 |
+
|
592 |
+
if attn.group_norm is not None:
|
593 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
594 |
+
|
595 |
+
query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
|
596 |
+
query = attn.head_to_batch_dim(query)
|
597 |
+
|
598 |
+
if encoder_hidden_states is None:
|
599 |
+
encoder_hidden_states = hidden_states
|
600 |
+
elif attn.norm_cross:
|
601 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
602 |
+
|
603 |
+
key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states)
|
604 |
+
value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states)
|
605 |
+
|
606 |
+
key = attn.head_to_batch_dim(key)
|
607 |
+
value = attn.head_to_batch_dim(value)
|
608 |
+
|
609 |
+
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
610 |
+
hidden_states = torch.bmm(attention_probs, value)
|
611 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
612 |
+
|
613 |
+
# linear proj
|
614 |
+
hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states)
|
615 |
+
# dropout
|
616 |
+
hidden_states = attn.to_out[1](hidden_states)
|
617 |
+
|
618 |
+
if input_ndim == 4:
|
619 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
620 |
+
|
621 |
+
if attn.residual_connection:
|
622 |
+
hidden_states = hidden_states + residual
|
623 |
+
|
624 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
625 |
+
|
626 |
+
return hidden_states
|
627 |
+
|
628 |
+
|
629 |
+
class CustomDiffusionAttnProcessor(nn.Module):
|
630 |
+
r"""
|
631 |
+
Processor for implementing attention for the Custom Diffusion method.
|
632 |
+
|
633 |
+
Args:
|
634 |
+
train_kv (`bool`, defaults to `True`):
|
635 |
+
Whether to newly train the key and value matrices corresponding to the text features.
|
636 |
+
train_q_out (`bool`, defaults to `True`):
|
637 |
+
Whether to newly train query matrices corresponding to the latent image features.
|
638 |
+
hidden_size (`int`, *optional*, defaults to `None`):
|
639 |
+
The hidden size of the attention layer.
|
640 |
+
cross_attention_dim (`int`, *optional*, defaults to `None`):
|
641 |
+
The number of channels in the `encoder_hidden_states`.
|
642 |
+
out_bias (`bool`, defaults to `True`):
|
643 |
+
Whether to include the bias parameter in `train_q_out`.
|
644 |
+
dropout (`float`, *optional*, defaults to 0.0):
|
645 |
+
The dropout probability to use.
|
646 |
+
"""
|
647 |
+
|
648 |
+
def __init__(
|
649 |
+
self,
|
650 |
+
train_kv=True,
|
651 |
+
train_q_out=True,
|
652 |
+
hidden_size=None,
|
653 |
+
cross_attention_dim=None,
|
654 |
+
out_bias=True,
|
655 |
+
dropout=0.0,
|
656 |
+
):
|
657 |
+
super().__init__()
|
658 |
+
self.train_kv = train_kv
|
659 |
+
self.train_q_out = train_q_out
|
660 |
+
|
661 |
+
self.hidden_size = hidden_size
|
662 |
+
self.cross_attention_dim = cross_attention_dim
|
663 |
+
|
664 |
+
# `_custom_diffusion` id for easy serialization and loading.
|
665 |
+
if self.train_kv:
|
666 |
+
self.to_k_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
667 |
+
self.to_v_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
668 |
+
if self.train_q_out:
|
669 |
+
self.to_q_custom_diffusion = nn.Linear(hidden_size, hidden_size, bias=False)
|
670 |
+
self.to_out_custom_diffusion = nn.ModuleList([])
|
671 |
+
self.to_out_custom_diffusion.append(nn.Linear(hidden_size, hidden_size, bias=out_bias))
|
672 |
+
self.to_out_custom_diffusion.append(nn.Dropout(dropout))
|
673 |
+
|
674 |
+
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
675 |
+
batch_size, sequence_length, _ = hidden_states.shape
|
676 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
677 |
+
if self.train_q_out:
|
678 |
+
query = self.to_q_custom_diffusion(hidden_states)
|
679 |
+
else:
|
680 |
+
query = attn.to_q(hidden_states)
|
681 |
+
|
682 |
+
if encoder_hidden_states is None:
|
683 |
+
crossattn = False
|
684 |
+
encoder_hidden_states = hidden_states
|
685 |
+
else:
|
686 |
+
crossattn = True
|
687 |
+
if attn.norm_cross:
|
688 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
689 |
+
|
690 |
+
if self.train_kv:
|
691 |
+
key = self.to_k_custom_diffusion(encoder_hidden_states)
|
692 |
+
value = self.to_v_custom_diffusion(encoder_hidden_states)
|
693 |
+
else:
|
694 |
+
key = attn.to_k(encoder_hidden_states)
|
695 |
+
value = attn.to_v(encoder_hidden_states)
|
696 |
+
|
697 |
+
if crossattn:
|
698 |
+
detach = torch.ones_like(key)
|
699 |
+
detach[:, :1, :] = detach[:, :1, :] * 0.0
|
700 |
+
key = detach * key + (1 - detach) * key.detach()
|
701 |
+
value = detach * value + (1 - detach) * value.detach()
|
702 |
+
|
703 |
+
query = attn.head_to_batch_dim(query)
|
704 |
+
key = attn.head_to_batch_dim(key)
|
705 |
+
value = attn.head_to_batch_dim(value)
|
706 |
+
|
707 |
+
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
708 |
+
hidden_states = torch.bmm(attention_probs, value)
|
709 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
710 |
+
|
711 |
+
if self.train_q_out:
|
712 |
+
# linear proj
|
713 |
+
hidden_states = self.to_out_custom_diffusion[0](hidden_states)
|
714 |
+
# dropout
|
715 |
+
hidden_states = self.to_out_custom_diffusion[1](hidden_states)
|
716 |
+
else:
|
717 |
+
# linear proj
|
718 |
+
hidden_states = attn.to_out[0](hidden_states)
|
719 |
+
# dropout
|
720 |
+
hidden_states = attn.to_out[1](hidden_states)
|
721 |
+
|
722 |
+
return hidden_states
|
723 |
+
|
724 |
+
|
725 |
+
class AttnAddedKVProcessor:
|
726 |
+
r"""
|
727 |
+
Processor for performing attention-related computations with extra learnable key and value matrices for the text
|
728 |
+
encoder.
|
729 |
+
"""
|
730 |
+
|
731 |
+
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
732 |
+
residual = hidden_states
|
733 |
+
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
|
734 |
+
batch_size, sequence_length, _ = hidden_states.shape
|
735 |
+
|
736 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
737 |
+
|
738 |
+
if encoder_hidden_states is None:
|
739 |
+
encoder_hidden_states = hidden_states
|
740 |
+
elif attn.norm_cross:
|
741 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
742 |
+
|
743 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
744 |
+
|
745 |
+
query = attn.to_q(hidden_states)
|
746 |
+
query = attn.head_to_batch_dim(query)
|
747 |
+
|
748 |
+
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
|
749 |
+
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
|
750 |
+
encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
|
751 |
+
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
|
752 |
+
|
753 |
+
if not attn.only_cross_attention:
|
754 |
+
key = attn.to_k(hidden_states)
|
755 |
+
value = attn.to_v(hidden_states)
|
756 |
+
key = attn.head_to_batch_dim(key)
|
757 |
+
value = attn.head_to_batch_dim(value)
|
758 |
+
key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
|
759 |
+
value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
|
760 |
+
else:
|
761 |
+
key = encoder_hidden_states_key_proj
|
762 |
+
value = encoder_hidden_states_value_proj
|
763 |
+
|
764 |
+
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
765 |
+
hidden_states = torch.bmm(attention_probs, value)
|
766 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
767 |
+
|
768 |
+
# linear proj
|
769 |
+
hidden_states = attn.to_out[0](hidden_states)
|
770 |
+
# dropout
|
771 |
+
hidden_states = attn.to_out[1](hidden_states)
|
772 |
+
|
773 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
|
774 |
+
hidden_states = hidden_states + residual
|
775 |
+
|
776 |
+
return hidden_states
|
777 |
+
|
778 |
+
|
779 |
+
class AttnAddedKVProcessor2_0:
|
780 |
+
r"""
|
781 |
+
Processor for performing scaled dot-product attention (enabled by default if you're using PyTorch 2.0), with extra
|
782 |
+
learnable key and value matrices for the text encoder.
|
783 |
+
"""
|
784 |
+
|
785 |
+
def __init__(self):
|
786 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
787 |
+
raise ImportError(
|
788 |
+
"AttnAddedKVProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
|
789 |
+
)
|
790 |
+
|
791 |
+
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
792 |
+
residual = hidden_states
|
793 |
+
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
|
794 |
+
batch_size, sequence_length, _ = hidden_states.shape
|
795 |
+
|
796 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size, out_dim=4)
|
797 |
+
|
798 |
+
if encoder_hidden_states is None:
|
799 |
+
encoder_hidden_states = hidden_states
|
800 |
+
elif attn.norm_cross:
|
801 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
802 |
+
|
803 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
804 |
+
|
805 |
+
query = attn.to_q(hidden_states)
|
806 |
+
query = attn.head_to_batch_dim(query, out_dim=4)
|
807 |
+
|
808 |
+
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
|
809 |
+
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
|
810 |
+
encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj, out_dim=4)
|
811 |
+
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj, out_dim=4)
|
812 |
+
|
813 |
+
if not attn.only_cross_attention:
|
814 |
+
key = attn.to_k(hidden_states)
|
815 |
+
value = attn.to_v(hidden_states)
|
816 |
+
key = attn.head_to_batch_dim(key, out_dim=4)
|
817 |
+
value = attn.head_to_batch_dim(value, out_dim=4)
|
818 |
+
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
|
819 |
+
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
|
820 |
+
else:
|
821 |
+
key = encoder_hidden_states_key_proj
|
822 |
+
value = encoder_hidden_states_value_proj
|
823 |
+
|
824 |
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
825 |
+
# TODO: add support for attn.scale when we move to Torch 2.1
|
826 |
+
hidden_states = F.scaled_dot_product_attention(
|
827 |
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
828 |
+
)
|
829 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, residual.shape[1])
|
830 |
+
|
831 |
+
# linear proj
|
832 |
+
hidden_states = attn.to_out[0](hidden_states)
|
833 |
+
# dropout
|
834 |
+
hidden_states = attn.to_out[1](hidden_states)
|
835 |
+
|
836 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
|
837 |
+
hidden_states = hidden_states + residual
|
838 |
+
|
839 |
+
return hidden_states
|
840 |
+
|
841 |
+
|
842 |
+
class LoRAAttnAddedKVProcessor(nn.Module):
|
843 |
+
r"""
|
844 |
+
Processor for implementing the LoRA attention mechanism with extra learnable key and value matrices for the text
|
845 |
+
encoder.
|
846 |
+
|
847 |
+
Args:
|
848 |
+
hidden_size (`int`, *optional*):
|
849 |
+
The hidden size of the attention layer.
|
850 |
+
cross_attention_dim (`int`, *optional*, defaults to `None`):
|
851 |
+
The number of channels in the `encoder_hidden_states`.
|
852 |
+
rank (`int`, defaults to 4):
|
853 |
+
The dimension of the LoRA update matrices.
|
854 |
+
|
855 |
+
"""
|
856 |
+
|
857 |
+
def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None):
|
858 |
+
super().__init__()
|
859 |
+
|
860 |
+
self.hidden_size = hidden_size
|
861 |
+
self.cross_attention_dim = cross_attention_dim
|
862 |
+
self.rank = rank
|
863 |
+
|
864 |
+
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
|
865 |
+
self.add_k_proj_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
|
866 |
+
self.add_v_proj_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
|
867 |
+
self.to_k_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
|
868 |
+
self.to_v_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
|
869 |
+
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
|
870 |
+
|
871 |
+
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0):
|
872 |
+
residual = hidden_states
|
873 |
+
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
|
874 |
+
batch_size, sequence_length, _ = hidden_states.shape
|
875 |
+
|
876 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
877 |
+
|
878 |
+
if encoder_hidden_states is None:
|
879 |
+
encoder_hidden_states = hidden_states
|
880 |
+
elif attn.norm_cross:
|
881 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
882 |
+
|
883 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
884 |
+
|
885 |
+
query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
|
886 |
+
query = attn.head_to_batch_dim(query)
|
887 |
+
|
888 |
+
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) + scale * self.add_k_proj_lora(
|
889 |
+
encoder_hidden_states
|
890 |
+
)
|
891 |
+
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) + scale * self.add_v_proj_lora(
|
892 |
+
encoder_hidden_states
|
893 |
+
)
|
894 |
+
encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
|
895 |
+
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
|
896 |
+
|
897 |
+
if not attn.only_cross_attention:
|
898 |
+
key = attn.to_k(hidden_states) + scale * self.to_k_lora(hidden_states)
|
899 |
+
value = attn.to_v(hidden_states) + scale * self.to_v_lora(hidden_states)
|
900 |
+
key = attn.head_to_batch_dim(key)
|
901 |
+
value = attn.head_to_batch_dim(value)
|
902 |
+
key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
|
903 |
+
value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
|
904 |
+
else:
|
905 |
+
key = encoder_hidden_states_key_proj
|
906 |
+
value = encoder_hidden_states_value_proj
|
907 |
+
|
908 |
+
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
909 |
+
hidden_states = torch.bmm(attention_probs, value)
|
910 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
911 |
+
|
912 |
+
# linear proj
|
913 |
+
hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states)
|
914 |
+
# dropout
|
915 |
+
hidden_states = attn.to_out[1](hidden_states)
|
916 |
+
|
917 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
|
918 |
+
hidden_states = hidden_states + residual
|
919 |
+
|
920 |
+
return hidden_states
|
921 |
+
|
922 |
+
|
923 |
+
class XFormersAttnAddedKVProcessor:
|
924 |
+
r"""
|
925 |
+
Processor for implementing memory efficient attention using xFormers.
|
926 |
+
|
927 |
+
Args:
|
928 |
+
attention_op (`Callable`, *optional*, defaults to `None`):
|
929 |
+
The base
|
930 |
+
[operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
|
931 |
+
use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
|
932 |
+
operator.
|
933 |
+
"""
|
934 |
+
|
935 |
+
def __init__(self, attention_op: Optional[Callable] = None):
|
936 |
+
self.attention_op = attention_op
|
937 |
+
|
938 |
+
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
939 |
+
residual = hidden_states
|
940 |
+
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
|
941 |
+
batch_size, sequence_length, _ = hidden_states.shape
|
942 |
+
|
943 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
944 |
+
|
945 |
+
if encoder_hidden_states is None:
|
946 |
+
encoder_hidden_states = hidden_states
|
947 |
+
elif attn.norm_cross:
|
948 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
949 |
+
|
950 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
951 |
+
|
952 |
+
query = attn.to_q(hidden_states)
|
953 |
+
query = attn.head_to_batch_dim(query)
|
954 |
+
|
955 |
+
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
|
956 |
+
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
|
957 |
+
encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
|
958 |
+
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
|
959 |
+
|
960 |
+
if not attn.only_cross_attention:
|
961 |
+
key = attn.to_k(hidden_states)
|
962 |
+
value = attn.to_v(hidden_states)
|
963 |
+
key = attn.head_to_batch_dim(key)
|
964 |
+
value = attn.head_to_batch_dim(value)
|
965 |
+
key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
|
966 |
+
value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
|
967 |
+
else:
|
968 |
+
key = encoder_hidden_states_key_proj
|
969 |
+
value = encoder_hidden_states_value_proj
|
970 |
+
|
971 |
+
hidden_states = xformers.ops.memory_efficient_attention(
|
972 |
+
query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
|
973 |
+
)
|
974 |
+
hidden_states = hidden_states.to(query.dtype)
|
975 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
976 |
+
|
977 |
+
# linear proj
|
978 |
+
hidden_states = attn.to_out[0](hidden_states)
|
979 |
+
# dropout
|
980 |
+
hidden_states = attn.to_out[1](hidden_states)
|
981 |
+
|
982 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
|
983 |
+
hidden_states = hidden_states + residual
|
984 |
+
|
985 |
+
return hidden_states
|
986 |
+
|
987 |
+
|
988 |
+
class XFormersAttnProcessor:
|
989 |
+
r"""
|
990 |
+
Processor for implementing memory efficient attention using xFormers.
|
991 |
+
|
992 |
+
Args:
|
993 |
+
attention_op (`Callable`, *optional*, defaults to `None`):
|
994 |
+
The base
|
995 |
+
[operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
|
996 |
+
use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
|
997 |
+
operator.
|
998 |
+
"""
|
999 |
+
|
1000 |
+
def __init__(self, attention_op: Optional[Callable] = None):
|
1001 |
+
self.attention_op = attention_op
|
1002 |
+
|
1003 |
+
def __call__(
|
1004 |
+
self,
|
1005 |
+
attn: Attention,
|
1006 |
+
hidden_states: torch.FloatTensor,
|
1007 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
1008 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
1009 |
+
temb: Optional[torch.FloatTensor] = None,
|
1010 |
+
posemb: Optional = None,
|
1011 |
+
):
|
1012 |
+
residual = hidden_states
|
1013 |
+
|
1014 |
+
if attn.spatial_norm is not None:
|
1015 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
1016 |
+
|
1017 |
+
input_ndim = hidden_states.ndim
|
1018 |
+
|
1019 |
+
if input_ndim == 4:
|
1020 |
+
batch_size, channel, height, width = hidden_states.shape
|
1021 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
1022 |
+
|
1023 |
+
if posemb is not None:
|
1024 |
+
# turn 2d attention into multiview attention
|
1025 |
+
self_attn = encoder_hidden_states is None # check if self attn or cross attn
|
1026 |
+
[p_out, p_out_inv], [p_in, p_in_inv] = posemb
|
1027 |
+
t_out, t_in = p_out.shape[1], p_in.shape[1] # t size
|
1028 |
+
hidden_states = einops.rearrange(hidden_states, '(b t_out) l d -> b (t_out l) d', t_out=t_out)
|
1029 |
+
|
1030 |
+
batch_size, key_tokens, _ = (
|
1031 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
1032 |
+
)
|
1033 |
+
|
1034 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, key_tokens, batch_size)
|
1035 |
+
if attention_mask is not None:
|
1036 |
+
# expand our mask's singleton query_tokens dimension:
|
1037 |
+
# [batch*heads, 1, key_tokens] ->
|
1038 |
+
# [batch*heads, query_tokens, key_tokens]
|
1039 |
+
# so that it can be added as a bias onto the attention scores that xformers computes:
|
1040 |
+
# [batch*heads, query_tokens, key_tokens]
|
1041 |
+
# we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
|
1042 |
+
_, query_tokens, _ = hidden_states.shape
|
1043 |
+
attention_mask = attention_mask.expand(-1, query_tokens, -1)
|
1044 |
+
|
1045 |
+
if attn.group_norm is not None:
|
1046 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
1047 |
+
|
1048 |
+
query = attn.to_q(hidden_states)
|
1049 |
+
if encoder_hidden_states is None:
|
1050 |
+
encoder_hidden_states = hidden_states
|
1051 |
+
elif attn.norm_cross:
|
1052 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
1053 |
+
|
1054 |
+
key = attn.to_k(encoder_hidden_states)
|
1055 |
+
value = attn.to_v(encoder_hidden_states)
|
1056 |
+
|
1057 |
+
|
1058 |
+
# apply 6DoF, todo now only for xformer processor
|
1059 |
+
if posemb is not None:
|
1060 |
+
p_out_inv = einops.repeat(p_out_inv, 'b t_out f g -> b (t_out l) f g', l=query.shape[1] // t_out) # query shape
|
1061 |
+
if self_attn:
|
1062 |
+
p_in = einops.repeat(p_out, 'b t_out f g -> b (t_out l) f g', l=query.shape[1] // t_out) # query shape
|
1063 |
+
else:
|
1064 |
+
p_in = einops.repeat(p_in, 'b t_in f g -> b (t_in l) f g', l=key.shape[1] // t_in) # key shape
|
1065 |
+
query = cape_embed(query, p_out_inv) # query f_q @ (p_out)^(-T) .permute(0, 1, 3, 2)
|
1066 |
+
key = cape_embed(key, p_in) # key f_k @ p_in
|
1067 |
+
|
1068 |
+
|
1069 |
+
query = attn.head_to_batch_dim(query).contiguous()
|
1070 |
+
key = attn.head_to_batch_dim(key).contiguous()
|
1071 |
+
value = attn.head_to_batch_dim(value).contiguous()
|
1072 |
+
|
1073 |
+
# self-ttn (bm) l c x (bm) l c -> (bm) l c
|
1074 |
+
# cross-ttn (bm) l c x b (nl) c -> (bm) l c
|
1075 |
+
# reuse 2d attention for multiview attention
|
1076 |
+
# self-ttn b (ml) c x b (ml) c -> b (ml) c
|
1077 |
+
# cross-ttn b (ml) c x b (nl) c -> b (ml) c
|
1078 |
+
hidden_states = xformers.ops.memory_efficient_attention( # query: (bm) l c -> b (ml) c; key: b (nl) c
|
1079 |
+
query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
|
1080 |
+
)
|
1081 |
+
hidden_states = hidden_states.to(query.dtype)
|
1082 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
1083 |
+
|
1084 |
+
# linear proj
|
1085 |
+
hidden_states = attn.to_out[0](hidden_states)
|
1086 |
+
# dropout
|
1087 |
+
hidden_states = attn.to_out[1](hidden_states)
|
1088 |
+
|
1089 |
+
if posemb is not None:
|
1090 |
+
# reshape back
|
1091 |
+
hidden_states = einops.rearrange(hidden_states, 'b (t_out l) d -> (b t_out) l d', t_out=t_out)
|
1092 |
+
|
1093 |
+
if input_ndim == 4:
|
1094 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
1095 |
+
|
1096 |
+
if attn.residual_connection:
|
1097 |
+
hidden_states = hidden_states + residual
|
1098 |
+
|
1099 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
1100 |
+
|
1101 |
+
|
1102 |
+
return hidden_states
|
1103 |
+
|
1104 |
+
|
1105 |
+
class AttnProcessor2_0:
|
1106 |
+
r"""
|
1107 |
+
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
|
1108 |
+
"""
|
1109 |
+
|
1110 |
+
def __init__(self):
|
1111 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
1112 |
+
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
1113 |
+
|
1114 |
+
def __call__(
|
1115 |
+
self,
|
1116 |
+
attn: Attention,
|
1117 |
+
hidden_states,
|
1118 |
+
encoder_hidden_states=None,
|
1119 |
+
attention_mask=None,
|
1120 |
+
temb=None,
|
1121 |
+
):
|
1122 |
+
residual = hidden_states
|
1123 |
+
|
1124 |
+
if attn.spatial_norm is not None:
|
1125 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
1126 |
+
|
1127 |
+
input_ndim = hidden_states.ndim
|
1128 |
+
|
1129 |
+
if input_ndim == 4:
|
1130 |
+
batch_size, channel, height, width = hidden_states.shape
|
1131 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
1132 |
+
|
1133 |
+
batch_size, sequence_length, _ = (
|
1134 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
1135 |
+
)
|
1136 |
+
inner_dim = hidden_states.shape[-1]
|
1137 |
+
|
1138 |
+
if attention_mask is not None:
|
1139 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
1140 |
+
# scaled_dot_product_attention expects attention_mask shape to be
|
1141 |
+
# (batch, heads, source_length, target_length)
|
1142 |
+
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
1143 |
+
|
1144 |
+
if attn.group_norm is not None:
|
1145 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
1146 |
+
|
1147 |
+
query = attn.to_q(hidden_states)
|
1148 |
+
|
1149 |
+
if encoder_hidden_states is None:
|
1150 |
+
encoder_hidden_states = hidden_states
|
1151 |
+
elif attn.norm_cross:
|
1152 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
1153 |
+
|
1154 |
+
key = attn.to_k(encoder_hidden_states)
|
1155 |
+
value = attn.to_v(encoder_hidden_states)
|
1156 |
+
|
1157 |
+
head_dim = inner_dim // attn.heads
|
1158 |
+
|
1159 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1160 |
+
|
1161 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1162 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1163 |
+
|
1164 |
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
1165 |
+
# TODO: add support for attn.scale when we move to Torch 2.1
|
1166 |
+
hidden_states = F.scaled_dot_product_attention(
|
1167 |
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
1168 |
+
)
|
1169 |
+
|
1170 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
1171 |
+
hidden_states = hidden_states.to(query.dtype)
|
1172 |
+
|
1173 |
+
# linear proj
|
1174 |
+
hidden_states = attn.to_out[0](hidden_states)
|
1175 |
+
# dropout
|
1176 |
+
hidden_states = attn.to_out[1](hidden_states)
|
1177 |
+
|
1178 |
+
if input_ndim == 4:
|
1179 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
1180 |
+
|
1181 |
+
if attn.residual_connection:
|
1182 |
+
hidden_states = hidden_states + residual
|
1183 |
+
|
1184 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
1185 |
+
|
1186 |
+
return hidden_states
|
1187 |
+
|
1188 |
+
|
1189 |
+
class LoRAXFormersAttnProcessor(nn.Module):
|
1190 |
+
r"""
|
1191 |
+
Processor for implementing the LoRA attention mechanism with memory efficient attention using xFormers.
|
1192 |
+
|
1193 |
+
Args:
|
1194 |
+
hidden_size (`int`, *optional*):
|
1195 |
+
The hidden size of the attention layer.
|
1196 |
+
cross_attention_dim (`int`, *optional*):
|
1197 |
+
The number of channels in the `encoder_hidden_states`.
|
1198 |
+
rank (`int`, defaults to 4):
|
1199 |
+
The dimension of the LoRA update matrices.
|
1200 |
+
attention_op (`Callable`, *optional*, defaults to `None`):
|
1201 |
+
The base
|
1202 |
+
[operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
|
1203 |
+
use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
|
1204 |
+
operator.
|
1205 |
+
network_alpha (`int`, *optional*):
|
1206 |
+
Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.
|
1207 |
+
|
1208 |
+
"""
|
1209 |
+
|
1210 |
+
def __init__(
|
1211 |
+
self, hidden_size, cross_attention_dim, rank=4, attention_op: Optional[Callable] = None, network_alpha=None
|
1212 |
+
):
|
1213 |
+
super().__init__()
|
1214 |
+
|
1215 |
+
self.hidden_size = hidden_size
|
1216 |
+
self.cross_attention_dim = cross_attention_dim
|
1217 |
+
self.rank = rank
|
1218 |
+
self.attention_op = attention_op
|
1219 |
+
|
1220 |
+
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
|
1221 |
+
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
|
1222 |
+
self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
|
1223 |
+
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
|
1224 |
+
|
1225 |
+
def __call__(
|
1226 |
+
self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0, temb=None
|
1227 |
+
):
|
1228 |
+
residual = hidden_states
|
1229 |
+
|
1230 |
+
if attn.spatial_norm is not None:
|
1231 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
1232 |
+
|
1233 |
+
input_ndim = hidden_states.ndim
|
1234 |
+
|
1235 |
+
if input_ndim == 4:
|
1236 |
+
batch_size, channel, height, width = hidden_states.shape
|
1237 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
1238 |
+
|
1239 |
+
batch_size, sequence_length, _ = (
|
1240 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
1241 |
+
)
|
1242 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
1243 |
+
|
1244 |
+
if attn.group_norm is not None:
|
1245 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
1246 |
+
|
1247 |
+
query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
|
1248 |
+
query = attn.head_to_batch_dim(query).contiguous()
|
1249 |
+
|
1250 |
+
if encoder_hidden_states is None:
|
1251 |
+
encoder_hidden_states = hidden_states
|
1252 |
+
elif attn.norm_cross:
|
1253 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
1254 |
+
|
1255 |
+
key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states)
|
1256 |
+
value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states)
|
1257 |
+
|
1258 |
+
key = attn.head_to_batch_dim(key).contiguous()
|
1259 |
+
value = attn.head_to_batch_dim(value).contiguous()
|
1260 |
+
|
1261 |
+
hidden_states = xformers.ops.memory_efficient_attention(
|
1262 |
+
query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
|
1263 |
+
)
|
1264 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
1265 |
+
|
1266 |
+
# linear proj
|
1267 |
+
hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states)
|
1268 |
+
# dropout
|
1269 |
+
hidden_states = attn.to_out[1](hidden_states)
|
1270 |
+
|
1271 |
+
if input_ndim == 4:
|
1272 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
1273 |
+
|
1274 |
+
if attn.residual_connection:
|
1275 |
+
hidden_states = hidden_states + residual
|
1276 |
+
|
1277 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
1278 |
+
|
1279 |
+
return hidden_states
|
1280 |
+
|
1281 |
+
|
1282 |
+
class LoRAAttnProcessor2_0(nn.Module):
|
1283 |
+
r"""
|
1284 |
+
Processor for implementing the LoRA attention mechanism using PyTorch 2.0's memory-efficient scaled dot-product
|
1285 |
+
attention.
|
1286 |
+
|
1287 |
+
Args:
|
1288 |
+
hidden_size (`int`):
|
1289 |
+
The hidden size of the attention layer.
|
1290 |
+
cross_attention_dim (`int`, *optional*):
|
1291 |
+
The number of channels in the `encoder_hidden_states`.
|
1292 |
+
rank (`int`, defaults to 4):
|
1293 |
+
The dimension of the LoRA update matrices.
|
1294 |
+
network_alpha (`int`, *optional*):
|
1295 |
+
Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.
|
1296 |
+
"""
|
1297 |
+
|
1298 |
+
def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None):
|
1299 |
+
super().__init__()
|
1300 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
1301 |
+
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
1302 |
+
|
1303 |
+
self.hidden_size = hidden_size
|
1304 |
+
self.cross_attention_dim = cross_attention_dim
|
1305 |
+
self.rank = rank
|
1306 |
+
|
1307 |
+
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
|
1308 |
+
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
|
1309 |
+
self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
|
1310 |
+
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
|
1311 |
+
|
1312 |
+
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0):
|
1313 |
+
residual = hidden_states
|
1314 |
+
|
1315 |
+
input_ndim = hidden_states.ndim
|
1316 |
+
|
1317 |
+
if input_ndim == 4:
|
1318 |
+
batch_size, channel, height, width = hidden_states.shape
|
1319 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
1320 |
+
|
1321 |
+
batch_size, sequence_length, _ = (
|
1322 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
1323 |
+
)
|
1324 |
+
inner_dim = hidden_states.shape[-1]
|
1325 |
+
|
1326 |
+
if attention_mask is not None:
|
1327 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
1328 |
+
# scaled_dot_product_attention expects attention_mask shape to be
|
1329 |
+
# (batch, heads, source_length, target_length)
|
1330 |
+
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
1331 |
+
|
1332 |
+
if attn.group_norm is not None:
|
1333 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
1334 |
+
|
1335 |
+
query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
|
1336 |
+
|
1337 |
+
if encoder_hidden_states is None:
|
1338 |
+
encoder_hidden_states = hidden_states
|
1339 |
+
elif attn.norm_cross:
|
1340 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
1341 |
+
|
1342 |
+
key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states)
|
1343 |
+
value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states)
|
1344 |
+
|
1345 |
+
head_dim = inner_dim // attn.heads
|
1346 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1347 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1348 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1349 |
+
|
1350 |
+
# TODO: add support for attn.scale when we move to Torch 2.1
|
1351 |
+
hidden_states = F.scaled_dot_product_attention(
|
1352 |
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
1353 |
+
)
|
1354 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
1355 |
+
hidden_states = hidden_states.to(query.dtype)
|
1356 |
+
|
1357 |
+
# linear proj
|
1358 |
+
hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states)
|
1359 |
+
# dropout
|
1360 |
+
hidden_states = attn.to_out[1](hidden_states)
|
1361 |
+
|
1362 |
+
if input_ndim == 4:
|
1363 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
1364 |
+
|
1365 |
+
if attn.residual_connection:
|
1366 |
+
hidden_states = hidden_states + residual
|
1367 |
+
|
1368 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
1369 |
+
|
1370 |
+
return hidden_states
|
1371 |
+
|
1372 |
+
|
1373 |
+
class CustomDiffusionXFormersAttnProcessor(nn.Module):
|
1374 |
+
r"""
|
1375 |
+
Processor for implementing memory efficient attention using xFormers for the Custom Diffusion method.
|
1376 |
+
|
1377 |
+
Args:
|
1378 |
+
train_kv (`bool`, defaults to `True`):
|
1379 |
+
Whether to newly train the key and value matrices corresponding to the text features.
|
1380 |
+
train_q_out (`bool`, defaults to `True`):
|
1381 |
+
Whether to newly train query matrices corresponding to the latent image features.
|
1382 |
+
hidden_size (`int`, *optional*, defaults to `None`):
|
1383 |
+
The hidden size of the attention layer.
|
1384 |
+
cross_attention_dim (`int`, *optional*, defaults to `None`):
|
1385 |
+
The number of channels in the `encoder_hidden_states`.
|
1386 |
+
out_bias (`bool`, defaults to `True`):
|
1387 |
+
Whether to include the bias parameter in `train_q_out`.
|
1388 |
+
dropout (`float`, *optional*, defaults to 0.0):
|
1389 |
+
The dropout probability to use.
|
1390 |
+
attention_op (`Callable`, *optional*, defaults to `None`):
|
1391 |
+
The base
|
1392 |
+
[operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to use
|
1393 |
+
as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best operator.
|
1394 |
+
"""
|
1395 |
+
|
1396 |
+
def __init__(
|
1397 |
+
self,
|
1398 |
+
train_kv=True,
|
1399 |
+
train_q_out=False,
|
1400 |
+
hidden_size=None,
|
1401 |
+
cross_attention_dim=None,
|
1402 |
+
out_bias=True,
|
1403 |
+
dropout=0.0,
|
1404 |
+
attention_op: Optional[Callable] = None,
|
1405 |
+
):
|
1406 |
+
super().__init__()
|
1407 |
+
self.train_kv = train_kv
|
1408 |
+
self.train_q_out = train_q_out
|
1409 |
+
|
1410 |
+
self.hidden_size = hidden_size
|
1411 |
+
self.cross_attention_dim = cross_attention_dim
|
1412 |
+
self.attention_op = attention_op
|
1413 |
+
|
1414 |
+
# `_custom_diffusion` id for easy serialization and loading.
|
1415 |
+
if self.train_kv:
|
1416 |
+
self.to_k_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
1417 |
+
self.to_v_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
1418 |
+
if self.train_q_out:
|
1419 |
+
self.to_q_custom_diffusion = nn.Linear(hidden_size, hidden_size, bias=False)
|
1420 |
+
self.to_out_custom_diffusion = nn.ModuleList([])
|
1421 |
+
self.to_out_custom_diffusion.append(nn.Linear(hidden_size, hidden_size, bias=out_bias))
|
1422 |
+
self.to_out_custom_diffusion.append(nn.Dropout(dropout))
|
1423 |
+
|
1424 |
+
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
1425 |
+
batch_size, sequence_length, _ = (
|
1426 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
1427 |
+
)
|
1428 |
+
|
1429 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
1430 |
+
|
1431 |
+
if self.train_q_out:
|
1432 |
+
query = self.to_q_custom_diffusion(hidden_states)
|
1433 |
+
else:
|
1434 |
+
query = attn.to_q(hidden_states)
|
1435 |
+
|
1436 |
+
if encoder_hidden_states is None:
|
1437 |
+
crossattn = False
|
1438 |
+
encoder_hidden_states = hidden_states
|
1439 |
+
else:
|
1440 |
+
crossattn = True
|
1441 |
+
if attn.norm_cross:
|
1442 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
1443 |
+
|
1444 |
+
if self.train_kv:
|
1445 |
+
key = self.to_k_custom_diffusion(encoder_hidden_states)
|
1446 |
+
value = self.to_v_custom_diffusion(encoder_hidden_states)
|
1447 |
+
else:
|
1448 |
+
key = attn.to_k(encoder_hidden_states)
|
1449 |
+
value = attn.to_v(encoder_hidden_states)
|
1450 |
+
|
1451 |
+
if crossattn:
|
1452 |
+
detach = torch.ones_like(key)
|
1453 |
+
detach[:, :1, :] = detach[:, :1, :] * 0.0
|
1454 |
+
key = detach * key + (1 - detach) * key.detach()
|
1455 |
+
value = detach * value + (1 - detach) * value.detach()
|
1456 |
+
|
1457 |
+
query = attn.head_to_batch_dim(query).contiguous()
|
1458 |
+
key = attn.head_to_batch_dim(key).contiguous()
|
1459 |
+
value = attn.head_to_batch_dim(value).contiguous()
|
1460 |
+
|
1461 |
+
hidden_states = xformers.ops.memory_efficient_attention(
|
1462 |
+
query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
|
1463 |
+
)
|
1464 |
+
hidden_states = hidden_states.to(query.dtype)
|
1465 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
1466 |
+
|
1467 |
+
if self.train_q_out:
|
1468 |
+
# linear proj
|
1469 |
+
hidden_states = self.to_out_custom_diffusion[0](hidden_states)
|
1470 |
+
# dropout
|
1471 |
+
hidden_states = self.to_out_custom_diffusion[1](hidden_states)
|
1472 |
+
else:
|
1473 |
+
# linear proj
|
1474 |
+
hidden_states = attn.to_out[0](hidden_states)
|
1475 |
+
# dropout
|
1476 |
+
hidden_states = attn.to_out[1](hidden_states)
|
1477 |
+
return hidden_states
|
1478 |
+
|
1479 |
+
|
1480 |
+
class SlicedAttnProcessor:
|
1481 |
+
r"""
|
1482 |
+
Processor for implementing sliced attention.
|
1483 |
+
|
1484 |
+
Args:
|
1485 |
+
slice_size (`int`, *optional*):
|
1486 |
+
The number of steps to compute attention. Uses as many slices as `attention_head_dim // slice_size`, and
|
1487 |
+
`attention_head_dim` must be a multiple of the `slice_size`.
|
1488 |
+
"""
|
1489 |
+
|
1490 |
+
def __init__(self, slice_size):
|
1491 |
+
self.slice_size = slice_size
|
1492 |
+
|
1493 |
+
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
1494 |
+
residual = hidden_states
|
1495 |
+
|
1496 |
+
input_ndim = hidden_states.ndim
|
1497 |
+
|
1498 |
+
if input_ndim == 4:
|
1499 |
+
batch_size, channel, height, width = hidden_states.shape
|
1500 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
1501 |
+
|
1502 |
+
batch_size, sequence_length, _ = (
|
1503 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
1504 |
+
)
|
1505 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
1506 |
+
|
1507 |
+
if attn.group_norm is not None:
|
1508 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
1509 |
+
|
1510 |
+
query = attn.to_q(hidden_states)
|
1511 |
+
dim = query.shape[-1]
|
1512 |
+
query = attn.head_to_batch_dim(query)
|
1513 |
+
|
1514 |
+
if encoder_hidden_states is None:
|
1515 |
+
encoder_hidden_states = hidden_states
|
1516 |
+
elif attn.norm_cross:
|
1517 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
1518 |
+
|
1519 |
+
key = attn.to_k(encoder_hidden_states)
|
1520 |
+
value = attn.to_v(encoder_hidden_states)
|
1521 |
+
key = attn.head_to_batch_dim(key)
|
1522 |
+
value = attn.head_to_batch_dim(value)
|
1523 |
+
|
1524 |
+
batch_size_attention, query_tokens, _ = query.shape
|
1525 |
+
hidden_states = torch.zeros(
|
1526 |
+
(batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype
|
1527 |
+
)
|
1528 |
+
|
1529 |
+
for i in range(batch_size_attention // self.slice_size):
|
1530 |
+
start_idx = i * self.slice_size
|
1531 |
+
end_idx = (i + 1) * self.slice_size
|
1532 |
+
|
1533 |
+
query_slice = query[start_idx:end_idx]
|
1534 |
+
key_slice = key[start_idx:end_idx]
|
1535 |
+
attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
|
1536 |
+
|
1537 |
+
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
|
1538 |
+
|
1539 |
+
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
|
1540 |
+
|
1541 |
+
hidden_states[start_idx:end_idx] = attn_slice
|
1542 |
+
|
1543 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
1544 |
+
|
1545 |
+
# linear proj
|
1546 |
+
hidden_states = attn.to_out[0](hidden_states)
|
1547 |
+
# dropout
|
1548 |
+
hidden_states = attn.to_out[1](hidden_states)
|
1549 |
+
|
1550 |
+
if input_ndim == 4:
|
1551 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
1552 |
+
|
1553 |
+
if attn.residual_connection:
|
1554 |
+
hidden_states = hidden_states + residual
|
1555 |
+
|
1556 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
1557 |
+
|
1558 |
+
return hidden_states
|
1559 |
+
|
1560 |
+
|
1561 |
+
class SlicedAttnAddedKVProcessor:
|
1562 |
+
r"""
|
1563 |
+
Processor for implementing sliced attention with extra learnable key and value matrices for the text encoder.
|
1564 |
+
|
1565 |
+
Args:
|
1566 |
+
slice_size (`int`, *optional*):
|
1567 |
+
The number of steps to compute attention. Uses as many slices as `attention_head_dim // slice_size`, and
|
1568 |
+
`attention_head_dim` must be a multiple of the `slice_size`.
|
1569 |
+
"""
|
1570 |
+
|
1571 |
+
def __init__(self, slice_size):
|
1572 |
+
self.slice_size = slice_size
|
1573 |
+
|
1574 |
+
def __call__(self, attn: "Attention", hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None):
|
1575 |
+
residual = hidden_states
|
1576 |
+
|
1577 |
+
if attn.spatial_norm is not None:
|
1578 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
1579 |
+
|
1580 |
+
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
|
1581 |
+
|
1582 |
+
batch_size, sequence_length, _ = hidden_states.shape
|
1583 |
+
|
1584 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
1585 |
+
|
1586 |
+
if encoder_hidden_states is None:
|
1587 |
+
encoder_hidden_states = hidden_states
|
1588 |
+
elif attn.norm_cross:
|
1589 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
1590 |
+
|
1591 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
1592 |
+
|
1593 |
+
query = attn.to_q(hidden_states)
|
1594 |
+
dim = query.shape[-1]
|
1595 |
+
query = attn.head_to_batch_dim(query)
|
1596 |
+
|
1597 |
+
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
|
1598 |
+
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
|
1599 |
+
|
1600 |
+
encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
|
1601 |
+
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
|
1602 |
+
|
1603 |
+
if not attn.only_cross_attention:
|
1604 |
+
key = attn.to_k(hidden_states)
|
1605 |
+
value = attn.to_v(hidden_states)
|
1606 |
+
key = attn.head_to_batch_dim(key)
|
1607 |
+
value = attn.head_to_batch_dim(value)
|
1608 |
+
key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
|
1609 |
+
value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
|
1610 |
+
else:
|
1611 |
+
key = encoder_hidden_states_key_proj
|
1612 |
+
value = encoder_hidden_states_value_proj
|
1613 |
+
|
1614 |
+
batch_size_attention, query_tokens, _ = query.shape
|
1615 |
+
hidden_states = torch.zeros(
|
1616 |
+
(batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype
|
1617 |
+
)
|
1618 |
+
|
1619 |
+
for i in range(batch_size_attention // self.slice_size):
|
1620 |
+
start_idx = i * self.slice_size
|
1621 |
+
end_idx = (i + 1) * self.slice_size
|
1622 |
+
|
1623 |
+
query_slice = query[start_idx:end_idx]
|
1624 |
+
key_slice = key[start_idx:end_idx]
|
1625 |
+
attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
|
1626 |
+
|
1627 |
+
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
|
1628 |
+
|
1629 |
+
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
|
1630 |
+
|
1631 |
+
hidden_states[start_idx:end_idx] = attn_slice
|
1632 |
+
|
1633 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
1634 |
+
|
1635 |
+
# linear proj
|
1636 |
+
hidden_states = attn.to_out[0](hidden_states)
|
1637 |
+
# dropout
|
1638 |
+
hidden_states = attn.to_out[1](hidden_states)
|
1639 |
+
|
1640 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
|
1641 |
+
hidden_states = hidden_states + residual
|
1642 |
+
|
1643 |
+
return hidden_states
|
1644 |
+
|
1645 |
+
|
1646 |
+
AttentionProcessor = Union[
|
1647 |
+
AttnProcessor,
|
1648 |
+
AttnProcessor2_0,
|
1649 |
+
XFormersAttnProcessor,
|
1650 |
+
SlicedAttnProcessor,
|
1651 |
+
AttnAddedKVProcessor,
|
1652 |
+
SlicedAttnAddedKVProcessor,
|
1653 |
+
AttnAddedKVProcessor2_0,
|
1654 |
+
XFormersAttnAddedKVProcessor,
|
1655 |
+
LoRAAttnProcessor,
|
1656 |
+
LoRAXFormersAttnProcessor,
|
1657 |
+
LoRAAttnProcessor2_0,
|
1658 |
+
LoRAAttnAddedKVProcessor,
|
1659 |
+
CustomDiffusionAttnProcessor,
|
1660 |
+
CustomDiffusionXFormersAttnProcessor,
|
1661 |
+
]
|
1662 |
+
|
1663 |
+
|
1664 |
+
class SpatialNorm(nn.Module):
|
1665 |
+
"""
|
1666 |
+
Spatially conditioned normalization as defined in https://arxiv.org/abs/2209.09002
|
1667 |
+
"""
|
1668 |
+
|
1669 |
+
def __init__(
|
1670 |
+
self,
|
1671 |
+
f_channels,
|
1672 |
+
zq_channels,
|
1673 |
+
):
|
1674 |
+
super().__init__()
|
1675 |
+
self.norm_layer = nn.GroupNorm(num_channels=f_channels, num_groups=32, eps=1e-6, affine=True)
|
1676 |
+
self.conv_y = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
|
1677 |
+
self.conv_b = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
|
1678 |
+
|
1679 |
+
def forward(self, f, zq):
|
1680 |
+
f_size = f.shape[-2:]
|
1681 |
+
zq = F.interpolate(zq, size=f_size, mode="nearest")
|
1682 |
+
norm_f = self.norm_layer(f)
|
1683 |
+
new_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
|
1684 |
+
return new_f
|
6DoF/diffusers/models/autoencoder_kl.py
ADDED
@@ -0,0 +1,411 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
from dataclasses import dataclass
|
15 |
+
from typing import Dict, Optional, Tuple, Union
|
16 |
+
|
17 |
+
import torch
|
18 |
+
import torch.nn as nn
|
19 |
+
|
20 |
+
from ..configuration_utils import ConfigMixin, register_to_config
|
21 |
+
from ..utils import BaseOutput, apply_forward_hook
|
22 |
+
from .attention_processor import AttentionProcessor, AttnProcessor
|
23 |
+
from .modeling_utils import ModelMixin
|
24 |
+
from .vae import Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder
|
25 |
+
|
26 |
+
|
27 |
+
@dataclass
|
28 |
+
class AutoencoderKLOutput(BaseOutput):
|
29 |
+
"""
|
30 |
+
Output of AutoencoderKL encoding method.
|
31 |
+
|
32 |
+
Args:
|
33 |
+
latent_dist (`DiagonalGaussianDistribution`):
|
34 |
+
Encoded outputs of `Encoder` represented as the mean and logvar of `DiagonalGaussianDistribution`.
|
35 |
+
`DiagonalGaussianDistribution` allows for sampling latents from the distribution.
|
36 |
+
"""
|
37 |
+
|
38 |
+
latent_dist: "DiagonalGaussianDistribution"
|
39 |
+
|
40 |
+
|
41 |
+
class AutoencoderKL(ModelMixin, ConfigMixin):
|
42 |
+
r"""
|
43 |
+
A VAE model with KL loss for encoding images into latents and decoding latent representations into images.
|
44 |
+
|
45 |
+
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
|
46 |
+
for all models (such as downloading or saving).
|
47 |
+
|
48 |
+
Parameters:
|
49 |
+
in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
|
50 |
+
out_channels (int, *optional*, defaults to 3): Number of channels in the output.
|
51 |
+
down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
|
52 |
+
Tuple of downsample block types.
|
53 |
+
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
|
54 |
+
Tuple of upsample block types.
|
55 |
+
block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
|
56 |
+
Tuple of block output channels.
|
57 |
+
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
|
58 |
+
latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space.
|
59 |
+
sample_size (`int`, *optional*, defaults to `32`): Sample input size.
|
60 |
+
scaling_factor (`float`, *optional*, defaults to 0.18215):
|
61 |
+
The component-wise standard deviation of the trained latent space computed using the first batch of the
|
62 |
+
training set. This is used to scale the latent space to have unit variance when training the diffusion
|
63 |
+
model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
|
64 |
+
diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
|
65 |
+
/ scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
|
66 |
+
Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
|
67 |
+
"""
|
68 |
+
|
69 |
+
_supports_gradient_checkpointing = True
|
70 |
+
|
71 |
+
@register_to_config
|
72 |
+
def __init__(
|
73 |
+
self,
|
74 |
+
in_channels: int = 3,
|
75 |
+
out_channels: int = 3,
|
76 |
+
down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
|
77 |
+
up_block_types: Tuple[str] = ("UpDecoderBlock2D",),
|
78 |
+
block_out_channels: Tuple[int] = (64,),
|
79 |
+
layers_per_block: int = 1,
|
80 |
+
act_fn: str = "silu",
|
81 |
+
latent_channels: int = 4,
|
82 |
+
norm_num_groups: int = 32,
|
83 |
+
sample_size: int = 32,
|
84 |
+
scaling_factor: float = 0.18215,
|
85 |
+
):
|
86 |
+
super().__init__()
|
87 |
+
|
88 |
+
# pass init params to Encoder
|
89 |
+
self.encoder = Encoder(
|
90 |
+
in_channels=in_channels,
|
91 |
+
out_channels=latent_channels,
|
92 |
+
down_block_types=down_block_types,
|
93 |
+
block_out_channels=block_out_channels,
|
94 |
+
layers_per_block=layers_per_block,
|
95 |
+
act_fn=act_fn,
|
96 |
+
norm_num_groups=norm_num_groups,
|
97 |
+
double_z=True,
|
98 |
+
)
|
99 |
+
|
100 |
+
# pass init params to Decoder
|
101 |
+
self.decoder = Decoder(
|
102 |
+
in_channels=latent_channels,
|
103 |
+
out_channels=out_channels,
|
104 |
+
up_block_types=up_block_types,
|
105 |
+
block_out_channels=block_out_channels,
|
106 |
+
layers_per_block=layers_per_block,
|
107 |
+
norm_num_groups=norm_num_groups,
|
108 |
+
act_fn=act_fn,
|
109 |
+
)
|
110 |
+
|
111 |
+
self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
|
112 |
+
self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1)
|
113 |
+
|
114 |
+
self.use_slicing = False
|
115 |
+
self.use_tiling = False
|
116 |
+
|
117 |
+
# only relevant if vae tiling is enabled
|
118 |
+
self.tile_sample_min_size = self.config.sample_size
|
119 |
+
sample_size = (
|
120 |
+
self.config.sample_size[0]
|
121 |
+
if isinstance(self.config.sample_size, (list, tuple))
|
122 |
+
else self.config.sample_size
|
123 |
+
)
|
124 |
+
self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1)))
|
125 |
+
self.tile_overlap_factor = 0.25
|
126 |
+
|
127 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
128 |
+
if isinstance(module, (Encoder, Decoder)):
|
129 |
+
module.gradient_checkpointing = value
|
130 |
+
|
131 |
+
def enable_tiling(self, use_tiling: bool = True):
|
132 |
+
r"""
|
133 |
+
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
|
134 |
+
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
|
135 |
+
processing larger images.
|
136 |
+
"""
|
137 |
+
self.use_tiling = use_tiling
|
138 |
+
|
139 |
+
def disable_tiling(self):
|
140 |
+
r"""
|
141 |
+
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
|
142 |
+
decoding in one step.
|
143 |
+
"""
|
144 |
+
self.enable_tiling(False)
|
145 |
+
|
146 |
+
def enable_slicing(self):
|
147 |
+
r"""
|
148 |
+
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
149 |
+
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
150 |
+
"""
|
151 |
+
self.use_slicing = True
|
152 |
+
|
153 |
+
def disable_slicing(self):
|
154 |
+
r"""
|
155 |
+
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
|
156 |
+
decoding in one step.
|
157 |
+
"""
|
158 |
+
self.use_slicing = False
|
159 |
+
|
160 |
+
@property
|
161 |
+
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
|
162 |
+
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
163 |
+
r"""
|
164 |
+
Returns:
|
165 |
+
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
166 |
+
indexed by its weight name.
|
167 |
+
"""
|
168 |
+
# set recursively
|
169 |
+
processors = {}
|
170 |
+
|
171 |
+
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
172 |
+
if hasattr(module, "set_processor"):
|
173 |
+
processors[f"{name}.processor"] = module.processor
|
174 |
+
|
175 |
+
for sub_name, child in module.named_children():
|
176 |
+
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
177 |
+
|
178 |
+
return processors
|
179 |
+
|
180 |
+
for name, module in self.named_children():
|
181 |
+
fn_recursive_add_processors(name, module, processors)
|
182 |
+
|
183 |
+
return processors
|
184 |
+
|
185 |
+
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
186 |
+
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
187 |
+
r"""
|
188 |
+
Sets the attention processor to use to compute attention.
|
189 |
+
|
190 |
+
Parameters:
|
191 |
+
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
192 |
+
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
193 |
+
for **all** `Attention` layers.
|
194 |
+
|
195 |
+
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
196 |
+
processor. This is strongly recommended when setting trainable attention processors.
|
197 |
+
|
198 |
+
"""
|
199 |
+
count = len(self.attn_processors.keys())
|
200 |
+
|
201 |
+
if isinstance(processor, dict) and len(processor) != count:
|
202 |
+
raise ValueError(
|
203 |
+
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
204 |
+
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
205 |
+
)
|
206 |
+
|
207 |
+
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
208 |
+
if hasattr(module, "set_processor"):
|
209 |
+
if not isinstance(processor, dict):
|
210 |
+
module.set_processor(processor)
|
211 |
+
else:
|
212 |
+
module.set_processor(processor.pop(f"{name}.processor"))
|
213 |
+
|
214 |
+
for sub_name, child in module.named_children():
|
215 |
+
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
216 |
+
|
217 |
+
for name, module in self.named_children():
|
218 |
+
fn_recursive_attn_processor(name, module, processor)
|
219 |
+
|
220 |
+
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
|
221 |
+
def set_default_attn_processor(self):
|
222 |
+
"""
|
223 |
+
Disables custom attention processors and sets the default attention implementation.
|
224 |
+
"""
|
225 |
+
self.set_attn_processor(AttnProcessor())
|
226 |
+
|
227 |
+
@apply_forward_hook
|
228 |
+
def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
|
229 |
+
if self.use_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size):
|
230 |
+
return self.tiled_encode(x, return_dict=return_dict)
|
231 |
+
|
232 |
+
if self.use_slicing and x.shape[0] > 1:
|
233 |
+
encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)]
|
234 |
+
h = torch.cat(encoded_slices)
|
235 |
+
else:
|
236 |
+
h = self.encoder(x)
|
237 |
+
|
238 |
+
moments = self.quant_conv(h)
|
239 |
+
posterior = DiagonalGaussianDistribution(moments)
|
240 |
+
|
241 |
+
if not return_dict:
|
242 |
+
return (posterior,)
|
243 |
+
|
244 |
+
return AutoencoderKLOutput(latent_dist=posterior)
|
245 |
+
|
246 |
+
def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
|
247 |
+
if self.use_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size):
|
248 |
+
return self.tiled_decode(z, return_dict=return_dict)
|
249 |
+
|
250 |
+
z = self.post_quant_conv(z)
|
251 |
+
dec = self.decoder(z)
|
252 |
+
|
253 |
+
if not return_dict:
|
254 |
+
return (dec,)
|
255 |
+
|
256 |
+
return DecoderOutput(sample=dec)
|
257 |
+
|
258 |
+
@apply_forward_hook
|
259 |
+
def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
|
260 |
+
if self.use_slicing and z.shape[0] > 1:
|
261 |
+
decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
|
262 |
+
decoded = torch.cat(decoded_slices)
|
263 |
+
else:
|
264 |
+
decoded = self._decode(z).sample
|
265 |
+
|
266 |
+
if not return_dict:
|
267 |
+
return (decoded,)
|
268 |
+
|
269 |
+
return DecoderOutput(sample=decoded)
|
270 |
+
|
271 |
+
def blend_v(self, a, b, blend_extent):
|
272 |
+
blend_extent = min(a.shape[2], b.shape[2], blend_extent)
|
273 |
+
for y in range(blend_extent):
|
274 |
+
b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent)
|
275 |
+
return b
|
276 |
+
|
277 |
+
def blend_h(self, a, b, blend_extent):
|
278 |
+
blend_extent = min(a.shape[3], b.shape[3], blend_extent)
|
279 |
+
for x in range(blend_extent):
|
280 |
+
b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent)
|
281 |
+
return b
|
282 |
+
|
283 |
+
def tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
|
284 |
+
r"""Encode a batch of images using a tiled encoder.
|
285 |
+
|
286 |
+
When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
|
287 |
+
steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is
|
288 |
+
different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
|
289 |
+
tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
|
290 |
+
output, but they should be much less noticeable.
|
291 |
+
|
292 |
+
Args:
|
293 |
+
x (`torch.FloatTensor`): Input batch of images.
|
294 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
295 |
+
Whether or not to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
|
296 |
+
|
297 |
+
Returns:
|
298 |
+
[`~models.autoencoder_kl.AutoencoderKLOutput`] or `tuple`:
|
299 |
+
If return_dict is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain
|
300 |
+
`tuple` is returned.
|
301 |
+
"""
|
302 |
+
overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
|
303 |
+
blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
|
304 |
+
row_limit = self.tile_latent_min_size - blend_extent
|
305 |
+
|
306 |
+
# Split the image into 512x512 tiles and encode them separately.
|
307 |
+
rows = []
|
308 |
+
for i in range(0, x.shape[2], overlap_size):
|
309 |
+
row = []
|
310 |
+
for j in range(0, x.shape[3], overlap_size):
|
311 |
+
tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
|
312 |
+
tile = self.encoder(tile)
|
313 |
+
tile = self.quant_conv(tile)
|
314 |
+
row.append(tile)
|
315 |
+
rows.append(row)
|
316 |
+
result_rows = []
|
317 |
+
for i, row in enumerate(rows):
|
318 |
+
result_row = []
|
319 |
+
for j, tile in enumerate(row):
|
320 |
+
# blend the above tile and the left tile
|
321 |
+
# to the current tile and add the current tile to the result row
|
322 |
+
if i > 0:
|
323 |
+
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
|
324 |
+
if j > 0:
|
325 |
+
tile = self.blend_h(row[j - 1], tile, blend_extent)
|
326 |
+
result_row.append(tile[:, :, :row_limit, :row_limit])
|
327 |
+
result_rows.append(torch.cat(result_row, dim=3))
|
328 |
+
|
329 |
+
moments = torch.cat(result_rows, dim=2)
|
330 |
+
posterior = DiagonalGaussianDistribution(moments)
|
331 |
+
|
332 |
+
if not return_dict:
|
333 |
+
return (posterior,)
|
334 |
+
|
335 |
+
return AutoencoderKLOutput(latent_dist=posterior)
|
336 |
+
|
337 |
+
def tiled_decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
|
338 |
+
r"""
|
339 |
+
Decode a batch of images using a tiled decoder.
|
340 |
+
|
341 |
+
Args:
|
342 |
+
z (`torch.FloatTensor`): Input batch of latent vectors.
|
343 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
344 |
+
Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
|
345 |
+
|
346 |
+
Returns:
|
347 |
+
[`~models.vae.DecoderOutput`] or `tuple`:
|
348 |
+
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
|
349 |
+
returned.
|
350 |
+
"""
|
351 |
+
overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor))
|
352 |
+
blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor)
|
353 |
+
row_limit = self.tile_sample_min_size - blend_extent
|
354 |
+
|
355 |
+
# Split z into overlapping 64x64 tiles and decode them separately.
|
356 |
+
# The tiles have an overlap to avoid seams between tiles.
|
357 |
+
rows = []
|
358 |
+
for i in range(0, z.shape[2], overlap_size):
|
359 |
+
row = []
|
360 |
+
for j in range(0, z.shape[3], overlap_size):
|
361 |
+
tile = z[:, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size]
|
362 |
+
tile = self.post_quant_conv(tile)
|
363 |
+
decoded = self.decoder(tile)
|
364 |
+
row.append(decoded)
|
365 |
+
rows.append(row)
|
366 |
+
result_rows = []
|
367 |
+
for i, row in enumerate(rows):
|
368 |
+
result_row = []
|
369 |
+
for j, tile in enumerate(row):
|
370 |
+
# blend the above tile and the left tile
|
371 |
+
# to the current tile and add the current tile to the result row
|
372 |
+
if i > 0:
|
373 |
+
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
|
374 |
+
if j > 0:
|
375 |
+
tile = self.blend_h(row[j - 1], tile, blend_extent)
|
376 |
+
result_row.append(tile[:, :, :row_limit, :row_limit])
|
377 |
+
result_rows.append(torch.cat(result_row, dim=3))
|
378 |
+
|
379 |
+
dec = torch.cat(result_rows, dim=2)
|
380 |
+
if not return_dict:
|
381 |
+
return (dec,)
|
382 |
+
|
383 |
+
return DecoderOutput(sample=dec)
|
384 |
+
|
385 |
+
def forward(
|
386 |
+
self,
|
387 |
+
sample: torch.FloatTensor,
|
388 |
+
sample_posterior: bool = False,
|
389 |
+
return_dict: bool = True,
|
390 |
+
generator: Optional[torch.Generator] = None,
|
391 |
+
) -> Union[DecoderOutput, torch.FloatTensor]:
|
392 |
+
r"""
|
393 |
+
Args:
|
394 |
+
sample (`torch.FloatTensor`): Input sample.
|
395 |
+
sample_posterior (`bool`, *optional*, defaults to `False`):
|
396 |
+
Whether to sample from the posterior.
|
397 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
398 |
+
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
|
399 |
+
"""
|
400 |
+
x = sample
|
401 |
+
posterior = self.encode(x).latent_dist
|
402 |
+
if sample_posterior:
|
403 |
+
z = posterior.sample(generator=generator)
|
404 |
+
else:
|
405 |
+
z = posterior.mode()
|
406 |
+
dec = self.decode(z).sample
|
407 |
+
|
408 |
+
if not return_dict:
|
409 |
+
return (dec,)
|
410 |
+
|
411 |
+
return DecoderOutput(sample=dec)
|
6DoF/diffusers/models/controlnet.py
ADDED
@@ -0,0 +1,705 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
from dataclasses import dataclass
|
15 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
16 |
+
|
17 |
+
import torch
|
18 |
+
from torch import nn
|
19 |
+
from torch.nn import functional as F
|
20 |
+
|
21 |
+
from ..configuration_utils import ConfigMixin, register_to_config
|
22 |
+
from ..utils import BaseOutput, logging
|
23 |
+
from .attention_processor import AttentionProcessor, AttnProcessor
|
24 |
+
from .embeddings import TimestepEmbedding, Timesteps
|
25 |
+
from .modeling_utils import ModelMixin
|
26 |
+
from .unet_2d_blocks import (
|
27 |
+
CrossAttnDownBlock2D,
|
28 |
+
DownBlock2D,
|
29 |
+
UNetMidBlock2DCrossAttn,
|
30 |
+
get_down_block,
|
31 |
+
)
|
32 |
+
from .unet_2d_condition import UNet2DConditionModel
|
33 |
+
|
34 |
+
|
35 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
36 |
+
|
37 |
+
|
38 |
+
@dataclass
|
39 |
+
class ControlNetOutput(BaseOutput):
|
40 |
+
"""
|
41 |
+
The output of [`ControlNetModel`].
|
42 |
+
|
43 |
+
Args:
|
44 |
+
down_block_res_samples (`tuple[torch.Tensor]`):
|
45 |
+
A tuple of downsample activations at different resolutions for each downsampling block. Each tensor should
|
46 |
+
be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be
|
47 |
+
used to condition the original UNet's downsampling activations.
|
48 |
+
mid_down_block_re_sample (`torch.Tensor`):
|
49 |
+
The activation of the midde block (the lowest sample resolution). Each tensor should be of shape
|
50 |
+
`(batch_size, channel * lowest_resolution, height // lowest_resolution, width // lowest_resolution)`.
|
51 |
+
Output can be used to condition the original UNet's middle block activation.
|
52 |
+
"""
|
53 |
+
|
54 |
+
down_block_res_samples: Tuple[torch.Tensor]
|
55 |
+
mid_block_res_sample: torch.Tensor
|
56 |
+
|
57 |
+
|
58 |
+
class ControlNetConditioningEmbedding(nn.Module):
|
59 |
+
"""
|
60 |
+
Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN
|
61 |
+
[11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized
|
62 |
+
training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the
|
63 |
+
convolution size. We use a tiny network E(·) of four convolution layers with 4 × 4 kernels and 2 × 2 strides
|
64 |
+
(activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full
|
65 |
+
model) to encode image-space conditions ... into feature maps ..."
|
66 |
+
"""
|
67 |
+
|
68 |
+
def __init__(
|
69 |
+
self,
|
70 |
+
conditioning_embedding_channels: int,
|
71 |
+
conditioning_channels: int = 3,
|
72 |
+
block_out_channels: Tuple[int] = (16, 32, 96, 256),
|
73 |
+
):
|
74 |
+
super().__init__()
|
75 |
+
|
76 |
+
self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1)
|
77 |
+
|
78 |
+
self.blocks = nn.ModuleList([])
|
79 |
+
|
80 |
+
for i in range(len(block_out_channels) - 1):
|
81 |
+
channel_in = block_out_channels[i]
|
82 |
+
channel_out = block_out_channels[i + 1]
|
83 |
+
self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1))
|
84 |
+
self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2))
|
85 |
+
|
86 |
+
self.conv_out = zero_module(
|
87 |
+
nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1)
|
88 |
+
)
|
89 |
+
|
90 |
+
def forward(self, conditioning):
|
91 |
+
embedding = self.conv_in(conditioning)
|
92 |
+
embedding = F.silu(embedding)
|
93 |
+
|
94 |
+
for block in self.blocks:
|
95 |
+
embedding = block(embedding)
|
96 |
+
embedding = F.silu(embedding)
|
97 |
+
|
98 |
+
embedding = self.conv_out(embedding)
|
99 |
+
|
100 |
+
return embedding
|
101 |
+
|
102 |
+
|
103 |
+
class ControlNetModel(ModelMixin, ConfigMixin):
|
104 |
+
"""
|
105 |
+
A ControlNet model.
|
106 |
+
|
107 |
+
Args:
|
108 |
+
in_channels (`int`, defaults to 4):
|
109 |
+
The number of channels in the input sample.
|
110 |
+
flip_sin_to_cos (`bool`, defaults to `True`):
|
111 |
+
Whether to flip the sin to cos in the time embedding.
|
112 |
+
freq_shift (`int`, defaults to 0):
|
113 |
+
The frequency shift to apply to the time embedding.
|
114 |
+
down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
|
115 |
+
The tuple of downsample blocks to use.
|
116 |
+
only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`):
|
117 |
+
block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`):
|
118 |
+
The tuple of output channels for each block.
|
119 |
+
layers_per_block (`int`, defaults to 2):
|
120 |
+
The number of layers per block.
|
121 |
+
downsample_padding (`int`, defaults to 1):
|
122 |
+
The padding to use for the downsampling convolution.
|
123 |
+
mid_block_scale_factor (`float`, defaults to 1):
|
124 |
+
The scale factor to use for the mid block.
|
125 |
+
act_fn (`str`, defaults to "silu"):
|
126 |
+
The activation function to use.
|
127 |
+
norm_num_groups (`int`, *optional*, defaults to 32):
|
128 |
+
The number of groups to use for the normalization. If None, normalization and activation layers is skipped
|
129 |
+
in post-processing.
|
130 |
+
norm_eps (`float`, defaults to 1e-5):
|
131 |
+
The epsilon to use for the normalization.
|
132 |
+
cross_attention_dim (`int`, defaults to 1280):
|
133 |
+
The dimension of the cross attention features.
|
134 |
+
attention_head_dim (`Union[int, Tuple[int]]`, defaults to 8):
|
135 |
+
The dimension of the attention heads.
|
136 |
+
use_linear_projection (`bool`, defaults to `False`):
|
137 |
+
class_embed_type (`str`, *optional*, defaults to `None`):
|
138 |
+
The type of class embedding to use which is ultimately summed with the time embeddings. Choose from None,
|
139 |
+
`"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
|
140 |
+
num_class_embeds (`int`, *optional*, defaults to 0):
|
141 |
+
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
|
142 |
+
class conditioning with `class_embed_type` equal to `None`.
|
143 |
+
upcast_attention (`bool`, defaults to `False`):
|
144 |
+
resnet_time_scale_shift (`str`, defaults to `"default"`):
|
145 |
+
Time scale shift config for ResNet blocks (see `ResnetBlock2D`). Choose from `default` or `scale_shift`.
|
146 |
+
projection_class_embeddings_input_dim (`int`, *optional*, defaults to `None`):
|
147 |
+
The dimension of the `class_labels` input when `class_embed_type="projection"`. Required when
|
148 |
+
`class_embed_type="projection"`.
|
149 |
+
controlnet_conditioning_channel_order (`str`, defaults to `"rgb"`):
|
150 |
+
The channel order of conditional image. Will convert to `rgb` if it's `bgr`.
|
151 |
+
conditioning_embedding_out_channels (`tuple[int]`, *optional*, defaults to `(16, 32, 96, 256)`):
|
152 |
+
The tuple of output channel for each block in the `conditioning_embedding` layer.
|
153 |
+
global_pool_conditions (`bool`, defaults to `False`):
|
154 |
+
"""
|
155 |
+
|
156 |
+
_supports_gradient_checkpointing = True
|
157 |
+
|
158 |
+
@register_to_config
|
159 |
+
def __init__(
|
160 |
+
self,
|
161 |
+
in_channels: int = 4,
|
162 |
+
conditioning_channels: int = 3,
|
163 |
+
flip_sin_to_cos: bool = True,
|
164 |
+
freq_shift: int = 0,
|
165 |
+
down_block_types: Tuple[str] = (
|
166 |
+
"CrossAttnDownBlock2D",
|
167 |
+
"CrossAttnDownBlock2D",
|
168 |
+
"CrossAttnDownBlock2D",
|
169 |
+
"DownBlock2D",
|
170 |
+
),
|
171 |
+
only_cross_attention: Union[bool, Tuple[bool]] = False,
|
172 |
+
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
|
173 |
+
layers_per_block: int = 2,
|
174 |
+
downsample_padding: int = 1,
|
175 |
+
mid_block_scale_factor: float = 1,
|
176 |
+
act_fn: str = "silu",
|
177 |
+
norm_num_groups: Optional[int] = 32,
|
178 |
+
norm_eps: float = 1e-5,
|
179 |
+
cross_attention_dim: int = 1280,
|
180 |
+
attention_head_dim: Union[int, Tuple[int]] = 8,
|
181 |
+
num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
|
182 |
+
use_linear_projection: bool = False,
|
183 |
+
class_embed_type: Optional[str] = None,
|
184 |
+
num_class_embeds: Optional[int] = None,
|
185 |
+
upcast_attention: bool = False,
|
186 |
+
resnet_time_scale_shift: str = "default",
|
187 |
+
projection_class_embeddings_input_dim: Optional[int] = None,
|
188 |
+
controlnet_conditioning_channel_order: str = "rgb",
|
189 |
+
conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256),
|
190 |
+
global_pool_conditions: bool = False,
|
191 |
+
):
|
192 |
+
super().__init__()
|
193 |
+
|
194 |
+
# If `num_attention_heads` is not defined (which is the case for most models)
|
195 |
+
# it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
|
196 |
+
# The reason for this behavior is to correct for incorrectly named variables that were introduced
|
197 |
+
# when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
|
198 |
+
# Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
|
199 |
+
# which is why we correct for the naming here.
|
200 |
+
num_attention_heads = num_attention_heads or attention_head_dim
|
201 |
+
|
202 |
+
# Check inputs
|
203 |
+
if len(block_out_channels) != len(down_block_types):
|
204 |
+
raise ValueError(
|
205 |
+
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
|
206 |
+
)
|
207 |
+
|
208 |
+
if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
|
209 |
+
raise ValueError(
|
210 |
+
f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
|
211 |
+
)
|
212 |
+
|
213 |
+
if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
|
214 |
+
raise ValueError(
|
215 |
+
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
|
216 |
+
)
|
217 |
+
|
218 |
+
# input
|
219 |
+
conv_in_kernel = 3
|
220 |
+
conv_in_padding = (conv_in_kernel - 1) // 2
|
221 |
+
self.conv_in = nn.Conv2d(
|
222 |
+
in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
|
223 |
+
)
|
224 |
+
|
225 |
+
# time
|
226 |
+
time_embed_dim = block_out_channels[0] * 4
|
227 |
+
|
228 |
+
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
|
229 |
+
timestep_input_dim = block_out_channels[0]
|
230 |
+
|
231 |
+
self.time_embedding = TimestepEmbedding(
|
232 |
+
timestep_input_dim,
|
233 |
+
time_embed_dim,
|
234 |
+
act_fn=act_fn,
|
235 |
+
)
|
236 |
+
|
237 |
+
# class embedding
|
238 |
+
if class_embed_type is None and num_class_embeds is not None:
|
239 |
+
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
|
240 |
+
elif class_embed_type == "timestep":
|
241 |
+
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
|
242 |
+
elif class_embed_type == "identity":
|
243 |
+
self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
|
244 |
+
elif class_embed_type == "projection":
|
245 |
+
if projection_class_embeddings_input_dim is None:
|
246 |
+
raise ValueError(
|
247 |
+
"`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
|
248 |
+
)
|
249 |
+
# The projection `class_embed_type` is the same as the timestep `class_embed_type` except
|
250 |
+
# 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
|
251 |
+
# 2. it projects from an arbitrary input dimension.
|
252 |
+
#
|
253 |
+
# Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
|
254 |
+
# When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
|
255 |
+
# As a result, `TimestepEmbedding` can be passed arbitrary vectors.
|
256 |
+
self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
|
257 |
+
else:
|
258 |
+
self.class_embedding = None
|
259 |
+
|
260 |
+
# control net conditioning embedding
|
261 |
+
self.controlnet_cond_embedding = ControlNetConditioningEmbedding(
|
262 |
+
conditioning_embedding_channels=block_out_channels[0],
|
263 |
+
block_out_channels=conditioning_embedding_out_channels,
|
264 |
+
conditioning_channels=conditioning_channels,
|
265 |
+
)
|
266 |
+
|
267 |
+
self.down_blocks = nn.ModuleList([])
|
268 |
+
self.controlnet_down_blocks = nn.ModuleList([])
|
269 |
+
|
270 |
+
if isinstance(only_cross_attention, bool):
|
271 |
+
only_cross_attention = [only_cross_attention] * len(down_block_types)
|
272 |
+
|
273 |
+
if isinstance(attention_head_dim, int):
|
274 |
+
attention_head_dim = (attention_head_dim,) * len(down_block_types)
|
275 |
+
|
276 |
+
if isinstance(num_attention_heads, int):
|
277 |
+
num_attention_heads = (num_attention_heads,) * len(down_block_types)
|
278 |
+
|
279 |
+
# down
|
280 |
+
output_channel = block_out_channels[0]
|
281 |
+
|
282 |
+
controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
|
283 |
+
controlnet_block = zero_module(controlnet_block)
|
284 |
+
self.controlnet_down_blocks.append(controlnet_block)
|
285 |
+
|
286 |
+
for i, down_block_type in enumerate(down_block_types):
|
287 |
+
input_channel = output_channel
|
288 |
+
output_channel = block_out_channels[i]
|
289 |
+
is_final_block = i == len(block_out_channels) - 1
|
290 |
+
|
291 |
+
down_block = get_down_block(
|
292 |
+
down_block_type,
|
293 |
+
num_layers=layers_per_block,
|
294 |
+
in_channels=input_channel,
|
295 |
+
out_channels=output_channel,
|
296 |
+
temb_channels=time_embed_dim,
|
297 |
+
add_downsample=not is_final_block,
|
298 |
+
resnet_eps=norm_eps,
|
299 |
+
resnet_act_fn=act_fn,
|
300 |
+
resnet_groups=norm_num_groups,
|
301 |
+
cross_attention_dim=cross_attention_dim,
|
302 |
+
num_attention_heads=num_attention_heads[i],
|
303 |
+
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
|
304 |
+
downsample_padding=downsample_padding,
|
305 |
+
use_linear_projection=use_linear_projection,
|
306 |
+
only_cross_attention=only_cross_attention[i],
|
307 |
+
upcast_attention=upcast_attention,
|
308 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
309 |
+
)
|
310 |
+
self.down_blocks.append(down_block)
|
311 |
+
|
312 |
+
for _ in range(layers_per_block):
|
313 |
+
controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
|
314 |
+
controlnet_block = zero_module(controlnet_block)
|
315 |
+
self.controlnet_down_blocks.append(controlnet_block)
|
316 |
+
|
317 |
+
if not is_final_block:
|
318 |
+
controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
|
319 |
+
controlnet_block = zero_module(controlnet_block)
|
320 |
+
self.controlnet_down_blocks.append(controlnet_block)
|
321 |
+
|
322 |
+
# mid
|
323 |
+
mid_block_channel = block_out_channels[-1]
|
324 |
+
|
325 |
+
controlnet_block = nn.Conv2d(mid_block_channel, mid_block_channel, kernel_size=1)
|
326 |
+
controlnet_block = zero_module(controlnet_block)
|
327 |
+
self.controlnet_mid_block = controlnet_block
|
328 |
+
|
329 |
+
self.mid_block = UNetMidBlock2DCrossAttn(
|
330 |
+
in_channels=mid_block_channel,
|
331 |
+
temb_channels=time_embed_dim,
|
332 |
+
resnet_eps=norm_eps,
|
333 |
+
resnet_act_fn=act_fn,
|
334 |
+
output_scale_factor=mid_block_scale_factor,
|
335 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
336 |
+
cross_attention_dim=cross_attention_dim,
|
337 |
+
num_attention_heads=num_attention_heads[-1],
|
338 |
+
resnet_groups=norm_num_groups,
|
339 |
+
use_linear_projection=use_linear_projection,
|
340 |
+
upcast_attention=upcast_attention,
|
341 |
+
)
|
342 |
+
|
343 |
+
@classmethod
|
344 |
+
def from_unet(
|
345 |
+
cls,
|
346 |
+
unet: UNet2DConditionModel,
|
347 |
+
controlnet_conditioning_channel_order: str = "rgb",
|
348 |
+
conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256),
|
349 |
+
load_weights_from_unet: bool = True,
|
350 |
+
):
|
351 |
+
r"""
|
352 |
+
Instantiate a [`ControlNetModel`] from [`UNet2DConditionModel`].
|
353 |
+
|
354 |
+
Parameters:
|
355 |
+
unet (`UNet2DConditionModel`):
|
356 |
+
The UNet model weights to copy to the [`ControlNetModel`]. All configuration options are also copied
|
357 |
+
where applicable.
|
358 |
+
"""
|
359 |
+
controlnet = cls(
|
360 |
+
in_channels=unet.config.in_channels,
|
361 |
+
flip_sin_to_cos=unet.config.flip_sin_to_cos,
|
362 |
+
freq_shift=unet.config.freq_shift,
|
363 |
+
down_block_types=unet.config.down_block_types,
|
364 |
+
only_cross_attention=unet.config.only_cross_attention,
|
365 |
+
block_out_channels=unet.config.block_out_channels,
|
366 |
+
layers_per_block=unet.config.layers_per_block,
|
367 |
+
downsample_padding=unet.config.downsample_padding,
|
368 |
+
mid_block_scale_factor=unet.config.mid_block_scale_factor,
|
369 |
+
act_fn=unet.config.act_fn,
|
370 |
+
norm_num_groups=unet.config.norm_num_groups,
|
371 |
+
norm_eps=unet.config.norm_eps,
|
372 |
+
cross_attention_dim=unet.config.cross_attention_dim,
|
373 |
+
attention_head_dim=unet.config.attention_head_dim,
|
374 |
+
num_attention_heads=unet.config.num_attention_heads,
|
375 |
+
use_linear_projection=unet.config.use_linear_projection,
|
376 |
+
class_embed_type=unet.config.class_embed_type,
|
377 |
+
num_class_embeds=unet.config.num_class_embeds,
|
378 |
+
upcast_attention=unet.config.upcast_attention,
|
379 |
+
resnet_time_scale_shift=unet.config.resnet_time_scale_shift,
|
380 |
+
projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim,
|
381 |
+
controlnet_conditioning_channel_order=controlnet_conditioning_channel_order,
|
382 |
+
conditioning_embedding_out_channels=conditioning_embedding_out_channels,
|
383 |
+
)
|
384 |
+
|
385 |
+
if load_weights_from_unet:
|
386 |
+
controlnet.conv_in.load_state_dict(unet.conv_in.state_dict())
|
387 |
+
controlnet.time_proj.load_state_dict(unet.time_proj.state_dict())
|
388 |
+
controlnet.time_embedding.load_state_dict(unet.time_embedding.state_dict())
|
389 |
+
|
390 |
+
if controlnet.class_embedding:
|
391 |
+
controlnet.class_embedding.load_state_dict(unet.class_embedding.state_dict())
|
392 |
+
|
393 |
+
controlnet.down_blocks.load_state_dict(unet.down_blocks.state_dict())
|
394 |
+
controlnet.mid_block.load_state_dict(unet.mid_block.state_dict())
|
395 |
+
|
396 |
+
return controlnet
|
397 |
+
|
398 |
+
@property
|
399 |
+
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
|
400 |
+
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
401 |
+
r"""
|
402 |
+
Returns:
|
403 |
+
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
404 |
+
indexed by its weight name.
|
405 |
+
"""
|
406 |
+
# set recursively
|
407 |
+
processors = {}
|
408 |
+
|
409 |
+
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
410 |
+
if hasattr(module, "set_processor"):
|
411 |
+
processors[f"{name}.processor"] = module.processor
|
412 |
+
|
413 |
+
for sub_name, child in module.named_children():
|
414 |
+
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
415 |
+
|
416 |
+
return processors
|
417 |
+
|
418 |
+
for name, module in self.named_children():
|
419 |
+
fn_recursive_add_processors(name, module, processors)
|
420 |
+
|
421 |
+
return processors
|
422 |
+
|
423 |
+
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
424 |
+
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
425 |
+
r"""
|
426 |
+
Sets the attention processor to use to compute attention.
|
427 |
+
|
428 |
+
Parameters:
|
429 |
+
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
430 |
+
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
431 |
+
for **all** `Attention` layers.
|
432 |
+
|
433 |
+
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
434 |
+
processor. This is strongly recommended when setting trainable attention processors.
|
435 |
+
|
436 |
+
"""
|
437 |
+
count = len(self.attn_processors.keys())
|
438 |
+
|
439 |
+
if isinstance(processor, dict) and len(processor) != count:
|
440 |
+
raise ValueError(
|
441 |
+
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
442 |
+
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
443 |
+
)
|
444 |
+
|
445 |
+
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
446 |
+
if hasattr(module, "set_processor"):
|
447 |
+
if not isinstance(processor, dict):
|
448 |
+
module.set_processor(processor)
|
449 |
+
else:
|
450 |
+
module.set_processor(processor.pop(f"{name}.processor"))
|
451 |
+
|
452 |
+
for sub_name, child in module.named_children():
|
453 |
+
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
454 |
+
|
455 |
+
for name, module in self.named_children():
|
456 |
+
fn_recursive_attn_processor(name, module, processor)
|
457 |
+
|
458 |
+
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
|
459 |
+
def set_default_attn_processor(self):
|
460 |
+
"""
|
461 |
+
Disables custom attention processors and sets the default attention implementation.
|
462 |
+
"""
|
463 |
+
self.set_attn_processor(AttnProcessor())
|
464 |
+
|
465 |
+
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice
|
466 |
+
def set_attention_slice(self, slice_size):
|
467 |
+
r"""
|
468 |
+
Enable sliced attention computation.
|
469 |
+
|
470 |
+
When this option is enabled, the attention module splits the input tensor in slices to compute attention in
|
471 |
+
several steps. This is useful for saving some memory in exchange for a small decrease in speed.
|
472 |
+
|
473 |
+
Args:
|
474 |
+
slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
|
475 |
+
When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
|
476 |
+
`"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
|
477 |
+
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
|
478 |
+
must be a multiple of `slice_size`.
|
479 |
+
"""
|
480 |
+
sliceable_head_dims = []
|
481 |
+
|
482 |
+
def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
|
483 |
+
if hasattr(module, "set_attention_slice"):
|
484 |
+
sliceable_head_dims.append(module.sliceable_head_dim)
|
485 |
+
|
486 |
+
for child in module.children():
|
487 |
+
fn_recursive_retrieve_sliceable_dims(child)
|
488 |
+
|
489 |
+
# retrieve number of attention layers
|
490 |
+
for module in self.children():
|
491 |
+
fn_recursive_retrieve_sliceable_dims(module)
|
492 |
+
|
493 |
+
num_sliceable_layers = len(sliceable_head_dims)
|
494 |
+
|
495 |
+
if slice_size == "auto":
|
496 |
+
# half the attention head size is usually a good trade-off between
|
497 |
+
# speed and memory
|
498 |
+
slice_size = [dim // 2 for dim in sliceable_head_dims]
|
499 |
+
elif slice_size == "max":
|
500 |
+
# make smallest slice possible
|
501 |
+
slice_size = num_sliceable_layers * [1]
|
502 |
+
|
503 |
+
slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
|
504 |
+
|
505 |
+
if len(slice_size) != len(sliceable_head_dims):
|
506 |
+
raise ValueError(
|
507 |
+
f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
|
508 |
+
f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
|
509 |
+
)
|
510 |
+
|
511 |
+
for i in range(len(slice_size)):
|
512 |
+
size = slice_size[i]
|
513 |
+
dim = sliceable_head_dims[i]
|
514 |
+
if size is not None and size > dim:
|
515 |
+
raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
|
516 |
+
|
517 |
+
# Recursively walk through all the children.
|
518 |
+
# Any children which exposes the set_attention_slice method
|
519 |
+
# gets the message
|
520 |
+
def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
|
521 |
+
if hasattr(module, "set_attention_slice"):
|
522 |
+
module.set_attention_slice(slice_size.pop())
|
523 |
+
|
524 |
+
for child in module.children():
|
525 |
+
fn_recursive_set_attention_slice(child, slice_size)
|
526 |
+
|
527 |
+
reversed_slice_size = list(reversed(slice_size))
|
528 |
+
for module in self.children():
|
529 |
+
fn_recursive_set_attention_slice(module, reversed_slice_size)
|
530 |
+
|
531 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
532 |
+
if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)):
|
533 |
+
module.gradient_checkpointing = value
|
534 |
+
|
535 |
+
def forward(
|
536 |
+
self,
|
537 |
+
sample: torch.FloatTensor,
|
538 |
+
timestep: Union[torch.Tensor, float, int],
|
539 |
+
encoder_hidden_states: torch.Tensor,
|
540 |
+
controlnet_cond: torch.FloatTensor,
|
541 |
+
conditioning_scale: float = 1.0,
|
542 |
+
class_labels: Optional[torch.Tensor] = None,
|
543 |
+
timestep_cond: Optional[torch.Tensor] = None,
|
544 |
+
attention_mask: Optional[torch.Tensor] = None,
|
545 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
546 |
+
guess_mode: bool = False,
|
547 |
+
return_dict: bool = True,
|
548 |
+
) -> Union[ControlNetOutput, Tuple]:
|
549 |
+
"""
|
550 |
+
The [`ControlNetModel`] forward method.
|
551 |
+
|
552 |
+
Args:
|
553 |
+
sample (`torch.FloatTensor`):
|
554 |
+
The noisy input tensor.
|
555 |
+
timestep (`Union[torch.Tensor, float, int]`):
|
556 |
+
The number of timesteps to denoise an input.
|
557 |
+
encoder_hidden_states (`torch.Tensor`):
|
558 |
+
The encoder hidden states.
|
559 |
+
controlnet_cond (`torch.FloatTensor`):
|
560 |
+
The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
|
561 |
+
conditioning_scale (`float`, defaults to `1.0`):
|
562 |
+
The scale factor for ControlNet outputs.
|
563 |
+
class_labels (`torch.Tensor`, *optional*, defaults to `None`):
|
564 |
+
Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
|
565 |
+
timestep_cond (`torch.Tensor`, *optional*, defaults to `None`):
|
566 |
+
attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
|
567 |
+
cross_attention_kwargs(`dict[str]`, *optional*, defaults to `None`):
|
568 |
+
A kwargs dictionary that if specified is passed along to the `AttnProcessor`.
|
569 |
+
guess_mode (`bool`, defaults to `False`):
|
570 |
+
In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if
|
571 |
+
you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended.
|
572 |
+
return_dict (`bool`, defaults to `True`):
|
573 |
+
Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple.
|
574 |
+
|
575 |
+
Returns:
|
576 |
+
[`~models.controlnet.ControlNetOutput`] **or** `tuple`:
|
577 |
+
If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is
|
578 |
+
returned where the first element is the sample tensor.
|
579 |
+
"""
|
580 |
+
# check channel order
|
581 |
+
channel_order = self.config.controlnet_conditioning_channel_order
|
582 |
+
|
583 |
+
if channel_order == "rgb":
|
584 |
+
# in rgb order by default
|
585 |
+
...
|
586 |
+
elif channel_order == "bgr":
|
587 |
+
controlnet_cond = torch.flip(controlnet_cond, dims=[1])
|
588 |
+
else:
|
589 |
+
raise ValueError(f"unknown `controlnet_conditioning_channel_order`: {channel_order}")
|
590 |
+
|
591 |
+
# prepare attention_mask
|
592 |
+
if attention_mask is not None:
|
593 |
+
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
|
594 |
+
attention_mask = attention_mask.unsqueeze(1)
|
595 |
+
|
596 |
+
# 1. time
|
597 |
+
timesteps = timestep
|
598 |
+
if not torch.is_tensor(timesteps):
|
599 |
+
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
600 |
+
# This would be a good case for the `match` statement (Python 3.10+)
|
601 |
+
is_mps = sample.device.type == "mps"
|
602 |
+
if isinstance(timestep, float):
|
603 |
+
dtype = torch.float32 if is_mps else torch.float64
|
604 |
+
else:
|
605 |
+
dtype = torch.int32 if is_mps else torch.int64
|
606 |
+
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
|
607 |
+
elif len(timesteps.shape) == 0:
|
608 |
+
timesteps = timesteps[None].to(sample.device)
|
609 |
+
|
610 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
611 |
+
timesteps = timesteps.expand(sample.shape[0])
|
612 |
+
|
613 |
+
t_emb = self.time_proj(timesteps)
|
614 |
+
|
615 |
+
# timesteps does not contain any weights and will always return f32 tensors
|
616 |
+
# but time_embedding might actually be running in fp16. so we need to cast here.
|
617 |
+
# there might be better ways to encapsulate this.
|
618 |
+
t_emb = t_emb.to(dtype=sample.dtype)
|
619 |
+
|
620 |
+
emb = self.time_embedding(t_emb, timestep_cond)
|
621 |
+
|
622 |
+
if self.class_embedding is not None:
|
623 |
+
if class_labels is None:
|
624 |
+
raise ValueError("class_labels should be provided when num_class_embeds > 0")
|
625 |
+
|
626 |
+
if self.config.class_embed_type == "timestep":
|
627 |
+
class_labels = self.time_proj(class_labels)
|
628 |
+
|
629 |
+
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
|
630 |
+
emb = emb + class_emb
|
631 |
+
|
632 |
+
# 2. pre-process
|
633 |
+
sample = self.conv_in(sample)
|
634 |
+
|
635 |
+
controlnet_cond = self.controlnet_cond_embedding(controlnet_cond)
|
636 |
+
|
637 |
+
sample = sample + controlnet_cond
|
638 |
+
|
639 |
+
# 3. down
|
640 |
+
down_block_res_samples = (sample,)
|
641 |
+
for downsample_block in self.down_blocks:
|
642 |
+
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
|
643 |
+
sample, res_samples = downsample_block(
|
644 |
+
hidden_states=sample,
|
645 |
+
temb=emb,
|
646 |
+
encoder_hidden_states=encoder_hidden_states,
|
647 |
+
attention_mask=attention_mask,
|
648 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
649 |
+
)
|
650 |
+
else:
|
651 |
+
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
|
652 |
+
|
653 |
+
down_block_res_samples += res_samples
|
654 |
+
|
655 |
+
# 4. mid
|
656 |
+
if self.mid_block is not None:
|
657 |
+
sample = self.mid_block(
|
658 |
+
sample,
|
659 |
+
emb,
|
660 |
+
encoder_hidden_states=encoder_hidden_states,
|
661 |
+
attention_mask=attention_mask,
|
662 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
663 |
+
)
|
664 |
+
|
665 |
+
# 5. Control net blocks
|
666 |
+
|
667 |
+
controlnet_down_block_res_samples = ()
|
668 |
+
|
669 |
+
for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):
|
670 |
+
down_block_res_sample = controlnet_block(down_block_res_sample)
|
671 |
+
controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,)
|
672 |
+
|
673 |
+
down_block_res_samples = controlnet_down_block_res_samples
|
674 |
+
|
675 |
+
mid_block_res_sample = self.controlnet_mid_block(sample)
|
676 |
+
|
677 |
+
# 6. scaling
|
678 |
+
if guess_mode and not self.config.global_pool_conditions:
|
679 |
+
scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0
|
680 |
+
|
681 |
+
scales = scales * conditioning_scale
|
682 |
+
down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)]
|
683 |
+
mid_block_res_sample = mid_block_res_sample * scales[-1] # last one
|
684 |
+
else:
|
685 |
+
down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
|
686 |
+
mid_block_res_sample = mid_block_res_sample * conditioning_scale
|
687 |
+
|
688 |
+
if self.config.global_pool_conditions:
|
689 |
+
down_block_res_samples = [
|
690 |
+
torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples
|
691 |
+
]
|
692 |
+
mid_block_res_sample = torch.mean(mid_block_res_sample, dim=(2, 3), keepdim=True)
|
693 |
+
|
694 |
+
if not return_dict:
|
695 |
+
return (down_block_res_samples, mid_block_res_sample)
|
696 |
+
|
697 |
+
return ControlNetOutput(
|
698 |
+
down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample
|
699 |
+
)
|
700 |
+
|
701 |
+
|
702 |
+
def zero_module(module):
|
703 |
+
for p in module.parameters():
|
704 |
+
nn.init.zeros_(p)
|
705 |
+
return module
|
6DoF/diffusers/models/controlnet_flax.py
ADDED
@@ -0,0 +1,394 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
from typing import Optional, Tuple, Union
|
15 |
+
|
16 |
+
import flax
|
17 |
+
import flax.linen as nn
|
18 |
+
import jax
|
19 |
+
import jax.numpy as jnp
|
20 |
+
from flax.core.frozen_dict import FrozenDict
|
21 |
+
|
22 |
+
from ..configuration_utils import ConfigMixin, flax_register_to_config
|
23 |
+
from ..utils import BaseOutput
|
24 |
+
from .embeddings_flax import FlaxTimestepEmbedding, FlaxTimesteps
|
25 |
+
from .modeling_flax_utils import FlaxModelMixin
|
26 |
+
from .unet_2d_blocks_flax import (
|
27 |
+
FlaxCrossAttnDownBlock2D,
|
28 |
+
FlaxDownBlock2D,
|
29 |
+
FlaxUNetMidBlock2DCrossAttn,
|
30 |
+
)
|
31 |
+
|
32 |
+
|
33 |
+
@flax.struct.dataclass
|
34 |
+
class FlaxControlNetOutput(BaseOutput):
|
35 |
+
"""
|
36 |
+
The output of [`FlaxControlNetModel`].
|
37 |
+
|
38 |
+
Args:
|
39 |
+
down_block_res_samples (`jnp.ndarray`):
|
40 |
+
mid_block_res_sample (`jnp.ndarray`):
|
41 |
+
"""
|
42 |
+
|
43 |
+
down_block_res_samples: jnp.ndarray
|
44 |
+
mid_block_res_sample: jnp.ndarray
|
45 |
+
|
46 |
+
|
47 |
+
class FlaxControlNetConditioningEmbedding(nn.Module):
|
48 |
+
conditioning_embedding_channels: int
|
49 |
+
block_out_channels: Tuple[int] = (16, 32, 96, 256)
|
50 |
+
dtype: jnp.dtype = jnp.float32
|
51 |
+
|
52 |
+
def setup(self):
|
53 |
+
self.conv_in = nn.Conv(
|
54 |
+
self.block_out_channels[0],
|
55 |
+
kernel_size=(3, 3),
|
56 |
+
padding=((1, 1), (1, 1)),
|
57 |
+
dtype=self.dtype,
|
58 |
+
)
|
59 |
+
|
60 |
+
blocks = []
|
61 |
+
for i in range(len(self.block_out_channels) - 1):
|
62 |
+
channel_in = self.block_out_channels[i]
|
63 |
+
channel_out = self.block_out_channels[i + 1]
|
64 |
+
conv1 = nn.Conv(
|
65 |
+
channel_in,
|
66 |
+
kernel_size=(3, 3),
|
67 |
+
padding=((1, 1), (1, 1)),
|
68 |
+
dtype=self.dtype,
|
69 |
+
)
|
70 |
+
blocks.append(conv1)
|
71 |
+
conv2 = nn.Conv(
|
72 |
+
channel_out,
|
73 |
+
kernel_size=(3, 3),
|
74 |
+
strides=(2, 2),
|
75 |
+
padding=((1, 1), (1, 1)),
|
76 |
+
dtype=self.dtype,
|
77 |
+
)
|
78 |
+
blocks.append(conv2)
|
79 |
+
self.blocks = blocks
|
80 |
+
|
81 |
+
self.conv_out = nn.Conv(
|
82 |
+
self.conditioning_embedding_channels,
|
83 |
+
kernel_size=(3, 3),
|
84 |
+
padding=((1, 1), (1, 1)),
|
85 |
+
kernel_init=nn.initializers.zeros_init(),
|
86 |
+
bias_init=nn.initializers.zeros_init(),
|
87 |
+
dtype=self.dtype,
|
88 |
+
)
|
89 |
+
|
90 |
+
def __call__(self, conditioning):
|
91 |
+
embedding = self.conv_in(conditioning)
|
92 |
+
embedding = nn.silu(embedding)
|
93 |
+
|
94 |
+
for block in self.blocks:
|
95 |
+
embedding = block(embedding)
|
96 |
+
embedding = nn.silu(embedding)
|
97 |
+
|
98 |
+
embedding = self.conv_out(embedding)
|
99 |
+
|
100 |
+
return embedding
|
101 |
+
|
102 |
+
|
103 |
+
@flax_register_to_config
|
104 |
+
class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin):
|
105 |
+
r"""
|
106 |
+
A ControlNet model.
|
107 |
+
|
108 |
+
This model inherits from [`FlaxModelMixin`]. Check the superclass documentation for it’s generic methods
|
109 |
+
implemented for all models (such as downloading or saving).
|
110 |
+
|
111 |
+
This model is also a Flax Linen [`flax.linen.Module`](https://flax.readthedocs.io/en/latest/flax.linen.html#module)
|
112 |
+
subclass. Use it as a regular Flax Linen module and refer to the Flax documentation for all matters related to its
|
113 |
+
general usage and behavior.
|
114 |
+
|
115 |
+
Inherent JAX features such as the following are supported:
|
116 |
+
|
117 |
+
- [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
|
118 |
+
- [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
|
119 |
+
- [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
|
120 |
+
- [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
|
121 |
+
|
122 |
+
Parameters:
|
123 |
+
sample_size (`int`, *optional*):
|
124 |
+
The size of the input sample.
|
125 |
+
in_channels (`int`, *optional*, defaults to 4):
|
126 |
+
The number of channels in the input sample.
|
127 |
+
down_block_types (`Tuple[str]`, *optional*, defaults to `("FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxDownBlock2D")`):
|
128 |
+
The tuple of downsample blocks to use.
|
129 |
+
block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
|
130 |
+
The tuple of output channels for each block.
|
131 |
+
layers_per_block (`int`, *optional*, defaults to 2):
|
132 |
+
The number of layers per block.
|
133 |
+
attention_head_dim (`int` or `Tuple[int]`, *optional*, defaults to 8):
|
134 |
+
The dimension of the attention heads.
|
135 |
+
num_attention_heads (`int` or `Tuple[int]`, *optional*):
|
136 |
+
The number of attention heads.
|
137 |
+
cross_attention_dim (`int`, *optional*, defaults to 768):
|
138 |
+
The dimension of the cross attention features.
|
139 |
+
dropout (`float`, *optional*, defaults to 0):
|
140 |
+
Dropout probability for down, up and bottleneck blocks.
|
141 |
+
flip_sin_to_cos (`bool`, *optional*, defaults to `True`):
|
142 |
+
Whether to flip the sin to cos in the time embedding.
|
143 |
+
freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
|
144 |
+
controlnet_conditioning_channel_order (`str`, *optional*, defaults to `rgb`):
|
145 |
+
The channel order of conditional image. Will convert to `rgb` if it's `bgr`.
|
146 |
+
conditioning_embedding_out_channels (`tuple`, *optional*, defaults to `(16, 32, 96, 256)`):
|
147 |
+
The tuple of output channel for each block in the `conditioning_embedding` layer.
|
148 |
+
"""
|
149 |
+
sample_size: int = 32
|
150 |
+
in_channels: int = 4
|
151 |
+
down_block_types: Tuple[str] = (
|
152 |
+
"CrossAttnDownBlock2D",
|
153 |
+
"CrossAttnDownBlock2D",
|
154 |
+
"CrossAttnDownBlock2D",
|
155 |
+
"DownBlock2D",
|
156 |
+
)
|
157 |
+
only_cross_attention: Union[bool, Tuple[bool]] = False
|
158 |
+
block_out_channels: Tuple[int] = (320, 640, 1280, 1280)
|
159 |
+
layers_per_block: int = 2
|
160 |
+
attention_head_dim: Union[int, Tuple[int]] = 8
|
161 |
+
num_attention_heads: Optional[Union[int, Tuple[int]]] = None
|
162 |
+
cross_attention_dim: int = 1280
|
163 |
+
dropout: float = 0.0
|
164 |
+
use_linear_projection: bool = False
|
165 |
+
dtype: jnp.dtype = jnp.float32
|
166 |
+
flip_sin_to_cos: bool = True
|
167 |
+
freq_shift: int = 0
|
168 |
+
controlnet_conditioning_channel_order: str = "rgb"
|
169 |
+
conditioning_embedding_out_channels: Tuple[int] = (16, 32, 96, 256)
|
170 |
+
|
171 |
+
def init_weights(self, rng: jax.random.KeyArray) -> FrozenDict:
|
172 |
+
# init input tensors
|
173 |
+
sample_shape = (1, self.in_channels, self.sample_size, self.sample_size)
|
174 |
+
sample = jnp.zeros(sample_shape, dtype=jnp.float32)
|
175 |
+
timesteps = jnp.ones((1,), dtype=jnp.int32)
|
176 |
+
encoder_hidden_states = jnp.zeros((1, 1, self.cross_attention_dim), dtype=jnp.float32)
|
177 |
+
controlnet_cond_shape = (1, 3, self.sample_size * 8, self.sample_size * 8)
|
178 |
+
controlnet_cond = jnp.zeros(controlnet_cond_shape, dtype=jnp.float32)
|
179 |
+
|
180 |
+
params_rng, dropout_rng = jax.random.split(rng)
|
181 |
+
rngs = {"params": params_rng, "dropout": dropout_rng}
|
182 |
+
|
183 |
+
return self.init(rngs, sample, timesteps, encoder_hidden_states, controlnet_cond)["params"]
|
184 |
+
|
185 |
+
def setup(self):
|
186 |
+
block_out_channels = self.block_out_channels
|
187 |
+
time_embed_dim = block_out_channels[0] * 4
|
188 |
+
|
189 |
+
# If `num_attention_heads` is not defined (which is the case for most models)
|
190 |
+
# it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
|
191 |
+
# The reason for this behavior is to correct for incorrectly named variables that were introduced
|
192 |
+
# when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
|
193 |
+
# Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
|
194 |
+
# which is why we correct for the naming here.
|
195 |
+
num_attention_heads = self.num_attention_heads or self.attention_head_dim
|
196 |
+
|
197 |
+
# input
|
198 |
+
self.conv_in = nn.Conv(
|
199 |
+
block_out_channels[0],
|
200 |
+
kernel_size=(3, 3),
|
201 |
+
strides=(1, 1),
|
202 |
+
padding=((1, 1), (1, 1)),
|
203 |
+
dtype=self.dtype,
|
204 |
+
)
|
205 |
+
|
206 |
+
# time
|
207 |
+
self.time_proj = FlaxTimesteps(
|
208 |
+
block_out_channels[0], flip_sin_to_cos=self.flip_sin_to_cos, freq_shift=self.config.freq_shift
|
209 |
+
)
|
210 |
+
self.time_embedding = FlaxTimestepEmbedding(time_embed_dim, dtype=self.dtype)
|
211 |
+
|
212 |
+
self.controlnet_cond_embedding = FlaxControlNetConditioningEmbedding(
|
213 |
+
conditioning_embedding_channels=block_out_channels[0],
|
214 |
+
block_out_channels=self.conditioning_embedding_out_channels,
|
215 |
+
)
|
216 |
+
|
217 |
+
only_cross_attention = self.only_cross_attention
|
218 |
+
if isinstance(only_cross_attention, bool):
|
219 |
+
only_cross_attention = (only_cross_attention,) * len(self.down_block_types)
|
220 |
+
|
221 |
+
if isinstance(num_attention_heads, int):
|
222 |
+
num_attention_heads = (num_attention_heads,) * len(self.down_block_types)
|
223 |
+
|
224 |
+
# down
|
225 |
+
down_blocks = []
|
226 |
+
controlnet_down_blocks = []
|
227 |
+
|
228 |
+
output_channel = block_out_channels[0]
|
229 |
+
|
230 |
+
controlnet_block = nn.Conv(
|
231 |
+
output_channel,
|
232 |
+
kernel_size=(1, 1),
|
233 |
+
padding="VALID",
|
234 |
+
kernel_init=nn.initializers.zeros_init(),
|
235 |
+
bias_init=nn.initializers.zeros_init(),
|
236 |
+
dtype=self.dtype,
|
237 |
+
)
|
238 |
+
controlnet_down_blocks.append(controlnet_block)
|
239 |
+
|
240 |
+
for i, down_block_type in enumerate(self.down_block_types):
|
241 |
+
input_channel = output_channel
|
242 |
+
output_channel = block_out_channels[i]
|
243 |
+
is_final_block = i == len(block_out_channels) - 1
|
244 |
+
|
245 |
+
if down_block_type == "CrossAttnDownBlock2D":
|
246 |
+
down_block = FlaxCrossAttnDownBlock2D(
|
247 |
+
in_channels=input_channel,
|
248 |
+
out_channels=output_channel,
|
249 |
+
dropout=self.dropout,
|
250 |
+
num_layers=self.layers_per_block,
|
251 |
+
num_attention_heads=num_attention_heads[i],
|
252 |
+
add_downsample=not is_final_block,
|
253 |
+
use_linear_projection=self.use_linear_projection,
|
254 |
+
only_cross_attention=only_cross_attention[i],
|
255 |
+
dtype=self.dtype,
|
256 |
+
)
|
257 |
+
else:
|
258 |
+
down_block = FlaxDownBlock2D(
|
259 |
+
in_channels=input_channel,
|
260 |
+
out_channels=output_channel,
|
261 |
+
dropout=self.dropout,
|
262 |
+
num_layers=self.layers_per_block,
|
263 |
+
add_downsample=not is_final_block,
|
264 |
+
dtype=self.dtype,
|
265 |
+
)
|
266 |
+
|
267 |
+
down_blocks.append(down_block)
|
268 |
+
|
269 |
+
for _ in range(self.layers_per_block):
|
270 |
+
controlnet_block = nn.Conv(
|
271 |
+
output_channel,
|
272 |
+
kernel_size=(1, 1),
|
273 |
+
padding="VALID",
|
274 |
+
kernel_init=nn.initializers.zeros_init(),
|
275 |
+
bias_init=nn.initializers.zeros_init(),
|
276 |
+
dtype=self.dtype,
|
277 |
+
)
|
278 |
+
controlnet_down_blocks.append(controlnet_block)
|
279 |
+
|
280 |
+
if not is_final_block:
|
281 |
+
controlnet_block = nn.Conv(
|
282 |
+
output_channel,
|
283 |
+
kernel_size=(1, 1),
|
284 |
+
padding="VALID",
|
285 |
+
kernel_init=nn.initializers.zeros_init(),
|
286 |
+
bias_init=nn.initializers.zeros_init(),
|
287 |
+
dtype=self.dtype,
|
288 |
+
)
|
289 |
+
controlnet_down_blocks.append(controlnet_block)
|
290 |
+
|
291 |
+
self.down_blocks = down_blocks
|
292 |
+
self.controlnet_down_blocks = controlnet_down_blocks
|
293 |
+
|
294 |
+
# mid
|
295 |
+
mid_block_channel = block_out_channels[-1]
|
296 |
+
self.mid_block = FlaxUNetMidBlock2DCrossAttn(
|
297 |
+
in_channels=mid_block_channel,
|
298 |
+
dropout=self.dropout,
|
299 |
+
num_attention_heads=num_attention_heads[-1],
|
300 |
+
use_linear_projection=self.use_linear_projection,
|
301 |
+
dtype=self.dtype,
|
302 |
+
)
|
303 |
+
|
304 |
+
self.controlnet_mid_block = nn.Conv(
|
305 |
+
mid_block_channel,
|
306 |
+
kernel_size=(1, 1),
|
307 |
+
padding="VALID",
|
308 |
+
kernel_init=nn.initializers.zeros_init(),
|
309 |
+
bias_init=nn.initializers.zeros_init(),
|
310 |
+
dtype=self.dtype,
|
311 |
+
)
|
312 |
+
|
313 |
+
def __call__(
|
314 |
+
self,
|
315 |
+
sample,
|
316 |
+
timesteps,
|
317 |
+
encoder_hidden_states,
|
318 |
+
controlnet_cond,
|
319 |
+
conditioning_scale: float = 1.0,
|
320 |
+
return_dict: bool = True,
|
321 |
+
train: bool = False,
|
322 |
+
) -> Union[FlaxControlNetOutput, Tuple]:
|
323 |
+
r"""
|
324 |
+
Args:
|
325 |
+
sample (`jnp.ndarray`): (batch, channel, height, width) noisy inputs tensor
|
326 |
+
timestep (`jnp.ndarray` or `float` or `int`): timesteps
|
327 |
+
encoder_hidden_states (`jnp.ndarray`): (batch_size, sequence_length, hidden_size) encoder hidden states
|
328 |
+
controlnet_cond (`jnp.ndarray`): (batch, channel, height, width) the conditional input tensor
|
329 |
+
conditioning_scale: (`float`) the scale factor for controlnet outputs
|
330 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
331 |
+
Whether or not to return a [`models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] instead of a
|
332 |
+
plain tuple.
|
333 |
+
train (`bool`, *optional*, defaults to `False`):
|
334 |
+
Use deterministic functions and disable dropout when not training.
|
335 |
+
|
336 |
+
Returns:
|
337 |
+
[`~models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] or `tuple`:
|
338 |
+
[`~models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`.
|
339 |
+
When returning a tuple, the first element is the sample tensor.
|
340 |
+
"""
|
341 |
+
channel_order = self.controlnet_conditioning_channel_order
|
342 |
+
if channel_order == "bgr":
|
343 |
+
controlnet_cond = jnp.flip(controlnet_cond, axis=1)
|
344 |
+
|
345 |
+
# 1. time
|
346 |
+
if not isinstance(timesteps, jnp.ndarray):
|
347 |
+
timesteps = jnp.array([timesteps], dtype=jnp.int32)
|
348 |
+
elif isinstance(timesteps, jnp.ndarray) and len(timesteps.shape) == 0:
|
349 |
+
timesteps = timesteps.astype(dtype=jnp.float32)
|
350 |
+
timesteps = jnp.expand_dims(timesteps, 0)
|
351 |
+
|
352 |
+
t_emb = self.time_proj(timesteps)
|
353 |
+
t_emb = self.time_embedding(t_emb)
|
354 |
+
|
355 |
+
# 2. pre-process
|
356 |
+
sample = jnp.transpose(sample, (0, 2, 3, 1))
|
357 |
+
sample = self.conv_in(sample)
|
358 |
+
|
359 |
+
controlnet_cond = jnp.transpose(controlnet_cond, (0, 2, 3, 1))
|
360 |
+
controlnet_cond = self.controlnet_cond_embedding(controlnet_cond)
|
361 |
+
sample += controlnet_cond
|
362 |
+
|
363 |
+
# 3. down
|
364 |
+
down_block_res_samples = (sample,)
|
365 |
+
for down_block in self.down_blocks:
|
366 |
+
if isinstance(down_block, FlaxCrossAttnDownBlock2D):
|
367 |
+
sample, res_samples = down_block(sample, t_emb, encoder_hidden_states, deterministic=not train)
|
368 |
+
else:
|
369 |
+
sample, res_samples = down_block(sample, t_emb, deterministic=not train)
|
370 |
+
down_block_res_samples += res_samples
|
371 |
+
|
372 |
+
# 4. mid
|
373 |
+
sample = self.mid_block(sample, t_emb, encoder_hidden_states, deterministic=not train)
|
374 |
+
|
375 |
+
# 5. contronet blocks
|
376 |
+
controlnet_down_block_res_samples = ()
|
377 |
+
for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):
|
378 |
+
down_block_res_sample = controlnet_block(down_block_res_sample)
|
379 |
+
controlnet_down_block_res_samples += (down_block_res_sample,)
|
380 |
+
|
381 |
+
down_block_res_samples = controlnet_down_block_res_samples
|
382 |
+
|
383 |
+
mid_block_res_sample = self.controlnet_mid_block(sample)
|
384 |
+
|
385 |
+
# 6. scaling
|
386 |
+
down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
|
387 |
+
mid_block_res_sample *= conditioning_scale
|
388 |
+
|
389 |
+
if not return_dict:
|
390 |
+
return (down_block_res_samples, mid_block_res_sample)
|
391 |
+
|
392 |
+
return FlaxControlNetOutput(
|
393 |
+
down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample
|
394 |
+
)
|
6DoF/diffusers/models/cross_attention.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
from ..utils import deprecate
|
15 |
+
from .attention_processor import ( # noqa: F401
|
16 |
+
Attention,
|
17 |
+
AttentionProcessor,
|
18 |
+
AttnAddedKVProcessor,
|
19 |
+
AttnProcessor2_0,
|
20 |
+
LoRAAttnProcessor,
|
21 |
+
LoRALinearLayer,
|
22 |
+
LoRAXFormersAttnProcessor,
|
23 |
+
SlicedAttnAddedKVProcessor,
|
24 |
+
SlicedAttnProcessor,
|
25 |
+
XFormersAttnProcessor,
|
26 |
+
)
|
27 |
+
from .attention_processor import AttnProcessor as AttnProcessorRename # noqa: F401
|
28 |
+
|
29 |
+
|
30 |
+
deprecate(
|
31 |
+
"cross_attention",
|
32 |
+
"0.20.0",
|
33 |
+
"Importing from cross_attention is deprecated. Please import from diffusers.models.attention_processor instead.",
|
34 |
+
standard_warn=False,
|
35 |
+
)
|
36 |
+
|
37 |
+
|
38 |
+
AttnProcessor = AttentionProcessor
|
39 |
+
|
40 |
+
|
41 |
+
class CrossAttention(Attention):
|
42 |
+
def __init__(self, *args, **kwargs):
|
43 |
+
deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.20.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead."
|
44 |
+
deprecate("cross_attention", "0.20.0", deprecation_message, standard_warn=False)
|
45 |
+
super().__init__(*args, **kwargs)
|
46 |
+
|
47 |
+
|
48 |
+
class CrossAttnProcessor(AttnProcessorRename):
|
49 |
+
def __init__(self, *args, **kwargs):
|
50 |
+
deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.20.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead."
|
51 |
+
deprecate("cross_attention", "0.20.0", deprecation_message, standard_warn=False)
|
52 |
+
super().__init__(*args, **kwargs)
|
53 |
+
|
54 |
+
|
55 |
+
class LoRACrossAttnProcessor(LoRAAttnProcessor):
|
56 |
+
def __init__(self, *args, **kwargs):
|
57 |
+
deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.20.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead."
|
58 |
+
deprecate("cross_attention", "0.20.0", deprecation_message, standard_warn=False)
|
59 |
+
super().__init__(*args, **kwargs)
|
60 |
+
|
61 |
+
|
62 |
+
class CrossAttnAddedKVProcessor(AttnAddedKVProcessor):
|
63 |
+
def __init__(self, *args, **kwargs):
|
64 |
+
deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.20.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead."
|
65 |
+
deprecate("cross_attention", "0.20.0", deprecation_message, standard_warn=False)
|
66 |
+
super().__init__(*args, **kwargs)
|
67 |
+
|
68 |
+
|
69 |
+
class XFormersCrossAttnProcessor(XFormersAttnProcessor):
|
70 |
+
def __init__(self, *args, **kwargs):
|
71 |
+
deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.20.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead."
|
72 |
+
deprecate("cross_attention", "0.20.0", deprecation_message, standard_warn=False)
|
73 |
+
super().__init__(*args, **kwargs)
|
74 |
+
|
75 |
+
|
76 |
+
class LoRAXFormersCrossAttnProcessor(LoRAXFormersAttnProcessor):
|
77 |
+
def __init__(self, *args, **kwargs):
|
78 |
+
deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.20.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead."
|
79 |
+
deprecate("cross_attention", "0.20.0", deprecation_message, standard_warn=False)
|
80 |
+
super().__init__(*args, **kwargs)
|
81 |
+
|
82 |
+
|
83 |
+
class SlicedCrossAttnProcessor(SlicedAttnProcessor):
|
84 |
+
def __init__(self, *args, **kwargs):
|
85 |
+
deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.20.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead."
|
86 |
+
deprecate("cross_attention", "0.20.0", deprecation_message, standard_warn=False)
|
87 |
+
super().__init__(*args, **kwargs)
|
88 |
+
|
89 |
+
|
90 |
+
class SlicedCrossAttnAddedKVProcessor(SlicedAttnAddedKVProcessor):
|
91 |
+
def __init__(self, *args, **kwargs):
|
92 |
+
deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.20.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead."
|
93 |
+
deprecate("cross_attention", "0.20.0", deprecation_message, standard_warn=False)
|
94 |
+
super().__init__(*args, **kwargs)
|
6DoF/diffusers/models/dual_transformer_2d.py
ADDED
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
from typing import Optional
|
15 |
+
|
16 |
+
from torch import nn
|
17 |
+
|
18 |
+
from .transformer_2d import Transformer2DModel, Transformer2DModelOutput
|
19 |
+
|
20 |
+
|
21 |
+
class DualTransformer2DModel(nn.Module):
|
22 |
+
"""
|
23 |
+
Dual transformer wrapper that combines two `Transformer2DModel`s for mixed inference.
|
24 |
+
|
25 |
+
Parameters:
|
26 |
+
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
|
27 |
+
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
|
28 |
+
in_channels (`int`, *optional*):
|
29 |
+
Pass if the input is continuous. The number of channels in the input and output.
|
30 |
+
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
|
31 |
+
dropout (`float`, *optional*, defaults to 0.1): The dropout probability to use.
|
32 |
+
cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use.
|
33 |
+
sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images.
|
34 |
+
Note that this is fixed at training time as it is used for learning a number of position embeddings. See
|
35 |
+
`ImagePositionalEmbeddings`.
|
36 |
+
num_vector_embeds (`int`, *optional*):
|
37 |
+
Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels.
|
38 |
+
Includes the class for the masked latent pixel.
|
39 |
+
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
40 |
+
num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`.
|
41 |
+
The number of diffusion steps used during training. Note that this is fixed at training time as it is used
|
42 |
+
to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for
|
43 |
+
up to but not more than steps than `num_embeds_ada_norm`.
|
44 |
+
attention_bias (`bool`, *optional*):
|
45 |
+
Configure if the TransformerBlocks' attention should contain a bias parameter.
|
46 |
+
"""
|
47 |
+
|
48 |
+
def __init__(
|
49 |
+
self,
|
50 |
+
num_attention_heads: int = 16,
|
51 |
+
attention_head_dim: int = 88,
|
52 |
+
in_channels: Optional[int] = None,
|
53 |
+
num_layers: int = 1,
|
54 |
+
dropout: float = 0.0,
|
55 |
+
norm_num_groups: int = 32,
|
56 |
+
cross_attention_dim: Optional[int] = None,
|
57 |
+
attention_bias: bool = False,
|
58 |
+
sample_size: Optional[int] = None,
|
59 |
+
num_vector_embeds: Optional[int] = None,
|
60 |
+
activation_fn: str = "geglu",
|
61 |
+
num_embeds_ada_norm: Optional[int] = None,
|
62 |
+
):
|
63 |
+
super().__init__()
|
64 |
+
self.transformers = nn.ModuleList(
|
65 |
+
[
|
66 |
+
Transformer2DModel(
|
67 |
+
num_attention_heads=num_attention_heads,
|
68 |
+
attention_head_dim=attention_head_dim,
|
69 |
+
in_channels=in_channels,
|
70 |
+
num_layers=num_layers,
|
71 |
+
dropout=dropout,
|
72 |
+
norm_num_groups=norm_num_groups,
|
73 |
+
cross_attention_dim=cross_attention_dim,
|
74 |
+
attention_bias=attention_bias,
|
75 |
+
sample_size=sample_size,
|
76 |
+
num_vector_embeds=num_vector_embeds,
|
77 |
+
activation_fn=activation_fn,
|
78 |
+
num_embeds_ada_norm=num_embeds_ada_norm,
|
79 |
+
)
|
80 |
+
for _ in range(2)
|
81 |
+
]
|
82 |
+
)
|
83 |
+
|
84 |
+
# Variables that can be set by a pipeline:
|
85 |
+
|
86 |
+
# The ratio of transformer1 to transformer2's output states to be combined during inference
|
87 |
+
self.mix_ratio = 0.5
|
88 |
+
|
89 |
+
# The shape of `encoder_hidden_states` is expected to be
|
90 |
+
# `(batch_size, condition_lengths[0]+condition_lengths[1], num_features)`
|
91 |
+
self.condition_lengths = [77, 257]
|
92 |
+
|
93 |
+
# Which transformer to use to encode which condition.
|
94 |
+
# E.g. `(1, 0)` means that we'll use `transformers[1](conditions[0])` and `transformers[0](conditions[1])`
|
95 |
+
self.transformer_index_for_condition = [1, 0]
|
96 |
+
|
97 |
+
def forward(
|
98 |
+
self,
|
99 |
+
hidden_states,
|
100 |
+
encoder_hidden_states,
|
101 |
+
timestep=None,
|
102 |
+
attention_mask=None,
|
103 |
+
cross_attention_kwargs=None,
|
104 |
+
return_dict: bool = True,
|
105 |
+
):
|
106 |
+
"""
|
107 |
+
Args:
|
108 |
+
hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`.
|
109 |
+
When continuous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input
|
110 |
+
hidden_states
|
111 |
+
encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
|
112 |
+
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
|
113 |
+
self-attention.
|
114 |
+
timestep ( `torch.long`, *optional*):
|
115 |
+
Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step.
|
116 |
+
attention_mask (`torch.FloatTensor`, *optional*):
|
117 |
+
Optional attention mask to be applied in Attention
|
118 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
119 |
+
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
|
120 |
+
|
121 |
+
Returns:
|
122 |
+
[`~models.transformer_2d.Transformer2DModelOutput`] or `tuple`:
|
123 |
+
[`~models.transformer_2d.Transformer2DModelOutput`] if `return_dict` is True, otherwise a `tuple`. When
|
124 |
+
returning a tuple, the first element is the sample tensor.
|
125 |
+
"""
|
126 |
+
input_states = hidden_states
|
127 |
+
|
128 |
+
encoded_states = []
|
129 |
+
tokens_start = 0
|
130 |
+
# attention_mask is not used yet
|
131 |
+
for i in range(2):
|
132 |
+
# for each of the two transformers, pass the corresponding condition tokens
|
133 |
+
condition_state = encoder_hidden_states[:, tokens_start : tokens_start + self.condition_lengths[i]]
|
134 |
+
transformer_index = self.transformer_index_for_condition[i]
|
135 |
+
encoded_state = self.transformers[transformer_index](
|
136 |
+
input_states,
|
137 |
+
encoder_hidden_states=condition_state,
|
138 |
+
timestep=timestep,
|
139 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
140 |
+
return_dict=False,
|
141 |
+
)[0]
|
142 |
+
encoded_states.append(encoded_state - input_states)
|
143 |
+
tokens_start += self.condition_lengths[i]
|
144 |
+
|
145 |
+
output_states = encoded_states[0] * self.mix_ratio + encoded_states[1] * (1 - self.mix_ratio)
|
146 |
+
output_states = output_states + input_states
|
147 |
+
|
148 |
+
if not return_dict:
|
149 |
+
return (output_states,)
|
150 |
+
|
151 |
+
return Transformer2DModelOutput(sample=output_states)
|
6DoF/diffusers/models/embeddings.py
ADDED
@@ -0,0 +1,546 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import math
|
15 |
+
from typing import Optional
|
16 |
+
|
17 |
+
import numpy as np
|
18 |
+
import torch
|
19 |
+
from torch import nn
|
20 |
+
|
21 |
+
from .activations import get_activation
|
22 |
+
|
23 |
+
|
24 |
+
def get_timestep_embedding(
|
25 |
+
timesteps: torch.Tensor,
|
26 |
+
embedding_dim: int,
|
27 |
+
flip_sin_to_cos: bool = False,
|
28 |
+
downscale_freq_shift: float = 1,
|
29 |
+
scale: float = 1,
|
30 |
+
max_period: int = 10000,
|
31 |
+
):
|
32 |
+
"""
|
33 |
+
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
|
34 |
+
|
35 |
+
:param timesteps: a 1-D Tensor of N indices, one per batch element.
|
36 |
+
These may be fractional.
|
37 |
+
:param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
|
38 |
+
embeddings. :return: an [N x dim] Tensor of positional embeddings.
|
39 |
+
"""
|
40 |
+
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
|
41 |
+
|
42 |
+
half_dim = embedding_dim // 2
|
43 |
+
exponent = -math.log(max_period) * torch.arange(
|
44 |
+
start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
|
45 |
+
)
|
46 |
+
exponent = exponent / (half_dim - downscale_freq_shift)
|
47 |
+
|
48 |
+
emb = torch.exp(exponent)
|
49 |
+
emb = timesteps[:, None].float() * emb[None, :]
|
50 |
+
|
51 |
+
# scale embeddings
|
52 |
+
emb = scale * emb
|
53 |
+
|
54 |
+
# concat sine and cosine embeddings
|
55 |
+
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
|
56 |
+
|
57 |
+
# flip sine and cosine embeddings
|
58 |
+
if flip_sin_to_cos:
|
59 |
+
emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
|
60 |
+
|
61 |
+
# zero pad
|
62 |
+
if embedding_dim % 2 == 1:
|
63 |
+
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
|
64 |
+
return emb
|
65 |
+
|
66 |
+
|
67 |
+
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
|
68 |
+
"""
|
69 |
+
grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or
|
70 |
+
[1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
71 |
+
"""
|
72 |
+
grid_h = np.arange(grid_size, dtype=np.float32)
|
73 |
+
grid_w = np.arange(grid_size, dtype=np.float32)
|
74 |
+
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
75 |
+
grid = np.stack(grid, axis=0)
|
76 |
+
|
77 |
+
grid = grid.reshape([2, 1, grid_size, grid_size])
|
78 |
+
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
79 |
+
if cls_token and extra_tokens > 0:
|
80 |
+
pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
|
81 |
+
return pos_embed
|
82 |
+
|
83 |
+
|
84 |
+
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
85 |
+
if embed_dim % 2 != 0:
|
86 |
+
raise ValueError("embed_dim must be divisible by 2")
|
87 |
+
|
88 |
+
# use half of dimensions to encode grid_h
|
89 |
+
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
90 |
+
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
91 |
+
|
92 |
+
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
93 |
+
return emb
|
94 |
+
|
95 |
+
|
96 |
+
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
97 |
+
"""
|
98 |
+
embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D)
|
99 |
+
"""
|
100 |
+
if embed_dim % 2 != 0:
|
101 |
+
raise ValueError("embed_dim must be divisible by 2")
|
102 |
+
|
103 |
+
omega = np.arange(embed_dim // 2, dtype=np.float64)
|
104 |
+
omega /= embed_dim / 2.0
|
105 |
+
omega = 1.0 / 10000**omega # (D/2,)
|
106 |
+
|
107 |
+
pos = pos.reshape(-1) # (M,)
|
108 |
+
out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
|
109 |
+
|
110 |
+
emb_sin = np.sin(out) # (M, D/2)
|
111 |
+
emb_cos = np.cos(out) # (M, D/2)
|
112 |
+
|
113 |
+
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
114 |
+
return emb
|
115 |
+
|
116 |
+
|
117 |
+
class PatchEmbed(nn.Module):
|
118 |
+
"""2D Image to Patch Embedding"""
|
119 |
+
|
120 |
+
def __init__(
|
121 |
+
self,
|
122 |
+
height=224,
|
123 |
+
width=224,
|
124 |
+
patch_size=16,
|
125 |
+
in_channels=3,
|
126 |
+
embed_dim=768,
|
127 |
+
layer_norm=False,
|
128 |
+
flatten=True,
|
129 |
+
bias=True,
|
130 |
+
):
|
131 |
+
super().__init__()
|
132 |
+
|
133 |
+
num_patches = (height // patch_size) * (width // patch_size)
|
134 |
+
self.flatten = flatten
|
135 |
+
self.layer_norm = layer_norm
|
136 |
+
|
137 |
+
self.proj = nn.Conv2d(
|
138 |
+
in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
|
139 |
+
)
|
140 |
+
if layer_norm:
|
141 |
+
self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6)
|
142 |
+
else:
|
143 |
+
self.norm = None
|
144 |
+
|
145 |
+
pos_embed = get_2d_sincos_pos_embed(embed_dim, int(num_patches**0.5))
|
146 |
+
self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=False)
|
147 |
+
|
148 |
+
def forward(self, latent):
|
149 |
+
latent = self.proj(latent)
|
150 |
+
if self.flatten:
|
151 |
+
latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC
|
152 |
+
if self.layer_norm:
|
153 |
+
latent = self.norm(latent)
|
154 |
+
return latent + self.pos_embed
|
155 |
+
|
156 |
+
|
157 |
+
class TimestepEmbedding(nn.Module):
|
158 |
+
def __init__(
|
159 |
+
self,
|
160 |
+
in_channels: int,
|
161 |
+
time_embed_dim: int,
|
162 |
+
act_fn: str = "silu",
|
163 |
+
out_dim: int = None,
|
164 |
+
post_act_fn: Optional[str] = None,
|
165 |
+
cond_proj_dim=None,
|
166 |
+
):
|
167 |
+
super().__init__()
|
168 |
+
|
169 |
+
self.linear_1 = nn.Linear(in_channels, time_embed_dim)
|
170 |
+
|
171 |
+
if cond_proj_dim is not None:
|
172 |
+
self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
|
173 |
+
else:
|
174 |
+
self.cond_proj = None
|
175 |
+
|
176 |
+
self.act = get_activation(act_fn)
|
177 |
+
|
178 |
+
if out_dim is not None:
|
179 |
+
time_embed_dim_out = out_dim
|
180 |
+
else:
|
181 |
+
time_embed_dim_out = time_embed_dim
|
182 |
+
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out)
|
183 |
+
|
184 |
+
if post_act_fn is None:
|
185 |
+
self.post_act = None
|
186 |
+
else:
|
187 |
+
self.post_act = get_activation(post_act_fn)
|
188 |
+
|
189 |
+
def forward(self, sample, condition=None):
|
190 |
+
if condition is not None:
|
191 |
+
sample = sample + self.cond_proj(condition)
|
192 |
+
sample = self.linear_1(sample)
|
193 |
+
|
194 |
+
if self.act is not None:
|
195 |
+
sample = self.act(sample)
|
196 |
+
|
197 |
+
sample = self.linear_2(sample)
|
198 |
+
|
199 |
+
if self.post_act is not None:
|
200 |
+
sample = self.post_act(sample)
|
201 |
+
return sample
|
202 |
+
|
203 |
+
|
204 |
+
class Timesteps(nn.Module):
|
205 |
+
def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float):
|
206 |
+
super().__init__()
|
207 |
+
self.num_channels = num_channels
|
208 |
+
self.flip_sin_to_cos = flip_sin_to_cos
|
209 |
+
self.downscale_freq_shift = downscale_freq_shift
|
210 |
+
|
211 |
+
def forward(self, timesteps):
|
212 |
+
t_emb = get_timestep_embedding(
|
213 |
+
timesteps,
|
214 |
+
self.num_channels,
|
215 |
+
flip_sin_to_cos=self.flip_sin_to_cos,
|
216 |
+
downscale_freq_shift=self.downscale_freq_shift,
|
217 |
+
)
|
218 |
+
return t_emb
|
219 |
+
|
220 |
+
|
221 |
+
class GaussianFourierProjection(nn.Module):
|
222 |
+
"""Gaussian Fourier embeddings for noise levels."""
|
223 |
+
|
224 |
+
def __init__(
|
225 |
+
self, embedding_size: int = 256, scale: float = 1.0, set_W_to_weight=True, log=True, flip_sin_to_cos=False
|
226 |
+
):
|
227 |
+
super().__init__()
|
228 |
+
self.weight = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
|
229 |
+
self.log = log
|
230 |
+
self.flip_sin_to_cos = flip_sin_to_cos
|
231 |
+
|
232 |
+
if set_W_to_weight:
|
233 |
+
# to delete later
|
234 |
+
self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
|
235 |
+
|
236 |
+
self.weight = self.W
|
237 |
+
|
238 |
+
def forward(self, x):
|
239 |
+
if self.log:
|
240 |
+
x = torch.log(x)
|
241 |
+
|
242 |
+
x_proj = x[:, None] * self.weight[None, :] * 2 * np.pi
|
243 |
+
|
244 |
+
if self.flip_sin_to_cos:
|
245 |
+
out = torch.cat([torch.cos(x_proj), torch.sin(x_proj)], dim=-1)
|
246 |
+
else:
|
247 |
+
out = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
|
248 |
+
return out
|
249 |
+
|
250 |
+
|
251 |
+
class ImagePositionalEmbeddings(nn.Module):
|
252 |
+
"""
|
253 |
+
Converts latent image classes into vector embeddings. Sums the vector embeddings with positional embeddings for the
|
254 |
+
height and width of the latent space.
|
255 |
+
|
256 |
+
For more details, see figure 10 of the dall-e paper: https://arxiv.org/abs/2102.12092
|
257 |
+
|
258 |
+
For VQ-diffusion:
|
259 |
+
|
260 |
+
Output vector embeddings are used as input for the transformer.
|
261 |
+
|
262 |
+
Note that the vector embeddings for the transformer are different than the vector embeddings from the VQVAE.
|
263 |
+
|
264 |
+
Args:
|
265 |
+
num_embed (`int`):
|
266 |
+
Number of embeddings for the latent pixels embeddings.
|
267 |
+
height (`int`):
|
268 |
+
Height of the latent image i.e. the number of height embeddings.
|
269 |
+
width (`int`):
|
270 |
+
Width of the latent image i.e. the number of width embeddings.
|
271 |
+
embed_dim (`int`):
|
272 |
+
Dimension of the produced vector embeddings. Used for the latent pixel, height, and width embeddings.
|
273 |
+
"""
|
274 |
+
|
275 |
+
def __init__(
|
276 |
+
self,
|
277 |
+
num_embed: int,
|
278 |
+
height: int,
|
279 |
+
width: int,
|
280 |
+
embed_dim: int,
|
281 |
+
):
|
282 |
+
super().__init__()
|
283 |
+
|
284 |
+
self.height = height
|
285 |
+
self.width = width
|
286 |
+
self.num_embed = num_embed
|
287 |
+
self.embed_dim = embed_dim
|
288 |
+
|
289 |
+
self.emb = nn.Embedding(self.num_embed, embed_dim)
|
290 |
+
self.height_emb = nn.Embedding(self.height, embed_dim)
|
291 |
+
self.width_emb = nn.Embedding(self.width, embed_dim)
|
292 |
+
|
293 |
+
def forward(self, index):
|
294 |
+
emb = self.emb(index)
|
295 |
+
|
296 |
+
height_emb = self.height_emb(torch.arange(self.height, device=index.device).view(1, self.height))
|
297 |
+
|
298 |
+
# 1 x H x D -> 1 x H x 1 x D
|
299 |
+
height_emb = height_emb.unsqueeze(2)
|
300 |
+
|
301 |
+
width_emb = self.width_emb(torch.arange(self.width, device=index.device).view(1, self.width))
|
302 |
+
|
303 |
+
# 1 x W x D -> 1 x 1 x W x D
|
304 |
+
width_emb = width_emb.unsqueeze(1)
|
305 |
+
|
306 |
+
pos_emb = height_emb + width_emb
|
307 |
+
|
308 |
+
# 1 x H x W x D -> 1 x L xD
|
309 |
+
pos_emb = pos_emb.view(1, self.height * self.width, -1)
|
310 |
+
|
311 |
+
emb = emb + pos_emb[:, : emb.shape[1], :]
|
312 |
+
|
313 |
+
return emb
|
314 |
+
|
315 |
+
|
316 |
+
class LabelEmbedding(nn.Module):
|
317 |
+
"""
|
318 |
+
Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
|
319 |
+
|
320 |
+
Args:
|
321 |
+
num_classes (`int`): The number of classes.
|
322 |
+
hidden_size (`int`): The size of the vector embeddings.
|
323 |
+
dropout_prob (`float`): The probability of dropping a label.
|
324 |
+
"""
|
325 |
+
|
326 |
+
def __init__(self, num_classes, hidden_size, dropout_prob):
|
327 |
+
super().__init__()
|
328 |
+
use_cfg_embedding = dropout_prob > 0
|
329 |
+
self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
|
330 |
+
self.num_classes = num_classes
|
331 |
+
self.dropout_prob = dropout_prob
|
332 |
+
|
333 |
+
def token_drop(self, labels, force_drop_ids=None):
|
334 |
+
"""
|
335 |
+
Drops labels to enable classifier-free guidance.
|
336 |
+
"""
|
337 |
+
if force_drop_ids is None:
|
338 |
+
drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
|
339 |
+
else:
|
340 |
+
drop_ids = torch.tensor(force_drop_ids == 1)
|
341 |
+
labels = torch.where(drop_ids, self.num_classes, labels)
|
342 |
+
return labels
|
343 |
+
|
344 |
+
def forward(self, labels: torch.LongTensor, force_drop_ids=None):
|
345 |
+
use_dropout = self.dropout_prob > 0
|
346 |
+
if (self.training and use_dropout) or (force_drop_ids is not None):
|
347 |
+
labels = self.token_drop(labels, force_drop_ids)
|
348 |
+
embeddings = self.embedding_table(labels)
|
349 |
+
return embeddings
|
350 |
+
|
351 |
+
|
352 |
+
class TextImageProjection(nn.Module):
|
353 |
+
def __init__(
|
354 |
+
self,
|
355 |
+
text_embed_dim: int = 1024,
|
356 |
+
image_embed_dim: int = 768,
|
357 |
+
cross_attention_dim: int = 768,
|
358 |
+
num_image_text_embeds: int = 10,
|
359 |
+
):
|
360 |
+
super().__init__()
|
361 |
+
|
362 |
+
self.num_image_text_embeds = num_image_text_embeds
|
363 |
+
self.image_embeds = nn.Linear(image_embed_dim, self.num_image_text_embeds * cross_attention_dim)
|
364 |
+
self.text_proj = nn.Linear(text_embed_dim, cross_attention_dim)
|
365 |
+
|
366 |
+
def forward(self, text_embeds: torch.FloatTensor, image_embeds: torch.FloatTensor):
|
367 |
+
batch_size = text_embeds.shape[0]
|
368 |
+
|
369 |
+
# image
|
370 |
+
image_text_embeds = self.image_embeds(image_embeds)
|
371 |
+
image_text_embeds = image_text_embeds.reshape(batch_size, self.num_image_text_embeds, -1)
|
372 |
+
|
373 |
+
# text
|
374 |
+
text_embeds = self.text_proj(text_embeds)
|
375 |
+
|
376 |
+
return torch.cat([image_text_embeds, text_embeds], dim=1)
|
377 |
+
|
378 |
+
|
379 |
+
class ImageProjection(nn.Module):
|
380 |
+
def __init__(
|
381 |
+
self,
|
382 |
+
image_embed_dim: int = 768,
|
383 |
+
cross_attention_dim: int = 768,
|
384 |
+
num_image_text_embeds: int = 32,
|
385 |
+
):
|
386 |
+
super().__init__()
|
387 |
+
|
388 |
+
self.num_image_text_embeds = num_image_text_embeds
|
389 |
+
self.image_embeds = nn.Linear(image_embed_dim, self.num_image_text_embeds * cross_attention_dim)
|
390 |
+
self.norm = nn.LayerNorm(cross_attention_dim)
|
391 |
+
|
392 |
+
def forward(self, image_embeds: torch.FloatTensor):
|
393 |
+
batch_size = image_embeds.shape[0]
|
394 |
+
|
395 |
+
# image
|
396 |
+
image_embeds = self.image_embeds(image_embeds)
|
397 |
+
image_embeds = image_embeds.reshape(batch_size, self.num_image_text_embeds, -1)
|
398 |
+
image_embeds = self.norm(image_embeds)
|
399 |
+
return image_embeds
|
400 |
+
|
401 |
+
|
402 |
+
class CombinedTimestepLabelEmbeddings(nn.Module):
|
403 |
+
def __init__(self, num_classes, embedding_dim, class_dropout_prob=0.1):
|
404 |
+
super().__init__()
|
405 |
+
|
406 |
+
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1)
|
407 |
+
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
|
408 |
+
self.class_embedder = LabelEmbedding(num_classes, embedding_dim, class_dropout_prob)
|
409 |
+
|
410 |
+
def forward(self, timestep, class_labels, hidden_dtype=None):
|
411 |
+
timesteps_proj = self.time_proj(timestep)
|
412 |
+
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D)
|
413 |
+
|
414 |
+
class_labels = self.class_embedder(class_labels) # (N, D)
|
415 |
+
|
416 |
+
conditioning = timesteps_emb + class_labels # (N, D)
|
417 |
+
|
418 |
+
return conditioning
|
419 |
+
|
420 |
+
|
421 |
+
class TextTimeEmbedding(nn.Module):
|
422 |
+
def __init__(self, encoder_dim: int, time_embed_dim: int, num_heads: int = 64):
|
423 |
+
super().__init__()
|
424 |
+
self.norm1 = nn.LayerNorm(encoder_dim)
|
425 |
+
self.pool = AttentionPooling(num_heads, encoder_dim)
|
426 |
+
self.proj = nn.Linear(encoder_dim, time_embed_dim)
|
427 |
+
self.norm2 = nn.LayerNorm(time_embed_dim)
|
428 |
+
|
429 |
+
def forward(self, hidden_states):
|
430 |
+
hidden_states = self.norm1(hidden_states)
|
431 |
+
hidden_states = self.pool(hidden_states)
|
432 |
+
hidden_states = self.proj(hidden_states)
|
433 |
+
hidden_states = self.norm2(hidden_states)
|
434 |
+
return hidden_states
|
435 |
+
|
436 |
+
|
437 |
+
class TextImageTimeEmbedding(nn.Module):
|
438 |
+
def __init__(self, text_embed_dim: int = 768, image_embed_dim: int = 768, time_embed_dim: int = 1536):
|
439 |
+
super().__init__()
|
440 |
+
self.text_proj = nn.Linear(text_embed_dim, time_embed_dim)
|
441 |
+
self.text_norm = nn.LayerNorm(time_embed_dim)
|
442 |
+
self.image_proj = nn.Linear(image_embed_dim, time_embed_dim)
|
443 |
+
|
444 |
+
def forward(self, text_embeds: torch.FloatTensor, image_embeds: torch.FloatTensor):
|
445 |
+
# text
|
446 |
+
time_text_embeds = self.text_proj(text_embeds)
|
447 |
+
time_text_embeds = self.text_norm(time_text_embeds)
|
448 |
+
|
449 |
+
# image
|
450 |
+
time_image_embeds = self.image_proj(image_embeds)
|
451 |
+
|
452 |
+
return time_image_embeds + time_text_embeds
|
453 |
+
|
454 |
+
|
455 |
+
class ImageTimeEmbedding(nn.Module):
|
456 |
+
def __init__(self, image_embed_dim: int = 768, time_embed_dim: int = 1536):
|
457 |
+
super().__init__()
|
458 |
+
self.image_proj = nn.Linear(image_embed_dim, time_embed_dim)
|
459 |
+
self.image_norm = nn.LayerNorm(time_embed_dim)
|
460 |
+
|
461 |
+
def forward(self, image_embeds: torch.FloatTensor):
|
462 |
+
# image
|
463 |
+
time_image_embeds = self.image_proj(image_embeds)
|
464 |
+
time_image_embeds = self.image_norm(time_image_embeds)
|
465 |
+
return time_image_embeds
|
466 |
+
|
467 |
+
|
468 |
+
class ImageHintTimeEmbedding(nn.Module):
|
469 |
+
def __init__(self, image_embed_dim: int = 768, time_embed_dim: int = 1536):
|
470 |
+
super().__init__()
|
471 |
+
self.image_proj = nn.Linear(image_embed_dim, time_embed_dim)
|
472 |
+
self.image_norm = nn.LayerNorm(time_embed_dim)
|
473 |
+
self.input_hint_block = nn.Sequential(
|
474 |
+
nn.Conv2d(3, 16, 3, padding=1),
|
475 |
+
nn.SiLU(),
|
476 |
+
nn.Conv2d(16, 16, 3, padding=1),
|
477 |
+
nn.SiLU(),
|
478 |
+
nn.Conv2d(16, 32, 3, padding=1, stride=2),
|
479 |
+
nn.SiLU(),
|
480 |
+
nn.Conv2d(32, 32, 3, padding=1),
|
481 |
+
nn.SiLU(),
|
482 |
+
nn.Conv2d(32, 96, 3, padding=1, stride=2),
|
483 |
+
nn.SiLU(),
|
484 |
+
nn.Conv2d(96, 96, 3, padding=1),
|
485 |
+
nn.SiLU(),
|
486 |
+
nn.Conv2d(96, 256, 3, padding=1, stride=2),
|
487 |
+
nn.SiLU(),
|
488 |
+
nn.Conv2d(256, 4, 3, padding=1),
|
489 |
+
)
|
490 |
+
|
491 |
+
def forward(self, image_embeds: torch.FloatTensor, hint: torch.FloatTensor):
|
492 |
+
# image
|
493 |
+
time_image_embeds = self.image_proj(image_embeds)
|
494 |
+
time_image_embeds = self.image_norm(time_image_embeds)
|
495 |
+
hint = self.input_hint_block(hint)
|
496 |
+
return time_image_embeds, hint
|
497 |
+
|
498 |
+
|
499 |
+
class AttentionPooling(nn.Module):
|
500 |
+
# Copied from https://github.com/deep-floyd/IF/blob/2f91391f27dd3c468bf174be5805b4cc92980c0b/deepfloyd_if/model/nn.py#L54
|
501 |
+
|
502 |
+
def __init__(self, num_heads, embed_dim, dtype=None):
|
503 |
+
super().__init__()
|
504 |
+
self.dtype = dtype
|
505 |
+
self.positional_embedding = nn.Parameter(torch.randn(1, embed_dim) / embed_dim**0.5)
|
506 |
+
self.k_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype)
|
507 |
+
self.q_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype)
|
508 |
+
self.v_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype)
|
509 |
+
self.num_heads = num_heads
|
510 |
+
self.dim_per_head = embed_dim // self.num_heads
|
511 |
+
|
512 |
+
def forward(self, x):
|
513 |
+
bs, length, width = x.size()
|
514 |
+
|
515 |
+
def shape(x):
|
516 |
+
# (bs, length, width) --> (bs, length, n_heads, dim_per_head)
|
517 |
+
x = x.view(bs, -1, self.num_heads, self.dim_per_head)
|
518 |
+
# (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
|
519 |
+
x = x.transpose(1, 2)
|
520 |
+
# (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
|
521 |
+
x = x.reshape(bs * self.num_heads, -1, self.dim_per_head)
|
522 |
+
# (bs*n_heads, length, dim_per_head) --> (bs*n_heads, dim_per_head, length)
|
523 |
+
x = x.transpose(1, 2)
|
524 |
+
return x
|
525 |
+
|
526 |
+
class_token = x.mean(dim=1, keepdim=True) + self.positional_embedding.to(x.dtype)
|
527 |
+
x = torch.cat([class_token, x], dim=1) # (bs, length+1, width)
|
528 |
+
|
529 |
+
# (bs*n_heads, class_token_length, dim_per_head)
|
530 |
+
q = shape(self.q_proj(class_token))
|
531 |
+
# (bs*n_heads, length+class_token_length, dim_per_head)
|
532 |
+
k = shape(self.k_proj(x))
|
533 |
+
v = shape(self.v_proj(x))
|
534 |
+
|
535 |
+
# (bs*n_heads, class_token_length, length+class_token_length):
|
536 |
+
scale = 1 / math.sqrt(math.sqrt(self.dim_per_head))
|
537 |
+
weight = torch.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards
|
538 |
+
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
|
539 |
+
|
540 |
+
# (bs*n_heads, dim_per_head, class_token_length)
|
541 |
+
a = torch.einsum("bts,bcs->bct", weight, v)
|
542 |
+
|
543 |
+
# (bs, length+1, width)
|
544 |
+
a = a.reshape(bs, -1, 1).transpose(1, 2)
|
545 |
+
|
546 |
+
return a[:, 0, :] # cls_token
|
6DoF/diffusers/models/embeddings_flax.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import math
|
15 |
+
|
16 |
+
import flax.linen as nn
|
17 |
+
import jax.numpy as jnp
|
18 |
+
|
19 |
+
|
20 |
+
def get_sinusoidal_embeddings(
|
21 |
+
timesteps: jnp.ndarray,
|
22 |
+
embedding_dim: int,
|
23 |
+
freq_shift: float = 1,
|
24 |
+
min_timescale: float = 1,
|
25 |
+
max_timescale: float = 1.0e4,
|
26 |
+
flip_sin_to_cos: bool = False,
|
27 |
+
scale: float = 1.0,
|
28 |
+
) -> jnp.ndarray:
|
29 |
+
"""Returns the positional encoding (same as Tensor2Tensor).
|
30 |
+
|
31 |
+
Args:
|
32 |
+
timesteps: a 1-D Tensor of N indices, one per batch element.
|
33 |
+
These may be fractional.
|
34 |
+
embedding_dim: The number of output channels.
|
35 |
+
min_timescale: The smallest time unit (should probably be 0.0).
|
36 |
+
max_timescale: The largest time unit.
|
37 |
+
Returns:
|
38 |
+
a Tensor of timing signals [N, num_channels]
|
39 |
+
"""
|
40 |
+
assert timesteps.ndim == 1, "Timesteps should be a 1d-array"
|
41 |
+
assert embedding_dim % 2 == 0, f"Embedding dimension {embedding_dim} should be even"
|
42 |
+
num_timescales = float(embedding_dim // 2)
|
43 |
+
log_timescale_increment = math.log(max_timescale / min_timescale) / (num_timescales - freq_shift)
|
44 |
+
inv_timescales = min_timescale * jnp.exp(jnp.arange(num_timescales, dtype=jnp.float32) * -log_timescale_increment)
|
45 |
+
emb = jnp.expand_dims(timesteps, 1) * jnp.expand_dims(inv_timescales, 0)
|
46 |
+
|
47 |
+
# scale embeddings
|
48 |
+
scaled_time = scale * emb
|
49 |
+
|
50 |
+
if flip_sin_to_cos:
|
51 |
+
signal = jnp.concatenate([jnp.cos(scaled_time), jnp.sin(scaled_time)], axis=1)
|
52 |
+
else:
|
53 |
+
signal = jnp.concatenate([jnp.sin(scaled_time), jnp.cos(scaled_time)], axis=1)
|
54 |
+
signal = jnp.reshape(signal, [jnp.shape(timesteps)[0], embedding_dim])
|
55 |
+
return signal
|
56 |
+
|
57 |
+
|
58 |
+
class FlaxTimestepEmbedding(nn.Module):
|
59 |
+
r"""
|
60 |
+
Time step Embedding Module. Learns embeddings for input time steps.
|
61 |
+
|
62 |
+
Args:
|
63 |
+
time_embed_dim (`int`, *optional*, defaults to `32`):
|
64 |
+
Time step embedding dimension
|
65 |
+
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
|
66 |
+
Parameters `dtype`
|
67 |
+
"""
|
68 |
+
time_embed_dim: int = 32
|
69 |
+
dtype: jnp.dtype = jnp.float32
|
70 |
+
|
71 |
+
@nn.compact
|
72 |
+
def __call__(self, temb):
|
73 |
+
temb = nn.Dense(self.time_embed_dim, dtype=self.dtype, name="linear_1")(temb)
|
74 |
+
temb = nn.silu(temb)
|
75 |
+
temb = nn.Dense(self.time_embed_dim, dtype=self.dtype, name="linear_2")(temb)
|
76 |
+
return temb
|
77 |
+
|
78 |
+
|
79 |
+
class FlaxTimesteps(nn.Module):
|
80 |
+
r"""
|
81 |
+
Wrapper Module for sinusoidal Time step Embeddings as described in https://arxiv.org/abs/2006.11239
|
82 |
+
|
83 |
+
Args:
|
84 |
+
dim (`int`, *optional*, defaults to `32`):
|
85 |
+
Time step embedding dimension
|
86 |
+
"""
|
87 |
+
dim: int = 32
|
88 |
+
flip_sin_to_cos: bool = False
|
89 |
+
freq_shift: float = 1
|
90 |
+
|
91 |
+
@nn.compact
|
92 |
+
def __call__(self, timesteps):
|
93 |
+
return get_sinusoidal_embeddings(
|
94 |
+
timesteps, embedding_dim=self.dim, flip_sin_to_cos=self.flip_sin_to_cos, freq_shift=self.freq_shift
|
95 |
+
)
|
6DoF/diffusers/models/modeling_flax_pytorch_utils.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023 The HuggingFace Inc. team.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
""" PyTorch - Flax general utilities."""
|
16 |
+
import re
|
17 |
+
|
18 |
+
import jax.numpy as jnp
|
19 |
+
from flax.traverse_util import flatten_dict, unflatten_dict
|
20 |
+
from jax.random import PRNGKey
|
21 |
+
|
22 |
+
from ..utils import logging
|
23 |
+
|
24 |
+
|
25 |
+
logger = logging.get_logger(__name__)
|
26 |
+
|
27 |
+
|
28 |
+
def rename_key(key):
|
29 |
+
regex = r"\w+[.]\d+"
|
30 |
+
pats = re.findall(regex, key)
|
31 |
+
for pat in pats:
|
32 |
+
key = key.replace(pat, "_".join(pat.split(".")))
|
33 |
+
return key
|
34 |
+
|
35 |
+
|
36 |
+
#####################
|
37 |
+
# PyTorch => Flax #
|
38 |
+
#####################
|
39 |
+
|
40 |
+
|
41 |
+
# Adapted from https://github.com/huggingface/transformers/blob/c603c80f46881ae18b2ca50770ef65fa4033eacd/src/transformers/modeling_flax_pytorch_utils.py#L69
|
42 |
+
# and https://github.com/patil-suraj/stable-diffusion-jax/blob/main/stable_diffusion_jax/convert_diffusers_to_jax.py
|
43 |
+
def rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dict):
|
44 |
+
"""Rename PT weight names to corresponding Flax weight names and reshape tensor if necessary"""
|
45 |
+
|
46 |
+
# conv norm or layer norm
|
47 |
+
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",)
|
48 |
+
if (
|
49 |
+
any("norm" in str_ for str_ in pt_tuple_key)
|
50 |
+
and (pt_tuple_key[-1] == "bias")
|
51 |
+
and (pt_tuple_key[:-1] + ("bias",) not in random_flax_state_dict)
|
52 |
+
and (pt_tuple_key[:-1] + ("scale",) in random_flax_state_dict)
|
53 |
+
):
|
54 |
+
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",)
|
55 |
+
return renamed_pt_tuple_key, pt_tensor
|
56 |
+
elif pt_tuple_key[-1] in ["weight", "gamma"] and pt_tuple_key[:-1] + ("scale",) in random_flax_state_dict:
|
57 |
+
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",)
|
58 |
+
return renamed_pt_tuple_key, pt_tensor
|
59 |
+
|
60 |
+
# embedding
|
61 |
+
if pt_tuple_key[-1] == "weight" and pt_tuple_key[:-1] + ("embedding",) in random_flax_state_dict:
|
62 |
+
pt_tuple_key = pt_tuple_key[:-1] + ("embedding",)
|
63 |
+
return renamed_pt_tuple_key, pt_tensor
|
64 |
+
|
65 |
+
# conv layer
|
66 |
+
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("kernel",)
|
67 |
+
if pt_tuple_key[-1] == "weight" and pt_tensor.ndim == 4:
|
68 |
+
pt_tensor = pt_tensor.transpose(2, 3, 1, 0)
|
69 |
+
return renamed_pt_tuple_key, pt_tensor
|
70 |
+
|
71 |
+
# linear layer
|
72 |
+
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("kernel",)
|
73 |
+
if pt_tuple_key[-1] == "weight":
|
74 |
+
pt_tensor = pt_tensor.T
|
75 |
+
return renamed_pt_tuple_key, pt_tensor
|
76 |
+
|
77 |
+
# old PyTorch layer norm weight
|
78 |
+
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("weight",)
|
79 |
+
if pt_tuple_key[-1] == "gamma":
|
80 |
+
return renamed_pt_tuple_key, pt_tensor
|
81 |
+
|
82 |
+
# old PyTorch layer norm bias
|
83 |
+
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("bias",)
|
84 |
+
if pt_tuple_key[-1] == "beta":
|
85 |
+
return renamed_pt_tuple_key, pt_tensor
|
86 |
+
|
87 |
+
return pt_tuple_key, pt_tensor
|
88 |
+
|
89 |
+
|
90 |
+
def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model, init_key=42):
|
91 |
+
# Step 1: Convert pytorch tensor to numpy
|
92 |
+
pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()}
|
93 |
+
|
94 |
+
# Step 2: Since the model is stateless, get random Flax params
|
95 |
+
random_flax_params = flax_model.init_weights(PRNGKey(init_key))
|
96 |
+
|
97 |
+
random_flax_state_dict = flatten_dict(random_flax_params)
|
98 |
+
flax_state_dict = {}
|
99 |
+
|
100 |
+
# Need to change some parameters name to match Flax names
|
101 |
+
for pt_key, pt_tensor in pt_state_dict.items():
|
102 |
+
renamed_pt_key = rename_key(pt_key)
|
103 |
+
pt_tuple_key = tuple(renamed_pt_key.split("."))
|
104 |
+
|
105 |
+
# Correctly rename weight parameters
|
106 |
+
flax_key, flax_tensor = rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dict)
|
107 |
+
|
108 |
+
if flax_key in random_flax_state_dict:
|
109 |
+
if flax_tensor.shape != random_flax_state_dict[flax_key].shape:
|
110 |
+
raise ValueError(
|
111 |
+
f"PyTorch checkpoint seems to be incorrect. Weight {pt_key} was expected to be of shape "
|
112 |
+
f"{random_flax_state_dict[flax_key].shape}, but is {flax_tensor.shape}."
|
113 |
+
)
|
114 |
+
|
115 |
+
# also add unexpected weight so that warning is thrown
|
116 |
+
flax_state_dict[flax_key] = jnp.asarray(flax_tensor)
|
117 |
+
|
118 |
+
return unflatten_dict(flax_state_dict)
|
6DoF/diffusers/models/modeling_flax_utils.py
ADDED
@@ -0,0 +1,534 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023 The HuggingFace Inc. team.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
import os
|
17 |
+
from pickle import UnpicklingError
|
18 |
+
from typing import Any, Dict, Union
|
19 |
+
|
20 |
+
import jax
|
21 |
+
import jax.numpy as jnp
|
22 |
+
import msgpack.exceptions
|
23 |
+
from flax.core.frozen_dict import FrozenDict, unfreeze
|
24 |
+
from flax.serialization import from_bytes, to_bytes
|
25 |
+
from flax.traverse_util import flatten_dict, unflatten_dict
|
26 |
+
from huggingface_hub import hf_hub_download
|
27 |
+
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
|
28 |
+
from requests import HTTPError
|
29 |
+
|
30 |
+
from .. import __version__, is_torch_available
|
31 |
+
from ..utils import (
|
32 |
+
CONFIG_NAME,
|
33 |
+
DIFFUSERS_CACHE,
|
34 |
+
FLAX_WEIGHTS_NAME,
|
35 |
+
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
|
36 |
+
WEIGHTS_NAME,
|
37 |
+
logging,
|
38 |
+
)
|
39 |
+
from .modeling_flax_pytorch_utils import convert_pytorch_state_dict_to_flax
|
40 |
+
|
41 |
+
|
42 |
+
logger = logging.get_logger(__name__)
|
43 |
+
|
44 |
+
|
45 |
+
class FlaxModelMixin:
|
46 |
+
r"""
|
47 |
+
Base class for all Flax models.
|
48 |
+
|
49 |
+
[`FlaxModelMixin`] takes care of storing the model configuration and provides methods for loading, downloading and
|
50 |
+
saving models.
|
51 |
+
|
52 |
+
- **config_name** ([`str`]) -- Filename to save a model to when calling [`~FlaxModelMixin.save_pretrained`].
|
53 |
+
"""
|
54 |
+
config_name = CONFIG_NAME
|
55 |
+
_automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"]
|
56 |
+
_flax_internal_args = ["name", "parent", "dtype"]
|
57 |
+
|
58 |
+
@classmethod
|
59 |
+
def _from_config(cls, config, **kwargs):
|
60 |
+
"""
|
61 |
+
All context managers that the model should be initialized under go here.
|
62 |
+
"""
|
63 |
+
return cls(config, **kwargs)
|
64 |
+
|
65 |
+
def _cast_floating_to(self, params: Union[Dict, FrozenDict], dtype: jnp.dtype, mask: Any = None) -> Any:
|
66 |
+
"""
|
67 |
+
Helper method to cast floating-point values of given parameter `PyTree` to given `dtype`.
|
68 |
+
"""
|
69 |
+
|
70 |
+
# taken from https://github.com/deepmind/jmp/blob/3a8318abc3292be38582794dbf7b094e6583b192/jmp/_src/policy.py#L27
|
71 |
+
def conditional_cast(param):
|
72 |
+
if isinstance(param, jnp.ndarray) and jnp.issubdtype(param.dtype, jnp.floating):
|
73 |
+
param = param.astype(dtype)
|
74 |
+
return param
|
75 |
+
|
76 |
+
if mask is None:
|
77 |
+
return jax.tree_map(conditional_cast, params)
|
78 |
+
|
79 |
+
flat_params = flatten_dict(params)
|
80 |
+
flat_mask, _ = jax.tree_flatten(mask)
|
81 |
+
|
82 |
+
for masked, key in zip(flat_mask, flat_params.keys()):
|
83 |
+
if masked:
|
84 |
+
param = flat_params[key]
|
85 |
+
flat_params[key] = conditional_cast(param)
|
86 |
+
|
87 |
+
return unflatten_dict(flat_params)
|
88 |
+
|
89 |
+
def to_bf16(self, params: Union[Dict, FrozenDict], mask: Any = None):
|
90 |
+
r"""
|
91 |
+
Cast the floating-point `params` to `jax.numpy.bfloat16`. This returns a new `params` tree and does not cast
|
92 |
+
the `params` in place.
|
93 |
+
|
94 |
+
This method can be used on a TPU to explicitly convert the model parameters to bfloat16 precision to do full
|
95 |
+
half-precision training or to save weights in bfloat16 for inference in order to save memory and improve speed.
|
96 |
+
|
97 |
+
Arguments:
|
98 |
+
params (`Union[Dict, FrozenDict]`):
|
99 |
+
A `PyTree` of model parameters.
|
100 |
+
mask (`Union[Dict, FrozenDict]`):
|
101 |
+
A `PyTree` with same structure as the `params` tree. The leaves should be booleans. It should be `True`
|
102 |
+
for params you want to cast, and `False` for those you want to skip.
|
103 |
+
|
104 |
+
Examples:
|
105 |
+
|
106 |
+
```python
|
107 |
+
>>> from diffusers import FlaxUNet2DConditionModel
|
108 |
+
|
109 |
+
>>> # load model
|
110 |
+
>>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5")
|
111 |
+
>>> # By default, the model parameters will be in fp32 precision, to cast these to bfloat16 precision
|
112 |
+
>>> params = model.to_bf16(params)
|
113 |
+
>>> # If you don't want to cast certain parameters (for example layer norm bias and scale)
|
114 |
+
>>> # then pass the mask as follows
|
115 |
+
>>> from flax import traverse_util
|
116 |
+
|
117 |
+
>>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5")
|
118 |
+
>>> flat_params = traverse_util.flatten_dict(params)
|
119 |
+
>>> mask = {
|
120 |
+
... path: (path[-2] != ("LayerNorm", "bias") and path[-2:] != ("LayerNorm", "scale"))
|
121 |
+
... for path in flat_params
|
122 |
+
... }
|
123 |
+
>>> mask = traverse_util.unflatten_dict(mask)
|
124 |
+
>>> params = model.to_bf16(params, mask)
|
125 |
+
```"""
|
126 |
+
return self._cast_floating_to(params, jnp.bfloat16, mask)
|
127 |
+
|
128 |
+
def to_fp32(self, params: Union[Dict, FrozenDict], mask: Any = None):
|
129 |
+
r"""
|
130 |
+
Cast the floating-point `params` to `jax.numpy.float32`. This method can be used to explicitly convert the
|
131 |
+
model parameters to fp32 precision. This returns a new `params` tree and does not cast the `params` in place.
|
132 |
+
|
133 |
+
Arguments:
|
134 |
+
params (`Union[Dict, FrozenDict]`):
|
135 |
+
A `PyTree` of model parameters.
|
136 |
+
mask (`Union[Dict, FrozenDict]`):
|
137 |
+
A `PyTree` with same structure as the `params` tree. The leaves should be booleans. It should be `True`
|
138 |
+
for params you want to cast, and `False` for those you want to skip.
|
139 |
+
|
140 |
+
Examples:
|
141 |
+
|
142 |
+
```python
|
143 |
+
>>> from diffusers import FlaxUNet2DConditionModel
|
144 |
+
|
145 |
+
>>> # Download model and configuration from huggingface.co
|
146 |
+
>>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5")
|
147 |
+
>>> # By default, the model params will be in fp32, to illustrate the use of this method,
|
148 |
+
>>> # we'll first cast to fp16 and back to fp32
|
149 |
+
>>> params = model.to_f16(params)
|
150 |
+
>>> # now cast back to fp32
|
151 |
+
>>> params = model.to_fp32(params)
|
152 |
+
```"""
|
153 |
+
return self._cast_floating_to(params, jnp.float32, mask)
|
154 |
+
|
155 |
+
def to_fp16(self, params: Union[Dict, FrozenDict], mask: Any = None):
|
156 |
+
r"""
|
157 |
+
Cast the floating-point `params` to `jax.numpy.float16`. This returns a new `params` tree and does not cast the
|
158 |
+
`params` in place.
|
159 |
+
|
160 |
+
This method can be used on a GPU to explicitly convert the model parameters to float16 precision to do full
|
161 |
+
half-precision training or to save weights in float16 for inference in order to save memory and improve speed.
|
162 |
+
|
163 |
+
Arguments:
|
164 |
+
params (`Union[Dict, FrozenDict]`):
|
165 |
+
A `PyTree` of model parameters.
|
166 |
+
mask (`Union[Dict, FrozenDict]`):
|
167 |
+
A `PyTree` with same structure as the `params` tree. The leaves should be booleans. It should be `True`
|
168 |
+
for params you want to cast, and `False` for those you want to skip.
|
169 |
+
|
170 |
+
Examples:
|
171 |
+
|
172 |
+
```python
|
173 |
+
>>> from diffusers import FlaxUNet2DConditionModel
|
174 |
+
|
175 |
+
>>> # load model
|
176 |
+
>>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5")
|
177 |
+
>>> # By default, the model params will be in fp32, to cast these to float16
|
178 |
+
>>> params = model.to_fp16(params)
|
179 |
+
>>> # If you want don't want to cast certain parameters (for example layer norm bias and scale)
|
180 |
+
>>> # then pass the mask as follows
|
181 |
+
>>> from flax import traverse_util
|
182 |
+
|
183 |
+
>>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5")
|
184 |
+
>>> flat_params = traverse_util.flatten_dict(params)
|
185 |
+
>>> mask = {
|
186 |
+
... path: (path[-2] != ("LayerNorm", "bias") and path[-2:] != ("LayerNorm", "scale"))
|
187 |
+
... for path in flat_params
|
188 |
+
... }
|
189 |
+
>>> mask = traverse_util.unflatten_dict(mask)
|
190 |
+
>>> params = model.to_fp16(params, mask)
|
191 |
+
```"""
|
192 |
+
return self._cast_floating_to(params, jnp.float16, mask)
|
193 |
+
|
194 |
+
def init_weights(self, rng: jax.random.KeyArray) -> Dict:
|
195 |
+
raise NotImplementedError(f"init_weights method has to be implemented for {self}")
|
196 |
+
|
197 |
+
@classmethod
|
198 |
+
def from_pretrained(
|
199 |
+
cls,
|
200 |
+
pretrained_model_name_or_path: Union[str, os.PathLike],
|
201 |
+
dtype: jnp.dtype = jnp.float32,
|
202 |
+
*model_args,
|
203 |
+
**kwargs,
|
204 |
+
):
|
205 |
+
r"""
|
206 |
+
Instantiate a pretrained Flax model from a pretrained model configuration.
|
207 |
+
|
208 |
+
Parameters:
|
209 |
+
pretrained_model_name_or_path (`str` or `os.PathLike`):
|
210 |
+
Can be either:
|
211 |
+
|
212 |
+
- A string, the *model id* (for example `runwayml/stable-diffusion-v1-5`) of a pretrained model
|
213 |
+
hosted on the Hub.
|
214 |
+
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
|
215 |
+
using [`~FlaxModelMixin.save_pretrained`].
|
216 |
+
dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
|
217 |
+
The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
|
218 |
+
`jax.numpy.bfloat16` (on TPUs).
|
219 |
+
|
220 |
+
This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
|
221 |
+
specified, all the computation will be performed with the given `dtype`.
|
222 |
+
|
223 |
+
<Tip>
|
224 |
+
|
225 |
+
This only specifies the dtype of the *computation* and does not influence the dtype of model
|
226 |
+
parameters.
|
227 |
+
|
228 |
+
If you wish to change the dtype of the model parameters, see [`~FlaxModelMixin.to_fp16`] and
|
229 |
+
[`~FlaxModelMixin.to_bf16`].
|
230 |
+
|
231 |
+
</Tip>
|
232 |
+
|
233 |
+
model_args (sequence of positional arguments, *optional*):
|
234 |
+
All remaining positional arguments are passed to the underlying model's `__init__` method.
|
235 |
+
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
236 |
+
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
237 |
+
is not used.
|
238 |
+
force_download (`bool`, *optional*, defaults to `False`):
|
239 |
+
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
240 |
+
cached versions if they exist.
|
241 |
+
resume_download (`bool`, *optional*, defaults to `False`):
|
242 |
+
Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
|
243 |
+
incompletely downloaded files are deleted.
|
244 |
+
proxies (`Dict[str, str]`, *optional*):
|
245 |
+
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
246 |
+
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
247 |
+
local_files_only(`bool`, *optional*, defaults to `False`):
|
248 |
+
Whether to only load local model weights and configuration files or not. If set to `True`, the model
|
249 |
+
won't be downloaded from the Hub.
|
250 |
+
revision (`str`, *optional*, defaults to `"main"`):
|
251 |
+
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
252 |
+
allowed by Git.
|
253 |
+
from_pt (`bool`, *optional*, defaults to `False`):
|
254 |
+
Load the model weights from a PyTorch checkpoint save file.
|
255 |
+
kwargs (remaining dictionary of keyword arguments, *optional*):
|
256 |
+
Can be used to update the configuration object (after it is loaded) and initiate the model (for
|
257 |
+
example, `output_attentions=True`). Behaves differently depending on whether a `config` is provided or
|
258 |
+
automatically loaded:
|
259 |
+
|
260 |
+
- If a configuration is provided with `config`, `kwargs` are directly passed to the underlying
|
261 |
+
model's `__init__` method (we assume all relevant updates to the configuration have already been
|
262 |
+
done).
|
263 |
+
- If a configuration is not provided, `kwargs` are first passed to the configuration class
|
264 |
+
initialization function [`~ConfigMixin.from_config`]. Each key of the `kwargs` that corresponds
|
265 |
+
to a configuration attribute is used to override said attribute with the supplied `kwargs` value.
|
266 |
+
Remaining keys that do not correspond to any configuration attribute are passed to the underlying
|
267 |
+
model's `__init__` function.
|
268 |
+
|
269 |
+
Examples:
|
270 |
+
|
271 |
+
```python
|
272 |
+
>>> from diffusers import FlaxUNet2DConditionModel
|
273 |
+
|
274 |
+
>>> # Download model and configuration from huggingface.co and cache.
|
275 |
+
>>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5")
|
276 |
+
>>> # Model was saved using *save_pretrained('./test/saved_model/')* (for example purposes, not runnable).
|
277 |
+
>>> model, params = FlaxUNet2DConditionModel.from_pretrained("./test/saved_model/")
|
278 |
+
```
|
279 |
+
|
280 |
+
If you get the error message below, you need to finetune the weights for your downstream task:
|
281 |
+
|
282 |
+
```bash
|
283 |
+
Some weights of UNet2DConditionModel were not initialized from the model checkpoint at runwayml/stable-diffusion-v1-5 and are newly initialized because the shapes did not match:
|
284 |
+
- conv_in.weight: found shape torch.Size([320, 4, 3, 3]) in the checkpoint and torch.Size([320, 9, 3, 3]) in the model instantiated
|
285 |
+
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
|
286 |
+
```
|
287 |
+
"""
|
288 |
+
config = kwargs.pop("config", None)
|
289 |
+
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
|
290 |
+
force_download = kwargs.pop("force_download", False)
|
291 |
+
from_pt = kwargs.pop("from_pt", False)
|
292 |
+
resume_download = kwargs.pop("resume_download", False)
|
293 |
+
proxies = kwargs.pop("proxies", None)
|
294 |
+
local_files_only = kwargs.pop("local_files_only", False)
|
295 |
+
use_auth_token = kwargs.pop("use_auth_token", None)
|
296 |
+
revision = kwargs.pop("revision", None)
|
297 |
+
subfolder = kwargs.pop("subfolder", None)
|
298 |
+
|
299 |
+
user_agent = {
|
300 |
+
"diffusers": __version__,
|
301 |
+
"file_type": "model",
|
302 |
+
"framework": "flax",
|
303 |
+
}
|
304 |
+
|
305 |
+
# Load config if we don't provide a configuration
|
306 |
+
config_path = config if config is not None else pretrained_model_name_or_path
|
307 |
+
model, model_kwargs = cls.from_config(
|
308 |
+
config_path,
|
309 |
+
cache_dir=cache_dir,
|
310 |
+
return_unused_kwargs=True,
|
311 |
+
force_download=force_download,
|
312 |
+
resume_download=resume_download,
|
313 |
+
proxies=proxies,
|
314 |
+
local_files_only=local_files_only,
|
315 |
+
use_auth_token=use_auth_token,
|
316 |
+
revision=revision,
|
317 |
+
subfolder=subfolder,
|
318 |
+
# model args
|
319 |
+
dtype=dtype,
|
320 |
+
**kwargs,
|
321 |
+
)
|
322 |
+
|
323 |
+
# Load model
|
324 |
+
pretrained_path_with_subfolder = (
|
325 |
+
pretrained_model_name_or_path
|
326 |
+
if subfolder is None
|
327 |
+
else os.path.join(pretrained_model_name_or_path, subfolder)
|
328 |
+
)
|
329 |
+
if os.path.isdir(pretrained_path_with_subfolder):
|
330 |
+
if from_pt:
|
331 |
+
if not os.path.isfile(os.path.join(pretrained_path_with_subfolder, WEIGHTS_NAME)):
|
332 |
+
raise EnvironmentError(
|
333 |
+
f"Error no file named {WEIGHTS_NAME} found in directory {pretrained_path_with_subfolder} "
|
334 |
+
)
|
335 |
+
model_file = os.path.join(pretrained_path_with_subfolder, WEIGHTS_NAME)
|
336 |
+
elif os.path.isfile(os.path.join(pretrained_path_with_subfolder, FLAX_WEIGHTS_NAME)):
|
337 |
+
# Load from a Flax checkpoint
|
338 |
+
model_file = os.path.join(pretrained_path_with_subfolder, FLAX_WEIGHTS_NAME)
|
339 |
+
# Check if pytorch weights exist instead
|
340 |
+
elif os.path.isfile(os.path.join(pretrained_path_with_subfolder, WEIGHTS_NAME)):
|
341 |
+
raise EnvironmentError(
|
342 |
+
f"{WEIGHTS_NAME} file found in directory {pretrained_path_with_subfolder}. Please load the model"
|
343 |
+
" using `from_pt=True`."
|
344 |
+
)
|
345 |
+
else:
|
346 |
+
raise EnvironmentError(
|
347 |
+
f"Error no file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME} found in directory "
|
348 |
+
f"{pretrained_path_with_subfolder}."
|
349 |
+
)
|
350 |
+
else:
|
351 |
+
try:
|
352 |
+
model_file = hf_hub_download(
|
353 |
+
pretrained_model_name_or_path,
|
354 |
+
filename=FLAX_WEIGHTS_NAME if not from_pt else WEIGHTS_NAME,
|
355 |
+
cache_dir=cache_dir,
|
356 |
+
force_download=force_download,
|
357 |
+
proxies=proxies,
|
358 |
+
resume_download=resume_download,
|
359 |
+
local_files_only=local_files_only,
|
360 |
+
use_auth_token=use_auth_token,
|
361 |
+
user_agent=user_agent,
|
362 |
+
subfolder=subfolder,
|
363 |
+
revision=revision,
|
364 |
+
)
|
365 |
+
|
366 |
+
except RepositoryNotFoundError:
|
367 |
+
raise EnvironmentError(
|
368 |
+
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
|
369 |
+
"listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a "
|
370 |
+
"token having permission to this repo with `use_auth_token` or log in with `huggingface-cli "
|
371 |
+
"login`."
|
372 |
+
)
|
373 |
+
except RevisionNotFoundError:
|
374 |
+
raise EnvironmentError(
|
375 |
+
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for "
|
376 |
+
"this model name. Check the model page at "
|
377 |
+
f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
|
378 |
+
)
|
379 |
+
except EntryNotFoundError:
|
380 |
+
raise EnvironmentError(
|
381 |
+
f"{pretrained_model_name_or_path} does not appear to have a file named {FLAX_WEIGHTS_NAME}."
|
382 |
+
)
|
383 |
+
except HTTPError as err:
|
384 |
+
raise EnvironmentError(
|
385 |
+
f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n"
|
386 |
+
f"{err}"
|
387 |
+
)
|
388 |
+
except ValueError:
|
389 |
+
raise EnvironmentError(
|
390 |
+
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
|
391 |
+
f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
|
392 |
+
f" directory containing a file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME}.\nCheckout your"
|
393 |
+
" internet connection or see how to run the library in offline mode at"
|
394 |
+
" 'https://huggingface.co/docs/transformers/installation#offline-mode'."
|
395 |
+
)
|
396 |
+
except EnvironmentError:
|
397 |
+
raise EnvironmentError(
|
398 |
+
f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from "
|
399 |
+
"'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
|
400 |
+
f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
|
401 |
+
f"containing a file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME}."
|
402 |
+
)
|
403 |
+
|
404 |
+
if from_pt:
|
405 |
+
if is_torch_available():
|
406 |
+
from .modeling_utils import load_state_dict
|
407 |
+
else:
|
408 |
+
raise EnvironmentError(
|
409 |
+
"Can't load the model in PyTorch format because PyTorch is not installed. "
|
410 |
+
"Please, install PyTorch or use native Flax weights."
|
411 |
+
)
|
412 |
+
|
413 |
+
# Step 1: Get the pytorch file
|
414 |
+
pytorch_model_file = load_state_dict(model_file)
|
415 |
+
|
416 |
+
# Step 2: Convert the weights
|
417 |
+
state = convert_pytorch_state_dict_to_flax(pytorch_model_file, model)
|
418 |
+
else:
|
419 |
+
try:
|
420 |
+
with open(model_file, "rb") as state_f:
|
421 |
+
state = from_bytes(cls, state_f.read())
|
422 |
+
except (UnpicklingError, msgpack.exceptions.ExtraData) as e:
|
423 |
+
try:
|
424 |
+
with open(model_file) as f:
|
425 |
+
if f.read().startswith("version"):
|
426 |
+
raise OSError(
|
427 |
+
"You seem to have cloned a repository without having git-lfs installed. Please"
|
428 |
+
" install git-lfs and run `git lfs install` followed by `git lfs pull` in the"
|
429 |
+
" folder you cloned."
|
430 |
+
)
|
431 |
+
else:
|
432 |
+
raise ValueError from e
|
433 |
+
except (UnicodeDecodeError, ValueError):
|
434 |
+
raise EnvironmentError(f"Unable to convert {model_file} to Flax deserializable object. ")
|
435 |
+
# make sure all arrays are stored as jnp.ndarray
|
436 |
+
# NOTE: This is to prevent a bug this will be fixed in Flax >= v0.3.4:
|
437 |
+
# https://github.com/google/flax/issues/1261
|
438 |
+
state = jax.tree_util.tree_map(lambda x: jax.device_put(x, jax.devices("cpu")[0]), state)
|
439 |
+
|
440 |
+
# flatten dicts
|
441 |
+
state = flatten_dict(state)
|
442 |
+
|
443 |
+
params_shape_tree = jax.eval_shape(model.init_weights, rng=jax.random.PRNGKey(0))
|
444 |
+
required_params = set(flatten_dict(unfreeze(params_shape_tree)).keys())
|
445 |
+
|
446 |
+
shape_state = flatten_dict(unfreeze(params_shape_tree))
|
447 |
+
|
448 |
+
missing_keys = required_params - set(state.keys())
|
449 |
+
unexpected_keys = set(state.keys()) - required_params
|
450 |
+
|
451 |
+
if missing_keys:
|
452 |
+
logger.warning(
|
453 |
+
f"The checkpoint {pretrained_model_name_or_path} is missing required keys: {missing_keys}. "
|
454 |
+
"Make sure to call model.init_weights to initialize the missing weights."
|
455 |
+
)
|
456 |
+
cls._missing_keys = missing_keys
|
457 |
+
|
458 |
+
for key in state.keys():
|
459 |
+
if key in shape_state and state[key].shape != shape_state[key].shape:
|
460 |
+
raise ValueError(
|
461 |
+
f"Trying to load the pretrained weight for {key} failed: checkpoint has shape "
|
462 |
+
f"{state[key].shape} which is incompatible with the model shape {shape_state[key].shape}. "
|
463 |
+
)
|
464 |
+
|
465 |
+
# remove unexpected keys to not be saved again
|
466 |
+
for unexpected_key in unexpected_keys:
|
467 |
+
del state[unexpected_key]
|
468 |
+
|
469 |
+
if len(unexpected_keys) > 0:
|
470 |
+
logger.warning(
|
471 |
+
f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
|
472 |
+
f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
|
473 |
+
f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task or"
|
474 |
+
" with another architecture."
|
475 |
+
)
|
476 |
+
else:
|
477 |
+
logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
|
478 |
+
|
479 |
+
if len(missing_keys) > 0:
|
480 |
+
logger.warning(
|
481 |
+
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
|
482 |
+
f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably"
|
483 |
+
" TRAIN this model on a down-stream task to be able to use it for predictions and inference."
|
484 |
+
)
|
485 |
+
else:
|
486 |
+
logger.info(
|
487 |
+
f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at"
|
488 |
+
f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the checkpoint"
|
489 |
+
f" was trained on, you can already use {model.__class__.__name__} for predictions without further"
|
490 |
+
" training."
|
491 |
+
)
|
492 |
+
|
493 |
+
return model, unflatten_dict(state)
|
494 |
+
|
495 |
+
def save_pretrained(
|
496 |
+
self,
|
497 |
+
save_directory: Union[str, os.PathLike],
|
498 |
+
params: Union[Dict, FrozenDict],
|
499 |
+
is_main_process: bool = True,
|
500 |
+
):
|
501 |
+
"""
|
502 |
+
Save a model and its configuration file to a directory so that it can be reloaded using the
|
503 |
+
[`~FlaxModelMixin.from_pretrained`] class method.
|
504 |
+
|
505 |
+
Arguments:
|
506 |
+
save_directory (`str` or `os.PathLike`):
|
507 |
+
Directory to save a model and its configuration file to. Will be created if it doesn't exist.
|
508 |
+
params (`Union[Dict, FrozenDict]`):
|
509 |
+
A `PyTree` of model parameters.
|
510 |
+
is_main_process (`bool`, *optional*, defaults to `True`):
|
511 |
+
Whether the process calling this is the main process or not. Useful during distributed training and you
|
512 |
+
need to call this function on all processes. In this case, set `is_main_process=True` only on the main
|
513 |
+
process to avoid race conditions.
|
514 |
+
"""
|
515 |
+
if os.path.isfile(save_directory):
|
516 |
+
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
|
517 |
+
return
|
518 |
+
|
519 |
+
os.makedirs(save_directory, exist_ok=True)
|
520 |
+
|
521 |
+
model_to_save = self
|
522 |
+
|
523 |
+
# Attach architecture to the config
|
524 |
+
# Save the config
|
525 |
+
if is_main_process:
|
526 |
+
model_to_save.save_config(save_directory)
|
527 |
+
|
528 |
+
# save model
|
529 |
+
output_model_file = os.path.join(save_directory, FLAX_WEIGHTS_NAME)
|
530 |
+
with open(output_model_file, "wb") as f:
|
531 |
+
model_bytes = to_bytes(params)
|
532 |
+
f.write(model_bytes)
|
533 |
+
|
534 |
+
logger.info(f"Model weights saved in {output_model_file}")
|
6DoF/diffusers/models/modeling_pytorch_flax_utils.py
ADDED
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023 The HuggingFace Inc. team.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
""" PyTorch - Flax general utilities."""
|
16 |
+
|
17 |
+
from pickle import UnpicklingError
|
18 |
+
|
19 |
+
import jax
|
20 |
+
import jax.numpy as jnp
|
21 |
+
import numpy as np
|
22 |
+
from flax.serialization import from_bytes
|
23 |
+
from flax.traverse_util import flatten_dict
|
24 |
+
|
25 |
+
from ..utils import logging
|
26 |
+
|
27 |
+
|
28 |
+
logger = logging.get_logger(__name__)
|
29 |
+
|
30 |
+
|
31 |
+
#####################
|
32 |
+
# Flax => PyTorch #
|
33 |
+
#####################
|
34 |
+
|
35 |
+
|
36 |
+
# from https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_flax_pytorch_utils.py#L224-L352
|
37 |
+
def load_flax_checkpoint_in_pytorch_model(pt_model, model_file):
|
38 |
+
try:
|
39 |
+
with open(model_file, "rb") as flax_state_f:
|
40 |
+
flax_state = from_bytes(None, flax_state_f.read())
|
41 |
+
except UnpicklingError as e:
|
42 |
+
try:
|
43 |
+
with open(model_file) as f:
|
44 |
+
if f.read().startswith("version"):
|
45 |
+
raise OSError(
|
46 |
+
"You seem to have cloned a repository without having git-lfs installed. Please"
|
47 |
+
" install git-lfs and run `git lfs install` followed by `git lfs pull` in the"
|
48 |
+
" folder you cloned."
|
49 |
+
)
|
50 |
+
else:
|
51 |
+
raise ValueError from e
|
52 |
+
except (UnicodeDecodeError, ValueError):
|
53 |
+
raise EnvironmentError(f"Unable to convert {model_file} to Flax deserializable object. ")
|
54 |
+
|
55 |
+
return load_flax_weights_in_pytorch_model(pt_model, flax_state)
|
56 |
+
|
57 |
+
|
58 |
+
def load_flax_weights_in_pytorch_model(pt_model, flax_state):
|
59 |
+
"""Load flax checkpoints in a PyTorch model"""
|
60 |
+
|
61 |
+
try:
|
62 |
+
import torch # noqa: F401
|
63 |
+
except ImportError:
|
64 |
+
logger.error(
|
65 |
+
"Loading Flax weights in PyTorch requires both PyTorch and Flax to be installed. Please see"
|
66 |
+
" https://pytorch.org/ and https://flax.readthedocs.io/en/latest/installation.html for installation"
|
67 |
+
" instructions."
|
68 |
+
)
|
69 |
+
raise
|
70 |
+
|
71 |
+
# check if we have bf16 weights
|
72 |
+
is_type_bf16 = flatten_dict(jax.tree_util.tree_map(lambda x: x.dtype == jnp.bfloat16, flax_state)).values()
|
73 |
+
if any(is_type_bf16):
|
74 |
+
# convert all weights to fp32 if they are bf16 since torch.from_numpy can-not handle bf16
|
75 |
+
|
76 |
+
# and bf16 is not fully supported in PT yet.
|
77 |
+
logger.warning(
|
78 |
+
"Found ``bfloat16`` weights in Flax model. Casting all ``bfloat16`` weights to ``float32`` "
|
79 |
+
"before loading those in PyTorch model."
|
80 |
+
)
|
81 |
+
flax_state = jax.tree_util.tree_map(
|
82 |
+
lambda params: params.astype(np.float32) if params.dtype == jnp.bfloat16 else params, flax_state
|
83 |
+
)
|
84 |
+
|
85 |
+
pt_model.base_model_prefix = ""
|
86 |
+
|
87 |
+
flax_state_dict = flatten_dict(flax_state, sep=".")
|
88 |
+
pt_model_dict = pt_model.state_dict()
|
89 |
+
|
90 |
+
# keep track of unexpected & missing keys
|
91 |
+
unexpected_keys = []
|
92 |
+
missing_keys = set(pt_model_dict.keys())
|
93 |
+
|
94 |
+
for flax_key_tuple, flax_tensor in flax_state_dict.items():
|
95 |
+
flax_key_tuple_array = flax_key_tuple.split(".")
|
96 |
+
|
97 |
+
if flax_key_tuple_array[-1] == "kernel" and flax_tensor.ndim == 4:
|
98 |
+
flax_key_tuple_array = flax_key_tuple_array[:-1] + ["weight"]
|
99 |
+
flax_tensor = jnp.transpose(flax_tensor, (3, 2, 0, 1))
|
100 |
+
elif flax_key_tuple_array[-1] == "kernel":
|
101 |
+
flax_key_tuple_array = flax_key_tuple_array[:-1] + ["weight"]
|
102 |
+
flax_tensor = flax_tensor.T
|
103 |
+
elif flax_key_tuple_array[-1] == "scale":
|
104 |
+
flax_key_tuple_array = flax_key_tuple_array[:-1] + ["weight"]
|
105 |
+
|
106 |
+
if "time_embedding" not in flax_key_tuple_array:
|
107 |
+
for i, flax_key_tuple_string in enumerate(flax_key_tuple_array):
|
108 |
+
flax_key_tuple_array[i] = (
|
109 |
+
flax_key_tuple_string.replace("_0", ".0")
|
110 |
+
.replace("_1", ".1")
|
111 |
+
.replace("_2", ".2")
|
112 |
+
.replace("_3", ".3")
|
113 |
+
.replace("_4", ".4")
|
114 |
+
.replace("_5", ".5")
|
115 |
+
.replace("_6", ".6")
|
116 |
+
.replace("_7", ".7")
|
117 |
+
.replace("_8", ".8")
|
118 |
+
.replace("_9", ".9")
|
119 |
+
)
|
120 |
+
|
121 |
+
flax_key = ".".join(flax_key_tuple_array)
|
122 |
+
|
123 |
+
if flax_key in pt_model_dict:
|
124 |
+
if flax_tensor.shape != pt_model_dict[flax_key].shape:
|
125 |
+
raise ValueError(
|
126 |
+
f"Flax checkpoint seems to be incorrect. Weight {flax_key_tuple} was expected "
|
127 |
+
f"to be of shape {pt_model_dict[flax_key].shape}, but is {flax_tensor.shape}."
|
128 |
+
)
|
129 |
+
else:
|
130 |
+
# add weight to pytorch dict
|
131 |
+
flax_tensor = np.asarray(flax_tensor) if not isinstance(flax_tensor, np.ndarray) else flax_tensor
|
132 |
+
pt_model_dict[flax_key] = torch.from_numpy(flax_tensor)
|
133 |
+
# remove from missing keys
|
134 |
+
missing_keys.remove(flax_key)
|
135 |
+
else:
|
136 |
+
# weight is not expected by PyTorch model
|
137 |
+
unexpected_keys.append(flax_key)
|
138 |
+
|
139 |
+
pt_model.load_state_dict(pt_model_dict)
|
140 |
+
|
141 |
+
# re-transform missing_keys to list
|
142 |
+
missing_keys = list(missing_keys)
|
143 |
+
|
144 |
+
if len(unexpected_keys) > 0:
|
145 |
+
logger.warning(
|
146 |
+
"Some weights of the Flax model were not used when initializing the PyTorch model"
|
147 |
+
f" {pt_model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are initializing"
|
148 |
+
f" {pt_model.__class__.__name__} from a Flax model trained on another task or with another architecture"
|
149 |
+
" (e.g. initializing a BertForSequenceClassification model from a FlaxBertForPreTraining model).\n- This"
|
150 |
+
f" IS NOT expected if you are initializing {pt_model.__class__.__name__} from a Flax model that you expect"
|
151 |
+
" to be exactly identical (e.g. initializing a BertForSequenceClassification model from a"
|
152 |
+
" FlaxBertForSequenceClassification model)."
|
153 |
+
)
|
154 |
+
if len(missing_keys) > 0:
|
155 |
+
logger.warning(
|
156 |
+
f"Some weights of {pt_model.__class__.__name__} were not initialized from the Flax model and are newly"
|
157 |
+
f" initialized: {missing_keys}\nYou should probably TRAIN this model on a down-stream task to be able to"
|
158 |
+
" use it for predictions and inference."
|
159 |
+
)
|
160 |
+
|
161 |
+
return pt_model
|
6DoF/diffusers/models/modeling_utils.py
ADDED
@@ -0,0 +1,980 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023 The HuggingFace Inc. team.
|
3 |
+
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
|
17 |
+
import inspect
|
18 |
+
import itertools
|
19 |
+
import os
|
20 |
+
import re
|
21 |
+
from functools import partial
|
22 |
+
from typing import Any, Callable, List, Optional, Tuple, Union
|
23 |
+
|
24 |
+
import torch
|
25 |
+
from torch import Tensor, device, nn
|
26 |
+
|
27 |
+
from .. import __version__
|
28 |
+
from ..utils import (
|
29 |
+
CONFIG_NAME,
|
30 |
+
DIFFUSERS_CACHE,
|
31 |
+
FLAX_WEIGHTS_NAME,
|
32 |
+
HF_HUB_OFFLINE,
|
33 |
+
SAFETENSORS_WEIGHTS_NAME,
|
34 |
+
WEIGHTS_NAME,
|
35 |
+
_add_variant,
|
36 |
+
_get_model_file,
|
37 |
+
deprecate,
|
38 |
+
is_accelerate_available,
|
39 |
+
is_safetensors_available,
|
40 |
+
is_torch_version,
|
41 |
+
logging,
|
42 |
+
)
|
43 |
+
|
44 |
+
|
45 |
+
logger = logging.get_logger(__name__)
|
46 |
+
|
47 |
+
|
48 |
+
if is_torch_version(">=", "1.9.0"):
|
49 |
+
_LOW_CPU_MEM_USAGE_DEFAULT = True
|
50 |
+
else:
|
51 |
+
_LOW_CPU_MEM_USAGE_DEFAULT = False
|
52 |
+
|
53 |
+
|
54 |
+
if is_accelerate_available():
|
55 |
+
import accelerate
|
56 |
+
from accelerate.utils import set_module_tensor_to_device
|
57 |
+
from accelerate.utils.versions import is_torch_version
|
58 |
+
|
59 |
+
if is_safetensors_available():
|
60 |
+
import safetensors
|
61 |
+
|
62 |
+
|
63 |
+
def get_parameter_device(parameter: torch.nn.Module):
|
64 |
+
try:
|
65 |
+
parameters_and_buffers = itertools.chain(parameter.parameters(), parameter.buffers())
|
66 |
+
return next(parameters_and_buffers).device
|
67 |
+
except StopIteration:
|
68 |
+
# For torch.nn.DataParallel compatibility in PyTorch 1.5
|
69 |
+
|
70 |
+
def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
|
71 |
+
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
|
72 |
+
return tuples
|
73 |
+
|
74 |
+
gen = parameter._named_members(get_members_fn=find_tensor_attributes)
|
75 |
+
first_tuple = next(gen)
|
76 |
+
return first_tuple[1].device
|
77 |
+
|
78 |
+
|
79 |
+
def get_parameter_dtype(parameter: torch.nn.Module):
|
80 |
+
try:
|
81 |
+
params = tuple(parameter.parameters())
|
82 |
+
if len(params) > 0:
|
83 |
+
return params[0].dtype
|
84 |
+
|
85 |
+
buffers = tuple(parameter.buffers())
|
86 |
+
if len(buffers) > 0:
|
87 |
+
return buffers[0].dtype
|
88 |
+
|
89 |
+
except StopIteration:
|
90 |
+
# For torch.nn.DataParallel compatibility in PyTorch 1.5
|
91 |
+
|
92 |
+
def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
|
93 |
+
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
|
94 |
+
return tuples
|
95 |
+
|
96 |
+
gen = parameter._named_members(get_members_fn=find_tensor_attributes)
|
97 |
+
first_tuple = next(gen)
|
98 |
+
return first_tuple[1].dtype
|
99 |
+
|
100 |
+
|
101 |
+
def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[str] = None):
|
102 |
+
"""
|
103 |
+
Reads a checkpoint file, returning properly formatted errors if they arise.
|
104 |
+
"""
|
105 |
+
try:
|
106 |
+
if os.path.basename(checkpoint_file) == _add_variant(WEIGHTS_NAME, variant):
|
107 |
+
return torch.load(checkpoint_file, map_location="cpu")
|
108 |
+
else:
|
109 |
+
return safetensors.torch.load_file(checkpoint_file, device="cpu")
|
110 |
+
except Exception as e:
|
111 |
+
try:
|
112 |
+
with open(checkpoint_file) as f:
|
113 |
+
if f.read().startswith("version"):
|
114 |
+
raise OSError(
|
115 |
+
"You seem to have cloned a repository without having git-lfs installed. Please install "
|
116 |
+
"git-lfs and run `git lfs install` followed by `git lfs pull` in the folder "
|
117 |
+
"you cloned."
|
118 |
+
)
|
119 |
+
else:
|
120 |
+
raise ValueError(
|
121 |
+
f"Unable to locate the file {checkpoint_file} which is necessary to load this pretrained "
|
122 |
+
"model. Make sure you have saved the model properly."
|
123 |
+
) from e
|
124 |
+
except (UnicodeDecodeError, ValueError):
|
125 |
+
raise OSError(
|
126 |
+
f"Unable to load weights from checkpoint file for '{checkpoint_file}' "
|
127 |
+
f"at '{checkpoint_file}'. "
|
128 |
+
"If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True."
|
129 |
+
)
|
130 |
+
|
131 |
+
|
132 |
+
def _load_state_dict_into_model(model_to_load, state_dict):
|
133 |
+
# Convert old format to new format if needed from a PyTorch state_dict
|
134 |
+
# copy state_dict so _load_from_state_dict can modify it
|
135 |
+
state_dict = state_dict.copy()
|
136 |
+
error_msgs = []
|
137 |
+
|
138 |
+
# PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
|
139 |
+
# so we need to apply the function recursively.
|
140 |
+
def load(module: torch.nn.Module, prefix=""):
|
141 |
+
args = (state_dict, prefix, {}, True, [], [], error_msgs)
|
142 |
+
module._load_from_state_dict(*args)
|
143 |
+
|
144 |
+
for name, child in module._modules.items():
|
145 |
+
if child is not None:
|
146 |
+
load(child, prefix + name + ".")
|
147 |
+
|
148 |
+
load(model_to_load)
|
149 |
+
|
150 |
+
return error_msgs
|
151 |
+
|
152 |
+
|
153 |
+
class ModelMixin(torch.nn.Module):
|
154 |
+
r"""
|
155 |
+
Base class for all models.
|
156 |
+
|
157 |
+
[`ModelMixin`] takes care of storing the model configuration and provides methods for loading, downloading and
|
158 |
+
saving models.
|
159 |
+
|
160 |
+
- **config_name** ([`str`]) -- Filename to save a model to when calling [`~models.ModelMixin.save_pretrained`].
|
161 |
+
"""
|
162 |
+
config_name = CONFIG_NAME
|
163 |
+
_automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"]
|
164 |
+
_supports_gradient_checkpointing = False
|
165 |
+
_keys_to_ignore_on_load_unexpected = None
|
166 |
+
|
167 |
+
def __init__(self):
|
168 |
+
super().__init__()
|
169 |
+
|
170 |
+
def __getattr__(self, name: str) -> Any:
|
171 |
+
"""The only reason we overwrite `getattr` here is to gracefully deprecate accessing
|
172 |
+
config attributes directly. See https://github.com/huggingface/diffusers/pull/3129 We need to overwrite
|
173 |
+
__getattr__ here in addition so that we don't trigger `torch.nn.Module`'s __getattr__':
|
174 |
+
https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module
|
175 |
+
"""
|
176 |
+
|
177 |
+
is_in_config = "_internal_dict" in self.__dict__ and hasattr(self.__dict__["_internal_dict"], name)
|
178 |
+
is_attribute = name in self.__dict__
|
179 |
+
|
180 |
+
if is_in_config and not is_attribute:
|
181 |
+
deprecation_message = f"Accessing config attribute `{name}` directly via '{type(self).__name__}' object attribute is deprecated. Please access '{name}' over '{type(self).__name__}'s config object instead, e.g. 'unet.config.{name}'."
|
182 |
+
deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False, stacklevel=3)
|
183 |
+
return self._internal_dict[name]
|
184 |
+
|
185 |
+
# call PyTorch's https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module
|
186 |
+
return super().__getattr__(name)
|
187 |
+
|
188 |
+
@property
|
189 |
+
def is_gradient_checkpointing(self) -> bool:
|
190 |
+
"""
|
191 |
+
Whether gradient checkpointing is activated for this model or not.
|
192 |
+
"""
|
193 |
+
return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules())
|
194 |
+
|
195 |
+
def enable_gradient_checkpointing(self):
|
196 |
+
"""
|
197 |
+
Activates gradient checkpointing for the current model (may be referred to as *activation checkpointing* or
|
198 |
+
*checkpoint activations* in other frameworks).
|
199 |
+
"""
|
200 |
+
if not self._supports_gradient_checkpointing:
|
201 |
+
raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
|
202 |
+
self.apply(partial(self._set_gradient_checkpointing, value=True))
|
203 |
+
|
204 |
+
def disable_gradient_checkpointing(self):
|
205 |
+
"""
|
206 |
+
Deactivates gradient checkpointing for the current model (may be referred to as *activation checkpointing* or
|
207 |
+
*checkpoint activations* in other frameworks).
|
208 |
+
"""
|
209 |
+
if self._supports_gradient_checkpointing:
|
210 |
+
self.apply(partial(self._set_gradient_checkpointing, value=False))
|
211 |
+
|
212 |
+
def set_use_memory_efficient_attention_xformers(
|
213 |
+
self, valid: bool, attention_op: Optional[Callable] = None
|
214 |
+
) -> None:
|
215 |
+
# Recursively walk through all the children.
|
216 |
+
# Any children which exposes the set_use_memory_efficient_attention_xformers method
|
217 |
+
# gets the message
|
218 |
+
def fn_recursive_set_mem_eff(module: torch.nn.Module):
|
219 |
+
if hasattr(module, "set_use_memory_efficient_attention_xformers"):
|
220 |
+
module.set_use_memory_efficient_attention_xformers(valid, attention_op)
|
221 |
+
|
222 |
+
for child in module.children():
|
223 |
+
fn_recursive_set_mem_eff(child)
|
224 |
+
|
225 |
+
for module in self.children():
|
226 |
+
if isinstance(module, torch.nn.Module):
|
227 |
+
fn_recursive_set_mem_eff(module)
|
228 |
+
|
229 |
+
def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None):
|
230 |
+
r"""
|
231 |
+
Enable memory efficient attention from [xFormers](https://facebookresearch.github.io/xformers/).
|
232 |
+
|
233 |
+
When this option is enabled, you should observe lower GPU memory usage and a potential speed up during
|
234 |
+
inference. Speed up during training is not guaranteed.
|
235 |
+
|
236 |
+
<Tip warning={true}>
|
237 |
+
|
238 |
+
⚠️ When memory efficient attention and sliced attention are both enabled, memory efficient attention takes
|
239 |
+
precedent.
|
240 |
+
|
241 |
+
</Tip>
|
242 |
+
|
243 |
+
Parameters:
|
244 |
+
attention_op (`Callable`, *optional*):
|
245 |
+
Override the default `None` operator for use as `op` argument to the
|
246 |
+
[`memory_efficient_attention()`](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.memory_efficient_attention)
|
247 |
+
function of xFormers.
|
248 |
+
|
249 |
+
Examples:
|
250 |
+
|
251 |
+
```py
|
252 |
+
>>> import torch
|
253 |
+
>>> from diffusers import UNet2DConditionModel
|
254 |
+
>>> from xformers.ops import MemoryEfficientAttentionFlashAttentionOp
|
255 |
+
|
256 |
+
>>> model = UNet2DConditionModel.from_pretrained(
|
257 |
+
... "stabilityai/stable-diffusion-2-1", subfolder="unet", torch_dtype=torch.float16
|
258 |
+
... )
|
259 |
+
>>> model = model.to("cuda")
|
260 |
+
>>> model.enable_xformers_memory_efficient_attention(attention_op=MemoryEfficientAttentionFlashAttentionOp)
|
261 |
+
```
|
262 |
+
"""
|
263 |
+
self.set_use_memory_efficient_attention_xformers(True, attention_op)
|
264 |
+
|
265 |
+
def disable_xformers_memory_efficient_attention(self):
|
266 |
+
r"""
|
267 |
+
Disable memory efficient attention from [xFormers](https://facebookresearch.github.io/xformers/).
|
268 |
+
"""
|
269 |
+
self.set_use_memory_efficient_attention_xformers(False)
|
270 |
+
|
271 |
+
def save_pretrained(
|
272 |
+
self,
|
273 |
+
save_directory: Union[str, os.PathLike],
|
274 |
+
is_main_process: bool = True,
|
275 |
+
save_function: Callable = None,
|
276 |
+
safe_serialization: bool = False,
|
277 |
+
variant: Optional[str] = None,
|
278 |
+
):
|
279 |
+
"""
|
280 |
+
Save a model and its configuration file to a directory so that it can be reloaded using the
|
281 |
+
[`~models.ModelMixin.from_pretrained`] class method.
|
282 |
+
|
283 |
+
Arguments:
|
284 |
+
save_directory (`str` or `os.PathLike`):
|
285 |
+
Directory to save a model and its configuration file to. Will be created if it doesn't exist.
|
286 |
+
is_main_process (`bool`, *optional*, defaults to `True`):
|
287 |
+
Whether the process calling this is the main process or not. Useful during distributed training and you
|
288 |
+
need to call this function on all processes. In this case, set `is_main_process=True` only on the main
|
289 |
+
process to avoid race conditions.
|
290 |
+
save_function (`Callable`):
|
291 |
+
The function to use to save the state dictionary. Useful during distributed training when you need to
|
292 |
+
replace `torch.save` with another method. Can be configured with the environment variable
|
293 |
+
`DIFFUSERS_SAVE_MODE`.
|
294 |
+
safe_serialization (`bool`, *optional*, defaults to `False`):
|
295 |
+
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
|
296 |
+
variant (`str`, *optional*):
|
297 |
+
If specified, weights are saved in the format `pytorch_model.<variant>.bin`.
|
298 |
+
"""
|
299 |
+
if safe_serialization and not is_safetensors_available():
|
300 |
+
raise ImportError("`safe_serialization` requires the `safetensors library: `pip install safetensors`.")
|
301 |
+
|
302 |
+
if os.path.isfile(save_directory):
|
303 |
+
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
|
304 |
+
return
|
305 |
+
|
306 |
+
os.makedirs(save_directory, exist_ok=True)
|
307 |
+
|
308 |
+
model_to_save = self
|
309 |
+
|
310 |
+
# Attach architecture to the config
|
311 |
+
# Save the config
|
312 |
+
if is_main_process:
|
313 |
+
model_to_save.save_config(save_directory)
|
314 |
+
|
315 |
+
# Save the model
|
316 |
+
state_dict = model_to_save.state_dict()
|
317 |
+
|
318 |
+
weights_name = SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
|
319 |
+
weights_name = _add_variant(weights_name, variant)
|
320 |
+
|
321 |
+
# Save the model
|
322 |
+
if safe_serialization:
|
323 |
+
safetensors.torch.save_file(
|
324 |
+
state_dict, os.path.join(save_directory, weights_name), metadata={"format": "pt"}
|
325 |
+
)
|
326 |
+
else:
|
327 |
+
torch.save(state_dict, os.path.join(save_directory, weights_name))
|
328 |
+
|
329 |
+
logger.info(f"Model weights saved in {os.path.join(save_directory, weights_name)}")
|
330 |
+
|
331 |
+
@classmethod
|
332 |
+
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
|
333 |
+
r"""
|
334 |
+
Instantiate a pretrained PyTorch model from a pretrained model configuration.
|
335 |
+
|
336 |
+
The model is set in evaluation mode - `model.eval()` - by default, and dropout modules are deactivated. To
|
337 |
+
train the model, set it back in training mode with `model.train()`.
|
338 |
+
|
339 |
+
Parameters:
|
340 |
+
pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
|
341 |
+
Can be either:
|
342 |
+
|
343 |
+
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
|
344 |
+
the Hub.
|
345 |
+
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
|
346 |
+
with [`~ModelMixin.save_pretrained`].
|
347 |
+
|
348 |
+
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
349 |
+
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
350 |
+
is not used.
|
351 |
+
torch_dtype (`str` or `torch.dtype`, *optional*):
|
352 |
+
Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
|
353 |
+
dtype is automatically derived from the model's weights.
|
354 |
+
force_download (`bool`, *optional*, defaults to `False`):
|
355 |
+
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
356 |
+
cached versions if they exist.
|
357 |
+
resume_download (`bool`, *optional*, defaults to `False`):
|
358 |
+
Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
|
359 |
+
incompletely downloaded files are deleted.
|
360 |
+
proxies (`Dict[str, str]`, *optional*):
|
361 |
+
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
362 |
+
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
363 |
+
output_loading_info (`bool`, *optional*, defaults to `False`):
|
364 |
+
Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
|
365 |
+
local_files_only(`bool`, *optional*, defaults to `False`):
|
366 |
+
Whether to only load local model weights and configuration files or not. If set to `True`, the model
|
367 |
+
won't be downloaded from the Hub.
|
368 |
+
use_auth_token (`str` or *bool*, *optional*):
|
369 |
+
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
370 |
+
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
371 |
+
revision (`str`, *optional*, defaults to `"main"`):
|
372 |
+
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
373 |
+
allowed by Git.
|
374 |
+
from_flax (`bool`, *optional*, defaults to `False`):
|
375 |
+
Load the model weights from a Flax checkpoint save file.
|
376 |
+
subfolder (`str`, *optional*, defaults to `""`):
|
377 |
+
The subfolder location of a model file within a larger model repository on the Hub or locally.
|
378 |
+
mirror (`str`, *optional*):
|
379 |
+
Mirror source to resolve accessibility issues if you're downloading a model in China. We do not
|
380 |
+
guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
|
381 |
+
information.
|
382 |
+
device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
|
383 |
+
A map that specifies where each submodule should go. It doesn't need to be defined for each
|
384 |
+
parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the
|
385 |
+
same device.
|
386 |
+
|
387 |
+
Set `device_map="auto"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For
|
388 |
+
more information about each option see [designing a device
|
389 |
+
map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
|
390 |
+
max_memory (`Dict`, *optional*):
|
391 |
+
A dictionary device identifier for the maximum memory. Will default to the maximum memory available for
|
392 |
+
each GPU and the available CPU RAM if unset.
|
393 |
+
offload_folder (`str` or `os.PathLike`, *optional*):
|
394 |
+
The path to offload weights if `device_map` contains the value `"disk"`.
|
395 |
+
offload_state_dict (`bool`, *optional*):
|
396 |
+
If `True`, temporarily offloads the CPU state dict to the hard drive to avoid running out of CPU RAM if
|
397 |
+
the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to `True`
|
398 |
+
when there is some disk offload.
|
399 |
+
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
|
400 |
+
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
|
401 |
+
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
|
402 |
+
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
|
403 |
+
argument to `True` will raise an error.
|
404 |
+
variant (`str`, *optional*):
|
405 |
+
Load weights from a specified `variant` filename such as `"fp16"` or `"ema"`. This is ignored when
|
406 |
+
loading `from_flax`.
|
407 |
+
use_safetensors (`bool`, *optional*, defaults to `None`):
|
408 |
+
If set to `None`, the `safetensors` weights are downloaded if they're available **and** if the
|
409 |
+
`safetensors` library is installed. If set to `True`, the model is forcibly loaded from `safetensors`
|
410 |
+
weights. If set to `False`, `safetensors` weights are not loaded.
|
411 |
+
|
412 |
+
<Tip>
|
413 |
+
|
414 |
+
To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with
|
415 |
+
`huggingface-cli login`. You can also activate the special
|
416 |
+
["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use this method in a
|
417 |
+
firewalled environment.
|
418 |
+
|
419 |
+
</Tip>
|
420 |
+
|
421 |
+
Example:
|
422 |
+
|
423 |
+
```py
|
424 |
+
from diffusers import UNet2DConditionModel
|
425 |
+
|
426 |
+
unet = UNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet")
|
427 |
+
```
|
428 |
+
|
429 |
+
If you get the error message below, you need to finetune the weights for your downstream task:
|
430 |
+
|
431 |
+
```bash
|
432 |
+
Some weights of UNet2DConditionModel were not initialized from the model checkpoint at runwayml/stable-diffusion-v1-5 and are newly initialized because the shapes did not match:
|
433 |
+
- conv_in.weight: found shape torch.Size([320, 4, 3, 3]) in the checkpoint and torch.Size([320, 9, 3, 3]) in the model instantiated
|
434 |
+
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
|
435 |
+
```
|
436 |
+
"""
|
437 |
+
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
|
438 |
+
ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
|
439 |
+
force_download = kwargs.pop("force_download", False)
|
440 |
+
from_flax = kwargs.pop("from_flax", False)
|
441 |
+
resume_download = kwargs.pop("resume_download", False)
|
442 |
+
proxies = kwargs.pop("proxies", None)
|
443 |
+
output_loading_info = kwargs.pop("output_loading_info", False)
|
444 |
+
local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
|
445 |
+
use_auth_token = kwargs.pop("use_auth_token", None)
|
446 |
+
revision = kwargs.pop("revision", None)
|
447 |
+
torch_dtype = kwargs.pop("torch_dtype", None)
|
448 |
+
subfolder = kwargs.pop("subfolder", None)
|
449 |
+
device_map = kwargs.pop("device_map", None)
|
450 |
+
max_memory = kwargs.pop("max_memory", None)
|
451 |
+
offload_folder = kwargs.pop("offload_folder", None)
|
452 |
+
offload_state_dict = kwargs.pop("offload_state_dict", False)
|
453 |
+
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
|
454 |
+
variant = kwargs.pop("variant", None)
|
455 |
+
use_safetensors = kwargs.pop("use_safetensors", None)
|
456 |
+
|
457 |
+
if use_safetensors and not is_safetensors_available():
|
458 |
+
raise ValueError(
|
459 |
+
"`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetensors"
|
460 |
+
)
|
461 |
+
|
462 |
+
allow_pickle = False
|
463 |
+
if use_safetensors is None:
|
464 |
+
use_safetensors = is_safetensors_available()
|
465 |
+
allow_pickle = True
|
466 |
+
|
467 |
+
if low_cpu_mem_usage and not is_accelerate_available():
|
468 |
+
low_cpu_mem_usage = False
|
469 |
+
logger.warning(
|
470 |
+
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
|
471 |
+
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
|
472 |
+
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
|
473 |
+
" install accelerate\n```\n."
|
474 |
+
)
|
475 |
+
|
476 |
+
if device_map is not None and not is_accelerate_available():
|
477 |
+
raise NotImplementedError(
|
478 |
+
"Loading and dispatching requires `accelerate`. Please make sure to install accelerate or set"
|
479 |
+
" `device_map=None`. You can install accelerate with `pip install accelerate`."
|
480 |
+
)
|
481 |
+
|
482 |
+
# Check if we can handle device_map and dispatching the weights
|
483 |
+
if device_map is not None and not is_torch_version(">=", "1.9.0"):
|
484 |
+
raise NotImplementedError(
|
485 |
+
"Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
486 |
+
" `device_map=None`."
|
487 |
+
)
|
488 |
+
|
489 |
+
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
|
490 |
+
raise NotImplementedError(
|
491 |
+
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
492 |
+
" `low_cpu_mem_usage=False`."
|
493 |
+
)
|
494 |
+
|
495 |
+
if low_cpu_mem_usage is False and device_map is not None:
|
496 |
+
raise ValueError(
|
497 |
+
f"You cannot set `low_cpu_mem_usage` to `False` while using device_map={device_map} for loading and"
|
498 |
+
" dispatching. Please make sure to set `low_cpu_mem_usage=True`."
|
499 |
+
)
|
500 |
+
|
501 |
+
# Load config if we don't provide a configuration
|
502 |
+
config_path = pretrained_model_name_or_path
|
503 |
+
|
504 |
+
user_agent = {
|
505 |
+
"diffusers": __version__,
|
506 |
+
"file_type": "model",
|
507 |
+
"framework": "pytorch",
|
508 |
+
}
|
509 |
+
|
510 |
+
# load config
|
511 |
+
config, unused_kwargs, commit_hash = cls.load_config(
|
512 |
+
config_path,
|
513 |
+
cache_dir=cache_dir,
|
514 |
+
return_unused_kwargs=True,
|
515 |
+
return_commit_hash=True,
|
516 |
+
force_download=force_download,
|
517 |
+
resume_download=resume_download,
|
518 |
+
proxies=proxies,
|
519 |
+
local_files_only=local_files_only,
|
520 |
+
use_auth_token=use_auth_token,
|
521 |
+
revision=revision,
|
522 |
+
subfolder=subfolder,
|
523 |
+
device_map=device_map,
|
524 |
+
max_memory=max_memory,
|
525 |
+
offload_folder=offload_folder,
|
526 |
+
offload_state_dict=offload_state_dict,
|
527 |
+
user_agent=user_agent,
|
528 |
+
**kwargs,
|
529 |
+
)
|
530 |
+
|
531 |
+
# load model
|
532 |
+
model_file = None
|
533 |
+
if from_flax:
|
534 |
+
model_file = _get_model_file(
|
535 |
+
pretrained_model_name_or_path,
|
536 |
+
weights_name=FLAX_WEIGHTS_NAME,
|
537 |
+
cache_dir=cache_dir,
|
538 |
+
force_download=force_download,
|
539 |
+
resume_download=resume_download,
|
540 |
+
proxies=proxies,
|
541 |
+
local_files_only=local_files_only,
|
542 |
+
use_auth_token=use_auth_token,
|
543 |
+
revision=revision,
|
544 |
+
subfolder=subfolder,
|
545 |
+
user_agent=user_agent,
|
546 |
+
commit_hash=commit_hash,
|
547 |
+
)
|
548 |
+
model = cls.from_config(config, **unused_kwargs)
|
549 |
+
|
550 |
+
# Convert the weights
|
551 |
+
from .modeling_pytorch_flax_utils import load_flax_checkpoint_in_pytorch_model
|
552 |
+
|
553 |
+
model = load_flax_checkpoint_in_pytorch_model(model, model_file)
|
554 |
+
else:
|
555 |
+
if use_safetensors:
|
556 |
+
try:
|
557 |
+
model_file = _get_model_file(
|
558 |
+
pretrained_model_name_or_path,
|
559 |
+
weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant),
|
560 |
+
cache_dir=cache_dir,
|
561 |
+
force_download=force_download,
|
562 |
+
resume_download=resume_download,
|
563 |
+
proxies=proxies,
|
564 |
+
local_files_only=local_files_only,
|
565 |
+
use_auth_token=use_auth_token,
|
566 |
+
revision=revision,
|
567 |
+
subfolder=subfolder,
|
568 |
+
user_agent=user_agent,
|
569 |
+
commit_hash=commit_hash,
|
570 |
+
)
|
571 |
+
except IOError as e:
|
572 |
+
if not allow_pickle:
|
573 |
+
raise e
|
574 |
+
pass
|
575 |
+
if model_file is None:
|
576 |
+
model_file = _get_model_file(
|
577 |
+
pretrained_model_name_or_path,
|
578 |
+
weights_name=_add_variant(WEIGHTS_NAME, variant),
|
579 |
+
cache_dir=cache_dir,
|
580 |
+
force_download=force_download,
|
581 |
+
resume_download=resume_download,
|
582 |
+
proxies=proxies,
|
583 |
+
local_files_only=local_files_only,
|
584 |
+
use_auth_token=use_auth_token,
|
585 |
+
revision=revision,
|
586 |
+
subfolder=subfolder,
|
587 |
+
user_agent=user_agent,
|
588 |
+
commit_hash=commit_hash,
|
589 |
+
)
|
590 |
+
|
591 |
+
if low_cpu_mem_usage:
|
592 |
+
# Instantiate model with empty weights
|
593 |
+
with accelerate.init_empty_weights():
|
594 |
+
model = cls.from_config(config, **unused_kwargs)
|
595 |
+
|
596 |
+
# if device_map is None, load the state dict and move the params from meta device to the cpu
|
597 |
+
if device_map is None:
|
598 |
+
param_device = "cpu"
|
599 |
+
state_dict = load_state_dict(model_file, variant=variant)
|
600 |
+
model._convert_deprecated_attention_blocks(state_dict)
|
601 |
+
# move the params from meta device to cpu
|
602 |
+
missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
|
603 |
+
if len(missing_keys) > 0:
|
604 |
+
raise ValueError(
|
605 |
+
f"Cannot load {cls} from {pretrained_model_name_or_path} because the following keys are"
|
606 |
+
f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass"
|
607 |
+
" `low_cpu_mem_usage=False` and `device_map=None` if you want to randomly initialize"
|
608 |
+
" those weights or else make sure your checkpoint file is correct."
|
609 |
+
)
|
610 |
+
unexpected_keys = []
|
611 |
+
|
612 |
+
empty_state_dict = model.state_dict()
|
613 |
+
for param_name, param in state_dict.items():
|
614 |
+
accepts_dtype = "dtype" in set(
|
615 |
+
inspect.signature(set_module_tensor_to_device).parameters.keys()
|
616 |
+
)
|
617 |
+
|
618 |
+
if param_name not in empty_state_dict:
|
619 |
+
unexpected_keys.append(param_name)
|
620 |
+
continue
|
621 |
+
|
622 |
+
if empty_state_dict[param_name].shape != param.shape:
|
623 |
+
raise ValueError(
|
624 |
+
f"Cannot load {pretrained_model_name_or_path} because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example."
|
625 |
+
)
|
626 |
+
|
627 |
+
if accepts_dtype:
|
628 |
+
set_module_tensor_to_device(
|
629 |
+
model, param_name, param_device, value=param, dtype=torch_dtype
|
630 |
+
)
|
631 |
+
else:
|
632 |
+
set_module_tensor_to_device(model, param_name, param_device, value=param)
|
633 |
+
|
634 |
+
if cls._keys_to_ignore_on_load_unexpected is not None:
|
635 |
+
for pat in cls._keys_to_ignore_on_load_unexpected:
|
636 |
+
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
|
637 |
+
|
638 |
+
if len(unexpected_keys) > 0:
|
639 |
+
logger.warn(
|
640 |
+
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
|
641 |
+
)
|
642 |
+
|
643 |
+
else: # else let accelerate handle loading and dispatching.
|
644 |
+
# Load weights and dispatch according to the device_map
|
645 |
+
# by default the device_map is None and the weights are loaded on the CPU
|
646 |
+
try:
|
647 |
+
accelerate.load_checkpoint_and_dispatch(
|
648 |
+
model,
|
649 |
+
model_file,
|
650 |
+
device_map,
|
651 |
+
max_memory=max_memory,
|
652 |
+
offload_folder=offload_folder,
|
653 |
+
offload_state_dict=offload_state_dict,
|
654 |
+
dtype=torch_dtype,
|
655 |
+
)
|
656 |
+
except AttributeError as e:
|
657 |
+
# When using accelerate loading, we do not have the ability to load the state
|
658 |
+
# dict and rename the weight names manually. Additionally, accelerate skips
|
659 |
+
# torch loading conventions and directly writes into `module.{_buffers, _parameters}`
|
660 |
+
# (which look like they should be private variables?), so we can't use the standard hooks
|
661 |
+
# to rename parameters on load. We need to mimic the original weight names so the correct
|
662 |
+
# attributes are available. After we have loaded the weights, we convert the deprecated
|
663 |
+
# names to the new non-deprecated names. Then we _greatly encourage_ the user to convert
|
664 |
+
# the weights so we don't have to do this again.
|
665 |
+
|
666 |
+
if "'Attention' object has no attribute" in str(e):
|
667 |
+
logger.warn(
|
668 |
+
f"Taking `{str(e)}` while using `accelerate.load_checkpoint_and_dispatch` to mean {pretrained_model_name_or_path}"
|
669 |
+
" was saved with deprecated attention block weight names. We will load it with the deprecated attention block"
|
670 |
+
" names and convert them on the fly to the new attention block format. Please re-save the model after this conversion,"
|
671 |
+
" so we don't have to do the on the fly renaming in the future. If the model is from a hub checkpoint,"
|
672 |
+
" please also re-upload it or open a PR on the original repository."
|
673 |
+
)
|
674 |
+
model._temp_convert_self_to_deprecated_attention_blocks()
|
675 |
+
accelerate.load_checkpoint_and_dispatch(
|
676 |
+
model,
|
677 |
+
model_file,
|
678 |
+
device_map,
|
679 |
+
max_memory=max_memory,
|
680 |
+
offload_folder=offload_folder,
|
681 |
+
offload_state_dict=offload_state_dict,
|
682 |
+
dtype=torch_dtype,
|
683 |
+
)
|
684 |
+
model._undo_temp_convert_self_to_deprecated_attention_blocks()
|
685 |
+
else:
|
686 |
+
raise e
|
687 |
+
|
688 |
+
loading_info = {
|
689 |
+
"missing_keys": [],
|
690 |
+
"unexpected_keys": [],
|
691 |
+
"mismatched_keys": [],
|
692 |
+
"error_msgs": [],
|
693 |
+
}
|
694 |
+
else:
|
695 |
+
model = cls.from_config(config, **unused_kwargs)
|
696 |
+
|
697 |
+
state_dict = load_state_dict(model_file, variant=variant)
|
698 |
+
model._convert_deprecated_attention_blocks(state_dict)
|
699 |
+
|
700 |
+
model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
|
701 |
+
model,
|
702 |
+
state_dict,
|
703 |
+
model_file,
|
704 |
+
pretrained_model_name_or_path,
|
705 |
+
ignore_mismatched_sizes=ignore_mismatched_sizes,
|
706 |
+
)
|
707 |
+
|
708 |
+
loading_info = {
|
709 |
+
"missing_keys": missing_keys,
|
710 |
+
"unexpected_keys": unexpected_keys,
|
711 |
+
"mismatched_keys": mismatched_keys,
|
712 |
+
"error_msgs": error_msgs,
|
713 |
+
}
|
714 |
+
|
715 |
+
if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
|
716 |
+
raise ValueError(
|
717 |
+
f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}."
|
718 |
+
)
|
719 |
+
elif torch_dtype is not None:
|
720 |
+
model = model.to(torch_dtype)
|
721 |
+
|
722 |
+
model.register_to_config(_name_or_path=pretrained_model_name_or_path)
|
723 |
+
|
724 |
+
# Set model in evaluation mode to deactivate DropOut modules by default
|
725 |
+
model.eval()
|
726 |
+
if output_loading_info:
|
727 |
+
return model, loading_info
|
728 |
+
|
729 |
+
return model
|
730 |
+
|
731 |
+
@classmethod
|
732 |
+
def _load_pretrained_model(
|
733 |
+
cls,
|
734 |
+
model,
|
735 |
+
state_dict,
|
736 |
+
resolved_archive_file,
|
737 |
+
pretrained_model_name_or_path,
|
738 |
+
ignore_mismatched_sizes=False,
|
739 |
+
):
|
740 |
+
# Retrieve missing & unexpected_keys
|
741 |
+
model_state_dict = model.state_dict()
|
742 |
+
loaded_keys = list(state_dict.keys())
|
743 |
+
|
744 |
+
expected_keys = list(model_state_dict.keys())
|
745 |
+
|
746 |
+
original_loaded_keys = loaded_keys
|
747 |
+
|
748 |
+
missing_keys = list(set(expected_keys) - set(loaded_keys))
|
749 |
+
unexpected_keys = list(set(loaded_keys) - set(expected_keys))
|
750 |
+
|
751 |
+
# Make sure we are able to load base models as well as derived models (with heads)
|
752 |
+
model_to_load = model
|
753 |
+
|
754 |
+
def _find_mismatched_keys(
|
755 |
+
state_dict,
|
756 |
+
model_state_dict,
|
757 |
+
loaded_keys,
|
758 |
+
ignore_mismatched_sizes,
|
759 |
+
):
|
760 |
+
mismatched_keys = []
|
761 |
+
if ignore_mismatched_sizes:
|
762 |
+
for checkpoint_key in loaded_keys:
|
763 |
+
model_key = checkpoint_key
|
764 |
+
|
765 |
+
if (
|
766 |
+
model_key in model_state_dict
|
767 |
+
and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape
|
768 |
+
):
|
769 |
+
mismatched_keys.append(
|
770 |
+
(checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
|
771 |
+
)
|
772 |
+
del state_dict[checkpoint_key]
|
773 |
+
return mismatched_keys
|
774 |
+
|
775 |
+
if state_dict is not None:
|
776 |
+
# Whole checkpoint
|
777 |
+
mismatched_keys = _find_mismatched_keys(
|
778 |
+
state_dict,
|
779 |
+
model_state_dict,
|
780 |
+
original_loaded_keys,
|
781 |
+
ignore_mismatched_sizes,
|
782 |
+
)
|
783 |
+
error_msgs = _load_state_dict_into_model(model_to_load, state_dict)
|
784 |
+
|
785 |
+
if len(error_msgs) > 0:
|
786 |
+
error_msg = "\n\t".join(error_msgs)
|
787 |
+
if "size mismatch" in error_msg:
|
788 |
+
error_msg += (
|
789 |
+
"\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method."
|
790 |
+
)
|
791 |
+
raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")
|
792 |
+
|
793 |
+
if len(unexpected_keys) > 0:
|
794 |
+
logger.warning(
|
795 |
+
f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
|
796 |
+
f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
|
797 |
+
f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task"
|
798 |
+
" or with another architecture (e.g. initializing a BertForSequenceClassification model from a"
|
799 |
+
" BertForPreTraining model).\n- This IS NOT expected if you are initializing"
|
800 |
+
f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly"
|
801 |
+
" identical (initializing a BertForSequenceClassification model from a"
|
802 |
+
" BertForSequenceClassification model)."
|
803 |
+
)
|
804 |
+
else:
|
805 |
+
logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
|
806 |
+
if len(missing_keys) > 0:
|
807 |
+
logger.warning(
|
808 |
+
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
|
809 |
+
f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably"
|
810 |
+
" TRAIN this model on a down-stream task to be able to use it for predictions and inference."
|
811 |
+
)
|
812 |
+
elif len(mismatched_keys) == 0:
|
813 |
+
logger.info(
|
814 |
+
f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at"
|
815 |
+
f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the"
|
816 |
+
f" checkpoint was trained on, you can already use {model.__class__.__name__} for predictions"
|
817 |
+
" without further training."
|
818 |
+
)
|
819 |
+
if len(mismatched_keys) > 0:
|
820 |
+
mismatched_warning = "\n".join(
|
821 |
+
[
|
822 |
+
f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
|
823 |
+
for key, shape1, shape2 in mismatched_keys
|
824 |
+
]
|
825 |
+
)
|
826 |
+
logger.warning(
|
827 |
+
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
|
828 |
+
f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not"
|
829 |
+
f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be"
|
830 |
+
" able to use it for predictions and inference."
|
831 |
+
)
|
832 |
+
|
833 |
+
return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs
|
834 |
+
|
835 |
+
@property
|
836 |
+
def device(self) -> device:
|
837 |
+
"""
|
838 |
+
`torch.device`: The device on which the module is (assuming that all the module parameters are on the same
|
839 |
+
device).
|
840 |
+
"""
|
841 |
+
return get_parameter_device(self)
|
842 |
+
|
843 |
+
@property
|
844 |
+
def dtype(self) -> torch.dtype:
|
845 |
+
"""
|
846 |
+
`torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
|
847 |
+
"""
|
848 |
+
return get_parameter_dtype(self)
|
849 |
+
|
850 |
+
def num_parameters(self, only_trainable: bool = False, exclude_embeddings: bool = False) -> int:
|
851 |
+
"""
|
852 |
+
Get number of (trainable or non-embedding) parameters in the module.
|
853 |
+
|
854 |
+
Args:
|
855 |
+
only_trainable (`bool`, *optional*, defaults to `False`):
|
856 |
+
Whether or not to return only the number of trainable parameters.
|
857 |
+
exclude_embeddings (`bool`, *optional*, defaults to `False`):
|
858 |
+
Whether or not to return only the number of non-embedding parameters.
|
859 |
+
|
860 |
+
Returns:
|
861 |
+
`int`: The number of parameters.
|
862 |
+
|
863 |
+
Example:
|
864 |
+
|
865 |
+
```py
|
866 |
+
from diffusers import UNet2DConditionModel
|
867 |
+
|
868 |
+
model_id = "runwayml/stable-diffusion-v1-5"
|
869 |
+
unet = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet")
|
870 |
+
unet.num_parameters(only_trainable=True)
|
871 |
+
859520964
|
872 |
+
```
|
873 |
+
"""
|
874 |
+
|
875 |
+
if exclude_embeddings:
|
876 |
+
embedding_param_names = [
|
877 |
+
f"{name}.weight"
|
878 |
+
for name, module_type in self.named_modules()
|
879 |
+
if isinstance(module_type, torch.nn.Embedding)
|
880 |
+
]
|
881 |
+
non_embedding_parameters = [
|
882 |
+
parameter for name, parameter in self.named_parameters() if name not in embedding_param_names
|
883 |
+
]
|
884 |
+
return sum(p.numel() for p in non_embedding_parameters if p.requires_grad or not only_trainable)
|
885 |
+
else:
|
886 |
+
return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable)
|
887 |
+
|
888 |
+
def _convert_deprecated_attention_blocks(self, state_dict):
|
889 |
+
deprecated_attention_block_paths = []
|
890 |
+
|
891 |
+
def recursive_find_attn_block(name, module):
|
892 |
+
if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block:
|
893 |
+
deprecated_attention_block_paths.append(name)
|
894 |
+
|
895 |
+
for sub_name, sub_module in module.named_children():
|
896 |
+
sub_name = sub_name if name == "" else f"{name}.{sub_name}"
|
897 |
+
recursive_find_attn_block(sub_name, sub_module)
|
898 |
+
|
899 |
+
recursive_find_attn_block("", self)
|
900 |
+
|
901 |
+
# NOTE: we have to check if the deprecated parameters are in the state dict
|
902 |
+
# because it is possible we are loading from a state dict that was already
|
903 |
+
# converted
|
904 |
+
|
905 |
+
for path in deprecated_attention_block_paths:
|
906 |
+
# group_norm path stays the same
|
907 |
+
|
908 |
+
# query -> to_q
|
909 |
+
if f"{path}.query.weight" in state_dict:
|
910 |
+
state_dict[f"{path}.to_q.weight"] = state_dict.pop(f"{path}.query.weight")
|
911 |
+
if f"{path}.query.bias" in state_dict:
|
912 |
+
state_dict[f"{path}.to_q.bias"] = state_dict.pop(f"{path}.query.bias")
|
913 |
+
|
914 |
+
# key -> to_k
|
915 |
+
if f"{path}.key.weight" in state_dict:
|
916 |
+
state_dict[f"{path}.to_k.weight"] = state_dict.pop(f"{path}.key.weight")
|
917 |
+
if f"{path}.key.bias" in state_dict:
|
918 |
+
state_dict[f"{path}.to_k.bias"] = state_dict.pop(f"{path}.key.bias")
|
919 |
+
|
920 |
+
# value -> to_v
|
921 |
+
if f"{path}.value.weight" in state_dict:
|
922 |
+
state_dict[f"{path}.to_v.weight"] = state_dict.pop(f"{path}.value.weight")
|
923 |
+
if f"{path}.value.bias" in state_dict:
|
924 |
+
state_dict[f"{path}.to_v.bias"] = state_dict.pop(f"{path}.value.bias")
|
925 |
+
|
926 |
+
# proj_attn -> to_out.0
|
927 |
+
if f"{path}.proj_attn.weight" in state_dict:
|
928 |
+
state_dict[f"{path}.to_out.0.weight"] = state_dict.pop(f"{path}.proj_attn.weight")
|
929 |
+
if f"{path}.proj_attn.bias" in state_dict:
|
930 |
+
state_dict[f"{path}.to_out.0.bias"] = state_dict.pop(f"{path}.proj_attn.bias")
|
931 |
+
|
932 |
+
def _temp_convert_self_to_deprecated_attention_blocks(self):
|
933 |
+
deprecated_attention_block_modules = []
|
934 |
+
|
935 |
+
def recursive_find_attn_block(module):
|
936 |
+
if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block:
|
937 |
+
deprecated_attention_block_modules.append(module)
|
938 |
+
|
939 |
+
for sub_module in module.children():
|
940 |
+
recursive_find_attn_block(sub_module)
|
941 |
+
|
942 |
+
recursive_find_attn_block(self)
|
943 |
+
|
944 |
+
for module in deprecated_attention_block_modules:
|
945 |
+
module.query = module.to_q
|
946 |
+
module.key = module.to_k
|
947 |
+
module.value = module.to_v
|
948 |
+
module.proj_attn = module.to_out[0]
|
949 |
+
|
950 |
+
# We don't _have_ to delete the old attributes, but it's helpful to ensure
|
951 |
+
# that _all_ the weights are loaded into the new attributes and we're not
|
952 |
+
# making an incorrect assumption that this model should be converted when
|
953 |
+
# it really shouldn't be.
|
954 |
+
del module.to_q
|
955 |
+
del module.to_k
|
956 |
+
del module.to_v
|
957 |
+
del module.to_out
|
958 |
+
|
959 |
+
def _undo_temp_convert_self_to_deprecated_attention_blocks(self):
|
960 |
+
deprecated_attention_block_modules = []
|
961 |
+
|
962 |
+
def recursive_find_attn_block(module):
|
963 |
+
if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block:
|
964 |
+
deprecated_attention_block_modules.append(module)
|
965 |
+
|
966 |
+
for sub_module in module.children():
|
967 |
+
recursive_find_attn_block(sub_module)
|
968 |
+
|
969 |
+
recursive_find_attn_block(self)
|
970 |
+
|
971 |
+
for module in deprecated_attention_block_modules:
|
972 |
+
module.to_q = module.query
|
973 |
+
module.to_k = module.key
|
974 |
+
module.to_v = module.value
|
975 |
+
module.to_out = nn.ModuleList([module.proj_attn, nn.Dropout(module.dropout)])
|
976 |
+
|
977 |
+
del module.query
|
978 |
+
del module.key
|
979 |
+
del module.value
|
980 |
+
del module.proj_attn
|
6DoF/diffusers/models/prior_transformer.py
ADDED
@@ -0,0 +1,364 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
from typing import Dict, Optional, Union
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from torch import nn
|
7 |
+
|
8 |
+
from ..configuration_utils import ConfigMixin, register_to_config
|
9 |
+
from ..utils import BaseOutput
|
10 |
+
from .attention import BasicTransformerBlock
|
11 |
+
from .attention_processor import AttentionProcessor, AttnProcessor
|
12 |
+
from .embeddings import TimestepEmbedding, Timesteps
|
13 |
+
from .modeling_utils import ModelMixin
|
14 |
+
|
15 |
+
|
16 |
+
@dataclass
|
17 |
+
class PriorTransformerOutput(BaseOutput):
|
18 |
+
"""
|
19 |
+
The output of [`PriorTransformer`].
|
20 |
+
|
21 |
+
Args:
|
22 |
+
predicted_image_embedding (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`):
|
23 |
+
The predicted CLIP image embedding conditioned on the CLIP text embedding input.
|
24 |
+
"""
|
25 |
+
|
26 |
+
predicted_image_embedding: torch.FloatTensor
|
27 |
+
|
28 |
+
|
29 |
+
class PriorTransformer(ModelMixin, ConfigMixin):
|
30 |
+
"""
|
31 |
+
A Prior Transformer model.
|
32 |
+
|
33 |
+
Parameters:
|
34 |
+
num_attention_heads (`int`, *optional*, defaults to 32): The number of heads to use for multi-head attention.
|
35 |
+
attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
|
36 |
+
num_layers (`int`, *optional*, defaults to 20): The number of layers of Transformer blocks to use.
|
37 |
+
embedding_dim (`int`, *optional*, defaults to 768): The dimension of the model input `hidden_states`
|
38 |
+
num_embeddings (`int`, *optional*, defaults to 77):
|
39 |
+
The number of embeddings of the model input `hidden_states`
|
40 |
+
additional_embeddings (`int`, *optional*, defaults to 4): The number of additional tokens appended to the
|
41 |
+
projected `hidden_states`. The actual length of the used `hidden_states` is `num_embeddings +
|
42 |
+
additional_embeddings`.
|
43 |
+
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
44 |
+
time_embed_act_fn (`str`, *optional*, defaults to 'silu'):
|
45 |
+
The activation function to use to create timestep embeddings.
|
46 |
+
norm_in_type (`str`, *optional*, defaults to None): The normalization layer to apply on hidden states before
|
47 |
+
passing to Transformer blocks. Set it to `None` if normalization is not needed.
|
48 |
+
embedding_proj_norm_type (`str`, *optional*, defaults to None):
|
49 |
+
The normalization layer to apply on the input `proj_embedding`. Set it to `None` if normalization is not
|
50 |
+
needed.
|
51 |
+
encoder_hid_proj_type (`str`, *optional*, defaults to `linear`):
|
52 |
+
The projection layer to apply on the input `encoder_hidden_states`. Set it to `None` if
|
53 |
+
`encoder_hidden_states` is `None`.
|
54 |
+
added_emb_type (`str`, *optional*, defaults to `prd`): Additional embeddings to condition the model.
|
55 |
+
Choose from `prd` or `None`. if choose `prd`, it will prepend a token indicating the (quantized) dot
|
56 |
+
product between the text embedding and image embedding as proposed in the unclip paper
|
57 |
+
https://arxiv.org/abs/2204.06125 If it is `None`, no additional embeddings will be prepended.
|
58 |
+
time_embed_dim (`int, *optional*, defaults to None): The dimension of timestep embeddings.
|
59 |
+
If None, will be set to `num_attention_heads * attention_head_dim`
|
60 |
+
embedding_proj_dim (`int`, *optional*, default to None):
|
61 |
+
The dimension of `proj_embedding`. If None, will be set to `embedding_dim`.
|
62 |
+
clip_embed_dim (`int`, *optional*, default to None):
|
63 |
+
The dimension of the output. If None, will be set to `embedding_dim`.
|
64 |
+
"""
|
65 |
+
|
66 |
+
@register_to_config
|
67 |
+
def __init__(
|
68 |
+
self,
|
69 |
+
num_attention_heads: int = 32,
|
70 |
+
attention_head_dim: int = 64,
|
71 |
+
num_layers: int = 20,
|
72 |
+
embedding_dim: int = 768,
|
73 |
+
num_embeddings=77,
|
74 |
+
additional_embeddings=4,
|
75 |
+
dropout: float = 0.0,
|
76 |
+
time_embed_act_fn: str = "silu",
|
77 |
+
norm_in_type: Optional[str] = None, # layer
|
78 |
+
embedding_proj_norm_type: Optional[str] = None, # layer
|
79 |
+
encoder_hid_proj_type: Optional[str] = "linear", # linear
|
80 |
+
added_emb_type: Optional[str] = "prd", # prd
|
81 |
+
time_embed_dim: Optional[int] = None,
|
82 |
+
embedding_proj_dim: Optional[int] = None,
|
83 |
+
clip_embed_dim: Optional[int] = None,
|
84 |
+
):
|
85 |
+
super().__init__()
|
86 |
+
self.num_attention_heads = num_attention_heads
|
87 |
+
self.attention_head_dim = attention_head_dim
|
88 |
+
inner_dim = num_attention_heads * attention_head_dim
|
89 |
+
self.additional_embeddings = additional_embeddings
|
90 |
+
|
91 |
+
time_embed_dim = time_embed_dim or inner_dim
|
92 |
+
embedding_proj_dim = embedding_proj_dim or embedding_dim
|
93 |
+
clip_embed_dim = clip_embed_dim or embedding_dim
|
94 |
+
|
95 |
+
self.time_proj = Timesteps(inner_dim, True, 0)
|
96 |
+
self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, out_dim=inner_dim, act_fn=time_embed_act_fn)
|
97 |
+
|
98 |
+
self.proj_in = nn.Linear(embedding_dim, inner_dim)
|
99 |
+
|
100 |
+
if embedding_proj_norm_type is None:
|
101 |
+
self.embedding_proj_norm = None
|
102 |
+
elif embedding_proj_norm_type == "layer":
|
103 |
+
self.embedding_proj_norm = nn.LayerNorm(embedding_proj_dim)
|
104 |
+
else:
|
105 |
+
raise ValueError(f"unsupported embedding_proj_norm_type: {embedding_proj_norm_type}")
|
106 |
+
|
107 |
+
self.embedding_proj = nn.Linear(embedding_proj_dim, inner_dim)
|
108 |
+
|
109 |
+
if encoder_hid_proj_type is None:
|
110 |
+
self.encoder_hidden_states_proj = None
|
111 |
+
elif encoder_hid_proj_type == "linear":
|
112 |
+
self.encoder_hidden_states_proj = nn.Linear(embedding_dim, inner_dim)
|
113 |
+
else:
|
114 |
+
raise ValueError(f"unsupported encoder_hid_proj_type: {encoder_hid_proj_type}")
|
115 |
+
|
116 |
+
self.positional_embedding = nn.Parameter(torch.zeros(1, num_embeddings + additional_embeddings, inner_dim))
|
117 |
+
|
118 |
+
if added_emb_type == "prd":
|
119 |
+
self.prd_embedding = nn.Parameter(torch.zeros(1, 1, inner_dim))
|
120 |
+
elif added_emb_type is None:
|
121 |
+
self.prd_embedding = None
|
122 |
+
else:
|
123 |
+
raise ValueError(
|
124 |
+
f"`added_emb_type`: {added_emb_type} is not supported. Make sure to choose one of `'prd'` or `None`."
|
125 |
+
)
|
126 |
+
|
127 |
+
self.transformer_blocks = nn.ModuleList(
|
128 |
+
[
|
129 |
+
BasicTransformerBlock(
|
130 |
+
inner_dim,
|
131 |
+
num_attention_heads,
|
132 |
+
attention_head_dim,
|
133 |
+
dropout=dropout,
|
134 |
+
activation_fn="gelu",
|
135 |
+
attention_bias=True,
|
136 |
+
)
|
137 |
+
for d in range(num_layers)
|
138 |
+
]
|
139 |
+
)
|
140 |
+
|
141 |
+
if norm_in_type == "layer":
|
142 |
+
self.norm_in = nn.LayerNorm(inner_dim)
|
143 |
+
elif norm_in_type is None:
|
144 |
+
self.norm_in = None
|
145 |
+
else:
|
146 |
+
raise ValueError(f"Unsupported norm_in_type: {norm_in_type}.")
|
147 |
+
|
148 |
+
self.norm_out = nn.LayerNorm(inner_dim)
|
149 |
+
|
150 |
+
self.proj_to_clip_embeddings = nn.Linear(inner_dim, clip_embed_dim)
|
151 |
+
|
152 |
+
causal_attention_mask = torch.full(
|
153 |
+
[num_embeddings + additional_embeddings, num_embeddings + additional_embeddings], -10000.0
|
154 |
+
)
|
155 |
+
causal_attention_mask.triu_(1)
|
156 |
+
causal_attention_mask = causal_attention_mask[None, ...]
|
157 |
+
self.register_buffer("causal_attention_mask", causal_attention_mask, persistent=False)
|
158 |
+
|
159 |
+
self.clip_mean = nn.Parameter(torch.zeros(1, clip_embed_dim))
|
160 |
+
self.clip_std = nn.Parameter(torch.zeros(1, clip_embed_dim))
|
161 |
+
|
162 |
+
@property
|
163 |
+
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
|
164 |
+
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
165 |
+
r"""
|
166 |
+
Returns:
|
167 |
+
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
168 |
+
indexed by its weight name.
|
169 |
+
"""
|
170 |
+
# set recursively
|
171 |
+
processors = {}
|
172 |
+
|
173 |
+
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
174 |
+
if hasattr(module, "set_processor"):
|
175 |
+
processors[f"{name}.processor"] = module.processor
|
176 |
+
|
177 |
+
for sub_name, child in module.named_children():
|
178 |
+
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
179 |
+
|
180 |
+
return processors
|
181 |
+
|
182 |
+
for name, module in self.named_children():
|
183 |
+
fn_recursive_add_processors(name, module, processors)
|
184 |
+
|
185 |
+
return processors
|
186 |
+
|
187 |
+
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
188 |
+
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
189 |
+
r"""
|
190 |
+
Sets the attention processor to use to compute attention.
|
191 |
+
|
192 |
+
Parameters:
|
193 |
+
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
194 |
+
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
195 |
+
for **all** `Attention` layers.
|
196 |
+
|
197 |
+
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
198 |
+
processor. This is strongly recommended when setting trainable attention processors.
|
199 |
+
|
200 |
+
"""
|
201 |
+
count = len(self.attn_processors.keys())
|
202 |
+
|
203 |
+
if isinstance(processor, dict) and len(processor) != count:
|
204 |
+
raise ValueError(
|
205 |
+
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
206 |
+
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
207 |
+
)
|
208 |
+
|
209 |
+
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
210 |
+
if hasattr(module, "set_processor"):
|
211 |
+
if not isinstance(processor, dict):
|
212 |
+
module.set_processor(processor)
|
213 |
+
else:
|
214 |
+
module.set_processor(processor.pop(f"{name}.processor"))
|
215 |
+
|
216 |
+
for sub_name, child in module.named_children():
|
217 |
+
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
218 |
+
|
219 |
+
for name, module in self.named_children():
|
220 |
+
fn_recursive_attn_processor(name, module, processor)
|
221 |
+
|
222 |
+
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
|
223 |
+
def set_default_attn_processor(self):
|
224 |
+
"""
|
225 |
+
Disables custom attention processors and sets the default attention implementation.
|
226 |
+
"""
|
227 |
+
self.set_attn_processor(AttnProcessor())
|
228 |
+
|
229 |
+
def forward(
|
230 |
+
self,
|
231 |
+
hidden_states,
|
232 |
+
timestep: Union[torch.Tensor, float, int],
|
233 |
+
proj_embedding: torch.FloatTensor,
|
234 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
235 |
+
attention_mask: Optional[torch.BoolTensor] = None,
|
236 |
+
return_dict: bool = True,
|
237 |
+
):
|
238 |
+
"""
|
239 |
+
The [`PriorTransformer`] forward method.
|
240 |
+
|
241 |
+
Args:
|
242 |
+
hidden_states (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`):
|
243 |
+
The currently predicted image embeddings.
|
244 |
+
timestep (`torch.LongTensor`):
|
245 |
+
Current denoising step.
|
246 |
+
proj_embedding (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`):
|
247 |
+
Projected embedding vector the denoising process is conditioned on.
|
248 |
+
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, num_embeddings, embedding_dim)`):
|
249 |
+
Hidden states of the text embeddings the denoising process is conditioned on.
|
250 |
+
attention_mask (`torch.BoolTensor` of shape `(batch_size, num_embeddings)`):
|
251 |
+
Text mask for the text embeddings.
|
252 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
253 |
+
Whether or not to return a [`~models.prior_transformer.PriorTransformerOutput`] instead of a plain
|
254 |
+
tuple.
|
255 |
+
|
256 |
+
Returns:
|
257 |
+
[`~models.prior_transformer.PriorTransformerOutput`] or `tuple`:
|
258 |
+
If return_dict is True, a [`~models.prior_transformer.PriorTransformerOutput`] is returned, otherwise a
|
259 |
+
tuple is returned where the first element is the sample tensor.
|
260 |
+
"""
|
261 |
+
batch_size = hidden_states.shape[0]
|
262 |
+
|
263 |
+
timesteps = timestep
|
264 |
+
if not torch.is_tensor(timesteps):
|
265 |
+
timesteps = torch.tensor([timesteps], dtype=torch.long, device=hidden_states.device)
|
266 |
+
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
|
267 |
+
timesteps = timesteps[None].to(hidden_states.device)
|
268 |
+
|
269 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
270 |
+
timesteps = timesteps * torch.ones(batch_size, dtype=timesteps.dtype, device=timesteps.device)
|
271 |
+
|
272 |
+
timesteps_projected = self.time_proj(timesteps)
|
273 |
+
|
274 |
+
# timesteps does not contain any weights and will always return f32 tensors
|
275 |
+
# but time_embedding might be fp16, so we need to cast here.
|
276 |
+
timesteps_projected = timesteps_projected.to(dtype=self.dtype)
|
277 |
+
time_embeddings = self.time_embedding(timesteps_projected)
|
278 |
+
|
279 |
+
if self.embedding_proj_norm is not None:
|
280 |
+
proj_embedding = self.embedding_proj_norm(proj_embedding)
|
281 |
+
|
282 |
+
proj_embeddings = self.embedding_proj(proj_embedding)
|
283 |
+
if self.encoder_hidden_states_proj is not None and encoder_hidden_states is not None:
|
284 |
+
encoder_hidden_states = self.encoder_hidden_states_proj(encoder_hidden_states)
|
285 |
+
elif self.encoder_hidden_states_proj is not None and encoder_hidden_states is None:
|
286 |
+
raise ValueError("`encoder_hidden_states_proj` requires `encoder_hidden_states` to be set")
|
287 |
+
|
288 |
+
hidden_states = self.proj_in(hidden_states)
|
289 |
+
|
290 |
+
positional_embeddings = self.positional_embedding.to(hidden_states.dtype)
|
291 |
+
|
292 |
+
additional_embeds = []
|
293 |
+
additional_embeddings_len = 0
|
294 |
+
|
295 |
+
if encoder_hidden_states is not None:
|
296 |
+
additional_embeds.append(encoder_hidden_states)
|
297 |
+
additional_embeddings_len += encoder_hidden_states.shape[1]
|
298 |
+
|
299 |
+
if len(proj_embeddings.shape) == 2:
|
300 |
+
proj_embeddings = proj_embeddings[:, None, :]
|
301 |
+
|
302 |
+
if len(hidden_states.shape) == 2:
|
303 |
+
hidden_states = hidden_states[:, None, :]
|
304 |
+
|
305 |
+
additional_embeds = additional_embeds + [
|
306 |
+
proj_embeddings,
|
307 |
+
time_embeddings[:, None, :],
|
308 |
+
hidden_states,
|
309 |
+
]
|
310 |
+
|
311 |
+
if self.prd_embedding is not None:
|
312 |
+
prd_embedding = self.prd_embedding.to(hidden_states.dtype).expand(batch_size, -1, -1)
|
313 |
+
additional_embeds.append(prd_embedding)
|
314 |
+
|
315 |
+
hidden_states = torch.cat(
|
316 |
+
additional_embeds,
|
317 |
+
dim=1,
|
318 |
+
)
|
319 |
+
|
320 |
+
# Allow positional_embedding to not include the `addtional_embeddings` and instead pad it with zeros for these additional tokens
|
321 |
+
additional_embeddings_len = additional_embeddings_len + proj_embeddings.shape[1] + 1
|
322 |
+
if positional_embeddings.shape[1] < hidden_states.shape[1]:
|
323 |
+
positional_embeddings = F.pad(
|
324 |
+
positional_embeddings,
|
325 |
+
(
|
326 |
+
0,
|
327 |
+
0,
|
328 |
+
additional_embeddings_len,
|
329 |
+
self.prd_embedding.shape[1] if self.prd_embedding is not None else 0,
|
330 |
+
),
|
331 |
+
value=0.0,
|
332 |
+
)
|
333 |
+
|
334 |
+
hidden_states = hidden_states + positional_embeddings
|
335 |
+
|
336 |
+
if attention_mask is not None:
|
337 |
+
attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
|
338 |
+
attention_mask = F.pad(attention_mask, (0, self.additional_embeddings), value=0.0)
|
339 |
+
attention_mask = (attention_mask[:, None, :] + self.causal_attention_mask).to(hidden_states.dtype)
|
340 |
+
attention_mask = attention_mask.repeat_interleave(self.config.num_attention_heads, dim=0)
|
341 |
+
|
342 |
+
if self.norm_in is not None:
|
343 |
+
hidden_states = self.norm_in(hidden_states)
|
344 |
+
|
345 |
+
for block in self.transformer_blocks:
|
346 |
+
hidden_states = block(hidden_states, attention_mask=attention_mask)
|
347 |
+
|
348 |
+
hidden_states = self.norm_out(hidden_states)
|
349 |
+
|
350 |
+
if self.prd_embedding is not None:
|
351 |
+
hidden_states = hidden_states[:, -1]
|
352 |
+
else:
|
353 |
+
hidden_states = hidden_states[:, additional_embeddings_len:]
|
354 |
+
|
355 |
+
predicted_image_embedding = self.proj_to_clip_embeddings(hidden_states)
|
356 |
+
|
357 |
+
if not return_dict:
|
358 |
+
return (predicted_image_embedding,)
|
359 |
+
|
360 |
+
return PriorTransformerOutput(predicted_image_embedding=predicted_image_embedding)
|
361 |
+
|
362 |
+
def post_process_latents(self, prior_latents):
|
363 |
+
prior_latents = (prior_latents * self.clip_std) + self.clip_mean
|
364 |
+
return prior_latents
|
6DoF/diffusers/models/resnet.py
ADDED
@@ -0,0 +1,877 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
# `TemporalConvLayer` Copyright 2023 Alibaba DAMO-VILAB, The ModelScope Team and The HuggingFace Team. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
from functools import partial
|
17 |
+
from typing import Optional
|
18 |
+
|
19 |
+
import torch
|
20 |
+
import torch.nn as nn
|
21 |
+
import torch.nn.functional as F
|
22 |
+
|
23 |
+
from .activations import get_activation
|
24 |
+
from .attention import AdaGroupNorm
|
25 |
+
from .attention_processor import SpatialNorm
|
26 |
+
|
27 |
+
|
28 |
+
class Upsample1D(nn.Module):
|
29 |
+
"""A 1D upsampling layer with an optional convolution.
|
30 |
+
|
31 |
+
Parameters:
|
32 |
+
channels (`int`):
|
33 |
+
number of channels in the inputs and outputs.
|
34 |
+
use_conv (`bool`, default `False`):
|
35 |
+
option to use a convolution.
|
36 |
+
use_conv_transpose (`bool`, default `False`):
|
37 |
+
option to use a convolution transpose.
|
38 |
+
out_channels (`int`, optional):
|
39 |
+
number of output channels. Defaults to `channels`.
|
40 |
+
"""
|
41 |
+
|
42 |
+
def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
|
43 |
+
super().__init__()
|
44 |
+
self.channels = channels
|
45 |
+
self.out_channels = out_channels or channels
|
46 |
+
self.use_conv = use_conv
|
47 |
+
self.use_conv_transpose = use_conv_transpose
|
48 |
+
self.name = name
|
49 |
+
|
50 |
+
self.conv = None
|
51 |
+
if use_conv_transpose:
|
52 |
+
self.conv = nn.ConvTranspose1d(channels, self.out_channels, 4, 2, 1)
|
53 |
+
elif use_conv:
|
54 |
+
self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1)
|
55 |
+
|
56 |
+
def forward(self, inputs):
|
57 |
+
assert inputs.shape[1] == self.channels
|
58 |
+
if self.use_conv_transpose:
|
59 |
+
return self.conv(inputs)
|
60 |
+
|
61 |
+
outputs = F.interpolate(inputs, scale_factor=2.0, mode="nearest")
|
62 |
+
|
63 |
+
if self.use_conv:
|
64 |
+
outputs = self.conv(outputs)
|
65 |
+
|
66 |
+
return outputs
|
67 |
+
|
68 |
+
|
69 |
+
class Downsample1D(nn.Module):
|
70 |
+
"""A 1D downsampling layer with an optional convolution.
|
71 |
+
|
72 |
+
Parameters:
|
73 |
+
channels (`int`):
|
74 |
+
number of channels in the inputs and outputs.
|
75 |
+
use_conv (`bool`, default `False`):
|
76 |
+
option to use a convolution.
|
77 |
+
out_channels (`int`, optional):
|
78 |
+
number of output channels. Defaults to `channels`.
|
79 |
+
padding (`int`, default `1`):
|
80 |
+
padding for the convolution.
|
81 |
+
"""
|
82 |
+
|
83 |
+
def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
|
84 |
+
super().__init__()
|
85 |
+
self.channels = channels
|
86 |
+
self.out_channels = out_channels or channels
|
87 |
+
self.use_conv = use_conv
|
88 |
+
self.padding = padding
|
89 |
+
stride = 2
|
90 |
+
self.name = name
|
91 |
+
|
92 |
+
if use_conv:
|
93 |
+
self.conv = nn.Conv1d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
|
94 |
+
else:
|
95 |
+
assert self.channels == self.out_channels
|
96 |
+
self.conv = nn.AvgPool1d(kernel_size=stride, stride=stride)
|
97 |
+
|
98 |
+
def forward(self, inputs):
|
99 |
+
assert inputs.shape[1] == self.channels
|
100 |
+
return self.conv(inputs)
|
101 |
+
|
102 |
+
|
103 |
+
class Upsample2D(nn.Module):
|
104 |
+
"""A 2D upsampling layer with an optional convolution.
|
105 |
+
|
106 |
+
Parameters:
|
107 |
+
channels (`int`):
|
108 |
+
number of channels in the inputs and outputs.
|
109 |
+
use_conv (`bool`, default `False`):
|
110 |
+
option to use a convolution.
|
111 |
+
use_conv_transpose (`bool`, default `False`):
|
112 |
+
option to use a convolution transpose.
|
113 |
+
out_channels (`int`, optional):
|
114 |
+
number of output channels. Defaults to `channels`.
|
115 |
+
"""
|
116 |
+
|
117 |
+
def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
|
118 |
+
super().__init__()
|
119 |
+
self.channels = channels
|
120 |
+
self.out_channels = out_channels or channels
|
121 |
+
self.use_conv = use_conv
|
122 |
+
self.use_conv_transpose = use_conv_transpose
|
123 |
+
self.name = name
|
124 |
+
|
125 |
+
conv = None
|
126 |
+
if use_conv_transpose:
|
127 |
+
conv = nn.ConvTranspose2d(channels, self.out_channels, 4, 2, 1)
|
128 |
+
elif use_conv:
|
129 |
+
conv = nn.Conv2d(self.channels, self.out_channels, 3, padding=1)
|
130 |
+
|
131 |
+
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
|
132 |
+
if name == "conv":
|
133 |
+
self.conv = conv
|
134 |
+
else:
|
135 |
+
self.Conv2d_0 = conv
|
136 |
+
|
137 |
+
def forward(self, hidden_states, output_size=None):
|
138 |
+
assert hidden_states.shape[1] == self.channels
|
139 |
+
|
140 |
+
if self.use_conv_transpose:
|
141 |
+
return self.conv(hidden_states)
|
142 |
+
|
143 |
+
# Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
|
144 |
+
# TODO(Suraj): Remove this cast once the issue is fixed in PyTorch
|
145 |
+
# https://github.com/pytorch/pytorch/issues/86679
|
146 |
+
dtype = hidden_states.dtype
|
147 |
+
if dtype == torch.bfloat16:
|
148 |
+
hidden_states = hidden_states.to(torch.float32)
|
149 |
+
|
150 |
+
# upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
|
151 |
+
if hidden_states.shape[0] >= 64:
|
152 |
+
hidden_states = hidden_states.contiguous()
|
153 |
+
|
154 |
+
# if `output_size` is passed we force the interpolation output
|
155 |
+
# size and do not make use of `scale_factor=2`
|
156 |
+
if output_size is None:
|
157 |
+
hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
|
158 |
+
else:
|
159 |
+
hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
|
160 |
+
|
161 |
+
# If the input is bfloat16, we cast back to bfloat16
|
162 |
+
if dtype == torch.bfloat16:
|
163 |
+
hidden_states = hidden_states.to(dtype)
|
164 |
+
|
165 |
+
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
|
166 |
+
if self.use_conv:
|
167 |
+
if self.name == "conv":
|
168 |
+
hidden_states = self.conv(hidden_states)
|
169 |
+
else:
|
170 |
+
hidden_states = self.Conv2d_0(hidden_states)
|
171 |
+
|
172 |
+
return hidden_states
|
173 |
+
|
174 |
+
|
175 |
+
class Downsample2D(nn.Module):
|
176 |
+
"""A 2D downsampling layer with an optional convolution.
|
177 |
+
|
178 |
+
Parameters:
|
179 |
+
channels (`int`):
|
180 |
+
number of channels in the inputs and outputs.
|
181 |
+
use_conv (`bool`, default `False`):
|
182 |
+
option to use a convolution.
|
183 |
+
out_channels (`int`, optional):
|
184 |
+
number of output channels. Defaults to `channels`.
|
185 |
+
padding (`int`, default `1`):
|
186 |
+
padding for the convolution.
|
187 |
+
"""
|
188 |
+
|
189 |
+
def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
|
190 |
+
super().__init__()
|
191 |
+
self.channels = channels
|
192 |
+
self.out_channels = out_channels or channels
|
193 |
+
self.use_conv = use_conv
|
194 |
+
self.padding = padding
|
195 |
+
stride = 2
|
196 |
+
self.name = name
|
197 |
+
|
198 |
+
if use_conv:
|
199 |
+
conv = nn.Conv2d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
|
200 |
+
else:
|
201 |
+
assert self.channels == self.out_channels
|
202 |
+
conv = nn.AvgPool2d(kernel_size=stride, stride=stride)
|
203 |
+
|
204 |
+
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
|
205 |
+
if name == "conv":
|
206 |
+
self.Conv2d_0 = conv
|
207 |
+
self.conv = conv
|
208 |
+
elif name == "Conv2d_0":
|
209 |
+
self.conv = conv
|
210 |
+
else:
|
211 |
+
self.conv = conv
|
212 |
+
|
213 |
+
def forward(self, hidden_states):
|
214 |
+
assert hidden_states.shape[1] == self.channels
|
215 |
+
if self.use_conv and self.padding == 0:
|
216 |
+
pad = (0, 1, 0, 1)
|
217 |
+
hidden_states = F.pad(hidden_states, pad, mode="constant", value=0)
|
218 |
+
|
219 |
+
assert hidden_states.shape[1] == self.channels
|
220 |
+
hidden_states = self.conv(hidden_states)
|
221 |
+
|
222 |
+
return hidden_states
|
223 |
+
|
224 |
+
|
225 |
+
class FirUpsample2D(nn.Module):
|
226 |
+
"""A 2D FIR upsampling layer with an optional convolution.
|
227 |
+
|
228 |
+
Parameters:
|
229 |
+
channels (`int`):
|
230 |
+
number of channels in the inputs and outputs.
|
231 |
+
use_conv (`bool`, default `False`):
|
232 |
+
option to use a convolution.
|
233 |
+
out_channels (`int`, optional):
|
234 |
+
number of output channels. Defaults to `channels`.
|
235 |
+
fir_kernel (`tuple`, default `(1, 3, 3, 1)`):
|
236 |
+
kernel for the FIR filter.
|
237 |
+
"""
|
238 |
+
|
239 |
+
def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
|
240 |
+
super().__init__()
|
241 |
+
out_channels = out_channels if out_channels else channels
|
242 |
+
if use_conv:
|
243 |
+
self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
|
244 |
+
self.use_conv = use_conv
|
245 |
+
self.fir_kernel = fir_kernel
|
246 |
+
self.out_channels = out_channels
|
247 |
+
|
248 |
+
def _upsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain=1):
|
249 |
+
"""Fused `upsample_2d()` followed by `Conv2d()`.
|
250 |
+
|
251 |
+
Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
|
252 |
+
efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of
|
253 |
+
arbitrary order.
|
254 |
+
|
255 |
+
Args:
|
256 |
+
hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
|
257 |
+
weight: Weight tensor of the shape `[filterH, filterW, inChannels,
|
258 |
+
outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`.
|
259 |
+
kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
|
260 |
+
(separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
|
261 |
+
factor: Integer upsampling factor (default: 2).
|
262 |
+
gain: Scaling factor for signal magnitude (default: 1.0).
|
263 |
+
|
264 |
+
Returns:
|
265 |
+
output: Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same
|
266 |
+
datatype as `hidden_states`.
|
267 |
+
"""
|
268 |
+
|
269 |
+
assert isinstance(factor, int) and factor >= 1
|
270 |
+
|
271 |
+
# Setup filter kernel.
|
272 |
+
if kernel is None:
|
273 |
+
kernel = [1] * factor
|
274 |
+
|
275 |
+
# setup kernel
|
276 |
+
kernel = torch.tensor(kernel, dtype=torch.float32)
|
277 |
+
if kernel.ndim == 1:
|
278 |
+
kernel = torch.outer(kernel, kernel)
|
279 |
+
kernel /= torch.sum(kernel)
|
280 |
+
|
281 |
+
kernel = kernel * (gain * (factor**2))
|
282 |
+
|
283 |
+
if self.use_conv:
|
284 |
+
convH = weight.shape[2]
|
285 |
+
convW = weight.shape[3]
|
286 |
+
inC = weight.shape[1]
|
287 |
+
|
288 |
+
pad_value = (kernel.shape[0] - factor) - (convW - 1)
|
289 |
+
|
290 |
+
stride = (factor, factor)
|
291 |
+
# Determine data dimensions.
|
292 |
+
output_shape = (
|
293 |
+
(hidden_states.shape[2] - 1) * factor + convH,
|
294 |
+
(hidden_states.shape[3] - 1) * factor + convW,
|
295 |
+
)
|
296 |
+
output_padding = (
|
297 |
+
output_shape[0] - (hidden_states.shape[2] - 1) * stride[0] - convH,
|
298 |
+
output_shape[1] - (hidden_states.shape[3] - 1) * stride[1] - convW,
|
299 |
+
)
|
300 |
+
assert output_padding[0] >= 0 and output_padding[1] >= 0
|
301 |
+
num_groups = hidden_states.shape[1] // inC
|
302 |
+
|
303 |
+
# Transpose weights.
|
304 |
+
weight = torch.reshape(weight, (num_groups, -1, inC, convH, convW))
|
305 |
+
weight = torch.flip(weight, dims=[3, 4]).permute(0, 2, 1, 3, 4)
|
306 |
+
weight = torch.reshape(weight, (num_groups * inC, -1, convH, convW))
|
307 |
+
|
308 |
+
inverse_conv = F.conv_transpose2d(
|
309 |
+
hidden_states, weight, stride=stride, output_padding=output_padding, padding=0
|
310 |
+
)
|
311 |
+
|
312 |
+
output = upfirdn2d_native(
|
313 |
+
inverse_conv,
|
314 |
+
torch.tensor(kernel, device=inverse_conv.device),
|
315 |
+
pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2 + 1),
|
316 |
+
)
|
317 |
+
else:
|
318 |
+
pad_value = kernel.shape[0] - factor
|
319 |
+
output = upfirdn2d_native(
|
320 |
+
hidden_states,
|
321 |
+
torch.tensor(kernel, device=hidden_states.device),
|
322 |
+
up=factor,
|
323 |
+
pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2),
|
324 |
+
)
|
325 |
+
|
326 |
+
return output
|
327 |
+
|
328 |
+
def forward(self, hidden_states):
|
329 |
+
if self.use_conv:
|
330 |
+
height = self._upsample_2d(hidden_states, self.Conv2d_0.weight, kernel=self.fir_kernel)
|
331 |
+
height = height + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
|
332 |
+
else:
|
333 |
+
height = self._upsample_2d(hidden_states, kernel=self.fir_kernel, factor=2)
|
334 |
+
|
335 |
+
return height
|
336 |
+
|
337 |
+
|
338 |
+
class FirDownsample2D(nn.Module):
|
339 |
+
"""A 2D FIR downsampling layer with an optional convolution.
|
340 |
+
|
341 |
+
Parameters:
|
342 |
+
channels (`int`):
|
343 |
+
number of channels in the inputs and outputs.
|
344 |
+
use_conv (`bool`, default `False`):
|
345 |
+
option to use a convolution.
|
346 |
+
out_channels (`int`, optional):
|
347 |
+
number of output channels. Defaults to `channels`.
|
348 |
+
fir_kernel (`tuple`, default `(1, 3, 3, 1)`):
|
349 |
+
kernel for the FIR filter.
|
350 |
+
"""
|
351 |
+
|
352 |
+
def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
|
353 |
+
super().__init__()
|
354 |
+
out_channels = out_channels if out_channels else channels
|
355 |
+
if use_conv:
|
356 |
+
self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
|
357 |
+
self.fir_kernel = fir_kernel
|
358 |
+
self.use_conv = use_conv
|
359 |
+
self.out_channels = out_channels
|
360 |
+
|
361 |
+
def _downsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain=1):
|
362 |
+
"""Fused `Conv2d()` followed by `downsample_2d()`.
|
363 |
+
Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
|
364 |
+
efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of
|
365 |
+
arbitrary order.
|
366 |
+
|
367 |
+
Args:
|
368 |
+
hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
|
369 |
+
weight:
|
370 |
+
Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`. Grouped convolution can be
|
371 |
+
performed by `inChannels = x.shape[0] // numGroups`.
|
372 |
+
kernel: FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] *
|
373 |
+
factor`, which corresponds to average pooling.
|
374 |
+
factor: Integer downsampling factor (default: 2).
|
375 |
+
gain: Scaling factor for signal magnitude (default: 1.0).
|
376 |
+
|
377 |
+
Returns:
|
378 |
+
output: Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and
|
379 |
+
same datatype as `x`.
|
380 |
+
"""
|
381 |
+
|
382 |
+
assert isinstance(factor, int) and factor >= 1
|
383 |
+
if kernel is None:
|
384 |
+
kernel = [1] * factor
|
385 |
+
|
386 |
+
# setup kernel
|
387 |
+
kernel = torch.tensor(kernel, dtype=torch.float32)
|
388 |
+
if kernel.ndim == 1:
|
389 |
+
kernel = torch.outer(kernel, kernel)
|
390 |
+
kernel /= torch.sum(kernel)
|
391 |
+
|
392 |
+
kernel = kernel * gain
|
393 |
+
|
394 |
+
if self.use_conv:
|
395 |
+
_, _, convH, convW = weight.shape
|
396 |
+
pad_value = (kernel.shape[0] - factor) + (convW - 1)
|
397 |
+
stride_value = [factor, factor]
|
398 |
+
upfirdn_input = upfirdn2d_native(
|
399 |
+
hidden_states,
|
400 |
+
torch.tensor(kernel, device=hidden_states.device),
|
401 |
+
pad=((pad_value + 1) // 2, pad_value // 2),
|
402 |
+
)
|
403 |
+
output = F.conv2d(upfirdn_input, weight, stride=stride_value, padding=0)
|
404 |
+
else:
|
405 |
+
pad_value = kernel.shape[0] - factor
|
406 |
+
output = upfirdn2d_native(
|
407 |
+
hidden_states,
|
408 |
+
torch.tensor(kernel, device=hidden_states.device),
|
409 |
+
down=factor,
|
410 |
+
pad=((pad_value + 1) // 2, pad_value // 2),
|
411 |
+
)
|
412 |
+
|
413 |
+
return output
|
414 |
+
|
415 |
+
def forward(self, hidden_states):
|
416 |
+
if self.use_conv:
|
417 |
+
downsample_input = self._downsample_2d(hidden_states, weight=self.Conv2d_0.weight, kernel=self.fir_kernel)
|
418 |
+
hidden_states = downsample_input + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
|
419 |
+
else:
|
420 |
+
hidden_states = self._downsample_2d(hidden_states, kernel=self.fir_kernel, factor=2)
|
421 |
+
|
422 |
+
return hidden_states
|
423 |
+
|
424 |
+
|
425 |
+
# downsample/upsample layer used in k-upscaler, might be able to use FirDownsample2D/DirUpsample2D instead
|
426 |
+
class KDownsample2D(nn.Module):
|
427 |
+
def __init__(self, pad_mode="reflect"):
|
428 |
+
super().__init__()
|
429 |
+
self.pad_mode = pad_mode
|
430 |
+
kernel_1d = torch.tensor([[1 / 8, 3 / 8, 3 / 8, 1 / 8]])
|
431 |
+
self.pad = kernel_1d.shape[1] // 2 - 1
|
432 |
+
self.register_buffer("kernel", kernel_1d.T @ kernel_1d, persistent=False)
|
433 |
+
|
434 |
+
def forward(self, inputs):
|
435 |
+
inputs = F.pad(inputs, (self.pad,) * 4, self.pad_mode)
|
436 |
+
weight = inputs.new_zeros([inputs.shape[1], inputs.shape[1], self.kernel.shape[0], self.kernel.shape[1]])
|
437 |
+
indices = torch.arange(inputs.shape[1], device=inputs.device)
|
438 |
+
kernel = self.kernel.to(weight)[None, :].expand(inputs.shape[1], -1, -1)
|
439 |
+
weight[indices, indices] = kernel
|
440 |
+
return F.conv2d(inputs, weight, stride=2)
|
441 |
+
|
442 |
+
|
443 |
+
class KUpsample2D(nn.Module):
|
444 |
+
def __init__(self, pad_mode="reflect"):
|
445 |
+
super().__init__()
|
446 |
+
self.pad_mode = pad_mode
|
447 |
+
kernel_1d = torch.tensor([[1 / 8, 3 / 8, 3 / 8, 1 / 8]]) * 2
|
448 |
+
self.pad = kernel_1d.shape[1] // 2 - 1
|
449 |
+
self.register_buffer("kernel", kernel_1d.T @ kernel_1d, persistent=False)
|
450 |
+
|
451 |
+
def forward(self, inputs):
|
452 |
+
inputs = F.pad(inputs, ((self.pad + 1) // 2,) * 4, self.pad_mode)
|
453 |
+
weight = inputs.new_zeros([inputs.shape[1], inputs.shape[1], self.kernel.shape[0], self.kernel.shape[1]])
|
454 |
+
indices = torch.arange(inputs.shape[1], device=inputs.device)
|
455 |
+
kernel = self.kernel.to(weight)[None, :].expand(inputs.shape[1], -1, -1)
|
456 |
+
weight[indices, indices] = kernel
|
457 |
+
return F.conv_transpose2d(inputs, weight, stride=2, padding=self.pad * 2 + 1)
|
458 |
+
|
459 |
+
|
460 |
+
class ResnetBlock2D(nn.Module):
|
461 |
+
r"""
|
462 |
+
A Resnet block.
|
463 |
+
|
464 |
+
Parameters:
|
465 |
+
in_channels (`int`): The number of channels in the input.
|
466 |
+
out_channels (`int`, *optional*, default to be `None`):
|
467 |
+
The number of output channels for the first conv2d layer. If None, same as `in_channels`.
|
468 |
+
dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use.
|
469 |
+
temb_channels (`int`, *optional*, default to `512`): the number of channels in timestep embedding.
|
470 |
+
groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer.
|
471 |
+
groups_out (`int`, *optional*, default to None):
|
472 |
+
The number of groups to use for the second normalization layer. if set to None, same as `groups`.
|
473 |
+
eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization.
|
474 |
+
non_linearity (`str`, *optional*, default to `"swish"`): the activation function to use.
|
475 |
+
time_embedding_norm (`str`, *optional*, default to `"default"` ): Time scale shift config.
|
476 |
+
By default, apply timestep embedding conditioning with a simple shift mechanism. Choose "scale_shift" or
|
477 |
+
"ada_group" for a stronger conditioning with scale and shift.
|
478 |
+
kernel (`torch.FloatTensor`, optional, default to None): FIR filter, see
|
479 |
+
[`~models.resnet.FirUpsample2D`] and [`~models.resnet.FirDownsample2D`].
|
480 |
+
output_scale_factor (`float`, *optional*, default to be `1.0`): the scale factor to use for the output.
|
481 |
+
use_in_shortcut (`bool`, *optional*, default to `True`):
|
482 |
+
If `True`, add a 1x1 nn.conv2d layer for skip-connection.
|
483 |
+
up (`bool`, *optional*, default to `False`): If `True`, add an upsample layer.
|
484 |
+
down (`bool`, *optional*, default to `False`): If `True`, add a downsample layer.
|
485 |
+
conv_shortcut_bias (`bool`, *optional*, default to `True`): If `True`, adds a learnable bias to the
|
486 |
+
`conv_shortcut` output.
|
487 |
+
conv_2d_out_channels (`int`, *optional*, default to `None`): the number of channels in the output.
|
488 |
+
If None, same as `out_channels`.
|
489 |
+
"""
|
490 |
+
|
491 |
+
def __init__(
|
492 |
+
self,
|
493 |
+
*,
|
494 |
+
in_channels,
|
495 |
+
out_channels=None,
|
496 |
+
conv_shortcut=False,
|
497 |
+
dropout=0.0,
|
498 |
+
temb_channels=512,
|
499 |
+
groups=32,
|
500 |
+
groups_out=None,
|
501 |
+
pre_norm=True,
|
502 |
+
eps=1e-6,
|
503 |
+
non_linearity="swish",
|
504 |
+
skip_time_act=False,
|
505 |
+
time_embedding_norm="default", # default, scale_shift, ada_group, spatial
|
506 |
+
kernel=None,
|
507 |
+
output_scale_factor=1.0,
|
508 |
+
use_in_shortcut=None,
|
509 |
+
up=False,
|
510 |
+
down=False,
|
511 |
+
conv_shortcut_bias: bool = True,
|
512 |
+
conv_2d_out_channels: Optional[int] = None,
|
513 |
+
):
|
514 |
+
super().__init__()
|
515 |
+
self.pre_norm = pre_norm
|
516 |
+
self.pre_norm = True
|
517 |
+
self.in_channels = in_channels
|
518 |
+
out_channels = in_channels if out_channels is None else out_channels
|
519 |
+
self.out_channels = out_channels
|
520 |
+
self.use_conv_shortcut = conv_shortcut
|
521 |
+
self.up = up
|
522 |
+
self.down = down
|
523 |
+
self.output_scale_factor = output_scale_factor
|
524 |
+
self.time_embedding_norm = time_embedding_norm
|
525 |
+
self.skip_time_act = skip_time_act
|
526 |
+
|
527 |
+
if groups_out is None:
|
528 |
+
groups_out = groups
|
529 |
+
|
530 |
+
if self.time_embedding_norm == "ada_group":
|
531 |
+
self.norm1 = AdaGroupNorm(temb_channels, in_channels, groups, eps=eps)
|
532 |
+
elif self.time_embedding_norm == "spatial":
|
533 |
+
self.norm1 = SpatialNorm(in_channels, temb_channels)
|
534 |
+
else:
|
535 |
+
self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
|
536 |
+
|
537 |
+
self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
538 |
+
|
539 |
+
if temb_channels is not None:
|
540 |
+
if self.time_embedding_norm == "default":
|
541 |
+
self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels)
|
542 |
+
elif self.time_embedding_norm == "scale_shift":
|
543 |
+
self.time_emb_proj = torch.nn.Linear(temb_channels, 2 * out_channels)
|
544 |
+
elif self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
|
545 |
+
self.time_emb_proj = None
|
546 |
+
else:
|
547 |
+
raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ")
|
548 |
+
else:
|
549 |
+
self.time_emb_proj = None
|
550 |
+
|
551 |
+
if self.time_embedding_norm == "ada_group":
|
552 |
+
self.norm2 = AdaGroupNorm(temb_channels, out_channels, groups_out, eps=eps)
|
553 |
+
elif self.time_embedding_norm == "spatial":
|
554 |
+
self.norm2 = SpatialNorm(out_channels, temb_channels)
|
555 |
+
else:
|
556 |
+
self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
|
557 |
+
|
558 |
+
self.dropout = torch.nn.Dropout(dropout)
|
559 |
+
conv_2d_out_channels = conv_2d_out_channels or out_channels
|
560 |
+
self.conv2 = torch.nn.Conv2d(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1)
|
561 |
+
|
562 |
+
self.nonlinearity = get_activation(non_linearity)
|
563 |
+
|
564 |
+
self.upsample = self.downsample = None
|
565 |
+
if self.up:
|
566 |
+
if kernel == "fir":
|
567 |
+
fir_kernel = (1, 3, 3, 1)
|
568 |
+
self.upsample = lambda x: upsample_2d(x, kernel=fir_kernel)
|
569 |
+
elif kernel == "sde_vp":
|
570 |
+
self.upsample = partial(F.interpolate, scale_factor=2.0, mode="nearest")
|
571 |
+
else:
|
572 |
+
self.upsample = Upsample2D(in_channels, use_conv=False)
|
573 |
+
elif self.down:
|
574 |
+
if kernel == "fir":
|
575 |
+
fir_kernel = (1, 3, 3, 1)
|
576 |
+
self.downsample = lambda x: downsample_2d(x, kernel=fir_kernel)
|
577 |
+
elif kernel == "sde_vp":
|
578 |
+
self.downsample = partial(F.avg_pool2d, kernel_size=2, stride=2)
|
579 |
+
else:
|
580 |
+
self.downsample = Downsample2D(in_channels, use_conv=False, padding=1, name="op")
|
581 |
+
|
582 |
+
self.use_in_shortcut = self.in_channels != conv_2d_out_channels if use_in_shortcut is None else use_in_shortcut
|
583 |
+
|
584 |
+
self.conv_shortcut = None
|
585 |
+
if self.use_in_shortcut:
|
586 |
+
self.conv_shortcut = torch.nn.Conv2d(
|
587 |
+
in_channels, conv_2d_out_channels, kernel_size=1, stride=1, padding=0, bias=conv_shortcut_bias
|
588 |
+
)
|
589 |
+
|
590 |
+
def forward(self, input_tensor, temb):
|
591 |
+
hidden_states = input_tensor
|
592 |
+
|
593 |
+
if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
|
594 |
+
hidden_states = self.norm1(hidden_states, temb)
|
595 |
+
else:
|
596 |
+
hidden_states = self.norm1(hidden_states)
|
597 |
+
|
598 |
+
hidden_states = self.nonlinearity(hidden_states)
|
599 |
+
|
600 |
+
if self.upsample is not None:
|
601 |
+
# upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
|
602 |
+
if hidden_states.shape[0] >= 64:
|
603 |
+
input_tensor = input_tensor.contiguous()
|
604 |
+
hidden_states = hidden_states.contiguous()
|
605 |
+
input_tensor = self.upsample(input_tensor)
|
606 |
+
hidden_states = self.upsample(hidden_states)
|
607 |
+
elif self.downsample is not None:
|
608 |
+
input_tensor = self.downsample(input_tensor)
|
609 |
+
hidden_states = self.downsample(hidden_states)
|
610 |
+
|
611 |
+
hidden_states = self.conv1(hidden_states)
|
612 |
+
|
613 |
+
if self.time_emb_proj is not None:
|
614 |
+
if not self.skip_time_act:
|
615 |
+
temb = self.nonlinearity(temb)
|
616 |
+
temb = self.time_emb_proj(temb)[:, :, None, None]
|
617 |
+
|
618 |
+
if temb is not None and self.time_embedding_norm == "default":
|
619 |
+
hidden_states = hidden_states + temb
|
620 |
+
|
621 |
+
if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
|
622 |
+
hidden_states = self.norm2(hidden_states, temb)
|
623 |
+
else:
|
624 |
+
hidden_states = self.norm2(hidden_states)
|
625 |
+
|
626 |
+
if temb is not None and self.time_embedding_norm == "scale_shift":
|
627 |
+
scale, shift = torch.chunk(temb, 2, dim=1)
|
628 |
+
hidden_states = hidden_states * (1 + scale) + shift
|
629 |
+
|
630 |
+
hidden_states = self.nonlinearity(hidden_states)
|
631 |
+
|
632 |
+
hidden_states = self.dropout(hidden_states)
|
633 |
+
hidden_states = self.conv2(hidden_states)
|
634 |
+
|
635 |
+
if self.conv_shortcut is not None:
|
636 |
+
input_tensor = self.conv_shortcut(input_tensor)
|
637 |
+
|
638 |
+
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
|
639 |
+
|
640 |
+
return output_tensor
|
641 |
+
|
642 |
+
|
643 |
+
# unet_rl.py
|
644 |
+
def rearrange_dims(tensor):
|
645 |
+
if len(tensor.shape) == 2:
|
646 |
+
return tensor[:, :, None]
|
647 |
+
if len(tensor.shape) == 3:
|
648 |
+
return tensor[:, :, None, :]
|
649 |
+
elif len(tensor.shape) == 4:
|
650 |
+
return tensor[:, :, 0, :]
|
651 |
+
else:
|
652 |
+
raise ValueError(f"`len(tensor)`: {len(tensor)} has to be 2, 3 or 4.")
|
653 |
+
|
654 |
+
|
655 |
+
class Conv1dBlock(nn.Module):
|
656 |
+
"""
|
657 |
+
Conv1d --> GroupNorm --> Mish
|
658 |
+
"""
|
659 |
+
|
660 |
+
def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
|
661 |
+
super().__init__()
|
662 |
+
|
663 |
+
self.conv1d = nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2)
|
664 |
+
self.group_norm = nn.GroupNorm(n_groups, out_channels)
|
665 |
+
self.mish = nn.Mish()
|
666 |
+
|
667 |
+
def forward(self, inputs):
|
668 |
+
intermediate_repr = self.conv1d(inputs)
|
669 |
+
intermediate_repr = rearrange_dims(intermediate_repr)
|
670 |
+
intermediate_repr = self.group_norm(intermediate_repr)
|
671 |
+
intermediate_repr = rearrange_dims(intermediate_repr)
|
672 |
+
output = self.mish(intermediate_repr)
|
673 |
+
return output
|
674 |
+
|
675 |
+
|
676 |
+
# unet_rl.py
|
677 |
+
class ResidualTemporalBlock1D(nn.Module):
|
678 |
+
def __init__(self, inp_channels, out_channels, embed_dim, kernel_size=5):
|
679 |
+
super().__init__()
|
680 |
+
self.conv_in = Conv1dBlock(inp_channels, out_channels, kernel_size)
|
681 |
+
self.conv_out = Conv1dBlock(out_channels, out_channels, kernel_size)
|
682 |
+
|
683 |
+
self.time_emb_act = nn.Mish()
|
684 |
+
self.time_emb = nn.Linear(embed_dim, out_channels)
|
685 |
+
|
686 |
+
self.residual_conv = (
|
687 |
+
nn.Conv1d(inp_channels, out_channels, 1) if inp_channels != out_channels else nn.Identity()
|
688 |
+
)
|
689 |
+
|
690 |
+
def forward(self, inputs, t):
|
691 |
+
"""
|
692 |
+
Args:
|
693 |
+
inputs : [ batch_size x inp_channels x horizon ]
|
694 |
+
t : [ batch_size x embed_dim ]
|
695 |
+
|
696 |
+
returns:
|
697 |
+
out : [ batch_size x out_channels x horizon ]
|
698 |
+
"""
|
699 |
+
t = self.time_emb_act(t)
|
700 |
+
t = self.time_emb(t)
|
701 |
+
out = self.conv_in(inputs) + rearrange_dims(t)
|
702 |
+
out = self.conv_out(out)
|
703 |
+
return out + self.residual_conv(inputs)
|
704 |
+
|
705 |
+
|
706 |
+
def upsample_2d(hidden_states, kernel=None, factor=2, gain=1):
|
707 |
+
r"""Upsample2D a batch of 2D images with the given filter.
|
708 |
+
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given
|
709 |
+
filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified
|
710 |
+
`gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is
|
711 |
+
a: multiple of the upsampling factor.
|
712 |
+
|
713 |
+
Args:
|
714 |
+
hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
|
715 |
+
kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
|
716 |
+
(separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
|
717 |
+
factor: Integer upsampling factor (default: 2).
|
718 |
+
gain: Scaling factor for signal magnitude (default: 1.0).
|
719 |
+
|
720 |
+
Returns:
|
721 |
+
output: Tensor of the shape `[N, C, H * factor, W * factor]`
|
722 |
+
"""
|
723 |
+
assert isinstance(factor, int) and factor >= 1
|
724 |
+
if kernel is None:
|
725 |
+
kernel = [1] * factor
|
726 |
+
|
727 |
+
kernel = torch.tensor(kernel, dtype=torch.float32)
|
728 |
+
if kernel.ndim == 1:
|
729 |
+
kernel = torch.outer(kernel, kernel)
|
730 |
+
kernel /= torch.sum(kernel)
|
731 |
+
|
732 |
+
kernel = kernel * (gain * (factor**2))
|
733 |
+
pad_value = kernel.shape[0] - factor
|
734 |
+
output = upfirdn2d_native(
|
735 |
+
hidden_states,
|
736 |
+
kernel.to(device=hidden_states.device),
|
737 |
+
up=factor,
|
738 |
+
pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2),
|
739 |
+
)
|
740 |
+
return output
|
741 |
+
|
742 |
+
|
743 |
+
def downsample_2d(hidden_states, kernel=None, factor=2, gain=1):
|
744 |
+
r"""Downsample2D a batch of 2D images with the given filter.
|
745 |
+
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the
|
746 |
+
given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the
|
747 |
+
specified `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its
|
748 |
+
shape is a multiple of the downsampling factor.
|
749 |
+
|
750 |
+
Args:
|
751 |
+
hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
|
752 |
+
kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
|
753 |
+
(separable). The default is `[1] * factor`, which corresponds to average pooling.
|
754 |
+
factor: Integer downsampling factor (default: 2).
|
755 |
+
gain: Scaling factor for signal magnitude (default: 1.0).
|
756 |
+
|
757 |
+
Returns:
|
758 |
+
output: Tensor of the shape `[N, C, H // factor, W // factor]`
|
759 |
+
"""
|
760 |
+
|
761 |
+
assert isinstance(factor, int) and factor >= 1
|
762 |
+
if kernel is None:
|
763 |
+
kernel = [1] * factor
|
764 |
+
|
765 |
+
kernel = torch.tensor(kernel, dtype=torch.float32)
|
766 |
+
if kernel.ndim == 1:
|
767 |
+
kernel = torch.outer(kernel, kernel)
|
768 |
+
kernel /= torch.sum(kernel)
|
769 |
+
|
770 |
+
kernel = kernel * gain
|
771 |
+
pad_value = kernel.shape[0] - factor
|
772 |
+
output = upfirdn2d_native(
|
773 |
+
hidden_states, kernel.to(device=hidden_states.device), down=factor, pad=((pad_value + 1) // 2, pad_value // 2)
|
774 |
+
)
|
775 |
+
return output
|
776 |
+
|
777 |
+
|
778 |
+
def upfirdn2d_native(tensor, kernel, up=1, down=1, pad=(0, 0)):
|
779 |
+
up_x = up_y = up
|
780 |
+
down_x = down_y = down
|
781 |
+
pad_x0 = pad_y0 = pad[0]
|
782 |
+
pad_x1 = pad_y1 = pad[1]
|
783 |
+
|
784 |
+
_, channel, in_h, in_w = tensor.shape
|
785 |
+
tensor = tensor.reshape(-1, in_h, in_w, 1)
|
786 |
+
|
787 |
+
_, in_h, in_w, minor = tensor.shape
|
788 |
+
kernel_h, kernel_w = kernel.shape
|
789 |
+
|
790 |
+
out = tensor.view(-1, in_h, 1, in_w, 1, minor)
|
791 |
+
out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
|
792 |
+
out = out.view(-1, in_h * up_y, in_w * up_x, minor)
|
793 |
+
|
794 |
+
out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])
|
795 |
+
out = out.to(tensor.device) # Move back to mps if necessary
|
796 |
+
out = out[
|
797 |
+
:,
|
798 |
+
max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
|
799 |
+
max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
|
800 |
+
:,
|
801 |
+
]
|
802 |
+
|
803 |
+
out = out.permute(0, 3, 1, 2)
|
804 |
+
out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])
|
805 |
+
w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
|
806 |
+
out = F.conv2d(out, w)
|
807 |
+
out = out.reshape(
|
808 |
+
-1,
|
809 |
+
minor,
|
810 |
+
in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
|
811 |
+
in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
|
812 |
+
)
|
813 |
+
out = out.permute(0, 2, 3, 1)
|
814 |
+
out = out[:, ::down_y, ::down_x, :]
|
815 |
+
|
816 |
+
out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
|
817 |
+
out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
|
818 |
+
|
819 |
+
return out.view(-1, channel, out_h, out_w)
|
820 |
+
|
821 |
+
|
822 |
+
class TemporalConvLayer(nn.Module):
|
823 |
+
"""
|
824 |
+
Temporal convolutional layer that can be used for video (sequence of images) input Code mostly copied from:
|
825 |
+
https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/models/multi_modal/video_synthesis/unet_sd.py#L1016
|
826 |
+
"""
|
827 |
+
|
828 |
+
def __init__(self, in_dim, out_dim=None, dropout=0.0):
|
829 |
+
super().__init__()
|
830 |
+
out_dim = out_dim or in_dim
|
831 |
+
self.in_dim = in_dim
|
832 |
+
self.out_dim = out_dim
|
833 |
+
|
834 |
+
# conv layers
|
835 |
+
self.conv1 = nn.Sequential(
|
836 |
+
nn.GroupNorm(32, in_dim), nn.SiLU(), nn.Conv3d(in_dim, out_dim, (3, 1, 1), padding=(1, 0, 0))
|
837 |
+
)
|
838 |
+
self.conv2 = nn.Sequential(
|
839 |
+
nn.GroupNorm(32, out_dim),
|
840 |
+
nn.SiLU(),
|
841 |
+
nn.Dropout(dropout),
|
842 |
+
nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)),
|
843 |
+
)
|
844 |
+
self.conv3 = nn.Sequential(
|
845 |
+
nn.GroupNorm(32, out_dim),
|
846 |
+
nn.SiLU(),
|
847 |
+
nn.Dropout(dropout),
|
848 |
+
nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)),
|
849 |
+
)
|
850 |
+
self.conv4 = nn.Sequential(
|
851 |
+
nn.GroupNorm(32, out_dim),
|
852 |
+
nn.SiLU(),
|
853 |
+
nn.Dropout(dropout),
|
854 |
+
nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)),
|
855 |
+
)
|
856 |
+
|
857 |
+
# zero out the last layer params,so the conv block is identity
|
858 |
+
nn.init.zeros_(self.conv4[-1].weight)
|
859 |
+
nn.init.zeros_(self.conv4[-1].bias)
|
860 |
+
|
861 |
+
def forward(self, hidden_states, num_frames=1):
|
862 |
+
hidden_states = (
|
863 |
+
hidden_states[None, :].reshape((-1, num_frames) + hidden_states.shape[1:]).permute(0, 2, 1, 3, 4)
|
864 |
+
)
|
865 |
+
|
866 |
+
identity = hidden_states
|
867 |
+
hidden_states = self.conv1(hidden_states)
|
868 |
+
hidden_states = self.conv2(hidden_states)
|
869 |
+
hidden_states = self.conv3(hidden_states)
|
870 |
+
hidden_states = self.conv4(hidden_states)
|
871 |
+
|
872 |
+
hidden_states = identity + hidden_states
|
873 |
+
|
874 |
+
hidden_states = hidden_states.permute(0, 2, 1, 3, 4).reshape(
|
875 |
+
(hidden_states.shape[0] * hidden_states.shape[2], -1) + hidden_states.shape[3:]
|
876 |
+
)
|
877 |
+
return hidden_states
|
6DoF/diffusers/models/resnet_flax.py
ADDED
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import flax.linen as nn
|
15 |
+
import jax
|
16 |
+
import jax.numpy as jnp
|
17 |
+
|
18 |
+
|
19 |
+
class FlaxUpsample2D(nn.Module):
|
20 |
+
out_channels: int
|
21 |
+
dtype: jnp.dtype = jnp.float32
|
22 |
+
|
23 |
+
def setup(self):
|
24 |
+
self.conv = nn.Conv(
|
25 |
+
self.out_channels,
|
26 |
+
kernel_size=(3, 3),
|
27 |
+
strides=(1, 1),
|
28 |
+
padding=((1, 1), (1, 1)),
|
29 |
+
dtype=self.dtype,
|
30 |
+
)
|
31 |
+
|
32 |
+
def __call__(self, hidden_states):
|
33 |
+
batch, height, width, channels = hidden_states.shape
|
34 |
+
hidden_states = jax.image.resize(
|
35 |
+
hidden_states,
|
36 |
+
shape=(batch, height * 2, width * 2, channels),
|
37 |
+
method="nearest",
|
38 |
+
)
|
39 |
+
hidden_states = self.conv(hidden_states)
|
40 |
+
return hidden_states
|
41 |
+
|
42 |
+
|
43 |
+
class FlaxDownsample2D(nn.Module):
|
44 |
+
out_channels: int
|
45 |
+
dtype: jnp.dtype = jnp.float32
|
46 |
+
|
47 |
+
def setup(self):
|
48 |
+
self.conv = nn.Conv(
|
49 |
+
self.out_channels,
|
50 |
+
kernel_size=(3, 3),
|
51 |
+
strides=(2, 2),
|
52 |
+
padding=((1, 1), (1, 1)), # padding="VALID",
|
53 |
+
dtype=self.dtype,
|
54 |
+
)
|
55 |
+
|
56 |
+
def __call__(self, hidden_states):
|
57 |
+
# pad = ((0, 0), (0, 1), (0, 1), (0, 0)) # pad height and width dim
|
58 |
+
# hidden_states = jnp.pad(hidden_states, pad_width=pad)
|
59 |
+
hidden_states = self.conv(hidden_states)
|
60 |
+
return hidden_states
|
61 |
+
|
62 |
+
|
63 |
+
class FlaxResnetBlock2D(nn.Module):
|
64 |
+
in_channels: int
|
65 |
+
out_channels: int = None
|
66 |
+
dropout_prob: float = 0.0
|
67 |
+
use_nin_shortcut: bool = None
|
68 |
+
dtype: jnp.dtype = jnp.float32
|
69 |
+
|
70 |
+
def setup(self):
|
71 |
+
out_channels = self.in_channels if self.out_channels is None else self.out_channels
|
72 |
+
|
73 |
+
self.norm1 = nn.GroupNorm(num_groups=32, epsilon=1e-5)
|
74 |
+
self.conv1 = nn.Conv(
|
75 |
+
out_channels,
|
76 |
+
kernel_size=(3, 3),
|
77 |
+
strides=(1, 1),
|
78 |
+
padding=((1, 1), (1, 1)),
|
79 |
+
dtype=self.dtype,
|
80 |
+
)
|
81 |
+
|
82 |
+
self.time_emb_proj = nn.Dense(out_channels, dtype=self.dtype)
|
83 |
+
|
84 |
+
self.norm2 = nn.GroupNorm(num_groups=32, epsilon=1e-5)
|
85 |
+
self.dropout = nn.Dropout(self.dropout_prob)
|
86 |
+
self.conv2 = nn.Conv(
|
87 |
+
out_channels,
|
88 |
+
kernel_size=(3, 3),
|
89 |
+
strides=(1, 1),
|
90 |
+
padding=((1, 1), (1, 1)),
|
91 |
+
dtype=self.dtype,
|
92 |
+
)
|
93 |
+
|
94 |
+
use_nin_shortcut = self.in_channels != out_channels if self.use_nin_shortcut is None else self.use_nin_shortcut
|
95 |
+
|
96 |
+
self.conv_shortcut = None
|
97 |
+
if use_nin_shortcut:
|
98 |
+
self.conv_shortcut = nn.Conv(
|
99 |
+
out_channels,
|
100 |
+
kernel_size=(1, 1),
|
101 |
+
strides=(1, 1),
|
102 |
+
padding="VALID",
|
103 |
+
dtype=self.dtype,
|
104 |
+
)
|
105 |
+
|
106 |
+
def __call__(self, hidden_states, temb, deterministic=True):
|
107 |
+
residual = hidden_states
|
108 |
+
hidden_states = self.norm1(hidden_states)
|
109 |
+
hidden_states = nn.swish(hidden_states)
|
110 |
+
hidden_states = self.conv1(hidden_states)
|
111 |
+
|
112 |
+
temb = self.time_emb_proj(nn.swish(temb))
|
113 |
+
temb = jnp.expand_dims(jnp.expand_dims(temb, 1), 1)
|
114 |
+
hidden_states = hidden_states + temb
|
115 |
+
|
116 |
+
hidden_states = self.norm2(hidden_states)
|
117 |
+
hidden_states = nn.swish(hidden_states)
|
118 |
+
hidden_states = self.dropout(hidden_states, deterministic)
|
119 |
+
hidden_states = self.conv2(hidden_states)
|
120 |
+
|
121 |
+
if self.conv_shortcut is not None:
|
122 |
+
residual = self.conv_shortcut(residual)
|
123 |
+
|
124 |
+
return hidden_states + residual
|
6DoF/diffusers/models/t5_film_transformer.py
ADDED
@@ -0,0 +1,321 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import math
|
15 |
+
|
16 |
+
import torch
|
17 |
+
from torch import nn
|
18 |
+
|
19 |
+
from ..configuration_utils import ConfigMixin, register_to_config
|
20 |
+
from .attention_processor import Attention
|
21 |
+
from .embeddings import get_timestep_embedding
|
22 |
+
from .modeling_utils import ModelMixin
|
23 |
+
|
24 |
+
|
25 |
+
class T5FilmDecoder(ModelMixin, ConfigMixin):
|
26 |
+
@register_to_config
|
27 |
+
def __init__(
|
28 |
+
self,
|
29 |
+
input_dims: int = 128,
|
30 |
+
targets_length: int = 256,
|
31 |
+
max_decoder_noise_time: float = 2000.0,
|
32 |
+
d_model: int = 768,
|
33 |
+
num_layers: int = 12,
|
34 |
+
num_heads: int = 12,
|
35 |
+
d_kv: int = 64,
|
36 |
+
d_ff: int = 2048,
|
37 |
+
dropout_rate: float = 0.1,
|
38 |
+
):
|
39 |
+
super().__init__()
|
40 |
+
|
41 |
+
self.conditioning_emb = nn.Sequential(
|
42 |
+
nn.Linear(d_model, d_model * 4, bias=False),
|
43 |
+
nn.SiLU(),
|
44 |
+
nn.Linear(d_model * 4, d_model * 4, bias=False),
|
45 |
+
nn.SiLU(),
|
46 |
+
)
|
47 |
+
|
48 |
+
self.position_encoding = nn.Embedding(targets_length, d_model)
|
49 |
+
self.position_encoding.weight.requires_grad = False
|
50 |
+
|
51 |
+
self.continuous_inputs_projection = nn.Linear(input_dims, d_model, bias=False)
|
52 |
+
|
53 |
+
self.dropout = nn.Dropout(p=dropout_rate)
|
54 |
+
|
55 |
+
self.decoders = nn.ModuleList()
|
56 |
+
for lyr_num in range(num_layers):
|
57 |
+
# FiLM conditional T5 decoder
|
58 |
+
lyr = DecoderLayer(d_model=d_model, d_kv=d_kv, num_heads=num_heads, d_ff=d_ff, dropout_rate=dropout_rate)
|
59 |
+
self.decoders.append(lyr)
|
60 |
+
|
61 |
+
self.decoder_norm = T5LayerNorm(d_model)
|
62 |
+
|
63 |
+
self.post_dropout = nn.Dropout(p=dropout_rate)
|
64 |
+
self.spec_out = nn.Linear(d_model, input_dims, bias=False)
|
65 |
+
|
66 |
+
def encoder_decoder_mask(self, query_input, key_input):
|
67 |
+
mask = torch.mul(query_input.unsqueeze(-1), key_input.unsqueeze(-2))
|
68 |
+
return mask.unsqueeze(-3)
|
69 |
+
|
70 |
+
def forward(self, encodings_and_masks, decoder_input_tokens, decoder_noise_time):
|
71 |
+
batch, _, _ = decoder_input_tokens.shape
|
72 |
+
assert decoder_noise_time.shape == (batch,)
|
73 |
+
|
74 |
+
# decoder_noise_time is in [0, 1), so rescale to expected timing range.
|
75 |
+
time_steps = get_timestep_embedding(
|
76 |
+
decoder_noise_time * self.config.max_decoder_noise_time,
|
77 |
+
embedding_dim=self.config.d_model,
|
78 |
+
max_period=self.config.max_decoder_noise_time,
|
79 |
+
).to(dtype=self.dtype)
|
80 |
+
|
81 |
+
conditioning_emb = self.conditioning_emb(time_steps).unsqueeze(1)
|
82 |
+
|
83 |
+
assert conditioning_emb.shape == (batch, 1, self.config.d_model * 4)
|
84 |
+
|
85 |
+
seq_length = decoder_input_tokens.shape[1]
|
86 |
+
|
87 |
+
# If we want to use relative positions for audio context, we can just offset
|
88 |
+
# this sequence by the length of encodings_and_masks.
|
89 |
+
decoder_positions = torch.broadcast_to(
|
90 |
+
torch.arange(seq_length, device=decoder_input_tokens.device),
|
91 |
+
(batch, seq_length),
|
92 |
+
)
|
93 |
+
|
94 |
+
position_encodings = self.position_encoding(decoder_positions)
|
95 |
+
|
96 |
+
inputs = self.continuous_inputs_projection(decoder_input_tokens)
|
97 |
+
inputs += position_encodings
|
98 |
+
y = self.dropout(inputs)
|
99 |
+
|
100 |
+
# decoder: No padding present.
|
101 |
+
decoder_mask = torch.ones(
|
102 |
+
decoder_input_tokens.shape[:2], device=decoder_input_tokens.device, dtype=inputs.dtype
|
103 |
+
)
|
104 |
+
|
105 |
+
# Translate encoding masks to encoder-decoder masks.
|
106 |
+
encodings_and_encdec_masks = [(x, self.encoder_decoder_mask(decoder_mask, y)) for x, y in encodings_and_masks]
|
107 |
+
|
108 |
+
# cross attend style: concat encodings
|
109 |
+
encoded = torch.cat([x[0] for x in encodings_and_encdec_masks], dim=1)
|
110 |
+
encoder_decoder_mask = torch.cat([x[1] for x in encodings_and_encdec_masks], dim=-1)
|
111 |
+
|
112 |
+
for lyr in self.decoders:
|
113 |
+
y = lyr(
|
114 |
+
y,
|
115 |
+
conditioning_emb=conditioning_emb,
|
116 |
+
encoder_hidden_states=encoded,
|
117 |
+
encoder_attention_mask=encoder_decoder_mask,
|
118 |
+
)[0]
|
119 |
+
|
120 |
+
y = self.decoder_norm(y)
|
121 |
+
y = self.post_dropout(y)
|
122 |
+
|
123 |
+
spec_out = self.spec_out(y)
|
124 |
+
return spec_out
|
125 |
+
|
126 |
+
|
127 |
+
class DecoderLayer(nn.Module):
|
128 |
+
def __init__(self, d_model, d_kv, num_heads, d_ff, dropout_rate, layer_norm_epsilon=1e-6):
|
129 |
+
super().__init__()
|
130 |
+
self.layer = nn.ModuleList()
|
131 |
+
|
132 |
+
# cond self attention: layer 0
|
133 |
+
self.layer.append(
|
134 |
+
T5LayerSelfAttentionCond(d_model=d_model, d_kv=d_kv, num_heads=num_heads, dropout_rate=dropout_rate)
|
135 |
+
)
|
136 |
+
|
137 |
+
# cross attention: layer 1
|
138 |
+
self.layer.append(
|
139 |
+
T5LayerCrossAttention(
|
140 |
+
d_model=d_model,
|
141 |
+
d_kv=d_kv,
|
142 |
+
num_heads=num_heads,
|
143 |
+
dropout_rate=dropout_rate,
|
144 |
+
layer_norm_epsilon=layer_norm_epsilon,
|
145 |
+
)
|
146 |
+
)
|
147 |
+
|
148 |
+
# Film Cond MLP + dropout: last layer
|
149 |
+
self.layer.append(
|
150 |
+
T5LayerFFCond(d_model=d_model, d_ff=d_ff, dropout_rate=dropout_rate, layer_norm_epsilon=layer_norm_epsilon)
|
151 |
+
)
|
152 |
+
|
153 |
+
def forward(
|
154 |
+
self,
|
155 |
+
hidden_states,
|
156 |
+
conditioning_emb=None,
|
157 |
+
attention_mask=None,
|
158 |
+
encoder_hidden_states=None,
|
159 |
+
encoder_attention_mask=None,
|
160 |
+
encoder_decoder_position_bias=None,
|
161 |
+
):
|
162 |
+
hidden_states = self.layer[0](
|
163 |
+
hidden_states,
|
164 |
+
conditioning_emb=conditioning_emb,
|
165 |
+
attention_mask=attention_mask,
|
166 |
+
)
|
167 |
+
|
168 |
+
if encoder_hidden_states is not None:
|
169 |
+
encoder_extended_attention_mask = torch.where(encoder_attention_mask > 0, 0, -1e10).to(
|
170 |
+
encoder_hidden_states.dtype
|
171 |
+
)
|
172 |
+
|
173 |
+
hidden_states = self.layer[1](
|
174 |
+
hidden_states,
|
175 |
+
key_value_states=encoder_hidden_states,
|
176 |
+
attention_mask=encoder_extended_attention_mask,
|
177 |
+
)
|
178 |
+
|
179 |
+
# Apply Film Conditional Feed Forward layer
|
180 |
+
hidden_states = self.layer[-1](hidden_states, conditioning_emb)
|
181 |
+
|
182 |
+
return (hidden_states,)
|
183 |
+
|
184 |
+
|
185 |
+
class T5LayerSelfAttentionCond(nn.Module):
|
186 |
+
def __init__(self, d_model, d_kv, num_heads, dropout_rate):
|
187 |
+
super().__init__()
|
188 |
+
self.layer_norm = T5LayerNorm(d_model)
|
189 |
+
self.FiLMLayer = T5FiLMLayer(in_features=d_model * 4, out_features=d_model)
|
190 |
+
self.attention = Attention(query_dim=d_model, heads=num_heads, dim_head=d_kv, out_bias=False, scale_qk=False)
|
191 |
+
self.dropout = nn.Dropout(dropout_rate)
|
192 |
+
|
193 |
+
def forward(
|
194 |
+
self,
|
195 |
+
hidden_states,
|
196 |
+
conditioning_emb=None,
|
197 |
+
attention_mask=None,
|
198 |
+
):
|
199 |
+
# pre_self_attention_layer_norm
|
200 |
+
normed_hidden_states = self.layer_norm(hidden_states)
|
201 |
+
|
202 |
+
if conditioning_emb is not None:
|
203 |
+
normed_hidden_states = self.FiLMLayer(normed_hidden_states, conditioning_emb)
|
204 |
+
|
205 |
+
# Self-attention block
|
206 |
+
attention_output = self.attention(normed_hidden_states)
|
207 |
+
|
208 |
+
hidden_states = hidden_states + self.dropout(attention_output)
|
209 |
+
|
210 |
+
return hidden_states
|
211 |
+
|
212 |
+
|
213 |
+
class T5LayerCrossAttention(nn.Module):
|
214 |
+
def __init__(self, d_model, d_kv, num_heads, dropout_rate, layer_norm_epsilon):
|
215 |
+
super().__init__()
|
216 |
+
self.attention = Attention(query_dim=d_model, heads=num_heads, dim_head=d_kv, out_bias=False, scale_qk=False)
|
217 |
+
self.layer_norm = T5LayerNorm(d_model, eps=layer_norm_epsilon)
|
218 |
+
self.dropout = nn.Dropout(dropout_rate)
|
219 |
+
|
220 |
+
def forward(
|
221 |
+
self,
|
222 |
+
hidden_states,
|
223 |
+
key_value_states=None,
|
224 |
+
attention_mask=None,
|
225 |
+
):
|
226 |
+
normed_hidden_states = self.layer_norm(hidden_states)
|
227 |
+
attention_output = self.attention(
|
228 |
+
normed_hidden_states,
|
229 |
+
encoder_hidden_states=key_value_states,
|
230 |
+
attention_mask=attention_mask.squeeze(1),
|
231 |
+
)
|
232 |
+
layer_output = hidden_states + self.dropout(attention_output)
|
233 |
+
return layer_output
|
234 |
+
|
235 |
+
|
236 |
+
class T5LayerFFCond(nn.Module):
|
237 |
+
def __init__(self, d_model, d_ff, dropout_rate, layer_norm_epsilon):
|
238 |
+
super().__init__()
|
239 |
+
self.DenseReluDense = T5DenseGatedActDense(d_model=d_model, d_ff=d_ff, dropout_rate=dropout_rate)
|
240 |
+
self.film = T5FiLMLayer(in_features=d_model * 4, out_features=d_model)
|
241 |
+
self.layer_norm = T5LayerNorm(d_model, eps=layer_norm_epsilon)
|
242 |
+
self.dropout = nn.Dropout(dropout_rate)
|
243 |
+
|
244 |
+
def forward(self, hidden_states, conditioning_emb=None):
|
245 |
+
forwarded_states = self.layer_norm(hidden_states)
|
246 |
+
if conditioning_emb is not None:
|
247 |
+
forwarded_states = self.film(forwarded_states, conditioning_emb)
|
248 |
+
|
249 |
+
forwarded_states = self.DenseReluDense(forwarded_states)
|
250 |
+
hidden_states = hidden_states + self.dropout(forwarded_states)
|
251 |
+
return hidden_states
|
252 |
+
|
253 |
+
|
254 |
+
class T5DenseGatedActDense(nn.Module):
|
255 |
+
def __init__(self, d_model, d_ff, dropout_rate):
|
256 |
+
super().__init__()
|
257 |
+
self.wi_0 = nn.Linear(d_model, d_ff, bias=False)
|
258 |
+
self.wi_1 = nn.Linear(d_model, d_ff, bias=False)
|
259 |
+
self.wo = nn.Linear(d_ff, d_model, bias=False)
|
260 |
+
self.dropout = nn.Dropout(dropout_rate)
|
261 |
+
self.act = NewGELUActivation()
|
262 |
+
|
263 |
+
def forward(self, hidden_states):
|
264 |
+
hidden_gelu = self.act(self.wi_0(hidden_states))
|
265 |
+
hidden_linear = self.wi_1(hidden_states)
|
266 |
+
hidden_states = hidden_gelu * hidden_linear
|
267 |
+
hidden_states = self.dropout(hidden_states)
|
268 |
+
|
269 |
+
hidden_states = self.wo(hidden_states)
|
270 |
+
return hidden_states
|
271 |
+
|
272 |
+
|
273 |
+
class T5LayerNorm(nn.Module):
|
274 |
+
def __init__(self, hidden_size, eps=1e-6):
|
275 |
+
"""
|
276 |
+
Construct a layernorm module in the T5 style. No bias and no subtraction of mean.
|
277 |
+
"""
|
278 |
+
super().__init__()
|
279 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
280 |
+
self.variance_epsilon = eps
|
281 |
+
|
282 |
+
def forward(self, hidden_states):
|
283 |
+
# T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean
|
284 |
+
# Square Layer Normalization https://arxiv.org/abs/1910.07467 thus variance is calculated
|
285 |
+
# w/o mean and there is no bias. Additionally we want to make sure that the accumulation for
|
286 |
+
# half-precision inputs is done in fp32
|
287 |
+
|
288 |
+
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
|
289 |
+
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
290 |
+
|
291 |
+
# convert into half-precision if necessary
|
292 |
+
if self.weight.dtype in [torch.float16, torch.bfloat16]:
|
293 |
+
hidden_states = hidden_states.to(self.weight.dtype)
|
294 |
+
|
295 |
+
return self.weight * hidden_states
|
296 |
+
|
297 |
+
|
298 |
+
class NewGELUActivation(nn.Module):
|
299 |
+
"""
|
300 |
+
Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see
|
301 |
+
the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
|
302 |
+
"""
|
303 |
+
|
304 |
+
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
305 |
+
return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))
|
306 |
+
|
307 |
+
|
308 |
+
class T5FiLMLayer(nn.Module):
|
309 |
+
"""
|
310 |
+
FiLM Layer
|
311 |
+
"""
|
312 |
+
|
313 |
+
def __init__(self, in_features, out_features):
|
314 |
+
super().__init__()
|
315 |
+
self.scale_bias = nn.Linear(in_features, out_features * 2, bias=False)
|
316 |
+
|
317 |
+
def forward(self, x, conditioning_emb):
|
318 |
+
emb = self.scale_bias(conditioning_emb)
|
319 |
+
scale, shift = torch.chunk(emb, 2, -1)
|
320 |
+
x = x * (1 + scale) + shift
|
321 |
+
return x
|
6DoF/diffusers/models/transformer_2d.py
ADDED
@@ -0,0 +1,343 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
from dataclasses import dataclass
|
15 |
+
from typing import Any, Dict, Optional
|
16 |
+
|
17 |
+
import torch
|
18 |
+
import torch.nn.functional as F
|
19 |
+
from torch import nn
|
20 |
+
|
21 |
+
from ..configuration_utils import ConfigMixin, register_to_config
|
22 |
+
from ..models.embeddings import ImagePositionalEmbeddings
|
23 |
+
from ..utils import BaseOutput, deprecate
|
24 |
+
from .attention import BasicTransformerBlock
|
25 |
+
from .embeddings import PatchEmbed
|
26 |
+
from .modeling_utils import ModelMixin
|
27 |
+
|
28 |
+
|
29 |
+
@dataclass
|
30 |
+
class Transformer2DModelOutput(BaseOutput):
|
31 |
+
"""
|
32 |
+
The output of [`Transformer2DModel`].
|
33 |
+
|
34 |
+
Args:
|
35 |
+
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
|
36 |
+
The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
|
37 |
+
distributions for the unnoised latent pixels.
|
38 |
+
"""
|
39 |
+
|
40 |
+
sample: torch.FloatTensor
|
41 |
+
|
42 |
+
|
43 |
+
class Transformer2DModel(ModelMixin, ConfigMixin):
|
44 |
+
"""
|
45 |
+
A 2D Transformer model for image-like data.
|
46 |
+
|
47 |
+
Parameters:
|
48 |
+
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
|
49 |
+
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
|
50 |
+
in_channels (`int`, *optional*):
|
51 |
+
The number of channels in the input and output (specify if the input is **continuous**).
|
52 |
+
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
|
53 |
+
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
54 |
+
cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
|
55 |
+
sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
|
56 |
+
This is fixed during training since it is used to learn a number of position embeddings.
|
57 |
+
num_vector_embeds (`int`, *optional*):
|
58 |
+
The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**).
|
59 |
+
Includes the class for the masked latent pixel.
|
60 |
+
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
|
61 |
+
num_embeds_ada_norm ( `int`, *optional*):
|
62 |
+
The number of diffusion steps used during training. Pass if at least one of the norm_layers is
|
63 |
+
`AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
|
64 |
+
added to the hidden states.
|
65 |
+
|
66 |
+
During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`.
|
67 |
+
attention_bias (`bool`, *optional*):
|
68 |
+
Configure if the `TransformerBlocks` attention should contain a bias parameter.
|
69 |
+
"""
|
70 |
+
|
71 |
+
@register_to_config
|
72 |
+
def __init__(
|
73 |
+
self,
|
74 |
+
num_attention_heads: int = 16,
|
75 |
+
attention_head_dim: int = 88,
|
76 |
+
in_channels: Optional[int] = None,
|
77 |
+
out_channels: Optional[int] = None,
|
78 |
+
num_layers: int = 1,
|
79 |
+
dropout: float = 0.0,
|
80 |
+
norm_num_groups: int = 32,
|
81 |
+
cross_attention_dim: Optional[int] = None,
|
82 |
+
attention_bias: bool = False,
|
83 |
+
sample_size: Optional[int] = None,
|
84 |
+
num_vector_embeds: Optional[int] = None,
|
85 |
+
patch_size: Optional[int] = None,
|
86 |
+
activation_fn: str = "geglu",
|
87 |
+
num_embeds_ada_norm: Optional[int] = None,
|
88 |
+
use_linear_projection: bool = False,
|
89 |
+
only_cross_attention: bool = False,
|
90 |
+
upcast_attention: bool = False,
|
91 |
+
norm_type: str = "layer_norm",
|
92 |
+
norm_elementwise_affine: bool = True,
|
93 |
+
):
|
94 |
+
super().__init__()
|
95 |
+
self.use_linear_projection = use_linear_projection
|
96 |
+
self.num_attention_heads = num_attention_heads
|
97 |
+
self.attention_head_dim = attention_head_dim
|
98 |
+
inner_dim = num_attention_heads * attention_head_dim
|
99 |
+
|
100 |
+
# 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
|
101 |
+
# Define whether input is continuous or discrete depending on configuration
|
102 |
+
self.is_input_continuous = (in_channels is not None) and (patch_size is None)
|
103 |
+
self.is_input_vectorized = num_vector_embeds is not None
|
104 |
+
self.is_input_patches = in_channels is not None and patch_size is not None
|
105 |
+
|
106 |
+
if norm_type == "layer_norm" and num_embeds_ada_norm is not None:
|
107 |
+
deprecation_message = (
|
108 |
+
f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or"
|
109 |
+
" incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config."
|
110 |
+
" Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect"
|
111 |
+
" results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it"
|
112 |
+
" would be very nice if you could open a Pull request for the `transformer/config.json` file"
|
113 |
+
)
|
114 |
+
deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False)
|
115 |
+
norm_type = "ada_norm"
|
116 |
+
|
117 |
+
if self.is_input_continuous and self.is_input_vectorized:
|
118 |
+
raise ValueError(
|
119 |
+
f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make"
|
120 |
+
" sure that either `in_channels` or `num_vector_embeds` is None."
|
121 |
+
)
|
122 |
+
elif self.is_input_vectorized and self.is_input_patches:
|
123 |
+
raise ValueError(
|
124 |
+
f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make"
|
125 |
+
" sure that either `num_vector_embeds` or `num_patches` is None."
|
126 |
+
)
|
127 |
+
elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches:
|
128 |
+
raise ValueError(
|
129 |
+
f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:"
|
130 |
+
f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None."
|
131 |
+
)
|
132 |
+
|
133 |
+
# 2. Define input layers
|
134 |
+
if self.is_input_continuous:
|
135 |
+
self.in_channels = in_channels
|
136 |
+
|
137 |
+
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
138 |
+
if use_linear_projection:
|
139 |
+
self.proj_in = nn.Linear(in_channels, inner_dim)
|
140 |
+
else:
|
141 |
+
self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
|
142 |
+
elif self.is_input_vectorized:
|
143 |
+
assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
|
144 |
+
assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed"
|
145 |
+
|
146 |
+
self.height = sample_size
|
147 |
+
self.width = sample_size
|
148 |
+
self.num_vector_embeds = num_vector_embeds
|
149 |
+
self.num_latent_pixels = self.height * self.width
|
150 |
+
|
151 |
+
self.latent_image_embedding = ImagePositionalEmbeddings(
|
152 |
+
num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width
|
153 |
+
)
|
154 |
+
elif self.is_input_patches:
|
155 |
+
assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size"
|
156 |
+
|
157 |
+
self.height = sample_size
|
158 |
+
self.width = sample_size
|
159 |
+
|
160 |
+
self.patch_size = patch_size
|
161 |
+
self.pos_embed = PatchEmbed(
|
162 |
+
height=sample_size,
|
163 |
+
width=sample_size,
|
164 |
+
patch_size=patch_size,
|
165 |
+
in_channels=in_channels,
|
166 |
+
embed_dim=inner_dim,
|
167 |
+
)
|
168 |
+
|
169 |
+
# 3. Define transformers blocks
|
170 |
+
self.transformer_blocks = nn.ModuleList(
|
171 |
+
[
|
172 |
+
BasicTransformerBlock(
|
173 |
+
inner_dim,
|
174 |
+
num_attention_heads,
|
175 |
+
attention_head_dim,
|
176 |
+
dropout=dropout,
|
177 |
+
cross_attention_dim=cross_attention_dim,
|
178 |
+
activation_fn=activation_fn,
|
179 |
+
num_embeds_ada_norm=num_embeds_ada_norm,
|
180 |
+
attention_bias=attention_bias,
|
181 |
+
only_cross_attention=only_cross_attention,
|
182 |
+
upcast_attention=upcast_attention,
|
183 |
+
norm_type=norm_type,
|
184 |
+
norm_elementwise_affine=norm_elementwise_affine,
|
185 |
+
)
|
186 |
+
for d in range(num_layers)
|
187 |
+
]
|
188 |
+
)
|
189 |
+
|
190 |
+
# 4. Define output layers
|
191 |
+
self.out_channels = in_channels if out_channels is None else out_channels
|
192 |
+
if self.is_input_continuous:
|
193 |
+
# TODO: should use out_channels for continuous projections
|
194 |
+
if use_linear_projection:
|
195 |
+
self.proj_out = nn.Linear(inner_dim, in_channels)
|
196 |
+
else:
|
197 |
+
self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
|
198 |
+
elif self.is_input_vectorized:
|
199 |
+
self.norm_out = nn.LayerNorm(inner_dim)
|
200 |
+
self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
|
201 |
+
elif self.is_input_patches:
|
202 |
+
self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
|
203 |
+
self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim)
|
204 |
+
self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
|
205 |
+
|
206 |
+
def forward(
|
207 |
+
self,
|
208 |
+
hidden_states: torch.Tensor,
|
209 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
210 |
+
timestep: Optional[torch.LongTensor] = None,
|
211 |
+
class_labels: Optional[torch.LongTensor] = None,
|
212 |
+
posemb: Optional = None,
|
213 |
+
cross_attention_kwargs: Dict[str, Any] = None,
|
214 |
+
attention_mask: Optional[torch.Tensor] = None,
|
215 |
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
216 |
+
return_dict: bool = True,
|
217 |
+
):
|
218 |
+
"""
|
219 |
+
The [`Transformer2DModel`] forward method.
|
220 |
+
|
221 |
+
Args:
|
222 |
+
hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
|
223 |
+
Input `hidden_states`.
|
224 |
+
encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
|
225 |
+
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
|
226 |
+
self-attention.
|
227 |
+
timestep ( `torch.LongTensor`, *optional*):
|
228 |
+
Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
|
229 |
+
class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
|
230 |
+
Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
|
231 |
+
`AdaLayerZeroNorm`.
|
232 |
+
encoder_attention_mask ( `torch.Tensor`, *optional*):
|
233 |
+
Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
|
234 |
+
|
235 |
+
* Mask `(batch, sequence_length)` True = keep, False = discard.
|
236 |
+
* Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
|
237 |
+
|
238 |
+
If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
|
239 |
+
above. This bias will be added to the cross-attention scores.
|
240 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
241 |
+
Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
|
242 |
+
tuple.
|
243 |
+
|
244 |
+
Returns:
|
245 |
+
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
|
246 |
+
`tuple` where the first element is the sample tensor.
|
247 |
+
"""
|
248 |
+
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
|
249 |
+
# we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
|
250 |
+
# we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
|
251 |
+
# expects mask of shape:
|
252 |
+
# [batch, key_tokens]
|
253 |
+
# adds singleton query_tokens dimension:
|
254 |
+
# [batch, 1, key_tokens]
|
255 |
+
# this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
|
256 |
+
# [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
|
257 |
+
# [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
|
258 |
+
if attention_mask is not None and attention_mask.ndim == 2:
|
259 |
+
# assume that mask is expressed as:
|
260 |
+
# (1 = keep, 0 = discard)
|
261 |
+
# convert mask into a bias that can be added to attention scores:
|
262 |
+
# (keep = +0, discard = -10000.0)
|
263 |
+
attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
|
264 |
+
attention_mask = attention_mask.unsqueeze(1)
|
265 |
+
|
266 |
+
# convert encoder_attention_mask to a bias the same way we do for attention_mask
|
267 |
+
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
|
268 |
+
encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
|
269 |
+
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
|
270 |
+
|
271 |
+
# 1. Input
|
272 |
+
if self.is_input_continuous:
|
273 |
+
batch, _, height, width = hidden_states.shape
|
274 |
+
residual = hidden_states
|
275 |
+
|
276 |
+
hidden_states = self.norm(hidden_states)
|
277 |
+
if not self.use_linear_projection:
|
278 |
+
hidden_states = self.proj_in(hidden_states)
|
279 |
+
inner_dim = hidden_states.shape[1]
|
280 |
+
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
|
281 |
+
else:
|
282 |
+
inner_dim = hidden_states.shape[1]
|
283 |
+
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
|
284 |
+
hidden_states = self.proj_in(hidden_states)
|
285 |
+
elif self.is_input_vectorized:
|
286 |
+
hidden_states = self.latent_image_embedding(hidden_states)
|
287 |
+
elif self.is_input_patches:
|
288 |
+
hidden_states = self.pos_embed(hidden_states)
|
289 |
+
|
290 |
+
# 2. Blocks
|
291 |
+
for block in self.transformer_blocks:
|
292 |
+
hidden_states = block(
|
293 |
+
hidden_states,
|
294 |
+
attention_mask=attention_mask,
|
295 |
+
encoder_hidden_states=encoder_hidden_states,
|
296 |
+
encoder_attention_mask=encoder_attention_mask,
|
297 |
+
timestep=timestep,
|
298 |
+
posemb=posemb,
|
299 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
300 |
+
class_labels=class_labels,
|
301 |
+
)
|
302 |
+
|
303 |
+
# 3. Output
|
304 |
+
if self.is_input_continuous:
|
305 |
+
if not self.use_linear_projection:
|
306 |
+
hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
|
307 |
+
hidden_states = self.proj_out(hidden_states)
|
308 |
+
else:
|
309 |
+
hidden_states = self.proj_out(hidden_states)
|
310 |
+
hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
|
311 |
+
|
312 |
+
output = hidden_states + residual
|
313 |
+
elif self.is_input_vectorized:
|
314 |
+
hidden_states = self.norm_out(hidden_states)
|
315 |
+
logits = self.out(hidden_states)
|
316 |
+
# (batch, self.num_vector_embeds - 1, self.num_latent_pixels)
|
317 |
+
logits = logits.permute(0, 2, 1)
|
318 |
+
|
319 |
+
# log(p(x_0))
|
320 |
+
output = F.log_softmax(logits.double(), dim=1).float()
|
321 |
+
elif self.is_input_patches:
|
322 |
+
# TODO: cleanup!
|
323 |
+
conditioning = self.transformer_blocks[0].norm1.emb(
|
324 |
+
timestep, class_labels, hidden_dtype=hidden_states.dtype
|
325 |
+
)
|
326 |
+
shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
|
327 |
+
hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
|
328 |
+
hidden_states = self.proj_out_2(hidden_states)
|
329 |
+
|
330 |
+
# unpatchify
|
331 |
+
height = width = int(hidden_states.shape[1] ** 0.5)
|
332 |
+
hidden_states = hidden_states.reshape(
|
333 |
+
shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
|
334 |
+
)
|
335 |
+
hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
|
336 |
+
output = hidden_states.reshape(
|
337 |
+
shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
|
338 |
+
)
|
339 |
+
|
340 |
+
if not return_dict:
|
341 |
+
return (output,)
|
342 |
+
|
343 |
+
return Transformer2DModelOutput(sample=output)
|
6DoF/diffusers/models/transformer_temporal.py
ADDED
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
from dataclasses import dataclass
|
15 |
+
from typing import Optional
|
16 |
+
|
17 |
+
import torch
|
18 |
+
from torch import nn
|
19 |
+
|
20 |
+
from ..configuration_utils import ConfigMixin, register_to_config
|
21 |
+
from ..utils import BaseOutput
|
22 |
+
from .attention import BasicTransformerBlock
|
23 |
+
from .modeling_utils import ModelMixin
|
24 |
+
|
25 |
+
|
26 |
+
@dataclass
|
27 |
+
class TransformerTemporalModelOutput(BaseOutput):
|
28 |
+
"""
|
29 |
+
The output of [`TransformerTemporalModel`].
|
30 |
+
|
31 |
+
Args:
|
32 |
+
sample (`torch.FloatTensor` of shape `(batch_size x num_frames, num_channels, height, width)`):
|
33 |
+
The hidden states output conditioned on `encoder_hidden_states` input.
|
34 |
+
"""
|
35 |
+
|
36 |
+
sample: torch.FloatTensor
|
37 |
+
|
38 |
+
|
39 |
+
class TransformerTemporalModel(ModelMixin, ConfigMixin):
|
40 |
+
"""
|
41 |
+
A Transformer model for video-like data.
|
42 |
+
|
43 |
+
Parameters:
|
44 |
+
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
|
45 |
+
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
|
46 |
+
in_channels (`int`, *optional*):
|
47 |
+
The number of channels in the input and output (specify if the input is **continuous**).
|
48 |
+
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
|
49 |
+
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
50 |
+
cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
|
51 |
+
sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
|
52 |
+
This is fixed during training since it is used to learn a number of position embeddings.
|
53 |
+
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
|
54 |
+
attention_bias (`bool`, *optional*):
|
55 |
+
Configure if the `TransformerBlock` attention should contain a bias parameter.
|
56 |
+
double_self_attention (`bool`, *optional*):
|
57 |
+
Configure if each `TransformerBlock` should contain two self-attention layers.
|
58 |
+
"""
|
59 |
+
|
60 |
+
@register_to_config
|
61 |
+
def __init__(
|
62 |
+
self,
|
63 |
+
num_attention_heads: int = 16,
|
64 |
+
attention_head_dim: int = 88,
|
65 |
+
in_channels: Optional[int] = None,
|
66 |
+
out_channels: Optional[int] = None,
|
67 |
+
num_layers: int = 1,
|
68 |
+
dropout: float = 0.0,
|
69 |
+
norm_num_groups: int = 32,
|
70 |
+
cross_attention_dim: Optional[int] = None,
|
71 |
+
attention_bias: bool = False,
|
72 |
+
sample_size: Optional[int] = None,
|
73 |
+
activation_fn: str = "geglu",
|
74 |
+
norm_elementwise_affine: bool = True,
|
75 |
+
double_self_attention: bool = True,
|
76 |
+
):
|
77 |
+
super().__init__()
|
78 |
+
self.num_attention_heads = num_attention_heads
|
79 |
+
self.attention_head_dim = attention_head_dim
|
80 |
+
inner_dim = num_attention_heads * attention_head_dim
|
81 |
+
|
82 |
+
self.in_channels = in_channels
|
83 |
+
|
84 |
+
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
85 |
+
self.proj_in = nn.Linear(in_channels, inner_dim)
|
86 |
+
|
87 |
+
# 3. Define transformers blocks
|
88 |
+
self.transformer_blocks = nn.ModuleList(
|
89 |
+
[
|
90 |
+
BasicTransformerBlock(
|
91 |
+
inner_dim,
|
92 |
+
num_attention_heads,
|
93 |
+
attention_head_dim,
|
94 |
+
dropout=dropout,
|
95 |
+
cross_attention_dim=cross_attention_dim,
|
96 |
+
activation_fn=activation_fn,
|
97 |
+
attention_bias=attention_bias,
|
98 |
+
double_self_attention=double_self_attention,
|
99 |
+
norm_elementwise_affine=norm_elementwise_affine,
|
100 |
+
)
|
101 |
+
for d in range(num_layers)
|
102 |
+
]
|
103 |
+
)
|
104 |
+
|
105 |
+
self.proj_out = nn.Linear(inner_dim, in_channels)
|
106 |
+
|
107 |
+
def forward(
|
108 |
+
self,
|
109 |
+
hidden_states,
|
110 |
+
encoder_hidden_states=None,
|
111 |
+
timestep=None,
|
112 |
+
class_labels=None,
|
113 |
+
num_frames=1,
|
114 |
+
cross_attention_kwargs=None,
|
115 |
+
return_dict: bool = True,
|
116 |
+
):
|
117 |
+
"""
|
118 |
+
The [`TransformerTemporal`] forward method.
|
119 |
+
|
120 |
+
Args:
|
121 |
+
hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
|
122 |
+
Input hidden_states.
|
123 |
+
encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
|
124 |
+
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
|
125 |
+
self-attention.
|
126 |
+
timestep ( `torch.long`, *optional*):
|
127 |
+
Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
|
128 |
+
class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
|
129 |
+
Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
|
130 |
+
`AdaLayerZeroNorm`.
|
131 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
132 |
+
Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
|
133 |
+
tuple.
|
134 |
+
|
135 |
+
Returns:
|
136 |
+
[`~models.transformer_temporal.TransformerTemporalModelOutput`] or `tuple`:
|
137 |
+
If `return_dict` is True, an [`~models.transformer_temporal.TransformerTemporalModelOutput`] is
|
138 |
+
returned, otherwise a `tuple` where the first element is the sample tensor.
|
139 |
+
"""
|
140 |
+
# 1. Input
|
141 |
+
batch_frames, channel, height, width = hidden_states.shape
|
142 |
+
batch_size = batch_frames // num_frames
|
143 |
+
|
144 |
+
residual = hidden_states
|
145 |
+
|
146 |
+
hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, channel, height, width)
|
147 |
+
hidden_states = hidden_states.permute(0, 2, 1, 3, 4)
|
148 |
+
|
149 |
+
hidden_states = self.norm(hidden_states)
|
150 |
+
hidden_states = hidden_states.permute(0, 3, 4, 2, 1).reshape(batch_size * height * width, num_frames, channel)
|
151 |
+
|
152 |
+
hidden_states = self.proj_in(hidden_states)
|
153 |
+
|
154 |
+
# 2. Blocks
|
155 |
+
for block in self.transformer_blocks:
|
156 |
+
hidden_states = block(
|
157 |
+
hidden_states,
|
158 |
+
encoder_hidden_states=encoder_hidden_states,
|
159 |
+
timestep=timestep,
|
160 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
161 |
+
class_labels=class_labels,
|
162 |
+
)
|
163 |
+
|
164 |
+
# 3. Output
|
165 |
+
hidden_states = self.proj_out(hidden_states)
|
166 |
+
hidden_states = (
|
167 |
+
hidden_states[None, None, :]
|
168 |
+
.reshape(batch_size, height, width, channel, num_frames)
|
169 |
+
.permute(0, 3, 4, 1, 2)
|
170 |
+
.contiguous()
|
171 |
+
)
|
172 |
+
hidden_states = hidden_states.reshape(batch_frames, channel, height, width)
|
173 |
+
|
174 |
+
output = hidden_states + residual
|
175 |
+
|
176 |
+
if not return_dict:
|
177 |
+
return (output,)
|
178 |
+
|
179 |
+
return TransformerTemporalModelOutput(sample=output)
|
6DoF/diffusers/models/unet_1d.py
ADDED
@@ -0,0 +1,255 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from dataclasses import dataclass
|
16 |
+
from typing import Optional, Tuple, Union
|
17 |
+
|
18 |
+
import torch
|
19 |
+
import torch.nn as nn
|
20 |
+
|
21 |
+
from ..configuration_utils import ConfigMixin, register_to_config
|
22 |
+
from ..utils import BaseOutput
|
23 |
+
from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
|
24 |
+
from .modeling_utils import ModelMixin
|
25 |
+
from .unet_1d_blocks import get_down_block, get_mid_block, get_out_block, get_up_block
|
26 |
+
|
27 |
+
|
28 |
+
@dataclass
|
29 |
+
class UNet1DOutput(BaseOutput):
|
30 |
+
"""
|
31 |
+
The output of [`UNet1DModel`].
|
32 |
+
|
33 |
+
Args:
|
34 |
+
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, sample_size)`):
|
35 |
+
The hidden states output from the last layer of the model.
|
36 |
+
"""
|
37 |
+
|
38 |
+
sample: torch.FloatTensor
|
39 |
+
|
40 |
+
|
41 |
+
class UNet1DModel(ModelMixin, ConfigMixin):
|
42 |
+
r"""
|
43 |
+
A 1D UNet model that takes a noisy sample and a timestep and returns a sample shaped output.
|
44 |
+
|
45 |
+
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
|
46 |
+
for all models (such as downloading or saving).
|
47 |
+
|
48 |
+
Parameters:
|
49 |
+
sample_size (`int`, *optional*): Default length of sample. Should be adaptable at runtime.
|
50 |
+
in_channels (`int`, *optional*, defaults to 2): Number of channels in the input sample.
|
51 |
+
out_channels (`int`, *optional*, defaults to 2): Number of channels in the output.
|
52 |
+
extra_in_channels (`int`, *optional*, defaults to 0):
|
53 |
+
Number of additional channels to be added to the input of the first down block. Useful for cases where the
|
54 |
+
input data has more channels than what the model was initially designed for.
|
55 |
+
time_embedding_type (`str`, *optional*, defaults to `"fourier"`): Type of time embedding to use.
|
56 |
+
freq_shift (`float`, *optional*, defaults to 0.0): Frequency shift for Fourier time embedding.
|
57 |
+
flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
|
58 |
+
Whether to flip sin to cos for Fourier time embedding.
|
59 |
+
down_block_types (`Tuple[str]`, *optional*, defaults to `("DownBlock1D", "DownBlock1DNoSkip", "AttnDownBlock1D")`):
|
60 |
+
Tuple of downsample block types.
|
61 |
+
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock1D", "UpBlock1DNoSkip", "AttnUpBlock1D")`):
|
62 |
+
Tuple of upsample block types.
|
63 |
+
block_out_channels (`Tuple[int]`, *optional*, defaults to `(32, 32, 64)`):
|
64 |
+
Tuple of block output channels.
|
65 |
+
mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock1D"`): Block type for middle of UNet.
|
66 |
+
out_block_type (`str`, *optional*, defaults to `None`): Optional output processing block of UNet.
|
67 |
+
act_fn (`str`, *optional*, defaults to `None`): Optional activation function in UNet blocks.
|
68 |
+
norm_num_groups (`int`, *optional*, defaults to 8): The number of groups for normalization.
|
69 |
+
layers_per_block (`int`, *optional*, defaults to 1): The number of layers per block.
|
70 |
+
downsample_each_block (`int`, *optional*, defaults to `False`):
|
71 |
+
Experimental feature for using a UNet without upsampling.
|
72 |
+
"""
|
73 |
+
|
74 |
+
@register_to_config
|
75 |
+
def __init__(
|
76 |
+
self,
|
77 |
+
sample_size: int = 65536,
|
78 |
+
sample_rate: Optional[int] = None,
|
79 |
+
in_channels: int = 2,
|
80 |
+
out_channels: int = 2,
|
81 |
+
extra_in_channels: int = 0,
|
82 |
+
time_embedding_type: str = "fourier",
|
83 |
+
flip_sin_to_cos: bool = True,
|
84 |
+
use_timestep_embedding: bool = False,
|
85 |
+
freq_shift: float = 0.0,
|
86 |
+
down_block_types: Tuple[str] = ("DownBlock1DNoSkip", "DownBlock1D", "AttnDownBlock1D"),
|
87 |
+
up_block_types: Tuple[str] = ("AttnUpBlock1D", "UpBlock1D", "UpBlock1DNoSkip"),
|
88 |
+
mid_block_type: Tuple[str] = "UNetMidBlock1D",
|
89 |
+
out_block_type: str = None,
|
90 |
+
block_out_channels: Tuple[int] = (32, 32, 64),
|
91 |
+
act_fn: str = None,
|
92 |
+
norm_num_groups: int = 8,
|
93 |
+
layers_per_block: int = 1,
|
94 |
+
downsample_each_block: bool = False,
|
95 |
+
):
|
96 |
+
super().__init__()
|
97 |
+
self.sample_size = sample_size
|
98 |
+
|
99 |
+
# time
|
100 |
+
if time_embedding_type == "fourier":
|
101 |
+
self.time_proj = GaussianFourierProjection(
|
102 |
+
embedding_size=8, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
|
103 |
+
)
|
104 |
+
timestep_input_dim = 2 * block_out_channels[0]
|
105 |
+
elif time_embedding_type == "positional":
|
106 |
+
self.time_proj = Timesteps(
|
107 |
+
block_out_channels[0], flip_sin_to_cos=flip_sin_to_cos, downscale_freq_shift=freq_shift
|
108 |
+
)
|
109 |
+
timestep_input_dim = block_out_channels[0]
|
110 |
+
|
111 |
+
if use_timestep_embedding:
|
112 |
+
time_embed_dim = block_out_channels[0] * 4
|
113 |
+
self.time_mlp = TimestepEmbedding(
|
114 |
+
in_channels=timestep_input_dim,
|
115 |
+
time_embed_dim=time_embed_dim,
|
116 |
+
act_fn=act_fn,
|
117 |
+
out_dim=block_out_channels[0],
|
118 |
+
)
|
119 |
+
|
120 |
+
self.down_blocks = nn.ModuleList([])
|
121 |
+
self.mid_block = None
|
122 |
+
self.up_blocks = nn.ModuleList([])
|
123 |
+
self.out_block = None
|
124 |
+
|
125 |
+
# down
|
126 |
+
output_channel = in_channels
|
127 |
+
for i, down_block_type in enumerate(down_block_types):
|
128 |
+
input_channel = output_channel
|
129 |
+
output_channel = block_out_channels[i]
|
130 |
+
|
131 |
+
if i == 0:
|
132 |
+
input_channel += extra_in_channels
|
133 |
+
|
134 |
+
is_final_block = i == len(block_out_channels) - 1
|
135 |
+
|
136 |
+
down_block = get_down_block(
|
137 |
+
down_block_type,
|
138 |
+
num_layers=layers_per_block,
|
139 |
+
in_channels=input_channel,
|
140 |
+
out_channels=output_channel,
|
141 |
+
temb_channels=block_out_channels[0],
|
142 |
+
add_downsample=not is_final_block or downsample_each_block,
|
143 |
+
)
|
144 |
+
self.down_blocks.append(down_block)
|
145 |
+
|
146 |
+
# mid
|
147 |
+
self.mid_block = get_mid_block(
|
148 |
+
mid_block_type,
|
149 |
+
in_channels=block_out_channels[-1],
|
150 |
+
mid_channels=block_out_channels[-1],
|
151 |
+
out_channels=block_out_channels[-1],
|
152 |
+
embed_dim=block_out_channels[0],
|
153 |
+
num_layers=layers_per_block,
|
154 |
+
add_downsample=downsample_each_block,
|
155 |
+
)
|
156 |
+
|
157 |
+
# up
|
158 |
+
reversed_block_out_channels = list(reversed(block_out_channels))
|
159 |
+
output_channel = reversed_block_out_channels[0]
|
160 |
+
if out_block_type is None:
|
161 |
+
final_upsample_channels = out_channels
|
162 |
+
else:
|
163 |
+
final_upsample_channels = block_out_channels[0]
|
164 |
+
|
165 |
+
for i, up_block_type in enumerate(up_block_types):
|
166 |
+
prev_output_channel = output_channel
|
167 |
+
output_channel = (
|
168 |
+
reversed_block_out_channels[i + 1] if i < len(up_block_types) - 1 else final_upsample_channels
|
169 |
+
)
|
170 |
+
|
171 |
+
is_final_block = i == len(block_out_channels) - 1
|
172 |
+
|
173 |
+
up_block = get_up_block(
|
174 |
+
up_block_type,
|
175 |
+
num_layers=layers_per_block,
|
176 |
+
in_channels=prev_output_channel,
|
177 |
+
out_channels=output_channel,
|
178 |
+
temb_channels=block_out_channels[0],
|
179 |
+
add_upsample=not is_final_block,
|
180 |
+
)
|
181 |
+
self.up_blocks.append(up_block)
|
182 |
+
prev_output_channel = output_channel
|
183 |
+
|
184 |
+
# out
|
185 |
+
num_groups_out = norm_num_groups if norm_num_groups is not None else min(block_out_channels[0] // 4, 32)
|
186 |
+
self.out_block = get_out_block(
|
187 |
+
out_block_type=out_block_type,
|
188 |
+
num_groups_out=num_groups_out,
|
189 |
+
embed_dim=block_out_channels[0],
|
190 |
+
out_channels=out_channels,
|
191 |
+
act_fn=act_fn,
|
192 |
+
fc_dim=block_out_channels[-1] // 4,
|
193 |
+
)
|
194 |
+
|
195 |
+
def forward(
|
196 |
+
self,
|
197 |
+
sample: torch.FloatTensor,
|
198 |
+
timestep: Union[torch.Tensor, float, int],
|
199 |
+
return_dict: bool = True,
|
200 |
+
) -> Union[UNet1DOutput, Tuple]:
|
201 |
+
r"""
|
202 |
+
The [`UNet1DModel`] forward method.
|
203 |
+
|
204 |
+
Args:
|
205 |
+
sample (`torch.FloatTensor`):
|
206 |
+
The noisy input tensor with the following shape `(batch_size, num_channels, sample_size)`.
|
207 |
+
timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
|
208 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
209 |
+
Whether or not to return a [`~models.unet_1d.UNet1DOutput`] instead of a plain tuple.
|
210 |
+
|
211 |
+
Returns:
|
212 |
+
[`~models.unet_1d.UNet1DOutput`] or `tuple`:
|
213 |
+
If `return_dict` is True, an [`~models.unet_1d.UNet1DOutput`] is returned, otherwise a `tuple` is
|
214 |
+
returned where the first element is the sample tensor.
|
215 |
+
"""
|
216 |
+
|
217 |
+
# 1. time
|
218 |
+
timesteps = timestep
|
219 |
+
if not torch.is_tensor(timesteps):
|
220 |
+
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
|
221 |
+
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
|
222 |
+
timesteps = timesteps[None].to(sample.device)
|
223 |
+
|
224 |
+
timestep_embed = self.time_proj(timesteps)
|
225 |
+
if self.config.use_timestep_embedding:
|
226 |
+
timestep_embed = self.time_mlp(timestep_embed)
|
227 |
+
else:
|
228 |
+
timestep_embed = timestep_embed[..., None]
|
229 |
+
timestep_embed = timestep_embed.repeat([1, 1, sample.shape[2]]).to(sample.dtype)
|
230 |
+
timestep_embed = timestep_embed.broadcast_to((sample.shape[:1] + timestep_embed.shape[1:]))
|
231 |
+
|
232 |
+
# 2. down
|
233 |
+
down_block_res_samples = ()
|
234 |
+
for downsample_block in self.down_blocks:
|
235 |
+
sample, res_samples = downsample_block(hidden_states=sample, temb=timestep_embed)
|
236 |
+
down_block_res_samples += res_samples
|
237 |
+
|
238 |
+
# 3. mid
|
239 |
+
if self.mid_block:
|
240 |
+
sample = self.mid_block(sample, timestep_embed)
|
241 |
+
|
242 |
+
# 4. up
|
243 |
+
for i, upsample_block in enumerate(self.up_blocks):
|
244 |
+
res_samples = down_block_res_samples[-1:]
|
245 |
+
down_block_res_samples = down_block_res_samples[:-1]
|
246 |
+
sample = upsample_block(sample, res_hidden_states_tuple=res_samples, temb=timestep_embed)
|
247 |
+
|
248 |
+
# 5. post-process
|
249 |
+
if self.out_block:
|
250 |
+
sample = self.out_block(sample, timestep_embed)
|
251 |
+
|
252 |
+
if not return_dict:
|
253 |
+
return (sample,)
|
254 |
+
|
255 |
+
return UNet1DOutput(sample=sample)
|
6DoF/diffusers/models/unet_1d_blocks.py
ADDED
@@ -0,0 +1,656 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import math
|
15 |
+
|
16 |
+
import torch
|
17 |
+
import torch.nn.functional as F
|
18 |
+
from torch import nn
|
19 |
+
|
20 |
+
from .activations import get_activation
|
21 |
+
from .resnet import Downsample1D, ResidualTemporalBlock1D, Upsample1D, rearrange_dims
|
22 |
+
|
23 |
+
|
24 |
+
class DownResnetBlock1D(nn.Module):
|
25 |
+
def __init__(
|
26 |
+
self,
|
27 |
+
in_channels,
|
28 |
+
out_channels=None,
|
29 |
+
num_layers=1,
|
30 |
+
conv_shortcut=False,
|
31 |
+
temb_channels=32,
|
32 |
+
groups=32,
|
33 |
+
groups_out=None,
|
34 |
+
non_linearity=None,
|
35 |
+
time_embedding_norm="default",
|
36 |
+
output_scale_factor=1.0,
|
37 |
+
add_downsample=True,
|
38 |
+
):
|
39 |
+
super().__init__()
|
40 |
+
self.in_channels = in_channels
|
41 |
+
out_channels = in_channels if out_channels is None else out_channels
|
42 |
+
self.out_channels = out_channels
|
43 |
+
self.use_conv_shortcut = conv_shortcut
|
44 |
+
self.time_embedding_norm = time_embedding_norm
|
45 |
+
self.add_downsample = add_downsample
|
46 |
+
self.output_scale_factor = output_scale_factor
|
47 |
+
|
48 |
+
if groups_out is None:
|
49 |
+
groups_out = groups
|
50 |
+
|
51 |
+
# there will always be at least one resnet
|
52 |
+
resnets = [ResidualTemporalBlock1D(in_channels, out_channels, embed_dim=temb_channels)]
|
53 |
+
|
54 |
+
for _ in range(num_layers):
|
55 |
+
resnets.append(ResidualTemporalBlock1D(out_channels, out_channels, embed_dim=temb_channels))
|
56 |
+
|
57 |
+
self.resnets = nn.ModuleList(resnets)
|
58 |
+
|
59 |
+
if non_linearity is None:
|
60 |
+
self.nonlinearity = None
|
61 |
+
else:
|
62 |
+
self.nonlinearity = get_activation(non_linearity)
|
63 |
+
|
64 |
+
self.downsample = None
|
65 |
+
if add_downsample:
|
66 |
+
self.downsample = Downsample1D(out_channels, use_conv=True, padding=1)
|
67 |
+
|
68 |
+
def forward(self, hidden_states, temb=None):
|
69 |
+
output_states = ()
|
70 |
+
|
71 |
+
hidden_states = self.resnets[0](hidden_states, temb)
|
72 |
+
for resnet in self.resnets[1:]:
|
73 |
+
hidden_states = resnet(hidden_states, temb)
|
74 |
+
|
75 |
+
output_states += (hidden_states,)
|
76 |
+
|
77 |
+
if self.nonlinearity is not None:
|
78 |
+
hidden_states = self.nonlinearity(hidden_states)
|
79 |
+
|
80 |
+
if self.downsample is not None:
|
81 |
+
hidden_states = self.downsample(hidden_states)
|
82 |
+
|
83 |
+
return hidden_states, output_states
|
84 |
+
|
85 |
+
|
86 |
+
class UpResnetBlock1D(nn.Module):
|
87 |
+
def __init__(
|
88 |
+
self,
|
89 |
+
in_channels,
|
90 |
+
out_channels=None,
|
91 |
+
num_layers=1,
|
92 |
+
temb_channels=32,
|
93 |
+
groups=32,
|
94 |
+
groups_out=None,
|
95 |
+
non_linearity=None,
|
96 |
+
time_embedding_norm="default",
|
97 |
+
output_scale_factor=1.0,
|
98 |
+
add_upsample=True,
|
99 |
+
):
|
100 |
+
super().__init__()
|
101 |
+
self.in_channels = in_channels
|
102 |
+
out_channels = in_channels if out_channels is None else out_channels
|
103 |
+
self.out_channels = out_channels
|
104 |
+
self.time_embedding_norm = time_embedding_norm
|
105 |
+
self.add_upsample = add_upsample
|
106 |
+
self.output_scale_factor = output_scale_factor
|
107 |
+
|
108 |
+
if groups_out is None:
|
109 |
+
groups_out = groups
|
110 |
+
|
111 |
+
# there will always be at least one resnet
|
112 |
+
resnets = [ResidualTemporalBlock1D(2 * in_channels, out_channels, embed_dim=temb_channels)]
|
113 |
+
|
114 |
+
for _ in range(num_layers):
|
115 |
+
resnets.append(ResidualTemporalBlock1D(out_channels, out_channels, embed_dim=temb_channels))
|
116 |
+
|
117 |
+
self.resnets = nn.ModuleList(resnets)
|
118 |
+
|
119 |
+
if non_linearity is None:
|
120 |
+
self.nonlinearity = None
|
121 |
+
else:
|
122 |
+
self.nonlinearity = get_activation(non_linearity)
|
123 |
+
|
124 |
+
self.upsample = None
|
125 |
+
if add_upsample:
|
126 |
+
self.upsample = Upsample1D(out_channels, use_conv_transpose=True)
|
127 |
+
|
128 |
+
def forward(self, hidden_states, res_hidden_states_tuple=None, temb=None):
|
129 |
+
if res_hidden_states_tuple is not None:
|
130 |
+
res_hidden_states = res_hidden_states_tuple[-1]
|
131 |
+
hidden_states = torch.cat((hidden_states, res_hidden_states), dim=1)
|
132 |
+
|
133 |
+
hidden_states = self.resnets[0](hidden_states, temb)
|
134 |
+
for resnet in self.resnets[1:]:
|
135 |
+
hidden_states = resnet(hidden_states, temb)
|
136 |
+
|
137 |
+
if self.nonlinearity is not None:
|
138 |
+
hidden_states = self.nonlinearity(hidden_states)
|
139 |
+
|
140 |
+
if self.upsample is not None:
|
141 |
+
hidden_states = self.upsample(hidden_states)
|
142 |
+
|
143 |
+
return hidden_states
|
144 |
+
|
145 |
+
|
146 |
+
class ValueFunctionMidBlock1D(nn.Module):
|
147 |
+
def __init__(self, in_channels, out_channels, embed_dim):
|
148 |
+
super().__init__()
|
149 |
+
self.in_channels = in_channels
|
150 |
+
self.out_channels = out_channels
|
151 |
+
self.embed_dim = embed_dim
|
152 |
+
|
153 |
+
self.res1 = ResidualTemporalBlock1D(in_channels, in_channels // 2, embed_dim=embed_dim)
|
154 |
+
self.down1 = Downsample1D(out_channels // 2, use_conv=True)
|
155 |
+
self.res2 = ResidualTemporalBlock1D(in_channels // 2, in_channels // 4, embed_dim=embed_dim)
|
156 |
+
self.down2 = Downsample1D(out_channels // 4, use_conv=True)
|
157 |
+
|
158 |
+
def forward(self, x, temb=None):
|
159 |
+
x = self.res1(x, temb)
|
160 |
+
x = self.down1(x)
|
161 |
+
x = self.res2(x, temb)
|
162 |
+
x = self.down2(x)
|
163 |
+
return x
|
164 |
+
|
165 |
+
|
166 |
+
class MidResTemporalBlock1D(nn.Module):
|
167 |
+
def __init__(
|
168 |
+
self,
|
169 |
+
in_channels,
|
170 |
+
out_channels,
|
171 |
+
embed_dim,
|
172 |
+
num_layers: int = 1,
|
173 |
+
add_downsample: bool = False,
|
174 |
+
add_upsample: bool = False,
|
175 |
+
non_linearity=None,
|
176 |
+
):
|
177 |
+
super().__init__()
|
178 |
+
self.in_channels = in_channels
|
179 |
+
self.out_channels = out_channels
|
180 |
+
self.add_downsample = add_downsample
|
181 |
+
|
182 |
+
# there will always be at least one resnet
|
183 |
+
resnets = [ResidualTemporalBlock1D(in_channels, out_channels, embed_dim=embed_dim)]
|
184 |
+
|
185 |
+
for _ in range(num_layers):
|
186 |
+
resnets.append(ResidualTemporalBlock1D(out_channels, out_channels, embed_dim=embed_dim))
|
187 |
+
|
188 |
+
self.resnets = nn.ModuleList(resnets)
|
189 |
+
|
190 |
+
if non_linearity is None:
|
191 |
+
self.nonlinearity = None
|
192 |
+
else:
|
193 |
+
self.nonlinearity = get_activation(non_linearity)
|
194 |
+
|
195 |
+
self.upsample = None
|
196 |
+
if add_upsample:
|
197 |
+
self.upsample = Downsample1D(out_channels, use_conv=True)
|
198 |
+
|
199 |
+
self.downsample = None
|
200 |
+
if add_downsample:
|
201 |
+
self.downsample = Downsample1D(out_channels, use_conv=True)
|
202 |
+
|
203 |
+
if self.upsample and self.downsample:
|
204 |
+
raise ValueError("Block cannot downsample and upsample")
|
205 |
+
|
206 |
+
def forward(self, hidden_states, temb):
|
207 |
+
hidden_states = self.resnets[0](hidden_states, temb)
|
208 |
+
for resnet in self.resnets[1:]:
|
209 |
+
hidden_states = resnet(hidden_states, temb)
|
210 |
+
|
211 |
+
if self.upsample:
|
212 |
+
hidden_states = self.upsample(hidden_states)
|
213 |
+
if self.downsample:
|
214 |
+
self.downsample = self.downsample(hidden_states)
|
215 |
+
|
216 |
+
return hidden_states
|
217 |
+
|
218 |
+
|
219 |
+
class OutConv1DBlock(nn.Module):
|
220 |
+
def __init__(self, num_groups_out, out_channels, embed_dim, act_fn):
|
221 |
+
super().__init__()
|
222 |
+
self.final_conv1d_1 = nn.Conv1d(embed_dim, embed_dim, 5, padding=2)
|
223 |
+
self.final_conv1d_gn = nn.GroupNorm(num_groups_out, embed_dim)
|
224 |
+
self.final_conv1d_act = get_activation(act_fn)
|
225 |
+
self.final_conv1d_2 = nn.Conv1d(embed_dim, out_channels, 1)
|
226 |
+
|
227 |
+
def forward(self, hidden_states, temb=None):
|
228 |
+
hidden_states = self.final_conv1d_1(hidden_states)
|
229 |
+
hidden_states = rearrange_dims(hidden_states)
|
230 |
+
hidden_states = self.final_conv1d_gn(hidden_states)
|
231 |
+
hidden_states = rearrange_dims(hidden_states)
|
232 |
+
hidden_states = self.final_conv1d_act(hidden_states)
|
233 |
+
hidden_states = self.final_conv1d_2(hidden_states)
|
234 |
+
return hidden_states
|
235 |
+
|
236 |
+
|
237 |
+
class OutValueFunctionBlock(nn.Module):
|
238 |
+
def __init__(self, fc_dim, embed_dim):
|
239 |
+
super().__init__()
|
240 |
+
self.final_block = nn.ModuleList(
|
241 |
+
[
|
242 |
+
nn.Linear(fc_dim + embed_dim, fc_dim // 2),
|
243 |
+
nn.Mish(),
|
244 |
+
nn.Linear(fc_dim // 2, 1),
|
245 |
+
]
|
246 |
+
)
|
247 |
+
|
248 |
+
def forward(self, hidden_states, temb):
|
249 |
+
hidden_states = hidden_states.view(hidden_states.shape[0], -1)
|
250 |
+
hidden_states = torch.cat((hidden_states, temb), dim=-1)
|
251 |
+
for layer in self.final_block:
|
252 |
+
hidden_states = layer(hidden_states)
|
253 |
+
|
254 |
+
return hidden_states
|
255 |
+
|
256 |
+
|
257 |
+
_kernels = {
|
258 |
+
"linear": [1 / 8, 3 / 8, 3 / 8, 1 / 8],
|
259 |
+
"cubic": [-0.01171875, -0.03515625, 0.11328125, 0.43359375, 0.43359375, 0.11328125, -0.03515625, -0.01171875],
|
260 |
+
"lanczos3": [
|
261 |
+
0.003689131001010537,
|
262 |
+
0.015056144446134567,
|
263 |
+
-0.03399861603975296,
|
264 |
+
-0.066637322306633,
|
265 |
+
0.13550527393817902,
|
266 |
+
0.44638532400131226,
|
267 |
+
0.44638532400131226,
|
268 |
+
0.13550527393817902,
|
269 |
+
-0.066637322306633,
|
270 |
+
-0.03399861603975296,
|
271 |
+
0.015056144446134567,
|
272 |
+
0.003689131001010537,
|
273 |
+
],
|
274 |
+
}
|
275 |
+
|
276 |
+
|
277 |
+
class Downsample1d(nn.Module):
|
278 |
+
def __init__(self, kernel="linear", pad_mode="reflect"):
|
279 |
+
super().__init__()
|
280 |
+
self.pad_mode = pad_mode
|
281 |
+
kernel_1d = torch.tensor(_kernels[kernel])
|
282 |
+
self.pad = kernel_1d.shape[0] // 2 - 1
|
283 |
+
self.register_buffer("kernel", kernel_1d)
|
284 |
+
|
285 |
+
def forward(self, hidden_states):
|
286 |
+
hidden_states = F.pad(hidden_states, (self.pad,) * 2, self.pad_mode)
|
287 |
+
weight = hidden_states.new_zeros([hidden_states.shape[1], hidden_states.shape[1], self.kernel.shape[0]])
|
288 |
+
indices = torch.arange(hidden_states.shape[1], device=hidden_states.device)
|
289 |
+
kernel = self.kernel.to(weight)[None, :].expand(hidden_states.shape[1], -1)
|
290 |
+
weight[indices, indices] = kernel
|
291 |
+
return F.conv1d(hidden_states, weight, stride=2)
|
292 |
+
|
293 |
+
|
294 |
+
class Upsample1d(nn.Module):
|
295 |
+
def __init__(self, kernel="linear", pad_mode="reflect"):
|
296 |
+
super().__init__()
|
297 |
+
self.pad_mode = pad_mode
|
298 |
+
kernel_1d = torch.tensor(_kernels[kernel]) * 2
|
299 |
+
self.pad = kernel_1d.shape[0] // 2 - 1
|
300 |
+
self.register_buffer("kernel", kernel_1d)
|
301 |
+
|
302 |
+
def forward(self, hidden_states, temb=None):
|
303 |
+
hidden_states = F.pad(hidden_states, ((self.pad + 1) // 2,) * 2, self.pad_mode)
|
304 |
+
weight = hidden_states.new_zeros([hidden_states.shape[1], hidden_states.shape[1], self.kernel.shape[0]])
|
305 |
+
indices = torch.arange(hidden_states.shape[1], device=hidden_states.device)
|
306 |
+
kernel = self.kernel.to(weight)[None, :].expand(hidden_states.shape[1], -1)
|
307 |
+
weight[indices, indices] = kernel
|
308 |
+
return F.conv_transpose1d(hidden_states, weight, stride=2, padding=self.pad * 2 + 1)
|
309 |
+
|
310 |
+
|
311 |
+
class SelfAttention1d(nn.Module):
|
312 |
+
def __init__(self, in_channels, n_head=1, dropout_rate=0.0):
|
313 |
+
super().__init__()
|
314 |
+
self.channels = in_channels
|
315 |
+
self.group_norm = nn.GroupNorm(1, num_channels=in_channels)
|
316 |
+
self.num_heads = n_head
|
317 |
+
|
318 |
+
self.query = nn.Linear(self.channels, self.channels)
|
319 |
+
self.key = nn.Linear(self.channels, self.channels)
|
320 |
+
self.value = nn.Linear(self.channels, self.channels)
|
321 |
+
|
322 |
+
self.proj_attn = nn.Linear(self.channels, self.channels, bias=True)
|
323 |
+
|
324 |
+
self.dropout = nn.Dropout(dropout_rate, inplace=True)
|
325 |
+
|
326 |
+
def transpose_for_scores(self, projection: torch.Tensor) -> torch.Tensor:
|
327 |
+
new_projection_shape = projection.size()[:-1] + (self.num_heads, -1)
|
328 |
+
# move heads to 2nd position (B, T, H * D) -> (B, T, H, D) -> (B, H, T, D)
|
329 |
+
new_projection = projection.view(new_projection_shape).permute(0, 2, 1, 3)
|
330 |
+
return new_projection
|
331 |
+
|
332 |
+
def forward(self, hidden_states):
|
333 |
+
residual = hidden_states
|
334 |
+
batch, channel_dim, seq = hidden_states.shape
|
335 |
+
|
336 |
+
hidden_states = self.group_norm(hidden_states)
|
337 |
+
hidden_states = hidden_states.transpose(1, 2)
|
338 |
+
|
339 |
+
query_proj = self.query(hidden_states)
|
340 |
+
key_proj = self.key(hidden_states)
|
341 |
+
value_proj = self.value(hidden_states)
|
342 |
+
|
343 |
+
query_states = self.transpose_for_scores(query_proj)
|
344 |
+
key_states = self.transpose_for_scores(key_proj)
|
345 |
+
value_states = self.transpose_for_scores(value_proj)
|
346 |
+
|
347 |
+
scale = 1 / math.sqrt(math.sqrt(key_states.shape[-1]))
|
348 |
+
|
349 |
+
attention_scores = torch.matmul(query_states * scale, key_states.transpose(-1, -2) * scale)
|
350 |
+
attention_probs = torch.softmax(attention_scores, dim=-1)
|
351 |
+
|
352 |
+
# compute attention output
|
353 |
+
hidden_states = torch.matmul(attention_probs, value_states)
|
354 |
+
|
355 |
+
hidden_states = hidden_states.permute(0, 2, 1, 3).contiguous()
|
356 |
+
new_hidden_states_shape = hidden_states.size()[:-2] + (self.channels,)
|
357 |
+
hidden_states = hidden_states.view(new_hidden_states_shape)
|
358 |
+
|
359 |
+
# compute next hidden_states
|
360 |
+
hidden_states = self.proj_attn(hidden_states)
|
361 |
+
hidden_states = hidden_states.transpose(1, 2)
|
362 |
+
hidden_states = self.dropout(hidden_states)
|
363 |
+
|
364 |
+
output = hidden_states + residual
|
365 |
+
|
366 |
+
return output
|
367 |
+
|
368 |
+
|
369 |
+
class ResConvBlock(nn.Module):
|
370 |
+
def __init__(self, in_channels, mid_channels, out_channels, is_last=False):
|
371 |
+
super().__init__()
|
372 |
+
self.is_last = is_last
|
373 |
+
self.has_conv_skip = in_channels != out_channels
|
374 |
+
|
375 |
+
if self.has_conv_skip:
|
376 |
+
self.conv_skip = nn.Conv1d(in_channels, out_channels, 1, bias=False)
|
377 |
+
|
378 |
+
self.conv_1 = nn.Conv1d(in_channels, mid_channels, 5, padding=2)
|
379 |
+
self.group_norm_1 = nn.GroupNorm(1, mid_channels)
|
380 |
+
self.gelu_1 = nn.GELU()
|
381 |
+
self.conv_2 = nn.Conv1d(mid_channels, out_channels, 5, padding=2)
|
382 |
+
|
383 |
+
if not self.is_last:
|
384 |
+
self.group_norm_2 = nn.GroupNorm(1, out_channels)
|
385 |
+
self.gelu_2 = nn.GELU()
|
386 |
+
|
387 |
+
def forward(self, hidden_states):
|
388 |
+
residual = self.conv_skip(hidden_states) if self.has_conv_skip else hidden_states
|
389 |
+
|
390 |
+
hidden_states = self.conv_1(hidden_states)
|
391 |
+
hidden_states = self.group_norm_1(hidden_states)
|
392 |
+
hidden_states = self.gelu_1(hidden_states)
|
393 |
+
hidden_states = self.conv_2(hidden_states)
|
394 |
+
|
395 |
+
if not self.is_last:
|
396 |
+
hidden_states = self.group_norm_2(hidden_states)
|
397 |
+
hidden_states = self.gelu_2(hidden_states)
|
398 |
+
|
399 |
+
output = hidden_states + residual
|
400 |
+
return output
|
401 |
+
|
402 |
+
|
403 |
+
class UNetMidBlock1D(nn.Module):
|
404 |
+
def __init__(self, mid_channels, in_channels, out_channels=None):
|
405 |
+
super().__init__()
|
406 |
+
|
407 |
+
out_channels = in_channels if out_channels is None else out_channels
|
408 |
+
|
409 |
+
# there is always at least one resnet
|
410 |
+
self.down = Downsample1d("cubic")
|
411 |
+
resnets = [
|
412 |
+
ResConvBlock(in_channels, mid_channels, mid_channels),
|
413 |
+
ResConvBlock(mid_channels, mid_channels, mid_channels),
|
414 |
+
ResConvBlock(mid_channels, mid_channels, mid_channels),
|
415 |
+
ResConvBlock(mid_channels, mid_channels, mid_channels),
|
416 |
+
ResConvBlock(mid_channels, mid_channels, mid_channels),
|
417 |
+
ResConvBlock(mid_channels, mid_channels, out_channels),
|
418 |
+
]
|
419 |
+
attentions = [
|
420 |
+
SelfAttention1d(mid_channels, mid_channels // 32),
|
421 |
+
SelfAttention1d(mid_channels, mid_channels // 32),
|
422 |
+
SelfAttention1d(mid_channels, mid_channels // 32),
|
423 |
+
SelfAttention1d(mid_channels, mid_channels // 32),
|
424 |
+
SelfAttention1d(mid_channels, mid_channels // 32),
|
425 |
+
SelfAttention1d(out_channels, out_channels // 32),
|
426 |
+
]
|
427 |
+
self.up = Upsample1d(kernel="cubic")
|
428 |
+
|
429 |
+
self.attentions = nn.ModuleList(attentions)
|
430 |
+
self.resnets = nn.ModuleList(resnets)
|
431 |
+
|
432 |
+
def forward(self, hidden_states, temb=None):
|
433 |
+
hidden_states = self.down(hidden_states)
|
434 |
+
for attn, resnet in zip(self.attentions, self.resnets):
|
435 |
+
hidden_states = resnet(hidden_states)
|
436 |
+
hidden_states = attn(hidden_states)
|
437 |
+
|
438 |
+
hidden_states = self.up(hidden_states)
|
439 |
+
|
440 |
+
return hidden_states
|
441 |
+
|
442 |
+
|
443 |
+
class AttnDownBlock1D(nn.Module):
|
444 |
+
def __init__(self, out_channels, in_channels, mid_channels=None):
|
445 |
+
super().__init__()
|
446 |
+
mid_channels = out_channels if mid_channels is None else mid_channels
|
447 |
+
|
448 |
+
self.down = Downsample1d("cubic")
|
449 |
+
resnets = [
|
450 |
+
ResConvBlock(in_channels, mid_channels, mid_channels),
|
451 |
+
ResConvBlock(mid_channels, mid_channels, mid_channels),
|
452 |
+
ResConvBlock(mid_channels, mid_channels, out_channels),
|
453 |
+
]
|
454 |
+
attentions = [
|
455 |
+
SelfAttention1d(mid_channels, mid_channels // 32),
|
456 |
+
SelfAttention1d(mid_channels, mid_channels // 32),
|
457 |
+
SelfAttention1d(out_channels, out_channels // 32),
|
458 |
+
]
|
459 |
+
|
460 |
+
self.attentions = nn.ModuleList(attentions)
|
461 |
+
self.resnets = nn.ModuleList(resnets)
|
462 |
+
|
463 |
+
def forward(self, hidden_states, temb=None):
|
464 |
+
hidden_states = self.down(hidden_states)
|
465 |
+
|
466 |
+
for resnet, attn in zip(self.resnets, self.attentions):
|
467 |
+
hidden_states = resnet(hidden_states)
|
468 |
+
hidden_states = attn(hidden_states)
|
469 |
+
|
470 |
+
return hidden_states, (hidden_states,)
|
471 |
+
|
472 |
+
|
473 |
+
class DownBlock1D(nn.Module):
|
474 |
+
def __init__(self, out_channels, in_channels, mid_channels=None):
|
475 |
+
super().__init__()
|
476 |
+
mid_channels = out_channels if mid_channels is None else mid_channels
|
477 |
+
|
478 |
+
self.down = Downsample1d("cubic")
|
479 |
+
resnets = [
|
480 |
+
ResConvBlock(in_channels, mid_channels, mid_channels),
|
481 |
+
ResConvBlock(mid_channels, mid_channels, mid_channels),
|
482 |
+
ResConvBlock(mid_channels, mid_channels, out_channels),
|
483 |
+
]
|
484 |
+
|
485 |
+
self.resnets = nn.ModuleList(resnets)
|
486 |
+
|
487 |
+
def forward(self, hidden_states, temb=None):
|
488 |
+
hidden_states = self.down(hidden_states)
|
489 |
+
|
490 |
+
for resnet in self.resnets:
|
491 |
+
hidden_states = resnet(hidden_states)
|
492 |
+
|
493 |
+
return hidden_states, (hidden_states,)
|
494 |
+
|
495 |
+
|
496 |
+
class DownBlock1DNoSkip(nn.Module):
|
497 |
+
def __init__(self, out_channels, in_channels, mid_channels=None):
|
498 |
+
super().__init__()
|
499 |
+
mid_channels = out_channels if mid_channels is None else mid_channels
|
500 |
+
|
501 |
+
resnets = [
|
502 |
+
ResConvBlock(in_channels, mid_channels, mid_channels),
|
503 |
+
ResConvBlock(mid_channels, mid_channels, mid_channels),
|
504 |
+
ResConvBlock(mid_channels, mid_channels, out_channels),
|
505 |
+
]
|
506 |
+
|
507 |
+
self.resnets = nn.ModuleList(resnets)
|
508 |
+
|
509 |
+
def forward(self, hidden_states, temb=None):
|
510 |
+
hidden_states = torch.cat([hidden_states, temb], dim=1)
|
511 |
+
for resnet in self.resnets:
|
512 |
+
hidden_states = resnet(hidden_states)
|
513 |
+
|
514 |
+
return hidden_states, (hidden_states,)
|
515 |
+
|
516 |
+
|
517 |
+
class AttnUpBlock1D(nn.Module):
|
518 |
+
def __init__(self, in_channels, out_channels, mid_channels=None):
|
519 |
+
super().__init__()
|
520 |
+
mid_channels = out_channels if mid_channels is None else mid_channels
|
521 |
+
|
522 |
+
resnets = [
|
523 |
+
ResConvBlock(2 * in_channels, mid_channels, mid_channels),
|
524 |
+
ResConvBlock(mid_channels, mid_channels, mid_channels),
|
525 |
+
ResConvBlock(mid_channels, mid_channels, out_channels),
|
526 |
+
]
|
527 |
+
attentions = [
|
528 |
+
SelfAttention1d(mid_channels, mid_channels // 32),
|
529 |
+
SelfAttention1d(mid_channels, mid_channels // 32),
|
530 |
+
SelfAttention1d(out_channels, out_channels // 32),
|
531 |
+
]
|
532 |
+
|
533 |
+
self.attentions = nn.ModuleList(attentions)
|
534 |
+
self.resnets = nn.ModuleList(resnets)
|
535 |
+
self.up = Upsample1d(kernel="cubic")
|
536 |
+
|
537 |
+
def forward(self, hidden_states, res_hidden_states_tuple, temb=None):
|
538 |
+
res_hidden_states = res_hidden_states_tuple[-1]
|
539 |
+
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
540 |
+
|
541 |
+
for resnet, attn in zip(self.resnets, self.attentions):
|
542 |
+
hidden_states = resnet(hidden_states)
|
543 |
+
hidden_states = attn(hidden_states)
|
544 |
+
|
545 |
+
hidden_states = self.up(hidden_states)
|
546 |
+
|
547 |
+
return hidden_states
|
548 |
+
|
549 |
+
|
550 |
+
class UpBlock1D(nn.Module):
|
551 |
+
def __init__(self, in_channels, out_channels, mid_channels=None):
|
552 |
+
super().__init__()
|
553 |
+
mid_channels = in_channels if mid_channels is None else mid_channels
|
554 |
+
|
555 |
+
resnets = [
|
556 |
+
ResConvBlock(2 * in_channels, mid_channels, mid_channels),
|
557 |
+
ResConvBlock(mid_channels, mid_channels, mid_channels),
|
558 |
+
ResConvBlock(mid_channels, mid_channels, out_channels),
|
559 |
+
]
|
560 |
+
|
561 |
+
self.resnets = nn.ModuleList(resnets)
|
562 |
+
self.up = Upsample1d(kernel="cubic")
|
563 |
+
|
564 |
+
def forward(self, hidden_states, res_hidden_states_tuple, temb=None):
|
565 |
+
res_hidden_states = res_hidden_states_tuple[-1]
|
566 |
+
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
567 |
+
|
568 |
+
for resnet in self.resnets:
|
569 |
+
hidden_states = resnet(hidden_states)
|
570 |
+
|
571 |
+
hidden_states = self.up(hidden_states)
|
572 |
+
|
573 |
+
return hidden_states
|
574 |
+
|
575 |
+
|
576 |
+
class UpBlock1DNoSkip(nn.Module):
|
577 |
+
def __init__(self, in_channels, out_channels, mid_channels=None):
|
578 |
+
super().__init__()
|
579 |
+
mid_channels = in_channels if mid_channels is None else mid_channels
|
580 |
+
|
581 |
+
resnets = [
|
582 |
+
ResConvBlock(2 * in_channels, mid_channels, mid_channels),
|
583 |
+
ResConvBlock(mid_channels, mid_channels, mid_channels),
|
584 |
+
ResConvBlock(mid_channels, mid_channels, out_channels, is_last=True),
|
585 |
+
]
|
586 |
+
|
587 |
+
self.resnets = nn.ModuleList(resnets)
|
588 |
+
|
589 |
+
def forward(self, hidden_states, res_hidden_states_tuple, temb=None):
|
590 |
+
res_hidden_states = res_hidden_states_tuple[-1]
|
591 |
+
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
592 |
+
|
593 |
+
for resnet in self.resnets:
|
594 |
+
hidden_states = resnet(hidden_states)
|
595 |
+
|
596 |
+
return hidden_states
|
597 |
+
|
598 |
+
|
599 |
+
def get_down_block(down_block_type, num_layers, in_channels, out_channels, temb_channels, add_downsample):
|
600 |
+
if down_block_type == "DownResnetBlock1D":
|
601 |
+
return DownResnetBlock1D(
|
602 |
+
in_channels=in_channels,
|
603 |
+
num_layers=num_layers,
|
604 |
+
out_channels=out_channels,
|
605 |
+
temb_channels=temb_channels,
|
606 |
+
add_downsample=add_downsample,
|
607 |
+
)
|
608 |
+
elif down_block_type == "DownBlock1D":
|
609 |
+
return DownBlock1D(out_channels=out_channels, in_channels=in_channels)
|
610 |
+
elif down_block_type == "AttnDownBlock1D":
|
611 |
+
return AttnDownBlock1D(out_channels=out_channels, in_channels=in_channels)
|
612 |
+
elif down_block_type == "DownBlock1DNoSkip":
|
613 |
+
return DownBlock1DNoSkip(out_channels=out_channels, in_channels=in_channels)
|
614 |
+
raise ValueError(f"{down_block_type} does not exist.")
|
615 |
+
|
616 |
+
|
617 |
+
def get_up_block(up_block_type, num_layers, in_channels, out_channels, temb_channels, add_upsample):
|
618 |
+
if up_block_type == "UpResnetBlock1D":
|
619 |
+
return UpResnetBlock1D(
|
620 |
+
in_channels=in_channels,
|
621 |
+
num_layers=num_layers,
|
622 |
+
out_channels=out_channels,
|
623 |
+
temb_channels=temb_channels,
|
624 |
+
add_upsample=add_upsample,
|
625 |
+
)
|
626 |
+
elif up_block_type == "UpBlock1D":
|
627 |
+
return UpBlock1D(in_channels=in_channels, out_channels=out_channels)
|
628 |
+
elif up_block_type == "AttnUpBlock1D":
|
629 |
+
return AttnUpBlock1D(in_channels=in_channels, out_channels=out_channels)
|
630 |
+
elif up_block_type == "UpBlock1DNoSkip":
|
631 |
+
return UpBlock1DNoSkip(in_channels=in_channels, out_channels=out_channels)
|
632 |
+
raise ValueError(f"{up_block_type} does not exist.")
|
633 |
+
|
634 |
+
|
635 |
+
def get_mid_block(mid_block_type, num_layers, in_channels, mid_channels, out_channels, embed_dim, add_downsample):
|
636 |
+
if mid_block_type == "MidResTemporalBlock1D":
|
637 |
+
return MidResTemporalBlock1D(
|
638 |
+
num_layers=num_layers,
|
639 |
+
in_channels=in_channels,
|
640 |
+
out_channels=out_channels,
|
641 |
+
embed_dim=embed_dim,
|
642 |
+
add_downsample=add_downsample,
|
643 |
+
)
|
644 |
+
elif mid_block_type == "ValueFunctionMidBlock1D":
|
645 |
+
return ValueFunctionMidBlock1D(in_channels=in_channels, out_channels=out_channels, embed_dim=embed_dim)
|
646 |
+
elif mid_block_type == "UNetMidBlock1D":
|
647 |
+
return UNetMidBlock1D(in_channels=in_channels, mid_channels=mid_channels, out_channels=out_channels)
|
648 |
+
raise ValueError(f"{mid_block_type} does not exist.")
|
649 |
+
|
650 |
+
|
651 |
+
def get_out_block(*, out_block_type, num_groups_out, embed_dim, out_channels, act_fn, fc_dim):
|
652 |
+
if out_block_type == "OutConv1DBlock":
|
653 |
+
return OutConv1DBlock(num_groups_out, out_channels, embed_dim, act_fn)
|
654 |
+
elif out_block_type == "ValueFunction":
|
655 |
+
return OutValueFunctionBlock(fc_dim, embed_dim)
|
656 |
+
return None
|
6DoF/diffusers/models/unet_2d.py
ADDED
@@ -0,0 +1,329 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
from dataclasses import dataclass
|
15 |
+
from typing import Optional, Tuple, Union
|
16 |
+
|
17 |
+
import torch
|
18 |
+
import torch.nn as nn
|
19 |
+
|
20 |
+
from ..configuration_utils import ConfigMixin, register_to_config
|
21 |
+
from ..utils import BaseOutput
|
22 |
+
from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
|
23 |
+
from .modeling_utils import ModelMixin
|
24 |
+
from .unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block
|
25 |
+
|
26 |
+
|
27 |
+
@dataclass
|
28 |
+
class UNet2DOutput(BaseOutput):
|
29 |
+
"""
|
30 |
+
The output of [`UNet2DModel`].
|
31 |
+
|
32 |
+
Args:
|
33 |
+
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
34 |
+
The hidden states output from the last layer of the model.
|
35 |
+
"""
|
36 |
+
|
37 |
+
sample: torch.FloatTensor
|
38 |
+
|
39 |
+
|
40 |
+
class UNet2DModel(ModelMixin, ConfigMixin):
|
41 |
+
r"""
|
42 |
+
A 2D UNet model that takes a noisy sample and a timestep and returns a sample shaped output.
|
43 |
+
|
44 |
+
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
|
45 |
+
for all models (such as downloading or saving).
|
46 |
+
|
47 |
+
Parameters:
|
48 |
+
sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
|
49 |
+
Height and width of input/output sample. Dimensions must be a multiple of `2 ** (len(block_out_channels) -
|
50 |
+
1)`.
|
51 |
+
in_channels (`int`, *optional*, defaults to 3): Number of channels in the input sample.
|
52 |
+
out_channels (`int`, *optional*, defaults to 3): Number of channels in the output.
|
53 |
+
center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
|
54 |
+
time_embedding_type (`str`, *optional*, defaults to `"positional"`): Type of time embedding to use.
|
55 |
+
freq_shift (`int`, *optional*, defaults to 0): Frequency shift for Fourier time embedding.
|
56 |
+
flip_sin_to_cos (`bool`, *optional*, defaults to `True`):
|
57 |
+
Whether to flip sin to cos for Fourier time embedding.
|
58 |
+
down_block_types (`Tuple[str]`, *optional*, defaults to `("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D")`):
|
59 |
+
Tuple of downsample block types.
|
60 |
+
mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2D"`):
|
61 |
+
Block type for middle of UNet, it can be either `UNetMidBlock2D` or `UnCLIPUNetMidBlock2D`.
|
62 |
+
up_block_types (`Tuple[str]`, *optional*, defaults to `("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D")`):
|
63 |
+
Tuple of upsample block types.
|
64 |
+
block_out_channels (`Tuple[int]`, *optional*, defaults to `(224, 448, 672, 896)`):
|
65 |
+
Tuple of block output channels.
|
66 |
+
layers_per_block (`int`, *optional*, defaults to `2`): The number of layers per block.
|
67 |
+
mid_block_scale_factor (`float`, *optional*, defaults to `1`): The scale factor for the mid block.
|
68 |
+
downsample_padding (`int`, *optional*, defaults to `1`): The padding for the downsample convolution.
|
69 |
+
downsample_type (`str`, *optional*, defaults to `conv`):
|
70 |
+
The downsample type for downsampling layers. Choose between "conv" and "resnet"
|
71 |
+
upsample_type (`str`, *optional*, defaults to `conv`):
|
72 |
+
The upsample type for upsampling layers. Choose between "conv" and "resnet"
|
73 |
+
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
|
74 |
+
attention_head_dim (`int`, *optional*, defaults to `8`): The attention head dimension.
|
75 |
+
norm_num_groups (`int`, *optional*, defaults to `32`): The number of groups for normalization.
|
76 |
+
norm_eps (`float`, *optional*, defaults to `1e-5`): The epsilon for normalization.
|
77 |
+
resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
|
78 |
+
for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
|
79 |
+
class_embed_type (`str`, *optional*, defaults to `None`):
|
80 |
+
The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
|
81 |
+
`"timestep"`, or `"identity"`.
|
82 |
+
num_class_embeds (`int`, *optional*, defaults to `None`):
|
83 |
+
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim` when performing class
|
84 |
+
conditioning with `class_embed_type` equal to `None`.
|
85 |
+
"""
|
86 |
+
|
87 |
+
@register_to_config
|
88 |
+
def __init__(
|
89 |
+
self,
|
90 |
+
sample_size: Optional[Union[int, Tuple[int, int]]] = None,
|
91 |
+
in_channels: int = 3,
|
92 |
+
out_channels: int = 3,
|
93 |
+
center_input_sample: bool = False,
|
94 |
+
time_embedding_type: str = "positional",
|
95 |
+
freq_shift: int = 0,
|
96 |
+
flip_sin_to_cos: bool = True,
|
97 |
+
down_block_types: Tuple[str] = ("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D"),
|
98 |
+
up_block_types: Tuple[str] = ("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D"),
|
99 |
+
block_out_channels: Tuple[int] = (224, 448, 672, 896),
|
100 |
+
layers_per_block: int = 2,
|
101 |
+
mid_block_scale_factor: float = 1,
|
102 |
+
downsample_padding: int = 1,
|
103 |
+
downsample_type: str = "conv",
|
104 |
+
upsample_type: str = "conv",
|
105 |
+
act_fn: str = "silu",
|
106 |
+
attention_head_dim: Optional[int] = 8,
|
107 |
+
norm_num_groups: int = 32,
|
108 |
+
norm_eps: float = 1e-5,
|
109 |
+
resnet_time_scale_shift: str = "default",
|
110 |
+
add_attention: bool = True,
|
111 |
+
class_embed_type: Optional[str] = None,
|
112 |
+
num_class_embeds: Optional[int] = None,
|
113 |
+
):
|
114 |
+
super().__init__()
|
115 |
+
|
116 |
+
self.sample_size = sample_size
|
117 |
+
time_embed_dim = block_out_channels[0] * 4
|
118 |
+
|
119 |
+
# Check inputs
|
120 |
+
if len(down_block_types) != len(up_block_types):
|
121 |
+
raise ValueError(
|
122 |
+
f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
|
123 |
+
)
|
124 |
+
|
125 |
+
if len(block_out_channels) != len(down_block_types):
|
126 |
+
raise ValueError(
|
127 |
+
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
|
128 |
+
)
|
129 |
+
|
130 |
+
# input
|
131 |
+
self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1))
|
132 |
+
|
133 |
+
# time
|
134 |
+
if time_embedding_type == "fourier":
|
135 |
+
self.time_proj = GaussianFourierProjection(embedding_size=block_out_channels[0], scale=16)
|
136 |
+
timestep_input_dim = 2 * block_out_channels[0]
|
137 |
+
elif time_embedding_type == "positional":
|
138 |
+
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
|
139 |
+
timestep_input_dim = block_out_channels[0]
|
140 |
+
|
141 |
+
self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
|
142 |
+
|
143 |
+
# class embedding
|
144 |
+
if class_embed_type is None and num_class_embeds is not None:
|
145 |
+
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
|
146 |
+
elif class_embed_type == "timestep":
|
147 |
+
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
|
148 |
+
elif class_embed_type == "identity":
|
149 |
+
self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
|
150 |
+
else:
|
151 |
+
self.class_embedding = None
|
152 |
+
|
153 |
+
self.down_blocks = nn.ModuleList([])
|
154 |
+
self.mid_block = None
|
155 |
+
self.up_blocks = nn.ModuleList([])
|
156 |
+
|
157 |
+
# down
|
158 |
+
output_channel = block_out_channels[0]
|
159 |
+
for i, down_block_type in enumerate(down_block_types):
|
160 |
+
input_channel = output_channel
|
161 |
+
output_channel = block_out_channels[i]
|
162 |
+
is_final_block = i == len(block_out_channels) - 1
|
163 |
+
|
164 |
+
down_block = get_down_block(
|
165 |
+
down_block_type,
|
166 |
+
num_layers=layers_per_block,
|
167 |
+
in_channels=input_channel,
|
168 |
+
out_channels=output_channel,
|
169 |
+
temb_channels=time_embed_dim,
|
170 |
+
add_downsample=not is_final_block,
|
171 |
+
resnet_eps=norm_eps,
|
172 |
+
resnet_act_fn=act_fn,
|
173 |
+
resnet_groups=norm_num_groups,
|
174 |
+
attention_head_dim=attention_head_dim if attention_head_dim is not None else output_channel,
|
175 |
+
downsample_padding=downsample_padding,
|
176 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
177 |
+
downsample_type=downsample_type,
|
178 |
+
)
|
179 |
+
self.down_blocks.append(down_block)
|
180 |
+
|
181 |
+
# mid
|
182 |
+
self.mid_block = UNetMidBlock2D(
|
183 |
+
in_channels=block_out_channels[-1],
|
184 |
+
temb_channels=time_embed_dim,
|
185 |
+
resnet_eps=norm_eps,
|
186 |
+
resnet_act_fn=act_fn,
|
187 |
+
output_scale_factor=mid_block_scale_factor,
|
188 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
189 |
+
attention_head_dim=attention_head_dim if attention_head_dim is not None else block_out_channels[-1],
|
190 |
+
resnet_groups=norm_num_groups,
|
191 |
+
add_attention=add_attention,
|
192 |
+
)
|
193 |
+
|
194 |
+
# up
|
195 |
+
reversed_block_out_channels = list(reversed(block_out_channels))
|
196 |
+
output_channel = reversed_block_out_channels[0]
|
197 |
+
for i, up_block_type in enumerate(up_block_types):
|
198 |
+
prev_output_channel = output_channel
|
199 |
+
output_channel = reversed_block_out_channels[i]
|
200 |
+
input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
|
201 |
+
|
202 |
+
is_final_block = i == len(block_out_channels) - 1
|
203 |
+
|
204 |
+
up_block = get_up_block(
|
205 |
+
up_block_type,
|
206 |
+
num_layers=layers_per_block + 1,
|
207 |
+
in_channels=input_channel,
|
208 |
+
out_channels=output_channel,
|
209 |
+
prev_output_channel=prev_output_channel,
|
210 |
+
temb_channels=time_embed_dim,
|
211 |
+
add_upsample=not is_final_block,
|
212 |
+
resnet_eps=norm_eps,
|
213 |
+
resnet_act_fn=act_fn,
|
214 |
+
resnet_groups=norm_num_groups,
|
215 |
+
attention_head_dim=attention_head_dim if attention_head_dim is not None else output_channel,
|
216 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
217 |
+
upsample_type=upsample_type,
|
218 |
+
)
|
219 |
+
self.up_blocks.append(up_block)
|
220 |
+
prev_output_channel = output_channel
|
221 |
+
|
222 |
+
# out
|
223 |
+
num_groups_out = norm_num_groups if norm_num_groups is not None else min(block_out_channels[0] // 4, 32)
|
224 |
+
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=num_groups_out, eps=norm_eps)
|
225 |
+
self.conv_act = nn.SiLU()
|
226 |
+
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
|
227 |
+
|
228 |
+
def forward(
|
229 |
+
self,
|
230 |
+
sample: torch.FloatTensor,
|
231 |
+
timestep: Union[torch.Tensor, float, int],
|
232 |
+
class_labels: Optional[torch.Tensor] = None,
|
233 |
+
return_dict: bool = True,
|
234 |
+
) -> Union[UNet2DOutput, Tuple]:
|
235 |
+
r"""
|
236 |
+
The [`UNet2DModel`] forward method.
|
237 |
+
|
238 |
+
Args:
|
239 |
+
sample (`torch.FloatTensor`):
|
240 |
+
The noisy input tensor with the following shape `(batch, channel, height, width)`.
|
241 |
+
timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
|
242 |
+
class_labels (`torch.FloatTensor`, *optional*, defaults to `None`):
|
243 |
+
Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
|
244 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
245 |
+
Whether or not to return a [`~models.unet_2d.UNet2DOutput`] instead of a plain tuple.
|
246 |
+
|
247 |
+
Returns:
|
248 |
+
[`~models.unet_2d.UNet2DOutput`] or `tuple`:
|
249 |
+
If `return_dict` is True, an [`~models.unet_2d.UNet2DOutput`] is returned, otherwise a `tuple` is
|
250 |
+
returned where the first element is the sample tensor.
|
251 |
+
"""
|
252 |
+
# 0. center input if necessary
|
253 |
+
if self.config.center_input_sample:
|
254 |
+
sample = 2 * sample - 1.0
|
255 |
+
|
256 |
+
# 1. time
|
257 |
+
timesteps = timestep
|
258 |
+
if not torch.is_tensor(timesteps):
|
259 |
+
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
|
260 |
+
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
|
261 |
+
timesteps = timesteps[None].to(sample.device)
|
262 |
+
|
263 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
264 |
+
timesteps = timesteps * torch.ones(sample.shape[0], dtype=timesteps.dtype, device=timesteps.device)
|
265 |
+
|
266 |
+
t_emb = self.time_proj(timesteps)
|
267 |
+
|
268 |
+
# timesteps does not contain any weights and will always return f32 tensors
|
269 |
+
# but time_embedding might actually be running in fp16. so we need to cast here.
|
270 |
+
# there might be better ways to encapsulate this.
|
271 |
+
t_emb = t_emb.to(dtype=self.dtype)
|
272 |
+
emb = self.time_embedding(t_emb)
|
273 |
+
|
274 |
+
if self.class_embedding is not None:
|
275 |
+
if class_labels is None:
|
276 |
+
raise ValueError("class_labels should be provided when doing class conditioning")
|
277 |
+
|
278 |
+
if self.config.class_embed_type == "timestep":
|
279 |
+
class_labels = self.time_proj(class_labels)
|
280 |
+
|
281 |
+
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
|
282 |
+
emb = emb + class_emb
|
283 |
+
|
284 |
+
# 2. pre-process
|
285 |
+
skip_sample = sample
|
286 |
+
sample = self.conv_in(sample)
|
287 |
+
|
288 |
+
# 3. down
|
289 |
+
down_block_res_samples = (sample,)
|
290 |
+
for downsample_block in self.down_blocks:
|
291 |
+
if hasattr(downsample_block, "skip_conv"):
|
292 |
+
sample, res_samples, skip_sample = downsample_block(
|
293 |
+
hidden_states=sample, temb=emb, skip_sample=skip_sample
|
294 |
+
)
|
295 |
+
else:
|
296 |
+
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
|
297 |
+
|
298 |
+
down_block_res_samples += res_samples
|
299 |
+
|
300 |
+
# 4. mid
|
301 |
+
sample = self.mid_block(sample, emb)
|
302 |
+
|
303 |
+
# 5. up
|
304 |
+
skip_sample = None
|
305 |
+
for upsample_block in self.up_blocks:
|
306 |
+
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
|
307 |
+
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
|
308 |
+
|
309 |
+
if hasattr(upsample_block, "skip_conv"):
|
310 |
+
sample, skip_sample = upsample_block(sample, res_samples, emb, skip_sample)
|
311 |
+
else:
|
312 |
+
sample = upsample_block(sample, res_samples, emb)
|
313 |
+
|
314 |
+
# 6. post-process
|
315 |
+
sample = self.conv_norm_out(sample)
|
316 |
+
sample = self.conv_act(sample)
|
317 |
+
sample = self.conv_out(sample)
|
318 |
+
|
319 |
+
if skip_sample is not None:
|
320 |
+
sample += skip_sample
|
321 |
+
|
322 |
+
if self.config.time_embedding_type == "fourier":
|
323 |
+
timesteps = timesteps.reshape((sample.shape[0], *([1] * len(sample.shape[1:]))))
|
324 |
+
sample = sample / timesteps
|
325 |
+
|
326 |
+
if not return_dict:
|
327 |
+
return (sample,)
|
328 |
+
|
329 |
+
return UNet2DOutput(sample=sample)
|
6DoF/diffusers/models/unet_2d_blocks.py
ADDED
The diff for this file is too large to render.
See raw diff
|
|
6DoF/diffusers/models/unet_2d_blocks_flax.py
ADDED
@@ -0,0 +1,377 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import flax.linen as nn
|
16 |
+
import jax.numpy as jnp
|
17 |
+
|
18 |
+
from .attention_flax import FlaxTransformer2DModel
|
19 |
+
from .resnet_flax import FlaxDownsample2D, FlaxResnetBlock2D, FlaxUpsample2D
|
20 |
+
|
21 |
+
|
22 |
+
class FlaxCrossAttnDownBlock2D(nn.Module):
|
23 |
+
r"""
|
24 |
+
Cross Attention 2D Downsizing block - original architecture from Unet transformers:
|
25 |
+
https://arxiv.org/abs/2103.06104
|
26 |
+
|
27 |
+
Parameters:
|
28 |
+
in_channels (:obj:`int`):
|
29 |
+
Input channels
|
30 |
+
out_channels (:obj:`int`):
|
31 |
+
Output channels
|
32 |
+
dropout (:obj:`float`, *optional*, defaults to 0.0):
|
33 |
+
Dropout rate
|
34 |
+
num_layers (:obj:`int`, *optional*, defaults to 1):
|
35 |
+
Number of attention blocks layers
|
36 |
+
num_attention_heads (:obj:`int`, *optional*, defaults to 1):
|
37 |
+
Number of attention heads of each spatial transformer block
|
38 |
+
add_downsample (:obj:`bool`, *optional*, defaults to `True`):
|
39 |
+
Whether to add downsampling layer before each final output
|
40 |
+
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
|
41 |
+
enable memory efficient attention https://arxiv.org/abs/2112.05682
|
42 |
+
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
|
43 |
+
Parameters `dtype`
|
44 |
+
"""
|
45 |
+
in_channels: int
|
46 |
+
out_channels: int
|
47 |
+
dropout: float = 0.0
|
48 |
+
num_layers: int = 1
|
49 |
+
num_attention_heads: int = 1
|
50 |
+
add_downsample: bool = True
|
51 |
+
use_linear_projection: bool = False
|
52 |
+
only_cross_attention: bool = False
|
53 |
+
use_memory_efficient_attention: bool = False
|
54 |
+
dtype: jnp.dtype = jnp.float32
|
55 |
+
|
56 |
+
def setup(self):
|
57 |
+
resnets = []
|
58 |
+
attentions = []
|
59 |
+
|
60 |
+
for i in range(self.num_layers):
|
61 |
+
in_channels = self.in_channels if i == 0 else self.out_channels
|
62 |
+
|
63 |
+
res_block = FlaxResnetBlock2D(
|
64 |
+
in_channels=in_channels,
|
65 |
+
out_channels=self.out_channels,
|
66 |
+
dropout_prob=self.dropout,
|
67 |
+
dtype=self.dtype,
|
68 |
+
)
|
69 |
+
resnets.append(res_block)
|
70 |
+
|
71 |
+
attn_block = FlaxTransformer2DModel(
|
72 |
+
in_channels=self.out_channels,
|
73 |
+
n_heads=self.num_attention_heads,
|
74 |
+
d_head=self.out_channels // self.num_attention_heads,
|
75 |
+
depth=1,
|
76 |
+
use_linear_projection=self.use_linear_projection,
|
77 |
+
only_cross_attention=self.only_cross_attention,
|
78 |
+
use_memory_efficient_attention=self.use_memory_efficient_attention,
|
79 |
+
dtype=self.dtype,
|
80 |
+
)
|
81 |
+
attentions.append(attn_block)
|
82 |
+
|
83 |
+
self.resnets = resnets
|
84 |
+
self.attentions = attentions
|
85 |
+
|
86 |
+
if self.add_downsample:
|
87 |
+
self.downsamplers_0 = FlaxDownsample2D(self.out_channels, dtype=self.dtype)
|
88 |
+
|
89 |
+
def __call__(self, hidden_states, temb, encoder_hidden_states, deterministic=True):
|
90 |
+
output_states = ()
|
91 |
+
|
92 |
+
for resnet, attn in zip(self.resnets, self.attentions):
|
93 |
+
hidden_states = resnet(hidden_states, temb, deterministic=deterministic)
|
94 |
+
hidden_states = attn(hidden_states, encoder_hidden_states, deterministic=deterministic)
|
95 |
+
output_states += (hidden_states,)
|
96 |
+
|
97 |
+
if self.add_downsample:
|
98 |
+
hidden_states = self.downsamplers_0(hidden_states)
|
99 |
+
output_states += (hidden_states,)
|
100 |
+
|
101 |
+
return hidden_states, output_states
|
102 |
+
|
103 |
+
|
104 |
+
class FlaxDownBlock2D(nn.Module):
|
105 |
+
r"""
|
106 |
+
Flax 2D downsizing block
|
107 |
+
|
108 |
+
Parameters:
|
109 |
+
in_channels (:obj:`int`):
|
110 |
+
Input channels
|
111 |
+
out_channels (:obj:`int`):
|
112 |
+
Output channels
|
113 |
+
dropout (:obj:`float`, *optional*, defaults to 0.0):
|
114 |
+
Dropout rate
|
115 |
+
num_layers (:obj:`int`, *optional*, defaults to 1):
|
116 |
+
Number of attention blocks layers
|
117 |
+
add_downsample (:obj:`bool`, *optional*, defaults to `True`):
|
118 |
+
Whether to add downsampling layer before each final output
|
119 |
+
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
|
120 |
+
Parameters `dtype`
|
121 |
+
"""
|
122 |
+
in_channels: int
|
123 |
+
out_channels: int
|
124 |
+
dropout: float = 0.0
|
125 |
+
num_layers: int = 1
|
126 |
+
add_downsample: bool = True
|
127 |
+
dtype: jnp.dtype = jnp.float32
|
128 |
+
|
129 |
+
def setup(self):
|
130 |
+
resnets = []
|
131 |
+
|
132 |
+
for i in range(self.num_layers):
|
133 |
+
in_channels = self.in_channels if i == 0 else self.out_channels
|
134 |
+
|
135 |
+
res_block = FlaxResnetBlock2D(
|
136 |
+
in_channels=in_channels,
|
137 |
+
out_channels=self.out_channels,
|
138 |
+
dropout_prob=self.dropout,
|
139 |
+
dtype=self.dtype,
|
140 |
+
)
|
141 |
+
resnets.append(res_block)
|
142 |
+
self.resnets = resnets
|
143 |
+
|
144 |
+
if self.add_downsample:
|
145 |
+
self.downsamplers_0 = FlaxDownsample2D(self.out_channels, dtype=self.dtype)
|
146 |
+
|
147 |
+
def __call__(self, hidden_states, temb, deterministic=True):
|
148 |
+
output_states = ()
|
149 |
+
|
150 |
+
for resnet in self.resnets:
|
151 |
+
hidden_states = resnet(hidden_states, temb, deterministic=deterministic)
|
152 |
+
output_states += (hidden_states,)
|
153 |
+
|
154 |
+
if self.add_downsample:
|
155 |
+
hidden_states = self.downsamplers_0(hidden_states)
|
156 |
+
output_states += (hidden_states,)
|
157 |
+
|
158 |
+
return hidden_states, output_states
|
159 |
+
|
160 |
+
|
161 |
+
class FlaxCrossAttnUpBlock2D(nn.Module):
|
162 |
+
r"""
|
163 |
+
Cross Attention 2D Upsampling block - original architecture from Unet transformers:
|
164 |
+
https://arxiv.org/abs/2103.06104
|
165 |
+
|
166 |
+
Parameters:
|
167 |
+
in_channels (:obj:`int`):
|
168 |
+
Input channels
|
169 |
+
out_channels (:obj:`int`):
|
170 |
+
Output channels
|
171 |
+
dropout (:obj:`float`, *optional*, defaults to 0.0):
|
172 |
+
Dropout rate
|
173 |
+
num_layers (:obj:`int`, *optional*, defaults to 1):
|
174 |
+
Number of attention blocks layers
|
175 |
+
num_attention_heads (:obj:`int`, *optional*, defaults to 1):
|
176 |
+
Number of attention heads of each spatial transformer block
|
177 |
+
add_upsample (:obj:`bool`, *optional*, defaults to `True`):
|
178 |
+
Whether to add upsampling layer before each final output
|
179 |
+
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
|
180 |
+
enable memory efficient attention https://arxiv.org/abs/2112.05682
|
181 |
+
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
|
182 |
+
Parameters `dtype`
|
183 |
+
"""
|
184 |
+
in_channels: int
|
185 |
+
out_channels: int
|
186 |
+
prev_output_channel: int
|
187 |
+
dropout: float = 0.0
|
188 |
+
num_layers: int = 1
|
189 |
+
num_attention_heads: int = 1
|
190 |
+
add_upsample: bool = True
|
191 |
+
use_linear_projection: bool = False
|
192 |
+
only_cross_attention: bool = False
|
193 |
+
use_memory_efficient_attention: bool = False
|
194 |
+
dtype: jnp.dtype = jnp.float32
|
195 |
+
|
196 |
+
def setup(self):
|
197 |
+
resnets = []
|
198 |
+
attentions = []
|
199 |
+
|
200 |
+
for i in range(self.num_layers):
|
201 |
+
res_skip_channels = self.in_channels if (i == self.num_layers - 1) else self.out_channels
|
202 |
+
resnet_in_channels = self.prev_output_channel if i == 0 else self.out_channels
|
203 |
+
|
204 |
+
res_block = FlaxResnetBlock2D(
|
205 |
+
in_channels=resnet_in_channels + res_skip_channels,
|
206 |
+
out_channels=self.out_channels,
|
207 |
+
dropout_prob=self.dropout,
|
208 |
+
dtype=self.dtype,
|
209 |
+
)
|
210 |
+
resnets.append(res_block)
|
211 |
+
|
212 |
+
attn_block = FlaxTransformer2DModel(
|
213 |
+
in_channels=self.out_channels,
|
214 |
+
n_heads=self.num_attention_heads,
|
215 |
+
d_head=self.out_channels // self.num_attention_heads,
|
216 |
+
depth=1,
|
217 |
+
use_linear_projection=self.use_linear_projection,
|
218 |
+
only_cross_attention=self.only_cross_attention,
|
219 |
+
use_memory_efficient_attention=self.use_memory_efficient_attention,
|
220 |
+
dtype=self.dtype,
|
221 |
+
)
|
222 |
+
attentions.append(attn_block)
|
223 |
+
|
224 |
+
self.resnets = resnets
|
225 |
+
self.attentions = attentions
|
226 |
+
|
227 |
+
if self.add_upsample:
|
228 |
+
self.upsamplers_0 = FlaxUpsample2D(self.out_channels, dtype=self.dtype)
|
229 |
+
|
230 |
+
def __call__(self, hidden_states, res_hidden_states_tuple, temb, encoder_hidden_states, deterministic=True):
|
231 |
+
for resnet, attn in zip(self.resnets, self.attentions):
|
232 |
+
# pop res hidden states
|
233 |
+
res_hidden_states = res_hidden_states_tuple[-1]
|
234 |
+
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
235 |
+
hidden_states = jnp.concatenate((hidden_states, res_hidden_states), axis=-1)
|
236 |
+
|
237 |
+
hidden_states = resnet(hidden_states, temb, deterministic=deterministic)
|
238 |
+
hidden_states = attn(hidden_states, encoder_hidden_states, deterministic=deterministic)
|
239 |
+
|
240 |
+
if self.add_upsample:
|
241 |
+
hidden_states = self.upsamplers_0(hidden_states)
|
242 |
+
|
243 |
+
return hidden_states
|
244 |
+
|
245 |
+
|
246 |
+
class FlaxUpBlock2D(nn.Module):
|
247 |
+
r"""
|
248 |
+
Flax 2D upsampling block
|
249 |
+
|
250 |
+
Parameters:
|
251 |
+
in_channels (:obj:`int`):
|
252 |
+
Input channels
|
253 |
+
out_channels (:obj:`int`):
|
254 |
+
Output channels
|
255 |
+
prev_output_channel (:obj:`int`):
|
256 |
+
Output channels from the previous block
|
257 |
+
dropout (:obj:`float`, *optional*, defaults to 0.0):
|
258 |
+
Dropout rate
|
259 |
+
num_layers (:obj:`int`, *optional*, defaults to 1):
|
260 |
+
Number of attention blocks layers
|
261 |
+
add_downsample (:obj:`bool`, *optional*, defaults to `True`):
|
262 |
+
Whether to add downsampling layer before each final output
|
263 |
+
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
|
264 |
+
Parameters `dtype`
|
265 |
+
"""
|
266 |
+
in_channels: int
|
267 |
+
out_channels: int
|
268 |
+
prev_output_channel: int
|
269 |
+
dropout: float = 0.0
|
270 |
+
num_layers: int = 1
|
271 |
+
add_upsample: bool = True
|
272 |
+
dtype: jnp.dtype = jnp.float32
|
273 |
+
|
274 |
+
def setup(self):
|
275 |
+
resnets = []
|
276 |
+
|
277 |
+
for i in range(self.num_layers):
|
278 |
+
res_skip_channels = self.in_channels if (i == self.num_layers - 1) else self.out_channels
|
279 |
+
resnet_in_channels = self.prev_output_channel if i == 0 else self.out_channels
|
280 |
+
|
281 |
+
res_block = FlaxResnetBlock2D(
|
282 |
+
in_channels=resnet_in_channels + res_skip_channels,
|
283 |
+
out_channels=self.out_channels,
|
284 |
+
dropout_prob=self.dropout,
|
285 |
+
dtype=self.dtype,
|
286 |
+
)
|
287 |
+
resnets.append(res_block)
|
288 |
+
|
289 |
+
self.resnets = resnets
|
290 |
+
|
291 |
+
if self.add_upsample:
|
292 |
+
self.upsamplers_0 = FlaxUpsample2D(self.out_channels, dtype=self.dtype)
|
293 |
+
|
294 |
+
def __call__(self, hidden_states, res_hidden_states_tuple, temb, deterministic=True):
|
295 |
+
for resnet in self.resnets:
|
296 |
+
# pop res hidden states
|
297 |
+
res_hidden_states = res_hidden_states_tuple[-1]
|
298 |
+
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
299 |
+
hidden_states = jnp.concatenate((hidden_states, res_hidden_states), axis=-1)
|
300 |
+
|
301 |
+
hidden_states = resnet(hidden_states, temb, deterministic=deterministic)
|
302 |
+
|
303 |
+
if self.add_upsample:
|
304 |
+
hidden_states = self.upsamplers_0(hidden_states)
|
305 |
+
|
306 |
+
return hidden_states
|
307 |
+
|
308 |
+
|
309 |
+
class FlaxUNetMidBlock2DCrossAttn(nn.Module):
|
310 |
+
r"""
|
311 |
+
Cross Attention 2D Mid-level block - original architecture from Unet transformers: https://arxiv.org/abs/2103.06104
|
312 |
+
|
313 |
+
Parameters:
|
314 |
+
in_channels (:obj:`int`):
|
315 |
+
Input channels
|
316 |
+
dropout (:obj:`float`, *optional*, defaults to 0.0):
|
317 |
+
Dropout rate
|
318 |
+
num_layers (:obj:`int`, *optional*, defaults to 1):
|
319 |
+
Number of attention blocks layers
|
320 |
+
num_attention_heads (:obj:`int`, *optional*, defaults to 1):
|
321 |
+
Number of attention heads of each spatial transformer block
|
322 |
+
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
|
323 |
+
enable memory efficient attention https://arxiv.org/abs/2112.05682
|
324 |
+
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
|
325 |
+
Parameters `dtype`
|
326 |
+
"""
|
327 |
+
in_channels: int
|
328 |
+
dropout: float = 0.0
|
329 |
+
num_layers: int = 1
|
330 |
+
num_attention_heads: int = 1
|
331 |
+
use_linear_projection: bool = False
|
332 |
+
use_memory_efficient_attention: bool = False
|
333 |
+
dtype: jnp.dtype = jnp.float32
|
334 |
+
|
335 |
+
def setup(self):
|
336 |
+
# there is always at least one resnet
|
337 |
+
resnets = [
|
338 |
+
FlaxResnetBlock2D(
|
339 |
+
in_channels=self.in_channels,
|
340 |
+
out_channels=self.in_channels,
|
341 |
+
dropout_prob=self.dropout,
|
342 |
+
dtype=self.dtype,
|
343 |
+
)
|
344 |
+
]
|
345 |
+
|
346 |
+
attentions = []
|
347 |
+
|
348 |
+
for _ in range(self.num_layers):
|
349 |
+
attn_block = FlaxTransformer2DModel(
|
350 |
+
in_channels=self.in_channels,
|
351 |
+
n_heads=self.num_attention_heads,
|
352 |
+
d_head=self.in_channels // self.num_attention_heads,
|
353 |
+
depth=1,
|
354 |
+
use_linear_projection=self.use_linear_projection,
|
355 |
+
use_memory_efficient_attention=self.use_memory_efficient_attention,
|
356 |
+
dtype=self.dtype,
|
357 |
+
)
|
358 |
+
attentions.append(attn_block)
|
359 |
+
|
360 |
+
res_block = FlaxResnetBlock2D(
|
361 |
+
in_channels=self.in_channels,
|
362 |
+
out_channels=self.in_channels,
|
363 |
+
dropout_prob=self.dropout,
|
364 |
+
dtype=self.dtype,
|
365 |
+
)
|
366 |
+
resnets.append(res_block)
|
367 |
+
|
368 |
+
self.resnets = resnets
|
369 |
+
self.attentions = attentions
|
370 |
+
|
371 |
+
def __call__(self, hidden_states, temb, encoder_hidden_states, deterministic=True):
|
372 |
+
hidden_states = self.resnets[0](hidden_states, temb)
|
373 |
+
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
374 |
+
hidden_states = attn(hidden_states, encoder_hidden_states, deterministic=deterministic)
|
375 |
+
hidden_states = resnet(hidden_states, temb, deterministic=deterministic)
|
376 |
+
|
377 |
+
return hidden_states
|
6DoF/diffusers/models/unet_2d_condition.py
ADDED
@@ -0,0 +1,980 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
from dataclasses import dataclass
|
15 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
16 |
+
|
17 |
+
import torch
|
18 |
+
import torch.nn as nn
|
19 |
+
import torch.utils.checkpoint
|
20 |
+
|
21 |
+
from ..configuration_utils import ConfigMixin, register_to_config
|
22 |
+
from ..loaders import UNet2DConditionLoadersMixin
|
23 |
+
from ..utils import BaseOutput, logging
|
24 |
+
from .activations import get_activation
|
25 |
+
from .attention_processor import AttentionProcessor, AttnProcessor
|
26 |
+
from .embeddings import (
|
27 |
+
GaussianFourierProjection,
|
28 |
+
ImageHintTimeEmbedding,
|
29 |
+
ImageProjection,
|
30 |
+
ImageTimeEmbedding,
|
31 |
+
TextImageProjection,
|
32 |
+
TextImageTimeEmbedding,
|
33 |
+
TextTimeEmbedding,
|
34 |
+
TimestepEmbedding,
|
35 |
+
Timesteps,
|
36 |
+
)
|
37 |
+
from .modeling_utils import ModelMixin
|
38 |
+
from .unet_2d_blocks import (
|
39 |
+
CrossAttnDownBlock2D,
|
40 |
+
CrossAttnUpBlock2D,
|
41 |
+
DownBlock2D,
|
42 |
+
UNetMidBlock2DCrossAttn,
|
43 |
+
UNetMidBlock2DSimpleCrossAttn,
|
44 |
+
UpBlock2D,
|
45 |
+
get_down_block,
|
46 |
+
get_up_block,
|
47 |
+
)
|
48 |
+
|
49 |
+
|
50 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
51 |
+
|
52 |
+
|
53 |
+
@dataclass
|
54 |
+
class UNet2DConditionOutput(BaseOutput):
|
55 |
+
"""
|
56 |
+
The output of [`UNet2DConditionModel`].
|
57 |
+
|
58 |
+
Args:
|
59 |
+
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
60 |
+
The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
|
61 |
+
"""
|
62 |
+
|
63 |
+
sample: torch.FloatTensor = None
|
64 |
+
|
65 |
+
|
66 |
+
class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
67 |
+
r"""
|
68 |
+
A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
|
69 |
+
shaped output.
|
70 |
+
|
71 |
+
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
|
72 |
+
for all models (such as downloading or saving).
|
73 |
+
|
74 |
+
Parameters:
|
75 |
+
sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
|
76 |
+
Height and width of input/output sample.
|
77 |
+
in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample.
|
78 |
+
out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
|
79 |
+
center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
|
80 |
+
flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
|
81 |
+
Whether to flip the sin to cos in the time embedding.
|
82 |
+
freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
|
83 |
+
down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
|
84 |
+
The tuple of downsample blocks to use.
|
85 |
+
mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
|
86 |
+
Block type for middle of UNet, it can be either `UNetMidBlock2DCrossAttn` or
|
87 |
+
`UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.
|
88 |
+
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
|
89 |
+
The tuple of upsample blocks to use.
|
90 |
+
only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):
|
91 |
+
Whether to include self-attention in the basic transformer blocks, see
|
92 |
+
[`~models.attention.BasicTransformerBlock`].
|
93 |
+
block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
|
94 |
+
The tuple of output channels for each block.
|
95 |
+
layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
|
96 |
+
downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
|
97 |
+
mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
|
98 |
+
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
|
99 |
+
norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
|
100 |
+
If `None`, normalization and activation layers is skipped in post-processing.
|
101 |
+
norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
|
102 |
+
cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
|
103 |
+
The dimension of the cross attention features.
|
104 |
+
transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
|
105 |
+
The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
|
106 |
+
[`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
|
107 |
+
[`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
|
108 |
+
encoder_hid_dim (`int`, *optional*, defaults to None):
|
109 |
+
If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
|
110 |
+
dimension to `cross_attention_dim`.
|
111 |
+
encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
|
112 |
+
If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
|
113 |
+
embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
|
114 |
+
attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
|
115 |
+
num_attention_heads (`int`, *optional*):
|
116 |
+
The number of attention heads. If not defined, defaults to `attention_head_dim`
|
117 |
+
resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
|
118 |
+
for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
|
119 |
+
class_embed_type (`str`, *optional*, defaults to `None`):
|
120 |
+
The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
|
121 |
+
`"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
|
122 |
+
addition_embed_type (`str`, *optional*, defaults to `None`):
|
123 |
+
Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
|
124 |
+
"text". "text" will use the `TextTimeEmbedding` layer.
|
125 |
+
addition_time_embed_dim: (`int`, *optional*, defaults to `None`):
|
126 |
+
Dimension for the timestep embeddings.
|
127 |
+
num_class_embeds (`int`, *optional*, defaults to `None`):
|
128 |
+
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
|
129 |
+
class conditioning with `class_embed_type` equal to `None`.
|
130 |
+
time_embedding_type (`str`, *optional*, defaults to `positional`):
|
131 |
+
The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
|
132 |
+
time_embedding_dim (`int`, *optional*, defaults to `None`):
|
133 |
+
An optional override for the dimension of the projected time embedding.
|
134 |
+
time_embedding_act_fn (`str`, *optional*, defaults to `None`):
|
135 |
+
Optional activation function to use only once on the time embeddings before they are passed to the rest of
|
136 |
+
the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`.
|
137 |
+
timestep_post_act (`str`, *optional*, defaults to `None`):
|
138 |
+
The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
|
139 |
+
time_cond_proj_dim (`int`, *optional*, defaults to `None`):
|
140 |
+
The dimension of `cond_proj` layer in the timestep embedding.
|
141 |
+
conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer.
|
142 |
+
conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer.
|
143 |
+
projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when
|
144 |
+
`class_embed_type="projection"`. Required when `class_embed_type="projection"`.
|
145 |
+
class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
|
146 |
+
embeddings with the class embeddings.
|
147 |
+
mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):
|
148 |
+
Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If
|
149 |
+
`only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the
|
150 |
+
`only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False`
|
151 |
+
otherwise.
|
152 |
+
"""
|
153 |
+
|
154 |
+
_supports_gradient_checkpointing = True
|
155 |
+
|
156 |
+
@register_to_config
|
157 |
+
def __init__(
|
158 |
+
self,
|
159 |
+
sample_size: Optional[int] = None,
|
160 |
+
in_channels: int = 4,
|
161 |
+
out_channels: int = 4,
|
162 |
+
center_input_sample: bool = False,
|
163 |
+
flip_sin_to_cos: bool = True,
|
164 |
+
freq_shift: int = 0,
|
165 |
+
down_block_types: Tuple[str] = (
|
166 |
+
"CrossAttnDownBlock2D",
|
167 |
+
"CrossAttnDownBlock2D",
|
168 |
+
"CrossAttnDownBlock2D",
|
169 |
+
"DownBlock2D",
|
170 |
+
),
|
171 |
+
mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
|
172 |
+
up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
|
173 |
+
only_cross_attention: Union[bool, Tuple[bool]] = False,
|
174 |
+
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
|
175 |
+
layers_per_block: Union[int, Tuple[int]] = 2,
|
176 |
+
downsample_padding: int = 1,
|
177 |
+
mid_block_scale_factor: float = 1,
|
178 |
+
act_fn: str = "silu",
|
179 |
+
norm_num_groups: Optional[int] = 32,
|
180 |
+
norm_eps: float = 1e-5,
|
181 |
+
cross_attention_dim: Union[int, Tuple[int]] = 1280,
|
182 |
+
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
|
183 |
+
encoder_hid_dim: Optional[int] = None,
|
184 |
+
encoder_hid_dim_type: Optional[str] = None,
|
185 |
+
attention_head_dim: Union[int, Tuple[int]] = 8,
|
186 |
+
num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
|
187 |
+
dual_cross_attention: bool = False,
|
188 |
+
use_linear_projection: bool = False,
|
189 |
+
class_embed_type: Optional[str] = None,
|
190 |
+
addition_embed_type: Optional[str] = None,
|
191 |
+
addition_time_embed_dim: Optional[int] = None,
|
192 |
+
num_class_embeds: Optional[int] = None,
|
193 |
+
upcast_attention: bool = False,
|
194 |
+
resnet_time_scale_shift: str = "default",
|
195 |
+
resnet_skip_time_act: bool = False,
|
196 |
+
resnet_out_scale_factor: int = 1.0,
|
197 |
+
time_embedding_type: str = "positional",
|
198 |
+
time_embedding_dim: Optional[int] = None,
|
199 |
+
time_embedding_act_fn: Optional[str] = None,
|
200 |
+
timestep_post_act: Optional[str] = None,
|
201 |
+
time_cond_proj_dim: Optional[int] = None,
|
202 |
+
conv_in_kernel: int = 3,
|
203 |
+
conv_out_kernel: int = 3,
|
204 |
+
projection_class_embeddings_input_dim: Optional[int] = None,
|
205 |
+
class_embeddings_concat: bool = False,
|
206 |
+
mid_block_only_cross_attention: Optional[bool] = None,
|
207 |
+
cross_attention_norm: Optional[str] = None,
|
208 |
+
addition_embed_type_num_heads=64,
|
209 |
+
):
|
210 |
+
super().__init__()
|
211 |
+
|
212 |
+
self.sample_size = sample_size
|
213 |
+
|
214 |
+
if num_attention_heads is not None:
|
215 |
+
raise ValueError(
|
216 |
+
"At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
|
217 |
+
)
|
218 |
+
|
219 |
+
# If `num_attention_heads` is not defined (which is the case for most models)
|
220 |
+
# it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
|
221 |
+
# The reason for this behavior is to correct for incorrectly named variables that were introduced
|
222 |
+
# when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
|
223 |
+
# Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
|
224 |
+
# which is why we correct for the naming here.
|
225 |
+
num_attention_heads = num_attention_heads or attention_head_dim
|
226 |
+
|
227 |
+
# Check inputs
|
228 |
+
if len(down_block_types) != len(up_block_types):
|
229 |
+
raise ValueError(
|
230 |
+
f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
|
231 |
+
)
|
232 |
+
|
233 |
+
if len(block_out_channels) != len(down_block_types):
|
234 |
+
raise ValueError(
|
235 |
+
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
|
236 |
+
)
|
237 |
+
|
238 |
+
if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
|
239 |
+
raise ValueError(
|
240 |
+
f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
|
241 |
+
)
|
242 |
+
|
243 |
+
if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
|
244 |
+
raise ValueError(
|
245 |
+
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
|
246 |
+
)
|
247 |
+
|
248 |
+
if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
|
249 |
+
raise ValueError(
|
250 |
+
f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
|
251 |
+
)
|
252 |
+
|
253 |
+
if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
|
254 |
+
raise ValueError(
|
255 |
+
f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
|
256 |
+
)
|
257 |
+
|
258 |
+
if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
|
259 |
+
raise ValueError(
|
260 |
+
f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
|
261 |
+
)
|
262 |
+
|
263 |
+
# input
|
264 |
+
conv_in_padding = (conv_in_kernel - 1) // 2
|
265 |
+
self.conv_in = nn.Conv2d(
|
266 |
+
in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
|
267 |
+
)
|
268 |
+
|
269 |
+
# time
|
270 |
+
if time_embedding_type == "fourier":
|
271 |
+
time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
|
272 |
+
if time_embed_dim % 2 != 0:
|
273 |
+
raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")
|
274 |
+
self.time_proj = GaussianFourierProjection(
|
275 |
+
time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
|
276 |
+
)
|
277 |
+
timestep_input_dim = time_embed_dim
|
278 |
+
elif time_embedding_type == "positional":
|
279 |
+
time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
|
280 |
+
|
281 |
+
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
|
282 |
+
timestep_input_dim = block_out_channels[0]
|
283 |
+
else:
|
284 |
+
raise ValueError(
|
285 |
+
f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
|
286 |
+
)
|
287 |
+
|
288 |
+
self.time_embedding = TimestepEmbedding(
|
289 |
+
timestep_input_dim,
|
290 |
+
time_embed_dim,
|
291 |
+
act_fn=act_fn,
|
292 |
+
post_act_fn=timestep_post_act,
|
293 |
+
cond_proj_dim=time_cond_proj_dim,
|
294 |
+
)
|
295 |
+
|
296 |
+
if encoder_hid_dim_type is None and encoder_hid_dim is not None:
|
297 |
+
encoder_hid_dim_type = "text_proj"
|
298 |
+
self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
|
299 |
+
logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
|
300 |
+
|
301 |
+
if encoder_hid_dim is None and encoder_hid_dim_type is not None:
|
302 |
+
raise ValueError(
|
303 |
+
f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
|
304 |
+
)
|
305 |
+
|
306 |
+
if encoder_hid_dim_type == "text_proj":
|
307 |
+
self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
|
308 |
+
elif encoder_hid_dim_type == "text_image_proj":
|
309 |
+
# image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
|
310 |
+
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
|
311 |
+
# case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
|
312 |
+
self.encoder_hid_proj = TextImageProjection(
|
313 |
+
text_embed_dim=encoder_hid_dim,
|
314 |
+
image_embed_dim=cross_attention_dim,
|
315 |
+
cross_attention_dim=cross_attention_dim,
|
316 |
+
)
|
317 |
+
elif encoder_hid_dim_type == "image_proj":
|
318 |
+
# Kandinsky 2.2
|
319 |
+
self.encoder_hid_proj = ImageProjection(
|
320 |
+
image_embed_dim=encoder_hid_dim,
|
321 |
+
cross_attention_dim=cross_attention_dim,
|
322 |
+
)
|
323 |
+
elif encoder_hid_dim_type is not None:
|
324 |
+
raise ValueError(
|
325 |
+
f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
|
326 |
+
)
|
327 |
+
else:
|
328 |
+
self.encoder_hid_proj = None
|
329 |
+
|
330 |
+
# class embedding
|
331 |
+
if class_embed_type is None and num_class_embeds is not None:
|
332 |
+
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
|
333 |
+
elif class_embed_type == "timestep":
|
334 |
+
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)
|
335 |
+
elif class_embed_type == "identity":
|
336 |
+
self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
|
337 |
+
elif class_embed_type == "projection":
|
338 |
+
if projection_class_embeddings_input_dim is None:
|
339 |
+
raise ValueError(
|
340 |
+
"`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
|
341 |
+
)
|
342 |
+
# The projection `class_embed_type` is the same as the timestep `class_embed_type` except
|
343 |
+
# 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
|
344 |
+
# 2. it projects from an arbitrary input dimension.
|
345 |
+
#
|
346 |
+
# Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
|
347 |
+
# When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
|
348 |
+
# As a result, `TimestepEmbedding` can be passed arbitrary vectors.
|
349 |
+
self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
|
350 |
+
elif class_embed_type == "simple_projection":
|
351 |
+
if projection_class_embeddings_input_dim is None:
|
352 |
+
raise ValueError(
|
353 |
+
"`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
|
354 |
+
)
|
355 |
+
self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim)
|
356 |
+
else:
|
357 |
+
self.class_embedding = None
|
358 |
+
|
359 |
+
if addition_embed_type == "text":
|
360 |
+
if encoder_hid_dim is not None:
|
361 |
+
text_time_embedding_from_dim = encoder_hid_dim
|
362 |
+
else:
|
363 |
+
text_time_embedding_from_dim = cross_attention_dim
|
364 |
+
|
365 |
+
self.add_embedding = TextTimeEmbedding(
|
366 |
+
text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
|
367 |
+
)
|
368 |
+
elif addition_embed_type == "text_image":
|
369 |
+
# text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
|
370 |
+
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
|
371 |
+
# case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
|
372 |
+
self.add_embedding = TextImageTimeEmbedding(
|
373 |
+
text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
|
374 |
+
)
|
375 |
+
elif addition_embed_type == "text_time":
|
376 |
+
self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
|
377 |
+
self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
|
378 |
+
elif addition_embed_type == "image":
|
379 |
+
# Kandinsky 2.2
|
380 |
+
self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
|
381 |
+
elif addition_embed_type == "image_hint":
|
382 |
+
# Kandinsky 2.2 ControlNet
|
383 |
+
self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
|
384 |
+
elif addition_embed_type is not None:
|
385 |
+
raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
|
386 |
+
|
387 |
+
if time_embedding_act_fn is None:
|
388 |
+
self.time_embed_act = None
|
389 |
+
else:
|
390 |
+
self.time_embed_act = get_activation(time_embedding_act_fn)
|
391 |
+
|
392 |
+
self.down_blocks = nn.ModuleList([])
|
393 |
+
self.up_blocks = nn.ModuleList([])
|
394 |
+
|
395 |
+
if isinstance(only_cross_attention, bool):
|
396 |
+
if mid_block_only_cross_attention is None:
|
397 |
+
mid_block_only_cross_attention = only_cross_attention
|
398 |
+
|
399 |
+
only_cross_attention = [only_cross_attention] * len(down_block_types)
|
400 |
+
|
401 |
+
if mid_block_only_cross_attention is None:
|
402 |
+
mid_block_only_cross_attention = False
|
403 |
+
|
404 |
+
if isinstance(num_attention_heads, int):
|
405 |
+
num_attention_heads = (num_attention_heads,) * len(down_block_types)
|
406 |
+
|
407 |
+
if isinstance(attention_head_dim, int):
|
408 |
+
attention_head_dim = (attention_head_dim,) * len(down_block_types)
|
409 |
+
|
410 |
+
if isinstance(cross_attention_dim, int):
|
411 |
+
cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
|
412 |
+
|
413 |
+
if isinstance(layers_per_block, int):
|
414 |
+
layers_per_block = [layers_per_block] * len(down_block_types)
|
415 |
+
|
416 |
+
if isinstance(transformer_layers_per_block, int):
|
417 |
+
transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
|
418 |
+
|
419 |
+
if class_embeddings_concat:
|
420 |
+
# The time embeddings are concatenated with the class embeddings. The dimension of the
|
421 |
+
# time embeddings passed to the down, middle, and up blocks is twice the dimension of the
|
422 |
+
# regular time embeddings
|
423 |
+
blocks_time_embed_dim = time_embed_dim * 2
|
424 |
+
else:
|
425 |
+
blocks_time_embed_dim = time_embed_dim
|
426 |
+
|
427 |
+
# down
|
428 |
+
output_channel = block_out_channels[0]
|
429 |
+
for i, down_block_type in enumerate(down_block_types):
|
430 |
+
input_channel = output_channel
|
431 |
+
output_channel = block_out_channels[i]
|
432 |
+
is_final_block = i == len(block_out_channels) - 1
|
433 |
+
|
434 |
+
down_block = get_down_block(
|
435 |
+
down_block_type,
|
436 |
+
num_layers=layers_per_block[i],
|
437 |
+
transformer_layers_per_block=transformer_layers_per_block[i],
|
438 |
+
in_channels=input_channel,
|
439 |
+
out_channels=output_channel,
|
440 |
+
temb_channels=blocks_time_embed_dim,
|
441 |
+
add_downsample=not is_final_block,
|
442 |
+
resnet_eps=norm_eps,
|
443 |
+
resnet_act_fn=act_fn,
|
444 |
+
resnet_groups=norm_num_groups,
|
445 |
+
cross_attention_dim=cross_attention_dim[i],
|
446 |
+
num_attention_heads=num_attention_heads[i],
|
447 |
+
downsample_padding=downsample_padding,
|
448 |
+
dual_cross_attention=dual_cross_attention,
|
449 |
+
use_linear_projection=use_linear_projection,
|
450 |
+
only_cross_attention=only_cross_attention[i],
|
451 |
+
upcast_attention=upcast_attention,
|
452 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
453 |
+
resnet_skip_time_act=resnet_skip_time_act,
|
454 |
+
resnet_out_scale_factor=resnet_out_scale_factor,
|
455 |
+
cross_attention_norm=cross_attention_norm,
|
456 |
+
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
|
457 |
+
)
|
458 |
+
self.down_blocks.append(down_block)
|
459 |
+
|
460 |
+
# mid
|
461 |
+
if mid_block_type == "UNetMidBlock2DCrossAttn":
|
462 |
+
self.mid_block = UNetMidBlock2DCrossAttn(
|
463 |
+
transformer_layers_per_block=transformer_layers_per_block[-1],
|
464 |
+
in_channels=block_out_channels[-1],
|
465 |
+
temb_channels=blocks_time_embed_dim,
|
466 |
+
resnet_eps=norm_eps,
|
467 |
+
resnet_act_fn=act_fn,
|
468 |
+
output_scale_factor=mid_block_scale_factor,
|
469 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
470 |
+
cross_attention_dim=cross_attention_dim[-1],
|
471 |
+
num_attention_heads=num_attention_heads[-1],
|
472 |
+
resnet_groups=norm_num_groups,
|
473 |
+
dual_cross_attention=dual_cross_attention,
|
474 |
+
use_linear_projection=use_linear_projection,
|
475 |
+
upcast_attention=upcast_attention,
|
476 |
+
)
|
477 |
+
elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn":
|
478 |
+
self.mid_block = UNetMidBlock2DSimpleCrossAttn(
|
479 |
+
in_channels=block_out_channels[-1],
|
480 |
+
temb_channels=blocks_time_embed_dim,
|
481 |
+
resnet_eps=norm_eps,
|
482 |
+
resnet_act_fn=act_fn,
|
483 |
+
output_scale_factor=mid_block_scale_factor,
|
484 |
+
cross_attention_dim=cross_attention_dim[-1],
|
485 |
+
attention_head_dim=attention_head_dim[-1],
|
486 |
+
resnet_groups=norm_num_groups,
|
487 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
488 |
+
skip_time_act=resnet_skip_time_act,
|
489 |
+
only_cross_attention=mid_block_only_cross_attention,
|
490 |
+
cross_attention_norm=cross_attention_norm,
|
491 |
+
)
|
492 |
+
elif mid_block_type is None:
|
493 |
+
self.mid_block = None
|
494 |
+
else:
|
495 |
+
raise ValueError(f"unknown mid_block_type : {mid_block_type}")
|
496 |
+
|
497 |
+
# count how many layers upsample the images
|
498 |
+
self.num_upsamplers = 0
|
499 |
+
|
500 |
+
# up
|
501 |
+
reversed_block_out_channels = list(reversed(block_out_channels))
|
502 |
+
reversed_num_attention_heads = list(reversed(num_attention_heads))
|
503 |
+
reversed_layers_per_block = list(reversed(layers_per_block))
|
504 |
+
reversed_cross_attention_dim = list(reversed(cross_attention_dim))
|
505 |
+
reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
|
506 |
+
only_cross_attention = list(reversed(only_cross_attention))
|
507 |
+
|
508 |
+
output_channel = reversed_block_out_channels[0]
|
509 |
+
for i, up_block_type in enumerate(up_block_types):
|
510 |
+
is_final_block = i == len(block_out_channels) - 1
|
511 |
+
|
512 |
+
prev_output_channel = output_channel
|
513 |
+
output_channel = reversed_block_out_channels[i]
|
514 |
+
input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
|
515 |
+
|
516 |
+
# add upsample block for all BUT final layer
|
517 |
+
if not is_final_block:
|
518 |
+
add_upsample = True
|
519 |
+
self.num_upsamplers += 1
|
520 |
+
else:
|
521 |
+
add_upsample = False
|
522 |
+
|
523 |
+
up_block = get_up_block(
|
524 |
+
up_block_type,
|
525 |
+
num_layers=reversed_layers_per_block[i] + 1,
|
526 |
+
transformer_layers_per_block=reversed_transformer_layers_per_block[i],
|
527 |
+
in_channels=input_channel,
|
528 |
+
out_channels=output_channel,
|
529 |
+
prev_output_channel=prev_output_channel,
|
530 |
+
temb_channels=blocks_time_embed_dim,
|
531 |
+
add_upsample=add_upsample,
|
532 |
+
resnet_eps=norm_eps,
|
533 |
+
resnet_act_fn=act_fn,
|
534 |
+
resnet_groups=norm_num_groups,
|
535 |
+
cross_attention_dim=reversed_cross_attention_dim[i],
|
536 |
+
num_attention_heads=reversed_num_attention_heads[i],
|
537 |
+
dual_cross_attention=dual_cross_attention,
|
538 |
+
use_linear_projection=use_linear_projection,
|
539 |
+
only_cross_attention=only_cross_attention[i],
|
540 |
+
upcast_attention=upcast_attention,
|
541 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
542 |
+
resnet_skip_time_act=resnet_skip_time_act,
|
543 |
+
resnet_out_scale_factor=resnet_out_scale_factor,
|
544 |
+
cross_attention_norm=cross_attention_norm,
|
545 |
+
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
|
546 |
+
)
|
547 |
+
self.up_blocks.append(up_block)
|
548 |
+
prev_output_channel = output_channel
|
549 |
+
|
550 |
+
# out
|
551 |
+
if norm_num_groups is not None:
|
552 |
+
self.conv_norm_out = nn.GroupNorm(
|
553 |
+
num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
|
554 |
+
)
|
555 |
+
|
556 |
+
self.conv_act = get_activation(act_fn)
|
557 |
+
|
558 |
+
else:
|
559 |
+
self.conv_norm_out = None
|
560 |
+
self.conv_act = None
|
561 |
+
|
562 |
+
conv_out_padding = (conv_out_kernel - 1) // 2
|
563 |
+
self.conv_out = nn.Conv2d(
|
564 |
+
block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
|
565 |
+
)
|
566 |
+
|
567 |
+
@property
|
568 |
+
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
569 |
+
r"""
|
570 |
+
Returns:
|
571 |
+
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
572 |
+
indexed by its weight name.
|
573 |
+
"""
|
574 |
+
# set recursively
|
575 |
+
processors = {}
|
576 |
+
|
577 |
+
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
578 |
+
if hasattr(module, "set_processor"):
|
579 |
+
processors[f"{name}.processor"] = module.processor
|
580 |
+
|
581 |
+
for sub_name, child in module.named_children():
|
582 |
+
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
583 |
+
|
584 |
+
return processors
|
585 |
+
|
586 |
+
for name, module in self.named_children():
|
587 |
+
fn_recursive_add_processors(name, module, processors)
|
588 |
+
|
589 |
+
return processors
|
590 |
+
|
591 |
+
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
592 |
+
r"""
|
593 |
+
Sets the attention processor to use to compute attention.
|
594 |
+
|
595 |
+
Parameters:
|
596 |
+
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
597 |
+
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
598 |
+
for **all** `Attention` layers.
|
599 |
+
|
600 |
+
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
601 |
+
processor. This is strongly recommended when setting trainable attention processors.
|
602 |
+
|
603 |
+
"""
|
604 |
+
count = len(self.attn_processors.keys())
|
605 |
+
|
606 |
+
if isinstance(processor, dict) and len(processor) != count:
|
607 |
+
raise ValueError(
|
608 |
+
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
609 |
+
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
610 |
+
)
|
611 |
+
|
612 |
+
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
613 |
+
if hasattr(module, "set_processor"):
|
614 |
+
if not isinstance(processor, dict):
|
615 |
+
module.set_processor(processor)
|
616 |
+
else:
|
617 |
+
module.set_processor(processor.pop(f"{name}.processor"))
|
618 |
+
|
619 |
+
for sub_name, child in module.named_children():
|
620 |
+
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
621 |
+
|
622 |
+
for name, module in self.named_children():
|
623 |
+
fn_recursive_attn_processor(name, module, processor)
|
624 |
+
|
625 |
+
def set_default_attn_processor(self):
|
626 |
+
"""
|
627 |
+
Disables custom attention processors and sets the default attention implementation.
|
628 |
+
"""
|
629 |
+
self.set_attn_processor(AttnProcessor())
|
630 |
+
|
631 |
+
def set_attention_slice(self, slice_size):
|
632 |
+
r"""
|
633 |
+
Enable sliced attention computation.
|
634 |
+
|
635 |
+
When this option is enabled, the attention module splits the input tensor in slices to compute attention in
|
636 |
+
several steps. This is useful for saving some memory in exchange for a small decrease in speed.
|
637 |
+
|
638 |
+
Args:
|
639 |
+
slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
|
640 |
+
When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
|
641 |
+
`"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
|
642 |
+
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
|
643 |
+
must be a multiple of `slice_size`.
|
644 |
+
"""
|
645 |
+
sliceable_head_dims = []
|
646 |
+
|
647 |
+
def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
|
648 |
+
if hasattr(module, "set_attention_slice"):
|
649 |
+
sliceable_head_dims.append(module.sliceable_head_dim)
|
650 |
+
|
651 |
+
for child in module.children():
|
652 |
+
fn_recursive_retrieve_sliceable_dims(child)
|
653 |
+
|
654 |
+
# retrieve number of attention layers
|
655 |
+
for module in self.children():
|
656 |
+
fn_recursive_retrieve_sliceable_dims(module)
|
657 |
+
|
658 |
+
num_sliceable_layers = len(sliceable_head_dims)
|
659 |
+
|
660 |
+
if slice_size == "auto":
|
661 |
+
# half the attention head size is usually a good trade-off between
|
662 |
+
# speed and memory
|
663 |
+
slice_size = [dim // 2 for dim in sliceable_head_dims]
|
664 |
+
elif slice_size == "max":
|
665 |
+
# make smallest slice possible
|
666 |
+
slice_size = num_sliceable_layers * [1]
|
667 |
+
|
668 |
+
slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
|
669 |
+
|
670 |
+
if len(slice_size) != len(sliceable_head_dims):
|
671 |
+
raise ValueError(
|
672 |
+
f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
|
673 |
+
f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
|
674 |
+
)
|
675 |
+
|
676 |
+
for i in range(len(slice_size)):
|
677 |
+
size = slice_size[i]
|
678 |
+
dim = sliceable_head_dims[i]
|
679 |
+
if size is not None and size > dim:
|
680 |
+
raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
|
681 |
+
|
682 |
+
# Recursively walk through all the children.
|
683 |
+
# Any children which exposes the set_attention_slice method
|
684 |
+
# gets the message
|
685 |
+
def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
|
686 |
+
if hasattr(module, "set_attention_slice"):
|
687 |
+
module.set_attention_slice(slice_size.pop())
|
688 |
+
|
689 |
+
for child in module.children():
|
690 |
+
fn_recursive_set_attention_slice(child, slice_size)
|
691 |
+
|
692 |
+
reversed_slice_size = list(reversed(slice_size))
|
693 |
+
for module in self.children():
|
694 |
+
fn_recursive_set_attention_slice(module, reversed_slice_size)
|
695 |
+
|
696 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
697 |
+
if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D, CrossAttnUpBlock2D, UpBlock2D)):
|
698 |
+
module.gradient_checkpointing = value
|
699 |
+
|
700 |
+
def forward(
|
701 |
+
self,
|
702 |
+
sample: torch.FloatTensor,
|
703 |
+
timestep: Union[torch.Tensor, float, int],
|
704 |
+
encoder_hidden_states: torch.Tensor,
|
705 |
+
class_labels: Optional[torch.Tensor] = None,
|
706 |
+
timestep_cond: Optional[torch.Tensor] = None,
|
707 |
+
attention_mask: Optional[torch.Tensor] = None,
|
708 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
709 |
+
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
710 |
+
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
|
711 |
+
mid_block_additional_residual: Optional[torch.Tensor] = None,
|
712 |
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
713 |
+
return_dict: bool = True,
|
714 |
+
) -> Union[UNet2DConditionOutput, Tuple]:
|
715 |
+
r"""
|
716 |
+
The [`UNet2DConditionModel`] forward method.
|
717 |
+
|
718 |
+
Args:
|
719 |
+
sample (`torch.FloatTensor`):
|
720 |
+
The noisy input tensor with the following shape `(batch, channel, height, width)`.
|
721 |
+
timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
|
722 |
+
encoder_hidden_states (`torch.FloatTensor`):
|
723 |
+
The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
|
724 |
+
encoder_attention_mask (`torch.Tensor`):
|
725 |
+
A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
|
726 |
+
`True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
|
727 |
+
which adds large negative values to the attention scores corresponding to "discard" tokens.
|
728 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
729 |
+
Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
|
730 |
+
tuple.
|
731 |
+
cross_attention_kwargs (`dict`, *optional*):
|
732 |
+
A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
|
733 |
+
added_cond_kwargs: (`dict`, *optional*):
|
734 |
+
A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that
|
735 |
+
are passed along to the UNet blocks.
|
736 |
+
|
737 |
+
Returns:
|
738 |
+
[`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
|
739 |
+
If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
|
740 |
+
a `tuple` is returned where the first element is the sample tensor.
|
741 |
+
"""
|
742 |
+
# By default samples have to be AT least a multiple of the overall upsampling factor.
|
743 |
+
# The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
|
744 |
+
# However, the upsampling interpolation output size can be forced to fit any upsampling size
|
745 |
+
# on the fly if necessary.
|
746 |
+
default_overall_up_factor = 2**self.num_upsamplers
|
747 |
+
|
748 |
+
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
|
749 |
+
forward_upsample_size = False
|
750 |
+
upsample_size = None
|
751 |
+
|
752 |
+
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
|
753 |
+
logger.info("Forward upsample size to force interpolation output size.")
|
754 |
+
forward_upsample_size = True
|
755 |
+
|
756 |
+
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension
|
757 |
+
# expects mask of shape:
|
758 |
+
# [batch, key_tokens]
|
759 |
+
# adds singleton query_tokens dimension:
|
760 |
+
# [batch, 1, key_tokens]
|
761 |
+
# this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
|
762 |
+
# [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
|
763 |
+
# [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
|
764 |
+
if attention_mask is not None:
|
765 |
+
# assume that mask is expressed as:
|
766 |
+
# (1 = keep, 0 = discard)
|
767 |
+
# convert mask into a bias that can be added to attention scores:
|
768 |
+
# (keep = +0, discard = -10000.0)
|
769 |
+
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
|
770 |
+
attention_mask = attention_mask.unsqueeze(1)
|
771 |
+
|
772 |
+
# convert encoder_attention_mask to a bias the same way we do for attention_mask
|
773 |
+
if encoder_attention_mask is not None:
|
774 |
+
encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
|
775 |
+
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
|
776 |
+
|
777 |
+
# 0. center input if necessary
|
778 |
+
if self.config.center_input_sample:
|
779 |
+
sample = 2 * sample - 1.0
|
780 |
+
|
781 |
+
# 1. time
|
782 |
+
timesteps = timestep
|
783 |
+
if not torch.is_tensor(timesteps):
|
784 |
+
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
785 |
+
# This would be a good case for the `match` statement (Python 3.10+)
|
786 |
+
is_mps = sample.device.type == "mps"
|
787 |
+
if isinstance(timestep, float):
|
788 |
+
dtype = torch.float32 if is_mps else torch.float64
|
789 |
+
else:
|
790 |
+
dtype = torch.int32 if is_mps else torch.int64
|
791 |
+
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
|
792 |
+
elif len(timesteps.shape) == 0:
|
793 |
+
timesteps = timesteps[None].to(sample.device)
|
794 |
+
|
795 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
796 |
+
timesteps = timesteps.expand(sample.shape[0])
|
797 |
+
|
798 |
+
t_emb = self.time_proj(timesteps)
|
799 |
+
|
800 |
+
# `Timesteps` does not contain any weights and will always return f32 tensors
|
801 |
+
# but time_embedding might actually be running in fp16. so we need to cast here.
|
802 |
+
# there might be better ways to encapsulate this.
|
803 |
+
t_emb = t_emb.to(dtype=sample.dtype)
|
804 |
+
|
805 |
+
emb = self.time_embedding(t_emb, timestep_cond)
|
806 |
+
aug_emb = None
|
807 |
+
|
808 |
+
if self.class_embedding is not None:
|
809 |
+
if class_labels is None:
|
810 |
+
raise ValueError("class_labels should be provided when num_class_embeds > 0")
|
811 |
+
|
812 |
+
if self.config.class_embed_type == "timestep":
|
813 |
+
class_labels = self.time_proj(class_labels)
|
814 |
+
|
815 |
+
# `Timesteps` does not contain any weights and will always return f32 tensors
|
816 |
+
# there might be better ways to encapsulate this.
|
817 |
+
class_labels = class_labels.to(dtype=sample.dtype)
|
818 |
+
|
819 |
+
class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
|
820 |
+
|
821 |
+
if self.config.class_embeddings_concat:
|
822 |
+
emb = torch.cat([emb, class_emb], dim=-1)
|
823 |
+
else:
|
824 |
+
emb = emb + class_emb
|
825 |
+
|
826 |
+
if self.config.addition_embed_type == "text":
|
827 |
+
aug_emb = self.add_embedding(encoder_hidden_states)
|
828 |
+
elif self.config.addition_embed_type == "text_image":
|
829 |
+
# Kandinsky 2.1 - style
|
830 |
+
if "image_embeds" not in added_cond_kwargs:
|
831 |
+
raise ValueError(
|
832 |
+
f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
|
833 |
+
)
|
834 |
+
|
835 |
+
image_embs = added_cond_kwargs.get("image_embeds")
|
836 |
+
text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
|
837 |
+
aug_emb = self.add_embedding(text_embs, image_embs)
|
838 |
+
elif self.config.addition_embed_type == "text_time":
|
839 |
+
if "text_embeds" not in added_cond_kwargs:
|
840 |
+
raise ValueError(
|
841 |
+
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
|
842 |
+
)
|
843 |
+
text_embeds = added_cond_kwargs.get("text_embeds")
|
844 |
+
if "time_ids" not in added_cond_kwargs:
|
845 |
+
raise ValueError(
|
846 |
+
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
|
847 |
+
)
|
848 |
+
time_ids = added_cond_kwargs.get("time_ids")
|
849 |
+
time_embeds = self.add_time_proj(time_ids.flatten())
|
850 |
+
time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
|
851 |
+
|
852 |
+
add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
|
853 |
+
add_embeds = add_embeds.to(emb.dtype)
|
854 |
+
aug_emb = self.add_embedding(add_embeds)
|
855 |
+
elif self.config.addition_embed_type == "image":
|
856 |
+
# Kandinsky 2.2 - style
|
857 |
+
if "image_embeds" not in added_cond_kwargs:
|
858 |
+
raise ValueError(
|
859 |
+
f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
|
860 |
+
)
|
861 |
+
image_embs = added_cond_kwargs.get("image_embeds")
|
862 |
+
aug_emb = self.add_embedding(image_embs)
|
863 |
+
elif self.config.addition_embed_type == "image_hint":
|
864 |
+
# Kandinsky 2.2 - style
|
865 |
+
if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs:
|
866 |
+
raise ValueError(
|
867 |
+
f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
|
868 |
+
)
|
869 |
+
image_embs = added_cond_kwargs.get("image_embeds")
|
870 |
+
hint = added_cond_kwargs.get("hint")
|
871 |
+
aug_emb, hint = self.add_embedding(image_embs, hint)
|
872 |
+
sample = torch.cat([sample, hint], dim=1)
|
873 |
+
|
874 |
+
emb = emb + aug_emb if aug_emb is not None else emb
|
875 |
+
|
876 |
+
if self.time_embed_act is not None:
|
877 |
+
emb = self.time_embed_act(emb)
|
878 |
+
|
879 |
+
if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
|
880 |
+
encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
|
881 |
+
elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
|
882 |
+
# Kadinsky 2.1 - style
|
883 |
+
if "image_embeds" not in added_cond_kwargs:
|
884 |
+
raise ValueError(
|
885 |
+
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
|
886 |
+
)
|
887 |
+
|
888 |
+
image_embeds = added_cond_kwargs.get("image_embeds")
|
889 |
+
encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
|
890 |
+
elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj":
|
891 |
+
# Kandinsky 2.2 - style
|
892 |
+
if "image_embeds" not in added_cond_kwargs:
|
893 |
+
raise ValueError(
|
894 |
+
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
|
895 |
+
)
|
896 |
+
image_embeds = added_cond_kwargs.get("image_embeds")
|
897 |
+
encoder_hidden_states = self.encoder_hid_proj(image_embeds)
|
898 |
+
# 2. pre-process
|
899 |
+
sample = self.conv_in(sample)
|
900 |
+
|
901 |
+
# 3. down
|
902 |
+
down_block_res_samples = (sample,)
|
903 |
+
for downsample_block in self.down_blocks:
|
904 |
+
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
|
905 |
+
sample, res_samples = downsample_block(
|
906 |
+
hidden_states=sample,
|
907 |
+
temb=emb,
|
908 |
+
encoder_hidden_states=encoder_hidden_states,
|
909 |
+
attention_mask=attention_mask,
|
910 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
911 |
+
encoder_attention_mask=encoder_attention_mask,
|
912 |
+
)
|
913 |
+
else:
|
914 |
+
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
|
915 |
+
|
916 |
+
down_block_res_samples += res_samples
|
917 |
+
|
918 |
+
if down_block_additional_residuals is not None:
|
919 |
+
new_down_block_res_samples = ()
|
920 |
+
|
921 |
+
for down_block_res_sample, down_block_additional_residual in zip(
|
922 |
+
down_block_res_samples, down_block_additional_residuals
|
923 |
+
):
|
924 |
+
down_block_res_sample = down_block_res_sample + down_block_additional_residual
|
925 |
+
new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
|
926 |
+
|
927 |
+
down_block_res_samples = new_down_block_res_samples
|
928 |
+
|
929 |
+
# 4. mid
|
930 |
+
if self.mid_block is not None:
|
931 |
+
sample = self.mid_block(
|
932 |
+
sample,
|
933 |
+
emb,
|
934 |
+
encoder_hidden_states=encoder_hidden_states,
|
935 |
+
attention_mask=attention_mask,
|
936 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
937 |
+
encoder_attention_mask=encoder_attention_mask,
|
938 |
+
)
|
939 |
+
|
940 |
+
if mid_block_additional_residual is not None:
|
941 |
+
sample = sample + mid_block_additional_residual
|
942 |
+
|
943 |
+
# 5. up
|
944 |
+
for i, upsample_block in enumerate(self.up_blocks):
|
945 |
+
is_final_block = i == len(self.up_blocks) - 1
|
946 |
+
|
947 |
+
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
|
948 |
+
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
|
949 |
+
|
950 |
+
# if we have not reached the final block and need to forward the
|
951 |
+
# upsample size, we do it here
|
952 |
+
if not is_final_block and forward_upsample_size:
|
953 |
+
upsample_size = down_block_res_samples[-1].shape[2:]
|
954 |
+
|
955 |
+
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
|
956 |
+
sample = upsample_block(
|
957 |
+
hidden_states=sample,
|
958 |
+
temb=emb,
|
959 |
+
res_hidden_states_tuple=res_samples,
|
960 |
+
encoder_hidden_states=encoder_hidden_states,
|
961 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
962 |
+
upsample_size=upsample_size,
|
963 |
+
attention_mask=attention_mask,
|
964 |
+
encoder_attention_mask=encoder_attention_mask,
|
965 |
+
)
|
966 |
+
else:
|
967 |
+
sample = upsample_block(
|
968 |
+
hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
|
969 |
+
)
|
970 |
+
|
971 |
+
# 6. post-process
|
972 |
+
if self.conv_norm_out:
|
973 |
+
sample = self.conv_norm_out(sample)
|
974 |
+
sample = self.conv_act(sample)
|
975 |
+
sample = self.conv_out(sample)
|
976 |
+
|
977 |
+
if not return_dict:
|
978 |
+
return (sample,)
|
979 |
+
|
980 |
+
return UNet2DConditionOutput(sample=sample)
|
6DoF/diffusers/models/unet_2d_condition_flax.py
ADDED
@@ -0,0 +1,357 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
from typing import Optional, Tuple, Union
|
15 |
+
|
16 |
+
import flax
|
17 |
+
import flax.linen as nn
|
18 |
+
import jax
|
19 |
+
import jax.numpy as jnp
|
20 |
+
from flax.core.frozen_dict import FrozenDict
|
21 |
+
|
22 |
+
from ..configuration_utils import ConfigMixin, flax_register_to_config
|
23 |
+
from ..utils import BaseOutput
|
24 |
+
from .embeddings_flax import FlaxTimestepEmbedding, FlaxTimesteps
|
25 |
+
from .modeling_flax_utils import FlaxModelMixin
|
26 |
+
from .unet_2d_blocks_flax import (
|
27 |
+
FlaxCrossAttnDownBlock2D,
|
28 |
+
FlaxCrossAttnUpBlock2D,
|
29 |
+
FlaxDownBlock2D,
|
30 |
+
FlaxUNetMidBlock2DCrossAttn,
|
31 |
+
FlaxUpBlock2D,
|
32 |
+
)
|
33 |
+
|
34 |
+
|
35 |
+
@flax.struct.dataclass
|
36 |
+
class FlaxUNet2DConditionOutput(BaseOutput):
|
37 |
+
"""
|
38 |
+
The output of [`FlaxUNet2DConditionModel`].
|
39 |
+
|
40 |
+
Args:
|
41 |
+
sample (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)`):
|
42 |
+
The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
|
43 |
+
"""
|
44 |
+
|
45 |
+
sample: jnp.ndarray
|
46 |
+
|
47 |
+
|
48 |
+
@flax_register_to_config
|
49 |
+
class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
|
50 |
+
r"""
|
51 |
+
A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
|
52 |
+
shaped output.
|
53 |
+
|
54 |
+
This model inherits from [`FlaxModelMixin`]. Check the superclass documentation for it's generic methods
|
55 |
+
implemented for all models (such as downloading or saving).
|
56 |
+
|
57 |
+
This model is also a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module)
|
58 |
+
subclass. Use it as a regular Flax Linen module and refer to the Flax documentation for all matters related to its
|
59 |
+
general usage and behavior.
|
60 |
+
|
61 |
+
Inherent JAX features such as the following are supported:
|
62 |
+
- [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
|
63 |
+
- [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
|
64 |
+
- [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
|
65 |
+
- [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
|
66 |
+
|
67 |
+
Parameters:
|
68 |
+
sample_size (`int`, *optional*):
|
69 |
+
The size of the input sample.
|
70 |
+
in_channels (`int`, *optional*, defaults to 4):
|
71 |
+
The number of channels in the input sample.
|
72 |
+
out_channels (`int`, *optional*, defaults to 4):
|
73 |
+
The number of channels in the output.
|
74 |
+
down_block_types (`Tuple[str]`, *optional*, defaults to `("FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxDownBlock2D")`):
|
75 |
+
The tuple of downsample blocks to use.
|
76 |
+
up_block_types (`Tuple[str]`, *optional*, defaults to `("FlaxUpBlock2D", "FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D")`):
|
77 |
+
The tuple of upsample blocks to use.
|
78 |
+
block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
|
79 |
+
The tuple of output channels for each block.
|
80 |
+
layers_per_block (`int`, *optional*, defaults to 2):
|
81 |
+
The number of layers per block.
|
82 |
+
attention_head_dim (`int` or `Tuple[int]`, *optional*, defaults to 8):
|
83 |
+
The dimension of the attention heads.
|
84 |
+
num_attention_heads (`int` or `Tuple[int]`, *optional*):
|
85 |
+
The number of attention heads.
|
86 |
+
cross_attention_dim (`int`, *optional*, defaults to 768):
|
87 |
+
The dimension of the cross attention features.
|
88 |
+
dropout (`float`, *optional*, defaults to 0):
|
89 |
+
Dropout probability for down, up and bottleneck blocks.
|
90 |
+
flip_sin_to_cos (`bool`, *optional*, defaults to `True`):
|
91 |
+
Whether to flip the sin to cos in the time embedding.
|
92 |
+
freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
|
93 |
+
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
|
94 |
+
Enable memory efficient attention as described [here](https://arxiv.org/abs/2112.05682).
|
95 |
+
"""
|
96 |
+
|
97 |
+
sample_size: int = 32
|
98 |
+
in_channels: int = 4
|
99 |
+
out_channels: int = 4
|
100 |
+
down_block_types: Tuple[str] = (
|
101 |
+
"CrossAttnDownBlock2D",
|
102 |
+
"CrossAttnDownBlock2D",
|
103 |
+
"CrossAttnDownBlock2D",
|
104 |
+
"DownBlock2D",
|
105 |
+
)
|
106 |
+
up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")
|
107 |
+
only_cross_attention: Union[bool, Tuple[bool]] = False
|
108 |
+
block_out_channels: Tuple[int] = (320, 640, 1280, 1280)
|
109 |
+
layers_per_block: int = 2
|
110 |
+
attention_head_dim: Union[int, Tuple[int]] = 8
|
111 |
+
num_attention_heads: Optional[Union[int, Tuple[int]]] = None
|
112 |
+
cross_attention_dim: int = 1280
|
113 |
+
dropout: float = 0.0
|
114 |
+
use_linear_projection: bool = False
|
115 |
+
dtype: jnp.dtype = jnp.float32
|
116 |
+
flip_sin_to_cos: bool = True
|
117 |
+
freq_shift: int = 0
|
118 |
+
use_memory_efficient_attention: bool = False
|
119 |
+
|
120 |
+
def init_weights(self, rng: jax.random.KeyArray) -> FrozenDict:
|
121 |
+
# init input tensors
|
122 |
+
sample_shape = (1, self.in_channels, self.sample_size, self.sample_size)
|
123 |
+
sample = jnp.zeros(sample_shape, dtype=jnp.float32)
|
124 |
+
timesteps = jnp.ones((1,), dtype=jnp.int32)
|
125 |
+
encoder_hidden_states = jnp.zeros((1, 1, self.cross_attention_dim), dtype=jnp.float32)
|
126 |
+
|
127 |
+
params_rng, dropout_rng = jax.random.split(rng)
|
128 |
+
rngs = {"params": params_rng, "dropout": dropout_rng}
|
129 |
+
|
130 |
+
return self.init(rngs, sample, timesteps, encoder_hidden_states)["params"]
|
131 |
+
|
132 |
+
def setup(self):
|
133 |
+
block_out_channels = self.block_out_channels
|
134 |
+
time_embed_dim = block_out_channels[0] * 4
|
135 |
+
|
136 |
+
if self.num_attention_heads is not None:
|
137 |
+
raise ValueError(
|
138 |
+
"At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
|
139 |
+
)
|
140 |
+
|
141 |
+
# If `num_attention_heads` is not defined (which is the case for most models)
|
142 |
+
# it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
|
143 |
+
# The reason for this behavior is to correct for incorrectly named variables that were introduced
|
144 |
+
# when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
|
145 |
+
# Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
|
146 |
+
# which is why we correct for the naming here.
|
147 |
+
num_attention_heads = self.num_attention_heads or self.attention_head_dim
|
148 |
+
|
149 |
+
# input
|
150 |
+
self.conv_in = nn.Conv(
|
151 |
+
block_out_channels[0],
|
152 |
+
kernel_size=(3, 3),
|
153 |
+
strides=(1, 1),
|
154 |
+
padding=((1, 1), (1, 1)),
|
155 |
+
dtype=self.dtype,
|
156 |
+
)
|
157 |
+
|
158 |
+
# time
|
159 |
+
self.time_proj = FlaxTimesteps(
|
160 |
+
block_out_channels[0], flip_sin_to_cos=self.flip_sin_to_cos, freq_shift=self.config.freq_shift
|
161 |
+
)
|
162 |
+
self.time_embedding = FlaxTimestepEmbedding(time_embed_dim, dtype=self.dtype)
|
163 |
+
|
164 |
+
only_cross_attention = self.only_cross_attention
|
165 |
+
if isinstance(only_cross_attention, bool):
|
166 |
+
only_cross_attention = (only_cross_attention,) * len(self.down_block_types)
|
167 |
+
|
168 |
+
if isinstance(num_attention_heads, int):
|
169 |
+
num_attention_heads = (num_attention_heads,) * len(self.down_block_types)
|
170 |
+
|
171 |
+
# down
|
172 |
+
down_blocks = []
|
173 |
+
output_channel = block_out_channels[0]
|
174 |
+
for i, down_block_type in enumerate(self.down_block_types):
|
175 |
+
input_channel = output_channel
|
176 |
+
output_channel = block_out_channels[i]
|
177 |
+
is_final_block = i == len(block_out_channels) - 1
|
178 |
+
|
179 |
+
if down_block_type == "CrossAttnDownBlock2D":
|
180 |
+
down_block = FlaxCrossAttnDownBlock2D(
|
181 |
+
in_channels=input_channel,
|
182 |
+
out_channels=output_channel,
|
183 |
+
dropout=self.dropout,
|
184 |
+
num_layers=self.layers_per_block,
|
185 |
+
num_attention_heads=num_attention_heads[i],
|
186 |
+
add_downsample=not is_final_block,
|
187 |
+
use_linear_projection=self.use_linear_projection,
|
188 |
+
only_cross_attention=only_cross_attention[i],
|
189 |
+
use_memory_efficient_attention=self.use_memory_efficient_attention,
|
190 |
+
dtype=self.dtype,
|
191 |
+
)
|
192 |
+
else:
|
193 |
+
down_block = FlaxDownBlock2D(
|
194 |
+
in_channels=input_channel,
|
195 |
+
out_channels=output_channel,
|
196 |
+
dropout=self.dropout,
|
197 |
+
num_layers=self.layers_per_block,
|
198 |
+
add_downsample=not is_final_block,
|
199 |
+
dtype=self.dtype,
|
200 |
+
)
|
201 |
+
|
202 |
+
down_blocks.append(down_block)
|
203 |
+
self.down_blocks = down_blocks
|
204 |
+
|
205 |
+
# mid
|
206 |
+
self.mid_block = FlaxUNetMidBlock2DCrossAttn(
|
207 |
+
in_channels=block_out_channels[-1],
|
208 |
+
dropout=self.dropout,
|
209 |
+
num_attention_heads=num_attention_heads[-1],
|
210 |
+
use_linear_projection=self.use_linear_projection,
|
211 |
+
use_memory_efficient_attention=self.use_memory_efficient_attention,
|
212 |
+
dtype=self.dtype,
|
213 |
+
)
|
214 |
+
|
215 |
+
# up
|
216 |
+
up_blocks = []
|
217 |
+
reversed_block_out_channels = list(reversed(block_out_channels))
|
218 |
+
reversed_num_attention_heads = list(reversed(num_attention_heads))
|
219 |
+
only_cross_attention = list(reversed(only_cross_attention))
|
220 |
+
output_channel = reversed_block_out_channels[0]
|
221 |
+
for i, up_block_type in enumerate(self.up_block_types):
|
222 |
+
prev_output_channel = output_channel
|
223 |
+
output_channel = reversed_block_out_channels[i]
|
224 |
+
input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
|
225 |
+
|
226 |
+
is_final_block = i == len(block_out_channels) - 1
|
227 |
+
|
228 |
+
if up_block_type == "CrossAttnUpBlock2D":
|
229 |
+
up_block = FlaxCrossAttnUpBlock2D(
|
230 |
+
in_channels=input_channel,
|
231 |
+
out_channels=output_channel,
|
232 |
+
prev_output_channel=prev_output_channel,
|
233 |
+
num_layers=self.layers_per_block + 1,
|
234 |
+
num_attention_heads=reversed_num_attention_heads[i],
|
235 |
+
add_upsample=not is_final_block,
|
236 |
+
dropout=self.dropout,
|
237 |
+
use_linear_projection=self.use_linear_projection,
|
238 |
+
only_cross_attention=only_cross_attention[i],
|
239 |
+
use_memory_efficient_attention=self.use_memory_efficient_attention,
|
240 |
+
dtype=self.dtype,
|
241 |
+
)
|
242 |
+
else:
|
243 |
+
up_block = FlaxUpBlock2D(
|
244 |
+
in_channels=input_channel,
|
245 |
+
out_channels=output_channel,
|
246 |
+
prev_output_channel=prev_output_channel,
|
247 |
+
num_layers=self.layers_per_block + 1,
|
248 |
+
add_upsample=not is_final_block,
|
249 |
+
dropout=self.dropout,
|
250 |
+
dtype=self.dtype,
|
251 |
+
)
|
252 |
+
|
253 |
+
up_blocks.append(up_block)
|
254 |
+
prev_output_channel = output_channel
|
255 |
+
self.up_blocks = up_blocks
|
256 |
+
|
257 |
+
# out
|
258 |
+
self.conv_norm_out = nn.GroupNorm(num_groups=32, epsilon=1e-5)
|
259 |
+
self.conv_out = nn.Conv(
|
260 |
+
self.out_channels,
|
261 |
+
kernel_size=(3, 3),
|
262 |
+
strides=(1, 1),
|
263 |
+
padding=((1, 1), (1, 1)),
|
264 |
+
dtype=self.dtype,
|
265 |
+
)
|
266 |
+
|
267 |
+
def __call__(
|
268 |
+
self,
|
269 |
+
sample,
|
270 |
+
timesteps,
|
271 |
+
encoder_hidden_states,
|
272 |
+
down_block_additional_residuals=None,
|
273 |
+
mid_block_additional_residual=None,
|
274 |
+
return_dict: bool = True,
|
275 |
+
train: bool = False,
|
276 |
+
) -> Union[FlaxUNet2DConditionOutput, Tuple]:
|
277 |
+
r"""
|
278 |
+
Args:
|
279 |
+
sample (`jnp.ndarray`): (batch, channel, height, width) noisy inputs tensor
|
280 |
+
timestep (`jnp.ndarray` or `float` or `int`): timesteps
|
281 |
+
encoder_hidden_states (`jnp.ndarray`): (batch_size, sequence_length, hidden_size) encoder hidden states
|
282 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
283 |
+
Whether or not to return a [`models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] instead of a
|
284 |
+
plain tuple.
|
285 |
+
train (`bool`, *optional*, defaults to `False`):
|
286 |
+
Use deterministic functions and disable dropout when not training.
|
287 |
+
|
288 |
+
Returns:
|
289 |
+
[`~models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] or `tuple`:
|
290 |
+
[`~models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`.
|
291 |
+
When returning a tuple, the first element is the sample tensor.
|
292 |
+
"""
|
293 |
+
# 1. time
|
294 |
+
if not isinstance(timesteps, jnp.ndarray):
|
295 |
+
timesteps = jnp.array([timesteps], dtype=jnp.int32)
|
296 |
+
elif isinstance(timesteps, jnp.ndarray) and len(timesteps.shape) == 0:
|
297 |
+
timesteps = timesteps.astype(dtype=jnp.float32)
|
298 |
+
timesteps = jnp.expand_dims(timesteps, 0)
|
299 |
+
|
300 |
+
t_emb = self.time_proj(timesteps)
|
301 |
+
t_emb = self.time_embedding(t_emb)
|
302 |
+
|
303 |
+
# 2. pre-process
|
304 |
+
sample = jnp.transpose(sample, (0, 2, 3, 1))
|
305 |
+
sample = self.conv_in(sample)
|
306 |
+
|
307 |
+
# 3. down
|
308 |
+
down_block_res_samples = (sample,)
|
309 |
+
for down_block in self.down_blocks:
|
310 |
+
if isinstance(down_block, FlaxCrossAttnDownBlock2D):
|
311 |
+
sample, res_samples = down_block(sample, t_emb, encoder_hidden_states, deterministic=not train)
|
312 |
+
else:
|
313 |
+
sample, res_samples = down_block(sample, t_emb, deterministic=not train)
|
314 |
+
down_block_res_samples += res_samples
|
315 |
+
|
316 |
+
if down_block_additional_residuals is not None:
|
317 |
+
new_down_block_res_samples = ()
|
318 |
+
|
319 |
+
for down_block_res_sample, down_block_additional_residual in zip(
|
320 |
+
down_block_res_samples, down_block_additional_residuals
|
321 |
+
):
|
322 |
+
down_block_res_sample += down_block_additional_residual
|
323 |
+
new_down_block_res_samples += (down_block_res_sample,)
|
324 |
+
|
325 |
+
down_block_res_samples = new_down_block_res_samples
|
326 |
+
|
327 |
+
# 4. mid
|
328 |
+
sample = self.mid_block(sample, t_emb, encoder_hidden_states, deterministic=not train)
|
329 |
+
|
330 |
+
if mid_block_additional_residual is not None:
|
331 |
+
sample += mid_block_additional_residual
|
332 |
+
|
333 |
+
# 5. up
|
334 |
+
for up_block in self.up_blocks:
|
335 |
+
res_samples = down_block_res_samples[-(self.layers_per_block + 1) :]
|
336 |
+
down_block_res_samples = down_block_res_samples[: -(self.layers_per_block + 1)]
|
337 |
+
if isinstance(up_block, FlaxCrossAttnUpBlock2D):
|
338 |
+
sample = up_block(
|
339 |
+
sample,
|
340 |
+
temb=t_emb,
|
341 |
+
encoder_hidden_states=encoder_hidden_states,
|
342 |
+
res_hidden_states_tuple=res_samples,
|
343 |
+
deterministic=not train,
|
344 |
+
)
|
345 |
+
else:
|
346 |
+
sample = up_block(sample, temb=t_emb, res_hidden_states_tuple=res_samples, deterministic=not train)
|
347 |
+
|
348 |
+
# 6. post-process
|
349 |
+
sample = self.conv_norm_out(sample)
|
350 |
+
sample = nn.silu(sample)
|
351 |
+
sample = self.conv_out(sample)
|
352 |
+
sample = jnp.transpose(sample, (0, 3, 1, 2))
|
353 |
+
|
354 |
+
if not return_dict:
|
355 |
+
return (sample,)
|
356 |
+
|
357 |
+
return FlaxUNet2DConditionOutput(sample=sample)
|
6DoF/diffusers/models/unet_3d_blocks.py
ADDED
@@ -0,0 +1,679 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import torch
|
16 |
+
from torch import nn
|
17 |
+
|
18 |
+
from .resnet import Downsample2D, ResnetBlock2D, TemporalConvLayer, Upsample2D
|
19 |
+
from .transformer_2d import Transformer2DModel
|
20 |
+
from .transformer_temporal import TransformerTemporalModel
|
21 |
+
|
22 |
+
|
23 |
+
def get_down_block(
|
24 |
+
down_block_type,
|
25 |
+
num_layers,
|
26 |
+
in_channels,
|
27 |
+
out_channels,
|
28 |
+
temb_channels,
|
29 |
+
add_downsample,
|
30 |
+
resnet_eps,
|
31 |
+
resnet_act_fn,
|
32 |
+
num_attention_heads,
|
33 |
+
resnet_groups=None,
|
34 |
+
cross_attention_dim=None,
|
35 |
+
downsample_padding=None,
|
36 |
+
dual_cross_attention=False,
|
37 |
+
use_linear_projection=True,
|
38 |
+
only_cross_attention=False,
|
39 |
+
upcast_attention=False,
|
40 |
+
resnet_time_scale_shift="default",
|
41 |
+
):
|
42 |
+
if down_block_type == "DownBlock3D":
|
43 |
+
return DownBlock3D(
|
44 |
+
num_layers=num_layers,
|
45 |
+
in_channels=in_channels,
|
46 |
+
out_channels=out_channels,
|
47 |
+
temb_channels=temb_channels,
|
48 |
+
add_downsample=add_downsample,
|
49 |
+
resnet_eps=resnet_eps,
|
50 |
+
resnet_act_fn=resnet_act_fn,
|
51 |
+
resnet_groups=resnet_groups,
|
52 |
+
downsample_padding=downsample_padding,
|
53 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
54 |
+
)
|
55 |
+
elif down_block_type == "CrossAttnDownBlock3D":
|
56 |
+
if cross_attention_dim is None:
|
57 |
+
raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock3D")
|
58 |
+
return CrossAttnDownBlock3D(
|
59 |
+
num_layers=num_layers,
|
60 |
+
in_channels=in_channels,
|
61 |
+
out_channels=out_channels,
|
62 |
+
temb_channels=temb_channels,
|
63 |
+
add_downsample=add_downsample,
|
64 |
+
resnet_eps=resnet_eps,
|
65 |
+
resnet_act_fn=resnet_act_fn,
|
66 |
+
resnet_groups=resnet_groups,
|
67 |
+
downsample_padding=downsample_padding,
|
68 |
+
cross_attention_dim=cross_attention_dim,
|
69 |
+
num_attention_heads=num_attention_heads,
|
70 |
+
dual_cross_attention=dual_cross_attention,
|
71 |
+
use_linear_projection=use_linear_projection,
|
72 |
+
only_cross_attention=only_cross_attention,
|
73 |
+
upcast_attention=upcast_attention,
|
74 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
75 |
+
)
|
76 |
+
raise ValueError(f"{down_block_type} does not exist.")
|
77 |
+
|
78 |
+
|
79 |
+
def get_up_block(
|
80 |
+
up_block_type,
|
81 |
+
num_layers,
|
82 |
+
in_channels,
|
83 |
+
out_channels,
|
84 |
+
prev_output_channel,
|
85 |
+
temb_channels,
|
86 |
+
add_upsample,
|
87 |
+
resnet_eps,
|
88 |
+
resnet_act_fn,
|
89 |
+
num_attention_heads,
|
90 |
+
resnet_groups=None,
|
91 |
+
cross_attention_dim=None,
|
92 |
+
dual_cross_attention=False,
|
93 |
+
use_linear_projection=True,
|
94 |
+
only_cross_attention=False,
|
95 |
+
upcast_attention=False,
|
96 |
+
resnet_time_scale_shift="default",
|
97 |
+
):
|
98 |
+
if up_block_type == "UpBlock3D":
|
99 |
+
return UpBlock3D(
|
100 |
+
num_layers=num_layers,
|
101 |
+
in_channels=in_channels,
|
102 |
+
out_channels=out_channels,
|
103 |
+
prev_output_channel=prev_output_channel,
|
104 |
+
temb_channels=temb_channels,
|
105 |
+
add_upsample=add_upsample,
|
106 |
+
resnet_eps=resnet_eps,
|
107 |
+
resnet_act_fn=resnet_act_fn,
|
108 |
+
resnet_groups=resnet_groups,
|
109 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
110 |
+
)
|
111 |
+
elif up_block_type == "CrossAttnUpBlock3D":
|
112 |
+
if cross_attention_dim is None:
|
113 |
+
raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock3D")
|
114 |
+
return CrossAttnUpBlock3D(
|
115 |
+
num_layers=num_layers,
|
116 |
+
in_channels=in_channels,
|
117 |
+
out_channels=out_channels,
|
118 |
+
prev_output_channel=prev_output_channel,
|
119 |
+
temb_channels=temb_channels,
|
120 |
+
add_upsample=add_upsample,
|
121 |
+
resnet_eps=resnet_eps,
|
122 |
+
resnet_act_fn=resnet_act_fn,
|
123 |
+
resnet_groups=resnet_groups,
|
124 |
+
cross_attention_dim=cross_attention_dim,
|
125 |
+
num_attention_heads=num_attention_heads,
|
126 |
+
dual_cross_attention=dual_cross_attention,
|
127 |
+
use_linear_projection=use_linear_projection,
|
128 |
+
only_cross_attention=only_cross_attention,
|
129 |
+
upcast_attention=upcast_attention,
|
130 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
131 |
+
)
|
132 |
+
raise ValueError(f"{up_block_type} does not exist.")
|
133 |
+
|
134 |
+
|
135 |
+
class UNetMidBlock3DCrossAttn(nn.Module):
|
136 |
+
def __init__(
|
137 |
+
self,
|
138 |
+
in_channels: int,
|
139 |
+
temb_channels: int,
|
140 |
+
dropout: float = 0.0,
|
141 |
+
num_layers: int = 1,
|
142 |
+
resnet_eps: float = 1e-6,
|
143 |
+
resnet_time_scale_shift: str = "default",
|
144 |
+
resnet_act_fn: str = "swish",
|
145 |
+
resnet_groups: int = 32,
|
146 |
+
resnet_pre_norm: bool = True,
|
147 |
+
num_attention_heads=1,
|
148 |
+
output_scale_factor=1.0,
|
149 |
+
cross_attention_dim=1280,
|
150 |
+
dual_cross_attention=False,
|
151 |
+
use_linear_projection=True,
|
152 |
+
upcast_attention=False,
|
153 |
+
):
|
154 |
+
super().__init__()
|
155 |
+
|
156 |
+
self.has_cross_attention = True
|
157 |
+
self.num_attention_heads = num_attention_heads
|
158 |
+
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
|
159 |
+
|
160 |
+
# there is always at least one resnet
|
161 |
+
resnets = [
|
162 |
+
ResnetBlock2D(
|
163 |
+
in_channels=in_channels,
|
164 |
+
out_channels=in_channels,
|
165 |
+
temb_channels=temb_channels,
|
166 |
+
eps=resnet_eps,
|
167 |
+
groups=resnet_groups,
|
168 |
+
dropout=dropout,
|
169 |
+
time_embedding_norm=resnet_time_scale_shift,
|
170 |
+
non_linearity=resnet_act_fn,
|
171 |
+
output_scale_factor=output_scale_factor,
|
172 |
+
pre_norm=resnet_pre_norm,
|
173 |
+
)
|
174 |
+
]
|
175 |
+
temp_convs = [
|
176 |
+
TemporalConvLayer(
|
177 |
+
in_channels,
|
178 |
+
in_channels,
|
179 |
+
dropout=0.1,
|
180 |
+
)
|
181 |
+
]
|
182 |
+
attentions = []
|
183 |
+
temp_attentions = []
|
184 |
+
|
185 |
+
for _ in range(num_layers):
|
186 |
+
attentions.append(
|
187 |
+
Transformer2DModel(
|
188 |
+
in_channels // num_attention_heads,
|
189 |
+
num_attention_heads,
|
190 |
+
in_channels=in_channels,
|
191 |
+
num_layers=1,
|
192 |
+
cross_attention_dim=cross_attention_dim,
|
193 |
+
norm_num_groups=resnet_groups,
|
194 |
+
use_linear_projection=use_linear_projection,
|
195 |
+
upcast_attention=upcast_attention,
|
196 |
+
)
|
197 |
+
)
|
198 |
+
temp_attentions.append(
|
199 |
+
TransformerTemporalModel(
|
200 |
+
in_channels // num_attention_heads,
|
201 |
+
num_attention_heads,
|
202 |
+
in_channels=in_channels,
|
203 |
+
num_layers=1,
|
204 |
+
cross_attention_dim=cross_attention_dim,
|
205 |
+
norm_num_groups=resnet_groups,
|
206 |
+
)
|
207 |
+
)
|
208 |
+
resnets.append(
|
209 |
+
ResnetBlock2D(
|
210 |
+
in_channels=in_channels,
|
211 |
+
out_channels=in_channels,
|
212 |
+
temb_channels=temb_channels,
|
213 |
+
eps=resnet_eps,
|
214 |
+
groups=resnet_groups,
|
215 |
+
dropout=dropout,
|
216 |
+
time_embedding_norm=resnet_time_scale_shift,
|
217 |
+
non_linearity=resnet_act_fn,
|
218 |
+
output_scale_factor=output_scale_factor,
|
219 |
+
pre_norm=resnet_pre_norm,
|
220 |
+
)
|
221 |
+
)
|
222 |
+
temp_convs.append(
|
223 |
+
TemporalConvLayer(
|
224 |
+
in_channels,
|
225 |
+
in_channels,
|
226 |
+
dropout=0.1,
|
227 |
+
)
|
228 |
+
)
|
229 |
+
|
230 |
+
self.resnets = nn.ModuleList(resnets)
|
231 |
+
self.temp_convs = nn.ModuleList(temp_convs)
|
232 |
+
self.attentions = nn.ModuleList(attentions)
|
233 |
+
self.temp_attentions = nn.ModuleList(temp_attentions)
|
234 |
+
|
235 |
+
def forward(
|
236 |
+
self,
|
237 |
+
hidden_states,
|
238 |
+
temb=None,
|
239 |
+
encoder_hidden_states=None,
|
240 |
+
attention_mask=None,
|
241 |
+
num_frames=1,
|
242 |
+
cross_attention_kwargs=None,
|
243 |
+
):
|
244 |
+
hidden_states = self.resnets[0](hidden_states, temb)
|
245 |
+
hidden_states = self.temp_convs[0](hidden_states, num_frames=num_frames)
|
246 |
+
for attn, temp_attn, resnet, temp_conv in zip(
|
247 |
+
self.attentions, self.temp_attentions, self.resnets[1:], self.temp_convs[1:]
|
248 |
+
):
|
249 |
+
hidden_states = attn(
|
250 |
+
hidden_states,
|
251 |
+
encoder_hidden_states=encoder_hidden_states,
|
252 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
253 |
+
return_dict=False,
|
254 |
+
)[0]
|
255 |
+
hidden_states = temp_attn(
|
256 |
+
hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs, return_dict=False
|
257 |
+
)[0]
|
258 |
+
hidden_states = resnet(hidden_states, temb)
|
259 |
+
hidden_states = temp_conv(hidden_states, num_frames=num_frames)
|
260 |
+
|
261 |
+
return hidden_states
|
262 |
+
|
263 |
+
|
264 |
+
class CrossAttnDownBlock3D(nn.Module):
|
265 |
+
def __init__(
|
266 |
+
self,
|
267 |
+
in_channels: int,
|
268 |
+
out_channels: int,
|
269 |
+
temb_channels: int,
|
270 |
+
dropout: float = 0.0,
|
271 |
+
num_layers: int = 1,
|
272 |
+
resnet_eps: float = 1e-6,
|
273 |
+
resnet_time_scale_shift: str = "default",
|
274 |
+
resnet_act_fn: str = "swish",
|
275 |
+
resnet_groups: int = 32,
|
276 |
+
resnet_pre_norm: bool = True,
|
277 |
+
num_attention_heads=1,
|
278 |
+
cross_attention_dim=1280,
|
279 |
+
output_scale_factor=1.0,
|
280 |
+
downsample_padding=1,
|
281 |
+
add_downsample=True,
|
282 |
+
dual_cross_attention=False,
|
283 |
+
use_linear_projection=False,
|
284 |
+
only_cross_attention=False,
|
285 |
+
upcast_attention=False,
|
286 |
+
):
|
287 |
+
super().__init__()
|
288 |
+
resnets = []
|
289 |
+
attentions = []
|
290 |
+
temp_attentions = []
|
291 |
+
temp_convs = []
|
292 |
+
|
293 |
+
self.has_cross_attention = True
|
294 |
+
self.num_attention_heads = num_attention_heads
|
295 |
+
|
296 |
+
for i in range(num_layers):
|
297 |
+
in_channels = in_channels if i == 0 else out_channels
|
298 |
+
resnets.append(
|
299 |
+
ResnetBlock2D(
|
300 |
+
in_channels=in_channels,
|
301 |
+
out_channels=out_channels,
|
302 |
+
temb_channels=temb_channels,
|
303 |
+
eps=resnet_eps,
|
304 |
+
groups=resnet_groups,
|
305 |
+
dropout=dropout,
|
306 |
+
time_embedding_norm=resnet_time_scale_shift,
|
307 |
+
non_linearity=resnet_act_fn,
|
308 |
+
output_scale_factor=output_scale_factor,
|
309 |
+
pre_norm=resnet_pre_norm,
|
310 |
+
)
|
311 |
+
)
|
312 |
+
temp_convs.append(
|
313 |
+
TemporalConvLayer(
|
314 |
+
out_channels,
|
315 |
+
out_channels,
|
316 |
+
dropout=0.1,
|
317 |
+
)
|
318 |
+
)
|
319 |
+
attentions.append(
|
320 |
+
Transformer2DModel(
|
321 |
+
out_channels // num_attention_heads,
|
322 |
+
num_attention_heads,
|
323 |
+
in_channels=out_channels,
|
324 |
+
num_layers=1,
|
325 |
+
cross_attention_dim=cross_attention_dim,
|
326 |
+
norm_num_groups=resnet_groups,
|
327 |
+
use_linear_projection=use_linear_projection,
|
328 |
+
only_cross_attention=only_cross_attention,
|
329 |
+
upcast_attention=upcast_attention,
|
330 |
+
)
|
331 |
+
)
|
332 |
+
temp_attentions.append(
|
333 |
+
TransformerTemporalModel(
|
334 |
+
out_channels // num_attention_heads,
|
335 |
+
num_attention_heads,
|
336 |
+
in_channels=out_channels,
|
337 |
+
num_layers=1,
|
338 |
+
cross_attention_dim=cross_attention_dim,
|
339 |
+
norm_num_groups=resnet_groups,
|
340 |
+
)
|
341 |
+
)
|
342 |
+
self.resnets = nn.ModuleList(resnets)
|
343 |
+
self.temp_convs = nn.ModuleList(temp_convs)
|
344 |
+
self.attentions = nn.ModuleList(attentions)
|
345 |
+
self.temp_attentions = nn.ModuleList(temp_attentions)
|
346 |
+
|
347 |
+
if add_downsample:
|
348 |
+
self.downsamplers = nn.ModuleList(
|
349 |
+
[
|
350 |
+
Downsample2D(
|
351 |
+
out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
|
352 |
+
)
|
353 |
+
]
|
354 |
+
)
|
355 |
+
else:
|
356 |
+
self.downsamplers = None
|
357 |
+
|
358 |
+
self.gradient_checkpointing = False
|
359 |
+
|
360 |
+
def forward(
|
361 |
+
self,
|
362 |
+
hidden_states,
|
363 |
+
temb=None,
|
364 |
+
encoder_hidden_states=None,
|
365 |
+
attention_mask=None,
|
366 |
+
num_frames=1,
|
367 |
+
cross_attention_kwargs=None,
|
368 |
+
):
|
369 |
+
# TODO(Patrick, William) - attention mask is not used
|
370 |
+
output_states = ()
|
371 |
+
|
372 |
+
for resnet, temp_conv, attn, temp_attn in zip(
|
373 |
+
self.resnets, self.temp_convs, self.attentions, self.temp_attentions
|
374 |
+
):
|
375 |
+
hidden_states = resnet(hidden_states, temb)
|
376 |
+
hidden_states = temp_conv(hidden_states, num_frames=num_frames)
|
377 |
+
hidden_states = attn(
|
378 |
+
hidden_states,
|
379 |
+
encoder_hidden_states=encoder_hidden_states,
|
380 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
381 |
+
return_dict=False,
|
382 |
+
)[0]
|
383 |
+
hidden_states = temp_attn(
|
384 |
+
hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs, return_dict=False
|
385 |
+
)[0]
|
386 |
+
|
387 |
+
output_states += (hidden_states,)
|
388 |
+
|
389 |
+
if self.downsamplers is not None:
|
390 |
+
for downsampler in self.downsamplers:
|
391 |
+
hidden_states = downsampler(hidden_states)
|
392 |
+
|
393 |
+
output_states += (hidden_states,)
|
394 |
+
|
395 |
+
return hidden_states, output_states
|
396 |
+
|
397 |
+
|
398 |
+
class DownBlock3D(nn.Module):
|
399 |
+
def __init__(
|
400 |
+
self,
|
401 |
+
in_channels: int,
|
402 |
+
out_channels: int,
|
403 |
+
temb_channels: int,
|
404 |
+
dropout: float = 0.0,
|
405 |
+
num_layers: int = 1,
|
406 |
+
resnet_eps: float = 1e-6,
|
407 |
+
resnet_time_scale_shift: str = "default",
|
408 |
+
resnet_act_fn: str = "swish",
|
409 |
+
resnet_groups: int = 32,
|
410 |
+
resnet_pre_norm: bool = True,
|
411 |
+
output_scale_factor=1.0,
|
412 |
+
add_downsample=True,
|
413 |
+
downsample_padding=1,
|
414 |
+
):
|
415 |
+
super().__init__()
|
416 |
+
resnets = []
|
417 |
+
temp_convs = []
|
418 |
+
|
419 |
+
for i in range(num_layers):
|
420 |
+
in_channels = in_channels if i == 0 else out_channels
|
421 |
+
resnets.append(
|
422 |
+
ResnetBlock2D(
|
423 |
+
in_channels=in_channels,
|
424 |
+
out_channels=out_channels,
|
425 |
+
temb_channels=temb_channels,
|
426 |
+
eps=resnet_eps,
|
427 |
+
groups=resnet_groups,
|
428 |
+
dropout=dropout,
|
429 |
+
time_embedding_norm=resnet_time_scale_shift,
|
430 |
+
non_linearity=resnet_act_fn,
|
431 |
+
output_scale_factor=output_scale_factor,
|
432 |
+
pre_norm=resnet_pre_norm,
|
433 |
+
)
|
434 |
+
)
|
435 |
+
temp_convs.append(
|
436 |
+
TemporalConvLayer(
|
437 |
+
out_channels,
|
438 |
+
out_channels,
|
439 |
+
dropout=0.1,
|
440 |
+
)
|
441 |
+
)
|
442 |
+
|
443 |
+
self.resnets = nn.ModuleList(resnets)
|
444 |
+
self.temp_convs = nn.ModuleList(temp_convs)
|
445 |
+
|
446 |
+
if add_downsample:
|
447 |
+
self.downsamplers = nn.ModuleList(
|
448 |
+
[
|
449 |
+
Downsample2D(
|
450 |
+
out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
|
451 |
+
)
|
452 |
+
]
|
453 |
+
)
|
454 |
+
else:
|
455 |
+
self.downsamplers = None
|
456 |
+
|
457 |
+
self.gradient_checkpointing = False
|
458 |
+
|
459 |
+
def forward(self, hidden_states, temb=None, num_frames=1):
|
460 |
+
output_states = ()
|
461 |
+
|
462 |
+
for resnet, temp_conv in zip(self.resnets, self.temp_convs):
|
463 |
+
hidden_states = resnet(hidden_states, temb)
|
464 |
+
hidden_states = temp_conv(hidden_states, num_frames=num_frames)
|
465 |
+
|
466 |
+
output_states += (hidden_states,)
|
467 |
+
|
468 |
+
if self.downsamplers is not None:
|
469 |
+
for downsampler in self.downsamplers:
|
470 |
+
hidden_states = downsampler(hidden_states)
|
471 |
+
|
472 |
+
output_states += (hidden_states,)
|
473 |
+
|
474 |
+
return hidden_states, output_states
|
475 |
+
|
476 |
+
|
477 |
+
class CrossAttnUpBlock3D(nn.Module):
|
478 |
+
def __init__(
|
479 |
+
self,
|
480 |
+
in_channels: int,
|
481 |
+
out_channels: int,
|
482 |
+
prev_output_channel: int,
|
483 |
+
temb_channels: int,
|
484 |
+
dropout: float = 0.0,
|
485 |
+
num_layers: int = 1,
|
486 |
+
resnet_eps: float = 1e-6,
|
487 |
+
resnet_time_scale_shift: str = "default",
|
488 |
+
resnet_act_fn: str = "swish",
|
489 |
+
resnet_groups: int = 32,
|
490 |
+
resnet_pre_norm: bool = True,
|
491 |
+
num_attention_heads=1,
|
492 |
+
cross_attention_dim=1280,
|
493 |
+
output_scale_factor=1.0,
|
494 |
+
add_upsample=True,
|
495 |
+
dual_cross_attention=False,
|
496 |
+
use_linear_projection=False,
|
497 |
+
only_cross_attention=False,
|
498 |
+
upcast_attention=False,
|
499 |
+
):
|
500 |
+
super().__init__()
|
501 |
+
resnets = []
|
502 |
+
temp_convs = []
|
503 |
+
attentions = []
|
504 |
+
temp_attentions = []
|
505 |
+
|
506 |
+
self.has_cross_attention = True
|
507 |
+
self.num_attention_heads = num_attention_heads
|
508 |
+
|
509 |
+
for i in range(num_layers):
|
510 |
+
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
|
511 |
+
resnet_in_channels = prev_output_channel if i == 0 else out_channels
|
512 |
+
|
513 |
+
resnets.append(
|
514 |
+
ResnetBlock2D(
|
515 |
+
in_channels=resnet_in_channels + res_skip_channels,
|
516 |
+
out_channels=out_channels,
|
517 |
+
temb_channels=temb_channels,
|
518 |
+
eps=resnet_eps,
|
519 |
+
groups=resnet_groups,
|
520 |
+
dropout=dropout,
|
521 |
+
time_embedding_norm=resnet_time_scale_shift,
|
522 |
+
non_linearity=resnet_act_fn,
|
523 |
+
output_scale_factor=output_scale_factor,
|
524 |
+
pre_norm=resnet_pre_norm,
|
525 |
+
)
|
526 |
+
)
|
527 |
+
temp_convs.append(
|
528 |
+
TemporalConvLayer(
|
529 |
+
out_channels,
|
530 |
+
out_channels,
|
531 |
+
dropout=0.1,
|
532 |
+
)
|
533 |
+
)
|
534 |
+
attentions.append(
|
535 |
+
Transformer2DModel(
|
536 |
+
out_channels // num_attention_heads,
|
537 |
+
num_attention_heads,
|
538 |
+
in_channels=out_channels,
|
539 |
+
num_layers=1,
|
540 |
+
cross_attention_dim=cross_attention_dim,
|
541 |
+
norm_num_groups=resnet_groups,
|
542 |
+
use_linear_projection=use_linear_projection,
|
543 |
+
only_cross_attention=only_cross_attention,
|
544 |
+
upcast_attention=upcast_attention,
|
545 |
+
)
|
546 |
+
)
|
547 |
+
temp_attentions.append(
|
548 |
+
TransformerTemporalModel(
|
549 |
+
out_channels // num_attention_heads,
|
550 |
+
num_attention_heads,
|
551 |
+
in_channels=out_channels,
|
552 |
+
num_layers=1,
|
553 |
+
cross_attention_dim=cross_attention_dim,
|
554 |
+
norm_num_groups=resnet_groups,
|
555 |
+
)
|
556 |
+
)
|
557 |
+
self.resnets = nn.ModuleList(resnets)
|
558 |
+
self.temp_convs = nn.ModuleList(temp_convs)
|
559 |
+
self.attentions = nn.ModuleList(attentions)
|
560 |
+
self.temp_attentions = nn.ModuleList(temp_attentions)
|
561 |
+
|
562 |
+
if add_upsample:
|
563 |
+
self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
|
564 |
+
else:
|
565 |
+
self.upsamplers = None
|
566 |
+
|
567 |
+
self.gradient_checkpointing = False
|
568 |
+
|
569 |
+
def forward(
|
570 |
+
self,
|
571 |
+
hidden_states,
|
572 |
+
res_hidden_states_tuple,
|
573 |
+
temb=None,
|
574 |
+
encoder_hidden_states=None,
|
575 |
+
upsample_size=None,
|
576 |
+
attention_mask=None,
|
577 |
+
num_frames=1,
|
578 |
+
cross_attention_kwargs=None,
|
579 |
+
):
|
580 |
+
# TODO(Patrick, William) - attention mask is not used
|
581 |
+
for resnet, temp_conv, attn, temp_attn in zip(
|
582 |
+
self.resnets, self.temp_convs, self.attentions, self.temp_attentions
|
583 |
+
):
|
584 |
+
# pop res hidden states
|
585 |
+
res_hidden_states = res_hidden_states_tuple[-1]
|
586 |
+
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
587 |
+
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
588 |
+
|
589 |
+
hidden_states = resnet(hidden_states, temb)
|
590 |
+
hidden_states = temp_conv(hidden_states, num_frames=num_frames)
|
591 |
+
hidden_states = attn(
|
592 |
+
hidden_states,
|
593 |
+
encoder_hidden_states=encoder_hidden_states,
|
594 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
595 |
+
return_dict=False,
|
596 |
+
)[0]
|
597 |
+
hidden_states = temp_attn(
|
598 |
+
hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs, return_dict=False
|
599 |
+
)[0]
|
600 |
+
|
601 |
+
if self.upsamplers is not None:
|
602 |
+
for upsampler in self.upsamplers:
|
603 |
+
hidden_states = upsampler(hidden_states, upsample_size)
|
604 |
+
|
605 |
+
return hidden_states
|
606 |
+
|
607 |
+
|
608 |
+
class UpBlock3D(nn.Module):
|
609 |
+
def __init__(
|
610 |
+
self,
|
611 |
+
in_channels: int,
|
612 |
+
prev_output_channel: int,
|
613 |
+
out_channels: int,
|
614 |
+
temb_channels: int,
|
615 |
+
dropout: float = 0.0,
|
616 |
+
num_layers: int = 1,
|
617 |
+
resnet_eps: float = 1e-6,
|
618 |
+
resnet_time_scale_shift: str = "default",
|
619 |
+
resnet_act_fn: str = "swish",
|
620 |
+
resnet_groups: int = 32,
|
621 |
+
resnet_pre_norm: bool = True,
|
622 |
+
output_scale_factor=1.0,
|
623 |
+
add_upsample=True,
|
624 |
+
):
|
625 |
+
super().__init__()
|
626 |
+
resnets = []
|
627 |
+
temp_convs = []
|
628 |
+
|
629 |
+
for i in range(num_layers):
|
630 |
+
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
|
631 |
+
resnet_in_channels = prev_output_channel if i == 0 else out_channels
|
632 |
+
|
633 |
+
resnets.append(
|
634 |
+
ResnetBlock2D(
|
635 |
+
in_channels=resnet_in_channels + res_skip_channels,
|
636 |
+
out_channels=out_channels,
|
637 |
+
temb_channels=temb_channels,
|
638 |
+
eps=resnet_eps,
|
639 |
+
groups=resnet_groups,
|
640 |
+
dropout=dropout,
|
641 |
+
time_embedding_norm=resnet_time_scale_shift,
|
642 |
+
non_linearity=resnet_act_fn,
|
643 |
+
output_scale_factor=output_scale_factor,
|
644 |
+
pre_norm=resnet_pre_norm,
|
645 |
+
)
|
646 |
+
)
|
647 |
+
temp_convs.append(
|
648 |
+
TemporalConvLayer(
|
649 |
+
out_channels,
|
650 |
+
out_channels,
|
651 |
+
dropout=0.1,
|
652 |
+
)
|
653 |
+
)
|
654 |
+
|
655 |
+
self.resnets = nn.ModuleList(resnets)
|
656 |
+
self.temp_convs = nn.ModuleList(temp_convs)
|
657 |
+
|
658 |
+
if add_upsample:
|
659 |
+
self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
|
660 |
+
else:
|
661 |
+
self.upsamplers = None
|
662 |
+
|
663 |
+
self.gradient_checkpointing = False
|
664 |
+
|
665 |
+
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, num_frames=1):
|
666 |
+
for resnet, temp_conv in zip(self.resnets, self.temp_convs):
|
667 |
+
# pop res hidden states
|
668 |
+
res_hidden_states = res_hidden_states_tuple[-1]
|
669 |
+
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
670 |
+
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
671 |
+
|
672 |
+
hidden_states = resnet(hidden_states, temb)
|
673 |
+
hidden_states = temp_conv(hidden_states, num_frames=num_frames)
|
674 |
+
|
675 |
+
if self.upsamplers is not None:
|
676 |
+
for upsampler in self.upsamplers:
|
677 |
+
hidden_states = upsampler(hidden_states, upsample_size)
|
678 |
+
|
679 |
+
return hidden_states
|
6DoF/diffusers/models/unet_3d_condition.py
ADDED
@@ -0,0 +1,627 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 Alibaba DAMO-VILAB and The HuggingFace Team. All rights reserved.
|
2 |
+
# Copyright 2023 The ModelScope Team.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
from dataclasses import dataclass
|
16 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
17 |
+
|
18 |
+
import torch
|
19 |
+
import torch.nn as nn
|
20 |
+
import torch.utils.checkpoint
|
21 |
+
|
22 |
+
from ..configuration_utils import ConfigMixin, register_to_config
|
23 |
+
from ..loaders import UNet2DConditionLoadersMixin
|
24 |
+
from ..utils import BaseOutput, logging
|
25 |
+
from .attention_processor import AttentionProcessor, AttnProcessor
|
26 |
+
from .embeddings import TimestepEmbedding, Timesteps
|
27 |
+
from .modeling_utils import ModelMixin
|
28 |
+
from .transformer_temporal import TransformerTemporalModel
|
29 |
+
from .unet_3d_blocks import (
|
30 |
+
CrossAttnDownBlock3D,
|
31 |
+
CrossAttnUpBlock3D,
|
32 |
+
DownBlock3D,
|
33 |
+
UNetMidBlock3DCrossAttn,
|
34 |
+
UpBlock3D,
|
35 |
+
get_down_block,
|
36 |
+
get_up_block,
|
37 |
+
)
|
38 |
+
|
39 |
+
|
40 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
41 |
+
|
42 |
+
|
43 |
+
@dataclass
|
44 |
+
class UNet3DConditionOutput(BaseOutput):
|
45 |
+
"""
|
46 |
+
The output of [`UNet3DConditionModel`].
|
47 |
+
|
48 |
+
Args:
|
49 |
+
sample (`torch.FloatTensor` of shape `(batch_size, num_frames, num_channels, height, width)`):
|
50 |
+
The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
|
51 |
+
"""
|
52 |
+
|
53 |
+
sample: torch.FloatTensor
|
54 |
+
|
55 |
+
|
56 |
+
class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
57 |
+
r"""
|
58 |
+
A conditional 3D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
|
59 |
+
shaped output.
|
60 |
+
|
61 |
+
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
|
62 |
+
for all models (such as downloading or saving).
|
63 |
+
|
64 |
+
Parameters:
|
65 |
+
sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
|
66 |
+
Height and width of input/output sample.
|
67 |
+
in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample.
|
68 |
+
out_channels (`int`, *optional*, defaults to 4): The number of channels in the output.
|
69 |
+
down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
|
70 |
+
The tuple of downsample blocks to use.
|
71 |
+
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
|
72 |
+
The tuple of upsample blocks to use.
|
73 |
+
block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
|
74 |
+
The tuple of output channels for each block.
|
75 |
+
layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
|
76 |
+
downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
|
77 |
+
mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
|
78 |
+
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
|
79 |
+
norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
|
80 |
+
If `None`, normalization and activation layers is skipped in post-processing.
|
81 |
+
norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
|
82 |
+
cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features.
|
83 |
+
attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
|
84 |
+
num_attention_heads (`int`, *optional*): The number of attention heads.
|
85 |
+
"""
|
86 |
+
|
87 |
+
_supports_gradient_checkpointing = False
|
88 |
+
|
89 |
+
@register_to_config
|
90 |
+
def __init__(
|
91 |
+
self,
|
92 |
+
sample_size: Optional[int] = None,
|
93 |
+
in_channels: int = 4,
|
94 |
+
out_channels: int = 4,
|
95 |
+
down_block_types: Tuple[str] = (
|
96 |
+
"CrossAttnDownBlock3D",
|
97 |
+
"CrossAttnDownBlock3D",
|
98 |
+
"CrossAttnDownBlock3D",
|
99 |
+
"DownBlock3D",
|
100 |
+
),
|
101 |
+
up_block_types: Tuple[str] = ("UpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D"),
|
102 |
+
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
|
103 |
+
layers_per_block: int = 2,
|
104 |
+
downsample_padding: int = 1,
|
105 |
+
mid_block_scale_factor: float = 1,
|
106 |
+
act_fn: str = "silu",
|
107 |
+
norm_num_groups: Optional[int] = 32,
|
108 |
+
norm_eps: float = 1e-5,
|
109 |
+
cross_attention_dim: int = 1024,
|
110 |
+
attention_head_dim: Union[int, Tuple[int]] = 64,
|
111 |
+
num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
|
112 |
+
):
|
113 |
+
super().__init__()
|
114 |
+
|
115 |
+
self.sample_size = sample_size
|
116 |
+
|
117 |
+
if num_attention_heads is not None:
|
118 |
+
raise NotImplementedError(
|
119 |
+
"At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
|
120 |
+
)
|
121 |
+
|
122 |
+
# If `num_attention_heads` is not defined (which is the case for most models)
|
123 |
+
# it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
|
124 |
+
# The reason for this behavior is to correct for incorrectly named variables that were introduced
|
125 |
+
# when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
|
126 |
+
# Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
|
127 |
+
# which is why we correct for the naming here.
|
128 |
+
num_attention_heads = num_attention_heads or attention_head_dim
|
129 |
+
|
130 |
+
# Check inputs
|
131 |
+
if len(down_block_types) != len(up_block_types):
|
132 |
+
raise ValueError(
|
133 |
+
f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
|
134 |
+
)
|
135 |
+
|
136 |
+
if len(block_out_channels) != len(down_block_types):
|
137 |
+
raise ValueError(
|
138 |
+
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
|
139 |
+
)
|
140 |
+
|
141 |
+
if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
|
142 |
+
raise ValueError(
|
143 |
+
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
|
144 |
+
)
|
145 |
+
|
146 |
+
# input
|
147 |
+
conv_in_kernel = 3
|
148 |
+
conv_out_kernel = 3
|
149 |
+
conv_in_padding = (conv_in_kernel - 1) // 2
|
150 |
+
self.conv_in = nn.Conv2d(
|
151 |
+
in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
|
152 |
+
)
|
153 |
+
|
154 |
+
# time
|
155 |
+
time_embed_dim = block_out_channels[0] * 4
|
156 |
+
self.time_proj = Timesteps(block_out_channels[0], True, 0)
|
157 |
+
timestep_input_dim = block_out_channels[0]
|
158 |
+
|
159 |
+
self.time_embedding = TimestepEmbedding(
|
160 |
+
timestep_input_dim,
|
161 |
+
time_embed_dim,
|
162 |
+
act_fn=act_fn,
|
163 |
+
)
|
164 |
+
|
165 |
+
self.transformer_in = TransformerTemporalModel(
|
166 |
+
num_attention_heads=8,
|
167 |
+
attention_head_dim=attention_head_dim,
|
168 |
+
in_channels=block_out_channels[0],
|
169 |
+
num_layers=1,
|
170 |
+
)
|
171 |
+
|
172 |
+
# class embedding
|
173 |
+
self.down_blocks = nn.ModuleList([])
|
174 |
+
self.up_blocks = nn.ModuleList([])
|
175 |
+
|
176 |
+
if isinstance(num_attention_heads, int):
|
177 |
+
num_attention_heads = (num_attention_heads,) * len(down_block_types)
|
178 |
+
|
179 |
+
# down
|
180 |
+
output_channel = block_out_channels[0]
|
181 |
+
for i, down_block_type in enumerate(down_block_types):
|
182 |
+
input_channel = output_channel
|
183 |
+
output_channel = block_out_channels[i]
|
184 |
+
is_final_block = i == len(block_out_channels) - 1
|
185 |
+
|
186 |
+
down_block = get_down_block(
|
187 |
+
down_block_type,
|
188 |
+
num_layers=layers_per_block,
|
189 |
+
in_channels=input_channel,
|
190 |
+
out_channels=output_channel,
|
191 |
+
temb_channels=time_embed_dim,
|
192 |
+
add_downsample=not is_final_block,
|
193 |
+
resnet_eps=norm_eps,
|
194 |
+
resnet_act_fn=act_fn,
|
195 |
+
resnet_groups=norm_num_groups,
|
196 |
+
cross_attention_dim=cross_attention_dim,
|
197 |
+
num_attention_heads=num_attention_heads[i],
|
198 |
+
downsample_padding=downsample_padding,
|
199 |
+
dual_cross_attention=False,
|
200 |
+
)
|
201 |
+
self.down_blocks.append(down_block)
|
202 |
+
|
203 |
+
# mid
|
204 |
+
self.mid_block = UNetMidBlock3DCrossAttn(
|
205 |
+
in_channels=block_out_channels[-1],
|
206 |
+
temb_channels=time_embed_dim,
|
207 |
+
resnet_eps=norm_eps,
|
208 |
+
resnet_act_fn=act_fn,
|
209 |
+
output_scale_factor=mid_block_scale_factor,
|
210 |
+
cross_attention_dim=cross_attention_dim,
|
211 |
+
num_attention_heads=num_attention_heads[-1],
|
212 |
+
resnet_groups=norm_num_groups,
|
213 |
+
dual_cross_attention=False,
|
214 |
+
)
|
215 |
+
|
216 |
+
# count how many layers upsample the images
|
217 |
+
self.num_upsamplers = 0
|
218 |
+
|
219 |
+
# up
|
220 |
+
reversed_block_out_channels = list(reversed(block_out_channels))
|
221 |
+
reversed_num_attention_heads = list(reversed(num_attention_heads))
|
222 |
+
|
223 |
+
output_channel = reversed_block_out_channels[0]
|
224 |
+
for i, up_block_type in enumerate(up_block_types):
|
225 |
+
is_final_block = i == len(block_out_channels) - 1
|
226 |
+
|
227 |
+
prev_output_channel = output_channel
|
228 |
+
output_channel = reversed_block_out_channels[i]
|
229 |
+
input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
|
230 |
+
|
231 |
+
# add upsample block for all BUT final layer
|
232 |
+
if not is_final_block:
|
233 |
+
add_upsample = True
|
234 |
+
self.num_upsamplers += 1
|
235 |
+
else:
|
236 |
+
add_upsample = False
|
237 |
+
|
238 |
+
up_block = get_up_block(
|
239 |
+
up_block_type,
|
240 |
+
num_layers=layers_per_block + 1,
|
241 |
+
in_channels=input_channel,
|
242 |
+
out_channels=output_channel,
|
243 |
+
prev_output_channel=prev_output_channel,
|
244 |
+
temb_channels=time_embed_dim,
|
245 |
+
add_upsample=add_upsample,
|
246 |
+
resnet_eps=norm_eps,
|
247 |
+
resnet_act_fn=act_fn,
|
248 |
+
resnet_groups=norm_num_groups,
|
249 |
+
cross_attention_dim=cross_attention_dim,
|
250 |
+
num_attention_heads=reversed_num_attention_heads[i],
|
251 |
+
dual_cross_attention=False,
|
252 |
+
)
|
253 |
+
self.up_blocks.append(up_block)
|
254 |
+
prev_output_channel = output_channel
|
255 |
+
|
256 |
+
# out
|
257 |
+
if norm_num_groups is not None:
|
258 |
+
self.conv_norm_out = nn.GroupNorm(
|
259 |
+
num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
|
260 |
+
)
|
261 |
+
self.conv_act = nn.SiLU()
|
262 |
+
else:
|
263 |
+
self.conv_norm_out = None
|
264 |
+
self.conv_act = None
|
265 |
+
|
266 |
+
conv_out_padding = (conv_out_kernel - 1) // 2
|
267 |
+
self.conv_out = nn.Conv2d(
|
268 |
+
block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
|
269 |
+
)
|
270 |
+
|
271 |
+
@property
|
272 |
+
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
|
273 |
+
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
274 |
+
r"""
|
275 |
+
Returns:
|
276 |
+
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
277 |
+
indexed by its weight name.
|
278 |
+
"""
|
279 |
+
# set recursively
|
280 |
+
processors = {}
|
281 |
+
|
282 |
+
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
283 |
+
if hasattr(module, "set_processor"):
|
284 |
+
processors[f"{name}.processor"] = module.processor
|
285 |
+
|
286 |
+
for sub_name, child in module.named_children():
|
287 |
+
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
288 |
+
|
289 |
+
return processors
|
290 |
+
|
291 |
+
for name, module in self.named_children():
|
292 |
+
fn_recursive_add_processors(name, module, processors)
|
293 |
+
|
294 |
+
return processors
|
295 |
+
|
296 |
+
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice
|
297 |
+
def set_attention_slice(self, slice_size):
|
298 |
+
r"""
|
299 |
+
Enable sliced attention computation.
|
300 |
+
|
301 |
+
When this option is enabled, the attention module splits the input tensor in slices to compute attention in
|
302 |
+
several steps. This is useful for saving some memory in exchange for a small decrease in speed.
|
303 |
+
|
304 |
+
Args:
|
305 |
+
slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
|
306 |
+
When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
|
307 |
+
`"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
|
308 |
+
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
|
309 |
+
must be a multiple of `slice_size`.
|
310 |
+
"""
|
311 |
+
sliceable_head_dims = []
|
312 |
+
|
313 |
+
def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
|
314 |
+
if hasattr(module, "set_attention_slice"):
|
315 |
+
sliceable_head_dims.append(module.sliceable_head_dim)
|
316 |
+
|
317 |
+
for child in module.children():
|
318 |
+
fn_recursive_retrieve_sliceable_dims(child)
|
319 |
+
|
320 |
+
# retrieve number of attention layers
|
321 |
+
for module in self.children():
|
322 |
+
fn_recursive_retrieve_sliceable_dims(module)
|
323 |
+
|
324 |
+
num_sliceable_layers = len(sliceable_head_dims)
|
325 |
+
|
326 |
+
if slice_size == "auto":
|
327 |
+
# half the attention head size is usually a good trade-off between
|
328 |
+
# speed and memory
|
329 |
+
slice_size = [dim // 2 for dim in sliceable_head_dims]
|
330 |
+
elif slice_size == "max":
|
331 |
+
# make smallest slice possible
|
332 |
+
slice_size = num_sliceable_layers * [1]
|
333 |
+
|
334 |
+
slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
|
335 |
+
|
336 |
+
if len(slice_size) != len(sliceable_head_dims):
|
337 |
+
raise ValueError(
|
338 |
+
f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
|
339 |
+
f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
|
340 |
+
)
|
341 |
+
|
342 |
+
for i in range(len(slice_size)):
|
343 |
+
size = slice_size[i]
|
344 |
+
dim = sliceable_head_dims[i]
|
345 |
+
if size is not None and size > dim:
|
346 |
+
raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
|
347 |
+
|
348 |
+
# Recursively walk through all the children.
|
349 |
+
# Any children which exposes the set_attention_slice method
|
350 |
+
# gets the message
|
351 |
+
def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
|
352 |
+
if hasattr(module, "set_attention_slice"):
|
353 |
+
module.set_attention_slice(slice_size.pop())
|
354 |
+
|
355 |
+
for child in module.children():
|
356 |
+
fn_recursive_set_attention_slice(child, slice_size)
|
357 |
+
|
358 |
+
reversed_slice_size = list(reversed(slice_size))
|
359 |
+
for module in self.children():
|
360 |
+
fn_recursive_set_attention_slice(module, reversed_slice_size)
|
361 |
+
|
362 |
+
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
363 |
+
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
364 |
+
r"""
|
365 |
+
Sets the attention processor to use to compute attention.
|
366 |
+
|
367 |
+
Parameters:
|
368 |
+
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
369 |
+
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
370 |
+
for **all** `Attention` layers.
|
371 |
+
|
372 |
+
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
373 |
+
processor. This is strongly recommended when setting trainable attention processors.
|
374 |
+
|
375 |
+
"""
|
376 |
+
count = len(self.attn_processors.keys())
|
377 |
+
|
378 |
+
if isinstance(processor, dict) and len(processor) != count:
|
379 |
+
raise ValueError(
|
380 |
+
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
381 |
+
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
382 |
+
)
|
383 |
+
|
384 |
+
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
385 |
+
if hasattr(module, "set_processor"):
|
386 |
+
if not isinstance(processor, dict):
|
387 |
+
module.set_processor(processor)
|
388 |
+
else:
|
389 |
+
module.set_processor(processor.pop(f"{name}.processor"))
|
390 |
+
|
391 |
+
for sub_name, child in module.named_children():
|
392 |
+
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
393 |
+
|
394 |
+
for name, module in self.named_children():
|
395 |
+
fn_recursive_attn_processor(name, module, processor)
|
396 |
+
|
397 |
+
def enable_forward_chunking(self, chunk_size=None, dim=0):
|
398 |
+
"""
|
399 |
+
Sets the attention processor to use [feed forward
|
400 |
+
chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers).
|
401 |
+
|
402 |
+
Parameters:
|
403 |
+
chunk_size (`int`, *optional*):
|
404 |
+
The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually
|
405 |
+
over each tensor of dim=`dim`.
|
406 |
+
dim (`int`, *optional*, defaults to `0`):
|
407 |
+
The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch)
|
408 |
+
or dim=1 (sequence length).
|
409 |
+
"""
|
410 |
+
if dim not in [0, 1]:
|
411 |
+
raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}")
|
412 |
+
|
413 |
+
# By default chunk size is 1
|
414 |
+
chunk_size = chunk_size or 1
|
415 |
+
|
416 |
+
def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
|
417 |
+
if hasattr(module, "set_chunk_feed_forward"):
|
418 |
+
module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
|
419 |
+
|
420 |
+
for child in module.children():
|
421 |
+
fn_recursive_feed_forward(child, chunk_size, dim)
|
422 |
+
|
423 |
+
for module in self.children():
|
424 |
+
fn_recursive_feed_forward(module, chunk_size, dim)
|
425 |
+
|
426 |
+
def disable_forward_chunking(self):
|
427 |
+
def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
|
428 |
+
if hasattr(module, "set_chunk_feed_forward"):
|
429 |
+
module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
|
430 |
+
|
431 |
+
for child in module.children():
|
432 |
+
fn_recursive_feed_forward(child, chunk_size, dim)
|
433 |
+
|
434 |
+
for module in self.children():
|
435 |
+
fn_recursive_feed_forward(module, None, 0)
|
436 |
+
|
437 |
+
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
|
438 |
+
def set_default_attn_processor(self):
|
439 |
+
"""
|
440 |
+
Disables custom attention processors and sets the default attention implementation.
|
441 |
+
"""
|
442 |
+
self.set_attn_processor(AttnProcessor())
|
443 |
+
|
444 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
445 |
+
if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)):
|
446 |
+
module.gradient_checkpointing = value
|
447 |
+
|
448 |
+
def forward(
|
449 |
+
self,
|
450 |
+
sample: torch.FloatTensor,
|
451 |
+
timestep: Union[torch.Tensor, float, int],
|
452 |
+
encoder_hidden_states: torch.Tensor,
|
453 |
+
class_labels: Optional[torch.Tensor] = None,
|
454 |
+
timestep_cond: Optional[torch.Tensor] = None,
|
455 |
+
attention_mask: Optional[torch.Tensor] = None,
|
456 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
457 |
+
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
|
458 |
+
mid_block_additional_residual: Optional[torch.Tensor] = None,
|
459 |
+
return_dict: bool = True,
|
460 |
+
) -> Union[UNet3DConditionOutput, Tuple]:
|
461 |
+
r"""
|
462 |
+
The [`UNet3DConditionModel`] forward method.
|
463 |
+
|
464 |
+
Args:
|
465 |
+
sample (`torch.FloatTensor`):
|
466 |
+
The noisy input tensor with the following shape `(batch, num_frames, channel, height, width`.
|
467 |
+
timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
|
468 |
+
encoder_hidden_states (`torch.FloatTensor`):
|
469 |
+
The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
|
470 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
471 |
+
Whether or not to return a [`~models.unet_3d_condition.UNet3DConditionOutput`] instead of a plain
|
472 |
+
tuple.
|
473 |
+
cross_attention_kwargs (`dict`, *optional*):
|
474 |
+
A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
|
475 |
+
|
476 |
+
Returns:
|
477 |
+
[`~models.unet_3d_condition.UNet3DConditionOutput`] or `tuple`:
|
478 |
+
If `return_dict` is True, an [`~models.unet_3d_condition.UNet3DConditionOutput`] is returned, otherwise
|
479 |
+
a `tuple` is returned where the first element is the sample tensor.
|
480 |
+
"""
|
481 |
+
# By default samples have to be AT least a multiple of the overall upsampling factor.
|
482 |
+
# The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
|
483 |
+
# However, the upsampling interpolation output size can be forced to fit any upsampling size
|
484 |
+
# on the fly if necessary.
|
485 |
+
default_overall_up_factor = 2**self.num_upsamplers
|
486 |
+
|
487 |
+
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
|
488 |
+
forward_upsample_size = False
|
489 |
+
upsample_size = None
|
490 |
+
|
491 |
+
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
|
492 |
+
logger.info("Forward upsample size to force interpolation output size.")
|
493 |
+
forward_upsample_size = True
|
494 |
+
|
495 |
+
# prepare attention_mask
|
496 |
+
if attention_mask is not None:
|
497 |
+
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
|
498 |
+
attention_mask = attention_mask.unsqueeze(1)
|
499 |
+
|
500 |
+
# 1. time
|
501 |
+
timesteps = timestep
|
502 |
+
if not torch.is_tensor(timesteps):
|
503 |
+
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
504 |
+
# This would be a good case for the `match` statement (Python 3.10+)
|
505 |
+
is_mps = sample.device.type == "mps"
|
506 |
+
if isinstance(timestep, float):
|
507 |
+
dtype = torch.float32 if is_mps else torch.float64
|
508 |
+
else:
|
509 |
+
dtype = torch.int32 if is_mps else torch.int64
|
510 |
+
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
|
511 |
+
elif len(timesteps.shape) == 0:
|
512 |
+
timesteps = timesteps[None].to(sample.device)
|
513 |
+
|
514 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
515 |
+
num_frames = sample.shape[2]
|
516 |
+
timesteps = timesteps.expand(sample.shape[0])
|
517 |
+
|
518 |
+
t_emb = self.time_proj(timesteps)
|
519 |
+
|
520 |
+
# timesteps does not contain any weights and will always return f32 tensors
|
521 |
+
# but time_embedding might actually be running in fp16. so we need to cast here.
|
522 |
+
# there might be better ways to encapsulate this.
|
523 |
+
t_emb = t_emb.to(dtype=self.dtype)
|
524 |
+
|
525 |
+
emb = self.time_embedding(t_emb, timestep_cond)
|
526 |
+
emb = emb.repeat_interleave(repeats=num_frames, dim=0)
|
527 |
+
encoder_hidden_states = encoder_hidden_states.repeat_interleave(repeats=num_frames, dim=0)
|
528 |
+
|
529 |
+
# 2. pre-process
|
530 |
+
sample = sample.permute(0, 2, 1, 3, 4).reshape((sample.shape[0] * num_frames, -1) + sample.shape[3:])
|
531 |
+
sample = self.conv_in(sample)
|
532 |
+
|
533 |
+
sample = self.transformer_in(
|
534 |
+
sample,
|
535 |
+
num_frames=num_frames,
|
536 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
537 |
+
return_dict=False,
|
538 |
+
)[0]
|
539 |
+
|
540 |
+
# 3. down
|
541 |
+
down_block_res_samples = (sample,)
|
542 |
+
for downsample_block in self.down_blocks:
|
543 |
+
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
|
544 |
+
sample, res_samples = downsample_block(
|
545 |
+
hidden_states=sample,
|
546 |
+
temb=emb,
|
547 |
+
encoder_hidden_states=encoder_hidden_states,
|
548 |
+
attention_mask=attention_mask,
|
549 |
+
num_frames=num_frames,
|
550 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
551 |
+
)
|
552 |
+
else:
|
553 |
+
sample, res_samples = downsample_block(hidden_states=sample, temb=emb, num_frames=num_frames)
|
554 |
+
|
555 |
+
down_block_res_samples += res_samples
|
556 |
+
|
557 |
+
if down_block_additional_residuals is not None:
|
558 |
+
new_down_block_res_samples = ()
|
559 |
+
|
560 |
+
for down_block_res_sample, down_block_additional_residual in zip(
|
561 |
+
down_block_res_samples, down_block_additional_residuals
|
562 |
+
):
|
563 |
+
down_block_res_sample = down_block_res_sample + down_block_additional_residual
|
564 |
+
new_down_block_res_samples += (down_block_res_sample,)
|
565 |
+
|
566 |
+
down_block_res_samples = new_down_block_res_samples
|
567 |
+
|
568 |
+
# 4. mid
|
569 |
+
if self.mid_block is not None:
|
570 |
+
sample = self.mid_block(
|
571 |
+
sample,
|
572 |
+
emb,
|
573 |
+
encoder_hidden_states=encoder_hidden_states,
|
574 |
+
attention_mask=attention_mask,
|
575 |
+
num_frames=num_frames,
|
576 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
577 |
+
)
|
578 |
+
|
579 |
+
if mid_block_additional_residual is not None:
|
580 |
+
sample = sample + mid_block_additional_residual
|
581 |
+
|
582 |
+
# 5. up
|
583 |
+
for i, upsample_block in enumerate(self.up_blocks):
|
584 |
+
is_final_block = i == len(self.up_blocks) - 1
|
585 |
+
|
586 |
+
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
|
587 |
+
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
|
588 |
+
|
589 |
+
# if we have not reached the final block and need to forward the
|
590 |
+
# upsample size, we do it here
|
591 |
+
if not is_final_block and forward_upsample_size:
|
592 |
+
upsample_size = down_block_res_samples[-1].shape[2:]
|
593 |
+
|
594 |
+
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
|
595 |
+
sample = upsample_block(
|
596 |
+
hidden_states=sample,
|
597 |
+
temb=emb,
|
598 |
+
res_hidden_states_tuple=res_samples,
|
599 |
+
encoder_hidden_states=encoder_hidden_states,
|
600 |
+
upsample_size=upsample_size,
|
601 |
+
attention_mask=attention_mask,
|
602 |
+
num_frames=num_frames,
|
603 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
604 |
+
)
|
605 |
+
else:
|
606 |
+
sample = upsample_block(
|
607 |
+
hidden_states=sample,
|
608 |
+
temb=emb,
|
609 |
+
res_hidden_states_tuple=res_samples,
|
610 |
+
upsample_size=upsample_size,
|
611 |
+
num_frames=num_frames,
|
612 |
+
)
|
613 |
+
|
614 |
+
# 6. post-process
|
615 |
+
if self.conv_norm_out:
|
616 |
+
sample = self.conv_norm_out(sample)
|
617 |
+
sample = self.conv_act(sample)
|
618 |
+
|
619 |
+
sample = self.conv_out(sample)
|
620 |
+
|
621 |
+
# reshape to (batch, channel, framerate, width, height)
|
622 |
+
sample = sample[None, :].reshape((-1, num_frames) + sample.shape[1:]).permute(0, 2, 1, 3, 4)
|
623 |
+
|
624 |
+
if not return_dict:
|
625 |
+
return (sample,)
|
626 |
+
|
627 |
+
return UNet3DConditionOutput(sample=sample)
|
6DoF/diffusers/models/vae.py
ADDED
@@ -0,0 +1,441 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
from dataclasses import dataclass
|
15 |
+
from typing import Optional
|
16 |
+
|
17 |
+
import numpy as np
|
18 |
+
import torch
|
19 |
+
import torch.nn as nn
|
20 |
+
|
21 |
+
from ..utils import BaseOutput, is_torch_version, randn_tensor
|
22 |
+
from .attention_processor import SpatialNorm
|
23 |
+
from .unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block
|
24 |
+
|
25 |
+
|
26 |
+
@dataclass
|
27 |
+
class DecoderOutput(BaseOutput):
|
28 |
+
"""
|
29 |
+
Output of decoding method.
|
30 |
+
|
31 |
+
Args:
|
32 |
+
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
33 |
+
The decoded output sample from the last layer of the model.
|
34 |
+
"""
|
35 |
+
|
36 |
+
sample: torch.FloatTensor
|
37 |
+
|
38 |
+
|
39 |
+
class Encoder(nn.Module):
|
40 |
+
def __init__(
|
41 |
+
self,
|
42 |
+
in_channels=3,
|
43 |
+
out_channels=3,
|
44 |
+
down_block_types=("DownEncoderBlock2D",),
|
45 |
+
block_out_channels=(64,),
|
46 |
+
layers_per_block=2,
|
47 |
+
norm_num_groups=32,
|
48 |
+
act_fn="silu",
|
49 |
+
double_z=True,
|
50 |
+
):
|
51 |
+
super().__init__()
|
52 |
+
self.layers_per_block = layers_per_block
|
53 |
+
|
54 |
+
self.conv_in = torch.nn.Conv2d(
|
55 |
+
in_channels,
|
56 |
+
block_out_channels[0],
|
57 |
+
kernel_size=3,
|
58 |
+
stride=1,
|
59 |
+
padding=1,
|
60 |
+
)
|
61 |
+
|
62 |
+
self.mid_block = None
|
63 |
+
self.down_blocks = nn.ModuleList([])
|
64 |
+
|
65 |
+
# down
|
66 |
+
output_channel = block_out_channels[0]
|
67 |
+
for i, down_block_type in enumerate(down_block_types):
|
68 |
+
input_channel = output_channel
|
69 |
+
output_channel = block_out_channels[i]
|
70 |
+
is_final_block = i == len(block_out_channels) - 1
|
71 |
+
|
72 |
+
down_block = get_down_block(
|
73 |
+
down_block_type,
|
74 |
+
num_layers=self.layers_per_block,
|
75 |
+
in_channels=input_channel,
|
76 |
+
out_channels=output_channel,
|
77 |
+
add_downsample=not is_final_block,
|
78 |
+
resnet_eps=1e-6,
|
79 |
+
downsample_padding=0,
|
80 |
+
resnet_act_fn=act_fn,
|
81 |
+
resnet_groups=norm_num_groups,
|
82 |
+
attention_head_dim=output_channel,
|
83 |
+
temb_channels=None,
|
84 |
+
)
|
85 |
+
self.down_blocks.append(down_block)
|
86 |
+
|
87 |
+
# mid
|
88 |
+
self.mid_block = UNetMidBlock2D(
|
89 |
+
in_channels=block_out_channels[-1],
|
90 |
+
resnet_eps=1e-6,
|
91 |
+
resnet_act_fn=act_fn,
|
92 |
+
output_scale_factor=1,
|
93 |
+
resnet_time_scale_shift="default",
|
94 |
+
attention_head_dim=block_out_channels[-1],
|
95 |
+
resnet_groups=norm_num_groups,
|
96 |
+
temb_channels=None,
|
97 |
+
)
|
98 |
+
|
99 |
+
# out
|
100 |
+
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6)
|
101 |
+
self.conv_act = nn.SiLU()
|
102 |
+
|
103 |
+
conv_out_channels = 2 * out_channels if double_z else out_channels
|
104 |
+
self.conv_out = nn.Conv2d(block_out_channels[-1], conv_out_channels, 3, padding=1)
|
105 |
+
|
106 |
+
self.gradient_checkpointing = False
|
107 |
+
|
108 |
+
def forward(self, x):
|
109 |
+
sample = x
|
110 |
+
sample = self.conv_in(sample)
|
111 |
+
|
112 |
+
if self.training and self.gradient_checkpointing:
|
113 |
+
|
114 |
+
def create_custom_forward(module):
|
115 |
+
def custom_forward(*inputs):
|
116 |
+
return module(*inputs)
|
117 |
+
|
118 |
+
return custom_forward
|
119 |
+
|
120 |
+
# down
|
121 |
+
if is_torch_version(">=", "1.11.0"):
|
122 |
+
for down_block in self.down_blocks:
|
123 |
+
sample = torch.utils.checkpoint.checkpoint(
|
124 |
+
create_custom_forward(down_block), sample, use_reentrant=False
|
125 |
+
)
|
126 |
+
# middle
|
127 |
+
sample = torch.utils.checkpoint.checkpoint(
|
128 |
+
create_custom_forward(self.mid_block), sample, use_reentrant=False
|
129 |
+
)
|
130 |
+
else:
|
131 |
+
for down_block in self.down_blocks:
|
132 |
+
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(down_block), sample)
|
133 |
+
# middle
|
134 |
+
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample)
|
135 |
+
|
136 |
+
else:
|
137 |
+
# down
|
138 |
+
for down_block in self.down_blocks:
|
139 |
+
sample = down_block(sample)
|
140 |
+
|
141 |
+
# middle
|
142 |
+
sample = self.mid_block(sample)
|
143 |
+
|
144 |
+
# post-process
|
145 |
+
sample = self.conv_norm_out(sample)
|
146 |
+
sample = self.conv_act(sample)
|
147 |
+
sample = self.conv_out(sample)
|
148 |
+
|
149 |
+
return sample
|
150 |
+
|
151 |
+
|
152 |
+
class Decoder(nn.Module):
|
153 |
+
def __init__(
|
154 |
+
self,
|
155 |
+
in_channels=3,
|
156 |
+
out_channels=3,
|
157 |
+
up_block_types=("UpDecoderBlock2D",),
|
158 |
+
block_out_channels=(64,),
|
159 |
+
layers_per_block=2,
|
160 |
+
norm_num_groups=32,
|
161 |
+
act_fn="silu",
|
162 |
+
norm_type="group", # group, spatial
|
163 |
+
):
|
164 |
+
super().__init__()
|
165 |
+
self.layers_per_block = layers_per_block
|
166 |
+
|
167 |
+
self.conv_in = nn.Conv2d(
|
168 |
+
in_channels,
|
169 |
+
block_out_channels[-1],
|
170 |
+
kernel_size=3,
|
171 |
+
stride=1,
|
172 |
+
padding=1,
|
173 |
+
)
|
174 |
+
|
175 |
+
self.mid_block = None
|
176 |
+
self.up_blocks = nn.ModuleList([])
|
177 |
+
|
178 |
+
temb_channels = in_channels if norm_type == "spatial" else None
|
179 |
+
|
180 |
+
# mid
|
181 |
+
self.mid_block = UNetMidBlock2D(
|
182 |
+
in_channels=block_out_channels[-1],
|
183 |
+
resnet_eps=1e-6,
|
184 |
+
resnet_act_fn=act_fn,
|
185 |
+
output_scale_factor=1,
|
186 |
+
resnet_time_scale_shift="default" if norm_type == "group" else norm_type,
|
187 |
+
attention_head_dim=block_out_channels[-1],
|
188 |
+
resnet_groups=norm_num_groups,
|
189 |
+
temb_channels=temb_channels,
|
190 |
+
)
|
191 |
+
|
192 |
+
# up
|
193 |
+
reversed_block_out_channels = list(reversed(block_out_channels))
|
194 |
+
output_channel = reversed_block_out_channels[0]
|
195 |
+
for i, up_block_type in enumerate(up_block_types):
|
196 |
+
prev_output_channel = output_channel
|
197 |
+
output_channel = reversed_block_out_channels[i]
|
198 |
+
|
199 |
+
is_final_block = i == len(block_out_channels) - 1
|
200 |
+
|
201 |
+
up_block = get_up_block(
|
202 |
+
up_block_type,
|
203 |
+
num_layers=self.layers_per_block + 1,
|
204 |
+
in_channels=prev_output_channel,
|
205 |
+
out_channels=output_channel,
|
206 |
+
prev_output_channel=None,
|
207 |
+
add_upsample=not is_final_block,
|
208 |
+
resnet_eps=1e-6,
|
209 |
+
resnet_act_fn=act_fn,
|
210 |
+
resnet_groups=norm_num_groups,
|
211 |
+
attention_head_dim=output_channel,
|
212 |
+
temb_channels=temb_channels,
|
213 |
+
resnet_time_scale_shift=norm_type,
|
214 |
+
)
|
215 |
+
self.up_blocks.append(up_block)
|
216 |
+
prev_output_channel = output_channel
|
217 |
+
|
218 |
+
# out
|
219 |
+
if norm_type == "spatial":
|
220 |
+
self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels)
|
221 |
+
else:
|
222 |
+
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)
|
223 |
+
self.conv_act = nn.SiLU()
|
224 |
+
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
|
225 |
+
|
226 |
+
self.gradient_checkpointing = False
|
227 |
+
|
228 |
+
def forward(self, z, latent_embeds=None):
|
229 |
+
sample = z
|
230 |
+
sample = self.conv_in(sample)
|
231 |
+
|
232 |
+
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
|
233 |
+
if self.training and self.gradient_checkpointing:
|
234 |
+
|
235 |
+
def create_custom_forward(module):
|
236 |
+
def custom_forward(*inputs):
|
237 |
+
return module(*inputs)
|
238 |
+
|
239 |
+
return custom_forward
|
240 |
+
|
241 |
+
if is_torch_version(">=", "1.11.0"):
|
242 |
+
# middle
|
243 |
+
sample = torch.utils.checkpoint.checkpoint(
|
244 |
+
create_custom_forward(self.mid_block), sample, latent_embeds, use_reentrant=False
|
245 |
+
)
|
246 |
+
sample = sample.to(upscale_dtype)
|
247 |
+
|
248 |
+
# up
|
249 |
+
for up_block in self.up_blocks:
|
250 |
+
sample = torch.utils.checkpoint.checkpoint(
|
251 |
+
create_custom_forward(up_block), sample, latent_embeds, use_reentrant=False
|
252 |
+
)
|
253 |
+
else:
|
254 |
+
# middle
|
255 |
+
sample = torch.utils.checkpoint.checkpoint(
|
256 |
+
create_custom_forward(self.mid_block), sample, latent_embeds
|
257 |
+
)
|
258 |
+
sample = sample.to(upscale_dtype)
|
259 |
+
|
260 |
+
# up
|
261 |
+
for up_block in self.up_blocks:
|
262 |
+
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample, latent_embeds)
|
263 |
+
else:
|
264 |
+
# middle
|
265 |
+
sample = self.mid_block(sample, latent_embeds)
|
266 |
+
sample = sample.to(upscale_dtype)
|
267 |
+
|
268 |
+
# up
|
269 |
+
for up_block in self.up_blocks:
|
270 |
+
sample = up_block(sample, latent_embeds)
|
271 |
+
|
272 |
+
# post-process
|
273 |
+
if latent_embeds is None:
|
274 |
+
sample = self.conv_norm_out(sample)
|
275 |
+
else:
|
276 |
+
sample = self.conv_norm_out(sample, latent_embeds)
|
277 |
+
sample = self.conv_act(sample)
|
278 |
+
sample = self.conv_out(sample)
|
279 |
+
|
280 |
+
return sample
|
281 |
+
|
282 |
+
|
283 |
+
class VectorQuantizer(nn.Module):
|
284 |
+
"""
|
285 |
+
Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly avoids costly matrix
|
286 |
+
multiplications and allows for post-hoc remapping of indices.
|
287 |
+
"""
|
288 |
+
|
289 |
+
# NOTE: due to a bug the beta term was applied to the wrong term. for
|
290 |
+
# backwards compatibility we use the buggy version by default, but you can
|
291 |
+
# specify legacy=False to fix it.
|
292 |
+
def __init__(
|
293 |
+
self, n_e, vq_embed_dim, beta, remap=None, unknown_index="random", sane_index_shape=False, legacy=True
|
294 |
+
):
|
295 |
+
super().__init__()
|
296 |
+
self.n_e = n_e
|
297 |
+
self.vq_embed_dim = vq_embed_dim
|
298 |
+
self.beta = beta
|
299 |
+
self.legacy = legacy
|
300 |
+
|
301 |
+
self.embedding = nn.Embedding(self.n_e, self.vq_embed_dim)
|
302 |
+
self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
|
303 |
+
|
304 |
+
self.remap = remap
|
305 |
+
if self.remap is not None:
|
306 |
+
self.register_buffer("used", torch.tensor(np.load(self.remap)))
|
307 |
+
self.re_embed = self.used.shape[0]
|
308 |
+
self.unknown_index = unknown_index # "random" or "extra" or integer
|
309 |
+
if self.unknown_index == "extra":
|
310 |
+
self.unknown_index = self.re_embed
|
311 |
+
self.re_embed = self.re_embed + 1
|
312 |
+
print(
|
313 |
+
f"Remapping {self.n_e} indices to {self.re_embed} indices. "
|
314 |
+
f"Using {self.unknown_index} for unknown indices."
|
315 |
+
)
|
316 |
+
else:
|
317 |
+
self.re_embed = n_e
|
318 |
+
|
319 |
+
self.sane_index_shape = sane_index_shape
|
320 |
+
|
321 |
+
def remap_to_used(self, inds):
|
322 |
+
ishape = inds.shape
|
323 |
+
assert len(ishape) > 1
|
324 |
+
inds = inds.reshape(ishape[0], -1)
|
325 |
+
used = self.used.to(inds)
|
326 |
+
match = (inds[:, :, None] == used[None, None, ...]).long()
|
327 |
+
new = match.argmax(-1)
|
328 |
+
unknown = match.sum(2) < 1
|
329 |
+
if self.unknown_index == "random":
|
330 |
+
new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device)
|
331 |
+
else:
|
332 |
+
new[unknown] = self.unknown_index
|
333 |
+
return new.reshape(ishape)
|
334 |
+
|
335 |
+
def unmap_to_all(self, inds):
|
336 |
+
ishape = inds.shape
|
337 |
+
assert len(ishape) > 1
|
338 |
+
inds = inds.reshape(ishape[0], -1)
|
339 |
+
used = self.used.to(inds)
|
340 |
+
if self.re_embed > self.used.shape[0]: # extra token
|
341 |
+
inds[inds >= self.used.shape[0]] = 0 # simply set to zero
|
342 |
+
back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds)
|
343 |
+
return back.reshape(ishape)
|
344 |
+
|
345 |
+
def forward(self, z):
|
346 |
+
# reshape z -> (batch, height, width, channel) and flatten
|
347 |
+
z = z.permute(0, 2, 3, 1).contiguous()
|
348 |
+
z_flattened = z.view(-1, self.vq_embed_dim)
|
349 |
+
|
350 |
+
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
|
351 |
+
min_encoding_indices = torch.argmin(torch.cdist(z_flattened, self.embedding.weight), dim=1)
|
352 |
+
|
353 |
+
z_q = self.embedding(min_encoding_indices).view(z.shape)
|
354 |
+
perplexity = None
|
355 |
+
min_encodings = None
|
356 |
+
|
357 |
+
# compute loss for embedding
|
358 |
+
if not self.legacy:
|
359 |
+
loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean((z_q - z.detach()) ** 2)
|
360 |
+
else:
|
361 |
+
loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean((z_q - z.detach()) ** 2)
|
362 |
+
|
363 |
+
# preserve gradients
|
364 |
+
z_q = z + (z_q - z).detach()
|
365 |
+
|
366 |
+
# reshape back to match original input shape
|
367 |
+
z_q = z_q.permute(0, 3, 1, 2).contiguous()
|
368 |
+
|
369 |
+
if self.remap is not None:
|
370 |
+
min_encoding_indices = min_encoding_indices.reshape(z.shape[0], -1) # add batch axis
|
371 |
+
min_encoding_indices = self.remap_to_used(min_encoding_indices)
|
372 |
+
min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten
|
373 |
+
|
374 |
+
if self.sane_index_shape:
|
375 |
+
min_encoding_indices = min_encoding_indices.reshape(z_q.shape[0], z_q.shape[2], z_q.shape[3])
|
376 |
+
|
377 |
+
return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
|
378 |
+
|
379 |
+
def get_codebook_entry(self, indices, shape):
|
380 |
+
# shape specifying (batch, height, width, channel)
|
381 |
+
if self.remap is not None:
|
382 |
+
indices = indices.reshape(shape[0], -1) # add batch axis
|
383 |
+
indices = self.unmap_to_all(indices)
|
384 |
+
indices = indices.reshape(-1) # flatten again
|
385 |
+
|
386 |
+
# get quantized latent vectors
|
387 |
+
z_q = self.embedding(indices)
|
388 |
+
|
389 |
+
if shape is not None:
|
390 |
+
z_q = z_q.view(shape)
|
391 |
+
# reshape back to match original input shape
|
392 |
+
z_q = z_q.permute(0, 3, 1, 2).contiguous()
|
393 |
+
|
394 |
+
return z_q
|
395 |
+
|
396 |
+
|
397 |
+
class DiagonalGaussianDistribution(object):
|
398 |
+
def __init__(self, parameters, deterministic=False):
|
399 |
+
self.parameters = parameters
|
400 |
+
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
|
401 |
+
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
|
402 |
+
self.deterministic = deterministic
|
403 |
+
self.std = torch.exp(0.5 * self.logvar)
|
404 |
+
self.var = torch.exp(self.logvar)
|
405 |
+
if self.deterministic:
|
406 |
+
self.var = self.std = torch.zeros_like(
|
407 |
+
self.mean, device=self.parameters.device, dtype=self.parameters.dtype
|
408 |
+
)
|
409 |
+
|
410 |
+
def sample(self, generator: Optional[torch.Generator] = None) -> torch.FloatTensor:
|
411 |
+
# make sure sample is on the same device as the parameters and has same dtype
|
412 |
+
sample = randn_tensor(
|
413 |
+
self.mean.shape, generator=generator, device=self.parameters.device, dtype=self.parameters.dtype
|
414 |
+
)
|
415 |
+
x = self.mean + self.std * sample
|
416 |
+
return x
|
417 |
+
|
418 |
+
def kl(self, other=None):
|
419 |
+
if self.deterministic:
|
420 |
+
return torch.Tensor([0.0])
|
421 |
+
else:
|
422 |
+
if other is None:
|
423 |
+
return 0.5 * torch.sum(torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, dim=[1, 2, 3])
|
424 |
+
else:
|
425 |
+
return 0.5 * torch.sum(
|
426 |
+
torch.pow(self.mean - other.mean, 2) / other.var
|
427 |
+
+ self.var / other.var
|
428 |
+
- 1.0
|
429 |
+
- self.logvar
|
430 |
+
+ other.logvar,
|
431 |
+
dim=[1, 2, 3],
|
432 |
+
)
|
433 |
+
|
434 |
+
def nll(self, sample, dims=[1, 2, 3]):
|
435 |
+
if self.deterministic:
|
436 |
+
return torch.Tensor([0.0])
|
437 |
+
logtwopi = np.log(2.0 * np.pi)
|
438 |
+
return 0.5 * torch.sum(logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, dim=dims)
|
439 |
+
|
440 |
+
def mode(self):
|
441 |
+
return self.mean
|
6DoF/diffusers/models/vae_flax.py
ADDED
@@ -0,0 +1,869 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
# JAX implementation of VQGAN from taming-transformers https://github.com/CompVis/taming-transformers
|
16 |
+
|
17 |
+
import math
|
18 |
+
from functools import partial
|
19 |
+
from typing import Tuple
|
20 |
+
|
21 |
+
import flax
|
22 |
+
import flax.linen as nn
|
23 |
+
import jax
|
24 |
+
import jax.numpy as jnp
|
25 |
+
from flax.core.frozen_dict import FrozenDict
|
26 |
+
|
27 |
+
from ..configuration_utils import ConfigMixin, flax_register_to_config
|
28 |
+
from ..utils import BaseOutput
|
29 |
+
from .modeling_flax_utils import FlaxModelMixin
|
30 |
+
|
31 |
+
|
32 |
+
@flax.struct.dataclass
|
33 |
+
class FlaxDecoderOutput(BaseOutput):
|
34 |
+
"""
|
35 |
+
Output of decoding method.
|
36 |
+
|
37 |
+
Args:
|
38 |
+
sample (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)`):
|
39 |
+
The decoded output sample from the last layer of the model.
|
40 |
+
dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`):
|
41 |
+
The `dtype` of the parameters.
|
42 |
+
"""
|
43 |
+
|
44 |
+
sample: jnp.ndarray
|
45 |
+
|
46 |
+
|
47 |
+
@flax.struct.dataclass
|
48 |
+
class FlaxAutoencoderKLOutput(BaseOutput):
|
49 |
+
"""
|
50 |
+
Output of AutoencoderKL encoding method.
|
51 |
+
|
52 |
+
Args:
|
53 |
+
latent_dist (`FlaxDiagonalGaussianDistribution`):
|
54 |
+
Encoded outputs of `Encoder` represented as the mean and logvar of `FlaxDiagonalGaussianDistribution`.
|
55 |
+
`FlaxDiagonalGaussianDistribution` allows for sampling latents from the distribution.
|
56 |
+
"""
|
57 |
+
|
58 |
+
latent_dist: "FlaxDiagonalGaussianDistribution"
|
59 |
+
|
60 |
+
|
61 |
+
class FlaxUpsample2D(nn.Module):
|
62 |
+
"""
|
63 |
+
Flax implementation of 2D Upsample layer
|
64 |
+
|
65 |
+
Args:
|
66 |
+
in_channels (`int`):
|
67 |
+
Input channels
|
68 |
+
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
|
69 |
+
Parameters `dtype`
|
70 |
+
"""
|
71 |
+
|
72 |
+
in_channels: int
|
73 |
+
dtype: jnp.dtype = jnp.float32
|
74 |
+
|
75 |
+
def setup(self):
|
76 |
+
self.conv = nn.Conv(
|
77 |
+
self.in_channels,
|
78 |
+
kernel_size=(3, 3),
|
79 |
+
strides=(1, 1),
|
80 |
+
padding=((1, 1), (1, 1)),
|
81 |
+
dtype=self.dtype,
|
82 |
+
)
|
83 |
+
|
84 |
+
def __call__(self, hidden_states):
|
85 |
+
batch, height, width, channels = hidden_states.shape
|
86 |
+
hidden_states = jax.image.resize(
|
87 |
+
hidden_states,
|
88 |
+
shape=(batch, height * 2, width * 2, channels),
|
89 |
+
method="nearest",
|
90 |
+
)
|
91 |
+
hidden_states = self.conv(hidden_states)
|
92 |
+
return hidden_states
|
93 |
+
|
94 |
+
|
95 |
+
class FlaxDownsample2D(nn.Module):
|
96 |
+
"""
|
97 |
+
Flax implementation of 2D Downsample layer
|
98 |
+
|
99 |
+
Args:
|
100 |
+
in_channels (`int`):
|
101 |
+
Input channels
|
102 |
+
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
|
103 |
+
Parameters `dtype`
|
104 |
+
"""
|
105 |
+
|
106 |
+
in_channels: int
|
107 |
+
dtype: jnp.dtype = jnp.float32
|
108 |
+
|
109 |
+
def setup(self):
|
110 |
+
self.conv = nn.Conv(
|
111 |
+
self.in_channels,
|
112 |
+
kernel_size=(3, 3),
|
113 |
+
strides=(2, 2),
|
114 |
+
padding="VALID",
|
115 |
+
dtype=self.dtype,
|
116 |
+
)
|
117 |
+
|
118 |
+
def __call__(self, hidden_states):
|
119 |
+
pad = ((0, 0), (0, 1), (0, 1), (0, 0)) # pad height and width dim
|
120 |
+
hidden_states = jnp.pad(hidden_states, pad_width=pad)
|
121 |
+
hidden_states = self.conv(hidden_states)
|
122 |
+
return hidden_states
|
123 |
+
|
124 |
+
|
125 |
+
class FlaxResnetBlock2D(nn.Module):
|
126 |
+
"""
|
127 |
+
Flax implementation of 2D Resnet Block.
|
128 |
+
|
129 |
+
Args:
|
130 |
+
in_channels (`int`):
|
131 |
+
Input channels
|
132 |
+
out_channels (`int`):
|
133 |
+
Output channels
|
134 |
+
dropout (:obj:`float`, *optional*, defaults to 0.0):
|
135 |
+
Dropout rate
|
136 |
+
groups (:obj:`int`, *optional*, defaults to `32`):
|
137 |
+
The number of groups to use for group norm.
|
138 |
+
use_nin_shortcut (:obj:`bool`, *optional*, defaults to `None`):
|
139 |
+
Whether to use `nin_shortcut`. This activates a new layer inside ResNet block
|
140 |
+
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
|
141 |
+
Parameters `dtype`
|
142 |
+
"""
|
143 |
+
|
144 |
+
in_channels: int
|
145 |
+
out_channels: int = None
|
146 |
+
dropout: float = 0.0
|
147 |
+
groups: int = 32
|
148 |
+
use_nin_shortcut: bool = None
|
149 |
+
dtype: jnp.dtype = jnp.float32
|
150 |
+
|
151 |
+
def setup(self):
|
152 |
+
out_channels = self.in_channels if self.out_channels is None else self.out_channels
|
153 |
+
|
154 |
+
self.norm1 = nn.GroupNorm(num_groups=self.groups, epsilon=1e-6)
|
155 |
+
self.conv1 = nn.Conv(
|
156 |
+
out_channels,
|
157 |
+
kernel_size=(3, 3),
|
158 |
+
strides=(1, 1),
|
159 |
+
padding=((1, 1), (1, 1)),
|
160 |
+
dtype=self.dtype,
|
161 |
+
)
|
162 |
+
|
163 |
+
self.norm2 = nn.GroupNorm(num_groups=self.groups, epsilon=1e-6)
|
164 |
+
self.dropout_layer = nn.Dropout(self.dropout)
|
165 |
+
self.conv2 = nn.Conv(
|
166 |
+
out_channels,
|
167 |
+
kernel_size=(3, 3),
|
168 |
+
strides=(1, 1),
|
169 |
+
padding=((1, 1), (1, 1)),
|
170 |
+
dtype=self.dtype,
|
171 |
+
)
|
172 |
+
|
173 |
+
use_nin_shortcut = self.in_channels != out_channels if self.use_nin_shortcut is None else self.use_nin_shortcut
|
174 |
+
|
175 |
+
self.conv_shortcut = None
|
176 |
+
if use_nin_shortcut:
|
177 |
+
self.conv_shortcut = nn.Conv(
|
178 |
+
out_channels,
|
179 |
+
kernel_size=(1, 1),
|
180 |
+
strides=(1, 1),
|
181 |
+
padding="VALID",
|
182 |
+
dtype=self.dtype,
|
183 |
+
)
|
184 |
+
|
185 |
+
def __call__(self, hidden_states, deterministic=True):
|
186 |
+
residual = hidden_states
|
187 |
+
hidden_states = self.norm1(hidden_states)
|
188 |
+
hidden_states = nn.swish(hidden_states)
|
189 |
+
hidden_states = self.conv1(hidden_states)
|
190 |
+
|
191 |
+
hidden_states = self.norm2(hidden_states)
|
192 |
+
hidden_states = nn.swish(hidden_states)
|
193 |
+
hidden_states = self.dropout_layer(hidden_states, deterministic)
|
194 |
+
hidden_states = self.conv2(hidden_states)
|
195 |
+
|
196 |
+
if self.conv_shortcut is not None:
|
197 |
+
residual = self.conv_shortcut(residual)
|
198 |
+
|
199 |
+
return hidden_states + residual
|
200 |
+
|
201 |
+
|
202 |
+
class FlaxAttentionBlock(nn.Module):
|
203 |
+
r"""
|
204 |
+
Flax Convolutional based multi-head attention block for diffusion-based VAE.
|
205 |
+
|
206 |
+
Parameters:
|
207 |
+
channels (:obj:`int`):
|
208 |
+
Input channels
|
209 |
+
num_head_channels (:obj:`int`, *optional*, defaults to `None`):
|
210 |
+
Number of attention heads
|
211 |
+
num_groups (:obj:`int`, *optional*, defaults to `32`):
|
212 |
+
The number of groups to use for group norm
|
213 |
+
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
|
214 |
+
Parameters `dtype`
|
215 |
+
|
216 |
+
"""
|
217 |
+
channels: int
|
218 |
+
num_head_channels: int = None
|
219 |
+
num_groups: int = 32
|
220 |
+
dtype: jnp.dtype = jnp.float32
|
221 |
+
|
222 |
+
def setup(self):
|
223 |
+
self.num_heads = self.channels // self.num_head_channels if self.num_head_channels is not None else 1
|
224 |
+
|
225 |
+
dense = partial(nn.Dense, self.channels, dtype=self.dtype)
|
226 |
+
|
227 |
+
self.group_norm = nn.GroupNorm(num_groups=self.num_groups, epsilon=1e-6)
|
228 |
+
self.query, self.key, self.value = dense(), dense(), dense()
|
229 |
+
self.proj_attn = dense()
|
230 |
+
|
231 |
+
def transpose_for_scores(self, projection):
|
232 |
+
new_projection_shape = projection.shape[:-1] + (self.num_heads, -1)
|
233 |
+
# move heads to 2nd position (B, T, H * D) -> (B, T, H, D)
|
234 |
+
new_projection = projection.reshape(new_projection_shape)
|
235 |
+
# (B, T, H, D) -> (B, H, T, D)
|
236 |
+
new_projection = jnp.transpose(new_projection, (0, 2, 1, 3))
|
237 |
+
return new_projection
|
238 |
+
|
239 |
+
def __call__(self, hidden_states):
|
240 |
+
residual = hidden_states
|
241 |
+
batch, height, width, channels = hidden_states.shape
|
242 |
+
|
243 |
+
hidden_states = self.group_norm(hidden_states)
|
244 |
+
|
245 |
+
hidden_states = hidden_states.reshape((batch, height * width, channels))
|
246 |
+
|
247 |
+
query = self.query(hidden_states)
|
248 |
+
key = self.key(hidden_states)
|
249 |
+
value = self.value(hidden_states)
|
250 |
+
|
251 |
+
# transpose
|
252 |
+
query = self.transpose_for_scores(query)
|
253 |
+
key = self.transpose_for_scores(key)
|
254 |
+
value = self.transpose_for_scores(value)
|
255 |
+
|
256 |
+
# compute attentions
|
257 |
+
scale = 1 / math.sqrt(math.sqrt(self.channels / self.num_heads))
|
258 |
+
attn_weights = jnp.einsum("...qc,...kc->...qk", query * scale, key * scale)
|
259 |
+
attn_weights = nn.softmax(attn_weights, axis=-1)
|
260 |
+
|
261 |
+
# attend to values
|
262 |
+
hidden_states = jnp.einsum("...kc,...qk->...qc", value, attn_weights)
|
263 |
+
|
264 |
+
hidden_states = jnp.transpose(hidden_states, (0, 2, 1, 3))
|
265 |
+
new_hidden_states_shape = hidden_states.shape[:-2] + (self.channels,)
|
266 |
+
hidden_states = hidden_states.reshape(new_hidden_states_shape)
|
267 |
+
|
268 |
+
hidden_states = self.proj_attn(hidden_states)
|
269 |
+
hidden_states = hidden_states.reshape((batch, height, width, channels))
|
270 |
+
hidden_states = hidden_states + residual
|
271 |
+
return hidden_states
|
272 |
+
|
273 |
+
|
274 |
+
class FlaxDownEncoderBlock2D(nn.Module):
|
275 |
+
r"""
|
276 |
+
Flax Resnet blocks-based Encoder block for diffusion-based VAE.
|
277 |
+
|
278 |
+
Parameters:
|
279 |
+
in_channels (:obj:`int`):
|
280 |
+
Input channels
|
281 |
+
out_channels (:obj:`int`):
|
282 |
+
Output channels
|
283 |
+
dropout (:obj:`float`, *optional*, defaults to 0.0):
|
284 |
+
Dropout rate
|
285 |
+
num_layers (:obj:`int`, *optional*, defaults to 1):
|
286 |
+
Number of Resnet layer block
|
287 |
+
resnet_groups (:obj:`int`, *optional*, defaults to `32`):
|
288 |
+
The number of groups to use for the Resnet block group norm
|
289 |
+
add_downsample (:obj:`bool`, *optional*, defaults to `True`):
|
290 |
+
Whether to add downsample layer
|
291 |
+
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
|
292 |
+
Parameters `dtype`
|
293 |
+
"""
|
294 |
+
in_channels: int
|
295 |
+
out_channels: int
|
296 |
+
dropout: float = 0.0
|
297 |
+
num_layers: int = 1
|
298 |
+
resnet_groups: int = 32
|
299 |
+
add_downsample: bool = True
|
300 |
+
dtype: jnp.dtype = jnp.float32
|
301 |
+
|
302 |
+
def setup(self):
|
303 |
+
resnets = []
|
304 |
+
for i in range(self.num_layers):
|
305 |
+
in_channels = self.in_channels if i == 0 else self.out_channels
|
306 |
+
|
307 |
+
res_block = FlaxResnetBlock2D(
|
308 |
+
in_channels=in_channels,
|
309 |
+
out_channels=self.out_channels,
|
310 |
+
dropout=self.dropout,
|
311 |
+
groups=self.resnet_groups,
|
312 |
+
dtype=self.dtype,
|
313 |
+
)
|
314 |
+
resnets.append(res_block)
|
315 |
+
self.resnets = resnets
|
316 |
+
|
317 |
+
if self.add_downsample:
|
318 |
+
self.downsamplers_0 = FlaxDownsample2D(self.out_channels, dtype=self.dtype)
|
319 |
+
|
320 |
+
def __call__(self, hidden_states, deterministic=True):
|
321 |
+
for resnet in self.resnets:
|
322 |
+
hidden_states = resnet(hidden_states, deterministic=deterministic)
|
323 |
+
|
324 |
+
if self.add_downsample:
|
325 |
+
hidden_states = self.downsamplers_0(hidden_states)
|
326 |
+
|
327 |
+
return hidden_states
|
328 |
+
|
329 |
+
|
330 |
+
class FlaxUpDecoderBlock2D(nn.Module):
|
331 |
+
r"""
|
332 |
+
Flax Resnet blocks-based Decoder block for diffusion-based VAE.
|
333 |
+
|
334 |
+
Parameters:
|
335 |
+
in_channels (:obj:`int`):
|
336 |
+
Input channels
|
337 |
+
out_channels (:obj:`int`):
|
338 |
+
Output channels
|
339 |
+
dropout (:obj:`float`, *optional*, defaults to 0.0):
|
340 |
+
Dropout rate
|
341 |
+
num_layers (:obj:`int`, *optional*, defaults to 1):
|
342 |
+
Number of Resnet layer block
|
343 |
+
resnet_groups (:obj:`int`, *optional*, defaults to `32`):
|
344 |
+
The number of groups to use for the Resnet block group norm
|
345 |
+
add_upsample (:obj:`bool`, *optional*, defaults to `True`):
|
346 |
+
Whether to add upsample layer
|
347 |
+
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
|
348 |
+
Parameters `dtype`
|
349 |
+
"""
|
350 |
+
in_channels: int
|
351 |
+
out_channels: int
|
352 |
+
dropout: float = 0.0
|
353 |
+
num_layers: int = 1
|
354 |
+
resnet_groups: int = 32
|
355 |
+
add_upsample: bool = True
|
356 |
+
dtype: jnp.dtype = jnp.float32
|
357 |
+
|
358 |
+
def setup(self):
|
359 |
+
resnets = []
|
360 |
+
for i in range(self.num_layers):
|
361 |
+
in_channels = self.in_channels if i == 0 else self.out_channels
|
362 |
+
res_block = FlaxResnetBlock2D(
|
363 |
+
in_channels=in_channels,
|
364 |
+
out_channels=self.out_channels,
|
365 |
+
dropout=self.dropout,
|
366 |
+
groups=self.resnet_groups,
|
367 |
+
dtype=self.dtype,
|
368 |
+
)
|
369 |
+
resnets.append(res_block)
|
370 |
+
|
371 |
+
self.resnets = resnets
|
372 |
+
|
373 |
+
if self.add_upsample:
|
374 |
+
self.upsamplers_0 = FlaxUpsample2D(self.out_channels, dtype=self.dtype)
|
375 |
+
|
376 |
+
def __call__(self, hidden_states, deterministic=True):
|
377 |
+
for resnet in self.resnets:
|
378 |
+
hidden_states = resnet(hidden_states, deterministic=deterministic)
|
379 |
+
|
380 |
+
if self.add_upsample:
|
381 |
+
hidden_states = self.upsamplers_0(hidden_states)
|
382 |
+
|
383 |
+
return hidden_states
|
384 |
+
|
385 |
+
|
386 |
+
class FlaxUNetMidBlock2D(nn.Module):
|
387 |
+
r"""
|
388 |
+
Flax Unet Mid-Block module.
|
389 |
+
|
390 |
+
Parameters:
|
391 |
+
in_channels (:obj:`int`):
|
392 |
+
Input channels
|
393 |
+
dropout (:obj:`float`, *optional*, defaults to 0.0):
|
394 |
+
Dropout rate
|
395 |
+
num_layers (:obj:`int`, *optional*, defaults to 1):
|
396 |
+
Number of Resnet layer block
|
397 |
+
resnet_groups (:obj:`int`, *optional*, defaults to `32`):
|
398 |
+
The number of groups to use for the Resnet and Attention block group norm
|
399 |
+
num_attention_heads (:obj:`int`, *optional*, defaults to `1`):
|
400 |
+
Number of attention heads for each attention block
|
401 |
+
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
|
402 |
+
Parameters `dtype`
|
403 |
+
"""
|
404 |
+
in_channels: int
|
405 |
+
dropout: float = 0.0
|
406 |
+
num_layers: int = 1
|
407 |
+
resnet_groups: int = 32
|
408 |
+
num_attention_heads: int = 1
|
409 |
+
dtype: jnp.dtype = jnp.float32
|
410 |
+
|
411 |
+
def setup(self):
|
412 |
+
resnet_groups = self.resnet_groups if self.resnet_groups is not None else min(self.in_channels // 4, 32)
|
413 |
+
|
414 |
+
# there is always at least one resnet
|
415 |
+
resnets = [
|
416 |
+
FlaxResnetBlock2D(
|
417 |
+
in_channels=self.in_channels,
|
418 |
+
out_channels=self.in_channels,
|
419 |
+
dropout=self.dropout,
|
420 |
+
groups=resnet_groups,
|
421 |
+
dtype=self.dtype,
|
422 |
+
)
|
423 |
+
]
|
424 |
+
|
425 |
+
attentions = []
|
426 |
+
|
427 |
+
for _ in range(self.num_layers):
|
428 |
+
attn_block = FlaxAttentionBlock(
|
429 |
+
channels=self.in_channels,
|
430 |
+
num_head_channels=self.num_attention_heads,
|
431 |
+
num_groups=resnet_groups,
|
432 |
+
dtype=self.dtype,
|
433 |
+
)
|
434 |
+
attentions.append(attn_block)
|
435 |
+
|
436 |
+
res_block = FlaxResnetBlock2D(
|
437 |
+
in_channels=self.in_channels,
|
438 |
+
out_channels=self.in_channels,
|
439 |
+
dropout=self.dropout,
|
440 |
+
groups=resnet_groups,
|
441 |
+
dtype=self.dtype,
|
442 |
+
)
|
443 |
+
resnets.append(res_block)
|
444 |
+
|
445 |
+
self.resnets = resnets
|
446 |
+
self.attentions = attentions
|
447 |
+
|
448 |
+
def __call__(self, hidden_states, deterministic=True):
|
449 |
+
hidden_states = self.resnets[0](hidden_states, deterministic=deterministic)
|
450 |
+
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
451 |
+
hidden_states = attn(hidden_states)
|
452 |
+
hidden_states = resnet(hidden_states, deterministic=deterministic)
|
453 |
+
|
454 |
+
return hidden_states
|
455 |
+
|
456 |
+
|
457 |
+
class FlaxEncoder(nn.Module):
|
458 |
+
r"""
|
459 |
+
Flax Implementation of VAE Encoder.
|
460 |
+
|
461 |
+
This model is a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module)
|
462 |
+
subclass. Use it as a regular Flax linen Module and refer to the Flax documentation for all matter related to
|
463 |
+
general usage and behavior.
|
464 |
+
|
465 |
+
Finally, this model supports inherent JAX features such as:
|
466 |
+
- [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
|
467 |
+
- [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
|
468 |
+
- [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
|
469 |
+
- [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
|
470 |
+
|
471 |
+
Parameters:
|
472 |
+
in_channels (:obj:`int`, *optional*, defaults to 3):
|
473 |
+
Input channels
|
474 |
+
out_channels (:obj:`int`, *optional*, defaults to 3):
|
475 |
+
Output channels
|
476 |
+
down_block_types (:obj:`Tuple[str]`, *optional*, defaults to `(DownEncoderBlock2D)`):
|
477 |
+
DownEncoder block type
|
478 |
+
block_out_channels (:obj:`Tuple[str]`, *optional*, defaults to `(64,)`):
|
479 |
+
Tuple containing the number of output channels for each block
|
480 |
+
layers_per_block (:obj:`int`, *optional*, defaults to `2`):
|
481 |
+
Number of Resnet layer for each block
|
482 |
+
norm_num_groups (:obj:`int`, *optional*, defaults to `32`):
|
483 |
+
norm num group
|
484 |
+
act_fn (:obj:`str`, *optional*, defaults to `silu`):
|
485 |
+
Activation function
|
486 |
+
double_z (:obj:`bool`, *optional*, defaults to `False`):
|
487 |
+
Whether to double the last output channels
|
488 |
+
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
|
489 |
+
Parameters `dtype`
|
490 |
+
"""
|
491 |
+
in_channels: int = 3
|
492 |
+
out_channels: int = 3
|
493 |
+
down_block_types: Tuple[str] = ("DownEncoderBlock2D",)
|
494 |
+
block_out_channels: Tuple[int] = (64,)
|
495 |
+
layers_per_block: int = 2
|
496 |
+
norm_num_groups: int = 32
|
497 |
+
act_fn: str = "silu"
|
498 |
+
double_z: bool = False
|
499 |
+
dtype: jnp.dtype = jnp.float32
|
500 |
+
|
501 |
+
def setup(self):
|
502 |
+
block_out_channels = self.block_out_channels
|
503 |
+
# in
|
504 |
+
self.conv_in = nn.Conv(
|
505 |
+
block_out_channels[0],
|
506 |
+
kernel_size=(3, 3),
|
507 |
+
strides=(1, 1),
|
508 |
+
padding=((1, 1), (1, 1)),
|
509 |
+
dtype=self.dtype,
|
510 |
+
)
|
511 |
+
|
512 |
+
# downsampling
|
513 |
+
down_blocks = []
|
514 |
+
output_channel = block_out_channels[0]
|
515 |
+
for i, _ in enumerate(self.down_block_types):
|
516 |
+
input_channel = output_channel
|
517 |
+
output_channel = block_out_channels[i]
|
518 |
+
is_final_block = i == len(block_out_channels) - 1
|
519 |
+
|
520 |
+
down_block = FlaxDownEncoderBlock2D(
|
521 |
+
in_channels=input_channel,
|
522 |
+
out_channels=output_channel,
|
523 |
+
num_layers=self.layers_per_block,
|
524 |
+
resnet_groups=self.norm_num_groups,
|
525 |
+
add_downsample=not is_final_block,
|
526 |
+
dtype=self.dtype,
|
527 |
+
)
|
528 |
+
down_blocks.append(down_block)
|
529 |
+
self.down_blocks = down_blocks
|
530 |
+
|
531 |
+
# middle
|
532 |
+
self.mid_block = FlaxUNetMidBlock2D(
|
533 |
+
in_channels=block_out_channels[-1],
|
534 |
+
resnet_groups=self.norm_num_groups,
|
535 |
+
num_attention_heads=None,
|
536 |
+
dtype=self.dtype,
|
537 |
+
)
|
538 |
+
|
539 |
+
# end
|
540 |
+
conv_out_channels = 2 * self.out_channels if self.double_z else self.out_channels
|
541 |
+
self.conv_norm_out = nn.GroupNorm(num_groups=self.norm_num_groups, epsilon=1e-6)
|
542 |
+
self.conv_out = nn.Conv(
|
543 |
+
conv_out_channels,
|
544 |
+
kernel_size=(3, 3),
|
545 |
+
strides=(1, 1),
|
546 |
+
padding=((1, 1), (1, 1)),
|
547 |
+
dtype=self.dtype,
|
548 |
+
)
|
549 |
+
|
550 |
+
def __call__(self, sample, deterministic: bool = True):
|
551 |
+
# in
|
552 |
+
sample = self.conv_in(sample)
|
553 |
+
|
554 |
+
# downsampling
|
555 |
+
for block in self.down_blocks:
|
556 |
+
sample = block(sample, deterministic=deterministic)
|
557 |
+
|
558 |
+
# middle
|
559 |
+
sample = self.mid_block(sample, deterministic=deterministic)
|
560 |
+
|
561 |
+
# end
|
562 |
+
sample = self.conv_norm_out(sample)
|
563 |
+
sample = nn.swish(sample)
|
564 |
+
sample = self.conv_out(sample)
|
565 |
+
|
566 |
+
return sample
|
567 |
+
|
568 |
+
|
569 |
+
class FlaxDecoder(nn.Module):
|
570 |
+
r"""
|
571 |
+
Flax Implementation of VAE Decoder.
|
572 |
+
|
573 |
+
This model is a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module)
|
574 |
+
subclass. Use it as a regular Flax linen Module and refer to the Flax documentation for all matter related to
|
575 |
+
general usage and behavior.
|
576 |
+
|
577 |
+
Finally, this model supports inherent JAX features such as:
|
578 |
+
- [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
|
579 |
+
- [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
|
580 |
+
- [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
|
581 |
+
- [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
|
582 |
+
|
583 |
+
Parameters:
|
584 |
+
in_channels (:obj:`int`, *optional*, defaults to 3):
|
585 |
+
Input channels
|
586 |
+
out_channels (:obj:`int`, *optional*, defaults to 3):
|
587 |
+
Output channels
|
588 |
+
up_block_types (:obj:`Tuple[str]`, *optional*, defaults to `(UpDecoderBlock2D)`):
|
589 |
+
UpDecoder block type
|
590 |
+
block_out_channels (:obj:`Tuple[str]`, *optional*, defaults to `(64,)`):
|
591 |
+
Tuple containing the number of output channels for each block
|
592 |
+
layers_per_block (:obj:`int`, *optional*, defaults to `2`):
|
593 |
+
Number of Resnet layer for each block
|
594 |
+
norm_num_groups (:obj:`int`, *optional*, defaults to `32`):
|
595 |
+
norm num group
|
596 |
+
act_fn (:obj:`str`, *optional*, defaults to `silu`):
|
597 |
+
Activation function
|
598 |
+
double_z (:obj:`bool`, *optional*, defaults to `False`):
|
599 |
+
Whether to double the last output channels
|
600 |
+
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
|
601 |
+
parameters `dtype`
|
602 |
+
"""
|
603 |
+
in_channels: int = 3
|
604 |
+
out_channels: int = 3
|
605 |
+
up_block_types: Tuple[str] = ("UpDecoderBlock2D",)
|
606 |
+
block_out_channels: int = (64,)
|
607 |
+
layers_per_block: int = 2
|
608 |
+
norm_num_groups: int = 32
|
609 |
+
act_fn: str = "silu"
|
610 |
+
dtype: jnp.dtype = jnp.float32
|
611 |
+
|
612 |
+
def setup(self):
|
613 |
+
block_out_channels = self.block_out_channels
|
614 |
+
|
615 |
+
# z to block_in
|
616 |
+
self.conv_in = nn.Conv(
|
617 |
+
block_out_channels[-1],
|
618 |
+
kernel_size=(3, 3),
|
619 |
+
strides=(1, 1),
|
620 |
+
padding=((1, 1), (1, 1)),
|
621 |
+
dtype=self.dtype,
|
622 |
+
)
|
623 |
+
|
624 |
+
# middle
|
625 |
+
self.mid_block = FlaxUNetMidBlock2D(
|
626 |
+
in_channels=block_out_channels[-1],
|
627 |
+
resnet_groups=self.norm_num_groups,
|
628 |
+
num_attention_heads=None,
|
629 |
+
dtype=self.dtype,
|
630 |
+
)
|
631 |
+
|
632 |
+
# upsampling
|
633 |
+
reversed_block_out_channels = list(reversed(block_out_channels))
|
634 |
+
output_channel = reversed_block_out_channels[0]
|
635 |
+
up_blocks = []
|
636 |
+
for i, _ in enumerate(self.up_block_types):
|
637 |
+
prev_output_channel = output_channel
|
638 |
+
output_channel = reversed_block_out_channels[i]
|
639 |
+
|
640 |
+
is_final_block = i == len(block_out_channels) - 1
|
641 |
+
|
642 |
+
up_block = FlaxUpDecoderBlock2D(
|
643 |
+
in_channels=prev_output_channel,
|
644 |
+
out_channels=output_channel,
|
645 |
+
num_layers=self.layers_per_block + 1,
|
646 |
+
resnet_groups=self.norm_num_groups,
|
647 |
+
add_upsample=not is_final_block,
|
648 |
+
dtype=self.dtype,
|
649 |
+
)
|
650 |
+
up_blocks.append(up_block)
|
651 |
+
prev_output_channel = output_channel
|
652 |
+
|
653 |
+
self.up_blocks = up_blocks
|
654 |
+
|
655 |
+
# end
|
656 |
+
self.conv_norm_out = nn.GroupNorm(num_groups=self.norm_num_groups, epsilon=1e-6)
|
657 |
+
self.conv_out = nn.Conv(
|
658 |
+
self.out_channels,
|
659 |
+
kernel_size=(3, 3),
|
660 |
+
strides=(1, 1),
|
661 |
+
padding=((1, 1), (1, 1)),
|
662 |
+
dtype=self.dtype,
|
663 |
+
)
|
664 |
+
|
665 |
+
def __call__(self, sample, deterministic: bool = True):
|
666 |
+
# z to block_in
|
667 |
+
sample = self.conv_in(sample)
|
668 |
+
|
669 |
+
# middle
|
670 |
+
sample = self.mid_block(sample, deterministic=deterministic)
|
671 |
+
|
672 |
+
# upsampling
|
673 |
+
for block in self.up_blocks:
|
674 |
+
sample = block(sample, deterministic=deterministic)
|
675 |
+
|
676 |
+
sample = self.conv_norm_out(sample)
|
677 |
+
sample = nn.swish(sample)
|
678 |
+
sample = self.conv_out(sample)
|
679 |
+
|
680 |
+
return sample
|
681 |
+
|
682 |
+
|
683 |
+
class FlaxDiagonalGaussianDistribution(object):
|
684 |
+
def __init__(self, parameters, deterministic=False):
|
685 |
+
# Last axis to account for channels-last
|
686 |
+
self.mean, self.logvar = jnp.split(parameters, 2, axis=-1)
|
687 |
+
self.logvar = jnp.clip(self.logvar, -30.0, 20.0)
|
688 |
+
self.deterministic = deterministic
|
689 |
+
self.std = jnp.exp(0.5 * self.logvar)
|
690 |
+
self.var = jnp.exp(self.logvar)
|
691 |
+
if self.deterministic:
|
692 |
+
self.var = self.std = jnp.zeros_like(self.mean)
|
693 |
+
|
694 |
+
def sample(self, key):
|
695 |
+
return self.mean + self.std * jax.random.normal(key, self.mean.shape)
|
696 |
+
|
697 |
+
def kl(self, other=None):
|
698 |
+
if self.deterministic:
|
699 |
+
return jnp.array([0.0])
|
700 |
+
|
701 |
+
if other is None:
|
702 |
+
return 0.5 * jnp.sum(self.mean**2 + self.var - 1.0 - self.logvar, axis=[1, 2, 3])
|
703 |
+
|
704 |
+
return 0.5 * jnp.sum(
|
705 |
+
jnp.square(self.mean - other.mean) / other.var + self.var / other.var - 1.0 - self.logvar + other.logvar,
|
706 |
+
axis=[1, 2, 3],
|
707 |
+
)
|
708 |
+
|
709 |
+
def nll(self, sample, axis=[1, 2, 3]):
|
710 |
+
if self.deterministic:
|
711 |
+
return jnp.array([0.0])
|
712 |
+
|
713 |
+
logtwopi = jnp.log(2.0 * jnp.pi)
|
714 |
+
return 0.5 * jnp.sum(logtwopi + self.logvar + jnp.square(sample - self.mean) / self.var, axis=axis)
|
715 |
+
|
716 |
+
def mode(self):
|
717 |
+
return self.mean
|
718 |
+
|
719 |
+
|
720 |
+
@flax_register_to_config
|
721 |
+
class FlaxAutoencoderKL(nn.Module, FlaxModelMixin, ConfigMixin):
|
722 |
+
r"""
|
723 |
+
Flax implementation of a VAE model with KL loss for decoding latent representations.
|
724 |
+
|
725 |
+
This model inherits from [`FlaxModelMixin`]. Check the superclass documentation for it's generic methods
|
726 |
+
implemented for all models (such as downloading or saving).
|
727 |
+
|
728 |
+
This model is a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module)
|
729 |
+
subclass. Use it as a regular Flax Linen module and refer to the Flax documentation for all matter related to its
|
730 |
+
general usage and behavior.
|
731 |
+
|
732 |
+
Inherent JAX features such as the following are supported:
|
733 |
+
|
734 |
+
- [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
|
735 |
+
- [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
|
736 |
+
- [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
|
737 |
+
- [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
|
738 |
+
|
739 |
+
Parameters:
|
740 |
+
in_channels (`int`, *optional*, defaults to 3):
|
741 |
+
Number of channels in the input image.
|
742 |
+
out_channels (`int`, *optional*, defaults to 3):
|
743 |
+
Number of channels in the output.
|
744 |
+
down_block_types (`Tuple[str]`, *optional*, defaults to `(DownEncoderBlock2D)`):
|
745 |
+
Tuple of downsample block types.
|
746 |
+
up_block_types (`Tuple[str]`, *optional*, defaults to `(UpDecoderBlock2D)`):
|
747 |
+
Tuple of upsample block types.
|
748 |
+
block_out_channels (`Tuple[str]`, *optional*, defaults to `(64,)`):
|
749 |
+
Tuple of block output channels.
|
750 |
+
layers_per_block (`int`, *optional*, defaults to `2`):
|
751 |
+
Number of ResNet layer for each block.
|
752 |
+
act_fn (`str`, *optional*, defaults to `silu`):
|
753 |
+
The activation function to use.
|
754 |
+
latent_channels (`int`, *optional*, defaults to `4`):
|
755 |
+
Number of channels in the latent space.
|
756 |
+
norm_num_groups (`int`, *optional*, defaults to `32`):
|
757 |
+
The number of groups for normalization.
|
758 |
+
sample_size (`int`, *optional*, defaults to 32):
|
759 |
+
Sample input size.
|
760 |
+
scaling_factor (`float`, *optional*, defaults to 0.18215):
|
761 |
+
The component-wise standard deviation of the trained latent space computed using the first batch of the
|
762 |
+
training set. This is used to scale the latent space to have unit variance when training the diffusion
|
763 |
+
model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
|
764 |
+
diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
|
765 |
+
/ scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
|
766 |
+
Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
|
767 |
+
dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`):
|
768 |
+
The `dtype` of the parameters.
|
769 |
+
"""
|
770 |
+
in_channels: int = 3
|
771 |
+
out_channels: int = 3
|
772 |
+
down_block_types: Tuple[str] = ("DownEncoderBlock2D",)
|
773 |
+
up_block_types: Tuple[str] = ("UpDecoderBlock2D",)
|
774 |
+
block_out_channels: Tuple[int] = (64,)
|
775 |
+
layers_per_block: int = 1
|
776 |
+
act_fn: str = "silu"
|
777 |
+
latent_channels: int = 4
|
778 |
+
norm_num_groups: int = 32
|
779 |
+
sample_size: int = 32
|
780 |
+
scaling_factor: float = 0.18215
|
781 |
+
dtype: jnp.dtype = jnp.float32
|
782 |
+
|
783 |
+
def setup(self):
|
784 |
+
self.encoder = FlaxEncoder(
|
785 |
+
in_channels=self.config.in_channels,
|
786 |
+
out_channels=self.config.latent_channels,
|
787 |
+
down_block_types=self.config.down_block_types,
|
788 |
+
block_out_channels=self.config.block_out_channels,
|
789 |
+
layers_per_block=self.config.layers_per_block,
|
790 |
+
act_fn=self.config.act_fn,
|
791 |
+
norm_num_groups=self.config.norm_num_groups,
|
792 |
+
double_z=True,
|
793 |
+
dtype=self.dtype,
|
794 |
+
)
|
795 |
+
self.decoder = FlaxDecoder(
|
796 |
+
in_channels=self.config.latent_channels,
|
797 |
+
out_channels=self.config.out_channels,
|
798 |
+
up_block_types=self.config.up_block_types,
|
799 |
+
block_out_channels=self.config.block_out_channels,
|
800 |
+
layers_per_block=self.config.layers_per_block,
|
801 |
+
norm_num_groups=self.config.norm_num_groups,
|
802 |
+
act_fn=self.config.act_fn,
|
803 |
+
dtype=self.dtype,
|
804 |
+
)
|
805 |
+
self.quant_conv = nn.Conv(
|
806 |
+
2 * self.config.latent_channels,
|
807 |
+
kernel_size=(1, 1),
|
808 |
+
strides=(1, 1),
|
809 |
+
padding="VALID",
|
810 |
+
dtype=self.dtype,
|
811 |
+
)
|
812 |
+
self.post_quant_conv = nn.Conv(
|
813 |
+
self.config.latent_channels,
|
814 |
+
kernel_size=(1, 1),
|
815 |
+
strides=(1, 1),
|
816 |
+
padding="VALID",
|
817 |
+
dtype=self.dtype,
|
818 |
+
)
|
819 |
+
|
820 |
+
def init_weights(self, rng: jax.random.KeyArray) -> FrozenDict:
|
821 |
+
# init input tensors
|
822 |
+
sample_shape = (1, self.in_channels, self.sample_size, self.sample_size)
|
823 |
+
sample = jnp.zeros(sample_shape, dtype=jnp.float32)
|
824 |
+
|
825 |
+
params_rng, dropout_rng, gaussian_rng = jax.random.split(rng, 3)
|
826 |
+
rngs = {"params": params_rng, "dropout": dropout_rng, "gaussian": gaussian_rng}
|
827 |
+
|
828 |
+
return self.init(rngs, sample)["params"]
|
829 |
+
|
830 |
+
def encode(self, sample, deterministic: bool = True, return_dict: bool = True):
|
831 |
+
sample = jnp.transpose(sample, (0, 2, 3, 1))
|
832 |
+
|
833 |
+
hidden_states = self.encoder(sample, deterministic=deterministic)
|
834 |
+
moments = self.quant_conv(hidden_states)
|
835 |
+
posterior = FlaxDiagonalGaussianDistribution(moments)
|
836 |
+
|
837 |
+
if not return_dict:
|
838 |
+
return (posterior,)
|
839 |
+
|
840 |
+
return FlaxAutoencoderKLOutput(latent_dist=posterior)
|
841 |
+
|
842 |
+
def decode(self, latents, deterministic: bool = True, return_dict: bool = True):
|
843 |
+
if latents.shape[-1] != self.config.latent_channels:
|
844 |
+
latents = jnp.transpose(latents, (0, 2, 3, 1))
|
845 |
+
|
846 |
+
hidden_states = self.post_quant_conv(latents)
|
847 |
+
hidden_states = self.decoder(hidden_states, deterministic=deterministic)
|
848 |
+
|
849 |
+
hidden_states = jnp.transpose(hidden_states, (0, 3, 1, 2))
|
850 |
+
|
851 |
+
if not return_dict:
|
852 |
+
return (hidden_states,)
|
853 |
+
|
854 |
+
return FlaxDecoderOutput(sample=hidden_states)
|
855 |
+
|
856 |
+
def __call__(self, sample, sample_posterior=False, deterministic: bool = True, return_dict: bool = True):
|
857 |
+
posterior = self.encode(sample, deterministic=deterministic, return_dict=return_dict)
|
858 |
+
if sample_posterior:
|
859 |
+
rng = self.make_rng("gaussian")
|
860 |
+
hidden_states = posterior.latent_dist.sample(rng)
|
861 |
+
else:
|
862 |
+
hidden_states = posterior.latent_dist.mode()
|
863 |
+
|
864 |
+
sample = self.decode(hidden_states, return_dict=return_dict).sample
|
865 |
+
|
866 |
+
if not return_dict:
|
867 |
+
return (sample,)
|
868 |
+
|
869 |
+
return FlaxDecoderOutput(sample=sample)
|
6DoF/diffusers/models/vq_model.py
ADDED
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
from dataclasses import dataclass
|
15 |
+
from typing import Optional, Tuple, Union
|
16 |
+
|
17 |
+
import torch
|
18 |
+
import torch.nn as nn
|
19 |
+
|
20 |
+
from ..configuration_utils import ConfigMixin, register_to_config
|
21 |
+
from ..utils import BaseOutput, apply_forward_hook
|
22 |
+
from .modeling_utils import ModelMixin
|
23 |
+
from .vae import Decoder, DecoderOutput, Encoder, VectorQuantizer
|
24 |
+
|
25 |
+
|
26 |
+
@dataclass
|
27 |
+
class VQEncoderOutput(BaseOutput):
|
28 |
+
"""
|
29 |
+
Output of VQModel encoding method.
|
30 |
+
|
31 |
+
Args:
|
32 |
+
latents (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
33 |
+
The encoded output sample from the last layer of the model.
|
34 |
+
"""
|
35 |
+
|
36 |
+
latents: torch.FloatTensor
|
37 |
+
|
38 |
+
|
39 |
+
class VQModel(ModelMixin, ConfigMixin):
|
40 |
+
r"""
|
41 |
+
A VQ-VAE model for decoding latent representations.
|
42 |
+
|
43 |
+
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
|
44 |
+
for all models (such as downloading or saving).
|
45 |
+
|
46 |
+
Parameters:
|
47 |
+
in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
|
48 |
+
out_channels (int, *optional*, defaults to 3): Number of channels in the output.
|
49 |
+
down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
|
50 |
+
Tuple of downsample block types.
|
51 |
+
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
|
52 |
+
Tuple of upsample block types.
|
53 |
+
block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
|
54 |
+
Tuple of block output channels.
|
55 |
+
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
|
56 |
+
latent_channels (`int`, *optional*, defaults to `3`): Number of channels in the latent space.
|
57 |
+
sample_size (`int`, *optional*, defaults to `32`): Sample input size.
|
58 |
+
num_vq_embeddings (`int`, *optional*, defaults to `256`): Number of codebook vectors in the VQ-VAE.
|
59 |
+
vq_embed_dim (`int`, *optional*): Hidden dim of codebook vectors in the VQ-VAE.
|
60 |
+
scaling_factor (`float`, *optional*, defaults to `0.18215`):
|
61 |
+
The component-wise standard deviation of the trained latent space computed using the first batch of the
|
62 |
+
training set. This is used to scale the latent space to have unit variance when training the diffusion
|
63 |
+
model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
|
64 |
+
diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
|
65 |
+
/ scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
|
66 |
+
Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
|
67 |
+
"""
|
68 |
+
|
69 |
+
@register_to_config
|
70 |
+
def __init__(
|
71 |
+
self,
|
72 |
+
in_channels: int = 3,
|
73 |
+
out_channels: int = 3,
|
74 |
+
down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
|
75 |
+
up_block_types: Tuple[str] = ("UpDecoderBlock2D",),
|
76 |
+
block_out_channels: Tuple[int] = (64,),
|
77 |
+
layers_per_block: int = 1,
|
78 |
+
act_fn: str = "silu",
|
79 |
+
latent_channels: int = 3,
|
80 |
+
sample_size: int = 32,
|
81 |
+
num_vq_embeddings: int = 256,
|
82 |
+
norm_num_groups: int = 32,
|
83 |
+
vq_embed_dim: Optional[int] = None,
|
84 |
+
scaling_factor: float = 0.18215,
|
85 |
+
norm_type: str = "group", # group, spatial
|
86 |
+
):
|
87 |
+
super().__init__()
|
88 |
+
|
89 |
+
# pass init params to Encoder
|
90 |
+
self.encoder = Encoder(
|
91 |
+
in_channels=in_channels,
|
92 |
+
out_channels=latent_channels,
|
93 |
+
down_block_types=down_block_types,
|
94 |
+
block_out_channels=block_out_channels,
|
95 |
+
layers_per_block=layers_per_block,
|
96 |
+
act_fn=act_fn,
|
97 |
+
norm_num_groups=norm_num_groups,
|
98 |
+
double_z=False,
|
99 |
+
)
|
100 |
+
|
101 |
+
vq_embed_dim = vq_embed_dim if vq_embed_dim is not None else latent_channels
|
102 |
+
|
103 |
+
self.quant_conv = nn.Conv2d(latent_channels, vq_embed_dim, 1)
|
104 |
+
self.quantize = VectorQuantizer(num_vq_embeddings, vq_embed_dim, beta=0.25, remap=None, sane_index_shape=False)
|
105 |
+
self.post_quant_conv = nn.Conv2d(vq_embed_dim, latent_channels, 1)
|
106 |
+
|
107 |
+
# pass init params to Decoder
|
108 |
+
self.decoder = Decoder(
|
109 |
+
in_channels=latent_channels,
|
110 |
+
out_channels=out_channels,
|
111 |
+
up_block_types=up_block_types,
|
112 |
+
block_out_channels=block_out_channels,
|
113 |
+
layers_per_block=layers_per_block,
|
114 |
+
act_fn=act_fn,
|
115 |
+
norm_num_groups=norm_num_groups,
|
116 |
+
norm_type=norm_type,
|
117 |
+
)
|
118 |
+
|
119 |
+
@apply_forward_hook
|
120 |
+
def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> VQEncoderOutput:
|
121 |
+
h = self.encoder(x)
|
122 |
+
h = self.quant_conv(h)
|
123 |
+
|
124 |
+
if not return_dict:
|
125 |
+
return (h,)
|
126 |
+
|
127 |
+
return VQEncoderOutput(latents=h)
|
128 |
+
|
129 |
+
@apply_forward_hook
|
130 |
+
def decode(
|
131 |
+
self, h: torch.FloatTensor, force_not_quantize: bool = False, return_dict: bool = True
|
132 |
+
) -> Union[DecoderOutput, torch.FloatTensor]:
|
133 |
+
# also go through quantization layer
|
134 |
+
if not force_not_quantize:
|
135 |
+
quant, emb_loss, info = self.quantize(h)
|
136 |
+
else:
|
137 |
+
quant = h
|
138 |
+
quant2 = self.post_quant_conv(quant)
|
139 |
+
dec = self.decoder(quant2, quant if self.config.norm_type == "spatial" else None)
|
140 |
+
|
141 |
+
if not return_dict:
|
142 |
+
return (dec,)
|
143 |
+
|
144 |
+
return DecoderOutput(sample=dec)
|
145 |
+
|
146 |
+
def forward(self, sample: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
|
147 |
+
r"""
|
148 |
+
The [`VQModel`] forward method.
|
149 |
+
|
150 |
+
Args:
|
151 |
+
sample (`torch.FloatTensor`): Input sample.
|
152 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
153 |
+
Whether or not to return a [`models.vq_model.VQEncoderOutput`] instead of a plain tuple.
|
154 |
+
|
155 |
+
Returns:
|
156 |
+
[`~models.vq_model.VQEncoderOutput`] or `tuple`:
|
157 |
+
If return_dict is True, a [`~models.vq_model.VQEncoderOutput`] is returned, otherwise a plain `tuple`
|
158 |
+
is returned.
|
159 |
+
"""
|
160 |
+
x = sample
|
161 |
+
h = self.encode(x).latents
|
162 |
+
dec = self.decode(h).sample
|
163 |
+
|
164 |
+
if not return_dict:
|
165 |
+
return (dec,)
|
166 |
+
|
167 |
+
return DecoderOutput(sample=dec)
|
6DoF/diffusers/optimization.py
ADDED
@@ -0,0 +1,354 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023 The HuggingFace Inc. team.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
"""PyTorch optimization for diffusion models."""
|
16 |
+
|
17 |
+
import math
|
18 |
+
from enum import Enum
|
19 |
+
from typing import Optional, Union
|
20 |
+
|
21 |
+
from torch.optim import Optimizer
|
22 |
+
from torch.optim.lr_scheduler import LambdaLR
|
23 |
+
|
24 |
+
from .utils import logging
|
25 |
+
|
26 |
+
|
27 |
+
logger = logging.get_logger(__name__)
|
28 |
+
|
29 |
+
|
30 |
+
class SchedulerType(Enum):
|
31 |
+
LINEAR = "linear"
|
32 |
+
COSINE = "cosine"
|
33 |
+
COSINE_WITH_RESTARTS = "cosine_with_restarts"
|
34 |
+
POLYNOMIAL = "polynomial"
|
35 |
+
CONSTANT = "constant"
|
36 |
+
CONSTANT_WITH_WARMUP = "constant_with_warmup"
|
37 |
+
PIECEWISE_CONSTANT = "piecewise_constant"
|
38 |
+
|
39 |
+
|
40 |
+
def get_constant_schedule(optimizer: Optimizer, last_epoch: int = -1):
|
41 |
+
"""
|
42 |
+
Create a schedule with a constant learning rate, using the learning rate set in optimizer.
|
43 |
+
|
44 |
+
Args:
|
45 |
+
optimizer ([`~torch.optim.Optimizer`]):
|
46 |
+
The optimizer for which to schedule the learning rate.
|
47 |
+
last_epoch (`int`, *optional*, defaults to -1):
|
48 |
+
The index of the last epoch when resuming training.
|
49 |
+
|
50 |
+
Return:
|
51 |
+
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
|
52 |
+
"""
|
53 |
+
return LambdaLR(optimizer, lambda _: 1, last_epoch=last_epoch)
|
54 |
+
|
55 |
+
|
56 |
+
def get_constant_schedule_with_warmup(optimizer: Optimizer, num_warmup_steps: int, last_epoch: int = -1):
|
57 |
+
"""
|
58 |
+
Create a schedule with a constant learning rate preceded by a warmup period during which the learning rate
|
59 |
+
increases linearly between 0 and the initial lr set in the optimizer.
|
60 |
+
|
61 |
+
Args:
|
62 |
+
optimizer ([`~torch.optim.Optimizer`]):
|
63 |
+
The optimizer for which to schedule the learning rate.
|
64 |
+
num_warmup_steps (`int`):
|
65 |
+
The number of steps for the warmup phase.
|
66 |
+
last_epoch (`int`, *optional*, defaults to -1):
|
67 |
+
The index of the last epoch when resuming training.
|
68 |
+
|
69 |
+
Return:
|
70 |
+
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
|
71 |
+
"""
|
72 |
+
|
73 |
+
def lr_lambda(current_step: int):
|
74 |
+
if current_step < num_warmup_steps:
|
75 |
+
return float(current_step) / float(max(1.0, num_warmup_steps))
|
76 |
+
return 1.0
|
77 |
+
|
78 |
+
return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch)
|
79 |
+
|
80 |
+
|
81 |
+
def get_piecewise_constant_schedule(optimizer: Optimizer, step_rules: str, last_epoch: int = -1):
|
82 |
+
"""
|
83 |
+
Create a schedule with a constant learning rate, using the learning rate set in optimizer.
|
84 |
+
|
85 |
+
Args:
|
86 |
+
optimizer ([`~torch.optim.Optimizer`]):
|
87 |
+
The optimizer for which to schedule the learning rate.
|
88 |
+
step_rules (`string`):
|
89 |
+
The rules for the learning rate. ex: rule_steps="1:10,0.1:20,0.01:30,0.005" it means that the learning rate
|
90 |
+
if multiple 1 for the first 10 steps, mutiple 0.1 for the next 20 steps, multiple 0.01 for the next 30
|
91 |
+
steps and multiple 0.005 for the other steps.
|
92 |
+
last_epoch (`int`, *optional*, defaults to -1):
|
93 |
+
The index of the last epoch when resuming training.
|
94 |
+
|
95 |
+
Return:
|
96 |
+
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
|
97 |
+
"""
|
98 |
+
|
99 |
+
rules_dict = {}
|
100 |
+
rule_list = step_rules.split(",")
|
101 |
+
for rule_str in rule_list[:-1]:
|
102 |
+
value_str, steps_str = rule_str.split(":")
|
103 |
+
steps = int(steps_str)
|
104 |
+
value = float(value_str)
|
105 |
+
rules_dict[steps] = value
|
106 |
+
last_lr_multiple = float(rule_list[-1])
|
107 |
+
|
108 |
+
def create_rules_function(rules_dict, last_lr_multiple):
|
109 |
+
def rule_func(steps: int) -> float:
|
110 |
+
sorted_steps = sorted(rules_dict.keys())
|
111 |
+
for i, sorted_step in enumerate(sorted_steps):
|
112 |
+
if steps < sorted_step:
|
113 |
+
return rules_dict[sorted_steps[i]]
|
114 |
+
return last_lr_multiple
|
115 |
+
|
116 |
+
return rule_func
|
117 |
+
|
118 |
+
rules_func = create_rules_function(rules_dict, last_lr_multiple)
|
119 |
+
|
120 |
+
return LambdaLR(optimizer, rules_func, last_epoch=last_epoch)
|
121 |
+
|
122 |
+
|
123 |
+
def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1):
|
124 |
+
"""
|
125 |
+
Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after
|
126 |
+
a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer.
|
127 |
+
|
128 |
+
Args:
|
129 |
+
optimizer ([`~torch.optim.Optimizer`]):
|
130 |
+
The optimizer for which to schedule the learning rate.
|
131 |
+
num_warmup_steps (`int`):
|
132 |
+
The number of steps for the warmup phase.
|
133 |
+
num_training_steps (`int`):
|
134 |
+
The total number of training steps.
|
135 |
+
last_epoch (`int`, *optional*, defaults to -1):
|
136 |
+
The index of the last epoch when resuming training.
|
137 |
+
|
138 |
+
Return:
|
139 |
+
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
|
140 |
+
"""
|
141 |
+
|
142 |
+
def lr_lambda(current_step: int):
|
143 |
+
if current_step < num_warmup_steps:
|
144 |
+
return float(current_step) / float(max(1, num_warmup_steps))
|
145 |
+
return max(
|
146 |
+
0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps))
|
147 |
+
)
|
148 |
+
|
149 |
+
return LambdaLR(optimizer, lr_lambda, last_epoch)
|
150 |
+
|
151 |
+
|
152 |
+
def get_cosine_schedule_with_warmup(
|
153 |
+
optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: float = 0.5, last_epoch: int = -1
|
154 |
+
):
|
155 |
+
"""
|
156 |
+
Create a schedule with a learning rate that decreases following the values of the cosine function between the
|
157 |
+
initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the
|
158 |
+
initial lr set in the optimizer.
|
159 |
+
|
160 |
+
Args:
|
161 |
+
optimizer ([`~torch.optim.Optimizer`]):
|
162 |
+
The optimizer for which to schedule the learning rate.
|
163 |
+
num_warmup_steps (`int`):
|
164 |
+
The number of steps for the warmup phase.
|
165 |
+
num_training_steps (`int`):
|
166 |
+
The total number of training steps.
|
167 |
+
num_periods (`float`, *optional*, defaults to 0.5):
|
168 |
+
The number of periods of the cosine function in a schedule (the default is to just decrease from the max
|
169 |
+
value to 0 following a half-cosine).
|
170 |
+
last_epoch (`int`, *optional*, defaults to -1):
|
171 |
+
The index of the last epoch when resuming training.
|
172 |
+
|
173 |
+
Return:
|
174 |
+
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
|
175 |
+
"""
|
176 |
+
|
177 |
+
def lr_lambda(current_step):
|
178 |
+
if current_step < num_warmup_steps:
|
179 |
+
return float(current_step) / float(max(1, num_warmup_steps))
|
180 |
+
progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
|
181 |
+
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
|
182 |
+
|
183 |
+
return LambdaLR(optimizer, lr_lambda, last_epoch)
|
184 |
+
|
185 |
+
|
186 |
+
def get_cosine_with_hard_restarts_schedule_with_warmup(
|
187 |
+
optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: int = 1, last_epoch: int = -1
|
188 |
+
):
|
189 |
+
"""
|
190 |
+
Create a schedule with a learning rate that decreases following the values of the cosine function between the
|
191 |
+
initial lr set in the optimizer to 0, with several hard restarts, after a warmup period during which it increases
|
192 |
+
linearly between 0 and the initial lr set in the optimizer.
|
193 |
+
|
194 |
+
Args:
|
195 |
+
optimizer ([`~torch.optim.Optimizer`]):
|
196 |
+
The optimizer for which to schedule the learning rate.
|
197 |
+
num_warmup_steps (`int`):
|
198 |
+
The number of steps for the warmup phase.
|
199 |
+
num_training_steps (`int`):
|
200 |
+
The total number of training steps.
|
201 |
+
num_cycles (`int`, *optional*, defaults to 1):
|
202 |
+
The number of hard restarts to use.
|
203 |
+
last_epoch (`int`, *optional*, defaults to -1):
|
204 |
+
The index of the last epoch when resuming training.
|
205 |
+
|
206 |
+
Return:
|
207 |
+
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
|
208 |
+
"""
|
209 |
+
|
210 |
+
def lr_lambda(current_step):
|
211 |
+
if current_step < num_warmup_steps:
|
212 |
+
return float(current_step) / float(max(1, num_warmup_steps))
|
213 |
+
progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
|
214 |
+
if progress >= 1.0:
|
215 |
+
return 0.0
|
216 |
+
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * ((float(num_cycles) * progress) % 1.0))))
|
217 |
+
|
218 |
+
return LambdaLR(optimizer, lr_lambda, last_epoch)
|
219 |
+
|
220 |
+
|
221 |
+
def get_polynomial_decay_schedule_with_warmup(
|
222 |
+
optimizer, num_warmup_steps, num_training_steps, lr_end=1e-7, power=1.0, last_epoch=-1
|
223 |
+
):
|
224 |
+
"""
|
225 |
+
Create a schedule with a learning rate that decreases as a polynomial decay from the initial lr set in the
|
226 |
+
optimizer to end lr defined by *lr_end*, after a warmup period during which it increases linearly from 0 to the
|
227 |
+
initial lr set in the optimizer.
|
228 |
+
|
229 |
+
Args:
|
230 |
+
optimizer ([`~torch.optim.Optimizer`]):
|
231 |
+
The optimizer for which to schedule the learning rate.
|
232 |
+
num_warmup_steps (`int`):
|
233 |
+
The number of steps for the warmup phase.
|
234 |
+
num_training_steps (`int`):
|
235 |
+
The total number of training steps.
|
236 |
+
lr_end (`float`, *optional*, defaults to 1e-7):
|
237 |
+
The end LR.
|
238 |
+
power (`float`, *optional*, defaults to 1.0):
|
239 |
+
Power factor.
|
240 |
+
last_epoch (`int`, *optional*, defaults to -1):
|
241 |
+
The index of the last epoch when resuming training.
|
242 |
+
|
243 |
+
Note: *power* defaults to 1.0 as in the fairseq implementation, which in turn is based on the original BERT
|
244 |
+
implementation at
|
245 |
+
https://github.com/google-research/bert/blob/f39e881b169b9d53bea03d2d341b31707a6c052b/optimization.py#L37
|
246 |
+
|
247 |
+
Return:
|
248 |
+
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
|
249 |
+
|
250 |
+
"""
|
251 |
+
|
252 |
+
lr_init = optimizer.defaults["lr"]
|
253 |
+
if not (lr_init > lr_end):
|
254 |
+
raise ValueError(f"lr_end ({lr_end}) must be be smaller than initial lr ({lr_init})")
|
255 |
+
|
256 |
+
def lr_lambda(current_step: int):
|
257 |
+
if current_step < num_warmup_steps:
|
258 |
+
return float(current_step) / float(max(1, num_warmup_steps))
|
259 |
+
elif current_step > num_training_steps:
|
260 |
+
return lr_end / lr_init # as LambdaLR multiplies by lr_init
|
261 |
+
else:
|
262 |
+
lr_range = lr_init - lr_end
|
263 |
+
decay_steps = num_training_steps - num_warmup_steps
|
264 |
+
pct_remaining = 1 - (current_step - num_warmup_steps) / decay_steps
|
265 |
+
decay = lr_range * pct_remaining**power + lr_end
|
266 |
+
return decay / lr_init # as LambdaLR multiplies by lr_init
|
267 |
+
|
268 |
+
return LambdaLR(optimizer, lr_lambda, last_epoch)
|
269 |
+
|
270 |
+
|
271 |
+
TYPE_TO_SCHEDULER_FUNCTION = {
|
272 |
+
SchedulerType.LINEAR: get_linear_schedule_with_warmup,
|
273 |
+
SchedulerType.COSINE: get_cosine_schedule_with_warmup,
|
274 |
+
SchedulerType.COSINE_WITH_RESTARTS: get_cosine_with_hard_restarts_schedule_with_warmup,
|
275 |
+
SchedulerType.POLYNOMIAL: get_polynomial_decay_schedule_with_warmup,
|
276 |
+
SchedulerType.CONSTANT: get_constant_schedule,
|
277 |
+
SchedulerType.CONSTANT_WITH_WARMUP: get_constant_schedule_with_warmup,
|
278 |
+
SchedulerType.PIECEWISE_CONSTANT: get_piecewise_constant_schedule,
|
279 |
+
}
|
280 |
+
|
281 |
+
|
282 |
+
def get_scheduler(
|
283 |
+
name: Union[str, SchedulerType],
|
284 |
+
optimizer: Optimizer,
|
285 |
+
step_rules: Optional[str] = None,
|
286 |
+
num_warmup_steps: Optional[int] = None,
|
287 |
+
num_training_steps: Optional[int] = None,
|
288 |
+
num_cycles: int = 1,
|
289 |
+
power: float = 1.0,
|
290 |
+
last_epoch: int = -1,
|
291 |
+
):
|
292 |
+
"""
|
293 |
+
Unified API to get any scheduler from its name.
|
294 |
+
|
295 |
+
Args:
|
296 |
+
name (`str` or `SchedulerType`):
|
297 |
+
The name of the scheduler to use.
|
298 |
+
optimizer (`torch.optim.Optimizer`):
|
299 |
+
The optimizer that will be used during training.
|
300 |
+
step_rules (`str`, *optional*):
|
301 |
+
A string representing the step rules to use. This is only used by the `PIECEWISE_CONSTANT` scheduler.
|
302 |
+
num_warmup_steps (`int`, *optional*):
|
303 |
+
The number of warmup steps to do. This is not required by all schedulers (hence the argument being
|
304 |
+
optional), the function will raise an error if it's unset and the scheduler type requires it.
|
305 |
+
num_training_steps (`int``, *optional*):
|
306 |
+
The number of training steps to do. This is not required by all schedulers (hence the argument being
|
307 |
+
optional), the function will raise an error if it's unset and the scheduler type requires it.
|
308 |
+
num_cycles (`int`, *optional*):
|
309 |
+
The number of hard restarts used in `COSINE_WITH_RESTARTS` scheduler.
|
310 |
+
power (`float`, *optional*, defaults to 1.0):
|
311 |
+
Power factor. See `POLYNOMIAL` scheduler
|
312 |
+
last_epoch (`int`, *optional*, defaults to -1):
|
313 |
+
The index of the last epoch when resuming training.
|
314 |
+
"""
|
315 |
+
name = SchedulerType(name)
|
316 |
+
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
|
317 |
+
if name == SchedulerType.CONSTANT:
|
318 |
+
return schedule_func(optimizer, last_epoch=last_epoch)
|
319 |
+
|
320 |
+
if name == SchedulerType.PIECEWISE_CONSTANT:
|
321 |
+
return schedule_func(optimizer, step_rules=step_rules, last_epoch=last_epoch)
|
322 |
+
|
323 |
+
# All other schedulers require `num_warmup_steps`
|
324 |
+
if num_warmup_steps is None:
|
325 |
+
raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")
|
326 |
+
|
327 |
+
if name == SchedulerType.CONSTANT_WITH_WARMUP:
|
328 |
+
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, last_epoch=last_epoch)
|
329 |
+
|
330 |
+
# All other schedulers require `num_training_steps`
|
331 |
+
if num_training_steps is None:
|
332 |
+
raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
|
333 |
+
|
334 |
+
if name == SchedulerType.COSINE_WITH_RESTARTS:
|
335 |
+
return schedule_func(
|
336 |
+
optimizer,
|
337 |
+
num_warmup_steps=num_warmup_steps,
|
338 |
+
num_training_steps=num_training_steps,
|
339 |
+
num_cycles=num_cycles,
|
340 |
+
last_epoch=last_epoch,
|
341 |
+
)
|
342 |
+
|
343 |
+
if name == SchedulerType.POLYNOMIAL:
|
344 |
+
return schedule_func(
|
345 |
+
optimizer,
|
346 |
+
num_warmup_steps=num_warmup_steps,
|
347 |
+
num_training_steps=num_training_steps,
|
348 |
+
power=power,
|
349 |
+
last_epoch=last_epoch,
|
350 |
+
)
|
351 |
+
|
352 |
+
return schedule_func(
|
353 |
+
optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, last_epoch=last_epoch
|
354 |
+
)
|
6DoF/diffusers/pipeline_utils.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
# NOTE: This file is deprecated and will be removed in a future version.
|
17 |
+
# It only exists so that temporarely `from diffusers.pipelines import DiffusionPipeline` works
|
18 |
+
|
19 |
+
from .pipelines import DiffusionPipeline, ImagePipelineOutput # noqa: F401
|
20 |
+
from .utils import deprecate
|
21 |
+
|
22 |
+
|
23 |
+
deprecate(
|
24 |
+
"pipelines_utils",
|
25 |
+
"0.22.0",
|
26 |
+
"Importing `DiffusionPipeline` or `ImagePipelineOutput` from diffusers.pipeline_utils is deprecated. Please import from diffusers.pipelines.pipeline_utils instead.",
|
27 |
+
standard_warn=False,
|
28 |
+
stacklevel=3,
|
29 |
+
)
|