Upload !adetailer.py
Browse files- !adetailer.py +665 -0
!adetailer.py
ADDED
@@ -0,0 +1,665 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
import os
|
4 |
+
import platform
|
5 |
+
import re
|
6 |
+
import sys
|
7 |
+
import traceback
|
8 |
+
from contextlib import contextmanager
|
9 |
+
from copy import copy, deepcopy
|
10 |
+
from pathlib import Path
|
11 |
+
from textwrap import dedent
|
12 |
+
from typing import Any
|
13 |
+
from modules import scripts, script_callbacks, sd_samplers, sd_samplers_compvis, sd_samplers_kdiffusion, sd_samplers_commo
|
14 |
+
|
15 |
+
|
16 |
+
import gradio as gr
|
17 |
+
import torch
|
18 |
+
|
19 |
+
import modules # noqa: F401
|
20 |
+
from adetailer import (
|
21 |
+
AFTER_DETAILER,
|
22 |
+
__version__,
|
23 |
+
get_models,
|
24 |
+
mediapipe_predict,
|
25 |
+
ultralytics_predict,
|
26 |
+
)
|
27 |
+
from adetailer.args import ALL_ARGS, BBOX_SORTBY, ADetailerArgs, EnableChecker
|
28 |
+
from adetailer.common import PredictOutput
|
29 |
+
from adetailer.mask import filter_by_ratio, mask_preprocess, sort_bboxes
|
30 |
+
from adetailer.ui import adui, ordinal, suffix
|
31 |
+
from controlnet_ext import ControlNetExt, controlnet_exists
|
32 |
+
from controlnet_ext.restore import (
|
33 |
+
CNHijackRestore,
|
34 |
+
cn_allow_script_control,
|
35 |
+
cn_restore_unet_hook,
|
36 |
+
)
|
37 |
+
from sd_webui import images, safe, script_callbacks, scripts, shared
|
38 |
+
from sd_webui.paths import data_path, models_path
|
39 |
+
from sd_webui.processing import (
|
40 |
+
StableDiffusionProcessingImg2Img,
|
41 |
+
create_infotext,
|
42 |
+
process_images,
|
43 |
+
)
|
44 |
+
from sd_webui.shared import cmd_opts, opts, state
|
45 |
+
|
46 |
+
try:
|
47 |
+
from rich import print
|
48 |
+
from rich.traceback import install
|
49 |
+
|
50 |
+
install(show_locals=True)
|
51 |
+
except Exception:
|
52 |
+
pass
|
53 |
+
|
54 |
+
no_huggingface = getattr(cmd_opts, "ad_no_huggingface", False)
|
55 |
+
adetailer_dir = Path(models_path, "adetailer")
|
56 |
+
model_mapping = get_models(adetailer_dir, huggingface=not no_huggingface)
|
57 |
+
txt2img_submit_button = img2img_submit_button = None
|
58 |
+
SCRIPT_DEFAULT = "dynamic_prompting,dynamic_thresholding,wildcard_recursive,wildcards"
|
59 |
+
|
60 |
+
if (
|
61 |
+
not adetailer_dir.exists()
|
62 |
+
and adetailer_dir.parent.exists()
|
63 |
+
and os.access(adetailer_dir.parent, os.W_OK)
|
64 |
+
):
|
65 |
+
adetailer_dir.mkdir()
|
66 |
+
|
67 |
+
print(
|
68 |
+
f"[-] ADetailer initialized. version: {__version__}, num models: {len(model_mapping)}"
|
69 |
+
)
|
70 |
+
|
71 |
+
|
72 |
+
@contextmanager
|
73 |
+
def change_torch_load():
|
74 |
+
orig = torch.load
|
75 |
+
try:
|
76 |
+
torch.load = safe.unsafe_torch_load
|
77 |
+
yield
|
78 |
+
finally:
|
79 |
+
torch.load = orig
|
80 |
+
|
81 |
+
|
82 |
+
@contextmanager
|
83 |
+
def pause_total_tqdm():
|
84 |
+
orig = opts.data.get("multiple_tqdm", True)
|
85 |
+
try:
|
86 |
+
opts.data["multiple_tqdm"] = False
|
87 |
+
yield
|
88 |
+
finally:
|
89 |
+
opts.data["multiple_tqdm"] = orig
|
90 |
+
|
91 |
+
|
92 |
+
class AfterDetailerScript(scripts.Script):
|
93 |
+
def __init__(self):
|
94 |
+
super().__init__()
|
95 |
+
self.ultralytics_device = self.get_ultralytics_device()
|
96 |
+
|
97 |
+
self.controlnet_ext = None
|
98 |
+
self.cn_script = None
|
99 |
+
self.cn_latest_network = None
|
100 |
+
|
101 |
+
def title(self):
|
102 |
+
return AFTER_DETAILER
|
103 |
+
|
104 |
+
def show(self, is_img2img):
|
105 |
+
return scripts.AlwaysVisible
|
106 |
+
|
107 |
+
def ui(self, is_img2img):
|
108 |
+
num_models = opts.data.get("ad_max_models", 2)
|
109 |
+
model_list = list(model_mapping.keys())
|
110 |
+
|
111 |
+
components, infotext_fields = adui(
|
112 |
+
num_models,
|
113 |
+
is_img2img,
|
114 |
+
model_list,
|
115 |
+
txt2img_submit_button,
|
116 |
+
img2img_submit_button,
|
117 |
+
)
|
118 |
+
|
119 |
+
self.infotext_fields = infotext_fields
|
120 |
+
return components
|
121 |
+
|
122 |
+
def init_controlnet_ext(self) -> None:
|
123 |
+
if self.controlnet_ext is not None:
|
124 |
+
return
|
125 |
+
self.controlnet_ext = ControlNetExt()
|
126 |
+
|
127 |
+
if controlnet_exists:
|
128 |
+
try:
|
129 |
+
self.controlnet_ext.init_controlnet()
|
130 |
+
except ImportError:
|
131 |
+
error = traceback.format_exc()
|
132 |
+
print(
|
133 |
+
f"[-] ADetailer: ControlNetExt init failed:\n{error}",
|
134 |
+
file=sys.stderr,
|
135 |
+
)
|
136 |
+
|
137 |
+
def update_controlnet_args(self, p, args: ADetailerArgs) -> None:
|
138 |
+
if self.controlnet_ext is None:
|
139 |
+
self.init_controlnet_ext()
|
140 |
+
|
141 |
+
if (
|
142 |
+
self.controlnet_ext is not None
|
143 |
+
and self.controlnet_ext.cn_available
|
144 |
+
and args.ad_controlnet_model != "None"
|
145 |
+
):
|
146 |
+
self.controlnet_ext.update_scripts_args(
|
147 |
+
p, args.ad_controlnet_model, args.ad_controlnet_weight
|
148 |
+
)
|
149 |
+
|
150 |
+
def is_ad_enabled(self, *args_) -> bool:
|
151 |
+
if len(args_) == 0 or (len(args_) == 1 and isinstance(args_[0], bool)):
|
152 |
+
message = f"""
|
153 |
+
[-] ADetailer: Not enough arguments passed to ADetailer.
|
154 |
+
input: {args_!r}
|
155 |
+
"""
|
156 |
+
raise ValueError(dedent(message))
|
157 |
+
a0 = args_[0]
|
158 |
+
a1 = args_[1] if len(args_) > 1 else None
|
159 |
+
checker = EnableChecker(a0=a0, a1=a1)
|
160 |
+
return checker.is_enabled()
|
161 |
+
|
162 |
+
def get_args(self, *args_) -> list[ADetailerArgs]:
|
163 |
+
"""
|
164 |
+
`args_` is at least 1 in length by `is_ad_enabled` immediately above
|
165 |
+
"""
|
166 |
+
args = [arg for arg in args_ if isinstance(arg, dict)]
|
167 |
+
|
168 |
+
if not args:
|
169 |
+
message = f"[-] ADetailer: Invalid arguments passed to ADetailer: {args_!r}"
|
170 |
+
raise ValueError(message)
|
171 |
+
|
172 |
+
all_inputs = []
|
173 |
+
|
174 |
+
for n, arg_dict in enumerate(args, 1):
|
175 |
+
try:
|
176 |
+
inp = ADetailerArgs(**arg_dict)
|
177 |
+
except ValueError as e:
|
178 |
+
msgs = [
|
179 |
+
f"[-] ADetailer: ValidationError when validating {ordinal(n)} arguments: {e}\n"
|
180 |
+
]
|
181 |
+
for attr in ALL_ARGS.attrs:
|
182 |
+
arg = arg_dict.get(attr)
|
183 |
+
dtype = type(arg)
|
184 |
+
arg = "DEFAULT" if arg is None else repr(arg)
|
185 |
+
msgs.append(f" {attr}: {arg} ({dtype})")
|
186 |
+
raise ValueError("\n".join(msgs)) from e
|
187 |
+
|
188 |
+
all_inputs.append(inp)
|
189 |
+
|
190 |
+
return all_inputs
|
191 |
+
|
192 |
+
def extra_params(self, arg_list: list[ADetailerArgs]) -> dict:
|
193 |
+
params = {}
|
194 |
+
for n, args in enumerate(arg_list):
|
195 |
+
params.update(args.extra_params(suffix=suffix(n)))
|
196 |
+
params["ADetailer version"] = __version__
|
197 |
+
return params
|
198 |
+
|
199 |
+
@staticmethod
|
200 |
+
def get_ultralytics_device() -> str:
|
201 |
+
'`device = ""` means autodetect'
|
202 |
+
device = ""
|
203 |
+
if platform.system() == "Darwin":
|
204 |
+
return device
|
205 |
+
|
206 |
+
if any(getattr(cmd_opts, vram, False) for vram in ["lowvram", "medvram"]):
|
207 |
+
device = "cpu"
|
208 |
+
|
209 |
+
return device
|
210 |
+
|
211 |
+
def prompt_blank_replacement(
|
212 |
+
self, all_prompts: list[str], i: int, default: str
|
213 |
+
) -> str:
|
214 |
+
if not all_prompts:
|
215 |
+
return default
|
216 |
+
if i < len(all_prompts):
|
217 |
+
return all_prompts[i]
|
218 |
+
j = i % len(all_prompts)
|
219 |
+
return all_prompts[j]
|
220 |
+
|
221 |
+
def _get_prompt(
|
222 |
+
self, ad_prompt: str, all_prompts: list[str], i: int, default: str
|
223 |
+
) -> list[str]:
|
224 |
+
prompts = re.split(r"\s*\[SEP\]\s*", ad_prompt)
|
225 |
+
blank_replacement = self.prompt_blank_replacement(all_prompts, i, default)
|
226 |
+
for n in range(len(prompts)):
|
227 |
+
if not prompts[n]:
|
228 |
+
prompts[n] = blank_replacement
|
229 |
+
return prompts
|
230 |
+
|
231 |
+
def get_prompt(self, p, args: ADetailerArgs) -> tuple[list[str], list[str]]:
|
232 |
+
i = p._idx
|
233 |
+
|
234 |
+
prompt = self._get_prompt(args.ad_prompt, p.all_prompts, i, p.prompt)
|
235 |
+
negative_prompt = self._get_prompt(
|
236 |
+
args.ad_negative_prompt, p.all_negative_prompts, i, p.negative_prompt
|
237 |
+
)
|
238 |
+
|
239 |
+
return prompt, negative_prompt
|
240 |
+
|
241 |
+
def get_seed(self, p) -> tuple[int, int]:
|
242 |
+
i = p._idx
|
243 |
+
|
244 |
+
if not p.all_seeds:
|
245 |
+
seed = p.seed
|
246 |
+
elif i < len(p.all_seeds):
|
247 |
+
seed = p.all_seeds[i]
|
248 |
+
else:
|
249 |
+
j = i % len(p.all_seeds)
|
250 |
+
seed = p.all_seeds[j]
|
251 |
+
|
252 |
+
if not p.all_subseeds:
|
253 |
+
subseed = p.subseed
|
254 |
+
elif i < len(p.all_subseeds):
|
255 |
+
subseed = p.all_subseeds[i]
|
256 |
+
else:
|
257 |
+
j = i % len(p.all_subseeds)
|
258 |
+
subseed = p.all_subseeds[j]
|
259 |
+
|
260 |
+
return seed, subseed
|
261 |
+
|
262 |
+
def get_width_height(self, p, args: ADetailerArgs) -> tuple[int, int]:
|
263 |
+
if args.ad_use_inpaint_width_height:
|
264 |
+
width = args.ad_inpaint_width
|
265 |
+
height = args.ad_inpaint_height
|
266 |
+
else:
|
267 |
+
width = p.width
|
268 |
+
height = p.height
|
269 |
+
|
270 |
+
return width, height
|
271 |
+
|
272 |
+
def get_steps(self, p, args: ADetailerArgs) -> int:
|
273 |
+
if args.ad_use_steps:
|
274 |
+
return args.ad_steps
|
275 |
+
return p.steps
|
276 |
+
|
277 |
+
def get_cfg_scale(self, p, args: ADetailerArgs) -> float:
|
278 |
+
if args.ad_use_cfg_scale:
|
279 |
+
return args.ad_cfg_scale
|
280 |
+
return p.cfg_scale
|
281 |
+
|
282 |
+
def infotext(self, p) -> str:
|
283 |
+
return create_infotext(
|
284 |
+
p, p.all_prompts, p.all_seeds, p.all_subseeds, None, 0, 0
|
285 |
+
)
|
286 |
+
|
287 |
+
def write_params_txt(self, p) -> None:
|
288 |
+
infotext = self.infotext(p)
|
289 |
+
params_txt = Path(data_path, "params.txt")
|
290 |
+
params_txt.write_text(infotext, encoding="utf-8")
|
291 |
+
|
292 |
+
def script_filter(self, p, args: ADetailerArgs):
|
293 |
+
script_runner = copy(p.scripts)
|
294 |
+
script_args = deepcopy(p.script_args)
|
295 |
+
self.disable_controlnet_units(script_args)
|
296 |
+
|
297 |
+
ad_only_seleted_scripts = opts.data.get("ad_only_seleted_scripts", True)
|
298 |
+
if not ad_only_seleted_scripts:
|
299 |
+
return script_runner, script_args
|
300 |
+
|
301 |
+
ad_script_names = opts.data.get("ad_script_names", SCRIPT_DEFAULT)
|
302 |
+
script_names_set = {
|
303 |
+
name
|
304 |
+
for script_name in ad_script_names.split(",")
|
305 |
+
for name in (script_name, script_name.strip())
|
306 |
+
}
|
307 |
+
|
308 |
+
if args.ad_controlnet_model != "None":
|
309 |
+
script_names_set.add("controlnet")
|
310 |
+
|
311 |
+
filtered_alwayson = []
|
312 |
+
for script_object in script_runner.alwayson_scripts:
|
313 |
+
filepath = script_object.filename
|
314 |
+
filename = Path(filepath).stem
|
315 |
+
if filename in script_names_set:
|
316 |
+
filtered_alwayson.append(script_object)
|
317 |
+
if filename == "controlnet":
|
318 |
+
self.cn_script = script_object
|
319 |
+
self.cn_latest_network = script_object.latest_network
|
320 |
+
|
321 |
+
script_runner.alwayson_scripts = filtered_alwayson
|
322 |
+
return script_runner, script_args
|
323 |
+
|
324 |
+
def disable_controlnet_units(self, script_args: list[Any]) -> None:
|
325 |
+
for obj in script_args:
|
326 |
+
if "controlnet" in obj.__class__.__name__.lower():
|
327 |
+
if hasattr(obj, "enabled"):
|
328 |
+
obj.enabled = False
|
329 |
+
if hasattr(obj, "input_mode"):
|
330 |
+
obj.input_mode = getattr(obj.input_mode, "SIMPLE", "simple")
|
331 |
+
|
332 |
+
elif isinstance(obj, dict) and "module" in obj:
|
333 |
+
obj["enabled"] = False
|
334 |
+
|
335 |
+
def get_i2i_p(self, p, args: ADetailerArgs, image):
|
336 |
+
seed, subseed = self.get_seed(p)
|
337 |
+
width, height = self.get_width_height(p, args)
|
338 |
+
steps = self.get_steps(p, args)
|
339 |
+
cfg_scale = self.get_cfg_scale(p, args)
|
340 |
+
|
341 |
+
sampler_name = p.sampler_name
|
342 |
+
if sampler_name in ["PLMS", "UniPC"]:
|
343 |
+
sampler_name = "Euler"
|
344 |
+
|
345 |
+
i2i = StableDiffusionProcessingImg2Img(
|
346 |
+
init_images=[image],
|
347 |
+
resize_mode=0,
|
348 |
+
denoising_strength=args.ad_denoising_strength,
|
349 |
+
mask=None,
|
350 |
+
mask_blur=args.ad_mask_blur,
|
351 |
+
inpainting_fill=1,
|
352 |
+
inpaint_full_res=args.ad_inpaint_full_res,
|
353 |
+
inpaint_full_res_padding=args.ad_inpaint_full_res_padding,
|
354 |
+
inpainting_mask_invert=0,
|
355 |
+
sd_model=p.sd_model,
|
356 |
+
outpath_samples=p.outpath_samples,
|
357 |
+
outpath_grids=p.outpath_grids,
|
358 |
+
prompt="", # replace later
|
359 |
+
negative_prompt="",
|
360 |
+
styles=p.styles,
|
361 |
+
seed=seed,
|
362 |
+
subseed=subseed,
|
363 |
+
subseed_strength=p.subseed_strength,
|
364 |
+
seed_resize_from_h=p.seed_resize_from_h,
|
365 |
+
seed_resize_from_w=p.seed_resize_from_w,
|
366 |
+
sampler_name=sampler_name,
|
367 |
+
batch_size=1,
|
368 |
+
n_iter=1,
|
369 |
+
steps=steps,
|
370 |
+
cfg_scale=cfg_scale,
|
371 |
+
width=width,
|
372 |
+
height=height,
|
373 |
+
restore_faces=args.ad_restore_face,
|
374 |
+
tiling=p.tiling,
|
375 |
+
extra_generation_params=p.extra_generation_params,
|
376 |
+
do_not_save_samples=True,
|
377 |
+
do_not_save_grid=True,
|
378 |
+
)
|
379 |
+
|
380 |
+
i2i.scripts, i2i.script_args = self.script_filter(p, args)
|
381 |
+
i2i._disable_adetailer = True
|
382 |
+
|
383 |
+
if args.ad_controlnet_model != "None":
|
384 |
+
self.update_controlnet_args(i2i, args)
|
385 |
+
else:
|
386 |
+
i2i.control_net_enabled = False
|
387 |
+
|
388 |
+
return i2i
|
389 |
+
|
390 |
+
def save_image(self, p, image, *, condition: str, suffix: str) -> None:
|
391 |
+
i = p._idx
|
392 |
+
seed, _ = self.get_seed(p)
|
393 |
+
|
394 |
+
if opts.data.get(condition, False):
|
395 |
+
images.save_image(
|
396 |
+
image=image,
|
397 |
+
path=p.outpath_samples,
|
398 |
+
basename="",
|
399 |
+
seed=seed,
|
400 |
+
prompt=p.all_prompts[i] if i < len(p.all_prompts) else p.prompt,
|
401 |
+
extension=opts.samples_format,
|
402 |
+
info=self.infotext(p),
|
403 |
+
p=p,
|
404 |
+
suffix=suffix,
|
405 |
+
)
|
406 |
+
|
407 |
+
def get_ad_model(self, name: str):
|
408 |
+
if name not in model_mapping:
|
409 |
+
msg = f"[-] ADetailer: Model {name!r} not found. Available models: {list(model_mapping.keys())}"
|
410 |
+
raise ValueError(msg)
|
411 |
+
return model_mapping[name]
|
412 |
+
|
413 |
+
def sort_bboxes(self, pred: PredictOutput) -> PredictOutput:
|
414 |
+
sortby = opts.data.get("ad_bbox_sortby", BBOX_SORTBY[0])
|
415 |
+
sortby_idx = BBOX_SORTBY.index(sortby)
|
416 |
+
pred = sort_bboxes(pred, sortby_idx)
|
417 |
+
return pred
|
418 |
+
|
419 |
+
def pred_preprocessing(self, pred: PredictOutput, args: ADetailerArgs):
|
420 |
+
pred = filter_by_ratio(
|
421 |
+
pred, low=args.ad_mask_min_ratio, high=args.ad_mask_max_ratio
|
422 |
+
)
|
423 |
+
pred = self.sort_bboxes(pred)
|
424 |
+
return mask_preprocess(
|
425 |
+
pred.masks,
|
426 |
+
kernel=args.ad_dilate_erode,
|
427 |
+
x_offset=args.ad_x_offset,
|
428 |
+
y_offset=args.ad_y_offset,
|
429 |
+
merge_invert=args.ad_mask_merge_invert,
|
430 |
+
)
|
431 |
+
|
432 |
+
def i2i_prompts_replace(
|
433 |
+
self, i2i, prompts: list[str], negative_prompts: list[str], j: int
|
434 |
+
):
|
435 |
+
i1 = min(j, len(prompts) - 1)
|
436 |
+
i2 = min(j, len(negative_prompts) - 1)
|
437 |
+
prompt = prompts[i1]
|
438 |
+
negative_prompt = negative_prompts[i2]
|
439 |
+
i2i.prompt = prompt
|
440 |
+
i2i.negative_prompt = negative_prompt
|
441 |
+
|
442 |
+
def is_need_call_process(self, p):
|
443 |
+
i = p._idx
|
444 |
+
n_iter = p.iteration
|
445 |
+
bs = p.batch_size
|
446 |
+
return (i == (n_iter + 1) * bs - 1) and (i != len(p.all_prompts) - 1)
|
447 |
+
|
448 |
+
def process(self, p, *args_):
|
449 |
+
if getattr(p, "_disable_adetailer", False):
|
450 |
+
return
|
451 |
+
|
452 |
+
if self.is_ad_enabled(*args_):
|
453 |
+
arg_list = self.get_args(*args_)
|
454 |
+
extra_params = self.extra_params(arg_list)
|
455 |
+
p.extra_generation_params.update(extra_params)
|
456 |
+
|
457 |
+
p._idx = -1
|
458 |
+
|
459 |
+
def _postprocess_image(self, p, pp, args: ADetailerArgs, *, n: int = 0) -> bool:
|
460 |
+
"""
|
461 |
+
Returns
|
462 |
+
-------
|
463 |
+
bool
|
464 |
+
|
465 |
+
`True` if image was processed, `False` otherwise.
|
466 |
+
"""
|
467 |
+
if state.interrupted:
|
468 |
+
return False
|
469 |
+
|
470 |
+
i = p._idx
|
471 |
+
|
472 |
+
i2i = self.get_i2i_p(p, args, pp.image)
|
473 |
+
seed, subseed = self.get_seed(p)
|
474 |
+
ad_prompts, ad_negatives = self.get_prompt(p, args)
|
475 |
+
|
476 |
+
is_mediapipe = args.ad_model.lower().startswith("mediapipe")
|
477 |
+
|
478 |
+
kwargs = {}
|
479 |
+
if is_mediapipe:
|
480 |
+
predictor = mediapipe_predict
|
481 |
+
ad_model = args.ad_model
|
482 |
+
else:
|
483 |
+
predictor = ultralytics_predict
|
484 |
+
ad_model = self.get_ad_model(args.ad_model)
|
485 |
+
kwargs["device"] = self.ultralytics_device
|
486 |
+
|
487 |
+
with change_torch_load():
|
488 |
+
pred = predictor(ad_model, pp.image, args.ad_conf, **kwargs)
|
489 |
+
|
490 |
+
masks = self.pred_preprocessing(pred, args)
|
491 |
+
|
492 |
+
if not masks:
|
493 |
+
print(
|
494 |
+
f"[-] ADetailer: nothing detected on image {i + 1} with {ordinal(n + 1)} settings."
|
495 |
+
)
|
496 |
+
return False
|
497 |
+
|
498 |
+
self.save_image(
|
499 |
+
p,
|
500 |
+
pred.preview,
|
501 |
+
condition="ad_save_previews",
|
502 |
+
suffix="-ad-preview" + suffix(n, "-"),
|
503 |
+
)
|
504 |
+
|
505 |
+
steps = len(masks)
|
506 |
+
processed = None
|
507 |
+
state.job_count += steps
|
508 |
+
|
509 |
+
if is_mediapipe:
|
510 |
+
print(f"mediapipe: {steps} detected.")
|
511 |
+
|
512 |
+
p2 = copy(i2i)
|
513 |
+
for j in range(steps):
|
514 |
+
p2.image_mask = masks[j]
|
515 |
+
self.i2i_prompts_replace(p2, ad_prompts, ad_negatives, j)
|
516 |
+
|
517 |
+
if not re.match(r"^\s*\[SKIP\]\s*$", p2.prompt):
|
518 |
+
if args.ad_controlnet_model == "None":
|
519 |
+
cn_restore_unet_hook(p2, self.cn_latest_network)
|
520 |
+
processed = process_images(p2)
|
521 |
+
|
522 |
+
p2 = copy(i2i)
|
523 |
+
p2.init_images = [processed.images[0]]
|
524 |
+
|
525 |
+
p2.seed = seed + j + 1
|
526 |
+
p2.subseed = subseed + j + 1
|
527 |
+
|
528 |
+
if processed is not None:
|
529 |
+
pp.image = processed.images[0]
|
530 |
+
return True
|
531 |
+
|
532 |
+
return False
|
533 |
+
|
534 |
+
def postprocess_image(self, p, pp, *args_):
|
535 |
+
if getattr(p, "_disable_adetailer", False):
|
536 |
+
return
|
537 |
+
|
538 |
+
if not self.is_ad_enabled(*args_):
|
539 |
+
return
|
540 |
+
|
541 |
+
p._idx = getattr(p, "_idx", -1) + 1
|
542 |
+
init_image = copy(pp.image)
|
543 |
+
arg_list = self.get_args(*args_)
|
544 |
+
|
545 |
+
is_processed = False
|
546 |
+
with CNHijackRestore(), pause_total_tqdm(), cn_allow_script_control():
|
547 |
+
for n, args in enumerate(arg_list):
|
548 |
+
if args.ad_model == "None":
|
549 |
+
continue
|
550 |
+
is_processed |= self._postprocess_image(p, pp, args, n=n)
|
551 |
+
|
552 |
+
if is_processed:
|
553 |
+
self.save_image(
|
554 |
+
p, init_image, condition="ad_save_images_before", suffix="-ad-before"
|
555 |
+
)
|
556 |
+
|
557 |
+
if self.cn_script is not None and self.is_need_call_process(p):
|
558 |
+
self.cn_script.process(p)
|
559 |
+
|
560 |
+
try:
|
561 |
+
if p._idx == len(p.all_prompts) - 1:
|
562 |
+
self.write_params_txt(p)
|
563 |
+
except Exception:
|
564 |
+
pass
|
565 |
+
|
566 |
+
|
567 |
+
def on_after_component(component, **_kwargs):
|
568 |
+
global txt2img_submit_button, img2img_submit_button
|
569 |
+
if getattr(component, "elem_id", None) == "txt2img_generate":
|
570 |
+
txt2img_submit_button = component
|
571 |
+
return
|
572 |
+
|
573 |
+
if getattr(component, "elem_id", None) == "img2img_generate":
|
574 |
+
img2img_submit_button = component
|
575 |
+
|
576 |
+
|
577 |
+
def on_ui_settings():
|
578 |
+
section = ("ADetailer", AFTER_DETAILER)
|
579 |
+
shared.opts.add_option(
|
580 |
+
"ad_max_models",
|
581 |
+
shared.OptionInfo(
|
582 |
+
default=2,
|
583 |
+
label="Max models",
|
584 |
+
component=gr.Slider,
|
585 |
+
component_args={"minimum": 1, "maximum": 5, "step": 1},
|
586 |
+
section=section,
|
587 |
+
),
|
588 |
+
)
|
589 |
+
|
590 |
+
shared.opts.add_option(
|
591 |
+
"ad_save_previews",
|
592 |
+
shared.OptionInfo(False, "Save mask previews", section=section),
|
593 |
+
)
|
594 |
+
|
595 |
+
shared.opts.add_option(
|
596 |
+
"ad_save_images_before",
|
597 |
+
shared.OptionInfo(False, "Save images before ADetailer", section=section),
|
598 |
+
)
|
599 |
+
|
600 |
+
shared.opts.add_option(
|
601 |
+
"ad_only_seleted_scripts",
|
602 |
+
shared.OptionInfo(
|
603 |
+
True, "Apply only selected scripts to ADetailer", section=section
|
604 |
+
),
|
605 |
+
)
|
606 |
+
|
607 |
+
textbox_args = {
|
608 |
+
"placeholder": "comma-separated list of script names",
|
609 |
+
"interactive": True,
|
610 |
+
}
|
611 |
+
|
612 |
+
shared.opts.add_option(
|
613 |
+
"ad_script_names",
|
614 |
+
shared.OptionInfo(
|
615 |
+
default=SCRIPT_DEFAULT,
|
616 |
+
label="Script names to apply to ADetailer (separated by comma)",
|
617 |
+
component=gr.Textbox,
|
618 |
+
component_args=textbox_args,
|
619 |
+
section=section,
|
620 |
+
),
|
621 |
+
)
|
622 |
+
|
623 |
+
shared.opts.add_option(
|
624 |
+
"ad_bbox_sortby",
|
625 |
+
shared.OptionInfo(
|
626 |
+
default="None",
|
627 |
+
label="Sort bounding boxes by",
|
628 |
+
component=gr.Radio,
|
629 |
+
component_args={"choices": BBOX_SORTBY},
|
630 |
+
section=section,
|
631 |
+
),
|
632 |
+
)
|
633 |
+
|
634 |
+
########################################################
|
635 |
+
|
636 |
+
def make_axis_options():
|
637 |
+
xyz_grid = [x for x in scripts.scripts_data if x.script_class.__module__ == "xyz_grid.py"][0].module
|
638 |
+
ad_denoising_strength = getattr(p, 'inp denoising strength', args.ad_denoising_strength)
|
639 |
+
p.extra_generation_params["ad_denoising_strength"] = args.ad_denoising_strength
|
640 |
+
extra_axis_options = [
|
641 |
+
xyz_grid.AxisOption("[ade] inpaint denoising strength", float, xyz_grid.apply_field("inp denoising strength"))
|
642 |
+
|
643 |
+
]
|
644 |
+
|
645 |
+
def callbackBeforeUi():
|
646 |
+
try:
|
647 |
+
make_axis_options()
|
648 |
+
except Exception as e:
|
649 |
+
traceback.print_exc()
|
650 |
+
print(f"Failed to add support for X/Y/Z Plot Script because: {e}")
|
651 |
+
|
652 |
+
script_callbacks.on_before_ui(callbackBeforeUi)
|
653 |
+
|
654 |
+
|
655 |
+
|
656 |
+
|
657 |
+
|
658 |
+
|
659 |
+
|
660 |
+
|
661 |
+
|
662 |
+
|
663 |
+
|
664 |
+
script_callbacks.on_ui_settings(on_ui_settings)
|
665 |
+
script_callbacks.on_after_component(on_after_component)
|