Spaces:
Runtime error
Runtime error
upload code
Browse files- .gitignore +3 -0
- LICENSE +21 -0
- annotator/hed/__init__.py +114 -0
- annotator/lineart/LICENSE +21 -0
- annotator/lineart/__init__.py +138 -0
- annotator/util.py +98 -0
- app.py +197 -4
- ip_adapter/__init__.py +11 -0
- ip_adapter/attention_processor.py +554 -0
- ip_adapter/custom_pipelines.py +394 -0
- ip_adapter/ip_adapter.py +1086 -0
- ip_adapter/resampler.py +158 -0
- ip_adapter/style_encoder.py +246 -0
- ip_adapter/test_resampler.py +44 -0
- ip_adapter/tools.py +31 -0
- ip_adapter/utils.py +5 -0
.gitignore
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
__pycache__
|
2 |
+
*/__pycache__
|
3 |
+
**/__pycache__
|
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2024 OpenMMLab
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
annotator/hed/__init__.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This is an improved version and model of HED edge detection with Apache License, Version 2.0.
|
2 |
+
# Please use this implementation in your products
|
3 |
+
# This implementation may produce slightly different results from Saining Xie's official implementations,
|
4 |
+
# but it generates smoother edges and is more suitable for ControlNet as well as other image-to-image translations.
|
5 |
+
# Different from official models and other implementations, this is an RGB-input model (rather than BGR)
|
6 |
+
# and in this way it works better for gradio's RGB protocol
|
7 |
+
|
8 |
+
import os
|
9 |
+
import cv2
|
10 |
+
import torch
|
11 |
+
import numpy as np
|
12 |
+
|
13 |
+
from einops import rearrange
|
14 |
+
from annotator.util import annotator_ckpts_path, safe_step
|
15 |
+
|
16 |
+
|
17 |
+
class DoubleConvBlock(torch.nn.Module):
|
18 |
+
def __init__(self, input_channel, output_channel, layer_number):
|
19 |
+
super().__init__()
|
20 |
+
self.convs = torch.nn.Sequential()
|
21 |
+
self.convs.append(torch.nn.Conv2d(in_channels=input_channel, out_channels=output_channel, kernel_size=(3, 3), stride=(1, 1), padding=1))
|
22 |
+
for i in range(1, layer_number):
|
23 |
+
self.convs.append(torch.nn.Conv2d(in_channels=output_channel, out_channels=output_channel, kernel_size=(3, 3), stride=(1, 1), padding=1))
|
24 |
+
self.projection = torch.nn.Conv2d(in_channels=output_channel, out_channels=1, kernel_size=(1, 1), stride=(1, 1), padding=0)
|
25 |
+
|
26 |
+
def __call__(self, x, down_sampling=False):
|
27 |
+
h = x
|
28 |
+
if down_sampling:
|
29 |
+
h = torch.nn.functional.max_pool2d(h, kernel_size=(2, 2), stride=(2, 2))
|
30 |
+
for conv in self.convs:
|
31 |
+
h = conv(h)
|
32 |
+
h = torch.nn.functional.relu(h)
|
33 |
+
return h, self.projection(h)
|
34 |
+
|
35 |
+
|
36 |
+
class ControlNetHED_Apache2(torch.nn.Module):
|
37 |
+
def __init__(self):
|
38 |
+
super().__init__()
|
39 |
+
self.norm = torch.nn.Parameter(torch.zeros(size=(1, 3, 1, 1)))
|
40 |
+
self.block1 = DoubleConvBlock(input_channel=3, output_channel=64, layer_number=2)
|
41 |
+
self.block2 = DoubleConvBlock(input_channel=64, output_channel=128, layer_number=2)
|
42 |
+
self.block3 = DoubleConvBlock(input_channel=128, output_channel=256, layer_number=3)
|
43 |
+
self.block4 = DoubleConvBlock(input_channel=256, output_channel=512, layer_number=3)
|
44 |
+
self.block5 = DoubleConvBlock(input_channel=512, output_channel=512, layer_number=3)
|
45 |
+
|
46 |
+
def __call__(self, x):
|
47 |
+
h = x - self.norm
|
48 |
+
h, projection1 = self.block1(h)
|
49 |
+
h, projection2 = self.block2(h, down_sampling=True)
|
50 |
+
h, projection3 = self.block3(h, down_sampling=True)
|
51 |
+
h, projection4 = self.block4(h, down_sampling=True)
|
52 |
+
h, projection5 = self.block5(h, down_sampling=True)
|
53 |
+
return projection1, projection2, projection3, projection4, projection5
|
54 |
+
|
55 |
+
|
56 |
+
class HEDdetector:
|
57 |
+
def __init__(self):
|
58 |
+
remote_model_path = "https://huggingface.co/lllyasviel/Annotators/resolve/main/ControlNetHED.pth"
|
59 |
+
modelpath = os.path.join(annotator_ckpts_path, "ControlNetHED.pth")
|
60 |
+
if not os.path.exists(modelpath):
|
61 |
+
from basicsr.utils.download_util import load_file_from_url
|
62 |
+
load_file_from_url(remote_model_path, model_dir=annotator_ckpts_path)
|
63 |
+
self.netNetwork = ControlNetHED_Apache2().float().cuda().eval()
|
64 |
+
self.netNetwork.load_state_dict(torch.load(modelpath))
|
65 |
+
|
66 |
+
def __call__(self, input_image, safe=False):
|
67 |
+
assert input_image.ndim == 3
|
68 |
+
H, W, C = input_image.shape
|
69 |
+
with torch.no_grad():
|
70 |
+
image_hed = torch.from_numpy(input_image.copy()).float().cuda()
|
71 |
+
image_hed = rearrange(image_hed, 'h w c -> 1 c h w')
|
72 |
+
edges = self.netNetwork(image_hed)
|
73 |
+
edges = [e.detach().cpu().numpy().astype(np.float32)[0, 0] for e in edges]
|
74 |
+
edges = [cv2.resize(e, (W, H), interpolation=cv2.INTER_LINEAR) for e in edges]
|
75 |
+
edges = np.stack(edges, axis=2)
|
76 |
+
edge = 1 / (1 + np.exp(-np.mean(edges, axis=2).astype(np.float64)))
|
77 |
+
if safe:
|
78 |
+
edge = safe_step(edge)
|
79 |
+
edge = (edge * 255.0).clip(0, 255).astype(np.uint8)
|
80 |
+
return edge
|
81 |
+
|
82 |
+
|
83 |
+
class SOFT_HEDdetector:
|
84 |
+
def __init__(self):
|
85 |
+
remote_model_path = "https://huggingface.co/lllyasviel/Annotators/resolve/main/ControlNetHED.pth"
|
86 |
+
modelpath = os.path.join(annotator_ckpts_path, "ControlNetHED.pth")
|
87 |
+
if not os.path.exists(modelpath):
|
88 |
+
from basicsr.utils.download_util import load_file_from_url
|
89 |
+
load_file_from_url(remote_model_path, model_dir=annotator_ckpts_path)
|
90 |
+
self.netNetwork = ControlNetHED_Apache2().float().cuda().eval()
|
91 |
+
self.netNetwork.load_state_dict(torch.load(modelpath))
|
92 |
+
|
93 |
+
def __call__(self, input_image, safe=False, threshold=200):
|
94 |
+
assert input_image.ndim == 3
|
95 |
+
H, W, C = input_image.shape
|
96 |
+
with torch.no_grad():
|
97 |
+
image_hed = torch.from_numpy(input_image.copy()).float().cuda()
|
98 |
+
image_hed = rearrange(image_hed, 'h w c -> 1 c h w')
|
99 |
+
edges = self.netNetwork(image_hed)
|
100 |
+
edges = [e.detach().cpu().numpy().astype(np.float32)[0, 0] for e in edges]
|
101 |
+
edges = [cv2.resize(e, (W, H), interpolation=cv2.INTER_LINEAR) for e in edges]
|
102 |
+
edges = np.stack(edges, axis=2)
|
103 |
+
edge = 1 / (1 + np.exp(-np.mean(edges, axis=2).astype(np.float64)))
|
104 |
+
if safe:
|
105 |
+
edge = safe_step(edge)
|
106 |
+
edge = (edge * 255.0).clip(0, 255).astype(np.uint8)
|
107 |
+
|
108 |
+
content_image = edge
|
109 |
+
content_image[content_image > threshold] = 255
|
110 |
+
content_image[content_image < 255] = 0
|
111 |
+
kernel = np.ones((3,3), np.uint8)
|
112 |
+
|
113 |
+
content_image = cv2.dilate(content_image, kernel, iterations=1)
|
114 |
+
return content_image
|
annotator/lineart/LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2022 Caroline Chan
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
annotator/lineart/__init__.py
ADDED
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# From https://github.com/carolineec/informative-drawings
|
2 |
+
# MIT License
|
3 |
+
|
4 |
+
import os
|
5 |
+
import cv2
|
6 |
+
import torch
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
import torch.nn as nn
|
10 |
+
from einops import rearrange
|
11 |
+
from annotator.util import annotator_ckpts_path
|
12 |
+
|
13 |
+
|
14 |
+
norm_layer = nn.InstanceNorm2d
|
15 |
+
|
16 |
+
|
17 |
+
class ResidualBlock(nn.Module):
|
18 |
+
def __init__(self, in_features):
|
19 |
+
super(ResidualBlock, self).__init__()
|
20 |
+
|
21 |
+
conv_block = [ nn.ReflectionPad2d(1),
|
22 |
+
nn.Conv2d(in_features, in_features, 3),
|
23 |
+
norm_layer(in_features),
|
24 |
+
nn.ReLU(inplace=True),
|
25 |
+
nn.ReflectionPad2d(1),
|
26 |
+
nn.Conv2d(in_features, in_features, 3),
|
27 |
+
norm_layer(in_features)
|
28 |
+
]
|
29 |
+
|
30 |
+
self.conv_block = nn.Sequential(*conv_block)
|
31 |
+
|
32 |
+
def forward(self, x):
|
33 |
+
return x + self.conv_block(x)
|
34 |
+
|
35 |
+
|
36 |
+
class Generator(nn.Module):
|
37 |
+
def __init__(self, input_nc, output_nc, n_residual_blocks=9, sigmoid=True):
|
38 |
+
super(Generator, self).__init__()
|
39 |
+
|
40 |
+
# Initial convolution block
|
41 |
+
model0 = [ nn.ReflectionPad2d(3),
|
42 |
+
nn.Conv2d(input_nc, 64, 7),
|
43 |
+
norm_layer(64),
|
44 |
+
nn.ReLU(inplace=True) ]
|
45 |
+
self.model0 = nn.Sequential(*model0)
|
46 |
+
|
47 |
+
# Downsampling
|
48 |
+
model1 = []
|
49 |
+
in_features = 64
|
50 |
+
out_features = in_features*2
|
51 |
+
for _ in range(2):
|
52 |
+
model1 += [ nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
|
53 |
+
norm_layer(out_features),
|
54 |
+
nn.ReLU(inplace=True) ]
|
55 |
+
in_features = out_features
|
56 |
+
out_features = in_features*2
|
57 |
+
self.model1 = nn.Sequential(*model1)
|
58 |
+
|
59 |
+
model2 = []
|
60 |
+
# Residual blocks
|
61 |
+
for _ in range(n_residual_blocks):
|
62 |
+
model2 += [ResidualBlock(in_features)]
|
63 |
+
self.model2 = nn.Sequential(*model2)
|
64 |
+
|
65 |
+
# Upsampling
|
66 |
+
model3 = []
|
67 |
+
out_features = in_features//2
|
68 |
+
for _ in range(2):
|
69 |
+
model3 += [ nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1),
|
70 |
+
norm_layer(out_features),
|
71 |
+
nn.ReLU(inplace=True) ]
|
72 |
+
in_features = out_features
|
73 |
+
out_features = in_features//2
|
74 |
+
self.model3 = nn.Sequential(*model3)
|
75 |
+
|
76 |
+
# Output layer
|
77 |
+
model4 = [ nn.ReflectionPad2d(3),
|
78 |
+
nn.Conv2d(64, output_nc, 7)]
|
79 |
+
if sigmoid:
|
80 |
+
model4 += [nn.Sigmoid()]
|
81 |
+
|
82 |
+
self.model4 = nn.Sequential(*model4)
|
83 |
+
|
84 |
+
def forward(self, x, cond=None):
|
85 |
+
out = self.model0(x)
|
86 |
+
out = self.model1(out)
|
87 |
+
out = self.model2(out)
|
88 |
+
out = self.model3(out)
|
89 |
+
out = self.model4(out)
|
90 |
+
|
91 |
+
return out
|
92 |
+
|
93 |
+
|
94 |
+
class LineartDetector:
|
95 |
+
def __init__(self):
|
96 |
+
self.model = self.load_model('sk_model.pth')
|
97 |
+
self.model_coarse = self.load_model('sk_model2.pth')
|
98 |
+
|
99 |
+
def load_model(self, name):
|
100 |
+
remote_model_path = "https://huggingface.co/lllyasviel/Annotators/resolve/main/" + name
|
101 |
+
modelpath = os.path.join(annotator_ckpts_path, name)
|
102 |
+
if not os.path.exists(modelpath):
|
103 |
+
from basicsr.utils.download_util import load_file_from_url
|
104 |
+
load_file_from_url(remote_model_path, model_dir=annotator_ckpts_path)
|
105 |
+
model = Generator(3, 1, 3)
|
106 |
+
model.load_state_dict(torch.load(modelpath, map_location=torch.device('cpu')))
|
107 |
+
model.eval()
|
108 |
+
model = model.cuda()
|
109 |
+
return model
|
110 |
+
|
111 |
+
def __call__(self, input_image, coarse = False):
|
112 |
+
model = self.model_coarse if coarse else self.model
|
113 |
+
assert input_image.ndim == 3
|
114 |
+
image = input_image
|
115 |
+
# images = input_images
|
116 |
+
# results = []
|
117 |
+
with torch.no_grad():
|
118 |
+
image = torch.from_numpy(image).float().cuda()
|
119 |
+
# batch_imgs = torch.stack([torch.from_numpy(image).float().cuda() / 255.0 for image in images], dim=0)
|
120 |
+
image = image / 255.0
|
121 |
+
image = rearrange(image, 'h w c -> 1 c h w')
|
122 |
+
line = model(image)[0][0]
|
123 |
+
|
124 |
+
line = line.cpu().numpy()
|
125 |
+
line = (line * 255.0).clip(0, 255).astype(np.uint8)
|
126 |
+
|
127 |
+
# with torch.no_grad():
|
128 |
+
# # 将批次的图像传入模型
|
129 |
+
# outputs = model(batch_imgs)
|
130 |
+
|
131 |
+
# for output in outputs:
|
132 |
+
# line = output[0][0].cpu().numpy()
|
133 |
+
# line = (line * 255.0).clip(0, 255).astype(np.uint8)
|
134 |
+
# results.append(line)
|
135 |
+
|
136 |
+
# return results
|
137 |
+
|
138 |
+
return line
|
annotator/util.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import cv2
|
5 |
+
import os
|
6 |
+
|
7 |
+
|
8 |
+
annotator_ckpts_path = os.path.join(os.path.dirname(__file__), 'ckpts')
|
9 |
+
|
10 |
+
|
11 |
+
def HWC3(x):
|
12 |
+
assert x.dtype == np.uint8
|
13 |
+
if x.ndim == 2:
|
14 |
+
x = x[:, :, None]
|
15 |
+
assert x.ndim == 3
|
16 |
+
H, W, C = x.shape
|
17 |
+
assert C == 1 or C == 3 or C == 4
|
18 |
+
if C == 3:
|
19 |
+
return x
|
20 |
+
if C == 1:
|
21 |
+
return np.concatenate([x, x, x], axis=2)
|
22 |
+
if C == 4:
|
23 |
+
color = x[:, :, 0:3].astype(np.float32)
|
24 |
+
alpha = x[:, :, 3:4].astype(np.float32) / 255.0
|
25 |
+
y = color * alpha + 255.0 * (1.0 - alpha)
|
26 |
+
y = y.clip(0, 255).astype(np.uint8)
|
27 |
+
return y
|
28 |
+
|
29 |
+
|
30 |
+
def resize_image(input_image, resolution):
|
31 |
+
H, W, C = input_image.shape
|
32 |
+
H = float(H)
|
33 |
+
W = float(W)
|
34 |
+
k = float(resolution) / min(H, W)
|
35 |
+
H *= k
|
36 |
+
W *= k
|
37 |
+
H = int(np.round(H / 64.0)) * 64
|
38 |
+
W = int(np.round(W / 64.0)) * 64
|
39 |
+
img = cv2.resize(input_image, (W, H), interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA)
|
40 |
+
return img
|
41 |
+
|
42 |
+
|
43 |
+
def nms(x, t, s):
|
44 |
+
x = cv2.GaussianBlur(x.astype(np.float32), (0, 0), s)
|
45 |
+
|
46 |
+
f1 = np.array([[0, 0, 0], [1, 1, 1], [0, 0, 0]], dtype=np.uint8)
|
47 |
+
f2 = np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], dtype=np.uint8)
|
48 |
+
f3 = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.uint8)
|
49 |
+
f4 = np.array([[0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=np.uint8)
|
50 |
+
|
51 |
+
y = np.zeros_like(x)
|
52 |
+
|
53 |
+
for f in [f1, f2, f3, f4]:
|
54 |
+
np.putmask(y, cv2.dilate(x, kernel=f) == x, x)
|
55 |
+
|
56 |
+
z = np.zeros_like(y, dtype=np.uint8)
|
57 |
+
z[y > t] = 255
|
58 |
+
return z
|
59 |
+
|
60 |
+
|
61 |
+
def make_noise_disk(H, W, C, F):
|
62 |
+
noise = np.random.uniform(low=0, high=1, size=((H // F) + 2, (W // F) + 2, C))
|
63 |
+
noise = cv2.resize(noise, (W + 2 * F, H + 2 * F), interpolation=cv2.INTER_CUBIC)
|
64 |
+
noise = noise[F: F + H, F: F + W]
|
65 |
+
noise -= np.min(noise)
|
66 |
+
noise /= np.max(noise)
|
67 |
+
if C == 1:
|
68 |
+
noise = noise[:, :, None]
|
69 |
+
return noise
|
70 |
+
|
71 |
+
|
72 |
+
def min_max_norm(x):
|
73 |
+
x -= np.min(x)
|
74 |
+
x /= np.maximum(np.max(x), 1e-5)
|
75 |
+
return x
|
76 |
+
|
77 |
+
|
78 |
+
def safe_step(x, step=2):
|
79 |
+
y = x.astype(np.float32) * float(step + 1)
|
80 |
+
y = y.astype(np.int32).astype(np.float32) / float(step)
|
81 |
+
return y
|
82 |
+
|
83 |
+
|
84 |
+
def img2mask(img, H, W, low=10, high=90):
|
85 |
+
assert img.ndim == 3 or img.ndim == 2
|
86 |
+
assert img.dtype == np.uint8
|
87 |
+
|
88 |
+
if img.ndim == 3:
|
89 |
+
y = img[:, :, random.randrange(0, img.shape[2])]
|
90 |
+
else:
|
91 |
+
y = img
|
92 |
+
|
93 |
+
y = cv2.resize(y, (W, H), interpolation=cv2.INTER_CUBIC)
|
94 |
+
|
95 |
+
if random.uniform(0, 1) < 0.5:
|
96 |
+
y = 255 - y
|
97 |
+
|
98 |
+
return y < np.percentile(y, random.randrange(low, high))
|
app.py
CHANGED
@@ -1,7 +1,200 @@
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
-
def greet(name):
|
4 |
-
return "Hello " + name + "!!"
|
5 |
|
6 |
-
|
7 |
-
demo.launch()
|
|
|
1 |
+
from types import MethodType
|
2 |
+
|
3 |
+
import spaces
|
4 |
+
import os
|
5 |
import gradio as gr
|
6 |
+
import torch
|
7 |
+
import cv2
|
8 |
+
from annotator.util import resize_image
|
9 |
+
from annotator.hed import SOFT_HEDdetector
|
10 |
+
from annotator.lineart import LineartDetector
|
11 |
+
from diffusers import UNet2DConditionModel, ControlNetModel
|
12 |
+
from transformers import CLIPVisionModelWithProjection
|
13 |
+
from huggingface_hub import snapshot_download
|
14 |
+
from PIL import Image
|
15 |
+
from ip_adapter import StyleShot, StyleContentStableDiffusionControlNetPipeline
|
16 |
+
|
17 |
+
device = "cuda"
|
18 |
+
|
19 |
+
contour_detector = SOFT_HEDdetector()
|
20 |
+
lineart_detector = LineartDetector()
|
21 |
+
|
22 |
+
base_model_path = "runwayml/stable-diffusion-v1-5"
|
23 |
+
transformer_block_path = "laion/CLIP-ViT-H-14-laion2B-s32B-b79K"
|
24 |
+
styleshot_model_path = "Gaojunyao/StyleShot"
|
25 |
+
styleshot_lineart_model_path = "Gaojunyao/StyleShot_lineart"
|
26 |
+
|
27 |
+
if not os.path.isdir(base_model_path):
|
28 |
+
base_model_path = snapshot_download(base_model_path, local_dir=base_model_path)
|
29 |
+
print(f"Downloaded model to {base_model_path}")
|
30 |
+
if not os.path.isdir(transformer_block_path):
|
31 |
+
transformer_block_path = snapshot_download(transformer_block_path, local_dir=transformer_block_path)
|
32 |
+
print(f"Downloaded model to {transformer_block_path}")
|
33 |
+
if not os.path.isdir(styleshot_model_path):
|
34 |
+
styleshot_model_path = snapshot_download(styleshot_model_path, local_dir=styleshot_model_path)
|
35 |
+
print(f"Downloaded model to {styleshot_model_path}")
|
36 |
+
if not os.path.isdir(styleshot_lineart_model_path):
|
37 |
+
styleshot_lineart_model_path = snapshot_download(styleshot_lineart_model_path, local_dir=styleshot_lineart_model_path)
|
38 |
+
print(f"Downloaded model to {styleshot_lineart_model_path}")
|
39 |
+
|
40 |
+
|
41 |
+
# weights for ip-adapter and our content-fusion encoder
|
42 |
+
contour_ip_ckpt = os.path.join(styleshot_model_path, "pretrained_weight/ip.bin")
|
43 |
+
contour_style_aware_encoder_path = os.path.join(styleshot_model_path, "pretrained_weight/style_aware_encoder.bin")
|
44 |
+
contour_transformer_block_path = transformer_block_path
|
45 |
+
contour_unet = UNet2DConditionModel.from_pretrained(base_model_path, subfolder="unet")
|
46 |
+
contour_content_fusion_encoder = ControlNetModel.from_unet(contour_unet)
|
47 |
+
|
48 |
+
contour_pipe = StyleContentStableDiffusionControlNetPipeline.from_pretrained(base_model_path, controlnet=contour_content_fusion_encoder)
|
49 |
+
contour_styleshot = StyleShot(device, contour_pipe, contour_ip_ckpt, contour_style_aware_encoder_path, contour_transformer_block_path)
|
50 |
+
|
51 |
+
lineart_ip_ckpt = os.path.join(styleshot_lineart_model_path, "pretrained_weight/ip.bin")
|
52 |
+
lineart_style_aware_encoder_path = os.path.join(styleshot_lineart_model_path, "pretrained_weight/style_aware_encoder.bin")
|
53 |
+
lineart_transformer_block_path = transformer_block_path
|
54 |
+
lineart_unet = UNet2DConditionModel.from_pretrained(base_model_path, subfolder="unet")
|
55 |
+
lineart_content_fusion_encoder = ControlNetModel.from_unet(lineart_unet)
|
56 |
+
|
57 |
+
lineart_pipe = StyleContentStableDiffusionControlNetPipeline.from_pretrained(base_model_path, controlnet=lineart_content_fusion_encoder)
|
58 |
+
lineart_styleshot = StyleShot(device, lineart_pipe, lineart_ip_ckpt, lineart_style_aware_encoder_path, lineart_transformer_block_path)
|
59 |
+
|
60 |
+
|
61 |
+
@spaces.GPU
|
62 |
+
def process(style_image, content_image, prompt, num_samples, image_resolution, condition_scale, style_scale,ddim_steps, guidance_scale, seed, a_prompt, n_prompt, btn1, Contour_Threshold=200):
|
63 |
+
weight_dtype = torch.float32
|
64 |
+
|
65 |
+
style_shots = []
|
66 |
+
btns = []
|
67 |
+
contour_content_images = []
|
68 |
+
contour_results = []
|
69 |
+
lineart_content_images = []
|
70 |
+
lineart_results = []
|
71 |
+
|
72 |
+
type1 = 'Contour'
|
73 |
+
type2 = 'Lineart'
|
74 |
+
|
75 |
+
if btn1 == type1 or content_image is None:
|
76 |
+
style_shots = [contour_styleshot]
|
77 |
+
btns = [type1]
|
78 |
+
elif btn1 == type2:
|
79 |
+
style_shots = [lineart_styleshot]
|
80 |
+
btns = [type2]
|
81 |
+
elif btn1 == "Both":
|
82 |
+
style_shots = [contour_styleshot, lineart_styleshot]
|
83 |
+
btns = [type1, type2]
|
84 |
+
|
85 |
+
ori_style_image = style_image.copy()
|
86 |
+
|
87 |
+
|
88 |
+
if content_image is not None:
|
89 |
+
ori_content_image = content_image.copy()
|
90 |
+
else:
|
91 |
+
ori_content_image = None
|
92 |
+
|
93 |
+
for styleshot, btn in zip(style_shots, btns):
|
94 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
95 |
+
prompts = [prompt+" "+a_prompt]
|
96 |
+
|
97 |
+
style_image = Image.fromarray(ori_style_image)
|
98 |
+
|
99 |
+
if ori_content_image is not None:
|
100 |
+
if btn == type1:
|
101 |
+
content_image = resize_image(ori_content_image, image_resolution)
|
102 |
+
content_image = contour_detector(content_image, threshold=Contour_Threshold)
|
103 |
+
elif btn == type2:
|
104 |
+
content_image = resize_image(ori_content_image, image_resolution)
|
105 |
+
content_image = lineart_detector(content_image, coarse=False)
|
106 |
+
|
107 |
+
content_image = Image.fromarray(content_image)
|
108 |
+
else:
|
109 |
+
content_image = cv2.resize(ori_style_image, (image_resolution, image_resolution))
|
110 |
+
content_image = Image.fromarray(content_image)
|
111 |
+
condition_scale = 0.0
|
112 |
+
|
113 |
+
g_images = styleshot.generate(style_image=style_image,
|
114 |
+
prompt=[[prompt]],
|
115 |
+
negative_prompt=n_prompt,
|
116 |
+
scale=style_scale,
|
117 |
+
num_samples = num_samples,
|
118 |
+
seed = seed,
|
119 |
+
num_inference_steps=ddim_steps,
|
120 |
+
guidance_scale=guidance_scale,
|
121 |
+
content_image=content_image,
|
122 |
+
controlnet_conditioning_scale= float(condition_scale))
|
123 |
+
|
124 |
+
if btn == type1:
|
125 |
+
contour_content_images = [content_image]
|
126 |
+
contour_results = g_images[0]
|
127 |
+
elif btn == type2:
|
128 |
+
lineart_content_images = [content_image]
|
129 |
+
lineart_results = g_images[0]
|
130 |
+
if ori_content_image is None:
|
131 |
+
contour_content_images = []
|
132 |
+
lineart_results = []
|
133 |
+
lineart_content_images = []
|
134 |
+
|
135 |
+
return [contour_results, contour_content_images, lineart_results, lineart_content_images]
|
136 |
+
|
137 |
+
|
138 |
+
block = gr.Blocks().queue()
|
139 |
+
with block:
|
140 |
+
with gr.Row():
|
141 |
+
gr.Markdown("## Styleshot Demo")
|
142 |
+
with gr.Row():
|
143 |
+
with gr.Column():
|
144 |
+
style_image = gr.Image(sources=['upload'], type="numpy", label='Style Image')
|
145 |
+
with gr.Column():
|
146 |
+
with gr.Box():
|
147 |
+
with gr.Column():
|
148 |
+
content_image = gr.Image(sources=['upload'], type="numpy", label='Content Image (optional)')
|
149 |
+
btn1 = gr.Radio(
|
150 |
+
choices=["Contour", "Lineart", "Both"],
|
151 |
+
interactive=True,
|
152 |
+
label="Preprocessor",
|
153 |
+
value="Both",
|
154 |
+
)
|
155 |
+
gr.Markdown("We recommend using 'Contour' for sparse control and 'Lineart' for detailed control. If you choose 'Both', we will provide results for two types of control. If you choose 'Contour', you can adjust the 'Contour Threshold' under the 'Advanced options' for the level of detail in control. ")
|
156 |
+
with gr.Row():
|
157 |
+
prompt = gr.Textbox(label="Prompt")
|
158 |
+
with gr.Row():
|
159 |
+
run_button = gr.Button(value="Run")
|
160 |
+
with gr.Row():
|
161 |
+
with gr.Column():
|
162 |
+
with gr.Accordion("Advanced options", open=False):
|
163 |
+
num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=4, step=1)
|
164 |
+
image_resolution = gr.Slider(label="Image Resolution", minimum=256, maximum=768, value=512, step=64)
|
165 |
+
condition_scale = gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
|
166 |
+
|
167 |
+
Contour_Threshold = gr.Slider(label="Contour Threshold", minimum=0, maximum=255, value=200, step=1)
|
168 |
+
|
169 |
+
style_scale = gr.Slider(label="Style Strength", minimum=0, maximum=2, value=1.0, step=0.01)
|
170 |
+
|
171 |
+
ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=50, step=1)
|
172 |
+
guidance_scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=7.5, step=0.1)
|
173 |
+
seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, value=42, step=1)
|
174 |
+
|
175 |
+
a_prompt = gr.Textbox(label="Added Prompt", value='best quality, extremely detailed')
|
176 |
+
n_prompt = gr.Textbox(label="Negative Prompt",
|
177 |
+
value='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality')
|
178 |
+
|
179 |
+
with gr.Row():
|
180 |
+
with gr.Box():
|
181 |
+
gr.Markdown("### Results for Contour")
|
182 |
+
with gr.Row():
|
183 |
+
with gr.Column(scale = 1):
|
184 |
+
contour_gallery = gr.Gallery(label='Contour Output', show_label=True, elem_id="gallery", columns=[1], rows=[1], height='auto')
|
185 |
+
with gr.Column(scale = 4):
|
186 |
+
image_gallery = gr.Gallery(label='Result for Contour', show_label=True, elem_id="gallery", columns=[4], rows=[1], height='auto')
|
187 |
+
with gr.Row():
|
188 |
+
with gr.Box():
|
189 |
+
gr.Markdown("### Results for Lineart")
|
190 |
+
with gr.Row():
|
191 |
+
with gr.Column(scale = 1):
|
192 |
+
line_gallery = gr.Gallery(label='Lineart Output', show_label=True, elem_id="gallery", columns=[1], rows=[1], height='auto')
|
193 |
+
with gr.Column(scale = 4):
|
194 |
+
line_image_gallery = gr.Gallery(label='Result for Lineart', show_label=True, elem_id="gallery", columns=[4], rows=[1], height='auto')
|
195 |
+
|
196 |
+
ips = [style_image, content_image, prompt, num_samples, image_resolution, condition_scale, style_scale, ddim_steps, guidance_scale, seed, a_prompt, n_prompt, btn1, Contour_Threshold]
|
197 |
+
run_button.click(fn=process, inputs=ips, outputs=[image_gallery, contour_gallery, line_image_gallery, line_gallery])
|
198 |
|
|
|
|
|
199 |
|
200 |
+
block.launch(server_name='0.0.0.0')
|
|
ip_adapter/__init__.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .ip_adapter import IPAdapter, IPAdapterPlus, IPAdapterPlusXL, IPAdapterXL, IPAdapterFull, StyleShot, StyleContentStableDiffusionControlNetPipeline
|
2 |
+
|
3 |
+
__all__ = [
|
4 |
+
"IPAdapter",
|
5 |
+
"IPAdapterPlus",
|
6 |
+
"IPAdapterPlusXL",
|
7 |
+
"IPAdapterXL",
|
8 |
+
"IPAdapterFull",
|
9 |
+
"StyleShot",
|
10 |
+
"StyleContentStableDiffusionControlNetPipeline",
|
11 |
+
]
|
ip_adapter/attention_processor.py
ADDED
@@ -0,0 +1,554 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
|
6 |
+
|
7 |
+
class AttnProcessor(nn.Module):
|
8 |
+
r"""
|
9 |
+
Default processor for performing attention-related computations.
|
10 |
+
"""
|
11 |
+
|
12 |
+
def __init__(
|
13 |
+
self,
|
14 |
+
hidden_size=None,
|
15 |
+
cross_attention_dim=None,
|
16 |
+
):
|
17 |
+
super().__init__()
|
18 |
+
|
19 |
+
def __call__(
|
20 |
+
self,
|
21 |
+
attn,
|
22 |
+
hidden_states,
|
23 |
+
encoder_hidden_states=None,
|
24 |
+
attention_mask=None,
|
25 |
+
temb=None,
|
26 |
+
):
|
27 |
+
residual = hidden_states
|
28 |
+
|
29 |
+
if attn.spatial_norm is not None:
|
30 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
31 |
+
|
32 |
+
input_ndim = hidden_states.ndim
|
33 |
+
|
34 |
+
if input_ndim == 4:
|
35 |
+
batch_size, channel, height, width = hidden_states.shape
|
36 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
37 |
+
|
38 |
+
batch_size, sequence_length, _ = (
|
39 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
40 |
+
)
|
41 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
42 |
+
|
43 |
+
if attn.group_norm is not None:
|
44 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
45 |
+
|
46 |
+
query = attn.to_q(hidden_states)
|
47 |
+
|
48 |
+
if encoder_hidden_states is None:
|
49 |
+
encoder_hidden_states = hidden_states
|
50 |
+
elif attn.norm_cross:
|
51 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
52 |
+
|
53 |
+
key = attn.to_k(encoder_hidden_states)
|
54 |
+
value = attn.to_v(encoder_hidden_states)
|
55 |
+
|
56 |
+
query = attn.head_to_batch_dim(query)
|
57 |
+
key = attn.head_to_batch_dim(key)
|
58 |
+
value = attn.head_to_batch_dim(value)
|
59 |
+
|
60 |
+
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
61 |
+
hidden_states = torch.bmm(attention_probs, value)
|
62 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
63 |
+
|
64 |
+
# linear proj
|
65 |
+
hidden_states = attn.to_out[0](hidden_states)
|
66 |
+
# dropout
|
67 |
+
hidden_states = attn.to_out[1](hidden_states)
|
68 |
+
|
69 |
+
if input_ndim == 4:
|
70 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
71 |
+
|
72 |
+
if attn.residual_connection:
|
73 |
+
hidden_states = hidden_states + residual
|
74 |
+
|
75 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
76 |
+
|
77 |
+
return hidden_states
|
78 |
+
|
79 |
+
|
80 |
+
class IPAttnProcessor(nn.Module):
|
81 |
+
r"""
|
82 |
+
Attention processor for IP-Adapater.
|
83 |
+
Args:
|
84 |
+
hidden_size (`int`):
|
85 |
+
The hidden size of the attention layer.
|
86 |
+
cross_attention_dim (`int`):
|
87 |
+
The number of channels in the `encoder_hidden_states`.
|
88 |
+
scale (`float`, defaults to 1.0):
|
89 |
+
the weight scale of image prompt.
|
90 |
+
num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
|
91 |
+
The context length of the image features.
|
92 |
+
"""
|
93 |
+
|
94 |
+
def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4):
|
95 |
+
super().__init__()
|
96 |
+
|
97 |
+
self.hidden_size = hidden_size
|
98 |
+
self.cross_attention_dim = cross_attention_dim
|
99 |
+
self.scale = scale
|
100 |
+
self.num_tokens = num_tokens
|
101 |
+
|
102 |
+
self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
103 |
+
self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
104 |
+
|
105 |
+
def __call__(
|
106 |
+
self,
|
107 |
+
attn,
|
108 |
+
hidden_states,
|
109 |
+
encoder_hidden_states=None,
|
110 |
+
attention_mask=None,
|
111 |
+
temb=None,
|
112 |
+
):
|
113 |
+
residual = hidden_states
|
114 |
+
|
115 |
+
if attn.spatial_norm is not None:
|
116 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
117 |
+
|
118 |
+
input_ndim = hidden_states.ndim
|
119 |
+
|
120 |
+
if input_ndim == 4:
|
121 |
+
batch_size, channel, height, width = hidden_states.shape
|
122 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
123 |
+
|
124 |
+
batch_size, sequence_length, _ = (
|
125 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
126 |
+
)
|
127 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
128 |
+
|
129 |
+
if attn.group_norm is not None:
|
130 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
131 |
+
|
132 |
+
query = attn.to_q(hidden_states)
|
133 |
+
|
134 |
+
if encoder_hidden_states is None:
|
135 |
+
encoder_hidden_states = hidden_states
|
136 |
+
else:
|
137 |
+
# get encoder_hidden_states, ip_hidden_states
|
138 |
+
end_pos = encoder_hidden_states.shape[1] - self.num_tokens
|
139 |
+
encoder_hidden_states, ip_hidden_states = (
|
140 |
+
encoder_hidden_states[:, :end_pos, :],
|
141 |
+
encoder_hidden_states[:, end_pos:, :],
|
142 |
+
)
|
143 |
+
if attn.norm_cross:
|
144 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
145 |
+
|
146 |
+
key = attn.to_k(encoder_hidden_states)
|
147 |
+
value = attn.to_v(encoder_hidden_states)
|
148 |
+
|
149 |
+
query = attn.head_to_batch_dim(query)
|
150 |
+
key = attn.head_to_batch_dim(key)
|
151 |
+
value = attn.head_to_batch_dim(value)
|
152 |
+
|
153 |
+
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
154 |
+
hidden_states = torch.bmm(attention_probs, value)
|
155 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
156 |
+
|
157 |
+
# for ip-adapter
|
158 |
+
ip_key = self.to_k_ip(ip_hidden_states)
|
159 |
+
ip_value = self.to_v_ip(ip_hidden_states)
|
160 |
+
|
161 |
+
ip_key = attn.head_to_batch_dim(ip_key)
|
162 |
+
ip_value = attn.head_to_batch_dim(ip_value)
|
163 |
+
|
164 |
+
ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
|
165 |
+
ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
|
166 |
+
ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)
|
167 |
+
|
168 |
+
hidden_states = hidden_states + self.scale * ip_hidden_states
|
169 |
+
|
170 |
+
# linear proj
|
171 |
+
hidden_states = attn.to_out[0](hidden_states)
|
172 |
+
# dropout
|
173 |
+
hidden_states = attn.to_out[1](hidden_states)
|
174 |
+
|
175 |
+
if input_ndim == 4:
|
176 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
177 |
+
|
178 |
+
if attn.residual_connection:
|
179 |
+
hidden_states = hidden_states + residual
|
180 |
+
|
181 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
182 |
+
|
183 |
+
return hidden_states
|
184 |
+
|
185 |
+
|
186 |
+
class AttnProcessor2_0(torch.nn.Module):
|
187 |
+
r"""
|
188 |
+
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
|
189 |
+
"""
|
190 |
+
|
191 |
+
def __init__(
|
192 |
+
self,
|
193 |
+
hidden_size=None,
|
194 |
+
cross_attention_dim=None,
|
195 |
+
):
|
196 |
+
super().__init__()
|
197 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
198 |
+
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
199 |
+
|
200 |
+
def __call__(
|
201 |
+
self,
|
202 |
+
attn,
|
203 |
+
hidden_states,
|
204 |
+
encoder_hidden_states=None,
|
205 |
+
attention_mask=None,
|
206 |
+
temb=None,
|
207 |
+
):
|
208 |
+
residual = hidden_states
|
209 |
+
|
210 |
+
if attn.spatial_norm is not None:
|
211 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
212 |
+
|
213 |
+
input_ndim = hidden_states.ndim
|
214 |
+
|
215 |
+
if input_ndim == 4:
|
216 |
+
batch_size, channel, height, width = hidden_states.shape
|
217 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
218 |
+
|
219 |
+
batch_size, sequence_length, _ = (
|
220 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
221 |
+
)
|
222 |
+
|
223 |
+
if attention_mask is not None:
|
224 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
225 |
+
# scaled_dot_product_attention expects attention_mask shape to be
|
226 |
+
# (batch, heads, source_length, target_length)
|
227 |
+
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
228 |
+
|
229 |
+
if attn.group_norm is not None:
|
230 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
231 |
+
|
232 |
+
query = attn.to_q(hidden_states)
|
233 |
+
|
234 |
+
if encoder_hidden_states is None:
|
235 |
+
encoder_hidden_states = hidden_states
|
236 |
+
elif attn.norm_cross:
|
237 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
238 |
+
|
239 |
+
key = attn.to_k(encoder_hidden_states)
|
240 |
+
value = attn.to_v(encoder_hidden_states)
|
241 |
+
|
242 |
+
inner_dim = key.shape[-1]
|
243 |
+
head_dim = inner_dim // attn.heads
|
244 |
+
|
245 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
246 |
+
|
247 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
248 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
249 |
+
|
250 |
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
251 |
+
# TODO: add support for attn.scale when we move to Torch 2.1
|
252 |
+
hidden_states = F.scaled_dot_product_attention(
|
253 |
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
254 |
+
)
|
255 |
+
|
256 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
257 |
+
hidden_states = hidden_states.to(query.dtype)
|
258 |
+
|
259 |
+
# linear proj
|
260 |
+
hidden_states = attn.to_out[0](hidden_states)
|
261 |
+
# dropout
|
262 |
+
hidden_states = attn.to_out[1](hidden_states)
|
263 |
+
|
264 |
+
if input_ndim == 4:
|
265 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
266 |
+
|
267 |
+
if attn.residual_connection:
|
268 |
+
hidden_states = hidden_states + residual
|
269 |
+
|
270 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
271 |
+
|
272 |
+
return hidden_states
|
273 |
+
|
274 |
+
|
275 |
+
class IPAttnProcessor2_0(torch.nn.Module):
|
276 |
+
r"""
|
277 |
+
Attention processor for IP-Adapater for PyTorch 2.0.
|
278 |
+
Args:
|
279 |
+
hidden_size (`int`):
|
280 |
+
The hidden size of the attention layer.
|
281 |
+
cross_attention_dim (`int`):
|
282 |
+
The number of channels in the `encoder_hidden_states`.
|
283 |
+
scale (`float`, defaults to 1.0):
|
284 |
+
the weight scale of image prompt.
|
285 |
+
num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
|
286 |
+
The context length of the image features.
|
287 |
+
"""
|
288 |
+
|
289 |
+
def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4):
|
290 |
+
super().__init__()
|
291 |
+
|
292 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
293 |
+
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
294 |
+
|
295 |
+
self.hidden_size = hidden_size
|
296 |
+
self.cross_attention_dim = cross_attention_dim
|
297 |
+
self.scale = scale
|
298 |
+
self.num_tokens = num_tokens
|
299 |
+
|
300 |
+
self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
301 |
+
self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
302 |
+
|
303 |
+
def __call__(
|
304 |
+
self,
|
305 |
+
attn,
|
306 |
+
hidden_states,
|
307 |
+
encoder_hidden_states=None,
|
308 |
+
attention_mask=None,
|
309 |
+
temb=None,
|
310 |
+
):
|
311 |
+
residual = hidden_states
|
312 |
+
|
313 |
+
if attn.spatial_norm is not None:
|
314 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
315 |
+
|
316 |
+
input_ndim = hidden_states.ndim
|
317 |
+
|
318 |
+
if input_ndim == 4:
|
319 |
+
batch_size, channel, height, width = hidden_states.shape
|
320 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
321 |
+
|
322 |
+
batch_size, sequence_length, _ = (
|
323 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
324 |
+
)
|
325 |
+
|
326 |
+
if attention_mask is not None:
|
327 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
328 |
+
# scaled_dot_product_attention expects attention_mask shape to be
|
329 |
+
# (batch, heads, source_length, target_length)
|
330 |
+
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
331 |
+
|
332 |
+
if attn.group_norm is not None:
|
333 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
334 |
+
|
335 |
+
query = attn.to_q(hidden_states)
|
336 |
+
|
337 |
+
if encoder_hidden_states is None:
|
338 |
+
encoder_hidden_states = hidden_states
|
339 |
+
else:
|
340 |
+
# get encoder_hidden_states, ip_hidden_states
|
341 |
+
end_pos = encoder_hidden_states.shape[1] - self.num_tokens
|
342 |
+
encoder_hidden_states, ip_hidden_states = (
|
343 |
+
encoder_hidden_states[:, :end_pos, :],
|
344 |
+
encoder_hidden_states[:, end_pos:, :],
|
345 |
+
)
|
346 |
+
if attn.norm_cross:
|
347 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
348 |
+
|
349 |
+
key = attn.to_k(encoder_hidden_states)
|
350 |
+
value = attn.to_v(encoder_hidden_states)
|
351 |
+
|
352 |
+
inner_dim = key.shape[-1]
|
353 |
+
head_dim = inner_dim // attn.heads
|
354 |
+
|
355 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
356 |
+
|
357 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
358 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
359 |
+
|
360 |
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
361 |
+
# TODO: add support for attn.scale when we move to Torch 2.1
|
362 |
+
hidden_states = F.scaled_dot_product_attention(
|
363 |
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
364 |
+
)
|
365 |
+
|
366 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
367 |
+
hidden_states = hidden_states.to(query.dtype)
|
368 |
+
|
369 |
+
# for ip-adapter
|
370 |
+
ip_key = self.to_k_ip(ip_hidden_states)
|
371 |
+
ip_value = self.to_v_ip(ip_hidden_states)
|
372 |
+
|
373 |
+
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
374 |
+
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
375 |
+
|
376 |
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
377 |
+
# TODO: add support for attn.scale when we move to Torch 2.1
|
378 |
+
ip_hidden_states = F.scaled_dot_product_attention(
|
379 |
+
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
|
380 |
+
)
|
381 |
+
|
382 |
+
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
383 |
+
ip_hidden_states = ip_hidden_states.to(query.dtype)
|
384 |
+
|
385 |
+
hidden_states = hidden_states + self.scale * ip_hidden_states
|
386 |
+
|
387 |
+
# linear proj
|
388 |
+
hidden_states = attn.to_out[0](hidden_states)
|
389 |
+
# dropout
|
390 |
+
hidden_states = attn.to_out[1](hidden_states)
|
391 |
+
|
392 |
+
if input_ndim == 4:
|
393 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
394 |
+
|
395 |
+
if attn.residual_connection:
|
396 |
+
hidden_states = hidden_states + residual
|
397 |
+
|
398 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
399 |
+
|
400 |
+
return hidden_states
|
401 |
+
|
402 |
+
|
403 |
+
## for controlnet
|
404 |
+
class CNAttnProcessor:
|
405 |
+
r"""
|
406 |
+
Default processor for performing attention-related computations.
|
407 |
+
"""
|
408 |
+
|
409 |
+
def __init__(self, num_tokens=4):
|
410 |
+
self.num_tokens = num_tokens
|
411 |
+
|
412 |
+
def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None):
|
413 |
+
residual = hidden_states
|
414 |
+
|
415 |
+
if attn.spatial_norm is not None:
|
416 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
417 |
+
|
418 |
+
input_ndim = hidden_states.ndim
|
419 |
+
|
420 |
+
if input_ndim == 4:
|
421 |
+
batch_size, channel, height, width = hidden_states.shape
|
422 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
423 |
+
|
424 |
+
batch_size, sequence_length, _ = (
|
425 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
426 |
+
)
|
427 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
428 |
+
|
429 |
+
if attn.group_norm is not None:
|
430 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
431 |
+
|
432 |
+
query = attn.to_q(hidden_states)
|
433 |
+
|
434 |
+
if encoder_hidden_states is None:
|
435 |
+
encoder_hidden_states = hidden_states
|
436 |
+
else:
|
437 |
+
end_pos = encoder_hidden_states.shape[1] - self.num_tokens
|
438 |
+
encoder_hidden_states = encoder_hidden_states[:, :end_pos] # only use text
|
439 |
+
if attn.norm_cross:
|
440 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
441 |
+
|
442 |
+
key = attn.to_k(encoder_hidden_states)
|
443 |
+
value = attn.to_v(encoder_hidden_states)
|
444 |
+
|
445 |
+
query = attn.head_to_batch_dim(query)
|
446 |
+
key = attn.head_to_batch_dim(key)
|
447 |
+
value = attn.head_to_batch_dim(value)
|
448 |
+
|
449 |
+
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
450 |
+
hidden_states = torch.bmm(attention_probs, value)
|
451 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
452 |
+
|
453 |
+
# linear proj
|
454 |
+
hidden_states = attn.to_out[0](hidden_states)
|
455 |
+
# dropout
|
456 |
+
hidden_states = attn.to_out[1](hidden_states)
|
457 |
+
|
458 |
+
if input_ndim == 4:
|
459 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
460 |
+
|
461 |
+
if attn.residual_connection:
|
462 |
+
hidden_states = hidden_states + residual
|
463 |
+
|
464 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
465 |
+
|
466 |
+
return hidden_states
|
467 |
+
|
468 |
+
|
469 |
+
class CNAttnProcessor2_0:
|
470 |
+
r"""
|
471 |
+
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
|
472 |
+
"""
|
473 |
+
|
474 |
+
def __init__(self, num_tokens=4):
|
475 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
476 |
+
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
477 |
+
self.num_tokens = num_tokens
|
478 |
+
|
479 |
+
def __call__(
|
480 |
+
self,
|
481 |
+
attn,
|
482 |
+
hidden_states,
|
483 |
+
encoder_hidden_states=None,
|
484 |
+
attention_mask=None,
|
485 |
+
temb=None,
|
486 |
+
):
|
487 |
+
residual = hidden_states
|
488 |
+
|
489 |
+
if attn.spatial_norm is not None:
|
490 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
491 |
+
|
492 |
+
input_ndim = hidden_states.ndim
|
493 |
+
|
494 |
+
if input_ndim == 4:
|
495 |
+
batch_size, channel, height, width = hidden_states.shape
|
496 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
497 |
+
|
498 |
+
batch_size, sequence_length, _ = (
|
499 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
500 |
+
)
|
501 |
+
|
502 |
+
if attention_mask is not None:
|
503 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
504 |
+
# scaled_dot_product_attention expects attention_mask shape to be
|
505 |
+
# (batch, heads, source_length, target_length)
|
506 |
+
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
507 |
+
|
508 |
+
if attn.group_norm is not None:
|
509 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
510 |
+
|
511 |
+
query = attn.to_q(hidden_states)
|
512 |
+
|
513 |
+
if encoder_hidden_states is None:
|
514 |
+
encoder_hidden_states = hidden_states
|
515 |
+
else:
|
516 |
+
end_pos = encoder_hidden_states.shape[1] - self.num_tokens
|
517 |
+
encoder_hidden_states = encoder_hidden_states[:, :end_pos] # only use text
|
518 |
+
if attn.norm_cross:
|
519 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
520 |
+
|
521 |
+
key = attn.to_k(encoder_hidden_states)
|
522 |
+
value = attn.to_v(encoder_hidden_states)
|
523 |
+
|
524 |
+
inner_dim = key.shape[-1]
|
525 |
+
head_dim = inner_dim // attn.heads
|
526 |
+
|
527 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
528 |
+
|
529 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
530 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
531 |
+
|
532 |
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
533 |
+
# TODO: add support for attn.scale when we move to Torch 2.1
|
534 |
+
hidden_states = F.scaled_dot_product_attention(
|
535 |
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
536 |
+
)
|
537 |
+
|
538 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
539 |
+
hidden_states = hidden_states.to(query.dtype)
|
540 |
+
|
541 |
+
# linear proj
|
542 |
+
hidden_states = attn.to_out[0](hidden_states)
|
543 |
+
# dropout
|
544 |
+
hidden_states = attn.to_out[1](hidden_states)
|
545 |
+
|
546 |
+
if input_ndim == 4:
|
547 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
548 |
+
|
549 |
+
if attn.residual_connection:
|
550 |
+
hidden_states = hidden_states + residual
|
551 |
+
|
552 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
553 |
+
|
554 |
+
return hidden_states
|
ip_adapter/custom_pipelines.py
ADDED
@@ -0,0 +1,394 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from diffusers import StableDiffusionXLPipeline
|
5 |
+
from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput
|
6 |
+
from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import rescale_noise_cfg
|
7 |
+
|
8 |
+
from .utils import is_torch2_available
|
9 |
+
|
10 |
+
if is_torch2_available():
|
11 |
+
from .attention_processor import IPAttnProcessor2_0 as IPAttnProcessor
|
12 |
+
else:
|
13 |
+
from .attention_processor import IPAttnProcessor
|
14 |
+
|
15 |
+
|
16 |
+
class StableDiffusionXLCustomPipeline(StableDiffusionXLPipeline):
|
17 |
+
def set_scale(self, scale):
|
18 |
+
for attn_processor in self.unet.attn_processors.values():
|
19 |
+
if isinstance(attn_processor, IPAttnProcessor):
|
20 |
+
attn_processor.scale = scale
|
21 |
+
|
22 |
+
@torch.no_grad()
|
23 |
+
def __call__( # noqa: C901
|
24 |
+
self,
|
25 |
+
prompt: Optional[Union[str, List[str]]] = None,
|
26 |
+
prompt_2: Optional[Union[str, List[str]]] = None,
|
27 |
+
height: Optional[int] = None,
|
28 |
+
width: Optional[int] = None,
|
29 |
+
num_inference_steps: int = 50,
|
30 |
+
denoising_end: Optional[float] = None,
|
31 |
+
guidance_scale: float = 5.0,
|
32 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
33 |
+
negative_prompt_2: Optional[Union[str, List[str]]] = None,
|
34 |
+
num_images_per_prompt: Optional[int] = 1,
|
35 |
+
eta: float = 0.0,
|
36 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
37 |
+
latents: Optional[torch.FloatTensor] = None,
|
38 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
39 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
40 |
+
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
41 |
+
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
42 |
+
output_type: Optional[str] = "pil",
|
43 |
+
return_dict: bool = True,
|
44 |
+
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
45 |
+
callback_steps: int = 1,
|
46 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
47 |
+
guidance_rescale: float = 0.0,
|
48 |
+
original_size: Optional[Tuple[int, int]] = None,
|
49 |
+
crops_coords_top_left: Tuple[int, int] = (0, 0),
|
50 |
+
target_size: Optional[Tuple[int, int]] = None,
|
51 |
+
negative_original_size: Optional[Tuple[int, int]] = None,
|
52 |
+
negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
|
53 |
+
negative_target_size: Optional[Tuple[int, int]] = None,
|
54 |
+
control_guidance_start: float = 0.0,
|
55 |
+
control_guidance_end: float = 1.0,
|
56 |
+
):
|
57 |
+
r"""
|
58 |
+
Function invoked when calling the pipeline for generation.
|
59 |
+
|
60 |
+
Args:
|
61 |
+
prompt (`str` or `List[str]`, *optional*):
|
62 |
+
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
63 |
+
instead.
|
64 |
+
prompt_2 (`str` or `List[str]`, *optional*):
|
65 |
+
The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
|
66 |
+
used in both text-encoders
|
67 |
+
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
68 |
+
The height in pixels of the generated image. This is set to 1024 by default for the best results.
|
69 |
+
Anything below 512 pixels won't work well for
|
70 |
+
[stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
|
71 |
+
and checkpoints that are not specifically fine-tuned on low resolutions.
|
72 |
+
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
73 |
+
The width in pixels of the generated image. This is set to 1024 by default for the best results.
|
74 |
+
Anything below 512 pixels won't work well for
|
75 |
+
[stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
|
76 |
+
and checkpoints that are not specifically fine-tuned on low resolutions.
|
77 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
78 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
79 |
+
expense of slower inference.
|
80 |
+
denoising_end (`float`, *optional*):
|
81 |
+
When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
|
82 |
+
completed before it is intentionally prematurely terminated. As a result, the returned sample will
|
83 |
+
still retain a substantial amount of noise as determined by the discrete timesteps selected by the
|
84 |
+
scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a
|
85 |
+
"Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
|
86 |
+
Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output)
|
87 |
+
guidance_scale (`float`, *optional*, defaults to 5.0):
|
88 |
+
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
89 |
+
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
90 |
+
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
91 |
+
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
92 |
+
usually at the expense of lower image quality.
|
93 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
94 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
95 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
96 |
+
less than `1`).
|
97 |
+
negative_prompt_2 (`str` or `List[str]`, *optional*):
|
98 |
+
The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
|
99 |
+
`text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
|
100 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
101 |
+
The number of images to generate per prompt.
|
102 |
+
eta (`float`, *optional*, defaults to 0.0):
|
103 |
+
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
104 |
+
[`schedulers.DDIMScheduler`], will be ignored for others.
|
105 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
106 |
+
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
107 |
+
to make generation deterministic.
|
108 |
+
latents (`torch.FloatTensor`, *optional*):
|
109 |
+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
110 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
111 |
+
tensor will ge generated by sampling using the supplied random `generator`.
|
112 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
113 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
114 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
115 |
+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
116 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
117 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
118 |
+
argument.
|
119 |
+
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
120 |
+
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
121 |
+
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
122 |
+
negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
123 |
+
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
124 |
+
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
|
125 |
+
input argument.
|
126 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
127 |
+
The output format of the generate image. Choose between
|
128 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
129 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
130 |
+
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
|
131 |
+
of a plain tuple.
|
132 |
+
callback (`Callable`, *optional*):
|
133 |
+
A function that will be called every `callback_steps` steps during inference. The function will be
|
134 |
+
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
135 |
+
callback_steps (`int`, *optional*, defaults to 1):
|
136 |
+
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
137 |
+
called at every step.
|
138 |
+
cross_attention_kwargs (`dict`, *optional*):
|
139 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
140 |
+
`self.processor` in
|
141 |
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
142 |
+
guidance_rescale (`float`, *optional*, defaults to 0.7):
|
143 |
+
Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
|
144 |
+
Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
|
145 |
+
[Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
|
146 |
+
Guidance rescale factor should fix overexposure when using zero terminal SNR.
|
147 |
+
original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
|
148 |
+
If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
|
149 |
+
`original_size` defaults to `(width, height)` if not specified. Part of SDXL's micro-conditioning as
|
150 |
+
explained in section 2.2 of
|
151 |
+
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
|
152 |
+
crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
|
153 |
+
`crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
|
154 |
+
`crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
|
155 |
+
`crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
|
156 |
+
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
|
157 |
+
target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
|
158 |
+
For most cases, `target_size` should be set to the desired height and width of the generated image. If
|
159 |
+
not specified it will default to `(width, height)`. Part of SDXL's micro-conditioning as explained in
|
160 |
+
section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
|
161 |
+
negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
|
162 |
+
To negatively condition the generation process based on a specific image resolution. Part of SDXL's
|
163 |
+
micro-conditioning as explained in section 2.2 of
|
164 |
+
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
|
165 |
+
information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
|
166 |
+
negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
|
167 |
+
To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's
|
168 |
+
micro-conditioning as explained in section 2.2 of
|
169 |
+
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
|
170 |
+
information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
|
171 |
+
negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
|
172 |
+
To negatively condition the generation process based on a target image resolution. It should be as same
|
173 |
+
as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of
|
174 |
+
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
|
175 |
+
information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
|
176 |
+
control_guidance_start (`float`, *optional*, defaults to 0.0):
|
177 |
+
The percentage of total steps at which the ControlNet starts applying.
|
178 |
+
control_guidance_end (`float`, *optional*, defaults to 1.0):
|
179 |
+
The percentage of total steps at which the ControlNet stops applying.
|
180 |
+
|
181 |
+
Examples:
|
182 |
+
|
183 |
+
Returns:
|
184 |
+
[`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`:
|
185 |
+
[`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
|
186 |
+
`tuple`. When returning a tuple, the first element is a list with the generated images.
|
187 |
+
"""
|
188 |
+
# 0. Default height and width to unet
|
189 |
+
height = height or self.default_sample_size * self.vae_scale_factor
|
190 |
+
width = width or self.default_sample_size * self.vae_scale_factor
|
191 |
+
|
192 |
+
original_size = original_size or (height, width)
|
193 |
+
target_size = target_size or (height, width)
|
194 |
+
|
195 |
+
# 1. Check inputs. Raise error if not correct
|
196 |
+
self.check_inputs(
|
197 |
+
prompt,
|
198 |
+
prompt_2,
|
199 |
+
height,
|
200 |
+
width,
|
201 |
+
callback_steps,
|
202 |
+
negative_prompt,
|
203 |
+
negative_prompt_2,
|
204 |
+
prompt_embeds,
|
205 |
+
negative_prompt_embeds,
|
206 |
+
pooled_prompt_embeds,
|
207 |
+
negative_pooled_prompt_embeds,
|
208 |
+
)
|
209 |
+
|
210 |
+
# 2. Define call parameters
|
211 |
+
if prompt is not None and isinstance(prompt, str):
|
212 |
+
batch_size = 1
|
213 |
+
elif prompt is not None and isinstance(prompt, list):
|
214 |
+
batch_size = len(prompt)
|
215 |
+
else:
|
216 |
+
batch_size = prompt_embeds.shape[0]
|
217 |
+
|
218 |
+
device = self._execution_device
|
219 |
+
|
220 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
221 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
222 |
+
# corresponds to doing no classifier free guidance.
|
223 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
224 |
+
|
225 |
+
# 3. Encode input prompt
|
226 |
+
text_encoder_lora_scale = (
|
227 |
+
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
|
228 |
+
)
|
229 |
+
(
|
230 |
+
prompt_embeds,
|
231 |
+
negative_prompt_embeds,
|
232 |
+
pooled_prompt_embeds,
|
233 |
+
negative_pooled_prompt_embeds,
|
234 |
+
) = self.encode_prompt(
|
235 |
+
prompt=prompt,
|
236 |
+
prompt_2=prompt_2,
|
237 |
+
device=device,
|
238 |
+
num_images_per_prompt=num_images_per_prompt,
|
239 |
+
do_classifier_free_guidance=do_classifier_free_guidance,
|
240 |
+
negative_prompt=negative_prompt,
|
241 |
+
negative_prompt_2=negative_prompt_2,
|
242 |
+
prompt_embeds=prompt_embeds,
|
243 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
244 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
245 |
+
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
246 |
+
lora_scale=text_encoder_lora_scale,
|
247 |
+
)
|
248 |
+
|
249 |
+
# 4. Prepare timesteps
|
250 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
251 |
+
|
252 |
+
timesteps = self.scheduler.timesteps
|
253 |
+
|
254 |
+
# 5. Prepare latent variables
|
255 |
+
num_channels_latents = self.unet.config.in_channels
|
256 |
+
latents = self.prepare_latents(
|
257 |
+
batch_size * num_images_per_prompt,
|
258 |
+
num_channels_latents,
|
259 |
+
height,
|
260 |
+
width,
|
261 |
+
prompt_embeds.dtype,
|
262 |
+
device,
|
263 |
+
generator,
|
264 |
+
latents,
|
265 |
+
)
|
266 |
+
|
267 |
+
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
268 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
269 |
+
|
270 |
+
# 7. Prepare added time ids & embeddings
|
271 |
+
add_text_embeds = pooled_prompt_embeds
|
272 |
+
if self.text_encoder_2 is None:
|
273 |
+
text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
|
274 |
+
else:
|
275 |
+
text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
|
276 |
+
|
277 |
+
add_time_ids = self._get_add_time_ids(
|
278 |
+
original_size,
|
279 |
+
crops_coords_top_left,
|
280 |
+
target_size,
|
281 |
+
dtype=prompt_embeds.dtype,
|
282 |
+
text_encoder_projection_dim=text_encoder_projection_dim,
|
283 |
+
)
|
284 |
+
if negative_original_size is not None and negative_target_size is not None:
|
285 |
+
negative_add_time_ids = self._get_add_time_ids(
|
286 |
+
negative_original_size,
|
287 |
+
negative_crops_coords_top_left,
|
288 |
+
negative_target_size,
|
289 |
+
dtype=prompt_embeds.dtype,
|
290 |
+
text_encoder_projection_dim=text_encoder_projection_dim,
|
291 |
+
)
|
292 |
+
else:
|
293 |
+
negative_add_time_ids = add_time_ids
|
294 |
+
|
295 |
+
if do_classifier_free_guidance:
|
296 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
297 |
+
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
|
298 |
+
add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
|
299 |
+
|
300 |
+
prompt_embeds = prompt_embeds.to(device)
|
301 |
+
add_text_embeds = add_text_embeds.to(device)
|
302 |
+
add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
|
303 |
+
|
304 |
+
# 8. Denoising loop
|
305 |
+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
306 |
+
|
307 |
+
# 7.1 Apply denoising_end
|
308 |
+
if denoising_end is not None and isinstance(denoising_end, float) and denoising_end > 0 and denoising_end < 1:
|
309 |
+
discrete_timestep_cutoff = int(
|
310 |
+
round(
|
311 |
+
self.scheduler.config.num_train_timesteps
|
312 |
+
- (denoising_end * self.scheduler.config.num_train_timesteps)
|
313 |
+
)
|
314 |
+
)
|
315 |
+
num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
|
316 |
+
timesteps = timesteps[:num_inference_steps]
|
317 |
+
|
318 |
+
# get init conditioning scale
|
319 |
+
for attn_processor in self.unet.attn_processors.values():
|
320 |
+
if isinstance(attn_processor, IPAttnProcessor):
|
321 |
+
conditioning_scale = attn_processor.scale
|
322 |
+
break
|
323 |
+
|
324 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
325 |
+
for i, t in enumerate(timesteps):
|
326 |
+
if (i / len(timesteps) < control_guidance_start) or ((i + 1) / len(timesteps) > control_guidance_end):
|
327 |
+
self.set_scale(0.0)
|
328 |
+
else:
|
329 |
+
self.set_scale(conditioning_scale)
|
330 |
+
|
331 |
+
# expand the latents if we are doing classifier free guidance
|
332 |
+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
333 |
+
|
334 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
335 |
+
|
336 |
+
# predict the noise residual
|
337 |
+
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
|
338 |
+
noise_pred = self.unet(
|
339 |
+
latent_model_input,
|
340 |
+
t,
|
341 |
+
encoder_hidden_states=prompt_embeds,
|
342 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
343 |
+
added_cond_kwargs=added_cond_kwargs,
|
344 |
+
return_dict=False,
|
345 |
+
)[0]
|
346 |
+
|
347 |
+
# perform guidance
|
348 |
+
if do_classifier_free_guidance:
|
349 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
350 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
351 |
+
|
352 |
+
if do_classifier_free_guidance and guidance_rescale > 0.0:
|
353 |
+
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
354 |
+
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
|
355 |
+
|
356 |
+
# compute the previous noisy sample x_t -> x_t-1
|
357 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
358 |
+
|
359 |
+
# call the callback, if provided
|
360 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
361 |
+
progress_bar.update()
|
362 |
+
if callback is not None and i % callback_steps == 0:
|
363 |
+
callback(i, t, latents)
|
364 |
+
|
365 |
+
if not output_type == "latent":
|
366 |
+
# make sure the VAE is in float32 mode, as it overflows in float16
|
367 |
+
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
|
368 |
+
|
369 |
+
if needs_upcasting:
|
370 |
+
self.upcast_vae()
|
371 |
+
latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
|
372 |
+
|
373 |
+
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
|
374 |
+
|
375 |
+
# cast back to fp16 if needed
|
376 |
+
if needs_upcasting:
|
377 |
+
self.vae.to(dtype=torch.float16)
|
378 |
+
else:
|
379 |
+
image = latents
|
380 |
+
|
381 |
+
if output_type != "latent":
|
382 |
+
# apply watermark if available
|
383 |
+
if self.watermark is not None:
|
384 |
+
image = self.watermark.apply_watermark(image)
|
385 |
+
|
386 |
+
image = self.image_processor.postprocess(image, output_type=output_type)
|
387 |
+
|
388 |
+
# Offload all models
|
389 |
+
self.maybe_free_model_hooks()
|
390 |
+
|
391 |
+
if not return_dict:
|
392 |
+
return (image,)
|
393 |
+
|
394 |
+
return StableDiffusionXLPipelineOutput(images=image)
|
ip_adapter/ip_adapter.py
ADDED
@@ -0,0 +1,1086 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import List
|
3 |
+
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from typing import Optional, Union, Any, Dict, Tuple, List, Callable
|
7 |
+
from diffusers.utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor
|
8 |
+
from diffusers.utils import (
|
9 |
+
USE_PEFT_BACKEND,
|
10 |
+
deprecate,
|
11 |
+
logging,
|
12 |
+
replace_example_docstring,
|
13 |
+
scale_lora_layers,
|
14 |
+
unscale_lora_layers,
|
15 |
+
)
|
16 |
+
from diffusers.pipelines.controlnet.pipeline_controlnet import retrieve_timesteps
|
17 |
+
from diffusers import StableDiffusionPipeline
|
18 |
+
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput
|
19 |
+
from diffusers.pipelines.controlnet.pipeline_controlnet import StableDiffusionControlNetPipeline
|
20 |
+
from diffusers.models.controlnet import ControlNetModel
|
21 |
+
from diffusers.image_processor import PipelineImageInput
|
22 |
+
from diffusers.pipelines.controlnet import MultiControlNetModel
|
23 |
+
from PIL import Image
|
24 |
+
from safetensors import safe_open
|
25 |
+
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
26 |
+
from torchvision import transforms
|
27 |
+
from .style_encoder import Style_Aware_Encoder
|
28 |
+
from .tools import pre_processing
|
29 |
+
|
30 |
+
from .utils import is_torch2_available
|
31 |
+
|
32 |
+
if is_torch2_available():
|
33 |
+
from .attention_processor import (
|
34 |
+
AttnProcessor2_0 as AttnProcessor,
|
35 |
+
)
|
36 |
+
from .attention_processor import (
|
37 |
+
CNAttnProcessor2_0 as CNAttnProcessor,
|
38 |
+
)
|
39 |
+
from .attention_processor import (
|
40 |
+
IPAttnProcessor2_0 as IPAttnProcessor,
|
41 |
+
)
|
42 |
+
else:
|
43 |
+
from .attention_processor import AttnProcessor, CNAttnProcessor, IPAttnProcessor
|
44 |
+
from .resampler import Resampler
|
45 |
+
|
46 |
+
|
47 |
+
class ImageProjModel(torch.nn.Module):
|
48 |
+
"""Projection Model"""
|
49 |
+
|
50 |
+
def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4):
|
51 |
+
super().__init__()
|
52 |
+
|
53 |
+
self.cross_attention_dim = cross_attention_dim
|
54 |
+
self.clip_extra_context_tokens = clip_extra_context_tokens
|
55 |
+
self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
|
56 |
+
self.norm = torch.nn.LayerNorm(cross_attention_dim)
|
57 |
+
|
58 |
+
def forward(self, image_embeds):
|
59 |
+
embeds = image_embeds
|
60 |
+
clip_extra_context_tokens = self.proj(embeds).reshape(
|
61 |
+
-1, self.clip_extra_context_tokens, self.cross_attention_dim
|
62 |
+
)
|
63 |
+
clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
|
64 |
+
return clip_extra_context_tokens
|
65 |
+
|
66 |
+
|
67 |
+
class MLPProjModel(torch.nn.Module):
|
68 |
+
"""SD model with image prompt"""
|
69 |
+
def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024):
|
70 |
+
super().__init__()
|
71 |
+
|
72 |
+
self.proj = torch.nn.Sequential(
|
73 |
+
torch.nn.Linear(clip_embeddings_dim, clip_embeddings_dim),
|
74 |
+
torch.nn.GELU(),
|
75 |
+
torch.nn.Linear(clip_embeddings_dim, cross_attention_dim),
|
76 |
+
torch.nn.LayerNorm(cross_attention_dim)
|
77 |
+
)
|
78 |
+
|
79 |
+
def forward(self, image_embeds):
|
80 |
+
clip_extra_context_tokens = self.proj(image_embeds)
|
81 |
+
return clip_extra_context_tokens
|
82 |
+
|
83 |
+
|
84 |
+
class IPAdapter:
|
85 |
+
def __init__(self, sd_pipe, image_encoder_path, ip_ckpt, device, num_tokens=4):
|
86 |
+
self.device = device
|
87 |
+
self.image_encoder_path = image_encoder_path
|
88 |
+
self.ip_ckpt = ip_ckpt
|
89 |
+
self.num_tokens = num_tokens
|
90 |
+
|
91 |
+
self.pipe = sd_pipe.to(self.device)
|
92 |
+
self.set_ip_adapter()
|
93 |
+
|
94 |
+
# load image encoder
|
95 |
+
self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to(
|
96 |
+
self.device, dtype=torch.float16
|
97 |
+
)
|
98 |
+
self.clip_image_processor = CLIPImageProcessor()
|
99 |
+
# image proj model
|
100 |
+
self.image_proj_model = self.init_proj()
|
101 |
+
|
102 |
+
self.load_ip_adapter()
|
103 |
+
|
104 |
+
def init_proj(self):
|
105 |
+
image_proj_model = ImageProjModel(
|
106 |
+
cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
|
107 |
+
clip_embeddings_dim=self.image_encoder.config.projection_dim,
|
108 |
+
clip_extra_context_tokens=self.num_tokens,
|
109 |
+
).to(self.device, dtype=torch.float16)
|
110 |
+
return image_proj_model
|
111 |
+
|
112 |
+
def set_ip_adapter(self):
|
113 |
+
unet = self.pipe.unet
|
114 |
+
attn_procs = {}
|
115 |
+
for name in unet.attn_processors.keys():
|
116 |
+
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
|
117 |
+
if name.startswith("mid_block"):
|
118 |
+
hidden_size = unet.config.block_out_channels[-1]
|
119 |
+
elif name.startswith("up_blocks"):
|
120 |
+
block_id = int(name[len("up_blocks.")])
|
121 |
+
hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
|
122 |
+
elif name.startswith("down_blocks"):
|
123 |
+
block_id = int(name[len("down_blocks.")])
|
124 |
+
hidden_size = unet.config.block_out_channels[block_id]
|
125 |
+
if cross_attention_dim is None:
|
126 |
+
attn_procs[name] = AttnProcessor()
|
127 |
+
else:
|
128 |
+
attn_procs[name] = IPAttnProcessor(
|
129 |
+
hidden_size=hidden_size,
|
130 |
+
cross_attention_dim=cross_attention_dim,
|
131 |
+
scale=1.0,
|
132 |
+
num_tokens=self.num_tokens,
|
133 |
+
).to(self.device, dtype=torch.float16)
|
134 |
+
unet.set_attn_processor(attn_procs)
|
135 |
+
if hasattr(self.pipe, "controlnet"):
|
136 |
+
if isinstance(self.pipe.controlnet, MultiControlNetModel):
|
137 |
+
for controlnet in self.pipe.controlnet.nets:
|
138 |
+
controlnet.set_attn_processor(CNAttnProcessor(num_tokens=self.num_tokens))
|
139 |
+
else:
|
140 |
+
self.pipe.controlnet.set_attn_processor(CNAttnProcessor(num_tokens=self.num_tokens))
|
141 |
+
|
142 |
+
def load_ip_adapter(self):
|
143 |
+
if os.path.splitext(self.ip_ckpt)[-1] == ".safetensors":
|
144 |
+
state_dict = {"image_proj": {}, "ip_adapter": {}}
|
145 |
+
with safe_open(self.ip_ckpt, framework="pt", device="cpu") as f:
|
146 |
+
for key in f.keys():
|
147 |
+
if key.startswith("image_proj."):
|
148 |
+
state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
|
149 |
+
elif key.startswith("ip_adapter."):
|
150 |
+
state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
|
151 |
+
else:
|
152 |
+
state_dict = torch.load(self.ip_ckpt, map_location="cpu")
|
153 |
+
self.image_proj_model.load_state_dict(state_dict["image_proj"])
|
154 |
+
ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values())
|
155 |
+
ip_layers.load_state_dict(state_dict["ip_adapter"])
|
156 |
+
|
157 |
+
@torch.inference_mode()
|
158 |
+
def get_image_embeds(self, pil_image=None, clip_image_embeds=None):
|
159 |
+
if pil_image is not None:
|
160 |
+
if isinstance(pil_image, Image.Image):
|
161 |
+
pil_image = [pil_image]
|
162 |
+
clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
|
163 |
+
clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=torch.float16)).image_embeds
|
164 |
+
else:
|
165 |
+
clip_image_embeds = clip_image_embeds.to(self.device, dtype=torch.float16)
|
166 |
+
image_prompt_embeds = self.image_proj_model(clip_image_embeds)
|
167 |
+
uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(clip_image_embeds))
|
168 |
+
return image_prompt_embeds, uncond_image_prompt_embeds
|
169 |
+
|
170 |
+
def set_scale(self, scale):
|
171 |
+
for attn_processor in self.pipe.unet.attn_processors.values():
|
172 |
+
if isinstance(attn_processor, IPAttnProcessor):
|
173 |
+
attn_processor.scale = scale
|
174 |
+
|
175 |
+
def generate(
|
176 |
+
self,
|
177 |
+
pil_image=None,
|
178 |
+
clip_image_embeds=None,
|
179 |
+
prompt=None,
|
180 |
+
negative_prompt=None,
|
181 |
+
scale=1.0,
|
182 |
+
num_samples=4,
|
183 |
+
seed=None,
|
184 |
+
guidance_scale=7.5,
|
185 |
+
num_inference_steps=30,
|
186 |
+
**kwargs,
|
187 |
+
):
|
188 |
+
self.set_scale(scale)
|
189 |
+
|
190 |
+
if pil_image is not None:
|
191 |
+
num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image)
|
192 |
+
else:
|
193 |
+
num_prompts = clip_image_embeds.size(0)
|
194 |
+
|
195 |
+
if prompt is None:
|
196 |
+
prompt = "best quality, high quality"
|
197 |
+
if negative_prompt is None:
|
198 |
+
negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
|
199 |
+
|
200 |
+
if not isinstance(prompt, List):
|
201 |
+
prompt = [prompt] * num_prompts
|
202 |
+
if not isinstance(negative_prompt, List):
|
203 |
+
negative_prompt = [negative_prompt] * num_prompts
|
204 |
+
|
205 |
+
image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(
|
206 |
+
pil_image=pil_image, clip_image_embeds=clip_image_embeds
|
207 |
+
)
|
208 |
+
bs_embed, seq_len, _ = image_prompt_embeds.shape
|
209 |
+
image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
|
210 |
+
image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
|
211 |
+
uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
|
212 |
+
uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
|
213 |
+
|
214 |
+
with torch.inference_mode():
|
215 |
+
prompt_embeds_, negative_prompt_embeds_ = self.pipe.encode_prompt(
|
216 |
+
prompt,
|
217 |
+
device=self.device,
|
218 |
+
num_images_per_prompt=num_samples,
|
219 |
+
do_classifier_free_guidance=True,
|
220 |
+
negative_prompt=negative_prompt,
|
221 |
+
)
|
222 |
+
prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1)
|
223 |
+
negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1)
|
224 |
+
|
225 |
+
generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
|
226 |
+
images = self.pipe(
|
227 |
+
prompt_embeds=prompt_embeds,
|
228 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
229 |
+
guidance_scale=guidance_scale,
|
230 |
+
num_inference_steps=num_inference_steps,
|
231 |
+
generator=generator,
|
232 |
+
**kwargs,
|
233 |
+
).images
|
234 |
+
|
235 |
+
return images
|
236 |
+
|
237 |
+
|
238 |
+
class IPAdapterXL(IPAdapter):
|
239 |
+
"""SDXL"""
|
240 |
+
|
241 |
+
def generate(
|
242 |
+
self,
|
243 |
+
pil_image,
|
244 |
+
prompt=None,
|
245 |
+
negative_prompt=None,
|
246 |
+
scale=1.0,
|
247 |
+
num_samples=4,
|
248 |
+
seed=None,
|
249 |
+
num_inference_steps=30,
|
250 |
+
**kwargs,
|
251 |
+
):
|
252 |
+
self.set_scale(scale)
|
253 |
+
|
254 |
+
num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image)
|
255 |
+
|
256 |
+
if prompt is None:
|
257 |
+
prompt = "best quality, high quality"
|
258 |
+
if negative_prompt is None:
|
259 |
+
negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
|
260 |
+
|
261 |
+
if not isinstance(prompt, List):
|
262 |
+
prompt = [prompt] * num_prompts
|
263 |
+
if not isinstance(negative_prompt, List):
|
264 |
+
negative_prompt = [negative_prompt] * num_prompts
|
265 |
+
|
266 |
+
image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image)
|
267 |
+
bs_embed, seq_len, _ = image_prompt_embeds.shape
|
268 |
+
image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
|
269 |
+
image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
|
270 |
+
uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
|
271 |
+
uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
|
272 |
+
|
273 |
+
with torch.inference_mode():
|
274 |
+
(
|
275 |
+
prompt_embeds,
|
276 |
+
negative_prompt_embeds,
|
277 |
+
pooled_prompt_embeds,
|
278 |
+
negative_pooled_prompt_embeds,
|
279 |
+
) = self.pipe.encode_prompt(
|
280 |
+
prompt,
|
281 |
+
num_images_per_prompt=num_samples,
|
282 |
+
do_classifier_free_guidance=True,
|
283 |
+
negative_prompt=negative_prompt,
|
284 |
+
)
|
285 |
+
prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)
|
286 |
+
negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1)
|
287 |
+
|
288 |
+
generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
|
289 |
+
images = self.pipe(
|
290 |
+
prompt_embeds=prompt_embeds,
|
291 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
292 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
293 |
+
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
294 |
+
num_inference_steps=num_inference_steps,
|
295 |
+
generator=generator,
|
296 |
+
**kwargs,
|
297 |
+
).images
|
298 |
+
|
299 |
+
return images
|
300 |
+
|
301 |
+
|
302 |
+
class IPAdapterPlus(IPAdapter):
|
303 |
+
"""IP-Adapter with fine-grained features"""
|
304 |
+
|
305 |
+
def init_proj(self):
|
306 |
+
image_proj_model = Resampler(
|
307 |
+
dim=self.pipe.unet.config.cross_attention_dim,
|
308 |
+
depth=4,
|
309 |
+
dim_head=64,
|
310 |
+
heads=12,
|
311 |
+
num_queries=self.num_tokens,
|
312 |
+
embedding_dim=self.image_encoder.config.hidden_size,
|
313 |
+
output_dim=self.pipe.unet.config.cross_attention_dim,
|
314 |
+
ff_mult=4,
|
315 |
+
).to(self.device, dtype=torch.float16)
|
316 |
+
return image_proj_model
|
317 |
+
|
318 |
+
@torch.inference_mode()
|
319 |
+
def get_image_embeds(self, pil_image=None, clip_image_embeds=None):
|
320 |
+
if isinstance(pil_image, Image.Image):
|
321 |
+
pil_image = [pil_image]
|
322 |
+
clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
|
323 |
+
clip_image = clip_image.to(self.device, dtype=torch.float16)
|
324 |
+
clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
|
325 |
+
image_prompt_embeds = self.image_proj_model(clip_image_embeds)
|
326 |
+
uncond_clip_image_embeds = self.image_encoder(
|
327 |
+
torch.zeros_like(clip_image), output_hidden_states=True
|
328 |
+
).hidden_states[-2]
|
329 |
+
uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds)
|
330 |
+
return image_prompt_embeds, uncond_image_prompt_embeds
|
331 |
+
|
332 |
+
|
333 |
+
class IPAdapterFull(IPAdapterPlus):
|
334 |
+
"""IP-Adapter with full features"""
|
335 |
+
|
336 |
+
def init_proj(self):
|
337 |
+
image_proj_model = MLPProjModel(
|
338 |
+
cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
|
339 |
+
clip_embeddings_dim=self.image_encoder.config.hidden_size,
|
340 |
+
).to(self.device, dtype=torch.float16)
|
341 |
+
return image_proj_model
|
342 |
+
|
343 |
+
|
344 |
+
class IPAdapterPlusXL(IPAdapter):
|
345 |
+
"""SDXL"""
|
346 |
+
|
347 |
+
def init_proj(self):
|
348 |
+
image_proj_model = Resampler(
|
349 |
+
dim=1280,
|
350 |
+
depth=4,
|
351 |
+
dim_head=64,
|
352 |
+
heads=20,
|
353 |
+
num_queries=self.num_tokens,
|
354 |
+
embedding_dim=self.image_encoder.config.hidden_size,
|
355 |
+
output_dim=self.pipe.unet.config.cross_attention_dim,
|
356 |
+
ff_mult=4,
|
357 |
+
).to(self.device, dtype=torch.float16)
|
358 |
+
return image_proj_model
|
359 |
+
|
360 |
+
@torch.inference_mode()
|
361 |
+
def get_image_embeds(self, pil_image):
|
362 |
+
if isinstance(pil_image, Image.Image):
|
363 |
+
pil_image = [pil_image]
|
364 |
+
clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
|
365 |
+
clip_image = clip_image.to(self.device, dtype=torch.float16)
|
366 |
+
clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
|
367 |
+
image_prompt_embeds = self.image_proj_model(clip_image_embeds)
|
368 |
+
uncond_clip_image_embeds = self.image_encoder(
|
369 |
+
torch.zeros_like(clip_image), output_hidden_states=True
|
370 |
+
).hidden_states[-2]
|
371 |
+
uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds)
|
372 |
+
return image_prompt_embeds, uncond_image_prompt_embeds
|
373 |
+
|
374 |
+
def generate(
|
375 |
+
self,
|
376 |
+
pil_image,
|
377 |
+
prompt=None,
|
378 |
+
negative_prompt=None,
|
379 |
+
scale=1.0,
|
380 |
+
num_samples=4,
|
381 |
+
seed=None,
|
382 |
+
num_inference_steps=30,
|
383 |
+
**kwargs,
|
384 |
+
):
|
385 |
+
self.set_scale(scale)
|
386 |
+
|
387 |
+
num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image)
|
388 |
+
|
389 |
+
if prompt is None:
|
390 |
+
prompt = "best quality, high quality"
|
391 |
+
if negative_prompt is None:
|
392 |
+
negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
|
393 |
+
|
394 |
+
if not isinstance(prompt, List):
|
395 |
+
prompt = [prompt] * num_prompts
|
396 |
+
if not isinstance(negative_prompt, List):
|
397 |
+
negative_prompt = [negative_prompt] * num_prompts
|
398 |
+
|
399 |
+
image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image)
|
400 |
+
bs_embed, seq_len, _ = image_prompt_embeds.shape
|
401 |
+
image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
|
402 |
+
image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
|
403 |
+
uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
|
404 |
+
uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
|
405 |
+
|
406 |
+
with torch.inference_mode():
|
407 |
+
(
|
408 |
+
prompt_embeds,
|
409 |
+
negative_prompt_embeds,
|
410 |
+
pooled_prompt_embeds,
|
411 |
+
negative_pooled_prompt_embeds,
|
412 |
+
) = self.pipe.encode_prompt(
|
413 |
+
prompt,
|
414 |
+
num_images_per_prompt=num_samples,
|
415 |
+
do_classifier_free_guidance=True,
|
416 |
+
negative_prompt=negative_prompt,
|
417 |
+
)
|
418 |
+
prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)
|
419 |
+
negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1)
|
420 |
+
|
421 |
+
generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
|
422 |
+
images = self.pipe(
|
423 |
+
prompt_embeds=prompt_embeds,
|
424 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
425 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
426 |
+
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
427 |
+
num_inference_steps=num_inference_steps,
|
428 |
+
generator=generator,
|
429 |
+
**kwargs,
|
430 |
+
).images
|
431 |
+
|
432 |
+
return images
|
433 |
+
|
434 |
+
|
435 |
+
def StyleProcessor(style_image, device):
|
436 |
+
transform = transforms.Compose([
|
437 |
+
transforms.ToTensor(),
|
438 |
+
transforms.Normalize([0.5], [0.5]),
|
439 |
+
])
|
440 |
+
# centercrop for style condition
|
441 |
+
crop = transforms.Compose(
|
442 |
+
[
|
443 |
+
transforms.Resize(512, interpolation=transforms.InterpolationMode.BILINEAR),
|
444 |
+
transforms.CenterCrop(512),
|
445 |
+
]
|
446 |
+
)
|
447 |
+
style_image = crop(style_image)
|
448 |
+
high_style_patch, middle_style_patch, low_style_patch = pre_processing(style_image.convert("RGB"), transform)
|
449 |
+
# shuffling
|
450 |
+
high_style_patch, middle_style_patch, low_style_patch = (high_style_patch[torch.randperm(high_style_patch.shape[0])],
|
451 |
+
middle_style_patch[torch.randperm(middle_style_patch.shape[0])],
|
452 |
+
low_style_patch[torch.randperm(low_style_patch.shape[0])])
|
453 |
+
return (high_style_patch.to(device, dtype=torch.float32), middle_style_patch.to(device, dtype=torch.float32), low_style_patch.to(device, dtype=torch.float32))
|
454 |
+
|
455 |
+
|
456 |
+
class StyleShot(torch.nn.Module):
|
457 |
+
"""StyleShot generation"""
|
458 |
+
def __init__(self, device, pipe, ip_ckpt, style_aware_encoder_ckpt, transformer_patch):
|
459 |
+
super().__init__()
|
460 |
+
self.num_tokens = 6
|
461 |
+
self.device = device
|
462 |
+
self.pipe = pipe
|
463 |
+
|
464 |
+
self.set_ip_adapter(device)
|
465 |
+
self.ip_ckpt = ip_ckpt
|
466 |
+
|
467 |
+
self.style_aware_encoder = Style_Aware_Encoder(CLIPVisionModelWithProjection.from_pretrained(transformer_patch)).to(self.device, dtype=torch.float32)
|
468 |
+
self.style_aware_encoder.load_state_dict(torch.load(style_aware_encoder_ckpt))
|
469 |
+
|
470 |
+
self.style_image_proj_modules = self.init_proj()
|
471 |
+
|
472 |
+
self.load_ip_adapter()
|
473 |
+
self.pipe = self.pipe.to(self.device, dtype=torch.float32)
|
474 |
+
|
475 |
+
def init_proj(self):
|
476 |
+
style_image_proj_modules = torch.nn.ModuleList([
|
477 |
+
ImageProjModel(
|
478 |
+
cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
|
479 |
+
clip_embeddings_dim=self.style_aware_encoder.projection_dim,
|
480 |
+
clip_extra_context_tokens=2,
|
481 |
+
),
|
482 |
+
ImageProjModel(
|
483 |
+
cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
|
484 |
+
clip_embeddings_dim=self.style_aware_encoder.projection_dim,
|
485 |
+
clip_extra_context_tokens=2,
|
486 |
+
),
|
487 |
+
ImageProjModel(
|
488 |
+
cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
|
489 |
+
clip_embeddings_dim=self.style_aware_encoder.projection_dim,
|
490 |
+
clip_extra_context_tokens=2,
|
491 |
+
)])
|
492 |
+
return style_image_proj_modules.to(self.device, dtype=torch.float32)
|
493 |
+
|
494 |
+
def load_ip_adapter(self):
|
495 |
+
sd = torch.load(self.ip_ckpt, map_location="cpu")
|
496 |
+
style_image_proj_sd = {}
|
497 |
+
ip_sd = {}
|
498 |
+
controlnet_sd = {}
|
499 |
+
for k in sd:
|
500 |
+
if k.startswith("unet"):
|
501 |
+
pass
|
502 |
+
elif k.startswith("style_image_proj_modules"):
|
503 |
+
style_image_proj_sd[k.replace("style_image_proj_modules.", "")] = sd[k]
|
504 |
+
elif k.startswith("adapter_modules"):
|
505 |
+
ip_sd[k.replace("adapter_modules.", "")] = sd[k]
|
506 |
+
elif k.startswith("controlnet"):
|
507 |
+
controlnet_sd[k.replace("controlnet.", "")] = sd[k]
|
508 |
+
# Load state dict for image_proj_model and adapter_modules
|
509 |
+
self.style_image_proj_modules.load_state_dict(style_image_proj_sd, strict=True)
|
510 |
+
ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values())
|
511 |
+
if hasattr(self.pipe, "controlnet") and isinstance(self.pipe, StyleContentStableDiffusionControlNetPipeline):
|
512 |
+
self.pipe.controlnet.load_state_dict(controlnet_sd, strict=True)
|
513 |
+
ip_layers.load_state_dict(ip_sd, strict=True)
|
514 |
+
|
515 |
+
def set_ip_adapter(self, device):
|
516 |
+
unet = self.pipe.unet
|
517 |
+
attn_procs = {}
|
518 |
+
for name in unet.attn_processors.keys():
|
519 |
+
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
|
520 |
+
if name.startswith("mid_block"):
|
521 |
+
hidden_size = unet.config.block_out_channels[-1]
|
522 |
+
elif name.startswith("up_blocks"):
|
523 |
+
block_id = int(name[len("up_blocks.")])
|
524 |
+
hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
|
525 |
+
elif name.startswith("down_blocks"):
|
526 |
+
block_id = int(name[len("down_blocks.")])
|
527 |
+
hidden_size = unet.config.block_out_channels[block_id]
|
528 |
+
if cross_attention_dim is None:
|
529 |
+
attn_procs[name] = AttnProcessor()
|
530 |
+
else:
|
531 |
+
attn_procs[name] = IPAttnProcessor(
|
532 |
+
hidden_size=hidden_size,
|
533 |
+
cross_attention_dim=cross_attention_dim,
|
534 |
+
scale=1.0,
|
535 |
+
num_tokens=self.num_tokens,
|
536 |
+
).to(device, dtype=torch.float16)
|
537 |
+
if hasattr(self.pipe, "controlnet") and not isinstance(self.pipe, StyleContentStableDiffusionControlNetPipeline):
|
538 |
+
if isinstance(self.pipe.controlnet, MultiControlNetModel):
|
539 |
+
for controlnet in self.pipe.controlnet.nets:
|
540 |
+
controlnet.set_attn_processor(CNAttnProcessor(num_tokens=self.num_tokens))
|
541 |
+
else:
|
542 |
+
self.pipe.controlnet.set_attn_processor(CNAttnProcessor(num_tokens=self.num_tokens))
|
543 |
+
unet.set_attn_processor(attn_procs)
|
544 |
+
|
545 |
+
@torch.inference_mode()
|
546 |
+
def get_image_embeds(self, style_image=None):
|
547 |
+
style_image = StyleProcessor(style_image, self.device)
|
548 |
+
style_embeds = self.style_aware_encoder(style_image).to(self.device, dtype=torch.float32)
|
549 |
+
style_ip_tokens = []
|
550 |
+
uncond_style_ip_tokens = []
|
551 |
+
for idx, style_embed in enumerate([style_embeds[:, 0, :], style_embeds[:, 1, :], style_embeds[:, 2, :]]):
|
552 |
+
style_ip_tokens.append(self.style_image_proj_modules[idx](style_embed))
|
553 |
+
uncond_style_ip_tokens.append(self.style_image_proj_modules[idx](torch.zeros_like(style_embed)))
|
554 |
+
style_ip_tokens = torch.cat(style_ip_tokens, dim=1)
|
555 |
+
uncond_style_ip_tokens = torch.cat(uncond_style_ip_tokens, dim=1)
|
556 |
+
return style_ip_tokens, uncond_style_ip_tokens
|
557 |
+
|
558 |
+
def set_scale(self, scale):
|
559 |
+
for attn_processor in self.pipe.unet.attn_processors.values():
|
560 |
+
if isinstance(attn_processor, IPAttnProcessor):
|
561 |
+
attn_processor.scale = scale
|
562 |
+
|
563 |
+
def samples(self, image_prompt_embeds, uncond_image_prompt_embeds, num_samples, device, prompt, negative_prompt,
|
564 |
+
seed, guidance_scale, num_inference_steps, content_image, **kwargs, ):
|
565 |
+
bs_embed, seq_len, _ = image_prompt_embeds.shape
|
566 |
+
image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
|
567 |
+
image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
|
568 |
+
uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
|
569 |
+
uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
|
570 |
+
with torch.inference_mode():
|
571 |
+
prompt_embeds_, negative_prompt_embeds_ = self.pipe.encode_prompt(
|
572 |
+
prompt,
|
573 |
+
device=device,
|
574 |
+
num_images_per_prompt=num_samples,
|
575 |
+
do_classifier_free_guidance=True,
|
576 |
+
negative_prompt=negative_prompt,
|
577 |
+
)
|
578 |
+
prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1)
|
579 |
+
negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1)
|
580 |
+
|
581 |
+
generator = torch.Generator(device).manual_seed(seed) if seed is not None else None
|
582 |
+
if content_image is None:
|
583 |
+
images = self.pipe(
|
584 |
+
prompt_embeds=prompt_embeds,
|
585 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
586 |
+
guidance_scale=guidance_scale,
|
587 |
+
num_inference_steps=num_inference_steps,
|
588 |
+
generator=generator,
|
589 |
+
**kwargs,
|
590 |
+
).images
|
591 |
+
else:
|
592 |
+
images = self.pipe(
|
593 |
+
prompt_embeds=prompt_embeds,
|
594 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
595 |
+
guidance_scale=guidance_scale,
|
596 |
+
num_inference_steps=num_inference_steps,
|
597 |
+
generator=generator,
|
598 |
+
image=content_image,
|
599 |
+
style_embeddings=image_prompt_embeds,
|
600 |
+
negative_style_embeddings=uncond_image_prompt_embeds,
|
601 |
+
**kwargs,
|
602 |
+
).images
|
603 |
+
return images
|
604 |
+
|
605 |
+
def generate(
|
606 |
+
self,
|
607 |
+
style_image=None,
|
608 |
+
prompt=None,
|
609 |
+
negative_prompt=None,
|
610 |
+
scale=1.0,
|
611 |
+
num_samples=1,
|
612 |
+
seed=42,
|
613 |
+
guidance_scale=7.5,
|
614 |
+
num_inference_steps=50,
|
615 |
+
content_image=None,
|
616 |
+
**kwargs,
|
617 |
+
):
|
618 |
+
self.set_scale(scale)
|
619 |
+
|
620 |
+
num_prompts = 1
|
621 |
+
|
622 |
+
if prompt is None:
|
623 |
+
prompt = "best quality, high quality"
|
624 |
+
if negative_prompt is None:
|
625 |
+
negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
|
626 |
+
|
627 |
+
if not isinstance(prompt, List):
|
628 |
+
prompt = [prompt] * num_prompts
|
629 |
+
if not isinstance(negative_prompt, List):
|
630 |
+
negative_prompt = [negative_prompt] * num_prompts
|
631 |
+
|
632 |
+
style_ip_tokens, uncond_style_ip_tokens = self.get_image_embeds(style_image)
|
633 |
+
generate_images = []
|
634 |
+
for p in prompt:
|
635 |
+
images = self.samples(style_ip_tokens, uncond_style_ip_tokens, num_samples, self.device, p * num_prompts, negative_prompt, seed, guidance_scale, num_inference_steps, content_image, **kwargs, )
|
636 |
+
generate_images.append(images)
|
637 |
+
return generate_images
|
638 |
+
|
639 |
+
|
640 |
+
class StyleContentStableDiffusionControlNetPipeline(StableDiffusionControlNetPipeline):
|
641 |
+
@torch.no_grad()
|
642 |
+
def __call__(
|
643 |
+
self,
|
644 |
+
prompt: Union[str, List[str]] = None,
|
645 |
+
image: PipelineImageInput = None,
|
646 |
+
height: Optional[int] = None,
|
647 |
+
width: Optional[int] = None,
|
648 |
+
num_inference_steps: int = 50,
|
649 |
+
timesteps: List[int] = None,
|
650 |
+
guidance_scale: float = 7.5,
|
651 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
652 |
+
num_images_per_prompt: Optional[int] = 1,
|
653 |
+
eta: float = 0.0,
|
654 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
655 |
+
latents: Optional[torch.FloatTensor] = None,
|
656 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
657 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
658 |
+
ip_adapter_image: Optional[PipelineImageInput] = None,
|
659 |
+
ip_adapter_image_embeds: Optional[List[torch.FloatTensor]] = None,
|
660 |
+
output_type: Optional[str] = "pil",
|
661 |
+
return_dict: bool = True,
|
662 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
663 |
+
controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
|
664 |
+
guess_mode: bool = False,
|
665 |
+
control_guidance_start: Union[float, List[float]] = 0.0,
|
666 |
+
control_guidance_end: Union[float, List[float]] = 1.0,
|
667 |
+
clip_skip: Optional[int] = None,
|
668 |
+
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
669 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
670 |
+
style_embeddings: Optional[torch.FloatTensor] = None,
|
671 |
+
negative_style_embeddings: Optional[torch.FloatTensor] = None,
|
672 |
+
**kwargs,
|
673 |
+
):
|
674 |
+
r"""
|
675 |
+
The call function to the pipeline for generation.
|
676 |
+
|
677 |
+
Args:
|
678 |
+
prompt (`str` or `List[str]`, *optional*):
|
679 |
+
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
|
680 |
+
image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
|
681 |
+
`List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
|
682 |
+
The ControlNet input condition to provide guidance to the `unet` for generation. If the type is
|
683 |
+
specified as `torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be
|
684 |
+
accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height
|
685 |
+
and/or width are passed, `image` is resized accordingly. If multiple ControlNets are specified in
|
686 |
+
`init`, images must be passed as a list such that each element of the list can be correctly batched for
|
687 |
+
input to a single ControlNet. When `prompt` is a list, and if a list of images is passed for a single ControlNet,
|
688 |
+
each will be paired with each prompt in the `prompt` list. This also applies to multiple ControlNets,
|
689 |
+
where a list of image lists can be passed to batch for each prompt and each ControlNet.
|
690 |
+
height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
691 |
+
The height in pixels of the generated image.
|
692 |
+
width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
693 |
+
The width in pixels of the generated image.
|
694 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
695 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
696 |
+
expense of slower inference.
|
697 |
+
timesteps (`List[int]`, *optional*):
|
698 |
+
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
|
699 |
+
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
|
700 |
+
passed will be used. Must be in descending order.
|
701 |
+
guidance_scale (`float`, *optional*, defaults to 7.5):
|
702 |
+
A higher guidance scale value encourages the model to generate images closely linked to the text
|
703 |
+
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
|
704 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
705 |
+
The prompt or prompts to guide what to not include in image generation. If not defined, you need to
|
706 |
+
pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
|
707 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
708 |
+
The number of images to generate per prompt.
|
709 |
+
eta (`float`, *optional*, defaults to 0.0):
|
710 |
+
Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
|
711 |
+
to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
|
712 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
713 |
+
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
714 |
+
generation deterministic.
|
715 |
+
latents (`torch.FloatTensor`, *optional*):
|
716 |
+
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
|
717 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
718 |
+
tensor is generated by sampling using the supplied random `generator`.
|
719 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
720 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
|
721 |
+
provided, text embeddings are generated from the `prompt` input argument.
|
722 |
+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
723 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
|
724 |
+
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
|
725 |
+
ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
|
726 |
+
ip_adapter_image_embeds (`List[torch.FloatTensor]`, *optional*):
|
727 |
+
Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of IP-adapters.
|
728 |
+
Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should contain the negative image embedding
|
729 |
+
if `do_classifier_free_guidance` is set to `True`.
|
730 |
+
If not provided, embeddings are computed from the `ip_adapter_image` input argument.
|
731 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
732 |
+
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
|
733 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
734 |
+
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
735 |
+
plain tuple.
|
736 |
+
callback (`Callable`, *optional*):
|
737 |
+
A function that calls every `callback_steps` steps during inference. The function is called with the
|
738 |
+
following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
739 |
+
callback_steps (`int`, *optional*, defaults to 1):
|
740 |
+
The frequency at which the `callback` function is called. If not specified, the callback is called at
|
741 |
+
every step.
|
742 |
+
cross_attention_kwargs (`dict`, *optional*):
|
743 |
+
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
|
744 |
+
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
745 |
+
controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
|
746 |
+
The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added
|
747 |
+
to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set
|
748 |
+
the corresponding scale as a list.
|
749 |
+
guess_mode (`bool`, *optional*, defaults to `False`):
|
750 |
+
The ControlNet encoder tries to recognize the content of the input image even if you remove all
|
751 |
+
prompts. A `guidance_scale` value between 3.0 and 5.0 is recommended.
|
752 |
+
control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):
|
753 |
+
The percentage of total steps at which the ControlNet starts applying.
|
754 |
+
control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
|
755 |
+
The percentage of total steps at which the ControlNet stops applying.
|
756 |
+
clip_skip (`int`, *optional*):
|
757 |
+
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
758 |
+
the output of the pre-final layer will be used for computing the prompt embeddings.
|
759 |
+
callback_on_step_end (`Callable`, *optional*):
|
760 |
+
A function that calls at the end of each denoising steps during the inference. The function is called
|
761 |
+
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
762 |
+
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
763 |
+
`callback_on_step_end_tensor_inputs`.
|
764 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
765 |
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
766 |
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
767 |
+
`._callback_tensor_inputs` attribute of your pipeine class.
|
768 |
+
|
769 |
+
Examples:
|
770 |
+
|
771 |
+
Returns:
|
772 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
773 |
+
If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
|
774 |
+
otherwise a `tuple` is returned where the first element is a list with the generated images and the
|
775 |
+
second element is a list of `bool`s indicating whether the corresponding generated image contains
|
776 |
+
"not-safe-for-work" (nsfw) content.
|
777 |
+
"""
|
778 |
+
|
779 |
+
callback = kwargs.pop("callback", None)
|
780 |
+
callback_steps = kwargs.pop("callback_steps", None)
|
781 |
+
|
782 |
+
if callback is not None:
|
783 |
+
deprecate(
|
784 |
+
"callback",
|
785 |
+
"1.0.0",
|
786 |
+
"Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
|
787 |
+
)
|
788 |
+
if callback_steps is not None:
|
789 |
+
deprecate(
|
790 |
+
"callback_steps",
|
791 |
+
"1.0.0",
|
792 |
+
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
|
793 |
+
)
|
794 |
+
|
795 |
+
controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
|
796 |
+
|
797 |
+
# align format for control guidance
|
798 |
+
if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
|
799 |
+
control_guidance_start = len(control_guidance_end) * [control_guidance_start]
|
800 |
+
elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
|
801 |
+
control_guidance_end = len(control_guidance_start) * [control_guidance_end]
|
802 |
+
elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
|
803 |
+
mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1
|
804 |
+
control_guidance_start, control_guidance_end = (
|
805 |
+
mult * [control_guidance_start],
|
806 |
+
mult * [control_guidance_end],
|
807 |
+
)
|
808 |
+
|
809 |
+
# 1. Check inputs. Raise error if not correct
|
810 |
+
self.check_inputs(
|
811 |
+
prompt,
|
812 |
+
image,
|
813 |
+
callback_steps,
|
814 |
+
negative_prompt,
|
815 |
+
prompt_embeds,
|
816 |
+
negative_prompt_embeds,
|
817 |
+
ip_adapter_image,
|
818 |
+
ip_adapter_image_embeds,
|
819 |
+
controlnet_conditioning_scale,
|
820 |
+
control_guidance_start,
|
821 |
+
control_guidance_end,
|
822 |
+
callback_on_step_end_tensor_inputs,
|
823 |
+
)
|
824 |
+
|
825 |
+
self._guidance_scale = guidance_scale
|
826 |
+
self._clip_skip = clip_skip
|
827 |
+
self._cross_attention_kwargs = cross_attention_kwargs
|
828 |
+
|
829 |
+
# 2. Define call parameters
|
830 |
+
if prompt is not None and isinstance(prompt, str):
|
831 |
+
batch_size = 1
|
832 |
+
elif prompt is not None and isinstance(prompt, list):
|
833 |
+
batch_size = len(prompt)
|
834 |
+
else:
|
835 |
+
batch_size = prompt_embeds.shape[0]
|
836 |
+
|
837 |
+
device = self._execution_device
|
838 |
+
|
839 |
+
if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
|
840 |
+
controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)
|
841 |
+
|
842 |
+
global_pool_conditions = (
|
843 |
+
controlnet.config.global_pool_conditions
|
844 |
+
if isinstance(controlnet, ControlNetModel)
|
845 |
+
else controlnet.nets[0].config.global_pool_conditions
|
846 |
+
)
|
847 |
+
guess_mode = guess_mode or global_pool_conditions
|
848 |
+
|
849 |
+
# 3. Encode input prompt
|
850 |
+
text_encoder_lora_scale = (
|
851 |
+
self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
|
852 |
+
)
|
853 |
+
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
854 |
+
prompt,
|
855 |
+
device,
|
856 |
+
num_images_per_prompt,
|
857 |
+
self.do_classifier_free_guidance,
|
858 |
+
negative_prompt,
|
859 |
+
prompt_embeds=prompt_embeds,
|
860 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
861 |
+
lora_scale=text_encoder_lora_scale,
|
862 |
+
clip_skip=self.clip_skip,
|
863 |
+
)
|
864 |
+
# For classifier free guidance, we need to do two forward passes.
|
865 |
+
# Here we concatenate the unconditional and text embeddings into a single batch
|
866 |
+
# to avoid doing two forward passes
|
867 |
+
if self.do_classifier_free_guidance:
|
868 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
869 |
+
|
870 |
+
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
|
871 |
+
image_embeds = self.prepare_ip_adapter_image_embeds(
|
872 |
+
ip_adapter_image,
|
873 |
+
ip_adapter_image_embeds,
|
874 |
+
device,
|
875 |
+
batch_size * num_images_per_prompt,
|
876 |
+
self.do_classifier_free_guidance,
|
877 |
+
)
|
878 |
+
|
879 |
+
# 4. Prepare image
|
880 |
+
if isinstance(controlnet, ControlNetModel):
|
881 |
+
image = self.prepare_image(
|
882 |
+
image=image,
|
883 |
+
width=width,
|
884 |
+
height=height,
|
885 |
+
batch_size=batch_size * num_images_per_prompt,
|
886 |
+
num_images_per_prompt=num_images_per_prompt,
|
887 |
+
device=device,
|
888 |
+
dtype=controlnet.dtype,
|
889 |
+
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
890 |
+
guess_mode=guess_mode,
|
891 |
+
)
|
892 |
+
height, width = image.shape[-2:]
|
893 |
+
|
894 |
+
elif isinstance(controlnet, MultiControlNetModel):
|
895 |
+
images = []
|
896 |
+
|
897 |
+
# Nested lists as ControlNet condition
|
898 |
+
if isinstance(image[0], list):
|
899 |
+
# Transpose the nested image list
|
900 |
+
image = [list(t) for t in zip(*image)]
|
901 |
+
|
902 |
+
for image_ in image:
|
903 |
+
image_ = self.prepare_image(
|
904 |
+
image=image_,
|
905 |
+
width=width,
|
906 |
+
height=height,
|
907 |
+
batch_size=batch_size * num_images_per_prompt,
|
908 |
+
num_images_per_prompt=num_images_per_prompt,
|
909 |
+
device=device,
|
910 |
+
dtype=controlnet.dtype,
|
911 |
+
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
912 |
+
guess_mode=guess_mode,
|
913 |
+
)
|
914 |
+
|
915 |
+
images.append(image_)
|
916 |
+
|
917 |
+
image = images
|
918 |
+
height, width = image[0].shape[-2:]
|
919 |
+
else:
|
920 |
+
assert False
|
921 |
+
|
922 |
+
# 5. Prepare timesteps
|
923 |
+
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
|
924 |
+
self._num_timesteps = len(timesteps)
|
925 |
+
|
926 |
+
# 6. Prepare latent variables
|
927 |
+
num_channels_latents = self.unet.config.in_channels
|
928 |
+
latents = self.prepare_latents(
|
929 |
+
batch_size * num_images_per_prompt,
|
930 |
+
num_channels_latents,
|
931 |
+
height,
|
932 |
+
width,
|
933 |
+
prompt_embeds.dtype,
|
934 |
+
device,
|
935 |
+
generator,
|
936 |
+
latents,
|
937 |
+
)
|
938 |
+
|
939 |
+
# 6.5 Optionally get Guidance Scale Embedding
|
940 |
+
timestep_cond = None
|
941 |
+
if self.unet.config.time_cond_proj_dim is not None:
|
942 |
+
guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
|
943 |
+
timestep_cond = self.get_guidance_scale_embedding(
|
944 |
+
guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
|
945 |
+
).to(device=device, dtype=latents.dtype)
|
946 |
+
|
947 |
+
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
948 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
949 |
+
|
950 |
+
# 7.1 Add image embeds for IP-Adapter
|
951 |
+
added_cond_kwargs = (
|
952 |
+
{"image_embeds": image_embeds}
|
953 |
+
if ip_adapter_image is not None or ip_adapter_image_embeds is not None
|
954 |
+
else None
|
955 |
+
)
|
956 |
+
|
957 |
+
# 7.2 Create tensor stating which controlnets to keep
|
958 |
+
controlnet_keep = []
|
959 |
+
for i in range(len(timesteps)):
|
960 |
+
keeps = [
|
961 |
+
1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
|
962 |
+
for s, e in zip(control_guidance_start, control_guidance_end)
|
963 |
+
]
|
964 |
+
controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps)
|
965 |
+
|
966 |
+
# 8. Denoising loop
|
967 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
968 |
+
is_unet_compiled = is_compiled_module(self.unet)
|
969 |
+
is_controlnet_compiled = is_compiled_module(self.controlnet)
|
970 |
+
is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1")
|
971 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
972 |
+
for i, t in enumerate(timesteps):
|
973 |
+
# Relevant thread:
|
974 |
+
# https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428
|
975 |
+
if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1:
|
976 |
+
torch._inductor.cudagraph_mark_step_begin()
|
977 |
+
# expand the latents if we are doing classifier free guidance
|
978 |
+
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
979 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
980 |
+
|
981 |
+
# controlnet(s) inference
|
982 |
+
if guess_mode and self.do_classifier_free_guidance:
|
983 |
+
# Infer ControlNet only for the conditional batch.
|
984 |
+
control_model_input = latents
|
985 |
+
control_model_input = self.scheduler.scale_model_input(control_model_input, t)
|
986 |
+
controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
|
987 |
+
else:
|
988 |
+
control_model_input = latent_model_input
|
989 |
+
controlnet_prompt_embeds = prompt_embeds
|
990 |
+
|
991 |
+
if isinstance(controlnet_keep[i], list):
|
992 |
+
cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
|
993 |
+
else:
|
994 |
+
controlnet_cond_scale = controlnet_conditioning_scale
|
995 |
+
if isinstance(controlnet_cond_scale, list):
|
996 |
+
controlnet_cond_scale = controlnet_cond_scale[0]
|
997 |
+
cond_scale = controlnet_cond_scale * controlnet_keep[i]
|
998 |
+
|
999 |
+
if self.do_classifier_free_guidance:
|
1000 |
+
style_embeddings_input = torch.cat([negative_style_embeddings, style_embeddings])
|
1001 |
+
|
1002 |
+
down_block_res_samples, mid_block_res_sample = self.controlnet(
|
1003 |
+
control_model_input,
|
1004 |
+
t,
|
1005 |
+
encoder_hidden_states=style_embeddings_input,
|
1006 |
+
controlnet_cond=image,
|
1007 |
+
conditioning_scale=cond_scale,
|
1008 |
+
guess_mode=guess_mode,
|
1009 |
+
return_dict=False,
|
1010 |
+
)
|
1011 |
+
|
1012 |
+
if guess_mode and self.do_classifier_free_guidance:
|
1013 |
+
# Infered ControlNet only for the conditional batch.
|
1014 |
+
# To apply the output of ControlNet to both the unconditional and conditional batches,
|
1015 |
+
# add 0 to the unconditional batch to keep it unchanged.
|
1016 |
+
down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
|
1017 |
+
mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])
|
1018 |
+
|
1019 |
+
# predict the noise residual
|
1020 |
+
noise_pred = self.unet(
|
1021 |
+
latent_model_input,
|
1022 |
+
t,
|
1023 |
+
encoder_hidden_states=prompt_embeds,
|
1024 |
+
timestep_cond=timestep_cond,
|
1025 |
+
cross_attention_kwargs=self.cross_attention_kwargs,
|
1026 |
+
down_block_additional_residuals=down_block_res_samples,
|
1027 |
+
mid_block_additional_residual=mid_block_res_sample,
|
1028 |
+
added_cond_kwargs=added_cond_kwargs,
|
1029 |
+
return_dict=False,
|
1030 |
+
)[0]
|
1031 |
+
|
1032 |
+
# perform guidance
|
1033 |
+
if self.do_classifier_free_guidance:
|
1034 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
1035 |
+
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
1036 |
+
|
1037 |
+
# compute the previous noisy sample x_t -> x_t-1
|
1038 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
1039 |
+
|
1040 |
+
if callback_on_step_end is not None:
|
1041 |
+
callback_kwargs = {}
|
1042 |
+
for k in callback_on_step_end_tensor_inputs:
|
1043 |
+
callback_kwargs[k] = locals()[k]
|
1044 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
1045 |
+
|
1046 |
+
latents = callback_outputs.pop("latents", latents)
|
1047 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
1048 |
+
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
1049 |
+
|
1050 |
+
# call the callback, if provided
|
1051 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
1052 |
+
progress_bar.update()
|
1053 |
+
if callback is not None and i % callback_steps == 0:
|
1054 |
+
step_idx = i // getattr(self.scheduler, "order", 1)
|
1055 |
+
callback(step_idx, t, latents)
|
1056 |
+
|
1057 |
+
# If we do sequential model offloading, let's offload unet and controlnet
|
1058 |
+
# manually for max memory savings
|
1059 |
+
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
1060 |
+
self.unet.to("cpu")
|
1061 |
+
self.controlnet.to("cpu")
|
1062 |
+
torch.cuda.empty_cache()
|
1063 |
+
|
1064 |
+
if not output_type == "latent":
|
1065 |
+
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
|
1066 |
+
0
|
1067 |
+
]
|
1068 |
+
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
|
1069 |
+
else:
|
1070 |
+
image = latents
|
1071 |
+
has_nsfw_concept = None
|
1072 |
+
|
1073 |
+
if has_nsfw_concept is None:
|
1074 |
+
do_denormalize = [True] * image.shape[0]
|
1075 |
+
else:
|
1076 |
+
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
|
1077 |
+
|
1078 |
+
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
|
1079 |
+
|
1080 |
+
# Offload all models
|
1081 |
+
self.maybe_free_model_hooks()
|
1082 |
+
|
1083 |
+
if not return_dict:
|
1084 |
+
return (image, has_nsfw_concept)
|
1085 |
+
|
1086 |
+
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
|
ip_adapter/resampler.py
ADDED
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
|
2 |
+
# and https://github.com/lucidrains/imagen-pytorch/blob/main/imagen_pytorch/imagen_pytorch.py
|
3 |
+
|
4 |
+
import math
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
from einops import rearrange
|
9 |
+
from einops.layers.torch import Rearrange
|
10 |
+
|
11 |
+
|
12 |
+
# FFN
|
13 |
+
def FeedForward(dim, mult=4):
|
14 |
+
inner_dim = int(dim * mult)
|
15 |
+
return nn.Sequential(
|
16 |
+
nn.LayerNorm(dim),
|
17 |
+
nn.Linear(dim, inner_dim, bias=False),
|
18 |
+
nn.GELU(),
|
19 |
+
nn.Linear(inner_dim, dim, bias=False),
|
20 |
+
)
|
21 |
+
|
22 |
+
|
23 |
+
def reshape_tensor(x, heads):
|
24 |
+
bs, length, width = x.shape
|
25 |
+
# (bs, length, width) --> (bs, length, n_heads, dim_per_head)
|
26 |
+
x = x.view(bs, length, heads, -1)
|
27 |
+
# (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
|
28 |
+
x = x.transpose(1, 2)
|
29 |
+
# (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
|
30 |
+
x = x.reshape(bs, heads, length, -1)
|
31 |
+
return x
|
32 |
+
|
33 |
+
|
34 |
+
class PerceiverAttention(nn.Module):
|
35 |
+
def __init__(self, *, dim, dim_head=64, heads=8):
|
36 |
+
super().__init__()
|
37 |
+
self.scale = dim_head**-0.5
|
38 |
+
self.dim_head = dim_head
|
39 |
+
self.heads = heads
|
40 |
+
inner_dim = dim_head * heads
|
41 |
+
|
42 |
+
self.norm1 = nn.LayerNorm(dim)
|
43 |
+
self.norm2 = nn.LayerNorm(dim)
|
44 |
+
|
45 |
+
self.to_q = nn.Linear(dim, inner_dim, bias=False)
|
46 |
+
self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
|
47 |
+
self.to_out = nn.Linear(inner_dim, dim, bias=False)
|
48 |
+
|
49 |
+
def forward(self, x, latents):
|
50 |
+
"""
|
51 |
+
Args:
|
52 |
+
x (torch.Tensor): image features
|
53 |
+
shape (b, n1, D)
|
54 |
+
latent (torch.Tensor): latent features
|
55 |
+
shape (b, n2, D)
|
56 |
+
"""
|
57 |
+
x = self.norm1(x)
|
58 |
+
latents = self.norm2(latents)
|
59 |
+
|
60 |
+
b, l, _ = latents.shape
|
61 |
+
|
62 |
+
q = self.to_q(latents)
|
63 |
+
kv_input = torch.cat((x, latents), dim=-2)
|
64 |
+
k, v = self.to_kv(kv_input).chunk(2, dim=-1)
|
65 |
+
|
66 |
+
q = reshape_tensor(q, self.heads)
|
67 |
+
k = reshape_tensor(k, self.heads)
|
68 |
+
v = reshape_tensor(v, self.heads)
|
69 |
+
|
70 |
+
# attention
|
71 |
+
scale = 1 / math.sqrt(math.sqrt(self.dim_head))
|
72 |
+
weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
|
73 |
+
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
|
74 |
+
out = weight @ v
|
75 |
+
|
76 |
+
out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
|
77 |
+
|
78 |
+
return self.to_out(out)
|
79 |
+
|
80 |
+
|
81 |
+
class Resampler(nn.Module):
|
82 |
+
def __init__(
|
83 |
+
self,
|
84 |
+
dim=1024,
|
85 |
+
depth=8,
|
86 |
+
dim_head=64,
|
87 |
+
heads=16,
|
88 |
+
num_queries=8,
|
89 |
+
embedding_dim=768,
|
90 |
+
output_dim=1024,
|
91 |
+
ff_mult=4,
|
92 |
+
max_seq_len: int = 257, # CLIP tokens + CLS token
|
93 |
+
apply_pos_emb: bool = False,
|
94 |
+
num_latents_mean_pooled: int = 0, # number of latents derived from mean pooled representation of the sequence
|
95 |
+
):
|
96 |
+
super().__init__()
|
97 |
+
self.pos_emb = nn.Embedding(max_seq_len, embedding_dim) if apply_pos_emb else None
|
98 |
+
|
99 |
+
self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
|
100 |
+
|
101 |
+
self.proj_in = nn.Linear(embedding_dim, dim)
|
102 |
+
|
103 |
+
self.proj_out = nn.Linear(dim, output_dim)
|
104 |
+
self.norm_out = nn.LayerNorm(output_dim)
|
105 |
+
|
106 |
+
self.to_latents_from_mean_pooled_seq = (
|
107 |
+
nn.Sequential(
|
108 |
+
nn.LayerNorm(dim),
|
109 |
+
nn.Linear(dim, dim * num_latents_mean_pooled),
|
110 |
+
Rearrange("b (n d) -> b n d", n=num_latents_mean_pooled),
|
111 |
+
)
|
112 |
+
if num_latents_mean_pooled > 0
|
113 |
+
else None
|
114 |
+
)
|
115 |
+
|
116 |
+
self.layers = nn.ModuleList([])
|
117 |
+
for _ in range(depth):
|
118 |
+
self.layers.append(
|
119 |
+
nn.ModuleList(
|
120 |
+
[
|
121 |
+
PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
|
122 |
+
FeedForward(dim=dim, mult=ff_mult),
|
123 |
+
]
|
124 |
+
)
|
125 |
+
)
|
126 |
+
|
127 |
+
def forward(self, x):
|
128 |
+
if self.pos_emb is not None:
|
129 |
+
n, device = x.shape[1], x.device
|
130 |
+
pos_emb = self.pos_emb(torch.arange(n, device=device))
|
131 |
+
x = x + pos_emb
|
132 |
+
|
133 |
+
latents = self.latents.repeat(x.size(0), 1, 1)
|
134 |
+
|
135 |
+
x = self.proj_in(x)
|
136 |
+
|
137 |
+
if self.to_latents_from_mean_pooled_seq:
|
138 |
+
meanpooled_seq = masked_mean(x, dim=1, mask=torch.ones(x.shape[:2], device=x.device, dtype=torch.bool))
|
139 |
+
meanpooled_latents = self.to_latents_from_mean_pooled_seq(meanpooled_seq)
|
140 |
+
latents = torch.cat((meanpooled_latents, latents), dim=-2)
|
141 |
+
|
142 |
+
for attn, ff in self.layers:
|
143 |
+
latents = attn(x, latents) + latents
|
144 |
+
latents = ff(latents) + latents
|
145 |
+
|
146 |
+
latents = self.proj_out(latents)
|
147 |
+
return self.norm_out(latents)
|
148 |
+
|
149 |
+
|
150 |
+
def masked_mean(t, *, dim, mask=None):
|
151 |
+
if mask is None:
|
152 |
+
return t.mean(dim=dim)
|
153 |
+
|
154 |
+
denom = mask.sum(dim=dim, keepdim=True)
|
155 |
+
mask = rearrange(mask, "b n -> b n 1")
|
156 |
+
masked_t = t.masked_fill(~mask, 0.0)
|
157 |
+
|
158 |
+
return masked_t.sum(dim=dim) / denom.clamp(min=1e-5)
|
ip_adapter/style_encoder.py
ADDED
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import numpy as np
|
4 |
+
from collections import OrderedDict
|
5 |
+
|
6 |
+
|
7 |
+
def conv_nd(dims, *args, **kwargs):
|
8 |
+
"""
|
9 |
+
Create a 1D, 2D, or 3D convolution module.
|
10 |
+
"""
|
11 |
+
if dims == 1:
|
12 |
+
return nn.Conv1d(*args, **kwargs)
|
13 |
+
elif dims == 2:
|
14 |
+
return nn.Conv2d(*args, **kwargs)
|
15 |
+
elif dims == 3:
|
16 |
+
return nn.Conv3d(*args, **kwargs)
|
17 |
+
raise ValueError(f"unsupported dimensions: {dims}")
|
18 |
+
|
19 |
+
|
20 |
+
def avg_pool_nd(dims, *args, **kwargs):
|
21 |
+
"""
|
22 |
+
Create a 1D, 2D, or 3D average pooling module.
|
23 |
+
"""
|
24 |
+
if dims == 1:
|
25 |
+
return nn.AvgPool1d(*args, **kwargs)
|
26 |
+
elif dims == 2:
|
27 |
+
return nn.AvgPool2d(*args, **kwargs)
|
28 |
+
elif dims == 3:
|
29 |
+
return nn.AvgPool3d(*args, **kwargs)
|
30 |
+
raise ValueError(f"unsupported dimensions: {dims}")
|
31 |
+
|
32 |
+
def get_parameter_dtype(parameter: torch.nn.Module):
|
33 |
+
try:
|
34 |
+
params = tuple(parameter.parameters())
|
35 |
+
if len(params) > 0:
|
36 |
+
return params[0].dtype
|
37 |
+
|
38 |
+
buffers = tuple(parameter.buffers())
|
39 |
+
if len(buffers) > 0:
|
40 |
+
return buffers[0].dtype
|
41 |
+
|
42 |
+
except StopIteration:
|
43 |
+
# For torch.nn.DataParallel compatibility in PyTorch 1.5
|
44 |
+
|
45 |
+
def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
|
46 |
+
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
|
47 |
+
return tuples
|
48 |
+
|
49 |
+
gen = parameter._named_members(get_members_fn=find_tensor_attributes)
|
50 |
+
first_tuple = next(gen)
|
51 |
+
return first_tuple[1].dtype
|
52 |
+
|
53 |
+
class Downsample(nn.Module):
|
54 |
+
"""
|
55 |
+
A downsampling layer with an optional convolution.
|
56 |
+
:param channels: channels in the inputs and outputs.
|
57 |
+
:param use_conv: a bool determining if a convolution is applied.
|
58 |
+
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
|
59 |
+
downsampling occurs in the inner-two dimensions.
|
60 |
+
"""
|
61 |
+
|
62 |
+
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
|
63 |
+
super().__init__()
|
64 |
+
self.channels = channels
|
65 |
+
self.out_channels = out_channels or channels
|
66 |
+
self.use_conv = use_conv
|
67 |
+
self.dims = dims
|
68 |
+
stride = 2 if dims != 3 else (1, 2, 2)
|
69 |
+
if use_conv:
|
70 |
+
self.op = conv_nd(dims, self.channels, self.out_channels, 3, stride=stride, padding=padding)
|
71 |
+
else:
|
72 |
+
assert self.channels == self.out_channels
|
73 |
+
self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
|
74 |
+
|
75 |
+
def forward(self, x):
|
76 |
+
assert x.shape[1] == self.channels
|
77 |
+
return self.op(x)
|
78 |
+
|
79 |
+
|
80 |
+
class ResnetBlock(nn.Module):
|
81 |
+
|
82 |
+
def __init__(self, in_c, out_c, down, ksize=3, sk=False, use_conv=True):
|
83 |
+
super().__init__()
|
84 |
+
ps = ksize // 2
|
85 |
+
if in_c != out_c or sk == False:
|
86 |
+
self.in_conv = nn.Conv2d(in_c, out_c, ksize, 1, ps)
|
87 |
+
else:
|
88 |
+
self.in_conv = None
|
89 |
+
self.block1 = nn.Conv2d(out_c, out_c, 3, 1, 1)
|
90 |
+
self.act = nn.ReLU()
|
91 |
+
self.block2 = nn.Conv2d(out_c, out_c, ksize, 1, ps)
|
92 |
+
if sk == False:
|
93 |
+
self.skep = nn.Conv2d(in_c, out_c, ksize, 1, ps)
|
94 |
+
else:
|
95 |
+
self.skep = None
|
96 |
+
|
97 |
+
self.down = down
|
98 |
+
if self.down == True:
|
99 |
+
self.down_opt = Downsample(in_c, use_conv=use_conv)
|
100 |
+
|
101 |
+
def forward(self, x):
|
102 |
+
if self.down == True:
|
103 |
+
x = self.down_opt(x)
|
104 |
+
if self.in_conv is not None: # edit
|
105 |
+
x = self.in_conv(x)
|
106 |
+
|
107 |
+
h = self.block1(x)
|
108 |
+
h = self.act(h)
|
109 |
+
h = self.block2(h)
|
110 |
+
if self.skep is not None:
|
111 |
+
return h + self.skep(x)
|
112 |
+
else:
|
113 |
+
return h + x
|
114 |
+
|
115 |
+
class Low_CNN(nn.Module):
|
116 |
+
def __init__(self, cin=192, ksize=1, sk=False, use_conv=True):
|
117 |
+
super(Low_CNN, self).__init__()
|
118 |
+
self.unshuffle = nn.PixelUnshuffle(8)
|
119 |
+
self.body = nn.Sequential(ResnetBlock(320, 320, down=False, ksize=ksize, sk=sk, use_conv=use_conv),
|
120 |
+
ResnetBlock(320, 640, down=False, ksize=ksize, sk=sk, use_conv=use_conv),
|
121 |
+
ResnetBlock(640, 1280, down=True, ksize=ksize, sk=sk, use_conv=use_conv),
|
122 |
+
ResnetBlock(1280, 1280, down=False, ksize=ksize, sk=sk, use_conv=use_conv))
|
123 |
+
self.conv_in = nn.Conv2d(cin, 320, 3, 1, 1)
|
124 |
+
self.pool = nn.AvgPool2d(kernel_size=2, stride=2)
|
125 |
+
self.adapter = nn.Linear(1280, 1280)
|
126 |
+
|
127 |
+
@property
|
128 |
+
def dtype(self) -> torch.dtype:
|
129 |
+
"""
|
130 |
+
`torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
|
131 |
+
"""
|
132 |
+
return get_parameter_dtype(self)
|
133 |
+
|
134 |
+
def forward(self, x):
|
135 |
+
x = self.unshuffle(x)
|
136 |
+
x = self.conv_in(x)
|
137 |
+
x = self.body(x)
|
138 |
+
x = self.pool(x)
|
139 |
+
x = x.flatten(start_dim=1, end_dim=-1)
|
140 |
+
x = self.adapter(x)
|
141 |
+
return x
|
142 |
+
|
143 |
+
|
144 |
+
class Middle_CNN(nn.Module):
|
145 |
+
def __init__(self, cin=192, ksize=1, sk=False, use_conv=True):
|
146 |
+
super(Middle_CNN, self).__init__()
|
147 |
+
self.unshuffle = nn.PixelUnshuffle(8)
|
148 |
+
self.body = nn.Sequential(ResnetBlock(320, 320, down=False, ksize=ksize, sk=sk, use_conv=use_conv),
|
149 |
+
ResnetBlock(320, 640, down=False, ksize=ksize, sk=sk, use_conv=use_conv),
|
150 |
+
ResnetBlock(640, 640, down=True, ksize=ksize, sk=sk, use_conv=use_conv),
|
151 |
+
ResnetBlock(640, 1280, down=True, ksize=ksize, sk=sk, use_conv=use_conv),
|
152 |
+
ResnetBlock(1280, 1280, down=False, ksize=ksize, sk=sk, use_conv=use_conv))
|
153 |
+
self.conv_in = nn.Conv2d(cin, 320, 3, 1, 1)
|
154 |
+
self.pool = nn.AvgPool2d(kernel_size=2, stride=2)
|
155 |
+
self.adapter = nn.Linear(1280, 1280)
|
156 |
+
|
157 |
+
@property
|
158 |
+
def dtype(self) -> torch.dtype:
|
159 |
+
"""
|
160 |
+
`torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
|
161 |
+
"""
|
162 |
+
return get_parameter_dtype(self)
|
163 |
+
|
164 |
+
def forward(self, x):
|
165 |
+
x = self.unshuffle(x)
|
166 |
+
x = self.conv_in(x)
|
167 |
+
x = self.body(x)
|
168 |
+
x = self.pool(x)
|
169 |
+
x = x.flatten(start_dim=1, end_dim=-1)
|
170 |
+
x = self.adapter(x)
|
171 |
+
return x
|
172 |
+
|
173 |
+
|
174 |
+
class High_CNN(nn.Module):
|
175 |
+
def __init__(self, cin=192, ksize=1, sk=False, use_conv=True):
|
176 |
+
super(High_CNN, self).__init__()
|
177 |
+
self.unshuffle = nn.PixelUnshuffle(8)
|
178 |
+
self.body = nn.Sequential(ResnetBlock(320, 320, down=False, ksize=ksize, sk=sk, use_conv=use_conv),
|
179 |
+
ResnetBlock(320, 640, down=False, ksize=ksize, sk=sk, use_conv=use_conv),
|
180 |
+
ResnetBlock(640, 640, down=True, ksize=ksize, sk=sk, use_conv=use_conv),
|
181 |
+
ResnetBlock(640, 640, down=True, ksize=ksize, sk=sk, use_conv=use_conv),
|
182 |
+
ResnetBlock(640, 1280, down=True, ksize=ksize, sk=sk, use_conv=use_conv),
|
183 |
+
ResnetBlock(1280, 1280, down=False, ksize=ksize, sk=sk, use_conv=use_conv))
|
184 |
+
self.conv_in = nn.Conv2d(cin, 320, 3, 1, 1)
|
185 |
+
self.pool = nn.AvgPool2d(kernel_size=2, stride=2)
|
186 |
+
self.adapter = nn.Linear(1280, 1280)
|
187 |
+
|
188 |
+
@property
|
189 |
+
def dtype(self) -> torch.dtype:
|
190 |
+
"""
|
191 |
+
`torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
|
192 |
+
"""
|
193 |
+
return get_parameter_dtype(self)
|
194 |
+
|
195 |
+
def forward(self, x):
|
196 |
+
x = self.unshuffle(x)
|
197 |
+
x = self.conv_in(x)
|
198 |
+
x = self.body(x)
|
199 |
+
x = self.pool(x)
|
200 |
+
x = x.flatten(start_dim=1, end_dim=-1)
|
201 |
+
x = self.adapter(x)
|
202 |
+
return x
|
203 |
+
|
204 |
+
|
205 |
+
class Style_Aware_Encoder(torch.nn.Module):
|
206 |
+
def __init__(self, image_encoder):
|
207 |
+
super().__init__()
|
208 |
+
self.image_encoder = image_encoder
|
209 |
+
self.projection_dim = self.image_encoder.config.projection_dim
|
210 |
+
self.num_positions = 59
|
211 |
+
self.embed_dim = 1280
|
212 |
+
self.cnn = nn.ModuleList(
|
213 |
+
[High_CNN(sk=True, use_conv=False),
|
214 |
+
Middle_CNN(sk=True, use_conv=False),
|
215 |
+
Low_CNN(sk=True, use_conv=False)]
|
216 |
+
)
|
217 |
+
self.style_embeddings = nn.ParameterList(
|
218 |
+
[nn.Parameter(torch.randn(self.embed_dim)),
|
219 |
+
nn.Parameter(torch.randn(self.embed_dim)),
|
220 |
+
nn.Parameter(torch.randn(self.embed_dim))]
|
221 |
+
)
|
222 |
+
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
|
223 |
+
self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)
|
224 |
+
|
225 |
+
def forward(self, inputs, batch_size=1):
|
226 |
+
embeddings = []
|
227 |
+
for idx, x in enumerate(inputs):
|
228 |
+
class_embed = self.style_embeddings[idx].expand(batch_size, 1, -1)
|
229 |
+
patch_embed = self.cnn[idx](x)
|
230 |
+
patch_embed = patch_embed.view(batch_size, -1, patch_embed.shape[1])
|
231 |
+
embedding = torch.cat([class_embed, patch_embed], dim=1)
|
232 |
+
embeddings.append(embedding)
|
233 |
+
embeddings = torch.cat(embeddings, dim=1)
|
234 |
+
embeddings = embeddings + self.position_embedding(self.position_ids) # [B, 256, 1280] - [B, P, 1280]
|
235 |
+
embeddings = self.image_encoder.vision_model.pre_layrnorm(embeddings)
|
236 |
+
encoder_outputs = self.image_encoder.vision_model.encoder(
|
237 |
+
inputs_embeds=embeddings,
|
238 |
+
output_attentions=None,
|
239 |
+
output_hidden_states=None,
|
240 |
+
return_dict=None,
|
241 |
+
)
|
242 |
+
last_hidden_state = encoder_outputs[0]
|
243 |
+
pooled_output = last_hidden_state[:, [0, 9, 26], :]
|
244 |
+
pooled_output = self.image_encoder.vision_model.post_layernorm(pooled_output)
|
245 |
+
out = self.image_encoder.visual_projection(pooled_output)
|
246 |
+
return out
|
ip_adapter/test_resampler.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from resampler import Resampler
|
3 |
+
from transformers import CLIPVisionModel
|
4 |
+
|
5 |
+
BATCH_SIZE = 2
|
6 |
+
OUTPUT_DIM = 1280
|
7 |
+
NUM_QUERIES = 8
|
8 |
+
NUM_LATENTS_MEAN_POOLED = 4 # 0 for no mean pooling (previous behavior)
|
9 |
+
APPLY_POS_EMB = True # False for no positional embeddings (previous behavior)
|
10 |
+
IMAGE_ENCODER_NAME_OR_PATH = "laion/CLIP-ViT-H-14-laion2B-s32B-b79K"
|
11 |
+
|
12 |
+
|
13 |
+
def main():
|
14 |
+
image_encoder = CLIPVisionModel.from_pretrained(IMAGE_ENCODER_NAME_OR_PATH)
|
15 |
+
embedding_dim = image_encoder.config.hidden_size
|
16 |
+
print(f"image_encoder hidden size: ", embedding_dim)
|
17 |
+
|
18 |
+
image_proj_model = Resampler(
|
19 |
+
dim=1024,
|
20 |
+
depth=2,
|
21 |
+
dim_head=64,
|
22 |
+
heads=16,
|
23 |
+
num_queries=NUM_QUERIES,
|
24 |
+
embedding_dim=embedding_dim,
|
25 |
+
output_dim=OUTPUT_DIM,
|
26 |
+
ff_mult=2,
|
27 |
+
max_seq_len=257,
|
28 |
+
apply_pos_emb=APPLY_POS_EMB,
|
29 |
+
num_latents_mean_pooled=NUM_LATENTS_MEAN_POOLED,
|
30 |
+
)
|
31 |
+
|
32 |
+
dummy_images = torch.randn(BATCH_SIZE, 3, 224, 224)
|
33 |
+
with torch.no_grad():
|
34 |
+
image_embeds = image_encoder(dummy_images, output_hidden_states=True).hidden_states[-2]
|
35 |
+
print("image_embds shape: ", image_embeds.shape)
|
36 |
+
|
37 |
+
with torch.no_grad():
|
38 |
+
ip_tokens = image_proj_model(image_embeds)
|
39 |
+
print("ip_tokens shape:", ip_tokens.shape)
|
40 |
+
assert ip_tokens.shape == (BATCH_SIZE, NUM_QUERIES + NUM_LATENTS_MEAN_POOLED, OUTPUT_DIM)
|
41 |
+
|
42 |
+
|
43 |
+
if __name__ == "__main__":
|
44 |
+
main()
|
ip_adapter/tools.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
import cv2
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
from PIL import Image
|
6 |
+
|
7 |
+
|
8 |
+
def crop_4_patches(image):
|
9 |
+
crop_size = int(image.size[0]/2)
|
10 |
+
return (image.crop((0, 0, crop_size, crop_size)), image.crop((0, crop_size, crop_size, 2*crop_size)),
|
11 |
+
image.crop((crop_size, 0, 2*crop_size, crop_size)), image.crop((crop_size, crop_size, 2*crop_size, 2*crop_size)))
|
12 |
+
|
13 |
+
|
14 |
+
def pre_processing(image, transform):
|
15 |
+
high_level = []
|
16 |
+
middle_level = []
|
17 |
+
low_level = []
|
18 |
+
crops_4 = crop_4_patches(image)
|
19 |
+
for c_4 in crops_4:
|
20 |
+
crops_8 = crop_4_patches(c_4)
|
21 |
+
high_level.append(transform(crops_8[0]))
|
22 |
+
high_level.append(transform(crops_8[3]))
|
23 |
+
for c_8 in [crops_8[1], crops_8[2]]:
|
24 |
+
crops_16 = crop_4_patches(c_8)
|
25 |
+
middle_level.append(transform(crops_16[0]))
|
26 |
+
middle_level.append(transform(crops_16[3]))
|
27 |
+
for c_16 in [crops_16[1], crops_16[2]]:
|
28 |
+
crops_32 = crop_4_patches(c_16)
|
29 |
+
low_level.append(transform(crops_32[0]))
|
30 |
+
low_level.append(transform(crops_32[3]))
|
31 |
+
return torch.stack(high_level), torch.stack(middle_level), torch.stack(low_level)
|
ip_adapter/utils.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn.functional as F
|
2 |
+
|
3 |
+
|
4 |
+
def is_torch2_available():
|
5 |
+
return hasattr(F, "scaled_dot_product_attention")
|