update : inference
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +3 -1
- external/briarmbg.py +460 -0
- external/llite/library/custom_train_functions.py +529 -529
- external/midas/__init__.py +0 -39
- external/midas/base_model.py +16 -0
- external/midas/blocks.py +342 -0
- external/midas/dpt_depth.py +109 -0
- external/midas/midas_net.py +76 -0
- external/midas/midas_net_custom.py +128 -0
- external/midas/transforms.py +234 -0
- external/midas/vit.py +491 -0
- external/realesrgan/__init__.py +6 -0
- external/realesrgan/archs/__init__.py +10 -0
- external/realesrgan/archs/discriminator_arch.py +67 -0
- external/realesrgan/archs/srvgg_arch.py +69 -0
- external/realesrgan/data/__init__.py +10 -0
- external/realesrgan/data/realesrgan_dataset.py +192 -0
- external/realesrgan/data/realesrgan_paired_dataset.py +117 -0
- external/realesrgan/models/__init__.py +10 -0
- external/realesrgan/models/realesrgan_model.py +258 -0
- external/realesrgan/models/realesrnet_model.py +188 -0
- external/realesrgan/train.py +11 -0
- external/realesrgan/utils.py +302 -0
- handler.py +6 -1
- inference.py +241 -96
- internals/data/task.py +17 -1
- internals/pipelines/commons.py +23 -4
- internals/pipelines/controlnets.py +277 -61
- internals/pipelines/high_res.py +32 -3
- internals/pipelines/inpaint_imageprocessor.py +976 -0
- internals/pipelines/inpainter.py +35 -4
- internals/pipelines/prompt_modifier.py +3 -1
- internals/pipelines/realtime_draw.py +13 -3
- internals/pipelines/remove_background.py +55 -5
- internals/pipelines/replace_background.py +8 -8
- internals/pipelines/safety_checker.py +3 -2
- internals/pipelines/sdxl_llite_pipeline.py +3 -1
- internals/pipelines/sdxl_tile_upscale.py +85 -14
- internals/pipelines/upscaler.py +25 -8
- internals/util/__init__.py +6 -0
- internals/util/cache.py +2 -0
- internals/util/commons.py +4 -4
- internals/util/config.py +19 -5
- internals/util/failure_hander.py +7 -4
- internals/util/image.py +18 -0
- internals/util/lora_style.py +29 -5
- internals/util/model_loader.py +8 -0
- internals/util/prompt.py +6 -1
- internals/util/sdxl_lightning.py +74 -0
- internals/util/slack.py +3 -0
.gitignore
CHANGED
@@ -3,6 +3,8 @@
|
|
3 |
.ipynb_checkpoints █
|
4 |
.vscode
|
5 |
env
|
6 |
-
test
|
7 |
*.jpeg
|
8 |
__pycache__
|
|
|
|
|
|
3 |
.ipynb_checkpoints █
|
4 |
.vscode
|
5 |
env
|
6 |
+
test*.py
|
7 |
*.jpeg
|
8 |
__pycache__
|
9 |
+
sample_task.txt
|
10 |
+
.idea
|
external/briarmbg.py
ADDED
@@ -0,0 +1,460 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from huggingface_hub import PyTorchModelHubMixin
|
5 |
+
|
6 |
+
|
7 |
+
class REBNCONV(nn.Module):
|
8 |
+
def __init__(self, in_ch=3, out_ch=3, dirate=1, stride=1):
|
9 |
+
super(REBNCONV, self).__init__()
|
10 |
+
|
11 |
+
self.conv_s1 = nn.Conv2d(
|
12 |
+
in_ch, out_ch, 3, padding=1 * dirate, dilation=1 * dirate, stride=stride
|
13 |
+
)
|
14 |
+
self.bn_s1 = nn.BatchNorm2d(out_ch)
|
15 |
+
self.relu_s1 = nn.ReLU(inplace=True)
|
16 |
+
|
17 |
+
def forward(self, x):
|
18 |
+
hx = x
|
19 |
+
xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))
|
20 |
+
|
21 |
+
return xout
|
22 |
+
|
23 |
+
|
24 |
+
## upsample tensor 'src' to have the same spatial size with tensor 'tar'
|
25 |
+
def _upsample_like(src, tar):
|
26 |
+
src = F.interpolate(src, size=tar.shape[2:], mode="bilinear")
|
27 |
+
|
28 |
+
return src
|
29 |
+
|
30 |
+
|
31 |
+
### RSU-7 ###
|
32 |
+
class RSU7(nn.Module):
|
33 |
+
def __init__(self, in_ch=3, mid_ch=12, out_ch=3, img_size=512):
|
34 |
+
super(RSU7, self).__init__()
|
35 |
+
|
36 |
+
self.in_ch = in_ch
|
37 |
+
self.mid_ch = mid_ch
|
38 |
+
self.out_ch = out_ch
|
39 |
+
|
40 |
+
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1) ## 1 -> 1/2
|
41 |
+
|
42 |
+
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
|
43 |
+
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
44 |
+
|
45 |
+
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
46 |
+
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
47 |
+
|
48 |
+
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
49 |
+
self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
50 |
+
|
51 |
+
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
52 |
+
self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
53 |
+
|
54 |
+
self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
55 |
+
self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
56 |
+
|
57 |
+
self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
58 |
+
|
59 |
+
self.rebnconv7 = REBNCONV(mid_ch, mid_ch, dirate=2)
|
60 |
+
|
61 |
+
self.rebnconv6d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
62 |
+
self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
63 |
+
self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
64 |
+
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
65 |
+
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
66 |
+
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
|
67 |
+
|
68 |
+
def forward(self, x):
|
69 |
+
b, c, h, w = x.shape
|
70 |
+
|
71 |
+
hx = x
|
72 |
+
hxin = self.rebnconvin(hx)
|
73 |
+
|
74 |
+
hx1 = self.rebnconv1(hxin)
|
75 |
+
hx = self.pool1(hx1)
|
76 |
+
|
77 |
+
hx2 = self.rebnconv2(hx)
|
78 |
+
hx = self.pool2(hx2)
|
79 |
+
|
80 |
+
hx3 = self.rebnconv3(hx)
|
81 |
+
hx = self.pool3(hx3)
|
82 |
+
|
83 |
+
hx4 = self.rebnconv4(hx)
|
84 |
+
hx = self.pool4(hx4)
|
85 |
+
|
86 |
+
hx5 = self.rebnconv5(hx)
|
87 |
+
hx = self.pool5(hx5)
|
88 |
+
|
89 |
+
hx6 = self.rebnconv6(hx)
|
90 |
+
|
91 |
+
hx7 = self.rebnconv7(hx6)
|
92 |
+
|
93 |
+
hx6d = self.rebnconv6d(torch.cat((hx7, hx6), 1))
|
94 |
+
hx6dup = _upsample_like(hx6d, hx5)
|
95 |
+
|
96 |
+
hx5d = self.rebnconv5d(torch.cat((hx6dup, hx5), 1))
|
97 |
+
hx5dup = _upsample_like(hx5d, hx4)
|
98 |
+
|
99 |
+
hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
|
100 |
+
hx4dup = _upsample_like(hx4d, hx3)
|
101 |
+
|
102 |
+
hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
|
103 |
+
hx3dup = _upsample_like(hx3d, hx2)
|
104 |
+
|
105 |
+
hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
|
106 |
+
hx2dup = _upsample_like(hx2d, hx1)
|
107 |
+
|
108 |
+
hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
|
109 |
+
|
110 |
+
return hx1d + hxin
|
111 |
+
|
112 |
+
|
113 |
+
### RSU-6 ###
|
114 |
+
class RSU6(nn.Module):
|
115 |
+
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
|
116 |
+
super(RSU6, self).__init__()
|
117 |
+
|
118 |
+
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
|
119 |
+
|
120 |
+
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
|
121 |
+
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
122 |
+
|
123 |
+
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
124 |
+
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
125 |
+
|
126 |
+
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
127 |
+
self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
128 |
+
|
129 |
+
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
130 |
+
self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
131 |
+
|
132 |
+
self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
133 |
+
|
134 |
+
self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=2)
|
135 |
+
|
136 |
+
self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
137 |
+
self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
138 |
+
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
139 |
+
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
140 |
+
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
|
141 |
+
|
142 |
+
def forward(self, x):
|
143 |
+
hx = x
|
144 |
+
|
145 |
+
hxin = self.rebnconvin(hx)
|
146 |
+
|
147 |
+
hx1 = self.rebnconv1(hxin)
|
148 |
+
hx = self.pool1(hx1)
|
149 |
+
|
150 |
+
hx2 = self.rebnconv2(hx)
|
151 |
+
hx = self.pool2(hx2)
|
152 |
+
|
153 |
+
hx3 = self.rebnconv3(hx)
|
154 |
+
hx = self.pool3(hx3)
|
155 |
+
|
156 |
+
hx4 = self.rebnconv4(hx)
|
157 |
+
hx = self.pool4(hx4)
|
158 |
+
|
159 |
+
hx5 = self.rebnconv5(hx)
|
160 |
+
|
161 |
+
hx6 = self.rebnconv6(hx5)
|
162 |
+
|
163 |
+
hx5d = self.rebnconv5d(torch.cat((hx6, hx5), 1))
|
164 |
+
hx5dup = _upsample_like(hx5d, hx4)
|
165 |
+
|
166 |
+
hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
|
167 |
+
hx4dup = _upsample_like(hx4d, hx3)
|
168 |
+
|
169 |
+
hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
|
170 |
+
hx3dup = _upsample_like(hx3d, hx2)
|
171 |
+
|
172 |
+
hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
|
173 |
+
hx2dup = _upsample_like(hx2d, hx1)
|
174 |
+
|
175 |
+
hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
|
176 |
+
|
177 |
+
return hx1d + hxin
|
178 |
+
|
179 |
+
|
180 |
+
### RSU-5 ###
|
181 |
+
class RSU5(nn.Module):
|
182 |
+
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
|
183 |
+
super(RSU5, self).__init__()
|
184 |
+
|
185 |
+
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
|
186 |
+
|
187 |
+
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
|
188 |
+
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
189 |
+
|
190 |
+
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
191 |
+
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
192 |
+
|
193 |
+
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
194 |
+
self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
195 |
+
|
196 |
+
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
197 |
+
|
198 |
+
self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=2)
|
199 |
+
|
200 |
+
self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
201 |
+
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
202 |
+
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
203 |
+
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
|
204 |
+
|
205 |
+
def forward(self, x):
|
206 |
+
hx = x
|
207 |
+
|
208 |
+
hxin = self.rebnconvin(hx)
|
209 |
+
|
210 |
+
hx1 = self.rebnconv1(hxin)
|
211 |
+
hx = self.pool1(hx1)
|
212 |
+
|
213 |
+
hx2 = self.rebnconv2(hx)
|
214 |
+
hx = self.pool2(hx2)
|
215 |
+
|
216 |
+
hx3 = self.rebnconv3(hx)
|
217 |
+
hx = self.pool3(hx3)
|
218 |
+
|
219 |
+
hx4 = self.rebnconv4(hx)
|
220 |
+
|
221 |
+
hx5 = self.rebnconv5(hx4)
|
222 |
+
|
223 |
+
hx4d = self.rebnconv4d(torch.cat((hx5, hx4), 1))
|
224 |
+
hx4dup = _upsample_like(hx4d, hx3)
|
225 |
+
|
226 |
+
hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
|
227 |
+
hx3dup = _upsample_like(hx3d, hx2)
|
228 |
+
|
229 |
+
hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
|
230 |
+
hx2dup = _upsample_like(hx2d, hx1)
|
231 |
+
|
232 |
+
hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
|
233 |
+
|
234 |
+
return hx1d + hxin
|
235 |
+
|
236 |
+
|
237 |
+
### RSU-4 ###
|
238 |
+
class RSU4(nn.Module):
|
239 |
+
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
|
240 |
+
super(RSU4, self).__init__()
|
241 |
+
|
242 |
+
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
|
243 |
+
|
244 |
+
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
|
245 |
+
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
246 |
+
|
247 |
+
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
248 |
+
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
249 |
+
|
250 |
+
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
251 |
+
|
252 |
+
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=2)
|
253 |
+
|
254 |
+
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
255 |
+
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
256 |
+
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
|
257 |
+
|
258 |
+
def forward(self, x):
|
259 |
+
hx = x
|
260 |
+
|
261 |
+
hxin = self.rebnconvin(hx)
|
262 |
+
|
263 |
+
hx1 = self.rebnconv1(hxin)
|
264 |
+
hx = self.pool1(hx1)
|
265 |
+
|
266 |
+
hx2 = self.rebnconv2(hx)
|
267 |
+
hx = self.pool2(hx2)
|
268 |
+
|
269 |
+
hx3 = self.rebnconv3(hx)
|
270 |
+
|
271 |
+
hx4 = self.rebnconv4(hx3)
|
272 |
+
|
273 |
+
hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
|
274 |
+
hx3dup = _upsample_like(hx3d, hx2)
|
275 |
+
|
276 |
+
hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
|
277 |
+
hx2dup = _upsample_like(hx2d, hx1)
|
278 |
+
|
279 |
+
hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
|
280 |
+
|
281 |
+
return hx1d + hxin
|
282 |
+
|
283 |
+
|
284 |
+
### RSU-4F ###
|
285 |
+
class RSU4F(nn.Module):
|
286 |
+
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
|
287 |
+
super(RSU4F, self).__init__()
|
288 |
+
|
289 |
+
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
|
290 |
+
|
291 |
+
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
|
292 |
+
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=2)
|
293 |
+
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=4)
|
294 |
+
|
295 |
+
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=8)
|
296 |
+
|
297 |
+
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=4)
|
298 |
+
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=2)
|
299 |
+
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
|
300 |
+
|
301 |
+
def forward(self, x):
|
302 |
+
hx = x
|
303 |
+
|
304 |
+
hxin = self.rebnconvin(hx)
|
305 |
+
|
306 |
+
hx1 = self.rebnconv1(hxin)
|
307 |
+
hx2 = self.rebnconv2(hx1)
|
308 |
+
hx3 = self.rebnconv3(hx2)
|
309 |
+
|
310 |
+
hx4 = self.rebnconv4(hx3)
|
311 |
+
|
312 |
+
hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
|
313 |
+
hx2d = self.rebnconv2d(torch.cat((hx3d, hx2), 1))
|
314 |
+
hx1d = self.rebnconv1d(torch.cat((hx2d, hx1), 1))
|
315 |
+
|
316 |
+
return hx1d + hxin
|
317 |
+
|
318 |
+
|
319 |
+
class myrebnconv(nn.Module):
|
320 |
+
def __init__(
|
321 |
+
self,
|
322 |
+
in_ch=3,
|
323 |
+
out_ch=1,
|
324 |
+
kernel_size=3,
|
325 |
+
stride=1,
|
326 |
+
padding=1,
|
327 |
+
dilation=1,
|
328 |
+
groups=1,
|
329 |
+
):
|
330 |
+
super(myrebnconv, self).__init__()
|
331 |
+
|
332 |
+
self.conv = nn.Conv2d(
|
333 |
+
in_ch,
|
334 |
+
out_ch,
|
335 |
+
kernel_size=kernel_size,
|
336 |
+
stride=stride,
|
337 |
+
padding=padding,
|
338 |
+
dilation=dilation,
|
339 |
+
groups=groups,
|
340 |
+
)
|
341 |
+
self.bn = nn.BatchNorm2d(out_ch)
|
342 |
+
self.rl = nn.ReLU(inplace=True)
|
343 |
+
|
344 |
+
def forward(self, x):
|
345 |
+
return self.rl(self.bn(self.conv(x)))
|
346 |
+
|
347 |
+
|
348 |
+
class BriaRMBG(nn.Module, PyTorchModelHubMixin):
|
349 |
+
def __init__(self, config: dict = {"in_ch": 3, "out_ch": 1}):
|
350 |
+
super(BriaRMBG, self).__init__()
|
351 |
+
in_ch = config["in_ch"]
|
352 |
+
out_ch = config["out_ch"]
|
353 |
+
self.conv_in = nn.Conv2d(in_ch, 64, 3, stride=2, padding=1)
|
354 |
+
self.pool_in = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
355 |
+
|
356 |
+
self.stage1 = RSU7(64, 32, 64)
|
357 |
+
self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
358 |
+
|
359 |
+
self.stage2 = RSU6(64, 32, 128)
|
360 |
+
self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
361 |
+
|
362 |
+
self.stage3 = RSU5(128, 64, 256)
|
363 |
+
self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
364 |
+
|
365 |
+
self.stage4 = RSU4(256, 128, 512)
|
366 |
+
self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
367 |
+
|
368 |
+
self.stage5 = RSU4F(512, 256, 512)
|
369 |
+
self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
370 |
+
|
371 |
+
self.stage6 = RSU4F(512, 256, 512)
|
372 |
+
|
373 |
+
# decoder
|
374 |
+
self.stage5d = RSU4F(1024, 256, 512)
|
375 |
+
self.stage4d = RSU4(1024, 128, 256)
|
376 |
+
self.stage3d = RSU5(512, 64, 128)
|
377 |
+
self.stage2d = RSU6(256, 32, 64)
|
378 |
+
self.stage1d = RSU7(128, 16, 64)
|
379 |
+
|
380 |
+
self.side1 = nn.Conv2d(64, out_ch, 3, padding=1)
|
381 |
+
self.side2 = nn.Conv2d(64, out_ch, 3, padding=1)
|
382 |
+
self.side3 = nn.Conv2d(128, out_ch, 3, padding=1)
|
383 |
+
self.side4 = nn.Conv2d(256, out_ch, 3, padding=1)
|
384 |
+
self.side5 = nn.Conv2d(512, out_ch, 3, padding=1)
|
385 |
+
self.side6 = nn.Conv2d(512, out_ch, 3, padding=1)
|
386 |
+
|
387 |
+
# self.outconv = nn.Conv2d(6*out_ch,out_ch,1)
|
388 |
+
|
389 |
+
def forward(self, x):
|
390 |
+
hx = x
|
391 |
+
|
392 |
+
hxin = self.conv_in(hx)
|
393 |
+
# hx = self.pool_in(hxin)
|
394 |
+
|
395 |
+
# stage 1
|
396 |
+
hx1 = self.stage1(hxin)
|
397 |
+
hx = self.pool12(hx1)
|
398 |
+
|
399 |
+
# stage 2
|
400 |
+
hx2 = self.stage2(hx)
|
401 |
+
hx = self.pool23(hx2)
|
402 |
+
|
403 |
+
# stage 3
|
404 |
+
hx3 = self.stage3(hx)
|
405 |
+
hx = self.pool34(hx3)
|
406 |
+
|
407 |
+
# stage 4
|
408 |
+
hx4 = self.stage4(hx)
|
409 |
+
hx = self.pool45(hx4)
|
410 |
+
|
411 |
+
# stage 5
|
412 |
+
hx5 = self.stage5(hx)
|
413 |
+
hx = self.pool56(hx5)
|
414 |
+
|
415 |
+
# stage 6
|
416 |
+
hx6 = self.stage6(hx)
|
417 |
+
hx6up = _upsample_like(hx6, hx5)
|
418 |
+
|
419 |
+
# -------------------- decoder --------------------
|
420 |
+
hx5d = self.stage5d(torch.cat((hx6up, hx5), 1))
|
421 |
+
hx5dup = _upsample_like(hx5d, hx4)
|
422 |
+
|
423 |
+
hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1))
|
424 |
+
hx4dup = _upsample_like(hx4d, hx3)
|
425 |
+
|
426 |
+
hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1))
|
427 |
+
hx3dup = _upsample_like(hx3d, hx2)
|
428 |
+
|
429 |
+
hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1))
|
430 |
+
hx2dup = _upsample_like(hx2d, hx1)
|
431 |
+
|
432 |
+
hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1))
|
433 |
+
|
434 |
+
# side output
|
435 |
+
d1 = self.side1(hx1d)
|
436 |
+
d1 = _upsample_like(d1, x)
|
437 |
+
|
438 |
+
d2 = self.side2(hx2d)
|
439 |
+
d2 = _upsample_like(d2, x)
|
440 |
+
|
441 |
+
d3 = self.side3(hx3d)
|
442 |
+
d3 = _upsample_like(d3, x)
|
443 |
+
|
444 |
+
d4 = self.side4(hx4d)
|
445 |
+
d4 = _upsample_like(d4, x)
|
446 |
+
|
447 |
+
d5 = self.side5(hx5d)
|
448 |
+
d5 = _upsample_like(d5, x)
|
449 |
+
|
450 |
+
d6 = self.side6(hx6)
|
451 |
+
d6 = _upsample_like(d6, x)
|
452 |
+
|
453 |
+
return [
|
454 |
+
F.sigmoid(d1),
|
455 |
+
F.sigmoid(d2),
|
456 |
+
F.sigmoid(d3),
|
457 |
+
F.sigmoid(d4),
|
458 |
+
F.sigmoid(d5),
|
459 |
+
F.sigmoid(d6),
|
460 |
+
], [hx1d, hx2d, hx3d, hx4d, hx5d, hx6]
|
external/llite/library/custom_train_functions.py
CHANGED
@@ -1,529 +1,529 @@
|
|
1 |
-
import torch
|
2 |
-
import argparse
|
3 |
-
import random
|
4 |
-
import re
|
5 |
-
from typing import List, Optional, Union
|
6 |
-
|
7 |
-
|
8 |
-
def prepare_scheduler_for_custom_training(noise_scheduler, device):
|
9 |
-
if hasattr(noise_scheduler, "all_snr"):
|
10 |
-
return
|
11 |
-
|
12 |
-
alphas_cumprod = noise_scheduler.alphas_cumprod
|
13 |
-
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
|
14 |
-
sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)
|
15 |
-
alpha = sqrt_alphas_cumprod
|
16 |
-
sigma = sqrt_one_minus_alphas_cumprod
|
17 |
-
all_snr = (alpha / sigma) ** 2
|
18 |
-
|
19 |
-
noise_scheduler.all_snr = all_snr.to(device)
|
20 |
-
|
21 |
-
|
22 |
-
def fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler):
|
23 |
-
# fix beta: zero terminal SNR
|
24 |
-
print(f"fix noise scheduler betas: https://arxiv.org/abs/2305.08891")
|
25 |
-
|
26 |
-
def enforce_zero_terminal_snr(betas):
|
27 |
-
# Convert betas to alphas_bar_sqrt
|
28 |
-
alphas = 1 - betas
|
29 |
-
alphas_bar = alphas.cumprod(0)
|
30 |
-
alphas_bar_sqrt = alphas_bar.sqrt()
|
31 |
-
|
32 |
-
# Store old values.
|
33 |
-
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
|
34 |
-
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
|
35 |
-
# Shift so last timestep is zero.
|
36 |
-
alphas_bar_sqrt -= alphas_bar_sqrt_T
|
37 |
-
# Scale so first timestep is back to old value.
|
38 |
-
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
|
39 |
-
|
40 |
-
# Convert alphas_bar_sqrt to betas
|
41 |
-
alphas_bar = alphas_bar_sqrt**2
|
42 |
-
alphas = alphas_bar[1:] / alphas_bar[:-1]
|
43 |
-
alphas = torch.cat([alphas_bar[0:1], alphas])
|
44 |
-
betas = 1 - alphas
|
45 |
-
return betas
|
46 |
-
|
47 |
-
betas = noise_scheduler.betas
|
48 |
-
betas = enforce_zero_terminal_snr(betas)
|
49 |
-
alphas = 1.0 - betas
|
50 |
-
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
51 |
-
|
52 |
-
# print("original:", noise_scheduler.betas)
|
53 |
-
# print("fixed:", betas)
|
54 |
-
|
55 |
-
noise_scheduler.betas = betas
|
56 |
-
noise_scheduler.alphas = alphas
|
57 |
-
noise_scheduler.alphas_cumprod = alphas_cumprod
|
58 |
-
|
59 |
-
|
60 |
-
def apply_snr_weight(loss, timesteps, noise_scheduler, gamma, v_prediction=False):
|
61 |
-
snr = torch.stack([noise_scheduler.all_snr[t] for t in timesteps])
|
62 |
-
min_snr_gamma = torch.minimum(snr, torch.full_like(snr, gamma))
|
63 |
-
if v_prediction:
|
64 |
-
snr_weight = torch.div(min_snr_gamma, snr+1).float().to(loss.device)
|
65 |
-
else:
|
66 |
-
snr_weight = torch.div(min_snr_gamma, snr).float().to(loss.device)
|
67 |
-
loss = loss * snr_weight
|
68 |
-
return loss
|
69 |
-
|
70 |
-
|
71 |
-
def scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler):
|
72 |
-
scale = get_snr_scale(timesteps, noise_scheduler)
|
73 |
-
loss = loss * scale
|
74 |
-
return loss
|
75 |
-
|
76 |
-
|
77 |
-
def get_snr_scale(timesteps, noise_scheduler):
|
78 |
-
snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) # batch_size
|
79 |
-
snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000
|
80 |
-
scale = snr_t / (snr_t + 1)
|
81 |
-
# # show debug info
|
82 |
-
# print(f"timesteps: {timesteps}, snr_t: {snr_t}, scale: {scale}")
|
83 |
-
return scale
|
84 |
-
|
85 |
-
|
86 |
-
def add_v_prediction_like_loss(loss, timesteps, noise_scheduler, v_pred_like_loss):
|
87 |
-
scale = get_snr_scale(timesteps, noise_scheduler)
|
88 |
-
# print(f"add v-prediction like loss: {v_pred_like_loss}, scale: {scale}, loss: {loss}, time: {timesteps}")
|
89 |
-
loss = loss + loss / scale * v_pred_like_loss
|
90 |
-
return loss
|
91 |
-
|
92 |
-
def apply_debiased_estimation(loss, timesteps, noise_scheduler):
|
93 |
-
snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) # batch_size
|
94 |
-
snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000
|
95 |
-
weight = 1/torch.sqrt(snr_t)
|
96 |
-
loss = weight * loss
|
97 |
-
return loss
|
98 |
-
|
99 |
-
# TODO train_utilと分散しているのでどちらかに寄せる
|
100 |
-
|
101 |
-
|
102 |
-
def add_custom_train_arguments(parser: argparse.ArgumentParser, support_weighted_captions: bool = True):
|
103 |
-
parser.add_argument(
|
104 |
-
"--min_snr_gamma",
|
105 |
-
type=float,
|
106 |
-
default=None,
|
107 |
-
help="gamma for reducing the weight of high loss timesteps. Lower numbers have stronger effect. 5 is recommended by paper. / 低いタイムステップでの高いlossに対して重みを減らすためのgamma値、低いほど効果が強く、論文では5が推奨",
|
108 |
-
)
|
109 |
-
parser.add_argument(
|
110 |
-
"--scale_v_pred_loss_like_noise_pred",
|
111 |
-
action="store_true",
|
112 |
-
help="scale v-prediction loss like noise prediction loss / v-prediction lossをnoise prediction lossと同じようにスケーリングする",
|
113 |
-
)
|
114 |
-
parser.add_argument(
|
115 |
-
"--v_pred_like_loss",
|
116 |
-
type=float,
|
117 |
-
default=None,
|
118 |
-
help="add v-prediction like loss multiplied by this value / v-prediction lossをこの値をかけたものをlossに加算する",
|
119 |
-
)
|
120 |
-
parser.add_argument(
|
121 |
-
"--debiased_estimation_loss",
|
122 |
-
action="store_true",
|
123 |
-
help="debiased estimation loss / debiased estimation loss",
|
124 |
-
)
|
125 |
-
if support_weighted_captions:
|
126 |
-
parser.add_argument(
|
127 |
-
"--weighted_captions",
|
128 |
-
action="store_true",
|
129 |
-
default=False,
|
130 |
-
help="Enable weighted captions in the standard style (token:1.3). No commas inside parens, or shuffle/dropout may break the decoder. / 「[token]」、「(token)」「(token:1.3)」のような重み付きキャプションを有効にする。カンマを括弧内に入れるとシャッフルやdropoutで重みづけがおかしくなるので注意",
|
131 |
-
)
|
132 |
-
|
133 |
-
|
134 |
-
re_attention = re.compile(
|
135 |
-
r"""
|
136 |
-
\\\(|
|
137 |
-
\\\)|
|
138 |
-
\\\[|
|
139 |
-
\\]|
|
140 |
-
\\\\|
|
141 |
-
\\|
|
142 |
-
\(|
|
143 |
-
\[|
|
144 |
-
:([+-]?[.\d]+)\)|
|
145 |
-
\)|
|
146 |
-
]|
|
147 |
-
[^\\()\[\]:]+|
|
148 |
-
:
|
149 |
-
""",
|
150 |
-
re.X,
|
151 |
-
)
|
152 |
-
|
153 |
-
|
154 |
-
def parse_prompt_attention(text):
|
155 |
-
"""
|
156 |
-
Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
|
157 |
-
Accepted tokens are:
|
158 |
-
(abc) - increases attention to abc by a multiplier of 1.1
|
159 |
-
(abc:3.12) - increases attention to abc by a multiplier of 3.12
|
160 |
-
[abc] - decreases attention to abc by a multiplier of 1.1
|
161 |
-
\( - literal character '('
|
162 |
-
\[ - literal character '['
|
163 |
-
\) - literal character ')'
|
164 |
-
\] - literal character ']'
|
165 |
-
\\ - literal character '\'
|
166 |
-
anything else - just text
|
167 |
-
>>> parse_prompt_attention('normal text')
|
168 |
-
[['normal text', 1.0]]
|
169 |
-
>>> parse_prompt_attention('an (important) word')
|
170 |
-
[['an ', 1.0], ['important', 1.1], [' word', 1.0]]
|
171 |
-
>>> parse_prompt_attention('(unbalanced')
|
172 |
-
[['unbalanced', 1.1]]
|
173 |
-
>>> parse_prompt_attention('\(literal\]')
|
174 |
-
[['(literal]', 1.0]]
|
175 |
-
>>> parse_prompt_attention('(unnecessary)(parens)')
|
176 |
-
[['unnecessaryparens', 1.1]]
|
177 |
-
>>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
|
178 |
-
[['a ', 1.0],
|
179 |
-
['house', 1.5730000000000004],
|
180 |
-
[' ', 1.1],
|
181 |
-
['on', 1.0],
|
182 |
-
[' a ', 1.1],
|
183 |
-
['hill', 0.55],
|
184 |
-
[', sun, ', 1.1],
|
185 |
-
['sky', 1.4641000000000006],
|
186 |
-
['.', 1.1]]
|
187 |
-
"""
|
188 |
-
|
189 |
-
res = []
|
190 |
-
round_brackets = []
|
191 |
-
square_brackets = []
|
192 |
-
|
193 |
-
round_bracket_multiplier = 1.1
|
194 |
-
square_bracket_multiplier = 1 / 1.1
|
195 |
-
|
196 |
-
def multiply_range(start_position, multiplier):
|
197 |
-
for p in range(start_position, len(res)):
|
198 |
-
res[p][1] *= multiplier
|
199 |
-
|
200 |
-
for m in re_attention.finditer(text):
|
201 |
-
text = m.group(0)
|
202 |
-
weight = m.group(1)
|
203 |
-
|
204 |
-
if text.startswith("\\"):
|
205 |
-
res.append([text[1:], 1.0])
|
206 |
-
elif text == "(":
|
207 |
-
round_brackets.append(len(res))
|
208 |
-
elif text == "[":
|
209 |
-
square_brackets.append(len(res))
|
210 |
-
elif weight is not None and len(round_brackets) > 0:
|
211 |
-
multiply_range(round_brackets.pop(), float(weight))
|
212 |
-
elif text == ")" and len(round_brackets) > 0:
|
213 |
-
multiply_range(round_brackets.pop(), round_bracket_multiplier)
|
214 |
-
elif text == "]" and len(square_brackets) > 0:
|
215 |
-
multiply_range(square_brackets.pop(), square_bracket_multiplier)
|
216 |
-
else:
|
217 |
-
res.append([text, 1.0])
|
218 |
-
|
219 |
-
for pos in round_brackets:
|
220 |
-
multiply_range(pos, round_bracket_multiplier)
|
221 |
-
|
222 |
-
for pos in square_brackets:
|
223 |
-
multiply_range(pos, square_bracket_multiplier)
|
224 |
-
|
225 |
-
if len(res) == 0:
|
226 |
-
res = [["", 1.0]]
|
227 |
-
|
228 |
-
# merge runs of identical weights
|
229 |
-
i = 0
|
230 |
-
while i + 1 < len(res):
|
231 |
-
if res[i][1] == res[i + 1][1]:
|
232 |
-
res[i][0] += res[i + 1][0]
|
233 |
-
res.pop(i + 1)
|
234 |
-
else:
|
235 |
-
i += 1
|
236 |
-
|
237 |
-
return res
|
238 |
-
|
239 |
-
|
240 |
-
def get_prompts_with_weights(tokenizer, prompt: List[str], max_length: int):
|
241 |
-
r"""
|
242 |
-
Tokenize a list of prompts and return its tokens with weights of each token.
|
243 |
-
|
244 |
-
No padding, starting or ending token is included.
|
245 |
-
"""
|
246 |
-
tokens = []
|
247 |
-
weights = []
|
248 |
-
truncated = False
|
249 |
-
for text in prompt:
|
250 |
-
texts_and_weights = parse_prompt_attention(text)
|
251 |
-
text_token = []
|
252 |
-
text_weight = []
|
253 |
-
for word, weight in texts_and_weights:
|
254 |
-
# tokenize and discard the starting and the ending token
|
255 |
-
token = tokenizer(word).input_ids[1:-1]
|
256 |
-
text_token += token
|
257 |
-
# copy the weight by length of token
|
258 |
-
text_weight += [weight] * len(token)
|
259 |
-
# stop if the text is too long (longer than truncation limit)
|
260 |
-
if len(text_token) > max_length:
|
261 |
-
truncated = True
|
262 |
-
break
|
263 |
-
# truncate
|
264 |
-
if len(text_token) > max_length:
|
265 |
-
truncated = True
|
266 |
-
text_token = text_token[:max_length]
|
267 |
-
text_weight = text_weight[:max_length]
|
268 |
-
tokens.append(text_token)
|
269 |
-
weights.append(text_weight)
|
270 |
-
if truncated:
|
271 |
-
print("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples")
|
272 |
-
return tokens, weights
|
273 |
-
|
274 |
-
|
275 |
-
def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, no_boseos_middle=True, chunk_length=77):
|
276 |
-
r"""
|
277 |
-
Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length.
|
278 |
-
"""
|
279 |
-
max_embeddings_multiples = (max_length - 2) // (chunk_length - 2)
|
280 |
-
weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length
|
281 |
-
for i in range(len(tokens)):
|
282 |
-
tokens[i] = [bos] + tokens[i] + [eos] * (max_length - 1 - len(tokens[i]))
|
283 |
-
if no_boseos_middle:
|
284 |
-
weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i]))
|
285 |
-
else:
|
286 |
-
w = []
|
287 |
-
if len(weights[i]) == 0:
|
288 |
-
w = [1.0] * weights_length
|
289 |
-
else:
|
290 |
-
for j in range(max_embeddings_multiples):
|
291 |
-
w.append(1.0) # weight for starting token in this chunk
|
292 |
-
w += weights[i][j * (chunk_length - 2) : min(len(weights[i]), (j + 1) * (chunk_length - 2))]
|
293 |
-
w.append(1.0) # weight for ending token in this chunk
|
294 |
-
w += [1.0] * (weights_length - len(w))
|
295 |
-
weights[i] = w[:]
|
296 |
-
|
297 |
-
return tokens, weights
|
298 |
-
|
299 |
-
|
300 |
-
def get_unweighted_text_embeddings(
|
301 |
-
tokenizer,
|
302 |
-
text_encoder,
|
303 |
-
text_input: torch.Tensor,
|
304 |
-
chunk_length: int,
|
305 |
-
clip_skip: int,
|
306 |
-
eos: int,
|
307 |
-
pad: int,
|
308 |
-
no_boseos_middle: Optional[bool] = True,
|
309 |
-
):
|
310 |
-
"""
|
311 |
-
When the length of tokens is a multiple of the capacity of the text encoder,
|
312 |
-
it should be split into chunks and sent to the text encoder individually.
|
313 |
-
"""
|
314 |
-
max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2)
|
315 |
-
if max_embeddings_multiples > 1:
|
316 |
-
text_embeddings = []
|
317 |
-
for i in range(max_embeddings_multiples):
|
318 |
-
# extract the i-th chunk
|
319 |
-
text_input_chunk = text_input[:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2].clone()
|
320 |
-
|
321 |
-
# cover the head and the tail by the starting and the ending tokens
|
322 |
-
text_input_chunk[:, 0] = text_input[0, 0]
|
323 |
-
if pad == eos: # v1
|
324 |
-
text_input_chunk[:, -1] = text_input[0, -1]
|
325 |
-
else: # v2
|
326 |
-
for j in range(len(text_input_chunk)):
|
327 |
-
if text_input_chunk[j, -1] != eos and text_input_chunk[j, -1] != pad: # 最後に普通の文字がある
|
328 |
-
text_input_chunk[j, -1] = eos
|
329 |
-
if text_input_chunk[j, 1] == pad: # BOSだけであとはPAD
|
330 |
-
text_input_chunk[j, 1] = eos
|
331 |
-
|
332 |
-
if clip_skip is None or clip_skip == 1:
|
333 |
-
text_embedding = text_encoder(text_input_chunk)[0]
|
334 |
-
else:
|
335 |
-
enc_out = text_encoder(text_input_chunk, output_hidden_states=True, return_dict=True)
|
336 |
-
text_embedding = enc_out["hidden_states"][-clip_skip]
|
337 |
-
text_embedding = text_encoder.text_model.final_layer_norm(text_embedding)
|
338 |
-
|
339 |
-
if no_boseos_middle:
|
340 |
-
if i == 0:
|
341 |
-
# discard the ending token
|
342 |
-
text_embedding = text_embedding[:, :-1]
|
343 |
-
elif i == max_embeddings_multiples - 1:
|
344 |
-
# discard the starting token
|
345 |
-
text_embedding = text_embedding[:, 1:]
|
346 |
-
else:
|
347 |
-
# discard both starting and ending tokens
|
348 |
-
text_embedding = text_embedding[:, 1:-1]
|
349 |
-
|
350 |
-
text_embeddings.append(text_embedding)
|
351 |
-
text_embeddings = torch.concat(text_embeddings, axis=1)
|
352 |
-
else:
|
353 |
-
if clip_skip is None or clip_skip == 1:
|
354 |
-
text_embeddings = text_encoder(text_input)[0]
|
355 |
-
else:
|
356 |
-
enc_out = text_encoder(text_input, output_hidden_states=True, return_dict=True)
|
357 |
-
text_embeddings = enc_out["hidden_states"][-clip_skip]
|
358 |
-
text_embeddings = text_encoder.text_model.final_layer_norm(text_embeddings)
|
359 |
-
return text_embeddings
|
360 |
-
|
361 |
-
|
362 |
-
def get_weighted_text_embeddings(
|
363 |
-
tokenizer,
|
364 |
-
text_encoder,
|
365 |
-
prompt: Union[str, List[str]],
|
366 |
-
device,
|
367 |
-
max_embeddings_multiples: Optional[int] = 3,
|
368 |
-
no_boseos_middle: Optional[bool] = False,
|
369 |
-
clip_skip=None,
|
370 |
-
):
|
371 |
-
r"""
|
372 |
-
Prompts can be assigned with local weights using brackets. For example,
|
373 |
-
prompt 'A (very beautiful) masterpiece' highlights the words 'very beautiful',
|
374 |
-
and the embedding tokens corresponding to the words get multiplied by a constant, 1.1.
|
375 |
-
|
376 |
-
Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean.
|
377 |
-
|
378 |
-
Args:
|
379 |
-
prompt (`str` or `List[str]`):
|
380 |
-
The prompt or prompts to guide the image generation.
|
381 |
-
max_embeddings_multiples (`int`, *optional*, defaults to `3`):
|
382 |
-
The max multiple length of prompt embeddings compared to the max output length of text encoder.
|
383 |
-
no_boseos_middle (`bool`, *optional*, defaults to `False`):
|
384 |
-
If the length of text token is multiples of the capacity of text encoder, whether reserve the starting and
|
385 |
-
ending token in each of the chunk in the middle.
|
386 |
-
skip_parsing (`bool`, *optional*, defaults to `False`):
|
387 |
-
Skip the parsing of brackets.
|
388 |
-
skip_weighting (`bool`, *optional*, defaults to `False`):
|
389 |
-
Skip the weighting. When the parsing is skipped, it is forced True.
|
390 |
-
"""
|
391 |
-
max_length = (tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
|
392 |
-
if isinstance(prompt, str):
|
393 |
-
prompt = [prompt]
|
394 |
-
|
395 |
-
prompt_tokens, prompt_weights = get_prompts_with_weights(tokenizer, prompt, max_length - 2)
|
396 |
-
|
397 |
-
# round up the longest length of tokens to a multiple of (model_max_length - 2)
|
398 |
-
max_length = max([len(token) for token in prompt_tokens])
|
399 |
-
|
400 |
-
max_embeddings_multiples = min(
|
401 |
-
max_embeddings_multiples,
|
402 |
-
(max_length - 1) // (tokenizer.model_max_length - 2) + 1,
|
403 |
-
)
|
404 |
-
max_embeddings_multiples = max(1, max_embeddings_multiples)
|
405 |
-
max_length = (tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
|
406 |
-
|
407 |
-
# pad the length of tokens and weights
|
408 |
-
bos = tokenizer.bos_token_id
|
409 |
-
eos = tokenizer.eos_token_id
|
410 |
-
pad = tokenizer.pad_token_id
|
411 |
-
prompt_tokens, prompt_weights = pad_tokens_and_weights(
|
412 |
-
prompt_tokens,
|
413 |
-
prompt_weights,
|
414 |
-
max_length,
|
415 |
-
bos,
|
416 |
-
eos,
|
417 |
-
no_boseos_middle=no_boseos_middle,
|
418 |
-
chunk_length=tokenizer.model_max_length,
|
419 |
-
)
|
420 |
-
prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=device)
|
421 |
-
|
422 |
-
# get the embeddings
|
423 |
-
text_embeddings = get_unweighted_text_embeddings(
|
424 |
-
tokenizer,
|
425 |
-
text_encoder,
|
426 |
-
prompt_tokens,
|
427 |
-
tokenizer.model_max_length,
|
428 |
-
clip_skip,
|
429 |
-
eos,
|
430 |
-
pad,
|
431 |
-
no_boseos_middle=no_boseos_middle,
|
432 |
-
)
|
433 |
-
prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=device)
|
434 |
-
|
435 |
-
# assign weights to the prompts and normalize in the sense of mean
|
436 |
-
previous_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
|
437 |
-
text_embeddings = text_embeddings * prompt_weights.unsqueeze(-1)
|
438 |
-
current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
|
439 |
-
text_embeddings = text_embeddings * (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
|
440 |
-
|
441 |
-
return text_embeddings
|
442 |
-
|
443 |
-
|
444 |
-
# https://wandb.ai/johnowhitaker/multires_noise/reports/Multi-Resolution-Noise-for-Diffusion-Model-Training--VmlldzozNjYyOTU2
|
445 |
-
def pyramid_noise_like(noise, device, iterations=6, discount=0.4):
|
446 |
-
b, c, w, h = noise.shape # EDIT: w and h get over-written, rename for a different variant!
|
447 |
-
u = torch.nn.Upsample(size=(w, h), mode="bilinear").to(device)
|
448 |
-
for i in range(iterations):
|
449 |
-
r = random.random() * 2 + 2 # Rather than always going 2x,
|
450 |
-
wn, hn = max(1, int(w / (r**i))), max(1, int(h / (r**i)))
|
451 |
-
noise += u(torch.randn(b, c, wn, hn).to(device)) * discount**i
|
452 |
-
if wn == 1 or hn == 1:
|
453 |
-
break # Lowest resolution is 1x1
|
454 |
-
return noise / noise.std() # Scaled back to roughly unit variance
|
455 |
-
|
456 |
-
|
457 |
-
# https://www.crosslabs.org//blog/diffusion-with-offset-noise
|
458 |
-
def apply_noise_offset(latents, noise, noise_offset, adaptive_noise_scale):
|
459 |
-
if noise_offset is None:
|
460 |
-
return noise
|
461 |
-
if adaptive_noise_scale is not None:
|
462 |
-
# latent shape: (batch_size, channels, height, width)
|
463 |
-
# abs mean value for each channel
|
464 |
-
latent_mean = torch.abs(latents.mean(dim=(2, 3), keepdim=True))
|
465 |
-
|
466 |
-
# multiply adaptive noise scale to the mean value and add it to the noise offset
|
467 |
-
noise_offset = noise_offset + adaptive_noise_scale * latent_mean
|
468 |
-
noise_offset = torch.clamp(noise_offset, 0.0, None) # in case of adaptive noise scale is negative
|
469 |
-
|
470 |
-
noise = noise + noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device)
|
471 |
-
return noise
|
472 |
-
|
473 |
-
|
474 |
-
"""
|
475 |
-
##########################################
|
476 |
-
# Perlin Noise
|
477 |
-
def rand_perlin_2d(device, shape, res, fade=lambda t: 6 * t**5 - 15 * t**4 + 10 * t**3):
|
478 |
-
delta = (res[0] / shape[0], res[1] / shape[1])
|
479 |
-
d = (shape[0] // res[0], shape[1] // res[1])
|
480 |
-
|
481 |
-
grid = (
|
482 |
-
torch.stack(
|
483 |
-
torch.meshgrid(torch.arange(0, res[0], delta[0], device=device), torch.arange(0, res[1], delta[1], device=device)),
|
484 |
-
dim=-1,
|
485 |
-
)
|
486 |
-
% 1
|
487 |
-
)
|
488 |
-
angles = 2 * torch.pi * torch.rand(res[0] + 1, res[1] + 1, device=device)
|
489 |
-
gradients = torch.stack((torch.cos(angles), torch.sin(angles)), dim=-1)
|
490 |
-
|
491 |
-
tile_grads = (
|
492 |
-
lambda slice1, slice2: gradients[slice1[0] : slice1[1], slice2[0] : slice2[1]]
|
493 |
-
.repeat_interleave(d[0], 0)
|
494 |
-
.repeat_interleave(d[1], 1)
|
495 |
-
)
|
496 |
-
dot = lambda grad, shift: (
|
497 |
-
torch.stack((grid[: shape[0], : shape[1], 0] + shift[0], grid[: shape[0], : shape[1], 1] + shift[1]), dim=-1)
|
498 |
-
* grad[: shape[0], : shape[1]]
|
499 |
-
).sum(dim=-1)
|
500 |
-
|
501 |
-
n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0])
|
502 |
-
n10 = dot(tile_grads([1, None], [0, -1]), [-1, 0])
|
503 |
-
n01 = dot(tile_grads([0, -1], [1, None]), [0, -1])
|
504 |
-
n11 = dot(tile_grads([1, None], [1, None]), [-1, -1])
|
505 |
-
t = fade(grid[: shape[0], : shape[1]])
|
506 |
-
return 1.414 * torch.lerp(torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1])
|
507 |
-
|
508 |
-
|
509 |
-
def rand_perlin_2d_octaves(device, shape, res, octaves=1, persistence=0.5):
|
510 |
-
noise = torch.zeros(shape, device=device)
|
511 |
-
frequency = 1
|
512 |
-
amplitude = 1
|
513 |
-
for _ in range(octaves):
|
514 |
-
noise += amplitude * rand_perlin_2d(device, shape, (frequency * res[0], frequency * res[1]))
|
515 |
-
frequency *= 2
|
516 |
-
amplitude *= persistence
|
517 |
-
return noise
|
518 |
-
|
519 |
-
|
520 |
-
def perlin_noise(noise, device, octaves):
|
521 |
-
_, c, w, h = noise.shape
|
522 |
-
perlin = lambda: rand_perlin_2d_octaves(device, (w, h), (4, 4), octaves)
|
523 |
-
noise_perlin = []
|
524 |
-
for _ in range(c):
|
525 |
-
noise_perlin.append(perlin())
|
526 |
-
noise_perlin = torch.stack(noise_perlin).unsqueeze(0) # (1, c, w, h)
|
527 |
-
noise += noise_perlin # broadcast for each batch
|
528 |
-
return noise / noise.std() # Scaled back to roughly unit variance
|
529 |
-
"""
|
|
|
1 |
+
import torch
|
2 |
+
import argparse
|
3 |
+
import random
|
4 |
+
import re
|
5 |
+
from typing import List, Optional, Union
|
6 |
+
|
7 |
+
|
8 |
+
def prepare_scheduler_for_custom_training(noise_scheduler, device):
|
9 |
+
if hasattr(noise_scheduler, "all_snr"):
|
10 |
+
return
|
11 |
+
|
12 |
+
alphas_cumprod = noise_scheduler.alphas_cumprod
|
13 |
+
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
|
14 |
+
sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)
|
15 |
+
alpha = sqrt_alphas_cumprod
|
16 |
+
sigma = sqrt_one_minus_alphas_cumprod
|
17 |
+
all_snr = (alpha / sigma) ** 2
|
18 |
+
|
19 |
+
noise_scheduler.all_snr = all_snr.to(device)
|
20 |
+
|
21 |
+
|
22 |
+
def fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler):
|
23 |
+
# fix beta: zero terminal SNR
|
24 |
+
print(f"fix noise scheduler betas: https://arxiv.org/abs/2305.08891")
|
25 |
+
|
26 |
+
def enforce_zero_terminal_snr(betas):
|
27 |
+
# Convert betas to alphas_bar_sqrt
|
28 |
+
alphas = 1 - betas
|
29 |
+
alphas_bar = alphas.cumprod(0)
|
30 |
+
alphas_bar_sqrt = alphas_bar.sqrt()
|
31 |
+
|
32 |
+
# Store old values.
|
33 |
+
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
|
34 |
+
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
|
35 |
+
# Shift so last timestep is zero.
|
36 |
+
alphas_bar_sqrt -= alphas_bar_sqrt_T
|
37 |
+
# Scale so first timestep is back to old value.
|
38 |
+
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
|
39 |
+
|
40 |
+
# Convert alphas_bar_sqrt to betas
|
41 |
+
alphas_bar = alphas_bar_sqrt**2
|
42 |
+
alphas = alphas_bar[1:] / alphas_bar[:-1]
|
43 |
+
alphas = torch.cat([alphas_bar[0:1], alphas])
|
44 |
+
betas = 1 - alphas
|
45 |
+
return betas
|
46 |
+
|
47 |
+
betas = noise_scheduler.betas
|
48 |
+
betas = enforce_zero_terminal_snr(betas)
|
49 |
+
alphas = 1.0 - betas
|
50 |
+
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
51 |
+
|
52 |
+
# print("original:", noise_scheduler.betas)
|
53 |
+
# print("fixed:", betas)
|
54 |
+
|
55 |
+
noise_scheduler.betas = betas
|
56 |
+
noise_scheduler.alphas = alphas
|
57 |
+
noise_scheduler.alphas_cumprod = alphas_cumprod
|
58 |
+
|
59 |
+
|
60 |
+
def apply_snr_weight(loss, timesteps, noise_scheduler, gamma, v_prediction=False):
|
61 |
+
snr = torch.stack([noise_scheduler.all_snr[t] for t in timesteps])
|
62 |
+
min_snr_gamma = torch.minimum(snr, torch.full_like(snr, gamma))
|
63 |
+
if v_prediction:
|
64 |
+
snr_weight = torch.div(min_snr_gamma, snr+1).float().to(loss.device)
|
65 |
+
else:
|
66 |
+
snr_weight = torch.div(min_snr_gamma, snr).float().to(loss.device)
|
67 |
+
loss = loss * snr_weight
|
68 |
+
return loss
|
69 |
+
|
70 |
+
|
71 |
+
def scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler):
|
72 |
+
scale = get_snr_scale(timesteps, noise_scheduler)
|
73 |
+
loss = loss * scale
|
74 |
+
return loss
|
75 |
+
|
76 |
+
|
77 |
+
def get_snr_scale(timesteps, noise_scheduler):
|
78 |
+
snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) # batch_size
|
79 |
+
snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000
|
80 |
+
scale = snr_t / (snr_t + 1)
|
81 |
+
# # show debug info
|
82 |
+
# print(f"timesteps: {timesteps}, snr_t: {snr_t}, scale: {scale}")
|
83 |
+
return scale
|
84 |
+
|
85 |
+
|
86 |
+
def add_v_prediction_like_loss(loss, timesteps, noise_scheduler, v_pred_like_loss):
|
87 |
+
scale = get_snr_scale(timesteps, noise_scheduler)
|
88 |
+
# print(f"add v-prediction like loss: {v_pred_like_loss}, scale: {scale}, loss: {loss}, time: {timesteps}")
|
89 |
+
loss = loss + loss / scale * v_pred_like_loss
|
90 |
+
return loss
|
91 |
+
|
92 |
+
def apply_debiased_estimation(loss, timesteps, noise_scheduler):
|
93 |
+
snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) # batch_size
|
94 |
+
snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000
|
95 |
+
weight = 1/torch.sqrt(snr_t)
|
96 |
+
loss = weight * loss
|
97 |
+
return loss
|
98 |
+
|
99 |
+
# TODO train_utilと分散しているのでどちらかに寄せる
|
100 |
+
|
101 |
+
|
102 |
+
def add_custom_train_arguments(parser: argparse.ArgumentParser, support_weighted_captions: bool = True):
|
103 |
+
parser.add_argument(
|
104 |
+
"--min_snr_gamma",
|
105 |
+
type=float,
|
106 |
+
default=None,
|
107 |
+
help="gamma for reducing the weight of high loss timesteps. Lower numbers have stronger effect. 5 is recommended by paper. / 低いタイムステップでの高いlossに対して重みを減らすためのgamma値、低いほど効果が強く、論文では5が推奨",
|
108 |
+
)
|
109 |
+
parser.add_argument(
|
110 |
+
"--scale_v_pred_loss_like_noise_pred",
|
111 |
+
action="store_true",
|
112 |
+
help="scale v-prediction loss like noise prediction loss / v-prediction lossをnoise prediction lossと同じようにスケーリングする",
|
113 |
+
)
|
114 |
+
parser.add_argument(
|
115 |
+
"--v_pred_like_loss",
|
116 |
+
type=float,
|
117 |
+
default=None,
|
118 |
+
help="add v-prediction like loss multiplied by this value / v-prediction lossをこの値をかけたものをlossに加算する",
|
119 |
+
)
|
120 |
+
parser.add_argument(
|
121 |
+
"--debiased_estimation_loss",
|
122 |
+
action="store_true",
|
123 |
+
help="debiased estimation loss / debiased estimation loss",
|
124 |
+
)
|
125 |
+
if support_weighted_captions:
|
126 |
+
parser.add_argument(
|
127 |
+
"--weighted_captions",
|
128 |
+
action="store_true",
|
129 |
+
default=False,
|
130 |
+
help="Enable weighted captions in the standard style (token:1.3). No commas inside parens, or shuffle/dropout may break the decoder. / 「[token]」、「(token)」「(token:1.3)」のような重み付きキャプションを有効にする。カンマを括弧内に入れるとシャッフルやdropoutで重みづけがおかしくなるので注意",
|
131 |
+
)
|
132 |
+
|
133 |
+
|
134 |
+
re_attention = re.compile(
|
135 |
+
r"""
|
136 |
+
\\\(|
|
137 |
+
\\\)|
|
138 |
+
\\\[|
|
139 |
+
\\]|
|
140 |
+
\\\\|
|
141 |
+
\\|
|
142 |
+
\(|
|
143 |
+
\[|
|
144 |
+
:([+-]?[.\d]+)\)|
|
145 |
+
\)|
|
146 |
+
]|
|
147 |
+
[^\\()\[\]:]+|
|
148 |
+
:
|
149 |
+
""",
|
150 |
+
re.X,
|
151 |
+
)
|
152 |
+
|
153 |
+
|
154 |
+
def parse_prompt_attention(text):
|
155 |
+
"""
|
156 |
+
Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
|
157 |
+
Accepted tokens are:
|
158 |
+
(abc) - increases attention to abc by a multiplier of 1.1
|
159 |
+
(abc:3.12) - increases attention to abc by a multiplier of 3.12
|
160 |
+
[abc] - decreases attention to abc by a multiplier of 1.1
|
161 |
+
\( - literal character '('
|
162 |
+
\[ - literal character '['
|
163 |
+
\) - literal character ')'
|
164 |
+
\] - literal character ']'
|
165 |
+
\\ - literal character '\'
|
166 |
+
anything else - just text
|
167 |
+
>>> parse_prompt_attention('normal text')
|
168 |
+
[['normal text', 1.0]]
|
169 |
+
>>> parse_prompt_attention('an (important) word')
|
170 |
+
[['an ', 1.0], ['important', 1.1], [' word', 1.0]]
|
171 |
+
>>> parse_prompt_attention('(unbalanced')
|
172 |
+
[['unbalanced', 1.1]]
|
173 |
+
>>> parse_prompt_attention('\(literal\]')
|
174 |
+
[['(literal]', 1.0]]
|
175 |
+
>>> parse_prompt_attention('(unnecessary)(parens)')
|
176 |
+
[['unnecessaryparens', 1.1]]
|
177 |
+
>>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
|
178 |
+
[['a ', 1.0],
|
179 |
+
['house', 1.5730000000000004],
|
180 |
+
[' ', 1.1],
|
181 |
+
['on', 1.0],
|
182 |
+
[' a ', 1.1],
|
183 |
+
['hill', 0.55],
|
184 |
+
[', sun, ', 1.1],
|
185 |
+
['sky', 1.4641000000000006],
|
186 |
+
['.', 1.1]]
|
187 |
+
"""
|
188 |
+
|
189 |
+
res = []
|
190 |
+
round_brackets = []
|
191 |
+
square_brackets = []
|
192 |
+
|
193 |
+
round_bracket_multiplier = 1.1
|
194 |
+
square_bracket_multiplier = 1 / 1.1
|
195 |
+
|
196 |
+
def multiply_range(start_position, multiplier):
|
197 |
+
for p in range(start_position, len(res)):
|
198 |
+
res[p][1] *= multiplier
|
199 |
+
|
200 |
+
for m in re_attention.finditer(text):
|
201 |
+
text = m.group(0)
|
202 |
+
weight = m.group(1)
|
203 |
+
|
204 |
+
if text.startswith("\\"):
|
205 |
+
res.append([text[1:], 1.0])
|
206 |
+
elif text == "(":
|
207 |
+
round_brackets.append(len(res))
|
208 |
+
elif text == "[":
|
209 |
+
square_brackets.append(len(res))
|
210 |
+
elif weight is not None and len(round_brackets) > 0:
|
211 |
+
multiply_range(round_brackets.pop(), float(weight))
|
212 |
+
elif text == ")" and len(round_brackets) > 0:
|
213 |
+
multiply_range(round_brackets.pop(), round_bracket_multiplier)
|
214 |
+
elif text == "]" and len(square_brackets) > 0:
|
215 |
+
multiply_range(square_brackets.pop(), square_bracket_multiplier)
|
216 |
+
else:
|
217 |
+
res.append([text, 1.0])
|
218 |
+
|
219 |
+
for pos in round_brackets:
|
220 |
+
multiply_range(pos, round_bracket_multiplier)
|
221 |
+
|
222 |
+
for pos in square_brackets:
|
223 |
+
multiply_range(pos, square_bracket_multiplier)
|
224 |
+
|
225 |
+
if len(res) == 0:
|
226 |
+
res = [["", 1.0]]
|
227 |
+
|
228 |
+
# merge runs of identical weights
|
229 |
+
i = 0
|
230 |
+
while i + 1 < len(res):
|
231 |
+
if res[i][1] == res[i + 1][1]:
|
232 |
+
res[i][0] += res[i + 1][0]
|
233 |
+
res.pop(i + 1)
|
234 |
+
else:
|
235 |
+
i += 1
|
236 |
+
|
237 |
+
return res
|
238 |
+
|
239 |
+
|
240 |
+
def get_prompts_with_weights(tokenizer, prompt: List[str], max_length: int):
|
241 |
+
r"""
|
242 |
+
Tokenize a list of prompts and return its tokens with weights of each token.
|
243 |
+
|
244 |
+
No padding, starting or ending token is included.
|
245 |
+
"""
|
246 |
+
tokens = []
|
247 |
+
weights = []
|
248 |
+
truncated = False
|
249 |
+
for text in prompt:
|
250 |
+
texts_and_weights = parse_prompt_attention(text)
|
251 |
+
text_token = []
|
252 |
+
text_weight = []
|
253 |
+
for word, weight in texts_and_weights:
|
254 |
+
# tokenize and discard the starting and the ending token
|
255 |
+
token = tokenizer(word).input_ids[1:-1]
|
256 |
+
text_token += token
|
257 |
+
# copy the weight by length of token
|
258 |
+
text_weight += [weight] * len(token)
|
259 |
+
# stop if the text is too long (longer than truncation limit)
|
260 |
+
if len(text_token) > max_length:
|
261 |
+
truncated = True
|
262 |
+
break
|
263 |
+
# truncate
|
264 |
+
if len(text_token) > max_length:
|
265 |
+
truncated = True
|
266 |
+
text_token = text_token[:max_length]
|
267 |
+
text_weight = text_weight[:max_length]
|
268 |
+
tokens.append(text_token)
|
269 |
+
weights.append(text_weight)
|
270 |
+
if truncated:
|
271 |
+
print("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples")
|
272 |
+
return tokens, weights
|
273 |
+
|
274 |
+
|
275 |
+
def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, no_boseos_middle=True, chunk_length=77):
|
276 |
+
r"""
|
277 |
+
Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length.
|
278 |
+
"""
|
279 |
+
max_embeddings_multiples = (max_length - 2) // (chunk_length - 2)
|
280 |
+
weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length
|
281 |
+
for i in range(len(tokens)):
|
282 |
+
tokens[i] = [bos] + tokens[i] + [eos] * (max_length - 1 - len(tokens[i]))
|
283 |
+
if no_boseos_middle:
|
284 |
+
weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i]))
|
285 |
+
else:
|
286 |
+
w = []
|
287 |
+
if len(weights[i]) == 0:
|
288 |
+
w = [1.0] * weights_length
|
289 |
+
else:
|
290 |
+
for j in range(max_embeddings_multiples):
|
291 |
+
w.append(1.0) # weight for starting token in this chunk
|
292 |
+
w += weights[i][j * (chunk_length - 2) : min(len(weights[i]), (j + 1) * (chunk_length - 2))]
|
293 |
+
w.append(1.0) # weight for ending token in this chunk
|
294 |
+
w += [1.0] * (weights_length - len(w))
|
295 |
+
weights[i] = w[:]
|
296 |
+
|
297 |
+
return tokens, weights
|
298 |
+
|
299 |
+
|
300 |
+
def get_unweighted_text_embeddings(
|
301 |
+
tokenizer,
|
302 |
+
text_encoder,
|
303 |
+
text_input: torch.Tensor,
|
304 |
+
chunk_length: int,
|
305 |
+
clip_skip: int,
|
306 |
+
eos: int,
|
307 |
+
pad: int,
|
308 |
+
no_boseos_middle: Optional[bool] = True,
|
309 |
+
):
|
310 |
+
"""
|
311 |
+
When the length of tokens is a multiple of the capacity of the text encoder,
|
312 |
+
it should be split into chunks and sent to the text encoder individually.
|
313 |
+
"""
|
314 |
+
max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2)
|
315 |
+
if max_embeddings_multiples > 1:
|
316 |
+
text_embeddings = []
|
317 |
+
for i in range(max_embeddings_multiples):
|
318 |
+
# extract the i-th chunk
|
319 |
+
text_input_chunk = text_input[:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2].clone()
|
320 |
+
|
321 |
+
# cover the head and the tail by the starting and the ending tokens
|
322 |
+
text_input_chunk[:, 0] = text_input[0, 0]
|
323 |
+
if pad == eos: # v1
|
324 |
+
text_input_chunk[:, -1] = text_input[0, -1]
|
325 |
+
else: # v2
|
326 |
+
for j in range(len(text_input_chunk)):
|
327 |
+
if text_input_chunk[j, -1] != eos and text_input_chunk[j, -1] != pad: # 最後に普通の文字がある
|
328 |
+
text_input_chunk[j, -1] = eos
|
329 |
+
if text_input_chunk[j, 1] == pad: # BOSだけであとはPAD
|
330 |
+
text_input_chunk[j, 1] = eos
|
331 |
+
|
332 |
+
if clip_skip is None or clip_skip == 1:
|
333 |
+
text_embedding = text_encoder(text_input_chunk)[0]
|
334 |
+
else:
|
335 |
+
enc_out = text_encoder(text_input_chunk, output_hidden_states=True, return_dict=True)
|
336 |
+
text_embedding = enc_out["hidden_states"][-clip_skip]
|
337 |
+
text_embedding = text_encoder.text_model.final_layer_norm(text_embedding)
|
338 |
+
|
339 |
+
if no_boseos_middle:
|
340 |
+
if i == 0:
|
341 |
+
# discard the ending token
|
342 |
+
text_embedding = text_embedding[:, :-1]
|
343 |
+
elif i == max_embeddings_multiples - 1:
|
344 |
+
# discard the starting token
|
345 |
+
text_embedding = text_embedding[:, 1:]
|
346 |
+
else:
|
347 |
+
# discard both starting and ending tokens
|
348 |
+
text_embedding = text_embedding[:, 1:-1]
|
349 |
+
|
350 |
+
text_embeddings.append(text_embedding)
|
351 |
+
text_embeddings = torch.concat(text_embeddings, axis=1)
|
352 |
+
else:
|
353 |
+
if clip_skip is None or clip_skip == 1:
|
354 |
+
text_embeddings = text_encoder(text_input)[0]
|
355 |
+
else:
|
356 |
+
enc_out = text_encoder(text_input, output_hidden_states=True, return_dict=True)
|
357 |
+
text_embeddings = enc_out["hidden_states"][-clip_skip]
|
358 |
+
text_embeddings = text_encoder.text_model.final_layer_norm(text_embeddings)
|
359 |
+
return text_embeddings
|
360 |
+
|
361 |
+
|
362 |
+
def get_weighted_text_embeddings(
|
363 |
+
tokenizer,
|
364 |
+
text_encoder,
|
365 |
+
prompt: Union[str, List[str]],
|
366 |
+
device,
|
367 |
+
max_embeddings_multiples: Optional[int] = 3,
|
368 |
+
no_boseos_middle: Optional[bool] = False,
|
369 |
+
clip_skip=None,
|
370 |
+
):
|
371 |
+
r"""
|
372 |
+
Prompts can be assigned with local weights using brackets. For example,
|
373 |
+
prompt 'A (very beautiful) masterpiece' highlights the words 'very beautiful',
|
374 |
+
and the embedding tokens corresponding to the words get multiplied by a constant, 1.1.
|
375 |
+
|
376 |
+
Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean.
|
377 |
+
|
378 |
+
Args:
|
379 |
+
prompt (`str` or `List[str]`):
|
380 |
+
The prompt or prompts to guide the image generation.
|
381 |
+
max_embeddings_multiples (`int`, *optional*, defaults to `3`):
|
382 |
+
The max multiple length of prompt embeddings compared to the max output length of text encoder.
|
383 |
+
no_boseos_middle (`bool`, *optional*, defaults to `False`):
|
384 |
+
If the length of text token is multiples of the capacity of text encoder, whether reserve the starting and
|
385 |
+
ending token in each of the chunk in the middle.
|
386 |
+
skip_parsing (`bool`, *optional*, defaults to `False`):
|
387 |
+
Skip the parsing of brackets.
|
388 |
+
skip_weighting (`bool`, *optional*, defaults to `False`):
|
389 |
+
Skip the weighting. When the parsing is skipped, it is forced True.
|
390 |
+
"""
|
391 |
+
max_length = (tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
|
392 |
+
if isinstance(prompt, str):
|
393 |
+
prompt = [prompt]
|
394 |
+
|
395 |
+
prompt_tokens, prompt_weights = get_prompts_with_weights(tokenizer, prompt, max_length - 2)
|
396 |
+
|
397 |
+
# round up the longest length of tokens to a multiple of (model_max_length - 2)
|
398 |
+
max_length = max([len(token) for token in prompt_tokens])
|
399 |
+
|
400 |
+
max_embeddings_multiples = min(
|
401 |
+
max_embeddings_multiples,
|
402 |
+
(max_length - 1) // (tokenizer.model_max_length - 2) + 1,
|
403 |
+
)
|
404 |
+
max_embeddings_multiples = max(1, max_embeddings_multiples)
|
405 |
+
max_length = (tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
|
406 |
+
|
407 |
+
# pad the length of tokens and weights
|
408 |
+
bos = tokenizer.bos_token_id
|
409 |
+
eos = tokenizer.eos_token_id
|
410 |
+
pad = tokenizer.pad_token_id
|
411 |
+
prompt_tokens, prompt_weights = pad_tokens_and_weights(
|
412 |
+
prompt_tokens,
|
413 |
+
prompt_weights,
|
414 |
+
max_length,
|
415 |
+
bos,
|
416 |
+
eos,
|
417 |
+
no_boseos_middle=no_boseos_middle,
|
418 |
+
chunk_length=tokenizer.model_max_length,
|
419 |
+
)
|
420 |
+
prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=device)
|
421 |
+
|
422 |
+
# get the embeddings
|
423 |
+
text_embeddings = get_unweighted_text_embeddings(
|
424 |
+
tokenizer,
|
425 |
+
text_encoder,
|
426 |
+
prompt_tokens,
|
427 |
+
tokenizer.model_max_length,
|
428 |
+
clip_skip,
|
429 |
+
eos,
|
430 |
+
pad,
|
431 |
+
no_boseos_middle=no_boseos_middle,
|
432 |
+
)
|
433 |
+
prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=device)
|
434 |
+
|
435 |
+
# assign weights to the prompts and normalize in the sense of mean
|
436 |
+
previous_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
|
437 |
+
text_embeddings = text_embeddings * prompt_weights.unsqueeze(-1)
|
438 |
+
current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
|
439 |
+
text_embeddings = text_embeddings * (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
|
440 |
+
|
441 |
+
return text_embeddings
|
442 |
+
|
443 |
+
|
444 |
+
# https://wandb.ai/johnowhitaker/multires_noise/reports/Multi-Resolution-Noise-for-Diffusion-Model-Training--VmlldzozNjYyOTU2
|
445 |
+
def pyramid_noise_like(noise, device, iterations=6, discount=0.4):
|
446 |
+
b, c, w, h = noise.shape # EDIT: w and h get over-written, rename for a different variant!
|
447 |
+
u = torch.nn.Upsample(size=(w, h), mode="bilinear").to(device)
|
448 |
+
for i in range(iterations):
|
449 |
+
r = random.random() * 2 + 2 # Rather than always going 2x,
|
450 |
+
wn, hn = max(1, int(w / (r**i))), max(1, int(h / (r**i)))
|
451 |
+
noise += u(torch.randn(b, c, wn, hn).to(device)) * discount**i
|
452 |
+
if wn == 1 or hn == 1:
|
453 |
+
break # Lowest resolution is 1x1
|
454 |
+
return noise / noise.std() # Scaled back to roughly unit variance
|
455 |
+
|
456 |
+
|
457 |
+
# https://www.crosslabs.org//blog/diffusion-with-offset-noise
|
458 |
+
def apply_noise_offset(latents, noise, noise_offset, adaptive_noise_scale):
|
459 |
+
if noise_offset is None:
|
460 |
+
return noise
|
461 |
+
if adaptive_noise_scale is not None:
|
462 |
+
# latent shape: (batch_size, channels, height, width)
|
463 |
+
# abs mean value for each channel
|
464 |
+
latent_mean = torch.abs(latents.mean(dim=(2, 3), keepdim=True))
|
465 |
+
|
466 |
+
# multiply adaptive noise scale to the mean value and add it to the noise offset
|
467 |
+
noise_offset = noise_offset + adaptive_noise_scale * latent_mean
|
468 |
+
noise_offset = torch.clamp(noise_offset, 0.0, None) # in case of adaptive noise scale is negative
|
469 |
+
|
470 |
+
noise = noise + noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device)
|
471 |
+
return noise
|
472 |
+
|
473 |
+
|
474 |
+
"""
|
475 |
+
##########################################
|
476 |
+
# Perlin Noise
|
477 |
+
def rand_perlin_2d(device, shape, res, fade=lambda t: 6 * t**5 - 15 * t**4 + 10 * t**3):
|
478 |
+
delta = (res[0] / shape[0], res[1] / shape[1])
|
479 |
+
d = (shape[0] // res[0], shape[1] // res[1])
|
480 |
+
|
481 |
+
grid = (
|
482 |
+
torch.stack(
|
483 |
+
torch.meshgrid(torch.arange(0, res[0], delta[0], device=device), torch.arange(0, res[1], delta[1], device=device)),
|
484 |
+
dim=-1,
|
485 |
+
)
|
486 |
+
% 1
|
487 |
+
)
|
488 |
+
angles = 2 * torch.pi * torch.rand(res[0] + 1, res[1] + 1, device=device)
|
489 |
+
gradients = torch.stack((torch.cos(angles), torch.sin(angles)), dim=-1)
|
490 |
+
|
491 |
+
tile_grads = (
|
492 |
+
lambda slice1, slice2: gradients[slice1[0] : slice1[1], slice2[0] : slice2[1]]
|
493 |
+
.repeat_interleave(d[0], 0)
|
494 |
+
.repeat_interleave(d[1], 1)
|
495 |
+
)
|
496 |
+
dot = lambda grad, shift: (
|
497 |
+
torch.stack((grid[: shape[0], : shape[1], 0] + shift[0], grid[: shape[0], : shape[1], 1] + shift[1]), dim=-1)
|
498 |
+
* grad[: shape[0], : shape[1]]
|
499 |
+
).sum(dim=-1)
|
500 |
+
|
501 |
+
n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0])
|
502 |
+
n10 = dot(tile_grads([1, None], [0, -1]), [-1, 0])
|
503 |
+
n01 = dot(tile_grads([0, -1], [1, None]), [0, -1])
|
504 |
+
n11 = dot(tile_grads([1, None], [1, None]), [-1, -1])
|
505 |
+
t = fade(grid[: shape[0], : shape[1]])
|
506 |
+
return 1.414 * torch.lerp(torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1])
|
507 |
+
|
508 |
+
|
509 |
+
def rand_perlin_2d_octaves(device, shape, res, octaves=1, persistence=0.5):
|
510 |
+
noise = torch.zeros(shape, device=device)
|
511 |
+
frequency = 1
|
512 |
+
amplitude = 1
|
513 |
+
for _ in range(octaves):
|
514 |
+
noise += amplitude * rand_perlin_2d(device, shape, (frequency * res[0], frequency * res[1]))
|
515 |
+
frequency *= 2
|
516 |
+
amplitude *= persistence
|
517 |
+
return noise
|
518 |
+
|
519 |
+
|
520 |
+
def perlin_noise(noise, device, octaves):
|
521 |
+
_, c, w, h = noise.shape
|
522 |
+
perlin = lambda: rand_perlin_2d_octaves(device, (w, h), (4, 4), octaves)
|
523 |
+
noise_perlin = []
|
524 |
+
for _ in range(c):
|
525 |
+
noise_perlin.append(perlin())
|
526 |
+
noise_perlin = torch.stack(noise_perlin).unsqueeze(0) # (1, c, w, h)
|
527 |
+
noise += noise_perlin # broadcast for each batch
|
528 |
+
return noise / noise.std() # Scaled back to roughly unit variance
|
529 |
+
"""
|
external/midas/__init__.py
CHANGED
@@ -1,39 +0,0 @@
|
|
1 |
-
import cv2
|
2 |
-
import numpy as np
|
3 |
-
import torch
|
4 |
-
from einops import rearrange
|
5 |
-
|
6 |
-
from .api import MiDaSInference
|
7 |
-
|
8 |
-
model = None
|
9 |
-
|
10 |
-
|
11 |
-
def apply_midas(input_image, a=np.pi * 2.0, bg_th=0.1):
|
12 |
-
global model
|
13 |
-
if not model:
|
14 |
-
model = MiDaSInference(model_type="dpt_hybrid").cuda()
|
15 |
-
assert input_image.ndim == 3
|
16 |
-
image_depth = input_image
|
17 |
-
with torch.no_grad():
|
18 |
-
image_depth = torch.from_numpy(image_depth).float().cuda()
|
19 |
-
image_depth = image_depth / 127.5 - 1.0
|
20 |
-
image_depth = rearrange(image_depth, "h w c -> 1 c h w")
|
21 |
-
depth = model(image_depth)[0]
|
22 |
-
|
23 |
-
depth_pt = depth.clone()
|
24 |
-
depth_pt -= torch.min(depth_pt)
|
25 |
-
depth_pt /= torch.max(depth_pt)
|
26 |
-
depth_pt = depth_pt.cpu().numpy()
|
27 |
-
depth_image = (depth_pt * 255.0).clip(0, 255).astype(np.uint8)
|
28 |
-
|
29 |
-
depth_np = depth.cpu().numpy()
|
30 |
-
x = cv2.Sobel(depth_np, cv2.CV_32F, 1, 0, ksize=3)
|
31 |
-
y = cv2.Sobel(depth_np, cv2.CV_32F, 0, 1, ksize=3)
|
32 |
-
z = np.ones_like(x) * a
|
33 |
-
x[depth_pt < bg_th] = 0
|
34 |
-
y[depth_pt < bg_th] = 0
|
35 |
-
normal = np.stack([x, y, z], axis=2)
|
36 |
-
normal /= np.sum(normal**2.0, axis=2, keepdims=True) ** 0.5
|
37 |
-
normal_image = (normal * 127.5 + 127.5).clip(0, 255).astype(np.uint8)
|
38 |
-
|
39 |
-
return depth_image, normal_image
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
external/midas/base_model.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
|
4 |
+
class BaseModel(torch.nn.Module):
|
5 |
+
def load(self, path):
|
6 |
+
"""Load model from file.
|
7 |
+
|
8 |
+
Args:
|
9 |
+
path (str): file path
|
10 |
+
"""
|
11 |
+
parameters = torch.load(path, map_location=torch.device('cpu'))
|
12 |
+
|
13 |
+
if "optimizer" in parameters:
|
14 |
+
parameters = parameters["model"]
|
15 |
+
|
16 |
+
self.load_state_dict(parameters)
|
external/midas/blocks.py
ADDED
@@ -0,0 +1,342 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
from .vit import (
|
5 |
+
_make_pretrained_vitb_rn50_384,
|
6 |
+
_make_pretrained_vitl16_384,
|
7 |
+
_make_pretrained_vitb16_384,
|
8 |
+
forward_vit,
|
9 |
+
)
|
10 |
+
|
11 |
+
def _make_encoder(backbone, features, use_pretrained, groups=1, expand=False, exportable=True, hooks=None, use_vit_only=False, use_readout="ignore",):
|
12 |
+
if backbone == "vitl16_384":
|
13 |
+
pretrained = _make_pretrained_vitl16_384(
|
14 |
+
use_pretrained, hooks=hooks, use_readout=use_readout
|
15 |
+
)
|
16 |
+
scratch = _make_scratch(
|
17 |
+
[256, 512, 1024, 1024], features, groups=groups, expand=expand
|
18 |
+
) # ViT-L/16 - 85.0% Top1 (backbone)
|
19 |
+
elif backbone == "vitb_rn50_384":
|
20 |
+
pretrained = _make_pretrained_vitb_rn50_384(
|
21 |
+
use_pretrained,
|
22 |
+
hooks=hooks,
|
23 |
+
use_vit_only=use_vit_only,
|
24 |
+
use_readout=use_readout,
|
25 |
+
)
|
26 |
+
scratch = _make_scratch(
|
27 |
+
[256, 512, 768, 768], features, groups=groups, expand=expand
|
28 |
+
) # ViT-H/16 - 85.0% Top1 (backbone)
|
29 |
+
elif backbone == "vitb16_384":
|
30 |
+
pretrained = _make_pretrained_vitb16_384(
|
31 |
+
use_pretrained, hooks=hooks, use_readout=use_readout
|
32 |
+
)
|
33 |
+
scratch = _make_scratch(
|
34 |
+
[96, 192, 384, 768], features, groups=groups, expand=expand
|
35 |
+
) # ViT-B/16 - 84.6% Top1 (backbone)
|
36 |
+
elif backbone == "resnext101_wsl":
|
37 |
+
pretrained = _make_pretrained_resnext101_wsl(use_pretrained)
|
38 |
+
scratch = _make_scratch([256, 512, 1024, 2048], features, groups=groups, expand=expand) # efficientnet_lite3
|
39 |
+
elif backbone == "efficientnet_lite3":
|
40 |
+
pretrained = _make_pretrained_efficientnet_lite3(use_pretrained, exportable=exportable)
|
41 |
+
scratch = _make_scratch([32, 48, 136, 384], features, groups=groups, expand=expand) # efficientnet_lite3
|
42 |
+
else:
|
43 |
+
print(f"Backbone '{backbone}' not implemented")
|
44 |
+
assert False
|
45 |
+
|
46 |
+
return pretrained, scratch
|
47 |
+
|
48 |
+
|
49 |
+
def _make_scratch(in_shape, out_shape, groups=1, expand=False):
|
50 |
+
scratch = nn.Module()
|
51 |
+
|
52 |
+
out_shape1 = out_shape
|
53 |
+
out_shape2 = out_shape
|
54 |
+
out_shape3 = out_shape
|
55 |
+
out_shape4 = out_shape
|
56 |
+
if expand==True:
|
57 |
+
out_shape1 = out_shape
|
58 |
+
out_shape2 = out_shape*2
|
59 |
+
out_shape3 = out_shape*4
|
60 |
+
out_shape4 = out_shape*8
|
61 |
+
|
62 |
+
scratch.layer1_rn = nn.Conv2d(
|
63 |
+
in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
64 |
+
)
|
65 |
+
scratch.layer2_rn = nn.Conv2d(
|
66 |
+
in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
67 |
+
)
|
68 |
+
scratch.layer3_rn = nn.Conv2d(
|
69 |
+
in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
70 |
+
)
|
71 |
+
scratch.layer4_rn = nn.Conv2d(
|
72 |
+
in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
73 |
+
)
|
74 |
+
|
75 |
+
return scratch
|
76 |
+
|
77 |
+
|
78 |
+
def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False):
|
79 |
+
efficientnet = torch.hub.load(
|
80 |
+
"rwightman/gen-efficientnet-pytorch",
|
81 |
+
"tf_efficientnet_lite3",
|
82 |
+
pretrained=use_pretrained,
|
83 |
+
exportable=exportable
|
84 |
+
)
|
85 |
+
return _make_efficientnet_backbone(efficientnet)
|
86 |
+
|
87 |
+
|
88 |
+
def _make_efficientnet_backbone(effnet):
|
89 |
+
pretrained = nn.Module()
|
90 |
+
|
91 |
+
pretrained.layer1 = nn.Sequential(
|
92 |
+
effnet.conv_stem, effnet.bn1, effnet.act1, *effnet.blocks[0:2]
|
93 |
+
)
|
94 |
+
pretrained.layer2 = nn.Sequential(*effnet.blocks[2:3])
|
95 |
+
pretrained.layer3 = nn.Sequential(*effnet.blocks[3:5])
|
96 |
+
pretrained.layer4 = nn.Sequential(*effnet.blocks[5:9])
|
97 |
+
|
98 |
+
return pretrained
|
99 |
+
|
100 |
+
|
101 |
+
def _make_resnet_backbone(resnet):
|
102 |
+
pretrained = nn.Module()
|
103 |
+
pretrained.layer1 = nn.Sequential(
|
104 |
+
resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1
|
105 |
+
)
|
106 |
+
|
107 |
+
pretrained.layer2 = resnet.layer2
|
108 |
+
pretrained.layer3 = resnet.layer3
|
109 |
+
pretrained.layer4 = resnet.layer4
|
110 |
+
|
111 |
+
return pretrained
|
112 |
+
|
113 |
+
|
114 |
+
def _make_pretrained_resnext101_wsl(use_pretrained):
|
115 |
+
resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl")
|
116 |
+
return _make_resnet_backbone(resnet)
|
117 |
+
|
118 |
+
|
119 |
+
|
120 |
+
class Interpolate(nn.Module):
|
121 |
+
"""Interpolation module.
|
122 |
+
"""
|
123 |
+
|
124 |
+
def __init__(self, scale_factor, mode, align_corners=False):
|
125 |
+
"""Init.
|
126 |
+
|
127 |
+
Args:
|
128 |
+
scale_factor (float): scaling
|
129 |
+
mode (str): interpolation mode
|
130 |
+
"""
|
131 |
+
super(Interpolate, self).__init__()
|
132 |
+
|
133 |
+
self.interp = nn.functional.interpolate
|
134 |
+
self.scale_factor = scale_factor
|
135 |
+
self.mode = mode
|
136 |
+
self.align_corners = align_corners
|
137 |
+
|
138 |
+
def forward(self, x):
|
139 |
+
"""Forward pass.
|
140 |
+
|
141 |
+
Args:
|
142 |
+
x (tensor): input
|
143 |
+
|
144 |
+
Returns:
|
145 |
+
tensor: interpolated data
|
146 |
+
"""
|
147 |
+
|
148 |
+
x = self.interp(
|
149 |
+
x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners
|
150 |
+
)
|
151 |
+
|
152 |
+
return x
|
153 |
+
|
154 |
+
|
155 |
+
class ResidualConvUnit(nn.Module):
|
156 |
+
"""Residual convolution module.
|
157 |
+
"""
|
158 |
+
|
159 |
+
def __init__(self, features):
|
160 |
+
"""Init.
|
161 |
+
|
162 |
+
Args:
|
163 |
+
features (int): number of features
|
164 |
+
"""
|
165 |
+
super().__init__()
|
166 |
+
|
167 |
+
self.conv1 = nn.Conv2d(
|
168 |
+
features, features, kernel_size=3, stride=1, padding=1, bias=True
|
169 |
+
)
|
170 |
+
|
171 |
+
self.conv2 = nn.Conv2d(
|
172 |
+
features, features, kernel_size=3, stride=1, padding=1, bias=True
|
173 |
+
)
|
174 |
+
|
175 |
+
self.relu = nn.ReLU(inplace=True)
|
176 |
+
|
177 |
+
def forward(self, x):
|
178 |
+
"""Forward pass.
|
179 |
+
|
180 |
+
Args:
|
181 |
+
x (tensor): input
|
182 |
+
|
183 |
+
Returns:
|
184 |
+
tensor: output
|
185 |
+
"""
|
186 |
+
out = self.relu(x)
|
187 |
+
out = self.conv1(out)
|
188 |
+
out = self.relu(out)
|
189 |
+
out = self.conv2(out)
|
190 |
+
|
191 |
+
return out + x
|
192 |
+
|
193 |
+
|
194 |
+
class FeatureFusionBlock(nn.Module):
|
195 |
+
"""Feature fusion block.
|
196 |
+
"""
|
197 |
+
|
198 |
+
def __init__(self, features):
|
199 |
+
"""Init.
|
200 |
+
|
201 |
+
Args:
|
202 |
+
features (int): number of features
|
203 |
+
"""
|
204 |
+
super(FeatureFusionBlock, self).__init__()
|
205 |
+
|
206 |
+
self.resConfUnit1 = ResidualConvUnit(features)
|
207 |
+
self.resConfUnit2 = ResidualConvUnit(features)
|
208 |
+
|
209 |
+
def forward(self, *xs):
|
210 |
+
"""Forward pass.
|
211 |
+
|
212 |
+
Returns:
|
213 |
+
tensor: output
|
214 |
+
"""
|
215 |
+
output = xs[0]
|
216 |
+
|
217 |
+
if len(xs) == 2:
|
218 |
+
output += self.resConfUnit1(xs[1])
|
219 |
+
|
220 |
+
output = self.resConfUnit2(output)
|
221 |
+
|
222 |
+
output = nn.functional.interpolate(
|
223 |
+
output, scale_factor=2, mode="bilinear", align_corners=True
|
224 |
+
)
|
225 |
+
|
226 |
+
return output
|
227 |
+
|
228 |
+
|
229 |
+
|
230 |
+
|
231 |
+
class ResidualConvUnit_custom(nn.Module):
|
232 |
+
"""Residual convolution module.
|
233 |
+
"""
|
234 |
+
|
235 |
+
def __init__(self, features, activation, bn):
|
236 |
+
"""Init.
|
237 |
+
|
238 |
+
Args:
|
239 |
+
features (int): number of features
|
240 |
+
"""
|
241 |
+
super().__init__()
|
242 |
+
|
243 |
+
self.bn = bn
|
244 |
+
|
245 |
+
self.groups=1
|
246 |
+
|
247 |
+
self.conv1 = nn.Conv2d(
|
248 |
+
features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
|
249 |
+
)
|
250 |
+
|
251 |
+
self.conv2 = nn.Conv2d(
|
252 |
+
features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
|
253 |
+
)
|
254 |
+
|
255 |
+
if self.bn==True:
|
256 |
+
self.bn1 = nn.BatchNorm2d(features)
|
257 |
+
self.bn2 = nn.BatchNorm2d(features)
|
258 |
+
|
259 |
+
self.activation = activation
|
260 |
+
|
261 |
+
self.skip_add = nn.quantized.FloatFunctional()
|
262 |
+
|
263 |
+
def forward(self, x):
|
264 |
+
"""Forward pass.
|
265 |
+
|
266 |
+
Args:
|
267 |
+
x (tensor): input
|
268 |
+
|
269 |
+
Returns:
|
270 |
+
tensor: output
|
271 |
+
"""
|
272 |
+
|
273 |
+
out = self.activation(x)
|
274 |
+
out = self.conv1(out)
|
275 |
+
if self.bn==True:
|
276 |
+
out = self.bn1(out)
|
277 |
+
|
278 |
+
out = self.activation(out)
|
279 |
+
out = self.conv2(out)
|
280 |
+
if self.bn==True:
|
281 |
+
out = self.bn2(out)
|
282 |
+
|
283 |
+
if self.groups > 1:
|
284 |
+
out = self.conv_merge(out)
|
285 |
+
|
286 |
+
return self.skip_add.add(out, x)
|
287 |
+
|
288 |
+
# return out + x
|
289 |
+
|
290 |
+
|
291 |
+
class FeatureFusionBlock_custom(nn.Module):
|
292 |
+
"""Feature fusion block.
|
293 |
+
"""
|
294 |
+
|
295 |
+
def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True):
|
296 |
+
"""Init.
|
297 |
+
|
298 |
+
Args:
|
299 |
+
features (int): number of features
|
300 |
+
"""
|
301 |
+
super(FeatureFusionBlock_custom, self).__init__()
|
302 |
+
|
303 |
+
self.deconv = deconv
|
304 |
+
self.align_corners = align_corners
|
305 |
+
|
306 |
+
self.groups=1
|
307 |
+
|
308 |
+
self.expand = expand
|
309 |
+
out_features = features
|
310 |
+
if self.expand==True:
|
311 |
+
out_features = features//2
|
312 |
+
|
313 |
+
self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1)
|
314 |
+
|
315 |
+
self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn)
|
316 |
+
self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn)
|
317 |
+
|
318 |
+
self.skip_add = nn.quantized.FloatFunctional()
|
319 |
+
|
320 |
+
def forward(self, *xs):
|
321 |
+
"""Forward pass.
|
322 |
+
|
323 |
+
Returns:
|
324 |
+
tensor: output
|
325 |
+
"""
|
326 |
+
output = xs[0]
|
327 |
+
|
328 |
+
if len(xs) == 2:
|
329 |
+
res = self.resConfUnit1(xs[1])
|
330 |
+
output = self.skip_add.add(output, res)
|
331 |
+
# output += res
|
332 |
+
|
333 |
+
output = self.resConfUnit2(output)
|
334 |
+
|
335 |
+
output = nn.functional.interpolate(
|
336 |
+
output, scale_factor=2, mode="bilinear", align_corners=self.align_corners
|
337 |
+
)
|
338 |
+
|
339 |
+
output = self.out_conv(output)
|
340 |
+
|
341 |
+
return output
|
342 |
+
|
external/midas/dpt_depth.py
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
from .base_model import BaseModel
|
6 |
+
from .blocks import (
|
7 |
+
FeatureFusionBlock,
|
8 |
+
FeatureFusionBlock_custom,
|
9 |
+
Interpolate,
|
10 |
+
_make_encoder,
|
11 |
+
forward_vit,
|
12 |
+
)
|
13 |
+
|
14 |
+
|
15 |
+
def _make_fusion_block(features, use_bn):
|
16 |
+
return FeatureFusionBlock_custom(
|
17 |
+
features,
|
18 |
+
nn.ReLU(False),
|
19 |
+
deconv=False,
|
20 |
+
bn=use_bn,
|
21 |
+
expand=False,
|
22 |
+
align_corners=True,
|
23 |
+
)
|
24 |
+
|
25 |
+
|
26 |
+
class DPT(BaseModel):
|
27 |
+
def __init__(
|
28 |
+
self,
|
29 |
+
head,
|
30 |
+
features=256,
|
31 |
+
backbone="vitb_rn50_384",
|
32 |
+
readout="project",
|
33 |
+
channels_last=False,
|
34 |
+
use_bn=False,
|
35 |
+
):
|
36 |
+
|
37 |
+
super(DPT, self).__init__()
|
38 |
+
|
39 |
+
self.channels_last = channels_last
|
40 |
+
|
41 |
+
hooks = {
|
42 |
+
"vitb_rn50_384": [0, 1, 8, 11],
|
43 |
+
"vitb16_384": [2, 5, 8, 11],
|
44 |
+
"vitl16_384": [5, 11, 17, 23],
|
45 |
+
}
|
46 |
+
|
47 |
+
# Instantiate backbone and reassemble blocks
|
48 |
+
self.pretrained, self.scratch = _make_encoder(
|
49 |
+
backbone,
|
50 |
+
features,
|
51 |
+
False, # Set to true of you want to train from scratch, uses ImageNet weights
|
52 |
+
groups=1,
|
53 |
+
expand=False,
|
54 |
+
exportable=False,
|
55 |
+
hooks=hooks[backbone],
|
56 |
+
use_readout=readout,
|
57 |
+
)
|
58 |
+
|
59 |
+
self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
|
60 |
+
self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
|
61 |
+
self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
|
62 |
+
self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
|
63 |
+
|
64 |
+
self.scratch.output_conv = head
|
65 |
+
|
66 |
+
|
67 |
+
def forward(self, x):
|
68 |
+
if self.channels_last == True:
|
69 |
+
x.contiguous(memory_format=torch.channels_last)
|
70 |
+
|
71 |
+
layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x)
|
72 |
+
|
73 |
+
layer_1_rn = self.scratch.layer1_rn(layer_1)
|
74 |
+
layer_2_rn = self.scratch.layer2_rn(layer_2)
|
75 |
+
layer_3_rn = self.scratch.layer3_rn(layer_3)
|
76 |
+
layer_4_rn = self.scratch.layer4_rn(layer_4)
|
77 |
+
|
78 |
+
path_4 = self.scratch.refinenet4(layer_4_rn)
|
79 |
+
path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
|
80 |
+
path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
|
81 |
+
path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
|
82 |
+
|
83 |
+
out = self.scratch.output_conv(path_1)
|
84 |
+
|
85 |
+
return out
|
86 |
+
|
87 |
+
|
88 |
+
class DPTDepthModel(DPT):
|
89 |
+
def __init__(self, path=None, non_negative=True, **kwargs):
|
90 |
+
features = kwargs["features"] if "features" in kwargs else 256
|
91 |
+
|
92 |
+
head = nn.Sequential(
|
93 |
+
nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1),
|
94 |
+
Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
|
95 |
+
nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),
|
96 |
+
nn.ReLU(True),
|
97 |
+
nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
|
98 |
+
nn.ReLU(True) if non_negative else nn.Identity(),
|
99 |
+
nn.Identity(),
|
100 |
+
)
|
101 |
+
|
102 |
+
super().__init__(head, **kwargs)
|
103 |
+
|
104 |
+
if path is not None:
|
105 |
+
self.load(path)
|
106 |
+
|
107 |
+
def forward(self, x):
|
108 |
+
return super().forward(x).squeeze(dim=1)
|
109 |
+
|
external/midas/midas_net.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""MidashNet: Network for monocular depth estimation trained by mixing several datasets.
|
2 |
+
This file contains code that is adapted from
|
3 |
+
https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
|
4 |
+
"""
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
|
8 |
+
from .base_model import BaseModel
|
9 |
+
from .blocks import FeatureFusionBlock, Interpolate, _make_encoder
|
10 |
+
|
11 |
+
|
12 |
+
class MidasNet(BaseModel):
|
13 |
+
"""Network for monocular depth estimation.
|
14 |
+
"""
|
15 |
+
|
16 |
+
def __init__(self, path=None, features=256, non_negative=True):
|
17 |
+
"""Init.
|
18 |
+
|
19 |
+
Args:
|
20 |
+
path (str, optional): Path to saved model. Defaults to None.
|
21 |
+
features (int, optional): Number of features. Defaults to 256.
|
22 |
+
backbone (str, optional): Backbone network for encoder. Defaults to resnet50
|
23 |
+
"""
|
24 |
+
print("Loading weights: ", path)
|
25 |
+
|
26 |
+
super(MidasNet, self).__init__()
|
27 |
+
|
28 |
+
use_pretrained = False if path is None else True
|
29 |
+
|
30 |
+
self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained)
|
31 |
+
|
32 |
+
self.scratch.refinenet4 = FeatureFusionBlock(features)
|
33 |
+
self.scratch.refinenet3 = FeatureFusionBlock(features)
|
34 |
+
self.scratch.refinenet2 = FeatureFusionBlock(features)
|
35 |
+
self.scratch.refinenet1 = FeatureFusionBlock(features)
|
36 |
+
|
37 |
+
self.scratch.output_conv = nn.Sequential(
|
38 |
+
nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1),
|
39 |
+
Interpolate(scale_factor=2, mode="bilinear"),
|
40 |
+
nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1),
|
41 |
+
nn.ReLU(True),
|
42 |
+
nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
|
43 |
+
nn.ReLU(True) if non_negative else nn.Identity(),
|
44 |
+
)
|
45 |
+
|
46 |
+
if path:
|
47 |
+
self.load(path)
|
48 |
+
|
49 |
+
def forward(self, x):
|
50 |
+
"""Forward pass.
|
51 |
+
|
52 |
+
Args:
|
53 |
+
x (tensor): input data (image)
|
54 |
+
|
55 |
+
Returns:
|
56 |
+
tensor: depth
|
57 |
+
"""
|
58 |
+
|
59 |
+
layer_1 = self.pretrained.layer1(x)
|
60 |
+
layer_2 = self.pretrained.layer2(layer_1)
|
61 |
+
layer_3 = self.pretrained.layer3(layer_2)
|
62 |
+
layer_4 = self.pretrained.layer4(layer_3)
|
63 |
+
|
64 |
+
layer_1_rn = self.scratch.layer1_rn(layer_1)
|
65 |
+
layer_2_rn = self.scratch.layer2_rn(layer_2)
|
66 |
+
layer_3_rn = self.scratch.layer3_rn(layer_3)
|
67 |
+
layer_4_rn = self.scratch.layer4_rn(layer_4)
|
68 |
+
|
69 |
+
path_4 = self.scratch.refinenet4(layer_4_rn)
|
70 |
+
path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
|
71 |
+
path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
|
72 |
+
path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
|
73 |
+
|
74 |
+
out = self.scratch.output_conv(path_1)
|
75 |
+
|
76 |
+
return torch.squeeze(out, dim=1)
|
external/midas/midas_net_custom.py
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""MidashNet: Network for monocular depth estimation trained by mixing several datasets.
|
2 |
+
This file contains code that is adapted from
|
3 |
+
https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
|
4 |
+
"""
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
|
8 |
+
from .base_model import BaseModel
|
9 |
+
from .blocks import FeatureFusionBlock, FeatureFusionBlock_custom, Interpolate, _make_encoder
|
10 |
+
|
11 |
+
|
12 |
+
class MidasNet_small(BaseModel):
|
13 |
+
"""Network for monocular depth estimation.
|
14 |
+
"""
|
15 |
+
|
16 |
+
def __init__(self, path=None, features=64, backbone="efficientnet_lite3", non_negative=True, exportable=True, channels_last=False, align_corners=True,
|
17 |
+
blocks={'expand': True}):
|
18 |
+
"""Init.
|
19 |
+
|
20 |
+
Args:
|
21 |
+
path (str, optional): Path to saved model. Defaults to None.
|
22 |
+
features (int, optional): Number of features. Defaults to 256.
|
23 |
+
backbone (str, optional): Backbone network for encoder. Defaults to resnet50
|
24 |
+
"""
|
25 |
+
print("Loading weights: ", path)
|
26 |
+
|
27 |
+
super(MidasNet_small, self).__init__()
|
28 |
+
|
29 |
+
use_pretrained = False if path else True
|
30 |
+
|
31 |
+
self.channels_last = channels_last
|
32 |
+
self.blocks = blocks
|
33 |
+
self.backbone = backbone
|
34 |
+
|
35 |
+
self.groups = 1
|
36 |
+
|
37 |
+
features1=features
|
38 |
+
features2=features
|
39 |
+
features3=features
|
40 |
+
features4=features
|
41 |
+
self.expand = False
|
42 |
+
if "expand" in self.blocks and self.blocks['expand'] == True:
|
43 |
+
self.expand = True
|
44 |
+
features1=features
|
45 |
+
features2=features*2
|
46 |
+
features3=features*4
|
47 |
+
features4=features*8
|
48 |
+
|
49 |
+
self.pretrained, self.scratch = _make_encoder(self.backbone, features, use_pretrained, groups=self.groups, expand=self.expand, exportable=exportable)
|
50 |
+
|
51 |
+
self.scratch.activation = nn.ReLU(False)
|
52 |
+
|
53 |
+
self.scratch.refinenet4 = FeatureFusionBlock_custom(features4, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
|
54 |
+
self.scratch.refinenet3 = FeatureFusionBlock_custom(features3, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
|
55 |
+
self.scratch.refinenet2 = FeatureFusionBlock_custom(features2, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
|
56 |
+
self.scratch.refinenet1 = FeatureFusionBlock_custom(features1, self.scratch.activation, deconv=False, bn=False, align_corners=align_corners)
|
57 |
+
|
58 |
+
|
59 |
+
self.scratch.output_conv = nn.Sequential(
|
60 |
+
nn.Conv2d(features, features//2, kernel_size=3, stride=1, padding=1, groups=self.groups),
|
61 |
+
Interpolate(scale_factor=2, mode="bilinear"),
|
62 |
+
nn.Conv2d(features//2, 32, kernel_size=3, stride=1, padding=1),
|
63 |
+
self.scratch.activation,
|
64 |
+
nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
|
65 |
+
nn.ReLU(True) if non_negative else nn.Identity(),
|
66 |
+
nn.Identity(),
|
67 |
+
)
|
68 |
+
|
69 |
+
if path:
|
70 |
+
self.load(path)
|
71 |
+
|
72 |
+
|
73 |
+
def forward(self, x):
|
74 |
+
"""Forward pass.
|
75 |
+
|
76 |
+
Args:
|
77 |
+
x (tensor): input data (image)
|
78 |
+
|
79 |
+
Returns:
|
80 |
+
tensor: depth
|
81 |
+
"""
|
82 |
+
if self.channels_last==True:
|
83 |
+
print("self.channels_last = ", self.channels_last)
|
84 |
+
x.contiguous(memory_format=torch.channels_last)
|
85 |
+
|
86 |
+
|
87 |
+
layer_1 = self.pretrained.layer1(x)
|
88 |
+
layer_2 = self.pretrained.layer2(layer_1)
|
89 |
+
layer_3 = self.pretrained.layer3(layer_2)
|
90 |
+
layer_4 = self.pretrained.layer4(layer_3)
|
91 |
+
|
92 |
+
layer_1_rn = self.scratch.layer1_rn(layer_1)
|
93 |
+
layer_2_rn = self.scratch.layer2_rn(layer_2)
|
94 |
+
layer_3_rn = self.scratch.layer3_rn(layer_3)
|
95 |
+
layer_4_rn = self.scratch.layer4_rn(layer_4)
|
96 |
+
|
97 |
+
|
98 |
+
path_4 = self.scratch.refinenet4(layer_4_rn)
|
99 |
+
path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
|
100 |
+
path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
|
101 |
+
path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
|
102 |
+
|
103 |
+
out = self.scratch.output_conv(path_1)
|
104 |
+
|
105 |
+
return torch.squeeze(out, dim=1)
|
106 |
+
|
107 |
+
|
108 |
+
|
109 |
+
def fuse_model(m):
|
110 |
+
prev_previous_type = nn.Identity()
|
111 |
+
prev_previous_name = ''
|
112 |
+
previous_type = nn.Identity()
|
113 |
+
previous_name = ''
|
114 |
+
for name, module in m.named_modules():
|
115 |
+
if prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d and type(module) == nn.ReLU:
|
116 |
+
# print("FUSED ", prev_previous_name, previous_name, name)
|
117 |
+
torch.quantization.fuse_modules(m, [prev_previous_name, previous_name, name], inplace=True)
|
118 |
+
elif prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d:
|
119 |
+
# print("FUSED ", prev_previous_name, previous_name)
|
120 |
+
torch.quantization.fuse_modules(m, [prev_previous_name, previous_name], inplace=True)
|
121 |
+
# elif previous_type == nn.Conv2d and type(module) == nn.ReLU:
|
122 |
+
# print("FUSED ", previous_name, name)
|
123 |
+
# torch.quantization.fuse_modules(m, [previous_name, name], inplace=True)
|
124 |
+
|
125 |
+
prev_previous_type = previous_type
|
126 |
+
prev_previous_name = previous_name
|
127 |
+
previous_type = type(module)
|
128 |
+
previous_name = name
|
external/midas/transforms.py
ADDED
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import cv2
|
3 |
+
import math
|
4 |
+
|
5 |
+
|
6 |
+
def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA):
|
7 |
+
"""Rezise the sample to ensure the given size. Keeps aspect ratio.
|
8 |
+
|
9 |
+
Args:
|
10 |
+
sample (dict): sample
|
11 |
+
size (tuple): image size
|
12 |
+
|
13 |
+
Returns:
|
14 |
+
tuple: new size
|
15 |
+
"""
|
16 |
+
shape = list(sample["disparity"].shape)
|
17 |
+
|
18 |
+
if shape[0] >= size[0] and shape[1] >= size[1]:
|
19 |
+
return sample
|
20 |
+
|
21 |
+
scale = [0, 0]
|
22 |
+
scale[0] = size[0] / shape[0]
|
23 |
+
scale[1] = size[1] / shape[1]
|
24 |
+
|
25 |
+
scale = max(scale)
|
26 |
+
|
27 |
+
shape[0] = math.ceil(scale * shape[0])
|
28 |
+
shape[1] = math.ceil(scale * shape[1])
|
29 |
+
|
30 |
+
# resize
|
31 |
+
sample["image"] = cv2.resize(
|
32 |
+
sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method
|
33 |
+
)
|
34 |
+
|
35 |
+
sample["disparity"] = cv2.resize(
|
36 |
+
sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST
|
37 |
+
)
|
38 |
+
sample["mask"] = cv2.resize(
|
39 |
+
sample["mask"].astype(np.float32),
|
40 |
+
tuple(shape[::-1]),
|
41 |
+
interpolation=cv2.INTER_NEAREST,
|
42 |
+
)
|
43 |
+
sample["mask"] = sample["mask"].astype(bool)
|
44 |
+
|
45 |
+
return tuple(shape)
|
46 |
+
|
47 |
+
|
48 |
+
class Resize(object):
|
49 |
+
"""Resize sample to given size (width, height).
|
50 |
+
"""
|
51 |
+
|
52 |
+
def __init__(
|
53 |
+
self,
|
54 |
+
width,
|
55 |
+
height,
|
56 |
+
resize_target=True,
|
57 |
+
keep_aspect_ratio=False,
|
58 |
+
ensure_multiple_of=1,
|
59 |
+
resize_method="lower_bound",
|
60 |
+
image_interpolation_method=cv2.INTER_AREA,
|
61 |
+
):
|
62 |
+
"""Init.
|
63 |
+
|
64 |
+
Args:
|
65 |
+
width (int): desired output width
|
66 |
+
height (int): desired output height
|
67 |
+
resize_target (bool, optional):
|
68 |
+
True: Resize the full sample (image, mask, target).
|
69 |
+
False: Resize image only.
|
70 |
+
Defaults to True.
|
71 |
+
keep_aspect_ratio (bool, optional):
|
72 |
+
True: Keep the aspect ratio of the input sample.
|
73 |
+
Output sample might not have the given width and height, and
|
74 |
+
resize behaviour depends on the parameter 'resize_method'.
|
75 |
+
Defaults to False.
|
76 |
+
ensure_multiple_of (int, optional):
|
77 |
+
Output width and height is constrained to be multiple of this parameter.
|
78 |
+
Defaults to 1.
|
79 |
+
resize_method (str, optional):
|
80 |
+
"lower_bound": Output will be at least as large as the given size.
|
81 |
+
"upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.)
|
82 |
+
"minimal": Scale as least as possible. (Output size might be smaller than given size.)
|
83 |
+
Defaults to "lower_bound".
|
84 |
+
"""
|
85 |
+
self.__width = width
|
86 |
+
self.__height = height
|
87 |
+
|
88 |
+
self.__resize_target = resize_target
|
89 |
+
self.__keep_aspect_ratio = keep_aspect_ratio
|
90 |
+
self.__multiple_of = ensure_multiple_of
|
91 |
+
self.__resize_method = resize_method
|
92 |
+
self.__image_interpolation_method = image_interpolation_method
|
93 |
+
|
94 |
+
def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
|
95 |
+
y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
|
96 |
+
|
97 |
+
if max_val is not None and y > max_val:
|
98 |
+
y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int)
|
99 |
+
|
100 |
+
if y < min_val:
|
101 |
+
y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int)
|
102 |
+
|
103 |
+
return y
|
104 |
+
|
105 |
+
def get_size(self, width, height):
|
106 |
+
# determine new height and width
|
107 |
+
scale_height = self.__height / height
|
108 |
+
scale_width = self.__width / width
|
109 |
+
|
110 |
+
if self.__keep_aspect_ratio:
|
111 |
+
if self.__resize_method == "lower_bound":
|
112 |
+
# scale such that output size is lower bound
|
113 |
+
if scale_width > scale_height:
|
114 |
+
# fit width
|
115 |
+
scale_height = scale_width
|
116 |
+
else:
|
117 |
+
# fit height
|
118 |
+
scale_width = scale_height
|
119 |
+
elif self.__resize_method == "upper_bound":
|
120 |
+
# scale such that output size is upper bound
|
121 |
+
if scale_width < scale_height:
|
122 |
+
# fit width
|
123 |
+
scale_height = scale_width
|
124 |
+
else:
|
125 |
+
# fit height
|
126 |
+
scale_width = scale_height
|
127 |
+
elif self.__resize_method == "minimal":
|
128 |
+
# scale as least as possbile
|
129 |
+
if abs(1 - scale_width) < abs(1 - scale_height):
|
130 |
+
# fit width
|
131 |
+
scale_height = scale_width
|
132 |
+
else:
|
133 |
+
# fit height
|
134 |
+
scale_width = scale_height
|
135 |
+
else:
|
136 |
+
raise ValueError(
|
137 |
+
f"resize_method {self.__resize_method} not implemented"
|
138 |
+
)
|
139 |
+
|
140 |
+
if self.__resize_method == "lower_bound":
|
141 |
+
new_height = self.constrain_to_multiple_of(
|
142 |
+
scale_height * height, min_val=self.__height
|
143 |
+
)
|
144 |
+
new_width = self.constrain_to_multiple_of(
|
145 |
+
scale_width * width, min_val=self.__width
|
146 |
+
)
|
147 |
+
elif self.__resize_method == "upper_bound":
|
148 |
+
new_height = self.constrain_to_multiple_of(
|
149 |
+
scale_height * height, max_val=self.__height
|
150 |
+
)
|
151 |
+
new_width = self.constrain_to_multiple_of(
|
152 |
+
scale_width * width, max_val=self.__width
|
153 |
+
)
|
154 |
+
elif self.__resize_method == "minimal":
|
155 |
+
new_height = self.constrain_to_multiple_of(scale_height * height)
|
156 |
+
new_width = self.constrain_to_multiple_of(scale_width * width)
|
157 |
+
else:
|
158 |
+
raise ValueError(f"resize_method {self.__resize_method} not implemented")
|
159 |
+
|
160 |
+
return (new_width, new_height)
|
161 |
+
|
162 |
+
def __call__(self, sample):
|
163 |
+
width, height = self.get_size(
|
164 |
+
sample["image"].shape[1], sample["image"].shape[0]
|
165 |
+
)
|
166 |
+
|
167 |
+
# resize sample
|
168 |
+
sample["image"] = cv2.resize(
|
169 |
+
sample["image"],
|
170 |
+
(width, height),
|
171 |
+
interpolation=self.__image_interpolation_method,
|
172 |
+
)
|
173 |
+
|
174 |
+
if self.__resize_target:
|
175 |
+
if "disparity" in sample:
|
176 |
+
sample["disparity"] = cv2.resize(
|
177 |
+
sample["disparity"],
|
178 |
+
(width, height),
|
179 |
+
interpolation=cv2.INTER_NEAREST,
|
180 |
+
)
|
181 |
+
|
182 |
+
if "depth" in sample:
|
183 |
+
sample["depth"] = cv2.resize(
|
184 |
+
sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST
|
185 |
+
)
|
186 |
+
|
187 |
+
sample["mask"] = cv2.resize(
|
188 |
+
sample["mask"].astype(np.float32),
|
189 |
+
(width, height),
|
190 |
+
interpolation=cv2.INTER_NEAREST,
|
191 |
+
)
|
192 |
+
sample["mask"] = sample["mask"].astype(bool)
|
193 |
+
|
194 |
+
return sample
|
195 |
+
|
196 |
+
|
197 |
+
class NormalizeImage(object):
|
198 |
+
"""Normlize image by given mean and std.
|
199 |
+
"""
|
200 |
+
|
201 |
+
def __init__(self, mean, std):
|
202 |
+
self.__mean = mean
|
203 |
+
self.__std = std
|
204 |
+
|
205 |
+
def __call__(self, sample):
|
206 |
+
sample["image"] = (sample["image"] - self.__mean) / self.__std
|
207 |
+
|
208 |
+
return sample
|
209 |
+
|
210 |
+
|
211 |
+
class PrepareForNet(object):
|
212 |
+
"""Prepare sample for usage as network input.
|
213 |
+
"""
|
214 |
+
|
215 |
+
def __init__(self):
|
216 |
+
pass
|
217 |
+
|
218 |
+
def __call__(self, sample):
|
219 |
+
image = np.transpose(sample["image"], (2, 0, 1))
|
220 |
+
sample["image"] = np.ascontiguousarray(image).astype(np.float32)
|
221 |
+
|
222 |
+
if "mask" in sample:
|
223 |
+
sample["mask"] = sample["mask"].astype(np.float32)
|
224 |
+
sample["mask"] = np.ascontiguousarray(sample["mask"])
|
225 |
+
|
226 |
+
if "disparity" in sample:
|
227 |
+
disparity = sample["disparity"].astype(np.float32)
|
228 |
+
sample["disparity"] = np.ascontiguousarray(disparity)
|
229 |
+
|
230 |
+
if "depth" in sample:
|
231 |
+
depth = sample["depth"].astype(np.float32)
|
232 |
+
sample["depth"] = np.ascontiguousarray(depth)
|
233 |
+
|
234 |
+
return sample
|
external/midas/vit.py
ADDED
@@ -0,0 +1,491 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import timm
|
4 |
+
import types
|
5 |
+
import math
|
6 |
+
import torch.nn.functional as F
|
7 |
+
|
8 |
+
|
9 |
+
class Slice(nn.Module):
|
10 |
+
def __init__(self, start_index=1):
|
11 |
+
super(Slice, self).__init__()
|
12 |
+
self.start_index = start_index
|
13 |
+
|
14 |
+
def forward(self, x):
|
15 |
+
return x[:, self.start_index :]
|
16 |
+
|
17 |
+
|
18 |
+
class AddReadout(nn.Module):
|
19 |
+
def __init__(self, start_index=1):
|
20 |
+
super(AddReadout, self).__init__()
|
21 |
+
self.start_index = start_index
|
22 |
+
|
23 |
+
def forward(self, x):
|
24 |
+
if self.start_index == 2:
|
25 |
+
readout = (x[:, 0] + x[:, 1]) / 2
|
26 |
+
else:
|
27 |
+
readout = x[:, 0]
|
28 |
+
return x[:, self.start_index :] + readout.unsqueeze(1)
|
29 |
+
|
30 |
+
|
31 |
+
class ProjectReadout(nn.Module):
|
32 |
+
def __init__(self, in_features, start_index=1):
|
33 |
+
super(ProjectReadout, self).__init__()
|
34 |
+
self.start_index = start_index
|
35 |
+
|
36 |
+
self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU())
|
37 |
+
|
38 |
+
def forward(self, x):
|
39 |
+
readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index :])
|
40 |
+
features = torch.cat((x[:, self.start_index :], readout), -1)
|
41 |
+
|
42 |
+
return self.project(features)
|
43 |
+
|
44 |
+
|
45 |
+
class Transpose(nn.Module):
|
46 |
+
def __init__(self, dim0, dim1):
|
47 |
+
super(Transpose, self).__init__()
|
48 |
+
self.dim0 = dim0
|
49 |
+
self.dim1 = dim1
|
50 |
+
|
51 |
+
def forward(self, x):
|
52 |
+
x = x.transpose(self.dim0, self.dim1)
|
53 |
+
return x
|
54 |
+
|
55 |
+
|
56 |
+
def forward_vit(pretrained, x):
|
57 |
+
b, c, h, w = x.shape
|
58 |
+
|
59 |
+
glob = pretrained.model.forward_flex(x)
|
60 |
+
|
61 |
+
layer_1 = pretrained.activations["1"]
|
62 |
+
layer_2 = pretrained.activations["2"]
|
63 |
+
layer_3 = pretrained.activations["3"]
|
64 |
+
layer_4 = pretrained.activations["4"]
|
65 |
+
|
66 |
+
layer_1 = pretrained.act_postprocess1[0:2](layer_1)
|
67 |
+
layer_2 = pretrained.act_postprocess2[0:2](layer_2)
|
68 |
+
layer_3 = pretrained.act_postprocess3[0:2](layer_3)
|
69 |
+
layer_4 = pretrained.act_postprocess4[0:2](layer_4)
|
70 |
+
|
71 |
+
unflatten = nn.Sequential(
|
72 |
+
nn.Unflatten(
|
73 |
+
2,
|
74 |
+
torch.Size(
|
75 |
+
[
|
76 |
+
h // pretrained.model.patch_size[1],
|
77 |
+
w // pretrained.model.patch_size[0],
|
78 |
+
]
|
79 |
+
),
|
80 |
+
)
|
81 |
+
)
|
82 |
+
|
83 |
+
if layer_1.ndim == 3:
|
84 |
+
layer_1 = unflatten(layer_1)
|
85 |
+
if layer_2.ndim == 3:
|
86 |
+
layer_2 = unflatten(layer_2)
|
87 |
+
if layer_3.ndim == 3:
|
88 |
+
layer_3 = unflatten(layer_3)
|
89 |
+
if layer_4.ndim == 3:
|
90 |
+
layer_4 = unflatten(layer_4)
|
91 |
+
|
92 |
+
layer_1 = pretrained.act_postprocess1[3 : len(pretrained.act_postprocess1)](layer_1)
|
93 |
+
layer_2 = pretrained.act_postprocess2[3 : len(pretrained.act_postprocess2)](layer_2)
|
94 |
+
layer_3 = pretrained.act_postprocess3[3 : len(pretrained.act_postprocess3)](layer_3)
|
95 |
+
layer_4 = pretrained.act_postprocess4[3 : len(pretrained.act_postprocess4)](layer_4)
|
96 |
+
|
97 |
+
return layer_1, layer_2, layer_3, layer_4
|
98 |
+
|
99 |
+
|
100 |
+
def _resize_pos_embed(self, posemb, gs_h, gs_w):
|
101 |
+
posemb_tok, posemb_grid = (
|
102 |
+
posemb[:, : self.start_index],
|
103 |
+
posemb[0, self.start_index :],
|
104 |
+
)
|
105 |
+
|
106 |
+
gs_old = int(math.sqrt(len(posemb_grid)))
|
107 |
+
|
108 |
+
posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
|
109 |
+
posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear")
|
110 |
+
posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1)
|
111 |
+
|
112 |
+
posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
|
113 |
+
|
114 |
+
return posemb
|
115 |
+
|
116 |
+
|
117 |
+
def forward_flex(self, x):
|
118 |
+
b, c, h, w = x.shape
|
119 |
+
|
120 |
+
pos_embed = self._resize_pos_embed(
|
121 |
+
self.pos_embed, h // self.patch_size[1], w // self.patch_size[0]
|
122 |
+
)
|
123 |
+
|
124 |
+
B = x.shape[0]
|
125 |
+
|
126 |
+
if hasattr(self.patch_embed, "backbone"):
|
127 |
+
x = self.patch_embed.backbone(x)
|
128 |
+
if isinstance(x, (list, tuple)):
|
129 |
+
x = x[-1] # last feature if backbone outputs list/tuple of features
|
130 |
+
|
131 |
+
x = self.patch_embed.proj(x).flatten(2).transpose(1, 2)
|
132 |
+
|
133 |
+
if getattr(self, "dist_token", None) is not None:
|
134 |
+
cls_tokens = self.cls_token.expand(
|
135 |
+
B, -1, -1
|
136 |
+
) # stole cls_tokens impl from Phil Wang, thanks
|
137 |
+
dist_token = self.dist_token.expand(B, -1, -1)
|
138 |
+
x = torch.cat((cls_tokens, dist_token, x), dim=1)
|
139 |
+
else:
|
140 |
+
cls_tokens = self.cls_token.expand(
|
141 |
+
B, -1, -1
|
142 |
+
) # stole cls_tokens impl from Phil Wang, thanks
|
143 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
144 |
+
|
145 |
+
x = x + pos_embed
|
146 |
+
x = self.pos_drop(x)
|
147 |
+
|
148 |
+
for blk in self.blocks:
|
149 |
+
x = blk(x)
|
150 |
+
|
151 |
+
x = self.norm(x)
|
152 |
+
|
153 |
+
return x
|
154 |
+
|
155 |
+
|
156 |
+
activations = {}
|
157 |
+
|
158 |
+
|
159 |
+
def get_activation(name):
|
160 |
+
def hook(model, input, output):
|
161 |
+
activations[name] = output
|
162 |
+
|
163 |
+
return hook
|
164 |
+
|
165 |
+
|
166 |
+
def get_readout_oper(vit_features, features, use_readout, start_index=1):
|
167 |
+
if use_readout == "ignore":
|
168 |
+
readout_oper = [Slice(start_index)] * len(features)
|
169 |
+
elif use_readout == "add":
|
170 |
+
readout_oper = [AddReadout(start_index)] * len(features)
|
171 |
+
elif use_readout == "project":
|
172 |
+
readout_oper = [
|
173 |
+
ProjectReadout(vit_features, start_index) for out_feat in features
|
174 |
+
]
|
175 |
+
else:
|
176 |
+
assert (
|
177 |
+
False
|
178 |
+
), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'"
|
179 |
+
|
180 |
+
return readout_oper
|
181 |
+
|
182 |
+
|
183 |
+
def _make_vit_b16_backbone(
|
184 |
+
model,
|
185 |
+
features=[96, 192, 384, 768],
|
186 |
+
size=[384, 384],
|
187 |
+
hooks=[2, 5, 8, 11],
|
188 |
+
vit_features=768,
|
189 |
+
use_readout="ignore",
|
190 |
+
start_index=1,
|
191 |
+
):
|
192 |
+
pretrained = nn.Module()
|
193 |
+
|
194 |
+
pretrained.model = model
|
195 |
+
pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
|
196 |
+
pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
|
197 |
+
pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
|
198 |
+
pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
|
199 |
+
|
200 |
+
pretrained.activations = activations
|
201 |
+
|
202 |
+
readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
|
203 |
+
|
204 |
+
# 32, 48, 136, 384
|
205 |
+
pretrained.act_postprocess1 = nn.Sequential(
|
206 |
+
readout_oper[0],
|
207 |
+
Transpose(1, 2),
|
208 |
+
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
209 |
+
nn.Conv2d(
|
210 |
+
in_channels=vit_features,
|
211 |
+
out_channels=features[0],
|
212 |
+
kernel_size=1,
|
213 |
+
stride=1,
|
214 |
+
padding=0,
|
215 |
+
),
|
216 |
+
nn.ConvTranspose2d(
|
217 |
+
in_channels=features[0],
|
218 |
+
out_channels=features[0],
|
219 |
+
kernel_size=4,
|
220 |
+
stride=4,
|
221 |
+
padding=0,
|
222 |
+
bias=True,
|
223 |
+
dilation=1,
|
224 |
+
groups=1,
|
225 |
+
),
|
226 |
+
)
|
227 |
+
|
228 |
+
pretrained.act_postprocess2 = nn.Sequential(
|
229 |
+
readout_oper[1],
|
230 |
+
Transpose(1, 2),
|
231 |
+
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
232 |
+
nn.Conv2d(
|
233 |
+
in_channels=vit_features,
|
234 |
+
out_channels=features[1],
|
235 |
+
kernel_size=1,
|
236 |
+
stride=1,
|
237 |
+
padding=0,
|
238 |
+
),
|
239 |
+
nn.ConvTranspose2d(
|
240 |
+
in_channels=features[1],
|
241 |
+
out_channels=features[1],
|
242 |
+
kernel_size=2,
|
243 |
+
stride=2,
|
244 |
+
padding=0,
|
245 |
+
bias=True,
|
246 |
+
dilation=1,
|
247 |
+
groups=1,
|
248 |
+
),
|
249 |
+
)
|
250 |
+
|
251 |
+
pretrained.act_postprocess3 = nn.Sequential(
|
252 |
+
readout_oper[2],
|
253 |
+
Transpose(1, 2),
|
254 |
+
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
255 |
+
nn.Conv2d(
|
256 |
+
in_channels=vit_features,
|
257 |
+
out_channels=features[2],
|
258 |
+
kernel_size=1,
|
259 |
+
stride=1,
|
260 |
+
padding=0,
|
261 |
+
),
|
262 |
+
)
|
263 |
+
|
264 |
+
pretrained.act_postprocess4 = nn.Sequential(
|
265 |
+
readout_oper[3],
|
266 |
+
Transpose(1, 2),
|
267 |
+
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
268 |
+
nn.Conv2d(
|
269 |
+
in_channels=vit_features,
|
270 |
+
out_channels=features[3],
|
271 |
+
kernel_size=1,
|
272 |
+
stride=1,
|
273 |
+
padding=0,
|
274 |
+
),
|
275 |
+
nn.Conv2d(
|
276 |
+
in_channels=features[3],
|
277 |
+
out_channels=features[3],
|
278 |
+
kernel_size=3,
|
279 |
+
stride=2,
|
280 |
+
padding=1,
|
281 |
+
),
|
282 |
+
)
|
283 |
+
|
284 |
+
pretrained.model.start_index = start_index
|
285 |
+
pretrained.model.patch_size = [16, 16]
|
286 |
+
|
287 |
+
# We inject this function into the VisionTransformer instances so that
|
288 |
+
# we can use it with interpolated position embeddings without modifying the library source.
|
289 |
+
pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
|
290 |
+
pretrained.model._resize_pos_embed = types.MethodType(
|
291 |
+
_resize_pos_embed, pretrained.model
|
292 |
+
)
|
293 |
+
|
294 |
+
return pretrained
|
295 |
+
|
296 |
+
|
297 |
+
def _make_pretrained_vitl16_384(pretrained, use_readout="ignore", hooks=None):
|
298 |
+
model = timm.create_model("vit_large_patch16_384", pretrained=pretrained)
|
299 |
+
|
300 |
+
hooks = [5, 11, 17, 23] if hooks == None else hooks
|
301 |
+
return _make_vit_b16_backbone(
|
302 |
+
model,
|
303 |
+
features=[256, 512, 1024, 1024],
|
304 |
+
hooks=hooks,
|
305 |
+
vit_features=1024,
|
306 |
+
use_readout=use_readout,
|
307 |
+
)
|
308 |
+
|
309 |
+
|
310 |
+
def _make_pretrained_vitb16_384(pretrained, use_readout="ignore", hooks=None):
|
311 |
+
model = timm.create_model("vit_base_patch16_384", pretrained=pretrained)
|
312 |
+
|
313 |
+
hooks = [2, 5, 8, 11] if hooks == None else hooks
|
314 |
+
return _make_vit_b16_backbone(
|
315 |
+
model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout
|
316 |
+
)
|
317 |
+
|
318 |
+
|
319 |
+
def _make_pretrained_deitb16_384(pretrained, use_readout="ignore", hooks=None):
|
320 |
+
model = timm.create_model("vit_deit_base_patch16_384", pretrained=pretrained)
|
321 |
+
|
322 |
+
hooks = [2, 5, 8, 11] if hooks == None else hooks
|
323 |
+
return _make_vit_b16_backbone(
|
324 |
+
model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout
|
325 |
+
)
|
326 |
+
|
327 |
+
|
328 |
+
def _make_pretrained_deitb16_distil_384(pretrained, use_readout="ignore", hooks=None):
|
329 |
+
model = timm.create_model(
|
330 |
+
"vit_deit_base_distilled_patch16_384", pretrained=pretrained
|
331 |
+
)
|
332 |
+
|
333 |
+
hooks = [2, 5, 8, 11] if hooks == None else hooks
|
334 |
+
return _make_vit_b16_backbone(
|
335 |
+
model,
|
336 |
+
features=[96, 192, 384, 768],
|
337 |
+
hooks=hooks,
|
338 |
+
use_readout=use_readout,
|
339 |
+
start_index=2,
|
340 |
+
)
|
341 |
+
|
342 |
+
|
343 |
+
def _make_vit_b_rn50_backbone(
|
344 |
+
model,
|
345 |
+
features=[256, 512, 768, 768],
|
346 |
+
size=[384, 384],
|
347 |
+
hooks=[0, 1, 8, 11],
|
348 |
+
vit_features=768,
|
349 |
+
use_vit_only=False,
|
350 |
+
use_readout="ignore",
|
351 |
+
start_index=1,
|
352 |
+
):
|
353 |
+
pretrained = nn.Module()
|
354 |
+
|
355 |
+
pretrained.model = model
|
356 |
+
|
357 |
+
if use_vit_only == True:
|
358 |
+
pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
|
359 |
+
pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
|
360 |
+
else:
|
361 |
+
pretrained.model.patch_embed.backbone.stages[0].register_forward_hook(
|
362 |
+
get_activation("1")
|
363 |
+
)
|
364 |
+
pretrained.model.patch_embed.backbone.stages[1].register_forward_hook(
|
365 |
+
get_activation("2")
|
366 |
+
)
|
367 |
+
|
368 |
+
pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
|
369 |
+
pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
|
370 |
+
|
371 |
+
pretrained.activations = activations
|
372 |
+
|
373 |
+
readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
|
374 |
+
|
375 |
+
if use_vit_only == True:
|
376 |
+
pretrained.act_postprocess1 = nn.Sequential(
|
377 |
+
readout_oper[0],
|
378 |
+
Transpose(1, 2),
|
379 |
+
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
380 |
+
nn.Conv2d(
|
381 |
+
in_channels=vit_features,
|
382 |
+
out_channels=features[0],
|
383 |
+
kernel_size=1,
|
384 |
+
stride=1,
|
385 |
+
padding=0,
|
386 |
+
),
|
387 |
+
nn.ConvTranspose2d(
|
388 |
+
in_channels=features[0],
|
389 |
+
out_channels=features[0],
|
390 |
+
kernel_size=4,
|
391 |
+
stride=4,
|
392 |
+
padding=0,
|
393 |
+
bias=True,
|
394 |
+
dilation=1,
|
395 |
+
groups=1,
|
396 |
+
),
|
397 |
+
)
|
398 |
+
|
399 |
+
pretrained.act_postprocess2 = nn.Sequential(
|
400 |
+
readout_oper[1],
|
401 |
+
Transpose(1, 2),
|
402 |
+
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
403 |
+
nn.Conv2d(
|
404 |
+
in_channels=vit_features,
|
405 |
+
out_channels=features[1],
|
406 |
+
kernel_size=1,
|
407 |
+
stride=1,
|
408 |
+
padding=0,
|
409 |
+
),
|
410 |
+
nn.ConvTranspose2d(
|
411 |
+
in_channels=features[1],
|
412 |
+
out_channels=features[1],
|
413 |
+
kernel_size=2,
|
414 |
+
stride=2,
|
415 |
+
padding=0,
|
416 |
+
bias=True,
|
417 |
+
dilation=1,
|
418 |
+
groups=1,
|
419 |
+
),
|
420 |
+
)
|
421 |
+
else:
|
422 |
+
pretrained.act_postprocess1 = nn.Sequential(
|
423 |
+
nn.Identity(), nn.Identity(), nn.Identity()
|
424 |
+
)
|
425 |
+
pretrained.act_postprocess2 = nn.Sequential(
|
426 |
+
nn.Identity(), nn.Identity(), nn.Identity()
|
427 |
+
)
|
428 |
+
|
429 |
+
pretrained.act_postprocess3 = nn.Sequential(
|
430 |
+
readout_oper[2],
|
431 |
+
Transpose(1, 2),
|
432 |
+
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
433 |
+
nn.Conv2d(
|
434 |
+
in_channels=vit_features,
|
435 |
+
out_channels=features[2],
|
436 |
+
kernel_size=1,
|
437 |
+
stride=1,
|
438 |
+
padding=0,
|
439 |
+
),
|
440 |
+
)
|
441 |
+
|
442 |
+
pretrained.act_postprocess4 = nn.Sequential(
|
443 |
+
readout_oper[3],
|
444 |
+
Transpose(1, 2),
|
445 |
+
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
446 |
+
nn.Conv2d(
|
447 |
+
in_channels=vit_features,
|
448 |
+
out_channels=features[3],
|
449 |
+
kernel_size=1,
|
450 |
+
stride=1,
|
451 |
+
padding=0,
|
452 |
+
),
|
453 |
+
nn.Conv2d(
|
454 |
+
in_channels=features[3],
|
455 |
+
out_channels=features[3],
|
456 |
+
kernel_size=3,
|
457 |
+
stride=2,
|
458 |
+
padding=1,
|
459 |
+
),
|
460 |
+
)
|
461 |
+
|
462 |
+
pretrained.model.start_index = start_index
|
463 |
+
pretrained.model.patch_size = [16, 16]
|
464 |
+
|
465 |
+
# We inject this function into the VisionTransformer instances so that
|
466 |
+
# we can use it with interpolated position embeddings without modifying the library source.
|
467 |
+
pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
|
468 |
+
|
469 |
+
# We inject this function into the VisionTransformer instances so that
|
470 |
+
# we can use it with interpolated position embeddings without modifying the library source.
|
471 |
+
pretrained.model._resize_pos_embed = types.MethodType(
|
472 |
+
_resize_pos_embed, pretrained.model
|
473 |
+
)
|
474 |
+
|
475 |
+
return pretrained
|
476 |
+
|
477 |
+
|
478 |
+
def _make_pretrained_vitb_rn50_384(
|
479 |
+
pretrained, use_readout="ignore", hooks=None, use_vit_only=False
|
480 |
+
):
|
481 |
+
model = timm.create_model("vit_base_resnet50_384", pretrained=pretrained)
|
482 |
+
|
483 |
+
hooks = [0, 1, 8, 11] if hooks == None else hooks
|
484 |
+
return _make_vit_b_rn50_backbone(
|
485 |
+
model,
|
486 |
+
features=[256, 512, 768, 768],
|
487 |
+
size=[384, 384],
|
488 |
+
hooks=hooks,
|
489 |
+
use_vit_only=use_vit_only,
|
490 |
+
use_readout=use_readout,
|
491 |
+
)
|
external/realesrgan/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# flake8: noqa
|
2 |
+
from .archs import *
|
3 |
+
from .data import *
|
4 |
+
from .models import *
|
5 |
+
from .utils import *
|
6 |
+
#from .version import *
|
external/realesrgan/archs/__init__.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import importlib
|
2 |
+
from basicsr.utils import scandir
|
3 |
+
from os import path as osp
|
4 |
+
|
5 |
+
# automatically scan and import arch modules for registry
|
6 |
+
# scan all the files that end with '_arch.py' under the archs folder
|
7 |
+
arch_folder = osp.dirname(osp.abspath(__file__))
|
8 |
+
arch_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith('_arch.py')]
|
9 |
+
# import all the arch modules
|
10 |
+
_arch_modules = [importlib.import_module(f'realesrgan.archs.{file_name}') for file_name in arch_filenames]
|
external/realesrgan/archs/discriminator_arch.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from basicsr.utils.registry import ARCH_REGISTRY
|
2 |
+
from torch import nn as nn
|
3 |
+
from torch.nn import functional as F
|
4 |
+
from torch.nn.utils import spectral_norm
|
5 |
+
|
6 |
+
|
7 |
+
@ARCH_REGISTRY.register()
|
8 |
+
class UNetDiscriminatorSN(nn.Module):
|
9 |
+
"""Defines a U-Net discriminator with spectral normalization (SN)
|
10 |
+
|
11 |
+
It is used in Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data.
|
12 |
+
|
13 |
+
Arg:
|
14 |
+
num_in_ch (int): Channel number of inputs. Default: 3.
|
15 |
+
num_feat (int): Channel number of base intermediate features. Default: 64.
|
16 |
+
skip_connection (bool): Whether to use skip connections between U-Net. Default: True.
|
17 |
+
"""
|
18 |
+
|
19 |
+
def __init__(self, num_in_ch, num_feat=64, skip_connection=True):
|
20 |
+
super(UNetDiscriminatorSN, self).__init__()
|
21 |
+
self.skip_connection = skip_connection
|
22 |
+
norm = spectral_norm
|
23 |
+
# the first convolution
|
24 |
+
self.conv0 = nn.Conv2d(num_in_ch, num_feat, kernel_size=3, stride=1, padding=1)
|
25 |
+
# downsample
|
26 |
+
self.conv1 = norm(nn.Conv2d(num_feat, num_feat * 2, 4, 2, 1, bias=False))
|
27 |
+
self.conv2 = norm(nn.Conv2d(num_feat * 2, num_feat * 4, 4, 2, 1, bias=False))
|
28 |
+
self.conv3 = norm(nn.Conv2d(num_feat * 4, num_feat * 8, 4, 2, 1, bias=False))
|
29 |
+
# upsample
|
30 |
+
self.conv4 = norm(nn.Conv2d(num_feat * 8, num_feat * 4, 3, 1, 1, bias=False))
|
31 |
+
self.conv5 = norm(nn.Conv2d(num_feat * 4, num_feat * 2, 3, 1, 1, bias=False))
|
32 |
+
self.conv6 = norm(nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1, bias=False))
|
33 |
+
# extra convolutions
|
34 |
+
self.conv7 = norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=False))
|
35 |
+
self.conv8 = norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=False))
|
36 |
+
self.conv9 = nn.Conv2d(num_feat, 1, 3, 1, 1)
|
37 |
+
|
38 |
+
def forward(self, x):
|
39 |
+
# downsample
|
40 |
+
x0 = F.leaky_relu(self.conv0(x), negative_slope=0.2, inplace=True)
|
41 |
+
x1 = F.leaky_relu(self.conv1(x0), negative_slope=0.2, inplace=True)
|
42 |
+
x2 = F.leaky_relu(self.conv2(x1), negative_slope=0.2, inplace=True)
|
43 |
+
x3 = F.leaky_relu(self.conv3(x2), negative_slope=0.2, inplace=True)
|
44 |
+
|
45 |
+
# upsample
|
46 |
+
x3 = F.interpolate(x3, scale_factor=2, mode='bilinear', align_corners=False)
|
47 |
+
x4 = F.leaky_relu(self.conv4(x3), negative_slope=0.2, inplace=True)
|
48 |
+
|
49 |
+
if self.skip_connection:
|
50 |
+
x4 = x4 + x2
|
51 |
+
x4 = F.interpolate(x4, scale_factor=2, mode='bilinear', align_corners=False)
|
52 |
+
x5 = F.leaky_relu(self.conv5(x4), negative_slope=0.2, inplace=True)
|
53 |
+
|
54 |
+
if self.skip_connection:
|
55 |
+
x5 = x5 + x1
|
56 |
+
x5 = F.interpolate(x5, scale_factor=2, mode='bilinear', align_corners=False)
|
57 |
+
x6 = F.leaky_relu(self.conv6(x5), negative_slope=0.2, inplace=True)
|
58 |
+
|
59 |
+
if self.skip_connection:
|
60 |
+
x6 = x6 + x0
|
61 |
+
|
62 |
+
# extra convolutions
|
63 |
+
out = F.leaky_relu(self.conv7(x6), negative_slope=0.2, inplace=True)
|
64 |
+
out = F.leaky_relu(self.conv8(out), negative_slope=0.2, inplace=True)
|
65 |
+
out = self.conv9(out)
|
66 |
+
|
67 |
+
return out
|
external/realesrgan/archs/srvgg_arch.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from basicsr.utils.registry import ARCH_REGISTRY
|
2 |
+
from torch import nn as nn
|
3 |
+
from torch.nn import functional as F
|
4 |
+
|
5 |
+
|
6 |
+
@ARCH_REGISTRY.register()
|
7 |
+
class SRVGGNetCompact(nn.Module):
|
8 |
+
"""A compact VGG-style network structure for super-resolution.
|
9 |
+
|
10 |
+
It is a compact network structure, which performs upsampling in the last layer and no convolution is
|
11 |
+
conducted on the HR feature space.
|
12 |
+
|
13 |
+
Args:
|
14 |
+
num_in_ch (int): Channel number of inputs. Default: 3.
|
15 |
+
num_out_ch (int): Channel number of outputs. Default: 3.
|
16 |
+
num_feat (int): Channel number of intermediate features. Default: 64.
|
17 |
+
num_conv (int): Number of convolution layers in the body network. Default: 16.
|
18 |
+
upscale (int): Upsampling factor. Default: 4.
|
19 |
+
act_type (str): Activation type, options: 'relu', 'prelu', 'leakyrelu'. Default: prelu.
|
20 |
+
"""
|
21 |
+
|
22 |
+
def __init__(self, num_in_ch = 3, num_out_ch = 3, num_feat = 64, num_conv = 16, upscale = 4, act_type = 'prelu'):
|
23 |
+
super(SRVGGNetCompact, self).__init__()
|
24 |
+
self.num_in_ch = num_in_ch
|
25 |
+
self.num_out_ch = num_out_ch
|
26 |
+
self.num_feat = num_feat
|
27 |
+
self.num_conv = num_conv
|
28 |
+
self.upscale = upscale
|
29 |
+
self.act_type = act_type
|
30 |
+
|
31 |
+
self.body = nn.ModuleList()
|
32 |
+
# the first conv
|
33 |
+
self.body.append(nn.Conv2d(num_in_ch, num_feat, 3, 1, 1))
|
34 |
+
# the first activation
|
35 |
+
if act_type == 'relu':
|
36 |
+
activation = nn.ReLU(inplace = True)
|
37 |
+
elif act_type == 'prelu':
|
38 |
+
activation = nn.PReLU(num_parameters = num_feat)
|
39 |
+
elif act_type == 'leakyrelu':
|
40 |
+
activation = nn.LeakyReLU(negative_slope = 0.1, inplace = True)
|
41 |
+
self.body.append(activation)
|
42 |
+
|
43 |
+
# the body structure
|
44 |
+
for _ in range(num_conv):
|
45 |
+
self.body.append(nn.Conv2d(num_feat, num_feat, 3, 1, 1))
|
46 |
+
# activation
|
47 |
+
if act_type == 'relu':
|
48 |
+
activation = nn.ReLU(inplace = True)
|
49 |
+
elif act_type == 'prelu':
|
50 |
+
activation = nn.PReLU(num_parameters = num_feat)
|
51 |
+
elif act_type == 'leakyrelu':
|
52 |
+
activation = nn.LeakyReLU(negative_slope = 0.1, inplace = True)
|
53 |
+
self.body.append(activation)
|
54 |
+
|
55 |
+
# the last conv
|
56 |
+
self.body.append(nn.Conv2d(num_feat, num_out_ch * upscale * upscale, 3, 1, 1))
|
57 |
+
# upsample
|
58 |
+
self.upsampler = nn.PixelShuffle(upscale)
|
59 |
+
|
60 |
+
def forward(self, x):
|
61 |
+
out = x
|
62 |
+
for i in range(0, len(self.body)):
|
63 |
+
out = self.body[i](out)
|
64 |
+
|
65 |
+
out = self.upsampler(out)
|
66 |
+
# add the nearest upsampled image, so that the network learns the residual
|
67 |
+
base = F.interpolate(x, scale_factor = self.upscale, mode = 'nearest')
|
68 |
+
out += base
|
69 |
+
return out
|
external/realesrgan/data/__init__.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import importlib
|
2 |
+
from basicsr.utils import scandir
|
3 |
+
from os import path as osp
|
4 |
+
|
5 |
+
# automatically scan and import dataset modules for registry
|
6 |
+
# scan all the files that end with '_dataset.py' under the data folder
|
7 |
+
data_folder = osp.dirname(osp.abspath(__file__))
|
8 |
+
dataset_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith('_dataset.py')]
|
9 |
+
# import all the dataset modules
|
10 |
+
_dataset_modules = [importlib.import_module(f'realesrgan.data.{file_name}') for file_name in dataset_filenames]
|
external/realesrgan/data/realesrgan_dataset.py
ADDED
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import math
|
3 |
+
import numpy as np
|
4 |
+
import os
|
5 |
+
import os.path as osp
|
6 |
+
import random
|
7 |
+
import time
|
8 |
+
import torch
|
9 |
+
from basicsr.data.degradations import circular_lowpass_kernel, random_mixed_kernels
|
10 |
+
from basicsr.data.transforms import augment
|
11 |
+
from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
|
12 |
+
from basicsr.utils.registry import DATASET_REGISTRY
|
13 |
+
from torch.utils import data as data
|
14 |
+
|
15 |
+
|
16 |
+
@DATASET_REGISTRY.register()
|
17 |
+
class RealESRGANDataset(data.Dataset):
|
18 |
+
"""Dataset used for Real-ESRGAN model:
|
19 |
+
Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data.
|
20 |
+
|
21 |
+
It loads gt (Ground-Truth) images, and augments them.
|
22 |
+
It also generates blur kernels and sinc kernels for generating low-quality images.
|
23 |
+
Note that the low-quality images are processed in tensors on GPUS for faster processing.
|
24 |
+
|
25 |
+
Args:
|
26 |
+
opt (dict): Config for train datasets. It contains the following keys:
|
27 |
+
dataroot_gt (str): Data root path for gt.
|
28 |
+
meta_info (str): Path for meta information file.
|
29 |
+
io_backend (dict): IO backend type and other kwarg.
|
30 |
+
use_hflip (bool): Use horizontal flips.
|
31 |
+
use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation).
|
32 |
+
Please see more options in the codes.
|
33 |
+
"""
|
34 |
+
|
35 |
+
def __init__(self, opt):
|
36 |
+
super(RealESRGANDataset, self).__init__()
|
37 |
+
self.opt = opt
|
38 |
+
self.file_client = None
|
39 |
+
self.io_backend_opt = opt['io_backend']
|
40 |
+
self.gt_folder = opt['dataroot_gt']
|
41 |
+
|
42 |
+
# file client (lmdb io backend)
|
43 |
+
if self.io_backend_opt['type'] == 'lmdb':
|
44 |
+
self.io_backend_opt['db_paths'] = [self.gt_folder]
|
45 |
+
self.io_backend_opt['client_keys'] = ['gt']
|
46 |
+
if not self.gt_folder.endswith('.lmdb'):
|
47 |
+
raise ValueError(f"'dataroot_gt' should end with '.lmdb', but received {self.gt_folder}")
|
48 |
+
with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin:
|
49 |
+
self.paths = [line.split('.')[0] for line in fin]
|
50 |
+
else:
|
51 |
+
# disk backend with meta_info
|
52 |
+
# Each line in the meta_info describes the relative path to an image
|
53 |
+
with open(self.opt['meta_info']) as fin:
|
54 |
+
paths = [line.strip().split(' ')[0] for line in fin]
|
55 |
+
self.paths = [os.path.join(self.gt_folder, v) for v in paths]
|
56 |
+
|
57 |
+
# blur settings for the first degradation
|
58 |
+
self.blur_kernel_size = opt['blur_kernel_size']
|
59 |
+
self.kernel_list = opt['kernel_list']
|
60 |
+
self.kernel_prob = opt['kernel_prob'] # a list for each kernel probability
|
61 |
+
self.blur_sigma = opt['blur_sigma']
|
62 |
+
self.betag_range = opt['betag_range'] # betag used in generalized Gaussian blur kernels
|
63 |
+
self.betap_range = opt['betap_range'] # betap used in plateau blur kernels
|
64 |
+
self.sinc_prob = opt['sinc_prob'] # the probability for sinc filters
|
65 |
+
|
66 |
+
# blur settings for the second degradation
|
67 |
+
self.blur_kernel_size2 = opt['blur_kernel_size2']
|
68 |
+
self.kernel_list2 = opt['kernel_list2']
|
69 |
+
self.kernel_prob2 = opt['kernel_prob2']
|
70 |
+
self.blur_sigma2 = opt['blur_sigma2']
|
71 |
+
self.betag_range2 = opt['betag_range2']
|
72 |
+
self.betap_range2 = opt['betap_range2']
|
73 |
+
self.sinc_prob2 = opt['sinc_prob2']
|
74 |
+
|
75 |
+
# a final sinc filter
|
76 |
+
self.final_sinc_prob = opt['final_sinc_prob']
|
77 |
+
|
78 |
+
self.kernel_range = [2 * v + 1 for v in range(3, 11)] # kernel size ranges from 7 to 21
|
79 |
+
# TODO: kernel range is now hard-coded, should be in the configure file
|
80 |
+
self.pulse_tensor = torch.zeros(21, 21).float() # convolving with pulse tensor brings no blurry effect
|
81 |
+
self.pulse_tensor[10, 10] = 1
|
82 |
+
|
83 |
+
def __getitem__(self, index):
|
84 |
+
if self.file_client is None:
|
85 |
+
self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
|
86 |
+
|
87 |
+
# -------------------------------- Load gt images -------------------------------- #
|
88 |
+
# Shape: (h, w, c); channel order: BGR; image range: [0, 1], float32.
|
89 |
+
gt_path = self.paths[index]
|
90 |
+
# avoid errors caused by high latency in reading files
|
91 |
+
retry = 3
|
92 |
+
while retry > 0:
|
93 |
+
try:
|
94 |
+
img_bytes = self.file_client.get(gt_path, 'gt')
|
95 |
+
except (IOError, OSError) as e:
|
96 |
+
logger = get_root_logger()
|
97 |
+
logger.warn(f'File client error: {e}, remaining retry times: {retry - 1}')
|
98 |
+
# change another file to read
|
99 |
+
index = random.randint(0, self.__len__())
|
100 |
+
gt_path = self.paths[index]
|
101 |
+
time.sleep(1) # sleep 1s for occasional server congestion
|
102 |
+
else:
|
103 |
+
break
|
104 |
+
finally:
|
105 |
+
retry -= 1
|
106 |
+
img_gt = imfrombytes(img_bytes, float32=True)
|
107 |
+
|
108 |
+
# -------------------- Do augmentation for training: flip, rotation -------------------- #
|
109 |
+
img_gt = augment(img_gt, self.opt['use_hflip'], self.opt['use_rot'])
|
110 |
+
|
111 |
+
# crop or pad to 400
|
112 |
+
# TODO: 400 is hard-coded. You may change it accordingly
|
113 |
+
h, w = img_gt.shape[0:2]
|
114 |
+
crop_pad_size = 400
|
115 |
+
# pad
|
116 |
+
if h < crop_pad_size or w < crop_pad_size:
|
117 |
+
pad_h = max(0, crop_pad_size - h)
|
118 |
+
pad_w = max(0, crop_pad_size - w)
|
119 |
+
img_gt = cv2.copyMakeBorder(img_gt, 0, pad_h, 0, pad_w, cv2.BORDER_REFLECT_101)
|
120 |
+
# crop
|
121 |
+
if img_gt.shape[0] > crop_pad_size or img_gt.shape[1] > crop_pad_size:
|
122 |
+
h, w = img_gt.shape[0:2]
|
123 |
+
# randomly choose top and left coordinates
|
124 |
+
top = random.randint(0, h - crop_pad_size)
|
125 |
+
left = random.randint(0, w - crop_pad_size)
|
126 |
+
img_gt = img_gt[top:top + crop_pad_size, left:left + crop_pad_size, ...]
|
127 |
+
|
128 |
+
# ------------------------ Generate kernels (used in the first degradation) ------------------------ #
|
129 |
+
kernel_size = random.choice(self.kernel_range)
|
130 |
+
if np.random.uniform() < self.opt['sinc_prob']:
|
131 |
+
# this sinc filter setting is for kernels ranging from [7, 21]
|
132 |
+
if kernel_size < 13:
|
133 |
+
omega_c = np.random.uniform(np.pi / 3, np.pi)
|
134 |
+
else:
|
135 |
+
omega_c = np.random.uniform(np.pi / 5, np.pi)
|
136 |
+
kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False)
|
137 |
+
else:
|
138 |
+
kernel = random_mixed_kernels(
|
139 |
+
self.kernel_list,
|
140 |
+
self.kernel_prob,
|
141 |
+
kernel_size,
|
142 |
+
self.blur_sigma,
|
143 |
+
self.blur_sigma, [-math.pi, math.pi],
|
144 |
+
self.betag_range,
|
145 |
+
self.betap_range,
|
146 |
+
noise_range=None)
|
147 |
+
# pad kernel
|
148 |
+
pad_size = (21 - kernel_size) // 2
|
149 |
+
kernel = np.pad(kernel, ((pad_size, pad_size), (pad_size, pad_size)))
|
150 |
+
|
151 |
+
# ------------------------ Generate kernels (used in the second degradation) ------------------------ #
|
152 |
+
kernel_size = random.choice(self.kernel_range)
|
153 |
+
if np.random.uniform() < self.opt['sinc_prob2']:
|
154 |
+
if kernel_size < 13:
|
155 |
+
omega_c = np.random.uniform(np.pi / 3, np.pi)
|
156 |
+
else:
|
157 |
+
omega_c = np.random.uniform(np.pi / 5, np.pi)
|
158 |
+
kernel2 = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False)
|
159 |
+
else:
|
160 |
+
kernel2 = random_mixed_kernels(
|
161 |
+
self.kernel_list2,
|
162 |
+
self.kernel_prob2,
|
163 |
+
kernel_size,
|
164 |
+
self.blur_sigma2,
|
165 |
+
self.blur_sigma2, [-math.pi, math.pi],
|
166 |
+
self.betag_range2,
|
167 |
+
self.betap_range2,
|
168 |
+
noise_range=None)
|
169 |
+
|
170 |
+
# pad kernel
|
171 |
+
pad_size = (21 - kernel_size) // 2
|
172 |
+
kernel2 = np.pad(kernel2, ((pad_size, pad_size), (pad_size, pad_size)))
|
173 |
+
|
174 |
+
# ------------------------------------- the final sinc kernel ------------------------------------- #
|
175 |
+
if np.random.uniform() < self.opt['final_sinc_prob']:
|
176 |
+
kernel_size = random.choice(self.kernel_range)
|
177 |
+
omega_c = np.random.uniform(np.pi / 3, np.pi)
|
178 |
+
sinc_kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=21)
|
179 |
+
sinc_kernel = torch.FloatTensor(sinc_kernel)
|
180 |
+
else:
|
181 |
+
sinc_kernel = self.pulse_tensor
|
182 |
+
|
183 |
+
# BGR to RGB, HWC to CHW, numpy to tensor
|
184 |
+
img_gt = img2tensor([img_gt], bgr2rgb=True, float32=True)[0]
|
185 |
+
kernel = torch.FloatTensor(kernel)
|
186 |
+
kernel2 = torch.FloatTensor(kernel2)
|
187 |
+
|
188 |
+
return_d = {'gt': img_gt, 'kernel1': kernel, 'kernel2': kernel2, 'sinc_kernel': sinc_kernel, 'gt_path': gt_path}
|
189 |
+
return return_d
|
190 |
+
|
191 |
+
def __len__(self):
|
192 |
+
return len(self.paths)
|
external/realesrgan/data/realesrgan_paired_dataset.py
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from basicsr.data.data_util import paired_paths_from_folder, paired_paths_from_lmdb
|
3 |
+
from basicsr.data.transforms import augment, paired_random_crop
|
4 |
+
from basicsr.utils import FileClient
|
5 |
+
from basicsr.utils.img_util import imfrombytes, img2tensor
|
6 |
+
from basicsr.utils.registry import DATASET_REGISTRY
|
7 |
+
from torch.utils import data as data
|
8 |
+
from torchvision.transforms.functional import normalize
|
9 |
+
|
10 |
+
|
11 |
+
@DATASET_REGISTRY.register()
|
12 |
+
class RealESRGANPairedDataset(data.Dataset):
|
13 |
+
"""Paired image dataset for image restoration.
|
14 |
+
|
15 |
+
Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc) and GT image pairs.
|
16 |
+
|
17 |
+
There are three modes:
|
18 |
+
1. 'lmdb': Use lmdb files.
|
19 |
+
If opt['io_backend'] == lmdb.
|
20 |
+
2. 'meta_info': Use meta information file to generate paths.
|
21 |
+
If opt['io_backend'] != lmdb and opt['meta_info'] is not None.
|
22 |
+
3. 'folder': Scan folders to generate paths.
|
23 |
+
The rest.
|
24 |
+
|
25 |
+
Args:
|
26 |
+
opt (dict): Config for train datasets. It contains the following keys:
|
27 |
+
dataroot_gt (str): Data root path for gt.
|
28 |
+
dataroot_lq (str): Data root path for lq.
|
29 |
+
meta_info (str): Path for meta information file.
|
30 |
+
io_backend (dict): IO backend type and other kwarg.
|
31 |
+
filename_tmpl (str): Template for each filename. Note that the template excludes the file extension.
|
32 |
+
Default: '{}'.
|
33 |
+
gt_size (int): Cropped patched size for gt patches.
|
34 |
+
use_hflip (bool): Use horizontal flips.
|
35 |
+
use_rot (bool): Use rotation (use vertical flip and transposing h
|
36 |
+
and w for implementation).
|
37 |
+
|
38 |
+
scale (bool): Scale, which will be added automatically.
|
39 |
+
phase (str): 'train' or 'val'.
|
40 |
+
"""
|
41 |
+
|
42 |
+
def __init__(self, opt):
|
43 |
+
super(RealESRGANPairedDataset, self).__init__()
|
44 |
+
self.opt = opt
|
45 |
+
self.file_client = None
|
46 |
+
self.io_backend_opt = opt['io_backend']
|
47 |
+
# mean and std for normalizing the input images
|
48 |
+
self.mean = opt['mean'] if 'mean' in opt else None
|
49 |
+
self.std = opt['std'] if 'std' in opt else None
|
50 |
+
|
51 |
+
in_channels = opt['in_channels'] if 'in_channels' in opt else 3
|
52 |
+
if in_channels == 1:
|
53 |
+
self.flag = 'grayscale'
|
54 |
+
elif in_channels == 3:
|
55 |
+
self.flag = 'color'
|
56 |
+
else:
|
57 |
+
self.flag = 'unchanged'
|
58 |
+
|
59 |
+
self.gt_folder, self.lq_folder = opt['dataroot_gt'], opt['dataroot_lq']
|
60 |
+
self.filename_tmpl = opt['filename_tmpl'] if 'filename_tmpl' in opt else '{}'
|
61 |
+
|
62 |
+
# file client (lmdb io backend)
|
63 |
+
if self.io_backend_opt['type'] == 'lmdb':
|
64 |
+
self.io_backend_opt['db_paths'] = [self.lq_folder, self.gt_folder]
|
65 |
+
self.io_backend_opt['client_keys'] = ['lq', 'gt']
|
66 |
+
self.paths = paired_paths_from_lmdb([self.lq_folder, self.gt_folder], ['lq', 'gt'])
|
67 |
+
elif 'meta_info' in self.opt and self.opt['meta_info'] is not None:
|
68 |
+
# disk backend with meta_info
|
69 |
+
# Each line in the meta_info describes the relative path to an image
|
70 |
+
with open(self.opt['meta_info']) as fin:
|
71 |
+
paths = [line.strip() for line in fin]
|
72 |
+
self.paths = []
|
73 |
+
for path in paths:
|
74 |
+
gt_path, lq_path = path.split(', ')
|
75 |
+
gt_path = os.path.join(self.gt_folder, gt_path)
|
76 |
+
lq_path = os.path.join(self.lq_folder, lq_path)
|
77 |
+
self.paths.append(dict([('gt_path', gt_path), ('lq_path', lq_path)]))
|
78 |
+
else:
|
79 |
+
# disk backend
|
80 |
+
# it will scan the whole folder to get meta info
|
81 |
+
# it will be time-consuming for folders with too many files. It is recommended using an extra meta txt file
|
82 |
+
self.paths = paired_paths_from_folder([self.lq_folder, self.gt_folder], ['lq', 'gt'], self.filename_tmpl)
|
83 |
+
|
84 |
+
def __getitem__(self, index):
|
85 |
+
if self.file_client is None:
|
86 |
+
self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
|
87 |
+
|
88 |
+
scale = self.opt['scale']
|
89 |
+
|
90 |
+
# Load gt and lq images. Dimension order: HWC; channel order: BGR;
|
91 |
+
# image range: [0, 1], float32.
|
92 |
+
gt_path = self.paths[index]['gt_path']
|
93 |
+
img_bytes = self.file_client.get(gt_path, 'gt')
|
94 |
+
img_gt = imfrombytes(img_bytes, flag = self.flag, float32=True)
|
95 |
+
lq_path = self.paths[index]['lq_path']
|
96 |
+
img_bytes = self.file_client.get(lq_path, 'lq')
|
97 |
+
img_lq = imfrombytes(img_bytes, flag = self.flag, float32=True)
|
98 |
+
|
99 |
+
# augmentation for training
|
100 |
+
if self.opt['phase'] == 'train':
|
101 |
+
gt_size = self.opt['gt_size']
|
102 |
+
# random crop
|
103 |
+
img_gt, img_lq = paired_random_crop(img_gt, img_lq, gt_size, scale, gt_path)
|
104 |
+
# flip, rotation
|
105 |
+
img_gt, img_lq = augment([img_gt, img_lq], self.opt['use_hflip'], self.opt['use_rot'])
|
106 |
+
|
107 |
+
# BGR to RGB, HWC to CHW, numpy to tensor
|
108 |
+
img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True)
|
109 |
+
# normalize
|
110 |
+
if self.mean is not None or self.std is not None:
|
111 |
+
normalize(img_lq, self.mean, self.std, inplace=True)
|
112 |
+
normalize(img_gt, self.mean, self.std, inplace=True)
|
113 |
+
|
114 |
+
return {'lq': img_lq, 'gt': img_gt, 'lq_path': lq_path, 'gt_path': gt_path}
|
115 |
+
|
116 |
+
def __len__(self):
|
117 |
+
return len(self.paths)
|
external/realesrgan/models/__init__.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import importlib
|
2 |
+
from basicsr.utils import scandir
|
3 |
+
from os import path as osp
|
4 |
+
|
5 |
+
# automatically scan and import model modules for registry
|
6 |
+
# scan all the files that end with '_model.py' under the model folder
|
7 |
+
model_folder = osp.dirname(osp.abspath(__file__))
|
8 |
+
model_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(model_folder) if v.endswith('_model.py')]
|
9 |
+
# import all the model modules
|
10 |
+
_model_modules = [importlib.import_module(f'realesrgan.models.{file_name}') for file_name in model_filenames]
|
external/realesrgan/models/realesrgan_model.py
ADDED
@@ -0,0 +1,258 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import random
|
3 |
+
import torch
|
4 |
+
from basicsr.data.degradations import random_add_gaussian_noise_pt, random_add_poisson_noise_pt
|
5 |
+
from basicsr.data.transforms import paired_random_crop
|
6 |
+
from basicsr.models.srgan_model import SRGANModel
|
7 |
+
from basicsr.utils import DiffJPEG, USMSharp
|
8 |
+
from basicsr.utils.img_process_util import filter2D
|
9 |
+
from basicsr.utils.registry import MODEL_REGISTRY
|
10 |
+
from collections import OrderedDict
|
11 |
+
from torch.nn import functional as F
|
12 |
+
|
13 |
+
|
14 |
+
@MODEL_REGISTRY.register()
|
15 |
+
class RealESRGANModel(SRGANModel):
|
16 |
+
"""RealESRGAN Model for Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data.
|
17 |
+
|
18 |
+
It mainly performs:
|
19 |
+
1. randomly synthesize LQ images in GPU tensors
|
20 |
+
2. optimize the networks with GAN training.
|
21 |
+
"""
|
22 |
+
|
23 |
+
def __init__(self, opt):
|
24 |
+
super(RealESRGANModel, self).__init__(opt)
|
25 |
+
self.jpeger = DiffJPEG(differentiable=False).cuda() # simulate JPEG compression artifacts
|
26 |
+
self.usm_sharpener = USMSharp().cuda() # do usm sharpening
|
27 |
+
self.queue_size = opt.get('queue_size', 180)
|
28 |
+
|
29 |
+
@torch.no_grad()
|
30 |
+
def _dequeue_and_enqueue(self):
|
31 |
+
"""It is the training pair pool for increasing the diversity in a batch.
|
32 |
+
|
33 |
+
Batch processing limits the diversity of synthetic degradations in a batch. For example, samples in a
|
34 |
+
batch could not have different resize scaling factors. Therefore, we employ this training pair pool
|
35 |
+
to increase the degradation diversity in a batch.
|
36 |
+
"""
|
37 |
+
# initialize
|
38 |
+
b, c, h, w = self.lq.size()
|
39 |
+
if not hasattr(self, 'queue_lr'):
|
40 |
+
assert self.queue_size % b == 0, f'queue size {self.queue_size} should be divisible by batch size {b}'
|
41 |
+
self.queue_lr = torch.zeros(self.queue_size, c, h, w).cuda()
|
42 |
+
_, c, h, w = self.gt.size()
|
43 |
+
self.queue_gt = torch.zeros(self.queue_size, c, h, w).cuda()
|
44 |
+
self.queue_ptr = 0
|
45 |
+
if self.queue_ptr == self.queue_size: # the pool is full
|
46 |
+
# do dequeue and enqueue
|
47 |
+
# shuffle
|
48 |
+
idx = torch.randperm(self.queue_size)
|
49 |
+
self.queue_lr = self.queue_lr[idx]
|
50 |
+
self.queue_gt = self.queue_gt[idx]
|
51 |
+
# get first b samples
|
52 |
+
lq_dequeue = self.queue_lr[0:b, :, :, :].clone()
|
53 |
+
gt_dequeue = self.queue_gt[0:b, :, :, :].clone()
|
54 |
+
# update the queue
|
55 |
+
self.queue_lr[0:b, :, :, :] = self.lq.clone()
|
56 |
+
self.queue_gt[0:b, :, :, :] = self.gt.clone()
|
57 |
+
|
58 |
+
self.lq = lq_dequeue
|
59 |
+
self.gt = gt_dequeue
|
60 |
+
else:
|
61 |
+
# only do enqueue
|
62 |
+
self.queue_lr[self.queue_ptr:self.queue_ptr + b, :, :, :] = self.lq.clone()
|
63 |
+
self.queue_gt[self.queue_ptr:self.queue_ptr + b, :, :, :] = self.gt.clone()
|
64 |
+
self.queue_ptr = self.queue_ptr + b
|
65 |
+
|
66 |
+
@torch.no_grad()
|
67 |
+
def feed_data(self, data):
|
68 |
+
"""Accept data from dataloader, and then add two-order degradations to obtain LQ images.
|
69 |
+
"""
|
70 |
+
if self.is_train and self.opt.get('high_order_degradation', True):
|
71 |
+
# training data synthesis
|
72 |
+
self.gt = data['gt'].to(self.device)
|
73 |
+
self.gt_usm = self.usm_sharpener(self.gt)
|
74 |
+
|
75 |
+
self.kernel1 = data['kernel1'].to(self.device)
|
76 |
+
self.kernel2 = data['kernel2'].to(self.device)
|
77 |
+
self.sinc_kernel = data['sinc_kernel'].to(self.device)
|
78 |
+
|
79 |
+
ori_h, ori_w = self.gt.size()[2:4]
|
80 |
+
|
81 |
+
# ----------------------- The first degradation process ----------------------- #
|
82 |
+
# blur
|
83 |
+
out = filter2D(self.gt_usm, self.kernel1)
|
84 |
+
# random resize
|
85 |
+
updown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob'])[0]
|
86 |
+
if updown_type == 'up':
|
87 |
+
scale = np.random.uniform(1, self.opt['resize_range'][1])
|
88 |
+
elif updown_type == 'down':
|
89 |
+
scale = np.random.uniform(self.opt['resize_range'][0], 1)
|
90 |
+
else:
|
91 |
+
scale = 1
|
92 |
+
mode = random.choice(['area', 'bilinear', 'bicubic'])
|
93 |
+
out = F.interpolate(out, scale_factor=scale, mode=mode)
|
94 |
+
# add noise
|
95 |
+
gray_noise_prob = self.opt['gray_noise_prob']
|
96 |
+
if np.random.uniform() < self.opt['gaussian_noise_prob']:
|
97 |
+
out = random_add_gaussian_noise_pt(
|
98 |
+
out, sigma_range=self.opt['noise_range'], clip=True, rounds=False, gray_prob=gray_noise_prob)
|
99 |
+
else:
|
100 |
+
out = random_add_poisson_noise_pt(
|
101 |
+
out,
|
102 |
+
scale_range=self.opt['poisson_scale_range'],
|
103 |
+
gray_prob=gray_noise_prob,
|
104 |
+
clip=True,
|
105 |
+
rounds=False)
|
106 |
+
# JPEG compression
|
107 |
+
jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range'])
|
108 |
+
out = torch.clamp(out, 0, 1) # clamp to [0, 1], otherwise JPEGer will result in unpleasant artifacts
|
109 |
+
out = self.jpeger(out, quality=jpeg_p)
|
110 |
+
|
111 |
+
# ----------------------- The second degradation process ----------------------- #
|
112 |
+
# blur
|
113 |
+
if np.random.uniform() < self.opt['second_blur_prob']:
|
114 |
+
out = filter2D(out, self.kernel2)
|
115 |
+
# random resize
|
116 |
+
updown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob2'])[0]
|
117 |
+
if updown_type == 'up':
|
118 |
+
scale = np.random.uniform(1, self.opt['resize_range2'][1])
|
119 |
+
elif updown_type == 'down':
|
120 |
+
scale = np.random.uniform(self.opt['resize_range2'][0], 1)
|
121 |
+
else:
|
122 |
+
scale = 1
|
123 |
+
mode = random.choice(['area', 'bilinear', 'bicubic'])
|
124 |
+
out = F.interpolate(
|
125 |
+
out, size=(int(ori_h / self.opt['scale'] * scale), int(ori_w / self.opt['scale'] * scale)), mode=mode)
|
126 |
+
# add noise
|
127 |
+
gray_noise_prob = self.opt['gray_noise_prob2']
|
128 |
+
if np.random.uniform() < self.opt['gaussian_noise_prob2']:
|
129 |
+
out = random_add_gaussian_noise_pt(
|
130 |
+
out, sigma_range=self.opt['noise_range2'], clip=True, rounds=False, gray_prob=gray_noise_prob)
|
131 |
+
else:
|
132 |
+
out = random_add_poisson_noise_pt(
|
133 |
+
out,
|
134 |
+
scale_range=self.opt['poisson_scale_range2'],
|
135 |
+
gray_prob=gray_noise_prob,
|
136 |
+
clip=True,
|
137 |
+
rounds=False)
|
138 |
+
|
139 |
+
# JPEG compression + the final sinc filter
|
140 |
+
# We also need to resize images to desired sizes. We group [resize back + sinc filter] together
|
141 |
+
# as one operation.
|
142 |
+
# We consider two orders:
|
143 |
+
# 1. [resize back + sinc filter] + JPEG compression
|
144 |
+
# 2. JPEG compression + [resize back + sinc filter]
|
145 |
+
# Empirically, we find other combinations (sinc + JPEG + Resize) will introduce twisted lines.
|
146 |
+
if np.random.uniform() < 0.5:
|
147 |
+
# resize back + the final sinc filter
|
148 |
+
mode = random.choice(['area', 'bilinear', 'bicubic'])
|
149 |
+
out = F.interpolate(out, size=(ori_h // self.opt['scale'], ori_w // self.opt['scale']), mode=mode)
|
150 |
+
out = filter2D(out, self.sinc_kernel)
|
151 |
+
# JPEG compression
|
152 |
+
jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2'])
|
153 |
+
out = torch.clamp(out, 0, 1)
|
154 |
+
out = self.jpeger(out, quality=jpeg_p)
|
155 |
+
else:
|
156 |
+
# JPEG compression
|
157 |
+
jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2'])
|
158 |
+
out = torch.clamp(out, 0, 1)
|
159 |
+
out = self.jpeger(out, quality=jpeg_p)
|
160 |
+
# resize back + the final sinc filter
|
161 |
+
mode = random.choice(['area', 'bilinear', 'bicubic'])
|
162 |
+
out = F.interpolate(out, size=(ori_h // self.opt['scale'], ori_w // self.opt['scale']), mode=mode)
|
163 |
+
out = filter2D(out, self.sinc_kernel)
|
164 |
+
|
165 |
+
# clamp and round
|
166 |
+
self.lq = torch.clamp((out * 255.0).round(), 0, 255) / 255.
|
167 |
+
|
168 |
+
# random crop
|
169 |
+
gt_size = self.opt['gt_size']
|
170 |
+
(self.gt, self.gt_usm), self.lq = paired_random_crop([self.gt, self.gt_usm], self.lq, gt_size,
|
171 |
+
self.opt['scale'])
|
172 |
+
|
173 |
+
# training pair pool
|
174 |
+
self._dequeue_and_enqueue()
|
175 |
+
# sharpen self.gt again, as we have changed the self.gt with self._dequeue_and_enqueue
|
176 |
+
self.gt_usm = self.usm_sharpener(self.gt)
|
177 |
+
self.lq = self.lq.contiguous() # for the warning: grad and param do not obey the gradient layout contract
|
178 |
+
else:
|
179 |
+
# for paired training or validation
|
180 |
+
self.lq = data['lq'].to(self.device)
|
181 |
+
if 'gt' in data:
|
182 |
+
self.gt = data['gt'].to(self.device)
|
183 |
+
self.gt_usm = self.usm_sharpener(self.gt)
|
184 |
+
|
185 |
+
def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
|
186 |
+
# do not use the synthetic process during validation
|
187 |
+
self.is_train = False
|
188 |
+
super(RealESRGANModel, self).nondist_validation(dataloader, current_iter, tb_logger, save_img)
|
189 |
+
self.is_train = True
|
190 |
+
|
191 |
+
def optimize_parameters(self, current_iter):
|
192 |
+
# usm sharpening
|
193 |
+
l1_gt = self.gt_usm
|
194 |
+
percep_gt = self.gt_usm
|
195 |
+
gan_gt = self.gt_usm
|
196 |
+
if self.opt['l1_gt_usm'] is False:
|
197 |
+
l1_gt = self.gt
|
198 |
+
if self.opt['percep_gt_usm'] is False:
|
199 |
+
percep_gt = self.gt
|
200 |
+
if self.opt['gan_gt_usm'] is False:
|
201 |
+
gan_gt = self.gt
|
202 |
+
|
203 |
+
# optimize net_g
|
204 |
+
for p in self.net_d.parameters():
|
205 |
+
p.requires_grad = False
|
206 |
+
|
207 |
+
self.optimizer_g.zero_grad()
|
208 |
+
self.output = self.net_g(self.lq)
|
209 |
+
|
210 |
+
l_g_total = 0
|
211 |
+
loss_dict = OrderedDict()
|
212 |
+
if (current_iter % self.net_d_iters == 0 and current_iter > self.net_d_init_iters):
|
213 |
+
# pixel loss
|
214 |
+
if self.cri_pix:
|
215 |
+
l_g_pix = self.cri_pix(self.output, l1_gt)
|
216 |
+
l_g_total += l_g_pix
|
217 |
+
loss_dict['l_g_pix'] = l_g_pix
|
218 |
+
# perceptual loss
|
219 |
+
if self.cri_perceptual:
|
220 |
+
l_g_percep, l_g_style = self.cri_perceptual(self.output, percep_gt)
|
221 |
+
if l_g_percep is not None:
|
222 |
+
l_g_total += l_g_percep
|
223 |
+
loss_dict['l_g_percep'] = l_g_percep
|
224 |
+
if l_g_style is not None:
|
225 |
+
l_g_total += l_g_style
|
226 |
+
loss_dict['l_g_style'] = l_g_style
|
227 |
+
# gan loss
|
228 |
+
fake_g_pred = self.net_d(self.output)
|
229 |
+
l_g_gan = self.cri_gan(fake_g_pred, True, is_disc=False)
|
230 |
+
l_g_total += l_g_gan
|
231 |
+
loss_dict['l_g_gan'] = l_g_gan
|
232 |
+
|
233 |
+
l_g_total.backward()
|
234 |
+
self.optimizer_g.step()
|
235 |
+
|
236 |
+
# optimize net_d
|
237 |
+
for p in self.net_d.parameters():
|
238 |
+
p.requires_grad = True
|
239 |
+
|
240 |
+
self.optimizer_d.zero_grad()
|
241 |
+
# real
|
242 |
+
real_d_pred = self.net_d(gan_gt)
|
243 |
+
l_d_real = self.cri_gan(real_d_pred, True, is_disc=True)
|
244 |
+
loss_dict['l_d_real'] = l_d_real
|
245 |
+
loss_dict['out_d_real'] = torch.mean(real_d_pred.detach())
|
246 |
+
l_d_real.backward()
|
247 |
+
# fake
|
248 |
+
fake_d_pred = self.net_d(self.output.detach().clone()) # clone for pt1.9
|
249 |
+
l_d_fake = self.cri_gan(fake_d_pred, False, is_disc=True)
|
250 |
+
loss_dict['l_d_fake'] = l_d_fake
|
251 |
+
loss_dict['out_d_fake'] = torch.mean(fake_d_pred.detach())
|
252 |
+
l_d_fake.backward()
|
253 |
+
self.optimizer_d.step()
|
254 |
+
|
255 |
+
if self.ema_decay > 0:
|
256 |
+
self.model_ema(decay=self.ema_decay)
|
257 |
+
|
258 |
+
self.log_dict = self.reduce_loss_dict(loss_dict)
|
external/realesrgan/models/realesrnet_model.py
ADDED
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import random
|
3 |
+
import torch
|
4 |
+
from basicsr.data.degradations import random_add_gaussian_noise_pt, random_add_poisson_noise_pt
|
5 |
+
from basicsr.data.transforms import paired_random_crop
|
6 |
+
from basicsr.models.sr_model import SRModel
|
7 |
+
from basicsr.utils import DiffJPEG, USMSharp
|
8 |
+
from basicsr.utils.img_process_util import filter2D
|
9 |
+
from basicsr.utils.registry import MODEL_REGISTRY
|
10 |
+
from torch.nn import functional as F
|
11 |
+
|
12 |
+
|
13 |
+
@MODEL_REGISTRY.register()
|
14 |
+
class RealESRNetModel(SRModel):
|
15 |
+
"""RealESRNet Model for Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data.
|
16 |
+
|
17 |
+
It is trained without GAN losses.
|
18 |
+
It mainly performs:
|
19 |
+
1. randomly synthesize LQ images in GPU tensors
|
20 |
+
2. optimize the networks with GAN training.
|
21 |
+
"""
|
22 |
+
|
23 |
+
def __init__(self, opt):
|
24 |
+
super(RealESRNetModel, self).__init__(opt)
|
25 |
+
self.jpeger = DiffJPEG(differentiable=False).cuda() # simulate JPEG compression artifacts
|
26 |
+
self.usm_sharpener = USMSharp().cuda() # do usm sharpening
|
27 |
+
self.queue_size = opt.get('queue_size', 180)
|
28 |
+
|
29 |
+
@torch.no_grad()
|
30 |
+
def _dequeue_and_enqueue(self):
|
31 |
+
"""It is the training pair pool for increasing the diversity in a batch.
|
32 |
+
|
33 |
+
Batch processing limits the diversity of synthetic degradations in a batch. For example, samples in a
|
34 |
+
batch could not have different resize scaling factors. Therefore, we employ this training pair pool
|
35 |
+
to increase the degradation diversity in a batch.
|
36 |
+
"""
|
37 |
+
# initialize
|
38 |
+
b, c, h, w = self.lq.size()
|
39 |
+
if not hasattr(self, 'queue_lr'):
|
40 |
+
assert self.queue_size % b == 0, f'queue size {self.queue_size} should be divisible by batch size {b}'
|
41 |
+
self.queue_lr = torch.zeros(self.queue_size, c, h, w).cuda()
|
42 |
+
_, c, h, w = self.gt.size()
|
43 |
+
self.queue_gt = torch.zeros(self.queue_size, c, h, w).cuda()
|
44 |
+
self.queue_ptr = 0
|
45 |
+
if self.queue_ptr == self.queue_size: # the pool is full
|
46 |
+
# do dequeue and enqueue
|
47 |
+
# shuffle
|
48 |
+
idx = torch.randperm(self.queue_size)
|
49 |
+
self.queue_lr = self.queue_lr[idx]
|
50 |
+
self.queue_gt = self.queue_gt[idx]
|
51 |
+
# get first b samples
|
52 |
+
lq_dequeue = self.queue_lr[0:b, :, :, :].clone()
|
53 |
+
gt_dequeue = self.queue_gt[0:b, :, :, :].clone()
|
54 |
+
# update the queue
|
55 |
+
self.queue_lr[0:b, :, :, :] = self.lq.clone()
|
56 |
+
self.queue_gt[0:b, :, :, :] = self.gt.clone()
|
57 |
+
|
58 |
+
self.lq = lq_dequeue
|
59 |
+
self.gt = gt_dequeue
|
60 |
+
else:
|
61 |
+
# only do enqueue
|
62 |
+
self.queue_lr[self.queue_ptr:self.queue_ptr + b, :, :, :] = self.lq.clone()
|
63 |
+
self.queue_gt[self.queue_ptr:self.queue_ptr + b, :, :, :] = self.gt.clone()
|
64 |
+
self.queue_ptr = self.queue_ptr + b
|
65 |
+
|
66 |
+
@torch.no_grad()
|
67 |
+
def feed_data(self, data):
|
68 |
+
"""Accept data from dataloader, and then add two-order degradations to obtain LQ images.
|
69 |
+
"""
|
70 |
+
if self.is_train and self.opt.get('high_order_degradation', True):
|
71 |
+
# training data synthesis
|
72 |
+
self.gt = data['gt'].to(self.device)
|
73 |
+
# USM sharpen the GT images
|
74 |
+
if self.opt['gt_usm'] is True:
|
75 |
+
self.gt = self.usm_sharpener(self.gt)
|
76 |
+
|
77 |
+
self.kernel1 = data['kernel1'].to(self.device)
|
78 |
+
self.kernel2 = data['kernel2'].to(self.device)
|
79 |
+
self.sinc_kernel = data['sinc_kernel'].to(self.device)
|
80 |
+
|
81 |
+
ori_h, ori_w = self.gt.size()[2:4]
|
82 |
+
|
83 |
+
# ----------------------- The first degradation process ----------------------- #
|
84 |
+
# blur
|
85 |
+
out = filter2D(self.gt, self.kernel1)
|
86 |
+
# random resize
|
87 |
+
updown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob'])[0]
|
88 |
+
if updown_type == 'up':
|
89 |
+
scale = np.random.uniform(1, self.opt['resize_range'][1])
|
90 |
+
elif updown_type == 'down':
|
91 |
+
scale = np.random.uniform(self.opt['resize_range'][0], 1)
|
92 |
+
else:
|
93 |
+
scale = 1
|
94 |
+
mode = random.choice(['area', 'bilinear', 'bicubic'])
|
95 |
+
out = F.interpolate(out, scale_factor=scale, mode=mode)
|
96 |
+
# add noise
|
97 |
+
gray_noise_prob = self.opt['gray_noise_prob']
|
98 |
+
if np.random.uniform() < self.opt['gaussian_noise_prob']:
|
99 |
+
out = random_add_gaussian_noise_pt(
|
100 |
+
out, sigma_range=self.opt['noise_range'], clip=True, rounds=False, gray_prob=gray_noise_prob)
|
101 |
+
else:
|
102 |
+
out = random_add_poisson_noise_pt(
|
103 |
+
out,
|
104 |
+
scale_range=self.opt['poisson_scale_range'],
|
105 |
+
gray_prob=gray_noise_prob,
|
106 |
+
clip=True,
|
107 |
+
rounds=False)
|
108 |
+
# JPEG compression
|
109 |
+
jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range'])
|
110 |
+
out = torch.clamp(out, 0, 1) # clamp to [0, 1], otherwise JPEGer will result in unpleasant artifacts
|
111 |
+
out = self.jpeger(out, quality=jpeg_p)
|
112 |
+
|
113 |
+
# ----------------------- The second degradation process ----------------------- #
|
114 |
+
# blur
|
115 |
+
if np.random.uniform() < self.opt['second_blur_prob']:
|
116 |
+
out = filter2D(out, self.kernel2)
|
117 |
+
# random resize
|
118 |
+
updown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob2'])[0]
|
119 |
+
if updown_type == 'up':
|
120 |
+
scale = np.random.uniform(1, self.opt['resize_range2'][1])
|
121 |
+
elif updown_type == 'down':
|
122 |
+
scale = np.random.uniform(self.opt['resize_range2'][0], 1)
|
123 |
+
else:
|
124 |
+
scale = 1
|
125 |
+
mode = random.choice(['area', 'bilinear', 'bicubic'])
|
126 |
+
out = F.interpolate(
|
127 |
+
out, size=(int(ori_h / self.opt['scale'] * scale), int(ori_w / self.opt['scale'] * scale)), mode=mode)
|
128 |
+
# add noise
|
129 |
+
gray_noise_prob = self.opt['gray_noise_prob2']
|
130 |
+
if np.random.uniform() < self.opt['gaussian_noise_prob2']:
|
131 |
+
out = random_add_gaussian_noise_pt(
|
132 |
+
out, sigma_range=self.opt['noise_range2'], clip=True, rounds=False, gray_prob=gray_noise_prob)
|
133 |
+
else:
|
134 |
+
out = random_add_poisson_noise_pt(
|
135 |
+
out,
|
136 |
+
scale_range=self.opt['poisson_scale_range2'],
|
137 |
+
gray_prob=gray_noise_prob,
|
138 |
+
clip=True,
|
139 |
+
rounds=False)
|
140 |
+
|
141 |
+
# JPEG compression + the final sinc filter
|
142 |
+
# We also need to resize images to desired sizes. We group [resize back + sinc filter] together
|
143 |
+
# as one operation.
|
144 |
+
# We consider two orders:
|
145 |
+
# 1. [resize back + sinc filter] + JPEG compression
|
146 |
+
# 2. JPEG compression + [resize back + sinc filter]
|
147 |
+
# Empirically, we find other combinations (sinc + JPEG + Resize) will introduce twisted lines.
|
148 |
+
if np.random.uniform() < 0.5:
|
149 |
+
# resize back + the final sinc filter
|
150 |
+
mode = random.choice(['area', 'bilinear', 'bicubic'])
|
151 |
+
out = F.interpolate(out, size=(ori_h // self.opt['scale'], ori_w // self.opt['scale']), mode=mode)
|
152 |
+
out = filter2D(out, self.sinc_kernel)
|
153 |
+
# JPEG compression
|
154 |
+
jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2'])
|
155 |
+
out = torch.clamp(out, 0, 1)
|
156 |
+
out = self.jpeger(out, quality=jpeg_p)
|
157 |
+
else:
|
158 |
+
# JPEG compression
|
159 |
+
jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2'])
|
160 |
+
out = torch.clamp(out, 0, 1)
|
161 |
+
out = self.jpeger(out, quality=jpeg_p)
|
162 |
+
# resize back + the final sinc filter
|
163 |
+
mode = random.choice(['area', 'bilinear', 'bicubic'])
|
164 |
+
out = F.interpolate(out, size=(ori_h // self.opt['scale'], ori_w // self.opt['scale']), mode=mode)
|
165 |
+
out = filter2D(out, self.sinc_kernel)
|
166 |
+
|
167 |
+
# clamp and round
|
168 |
+
self.lq = torch.clamp((out * 255.0).round(), 0, 255) / 255.
|
169 |
+
|
170 |
+
# random crop
|
171 |
+
gt_size = self.opt['gt_size']
|
172 |
+
self.gt, self.lq = paired_random_crop(self.gt, self.lq, gt_size, self.opt['scale'])
|
173 |
+
|
174 |
+
# training pair pool
|
175 |
+
self._dequeue_and_enqueue()
|
176 |
+
self.lq = self.lq.contiguous() # for the warning: grad and param do not obey the gradient layout contract
|
177 |
+
else:
|
178 |
+
# for paired training or validation
|
179 |
+
self.lq = data['lq'].to(self.device)
|
180 |
+
if 'gt' in data:
|
181 |
+
self.gt = data['gt'].to(self.device)
|
182 |
+
self.gt_usm = self.usm_sharpener(self.gt)
|
183 |
+
|
184 |
+
def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
|
185 |
+
# do not use the synthetic process during validation
|
186 |
+
self.is_train = False
|
187 |
+
super(RealESRNetModel, self).nondist_validation(dataloader, current_iter, tb_logger, save_img)
|
188 |
+
self.is_train = True
|
external/realesrgan/train.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# flake8: noqa
|
2 |
+
import os.path as osp
|
3 |
+
from basicsr.train import train_pipeline
|
4 |
+
|
5 |
+
import realesrgan.archs
|
6 |
+
import realesrgan.data
|
7 |
+
import realesrgan.models
|
8 |
+
|
9 |
+
if __name__ == '__main__':
|
10 |
+
root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir))
|
11 |
+
train_pipeline(root_path)
|
external/realesrgan/utils.py
ADDED
@@ -0,0 +1,302 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import math
|
3 |
+
import numpy as np
|
4 |
+
import os
|
5 |
+
import queue
|
6 |
+
import threading
|
7 |
+
import torch
|
8 |
+
from basicsr.utils.download_util import load_file_from_url
|
9 |
+
from torch.nn import functional as F
|
10 |
+
|
11 |
+
ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
12 |
+
|
13 |
+
|
14 |
+
class RealESRGANer():
|
15 |
+
"""A helper class for upsampling images with RealESRGAN.
|
16 |
+
|
17 |
+
Args:
|
18 |
+
scale (int): Upsampling scale factor used in the networks. It is usually 2 or 4.
|
19 |
+
model_path (str): The path to the pretrained model. It can be urls (will first download it automatically).
|
20 |
+
model (nn.Module): The defined network. Default: None.
|
21 |
+
tile (int): As too large images result in the out of GPU memory issue, so this tile option will first crop
|
22 |
+
input images into tiles, and then process each of them. Finally, they will be merged into one image.
|
23 |
+
0 denotes for do not use tile. Default: 0.
|
24 |
+
tile_pad (int): The pad size for each tile, to remove border artifacts. Default: 10.
|
25 |
+
pre_pad (int): Pad the input images to avoid border artifacts. Default: 10.
|
26 |
+
half (float): Whether to use half precision during inference. Default: False.
|
27 |
+
"""
|
28 |
+
|
29 |
+
def __init__(self, scale, model_path, dni_weight = None, model = None, tile = 0, tile_pad = 10, pre_pad = 10, half = False, device = None, gpu_id = None):
|
30 |
+
self.scale = scale
|
31 |
+
self.tile_size = tile
|
32 |
+
self.tile_pad = tile_pad
|
33 |
+
self.pre_pad = pre_pad
|
34 |
+
self.mod_scale = None
|
35 |
+
self.half = half
|
36 |
+
|
37 |
+
# initialize model
|
38 |
+
if gpu_id:
|
39 |
+
self.device = torch.device(f'cuda:{gpu_id}' if torch.cuda.is_available() else 'cpu') if device is None else device
|
40 |
+
else:
|
41 |
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device
|
42 |
+
|
43 |
+
if isinstance(model_path, list):
|
44 |
+
# dni
|
45 |
+
assert len(model_path) == len(dni_weight), 'model_path and dni_weight should have the save length.'
|
46 |
+
loadnet = self.dni(model_path[0], model_path[1], dni_weight)
|
47 |
+
else:
|
48 |
+
# if the model_path starts with https, it will first download models to the folder: weights
|
49 |
+
if model_path.startswith('https://'):
|
50 |
+
model_path = load_file_from_url(url = model_path, model_dir = os.path.join(ROOT_DIR, 'weights'), progress = True, file_name = None)
|
51 |
+
loadnet = torch.load(model_path, map_location = torch.device('cpu'))
|
52 |
+
|
53 |
+
# prefer to use params_ema
|
54 |
+
if 'params_ema' in loadnet:
|
55 |
+
keyname = 'params_ema'
|
56 |
+
else:
|
57 |
+
keyname = 'params'
|
58 |
+
model.load_state_dict(loadnet[keyname], strict=True)
|
59 |
+
|
60 |
+
model.eval()
|
61 |
+
self.model = model.to(self.device)
|
62 |
+
if self.half:
|
63 |
+
self.model = self.model.half()
|
64 |
+
|
65 |
+
def dni(self, net_a, net_b, dni_weight, key='params', loc='cpu'):
|
66 |
+
"""Deep network interpolation.
|
67 |
+
|
68 |
+
``Paper: Deep Network Interpolation for Continuous Imagery Effect Transition``
|
69 |
+
"""
|
70 |
+
net_a = torch.load(net_a, map_location = torch.device(loc))
|
71 |
+
net_b = torch.load(net_b, map_location = torch.device(loc))
|
72 |
+
for k, v_a in net_a[key].items():
|
73 |
+
net_a[key][k] = dni_weight[0] * v_a + dni_weight[1] * net_b[key][k]
|
74 |
+
return net_a
|
75 |
+
|
76 |
+
def pre_process(self, img):
|
77 |
+
"""Pre-process, such as pre-pad and mod pad, so that the images can be divisible
|
78 |
+
"""
|
79 |
+
img = torch.from_numpy(np.transpose(img, (2, 0, 1))).float()
|
80 |
+
self.img = img.unsqueeze(0).to(self.device)
|
81 |
+
if self.half:
|
82 |
+
self.img = self.img.half()
|
83 |
+
|
84 |
+
# pre_pad
|
85 |
+
if self.pre_pad != 0:
|
86 |
+
self.img = F.pad(self.img, (0, self.pre_pad, 0, self.pre_pad), 'reflect')
|
87 |
+
# mod pad for divisible borders
|
88 |
+
if self.scale == 2:
|
89 |
+
self.mod_scale = 2
|
90 |
+
elif self.scale == 1:
|
91 |
+
self.mod_scale = 4
|
92 |
+
if self.mod_scale is not None:
|
93 |
+
self.mod_pad_h, self.mod_pad_w = 0, 0
|
94 |
+
_, _, h, w = self.img.size()
|
95 |
+
if (h % self.mod_scale != 0):
|
96 |
+
self.mod_pad_h = (self.mod_scale - h % self.mod_scale)
|
97 |
+
if (w % self.mod_scale != 0):
|
98 |
+
self.mod_pad_w = (self.mod_scale - w % self.mod_scale)
|
99 |
+
self.img = F.pad(self.img, (0, self.mod_pad_w, 0, self.mod_pad_h), 'reflect')
|
100 |
+
|
101 |
+
def process(self):
|
102 |
+
# model inference
|
103 |
+
self.output = self.model(self.img)
|
104 |
+
|
105 |
+
def tile_process(self):
|
106 |
+
"""It will first crop input images to tiles, and then process each tile.
|
107 |
+
Finally, all the processed tiles are merged into one images.
|
108 |
+
|
109 |
+
Modified from: https://github.com/ata4/esrgan-launcher
|
110 |
+
"""
|
111 |
+
batch, channel, height, width = self.img.shape
|
112 |
+
output_height = height * self.scale
|
113 |
+
output_width = width * self.scale
|
114 |
+
output_shape = (batch, channel, output_height, output_width)
|
115 |
+
|
116 |
+
# start with black image
|
117 |
+
self.output = self.img.new_zeros(output_shape)
|
118 |
+
tiles_x = math.ceil(width / self.tile_size)
|
119 |
+
tiles_y = math.ceil(height / self.tile_size)
|
120 |
+
|
121 |
+
# loop over all tiles
|
122 |
+
for y in range(tiles_y):
|
123 |
+
for x in range(tiles_x):
|
124 |
+
# extract tile from input image
|
125 |
+
ofs_x = x * self.tile_size
|
126 |
+
ofs_y = y * self.tile_size
|
127 |
+
# input tile area on total image
|
128 |
+
input_start_x = ofs_x
|
129 |
+
input_end_x = min(ofs_x + self.tile_size, width)
|
130 |
+
input_start_y = ofs_y
|
131 |
+
input_end_y = min(ofs_y + self.tile_size, height)
|
132 |
+
|
133 |
+
# input tile area on total image with padding
|
134 |
+
input_start_x_pad = max(input_start_x - self.tile_pad, 0)
|
135 |
+
input_end_x_pad = min(input_end_x + self.tile_pad, width)
|
136 |
+
input_start_y_pad = max(input_start_y - self.tile_pad, 0)
|
137 |
+
input_end_y_pad = min(input_end_y + self.tile_pad, height)
|
138 |
+
|
139 |
+
# input tile dimensions
|
140 |
+
input_tile_width = input_end_x - input_start_x
|
141 |
+
input_tile_height = input_end_y - input_start_y
|
142 |
+
tile_idx = y * tiles_x + x + 1
|
143 |
+
input_tile = self.img[:, :, input_start_y_pad:input_end_y_pad, input_start_x_pad:input_end_x_pad]
|
144 |
+
|
145 |
+
# upscale tile
|
146 |
+
try:
|
147 |
+
with torch.no_grad():
|
148 |
+
output_tile = self.model(input_tile)
|
149 |
+
except RuntimeError as error:
|
150 |
+
print('Error', error)
|
151 |
+
print(f'\tTile {tile_idx}/{tiles_x * tiles_y}')
|
152 |
+
|
153 |
+
# output tile area on total image
|
154 |
+
output_start_x = input_start_x * self.scale
|
155 |
+
output_end_x = input_end_x * self.scale
|
156 |
+
output_start_y = input_start_y * self.scale
|
157 |
+
output_end_y = input_end_y * self.scale
|
158 |
+
|
159 |
+
# output tile area without padding
|
160 |
+
output_start_x_tile = (input_start_x - input_start_x_pad) * self.scale
|
161 |
+
output_end_x_tile = output_start_x_tile + input_tile_width * self.scale
|
162 |
+
output_start_y_tile = (input_start_y - input_start_y_pad) * self.scale
|
163 |
+
output_end_y_tile = output_start_y_tile + input_tile_height * self.scale
|
164 |
+
|
165 |
+
# put tile into output image
|
166 |
+
self.output[:, :, output_start_y:output_end_y, output_start_x:output_end_x] = output_tile[:, :, output_start_y_tile:output_end_y_tile, output_start_x_tile:output_end_x_tile]
|
167 |
+
|
168 |
+
def post_process(self):
|
169 |
+
# remove extra pad
|
170 |
+
if self.mod_scale is not None:
|
171 |
+
_, _, h, w = self.output.size()
|
172 |
+
self.output = self.output[:, :, 0:h - self.mod_pad_h * self.scale, 0:w - self.mod_pad_w * self.scale]
|
173 |
+
# remove prepad
|
174 |
+
if self.pre_pad != 0:
|
175 |
+
_, _, h, w = self.output.size()
|
176 |
+
self.output = self.output[:, :, 0:h - self.pre_pad * self.scale, 0:w - self.pre_pad * self.scale]
|
177 |
+
return self.output
|
178 |
+
|
179 |
+
@torch.no_grad()
|
180 |
+
def enhance(self, img, outscale = None, num_out_ch = 3, alpha_upsampler = 'realesrgan'):
|
181 |
+
h_input, w_input = img.shape[0:2]
|
182 |
+
# img: numpy
|
183 |
+
img = img.astype(np.float32)
|
184 |
+
if np.max(img) > 256: # 16-bit image
|
185 |
+
max_range = 65535
|
186 |
+
print('\tInput is a 16-bit image')
|
187 |
+
else:
|
188 |
+
max_range = 255
|
189 |
+
img = img / max_range
|
190 |
+
if len(img.shape) == 2: # gray image
|
191 |
+
img_mode = 'L'
|
192 |
+
img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
|
193 |
+
elif img.shape[2] == 4: # RGBA image with alpha channel
|
194 |
+
img_mode = 'RGBA'
|
195 |
+
if num_out_ch != 3:
|
196 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGRA2RGBA)
|
197 |
+
else:
|
198 |
+
alpha = img[:, :, 3]
|
199 |
+
img = img[:, :, 0:3]
|
200 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
201 |
+
if alpha_upsampler == 'realesrgan':
|
202 |
+
alpha = cv2.cvtColor(alpha, cv2.COLOR_GRAY2RGB)
|
203 |
+
else:
|
204 |
+
img_mode = 'RGB'
|
205 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
206 |
+
|
207 |
+
# ------------------- process image (without the alpha channel) ------------------- #
|
208 |
+
self.pre_process(img)
|
209 |
+
if self.tile_size > 0:
|
210 |
+
self.tile_process()
|
211 |
+
else:
|
212 |
+
self.process()
|
213 |
+
output_img = self.post_process()
|
214 |
+
output_img = output_img.data.squeeze().float().cpu().clamp_(0, 1).numpy()
|
215 |
+
img_struct_list = []
|
216 |
+
for i in range(3, num_out_ch):
|
217 |
+
img_struct_list.append(i)
|
218 |
+
output_img = output_img[[2, 1, 0] + img_struct_list, :, :]
|
219 |
+
output_img = np.transpose(output_img, (1, 2, 0))
|
220 |
+
if img_mode == 'L':
|
221 |
+
output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2GRAY)
|
222 |
+
|
223 |
+
# ------------------- process the alpha channel if necessary ------------------- #
|
224 |
+
if img_mode == 'RGBA' and num_out_ch == 3:
|
225 |
+
if alpha_upsampler == 'realesrgan':
|
226 |
+
self.pre_process(alpha)
|
227 |
+
if self.tile_size > 0:
|
228 |
+
self.tile_process()
|
229 |
+
else:
|
230 |
+
self.process()
|
231 |
+
output_alpha = self.post_process()
|
232 |
+
output_alpha = output_alpha.data.squeeze().float().cpu().clamp_(0, 1).numpy()
|
233 |
+
output_alpha = np.transpose(output_alpha[[2, 1, 0], :, :], (1, 2, 0))
|
234 |
+
output_alpha = cv2.cvtColor(output_alpha, cv2.COLOR_BGR2GRAY)
|
235 |
+
else: # use the cv2 resize for alpha channel
|
236 |
+
h, w = alpha.shape[0:2]
|
237 |
+
output_alpha = cv2.resize(alpha, (w * self.scale, h * self.scale), interpolation=cv2.INTER_LINEAR)
|
238 |
+
|
239 |
+
# merge the alpha channel
|
240 |
+
output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2BGRA)
|
241 |
+
output_img[:, :, 3] = output_alpha
|
242 |
+
|
243 |
+
# ------------------------------ return ------------------------------ #
|
244 |
+
if max_range == 65535: # 16-bit image
|
245 |
+
output = (output_img * 65535.0).round().astype(np.uint16)
|
246 |
+
else:
|
247 |
+
output = (output_img * 255.0).round().astype(np.uint8)
|
248 |
+
|
249 |
+
if outscale is not None and outscale != float(self.scale):
|
250 |
+
output = cv2.resize(output, (int(w_input * outscale), int(h_input * outscale)), interpolation = cv2.INTER_LANCZOS4)
|
251 |
+
|
252 |
+
return output, img_mode
|
253 |
+
|
254 |
+
|
255 |
+
class PrefetchReader(threading.Thread):
|
256 |
+
"""Prefetch images.
|
257 |
+
|
258 |
+
Args:
|
259 |
+
img_list (list[str]): A image list of image paths to be read.
|
260 |
+
num_prefetch_queue (int): Number of prefetch queue.
|
261 |
+
"""
|
262 |
+
|
263 |
+
def __init__(self, img_list, num_prefetch_queue):
|
264 |
+
super().__init__()
|
265 |
+
self.que = queue.Queue(num_prefetch_queue)
|
266 |
+
self.img_list = img_list
|
267 |
+
|
268 |
+
def run(self):
|
269 |
+
for img_path in self.img_list:
|
270 |
+
img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
|
271 |
+
self.que.put(img)
|
272 |
+
|
273 |
+
self.que.put(None)
|
274 |
+
|
275 |
+
def __next__(self):
|
276 |
+
next_item = self.que.get()
|
277 |
+
if next_item is None:
|
278 |
+
raise StopIteration
|
279 |
+
return next_item
|
280 |
+
|
281 |
+
def __iter__(self):
|
282 |
+
return self
|
283 |
+
|
284 |
+
|
285 |
+
class IOConsumer(threading.Thread):
|
286 |
+
|
287 |
+
def __init__(self, opt, que, qid):
|
288 |
+
super().__init__()
|
289 |
+
self._queue = que
|
290 |
+
self.qid = qid
|
291 |
+
self.opt = opt
|
292 |
+
|
293 |
+
def run(self):
|
294 |
+
while True:
|
295 |
+
msg = self._queue.get()
|
296 |
+
if isinstance(msg, str) and msg == 'quit':
|
297 |
+
break
|
298 |
+
|
299 |
+
output = msg['output']
|
300 |
+
save_path = msg['save_path']
|
301 |
+
cv2.imwrite(save_path, output)
|
302 |
+
print(f'IO worker {self.qid} is done.')
|
handler.py
CHANGED
@@ -1,5 +1,10 @@
|
|
1 |
-
import json
|
2 |
import os
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
from pathlib import Path
|
4 |
from typing import Any, Dict, List
|
5 |
|
|
|
|
|
1 |
import os
|
2 |
+
import sys
|
3 |
+
|
4 |
+
path = os.path.dirname(os.path.abspath(__file__))
|
5 |
+
sys.path.insert(1, os.path.join(path, "external"))
|
6 |
+
|
7 |
+
|
8 |
from pathlib import Path
|
9 |
from typing import Any, Dict, List
|
10 |
|
inference.py
CHANGED
@@ -17,10 +17,9 @@ from internals.pipelines.img_classifier import ImageClassifier
|
|
17 |
from internals.pipelines.img_to_text import Image2Text
|
18 |
from internals.pipelines.inpainter import InPainter
|
19 |
from internals.pipelines.object_remove import ObjectRemoval
|
20 |
-
from internals.pipelines.pose_detector import PoseDetector
|
21 |
from internals.pipelines.prompt_modifier import PromptModifier
|
22 |
from internals.pipelines.realtime_draw import RealtimeDraw
|
23 |
-
from internals.pipelines.remove_background import
|
24 |
from internals.pipelines.replace_background import ReplaceBackground
|
25 |
from internals.pipelines.safety_checker import SafetyChecker
|
26 |
from internals.pipelines.sdxl_tile_upscale import SDXLTileUpscaler
|
@@ -45,7 +44,6 @@ from internals.util.config import (
|
|
45 |
set_model_config,
|
46 |
set_root_dir,
|
47 |
)
|
48 |
-
from internals.util.failure_hander import FailureHandler
|
49 |
from internals.util.lora_style import LoraStyle
|
50 |
from internals.util.model_loader import load_model_from_config
|
51 |
from internals.util.slack import Slack
|
@@ -57,14 +55,13 @@ auto_mode = False
|
|
57 |
|
58 |
prompt_modifier = PromptModifier(num_of_sequences=get_num_return_sequences())
|
59 |
upscaler = Upscaler()
|
60 |
-
pose_detector = PoseDetector()
|
61 |
inpainter = InPainter()
|
62 |
high_res = HighRes()
|
63 |
img2text = Image2Text()
|
64 |
img_classifier = ImageClassifier()
|
65 |
object_removal = ObjectRemoval()
|
66 |
replace_background = ReplaceBackground()
|
67 |
-
|
68 |
replace_background = ReplaceBackground()
|
69 |
controlnet = ControlNet()
|
70 |
lora_style = LoraStyle()
|
@@ -92,7 +89,7 @@ def get_patched_prompt_text2img(task: Task):
|
|
92 |
|
93 |
def get_patched_prompt_tile_upscale(task: Task):
|
94 |
return prompt_util.get_patched_prompt_tile_upscale(
|
95 |
-
task, avatar, lora_style, img_classifier, img2text
|
96 |
)
|
97 |
|
98 |
|
@@ -126,20 +123,19 @@ def canny(task: Task):
|
|
126 |
"num_inference_steps": task.get_steps(),
|
127 |
"width": width,
|
128 |
"height": height,
|
129 |
-
"negative_prompt": [
|
130 |
-
|
131 |
-
]
|
132 |
-
* get_num_return_sequences(),
|
133 |
**task.cnc_kwargs(),
|
134 |
**lora_patcher.kwargs(),
|
135 |
}
|
136 |
-
images, has_nsfw = controlnet.process(**kwargs)
|
137 |
if task.get_high_res_fix():
|
138 |
kwargs = {
|
139 |
"prompt": prompt,
|
140 |
"negative_prompt": [task.get_negative_prompt()]
|
141 |
* get_num_return_sequences(),
|
142 |
"images": images,
|
|
|
143 |
"width": task.get_width(),
|
144 |
"height": task.get_height(),
|
145 |
"num_inference_steps": task.get_steps(),
|
@@ -147,6 +143,9 @@ def canny(task: Task):
|
|
147 |
}
|
148 |
images, _ = high_res.apply(**kwargs)
|
149 |
|
|
|
|
|
|
|
150 |
generated_image_urls = upload_images(images, "_canny", task.get_taskId())
|
151 |
|
152 |
lora_patcher.cleanup()
|
@@ -162,48 +161,102 @@ def canny(task: Task):
|
|
162 |
@update_db
|
163 |
@auto_clear_cuda_and_gc(controlnet)
|
164 |
@slack.auto_send_alert
|
165 |
-
def
|
166 |
-
|
167 |
-
|
168 |
-
prompt = get_patched_prompt_tile_upscale(task)
|
169 |
-
|
170 |
-
if get_is_sdxl():
|
171 |
-
lora_patcher = lora_style.get_patcher(
|
172 |
-
[sdxl_tileupscaler.pipe, high_res.pipe], task.get_style()
|
173 |
-
)
|
174 |
-
lora_patcher.patch()
|
175 |
|
176 |
-
|
177 |
-
prompt=prompt,
|
178 |
-
imageUrl=task.get_imageUrl(),
|
179 |
-
resize_dimension=task.get_resize_dimension(),
|
180 |
-
negative_prompt=task.get_negative_prompt(),
|
181 |
-
width=task.get_width(),
|
182 |
-
height=task.get_height(),
|
183 |
-
model_id=task.get_model_id(),
|
184 |
-
)
|
185 |
|
186 |
-
|
187 |
-
else:
|
188 |
-
controlnet.load_model("tile_upscaler")
|
189 |
|
190 |
-
|
191 |
-
|
|
|
|
|
192 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
193 |
kwargs = {
|
194 |
-
"
|
|
|
|
|
|
|
195 |
"seed": task.get_seed(),
|
196 |
-
"num_inference_steps": task.get_steps(),
|
197 |
-
"negative_prompt": task.get_negative_prompt(),
|
198 |
"width": task.get_width(),
|
199 |
"height": task.get_height(),
|
200 |
-
"
|
201 |
-
|
202 |
-
**task.cnt_kwargs(),
|
203 |
}
|
204 |
-
images,
|
205 |
-
|
206 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
207 |
|
208 |
generated_image_url = upload_image(images[0], output_key)
|
209 |
|
@@ -229,12 +282,7 @@ def scribble(task: Task):
|
|
229 |
)
|
230 |
lora_patcher.patch()
|
231 |
|
232 |
-
image =
|
233 |
-
if get_is_sdxl():
|
234 |
-
# We use sketch in SDXL
|
235 |
-
image = ControlNet.pidinet_image(image)
|
236 |
-
else:
|
237 |
-
image = ControlNet.scribble_image(image)
|
238 |
|
239 |
kwargs = {
|
240 |
"image": [image] * get_num_return_sequences(),
|
@@ -244,9 +292,10 @@ def scribble(task: Task):
|
|
244 |
"height": height,
|
245 |
"prompt": prompt,
|
246 |
"negative_prompt": [task.get_negative_prompt()] * get_num_return_sequences(),
|
|
|
247 |
**task.cns_kwargs(),
|
248 |
}
|
249 |
-
images, has_nsfw = controlnet.process(**kwargs)
|
250 |
|
251 |
if task.get_high_res_fix():
|
252 |
kwargs = {
|
@@ -256,11 +305,15 @@ def scribble(task: Task):
|
|
256 |
"images": images,
|
257 |
"width": task.get_width(),
|
258 |
"height": task.get_height(),
|
|
|
259 |
"num_inference_steps": task.get_steps(),
|
260 |
**task.high_res_kwargs(),
|
261 |
}
|
262 |
images, _ = high_res.apply(**kwargs)
|
263 |
|
|
|
|
|
|
|
264 |
generated_image_urls = upload_images(images, "_scribble", task.get_taskId())
|
265 |
|
266 |
lora_patcher.cleanup()
|
@@ -296,16 +349,21 @@ def linearart(task: Task):
|
|
296 |
"height": height,
|
297 |
"prompt": prompt,
|
298 |
"negative_prompt": [task.get_negative_prompt()] * get_num_return_sequences(),
|
|
|
299 |
**task.cnl_kwargs(),
|
300 |
}
|
301 |
-
images, has_nsfw = controlnet.process(**kwargs)
|
302 |
|
303 |
if task.get_high_res_fix():
|
|
|
|
|
|
|
304 |
kwargs = {
|
305 |
"prompt": prompt,
|
306 |
"negative_prompt": [task.get_negative_prompt()]
|
307 |
* get_num_return_sequences(),
|
308 |
"images": images,
|
|
|
309 |
"width": task.get_width(),
|
310 |
"height": task.get_height(),
|
311 |
"num_inference_steps": task.get_steps(),
|
@@ -313,6 +371,22 @@ def linearart(task: Task):
|
|
313 |
}
|
314 |
images, _ = high_res.apply(**kwargs)
|
315 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
316 |
generated_image_urls = upload_images(images, "_linearart", task.get_taskId())
|
317 |
|
318 |
lora_patcher.cleanup()
|
@@ -341,20 +415,14 @@ def pose(task: Task, s3_outkey: str = "_pose", poses: Optional[list] = None):
|
|
341 |
)
|
342 |
lora_patcher.patch()
|
343 |
|
344 |
-
if not task.
|
|
|
|
|
345 |
print("Not detecting pose")
|
346 |
pose = download_image(task.get_imageUrl()).resize(
|
347 |
(task.get_width(), task.get_height())
|
348 |
)
|
349 |
poses = [pose] * get_num_return_sequences()
|
350 |
-
elif task.get_pose_coordinates():
|
351 |
-
infered_pose = pose_detector.transform(
|
352 |
-
image=task.get_imageUrl(),
|
353 |
-
client_coordinates=task.get_pose_coordinates(),
|
354 |
-
width=task.get_width(),
|
355 |
-
height=task.get_height(),
|
356 |
-
)
|
357 |
-
poses = [infered_pose] * get_num_return_sequences()
|
358 |
else:
|
359 |
poses = [
|
360 |
controlnet.detect_pose(task.get_imageUrl())
|
@@ -370,8 +438,11 @@ def pose(task: Task, s3_outkey: str = "_pose", poses: Optional[list] = None):
|
|
370 |
|
371 |
upload_image(depth, "crecoAI/{}_depth.png".format(task.get_taskId()))
|
372 |
|
|
|
|
|
373 |
kwargs = {
|
374 |
-
"
|
|
|
375 |
}
|
376 |
else:
|
377 |
images = poses[0]
|
@@ -389,7 +460,7 @@ def pose(task: Task, s3_outkey: str = "_pose", poses: Optional[list] = None):
|
|
389 |
**task.cnp_kwargs(),
|
390 |
**lora_patcher.kwargs(),
|
391 |
}
|
392 |
-
images, has_nsfw = controlnet.process(**kwargs)
|
393 |
|
394 |
if task.get_high_res_fix():
|
395 |
kwargs = {
|
@@ -400,11 +471,12 @@ def pose(task: Task, s3_outkey: str = "_pose", poses: Optional[list] = None):
|
|
400 |
"width": task.get_width(),
|
401 |
"height": task.get_height(),
|
402 |
"num_inference_steps": task.get_steps(),
|
|
|
403 |
**task.high_res_kwargs(),
|
404 |
}
|
405 |
images, _ = high_res.apply(**kwargs)
|
406 |
|
407 |
-
upload_image(poses[0], "crecoAI/{}
|
408 |
|
409 |
generated_image_urls = upload_images(images, s3_outkey, task.get_taskId())
|
410 |
|
@@ -431,12 +503,11 @@ def text2img(task: Task):
|
|
431 |
)
|
432 |
lora_patcher.patch()
|
433 |
|
434 |
-
torch.manual_seed(task.get_seed())
|
435 |
-
|
436 |
kwargs = {
|
437 |
"params": params,
|
438 |
"num_inference_steps": task.get_steps(),
|
439 |
"height": height,
|
|
|
440 |
"width": width,
|
441 |
"negative_prompt": task.get_negative_prompt(),
|
442 |
**task.t2i_kwargs(),
|
@@ -455,6 +526,7 @@ def text2img(task: Task):
|
|
455 |
"width": task.get_width(),
|
456 |
"height": task.get_height(),
|
457 |
"num_inference_steps": task.get_steps(),
|
|
|
458 |
**task.high_res_kwargs(),
|
459 |
}
|
460 |
images, _ = high_res.apply(**kwargs)
|
@@ -478,11 +550,9 @@ def img2img(task: Task):
|
|
478 |
|
479 |
width, height = get_intermediate_dimension(task)
|
480 |
|
481 |
-
torch.manual_seed(task.get_seed())
|
482 |
-
|
483 |
if get_is_sdxl():
|
484 |
# we run lineart for img2img
|
485 |
-
controlnet.load_model("
|
486 |
|
487 |
lora_patcher = lora_style.get_patcher(
|
488 |
[controlnet.pipe2, high_res.pipe], task.get_style()
|
@@ -498,10 +568,11 @@ def img2img(task: Task):
|
|
498 |
"prompt": prompt,
|
499 |
"negative_prompt": [task.get_negative_prompt()]
|
500 |
* get_num_return_sequences(),
|
501 |
-
|
502 |
-
"adapter_conditioning_scale": 0.3,
|
|
|
503 |
}
|
504 |
-
images, has_nsfw = controlnet.process(**kwargs)
|
505 |
else:
|
506 |
lora_patcher = lora_style.get_patcher(
|
507 |
[img2img_pipe.pipe, high_res.pipe], task.get_style()
|
@@ -516,6 +587,7 @@ def img2img(task: Task):
|
|
516 |
"num_inference_steps": task.get_steps(),
|
517 |
"width": width,
|
518 |
"height": height,
|
|
|
519 |
**task.i2i_kwargs(),
|
520 |
**lora_patcher.kwargs(),
|
521 |
}
|
@@ -530,6 +602,7 @@ def img2img(task: Task):
|
|
530 |
"width": task.get_width(),
|
531 |
"height": task.get_height(),
|
532 |
"num_inference_steps": task.get_steps(),
|
|
|
533 |
**task.high_res_kwargs(),
|
534 |
}
|
535 |
images, _ = high_res.apply(**kwargs)
|
@@ -568,7 +641,9 @@ def inpaint(task: Task):
|
|
568 |
"num_inference_steps": task.get_steps(),
|
569 |
**task.ip_kwargs(),
|
570 |
}
|
571 |
-
images = inpainter.process(**kwargs)
|
|
|
|
|
572 |
|
573 |
generated_image_urls = upload_images(images, key, task.get_taskId())
|
574 |
|
@@ -617,9 +692,7 @@ def replace_bg(task: Task):
|
|
617 |
@update_db
|
618 |
@slack.auto_send_alert
|
619 |
def remove_bg(task: Task):
|
620 |
-
output_image =
|
621 |
-
task.get_imageUrl(), model_type=task.get_modelType()
|
622 |
-
)
|
623 |
|
624 |
output_key = "crecoAI/{}_rmbg.png".format(task.get_taskId())
|
625 |
image_url = upload_image(output_image, output_key)
|
@@ -732,6 +805,67 @@ def rt_draw_img(task: Task):
|
|
732 |
return {"image": base64_image}
|
733 |
|
734 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
735 |
def custom_action(task: Task):
|
736 |
from external.scripts import __scripts__
|
737 |
|
@@ -759,6 +893,14 @@ def custom_action(task: Task):
|
|
759 |
|
760 |
|
761 |
def load_model_by_task(task_type: TaskType, model_id=-1):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
762 |
if not text2img_pipe.is_loaded():
|
763 |
text2img_pipe.load(get_model_dir())
|
764 |
img2img_pipe.create(text2img_pipe)
|
@@ -782,12 +924,14 @@ def load_model_by_task(task_type: TaskType, model_id=-1):
|
|
782 |
upscaler.load()
|
783 |
else:
|
784 |
if task_type == TaskType.TILE_UPSCALE:
|
785 |
-
if get_is_sdxl():
|
786 |
-
|
787 |
-
else:
|
788 |
-
|
789 |
elif task_type == TaskType.CANNY:
|
790 |
controlnet.load_model("canny")
|
|
|
|
|
791 |
elif task_type == TaskType.SCRIBBLE:
|
792 |
controlnet.load_model("scribble")
|
793 |
elif task_type == TaskType.LINEARART:
|
@@ -798,23 +942,24 @@ def load_model_by_task(task_type: TaskType, model_id=-1):
|
|
798 |
|
799 |
def unload_model_by_task(task_type: TaskType):
|
800 |
if task_type == TaskType.INPAINT or task_type == TaskType.OUTPAINT:
|
801 |
-
inpainter.unload()
|
|
|
802 |
elif task_type == TaskType.REPLACE_BG:
|
803 |
replace_background.unload()
|
804 |
elif task_type == TaskType.OBJECT_REMOVAL:
|
805 |
object_removal.unload()
|
806 |
elif task_type == TaskType.TILE_UPSCALE:
|
807 |
-
if get_is_sdxl():
|
808 |
-
|
809 |
-
else:
|
810 |
-
controlnet.unload()
|
811 |
-
elif task_type == TaskType.CANNY:
|
812 |
controlnet.unload()
|
813 |
-
elif
|
814 |
-
|
815 |
-
|
816 |
-
|
817 |
-
|
|
|
|
|
818 |
controlnet.unload()
|
819 |
|
820 |
|
@@ -831,8 +976,6 @@ def model_fn(model_dir):
|
|
831 |
set_model_config(config)
|
832 |
set_root_dir(__file__)
|
833 |
|
834 |
-
FailureHandler.register()
|
835 |
-
|
836 |
avatar.load_local(model_dir)
|
837 |
|
838 |
lora_style.load(model_dir)
|
@@ -855,15 +998,12 @@ def auto_unload_task(func):
|
|
855 |
|
856 |
|
857 |
@auto_unload_task
|
858 |
-
@FailureHandler.clear
|
859 |
def predict_fn(data, pipe):
|
860 |
task = Task(data)
|
861 |
print("task is ", data)
|
862 |
|
863 |
clear_cuda_and_gc()
|
864 |
|
865 |
-
FailureHandler.handle(task)
|
866 |
-
|
867 |
try:
|
868 |
task_type = task.get_type()
|
869 |
|
@@ -894,11 +1034,16 @@ def predict_fn(data, pipe):
|
|
894 |
avatar.fetch_from_network(task.get_model_id())
|
895 |
|
896 |
if task_type == TaskType.TEXT_TO_IMAGE:
|
|
|
|
|
|
|
897 |
return text2img(task)
|
898 |
elif task_type == TaskType.IMAGE_TO_IMAGE:
|
899 |
return img2img(task)
|
900 |
elif task_type == TaskType.CANNY:
|
901 |
return canny(task)
|
|
|
|
|
902 |
elif task_type == TaskType.POSE:
|
903 |
return pose(task)
|
904 |
elif task_type == TaskType.TILE_UPSCALE:
|
|
|
17 |
from internals.pipelines.img_to_text import Image2Text
|
18 |
from internals.pipelines.inpainter import InPainter
|
19 |
from internals.pipelines.object_remove import ObjectRemoval
|
|
|
20 |
from internals.pipelines.prompt_modifier import PromptModifier
|
21 |
from internals.pipelines.realtime_draw import RealtimeDraw
|
22 |
+
from internals.pipelines.remove_background import RemoveBackgroundV3
|
23 |
from internals.pipelines.replace_background import ReplaceBackground
|
24 |
from internals.pipelines.safety_checker import SafetyChecker
|
25 |
from internals.pipelines.sdxl_tile_upscale import SDXLTileUpscaler
|
|
|
44 |
set_model_config,
|
45 |
set_root_dir,
|
46 |
)
|
|
|
47 |
from internals.util.lora_style import LoraStyle
|
48 |
from internals.util.model_loader import load_model_from_config
|
49 |
from internals.util.slack import Slack
|
|
|
55 |
|
56 |
prompt_modifier = PromptModifier(num_of_sequences=get_num_return_sequences())
|
57 |
upscaler = Upscaler()
|
|
|
58 |
inpainter = InPainter()
|
59 |
high_res = HighRes()
|
60 |
img2text = Image2Text()
|
61 |
img_classifier = ImageClassifier()
|
62 |
object_removal = ObjectRemoval()
|
63 |
replace_background = ReplaceBackground()
|
64 |
+
remove_background_v3 = RemoveBackgroundV3()
|
65 |
replace_background = ReplaceBackground()
|
66 |
controlnet = ControlNet()
|
67 |
lora_style = LoraStyle()
|
|
|
89 |
|
90 |
def get_patched_prompt_tile_upscale(task: Task):
|
91 |
return prompt_util.get_patched_prompt_tile_upscale(
|
92 |
+
task, avatar, lora_style, img_classifier, img2text, is_sdxl=get_is_sdxl()
|
93 |
)
|
94 |
|
95 |
|
|
|
123 |
"num_inference_steps": task.get_steps(),
|
124 |
"width": width,
|
125 |
"height": height,
|
126 |
+
"negative_prompt": [task.get_negative_prompt()] * get_num_return_sequences(),
|
127 |
+
"apply_preprocess": task.get_apply_preprocess(),
|
|
|
|
|
128 |
**task.cnc_kwargs(),
|
129 |
**lora_patcher.kwargs(),
|
130 |
}
|
131 |
+
(images, has_nsfw), control_image = controlnet.process(**kwargs)
|
132 |
if task.get_high_res_fix():
|
133 |
kwargs = {
|
134 |
"prompt": prompt,
|
135 |
"negative_prompt": [task.get_negative_prompt()]
|
136 |
* get_num_return_sequences(),
|
137 |
"images": images,
|
138 |
+
"seed": task.get_seed(),
|
139 |
"width": task.get_width(),
|
140 |
"height": task.get_height(),
|
141 |
"num_inference_steps": task.get_steps(),
|
|
|
143 |
}
|
144 |
images, _ = high_res.apply(**kwargs)
|
145 |
|
146 |
+
upload_image(
|
147 |
+
control_image, f"crecoAI/{task.get_taskId()}_condition.png" # pyright: ignore
|
148 |
+
)
|
149 |
generated_image_urls = upload_images(images, "_canny", task.get_taskId())
|
150 |
|
151 |
lora_patcher.cleanup()
|
|
|
161 |
@update_db
|
162 |
@auto_clear_cuda_and_gc(controlnet)
|
163 |
@slack.auto_send_alert
|
164 |
+
def canny_img2img(task: Task):
|
165 |
+
prompt, _ = get_patched_prompt(task)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
166 |
|
167 |
+
width, height = get_intermediate_dimension(task)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
168 |
|
169 |
+
controlnet.load_model("canny_2x")
|
|
|
|
|
170 |
|
171 |
+
lora_patcher = lora_style.get_patcher(
|
172 |
+
[controlnet.pipe, high_res.pipe], task.get_style()
|
173 |
+
)
|
174 |
+
lora_patcher.patch()
|
175 |
|
176 |
+
kwargs = {
|
177 |
+
"prompt": prompt,
|
178 |
+
"imageUrl": task.get_imageUrl(),
|
179 |
+
"seed": task.get_seed(),
|
180 |
+
"num_inference_steps": task.get_steps(),
|
181 |
+
"width": width,
|
182 |
+
"height": height,
|
183 |
+
"negative_prompt": [task.get_negative_prompt()] * get_num_return_sequences(),
|
184 |
+
**task.cnci2i_kwargs(),
|
185 |
+
**lora_patcher.kwargs(),
|
186 |
+
}
|
187 |
+
(images, has_nsfw), control_image = controlnet.process(**kwargs)
|
188 |
+
if task.get_high_res_fix():
|
189 |
+
# we run both here normal upscaler and highres
|
190 |
+
# and show normal upscaler image as output
|
191 |
+
# but use highres image for tile upscale
|
192 |
kwargs = {
|
193 |
+
"prompt": prompt,
|
194 |
+
"negative_prompt": [task.get_negative_prompt()]
|
195 |
+
* get_num_return_sequences(),
|
196 |
+
"images": images,
|
197 |
"seed": task.get_seed(),
|
|
|
|
|
198 |
"width": task.get_width(),
|
199 |
"height": task.get_height(),
|
200 |
+
"num_inference_steps": task.get_steps(),
|
201 |
+
**task.high_res_kwargs(),
|
|
|
202 |
}
|
203 |
+
images, _ = high_res.apply(**kwargs)
|
204 |
+
|
205 |
+
# upload_images(images_high_res, "_canny_2x_highres", task.get_taskId())
|
206 |
+
|
207 |
+
for i, image in enumerate(images):
|
208 |
+
img = upscaler.upscale(
|
209 |
+
image=image,
|
210 |
+
width=task.get_width(),
|
211 |
+
height=task.get_height(),
|
212 |
+
face_enhance=task.get_face_enhance(),
|
213 |
+
resize_dimension=None,
|
214 |
+
)
|
215 |
+
img = Upscaler.to_pil(img)
|
216 |
+
images[i] = img.resize((task.get_width(), task.get_height()))
|
217 |
+
|
218 |
+
upload_image(
|
219 |
+
control_image, f"crecoAI/{task.get_taskId()}_condition.png" # pyright: ignore
|
220 |
+
)
|
221 |
+
generated_image_urls = upload_images(images, "_canny_2x", task.get_taskId())
|
222 |
+
|
223 |
+
lora_patcher.cleanup()
|
224 |
+
controlnet.cleanup()
|
225 |
+
|
226 |
+
return {
|
227 |
+
"modified_prompts": prompt,
|
228 |
+
"generated_image_urls": generated_image_urls,
|
229 |
+
"has_nsfw": has_nsfw,
|
230 |
+
}
|
231 |
+
|
232 |
+
|
233 |
+
@update_db
|
234 |
+
@auto_clear_cuda_and_gc(controlnet)
|
235 |
+
@slack.auto_send_alert
|
236 |
+
def tile_upscale(task: Task):
|
237 |
+
output_key = "crecoAI/{}_tile_upscaler.png".format(task.get_taskId())
|
238 |
+
|
239 |
+
prompt = get_patched_prompt_tile_upscale(task)
|
240 |
+
|
241 |
+
controlnet.load_model("tile_upscaler")
|
242 |
+
|
243 |
+
lora_patcher = lora_style.get_patcher(controlnet.pipe, task.get_style())
|
244 |
+
lora_patcher.patch()
|
245 |
+
|
246 |
+
kwargs = {
|
247 |
+
"imageUrl": task.get_imageUrl(),
|
248 |
+
"seed": task.get_seed(),
|
249 |
+
"num_inference_steps": task.get_steps(),
|
250 |
+
"negative_prompt": task.get_negative_prompt(),
|
251 |
+
"width": task.get_width(),
|
252 |
+
"height": task.get_height(),
|
253 |
+
"prompt": prompt,
|
254 |
+
"resize_dimension": task.get_resize_dimension(),
|
255 |
+
**task.cnt_kwargs(),
|
256 |
+
}
|
257 |
+
(images, has_nsfw), _ = controlnet.process(**kwargs)
|
258 |
+
lora_patcher.cleanup()
|
259 |
+
controlnet.cleanup()
|
260 |
|
261 |
generated_image_url = upload_image(images[0], output_key)
|
262 |
|
|
|
282 |
)
|
283 |
lora_patcher.patch()
|
284 |
|
285 |
+
image = controlnet.preprocess_image(task.get_imageUrl(), width, height)
|
|
|
|
|
|
|
|
|
|
|
286 |
|
287 |
kwargs = {
|
288 |
"image": [image] * get_num_return_sequences(),
|
|
|
292 |
"height": height,
|
293 |
"prompt": prompt,
|
294 |
"negative_prompt": [task.get_negative_prompt()] * get_num_return_sequences(),
|
295 |
+
"apply_preprocess": task.get_apply_preprocess(),
|
296 |
**task.cns_kwargs(),
|
297 |
}
|
298 |
+
(images, has_nsfw), condition_image = controlnet.process(**kwargs)
|
299 |
|
300 |
if task.get_high_res_fix():
|
301 |
kwargs = {
|
|
|
305 |
"images": images,
|
306 |
"width": task.get_width(),
|
307 |
"height": task.get_height(),
|
308 |
+
"seed": task.get_seed(),
|
309 |
"num_inference_steps": task.get_steps(),
|
310 |
**task.high_res_kwargs(),
|
311 |
}
|
312 |
images, _ = high_res.apply(**kwargs)
|
313 |
|
314 |
+
upload_image(
|
315 |
+
condition_image, f"crecoAI/{task.get_taskId()}_condition.png" # pyright: ignore
|
316 |
+
)
|
317 |
generated_image_urls = upload_images(images, "_scribble", task.get_taskId())
|
318 |
|
319 |
lora_patcher.cleanup()
|
|
|
349 |
"height": height,
|
350 |
"prompt": prompt,
|
351 |
"negative_prompt": [task.get_negative_prompt()] * get_num_return_sequences(),
|
352 |
+
"apply_preprocess": task.get_apply_preprocess(),
|
353 |
**task.cnl_kwargs(),
|
354 |
}
|
355 |
+
(images, has_nsfw), condition_image = controlnet.process(**kwargs)
|
356 |
|
357 |
if task.get_high_res_fix():
|
358 |
+
# we run both here normal upscaler and highres
|
359 |
+
# and show normal upscaler image as output
|
360 |
+
# but use highres image for tile upscale
|
361 |
kwargs = {
|
362 |
"prompt": prompt,
|
363 |
"negative_prompt": [task.get_negative_prompt()]
|
364 |
* get_num_return_sequences(),
|
365 |
"images": images,
|
366 |
+
"seed": task.get_seed(),
|
367 |
"width": task.get_width(),
|
368 |
"height": task.get_height(),
|
369 |
"num_inference_steps": task.get_steps(),
|
|
|
371 |
}
|
372 |
images, _ = high_res.apply(**kwargs)
|
373 |
|
374 |
+
# upload_images(images_high_res, "_linearart_highres", task.get_taskId())
|
375 |
+
#
|
376 |
+
# for i, image in enumerate(images):
|
377 |
+
# img = upscaler.upscale(
|
378 |
+
# image=image,
|
379 |
+
# width=task.get_width(),
|
380 |
+
# height=task.get_height(),
|
381 |
+
# face_enhance=task.get_face_enhance(),
|
382 |
+
# resize_dimension=None,
|
383 |
+
# )
|
384 |
+
# img = Upscaler.to_pil(img)
|
385 |
+
# images[i] = img
|
386 |
+
|
387 |
+
upload_image(
|
388 |
+
condition_image, f"crecoAI/{task.get_taskId()}_condition.png" # pyright: ignore
|
389 |
+
)
|
390 |
generated_image_urls = upload_images(images, "_linearart", task.get_taskId())
|
391 |
|
392 |
lora_patcher.cleanup()
|
|
|
415 |
)
|
416 |
lora_patcher.patch()
|
417 |
|
418 |
+
if not task.get_apply_preprocess():
|
419 |
+
poses = [download_image(task.get_imageUrl()).resize((width, height))]
|
420 |
+
elif not task.get_pose_estimation():
|
421 |
print("Not detecting pose")
|
422 |
pose = download_image(task.get_imageUrl()).resize(
|
423 |
(task.get_width(), task.get_height())
|
424 |
)
|
425 |
poses = [pose] * get_num_return_sequences()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
426 |
else:
|
427 |
poses = [
|
428 |
controlnet.detect_pose(task.get_imageUrl())
|
|
|
438 |
|
439 |
upload_image(depth, "crecoAI/{}_depth.png".format(task.get_taskId()))
|
440 |
|
441 |
+
scale = task.cnp_kwargs().pop("controlnet_conditioning_scale", None)
|
442 |
+
factor = task.cnp_kwargs().pop("control_guidance_end", None)
|
443 |
kwargs = {
|
444 |
+
"controlnet_conditioning_scale": [1.0, scale or 1.0],
|
445 |
+
"control_guidance_end": [0.5, factor or 1.0],
|
446 |
}
|
447 |
else:
|
448 |
images = poses[0]
|
|
|
460 |
**task.cnp_kwargs(),
|
461 |
**lora_patcher.kwargs(),
|
462 |
}
|
463 |
+
(images, has_nsfw), _ = controlnet.process(**kwargs)
|
464 |
|
465 |
if task.get_high_res_fix():
|
466 |
kwargs = {
|
|
|
471 |
"width": task.get_width(),
|
472 |
"height": task.get_height(),
|
473 |
"num_inference_steps": task.get_steps(),
|
474 |
+
"seed": task.get_seed(),
|
475 |
**task.high_res_kwargs(),
|
476 |
}
|
477 |
images, _ = high_res.apply(**kwargs)
|
478 |
|
479 |
+
upload_image(poses[0], "crecoAI/{}_condition.png".format(task.get_taskId()))
|
480 |
|
481 |
generated_image_urls = upload_images(images, s3_outkey, task.get_taskId())
|
482 |
|
|
|
503 |
)
|
504 |
lora_patcher.patch()
|
505 |
|
|
|
|
|
506 |
kwargs = {
|
507 |
"params": params,
|
508 |
"num_inference_steps": task.get_steps(),
|
509 |
"height": height,
|
510 |
+
"seed": task.get_seed(),
|
511 |
"width": width,
|
512 |
"negative_prompt": task.get_negative_prompt(),
|
513 |
**task.t2i_kwargs(),
|
|
|
526 |
"width": task.get_width(),
|
527 |
"height": task.get_height(),
|
528 |
"num_inference_steps": task.get_steps(),
|
529 |
+
"seed": task.get_seed(),
|
530 |
**task.high_res_kwargs(),
|
531 |
}
|
532 |
images, _ = high_res.apply(**kwargs)
|
|
|
550 |
|
551 |
width, height = get_intermediate_dimension(task)
|
552 |
|
|
|
|
|
553 |
if get_is_sdxl():
|
554 |
# we run lineart for img2img
|
555 |
+
controlnet.load_model("canny")
|
556 |
|
557 |
lora_patcher = lora_style.get_patcher(
|
558 |
[controlnet.pipe2, high_res.pipe], task.get_style()
|
|
|
568 |
"prompt": prompt,
|
569 |
"negative_prompt": [task.get_negative_prompt()]
|
570 |
* get_num_return_sequences(),
|
571 |
+
"controlnet_conditioning_scale": 0.5,
|
572 |
+
# "adapter_conditioning_scale": 0.3,
|
573 |
+
**task.i2i_kwargs(),
|
574 |
}
|
575 |
+
(images, has_nsfw), _ = controlnet.process(**kwargs)
|
576 |
else:
|
577 |
lora_patcher = lora_style.get_patcher(
|
578 |
[img2img_pipe.pipe, high_res.pipe], task.get_style()
|
|
|
587 |
"num_inference_steps": task.get_steps(),
|
588 |
"width": width,
|
589 |
"height": height,
|
590 |
+
"seed": task.get_seed(),
|
591 |
**task.i2i_kwargs(),
|
592 |
**lora_patcher.kwargs(),
|
593 |
}
|
|
|
602 |
"width": task.get_width(),
|
603 |
"height": task.get_height(),
|
604 |
"num_inference_steps": task.get_steps(),
|
605 |
+
"seed": task.get_seed(),
|
606 |
**task.high_res_kwargs(),
|
607 |
}
|
608 |
images, _ = high_res.apply(**kwargs)
|
|
|
641 |
"num_inference_steps": task.get_steps(),
|
642 |
**task.ip_kwargs(),
|
643 |
}
|
644 |
+
images, mask = inpainter.process(**kwargs)
|
645 |
+
|
646 |
+
upload_image(mask, "crecoAI/{}_mask.png".format(task.get_taskId()))
|
647 |
|
648 |
generated_image_urls = upload_images(images, key, task.get_taskId())
|
649 |
|
|
|
692 |
@update_db
|
693 |
@slack.auto_send_alert
|
694 |
def remove_bg(task: Task):
|
695 |
+
output_image = remove_background_v3.remove(task.get_imageUrl())
|
|
|
|
|
696 |
|
697 |
output_key = "crecoAI/{}_rmbg.png".format(task.get_taskId())
|
698 |
image_url = upload_image(output_image, output_key)
|
|
|
805 |
return {"image": base64_image}
|
806 |
|
807 |
|
808 |
+
@update_db
|
809 |
+
@auto_clear_cuda_and_gc(controlnet)
|
810 |
+
@slack.auto_send_alert
|
811 |
+
def depth_rig(task: Task):
|
812 |
+
# Note : This task is for only processing a hardcoded character rig model using depth controlnet
|
813 |
+
# Hack : This model requires hardcoded depth images for optimal processing, so we pass it by default
|
814 |
+
default_depth_url = "https://s3.ap-south-1.amazonaws.com/assets.autodraft.in/character-sheet/rigs/character-rig-depth-map.png"
|
815 |
+
|
816 |
+
params = get_patched_prompt_text2img(task)
|
817 |
+
|
818 |
+
width, height = get_intermediate_dimension(task)
|
819 |
+
|
820 |
+
controlnet.load_model("depth")
|
821 |
+
|
822 |
+
lora_patcher = lora_style.get_patcher(
|
823 |
+
[controlnet.pipe2, high_res.pipe], task.get_style()
|
824 |
+
)
|
825 |
+
lora_patcher.patch()
|
826 |
+
|
827 |
+
kwargs = {
|
828 |
+
"params": params,
|
829 |
+
"prompt": params.prompt,
|
830 |
+
"num_inference_steps": task.get_steps(),
|
831 |
+
"imageUrl": default_depth_url,
|
832 |
+
"height": height,
|
833 |
+
"seed": task.get_seed(),
|
834 |
+
"width": width,
|
835 |
+
"negative_prompt": task.get_negative_prompt(),
|
836 |
+
**task.t2i_kwargs(),
|
837 |
+
**lora_patcher.kwargs(),
|
838 |
+
}
|
839 |
+
(images, has_nsfw), condition_image = controlnet.process(**kwargs)
|
840 |
+
|
841 |
+
if task.get_high_res_fix():
|
842 |
+
kwargs = {
|
843 |
+
"prompt": params.prompt
|
844 |
+
if params.prompt
|
845 |
+
else [""] * get_num_return_sequences(),
|
846 |
+
"negative_prompt": [task.get_negative_prompt()]
|
847 |
+
* get_num_return_sequences(),
|
848 |
+
"images": images,
|
849 |
+
"width": task.get_width(),
|
850 |
+
"height": task.get_height(),
|
851 |
+
"num_inference_steps": task.get_steps(),
|
852 |
+
"seed": task.get_seed(),
|
853 |
+
**task.high_res_kwargs(),
|
854 |
+
}
|
855 |
+
images, _ = high_res.apply(**kwargs)
|
856 |
+
|
857 |
+
upload_image(condition_image, "crecoAI/{}_condition.png".format(task.get_taskId()))
|
858 |
+
generated_image_urls = upload_images(images, "", task.get_taskId())
|
859 |
+
|
860 |
+
lora_patcher.cleanup()
|
861 |
+
|
862 |
+
return {
|
863 |
+
**params.__dict__,
|
864 |
+
"generated_image_urls": generated_image_urls,
|
865 |
+
"has_nsfw": has_nsfw,
|
866 |
+
}
|
867 |
+
|
868 |
+
|
869 |
def custom_action(task: Task):
|
870 |
from external.scripts import __scripts__
|
871 |
|
|
|
893 |
|
894 |
|
895 |
def load_model_by_task(task_type: TaskType, model_id=-1):
|
896 |
+
from internals.pipelines.controlnets import clear_networks
|
897 |
+
|
898 |
+
# pre-cleanup inpaint and controlnet models
|
899 |
+
if task_type == TaskType.INPAINT or task_type == TaskType.OUTPAINT:
|
900 |
+
clear_networks()
|
901 |
+
else:
|
902 |
+
inpainter.unload()
|
903 |
+
|
904 |
if not text2img_pipe.is_loaded():
|
905 |
text2img_pipe.load(get_model_dir())
|
906 |
img2img_pipe.create(text2img_pipe)
|
|
|
924 |
upscaler.load()
|
925 |
else:
|
926 |
if task_type == TaskType.TILE_UPSCALE:
|
927 |
+
# if get_is_sdxl():
|
928 |
+
# sdxl_tileupscaler.create(high_res, text2img_pipe, model_id)
|
929 |
+
# else:
|
930 |
+
controlnet.load_model("tile_upscaler")
|
931 |
elif task_type == TaskType.CANNY:
|
932 |
controlnet.load_model("canny")
|
933 |
+
elif task_type == TaskType.CANNY_IMG2IMG:
|
934 |
+
controlnet.load_model("canny_2x")
|
935 |
elif task_type == TaskType.SCRIBBLE:
|
936 |
controlnet.load_model("scribble")
|
937 |
elif task_type == TaskType.LINEARART:
|
|
|
942 |
|
943 |
def unload_model_by_task(task_type: TaskType):
|
944 |
if task_type == TaskType.INPAINT or task_type == TaskType.OUTPAINT:
|
945 |
+
# inpainter.unload()
|
946 |
+
pass
|
947 |
elif task_type == TaskType.REPLACE_BG:
|
948 |
replace_background.unload()
|
949 |
elif task_type == TaskType.OBJECT_REMOVAL:
|
950 |
object_removal.unload()
|
951 |
elif task_type == TaskType.TILE_UPSCALE:
|
952 |
+
# if get_is_sdxl():
|
953 |
+
# sdxl_tileupscaler.unload()
|
954 |
+
# else:
|
|
|
|
|
955 |
controlnet.unload()
|
956 |
+
elif (
|
957 |
+
task_type == TaskType.CANNY
|
958 |
+
or task_type == TaskType.CANNY_IMG2IMG
|
959 |
+
or task_type == TaskType.SCRIBBLE
|
960 |
+
or task_type == TaskType.LINEARART
|
961 |
+
or task_type == TaskType.POSE
|
962 |
+
):
|
963 |
controlnet.unload()
|
964 |
|
965 |
|
|
|
976 |
set_model_config(config)
|
977 |
set_root_dir(__file__)
|
978 |
|
|
|
|
|
979 |
avatar.load_local(model_dir)
|
980 |
|
981 |
lora_style.load(model_dir)
|
|
|
998 |
|
999 |
|
1000 |
@auto_unload_task
|
|
|
1001 |
def predict_fn(data, pipe):
|
1002 |
task = Task(data)
|
1003 |
print("task is ", data)
|
1004 |
|
1005 |
clear_cuda_and_gc()
|
1006 |
|
|
|
|
|
1007 |
try:
|
1008 |
task_type = task.get_type()
|
1009 |
|
|
|
1034 |
avatar.fetch_from_network(task.get_model_id())
|
1035 |
|
1036 |
if task_type == TaskType.TEXT_TO_IMAGE:
|
1037 |
+
# Hack : Character Rigging Model Task Redirection
|
1038 |
+
if task.get_model_id() == 2000336 or task.get_model_id() == 2000341:
|
1039 |
+
return depth_rig(task)
|
1040 |
return text2img(task)
|
1041 |
elif task_type == TaskType.IMAGE_TO_IMAGE:
|
1042 |
return img2img(task)
|
1043 |
elif task_type == TaskType.CANNY:
|
1044 |
return canny(task)
|
1045 |
+
elif task_type == TaskType.CANNY_IMG2IMG:
|
1046 |
+
return canny_img2img(task)
|
1047 |
elif task_type == TaskType.POSE:
|
1048 |
return pose(task)
|
1049 |
elif task_type == TaskType.TILE_UPSCALE:
|
internals/data/task.py
CHANGED
@@ -11,6 +11,7 @@ class TaskType(Enum):
|
|
11 |
POSE = "POSE"
|
12 |
CANNY = "CANNY"
|
13 |
REMOVE_BG = "REMOVE_BG"
|
|
|
14 |
INPAINT = "INPAINT"
|
15 |
UPSCALE_IMAGE = "UPSCALE_IMAGE"
|
16 |
TILE_UPSCALE = "TILE_UPSCALE"
|
@@ -47,12 +48,18 @@ class Task:
|
|
47 |
elif len(prompt) > 200:
|
48 |
self.__data["prompt"] = data.get("prompt", "")[:200] + ", "
|
49 |
|
|
|
|
|
|
|
50 |
def get_taskId(self) -> str:
|
51 |
return self.__data.get("task_id")
|
52 |
|
53 |
def get_sourceId(self) -> str:
|
54 |
return self.__data.get("source_id")
|
55 |
|
|
|
|
|
|
|
56 |
def get_imageUrl(self) -> str:
|
57 |
return self.__data.get("imageUrl", None)
|
58 |
|
@@ -150,12 +157,18 @@ class Task:
|
|
150 |
def get_access_token(self) -> str:
|
151 |
return self.__data.get("access_token", "")
|
152 |
|
|
|
|
|
|
|
153 |
def get_high_res_fix(self) -> bool:
|
154 |
return self.__data.get("high_res_fix", False)
|
155 |
|
156 |
def get_base_dimension(self):
|
157 |
return self.__data.get("base_dimension", None)
|
158 |
|
|
|
|
|
|
|
159 |
def get_action_data(self) -> dict:
|
160 |
"If task_type is CUSTOM_ACTION, then this will return the action data with 'name' as key"
|
161 |
return self.__data.get("action_data", {})
|
@@ -175,6 +188,9 @@ class Task:
|
|
175 |
def cnc_kwargs(self) -> dict:
|
176 |
return dict(self.__get_kwargs("cnc_"))
|
177 |
|
|
|
|
|
|
|
178 |
def cnp_kwargs(self) -> dict:
|
179 |
return dict(self.__get_kwargs("cnp_"))
|
180 |
|
@@ -192,7 +208,7 @@ class Task:
|
|
192 |
|
193 |
def __get_kwargs(self, prefix: str):
|
194 |
for k, v in self.__data.items():
|
195 |
-
if k.startswith(prefix):
|
196 |
yield k[len(prefix) :], v
|
197 |
|
198 |
@property
|
|
|
11 |
POSE = "POSE"
|
12 |
CANNY = "CANNY"
|
13 |
REMOVE_BG = "REMOVE_BG"
|
14 |
+
CANNY_IMG2IMG = "CANNY_IMG2IMG"
|
15 |
INPAINT = "INPAINT"
|
16 |
UPSCALE_IMAGE = "UPSCALE_IMAGE"
|
17 |
TILE_UPSCALE = "TILE_UPSCALE"
|
|
|
48 |
elif len(prompt) > 200:
|
49 |
self.__data["prompt"] = data.get("prompt", "")[:200] + ", "
|
50 |
|
51 |
+
def get_environment(self) -> str:
|
52 |
+
return self.__data.get("stage", "prod")
|
53 |
+
|
54 |
def get_taskId(self) -> str:
|
55 |
return self.__data.get("task_id")
|
56 |
|
57 |
def get_sourceId(self) -> str:
|
58 |
return self.__data.get("source_id")
|
59 |
|
60 |
+
def get_slack_url(self) -> str:
|
61 |
+
return self.__data.get("slack_url", None)
|
62 |
+
|
63 |
def get_imageUrl(self) -> str:
|
64 |
return self.__data.get("imageUrl", None)
|
65 |
|
|
|
157 |
def get_access_token(self) -> str:
|
158 |
return self.__data.get("access_token", "")
|
159 |
|
160 |
+
def get_apply_preprocess(self) -> bool:
|
161 |
+
return self.__data.get("apply_preprocess", True)
|
162 |
+
|
163 |
def get_high_res_fix(self) -> bool:
|
164 |
return self.__data.get("high_res_fix", False)
|
165 |
|
166 |
def get_base_dimension(self):
|
167 |
return self.__data.get("base_dimension", None)
|
168 |
|
169 |
+
def get_process_mode(self):
|
170 |
+
return self.__data.get("process_mode", None)
|
171 |
+
|
172 |
def get_action_data(self) -> dict:
|
173 |
"If task_type is CUSTOM_ACTION, then this will return the action data with 'name' as key"
|
174 |
return self.__data.get("action_data", {})
|
|
|
188 |
def cnc_kwargs(self) -> dict:
|
189 |
return dict(self.__get_kwargs("cnc_"))
|
190 |
|
191 |
+
def cnci2i_kwargs(self) -> dict:
|
192 |
+
return dict(self.__get_kwargs("cnci2i_"))
|
193 |
+
|
194 |
def cnp_kwargs(self) -> dict:
|
195 |
return dict(self.__get_kwargs("cnp_"))
|
196 |
|
|
|
208 |
|
209 |
def __get_kwargs(self, prefix: str):
|
210 |
for k, v in self.__data.items():
|
211 |
+
if k.startswith(prefix) and v != -1:
|
212 |
yield k[len(prefix) :], v
|
213 |
|
214 |
@property
|
internals/pipelines/commons.py
CHANGED
@@ -11,11 +11,14 @@ from diffusers import (
|
|
11 |
|
12 |
from internals.data.result import Result
|
13 |
from internals.pipelines.twoStepPipeline import two_step_pipeline
|
|
|
14 |
from internals.util.commons import disable_safety_checker, download_image
|
15 |
from internals.util.config import (
|
|
|
16 |
get_base_model_variant,
|
17 |
get_hf_token,
|
18 |
get_is_sdxl,
|
|
|
19 |
get_num_return_sequences,
|
20 |
)
|
21 |
|
@@ -38,6 +41,9 @@ class Text2Img(AbstractPipeline):
|
|
38 |
|
39 |
def load(self, model_dir: str):
|
40 |
if get_is_sdxl():
|
|
|
|
|
|
|
41 |
vae = AutoencoderKL.from_pretrained(
|
42 |
"madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16
|
43 |
)
|
@@ -47,6 +53,7 @@ class Text2Img(AbstractPipeline):
|
|
47 |
token=get_hf_token(),
|
48 |
use_safetensors=True,
|
49 |
variant=get_base_model_variant(),
|
|
|
50 |
)
|
51 |
pipe.vae = vae
|
52 |
pipe.to("cuda")
|
@@ -70,9 +77,9 @@ class Text2Img(AbstractPipeline):
|
|
70 |
self.__patch()
|
71 |
|
72 |
def __patch(self):
|
73 |
-
if get_is_sdxl():
|
74 |
-
self.pipe.
|
75 |
-
self.pipe.
|
76 |
self.pipe.enable_xformers_memory_efficient_attention()
|
77 |
|
78 |
@torch.inference_mode()
|
@@ -82,12 +89,15 @@ class Text2Img(AbstractPipeline):
|
|
82 |
num_inference_steps: int,
|
83 |
height: int,
|
84 |
width: int,
|
|
|
85 |
negative_prompt: str,
|
86 |
iteration: float = 3.0,
|
87 |
**kwargs,
|
88 |
):
|
89 |
prompt = params.prompt
|
90 |
|
|
|
|
|
91 |
if params.prompt_left and params.prompt_right:
|
92 |
# multi-character pipelines
|
93 |
prompt = [params.prompt[0], params.prompt_left[0], params.prompt_right[0]]
|
@@ -99,6 +109,7 @@ class Text2Img(AbstractPipeline):
|
|
99 |
"width": width,
|
100 |
"num_inference_steps": num_inference_steps,
|
101 |
"negative_prompt": [negative_prompt or ""] * len(prompt),
|
|
|
102 |
**kwargs,
|
103 |
}
|
104 |
result = self.pipe.multi_character_diffusion(**kwargs)
|
@@ -125,8 +136,11 @@ class Text2Img(AbstractPipeline):
|
|
125 |
"width": width,
|
126 |
"negative_prompt": [negative_prompt or ""] * get_num_return_sequences(),
|
127 |
"num_inference_steps": num_inference_steps,
|
|
|
|
|
128 |
**kwargs,
|
129 |
}
|
|
|
130 |
result = self.pipe.__call__(**kwargs)
|
131 |
|
132 |
return Result.from_result(result)
|
@@ -145,6 +159,7 @@ class Img2Img(AbstractPipeline):
|
|
145 |
torch_dtype=torch.float16,
|
146 |
token=get_hf_token(),
|
147 |
variant=get_base_model_variant(),
|
|
|
148 |
use_safetensors=True,
|
149 |
).to("cuda")
|
150 |
else:
|
@@ -183,20 +198,24 @@ class Img2Img(AbstractPipeline):
|
|
183 |
num_inference_steps: int,
|
184 |
width: int,
|
185 |
height: int,
|
|
|
186 |
strength: float = 0.75,
|
187 |
guidance_scale: float = 7.5,
|
188 |
**kwargs,
|
189 |
):
|
190 |
image = download_image(imageUrl).resize((width, height))
|
191 |
|
|
|
|
|
192 |
kwargs = {
|
193 |
"prompt": prompt,
|
194 |
-
"image": image,
|
195 |
"strength": strength,
|
196 |
"negative_prompt": negative_prompt,
|
197 |
"guidance_scale": guidance_scale,
|
198 |
"num_images_per_prompt": 1,
|
199 |
"num_inference_steps": num_inference_steps,
|
|
|
200 |
**kwargs,
|
201 |
}
|
202 |
result = self.pipe.__call__(**kwargs)
|
|
|
11 |
|
12 |
from internals.data.result import Result
|
13 |
from internals.pipelines.twoStepPipeline import two_step_pipeline
|
14 |
+
from internals.util import get_generators
|
15 |
from internals.util.commons import disable_safety_checker, download_image
|
16 |
from internals.util.config import (
|
17 |
+
get_base_model_revision,
|
18 |
get_base_model_variant,
|
19 |
get_hf_token,
|
20 |
get_is_sdxl,
|
21 |
+
get_low_gpu_mem,
|
22 |
get_num_return_sequences,
|
23 |
)
|
24 |
|
|
|
41 |
|
42 |
def load(self, model_dir: str):
|
43 |
if get_is_sdxl():
|
44 |
+
print(
|
45 |
+
f"Loading model {model_dir} - {get_base_model_variant()}, {get_base_model_revision()}"
|
46 |
+
)
|
47 |
vae = AutoencoderKL.from_pretrained(
|
48 |
"madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16
|
49 |
)
|
|
|
53 |
token=get_hf_token(),
|
54 |
use_safetensors=True,
|
55 |
variant=get_base_model_variant(),
|
56 |
+
revision=get_base_model_revision(),
|
57 |
)
|
58 |
pipe.vae = vae
|
59 |
pipe.to("cuda")
|
|
|
77 |
self.__patch()
|
78 |
|
79 |
def __patch(self):
|
80 |
+
if get_is_sdxl() or get_low_gpu_mem():
|
81 |
+
self.pipe.vae.enable_tiling()
|
82 |
+
self.pipe.vae.enable_slicing()
|
83 |
self.pipe.enable_xformers_memory_efficient_attention()
|
84 |
|
85 |
@torch.inference_mode()
|
|
|
89 |
num_inference_steps: int,
|
90 |
height: int,
|
91 |
width: int,
|
92 |
+
seed: int,
|
93 |
negative_prompt: str,
|
94 |
iteration: float = 3.0,
|
95 |
**kwargs,
|
96 |
):
|
97 |
prompt = params.prompt
|
98 |
|
99 |
+
generator = get_generators(seed, get_num_return_sequences())
|
100 |
+
|
101 |
if params.prompt_left and params.prompt_right:
|
102 |
# multi-character pipelines
|
103 |
prompt = [params.prompt[0], params.prompt_left[0], params.prompt_right[0]]
|
|
|
109 |
"width": width,
|
110 |
"num_inference_steps": num_inference_steps,
|
111 |
"negative_prompt": [negative_prompt or ""] * len(prompt),
|
112 |
+
"generator": generator,
|
113 |
**kwargs,
|
114 |
}
|
115 |
result = self.pipe.multi_character_diffusion(**kwargs)
|
|
|
136 |
"width": width,
|
137 |
"negative_prompt": [negative_prompt or ""] * get_num_return_sequences(),
|
138 |
"num_inference_steps": num_inference_steps,
|
139 |
+
"guidance_scale": 7.5,
|
140 |
+
"generator": generator,
|
141 |
**kwargs,
|
142 |
}
|
143 |
+
print(kwargs)
|
144 |
result = self.pipe.__call__(**kwargs)
|
145 |
|
146 |
return Result.from_result(result)
|
|
|
159 |
torch_dtype=torch.float16,
|
160 |
token=get_hf_token(),
|
161 |
variant=get_base_model_variant(),
|
162 |
+
revision=get_base_model_revision(),
|
163 |
use_safetensors=True,
|
164 |
).to("cuda")
|
165 |
else:
|
|
|
198 |
num_inference_steps: int,
|
199 |
width: int,
|
200 |
height: int,
|
201 |
+
seed: int,
|
202 |
strength: float = 0.75,
|
203 |
guidance_scale: float = 7.5,
|
204 |
**kwargs,
|
205 |
):
|
206 |
image = download_image(imageUrl).resize((width, height))
|
207 |
|
208 |
+
generator = get_generators(seed, get_num_return_sequences())
|
209 |
+
|
210 |
kwargs = {
|
211 |
"prompt": prompt,
|
212 |
+
"image": [image] * get_num_return_sequences(),
|
213 |
"strength": strength,
|
214 |
"negative_prompt": negative_prompt,
|
215 |
"guidance_scale": guidance_scale,
|
216 |
"num_images_per_prompt": 1,
|
217 |
"num_inference_steps": num_inference_steps,
|
218 |
+
"generator": generator,
|
219 |
**kwargs,
|
220 |
}
|
221 |
result = self.pipe.__call__(**kwargs)
|
internals/pipelines/controlnets.py
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
from typing import AbstractSet, List, Literal, Optional, Union
|
2 |
|
3 |
import cv2
|
@@ -17,6 +18,7 @@ from diffusers import (
|
|
17 |
StableDiffusionControlNetImg2ImgPipeline,
|
18 |
StableDiffusionControlNetPipeline,
|
19 |
StableDiffusionXLAdapterPipeline,
|
|
|
20 |
StableDiffusionXLControlNetPipeline,
|
21 |
T2IAdapter,
|
22 |
UniPCMultistepScheduler,
|
@@ -29,9 +31,9 @@ from tqdm import gui
|
|
29 |
from transformers import pipeline
|
30 |
|
31 |
import internals.util.image as ImageUtil
|
32 |
-
from external.midas import apply_midas
|
33 |
from internals.data.result import Result
|
34 |
from internals.pipelines.commons import AbstractPipeline
|
|
|
35 |
from internals.util.cache import clear_cuda_and_gc
|
36 |
from internals.util.commons import download_image
|
37 |
from internals.util.config import (
|
@@ -39,9 +41,51 @@ from internals.util.config import (
|
|
39 |
get_hf_token,
|
40 |
get_is_sdxl,
|
41 |
get_model_dir,
|
|
|
42 |
)
|
43 |
|
44 |
-
CONTROLNET_TYPES = Literal[
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
|
46 |
|
47 |
class StableDiffusionNetworkModelPipelineLoader:
|
@@ -57,11 +101,6 @@ class StableDiffusionNetworkModelPipelineLoader:
|
|
57 |
pipeline_type,
|
58 |
base_pipe: Optional[AbstractSet] = None,
|
59 |
):
|
60 |
-
if is_sdxl and is_img2img:
|
61 |
-
# Does not matter pipeline type but tile upscale is not supported
|
62 |
-
print("Warning: Tile upscale is not supported on SDXL")
|
63 |
-
return None
|
64 |
-
|
65 |
if base_pipe is None:
|
66 |
pretrained = True
|
67 |
kwargs = {
|
@@ -75,7 +114,17 @@ class StableDiffusionNetworkModelPipelineLoader:
|
|
75 |
kwargs = {
|
76 |
**base_pipe.pipe.components, # pyright: ignore
|
77 |
}
|
|
|
|
|
|
|
78 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
if is_sdxl and pipeline_type == "controlnet":
|
80 |
model = (
|
81 |
StableDiffusionXLControlNetPipeline.from_pretrained
|
@@ -146,9 +195,10 @@ class ControlNet(AbstractPipeline):
|
|
146 |
def load_model(self, task_name: CONTROLNET_TYPES):
|
147 |
"Appropriately loads the network module, pipelines and cache it for reuse."
|
148 |
|
149 |
-
config = self.__model_sdxl if get_is_sdxl() else self.__model_normal
|
150 |
if self.__current_task_name == task_name:
|
151 |
return
|
|
|
|
|
152 |
model = config[task_name]
|
153 |
if not model:
|
154 |
raise Exception(f"ControlNet is not supported for {task_name}")
|
@@ -176,31 +226,13 @@ class ControlNet(AbstractPipeline):
|
|
176 |
def __load_network_model(self, model_name, pipeline_type):
|
177 |
"Loads the network module, eg: ControlNet or T2I Adapters"
|
178 |
|
179 |
-
def load_controlnet(model):
|
180 |
-
return ControlNetModel.from_pretrained(
|
181 |
-
model,
|
182 |
-
torch_dtype=torch.float16,
|
183 |
-
cache_dir=get_hf_cache_dir(),
|
184 |
-
).to("cuda")
|
185 |
-
|
186 |
-
def load_t2i(model):
|
187 |
-
return T2IAdapter.from_pretrained(
|
188 |
-
model,
|
189 |
-
torch_dtype=torch.float16,
|
190 |
-
varient="fp16",
|
191 |
-
).to("cuda")
|
192 |
-
|
193 |
if type(model_name) == str:
|
194 |
-
|
195 |
-
return load_controlnet(model_name)
|
196 |
-
if pipeline_type == "t2i":
|
197 |
-
return load_t2i(model_name)
|
198 |
-
raise Exception("Invalid pipeline type")
|
199 |
elif type(model_name) == list:
|
200 |
if pipeline_type == "controlnet":
|
201 |
cns = []
|
202 |
for model in model_name:
|
203 |
-
cns.append(
|
204 |
return MultiControlNetModel(cns).to("cuda")
|
205 |
elif pipeline_type == "t2i":
|
206 |
raise Exception("Multi T2I adapters are not supported")
|
@@ -219,9 +251,10 @@ class ControlNet(AbstractPipeline):
|
|
219 |
pipe.enable_vae_slicing()
|
220 |
pipe.enable_xformers_memory_efficient_attention()
|
221 |
# this scheduler produces good outputs for t2i adapters
|
222 |
-
|
223 |
-
pipe.scheduler.
|
224 |
-
|
|
|
225 |
else:
|
226 |
pipe.enable_xformers_memory_efficient_attention()
|
227 |
return pipe
|
@@ -229,7 +262,7 @@ class ControlNet(AbstractPipeline):
|
|
229 |
# If the pipeline type is changed we should reload all
|
230 |
# the pipelines
|
231 |
if not self.__loaded or self.__pipe_type != pipeline_type:
|
232 |
-
# controlnet pipeline for tile upscaler
|
233 |
pipe = StableDiffusionNetworkModelPipelineLoader(
|
234 |
is_sdxl=get_is_sdxl(),
|
235 |
is_img2img=True,
|
@@ -278,6 +311,8 @@ class ControlNet(AbstractPipeline):
|
|
278 |
def process(self, **kwargs):
|
279 |
if self.__current_task_name == "pose":
|
280 |
return self.process_pose(**kwargs)
|
|
|
|
|
281 |
if self.__current_task_name == "canny":
|
282 |
return self.process_canny(**kwargs)
|
283 |
if self.__current_task_name == "scribble":
|
@@ -286,6 +321,8 @@ class ControlNet(AbstractPipeline):
|
|
286 |
return self.process_linearart(**kwargs)
|
287 |
if self.__current_task_name == "tile_upscaler":
|
288 |
return self.process_tile_upscaler(**kwargs)
|
|
|
|
|
289 |
raise Exception("ControlNet is not loaded with any model")
|
290 |
|
291 |
@torch.inference_mode()
|
@@ -298,16 +335,22 @@ class ControlNet(AbstractPipeline):
|
|
298 |
negative_prompt: List[str],
|
299 |
height: int,
|
300 |
width: int,
|
301 |
-
guidance_scale: float =
|
|
|
302 |
**kwargs,
|
303 |
):
|
304 |
if self.__current_task_name != "canny":
|
305 |
raise Exception("ControlNet is not loaded with canny model")
|
306 |
|
307 |
-
|
|
|
|
|
|
|
|
|
|
|
308 |
|
309 |
-
|
310 |
-
|
311 |
|
312 |
kwargs = {
|
313 |
"prompt": prompt,
|
@@ -318,11 +361,67 @@ class ControlNet(AbstractPipeline):
|
|
318 |
"num_inference_steps": num_inference_steps,
|
319 |
"height": height,
|
320 |
"width": width,
|
|
|
321 |
**kwargs,
|
322 |
}
|
323 |
|
|
|
324 |
result = self.pipe2.__call__(**kwargs)
|
325 |
-
return Result.from_result(result)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
326 |
|
327 |
@torch.inference_mode()
|
328 |
def process_pose(
|
@@ -340,22 +439,23 @@ class ControlNet(AbstractPipeline):
|
|
340 |
if self.__current_task_name != "pose":
|
341 |
raise Exception("ControlNet is not loaded with pose model")
|
342 |
|
343 |
-
|
344 |
|
345 |
kwargs = {
|
346 |
"prompt": prompt[0],
|
347 |
"image": image,
|
348 |
-
"num_images_per_prompt":
|
349 |
"num_inference_steps": num_inference_steps,
|
350 |
"negative_prompt": negative_prompt[0],
|
351 |
"guidance_scale": guidance_scale,
|
352 |
"height": height,
|
353 |
"width": width,
|
|
|
354 |
**kwargs,
|
355 |
}
|
356 |
print(kwargs)
|
357 |
result = self.pipe2.__call__(**kwargs)
|
358 |
-
return Result.from_result(result)
|
359 |
|
360 |
@torch.inference_mode()
|
361 |
def process_tile_upscaler(
|
@@ -374,26 +474,60 @@ class ControlNet(AbstractPipeline):
|
|
374 |
if self.__current_task_name != "tile_upscaler":
|
375 |
raise Exception("ControlNet is not loaded with tile_upscaler model")
|
376 |
|
377 |
-
|
|
|
|
|
|
|
|
|
|
|
378 |
|
379 |
-
|
380 |
-
|
381 |
-
|
382 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
383 |
|
384 |
kwargs = {
|
385 |
-
"image": condition_image,
|
386 |
"prompt": prompt,
|
387 |
"control_image": condition_image,
|
388 |
"num_inference_steps": num_inference_steps,
|
389 |
"negative_prompt": negative_prompt,
|
390 |
"height": condition_image.size[1],
|
391 |
"width": condition_image.size[0],
|
392 |
-
"
|
393 |
**kwargs,
|
394 |
}
|
395 |
result = self.pipe.__call__(**kwargs)
|
396 |
-
return Result.from_result(result)
|
397 |
|
398 |
@torch.inference_mode()
|
399 |
def process_scribble(
|
@@ -406,16 +540,28 @@ class ControlNet(AbstractPipeline):
|
|
406 |
height: int,
|
407 |
width: int,
|
408 |
guidance_scale: float = 7.5,
|
|
|
409 |
**kwargs,
|
410 |
):
|
411 |
if self.__current_task_name != "scribble":
|
412 |
raise Exception("ControlNet is not loaded with scribble model")
|
413 |
|
414 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
415 |
|
416 |
sdxl_args = (
|
417 |
{
|
418 |
-
"guidance_scale":
|
419 |
"adapter_conditioning_scale": 1.0,
|
420 |
"adapter_conditioning_factor": 1.0,
|
421 |
}
|
@@ -431,11 +577,12 @@ class ControlNet(AbstractPipeline):
|
|
431 |
"height": height,
|
432 |
"width": width,
|
433 |
"guidance_scale": guidance_scale,
|
|
|
434 |
**sdxl_args,
|
435 |
**kwargs,
|
436 |
}
|
437 |
result = self.pipe2.__call__(**kwargs)
|
438 |
-
return Result.from_result(result)
|
439 |
|
440 |
@torch.inference_mode()
|
441 |
def process_linearart(
|
@@ -448,20 +595,26 @@ class ControlNet(AbstractPipeline):
|
|
448 |
height: int,
|
449 |
width: int,
|
450 |
guidance_scale: float = 7.5,
|
|
|
451 |
**kwargs,
|
452 |
):
|
453 |
if self.__current_task_name != "linearart":
|
454 |
raise Exception("ControlNet is not loaded with linearart model")
|
455 |
|
456 |
-
|
457 |
|
458 |
-
init_image =
|
459 |
-
|
|
|
|
|
|
|
|
|
|
|
460 |
|
461 |
# we use t2i adapter and the conditioning scale should always be 0.8
|
462 |
sdxl_args = (
|
463 |
{
|
464 |
-
"guidance_scale":
|
465 |
"adapter_conditioning_scale": 1.0,
|
466 |
"adapter_conditioning_factor": 1.0,
|
467 |
}
|
@@ -470,18 +623,68 @@ class ControlNet(AbstractPipeline):
|
|
470 |
)
|
471 |
|
472 |
kwargs = {
|
473 |
-
"image": [condition_image] *
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
474 |
"prompt": prompt,
|
475 |
"num_inference_steps": num_inference_steps,
|
476 |
"negative_prompt": negative_prompt,
|
477 |
"height": height,
|
478 |
"width": width,
|
479 |
"guidance_scale": guidance_scale,
|
|
|
480 |
**sdxl_args,
|
481 |
**kwargs,
|
482 |
}
|
483 |
result = self.pipe2.__call__(**kwargs)
|
484 |
-
return Result.from_result(result)
|
485 |
|
486 |
def cleanup(self):
|
487 |
"""Doesn't do anything considering new diffusers has itself a cleanup mechanism
|
@@ -504,12 +707,15 @@ class ControlNet(AbstractPipeline):
|
|
504 |
def linearart_condition_image(image: Image.Image, **kwargs) -> Image.Image:
|
505 |
processor = LineartDetector.from_pretrained("lllyasviel/Annotators")
|
506 |
if get_is_sdxl():
|
507 |
-
kwargs = {"detect_resolution": 384, **kwargs}
|
|
|
|
|
508 |
|
509 |
image = processor.__call__(input_image=image, **kwargs)
|
510 |
return image
|
511 |
|
512 |
@staticmethod
|
|
|
513 |
def depth_image(image: Image.Image) -> Image.Image:
|
514 |
global midas, midas_transforms
|
515 |
if "midas" not in globals():
|
@@ -555,6 +761,10 @@ class ControlNet(AbstractPipeline):
|
|
555 |
canny_image = Image.fromarray(image_array)
|
556 |
return canny_image
|
557 |
|
|
|
|
|
|
|
|
|
558 |
def __resize_for_condition_image(self, image: Image.Image, resolution: int):
|
559 |
input_image = image.convert("RGB")
|
560 |
W, H = input_image.size
|
@@ -572,6 +782,7 @@ class ControlNet(AbstractPipeline):
|
|
572 |
"linearart": "lllyasviel/control_v11p_sd15_lineart",
|
573 |
"scribble": "lllyasviel/control_v11p_sd15_scribble",
|
574 |
"tile_upscaler": "lllyasviel/control_v11f1e_sd15_tile",
|
|
|
575 |
}
|
576 |
__model_normal_types = {
|
577 |
"pose": "controlnet",
|
@@ -579,19 +790,24 @@ class ControlNet(AbstractPipeline):
|
|
579 |
"linearart": "controlnet",
|
580 |
"scribble": "controlnet",
|
581 |
"tile_upscaler": "controlnet",
|
|
|
582 |
}
|
583 |
|
584 |
__model_sdxl = {
|
585 |
"pose": "thibaud/controlnet-openpose-sdxl-1.0",
|
586 |
-
"canny": "
|
|
|
|
|
587 |
"linearart": "TencentARC/t2i-adapter-lineart-sdxl-1.0",
|
588 |
"scribble": "TencentARC/t2i-adapter-sketch-sdxl-1.0",
|
589 |
-
"tile_upscaler":
|
590 |
}
|
591 |
__model_sdxl_types = {
|
592 |
"pose": "controlnet",
|
593 |
"canny": "controlnet",
|
|
|
|
|
594 |
"linearart": "t2i",
|
595 |
"scribble": "t2i",
|
596 |
-
"tile_upscaler":
|
597 |
}
|
|
|
1 |
+
import os
|
2 |
from typing import AbstractSet, List, Literal, Optional, Union
|
3 |
|
4 |
import cv2
|
|
|
18 |
StableDiffusionControlNetImg2ImgPipeline,
|
19 |
StableDiffusionControlNetPipeline,
|
20 |
StableDiffusionXLAdapterPipeline,
|
21 |
+
StableDiffusionXLControlNetImg2ImgPipeline,
|
22 |
StableDiffusionXLControlNetPipeline,
|
23 |
T2IAdapter,
|
24 |
UniPCMultistepScheduler,
|
|
|
31 |
from transformers import pipeline
|
32 |
|
33 |
import internals.util.image as ImageUtil
|
|
|
34 |
from internals.data.result import Result
|
35 |
from internals.pipelines.commons import AbstractPipeline
|
36 |
+
from internals.util import get_generators
|
37 |
from internals.util.cache import clear_cuda_and_gc
|
38 |
from internals.util.commons import download_image
|
39 |
from internals.util.config import (
|
|
|
41 |
get_hf_token,
|
42 |
get_is_sdxl,
|
43 |
get_model_dir,
|
44 |
+
get_num_return_sequences,
|
45 |
)
|
46 |
|
47 |
+
CONTROLNET_TYPES = Literal[
|
48 |
+
"pose", "canny", "scribble", "linearart", "tile_upscaler", "canny_2x"
|
49 |
+
]
|
50 |
+
|
51 |
+
__CN_MODELS = {}
|
52 |
+
MAX_CN_MODELS = 3
|
53 |
+
|
54 |
+
|
55 |
+
def clear_networks():
|
56 |
+
global __CN_MODELS
|
57 |
+
__CN_MODELS = {}
|
58 |
+
|
59 |
+
|
60 |
+
def load_network_model_by_key(repo_id: str, pipeline_type: str):
|
61 |
+
global __CN_MODELS
|
62 |
+
|
63 |
+
if repo_id in __CN_MODELS:
|
64 |
+
return __CN_MODELS[repo_id]
|
65 |
+
|
66 |
+
if len(__CN_MODELS) >= MAX_CN_MODELS:
|
67 |
+
__CN_MODELS = {}
|
68 |
+
|
69 |
+
if pipeline_type == "controlnet":
|
70 |
+
model = ControlNetModel.from_pretrained(
|
71 |
+
repo_id,
|
72 |
+
torch_dtype=torch.float16,
|
73 |
+
cache_dir=get_hf_cache_dir(),
|
74 |
+
token=get_hf_token(),
|
75 |
+
).to("cuda")
|
76 |
+
elif pipeline_type == "t2i":
|
77 |
+
model = T2IAdapter.from_pretrained(
|
78 |
+
repo_id,
|
79 |
+
torch_dtype=torch.float16,
|
80 |
+
varient="fp16",
|
81 |
+
token=get_hf_token(),
|
82 |
+
).to("cuda")
|
83 |
+
else:
|
84 |
+
raise Exception("Invalid pipeline type")
|
85 |
+
|
86 |
+
__CN_MODELS[repo_id] = model
|
87 |
+
|
88 |
+
return model
|
89 |
|
90 |
|
91 |
class StableDiffusionNetworkModelPipelineLoader:
|
|
|
101 |
pipeline_type,
|
102 |
base_pipe: Optional[AbstractSet] = None,
|
103 |
):
|
|
|
|
|
|
|
|
|
|
|
104 |
if base_pipe is None:
|
105 |
pretrained = True
|
106 |
kwargs = {
|
|
|
114 |
kwargs = {
|
115 |
**base_pipe.pipe.components, # pyright: ignore
|
116 |
}
|
117 |
+
if get_is_sdxl():
|
118 |
+
kwargs.pop("image_encoder", None)
|
119 |
+
kwargs.pop("feature_extractor", None)
|
120 |
|
121 |
+
if is_sdxl and is_img2img and pipeline_type == "controlnet":
|
122 |
+
model = (
|
123 |
+
StableDiffusionXLControlNetImg2ImgPipeline.from_pretrained
|
124 |
+
if pretrained
|
125 |
+
else StableDiffusionXLControlNetImg2ImgPipeline
|
126 |
+
)
|
127 |
+
return model(controlnet=network_model, **kwargs).to("cuda")
|
128 |
if is_sdxl and pipeline_type == "controlnet":
|
129 |
model = (
|
130 |
StableDiffusionXLControlNetPipeline.from_pretrained
|
|
|
195 |
def load_model(self, task_name: CONTROLNET_TYPES):
|
196 |
"Appropriately loads the network module, pipelines and cache it for reuse."
|
197 |
|
|
|
198 |
if self.__current_task_name == task_name:
|
199 |
return
|
200 |
+
|
201 |
+
config = self.__model_sdxl if get_is_sdxl() else self.__model_normal
|
202 |
model = config[task_name]
|
203 |
if not model:
|
204 |
raise Exception(f"ControlNet is not supported for {task_name}")
|
|
|
226 |
def __load_network_model(self, model_name, pipeline_type):
|
227 |
"Loads the network module, eg: ControlNet or T2I Adapters"
|
228 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
229 |
if type(model_name) == str:
|
230 |
+
return load_network_model_by_key(model_name, pipeline_type)
|
|
|
|
|
|
|
|
|
231 |
elif type(model_name) == list:
|
232 |
if pipeline_type == "controlnet":
|
233 |
cns = []
|
234 |
for model in model_name:
|
235 |
+
cns.append(load_network_model_by_key(model, pipeline_type))
|
236 |
return MultiControlNetModel(cns).to("cuda")
|
237 |
elif pipeline_type == "t2i":
|
238 |
raise Exception("Multi T2I adapters are not supported")
|
|
|
251 |
pipe.enable_vae_slicing()
|
252 |
pipe.enable_xformers_memory_efficient_attention()
|
253 |
# this scheduler produces good outputs for t2i adapters
|
254 |
+
if pipeline_type == "t2i":
|
255 |
+
pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(
|
256 |
+
pipe.scheduler.config
|
257 |
+
)
|
258 |
else:
|
259 |
pipe.enable_xformers_memory_efficient_attention()
|
260 |
return pipe
|
|
|
262 |
# If the pipeline type is changed we should reload all
|
263 |
# the pipelines
|
264 |
if not self.__loaded or self.__pipe_type != pipeline_type:
|
265 |
+
# controlnet pipeline for tile upscaler or any pipeline with img2img + network support
|
266 |
pipe = StableDiffusionNetworkModelPipelineLoader(
|
267 |
is_sdxl=get_is_sdxl(),
|
268 |
is_img2img=True,
|
|
|
311 |
def process(self, **kwargs):
|
312 |
if self.__current_task_name == "pose":
|
313 |
return self.process_pose(**kwargs)
|
314 |
+
if self.__current_task_name == "depth":
|
315 |
+
return self.process_depth(**kwargs)
|
316 |
if self.__current_task_name == "canny":
|
317 |
return self.process_canny(**kwargs)
|
318 |
if self.__current_task_name == "scribble":
|
|
|
321 |
return self.process_linearart(**kwargs)
|
322 |
if self.__current_task_name == "tile_upscaler":
|
323 |
return self.process_tile_upscaler(**kwargs)
|
324 |
+
if self.__current_task_name == "canny_2x":
|
325 |
+
return self.process_canny_2x(**kwargs)
|
326 |
raise Exception("ControlNet is not loaded with any model")
|
327 |
|
328 |
@torch.inference_mode()
|
|
|
335 |
negative_prompt: List[str],
|
336 |
height: int,
|
337 |
width: int,
|
338 |
+
guidance_scale: float = 7.5,
|
339 |
+
apply_preprocess: bool = True,
|
340 |
**kwargs,
|
341 |
):
|
342 |
if self.__current_task_name != "canny":
|
343 |
raise Exception("ControlNet is not loaded with canny model")
|
344 |
|
345 |
+
generator = get_generators(seed, get_num_return_sequences())
|
346 |
+
|
347 |
+
init_image = self.preprocess_image(imageUrl, width, height)
|
348 |
+
if apply_preprocess:
|
349 |
+
init_image = ControlNet.canny_detect_edge(init_image)
|
350 |
+
init_image = init_image.resize((width, height))
|
351 |
|
352 |
+
# if get_is_sdxl():
|
353 |
+
# kwargs["controlnet_conditioning_scale"] = 0.5
|
354 |
|
355 |
kwargs = {
|
356 |
"prompt": prompt,
|
|
|
361 |
"num_inference_steps": num_inference_steps,
|
362 |
"height": height,
|
363 |
"width": width,
|
364 |
+
"generator": generator,
|
365 |
**kwargs,
|
366 |
}
|
367 |
|
368 |
+
print(kwargs)
|
369 |
result = self.pipe2.__call__(**kwargs)
|
370 |
+
return Result.from_result(result), init_image
|
371 |
+
|
372 |
+
@torch.inference_mode()
|
373 |
+
def process_canny_2x(
|
374 |
+
self,
|
375 |
+
prompt: List[str],
|
376 |
+
imageUrl: str,
|
377 |
+
seed: int,
|
378 |
+
num_inference_steps: int,
|
379 |
+
negative_prompt: List[str],
|
380 |
+
height: int,
|
381 |
+
width: int,
|
382 |
+
guidance_scale: float = 8.5,
|
383 |
+
**kwargs,
|
384 |
+
):
|
385 |
+
if self.__current_task_name != "canny_2x":
|
386 |
+
raise Exception("ControlNet is not loaded with canny model")
|
387 |
+
|
388 |
+
generator = get_generators(seed, get_num_return_sequences())
|
389 |
+
|
390 |
+
init_image = self.preprocess_image(imageUrl, width, height)
|
391 |
+
canny_image = ControlNet.canny_detect_edge(init_image).resize((width, height))
|
392 |
+
depth_image = ControlNet.depth_image(init_image).resize((width, height))
|
393 |
+
|
394 |
+
condition_scale = kwargs.get("controlnet_conditioning_scale", None)
|
395 |
+
condition_factor = kwargs.get("control_guidance_end", None)
|
396 |
+
print("condition_scale", condition_scale)
|
397 |
+
|
398 |
+
if not get_is_sdxl():
|
399 |
+
kwargs["guidance_scale"] = 7.5
|
400 |
+
kwargs["strength"] = 0.8
|
401 |
+
kwargs["controlnet_conditioning_scale"] = [condition_scale or 1.0, 0.3]
|
402 |
+
else:
|
403 |
+
kwargs["controlnet_conditioning_scale"] = [condition_scale or 0.8, 0.3]
|
404 |
+
|
405 |
+
kwargs["control_guidance_end"] = [condition_factor or 1.0, 1.0]
|
406 |
+
|
407 |
+
kwargs = {
|
408 |
+
"prompt": prompt[0],
|
409 |
+
"image": [init_image] * get_num_return_sequences(),
|
410 |
+
"control_image": [canny_image, depth_image],
|
411 |
+
"guidance_scale": guidance_scale,
|
412 |
+
"num_images_per_prompt": get_num_return_sequences(),
|
413 |
+
"negative_prompt": negative_prompt[0],
|
414 |
+
"num_inference_steps": num_inference_steps,
|
415 |
+
"strength": 1.0,
|
416 |
+
"height": height,
|
417 |
+
"width": width,
|
418 |
+
"generator": generator,
|
419 |
+
**kwargs,
|
420 |
+
}
|
421 |
+
print(kwargs)
|
422 |
+
|
423 |
+
result = self.pipe.__call__(**kwargs)
|
424 |
+
return Result.from_result(result), canny_image
|
425 |
|
426 |
@torch.inference_mode()
|
427 |
def process_pose(
|
|
|
439 |
if self.__current_task_name != "pose":
|
440 |
raise Exception("ControlNet is not loaded with pose model")
|
441 |
|
442 |
+
generator = get_generators(seed, get_num_return_sequences())
|
443 |
|
444 |
kwargs = {
|
445 |
"prompt": prompt[0],
|
446 |
"image": image,
|
447 |
+
"num_images_per_prompt": get_num_return_sequences(),
|
448 |
"num_inference_steps": num_inference_steps,
|
449 |
"negative_prompt": negative_prompt[0],
|
450 |
"guidance_scale": guidance_scale,
|
451 |
"height": height,
|
452 |
"width": width,
|
453 |
+
"generator": generator,
|
454 |
**kwargs,
|
455 |
}
|
456 |
print(kwargs)
|
457 |
result = self.pipe2.__call__(**kwargs)
|
458 |
+
return Result.from_result(result), image
|
459 |
|
460 |
@torch.inference_mode()
|
461 |
def process_tile_upscaler(
|
|
|
474 |
if self.__current_task_name != "tile_upscaler":
|
475 |
raise Exception("ControlNet is not loaded with tile_upscaler model")
|
476 |
|
477 |
+
init_image = None
|
478 |
+
# find the correct seed and imageUrl from imageUrl
|
479 |
+
try:
|
480 |
+
p = os.path.splitext(imageUrl)[0]
|
481 |
+
p = p.split("/")[-1]
|
482 |
+
p = p.split("_")[-1]
|
483 |
|
484 |
+
seed = seed + int(p)
|
485 |
+
|
486 |
+
if "_canny_2x" or "_linearart" in imageUrl:
|
487 |
+
imageUrl = imageUrl.replace("_canny_2x", "_canny_2x_highres").replace(
|
488 |
+
"_linearart_highres", ""
|
489 |
+
)
|
490 |
+
init_image = download_image(imageUrl)
|
491 |
+
width, height = init_image.size
|
492 |
+
|
493 |
+
print("Setting imageUrl with width and height", imageUrl, width, height)
|
494 |
+
except Exception as e:
|
495 |
+
print("Failed to extract seed from imageUrl", e)
|
496 |
+
|
497 |
+
print("Setting seed", seed)
|
498 |
+
generator = get_generators(seed)
|
499 |
+
|
500 |
+
if not init_image:
|
501 |
+
init_image = download_image(imageUrl).resize((width, height))
|
502 |
+
|
503 |
+
condition_image = ImageUtil.resize_image(init_image, 1024)
|
504 |
+
if get_is_sdxl():
|
505 |
+
condition_image = condition_image.resize(init_image.size)
|
506 |
+
else:
|
507 |
+
condition_image = self.__resize_for_condition_image(
|
508 |
+
init_image, resize_dimension
|
509 |
+
)
|
510 |
+
|
511 |
+
if get_is_sdxl():
|
512 |
+
kwargs["strength"] = 1.0
|
513 |
+
kwargs["controlnet_conditioning_scale"] = 1.0
|
514 |
+
kwargs["image"] = init_image
|
515 |
+
else:
|
516 |
+
kwargs["image"] = condition_image
|
517 |
+
kwargs["guidance_scale"] = guidance_scale
|
518 |
|
519 |
kwargs = {
|
|
|
520 |
"prompt": prompt,
|
521 |
"control_image": condition_image,
|
522 |
"num_inference_steps": num_inference_steps,
|
523 |
"negative_prompt": negative_prompt,
|
524 |
"height": condition_image.size[1],
|
525 |
"width": condition_image.size[0],
|
526 |
+
"generator": generator,
|
527 |
**kwargs,
|
528 |
}
|
529 |
result = self.pipe.__call__(**kwargs)
|
530 |
+
return Result.from_result(result), condition_image
|
531 |
|
532 |
@torch.inference_mode()
|
533 |
def process_scribble(
|
|
|
540 |
height: int,
|
541 |
width: int,
|
542 |
guidance_scale: float = 7.5,
|
543 |
+
apply_preprocess: bool = True,
|
544 |
**kwargs,
|
545 |
):
|
546 |
if self.__current_task_name != "scribble":
|
547 |
raise Exception("ControlNet is not loaded with scribble model")
|
548 |
|
549 |
+
generator = get_generators(seed, get_num_return_sequences())
|
550 |
+
|
551 |
+
if apply_preprocess:
|
552 |
+
if get_is_sdxl():
|
553 |
+
# We use sketch in SDXL
|
554 |
+
image = [
|
555 |
+
ControlNet.pidinet_image(image[0]).resize((width, height))
|
556 |
+
] * len(image)
|
557 |
+
else:
|
558 |
+
image = [
|
559 |
+
ControlNet.scribble_image(image[0]).resize((width, height))
|
560 |
+
] * len(image)
|
561 |
|
562 |
sdxl_args = (
|
563 |
{
|
564 |
+
"guidance_scale": guidance_scale,
|
565 |
"adapter_conditioning_scale": 1.0,
|
566 |
"adapter_conditioning_factor": 1.0,
|
567 |
}
|
|
|
577 |
"height": height,
|
578 |
"width": width,
|
579 |
"guidance_scale": guidance_scale,
|
580 |
+
"generator": generator,
|
581 |
**sdxl_args,
|
582 |
**kwargs,
|
583 |
}
|
584 |
result = self.pipe2.__call__(**kwargs)
|
585 |
+
return Result.from_result(result), image[0]
|
586 |
|
587 |
@torch.inference_mode()
|
588 |
def process_linearart(
|
|
|
595 |
height: int,
|
596 |
width: int,
|
597 |
guidance_scale: float = 7.5,
|
598 |
+
apply_preprocess: bool = True,
|
599 |
**kwargs,
|
600 |
):
|
601 |
if self.__current_task_name != "linearart":
|
602 |
raise Exception("ControlNet is not loaded with linearart model")
|
603 |
|
604 |
+
generator = get_generators(seed, get_num_return_sequences())
|
605 |
|
606 |
+
init_image = self.preprocess_image(imageUrl, width, height)
|
607 |
+
|
608 |
+
if apply_preprocess:
|
609 |
+
condition_image = ControlNet.linearart_condition_image(init_image)
|
610 |
+
condition_image = condition_image.resize(init_image.size)
|
611 |
+
else:
|
612 |
+
condition_image = init_image
|
613 |
|
614 |
# we use t2i adapter and the conditioning scale should always be 0.8
|
615 |
sdxl_args = (
|
616 |
{
|
617 |
+
"guidance_scale": guidance_scale,
|
618 |
"adapter_conditioning_scale": 1.0,
|
619 |
"adapter_conditioning_factor": 1.0,
|
620 |
}
|
|
|
623 |
)
|
624 |
|
625 |
kwargs = {
|
626 |
+
"image": [condition_image] * get_num_return_sequences(),
|
627 |
+
"prompt": prompt,
|
628 |
+
"num_inference_steps": num_inference_steps,
|
629 |
+
"negative_prompt": negative_prompt,
|
630 |
+
"height": height,
|
631 |
+
"width": width,
|
632 |
+
"guidance_scale": guidance_scale,
|
633 |
+
"generator": generator,
|
634 |
+
**sdxl_args,
|
635 |
+
**kwargs,
|
636 |
+
}
|
637 |
+
result = self.pipe2.__call__(**kwargs)
|
638 |
+
return Result.from_result(result), condition_image
|
639 |
+
|
640 |
+
@torch.inference_mode()
|
641 |
+
def process_depth(
|
642 |
+
self,
|
643 |
+
imageUrl: str,
|
644 |
+
prompt: Union[str, List[str]],
|
645 |
+
negative_prompt: Union[str, List[str]],
|
646 |
+
num_inference_steps: int,
|
647 |
+
seed: int,
|
648 |
+
height: int,
|
649 |
+
width: int,
|
650 |
+
guidance_scale: float = 7.5,
|
651 |
+
apply_preprocess: bool = True,
|
652 |
+
**kwargs,
|
653 |
+
):
|
654 |
+
if self.__current_task_name != "depth":
|
655 |
+
raise Exception("ControlNet is not loaded with depth model")
|
656 |
+
|
657 |
+
generator = get_generators(seed, get_num_return_sequences())
|
658 |
+
|
659 |
+
init_image = self.preprocess_image(imageUrl, width, height)
|
660 |
+
|
661 |
+
if apply_preprocess:
|
662 |
+
condition_image = ControlNet.depth_image(init_image)
|
663 |
+
condition_image = condition_image.resize(init_image.size)
|
664 |
+
else:
|
665 |
+
condition_image = init_image
|
666 |
+
|
667 |
+
# for using the depth controlnet in this SDXL model, these hyperparamters are optimal
|
668 |
+
sdxl_args = (
|
669 |
+
{"controlnet_conditioning_scale": 0.2, "control_guidance_end": 0.2}
|
670 |
+
if get_is_sdxl()
|
671 |
+
else {}
|
672 |
+
)
|
673 |
+
|
674 |
+
kwargs = {
|
675 |
+
"image": [condition_image] * get_num_return_sequences(),
|
676 |
"prompt": prompt,
|
677 |
"num_inference_steps": num_inference_steps,
|
678 |
"negative_prompt": negative_prompt,
|
679 |
"height": height,
|
680 |
"width": width,
|
681 |
"guidance_scale": guidance_scale,
|
682 |
+
"generator": generator,
|
683 |
**sdxl_args,
|
684 |
**kwargs,
|
685 |
}
|
686 |
result = self.pipe2.__call__(**kwargs)
|
687 |
+
return Result.from_result(result), condition_image
|
688 |
|
689 |
def cleanup(self):
|
690 |
"""Doesn't do anything considering new diffusers has itself a cleanup mechanism
|
|
|
707 |
def linearart_condition_image(image: Image.Image, **kwargs) -> Image.Image:
|
708 |
processor = LineartDetector.from_pretrained("lllyasviel/Annotators")
|
709 |
if get_is_sdxl():
|
710 |
+
kwargs = {"detect_resolution": 384, "image_resolution": 1024, **kwargs}
|
711 |
+
else:
|
712 |
+
kwargs = {}
|
713 |
|
714 |
image = processor.__call__(input_image=image, **kwargs)
|
715 |
return image
|
716 |
|
717 |
@staticmethod
|
718 |
+
@torch.inference_mode()
|
719 |
def depth_image(image: Image.Image) -> Image.Image:
|
720 |
global midas, midas_transforms
|
721 |
if "midas" not in globals():
|
|
|
761 |
canny_image = Image.fromarray(image_array)
|
762 |
return canny_image
|
763 |
|
764 |
+
def preprocess_image(self, imageUrl, width, height) -> Image.Image:
|
765 |
+
image = download_image(imageUrl, mode="RGBA").resize((width, height))
|
766 |
+
return ImageUtil.alpha_to_white(image)
|
767 |
+
|
768 |
def __resize_for_condition_image(self, image: Image.Image, resolution: int):
|
769 |
input_image = image.convert("RGB")
|
770 |
W, H = input_image.size
|
|
|
782 |
"linearart": "lllyasviel/control_v11p_sd15_lineart",
|
783 |
"scribble": "lllyasviel/control_v11p_sd15_scribble",
|
784 |
"tile_upscaler": "lllyasviel/control_v11f1e_sd15_tile",
|
785 |
+
"canny_2x": "lllyasviel/control_v11p_sd15_canny, lllyasviel/control_v11f1p_sd15_depth",
|
786 |
}
|
787 |
__model_normal_types = {
|
788 |
"pose": "controlnet",
|
|
|
790 |
"linearart": "controlnet",
|
791 |
"scribble": "controlnet",
|
792 |
"tile_upscaler": "controlnet",
|
793 |
+
"canny_2x": "controlnet",
|
794 |
}
|
795 |
|
796 |
__model_sdxl = {
|
797 |
"pose": "thibaud/controlnet-openpose-sdxl-1.0",
|
798 |
+
"canny": "Autodraft/controlnet-canny-sdxl-1.0",
|
799 |
+
"depth": "Autodraft/controlnet-depth-sdxl-1.0",
|
800 |
+
"canny_2x": "Autodraft/controlnet-canny-sdxl-1.0, Autodraft/controlnet-depth-sdxl-1.0",
|
801 |
"linearart": "TencentARC/t2i-adapter-lineart-sdxl-1.0",
|
802 |
"scribble": "TencentARC/t2i-adapter-sketch-sdxl-1.0",
|
803 |
+
"tile_upscaler": "Autodraft/ControlNet_SDXL_tile_upscale",
|
804 |
}
|
805 |
__model_sdxl_types = {
|
806 |
"pose": "controlnet",
|
807 |
"canny": "controlnet",
|
808 |
+
"canny_2x": "controlnet",
|
809 |
+
"depth": "controlnet",
|
810 |
"linearart": "t2i",
|
811 |
"scribble": "t2i",
|
812 |
+
"tile_upscaler": "controlnet",
|
813 |
}
|
internals/pipelines/high_res.py
CHANGED
@@ -1,15 +1,22 @@
|
|
1 |
import math
|
2 |
-
from typing import List, Optional
|
3 |
|
4 |
from PIL import Image
|
5 |
|
6 |
from internals.data.result import Result
|
7 |
from internals.pipelines.commons import AbstractPipeline, Img2Img
|
|
|
8 |
from internals.util.cache import clear_cuda_and_gc
|
9 |
-
from internals.util.config import
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
|
11 |
|
12 |
-
class HighRes(AbstractPipeline):
|
13 |
def load(self, img2img: Optional[Img2Img] = None):
|
14 |
if hasattr(self, "pipe"):
|
15 |
return
|
@@ -21,6 +28,9 @@ class HighRes(AbstractPipeline):
|
|
21 |
self.pipe = img2img.pipe
|
22 |
self.img2img = img2img
|
23 |
|
|
|
|
|
|
|
24 |
def apply(
|
25 |
self,
|
26 |
prompt: List[str],
|
@@ -28,6 +38,7 @@ class HighRes(AbstractPipeline):
|
|
28 |
images,
|
29 |
width: int,
|
30 |
height: int,
|
|
|
31 |
num_inference_steps: int,
|
32 |
strength: float = 0.5,
|
33 |
guidance_scale: int = 9,
|
@@ -35,7 +46,18 @@ class HighRes(AbstractPipeline):
|
|
35 |
):
|
36 |
clear_cuda_and_gc()
|
37 |
|
|
|
|
|
38 |
images = [image.resize((width, height)) for image in images]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
kwargs = {
|
40 |
"prompt": prompt,
|
41 |
"image": images,
|
@@ -43,9 +65,16 @@ class HighRes(AbstractPipeline):
|
|
43 |
"negative_prompt": negative_prompt,
|
44 |
"guidance_scale": guidance_scale,
|
45 |
"num_inference_steps": num_inference_steps,
|
|
|
46 |
**kwargs,
|
47 |
}
|
|
|
|
|
48 |
result = self.pipe.__call__(**kwargs)
|
|
|
|
|
|
|
|
|
49 |
return Result.from_result(result)
|
50 |
|
51 |
@staticmethod
|
|
|
1 |
import math
|
2 |
+
from typing import Dict, List, Optional
|
3 |
|
4 |
from PIL import Image
|
5 |
|
6 |
from internals.data.result import Result
|
7 |
from internals.pipelines.commons import AbstractPipeline, Img2Img
|
8 |
+
from internals.util import get_generators
|
9 |
from internals.util.cache import clear_cuda_and_gc
|
10 |
+
from internals.util.config import (
|
11 |
+
get_base_dimension,
|
12 |
+
get_is_sdxl,
|
13 |
+
get_model_dir,
|
14 |
+
get_num_return_sequences,
|
15 |
+
)
|
16 |
+
from internals.util.sdxl_lightning import LightningMixin
|
17 |
|
18 |
|
19 |
+
class HighRes(AbstractPipeline, LightningMixin):
|
20 |
def load(self, img2img: Optional[Img2Img] = None):
|
21 |
if hasattr(self, "pipe"):
|
22 |
return
|
|
|
28 |
self.pipe = img2img.pipe
|
29 |
self.img2img = img2img
|
30 |
|
31 |
+
if get_is_sdxl():
|
32 |
+
self.configure_sdxl_lightning(img2img.pipe)
|
33 |
+
|
34 |
def apply(
|
35 |
self,
|
36 |
prompt: List[str],
|
|
|
38 |
images,
|
39 |
width: int,
|
40 |
height: int,
|
41 |
+
seed: int,
|
42 |
num_inference_steps: int,
|
43 |
strength: float = 0.5,
|
44 |
guidance_scale: int = 9,
|
|
|
46 |
):
|
47 |
clear_cuda_and_gc()
|
48 |
|
49 |
+
generator = get_generators(seed, get_num_return_sequences())
|
50 |
+
|
51 |
images = [image.resize((width, height)) for image in images]
|
52 |
+
|
53 |
+
# if get_is_sdxl():
|
54 |
+
# kwargs["guidance_scale"] = kwargs.get("guidance_scale", 15)
|
55 |
+
# kwargs["strength"] = kwargs.get("strength", 0.6)
|
56 |
+
|
57 |
+
if get_is_sdxl():
|
58 |
+
extra_args = self.enable_sdxl_lightning()
|
59 |
+
kwargs.update(extra_args)
|
60 |
+
|
61 |
kwargs = {
|
62 |
"prompt": prompt,
|
63 |
"image": images,
|
|
|
65 |
"negative_prompt": negative_prompt,
|
66 |
"guidance_scale": guidance_scale,
|
67 |
"num_inference_steps": num_inference_steps,
|
68 |
+
"generator": generator,
|
69 |
**kwargs,
|
70 |
}
|
71 |
+
|
72 |
+
print(kwargs)
|
73 |
result = self.pipe.__call__(**kwargs)
|
74 |
+
|
75 |
+
if get_is_sdxl():
|
76 |
+
self.disable_sdxl_lightning()
|
77 |
+
|
78 |
return Result.from_result(result)
|
79 |
|
80 |
@staticmethod
|
internals/pipelines/inpaint_imageprocessor.py
ADDED
@@ -0,0 +1,976 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import warnings
|
16 |
+
from typing import List, Optional, Tuple, Union
|
17 |
+
|
18 |
+
import numpy as np
|
19 |
+
import PIL.Image
|
20 |
+
import torch
|
21 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
22 |
+
from diffusers.utils import CONFIG_NAME, PIL_INTERPOLATION, deprecate
|
23 |
+
from PIL import Image, ImageFilter, ImageOps
|
24 |
+
|
25 |
+
PipelineImageInput = Union[
|
26 |
+
PIL.Image.Image,
|
27 |
+
np.ndarray,
|
28 |
+
torch.FloatTensor,
|
29 |
+
List[PIL.Image.Image],
|
30 |
+
List[np.ndarray],
|
31 |
+
List[torch.FloatTensor],
|
32 |
+
]
|
33 |
+
|
34 |
+
PipelineDepthInput = PipelineImageInput
|
35 |
+
|
36 |
+
|
37 |
+
class VaeImageProcessor(ConfigMixin):
|
38 |
+
"""
|
39 |
+
Image processor for VAE.
|
40 |
+
|
41 |
+
Args:
|
42 |
+
do_resize (`bool`, *optional*, defaults to `True`):
|
43 |
+
Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`. Can accept
|
44 |
+
`height` and `width` arguments from [`image_processor.VaeImageProcessor.preprocess`] method.
|
45 |
+
vae_scale_factor (`int`, *optional*, defaults to `8`):
|
46 |
+
VAE scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of this factor.
|
47 |
+
resample (`str`, *optional*, defaults to `lanczos`):
|
48 |
+
Resampling filter to use when resizing the image.
|
49 |
+
do_normalize (`bool`, *optional*, defaults to `True`):
|
50 |
+
Whether to normalize the image to [-1,1].
|
51 |
+
do_binarize (`bool`, *optional*, defaults to `False`):
|
52 |
+
Whether to binarize the image to 0/1.
|
53 |
+
do_convert_rgb (`bool`, *optional*, defaults to be `False`):
|
54 |
+
Whether to convert the images to RGB format.
|
55 |
+
do_convert_grayscale (`bool`, *optional*, defaults to be `False`):
|
56 |
+
Whether to convert the images to grayscale format.
|
57 |
+
"""
|
58 |
+
|
59 |
+
config_name = CONFIG_NAME
|
60 |
+
|
61 |
+
@register_to_config
|
62 |
+
def __init__(
|
63 |
+
self,
|
64 |
+
do_resize: bool = True,
|
65 |
+
vae_scale_factor: int = 8,
|
66 |
+
resample: str = "lanczos",
|
67 |
+
do_normalize: bool = True,
|
68 |
+
do_binarize: bool = False,
|
69 |
+
do_convert_rgb: bool = False,
|
70 |
+
do_convert_grayscale: bool = False,
|
71 |
+
):
|
72 |
+
super().__init__()
|
73 |
+
if do_convert_rgb and do_convert_grayscale:
|
74 |
+
raise ValueError(
|
75 |
+
"`do_convert_rgb` and `do_convert_grayscale` can not both be set to `True`,"
|
76 |
+
" if you intended to convert the image into RGB format, please set `do_convert_grayscale = False`.",
|
77 |
+
" if you intended to convert the image into grayscale format, please set `do_convert_rgb = False`",
|
78 |
+
)
|
79 |
+
self.config.do_convert_rgb = False
|
80 |
+
|
81 |
+
@staticmethod
|
82 |
+
def numpy_to_pil(images: np.ndarray) -> List[PIL.Image.Image]:
|
83 |
+
"""
|
84 |
+
Convert a numpy image or a batch of images to a PIL image.
|
85 |
+
"""
|
86 |
+
if images.ndim == 3:
|
87 |
+
images = images[None, ...]
|
88 |
+
images = (images * 255).round().astype("uint8")
|
89 |
+
if images.shape[-1] == 1:
|
90 |
+
# special case for grayscale (single channel) images
|
91 |
+
pil_images = [
|
92 |
+
Image.fromarray(image.squeeze(), mode="L") for image in images
|
93 |
+
]
|
94 |
+
else:
|
95 |
+
pil_images = [Image.fromarray(image) for image in images]
|
96 |
+
|
97 |
+
return pil_images
|
98 |
+
|
99 |
+
@staticmethod
|
100 |
+
def pil_to_numpy(
|
101 |
+
images: Union[List[PIL.Image.Image], PIL.Image.Image]
|
102 |
+
) -> np.ndarray:
|
103 |
+
"""
|
104 |
+
Convert a PIL image or a list of PIL images to NumPy arrays.
|
105 |
+
"""
|
106 |
+
if not isinstance(images, list):
|
107 |
+
images = [images]
|
108 |
+
images = [np.array(image).astype(np.float32) / 255.0 for image in images]
|
109 |
+
images = np.stack(images, axis=0)
|
110 |
+
|
111 |
+
return images
|
112 |
+
|
113 |
+
@staticmethod
|
114 |
+
def numpy_to_pt(images: np.ndarray) -> torch.FloatTensor:
|
115 |
+
"""
|
116 |
+
Convert a NumPy image to a PyTorch tensor.
|
117 |
+
"""
|
118 |
+
if images.ndim == 3:
|
119 |
+
images = images[..., None]
|
120 |
+
|
121 |
+
images = torch.from_numpy(images.transpose(0, 3, 1, 2))
|
122 |
+
return images
|
123 |
+
|
124 |
+
@staticmethod
|
125 |
+
def pt_to_numpy(images: torch.FloatTensor) -> np.ndarray:
|
126 |
+
"""
|
127 |
+
Convert a PyTorch tensor to a NumPy image.
|
128 |
+
"""
|
129 |
+
images = images.cpu().permute(0, 2, 3, 1).float().numpy()
|
130 |
+
return images
|
131 |
+
|
132 |
+
@staticmethod
|
133 |
+
def normalize(
|
134 |
+
images: Union[np.ndarray, torch.Tensor]
|
135 |
+
) -> Union[np.ndarray, torch.Tensor]:
|
136 |
+
"""
|
137 |
+
Normalize an image array to [-1,1].
|
138 |
+
"""
|
139 |
+
return 2.0 * images - 1.0
|
140 |
+
|
141 |
+
@staticmethod
|
142 |
+
def denormalize(
|
143 |
+
images: Union[np.ndarray, torch.Tensor]
|
144 |
+
) -> Union[np.ndarray, torch.Tensor]:
|
145 |
+
"""
|
146 |
+
Denormalize an image array to [0,1].
|
147 |
+
"""
|
148 |
+
return (images / 2 + 0.5).clamp(0, 1)
|
149 |
+
|
150 |
+
@staticmethod
|
151 |
+
def convert_to_rgb(image: PIL.Image.Image) -> PIL.Image.Image:
|
152 |
+
"""
|
153 |
+
Converts a PIL image to RGB format.
|
154 |
+
"""
|
155 |
+
image = image.convert("RGB")
|
156 |
+
|
157 |
+
return image
|
158 |
+
|
159 |
+
@staticmethod
|
160 |
+
def convert_to_grayscale(image: PIL.Image.Image) -> PIL.Image.Image:
|
161 |
+
"""
|
162 |
+
Converts a PIL image to grayscale format.
|
163 |
+
"""
|
164 |
+
image = image.convert("L")
|
165 |
+
|
166 |
+
return image
|
167 |
+
|
168 |
+
@staticmethod
|
169 |
+
def blur(image: PIL.Image.Image, blur_factor: int = 4) -> PIL.Image.Image:
|
170 |
+
"""
|
171 |
+
Applies Gaussian blur to an image.
|
172 |
+
"""
|
173 |
+
image = image.filter(ImageFilter.GaussianBlur(blur_factor))
|
174 |
+
|
175 |
+
return image
|
176 |
+
|
177 |
+
@staticmethod
|
178 |
+
def get_crop_region(mask_image: PIL.Image.Image, width: int, height: int, pad=0):
|
179 |
+
"""
|
180 |
+
Finds a rectangular region that contains all masked ares in an image, and expands region to match the aspect ratio of the original image;
|
181 |
+
for example, if user drew mask in a 128x32 region, and the dimensions for processing are 512x512, the region will be expanded to 128x128.
|
182 |
+
|
183 |
+
Args:
|
184 |
+
mask_image (PIL.Image.Image): Mask image.
|
185 |
+
width (int): Width of the image to be processed.
|
186 |
+
height (int): Height of the image to be processed.
|
187 |
+
pad (int, optional): Padding to be added to the crop region. Defaults to 0.
|
188 |
+
|
189 |
+
Returns:
|
190 |
+
tuple: (x1, y1, x2, y2) represent a rectangular region that contains all masked ares in an image and matches the original aspect ratio.
|
191 |
+
"""
|
192 |
+
|
193 |
+
mask_image = mask_image.convert("L")
|
194 |
+
mask = np.array(mask_image)
|
195 |
+
|
196 |
+
# 1. find a rectangular region that contains all masked ares in an image
|
197 |
+
h, w = mask.shape
|
198 |
+
crop_left = 0
|
199 |
+
for i in range(w):
|
200 |
+
if not (mask[:, i] == 0).all():
|
201 |
+
break
|
202 |
+
crop_left += 1
|
203 |
+
|
204 |
+
crop_right = 0
|
205 |
+
for i in reversed(range(w)):
|
206 |
+
if not (mask[:, i] == 0).all():
|
207 |
+
break
|
208 |
+
crop_right += 1
|
209 |
+
|
210 |
+
crop_top = 0
|
211 |
+
for i in range(h):
|
212 |
+
if not (mask[i] == 0).all():
|
213 |
+
break
|
214 |
+
crop_top += 1
|
215 |
+
|
216 |
+
crop_bottom = 0
|
217 |
+
for i in reversed(range(h)):
|
218 |
+
if not (mask[i] == 0).all():
|
219 |
+
break
|
220 |
+
crop_bottom += 1
|
221 |
+
|
222 |
+
# 2. add padding to the crop region
|
223 |
+
x1, y1, x2, y2 = (
|
224 |
+
int(max(crop_left - pad, 0)),
|
225 |
+
int(max(crop_top - pad, 0)),
|
226 |
+
int(min(w - crop_right + pad, w)),
|
227 |
+
int(min(h - crop_bottom + pad, h)),
|
228 |
+
)
|
229 |
+
|
230 |
+
# 3. expands crop region to match the aspect ratio of the image to be processed
|
231 |
+
ratio_crop_region = (x2 - x1) / (y2 - y1)
|
232 |
+
ratio_processing = width / height
|
233 |
+
|
234 |
+
if ratio_crop_region > ratio_processing:
|
235 |
+
desired_height = (x2 - x1) / ratio_processing
|
236 |
+
desired_height_diff = int(desired_height - (y2 - y1))
|
237 |
+
y1 -= desired_height_diff // 2
|
238 |
+
y2 += desired_height_diff - desired_height_diff // 2
|
239 |
+
if y2 >= mask_image.height:
|
240 |
+
diff = y2 - mask_image.height
|
241 |
+
y2 -= diff
|
242 |
+
y1 -= diff
|
243 |
+
if y1 < 0:
|
244 |
+
y2 -= y1
|
245 |
+
y1 -= y1
|
246 |
+
if y2 >= mask_image.height:
|
247 |
+
y2 = mask_image.height
|
248 |
+
else:
|
249 |
+
desired_width = (y2 - y1) * ratio_processing
|
250 |
+
desired_width_diff = int(desired_width - (x2 - x1))
|
251 |
+
x1 -= desired_width_diff // 2
|
252 |
+
x2 += desired_width_diff - desired_width_diff // 2
|
253 |
+
if x2 >= mask_image.width:
|
254 |
+
diff = x2 - mask_image.width
|
255 |
+
x2 -= diff
|
256 |
+
x1 -= diff
|
257 |
+
if x1 < 0:
|
258 |
+
x2 -= x1
|
259 |
+
x1 -= x1
|
260 |
+
if x2 >= mask_image.width:
|
261 |
+
x2 = mask_image.width
|
262 |
+
|
263 |
+
return x1, y1, x2, y2
|
264 |
+
|
265 |
+
def _resize_and_fill(
|
266 |
+
self,
|
267 |
+
image: PIL.Image.Image,
|
268 |
+
width: int,
|
269 |
+
height: int,
|
270 |
+
) -> PIL.Image.Image:
|
271 |
+
"""
|
272 |
+
Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, filling empty with data from image.
|
273 |
+
|
274 |
+
Args:
|
275 |
+
image: The image to resize.
|
276 |
+
width: The width to resize the image to.
|
277 |
+
height: The height to resize the image to.
|
278 |
+
"""
|
279 |
+
|
280 |
+
ratio = width / height
|
281 |
+
src_ratio = image.width / image.height
|
282 |
+
|
283 |
+
src_w = width if ratio < src_ratio else image.width * height // image.height
|
284 |
+
src_h = height if ratio >= src_ratio else image.height * width // image.width
|
285 |
+
|
286 |
+
resized = image.resize((src_w, src_h), resample=PIL_INTERPOLATION["lanczos"])
|
287 |
+
res = Image.new("RGB", (width, height))
|
288 |
+
res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
|
289 |
+
|
290 |
+
if ratio < src_ratio:
|
291 |
+
fill_height = height // 2 - src_h // 2
|
292 |
+
if fill_height > 0:
|
293 |
+
res.paste(
|
294 |
+
resized.resize((width, fill_height), box=(0, 0, width, 0)),
|
295 |
+
box=(0, 0),
|
296 |
+
)
|
297 |
+
res.paste(
|
298 |
+
resized.resize(
|
299 |
+
(width, fill_height),
|
300 |
+
box=(0, resized.height, width, resized.height),
|
301 |
+
),
|
302 |
+
box=(0, fill_height + src_h),
|
303 |
+
)
|
304 |
+
elif ratio > src_ratio:
|
305 |
+
fill_width = width // 2 - src_w // 2
|
306 |
+
if fill_width > 0:
|
307 |
+
res.paste(
|
308 |
+
resized.resize((fill_width, height), box=(0, 0, 0, height)),
|
309 |
+
box=(0, 0),
|
310 |
+
)
|
311 |
+
res.paste(
|
312 |
+
resized.resize(
|
313 |
+
(fill_width, height),
|
314 |
+
box=(resized.width, 0, resized.width, height),
|
315 |
+
),
|
316 |
+
box=(fill_width + src_w, 0),
|
317 |
+
)
|
318 |
+
|
319 |
+
return res
|
320 |
+
|
321 |
+
def _resize_and_crop(
|
322 |
+
self,
|
323 |
+
image: PIL.Image.Image,
|
324 |
+
width: int,
|
325 |
+
height: int,
|
326 |
+
) -> PIL.Image.Image:
|
327 |
+
"""
|
328 |
+
Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, cropping the excess.
|
329 |
+
|
330 |
+
Args:
|
331 |
+
image: The image to resize.
|
332 |
+
width: The width to resize the image to.
|
333 |
+
height: The height to resize the image to.
|
334 |
+
"""
|
335 |
+
ratio = width / height
|
336 |
+
src_ratio = image.width / image.height
|
337 |
+
|
338 |
+
src_w = width if ratio > src_ratio else image.width * height // image.height
|
339 |
+
src_h = height if ratio <= src_ratio else image.height * width // image.width
|
340 |
+
|
341 |
+
resized = image.resize((src_w, src_h), resample=PIL_INTERPOLATION["lanczos"])
|
342 |
+
res = Image.new("RGB", (width, height))
|
343 |
+
res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
|
344 |
+
return res
|
345 |
+
|
346 |
+
def resize(
|
347 |
+
self,
|
348 |
+
image: Union[PIL.Image.Image, np.ndarray, torch.Tensor],
|
349 |
+
height: int,
|
350 |
+
width: int,
|
351 |
+
resize_mode: str = "default", # "defalt", "fill", "crop"
|
352 |
+
) -> Union[PIL.Image.Image, np.ndarray, torch.Tensor]:
|
353 |
+
"""
|
354 |
+
Resize image.
|
355 |
+
|
356 |
+
Args:
|
357 |
+
image (`PIL.Image.Image`, `np.ndarray` or `torch.Tensor`):
|
358 |
+
The image input, can be a PIL image, numpy array or pytorch tensor.
|
359 |
+
height (`int`):
|
360 |
+
The height to resize to.
|
361 |
+
width (`int`):
|
362 |
+
The width to resize to.
|
363 |
+
resize_mode (`str`, *optional*, defaults to `default`):
|
364 |
+
The resize mode to use, can be one of `default` or `fill`. If `default`, will resize the image to fit
|
365 |
+
within the specified width and height, and it may not maintaining the original aspect ratio.
|
366 |
+
If `fill`, will resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image
|
367 |
+
within the dimensions, filling empty with data from image.
|
368 |
+
If `crop`, will resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image
|
369 |
+
within the dimensions, cropping the excess.
|
370 |
+
Note that resize_mode `fill` and `crop` are only supported for PIL image input.
|
371 |
+
|
372 |
+
Returns:
|
373 |
+
`PIL.Image.Image`, `np.ndarray` or `torch.Tensor`:
|
374 |
+
The resized image.
|
375 |
+
"""
|
376 |
+
if resize_mode != "default" and not isinstance(image, PIL.Image.Image):
|
377 |
+
raise ValueError(
|
378 |
+
f"Only PIL image input is supported for resize_mode {resize_mode}"
|
379 |
+
)
|
380 |
+
if isinstance(image, PIL.Image.Image):
|
381 |
+
if resize_mode == "default":
|
382 |
+
image = image.resize(
|
383 |
+
(width, height), resample=PIL_INTERPOLATION[self.config.resample]
|
384 |
+
)
|
385 |
+
elif resize_mode == "fill":
|
386 |
+
image = self._resize_and_fill(image, width, height)
|
387 |
+
elif resize_mode == "crop":
|
388 |
+
image = self._resize_and_crop(image, width, height)
|
389 |
+
else:
|
390 |
+
raise ValueError(f"resize_mode {resize_mode} is not supported")
|
391 |
+
|
392 |
+
elif isinstance(image, torch.Tensor):
|
393 |
+
image = torch.nn.functional.interpolate(
|
394 |
+
image,
|
395 |
+
size=(height, width),
|
396 |
+
)
|
397 |
+
elif isinstance(image, np.ndarray):
|
398 |
+
image = self.numpy_to_pt(image)
|
399 |
+
image = torch.nn.functional.interpolate(
|
400 |
+
image,
|
401 |
+
size=(height, width),
|
402 |
+
)
|
403 |
+
image = self.pt_to_numpy(image)
|
404 |
+
return image
|
405 |
+
|
406 |
+
def binarize(self, image: PIL.Image.Image) -> PIL.Image.Image:
|
407 |
+
"""
|
408 |
+
Create a mask.
|
409 |
+
|
410 |
+
Args:
|
411 |
+
image (`PIL.Image.Image`):
|
412 |
+
The image input, should be a PIL image.
|
413 |
+
|
414 |
+
Returns:
|
415 |
+
`PIL.Image.Image`:
|
416 |
+
The binarized image. Values less than 0.5 are set to 0, values greater than 0.5 are set to 1.
|
417 |
+
"""
|
418 |
+
image[image < 0.5] = 0
|
419 |
+
image[image >= 0.5] = 1
|
420 |
+
|
421 |
+
return image
|
422 |
+
|
423 |
+
def get_default_height_width(
|
424 |
+
self,
|
425 |
+
image: Union[PIL.Image.Image, np.ndarray, torch.Tensor],
|
426 |
+
height: Optional[int] = None,
|
427 |
+
width: Optional[int] = None,
|
428 |
+
) -> Tuple[int, int]:
|
429 |
+
"""
|
430 |
+
This function return the height and width that are downscaled to the next integer multiple of
|
431 |
+
`vae_scale_factor`.
|
432 |
+
|
433 |
+
Args:
|
434 |
+
image(`PIL.Image.Image`, `np.ndarray` or `torch.Tensor`):
|
435 |
+
The image input, can be a PIL image, numpy array or pytorch tensor. if it is a numpy array, should have
|
436 |
+
shape `[batch, height, width]` or `[batch, height, width, channel]` if it is a pytorch tensor, should
|
437 |
+
have shape `[batch, channel, height, width]`.
|
438 |
+
height (`int`, *optional*, defaults to `None`):
|
439 |
+
The height in preprocessed image. If `None`, will use the height of `image` input.
|
440 |
+
width (`int`, *optional*`, defaults to `None`):
|
441 |
+
The width in preprocessed. If `None`, will use the width of the `image` input.
|
442 |
+
"""
|
443 |
+
|
444 |
+
if height is None:
|
445 |
+
if isinstance(image, PIL.Image.Image):
|
446 |
+
height = image.height
|
447 |
+
elif isinstance(image, torch.Tensor):
|
448 |
+
height = image.shape[2]
|
449 |
+
else:
|
450 |
+
height = image.shape[1]
|
451 |
+
|
452 |
+
if width is None:
|
453 |
+
if isinstance(image, PIL.Image.Image):
|
454 |
+
width = image.width
|
455 |
+
elif isinstance(image, torch.Tensor):
|
456 |
+
width = image.shape[3]
|
457 |
+
else:
|
458 |
+
width = image.shape[2]
|
459 |
+
|
460 |
+
width, height = (
|
461 |
+
x - x % self.config.vae_scale_factor for x in (width, height)
|
462 |
+
) # resize to integer multiple of vae_scale_factor
|
463 |
+
|
464 |
+
return height, width
|
465 |
+
|
466 |
+
def preprocess(
|
467 |
+
self,
|
468 |
+
image: PipelineImageInput,
|
469 |
+
height: Optional[int] = None,
|
470 |
+
width: Optional[int] = None,
|
471 |
+
resize_mode: str = "default", # "defalt", "fill", "crop"
|
472 |
+
crops_coords: Optional[Tuple[int, int, int, int]] = None,
|
473 |
+
) -> torch.Tensor:
|
474 |
+
"""
|
475 |
+
Preprocess the image input.
|
476 |
+
|
477 |
+
Args:
|
478 |
+
image (`pipeline_image_input`):
|
479 |
+
The image input, accepted formats are PIL images, NumPy arrays, PyTorch tensors; Also accept list of supported formats.
|
480 |
+
height (`int`, *optional*, defaults to `None`):
|
481 |
+
The height in preprocessed image. If `None`, will use the `get_default_height_width()` to get default height.
|
482 |
+
width (`int`, *optional*`, defaults to `None`):
|
483 |
+
The width in preprocessed. If `None`, will use get_default_height_width()` to get the default width.
|
484 |
+
resize_mode (`str`, *optional*, defaults to `default`):
|
485 |
+
The resize mode, can be one of `default` or `fill`. If `default`, will resize the image to fit
|
486 |
+
within the specified width and height, and it may not maintaining the original aspect ratio.
|
487 |
+
If `fill`, will resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image
|
488 |
+
within the dimensions, filling empty with data from image.
|
489 |
+
If `crop`, will resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image
|
490 |
+
within the dimensions, cropping the excess.
|
491 |
+
Note that resize_mode `fill` and `crop` are only supported for PIL image input.
|
492 |
+
crops_coords (`List[Tuple[int, int, int, int]]`, *optional*, defaults to `None`):
|
493 |
+
The crop coordinates for each image in the batch. If `None`, will not crop the image.
|
494 |
+
"""
|
495 |
+
supported_formats = (PIL.Image.Image, np.ndarray, torch.Tensor)
|
496 |
+
|
497 |
+
# Expand the missing dimension for 3-dimensional pytorch tensor or numpy array that represents grayscale image
|
498 |
+
if (
|
499 |
+
self.config.do_convert_grayscale
|
500 |
+
and isinstance(image, (torch.Tensor, np.ndarray))
|
501 |
+
and image.ndim == 3
|
502 |
+
):
|
503 |
+
if isinstance(image, torch.Tensor):
|
504 |
+
# if image is a pytorch tensor could have 2 possible shapes:
|
505 |
+
# 1. batch x height x width: we should insert the channel dimension at position 1
|
506 |
+
# 2. channnel x height x width: we should insert batch dimension at position 0,
|
507 |
+
# however, since both channel and batch dimension has same size 1, it is same to insert at position 1
|
508 |
+
# for simplicity, we insert a dimension of size 1 at position 1 for both cases
|
509 |
+
image = image.unsqueeze(1)
|
510 |
+
else:
|
511 |
+
# if it is a numpy array, it could have 2 possible shapes:
|
512 |
+
# 1. batch x height x width: insert channel dimension on last position
|
513 |
+
# 2. height x width x channel: insert batch dimension on first position
|
514 |
+
if image.shape[-1] == 1:
|
515 |
+
image = np.expand_dims(image, axis=0)
|
516 |
+
else:
|
517 |
+
image = np.expand_dims(image, axis=-1)
|
518 |
+
|
519 |
+
if isinstance(image, supported_formats):
|
520 |
+
image = [image]
|
521 |
+
elif not (
|
522 |
+
isinstance(image, list)
|
523 |
+
and all(isinstance(i, supported_formats) for i in image)
|
524 |
+
):
|
525 |
+
raise ValueError(
|
526 |
+
f"Input is in incorrect format: {[type(i) for i in image]}. Currently, we only support {', '.join(supported_formats)}"
|
527 |
+
)
|
528 |
+
|
529 |
+
if isinstance(image[0], PIL.Image.Image):
|
530 |
+
if crops_coords is not None:
|
531 |
+
image = [i.crop(crops_coords) for i in image]
|
532 |
+
if self.config.do_resize:
|
533 |
+
height, width = self.get_default_height_width(image[0], height, width)
|
534 |
+
image = [
|
535 |
+
self.resize(i, height, width, resize_mode=resize_mode)
|
536 |
+
for i in image
|
537 |
+
]
|
538 |
+
if self.config.do_convert_rgb:
|
539 |
+
image = [self.convert_to_rgb(i) for i in image]
|
540 |
+
elif self.config.do_convert_grayscale:
|
541 |
+
image = [self.convert_to_grayscale(i) for i in image]
|
542 |
+
image = self.pil_to_numpy(image) # to np
|
543 |
+
image = self.numpy_to_pt(image) # to pt
|
544 |
+
|
545 |
+
elif isinstance(image[0], np.ndarray):
|
546 |
+
image = (
|
547 |
+
np.concatenate(image, axis=0)
|
548 |
+
if image[0].ndim == 4
|
549 |
+
else np.stack(image, axis=0)
|
550 |
+
)
|
551 |
+
|
552 |
+
image = self.numpy_to_pt(image)
|
553 |
+
|
554 |
+
height, width = self.get_default_height_width(image, height, width)
|
555 |
+
if self.config.do_resize:
|
556 |
+
image = self.resize(image, height, width)
|
557 |
+
|
558 |
+
elif isinstance(image[0], torch.Tensor):
|
559 |
+
image = (
|
560 |
+
torch.cat(image, axis=0)
|
561 |
+
if image[0].ndim == 4
|
562 |
+
else torch.stack(image, axis=0)
|
563 |
+
)
|
564 |
+
|
565 |
+
if self.config.do_convert_grayscale and image.ndim == 3:
|
566 |
+
image = image.unsqueeze(1)
|
567 |
+
|
568 |
+
channel = image.shape[1]
|
569 |
+
# don't need any preprocess if the image is latents
|
570 |
+
if channel == 4:
|
571 |
+
return image
|
572 |
+
|
573 |
+
height, width = self.get_default_height_width(image, height, width)
|
574 |
+
if self.config.do_resize:
|
575 |
+
image = self.resize(image, height, width)
|
576 |
+
|
577 |
+
# expected range [0,1], normalize to [-1,1]
|
578 |
+
do_normalize = self.config.do_normalize
|
579 |
+
if do_normalize and image.min() < 0:
|
580 |
+
warnings.warn(
|
581 |
+
"Passing `image` as torch tensor with value range in [-1,1] is deprecated. The expected value range for image tensor is [0,1] "
|
582 |
+
f"when passing as pytorch tensor or numpy Array. You passed `image` with value range [{image.min()},{image.max()}]",
|
583 |
+
FutureWarning,
|
584 |
+
)
|
585 |
+
do_normalize = False
|
586 |
+
|
587 |
+
if do_normalize:
|
588 |
+
image = self.normalize(image)
|
589 |
+
|
590 |
+
if self.config.do_binarize:
|
591 |
+
image = self.binarize(image)
|
592 |
+
|
593 |
+
return image
|
594 |
+
|
595 |
+
def postprocess(
|
596 |
+
self,
|
597 |
+
image: torch.FloatTensor,
|
598 |
+
output_type: str = "pil",
|
599 |
+
do_denormalize: Optional[List[bool]] = None,
|
600 |
+
) -> Union[PIL.Image.Image, np.ndarray, torch.FloatTensor]:
|
601 |
+
"""
|
602 |
+
Postprocess the image output from tensor to `output_type`.
|
603 |
+
|
604 |
+
Args:
|
605 |
+
image (`torch.FloatTensor`):
|
606 |
+
The image input, should be a pytorch tensor with shape `B x C x H x W`.
|
607 |
+
output_type (`str`, *optional*, defaults to `pil`):
|
608 |
+
The output type of the image, can be one of `pil`, `np`, `pt`, `latent`.
|
609 |
+
do_denormalize (`List[bool]`, *optional*, defaults to `None`):
|
610 |
+
Whether to denormalize the image to [0,1]. If `None`, will use the value of `do_normalize` in the
|
611 |
+
`VaeImageProcessor` config.
|
612 |
+
|
613 |
+
Returns:
|
614 |
+
`PIL.Image.Image`, `np.ndarray` or `torch.FloatTensor`:
|
615 |
+
The postprocessed image.
|
616 |
+
"""
|
617 |
+
if not isinstance(image, torch.Tensor):
|
618 |
+
raise ValueError(
|
619 |
+
f"Input for postprocessing is in incorrect format: {type(image)}. We only support pytorch tensor"
|
620 |
+
)
|
621 |
+
if output_type not in ["latent", "pt", "np", "pil"]:
|
622 |
+
deprecation_message = (
|
623 |
+
f"the output_type {output_type} is outdated and has been set to `np`. Please make sure to set it to one of these instead: "
|
624 |
+
"`pil`, `np`, `pt`, `latent`"
|
625 |
+
)
|
626 |
+
deprecate(
|
627 |
+
"Unsupported output_type",
|
628 |
+
"1.0.0",
|
629 |
+
deprecation_message,
|
630 |
+
standard_warn=False,
|
631 |
+
)
|
632 |
+
output_type = "np"
|
633 |
+
|
634 |
+
if output_type == "latent":
|
635 |
+
return image
|
636 |
+
|
637 |
+
if do_denormalize is None:
|
638 |
+
do_denormalize = [self.config.do_normalize] * image.shape[0]
|
639 |
+
|
640 |
+
image = torch.stack(
|
641 |
+
[
|
642 |
+
self.denormalize(image[i]) if do_denormalize[i] else image[i]
|
643 |
+
for i in range(image.shape[0])
|
644 |
+
]
|
645 |
+
)
|
646 |
+
|
647 |
+
if output_type == "pt":
|
648 |
+
return image
|
649 |
+
|
650 |
+
image = self.pt_to_numpy(image)
|
651 |
+
|
652 |
+
if output_type == "np":
|
653 |
+
return image
|
654 |
+
|
655 |
+
if output_type == "pil":
|
656 |
+
return self.numpy_to_pil(image)
|
657 |
+
|
658 |
+
def apply_overlay(
|
659 |
+
self,
|
660 |
+
mask: PIL.Image.Image,
|
661 |
+
init_image: PIL.Image.Image,
|
662 |
+
image: PIL.Image.Image,
|
663 |
+
crop_coords: Optional[Tuple[int, int, int, int]] = None,
|
664 |
+
) -> PIL.Image.Image:
|
665 |
+
"""
|
666 |
+
overlay the inpaint output to the original image
|
667 |
+
"""
|
668 |
+
|
669 |
+
image = image.resize(init_image.size)
|
670 |
+
width, height = image.width, image.height
|
671 |
+
|
672 |
+
init_image = self.resize(init_image, width=width, height=height)
|
673 |
+
mask = self.resize(mask, width=width, height=height)
|
674 |
+
|
675 |
+
init_image_masked = PIL.Image.new("RGBa", (width, height))
|
676 |
+
init_image_masked.paste(
|
677 |
+
init_image.convert("RGBA").convert("RGBa"),
|
678 |
+
mask=ImageOps.invert(mask.convert("L")),
|
679 |
+
)
|
680 |
+
init_image_masked = init_image_masked.convert("RGBA")
|
681 |
+
|
682 |
+
if crop_coords is not None:
|
683 |
+
x, y, x2, y2 = crop_coords
|
684 |
+
w = x2 - x
|
685 |
+
h = y2 - y
|
686 |
+
base_image = PIL.Image.new("RGBA", (width, height))
|
687 |
+
image = self.resize(image, height=h, width=w, resize_mode="crop")
|
688 |
+
base_image.paste(image, (x, y))
|
689 |
+
image = base_image.convert("RGB")
|
690 |
+
|
691 |
+
image = image.convert("RGBA")
|
692 |
+
image.alpha_composite(init_image_masked)
|
693 |
+
image = image.convert("RGB")
|
694 |
+
|
695 |
+
return image
|
696 |
+
|
697 |
+
|
698 |
+
class VaeImageProcessorLDM3D(VaeImageProcessor):
|
699 |
+
"""
|
700 |
+
Image processor for VAE LDM3D.
|
701 |
+
|
702 |
+
Args:
|
703 |
+
do_resize (`bool`, *optional*, defaults to `True`):
|
704 |
+
Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`.
|
705 |
+
vae_scale_factor (`int`, *optional*, defaults to `8`):
|
706 |
+
VAE scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of this factor.
|
707 |
+
resample (`str`, *optional*, defaults to `lanczos`):
|
708 |
+
Resampling filter to use when resizing the image.
|
709 |
+
do_normalize (`bool`, *optional*, defaults to `True`):
|
710 |
+
Whether to normalize the image to [-1,1].
|
711 |
+
"""
|
712 |
+
|
713 |
+
config_name = CONFIG_NAME
|
714 |
+
|
715 |
+
@register_to_config
|
716 |
+
def __init__(
|
717 |
+
self,
|
718 |
+
do_resize: bool = True,
|
719 |
+
vae_scale_factor: int = 8,
|
720 |
+
resample: str = "lanczos",
|
721 |
+
do_normalize: bool = True,
|
722 |
+
):
|
723 |
+
super().__init__()
|
724 |
+
|
725 |
+
@staticmethod
|
726 |
+
def numpy_to_pil(images: np.ndarray) -> List[PIL.Image.Image]:
|
727 |
+
"""
|
728 |
+
Convert a NumPy image or a batch of images to a PIL image.
|
729 |
+
"""
|
730 |
+
if images.ndim == 3:
|
731 |
+
images = images[None, ...]
|
732 |
+
images = (images * 255).round().astype("uint8")
|
733 |
+
if images.shape[-1] == 1:
|
734 |
+
# special case for grayscale (single channel) images
|
735 |
+
pil_images = [
|
736 |
+
Image.fromarray(image.squeeze(), mode="L") for image in images
|
737 |
+
]
|
738 |
+
else:
|
739 |
+
pil_images = [Image.fromarray(image[:, :, :3]) for image in images]
|
740 |
+
|
741 |
+
return pil_images
|
742 |
+
|
743 |
+
@staticmethod
|
744 |
+
def depth_pil_to_numpy(
|
745 |
+
images: Union[List[PIL.Image.Image], PIL.Image.Image]
|
746 |
+
) -> np.ndarray:
|
747 |
+
"""
|
748 |
+
Convert a PIL image or a list of PIL images to NumPy arrays.
|
749 |
+
"""
|
750 |
+
if not isinstance(images, list):
|
751 |
+
images = [images]
|
752 |
+
|
753 |
+
images = [
|
754 |
+
np.array(image).astype(np.float32) / (2**16 - 1) for image in images
|
755 |
+
]
|
756 |
+
images = np.stack(images, axis=0)
|
757 |
+
return images
|
758 |
+
|
759 |
+
@staticmethod
|
760 |
+
def rgblike_to_depthmap(
|
761 |
+
image: Union[np.ndarray, torch.Tensor]
|
762 |
+
) -> Union[np.ndarray, torch.Tensor]:
|
763 |
+
"""
|
764 |
+
Args:
|
765 |
+
image: RGB-like depth image
|
766 |
+
|
767 |
+
Returns: depth map
|
768 |
+
|
769 |
+
"""
|
770 |
+
return image[:, :, 1] * 2**8 + image[:, :, 2]
|
771 |
+
|
772 |
+
def numpy_to_depth(self, images: np.ndarray) -> List[PIL.Image.Image]:
|
773 |
+
"""
|
774 |
+
Convert a NumPy depth image or a batch of images to a PIL image.
|
775 |
+
"""
|
776 |
+
if images.ndim == 3:
|
777 |
+
images = images[None, ...]
|
778 |
+
images_depth = images[:, :, :, 3:]
|
779 |
+
if images.shape[-1] == 6:
|
780 |
+
images_depth = (images_depth * 255).round().astype("uint8")
|
781 |
+
pil_images = [
|
782 |
+
Image.fromarray(self.rgblike_to_depthmap(image_depth), mode="I;16")
|
783 |
+
for image_depth in images_depth
|
784 |
+
]
|
785 |
+
elif images.shape[-1] == 4:
|
786 |
+
images_depth = (images_depth * 65535.0).astype(np.uint16)
|
787 |
+
pil_images = [
|
788 |
+
Image.fromarray(image_depth, mode="I;16")
|
789 |
+
for image_depth in images_depth
|
790 |
+
]
|
791 |
+
else:
|
792 |
+
raise Exception("Not supported")
|
793 |
+
|
794 |
+
return pil_images
|
795 |
+
|
796 |
+
def postprocess(
|
797 |
+
self,
|
798 |
+
image: torch.FloatTensor,
|
799 |
+
output_type: str = "pil",
|
800 |
+
do_denormalize: Optional[List[bool]] = None,
|
801 |
+
) -> Union[PIL.Image.Image, np.ndarray, torch.FloatTensor]:
|
802 |
+
"""
|
803 |
+
Postprocess the image output from tensor to `output_type`.
|
804 |
+
|
805 |
+
Args:
|
806 |
+
image (`torch.FloatTensor`):
|
807 |
+
The image input, should be a pytorch tensor with shape `B x C x H x W`.
|
808 |
+
output_type (`str`, *optional*, defaults to `pil`):
|
809 |
+
The output type of the image, can be one of `pil`, `np`, `pt`, `latent`.
|
810 |
+
do_denormalize (`List[bool]`, *optional*, defaults to `None`):
|
811 |
+
Whether to denormalize the image to [0,1]. If `None`, will use the value of `do_normalize` in the
|
812 |
+
`VaeImageProcessor` config.
|
813 |
+
|
814 |
+
Returns:
|
815 |
+
`PIL.Image.Image`, `np.ndarray` or `torch.FloatTensor`:
|
816 |
+
The postprocessed image.
|
817 |
+
"""
|
818 |
+
if not isinstance(image, torch.Tensor):
|
819 |
+
raise ValueError(
|
820 |
+
f"Input for postprocessing is in incorrect format: {type(image)}. We only support pytorch tensor"
|
821 |
+
)
|
822 |
+
if output_type not in ["latent", "pt", "np", "pil"]:
|
823 |
+
deprecation_message = (
|
824 |
+
f"the output_type {output_type} is outdated and has been set to `np`. Please make sure to set it to one of these instead: "
|
825 |
+
"`pil`, `np`, `pt`, `latent`"
|
826 |
+
)
|
827 |
+
deprecate(
|
828 |
+
"Unsupported output_type",
|
829 |
+
"1.0.0",
|
830 |
+
deprecation_message,
|
831 |
+
standard_warn=False,
|
832 |
+
)
|
833 |
+
output_type = "np"
|
834 |
+
|
835 |
+
if do_denormalize is None:
|
836 |
+
do_denormalize = [self.config.do_normalize] * image.shape[0]
|
837 |
+
|
838 |
+
image = torch.stack(
|
839 |
+
[
|
840 |
+
self.denormalize(image[i]) if do_denormalize[i] else image[i]
|
841 |
+
for i in range(image.shape[0])
|
842 |
+
]
|
843 |
+
)
|
844 |
+
|
845 |
+
image = self.pt_to_numpy(image)
|
846 |
+
|
847 |
+
if output_type == "np":
|
848 |
+
if image.shape[-1] == 6:
|
849 |
+
image_depth = np.stack(
|
850 |
+
[self.rgblike_to_depthmap(im[:, :, 3:]) for im in image], axis=0
|
851 |
+
)
|
852 |
+
else:
|
853 |
+
image_depth = image[:, :, :, 3:]
|
854 |
+
return image[:, :, :, :3], image_depth
|
855 |
+
|
856 |
+
if output_type == "pil":
|
857 |
+
return self.numpy_to_pil(image), self.numpy_to_depth(image)
|
858 |
+
else:
|
859 |
+
raise Exception(f"This type {output_type} is not supported")
|
860 |
+
|
861 |
+
def preprocess(
|
862 |
+
self,
|
863 |
+
rgb: Union[torch.FloatTensor, PIL.Image.Image, np.ndarray],
|
864 |
+
depth: Union[torch.FloatTensor, PIL.Image.Image, np.ndarray],
|
865 |
+
height: Optional[int] = None,
|
866 |
+
width: Optional[int] = None,
|
867 |
+
target_res: Optional[int] = None,
|
868 |
+
) -> torch.Tensor:
|
869 |
+
"""
|
870 |
+
Preprocess the image input. Accepted formats are PIL images, NumPy arrays or PyTorch tensors.
|
871 |
+
"""
|
872 |
+
supported_formats = (PIL.Image.Image, np.ndarray, torch.Tensor)
|
873 |
+
|
874 |
+
# Expand the missing dimension for 3-dimensional pytorch tensor or numpy array that represents grayscale image
|
875 |
+
if (
|
876 |
+
self.config.do_convert_grayscale
|
877 |
+
and isinstance(rgb, (torch.Tensor, np.ndarray))
|
878 |
+
and rgb.ndim == 3
|
879 |
+
):
|
880 |
+
raise Exception("This is not yet supported")
|
881 |
+
|
882 |
+
if isinstance(rgb, supported_formats):
|
883 |
+
rgb = [rgb]
|
884 |
+
depth = [depth]
|
885 |
+
elif not (
|
886 |
+
isinstance(rgb, list) and all(isinstance(i, supported_formats) for i in rgb)
|
887 |
+
):
|
888 |
+
raise ValueError(
|
889 |
+
f"Input is in incorrect format: {[type(i) for i in rgb]}. Currently, we only support {', '.join(supported_formats)}"
|
890 |
+
)
|
891 |
+
|
892 |
+
if isinstance(rgb[0], PIL.Image.Image):
|
893 |
+
if self.config.do_convert_rgb:
|
894 |
+
raise Exception("This is not yet supported")
|
895 |
+
# rgb = [self.convert_to_rgb(i) for i in rgb]
|
896 |
+
# depth = [self.convert_to_depth(i) for i in depth] #TODO define convert_to_depth
|
897 |
+
if self.config.do_resize or target_res:
|
898 |
+
height, width = (
|
899 |
+
self.get_default_height_width(rgb[0], height, width)
|
900 |
+
if not target_res
|
901 |
+
else target_res
|
902 |
+
)
|
903 |
+
rgb = [self.resize(i, height, width) for i in rgb]
|
904 |
+
depth = [self.resize(i, height, width) for i in depth]
|
905 |
+
rgb = self.pil_to_numpy(rgb) # to np
|
906 |
+
rgb = self.numpy_to_pt(rgb) # to pt
|
907 |
+
|
908 |
+
depth = self.depth_pil_to_numpy(depth) # to np
|
909 |
+
depth = self.numpy_to_pt(depth) # to pt
|
910 |
+
|
911 |
+
elif isinstance(rgb[0], np.ndarray):
|
912 |
+
rgb = (
|
913 |
+
np.concatenate(rgb, axis=0)
|
914 |
+
if rgb[0].ndim == 4
|
915 |
+
else np.stack(rgb, axis=0)
|
916 |
+
)
|
917 |
+
rgb = self.numpy_to_pt(rgb)
|
918 |
+
height, width = self.get_default_height_width(rgb, height, width)
|
919 |
+
if self.config.do_resize:
|
920 |
+
rgb = self.resize(rgb, height, width)
|
921 |
+
|
922 |
+
depth = (
|
923 |
+
np.concatenate(depth, axis=0)
|
924 |
+
if rgb[0].ndim == 4
|
925 |
+
else np.stack(depth, axis=0)
|
926 |
+
)
|
927 |
+
depth = self.numpy_to_pt(depth)
|
928 |
+
height, width = self.get_default_height_width(depth, height, width)
|
929 |
+
if self.config.do_resize:
|
930 |
+
depth = self.resize(depth, height, width)
|
931 |
+
|
932 |
+
elif isinstance(rgb[0], torch.Tensor):
|
933 |
+
raise Exception("This is not yet supported")
|
934 |
+
# rgb = torch.cat(rgb, axis=0) if rgb[0].ndim == 4 else torch.stack(rgb, axis=0)
|
935 |
+
|
936 |
+
# if self.config.do_convert_grayscale and rgb.ndim == 3:
|
937 |
+
# rgb = rgb.unsqueeze(1)
|
938 |
+
|
939 |
+
# channel = rgb.shape[1]
|
940 |
+
|
941 |
+
# height, width = self.get_default_height_width(rgb, height, width)
|
942 |
+
# if self.config.do_resize:
|
943 |
+
# rgb = self.resize(rgb, height, width)
|
944 |
+
|
945 |
+
# depth = torch.cat(depth, axis=0) if depth[0].ndim == 4 else torch.stack(depth, axis=0)
|
946 |
+
|
947 |
+
# if self.config.do_convert_grayscale and depth.ndim == 3:
|
948 |
+
# depth = depth.unsqueeze(1)
|
949 |
+
|
950 |
+
# channel = depth.shape[1]
|
951 |
+
# # don't need any preprocess if the image is latents
|
952 |
+
# if depth == 4:
|
953 |
+
# return rgb, depth
|
954 |
+
|
955 |
+
# height, width = self.get_default_height_width(depth, height, width)
|
956 |
+
# if self.config.do_resize:
|
957 |
+
# depth = self.resize(depth, height, width)
|
958 |
+
# expected range [0,1], normalize to [-1,1]
|
959 |
+
do_normalize = self.config.do_normalize
|
960 |
+
if rgb.min() < 0 and do_normalize:
|
961 |
+
warnings.warn(
|
962 |
+
"Passing `image` as torch tensor with value range in [-1,1] is deprecated. The expected value range for image tensor is [0,1] "
|
963 |
+
f"when passing as pytorch tensor or numpy Array. You passed `image` with value range [{rgb.min()},{rgb.max()}]",
|
964 |
+
FutureWarning,
|
965 |
+
)
|
966 |
+
do_normalize = False
|
967 |
+
|
968 |
+
if do_normalize:
|
969 |
+
rgb = self.normalize(rgb)
|
970 |
+
depth = self.normalize(depth)
|
971 |
+
|
972 |
+
if self.config.do_binarize:
|
973 |
+
rgb = self.binarize(rgb)
|
974 |
+
depth = self.binarize(depth)
|
975 |
+
|
976 |
+
return rgb, depth
|
internals/pipelines/inpainter.py
CHANGED
@@ -1,18 +1,27 @@
|
|
1 |
from typing import List, Union
|
2 |
|
3 |
import torch
|
4 |
-
from diffusers import
|
|
|
|
|
|
|
|
|
5 |
|
6 |
from internals.pipelines.commons import AbstractPipeline
|
|
|
|
|
|
|
7 |
from internals.util.cache import clear_cuda_and_gc
|
8 |
from internals.util.commons import disable_safety_checker, download_image
|
9 |
from internals.util.config import (
|
|
|
10 |
get_base_inpaint_model_variant,
|
11 |
get_hf_cache_dir,
|
12 |
get_hf_token,
|
13 |
get_inpaint_model_path,
|
14 |
get_is_sdxl,
|
15 |
get_model_dir,
|
|
|
16 |
)
|
17 |
|
18 |
|
@@ -32,13 +41,27 @@ class InPainter(AbstractPipeline):
|
|
32 |
return
|
33 |
|
34 |
if get_is_sdxl():
|
35 |
-
|
|
|
36 |
get_inpaint_model_path(),
|
37 |
torch_dtype=torch.float16,
|
38 |
cache_dir=get_hf_cache_dir(),
|
39 |
token=get_hf_token(),
|
|
|
40 |
variant=get_base_inpaint_model_variant(),
|
|
|
41 |
).to("cuda")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
else:
|
43 |
self.pipe = StableDiffusionInpaintPipeline.from_pretrained(
|
44 |
get_inpaint_model_path(),
|
@@ -90,11 +113,18 @@ class InPainter(AbstractPipeline):
|
|
90 |
num_inference_steps: int,
|
91 |
**kwargs,
|
92 |
):
|
93 |
-
|
94 |
|
95 |
input_img = download_image(image_url).resize((width, height))
|
96 |
mask_img = download_image(mask_image_url).resize((width, height))
|
97 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
98 |
kwargs = {
|
99 |
"prompt": prompt,
|
100 |
"image": input_img,
|
@@ -104,6 +134,7 @@ class InPainter(AbstractPipeline):
|
|
104 |
"negative_prompt": negative_prompt,
|
105 |
"num_inference_steps": num_inference_steps,
|
106 |
"strength": 1.0,
|
|
|
107 |
**kwargs,
|
108 |
}
|
109 |
-
return self.pipe.__call__(**kwargs).images
|
|
|
1 |
from typing import List, Union
|
2 |
|
3 |
import torch
|
4 |
+
from diffusers import (
|
5 |
+
StableDiffusionInpaintPipeline,
|
6 |
+
StableDiffusionXLInpaintPipeline,
|
7 |
+
UNet2DConditionModel,
|
8 |
+
)
|
9 |
|
10 |
from internals.pipelines.commons import AbstractPipeline
|
11 |
+
from internals.pipelines.high_res import HighRes
|
12 |
+
from internals.pipelines.inpaint_imageprocessor import VaeImageProcessor
|
13 |
+
from internals.util import get_generators
|
14 |
from internals.util.cache import clear_cuda_and_gc
|
15 |
from internals.util.commons import disable_safety_checker, download_image
|
16 |
from internals.util.config import (
|
17 |
+
get_base_inpaint_model_revision,
|
18 |
get_base_inpaint_model_variant,
|
19 |
get_hf_cache_dir,
|
20 |
get_hf_token,
|
21 |
get_inpaint_model_path,
|
22 |
get_is_sdxl,
|
23 |
get_model_dir,
|
24 |
+
get_num_return_sequences,
|
25 |
)
|
26 |
|
27 |
|
|
|
41 |
return
|
42 |
|
43 |
if get_is_sdxl():
|
44 |
+
# only take UNet from the repo
|
45 |
+
unet = UNet2DConditionModel.from_pretrained(
|
46 |
get_inpaint_model_path(),
|
47 |
torch_dtype=torch.float16,
|
48 |
cache_dir=get_hf_cache_dir(),
|
49 |
token=get_hf_token(),
|
50 |
+
subfolder="unet",
|
51 |
variant=get_base_inpaint_model_variant(),
|
52 |
+
revision=get_base_inpaint_model_revision(),
|
53 |
).to("cuda")
|
54 |
+
kwargs = {**self.__base.pipe.components, "unet": unet}
|
55 |
+
self.pipe = StableDiffusionXLInpaintPipeline(**kwargs).to("cuda")
|
56 |
+
self.pipe.mask_processor = VaeImageProcessor(
|
57 |
+
vae_scale_factor=self.pipe.vae_scale_factor,
|
58 |
+
do_normalize=False,
|
59 |
+
do_binarize=True,
|
60 |
+
do_convert_grayscale=True,
|
61 |
+
)
|
62 |
+
self.pipe.image_processor = VaeImageProcessor(
|
63 |
+
vae_scale_factor=self.pipe.vae_scale_factor
|
64 |
+
)
|
65 |
else:
|
66 |
self.pipe = StableDiffusionInpaintPipeline.from_pretrained(
|
67 |
get_inpaint_model_path(),
|
|
|
113 |
num_inference_steps: int,
|
114 |
**kwargs,
|
115 |
):
|
116 |
+
generator = get_generators(seed, get_num_return_sequences())
|
117 |
|
118 |
input_img = download_image(image_url).resize((width, height))
|
119 |
mask_img = download_image(mask_image_url).resize((width, height))
|
120 |
|
121 |
+
if get_is_sdxl():
|
122 |
+
width, height = HighRes.find_closest_sdxl_aspect_ratio(width, height)
|
123 |
+
mask_img = self.pipe.mask_processor.blur(mask_img, blur_factor=33)
|
124 |
+
|
125 |
+
kwargs["strength"] = 0.999
|
126 |
+
kwargs["padding_mask_crop"] = 1000
|
127 |
+
|
128 |
kwargs = {
|
129 |
"prompt": prompt,
|
130 |
"image": input_img,
|
|
|
134 |
"negative_prompt": negative_prompt,
|
135 |
"num_inference_steps": num_inference_steps,
|
136 |
"strength": 1.0,
|
137 |
+
"generator": generator,
|
138 |
**kwargs,
|
139 |
}
|
140 |
+
return self.pipe.__call__(**kwargs).images, mask_img
|
internals/pipelines/prompt_modifier.py
CHANGED
@@ -2,6 +2,8 @@ from typing import List, Optional
|
|
2 |
|
3 |
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
|
4 |
|
|
|
|
|
5 |
|
6 |
class PromptModifier:
|
7 |
__loaded = False
|
@@ -38,7 +40,7 @@ class PromptModifier:
|
|
38 |
do_sample=False,
|
39 |
max_new_tokens=75,
|
40 |
num_beams=4,
|
41 |
-
num_return_sequences=
|
42 |
eos_token_id=eos_id,
|
43 |
pad_token_id=eos_id,
|
44 |
length_penalty=-1.0,
|
|
|
2 |
|
3 |
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
|
4 |
|
5 |
+
from internals.util.config import get_num_return_sequences
|
6 |
+
|
7 |
|
8 |
class PromptModifier:
|
9 |
__loaded = False
|
|
|
40 |
do_sample=False,
|
41 |
max_new_tokens=75,
|
42 |
num_beams=4,
|
43 |
+
num_return_sequences=get_num_return_sequences(),
|
44 |
eos_token_id=eos_id,
|
45 |
pad_token_id=eos_id,
|
46 |
length_penalty=-1.0,
|
internals/pipelines/realtime_draw.py
CHANGED
@@ -9,7 +9,13 @@ from internals.pipelines.commons import AbstractPipeline
|
|
9 |
from internals.pipelines.controlnets import ControlNet
|
10 |
from internals.pipelines.high_res import HighRes
|
11 |
from internals.pipelines.sdxl_llite_pipeline import SDXLLLiteImg2ImgPipeline
|
12 |
-
from internals.util
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
|
14 |
|
15 |
class RealtimeDraw(AbstractPipeline):
|
@@ -60,7 +66,7 @@ class RealtimeDraw(AbstractPipeline):
|
|
60 |
if get_is_sdxl():
|
61 |
raise Exception("SDXL is not supported for this method")
|
62 |
|
63 |
-
|
64 |
|
65 |
image = ImageUtil.resize_image(image, 512)
|
66 |
|
@@ -70,6 +76,7 @@ class RealtimeDraw(AbstractPipeline):
|
|
70 |
prompt=prompt,
|
71 |
num_inference_steps=15,
|
72 |
negative_prompt=negative_prompt,
|
|
|
73 |
guidance_scale=10,
|
74 |
strength=0.8,
|
75 |
).images[0]
|
@@ -84,7 +91,7 @@ class RealtimeDraw(AbstractPipeline):
|
|
84 |
image: Optional[Image.Image] = None,
|
85 |
image2: Optional[Image.Image] = None,
|
86 |
):
|
87 |
-
|
88 |
|
89 |
b_dimen = get_base_dimension()
|
90 |
|
@@ -104,6 +111,8 @@ class RealtimeDraw(AbstractPipeline):
|
|
104 |
size = HighRes.find_closest_sdxl_aspect_ratio(image.size[0], image.size[1])
|
105 |
image = image.resize(size)
|
106 |
|
|
|
|
|
107 |
images = self.pipe.__call__(
|
108 |
image=image,
|
109 |
condition_image=image,
|
@@ -129,6 +138,7 @@ class RealtimeDraw(AbstractPipeline):
|
|
129 |
num_inference_steps=15,
|
130 |
negative_prompt=negative_prompt,
|
131 |
guidance_scale=10,
|
|
|
132 |
strength=0.9,
|
133 |
width=image.size[0],
|
134 |
height=image.size[1],
|
|
|
9 |
from internals.pipelines.controlnets import ControlNet
|
10 |
from internals.pipelines.high_res import HighRes
|
11 |
from internals.pipelines.sdxl_llite_pipeline import SDXLLLiteImg2ImgPipeline
|
12 |
+
from internals.util import get_generators
|
13 |
+
from internals.util.config import (
|
14 |
+
get_base_dimension,
|
15 |
+
get_hf_cache_dir,
|
16 |
+
get_is_sdxl,
|
17 |
+
get_num_return_sequences,
|
18 |
+
)
|
19 |
|
20 |
|
21 |
class RealtimeDraw(AbstractPipeline):
|
|
|
66 |
if get_is_sdxl():
|
67 |
raise Exception("SDXL is not supported for this method")
|
68 |
|
69 |
+
generator = get_generators(seed, get_num_return_sequences())
|
70 |
|
71 |
image = ImageUtil.resize_image(image, 512)
|
72 |
|
|
|
76 |
prompt=prompt,
|
77 |
num_inference_steps=15,
|
78 |
negative_prompt=negative_prompt,
|
79 |
+
generator=generator,
|
80 |
guidance_scale=10,
|
81 |
strength=0.8,
|
82 |
).images[0]
|
|
|
91 |
image: Optional[Image.Image] = None,
|
92 |
image2: Optional[Image.Image] = None,
|
93 |
):
|
94 |
+
generator = get_generators(seed, get_num_return_sequences())
|
95 |
|
96 |
b_dimen = get_base_dimension()
|
97 |
|
|
|
111 |
size = HighRes.find_closest_sdxl_aspect_ratio(image.size[0], image.size[1])
|
112 |
image = image.resize(size)
|
113 |
|
114 |
+
torch.manual_seed(seed)
|
115 |
+
|
116 |
images = self.pipe.__call__(
|
117 |
image=image,
|
118 |
condition_image=image,
|
|
|
138 |
num_inference_steps=15,
|
139 |
negative_prompt=negative_prompt,
|
140 |
guidance_scale=10,
|
141 |
+
generator=generator,
|
142 |
strength=0.9,
|
143 |
width=image.size[0],
|
144 |
height=image.size[1],
|
internals/pipelines/remove_background.py
CHANGED
@@ -1,20 +1,22 @@
|
|
1 |
import io
|
2 |
from pathlib import Path
|
3 |
from typing import Union
|
4 |
-
import numpy as np
|
5 |
-
import cv2
|
6 |
|
|
|
|
|
|
|
|
|
7 |
import torch
|
8 |
import torch.nn.functional as F
|
|
|
9 |
from PIL import Image
|
10 |
from rembg import remove
|
11 |
-
from
|
12 |
|
13 |
import internals.util.image as ImageUtil
|
14 |
from carvekit.api.high import HiInterface
|
|
|
15 |
from internals.util.commons import download_image, read_url
|
16 |
-
import onnxruntime as rt
|
17 |
-
import huggingface_hub
|
18 |
|
19 |
|
20 |
class RemoveBackground:
|
@@ -94,3 +96,51 @@ class RemoveBackgroundV2:
|
|
94 |
img = np.concatenate([img, mask], axis=2, dtype=np.uint8)
|
95 |
mask = mask.repeat(3, axis=2)
|
96 |
return mask, img
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import io
|
2 |
from pathlib import Path
|
3 |
from typing import Union
|
|
|
|
|
4 |
|
5 |
+
import cv2
|
6 |
+
import huggingface_hub
|
7 |
+
import numpy as np
|
8 |
+
import onnxruntime as rt
|
9 |
import torch
|
10 |
import torch.nn.functional as F
|
11 |
+
from briarmbg import BriaRMBG # pyright: ignore
|
12 |
from PIL import Image
|
13 |
from rembg import remove
|
14 |
+
from torchvision.transforms.functional import normalize
|
15 |
|
16 |
import internals.util.image as ImageUtil
|
17 |
from carvekit.api.high import HiInterface
|
18 |
+
from internals.data.task import ModelType
|
19 |
from internals.util.commons import download_image, read_url
|
|
|
|
|
20 |
|
21 |
|
22 |
class RemoveBackground:
|
|
|
96 |
img = np.concatenate([img, mask], axis=2, dtype=np.uint8)
|
97 |
mask = mask.repeat(3, axis=2)
|
98 |
return mask, img
|
99 |
+
|
100 |
+
|
101 |
+
class RemoveBackgroundV3:
|
102 |
+
def __init__(self):
|
103 |
+
net = BriaRMBG.from_pretrained("briaai/RMBG-1.4")
|
104 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
105 |
+
net.to(device)
|
106 |
+
self.net = net
|
107 |
+
|
108 |
+
def remove(self, image: Union[str, Image.Image]) -> Image.Image:
|
109 |
+
if type(image) is str:
|
110 |
+
image = download_image(image, mode="RGBA")
|
111 |
+
|
112 |
+
orig_image = image
|
113 |
+
w, h = orig_im_size = orig_image.size
|
114 |
+
image = self.__resize_image(orig_image)
|
115 |
+
im_np = np.array(image)
|
116 |
+
im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2, 0, 1)
|
117 |
+
im_tensor = torch.unsqueeze(im_tensor, 0)
|
118 |
+
im_tensor = torch.divide(im_tensor, 255.0)
|
119 |
+
im_tensor = normalize(im_tensor, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0])
|
120 |
+
if torch.cuda.is_available():
|
121 |
+
im_tensor = im_tensor.cuda()
|
122 |
+
|
123 |
+
# inference
|
124 |
+
result = self.net(im_tensor)
|
125 |
+
# post process
|
126 |
+
result = torch.squeeze(
|
127 |
+
F.interpolate(result[0][0], size=(h, w), mode="bilinear"), 0
|
128 |
+
)
|
129 |
+
ma = torch.max(result)
|
130 |
+
mi = torch.min(result)
|
131 |
+
result = (result - mi) / (ma - mi)
|
132 |
+
# image to pil
|
133 |
+
im_array = (result * 255).cpu().data.numpy().astype(np.uint8)
|
134 |
+
pil_im = Image.fromarray(np.squeeze(im_array))
|
135 |
+
# paste the mask on the original image
|
136 |
+
new_im = Image.new("RGBA", pil_im.size, (0, 0, 0, 0))
|
137 |
+
new_im.paste(orig_image, mask=pil_im)
|
138 |
+
# new_orig_image = orig_image.convert('RGBA')
|
139 |
+
|
140 |
+
return new_im
|
141 |
+
|
142 |
+
def __resize_image(self, image):
|
143 |
+
image = image.convert("RGB")
|
144 |
+
model_input_size = (1024, 1024)
|
145 |
+
image = image.resize(model_input_size, Image.BILINEAR)
|
146 |
+
return image
|
internals/pipelines/replace_background.py
CHANGED
@@ -16,11 +16,12 @@ import internals.util.image as ImageUtil
|
|
16 |
from internals.data.result import Result
|
17 |
from internals.data.task import ModelType
|
18 |
from internals.pipelines.commons import AbstractPipeline
|
19 |
-
from internals.pipelines.controlnets import ControlNet
|
20 |
from internals.pipelines.high_res import HighRes
|
21 |
from internals.pipelines.inpainter import InPainter
|
22 |
from internals.pipelines.remove_background import RemoveBackgroundV2
|
23 |
from internals.pipelines.upscaler import Upscaler
|
|
|
24 |
from internals.util.cache import clear_cuda_and_gc
|
25 |
from internals.util.commons import download_image
|
26 |
from internals.util.config import (
|
@@ -28,6 +29,7 @@ from internals.util.config import (
|
|
28 |
get_hf_token,
|
29 |
get_inpaint_model_path,
|
30 |
get_model_dir,
|
|
|
31 |
)
|
32 |
|
33 |
|
@@ -43,11 +45,9 @@ class ReplaceBackground(AbstractPipeline):
|
|
43 |
):
|
44 |
if self.__loaded:
|
45 |
return
|
46 |
-
controlnet_model =
|
47 |
-
"lllyasviel/control_v11p_sd15_canny",
|
48 |
-
|
49 |
-
cache_dir=get_hf_cache_dir(),
|
50 |
-
).to("cuda")
|
51 |
if base:
|
52 |
pipe = StableDiffusionControlNetPipeline(
|
53 |
**base.pipe.components,
|
@@ -109,8 +109,7 @@ class ReplaceBackground(AbstractPipeline):
|
|
109 |
if type(image) is str:
|
110 |
image = download_image(image)
|
111 |
|
112 |
-
|
113 |
-
torch.cuda.manual_seed(seed)
|
114 |
|
115 |
image = image.convert("RGB")
|
116 |
if max(image.size) > 1024:
|
@@ -148,6 +147,7 @@ class ReplaceBackground(AbstractPipeline):
|
|
148 |
guidance_scale=9,
|
149 |
height=height,
|
150 |
num_inference_steps=steps,
|
|
|
151 |
width=width,
|
152 |
)
|
153 |
result = Result.from_result(result)
|
|
|
16 |
from internals.data.result import Result
|
17 |
from internals.data.task import ModelType
|
18 |
from internals.pipelines.commons import AbstractPipeline
|
19 |
+
from internals.pipelines.controlnets import ControlNet, load_network_model_by_key
|
20 |
from internals.pipelines.high_res import HighRes
|
21 |
from internals.pipelines.inpainter import InPainter
|
22 |
from internals.pipelines.remove_background import RemoveBackgroundV2
|
23 |
from internals.pipelines.upscaler import Upscaler
|
24 |
+
from internals.util import get_generators
|
25 |
from internals.util.cache import clear_cuda_and_gc
|
26 |
from internals.util.commons import download_image
|
27 |
from internals.util.config import (
|
|
|
29 |
get_hf_token,
|
30 |
get_inpaint_model_path,
|
31 |
get_model_dir,
|
32 |
+
get_num_return_sequences,
|
33 |
)
|
34 |
|
35 |
|
|
|
45 |
):
|
46 |
if self.__loaded:
|
47 |
return
|
48 |
+
controlnet_model = load_network_model_by_key(
|
49 |
+
"lllyasviel/control_v11p_sd15_canny", "controlnet"
|
50 |
+
)
|
|
|
|
|
51 |
if base:
|
52 |
pipe = StableDiffusionControlNetPipeline(
|
53 |
**base.pipe.components,
|
|
|
109 |
if type(image) is str:
|
110 |
image = download_image(image)
|
111 |
|
112 |
+
generator = get_generators(seed, get_num_return_sequences())
|
|
|
113 |
|
114 |
image = image.convert("RGB")
|
115 |
if max(image.size) > 1024:
|
|
|
147 |
guidance_scale=9,
|
148 |
height=height,
|
149 |
num_inference_steps=steps,
|
150 |
+
generator=generator,
|
151 |
width=width,
|
152 |
)
|
153 |
result = Result.from_result(result)
|
internals/pipelines/safety_checker.py
CHANGED
@@ -31,10 +31,11 @@ class SafetyChecker:
|
|
31 |
self.__loaded = True
|
32 |
|
33 |
def apply(self, pipeline: AbstractPipeline):
|
34 |
-
|
35 |
-
if model:
|
36 |
self.load()
|
37 |
|
|
|
|
|
38 |
if not pipeline:
|
39 |
return
|
40 |
if hasattr(pipeline, "pipe"):
|
|
|
31 |
self.__loaded = True
|
32 |
|
33 |
def apply(self, pipeline: AbstractPipeline):
|
34 |
+
if not get_nsfw_access():
|
|
|
35 |
self.load()
|
36 |
|
37 |
+
model = self.model if not get_nsfw_access() else None
|
38 |
+
|
39 |
if not pipeline:
|
40 |
return
|
41 |
if hasattr(pipeline, "pipe"):
|
internals/pipelines/sdxl_llite_pipeline.py
CHANGED
@@ -1251,6 +1251,8 @@ class PipelineLike:
|
|
1251 |
|
1252 |
|
1253 |
class SDXLLLiteImg2ImgPipeline:
|
|
|
|
|
1254 |
def __init__(self):
|
1255 |
self.SCHEDULER_LINEAR_START = 0.00085
|
1256 |
self.SCHEDULER_LINEAR_END = 0.0120
|
@@ -1261,7 +1263,7 @@ class SDXLLLiteImg2ImgPipeline:
|
|
1261 |
|
1262 |
def replace_unet_modules(
|
1263 |
self,
|
1264 |
-
unet:
|
1265 |
mem_eff_attn,
|
1266 |
xformers,
|
1267 |
sdpa,
|
|
|
1251 |
|
1252 |
|
1253 |
class SDXLLLiteImg2ImgPipeline:
|
1254 |
+
from diffusers import UNet2DConditionModel
|
1255 |
+
|
1256 |
def __init__(self):
|
1257 |
self.SCHEDULER_LINEAR_START = 0.00085
|
1258 |
self.SCHEDULER_LINEAR_END = 0.0120
|
|
|
1263 |
|
1264 |
def replace_unet_modules(
|
1265 |
self,
|
1266 |
+
unet: UNet2DConditionModel,
|
1267 |
mem_eff_attn,
|
1268 |
xformers,
|
1269 |
sdpa,
|
internals/pipelines/sdxl_tile_upscale.py
CHANGED
@@ -4,8 +4,10 @@ from PIL import Image
|
|
4 |
from torchvision import transforms
|
5 |
|
6 |
import internals.util.image as ImageUtils
|
|
|
7 |
from carvekit.api import high
|
8 |
from internals.data.result import Result
|
|
|
9 |
from internals.pipelines.commons import AbstractPipeline, Text2Img
|
10 |
from internals.pipelines.controlnets import ControlNet
|
11 |
from internals.pipelines.demofusion_sdxl import DemoFusionSDXLControlNetPipeline
|
@@ -19,18 +21,16 @@ controlnet = ControlNet()
|
|
19 |
|
20 |
class SDXLTileUpscaler(AbstractPipeline):
|
21 |
__loaded = False
|
|
|
22 |
|
23 |
def create(self, high_res: HighRes, pipeline: Text2Img, model_id: int):
|
24 |
if self.__loaded:
|
25 |
return
|
26 |
# temporal hack for upscale model till multicontrolnet support is added
|
27 |
-
model = (
|
28 |
-
"thibaud/controlnet-openpose-sdxl-1.0"
|
29 |
-
if int(model_id) == 2000293
|
30 |
-
else "diffusers/controlnet-canny-sdxl-1.0"
|
31 |
-
)
|
32 |
|
33 |
-
controlnet = ControlNetModel.from_pretrained(
|
|
|
|
|
34 |
pipe = DemoFusionSDXLControlNetPipeline(
|
35 |
**pipeline.pipe.components, controlnet=controlnet
|
36 |
)
|
@@ -43,6 +43,7 @@ class SDXLTileUpscaler(AbstractPipeline):
|
|
43 |
|
44 |
self.pipe = pipe
|
45 |
|
|
|
46 |
self.__loaded = True
|
47 |
|
48 |
def unload(self):
|
@@ -52,6 +53,26 @@ class SDXLTileUpscaler(AbstractPipeline):
|
|
52 |
|
53 |
clear_cuda_and_gc()
|
54 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
def process(
|
56 |
self,
|
57 |
prompt: str,
|
@@ -61,21 +82,36 @@ class SDXLTileUpscaler(AbstractPipeline):
|
|
61 |
width: int,
|
62 |
height: int,
|
63 |
model_id: int,
|
|
|
|
|
64 |
):
|
65 |
-
|
|
|
|
|
|
|
|
|
|
|
66 |
condition_image = controlnet.detect_pose(imageUrl)
|
67 |
else:
|
|
|
68 |
condition_image = download_image(imageUrl)
|
69 |
condition_image = ControlNet.canny_detect_edge(condition_image)
|
70 |
-
|
71 |
|
72 |
-
img =
|
73 |
condition_image = condition_image.resize(img.size)
|
74 |
|
75 |
img2 = self.__resize_for_condition_image(img, resize_dimension)
|
76 |
|
|
|
77 |
image_lr = self.load_and_process_image(img)
|
78 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
if int(model_id) == 2000173:
|
80 |
kwargs = {
|
81 |
"prompt": prompt,
|
@@ -83,6 +119,7 @@ class SDXLTileUpscaler(AbstractPipeline):
|
|
83 |
"image": img2,
|
84 |
"strength": 0.3,
|
85 |
"num_inference_steps": 30,
|
|
|
86 |
}
|
87 |
images = self.high_res.pipe.__call__(**kwargs).images
|
88 |
else:
|
@@ -90,20 +127,24 @@ class SDXLTileUpscaler(AbstractPipeline):
|
|
90 |
image_lr=image_lr,
|
91 |
prompt=prompt,
|
92 |
condition_image=condition_image,
|
93 |
-
negative_prompt="blurry, ugly, duplicate, poorly drawn, deformed, mosaic"
|
|
|
94 |
guidance_scale=11,
|
95 |
sigma=0.8,
|
96 |
num_inference_steps=24,
|
97 |
-
|
98 |
-
|
|
|
|
|
99 |
)
|
100 |
images = images[::-1]
|
|
|
|
|
101 |
return images, False
|
102 |
|
103 |
def load_and_process_image(self, pil_image):
|
104 |
transform = transforms.Compose(
|
105 |
[
|
106 |
-
transforms.Resize((1024, 1024)),
|
107 |
transforms.ToTensor(),
|
108 |
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
109 |
]
|
@@ -113,6 +154,36 @@ class SDXLTileUpscaler(AbstractPipeline):
|
|
113 |
image = image.to("cuda")
|
114 |
return image
|
115 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
116 |
def __resize_for_condition_image(self, image: Image.Image, resolution: int):
|
117 |
input_image = image.convert("RGB")
|
118 |
W, H = input_image.size
|
|
|
4 |
from torchvision import transforms
|
5 |
|
6 |
import internals.util.image as ImageUtils
|
7 |
+
import internals.util.image as ImageUtil
|
8 |
from carvekit.api import high
|
9 |
from internals.data.result import Result
|
10 |
+
from internals.data.task import TaskType
|
11 |
from internals.pipelines.commons import AbstractPipeline, Text2Img
|
12 |
from internals.pipelines.controlnets import ControlNet
|
13 |
from internals.pipelines.demofusion_sdxl import DemoFusionSDXLControlNetPipeline
|
|
|
21 |
|
22 |
class SDXLTileUpscaler(AbstractPipeline):
|
23 |
__loaded = False
|
24 |
+
__current_process_mode = None
|
25 |
|
26 |
def create(self, high_res: HighRes, pipeline: Text2Img, model_id: int):
|
27 |
if self.__loaded:
|
28 |
return
|
29 |
# temporal hack for upscale model till multicontrolnet support is added
|
|
|
|
|
|
|
|
|
|
|
30 |
|
31 |
+
controlnet = ControlNetModel.from_pretrained(
|
32 |
+
"diffusers/controlnet-canny-sdxl-1.0", torch_dtype=torch.float16
|
33 |
+
)
|
34 |
pipe = DemoFusionSDXLControlNetPipeline(
|
35 |
**pipeline.pipe.components, controlnet=controlnet
|
36 |
)
|
|
|
43 |
|
44 |
self.pipe = pipe
|
45 |
|
46 |
+
self.__current_process_mode = TaskType.CANNY.name
|
47 |
self.__loaded = True
|
48 |
|
49 |
def unload(self):
|
|
|
53 |
|
54 |
clear_cuda_and_gc()
|
55 |
|
56 |
+
def __reload_controlnet(self, process_mode: str):
|
57 |
+
if self.__current_process_mode == process_mode:
|
58 |
+
return
|
59 |
+
|
60 |
+
model = (
|
61 |
+
"thibaud/controlnet-openpose-sdxl-1.0"
|
62 |
+
if process_mode == TaskType.POSE.name
|
63 |
+
else "diffusers/controlnet-canny-sdxl-1.0"
|
64 |
+
)
|
65 |
+
controlnet = ControlNetModel.from_pretrained(
|
66 |
+
model, torch_dtype=torch.float16
|
67 |
+
).to("cuda")
|
68 |
+
|
69 |
+
if hasattr(self, "pipe"):
|
70 |
+
self.pipe.controlnet = controlnet
|
71 |
+
|
72 |
+
self.__current_process_mode = process_mode
|
73 |
+
|
74 |
+
clear_cuda_and_gc()
|
75 |
+
|
76 |
def process(
|
77 |
self,
|
78 |
prompt: str,
|
|
|
82 |
width: int,
|
83 |
height: int,
|
84 |
model_id: int,
|
85 |
+
seed: int,
|
86 |
+
process_mode: str,
|
87 |
):
|
88 |
+
generator = torch.manual_seed(seed)
|
89 |
+
|
90 |
+
self.__reload_controlnet(process_mode)
|
91 |
+
|
92 |
+
if process_mode == TaskType.POSE.name:
|
93 |
+
print("Running POSE")
|
94 |
condition_image = controlnet.detect_pose(imageUrl)
|
95 |
else:
|
96 |
+
print("Running CANNY")
|
97 |
condition_image = download_image(imageUrl)
|
98 |
condition_image = ControlNet.canny_detect_edge(condition_image)
|
99 |
+
width, height = HighRes.find_closest_sdxl_aspect_ratio(width, height)
|
100 |
|
101 |
+
img = download_image(imageUrl).resize((width, height))
|
102 |
condition_image = condition_image.resize(img.size)
|
103 |
|
104 |
img2 = self.__resize_for_condition_image(img, resize_dimension)
|
105 |
|
106 |
+
img = self.pad_image(img)
|
107 |
image_lr = self.load_and_process_image(img)
|
108 |
+
|
109 |
+
out_img = self.pad_image(img2)
|
110 |
+
condition_image = self.pad_image(condition_image)
|
111 |
+
|
112 |
+
print("img", img.size)
|
113 |
+
print("img2", img2.size)
|
114 |
+
print("condition", condition_image.size)
|
115 |
if int(model_id) == 2000173:
|
116 |
kwargs = {
|
117 |
"prompt": prompt,
|
|
|
119 |
"image": img2,
|
120 |
"strength": 0.3,
|
121 |
"num_inference_steps": 30,
|
122 |
+
"generator": generator,
|
123 |
}
|
124 |
images = self.high_res.pipe.__call__(**kwargs).images
|
125 |
else:
|
|
|
127 |
image_lr=image_lr,
|
128 |
prompt=prompt,
|
129 |
condition_image=condition_image,
|
130 |
+
negative_prompt="blurry, ugly, duplicate, poorly drawn, deformed, mosaic, "
|
131 |
+
+ negative_prompt,
|
132 |
guidance_scale=11,
|
133 |
sigma=0.8,
|
134 |
num_inference_steps=24,
|
135 |
+
controlnet_conditioning_scale=0.5,
|
136 |
+
generator=generator,
|
137 |
+
width=out_img.size[0],
|
138 |
+
height=out_img.size[1],
|
139 |
)
|
140 |
images = images[::-1]
|
141 |
+
iv = ImageUtil.resize_image(img2, images[0].size[0])
|
142 |
+
images = [self.unpad_image(images[0], iv.size)]
|
143 |
return images, False
|
144 |
|
145 |
def load_and_process_image(self, pil_image):
|
146 |
transform = transforms.Compose(
|
147 |
[
|
|
|
148 |
transforms.ToTensor(),
|
149 |
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
150 |
]
|
|
|
154 |
image = image.to("cuda")
|
155 |
return image
|
156 |
|
157 |
+
def pad_image(self, image):
|
158 |
+
w, h = image.size
|
159 |
+
if w == h:
|
160 |
+
return image
|
161 |
+
elif w > h:
|
162 |
+
new_image = Image.new(image.mode, (w, w), (0, 0, 0))
|
163 |
+
pad_w = 0
|
164 |
+
pad_h = (w - h) // 2
|
165 |
+
new_image.paste(image, (0, pad_h))
|
166 |
+
return new_image
|
167 |
+
else:
|
168 |
+
new_image = Image.new(image.mode, (h, h), (0, 0, 0))
|
169 |
+
pad_w = (h - w) // 2
|
170 |
+
pad_h = 0
|
171 |
+
new_image.paste(image, (pad_w, 0))
|
172 |
+
return new_image
|
173 |
+
|
174 |
+
def unpad_image(self, padded_image, original_size):
|
175 |
+
w, h = original_size
|
176 |
+
if w == h:
|
177 |
+
return padded_image
|
178 |
+
elif w > h:
|
179 |
+
pad_h = (w - h) // 2
|
180 |
+
unpadded_image = padded_image.crop((0, pad_h, w, h + pad_h))
|
181 |
+
return unpadded_image
|
182 |
+
else:
|
183 |
+
pad_w = (h - w) // 2
|
184 |
+
unpadded_image = padded_image.crop((pad_w, 0, w + pad_w, h))
|
185 |
+
return unpadded_image
|
186 |
+
|
187 |
def __resize_for_condition_image(self, image: Image.Image, resolution: int):
|
188 |
input_image = image.convert("RGB")
|
189 |
W, H = input_image.size
|
internals/pipelines/upscaler.py
CHANGED
@@ -1,7 +1,8 @@
|
|
|
|
1 |
import math
|
2 |
import os
|
3 |
from pathlib import Path
|
4 |
-
from typing import Union
|
5 |
|
6 |
import cv2
|
7 |
import numpy as np
|
@@ -10,7 +11,7 @@ from basicsr.archs.srvgg_arch import SRVGGNetCompact
|
|
10 |
from basicsr.utils.download_util import load_file_from_url
|
11 |
from gfpgan import GFPGANer
|
12 |
from PIL import Image
|
13 |
-
from realesrgan import RealESRGANer
|
14 |
|
15 |
import internals.util.image as ImageUtil
|
16 |
from internals.util.commons import download_image
|
@@ -55,8 +56,12 @@ class Upscaler:
|
|
55 |
width: int,
|
56 |
height: int,
|
57 |
face_enhance: bool,
|
58 |
-
resize_dimension: int,
|
59 |
) -> bytes:
|
|
|
|
|
|
|
|
|
60 |
model = SRVGGNetCompact(
|
61 |
num_in_ch=3,
|
62 |
num_out_ch=3,
|
@@ -67,7 +72,7 @@ class Upscaler:
|
|
67 |
)
|
68 |
return self.__internal_upscale(
|
69 |
image,
|
70 |
-
resize_dimension,
|
71 |
face_enhance,
|
72 |
width,
|
73 |
height,
|
@@ -83,6 +88,10 @@ class Upscaler:
|
|
83 |
face_enhance: bool,
|
84 |
resize_dimension: int,
|
85 |
) -> bytes:
|
|
|
|
|
|
|
|
|
86 |
model = RRDBNet(
|
87 |
num_in_ch=3,
|
88 |
num_out_ch=3,
|
@@ -124,18 +133,22 @@ class Upscaler:
|
|
124 |
model,
|
125 |
) -> bytes:
|
126 |
if type(image) is str:
|
127 |
-
image = download_image(image)
|
128 |
|
129 |
w, h = image.size
|
130 |
-
if max(w, h) > 1024:
|
131 |
-
|
132 |
|
133 |
in_path = str(Path.home() / ".cache" / "input_upscale.png")
|
134 |
image.save(in_path)
|
135 |
input_image = cv2.imread(in_path, cv2.IMREAD_UNCHANGED)
|
136 |
-
dimension =
|
|
|
|
|
137 |
scale = max(math.floor(resize_dimension / dimension), 2)
|
138 |
|
|
|
|
|
139 |
os.chdir(str(Path.home() / ".cache"))
|
140 |
if scale == 4:
|
141 |
print("Using 4x-Ultrasharp")
|
@@ -174,3 +187,7 @@ class Upscaler:
|
|
174 |
cv2.imwrite("out.png", output)
|
175 |
out_bytes = cv2.imencode(".png", output)[1].tobytes()
|
176 |
return out_bytes
|
|
|
|
|
|
|
|
|
|
1 |
+
import io
|
2 |
import math
|
3 |
import os
|
4 |
from pathlib import Path
|
5 |
+
from typing import Optional, Union
|
6 |
|
7 |
import cv2
|
8 |
import numpy as np
|
|
|
11 |
from basicsr.utils.download_util import load_file_from_url
|
12 |
from gfpgan import GFPGANer
|
13 |
from PIL import Image
|
14 |
+
from realesrgan import RealESRGANer # pyright: ignore
|
15 |
|
16 |
import internals.util.image as ImageUtil
|
17 |
from internals.util.commons import download_image
|
|
|
56 |
width: int,
|
57 |
height: int,
|
58 |
face_enhance: bool,
|
59 |
+
resize_dimension: Optional[int] = None,
|
60 |
) -> bytes:
|
61 |
+
"if resize dimension is not provided, use the smaller of width and height"
|
62 |
+
|
63 |
+
self.load()
|
64 |
+
|
65 |
model = SRVGGNetCompact(
|
66 |
num_in_ch=3,
|
67 |
num_out_ch=3,
|
|
|
72 |
)
|
73 |
return self.__internal_upscale(
|
74 |
image,
|
75 |
+
resize_dimension, # type: ignore
|
76 |
face_enhance,
|
77 |
width,
|
78 |
height,
|
|
|
88 |
face_enhance: bool,
|
89 |
resize_dimension: int,
|
90 |
) -> bytes:
|
91 |
+
"if resize dimension is not provided, use the smaller of width and height"
|
92 |
+
|
93 |
+
self.load()
|
94 |
+
|
95 |
model = RRDBNet(
|
96 |
num_in_ch=3,
|
97 |
num_out_ch=3,
|
|
|
133 |
model,
|
134 |
) -> bytes:
|
135 |
if type(image) is str:
|
136 |
+
image = download_image(image, mode="RGBA")
|
137 |
|
138 |
w, h = image.size
|
139 |
+
# if max(w, h) > 1024:
|
140 |
+
# image = ImageUtil.resize_image(image, dimension=1024)
|
141 |
|
142 |
in_path = str(Path.home() / ".cache" / "input_upscale.png")
|
143 |
image.save(in_path)
|
144 |
input_image = cv2.imread(in_path, cv2.IMREAD_UNCHANGED)
|
145 |
+
dimension = max(input_image.shape[0], input_image.shape[1])
|
146 |
+
if not resize_dimension:
|
147 |
+
resize_dimension = max(width, height)
|
148 |
scale = max(math.floor(resize_dimension / dimension), 2)
|
149 |
|
150 |
+
print("Upscaling by: ", scale)
|
151 |
+
|
152 |
os.chdir(str(Path.home() / ".cache"))
|
153 |
if scale == 4:
|
154 |
print("Using 4x-Ultrasharp")
|
|
|
187 |
cv2.imwrite("out.png", output)
|
188 |
out_bytes = cv2.imencode(".png", output)[1].tobytes()
|
189 |
return out_bytes
|
190 |
+
|
191 |
+
@staticmethod
|
192 |
+
def to_pil(buffer: bytes, mode="RGB") -> Image.Image:
|
193 |
+
return Image.open(io.BytesIO(buffer)).convert(mode)
|
internals/util/__init__.py
CHANGED
@@ -1,7 +1,13 @@
|
|
1 |
import os
|
2 |
|
|
|
|
|
3 |
from internals.util.config import get_root_dir
|
4 |
|
5 |
|
6 |
def getcwd():
|
7 |
return get_root_dir()
|
|
|
|
|
|
|
|
|
|
1 |
import os
|
2 |
|
3 |
+
import torch
|
4 |
+
|
5 |
from internals.util.config import get_root_dir
|
6 |
|
7 |
|
8 |
def getcwd():
|
9 |
return get_root_dir()
|
10 |
+
|
11 |
+
|
12 |
+
def get_generators(seed, num_generators=1):
|
13 |
+
return [torch.Generator().manual_seed(seed + i) for i in range(num_generators)]
|
internals/util/cache.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
import gc
|
2 |
import os
|
|
|
3 |
import psutil
|
4 |
import torch
|
5 |
|
@@ -7,6 +8,7 @@ import torch
|
|
7 |
def print_memory_usage():
|
8 |
process = psutil.Process(os.getpid())
|
9 |
print(f"Memory usage: {process.memory_info().rss / 1024 ** 2:2f} MB")
|
|
|
10 |
|
11 |
|
12 |
def clear_cuda_and_gc():
|
|
|
1 |
import gc
|
2 |
import os
|
3 |
+
|
4 |
import psutil
|
5 |
import torch
|
6 |
|
|
|
8 |
def print_memory_usage():
|
9 |
process = psutil.Process(os.getpid())
|
10 |
print(f"Memory usage: {process.memory_info().rss / 1024 ** 2:2f} MB")
|
11 |
+
print(f"GPU usage: {torch.cuda.memory_allocated() / 1024 ** 2:2f} MB")
|
12 |
|
13 |
|
14 |
def clear_cuda_and_gc():
|
internals/util/commons.py
CHANGED
@@ -11,7 +11,7 @@ from typing import Any, Optional, Union
|
|
11 |
import boto3
|
12 |
import requests
|
13 |
|
14 |
-
from internals.util.config import api_endpoint, api_headers
|
15 |
|
16 |
s3 = boto3.client("s3")
|
17 |
import io
|
@@ -103,7 +103,7 @@ def upload_images(images, processName: str, taskId: str):
|
|
103 |
img_io.seek(0)
|
104 |
key = "crecoAI/{}{}_{}.png".format(taskId, processName, i)
|
105 |
res = requests.post(
|
106 |
-
|
107 |
+ "/autodraft-content/v1.0/upload/crecoai-assets-2?fileName="
|
108 |
+ "{}{}_{}.png".format(taskId, processName, i),
|
109 |
headers=api_headers(),
|
@@ -129,12 +129,12 @@ def upload_image(image: Union[Image.Image, BytesIO], out_path):
|
|
129 |
|
130 |
image.seek(0)
|
131 |
print(
|
132 |
-
|
133 |
+ "/autodraft-content/v1.0/upload/crecoai-assets-2?fileName="
|
134 |
+ str(out_path).replace("crecoAI/", ""),
|
135 |
)
|
136 |
res = requests.post(
|
137 |
-
|
138 |
+ "/autodraft-content/v1.0/upload/crecoai-assets-2?fileName="
|
139 |
+ str(out_path).replace("crecoAI/", ""),
|
140 |
headers=api_headers(),
|
|
|
11 |
import boto3
|
12 |
import requests
|
13 |
|
14 |
+
from internals.util.config import api_endpoint, api_headers, elb_endpoint
|
15 |
|
16 |
s3 = boto3.client("s3")
|
17 |
import io
|
|
|
103 |
img_io.seek(0)
|
104 |
key = "crecoAI/{}{}_{}.png".format(taskId, processName, i)
|
105 |
res = requests.post(
|
106 |
+
elb_endpoint()
|
107 |
+ "/autodraft-content/v1.0/upload/crecoai-assets-2?fileName="
|
108 |
+ "{}{}_{}.png".format(taskId, processName, i),
|
109 |
headers=api_headers(),
|
|
|
129 |
|
130 |
image.seek(0)
|
131 |
print(
|
132 |
+
elb_endpoint()
|
133 |
+ "/autodraft-content/v1.0/upload/crecoai-assets-2?fileName="
|
134 |
+ str(out_path).replace("crecoAI/", ""),
|
135 |
)
|
136 |
res = requests.post(
|
137 |
+
elb_endpoint()
|
138 |
+ "/autodraft-content/v1.0/upload/crecoai-assets-2?fileName="
|
139 |
+ str(out_path).replace("crecoAI/", ""),
|
140 |
headers=api_headers(),
|
internals/util/config.py
CHANGED
@@ -13,7 +13,7 @@ access_token = ""
|
|
13 |
root_dir = ""
|
14 |
model_config = None
|
15 |
hf_token = base64.b64decode(
|
16 |
-
b"
|
17 |
).decode()
|
18 |
hf_cache_dir = "/tmp/hf_hub"
|
19 |
|
@@ -46,7 +46,7 @@ def set_model_config(config: ModelConfig):
|
|
46 |
|
47 |
def set_configs_from_task(task: Task):
|
48 |
global env, nsfw_threshold, nsfw_access, access_token, base_dimension, num_return_sequences
|
49 |
-
name = task.
|
50 |
if name.startswith("gamma"):
|
51 |
env = "gamma"
|
52 |
else:
|
@@ -120,14 +120,25 @@ def get_base_model_variant():
|
|
120 |
return model_config.base_model_variant # pyright: ignore
|
121 |
|
122 |
|
|
|
|
|
|
|
|
|
|
|
123 |
def get_base_inpaint_model_variant():
|
124 |
global model_config
|
125 |
return model_config.base_inpaint_model_variant # pyright: ignore
|
126 |
|
127 |
|
|
|
|
|
|
|
|
|
|
|
128 |
def api_headers():
|
129 |
return {
|
130 |
"Access-Token": access_token,
|
|
|
131 |
}
|
132 |
|
133 |
|
@@ -138,8 +149,11 @@ def api_endpoint():
|
|
138 |
return "https://gamma-api.autodraft.in"
|
139 |
|
140 |
|
141 |
-
def
|
|
|
|
|
|
|
142 |
if env == "prod":
|
143 |
-
return "http://
|
144 |
else:
|
145 |
-
return "http://
|
|
|
13 |
root_dir = ""
|
14 |
model_config = None
|
15 |
hf_token = base64.b64decode(
|
16 |
+
b"aGZfaXRvVVJzTmN1RHZab1hXZ3hIeFRRRGdvSHdrQ2VNUldGbA=="
|
17 |
).decode()
|
18 |
hf_cache_dir = "/tmp/hf_hub"
|
19 |
|
|
|
46 |
|
47 |
def set_configs_from_task(task: Task):
|
48 |
global env, nsfw_threshold, nsfw_access, access_token, base_dimension, num_return_sequences
|
49 |
+
name = task.get_environment()
|
50 |
if name.startswith("gamma"):
|
51 |
env = "gamma"
|
52 |
else:
|
|
|
120 |
return model_config.base_model_variant # pyright: ignore
|
121 |
|
122 |
|
123 |
+
def get_base_model_revision():
|
124 |
+
global model_config
|
125 |
+
return model_config.base_model_revision # pyright: ignore
|
126 |
+
|
127 |
+
|
128 |
def get_base_inpaint_model_variant():
|
129 |
global model_config
|
130 |
return model_config.base_inpaint_model_variant # pyright: ignore
|
131 |
|
132 |
|
133 |
+
def get_base_inpaint_model_revision():
|
134 |
+
global model_config
|
135 |
+
return model_config.base_inpaint_model_revision # pyright: ignore
|
136 |
+
|
137 |
+
|
138 |
def api_headers():
|
139 |
return {
|
140 |
"Access-Token": access_token,
|
141 |
+
"Host": "api.autodraft.in" if env == "prod" else "gamma-api.autodraft.in",
|
142 |
}
|
143 |
|
144 |
|
|
|
149 |
return "https://gamma-api.autodraft.in"
|
150 |
|
151 |
|
152 |
+
def elb_endpoint():
|
153 |
+
# We use the ELB endpoint for uploading images since
|
154 |
+
# cloudflare has a hard limit of 100mb when the
|
155 |
+
# DNS is proxied
|
156 |
if env == "prod":
|
157 |
+
return "http://k8s-prod-ingresse-8ba91151af-2105029163.ap-south-1.elb.amazonaws.com"
|
158 |
else:
|
159 |
+
return "http://k8s-gamma-ingresse-fc1051bc41-1227070426.ap-south-1.elb.amazonaws.com"
|
internals/util/failure_hander.py
CHANGED
@@ -16,10 +16,13 @@ class FailureHandler:
|
|
16 |
path = FailureHandler.__task_path
|
17 |
path.parent.mkdir(parents=True, exist_ok=True)
|
18 |
if path.exists():
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
|
|
|
|
|
|
23 |
os.remove(path)
|
24 |
|
25 |
@staticmethod
|
|
|
16 |
path = FailureHandler.__task_path
|
17 |
path.parent.mkdir(parents=True, exist_ok=True)
|
18 |
if path.exists():
|
19 |
+
try:
|
20 |
+
task = Task(json.loads(path.read_text()))
|
21 |
+
set_configs_from_task(task)
|
22 |
+
# Slack().error_alert(task, Exception("CATASTROPHIC FAILURE"))
|
23 |
+
updateSource(task.get_sourceId(), task.get_userId(), "FAILED")
|
24 |
+
except Exception as e:
|
25 |
+
print("Failed to handle task", e)
|
26 |
os.remove(path)
|
27 |
|
28 |
@staticmethod
|
internals/util/image.py
CHANGED
@@ -48,3 +48,21 @@ def padd_image(image: Image.Image, to_width: int, to_height: int) -> Image.Image
|
|
48 |
img = Image.new("RGBA", (to_width, to_height), (0, 0, 0, 0))
|
49 |
img.paste(image, ((to_width - iw) // 2, (to_height - ih) // 2))
|
50 |
return img
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
img = Image.new("RGBA", (to_width, to_height), (0, 0, 0, 0))
|
49 |
img.paste(image, ((to_width - iw) // 2, (to_height - ih) // 2))
|
50 |
return img
|
51 |
+
|
52 |
+
|
53 |
+
def alpha_to_white(img: Image.Image) -> Image.Image:
|
54 |
+
if img.mode == "RGBA":
|
55 |
+
data = img.getdata()
|
56 |
+
|
57 |
+
new_data = []
|
58 |
+
|
59 |
+
for item in data:
|
60 |
+
if item[3] == 0:
|
61 |
+
new_data.append((255, 255, 255, 255))
|
62 |
+
else:
|
63 |
+
new_data.append(item)
|
64 |
+
|
65 |
+
img.putdata(new_data)
|
66 |
+
|
67 |
+
img = img.convert("RGB")
|
68 |
+
return img
|
internals/util/lora_style.py
CHANGED
@@ -52,9 +52,18 @@ class LoraStyle:
|
|
52 |
def patch(self):
|
53 |
def run(pipe):
|
54 |
path = self.__style["path"]
|
55 |
-
|
56 |
-
|
57 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
|
59 |
for p in self.pipe:
|
60 |
run(p)
|
@@ -105,7 +114,17 @@ class LoraStyle:
|
|
105 |
def prepend_style_to_prompt(self, prompt: str, key: str) -> str:
|
106 |
if key in self.__styles:
|
107 |
style = self.__styles[key]
|
108 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
109 |
return prompt
|
110 |
|
111 |
def get_patcher(
|
@@ -140,7 +159,9 @@ class LoraStyle:
|
|
140 |
"path": str(file_path),
|
141 |
"weight": attr["weight"],
|
142 |
"type": attr["type"],
|
|
|
143 |
"text": attr["text"],
|
|
|
144 |
"negativePrompt": attr["negativePrompt"],
|
145 |
}
|
146 |
return styles
|
@@ -159,4 +180,7 @@ class LoraStyle:
|
|
159 |
|
160 |
@staticmethod
|
161 |
def unload_lora_weights(pipe):
|
162 |
-
|
|
|
|
|
|
|
|
52 |
def patch(self):
|
53 |
def run(pipe):
|
54 |
path = self.__style["path"]
|
55 |
+
name = str(self.__style["tag"]).replace(" ", "_")
|
56 |
+
weight = self.__style.get("weight", 1.0)
|
57 |
+
if name not in pipe.get_list_adapters().get("unet", []):
|
58 |
+
print(
|
59 |
+
f"Loading lora {os.path.basename(path)} with weights {weight}, name: {name}"
|
60 |
+
)
|
61 |
+
pipe.load_lora_weights(
|
62 |
+
os.path.dirname(path),
|
63 |
+
weight_name=os.path.basename(path),
|
64 |
+
adapter_name=name,
|
65 |
+
)
|
66 |
+
pipe.set_adapters([name], adapter_weights=[weight])
|
67 |
|
68 |
for p in self.pipe:
|
69 |
run(p)
|
|
|
114 |
def prepend_style_to_prompt(self, prompt: str, key: str) -> str:
|
115 |
if key in self.__styles:
|
116 |
style = self.__styles[key]
|
117 |
+
prompt = f"{', '.join(style['text'])}, {prompt}"
|
118 |
+
prompt = prompt.replace("<NOSEP>, ", "")
|
119 |
+
return prompt
|
120 |
+
|
121 |
+
def append_style_to_prompt(self, prompt: str, key: str) -> str:
|
122 |
+
if key in self.__styles and "text_append" in self.__styles[key]:
|
123 |
+
style = self.__styles[key]
|
124 |
+
if prompt.endswith(","):
|
125 |
+
prompt = prompt[:-1]
|
126 |
+
prompt = f"{prompt}, {', '.join(style['text_append'])}"
|
127 |
+
prompt = prompt.replace("<NOSEP>, ", "")
|
128 |
return prompt
|
129 |
|
130 |
def get_patcher(
|
|
|
159 |
"path": str(file_path),
|
160 |
"weight": attr["weight"],
|
161 |
"type": attr["type"],
|
162 |
+
"tag": item["tag"],
|
163 |
"text": attr["text"],
|
164 |
+
"text_append": attr.get("text_append", []),
|
165 |
"negativePrompt": attr["negativePrompt"],
|
166 |
}
|
167 |
return styles
|
|
|
180 |
|
181 |
@staticmethod
|
182 |
def unload_lora_weights(pipe):
|
183 |
+
# we keep the lora layers in the adapters and unset it whenever
|
184 |
+
# not required instead of completely unloading it
|
185 |
+
pipe.set_adapters([])
|
186 |
+
# pipe.unload_lora_weights()
|
internals/util/model_loader.py
CHANGED
@@ -18,7 +18,9 @@ class ModelConfig:
|
|
18 |
base_dimension: int = 512
|
19 |
low_gpu_mem: bool = False
|
20 |
base_model_variant: Optional[str] = None
|
|
|
21 |
base_inpaint_model_variant: Optional[str] = None
|
|
|
22 |
|
23 |
|
24 |
def load_model_from_config(path):
|
@@ -31,7 +33,11 @@ def load_model_from_config(path):
|
|
31 |
is_sdxl = config.get("is_sdxl", False)
|
32 |
base_dimension = config.get("base_dimension", 512)
|
33 |
base_model_variant = config.get("base_model_variant", None)
|
|
|
34 |
base_inpaint_model_variant = config.get("base_inpaint_model_variant", None)
|
|
|
|
|
|
|
35 |
|
36 |
m_config.base_model_path = model_path
|
37 |
m_config.base_inpaint_model_path = inpaint_model_path
|
@@ -39,7 +45,9 @@ def load_model_from_config(path):
|
|
39 |
m_config.base_dimension = base_dimension
|
40 |
m_config.low_gpu_mem = config.get("low_gpu_mem", False)
|
41 |
m_config.base_model_variant = base_model_variant
|
|
|
42 |
m_config.base_inpaint_model_variant = base_inpaint_model_variant
|
|
|
43 |
|
44 |
#
|
45 |
# if config.get("model_type") == "huggingface":
|
|
|
18 |
base_dimension: int = 512
|
19 |
low_gpu_mem: bool = False
|
20 |
base_model_variant: Optional[str] = None
|
21 |
+
base_model_revision: Optional[str] = None
|
22 |
base_inpaint_model_variant: Optional[str] = None
|
23 |
+
base_inpaint_model_revision: Optional[str] = None
|
24 |
|
25 |
|
26 |
def load_model_from_config(path):
|
|
|
33 |
is_sdxl = config.get("is_sdxl", False)
|
34 |
base_dimension = config.get("base_dimension", 512)
|
35 |
base_model_variant = config.get("base_model_variant", None)
|
36 |
+
base_model_revision = config.get("base_model_revision", None)
|
37 |
base_inpaint_model_variant = config.get("base_inpaint_model_variant", None)
|
38 |
+
base_inpaint_model_revision = config.get(
|
39 |
+
"base_inpaint_model_revision", None
|
40 |
+
)
|
41 |
|
42 |
m_config.base_model_path = model_path
|
43 |
m_config.base_inpaint_model_path = inpaint_model_path
|
|
|
45 |
m_config.base_dimension = base_dimension
|
46 |
m_config.low_gpu_mem = config.get("low_gpu_mem", False)
|
47 |
m_config.base_model_variant = base_model_variant
|
48 |
+
m_config.base_model_revision = base_model_revision
|
49 |
m_config.base_inpaint_model_variant = base_inpaint_model_variant
|
50 |
+
m_config.base_inpaint_model_revision = base_inpaint_model_revision
|
51 |
|
52 |
#
|
53 |
# if config.get("model_type") == "huggingface":
|
internals/util/prompt.py
CHANGED
@@ -21,6 +21,7 @@ def get_patched_prompt(
|
|
21 |
for i in range(len(prompt)):
|
22 |
prompt[i] = avatar.add_code_names(prompt[i])
|
23 |
prompt[i] = lora_style.prepend_style_to_prompt(prompt[i], task.get_style())
|
|
|
24 |
if additional:
|
25 |
prompt[i] = additional + " " + prompt[i]
|
26 |
|
@@ -51,6 +52,7 @@ def get_patched_prompt_text2img(
|
|
51 |
def add_style_and_character(prompt: str, prepend: str = ""):
|
52 |
prompt = avatar.add_code_names(prompt)
|
53 |
prompt = lora_style.prepend_style_to_prompt(prompt, task.get_style())
|
|
|
54 |
prompt = prepend + prompt
|
55 |
return prompt
|
56 |
|
@@ -102,6 +104,7 @@ def get_patched_prompt_tile_upscale(
|
|
102 |
lora_style: LoraStyle,
|
103 |
img_classifier: ImageClassifier,
|
104 |
img2text: Image2Text,
|
|
|
105 |
):
|
106 |
if task.get_prompt():
|
107 |
prompt = task.get_prompt()
|
@@ -114,10 +117,12 @@ def get_patched_prompt_tile_upscale(
|
|
114 |
prompt = task.PROMPT.merge_blip(blip)
|
115 |
|
116 |
# remove anomalies in prompt
|
117 |
-
|
|
|
118 |
|
119 |
prompt = avatar.add_code_names(prompt)
|
120 |
prompt = lora_style.prepend_style_to_prompt(prompt, task.get_style())
|
|
|
121 |
|
122 |
if not task.get_style():
|
123 |
class_name = img_classifier.classify(
|
|
|
21 |
for i in range(len(prompt)):
|
22 |
prompt[i] = avatar.add_code_names(prompt[i])
|
23 |
prompt[i] = lora_style.prepend_style_to_prompt(prompt[i], task.get_style())
|
24 |
+
prompt[i] = lora_style.append_style_to_prompt(prompt[i], task.get_style())
|
25 |
if additional:
|
26 |
prompt[i] = additional + " " + prompt[i]
|
27 |
|
|
|
52 |
def add_style_and_character(prompt: str, prepend: str = ""):
|
53 |
prompt = avatar.add_code_names(prompt)
|
54 |
prompt = lora_style.prepend_style_to_prompt(prompt, task.get_style())
|
55 |
+
prompt = lora_style.append_style_to_prompt(prompt, task.get_style())
|
56 |
prompt = prepend + prompt
|
57 |
return prompt
|
58 |
|
|
|
104 |
lora_style: LoraStyle,
|
105 |
img_classifier: ImageClassifier,
|
106 |
img2text: Image2Text,
|
107 |
+
is_sdxl=False,
|
108 |
):
|
109 |
if task.get_prompt():
|
110 |
prompt = task.get_prompt()
|
|
|
117 |
prompt = task.PROMPT.merge_blip(blip)
|
118 |
|
119 |
# remove anomalies in prompt
|
120 |
+
if not is_sdxl:
|
121 |
+
prompt = remove_colors(prompt)
|
122 |
|
123 |
prompt = avatar.add_code_names(prompt)
|
124 |
prompt = lora_style.prepend_style_to_prompt(prompt, task.get_style())
|
125 |
+
prompt = lora_style.append_style_to_prompt(prompt, task.get_style())
|
126 |
|
127 |
if not task.get_style():
|
128 |
class_name = img_classifier.classify(
|
internals/util/sdxl_lightning.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
from re import S
|
3 |
+
from typing import List, Union
|
4 |
+
|
5 |
+
from diffusers import EulerDiscreteScheduler, StableDiffusionXLPipeline
|
6 |
+
from diffusers.loaders.lora import StableDiffusionXLLoraLoaderMixin
|
7 |
+
from torchvision.datasets.utils import download_url
|
8 |
+
|
9 |
+
|
10 |
+
class LightningMixin:
|
11 |
+
LORA_8_STEP_URL = "https://huggingface.co/ByteDance/SDXL-Lightning/resolve/main/sdxl_lightning_8step_lora.safetensors"
|
12 |
+
|
13 |
+
__scheduler_old = None
|
14 |
+
__pipe: StableDiffusionXLPipeline = None
|
15 |
+
__scheduler = None
|
16 |
+
|
17 |
+
def configure_sdxl_lightning(self, pipe: StableDiffusionXLPipeline):
|
18 |
+
lora_path = Path.home() / ".cache" / "lora_8_step.safetensors"
|
19 |
+
|
20 |
+
download_url(self.LORA_8_STEP_URL, str(lora_path.parent), lora_path.name)
|
21 |
+
|
22 |
+
pipe.load_lora_weights(str(lora_path), adapter_name="8step_lora")
|
23 |
+
pipe.set_adapters([])
|
24 |
+
|
25 |
+
self.__scheduler = EulerDiscreteScheduler.from_config(
|
26 |
+
pipe.scheduler.config, timestep_spacing="trailing"
|
27 |
+
)
|
28 |
+
self.__scheduler_old = pipe.scheduler
|
29 |
+
self.__pipe = pipe
|
30 |
+
|
31 |
+
def enable_sdxl_lightning(self):
|
32 |
+
pipe = self.__pipe
|
33 |
+
pipe.scheduler = self.__scheduler
|
34 |
+
|
35 |
+
current = pipe.get_active_adapters()
|
36 |
+
current.extend(["8step_lora"])
|
37 |
+
|
38 |
+
weights = self.__find_adapter_weights(current)
|
39 |
+
pipe.set_adapters(current, adapter_weights=weights)
|
40 |
+
|
41 |
+
return {"guidance_scale": 0, "num_inference_steps": 8}
|
42 |
+
|
43 |
+
def disable_sdxl_lightning(self):
|
44 |
+
pipe = self.__pipe
|
45 |
+
pipe.scheduler = self.__scheduler_old
|
46 |
+
|
47 |
+
current = pipe.get_active_adapters()
|
48 |
+
current = [adapter for adapter in current if adapter != "8step_lora"]
|
49 |
+
|
50 |
+
weights = self.__find_adapter_weights(current)
|
51 |
+
pipe.set_adapters(current, adapter_weights=weights)
|
52 |
+
|
53 |
+
def __find_adapter_weights(self, names: List[str]):
|
54 |
+
pipe = self.__pipe
|
55 |
+
|
56 |
+
model = pipe.unet
|
57 |
+
|
58 |
+
from peft.tuners.tuners_utils import BaseTunerLayer
|
59 |
+
|
60 |
+
weights = []
|
61 |
+
for adapter_name in names:
|
62 |
+
weight = 1.0
|
63 |
+
for module in model.modules():
|
64 |
+
if isinstance(module, BaseTunerLayer):
|
65 |
+
if adapter_name in module.scaling:
|
66 |
+
weight = (
|
67 |
+
module.scaling[adapter_name]
|
68 |
+
* module.r[adapter_name]
|
69 |
+
/ module.lora_alpha[adapter_name]
|
70 |
+
)
|
71 |
+
|
72 |
+
weights.append(weight)
|
73 |
+
|
74 |
+
return weights
|
internals/util/slack.py
CHANGED
@@ -14,6 +14,8 @@ class Slack:
|
|
14 |
self.error_webhook = "https://hooks.slack.com/services/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"
|
15 |
|
16 |
def send_alert(self, task: Task, args: Optional[dict]):
|
|
|
|
|
17 |
raw = task.get_raw().copy()
|
18 |
|
19 |
raw["environment"] = get_environment()
|
@@ -23,6 +25,7 @@ class Slack:
|
|
23 |
raw.pop("task_id", None)
|
24 |
raw.pop("maskImageUrl", None)
|
25 |
raw.pop("aux_imageUrl", None)
|
|
|
26 |
|
27 |
if args is not None:
|
28 |
raw.update(args.items())
|
|
|
14 |
self.error_webhook = "https://hooks.slack.com/services/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"
|
15 |
|
16 |
def send_alert(self, task: Task, args: Optional[dict]):
|
17 |
+
if task.get_slack_url():
|
18 |
+
self.webhook_url = task.get_slack_url()
|
19 |
raw = task.get_raw().copy()
|
20 |
|
21 |
raw["environment"] = get_environment()
|
|
|
25 |
raw.pop("task_id", None)
|
26 |
raw.pop("maskImageUrl", None)
|
27 |
raw.pop("aux_imageUrl", None)
|
28 |
+
raw.pop("slack_url", None)
|
29 |
|
30 |
if args is not None:
|
31 |
raw.update(args.items())
|