Spaces:
Sleeping
Sleeping
Upload 14 files
Browse files- auth.py +41 -0
- cldm.py +312 -0
- constants.py +5 -0
- entry_with_update 2.py +46 -0
- face_restoration_helper.py +374 -0
- inpaint_worker 2.py +264 -0
- inpaint_worker.py +264 -0
- launch_util.py +103 -0
- lora.py +152 -0
- model_loader.py +26 -0
- sdxl_styles.py +82 -0
- upscaler.py +34 -0
- util.py +177 -0
- webui.py +623 -0
auth.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import hashlib
|
3 |
+
import modules.constants as constants
|
4 |
+
|
5 |
+
from os.path import exists
|
6 |
+
|
7 |
+
|
8 |
+
def auth_list_to_dict(auth_list):
|
9 |
+
auth_dict = {}
|
10 |
+
for auth_data in auth_list:
|
11 |
+
if 'user' in auth_data:
|
12 |
+
if 'hash' in auth_data:
|
13 |
+
auth_dict |= {auth_data['user']: auth_data['hash']}
|
14 |
+
elif 'pass' in auth_data:
|
15 |
+
auth_dict |= {auth_data['user']: hashlib.sha256(bytes(auth_data['pass'], encoding='utf-8')).hexdigest()}
|
16 |
+
return auth_dict
|
17 |
+
|
18 |
+
|
19 |
+
def load_auth_data(filename=None):
|
20 |
+
auth_dict = None
|
21 |
+
if filename != None and exists(filename):
|
22 |
+
with open(filename, encoding='utf-8') as auth_file:
|
23 |
+
try:
|
24 |
+
auth_obj = json.load(auth_file)
|
25 |
+
if isinstance(auth_obj, list) and len(auth_obj) > 0:
|
26 |
+
auth_dict = auth_list_to_dict(auth_obj)
|
27 |
+
except Exception as e:
|
28 |
+
print('load_auth_data, e: ' + str(e))
|
29 |
+
return auth_dict
|
30 |
+
|
31 |
+
|
32 |
+
auth_dict = load_auth_data(constants.AUTH_FILENAME)
|
33 |
+
|
34 |
+
auth_enabled = auth_dict != None
|
35 |
+
|
36 |
+
|
37 |
+
def check_auth(user, password):
|
38 |
+
if user not in auth_dict:
|
39 |
+
return False
|
40 |
+
else:
|
41 |
+
return hashlib.sha256(bytes(password, encoding='utf-8')).hexdigest() == auth_dict[user]
|
cldm.py
ADDED
@@ -0,0 +1,312 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#taken from: https://github.com/lllyasviel/ControlNet
|
2 |
+
#and modified
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch as th
|
6 |
+
import torch.nn as nn
|
7 |
+
|
8 |
+
from ldm_patched.ldm.modules.diffusionmodules.util import (
|
9 |
+
zero_module,
|
10 |
+
timestep_embedding,
|
11 |
+
)
|
12 |
+
|
13 |
+
from ldm_patched.ldm.modules.attention import SpatialTransformer
|
14 |
+
from ldm_patched.ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample
|
15 |
+
from ldm_patched.ldm.util import exists
|
16 |
+
import ldm_patched.modules.ops
|
17 |
+
|
18 |
+
class ControlledUnetModel(UNetModel):
|
19 |
+
#implemented in the ldm unet
|
20 |
+
pass
|
21 |
+
|
22 |
+
class ControlNet(nn.Module):
|
23 |
+
def __init__(
|
24 |
+
self,
|
25 |
+
image_size,
|
26 |
+
in_channels,
|
27 |
+
model_channels,
|
28 |
+
hint_channels,
|
29 |
+
num_res_blocks,
|
30 |
+
dropout=0,
|
31 |
+
channel_mult=(1, 2, 4, 8),
|
32 |
+
conv_resample=True,
|
33 |
+
dims=2,
|
34 |
+
num_classes=None,
|
35 |
+
use_checkpoint=False,
|
36 |
+
dtype=torch.float32,
|
37 |
+
num_heads=-1,
|
38 |
+
num_head_channels=-1,
|
39 |
+
num_heads_upsample=-1,
|
40 |
+
use_scale_shift_norm=False,
|
41 |
+
resblock_updown=False,
|
42 |
+
use_new_attention_order=False,
|
43 |
+
use_spatial_transformer=False, # custom transformer support
|
44 |
+
transformer_depth=1, # custom transformer support
|
45 |
+
context_dim=None, # custom transformer support
|
46 |
+
n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
|
47 |
+
legacy=True,
|
48 |
+
disable_self_attentions=None,
|
49 |
+
num_attention_blocks=None,
|
50 |
+
disable_middle_self_attn=False,
|
51 |
+
use_linear_in_transformer=False,
|
52 |
+
adm_in_channels=None,
|
53 |
+
transformer_depth_middle=None,
|
54 |
+
transformer_depth_output=None,
|
55 |
+
device=None,
|
56 |
+
operations=ldm_patched.modules.ops.disable_weight_init,
|
57 |
+
**kwargs,
|
58 |
+
):
|
59 |
+
super().__init__()
|
60 |
+
assert use_spatial_transformer == True, "use_spatial_transformer has to be true"
|
61 |
+
if use_spatial_transformer:
|
62 |
+
assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
|
63 |
+
|
64 |
+
if context_dim is not None:
|
65 |
+
assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
|
66 |
+
# from omegaconf.listconfig import ListConfig
|
67 |
+
# if type(context_dim) == ListConfig:
|
68 |
+
# context_dim = list(context_dim)
|
69 |
+
|
70 |
+
if num_heads_upsample == -1:
|
71 |
+
num_heads_upsample = num_heads
|
72 |
+
|
73 |
+
if num_heads == -1:
|
74 |
+
assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
|
75 |
+
|
76 |
+
if num_head_channels == -1:
|
77 |
+
assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
|
78 |
+
|
79 |
+
self.dims = dims
|
80 |
+
self.image_size = image_size
|
81 |
+
self.in_channels = in_channels
|
82 |
+
self.model_channels = model_channels
|
83 |
+
|
84 |
+
if isinstance(num_res_blocks, int):
|
85 |
+
self.num_res_blocks = len(channel_mult) * [num_res_blocks]
|
86 |
+
else:
|
87 |
+
if len(num_res_blocks) != len(channel_mult):
|
88 |
+
raise ValueError("provide num_res_blocks either as an int (globally constant) or "
|
89 |
+
"as a list/tuple (per-level) with the same length as channel_mult")
|
90 |
+
self.num_res_blocks = num_res_blocks
|
91 |
+
|
92 |
+
if disable_self_attentions is not None:
|
93 |
+
# should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
|
94 |
+
assert len(disable_self_attentions) == len(channel_mult)
|
95 |
+
if num_attention_blocks is not None:
|
96 |
+
assert len(num_attention_blocks) == len(self.num_res_blocks)
|
97 |
+
assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
|
98 |
+
|
99 |
+
transformer_depth = transformer_depth[:]
|
100 |
+
|
101 |
+
self.dropout = dropout
|
102 |
+
self.channel_mult = channel_mult
|
103 |
+
self.conv_resample = conv_resample
|
104 |
+
self.num_classes = num_classes
|
105 |
+
self.use_checkpoint = use_checkpoint
|
106 |
+
self.dtype = dtype
|
107 |
+
self.num_heads = num_heads
|
108 |
+
self.num_head_channels = num_head_channels
|
109 |
+
self.num_heads_upsample = num_heads_upsample
|
110 |
+
self.predict_codebook_ids = n_embed is not None
|
111 |
+
|
112 |
+
time_embed_dim = model_channels * 4
|
113 |
+
self.time_embed = nn.Sequential(
|
114 |
+
operations.Linear(model_channels, time_embed_dim, dtype=self.dtype, device=device),
|
115 |
+
nn.SiLU(),
|
116 |
+
operations.Linear(time_embed_dim, time_embed_dim, dtype=self.dtype, device=device),
|
117 |
+
)
|
118 |
+
|
119 |
+
if self.num_classes is not None:
|
120 |
+
if isinstance(self.num_classes, int):
|
121 |
+
self.label_emb = nn.Embedding(num_classes, time_embed_dim)
|
122 |
+
elif self.num_classes == "continuous":
|
123 |
+
print("setting up linear c_adm embedding layer")
|
124 |
+
self.label_emb = nn.Linear(1, time_embed_dim)
|
125 |
+
elif self.num_classes == "sequential":
|
126 |
+
assert adm_in_channels is not None
|
127 |
+
self.label_emb = nn.Sequential(
|
128 |
+
nn.Sequential(
|
129 |
+
operations.Linear(adm_in_channels, time_embed_dim, dtype=self.dtype, device=device),
|
130 |
+
nn.SiLU(),
|
131 |
+
operations.Linear(time_embed_dim, time_embed_dim, dtype=self.dtype, device=device),
|
132 |
+
)
|
133 |
+
)
|
134 |
+
else:
|
135 |
+
raise ValueError()
|
136 |
+
|
137 |
+
self.input_blocks = nn.ModuleList(
|
138 |
+
[
|
139 |
+
TimestepEmbedSequential(
|
140 |
+
operations.conv_nd(dims, in_channels, model_channels, 3, padding=1, dtype=self.dtype, device=device)
|
141 |
+
)
|
142 |
+
]
|
143 |
+
)
|
144 |
+
self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels, operations=operations, dtype=self.dtype, device=device)])
|
145 |
+
|
146 |
+
self.input_hint_block = TimestepEmbedSequential(
|
147 |
+
operations.conv_nd(dims, hint_channels, 16, 3, padding=1, dtype=self.dtype, device=device),
|
148 |
+
nn.SiLU(),
|
149 |
+
operations.conv_nd(dims, 16, 16, 3, padding=1, dtype=self.dtype, device=device),
|
150 |
+
nn.SiLU(),
|
151 |
+
operations.conv_nd(dims, 16, 32, 3, padding=1, stride=2, dtype=self.dtype, device=device),
|
152 |
+
nn.SiLU(),
|
153 |
+
operations.conv_nd(dims, 32, 32, 3, padding=1, dtype=self.dtype, device=device),
|
154 |
+
nn.SiLU(),
|
155 |
+
operations.conv_nd(dims, 32, 96, 3, padding=1, stride=2, dtype=self.dtype, device=device),
|
156 |
+
nn.SiLU(),
|
157 |
+
operations.conv_nd(dims, 96, 96, 3, padding=1, dtype=self.dtype, device=device),
|
158 |
+
nn.SiLU(),
|
159 |
+
operations.conv_nd(dims, 96, 256, 3, padding=1, stride=2, dtype=self.dtype, device=device),
|
160 |
+
nn.SiLU(),
|
161 |
+
operations.conv_nd(dims, 256, model_channels, 3, padding=1, dtype=self.dtype, device=device)
|
162 |
+
)
|
163 |
+
|
164 |
+
self._feature_size = model_channels
|
165 |
+
input_block_chans = [model_channels]
|
166 |
+
ch = model_channels
|
167 |
+
ds = 1
|
168 |
+
for level, mult in enumerate(channel_mult):
|
169 |
+
for nr in range(self.num_res_blocks[level]):
|
170 |
+
layers = [
|
171 |
+
ResBlock(
|
172 |
+
ch,
|
173 |
+
time_embed_dim,
|
174 |
+
dropout,
|
175 |
+
out_channels=mult * model_channels,
|
176 |
+
dims=dims,
|
177 |
+
use_checkpoint=use_checkpoint,
|
178 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
179 |
+
dtype=self.dtype,
|
180 |
+
device=device,
|
181 |
+
operations=operations,
|
182 |
+
)
|
183 |
+
]
|
184 |
+
ch = mult * model_channels
|
185 |
+
num_transformers = transformer_depth.pop(0)
|
186 |
+
if num_transformers > 0:
|
187 |
+
if num_head_channels == -1:
|
188 |
+
dim_head = ch // num_heads
|
189 |
+
else:
|
190 |
+
num_heads = ch // num_head_channels
|
191 |
+
dim_head = num_head_channels
|
192 |
+
if legacy:
|
193 |
+
#num_heads = 1
|
194 |
+
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
195 |
+
if exists(disable_self_attentions):
|
196 |
+
disabled_sa = disable_self_attentions[level]
|
197 |
+
else:
|
198 |
+
disabled_sa = False
|
199 |
+
|
200 |
+
if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
|
201 |
+
layers.append(
|
202 |
+
SpatialTransformer(
|
203 |
+
ch, num_heads, dim_head, depth=num_transformers, context_dim=context_dim,
|
204 |
+
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
|
205 |
+
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations
|
206 |
+
)
|
207 |
+
)
|
208 |
+
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
209 |
+
self.zero_convs.append(self.make_zero_conv(ch, operations=operations, dtype=self.dtype, device=device))
|
210 |
+
self._feature_size += ch
|
211 |
+
input_block_chans.append(ch)
|
212 |
+
if level != len(channel_mult) - 1:
|
213 |
+
out_ch = ch
|
214 |
+
self.input_blocks.append(
|
215 |
+
TimestepEmbedSequential(
|
216 |
+
ResBlock(
|
217 |
+
ch,
|
218 |
+
time_embed_dim,
|
219 |
+
dropout,
|
220 |
+
out_channels=out_ch,
|
221 |
+
dims=dims,
|
222 |
+
use_checkpoint=use_checkpoint,
|
223 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
224 |
+
down=True,
|
225 |
+
dtype=self.dtype,
|
226 |
+
device=device,
|
227 |
+
operations=operations
|
228 |
+
)
|
229 |
+
if resblock_updown
|
230 |
+
else Downsample(
|
231 |
+
ch, conv_resample, dims=dims, out_channels=out_ch, dtype=self.dtype, device=device, operations=operations
|
232 |
+
)
|
233 |
+
)
|
234 |
+
)
|
235 |
+
ch = out_ch
|
236 |
+
input_block_chans.append(ch)
|
237 |
+
self.zero_convs.append(self.make_zero_conv(ch, operations=operations, dtype=self.dtype, device=device))
|
238 |
+
ds *= 2
|
239 |
+
self._feature_size += ch
|
240 |
+
|
241 |
+
if num_head_channels == -1:
|
242 |
+
dim_head = ch // num_heads
|
243 |
+
else:
|
244 |
+
num_heads = ch // num_head_channels
|
245 |
+
dim_head = num_head_channels
|
246 |
+
if legacy:
|
247 |
+
#num_heads = 1
|
248 |
+
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
249 |
+
mid_block = [
|
250 |
+
ResBlock(
|
251 |
+
ch,
|
252 |
+
time_embed_dim,
|
253 |
+
dropout,
|
254 |
+
dims=dims,
|
255 |
+
use_checkpoint=use_checkpoint,
|
256 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
257 |
+
dtype=self.dtype,
|
258 |
+
device=device,
|
259 |
+
operations=operations
|
260 |
+
)]
|
261 |
+
if transformer_depth_middle >= 0:
|
262 |
+
mid_block += [SpatialTransformer( # always uses a self-attn
|
263 |
+
ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim,
|
264 |
+
disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
|
265 |
+
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations
|
266 |
+
),
|
267 |
+
ResBlock(
|
268 |
+
ch,
|
269 |
+
time_embed_dim,
|
270 |
+
dropout,
|
271 |
+
dims=dims,
|
272 |
+
use_checkpoint=use_checkpoint,
|
273 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
274 |
+
dtype=self.dtype,
|
275 |
+
device=device,
|
276 |
+
operations=operations
|
277 |
+
)]
|
278 |
+
self.middle_block = TimestepEmbedSequential(*mid_block)
|
279 |
+
self.middle_block_out = self.make_zero_conv(ch, operations=operations, dtype=self.dtype, device=device)
|
280 |
+
self._feature_size += ch
|
281 |
+
|
282 |
+
def make_zero_conv(self, channels, operations=None, dtype=None, device=None):
|
283 |
+
return TimestepEmbedSequential(operations.conv_nd(self.dims, channels, channels, 1, padding=0, dtype=dtype, device=device))
|
284 |
+
|
285 |
+
def forward(self, x, hint, timesteps, context, y=None, **kwargs):
|
286 |
+
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
|
287 |
+
emb = self.time_embed(t_emb)
|
288 |
+
|
289 |
+
guided_hint = self.input_hint_block(hint, emb, context)
|
290 |
+
|
291 |
+
outs = []
|
292 |
+
|
293 |
+
hs = []
|
294 |
+
if self.num_classes is not None:
|
295 |
+
assert y.shape[0] == x.shape[0]
|
296 |
+
emb = emb + self.label_emb(y)
|
297 |
+
|
298 |
+
h = x
|
299 |
+
for module, zero_conv in zip(self.input_blocks, self.zero_convs):
|
300 |
+
if guided_hint is not None:
|
301 |
+
h = module(h, emb, context)
|
302 |
+
h += guided_hint
|
303 |
+
guided_hint = None
|
304 |
+
else:
|
305 |
+
h = module(h, emb, context)
|
306 |
+
outs.append(zero_conv(h, emb, context))
|
307 |
+
|
308 |
+
h = self.middle_block(h, emb, context)
|
309 |
+
outs.append(self.middle_block_out(h, emb, context))
|
310 |
+
|
311 |
+
return outs
|
312 |
+
|
constants.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# as in k-diffusion (sampling.py)
|
2 |
+
MIN_SEED = 0
|
3 |
+
MAX_SEED = 2**63 - 1
|
4 |
+
|
5 |
+
AUTH_FILENAME = 'auth.json'
|
entry_with_update 2.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
|
4 |
+
|
5 |
+
root = os.path.dirname(os.path.abspath(__file__))
|
6 |
+
sys.path.append(root)
|
7 |
+
os.chdir(root)
|
8 |
+
|
9 |
+
|
10 |
+
try:
|
11 |
+
import pygit2
|
12 |
+
pygit2.option(pygit2.GIT_OPT_SET_OWNER_VALIDATION, 0)
|
13 |
+
|
14 |
+
repo = pygit2.Repository(os.path.abspath(os.path.dirname(__file__)))
|
15 |
+
|
16 |
+
branch_name = repo.head.shorthand
|
17 |
+
|
18 |
+
remote_name = 'origin'
|
19 |
+
remote = repo.remotes[remote_name]
|
20 |
+
|
21 |
+
remote.fetch()
|
22 |
+
|
23 |
+
local_branch_ref = f'refs/heads/{branch_name}'
|
24 |
+
local_branch = repo.lookup_reference(local_branch_ref)
|
25 |
+
|
26 |
+
remote_reference = f'refs/remotes/{remote_name}/{branch_name}'
|
27 |
+
remote_commit = repo.revparse_single(remote_reference)
|
28 |
+
|
29 |
+
merge_result, _ = repo.merge_analysis(remote_commit.id)
|
30 |
+
|
31 |
+
if merge_result & pygit2.GIT_MERGE_ANALYSIS_UP_TO_DATE:
|
32 |
+
print("Already up-to-date")
|
33 |
+
elif merge_result & pygit2.GIT_MERGE_ANALYSIS_FASTFORWARD:
|
34 |
+
local_branch.set_target(remote_commit.id)
|
35 |
+
repo.head.set_target(remote_commit.id)
|
36 |
+
repo.checkout_tree(repo.get(remote_commit.id))
|
37 |
+
repo.reset(local_branch.target, pygit2.GIT_RESET_HARD)
|
38 |
+
print("Fast-forward merge")
|
39 |
+
elif merge_result & pygit2.GIT_MERGE_ANALYSIS_NORMAL:
|
40 |
+
print("Update failed - Did you modify any file?")
|
41 |
+
except Exception as e:
|
42 |
+
print('Update failed.')
|
43 |
+
print(str(e))
|
44 |
+
|
45 |
+
print('Update succeeded.')
|
46 |
+
from launch import *
|
face_restoration_helper.py
ADDED
@@ -0,0 +1,374 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
import os
|
4 |
+
import torch
|
5 |
+
from torchvision.transforms.functional import normalize
|
6 |
+
|
7 |
+
from extras.facexlib.detection import init_detection_model
|
8 |
+
from extras.facexlib.parsing import init_parsing_model
|
9 |
+
from extras.facexlib.utils.misc import img2tensor, imwrite
|
10 |
+
|
11 |
+
|
12 |
+
def get_largest_face(det_faces, h, w):
|
13 |
+
|
14 |
+
def get_location(val, length):
|
15 |
+
if val < 0:
|
16 |
+
return 0
|
17 |
+
elif val > length:
|
18 |
+
return length
|
19 |
+
else:
|
20 |
+
return val
|
21 |
+
|
22 |
+
face_areas = []
|
23 |
+
for det_face in det_faces:
|
24 |
+
left = get_location(det_face[0], w)
|
25 |
+
right = get_location(det_face[2], w)
|
26 |
+
top = get_location(det_face[1], h)
|
27 |
+
bottom = get_location(det_face[3], h)
|
28 |
+
face_area = (right - left) * (bottom - top)
|
29 |
+
face_areas.append(face_area)
|
30 |
+
largest_idx = face_areas.index(max(face_areas))
|
31 |
+
return det_faces[largest_idx], largest_idx
|
32 |
+
|
33 |
+
|
34 |
+
def get_center_face(det_faces, h=0, w=0, center=None):
|
35 |
+
if center is not None:
|
36 |
+
center = np.array(center)
|
37 |
+
else:
|
38 |
+
center = np.array([w / 2, h / 2])
|
39 |
+
center_dist = []
|
40 |
+
for det_face in det_faces:
|
41 |
+
face_center = np.array([(det_face[0] + det_face[2]) / 2, (det_face[1] + det_face[3]) / 2])
|
42 |
+
dist = np.linalg.norm(face_center - center)
|
43 |
+
center_dist.append(dist)
|
44 |
+
center_idx = center_dist.index(min(center_dist))
|
45 |
+
return det_faces[center_idx], center_idx
|
46 |
+
|
47 |
+
|
48 |
+
class FaceRestoreHelper(object):
|
49 |
+
"""Helper for the face restoration pipeline (base class)."""
|
50 |
+
|
51 |
+
def __init__(self,
|
52 |
+
upscale_factor,
|
53 |
+
face_size=512,
|
54 |
+
crop_ratio=(1, 1),
|
55 |
+
det_model='retinaface_resnet50',
|
56 |
+
save_ext='png',
|
57 |
+
template_3points=False,
|
58 |
+
pad_blur=False,
|
59 |
+
use_parse=False,
|
60 |
+
device=None,
|
61 |
+
model_rootpath=None):
|
62 |
+
self.template_3points = template_3points # improve robustness
|
63 |
+
self.upscale_factor = upscale_factor
|
64 |
+
# the cropped face ratio based on the square face
|
65 |
+
self.crop_ratio = crop_ratio # (h, w)
|
66 |
+
assert (self.crop_ratio[0] >= 1 and self.crop_ratio[1] >= 1), 'crop ration only supports >=1'
|
67 |
+
self.face_size = (int(face_size * self.crop_ratio[1]), int(face_size * self.crop_ratio[0]))
|
68 |
+
|
69 |
+
if self.template_3points:
|
70 |
+
self.face_template = np.array([[192, 240], [319, 240], [257, 371]])
|
71 |
+
else:
|
72 |
+
# standard 5 landmarks for FFHQ faces with 512 x 512
|
73 |
+
self.face_template = np.array([[192.98138, 239.94708], [318.90277, 240.1936], [256.63416, 314.01935],
|
74 |
+
[201.26117, 371.41043], [313.08905, 371.15118]])
|
75 |
+
self.face_template = self.face_template * (face_size / 512.0)
|
76 |
+
if self.crop_ratio[0] > 1:
|
77 |
+
self.face_template[:, 1] += face_size * (self.crop_ratio[0] - 1) / 2
|
78 |
+
if self.crop_ratio[1] > 1:
|
79 |
+
self.face_template[:, 0] += face_size * (self.crop_ratio[1] - 1) / 2
|
80 |
+
self.save_ext = save_ext
|
81 |
+
self.pad_blur = pad_blur
|
82 |
+
if self.pad_blur is True:
|
83 |
+
self.template_3points = False
|
84 |
+
|
85 |
+
self.all_landmarks_5 = []
|
86 |
+
self.det_faces = []
|
87 |
+
self.affine_matrices = []
|
88 |
+
self.inverse_affine_matrices = []
|
89 |
+
self.cropped_faces = []
|
90 |
+
self.restored_faces = []
|
91 |
+
self.pad_input_imgs = []
|
92 |
+
|
93 |
+
if device is None:
|
94 |
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
95 |
+
else:
|
96 |
+
self.device = device
|
97 |
+
|
98 |
+
# init face detection model
|
99 |
+
self.face_det = init_detection_model(det_model, half=False, device=self.device, model_rootpath=model_rootpath)
|
100 |
+
|
101 |
+
# init face parsing model
|
102 |
+
self.use_parse = use_parse
|
103 |
+
self.face_parse = init_parsing_model(model_name='parsenet', device=self.device, model_rootpath=model_rootpath)
|
104 |
+
|
105 |
+
def set_upscale_factor(self, upscale_factor):
|
106 |
+
self.upscale_factor = upscale_factor
|
107 |
+
|
108 |
+
def read_image(self, img):
|
109 |
+
"""img can be image path or cv2 loaded image."""
|
110 |
+
# self.input_img is Numpy array, (h, w, c), BGR, uint8, [0, 255]
|
111 |
+
if isinstance(img, str):
|
112 |
+
img = cv2.imread(img)
|
113 |
+
|
114 |
+
if np.max(img) > 256: # 16-bit image
|
115 |
+
img = img / 65535 * 255
|
116 |
+
if len(img.shape) == 2: # gray image
|
117 |
+
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
118 |
+
elif img.shape[2] == 4: # RGBA image with alpha channel
|
119 |
+
img = img[:, :, 0:3]
|
120 |
+
|
121 |
+
self.input_img = img
|
122 |
+
|
123 |
+
def get_face_landmarks_5(self,
|
124 |
+
only_keep_largest=False,
|
125 |
+
only_center_face=False,
|
126 |
+
resize=None,
|
127 |
+
blur_ratio=0.01,
|
128 |
+
eye_dist_threshold=None):
|
129 |
+
if resize is None:
|
130 |
+
scale = 1
|
131 |
+
input_img = self.input_img
|
132 |
+
else:
|
133 |
+
h, w = self.input_img.shape[0:2]
|
134 |
+
scale = min(h, w) / resize
|
135 |
+
h, w = int(h / scale), int(w / scale)
|
136 |
+
input_img = cv2.resize(self.input_img, (w, h), interpolation=cv2.INTER_LANCZOS4)
|
137 |
+
|
138 |
+
with torch.no_grad():
|
139 |
+
bboxes = self.face_det.detect_faces(input_img, 0.97) * scale
|
140 |
+
for bbox in bboxes:
|
141 |
+
# remove faces with too small eye distance: side faces or too small faces
|
142 |
+
eye_dist = np.linalg.norm([bbox[5] - bbox[7], bbox[6] - bbox[8]])
|
143 |
+
if eye_dist_threshold is not None and (eye_dist < eye_dist_threshold):
|
144 |
+
continue
|
145 |
+
|
146 |
+
if self.template_3points:
|
147 |
+
landmark = np.array([[bbox[i], bbox[i + 1]] for i in range(5, 11, 2)])
|
148 |
+
else:
|
149 |
+
landmark = np.array([[bbox[i], bbox[i + 1]] for i in range(5, 15, 2)])
|
150 |
+
self.all_landmarks_5.append(landmark)
|
151 |
+
self.det_faces.append(bbox[0:5])
|
152 |
+
if len(self.det_faces) == 0:
|
153 |
+
return 0
|
154 |
+
if only_keep_largest:
|
155 |
+
h, w, _ = self.input_img.shape
|
156 |
+
self.det_faces, largest_idx = get_largest_face(self.det_faces, h, w)
|
157 |
+
self.all_landmarks_5 = [self.all_landmarks_5[largest_idx]]
|
158 |
+
elif only_center_face:
|
159 |
+
h, w, _ = self.input_img.shape
|
160 |
+
self.det_faces, center_idx = get_center_face(self.det_faces, h, w)
|
161 |
+
self.all_landmarks_5 = [self.all_landmarks_5[center_idx]]
|
162 |
+
|
163 |
+
# pad blurry images
|
164 |
+
if self.pad_blur:
|
165 |
+
self.pad_input_imgs = []
|
166 |
+
for landmarks in self.all_landmarks_5:
|
167 |
+
# get landmarks
|
168 |
+
eye_left = landmarks[0, :]
|
169 |
+
eye_right = landmarks[1, :]
|
170 |
+
eye_avg = (eye_left + eye_right) * 0.5
|
171 |
+
mouth_avg = (landmarks[3, :] + landmarks[4, :]) * 0.5
|
172 |
+
eye_to_eye = eye_right - eye_left
|
173 |
+
eye_to_mouth = mouth_avg - eye_avg
|
174 |
+
|
175 |
+
# Get the oriented crop rectangle
|
176 |
+
# x: half width of the oriented crop rectangle
|
177 |
+
x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1]
|
178 |
+
# - np.flipud(eye_to_mouth) * [-1, 1]: rotate 90 clockwise
|
179 |
+
# norm with the hypotenuse: get the direction
|
180 |
+
x /= np.hypot(*x) # get the hypotenuse of a right triangle
|
181 |
+
rect_scale = 1.5
|
182 |
+
x *= max(np.hypot(*eye_to_eye) * 2.0 * rect_scale, np.hypot(*eye_to_mouth) * 1.8 * rect_scale)
|
183 |
+
# y: half height of the oriented crop rectangle
|
184 |
+
y = np.flipud(x) * [-1, 1]
|
185 |
+
|
186 |
+
# c: center
|
187 |
+
c = eye_avg + eye_to_mouth * 0.1
|
188 |
+
# quad: (left_top, left_bottom, right_bottom, right_top)
|
189 |
+
quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y])
|
190 |
+
# qsize: side length of the square
|
191 |
+
qsize = np.hypot(*x) * 2
|
192 |
+
border = max(int(np.rint(qsize * 0.1)), 3)
|
193 |
+
|
194 |
+
# get pad
|
195 |
+
# pad: (width_left, height_top, width_right, height_bottom)
|
196 |
+
pad = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))),
|
197 |
+
int(np.ceil(max(quad[:, 1]))))
|
198 |
+
pad = [
|
199 |
+
max(-pad[0] + border, 1),
|
200 |
+
max(-pad[1] + border, 1),
|
201 |
+
max(pad[2] - self.input_img.shape[0] + border, 1),
|
202 |
+
max(pad[3] - self.input_img.shape[1] + border, 1)
|
203 |
+
]
|
204 |
+
|
205 |
+
if max(pad) > 1:
|
206 |
+
# pad image
|
207 |
+
pad_img = np.pad(self.input_img, ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect')
|
208 |
+
# modify landmark coords
|
209 |
+
landmarks[:, 0] += pad[0]
|
210 |
+
landmarks[:, 1] += pad[1]
|
211 |
+
# blur pad images
|
212 |
+
h, w, _ = pad_img.shape
|
213 |
+
y, x, _ = np.ogrid[:h, :w, :1]
|
214 |
+
mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0],
|
215 |
+
np.float32(w - 1 - x) / pad[2]),
|
216 |
+
1.0 - np.minimum(np.float32(y) / pad[1],
|
217 |
+
np.float32(h - 1 - y) / pad[3]))
|
218 |
+
blur = int(qsize * blur_ratio)
|
219 |
+
if blur % 2 == 0:
|
220 |
+
blur += 1
|
221 |
+
blur_img = cv2.boxFilter(pad_img, 0, ksize=(blur, blur))
|
222 |
+
# blur_img = cv2.GaussianBlur(pad_img, (blur, blur), 0)
|
223 |
+
|
224 |
+
pad_img = pad_img.astype('float32')
|
225 |
+
pad_img += (blur_img - pad_img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0)
|
226 |
+
pad_img += (np.median(pad_img, axis=(0, 1)) - pad_img) * np.clip(mask, 0.0, 1.0)
|
227 |
+
pad_img = np.clip(pad_img, 0, 255) # float32, [0, 255]
|
228 |
+
self.pad_input_imgs.append(pad_img)
|
229 |
+
else:
|
230 |
+
self.pad_input_imgs.append(np.copy(self.input_img))
|
231 |
+
|
232 |
+
return len(self.all_landmarks_5)
|
233 |
+
|
234 |
+
def align_warp_face(self, save_cropped_path=None, border_mode='constant'):
|
235 |
+
"""Align and warp faces with face template.
|
236 |
+
"""
|
237 |
+
if self.pad_blur:
|
238 |
+
assert len(self.pad_input_imgs) == len(
|
239 |
+
self.all_landmarks_5), f'Mismatched samples: {len(self.pad_input_imgs)} and {len(self.all_landmarks_5)}'
|
240 |
+
for idx, landmark in enumerate(self.all_landmarks_5):
|
241 |
+
# use 5 landmarks to get affine matrix
|
242 |
+
# use cv2.LMEDS method for the equivalence to skimage transform
|
243 |
+
# ref: https://blog.csdn.net/yichxi/article/details/115827338
|
244 |
+
affine_matrix = cv2.estimateAffinePartial2D(landmark, self.face_template, method=cv2.LMEDS)[0]
|
245 |
+
self.affine_matrices.append(affine_matrix)
|
246 |
+
# warp and crop faces
|
247 |
+
if border_mode == 'constant':
|
248 |
+
border_mode = cv2.BORDER_CONSTANT
|
249 |
+
elif border_mode == 'reflect101':
|
250 |
+
border_mode = cv2.BORDER_REFLECT101
|
251 |
+
elif border_mode == 'reflect':
|
252 |
+
border_mode = cv2.BORDER_REFLECT
|
253 |
+
if self.pad_blur:
|
254 |
+
input_img = self.pad_input_imgs[idx]
|
255 |
+
else:
|
256 |
+
input_img = self.input_img
|
257 |
+
cropped_face = cv2.warpAffine(
|
258 |
+
input_img, affine_matrix, self.face_size, borderMode=border_mode, borderValue=(135, 133, 132)) # gray
|
259 |
+
self.cropped_faces.append(cropped_face)
|
260 |
+
# save the cropped face
|
261 |
+
if save_cropped_path is not None:
|
262 |
+
path = os.path.splitext(save_cropped_path)[0]
|
263 |
+
save_path = f'{path}_{idx:02d}.{self.save_ext}'
|
264 |
+
imwrite(cropped_face, save_path)
|
265 |
+
|
266 |
+
def get_inverse_affine(self, save_inverse_affine_path=None):
|
267 |
+
"""Get inverse affine matrix."""
|
268 |
+
for idx, affine_matrix in enumerate(self.affine_matrices):
|
269 |
+
inverse_affine = cv2.invertAffineTransform(affine_matrix)
|
270 |
+
inverse_affine *= self.upscale_factor
|
271 |
+
self.inverse_affine_matrices.append(inverse_affine)
|
272 |
+
# save inverse affine matrices
|
273 |
+
if save_inverse_affine_path is not None:
|
274 |
+
path, _ = os.path.splitext(save_inverse_affine_path)
|
275 |
+
save_path = f'{path}_{idx:02d}.pth'
|
276 |
+
torch.save(inverse_affine, save_path)
|
277 |
+
|
278 |
+
def add_restored_face(self, face):
|
279 |
+
self.restored_faces.append(face)
|
280 |
+
|
281 |
+
def paste_faces_to_input_image(self, save_path=None, upsample_img=None):
|
282 |
+
h, w, _ = self.input_img.shape
|
283 |
+
h_up, w_up = int(h * self.upscale_factor), int(w * self.upscale_factor)
|
284 |
+
|
285 |
+
if upsample_img is None:
|
286 |
+
# simply resize the background
|
287 |
+
upsample_img = cv2.resize(self.input_img, (w_up, h_up), interpolation=cv2.INTER_LANCZOS4)
|
288 |
+
else:
|
289 |
+
upsample_img = cv2.resize(upsample_img, (w_up, h_up), interpolation=cv2.INTER_LANCZOS4)
|
290 |
+
|
291 |
+
assert len(self.restored_faces) == len(
|
292 |
+
self.inverse_affine_matrices), ('length of restored_faces and affine_matrices are different.')
|
293 |
+
for restored_face, inverse_affine in zip(self.restored_faces, self.inverse_affine_matrices):
|
294 |
+
# Add an offset to inverse affine matrix, for more precise back alignment
|
295 |
+
if self.upscale_factor > 1:
|
296 |
+
extra_offset = 0.5 * self.upscale_factor
|
297 |
+
else:
|
298 |
+
extra_offset = 0
|
299 |
+
inverse_affine[:, 2] += extra_offset
|
300 |
+
inv_restored = cv2.warpAffine(restored_face, inverse_affine, (w_up, h_up))
|
301 |
+
|
302 |
+
if self.use_parse:
|
303 |
+
# inference
|
304 |
+
face_input = cv2.resize(restored_face, (512, 512), interpolation=cv2.INTER_LINEAR)
|
305 |
+
face_input = img2tensor(face_input.astype('float32') / 255., bgr2rgb=True, float32=True)
|
306 |
+
normalize(face_input, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
|
307 |
+
face_input = torch.unsqueeze(face_input, 0).to(self.device)
|
308 |
+
with torch.no_grad():
|
309 |
+
out = self.face_parse(face_input)[0]
|
310 |
+
out = out.argmax(dim=1).squeeze().cpu().numpy()
|
311 |
+
|
312 |
+
mask = np.zeros(out.shape)
|
313 |
+
MASK_COLORMAP = [0, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 0, 255, 0, 0, 0]
|
314 |
+
for idx, color in enumerate(MASK_COLORMAP):
|
315 |
+
mask[out == idx] = color
|
316 |
+
# blur the mask
|
317 |
+
mask = cv2.GaussianBlur(mask, (101, 101), 11)
|
318 |
+
mask = cv2.GaussianBlur(mask, (101, 101), 11)
|
319 |
+
# remove the black borders
|
320 |
+
thres = 10
|
321 |
+
mask[:thres, :] = 0
|
322 |
+
mask[-thres:, :] = 0
|
323 |
+
mask[:, :thres] = 0
|
324 |
+
mask[:, -thres:] = 0
|
325 |
+
mask = mask / 255.
|
326 |
+
|
327 |
+
mask = cv2.resize(mask, restored_face.shape[:2])
|
328 |
+
mask = cv2.warpAffine(mask, inverse_affine, (w_up, h_up), flags=3)
|
329 |
+
inv_soft_mask = mask[:, :, None]
|
330 |
+
pasted_face = inv_restored
|
331 |
+
|
332 |
+
else: # use square parse maps
|
333 |
+
mask = np.ones(self.face_size, dtype=np.float32)
|
334 |
+
inv_mask = cv2.warpAffine(mask, inverse_affine, (w_up, h_up))
|
335 |
+
# remove the black borders
|
336 |
+
inv_mask_erosion = cv2.erode(
|
337 |
+
inv_mask, np.ones((int(2 * self.upscale_factor), int(2 * self.upscale_factor)), np.uint8))
|
338 |
+
pasted_face = inv_mask_erosion[:, :, None] * inv_restored
|
339 |
+
total_face_area = np.sum(inv_mask_erosion) # // 3
|
340 |
+
# compute the fusion edge based on the area of face
|
341 |
+
w_edge = int(total_face_area**0.5) // 20
|
342 |
+
erosion_radius = w_edge * 2
|
343 |
+
inv_mask_center = cv2.erode(inv_mask_erosion, np.ones((erosion_radius, erosion_radius), np.uint8))
|
344 |
+
blur_size = w_edge * 2
|
345 |
+
inv_soft_mask = cv2.GaussianBlur(inv_mask_center, (blur_size + 1, blur_size + 1), 0)
|
346 |
+
if len(upsample_img.shape) == 2: # upsample_img is gray image
|
347 |
+
upsample_img = upsample_img[:, :, None]
|
348 |
+
inv_soft_mask = inv_soft_mask[:, :, None]
|
349 |
+
|
350 |
+
if len(upsample_img.shape) == 3 and upsample_img.shape[2] == 4: # alpha channel
|
351 |
+
alpha = upsample_img[:, :, 3:]
|
352 |
+
upsample_img = inv_soft_mask * pasted_face + (1 - inv_soft_mask) * upsample_img[:, :, 0:3]
|
353 |
+
upsample_img = np.concatenate((upsample_img, alpha), axis=2)
|
354 |
+
else:
|
355 |
+
upsample_img = inv_soft_mask * pasted_face + (1 - inv_soft_mask) * upsample_img
|
356 |
+
|
357 |
+
if np.max(upsample_img) > 256: # 16-bit image
|
358 |
+
upsample_img = upsample_img.astype(np.uint16)
|
359 |
+
else:
|
360 |
+
upsample_img = upsample_img.astype(np.uint8)
|
361 |
+
if save_path is not None:
|
362 |
+
path = os.path.splitext(save_path)[0]
|
363 |
+
save_path = f'{path}.{self.save_ext}'
|
364 |
+
imwrite(upsample_img, save_path)
|
365 |
+
return upsample_img
|
366 |
+
|
367 |
+
def clean_all(self):
|
368 |
+
self.all_landmarks_5 = []
|
369 |
+
self.restored_faces = []
|
370 |
+
self.affine_matrices = []
|
371 |
+
self.cropped_faces = []
|
372 |
+
self.inverse_affine_matrices = []
|
373 |
+
self.det_faces = []
|
374 |
+
self.pad_input_imgs = []
|
inpaint_worker 2.py
ADDED
@@ -0,0 +1,264 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
from PIL import Image, ImageFilter
|
5 |
+
from modules.util import resample_image, set_image_shape_ceil, get_image_shape_ceil
|
6 |
+
from modules.upscaler import perform_upscale
|
7 |
+
import cv2
|
8 |
+
|
9 |
+
|
10 |
+
inpaint_head_model = None
|
11 |
+
|
12 |
+
|
13 |
+
class InpaintHead(torch.nn.Module):
|
14 |
+
def __init__(self, *args, **kwargs):
|
15 |
+
super().__init__(*args, **kwargs)
|
16 |
+
self.head = torch.nn.Parameter(torch.empty(size=(320, 5, 3, 3), device='cpu'))
|
17 |
+
|
18 |
+
def __call__(self, x):
|
19 |
+
x = torch.nn.functional.pad(x, (1, 1, 1, 1), "replicate")
|
20 |
+
return torch.nn.functional.conv2d(input=x, weight=self.head)
|
21 |
+
|
22 |
+
|
23 |
+
current_task = None
|
24 |
+
|
25 |
+
|
26 |
+
def box_blur(x, k):
|
27 |
+
x = Image.fromarray(x)
|
28 |
+
x = x.filter(ImageFilter.BoxBlur(k))
|
29 |
+
return np.array(x)
|
30 |
+
|
31 |
+
|
32 |
+
def max_filter_opencv(x, ksize=3):
|
33 |
+
# Use OpenCV maximum filter
|
34 |
+
# Make sure the input type is int16
|
35 |
+
return cv2.dilate(x, np.ones((ksize, ksize), dtype=np.int16))
|
36 |
+
|
37 |
+
|
38 |
+
def morphological_open(x):
|
39 |
+
# Convert array to int16 type via threshold operation
|
40 |
+
x_int16 = np.zeros_like(x, dtype=np.int16)
|
41 |
+
x_int16[x > 127] = 256
|
42 |
+
|
43 |
+
for i in range(32):
|
44 |
+
# Use int16 type to avoid overflow
|
45 |
+
maxed = max_filter_opencv(x_int16, ksize=3) - 8
|
46 |
+
x_int16 = np.maximum(maxed, x_int16)
|
47 |
+
|
48 |
+
# Clip negative values to 0 and convert back to uint8 type
|
49 |
+
x_uint8 = np.clip(x_int16, 0, 255).astype(np.uint8)
|
50 |
+
return x_uint8
|
51 |
+
|
52 |
+
|
53 |
+
def up255(x, t=0):
|
54 |
+
y = np.zeros_like(x).astype(np.uint8)
|
55 |
+
y[x > t] = 255
|
56 |
+
return y
|
57 |
+
|
58 |
+
|
59 |
+
def imsave(x, path):
|
60 |
+
x = Image.fromarray(x)
|
61 |
+
x.save(path)
|
62 |
+
|
63 |
+
|
64 |
+
def regulate_abcd(x, a, b, c, d):
|
65 |
+
H, W = x.shape[:2]
|
66 |
+
if a < 0:
|
67 |
+
a = 0
|
68 |
+
if a > H:
|
69 |
+
a = H
|
70 |
+
if b < 0:
|
71 |
+
b = 0
|
72 |
+
if b > H:
|
73 |
+
b = H
|
74 |
+
if c < 0:
|
75 |
+
c = 0
|
76 |
+
if c > W:
|
77 |
+
c = W
|
78 |
+
if d < 0:
|
79 |
+
d = 0
|
80 |
+
if d > W:
|
81 |
+
d = W
|
82 |
+
return int(a), int(b), int(c), int(d)
|
83 |
+
|
84 |
+
|
85 |
+
def compute_initial_abcd(x):
|
86 |
+
indices = np.where(x)
|
87 |
+
a = np.min(indices[0])
|
88 |
+
b = np.max(indices[0])
|
89 |
+
c = np.min(indices[1])
|
90 |
+
d = np.max(indices[1])
|
91 |
+
abp = (b + a) // 2
|
92 |
+
abm = (b - a) // 2
|
93 |
+
cdp = (d + c) // 2
|
94 |
+
cdm = (d - c) // 2
|
95 |
+
l = int(max(abm, cdm) * 1.15)
|
96 |
+
a = abp - l
|
97 |
+
b = abp + l + 1
|
98 |
+
c = cdp - l
|
99 |
+
d = cdp + l + 1
|
100 |
+
a, b, c, d = regulate_abcd(x, a, b, c, d)
|
101 |
+
return a, b, c, d
|
102 |
+
|
103 |
+
|
104 |
+
def solve_abcd(x, a, b, c, d, k):
|
105 |
+
k = float(k)
|
106 |
+
assert 0.0 <= k <= 1.0
|
107 |
+
|
108 |
+
H, W = x.shape[:2]
|
109 |
+
if k == 1.0:
|
110 |
+
return 0, H, 0, W
|
111 |
+
while True:
|
112 |
+
if b - a >= H * k and d - c >= W * k:
|
113 |
+
break
|
114 |
+
|
115 |
+
add_h = (b - a) < (d - c)
|
116 |
+
add_w = not add_h
|
117 |
+
|
118 |
+
if b - a == H:
|
119 |
+
add_w = True
|
120 |
+
|
121 |
+
if d - c == W:
|
122 |
+
add_h = True
|
123 |
+
|
124 |
+
if add_h:
|
125 |
+
a -= 1
|
126 |
+
b += 1
|
127 |
+
|
128 |
+
if add_w:
|
129 |
+
c -= 1
|
130 |
+
d += 1
|
131 |
+
|
132 |
+
a, b, c, d = regulate_abcd(x, a, b, c, d)
|
133 |
+
return a, b, c, d
|
134 |
+
|
135 |
+
|
136 |
+
def fooocus_fill(image, mask):
|
137 |
+
current_image = image.copy()
|
138 |
+
raw_image = image.copy()
|
139 |
+
area = np.where(mask < 127)
|
140 |
+
store = raw_image[area]
|
141 |
+
|
142 |
+
for k, repeats in [(512, 2), (256, 2), (128, 4), (64, 4), (33, 8), (15, 8), (5, 16), (3, 16)]:
|
143 |
+
for _ in range(repeats):
|
144 |
+
current_image = box_blur(current_image, k)
|
145 |
+
current_image[area] = store
|
146 |
+
|
147 |
+
return current_image
|
148 |
+
|
149 |
+
|
150 |
+
class InpaintWorker:
|
151 |
+
def __init__(self, image, mask, use_fill=True, k=0.618):
|
152 |
+
a, b, c, d = compute_initial_abcd(mask > 0)
|
153 |
+
a, b, c, d = solve_abcd(mask, a, b, c, d, k=k)
|
154 |
+
|
155 |
+
# interested area
|
156 |
+
self.interested_area = (a, b, c, d)
|
157 |
+
self.interested_mask = mask[a:b, c:d]
|
158 |
+
self.interested_image = image[a:b, c:d]
|
159 |
+
|
160 |
+
# super resolution
|
161 |
+
if get_image_shape_ceil(self.interested_image) < 1024:
|
162 |
+
self.interested_image = perform_upscale(self.interested_image)
|
163 |
+
|
164 |
+
# resize to make images ready for diffusion
|
165 |
+
self.interested_image = set_image_shape_ceil(self.interested_image, 1024)
|
166 |
+
self.interested_fill = self.interested_image.copy()
|
167 |
+
H, W, C = self.interested_image.shape
|
168 |
+
|
169 |
+
# process mask
|
170 |
+
self.interested_mask = up255(resample_image(self.interested_mask, W, H), t=127)
|
171 |
+
|
172 |
+
# compute filling
|
173 |
+
if use_fill:
|
174 |
+
self.interested_fill = fooocus_fill(self.interested_image, self.interested_mask)
|
175 |
+
|
176 |
+
# soft pixels
|
177 |
+
self.mask = morphological_open(mask)
|
178 |
+
self.image = image
|
179 |
+
|
180 |
+
# ending
|
181 |
+
self.latent = None
|
182 |
+
self.latent_after_swap = None
|
183 |
+
self.swapped = False
|
184 |
+
self.latent_mask = None
|
185 |
+
self.inpaint_head_feature = None
|
186 |
+
return
|
187 |
+
|
188 |
+
def load_latent(self, latent_fill, latent_mask, latent_swap=None):
|
189 |
+
self.latent = latent_fill
|
190 |
+
self.latent_mask = latent_mask
|
191 |
+
self.latent_after_swap = latent_swap
|
192 |
+
return
|
193 |
+
|
194 |
+
def patch(self, inpaint_head_model_path, inpaint_latent, inpaint_latent_mask, model):
|
195 |
+
global inpaint_head_model
|
196 |
+
|
197 |
+
if inpaint_head_model is None:
|
198 |
+
inpaint_head_model = InpaintHead()
|
199 |
+
sd = torch.load(inpaint_head_model_path, map_location='cpu')
|
200 |
+
inpaint_head_model.load_state_dict(sd)
|
201 |
+
|
202 |
+
feed = torch.cat([
|
203 |
+
inpaint_latent_mask,
|
204 |
+
model.model.process_latent_in(inpaint_latent)
|
205 |
+
], dim=1)
|
206 |
+
|
207 |
+
inpaint_head_model.to(device=feed.device, dtype=feed.dtype)
|
208 |
+
inpaint_head_feature = inpaint_head_model(feed)
|
209 |
+
|
210 |
+
def input_block_patch(h, transformer_options):
|
211 |
+
if transformer_options["block"][1] == 0:
|
212 |
+
h = h + inpaint_head_feature.to(h)
|
213 |
+
return h
|
214 |
+
|
215 |
+
m = model.clone()
|
216 |
+
m.set_model_input_block_patch(input_block_patch)
|
217 |
+
return m
|
218 |
+
|
219 |
+
def swap(self):
|
220 |
+
if self.swapped:
|
221 |
+
return
|
222 |
+
|
223 |
+
if self.latent is None:
|
224 |
+
return
|
225 |
+
|
226 |
+
if self.latent_after_swap is None:
|
227 |
+
return
|
228 |
+
|
229 |
+
self.latent, self.latent_after_swap = self.latent_after_swap, self.latent
|
230 |
+
self.swapped = True
|
231 |
+
return
|
232 |
+
|
233 |
+
def unswap(self):
|
234 |
+
if not self.swapped:
|
235 |
+
return
|
236 |
+
|
237 |
+
if self.latent is None:
|
238 |
+
return
|
239 |
+
|
240 |
+
if self.latent_after_swap is None:
|
241 |
+
return
|
242 |
+
|
243 |
+
self.latent, self.latent_after_swap = self.latent_after_swap, self.latent
|
244 |
+
self.swapped = False
|
245 |
+
return
|
246 |
+
|
247 |
+
def color_correction(self, img):
|
248 |
+
fg = img.astype(np.float32)
|
249 |
+
bg = self.image.copy().astype(np.float32)
|
250 |
+
w = self.mask[:, :, None].astype(np.float32) / 255.0
|
251 |
+
y = fg * w + bg * (1 - w)
|
252 |
+
return y.clip(0, 255).astype(np.uint8)
|
253 |
+
|
254 |
+
def post_process(self, img):
|
255 |
+
a, b, c, d = self.interested_area
|
256 |
+
content = resample_image(img, d - c, b - a)
|
257 |
+
result = self.image.copy()
|
258 |
+
result[a:b, c:d] = content
|
259 |
+
result = self.color_correction(result)
|
260 |
+
return result
|
261 |
+
|
262 |
+
def visualize_mask_processing(self):
|
263 |
+
return [self.interested_fill, self.interested_mask, self.interested_image]
|
264 |
+
|
inpaint_worker.py
ADDED
@@ -0,0 +1,264 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
from PIL import Image, ImageFilter
|
5 |
+
from modules.util import resample_image, set_image_shape_ceil, get_image_shape_ceil
|
6 |
+
from modules.upscaler import perform_upscale
|
7 |
+
import cv2
|
8 |
+
|
9 |
+
|
10 |
+
inpaint_head_model = None
|
11 |
+
|
12 |
+
|
13 |
+
class InpaintHead(torch.nn.Module):
|
14 |
+
def __init__(self, *args, **kwargs):
|
15 |
+
super().__init__(*args, **kwargs)
|
16 |
+
self.head = torch.nn.Parameter(torch.empty(size=(320, 5, 3, 3), device='cpu'))
|
17 |
+
|
18 |
+
def __call__(self, x):
|
19 |
+
x = torch.nn.functional.pad(x, (1, 1, 1, 1), "replicate")
|
20 |
+
return torch.nn.functional.conv2d(input=x, weight=self.head)
|
21 |
+
|
22 |
+
|
23 |
+
current_task = None
|
24 |
+
|
25 |
+
|
26 |
+
def box_blur(x, k):
|
27 |
+
x = Image.fromarray(x)
|
28 |
+
x = x.filter(ImageFilter.BoxBlur(k))
|
29 |
+
return np.array(x)
|
30 |
+
|
31 |
+
|
32 |
+
def max_filter_opencv(x, ksize=3):
|
33 |
+
# Use OpenCV maximum filter
|
34 |
+
# Make sure the input type is int16
|
35 |
+
return cv2.dilate(x, np.ones((ksize, ksize), dtype=np.int16))
|
36 |
+
|
37 |
+
|
38 |
+
def morphological_open(x):
|
39 |
+
# Convert array to int16 type via threshold operation
|
40 |
+
x_int16 = np.zeros_like(x, dtype=np.int16)
|
41 |
+
x_int16[x > 127] = 256
|
42 |
+
|
43 |
+
for i in range(32):
|
44 |
+
# Use int16 type to avoid overflow
|
45 |
+
maxed = max_filter_opencv(x_int16, ksize=3) - 8
|
46 |
+
x_int16 = np.maximum(maxed, x_int16)
|
47 |
+
|
48 |
+
# Clip negative values to 0 and convert back to uint8 type
|
49 |
+
x_uint8 = np.clip(x_int16, 0, 255).astype(np.uint8)
|
50 |
+
return x_uint8
|
51 |
+
|
52 |
+
|
53 |
+
def up255(x, t=0):
|
54 |
+
y = np.zeros_like(x).astype(np.uint8)
|
55 |
+
y[x > t] = 255
|
56 |
+
return y
|
57 |
+
|
58 |
+
|
59 |
+
def imsave(x, path):
|
60 |
+
x = Image.fromarray(x)
|
61 |
+
x.save(path)
|
62 |
+
|
63 |
+
|
64 |
+
def regulate_abcd(x, a, b, c, d):
|
65 |
+
H, W = x.shape[:2]
|
66 |
+
if a < 0:
|
67 |
+
a = 0
|
68 |
+
if a > H:
|
69 |
+
a = H
|
70 |
+
if b < 0:
|
71 |
+
b = 0
|
72 |
+
if b > H:
|
73 |
+
b = H
|
74 |
+
if c < 0:
|
75 |
+
c = 0
|
76 |
+
if c > W:
|
77 |
+
c = W
|
78 |
+
if d < 0:
|
79 |
+
d = 0
|
80 |
+
if d > W:
|
81 |
+
d = W
|
82 |
+
return int(a), int(b), int(c), int(d)
|
83 |
+
|
84 |
+
|
85 |
+
def compute_initial_abcd(x):
|
86 |
+
indices = np.where(x)
|
87 |
+
a = np.min(indices[0])
|
88 |
+
b = np.max(indices[0])
|
89 |
+
c = np.min(indices[1])
|
90 |
+
d = np.max(indices[1])
|
91 |
+
abp = (b + a) // 2
|
92 |
+
abm = (b - a) // 2
|
93 |
+
cdp = (d + c) // 2
|
94 |
+
cdm = (d - c) // 2
|
95 |
+
l = int(max(abm, cdm) * 1.15)
|
96 |
+
a = abp - l
|
97 |
+
b = abp + l + 1
|
98 |
+
c = cdp - l
|
99 |
+
d = cdp + l + 1
|
100 |
+
a, b, c, d = regulate_abcd(x, a, b, c, d)
|
101 |
+
return a, b, c, d
|
102 |
+
|
103 |
+
|
104 |
+
def solve_abcd(x, a, b, c, d, k):
|
105 |
+
k = float(k)
|
106 |
+
assert 0.0 <= k <= 1.0
|
107 |
+
|
108 |
+
H, W = x.shape[:2]
|
109 |
+
if k == 1.0:
|
110 |
+
return 0, H, 0, W
|
111 |
+
while True:
|
112 |
+
if b - a >= H * k and d - c >= W * k:
|
113 |
+
break
|
114 |
+
|
115 |
+
add_h = (b - a) < (d - c)
|
116 |
+
add_w = not add_h
|
117 |
+
|
118 |
+
if b - a == H:
|
119 |
+
add_w = True
|
120 |
+
|
121 |
+
if d - c == W:
|
122 |
+
add_h = True
|
123 |
+
|
124 |
+
if add_h:
|
125 |
+
a -= 1
|
126 |
+
b += 1
|
127 |
+
|
128 |
+
if add_w:
|
129 |
+
c -= 1
|
130 |
+
d += 1
|
131 |
+
|
132 |
+
a, b, c, d = regulate_abcd(x, a, b, c, d)
|
133 |
+
return a, b, c, d
|
134 |
+
|
135 |
+
|
136 |
+
def fooocus_fill(image, mask):
|
137 |
+
current_image = image.copy()
|
138 |
+
raw_image = image.copy()
|
139 |
+
area = np.where(mask < 127)
|
140 |
+
store = raw_image[area]
|
141 |
+
|
142 |
+
for k, repeats in [(512, 2), (256, 2), (128, 4), (64, 4), (33, 8), (15, 8), (5, 16), (3, 16)]:
|
143 |
+
for _ in range(repeats):
|
144 |
+
current_image = box_blur(current_image, k)
|
145 |
+
current_image[area] = store
|
146 |
+
|
147 |
+
return current_image
|
148 |
+
|
149 |
+
|
150 |
+
class InpaintWorker:
|
151 |
+
def __init__(self, image, mask, use_fill=True, k=0.618):
|
152 |
+
a, b, c, d = compute_initial_abcd(mask > 0)
|
153 |
+
a, b, c, d = solve_abcd(mask, a, b, c, d, k=k)
|
154 |
+
|
155 |
+
# interested area
|
156 |
+
self.interested_area = (a, b, c, d)
|
157 |
+
self.interested_mask = mask[a:b, c:d]
|
158 |
+
self.interested_image = image[a:b, c:d]
|
159 |
+
|
160 |
+
# super resolution
|
161 |
+
if get_image_shape_ceil(self.interested_image) < 1024:
|
162 |
+
self.interested_image = perform_upscale(self.interested_image)
|
163 |
+
|
164 |
+
# resize to make images ready for diffusion
|
165 |
+
self.interested_image = set_image_shape_ceil(self.interested_image, 1024)
|
166 |
+
self.interested_fill = self.interested_image.copy()
|
167 |
+
H, W, C = self.interested_image.shape
|
168 |
+
|
169 |
+
# process mask
|
170 |
+
self.interested_mask = up255(resample_image(self.interested_mask, W, H), t=127)
|
171 |
+
|
172 |
+
# compute filling
|
173 |
+
if use_fill:
|
174 |
+
self.interested_fill = fooocus_fill(self.interested_image, self.interested_mask)
|
175 |
+
|
176 |
+
# soft pixels
|
177 |
+
self.mask = morphological_open(mask)
|
178 |
+
self.image = image
|
179 |
+
|
180 |
+
# ending
|
181 |
+
self.latent = None
|
182 |
+
self.latent_after_swap = None
|
183 |
+
self.swapped = False
|
184 |
+
self.latent_mask = None
|
185 |
+
self.inpaint_head_feature = None
|
186 |
+
return
|
187 |
+
|
188 |
+
def load_latent(self, latent_fill, latent_mask, latent_swap=None):
|
189 |
+
self.latent = latent_fill
|
190 |
+
self.latent_mask = latent_mask
|
191 |
+
self.latent_after_swap = latent_swap
|
192 |
+
return
|
193 |
+
|
194 |
+
def patch(self, inpaint_head_model_path, inpaint_latent, inpaint_latent_mask, model):
|
195 |
+
global inpaint_head_model
|
196 |
+
|
197 |
+
if inpaint_head_model is None:
|
198 |
+
inpaint_head_model = InpaintHead()
|
199 |
+
sd = torch.load(inpaint_head_model_path, map_location='cpu')
|
200 |
+
inpaint_head_model.load_state_dict(sd)
|
201 |
+
|
202 |
+
feed = torch.cat([
|
203 |
+
inpaint_latent_mask,
|
204 |
+
model.model.process_latent_in(inpaint_latent)
|
205 |
+
], dim=1)
|
206 |
+
|
207 |
+
inpaint_head_model.to(device=feed.device, dtype=feed.dtype)
|
208 |
+
inpaint_head_feature = inpaint_head_model(feed)
|
209 |
+
|
210 |
+
def input_block_patch(h, transformer_options):
|
211 |
+
if transformer_options["block"][1] == 0:
|
212 |
+
h = h + inpaint_head_feature.to(h)
|
213 |
+
return h
|
214 |
+
|
215 |
+
m = model.clone()
|
216 |
+
m.set_model_input_block_patch(input_block_patch)
|
217 |
+
return m
|
218 |
+
|
219 |
+
def swap(self):
|
220 |
+
if self.swapped:
|
221 |
+
return
|
222 |
+
|
223 |
+
if self.latent is None:
|
224 |
+
return
|
225 |
+
|
226 |
+
if self.latent_after_swap is None:
|
227 |
+
return
|
228 |
+
|
229 |
+
self.latent, self.latent_after_swap = self.latent_after_swap, self.latent
|
230 |
+
self.swapped = True
|
231 |
+
return
|
232 |
+
|
233 |
+
def unswap(self):
|
234 |
+
if not self.swapped:
|
235 |
+
return
|
236 |
+
|
237 |
+
if self.latent is None:
|
238 |
+
return
|
239 |
+
|
240 |
+
if self.latent_after_swap is None:
|
241 |
+
return
|
242 |
+
|
243 |
+
self.latent, self.latent_after_swap = self.latent_after_swap, self.latent
|
244 |
+
self.swapped = False
|
245 |
+
return
|
246 |
+
|
247 |
+
def color_correction(self, img):
|
248 |
+
fg = img.astype(np.float32)
|
249 |
+
bg = self.image.copy().astype(np.float32)
|
250 |
+
w = self.mask[:, :, None].astype(np.float32) / 255.0
|
251 |
+
y = fg * w + bg * (1 - w)
|
252 |
+
return y.clip(0, 255).astype(np.uint8)
|
253 |
+
|
254 |
+
def post_process(self, img):
|
255 |
+
a, b, c, d = self.interested_area
|
256 |
+
content = resample_image(img, d - c, b - a)
|
257 |
+
result = self.image.copy()
|
258 |
+
result[a:b, c:d] = content
|
259 |
+
result = self.color_correction(result)
|
260 |
+
return result
|
261 |
+
|
262 |
+
def visualize_mask_processing(self):
|
263 |
+
return [self.interested_fill, self.interested_mask, self.interested_image]
|
264 |
+
|
launch_util.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import importlib
|
3 |
+
import importlib.util
|
4 |
+
import subprocess
|
5 |
+
import sys
|
6 |
+
import re
|
7 |
+
import logging
|
8 |
+
import importlib.metadata
|
9 |
+
import packaging.version
|
10 |
+
from packaging.requirements import Requirement
|
11 |
+
|
12 |
+
|
13 |
+
|
14 |
+
|
15 |
+
logging.getLogger("torch.distributed.nn").setLevel(logging.ERROR) # sshh...
|
16 |
+
logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
|
17 |
+
|
18 |
+
re_requirement = re.compile(r"\s*([-_a-zA-Z0-9]+)\s*(?:==\s*([-+_.a-zA-Z0-9]+))?\s*")
|
19 |
+
|
20 |
+
python = sys.executable
|
21 |
+
default_command_live = (os.environ.get('LAUNCH_LIVE_OUTPUT') == "1")
|
22 |
+
index_url = os.environ.get('INDEX_URL', "")
|
23 |
+
|
24 |
+
modules_path = os.path.dirname(os.path.realpath(__file__))
|
25 |
+
script_path = os.path.dirname(modules_path)
|
26 |
+
|
27 |
+
|
28 |
+
def is_installed(package):
|
29 |
+
try:
|
30 |
+
spec = importlib.util.find_spec(package)
|
31 |
+
except ModuleNotFoundError:
|
32 |
+
return False
|
33 |
+
|
34 |
+
return spec is not None
|
35 |
+
|
36 |
+
|
37 |
+
def run(command, desc=None, errdesc=None, custom_env=None, live: bool = default_command_live) -> str:
|
38 |
+
if desc is not None:
|
39 |
+
print(desc)
|
40 |
+
|
41 |
+
run_kwargs = {
|
42 |
+
"args": command,
|
43 |
+
"shell": True,
|
44 |
+
"env": os.environ if custom_env is None else custom_env,
|
45 |
+
"encoding": 'utf8',
|
46 |
+
"errors": 'ignore',
|
47 |
+
}
|
48 |
+
|
49 |
+
if not live:
|
50 |
+
run_kwargs["stdout"] = run_kwargs["stderr"] = subprocess.PIPE
|
51 |
+
|
52 |
+
result = subprocess.run(**run_kwargs)
|
53 |
+
|
54 |
+
if result.returncode != 0:
|
55 |
+
error_bits = [
|
56 |
+
f"{errdesc or 'Error running command'}.",
|
57 |
+
f"Command: {command}",
|
58 |
+
f"Error code: {result.returncode}",
|
59 |
+
]
|
60 |
+
if result.stdout:
|
61 |
+
error_bits.append(f"stdout: {result.stdout}")
|
62 |
+
if result.stderr:
|
63 |
+
error_bits.append(f"stderr: {result.stderr}")
|
64 |
+
raise RuntimeError("\n".join(error_bits))
|
65 |
+
|
66 |
+
return (result.stdout or "")
|
67 |
+
|
68 |
+
|
69 |
+
def run_pip(command, desc=None, live=default_command_live):
|
70 |
+
try:
|
71 |
+
index_url_line = f' --index-url {index_url}' if index_url != '' else ''
|
72 |
+
return run(f'"{python}" -m pip {command} --prefer-binary{index_url_line}', desc=f"Installing {desc}",
|
73 |
+
errdesc=f"Couldn't install {desc}", live=live)
|
74 |
+
except Exception as e:
|
75 |
+
print(e)
|
76 |
+
print(f'CMD Failed {desc}: {command}')
|
77 |
+
return None
|
78 |
+
|
79 |
+
|
80 |
+
def requirements_met(requirements_file):
|
81 |
+
with open(requirements_file, "r", encoding="utf8") as file:
|
82 |
+
for line in file:
|
83 |
+
line = line.strip()
|
84 |
+
if line == "" or line.startswith('#'):
|
85 |
+
continue
|
86 |
+
|
87 |
+
requirement = Requirement(line)
|
88 |
+
package = requirement.name
|
89 |
+
|
90 |
+
try:
|
91 |
+
version_installed = importlib.metadata.version(package)
|
92 |
+
installed_version = packaging.version.parse(version_installed)
|
93 |
+
|
94 |
+
# Check if the installed version satisfies the requirement
|
95 |
+
if installed_version not in requirement.specifier:
|
96 |
+
print(f"Version mismatch for {package}: Installed version {version_installed} does not meet requirement {requirement}")
|
97 |
+
return False
|
98 |
+
except Exception as e:
|
99 |
+
print(f"Error checking version for {package}: {e}")
|
100 |
+
return False
|
101 |
+
|
102 |
+
return True
|
103 |
+
|
lora.py
ADDED
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
def match_lora(lora, to_load):
|
2 |
+
patch_dict = {}
|
3 |
+
loaded_keys = set()
|
4 |
+
for x in to_load:
|
5 |
+
real_load_key = to_load[x]
|
6 |
+
if real_load_key in lora:
|
7 |
+
patch_dict[real_load_key] = ('fooocus', lora[real_load_key])
|
8 |
+
loaded_keys.add(real_load_key)
|
9 |
+
continue
|
10 |
+
|
11 |
+
alpha_name = "{}.alpha".format(x)
|
12 |
+
alpha = None
|
13 |
+
if alpha_name in lora.keys():
|
14 |
+
alpha = lora[alpha_name].item()
|
15 |
+
loaded_keys.add(alpha_name)
|
16 |
+
|
17 |
+
regular_lora = "{}.lora_up.weight".format(x)
|
18 |
+
diffusers_lora = "{}_lora.up.weight".format(x)
|
19 |
+
transformers_lora = "{}.lora_linear_layer.up.weight".format(x)
|
20 |
+
A_name = None
|
21 |
+
|
22 |
+
if regular_lora in lora.keys():
|
23 |
+
A_name = regular_lora
|
24 |
+
B_name = "{}.lora_down.weight".format(x)
|
25 |
+
mid_name = "{}.lora_mid.weight".format(x)
|
26 |
+
elif diffusers_lora in lora.keys():
|
27 |
+
A_name = diffusers_lora
|
28 |
+
B_name = "{}_lora.down.weight".format(x)
|
29 |
+
mid_name = None
|
30 |
+
elif transformers_lora in lora.keys():
|
31 |
+
A_name = transformers_lora
|
32 |
+
B_name ="{}.lora_linear_layer.down.weight".format(x)
|
33 |
+
mid_name = None
|
34 |
+
|
35 |
+
if A_name is not None:
|
36 |
+
mid = None
|
37 |
+
if mid_name is not None and mid_name in lora.keys():
|
38 |
+
mid = lora[mid_name]
|
39 |
+
loaded_keys.add(mid_name)
|
40 |
+
patch_dict[to_load[x]] = ("lora", (lora[A_name], lora[B_name], alpha, mid))
|
41 |
+
loaded_keys.add(A_name)
|
42 |
+
loaded_keys.add(B_name)
|
43 |
+
|
44 |
+
|
45 |
+
######## loha
|
46 |
+
hada_w1_a_name = "{}.hada_w1_a".format(x)
|
47 |
+
hada_w1_b_name = "{}.hada_w1_b".format(x)
|
48 |
+
hada_w2_a_name = "{}.hada_w2_a".format(x)
|
49 |
+
hada_w2_b_name = "{}.hada_w2_b".format(x)
|
50 |
+
hada_t1_name = "{}.hada_t1".format(x)
|
51 |
+
hada_t2_name = "{}.hada_t2".format(x)
|
52 |
+
if hada_w1_a_name in lora.keys():
|
53 |
+
hada_t1 = None
|
54 |
+
hada_t2 = None
|
55 |
+
if hada_t1_name in lora.keys():
|
56 |
+
hada_t1 = lora[hada_t1_name]
|
57 |
+
hada_t2 = lora[hada_t2_name]
|
58 |
+
loaded_keys.add(hada_t1_name)
|
59 |
+
loaded_keys.add(hada_t2_name)
|
60 |
+
|
61 |
+
patch_dict[to_load[x]] = ("loha", (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name], hada_t1, hada_t2))
|
62 |
+
loaded_keys.add(hada_w1_a_name)
|
63 |
+
loaded_keys.add(hada_w1_b_name)
|
64 |
+
loaded_keys.add(hada_w2_a_name)
|
65 |
+
loaded_keys.add(hada_w2_b_name)
|
66 |
+
|
67 |
+
|
68 |
+
######## lokr
|
69 |
+
lokr_w1_name = "{}.lokr_w1".format(x)
|
70 |
+
lokr_w2_name = "{}.lokr_w2".format(x)
|
71 |
+
lokr_w1_a_name = "{}.lokr_w1_a".format(x)
|
72 |
+
lokr_w1_b_name = "{}.lokr_w1_b".format(x)
|
73 |
+
lokr_t2_name = "{}.lokr_t2".format(x)
|
74 |
+
lokr_w2_a_name = "{}.lokr_w2_a".format(x)
|
75 |
+
lokr_w2_b_name = "{}.lokr_w2_b".format(x)
|
76 |
+
|
77 |
+
lokr_w1 = None
|
78 |
+
if lokr_w1_name in lora.keys():
|
79 |
+
lokr_w1 = lora[lokr_w1_name]
|
80 |
+
loaded_keys.add(lokr_w1_name)
|
81 |
+
|
82 |
+
lokr_w2 = None
|
83 |
+
if lokr_w2_name in lora.keys():
|
84 |
+
lokr_w2 = lora[lokr_w2_name]
|
85 |
+
loaded_keys.add(lokr_w2_name)
|
86 |
+
|
87 |
+
lokr_w1_a = None
|
88 |
+
if lokr_w1_a_name in lora.keys():
|
89 |
+
lokr_w1_a = lora[lokr_w1_a_name]
|
90 |
+
loaded_keys.add(lokr_w1_a_name)
|
91 |
+
|
92 |
+
lokr_w1_b = None
|
93 |
+
if lokr_w1_b_name in lora.keys():
|
94 |
+
lokr_w1_b = lora[lokr_w1_b_name]
|
95 |
+
loaded_keys.add(lokr_w1_b_name)
|
96 |
+
|
97 |
+
lokr_w2_a = None
|
98 |
+
if lokr_w2_a_name in lora.keys():
|
99 |
+
lokr_w2_a = lora[lokr_w2_a_name]
|
100 |
+
loaded_keys.add(lokr_w2_a_name)
|
101 |
+
|
102 |
+
lokr_w2_b = None
|
103 |
+
if lokr_w2_b_name in lora.keys():
|
104 |
+
lokr_w2_b = lora[lokr_w2_b_name]
|
105 |
+
loaded_keys.add(lokr_w2_b_name)
|
106 |
+
|
107 |
+
lokr_t2 = None
|
108 |
+
if lokr_t2_name in lora.keys():
|
109 |
+
lokr_t2 = lora[lokr_t2_name]
|
110 |
+
loaded_keys.add(lokr_t2_name)
|
111 |
+
|
112 |
+
if (lokr_w1 is not None) or (lokr_w2 is not None) or (lokr_w1_a is not None) or (lokr_w2_a is not None):
|
113 |
+
patch_dict[to_load[x]] = ("lokr", (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2))
|
114 |
+
|
115 |
+
#glora
|
116 |
+
a1_name = "{}.a1.weight".format(x)
|
117 |
+
a2_name = "{}.a2.weight".format(x)
|
118 |
+
b1_name = "{}.b1.weight".format(x)
|
119 |
+
b2_name = "{}.b2.weight".format(x)
|
120 |
+
if a1_name in lora:
|
121 |
+
patch_dict[to_load[x]] = ("glora", (lora[a1_name], lora[a2_name], lora[b1_name], lora[b2_name], alpha))
|
122 |
+
loaded_keys.add(a1_name)
|
123 |
+
loaded_keys.add(a2_name)
|
124 |
+
loaded_keys.add(b1_name)
|
125 |
+
loaded_keys.add(b2_name)
|
126 |
+
|
127 |
+
w_norm_name = "{}.w_norm".format(x)
|
128 |
+
b_norm_name = "{}.b_norm".format(x)
|
129 |
+
w_norm = lora.get(w_norm_name, None)
|
130 |
+
b_norm = lora.get(b_norm_name, None)
|
131 |
+
|
132 |
+
if w_norm is not None:
|
133 |
+
loaded_keys.add(w_norm_name)
|
134 |
+
patch_dict[to_load[x]] = ("diff", (w_norm,))
|
135 |
+
if b_norm is not None:
|
136 |
+
loaded_keys.add(b_norm_name)
|
137 |
+
patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = ("diff", (b_norm,))
|
138 |
+
|
139 |
+
diff_name = "{}.diff".format(x)
|
140 |
+
diff_weight = lora.get(diff_name, None)
|
141 |
+
if diff_weight is not None:
|
142 |
+
patch_dict[to_load[x]] = ("diff", (diff_weight,))
|
143 |
+
loaded_keys.add(diff_name)
|
144 |
+
|
145 |
+
diff_bias_name = "{}.diff_b".format(x)
|
146 |
+
diff_bias = lora.get(diff_bias_name, None)
|
147 |
+
if diff_bias is not None:
|
148 |
+
patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = ("diff", (diff_bias,))
|
149 |
+
loaded_keys.add(diff_bias_name)
|
150 |
+
|
151 |
+
remaining_dict = {x: y for x, y in lora.items() if x not in loaded_keys}
|
152 |
+
return patch_dict, remaining_dict
|
model_loader.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from urllib.parse import urlparse
|
3 |
+
from typing import Optional
|
4 |
+
|
5 |
+
|
6 |
+
def load_file_from_url(
|
7 |
+
url: str,
|
8 |
+
*,
|
9 |
+
model_dir: str,
|
10 |
+
progress: bool = True,
|
11 |
+
file_name: Optional[str] = None,
|
12 |
+
) -> str:
|
13 |
+
"""Download a file from `url` into `model_dir`, using the file present if possible.
|
14 |
+
|
15 |
+
Returns the path to the downloaded file.
|
16 |
+
"""
|
17 |
+
os.makedirs(model_dir, exist_ok=True)
|
18 |
+
if not file_name:
|
19 |
+
parts = urlparse(url)
|
20 |
+
file_name = os.path.basename(parts.path)
|
21 |
+
cached_file = os.path.abspath(os.path.join(model_dir, file_name))
|
22 |
+
if not os.path.exists(cached_file):
|
23 |
+
print(f'Downloading: "{url}" to {cached_file}\n')
|
24 |
+
from torch.hub import download_url_to_file
|
25 |
+
download_url_to_file(url, cached_file, progress=progress)
|
26 |
+
return cached_file
|
sdxl_styles.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import re
|
3 |
+
import json
|
4 |
+
|
5 |
+
from modules.util import get_files_from_folder
|
6 |
+
|
7 |
+
|
8 |
+
# cannot use modules.config - validators causing circular imports
|
9 |
+
styles_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '../sdxl_styles/'))
|
10 |
+
wildcards_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '../wildcards/'))
|
11 |
+
wildcards_max_bfs_depth = 64
|
12 |
+
|
13 |
+
|
14 |
+
def normalize_key(k):
|
15 |
+
k = k.replace('-', ' ')
|
16 |
+
words = k.split(' ')
|
17 |
+
words = [w[:1].upper() + w[1:].lower() for w in words]
|
18 |
+
k = ' '.join(words)
|
19 |
+
k = k.replace('3d', '3D')
|
20 |
+
k = k.replace('Sai', 'SAI')
|
21 |
+
k = k.replace('Mre', 'MRE')
|
22 |
+
k = k.replace('(s', '(S')
|
23 |
+
return k
|
24 |
+
|
25 |
+
|
26 |
+
styles = {}
|
27 |
+
|
28 |
+
styles_files = get_files_from_folder(styles_path, ['.json'])
|
29 |
+
|
30 |
+
for x in ['sdxl_styles_fooocus.json',
|
31 |
+
'sdxl_styles_sai.json',
|
32 |
+
'sdxl_styles_mre.json',
|
33 |
+
'sdxl_styles_twri.json',
|
34 |
+
'sdxl_styles_diva.json',
|
35 |
+
'sdxl_styles_marc_k3nt3l.json']:
|
36 |
+
if x in styles_files:
|
37 |
+
styles_files.remove(x)
|
38 |
+
styles_files.append(x)
|
39 |
+
|
40 |
+
for styles_file in styles_files:
|
41 |
+
try:
|
42 |
+
with open(os.path.join(styles_path, styles_file), encoding='utf-8') as f:
|
43 |
+
for entry in json.load(f):
|
44 |
+
name = normalize_key(entry['name'])
|
45 |
+
prompt = entry['prompt'] if 'prompt' in entry else ''
|
46 |
+
negative_prompt = entry['negative_prompt'] if 'negative_prompt' in entry else ''
|
47 |
+
styles[name] = (prompt, negative_prompt)
|
48 |
+
except Exception as e:
|
49 |
+
print(str(e))
|
50 |
+
print(f'Failed to load style file {styles_file}')
|
51 |
+
|
52 |
+
style_keys = list(styles.keys())
|
53 |
+
fooocus_expansion = "Fooocus V2"
|
54 |
+
legal_style_names = [fooocus_expansion] + style_keys
|
55 |
+
|
56 |
+
|
57 |
+
def apply_style(style, positive):
|
58 |
+
p, n = styles[style]
|
59 |
+
return p.replace('{prompt}', positive).splitlines(), n.splitlines()
|
60 |
+
|
61 |
+
|
62 |
+
def apply_wildcards(wildcard_text, rng, directory=wildcards_path):
|
63 |
+
for _ in range(wildcards_max_bfs_depth):
|
64 |
+
placeholders = re.findall(r'__([\w-]+)__', wildcard_text)
|
65 |
+
if len(placeholders) == 0:
|
66 |
+
return wildcard_text
|
67 |
+
|
68 |
+
print(f'[Wildcards] processing: {wildcard_text}')
|
69 |
+
for placeholder in placeholders:
|
70 |
+
try:
|
71 |
+
words = open(os.path.join(directory, f'{placeholder}.txt'), encoding='utf-8').read().splitlines()
|
72 |
+
words = [x for x in words if x != '']
|
73 |
+
assert len(words) > 0
|
74 |
+
wildcard_text = wildcard_text.replace(f'__{placeholder}__', rng.choice(words), 1)
|
75 |
+
except:
|
76 |
+
print(f'[Wildcards] Warning: {placeholder}.txt missing or empty. '
|
77 |
+
f'Using "{placeholder}" as a normal word.')
|
78 |
+
wildcard_text = wildcard_text.replace(f'__{placeholder}__', placeholder)
|
79 |
+
print(f'[Wildcards] {wildcard_text}')
|
80 |
+
|
81 |
+
print(f'[Wildcards] BFS stack overflow. Current text: {wildcard_text}')
|
82 |
+
return wildcard_text
|
upscaler.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import modules.core as core
|
4 |
+
|
5 |
+
from ldm_patched.pfn.architecture.RRDB import RRDBNet as ESRGAN
|
6 |
+
from ldm_patched.contrib.external_upscale_model import ImageUpscaleWithModel
|
7 |
+
from collections import OrderedDict
|
8 |
+
from modules.config import path_upscale_models
|
9 |
+
|
10 |
+
model_filename = os.path.join(path_upscale_models, 'fooocus_upscaler_s409985e5.bin')
|
11 |
+
opImageUpscaleWithModel = ImageUpscaleWithModel()
|
12 |
+
model = None
|
13 |
+
|
14 |
+
|
15 |
+
def perform_upscale(img):
|
16 |
+
global model
|
17 |
+
|
18 |
+
print(f'Upscaling image with shape {str(img.shape)} ...')
|
19 |
+
|
20 |
+
if model is None:
|
21 |
+
sd = torch.load(model_filename)
|
22 |
+
sdo = OrderedDict()
|
23 |
+
for k, v in sd.items():
|
24 |
+
sdo[k.replace('residual_block_', 'RDB')] = v
|
25 |
+
del sd
|
26 |
+
model = ESRGAN(sdo)
|
27 |
+
model.cpu()
|
28 |
+
model.eval()
|
29 |
+
|
30 |
+
img = core.numpy_to_pytorch(img)
|
31 |
+
img = opImageUpscaleWithModel.upscale(model, img)[0]
|
32 |
+
img = core.pytorch_to_numpy(img)[0]
|
33 |
+
|
34 |
+
return img
|
util.py
ADDED
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import datetime
|
3 |
+
import random
|
4 |
+
import math
|
5 |
+
import os
|
6 |
+
import cv2
|
7 |
+
|
8 |
+
from PIL import Image
|
9 |
+
|
10 |
+
|
11 |
+
LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
|
12 |
+
|
13 |
+
|
14 |
+
def erode_or_dilate(x, k):
|
15 |
+
k = int(k)
|
16 |
+
if k > 0:
|
17 |
+
return cv2.dilate(x, kernel=np.ones(shape=(3, 3), dtype=np.uint8), iterations=k)
|
18 |
+
if k < 0:
|
19 |
+
return cv2.erode(x, kernel=np.ones(shape=(3, 3), dtype=np.uint8), iterations=-k)
|
20 |
+
return x
|
21 |
+
|
22 |
+
|
23 |
+
def resample_image(im, width, height):
|
24 |
+
im = Image.fromarray(im)
|
25 |
+
im = im.resize((int(width), int(height)), resample=LANCZOS)
|
26 |
+
return np.array(im)
|
27 |
+
|
28 |
+
|
29 |
+
def resize_image(im, width, height, resize_mode=1):
|
30 |
+
"""
|
31 |
+
Resizes an image with the specified resize_mode, width, and height.
|
32 |
+
|
33 |
+
Args:
|
34 |
+
resize_mode: The mode to use when resizing the image.
|
35 |
+
0: Resize the image to the specified width and height.
|
36 |
+
1: Resize the image to fill the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, cropping the excess.
|
37 |
+
2: Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, filling empty with data from image.
|
38 |
+
im: The image to resize.
|
39 |
+
width: The width to resize the image to.
|
40 |
+
height: The height to resize the image to.
|
41 |
+
"""
|
42 |
+
|
43 |
+
im = Image.fromarray(im)
|
44 |
+
|
45 |
+
def resize(im, w, h):
|
46 |
+
return im.resize((w, h), resample=LANCZOS)
|
47 |
+
|
48 |
+
if resize_mode == 0:
|
49 |
+
res = resize(im, width, height)
|
50 |
+
|
51 |
+
elif resize_mode == 1:
|
52 |
+
ratio = width / height
|
53 |
+
src_ratio = im.width / im.height
|
54 |
+
|
55 |
+
src_w = width if ratio > src_ratio else im.width * height // im.height
|
56 |
+
src_h = height if ratio <= src_ratio else im.height * width // im.width
|
57 |
+
|
58 |
+
resized = resize(im, src_w, src_h)
|
59 |
+
res = Image.new("RGB", (width, height))
|
60 |
+
res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
|
61 |
+
|
62 |
+
else:
|
63 |
+
ratio = width / height
|
64 |
+
src_ratio = im.width / im.height
|
65 |
+
|
66 |
+
src_w = width if ratio < src_ratio else im.width * height // im.height
|
67 |
+
src_h = height if ratio >= src_ratio else im.height * width // im.width
|
68 |
+
|
69 |
+
resized = resize(im, src_w, src_h)
|
70 |
+
res = Image.new("RGB", (width, height))
|
71 |
+
res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
|
72 |
+
|
73 |
+
if ratio < src_ratio:
|
74 |
+
fill_height = height // 2 - src_h // 2
|
75 |
+
if fill_height > 0:
|
76 |
+
res.paste(resized.resize((width, fill_height), box=(0, 0, width, 0)), box=(0, 0))
|
77 |
+
res.paste(resized.resize((width, fill_height), box=(0, resized.height, width, resized.height)), box=(0, fill_height + src_h))
|
78 |
+
elif ratio > src_ratio:
|
79 |
+
fill_width = width // 2 - src_w // 2
|
80 |
+
if fill_width > 0:
|
81 |
+
res.paste(resized.resize((fill_width, height), box=(0, 0, 0, height)), box=(0, 0))
|
82 |
+
res.paste(resized.resize((fill_width, height), box=(resized.width, 0, resized.width, height)), box=(fill_width + src_w, 0))
|
83 |
+
|
84 |
+
return np.array(res)
|
85 |
+
|
86 |
+
|
87 |
+
def get_shape_ceil(h, w):
|
88 |
+
return math.ceil(((h * w) ** 0.5) / 64.0) * 64.0
|
89 |
+
|
90 |
+
|
91 |
+
def get_image_shape_ceil(im):
|
92 |
+
H, W = im.shape[:2]
|
93 |
+
return get_shape_ceil(H, W)
|
94 |
+
|
95 |
+
|
96 |
+
def set_image_shape_ceil(im, shape_ceil):
|
97 |
+
shape_ceil = float(shape_ceil)
|
98 |
+
|
99 |
+
H_origin, W_origin, _ = im.shape
|
100 |
+
H, W = H_origin, W_origin
|
101 |
+
|
102 |
+
for _ in range(256):
|
103 |
+
current_shape_ceil = get_shape_ceil(H, W)
|
104 |
+
if abs(current_shape_ceil - shape_ceil) < 0.1:
|
105 |
+
break
|
106 |
+
k = shape_ceil / current_shape_ceil
|
107 |
+
H = int(round(float(H) * k / 64.0) * 64)
|
108 |
+
W = int(round(float(W) * k / 64.0) * 64)
|
109 |
+
|
110 |
+
if H == H_origin and W == W_origin:
|
111 |
+
return im
|
112 |
+
|
113 |
+
return resample_image(im, width=W, height=H)
|
114 |
+
|
115 |
+
|
116 |
+
def HWC3(x):
|
117 |
+
assert x.dtype == np.uint8
|
118 |
+
if x.ndim == 2:
|
119 |
+
x = x[:, :, None]
|
120 |
+
assert x.ndim == 3
|
121 |
+
H, W, C = x.shape
|
122 |
+
assert C == 1 or C == 3 or C == 4
|
123 |
+
if C == 3:
|
124 |
+
return x
|
125 |
+
if C == 1:
|
126 |
+
return np.concatenate([x, x, x], axis=2)
|
127 |
+
if C == 4:
|
128 |
+
color = x[:, :, 0:3].astype(np.float32)
|
129 |
+
alpha = x[:, :, 3:4].astype(np.float32) / 255.0
|
130 |
+
y = color * alpha + 255.0 * (1.0 - alpha)
|
131 |
+
y = y.clip(0, 255).astype(np.uint8)
|
132 |
+
return y
|
133 |
+
|
134 |
+
|
135 |
+
def remove_empty_str(items, default=None):
|
136 |
+
items = [x for x in items if x != ""]
|
137 |
+
if len(items) == 0 and default is not None:
|
138 |
+
return [default]
|
139 |
+
return items
|
140 |
+
|
141 |
+
|
142 |
+
def join_prompts(*args, **kwargs):
|
143 |
+
prompts = [str(x) for x in args if str(x) != ""]
|
144 |
+
if len(prompts) == 0:
|
145 |
+
return ""
|
146 |
+
if len(prompts) == 1:
|
147 |
+
return prompts[0]
|
148 |
+
return ', '.join(prompts)
|
149 |
+
|
150 |
+
|
151 |
+
def generate_temp_filename(folder='./outputs/', extension='png'):
|
152 |
+
current_time = datetime.datetime.now()
|
153 |
+
date_string = current_time.strftime("%Y-%m-%d")
|
154 |
+
time_string = current_time.strftime("%Y-%m-%d_%H-%M-%S")
|
155 |
+
random_number = random.randint(1000, 9999)
|
156 |
+
filename = f"{time_string}_{random_number}.{extension}"
|
157 |
+
result = os.path.join(folder, date_string, filename)
|
158 |
+
return date_string, os.path.abspath(os.path.realpath(result)), filename
|
159 |
+
|
160 |
+
|
161 |
+
def get_files_from_folder(folder_path, exensions=None, name_filter=None):
|
162 |
+
if not os.path.isdir(folder_path):
|
163 |
+
raise ValueError("Folder path is not a valid directory.")
|
164 |
+
|
165 |
+
filenames = []
|
166 |
+
|
167 |
+
for root, dirs, files in os.walk(folder_path):
|
168 |
+
relative_path = os.path.relpath(root, folder_path)
|
169 |
+
if relative_path == ".":
|
170 |
+
relative_path = ""
|
171 |
+
for filename in files:
|
172 |
+
_, file_extension = os.path.splitext(filename)
|
173 |
+
if (exensions == None or file_extension.lower() in exensions) and (name_filter == None or name_filter in _):
|
174 |
+
path = os.path.join(relative_path, filename)
|
175 |
+
filenames.append(path)
|
176 |
+
|
177 |
+
return sorted(filenames, key=lambda x: -1 if os.sep in x else 1)
|
webui.py
ADDED
@@ -0,0 +1,623 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import random
|
3 |
+
import os
|
4 |
+
import json
|
5 |
+
import time
|
6 |
+
import shared
|
7 |
+
import modules.config
|
8 |
+
import fooocus_version
|
9 |
+
import modules.html
|
10 |
+
import modules.async_worker as worker
|
11 |
+
import modules.constants as constants
|
12 |
+
import modules.flags as flags
|
13 |
+
import modules.gradio_hijack as grh
|
14 |
+
import modules.advanced_parameters as advanced_parameters
|
15 |
+
import modules.style_sorter as style_sorter
|
16 |
+
import modules.meta_parser
|
17 |
+
import args_manager
|
18 |
+
import copy
|
19 |
+
|
20 |
+
from modules.sdxl_styles import legal_style_names
|
21 |
+
from modules.private_logger import get_current_html_path
|
22 |
+
from modules.ui_gradio_extensions import reload_javascript
|
23 |
+
from modules.auth import auth_enabled, check_auth
|
24 |
+
|
25 |
+
|
26 |
+
def generate_clicked(*args):
|
27 |
+
import ldm_patched.modules.model_management as model_management
|
28 |
+
|
29 |
+
with model_management.interrupt_processing_mutex:
|
30 |
+
model_management.interrupt_processing = False
|
31 |
+
|
32 |
+
# outputs=[progress_html, progress_window, progress_gallery, gallery]
|
33 |
+
|
34 |
+
execution_start_time = time.perf_counter()
|
35 |
+
task = worker.AsyncTask(args=list(args))
|
36 |
+
finished = False
|
37 |
+
|
38 |
+
yield gr.update(visible=True, value=modules.html.make_progress_html(1, 'Waiting for task to start ...')), \
|
39 |
+
gr.update(visible=True, value=None), \
|
40 |
+
gr.update(visible=False, value=None), \
|
41 |
+
gr.update(visible=False)
|
42 |
+
|
43 |
+
worker.async_tasks.append(task)
|
44 |
+
|
45 |
+
while not finished:
|
46 |
+
time.sleep(0.01)
|
47 |
+
if len(task.yields) > 0:
|
48 |
+
flag, product = task.yields.pop(0)
|
49 |
+
if flag == 'preview':
|
50 |
+
|
51 |
+
# help bad internet connection by skipping duplicated preview
|
52 |
+
if len(task.yields) > 0: # if we have the next item
|
53 |
+
if task.yields[0][0] == 'preview': # if the next item is also a preview
|
54 |
+
# print('Skipped one preview for better internet connection.')
|
55 |
+
continue
|
56 |
+
|
57 |
+
percentage, title, image = product
|
58 |
+
yield gr.update(visible=True, value=modules.html.make_progress_html(percentage, title)), \
|
59 |
+
gr.update(visible=True, value=image) if image is not None else gr.update(), \
|
60 |
+
gr.update(), \
|
61 |
+
gr.update(visible=False)
|
62 |
+
if flag == 'results':
|
63 |
+
yield gr.update(visible=True), \
|
64 |
+
gr.update(visible=True), \
|
65 |
+
gr.update(visible=True, value=product), \
|
66 |
+
gr.update(visible=False)
|
67 |
+
if flag == 'finish':
|
68 |
+
yield gr.update(visible=False), \
|
69 |
+
gr.update(visible=False), \
|
70 |
+
gr.update(visible=False), \
|
71 |
+
gr.update(visible=True, value=product)
|
72 |
+
finished = True
|
73 |
+
|
74 |
+
execution_time = time.perf_counter() - execution_start_time
|
75 |
+
print(f'Total time: {execution_time:.2f} seconds')
|
76 |
+
return
|
77 |
+
|
78 |
+
|
79 |
+
reload_javascript()
|
80 |
+
|
81 |
+
title = f'Fooocus {fooocus_version.version}'
|
82 |
+
|
83 |
+
if isinstance(args_manager.args.preset, str):
|
84 |
+
title += ' ' + args_manager.args.preset
|
85 |
+
|
86 |
+
shared.gradio_root = gr.Blocks(
|
87 |
+
title=title,
|
88 |
+
css=modules.html.css).queue()
|
89 |
+
|
90 |
+
with shared.gradio_root:
|
91 |
+
with gr.Row():
|
92 |
+
with gr.Column(scale=2):
|
93 |
+
with gr.Row():
|
94 |
+
progress_window = grh.Image(label='Preview', show_label=True, visible=False, height=768,
|
95 |
+
elem_classes=['main_view'])
|
96 |
+
progress_gallery = gr.Gallery(label='Finished Images', show_label=True, object_fit='contain',
|
97 |
+
height=768, visible=False, elem_classes=['main_view', 'image_gallery'])
|
98 |
+
progress_html = gr.HTML(value=modules.html.make_progress_html(32, 'Progress 32%'), visible=False,
|
99 |
+
elem_id='progress-bar', elem_classes='progress-bar')
|
100 |
+
gallery = gr.Gallery(label='Gallery', show_label=False, object_fit='contain', visible=True, height=768,
|
101 |
+
elem_classes=['resizable_area', 'main_view', 'final_gallery', 'image_gallery'],
|
102 |
+
elem_id='final_gallery')
|
103 |
+
with gr.Row(elem_classes='type_row'):
|
104 |
+
with gr.Column(scale=17):
|
105 |
+
prompt = gr.Textbox(show_label=False, placeholder="Type prompt here or paste parameters.", elem_id='positive_prompt',
|
106 |
+
container=False, autofocus=True, elem_classes='type_row', lines=1024)
|
107 |
+
|
108 |
+
default_prompt = modules.config.default_prompt
|
109 |
+
if isinstance(default_prompt, str) and default_prompt != '':
|
110 |
+
shared.gradio_root.load(lambda: default_prompt, outputs=prompt)
|
111 |
+
|
112 |
+
with gr.Column(scale=3, min_width=0):
|
113 |
+
generate_button = gr.Button(label="Generate", value="Generate", elem_classes='type_row', elem_id='generate_button', visible=True)
|
114 |
+
load_parameter_button = gr.Button(label="Load Parameters", value="Load Parameters", elem_classes='type_row', elem_id='load_parameter_button', visible=False)
|
115 |
+
skip_button = gr.Button(label="Skip", value="Skip", elem_classes='type_row_half', visible=False)
|
116 |
+
stop_button = gr.Button(label="Stop", value="Stop", elem_classes='type_row_half', elem_id='stop_button', visible=False)
|
117 |
+
|
118 |
+
def stop_clicked():
|
119 |
+
import ldm_patched.modules.model_management as model_management
|
120 |
+
shared.last_stop = 'stop'
|
121 |
+
model_management.interrupt_current_processing()
|
122 |
+
return [gr.update(interactive=False)] * 2
|
123 |
+
|
124 |
+
def skip_clicked():
|
125 |
+
import ldm_patched.modules.model_management as model_management
|
126 |
+
shared.last_stop = 'skip'
|
127 |
+
model_management.interrupt_current_processing()
|
128 |
+
return
|
129 |
+
|
130 |
+
stop_button.click(stop_clicked, outputs=[skip_button, stop_button],
|
131 |
+
queue=False, show_progress=False, _js='cancelGenerateForever')
|
132 |
+
skip_button.click(skip_clicked, queue=False, show_progress=False)
|
133 |
+
with gr.Row(elem_classes='advanced_check_row'):
|
134 |
+
input_image_checkbox = gr.Checkbox(label='Input Image', value=False, container=False, elem_classes='min_check')
|
135 |
+
advanced_checkbox = gr.Checkbox(label='Advanced', value=modules.config.default_advanced_checkbox, container=False, elem_classes='min_check')
|
136 |
+
with gr.Row(visible=False) as image_input_panel:
|
137 |
+
with gr.Tabs():
|
138 |
+
with gr.TabItem(label='Upscale or Variation') as uov_tab:
|
139 |
+
with gr.Row():
|
140 |
+
with gr.Column():
|
141 |
+
uov_input_image = grh.Image(label='Drag above image to here', source='upload', type='numpy')
|
142 |
+
with gr.Column():
|
143 |
+
uov_method = gr.Radio(label='Upscale or Variation:', choices=flags.uov_list, value=flags.disabled)
|
144 |
+
gr.HTML('<a href="https://github.com/lllyasviel/Fooocus/discussions/390" target="_blank">\U0001F4D4 Document</a>')
|
145 |
+
with gr.TabItem(label='Image Prompt') as ip_tab:
|
146 |
+
with gr.Row():
|
147 |
+
ip_images = []
|
148 |
+
ip_types = []
|
149 |
+
ip_stops = []
|
150 |
+
ip_weights = []
|
151 |
+
ip_ctrls = []
|
152 |
+
ip_ad_cols = []
|
153 |
+
for _ in range(4):
|
154 |
+
with gr.Column():
|
155 |
+
ip_image = grh.Image(label='Image', source='upload', type='numpy', show_label=False, height=300)
|
156 |
+
ip_images.append(ip_image)
|
157 |
+
ip_ctrls.append(ip_image)
|
158 |
+
with gr.Column(visible=False) as ad_col:
|
159 |
+
with gr.Row():
|
160 |
+
default_end, default_weight = flags.default_parameters[flags.default_ip]
|
161 |
+
|
162 |
+
ip_stop = gr.Slider(label='Stop At', minimum=0.0, maximum=1.0, step=0.001, value=default_end)
|
163 |
+
ip_stops.append(ip_stop)
|
164 |
+
ip_ctrls.append(ip_stop)
|
165 |
+
|
166 |
+
ip_weight = gr.Slider(label='Weight', minimum=0.0, maximum=2.0, step=0.001, value=default_weight)
|
167 |
+
ip_weights.append(ip_weight)
|
168 |
+
ip_ctrls.append(ip_weight)
|
169 |
+
|
170 |
+
ip_type = gr.Radio(label='Type', choices=flags.ip_list, value=flags.default_ip, container=False)
|
171 |
+
ip_types.append(ip_type)
|
172 |
+
ip_ctrls.append(ip_type)
|
173 |
+
|
174 |
+
ip_type.change(lambda x: flags.default_parameters[x], inputs=[ip_type], outputs=[ip_stop, ip_weight], queue=False, show_progress=False)
|
175 |
+
ip_ad_cols.append(ad_col)
|
176 |
+
ip_advanced = gr.Checkbox(label='Advanced', value=False, container=False)
|
177 |
+
gr.HTML('* \"Image Prompt\" is powered by Fooocus Image Mixture Engine (v1.0.1). <a href="https://github.com/lllyasviel/Fooocus/discussions/557" target="_blank">\U0001F4D4 Document</a>')
|
178 |
+
|
179 |
+
def ip_advance_checked(x):
|
180 |
+
return [gr.update(visible=x)] * len(ip_ad_cols) + \
|
181 |
+
[flags.default_ip] * len(ip_types) + \
|
182 |
+
[flags.default_parameters[flags.default_ip][0]] * len(ip_stops) + \
|
183 |
+
[flags.default_parameters[flags.default_ip][1]] * len(ip_weights)
|
184 |
+
|
185 |
+
ip_advanced.change(ip_advance_checked, inputs=ip_advanced,
|
186 |
+
outputs=ip_ad_cols + ip_types + ip_stops + ip_weights,
|
187 |
+
queue=False, show_progress=False)
|
188 |
+
with gr.TabItem(label='Inpaint or Outpaint') as inpaint_tab:
|
189 |
+
with gr.Row():
|
190 |
+
inpaint_input_image = grh.Image(label='Drag inpaint or outpaint image to here', source='upload', type='numpy', tool='sketch', height=500, brush_color="#FFFFFF", elem_id='inpaint_canvas')
|
191 |
+
inpaint_mask_image = grh.Image(label='Mask Upload', source='upload', type='numpy', height=500, visible=False)
|
192 |
+
|
193 |
+
with gr.Row():
|
194 |
+
inpaint_additional_prompt = gr.Textbox(placeholder="Describe what you want to inpaint.", elem_id='inpaint_additional_prompt', label='Inpaint Additional Prompt', visible=False)
|
195 |
+
outpaint_selections = gr.CheckboxGroup(choices=['Left', 'Right', 'Top', 'Bottom'], value=[], label='Outpaint Direction')
|
196 |
+
inpaint_mode = gr.Dropdown(choices=modules.flags.inpaint_options, value=modules.flags.inpaint_option_default, label='Method')
|
197 |
+
example_inpaint_prompts = gr.Dataset(samples=modules.config.example_inpaint_prompts, label='Additional Prompt Quick List', components=[inpaint_additional_prompt], visible=False)
|
198 |
+
gr.HTML('* Powered by Fooocus Inpaint Engine <a href="https://github.com/lllyasviel/Fooocus/discussions/414" target="_blank">\U0001F4D4 Document</a>')
|
199 |
+
example_inpaint_prompts.click(lambda x: x[0], inputs=example_inpaint_prompts, outputs=inpaint_additional_prompt, show_progress=False, queue=False)
|
200 |
+
with gr.TabItem(label='Describe') as desc_tab:
|
201 |
+
with gr.Row():
|
202 |
+
with gr.Column():
|
203 |
+
desc_input_image = grh.Image(label='Drag any image to here', source='upload', type='numpy')
|
204 |
+
with gr.Column():
|
205 |
+
desc_method = gr.Radio(
|
206 |
+
label='Content Type',
|
207 |
+
choices=[flags.desc_type_photo, flags.desc_type_anime],
|
208 |
+
value=flags.desc_type_photo)
|
209 |
+
desc_btn = gr.Button(value='Describe this Image into Prompt')
|
210 |
+
gr.HTML('<a href="https://github.com/lllyasviel/Fooocus/discussions/1363" target="_blank">\U0001F4D4 Document</a>')
|
211 |
+
switch_js = "(x) => {if(x){viewer_to_bottom(100);viewer_to_bottom(500);}else{viewer_to_top();} return x;}"
|
212 |
+
down_js = "() => {viewer_to_bottom();}"
|
213 |
+
|
214 |
+
input_image_checkbox.change(lambda x: gr.update(visible=x), inputs=input_image_checkbox,
|
215 |
+
outputs=image_input_panel, queue=False, show_progress=False, _js=switch_js)
|
216 |
+
ip_advanced.change(lambda: None, queue=False, show_progress=False, _js=down_js)
|
217 |
+
|
218 |
+
current_tab = gr.Textbox(value='uov', visible=False)
|
219 |
+
uov_tab.select(lambda: 'uov', outputs=current_tab, queue=False, _js=down_js, show_progress=False)
|
220 |
+
inpaint_tab.select(lambda: 'inpaint', outputs=current_tab, queue=False, _js=down_js, show_progress=False)
|
221 |
+
ip_tab.select(lambda: 'ip', outputs=current_tab, queue=False, _js=down_js, show_progress=False)
|
222 |
+
desc_tab.select(lambda: 'desc', outputs=current_tab, queue=False, _js=down_js, show_progress=False)
|
223 |
+
|
224 |
+
with gr.Column(scale=1, visible=modules.config.default_advanced_checkbox) as advanced_column:
|
225 |
+
with gr.Tab(label='Setting'):
|
226 |
+
performance_selection = gr.Radio(label='Performance',
|
227 |
+
choices=modules.flags.performance_selections,
|
228 |
+
value=modules.config.default_performance)
|
229 |
+
aspect_ratios_selection = gr.Radio(label='Aspect Ratios', choices=modules.config.available_aspect_ratios,
|
230 |
+
value=modules.config.default_aspect_ratio, info='width × height',
|
231 |
+
elem_classes='aspect_ratios')
|
232 |
+
image_number = gr.Slider(label='Image Number', minimum=1, maximum=modules.config.default_max_image_number, step=1, value=modules.config.default_image_number)
|
233 |
+
negative_prompt = gr.Textbox(label='Negative Prompt', show_label=True, placeholder="Type prompt here.",
|
234 |
+
info='Describing what you do not want to see.', lines=2,
|
235 |
+
elem_id='negative_prompt',
|
236 |
+
value=modules.config.default_prompt_negative)
|
237 |
+
seed_random = gr.Checkbox(label='Random', value=True)
|
238 |
+
image_seed = gr.Textbox(label='Seed', value=0, max_lines=1, visible=False) # workaround for https://github.com/gradio-app/gradio/issues/5354
|
239 |
+
|
240 |
+
def random_checked(r):
|
241 |
+
return gr.update(visible=not r)
|
242 |
+
|
243 |
+
def refresh_seed(r, seed_string):
|
244 |
+
if r:
|
245 |
+
return random.randint(constants.MIN_SEED, constants.MAX_SEED)
|
246 |
+
else:
|
247 |
+
try:
|
248 |
+
seed_value = int(seed_string)
|
249 |
+
if constants.MIN_SEED <= seed_value <= constants.MAX_SEED:
|
250 |
+
return seed_value
|
251 |
+
except ValueError:
|
252 |
+
pass
|
253 |
+
return random.randint(constants.MIN_SEED, constants.MAX_SEED)
|
254 |
+
|
255 |
+
seed_random.change(random_checked, inputs=[seed_random], outputs=[image_seed],
|
256 |
+
queue=False, show_progress=False)
|
257 |
+
|
258 |
+
if not args_manager.args.disable_image_log:
|
259 |
+
gr.HTML(f'<a href="file={get_current_html_path()}" target="_blank">\U0001F4DA History Log</a>')
|
260 |
+
|
261 |
+
with gr.Tab(label='Style'):
|
262 |
+
style_sorter.try_load_sorted_styles(
|
263 |
+
style_names=legal_style_names,
|
264 |
+
default_selected=modules.config.default_styles)
|
265 |
+
|
266 |
+
style_search_bar = gr.Textbox(show_label=False, container=False,
|
267 |
+
placeholder="\U0001F50E Type here to search styles ...",
|
268 |
+
value="",
|
269 |
+
label='Search Styles')
|
270 |
+
style_selections = gr.CheckboxGroup(show_label=False, container=False,
|
271 |
+
choices=copy.deepcopy(style_sorter.all_styles),
|
272 |
+
value=copy.deepcopy(modules.config.default_styles),
|
273 |
+
label='Selected Styles',
|
274 |
+
elem_classes=['style_selections'])
|
275 |
+
gradio_receiver_style_selections = gr.Textbox(elem_id='gradio_receiver_style_selections', visible=False)
|
276 |
+
|
277 |
+
shared.gradio_root.load(lambda: gr.update(choices=copy.deepcopy(style_sorter.all_styles)),
|
278 |
+
outputs=style_selections)
|
279 |
+
|
280 |
+
style_search_bar.change(style_sorter.search_styles,
|
281 |
+
inputs=[style_selections, style_search_bar],
|
282 |
+
outputs=style_selections,
|
283 |
+
queue=False,
|
284 |
+
show_progress=False).then(
|
285 |
+
lambda: None, _js='()=>{refresh_style_localization();}')
|
286 |
+
|
287 |
+
gradio_receiver_style_selections.input(style_sorter.sort_styles,
|
288 |
+
inputs=style_selections,
|
289 |
+
outputs=style_selections,
|
290 |
+
queue=False,
|
291 |
+
show_progress=False).then(
|
292 |
+
lambda: None, _js='()=>{refresh_style_localization();}')
|
293 |
+
|
294 |
+
with gr.Tab(label='Model'):
|
295 |
+
with gr.Group():
|
296 |
+
with gr.Row():
|
297 |
+
base_model = gr.Dropdown(label='Base Model (SDXL only)', choices=modules.config.model_filenames, value=modules.config.default_base_model_name, show_label=True)
|
298 |
+
refiner_model = gr.Dropdown(label='Refiner (SDXL or SD 1.5)', choices=['None'] + modules.config.model_filenames, value=modules.config.default_refiner_model_name, show_label=True)
|
299 |
+
|
300 |
+
refiner_switch = gr.Slider(label='Refiner Switch At', minimum=0.1, maximum=1.0, step=0.0001,
|
301 |
+
info='Use 0.4 for SD1.5 realistic models; '
|
302 |
+
'or 0.667 for SD1.5 anime models; '
|
303 |
+
'or 0.8 for XL-refiners; '
|
304 |
+
'or any value for switching two SDXL models.',
|
305 |
+
value=modules.config.default_refiner_switch,
|
306 |
+
visible=modules.config.default_refiner_model_name != 'None')
|
307 |
+
|
308 |
+
refiner_model.change(lambda x: gr.update(visible=x != 'None'),
|
309 |
+
inputs=refiner_model, outputs=refiner_switch, show_progress=False, queue=False)
|
310 |
+
|
311 |
+
with gr.Group():
|
312 |
+
lora_ctrls = []
|
313 |
+
|
314 |
+
for i, (n, v) in enumerate(modules.config.default_loras):
|
315 |
+
with gr.Row():
|
316 |
+
lora_model = gr.Dropdown(label=f'LoRA {i + 1}',
|
317 |
+
choices=['None'] + modules.config.lora_filenames, value=n)
|
318 |
+
lora_weight = gr.Slider(label='Weight', minimum=-2, maximum=2, step=0.01, value=v,
|
319 |
+
elem_classes='lora_weight')
|
320 |
+
lora_ctrls += [lora_model, lora_weight]
|
321 |
+
|
322 |
+
with gr.Row():
|
323 |
+
model_refresh = gr.Button(label='Refresh', value='\U0001f504 Refresh All Files', variant='secondary', elem_classes='refresh_button')
|
324 |
+
with gr.Tab(label='Advanced'):
|
325 |
+
guidance_scale = gr.Slider(label='Guidance Scale', minimum=1.0, maximum=30.0, step=0.01,
|
326 |
+
value=modules.config.default_cfg_scale,
|
327 |
+
info='Higher value means style is cleaner, vivider, and more artistic.')
|
328 |
+
sharpness = gr.Slider(label='Image Sharpness', minimum=0.0, maximum=30.0, step=0.001,
|
329 |
+
value=modules.config.default_sample_sharpness,
|
330 |
+
info='Higher value means image and texture are sharper.')
|
331 |
+
gr.HTML('<a href="https://github.com/lllyasviel/Fooocus/discussions/117" target="_blank">\U0001F4D4 Document</a>')
|
332 |
+
dev_mode = gr.Checkbox(label='Developer Debug Mode', value=False, container=False)
|
333 |
+
|
334 |
+
with gr.Column(visible=False) as dev_tools:
|
335 |
+
with gr.Tab(label='Debug Tools'):
|
336 |
+
adm_scaler_positive = gr.Slider(label='Positive ADM Guidance Scaler', minimum=0.1, maximum=3.0,
|
337 |
+
step=0.001, value=1.5, info='The scaler multiplied to positive ADM (use 1.0 to disable). ')
|
338 |
+
adm_scaler_negative = gr.Slider(label='Negative ADM Guidance Scaler', minimum=0.1, maximum=3.0,
|
339 |
+
step=0.001, value=0.8, info='The scaler multiplied to negative ADM (use 1.0 to disable). ')
|
340 |
+
adm_scaler_end = gr.Slider(label='ADM Guidance End At Step', minimum=0.0, maximum=1.0,
|
341 |
+
step=0.001, value=0.3,
|
342 |
+
info='When to end the guidance from positive/negative ADM. ')
|
343 |
+
|
344 |
+
refiner_swap_method = gr.Dropdown(label='Refiner swap method', value='joint',
|
345 |
+
choices=['joint', 'separate', 'vae'])
|
346 |
+
|
347 |
+
adaptive_cfg = gr.Slider(label='CFG Mimicking from TSNR', minimum=1.0, maximum=30.0, step=0.01,
|
348 |
+
value=modules.config.default_cfg_tsnr,
|
349 |
+
info='Enabling Fooocus\'s implementation of CFG mimicking for TSNR '
|
350 |
+
'(effective when real CFG > mimicked CFG).')
|
351 |
+
sampler_name = gr.Dropdown(label='Sampler', choices=flags.sampler_list,
|
352 |
+
value=modules.config.default_sampler)
|
353 |
+
scheduler_name = gr.Dropdown(label='Scheduler', choices=flags.scheduler_list,
|
354 |
+
value=modules.config.default_scheduler)
|
355 |
+
|
356 |
+
generate_image_grid = gr.Checkbox(label='Generate Image Grid for Each Batch',
|
357 |
+
info='(Experimental) This may cause performance problems on some computers and certain internet conditions.',
|
358 |
+
value=False)
|
359 |
+
|
360 |
+
overwrite_step = gr.Slider(label='Forced Overwrite of Sampling Step',
|
361 |
+
minimum=-1, maximum=200, step=1,
|
362 |
+
value=modules.config.default_overwrite_step,
|
363 |
+
info='Set as -1 to disable. For developer debugging.')
|
364 |
+
overwrite_switch = gr.Slider(label='Forced Overwrite of Refiner Switch Step',
|
365 |
+
minimum=-1, maximum=200, step=1,
|
366 |
+
value=modules.config.default_overwrite_switch,
|
367 |
+
info='Set as -1 to disable. For developer debugging.')
|
368 |
+
overwrite_width = gr.Slider(label='Forced Overwrite of Generating Width',
|
369 |
+
minimum=-1, maximum=2048, step=1, value=-1,
|
370 |
+
info='Set as -1 to disable. For developer debugging. '
|
371 |
+
'Results will be worse for non-standard numbers that SDXL is not trained on.')
|
372 |
+
overwrite_height = gr.Slider(label='Forced Overwrite of Generating Height',
|
373 |
+
minimum=-1, maximum=2048, step=1, value=-1,
|
374 |
+
info='Set as -1 to disable. For developer debugging. '
|
375 |
+
'Results will be worse for non-standard numbers that SDXL is not trained on.')
|
376 |
+
overwrite_vary_strength = gr.Slider(label='Forced Overwrite of Denoising Strength of "Vary"',
|
377 |
+
minimum=-1, maximum=1.0, step=0.001, value=-1,
|
378 |
+
info='Set as negative number to disable. For developer debugging.')
|
379 |
+
overwrite_upscale_strength = gr.Slider(label='Forced Overwrite of Denoising Strength of "Upscale"',
|
380 |
+
minimum=-1, maximum=1.0, step=0.001, value=-1,
|
381 |
+
info='Set as negative number to disable. For developer debugging.')
|
382 |
+
disable_preview = gr.Checkbox(label='Disable Preview', value=False,
|
383 |
+
info='Disable preview during generation.')
|
384 |
+
|
385 |
+
with gr.Tab(label='Control'):
|
386 |
+
debugging_cn_preprocessor = gr.Checkbox(label='Debug Preprocessors', value=False,
|
387 |
+
info='See the results from preprocessors.')
|
388 |
+
skipping_cn_preprocessor = gr.Checkbox(label='Skip Preprocessors', value=False,
|
389 |
+
info='Do not preprocess images. (Inputs are already canny/depth/cropped-face/etc.)')
|
390 |
+
|
391 |
+
mixing_image_prompt_and_vary_upscale = gr.Checkbox(label='Mixing Image Prompt and Vary/Upscale',
|
392 |
+
value=False)
|
393 |
+
mixing_image_prompt_and_inpaint = gr.Checkbox(label='Mixing Image Prompt and Inpaint',
|
394 |
+
value=False)
|
395 |
+
|
396 |
+
controlnet_softness = gr.Slider(label='Softness of ControlNet', minimum=0.0, maximum=1.0,
|
397 |
+
step=0.001, value=0.25,
|
398 |
+
info='Similar to the Control Mode in A1111 (use 0.0 to disable). ')
|
399 |
+
|
400 |
+
with gr.Tab(label='Canny'):
|
401 |
+
canny_low_threshold = gr.Slider(label='Canny Low Threshold', minimum=1, maximum=255,
|
402 |
+
step=1, value=64)
|
403 |
+
canny_high_threshold = gr.Slider(label='Canny High Threshold', minimum=1, maximum=255,
|
404 |
+
step=1, value=128)
|
405 |
+
|
406 |
+
with gr.Tab(label='Inpaint'):
|
407 |
+
debugging_inpaint_preprocessor = gr.Checkbox(label='Debug Inpaint Preprocessing', value=False)
|
408 |
+
inpaint_disable_initial_latent = gr.Checkbox(label='Disable initial latent in inpaint', value=False)
|
409 |
+
inpaint_engine = gr.Dropdown(label='Inpaint Engine',
|
410 |
+
value=modules.config.default_inpaint_engine_version,
|
411 |
+
choices=flags.inpaint_engine_versions,
|
412 |
+
info='Version of Fooocus inpaint model')
|
413 |
+
inpaint_strength = gr.Slider(label='Inpaint Denoising Strength',
|
414 |
+
minimum=0.0, maximum=1.0, step=0.001, value=1.0,
|
415 |
+
info='Same as the denoising strength in A1111 inpaint. '
|
416 |
+
'Only used in inpaint, not used in outpaint. '
|
417 |
+
'(Outpaint always use 1.0)')
|
418 |
+
inpaint_respective_field = gr.Slider(label='Inpaint Respective Field',
|
419 |
+
minimum=0.0, maximum=1.0, step=0.001, value=0.618,
|
420 |
+
info='The area to inpaint. '
|
421 |
+
'Value 0 is same as "Only Masked" in A1111. '
|
422 |
+
'Value 1 is same as "Whole Image" in A1111. '
|
423 |
+
'Only used in inpaint, not used in outpaint. '
|
424 |
+
'(Outpaint always use 1.0)')
|
425 |
+
inpaint_erode_or_dilate = gr.Slider(label='Mask Erode or Dilate',
|
426 |
+
minimum=-64, maximum=64, step=1, value=0,
|
427 |
+
info='Positive value will make white area in the mask larger, '
|
428 |
+
'negative value will make white area smaller.'
|
429 |
+
'(default is 0, always process before any mask invert)')
|
430 |
+
inpaint_mask_upload_checkbox = gr.Checkbox(label='Enable Mask Upload', value=False)
|
431 |
+
invert_mask_checkbox = gr.Checkbox(label='Invert Mask', value=False)
|
432 |
+
|
433 |
+
inpaint_ctrls = [debugging_inpaint_preprocessor, inpaint_disable_initial_latent, inpaint_engine,
|
434 |
+
inpaint_strength, inpaint_respective_field,
|
435 |
+
inpaint_mask_upload_checkbox, invert_mask_checkbox, inpaint_erode_or_dilate]
|
436 |
+
|
437 |
+
inpaint_mask_upload_checkbox.change(lambda x: gr.update(visible=x),
|
438 |
+
inputs=inpaint_mask_upload_checkbox,
|
439 |
+
outputs=inpaint_mask_image, queue=False, show_progress=False)
|
440 |
+
|
441 |
+
with gr.Tab(label='FreeU'):
|
442 |
+
freeu_enabled = gr.Checkbox(label='Enabled', value=False)
|
443 |
+
freeu_b1 = gr.Slider(label='B1', minimum=0, maximum=2, step=0.01, value=1.01)
|
444 |
+
freeu_b2 = gr.Slider(label='B2', minimum=0, maximum=2, step=0.01, value=1.02)
|
445 |
+
freeu_s1 = gr.Slider(label='S1', minimum=0, maximum=4, step=0.01, value=0.99)
|
446 |
+
freeu_s2 = gr.Slider(label='S2', minimum=0, maximum=4, step=0.01, value=0.95)
|
447 |
+
freeu_ctrls = [freeu_enabled, freeu_b1, freeu_b2, freeu_s1, freeu_s2]
|
448 |
+
|
449 |
+
adps = [disable_preview, adm_scaler_positive, adm_scaler_negative, adm_scaler_end, adaptive_cfg, sampler_name,
|
450 |
+
scheduler_name, generate_image_grid, overwrite_step, overwrite_switch, overwrite_width, overwrite_height,
|
451 |
+
overwrite_vary_strength, overwrite_upscale_strength,
|
452 |
+
mixing_image_prompt_and_vary_upscale, mixing_image_prompt_and_inpaint,
|
453 |
+
debugging_cn_preprocessor, skipping_cn_preprocessor, controlnet_softness,
|
454 |
+
canny_low_threshold, canny_high_threshold, refiner_swap_method]
|
455 |
+
adps += freeu_ctrls
|
456 |
+
adps += inpaint_ctrls
|
457 |
+
|
458 |
+
def dev_mode_checked(r):
|
459 |
+
return gr.update(visible=r)
|
460 |
+
|
461 |
+
|
462 |
+
dev_mode.change(dev_mode_checked, inputs=[dev_mode], outputs=[dev_tools],
|
463 |
+
queue=False, show_progress=False)
|
464 |
+
|
465 |
+
def model_refresh_clicked():
|
466 |
+
modules.config.update_all_model_names()
|
467 |
+
results = []
|
468 |
+
results += [gr.update(choices=modules.config.model_filenames), gr.update(choices=['None'] + modules.config.model_filenames)]
|
469 |
+
for i in range(5):
|
470 |
+
results += [gr.update(choices=['None'] + modules.config.lora_filenames), gr.update()]
|
471 |
+
return results
|
472 |
+
|
473 |
+
model_refresh.click(model_refresh_clicked, [], [base_model, refiner_model] + lora_ctrls,
|
474 |
+
queue=False, show_progress=False)
|
475 |
+
|
476 |
+
performance_selection.change(lambda x: [gr.update(interactive=x != 'Extreme Speed')] * 11 +
|
477 |
+
[gr.update(visible=x != 'Extreme Speed')] * 1,
|
478 |
+
inputs=performance_selection,
|
479 |
+
outputs=[
|
480 |
+
guidance_scale, sharpness, adm_scaler_end, adm_scaler_positive,
|
481 |
+
adm_scaler_negative, refiner_switch, refiner_model, sampler_name,
|
482 |
+
scheduler_name, adaptive_cfg, refiner_swap_method, negative_prompt
|
483 |
+
], queue=False, show_progress=False)
|
484 |
+
|
485 |
+
advanced_checkbox.change(lambda x: gr.update(visible=x), advanced_checkbox, advanced_column,
|
486 |
+
queue=False, show_progress=False) \
|
487 |
+
.then(fn=lambda: None, _js='refresh_grid_delayed', queue=False, show_progress=False)
|
488 |
+
|
489 |
+
def inpaint_mode_change(mode):
|
490 |
+
assert mode in modules.flags.inpaint_options
|
491 |
+
|
492 |
+
# inpaint_additional_prompt, outpaint_selections, example_inpaint_prompts,
|
493 |
+
# inpaint_disable_initial_latent, inpaint_engine,
|
494 |
+
# inpaint_strength, inpaint_respective_field
|
495 |
+
|
496 |
+
if mode == modules.flags.inpaint_option_detail:
|
497 |
+
return [
|
498 |
+
gr.update(visible=True), gr.update(visible=False, value=[]),
|
499 |
+
gr.Dataset.update(visible=True, samples=modules.config.example_inpaint_prompts),
|
500 |
+
False, 'None', 0.5, 0.0
|
501 |
+
]
|
502 |
+
|
503 |
+
if mode == modules.flags.inpaint_option_modify:
|
504 |
+
return [
|
505 |
+
gr.update(visible=True), gr.update(visible=False, value=[]),
|
506 |
+
gr.Dataset.update(visible=False, samples=modules.config.example_inpaint_prompts),
|
507 |
+
True, modules.config.default_inpaint_engine_version, 1.0, 0.0
|
508 |
+
]
|
509 |
+
|
510 |
+
return [
|
511 |
+
gr.update(visible=False, value=''), gr.update(visible=True),
|
512 |
+
gr.Dataset.update(visible=False, samples=modules.config.example_inpaint_prompts),
|
513 |
+
False, modules.config.default_inpaint_engine_version, 1.0, 0.618
|
514 |
+
]
|
515 |
+
|
516 |
+
inpaint_mode.input(inpaint_mode_change, inputs=inpaint_mode, outputs=[
|
517 |
+
inpaint_additional_prompt, outpaint_selections, example_inpaint_prompts,
|
518 |
+
inpaint_disable_initial_latent, inpaint_engine,
|
519 |
+
inpaint_strength, inpaint_respective_field
|
520 |
+
], show_progress=False, queue=False)
|
521 |
+
|
522 |
+
ctrls = [
|
523 |
+
prompt, negative_prompt, style_selections,
|
524 |
+
performance_selection, aspect_ratios_selection, image_number, image_seed, sharpness, guidance_scale
|
525 |
+
]
|
526 |
+
|
527 |
+
ctrls += [base_model, refiner_model, refiner_switch] + lora_ctrls
|
528 |
+
ctrls += [input_image_checkbox, current_tab]
|
529 |
+
ctrls += [uov_method, uov_input_image]
|
530 |
+
ctrls += [outpaint_selections, inpaint_input_image, inpaint_additional_prompt, inpaint_mask_image]
|
531 |
+
ctrls += ip_ctrls
|
532 |
+
|
533 |
+
state_is_generating = gr.State(False)
|
534 |
+
|
535 |
+
def parse_meta(raw_prompt_txt, is_generating):
|
536 |
+
loaded_json = None
|
537 |
+
try:
|
538 |
+
if '{' in raw_prompt_txt:
|
539 |
+
if '}' in raw_prompt_txt:
|
540 |
+
if ':' in raw_prompt_txt:
|
541 |
+
loaded_json = json.loads(raw_prompt_txt)
|
542 |
+
assert isinstance(loaded_json, dict)
|
543 |
+
except:
|
544 |
+
loaded_json = None
|
545 |
+
|
546 |
+
if loaded_json is None:
|
547 |
+
if is_generating:
|
548 |
+
return gr.update(), gr.update(), gr.update()
|
549 |
+
else:
|
550 |
+
return gr.update(), gr.update(visible=True), gr.update(visible=False)
|
551 |
+
|
552 |
+
return json.dumps(loaded_json), gr.update(visible=False), gr.update(visible=True)
|
553 |
+
|
554 |
+
prompt.input(parse_meta, inputs=[prompt, state_is_generating], outputs=[prompt, generate_button, load_parameter_button], queue=False, show_progress=False)
|
555 |
+
|
556 |
+
load_parameter_button.click(modules.meta_parser.load_parameter_button_click, inputs=[prompt, state_is_generating], outputs=[
|
557 |
+
advanced_checkbox,
|
558 |
+
image_number,
|
559 |
+
prompt,
|
560 |
+
negative_prompt,
|
561 |
+
style_selections,
|
562 |
+
performance_selection,
|
563 |
+
aspect_ratios_selection,
|
564 |
+
overwrite_width,
|
565 |
+
overwrite_height,
|
566 |
+
sharpness,
|
567 |
+
guidance_scale,
|
568 |
+
adm_scaler_positive,
|
569 |
+
adm_scaler_negative,
|
570 |
+
adm_scaler_end,
|
571 |
+
base_model,
|
572 |
+
refiner_model,
|
573 |
+
refiner_switch,
|
574 |
+
sampler_name,
|
575 |
+
scheduler_name,
|
576 |
+
seed_random,
|
577 |
+
image_seed,
|
578 |
+
generate_button,
|
579 |
+
load_parameter_button
|
580 |
+
] + lora_ctrls, queue=False, show_progress=False)
|
581 |
+
|
582 |
+
generate_button.click(lambda: (gr.update(visible=True, interactive=True), gr.update(visible=True, interactive=True), gr.update(visible=False, interactive=False), [], True),
|
583 |
+
outputs=[stop_button, skip_button, generate_button, gallery, state_is_generating]) \
|
584 |
+
.then(fn=refresh_seed, inputs=[seed_random, image_seed], outputs=image_seed) \
|
585 |
+
.then(advanced_parameters.set_all_advanced_parameters, inputs=adps) \
|
586 |
+
.then(fn=generate_clicked, inputs=ctrls, outputs=[progress_html, progress_window, progress_gallery, gallery]) \
|
587 |
+
.then(lambda: (gr.update(visible=True, interactive=True), gr.update(visible=False, interactive=False), gr.update(visible=False, interactive=False), False),
|
588 |
+
outputs=[generate_button, stop_button, skip_button, state_is_generating]) \
|
589 |
+
.then(fn=lambda: None, _js='playNotification').then(fn=lambda: None, _js='refresh_grid_delayed')
|
590 |
+
|
591 |
+
for notification_file in ['notification.ogg', 'notification.mp3']:
|
592 |
+
if os.path.exists(notification_file):
|
593 |
+
gr.Audio(interactive=False, value=notification_file, elem_id='audio_notification', visible=False)
|
594 |
+
break
|
595 |
+
|
596 |
+
def trigger_describe(mode, img):
|
597 |
+
if mode == flags.desc_type_photo:
|
598 |
+
from extras.interrogate import default_interrogator as default_interrogator_photo
|
599 |
+
return default_interrogator_photo(img), ["Fooocus V2", "Fooocus Enhance", "Fooocus Sharp"]
|
600 |
+
if mode == flags.desc_type_anime:
|
601 |
+
from extras.wd14tagger import default_interrogator as default_interrogator_anime
|
602 |
+
return default_interrogator_anime(img), ["Fooocus V2", "Fooocus Masterpiece"]
|
603 |
+
return mode, ["Fooocus V2"]
|
604 |
+
|
605 |
+
desc_btn.click(trigger_describe, inputs=[desc_method, desc_input_image],
|
606 |
+
outputs=[prompt, style_selections], show_progress=True, queue=True)
|
607 |
+
|
608 |
+
|
609 |
+
def dump_default_english_config():
|
610 |
+
from modules.localization import dump_english_config
|
611 |
+
dump_english_config(grh.all_components)
|
612 |
+
|
613 |
+
|
614 |
+
# dump_default_english_config()
|
615 |
+
|
616 |
+
shared.gradio_root.launch(
|
617 |
+
inbrowser=args_manager.args.in_browser,
|
618 |
+
server_name=args_manager.args.listen,
|
619 |
+
server_port=args_manager.args.port,
|
620 |
+
share=args_manager.args.share,
|
621 |
+
auth=check_auth if args_manager.args.share and auth_enabled else None,
|
622 |
+
blocked_paths=[constants.AUTH_FILENAME]
|
623 |
+
)
|