improving ui
Browse files- app.py +309 -6
- appConfig.json +13 -2
app.py
CHANGED
@@ -5,9 +5,22 @@ from flask import Flask, render_template, request, send_file, jsonify
|
|
5 |
import torch
|
6 |
import json
|
7 |
from PIL import Image
|
8 |
-
from diffusers import DiffusionPipeline
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
import threading
|
10 |
import requests
|
|
|
|
|
|
|
|
|
11 |
|
12 |
def load_app_config():
|
13 |
global appConfig
|
@@ -23,9 +36,299 @@ def load_app_config():
|
|
23 |
|
24 |
load_app_config()
|
25 |
|
26 |
-
|
27 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
|
29 |
-
|
30 |
-
|
31 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
import torch
|
6 |
import json
|
7 |
from PIL import Image
|
8 |
+
from diffusers import DiffusionPipeline
|
9 |
+
from diffusers import (
|
10 |
+
DDPMScheduler,
|
11 |
+
DDIMScheduler,
|
12 |
+
PNDMScheduler,
|
13 |
+
LMSDiscreteScheduler,
|
14 |
+
EulerAncestralDiscreteScheduler,
|
15 |
+
EulerDiscreteScheduler,
|
16 |
+
DPMSolverMultistepScheduler,
|
17 |
+
)
|
18 |
import threading
|
19 |
import requests
|
20 |
+
from flask import Flask, render_template_string
|
21 |
+
from gradio import Interface
|
22 |
+
from diffusers import AutoencoderKL
|
23 |
+
|
24 |
|
25 |
def load_app_config():
|
26 |
global appConfig
|
|
|
36 |
|
37 |
load_app_config()
|
38 |
|
39 |
+
# code output order
|
40 |
+
code = {}
|
41 |
+
code_pos_device = '001_code'
|
42 |
+
code_pos_data_type = '002_data_type'
|
43 |
+
code_pos_tf32 = '003_tf32'
|
44 |
+
code_pos_variant = '004_variant'
|
45 |
+
code_pos_init_pipeline = '050_init_pipe'
|
46 |
+
code_pos_requires_safety_checker = '054_requires_safety_checker'
|
47 |
+
code_pos_safety_checker = '055_safety_checker'
|
48 |
+
code_pos_scheduler = '060_scheduler'
|
49 |
+
|
50 |
+
# model config
|
51 |
+
model_configs = appConfig.get("models", {})
|
52 |
+
models = list(model_configs.keys())
|
53 |
+
scheduler_configs = appConfig.get("schedulers", {})
|
54 |
+
schedulers = list(scheduler_configs.keys())
|
55 |
+
|
56 |
+
device = None
|
57 |
+
variant = None
|
58 |
+
allow_tensorfloat32 = False
|
59 |
+
use_safetensors = False
|
60 |
+
data_type = 'float16'
|
61 |
+
safety_checker = False
|
62 |
+
requires_safety_checker = False
|
63 |
+
manual_seed = 42
|
64 |
+
inference_steps = 10
|
65 |
+
guidance_scale = 0.5
|
66 |
+
prompt = 'A white rabbit'
|
67 |
+
negative_prompt = 'lowres, cropped, worst quality, low quality, chat bubble, chat bubbles, ugly'
|
68 |
+
|
69 |
+
# init device parameters
|
70 |
+
if torch.cuda.is_available():
|
71 |
+
device = "cuda"
|
72 |
+
data_type = 'bfloat16'
|
73 |
+
allow_tensorfloat32 = True
|
74 |
+
elif torch.backends.mps.is_available():
|
75 |
+
device = "mps"
|
76 |
+
else:
|
77 |
+
device = "cpu"
|
78 |
+
|
79 |
+
def get_sorted_code():
|
80 |
+
|
81 |
+
return '\r\n'.join(value[1] for value in sorted(code.items()))
|
82 |
+
|
83 |
+
# change methods
|
84 |
+
def device_change(device):
|
85 |
+
|
86 |
+
code[code_pos_device] = f'''device = "{device}"'''
|
87 |
+
|
88 |
+
return get_sorted_code()
|
89 |
+
|
90 |
+
def models_change(model, scheduler):
|
91 |
+
|
92 |
+
use_safetensors = False
|
93 |
+
|
94 |
+
# no model selected (because this is UI init run)
|
95 |
+
if type(model) != list:
|
96 |
+
|
97 |
+
use_safetensors = str(model_configs[model]['use_safetensors'])
|
98 |
+
|
99 |
+
# if no scheduler is selected, choose the default one for this model
|
100 |
+
if scheduler == None:
|
101 |
+
|
102 |
+
scheduler = model_configs[model]['scheduler']
|
103 |
+
|
104 |
+
code[code_pos_init_pipeline] = f'''pipeline = DiffusionPipeline.from_pretrained(
|
105 |
+
"{model}",
|
106 |
+
use_safetensors=use_safetensors,
|
107 |
+
torch_dtype=data_type,
|
108 |
+
variant=variant).to(device)'''
|
109 |
+
|
110 |
+
safety_checker_change(safety_checker)
|
111 |
+
requires_safety_checker_change(requires_safety_checker)
|
112 |
+
|
113 |
+
return get_sorted_code(), use_safetensors, scheduler
|
114 |
+
|
115 |
+
def data_type_change(selected_data_type):
|
116 |
+
|
117 |
+
get_data_type(selected_data_type)
|
118 |
+
return get_sorted_code()
|
119 |
+
|
120 |
+
def get_data_type(selected_data_type):
|
121 |
+
|
122 |
+
if selected_data_type == "bfloat16":
|
123 |
+
code[code_pos_data_type] = 'data_type = torch.bfloat16'
|
124 |
+
data_type = torch.bfloat16 # BFloat16 is not supported on MPS as of 01/2024
|
125 |
+
else:
|
126 |
+
code[code_pos_data_type] = 'data_type = torch.float16'
|
127 |
+
data_type = torch.float16 # Half-precision weights, as of https://huggingface.co/docs/diffusers/main/en/optimization/fp16 will save GPU memory
|
128 |
+
|
129 |
+
return data_type
|
130 |
+
|
131 |
+
def tensorfloat32_change(allow_tensorfloat32):
|
132 |
+
|
133 |
+
get_tensorfloat32(allow_tensorfloat32)
|
134 |
+
|
135 |
+
return get_sorted_code()
|
136 |
+
|
137 |
+
def get_tensorfloat32(allow_tensorfloat32):
|
138 |
+
|
139 |
+
code[code_pos_tf32] = f'torch.backends.cuda.matmul.allow_tf32 = {allow_tensorfloat32}'
|
140 |
+
|
141 |
+
return True if str(allow_tensorfloat32).lower() == 'true' else False
|
142 |
+
|
143 |
+
def variant_change(variant):
|
144 |
+
|
145 |
+
if str(variant) == 'None':
|
146 |
+
code[code_pos_variant] = f'variant = {variant}'
|
147 |
+
else:
|
148 |
+
code[code_pos_variant] = f'variant = "{variant}"'
|
149 |
+
|
150 |
+
return get_sorted_code()
|
151 |
+
|
152 |
+
def safety_checker_change(safety_checker):
|
153 |
+
|
154 |
+
if not safety_checker or str(safety_checker).lower == 'false':
|
155 |
+
code[code_pos_safety_checker] = f'pipeline.safety_checker = None'
|
156 |
+
else:
|
157 |
+
code[code_pos_safety_checker] = ''
|
158 |
+
|
159 |
+
return get_sorted_code()
|
160 |
+
|
161 |
+
def requires_safety_checker_change(requires_safety_checker):
|
162 |
+
|
163 |
+
code[code_pos_requires_safety_checker] = f'pipeline.requires_safety_checker = {requires_safety_checker}'
|
164 |
+
|
165 |
+
return get_sorted_code()
|
166 |
+
|
167 |
+
def schedulers_change(scheduler):
|
168 |
+
|
169 |
+
if type(scheduler) != list:
|
170 |
|
171 |
+
code[code_pos_scheduler] = f'pipeline.scheduler = {scheduler}.from_config(pipeline.scheduler.config)'
|
172 |
+
|
173 |
+
return get_sorted_code(), scheduler_configs[scheduler]
|
174 |
+
|
175 |
+
else:
|
176 |
+
|
177 |
+
return get_sorted_code(), ''
|
178 |
+
|
179 |
+
def get_scheduler(scheduler, config):
|
180 |
+
|
181 |
+
if scheduler == "DDPMScheduler":
|
182 |
+
return DDPMScheduler.from_config(config)
|
183 |
+
elif scheduler == "DDIMScheduler":
|
184 |
+
return DDIMScheduler.from_config(config)
|
185 |
+
elif scheduler == "PNDMScheduler":
|
186 |
+
return PNDMScheduler.from_config(config)
|
187 |
+
elif scheduler == "LMSDiscreteScheduler":
|
188 |
+
return LMSDiscreteScheduler.from_config(config)
|
189 |
+
elif scheduler == "EulerAncestralDiscreteScheduler":
|
190 |
+
return EulerAncestralDiscreteScheduler.from_config(config)
|
191 |
+
elif scheduler == "EulerDiscreteScheduler":
|
192 |
+
return EulerDiscreteScheduler.from_config(config)
|
193 |
+
elif scheduler == "DPMSolverMultistepScheduler":
|
194 |
+
return DPMSolverMultistepScheduler.from_config(config)
|
195 |
+
else:
|
196 |
+
return DPMSolverMultistepScheduler.from_config(config)
|
197 |
+
|
198 |
+
# pipeline
|
199 |
+
def start_pipeline(model,
|
200 |
+
device,
|
201 |
+
use_safetensors,
|
202 |
+
data_type,
|
203 |
+
variant,
|
204 |
+
safety_checker,
|
205 |
+
requires_safety_checker,
|
206 |
+
scheduler,
|
207 |
+
prompt,
|
208 |
+
negative_prompt,
|
209 |
+
inference_steps,
|
210 |
+
manual_seed,
|
211 |
+
guidance_scale,
|
212 |
+
progress=gr.Progress(track_tqdm=True)):
|
213 |
+
|
214 |
+
if model != None and scheduler != None:
|
215 |
+
|
216 |
+
progress((1,3), desc="Preparing pipeline initialization...")
|
217 |
+
|
218 |
+
torch.backends.cuda.matmul.allow_tf32 = get_tensorfloat32(allow_tensorfloat32) # Use TensorFloat-32 as of https://huggingface.co/docs/diffusers/main/en/optimization/fp16 faster, but slightly less accurate computations
|
219 |
+
|
220 |
+
bool_use_safetensors = True if use_safetensors.lower() == 'true' else False
|
221 |
+
|
222 |
+
progress((2,3), desc="Initializing pipeline...")
|
223 |
+
|
224 |
+
pipeline = DiffusionPipeline.from_pretrained(
|
225 |
+
model,
|
226 |
+
use_safetensors=bool_use_safetensors,
|
227 |
+
torch_dtype=get_data_type(data_type),
|
228 |
+
variant=variant).to(device)
|
229 |
+
|
230 |
+
if safety_checker is None or str(safety_checker).lower == 'false':
|
231 |
+
pipeline.safety_checker = None
|
232 |
+
|
233 |
+
pipeline.requires_safety_checker = bool(requires_safety_checker)
|
234 |
+
|
235 |
+
pipeline.scheduler = get_scheduler(scheduler, pipeline.scheduler.config)
|
236 |
+
|
237 |
+
generator = torch.Generator(device)
|
238 |
+
|
239 |
+
progress((3,3), desc="Creating the result...")
|
240 |
+
|
241 |
+
image = pipeline(
|
242 |
+
prompt=prompt,
|
243 |
+
negative_prompt=negative_prompt,
|
244 |
+
generator=generator.manual_seed(int(manual_seed)),
|
245 |
+
num_inference_steps=int(inference_steps),
|
246 |
+
guidance_scale=float(guidance_scale)).images[0]
|
247 |
+
|
248 |
+
return "Done.", image
|
249 |
+
|
250 |
+
else:
|
251 |
+
|
252 |
+
return "Please select a model AND a scheduler.", None
|
253 |
+
|
254 |
+
code[code_pos_device] = f'device = "{device}"'
|
255 |
+
code[code_pos_variant] = f'variant = {variant}'
|
256 |
+
code[code_pos_tf32] = f'torch.backends.cuda.matmul.allow_tf32 = {allow_tensorfloat32}'
|
257 |
+
code[code_pos_data_type] = 'data_type = torch.bfloat16'
|
258 |
+
code[code_pos_init_pipeline] = 'sys.exit("No model selected!")'
|
259 |
+
code[code_pos_safety_checker] = 'pipeline.safety_checker = None'
|
260 |
+
code[code_pos_requires_safety_checker] = f'pipeline.requires_safety_checker = {requires_safety_checker}'
|
261 |
+
code[code_pos_scheduler] = 'sys.exit("No scheduler selected!")'
|
262 |
+
|
263 |
+
# interface
|
264 |
+
with gr.Blocks() as demo:
|
265 |
+
|
266 |
+
gr.Markdown("## Image Generation")
|
267 |
+
gr.Markdown("### Device specific settings")
|
268 |
+
with gr.Row():
|
269 |
+
rg_device = gr.Radio(label="Device:", value=device, choices=["cpu", "mps", "gpu"])
|
270 |
+
rg_data_type = gr.Radio(label="Data Type:", value=data_type, choices=["bfloat16", "float16"], info="`blfoat16` is not supported on MPS devices right now; Half-precision weights, will save GPU memory, see https://huggingface.co/docs/diffusers/main/en/optimization/fp16")
|
271 |
+
rg_allow_tensorfloat32 = gr.Radio(label="Allow TensorFloat32:", value=allow_tensorfloat32, choices=[True, False], info="is not supported on MPS devices right now; use TensorFloat-32 is faster, but results in slightly less accurate computations, see https://huggingface.co/docs/diffusers/main/en/optimization/fp16 ")
|
272 |
+
rg_variant = gr.Radio(label="Variant:", value=variant, choices=["fp16", None], info="Use half-precision weights will save GPU memory, not all models support that, see https://huggingface.co/docs/diffusers/main/en/optimization/fp16 ")
|
273 |
+
|
274 |
+
gr.Markdown("### Model specific settings")
|
275 |
+
with gr.Row():
|
276 |
+
dd_models = gr.Dropdown(choices=models, label="Model", )
|
277 |
+
with gr.Row():
|
278 |
+
with gr.Column(scale=1):
|
279 |
+
rg_use_safetensors = gr.Radio(label="Use safe tensors:", choices=["True", "False"], interactive=False)
|
280 |
+
with gr.Column(scale=1):
|
281 |
+
rg_safety_checker = gr.Radio(label="Enable safety checker:", value=safety_checker, choices=[True, False])
|
282 |
+
rg_requires_safety_checker = gr.Radio(label="Requires safety checker:", value=requires_safety_checker, choices=[True, False])
|
283 |
+
|
284 |
+
gr.Markdown("### Scheduler")
|
285 |
+
with gr.Row():
|
286 |
+
dd_schedulers = gr.Dropdown(choices=schedulers, label="Scheduler", info="see https://huggingface.co/docs/diffusers/using-diffusers/loading#schedulers" )
|
287 |
+
txt_scheduler = gr.Textbox(value="", label="Description")
|
288 |
+
|
289 |
+
gr.Markdown("### Adapters")
|
290 |
+
with gr.Row():
|
291 |
+
gr.Markdown('Choose an adapter.')
|
292 |
+
|
293 |
+
gr.Markdown("### Inference settings")
|
294 |
+
with gr.Row():
|
295 |
+
el_prompt = gr.TextArea(label="Prompt", value=prompt)
|
296 |
+
el_negative_prompt = gr.TextArea(label="Negative prompt", value=negative_prompt)
|
297 |
+
with gr.Row():
|
298 |
+
el_inference_steps = gr.Textbox(label="Inference steps", value=inference_steps)
|
299 |
+
el_manual_seed = gr.Textbox(label="Manual seed", value=manual_seed)
|
300 |
+
el_guidance_scale = gr.Textbox(label="Guidance Scale", value=guidance_scale, info="A low guidance scale leads to a faster inference time, with the drawback that negative prompts don’t have any effect on the denoising process.")
|
301 |
+
|
302 |
+
gr.Markdown("### Output")
|
303 |
+
with gr.Row():
|
304 |
+
btn_start_pipeline = gr.Button(value="Run inferencing")
|
305 |
+
with gr.Row():
|
306 |
+
tb_result = gr.Textbox(label="Status", value="")
|
307 |
+
el_result = gr.Image()
|
308 |
+
txt_code = gr.Code(get_sorted_code(), label="Code")
|
309 |
+
|
310 |
+
rg_device.change(device_change, inputs=[rg_device], outputs=[txt_code])
|
311 |
+
rg_data_type.change(data_type_change, inputs=[rg_data_type], outputs=[txt_code])
|
312 |
+
rg_allow_tensorfloat32.change(tensorfloat32_change, inputs=[rg_allow_tensorfloat32], outputs=[txt_code])
|
313 |
+
rg_variant.change(variant_change, inputs=[rg_variant], outputs=[txt_code])
|
314 |
+
dd_models.change(models_change, inputs=[dd_models, dd_schedulers], outputs=[txt_code, rg_use_safetensors, dd_schedulers])
|
315 |
+
rg_safety_checker.change(safety_checker_change, inputs=[rg_safety_checker], outputs=[txt_code])
|
316 |
+
rg_requires_safety_checker.change(requires_safety_checker_change, inputs=[rg_requires_safety_checker], outputs=[txt_code])
|
317 |
+
dd_schedulers.change(schedulers_change, inputs=[dd_schedulers], outputs=[txt_code, txt_scheduler])
|
318 |
+
btn_start_pipeline.click(start_pipeline, inputs=[
|
319 |
+
dd_models,
|
320 |
+
rg_device,
|
321 |
+
rg_use_safetensors,
|
322 |
+
rg_data_type,
|
323 |
+
rg_variant,
|
324 |
+
rg_safety_checker,
|
325 |
+
rg_requires_safety_checker,
|
326 |
+
dd_schedulers,
|
327 |
+
el_prompt,
|
328 |
+
el_negative_prompt,
|
329 |
+
el_inference_steps,
|
330 |
+
el_manual_seed,
|
331 |
+
el_guidance_scale
|
332 |
+
], outputs=[tb_result, el_result])
|
333 |
+
|
334 |
+
demo.launch()
|
appConfig.json
CHANGED
@@ -2,13 +2,24 @@
|
|
2 |
"models": {
|
3 |
|
4 |
"sd-dreambooth-library/solo-levelling-art-style": {
|
5 |
-
"use_safetensors": false
|
|
|
6 |
},
|
7 |
"CompVis/stable-diffusion-v1-4": {
|
8 |
-
"use_safetensors": true
|
|
|
9 |
}
|
10 |
|
11 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
"negative_prompts": [
|
13 |
"lowres, cropped, worst quality, low quality"
|
14 |
]
|
|
|
2 |
"models": {
|
3 |
|
4 |
"sd-dreambooth-library/solo-levelling-art-style": {
|
5 |
+
"use_safetensors": false,
|
6 |
+
"scheduler": "DDPMScheduler"
|
7 |
},
|
8 |
"CompVis/stable-diffusion-v1-4": {
|
9 |
+
"use_safetensors": true,
|
10 |
+
"scheduler": "DDPMScheduler"
|
11 |
}
|
12 |
|
13 |
},
|
14 |
+
"schedulers": {
|
15 |
+
"DDPMScheduler": "This is DDPM",
|
16 |
+
"DDIMScheduler": "This is DDIM",
|
17 |
+
"PNDMScheduler": "This is PNDM",
|
18 |
+
"LMSDiscreteScheduler": "This is LMS",
|
19 |
+
"EulerAncestralDiscreteScheduler": "This is Euler, too",
|
20 |
+
"EulerDiscreteScheduler": "This is Euler",
|
21 |
+
"DPMSolverMultistepScheduler": "This is DPM Solve Multistep"
|
22 |
+
},
|
23 |
"negative_prompts": [
|
24 |
"lowres, cropped, worst quality, low quality"
|
25 |
]
|