jac12 commited on
Commit
4476e2e
·
1 Parent(s): 2f2ce98

Upload 22 files

Browse files
modules/timer.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import argparse
3
+
4
+
5
+ class TimerSubcategory:
6
+ def __init__(self, timer, category):
7
+ self.timer = timer
8
+ self.category = category
9
+ self.start = None
10
+ self.original_base_category = timer.base_category
11
+
12
+ def __enter__(self):
13
+ self.start = time.time()
14
+ self.timer.base_category = self.original_base_category + self.category + "/"
15
+ self.timer.subcategory_level += 1
16
+
17
+ if self.timer.print_log:
18
+ print(f"{' ' * self.timer.subcategory_level}{self.category}:")
19
+
20
+ def __exit__(self, exc_type, exc_val, exc_tb):
21
+ elapsed_for_subcategroy = time.time() - self.start
22
+ self.timer.base_category = self.original_base_category
23
+ self.timer.add_time_to_record(self.original_base_category + self.category, elapsed_for_subcategroy)
24
+ self.timer.subcategory_level -= 1
25
+ self.timer.record(self.category, disable_log=True)
26
+
27
+
28
+ class Timer:
29
+ def __init__(self, print_log=False):
30
+ self.start = time.time()
31
+ self.records = {}
32
+ self.total = 0
33
+ self.base_category = ''
34
+ self.print_log = print_log
35
+ self.subcategory_level = 0
36
+
37
+ def elapsed(self):
38
+ end = time.time()
39
+ res = end - self.start
40
+ self.start = end
41
+ return res
42
+
43
+ def add_time_to_record(self, category, amount):
44
+ if category not in self.records:
45
+ self.records[category] = 0
46
+
47
+ self.records[category] += amount
48
+
49
+ def record(self, category, extra_time=0, disable_log=False):
50
+ e = self.elapsed()
51
+
52
+ self.add_time_to_record(self.base_category + category, e + extra_time)
53
+
54
+ self.total += e + extra_time
55
+
56
+ if self.print_log and not disable_log:
57
+ print(f"{' ' * self.subcategory_level}{category}: done in {e + extra_time:.3f}s")
58
+
59
+ def subcategory(self, name):
60
+ self.elapsed()
61
+
62
+ subcat = TimerSubcategory(self, name)
63
+ return subcat
64
+
65
+ def summary(self):
66
+ res = f"{self.total:.1f}s"
67
+
68
+ additions = [(category, time_taken) for category, time_taken in self.records.items() if time_taken >= 0.1 and '/' not in category]
69
+ if not additions:
70
+ return res
71
+
72
+ res += " ("
73
+ res += ", ".join([f"{category}: {time_taken:.1f}s" for category, time_taken in additions])
74
+ res += ")"
75
+
76
+ return res
77
+
78
+ def dump(self):
79
+ return {'total': self.total, 'records': self.records}
80
+
81
+ def reset(self):
82
+ self.__init__()
83
+
84
+
85
+ parser = argparse.ArgumentParser(add_help=False)
86
+ parser.add_argument("--log-startup", action='store_true', help="print a detailed log of what's happening at startup")
87
+ args = parser.parse_known_args()[0]
88
+
89
+ startup_timer = Timer(print_log=args.log_startup)
90
+
91
+ startup_record = None
modules/txt2img.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from contextlib import closing
2
+
3
+ import modules.scripts
4
+ from modules import processing
5
+ from modules.generation_parameters_copypaste import create_override_settings_dict
6
+ from modules.shared import opts, cmd_opts
7
+ import modules.shared as shared
8
+ from modules.ui import plaintext_to_html
9
+ import gradio as gr
10
+
11
+
12
+ def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_name: str, n_iter: int, batch_size: int, cfg_scale: float, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_checkpoint_name: str, hr_sampler_name: str, hr_prompt: str, hr_negative_prompt, override_settings_texts, request: gr.Request, *args):
13
+ override_settings = create_override_settings_dict(override_settings_texts)
14
+
15
+ p = processing.StableDiffusionProcessingTxt2Img(
16
+ sd_model=shared.sd_model,
17
+ outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples,
18
+ outpath_grids=opts.outdir_grids or opts.outdir_txt2img_grids,
19
+ prompt=prompt,
20
+ styles=prompt_styles,
21
+ negative_prompt=negative_prompt,
22
+ sampler_name=sampler_name,
23
+ batch_size=batch_size,
24
+ n_iter=n_iter,
25
+ steps=steps,
26
+ cfg_scale=cfg_scale,
27
+ width=width,
28
+ height=height,
29
+ enable_hr=enable_hr,
30
+ denoising_strength=denoising_strength if enable_hr else None,
31
+ hr_scale=hr_scale,
32
+ hr_upscaler=hr_upscaler,
33
+ hr_second_pass_steps=hr_second_pass_steps,
34
+ hr_resize_x=hr_resize_x,
35
+ hr_resize_y=hr_resize_y,
36
+ hr_checkpoint_name=None if hr_checkpoint_name == 'Use same checkpoint' else hr_checkpoint_name,
37
+ hr_sampler_name=None if hr_sampler_name == 'Use same sampler' else hr_sampler_name,
38
+ hr_prompt=hr_prompt,
39
+ hr_negative_prompt=hr_negative_prompt,
40
+ override_settings=override_settings,
41
+ )
42
+
43
+ p.scripts = modules.scripts.scripts_txt2img
44
+ p.script_args = args
45
+
46
+ p.user = request.username
47
+
48
+ if cmd_opts.enable_console_prompts:
49
+ print(f"\ntxt2img: {prompt}", file=shared.progress_print_out)
50
+
51
+ with closing(p):
52
+ processed = modules.scripts.scripts_txt2img.run(p, *args)
53
+
54
+ if processed is None:
55
+ processed = processing.process_images(p)
56
+
57
+ shared.total_tqdm.clear()
58
+
59
+ generation_info_js = processed.js()
60
+ if opts.samples_log_stdout:
61
+ print(generation_info_js)
62
+
63
+ if opts.do_not_show_images:
64
+ processed.images = []
65
+
66
+ return processed.images, generation_info_js, plaintext_to_html(processed.info), plaintext_to_html(processed.comments, classname="comments")
modules/ui.py ADDED
@@ -0,0 +1,1366 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import mimetypes
3
+ import os
4
+ import sys
5
+ from functools import reduce
6
+ import warnings
7
+
8
+ import gradio as gr
9
+ import gradio.utils
10
+ import numpy as np
11
+ from PIL import Image, PngImagePlugin # noqa: F401
12
+ from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call
13
+
14
+ from modules import gradio_extensons # noqa: F401
15
+ from modules import sd_hijack, sd_models, script_callbacks, ui_extensions, deepbooru, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave, shared_items, ui_settings, timer, sysinfo, ui_checkpoint_merger, ui_prompt_styles, scripts, sd_samplers, processing, ui_extra_networks
16
+ from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML, InputAccordion, ResizeHandleRow
17
+ from modules.paths import script_path
18
+ from modules.ui_common import create_refresh_button
19
+ from modules.ui_gradio_extensions import reload_javascript
20
+
21
+ from modules.shared import opts, cmd_opts
22
+
23
+ import modules.generation_parameters_copypaste as parameters_copypaste
24
+ import modules.hypernetworks.ui as hypernetworks_ui
25
+ import modules.textual_inversion.ui as textual_inversion_ui
26
+ import modules.textual_inversion.textual_inversion as textual_inversion
27
+ import modules.shared as shared
28
+ import modules.images
29
+ from modules import prompt_parser
30
+ from modules.sd_hijack import model_hijack
31
+ from modules.generation_parameters_copypaste import image_from_url_text
32
+
33
+ create_setting_component = ui_settings.create_setting_component
34
+
35
+ warnings.filterwarnings("default" if opts.show_warnings else "ignore", category=UserWarning)
36
+ warnings.filterwarnings("default" if opts.show_gradio_deprecation_warnings else "ignore", category=gr.deprecation.GradioDeprecationWarning)
37
+
38
+ # this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI
39
+ mimetypes.init()
40
+ mimetypes.add_type('application/javascript', '.js')
41
+
42
+ # Likewise, add explicit content-type header for certain missing image types
43
+ mimetypes.add_type('image/webp', '.webp')
44
+
45
+ if not cmd_opts.share and not cmd_opts.listen:
46
+ # fix gradio phoning home
47
+ gradio.utils.version_check = lambda: None
48
+ gradio.utils.get_local_ip_address = lambda: '127.0.0.1'
49
+
50
+ if cmd_opts.ngrok is not None:
51
+ import modules.ngrok as ngrok
52
+ print('ngrok authtoken detected, trying to connect...')
53
+ ngrok.connect(
54
+ cmd_opts.ngrok,
55
+ cmd_opts.port if cmd_opts.port is not None else 7860,
56
+ cmd_opts.ngrok_options
57
+ )
58
+
59
+
60
+ def gr_show(visible=True):
61
+ return {"visible": visible, "__type__": "update"}
62
+
63
+
64
+ sample_img2img = "assets/stable-samples/img2img/sketch-mountains-input.jpg"
65
+ sample_img2img = sample_img2img if os.path.exists(sample_img2img) else None
66
+
67
+ # Using constants for these since the variation selector isn't visible.
68
+ # Important that they exactly match script.js for tooltip to work.
69
+ random_symbol = '\U0001f3b2\ufe0f' # 🎲️
70
+ reuse_symbol = '\u267b\ufe0f' # ♻️
71
+ paste_symbol = '\u2199\ufe0f' # ↙
72
+ refresh_symbol = '\U0001f504' # 🔄
73
+ save_style_symbol = '\U0001f4be' # 💾
74
+ apply_style_symbol = '\U0001f4cb' # 📋
75
+ clear_prompt_symbol = '\U0001f5d1\ufe0f' # 🗑️
76
+ extra_networks_symbol = '\U0001F3B4' # 🎴
77
+ switch_values_symbol = '\U000021C5' # ⇅
78
+ restore_progress_symbol = '\U0001F300' # 🌀
79
+ detect_image_size_symbol = '\U0001F4D0' # 📐
80
+
81
+
82
+ plaintext_to_html = ui_common.plaintext_to_html
83
+
84
+
85
+ def send_gradio_gallery_to_image(x):
86
+ if len(x) == 0:
87
+ return None
88
+ return image_from_url_text(x[0])
89
+
90
+
91
+ def calc_resolution_hires(enable, width, height, hr_scale, hr_resize_x, hr_resize_y):
92
+ if not enable:
93
+ return ""
94
+
95
+ p = processing.StableDiffusionProcessingTxt2Img(width=width, height=height, enable_hr=True, hr_scale=hr_scale, hr_resize_x=hr_resize_x, hr_resize_y=hr_resize_y)
96
+ p.calculate_target_resolution()
97
+
98
+ return f"from <span class='resolution'>{p.width}x{p.height}</span> to <span class='resolution'>{p.hr_resize_x or p.hr_upscale_to_x}x{p.hr_resize_y or p.hr_upscale_to_y}</span>"
99
+
100
+
101
+ def resize_from_to_html(width, height, scale_by):
102
+ target_width = int(width * scale_by)
103
+ target_height = int(height * scale_by)
104
+
105
+ if not target_width or not target_height:
106
+ return "no image selected"
107
+
108
+ return f"resize: from <span class='resolution'>{width}x{height}</span> to <span class='resolution'>{target_width}x{target_height}</span>"
109
+
110
+
111
+ def process_interrogate(interrogation_function, mode, ii_input_dir, ii_output_dir, *ii_singles):
112
+ if mode in {0, 1, 3, 4}:
113
+ return [interrogation_function(ii_singles[mode]), None]
114
+ elif mode == 2:
115
+ return [interrogation_function(ii_singles[mode]["image"]), None]
116
+ elif mode == 5:
117
+ assert not shared.cmd_opts.hide_ui_dir_config, "Launched with --hide-ui-dir-config, batch img2img disabled"
118
+ images = shared.listfiles(ii_input_dir)
119
+ print(f"Will process {len(images)} images.")
120
+ if ii_output_dir != "":
121
+ os.makedirs(ii_output_dir, exist_ok=True)
122
+ else:
123
+ ii_output_dir = ii_input_dir
124
+
125
+ for image in images:
126
+ img = Image.open(image)
127
+ filename = os.path.basename(image)
128
+ left, _ = os.path.splitext(filename)
129
+ print(interrogation_function(img), file=open(os.path.join(ii_output_dir, f"{left}.txt"), 'a', encoding='utf-8'))
130
+
131
+ return [gr.update(), None]
132
+
133
+
134
+ def interrogate(image):
135
+ prompt = shared.interrogator.interrogate(image.convert("RGB"))
136
+ return gr.update() if prompt is None else prompt
137
+
138
+
139
+ def interrogate_deepbooru(image):
140
+ prompt = deepbooru.model.tag(image)
141
+ return gr.update() if prompt is None else prompt
142
+
143
+
144
+ def connect_clear_prompt(button):
145
+ """Given clear button, prompt, and token_counter objects, setup clear prompt button click event"""
146
+ button.click(
147
+ _js="clear_prompt",
148
+ fn=None,
149
+ inputs=[],
150
+ outputs=[],
151
+ )
152
+
153
+
154
+ def update_token_counter(text, steps):
155
+ try:
156
+ text, _ = extra_networks.parse_prompt(text)
157
+
158
+ _, prompt_flat_list, _ = prompt_parser.get_multicond_prompt_list([text])
159
+ prompt_schedules = prompt_parser.get_learned_conditioning_prompt_schedules(prompt_flat_list, steps)
160
+
161
+ except Exception:
162
+ # a parsing error can happen here during typing, and we don't want to bother the user with
163
+ # messages related to it in console
164
+ prompt_schedules = [[[steps, text]]]
165
+
166
+ flat_prompts = reduce(lambda list1, list2: list1+list2, prompt_schedules)
167
+ prompts = [prompt_text for step, prompt_text in flat_prompts]
168
+ token_count, max_length = max([model_hijack.get_prompt_lengths(prompt) for prompt in prompts], key=lambda args: args[0])
169
+ return f"<span class='gr-box gr-text-input'>{token_count}/{max_length}</span>"
170
+
171
+
172
+ class Toprow:
173
+ """Creates a top row UI with prompts, generate button, styles, extra little buttons for things, and enables some functionality related to their operation"""
174
+
175
+ def __init__(self, is_img2img):
176
+ id_part = "img2img" if is_img2img else "txt2img"
177
+ self.id_part = id_part
178
+
179
+ with gr.Row(elem_id=f"{id_part}_toprow", variant="compact"):
180
+ with gr.Column(elem_id=f"{id_part}_prompt_container", scale=6):
181
+ with gr.Row():
182
+ with gr.Column(scale=80):
183
+ with gr.Row():
184
+ self.prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, lines=3, placeholder="Prompt (press Ctrl+Enter or Alt+Enter to generate)", elem_classes=["prompt"])
185
+ self.prompt_img = gr.File(label="", elem_id=f"{id_part}_prompt_image", file_count="single", type="binary", visible=False)
186
+
187
+ with gr.Row():
188
+ with gr.Column(scale=80):
189
+ with gr.Row():
190
+ self.negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=3, placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)", elem_classes=["prompt"])
191
+
192
+ self.button_interrogate = None
193
+ self.button_deepbooru = None
194
+ if is_img2img:
195
+ with gr.Column(scale=1, elem_classes="interrogate-col"):
196
+ self.button_interrogate = gr.Button('Interrogate\nCLIP', elem_id="interrogate")
197
+ self.button_deepbooru = gr.Button('Interrogate\nDeepBooru', elem_id="deepbooru")
198
+
199
+ with gr.Column(scale=1, elem_id=f"{id_part}_actions_column"):
200
+ with gr.Row(elem_id=f"{id_part}_generate_box", elem_classes="generate-box"):
201
+ self.interrupt = gr.Button('Interrupt', elem_id=f"{id_part}_interrupt", elem_classes="generate-box-interrupt")
202
+ self.skip = gr.Button('Skip', elem_id=f"{id_part}_skip", elem_classes="generate-box-skip")
203
+ self.submit = gr.Button('Generate', elem_id=f"{id_part}_generate", variant='primary')
204
+
205
+ self.skip.click(
206
+ fn=lambda: shared.state.skip(),
207
+ inputs=[],
208
+ outputs=[],
209
+ )
210
+
211
+ self.interrupt.click(
212
+ fn=lambda: shared.state.interrupt(),
213
+ inputs=[],
214
+ outputs=[],
215
+ )
216
+
217
+ with gr.Row(elem_id=f"{id_part}_tools"):
218
+ self.paste = ToolButton(value=paste_symbol, elem_id="paste")
219
+ self.clear_prompt_button = ToolButton(value=clear_prompt_symbol, elem_id=f"{id_part}_clear_prompt")
220
+ self.restore_progress_button = ToolButton(value=restore_progress_symbol, elem_id=f"{id_part}_restore_progress", visible=False)
221
+
222
+ self.token_counter = gr.HTML(value="<span>0/75</span>", elem_id=f"{id_part}_token_counter", elem_classes=["token-counter"])
223
+ self.token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button")
224
+ self.negative_token_counter = gr.HTML(value="<span>0/75</span>", elem_id=f"{id_part}_negative_token_counter", elem_classes=["token-counter"])
225
+ self.negative_token_button = gr.Button(visible=False, elem_id=f"{id_part}_negative_token_button")
226
+
227
+ self.clear_prompt_button.click(
228
+ fn=lambda *x: x,
229
+ _js="confirm_clear_prompt",
230
+ inputs=[self.prompt, self.negative_prompt],
231
+ outputs=[self.prompt, self.negative_prompt],
232
+ )
233
+
234
+ self.ui_styles = ui_prompt_styles.UiPromptStyles(id_part, self.prompt, self.negative_prompt)
235
+
236
+ self.prompt_img.change(
237
+ fn=modules.images.image_data,
238
+ inputs=[self.prompt_img],
239
+ outputs=[self.prompt, self.prompt_img],
240
+ show_progress=False,
241
+ )
242
+
243
+
244
+ def setup_progressbar(*args, **kwargs):
245
+ pass
246
+
247
+
248
+ def apply_setting(key, value):
249
+ if value is None:
250
+ return gr.update()
251
+
252
+ if shared.cmd_opts.freeze_settings:
253
+ return gr.update()
254
+
255
+ # dont allow model to be swapped when model hash exists in prompt
256
+ if key == "sd_model_checkpoint" and opts.disable_weights_auto_swap:
257
+ return gr.update()
258
+
259
+ if key == "sd_model_checkpoint":
260
+ ckpt_info = sd_models.get_closet_checkpoint_match(value)
261
+
262
+ if ckpt_info is not None:
263
+ value = ckpt_info.title
264
+ else:
265
+ return gr.update()
266
+
267
+ comp_args = opts.data_labels[key].component_args
268
+ if comp_args and isinstance(comp_args, dict) and comp_args.get('visible') is False:
269
+ return
270
+
271
+ valtype = type(opts.data_labels[key].default)
272
+ oldval = opts.data.get(key, None)
273
+ opts.data[key] = valtype(value) if valtype != type(None) else value
274
+ if oldval != value and opts.data_labels[key].onchange is not None:
275
+ opts.data_labels[key].onchange()
276
+
277
+ opts.save(shared.config_filename)
278
+ return getattr(opts, key)
279
+
280
+
281
+ def create_output_panel(tabname, outdir):
282
+ return ui_common.create_output_panel(tabname, outdir)
283
+
284
+
285
+ def create_sampler_and_steps_selection(choices, tabname):
286
+ if opts.samplers_in_dropdown:
287
+ with FormRow(elem_id=f"sampler_selection_{tabname}"):
288
+ sampler_name = gr.Dropdown(label='Sampling method', elem_id=f"{tabname}_sampling", choices=choices, value=choices[0])
289
+ steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling steps", value=20)
290
+ else:
291
+ with FormGroup(elem_id=f"sampler_selection_{tabname}"):
292
+ steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling steps", value=20)
293
+ sampler_name = gr.Radio(label='Sampling method', elem_id=f"{tabname}_sampling", choices=choices, value=choices[0])
294
+
295
+ return steps, sampler_name
296
+
297
+
298
+ def ordered_ui_categories():
299
+ user_order = {x.strip(): i * 2 + 1 for i, x in enumerate(shared.opts.ui_reorder_list)}
300
+
301
+ for _, category in sorted(enumerate(shared_items.ui_reorder_categories()), key=lambda x: user_order.get(x[1], x[0] * 2 + 0)):
302
+ yield category
303
+
304
+
305
+ def create_override_settings_dropdown(tabname, row):
306
+ dropdown = gr.Dropdown([], label="Override settings", visible=False, elem_id=f"{tabname}_override_settings", multiselect=True)
307
+
308
+ dropdown.change(
309
+ fn=lambda x: gr.Dropdown.update(visible=bool(x)),
310
+ inputs=[dropdown],
311
+ outputs=[dropdown],
312
+ )
313
+
314
+ return dropdown
315
+
316
+
317
+ def create_ui():
318
+ import modules.img2img
319
+ import modules.txt2img
320
+
321
+ reload_javascript()
322
+
323
+ parameters_copypaste.reset()
324
+
325
+ scripts.scripts_current = scripts.scripts_txt2img
326
+ scripts.scripts_txt2img.initialize_scripts(is_img2img=False)
327
+
328
+ with gr.Blocks(analytics_enabled=False) as txt2img_interface:
329
+ toprow = Toprow(is_img2img=False)
330
+
331
+ dummy_component = gr.Label(visible=False)
332
+
333
+ extra_tabs = gr.Tabs(elem_id="txt2img_extra_tabs")
334
+ extra_tabs.__enter__()
335
+
336
+ with gr.Tab("Generation", id="txt2img_generation") as txt2img_generation_tab, ResizeHandleRow(equal_height=False):
337
+ with gr.Column(variant='compact', elem_id="txt2img_settings"):
338
+ scripts.scripts_txt2img.prepare_ui()
339
+
340
+ for category in ordered_ui_categories():
341
+ if category == "sampler":
342
+ steps, sampler_name = create_sampler_and_steps_selection(sd_samplers.visible_sampler_names(), "txt2img")
343
+
344
+ elif category == "dimensions":
345
+ with FormRow():
346
+ with gr.Column(elem_id="txt2img_column_size", scale=4):
347
+ width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="txt2img_width")
348
+ height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="txt2img_height")
349
+
350
+ with gr.Column(elem_id="txt2img_dimensions_row", scale=1, elem_classes="dimensions-tools"):
351
+ res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="txt2img_res_switch_btn", label="Switch dims")
352
+
353
+ if opts.dimensions_and_batch_together:
354
+ with gr.Column(elem_id="txt2img_column_batch"):
355
+ batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="txt2img_batch_count")
356
+ batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="txt2img_batch_size")
357
+
358
+ elif category == "cfg":
359
+ with gr.Row():
360
+ cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="txt2img_cfg_scale")
361
+
362
+ elif category == "checkboxes":
363
+ with FormRow(elem_classes="checkboxes-row", variant="compact"):
364
+ pass
365
+
366
+ elif category == "accordions":
367
+ with gr.Row(elem_id="txt2img_accordions", elem_classes="accordions"):
368
+ with InputAccordion(False, label="Hires. fix", elem_id="txt2img_hr") as enable_hr:
369
+ with enable_hr.extra():
370
+ hr_final_resolution = FormHTML(value="", elem_id="txtimg_hr_finalres", label="Upscaled resolution", interactive=False, min_width=0)
371
+
372
+ with FormRow(elem_id="txt2img_hires_fix_row1", variant="compact"):
373
+ hr_upscaler = gr.Dropdown(label="Upscaler", elem_id="txt2img_hr_upscaler", choices=[*shared.latent_upscale_modes, *[x.name for x in shared.sd_upscalers]], value=shared.latent_upscale_default_mode)
374
+ hr_second_pass_steps = gr.Slider(minimum=0, maximum=150, step=1, label='Hires steps', value=0, elem_id="txt2img_hires_steps")
375
+ denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7, elem_id="txt2img_denoising_strength")
376
+
377
+ with FormRow(elem_id="txt2img_hires_fix_row2", variant="compact"):
378
+ hr_scale = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Upscale by", value=2.0, elem_id="txt2img_hr_scale")
379
+ hr_resize_x = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize width to", value=0, elem_id="txt2img_hr_resize_x")
380
+ hr_resize_y = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize height to", value=0, elem_id="txt2img_hr_resize_y")
381
+
382
+ with FormRow(elem_id="txt2img_hires_fix_row3", variant="compact", visible=opts.hires_fix_show_sampler) as hr_sampler_container:
383
+
384
+ hr_checkpoint_name = gr.Dropdown(label='Hires checkpoint', elem_id="hr_checkpoint", choices=["Use same checkpoint"] + modules.sd_models.checkpoint_tiles(use_short=True), value="Use same checkpoint")
385
+ create_refresh_button(hr_checkpoint_name, modules.sd_models.list_models, lambda: {"choices": ["Use same checkpoint"] + modules.sd_models.checkpoint_tiles(use_short=True)}, "hr_checkpoint_refresh")
386
+
387
+ hr_sampler_name = gr.Dropdown(label='Hires sampling method', elem_id="hr_sampler", choices=["Use same sampler"] + sd_samplers.visible_sampler_names(), value="Use same sampler")
388
+
389
+ with FormRow(elem_id="txt2img_hires_fix_row4", variant="compact", visible=opts.hires_fix_show_prompts) as hr_prompts_container:
390
+ with gr.Column(scale=80):
391
+ with gr.Row():
392
+ hr_prompt = gr.Textbox(label="Hires prompt", elem_id="hires_prompt", show_label=False, lines=3, placeholder="Prompt for hires fix pass.\nLeave empty to use the same prompt as in first pass.", elem_classes=["prompt"])
393
+ with gr.Column(scale=80):
394
+ with gr.Row():
395
+ hr_negative_prompt = gr.Textbox(label="Hires negative prompt", elem_id="hires_neg_prompt", show_label=False, lines=3, placeholder="Negative prompt for hires fix pass.\nLeave empty to use the same negative prompt as in first pass.", elem_classes=["prompt"])
396
+
397
+ scripts.scripts_txt2img.setup_ui_for_section(category)
398
+
399
+ elif category == "batch":
400
+ if not opts.dimensions_and_batch_together:
401
+ with FormRow(elem_id="txt2img_column_batch"):
402
+ batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="txt2img_batch_count")
403
+ batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="txt2img_batch_size")
404
+
405
+ elif category == "override_settings":
406
+ with FormRow(elem_id="txt2img_override_settings_row") as row:
407
+ override_settings = create_override_settings_dropdown('txt2img', row)
408
+
409
+ elif category == "scripts":
410
+ with FormGroup(elem_id="txt2img_script_container"):
411
+ custom_inputs = scripts.scripts_txt2img.setup_ui()
412
+
413
+ if category not in {"accordions"}:
414
+ scripts.scripts_txt2img.setup_ui_for_section(category)
415
+
416
+ hr_resolution_preview_inputs = [enable_hr, width, height, hr_scale, hr_resize_x, hr_resize_y]
417
+
418
+ for component in hr_resolution_preview_inputs:
419
+ event = component.release if isinstance(component, gr.Slider) else component.change
420
+
421
+ event(
422
+ fn=calc_resolution_hires,
423
+ inputs=hr_resolution_preview_inputs,
424
+ outputs=[hr_final_resolution],
425
+ show_progress=False,
426
+ )
427
+ event(
428
+ None,
429
+ _js="onCalcResolutionHires",
430
+ inputs=hr_resolution_preview_inputs,
431
+ outputs=[],
432
+ show_progress=False,
433
+ )
434
+
435
+ txt2img_gallery, generation_info, html_info, html_log = create_output_panel("txt2img", opts.outdir_txt2img_samples)
436
+
437
+ txt2img_args = dict(
438
+ fn=wrap_gradio_gpu_call(modules.txt2img.txt2img, extra_outputs=[None, '', '']),
439
+ _js="submit",
440
+ inputs=[
441
+ dummy_component,
442
+ toprow.prompt,
443
+ toprow.negative_prompt,
444
+ toprow.ui_styles.dropdown,
445
+ steps,
446
+ sampler_name,
447
+ batch_count,
448
+ batch_size,
449
+ cfg_scale,
450
+ height,
451
+ width,
452
+ enable_hr,
453
+ denoising_strength,
454
+ hr_scale,
455
+ hr_upscaler,
456
+ hr_second_pass_steps,
457
+ hr_resize_x,
458
+ hr_resize_y,
459
+ hr_checkpoint_name,
460
+ hr_sampler_name,
461
+ hr_prompt,
462
+ hr_negative_prompt,
463
+ override_settings,
464
+
465
+ ] + custom_inputs,
466
+
467
+ outputs=[
468
+ txt2img_gallery,
469
+ generation_info,
470
+ html_info,
471
+ html_log,
472
+ ],
473
+ show_progress=False,
474
+ )
475
+
476
+ toprow.prompt.submit(**txt2img_args)
477
+ toprow.submit.click(**txt2img_args)
478
+
479
+ res_switch_btn.click(fn=None, _js="function(){switchWidthHeight('txt2img')}", inputs=None, outputs=None, show_progress=False)
480
+
481
+ toprow.restore_progress_button.click(
482
+ fn=progress.restore_progress,
483
+ _js="restoreProgressTxt2img",
484
+ inputs=[dummy_component],
485
+ outputs=[
486
+ txt2img_gallery,
487
+ generation_info,
488
+ html_info,
489
+ html_log,
490
+ ],
491
+ show_progress=False,
492
+ )
493
+
494
+ txt2img_paste_fields = [
495
+ (toprow.prompt, "Prompt"),
496
+ (toprow.negative_prompt, "Negative prompt"),
497
+ (steps, "Steps"),
498
+ (sampler_name, "Sampler"),
499
+ (cfg_scale, "CFG scale"),
500
+ (width, "Size-1"),
501
+ (height, "Size-2"),
502
+ (batch_size, "Batch size"),
503
+ (toprow.ui_styles.dropdown, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update()),
504
+ (denoising_strength, "Denoising strength"),
505
+ (enable_hr, lambda d: "Denoising strength" in d and ("Hires upscale" in d or "Hires upscaler" in d or "Hires resize-1" in d)),
506
+ (hr_scale, "Hires upscale"),
507
+ (hr_upscaler, "Hires upscaler"),
508
+ (hr_second_pass_steps, "Hires steps"),
509
+ (hr_resize_x, "Hires resize-1"),
510
+ (hr_resize_y, "Hires resize-2"),
511
+ (hr_checkpoint_name, "Hires checkpoint"),
512
+ (hr_sampler_name, "Hires sampler"),
513
+ (hr_sampler_container, lambda d: gr.update(visible=True) if d.get("Hires sampler", "Use same sampler") != "Use same sampler" or d.get("Hires checkpoint", "Use same checkpoint") != "Use same checkpoint" else gr.update()),
514
+ (hr_prompt, "Hires prompt"),
515
+ (hr_negative_prompt, "Hires negative prompt"),
516
+ (hr_prompts_container, lambda d: gr.update(visible=True) if d.get("Hires prompt", "") != "" or d.get("Hires negative prompt", "") != "" else gr.update()),
517
+ *scripts.scripts_txt2img.infotext_fields
518
+ ]
519
+ parameters_copypaste.add_paste_fields("txt2img", None, txt2img_paste_fields, override_settings)
520
+ parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding(
521
+ paste_button=toprow.paste, tabname="txt2img", source_text_component=toprow.prompt, source_image_component=None,
522
+ ))
523
+
524
+ txt2img_preview_params = [
525
+ toprow.prompt,
526
+ toprow.negative_prompt,
527
+ steps,
528
+ sampler_name,
529
+ cfg_scale,
530
+ scripts.scripts_txt2img.script('Seed').seed,
531
+ width,
532
+ height,
533
+ ]
534
+
535
+ toprow.token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[toprow.prompt, steps], outputs=[toprow.token_counter])
536
+ toprow.negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[toprow.negative_prompt, steps], outputs=[toprow.negative_token_counter])
537
+
538
+ extra_networks_ui = ui_extra_networks.create_ui(txt2img_interface, [txt2img_generation_tab], 'txt2img')
539
+ ui_extra_networks.setup_ui(extra_networks_ui, txt2img_gallery)
540
+
541
+ extra_tabs.__exit__()
542
+
543
+ scripts.scripts_current = scripts.scripts_img2img
544
+ scripts.scripts_img2img.initialize_scripts(is_img2img=True)
545
+
546
+ with gr.Blocks(analytics_enabled=False) as img2img_interface:
547
+ toprow = Toprow(is_img2img=True)
548
+
549
+ extra_tabs = gr.Tabs(elem_id="img2img_extra_tabs")
550
+ extra_tabs.__enter__()
551
+
552
+ with gr.Tab("Generation", id="img2img_generation") as img2img_generation_tab, ResizeHandleRow(equal_height=False):
553
+ with gr.Column(variant='compact', elem_id="img2img_settings"):
554
+ copy_image_buttons = []
555
+ copy_image_destinations = {}
556
+
557
+ def add_copy_image_controls(tab_name, elem):
558
+ with gr.Row(variant="compact", elem_id=f"img2img_copy_to_{tab_name}"):
559
+ gr.HTML("Copy image to: ", elem_id=f"img2img_label_copy_to_{tab_name}")
560
+
561
+ for title, name in zip(['img2img', 'sketch', 'inpaint', 'inpaint sketch'], ['img2img', 'sketch', 'inpaint', 'inpaint_sketch']):
562
+ if name == tab_name:
563
+ gr.Button(title, interactive=False)
564
+ copy_image_destinations[name] = elem
565
+ continue
566
+
567
+ button = gr.Button(title)
568
+ copy_image_buttons.append((button, name, elem))
569
+
570
+ with gr.Tabs(elem_id="mode_img2img"):
571
+ img2img_selected_tab = gr.State(0)
572
+
573
+ with gr.TabItem('img2img', id='img2img', elem_id="img2img_img2img_tab") as tab_img2img:
574
+ init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool="editor", image_mode="RGBA", height=opts.img2img_editor_height)
575
+ add_copy_image_controls('img2img', init_img)
576
+
577
+ with gr.TabItem('Sketch', id='img2img_sketch', elem_id="img2img_img2img_sketch_tab") as tab_sketch:
578
+ sketch = gr.Image(label="Image for img2img", elem_id="img2img_sketch", show_label=False, source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGB", height=opts.img2img_editor_height, brush_color=opts.img2img_sketch_default_brush_color)
579
+ add_copy_image_controls('sketch', sketch)
580
+
581
+ with gr.TabItem('Inpaint', id='inpaint', elem_id="img2img_inpaint_tab") as tab_inpaint:
582
+ init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", image_mode="RGBA", height=opts.img2img_editor_height, brush_color=opts.img2img_inpaint_mask_brush_color)
583
+ add_copy_image_controls('inpaint', init_img_with_mask)
584
+
585
+ with gr.TabItem('Inpaint sketch', id='inpaint_sketch', elem_id="img2img_inpaint_sketch_tab") as tab_inpaint_color:
586
+ inpaint_color_sketch = gr.Image(label="Color sketch inpainting", show_label=False, elem_id="inpaint_sketch", source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGB", height=opts.img2img_editor_height, brush_color=opts.img2img_inpaint_sketch_default_brush_color)
587
+ inpaint_color_sketch_orig = gr.State(None)
588
+ add_copy_image_controls('inpaint_sketch', inpaint_color_sketch)
589
+
590
+ def update_orig(image, state):
591
+ if image is not None:
592
+ same_size = state is not None and state.size == image.size
593
+ has_exact_match = np.any(np.all(np.array(image) == np.array(state), axis=-1))
594
+ edited = same_size and has_exact_match
595
+ return image if not edited or state is None else state
596
+
597
+ inpaint_color_sketch.change(update_orig, [inpaint_color_sketch, inpaint_color_sketch_orig], inpaint_color_sketch_orig)
598
+
599
+ with gr.TabItem('Inpaint upload', id='inpaint_upload', elem_id="img2img_inpaint_upload_tab") as tab_inpaint_upload:
600
+ init_img_inpaint = gr.Image(label="Image for img2img", show_label=False, source="upload", interactive=True, type="pil", elem_id="img_inpaint_base")
601
+ init_mask_inpaint = gr.Image(label="Mask", source="upload", interactive=True, type="pil", image_mode="RGBA", elem_id="img_inpaint_mask")
602
+
603
+ with gr.TabItem('Batch', id='batch', elem_id="img2img_batch_tab") as tab_batch:
604
+ hidden = '<br>Disabled when launched with --hide-ui-dir-config.' if shared.cmd_opts.hide_ui_dir_config else ''
605
+ gr.HTML(
606
+ "<p style='padding-bottom: 1em;' class=\"text-gray-500\">Process images in a directory on the same machine where the server is running." +
607
+ "<br>Use an empty output directory to save pictures normally instead of writing to the output directory." +
608
+ f"<br>Add inpaint batch mask directory to enable inpaint batch processing."
609
+ f"{hidden}</p>"
610
+ )
611
+ img2img_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, elem_id="img2img_batch_input_dir")
612
+ img2img_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, elem_id="img2img_batch_output_dir")
613
+ img2img_batch_inpaint_mask_dir = gr.Textbox(label="Inpaint batch mask directory (required for inpaint batch processing only)", **shared.hide_dirs, elem_id="img2img_batch_inpaint_mask_dir")
614
+ with gr.Accordion("PNG info", open=False):
615
+ img2img_batch_use_png_info = gr.Checkbox(label="Append png info to prompts", **shared.hide_dirs, elem_id="img2img_batch_use_png_info")
616
+ img2img_batch_png_info_dir = gr.Textbox(label="PNG info directory", **shared.hide_dirs, placeholder="Leave empty to use input directory", elem_id="img2img_batch_png_info_dir")
617
+ img2img_batch_png_info_props = gr.CheckboxGroup(["Prompt", "Negative prompt", "Seed", "CFG scale", "Sampler", "Steps"], label="Parameters to take from png info", info="Prompts from png info will be appended to prompts set in ui.")
618
+
619
+ img2img_tabs = [tab_img2img, tab_sketch, tab_inpaint, tab_inpaint_color, tab_inpaint_upload, tab_batch]
620
+
621
+ for i, tab in enumerate(img2img_tabs):
622
+ tab.select(fn=lambda tabnum=i: tabnum, inputs=[], outputs=[img2img_selected_tab])
623
+
624
+ def copy_image(img):
625
+ if isinstance(img, dict) and 'image' in img:
626
+ return img['image']
627
+
628
+ return img
629
+
630
+ for button, name, elem in copy_image_buttons:
631
+ button.click(
632
+ fn=copy_image,
633
+ inputs=[elem],
634
+ outputs=[copy_image_destinations[name]],
635
+ )
636
+ button.click(
637
+ fn=lambda: None,
638
+ _js=f"switch_to_{name.replace(' ', '_')}",
639
+ inputs=[],
640
+ outputs=[],
641
+ )
642
+
643
+ with FormRow():
644
+ resize_mode = gr.Radio(label="Resize mode", elem_id="resize_mode", choices=["Just resize", "Crop and resize", "Resize and fill", "Just resize (latent upscale)"], type="index", value="Just resize")
645
+
646
+ scripts.scripts_img2img.prepare_ui()
647
+
648
+ for category in ordered_ui_categories():
649
+ if category == "sampler":
650
+ steps, sampler_name = create_sampler_and_steps_selection(sd_samplers.visible_sampler_names(), "img2img")
651
+
652
+ elif category == "dimensions":
653
+ with FormRow():
654
+ with gr.Column(elem_id="img2img_column_size", scale=4):
655
+ selected_scale_tab = gr.State(value=0)
656
+
657
+ with gr.Tabs():
658
+ with gr.Tab(label="Resize to", elem_id="img2img_tab_resize_to") as tab_scale_to:
659
+ with FormRow():
660
+ with gr.Column(elem_id="img2img_column_size", scale=4):
661
+ width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="img2img_width")
662
+ height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="img2img_height")
663
+ with gr.Column(elem_id="img2img_dimensions_row", scale=1, elem_classes="dimensions-tools"):
664
+ res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="img2img_res_switch_btn")
665
+ detect_image_size_btn = ToolButton(value=detect_image_size_symbol, elem_id="img2img_detect_image_size_btn")
666
+
667
+ with gr.Tab(label="Resize by", elem_id="img2img_tab_resize_by") as tab_scale_by:
668
+ scale_by = gr.Slider(minimum=0.05, maximum=4.0, step=0.05, label="Scale", value=1.0, elem_id="img2img_scale")
669
+
670
+ with FormRow():
671
+ scale_by_html = FormHTML(resize_from_to_html(0, 0, 0.0), elem_id="img2img_scale_resolution_preview")
672
+ gr.Slider(label="Unused", elem_id="img2img_unused_scale_by_slider")
673
+ button_update_resize_to = gr.Button(visible=False, elem_id="img2img_update_resize_to")
674
+
675
+ on_change_args = dict(
676
+ fn=resize_from_to_html,
677
+ _js="currentImg2imgSourceResolution",
678
+ inputs=[dummy_component, dummy_component, scale_by],
679
+ outputs=scale_by_html,
680
+ show_progress=False,
681
+ )
682
+
683
+ scale_by.release(**on_change_args)
684
+ button_update_resize_to.click(**on_change_args)
685
+
686
+ # the code below is meant to update the resolution label after the image in the image selection UI has changed.
687
+ # as it is now the event keeps firing continuously for inpaint edits, which ruins the page with constant requests.
688
+ # I assume this must be a gradio bug and for now we'll just do it for non-inpaint inputs.
689
+ for component in [init_img, sketch]:
690
+ component.change(fn=lambda: None, _js="updateImg2imgResizeToTextAfterChangingImage", inputs=[], outputs=[], show_progress=False)
691
+
692
+ tab_scale_to.select(fn=lambda: 0, inputs=[], outputs=[selected_scale_tab])
693
+ tab_scale_by.select(fn=lambda: 1, inputs=[], outputs=[selected_scale_tab])
694
+
695
+ if opts.dimensions_and_batch_together:
696
+ with gr.Column(elem_id="img2img_column_batch"):
697
+ batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="img2img_batch_count")
698
+ batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="img2img_batch_size")
699
+
700
+ elif category == "denoising":
701
+ denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.75, elem_id="img2img_denoising_strength")
702
+
703
+ elif category == "cfg":
704
+ with gr.Row():
705
+ cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="img2img_cfg_scale")
706
+ image_cfg_scale = gr.Slider(minimum=0, maximum=3.0, step=0.05, label='Image CFG Scale', value=1.5, elem_id="img2img_image_cfg_scale", visible=False)
707
+
708
+ elif category == "checkboxes":
709
+ with FormRow(elem_classes="checkboxes-row", variant="compact"):
710
+ pass
711
+
712
+ elif category == "accordions":
713
+ with gr.Row(elem_id="img2img_accordions", elem_classes="accordions"):
714
+ scripts.scripts_img2img.setup_ui_for_section(category)
715
+
716
+ elif category == "batch":
717
+ if not opts.dimensions_and_batch_together:
718
+ with FormRow(elem_id="img2img_column_batch"):
719
+ batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="img2img_batch_count")
720
+ batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="img2img_batch_size")
721
+
722
+ elif category == "override_settings":
723
+ with FormRow(elem_id="img2img_override_settings_row") as row:
724
+ override_settings = create_override_settings_dropdown('img2img', row)
725
+
726
+ elif category == "scripts":
727
+ with FormGroup(elem_id="img2img_script_container"):
728
+ custom_inputs = scripts.scripts_img2img.setup_ui()
729
+
730
+ elif category == "inpaint":
731
+ with FormGroup(elem_id="inpaint_controls", visible=False) as inpaint_controls:
732
+ with FormRow():
733
+ mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, elem_id="img2img_mask_blur")
734
+ mask_alpha = gr.Slider(label="Mask transparency", visible=False, elem_id="img2img_mask_alpha")
735
+
736
+ with FormRow():
737
+ inpainting_mask_invert = gr.Radio(label='Mask mode', choices=['Inpaint masked', 'Inpaint not masked'], value='Inpaint masked', type="index", elem_id="img2img_mask_mode")
738
+
739
+ with FormRow():
740
+ inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='original', type="index", elem_id="img2img_inpainting_fill")
741
+
742
+ with FormRow():
743
+ with gr.Column():
744
+ inpaint_full_res = gr.Radio(label="Inpaint area", choices=["Whole picture", "Only masked"], type="index", value="Whole picture", elem_id="img2img_inpaint_full_res")
745
+
746
+ with gr.Column(scale=4):
747
+ inpaint_full_res_padding = gr.Slider(label='Only masked padding, pixels', minimum=0, maximum=256, step=4, value=32, elem_id="img2img_inpaint_full_res_padding")
748
+
749
+ def select_img2img_tab(tab):
750
+ return gr.update(visible=tab in [2, 3, 4]), gr.update(visible=tab == 3),
751
+
752
+ for i, elem in enumerate(img2img_tabs):
753
+ elem.select(
754
+ fn=lambda tab=i: select_img2img_tab(tab),
755
+ inputs=[],
756
+ outputs=[inpaint_controls, mask_alpha],
757
+ )
758
+
759
+ if category not in {"accordions"}:
760
+ scripts.scripts_img2img.setup_ui_for_section(category)
761
+
762
+ img2img_gallery, generation_info, html_info, html_log = create_output_panel("img2img", opts.outdir_img2img_samples)
763
+
764
+ img2img_args = dict(
765
+ fn=wrap_gradio_gpu_call(modules.img2img.img2img, extra_outputs=[None, '', '']),
766
+ _js="submit_img2img",
767
+ inputs=[
768
+ dummy_component,
769
+ dummy_component,
770
+ toprow.prompt,
771
+ toprow.negative_prompt,
772
+ toprow.ui_styles.dropdown,
773
+ init_img,
774
+ sketch,
775
+ init_img_with_mask,
776
+ inpaint_color_sketch,
777
+ inpaint_color_sketch_orig,
778
+ init_img_inpaint,
779
+ init_mask_inpaint,
780
+ steps,
781
+ sampler_name,
782
+ mask_blur,
783
+ mask_alpha,
784
+ inpainting_fill,
785
+ batch_count,
786
+ batch_size,
787
+ cfg_scale,
788
+ image_cfg_scale,
789
+ denoising_strength,
790
+ selected_scale_tab,
791
+ height,
792
+ width,
793
+ scale_by,
794
+ resize_mode,
795
+ inpaint_full_res,
796
+ inpaint_full_res_padding,
797
+ inpainting_mask_invert,
798
+ img2img_batch_input_dir,
799
+ img2img_batch_output_dir,
800
+ img2img_batch_inpaint_mask_dir,
801
+ override_settings,
802
+ img2img_batch_use_png_info,
803
+ img2img_batch_png_info_props,
804
+ img2img_batch_png_info_dir,
805
+ ] + custom_inputs,
806
+ outputs=[
807
+ img2img_gallery,
808
+ generation_info,
809
+ html_info,
810
+ html_log,
811
+ ],
812
+ show_progress=False,
813
+ )
814
+
815
+ interrogate_args = dict(
816
+ _js="get_img2img_tab_index",
817
+ inputs=[
818
+ dummy_component,
819
+ img2img_batch_input_dir,
820
+ img2img_batch_output_dir,
821
+ init_img,
822
+ sketch,
823
+ init_img_with_mask,
824
+ inpaint_color_sketch,
825
+ init_img_inpaint,
826
+ ],
827
+ outputs=[toprow.prompt, dummy_component],
828
+ )
829
+
830
+ toprow.prompt.submit(**img2img_args)
831
+ toprow.submit.click(**img2img_args)
832
+
833
+ res_switch_btn.click(fn=None, _js="function(){switchWidthHeight('img2img')}", inputs=None, outputs=None, show_progress=False)
834
+
835
+ detect_image_size_btn.click(
836
+ fn=lambda w, h, _: (w or gr.update(), h or gr.update()),
837
+ _js="currentImg2imgSourceResolution",
838
+ inputs=[dummy_component, dummy_component, dummy_component],
839
+ outputs=[width, height],
840
+ show_progress=False,
841
+ )
842
+
843
+ toprow.restore_progress_button.click(
844
+ fn=progress.restore_progress,
845
+ _js="restoreProgressImg2img",
846
+ inputs=[dummy_component],
847
+ outputs=[
848
+ img2img_gallery,
849
+ generation_info,
850
+ html_info,
851
+ html_log,
852
+ ],
853
+ show_progress=False,
854
+ )
855
+
856
+ toprow.button_interrogate.click(
857
+ fn=lambda *args: process_interrogate(interrogate, *args),
858
+ **interrogate_args,
859
+ )
860
+
861
+ toprow.button_deepbooru.click(
862
+ fn=lambda *args: process_interrogate(interrogate_deepbooru, *args),
863
+ **interrogate_args,
864
+ )
865
+
866
+ toprow.token_button.click(fn=update_token_counter, inputs=[toprow.prompt, steps], outputs=[toprow.token_counter])
867
+ toprow.negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[toprow.negative_prompt, steps], outputs=[toprow.negative_token_counter])
868
+
869
+ img2img_paste_fields = [
870
+ (toprow.prompt, "Prompt"),
871
+ (toprow.negative_prompt, "Negative prompt"),
872
+ (steps, "Steps"),
873
+ (sampler_name, "Sampler"),
874
+ (cfg_scale, "CFG scale"),
875
+ (image_cfg_scale, "Image CFG scale"),
876
+ (width, "Size-1"),
877
+ (height, "Size-2"),
878
+ (batch_size, "Batch size"),
879
+ (toprow.ui_styles.dropdown, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update()),
880
+ (denoising_strength, "Denoising strength"),
881
+ (mask_blur, "Mask blur"),
882
+ *scripts.scripts_img2img.infotext_fields
883
+ ]
884
+ parameters_copypaste.add_paste_fields("img2img", init_img, img2img_paste_fields, override_settings)
885
+ parameters_copypaste.add_paste_fields("inpaint", init_img_with_mask, img2img_paste_fields, override_settings)
886
+ parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding(
887
+ paste_button=toprow.paste, tabname="img2img", source_text_component=toprow.prompt, source_image_component=None,
888
+ ))
889
+
890
+ extra_networks_ui_img2img = ui_extra_networks.create_ui(img2img_interface, [img2img_generation_tab], 'img2img')
891
+ ui_extra_networks.setup_ui(extra_networks_ui_img2img, img2img_gallery)
892
+
893
+ extra_tabs.__exit__()
894
+
895
+ scripts.scripts_current = None
896
+
897
+ with gr.Blocks(analytics_enabled=False) as extras_interface:
898
+ ui_postprocessing.create_ui()
899
+
900
+ with gr.Blocks(analytics_enabled=False) as pnginfo_interface:
901
+ with gr.Row(equal_height=False):
902
+ with gr.Column(variant='panel'):
903
+ image = gr.Image(elem_id="pnginfo_image", label="Source", source="upload", interactive=True, type="pil")
904
+
905
+ with gr.Column(variant='panel'):
906
+ html = gr.HTML()
907
+ generation_info = gr.Textbox(visible=False, elem_id="pnginfo_generation_info")
908
+ html2 = gr.HTML()
909
+ with gr.Row():
910
+ buttons = parameters_copypaste.create_buttons(["txt2img", "img2img", "inpaint", "extras"])
911
+
912
+ for tabname, button in buttons.items():
913
+ parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding(
914
+ paste_button=button, tabname=tabname, source_text_component=generation_info, source_image_component=image,
915
+ ))
916
+
917
+ image.change(
918
+ fn=wrap_gradio_call(modules.extras.run_pnginfo),
919
+ inputs=[image],
920
+ outputs=[html, generation_info, html2],
921
+ )
922
+
923
+ modelmerger_ui = ui_checkpoint_merger.UiCheckpointMerger()
924
+
925
+ with gr.Blocks(analytics_enabled=False) as train_interface:
926
+ with gr.Row(equal_height=False):
927
+ gr.HTML(value="<p style='margin-bottom: 0.7em'>See <b><a href=\"https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Textual-Inversion\">wiki</a></b> for detailed explanation.</p>")
928
+
929
+ with gr.Row(variant="compact", equal_height=False):
930
+ with gr.Tabs(elem_id="train_tabs"):
931
+
932
+ with gr.Tab(label="Create embedding", id="create_embedding"):
933
+ new_embedding_name = gr.Textbox(label="Name", elem_id="train_new_embedding_name")
934
+ initialization_text = gr.Textbox(label="Initialization text", value="*", elem_id="train_initialization_text")
935
+ nvpt = gr.Slider(label="Number of vectors per token", minimum=1, maximum=75, step=1, value=1, elem_id="train_nvpt")
936
+ overwrite_old_embedding = gr.Checkbox(value=False, label="Overwrite Old Embedding", elem_id="train_overwrite_old_embedding")
937
+
938
+ with gr.Row():
939
+ with gr.Column(scale=3):
940
+ gr.HTML(value="")
941
+
942
+ with gr.Column():
943
+ create_embedding = gr.Button(value="Create embedding", variant='primary', elem_id="train_create_embedding")
944
+
945
+ with gr.Tab(label="Create hypernetwork", id="create_hypernetwork"):
946
+ new_hypernetwork_name = gr.Textbox(label="Name", elem_id="train_new_hypernetwork_name")
947
+ new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "1024", "320", "640", "1280"], elem_id="train_new_hypernetwork_sizes")
948
+ new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'", elem_id="train_new_hypernetwork_layer_structure")
949
+ new_hypernetwork_activation_func = gr.Dropdown(value="linear", label="Select activation function of hypernetwork. Recommended : Swish / Linear(none)", choices=hypernetworks_ui.keys, elem_id="train_new_hypernetwork_activation_func")
950
+ new_hypernetwork_initialization_option = gr.Dropdown(value = "Normal", label="Select Layer weights initialization. Recommended: Kaiming for relu-like, Xavier for sigmoid-like, Normal otherwise", choices=["Normal", "KaimingUniform", "KaimingNormal", "XavierUniform", "XavierNormal"], elem_id="train_new_hypernetwork_initialization_option")
951
+ new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization", elem_id="train_new_hypernetwork_add_layer_norm")
952
+ new_hypernetwork_use_dropout = gr.Checkbox(label="Use dropout", elem_id="train_new_hypernetwork_use_dropout")
953
+ new_hypernetwork_dropout_structure = gr.Textbox("0, 0, 0", label="Enter hypernetwork Dropout structure (or empty). Recommended : 0~0.35 incrementing sequence: 0, 0.05, 0.15", placeholder="1st and last digit must be 0 and values should be between 0 and 1. ex:'0, 0.01, 0'")
954
+ overwrite_old_hypernetwork = gr.Checkbox(value=False, label="Overwrite Old Hypernetwork", elem_id="train_overwrite_old_hypernetwork")
955
+
956
+ with gr.Row():
957
+ with gr.Column(scale=3):
958
+ gr.HTML(value="")
959
+
960
+ with gr.Column():
961
+ create_hypernetwork = gr.Button(value="Create hypernetwork", variant='primary', elem_id="train_create_hypernetwork")
962
+
963
+ with gr.Tab(label="Preprocess images", id="preprocess_images"):
964
+ process_src = gr.Textbox(label='Source directory', elem_id="train_process_src")
965
+ process_dst = gr.Textbox(label='Destination directory', elem_id="train_process_dst")
966
+ process_width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="train_process_width")
967
+ process_height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="train_process_height")
968
+ preprocess_txt_action = gr.Dropdown(label='Existing Caption txt Action', value="ignore", choices=["ignore", "copy", "prepend", "append"], elem_id="train_preprocess_txt_action")
969
+
970
+ with gr.Row():
971
+ process_keep_original_size = gr.Checkbox(label='Keep original size', elem_id="train_process_keep_original_size")
972
+ process_flip = gr.Checkbox(label='Create flipped copies', elem_id="train_process_flip")
973
+ process_split = gr.Checkbox(label='Split oversized images', elem_id="train_process_split")
974
+ process_focal_crop = gr.Checkbox(label='Auto focal point crop', elem_id="train_process_focal_crop")
975
+ process_multicrop = gr.Checkbox(label='Auto-sized crop', elem_id="train_process_multicrop")
976
+ process_caption = gr.Checkbox(label='Use BLIP for caption', elem_id="train_process_caption")
977
+ process_caption_deepbooru = gr.Checkbox(label='Use deepbooru for caption', visible=True, elem_id="train_process_caption_deepbooru")
978
+
979
+ with gr.Row(visible=False) as process_split_extra_row:
980
+ process_split_threshold = gr.Slider(label='Split image threshold', value=0.5, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_split_threshold")
981
+ process_overlap_ratio = gr.Slider(label='Split image overlap ratio', value=0.2, minimum=0.0, maximum=0.9, step=0.05, elem_id="train_process_overlap_ratio")
982
+
983
+ with gr.Row(visible=False) as process_focal_crop_row:
984
+ process_focal_crop_face_weight = gr.Slider(label='Focal point face weight', value=0.9, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_face_weight")
985
+ process_focal_crop_entropy_weight = gr.Slider(label='Focal point entropy weight', value=0.15, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_entropy_weight")
986
+ process_focal_crop_edges_weight = gr.Slider(label='Focal point edges weight', value=0.5, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_edges_weight")
987
+ process_focal_crop_debug = gr.Checkbox(label='Create debug image', elem_id="train_process_focal_crop_debug")
988
+
989
+ with gr.Column(visible=False) as process_multicrop_col:
990
+ gr.Markdown('Each image is center-cropped with an automatically chosen width and height.')
991
+ with gr.Row():
992
+ process_multicrop_mindim = gr.Slider(minimum=64, maximum=2048, step=8, label="Dimension lower bound", value=384, elem_id="train_process_multicrop_mindim")
993
+ process_multicrop_maxdim = gr.Slider(minimum=64, maximum=2048, step=8, label="Dimension upper bound", value=768, elem_id="train_process_multicrop_maxdim")
994
+ with gr.Row():
995
+ process_multicrop_minarea = gr.Slider(minimum=64*64, maximum=2048*2048, step=1, label="Area lower bound", value=64*64, elem_id="train_process_multicrop_minarea")
996
+ process_multicrop_maxarea = gr.Slider(minimum=64*64, maximum=2048*2048, step=1, label="Area upper bound", value=640*640, elem_id="train_process_multicrop_maxarea")
997
+ with gr.Row():
998
+ process_multicrop_objective = gr.Radio(["Maximize area", "Minimize error"], value="Maximize area", label="Resizing objective", elem_id="train_process_multicrop_objective")
999
+ process_multicrop_threshold = gr.Slider(minimum=0, maximum=1, step=0.01, label="Error threshold", value=0.1, elem_id="train_process_multicrop_threshold")
1000
+
1001
+ with gr.Row():
1002
+ with gr.Column(scale=3):
1003
+ gr.HTML(value="")
1004
+
1005
+ with gr.Column():
1006
+ with gr.Row():
1007
+ interrupt_preprocessing = gr.Button("Interrupt", elem_id="train_interrupt_preprocessing")
1008
+ run_preprocess = gr.Button(value="Preprocess", variant='primary', elem_id="train_run_preprocess")
1009
+
1010
+ process_split.change(
1011
+ fn=lambda show: gr_show(show),
1012
+ inputs=[process_split],
1013
+ outputs=[process_split_extra_row],
1014
+ )
1015
+
1016
+ process_focal_crop.change(
1017
+ fn=lambda show: gr_show(show),
1018
+ inputs=[process_focal_crop],
1019
+ outputs=[process_focal_crop_row],
1020
+ )
1021
+
1022
+ process_multicrop.change(
1023
+ fn=lambda show: gr_show(show),
1024
+ inputs=[process_multicrop],
1025
+ outputs=[process_multicrop_col],
1026
+ )
1027
+
1028
+ def get_textual_inversion_template_names():
1029
+ return sorted(textual_inversion.textual_inversion_templates)
1030
+
1031
+ with gr.Tab(label="Train", id="train"):
1032
+ gr.HTML(value="<p style='margin-bottom: 0.7em'>Train an embedding or Hypernetwork; you must specify a directory with a set of 1:1 ratio images <a href=\"https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Textual-Inversion\" style=\"font-weight:bold;\">[wiki]</a></p>")
1033
+ with FormRow():
1034
+ train_embedding_name = gr.Dropdown(label='Embedding', elem_id="train_embedding", choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys()))
1035
+ create_refresh_button(train_embedding_name, sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings, lambda: {"choices": sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())}, "refresh_train_embedding_name")
1036
+
1037
+ train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', elem_id="train_hypernetwork", choices=sorted(shared.hypernetworks))
1038
+ create_refresh_button(train_hypernetwork_name, shared.reload_hypernetworks, lambda: {"choices": sorted(shared.hypernetworks)}, "refresh_train_hypernetwork_name")
1039
+
1040
+ with FormRow():
1041
+ embedding_learn_rate = gr.Textbox(label='Embedding Learning rate', placeholder="Embedding Learning rate", value="0.005", elem_id="train_embedding_learn_rate")
1042
+ hypernetwork_learn_rate = gr.Textbox(label='Hypernetwork Learning rate', placeholder="Hypernetwork Learning rate", value="0.00001", elem_id="train_hypernetwork_learn_rate")
1043
+
1044
+ with FormRow():
1045
+ clip_grad_mode = gr.Dropdown(value="disabled", label="Gradient Clipping", choices=["disabled", "value", "norm"])
1046
+ clip_grad_value = gr.Textbox(placeholder="Gradient clip value", value="0.1", show_label=False)
1047
+
1048
+ with FormRow():
1049
+ batch_size = gr.Number(label='Batch size', value=1, precision=0, elem_id="train_batch_size")
1050
+ gradient_step = gr.Number(label='Gradient accumulation steps', value=1, precision=0, elem_id="train_gradient_step")
1051
+
1052
+ dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images", elem_id="train_dataset_directory")
1053
+ log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion", elem_id="train_log_directory")
1054
+
1055
+ with FormRow():
1056
+ template_file = gr.Dropdown(label='Prompt template', value="style_filewords.txt", elem_id="train_template_file", choices=get_textual_inversion_template_names())
1057
+ create_refresh_button(template_file, textual_inversion.list_textual_inversion_templates, lambda: {"choices": get_textual_inversion_template_names()}, "refrsh_train_template_file")
1058
+
1059
+ training_width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="train_training_width")
1060
+ training_height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="train_training_height")
1061
+ varsize = gr.Checkbox(label="Do not resize images", value=False, elem_id="train_varsize")
1062
+ steps = gr.Number(label='Max steps', value=100000, precision=0, elem_id="train_steps")
1063
+
1064
+ with FormRow():
1065
+ create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0, elem_id="train_create_image_every")
1066
+ save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0, elem_id="train_save_embedding_every")
1067
+
1068
+ use_weight = gr.Checkbox(label="Use PNG alpha channel as loss weight", value=False, elem_id="use_weight")
1069
+
1070
+ save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True, elem_id="train_save_image_with_stored_embedding")
1071
+ preview_from_txt2img = gr.Checkbox(label='Read parameters (prompt, etc...) from txt2img tab when making previews', value=False, elem_id="train_preview_from_txt2img")
1072
+
1073
+ shuffle_tags = gr.Checkbox(label="Shuffle tags by ',' when creating prompts.", value=False, elem_id="train_shuffle_tags")
1074
+ tag_drop_out = gr.Slider(minimum=0, maximum=1, step=0.1, label="Drop out tags when creating prompts.", value=0, elem_id="train_tag_drop_out")
1075
+
1076
+ latent_sampling_method = gr.Radio(label='Choose latent sampling method', value="once", choices=['once', 'deterministic', 'random'], elem_id="train_latent_sampling_method")
1077
+
1078
+ with gr.Row():
1079
+ train_embedding = gr.Button(value="Train Embedding", variant='primary', elem_id="train_train_embedding")
1080
+ interrupt_training = gr.Button(value="Interrupt", elem_id="train_interrupt_training")
1081
+ train_hypernetwork = gr.Button(value="Train Hypernetwork", variant='primary', elem_id="train_train_hypernetwork")
1082
+
1083
+ params = script_callbacks.UiTrainTabParams(txt2img_preview_params)
1084
+
1085
+ script_callbacks.ui_train_tabs_callback(params)
1086
+
1087
+ with gr.Column(elem_id='ti_gallery_container'):
1088
+ ti_output = gr.Text(elem_id="ti_output", value="", show_label=False)
1089
+ gr.Gallery(label='Output', show_label=False, elem_id='ti_gallery', columns=4)
1090
+ gr.HTML(elem_id="ti_progress", value="")
1091
+ ti_outcome = gr.HTML(elem_id="ti_error", value="")
1092
+
1093
+ create_embedding.click(
1094
+ fn=textual_inversion_ui.create_embedding,
1095
+ inputs=[
1096
+ new_embedding_name,
1097
+ initialization_text,
1098
+ nvpt,
1099
+ overwrite_old_embedding,
1100
+ ],
1101
+ outputs=[
1102
+ train_embedding_name,
1103
+ ti_output,
1104
+ ti_outcome,
1105
+ ]
1106
+ )
1107
+
1108
+ create_hypernetwork.click(
1109
+ fn=hypernetworks_ui.create_hypernetwork,
1110
+ inputs=[
1111
+ new_hypernetwork_name,
1112
+ new_hypernetwork_sizes,
1113
+ overwrite_old_hypernetwork,
1114
+ new_hypernetwork_layer_structure,
1115
+ new_hypernetwork_activation_func,
1116
+ new_hypernetwork_initialization_option,
1117
+ new_hypernetwork_add_layer_norm,
1118
+ new_hypernetwork_use_dropout,
1119
+ new_hypernetwork_dropout_structure
1120
+ ],
1121
+ outputs=[
1122
+ train_hypernetwork_name,
1123
+ ti_output,
1124
+ ti_outcome,
1125
+ ]
1126
+ )
1127
+
1128
+ run_preprocess.click(
1129
+ fn=wrap_gradio_gpu_call(textual_inversion_ui.preprocess, extra_outputs=[gr.update()]),
1130
+ _js="start_training_textual_inversion",
1131
+ inputs=[
1132
+ dummy_component,
1133
+ process_src,
1134
+ process_dst,
1135
+ process_width,
1136
+ process_height,
1137
+ preprocess_txt_action,
1138
+ process_keep_original_size,
1139
+ process_flip,
1140
+ process_split,
1141
+ process_caption,
1142
+ process_caption_deepbooru,
1143
+ process_split_threshold,
1144
+ process_overlap_ratio,
1145
+ process_focal_crop,
1146
+ process_focal_crop_face_weight,
1147
+ process_focal_crop_entropy_weight,
1148
+ process_focal_crop_edges_weight,
1149
+ process_focal_crop_debug,
1150
+ process_multicrop,
1151
+ process_multicrop_mindim,
1152
+ process_multicrop_maxdim,
1153
+ process_multicrop_minarea,
1154
+ process_multicrop_maxarea,
1155
+ process_multicrop_objective,
1156
+ process_multicrop_threshold,
1157
+ ],
1158
+ outputs=[
1159
+ ti_output,
1160
+ ti_outcome,
1161
+ ],
1162
+ )
1163
+
1164
+ train_embedding.click(
1165
+ fn=wrap_gradio_gpu_call(textual_inversion_ui.train_embedding, extra_outputs=[gr.update()]),
1166
+ _js="start_training_textual_inversion",
1167
+ inputs=[
1168
+ dummy_component,
1169
+ train_embedding_name,
1170
+ embedding_learn_rate,
1171
+ batch_size,
1172
+ gradient_step,
1173
+ dataset_directory,
1174
+ log_directory,
1175
+ training_width,
1176
+ training_height,
1177
+ varsize,
1178
+ steps,
1179
+ clip_grad_mode,
1180
+ clip_grad_value,
1181
+ shuffle_tags,
1182
+ tag_drop_out,
1183
+ latent_sampling_method,
1184
+ use_weight,
1185
+ create_image_every,
1186
+ save_embedding_every,
1187
+ template_file,
1188
+ save_image_with_stored_embedding,
1189
+ preview_from_txt2img,
1190
+ *txt2img_preview_params,
1191
+ ],
1192
+ outputs=[
1193
+ ti_output,
1194
+ ti_outcome,
1195
+ ]
1196
+ )
1197
+
1198
+ train_hypernetwork.click(
1199
+ fn=wrap_gradio_gpu_call(hypernetworks_ui.train_hypernetwork, extra_outputs=[gr.update()]),
1200
+ _js="start_training_textual_inversion",
1201
+ inputs=[
1202
+ dummy_component,
1203
+ train_hypernetwork_name,
1204
+ hypernetwork_learn_rate,
1205
+ batch_size,
1206
+ gradient_step,
1207
+ dataset_directory,
1208
+ log_directory,
1209
+ training_width,
1210
+ training_height,
1211
+ varsize,
1212
+ steps,
1213
+ clip_grad_mode,
1214
+ clip_grad_value,
1215
+ shuffle_tags,
1216
+ tag_drop_out,
1217
+ latent_sampling_method,
1218
+ use_weight,
1219
+ create_image_every,
1220
+ save_embedding_every,
1221
+ template_file,
1222
+ preview_from_txt2img,
1223
+ *txt2img_preview_params,
1224
+ ],
1225
+ outputs=[
1226
+ ti_output,
1227
+ ti_outcome,
1228
+ ]
1229
+ )
1230
+
1231
+ interrupt_training.click(
1232
+ fn=lambda: shared.state.interrupt(),
1233
+ inputs=[],
1234
+ outputs=[],
1235
+ )
1236
+
1237
+ interrupt_preprocessing.click(
1238
+ fn=lambda: shared.state.interrupt(),
1239
+ inputs=[],
1240
+ outputs=[],
1241
+ )
1242
+
1243
+ loadsave = ui_loadsave.UiLoadsave(cmd_opts.ui_config_file)
1244
+
1245
+ settings = ui_settings.UiSettings()
1246
+ settings.create_ui(loadsave, dummy_component)
1247
+
1248
+ interfaces = [
1249
+ (txt2img_interface, "txt2img", "txt2img"),
1250
+ (img2img_interface, "img2img", "img2img"),
1251
+ (extras_interface, "Extras", "extras"),
1252
+ (pnginfo_interface, "PNG Info", "pnginfo"),
1253
+ (modelmerger_ui.blocks, "Checkpoint Merger", "modelmerger"),
1254
+ (train_interface, "Train", "train"),
1255
+ ]
1256
+
1257
+ interfaces += script_callbacks.ui_tabs_callback()
1258
+ interfaces += [(settings.interface, "Settings", "settings")]
1259
+
1260
+ extensions_interface = ui_extensions.create_ui()
1261
+ interfaces += [(extensions_interface, "Extensions", "extensions")]
1262
+
1263
+ shared.tab_names = []
1264
+ for _interface, label, _ifid in interfaces:
1265
+ shared.tab_names.append(label)
1266
+
1267
+ with gr.Blocks(theme=shared.gradio_theme, analytics_enabled=False, title="Stable Diffusion") as demo:
1268
+ settings.add_quicksettings()
1269
+
1270
+ parameters_copypaste.connect_paste_params_buttons()
1271
+
1272
+ with gr.Tabs(elem_id="tabs") as tabs:
1273
+ tab_order = {k: i for i, k in enumerate(opts.ui_tab_order)}
1274
+ sorted_interfaces = sorted(interfaces, key=lambda x: tab_order.get(x[1], 9999))
1275
+
1276
+ for interface, label, ifid in sorted_interfaces:
1277
+ if label in shared.opts.hidden_tabs:
1278
+ continue
1279
+ with gr.TabItem(label, id=ifid, elem_id=f"tab_{ifid}"):
1280
+ interface.render()
1281
+
1282
+ if ifid not in ["extensions", "settings"]:
1283
+ loadsave.add_block(interface, ifid)
1284
+
1285
+ loadsave.add_component(f"webui/Tabs@{tabs.elem_id}", tabs)
1286
+
1287
+ loadsave.setup_ui()
1288
+
1289
+ if os.path.exists(os.path.join(script_path, "notification.mp3")):
1290
+ gr.Audio(interactive=False, value=os.path.join(script_path, "notification.mp3"), elem_id="audio_notification", visible=False)
1291
+
1292
+ footer = shared.html("footer.html")
1293
+ footer = footer.format(versions=versions_html(), api_docs="/docs" if shared.cmd_opts.api else "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/API")
1294
+ gr.HTML(footer, elem_id="footer")
1295
+
1296
+ settings.add_functionality(demo)
1297
+
1298
+ update_image_cfg_scale_visibility = lambda: gr.update(visible=shared.sd_model and shared.sd_model.cond_stage_key == "edit")
1299
+ settings.text_settings.change(fn=update_image_cfg_scale_visibility, inputs=[], outputs=[image_cfg_scale])
1300
+ demo.load(fn=update_image_cfg_scale_visibility, inputs=[], outputs=[image_cfg_scale])
1301
+
1302
+ modelmerger_ui.setup_ui(dummy_component=dummy_component, sd_model_checkpoint_component=settings.component_dict['sd_model_checkpoint'])
1303
+
1304
+ loadsave.dump_defaults()
1305
+ demo.ui_loadsave = loadsave
1306
+
1307
+ return demo
1308
+
1309
+
1310
+ def versions_html():
1311
+ import torch
1312
+ import launch
1313
+
1314
+ python_version = ".".join([str(x) for x in sys.version_info[0:3]])
1315
+ commit = launch.commit_hash()
1316
+ tag = launch.git_tag()
1317
+
1318
+ if shared.xformers_available:
1319
+ import xformers
1320
+ xformers_version = xformers.__version__
1321
+ else:
1322
+ xformers_version = "N/A"
1323
+
1324
+ return f"""
1325
+ version: <a href="https://github.com/AUTOMATIC1111/stable-diffusion-webui/commit/{commit}">{tag}</a>
1326
+ &#x2000;•&#x2000;
1327
+ python: <span title="{sys.version}">{python_version}</span>
1328
+ &#x2000;•&#x2000;
1329
+ torch: {getattr(torch, '__long_version__',torch.__version__)}
1330
+ &#x2000;•&#x2000;
1331
+ xformers: {xformers_version}
1332
+ &#x2000;•&#x2000;
1333
+ gradio: {gr.__version__}
1334
+ &#x2000;•&#x2000;
1335
+ checkpoint: <a id="sd_checkpoint_hash">N/A</a>
1336
+ """
1337
+
1338
+
1339
+ def setup_ui_api(app):
1340
+ from pydantic import BaseModel, Field
1341
+ from typing import List
1342
+
1343
+ class QuicksettingsHint(BaseModel):
1344
+ name: str = Field(title="Name of the quicksettings field")
1345
+ label: str = Field(title="Label of the quicksettings field")
1346
+
1347
+ def quicksettings_hint():
1348
+ return [QuicksettingsHint(name=k, label=v.label) for k, v in opts.data_labels.items()]
1349
+
1350
+ app.add_api_route("/internal/quicksettings-hint", quicksettings_hint, methods=["GET"], response_model=List[QuicksettingsHint])
1351
+
1352
+ app.add_api_route("/internal/ping", lambda: {}, methods=["GET"])
1353
+
1354
+ app.add_api_route("/internal/profile-startup", lambda: timer.startup_record, methods=["GET"])
1355
+
1356
+ def download_sysinfo(attachment=False):
1357
+ from fastapi.responses import PlainTextResponse
1358
+
1359
+ text = sysinfo.get()
1360
+ filename = f"sysinfo-{datetime.datetime.utcnow().strftime('%Y-%m-%d-%H-%M')}.txt"
1361
+
1362
+ return PlainTextResponse(text, headers={'Content-Disposition': f'{"attachment" if attachment else "inline"}; filename="{filename}"'})
1363
+
1364
+ app.add_api_route("/internal/sysinfo", download_sysinfo, methods=["GET"])
1365
+ app.add_api_route("/internal/sysinfo-download", lambda: download_sysinfo(attachment=True), methods=["GET"])
1366
+
modules/ui_checkpoint_merger.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import gradio as gr
3
+
4
+ from modules import sd_models, sd_vae, errors, extras, call_queue
5
+ from modules.ui_components import FormRow
6
+ from modules.ui_common import create_refresh_button
7
+
8
+
9
+ def update_interp_description(value):
10
+ interp_description_css = "<p style='margin-bottom: 2.5em'>{}</p>"
11
+ interp_descriptions = {
12
+ "No interpolation": interp_description_css.format("No interpolation will be used. Requires one model; A. Allows for format conversion and VAE baking."),
13
+ "Weighted sum": interp_description_css.format("A weighted sum will be used for interpolation. Requires two models; A and B. The result is calculated as A * (1 - M) + B * M"),
14
+ "Add difference": interp_description_css.format("The difference between the last two models will be added to the first. Requires three models; A, B and C. The result is calculated as A + (B - C) * M")
15
+ }
16
+ return interp_descriptions[value]
17
+
18
+
19
+ def modelmerger(*args):
20
+ try:
21
+ results = extras.run_modelmerger(*args)
22
+ except Exception as e:
23
+ errors.report("Error loading/saving model file", exc_info=True)
24
+ sd_models.list_models() # to remove the potentially missing models from the list
25
+ return [*[gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)], f"Error merging checkpoints: {e}"]
26
+ return results
27
+
28
+
29
+ class UiCheckpointMerger:
30
+ def __init__(self):
31
+ with gr.Blocks(analytics_enabled=False) as modelmerger_interface:
32
+ with gr.Row(equal_height=False):
33
+ with gr.Column(variant='compact'):
34
+ self.interp_description = gr.HTML(value=update_interp_description("Weighted sum"), elem_id="modelmerger_interp_description")
35
+
36
+ with FormRow(elem_id="modelmerger_models"):
37
+ self.primary_model_name = gr.Dropdown(sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary model (A)")
38
+ create_refresh_button(self.primary_model_name, sd_models.list_models, lambda: {"choices": sd_models.checkpoint_tiles()}, "refresh_checkpoint_A")
39
+
40
+ self.secondary_model_name = gr.Dropdown(sd_models.checkpoint_tiles(), elem_id="modelmerger_secondary_model_name", label="Secondary model (B)")
41
+ create_refresh_button(self.secondary_model_name, sd_models.list_models, lambda: {"choices": sd_models.checkpoint_tiles()}, "refresh_checkpoint_B")
42
+
43
+ self.tertiary_model_name = gr.Dropdown(sd_models.checkpoint_tiles(), elem_id="modelmerger_tertiary_model_name", label="Tertiary model (C)")
44
+ create_refresh_button(self.tertiary_model_name, sd_models.list_models, lambda: {"choices": sd_models.checkpoint_tiles()}, "refresh_checkpoint_C")
45
+
46
+ self.custom_name = gr.Textbox(label="Custom Name (Optional)", elem_id="modelmerger_custom_name")
47
+ self.interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Multiplier (M) - set to 0 to get model A', value=0.3, elem_id="modelmerger_interp_amount")
48
+ self.interp_method = gr.Radio(choices=["No interpolation", "Weighted sum", "Add difference"], value="Weighted sum", label="Interpolation Method", elem_id="modelmerger_interp_method")
49
+ self.interp_method.change(fn=update_interp_description, inputs=[self.interp_method], outputs=[self.interp_description])
50
+
51
+ with FormRow():
52
+ self.checkpoint_format = gr.Radio(choices=["ckpt", "safetensors"], value="safetensors", label="Checkpoint format", elem_id="modelmerger_checkpoint_format")
53
+ self.save_as_half = gr.Checkbox(value=False, label="Save as float16", elem_id="modelmerger_save_as_half")
54
+
55
+ with FormRow():
56
+ with gr.Column():
57
+ self.config_source = gr.Radio(choices=["A, B or C", "B", "C", "Don't"], value="A, B or C", label="Copy config from", type="index", elem_id="modelmerger_config_method")
58
+
59
+ with gr.Column():
60
+ with FormRow():
61
+ self.bake_in_vae = gr.Dropdown(choices=["None"] + list(sd_vae.vae_dict), value="None", label="Bake in VAE", elem_id="modelmerger_bake_in_vae")
62
+ create_refresh_button(self.bake_in_vae, sd_vae.refresh_vae_list, lambda: {"choices": ["None"] + list(sd_vae.vae_dict)}, "modelmerger_refresh_bake_in_vae")
63
+
64
+ with FormRow():
65
+ self.discard_weights = gr.Textbox(value="", label="Discard weights with matching name", elem_id="modelmerger_discard_weights")
66
+
67
+ with gr.Accordion("Metadata", open=False) as metadata_editor:
68
+ with FormRow():
69
+ self.save_metadata = gr.Checkbox(value=True, label="Save metadata", elem_id="modelmerger_save_metadata")
70
+ self.add_merge_recipe = gr.Checkbox(value=True, label="Add merge recipe metadata", elem_id="modelmerger_add_recipe")
71
+ self.copy_metadata_fields = gr.Checkbox(value=True, label="Copy metadata from merged models", elem_id="modelmerger_copy_metadata")
72
+
73
+ self.metadata_json = gr.TextArea('{}', label="Metadata in JSON format")
74
+ self.read_metadata = gr.Button("Read metadata from selected checkpoints")
75
+
76
+ with FormRow():
77
+ self.modelmerger_merge = gr.Button(elem_id="modelmerger_merge", value="Merge", variant='primary')
78
+
79
+ with gr.Column(variant='compact', elem_id="modelmerger_results_container"):
80
+ with gr.Group(elem_id="modelmerger_results_panel"):
81
+ self.modelmerger_result = gr.HTML(elem_id="modelmerger_result", show_label=False)
82
+
83
+ self.metadata_editor = metadata_editor
84
+ self.blocks = modelmerger_interface
85
+
86
+ def setup_ui(self, dummy_component, sd_model_checkpoint_component):
87
+ self.checkpoint_format.change(lambda fmt: gr.update(visible=fmt == 'safetensors'), inputs=[self.checkpoint_format], outputs=[self.metadata_editor], show_progress=False)
88
+
89
+ self.read_metadata.click(extras.read_metadata, inputs=[self.primary_model_name, self.secondary_model_name, self.tertiary_model_name], outputs=[self.metadata_json])
90
+
91
+ self.modelmerger_merge.click(fn=lambda: '', inputs=[], outputs=[self.modelmerger_result])
92
+ self.modelmerger_merge.click(
93
+ fn=call_queue.wrap_gradio_gpu_call(modelmerger, extra_outputs=lambda: [gr.update() for _ in range(4)]),
94
+ _js='modelmerger',
95
+ inputs=[
96
+ dummy_component,
97
+ self.primary_model_name,
98
+ self.secondary_model_name,
99
+ self.tertiary_model_name,
100
+ self.interp_method,
101
+ self.interp_amount,
102
+ self.save_as_half,
103
+ self.custom_name,
104
+ self.checkpoint_format,
105
+ self.config_source,
106
+ self.bake_in_vae,
107
+ self.discard_weights,
108
+ self.save_metadata,
109
+ self.add_merge_recipe,
110
+ self.copy_metadata_fields,
111
+ self.metadata_json,
112
+ ],
113
+ outputs=[
114
+ self.primary_model_name,
115
+ self.secondary_model_name,
116
+ self.tertiary_model_name,
117
+ sd_model_checkpoint_component,
118
+ self.modelmerger_result,
119
+ ]
120
+ )
121
+
122
+ # Required as a workaround for change() event not triggering when loading values from ui-config.json
123
+ self.interp_description.value = update_interp_description(self.interp_method.value)
124
+
modules/ui_common.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import html
3
+ import os
4
+ import platform
5
+ import sys
6
+
7
+ import gradio as gr
8
+ import subprocess as sp
9
+
10
+ from modules import call_queue, shared
11
+ from modules.generation_parameters_copypaste import image_from_url_text
12
+ import modules.images
13
+ from modules.ui_components import ToolButton
14
+ import modules.generation_parameters_copypaste as parameters_copypaste
15
+
16
+ folder_symbol = '\U0001f4c2' # 📂
17
+ refresh_symbol = '\U0001f504' # 🔄
18
+
19
+
20
+ def update_generation_info(generation_info, html_info, img_index):
21
+ try:
22
+ generation_info = json.loads(generation_info)
23
+ if img_index < 0 or img_index >= len(generation_info["infotexts"]):
24
+ return html_info, gr.update()
25
+ return plaintext_to_html(generation_info["infotexts"][img_index]), gr.update()
26
+ except Exception:
27
+ pass
28
+ # if the json parse or anything else fails, just return the old html_info
29
+ return html_info, gr.update()
30
+
31
+
32
+ def plaintext_to_html(text, classname=None):
33
+ content = "<br>\n".join(html.escape(x) for x in text.split('\n'))
34
+
35
+ return f"<p class='{classname}'>{content}</p>" if classname else f"<p>{content}</p>"
36
+
37
+
38
+ def save_files(js_data, images, do_make_zip, index):
39
+ import csv
40
+ filenames = []
41
+ fullfns = []
42
+
43
+ #quick dictionary to class object conversion. Its necessary due apply_filename_pattern requiring it
44
+ class MyObject:
45
+ def __init__(self, d=None):
46
+ if d is not None:
47
+ for key, value in d.items():
48
+ setattr(self, key, value)
49
+
50
+ data = json.loads(js_data)
51
+
52
+ p = MyObject(data)
53
+ path = shared.opts.outdir_save
54
+ save_to_dirs = shared.opts.use_save_to_dirs_for_ui
55
+ extension: str = shared.opts.samples_format
56
+ start_index = 0
57
+ only_one = False
58
+
59
+ if index > -1 and shared.opts.save_selected_only and (index >= data["index_of_first_image"]): # ensures we are looking at a specific non-grid picture, and we have save_selected_only
60
+ only_one = True
61
+ images = [images[index]]
62
+ start_index = index
63
+
64
+ os.makedirs(shared.opts.outdir_save, exist_ok=True)
65
+
66
+ with open(os.path.join(shared.opts.outdir_save, "log.csv"), "a", encoding="utf8", newline='') as file:
67
+ at_start = file.tell() == 0
68
+ writer = csv.writer(file)
69
+ if at_start:
70
+ writer.writerow(["prompt", "seed", "width", "height", "sampler", "cfgs", "steps", "filename", "negative_prompt"])
71
+
72
+ for image_index, filedata in enumerate(images, start_index):
73
+ image = image_from_url_text(filedata)
74
+
75
+ is_grid = image_index < p.index_of_first_image
76
+ i = 0 if is_grid else (image_index - p.index_of_first_image)
77
+
78
+ p.batch_index = image_index-1
79
+ fullfn, txt_fullfn = modules.images.save_image(image, path, "", seed=p.all_seeds[i], prompt=p.all_prompts[i], extension=extension, info=p.infotexts[image_index], grid=is_grid, p=p, save_to_dirs=save_to_dirs)
80
+
81
+ filename = os.path.relpath(fullfn, path)
82
+ filenames.append(filename)
83
+ fullfns.append(fullfn)
84
+ if txt_fullfn:
85
+ filenames.append(os.path.basename(txt_fullfn))
86
+ fullfns.append(txt_fullfn)
87
+
88
+ writer.writerow([data["prompt"], data["seed"], data["width"], data["height"], data["sampler_name"], data["cfg_scale"], data["steps"], filenames[0], data["negative_prompt"]])
89
+
90
+ # Make Zip
91
+ if do_make_zip:
92
+ zip_fileseed = p.all_seeds[index-1] if only_one else p.all_seeds[0]
93
+ namegen = modules.images.FilenameGenerator(p, zip_fileseed, p.all_prompts[0], image, True)
94
+ zip_filename = namegen.apply(shared.opts.grid_zip_filename_pattern or "[datetime]_[[model_name]]_[seed]-[seed_last]")
95
+ zip_filepath = os.path.join(path, f"{zip_filename}.zip")
96
+
97
+ from zipfile import ZipFile
98
+ with ZipFile(zip_filepath, "w") as zip_file:
99
+ for i in range(len(fullfns)):
100
+ with open(fullfns[i], mode="rb") as f:
101
+ zip_file.writestr(filenames[i], f.read())
102
+ fullfns.insert(0, zip_filepath)
103
+
104
+ return gr.File.update(value=fullfns, visible=True), plaintext_to_html(f"Saved: {filenames[0]}")
105
+
106
+
107
+ def create_output_panel(tabname, outdir):
108
+
109
+ def open_folder(f):
110
+ if not os.path.exists(f):
111
+ print(f'Folder "{f}" does not exist. After you create an image, the folder will be created.')
112
+ return
113
+ elif not os.path.isdir(f):
114
+ print(f"""
115
+ WARNING
116
+ An open_folder request was made with an argument that is not a folder.
117
+ This could be an error or a malicious attempt to run code on your computer.
118
+ Requested path was: {f}
119
+ """, file=sys.stderr)
120
+ return
121
+
122
+ if not shared.cmd_opts.hide_ui_dir_config:
123
+ path = os.path.normpath(f)
124
+ if platform.system() == "Windows":
125
+ os.startfile(path)
126
+ elif platform.system() == "Darwin":
127
+ sp.Popen(["open", path])
128
+ elif "microsoft-standard-WSL2" in platform.uname().release:
129
+ sp.Popen(["wsl-open", path])
130
+ else:
131
+ sp.Popen(["xdg-open", path])
132
+
133
+ with gr.Column(variant='panel', elem_id=f"{tabname}_results"):
134
+ with gr.Group(elem_id=f"{tabname}_gallery_container"):
135
+ result_gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery", columns=4, preview=True, height=shared.opts.gallery_height or None)
136
+
137
+ generation_info = None
138
+ with gr.Column():
139
+ with gr.Row(elem_id=f"image_buttons_{tabname}", elem_classes="image-buttons"):
140
+ open_folder_button = ToolButton(folder_symbol, elem_id=f'{tabname}_open_folder', visible=not shared.cmd_opts.hide_ui_dir_config, tooltip="Open images output directory.")
141
+
142
+ if tabname != "extras":
143
+ save = ToolButton('💾', elem_id=f'save_{tabname}', tooltip=f"Save the image to a dedicated directory ({shared.opts.outdir_save}).")
144
+ save_zip = ToolButton('🗃️', elem_id=f'save_zip_{tabname}', tooltip=f"Save zip archive with images to a dedicated directory ({shared.opts.outdir_save})")
145
+
146
+ buttons = {
147
+ 'img2img': ToolButton('🖼️', elem_id=f'{tabname}_send_to_img2img', tooltip="Send image and generation parameters to img2img tab."),
148
+ 'inpaint': ToolButton('🎨️', elem_id=f'{tabname}_send_to_inpaint', tooltip="Send image and generation parameters to img2img inpaint tab."),
149
+ 'extras': ToolButton('📐', elem_id=f'{tabname}_send_to_extras', tooltip="Send image and generation parameters to extras tab.")
150
+ }
151
+
152
+ open_folder_button.click(
153
+ fn=lambda: open_folder(shared.opts.outdir_samples or outdir),
154
+ inputs=[],
155
+ outputs=[],
156
+ )
157
+
158
+ if tabname != "extras":
159
+ download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False, elem_id=f'download_files_{tabname}')
160
+
161
+ with gr.Group():
162
+ html_info = gr.HTML(elem_id=f'html_info_{tabname}', elem_classes="infotext")
163
+ html_log = gr.HTML(elem_id=f'html_log_{tabname}', elem_classes="html-log")
164
+
165
+ generation_info = gr.Textbox(visible=False, elem_id=f'generation_info_{tabname}')
166
+ if tabname == 'txt2img' or tabname == 'img2img':
167
+ generation_info_button = gr.Button(visible=False, elem_id=f"{tabname}_generation_info_button")
168
+ generation_info_button.click(
169
+ fn=update_generation_info,
170
+ _js="function(x, y, z){ return [x, y, selected_gallery_index()] }",
171
+ inputs=[generation_info, html_info, html_info],
172
+ outputs=[html_info, html_info],
173
+ show_progress=False,
174
+ )
175
+
176
+ save.click(
177
+ fn=call_queue.wrap_gradio_call(save_files),
178
+ _js="(x, y, z, w) => [x, y, false, selected_gallery_index()]",
179
+ inputs=[
180
+ generation_info,
181
+ result_gallery,
182
+ html_info,
183
+ html_info,
184
+ ],
185
+ outputs=[
186
+ download_files,
187
+ html_log,
188
+ ],
189
+ show_progress=False,
190
+ )
191
+
192
+ save_zip.click(
193
+ fn=call_queue.wrap_gradio_call(save_files),
194
+ _js="(x, y, z, w) => [x, y, true, selected_gallery_index()]",
195
+ inputs=[
196
+ generation_info,
197
+ result_gallery,
198
+ html_info,
199
+ html_info,
200
+ ],
201
+ outputs=[
202
+ download_files,
203
+ html_log,
204
+ ]
205
+ )
206
+
207
+ else:
208
+ html_info_x = gr.HTML(elem_id=f'html_info_x_{tabname}')
209
+ html_info = gr.HTML(elem_id=f'html_info_{tabname}', elem_classes="infotext")
210
+ html_log = gr.HTML(elem_id=f'html_log_{tabname}')
211
+
212
+ paste_field_names = []
213
+ if tabname == "txt2img":
214
+ paste_field_names = modules.scripts.scripts_txt2img.paste_field_names
215
+ elif tabname == "img2img":
216
+ paste_field_names = modules.scripts.scripts_img2img.paste_field_names
217
+
218
+ for paste_tabname, paste_button in buttons.items():
219
+ parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding(
220
+ paste_button=paste_button, tabname=paste_tabname, source_tabname="txt2img" if tabname == "txt2img" else None, source_image_component=result_gallery,
221
+ paste_field_names=paste_field_names
222
+ ))
223
+
224
+ return result_gallery, generation_info if tabname != "extras" else html_info_x, html_info, html_log
225
+
226
+
227
+ def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id):
228
+ refresh_components = refresh_component if isinstance(refresh_component, list) else [refresh_component]
229
+
230
+ label = None
231
+ for comp in refresh_components:
232
+ label = getattr(comp, 'label', None)
233
+ if label is not None:
234
+ break
235
+
236
+ def refresh():
237
+ refresh_method()
238
+ args = refreshed_args() if callable(refreshed_args) else refreshed_args
239
+
240
+ for k, v in args.items():
241
+ for comp in refresh_components:
242
+ setattr(comp, k, v)
243
+
244
+ return [gr.update(**(args or {})) for _ in refresh_components] if len(refresh_components) > 1 else gr.update(**(args or {}))
245
+
246
+ refresh_button = ToolButton(value=refresh_symbol, elem_id=elem_id, tooltip=f"{label}: refresh" if label else "Refresh")
247
+ refresh_button.click(
248
+ fn=refresh,
249
+ inputs=[],
250
+ outputs=refresh_components
251
+ )
252
+ return refresh_button
253
+
254
+
255
+ def setup_dialog(button_show, dialog, *, button_close=None):
256
+ """Sets up the UI so that the dialog (gr.Box) is invisible, and is only shown when buttons_show is clicked, in a fullscreen modal window."""
257
+
258
+ dialog.visible = False
259
+
260
+ button_show.click(
261
+ fn=lambda: gr.update(visible=True),
262
+ inputs=[],
263
+ outputs=[dialog],
264
+ ).then(fn=None, _js="function(){ popupId('" + dialog.elem_id + "'); }")
265
+
266
+ if button_close:
267
+ button_close.click(fn=None, _js="closePopup")
268
+
modules/ui_components.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+
4
+ class FormComponent:
5
+ def get_expected_parent(self):
6
+ return gr.components.Form
7
+
8
+
9
+ gr.Dropdown.get_expected_parent = FormComponent.get_expected_parent
10
+
11
+
12
+ class ToolButton(FormComponent, gr.Button):
13
+ """Small button with single emoji as text, fits inside gradio forms"""
14
+
15
+ def __init__(self, *args, **kwargs):
16
+ classes = kwargs.pop("elem_classes", [])
17
+ super().__init__(*args, elem_classes=["tool", *classes], **kwargs)
18
+
19
+ def get_block_name(self):
20
+ return "button"
21
+
22
+
23
+ class ResizeHandleRow(gr.Row):
24
+ """Same as gr.Row but fits inside gradio forms"""
25
+
26
+ def __init__(self, **kwargs):
27
+ super().__init__(**kwargs)
28
+
29
+ self.elem_classes.append("resize-handle-row")
30
+
31
+ def get_block_name(self):
32
+ return "row"
33
+
34
+
35
+ class FormRow(FormComponent, gr.Row):
36
+ """Same as gr.Row but fits inside gradio forms"""
37
+
38
+ def get_block_name(self):
39
+ return "row"
40
+
41
+
42
+ class FormColumn(FormComponent, gr.Column):
43
+ """Same as gr.Column but fits inside gradio forms"""
44
+
45
+ def get_block_name(self):
46
+ return "column"
47
+
48
+
49
+ class FormGroup(FormComponent, gr.Group):
50
+ """Same as gr.Group but fits inside gradio forms"""
51
+
52
+ def get_block_name(self):
53
+ return "group"
54
+
55
+
56
+ class FormHTML(FormComponent, gr.HTML):
57
+ """Same as gr.HTML but fits inside gradio forms"""
58
+
59
+ def get_block_name(self):
60
+ return "html"
61
+
62
+
63
+ class FormColorPicker(FormComponent, gr.ColorPicker):
64
+ """Same as gr.ColorPicker but fits inside gradio forms"""
65
+
66
+ def get_block_name(self):
67
+ return "colorpicker"
68
+
69
+
70
+ class DropdownMulti(FormComponent, gr.Dropdown):
71
+ """Same as gr.Dropdown but always multiselect"""
72
+ def __init__(self, **kwargs):
73
+ super().__init__(multiselect=True, **kwargs)
74
+
75
+ def get_block_name(self):
76
+ return "dropdown"
77
+
78
+
79
+ class DropdownEditable(FormComponent, gr.Dropdown):
80
+ """Same as gr.Dropdown but allows editing value"""
81
+ def __init__(self, **kwargs):
82
+ super().__init__(allow_custom_value=True, **kwargs)
83
+
84
+ def get_block_name(self):
85
+ return "dropdown"
86
+
87
+
88
+ class InputAccordion(gr.Checkbox):
89
+ """A gr.Accordion that can be used as an input - returns True if open, False if closed.
90
+
91
+ Actaully just a hidden checkbox, but creates an accordion that follows and is followed by the state of the checkbox.
92
+ """
93
+
94
+ global_index = 0
95
+
96
+ def __init__(self, value, **kwargs):
97
+ self.accordion_id = kwargs.get('elem_id')
98
+ if self.accordion_id is None:
99
+ self.accordion_id = f"input-accordion-{InputAccordion.global_index}"
100
+ InputAccordion.global_index += 1
101
+
102
+ kwargs_checkbox = {
103
+ **kwargs,
104
+ "elem_id": f"{self.accordion_id}-checkbox",
105
+ "visible": False,
106
+ }
107
+ super().__init__(value, **kwargs_checkbox)
108
+
109
+ self.change(fn=None, _js='function(checked){ inputAccordionChecked("' + self.accordion_id + '", checked); }', inputs=[self])
110
+
111
+ kwargs_accordion = {
112
+ **kwargs,
113
+ "elem_id": self.accordion_id,
114
+ "label": kwargs.get('label', 'Accordion'),
115
+ "elem_classes": ['input-accordion'],
116
+ "open": value,
117
+ }
118
+ self.accordion = gr.Accordion(**kwargs_accordion)
119
+
120
+ def extra(self):
121
+ """Allows you to put something into the label of the accordion.
122
+
123
+ Use it like this:
124
+
125
+ ```
126
+ with InputAccordion(False, label="Accordion") as acc:
127
+ with acc.extra():
128
+ FormHTML(value="hello", min_width=0)
129
+
130
+ ...
131
+ ```
132
+ """
133
+
134
+ return gr.Column(elem_id=self.accordion_id + '-extra', elem_classes='input-accordion-extra', min_width=0)
135
+
136
+ def __enter__(self):
137
+ self.accordion.__enter__()
138
+ return self
139
+
140
+ def __exit__(self, exc_type, exc_val, exc_tb):
141
+ self.accordion.__exit__(exc_type, exc_val, exc_tb)
142
+
143
+ def get_block_name(self):
144
+ return "checkbox"
145
+
modules/ui_extensions.py ADDED
@@ -0,0 +1,669 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import threading
4
+ import time
5
+ from datetime import datetime, timezone
6
+
7
+ import git
8
+
9
+ import gradio as gr
10
+ import html
11
+ import shutil
12
+ import errno
13
+
14
+ from modules import extensions, shared, paths, config_states, errors, restart
15
+ from modules.paths_internal import config_states_dir
16
+ from modules.call_queue import wrap_gradio_gpu_call
17
+
18
+ available_extensions = {"extensions": []}
19
+ STYLE_PRIMARY = ' style="color: var(--primary-400)"'
20
+
21
+
22
+ def check_access():
23
+ assert not shared.cmd_opts.disable_extension_access, "extension access disabled because of command line flags"
24
+
25
+
26
+ def apply_and_restart(disable_list, update_list, disable_all):
27
+ check_access()
28
+
29
+ disabled = json.loads(disable_list)
30
+ assert type(disabled) == list, f"wrong disable_list data for apply_and_restart: {disable_list}"
31
+
32
+ update = json.loads(update_list)
33
+ assert type(update) == list, f"wrong update_list data for apply_and_restart: {update_list}"
34
+
35
+ if update:
36
+ save_config_state("Backup (pre-update)")
37
+
38
+ update = set(update)
39
+
40
+ for ext in extensions.extensions:
41
+ if ext.name not in update:
42
+ continue
43
+
44
+ try:
45
+ ext.fetch_and_reset_hard()
46
+ except Exception:
47
+ errors.report(f"Error getting updates for {ext.name}", exc_info=True)
48
+
49
+ shared.opts.disabled_extensions = disabled
50
+ shared.opts.disable_all_extensions = disable_all
51
+ shared.opts.save(shared.config_filename)
52
+
53
+ if restart.is_restartable():
54
+ restart.restart_program()
55
+ else:
56
+ restart.stop_program()
57
+
58
+
59
+ def save_config_state(name):
60
+ current_config_state = config_states.get_config()
61
+ if not name:
62
+ name = "Config"
63
+ current_config_state["name"] = name
64
+ timestamp = datetime.now().strftime('%Y_%m_%d-%H_%M_%S')
65
+ filename = os.path.join(config_states_dir, f"{timestamp}_{name}.json")
66
+ print(f"Saving backup of webui/extension state to {filename}.")
67
+ with open(filename, "w", encoding="utf-8") as f:
68
+ json.dump(current_config_state, f, indent=4)
69
+ config_states.list_config_states()
70
+ new_value = next(iter(config_states.all_config_states.keys()), "Current")
71
+ new_choices = ["Current"] + list(config_states.all_config_states.keys())
72
+ return gr.Dropdown.update(value=new_value, choices=new_choices), f"<span>Saved current webui/extension state to \"{filename}\"</span>"
73
+
74
+
75
+ def restore_config_state(confirmed, config_state_name, restore_type):
76
+ if config_state_name == "Current":
77
+ return "<span>Select a config to restore from.</span>"
78
+ if not confirmed:
79
+ return "<span>Cancelled.</span>"
80
+
81
+ check_access()
82
+
83
+ config_state = config_states.all_config_states[config_state_name]
84
+
85
+ print(f"*** Restoring webui state from backup: {restore_type} ***")
86
+
87
+ if restore_type == "extensions" or restore_type == "both":
88
+ shared.opts.restore_config_state_file = config_state["filepath"]
89
+ shared.opts.save(shared.config_filename)
90
+
91
+ if restore_type == "webui" or restore_type == "both":
92
+ config_states.restore_webui_config(config_state)
93
+
94
+ shared.state.request_restart()
95
+
96
+ return ""
97
+
98
+
99
+ def check_updates(id_task, disable_list):
100
+ check_access()
101
+
102
+ disabled = json.loads(disable_list)
103
+ assert type(disabled) == list, f"wrong disable_list data for apply_and_restart: {disable_list}"
104
+
105
+ exts = [ext for ext in extensions.extensions if ext.remote is not None and ext.name not in disabled]
106
+ shared.state.job_count = len(exts)
107
+
108
+ for ext in exts:
109
+ shared.state.textinfo = ext.name
110
+
111
+ try:
112
+ ext.check_updates()
113
+ except FileNotFoundError as e:
114
+ if 'FETCH_HEAD' not in str(e):
115
+ raise
116
+ except Exception:
117
+ errors.report(f"Error checking updates for {ext.name}", exc_info=True)
118
+
119
+ shared.state.nextjob()
120
+
121
+ return extension_table(), ""
122
+
123
+
124
+ def make_commit_link(commit_hash, remote, text=None):
125
+ if text is None:
126
+ text = commit_hash[:8]
127
+ if remote.startswith("https://github.com/"):
128
+ if remote.endswith(".git"):
129
+ remote = remote[:-4]
130
+ href = remote + "/commit/" + commit_hash
131
+ return f'<a href="{href}" target="_blank">{text}</a>'
132
+ else:
133
+ return text
134
+
135
+
136
+ def extension_table():
137
+ code = f"""<!-- {time.time()} -->
138
+ <table id="extensions">
139
+ <thead>
140
+ <tr>
141
+ <th>
142
+ <input class="gr-check-radio gr-checkbox all_extensions_toggle" type="checkbox" {'checked="checked"' if all(ext.enabled for ext in extensions.extensions) else ''} onchange="toggle_all_extensions(event)" />
143
+ <abbr title="Use checkbox to enable the extension; it will be enabled or disabled when you click apply button">Extension</abbr>
144
+ </th>
145
+ <th>URL</th>
146
+ <th>Branch</th>
147
+ <th>Version</th>
148
+ <th>Date</th>
149
+ <th><abbr title="Use checkbox to mark the extension for update; it will be updated when you click apply button">Update</abbr></th>
150
+ </tr>
151
+ </thead>
152
+ <tbody>
153
+ """
154
+
155
+ for ext in extensions.extensions:
156
+ ext: extensions.Extension
157
+ ext.read_info_from_repo()
158
+
159
+ remote = f"""<a href="{html.escape(ext.remote or '')}" target="_blank">{html.escape("built-in" if ext.is_builtin else ext.remote or '')}</a>"""
160
+
161
+ if ext.can_update:
162
+ ext_status = f"""<label><input class="gr-check-radio gr-checkbox" name="update_{html.escape(ext.name)}" checked="checked" type="checkbox">{html.escape(ext.status)}</label>"""
163
+ else:
164
+ ext_status = ext.status
165
+
166
+ style = ""
167
+ if shared.cmd_opts.disable_extra_extensions and not ext.is_builtin or shared.opts.disable_all_extensions == "extra" and not ext.is_builtin or shared.cmd_opts.disable_all_extensions or shared.opts.disable_all_extensions == "all":
168
+ style = STYLE_PRIMARY
169
+
170
+ version_link = ext.version
171
+ if ext.commit_hash and ext.remote:
172
+ version_link = make_commit_link(ext.commit_hash, ext.remote, ext.version)
173
+
174
+ code += f"""
175
+ <tr>
176
+ <td><label{style}><input class="gr-check-radio gr-checkbox extension_toggle" name="enable_{html.escape(ext.name)}" type="checkbox" {'checked="checked"' if ext.enabled else ''} onchange="toggle_extension(event)" />{html.escape(ext.name)}</label></td>
177
+ <td>{remote}</td>
178
+ <td>{ext.branch}</td>
179
+ <td>{version_link}</td>
180
+ <td>{datetime.fromtimestamp(ext.commit_date) if ext.commit_date else ""}</td>
181
+ <td{' class="extension_status"' if ext.remote is not None else ''}>{ext_status}</td>
182
+ </tr>
183
+ """
184
+
185
+ code += """
186
+ </tbody>
187
+ </table>
188
+ """
189
+
190
+ return code
191
+
192
+
193
+ def update_config_states_table(state_name):
194
+ if state_name == "Current":
195
+ config_state = config_states.get_config()
196
+ else:
197
+ config_state = config_states.all_config_states[state_name]
198
+
199
+ config_name = config_state.get("name", "Config")
200
+ created_date = time.asctime(time.gmtime(config_state["created_at"]))
201
+ filepath = config_state.get("filepath", "<unknown>")
202
+
203
+ try:
204
+ webui_remote = config_state["webui"]["remote"] or ""
205
+ webui_branch = config_state["webui"]["branch"]
206
+ webui_commit_hash = config_state["webui"]["commit_hash"] or "<unknown>"
207
+ webui_commit_date = config_state["webui"]["commit_date"]
208
+ if webui_commit_date:
209
+ webui_commit_date = time.asctime(time.gmtime(webui_commit_date))
210
+ else:
211
+ webui_commit_date = "<unknown>"
212
+
213
+ remote = f"""<a href="{html.escape(webui_remote)}" target="_blank">{html.escape(webui_remote or '')}</a>"""
214
+ commit_link = make_commit_link(webui_commit_hash, webui_remote)
215
+ date_link = make_commit_link(webui_commit_hash, webui_remote, webui_commit_date)
216
+
217
+ current_webui = config_states.get_webui_config()
218
+
219
+ style_remote = ""
220
+ style_branch = ""
221
+ style_commit = ""
222
+ if current_webui["remote"] != webui_remote:
223
+ style_remote = STYLE_PRIMARY
224
+ if current_webui["branch"] != webui_branch:
225
+ style_branch = STYLE_PRIMARY
226
+ if current_webui["commit_hash"] != webui_commit_hash:
227
+ style_commit = STYLE_PRIMARY
228
+
229
+ code = f"""<!-- {time.time()} -->
230
+ <h2>Config Backup: {config_name}</h2>
231
+ <div><b>Filepath:</b> {filepath}</div>
232
+ <div><b>Created at:</b> {created_date}</div>
233
+ <h2>WebUI State</h2>
234
+ <table id="config_state_webui">
235
+ <thead>
236
+ <tr>
237
+ <th>URL</th>
238
+ <th>Branch</th>
239
+ <th>Commit</th>
240
+ <th>Date</th>
241
+ </tr>
242
+ </thead>
243
+ <tbody>
244
+ <tr>
245
+ <td>
246
+ <label{style_remote}>{remote}</label>
247
+ </td>
248
+ <td>
249
+ <label{style_branch}>{webui_branch}</label>
250
+ </td>
251
+ <td>
252
+ <label{style_commit}>{commit_link}</label>
253
+ </td>
254
+ <td>
255
+ <label{style_commit}>{date_link}</label>
256
+ </td>
257
+ </tr>
258
+ </tbody>
259
+ </table>
260
+ <h2>Extension State</h2>
261
+ <table id="config_state_extensions">
262
+ <thead>
263
+ <tr>
264
+ <th>Extension</th>
265
+ <th>URL</th>
266
+ <th>Branch</th>
267
+ <th>Commit</th>
268
+ <th>Date</th>
269
+ </tr>
270
+ </thead>
271
+ <tbody>
272
+ """
273
+
274
+ ext_map = {ext.name: ext for ext in extensions.extensions}
275
+
276
+ for ext_name, ext_conf in config_state["extensions"].items():
277
+ ext_remote = ext_conf["remote"] or ""
278
+ ext_branch = ext_conf["branch"] or "<unknown>"
279
+ ext_enabled = ext_conf["enabled"]
280
+ ext_commit_hash = ext_conf["commit_hash"] or "<unknown>"
281
+ ext_commit_date = ext_conf["commit_date"]
282
+ if ext_commit_date:
283
+ ext_commit_date = time.asctime(time.gmtime(ext_commit_date))
284
+ else:
285
+ ext_commit_date = "<unknown>"
286
+
287
+ remote = f"""<a href="{html.escape(ext_remote)}" target="_blank">{html.escape(ext_remote or '')}</a>"""
288
+ commit_link = make_commit_link(ext_commit_hash, ext_remote)
289
+ date_link = make_commit_link(ext_commit_hash, ext_remote, ext_commit_date)
290
+
291
+ style_enabled = ""
292
+ style_remote = ""
293
+ style_branch = ""
294
+ style_commit = ""
295
+ if ext_name in ext_map:
296
+ current_ext = ext_map[ext_name]
297
+ current_ext.read_info_from_repo()
298
+ if current_ext.enabled != ext_enabled:
299
+ style_enabled = STYLE_PRIMARY
300
+ if current_ext.remote != ext_remote:
301
+ style_remote = STYLE_PRIMARY
302
+ if current_ext.branch != ext_branch:
303
+ style_branch = STYLE_PRIMARY
304
+ if current_ext.commit_hash != ext_commit_hash:
305
+ style_commit = STYLE_PRIMARY
306
+
307
+ code += f""" <tr>
308
+ <td><label{style_enabled}><input class="gr-check-radio gr-checkbox" type="checkbox" disabled="true" {'checked="checked"' if ext_enabled else ''}>{html.escape(ext_name)}</label></td>
309
+ <td><label{style_remote}>{remote}</label></td>
310
+ <td><label{style_branch}>{ext_branch}</label></td>
311
+ <td><label{style_commit}>{commit_link}</label></td>
312
+ <td><label{style_commit}>{date_link}</label></td>
313
+ </tr>
314
+ """
315
+
316
+ code += """ </tbody>
317
+ </table>"""
318
+
319
+ except Exception as e:
320
+ print(f"[ERROR]: Config states {filepath}, {e}")
321
+ code = f"""<!-- {time.time()} -->
322
+ <h2>Config Backup: {config_name}</h2>
323
+ <div><b>Filepath:</b> {filepath}</div>
324
+ <div><b>Created at:</b> {created_date}</div>
325
+ <h2>This file is corrupted</h2>"""
326
+
327
+ return code
328
+
329
+
330
+ def normalize_git_url(url):
331
+ if url is None:
332
+ return ""
333
+
334
+ url = url.replace(".git", "")
335
+ return url
336
+
337
+
338
+ def install_extension_from_url(dirname, url, branch_name=None):
339
+ check_access()
340
+
341
+ if isinstance(dirname, str):
342
+ dirname = dirname.strip()
343
+ if isinstance(url, str):
344
+ url = url.strip()
345
+
346
+ assert url, 'No URL specified'
347
+
348
+ if dirname is None or dirname == "":
349
+ *parts, last_part = url.split('/')
350
+ last_part = normalize_git_url(last_part)
351
+
352
+ dirname = last_part
353
+
354
+ target_dir = os.path.join(extensions.extensions_dir, dirname)
355
+ assert not os.path.exists(target_dir), f'Extension directory already exists: {target_dir}'
356
+
357
+ normalized_url = normalize_git_url(url)
358
+ if any(x for x in extensions.extensions if normalize_git_url(x.remote) == normalized_url):
359
+ raise Exception(f'Extension with this URL is already installed: {url}')
360
+
361
+ tmpdir = os.path.join(paths.data_path, "tmp", dirname)
362
+
363
+ try:
364
+ shutil.rmtree(tmpdir, True)
365
+ if not branch_name:
366
+ # if no branch is specified, use the default branch
367
+ with git.Repo.clone_from(url, tmpdir, filter=['blob:none']) as repo:
368
+ repo.remote().fetch()
369
+ for submodule in repo.submodules:
370
+ submodule.update()
371
+ else:
372
+ with git.Repo.clone_from(url, tmpdir, filter=['blob:none'], branch=branch_name) as repo:
373
+ repo.remote().fetch()
374
+ for submodule in repo.submodules:
375
+ submodule.update()
376
+ try:
377
+ os.rename(tmpdir, target_dir)
378
+ except OSError as err:
379
+ if err.errno == errno.EXDEV:
380
+ # Cross device link, typical in docker or when tmp/ and extensions/ are on different file systems
381
+ # Since we can't use a rename, do the slower but more versitile shutil.move()
382
+ shutil.move(tmpdir, target_dir)
383
+ else:
384
+ # Something else, not enough free space, permissions, etc. rethrow it so that it gets handled.
385
+ raise err
386
+
387
+ import launch
388
+ launch.run_extension_installer(target_dir)
389
+
390
+ extensions.list_extensions()
391
+ return [extension_table(), html.escape(f"Installed into {target_dir}. Use Installed tab to restart.")]
392
+ finally:
393
+ shutil.rmtree(tmpdir, True)
394
+
395
+
396
+ def install_extension_from_index(url, hide_tags, sort_column, filter_text):
397
+ ext_table, message = install_extension_from_url(None, url)
398
+
399
+ code, _ = refresh_available_extensions_from_data(hide_tags, sort_column, filter_text)
400
+
401
+ return code, ext_table, message, ''
402
+
403
+
404
+ def refresh_available_extensions(url, hide_tags, sort_column):
405
+ global available_extensions
406
+
407
+ import urllib.request
408
+ with urllib.request.urlopen(url) as response:
409
+ text = response.read()
410
+
411
+ available_extensions = json.loads(text)
412
+
413
+ code, tags = refresh_available_extensions_from_data(hide_tags, sort_column)
414
+
415
+ return url, code, gr.CheckboxGroup.update(choices=tags), '', ''
416
+
417
+
418
+ def refresh_available_extensions_for_tags(hide_tags, sort_column, filter_text):
419
+ code, _ = refresh_available_extensions_from_data(hide_tags, sort_column, filter_text)
420
+
421
+ return code, ''
422
+
423
+
424
+ def search_extensions(filter_text, hide_tags, sort_column):
425
+ code, _ = refresh_available_extensions_from_data(hide_tags, sort_column, filter_text)
426
+
427
+ return code, ''
428
+
429
+
430
+ sort_ordering = [
431
+ # (reverse, order_by_function)
432
+ (True, lambda x: x.get('added', 'z')),
433
+ (False, lambda x: x.get('added', 'z')),
434
+ (False, lambda x: x.get('name', 'z')),
435
+ (True, lambda x: x.get('name', 'z')),
436
+ (False, lambda x: 'z'),
437
+ (True, lambda x: x.get('commit_time', '')),
438
+ (True, lambda x: x.get('created_at', '')),
439
+ (True, lambda x: x.get('stars', 0)),
440
+ ]
441
+
442
+
443
+ def get_date(info: dict, key):
444
+ try:
445
+ return datetime.strptime(info.get(key), "%Y-%m-%dT%H:%M:%SZ").replace(tzinfo=timezone.utc).astimezone().strftime("%Y-%m-%d")
446
+ except (ValueError, TypeError):
447
+ return ''
448
+
449
+
450
+ def refresh_available_extensions_from_data(hide_tags, sort_column, filter_text=""):
451
+ extlist = available_extensions["extensions"]
452
+ installed_extension_urls = {normalize_git_url(extension.remote): extension.name for extension in extensions.extensions}
453
+
454
+ tags = available_extensions.get("tags", {})
455
+ tags_to_hide = set(hide_tags)
456
+ hidden = 0
457
+
458
+ code = f"""<!-- {time.time()} -->
459
+ <table id="available_extensions">
460
+ <thead>
461
+ <tr>
462
+ <th>Extension</th>
463
+ <th>Description</th>
464
+ <th>Action</th>
465
+ </tr>
466
+ </thead>
467
+ <tbody>
468
+ """
469
+
470
+ sort_reverse, sort_function = sort_ordering[sort_column if 0 <= sort_column < len(sort_ordering) else 0]
471
+
472
+ for ext in sorted(extlist, key=sort_function, reverse=sort_reverse):
473
+ name = ext.get("name", "noname")
474
+ stars = int(ext.get("stars", 0))
475
+ added = ext.get('added', 'unknown')
476
+ update_time = get_date(ext, 'commit_time')
477
+ create_time = get_date(ext, 'created_at')
478
+ url = ext.get("url", None)
479
+ description = ext.get("description", "")
480
+ extension_tags = ext.get("tags", [])
481
+
482
+ if url is None:
483
+ continue
484
+
485
+ existing = installed_extension_urls.get(normalize_git_url(url), None)
486
+ extension_tags = extension_tags + ["installed"] if existing else extension_tags
487
+
488
+ if any(x for x in extension_tags if x in tags_to_hide):
489
+ hidden += 1
490
+ continue
491
+
492
+ if filter_text and filter_text.strip():
493
+ if filter_text.lower() not in html.escape(name).lower() and filter_text.lower() not in html.escape(description).lower():
494
+ hidden += 1
495
+ continue
496
+
497
+ install_code = f"""<button onclick="install_extension_from_index(this, '{html.escape(url)}')" {"disabled=disabled" if existing else ""} class="lg secondary gradio-button custom-button">{"Install" if not existing else "Installed"}</button>"""
498
+
499
+ tags_text = ", ".join([f"<span class='extension-tag' title='{tags.get(x, '')}'>{x}</span>" for x in extension_tags])
500
+
501
+ code += f"""
502
+ <tr>
503
+ <td><a href="{html.escape(url)}" target="_blank">{html.escape(name)}</a><br />{tags_text}</td>
504
+ <td>{html.escape(description)}<p class="info">
505
+ <span class="date_added">Update: {html.escape(update_time)} Added: {html.escape(added)} Created: {html.escape(create_time)}</span><span class="star_count">stars: <b>{stars}</b></a></p></td>
506
+ <td>{install_code}</td>
507
+ </tr>
508
+
509
+ """
510
+
511
+ for tag in [x for x in extension_tags if x not in tags]:
512
+ tags[tag] = tag
513
+
514
+ code += """
515
+ </tbody>
516
+ </table>
517
+ """
518
+
519
+ if hidden > 0:
520
+ code += f"<p>Extension hidden: {hidden}</p>"
521
+
522
+ return code, list(tags)
523
+
524
+
525
+ def preload_extensions_git_metadata():
526
+ for extension in extensions.extensions:
527
+ extension.read_info_from_repo()
528
+
529
+
530
+ def create_ui():
531
+ import modules.ui
532
+
533
+ config_states.list_config_states()
534
+
535
+ threading.Thread(target=preload_extensions_git_metadata).start()
536
+
537
+ with gr.Blocks(analytics_enabled=False) as ui:
538
+ with gr.Tabs(elem_id="tabs_extensions"):
539
+ with gr.TabItem("Installed", id="installed"):
540
+
541
+ with gr.Row(elem_id="extensions_installed_top"):
542
+ apply_label = ("Apply and restart UI" if restart.is_restartable() else "Apply and quit")
543
+ apply = gr.Button(value=apply_label, variant="primary")
544
+ check = gr.Button(value="Check for updates")
545
+ extensions_disable_all = gr.Radio(label="Disable all extensions", choices=["none", "extra", "all"], value=shared.opts.disable_all_extensions, elem_id="extensions_disable_all")
546
+ extensions_disabled_list = gr.Text(elem_id="extensions_disabled_list", visible=False, container=False)
547
+ extensions_update_list = gr.Text(elem_id="extensions_update_list", visible=False, container=False)
548
+
549
+ html = ""
550
+
551
+ if shared.cmd_opts.disable_all_extensions or shared.cmd_opts.disable_extra_extensions or shared.opts.disable_all_extensions != "none":
552
+ if shared.cmd_opts.disable_all_extensions:
553
+ msg = '"--disable-all-extensions" was used, remove it to load all extensions again'
554
+ elif shared.opts.disable_all_extensions != "none":
555
+ msg = '"Disable all extensions" was set, change it to "none" to load all extensions again'
556
+ elif shared.cmd_opts.disable_extra_extensions:
557
+ msg = '"--disable-extra-extensions" was used, remove it to load all extensions again'
558
+ html = f'<span style="color: var(--primary-400);">{msg}</span>'
559
+
560
+ with gr.Row():
561
+ info = gr.HTML(html)
562
+
563
+ with gr.Row(elem_classes="progress-container"):
564
+ extensions_table = gr.HTML('Loading...', elem_id="extensions_installed_html")
565
+
566
+ ui.load(fn=extension_table, inputs=[], outputs=[extensions_table])
567
+
568
+ apply.click(
569
+ fn=apply_and_restart,
570
+ _js="extensions_apply",
571
+ inputs=[extensions_disabled_list, extensions_update_list, extensions_disable_all],
572
+ outputs=[],
573
+ )
574
+
575
+ check.click(
576
+ fn=wrap_gradio_gpu_call(check_updates, extra_outputs=[gr.update()]),
577
+ _js="extensions_check",
578
+ inputs=[info, extensions_disabled_list],
579
+ outputs=[extensions_table, info],
580
+ )
581
+
582
+ with gr.TabItem("Available", id="available"):
583
+ with gr.Row():
584
+ refresh_available_extensions_button = gr.Button(value="Load from:", variant="primary")
585
+ extensions_index_url = os.environ.get('WEBUI_EXTENSIONS_INDEX', "https://raw.githubusercontent.com/AUTOMATIC1111/stable-diffusion-webui-extensions/master/index.json")
586
+ available_extensions_index = gr.Text(value=extensions_index_url, label="Extension index URL", container=False)
587
+ extension_to_install = gr.Text(elem_id="extension_to_install", visible=False)
588
+ install_extension_button = gr.Button(elem_id="install_extension_button", visible=False)
589
+
590
+ with gr.Row():
591
+ hide_tags = gr.CheckboxGroup(value=["ads", "localization", "installed"], label="Hide extensions with tags", choices=["script", "ads", "localization", "installed"])
592
+ sort_column = gr.Radio(value="newest first", label="Order", choices=["newest first", "oldest first", "a-z", "z-a", "internal order",'update time', 'create time', "stars"], type="index")
593
+
594
+ with gr.Row():
595
+ search_extensions_text = gr.Text(label="Search", container=False)
596
+
597
+ install_result = gr.HTML()
598
+ available_extensions_table = gr.HTML()
599
+
600
+ refresh_available_extensions_button.click(
601
+ fn=modules.ui.wrap_gradio_call(refresh_available_extensions, extra_outputs=[gr.update(), gr.update(), gr.update(), gr.update()]),
602
+ inputs=[available_extensions_index, hide_tags, sort_column],
603
+ outputs=[available_extensions_index, available_extensions_table, hide_tags, search_extensions_text, install_result],
604
+ )
605
+
606
+ install_extension_button.click(
607
+ fn=modules.ui.wrap_gradio_call(install_extension_from_index, extra_outputs=[gr.update(), gr.update()]),
608
+ inputs=[extension_to_install, hide_tags, sort_column, search_extensions_text],
609
+ outputs=[available_extensions_table, extensions_table, install_result],
610
+ )
611
+
612
+ search_extensions_text.change(
613
+ fn=modules.ui.wrap_gradio_call(search_extensions, extra_outputs=[gr.update()]),
614
+ inputs=[search_extensions_text, hide_tags, sort_column],
615
+ outputs=[available_extensions_table, install_result],
616
+ )
617
+
618
+ hide_tags.change(
619
+ fn=modules.ui.wrap_gradio_call(refresh_available_extensions_for_tags, extra_outputs=[gr.update()]),
620
+ inputs=[hide_tags, sort_column, search_extensions_text],
621
+ outputs=[available_extensions_table, install_result]
622
+ )
623
+
624
+ sort_column.change(
625
+ fn=modules.ui.wrap_gradio_call(refresh_available_extensions_for_tags, extra_outputs=[gr.update()]),
626
+ inputs=[hide_tags, sort_column, search_extensions_text],
627
+ outputs=[available_extensions_table, install_result]
628
+ )
629
+
630
+ with gr.TabItem("Install from URL", id="install_from_url"):
631
+ install_url = gr.Text(label="URL for extension's git repository")
632
+ install_branch = gr.Text(label="Specific branch name", placeholder="Leave empty for default main branch")
633
+ install_dirname = gr.Text(label="Local directory name", placeholder="Leave empty for auto")
634
+ install_button = gr.Button(value="Install", variant="primary")
635
+ install_result = gr.HTML(elem_id="extension_install_result")
636
+
637
+ install_button.click(
638
+ fn=modules.ui.wrap_gradio_call(lambda *args: [gr.update(), *install_extension_from_url(*args)], extra_outputs=[gr.update(), gr.update()]),
639
+ inputs=[install_dirname, install_url, install_branch],
640
+ outputs=[install_url, extensions_table, install_result],
641
+ )
642
+
643
+ with gr.TabItem("Backup/Restore"):
644
+ with gr.Row(elem_id="extensions_backup_top_row"):
645
+ config_states_list = gr.Dropdown(label="Saved Configs", elem_id="extension_backup_saved_configs", value="Current", choices=["Current"] + list(config_states.all_config_states.keys()))
646
+ modules.ui.create_refresh_button(config_states_list, config_states.list_config_states, lambda: {"choices": ["Current"] + list(config_states.all_config_states.keys())}, "refresh_config_states")
647
+ config_restore_type = gr.Radio(label="State to restore", choices=["extensions", "webui", "both"], value="extensions", elem_id="extension_backup_restore_type")
648
+ config_restore_button = gr.Button(value="Restore Selected Config", variant="primary", elem_id="extension_backup_restore")
649
+ with gr.Row(elem_id="extensions_backup_top_row2"):
650
+ config_save_name = gr.Textbox("", placeholder="Config Name", show_label=False)
651
+ config_save_button = gr.Button(value="Save Current Config")
652
+
653
+ config_states_info = gr.HTML("")
654
+ config_states_table = gr.HTML("Loading...")
655
+ ui.load(fn=update_config_states_table, inputs=[config_states_list], outputs=[config_states_table])
656
+
657
+ config_save_button.click(fn=save_config_state, inputs=[config_save_name], outputs=[config_states_list, config_states_info])
658
+
659
+ dummy_component = gr.Label(visible=False)
660
+ config_restore_button.click(fn=restore_config_state, _js="config_state_confirm_restore", inputs=[dummy_component, config_states_list, config_restore_type], outputs=[config_states_info])
661
+
662
+ config_states_list.change(
663
+ fn=update_config_states_table,
664
+ inputs=[config_states_list],
665
+ outputs=[config_states_table],
666
+ )
667
+
668
+
669
+ return ui
modules/ui_extra_networks.py ADDED
@@ -0,0 +1,455 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path
2
+ import urllib.parse
3
+ from pathlib import Path
4
+
5
+ from modules import shared, ui_extra_networks_user_metadata, errors, extra_networks
6
+ from modules.images import read_info_from_image, save_image_with_geninfo
7
+ import gradio as gr
8
+ import json
9
+ import html
10
+ from fastapi.exceptions import HTTPException
11
+
12
+ from modules.generation_parameters_copypaste import image_from_url_text
13
+ from modules.ui_components import ToolButton
14
+
15
+ extra_pages = []
16
+ allowed_dirs = set()
17
+
18
+
19
+ def register_page(page):
20
+ """registers extra networks page for the UI; recommend doing it in on_before_ui() callback for extensions"""
21
+
22
+ extra_pages.append(page)
23
+ allowed_dirs.clear()
24
+ allowed_dirs.update(set(sum([x.allowed_directories_for_previews() for x in extra_pages], [])))
25
+
26
+
27
+ def fetch_file(filename: str = ""):
28
+ from starlette.responses import FileResponse
29
+
30
+ if not os.path.isfile(filename):
31
+ raise HTTPException(status_code=404, detail="File not found")
32
+
33
+ if not any(Path(x).absolute() in Path(filename).absolute().parents for x in allowed_dirs):
34
+ raise ValueError(f"File cannot be fetched: {filename}. Must be in one of directories registered by extra pages.")
35
+
36
+ ext = os.path.splitext(filename)[1].lower()
37
+ if ext not in (".png", ".jpg", ".jpeg", ".webp", ".gif"):
38
+ raise ValueError(f"File cannot be fetched: {filename}. Only png, jpg, webp, and gif.")
39
+
40
+ # would profit from returning 304
41
+ return FileResponse(filename, headers={"Accept-Ranges": "bytes"})
42
+
43
+
44
+ def get_metadata(page: str = "", item: str = ""):
45
+ from starlette.responses import JSONResponse
46
+
47
+ page = next(iter([x for x in extra_pages if x.name == page]), None)
48
+ if page is None:
49
+ return JSONResponse({})
50
+
51
+ metadata = page.metadata.get(item)
52
+ if metadata is None:
53
+ return JSONResponse({})
54
+
55
+ return JSONResponse({"metadata": json.dumps(metadata, indent=4, ensure_ascii=False)})
56
+
57
+
58
+ def get_single_card(page: str = "", tabname: str = "", name: str = ""):
59
+ from starlette.responses import JSONResponse
60
+
61
+ page = next(iter([x for x in extra_pages if x.name == page]), None)
62
+
63
+ try:
64
+ item = page.create_item(name, enable_filter=False)
65
+ page.items[name] = item
66
+ except Exception as e:
67
+ errors.display(e, "creating item for extra network")
68
+ item = page.items.get(name)
69
+
70
+ page.read_user_metadata(item)
71
+ item_html = page.create_html_for_item(item, tabname)
72
+
73
+ return JSONResponse({"html": item_html})
74
+
75
+
76
+ def add_pages_to_demo(app):
77
+ app.add_api_route("/sd_extra_networks/thumb", fetch_file, methods=["GET"])
78
+ app.add_api_route("/sd_extra_networks/metadata", get_metadata, methods=["GET"])
79
+ app.add_api_route("/sd_extra_networks/get-single-card", get_single_card, methods=["GET"])
80
+
81
+
82
+ def quote_js(s):
83
+ s = s.replace('\\', '\\\\')
84
+ s = s.replace('"', '\\"')
85
+ return f'"{s}"'
86
+
87
+
88
+ class ExtraNetworksPage:
89
+ def __init__(self, title):
90
+ self.title = title
91
+ self.name = title.lower()
92
+ self.id_page = self.name.replace(" ", "_")
93
+ self.card_page = shared.html("extra-networks-card.html")
94
+ self.allow_negative_prompt = False
95
+ self.metadata = {}
96
+ self.items = {}
97
+
98
+ def refresh(self):
99
+ pass
100
+
101
+ def read_user_metadata(self, item):
102
+ filename = item.get("filename", None)
103
+ metadata = extra_networks.get_user_metadata(filename)
104
+
105
+ desc = metadata.get("description", None)
106
+ if desc is not None:
107
+ item["description"] = desc
108
+
109
+ item["user_metadata"] = metadata
110
+
111
+ def link_preview(self, filename):
112
+ quoted_filename = urllib.parse.quote(filename.replace('\\', '/'))
113
+ mtime = os.path.getmtime(filename)
114
+ return f"./sd_extra_networks/thumb?filename={quoted_filename}&mtime={mtime}"
115
+
116
+ def search_terms_from_path(self, filename, possible_directories=None):
117
+ abspath = os.path.abspath(filename)
118
+
119
+ for parentdir in (possible_directories if possible_directories is not None else self.allowed_directories_for_previews()):
120
+ parentdir = os.path.abspath(parentdir)
121
+ if abspath.startswith(parentdir):
122
+ return abspath[len(parentdir):].replace('\\', '/')
123
+
124
+ return ""
125
+
126
+ def create_html(self, tabname):
127
+ items_html = ''
128
+
129
+ self.metadata = {}
130
+
131
+ subdirs = {}
132
+ for parentdir in [os.path.abspath(x) for x in self.allowed_directories_for_previews()]:
133
+ for root, dirs, _ in sorted(os.walk(parentdir, followlinks=True), key=lambda x: shared.natural_sort_key(x[0])):
134
+ for dirname in sorted(dirs, key=shared.natural_sort_key):
135
+ x = os.path.join(root, dirname)
136
+
137
+ if not os.path.isdir(x):
138
+ continue
139
+
140
+ subdir = os.path.abspath(x)[len(parentdir):].replace("\\", "/")
141
+ while subdir.startswith("/"):
142
+ subdir = subdir[1:]
143
+
144
+ is_empty = len(os.listdir(x)) == 0
145
+ if not is_empty and not subdir.endswith("/"):
146
+ subdir = subdir + "/"
147
+
148
+ if ("/." in subdir or subdir.startswith(".")) and not shared.opts.extra_networks_show_hidden_directories:
149
+ continue
150
+
151
+ subdirs[subdir] = 1
152
+
153
+ if subdirs:
154
+ subdirs = {"": 1, **subdirs}
155
+
156
+ subdirs_html = "".join([f"""
157
+ <button class='lg secondary gradio-button custom-button{" search-all" if subdir=="" else ""}' onclick='extraNetworksSearchButton("{tabname}_extra_search", event)'>
158
+ {html.escape(subdir if subdir!="" else "all")}
159
+ </button>
160
+ """ for subdir in subdirs])
161
+
162
+ self.items = {x["name"]: x for x in self.list_items()}
163
+ for item in self.items.values():
164
+ metadata = item.get("metadata")
165
+ if metadata:
166
+ self.metadata[item["name"]] = metadata
167
+
168
+ if "user_metadata" not in item:
169
+ self.read_user_metadata(item)
170
+
171
+ items_html += self.create_html_for_item(item, tabname)
172
+
173
+ if items_html == '':
174
+ dirs = "".join([f"<li>{x}</li>" for x in self.allowed_directories_for_previews()])
175
+ items_html = shared.html("extra-networks-no-cards.html").format(dirs=dirs)
176
+
177
+ self_name_id = self.name.replace(" ", "_")
178
+
179
+ res = f"""
180
+ <div id='{tabname}_{self_name_id}_subdirs' class='extra-network-subdirs extra-network-subdirs-cards'>
181
+ {subdirs_html}
182
+ </div>
183
+ <div id='{tabname}_{self_name_id}_cards' class='extra-network-cards'>
184
+ {items_html}
185
+ </div>
186
+ """
187
+
188
+ return res
189
+
190
+ def create_item(self, name, index=None):
191
+ raise NotImplementedError()
192
+
193
+ def list_items(self):
194
+ raise NotImplementedError()
195
+
196
+ def allowed_directories_for_previews(self):
197
+ return []
198
+
199
+ def create_html_for_item(self, item, tabname):
200
+ """
201
+ Create HTML for card item in tab tabname; can return empty string if the item is not meant to be shown.
202
+ """
203
+
204
+ preview = item.get("preview", None)
205
+
206
+ onclick = item.get("onclick", None)
207
+ if onclick is None:
208
+ onclick = '"' + html.escape(f"""return cardClicked({quote_js(tabname)}, {item["prompt"]}, {"true" if self.allow_negative_prompt else "false"})""") + '"'
209
+
210
+ height = f"height: {shared.opts.extra_networks_card_height}px;" if shared.opts.extra_networks_card_height else ''
211
+ width = f"width: {shared.opts.extra_networks_card_width}px;" if shared.opts.extra_networks_card_width else ''
212
+ background_image = f'<img src="{html.escape(preview)}" class="preview" loading="lazy">' if preview else ''
213
+ metadata_button = ""
214
+ metadata = item.get("metadata")
215
+ if metadata:
216
+ metadata_button = f"<div class='metadata-button card-button' title='Show internal metadata' onclick='extraNetworksRequestMetadata(event, {quote_js(self.name)}, {quote_js(item['name'])})'></div>"
217
+
218
+ edit_button = f"<div class='edit-button card-button' title='Edit metadata' onclick='extraNetworksEditUserMetadata(event, {quote_js(tabname)}, {quote_js(self.id_page)}, {quote_js(item['name'])})'></div>"
219
+
220
+ local_path = ""
221
+ filename = item.get("filename", "")
222
+ for reldir in self.allowed_directories_for_previews():
223
+ absdir = os.path.abspath(reldir)
224
+
225
+ if filename.startswith(absdir):
226
+ local_path = filename[len(absdir):]
227
+
228
+ # if this is true, the item must not be shown in the default view, and must instead only be
229
+ # shown when searching for it
230
+ if shared.opts.extra_networks_hidden_models == "Always":
231
+ search_only = False
232
+ else:
233
+ search_only = "/." in local_path or "\\." in local_path
234
+
235
+ if search_only and shared.opts.extra_networks_hidden_models == "Never":
236
+ return ""
237
+
238
+ sort_keys = " ".join([html.escape(f'data-sort-{k}={v}') for k, v in item.get("sort_keys", {}).items()]).strip()
239
+
240
+ args = {
241
+ "background_image": background_image,
242
+ "style": f"'display: none; {height}{width}; font-size: {shared.opts.extra_networks_card_text_scale*100}%'",
243
+ "prompt": item.get("prompt", None),
244
+ "tabname": quote_js(tabname),
245
+ "local_preview": quote_js(item["local_preview"]),
246
+ "name": html.escape(item["name"]),
247
+ "description": (item.get("description") or "" if shared.opts.extra_networks_card_show_desc else ""),
248
+ "card_clicked": onclick,
249
+ "save_card_preview": '"' + html.escape(f"""return saveCardPreview(event, {quote_js(tabname)}, {quote_js(item["local_preview"])})""") + '"',
250
+ "search_term": item.get("search_term", ""),
251
+ "metadata_button": metadata_button,
252
+ "edit_button": edit_button,
253
+ "search_only": " search_only" if search_only else "",
254
+ "sort_keys": sort_keys,
255
+ }
256
+
257
+ return self.card_page.format(**args)
258
+
259
+ def get_sort_keys(self, path):
260
+ """
261
+ List of default keys used for sorting in the UI.
262
+ """
263
+ pth = Path(path)
264
+ stat = pth.stat()
265
+ return {
266
+ "date_created": int(stat.st_ctime or 0),
267
+ "date_modified": int(stat.st_mtime or 0),
268
+ "name": pth.name.lower(),
269
+ }
270
+
271
+ def find_preview(self, path):
272
+ """
273
+ Find a preview PNG for a given path (without extension) and call link_preview on it.
274
+ """
275
+
276
+ preview_extensions = ["png", "jpg", "jpeg", "webp"]
277
+ if shared.opts.samples_format not in preview_extensions:
278
+ preview_extensions.append(shared.opts.samples_format)
279
+
280
+ potential_files = sum([[path + "." + ext, path + ".preview." + ext] for ext in preview_extensions], [])
281
+
282
+ for file in potential_files:
283
+ if os.path.isfile(file):
284
+ return self.link_preview(file)
285
+
286
+ return None
287
+
288
+ def find_description(self, path):
289
+ """
290
+ Find and read a description file for a given path (without extension).
291
+ """
292
+ for file in [f"{path}.txt", f"{path}.description.txt"]:
293
+ try:
294
+ with open(file, "r", encoding="utf-8", errors="replace") as f:
295
+ return f.read()
296
+ except OSError:
297
+ pass
298
+ return None
299
+
300
+ def create_user_metadata_editor(self, ui, tabname):
301
+ return ui_extra_networks_user_metadata.UserMetadataEditor(ui, tabname, self)
302
+
303
+
304
+ def initialize():
305
+ extra_pages.clear()
306
+
307
+
308
+ def register_default_pages():
309
+ from modules.ui_extra_networks_textual_inversion import ExtraNetworksPageTextualInversion
310
+ from modules.ui_extra_networks_hypernets import ExtraNetworksPageHypernetworks
311
+ from modules.ui_extra_networks_checkpoints import ExtraNetworksPageCheckpoints
312
+ register_page(ExtraNetworksPageTextualInversion())
313
+ register_page(ExtraNetworksPageHypernetworks())
314
+ register_page(ExtraNetworksPageCheckpoints())
315
+
316
+
317
+ class ExtraNetworksUi:
318
+ def __init__(self):
319
+ self.pages = None
320
+ """gradio HTML components related to extra networks' pages"""
321
+
322
+ self.page_contents = None
323
+ """HTML content of the above; empty initially, filled when extra pages have to be shown"""
324
+
325
+ self.stored_extra_pages = None
326
+
327
+ self.button_save_preview = None
328
+ self.preview_target_filename = None
329
+
330
+ self.tabname = None
331
+
332
+
333
+ def pages_in_preferred_order(pages):
334
+ tab_order = [x.lower().strip() for x in shared.opts.ui_extra_networks_tab_reorder.split(",")]
335
+
336
+ def tab_name_score(name):
337
+ name = name.lower()
338
+ for i, possible_match in enumerate(tab_order):
339
+ if possible_match in name:
340
+ return i
341
+
342
+ return len(pages)
343
+
344
+ tab_scores = {page.name: (tab_name_score(page.name), original_index) for original_index, page in enumerate(pages)}
345
+
346
+ return sorted(pages, key=lambda x: tab_scores[x.name])
347
+
348
+
349
+ def create_ui(interface: gr.Blocks, unrelated_tabs, tabname):
350
+ from modules.ui import switch_values_symbol
351
+
352
+ ui = ExtraNetworksUi()
353
+ ui.pages = []
354
+ ui.pages_contents = []
355
+ ui.user_metadata_editors = []
356
+ ui.stored_extra_pages = pages_in_preferred_order(extra_pages.copy())
357
+ ui.tabname = tabname
358
+
359
+ related_tabs = []
360
+
361
+ for page in ui.stored_extra_pages:
362
+ with gr.Tab(page.title, id=page.id_page) as tab:
363
+ elem_id = f"{tabname}_{page.id_page}_cards_html"
364
+ page_elem = gr.HTML('Loading...', elem_id=elem_id)
365
+ ui.pages.append(page_elem)
366
+
367
+ page_elem.change(fn=lambda: None, _js='function(){applyExtraNetworkFilter(' + quote_js(tabname) + '); return []}', inputs=[], outputs=[])
368
+
369
+ editor = page.create_user_metadata_editor(ui, tabname)
370
+ editor.create_ui()
371
+ ui.user_metadata_editors.append(editor)
372
+
373
+ related_tabs.append(tab)
374
+
375
+ edit_search = gr.Textbox('', show_label=False, elem_id=tabname+"_extra_search", elem_classes="search", placeholder="Search...", visible=False, interactive=True)
376
+ dropdown_sort = gr.Dropdown(choices=['Default Sort', 'Date Created', 'Date Modified', 'Name'], value='Default Sort', elem_id=tabname+"_extra_sort", elem_classes="sort", multiselect=False, visible=False, show_label=False, interactive=True, label=tabname+"_extra_sort_order")
377
+ button_sortorder = ToolButton(switch_values_symbol, elem_id=tabname+"_extra_sortorder", elem_classes="sortorder", visible=False)
378
+ button_refresh = gr.Button('Refresh', elem_id=tabname+"_extra_refresh", visible=False)
379
+ checkbox_show_dirs = gr.Checkbox(True, label='Show dirs', elem_id=tabname+"_extra_show_dirs", elem_classes="show-dirs", visible=False)
380
+
381
+ ui.button_save_preview = gr.Button('Save preview', elem_id=tabname+"_save_preview", visible=False)
382
+ ui.preview_target_filename = gr.Textbox('Preview save filename', elem_id=tabname+"_preview_filename", visible=False)
383
+
384
+ for tab in unrelated_tabs:
385
+ tab.select(fn=lambda: [gr.update(visible=False) for _ in range(5)], inputs=[], outputs=[edit_search, dropdown_sort, button_sortorder, button_refresh, checkbox_show_dirs], show_progress=False)
386
+
387
+ for tab in related_tabs:
388
+ tab.select(fn=lambda: [gr.update(visible=True) for _ in range(5)], inputs=[], outputs=[edit_search, dropdown_sort, button_sortorder, button_refresh, checkbox_show_dirs], show_progress=False)
389
+
390
+ def pages_html():
391
+ if not ui.pages_contents:
392
+ return refresh()
393
+
394
+ return ui.pages_contents
395
+
396
+ def refresh():
397
+ for pg in ui.stored_extra_pages:
398
+ pg.refresh()
399
+
400
+ ui.pages_contents = [pg.create_html(ui.tabname) for pg in ui.stored_extra_pages]
401
+
402
+ return ui.pages_contents
403
+
404
+ interface.load(fn=pages_html, inputs=[], outputs=[*ui.pages])
405
+ button_refresh.click(fn=refresh, inputs=[], outputs=ui.pages)
406
+
407
+ return ui
408
+
409
+
410
+ def path_is_parent(parent_path, child_path):
411
+ parent_path = os.path.abspath(parent_path)
412
+ child_path = os.path.abspath(child_path)
413
+
414
+ return child_path.startswith(parent_path)
415
+
416
+
417
+ def setup_ui(ui, gallery):
418
+ def save_preview(index, images, filename):
419
+ # this function is here for backwards compatibility and likely will be removed soon
420
+
421
+ if len(images) == 0:
422
+ print("There is no image in gallery to save as a preview.")
423
+ return [page.create_html(ui.tabname) for page in ui.stored_extra_pages]
424
+
425
+ index = int(index)
426
+ index = 0 if index < 0 else index
427
+ index = len(images) - 1 if index >= len(images) else index
428
+
429
+ img_info = images[index if index >= 0 else 0]
430
+ image = image_from_url_text(img_info)
431
+ geninfo, items = read_info_from_image(image)
432
+
433
+ is_allowed = False
434
+ for extra_page in ui.stored_extra_pages:
435
+ if any(path_is_parent(x, filename) for x in extra_page.allowed_directories_for_previews()):
436
+ is_allowed = True
437
+ break
438
+
439
+ assert is_allowed, f'writing to {filename} is not allowed'
440
+
441
+ save_image_with_geninfo(image, geninfo, filename)
442
+
443
+ return [page.create_html(ui.tabname) for page in ui.stored_extra_pages]
444
+
445
+ ui.button_save_preview.click(
446
+ fn=save_preview,
447
+ _js="function(x, y, z){return [selected_gallery_index(), y, z]}",
448
+ inputs=[ui.preview_target_filename, gallery, ui.preview_target_filename],
449
+ outputs=[*ui.pages]
450
+ )
451
+
452
+ for editor in ui.user_metadata_editors:
453
+ editor.setup_ui(gallery)
454
+
455
+
modules/ui_extra_networks_checkpoints.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import html
2
+ import os
3
+
4
+ from modules import shared, ui_extra_networks, sd_models
5
+ from modules.ui_extra_networks import quote_js
6
+ from modules.ui_extra_networks_checkpoints_user_metadata import CheckpointUserMetadataEditor
7
+
8
+
9
+ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage):
10
+ def __init__(self):
11
+ super().__init__('Checkpoints')
12
+
13
+ def refresh(self):
14
+ shared.refresh_checkpoints()
15
+
16
+ def create_item(self, name, index=None, enable_filter=True):
17
+ checkpoint: sd_models.CheckpointInfo = sd_models.checkpoint_aliases.get(name)
18
+ path, ext = os.path.splitext(checkpoint.filename)
19
+ return {
20
+ "name": checkpoint.name_for_extra,
21
+ "filename": checkpoint.filename,
22
+ "shorthash": checkpoint.shorthash,
23
+ "preview": self.find_preview(path),
24
+ "description": self.find_description(path),
25
+ "search_term": self.search_terms_from_path(checkpoint.filename) + " " + (checkpoint.sha256 or ""),
26
+ "onclick": '"' + html.escape(f"""return selectCheckpoint({quote_js(name)})""") + '"',
27
+ "local_preview": f"{path}.{shared.opts.samples_format}",
28
+ "metadata": checkpoint.metadata,
29
+ "sort_keys": {'default': index, **self.get_sort_keys(checkpoint.filename)},
30
+ }
31
+
32
+ def list_items(self):
33
+ names = list(sd_models.checkpoints_list)
34
+ for index, name in enumerate(names):
35
+ yield self.create_item(name, index)
36
+
37
+ def allowed_directories_for_previews(self):
38
+ return [v for v in [shared.cmd_opts.ckpt_dir, sd_models.model_path] if v is not None]
39
+
40
+ def create_user_metadata_editor(self, ui, tabname):
41
+ return CheckpointUserMetadataEditor(ui, tabname, self)
modules/ui_extra_networks_checkpoints_user_metadata.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ from modules import ui_extra_networks_user_metadata, sd_vae, shared
4
+ from modules.ui_common import create_refresh_button
5
+
6
+
7
+ class CheckpointUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor):
8
+ def __init__(self, ui, tabname, page):
9
+ super().__init__(ui, tabname, page)
10
+
11
+ self.select_vae = None
12
+
13
+ def save_user_metadata(self, name, desc, notes, vae):
14
+ user_metadata = self.get_user_metadata(name)
15
+ user_metadata["description"] = desc
16
+ user_metadata["notes"] = notes
17
+ user_metadata["vae"] = vae
18
+
19
+ self.write_user_metadata(name, user_metadata)
20
+
21
+ def update_vae(self, name):
22
+ if name == shared.sd_model.sd_checkpoint_info.name_for_extra:
23
+ sd_vae.reload_vae_weights()
24
+
25
+ def put_values_into_components(self, name):
26
+ user_metadata = self.get_user_metadata(name)
27
+ values = super().put_values_into_components(name)
28
+
29
+ return [
30
+ *values[0:5],
31
+ user_metadata.get('vae', ''),
32
+ ]
33
+
34
+ def create_editor(self):
35
+ self.create_default_editor_elems()
36
+
37
+ with gr.Row():
38
+ self.select_vae = gr.Dropdown(choices=["Automatic", "None"] + list(sd_vae.vae_dict), value="None", label="Preferred VAE", elem_id="checpoint_edit_user_metadata_preferred_vae")
39
+ create_refresh_button(self.select_vae, sd_vae.refresh_vae_list, lambda: {"choices": ["Automatic", "None"] + list(sd_vae.vae_dict)}, "checpoint_edit_user_metadata_refresh_preferred_vae")
40
+
41
+ self.edit_notes = gr.TextArea(label='Notes', lines=4)
42
+
43
+ self.create_default_buttons()
44
+
45
+ viewed_components = [
46
+ self.edit_name,
47
+ self.edit_description,
48
+ self.html_filedata,
49
+ self.html_preview,
50
+ self.edit_notes,
51
+ self.select_vae,
52
+ ]
53
+
54
+ self.button_edit\
55
+ .click(fn=self.put_values_into_components, inputs=[self.edit_name_input], outputs=viewed_components)\
56
+ .then(fn=lambda: gr.update(visible=True), inputs=[], outputs=[self.box])
57
+
58
+ edited_components = [
59
+ self.edit_description,
60
+ self.edit_notes,
61
+ self.select_vae,
62
+ ]
63
+
64
+ self.setup_save_handler(self.button_save, self.save_user_metadata, edited_components)
65
+ self.button_save.click(fn=self.update_vae, inputs=[self.edit_name_input])
66
+
modules/ui_extra_networks_hypernets.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from modules import shared, ui_extra_networks
4
+ from modules.ui_extra_networks import quote_js
5
+ from modules.hashes import sha256_from_cache
6
+
7
+
8
+ class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage):
9
+ def __init__(self):
10
+ super().__init__('Hypernetworks')
11
+
12
+ def refresh(self):
13
+ shared.reload_hypernetworks()
14
+
15
+ def create_item(self, name, index=None, enable_filter=True):
16
+ full_path = shared.hypernetworks[name]
17
+ path, ext = os.path.splitext(full_path)
18
+ sha256 = sha256_from_cache(full_path, f'hypernet/{name}')
19
+ shorthash = sha256[0:10] if sha256 else None
20
+
21
+ return {
22
+ "name": name,
23
+ "filename": full_path,
24
+ "shorthash": shorthash,
25
+ "preview": self.find_preview(path),
26
+ "description": self.find_description(path),
27
+ "search_term": self.search_terms_from_path(path) + " " + (sha256 or ""),
28
+ "prompt": quote_js(f"<hypernet:{name}:") + " + opts.extra_networks_default_multiplier + " + quote_js(">"),
29
+ "local_preview": f"{path}.preview.{shared.opts.samples_format}",
30
+ "sort_keys": {'default': index, **self.get_sort_keys(path + ext)},
31
+ }
32
+
33
+ def list_items(self):
34
+ for index, name in enumerate(shared.hypernetworks):
35
+ yield self.create_item(name, index)
36
+
37
+ def allowed_directories_for_previews(self):
38
+ return [shared.cmd_opts.hypernetwork_dir]
39
+
modules/ui_extra_networks_textual_inversion.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from modules import ui_extra_networks, sd_hijack, shared
4
+ from modules.ui_extra_networks import quote_js
5
+
6
+
7
+ class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage):
8
+ def __init__(self):
9
+ super().__init__('Textual Inversion')
10
+ self.allow_negative_prompt = True
11
+
12
+ def refresh(self):
13
+ sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True)
14
+
15
+ def create_item(self, name, index=None, enable_filter=True):
16
+ embedding = sd_hijack.model_hijack.embedding_db.word_embeddings.get(name)
17
+
18
+ path, ext = os.path.splitext(embedding.filename)
19
+ return {
20
+ "name": name,
21
+ "filename": embedding.filename,
22
+ "shorthash": embedding.shorthash,
23
+ "preview": self.find_preview(path),
24
+ "description": self.find_description(path),
25
+ "search_term": self.search_terms_from_path(embedding.filename) + " " + (embedding.hash or ""),
26
+ "prompt": quote_js(embedding.name),
27
+ "local_preview": f"{path}.preview.{shared.opts.samples_format}",
28
+ "sort_keys": {'default': index, **self.get_sort_keys(embedding.filename)},
29
+ }
30
+
31
+ def list_items(self):
32
+ for index, name in enumerate(sd_hijack.model_hijack.embedding_db.word_embeddings):
33
+ yield self.create_item(name, index)
34
+
35
+ def allowed_directories_for_previews(self):
36
+ return list(sd_hijack.model_hijack.embedding_db.embedding_dirs)
modules/ui_extra_networks_user_metadata.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import html
3
+ import json
4
+ import os.path
5
+
6
+ import gradio as gr
7
+
8
+ from modules import generation_parameters_copypaste, images, sysinfo, errors, ui_extra_networks
9
+
10
+
11
+ class UserMetadataEditor:
12
+
13
+ def __init__(self, ui, tabname, page):
14
+ self.ui = ui
15
+ self.tabname = tabname
16
+ self.page = page
17
+ self.id_part = f"{self.tabname}_{self.page.id_page}_edit_user_metadata"
18
+
19
+ self.box = None
20
+
21
+ self.edit_name_input = None
22
+ self.button_edit = None
23
+
24
+ self.edit_name = None
25
+ self.edit_description = None
26
+ self.edit_notes = None
27
+ self.html_filedata = None
28
+ self.html_preview = None
29
+ self.html_status = None
30
+
31
+ self.button_cancel = None
32
+ self.button_replace_preview = None
33
+ self.button_save = None
34
+
35
+ def get_user_metadata(self, name):
36
+ item = self.page.items.get(name, {})
37
+
38
+ user_metadata = item.get('user_metadata', None)
39
+ if not user_metadata:
40
+ user_metadata = {'description': item.get('description', '')}
41
+ item['user_metadata'] = user_metadata
42
+
43
+ return user_metadata
44
+
45
+ def create_extra_default_items_in_left_column(self):
46
+ pass
47
+
48
+ def create_default_editor_elems(self):
49
+ with gr.Row():
50
+ with gr.Column(scale=2):
51
+ self.edit_name = gr.HTML(elem_classes="extra-network-name")
52
+ self.edit_description = gr.Textbox(label="Description", lines=4)
53
+ self.html_filedata = gr.HTML()
54
+
55
+ self.create_extra_default_items_in_left_column()
56
+
57
+ with gr.Column(scale=1, min_width=0):
58
+ self.html_preview = gr.HTML()
59
+
60
+ def create_default_buttons(self):
61
+
62
+ with gr.Row(elem_classes="edit-user-metadata-buttons"):
63
+ self.button_cancel = gr.Button('Cancel')
64
+ self.button_replace_preview = gr.Button('Replace preview', variant='primary')
65
+ self.button_save = gr.Button('Save', variant='primary')
66
+
67
+ self.html_status = gr.HTML(elem_classes="edit-user-metadata-status")
68
+
69
+ self.button_cancel.click(fn=None, _js="closePopup")
70
+
71
+ def get_card_html(self, name):
72
+ item = self.page.items.get(name, {})
73
+
74
+ preview_url = item.get("preview", None)
75
+
76
+ if not preview_url:
77
+ filename, _ = os.path.splitext(item["filename"])
78
+ preview_url = self.page.find_preview(filename)
79
+ item["preview"] = preview_url
80
+
81
+ if preview_url:
82
+ preview = f'''
83
+ <div class='card standalone-card-preview'>
84
+ <img src="{html.escape(preview_url)}" class="preview">
85
+ </div>
86
+ '''
87
+ else:
88
+ preview = "<div class='card standalone-card-preview'></div>"
89
+
90
+ return preview
91
+
92
+ def relative_path(self, path):
93
+ for parent_path in self.page.allowed_directories_for_previews():
94
+ if ui_extra_networks.path_is_parent(parent_path, path):
95
+ return os.path.relpath(path, parent_path)
96
+
97
+ return os.path.basename(path)
98
+
99
+ def get_metadata_table(self, name):
100
+ item = self.page.items.get(name, {})
101
+ try:
102
+ filename = item["filename"]
103
+ shorthash = item.get("shorthash", None)
104
+
105
+ stats = os.stat(filename)
106
+ params = [
107
+ ('Filename: ', self.relative_path(filename)),
108
+ ('File size: ', sysinfo.pretty_bytes(stats.st_size)),
109
+ ('Hash: ', shorthash),
110
+ ('Modified: ', datetime.datetime.fromtimestamp(stats.st_mtime).strftime('%Y-%m-%d %H:%M')),
111
+ ]
112
+
113
+ return params
114
+ except Exception as e:
115
+ errors.display(e, f"reading info for {name}")
116
+ return []
117
+
118
+ def put_values_into_components(self, name):
119
+ user_metadata = self.get_user_metadata(name)
120
+
121
+ try:
122
+ params = self.get_metadata_table(name)
123
+ except Exception as e:
124
+ errors.display(e, f"reading metadata info for {name}")
125
+ params = []
126
+
127
+ table = '<table class="file-metadata">' + "".join(f"<tr><th>{name}</th><td>{value}</td></tr>" for name, value in params if value is not None) + '</table>'
128
+
129
+ return html.escape(name), user_metadata.get('description', ''), table, self.get_card_html(name), user_metadata.get('notes', '')
130
+
131
+ def write_user_metadata(self, name, metadata):
132
+ item = self.page.items.get(name, {})
133
+ filename = item.get("filename", None)
134
+ basename, ext = os.path.splitext(filename)
135
+
136
+ with open(basename + '.json', "w", encoding="utf8") as file:
137
+ json.dump(metadata, file, indent=4)
138
+
139
+ def save_user_metadata(self, name, desc, notes):
140
+ user_metadata = self.get_user_metadata(name)
141
+ user_metadata["description"] = desc
142
+ user_metadata["notes"] = notes
143
+
144
+ self.write_user_metadata(name, user_metadata)
145
+
146
+ def setup_save_handler(self, button, func, components):
147
+ button\
148
+ .click(fn=func, inputs=[self.edit_name_input, *components], outputs=[])\
149
+ .then(fn=None, _js="function(name){closePopup(); extraNetworksRefreshSingleCard(" + json.dumps(self.page.name) + "," + json.dumps(self.tabname) + ", name);}", inputs=[self.edit_name_input], outputs=[])
150
+
151
+ def create_editor(self):
152
+ self.create_default_editor_elems()
153
+
154
+ self.edit_notes = gr.TextArea(label='Notes', lines=4)
155
+
156
+ self.create_default_buttons()
157
+
158
+ self.button_edit\
159
+ .click(fn=self.put_values_into_components, inputs=[self.edit_name_input], outputs=[self.edit_name, self.edit_description, self.html_filedata, self.html_preview, self.edit_notes])\
160
+ .then(fn=lambda: gr.update(visible=True), inputs=[], outputs=[self.box])
161
+
162
+ self.setup_save_handler(self.button_save, self.save_user_metadata, [self.edit_description, self.edit_notes])
163
+
164
+ def create_ui(self):
165
+ with gr.Box(visible=False, elem_id=self.id_part, elem_classes="edit-user-metadata") as box:
166
+ self.box = box
167
+
168
+ self.edit_name_input = gr.Textbox("Edit user metadata card id", visible=False, elem_id=f"{self.id_part}_name")
169
+ self.button_edit = gr.Button("Edit user metadata", visible=False, elem_id=f"{self.id_part}_button")
170
+
171
+ self.create_editor()
172
+
173
+ def save_preview(self, index, gallery, name):
174
+ if len(gallery) == 0:
175
+ return self.get_card_html(name), "There is no image in gallery to save as a preview."
176
+
177
+ item = self.page.items.get(name, {})
178
+
179
+ index = int(index)
180
+ index = 0 if index < 0 else index
181
+ index = len(gallery) - 1 if index >= len(gallery) else index
182
+
183
+ img_info = gallery[index if index >= 0 else 0]
184
+ image = generation_parameters_copypaste.image_from_url_text(img_info)
185
+ geninfo, items = images.read_info_from_image(image)
186
+
187
+ images.save_image_with_geninfo(image, geninfo, item["local_preview"])
188
+
189
+ return self.get_card_html(name), ''
190
+
191
+ def setup_ui(self, gallery):
192
+ self.button_replace_preview.click(
193
+ fn=self.save_preview,
194
+ _js="function(x, y, z){return [selected_gallery_index(), y, z]}",
195
+ inputs=[self.edit_name_input, gallery, self.edit_name_input],
196
+ outputs=[self.html_preview, self.html_status]
197
+ ).then(
198
+ fn=None,
199
+ _js="function(name){extraNetworksRefreshSingleCard(" + json.dumps(self.page.name) + "," + json.dumps(self.tabname) + ", name);}",
200
+ inputs=[self.edit_name_input],
201
+ outputs=[]
202
+ )
203
+
204
+
205
+
modules/ui_gradio_extensions.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+
4
+ from modules import localization, shared, scripts
5
+ from modules.paths import script_path, data_path
6
+
7
+
8
+ def webpath(fn):
9
+ if fn.startswith(script_path):
10
+ web_path = os.path.relpath(fn, script_path).replace('\\', '/')
11
+ else:
12
+ web_path = os.path.abspath(fn)
13
+
14
+ return f'file={web_path}?{os.path.getmtime(fn)}'
15
+
16
+
17
+ def javascript_html():
18
+ # Ensure localization is in `window` before scripts
19
+ head = f'<script type="text/javascript">{localization.localization_js(shared.opts.localization)}</script>\n'
20
+
21
+ script_js = os.path.join(script_path, "script.js")
22
+ head += f'<script type="text/javascript" src="{webpath(script_js)}"></script>\n'
23
+
24
+ for script in scripts.list_scripts("javascript", ".js"):
25
+ head += f'<script type="text/javascript" src="{webpath(script.path)}"></script>\n'
26
+
27
+ for script in scripts.list_scripts("javascript", ".mjs"):
28
+ head += f'<script type="module" src="{webpath(script.path)}"></script>\n'
29
+
30
+ if shared.cmd_opts.theme:
31
+ head += f'<script type="text/javascript">set_theme(\"{shared.cmd_opts.theme}\");</script>\n'
32
+
33
+ return head
34
+
35
+
36
+ def css_html():
37
+ head = ""
38
+
39
+ def stylesheet(fn):
40
+ return f'<link rel="stylesheet" property="stylesheet" href="{webpath(fn)}">'
41
+
42
+ for cssfile in scripts.list_files_with_name("style.css"):
43
+ if not os.path.isfile(cssfile):
44
+ continue
45
+
46
+ head += stylesheet(cssfile)
47
+
48
+ if os.path.exists(os.path.join(data_path, "user.css")):
49
+ head += stylesheet(os.path.join(data_path, "user.css"))
50
+
51
+ return head
52
+
53
+
54
+ def reload_javascript():
55
+ js = javascript_html()
56
+ css = css_html()
57
+
58
+ def template_response(*args, **kwargs):
59
+ res = shared.GradioTemplateResponseOriginal(*args, **kwargs)
60
+ res.body = res.body.replace(b'</head>', f'{js}</head>'.encode("utf8"))
61
+ res.body = res.body.replace(b'</body>', f'{css}</body>'.encode("utf8"))
62
+ res.init_headers()
63
+ return res
64
+
65
+ gr.routes.templates.TemplateResponse = template_response
66
+
67
+
68
+ if not hasattr(shared, 'GradioTemplateResponseOriginal'):
69
+ shared.GradioTemplateResponseOriginal = gr.routes.templates.TemplateResponse
modules/ui_loadsave.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+
4
+ import gradio as gr
5
+
6
+ from modules import errors
7
+ from modules.ui_components import ToolButton
8
+
9
+
10
+ def radio_choices(comp): # gradio 3.41 changes choices from list of values to list of pairs
11
+ return [x[0] if isinstance(x, tuple) else x for x in getattr(comp, 'choices', [])]
12
+
13
+
14
+ class UiLoadsave:
15
+ """allows saving and restoring default values for gradio components"""
16
+
17
+ def __init__(self, filename):
18
+ self.filename = filename
19
+ self.ui_settings = {}
20
+ self.component_mapping = {}
21
+ self.error_loading = False
22
+ self.finalized_ui = False
23
+
24
+ self.ui_defaults_view = None
25
+ self.ui_defaults_apply = None
26
+ self.ui_defaults_review = None
27
+
28
+ try:
29
+ if os.path.exists(self.filename):
30
+ self.ui_settings = self.read_from_file()
31
+ except Exception as e:
32
+ self.error_loading = True
33
+ errors.display(e, "loading settings")
34
+
35
+
36
+
37
+ def add_component(self, path, x):
38
+ """adds component to the registry of tracked components"""
39
+
40
+ assert not self.finalized_ui
41
+
42
+ def apply_field(obj, field, condition=None, init_field=None):
43
+ key = f"{path}/{field}"
44
+
45
+ if getattr(obj, 'custom_script_source', None) is not None:
46
+ key = f"customscript/{obj.custom_script_source}/{key}"
47
+
48
+ if getattr(obj, 'do_not_save_to_config', False):
49
+ return
50
+
51
+ saved_value = self.ui_settings.get(key, None)
52
+ if saved_value is None:
53
+ self.ui_settings[key] = getattr(obj, field)
54
+ elif condition and not condition(saved_value):
55
+ pass
56
+ else:
57
+ if isinstance(x, gr.Textbox) and field == 'value': # due to an undesirable behavior of gr.Textbox, if you give it an int value instead of str, everything dies
58
+ saved_value = str(saved_value)
59
+ elif isinstance(x, gr.Number) and field == 'value':
60
+ try:
61
+ saved_value = float(saved_value)
62
+ except ValueError:
63
+ return
64
+
65
+ setattr(obj, field, saved_value)
66
+ if init_field is not None:
67
+ init_field(saved_value)
68
+
69
+ if field == 'value' and key not in self.component_mapping:
70
+ self.component_mapping[key] = x
71
+
72
+ if type(x) in [gr.Slider, gr.Radio, gr.Checkbox, gr.Textbox, gr.Number, gr.Dropdown, ToolButton, gr.Button] and x.visible:
73
+ apply_field(x, 'visible')
74
+
75
+ if type(x) == gr.Slider:
76
+ apply_field(x, 'value')
77
+ apply_field(x, 'minimum')
78
+ apply_field(x, 'maximum')
79
+ apply_field(x, 'step')
80
+
81
+ if type(x) == gr.Radio:
82
+ apply_field(x, 'value', lambda val: val in radio_choices(x))
83
+
84
+ if type(x) == gr.Checkbox:
85
+ apply_field(x, 'value')
86
+
87
+ if type(x) == gr.Textbox:
88
+ apply_field(x, 'value')
89
+
90
+ if type(x) == gr.Number:
91
+ apply_field(x, 'value')
92
+
93
+ if type(x) == gr.Dropdown:
94
+ def check_dropdown(val):
95
+ choices = radio_choices(x)
96
+ if getattr(x, 'multiselect', False):
97
+ return all(value in choices for value in val)
98
+ else:
99
+ return val in choices
100
+
101
+ apply_field(x, 'value', check_dropdown, getattr(x, 'init_field', None))
102
+
103
+ def check_tab_id(tab_id):
104
+ tab_items = list(filter(lambda e: isinstance(e, gr.TabItem), x.children))
105
+ if type(tab_id) == str:
106
+ tab_ids = [t.id for t in tab_items]
107
+ return tab_id in tab_ids
108
+ elif type(tab_id) == int:
109
+ return 0 <= tab_id < len(tab_items)
110
+ else:
111
+ return False
112
+
113
+ if type(x) == gr.Tabs:
114
+ apply_field(x, 'selected', check_tab_id)
115
+
116
+ def add_block(self, x, path=""):
117
+ """adds all components inside a gradio block x to the registry of tracked components"""
118
+
119
+ if hasattr(x, 'children'):
120
+ if isinstance(x, gr.Tabs) and x.elem_id is not None:
121
+ # Tabs element can't have a label, have to use elem_id instead
122
+ self.add_component(f"{path}/Tabs@{x.elem_id}", x)
123
+ for c in x.children:
124
+ self.add_block(c, path)
125
+ elif x.label is not None:
126
+ self.add_component(f"{path}/{x.label}", x)
127
+ elif isinstance(x, gr.Button) and x.value is not None:
128
+ self.add_component(f"{path}/{x.value}", x)
129
+
130
+ def read_from_file(self):
131
+ with open(self.filename, "r", encoding="utf8") as file:
132
+ return json.load(file)
133
+
134
+ def write_to_file(self, current_ui_settings):
135
+ with open(self.filename, "w", encoding="utf8") as file:
136
+ json.dump(current_ui_settings, file, indent=4)
137
+
138
+ def dump_defaults(self):
139
+ """saves default values to a file unless tjhe file is present and there was an error loading default values at start"""
140
+
141
+ if self.error_loading and os.path.exists(self.filename):
142
+ return
143
+
144
+ self.write_to_file(self.ui_settings)
145
+
146
+ def iter_changes(self, current_ui_settings, values):
147
+ """
148
+ given a dictionary with defaults from a file and current values from gradio elements, returns
149
+ an iterator over tuples of values that are not the same between the file and the current;
150
+ tuple contents are: path, old value, new value
151
+ """
152
+
153
+ for (path, component), new_value in zip(self.component_mapping.items(), values):
154
+ old_value = current_ui_settings.get(path)
155
+
156
+ choices = radio_choices(component)
157
+ if isinstance(new_value, int) and choices:
158
+ if new_value >= len(choices):
159
+ continue
160
+
161
+ new_value = choices[new_value]
162
+ if isinstance(new_value, tuple):
163
+ new_value = new_value[0]
164
+
165
+ if new_value == old_value:
166
+ continue
167
+
168
+ if old_value is None and new_value == '' or new_value == []:
169
+ continue
170
+
171
+ yield path, old_value, new_value
172
+
173
+ def ui_view(self, *values):
174
+ text = ["<table><thead><tr><th>Path</th><th>Old value</th><th>New value</th></thead><tbody>"]
175
+
176
+ for path, old_value, new_value in self.iter_changes(self.read_from_file(), values):
177
+ if old_value is None:
178
+ old_value = "<span class='ui-defaults-none'>None</span>"
179
+
180
+ text.append(f"<tr><td>{path}</td><td>{old_value}</td><td>{new_value}</td></tr>")
181
+
182
+ if len(text) == 1:
183
+ text.append("<tr><td colspan=3>No changes</td></tr>")
184
+
185
+ text.append("</tbody>")
186
+ return "".join(text)
187
+
188
+ def ui_apply(self, *values):
189
+ num_changed = 0
190
+
191
+ current_ui_settings = self.read_from_file()
192
+
193
+ for path, _, new_value in self.iter_changes(current_ui_settings.copy(), values):
194
+ num_changed += 1
195
+ current_ui_settings[path] = new_value
196
+
197
+ if num_changed == 0:
198
+ return "No changes."
199
+
200
+ self.write_to_file(current_ui_settings)
201
+
202
+ return f"Wrote {num_changed} changes."
203
+
204
+ def create_ui(self):
205
+ """creates ui elements for editing defaults UI, without adding any logic to them"""
206
+
207
+ gr.HTML(
208
+ f"This page allows you to change default values in UI elements on other tabs.<br />"
209
+ f"Make your changes, press 'View changes' to review the changed default values,<br />"
210
+ f"then press 'Apply' to write them to {self.filename}.<br />"
211
+ f"New defaults will apply after you restart the UI.<br />"
212
+ )
213
+
214
+ with gr.Row():
215
+ self.ui_defaults_view = gr.Button(value='View changes', elem_id="ui_defaults_view", variant="secondary")
216
+ self.ui_defaults_apply = gr.Button(value='Apply', elem_id="ui_defaults_apply", variant="primary")
217
+
218
+ self.ui_defaults_review = gr.HTML("")
219
+
220
+ def setup_ui(self):
221
+ """adds logic to elements created with create_ui; all add_block class must be made before this"""
222
+
223
+ assert not self.finalized_ui
224
+ self.finalized_ui = True
225
+
226
+ self.ui_defaults_view.click(fn=self.ui_view, inputs=list(self.component_mapping.values()), outputs=[self.ui_defaults_review])
227
+ self.ui_defaults_apply.click(fn=self.ui_apply, inputs=list(self.component_mapping.values()), outputs=[self.ui_defaults_review])
modules/ui_postprocessing.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from modules import scripts, shared, ui_common, postprocessing, call_queue
3
+ import modules.generation_parameters_copypaste as parameters_copypaste
4
+
5
+
6
+ def create_ui():
7
+ tab_index = gr.State(value=0)
8
+
9
+ with gr.Row(equal_height=False, variant='compact'):
10
+ with gr.Column(variant='compact'):
11
+ with gr.Tabs(elem_id="mode_extras"):
12
+ with gr.TabItem('Single Image', id="single_image", elem_id="extras_single_tab") as tab_single:
13
+ extras_image = gr.Image(label="Source", source="upload", interactive=True, type="pil", elem_id="extras_image")
14
+
15
+ with gr.TabItem('Batch Process', id="batch_process", elem_id="extras_batch_process_tab") as tab_batch:
16
+ image_batch = gr.Files(label="Batch Process", interactive=True, elem_id="extras_image_batch")
17
+
18
+ with gr.TabItem('Batch from Directory', id="batch_from_directory", elem_id="extras_batch_directory_tab") as tab_batch_dir:
19
+ extras_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, placeholder="A directory on the same machine where the server is running.", elem_id="extras_batch_input_dir")
20
+ extras_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, placeholder="Leave blank to save images to the default path.", elem_id="extras_batch_output_dir")
21
+ show_extras_results = gr.Checkbox(label='Show result images', value=True, elem_id="extras_show_extras_results")
22
+
23
+ submit = gr.Button('Generate', elem_id="extras_generate", variant='primary')
24
+
25
+ script_inputs = scripts.scripts_postproc.setup_ui()
26
+
27
+ with gr.Column():
28
+ result_images, html_info_x, html_info, html_log = ui_common.create_output_panel("extras", shared.opts.outdir_extras_samples)
29
+
30
+ tab_single.select(fn=lambda: 0, inputs=[], outputs=[tab_index])
31
+ tab_batch.select(fn=lambda: 1, inputs=[], outputs=[tab_index])
32
+ tab_batch_dir.select(fn=lambda: 2, inputs=[], outputs=[tab_index])
33
+
34
+ submit.click(
35
+ fn=call_queue.wrap_gradio_gpu_call(postprocessing.run_postprocessing, extra_outputs=[None, '']),
36
+ inputs=[
37
+ tab_index,
38
+ extras_image,
39
+ image_batch,
40
+ extras_batch_input_dir,
41
+ extras_batch_output_dir,
42
+ show_extras_results,
43
+ *script_inputs
44
+ ],
45
+ outputs=[
46
+ result_images,
47
+ html_info_x,
48
+ html_info,
49
+ ]
50
+ )
51
+
52
+ parameters_copypaste.add_paste_fields("extras", extras_image, None)
53
+
54
+ extras_image.change(
55
+ fn=scripts.scripts_postproc.image_changed,
56
+ inputs=[], outputs=[]
57
+ )
modules/ui_prompt_styles.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ from modules import shared, ui_common, ui_components, styles
4
+
5
+ styles_edit_symbol = '\U0001f58c\uFE0F' # 🖌️
6
+ styles_materialize_symbol = '\U0001f4cb' # 📋
7
+
8
+
9
+ def select_style(name):
10
+ style = shared.prompt_styles.styles.get(name)
11
+ existing = style is not None
12
+ empty = not name
13
+
14
+ prompt = style.prompt if style else gr.update()
15
+ negative_prompt = style.negative_prompt if style else gr.update()
16
+
17
+ return prompt, negative_prompt, gr.update(visible=existing), gr.update(visible=not empty)
18
+
19
+
20
+ def save_style(name, prompt, negative_prompt):
21
+ if not name:
22
+ return gr.update(visible=False)
23
+
24
+ style = styles.PromptStyle(name, prompt, negative_prompt)
25
+ shared.prompt_styles.styles[style.name] = style
26
+ shared.prompt_styles.save_styles(shared.styles_filename)
27
+
28
+ return gr.update(visible=True)
29
+
30
+
31
+ def delete_style(name):
32
+ if name == "":
33
+ return
34
+
35
+ shared.prompt_styles.styles.pop(name, None)
36
+ shared.prompt_styles.save_styles(shared.styles_filename)
37
+
38
+ return '', '', ''
39
+
40
+
41
+ def materialize_styles(prompt, negative_prompt, styles):
42
+ prompt = shared.prompt_styles.apply_styles_to_prompt(prompt, styles)
43
+ negative_prompt = shared.prompt_styles.apply_negative_styles_to_prompt(negative_prompt, styles)
44
+
45
+ return [gr.Textbox.update(value=prompt), gr.Textbox.update(value=negative_prompt), gr.Dropdown.update(value=[])]
46
+
47
+
48
+ def refresh_styles():
49
+ return gr.update(choices=list(shared.prompt_styles.styles)), gr.update(choices=list(shared.prompt_styles.styles))
50
+
51
+
52
+ class UiPromptStyles:
53
+ def __init__(self, tabname, main_ui_prompt, main_ui_negative_prompt):
54
+ self.tabname = tabname
55
+
56
+ with gr.Row(elem_id=f"{tabname}_styles_row"):
57
+ self.dropdown = gr.Dropdown(label="Styles", show_label=False, elem_id=f"{tabname}_styles", choices=list(shared.prompt_styles.styles), value=[], multiselect=True, tooltip="Styles")
58
+ edit_button = ui_components.ToolButton(value=styles_edit_symbol, elem_id=f"{tabname}_styles_edit_button", tooltip="Edit styles")
59
+
60
+ with gr.Box(elem_id=f"{tabname}_styles_dialog", elem_classes="popup-dialog") as styles_dialog:
61
+ with gr.Row():
62
+ self.selection = gr.Dropdown(label="Styles", elem_id=f"{tabname}_styles_edit_select", choices=list(shared.prompt_styles.styles), value=[], allow_custom_value=True, info="Styles allow you to add custom text to prompt. Use the {prompt} token in style text, and it will be replaced with user's prompt when applying style. Otherwise, style's text will be added to the end of the prompt.")
63
+ ui_common.create_refresh_button([self.dropdown, self.selection], shared.prompt_styles.reload, lambda: {"choices": list(shared.prompt_styles.styles)}, f"refresh_{tabname}_styles")
64
+ self.materialize = ui_components.ToolButton(value=styles_materialize_symbol, elem_id=f"{tabname}_style_apply", tooltip="Apply all selected styles from the style selction dropdown in main UI to the prompt.")
65
+
66
+ with gr.Row():
67
+ self.prompt = gr.Textbox(label="Prompt", show_label=True, elem_id=f"{tabname}_edit_style_prompt", lines=3)
68
+
69
+ with gr.Row():
70
+ self.neg_prompt = gr.Textbox(label="Negative prompt", show_label=True, elem_id=f"{tabname}_edit_style_neg_prompt", lines=3)
71
+
72
+ with gr.Row():
73
+ self.save = gr.Button('Save', variant='primary', elem_id=f'{tabname}_edit_style_save', visible=False)
74
+ self.delete = gr.Button('Delete', variant='primary', elem_id=f'{tabname}_edit_style_delete', visible=False)
75
+ self.close = gr.Button('Close', variant='secondary', elem_id=f'{tabname}_edit_style_close')
76
+
77
+ self.selection.change(
78
+ fn=select_style,
79
+ inputs=[self.selection],
80
+ outputs=[self.prompt, self.neg_prompt, self.delete, self.save],
81
+ show_progress=False,
82
+ )
83
+
84
+ self.save.click(
85
+ fn=save_style,
86
+ inputs=[self.selection, self.prompt, self.neg_prompt],
87
+ outputs=[self.delete],
88
+ show_progress=False,
89
+ ).then(refresh_styles, outputs=[self.dropdown, self.selection], show_progress=False)
90
+
91
+ self.delete.click(
92
+ fn=delete_style,
93
+ _js='function(name){ if(name == "") return ""; return confirm("Delete style " + name + "?") ? name : ""; }',
94
+ inputs=[self.selection],
95
+ outputs=[self.selection, self.prompt, self.neg_prompt],
96
+ show_progress=False,
97
+ ).then(refresh_styles, outputs=[self.dropdown, self.selection], show_progress=False)
98
+
99
+ self.materialize.click(
100
+ fn=materialize_styles,
101
+ inputs=[main_ui_prompt, main_ui_negative_prompt, self.dropdown],
102
+ outputs=[main_ui_prompt, main_ui_negative_prompt, self.dropdown],
103
+ show_progress=False,
104
+ ).then(fn=None, _js="function(){update_"+tabname+"_tokens(); closePopup();}", show_progress=False)
105
+
106
+ ui_common.setup_dialog(button_show=edit_button, dialog=styles_dialog, button_close=self.close)
107
+
108
+
109
+
110
+
modules/ui_settings.py ADDED
@@ -0,0 +1,296 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ from modules import ui_common, shared, script_callbacks, scripts, sd_models, sysinfo
4
+ from modules.call_queue import wrap_gradio_call
5
+ from modules.shared import opts
6
+ from modules.ui_components import FormRow
7
+ from modules.ui_gradio_extensions import reload_javascript
8
+
9
+
10
+ def get_value_for_setting(key):
11
+ value = getattr(opts, key)
12
+
13
+ info = opts.data_labels[key]
14
+ args = info.component_args() if callable(info.component_args) else info.component_args or {}
15
+ args = {k: v for k, v in args.items() if k not in {'precision'}}
16
+
17
+ return gr.update(value=value, **args)
18
+
19
+
20
+ def create_setting_component(key, is_quicksettings=False):
21
+ def fun():
22
+ return opts.data[key] if key in opts.data else opts.data_labels[key].default
23
+
24
+ info = opts.data_labels[key]
25
+ t = type(info.default)
26
+
27
+ args = info.component_args() if callable(info.component_args) else info.component_args
28
+
29
+ if info.component is not None:
30
+ comp = info.component
31
+ elif t == str:
32
+ comp = gr.Textbox
33
+ elif t == int:
34
+ comp = gr.Number
35
+ elif t == bool:
36
+ comp = gr.Checkbox
37
+ else:
38
+ raise Exception(f'bad options item type: {t} for key {key}')
39
+
40
+ elem_id = f"setting_{key}"
41
+
42
+ if info.refresh is not None:
43
+ if is_quicksettings:
44
+ res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {}))
45
+ ui_common.create_refresh_button(res, info.refresh, info.component_args, f"refresh_{key}")
46
+ else:
47
+ with FormRow():
48
+ res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {}))
49
+ ui_common.create_refresh_button(res, info.refresh, info.component_args, f"refresh_{key}")
50
+ else:
51
+ res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {}))
52
+
53
+ return res
54
+
55
+
56
+ class UiSettings:
57
+ submit = None
58
+ result = None
59
+ interface = None
60
+ components = None
61
+ component_dict = None
62
+ dummy_component = None
63
+ quicksettings_list = None
64
+ quicksettings_names = None
65
+ text_settings = None
66
+
67
+ def run_settings(self, *args):
68
+ changed = []
69
+
70
+ for key, value, comp in zip(opts.data_labels.keys(), args, self.components):
71
+ assert comp == self.dummy_component or opts.same_type(value, opts.data_labels[key].default), f"Bad value for setting {key}: {value}; expecting {type(opts.data_labels[key].default).__name__}"
72
+
73
+ for key, value, comp in zip(opts.data_labels.keys(), args, self.components):
74
+ if comp == self.dummy_component:
75
+ continue
76
+
77
+ if opts.set(key, value):
78
+ changed.append(key)
79
+
80
+ try:
81
+ opts.save(shared.config_filename)
82
+ except RuntimeError:
83
+ return opts.dumpjson(), f'{len(changed)} settings changed without save: {", ".join(changed)}.'
84
+ return opts.dumpjson(), f'{len(changed)} settings changed{": " if changed else ""}{", ".join(changed)}.'
85
+
86
+ def run_settings_single(self, value, key):
87
+ if not opts.same_type(value, opts.data_labels[key].default):
88
+ return gr.update(visible=True), opts.dumpjson()
89
+
90
+ if value is None or not opts.set(key, value):
91
+ return gr.update(value=getattr(opts, key)), opts.dumpjson()
92
+
93
+ opts.save(shared.config_filename)
94
+
95
+ return get_value_for_setting(key), opts.dumpjson()
96
+
97
+ def create_ui(self, loadsave, dummy_component):
98
+ self.components = []
99
+ self.component_dict = {}
100
+ self.dummy_component = dummy_component
101
+
102
+ shared.settings_components = self.component_dict
103
+
104
+ script_callbacks.ui_settings_callback()
105
+ opts.reorder()
106
+
107
+ with gr.Blocks(analytics_enabled=False) as settings_interface:
108
+ with gr.Row():
109
+ with gr.Column(scale=6):
110
+ self.submit = gr.Button(value="Apply settings", variant='primary', elem_id="settings_submit")
111
+ with gr.Column():
112
+ restart_gradio = gr.Button(value='Reload UI', variant='primary', elem_id="settings_restart_gradio")
113
+
114
+ self.result = gr.HTML(elem_id="settings_result")
115
+
116
+ self.quicksettings_names = opts.quicksettings_list
117
+ self.quicksettings_names = {x: i for i, x in enumerate(self.quicksettings_names) if x != 'quicksettings'}
118
+
119
+ self.quicksettings_list = []
120
+
121
+ previous_section = None
122
+ current_tab = None
123
+ current_row = None
124
+ with gr.Tabs(elem_id="settings"):
125
+ for i, (k, item) in enumerate(opts.data_labels.items()):
126
+ section_must_be_skipped = item.section[0] is None
127
+
128
+ if previous_section != item.section and not section_must_be_skipped:
129
+ elem_id, text = item.section
130
+
131
+ if current_tab is not None:
132
+ current_row.__exit__()
133
+ current_tab.__exit__()
134
+
135
+ gr.Group()
136
+ current_tab = gr.TabItem(elem_id=f"settings_{elem_id}", label=text)
137
+ current_tab.__enter__()
138
+ current_row = gr.Column(variant='compact')
139
+ current_row.__enter__()
140
+
141
+ previous_section = item.section
142
+
143
+ if k in self.quicksettings_names and not shared.cmd_opts.freeze_settings:
144
+ self.quicksettings_list.append((i, k, item))
145
+ self.components.append(dummy_component)
146
+ elif section_must_be_skipped:
147
+ self.components.append(dummy_component)
148
+ else:
149
+ component = create_setting_component(k)
150
+ self.component_dict[k] = component
151
+ self.components.append(component)
152
+
153
+ if current_tab is not None:
154
+ current_row.__exit__()
155
+ current_tab.__exit__()
156
+
157
+ with gr.TabItem("Defaults", id="defaults", elem_id="settings_tab_defaults"):
158
+ loadsave.create_ui()
159
+
160
+ with gr.TabItem("Sysinfo", id="sysinfo", elem_id="settings_tab_sysinfo"):
161
+ gr.HTML('<a href="./internal/sysinfo-download" class="sysinfo_big_link" download>Download system info</a><br /><a href="./internal/sysinfo" target="_blank">(or open as text in a new page)</a>', elem_id="sysinfo_download")
162
+
163
+ with gr.Row():
164
+ with gr.Column(scale=1):
165
+ sysinfo_check_file = gr.File(label="Check system info for validity", type='binary')
166
+ with gr.Column(scale=1):
167
+ sysinfo_check_output = gr.HTML("", elem_id="sysinfo_validity")
168
+ with gr.Column(scale=100):
169
+ pass
170
+
171
+ with gr.TabItem("Actions", id="actions", elem_id="settings_tab_actions"):
172
+ request_notifications = gr.Button(value='Request browser notifications', elem_id="request_notifications")
173
+ download_localization = gr.Button(value='Download localization template', elem_id="download_localization")
174
+ reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary', elem_id="settings_reload_script_bodies")
175
+ with gr.Row():
176
+ unload_sd_model = gr.Button(value='Unload SD checkpoint to free VRAM', elem_id="sett_unload_sd_model")
177
+ reload_sd_model = gr.Button(value='Reload the last SD checkpoint back into VRAM', elem_id="sett_reload_sd_model")
178
+
179
+ with gr.TabItem("Licenses", id="licenses", elem_id="settings_tab_licenses"):
180
+ gr.HTML(shared.html("licenses.html"), elem_id="licenses")
181
+
182
+ gr.Button(value="Show all pages", elem_id="settings_show_all_pages")
183
+
184
+ self.text_settings = gr.Textbox(elem_id="settings_json", value=lambda: opts.dumpjson(), visible=False)
185
+
186
+ unload_sd_model.click(
187
+ fn=sd_models.unload_model_weights,
188
+ inputs=[],
189
+ outputs=[]
190
+ )
191
+
192
+ reload_sd_model.click(
193
+ fn=sd_models.reload_model_weights,
194
+ inputs=[],
195
+ outputs=[]
196
+ )
197
+
198
+ request_notifications.click(
199
+ fn=lambda: None,
200
+ inputs=[],
201
+ outputs=[],
202
+ _js='function(){}'
203
+ )
204
+
205
+ download_localization.click(
206
+ fn=lambda: None,
207
+ inputs=[],
208
+ outputs=[],
209
+ _js='download_localization'
210
+ )
211
+
212
+ def reload_scripts():
213
+ scripts.reload_script_body_only()
214
+ reload_javascript() # need to refresh the html page
215
+
216
+ reload_script_bodies.click(
217
+ fn=reload_scripts,
218
+ inputs=[],
219
+ outputs=[]
220
+ )
221
+
222
+ restart_gradio.click(
223
+ fn=shared.state.request_restart,
224
+ _js='restart_reload',
225
+ inputs=[],
226
+ outputs=[],
227
+ )
228
+
229
+ def check_file(x):
230
+ if x is None:
231
+ return ''
232
+
233
+ if sysinfo.check(x.decode('utf8', errors='ignore')):
234
+ return 'Valid'
235
+
236
+ return 'Invalid'
237
+
238
+ sysinfo_check_file.change(
239
+ fn=check_file,
240
+ inputs=[sysinfo_check_file],
241
+ outputs=[sysinfo_check_output],
242
+ )
243
+
244
+ self.interface = settings_interface
245
+
246
+ def add_quicksettings(self):
247
+ with gr.Row(elem_id="quicksettings", variant="compact"):
248
+ for _i, k, _item in sorted(self.quicksettings_list, key=lambda x: self.quicksettings_names.get(x[1], x[0])):
249
+ component = create_setting_component(k, is_quicksettings=True)
250
+ self.component_dict[k] = component
251
+
252
+ def add_functionality(self, demo):
253
+ self.submit.click(
254
+ fn=wrap_gradio_call(lambda *args: self.run_settings(*args), extra_outputs=[gr.update()]),
255
+ inputs=self.components,
256
+ outputs=[self.text_settings, self.result],
257
+ )
258
+
259
+ for _i, k, _item in self.quicksettings_list:
260
+ component = self.component_dict[k]
261
+ info = opts.data_labels[k]
262
+
263
+ if isinstance(component, gr.Textbox):
264
+ methods = [component.submit, component.blur]
265
+ elif hasattr(component, 'release'):
266
+ methods = [component.release]
267
+ else:
268
+ methods = [component.change]
269
+
270
+ for method in methods:
271
+ method(
272
+ fn=lambda value, k=k: self.run_settings_single(value, key=k),
273
+ inputs=[component],
274
+ outputs=[component, self.text_settings],
275
+ show_progress=info.refresh is not None,
276
+ )
277
+
278
+ button_set_checkpoint = gr.Button('Change checkpoint', elem_id='change_checkpoint', visible=False)
279
+ button_set_checkpoint.click(
280
+ fn=lambda value, _: self.run_settings_single(value, key='sd_model_checkpoint'),
281
+ _js="function(v){ var res = desiredCheckpointName; desiredCheckpointName = ''; return [res || v, null]; }",
282
+ inputs=[self.component_dict['sd_model_checkpoint'], self.dummy_component],
283
+ outputs=[self.component_dict['sd_model_checkpoint'], self.text_settings],
284
+ )
285
+
286
+ component_keys = [k for k in opts.data_labels.keys() if k in self.component_dict]
287
+
288
+ def get_settings_values():
289
+ return [get_value_for_setting(key) for key in component_keys]
290
+
291
+ demo.load(
292
+ fn=get_settings_values,
293
+ inputs=[],
294
+ outputs=[self.component_dict[k] for k in component_keys],
295
+ queue=False,
296
+ )
modules/ui_tempdir.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import tempfile
3
+ from collections import namedtuple
4
+ from pathlib import Path
5
+
6
+ import gradio.components
7
+
8
+ from PIL import PngImagePlugin
9
+
10
+ from modules import shared
11
+
12
+
13
+ Savedfile = namedtuple("Savedfile", ["name"])
14
+
15
+
16
+ def register_tmp_file(gradio, filename):
17
+ if hasattr(gradio, 'temp_file_sets'): # gradio 3.15
18
+ gradio.temp_file_sets[0] = gradio.temp_file_sets[0] | {os.path.abspath(filename)}
19
+
20
+ if hasattr(gradio, 'temp_dirs'): # gradio 3.9
21
+ gradio.temp_dirs = gradio.temp_dirs | {os.path.abspath(os.path.dirname(filename))}
22
+
23
+
24
+ def check_tmp_file(gradio, filename):
25
+ if hasattr(gradio, 'temp_file_sets'):
26
+ return any(filename in fileset for fileset in gradio.temp_file_sets)
27
+
28
+ if hasattr(gradio, 'temp_dirs'):
29
+ return any(Path(temp_dir).resolve() in Path(filename).resolve().parents for temp_dir in gradio.temp_dirs)
30
+
31
+ return False
32
+
33
+
34
+ def save_pil_to_file(self, pil_image, dir=None, format="png"):
35
+ already_saved_as = getattr(pil_image, 'already_saved_as', None)
36
+ if already_saved_as and os.path.isfile(already_saved_as):
37
+ register_tmp_file(shared.demo, already_saved_as)
38
+ filename = already_saved_as
39
+
40
+ if not shared.opts.save_images_add_number:
41
+ filename += f'?{os.path.getmtime(already_saved_as)}'
42
+
43
+ return filename
44
+
45
+ if shared.opts.temp_dir != "":
46
+ dir = shared.opts.temp_dir
47
+ else:
48
+ os.makedirs(dir, exist_ok=True)
49
+
50
+ use_metadata = False
51
+ metadata = PngImagePlugin.PngInfo()
52
+ for key, value in pil_image.info.items():
53
+ if isinstance(key, str) and isinstance(value, str):
54
+ metadata.add_text(key, value)
55
+ use_metadata = True
56
+
57
+ file_obj = tempfile.NamedTemporaryFile(delete=False, suffix=".png", dir=dir)
58
+ pil_image.save(file_obj, pnginfo=(metadata if use_metadata else None))
59
+ return file_obj.name
60
+
61
+
62
+ def install_ui_tempdir_override():
63
+ """override save to file function so that it also writes PNG info"""
64
+ gradio.components.IOComponent.pil_to_temp_file = save_pil_to_file
65
+
66
+
67
+ def on_tmpdir_changed():
68
+ if shared.opts.temp_dir == "" or shared.demo is None:
69
+ return
70
+
71
+ os.makedirs(shared.opts.temp_dir, exist_ok=True)
72
+
73
+ register_tmp_file(shared.demo, os.path.join(shared.opts.temp_dir, "x"))
74
+
75
+
76
+ def cleanup_tmpdr():
77
+ temp_dir = shared.opts.temp_dir
78
+ if temp_dir == "" or not os.path.isdir(temp_dir):
79
+ return
80
+
81
+ for root, _, files in os.walk(temp_dir, topdown=False):
82
+ for name in files:
83
+ _, extension = os.path.splitext(name)
84
+ if extension != ".png":
85
+ continue
86
+
87
+ filename = os.path.join(root, name)
88
+ os.remove(filename)
modules/upscaler.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from abc import abstractmethod
3
+
4
+ import PIL
5
+ from PIL import Image
6
+
7
+ import modules.shared
8
+ from modules import modelloader, shared
9
+
10
+ LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
11
+ NEAREST = (Image.Resampling.NEAREST if hasattr(Image, 'Resampling') else Image.NEAREST)
12
+
13
+
14
+ class Upscaler:
15
+ name = None
16
+ model_path = None
17
+ model_name = None
18
+ model_url = None
19
+ enable = True
20
+ filter = None
21
+ model = None
22
+ user_path = None
23
+ scalers: []
24
+ tile = True
25
+
26
+ def __init__(self, create_dirs=False):
27
+ self.mod_pad_h = None
28
+ self.tile_size = modules.shared.opts.ESRGAN_tile
29
+ self.tile_pad = modules.shared.opts.ESRGAN_tile_overlap
30
+ self.device = modules.shared.device
31
+ self.img = None
32
+ self.output = None
33
+ self.scale = 1
34
+ self.half = not modules.shared.cmd_opts.no_half
35
+ self.pre_pad = 0
36
+ self.mod_scale = None
37
+ self.model_download_path = None
38
+
39
+ if self.model_path is None and self.name:
40
+ self.model_path = os.path.join(shared.models_path, self.name)
41
+ if self.model_path and create_dirs:
42
+ os.makedirs(self.model_path, exist_ok=True)
43
+
44
+ try:
45
+ import cv2 # noqa: F401
46
+ self.can_tile = True
47
+ except Exception:
48
+ pass
49
+
50
+ @abstractmethod
51
+ def do_upscale(self, img: PIL.Image, selected_model: str):
52
+ return img
53
+
54
+ def upscale(self, img: PIL.Image, scale, selected_model: str = None):
55
+ self.scale = scale
56
+ dest_w = int((img.width * scale) // 8 * 8)
57
+ dest_h = int((img.height * scale) // 8 * 8)
58
+
59
+ for _ in range(3):
60
+ shape = (img.width, img.height)
61
+
62
+ img = self.do_upscale(img, selected_model)
63
+
64
+ if shape == (img.width, img.height):
65
+ break
66
+
67
+ if img.width >= dest_w and img.height >= dest_h:
68
+ break
69
+
70
+ if img.width != dest_w or img.height != dest_h:
71
+ img = img.resize((int(dest_w), int(dest_h)), resample=LANCZOS)
72
+
73
+ return img
74
+
75
+ @abstractmethod
76
+ def load_model(self, path: str):
77
+ pass
78
+
79
+ def find_models(self, ext_filter=None) -> list:
80
+ return modelloader.load_models(model_path=self.model_path, model_url=self.model_url, command_path=self.user_path, ext_filter=ext_filter)
81
+
82
+ def update_status(self, prompt):
83
+ print(f"\nextras: {prompt}", file=shared.progress_print_out)
84
+
85
+
86
+ class UpscalerData:
87
+ name = None
88
+ data_path = None
89
+ scale: int = 4
90
+ scaler: Upscaler = None
91
+ model: None
92
+
93
+ def __init__(self, name: str, path: str, upscaler: Upscaler = None, scale: int = 4, model=None):
94
+ self.name = name
95
+ self.data_path = path
96
+ self.local_data_path = path
97
+ self.scaler = upscaler
98
+ self.scale = scale
99
+ self.model = model
100
+
101
+
102
+ class UpscalerNone(Upscaler):
103
+ name = "None"
104
+ scalers = []
105
+
106
+ def load_model(self, path):
107
+ pass
108
+
109
+ def do_upscale(self, img, selected_model=None):
110
+ return img
111
+
112
+ def __init__(self, dirname=None):
113
+ super().__init__(False)
114
+ self.scalers = [UpscalerData("None", None, self)]
115
+
116
+
117
+ class UpscalerLanczos(Upscaler):
118
+ scalers = []
119
+
120
+ def do_upscale(self, img, selected_model=None):
121
+ return img.resize((int(img.width * self.scale), int(img.height * self.scale)), resample=LANCZOS)
122
+
123
+ def load_model(self, _):
124
+ pass
125
+
126
+ def __init__(self, dirname=None):
127
+ super().__init__(False)
128
+ self.name = "Lanczos"
129
+ self.scalers = [UpscalerData("Lanczos", None, self)]
130
+
131
+
132
+ class UpscalerNearest(Upscaler):
133
+ scalers = []
134
+
135
+ def do_upscale(self, img, selected_model=None):
136
+ return img.resize((int(img.width * self.scale), int(img.height * self.scale)), resample=NEAREST)
137
+
138
+ def load_model(self, _):
139
+ pass
140
+
141
+ def __init__(self, dirname=None):
142
+ super().__init__(False)
143
+ self.name = "Nearest"
144
+ self.scalers = [UpscalerData("Nearest", None, self)]
modules/util.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+
4
+ from modules import shared
5
+ from modules.paths_internal import script_path
6
+
7
+
8
+ def natural_sort_key(s, regex=re.compile('([0-9]+)')):
9
+ return [int(text) if text.isdigit() else text.lower() for text in regex.split(s)]
10
+
11
+
12
+ def listfiles(dirname):
13
+ filenames = [os.path.join(dirname, x) for x in sorted(os.listdir(dirname), key=natural_sort_key) if not x.startswith(".")]
14
+ return [file for file in filenames if os.path.isfile(file)]
15
+
16
+
17
+ def html_path(filename):
18
+ return os.path.join(script_path, "html", filename)
19
+
20
+
21
+ def html(filename):
22
+ path = html_path(filename)
23
+
24
+ if os.path.exists(path):
25
+ with open(path, encoding="utf8") as file:
26
+ return file.read()
27
+
28
+ return ""
29
+
30
+
31
+ def walk_files(path, allowed_extensions=None):
32
+ if not os.path.exists(path):
33
+ return
34
+
35
+ if allowed_extensions is not None:
36
+ allowed_extensions = set(allowed_extensions)
37
+
38
+ items = list(os.walk(path, followlinks=True))
39
+ items = sorted(items, key=lambda x: natural_sort_key(x[0]))
40
+
41
+ for root, _, files in items:
42
+ for filename in sorted(files, key=natural_sort_key):
43
+ if allowed_extensions is not None:
44
+ _, ext = os.path.splitext(filename)
45
+ if ext not in allowed_extensions:
46
+ continue
47
+
48
+ if not shared.opts.list_hidden_files and ("/." in root or "\\." in root):
49
+ continue
50
+
51
+ yield os.path.join(root, filename)
52
+
53
+
54
+ def ldm_print(*args, **kwargs):
55
+ if shared.opts.hide_ldm_prints:
56
+ return
57
+
58
+ print(*args, **kwargs)
modules/xlmr.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import BertPreTrainedModel, BertConfig
2
+ import torch.nn as nn
3
+ import torch
4
+ from transformers.models.xlm_roberta.configuration_xlm_roberta import XLMRobertaConfig
5
+ from transformers import XLMRobertaModel,XLMRobertaTokenizer
6
+ from typing import Optional
7
+
8
+ class BertSeriesConfig(BertConfig):
9
+ def __init__(self, vocab_size=30522, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, hidden_act="gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=512, type_vocab_size=2, initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, position_embedding_type="absolute", use_cache=True, classifier_dropout=None,project_dim=512, pooler_fn="average",learn_encoder=False,model_type='bert',**kwargs):
10
+
11
+ super().__init__(vocab_size, hidden_size, num_hidden_layers, num_attention_heads, intermediate_size, hidden_act, hidden_dropout_prob, attention_probs_dropout_prob, max_position_embeddings, type_vocab_size, initializer_range, layer_norm_eps, pad_token_id, position_embedding_type, use_cache, classifier_dropout, **kwargs)
12
+ self.project_dim = project_dim
13
+ self.pooler_fn = pooler_fn
14
+ self.learn_encoder = learn_encoder
15
+
16
+ class RobertaSeriesConfig(XLMRobertaConfig):
17
+ def __init__(self, pad_token_id=1, bos_token_id=0, eos_token_id=2,project_dim=512,pooler_fn='cls',learn_encoder=False, **kwargs):
18
+ super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
19
+ self.project_dim = project_dim
20
+ self.pooler_fn = pooler_fn
21
+ self.learn_encoder = learn_encoder
22
+
23
+
24
+ class BertSeriesModelWithTransformation(BertPreTrainedModel):
25
+
26
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
27
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
28
+ config_class = BertSeriesConfig
29
+
30
+ def __init__(self, config=None, **kargs):
31
+ # modify initialization for autoloading
32
+ if config is None:
33
+ config = XLMRobertaConfig()
34
+ config.attention_probs_dropout_prob= 0.1
35
+ config.bos_token_id=0
36
+ config.eos_token_id=2
37
+ config.hidden_act='gelu'
38
+ config.hidden_dropout_prob=0.1
39
+ config.hidden_size=1024
40
+ config.initializer_range=0.02
41
+ config.intermediate_size=4096
42
+ config.layer_norm_eps=1e-05
43
+ config.max_position_embeddings=514
44
+
45
+ config.num_attention_heads=16
46
+ config.num_hidden_layers=24
47
+ config.output_past=True
48
+ config.pad_token_id=1
49
+ config.position_embedding_type= "absolute"
50
+
51
+ config.type_vocab_size= 1
52
+ config.use_cache=True
53
+ config.vocab_size= 250002
54
+ config.project_dim = 768
55
+ config.learn_encoder = False
56
+ super().__init__(config)
57
+ self.roberta = XLMRobertaModel(config)
58
+ self.transformation = nn.Linear(config.hidden_size,config.project_dim)
59
+ self.pre_LN=nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
60
+ self.tokenizer = XLMRobertaTokenizer.from_pretrained('xlm-roberta-large')
61
+ self.pooler = lambda x: x[:,0]
62
+ self.post_init()
63
+
64
+ def encode(self,c):
65
+ device = next(self.parameters()).device
66
+ text = self.tokenizer(c,
67
+ truncation=True,
68
+ max_length=77,
69
+ return_length=False,
70
+ return_overflowing_tokens=False,
71
+ padding="max_length",
72
+ return_tensors="pt")
73
+ text["input_ids"] = torch.tensor(text["input_ids"]).to(device)
74
+ text["attention_mask"] = torch.tensor(
75
+ text['attention_mask']).to(device)
76
+ features = self(**text)
77
+ return features['projection_state']
78
+
79
+ def forward(
80
+ self,
81
+ input_ids: Optional[torch.Tensor] = None,
82
+ attention_mask: Optional[torch.Tensor] = None,
83
+ token_type_ids: Optional[torch.Tensor] = None,
84
+ position_ids: Optional[torch.Tensor] = None,
85
+ head_mask: Optional[torch.Tensor] = None,
86
+ inputs_embeds: Optional[torch.Tensor] = None,
87
+ encoder_hidden_states: Optional[torch.Tensor] = None,
88
+ encoder_attention_mask: Optional[torch.Tensor] = None,
89
+ output_attentions: Optional[bool] = None,
90
+ return_dict: Optional[bool] = None,
91
+ output_hidden_states: Optional[bool] = None,
92
+ ) :
93
+ r"""
94
+ """
95
+
96
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
97
+
98
+
99
+ outputs = self.roberta(
100
+ input_ids=input_ids,
101
+ attention_mask=attention_mask,
102
+ token_type_ids=token_type_ids,
103
+ position_ids=position_ids,
104
+ head_mask=head_mask,
105
+ inputs_embeds=inputs_embeds,
106
+ encoder_hidden_states=encoder_hidden_states,
107
+ encoder_attention_mask=encoder_attention_mask,
108
+ output_attentions=output_attentions,
109
+ output_hidden_states=True,
110
+ return_dict=return_dict,
111
+ )
112
+
113
+ # last module outputs
114
+ sequence_output = outputs[0]
115
+
116
+
117
+ # project every module
118
+ sequence_output_ln = self.pre_LN(sequence_output)
119
+
120
+ # pooler
121
+ pooler_output = self.pooler(sequence_output_ln)
122
+ pooler_output = self.transformation(pooler_output)
123
+ projection_state = self.transformation(outputs.last_hidden_state)
124
+
125
+ return {
126
+ 'pooler_output':pooler_output,
127
+ 'last_hidden_state':outputs.last_hidden_state,
128
+ 'hidden_states':outputs.hidden_states,
129
+ 'attentions':outputs.attentions,
130
+ 'projection_state':projection_state,
131
+ 'sequence_out': sequence_output
132
+ }
133
+
134
+
135
+ class RobertaSeriesModelWithTransformation(BertSeriesModelWithTransformation):
136
+ base_model_prefix = 'roberta'
137
+ config_class= RobertaSeriesConfig