Upload 22 files
Browse files- modules/timer.py +91 -0
- modules/txt2img.py +66 -0
- modules/ui.py +1366 -0
- modules/ui_checkpoint_merger.py +124 -0
- modules/ui_common.py +268 -0
- modules/ui_components.py +145 -0
- modules/ui_extensions.py +669 -0
- modules/ui_extra_networks.py +455 -0
- modules/ui_extra_networks_checkpoints.py +41 -0
- modules/ui_extra_networks_checkpoints_user_metadata.py +66 -0
- modules/ui_extra_networks_hypernets.py +39 -0
- modules/ui_extra_networks_textual_inversion.py +36 -0
- modules/ui_extra_networks_user_metadata.py +205 -0
- modules/ui_gradio_extensions.py +69 -0
- modules/ui_loadsave.py +227 -0
- modules/ui_postprocessing.py +57 -0
- modules/ui_prompt_styles.py +110 -0
- modules/ui_settings.py +296 -0
- modules/ui_tempdir.py +88 -0
- modules/upscaler.py +144 -0
- modules/util.py +58 -0
- modules/xlmr.py +137 -0
modules/timer.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
import argparse
|
3 |
+
|
4 |
+
|
5 |
+
class TimerSubcategory:
|
6 |
+
def __init__(self, timer, category):
|
7 |
+
self.timer = timer
|
8 |
+
self.category = category
|
9 |
+
self.start = None
|
10 |
+
self.original_base_category = timer.base_category
|
11 |
+
|
12 |
+
def __enter__(self):
|
13 |
+
self.start = time.time()
|
14 |
+
self.timer.base_category = self.original_base_category + self.category + "/"
|
15 |
+
self.timer.subcategory_level += 1
|
16 |
+
|
17 |
+
if self.timer.print_log:
|
18 |
+
print(f"{' ' * self.timer.subcategory_level}{self.category}:")
|
19 |
+
|
20 |
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
21 |
+
elapsed_for_subcategroy = time.time() - self.start
|
22 |
+
self.timer.base_category = self.original_base_category
|
23 |
+
self.timer.add_time_to_record(self.original_base_category + self.category, elapsed_for_subcategroy)
|
24 |
+
self.timer.subcategory_level -= 1
|
25 |
+
self.timer.record(self.category, disable_log=True)
|
26 |
+
|
27 |
+
|
28 |
+
class Timer:
|
29 |
+
def __init__(self, print_log=False):
|
30 |
+
self.start = time.time()
|
31 |
+
self.records = {}
|
32 |
+
self.total = 0
|
33 |
+
self.base_category = ''
|
34 |
+
self.print_log = print_log
|
35 |
+
self.subcategory_level = 0
|
36 |
+
|
37 |
+
def elapsed(self):
|
38 |
+
end = time.time()
|
39 |
+
res = end - self.start
|
40 |
+
self.start = end
|
41 |
+
return res
|
42 |
+
|
43 |
+
def add_time_to_record(self, category, amount):
|
44 |
+
if category not in self.records:
|
45 |
+
self.records[category] = 0
|
46 |
+
|
47 |
+
self.records[category] += amount
|
48 |
+
|
49 |
+
def record(self, category, extra_time=0, disable_log=False):
|
50 |
+
e = self.elapsed()
|
51 |
+
|
52 |
+
self.add_time_to_record(self.base_category + category, e + extra_time)
|
53 |
+
|
54 |
+
self.total += e + extra_time
|
55 |
+
|
56 |
+
if self.print_log and not disable_log:
|
57 |
+
print(f"{' ' * self.subcategory_level}{category}: done in {e + extra_time:.3f}s")
|
58 |
+
|
59 |
+
def subcategory(self, name):
|
60 |
+
self.elapsed()
|
61 |
+
|
62 |
+
subcat = TimerSubcategory(self, name)
|
63 |
+
return subcat
|
64 |
+
|
65 |
+
def summary(self):
|
66 |
+
res = f"{self.total:.1f}s"
|
67 |
+
|
68 |
+
additions = [(category, time_taken) for category, time_taken in self.records.items() if time_taken >= 0.1 and '/' not in category]
|
69 |
+
if not additions:
|
70 |
+
return res
|
71 |
+
|
72 |
+
res += " ("
|
73 |
+
res += ", ".join([f"{category}: {time_taken:.1f}s" for category, time_taken in additions])
|
74 |
+
res += ")"
|
75 |
+
|
76 |
+
return res
|
77 |
+
|
78 |
+
def dump(self):
|
79 |
+
return {'total': self.total, 'records': self.records}
|
80 |
+
|
81 |
+
def reset(self):
|
82 |
+
self.__init__()
|
83 |
+
|
84 |
+
|
85 |
+
parser = argparse.ArgumentParser(add_help=False)
|
86 |
+
parser.add_argument("--log-startup", action='store_true', help="print a detailed log of what's happening at startup")
|
87 |
+
args = parser.parse_known_args()[0]
|
88 |
+
|
89 |
+
startup_timer = Timer(print_log=args.log_startup)
|
90 |
+
|
91 |
+
startup_record = None
|
modules/txt2img.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from contextlib import closing
|
2 |
+
|
3 |
+
import modules.scripts
|
4 |
+
from modules import processing
|
5 |
+
from modules.generation_parameters_copypaste import create_override_settings_dict
|
6 |
+
from modules.shared import opts, cmd_opts
|
7 |
+
import modules.shared as shared
|
8 |
+
from modules.ui import plaintext_to_html
|
9 |
+
import gradio as gr
|
10 |
+
|
11 |
+
|
12 |
+
def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_name: str, n_iter: int, batch_size: int, cfg_scale: float, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_checkpoint_name: str, hr_sampler_name: str, hr_prompt: str, hr_negative_prompt, override_settings_texts, request: gr.Request, *args):
|
13 |
+
override_settings = create_override_settings_dict(override_settings_texts)
|
14 |
+
|
15 |
+
p = processing.StableDiffusionProcessingTxt2Img(
|
16 |
+
sd_model=shared.sd_model,
|
17 |
+
outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples,
|
18 |
+
outpath_grids=opts.outdir_grids or opts.outdir_txt2img_grids,
|
19 |
+
prompt=prompt,
|
20 |
+
styles=prompt_styles,
|
21 |
+
negative_prompt=negative_prompt,
|
22 |
+
sampler_name=sampler_name,
|
23 |
+
batch_size=batch_size,
|
24 |
+
n_iter=n_iter,
|
25 |
+
steps=steps,
|
26 |
+
cfg_scale=cfg_scale,
|
27 |
+
width=width,
|
28 |
+
height=height,
|
29 |
+
enable_hr=enable_hr,
|
30 |
+
denoising_strength=denoising_strength if enable_hr else None,
|
31 |
+
hr_scale=hr_scale,
|
32 |
+
hr_upscaler=hr_upscaler,
|
33 |
+
hr_second_pass_steps=hr_second_pass_steps,
|
34 |
+
hr_resize_x=hr_resize_x,
|
35 |
+
hr_resize_y=hr_resize_y,
|
36 |
+
hr_checkpoint_name=None if hr_checkpoint_name == 'Use same checkpoint' else hr_checkpoint_name,
|
37 |
+
hr_sampler_name=None if hr_sampler_name == 'Use same sampler' else hr_sampler_name,
|
38 |
+
hr_prompt=hr_prompt,
|
39 |
+
hr_negative_prompt=hr_negative_prompt,
|
40 |
+
override_settings=override_settings,
|
41 |
+
)
|
42 |
+
|
43 |
+
p.scripts = modules.scripts.scripts_txt2img
|
44 |
+
p.script_args = args
|
45 |
+
|
46 |
+
p.user = request.username
|
47 |
+
|
48 |
+
if cmd_opts.enable_console_prompts:
|
49 |
+
print(f"\ntxt2img: {prompt}", file=shared.progress_print_out)
|
50 |
+
|
51 |
+
with closing(p):
|
52 |
+
processed = modules.scripts.scripts_txt2img.run(p, *args)
|
53 |
+
|
54 |
+
if processed is None:
|
55 |
+
processed = processing.process_images(p)
|
56 |
+
|
57 |
+
shared.total_tqdm.clear()
|
58 |
+
|
59 |
+
generation_info_js = processed.js()
|
60 |
+
if opts.samples_log_stdout:
|
61 |
+
print(generation_info_js)
|
62 |
+
|
63 |
+
if opts.do_not_show_images:
|
64 |
+
processed.images = []
|
65 |
+
|
66 |
+
return processed.images, generation_info_js, plaintext_to_html(processed.info), plaintext_to_html(processed.comments, classname="comments")
|
modules/ui.py
ADDED
@@ -0,0 +1,1366 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import datetime
|
2 |
+
import mimetypes
|
3 |
+
import os
|
4 |
+
import sys
|
5 |
+
from functools import reduce
|
6 |
+
import warnings
|
7 |
+
|
8 |
+
import gradio as gr
|
9 |
+
import gradio.utils
|
10 |
+
import numpy as np
|
11 |
+
from PIL import Image, PngImagePlugin # noqa: F401
|
12 |
+
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call
|
13 |
+
|
14 |
+
from modules import gradio_extensons # noqa: F401
|
15 |
+
from modules import sd_hijack, sd_models, script_callbacks, ui_extensions, deepbooru, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave, shared_items, ui_settings, timer, sysinfo, ui_checkpoint_merger, ui_prompt_styles, scripts, sd_samplers, processing, ui_extra_networks
|
16 |
+
from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML, InputAccordion, ResizeHandleRow
|
17 |
+
from modules.paths import script_path
|
18 |
+
from modules.ui_common import create_refresh_button
|
19 |
+
from modules.ui_gradio_extensions import reload_javascript
|
20 |
+
|
21 |
+
from modules.shared import opts, cmd_opts
|
22 |
+
|
23 |
+
import modules.generation_parameters_copypaste as parameters_copypaste
|
24 |
+
import modules.hypernetworks.ui as hypernetworks_ui
|
25 |
+
import modules.textual_inversion.ui as textual_inversion_ui
|
26 |
+
import modules.textual_inversion.textual_inversion as textual_inversion
|
27 |
+
import modules.shared as shared
|
28 |
+
import modules.images
|
29 |
+
from modules import prompt_parser
|
30 |
+
from modules.sd_hijack import model_hijack
|
31 |
+
from modules.generation_parameters_copypaste import image_from_url_text
|
32 |
+
|
33 |
+
create_setting_component = ui_settings.create_setting_component
|
34 |
+
|
35 |
+
warnings.filterwarnings("default" if opts.show_warnings else "ignore", category=UserWarning)
|
36 |
+
warnings.filterwarnings("default" if opts.show_gradio_deprecation_warnings else "ignore", category=gr.deprecation.GradioDeprecationWarning)
|
37 |
+
|
38 |
+
# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI
|
39 |
+
mimetypes.init()
|
40 |
+
mimetypes.add_type('application/javascript', '.js')
|
41 |
+
|
42 |
+
# Likewise, add explicit content-type header for certain missing image types
|
43 |
+
mimetypes.add_type('image/webp', '.webp')
|
44 |
+
|
45 |
+
if not cmd_opts.share and not cmd_opts.listen:
|
46 |
+
# fix gradio phoning home
|
47 |
+
gradio.utils.version_check = lambda: None
|
48 |
+
gradio.utils.get_local_ip_address = lambda: '127.0.0.1'
|
49 |
+
|
50 |
+
if cmd_opts.ngrok is not None:
|
51 |
+
import modules.ngrok as ngrok
|
52 |
+
print('ngrok authtoken detected, trying to connect...')
|
53 |
+
ngrok.connect(
|
54 |
+
cmd_opts.ngrok,
|
55 |
+
cmd_opts.port if cmd_opts.port is not None else 7860,
|
56 |
+
cmd_opts.ngrok_options
|
57 |
+
)
|
58 |
+
|
59 |
+
|
60 |
+
def gr_show(visible=True):
|
61 |
+
return {"visible": visible, "__type__": "update"}
|
62 |
+
|
63 |
+
|
64 |
+
sample_img2img = "assets/stable-samples/img2img/sketch-mountains-input.jpg"
|
65 |
+
sample_img2img = sample_img2img if os.path.exists(sample_img2img) else None
|
66 |
+
|
67 |
+
# Using constants for these since the variation selector isn't visible.
|
68 |
+
# Important that they exactly match script.js for tooltip to work.
|
69 |
+
random_symbol = '\U0001f3b2\ufe0f' # 🎲️
|
70 |
+
reuse_symbol = '\u267b\ufe0f' # ♻️
|
71 |
+
paste_symbol = '\u2199\ufe0f' # ↙
|
72 |
+
refresh_symbol = '\U0001f504' # 🔄
|
73 |
+
save_style_symbol = '\U0001f4be' # 💾
|
74 |
+
apply_style_symbol = '\U0001f4cb' # 📋
|
75 |
+
clear_prompt_symbol = '\U0001f5d1\ufe0f' # 🗑️
|
76 |
+
extra_networks_symbol = '\U0001F3B4' # 🎴
|
77 |
+
switch_values_symbol = '\U000021C5' # ⇅
|
78 |
+
restore_progress_symbol = '\U0001F300' # 🌀
|
79 |
+
detect_image_size_symbol = '\U0001F4D0' # 📐
|
80 |
+
|
81 |
+
|
82 |
+
plaintext_to_html = ui_common.plaintext_to_html
|
83 |
+
|
84 |
+
|
85 |
+
def send_gradio_gallery_to_image(x):
|
86 |
+
if len(x) == 0:
|
87 |
+
return None
|
88 |
+
return image_from_url_text(x[0])
|
89 |
+
|
90 |
+
|
91 |
+
def calc_resolution_hires(enable, width, height, hr_scale, hr_resize_x, hr_resize_y):
|
92 |
+
if not enable:
|
93 |
+
return ""
|
94 |
+
|
95 |
+
p = processing.StableDiffusionProcessingTxt2Img(width=width, height=height, enable_hr=True, hr_scale=hr_scale, hr_resize_x=hr_resize_x, hr_resize_y=hr_resize_y)
|
96 |
+
p.calculate_target_resolution()
|
97 |
+
|
98 |
+
return f"from <span class='resolution'>{p.width}x{p.height}</span> to <span class='resolution'>{p.hr_resize_x or p.hr_upscale_to_x}x{p.hr_resize_y or p.hr_upscale_to_y}</span>"
|
99 |
+
|
100 |
+
|
101 |
+
def resize_from_to_html(width, height, scale_by):
|
102 |
+
target_width = int(width * scale_by)
|
103 |
+
target_height = int(height * scale_by)
|
104 |
+
|
105 |
+
if not target_width or not target_height:
|
106 |
+
return "no image selected"
|
107 |
+
|
108 |
+
return f"resize: from <span class='resolution'>{width}x{height}</span> to <span class='resolution'>{target_width}x{target_height}</span>"
|
109 |
+
|
110 |
+
|
111 |
+
def process_interrogate(interrogation_function, mode, ii_input_dir, ii_output_dir, *ii_singles):
|
112 |
+
if mode in {0, 1, 3, 4}:
|
113 |
+
return [interrogation_function(ii_singles[mode]), None]
|
114 |
+
elif mode == 2:
|
115 |
+
return [interrogation_function(ii_singles[mode]["image"]), None]
|
116 |
+
elif mode == 5:
|
117 |
+
assert not shared.cmd_opts.hide_ui_dir_config, "Launched with --hide-ui-dir-config, batch img2img disabled"
|
118 |
+
images = shared.listfiles(ii_input_dir)
|
119 |
+
print(f"Will process {len(images)} images.")
|
120 |
+
if ii_output_dir != "":
|
121 |
+
os.makedirs(ii_output_dir, exist_ok=True)
|
122 |
+
else:
|
123 |
+
ii_output_dir = ii_input_dir
|
124 |
+
|
125 |
+
for image in images:
|
126 |
+
img = Image.open(image)
|
127 |
+
filename = os.path.basename(image)
|
128 |
+
left, _ = os.path.splitext(filename)
|
129 |
+
print(interrogation_function(img), file=open(os.path.join(ii_output_dir, f"{left}.txt"), 'a', encoding='utf-8'))
|
130 |
+
|
131 |
+
return [gr.update(), None]
|
132 |
+
|
133 |
+
|
134 |
+
def interrogate(image):
|
135 |
+
prompt = shared.interrogator.interrogate(image.convert("RGB"))
|
136 |
+
return gr.update() if prompt is None else prompt
|
137 |
+
|
138 |
+
|
139 |
+
def interrogate_deepbooru(image):
|
140 |
+
prompt = deepbooru.model.tag(image)
|
141 |
+
return gr.update() if prompt is None else prompt
|
142 |
+
|
143 |
+
|
144 |
+
def connect_clear_prompt(button):
|
145 |
+
"""Given clear button, prompt, and token_counter objects, setup clear prompt button click event"""
|
146 |
+
button.click(
|
147 |
+
_js="clear_prompt",
|
148 |
+
fn=None,
|
149 |
+
inputs=[],
|
150 |
+
outputs=[],
|
151 |
+
)
|
152 |
+
|
153 |
+
|
154 |
+
def update_token_counter(text, steps):
|
155 |
+
try:
|
156 |
+
text, _ = extra_networks.parse_prompt(text)
|
157 |
+
|
158 |
+
_, prompt_flat_list, _ = prompt_parser.get_multicond_prompt_list([text])
|
159 |
+
prompt_schedules = prompt_parser.get_learned_conditioning_prompt_schedules(prompt_flat_list, steps)
|
160 |
+
|
161 |
+
except Exception:
|
162 |
+
# a parsing error can happen here during typing, and we don't want to bother the user with
|
163 |
+
# messages related to it in console
|
164 |
+
prompt_schedules = [[[steps, text]]]
|
165 |
+
|
166 |
+
flat_prompts = reduce(lambda list1, list2: list1+list2, prompt_schedules)
|
167 |
+
prompts = [prompt_text for step, prompt_text in flat_prompts]
|
168 |
+
token_count, max_length = max([model_hijack.get_prompt_lengths(prompt) for prompt in prompts], key=lambda args: args[0])
|
169 |
+
return f"<span class='gr-box gr-text-input'>{token_count}/{max_length}</span>"
|
170 |
+
|
171 |
+
|
172 |
+
class Toprow:
|
173 |
+
"""Creates a top row UI with prompts, generate button, styles, extra little buttons for things, and enables some functionality related to their operation"""
|
174 |
+
|
175 |
+
def __init__(self, is_img2img):
|
176 |
+
id_part = "img2img" if is_img2img else "txt2img"
|
177 |
+
self.id_part = id_part
|
178 |
+
|
179 |
+
with gr.Row(elem_id=f"{id_part}_toprow", variant="compact"):
|
180 |
+
with gr.Column(elem_id=f"{id_part}_prompt_container", scale=6):
|
181 |
+
with gr.Row():
|
182 |
+
with gr.Column(scale=80):
|
183 |
+
with gr.Row():
|
184 |
+
self.prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, lines=3, placeholder="Prompt (press Ctrl+Enter or Alt+Enter to generate)", elem_classes=["prompt"])
|
185 |
+
self.prompt_img = gr.File(label="", elem_id=f"{id_part}_prompt_image", file_count="single", type="binary", visible=False)
|
186 |
+
|
187 |
+
with gr.Row():
|
188 |
+
with gr.Column(scale=80):
|
189 |
+
with gr.Row():
|
190 |
+
self.negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=3, placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)", elem_classes=["prompt"])
|
191 |
+
|
192 |
+
self.button_interrogate = None
|
193 |
+
self.button_deepbooru = None
|
194 |
+
if is_img2img:
|
195 |
+
with gr.Column(scale=1, elem_classes="interrogate-col"):
|
196 |
+
self.button_interrogate = gr.Button('Interrogate\nCLIP', elem_id="interrogate")
|
197 |
+
self.button_deepbooru = gr.Button('Interrogate\nDeepBooru', elem_id="deepbooru")
|
198 |
+
|
199 |
+
with gr.Column(scale=1, elem_id=f"{id_part}_actions_column"):
|
200 |
+
with gr.Row(elem_id=f"{id_part}_generate_box", elem_classes="generate-box"):
|
201 |
+
self.interrupt = gr.Button('Interrupt', elem_id=f"{id_part}_interrupt", elem_classes="generate-box-interrupt")
|
202 |
+
self.skip = gr.Button('Skip', elem_id=f"{id_part}_skip", elem_classes="generate-box-skip")
|
203 |
+
self.submit = gr.Button('Generate', elem_id=f"{id_part}_generate", variant='primary')
|
204 |
+
|
205 |
+
self.skip.click(
|
206 |
+
fn=lambda: shared.state.skip(),
|
207 |
+
inputs=[],
|
208 |
+
outputs=[],
|
209 |
+
)
|
210 |
+
|
211 |
+
self.interrupt.click(
|
212 |
+
fn=lambda: shared.state.interrupt(),
|
213 |
+
inputs=[],
|
214 |
+
outputs=[],
|
215 |
+
)
|
216 |
+
|
217 |
+
with gr.Row(elem_id=f"{id_part}_tools"):
|
218 |
+
self.paste = ToolButton(value=paste_symbol, elem_id="paste")
|
219 |
+
self.clear_prompt_button = ToolButton(value=clear_prompt_symbol, elem_id=f"{id_part}_clear_prompt")
|
220 |
+
self.restore_progress_button = ToolButton(value=restore_progress_symbol, elem_id=f"{id_part}_restore_progress", visible=False)
|
221 |
+
|
222 |
+
self.token_counter = gr.HTML(value="<span>0/75</span>", elem_id=f"{id_part}_token_counter", elem_classes=["token-counter"])
|
223 |
+
self.token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button")
|
224 |
+
self.negative_token_counter = gr.HTML(value="<span>0/75</span>", elem_id=f"{id_part}_negative_token_counter", elem_classes=["token-counter"])
|
225 |
+
self.negative_token_button = gr.Button(visible=False, elem_id=f"{id_part}_negative_token_button")
|
226 |
+
|
227 |
+
self.clear_prompt_button.click(
|
228 |
+
fn=lambda *x: x,
|
229 |
+
_js="confirm_clear_prompt",
|
230 |
+
inputs=[self.prompt, self.negative_prompt],
|
231 |
+
outputs=[self.prompt, self.negative_prompt],
|
232 |
+
)
|
233 |
+
|
234 |
+
self.ui_styles = ui_prompt_styles.UiPromptStyles(id_part, self.prompt, self.negative_prompt)
|
235 |
+
|
236 |
+
self.prompt_img.change(
|
237 |
+
fn=modules.images.image_data,
|
238 |
+
inputs=[self.prompt_img],
|
239 |
+
outputs=[self.prompt, self.prompt_img],
|
240 |
+
show_progress=False,
|
241 |
+
)
|
242 |
+
|
243 |
+
|
244 |
+
def setup_progressbar(*args, **kwargs):
|
245 |
+
pass
|
246 |
+
|
247 |
+
|
248 |
+
def apply_setting(key, value):
|
249 |
+
if value is None:
|
250 |
+
return gr.update()
|
251 |
+
|
252 |
+
if shared.cmd_opts.freeze_settings:
|
253 |
+
return gr.update()
|
254 |
+
|
255 |
+
# dont allow model to be swapped when model hash exists in prompt
|
256 |
+
if key == "sd_model_checkpoint" and opts.disable_weights_auto_swap:
|
257 |
+
return gr.update()
|
258 |
+
|
259 |
+
if key == "sd_model_checkpoint":
|
260 |
+
ckpt_info = sd_models.get_closet_checkpoint_match(value)
|
261 |
+
|
262 |
+
if ckpt_info is not None:
|
263 |
+
value = ckpt_info.title
|
264 |
+
else:
|
265 |
+
return gr.update()
|
266 |
+
|
267 |
+
comp_args = opts.data_labels[key].component_args
|
268 |
+
if comp_args and isinstance(comp_args, dict) and comp_args.get('visible') is False:
|
269 |
+
return
|
270 |
+
|
271 |
+
valtype = type(opts.data_labels[key].default)
|
272 |
+
oldval = opts.data.get(key, None)
|
273 |
+
opts.data[key] = valtype(value) if valtype != type(None) else value
|
274 |
+
if oldval != value and opts.data_labels[key].onchange is not None:
|
275 |
+
opts.data_labels[key].onchange()
|
276 |
+
|
277 |
+
opts.save(shared.config_filename)
|
278 |
+
return getattr(opts, key)
|
279 |
+
|
280 |
+
|
281 |
+
def create_output_panel(tabname, outdir):
|
282 |
+
return ui_common.create_output_panel(tabname, outdir)
|
283 |
+
|
284 |
+
|
285 |
+
def create_sampler_and_steps_selection(choices, tabname):
|
286 |
+
if opts.samplers_in_dropdown:
|
287 |
+
with FormRow(elem_id=f"sampler_selection_{tabname}"):
|
288 |
+
sampler_name = gr.Dropdown(label='Sampling method', elem_id=f"{tabname}_sampling", choices=choices, value=choices[0])
|
289 |
+
steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling steps", value=20)
|
290 |
+
else:
|
291 |
+
with FormGroup(elem_id=f"sampler_selection_{tabname}"):
|
292 |
+
steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling steps", value=20)
|
293 |
+
sampler_name = gr.Radio(label='Sampling method', elem_id=f"{tabname}_sampling", choices=choices, value=choices[0])
|
294 |
+
|
295 |
+
return steps, sampler_name
|
296 |
+
|
297 |
+
|
298 |
+
def ordered_ui_categories():
|
299 |
+
user_order = {x.strip(): i * 2 + 1 for i, x in enumerate(shared.opts.ui_reorder_list)}
|
300 |
+
|
301 |
+
for _, category in sorted(enumerate(shared_items.ui_reorder_categories()), key=lambda x: user_order.get(x[1], x[0] * 2 + 0)):
|
302 |
+
yield category
|
303 |
+
|
304 |
+
|
305 |
+
def create_override_settings_dropdown(tabname, row):
|
306 |
+
dropdown = gr.Dropdown([], label="Override settings", visible=False, elem_id=f"{tabname}_override_settings", multiselect=True)
|
307 |
+
|
308 |
+
dropdown.change(
|
309 |
+
fn=lambda x: gr.Dropdown.update(visible=bool(x)),
|
310 |
+
inputs=[dropdown],
|
311 |
+
outputs=[dropdown],
|
312 |
+
)
|
313 |
+
|
314 |
+
return dropdown
|
315 |
+
|
316 |
+
|
317 |
+
def create_ui():
|
318 |
+
import modules.img2img
|
319 |
+
import modules.txt2img
|
320 |
+
|
321 |
+
reload_javascript()
|
322 |
+
|
323 |
+
parameters_copypaste.reset()
|
324 |
+
|
325 |
+
scripts.scripts_current = scripts.scripts_txt2img
|
326 |
+
scripts.scripts_txt2img.initialize_scripts(is_img2img=False)
|
327 |
+
|
328 |
+
with gr.Blocks(analytics_enabled=False) as txt2img_interface:
|
329 |
+
toprow = Toprow(is_img2img=False)
|
330 |
+
|
331 |
+
dummy_component = gr.Label(visible=False)
|
332 |
+
|
333 |
+
extra_tabs = gr.Tabs(elem_id="txt2img_extra_tabs")
|
334 |
+
extra_tabs.__enter__()
|
335 |
+
|
336 |
+
with gr.Tab("Generation", id="txt2img_generation") as txt2img_generation_tab, ResizeHandleRow(equal_height=False):
|
337 |
+
with gr.Column(variant='compact', elem_id="txt2img_settings"):
|
338 |
+
scripts.scripts_txt2img.prepare_ui()
|
339 |
+
|
340 |
+
for category in ordered_ui_categories():
|
341 |
+
if category == "sampler":
|
342 |
+
steps, sampler_name = create_sampler_and_steps_selection(sd_samplers.visible_sampler_names(), "txt2img")
|
343 |
+
|
344 |
+
elif category == "dimensions":
|
345 |
+
with FormRow():
|
346 |
+
with gr.Column(elem_id="txt2img_column_size", scale=4):
|
347 |
+
width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="txt2img_width")
|
348 |
+
height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="txt2img_height")
|
349 |
+
|
350 |
+
with gr.Column(elem_id="txt2img_dimensions_row", scale=1, elem_classes="dimensions-tools"):
|
351 |
+
res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="txt2img_res_switch_btn", label="Switch dims")
|
352 |
+
|
353 |
+
if opts.dimensions_and_batch_together:
|
354 |
+
with gr.Column(elem_id="txt2img_column_batch"):
|
355 |
+
batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="txt2img_batch_count")
|
356 |
+
batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="txt2img_batch_size")
|
357 |
+
|
358 |
+
elif category == "cfg":
|
359 |
+
with gr.Row():
|
360 |
+
cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="txt2img_cfg_scale")
|
361 |
+
|
362 |
+
elif category == "checkboxes":
|
363 |
+
with FormRow(elem_classes="checkboxes-row", variant="compact"):
|
364 |
+
pass
|
365 |
+
|
366 |
+
elif category == "accordions":
|
367 |
+
with gr.Row(elem_id="txt2img_accordions", elem_classes="accordions"):
|
368 |
+
with InputAccordion(False, label="Hires. fix", elem_id="txt2img_hr") as enable_hr:
|
369 |
+
with enable_hr.extra():
|
370 |
+
hr_final_resolution = FormHTML(value="", elem_id="txtimg_hr_finalres", label="Upscaled resolution", interactive=False, min_width=0)
|
371 |
+
|
372 |
+
with FormRow(elem_id="txt2img_hires_fix_row1", variant="compact"):
|
373 |
+
hr_upscaler = gr.Dropdown(label="Upscaler", elem_id="txt2img_hr_upscaler", choices=[*shared.latent_upscale_modes, *[x.name for x in shared.sd_upscalers]], value=shared.latent_upscale_default_mode)
|
374 |
+
hr_second_pass_steps = gr.Slider(minimum=0, maximum=150, step=1, label='Hires steps', value=0, elem_id="txt2img_hires_steps")
|
375 |
+
denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7, elem_id="txt2img_denoising_strength")
|
376 |
+
|
377 |
+
with FormRow(elem_id="txt2img_hires_fix_row2", variant="compact"):
|
378 |
+
hr_scale = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Upscale by", value=2.0, elem_id="txt2img_hr_scale")
|
379 |
+
hr_resize_x = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize width to", value=0, elem_id="txt2img_hr_resize_x")
|
380 |
+
hr_resize_y = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize height to", value=0, elem_id="txt2img_hr_resize_y")
|
381 |
+
|
382 |
+
with FormRow(elem_id="txt2img_hires_fix_row3", variant="compact", visible=opts.hires_fix_show_sampler) as hr_sampler_container:
|
383 |
+
|
384 |
+
hr_checkpoint_name = gr.Dropdown(label='Hires checkpoint', elem_id="hr_checkpoint", choices=["Use same checkpoint"] + modules.sd_models.checkpoint_tiles(use_short=True), value="Use same checkpoint")
|
385 |
+
create_refresh_button(hr_checkpoint_name, modules.sd_models.list_models, lambda: {"choices": ["Use same checkpoint"] + modules.sd_models.checkpoint_tiles(use_short=True)}, "hr_checkpoint_refresh")
|
386 |
+
|
387 |
+
hr_sampler_name = gr.Dropdown(label='Hires sampling method', elem_id="hr_sampler", choices=["Use same sampler"] + sd_samplers.visible_sampler_names(), value="Use same sampler")
|
388 |
+
|
389 |
+
with FormRow(elem_id="txt2img_hires_fix_row4", variant="compact", visible=opts.hires_fix_show_prompts) as hr_prompts_container:
|
390 |
+
with gr.Column(scale=80):
|
391 |
+
with gr.Row():
|
392 |
+
hr_prompt = gr.Textbox(label="Hires prompt", elem_id="hires_prompt", show_label=False, lines=3, placeholder="Prompt for hires fix pass.\nLeave empty to use the same prompt as in first pass.", elem_classes=["prompt"])
|
393 |
+
with gr.Column(scale=80):
|
394 |
+
with gr.Row():
|
395 |
+
hr_negative_prompt = gr.Textbox(label="Hires negative prompt", elem_id="hires_neg_prompt", show_label=False, lines=3, placeholder="Negative prompt for hires fix pass.\nLeave empty to use the same negative prompt as in first pass.", elem_classes=["prompt"])
|
396 |
+
|
397 |
+
scripts.scripts_txt2img.setup_ui_for_section(category)
|
398 |
+
|
399 |
+
elif category == "batch":
|
400 |
+
if not opts.dimensions_and_batch_together:
|
401 |
+
with FormRow(elem_id="txt2img_column_batch"):
|
402 |
+
batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="txt2img_batch_count")
|
403 |
+
batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="txt2img_batch_size")
|
404 |
+
|
405 |
+
elif category == "override_settings":
|
406 |
+
with FormRow(elem_id="txt2img_override_settings_row") as row:
|
407 |
+
override_settings = create_override_settings_dropdown('txt2img', row)
|
408 |
+
|
409 |
+
elif category == "scripts":
|
410 |
+
with FormGroup(elem_id="txt2img_script_container"):
|
411 |
+
custom_inputs = scripts.scripts_txt2img.setup_ui()
|
412 |
+
|
413 |
+
if category not in {"accordions"}:
|
414 |
+
scripts.scripts_txt2img.setup_ui_for_section(category)
|
415 |
+
|
416 |
+
hr_resolution_preview_inputs = [enable_hr, width, height, hr_scale, hr_resize_x, hr_resize_y]
|
417 |
+
|
418 |
+
for component in hr_resolution_preview_inputs:
|
419 |
+
event = component.release if isinstance(component, gr.Slider) else component.change
|
420 |
+
|
421 |
+
event(
|
422 |
+
fn=calc_resolution_hires,
|
423 |
+
inputs=hr_resolution_preview_inputs,
|
424 |
+
outputs=[hr_final_resolution],
|
425 |
+
show_progress=False,
|
426 |
+
)
|
427 |
+
event(
|
428 |
+
None,
|
429 |
+
_js="onCalcResolutionHires",
|
430 |
+
inputs=hr_resolution_preview_inputs,
|
431 |
+
outputs=[],
|
432 |
+
show_progress=False,
|
433 |
+
)
|
434 |
+
|
435 |
+
txt2img_gallery, generation_info, html_info, html_log = create_output_panel("txt2img", opts.outdir_txt2img_samples)
|
436 |
+
|
437 |
+
txt2img_args = dict(
|
438 |
+
fn=wrap_gradio_gpu_call(modules.txt2img.txt2img, extra_outputs=[None, '', '']),
|
439 |
+
_js="submit",
|
440 |
+
inputs=[
|
441 |
+
dummy_component,
|
442 |
+
toprow.prompt,
|
443 |
+
toprow.negative_prompt,
|
444 |
+
toprow.ui_styles.dropdown,
|
445 |
+
steps,
|
446 |
+
sampler_name,
|
447 |
+
batch_count,
|
448 |
+
batch_size,
|
449 |
+
cfg_scale,
|
450 |
+
height,
|
451 |
+
width,
|
452 |
+
enable_hr,
|
453 |
+
denoising_strength,
|
454 |
+
hr_scale,
|
455 |
+
hr_upscaler,
|
456 |
+
hr_second_pass_steps,
|
457 |
+
hr_resize_x,
|
458 |
+
hr_resize_y,
|
459 |
+
hr_checkpoint_name,
|
460 |
+
hr_sampler_name,
|
461 |
+
hr_prompt,
|
462 |
+
hr_negative_prompt,
|
463 |
+
override_settings,
|
464 |
+
|
465 |
+
] + custom_inputs,
|
466 |
+
|
467 |
+
outputs=[
|
468 |
+
txt2img_gallery,
|
469 |
+
generation_info,
|
470 |
+
html_info,
|
471 |
+
html_log,
|
472 |
+
],
|
473 |
+
show_progress=False,
|
474 |
+
)
|
475 |
+
|
476 |
+
toprow.prompt.submit(**txt2img_args)
|
477 |
+
toprow.submit.click(**txt2img_args)
|
478 |
+
|
479 |
+
res_switch_btn.click(fn=None, _js="function(){switchWidthHeight('txt2img')}", inputs=None, outputs=None, show_progress=False)
|
480 |
+
|
481 |
+
toprow.restore_progress_button.click(
|
482 |
+
fn=progress.restore_progress,
|
483 |
+
_js="restoreProgressTxt2img",
|
484 |
+
inputs=[dummy_component],
|
485 |
+
outputs=[
|
486 |
+
txt2img_gallery,
|
487 |
+
generation_info,
|
488 |
+
html_info,
|
489 |
+
html_log,
|
490 |
+
],
|
491 |
+
show_progress=False,
|
492 |
+
)
|
493 |
+
|
494 |
+
txt2img_paste_fields = [
|
495 |
+
(toprow.prompt, "Prompt"),
|
496 |
+
(toprow.negative_prompt, "Negative prompt"),
|
497 |
+
(steps, "Steps"),
|
498 |
+
(sampler_name, "Sampler"),
|
499 |
+
(cfg_scale, "CFG scale"),
|
500 |
+
(width, "Size-1"),
|
501 |
+
(height, "Size-2"),
|
502 |
+
(batch_size, "Batch size"),
|
503 |
+
(toprow.ui_styles.dropdown, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update()),
|
504 |
+
(denoising_strength, "Denoising strength"),
|
505 |
+
(enable_hr, lambda d: "Denoising strength" in d and ("Hires upscale" in d or "Hires upscaler" in d or "Hires resize-1" in d)),
|
506 |
+
(hr_scale, "Hires upscale"),
|
507 |
+
(hr_upscaler, "Hires upscaler"),
|
508 |
+
(hr_second_pass_steps, "Hires steps"),
|
509 |
+
(hr_resize_x, "Hires resize-1"),
|
510 |
+
(hr_resize_y, "Hires resize-2"),
|
511 |
+
(hr_checkpoint_name, "Hires checkpoint"),
|
512 |
+
(hr_sampler_name, "Hires sampler"),
|
513 |
+
(hr_sampler_container, lambda d: gr.update(visible=True) if d.get("Hires sampler", "Use same sampler") != "Use same sampler" or d.get("Hires checkpoint", "Use same checkpoint") != "Use same checkpoint" else gr.update()),
|
514 |
+
(hr_prompt, "Hires prompt"),
|
515 |
+
(hr_negative_prompt, "Hires negative prompt"),
|
516 |
+
(hr_prompts_container, lambda d: gr.update(visible=True) if d.get("Hires prompt", "") != "" or d.get("Hires negative prompt", "") != "" else gr.update()),
|
517 |
+
*scripts.scripts_txt2img.infotext_fields
|
518 |
+
]
|
519 |
+
parameters_copypaste.add_paste_fields("txt2img", None, txt2img_paste_fields, override_settings)
|
520 |
+
parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding(
|
521 |
+
paste_button=toprow.paste, tabname="txt2img", source_text_component=toprow.prompt, source_image_component=None,
|
522 |
+
))
|
523 |
+
|
524 |
+
txt2img_preview_params = [
|
525 |
+
toprow.prompt,
|
526 |
+
toprow.negative_prompt,
|
527 |
+
steps,
|
528 |
+
sampler_name,
|
529 |
+
cfg_scale,
|
530 |
+
scripts.scripts_txt2img.script('Seed').seed,
|
531 |
+
width,
|
532 |
+
height,
|
533 |
+
]
|
534 |
+
|
535 |
+
toprow.token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[toprow.prompt, steps], outputs=[toprow.token_counter])
|
536 |
+
toprow.negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[toprow.negative_prompt, steps], outputs=[toprow.negative_token_counter])
|
537 |
+
|
538 |
+
extra_networks_ui = ui_extra_networks.create_ui(txt2img_interface, [txt2img_generation_tab], 'txt2img')
|
539 |
+
ui_extra_networks.setup_ui(extra_networks_ui, txt2img_gallery)
|
540 |
+
|
541 |
+
extra_tabs.__exit__()
|
542 |
+
|
543 |
+
scripts.scripts_current = scripts.scripts_img2img
|
544 |
+
scripts.scripts_img2img.initialize_scripts(is_img2img=True)
|
545 |
+
|
546 |
+
with gr.Blocks(analytics_enabled=False) as img2img_interface:
|
547 |
+
toprow = Toprow(is_img2img=True)
|
548 |
+
|
549 |
+
extra_tabs = gr.Tabs(elem_id="img2img_extra_tabs")
|
550 |
+
extra_tabs.__enter__()
|
551 |
+
|
552 |
+
with gr.Tab("Generation", id="img2img_generation") as img2img_generation_tab, ResizeHandleRow(equal_height=False):
|
553 |
+
with gr.Column(variant='compact', elem_id="img2img_settings"):
|
554 |
+
copy_image_buttons = []
|
555 |
+
copy_image_destinations = {}
|
556 |
+
|
557 |
+
def add_copy_image_controls(tab_name, elem):
|
558 |
+
with gr.Row(variant="compact", elem_id=f"img2img_copy_to_{tab_name}"):
|
559 |
+
gr.HTML("Copy image to: ", elem_id=f"img2img_label_copy_to_{tab_name}")
|
560 |
+
|
561 |
+
for title, name in zip(['img2img', 'sketch', 'inpaint', 'inpaint sketch'], ['img2img', 'sketch', 'inpaint', 'inpaint_sketch']):
|
562 |
+
if name == tab_name:
|
563 |
+
gr.Button(title, interactive=False)
|
564 |
+
copy_image_destinations[name] = elem
|
565 |
+
continue
|
566 |
+
|
567 |
+
button = gr.Button(title)
|
568 |
+
copy_image_buttons.append((button, name, elem))
|
569 |
+
|
570 |
+
with gr.Tabs(elem_id="mode_img2img"):
|
571 |
+
img2img_selected_tab = gr.State(0)
|
572 |
+
|
573 |
+
with gr.TabItem('img2img', id='img2img', elem_id="img2img_img2img_tab") as tab_img2img:
|
574 |
+
init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool="editor", image_mode="RGBA", height=opts.img2img_editor_height)
|
575 |
+
add_copy_image_controls('img2img', init_img)
|
576 |
+
|
577 |
+
with gr.TabItem('Sketch', id='img2img_sketch', elem_id="img2img_img2img_sketch_tab") as tab_sketch:
|
578 |
+
sketch = gr.Image(label="Image for img2img", elem_id="img2img_sketch", show_label=False, source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGB", height=opts.img2img_editor_height, brush_color=opts.img2img_sketch_default_brush_color)
|
579 |
+
add_copy_image_controls('sketch', sketch)
|
580 |
+
|
581 |
+
with gr.TabItem('Inpaint', id='inpaint', elem_id="img2img_inpaint_tab") as tab_inpaint:
|
582 |
+
init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", image_mode="RGBA", height=opts.img2img_editor_height, brush_color=opts.img2img_inpaint_mask_brush_color)
|
583 |
+
add_copy_image_controls('inpaint', init_img_with_mask)
|
584 |
+
|
585 |
+
with gr.TabItem('Inpaint sketch', id='inpaint_sketch', elem_id="img2img_inpaint_sketch_tab") as tab_inpaint_color:
|
586 |
+
inpaint_color_sketch = gr.Image(label="Color sketch inpainting", show_label=False, elem_id="inpaint_sketch", source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGB", height=opts.img2img_editor_height, brush_color=opts.img2img_inpaint_sketch_default_brush_color)
|
587 |
+
inpaint_color_sketch_orig = gr.State(None)
|
588 |
+
add_copy_image_controls('inpaint_sketch', inpaint_color_sketch)
|
589 |
+
|
590 |
+
def update_orig(image, state):
|
591 |
+
if image is not None:
|
592 |
+
same_size = state is not None and state.size == image.size
|
593 |
+
has_exact_match = np.any(np.all(np.array(image) == np.array(state), axis=-1))
|
594 |
+
edited = same_size and has_exact_match
|
595 |
+
return image if not edited or state is None else state
|
596 |
+
|
597 |
+
inpaint_color_sketch.change(update_orig, [inpaint_color_sketch, inpaint_color_sketch_orig], inpaint_color_sketch_orig)
|
598 |
+
|
599 |
+
with gr.TabItem('Inpaint upload', id='inpaint_upload', elem_id="img2img_inpaint_upload_tab") as tab_inpaint_upload:
|
600 |
+
init_img_inpaint = gr.Image(label="Image for img2img", show_label=False, source="upload", interactive=True, type="pil", elem_id="img_inpaint_base")
|
601 |
+
init_mask_inpaint = gr.Image(label="Mask", source="upload", interactive=True, type="pil", image_mode="RGBA", elem_id="img_inpaint_mask")
|
602 |
+
|
603 |
+
with gr.TabItem('Batch', id='batch', elem_id="img2img_batch_tab") as tab_batch:
|
604 |
+
hidden = '<br>Disabled when launched with --hide-ui-dir-config.' if shared.cmd_opts.hide_ui_dir_config else ''
|
605 |
+
gr.HTML(
|
606 |
+
"<p style='padding-bottom: 1em;' class=\"text-gray-500\">Process images in a directory on the same machine where the server is running." +
|
607 |
+
"<br>Use an empty output directory to save pictures normally instead of writing to the output directory." +
|
608 |
+
f"<br>Add inpaint batch mask directory to enable inpaint batch processing."
|
609 |
+
f"{hidden}</p>"
|
610 |
+
)
|
611 |
+
img2img_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, elem_id="img2img_batch_input_dir")
|
612 |
+
img2img_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, elem_id="img2img_batch_output_dir")
|
613 |
+
img2img_batch_inpaint_mask_dir = gr.Textbox(label="Inpaint batch mask directory (required for inpaint batch processing only)", **shared.hide_dirs, elem_id="img2img_batch_inpaint_mask_dir")
|
614 |
+
with gr.Accordion("PNG info", open=False):
|
615 |
+
img2img_batch_use_png_info = gr.Checkbox(label="Append png info to prompts", **shared.hide_dirs, elem_id="img2img_batch_use_png_info")
|
616 |
+
img2img_batch_png_info_dir = gr.Textbox(label="PNG info directory", **shared.hide_dirs, placeholder="Leave empty to use input directory", elem_id="img2img_batch_png_info_dir")
|
617 |
+
img2img_batch_png_info_props = gr.CheckboxGroup(["Prompt", "Negative prompt", "Seed", "CFG scale", "Sampler", "Steps"], label="Parameters to take from png info", info="Prompts from png info will be appended to prompts set in ui.")
|
618 |
+
|
619 |
+
img2img_tabs = [tab_img2img, tab_sketch, tab_inpaint, tab_inpaint_color, tab_inpaint_upload, tab_batch]
|
620 |
+
|
621 |
+
for i, tab in enumerate(img2img_tabs):
|
622 |
+
tab.select(fn=lambda tabnum=i: tabnum, inputs=[], outputs=[img2img_selected_tab])
|
623 |
+
|
624 |
+
def copy_image(img):
|
625 |
+
if isinstance(img, dict) and 'image' in img:
|
626 |
+
return img['image']
|
627 |
+
|
628 |
+
return img
|
629 |
+
|
630 |
+
for button, name, elem in copy_image_buttons:
|
631 |
+
button.click(
|
632 |
+
fn=copy_image,
|
633 |
+
inputs=[elem],
|
634 |
+
outputs=[copy_image_destinations[name]],
|
635 |
+
)
|
636 |
+
button.click(
|
637 |
+
fn=lambda: None,
|
638 |
+
_js=f"switch_to_{name.replace(' ', '_')}",
|
639 |
+
inputs=[],
|
640 |
+
outputs=[],
|
641 |
+
)
|
642 |
+
|
643 |
+
with FormRow():
|
644 |
+
resize_mode = gr.Radio(label="Resize mode", elem_id="resize_mode", choices=["Just resize", "Crop and resize", "Resize and fill", "Just resize (latent upscale)"], type="index", value="Just resize")
|
645 |
+
|
646 |
+
scripts.scripts_img2img.prepare_ui()
|
647 |
+
|
648 |
+
for category in ordered_ui_categories():
|
649 |
+
if category == "sampler":
|
650 |
+
steps, sampler_name = create_sampler_and_steps_selection(sd_samplers.visible_sampler_names(), "img2img")
|
651 |
+
|
652 |
+
elif category == "dimensions":
|
653 |
+
with FormRow():
|
654 |
+
with gr.Column(elem_id="img2img_column_size", scale=4):
|
655 |
+
selected_scale_tab = gr.State(value=0)
|
656 |
+
|
657 |
+
with gr.Tabs():
|
658 |
+
with gr.Tab(label="Resize to", elem_id="img2img_tab_resize_to") as tab_scale_to:
|
659 |
+
with FormRow():
|
660 |
+
with gr.Column(elem_id="img2img_column_size", scale=4):
|
661 |
+
width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="img2img_width")
|
662 |
+
height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="img2img_height")
|
663 |
+
with gr.Column(elem_id="img2img_dimensions_row", scale=1, elem_classes="dimensions-tools"):
|
664 |
+
res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="img2img_res_switch_btn")
|
665 |
+
detect_image_size_btn = ToolButton(value=detect_image_size_symbol, elem_id="img2img_detect_image_size_btn")
|
666 |
+
|
667 |
+
with gr.Tab(label="Resize by", elem_id="img2img_tab_resize_by") as tab_scale_by:
|
668 |
+
scale_by = gr.Slider(minimum=0.05, maximum=4.0, step=0.05, label="Scale", value=1.0, elem_id="img2img_scale")
|
669 |
+
|
670 |
+
with FormRow():
|
671 |
+
scale_by_html = FormHTML(resize_from_to_html(0, 0, 0.0), elem_id="img2img_scale_resolution_preview")
|
672 |
+
gr.Slider(label="Unused", elem_id="img2img_unused_scale_by_slider")
|
673 |
+
button_update_resize_to = gr.Button(visible=False, elem_id="img2img_update_resize_to")
|
674 |
+
|
675 |
+
on_change_args = dict(
|
676 |
+
fn=resize_from_to_html,
|
677 |
+
_js="currentImg2imgSourceResolution",
|
678 |
+
inputs=[dummy_component, dummy_component, scale_by],
|
679 |
+
outputs=scale_by_html,
|
680 |
+
show_progress=False,
|
681 |
+
)
|
682 |
+
|
683 |
+
scale_by.release(**on_change_args)
|
684 |
+
button_update_resize_to.click(**on_change_args)
|
685 |
+
|
686 |
+
# the code below is meant to update the resolution label after the image in the image selection UI has changed.
|
687 |
+
# as it is now the event keeps firing continuously for inpaint edits, which ruins the page with constant requests.
|
688 |
+
# I assume this must be a gradio bug and for now we'll just do it for non-inpaint inputs.
|
689 |
+
for component in [init_img, sketch]:
|
690 |
+
component.change(fn=lambda: None, _js="updateImg2imgResizeToTextAfterChangingImage", inputs=[], outputs=[], show_progress=False)
|
691 |
+
|
692 |
+
tab_scale_to.select(fn=lambda: 0, inputs=[], outputs=[selected_scale_tab])
|
693 |
+
tab_scale_by.select(fn=lambda: 1, inputs=[], outputs=[selected_scale_tab])
|
694 |
+
|
695 |
+
if opts.dimensions_and_batch_together:
|
696 |
+
with gr.Column(elem_id="img2img_column_batch"):
|
697 |
+
batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="img2img_batch_count")
|
698 |
+
batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="img2img_batch_size")
|
699 |
+
|
700 |
+
elif category == "denoising":
|
701 |
+
denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.75, elem_id="img2img_denoising_strength")
|
702 |
+
|
703 |
+
elif category == "cfg":
|
704 |
+
with gr.Row():
|
705 |
+
cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="img2img_cfg_scale")
|
706 |
+
image_cfg_scale = gr.Slider(minimum=0, maximum=3.0, step=0.05, label='Image CFG Scale', value=1.5, elem_id="img2img_image_cfg_scale", visible=False)
|
707 |
+
|
708 |
+
elif category == "checkboxes":
|
709 |
+
with FormRow(elem_classes="checkboxes-row", variant="compact"):
|
710 |
+
pass
|
711 |
+
|
712 |
+
elif category == "accordions":
|
713 |
+
with gr.Row(elem_id="img2img_accordions", elem_classes="accordions"):
|
714 |
+
scripts.scripts_img2img.setup_ui_for_section(category)
|
715 |
+
|
716 |
+
elif category == "batch":
|
717 |
+
if not opts.dimensions_and_batch_together:
|
718 |
+
with FormRow(elem_id="img2img_column_batch"):
|
719 |
+
batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="img2img_batch_count")
|
720 |
+
batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="img2img_batch_size")
|
721 |
+
|
722 |
+
elif category == "override_settings":
|
723 |
+
with FormRow(elem_id="img2img_override_settings_row") as row:
|
724 |
+
override_settings = create_override_settings_dropdown('img2img', row)
|
725 |
+
|
726 |
+
elif category == "scripts":
|
727 |
+
with FormGroup(elem_id="img2img_script_container"):
|
728 |
+
custom_inputs = scripts.scripts_img2img.setup_ui()
|
729 |
+
|
730 |
+
elif category == "inpaint":
|
731 |
+
with FormGroup(elem_id="inpaint_controls", visible=False) as inpaint_controls:
|
732 |
+
with FormRow():
|
733 |
+
mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, elem_id="img2img_mask_blur")
|
734 |
+
mask_alpha = gr.Slider(label="Mask transparency", visible=False, elem_id="img2img_mask_alpha")
|
735 |
+
|
736 |
+
with FormRow():
|
737 |
+
inpainting_mask_invert = gr.Radio(label='Mask mode', choices=['Inpaint masked', 'Inpaint not masked'], value='Inpaint masked', type="index", elem_id="img2img_mask_mode")
|
738 |
+
|
739 |
+
with FormRow():
|
740 |
+
inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='original', type="index", elem_id="img2img_inpainting_fill")
|
741 |
+
|
742 |
+
with FormRow():
|
743 |
+
with gr.Column():
|
744 |
+
inpaint_full_res = gr.Radio(label="Inpaint area", choices=["Whole picture", "Only masked"], type="index", value="Whole picture", elem_id="img2img_inpaint_full_res")
|
745 |
+
|
746 |
+
with gr.Column(scale=4):
|
747 |
+
inpaint_full_res_padding = gr.Slider(label='Only masked padding, pixels', minimum=0, maximum=256, step=4, value=32, elem_id="img2img_inpaint_full_res_padding")
|
748 |
+
|
749 |
+
def select_img2img_tab(tab):
|
750 |
+
return gr.update(visible=tab in [2, 3, 4]), gr.update(visible=tab == 3),
|
751 |
+
|
752 |
+
for i, elem in enumerate(img2img_tabs):
|
753 |
+
elem.select(
|
754 |
+
fn=lambda tab=i: select_img2img_tab(tab),
|
755 |
+
inputs=[],
|
756 |
+
outputs=[inpaint_controls, mask_alpha],
|
757 |
+
)
|
758 |
+
|
759 |
+
if category not in {"accordions"}:
|
760 |
+
scripts.scripts_img2img.setup_ui_for_section(category)
|
761 |
+
|
762 |
+
img2img_gallery, generation_info, html_info, html_log = create_output_panel("img2img", opts.outdir_img2img_samples)
|
763 |
+
|
764 |
+
img2img_args = dict(
|
765 |
+
fn=wrap_gradio_gpu_call(modules.img2img.img2img, extra_outputs=[None, '', '']),
|
766 |
+
_js="submit_img2img",
|
767 |
+
inputs=[
|
768 |
+
dummy_component,
|
769 |
+
dummy_component,
|
770 |
+
toprow.prompt,
|
771 |
+
toprow.negative_prompt,
|
772 |
+
toprow.ui_styles.dropdown,
|
773 |
+
init_img,
|
774 |
+
sketch,
|
775 |
+
init_img_with_mask,
|
776 |
+
inpaint_color_sketch,
|
777 |
+
inpaint_color_sketch_orig,
|
778 |
+
init_img_inpaint,
|
779 |
+
init_mask_inpaint,
|
780 |
+
steps,
|
781 |
+
sampler_name,
|
782 |
+
mask_blur,
|
783 |
+
mask_alpha,
|
784 |
+
inpainting_fill,
|
785 |
+
batch_count,
|
786 |
+
batch_size,
|
787 |
+
cfg_scale,
|
788 |
+
image_cfg_scale,
|
789 |
+
denoising_strength,
|
790 |
+
selected_scale_tab,
|
791 |
+
height,
|
792 |
+
width,
|
793 |
+
scale_by,
|
794 |
+
resize_mode,
|
795 |
+
inpaint_full_res,
|
796 |
+
inpaint_full_res_padding,
|
797 |
+
inpainting_mask_invert,
|
798 |
+
img2img_batch_input_dir,
|
799 |
+
img2img_batch_output_dir,
|
800 |
+
img2img_batch_inpaint_mask_dir,
|
801 |
+
override_settings,
|
802 |
+
img2img_batch_use_png_info,
|
803 |
+
img2img_batch_png_info_props,
|
804 |
+
img2img_batch_png_info_dir,
|
805 |
+
] + custom_inputs,
|
806 |
+
outputs=[
|
807 |
+
img2img_gallery,
|
808 |
+
generation_info,
|
809 |
+
html_info,
|
810 |
+
html_log,
|
811 |
+
],
|
812 |
+
show_progress=False,
|
813 |
+
)
|
814 |
+
|
815 |
+
interrogate_args = dict(
|
816 |
+
_js="get_img2img_tab_index",
|
817 |
+
inputs=[
|
818 |
+
dummy_component,
|
819 |
+
img2img_batch_input_dir,
|
820 |
+
img2img_batch_output_dir,
|
821 |
+
init_img,
|
822 |
+
sketch,
|
823 |
+
init_img_with_mask,
|
824 |
+
inpaint_color_sketch,
|
825 |
+
init_img_inpaint,
|
826 |
+
],
|
827 |
+
outputs=[toprow.prompt, dummy_component],
|
828 |
+
)
|
829 |
+
|
830 |
+
toprow.prompt.submit(**img2img_args)
|
831 |
+
toprow.submit.click(**img2img_args)
|
832 |
+
|
833 |
+
res_switch_btn.click(fn=None, _js="function(){switchWidthHeight('img2img')}", inputs=None, outputs=None, show_progress=False)
|
834 |
+
|
835 |
+
detect_image_size_btn.click(
|
836 |
+
fn=lambda w, h, _: (w or gr.update(), h or gr.update()),
|
837 |
+
_js="currentImg2imgSourceResolution",
|
838 |
+
inputs=[dummy_component, dummy_component, dummy_component],
|
839 |
+
outputs=[width, height],
|
840 |
+
show_progress=False,
|
841 |
+
)
|
842 |
+
|
843 |
+
toprow.restore_progress_button.click(
|
844 |
+
fn=progress.restore_progress,
|
845 |
+
_js="restoreProgressImg2img",
|
846 |
+
inputs=[dummy_component],
|
847 |
+
outputs=[
|
848 |
+
img2img_gallery,
|
849 |
+
generation_info,
|
850 |
+
html_info,
|
851 |
+
html_log,
|
852 |
+
],
|
853 |
+
show_progress=False,
|
854 |
+
)
|
855 |
+
|
856 |
+
toprow.button_interrogate.click(
|
857 |
+
fn=lambda *args: process_interrogate(interrogate, *args),
|
858 |
+
**interrogate_args,
|
859 |
+
)
|
860 |
+
|
861 |
+
toprow.button_deepbooru.click(
|
862 |
+
fn=lambda *args: process_interrogate(interrogate_deepbooru, *args),
|
863 |
+
**interrogate_args,
|
864 |
+
)
|
865 |
+
|
866 |
+
toprow.token_button.click(fn=update_token_counter, inputs=[toprow.prompt, steps], outputs=[toprow.token_counter])
|
867 |
+
toprow.negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[toprow.negative_prompt, steps], outputs=[toprow.negative_token_counter])
|
868 |
+
|
869 |
+
img2img_paste_fields = [
|
870 |
+
(toprow.prompt, "Prompt"),
|
871 |
+
(toprow.negative_prompt, "Negative prompt"),
|
872 |
+
(steps, "Steps"),
|
873 |
+
(sampler_name, "Sampler"),
|
874 |
+
(cfg_scale, "CFG scale"),
|
875 |
+
(image_cfg_scale, "Image CFG scale"),
|
876 |
+
(width, "Size-1"),
|
877 |
+
(height, "Size-2"),
|
878 |
+
(batch_size, "Batch size"),
|
879 |
+
(toprow.ui_styles.dropdown, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update()),
|
880 |
+
(denoising_strength, "Denoising strength"),
|
881 |
+
(mask_blur, "Mask blur"),
|
882 |
+
*scripts.scripts_img2img.infotext_fields
|
883 |
+
]
|
884 |
+
parameters_copypaste.add_paste_fields("img2img", init_img, img2img_paste_fields, override_settings)
|
885 |
+
parameters_copypaste.add_paste_fields("inpaint", init_img_with_mask, img2img_paste_fields, override_settings)
|
886 |
+
parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding(
|
887 |
+
paste_button=toprow.paste, tabname="img2img", source_text_component=toprow.prompt, source_image_component=None,
|
888 |
+
))
|
889 |
+
|
890 |
+
extra_networks_ui_img2img = ui_extra_networks.create_ui(img2img_interface, [img2img_generation_tab], 'img2img')
|
891 |
+
ui_extra_networks.setup_ui(extra_networks_ui_img2img, img2img_gallery)
|
892 |
+
|
893 |
+
extra_tabs.__exit__()
|
894 |
+
|
895 |
+
scripts.scripts_current = None
|
896 |
+
|
897 |
+
with gr.Blocks(analytics_enabled=False) as extras_interface:
|
898 |
+
ui_postprocessing.create_ui()
|
899 |
+
|
900 |
+
with gr.Blocks(analytics_enabled=False) as pnginfo_interface:
|
901 |
+
with gr.Row(equal_height=False):
|
902 |
+
with gr.Column(variant='panel'):
|
903 |
+
image = gr.Image(elem_id="pnginfo_image", label="Source", source="upload", interactive=True, type="pil")
|
904 |
+
|
905 |
+
with gr.Column(variant='panel'):
|
906 |
+
html = gr.HTML()
|
907 |
+
generation_info = gr.Textbox(visible=False, elem_id="pnginfo_generation_info")
|
908 |
+
html2 = gr.HTML()
|
909 |
+
with gr.Row():
|
910 |
+
buttons = parameters_copypaste.create_buttons(["txt2img", "img2img", "inpaint", "extras"])
|
911 |
+
|
912 |
+
for tabname, button in buttons.items():
|
913 |
+
parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding(
|
914 |
+
paste_button=button, tabname=tabname, source_text_component=generation_info, source_image_component=image,
|
915 |
+
))
|
916 |
+
|
917 |
+
image.change(
|
918 |
+
fn=wrap_gradio_call(modules.extras.run_pnginfo),
|
919 |
+
inputs=[image],
|
920 |
+
outputs=[html, generation_info, html2],
|
921 |
+
)
|
922 |
+
|
923 |
+
modelmerger_ui = ui_checkpoint_merger.UiCheckpointMerger()
|
924 |
+
|
925 |
+
with gr.Blocks(analytics_enabled=False) as train_interface:
|
926 |
+
with gr.Row(equal_height=False):
|
927 |
+
gr.HTML(value="<p style='margin-bottom: 0.7em'>See <b><a href=\"https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Textual-Inversion\">wiki</a></b> for detailed explanation.</p>")
|
928 |
+
|
929 |
+
with gr.Row(variant="compact", equal_height=False):
|
930 |
+
with gr.Tabs(elem_id="train_tabs"):
|
931 |
+
|
932 |
+
with gr.Tab(label="Create embedding", id="create_embedding"):
|
933 |
+
new_embedding_name = gr.Textbox(label="Name", elem_id="train_new_embedding_name")
|
934 |
+
initialization_text = gr.Textbox(label="Initialization text", value="*", elem_id="train_initialization_text")
|
935 |
+
nvpt = gr.Slider(label="Number of vectors per token", minimum=1, maximum=75, step=1, value=1, elem_id="train_nvpt")
|
936 |
+
overwrite_old_embedding = gr.Checkbox(value=False, label="Overwrite Old Embedding", elem_id="train_overwrite_old_embedding")
|
937 |
+
|
938 |
+
with gr.Row():
|
939 |
+
with gr.Column(scale=3):
|
940 |
+
gr.HTML(value="")
|
941 |
+
|
942 |
+
with gr.Column():
|
943 |
+
create_embedding = gr.Button(value="Create embedding", variant='primary', elem_id="train_create_embedding")
|
944 |
+
|
945 |
+
with gr.Tab(label="Create hypernetwork", id="create_hypernetwork"):
|
946 |
+
new_hypernetwork_name = gr.Textbox(label="Name", elem_id="train_new_hypernetwork_name")
|
947 |
+
new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "1024", "320", "640", "1280"], elem_id="train_new_hypernetwork_sizes")
|
948 |
+
new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'", elem_id="train_new_hypernetwork_layer_structure")
|
949 |
+
new_hypernetwork_activation_func = gr.Dropdown(value="linear", label="Select activation function of hypernetwork. Recommended : Swish / Linear(none)", choices=hypernetworks_ui.keys, elem_id="train_new_hypernetwork_activation_func")
|
950 |
+
new_hypernetwork_initialization_option = gr.Dropdown(value = "Normal", label="Select Layer weights initialization. Recommended: Kaiming for relu-like, Xavier for sigmoid-like, Normal otherwise", choices=["Normal", "KaimingUniform", "KaimingNormal", "XavierUniform", "XavierNormal"], elem_id="train_new_hypernetwork_initialization_option")
|
951 |
+
new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization", elem_id="train_new_hypernetwork_add_layer_norm")
|
952 |
+
new_hypernetwork_use_dropout = gr.Checkbox(label="Use dropout", elem_id="train_new_hypernetwork_use_dropout")
|
953 |
+
new_hypernetwork_dropout_structure = gr.Textbox("0, 0, 0", label="Enter hypernetwork Dropout structure (or empty). Recommended : 0~0.35 incrementing sequence: 0, 0.05, 0.15", placeholder="1st and last digit must be 0 and values should be between 0 and 1. ex:'0, 0.01, 0'")
|
954 |
+
overwrite_old_hypernetwork = gr.Checkbox(value=False, label="Overwrite Old Hypernetwork", elem_id="train_overwrite_old_hypernetwork")
|
955 |
+
|
956 |
+
with gr.Row():
|
957 |
+
with gr.Column(scale=3):
|
958 |
+
gr.HTML(value="")
|
959 |
+
|
960 |
+
with gr.Column():
|
961 |
+
create_hypernetwork = gr.Button(value="Create hypernetwork", variant='primary', elem_id="train_create_hypernetwork")
|
962 |
+
|
963 |
+
with gr.Tab(label="Preprocess images", id="preprocess_images"):
|
964 |
+
process_src = gr.Textbox(label='Source directory', elem_id="train_process_src")
|
965 |
+
process_dst = gr.Textbox(label='Destination directory', elem_id="train_process_dst")
|
966 |
+
process_width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="train_process_width")
|
967 |
+
process_height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="train_process_height")
|
968 |
+
preprocess_txt_action = gr.Dropdown(label='Existing Caption txt Action', value="ignore", choices=["ignore", "copy", "prepend", "append"], elem_id="train_preprocess_txt_action")
|
969 |
+
|
970 |
+
with gr.Row():
|
971 |
+
process_keep_original_size = gr.Checkbox(label='Keep original size', elem_id="train_process_keep_original_size")
|
972 |
+
process_flip = gr.Checkbox(label='Create flipped copies', elem_id="train_process_flip")
|
973 |
+
process_split = gr.Checkbox(label='Split oversized images', elem_id="train_process_split")
|
974 |
+
process_focal_crop = gr.Checkbox(label='Auto focal point crop', elem_id="train_process_focal_crop")
|
975 |
+
process_multicrop = gr.Checkbox(label='Auto-sized crop', elem_id="train_process_multicrop")
|
976 |
+
process_caption = gr.Checkbox(label='Use BLIP for caption', elem_id="train_process_caption")
|
977 |
+
process_caption_deepbooru = gr.Checkbox(label='Use deepbooru for caption', visible=True, elem_id="train_process_caption_deepbooru")
|
978 |
+
|
979 |
+
with gr.Row(visible=False) as process_split_extra_row:
|
980 |
+
process_split_threshold = gr.Slider(label='Split image threshold', value=0.5, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_split_threshold")
|
981 |
+
process_overlap_ratio = gr.Slider(label='Split image overlap ratio', value=0.2, minimum=0.0, maximum=0.9, step=0.05, elem_id="train_process_overlap_ratio")
|
982 |
+
|
983 |
+
with gr.Row(visible=False) as process_focal_crop_row:
|
984 |
+
process_focal_crop_face_weight = gr.Slider(label='Focal point face weight', value=0.9, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_face_weight")
|
985 |
+
process_focal_crop_entropy_weight = gr.Slider(label='Focal point entropy weight', value=0.15, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_entropy_weight")
|
986 |
+
process_focal_crop_edges_weight = gr.Slider(label='Focal point edges weight', value=0.5, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_edges_weight")
|
987 |
+
process_focal_crop_debug = gr.Checkbox(label='Create debug image', elem_id="train_process_focal_crop_debug")
|
988 |
+
|
989 |
+
with gr.Column(visible=False) as process_multicrop_col:
|
990 |
+
gr.Markdown('Each image is center-cropped with an automatically chosen width and height.')
|
991 |
+
with gr.Row():
|
992 |
+
process_multicrop_mindim = gr.Slider(minimum=64, maximum=2048, step=8, label="Dimension lower bound", value=384, elem_id="train_process_multicrop_mindim")
|
993 |
+
process_multicrop_maxdim = gr.Slider(minimum=64, maximum=2048, step=8, label="Dimension upper bound", value=768, elem_id="train_process_multicrop_maxdim")
|
994 |
+
with gr.Row():
|
995 |
+
process_multicrop_minarea = gr.Slider(minimum=64*64, maximum=2048*2048, step=1, label="Area lower bound", value=64*64, elem_id="train_process_multicrop_minarea")
|
996 |
+
process_multicrop_maxarea = gr.Slider(minimum=64*64, maximum=2048*2048, step=1, label="Area upper bound", value=640*640, elem_id="train_process_multicrop_maxarea")
|
997 |
+
with gr.Row():
|
998 |
+
process_multicrop_objective = gr.Radio(["Maximize area", "Minimize error"], value="Maximize area", label="Resizing objective", elem_id="train_process_multicrop_objective")
|
999 |
+
process_multicrop_threshold = gr.Slider(minimum=0, maximum=1, step=0.01, label="Error threshold", value=0.1, elem_id="train_process_multicrop_threshold")
|
1000 |
+
|
1001 |
+
with gr.Row():
|
1002 |
+
with gr.Column(scale=3):
|
1003 |
+
gr.HTML(value="")
|
1004 |
+
|
1005 |
+
with gr.Column():
|
1006 |
+
with gr.Row():
|
1007 |
+
interrupt_preprocessing = gr.Button("Interrupt", elem_id="train_interrupt_preprocessing")
|
1008 |
+
run_preprocess = gr.Button(value="Preprocess", variant='primary', elem_id="train_run_preprocess")
|
1009 |
+
|
1010 |
+
process_split.change(
|
1011 |
+
fn=lambda show: gr_show(show),
|
1012 |
+
inputs=[process_split],
|
1013 |
+
outputs=[process_split_extra_row],
|
1014 |
+
)
|
1015 |
+
|
1016 |
+
process_focal_crop.change(
|
1017 |
+
fn=lambda show: gr_show(show),
|
1018 |
+
inputs=[process_focal_crop],
|
1019 |
+
outputs=[process_focal_crop_row],
|
1020 |
+
)
|
1021 |
+
|
1022 |
+
process_multicrop.change(
|
1023 |
+
fn=lambda show: gr_show(show),
|
1024 |
+
inputs=[process_multicrop],
|
1025 |
+
outputs=[process_multicrop_col],
|
1026 |
+
)
|
1027 |
+
|
1028 |
+
def get_textual_inversion_template_names():
|
1029 |
+
return sorted(textual_inversion.textual_inversion_templates)
|
1030 |
+
|
1031 |
+
with gr.Tab(label="Train", id="train"):
|
1032 |
+
gr.HTML(value="<p style='margin-bottom: 0.7em'>Train an embedding or Hypernetwork; you must specify a directory with a set of 1:1 ratio images <a href=\"https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Textual-Inversion\" style=\"font-weight:bold;\">[wiki]</a></p>")
|
1033 |
+
with FormRow():
|
1034 |
+
train_embedding_name = gr.Dropdown(label='Embedding', elem_id="train_embedding", choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys()))
|
1035 |
+
create_refresh_button(train_embedding_name, sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings, lambda: {"choices": sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())}, "refresh_train_embedding_name")
|
1036 |
+
|
1037 |
+
train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', elem_id="train_hypernetwork", choices=sorted(shared.hypernetworks))
|
1038 |
+
create_refresh_button(train_hypernetwork_name, shared.reload_hypernetworks, lambda: {"choices": sorted(shared.hypernetworks)}, "refresh_train_hypernetwork_name")
|
1039 |
+
|
1040 |
+
with FormRow():
|
1041 |
+
embedding_learn_rate = gr.Textbox(label='Embedding Learning rate', placeholder="Embedding Learning rate", value="0.005", elem_id="train_embedding_learn_rate")
|
1042 |
+
hypernetwork_learn_rate = gr.Textbox(label='Hypernetwork Learning rate', placeholder="Hypernetwork Learning rate", value="0.00001", elem_id="train_hypernetwork_learn_rate")
|
1043 |
+
|
1044 |
+
with FormRow():
|
1045 |
+
clip_grad_mode = gr.Dropdown(value="disabled", label="Gradient Clipping", choices=["disabled", "value", "norm"])
|
1046 |
+
clip_grad_value = gr.Textbox(placeholder="Gradient clip value", value="0.1", show_label=False)
|
1047 |
+
|
1048 |
+
with FormRow():
|
1049 |
+
batch_size = gr.Number(label='Batch size', value=1, precision=0, elem_id="train_batch_size")
|
1050 |
+
gradient_step = gr.Number(label='Gradient accumulation steps', value=1, precision=0, elem_id="train_gradient_step")
|
1051 |
+
|
1052 |
+
dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images", elem_id="train_dataset_directory")
|
1053 |
+
log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion", elem_id="train_log_directory")
|
1054 |
+
|
1055 |
+
with FormRow():
|
1056 |
+
template_file = gr.Dropdown(label='Prompt template', value="style_filewords.txt", elem_id="train_template_file", choices=get_textual_inversion_template_names())
|
1057 |
+
create_refresh_button(template_file, textual_inversion.list_textual_inversion_templates, lambda: {"choices": get_textual_inversion_template_names()}, "refrsh_train_template_file")
|
1058 |
+
|
1059 |
+
training_width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="train_training_width")
|
1060 |
+
training_height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="train_training_height")
|
1061 |
+
varsize = gr.Checkbox(label="Do not resize images", value=False, elem_id="train_varsize")
|
1062 |
+
steps = gr.Number(label='Max steps', value=100000, precision=0, elem_id="train_steps")
|
1063 |
+
|
1064 |
+
with FormRow():
|
1065 |
+
create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0, elem_id="train_create_image_every")
|
1066 |
+
save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0, elem_id="train_save_embedding_every")
|
1067 |
+
|
1068 |
+
use_weight = gr.Checkbox(label="Use PNG alpha channel as loss weight", value=False, elem_id="use_weight")
|
1069 |
+
|
1070 |
+
save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True, elem_id="train_save_image_with_stored_embedding")
|
1071 |
+
preview_from_txt2img = gr.Checkbox(label='Read parameters (prompt, etc...) from txt2img tab when making previews', value=False, elem_id="train_preview_from_txt2img")
|
1072 |
+
|
1073 |
+
shuffle_tags = gr.Checkbox(label="Shuffle tags by ',' when creating prompts.", value=False, elem_id="train_shuffle_tags")
|
1074 |
+
tag_drop_out = gr.Slider(minimum=0, maximum=1, step=0.1, label="Drop out tags when creating prompts.", value=0, elem_id="train_tag_drop_out")
|
1075 |
+
|
1076 |
+
latent_sampling_method = gr.Radio(label='Choose latent sampling method', value="once", choices=['once', 'deterministic', 'random'], elem_id="train_latent_sampling_method")
|
1077 |
+
|
1078 |
+
with gr.Row():
|
1079 |
+
train_embedding = gr.Button(value="Train Embedding", variant='primary', elem_id="train_train_embedding")
|
1080 |
+
interrupt_training = gr.Button(value="Interrupt", elem_id="train_interrupt_training")
|
1081 |
+
train_hypernetwork = gr.Button(value="Train Hypernetwork", variant='primary', elem_id="train_train_hypernetwork")
|
1082 |
+
|
1083 |
+
params = script_callbacks.UiTrainTabParams(txt2img_preview_params)
|
1084 |
+
|
1085 |
+
script_callbacks.ui_train_tabs_callback(params)
|
1086 |
+
|
1087 |
+
with gr.Column(elem_id='ti_gallery_container'):
|
1088 |
+
ti_output = gr.Text(elem_id="ti_output", value="", show_label=False)
|
1089 |
+
gr.Gallery(label='Output', show_label=False, elem_id='ti_gallery', columns=4)
|
1090 |
+
gr.HTML(elem_id="ti_progress", value="")
|
1091 |
+
ti_outcome = gr.HTML(elem_id="ti_error", value="")
|
1092 |
+
|
1093 |
+
create_embedding.click(
|
1094 |
+
fn=textual_inversion_ui.create_embedding,
|
1095 |
+
inputs=[
|
1096 |
+
new_embedding_name,
|
1097 |
+
initialization_text,
|
1098 |
+
nvpt,
|
1099 |
+
overwrite_old_embedding,
|
1100 |
+
],
|
1101 |
+
outputs=[
|
1102 |
+
train_embedding_name,
|
1103 |
+
ti_output,
|
1104 |
+
ti_outcome,
|
1105 |
+
]
|
1106 |
+
)
|
1107 |
+
|
1108 |
+
create_hypernetwork.click(
|
1109 |
+
fn=hypernetworks_ui.create_hypernetwork,
|
1110 |
+
inputs=[
|
1111 |
+
new_hypernetwork_name,
|
1112 |
+
new_hypernetwork_sizes,
|
1113 |
+
overwrite_old_hypernetwork,
|
1114 |
+
new_hypernetwork_layer_structure,
|
1115 |
+
new_hypernetwork_activation_func,
|
1116 |
+
new_hypernetwork_initialization_option,
|
1117 |
+
new_hypernetwork_add_layer_norm,
|
1118 |
+
new_hypernetwork_use_dropout,
|
1119 |
+
new_hypernetwork_dropout_structure
|
1120 |
+
],
|
1121 |
+
outputs=[
|
1122 |
+
train_hypernetwork_name,
|
1123 |
+
ti_output,
|
1124 |
+
ti_outcome,
|
1125 |
+
]
|
1126 |
+
)
|
1127 |
+
|
1128 |
+
run_preprocess.click(
|
1129 |
+
fn=wrap_gradio_gpu_call(textual_inversion_ui.preprocess, extra_outputs=[gr.update()]),
|
1130 |
+
_js="start_training_textual_inversion",
|
1131 |
+
inputs=[
|
1132 |
+
dummy_component,
|
1133 |
+
process_src,
|
1134 |
+
process_dst,
|
1135 |
+
process_width,
|
1136 |
+
process_height,
|
1137 |
+
preprocess_txt_action,
|
1138 |
+
process_keep_original_size,
|
1139 |
+
process_flip,
|
1140 |
+
process_split,
|
1141 |
+
process_caption,
|
1142 |
+
process_caption_deepbooru,
|
1143 |
+
process_split_threshold,
|
1144 |
+
process_overlap_ratio,
|
1145 |
+
process_focal_crop,
|
1146 |
+
process_focal_crop_face_weight,
|
1147 |
+
process_focal_crop_entropy_weight,
|
1148 |
+
process_focal_crop_edges_weight,
|
1149 |
+
process_focal_crop_debug,
|
1150 |
+
process_multicrop,
|
1151 |
+
process_multicrop_mindim,
|
1152 |
+
process_multicrop_maxdim,
|
1153 |
+
process_multicrop_minarea,
|
1154 |
+
process_multicrop_maxarea,
|
1155 |
+
process_multicrop_objective,
|
1156 |
+
process_multicrop_threshold,
|
1157 |
+
],
|
1158 |
+
outputs=[
|
1159 |
+
ti_output,
|
1160 |
+
ti_outcome,
|
1161 |
+
],
|
1162 |
+
)
|
1163 |
+
|
1164 |
+
train_embedding.click(
|
1165 |
+
fn=wrap_gradio_gpu_call(textual_inversion_ui.train_embedding, extra_outputs=[gr.update()]),
|
1166 |
+
_js="start_training_textual_inversion",
|
1167 |
+
inputs=[
|
1168 |
+
dummy_component,
|
1169 |
+
train_embedding_name,
|
1170 |
+
embedding_learn_rate,
|
1171 |
+
batch_size,
|
1172 |
+
gradient_step,
|
1173 |
+
dataset_directory,
|
1174 |
+
log_directory,
|
1175 |
+
training_width,
|
1176 |
+
training_height,
|
1177 |
+
varsize,
|
1178 |
+
steps,
|
1179 |
+
clip_grad_mode,
|
1180 |
+
clip_grad_value,
|
1181 |
+
shuffle_tags,
|
1182 |
+
tag_drop_out,
|
1183 |
+
latent_sampling_method,
|
1184 |
+
use_weight,
|
1185 |
+
create_image_every,
|
1186 |
+
save_embedding_every,
|
1187 |
+
template_file,
|
1188 |
+
save_image_with_stored_embedding,
|
1189 |
+
preview_from_txt2img,
|
1190 |
+
*txt2img_preview_params,
|
1191 |
+
],
|
1192 |
+
outputs=[
|
1193 |
+
ti_output,
|
1194 |
+
ti_outcome,
|
1195 |
+
]
|
1196 |
+
)
|
1197 |
+
|
1198 |
+
train_hypernetwork.click(
|
1199 |
+
fn=wrap_gradio_gpu_call(hypernetworks_ui.train_hypernetwork, extra_outputs=[gr.update()]),
|
1200 |
+
_js="start_training_textual_inversion",
|
1201 |
+
inputs=[
|
1202 |
+
dummy_component,
|
1203 |
+
train_hypernetwork_name,
|
1204 |
+
hypernetwork_learn_rate,
|
1205 |
+
batch_size,
|
1206 |
+
gradient_step,
|
1207 |
+
dataset_directory,
|
1208 |
+
log_directory,
|
1209 |
+
training_width,
|
1210 |
+
training_height,
|
1211 |
+
varsize,
|
1212 |
+
steps,
|
1213 |
+
clip_grad_mode,
|
1214 |
+
clip_grad_value,
|
1215 |
+
shuffle_tags,
|
1216 |
+
tag_drop_out,
|
1217 |
+
latent_sampling_method,
|
1218 |
+
use_weight,
|
1219 |
+
create_image_every,
|
1220 |
+
save_embedding_every,
|
1221 |
+
template_file,
|
1222 |
+
preview_from_txt2img,
|
1223 |
+
*txt2img_preview_params,
|
1224 |
+
],
|
1225 |
+
outputs=[
|
1226 |
+
ti_output,
|
1227 |
+
ti_outcome,
|
1228 |
+
]
|
1229 |
+
)
|
1230 |
+
|
1231 |
+
interrupt_training.click(
|
1232 |
+
fn=lambda: shared.state.interrupt(),
|
1233 |
+
inputs=[],
|
1234 |
+
outputs=[],
|
1235 |
+
)
|
1236 |
+
|
1237 |
+
interrupt_preprocessing.click(
|
1238 |
+
fn=lambda: shared.state.interrupt(),
|
1239 |
+
inputs=[],
|
1240 |
+
outputs=[],
|
1241 |
+
)
|
1242 |
+
|
1243 |
+
loadsave = ui_loadsave.UiLoadsave(cmd_opts.ui_config_file)
|
1244 |
+
|
1245 |
+
settings = ui_settings.UiSettings()
|
1246 |
+
settings.create_ui(loadsave, dummy_component)
|
1247 |
+
|
1248 |
+
interfaces = [
|
1249 |
+
(txt2img_interface, "txt2img", "txt2img"),
|
1250 |
+
(img2img_interface, "img2img", "img2img"),
|
1251 |
+
(extras_interface, "Extras", "extras"),
|
1252 |
+
(pnginfo_interface, "PNG Info", "pnginfo"),
|
1253 |
+
(modelmerger_ui.blocks, "Checkpoint Merger", "modelmerger"),
|
1254 |
+
(train_interface, "Train", "train"),
|
1255 |
+
]
|
1256 |
+
|
1257 |
+
interfaces += script_callbacks.ui_tabs_callback()
|
1258 |
+
interfaces += [(settings.interface, "Settings", "settings")]
|
1259 |
+
|
1260 |
+
extensions_interface = ui_extensions.create_ui()
|
1261 |
+
interfaces += [(extensions_interface, "Extensions", "extensions")]
|
1262 |
+
|
1263 |
+
shared.tab_names = []
|
1264 |
+
for _interface, label, _ifid in interfaces:
|
1265 |
+
shared.tab_names.append(label)
|
1266 |
+
|
1267 |
+
with gr.Blocks(theme=shared.gradio_theme, analytics_enabled=False, title="Stable Diffusion") as demo:
|
1268 |
+
settings.add_quicksettings()
|
1269 |
+
|
1270 |
+
parameters_copypaste.connect_paste_params_buttons()
|
1271 |
+
|
1272 |
+
with gr.Tabs(elem_id="tabs") as tabs:
|
1273 |
+
tab_order = {k: i for i, k in enumerate(opts.ui_tab_order)}
|
1274 |
+
sorted_interfaces = sorted(interfaces, key=lambda x: tab_order.get(x[1], 9999))
|
1275 |
+
|
1276 |
+
for interface, label, ifid in sorted_interfaces:
|
1277 |
+
if label in shared.opts.hidden_tabs:
|
1278 |
+
continue
|
1279 |
+
with gr.TabItem(label, id=ifid, elem_id=f"tab_{ifid}"):
|
1280 |
+
interface.render()
|
1281 |
+
|
1282 |
+
if ifid not in ["extensions", "settings"]:
|
1283 |
+
loadsave.add_block(interface, ifid)
|
1284 |
+
|
1285 |
+
loadsave.add_component(f"webui/Tabs@{tabs.elem_id}", tabs)
|
1286 |
+
|
1287 |
+
loadsave.setup_ui()
|
1288 |
+
|
1289 |
+
if os.path.exists(os.path.join(script_path, "notification.mp3")):
|
1290 |
+
gr.Audio(interactive=False, value=os.path.join(script_path, "notification.mp3"), elem_id="audio_notification", visible=False)
|
1291 |
+
|
1292 |
+
footer = shared.html("footer.html")
|
1293 |
+
footer = footer.format(versions=versions_html(), api_docs="/docs" if shared.cmd_opts.api else "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/API")
|
1294 |
+
gr.HTML(footer, elem_id="footer")
|
1295 |
+
|
1296 |
+
settings.add_functionality(demo)
|
1297 |
+
|
1298 |
+
update_image_cfg_scale_visibility = lambda: gr.update(visible=shared.sd_model and shared.sd_model.cond_stage_key == "edit")
|
1299 |
+
settings.text_settings.change(fn=update_image_cfg_scale_visibility, inputs=[], outputs=[image_cfg_scale])
|
1300 |
+
demo.load(fn=update_image_cfg_scale_visibility, inputs=[], outputs=[image_cfg_scale])
|
1301 |
+
|
1302 |
+
modelmerger_ui.setup_ui(dummy_component=dummy_component, sd_model_checkpoint_component=settings.component_dict['sd_model_checkpoint'])
|
1303 |
+
|
1304 |
+
loadsave.dump_defaults()
|
1305 |
+
demo.ui_loadsave = loadsave
|
1306 |
+
|
1307 |
+
return demo
|
1308 |
+
|
1309 |
+
|
1310 |
+
def versions_html():
|
1311 |
+
import torch
|
1312 |
+
import launch
|
1313 |
+
|
1314 |
+
python_version = ".".join([str(x) for x in sys.version_info[0:3]])
|
1315 |
+
commit = launch.commit_hash()
|
1316 |
+
tag = launch.git_tag()
|
1317 |
+
|
1318 |
+
if shared.xformers_available:
|
1319 |
+
import xformers
|
1320 |
+
xformers_version = xformers.__version__
|
1321 |
+
else:
|
1322 |
+
xformers_version = "N/A"
|
1323 |
+
|
1324 |
+
return f"""
|
1325 |
+
version: <a href="https://github.com/AUTOMATIC1111/stable-diffusion-webui/commit/{commit}">{tag}</a>
|
1326 |
+
 • 
|
1327 |
+
python: <span title="{sys.version}">{python_version}</span>
|
1328 |
+
 • 
|
1329 |
+
torch: {getattr(torch, '__long_version__',torch.__version__)}
|
1330 |
+
 • 
|
1331 |
+
xformers: {xformers_version}
|
1332 |
+
 • 
|
1333 |
+
gradio: {gr.__version__}
|
1334 |
+
 • 
|
1335 |
+
checkpoint: <a id="sd_checkpoint_hash">N/A</a>
|
1336 |
+
"""
|
1337 |
+
|
1338 |
+
|
1339 |
+
def setup_ui_api(app):
|
1340 |
+
from pydantic import BaseModel, Field
|
1341 |
+
from typing import List
|
1342 |
+
|
1343 |
+
class QuicksettingsHint(BaseModel):
|
1344 |
+
name: str = Field(title="Name of the quicksettings field")
|
1345 |
+
label: str = Field(title="Label of the quicksettings field")
|
1346 |
+
|
1347 |
+
def quicksettings_hint():
|
1348 |
+
return [QuicksettingsHint(name=k, label=v.label) for k, v in opts.data_labels.items()]
|
1349 |
+
|
1350 |
+
app.add_api_route("/internal/quicksettings-hint", quicksettings_hint, methods=["GET"], response_model=List[QuicksettingsHint])
|
1351 |
+
|
1352 |
+
app.add_api_route("/internal/ping", lambda: {}, methods=["GET"])
|
1353 |
+
|
1354 |
+
app.add_api_route("/internal/profile-startup", lambda: timer.startup_record, methods=["GET"])
|
1355 |
+
|
1356 |
+
def download_sysinfo(attachment=False):
|
1357 |
+
from fastapi.responses import PlainTextResponse
|
1358 |
+
|
1359 |
+
text = sysinfo.get()
|
1360 |
+
filename = f"sysinfo-{datetime.datetime.utcnow().strftime('%Y-%m-%d-%H-%M')}.txt"
|
1361 |
+
|
1362 |
+
return PlainTextResponse(text, headers={'Content-Disposition': f'{"attachment" if attachment else "inline"}; filename="{filename}"'})
|
1363 |
+
|
1364 |
+
app.add_api_route("/internal/sysinfo", download_sysinfo, methods=["GET"])
|
1365 |
+
app.add_api_route("/internal/sysinfo-download", lambda: download_sysinfo(attachment=True), methods=["GET"])
|
1366 |
+
|
modules/ui_checkpoint_merger.py
ADDED
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import gradio as gr
|
3 |
+
|
4 |
+
from modules import sd_models, sd_vae, errors, extras, call_queue
|
5 |
+
from modules.ui_components import FormRow
|
6 |
+
from modules.ui_common import create_refresh_button
|
7 |
+
|
8 |
+
|
9 |
+
def update_interp_description(value):
|
10 |
+
interp_description_css = "<p style='margin-bottom: 2.5em'>{}</p>"
|
11 |
+
interp_descriptions = {
|
12 |
+
"No interpolation": interp_description_css.format("No interpolation will be used. Requires one model; A. Allows for format conversion and VAE baking."),
|
13 |
+
"Weighted sum": interp_description_css.format("A weighted sum will be used for interpolation. Requires two models; A and B. The result is calculated as A * (1 - M) + B * M"),
|
14 |
+
"Add difference": interp_description_css.format("The difference between the last two models will be added to the first. Requires three models; A, B and C. The result is calculated as A + (B - C) * M")
|
15 |
+
}
|
16 |
+
return interp_descriptions[value]
|
17 |
+
|
18 |
+
|
19 |
+
def modelmerger(*args):
|
20 |
+
try:
|
21 |
+
results = extras.run_modelmerger(*args)
|
22 |
+
except Exception as e:
|
23 |
+
errors.report("Error loading/saving model file", exc_info=True)
|
24 |
+
sd_models.list_models() # to remove the potentially missing models from the list
|
25 |
+
return [*[gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)], f"Error merging checkpoints: {e}"]
|
26 |
+
return results
|
27 |
+
|
28 |
+
|
29 |
+
class UiCheckpointMerger:
|
30 |
+
def __init__(self):
|
31 |
+
with gr.Blocks(analytics_enabled=False) as modelmerger_interface:
|
32 |
+
with gr.Row(equal_height=False):
|
33 |
+
with gr.Column(variant='compact'):
|
34 |
+
self.interp_description = gr.HTML(value=update_interp_description("Weighted sum"), elem_id="modelmerger_interp_description")
|
35 |
+
|
36 |
+
with FormRow(elem_id="modelmerger_models"):
|
37 |
+
self.primary_model_name = gr.Dropdown(sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary model (A)")
|
38 |
+
create_refresh_button(self.primary_model_name, sd_models.list_models, lambda: {"choices": sd_models.checkpoint_tiles()}, "refresh_checkpoint_A")
|
39 |
+
|
40 |
+
self.secondary_model_name = gr.Dropdown(sd_models.checkpoint_tiles(), elem_id="modelmerger_secondary_model_name", label="Secondary model (B)")
|
41 |
+
create_refresh_button(self.secondary_model_name, sd_models.list_models, lambda: {"choices": sd_models.checkpoint_tiles()}, "refresh_checkpoint_B")
|
42 |
+
|
43 |
+
self.tertiary_model_name = gr.Dropdown(sd_models.checkpoint_tiles(), elem_id="modelmerger_tertiary_model_name", label="Tertiary model (C)")
|
44 |
+
create_refresh_button(self.tertiary_model_name, sd_models.list_models, lambda: {"choices": sd_models.checkpoint_tiles()}, "refresh_checkpoint_C")
|
45 |
+
|
46 |
+
self.custom_name = gr.Textbox(label="Custom Name (Optional)", elem_id="modelmerger_custom_name")
|
47 |
+
self.interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Multiplier (M) - set to 0 to get model A', value=0.3, elem_id="modelmerger_interp_amount")
|
48 |
+
self.interp_method = gr.Radio(choices=["No interpolation", "Weighted sum", "Add difference"], value="Weighted sum", label="Interpolation Method", elem_id="modelmerger_interp_method")
|
49 |
+
self.interp_method.change(fn=update_interp_description, inputs=[self.interp_method], outputs=[self.interp_description])
|
50 |
+
|
51 |
+
with FormRow():
|
52 |
+
self.checkpoint_format = gr.Radio(choices=["ckpt", "safetensors"], value="safetensors", label="Checkpoint format", elem_id="modelmerger_checkpoint_format")
|
53 |
+
self.save_as_half = gr.Checkbox(value=False, label="Save as float16", elem_id="modelmerger_save_as_half")
|
54 |
+
|
55 |
+
with FormRow():
|
56 |
+
with gr.Column():
|
57 |
+
self.config_source = gr.Radio(choices=["A, B or C", "B", "C", "Don't"], value="A, B or C", label="Copy config from", type="index", elem_id="modelmerger_config_method")
|
58 |
+
|
59 |
+
with gr.Column():
|
60 |
+
with FormRow():
|
61 |
+
self.bake_in_vae = gr.Dropdown(choices=["None"] + list(sd_vae.vae_dict), value="None", label="Bake in VAE", elem_id="modelmerger_bake_in_vae")
|
62 |
+
create_refresh_button(self.bake_in_vae, sd_vae.refresh_vae_list, lambda: {"choices": ["None"] + list(sd_vae.vae_dict)}, "modelmerger_refresh_bake_in_vae")
|
63 |
+
|
64 |
+
with FormRow():
|
65 |
+
self.discard_weights = gr.Textbox(value="", label="Discard weights with matching name", elem_id="modelmerger_discard_weights")
|
66 |
+
|
67 |
+
with gr.Accordion("Metadata", open=False) as metadata_editor:
|
68 |
+
with FormRow():
|
69 |
+
self.save_metadata = gr.Checkbox(value=True, label="Save metadata", elem_id="modelmerger_save_metadata")
|
70 |
+
self.add_merge_recipe = gr.Checkbox(value=True, label="Add merge recipe metadata", elem_id="modelmerger_add_recipe")
|
71 |
+
self.copy_metadata_fields = gr.Checkbox(value=True, label="Copy metadata from merged models", elem_id="modelmerger_copy_metadata")
|
72 |
+
|
73 |
+
self.metadata_json = gr.TextArea('{}', label="Metadata in JSON format")
|
74 |
+
self.read_metadata = gr.Button("Read metadata from selected checkpoints")
|
75 |
+
|
76 |
+
with FormRow():
|
77 |
+
self.modelmerger_merge = gr.Button(elem_id="modelmerger_merge", value="Merge", variant='primary')
|
78 |
+
|
79 |
+
with gr.Column(variant='compact', elem_id="modelmerger_results_container"):
|
80 |
+
with gr.Group(elem_id="modelmerger_results_panel"):
|
81 |
+
self.modelmerger_result = gr.HTML(elem_id="modelmerger_result", show_label=False)
|
82 |
+
|
83 |
+
self.metadata_editor = metadata_editor
|
84 |
+
self.blocks = modelmerger_interface
|
85 |
+
|
86 |
+
def setup_ui(self, dummy_component, sd_model_checkpoint_component):
|
87 |
+
self.checkpoint_format.change(lambda fmt: gr.update(visible=fmt == 'safetensors'), inputs=[self.checkpoint_format], outputs=[self.metadata_editor], show_progress=False)
|
88 |
+
|
89 |
+
self.read_metadata.click(extras.read_metadata, inputs=[self.primary_model_name, self.secondary_model_name, self.tertiary_model_name], outputs=[self.metadata_json])
|
90 |
+
|
91 |
+
self.modelmerger_merge.click(fn=lambda: '', inputs=[], outputs=[self.modelmerger_result])
|
92 |
+
self.modelmerger_merge.click(
|
93 |
+
fn=call_queue.wrap_gradio_gpu_call(modelmerger, extra_outputs=lambda: [gr.update() for _ in range(4)]),
|
94 |
+
_js='modelmerger',
|
95 |
+
inputs=[
|
96 |
+
dummy_component,
|
97 |
+
self.primary_model_name,
|
98 |
+
self.secondary_model_name,
|
99 |
+
self.tertiary_model_name,
|
100 |
+
self.interp_method,
|
101 |
+
self.interp_amount,
|
102 |
+
self.save_as_half,
|
103 |
+
self.custom_name,
|
104 |
+
self.checkpoint_format,
|
105 |
+
self.config_source,
|
106 |
+
self.bake_in_vae,
|
107 |
+
self.discard_weights,
|
108 |
+
self.save_metadata,
|
109 |
+
self.add_merge_recipe,
|
110 |
+
self.copy_metadata_fields,
|
111 |
+
self.metadata_json,
|
112 |
+
],
|
113 |
+
outputs=[
|
114 |
+
self.primary_model_name,
|
115 |
+
self.secondary_model_name,
|
116 |
+
self.tertiary_model_name,
|
117 |
+
sd_model_checkpoint_component,
|
118 |
+
self.modelmerger_result,
|
119 |
+
]
|
120 |
+
)
|
121 |
+
|
122 |
+
# Required as a workaround for change() event not triggering when loading values from ui-config.json
|
123 |
+
self.interp_description.value = update_interp_description(self.interp_method.value)
|
124 |
+
|
modules/ui_common.py
ADDED
@@ -0,0 +1,268 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import html
|
3 |
+
import os
|
4 |
+
import platform
|
5 |
+
import sys
|
6 |
+
|
7 |
+
import gradio as gr
|
8 |
+
import subprocess as sp
|
9 |
+
|
10 |
+
from modules import call_queue, shared
|
11 |
+
from modules.generation_parameters_copypaste import image_from_url_text
|
12 |
+
import modules.images
|
13 |
+
from modules.ui_components import ToolButton
|
14 |
+
import modules.generation_parameters_copypaste as parameters_copypaste
|
15 |
+
|
16 |
+
folder_symbol = '\U0001f4c2' # 📂
|
17 |
+
refresh_symbol = '\U0001f504' # 🔄
|
18 |
+
|
19 |
+
|
20 |
+
def update_generation_info(generation_info, html_info, img_index):
|
21 |
+
try:
|
22 |
+
generation_info = json.loads(generation_info)
|
23 |
+
if img_index < 0 or img_index >= len(generation_info["infotexts"]):
|
24 |
+
return html_info, gr.update()
|
25 |
+
return plaintext_to_html(generation_info["infotexts"][img_index]), gr.update()
|
26 |
+
except Exception:
|
27 |
+
pass
|
28 |
+
# if the json parse or anything else fails, just return the old html_info
|
29 |
+
return html_info, gr.update()
|
30 |
+
|
31 |
+
|
32 |
+
def plaintext_to_html(text, classname=None):
|
33 |
+
content = "<br>\n".join(html.escape(x) for x in text.split('\n'))
|
34 |
+
|
35 |
+
return f"<p class='{classname}'>{content}</p>" if classname else f"<p>{content}</p>"
|
36 |
+
|
37 |
+
|
38 |
+
def save_files(js_data, images, do_make_zip, index):
|
39 |
+
import csv
|
40 |
+
filenames = []
|
41 |
+
fullfns = []
|
42 |
+
|
43 |
+
#quick dictionary to class object conversion. Its necessary due apply_filename_pattern requiring it
|
44 |
+
class MyObject:
|
45 |
+
def __init__(self, d=None):
|
46 |
+
if d is not None:
|
47 |
+
for key, value in d.items():
|
48 |
+
setattr(self, key, value)
|
49 |
+
|
50 |
+
data = json.loads(js_data)
|
51 |
+
|
52 |
+
p = MyObject(data)
|
53 |
+
path = shared.opts.outdir_save
|
54 |
+
save_to_dirs = shared.opts.use_save_to_dirs_for_ui
|
55 |
+
extension: str = shared.opts.samples_format
|
56 |
+
start_index = 0
|
57 |
+
only_one = False
|
58 |
+
|
59 |
+
if index > -1 and shared.opts.save_selected_only and (index >= data["index_of_first_image"]): # ensures we are looking at a specific non-grid picture, and we have save_selected_only
|
60 |
+
only_one = True
|
61 |
+
images = [images[index]]
|
62 |
+
start_index = index
|
63 |
+
|
64 |
+
os.makedirs(shared.opts.outdir_save, exist_ok=True)
|
65 |
+
|
66 |
+
with open(os.path.join(shared.opts.outdir_save, "log.csv"), "a", encoding="utf8", newline='') as file:
|
67 |
+
at_start = file.tell() == 0
|
68 |
+
writer = csv.writer(file)
|
69 |
+
if at_start:
|
70 |
+
writer.writerow(["prompt", "seed", "width", "height", "sampler", "cfgs", "steps", "filename", "negative_prompt"])
|
71 |
+
|
72 |
+
for image_index, filedata in enumerate(images, start_index):
|
73 |
+
image = image_from_url_text(filedata)
|
74 |
+
|
75 |
+
is_grid = image_index < p.index_of_first_image
|
76 |
+
i = 0 if is_grid else (image_index - p.index_of_first_image)
|
77 |
+
|
78 |
+
p.batch_index = image_index-1
|
79 |
+
fullfn, txt_fullfn = modules.images.save_image(image, path, "", seed=p.all_seeds[i], prompt=p.all_prompts[i], extension=extension, info=p.infotexts[image_index], grid=is_grid, p=p, save_to_dirs=save_to_dirs)
|
80 |
+
|
81 |
+
filename = os.path.relpath(fullfn, path)
|
82 |
+
filenames.append(filename)
|
83 |
+
fullfns.append(fullfn)
|
84 |
+
if txt_fullfn:
|
85 |
+
filenames.append(os.path.basename(txt_fullfn))
|
86 |
+
fullfns.append(txt_fullfn)
|
87 |
+
|
88 |
+
writer.writerow([data["prompt"], data["seed"], data["width"], data["height"], data["sampler_name"], data["cfg_scale"], data["steps"], filenames[0], data["negative_prompt"]])
|
89 |
+
|
90 |
+
# Make Zip
|
91 |
+
if do_make_zip:
|
92 |
+
zip_fileseed = p.all_seeds[index-1] if only_one else p.all_seeds[0]
|
93 |
+
namegen = modules.images.FilenameGenerator(p, zip_fileseed, p.all_prompts[0], image, True)
|
94 |
+
zip_filename = namegen.apply(shared.opts.grid_zip_filename_pattern or "[datetime]_[[model_name]]_[seed]-[seed_last]")
|
95 |
+
zip_filepath = os.path.join(path, f"{zip_filename}.zip")
|
96 |
+
|
97 |
+
from zipfile import ZipFile
|
98 |
+
with ZipFile(zip_filepath, "w") as zip_file:
|
99 |
+
for i in range(len(fullfns)):
|
100 |
+
with open(fullfns[i], mode="rb") as f:
|
101 |
+
zip_file.writestr(filenames[i], f.read())
|
102 |
+
fullfns.insert(0, zip_filepath)
|
103 |
+
|
104 |
+
return gr.File.update(value=fullfns, visible=True), plaintext_to_html(f"Saved: {filenames[0]}")
|
105 |
+
|
106 |
+
|
107 |
+
def create_output_panel(tabname, outdir):
|
108 |
+
|
109 |
+
def open_folder(f):
|
110 |
+
if not os.path.exists(f):
|
111 |
+
print(f'Folder "{f}" does not exist. After you create an image, the folder will be created.')
|
112 |
+
return
|
113 |
+
elif not os.path.isdir(f):
|
114 |
+
print(f"""
|
115 |
+
WARNING
|
116 |
+
An open_folder request was made with an argument that is not a folder.
|
117 |
+
This could be an error or a malicious attempt to run code on your computer.
|
118 |
+
Requested path was: {f}
|
119 |
+
""", file=sys.stderr)
|
120 |
+
return
|
121 |
+
|
122 |
+
if not shared.cmd_opts.hide_ui_dir_config:
|
123 |
+
path = os.path.normpath(f)
|
124 |
+
if platform.system() == "Windows":
|
125 |
+
os.startfile(path)
|
126 |
+
elif platform.system() == "Darwin":
|
127 |
+
sp.Popen(["open", path])
|
128 |
+
elif "microsoft-standard-WSL2" in platform.uname().release:
|
129 |
+
sp.Popen(["wsl-open", path])
|
130 |
+
else:
|
131 |
+
sp.Popen(["xdg-open", path])
|
132 |
+
|
133 |
+
with gr.Column(variant='panel', elem_id=f"{tabname}_results"):
|
134 |
+
with gr.Group(elem_id=f"{tabname}_gallery_container"):
|
135 |
+
result_gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery", columns=4, preview=True, height=shared.opts.gallery_height or None)
|
136 |
+
|
137 |
+
generation_info = None
|
138 |
+
with gr.Column():
|
139 |
+
with gr.Row(elem_id=f"image_buttons_{tabname}", elem_classes="image-buttons"):
|
140 |
+
open_folder_button = ToolButton(folder_symbol, elem_id=f'{tabname}_open_folder', visible=not shared.cmd_opts.hide_ui_dir_config, tooltip="Open images output directory.")
|
141 |
+
|
142 |
+
if tabname != "extras":
|
143 |
+
save = ToolButton('💾', elem_id=f'save_{tabname}', tooltip=f"Save the image to a dedicated directory ({shared.opts.outdir_save}).")
|
144 |
+
save_zip = ToolButton('🗃️', elem_id=f'save_zip_{tabname}', tooltip=f"Save zip archive with images to a dedicated directory ({shared.opts.outdir_save})")
|
145 |
+
|
146 |
+
buttons = {
|
147 |
+
'img2img': ToolButton('🖼️', elem_id=f'{tabname}_send_to_img2img', tooltip="Send image and generation parameters to img2img tab."),
|
148 |
+
'inpaint': ToolButton('🎨️', elem_id=f'{tabname}_send_to_inpaint', tooltip="Send image and generation parameters to img2img inpaint tab."),
|
149 |
+
'extras': ToolButton('📐', elem_id=f'{tabname}_send_to_extras', tooltip="Send image and generation parameters to extras tab.")
|
150 |
+
}
|
151 |
+
|
152 |
+
open_folder_button.click(
|
153 |
+
fn=lambda: open_folder(shared.opts.outdir_samples or outdir),
|
154 |
+
inputs=[],
|
155 |
+
outputs=[],
|
156 |
+
)
|
157 |
+
|
158 |
+
if tabname != "extras":
|
159 |
+
download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False, elem_id=f'download_files_{tabname}')
|
160 |
+
|
161 |
+
with gr.Group():
|
162 |
+
html_info = gr.HTML(elem_id=f'html_info_{tabname}', elem_classes="infotext")
|
163 |
+
html_log = gr.HTML(elem_id=f'html_log_{tabname}', elem_classes="html-log")
|
164 |
+
|
165 |
+
generation_info = gr.Textbox(visible=False, elem_id=f'generation_info_{tabname}')
|
166 |
+
if tabname == 'txt2img' or tabname == 'img2img':
|
167 |
+
generation_info_button = gr.Button(visible=False, elem_id=f"{tabname}_generation_info_button")
|
168 |
+
generation_info_button.click(
|
169 |
+
fn=update_generation_info,
|
170 |
+
_js="function(x, y, z){ return [x, y, selected_gallery_index()] }",
|
171 |
+
inputs=[generation_info, html_info, html_info],
|
172 |
+
outputs=[html_info, html_info],
|
173 |
+
show_progress=False,
|
174 |
+
)
|
175 |
+
|
176 |
+
save.click(
|
177 |
+
fn=call_queue.wrap_gradio_call(save_files),
|
178 |
+
_js="(x, y, z, w) => [x, y, false, selected_gallery_index()]",
|
179 |
+
inputs=[
|
180 |
+
generation_info,
|
181 |
+
result_gallery,
|
182 |
+
html_info,
|
183 |
+
html_info,
|
184 |
+
],
|
185 |
+
outputs=[
|
186 |
+
download_files,
|
187 |
+
html_log,
|
188 |
+
],
|
189 |
+
show_progress=False,
|
190 |
+
)
|
191 |
+
|
192 |
+
save_zip.click(
|
193 |
+
fn=call_queue.wrap_gradio_call(save_files),
|
194 |
+
_js="(x, y, z, w) => [x, y, true, selected_gallery_index()]",
|
195 |
+
inputs=[
|
196 |
+
generation_info,
|
197 |
+
result_gallery,
|
198 |
+
html_info,
|
199 |
+
html_info,
|
200 |
+
],
|
201 |
+
outputs=[
|
202 |
+
download_files,
|
203 |
+
html_log,
|
204 |
+
]
|
205 |
+
)
|
206 |
+
|
207 |
+
else:
|
208 |
+
html_info_x = gr.HTML(elem_id=f'html_info_x_{tabname}')
|
209 |
+
html_info = gr.HTML(elem_id=f'html_info_{tabname}', elem_classes="infotext")
|
210 |
+
html_log = gr.HTML(elem_id=f'html_log_{tabname}')
|
211 |
+
|
212 |
+
paste_field_names = []
|
213 |
+
if tabname == "txt2img":
|
214 |
+
paste_field_names = modules.scripts.scripts_txt2img.paste_field_names
|
215 |
+
elif tabname == "img2img":
|
216 |
+
paste_field_names = modules.scripts.scripts_img2img.paste_field_names
|
217 |
+
|
218 |
+
for paste_tabname, paste_button in buttons.items():
|
219 |
+
parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding(
|
220 |
+
paste_button=paste_button, tabname=paste_tabname, source_tabname="txt2img" if tabname == "txt2img" else None, source_image_component=result_gallery,
|
221 |
+
paste_field_names=paste_field_names
|
222 |
+
))
|
223 |
+
|
224 |
+
return result_gallery, generation_info if tabname != "extras" else html_info_x, html_info, html_log
|
225 |
+
|
226 |
+
|
227 |
+
def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id):
|
228 |
+
refresh_components = refresh_component if isinstance(refresh_component, list) else [refresh_component]
|
229 |
+
|
230 |
+
label = None
|
231 |
+
for comp in refresh_components:
|
232 |
+
label = getattr(comp, 'label', None)
|
233 |
+
if label is not None:
|
234 |
+
break
|
235 |
+
|
236 |
+
def refresh():
|
237 |
+
refresh_method()
|
238 |
+
args = refreshed_args() if callable(refreshed_args) else refreshed_args
|
239 |
+
|
240 |
+
for k, v in args.items():
|
241 |
+
for comp in refresh_components:
|
242 |
+
setattr(comp, k, v)
|
243 |
+
|
244 |
+
return [gr.update(**(args or {})) for _ in refresh_components] if len(refresh_components) > 1 else gr.update(**(args or {}))
|
245 |
+
|
246 |
+
refresh_button = ToolButton(value=refresh_symbol, elem_id=elem_id, tooltip=f"{label}: refresh" if label else "Refresh")
|
247 |
+
refresh_button.click(
|
248 |
+
fn=refresh,
|
249 |
+
inputs=[],
|
250 |
+
outputs=refresh_components
|
251 |
+
)
|
252 |
+
return refresh_button
|
253 |
+
|
254 |
+
|
255 |
+
def setup_dialog(button_show, dialog, *, button_close=None):
|
256 |
+
"""Sets up the UI so that the dialog (gr.Box) is invisible, and is only shown when buttons_show is clicked, in a fullscreen modal window."""
|
257 |
+
|
258 |
+
dialog.visible = False
|
259 |
+
|
260 |
+
button_show.click(
|
261 |
+
fn=lambda: gr.update(visible=True),
|
262 |
+
inputs=[],
|
263 |
+
outputs=[dialog],
|
264 |
+
).then(fn=None, _js="function(){ popupId('" + dialog.elem_id + "'); }")
|
265 |
+
|
266 |
+
if button_close:
|
267 |
+
button_close.click(fn=None, _js="closePopup")
|
268 |
+
|
modules/ui_components.py
ADDED
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
|
3 |
+
|
4 |
+
class FormComponent:
|
5 |
+
def get_expected_parent(self):
|
6 |
+
return gr.components.Form
|
7 |
+
|
8 |
+
|
9 |
+
gr.Dropdown.get_expected_parent = FormComponent.get_expected_parent
|
10 |
+
|
11 |
+
|
12 |
+
class ToolButton(FormComponent, gr.Button):
|
13 |
+
"""Small button with single emoji as text, fits inside gradio forms"""
|
14 |
+
|
15 |
+
def __init__(self, *args, **kwargs):
|
16 |
+
classes = kwargs.pop("elem_classes", [])
|
17 |
+
super().__init__(*args, elem_classes=["tool", *classes], **kwargs)
|
18 |
+
|
19 |
+
def get_block_name(self):
|
20 |
+
return "button"
|
21 |
+
|
22 |
+
|
23 |
+
class ResizeHandleRow(gr.Row):
|
24 |
+
"""Same as gr.Row but fits inside gradio forms"""
|
25 |
+
|
26 |
+
def __init__(self, **kwargs):
|
27 |
+
super().__init__(**kwargs)
|
28 |
+
|
29 |
+
self.elem_classes.append("resize-handle-row")
|
30 |
+
|
31 |
+
def get_block_name(self):
|
32 |
+
return "row"
|
33 |
+
|
34 |
+
|
35 |
+
class FormRow(FormComponent, gr.Row):
|
36 |
+
"""Same as gr.Row but fits inside gradio forms"""
|
37 |
+
|
38 |
+
def get_block_name(self):
|
39 |
+
return "row"
|
40 |
+
|
41 |
+
|
42 |
+
class FormColumn(FormComponent, gr.Column):
|
43 |
+
"""Same as gr.Column but fits inside gradio forms"""
|
44 |
+
|
45 |
+
def get_block_name(self):
|
46 |
+
return "column"
|
47 |
+
|
48 |
+
|
49 |
+
class FormGroup(FormComponent, gr.Group):
|
50 |
+
"""Same as gr.Group but fits inside gradio forms"""
|
51 |
+
|
52 |
+
def get_block_name(self):
|
53 |
+
return "group"
|
54 |
+
|
55 |
+
|
56 |
+
class FormHTML(FormComponent, gr.HTML):
|
57 |
+
"""Same as gr.HTML but fits inside gradio forms"""
|
58 |
+
|
59 |
+
def get_block_name(self):
|
60 |
+
return "html"
|
61 |
+
|
62 |
+
|
63 |
+
class FormColorPicker(FormComponent, gr.ColorPicker):
|
64 |
+
"""Same as gr.ColorPicker but fits inside gradio forms"""
|
65 |
+
|
66 |
+
def get_block_name(self):
|
67 |
+
return "colorpicker"
|
68 |
+
|
69 |
+
|
70 |
+
class DropdownMulti(FormComponent, gr.Dropdown):
|
71 |
+
"""Same as gr.Dropdown but always multiselect"""
|
72 |
+
def __init__(self, **kwargs):
|
73 |
+
super().__init__(multiselect=True, **kwargs)
|
74 |
+
|
75 |
+
def get_block_name(self):
|
76 |
+
return "dropdown"
|
77 |
+
|
78 |
+
|
79 |
+
class DropdownEditable(FormComponent, gr.Dropdown):
|
80 |
+
"""Same as gr.Dropdown but allows editing value"""
|
81 |
+
def __init__(self, **kwargs):
|
82 |
+
super().__init__(allow_custom_value=True, **kwargs)
|
83 |
+
|
84 |
+
def get_block_name(self):
|
85 |
+
return "dropdown"
|
86 |
+
|
87 |
+
|
88 |
+
class InputAccordion(gr.Checkbox):
|
89 |
+
"""A gr.Accordion that can be used as an input - returns True if open, False if closed.
|
90 |
+
|
91 |
+
Actaully just a hidden checkbox, but creates an accordion that follows and is followed by the state of the checkbox.
|
92 |
+
"""
|
93 |
+
|
94 |
+
global_index = 0
|
95 |
+
|
96 |
+
def __init__(self, value, **kwargs):
|
97 |
+
self.accordion_id = kwargs.get('elem_id')
|
98 |
+
if self.accordion_id is None:
|
99 |
+
self.accordion_id = f"input-accordion-{InputAccordion.global_index}"
|
100 |
+
InputAccordion.global_index += 1
|
101 |
+
|
102 |
+
kwargs_checkbox = {
|
103 |
+
**kwargs,
|
104 |
+
"elem_id": f"{self.accordion_id}-checkbox",
|
105 |
+
"visible": False,
|
106 |
+
}
|
107 |
+
super().__init__(value, **kwargs_checkbox)
|
108 |
+
|
109 |
+
self.change(fn=None, _js='function(checked){ inputAccordionChecked("' + self.accordion_id + '", checked); }', inputs=[self])
|
110 |
+
|
111 |
+
kwargs_accordion = {
|
112 |
+
**kwargs,
|
113 |
+
"elem_id": self.accordion_id,
|
114 |
+
"label": kwargs.get('label', 'Accordion'),
|
115 |
+
"elem_classes": ['input-accordion'],
|
116 |
+
"open": value,
|
117 |
+
}
|
118 |
+
self.accordion = gr.Accordion(**kwargs_accordion)
|
119 |
+
|
120 |
+
def extra(self):
|
121 |
+
"""Allows you to put something into the label of the accordion.
|
122 |
+
|
123 |
+
Use it like this:
|
124 |
+
|
125 |
+
```
|
126 |
+
with InputAccordion(False, label="Accordion") as acc:
|
127 |
+
with acc.extra():
|
128 |
+
FormHTML(value="hello", min_width=0)
|
129 |
+
|
130 |
+
...
|
131 |
+
```
|
132 |
+
"""
|
133 |
+
|
134 |
+
return gr.Column(elem_id=self.accordion_id + '-extra', elem_classes='input-accordion-extra', min_width=0)
|
135 |
+
|
136 |
+
def __enter__(self):
|
137 |
+
self.accordion.__enter__()
|
138 |
+
return self
|
139 |
+
|
140 |
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
141 |
+
self.accordion.__exit__(exc_type, exc_val, exc_tb)
|
142 |
+
|
143 |
+
def get_block_name(self):
|
144 |
+
return "checkbox"
|
145 |
+
|
modules/ui_extensions.py
ADDED
@@ -0,0 +1,669 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
import threading
|
4 |
+
import time
|
5 |
+
from datetime import datetime, timezone
|
6 |
+
|
7 |
+
import git
|
8 |
+
|
9 |
+
import gradio as gr
|
10 |
+
import html
|
11 |
+
import shutil
|
12 |
+
import errno
|
13 |
+
|
14 |
+
from modules import extensions, shared, paths, config_states, errors, restart
|
15 |
+
from modules.paths_internal import config_states_dir
|
16 |
+
from modules.call_queue import wrap_gradio_gpu_call
|
17 |
+
|
18 |
+
available_extensions = {"extensions": []}
|
19 |
+
STYLE_PRIMARY = ' style="color: var(--primary-400)"'
|
20 |
+
|
21 |
+
|
22 |
+
def check_access():
|
23 |
+
assert not shared.cmd_opts.disable_extension_access, "extension access disabled because of command line flags"
|
24 |
+
|
25 |
+
|
26 |
+
def apply_and_restart(disable_list, update_list, disable_all):
|
27 |
+
check_access()
|
28 |
+
|
29 |
+
disabled = json.loads(disable_list)
|
30 |
+
assert type(disabled) == list, f"wrong disable_list data for apply_and_restart: {disable_list}"
|
31 |
+
|
32 |
+
update = json.loads(update_list)
|
33 |
+
assert type(update) == list, f"wrong update_list data for apply_and_restart: {update_list}"
|
34 |
+
|
35 |
+
if update:
|
36 |
+
save_config_state("Backup (pre-update)")
|
37 |
+
|
38 |
+
update = set(update)
|
39 |
+
|
40 |
+
for ext in extensions.extensions:
|
41 |
+
if ext.name not in update:
|
42 |
+
continue
|
43 |
+
|
44 |
+
try:
|
45 |
+
ext.fetch_and_reset_hard()
|
46 |
+
except Exception:
|
47 |
+
errors.report(f"Error getting updates for {ext.name}", exc_info=True)
|
48 |
+
|
49 |
+
shared.opts.disabled_extensions = disabled
|
50 |
+
shared.opts.disable_all_extensions = disable_all
|
51 |
+
shared.opts.save(shared.config_filename)
|
52 |
+
|
53 |
+
if restart.is_restartable():
|
54 |
+
restart.restart_program()
|
55 |
+
else:
|
56 |
+
restart.stop_program()
|
57 |
+
|
58 |
+
|
59 |
+
def save_config_state(name):
|
60 |
+
current_config_state = config_states.get_config()
|
61 |
+
if not name:
|
62 |
+
name = "Config"
|
63 |
+
current_config_state["name"] = name
|
64 |
+
timestamp = datetime.now().strftime('%Y_%m_%d-%H_%M_%S')
|
65 |
+
filename = os.path.join(config_states_dir, f"{timestamp}_{name}.json")
|
66 |
+
print(f"Saving backup of webui/extension state to {filename}.")
|
67 |
+
with open(filename, "w", encoding="utf-8") as f:
|
68 |
+
json.dump(current_config_state, f, indent=4)
|
69 |
+
config_states.list_config_states()
|
70 |
+
new_value = next(iter(config_states.all_config_states.keys()), "Current")
|
71 |
+
new_choices = ["Current"] + list(config_states.all_config_states.keys())
|
72 |
+
return gr.Dropdown.update(value=new_value, choices=new_choices), f"<span>Saved current webui/extension state to \"{filename}\"</span>"
|
73 |
+
|
74 |
+
|
75 |
+
def restore_config_state(confirmed, config_state_name, restore_type):
|
76 |
+
if config_state_name == "Current":
|
77 |
+
return "<span>Select a config to restore from.</span>"
|
78 |
+
if not confirmed:
|
79 |
+
return "<span>Cancelled.</span>"
|
80 |
+
|
81 |
+
check_access()
|
82 |
+
|
83 |
+
config_state = config_states.all_config_states[config_state_name]
|
84 |
+
|
85 |
+
print(f"*** Restoring webui state from backup: {restore_type} ***")
|
86 |
+
|
87 |
+
if restore_type == "extensions" or restore_type == "both":
|
88 |
+
shared.opts.restore_config_state_file = config_state["filepath"]
|
89 |
+
shared.opts.save(shared.config_filename)
|
90 |
+
|
91 |
+
if restore_type == "webui" or restore_type == "both":
|
92 |
+
config_states.restore_webui_config(config_state)
|
93 |
+
|
94 |
+
shared.state.request_restart()
|
95 |
+
|
96 |
+
return ""
|
97 |
+
|
98 |
+
|
99 |
+
def check_updates(id_task, disable_list):
|
100 |
+
check_access()
|
101 |
+
|
102 |
+
disabled = json.loads(disable_list)
|
103 |
+
assert type(disabled) == list, f"wrong disable_list data for apply_and_restart: {disable_list}"
|
104 |
+
|
105 |
+
exts = [ext for ext in extensions.extensions if ext.remote is not None and ext.name not in disabled]
|
106 |
+
shared.state.job_count = len(exts)
|
107 |
+
|
108 |
+
for ext in exts:
|
109 |
+
shared.state.textinfo = ext.name
|
110 |
+
|
111 |
+
try:
|
112 |
+
ext.check_updates()
|
113 |
+
except FileNotFoundError as e:
|
114 |
+
if 'FETCH_HEAD' not in str(e):
|
115 |
+
raise
|
116 |
+
except Exception:
|
117 |
+
errors.report(f"Error checking updates for {ext.name}", exc_info=True)
|
118 |
+
|
119 |
+
shared.state.nextjob()
|
120 |
+
|
121 |
+
return extension_table(), ""
|
122 |
+
|
123 |
+
|
124 |
+
def make_commit_link(commit_hash, remote, text=None):
|
125 |
+
if text is None:
|
126 |
+
text = commit_hash[:8]
|
127 |
+
if remote.startswith("https://github.com/"):
|
128 |
+
if remote.endswith(".git"):
|
129 |
+
remote = remote[:-4]
|
130 |
+
href = remote + "/commit/" + commit_hash
|
131 |
+
return f'<a href="{href}" target="_blank">{text}</a>'
|
132 |
+
else:
|
133 |
+
return text
|
134 |
+
|
135 |
+
|
136 |
+
def extension_table():
|
137 |
+
code = f"""<!-- {time.time()} -->
|
138 |
+
<table id="extensions">
|
139 |
+
<thead>
|
140 |
+
<tr>
|
141 |
+
<th>
|
142 |
+
<input class="gr-check-radio gr-checkbox all_extensions_toggle" type="checkbox" {'checked="checked"' if all(ext.enabled for ext in extensions.extensions) else ''} onchange="toggle_all_extensions(event)" />
|
143 |
+
<abbr title="Use checkbox to enable the extension; it will be enabled or disabled when you click apply button">Extension</abbr>
|
144 |
+
</th>
|
145 |
+
<th>URL</th>
|
146 |
+
<th>Branch</th>
|
147 |
+
<th>Version</th>
|
148 |
+
<th>Date</th>
|
149 |
+
<th><abbr title="Use checkbox to mark the extension for update; it will be updated when you click apply button">Update</abbr></th>
|
150 |
+
</tr>
|
151 |
+
</thead>
|
152 |
+
<tbody>
|
153 |
+
"""
|
154 |
+
|
155 |
+
for ext in extensions.extensions:
|
156 |
+
ext: extensions.Extension
|
157 |
+
ext.read_info_from_repo()
|
158 |
+
|
159 |
+
remote = f"""<a href="{html.escape(ext.remote or '')}" target="_blank">{html.escape("built-in" if ext.is_builtin else ext.remote or '')}</a>"""
|
160 |
+
|
161 |
+
if ext.can_update:
|
162 |
+
ext_status = f"""<label><input class="gr-check-radio gr-checkbox" name="update_{html.escape(ext.name)}" checked="checked" type="checkbox">{html.escape(ext.status)}</label>"""
|
163 |
+
else:
|
164 |
+
ext_status = ext.status
|
165 |
+
|
166 |
+
style = ""
|
167 |
+
if shared.cmd_opts.disable_extra_extensions and not ext.is_builtin or shared.opts.disable_all_extensions == "extra" and not ext.is_builtin or shared.cmd_opts.disable_all_extensions or shared.opts.disable_all_extensions == "all":
|
168 |
+
style = STYLE_PRIMARY
|
169 |
+
|
170 |
+
version_link = ext.version
|
171 |
+
if ext.commit_hash and ext.remote:
|
172 |
+
version_link = make_commit_link(ext.commit_hash, ext.remote, ext.version)
|
173 |
+
|
174 |
+
code += f"""
|
175 |
+
<tr>
|
176 |
+
<td><label{style}><input class="gr-check-radio gr-checkbox extension_toggle" name="enable_{html.escape(ext.name)}" type="checkbox" {'checked="checked"' if ext.enabled else ''} onchange="toggle_extension(event)" />{html.escape(ext.name)}</label></td>
|
177 |
+
<td>{remote}</td>
|
178 |
+
<td>{ext.branch}</td>
|
179 |
+
<td>{version_link}</td>
|
180 |
+
<td>{datetime.fromtimestamp(ext.commit_date) if ext.commit_date else ""}</td>
|
181 |
+
<td{' class="extension_status"' if ext.remote is not None else ''}>{ext_status}</td>
|
182 |
+
</tr>
|
183 |
+
"""
|
184 |
+
|
185 |
+
code += """
|
186 |
+
</tbody>
|
187 |
+
</table>
|
188 |
+
"""
|
189 |
+
|
190 |
+
return code
|
191 |
+
|
192 |
+
|
193 |
+
def update_config_states_table(state_name):
|
194 |
+
if state_name == "Current":
|
195 |
+
config_state = config_states.get_config()
|
196 |
+
else:
|
197 |
+
config_state = config_states.all_config_states[state_name]
|
198 |
+
|
199 |
+
config_name = config_state.get("name", "Config")
|
200 |
+
created_date = time.asctime(time.gmtime(config_state["created_at"]))
|
201 |
+
filepath = config_state.get("filepath", "<unknown>")
|
202 |
+
|
203 |
+
try:
|
204 |
+
webui_remote = config_state["webui"]["remote"] or ""
|
205 |
+
webui_branch = config_state["webui"]["branch"]
|
206 |
+
webui_commit_hash = config_state["webui"]["commit_hash"] or "<unknown>"
|
207 |
+
webui_commit_date = config_state["webui"]["commit_date"]
|
208 |
+
if webui_commit_date:
|
209 |
+
webui_commit_date = time.asctime(time.gmtime(webui_commit_date))
|
210 |
+
else:
|
211 |
+
webui_commit_date = "<unknown>"
|
212 |
+
|
213 |
+
remote = f"""<a href="{html.escape(webui_remote)}" target="_blank">{html.escape(webui_remote or '')}</a>"""
|
214 |
+
commit_link = make_commit_link(webui_commit_hash, webui_remote)
|
215 |
+
date_link = make_commit_link(webui_commit_hash, webui_remote, webui_commit_date)
|
216 |
+
|
217 |
+
current_webui = config_states.get_webui_config()
|
218 |
+
|
219 |
+
style_remote = ""
|
220 |
+
style_branch = ""
|
221 |
+
style_commit = ""
|
222 |
+
if current_webui["remote"] != webui_remote:
|
223 |
+
style_remote = STYLE_PRIMARY
|
224 |
+
if current_webui["branch"] != webui_branch:
|
225 |
+
style_branch = STYLE_PRIMARY
|
226 |
+
if current_webui["commit_hash"] != webui_commit_hash:
|
227 |
+
style_commit = STYLE_PRIMARY
|
228 |
+
|
229 |
+
code = f"""<!-- {time.time()} -->
|
230 |
+
<h2>Config Backup: {config_name}</h2>
|
231 |
+
<div><b>Filepath:</b> {filepath}</div>
|
232 |
+
<div><b>Created at:</b> {created_date}</div>
|
233 |
+
<h2>WebUI State</h2>
|
234 |
+
<table id="config_state_webui">
|
235 |
+
<thead>
|
236 |
+
<tr>
|
237 |
+
<th>URL</th>
|
238 |
+
<th>Branch</th>
|
239 |
+
<th>Commit</th>
|
240 |
+
<th>Date</th>
|
241 |
+
</tr>
|
242 |
+
</thead>
|
243 |
+
<tbody>
|
244 |
+
<tr>
|
245 |
+
<td>
|
246 |
+
<label{style_remote}>{remote}</label>
|
247 |
+
</td>
|
248 |
+
<td>
|
249 |
+
<label{style_branch}>{webui_branch}</label>
|
250 |
+
</td>
|
251 |
+
<td>
|
252 |
+
<label{style_commit}>{commit_link}</label>
|
253 |
+
</td>
|
254 |
+
<td>
|
255 |
+
<label{style_commit}>{date_link}</label>
|
256 |
+
</td>
|
257 |
+
</tr>
|
258 |
+
</tbody>
|
259 |
+
</table>
|
260 |
+
<h2>Extension State</h2>
|
261 |
+
<table id="config_state_extensions">
|
262 |
+
<thead>
|
263 |
+
<tr>
|
264 |
+
<th>Extension</th>
|
265 |
+
<th>URL</th>
|
266 |
+
<th>Branch</th>
|
267 |
+
<th>Commit</th>
|
268 |
+
<th>Date</th>
|
269 |
+
</tr>
|
270 |
+
</thead>
|
271 |
+
<tbody>
|
272 |
+
"""
|
273 |
+
|
274 |
+
ext_map = {ext.name: ext for ext in extensions.extensions}
|
275 |
+
|
276 |
+
for ext_name, ext_conf in config_state["extensions"].items():
|
277 |
+
ext_remote = ext_conf["remote"] or ""
|
278 |
+
ext_branch = ext_conf["branch"] or "<unknown>"
|
279 |
+
ext_enabled = ext_conf["enabled"]
|
280 |
+
ext_commit_hash = ext_conf["commit_hash"] or "<unknown>"
|
281 |
+
ext_commit_date = ext_conf["commit_date"]
|
282 |
+
if ext_commit_date:
|
283 |
+
ext_commit_date = time.asctime(time.gmtime(ext_commit_date))
|
284 |
+
else:
|
285 |
+
ext_commit_date = "<unknown>"
|
286 |
+
|
287 |
+
remote = f"""<a href="{html.escape(ext_remote)}" target="_blank">{html.escape(ext_remote or '')}</a>"""
|
288 |
+
commit_link = make_commit_link(ext_commit_hash, ext_remote)
|
289 |
+
date_link = make_commit_link(ext_commit_hash, ext_remote, ext_commit_date)
|
290 |
+
|
291 |
+
style_enabled = ""
|
292 |
+
style_remote = ""
|
293 |
+
style_branch = ""
|
294 |
+
style_commit = ""
|
295 |
+
if ext_name in ext_map:
|
296 |
+
current_ext = ext_map[ext_name]
|
297 |
+
current_ext.read_info_from_repo()
|
298 |
+
if current_ext.enabled != ext_enabled:
|
299 |
+
style_enabled = STYLE_PRIMARY
|
300 |
+
if current_ext.remote != ext_remote:
|
301 |
+
style_remote = STYLE_PRIMARY
|
302 |
+
if current_ext.branch != ext_branch:
|
303 |
+
style_branch = STYLE_PRIMARY
|
304 |
+
if current_ext.commit_hash != ext_commit_hash:
|
305 |
+
style_commit = STYLE_PRIMARY
|
306 |
+
|
307 |
+
code += f""" <tr>
|
308 |
+
<td><label{style_enabled}><input class="gr-check-radio gr-checkbox" type="checkbox" disabled="true" {'checked="checked"' if ext_enabled else ''}>{html.escape(ext_name)}</label></td>
|
309 |
+
<td><label{style_remote}>{remote}</label></td>
|
310 |
+
<td><label{style_branch}>{ext_branch}</label></td>
|
311 |
+
<td><label{style_commit}>{commit_link}</label></td>
|
312 |
+
<td><label{style_commit}>{date_link}</label></td>
|
313 |
+
</tr>
|
314 |
+
"""
|
315 |
+
|
316 |
+
code += """ </tbody>
|
317 |
+
</table>"""
|
318 |
+
|
319 |
+
except Exception as e:
|
320 |
+
print(f"[ERROR]: Config states {filepath}, {e}")
|
321 |
+
code = f"""<!-- {time.time()} -->
|
322 |
+
<h2>Config Backup: {config_name}</h2>
|
323 |
+
<div><b>Filepath:</b> {filepath}</div>
|
324 |
+
<div><b>Created at:</b> {created_date}</div>
|
325 |
+
<h2>This file is corrupted</h2>"""
|
326 |
+
|
327 |
+
return code
|
328 |
+
|
329 |
+
|
330 |
+
def normalize_git_url(url):
|
331 |
+
if url is None:
|
332 |
+
return ""
|
333 |
+
|
334 |
+
url = url.replace(".git", "")
|
335 |
+
return url
|
336 |
+
|
337 |
+
|
338 |
+
def install_extension_from_url(dirname, url, branch_name=None):
|
339 |
+
check_access()
|
340 |
+
|
341 |
+
if isinstance(dirname, str):
|
342 |
+
dirname = dirname.strip()
|
343 |
+
if isinstance(url, str):
|
344 |
+
url = url.strip()
|
345 |
+
|
346 |
+
assert url, 'No URL specified'
|
347 |
+
|
348 |
+
if dirname is None or dirname == "":
|
349 |
+
*parts, last_part = url.split('/')
|
350 |
+
last_part = normalize_git_url(last_part)
|
351 |
+
|
352 |
+
dirname = last_part
|
353 |
+
|
354 |
+
target_dir = os.path.join(extensions.extensions_dir, dirname)
|
355 |
+
assert not os.path.exists(target_dir), f'Extension directory already exists: {target_dir}'
|
356 |
+
|
357 |
+
normalized_url = normalize_git_url(url)
|
358 |
+
if any(x for x in extensions.extensions if normalize_git_url(x.remote) == normalized_url):
|
359 |
+
raise Exception(f'Extension with this URL is already installed: {url}')
|
360 |
+
|
361 |
+
tmpdir = os.path.join(paths.data_path, "tmp", dirname)
|
362 |
+
|
363 |
+
try:
|
364 |
+
shutil.rmtree(tmpdir, True)
|
365 |
+
if not branch_name:
|
366 |
+
# if no branch is specified, use the default branch
|
367 |
+
with git.Repo.clone_from(url, tmpdir, filter=['blob:none']) as repo:
|
368 |
+
repo.remote().fetch()
|
369 |
+
for submodule in repo.submodules:
|
370 |
+
submodule.update()
|
371 |
+
else:
|
372 |
+
with git.Repo.clone_from(url, tmpdir, filter=['blob:none'], branch=branch_name) as repo:
|
373 |
+
repo.remote().fetch()
|
374 |
+
for submodule in repo.submodules:
|
375 |
+
submodule.update()
|
376 |
+
try:
|
377 |
+
os.rename(tmpdir, target_dir)
|
378 |
+
except OSError as err:
|
379 |
+
if err.errno == errno.EXDEV:
|
380 |
+
# Cross device link, typical in docker or when tmp/ and extensions/ are on different file systems
|
381 |
+
# Since we can't use a rename, do the slower but more versitile shutil.move()
|
382 |
+
shutil.move(tmpdir, target_dir)
|
383 |
+
else:
|
384 |
+
# Something else, not enough free space, permissions, etc. rethrow it so that it gets handled.
|
385 |
+
raise err
|
386 |
+
|
387 |
+
import launch
|
388 |
+
launch.run_extension_installer(target_dir)
|
389 |
+
|
390 |
+
extensions.list_extensions()
|
391 |
+
return [extension_table(), html.escape(f"Installed into {target_dir}. Use Installed tab to restart.")]
|
392 |
+
finally:
|
393 |
+
shutil.rmtree(tmpdir, True)
|
394 |
+
|
395 |
+
|
396 |
+
def install_extension_from_index(url, hide_tags, sort_column, filter_text):
|
397 |
+
ext_table, message = install_extension_from_url(None, url)
|
398 |
+
|
399 |
+
code, _ = refresh_available_extensions_from_data(hide_tags, sort_column, filter_text)
|
400 |
+
|
401 |
+
return code, ext_table, message, ''
|
402 |
+
|
403 |
+
|
404 |
+
def refresh_available_extensions(url, hide_tags, sort_column):
|
405 |
+
global available_extensions
|
406 |
+
|
407 |
+
import urllib.request
|
408 |
+
with urllib.request.urlopen(url) as response:
|
409 |
+
text = response.read()
|
410 |
+
|
411 |
+
available_extensions = json.loads(text)
|
412 |
+
|
413 |
+
code, tags = refresh_available_extensions_from_data(hide_tags, sort_column)
|
414 |
+
|
415 |
+
return url, code, gr.CheckboxGroup.update(choices=tags), '', ''
|
416 |
+
|
417 |
+
|
418 |
+
def refresh_available_extensions_for_tags(hide_tags, sort_column, filter_text):
|
419 |
+
code, _ = refresh_available_extensions_from_data(hide_tags, sort_column, filter_text)
|
420 |
+
|
421 |
+
return code, ''
|
422 |
+
|
423 |
+
|
424 |
+
def search_extensions(filter_text, hide_tags, sort_column):
|
425 |
+
code, _ = refresh_available_extensions_from_data(hide_tags, sort_column, filter_text)
|
426 |
+
|
427 |
+
return code, ''
|
428 |
+
|
429 |
+
|
430 |
+
sort_ordering = [
|
431 |
+
# (reverse, order_by_function)
|
432 |
+
(True, lambda x: x.get('added', 'z')),
|
433 |
+
(False, lambda x: x.get('added', 'z')),
|
434 |
+
(False, lambda x: x.get('name', 'z')),
|
435 |
+
(True, lambda x: x.get('name', 'z')),
|
436 |
+
(False, lambda x: 'z'),
|
437 |
+
(True, lambda x: x.get('commit_time', '')),
|
438 |
+
(True, lambda x: x.get('created_at', '')),
|
439 |
+
(True, lambda x: x.get('stars', 0)),
|
440 |
+
]
|
441 |
+
|
442 |
+
|
443 |
+
def get_date(info: dict, key):
|
444 |
+
try:
|
445 |
+
return datetime.strptime(info.get(key), "%Y-%m-%dT%H:%M:%SZ").replace(tzinfo=timezone.utc).astimezone().strftime("%Y-%m-%d")
|
446 |
+
except (ValueError, TypeError):
|
447 |
+
return ''
|
448 |
+
|
449 |
+
|
450 |
+
def refresh_available_extensions_from_data(hide_tags, sort_column, filter_text=""):
|
451 |
+
extlist = available_extensions["extensions"]
|
452 |
+
installed_extension_urls = {normalize_git_url(extension.remote): extension.name for extension in extensions.extensions}
|
453 |
+
|
454 |
+
tags = available_extensions.get("tags", {})
|
455 |
+
tags_to_hide = set(hide_tags)
|
456 |
+
hidden = 0
|
457 |
+
|
458 |
+
code = f"""<!-- {time.time()} -->
|
459 |
+
<table id="available_extensions">
|
460 |
+
<thead>
|
461 |
+
<tr>
|
462 |
+
<th>Extension</th>
|
463 |
+
<th>Description</th>
|
464 |
+
<th>Action</th>
|
465 |
+
</tr>
|
466 |
+
</thead>
|
467 |
+
<tbody>
|
468 |
+
"""
|
469 |
+
|
470 |
+
sort_reverse, sort_function = sort_ordering[sort_column if 0 <= sort_column < len(sort_ordering) else 0]
|
471 |
+
|
472 |
+
for ext in sorted(extlist, key=sort_function, reverse=sort_reverse):
|
473 |
+
name = ext.get("name", "noname")
|
474 |
+
stars = int(ext.get("stars", 0))
|
475 |
+
added = ext.get('added', 'unknown')
|
476 |
+
update_time = get_date(ext, 'commit_time')
|
477 |
+
create_time = get_date(ext, 'created_at')
|
478 |
+
url = ext.get("url", None)
|
479 |
+
description = ext.get("description", "")
|
480 |
+
extension_tags = ext.get("tags", [])
|
481 |
+
|
482 |
+
if url is None:
|
483 |
+
continue
|
484 |
+
|
485 |
+
existing = installed_extension_urls.get(normalize_git_url(url), None)
|
486 |
+
extension_tags = extension_tags + ["installed"] if existing else extension_tags
|
487 |
+
|
488 |
+
if any(x for x in extension_tags if x in tags_to_hide):
|
489 |
+
hidden += 1
|
490 |
+
continue
|
491 |
+
|
492 |
+
if filter_text and filter_text.strip():
|
493 |
+
if filter_text.lower() not in html.escape(name).lower() and filter_text.lower() not in html.escape(description).lower():
|
494 |
+
hidden += 1
|
495 |
+
continue
|
496 |
+
|
497 |
+
install_code = f"""<button onclick="install_extension_from_index(this, '{html.escape(url)}')" {"disabled=disabled" if existing else ""} class="lg secondary gradio-button custom-button">{"Install" if not existing else "Installed"}</button>"""
|
498 |
+
|
499 |
+
tags_text = ", ".join([f"<span class='extension-tag' title='{tags.get(x, '')}'>{x}</span>" for x in extension_tags])
|
500 |
+
|
501 |
+
code += f"""
|
502 |
+
<tr>
|
503 |
+
<td><a href="{html.escape(url)}" target="_blank">{html.escape(name)}</a><br />{tags_text}</td>
|
504 |
+
<td>{html.escape(description)}<p class="info">
|
505 |
+
<span class="date_added">Update: {html.escape(update_time)} Added: {html.escape(added)} Created: {html.escape(create_time)}</span><span class="star_count">stars: <b>{stars}</b></a></p></td>
|
506 |
+
<td>{install_code}</td>
|
507 |
+
</tr>
|
508 |
+
|
509 |
+
"""
|
510 |
+
|
511 |
+
for tag in [x for x in extension_tags if x not in tags]:
|
512 |
+
tags[tag] = tag
|
513 |
+
|
514 |
+
code += """
|
515 |
+
</tbody>
|
516 |
+
</table>
|
517 |
+
"""
|
518 |
+
|
519 |
+
if hidden > 0:
|
520 |
+
code += f"<p>Extension hidden: {hidden}</p>"
|
521 |
+
|
522 |
+
return code, list(tags)
|
523 |
+
|
524 |
+
|
525 |
+
def preload_extensions_git_metadata():
|
526 |
+
for extension in extensions.extensions:
|
527 |
+
extension.read_info_from_repo()
|
528 |
+
|
529 |
+
|
530 |
+
def create_ui():
|
531 |
+
import modules.ui
|
532 |
+
|
533 |
+
config_states.list_config_states()
|
534 |
+
|
535 |
+
threading.Thread(target=preload_extensions_git_metadata).start()
|
536 |
+
|
537 |
+
with gr.Blocks(analytics_enabled=False) as ui:
|
538 |
+
with gr.Tabs(elem_id="tabs_extensions"):
|
539 |
+
with gr.TabItem("Installed", id="installed"):
|
540 |
+
|
541 |
+
with gr.Row(elem_id="extensions_installed_top"):
|
542 |
+
apply_label = ("Apply and restart UI" if restart.is_restartable() else "Apply and quit")
|
543 |
+
apply = gr.Button(value=apply_label, variant="primary")
|
544 |
+
check = gr.Button(value="Check for updates")
|
545 |
+
extensions_disable_all = gr.Radio(label="Disable all extensions", choices=["none", "extra", "all"], value=shared.opts.disable_all_extensions, elem_id="extensions_disable_all")
|
546 |
+
extensions_disabled_list = gr.Text(elem_id="extensions_disabled_list", visible=False, container=False)
|
547 |
+
extensions_update_list = gr.Text(elem_id="extensions_update_list", visible=False, container=False)
|
548 |
+
|
549 |
+
html = ""
|
550 |
+
|
551 |
+
if shared.cmd_opts.disable_all_extensions or shared.cmd_opts.disable_extra_extensions or shared.opts.disable_all_extensions != "none":
|
552 |
+
if shared.cmd_opts.disable_all_extensions:
|
553 |
+
msg = '"--disable-all-extensions" was used, remove it to load all extensions again'
|
554 |
+
elif shared.opts.disable_all_extensions != "none":
|
555 |
+
msg = '"Disable all extensions" was set, change it to "none" to load all extensions again'
|
556 |
+
elif shared.cmd_opts.disable_extra_extensions:
|
557 |
+
msg = '"--disable-extra-extensions" was used, remove it to load all extensions again'
|
558 |
+
html = f'<span style="color: var(--primary-400);">{msg}</span>'
|
559 |
+
|
560 |
+
with gr.Row():
|
561 |
+
info = gr.HTML(html)
|
562 |
+
|
563 |
+
with gr.Row(elem_classes="progress-container"):
|
564 |
+
extensions_table = gr.HTML('Loading...', elem_id="extensions_installed_html")
|
565 |
+
|
566 |
+
ui.load(fn=extension_table, inputs=[], outputs=[extensions_table])
|
567 |
+
|
568 |
+
apply.click(
|
569 |
+
fn=apply_and_restart,
|
570 |
+
_js="extensions_apply",
|
571 |
+
inputs=[extensions_disabled_list, extensions_update_list, extensions_disable_all],
|
572 |
+
outputs=[],
|
573 |
+
)
|
574 |
+
|
575 |
+
check.click(
|
576 |
+
fn=wrap_gradio_gpu_call(check_updates, extra_outputs=[gr.update()]),
|
577 |
+
_js="extensions_check",
|
578 |
+
inputs=[info, extensions_disabled_list],
|
579 |
+
outputs=[extensions_table, info],
|
580 |
+
)
|
581 |
+
|
582 |
+
with gr.TabItem("Available", id="available"):
|
583 |
+
with gr.Row():
|
584 |
+
refresh_available_extensions_button = gr.Button(value="Load from:", variant="primary")
|
585 |
+
extensions_index_url = os.environ.get('WEBUI_EXTENSIONS_INDEX', "https://raw.githubusercontent.com/AUTOMATIC1111/stable-diffusion-webui-extensions/master/index.json")
|
586 |
+
available_extensions_index = gr.Text(value=extensions_index_url, label="Extension index URL", container=False)
|
587 |
+
extension_to_install = gr.Text(elem_id="extension_to_install", visible=False)
|
588 |
+
install_extension_button = gr.Button(elem_id="install_extension_button", visible=False)
|
589 |
+
|
590 |
+
with gr.Row():
|
591 |
+
hide_tags = gr.CheckboxGroup(value=["ads", "localization", "installed"], label="Hide extensions with tags", choices=["script", "ads", "localization", "installed"])
|
592 |
+
sort_column = gr.Radio(value="newest first", label="Order", choices=["newest first", "oldest first", "a-z", "z-a", "internal order",'update time', 'create time', "stars"], type="index")
|
593 |
+
|
594 |
+
with gr.Row():
|
595 |
+
search_extensions_text = gr.Text(label="Search", container=False)
|
596 |
+
|
597 |
+
install_result = gr.HTML()
|
598 |
+
available_extensions_table = gr.HTML()
|
599 |
+
|
600 |
+
refresh_available_extensions_button.click(
|
601 |
+
fn=modules.ui.wrap_gradio_call(refresh_available_extensions, extra_outputs=[gr.update(), gr.update(), gr.update(), gr.update()]),
|
602 |
+
inputs=[available_extensions_index, hide_tags, sort_column],
|
603 |
+
outputs=[available_extensions_index, available_extensions_table, hide_tags, search_extensions_text, install_result],
|
604 |
+
)
|
605 |
+
|
606 |
+
install_extension_button.click(
|
607 |
+
fn=modules.ui.wrap_gradio_call(install_extension_from_index, extra_outputs=[gr.update(), gr.update()]),
|
608 |
+
inputs=[extension_to_install, hide_tags, sort_column, search_extensions_text],
|
609 |
+
outputs=[available_extensions_table, extensions_table, install_result],
|
610 |
+
)
|
611 |
+
|
612 |
+
search_extensions_text.change(
|
613 |
+
fn=modules.ui.wrap_gradio_call(search_extensions, extra_outputs=[gr.update()]),
|
614 |
+
inputs=[search_extensions_text, hide_tags, sort_column],
|
615 |
+
outputs=[available_extensions_table, install_result],
|
616 |
+
)
|
617 |
+
|
618 |
+
hide_tags.change(
|
619 |
+
fn=modules.ui.wrap_gradio_call(refresh_available_extensions_for_tags, extra_outputs=[gr.update()]),
|
620 |
+
inputs=[hide_tags, sort_column, search_extensions_text],
|
621 |
+
outputs=[available_extensions_table, install_result]
|
622 |
+
)
|
623 |
+
|
624 |
+
sort_column.change(
|
625 |
+
fn=modules.ui.wrap_gradio_call(refresh_available_extensions_for_tags, extra_outputs=[gr.update()]),
|
626 |
+
inputs=[hide_tags, sort_column, search_extensions_text],
|
627 |
+
outputs=[available_extensions_table, install_result]
|
628 |
+
)
|
629 |
+
|
630 |
+
with gr.TabItem("Install from URL", id="install_from_url"):
|
631 |
+
install_url = gr.Text(label="URL for extension's git repository")
|
632 |
+
install_branch = gr.Text(label="Specific branch name", placeholder="Leave empty for default main branch")
|
633 |
+
install_dirname = gr.Text(label="Local directory name", placeholder="Leave empty for auto")
|
634 |
+
install_button = gr.Button(value="Install", variant="primary")
|
635 |
+
install_result = gr.HTML(elem_id="extension_install_result")
|
636 |
+
|
637 |
+
install_button.click(
|
638 |
+
fn=modules.ui.wrap_gradio_call(lambda *args: [gr.update(), *install_extension_from_url(*args)], extra_outputs=[gr.update(), gr.update()]),
|
639 |
+
inputs=[install_dirname, install_url, install_branch],
|
640 |
+
outputs=[install_url, extensions_table, install_result],
|
641 |
+
)
|
642 |
+
|
643 |
+
with gr.TabItem("Backup/Restore"):
|
644 |
+
with gr.Row(elem_id="extensions_backup_top_row"):
|
645 |
+
config_states_list = gr.Dropdown(label="Saved Configs", elem_id="extension_backup_saved_configs", value="Current", choices=["Current"] + list(config_states.all_config_states.keys()))
|
646 |
+
modules.ui.create_refresh_button(config_states_list, config_states.list_config_states, lambda: {"choices": ["Current"] + list(config_states.all_config_states.keys())}, "refresh_config_states")
|
647 |
+
config_restore_type = gr.Radio(label="State to restore", choices=["extensions", "webui", "both"], value="extensions", elem_id="extension_backup_restore_type")
|
648 |
+
config_restore_button = gr.Button(value="Restore Selected Config", variant="primary", elem_id="extension_backup_restore")
|
649 |
+
with gr.Row(elem_id="extensions_backup_top_row2"):
|
650 |
+
config_save_name = gr.Textbox("", placeholder="Config Name", show_label=False)
|
651 |
+
config_save_button = gr.Button(value="Save Current Config")
|
652 |
+
|
653 |
+
config_states_info = gr.HTML("")
|
654 |
+
config_states_table = gr.HTML("Loading...")
|
655 |
+
ui.load(fn=update_config_states_table, inputs=[config_states_list], outputs=[config_states_table])
|
656 |
+
|
657 |
+
config_save_button.click(fn=save_config_state, inputs=[config_save_name], outputs=[config_states_list, config_states_info])
|
658 |
+
|
659 |
+
dummy_component = gr.Label(visible=False)
|
660 |
+
config_restore_button.click(fn=restore_config_state, _js="config_state_confirm_restore", inputs=[dummy_component, config_states_list, config_restore_type], outputs=[config_states_info])
|
661 |
+
|
662 |
+
config_states_list.change(
|
663 |
+
fn=update_config_states_table,
|
664 |
+
inputs=[config_states_list],
|
665 |
+
outputs=[config_states_table],
|
666 |
+
)
|
667 |
+
|
668 |
+
|
669 |
+
return ui
|
modules/ui_extra_networks.py
ADDED
@@ -0,0 +1,455 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os.path
|
2 |
+
import urllib.parse
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
+
from modules import shared, ui_extra_networks_user_metadata, errors, extra_networks
|
6 |
+
from modules.images import read_info_from_image, save_image_with_geninfo
|
7 |
+
import gradio as gr
|
8 |
+
import json
|
9 |
+
import html
|
10 |
+
from fastapi.exceptions import HTTPException
|
11 |
+
|
12 |
+
from modules.generation_parameters_copypaste import image_from_url_text
|
13 |
+
from modules.ui_components import ToolButton
|
14 |
+
|
15 |
+
extra_pages = []
|
16 |
+
allowed_dirs = set()
|
17 |
+
|
18 |
+
|
19 |
+
def register_page(page):
|
20 |
+
"""registers extra networks page for the UI; recommend doing it in on_before_ui() callback for extensions"""
|
21 |
+
|
22 |
+
extra_pages.append(page)
|
23 |
+
allowed_dirs.clear()
|
24 |
+
allowed_dirs.update(set(sum([x.allowed_directories_for_previews() for x in extra_pages], [])))
|
25 |
+
|
26 |
+
|
27 |
+
def fetch_file(filename: str = ""):
|
28 |
+
from starlette.responses import FileResponse
|
29 |
+
|
30 |
+
if not os.path.isfile(filename):
|
31 |
+
raise HTTPException(status_code=404, detail="File not found")
|
32 |
+
|
33 |
+
if not any(Path(x).absolute() in Path(filename).absolute().parents for x in allowed_dirs):
|
34 |
+
raise ValueError(f"File cannot be fetched: {filename}. Must be in one of directories registered by extra pages.")
|
35 |
+
|
36 |
+
ext = os.path.splitext(filename)[1].lower()
|
37 |
+
if ext not in (".png", ".jpg", ".jpeg", ".webp", ".gif"):
|
38 |
+
raise ValueError(f"File cannot be fetched: {filename}. Only png, jpg, webp, and gif.")
|
39 |
+
|
40 |
+
# would profit from returning 304
|
41 |
+
return FileResponse(filename, headers={"Accept-Ranges": "bytes"})
|
42 |
+
|
43 |
+
|
44 |
+
def get_metadata(page: str = "", item: str = ""):
|
45 |
+
from starlette.responses import JSONResponse
|
46 |
+
|
47 |
+
page = next(iter([x for x in extra_pages if x.name == page]), None)
|
48 |
+
if page is None:
|
49 |
+
return JSONResponse({})
|
50 |
+
|
51 |
+
metadata = page.metadata.get(item)
|
52 |
+
if metadata is None:
|
53 |
+
return JSONResponse({})
|
54 |
+
|
55 |
+
return JSONResponse({"metadata": json.dumps(metadata, indent=4, ensure_ascii=False)})
|
56 |
+
|
57 |
+
|
58 |
+
def get_single_card(page: str = "", tabname: str = "", name: str = ""):
|
59 |
+
from starlette.responses import JSONResponse
|
60 |
+
|
61 |
+
page = next(iter([x for x in extra_pages if x.name == page]), None)
|
62 |
+
|
63 |
+
try:
|
64 |
+
item = page.create_item(name, enable_filter=False)
|
65 |
+
page.items[name] = item
|
66 |
+
except Exception as e:
|
67 |
+
errors.display(e, "creating item for extra network")
|
68 |
+
item = page.items.get(name)
|
69 |
+
|
70 |
+
page.read_user_metadata(item)
|
71 |
+
item_html = page.create_html_for_item(item, tabname)
|
72 |
+
|
73 |
+
return JSONResponse({"html": item_html})
|
74 |
+
|
75 |
+
|
76 |
+
def add_pages_to_demo(app):
|
77 |
+
app.add_api_route("/sd_extra_networks/thumb", fetch_file, methods=["GET"])
|
78 |
+
app.add_api_route("/sd_extra_networks/metadata", get_metadata, methods=["GET"])
|
79 |
+
app.add_api_route("/sd_extra_networks/get-single-card", get_single_card, methods=["GET"])
|
80 |
+
|
81 |
+
|
82 |
+
def quote_js(s):
|
83 |
+
s = s.replace('\\', '\\\\')
|
84 |
+
s = s.replace('"', '\\"')
|
85 |
+
return f'"{s}"'
|
86 |
+
|
87 |
+
|
88 |
+
class ExtraNetworksPage:
|
89 |
+
def __init__(self, title):
|
90 |
+
self.title = title
|
91 |
+
self.name = title.lower()
|
92 |
+
self.id_page = self.name.replace(" ", "_")
|
93 |
+
self.card_page = shared.html("extra-networks-card.html")
|
94 |
+
self.allow_negative_prompt = False
|
95 |
+
self.metadata = {}
|
96 |
+
self.items = {}
|
97 |
+
|
98 |
+
def refresh(self):
|
99 |
+
pass
|
100 |
+
|
101 |
+
def read_user_metadata(self, item):
|
102 |
+
filename = item.get("filename", None)
|
103 |
+
metadata = extra_networks.get_user_metadata(filename)
|
104 |
+
|
105 |
+
desc = metadata.get("description", None)
|
106 |
+
if desc is not None:
|
107 |
+
item["description"] = desc
|
108 |
+
|
109 |
+
item["user_metadata"] = metadata
|
110 |
+
|
111 |
+
def link_preview(self, filename):
|
112 |
+
quoted_filename = urllib.parse.quote(filename.replace('\\', '/'))
|
113 |
+
mtime = os.path.getmtime(filename)
|
114 |
+
return f"./sd_extra_networks/thumb?filename={quoted_filename}&mtime={mtime}"
|
115 |
+
|
116 |
+
def search_terms_from_path(self, filename, possible_directories=None):
|
117 |
+
abspath = os.path.abspath(filename)
|
118 |
+
|
119 |
+
for parentdir in (possible_directories if possible_directories is not None else self.allowed_directories_for_previews()):
|
120 |
+
parentdir = os.path.abspath(parentdir)
|
121 |
+
if abspath.startswith(parentdir):
|
122 |
+
return abspath[len(parentdir):].replace('\\', '/')
|
123 |
+
|
124 |
+
return ""
|
125 |
+
|
126 |
+
def create_html(self, tabname):
|
127 |
+
items_html = ''
|
128 |
+
|
129 |
+
self.metadata = {}
|
130 |
+
|
131 |
+
subdirs = {}
|
132 |
+
for parentdir in [os.path.abspath(x) for x in self.allowed_directories_for_previews()]:
|
133 |
+
for root, dirs, _ in sorted(os.walk(parentdir, followlinks=True), key=lambda x: shared.natural_sort_key(x[0])):
|
134 |
+
for dirname in sorted(dirs, key=shared.natural_sort_key):
|
135 |
+
x = os.path.join(root, dirname)
|
136 |
+
|
137 |
+
if not os.path.isdir(x):
|
138 |
+
continue
|
139 |
+
|
140 |
+
subdir = os.path.abspath(x)[len(parentdir):].replace("\\", "/")
|
141 |
+
while subdir.startswith("/"):
|
142 |
+
subdir = subdir[1:]
|
143 |
+
|
144 |
+
is_empty = len(os.listdir(x)) == 0
|
145 |
+
if not is_empty and not subdir.endswith("/"):
|
146 |
+
subdir = subdir + "/"
|
147 |
+
|
148 |
+
if ("/." in subdir or subdir.startswith(".")) and not shared.opts.extra_networks_show_hidden_directories:
|
149 |
+
continue
|
150 |
+
|
151 |
+
subdirs[subdir] = 1
|
152 |
+
|
153 |
+
if subdirs:
|
154 |
+
subdirs = {"": 1, **subdirs}
|
155 |
+
|
156 |
+
subdirs_html = "".join([f"""
|
157 |
+
<button class='lg secondary gradio-button custom-button{" search-all" if subdir=="" else ""}' onclick='extraNetworksSearchButton("{tabname}_extra_search", event)'>
|
158 |
+
{html.escape(subdir if subdir!="" else "all")}
|
159 |
+
</button>
|
160 |
+
""" for subdir in subdirs])
|
161 |
+
|
162 |
+
self.items = {x["name"]: x for x in self.list_items()}
|
163 |
+
for item in self.items.values():
|
164 |
+
metadata = item.get("metadata")
|
165 |
+
if metadata:
|
166 |
+
self.metadata[item["name"]] = metadata
|
167 |
+
|
168 |
+
if "user_metadata" not in item:
|
169 |
+
self.read_user_metadata(item)
|
170 |
+
|
171 |
+
items_html += self.create_html_for_item(item, tabname)
|
172 |
+
|
173 |
+
if items_html == '':
|
174 |
+
dirs = "".join([f"<li>{x}</li>" for x in self.allowed_directories_for_previews()])
|
175 |
+
items_html = shared.html("extra-networks-no-cards.html").format(dirs=dirs)
|
176 |
+
|
177 |
+
self_name_id = self.name.replace(" ", "_")
|
178 |
+
|
179 |
+
res = f"""
|
180 |
+
<div id='{tabname}_{self_name_id}_subdirs' class='extra-network-subdirs extra-network-subdirs-cards'>
|
181 |
+
{subdirs_html}
|
182 |
+
</div>
|
183 |
+
<div id='{tabname}_{self_name_id}_cards' class='extra-network-cards'>
|
184 |
+
{items_html}
|
185 |
+
</div>
|
186 |
+
"""
|
187 |
+
|
188 |
+
return res
|
189 |
+
|
190 |
+
def create_item(self, name, index=None):
|
191 |
+
raise NotImplementedError()
|
192 |
+
|
193 |
+
def list_items(self):
|
194 |
+
raise NotImplementedError()
|
195 |
+
|
196 |
+
def allowed_directories_for_previews(self):
|
197 |
+
return []
|
198 |
+
|
199 |
+
def create_html_for_item(self, item, tabname):
|
200 |
+
"""
|
201 |
+
Create HTML for card item in tab tabname; can return empty string if the item is not meant to be shown.
|
202 |
+
"""
|
203 |
+
|
204 |
+
preview = item.get("preview", None)
|
205 |
+
|
206 |
+
onclick = item.get("onclick", None)
|
207 |
+
if onclick is None:
|
208 |
+
onclick = '"' + html.escape(f"""return cardClicked({quote_js(tabname)}, {item["prompt"]}, {"true" if self.allow_negative_prompt else "false"})""") + '"'
|
209 |
+
|
210 |
+
height = f"height: {shared.opts.extra_networks_card_height}px;" if shared.opts.extra_networks_card_height else ''
|
211 |
+
width = f"width: {shared.opts.extra_networks_card_width}px;" if shared.opts.extra_networks_card_width else ''
|
212 |
+
background_image = f'<img src="{html.escape(preview)}" class="preview" loading="lazy">' if preview else ''
|
213 |
+
metadata_button = ""
|
214 |
+
metadata = item.get("metadata")
|
215 |
+
if metadata:
|
216 |
+
metadata_button = f"<div class='metadata-button card-button' title='Show internal metadata' onclick='extraNetworksRequestMetadata(event, {quote_js(self.name)}, {quote_js(item['name'])})'></div>"
|
217 |
+
|
218 |
+
edit_button = f"<div class='edit-button card-button' title='Edit metadata' onclick='extraNetworksEditUserMetadata(event, {quote_js(tabname)}, {quote_js(self.id_page)}, {quote_js(item['name'])})'></div>"
|
219 |
+
|
220 |
+
local_path = ""
|
221 |
+
filename = item.get("filename", "")
|
222 |
+
for reldir in self.allowed_directories_for_previews():
|
223 |
+
absdir = os.path.abspath(reldir)
|
224 |
+
|
225 |
+
if filename.startswith(absdir):
|
226 |
+
local_path = filename[len(absdir):]
|
227 |
+
|
228 |
+
# if this is true, the item must not be shown in the default view, and must instead only be
|
229 |
+
# shown when searching for it
|
230 |
+
if shared.opts.extra_networks_hidden_models == "Always":
|
231 |
+
search_only = False
|
232 |
+
else:
|
233 |
+
search_only = "/." in local_path or "\\." in local_path
|
234 |
+
|
235 |
+
if search_only and shared.opts.extra_networks_hidden_models == "Never":
|
236 |
+
return ""
|
237 |
+
|
238 |
+
sort_keys = " ".join([html.escape(f'data-sort-{k}={v}') for k, v in item.get("sort_keys", {}).items()]).strip()
|
239 |
+
|
240 |
+
args = {
|
241 |
+
"background_image": background_image,
|
242 |
+
"style": f"'display: none; {height}{width}; font-size: {shared.opts.extra_networks_card_text_scale*100}%'",
|
243 |
+
"prompt": item.get("prompt", None),
|
244 |
+
"tabname": quote_js(tabname),
|
245 |
+
"local_preview": quote_js(item["local_preview"]),
|
246 |
+
"name": html.escape(item["name"]),
|
247 |
+
"description": (item.get("description") or "" if shared.opts.extra_networks_card_show_desc else ""),
|
248 |
+
"card_clicked": onclick,
|
249 |
+
"save_card_preview": '"' + html.escape(f"""return saveCardPreview(event, {quote_js(tabname)}, {quote_js(item["local_preview"])})""") + '"',
|
250 |
+
"search_term": item.get("search_term", ""),
|
251 |
+
"metadata_button": metadata_button,
|
252 |
+
"edit_button": edit_button,
|
253 |
+
"search_only": " search_only" if search_only else "",
|
254 |
+
"sort_keys": sort_keys,
|
255 |
+
}
|
256 |
+
|
257 |
+
return self.card_page.format(**args)
|
258 |
+
|
259 |
+
def get_sort_keys(self, path):
|
260 |
+
"""
|
261 |
+
List of default keys used for sorting in the UI.
|
262 |
+
"""
|
263 |
+
pth = Path(path)
|
264 |
+
stat = pth.stat()
|
265 |
+
return {
|
266 |
+
"date_created": int(stat.st_ctime or 0),
|
267 |
+
"date_modified": int(stat.st_mtime or 0),
|
268 |
+
"name": pth.name.lower(),
|
269 |
+
}
|
270 |
+
|
271 |
+
def find_preview(self, path):
|
272 |
+
"""
|
273 |
+
Find a preview PNG for a given path (without extension) and call link_preview on it.
|
274 |
+
"""
|
275 |
+
|
276 |
+
preview_extensions = ["png", "jpg", "jpeg", "webp"]
|
277 |
+
if shared.opts.samples_format not in preview_extensions:
|
278 |
+
preview_extensions.append(shared.opts.samples_format)
|
279 |
+
|
280 |
+
potential_files = sum([[path + "." + ext, path + ".preview." + ext] for ext in preview_extensions], [])
|
281 |
+
|
282 |
+
for file in potential_files:
|
283 |
+
if os.path.isfile(file):
|
284 |
+
return self.link_preview(file)
|
285 |
+
|
286 |
+
return None
|
287 |
+
|
288 |
+
def find_description(self, path):
|
289 |
+
"""
|
290 |
+
Find and read a description file for a given path (without extension).
|
291 |
+
"""
|
292 |
+
for file in [f"{path}.txt", f"{path}.description.txt"]:
|
293 |
+
try:
|
294 |
+
with open(file, "r", encoding="utf-8", errors="replace") as f:
|
295 |
+
return f.read()
|
296 |
+
except OSError:
|
297 |
+
pass
|
298 |
+
return None
|
299 |
+
|
300 |
+
def create_user_metadata_editor(self, ui, tabname):
|
301 |
+
return ui_extra_networks_user_metadata.UserMetadataEditor(ui, tabname, self)
|
302 |
+
|
303 |
+
|
304 |
+
def initialize():
|
305 |
+
extra_pages.clear()
|
306 |
+
|
307 |
+
|
308 |
+
def register_default_pages():
|
309 |
+
from modules.ui_extra_networks_textual_inversion import ExtraNetworksPageTextualInversion
|
310 |
+
from modules.ui_extra_networks_hypernets import ExtraNetworksPageHypernetworks
|
311 |
+
from modules.ui_extra_networks_checkpoints import ExtraNetworksPageCheckpoints
|
312 |
+
register_page(ExtraNetworksPageTextualInversion())
|
313 |
+
register_page(ExtraNetworksPageHypernetworks())
|
314 |
+
register_page(ExtraNetworksPageCheckpoints())
|
315 |
+
|
316 |
+
|
317 |
+
class ExtraNetworksUi:
|
318 |
+
def __init__(self):
|
319 |
+
self.pages = None
|
320 |
+
"""gradio HTML components related to extra networks' pages"""
|
321 |
+
|
322 |
+
self.page_contents = None
|
323 |
+
"""HTML content of the above; empty initially, filled when extra pages have to be shown"""
|
324 |
+
|
325 |
+
self.stored_extra_pages = None
|
326 |
+
|
327 |
+
self.button_save_preview = None
|
328 |
+
self.preview_target_filename = None
|
329 |
+
|
330 |
+
self.tabname = None
|
331 |
+
|
332 |
+
|
333 |
+
def pages_in_preferred_order(pages):
|
334 |
+
tab_order = [x.lower().strip() for x in shared.opts.ui_extra_networks_tab_reorder.split(",")]
|
335 |
+
|
336 |
+
def tab_name_score(name):
|
337 |
+
name = name.lower()
|
338 |
+
for i, possible_match in enumerate(tab_order):
|
339 |
+
if possible_match in name:
|
340 |
+
return i
|
341 |
+
|
342 |
+
return len(pages)
|
343 |
+
|
344 |
+
tab_scores = {page.name: (tab_name_score(page.name), original_index) for original_index, page in enumerate(pages)}
|
345 |
+
|
346 |
+
return sorted(pages, key=lambda x: tab_scores[x.name])
|
347 |
+
|
348 |
+
|
349 |
+
def create_ui(interface: gr.Blocks, unrelated_tabs, tabname):
|
350 |
+
from modules.ui import switch_values_symbol
|
351 |
+
|
352 |
+
ui = ExtraNetworksUi()
|
353 |
+
ui.pages = []
|
354 |
+
ui.pages_contents = []
|
355 |
+
ui.user_metadata_editors = []
|
356 |
+
ui.stored_extra_pages = pages_in_preferred_order(extra_pages.copy())
|
357 |
+
ui.tabname = tabname
|
358 |
+
|
359 |
+
related_tabs = []
|
360 |
+
|
361 |
+
for page in ui.stored_extra_pages:
|
362 |
+
with gr.Tab(page.title, id=page.id_page) as tab:
|
363 |
+
elem_id = f"{tabname}_{page.id_page}_cards_html"
|
364 |
+
page_elem = gr.HTML('Loading...', elem_id=elem_id)
|
365 |
+
ui.pages.append(page_elem)
|
366 |
+
|
367 |
+
page_elem.change(fn=lambda: None, _js='function(){applyExtraNetworkFilter(' + quote_js(tabname) + '); return []}', inputs=[], outputs=[])
|
368 |
+
|
369 |
+
editor = page.create_user_metadata_editor(ui, tabname)
|
370 |
+
editor.create_ui()
|
371 |
+
ui.user_metadata_editors.append(editor)
|
372 |
+
|
373 |
+
related_tabs.append(tab)
|
374 |
+
|
375 |
+
edit_search = gr.Textbox('', show_label=False, elem_id=tabname+"_extra_search", elem_classes="search", placeholder="Search...", visible=False, interactive=True)
|
376 |
+
dropdown_sort = gr.Dropdown(choices=['Default Sort', 'Date Created', 'Date Modified', 'Name'], value='Default Sort', elem_id=tabname+"_extra_sort", elem_classes="sort", multiselect=False, visible=False, show_label=False, interactive=True, label=tabname+"_extra_sort_order")
|
377 |
+
button_sortorder = ToolButton(switch_values_symbol, elem_id=tabname+"_extra_sortorder", elem_classes="sortorder", visible=False)
|
378 |
+
button_refresh = gr.Button('Refresh', elem_id=tabname+"_extra_refresh", visible=False)
|
379 |
+
checkbox_show_dirs = gr.Checkbox(True, label='Show dirs', elem_id=tabname+"_extra_show_dirs", elem_classes="show-dirs", visible=False)
|
380 |
+
|
381 |
+
ui.button_save_preview = gr.Button('Save preview', elem_id=tabname+"_save_preview", visible=False)
|
382 |
+
ui.preview_target_filename = gr.Textbox('Preview save filename', elem_id=tabname+"_preview_filename", visible=False)
|
383 |
+
|
384 |
+
for tab in unrelated_tabs:
|
385 |
+
tab.select(fn=lambda: [gr.update(visible=False) for _ in range(5)], inputs=[], outputs=[edit_search, dropdown_sort, button_sortorder, button_refresh, checkbox_show_dirs], show_progress=False)
|
386 |
+
|
387 |
+
for tab in related_tabs:
|
388 |
+
tab.select(fn=lambda: [gr.update(visible=True) for _ in range(5)], inputs=[], outputs=[edit_search, dropdown_sort, button_sortorder, button_refresh, checkbox_show_dirs], show_progress=False)
|
389 |
+
|
390 |
+
def pages_html():
|
391 |
+
if not ui.pages_contents:
|
392 |
+
return refresh()
|
393 |
+
|
394 |
+
return ui.pages_contents
|
395 |
+
|
396 |
+
def refresh():
|
397 |
+
for pg in ui.stored_extra_pages:
|
398 |
+
pg.refresh()
|
399 |
+
|
400 |
+
ui.pages_contents = [pg.create_html(ui.tabname) for pg in ui.stored_extra_pages]
|
401 |
+
|
402 |
+
return ui.pages_contents
|
403 |
+
|
404 |
+
interface.load(fn=pages_html, inputs=[], outputs=[*ui.pages])
|
405 |
+
button_refresh.click(fn=refresh, inputs=[], outputs=ui.pages)
|
406 |
+
|
407 |
+
return ui
|
408 |
+
|
409 |
+
|
410 |
+
def path_is_parent(parent_path, child_path):
|
411 |
+
parent_path = os.path.abspath(parent_path)
|
412 |
+
child_path = os.path.abspath(child_path)
|
413 |
+
|
414 |
+
return child_path.startswith(parent_path)
|
415 |
+
|
416 |
+
|
417 |
+
def setup_ui(ui, gallery):
|
418 |
+
def save_preview(index, images, filename):
|
419 |
+
# this function is here for backwards compatibility and likely will be removed soon
|
420 |
+
|
421 |
+
if len(images) == 0:
|
422 |
+
print("There is no image in gallery to save as a preview.")
|
423 |
+
return [page.create_html(ui.tabname) for page in ui.stored_extra_pages]
|
424 |
+
|
425 |
+
index = int(index)
|
426 |
+
index = 0 if index < 0 else index
|
427 |
+
index = len(images) - 1 if index >= len(images) else index
|
428 |
+
|
429 |
+
img_info = images[index if index >= 0 else 0]
|
430 |
+
image = image_from_url_text(img_info)
|
431 |
+
geninfo, items = read_info_from_image(image)
|
432 |
+
|
433 |
+
is_allowed = False
|
434 |
+
for extra_page in ui.stored_extra_pages:
|
435 |
+
if any(path_is_parent(x, filename) for x in extra_page.allowed_directories_for_previews()):
|
436 |
+
is_allowed = True
|
437 |
+
break
|
438 |
+
|
439 |
+
assert is_allowed, f'writing to {filename} is not allowed'
|
440 |
+
|
441 |
+
save_image_with_geninfo(image, geninfo, filename)
|
442 |
+
|
443 |
+
return [page.create_html(ui.tabname) for page in ui.stored_extra_pages]
|
444 |
+
|
445 |
+
ui.button_save_preview.click(
|
446 |
+
fn=save_preview,
|
447 |
+
_js="function(x, y, z){return [selected_gallery_index(), y, z]}",
|
448 |
+
inputs=[ui.preview_target_filename, gallery, ui.preview_target_filename],
|
449 |
+
outputs=[*ui.pages]
|
450 |
+
)
|
451 |
+
|
452 |
+
for editor in ui.user_metadata_editors:
|
453 |
+
editor.setup_ui(gallery)
|
454 |
+
|
455 |
+
|
modules/ui_extra_networks_checkpoints.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import html
|
2 |
+
import os
|
3 |
+
|
4 |
+
from modules import shared, ui_extra_networks, sd_models
|
5 |
+
from modules.ui_extra_networks import quote_js
|
6 |
+
from modules.ui_extra_networks_checkpoints_user_metadata import CheckpointUserMetadataEditor
|
7 |
+
|
8 |
+
|
9 |
+
class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage):
|
10 |
+
def __init__(self):
|
11 |
+
super().__init__('Checkpoints')
|
12 |
+
|
13 |
+
def refresh(self):
|
14 |
+
shared.refresh_checkpoints()
|
15 |
+
|
16 |
+
def create_item(self, name, index=None, enable_filter=True):
|
17 |
+
checkpoint: sd_models.CheckpointInfo = sd_models.checkpoint_aliases.get(name)
|
18 |
+
path, ext = os.path.splitext(checkpoint.filename)
|
19 |
+
return {
|
20 |
+
"name": checkpoint.name_for_extra,
|
21 |
+
"filename": checkpoint.filename,
|
22 |
+
"shorthash": checkpoint.shorthash,
|
23 |
+
"preview": self.find_preview(path),
|
24 |
+
"description": self.find_description(path),
|
25 |
+
"search_term": self.search_terms_from_path(checkpoint.filename) + " " + (checkpoint.sha256 or ""),
|
26 |
+
"onclick": '"' + html.escape(f"""return selectCheckpoint({quote_js(name)})""") + '"',
|
27 |
+
"local_preview": f"{path}.{shared.opts.samples_format}",
|
28 |
+
"metadata": checkpoint.metadata,
|
29 |
+
"sort_keys": {'default': index, **self.get_sort_keys(checkpoint.filename)},
|
30 |
+
}
|
31 |
+
|
32 |
+
def list_items(self):
|
33 |
+
names = list(sd_models.checkpoints_list)
|
34 |
+
for index, name in enumerate(names):
|
35 |
+
yield self.create_item(name, index)
|
36 |
+
|
37 |
+
def allowed_directories_for_previews(self):
|
38 |
+
return [v for v in [shared.cmd_opts.ckpt_dir, sd_models.model_path] if v is not None]
|
39 |
+
|
40 |
+
def create_user_metadata_editor(self, ui, tabname):
|
41 |
+
return CheckpointUserMetadataEditor(ui, tabname, self)
|
modules/ui_extra_networks_checkpoints_user_metadata.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
|
3 |
+
from modules import ui_extra_networks_user_metadata, sd_vae, shared
|
4 |
+
from modules.ui_common import create_refresh_button
|
5 |
+
|
6 |
+
|
7 |
+
class CheckpointUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor):
|
8 |
+
def __init__(self, ui, tabname, page):
|
9 |
+
super().__init__(ui, tabname, page)
|
10 |
+
|
11 |
+
self.select_vae = None
|
12 |
+
|
13 |
+
def save_user_metadata(self, name, desc, notes, vae):
|
14 |
+
user_metadata = self.get_user_metadata(name)
|
15 |
+
user_metadata["description"] = desc
|
16 |
+
user_metadata["notes"] = notes
|
17 |
+
user_metadata["vae"] = vae
|
18 |
+
|
19 |
+
self.write_user_metadata(name, user_metadata)
|
20 |
+
|
21 |
+
def update_vae(self, name):
|
22 |
+
if name == shared.sd_model.sd_checkpoint_info.name_for_extra:
|
23 |
+
sd_vae.reload_vae_weights()
|
24 |
+
|
25 |
+
def put_values_into_components(self, name):
|
26 |
+
user_metadata = self.get_user_metadata(name)
|
27 |
+
values = super().put_values_into_components(name)
|
28 |
+
|
29 |
+
return [
|
30 |
+
*values[0:5],
|
31 |
+
user_metadata.get('vae', ''),
|
32 |
+
]
|
33 |
+
|
34 |
+
def create_editor(self):
|
35 |
+
self.create_default_editor_elems()
|
36 |
+
|
37 |
+
with gr.Row():
|
38 |
+
self.select_vae = gr.Dropdown(choices=["Automatic", "None"] + list(sd_vae.vae_dict), value="None", label="Preferred VAE", elem_id="checpoint_edit_user_metadata_preferred_vae")
|
39 |
+
create_refresh_button(self.select_vae, sd_vae.refresh_vae_list, lambda: {"choices": ["Automatic", "None"] + list(sd_vae.vae_dict)}, "checpoint_edit_user_metadata_refresh_preferred_vae")
|
40 |
+
|
41 |
+
self.edit_notes = gr.TextArea(label='Notes', lines=4)
|
42 |
+
|
43 |
+
self.create_default_buttons()
|
44 |
+
|
45 |
+
viewed_components = [
|
46 |
+
self.edit_name,
|
47 |
+
self.edit_description,
|
48 |
+
self.html_filedata,
|
49 |
+
self.html_preview,
|
50 |
+
self.edit_notes,
|
51 |
+
self.select_vae,
|
52 |
+
]
|
53 |
+
|
54 |
+
self.button_edit\
|
55 |
+
.click(fn=self.put_values_into_components, inputs=[self.edit_name_input], outputs=viewed_components)\
|
56 |
+
.then(fn=lambda: gr.update(visible=True), inputs=[], outputs=[self.box])
|
57 |
+
|
58 |
+
edited_components = [
|
59 |
+
self.edit_description,
|
60 |
+
self.edit_notes,
|
61 |
+
self.select_vae,
|
62 |
+
]
|
63 |
+
|
64 |
+
self.setup_save_handler(self.button_save, self.save_user_metadata, edited_components)
|
65 |
+
self.button_save.click(fn=self.update_vae, inputs=[self.edit_name_input])
|
66 |
+
|
modules/ui_extra_networks_hypernets.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
from modules import shared, ui_extra_networks
|
4 |
+
from modules.ui_extra_networks import quote_js
|
5 |
+
from modules.hashes import sha256_from_cache
|
6 |
+
|
7 |
+
|
8 |
+
class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage):
|
9 |
+
def __init__(self):
|
10 |
+
super().__init__('Hypernetworks')
|
11 |
+
|
12 |
+
def refresh(self):
|
13 |
+
shared.reload_hypernetworks()
|
14 |
+
|
15 |
+
def create_item(self, name, index=None, enable_filter=True):
|
16 |
+
full_path = shared.hypernetworks[name]
|
17 |
+
path, ext = os.path.splitext(full_path)
|
18 |
+
sha256 = sha256_from_cache(full_path, f'hypernet/{name}')
|
19 |
+
shorthash = sha256[0:10] if sha256 else None
|
20 |
+
|
21 |
+
return {
|
22 |
+
"name": name,
|
23 |
+
"filename": full_path,
|
24 |
+
"shorthash": shorthash,
|
25 |
+
"preview": self.find_preview(path),
|
26 |
+
"description": self.find_description(path),
|
27 |
+
"search_term": self.search_terms_from_path(path) + " " + (sha256 or ""),
|
28 |
+
"prompt": quote_js(f"<hypernet:{name}:") + " + opts.extra_networks_default_multiplier + " + quote_js(">"),
|
29 |
+
"local_preview": f"{path}.preview.{shared.opts.samples_format}",
|
30 |
+
"sort_keys": {'default': index, **self.get_sort_keys(path + ext)},
|
31 |
+
}
|
32 |
+
|
33 |
+
def list_items(self):
|
34 |
+
for index, name in enumerate(shared.hypernetworks):
|
35 |
+
yield self.create_item(name, index)
|
36 |
+
|
37 |
+
def allowed_directories_for_previews(self):
|
38 |
+
return [shared.cmd_opts.hypernetwork_dir]
|
39 |
+
|
modules/ui_extra_networks_textual_inversion.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
from modules import ui_extra_networks, sd_hijack, shared
|
4 |
+
from modules.ui_extra_networks import quote_js
|
5 |
+
|
6 |
+
|
7 |
+
class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage):
|
8 |
+
def __init__(self):
|
9 |
+
super().__init__('Textual Inversion')
|
10 |
+
self.allow_negative_prompt = True
|
11 |
+
|
12 |
+
def refresh(self):
|
13 |
+
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True)
|
14 |
+
|
15 |
+
def create_item(self, name, index=None, enable_filter=True):
|
16 |
+
embedding = sd_hijack.model_hijack.embedding_db.word_embeddings.get(name)
|
17 |
+
|
18 |
+
path, ext = os.path.splitext(embedding.filename)
|
19 |
+
return {
|
20 |
+
"name": name,
|
21 |
+
"filename": embedding.filename,
|
22 |
+
"shorthash": embedding.shorthash,
|
23 |
+
"preview": self.find_preview(path),
|
24 |
+
"description": self.find_description(path),
|
25 |
+
"search_term": self.search_terms_from_path(embedding.filename) + " " + (embedding.hash or ""),
|
26 |
+
"prompt": quote_js(embedding.name),
|
27 |
+
"local_preview": f"{path}.preview.{shared.opts.samples_format}",
|
28 |
+
"sort_keys": {'default': index, **self.get_sort_keys(embedding.filename)},
|
29 |
+
}
|
30 |
+
|
31 |
+
def list_items(self):
|
32 |
+
for index, name in enumerate(sd_hijack.model_hijack.embedding_db.word_embeddings):
|
33 |
+
yield self.create_item(name, index)
|
34 |
+
|
35 |
+
def allowed_directories_for_previews(self):
|
36 |
+
return list(sd_hijack.model_hijack.embedding_db.embedding_dirs)
|
modules/ui_extra_networks_user_metadata.py
ADDED
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import datetime
|
2 |
+
import html
|
3 |
+
import json
|
4 |
+
import os.path
|
5 |
+
|
6 |
+
import gradio as gr
|
7 |
+
|
8 |
+
from modules import generation_parameters_copypaste, images, sysinfo, errors, ui_extra_networks
|
9 |
+
|
10 |
+
|
11 |
+
class UserMetadataEditor:
|
12 |
+
|
13 |
+
def __init__(self, ui, tabname, page):
|
14 |
+
self.ui = ui
|
15 |
+
self.tabname = tabname
|
16 |
+
self.page = page
|
17 |
+
self.id_part = f"{self.tabname}_{self.page.id_page}_edit_user_metadata"
|
18 |
+
|
19 |
+
self.box = None
|
20 |
+
|
21 |
+
self.edit_name_input = None
|
22 |
+
self.button_edit = None
|
23 |
+
|
24 |
+
self.edit_name = None
|
25 |
+
self.edit_description = None
|
26 |
+
self.edit_notes = None
|
27 |
+
self.html_filedata = None
|
28 |
+
self.html_preview = None
|
29 |
+
self.html_status = None
|
30 |
+
|
31 |
+
self.button_cancel = None
|
32 |
+
self.button_replace_preview = None
|
33 |
+
self.button_save = None
|
34 |
+
|
35 |
+
def get_user_metadata(self, name):
|
36 |
+
item = self.page.items.get(name, {})
|
37 |
+
|
38 |
+
user_metadata = item.get('user_metadata', None)
|
39 |
+
if not user_metadata:
|
40 |
+
user_metadata = {'description': item.get('description', '')}
|
41 |
+
item['user_metadata'] = user_metadata
|
42 |
+
|
43 |
+
return user_metadata
|
44 |
+
|
45 |
+
def create_extra_default_items_in_left_column(self):
|
46 |
+
pass
|
47 |
+
|
48 |
+
def create_default_editor_elems(self):
|
49 |
+
with gr.Row():
|
50 |
+
with gr.Column(scale=2):
|
51 |
+
self.edit_name = gr.HTML(elem_classes="extra-network-name")
|
52 |
+
self.edit_description = gr.Textbox(label="Description", lines=4)
|
53 |
+
self.html_filedata = gr.HTML()
|
54 |
+
|
55 |
+
self.create_extra_default_items_in_left_column()
|
56 |
+
|
57 |
+
with gr.Column(scale=1, min_width=0):
|
58 |
+
self.html_preview = gr.HTML()
|
59 |
+
|
60 |
+
def create_default_buttons(self):
|
61 |
+
|
62 |
+
with gr.Row(elem_classes="edit-user-metadata-buttons"):
|
63 |
+
self.button_cancel = gr.Button('Cancel')
|
64 |
+
self.button_replace_preview = gr.Button('Replace preview', variant='primary')
|
65 |
+
self.button_save = gr.Button('Save', variant='primary')
|
66 |
+
|
67 |
+
self.html_status = gr.HTML(elem_classes="edit-user-metadata-status")
|
68 |
+
|
69 |
+
self.button_cancel.click(fn=None, _js="closePopup")
|
70 |
+
|
71 |
+
def get_card_html(self, name):
|
72 |
+
item = self.page.items.get(name, {})
|
73 |
+
|
74 |
+
preview_url = item.get("preview", None)
|
75 |
+
|
76 |
+
if not preview_url:
|
77 |
+
filename, _ = os.path.splitext(item["filename"])
|
78 |
+
preview_url = self.page.find_preview(filename)
|
79 |
+
item["preview"] = preview_url
|
80 |
+
|
81 |
+
if preview_url:
|
82 |
+
preview = f'''
|
83 |
+
<div class='card standalone-card-preview'>
|
84 |
+
<img src="{html.escape(preview_url)}" class="preview">
|
85 |
+
</div>
|
86 |
+
'''
|
87 |
+
else:
|
88 |
+
preview = "<div class='card standalone-card-preview'></div>"
|
89 |
+
|
90 |
+
return preview
|
91 |
+
|
92 |
+
def relative_path(self, path):
|
93 |
+
for parent_path in self.page.allowed_directories_for_previews():
|
94 |
+
if ui_extra_networks.path_is_parent(parent_path, path):
|
95 |
+
return os.path.relpath(path, parent_path)
|
96 |
+
|
97 |
+
return os.path.basename(path)
|
98 |
+
|
99 |
+
def get_metadata_table(self, name):
|
100 |
+
item = self.page.items.get(name, {})
|
101 |
+
try:
|
102 |
+
filename = item["filename"]
|
103 |
+
shorthash = item.get("shorthash", None)
|
104 |
+
|
105 |
+
stats = os.stat(filename)
|
106 |
+
params = [
|
107 |
+
('Filename: ', self.relative_path(filename)),
|
108 |
+
('File size: ', sysinfo.pretty_bytes(stats.st_size)),
|
109 |
+
('Hash: ', shorthash),
|
110 |
+
('Modified: ', datetime.datetime.fromtimestamp(stats.st_mtime).strftime('%Y-%m-%d %H:%M')),
|
111 |
+
]
|
112 |
+
|
113 |
+
return params
|
114 |
+
except Exception as e:
|
115 |
+
errors.display(e, f"reading info for {name}")
|
116 |
+
return []
|
117 |
+
|
118 |
+
def put_values_into_components(self, name):
|
119 |
+
user_metadata = self.get_user_metadata(name)
|
120 |
+
|
121 |
+
try:
|
122 |
+
params = self.get_metadata_table(name)
|
123 |
+
except Exception as e:
|
124 |
+
errors.display(e, f"reading metadata info for {name}")
|
125 |
+
params = []
|
126 |
+
|
127 |
+
table = '<table class="file-metadata">' + "".join(f"<tr><th>{name}</th><td>{value}</td></tr>" for name, value in params if value is not None) + '</table>'
|
128 |
+
|
129 |
+
return html.escape(name), user_metadata.get('description', ''), table, self.get_card_html(name), user_metadata.get('notes', '')
|
130 |
+
|
131 |
+
def write_user_metadata(self, name, metadata):
|
132 |
+
item = self.page.items.get(name, {})
|
133 |
+
filename = item.get("filename", None)
|
134 |
+
basename, ext = os.path.splitext(filename)
|
135 |
+
|
136 |
+
with open(basename + '.json', "w", encoding="utf8") as file:
|
137 |
+
json.dump(metadata, file, indent=4)
|
138 |
+
|
139 |
+
def save_user_metadata(self, name, desc, notes):
|
140 |
+
user_metadata = self.get_user_metadata(name)
|
141 |
+
user_metadata["description"] = desc
|
142 |
+
user_metadata["notes"] = notes
|
143 |
+
|
144 |
+
self.write_user_metadata(name, user_metadata)
|
145 |
+
|
146 |
+
def setup_save_handler(self, button, func, components):
|
147 |
+
button\
|
148 |
+
.click(fn=func, inputs=[self.edit_name_input, *components], outputs=[])\
|
149 |
+
.then(fn=None, _js="function(name){closePopup(); extraNetworksRefreshSingleCard(" + json.dumps(self.page.name) + "," + json.dumps(self.tabname) + ", name);}", inputs=[self.edit_name_input], outputs=[])
|
150 |
+
|
151 |
+
def create_editor(self):
|
152 |
+
self.create_default_editor_elems()
|
153 |
+
|
154 |
+
self.edit_notes = gr.TextArea(label='Notes', lines=4)
|
155 |
+
|
156 |
+
self.create_default_buttons()
|
157 |
+
|
158 |
+
self.button_edit\
|
159 |
+
.click(fn=self.put_values_into_components, inputs=[self.edit_name_input], outputs=[self.edit_name, self.edit_description, self.html_filedata, self.html_preview, self.edit_notes])\
|
160 |
+
.then(fn=lambda: gr.update(visible=True), inputs=[], outputs=[self.box])
|
161 |
+
|
162 |
+
self.setup_save_handler(self.button_save, self.save_user_metadata, [self.edit_description, self.edit_notes])
|
163 |
+
|
164 |
+
def create_ui(self):
|
165 |
+
with gr.Box(visible=False, elem_id=self.id_part, elem_classes="edit-user-metadata") as box:
|
166 |
+
self.box = box
|
167 |
+
|
168 |
+
self.edit_name_input = gr.Textbox("Edit user metadata card id", visible=False, elem_id=f"{self.id_part}_name")
|
169 |
+
self.button_edit = gr.Button("Edit user metadata", visible=False, elem_id=f"{self.id_part}_button")
|
170 |
+
|
171 |
+
self.create_editor()
|
172 |
+
|
173 |
+
def save_preview(self, index, gallery, name):
|
174 |
+
if len(gallery) == 0:
|
175 |
+
return self.get_card_html(name), "There is no image in gallery to save as a preview."
|
176 |
+
|
177 |
+
item = self.page.items.get(name, {})
|
178 |
+
|
179 |
+
index = int(index)
|
180 |
+
index = 0 if index < 0 else index
|
181 |
+
index = len(gallery) - 1 if index >= len(gallery) else index
|
182 |
+
|
183 |
+
img_info = gallery[index if index >= 0 else 0]
|
184 |
+
image = generation_parameters_copypaste.image_from_url_text(img_info)
|
185 |
+
geninfo, items = images.read_info_from_image(image)
|
186 |
+
|
187 |
+
images.save_image_with_geninfo(image, geninfo, item["local_preview"])
|
188 |
+
|
189 |
+
return self.get_card_html(name), ''
|
190 |
+
|
191 |
+
def setup_ui(self, gallery):
|
192 |
+
self.button_replace_preview.click(
|
193 |
+
fn=self.save_preview,
|
194 |
+
_js="function(x, y, z){return [selected_gallery_index(), y, z]}",
|
195 |
+
inputs=[self.edit_name_input, gallery, self.edit_name_input],
|
196 |
+
outputs=[self.html_preview, self.html_status]
|
197 |
+
).then(
|
198 |
+
fn=None,
|
199 |
+
_js="function(name){extraNetworksRefreshSingleCard(" + json.dumps(self.page.name) + "," + json.dumps(self.tabname) + ", name);}",
|
200 |
+
inputs=[self.edit_name_input],
|
201 |
+
outputs=[]
|
202 |
+
)
|
203 |
+
|
204 |
+
|
205 |
+
|
modules/ui_gradio_extensions.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import gradio as gr
|
3 |
+
|
4 |
+
from modules import localization, shared, scripts
|
5 |
+
from modules.paths import script_path, data_path
|
6 |
+
|
7 |
+
|
8 |
+
def webpath(fn):
|
9 |
+
if fn.startswith(script_path):
|
10 |
+
web_path = os.path.relpath(fn, script_path).replace('\\', '/')
|
11 |
+
else:
|
12 |
+
web_path = os.path.abspath(fn)
|
13 |
+
|
14 |
+
return f'file={web_path}?{os.path.getmtime(fn)}'
|
15 |
+
|
16 |
+
|
17 |
+
def javascript_html():
|
18 |
+
# Ensure localization is in `window` before scripts
|
19 |
+
head = f'<script type="text/javascript">{localization.localization_js(shared.opts.localization)}</script>\n'
|
20 |
+
|
21 |
+
script_js = os.path.join(script_path, "script.js")
|
22 |
+
head += f'<script type="text/javascript" src="{webpath(script_js)}"></script>\n'
|
23 |
+
|
24 |
+
for script in scripts.list_scripts("javascript", ".js"):
|
25 |
+
head += f'<script type="text/javascript" src="{webpath(script.path)}"></script>\n'
|
26 |
+
|
27 |
+
for script in scripts.list_scripts("javascript", ".mjs"):
|
28 |
+
head += f'<script type="module" src="{webpath(script.path)}"></script>\n'
|
29 |
+
|
30 |
+
if shared.cmd_opts.theme:
|
31 |
+
head += f'<script type="text/javascript">set_theme(\"{shared.cmd_opts.theme}\");</script>\n'
|
32 |
+
|
33 |
+
return head
|
34 |
+
|
35 |
+
|
36 |
+
def css_html():
|
37 |
+
head = ""
|
38 |
+
|
39 |
+
def stylesheet(fn):
|
40 |
+
return f'<link rel="stylesheet" property="stylesheet" href="{webpath(fn)}">'
|
41 |
+
|
42 |
+
for cssfile in scripts.list_files_with_name("style.css"):
|
43 |
+
if not os.path.isfile(cssfile):
|
44 |
+
continue
|
45 |
+
|
46 |
+
head += stylesheet(cssfile)
|
47 |
+
|
48 |
+
if os.path.exists(os.path.join(data_path, "user.css")):
|
49 |
+
head += stylesheet(os.path.join(data_path, "user.css"))
|
50 |
+
|
51 |
+
return head
|
52 |
+
|
53 |
+
|
54 |
+
def reload_javascript():
|
55 |
+
js = javascript_html()
|
56 |
+
css = css_html()
|
57 |
+
|
58 |
+
def template_response(*args, **kwargs):
|
59 |
+
res = shared.GradioTemplateResponseOriginal(*args, **kwargs)
|
60 |
+
res.body = res.body.replace(b'</head>', f'{js}</head>'.encode("utf8"))
|
61 |
+
res.body = res.body.replace(b'</body>', f'{css}</body>'.encode("utf8"))
|
62 |
+
res.init_headers()
|
63 |
+
return res
|
64 |
+
|
65 |
+
gr.routes.templates.TemplateResponse = template_response
|
66 |
+
|
67 |
+
|
68 |
+
if not hasattr(shared, 'GradioTemplateResponseOriginal'):
|
69 |
+
shared.GradioTemplateResponseOriginal = gr.routes.templates.TemplateResponse
|
modules/ui_loadsave.py
ADDED
@@ -0,0 +1,227 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
|
4 |
+
import gradio as gr
|
5 |
+
|
6 |
+
from modules import errors
|
7 |
+
from modules.ui_components import ToolButton
|
8 |
+
|
9 |
+
|
10 |
+
def radio_choices(comp): # gradio 3.41 changes choices from list of values to list of pairs
|
11 |
+
return [x[0] if isinstance(x, tuple) else x for x in getattr(comp, 'choices', [])]
|
12 |
+
|
13 |
+
|
14 |
+
class UiLoadsave:
|
15 |
+
"""allows saving and restoring default values for gradio components"""
|
16 |
+
|
17 |
+
def __init__(self, filename):
|
18 |
+
self.filename = filename
|
19 |
+
self.ui_settings = {}
|
20 |
+
self.component_mapping = {}
|
21 |
+
self.error_loading = False
|
22 |
+
self.finalized_ui = False
|
23 |
+
|
24 |
+
self.ui_defaults_view = None
|
25 |
+
self.ui_defaults_apply = None
|
26 |
+
self.ui_defaults_review = None
|
27 |
+
|
28 |
+
try:
|
29 |
+
if os.path.exists(self.filename):
|
30 |
+
self.ui_settings = self.read_from_file()
|
31 |
+
except Exception as e:
|
32 |
+
self.error_loading = True
|
33 |
+
errors.display(e, "loading settings")
|
34 |
+
|
35 |
+
|
36 |
+
|
37 |
+
def add_component(self, path, x):
|
38 |
+
"""adds component to the registry of tracked components"""
|
39 |
+
|
40 |
+
assert not self.finalized_ui
|
41 |
+
|
42 |
+
def apply_field(obj, field, condition=None, init_field=None):
|
43 |
+
key = f"{path}/{field}"
|
44 |
+
|
45 |
+
if getattr(obj, 'custom_script_source', None) is not None:
|
46 |
+
key = f"customscript/{obj.custom_script_source}/{key}"
|
47 |
+
|
48 |
+
if getattr(obj, 'do_not_save_to_config', False):
|
49 |
+
return
|
50 |
+
|
51 |
+
saved_value = self.ui_settings.get(key, None)
|
52 |
+
if saved_value is None:
|
53 |
+
self.ui_settings[key] = getattr(obj, field)
|
54 |
+
elif condition and not condition(saved_value):
|
55 |
+
pass
|
56 |
+
else:
|
57 |
+
if isinstance(x, gr.Textbox) and field == 'value': # due to an undesirable behavior of gr.Textbox, if you give it an int value instead of str, everything dies
|
58 |
+
saved_value = str(saved_value)
|
59 |
+
elif isinstance(x, gr.Number) and field == 'value':
|
60 |
+
try:
|
61 |
+
saved_value = float(saved_value)
|
62 |
+
except ValueError:
|
63 |
+
return
|
64 |
+
|
65 |
+
setattr(obj, field, saved_value)
|
66 |
+
if init_field is not None:
|
67 |
+
init_field(saved_value)
|
68 |
+
|
69 |
+
if field == 'value' and key not in self.component_mapping:
|
70 |
+
self.component_mapping[key] = x
|
71 |
+
|
72 |
+
if type(x) in [gr.Slider, gr.Radio, gr.Checkbox, gr.Textbox, gr.Number, gr.Dropdown, ToolButton, gr.Button] and x.visible:
|
73 |
+
apply_field(x, 'visible')
|
74 |
+
|
75 |
+
if type(x) == gr.Slider:
|
76 |
+
apply_field(x, 'value')
|
77 |
+
apply_field(x, 'minimum')
|
78 |
+
apply_field(x, 'maximum')
|
79 |
+
apply_field(x, 'step')
|
80 |
+
|
81 |
+
if type(x) == gr.Radio:
|
82 |
+
apply_field(x, 'value', lambda val: val in radio_choices(x))
|
83 |
+
|
84 |
+
if type(x) == gr.Checkbox:
|
85 |
+
apply_field(x, 'value')
|
86 |
+
|
87 |
+
if type(x) == gr.Textbox:
|
88 |
+
apply_field(x, 'value')
|
89 |
+
|
90 |
+
if type(x) == gr.Number:
|
91 |
+
apply_field(x, 'value')
|
92 |
+
|
93 |
+
if type(x) == gr.Dropdown:
|
94 |
+
def check_dropdown(val):
|
95 |
+
choices = radio_choices(x)
|
96 |
+
if getattr(x, 'multiselect', False):
|
97 |
+
return all(value in choices for value in val)
|
98 |
+
else:
|
99 |
+
return val in choices
|
100 |
+
|
101 |
+
apply_field(x, 'value', check_dropdown, getattr(x, 'init_field', None))
|
102 |
+
|
103 |
+
def check_tab_id(tab_id):
|
104 |
+
tab_items = list(filter(lambda e: isinstance(e, gr.TabItem), x.children))
|
105 |
+
if type(tab_id) == str:
|
106 |
+
tab_ids = [t.id for t in tab_items]
|
107 |
+
return tab_id in tab_ids
|
108 |
+
elif type(tab_id) == int:
|
109 |
+
return 0 <= tab_id < len(tab_items)
|
110 |
+
else:
|
111 |
+
return False
|
112 |
+
|
113 |
+
if type(x) == gr.Tabs:
|
114 |
+
apply_field(x, 'selected', check_tab_id)
|
115 |
+
|
116 |
+
def add_block(self, x, path=""):
|
117 |
+
"""adds all components inside a gradio block x to the registry of tracked components"""
|
118 |
+
|
119 |
+
if hasattr(x, 'children'):
|
120 |
+
if isinstance(x, gr.Tabs) and x.elem_id is not None:
|
121 |
+
# Tabs element can't have a label, have to use elem_id instead
|
122 |
+
self.add_component(f"{path}/Tabs@{x.elem_id}", x)
|
123 |
+
for c in x.children:
|
124 |
+
self.add_block(c, path)
|
125 |
+
elif x.label is not None:
|
126 |
+
self.add_component(f"{path}/{x.label}", x)
|
127 |
+
elif isinstance(x, gr.Button) and x.value is not None:
|
128 |
+
self.add_component(f"{path}/{x.value}", x)
|
129 |
+
|
130 |
+
def read_from_file(self):
|
131 |
+
with open(self.filename, "r", encoding="utf8") as file:
|
132 |
+
return json.load(file)
|
133 |
+
|
134 |
+
def write_to_file(self, current_ui_settings):
|
135 |
+
with open(self.filename, "w", encoding="utf8") as file:
|
136 |
+
json.dump(current_ui_settings, file, indent=4)
|
137 |
+
|
138 |
+
def dump_defaults(self):
|
139 |
+
"""saves default values to a file unless tjhe file is present and there was an error loading default values at start"""
|
140 |
+
|
141 |
+
if self.error_loading and os.path.exists(self.filename):
|
142 |
+
return
|
143 |
+
|
144 |
+
self.write_to_file(self.ui_settings)
|
145 |
+
|
146 |
+
def iter_changes(self, current_ui_settings, values):
|
147 |
+
"""
|
148 |
+
given a dictionary with defaults from a file and current values from gradio elements, returns
|
149 |
+
an iterator over tuples of values that are not the same between the file and the current;
|
150 |
+
tuple contents are: path, old value, new value
|
151 |
+
"""
|
152 |
+
|
153 |
+
for (path, component), new_value in zip(self.component_mapping.items(), values):
|
154 |
+
old_value = current_ui_settings.get(path)
|
155 |
+
|
156 |
+
choices = radio_choices(component)
|
157 |
+
if isinstance(new_value, int) and choices:
|
158 |
+
if new_value >= len(choices):
|
159 |
+
continue
|
160 |
+
|
161 |
+
new_value = choices[new_value]
|
162 |
+
if isinstance(new_value, tuple):
|
163 |
+
new_value = new_value[0]
|
164 |
+
|
165 |
+
if new_value == old_value:
|
166 |
+
continue
|
167 |
+
|
168 |
+
if old_value is None and new_value == '' or new_value == []:
|
169 |
+
continue
|
170 |
+
|
171 |
+
yield path, old_value, new_value
|
172 |
+
|
173 |
+
def ui_view(self, *values):
|
174 |
+
text = ["<table><thead><tr><th>Path</th><th>Old value</th><th>New value</th></thead><tbody>"]
|
175 |
+
|
176 |
+
for path, old_value, new_value in self.iter_changes(self.read_from_file(), values):
|
177 |
+
if old_value is None:
|
178 |
+
old_value = "<span class='ui-defaults-none'>None</span>"
|
179 |
+
|
180 |
+
text.append(f"<tr><td>{path}</td><td>{old_value}</td><td>{new_value}</td></tr>")
|
181 |
+
|
182 |
+
if len(text) == 1:
|
183 |
+
text.append("<tr><td colspan=3>No changes</td></tr>")
|
184 |
+
|
185 |
+
text.append("</tbody>")
|
186 |
+
return "".join(text)
|
187 |
+
|
188 |
+
def ui_apply(self, *values):
|
189 |
+
num_changed = 0
|
190 |
+
|
191 |
+
current_ui_settings = self.read_from_file()
|
192 |
+
|
193 |
+
for path, _, new_value in self.iter_changes(current_ui_settings.copy(), values):
|
194 |
+
num_changed += 1
|
195 |
+
current_ui_settings[path] = new_value
|
196 |
+
|
197 |
+
if num_changed == 0:
|
198 |
+
return "No changes."
|
199 |
+
|
200 |
+
self.write_to_file(current_ui_settings)
|
201 |
+
|
202 |
+
return f"Wrote {num_changed} changes."
|
203 |
+
|
204 |
+
def create_ui(self):
|
205 |
+
"""creates ui elements for editing defaults UI, without adding any logic to them"""
|
206 |
+
|
207 |
+
gr.HTML(
|
208 |
+
f"This page allows you to change default values in UI elements on other tabs.<br />"
|
209 |
+
f"Make your changes, press 'View changes' to review the changed default values,<br />"
|
210 |
+
f"then press 'Apply' to write them to {self.filename}.<br />"
|
211 |
+
f"New defaults will apply after you restart the UI.<br />"
|
212 |
+
)
|
213 |
+
|
214 |
+
with gr.Row():
|
215 |
+
self.ui_defaults_view = gr.Button(value='View changes', elem_id="ui_defaults_view", variant="secondary")
|
216 |
+
self.ui_defaults_apply = gr.Button(value='Apply', elem_id="ui_defaults_apply", variant="primary")
|
217 |
+
|
218 |
+
self.ui_defaults_review = gr.HTML("")
|
219 |
+
|
220 |
+
def setup_ui(self):
|
221 |
+
"""adds logic to elements created with create_ui; all add_block class must be made before this"""
|
222 |
+
|
223 |
+
assert not self.finalized_ui
|
224 |
+
self.finalized_ui = True
|
225 |
+
|
226 |
+
self.ui_defaults_view.click(fn=self.ui_view, inputs=list(self.component_mapping.values()), outputs=[self.ui_defaults_review])
|
227 |
+
self.ui_defaults_apply.click(fn=self.ui_apply, inputs=list(self.component_mapping.values()), outputs=[self.ui_defaults_review])
|
modules/ui_postprocessing.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from modules import scripts, shared, ui_common, postprocessing, call_queue
|
3 |
+
import modules.generation_parameters_copypaste as parameters_copypaste
|
4 |
+
|
5 |
+
|
6 |
+
def create_ui():
|
7 |
+
tab_index = gr.State(value=0)
|
8 |
+
|
9 |
+
with gr.Row(equal_height=False, variant='compact'):
|
10 |
+
with gr.Column(variant='compact'):
|
11 |
+
with gr.Tabs(elem_id="mode_extras"):
|
12 |
+
with gr.TabItem('Single Image', id="single_image", elem_id="extras_single_tab") as tab_single:
|
13 |
+
extras_image = gr.Image(label="Source", source="upload", interactive=True, type="pil", elem_id="extras_image")
|
14 |
+
|
15 |
+
with gr.TabItem('Batch Process', id="batch_process", elem_id="extras_batch_process_tab") as tab_batch:
|
16 |
+
image_batch = gr.Files(label="Batch Process", interactive=True, elem_id="extras_image_batch")
|
17 |
+
|
18 |
+
with gr.TabItem('Batch from Directory', id="batch_from_directory", elem_id="extras_batch_directory_tab") as tab_batch_dir:
|
19 |
+
extras_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, placeholder="A directory on the same machine where the server is running.", elem_id="extras_batch_input_dir")
|
20 |
+
extras_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, placeholder="Leave blank to save images to the default path.", elem_id="extras_batch_output_dir")
|
21 |
+
show_extras_results = gr.Checkbox(label='Show result images', value=True, elem_id="extras_show_extras_results")
|
22 |
+
|
23 |
+
submit = gr.Button('Generate', elem_id="extras_generate", variant='primary')
|
24 |
+
|
25 |
+
script_inputs = scripts.scripts_postproc.setup_ui()
|
26 |
+
|
27 |
+
with gr.Column():
|
28 |
+
result_images, html_info_x, html_info, html_log = ui_common.create_output_panel("extras", shared.opts.outdir_extras_samples)
|
29 |
+
|
30 |
+
tab_single.select(fn=lambda: 0, inputs=[], outputs=[tab_index])
|
31 |
+
tab_batch.select(fn=lambda: 1, inputs=[], outputs=[tab_index])
|
32 |
+
tab_batch_dir.select(fn=lambda: 2, inputs=[], outputs=[tab_index])
|
33 |
+
|
34 |
+
submit.click(
|
35 |
+
fn=call_queue.wrap_gradio_gpu_call(postprocessing.run_postprocessing, extra_outputs=[None, '']),
|
36 |
+
inputs=[
|
37 |
+
tab_index,
|
38 |
+
extras_image,
|
39 |
+
image_batch,
|
40 |
+
extras_batch_input_dir,
|
41 |
+
extras_batch_output_dir,
|
42 |
+
show_extras_results,
|
43 |
+
*script_inputs
|
44 |
+
],
|
45 |
+
outputs=[
|
46 |
+
result_images,
|
47 |
+
html_info_x,
|
48 |
+
html_info,
|
49 |
+
]
|
50 |
+
)
|
51 |
+
|
52 |
+
parameters_copypaste.add_paste_fields("extras", extras_image, None)
|
53 |
+
|
54 |
+
extras_image.change(
|
55 |
+
fn=scripts.scripts_postproc.image_changed,
|
56 |
+
inputs=[], outputs=[]
|
57 |
+
)
|
modules/ui_prompt_styles.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
|
3 |
+
from modules import shared, ui_common, ui_components, styles
|
4 |
+
|
5 |
+
styles_edit_symbol = '\U0001f58c\uFE0F' # 🖌️
|
6 |
+
styles_materialize_symbol = '\U0001f4cb' # 📋
|
7 |
+
|
8 |
+
|
9 |
+
def select_style(name):
|
10 |
+
style = shared.prompt_styles.styles.get(name)
|
11 |
+
existing = style is not None
|
12 |
+
empty = not name
|
13 |
+
|
14 |
+
prompt = style.prompt if style else gr.update()
|
15 |
+
negative_prompt = style.negative_prompt if style else gr.update()
|
16 |
+
|
17 |
+
return prompt, negative_prompt, gr.update(visible=existing), gr.update(visible=not empty)
|
18 |
+
|
19 |
+
|
20 |
+
def save_style(name, prompt, negative_prompt):
|
21 |
+
if not name:
|
22 |
+
return gr.update(visible=False)
|
23 |
+
|
24 |
+
style = styles.PromptStyle(name, prompt, negative_prompt)
|
25 |
+
shared.prompt_styles.styles[style.name] = style
|
26 |
+
shared.prompt_styles.save_styles(shared.styles_filename)
|
27 |
+
|
28 |
+
return gr.update(visible=True)
|
29 |
+
|
30 |
+
|
31 |
+
def delete_style(name):
|
32 |
+
if name == "":
|
33 |
+
return
|
34 |
+
|
35 |
+
shared.prompt_styles.styles.pop(name, None)
|
36 |
+
shared.prompt_styles.save_styles(shared.styles_filename)
|
37 |
+
|
38 |
+
return '', '', ''
|
39 |
+
|
40 |
+
|
41 |
+
def materialize_styles(prompt, negative_prompt, styles):
|
42 |
+
prompt = shared.prompt_styles.apply_styles_to_prompt(prompt, styles)
|
43 |
+
negative_prompt = shared.prompt_styles.apply_negative_styles_to_prompt(negative_prompt, styles)
|
44 |
+
|
45 |
+
return [gr.Textbox.update(value=prompt), gr.Textbox.update(value=negative_prompt), gr.Dropdown.update(value=[])]
|
46 |
+
|
47 |
+
|
48 |
+
def refresh_styles():
|
49 |
+
return gr.update(choices=list(shared.prompt_styles.styles)), gr.update(choices=list(shared.prompt_styles.styles))
|
50 |
+
|
51 |
+
|
52 |
+
class UiPromptStyles:
|
53 |
+
def __init__(self, tabname, main_ui_prompt, main_ui_negative_prompt):
|
54 |
+
self.tabname = tabname
|
55 |
+
|
56 |
+
with gr.Row(elem_id=f"{tabname}_styles_row"):
|
57 |
+
self.dropdown = gr.Dropdown(label="Styles", show_label=False, elem_id=f"{tabname}_styles", choices=list(shared.prompt_styles.styles), value=[], multiselect=True, tooltip="Styles")
|
58 |
+
edit_button = ui_components.ToolButton(value=styles_edit_symbol, elem_id=f"{tabname}_styles_edit_button", tooltip="Edit styles")
|
59 |
+
|
60 |
+
with gr.Box(elem_id=f"{tabname}_styles_dialog", elem_classes="popup-dialog") as styles_dialog:
|
61 |
+
with gr.Row():
|
62 |
+
self.selection = gr.Dropdown(label="Styles", elem_id=f"{tabname}_styles_edit_select", choices=list(shared.prompt_styles.styles), value=[], allow_custom_value=True, info="Styles allow you to add custom text to prompt. Use the {prompt} token in style text, and it will be replaced with user's prompt when applying style. Otherwise, style's text will be added to the end of the prompt.")
|
63 |
+
ui_common.create_refresh_button([self.dropdown, self.selection], shared.prompt_styles.reload, lambda: {"choices": list(shared.prompt_styles.styles)}, f"refresh_{tabname}_styles")
|
64 |
+
self.materialize = ui_components.ToolButton(value=styles_materialize_symbol, elem_id=f"{tabname}_style_apply", tooltip="Apply all selected styles from the style selction dropdown in main UI to the prompt.")
|
65 |
+
|
66 |
+
with gr.Row():
|
67 |
+
self.prompt = gr.Textbox(label="Prompt", show_label=True, elem_id=f"{tabname}_edit_style_prompt", lines=3)
|
68 |
+
|
69 |
+
with gr.Row():
|
70 |
+
self.neg_prompt = gr.Textbox(label="Negative prompt", show_label=True, elem_id=f"{tabname}_edit_style_neg_prompt", lines=3)
|
71 |
+
|
72 |
+
with gr.Row():
|
73 |
+
self.save = gr.Button('Save', variant='primary', elem_id=f'{tabname}_edit_style_save', visible=False)
|
74 |
+
self.delete = gr.Button('Delete', variant='primary', elem_id=f'{tabname}_edit_style_delete', visible=False)
|
75 |
+
self.close = gr.Button('Close', variant='secondary', elem_id=f'{tabname}_edit_style_close')
|
76 |
+
|
77 |
+
self.selection.change(
|
78 |
+
fn=select_style,
|
79 |
+
inputs=[self.selection],
|
80 |
+
outputs=[self.prompt, self.neg_prompt, self.delete, self.save],
|
81 |
+
show_progress=False,
|
82 |
+
)
|
83 |
+
|
84 |
+
self.save.click(
|
85 |
+
fn=save_style,
|
86 |
+
inputs=[self.selection, self.prompt, self.neg_prompt],
|
87 |
+
outputs=[self.delete],
|
88 |
+
show_progress=False,
|
89 |
+
).then(refresh_styles, outputs=[self.dropdown, self.selection], show_progress=False)
|
90 |
+
|
91 |
+
self.delete.click(
|
92 |
+
fn=delete_style,
|
93 |
+
_js='function(name){ if(name == "") return ""; return confirm("Delete style " + name + "?") ? name : ""; }',
|
94 |
+
inputs=[self.selection],
|
95 |
+
outputs=[self.selection, self.prompt, self.neg_prompt],
|
96 |
+
show_progress=False,
|
97 |
+
).then(refresh_styles, outputs=[self.dropdown, self.selection], show_progress=False)
|
98 |
+
|
99 |
+
self.materialize.click(
|
100 |
+
fn=materialize_styles,
|
101 |
+
inputs=[main_ui_prompt, main_ui_negative_prompt, self.dropdown],
|
102 |
+
outputs=[main_ui_prompt, main_ui_negative_prompt, self.dropdown],
|
103 |
+
show_progress=False,
|
104 |
+
).then(fn=None, _js="function(){update_"+tabname+"_tokens(); closePopup();}", show_progress=False)
|
105 |
+
|
106 |
+
ui_common.setup_dialog(button_show=edit_button, dialog=styles_dialog, button_close=self.close)
|
107 |
+
|
108 |
+
|
109 |
+
|
110 |
+
|
modules/ui_settings.py
ADDED
@@ -0,0 +1,296 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
|
3 |
+
from modules import ui_common, shared, script_callbacks, scripts, sd_models, sysinfo
|
4 |
+
from modules.call_queue import wrap_gradio_call
|
5 |
+
from modules.shared import opts
|
6 |
+
from modules.ui_components import FormRow
|
7 |
+
from modules.ui_gradio_extensions import reload_javascript
|
8 |
+
|
9 |
+
|
10 |
+
def get_value_for_setting(key):
|
11 |
+
value = getattr(opts, key)
|
12 |
+
|
13 |
+
info = opts.data_labels[key]
|
14 |
+
args = info.component_args() if callable(info.component_args) else info.component_args or {}
|
15 |
+
args = {k: v for k, v in args.items() if k not in {'precision'}}
|
16 |
+
|
17 |
+
return gr.update(value=value, **args)
|
18 |
+
|
19 |
+
|
20 |
+
def create_setting_component(key, is_quicksettings=False):
|
21 |
+
def fun():
|
22 |
+
return opts.data[key] if key in opts.data else opts.data_labels[key].default
|
23 |
+
|
24 |
+
info = opts.data_labels[key]
|
25 |
+
t = type(info.default)
|
26 |
+
|
27 |
+
args = info.component_args() if callable(info.component_args) else info.component_args
|
28 |
+
|
29 |
+
if info.component is not None:
|
30 |
+
comp = info.component
|
31 |
+
elif t == str:
|
32 |
+
comp = gr.Textbox
|
33 |
+
elif t == int:
|
34 |
+
comp = gr.Number
|
35 |
+
elif t == bool:
|
36 |
+
comp = gr.Checkbox
|
37 |
+
else:
|
38 |
+
raise Exception(f'bad options item type: {t} for key {key}')
|
39 |
+
|
40 |
+
elem_id = f"setting_{key}"
|
41 |
+
|
42 |
+
if info.refresh is not None:
|
43 |
+
if is_quicksettings:
|
44 |
+
res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {}))
|
45 |
+
ui_common.create_refresh_button(res, info.refresh, info.component_args, f"refresh_{key}")
|
46 |
+
else:
|
47 |
+
with FormRow():
|
48 |
+
res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {}))
|
49 |
+
ui_common.create_refresh_button(res, info.refresh, info.component_args, f"refresh_{key}")
|
50 |
+
else:
|
51 |
+
res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {}))
|
52 |
+
|
53 |
+
return res
|
54 |
+
|
55 |
+
|
56 |
+
class UiSettings:
|
57 |
+
submit = None
|
58 |
+
result = None
|
59 |
+
interface = None
|
60 |
+
components = None
|
61 |
+
component_dict = None
|
62 |
+
dummy_component = None
|
63 |
+
quicksettings_list = None
|
64 |
+
quicksettings_names = None
|
65 |
+
text_settings = None
|
66 |
+
|
67 |
+
def run_settings(self, *args):
|
68 |
+
changed = []
|
69 |
+
|
70 |
+
for key, value, comp in zip(opts.data_labels.keys(), args, self.components):
|
71 |
+
assert comp == self.dummy_component or opts.same_type(value, opts.data_labels[key].default), f"Bad value for setting {key}: {value}; expecting {type(opts.data_labels[key].default).__name__}"
|
72 |
+
|
73 |
+
for key, value, comp in zip(opts.data_labels.keys(), args, self.components):
|
74 |
+
if comp == self.dummy_component:
|
75 |
+
continue
|
76 |
+
|
77 |
+
if opts.set(key, value):
|
78 |
+
changed.append(key)
|
79 |
+
|
80 |
+
try:
|
81 |
+
opts.save(shared.config_filename)
|
82 |
+
except RuntimeError:
|
83 |
+
return opts.dumpjson(), f'{len(changed)} settings changed without save: {", ".join(changed)}.'
|
84 |
+
return opts.dumpjson(), f'{len(changed)} settings changed{": " if changed else ""}{", ".join(changed)}.'
|
85 |
+
|
86 |
+
def run_settings_single(self, value, key):
|
87 |
+
if not opts.same_type(value, opts.data_labels[key].default):
|
88 |
+
return gr.update(visible=True), opts.dumpjson()
|
89 |
+
|
90 |
+
if value is None or not opts.set(key, value):
|
91 |
+
return gr.update(value=getattr(opts, key)), opts.dumpjson()
|
92 |
+
|
93 |
+
opts.save(shared.config_filename)
|
94 |
+
|
95 |
+
return get_value_for_setting(key), opts.dumpjson()
|
96 |
+
|
97 |
+
def create_ui(self, loadsave, dummy_component):
|
98 |
+
self.components = []
|
99 |
+
self.component_dict = {}
|
100 |
+
self.dummy_component = dummy_component
|
101 |
+
|
102 |
+
shared.settings_components = self.component_dict
|
103 |
+
|
104 |
+
script_callbacks.ui_settings_callback()
|
105 |
+
opts.reorder()
|
106 |
+
|
107 |
+
with gr.Blocks(analytics_enabled=False) as settings_interface:
|
108 |
+
with gr.Row():
|
109 |
+
with gr.Column(scale=6):
|
110 |
+
self.submit = gr.Button(value="Apply settings", variant='primary', elem_id="settings_submit")
|
111 |
+
with gr.Column():
|
112 |
+
restart_gradio = gr.Button(value='Reload UI', variant='primary', elem_id="settings_restart_gradio")
|
113 |
+
|
114 |
+
self.result = gr.HTML(elem_id="settings_result")
|
115 |
+
|
116 |
+
self.quicksettings_names = opts.quicksettings_list
|
117 |
+
self.quicksettings_names = {x: i for i, x in enumerate(self.quicksettings_names) if x != 'quicksettings'}
|
118 |
+
|
119 |
+
self.quicksettings_list = []
|
120 |
+
|
121 |
+
previous_section = None
|
122 |
+
current_tab = None
|
123 |
+
current_row = None
|
124 |
+
with gr.Tabs(elem_id="settings"):
|
125 |
+
for i, (k, item) in enumerate(opts.data_labels.items()):
|
126 |
+
section_must_be_skipped = item.section[0] is None
|
127 |
+
|
128 |
+
if previous_section != item.section and not section_must_be_skipped:
|
129 |
+
elem_id, text = item.section
|
130 |
+
|
131 |
+
if current_tab is not None:
|
132 |
+
current_row.__exit__()
|
133 |
+
current_tab.__exit__()
|
134 |
+
|
135 |
+
gr.Group()
|
136 |
+
current_tab = gr.TabItem(elem_id=f"settings_{elem_id}", label=text)
|
137 |
+
current_tab.__enter__()
|
138 |
+
current_row = gr.Column(variant='compact')
|
139 |
+
current_row.__enter__()
|
140 |
+
|
141 |
+
previous_section = item.section
|
142 |
+
|
143 |
+
if k in self.quicksettings_names and not shared.cmd_opts.freeze_settings:
|
144 |
+
self.quicksettings_list.append((i, k, item))
|
145 |
+
self.components.append(dummy_component)
|
146 |
+
elif section_must_be_skipped:
|
147 |
+
self.components.append(dummy_component)
|
148 |
+
else:
|
149 |
+
component = create_setting_component(k)
|
150 |
+
self.component_dict[k] = component
|
151 |
+
self.components.append(component)
|
152 |
+
|
153 |
+
if current_tab is not None:
|
154 |
+
current_row.__exit__()
|
155 |
+
current_tab.__exit__()
|
156 |
+
|
157 |
+
with gr.TabItem("Defaults", id="defaults", elem_id="settings_tab_defaults"):
|
158 |
+
loadsave.create_ui()
|
159 |
+
|
160 |
+
with gr.TabItem("Sysinfo", id="sysinfo", elem_id="settings_tab_sysinfo"):
|
161 |
+
gr.HTML('<a href="./internal/sysinfo-download" class="sysinfo_big_link" download>Download system info</a><br /><a href="./internal/sysinfo" target="_blank">(or open as text in a new page)</a>', elem_id="sysinfo_download")
|
162 |
+
|
163 |
+
with gr.Row():
|
164 |
+
with gr.Column(scale=1):
|
165 |
+
sysinfo_check_file = gr.File(label="Check system info for validity", type='binary')
|
166 |
+
with gr.Column(scale=1):
|
167 |
+
sysinfo_check_output = gr.HTML("", elem_id="sysinfo_validity")
|
168 |
+
with gr.Column(scale=100):
|
169 |
+
pass
|
170 |
+
|
171 |
+
with gr.TabItem("Actions", id="actions", elem_id="settings_tab_actions"):
|
172 |
+
request_notifications = gr.Button(value='Request browser notifications', elem_id="request_notifications")
|
173 |
+
download_localization = gr.Button(value='Download localization template', elem_id="download_localization")
|
174 |
+
reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary', elem_id="settings_reload_script_bodies")
|
175 |
+
with gr.Row():
|
176 |
+
unload_sd_model = gr.Button(value='Unload SD checkpoint to free VRAM', elem_id="sett_unload_sd_model")
|
177 |
+
reload_sd_model = gr.Button(value='Reload the last SD checkpoint back into VRAM', elem_id="sett_reload_sd_model")
|
178 |
+
|
179 |
+
with gr.TabItem("Licenses", id="licenses", elem_id="settings_tab_licenses"):
|
180 |
+
gr.HTML(shared.html("licenses.html"), elem_id="licenses")
|
181 |
+
|
182 |
+
gr.Button(value="Show all pages", elem_id="settings_show_all_pages")
|
183 |
+
|
184 |
+
self.text_settings = gr.Textbox(elem_id="settings_json", value=lambda: opts.dumpjson(), visible=False)
|
185 |
+
|
186 |
+
unload_sd_model.click(
|
187 |
+
fn=sd_models.unload_model_weights,
|
188 |
+
inputs=[],
|
189 |
+
outputs=[]
|
190 |
+
)
|
191 |
+
|
192 |
+
reload_sd_model.click(
|
193 |
+
fn=sd_models.reload_model_weights,
|
194 |
+
inputs=[],
|
195 |
+
outputs=[]
|
196 |
+
)
|
197 |
+
|
198 |
+
request_notifications.click(
|
199 |
+
fn=lambda: None,
|
200 |
+
inputs=[],
|
201 |
+
outputs=[],
|
202 |
+
_js='function(){}'
|
203 |
+
)
|
204 |
+
|
205 |
+
download_localization.click(
|
206 |
+
fn=lambda: None,
|
207 |
+
inputs=[],
|
208 |
+
outputs=[],
|
209 |
+
_js='download_localization'
|
210 |
+
)
|
211 |
+
|
212 |
+
def reload_scripts():
|
213 |
+
scripts.reload_script_body_only()
|
214 |
+
reload_javascript() # need to refresh the html page
|
215 |
+
|
216 |
+
reload_script_bodies.click(
|
217 |
+
fn=reload_scripts,
|
218 |
+
inputs=[],
|
219 |
+
outputs=[]
|
220 |
+
)
|
221 |
+
|
222 |
+
restart_gradio.click(
|
223 |
+
fn=shared.state.request_restart,
|
224 |
+
_js='restart_reload',
|
225 |
+
inputs=[],
|
226 |
+
outputs=[],
|
227 |
+
)
|
228 |
+
|
229 |
+
def check_file(x):
|
230 |
+
if x is None:
|
231 |
+
return ''
|
232 |
+
|
233 |
+
if sysinfo.check(x.decode('utf8', errors='ignore')):
|
234 |
+
return 'Valid'
|
235 |
+
|
236 |
+
return 'Invalid'
|
237 |
+
|
238 |
+
sysinfo_check_file.change(
|
239 |
+
fn=check_file,
|
240 |
+
inputs=[sysinfo_check_file],
|
241 |
+
outputs=[sysinfo_check_output],
|
242 |
+
)
|
243 |
+
|
244 |
+
self.interface = settings_interface
|
245 |
+
|
246 |
+
def add_quicksettings(self):
|
247 |
+
with gr.Row(elem_id="quicksettings", variant="compact"):
|
248 |
+
for _i, k, _item in sorted(self.quicksettings_list, key=lambda x: self.quicksettings_names.get(x[1], x[0])):
|
249 |
+
component = create_setting_component(k, is_quicksettings=True)
|
250 |
+
self.component_dict[k] = component
|
251 |
+
|
252 |
+
def add_functionality(self, demo):
|
253 |
+
self.submit.click(
|
254 |
+
fn=wrap_gradio_call(lambda *args: self.run_settings(*args), extra_outputs=[gr.update()]),
|
255 |
+
inputs=self.components,
|
256 |
+
outputs=[self.text_settings, self.result],
|
257 |
+
)
|
258 |
+
|
259 |
+
for _i, k, _item in self.quicksettings_list:
|
260 |
+
component = self.component_dict[k]
|
261 |
+
info = opts.data_labels[k]
|
262 |
+
|
263 |
+
if isinstance(component, gr.Textbox):
|
264 |
+
methods = [component.submit, component.blur]
|
265 |
+
elif hasattr(component, 'release'):
|
266 |
+
methods = [component.release]
|
267 |
+
else:
|
268 |
+
methods = [component.change]
|
269 |
+
|
270 |
+
for method in methods:
|
271 |
+
method(
|
272 |
+
fn=lambda value, k=k: self.run_settings_single(value, key=k),
|
273 |
+
inputs=[component],
|
274 |
+
outputs=[component, self.text_settings],
|
275 |
+
show_progress=info.refresh is not None,
|
276 |
+
)
|
277 |
+
|
278 |
+
button_set_checkpoint = gr.Button('Change checkpoint', elem_id='change_checkpoint', visible=False)
|
279 |
+
button_set_checkpoint.click(
|
280 |
+
fn=lambda value, _: self.run_settings_single(value, key='sd_model_checkpoint'),
|
281 |
+
_js="function(v){ var res = desiredCheckpointName; desiredCheckpointName = ''; return [res || v, null]; }",
|
282 |
+
inputs=[self.component_dict['sd_model_checkpoint'], self.dummy_component],
|
283 |
+
outputs=[self.component_dict['sd_model_checkpoint'], self.text_settings],
|
284 |
+
)
|
285 |
+
|
286 |
+
component_keys = [k for k in opts.data_labels.keys() if k in self.component_dict]
|
287 |
+
|
288 |
+
def get_settings_values():
|
289 |
+
return [get_value_for_setting(key) for key in component_keys]
|
290 |
+
|
291 |
+
demo.load(
|
292 |
+
fn=get_settings_values,
|
293 |
+
inputs=[],
|
294 |
+
outputs=[self.component_dict[k] for k in component_keys],
|
295 |
+
queue=False,
|
296 |
+
)
|
modules/ui_tempdir.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import tempfile
|
3 |
+
from collections import namedtuple
|
4 |
+
from pathlib import Path
|
5 |
+
|
6 |
+
import gradio.components
|
7 |
+
|
8 |
+
from PIL import PngImagePlugin
|
9 |
+
|
10 |
+
from modules import shared
|
11 |
+
|
12 |
+
|
13 |
+
Savedfile = namedtuple("Savedfile", ["name"])
|
14 |
+
|
15 |
+
|
16 |
+
def register_tmp_file(gradio, filename):
|
17 |
+
if hasattr(gradio, 'temp_file_sets'): # gradio 3.15
|
18 |
+
gradio.temp_file_sets[0] = gradio.temp_file_sets[0] | {os.path.abspath(filename)}
|
19 |
+
|
20 |
+
if hasattr(gradio, 'temp_dirs'): # gradio 3.9
|
21 |
+
gradio.temp_dirs = gradio.temp_dirs | {os.path.abspath(os.path.dirname(filename))}
|
22 |
+
|
23 |
+
|
24 |
+
def check_tmp_file(gradio, filename):
|
25 |
+
if hasattr(gradio, 'temp_file_sets'):
|
26 |
+
return any(filename in fileset for fileset in gradio.temp_file_sets)
|
27 |
+
|
28 |
+
if hasattr(gradio, 'temp_dirs'):
|
29 |
+
return any(Path(temp_dir).resolve() in Path(filename).resolve().parents for temp_dir in gradio.temp_dirs)
|
30 |
+
|
31 |
+
return False
|
32 |
+
|
33 |
+
|
34 |
+
def save_pil_to_file(self, pil_image, dir=None, format="png"):
|
35 |
+
already_saved_as = getattr(pil_image, 'already_saved_as', None)
|
36 |
+
if already_saved_as and os.path.isfile(already_saved_as):
|
37 |
+
register_tmp_file(shared.demo, already_saved_as)
|
38 |
+
filename = already_saved_as
|
39 |
+
|
40 |
+
if not shared.opts.save_images_add_number:
|
41 |
+
filename += f'?{os.path.getmtime(already_saved_as)}'
|
42 |
+
|
43 |
+
return filename
|
44 |
+
|
45 |
+
if shared.opts.temp_dir != "":
|
46 |
+
dir = shared.opts.temp_dir
|
47 |
+
else:
|
48 |
+
os.makedirs(dir, exist_ok=True)
|
49 |
+
|
50 |
+
use_metadata = False
|
51 |
+
metadata = PngImagePlugin.PngInfo()
|
52 |
+
for key, value in pil_image.info.items():
|
53 |
+
if isinstance(key, str) and isinstance(value, str):
|
54 |
+
metadata.add_text(key, value)
|
55 |
+
use_metadata = True
|
56 |
+
|
57 |
+
file_obj = tempfile.NamedTemporaryFile(delete=False, suffix=".png", dir=dir)
|
58 |
+
pil_image.save(file_obj, pnginfo=(metadata if use_metadata else None))
|
59 |
+
return file_obj.name
|
60 |
+
|
61 |
+
|
62 |
+
def install_ui_tempdir_override():
|
63 |
+
"""override save to file function so that it also writes PNG info"""
|
64 |
+
gradio.components.IOComponent.pil_to_temp_file = save_pil_to_file
|
65 |
+
|
66 |
+
|
67 |
+
def on_tmpdir_changed():
|
68 |
+
if shared.opts.temp_dir == "" or shared.demo is None:
|
69 |
+
return
|
70 |
+
|
71 |
+
os.makedirs(shared.opts.temp_dir, exist_ok=True)
|
72 |
+
|
73 |
+
register_tmp_file(shared.demo, os.path.join(shared.opts.temp_dir, "x"))
|
74 |
+
|
75 |
+
|
76 |
+
def cleanup_tmpdr():
|
77 |
+
temp_dir = shared.opts.temp_dir
|
78 |
+
if temp_dir == "" or not os.path.isdir(temp_dir):
|
79 |
+
return
|
80 |
+
|
81 |
+
for root, _, files in os.walk(temp_dir, topdown=False):
|
82 |
+
for name in files:
|
83 |
+
_, extension = os.path.splitext(name)
|
84 |
+
if extension != ".png":
|
85 |
+
continue
|
86 |
+
|
87 |
+
filename = os.path.join(root, name)
|
88 |
+
os.remove(filename)
|
modules/upscaler.py
ADDED
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from abc import abstractmethod
|
3 |
+
|
4 |
+
import PIL
|
5 |
+
from PIL import Image
|
6 |
+
|
7 |
+
import modules.shared
|
8 |
+
from modules import modelloader, shared
|
9 |
+
|
10 |
+
LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
|
11 |
+
NEAREST = (Image.Resampling.NEAREST if hasattr(Image, 'Resampling') else Image.NEAREST)
|
12 |
+
|
13 |
+
|
14 |
+
class Upscaler:
|
15 |
+
name = None
|
16 |
+
model_path = None
|
17 |
+
model_name = None
|
18 |
+
model_url = None
|
19 |
+
enable = True
|
20 |
+
filter = None
|
21 |
+
model = None
|
22 |
+
user_path = None
|
23 |
+
scalers: []
|
24 |
+
tile = True
|
25 |
+
|
26 |
+
def __init__(self, create_dirs=False):
|
27 |
+
self.mod_pad_h = None
|
28 |
+
self.tile_size = modules.shared.opts.ESRGAN_tile
|
29 |
+
self.tile_pad = modules.shared.opts.ESRGAN_tile_overlap
|
30 |
+
self.device = modules.shared.device
|
31 |
+
self.img = None
|
32 |
+
self.output = None
|
33 |
+
self.scale = 1
|
34 |
+
self.half = not modules.shared.cmd_opts.no_half
|
35 |
+
self.pre_pad = 0
|
36 |
+
self.mod_scale = None
|
37 |
+
self.model_download_path = None
|
38 |
+
|
39 |
+
if self.model_path is None and self.name:
|
40 |
+
self.model_path = os.path.join(shared.models_path, self.name)
|
41 |
+
if self.model_path and create_dirs:
|
42 |
+
os.makedirs(self.model_path, exist_ok=True)
|
43 |
+
|
44 |
+
try:
|
45 |
+
import cv2 # noqa: F401
|
46 |
+
self.can_tile = True
|
47 |
+
except Exception:
|
48 |
+
pass
|
49 |
+
|
50 |
+
@abstractmethod
|
51 |
+
def do_upscale(self, img: PIL.Image, selected_model: str):
|
52 |
+
return img
|
53 |
+
|
54 |
+
def upscale(self, img: PIL.Image, scale, selected_model: str = None):
|
55 |
+
self.scale = scale
|
56 |
+
dest_w = int((img.width * scale) // 8 * 8)
|
57 |
+
dest_h = int((img.height * scale) // 8 * 8)
|
58 |
+
|
59 |
+
for _ in range(3):
|
60 |
+
shape = (img.width, img.height)
|
61 |
+
|
62 |
+
img = self.do_upscale(img, selected_model)
|
63 |
+
|
64 |
+
if shape == (img.width, img.height):
|
65 |
+
break
|
66 |
+
|
67 |
+
if img.width >= dest_w and img.height >= dest_h:
|
68 |
+
break
|
69 |
+
|
70 |
+
if img.width != dest_w or img.height != dest_h:
|
71 |
+
img = img.resize((int(dest_w), int(dest_h)), resample=LANCZOS)
|
72 |
+
|
73 |
+
return img
|
74 |
+
|
75 |
+
@abstractmethod
|
76 |
+
def load_model(self, path: str):
|
77 |
+
pass
|
78 |
+
|
79 |
+
def find_models(self, ext_filter=None) -> list:
|
80 |
+
return modelloader.load_models(model_path=self.model_path, model_url=self.model_url, command_path=self.user_path, ext_filter=ext_filter)
|
81 |
+
|
82 |
+
def update_status(self, prompt):
|
83 |
+
print(f"\nextras: {prompt}", file=shared.progress_print_out)
|
84 |
+
|
85 |
+
|
86 |
+
class UpscalerData:
|
87 |
+
name = None
|
88 |
+
data_path = None
|
89 |
+
scale: int = 4
|
90 |
+
scaler: Upscaler = None
|
91 |
+
model: None
|
92 |
+
|
93 |
+
def __init__(self, name: str, path: str, upscaler: Upscaler = None, scale: int = 4, model=None):
|
94 |
+
self.name = name
|
95 |
+
self.data_path = path
|
96 |
+
self.local_data_path = path
|
97 |
+
self.scaler = upscaler
|
98 |
+
self.scale = scale
|
99 |
+
self.model = model
|
100 |
+
|
101 |
+
|
102 |
+
class UpscalerNone(Upscaler):
|
103 |
+
name = "None"
|
104 |
+
scalers = []
|
105 |
+
|
106 |
+
def load_model(self, path):
|
107 |
+
pass
|
108 |
+
|
109 |
+
def do_upscale(self, img, selected_model=None):
|
110 |
+
return img
|
111 |
+
|
112 |
+
def __init__(self, dirname=None):
|
113 |
+
super().__init__(False)
|
114 |
+
self.scalers = [UpscalerData("None", None, self)]
|
115 |
+
|
116 |
+
|
117 |
+
class UpscalerLanczos(Upscaler):
|
118 |
+
scalers = []
|
119 |
+
|
120 |
+
def do_upscale(self, img, selected_model=None):
|
121 |
+
return img.resize((int(img.width * self.scale), int(img.height * self.scale)), resample=LANCZOS)
|
122 |
+
|
123 |
+
def load_model(self, _):
|
124 |
+
pass
|
125 |
+
|
126 |
+
def __init__(self, dirname=None):
|
127 |
+
super().__init__(False)
|
128 |
+
self.name = "Lanczos"
|
129 |
+
self.scalers = [UpscalerData("Lanczos", None, self)]
|
130 |
+
|
131 |
+
|
132 |
+
class UpscalerNearest(Upscaler):
|
133 |
+
scalers = []
|
134 |
+
|
135 |
+
def do_upscale(self, img, selected_model=None):
|
136 |
+
return img.resize((int(img.width * self.scale), int(img.height * self.scale)), resample=NEAREST)
|
137 |
+
|
138 |
+
def load_model(self, _):
|
139 |
+
pass
|
140 |
+
|
141 |
+
def __init__(self, dirname=None):
|
142 |
+
super().__init__(False)
|
143 |
+
self.name = "Nearest"
|
144 |
+
self.scalers = [UpscalerData("Nearest", None, self)]
|
modules/util.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import re
|
3 |
+
|
4 |
+
from modules import shared
|
5 |
+
from modules.paths_internal import script_path
|
6 |
+
|
7 |
+
|
8 |
+
def natural_sort_key(s, regex=re.compile('([0-9]+)')):
|
9 |
+
return [int(text) if text.isdigit() else text.lower() for text in regex.split(s)]
|
10 |
+
|
11 |
+
|
12 |
+
def listfiles(dirname):
|
13 |
+
filenames = [os.path.join(dirname, x) for x in sorted(os.listdir(dirname), key=natural_sort_key) if not x.startswith(".")]
|
14 |
+
return [file for file in filenames if os.path.isfile(file)]
|
15 |
+
|
16 |
+
|
17 |
+
def html_path(filename):
|
18 |
+
return os.path.join(script_path, "html", filename)
|
19 |
+
|
20 |
+
|
21 |
+
def html(filename):
|
22 |
+
path = html_path(filename)
|
23 |
+
|
24 |
+
if os.path.exists(path):
|
25 |
+
with open(path, encoding="utf8") as file:
|
26 |
+
return file.read()
|
27 |
+
|
28 |
+
return ""
|
29 |
+
|
30 |
+
|
31 |
+
def walk_files(path, allowed_extensions=None):
|
32 |
+
if not os.path.exists(path):
|
33 |
+
return
|
34 |
+
|
35 |
+
if allowed_extensions is not None:
|
36 |
+
allowed_extensions = set(allowed_extensions)
|
37 |
+
|
38 |
+
items = list(os.walk(path, followlinks=True))
|
39 |
+
items = sorted(items, key=lambda x: natural_sort_key(x[0]))
|
40 |
+
|
41 |
+
for root, _, files in items:
|
42 |
+
for filename in sorted(files, key=natural_sort_key):
|
43 |
+
if allowed_extensions is not None:
|
44 |
+
_, ext = os.path.splitext(filename)
|
45 |
+
if ext not in allowed_extensions:
|
46 |
+
continue
|
47 |
+
|
48 |
+
if not shared.opts.list_hidden_files and ("/." in root or "\\." in root):
|
49 |
+
continue
|
50 |
+
|
51 |
+
yield os.path.join(root, filename)
|
52 |
+
|
53 |
+
|
54 |
+
def ldm_print(*args, **kwargs):
|
55 |
+
if shared.opts.hide_ldm_prints:
|
56 |
+
return
|
57 |
+
|
58 |
+
print(*args, **kwargs)
|
modules/xlmr.py
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import BertPreTrainedModel, BertConfig
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch
|
4 |
+
from transformers.models.xlm_roberta.configuration_xlm_roberta import XLMRobertaConfig
|
5 |
+
from transformers import XLMRobertaModel,XLMRobertaTokenizer
|
6 |
+
from typing import Optional
|
7 |
+
|
8 |
+
class BertSeriesConfig(BertConfig):
|
9 |
+
def __init__(self, vocab_size=30522, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, hidden_act="gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=512, type_vocab_size=2, initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, position_embedding_type="absolute", use_cache=True, classifier_dropout=None,project_dim=512, pooler_fn="average",learn_encoder=False,model_type='bert',**kwargs):
|
10 |
+
|
11 |
+
super().__init__(vocab_size, hidden_size, num_hidden_layers, num_attention_heads, intermediate_size, hidden_act, hidden_dropout_prob, attention_probs_dropout_prob, max_position_embeddings, type_vocab_size, initializer_range, layer_norm_eps, pad_token_id, position_embedding_type, use_cache, classifier_dropout, **kwargs)
|
12 |
+
self.project_dim = project_dim
|
13 |
+
self.pooler_fn = pooler_fn
|
14 |
+
self.learn_encoder = learn_encoder
|
15 |
+
|
16 |
+
class RobertaSeriesConfig(XLMRobertaConfig):
|
17 |
+
def __init__(self, pad_token_id=1, bos_token_id=0, eos_token_id=2,project_dim=512,pooler_fn='cls',learn_encoder=False, **kwargs):
|
18 |
+
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
19 |
+
self.project_dim = project_dim
|
20 |
+
self.pooler_fn = pooler_fn
|
21 |
+
self.learn_encoder = learn_encoder
|
22 |
+
|
23 |
+
|
24 |
+
class BertSeriesModelWithTransformation(BertPreTrainedModel):
|
25 |
+
|
26 |
+
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
27 |
+
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
|
28 |
+
config_class = BertSeriesConfig
|
29 |
+
|
30 |
+
def __init__(self, config=None, **kargs):
|
31 |
+
# modify initialization for autoloading
|
32 |
+
if config is None:
|
33 |
+
config = XLMRobertaConfig()
|
34 |
+
config.attention_probs_dropout_prob= 0.1
|
35 |
+
config.bos_token_id=0
|
36 |
+
config.eos_token_id=2
|
37 |
+
config.hidden_act='gelu'
|
38 |
+
config.hidden_dropout_prob=0.1
|
39 |
+
config.hidden_size=1024
|
40 |
+
config.initializer_range=0.02
|
41 |
+
config.intermediate_size=4096
|
42 |
+
config.layer_norm_eps=1e-05
|
43 |
+
config.max_position_embeddings=514
|
44 |
+
|
45 |
+
config.num_attention_heads=16
|
46 |
+
config.num_hidden_layers=24
|
47 |
+
config.output_past=True
|
48 |
+
config.pad_token_id=1
|
49 |
+
config.position_embedding_type= "absolute"
|
50 |
+
|
51 |
+
config.type_vocab_size= 1
|
52 |
+
config.use_cache=True
|
53 |
+
config.vocab_size= 250002
|
54 |
+
config.project_dim = 768
|
55 |
+
config.learn_encoder = False
|
56 |
+
super().__init__(config)
|
57 |
+
self.roberta = XLMRobertaModel(config)
|
58 |
+
self.transformation = nn.Linear(config.hidden_size,config.project_dim)
|
59 |
+
self.pre_LN=nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
60 |
+
self.tokenizer = XLMRobertaTokenizer.from_pretrained('xlm-roberta-large')
|
61 |
+
self.pooler = lambda x: x[:,0]
|
62 |
+
self.post_init()
|
63 |
+
|
64 |
+
def encode(self,c):
|
65 |
+
device = next(self.parameters()).device
|
66 |
+
text = self.tokenizer(c,
|
67 |
+
truncation=True,
|
68 |
+
max_length=77,
|
69 |
+
return_length=False,
|
70 |
+
return_overflowing_tokens=False,
|
71 |
+
padding="max_length",
|
72 |
+
return_tensors="pt")
|
73 |
+
text["input_ids"] = torch.tensor(text["input_ids"]).to(device)
|
74 |
+
text["attention_mask"] = torch.tensor(
|
75 |
+
text['attention_mask']).to(device)
|
76 |
+
features = self(**text)
|
77 |
+
return features['projection_state']
|
78 |
+
|
79 |
+
def forward(
|
80 |
+
self,
|
81 |
+
input_ids: Optional[torch.Tensor] = None,
|
82 |
+
attention_mask: Optional[torch.Tensor] = None,
|
83 |
+
token_type_ids: Optional[torch.Tensor] = None,
|
84 |
+
position_ids: Optional[torch.Tensor] = None,
|
85 |
+
head_mask: Optional[torch.Tensor] = None,
|
86 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
87 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
88 |
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
89 |
+
output_attentions: Optional[bool] = None,
|
90 |
+
return_dict: Optional[bool] = None,
|
91 |
+
output_hidden_states: Optional[bool] = None,
|
92 |
+
) :
|
93 |
+
r"""
|
94 |
+
"""
|
95 |
+
|
96 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
97 |
+
|
98 |
+
|
99 |
+
outputs = self.roberta(
|
100 |
+
input_ids=input_ids,
|
101 |
+
attention_mask=attention_mask,
|
102 |
+
token_type_ids=token_type_ids,
|
103 |
+
position_ids=position_ids,
|
104 |
+
head_mask=head_mask,
|
105 |
+
inputs_embeds=inputs_embeds,
|
106 |
+
encoder_hidden_states=encoder_hidden_states,
|
107 |
+
encoder_attention_mask=encoder_attention_mask,
|
108 |
+
output_attentions=output_attentions,
|
109 |
+
output_hidden_states=True,
|
110 |
+
return_dict=return_dict,
|
111 |
+
)
|
112 |
+
|
113 |
+
# last module outputs
|
114 |
+
sequence_output = outputs[0]
|
115 |
+
|
116 |
+
|
117 |
+
# project every module
|
118 |
+
sequence_output_ln = self.pre_LN(sequence_output)
|
119 |
+
|
120 |
+
# pooler
|
121 |
+
pooler_output = self.pooler(sequence_output_ln)
|
122 |
+
pooler_output = self.transformation(pooler_output)
|
123 |
+
projection_state = self.transformation(outputs.last_hidden_state)
|
124 |
+
|
125 |
+
return {
|
126 |
+
'pooler_output':pooler_output,
|
127 |
+
'last_hidden_state':outputs.last_hidden_state,
|
128 |
+
'hidden_states':outputs.hidden_states,
|
129 |
+
'attentions':outputs.attentions,
|
130 |
+
'projection_state':projection_state,
|
131 |
+
'sequence_out': sequence_output
|
132 |
+
}
|
133 |
+
|
134 |
+
|
135 |
+
class RobertaSeriesModelWithTransformation(BertSeriesModelWithTransformation):
|
136 |
+
base_model_prefix = 'roberta'
|
137 |
+
config_class= RobertaSeriesConfig
|