Froddan commited on
Commit
3d988d8
1 Parent(s): 276ba83

Upload generate_model_grid.py

Browse files

The most recent version, uses checkbox grid

Files changed (1) hide show
  1. generate_model_grid.py +299 -0
generate_model_grid.py ADDED
@@ -0,0 +1,299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import namedtuple
2
+ from copy import copy
3
+ from itertools import permutations, chain
4
+ import random
5
+ import csv
6
+ from io import StringIO
7
+ from PIL import Image
8
+ import numpy as np
9
+ import os
10
+
11
+ import modules.scripts as scripts
12
+ import gradio as gr
13
+
14
+ from modules import images, sd_samplers
15
+ from modules.hypernetworks import hypernetwork
16
+ from modules.processing import process_images, Processed, StableDiffusionProcessingTxt2Img
17
+ from modules.shared import opts, cmd_opts, state
18
+ import modules.shared as shared
19
+ import modules.sd_samplers
20
+ import modules.sd_models
21
+ import re
22
+
23
+
24
+ def apply_field(field):
25
+ def fun(p, x, xs):
26
+ setattr(p, field, x)
27
+
28
+ return fun
29
+
30
+
31
+ def apply_prompt(p, x, xs):
32
+ if xs[0] not in p.prompt and xs[0] not in p.negative_prompt:
33
+ raise RuntimeError(f"Prompt S/R did not find {xs[0]} in prompt or negative prompt.")
34
+
35
+ p.prompt = p.prompt.replace(xs[0], x)
36
+ p.negative_prompt = p.negative_prompt.replace(xs[0], x)
37
+
38
+ def edit_prompt(p,x,z):
39
+ p.prompt = z + " " + x
40
+
41
+
42
+ def apply_order(p, x, xs):
43
+ token_order = []
44
+
45
+ # Initally grab the tokens from the prompt, so they can be replaced in order of earliest seen
46
+ for token in x:
47
+ token_order.append((p.prompt.find(token), token))
48
+
49
+ token_order.sort(key=lambda t: t[0])
50
+
51
+ prompt_parts = []
52
+
53
+ # Split the prompt up, taking out the tokens
54
+ for _, token in token_order:
55
+ n = p.prompt.find(token)
56
+ prompt_parts.append(p.prompt[0:n])
57
+ p.prompt = p.prompt[n + len(token):]
58
+
59
+ # Rebuild the prompt with the tokens in the order we want
60
+ prompt_tmp = ""
61
+ for idx, part in enumerate(prompt_parts):
62
+ prompt_tmp += part
63
+ prompt_tmp += x[idx]
64
+ p.prompt = prompt_tmp + p.prompt
65
+
66
+
67
+ def build_samplers_dict():
68
+ samplers_dict = {}
69
+ for i, sampler in enumerate(sd_samplers.all_samplers):
70
+ samplers_dict[sampler.name.lower()] = i
71
+ for alias in sampler.aliases:
72
+ samplers_dict[alias.lower()] = i
73
+ return samplers_dict
74
+
75
+
76
+ def apply_sampler(p, x, xs):
77
+ sampler_index = build_samplers_dict().get(x.lower(), None)
78
+ if sampler_index is None:
79
+ raise RuntimeError(f"Unknown sampler: {x}")
80
+
81
+ p.sampler_index = sampler_index
82
+
83
+
84
+ def confirm_samplers(p, xs):
85
+ samplers_dict = build_samplers_dict()
86
+ for x in xs:
87
+ if x.lower() not in samplers_dict.keys():
88
+ raise RuntimeError(f"Unknown sampler: {x}")
89
+
90
+
91
+ def apply_checkpoint(p, x, xs):
92
+ info = modules.sd_models.get_closet_checkpoint_match(x)
93
+ if info is None:
94
+ raise RuntimeError(f"Unknown checkpoint: {x}")
95
+ modules.sd_models.reload_model_weights(shared.sd_model, info)
96
+ p.sd_model = shared.sd_model
97
+
98
+
99
+ def confirm_checkpoints(p, xs):
100
+ for x in xs:
101
+ if modules.sd_models.get_closet_checkpoint_match(x) is None:
102
+ raise RuntimeError(f"Unknown checkpoint: {x}")
103
+
104
+
105
+ def apply_hypernetwork(p, x, xs):
106
+ if x.lower() in ["", "none"]:
107
+ name = None
108
+ else:
109
+ name = hypernetwork.find_closest_hypernetwork_name(x)
110
+ if not name:
111
+ raise RuntimeError(f"Unknown hypernetwork: {x}")
112
+ hypernetwork.load_hypernetwork(name)
113
+
114
+
115
+ def apply_hypernetwork_strength(p, x, xs):
116
+ hypernetwork.apply_strength(x)
117
+
118
+
119
+ def confirm_hypernetworks(p, xs):
120
+ for x in xs:
121
+ if x.lower() in ["", "none"]:
122
+ continue
123
+ if not hypernetwork.find_closest_hypernetwork_name(x):
124
+ raise RuntimeError(f"Unknown hypernetwork: {x}")
125
+
126
+
127
+ def apply_clip_skip(p, x, xs):
128
+ opts.data["CLIP_stop_at_last_layers"] = x
129
+
130
+
131
+ def format_value_add_label(p, opt, x):
132
+ if type(x) == float:
133
+ x = round(x, 8)
134
+
135
+ return f"{opt.label}: {x}"
136
+
137
+
138
+ def format_value(p, opt, x):
139
+ if type(x) == float:
140
+ x = round(x, 8)
141
+ return x
142
+
143
+
144
+ def format_value_join_list(p, opt, x):
145
+ return ", ".join(x)
146
+
147
+
148
+ def do_nothing(p, x, xs):
149
+ pass
150
+
151
+
152
+ def format_nothing(p, opt, x):
153
+ return ""
154
+
155
+
156
+ def str_permutations(x):
157
+ """dummy function for specifying it in AxisOption's type when you want to get a list of permutations"""
158
+ return x
159
+
160
+ # AxisOption = namedtuple("AxisOption", ["label", "type", "apply", "format_value", "confirm"])
161
+ # AxisOptionImg2Img = namedtuple("AxisOptionImg2Img", ["label", "type", "apply", "format_value", "confirm"])
162
+
163
+
164
+ def draw_xy_grid(p, xs, ys, zs, x_labels, y_labels, cell, draw_legend, include_lone_images):
165
+ ver_texts = [[images.GridAnnotation(y)] for y in y_labels]
166
+ hor_texts = [[images.GridAnnotation(x)] for x in x_labels]
167
+
168
+ # Temporary list of all the images that are generated to be populated into the grid.
169
+ # Will be filled with empty images for any individual step that fails to process properly
170
+ image_cache = []
171
+
172
+ processed_result = None
173
+ cell_mode = "P"
174
+ cell_size = (1,1)
175
+
176
+ state.job_count = len(xs) * len(ys) * p.n_iter
177
+
178
+ for iy, y in enumerate(ys):
179
+ for ix, x in enumerate(xs):
180
+ state.job = f"{ix + iy * len(xs) + 1} out of {len(xs) * len(ys)}"
181
+ z = zs[iy]
182
+ processed:Processed = cell(x, y, z)
183
+ try:
184
+ # this dereference will throw an exception if the image was not processed
185
+ # (this happens in cases such as if the user stops the process from the UI)
186
+ processed_image = processed.images[0]
187
+
188
+ if processed_result is None:
189
+ # Use our first valid processed result as a template container to hold our full results
190
+ processed_result = copy(processed)
191
+ cell_mode = processed_image.mode
192
+ cell_size = processed_image.size
193
+ processed_result.images = [Image.new(cell_mode, cell_size)]
194
+
195
+ image_cache.append(processed_image)
196
+ if include_lone_images:
197
+ processed_result.images.append(processed_image)
198
+ processed_result.all_prompts.append(processed.prompt)
199
+ processed_result.all_seeds.append(processed.seed)
200
+ processed_result.infotexts.append(processed.infotexts[0])
201
+ except:
202
+ image_cache.append(Image.new(cell_mode, cell_size))
203
+
204
+ if not processed_result:
205
+ print("Unexpected error: draw_xy_grid failed to return even a single processed image")
206
+ return Processed()
207
+
208
+ grid = images.image_grid(image_cache, rows=len(ys))
209
+ if draw_legend:
210
+ grid = images.draw_grid_annotations(grid, cell_size[0], cell_size[1], hor_texts, ver_texts)
211
+
212
+ processed_result.images[0] = grid
213
+
214
+ return processed_result
215
+
216
+
217
+ class SharedSettingsStackHelper(object):
218
+ def __enter__(self):
219
+ self.CLIP_stop_at_last_layers = opts.CLIP_stop_at_last_layers
220
+ self.hypernetwork = opts.sd_hypernetwork
221
+ self.model = shared.sd_model
222
+
223
+ def __exit__(self, exc_type, exc_value, tb):
224
+ modules.sd_models.reload_model_weights(self.model)
225
+
226
+ hypernetwork.load_hypernetwork(self.hypernetwork)
227
+ hypernetwork.apply_strength()
228
+
229
+ opts.data["CLIP_stop_at_last_layers"] = self.CLIP_stop_at_last_layers
230
+
231
+
232
+ re_range = re.compile(r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\(([+-]\d+)\s*\))?\s*")
233
+ re_range_float = re.compile(r"\s*([+-]?\s*\d+(?:.\d*)?)\s*-\s*([+-]?\s*\d+(?:.\d*)?)(?:\s*\(([+-]\d+(?:.\d*)?)\s*\))?\s*")
234
+
235
+ re_range_count = re.compile(r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\[(\d+)\s*\])?\s*")
236
+ re_range_count_float = re.compile(r"\s*([+-]?\s*\d+(?:.\d*)?)\s*-\s*([+-]?\s*\d+(?:.\d*)?)(?:\s*\[(\d+(?:.\d*)?)\s*\])?\s*")
237
+
238
+ class Script(scripts.Script):
239
+ def title(self):
240
+ return "Generate Model Grid"
241
+
242
+ def ui(self, is_img2img):
243
+ filenames = []
244
+ dirpath = '/content/stable-diffusion-webui/models/Stable-diffusion/'
245
+ for path in os.listdir(dirpath):
246
+ if path.endswith('.ckpt'):
247
+ filenames.append(path)
248
+
249
+ with gr.Row():
250
+ x_values = gr.Textbox(label="Prompts, separated with &", lines=1)
251
+
252
+ with gr.Row():
253
+ y_values = gr.CheckboxGroup(filenames, label="Checkpoint file names, including file ending", lines=1)
254
+
255
+ with gr.Row():
256
+ z_values = gr.Textbox(label="Model tokens", lines=1)
257
+
258
+ draw_legend = gr.Checkbox(label='Draw legend', value=True)
259
+ include_lone_images = gr.Checkbox(label='Include Separate Images', value=False)
260
+ no_fixed_seeds = gr.Checkbox(label='Keep -1 for seeds', value=False)
261
+
262
+ return [x_values, y_values, z_values, draw_legend, include_lone_images, no_fixed_seeds]
263
+
264
+ def run(self, p, x_values, y_values, z_values, draw_legend, include_lone_images, no_fixed_seeds):
265
+ if not no_fixed_seeds:
266
+ modules.processing.fix_seed(p)
267
+
268
+ if not opts.return_grid:
269
+ p.batch_size = 1
270
+
271
+ xs = [x.strip() for x in chain.from_iterable(csv.reader(StringIO(x_values), delimiter='&'))]
272
+ ys = [x.strip() for x in chain.from_iterable(csv.reader(StringIO(y_values)))]
273
+ zs = [x.strip() for x in chain.from_iterable(csv.reader(StringIO(z_values)))]
274
+
275
+ def cell(x, y, z):
276
+ pc = copy(p)
277
+ edit_prompt(pc, x, z)
278
+ confirm_checkpoints(pc,ys)
279
+ apply_checkpoint(pc, y, ys)
280
+
281
+ return process_images(pc)
282
+
283
+ with SharedSettingsStackHelper():
284
+ processed = draw_xy_grid(
285
+ p,
286
+ xs=xs,
287
+ ys=ys,
288
+ zs=zs,
289
+ x_labels=xs,
290
+ y_labels=ys,
291
+ cell=cell,
292
+ draw_legend=draw_legend,
293
+ include_lone_images=include_lone_images
294
+ )
295
+
296
+ if opts.grid_save:
297
+ images.save_image(processed.images[0], p.outpath_grids, "xy_grid", prompt=p.prompt, seed=processed.seed, grid=True, p=p)
298
+
299
+ return processed