adding refiner option
Browse files- app.py +27 -6
- appConfig.json +23 -14
- config.py +24 -4
app.py
CHANGED
@@ -16,13 +16,15 @@ def models_change(model, scheduler, config):
|
|
16 |
config = set_config(config, 'model', model)
|
17 |
|
18 |
use_safetensors = False
|
|
|
19 |
|
20 |
# no model selected (because this is UI init run)
|
21 |
if type(model) != list and str(model) != 'None':
|
22 |
|
23 |
use_safetensors = str(models[model]['use_safetensors'])
|
24 |
model_description = models[model]['description']
|
25 |
-
|
|
|
26 |
# if no scheduler is selected, choose the default one for this model
|
27 |
if scheduler == None:
|
28 |
|
@@ -34,11 +36,12 @@ def models_change(model, scheduler, config):
|
|
34 |
|
35 |
config["use_safetensors"] = str(use_safetensors)
|
36 |
config["scheduler"] = str(scheduler)
|
|
|
37 |
|
38 |
# safety_checker_change(in_safety_checker.value, config)
|
39 |
# requires_safety_checker_change(in_requires_safety_checker.value, config)
|
40 |
|
41 |
-
return model_description, use_safetensors, scheduler, config, str(config), assemble_code(config)
|
42 |
|
43 |
def data_type_change(data_type, config):
|
44 |
|
@@ -132,6 +135,15 @@ def run_inference(config, config_history, progress=gr.Progress(track_tqdm=True))
|
|
132 |
torch_dtype = get_data_type(config["data_type"]),
|
133 |
variant = get_variant(config["variant"])).to(config["device"])
|
134 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
135 |
if str(config["safety_checker"]).lower() == 'false':
|
136 |
pipeline.safety_checker = None
|
137 |
|
@@ -151,11 +163,18 @@ def run_inference(config, config_history, progress=gr.Progress(track_tqdm=True))
|
|
151 |
negative_prompt = config["negative_prompt"],
|
152 |
generator = generator,
|
153 |
num_inference_steps = int(config["inference_steps"]),
|
154 |
-
guidance_scale = float(config["guidance_scale"])).images
|
155 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
156 |
config_history.append(config.copy())
|
157 |
|
158 |
-
return image, dict_list_to_markdown_table(config_history), config_history
|
159 |
|
160 |
else:
|
161 |
|
@@ -191,6 +210,7 @@ with gr.Blocks(analytics_enabled=False) as demo:
|
|
191 |
with gr.Row():
|
192 |
with gr.Column(scale=1):
|
193 |
in_use_safetensors = gr.Radio(label="Use safe tensors:", choices=["True", "False"], interactive=False)
|
|
|
194 |
with gr.Column(scale=1):
|
195 |
in_safety_checker = gr.Radio(label="Enable safety checker:", value=config.value["safety_checker"], choices=["True", "False"])
|
196 |
in_requires_safety_checker = gr.Radio(label="Requires safety checker:", value=config.value["requires_safety_checker"], choices=["True", "False"])
|
@@ -229,7 +249,7 @@ with gr.Blocks(analytics_enabled=False) as demo:
|
|
229 |
in_data_type.change(data_type_change, inputs=[in_data_type, config], outputs=[config, out_config, out_code])
|
230 |
in_allow_tensorfloat32.change(tensorfloat32_change, inputs=[in_allow_tensorfloat32, config], outputs=[config, out_config, out_code])
|
231 |
in_variant.change(variant_change, inputs=[in_variant, config], outputs=[config, out_config, out_code])
|
232 |
-
in_models.change(models_change, inputs=[in_models, in_schedulers, config], outputs=[out_model_description, in_use_safetensors, in_schedulers, config, out_config, out_code])
|
233 |
in_safety_checker.change(safety_checker_change, inputs=[in_safety_checker, config], outputs=[config, out_config, out_code])
|
234 |
in_requires_safety_checker.change(requires_safety_checker_change, inputs=[in_requires_safety_checker, config], outputs=[config, out_config, out_code])
|
235 |
in_schedulers.change(schedulers_change, inputs=[in_schedulers, config], outputs=[out_scheduler_description, config, out_config, out_code])
|
@@ -249,6 +269,7 @@ with gr.Blocks(analytics_enabled=False) as demo:
|
|
249 |
in_devices,
|
250 |
in_use_safetensors,
|
251 |
in_data_type,
|
|
|
252 |
in_variant,
|
253 |
in_safety_checker,
|
254 |
in_requires_safety_checker,
|
|
|
16 |
config = set_config(config, 'model', model)
|
17 |
|
18 |
use_safetensors = False
|
19 |
+
refiner = ""
|
20 |
|
21 |
# no model selected (because this is UI init run)
|
22 |
if type(model) != list and str(model) != 'None':
|
23 |
|
24 |
use_safetensors = str(models[model]['use_safetensors'])
|
25 |
model_description = models[model]['description']
|
26 |
+
refiner = models[model]['refiner']
|
27 |
+
|
28 |
# if no scheduler is selected, choose the default one for this model
|
29 |
if scheduler == None:
|
30 |
|
|
|
36 |
|
37 |
config["use_safetensors"] = str(use_safetensors)
|
38 |
config["scheduler"] = str(scheduler)
|
39 |
+
config["refiner"] = str(refiner)
|
40 |
|
41 |
# safety_checker_change(in_safety_checker.value, config)
|
42 |
# requires_safety_checker_change(in_requires_safety_checker.value, config)
|
43 |
|
44 |
+
return model_description, refiner, use_safetensors, scheduler, config, str(config), assemble_code(config)
|
45 |
|
46 |
def data_type_change(data_type, config):
|
47 |
|
|
|
135 |
torch_dtype = get_data_type(config["data_type"]),
|
136 |
variant = get_variant(config["variant"])).to(config["device"])
|
137 |
|
138 |
+
if config['refiner'] != '':
|
139 |
+
refiner = DiffusionPipeline.from_pretrained(
|
140 |
+
config['refiner'],
|
141 |
+
text_encoder_2=pipeline.text_encoder_2,
|
142 |
+
vae=pipeline.vae,
|
143 |
+
torch_dtype=get_data_type(config["data_type"]),
|
144 |
+
use_safetensors=get_bool(config["use_safetensors"]),
|
145 |
+
variant = get_variant(config["variant"])).to(config["device"])
|
146 |
+
|
147 |
if str(config["safety_checker"]).lower() == 'false':
|
148 |
pipeline.safety_checker = None
|
149 |
|
|
|
163 |
negative_prompt = config["negative_prompt"],
|
164 |
generator = generator,
|
165 |
num_inference_steps = int(config["inference_steps"]),
|
166 |
+
guidance_scale = float(config["guidance_scale"])).images
|
167 |
+
|
168 |
+
if config['refiner'] != '':
|
169 |
+
image = refiner(
|
170 |
+
prompt = config["prompt"],
|
171 |
+
num_inference_steps = int(config["inference_steps"]),
|
172 |
+
image=image,
|
173 |
+
).images
|
174 |
+
|
175 |
config_history.append(config.copy())
|
176 |
|
177 |
+
return image[0], dict_list_to_markdown_table(config_history), config_history
|
178 |
|
179 |
else:
|
180 |
|
|
|
210 |
with gr.Row():
|
211 |
with gr.Column(scale=1):
|
212 |
in_use_safetensors = gr.Radio(label="Use safe tensors:", choices=["True", "False"], interactive=False)
|
213 |
+
in_model_refiner = gr.Textbox(value="", label="Refiner")
|
214 |
with gr.Column(scale=1):
|
215 |
in_safety_checker = gr.Radio(label="Enable safety checker:", value=config.value["safety_checker"], choices=["True", "False"])
|
216 |
in_requires_safety_checker = gr.Radio(label="Requires safety checker:", value=config.value["requires_safety_checker"], choices=["True", "False"])
|
|
|
249 |
in_data_type.change(data_type_change, inputs=[in_data_type, config], outputs=[config, out_config, out_code])
|
250 |
in_allow_tensorfloat32.change(tensorfloat32_change, inputs=[in_allow_tensorfloat32, config], outputs=[config, out_config, out_code])
|
251 |
in_variant.change(variant_change, inputs=[in_variant, config], outputs=[config, out_config, out_code])
|
252 |
+
in_models.change(models_change, inputs=[in_models, in_schedulers, config], outputs=[out_model_description, in_model_refiner, in_use_safetensors, in_schedulers, config, out_config, out_code])
|
253 |
in_safety_checker.change(safety_checker_change, inputs=[in_safety_checker, config], outputs=[config, out_config, out_code])
|
254 |
in_requires_safety_checker.change(requires_safety_checker_change, inputs=[in_requires_safety_checker, config], outputs=[config, out_config, out_code])
|
255 |
in_schedulers.change(schedulers_change, inputs=[in_schedulers, config], outputs=[out_scheduler_description, config, out_config, out_code])
|
|
|
269 |
in_devices,
|
270 |
in_use_safetensors,
|
271 |
in_data_type,
|
272 |
+
in_model_refiner,
|
273 |
in_variant,
|
274 |
in_safety_checker,
|
275 |
in_requires_safety_checker,
|
appConfig.json
CHANGED
@@ -4,47 +4,56 @@
|
|
4 |
"sd-dreambooth-library/solo-levelling-art-style": {
|
5 |
"use_safetensors": false,
|
6 |
"description": "see https://huggingface.co/sd-dreambooth-library/solo-levelling-art-style",
|
7 |
-
"scheduler": "DDIMScheduler"
|
|
|
8 |
},
|
9 |
"CompVis/stable-diffusion-v1-4": {
|
10 |
"use_safetensors": true,
|
11 |
"description": "see https://huggingface.co/CompVis/stable-diffusion-v1-4",
|
12 |
-
"scheduler": "EulerDiscreteScheduler"
|
|
|
13 |
},
|
14 |
"runwayml/stable-diffusion-v1-5": {
|
15 |
"use_safetensors": true,
|
16 |
"description": "see https://huggingface.co/runwayml/stable-diffusion-v1-5",
|
17 |
-
"scheduler": "DDPMScheduler"
|
|
|
18 |
},
|
19 |
"stabilityai/stable-diffusion-2-1": {
|
20 |
"use_safetensors": true,
|
21 |
"description": "see https://huggingface.co/stabilityai/stable-diffusion-2-1",
|
22 |
-
"scheduler": "DPMSolverMultistepScheduler"
|
|
|
23 |
},
|
24 |
"stabilityai/stable-diffusion-xl-base-1.0": {
|
25 |
"use_safetensors": true,
|
26 |
"description": "see https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0",
|
27 |
-
"scheduler": "DDPMScheduler"
|
|
|
28 |
},
|
29 |
"sd-dreambooth-library/house-emblem": {
|
30 |
"use_safetensors": false,
|
31 |
"description": "see https://huggingface.co/sd-dreambooth-library/house-emblem",
|
32 |
-
"scheduler": "DDPMScheduler"
|
|
|
33 |
},
|
34 |
"Envvi/Inkpunk-Diffusion": {
|
35 |
-
"use_safetensors":
|
36 |
-
"description": "
|
37 |
-
"scheduler": "DDPMScheduler"
|
|
|
38 |
},
|
39 |
"Stelath/textual_inversion_comic_strip_fp16": {
|
40 |
"use_safetensors": true,
|
41 |
-
"description": "
|
42 |
-
"scheduler": "DDPMScheduler"
|
|
|
43 |
},
|
44 |
"sd-dreambooth-library/herge-style": {
|
45 |
-
"use_safetensors":
|
46 |
-
"description": "
|
47 |
-
"scheduler": "DDPMScheduler"
|
|
|
48 |
}
|
49 |
|
50 |
},
|
|
|
4 |
"sd-dreambooth-library/solo-levelling-art-style": {
|
5 |
"use_safetensors": false,
|
6 |
"description": "see https://huggingface.co/sd-dreambooth-library/solo-levelling-art-style",
|
7 |
+
"scheduler": "DDIMScheduler",
|
8 |
+
"refiner": ""
|
9 |
},
|
10 |
"CompVis/stable-diffusion-v1-4": {
|
11 |
"use_safetensors": true,
|
12 |
"description": "see https://huggingface.co/CompVis/stable-diffusion-v1-4",
|
13 |
+
"scheduler": "EulerDiscreteScheduler",
|
14 |
+
"refiner": ""
|
15 |
},
|
16 |
"runwayml/stable-diffusion-v1-5": {
|
17 |
"use_safetensors": true,
|
18 |
"description": "see https://huggingface.co/runwayml/stable-diffusion-v1-5",
|
19 |
+
"scheduler": "DDPMScheduler",
|
20 |
+
"refiner": ""
|
21 |
},
|
22 |
"stabilityai/stable-diffusion-2-1": {
|
23 |
"use_safetensors": true,
|
24 |
"description": "see https://huggingface.co/stabilityai/stable-diffusion-2-1",
|
25 |
+
"scheduler": "DPMSolverMultistepScheduler",
|
26 |
+
"refiner": ""
|
27 |
},
|
28 |
"stabilityai/stable-diffusion-xl-base-1.0": {
|
29 |
"use_safetensors": true,
|
30 |
"description": "see https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0",
|
31 |
+
"scheduler": "DDPMScheduler",
|
32 |
+
"refiner": "stabilityai/stable-diffusion-xl-refiner-1.0"
|
33 |
},
|
34 |
"sd-dreambooth-library/house-emblem": {
|
35 |
"use_safetensors": false,
|
36 |
"description": "see https://huggingface.co/sd-dreambooth-library/house-emblem",
|
37 |
+
"scheduler": "DDPMScheduler",
|
38 |
+
"refiner": ""
|
39 |
},
|
40 |
"Envvi/Inkpunk-Diffusion": {
|
41 |
+
"use_safetensors": false,
|
42 |
+
"description": "see https://huggingface.co/Envvi/Inkpunk-Diffusion",
|
43 |
+
"scheduler": "DDPMScheduler",
|
44 |
+
"refiner": ""
|
45 |
},
|
46 |
"Stelath/textual_inversion_comic_strip_fp16": {
|
47 |
"use_safetensors": true,
|
48 |
+
"description": "see https://huggingface.co/Stelath/textual_inversion_comic_strip_fp16",
|
49 |
+
"scheduler": "DDPMScheduler",
|
50 |
+
"refiner": ""
|
51 |
},
|
52 |
"sd-dreambooth-library/herge-style": {
|
53 |
+
"use_safetensors": false,
|
54 |
+
"description": "see https://huggingface.co/sd-dreambooth-library/herge-style",
|
55 |
+
"scheduler": "DDPMScheduler",
|
56 |
+
"refiner": ""
|
57 |
}
|
58 |
|
59 |
},
|
config.py
CHANGED
@@ -42,6 +42,7 @@ def get_initial_config():
|
|
42 |
"allow_tensorfloat32": allow_tensorfloat32,
|
43 |
"use_safetensors": "False",
|
44 |
"data_type": data_type,
|
|
|
45 |
"safety_checker": "False",
|
46 |
"requires_safety_checker": "False",
|
47 |
"manual_seed": 42,
|
@@ -78,6 +79,7 @@ def get_config_from_url(initial_config, request: Request):
|
|
78 |
return_config['device'],
|
79 |
return_config['use_safetensors'],
|
80 |
return_config['data_type'],
|
|
|
81 |
return_config['variant'],
|
82 |
return_config['safety_checker'],
|
83 |
return_config['requires_safety_checker'],
|
@@ -139,12 +141,20 @@ def assemble_code(str_config):
|
|
139 |
torch_dtype=data_type,
|
140 |
variant=variant).to(device)'''
|
141 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
142 |
code['054_requires_safety_checker'] = f'pipeline.requires_safety_checker = {config["requires_safety_checker"]}'
|
143 |
|
144 |
if str(config["safety_checker"]).lower() == 'false':
|
145 |
code['055_safety_checker'] = f'pipeline.safety_checker = None'
|
146 |
-
else:
|
147 |
-
code['055_safety_checker'] = ''
|
148 |
|
149 |
code['060_scheduler'] = f'pipeline.scheduler = {config["scheduler"]}.from_config(pipeline.scheduler.config)'
|
150 |
|
@@ -165,6 +175,16 @@ def assemble_code(str_config):
|
|
165 |
negative_prompt=negative_prompt,
|
166 |
generator=generator,
|
167 |
num_inference_steps=inference_steps,
|
168 |
-
guidance_scale=guidance_scale).images
|
169 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
170 |
return '\r\n'.join(value[1] for value in sorted(code.items()))
|
|
|
42 |
"allow_tensorfloat32": allow_tensorfloat32,
|
43 |
"use_safetensors": "False",
|
44 |
"data_type": data_type,
|
45 |
+
"refiner": "",
|
46 |
"safety_checker": "False",
|
47 |
"requires_safety_checker": "False",
|
48 |
"manual_seed": 42,
|
|
|
79 |
return_config['device'],
|
80 |
return_config['use_safetensors'],
|
81 |
return_config['data_type'],
|
82 |
+
return_config['refiner'],
|
83 |
return_config['variant'],
|
84 |
return_config['safety_checker'],
|
85 |
return_config['requires_safety_checker'],
|
|
|
141 |
torch_dtype=data_type,
|
142 |
variant=variant).to(device)'''
|
143 |
|
144 |
+
if config['refiner'] != '':
|
145 |
+
code['051_refiner'] = f'''refiner = DiffusionPipeline.from_pretrained(
|
146 |
+
"{config['refiner']}",
|
147 |
+
text_encoder_2 = base.text_encoder_2,
|
148 |
+
vae = base.vae,
|
149 |
+
torch_dtype = data_type,
|
150 |
+
use_safetensors = use_safetensors,
|
151 |
+
variant=variant,
|
152 |
+
).to(device)'''
|
153 |
+
|
154 |
code['054_requires_safety_checker'] = f'pipeline.requires_safety_checker = {config["requires_safety_checker"]}'
|
155 |
|
156 |
if str(config["safety_checker"]).lower() == 'false':
|
157 |
code['055_safety_checker'] = f'pipeline.safety_checker = None'
|
|
|
|
|
158 |
|
159 |
code['060_scheduler'] = f'pipeline.scheduler = {config["scheduler"]}.from_config(pipeline.scheduler.config)'
|
160 |
|
|
|
175 |
negative_prompt=negative_prompt,
|
176 |
generator=generator,
|
177 |
num_inference_steps=inference_steps,
|
178 |
+
guidance_scale=guidance_scale).images
|
179 |
+
'''
|
180 |
+
|
181 |
+
if config['refiner'] != '':
|
182 |
+
code["110_run_refiner"] = f'''image = refiner(
|
183 |
+
prompt=prompt,
|
184 |
+
negative_prompt=negative_prompt,
|
185 |
+
num_inference_steps=inference_steps
|
186 |
+
).images[0]'''
|
187 |
+
|
188 |
+
code["200_show_image"] = 'image[0]'
|
189 |
+
|
190 |
return '\r\n'.join(value[1] for value in sorted(code.items()))
|