osmunphotography commited on
Commit
54ac7d4
·
verified ·
1 Parent(s): c723d2f

Upload 14 files

Browse files
Files changed (14) hide show
  1. auth.py +41 -0
  2. cldm.py +312 -0
  3. constants.py +5 -0
  4. entry_with_update 2.py +46 -0
  5. face_restoration_helper.py +374 -0
  6. inpaint_worker 2.py +264 -0
  7. inpaint_worker.py +264 -0
  8. launch_util.py +103 -0
  9. lora.py +152 -0
  10. model_loader.py +26 -0
  11. sdxl_styles.py +82 -0
  12. upscaler.py +34 -0
  13. util.py +177 -0
  14. webui.py +623 -0
auth.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import hashlib
3
+ import modules.constants as constants
4
+
5
+ from os.path import exists
6
+
7
+
8
+ def auth_list_to_dict(auth_list):
9
+ auth_dict = {}
10
+ for auth_data in auth_list:
11
+ if 'user' in auth_data:
12
+ if 'hash' in auth_data:
13
+ auth_dict |= {auth_data['user']: auth_data['hash']}
14
+ elif 'pass' in auth_data:
15
+ auth_dict |= {auth_data['user']: hashlib.sha256(bytes(auth_data['pass'], encoding='utf-8')).hexdigest()}
16
+ return auth_dict
17
+
18
+
19
+ def load_auth_data(filename=None):
20
+ auth_dict = None
21
+ if filename != None and exists(filename):
22
+ with open(filename, encoding='utf-8') as auth_file:
23
+ try:
24
+ auth_obj = json.load(auth_file)
25
+ if isinstance(auth_obj, list) and len(auth_obj) > 0:
26
+ auth_dict = auth_list_to_dict(auth_obj)
27
+ except Exception as e:
28
+ print('load_auth_data, e: ' + str(e))
29
+ return auth_dict
30
+
31
+
32
+ auth_dict = load_auth_data(constants.AUTH_FILENAME)
33
+
34
+ auth_enabled = auth_dict != None
35
+
36
+
37
+ def check_auth(user, password):
38
+ if user not in auth_dict:
39
+ return False
40
+ else:
41
+ return hashlib.sha256(bytes(password, encoding='utf-8')).hexdigest() == auth_dict[user]
cldm.py ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #taken from: https://github.com/lllyasviel/ControlNet
2
+ #and modified
3
+
4
+ import torch
5
+ import torch as th
6
+ import torch.nn as nn
7
+
8
+ from ldm_patched.ldm.modules.diffusionmodules.util import (
9
+ zero_module,
10
+ timestep_embedding,
11
+ )
12
+
13
+ from ldm_patched.ldm.modules.attention import SpatialTransformer
14
+ from ldm_patched.ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample
15
+ from ldm_patched.ldm.util import exists
16
+ import ldm_patched.modules.ops
17
+
18
+ class ControlledUnetModel(UNetModel):
19
+ #implemented in the ldm unet
20
+ pass
21
+
22
+ class ControlNet(nn.Module):
23
+ def __init__(
24
+ self,
25
+ image_size,
26
+ in_channels,
27
+ model_channels,
28
+ hint_channels,
29
+ num_res_blocks,
30
+ dropout=0,
31
+ channel_mult=(1, 2, 4, 8),
32
+ conv_resample=True,
33
+ dims=2,
34
+ num_classes=None,
35
+ use_checkpoint=False,
36
+ dtype=torch.float32,
37
+ num_heads=-1,
38
+ num_head_channels=-1,
39
+ num_heads_upsample=-1,
40
+ use_scale_shift_norm=False,
41
+ resblock_updown=False,
42
+ use_new_attention_order=False,
43
+ use_spatial_transformer=False, # custom transformer support
44
+ transformer_depth=1, # custom transformer support
45
+ context_dim=None, # custom transformer support
46
+ n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
47
+ legacy=True,
48
+ disable_self_attentions=None,
49
+ num_attention_blocks=None,
50
+ disable_middle_self_attn=False,
51
+ use_linear_in_transformer=False,
52
+ adm_in_channels=None,
53
+ transformer_depth_middle=None,
54
+ transformer_depth_output=None,
55
+ device=None,
56
+ operations=ldm_patched.modules.ops.disable_weight_init,
57
+ **kwargs,
58
+ ):
59
+ super().__init__()
60
+ assert use_spatial_transformer == True, "use_spatial_transformer has to be true"
61
+ if use_spatial_transformer:
62
+ assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
63
+
64
+ if context_dim is not None:
65
+ assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
66
+ # from omegaconf.listconfig import ListConfig
67
+ # if type(context_dim) == ListConfig:
68
+ # context_dim = list(context_dim)
69
+
70
+ if num_heads_upsample == -1:
71
+ num_heads_upsample = num_heads
72
+
73
+ if num_heads == -1:
74
+ assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
75
+
76
+ if num_head_channels == -1:
77
+ assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
78
+
79
+ self.dims = dims
80
+ self.image_size = image_size
81
+ self.in_channels = in_channels
82
+ self.model_channels = model_channels
83
+
84
+ if isinstance(num_res_blocks, int):
85
+ self.num_res_blocks = len(channel_mult) * [num_res_blocks]
86
+ else:
87
+ if len(num_res_blocks) != len(channel_mult):
88
+ raise ValueError("provide num_res_blocks either as an int (globally constant) or "
89
+ "as a list/tuple (per-level) with the same length as channel_mult")
90
+ self.num_res_blocks = num_res_blocks
91
+
92
+ if disable_self_attentions is not None:
93
+ # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
94
+ assert len(disable_self_attentions) == len(channel_mult)
95
+ if num_attention_blocks is not None:
96
+ assert len(num_attention_blocks) == len(self.num_res_blocks)
97
+ assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
98
+
99
+ transformer_depth = transformer_depth[:]
100
+
101
+ self.dropout = dropout
102
+ self.channel_mult = channel_mult
103
+ self.conv_resample = conv_resample
104
+ self.num_classes = num_classes
105
+ self.use_checkpoint = use_checkpoint
106
+ self.dtype = dtype
107
+ self.num_heads = num_heads
108
+ self.num_head_channels = num_head_channels
109
+ self.num_heads_upsample = num_heads_upsample
110
+ self.predict_codebook_ids = n_embed is not None
111
+
112
+ time_embed_dim = model_channels * 4
113
+ self.time_embed = nn.Sequential(
114
+ operations.Linear(model_channels, time_embed_dim, dtype=self.dtype, device=device),
115
+ nn.SiLU(),
116
+ operations.Linear(time_embed_dim, time_embed_dim, dtype=self.dtype, device=device),
117
+ )
118
+
119
+ if self.num_classes is not None:
120
+ if isinstance(self.num_classes, int):
121
+ self.label_emb = nn.Embedding(num_classes, time_embed_dim)
122
+ elif self.num_classes == "continuous":
123
+ print("setting up linear c_adm embedding layer")
124
+ self.label_emb = nn.Linear(1, time_embed_dim)
125
+ elif self.num_classes == "sequential":
126
+ assert adm_in_channels is not None
127
+ self.label_emb = nn.Sequential(
128
+ nn.Sequential(
129
+ operations.Linear(adm_in_channels, time_embed_dim, dtype=self.dtype, device=device),
130
+ nn.SiLU(),
131
+ operations.Linear(time_embed_dim, time_embed_dim, dtype=self.dtype, device=device),
132
+ )
133
+ )
134
+ else:
135
+ raise ValueError()
136
+
137
+ self.input_blocks = nn.ModuleList(
138
+ [
139
+ TimestepEmbedSequential(
140
+ operations.conv_nd(dims, in_channels, model_channels, 3, padding=1, dtype=self.dtype, device=device)
141
+ )
142
+ ]
143
+ )
144
+ self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels, operations=operations, dtype=self.dtype, device=device)])
145
+
146
+ self.input_hint_block = TimestepEmbedSequential(
147
+ operations.conv_nd(dims, hint_channels, 16, 3, padding=1, dtype=self.dtype, device=device),
148
+ nn.SiLU(),
149
+ operations.conv_nd(dims, 16, 16, 3, padding=1, dtype=self.dtype, device=device),
150
+ nn.SiLU(),
151
+ operations.conv_nd(dims, 16, 32, 3, padding=1, stride=2, dtype=self.dtype, device=device),
152
+ nn.SiLU(),
153
+ operations.conv_nd(dims, 32, 32, 3, padding=1, dtype=self.dtype, device=device),
154
+ nn.SiLU(),
155
+ operations.conv_nd(dims, 32, 96, 3, padding=1, stride=2, dtype=self.dtype, device=device),
156
+ nn.SiLU(),
157
+ operations.conv_nd(dims, 96, 96, 3, padding=1, dtype=self.dtype, device=device),
158
+ nn.SiLU(),
159
+ operations.conv_nd(dims, 96, 256, 3, padding=1, stride=2, dtype=self.dtype, device=device),
160
+ nn.SiLU(),
161
+ operations.conv_nd(dims, 256, model_channels, 3, padding=1, dtype=self.dtype, device=device)
162
+ )
163
+
164
+ self._feature_size = model_channels
165
+ input_block_chans = [model_channels]
166
+ ch = model_channels
167
+ ds = 1
168
+ for level, mult in enumerate(channel_mult):
169
+ for nr in range(self.num_res_blocks[level]):
170
+ layers = [
171
+ ResBlock(
172
+ ch,
173
+ time_embed_dim,
174
+ dropout,
175
+ out_channels=mult * model_channels,
176
+ dims=dims,
177
+ use_checkpoint=use_checkpoint,
178
+ use_scale_shift_norm=use_scale_shift_norm,
179
+ dtype=self.dtype,
180
+ device=device,
181
+ operations=operations,
182
+ )
183
+ ]
184
+ ch = mult * model_channels
185
+ num_transformers = transformer_depth.pop(0)
186
+ if num_transformers > 0:
187
+ if num_head_channels == -1:
188
+ dim_head = ch // num_heads
189
+ else:
190
+ num_heads = ch // num_head_channels
191
+ dim_head = num_head_channels
192
+ if legacy:
193
+ #num_heads = 1
194
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
195
+ if exists(disable_self_attentions):
196
+ disabled_sa = disable_self_attentions[level]
197
+ else:
198
+ disabled_sa = False
199
+
200
+ if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
201
+ layers.append(
202
+ SpatialTransformer(
203
+ ch, num_heads, dim_head, depth=num_transformers, context_dim=context_dim,
204
+ disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
205
+ use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations
206
+ )
207
+ )
208
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
209
+ self.zero_convs.append(self.make_zero_conv(ch, operations=operations, dtype=self.dtype, device=device))
210
+ self._feature_size += ch
211
+ input_block_chans.append(ch)
212
+ if level != len(channel_mult) - 1:
213
+ out_ch = ch
214
+ self.input_blocks.append(
215
+ TimestepEmbedSequential(
216
+ ResBlock(
217
+ ch,
218
+ time_embed_dim,
219
+ dropout,
220
+ out_channels=out_ch,
221
+ dims=dims,
222
+ use_checkpoint=use_checkpoint,
223
+ use_scale_shift_norm=use_scale_shift_norm,
224
+ down=True,
225
+ dtype=self.dtype,
226
+ device=device,
227
+ operations=operations
228
+ )
229
+ if resblock_updown
230
+ else Downsample(
231
+ ch, conv_resample, dims=dims, out_channels=out_ch, dtype=self.dtype, device=device, operations=operations
232
+ )
233
+ )
234
+ )
235
+ ch = out_ch
236
+ input_block_chans.append(ch)
237
+ self.zero_convs.append(self.make_zero_conv(ch, operations=operations, dtype=self.dtype, device=device))
238
+ ds *= 2
239
+ self._feature_size += ch
240
+
241
+ if num_head_channels == -1:
242
+ dim_head = ch // num_heads
243
+ else:
244
+ num_heads = ch // num_head_channels
245
+ dim_head = num_head_channels
246
+ if legacy:
247
+ #num_heads = 1
248
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
249
+ mid_block = [
250
+ ResBlock(
251
+ ch,
252
+ time_embed_dim,
253
+ dropout,
254
+ dims=dims,
255
+ use_checkpoint=use_checkpoint,
256
+ use_scale_shift_norm=use_scale_shift_norm,
257
+ dtype=self.dtype,
258
+ device=device,
259
+ operations=operations
260
+ )]
261
+ if transformer_depth_middle >= 0:
262
+ mid_block += [SpatialTransformer( # always uses a self-attn
263
+ ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim,
264
+ disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
265
+ use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations
266
+ ),
267
+ ResBlock(
268
+ ch,
269
+ time_embed_dim,
270
+ dropout,
271
+ dims=dims,
272
+ use_checkpoint=use_checkpoint,
273
+ use_scale_shift_norm=use_scale_shift_norm,
274
+ dtype=self.dtype,
275
+ device=device,
276
+ operations=operations
277
+ )]
278
+ self.middle_block = TimestepEmbedSequential(*mid_block)
279
+ self.middle_block_out = self.make_zero_conv(ch, operations=operations, dtype=self.dtype, device=device)
280
+ self._feature_size += ch
281
+
282
+ def make_zero_conv(self, channels, operations=None, dtype=None, device=None):
283
+ return TimestepEmbedSequential(operations.conv_nd(self.dims, channels, channels, 1, padding=0, dtype=dtype, device=device))
284
+
285
+ def forward(self, x, hint, timesteps, context, y=None, **kwargs):
286
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
287
+ emb = self.time_embed(t_emb)
288
+
289
+ guided_hint = self.input_hint_block(hint, emb, context)
290
+
291
+ outs = []
292
+
293
+ hs = []
294
+ if self.num_classes is not None:
295
+ assert y.shape[0] == x.shape[0]
296
+ emb = emb + self.label_emb(y)
297
+
298
+ h = x
299
+ for module, zero_conv in zip(self.input_blocks, self.zero_convs):
300
+ if guided_hint is not None:
301
+ h = module(h, emb, context)
302
+ h += guided_hint
303
+ guided_hint = None
304
+ else:
305
+ h = module(h, emb, context)
306
+ outs.append(zero_conv(h, emb, context))
307
+
308
+ h = self.middle_block(h, emb, context)
309
+ outs.append(self.middle_block_out(h, emb, context))
310
+
311
+ return outs
312
+
constants.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # as in k-diffusion (sampling.py)
2
+ MIN_SEED = 0
3
+ MAX_SEED = 2**63 - 1
4
+
5
+ AUTH_FILENAME = 'auth.json'
entry_with_update 2.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+
5
+ root = os.path.dirname(os.path.abspath(__file__))
6
+ sys.path.append(root)
7
+ os.chdir(root)
8
+
9
+
10
+ try:
11
+ import pygit2
12
+ pygit2.option(pygit2.GIT_OPT_SET_OWNER_VALIDATION, 0)
13
+
14
+ repo = pygit2.Repository(os.path.abspath(os.path.dirname(__file__)))
15
+
16
+ branch_name = repo.head.shorthand
17
+
18
+ remote_name = 'origin'
19
+ remote = repo.remotes[remote_name]
20
+
21
+ remote.fetch()
22
+
23
+ local_branch_ref = f'refs/heads/{branch_name}'
24
+ local_branch = repo.lookup_reference(local_branch_ref)
25
+
26
+ remote_reference = f'refs/remotes/{remote_name}/{branch_name}'
27
+ remote_commit = repo.revparse_single(remote_reference)
28
+
29
+ merge_result, _ = repo.merge_analysis(remote_commit.id)
30
+
31
+ if merge_result & pygit2.GIT_MERGE_ANALYSIS_UP_TO_DATE:
32
+ print("Already up-to-date")
33
+ elif merge_result & pygit2.GIT_MERGE_ANALYSIS_FASTFORWARD:
34
+ local_branch.set_target(remote_commit.id)
35
+ repo.head.set_target(remote_commit.id)
36
+ repo.checkout_tree(repo.get(remote_commit.id))
37
+ repo.reset(local_branch.target, pygit2.GIT_RESET_HARD)
38
+ print("Fast-forward merge")
39
+ elif merge_result & pygit2.GIT_MERGE_ANALYSIS_NORMAL:
40
+ print("Update failed - Did you modify any file?")
41
+ except Exception as e:
42
+ print('Update failed.')
43
+ print(str(e))
44
+
45
+ print('Update succeeded.')
46
+ from launch import *
face_restoration_helper.py ADDED
@@ -0,0 +1,374 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import os
4
+ import torch
5
+ from torchvision.transforms.functional import normalize
6
+
7
+ from extras.facexlib.detection import init_detection_model
8
+ from extras.facexlib.parsing import init_parsing_model
9
+ from extras.facexlib.utils.misc import img2tensor, imwrite
10
+
11
+
12
+ def get_largest_face(det_faces, h, w):
13
+
14
+ def get_location(val, length):
15
+ if val < 0:
16
+ return 0
17
+ elif val > length:
18
+ return length
19
+ else:
20
+ return val
21
+
22
+ face_areas = []
23
+ for det_face in det_faces:
24
+ left = get_location(det_face[0], w)
25
+ right = get_location(det_face[2], w)
26
+ top = get_location(det_face[1], h)
27
+ bottom = get_location(det_face[3], h)
28
+ face_area = (right - left) * (bottom - top)
29
+ face_areas.append(face_area)
30
+ largest_idx = face_areas.index(max(face_areas))
31
+ return det_faces[largest_idx], largest_idx
32
+
33
+
34
+ def get_center_face(det_faces, h=0, w=0, center=None):
35
+ if center is not None:
36
+ center = np.array(center)
37
+ else:
38
+ center = np.array([w / 2, h / 2])
39
+ center_dist = []
40
+ for det_face in det_faces:
41
+ face_center = np.array([(det_face[0] + det_face[2]) / 2, (det_face[1] + det_face[3]) / 2])
42
+ dist = np.linalg.norm(face_center - center)
43
+ center_dist.append(dist)
44
+ center_idx = center_dist.index(min(center_dist))
45
+ return det_faces[center_idx], center_idx
46
+
47
+
48
+ class FaceRestoreHelper(object):
49
+ """Helper for the face restoration pipeline (base class)."""
50
+
51
+ def __init__(self,
52
+ upscale_factor,
53
+ face_size=512,
54
+ crop_ratio=(1, 1),
55
+ det_model='retinaface_resnet50',
56
+ save_ext='png',
57
+ template_3points=False,
58
+ pad_blur=False,
59
+ use_parse=False,
60
+ device=None,
61
+ model_rootpath=None):
62
+ self.template_3points = template_3points # improve robustness
63
+ self.upscale_factor = upscale_factor
64
+ # the cropped face ratio based on the square face
65
+ self.crop_ratio = crop_ratio # (h, w)
66
+ assert (self.crop_ratio[0] >= 1 and self.crop_ratio[1] >= 1), 'crop ration only supports >=1'
67
+ self.face_size = (int(face_size * self.crop_ratio[1]), int(face_size * self.crop_ratio[0]))
68
+
69
+ if self.template_3points:
70
+ self.face_template = np.array([[192, 240], [319, 240], [257, 371]])
71
+ else:
72
+ # standard 5 landmarks for FFHQ faces with 512 x 512
73
+ self.face_template = np.array([[192.98138, 239.94708], [318.90277, 240.1936], [256.63416, 314.01935],
74
+ [201.26117, 371.41043], [313.08905, 371.15118]])
75
+ self.face_template = self.face_template * (face_size / 512.0)
76
+ if self.crop_ratio[0] > 1:
77
+ self.face_template[:, 1] += face_size * (self.crop_ratio[0] - 1) / 2
78
+ if self.crop_ratio[1] > 1:
79
+ self.face_template[:, 0] += face_size * (self.crop_ratio[1] - 1) / 2
80
+ self.save_ext = save_ext
81
+ self.pad_blur = pad_blur
82
+ if self.pad_blur is True:
83
+ self.template_3points = False
84
+
85
+ self.all_landmarks_5 = []
86
+ self.det_faces = []
87
+ self.affine_matrices = []
88
+ self.inverse_affine_matrices = []
89
+ self.cropped_faces = []
90
+ self.restored_faces = []
91
+ self.pad_input_imgs = []
92
+
93
+ if device is None:
94
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
95
+ else:
96
+ self.device = device
97
+
98
+ # init face detection model
99
+ self.face_det = init_detection_model(det_model, half=False, device=self.device, model_rootpath=model_rootpath)
100
+
101
+ # init face parsing model
102
+ self.use_parse = use_parse
103
+ self.face_parse = init_parsing_model(model_name='parsenet', device=self.device, model_rootpath=model_rootpath)
104
+
105
+ def set_upscale_factor(self, upscale_factor):
106
+ self.upscale_factor = upscale_factor
107
+
108
+ def read_image(self, img):
109
+ """img can be image path or cv2 loaded image."""
110
+ # self.input_img is Numpy array, (h, w, c), BGR, uint8, [0, 255]
111
+ if isinstance(img, str):
112
+ img = cv2.imread(img)
113
+
114
+ if np.max(img) > 256: # 16-bit image
115
+ img = img / 65535 * 255
116
+ if len(img.shape) == 2: # gray image
117
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
118
+ elif img.shape[2] == 4: # RGBA image with alpha channel
119
+ img = img[:, :, 0:3]
120
+
121
+ self.input_img = img
122
+
123
+ def get_face_landmarks_5(self,
124
+ only_keep_largest=False,
125
+ only_center_face=False,
126
+ resize=None,
127
+ blur_ratio=0.01,
128
+ eye_dist_threshold=None):
129
+ if resize is None:
130
+ scale = 1
131
+ input_img = self.input_img
132
+ else:
133
+ h, w = self.input_img.shape[0:2]
134
+ scale = min(h, w) / resize
135
+ h, w = int(h / scale), int(w / scale)
136
+ input_img = cv2.resize(self.input_img, (w, h), interpolation=cv2.INTER_LANCZOS4)
137
+
138
+ with torch.no_grad():
139
+ bboxes = self.face_det.detect_faces(input_img, 0.97) * scale
140
+ for bbox in bboxes:
141
+ # remove faces with too small eye distance: side faces or too small faces
142
+ eye_dist = np.linalg.norm([bbox[5] - bbox[7], bbox[6] - bbox[8]])
143
+ if eye_dist_threshold is not None and (eye_dist < eye_dist_threshold):
144
+ continue
145
+
146
+ if self.template_3points:
147
+ landmark = np.array([[bbox[i], bbox[i + 1]] for i in range(5, 11, 2)])
148
+ else:
149
+ landmark = np.array([[bbox[i], bbox[i + 1]] for i in range(5, 15, 2)])
150
+ self.all_landmarks_5.append(landmark)
151
+ self.det_faces.append(bbox[0:5])
152
+ if len(self.det_faces) == 0:
153
+ return 0
154
+ if only_keep_largest:
155
+ h, w, _ = self.input_img.shape
156
+ self.det_faces, largest_idx = get_largest_face(self.det_faces, h, w)
157
+ self.all_landmarks_5 = [self.all_landmarks_5[largest_idx]]
158
+ elif only_center_face:
159
+ h, w, _ = self.input_img.shape
160
+ self.det_faces, center_idx = get_center_face(self.det_faces, h, w)
161
+ self.all_landmarks_5 = [self.all_landmarks_5[center_idx]]
162
+
163
+ # pad blurry images
164
+ if self.pad_blur:
165
+ self.pad_input_imgs = []
166
+ for landmarks in self.all_landmarks_5:
167
+ # get landmarks
168
+ eye_left = landmarks[0, :]
169
+ eye_right = landmarks[1, :]
170
+ eye_avg = (eye_left + eye_right) * 0.5
171
+ mouth_avg = (landmarks[3, :] + landmarks[4, :]) * 0.5
172
+ eye_to_eye = eye_right - eye_left
173
+ eye_to_mouth = mouth_avg - eye_avg
174
+
175
+ # Get the oriented crop rectangle
176
+ # x: half width of the oriented crop rectangle
177
+ x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1]
178
+ # - np.flipud(eye_to_mouth) * [-1, 1]: rotate 90 clockwise
179
+ # norm with the hypotenuse: get the direction
180
+ x /= np.hypot(*x) # get the hypotenuse of a right triangle
181
+ rect_scale = 1.5
182
+ x *= max(np.hypot(*eye_to_eye) * 2.0 * rect_scale, np.hypot(*eye_to_mouth) * 1.8 * rect_scale)
183
+ # y: half height of the oriented crop rectangle
184
+ y = np.flipud(x) * [-1, 1]
185
+
186
+ # c: center
187
+ c = eye_avg + eye_to_mouth * 0.1
188
+ # quad: (left_top, left_bottom, right_bottom, right_top)
189
+ quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y])
190
+ # qsize: side length of the square
191
+ qsize = np.hypot(*x) * 2
192
+ border = max(int(np.rint(qsize * 0.1)), 3)
193
+
194
+ # get pad
195
+ # pad: (width_left, height_top, width_right, height_bottom)
196
+ pad = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))),
197
+ int(np.ceil(max(quad[:, 1]))))
198
+ pad = [
199
+ max(-pad[0] + border, 1),
200
+ max(-pad[1] + border, 1),
201
+ max(pad[2] - self.input_img.shape[0] + border, 1),
202
+ max(pad[3] - self.input_img.shape[1] + border, 1)
203
+ ]
204
+
205
+ if max(pad) > 1:
206
+ # pad image
207
+ pad_img = np.pad(self.input_img, ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect')
208
+ # modify landmark coords
209
+ landmarks[:, 0] += pad[0]
210
+ landmarks[:, 1] += pad[1]
211
+ # blur pad images
212
+ h, w, _ = pad_img.shape
213
+ y, x, _ = np.ogrid[:h, :w, :1]
214
+ mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0],
215
+ np.float32(w - 1 - x) / pad[2]),
216
+ 1.0 - np.minimum(np.float32(y) / pad[1],
217
+ np.float32(h - 1 - y) / pad[3]))
218
+ blur = int(qsize * blur_ratio)
219
+ if blur % 2 == 0:
220
+ blur += 1
221
+ blur_img = cv2.boxFilter(pad_img, 0, ksize=(blur, blur))
222
+ # blur_img = cv2.GaussianBlur(pad_img, (blur, blur), 0)
223
+
224
+ pad_img = pad_img.astype('float32')
225
+ pad_img += (blur_img - pad_img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0)
226
+ pad_img += (np.median(pad_img, axis=(0, 1)) - pad_img) * np.clip(mask, 0.0, 1.0)
227
+ pad_img = np.clip(pad_img, 0, 255) # float32, [0, 255]
228
+ self.pad_input_imgs.append(pad_img)
229
+ else:
230
+ self.pad_input_imgs.append(np.copy(self.input_img))
231
+
232
+ return len(self.all_landmarks_5)
233
+
234
+ def align_warp_face(self, save_cropped_path=None, border_mode='constant'):
235
+ """Align and warp faces with face template.
236
+ """
237
+ if self.pad_blur:
238
+ assert len(self.pad_input_imgs) == len(
239
+ self.all_landmarks_5), f'Mismatched samples: {len(self.pad_input_imgs)} and {len(self.all_landmarks_5)}'
240
+ for idx, landmark in enumerate(self.all_landmarks_5):
241
+ # use 5 landmarks to get affine matrix
242
+ # use cv2.LMEDS method for the equivalence to skimage transform
243
+ # ref: https://blog.csdn.net/yichxi/article/details/115827338
244
+ affine_matrix = cv2.estimateAffinePartial2D(landmark, self.face_template, method=cv2.LMEDS)[0]
245
+ self.affine_matrices.append(affine_matrix)
246
+ # warp and crop faces
247
+ if border_mode == 'constant':
248
+ border_mode = cv2.BORDER_CONSTANT
249
+ elif border_mode == 'reflect101':
250
+ border_mode = cv2.BORDER_REFLECT101
251
+ elif border_mode == 'reflect':
252
+ border_mode = cv2.BORDER_REFLECT
253
+ if self.pad_blur:
254
+ input_img = self.pad_input_imgs[idx]
255
+ else:
256
+ input_img = self.input_img
257
+ cropped_face = cv2.warpAffine(
258
+ input_img, affine_matrix, self.face_size, borderMode=border_mode, borderValue=(135, 133, 132)) # gray
259
+ self.cropped_faces.append(cropped_face)
260
+ # save the cropped face
261
+ if save_cropped_path is not None:
262
+ path = os.path.splitext(save_cropped_path)[0]
263
+ save_path = f'{path}_{idx:02d}.{self.save_ext}'
264
+ imwrite(cropped_face, save_path)
265
+
266
+ def get_inverse_affine(self, save_inverse_affine_path=None):
267
+ """Get inverse affine matrix."""
268
+ for idx, affine_matrix in enumerate(self.affine_matrices):
269
+ inverse_affine = cv2.invertAffineTransform(affine_matrix)
270
+ inverse_affine *= self.upscale_factor
271
+ self.inverse_affine_matrices.append(inverse_affine)
272
+ # save inverse affine matrices
273
+ if save_inverse_affine_path is not None:
274
+ path, _ = os.path.splitext(save_inverse_affine_path)
275
+ save_path = f'{path}_{idx:02d}.pth'
276
+ torch.save(inverse_affine, save_path)
277
+
278
+ def add_restored_face(self, face):
279
+ self.restored_faces.append(face)
280
+
281
+ def paste_faces_to_input_image(self, save_path=None, upsample_img=None):
282
+ h, w, _ = self.input_img.shape
283
+ h_up, w_up = int(h * self.upscale_factor), int(w * self.upscale_factor)
284
+
285
+ if upsample_img is None:
286
+ # simply resize the background
287
+ upsample_img = cv2.resize(self.input_img, (w_up, h_up), interpolation=cv2.INTER_LANCZOS4)
288
+ else:
289
+ upsample_img = cv2.resize(upsample_img, (w_up, h_up), interpolation=cv2.INTER_LANCZOS4)
290
+
291
+ assert len(self.restored_faces) == len(
292
+ self.inverse_affine_matrices), ('length of restored_faces and affine_matrices are different.')
293
+ for restored_face, inverse_affine in zip(self.restored_faces, self.inverse_affine_matrices):
294
+ # Add an offset to inverse affine matrix, for more precise back alignment
295
+ if self.upscale_factor > 1:
296
+ extra_offset = 0.5 * self.upscale_factor
297
+ else:
298
+ extra_offset = 0
299
+ inverse_affine[:, 2] += extra_offset
300
+ inv_restored = cv2.warpAffine(restored_face, inverse_affine, (w_up, h_up))
301
+
302
+ if self.use_parse:
303
+ # inference
304
+ face_input = cv2.resize(restored_face, (512, 512), interpolation=cv2.INTER_LINEAR)
305
+ face_input = img2tensor(face_input.astype('float32') / 255., bgr2rgb=True, float32=True)
306
+ normalize(face_input, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
307
+ face_input = torch.unsqueeze(face_input, 0).to(self.device)
308
+ with torch.no_grad():
309
+ out = self.face_parse(face_input)[0]
310
+ out = out.argmax(dim=1).squeeze().cpu().numpy()
311
+
312
+ mask = np.zeros(out.shape)
313
+ MASK_COLORMAP = [0, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 0, 255, 0, 0, 0]
314
+ for idx, color in enumerate(MASK_COLORMAP):
315
+ mask[out == idx] = color
316
+ # blur the mask
317
+ mask = cv2.GaussianBlur(mask, (101, 101), 11)
318
+ mask = cv2.GaussianBlur(mask, (101, 101), 11)
319
+ # remove the black borders
320
+ thres = 10
321
+ mask[:thres, :] = 0
322
+ mask[-thres:, :] = 0
323
+ mask[:, :thres] = 0
324
+ mask[:, -thres:] = 0
325
+ mask = mask / 255.
326
+
327
+ mask = cv2.resize(mask, restored_face.shape[:2])
328
+ mask = cv2.warpAffine(mask, inverse_affine, (w_up, h_up), flags=3)
329
+ inv_soft_mask = mask[:, :, None]
330
+ pasted_face = inv_restored
331
+
332
+ else: # use square parse maps
333
+ mask = np.ones(self.face_size, dtype=np.float32)
334
+ inv_mask = cv2.warpAffine(mask, inverse_affine, (w_up, h_up))
335
+ # remove the black borders
336
+ inv_mask_erosion = cv2.erode(
337
+ inv_mask, np.ones((int(2 * self.upscale_factor), int(2 * self.upscale_factor)), np.uint8))
338
+ pasted_face = inv_mask_erosion[:, :, None] * inv_restored
339
+ total_face_area = np.sum(inv_mask_erosion) # // 3
340
+ # compute the fusion edge based on the area of face
341
+ w_edge = int(total_face_area**0.5) // 20
342
+ erosion_radius = w_edge * 2
343
+ inv_mask_center = cv2.erode(inv_mask_erosion, np.ones((erosion_radius, erosion_radius), np.uint8))
344
+ blur_size = w_edge * 2
345
+ inv_soft_mask = cv2.GaussianBlur(inv_mask_center, (blur_size + 1, blur_size + 1), 0)
346
+ if len(upsample_img.shape) == 2: # upsample_img is gray image
347
+ upsample_img = upsample_img[:, :, None]
348
+ inv_soft_mask = inv_soft_mask[:, :, None]
349
+
350
+ if len(upsample_img.shape) == 3 and upsample_img.shape[2] == 4: # alpha channel
351
+ alpha = upsample_img[:, :, 3:]
352
+ upsample_img = inv_soft_mask * pasted_face + (1 - inv_soft_mask) * upsample_img[:, :, 0:3]
353
+ upsample_img = np.concatenate((upsample_img, alpha), axis=2)
354
+ else:
355
+ upsample_img = inv_soft_mask * pasted_face + (1 - inv_soft_mask) * upsample_img
356
+
357
+ if np.max(upsample_img) > 256: # 16-bit image
358
+ upsample_img = upsample_img.astype(np.uint16)
359
+ else:
360
+ upsample_img = upsample_img.astype(np.uint8)
361
+ if save_path is not None:
362
+ path = os.path.splitext(save_path)[0]
363
+ save_path = f'{path}.{self.save_ext}'
364
+ imwrite(upsample_img, save_path)
365
+ return upsample_img
366
+
367
+ def clean_all(self):
368
+ self.all_landmarks_5 = []
369
+ self.restored_faces = []
370
+ self.affine_matrices = []
371
+ self.cropped_faces = []
372
+ self.inverse_affine_matrices = []
373
+ self.det_faces = []
374
+ self.pad_input_imgs = []
inpaint_worker 2.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+
4
+ from PIL import Image, ImageFilter
5
+ from modules.util import resample_image, set_image_shape_ceil, get_image_shape_ceil
6
+ from modules.upscaler import perform_upscale
7
+ import cv2
8
+
9
+
10
+ inpaint_head_model = None
11
+
12
+
13
+ class InpaintHead(torch.nn.Module):
14
+ def __init__(self, *args, **kwargs):
15
+ super().__init__(*args, **kwargs)
16
+ self.head = torch.nn.Parameter(torch.empty(size=(320, 5, 3, 3), device='cpu'))
17
+
18
+ def __call__(self, x):
19
+ x = torch.nn.functional.pad(x, (1, 1, 1, 1), "replicate")
20
+ return torch.nn.functional.conv2d(input=x, weight=self.head)
21
+
22
+
23
+ current_task = None
24
+
25
+
26
+ def box_blur(x, k):
27
+ x = Image.fromarray(x)
28
+ x = x.filter(ImageFilter.BoxBlur(k))
29
+ return np.array(x)
30
+
31
+
32
+ def max_filter_opencv(x, ksize=3):
33
+ # Use OpenCV maximum filter
34
+ # Make sure the input type is int16
35
+ return cv2.dilate(x, np.ones((ksize, ksize), dtype=np.int16))
36
+
37
+
38
+ def morphological_open(x):
39
+ # Convert array to int16 type via threshold operation
40
+ x_int16 = np.zeros_like(x, dtype=np.int16)
41
+ x_int16[x > 127] = 256
42
+
43
+ for i in range(32):
44
+ # Use int16 type to avoid overflow
45
+ maxed = max_filter_opencv(x_int16, ksize=3) - 8
46
+ x_int16 = np.maximum(maxed, x_int16)
47
+
48
+ # Clip negative values to 0 and convert back to uint8 type
49
+ x_uint8 = np.clip(x_int16, 0, 255).astype(np.uint8)
50
+ return x_uint8
51
+
52
+
53
+ def up255(x, t=0):
54
+ y = np.zeros_like(x).astype(np.uint8)
55
+ y[x > t] = 255
56
+ return y
57
+
58
+
59
+ def imsave(x, path):
60
+ x = Image.fromarray(x)
61
+ x.save(path)
62
+
63
+
64
+ def regulate_abcd(x, a, b, c, d):
65
+ H, W = x.shape[:2]
66
+ if a < 0:
67
+ a = 0
68
+ if a > H:
69
+ a = H
70
+ if b < 0:
71
+ b = 0
72
+ if b > H:
73
+ b = H
74
+ if c < 0:
75
+ c = 0
76
+ if c > W:
77
+ c = W
78
+ if d < 0:
79
+ d = 0
80
+ if d > W:
81
+ d = W
82
+ return int(a), int(b), int(c), int(d)
83
+
84
+
85
+ def compute_initial_abcd(x):
86
+ indices = np.where(x)
87
+ a = np.min(indices[0])
88
+ b = np.max(indices[0])
89
+ c = np.min(indices[1])
90
+ d = np.max(indices[1])
91
+ abp = (b + a) // 2
92
+ abm = (b - a) // 2
93
+ cdp = (d + c) // 2
94
+ cdm = (d - c) // 2
95
+ l = int(max(abm, cdm) * 1.15)
96
+ a = abp - l
97
+ b = abp + l + 1
98
+ c = cdp - l
99
+ d = cdp + l + 1
100
+ a, b, c, d = regulate_abcd(x, a, b, c, d)
101
+ return a, b, c, d
102
+
103
+
104
+ def solve_abcd(x, a, b, c, d, k):
105
+ k = float(k)
106
+ assert 0.0 <= k <= 1.0
107
+
108
+ H, W = x.shape[:2]
109
+ if k == 1.0:
110
+ return 0, H, 0, W
111
+ while True:
112
+ if b - a >= H * k and d - c >= W * k:
113
+ break
114
+
115
+ add_h = (b - a) < (d - c)
116
+ add_w = not add_h
117
+
118
+ if b - a == H:
119
+ add_w = True
120
+
121
+ if d - c == W:
122
+ add_h = True
123
+
124
+ if add_h:
125
+ a -= 1
126
+ b += 1
127
+
128
+ if add_w:
129
+ c -= 1
130
+ d += 1
131
+
132
+ a, b, c, d = regulate_abcd(x, a, b, c, d)
133
+ return a, b, c, d
134
+
135
+
136
+ def fooocus_fill(image, mask):
137
+ current_image = image.copy()
138
+ raw_image = image.copy()
139
+ area = np.where(mask < 127)
140
+ store = raw_image[area]
141
+
142
+ for k, repeats in [(512, 2), (256, 2), (128, 4), (64, 4), (33, 8), (15, 8), (5, 16), (3, 16)]:
143
+ for _ in range(repeats):
144
+ current_image = box_blur(current_image, k)
145
+ current_image[area] = store
146
+
147
+ return current_image
148
+
149
+
150
+ class InpaintWorker:
151
+ def __init__(self, image, mask, use_fill=True, k=0.618):
152
+ a, b, c, d = compute_initial_abcd(mask > 0)
153
+ a, b, c, d = solve_abcd(mask, a, b, c, d, k=k)
154
+
155
+ # interested area
156
+ self.interested_area = (a, b, c, d)
157
+ self.interested_mask = mask[a:b, c:d]
158
+ self.interested_image = image[a:b, c:d]
159
+
160
+ # super resolution
161
+ if get_image_shape_ceil(self.interested_image) < 1024:
162
+ self.interested_image = perform_upscale(self.interested_image)
163
+
164
+ # resize to make images ready for diffusion
165
+ self.interested_image = set_image_shape_ceil(self.interested_image, 1024)
166
+ self.interested_fill = self.interested_image.copy()
167
+ H, W, C = self.interested_image.shape
168
+
169
+ # process mask
170
+ self.interested_mask = up255(resample_image(self.interested_mask, W, H), t=127)
171
+
172
+ # compute filling
173
+ if use_fill:
174
+ self.interested_fill = fooocus_fill(self.interested_image, self.interested_mask)
175
+
176
+ # soft pixels
177
+ self.mask = morphological_open(mask)
178
+ self.image = image
179
+
180
+ # ending
181
+ self.latent = None
182
+ self.latent_after_swap = None
183
+ self.swapped = False
184
+ self.latent_mask = None
185
+ self.inpaint_head_feature = None
186
+ return
187
+
188
+ def load_latent(self, latent_fill, latent_mask, latent_swap=None):
189
+ self.latent = latent_fill
190
+ self.latent_mask = latent_mask
191
+ self.latent_after_swap = latent_swap
192
+ return
193
+
194
+ def patch(self, inpaint_head_model_path, inpaint_latent, inpaint_latent_mask, model):
195
+ global inpaint_head_model
196
+
197
+ if inpaint_head_model is None:
198
+ inpaint_head_model = InpaintHead()
199
+ sd = torch.load(inpaint_head_model_path, map_location='cpu')
200
+ inpaint_head_model.load_state_dict(sd)
201
+
202
+ feed = torch.cat([
203
+ inpaint_latent_mask,
204
+ model.model.process_latent_in(inpaint_latent)
205
+ ], dim=1)
206
+
207
+ inpaint_head_model.to(device=feed.device, dtype=feed.dtype)
208
+ inpaint_head_feature = inpaint_head_model(feed)
209
+
210
+ def input_block_patch(h, transformer_options):
211
+ if transformer_options["block"][1] == 0:
212
+ h = h + inpaint_head_feature.to(h)
213
+ return h
214
+
215
+ m = model.clone()
216
+ m.set_model_input_block_patch(input_block_patch)
217
+ return m
218
+
219
+ def swap(self):
220
+ if self.swapped:
221
+ return
222
+
223
+ if self.latent is None:
224
+ return
225
+
226
+ if self.latent_after_swap is None:
227
+ return
228
+
229
+ self.latent, self.latent_after_swap = self.latent_after_swap, self.latent
230
+ self.swapped = True
231
+ return
232
+
233
+ def unswap(self):
234
+ if not self.swapped:
235
+ return
236
+
237
+ if self.latent is None:
238
+ return
239
+
240
+ if self.latent_after_swap is None:
241
+ return
242
+
243
+ self.latent, self.latent_after_swap = self.latent_after_swap, self.latent
244
+ self.swapped = False
245
+ return
246
+
247
+ def color_correction(self, img):
248
+ fg = img.astype(np.float32)
249
+ bg = self.image.copy().astype(np.float32)
250
+ w = self.mask[:, :, None].astype(np.float32) / 255.0
251
+ y = fg * w + bg * (1 - w)
252
+ return y.clip(0, 255).astype(np.uint8)
253
+
254
+ def post_process(self, img):
255
+ a, b, c, d = self.interested_area
256
+ content = resample_image(img, d - c, b - a)
257
+ result = self.image.copy()
258
+ result[a:b, c:d] = content
259
+ result = self.color_correction(result)
260
+ return result
261
+
262
+ def visualize_mask_processing(self):
263
+ return [self.interested_fill, self.interested_mask, self.interested_image]
264
+
inpaint_worker.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+
4
+ from PIL import Image, ImageFilter
5
+ from modules.util import resample_image, set_image_shape_ceil, get_image_shape_ceil
6
+ from modules.upscaler import perform_upscale
7
+ import cv2
8
+
9
+
10
+ inpaint_head_model = None
11
+
12
+
13
+ class InpaintHead(torch.nn.Module):
14
+ def __init__(self, *args, **kwargs):
15
+ super().__init__(*args, **kwargs)
16
+ self.head = torch.nn.Parameter(torch.empty(size=(320, 5, 3, 3), device='cpu'))
17
+
18
+ def __call__(self, x):
19
+ x = torch.nn.functional.pad(x, (1, 1, 1, 1), "replicate")
20
+ return torch.nn.functional.conv2d(input=x, weight=self.head)
21
+
22
+
23
+ current_task = None
24
+
25
+
26
+ def box_blur(x, k):
27
+ x = Image.fromarray(x)
28
+ x = x.filter(ImageFilter.BoxBlur(k))
29
+ return np.array(x)
30
+
31
+
32
+ def max_filter_opencv(x, ksize=3):
33
+ # Use OpenCV maximum filter
34
+ # Make sure the input type is int16
35
+ return cv2.dilate(x, np.ones((ksize, ksize), dtype=np.int16))
36
+
37
+
38
+ def morphological_open(x):
39
+ # Convert array to int16 type via threshold operation
40
+ x_int16 = np.zeros_like(x, dtype=np.int16)
41
+ x_int16[x > 127] = 256
42
+
43
+ for i in range(32):
44
+ # Use int16 type to avoid overflow
45
+ maxed = max_filter_opencv(x_int16, ksize=3) - 8
46
+ x_int16 = np.maximum(maxed, x_int16)
47
+
48
+ # Clip negative values to 0 and convert back to uint8 type
49
+ x_uint8 = np.clip(x_int16, 0, 255).astype(np.uint8)
50
+ return x_uint8
51
+
52
+
53
+ def up255(x, t=0):
54
+ y = np.zeros_like(x).astype(np.uint8)
55
+ y[x > t] = 255
56
+ return y
57
+
58
+
59
+ def imsave(x, path):
60
+ x = Image.fromarray(x)
61
+ x.save(path)
62
+
63
+
64
+ def regulate_abcd(x, a, b, c, d):
65
+ H, W = x.shape[:2]
66
+ if a < 0:
67
+ a = 0
68
+ if a > H:
69
+ a = H
70
+ if b < 0:
71
+ b = 0
72
+ if b > H:
73
+ b = H
74
+ if c < 0:
75
+ c = 0
76
+ if c > W:
77
+ c = W
78
+ if d < 0:
79
+ d = 0
80
+ if d > W:
81
+ d = W
82
+ return int(a), int(b), int(c), int(d)
83
+
84
+
85
+ def compute_initial_abcd(x):
86
+ indices = np.where(x)
87
+ a = np.min(indices[0])
88
+ b = np.max(indices[0])
89
+ c = np.min(indices[1])
90
+ d = np.max(indices[1])
91
+ abp = (b + a) // 2
92
+ abm = (b - a) // 2
93
+ cdp = (d + c) // 2
94
+ cdm = (d - c) // 2
95
+ l = int(max(abm, cdm) * 1.15)
96
+ a = abp - l
97
+ b = abp + l + 1
98
+ c = cdp - l
99
+ d = cdp + l + 1
100
+ a, b, c, d = regulate_abcd(x, a, b, c, d)
101
+ return a, b, c, d
102
+
103
+
104
+ def solve_abcd(x, a, b, c, d, k):
105
+ k = float(k)
106
+ assert 0.0 <= k <= 1.0
107
+
108
+ H, W = x.shape[:2]
109
+ if k == 1.0:
110
+ return 0, H, 0, W
111
+ while True:
112
+ if b - a >= H * k and d - c >= W * k:
113
+ break
114
+
115
+ add_h = (b - a) < (d - c)
116
+ add_w = not add_h
117
+
118
+ if b - a == H:
119
+ add_w = True
120
+
121
+ if d - c == W:
122
+ add_h = True
123
+
124
+ if add_h:
125
+ a -= 1
126
+ b += 1
127
+
128
+ if add_w:
129
+ c -= 1
130
+ d += 1
131
+
132
+ a, b, c, d = regulate_abcd(x, a, b, c, d)
133
+ return a, b, c, d
134
+
135
+
136
+ def fooocus_fill(image, mask):
137
+ current_image = image.copy()
138
+ raw_image = image.copy()
139
+ area = np.where(mask < 127)
140
+ store = raw_image[area]
141
+
142
+ for k, repeats in [(512, 2), (256, 2), (128, 4), (64, 4), (33, 8), (15, 8), (5, 16), (3, 16)]:
143
+ for _ in range(repeats):
144
+ current_image = box_blur(current_image, k)
145
+ current_image[area] = store
146
+
147
+ return current_image
148
+
149
+
150
+ class InpaintWorker:
151
+ def __init__(self, image, mask, use_fill=True, k=0.618):
152
+ a, b, c, d = compute_initial_abcd(mask > 0)
153
+ a, b, c, d = solve_abcd(mask, a, b, c, d, k=k)
154
+
155
+ # interested area
156
+ self.interested_area = (a, b, c, d)
157
+ self.interested_mask = mask[a:b, c:d]
158
+ self.interested_image = image[a:b, c:d]
159
+
160
+ # super resolution
161
+ if get_image_shape_ceil(self.interested_image) < 1024:
162
+ self.interested_image = perform_upscale(self.interested_image)
163
+
164
+ # resize to make images ready for diffusion
165
+ self.interested_image = set_image_shape_ceil(self.interested_image, 1024)
166
+ self.interested_fill = self.interested_image.copy()
167
+ H, W, C = self.interested_image.shape
168
+
169
+ # process mask
170
+ self.interested_mask = up255(resample_image(self.interested_mask, W, H), t=127)
171
+
172
+ # compute filling
173
+ if use_fill:
174
+ self.interested_fill = fooocus_fill(self.interested_image, self.interested_mask)
175
+
176
+ # soft pixels
177
+ self.mask = morphological_open(mask)
178
+ self.image = image
179
+
180
+ # ending
181
+ self.latent = None
182
+ self.latent_after_swap = None
183
+ self.swapped = False
184
+ self.latent_mask = None
185
+ self.inpaint_head_feature = None
186
+ return
187
+
188
+ def load_latent(self, latent_fill, latent_mask, latent_swap=None):
189
+ self.latent = latent_fill
190
+ self.latent_mask = latent_mask
191
+ self.latent_after_swap = latent_swap
192
+ return
193
+
194
+ def patch(self, inpaint_head_model_path, inpaint_latent, inpaint_latent_mask, model):
195
+ global inpaint_head_model
196
+
197
+ if inpaint_head_model is None:
198
+ inpaint_head_model = InpaintHead()
199
+ sd = torch.load(inpaint_head_model_path, map_location='cpu')
200
+ inpaint_head_model.load_state_dict(sd)
201
+
202
+ feed = torch.cat([
203
+ inpaint_latent_mask,
204
+ model.model.process_latent_in(inpaint_latent)
205
+ ], dim=1)
206
+
207
+ inpaint_head_model.to(device=feed.device, dtype=feed.dtype)
208
+ inpaint_head_feature = inpaint_head_model(feed)
209
+
210
+ def input_block_patch(h, transformer_options):
211
+ if transformer_options["block"][1] == 0:
212
+ h = h + inpaint_head_feature.to(h)
213
+ return h
214
+
215
+ m = model.clone()
216
+ m.set_model_input_block_patch(input_block_patch)
217
+ return m
218
+
219
+ def swap(self):
220
+ if self.swapped:
221
+ return
222
+
223
+ if self.latent is None:
224
+ return
225
+
226
+ if self.latent_after_swap is None:
227
+ return
228
+
229
+ self.latent, self.latent_after_swap = self.latent_after_swap, self.latent
230
+ self.swapped = True
231
+ return
232
+
233
+ def unswap(self):
234
+ if not self.swapped:
235
+ return
236
+
237
+ if self.latent is None:
238
+ return
239
+
240
+ if self.latent_after_swap is None:
241
+ return
242
+
243
+ self.latent, self.latent_after_swap = self.latent_after_swap, self.latent
244
+ self.swapped = False
245
+ return
246
+
247
+ def color_correction(self, img):
248
+ fg = img.astype(np.float32)
249
+ bg = self.image.copy().astype(np.float32)
250
+ w = self.mask[:, :, None].astype(np.float32) / 255.0
251
+ y = fg * w + bg * (1 - w)
252
+ return y.clip(0, 255).astype(np.uint8)
253
+
254
+ def post_process(self, img):
255
+ a, b, c, d = self.interested_area
256
+ content = resample_image(img, d - c, b - a)
257
+ result = self.image.copy()
258
+ result[a:b, c:d] = content
259
+ result = self.color_correction(result)
260
+ return result
261
+
262
+ def visualize_mask_processing(self):
263
+ return [self.interested_fill, self.interested_mask, self.interested_image]
264
+
launch_util.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import importlib
3
+ import importlib.util
4
+ import subprocess
5
+ import sys
6
+ import re
7
+ import logging
8
+ import importlib.metadata
9
+ import packaging.version
10
+ from packaging.requirements import Requirement
11
+
12
+
13
+
14
+
15
+ logging.getLogger("torch.distributed.nn").setLevel(logging.ERROR) # sshh...
16
+ logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
17
+
18
+ re_requirement = re.compile(r"\s*([-_a-zA-Z0-9]+)\s*(?:==\s*([-+_.a-zA-Z0-9]+))?\s*")
19
+
20
+ python = sys.executable
21
+ default_command_live = (os.environ.get('LAUNCH_LIVE_OUTPUT') == "1")
22
+ index_url = os.environ.get('INDEX_URL', "")
23
+
24
+ modules_path = os.path.dirname(os.path.realpath(__file__))
25
+ script_path = os.path.dirname(modules_path)
26
+
27
+
28
+ def is_installed(package):
29
+ try:
30
+ spec = importlib.util.find_spec(package)
31
+ except ModuleNotFoundError:
32
+ return False
33
+
34
+ return spec is not None
35
+
36
+
37
+ def run(command, desc=None, errdesc=None, custom_env=None, live: bool = default_command_live) -> str:
38
+ if desc is not None:
39
+ print(desc)
40
+
41
+ run_kwargs = {
42
+ "args": command,
43
+ "shell": True,
44
+ "env": os.environ if custom_env is None else custom_env,
45
+ "encoding": 'utf8',
46
+ "errors": 'ignore',
47
+ }
48
+
49
+ if not live:
50
+ run_kwargs["stdout"] = run_kwargs["stderr"] = subprocess.PIPE
51
+
52
+ result = subprocess.run(**run_kwargs)
53
+
54
+ if result.returncode != 0:
55
+ error_bits = [
56
+ f"{errdesc or 'Error running command'}.",
57
+ f"Command: {command}",
58
+ f"Error code: {result.returncode}",
59
+ ]
60
+ if result.stdout:
61
+ error_bits.append(f"stdout: {result.stdout}")
62
+ if result.stderr:
63
+ error_bits.append(f"stderr: {result.stderr}")
64
+ raise RuntimeError("\n".join(error_bits))
65
+
66
+ return (result.stdout or "")
67
+
68
+
69
+ def run_pip(command, desc=None, live=default_command_live):
70
+ try:
71
+ index_url_line = f' --index-url {index_url}' if index_url != '' else ''
72
+ return run(f'"{python}" -m pip {command} --prefer-binary{index_url_line}', desc=f"Installing {desc}",
73
+ errdesc=f"Couldn't install {desc}", live=live)
74
+ except Exception as e:
75
+ print(e)
76
+ print(f'CMD Failed {desc}: {command}')
77
+ return None
78
+
79
+
80
+ def requirements_met(requirements_file):
81
+ with open(requirements_file, "r", encoding="utf8") as file:
82
+ for line in file:
83
+ line = line.strip()
84
+ if line == "" or line.startswith('#'):
85
+ continue
86
+
87
+ requirement = Requirement(line)
88
+ package = requirement.name
89
+
90
+ try:
91
+ version_installed = importlib.metadata.version(package)
92
+ installed_version = packaging.version.parse(version_installed)
93
+
94
+ # Check if the installed version satisfies the requirement
95
+ if installed_version not in requirement.specifier:
96
+ print(f"Version mismatch for {package}: Installed version {version_installed} does not meet requirement {requirement}")
97
+ return False
98
+ except Exception as e:
99
+ print(f"Error checking version for {package}: {e}")
100
+ return False
101
+
102
+ return True
103
+
lora.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def match_lora(lora, to_load):
2
+ patch_dict = {}
3
+ loaded_keys = set()
4
+ for x in to_load:
5
+ real_load_key = to_load[x]
6
+ if real_load_key in lora:
7
+ patch_dict[real_load_key] = ('fooocus', lora[real_load_key])
8
+ loaded_keys.add(real_load_key)
9
+ continue
10
+
11
+ alpha_name = "{}.alpha".format(x)
12
+ alpha = None
13
+ if alpha_name in lora.keys():
14
+ alpha = lora[alpha_name].item()
15
+ loaded_keys.add(alpha_name)
16
+
17
+ regular_lora = "{}.lora_up.weight".format(x)
18
+ diffusers_lora = "{}_lora.up.weight".format(x)
19
+ transformers_lora = "{}.lora_linear_layer.up.weight".format(x)
20
+ A_name = None
21
+
22
+ if regular_lora in lora.keys():
23
+ A_name = regular_lora
24
+ B_name = "{}.lora_down.weight".format(x)
25
+ mid_name = "{}.lora_mid.weight".format(x)
26
+ elif diffusers_lora in lora.keys():
27
+ A_name = diffusers_lora
28
+ B_name = "{}_lora.down.weight".format(x)
29
+ mid_name = None
30
+ elif transformers_lora in lora.keys():
31
+ A_name = transformers_lora
32
+ B_name ="{}.lora_linear_layer.down.weight".format(x)
33
+ mid_name = None
34
+
35
+ if A_name is not None:
36
+ mid = None
37
+ if mid_name is not None and mid_name in lora.keys():
38
+ mid = lora[mid_name]
39
+ loaded_keys.add(mid_name)
40
+ patch_dict[to_load[x]] = ("lora", (lora[A_name], lora[B_name], alpha, mid))
41
+ loaded_keys.add(A_name)
42
+ loaded_keys.add(B_name)
43
+
44
+
45
+ ######## loha
46
+ hada_w1_a_name = "{}.hada_w1_a".format(x)
47
+ hada_w1_b_name = "{}.hada_w1_b".format(x)
48
+ hada_w2_a_name = "{}.hada_w2_a".format(x)
49
+ hada_w2_b_name = "{}.hada_w2_b".format(x)
50
+ hada_t1_name = "{}.hada_t1".format(x)
51
+ hada_t2_name = "{}.hada_t2".format(x)
52
+ if hada_w1_a_name in lora.keys():
53
+ hada_t1 = None
54
+ hada_t2 = None
55
+ if hada_t1_name in lora.keys():
56
+ hada_t1 = lora[hada_t1_name]
57
+ hada_t2 = lora[hada_t2_name]
58
+ loaded_keys.add(hada_t1_name)
59
+ loaded_keys.add(hada_t2_name)
60
+
61
+ patch_dict[to_load[x]] = ("loha", (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name], hada_t1, hada_t2))
62
+ loaded_keys.add(hada_w1_a_name)
63
+ loaded_keys.add(hada_w1_b_name)
64
+ loaded_keys.add(hada_w2_a_name)
65
+ loaded_keys.add(hada_w2_b_name)
66
+
67
+
68
+ ######## lokr
69
+ lokr_w1_name = "{}.lokr_w1".format(x)
70
+ lokr_w2_name = "{}.lokr_w2".format(x)
71
+ lokr_w1_a_name = "{}.lokr_w1_a".format(x)
72
+ lokr_w1_b_name = "{}.lokr_w1_b".format(x)
73
+ lokr_t2_name = "{}.lokr_t2".format(x)
74
+ lokr_w2_a_name = "{}.lokr_w2_a".format(x)
75
+ lokr_w2_b_name = "{}.lokr_w2_b".format(x)
76
+
77
+ lokr_w1 = None
78
+ if lokr_w1_name in lora.keys():
79
+ lokr_w1 = lora[lokr_w1_name]
80
+ loaded_keys.add(lokr_w1_name)
81
+
82
+ lokr_w2 = None
83
+ if lokr_w2_name in lora.keys():
84
+ lokr_w2 = lora[lokr_w2_name]
85
+ loaded_keys.add(lokr_w2_name)
86
+
87
+ lokr_w1_a = None
88
+ if lokr_w1_a_name in lora.keys():
89
+ lokr_w1_a = lora[lokr_w1_a_name]
90
+ loaded_keys.add(lokr_w1_a_name)
91
+
92
+ lokr_w1_b = None
93
+ if lokr_w1_b_name in lora.keys():
94
+ lokr_w1_b = lora[lokr_w1_b_name]
95
+ loaded_keys.add(lokr_w1_b_name)
96
+
97
+ lokr_w2_a = None
98
+ if lokr_w2_a_name in lora.keys():
99
+ lokr_w2_a = lora[lokr_w2_a_name]
100
+ loaded_keys.add(lokr_w2_a_name)
101
+
102
+ lokr_w2_b = None
103
+ if lokr_w2_b_name in lora.keys():
104
+ lokr_w2_b = lora[lokr_w2_b_name]
105
+ loaded_keys.add(lokr_w2_b_name)
106
+
107
+ lokr_t2 = None
108
+ if lokr_t2_name in lora.keys():
109
+ lokr_t2 = lora[lokr_t2_name]
110
+ loaded_keys.add(lokr_t2_name)
111
+
112
+ if (lokr_w1 is not None) or (lokr_w2 is not None) or (lokr_w1_a is not None) or (lokr_w2_a is not None):
113
+ patch_dict[to_load[x]] = ("lokr", (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2))
114
+
115
+ #glora
116
+ a1_name = "{}.a1.weight".format(x)
117
+ a2_name = "{}.a2.weight".format(x)
118
+ b1_name = "{}.b1.weight".format(x)
119
+ b2_name = "{}.b2.weight".format(x)
120
+ if a1_name in lora:
121
+ patch_dict[to_load[x]] = ("glora", (lora[a1_name], lora[a2_name], lora[b1_name], lora[b2_name], alpha))
122
+ loaded_keys.add(a1_name)
123
+ loaded_keys.add(a2_name)
124
+ loaded_keys.add(b1_name)
125
+ loaded_keys.add(b2_name)
126
+
127
+ w_norm_name = "{}.w_norm".format(x)
128
+ b_norm_name = "{}.b_norm".format(x)
129
+ w_norm = lora.get(w_norm_name, None)
130
+ b_norm = lora.get(b_norm_name, None)
131
+
132
+ if w_norm is not None:
133
+ loaded_keys.add(w_norm_name)
134
+ patch_dict[to_load[x]] = ("diff", (w_norm,))
135
+ if b_norm is not None:
136
+ loaded_keys.add(b_norm_name)
137
+ patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = ("diff", (b_norm,))
138
+
139
+ diff_name = "{}.diff".format(x)
140
+ diff_weight = lora.get(diff_name, None)
141
+ if diff_weight is not None:
142
+ patch_dict[to_load[x]] = ("diff", (diff_weight,))
143
+ loaded_keys.add(diff_name)
144
+
145
+ diff_bias_name = "{}.diff_b".format(x)
146
+ diff_bias = lora.get(diff_bias_name, None)
147
+ if diff_bias is not None:
148
+ patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = ("diff", (diff_bias,))
149
+ loaded_keys.add(diff_bias_name)
150
+
151
+ remaining_dict = {x: y for x, y in lora.items() if x not in loaded_keys}
152
+ return patch_dict, remaining_dict
model_loader.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from urllib.parse import urlparse
3
+ from typing import Optional
4
+
5
+
6
+ def load_file_from_url(
7
+ url: str,
8
+ *,
9
+ model_dir: str,
10
+ progress: bool = True,
11
+ file_name: Optional[str] = None,
12
+ ) -> str:
13
+ """Download a file from `url` into `model_dir`, using the file present if possible.
14
+
15
+ Returns the path to the downloaded file.
16
+ """
17
+ os.makedirs(model_dir, exist_ok=True)
18
+ if not file_name:
19
+ parts = urlparse(url)
20
+ file_name = os.path.basename(parts.path)
21
+ cached_file = os.path.abspath(os.path.join(model_dir, file_name))
22
+ if not os.path.exists(cached_file):
23
+ print(f'Downloading: "{url}" to {cached_file}\n')
24
+ from torch.hub import download_url_to_file
25
+ download_url_to_file(url, cached_file, progress=progress)
26
+ return cached_file
sdxl_styles.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import json
4
+
5
+ from modules.util import get_files_from_folder
6
+
7
+
8
+ # cannot use modules.config - validators causing circular imports
9
+ styles_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '../sdxl_styles/'))
10
+ wildcards_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '../wildcards/'))
11
+ wildcards_max_bfs_depth = 64
12
+
13
+
14
+ def normalize_key(k):
15
+ k = k.replace('-', ' ')
16
+ words = k.split(' ')
17
+ words = [w[:1].upper() + w[1:].lower() for w in words]
18
+ k = ' '.join(words)
19
+ k = k.replace('3d', '3D')
20
+ k = k.replace('Sai', 'SAI')
21
+ k = k.replace('Mre', 'MRE')
22
+ k = k.replace('(s', '(S')
23
+ return k
24
+
25
+
26
+ styles = {}
27
+
28
+ styles_files = get_files_from_folder(styles_path, ['.json'])
29
+
30
+ for x in ['sdxl_styles_fooocus.json',
31
+ 'sdxl_styles_sai.json',
32
+ 'sdxl_styles_mre.json',
33
+ 'sdxl_styles_twri.json',
34
+ 'sdxl_styles_diva.json',
35
+ 'sdxl_styles_marc_k3nt3l.json']:
36
+ if x in styles_files:
37
+ styles_files.remove(x)
38
+ styles_files.append(x)
39
+
40
+ for styles_file in styles_files:
41
+ try:
42
+ with open(os.path.join(styles_path, styles_file), encoding='utf-8') as f:
43
+ for entry in json.load(f):
44
+ name = normalize_key(entry['name'])
45
+ prompt = entry['prompt'] if 'prompt' in entry else ''
46
+ negative_prompt = entry['negative_prompt'] if 'negative_prompt' in entry else ''
47
+ styles[name] = (prompt, negative_prompt)
48
+ except Exception as e:
49
+ print(str(e))
50
+ print(f'Failed to load style file {styles_file}')
51
+
52
+ style_keys = list(styles.keys())
53
+ fooocus_expansion = "Fooocus V2"
54
+ legal_style_names = [fooocus_expansion] + style_keys
55
+
56
+
57
+ def apply_style(style, positive):
58
+ p, n = styles[style]
59
+ return p.replace('{prompt}', positive).splitlines(), n.splitlines()
60
+
61
+
62
+ def apply_wildcards(wildcard_text, rng, directory=wildcards_path):
63
+ for _ in range(wildcards_max_bfs_depth):
64
+ placeholders = re.findall(r'__([\w-]+)__', wildcard_text)
65
+ if len(placeholders) == 0:
66
+ return wildcard_text
67
+
68
+ print(f'[Wildcards] processing: {wildcard_text}')
69
+ for placeholder in placeholders:
70
+ try:
71
+ words = open(os.path.join(directory, f'{placeholder}.txt'), encoding='utf-8').read().splitlines()
72
+ words = [x for x in words if x != '']
73
+ assert len(words) > 0
74
+ wildcard_text = wildcard_text.replace(f'__{placeholder}__', rng.choice(words), 1)
75
+ except:
76
+ print(f'[Wildcards] Warning: {placeholder}.txt missing or empty. '
77
+ f'Using "{placeholder}" as a normal word.')
78
+ wildcard_text = wildcard_text.replace(f'__{placeholder}__', placeholder)
79
+ print(f'[Wildcards] {wildcard_text}')
80
+
81
+ print(f'[Wildcards] BFS stack overflow. Current text: {wildcard_text}')
82
+ return wildcard_text
upscaler.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import modules.core as core
4
+
5
+ from ldm_patched.pfn.architecture.RRDB import RRDBNet as ESRGAN
6
+ from ldm_patched.contrib.external_upscale_model import ImageUpscaleWithModel
7
+ from collections import OrderedDict
8
+ from modules.config import path_upscale_models
9
+
10
+ model_filename = os.path.join(path_upscale_models, 'fooocus_upscaler_s409985e5.bin')
11
+ opImageUpscaleWithModel = ImageUpscaleWithModel()
12
+ model = None
13
+
14
+
15
+ def perform_upscale(img):
16
+ global model
17
+
18
+ print(f'Upscaling image with shape {str(img.shape)} ...')
19
+
20
+ if model is None:
21
+ sd = torch.load(model_filename)
22
+ sdo = OrderedDict()
23
+ for k, v in sd.items():
24
+ sdo[k.replace('residual_block_', 'RDB')] = v
25
+ del sd
26
+ model = ESRGAN(sdo)
27
+ model.cpu()
28
+ model.eval()
29
+
30
+ img = core.numpy_to_pytorch(img)
31
+ img = opImageUpscaleWithModel.upscale(model, img)[0]
32
+ img = core.pytorch_to_numpy(img)[0]
33
+
34
+ return img
util.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import datetime
3
+ import random
4
+ import math
5
+ import os
6
+ import cv2
7
+
8
+ from PIL import Image
9
+
10
+
11
+ LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
12
+
13
+
14
+ def erode_or_dilate(x, k):
15
+ k = int(k)
16
+ if k > 0:
17
+ return cv2.dilate(x, kernel=np.ones(shape=(3, 3), dtype=np.uint8), iterations=k)
18
+ if k < 0:
19
+ return cv2.erode(x, kernel=np.ones(shape=(3, 3), dtype=np.uint8), iterations=-k)
20
+ return x
21
+
22
+
23
+ def resample_image(im, width, height):
24
+ im = Image.fromarray(im)
25
+ im = im.resize((int(width), int(height)), resample=LANCZOS)
26
+ return np.array(im)
27
+
28
+
29
+ def resize_image(im, width, height, resize_mode=1):
30
+ """
31
+ Resizes an image with the specified resize_mode, width, and height.
32
+
33
+ Args:
34
+ resize_mode: The mode to use when resizing the image.
35
+ 0: Resize the image to the specified width and height.
36
+ 1: Resize the image to fill the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, cropping the excess.
37
+ 2: Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, filling empty with data from image.
38
+ im: The image to resize.
39
+ width: The width to resize the image to.
40
+ height: The height to resize the image to.
41
+ """
42
+
43
+ im = Image.fromarray(im)
44
+
45
+ def resize(im, w, h):
46
+ return im.resize((w, h), resample=LANCZOS)
47
+
48
+ if resize_mode == 0:
49
+ res = resize(im, width, height)
50
+
51
+ elif resize_mode == 1:
52
+ ratio = width / height
53
+ src_ratio = im.width / im.height
54
+
55
+ src_w = width if ratio > src_ratio else im.width * height // im.height
56
+ src_h = height if ratio <= src_ratio else im.height * width // im.width
57
+
58
+ resized = resize(im, src_w, src_h)
59
+ res = Image.new("RGB", (width, height))
60
+ res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
61
+
62
+ else:
63
+ ratio = width / height
64
+ src_ratio = im.width / im.height
65
+
66
+ src_w = width if ratio < src_ratio else im.width * height // im.height
67
+ src_h = height if ratio >= src_ratio else im.height * width // im.width
68
+
69
+ resized = resize(im, src_w, src_h)
70
+ res = Image.new("RGB", (width, height))
71
+ res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
72
+
73
+ if ratio < src_ratio:
74
+ fill_height = height // 2 - src_h // 2
75
+ if fill_height > 0:
76
+ res.paste(resized.resize((width, fill_height), box=(0, 0, width, 0)), box=(0, 0))
77
+ res.paste(resized.resize((width, fill_height), box=(0, resized.height, width, resized.height)), box=(0, fill_height + src_h))
78
+ elif ratio > src_ratio:
79
+ fill_width = width // 2 - src_w // 2
80
+ if fill_width > 0:
81
+ res.paste(resized.resize((fill_width, height), box=(0, 0, 0, height)), box=(0, 0))
82
+ res.paste(resized.resize((fill_width, height), box=(resized.width, 0, resized.width, height)), box=(fill_width + src_w, 0))
83
+
84
+ return np.array(res)
85
+
86
+
87
+ def get_shape_ceil(h, w):
88
+ return math.ceil(((h * w) ** 0.5) / 64.0) * 64.0
89
+
90
+
91
+ def get_image_shape_ceil(im):
92
+ H, W = im.shape[:2]
93
+ return get_shape_ceil(H, W)
94
+
95
+
96
+ def set_image_shape_ceil(im, shape_ceil):
97
+ shape_ceil = float(shape_ceil)
98
+
99
+ H_origin, W_origin, _ = im.shape
100
+ H, W = H_origin, W_origin
101
+
102
+ for _ in range(256):
103
+ current_shape_ceil = get_shape_ceil(H, W)
104
+ if abs(current_shape_ceil - shape_ceil) < 0.1:
105
+ break
106
+ k = shape_ceil / current_shape_ceil
107
+ H = int(round(float(H) * k / 64.0) * 64)
108
+ W = int(round(float(W) * k / 64.0) * 64)
109
+
110
+ if H == H_origin and W == W_origin:
111
+ return im
112
+
113
+ return resample_image(im, width=W, height=H)
114
+
115
+
116
+ def HWC3(x):
117
+ assert x.dtype == np.uint8
118
+ if x.ndim == 2:
119
+ x = x[:, :, None]
120
+ assert x.ndim == 3
121
+ H, W, C = x.shape
122
+ assert C == 1 or C == 3 or C == 4
123
+ if C == 3:
124
+ return x
125
+ if C == 1:
126
+ return np.concatenate([x, x, x], axis=2)
127
+ if C == 4:
128
+ color = x[:, :, 0:3].astype(np.float32)
129
+ alpha = x[:, :, 3:4].astype(np.float32) / 255.0
130
+ y = color * alpha + 255.0 * (1.0 - alpha)
131
+ y = y.clip(0, 255).astype(np.uint8)
132
+ return y
133
+
134
+
135
+ def remove_empty_str(items, default=None):
136
+ items = [x for x in items if x != ""]
137
+ if len(items) == 0 and default is not None:
138
+ return [default]
139
+ return items
140
+
141
+
142
+ def join_prompts(*args, **kwargs):
143
+ prompts = [str(x) for x in args if str(x) != ""]
144
+ if len(prompts) == 0:
145
+ return ""
146
+ if len(prompts) == 1:
147
+ return prompts[0]
148
+ return ', '.join(prompts)
149
+
150
+
151
+ def generate_temp_filename(folder='./outputs/', extension='png'):
152
+ current_time = datetime.datetime.now()
153
+ date_string = current_time.strftime("%Y-%m-%d")
154
+ time_string = current_time.strftime("%Y-%m-%d_%H-%M-%S")
155
+ random_number = random.randint(1000, 9999)
156
+ filename = f"{time_string}_{random_number}.{extension}"
157
+ result = os.path.join(folder, date_string, filename)
158
+ return date_string, os.path.abspath(os.path.realpath(result)), filename
159
+
160
+
161
+ def get_files_from_folder(folder_path, exensions=None, name_filter=None):
162
+ if not os.path.isdir(folder_path):
163
+ raise ValueError("Folder path is not a valid directory.")
164
+
165
+ filenames = []
166
+
167
+ for root, dirs, files in os.walk(folder_path):
168
+ relative_path = os.path.relpath(root, folder_path)
169
+ if relative_path == ".":
170
+ relative_path = ""
171
+ for filename in files:
172
+ _, file_extension = os.path.splitext(filename)
173
+ if (exensions == None or file_extension.lower() in exensions) and (name_filter == None or name_filter in _):
174
+ path = os.path.join(relative_path, filename)
175
+ filenames.append(path)
176
+
177
+ return sorted(filenames, key=lambda x: -1 if os.sep in x else 1)
webui.py ADDED
@@ -0,0 +1,623 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import random
3
+ import os
4
+ import json
5
+ import time
6
+ import shared
7
+ import modules.config
8
+ import fooocus_version
9
+ import modules.html
10
+ import modules.async_worker as worker
11
+ import modules.constants as constants
12
+ import modules.flags as flags
13
+ import modules.gradio_hijack as grh
14
+ import modules.advanced_parameters as advanced_parameters
15
+ import modules.style_sorter as style_sorter
16
+ import modules.meta_parser
17
+ import args_manager
18
+ import copy
19
+
20
+ from modules.sdxl_styles import legal_style_names
21
+ from modules.private_logger import get_current_html_path
22
+ from modules.ui_gradio_extensions import reload_javascript
23
+ from modules.auth import auth_enabled, check_auth
24
+
25
+
26
+ def generate_clicked(*args):
27
+ import ldm_patched.modules.model_management as model_management
28
+
29
+ with model_management.interrupt_processing_mutex:
30
+ model_management.interrupt_processing = False
31
+
32
+ # outputs=[progress_html, progress_window, progress_gallery, gallery]
33
+
34
+ execution_start_time = time.perf_counter()
35
+ task = worker.AsyncTask(args=list(args))
36
+ finished = False
37
+
38
+ yield gr.update(visible=True, value=modules.html.make_progress_html(1, 'Waiting for task to start ...')), \
39
+ gr.update(visible=True, value=None), \
40
+ gr.update(visible=False, value=None), \
41
+ gr.update(visible=False)
42
+
43
+ worker.async_tasks.append(task)
44
+
45
+ while not finished:
46
+ time.sleep(0.01)
47
+ if len(task.yields) > 0:
48
+ flag, product = task.yields.pop(0)
49
+ if flag == 'preview':
50
+
51
+ # help bad internet connection by skipping duplicated preview
52
+ if len(task.yields) > 0: # if we have the next item
53
+ if task.yields[0][0] == 'preview': # if the next item is also a preview
54
+ # print('Skipped one preview for better internet connection.')
55
+ continue
56
+
57
+ percentage, title, image = product
58
+ yield gr.update(visible=True, value=modules.html.make_progress_html(percentage, title)), \
59
+ gr.update(visible=True, value=image) if image is not None else gr.update(), \
60
+ gr.update(), \
61
+ gr.update(visible=False)
62
+ if flag == 'results':
63
+ yield gr.update(visible=True), \
64
+ gr.update(visible=True), \
65
+ gr.update(visible=True, value=product), \
66
+ gr.update(visible=False)
67
+ if flag == 'finish':
68
+ yield gr.update(visible=False), \
69
+ gr.update(visible=False), \
70
+ gr.update(visible=False), \
71
+ gr.update(visible=True, value=product)
72
+ finished = True
73
+
74
+ execution_time = time.perf_counter() - execution_start_time
75
+ print(f'Total time: {execution_time:.2f} seconds')
76
+ return
77
+
78
+
79
+ reload_javascript()
80
+
81
+ title = f'Fooocus {fooocus_version.version}'
82
+
83
+ if isinstance(args_manager.args.preset, str):
84
+ title += ' ' + args_manager.args.preset
85
+
86
+ shared.gradio_root = gr.Blocks(
87
+ title=title,
88
+ css=modules.html.css).queue()
89
+
90
+ with shared.gradio_root:
91
+ with gr.Row():
92
+ with gr.Column(scale=2):
93
+ with gr.Row():
94
+ progress_window = grh.Image(label='Preview', show_label=True, visible=False, height=768,
95
+ elem_classes=['main_view'])
96
+ progress_gallery = gr.Gallery(label='Finished Images', show_label=True, object_fit='contain',
97
+ height=768, visible=False, elem_classes=['main_view', 'image_gallery'])
98
+ progress_html = gr.HTML(value=modules.html.make_progress_html(32, 'Progress 32%'), visible=False,
99
+ elem_id='progress-bar', elem_classes='progress-bar')
100
+ gallery = gr.Gallery(label='Gallery', show_label=False, object_fit='contain', visible=True, height=768,
101
+ elem_classes=['resizable_area', 'main_view', 'final_gallery', 'image_gallery'],
102
+ elem_id='final_gallery')
103
+ with gr.Row(elem_classes='type_row'):
104
+ with gr.Column(scale=17):
105
+ prompt = gr.Textbox(show_label=False, placeholder="Type prompt here or paste parameters.", elem_id='positive_prompt',
106
+ container=False, autofocus=True, elem_classes='type_row', lines=1024)
107
+
108
+ default_prompt = modules.config.default_prompt
109
+ if isinstance(default_prompt, str) and default_prompt != '':
110
+ shared.gradio_root.load(lambda: default_prompt, outputs=prompt)
111
+
112
+ with gr.Column(scale=3, min_width=0):
113
+ generate_button = gr.Button(label="Generate", value="Generate", elem_classes='type_row', elem_id='generate_button', visible=True)
114
+ load_parameter_button = gr.Button(label="Load Parameters", value="Load Parameters", elem_classes='type_row', elem_id='load_parameter_button', visible=False)
115
+ skip_button = gr.Button(label="Skip", value="Skip", elem_classes='type_row_half', visible=False)
116
+ stop_button = gr.Button(label="Stop", value="Stop", elem_classes='type_row_half', elem_id='stop_button', visible=False)
117
+
118
+ def stop_clicked():
119
+ import ldm_patched.modules.model_management as model_management
120
+ shared.last_stop = 'stop'
121
+ model_management.interrupt_current_processing()
122
+ return [gr.update(interactive=False)] * 2
123
+
124
+ def skip_clicked():
125
+ import ldm_patched.modules.model_management as model_management
126
+ shared.last_stop = 'skip'
127
+ model_management.interrupt_current_processing()
128
+ return
129
+
130
+ stop_button.click(stop_clicked, outputs=[skip_button, stop_button],
131
+ queue=False, show_progress=False, _js='cancelGenerateForever')
132
+ skip_button.click(skip_clicked, queue=False, show_progress=False)
133
+ with gr.Row(elem_classes='advanced_check_row'):
134
+ input_image_checkbox = gr.Checkbox(label='Input Image', value=False, container=False, elem_classes='min_check')
135
+ advanced_checkbox = gr.Checkbox(label='Advanced', value=modules.config.default_advanced_checkbox, container=False, elem_classes='min_check')
136
+ with gr.Row(visible=False) as image_input_panel:
137
+ with gr.Tabs():
138
+ with gr.TabItem(label='Upscale or Variation') as uov_tab:
139
+ with gr.Row():
140
+ with gr.Column():
141
+ uov_input_image = grh.Image(label='Drag above image to here', source='upload', type='numpy')
142
+ with gr.Column():
143
+ uov_method = gr.Radio(label='Upscale or Variation:', choices=flags.uov_list, value=flags.disabled)
144
+ gr.HTML('<a href="https://github.com/lllyasviel/Fooocus/discussions/390" target="_blank">\U0001F4D4 Document</a>')
145
+ with gr.TabItem(label='Image Prompt') as ip_tab:
146
+ with gr.Row():
147
+ ip_images = []
148
+ ip_types = []
149
+ ip_stops = []
150
+ ip_weights = []
151
+ ip_ctrls = []
152
+ ip_ad_cols = []
153
+ for _ in range(4):
154
+ with gr.Column():
155
+ ip_image = grh.Image(label='Image', source='upload', type='numpy', show_label=False, height=300)
156
+ ip_images.append(ip_image)
157
+ ip_ctrls.append(ip_image)
158
+ with gr.Column(visible=False) as ad_col:
159
+ with gr.Row():
160
+ default_end, default_weight = flags.default_parameters[flags.default_ip]
161
+
162
+ ip_stop = gr.Slider(label='Stop At', minimum=0.0, maximum=1.0, step=0.001, value=default_end)
163
+ ip_stops.append(ip_stop)
164
+ ip_ctrls.append(ip_stop)
165
+
166
+ ip_weight = gr.Slider(label='Weight', minimum=0.0, maximum=2.0, step=0.001, value=default_weight)
167
+ ip_weights.append(ip_weight)
168
+ ip_ctrls.append(ip_weight)
169
+
170
+ ip_type = gr.Radio(label='Type', choices=flags.ip_list, value=flags.default_ip, container=False)
171
+ ip_types.append(ip_type)
172
+ ip_ctrls.append(ip_type)
173
+
174
+ ip_type.change(lambda x: flags.default_parameters[x], inputs=[ip_type], outputs=[ip_stop, ip_weight], queue=False, show_progress=False)
175
+ ip_ad_cols.append(ad_col)
176
+ ip_advanced = gr.Checkbox(label='Advanced', value=False, container=False)
177
+ gr.HTML('* \"Image Prompt\" is powered by Fooocus Image Mixture Engine (v1.0.1). <a href="https://github.com/lllyasviel/Fooocus/discussions/557" target="_blank">\U0001F4D4 Document</a>')
178
+
179
+ def ip_advance_checked(x):
180
+ return [gr.update(visible=x)] * len(ip_ad_cols) + \
181
+ [flags.default_ip] * len(ip_types) + \
182
+ [flags.default_parameters[flags.default_ip][0]] * len(ip_stops) + \
183
+ [flags.default_parameters[flags.default_ip][1]] * len(ip_weights)
184
+
185
+ ip_advanced.change(ip_advance_checked, inputs=ip_advanced,
186
+ outputs=ip_ad_cols + ip_types + ip_stops + ip_weights,
187
+ queue=False, show_progress=False)
188
+ with gr.TabItem(label='Inpaint or Outpaint') as inpaint_tab:
189
+ with gr.Row():
190
+ inpaint_input_image = grh.Image(label='Drag inpaint or outpaint image to here', source='upload', type='numpy', tool='sketch', height=500, brush_color="#FFFFFF", elem_id='inpaint_canvas')
191
+ inpaint_mask_image = grh.Image(label='Mask Upload', source='upload', type='numpy', height=500, visible=False)
192
+
193
+ with gr.Row():
194
+ inpaint_additional_prompt = gr.Textbox(placeholder="Describe what you want to inpaint.", elem_id='inpaint_additional_prompt', label='Inpaint Additional Prompt', visible=False)
195
+ outpaint_selections = gr.CheckboxGroup(choices=['Left', 'Right', 'Top', 'Bottom'], value=[], label='Outpaint Direction')
196
+ inpaint_mode = gr.Dropdown(choices=modules.flags.inpaint_options, value=modules.flags.inpaint_option_default, label='Method')
197
+ example_inpaint_prompts = gr.Dataset(samples=modules.config.example_inpaint_prompts, label='Additional Prompt Quick List', components=[inpaint_additional_prompt], visible=False)
198
+ gr.HTML('* Powered by Fooocus Inpaint Engine <a href="https://github.com/lllyasviel/Fooocus/discussions/414" target="_blank">\U0001F4D4 Document</a>')
199
+ example_inpaint_prompts.click(lambda x: x[0], inputs=example_inpaint_prompts, outputs=inpaint_additional_prompt, show_progress=False, queue=False)
200
+ with gr.TabItem(label='Describe') as desc_tab:
201
+ with gr.Row():
202
+ with gr.Column():
203
+ desc_input_image = grh.Image(label='Drag any image to here', source='upload', type='numpy')
204
+ with gr.Column():
205
+ desc_method = gr.Radio(
206
+ label='Content Type',
207
+ choices=[flags.desc_type_photo, flags.desc_type_anime],
208
+ value=flags.desc_type_photo)
209
+ desc_btn = gr.Button(value='Describe this Image into Prompt')
210
+ gr.HTML('<a href="https://github.com/lllyasviel/Fooocus/discussions/1363" target="_blank">\U0001F4D4 Document</a>')
211
+ switch_js = "(x) => {if(x){viewer_to_bottom(100);viewer_to_bottom(500);}else{viewer_to_top();} return x;}"
212
+ down_js = "() => {viewer_to_bottom();}"
213
+
214
+ input_image_checkbox.change(lambda x: gr.update(visible=x), inputs=input_image_checkbox,
215
+ outputs=image_input_panel, queue=False, show_progress=False, _js=switch_js)
216
+ ip_advanced.change(lambda: None, queue=False, show_progress=False, _js=down_js)
217
+
218
+ current_tab = gr.Textbox(value='uov', visible=False)
219
+ uov_tab.select(lambda: 'uov', outputs=current_tab, queue=False, _js=down_js, show_progress=False)
220
+ inpaint_tab.select(lambda: 'inpaint', outputs=current_tab, queue=False, _js=down_js, show_progress=False)
221
+ ip_tab.select(lambda: 'ip', outputs=current_tab, queue=False, _js=down_js, show_progress=False)
222
+ desc_tab.select(lambda: 'desc', outputs=current_tab, queue=False, _js=down_js, show_progress=False)
223
+
224
+ with gr.Column(scale=1, visible=modules.config.default_advanced_checkbox) as advanced_column:
225
+ with gr.Tab(label='Setting'):
226
+ performance_selection = gr.Radio(label='Performance',
227
+ choices=modules.flags.performance_selections,
228
+ value=modules.config.default_performance)
229
+ aspect_ratios_selection = gr.Radio(label='Aspect Ratios', choices=modules.config.available_aspect_ratios,
230
+ value=modules.config.default_aspect_ratio, info='width × height',
231
+ elem_classes='aspect_ratios')
232
+ image_number = gr.Slider(label='Image Number', minimum=1, maximum=modules.config.default_max_image_number, step=1, value=modules.config.default_image_number)
233
+ negative_prompt = gr.Textbox(label='Negative Prompt', show_label=True, placeholder="Type prompt here.",
234
+ info='Describing what you do not want to see.', lines=2,
235
+ elem_id='negative_prompt',
236
+ value=modules.config.default_prompt_negative)
237
+ seed_random = gr.Checkbox(label='Random', value=True)
238
+ image_seed = gr.Textbox(label='Seed', value=0, max_lines=1, visible=False) # workaround for https://github.com/gradio-app/gradio/issues/5354
239
+
240
+ def random_checked(r):
241
+ return gr.update(visible=not r)
242
+
243
+ def refresh_seed(r, seed_string):
244
+ if r:
245
+ return random.randint(constants.MIN_SEED, constants.MAX_SEED)
246
+ else:
247
+ try:
248
+ seed_value = int(seed_string)
249
+ if constants.MIN_SEED <= seed_value <= constants.MAX_SEED:
250
+ return seed_value
251
+ except ValueError:
252
+ pass
253
+ return random.randint(constants.MIN_SEED, constants.MAX_SEED)
254
+
255
+ seed_random.change(random_checked, inputs=[seed_random], outputs=[image_seed],
256
+ queue=False, show_progress=False)
257
+
258
+ if not args_manager.args.disable_image_log:
259
+ gr.HTML(f'<a href="file={get_current_html_path()}" target="_blank">\U0001F4DA History Log</a>')
260
+
261
+ with gr.Tab(label='Style'):
262
+ style_sorter.try_load_sorted_styles(
263
+ style_names=legal_style_names,
264
+ default_selected=modules.config.default_styles)
265
+
266
+ style_search_bar = gr.Textbox(show_label=False, container=False,
267
+ placeholder="\U0001F50E Type here to search styles ...",
268
+ value="",
269
+ label='Search Styles')
270
+ style_selections = gr.CheckboxGroup(show_label=False, container=False,
271
+ choices=copy.deepcopy(style_sorter.all_styles),
272
+ value=copy.deepcopy(modules.config.default_styles),
273
+ label='Selected Styles',
274
+ elem_classes=['style_selections'])
275
+ gradio_receiver_style_selections = gr.Textbox(elem_id='gradio_receiver_style_selections', visible=False)
276
+
277
+ shared.gradio_root.load(lambda: gr.update(choices=copy.deepcopy(style_sorter.all_styles)),
278
+ outputs=style_selections)
279
+
280
+ style_search_bar.change(style_sorter.search_styles,
281
+ inputs=[style_selections, style_search_bar],
282
+ outputs=style_selections,
283
+ queue=False,
284
+ show_progress=False).then(
285
+ lambda: None, _js='()=>{refresh_style_localization();}')
286
+
287
+ gradio_receiver_style_selections.input(style_sorter.sort_styles,
288
+ inputs=style_selections,
289
+ outputs=style_selections,
290
+ queue=False,
291
+ show_progress=False).then(
292
+ lambda: None, _js='()=>{refresh_style_localization();}')
293
+
294
+ with gr.Tab(label='Model'):
295
+ with gr.Group():
296
+ with gr.Row():
297
+ base_model = gr.Dropdown(label='Base Model (SDXL only)', choices=modules.config.model_filenames, value=modules.config.default_base_model_name, show_label=True)
298
+ refiner_model = gr.Dropdown(label='Refiner (SDXL or SD 1.5)', choices=['None'] + modules.config.model_filenames, value=modules.config.default_refiner_model_name, show_label=True)
299
+
300
+ refiner_switch = gr.Slider(label='Refiner Switch At', minimum=0.1, maximum=1.0, step=0.0001,
301
+ info='Use 0.4 for SD1.5 realistic models; '
302
+ 'or 0.667 for SD1.5 anime models; '
303
+ 'or 0.8 for XL-refiners; '
304
+ 'or any value for switching two SDXL models.',
305
+ value=modules.config.default_refiner_switch,
306
+ visible=modules.config.default_refiner_model_name != 'None')
307
+
308
+ refiner_model.change(lambda x: gr.update(visible=x != 'None'),
309
+ inputs=refiner_model, outputs=refiner_switch, show_progress=False, queue=False)
310
+
311
+ with gr.Group():
312
+ lora_ctrls = []
313
+
314
+ for i, (n, v) in enumerate(modules.config.default_loras):
315
+ with gr.Row():
316
+ lora_model = gr.Dropdown(label=f'LoRA {i + 1}',
317
+ choices=['None'] + modules.config.lora_filenames, value=n)
318
+ lora_weight = gr.Slider(label='Weight', minimum=-2, maximum=2, step=0.01, value=v,
319
+ elem_classes='lora_weight')
320
+ lora_ctrls += [lora_model, lora_weight]
321
+
322
+ with gr.Row():
323
+ model_refresh = gr.Button(label='Refresh', value='\U0001f504 Refresh All Files', variant='secondary', elem_classes='refresh_button')
324
+ with gr.Tab(label='Advanced'):
325
+ guidance_scale = gr.Slider(label='Guidance Scale', minimum=1.0, maximum=30.0, step=0.01,
326
+ value=modules.config.default_cfg_scale,
327
+ info='Higher value means style is cleaner, vivider, and more artistic.')
328
+ sharpness = gr.Slider(label='Image Sharpness', minimum=0.0, maximum=30.0, step=0.001,
329
+ value=modules.config.default_sample_sharpness,
330
+ info='Higher value means image and texture are sharper.')
331
+ gr.HTML('<a href="https://github.com/lllyasviel/Fooocus/discussions/117" target="_blank">\U0001F4D4 Document</a>')
332
+ dev_mode = gr.Checkbox(label='Developer Debug Mode', value=False, container=False)
333
+
334
+ with gr.Column(visible=False) as dev_tools:
335
+ with gr.Tab(label='Debug Tools'):
336
+ adm_scaler_positive = gr.Slider(label='Positive ADM Guidance Scaler', minimum=0.1, maximum=3.0,
337
+ step=0.001, value=1.5, info='The scaler multiplied to positive ADM (use 1.0 to disable). ')
338
+ adm_scaler_negative = gr.Slider(label='Negative ADM Guidance Scaler', minimum=0.1, maximum=3.0,
339
+ step=0.001, value=0.8, info='The scaler multiplied to negative ADM (use 1.0 to disable). ')
340
+ adm_scaler_end = gr.Slider(label='ADM Guidance End At Step', minimum=0.0, maximum=1.0,
341
+ step=0.001, value=0.3,
342
+ info='When to end the guidance from positive/negative ADM. ')
343
+
344
+ refiner_swap_method = gr.Dropdown(label='Refiner swap method', value='joint',
345
+ choices=['joint', 'separate', 'vae'])
346
+
347
+ adaptive_cfg = gr.Slider(label='CFG Mimicking from TSNR', minimum=1.0, maximum=30.0, step=0.01,
348
+ value=modules.config.default_cfg_tsnr,
349
+ info='Enabling Fooocus\'s implementation of CFG mimicking for TSNR '
350
+ '(effective when real CFG > mimicked CFG).')
351
+ sampler_name = gr.Dropdown(label='Sampler', choices=flags.sampler_list,
352
+ value=modules.config.default_sampler)
353
+ scheduler_name = gr.Dropdown(label='Scheduler', choices=flags.scheduler_list,
354
+ value=modules.config.default_scheduler)
355
+
356
+ generate_image_grid = gr.Checkbox(label='Generate Image Grid for Each Batch',
357
+ info='(Experimental) This may cause performance problems on some computers and certain internet conditions.',
358
+ value=False)
359
+
360
+ overwrite_step = gr.Slider(label='Forced Overwrite of Sampling Step',
361
+ minimum=-1, maximum=200, step=1,
362
+ value=modules.config.default_overwrite_step,
363
+ info='Set as -1 to disable. For developer debugging.')
364
+ overwrite_switch = gr.Slider(label='Forced Overwrite of Refiner Switch Step',
365
+ minimum=-1, maximum=200, step=1,
366
+ value=modules.config.default_overwrite_switch,
367
+ info='Set as -1 to disable. For developer debugging.')
368
+ overwrite_width = gr.Slider(label='Forced Overwrite of Generating Width',
369
+ minimum=-1, maximum=2048, step=1, value=-1,
370
+ info='Set as -1 to disable. For developer debugging. '
371
+ 'Results will be worse for non-standard numbers that SDXL is not trained on.')
372
+ overwrite_height = gr.Slider(label='Forced Overwrite of Generating Height',
373
+ minimum=-1, maximum=2048, step=1, value=-1,
374
+ info='Set as -1 to disable. For developer debugging. '
375
+ 'Results will be worse for non-standard numbers that SDXL is not trained on.')
376
+ overwrite_vary_strength = gr.Slider(label='Forced Overwrite of Denoising Strength of "Vary"',
377
+ minimum=-1, maximum=1.0, step=0.001, value=-1,
378
+ info='Set as negative number to disable. For developer debugging.')
379
+ overwrite_upscale_strength = gr.Slider(label='Forced Overwrite of Denoising Strength of "Upscale"',
380
+ minimum=-1, maximum=1.0, step=0.001, value=-1,
381
+ info='Set as negative number to disable. For developer debugging.')
382
+ disable_preview = gr.Checkbox(label='Disable Preview', value=False,
383
+ info='Disable preview during generation.')
384
+
385
+ with gr.Tab(label='Control'):
386
+ debugging_cn_preprocessor = gr.Checkbox(label='Debug Preprocessors', value=False,
387
+ info='See the results from preprocessors.')
388
+ skipping_cn_preprocessor = gr.Checkbox(label='Skip Preprocessors', value=False,
389
+ info='Do not preprocess images. (Inputs are already canny/depth/cropped-face/etc.)')
390
+
391
+ mixing_image_prompt_and_vary_upscale = gr.Checkbox(label='Mixing Image Prompt and Vary/Upscale',
392
+ value=False)
393
+ mixing_image_prompt_and_inpaint = gr.Checkbox(label='Mixing Image Prompt and Inpaint',
394
+ value=False)
395
+
396
+ controlnet_softness = gr.Slider(label='Softness of ControlNet', minimum=0.0, maximum=1.0,
397
+ step=0.001, value=0.25,
398
+ info='Similar to the Control Mode in A1111 (use 0.0 to disable). ')
399
+
400
+ with gr.Tab(label='Canny'):
401
+ canny_low_threshold = gr.Slider(label='Canny Low Threshold', minimum=1, maximum=255,
402
+ step=1, value=64)
403
+ canny_high_threshold = gr.Slider(label='Canny High Threshold', minimum=1, maximum=255,
404
+ step=1, value=128)
405
+
406
+ with gr.Tab(label='Inpaint'):
407
+ debugging_inpaint_preprocessor = gr.Checkbox(label='Debug Inpaint Preprocessing', value=False)
408
+ inpaint_disable_initial_latent = gr.Checkbox(label='Disable initial latent in inpaint', value=False)
409
+ inpaint_engine = gr.Dropdown(label='Inpaint Engine',
410
+ value=modules.config.default_inpaint_engine_version,
411
+ choices=flags.inpaint_engine_versions,
412
+ info='Version of Fooocus inpaint model')
413
+ inpaint_strength = gr.Slider(label='Inpaint Denoising Strength',
414
+ minimum=0.0, maximum=1.0, step=0.001, value=1.0,
415
+ info='Same as the denoising strength in A1111 inpaint. '
416
+ 'Only used in inpaint, not used in outpaint. '
417
+ '(Outpaint always use 1.0)')
418
+ inpaint_respective_field = gr.Slider(label='Inpaint Respective Field',
419
+ minimum=0.0, maximum=1.0, step=0.001, value=0.618,
420
+ info='The area to inpaint. '
421
+ 'Value 0 is same as "Only Masked" in A1111. '
422
+ 'Value 1 is same as "Whole Image" in A1111. '
423
+ 'Only used in inpaint, not used in outpaint. '
424
+ '(Outpaint always use 1.0)')
425
+ inpaint_erode_or_dilate = gr.Slider(label='Mask Erode or Dilate',
426
+ minimum=-64, maximum=64, step=1, value=0,
427
+ info='Positive value will make white area in the mask larger, '
428
+ 'negative value will make white area smaller.'
429
+ '(default is 0, always process before any mask invert)')
430
+ inpaint_mask_upload_checkbox = gr.Checkbox(label='Enable Mask Upload', value=False)
431
+ invert_mask_checkbox = gr.Checkbox(label='Invert Mask', value=False)
432
+
433
+ inpaint_ctrls = [debugging_inpaint_preprocessor, inpaint_disable_initial_latent, inpaint_engine,
434
+ inpaint_strength, inpaint_respective_field,
435
+ inpaint_mask_upload_checkbox, invert_mask_checkbox, inpaint_erode_or_dilate]
436
+
437
+ inpaint_mask_upload_checkbox.change(lambda x: gr.update(visible=x),
438
+ inputs=inpaint_mask_upload_checkbox,
439
+ outputs=inpaint_mask_image, queue=False, show_progress=False)
440
+
441
+ with gr.Tab(label='FreeU'):
442
+ freeu_enabled = gr.Checkbox(label='Enabled', value=False)
443
+ freeu_b1 = gr.Slider(label='B1', minimum=0, maximum=2, step=0.01, value=1.01)
444
+ freeu_b2 = gr.Slider(label='B2', minimum=0, maximum=2, step=0.01, value=1.02)
445
+ freeu_s1 = gr.Slider(label='S1', minimum=0, maximum=4, step=0.01, value=0.99)
446
+ freeu_s2 = gr.Slider(label='S2', minimum=0, maximum=4, step=0.01, value=0.95)
447
+ freeu_ctrls = [freeu_enabled, freeu_b1, freeu_b2, freeu_s1, freeu_s2]
448
+
449
+ adps = [disable_preview, adm_scaler_positive, adm_scaler_negative, adm_scaler_end, adaptive_cfg, sampler_name,
450
+ scheduler_name, generate_image_grid, overwrite_step, overwrite_switch, overwrite_width, overwrite_height,
451
+ overwrite_vary_strength, overwrite_upscale_strength,
452
+ mixing_image_prompt_and_vary_upscale, mixing_image_prompt_and_inpaint,
453
+ debugging_cn_preprocessor, skipping_cn_preprocessor, controlnet_softness,
454
+ canny_low_threshold, canny_high_threshold, refiner_swap_method]
455
+ adps += freeu_ctrls
456
+ adps += inpaint_ctrls
457
+
458
+ def dev_mode_checked(r):
459
+ return gr.update(visible=r)
460
+
461
+
462
+ dev_mode.change(dev_mode_checked, inputs=[dev_mode], outputs=[dev_tools],
463
+ queue=False, show_progress=False)
464
+
465
+ def model_refresh_clicked():
466
+ modules.config.update_all_model_names()
467
+ results = []
468
+ results += [gr.update(choices=modules.config.model_filenames), gr.update(choices=['None'] + modules.config.model_filenames)]
469
+ for i in range(5):
470
+ results += [gr.update(choices=['None'] + modules.config.lora_filenames), gr.update()]
471
+ return results
472
+
473
+ model_refresh.click(model_refresh_clicked, [], [base_model, refiner_model] + lora_ctrls,
474
+ queue=False, show_progress=False)
475
+
476
+ performance_selection.change(lambda x: [gr.update(interactive=x != 'Extreme Speed')] * 11 +
477
+ [gr.update(visible=x != 'Extreme Speed')] * 1,
478
+ inputs=performance_selection,
479
+ outputs=[
480
+ guidance_scale, sharpness, adm_scaler_end, adm_scaler_positive,
481
+ adm_scaler_negative, refiner_switch, refiner_model, sampler_name,
482
+ scheduler_name, adaptive_cfg, refiner_swap_method, negative_prompt
483
+ ], queue=False, show_progress=False)
484
+
485
+ advanced_checkbox.change(lambda x: gr.update(visible=x), advanced_checkbox, advanced_column,
486
+ queue=False, show_progress=False) \
487
+ .then(fn=lambda: None, _js='refresh_grid_delayed', queue=False, show_progress=False)
488
+
489
+ def inpaint_mode_change(mode):
490
+ assert mode in modules.flags.inpaint_options
491
+
492
+ # inpaint_additional_prompt, outpaint_selections, example_inpaint_prompts,
493
+ # inpaint_disable_initial_latent, inpaint_engine,
494
+ # inpaint_strength, inpaint_respective_field
495
+
496
+ if mode == modules.flags.inpaint_option_detail:
497
+ return [
498
+ gr.update(visible=True), gr.update(visible=False, value=[]),
499
+ gr.Dataset.update(visible=True, samples=modules.config.example_inpaint_prompts),
500
+ False, 'None', 0.5, 0.0
501
+ ]
502
+
503
+ if mode == modules.flags.inpaint_option_modify:
504
+ return [
505
+ gr.update(visible=True), gr.update(visible=False, value=[]),
506
+ gr.Dataset.update(visible=False, samples=modules.config.example_inpaint_prompts),
507
+ True, modules.config.default_inpaint_engine_version, 1.0, 0.0
508
+ ]
509
+
510
+ return [
511
+ gr.update(visible=False, value=''), gr.update(visible=True),
512
+ gr.Dataset.update(visible=False, samples=modules.config.example_inpaint_prompts),
513
+ False, modules.config.default_inpaint_engine_version, 1.0, 0.618
514
+ ]
515
+
516
+ inpaint_mode.input(inpaint_mode_change, inputs=inpaint_mode, outputs=[
517
+ inpaint_additional_prompt, outpaint_selections, example_inpaint_prompts,
518
+ inpaint_disable_initial_latent, inpaint_engine,
519
+ inpaint_strength, inpaint_respective_field
520
+ ], show_progress=False, queue=False)
521
+
522
+ ctrls = [
523
+ prompt, negative_prompt, style_selections,
524
+ performance_selection, aspect_ratios_selection, image_number, image_seed, sharpness, guidance_scale
525
+ ]
526
+
527
+ ctrls += [base_model, refiner_model, refiner_switch] + lora_ctrls
528
+ ctrls += [input_image_checkbox, current_tab]
529
+ ctrls += [uov_method, uov_input_image]
530
+ ctrls += [outpaint_selections, inpaint_input_image, inpaint_additional_prompt, inpaint_mask_image]
531
+ ctrls += ip_ctrls
532
+
533
+ state_is_generating = gr.State(False)
534
+
535
+ def parse_meta(raw_prompt_txt, is_generating):
536
+ loaded_json = None
537
+ try:
538
+ if '{' in raw_prompt_txt:
539
+ if '}' in raw_prompt_txt:
540
+ if ':' in raw_prompt_txt:
541
+ loaded_json = json.loads(raw_prompt_txt)
542
+ assert isinstance(loaded_json, dict)
543
+ except:
544
+ loaded_json = None
545
+
546
+ if loaded_json is None:
547
+ if is_generating:
548
+ return gr.update(), gr.update(), gr.update()
549
+ else:
550
+ return gr.update(), gr.update(visible=True), gr.update(visible=False)
551
+
552
+ return json.dumps(loaded_json), gr.update(visible=False), gr.update(visible=True)
553
+
554
+ prompt.input(parse_meta, inputs=[prompt, state_is_generating], outputs=[prompt, generate_button, load_parameter_button], queue=False, show_progress=False)
555
+
556
+ load_parameter_button.click(modules.meta_parser.load_parameter_button_click, inputs=[prompt, state_is_generating], outputs=[
557
+ advanced_checkbox,
558
+ image_number,
559
+ prompt,
560
+ negative_prompt,
561
+ style_selections,
562
+ performance_selection,
563
+ aspect_ratios_selection,
564
+ overwrite_width,
565
+ overwrite_height,
566
+ sharpness,
567
+ guidance_scale,
568
+ adm_scaler_positive,
569
+ adm_scaler_negative,
570
+ adm_scaler_end,
571
+ base_model,
572
+ refiner_model,
573
+ refiner_switch,
574
+ sampler_name,
575
+ scheduler_name,
576
+ seed_random,
577
+ image_seed,
578
+ generate_button,
579
+ load_parameter_button
580
+ ] + lora_ctrls, queue=False, show_progress=False)
581
+
582
+ generate_button.click(lambda: (gr.update(visible=True, interactive=True), gr.update(visible=True, interactive=True), gr.update(visible=False, interactive=False), [], True),
583
+ outputs=[stop_button, skip_button, generate_button, gallery, state_is_generating]) \
584
+ .then(fn=refresh_seed, inputs=[seed_random, image_seed], outputs=image_seed) \
585
+ .then(advanced_parameters.set_all_advanced_parameters, inputs=adps) \
586
+ .then(fn=generate_clicked, inputs=ctrls, outputs=[progress_html, progress_window, progress_gallery, gallery]) \
587
+ .then(lambda: (gr.update(visible=True, interactive=True), gr.update(visible=False, interactive=False), gr.update(visible=False, interactive=False), False),
588
+ outputs=[generate_button, stop_button, skip_button, state_is_generating]) \
589
+ .then(fn=lambda: None, _js='playNotification').then(fn=lambda: None, _js='refresh_grid_delayed')
590
+
591
+ for notification_file in ['notification.ogg', 'notification.mp3']:
592
+ if os.path.exists(notification_file):
593
+ gr.Audio(interactive=False, value=notification_file, elem_id='audio_notification', visible=False)
594
+ break
595
+
596
+ def trigger_describe(mode, img):
597
+ if mode == flags.desc_type_photo:
598
+ from extras.interrogate import default_interrogator as default_interrogator_photo
599
+ return default_interrogator_photo(img), ["Fooocus V2", "Fooocus Enhance", "Fooocus Sharp"]
600
+ if mode == flags.desc_type_anime:
601
+ from extras.wd14tagger import default_interrogator as default_interrogator_anime
602
+ return default_interrogator_anime(img), ["Fooocus V2", "Fooocus Masterpiece"]
603
+ return mode, ["Fooocus V2"]
604
+
605
+ desc_btn.click(trigger_describe, inputs=[desc_method, desc_input_image],
606
+ outputs=[prompt, style_selections], show_progress=True, queue=True)
607
+
608
+
609
+ def dump_default_english_config():
610
+ from modules.localization import dump_english_config
611
+ dump_english_config(grh.all_components)
612
+
613
+
614
+ # dump_default_english_config()
615
+
616
+ shared.gradio_root.launch(
617
+ inbrowser=args_manager.args.in_browser,
618
+ server_name=args_manager.args.listen,
619
+ server_port=args_manager.args.port,
620
+ share=args_manager.args.share,
621
+ auth=check_auth if args_manager.args.share and auth_enabled else None,
622
+ blocked_paths=[constants.AUTH_FILENAME]
623
+ )