refactoring and new feature: history
Browse files- app.py +127 -47
- requirements.txt +2 -1
app.py
CHANGED
@@ -20,7 +20,8 @@ import requests
|
|
20 |
from flask import Flask, render_template_string
|
21 |
from gradio import Interface
|
22 |
from diffusers import AutoencoderKL
|
23 |
-
|
|
|
24 |
|
25 |
def load_app_config():
|
26 |
global appConfig
|
@@ -46,15 +47,24 @@ code_pos_init_pipeline = '050_init_pipe'
|
|
46 |
code_pos_requires_safety_checker = '054_requires_safety_checker'
|
47 |
code_pos_safety_checker = '055_safety_checker'
|
48 |
code_pos_scheduler = '060_scheduler'
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
|
50 |
# model config
|
51 |
model_configs = appConfig.get("models", {})
|
52 |
models = list(model_configs.keys())
|
|
|
53 |
scheduler_configs = appConfig.get("schedulers", {})
|
54 |
schedulers = list(scheduler_configs.keys())
|
55 |
-
|
56 |
|
|
|
57 |
device = None
|
|
|
58 |
variant = None
|
59 |
allow_tensorfloat32 = False
|
60 |
use_safetensors = False
|
@@ -77,6 +87,25 @@ elif torch.backends.mps.is_available():
|
|
77 |
else:
|
78 |
device = "cpu"
|
79 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
80 |
def get_sorted_code():
|
81 |
|
82 |
return '\r\n'.join(value[1] for value in sorted(code.items()))
|
@@ -90,6 +119,7 @@ def device_change(device):
|
|
90 |
|
91 |
def models_change(model, scheduler):
|
92 |
|
|
|
93 |
use_safetensors = False
|
94 |
|
95 |
# no model selected (because this is UI init run)
|
@@ -111,6 +141,7 @@ def models_change(model, scheduler):
|
|
111 |
safety_checker_change(safety_checker)
|
112 |
requires_safety_checker_change(requires_safety_checker)
|
113 |
|
|
|
114 |
return get_sorted_code(), use_safetensors, scheduler
|
115 |
|
116 |
def data_type_change(selected_data_type):
|
@@ -197,7 +228,7 @@ def get_scheduler(scheduler, config):
|
|
197 |
return DPMSolverMultistepScheduler.from_config(config)
|
198 |
|
199 |
# pipeline
|
200 |
-
def
|
201 |
device,
|
202 |
use_safetensors,
|
203 |
data_type,
|
@@ -235,17 +266,22 @@ def start_pipeline(model,
|
|
235 |
|
236 |
pipeline.scheduler = get_scheduler(scheduler, pipeline.scheduler.config)
|
237 |
|
238 |
-
|
|
|
|
|
|
|
|
|
239 |
|
240 |
progress((3,3), desc="Creating the result...")
|
241 |
|
242 |
image = pipeline(
|
243 |
prompt=prompt,
|
244 |
negative_prompt=negative_prompt,
|
245 |
-
generator=generator
|
246 |
num_inference_steps=int(inference_steps),
|
247 |
guidance_scale=float(guidance_scale)).images[0]
|
248 |
-
|
|
|
249 |
return "Done.", image
|
250 |
|
251 |
else:
|
@@ -260,6 +296,46 @@ code[code_pos_init_pipeline] = 'sys.exit("No model selected!")'
|
|
260 |
code[code_pos_safety_checker] = 'pipeline.safety_checker = None'
|
261 |
code[code_pos_requires_safety_checker] = f'pipeline.requires_safety_checker = {requires_safety_checker}'
|
262 |
code[code_pos_scheduler] = 'sys.exit("No scheduler selected!")'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
263 |
|
264 |
# interface
|
265 |
with gr.Blocks() as demo:
|
@@ -270,25 +346,25 @@ with gr.Blocks() as demo:
|
|
270 |
</small>''')
|
271 |
gr.Markdown("### Device specific settings")
|
272 |
with gr.Row():
|
273 |
-
|
274 |
-
|
275 |
-
|
276 |
-
|
277 |
|
278 |
gr.Markdown("### Model specific settings")
|
279 |
with gr.Row():
|
280 |
-
|
281 |
with gr.Row():
|
282 |
with gr.Column(scale=1):
|
283 |
-
|
284 |
with gr.Column(scale=1):
|
285 |
-
|
286 |
-
|
287 |
|
288 |
gr.Markdown("### Scheduler")
|
289 |
with gr.Row():
|
290 |
-
|
291 |
-
|
292 |
|
293 |
gr.Markdown("### Adapters")
|
294 |
with gr.Row():
|
@@ -296,43 +372,47 @@ with gr.Blocks() as demo:
|
|
296 |
|
297 |
gr.Markdown("### Inference settings")
|
298 |
with gr.Row():
|
299 |
-
|
300 |
-
|
301 |
with gr.Row():
|
302 |
-
|
303 |
-
|
304 |
-
|
305 |
|
306 |
gr.Markdown("### Output")
|
307 |
with gr.Row():
|
308 |
btn_start_pipeline = gr.Button(value="Run inferencing")
|
309 |
with gr.Row():
|
310 |
-
|
311 |
-
|
312 |
-
|
|
|
|
|
|
|
|
|
313 |
|
314 |
-
|
315 |
-
|
316 |
-
|
317 |
-
|
318 |
-
|
319 |
-
|
320 |
-
|
321 |
-
|
322 |
-
btn_start_pipeline.click(
|
323 |
-
|
324 |
-
|
325 |
-
|
326 |
-
|
327 |
-
|
328 |
-
|
329 |
-
|
330 |
-
|
331 |
-
|
332 |
-
|
333 |
-
|
334 |
-
|
335 |
-
|
336 |
-
], outputs=[
|
337 |
|
338 |
demo.launch()
|
|
|
20 |
from flask import Flask, render_template_string
|
21 |
from gradio import Interface
|
22 |
from diffusers import AutoencoderKL
|
23 |
+
import pandas as pd
|
24 |
+
import base64
|
25 |
|
26 |
def load_app_config():
|
27 |
global appConfig
|
|
|
47 |
code_pos_requires_safety_checker = '054_requires_safety_checker'
|
48 |
code_pos_safety_checker = '055_safety_checker'
|
49 |
code_pos_scheduler = '060_scheduler'
|
50 |
+
code_pos_generator = '070_generator'
|
51 |
+
code_pos_prompt = '080_prompt'
|
52 |
+
code_pos_negative_prompt = '085_negative_prompt'
|
53 |
+
code_pos_inference_steps = '090_inference_steps'
|
54 |
+
code_pos_guidance_scale = '095_guidance_scale'
|
55 |
+
code_pos_run_inference = '100_run_inference'
|
56 |
|
57 |
# model config
|
58 |
model_configs = appConfig.get("models", {})
|
59 |
models = list(model_configs.keys())
|
60 |
+
model = None
|
61 |
scheduler_configs = appConfig.get("schedulers", {})
|
62 |
schedulers = list(scheduler_configs.keys())
|
63 |
+
scheduler = None
|
64 |
|
65 |
+
devices = appConfig.get("devices", [])
|
66 |
device = None
|
67 |
+
|
68 |
variant = None
|
69 |
allow_tensorfloat32 = False
|
70 |
use_safetensors = False
|
|
|
87 |
else:
|
88 |
device = "cpu"
|
89 |
|
90 |
+
# inference config
|
91 |
+
current_config = {
|
92 |
+
"device": device,
|
93 |
+
"model": model,
|
94 |
+
"scheduler": scheduler,
|
95 |
+
"variant": variant,
|
96 |
+
"allow_tensorflow": allow_tensorfloat32,
|
97 |
+
"use_safetensors": use_safetensors,
|
98 |
+
"data_type": data_type,
|
99 |
+
"safety_checker": safety_checker,
|
100 |
+
"requires_safety_checker": requires_safety_checker,
|
101 |
+
"manual_seed": manual_seed,
|
102 |
+
"inference_steps": inference_steps,
|
103 |
+
"guidance_scale": guidance_scale,
|
104 |
+
"prompt": prompt,
|
105 |
+
"negative_prompt": negative_prompt,
|
106 |
+
}
|
107 |
+
config_history = [current_config]
|
108 |
+
|
109 |
def get_sorted_code():
|
110 |
|
111 |
return '\r\n'.join(value[1] for value in sorted(code.items()))
|
|
|
119 |
|
120 |
def models_change(model, scheduler):
|
121 |
|
122 |
+
print(model)
|
123 |
use_safetensors = False
|
124 |
|
125 |
# no model selected (because this is UI init run)
|
|
|
141 |
safety_checker_change(safety_checker)
|
142 |
requires_safety_checker_change(requires_safety_checker)
|
143 |
|
144 |
+
print(use_safetensors)
|
145 |
return get_sorted_code(), use_safetensors, scheduler
|
146 |
|
147 |
def data_type_change(selected_data_type):
|
|
|
228 |
return DPMSolverMultistepScheduler.from_config(config)
|
229 |
|
230 |
# pipeline
|
231 |
+
def run_inference(model,
|
232 |
device,
|
233 |
use_safetensors,
|
234 |
data_type,
|
|
|
266 |
|
267 |
pipeline.scheduler = get_scheduler(scheduler, pipeline.scheduler.config)
|
268 |
|
269 |
+
|
270 |
+
if manual_seed < 0 or manual_seed is None or manual_seed == '':
|
271 |
+
generator = torch.Generator(device)
|
272 |
+
else:
|
273 |
+
generator = torch.manual_seed(42)
|
274 |
|
275 |
progress((3,3), desc="Creating the result...")
|
276 |
|
277 |
image = pipeline(
|
278 |
prompt=prompt,
|
279 |
negative_prompt=negative_prompt,
|
280 |
+
generator=generator,
|
281 |
num_inference_steps=int(inference_steps),
|
282 |
guidance_scale=float(guidance_scale)).images[0]
|
283 |
+
|
284 |
+
|
285 |
return "Done.", image
|
286 |
|
287 |
else:
|
|
|
296 |
code[code_pos_safety_checker] = 'pipeline.safety_checker = None'
|
297 |
code[code_pos_requires_safety_checker] = f'pipeline.requires_safety_checker = {requires_safety_checker}'
|
298 |
code[code_pos_scheduler] = 'sys.exit("No scheduler selected!")'
|
299 |
+
code[code_pos_generator] = f'generator = torch.Generator("{device}")'
|
300 |
+
code[code_pos_prompt] = f'prompt = "{prompt}"'
|
301 |
+
code[code_pos_negative_prompt] = f'negative_prompt = "{negative_prompt}"'
|
302 |
+
code[code_pos_inference_steps] = f'inference_steps = {inference_steps}'
|
303 |
+
code[code_pos_guidance_scale] = f'guidance_scale = {guidance_scale}'
|
304 |
+
code[code_pos_run_inference] = f'''image = pipeline(
|
305 |
+
prompt=prompt,
|
306 |
+
negative_prompt=negative_prompt,
|
307 |
+
generator=generator.manual_seed(manual_seed),
|
308 |
+
num_inference_steps=inference_steps,
|
309 |
+
guidance_scale=guidance_scale).images[0]'''
|
310 |
+
|
311 |
+
def dict_list_to_markdown_table(data):
|
312 |
+
if not data:
|
313 |
+
return ""
|
314 |
+
|
315 |
+
headers = list(data[0].keys())
|
316 |
+
markdown_table = "| action | " + " | ".join(headers) + " |\n"
|
317 |
+
markdown_table += "| --- | " + " | ".join(["---"] * len(headers)) + " |\n"
|
318 |
+
|
319 |
+
for i, row in enumerate(data):
|
320 |
+
# Encode row's content in base64 for sharing
|
321 |
+
encoded_row = base64.b64encode(str(row).encode()).decode()
|
322 |
+
# Create link to share the row's content
|
323 |
+
share_link = f'<a href="share/{encoded_row}">📎</a>'
|
324 |
+
# Create link to remove the row
|
325 |
+
remove_link = f'<a href="remove/{i}">❌</a>'
|
326 |
+
# Construct the row with links
|
327 |
+
markdown_table += f"| {share_link} {remove_link} | " + " | ".join(str(row.get(key, "")) for key in headers) + " |\n"
|
328 |
+
|
329 |
+
# Wrap the Markdown table in a <div> tag with horizontal scrolling
|
330 |
+
markdown_table = '<div style="overflow-x: auto;">\n\n' + markdown_table + '</div>'
|
331 |
+
|
332 |
+
return markdown_table
|
333 |
+
|
334 |
+
@app.route('/remove/<int:index>')
|
335 |
+
def remove_row(index):
|
336 |
+
if 0 <= index < len(data):
|
337 |
+
del data[index]
|
338 |
+
return "Row removed successfully"
|
339 |
|
340 |
# interface
|
341 |
with gr.Blocks() as demo:
|
|
|
346 |
</small>''')
|
347 |
gr.Markdown("### Device specific settings")
|
348 |
with gr.Row():
|
349 |
+
in_devices = gr.Dropdown(label="Device:", value=device, choices=devices, filterable=True, multiselect=False, allow_custom_value=True)
|
350 |
+
in_data_type = gr.Radio(label="Data Type:", value=data_type, choices=["bfloat16", "float16"], info="`blfoat16` is not supported on MPS devices right now; Half-precision weights, will save GPU memory, see https://huggingface.co/docs/diffusers/main/en/optimization/fp16")
|
351 |
+
in_allow_tensorfloat32 = gr.Radio(label="Allow TensorFloat32:", value=allow_tensorfloat32, choices=[True, False], info="is not supported on MPS devices right now; use TensorFloat-32 is faster, but results in slightly less accurate computations, see https://huggingface.co/docs/diffusers/main/en/optimization/fp16 ")
|
352 |
+
in_variant = gr.Radio(label="Variant:", value=variant, choices=["fp16", None], info="Use half-precision weights will save GPU memory, not all models support that, see https://huggingface.co/docs/diffusers/main/en/optimization/fp16 ")
|
353 |
|
354 |
gr.Markdown("### Model specific settings")
|
355 |
with gr.Row():
|
356 |
+
in_models = gr.Dropdown(choices=models, label="Model")
|
357 |
with gr.Row():
|
358 |
with gr.Column(scale=1):
|
359 |
+
in_use_safetensors = gr.Radio(label="Use safe tensors:", choices=["True", "False"], interactive=False)
|
360 |
with gr.Column(scale=1):
|
361 |
+
in_safety_checker = gr.Radio(label="Enable safety checker:", value=safety_checker, choices=[True, False])
|
362 |
+
in_requires_safety_checker = gr.Radio(label="Requires safety checker:", value=requires_safety_checker, choices=[True, False])
|
363 |
|
364 |
gr.Markdown("### Scheduler")
|
365 |
with gr.Row():
|
366 |
+
in_schedulers = gr.Dropdown(choices=schedulers, label="Scheduler", info="see https://huggingface.co/docs/diffusers/using-diffusers/loading#schedulers" )
|
367 |
+
out_scheduler_description = gr.Textbox(value="", label="Description")
|
368 |
|
369 |
gr.Markdown("### Adapters")
|
370 |
with gr.Row():
|
|
|
372 |
|
373 |
gr.Markdown("### Inference settings")
|
374 |
with gr.Row():
|
375 |
+
in_prompt = gr.TextArea(label="Prompt", value=prompt)
|
376 |
+
in_negative_prompt = gr.TextArea(label="Negative prompt", value=negative_prompt)
|
377 |
with gr.Row():
|
378 |
+
in_inference_steps = gr.Textbox(label="Inference steps", value=inference_steps)
|
379 |
+
in_manual_seed = gr.Textbox(label="Manual seed", value=manual_seed, info="Set this to -1 or leave it empty to randomly generate an image. A fixed value will result in a similar image for every run")
|
380 |
+
in_guidance_scale = gr.Textbox(label="Guidance Scale", value=guidance_scale, info="A low guidance scale leads to a faster inference time, with the drawback that negative prompts don’t have any effect on the denoising process.")
|
381 |
|
382 |
gr.Markdown("### Output")
|
383 |
with gr.Row():
|
384 |
btn_start_pipeline = gr.Button(value="Run inferencing")
|
385 |
with gr.Row():
|
386 |
+
# out_result = gr.Textbox(label="Status", value="")
|
387 |
+
out_image = gr.Image()
|
388 |
+
out_code = gr.Code(get_sorted_code(), label="Code")
|
389 |
+
with gr.Row():
|
390 |
+
out_current_config = gr.Code(value=str(current_config), label="Current config")
|
391 |
+
with gr.Row():
|
392 |
+
out_config_history = gr.Markdown(dict_list_to_markdown_table(config_history))
|
393 |
|
394 |
+
in_devices.change(device_change, inputs=[in_devices], outputs=[out_code])
|
395 |
+
in_data_type.change(data_type_change, inputs=[in_data_type], outputs=[out_code])
|
396 |
+
in_allow_tensorfloat32.change(tensorfloat32_change, inputs=[in_allow_tensorfloat32], outputs=[out_code])
|
397 |
+
in_variant.change(variant_change, inputs=[in_variant], outputs=[out_code])
|
398 |
+
in_models.change(models_change, inputs=[in_models, in_schedulers], outputs=[out_code, in_use_safetensors, in_schedulers])
|
399 |
+
in_safety_checker.change(safety_checker_change, inputs=[in_safety_checker], outputs=[out_code])
|
400 |
+
in_requires_safety_checker.change(requires_safety_checker_change, inputs=[in_requires_safety_checker], outputs=[out_code])
|
401 |
+
in_schedulers.change(schedulers_change, inputs=[in_schedulers], outputs=[out_code, out_scheduler_description])
|
402 |
+
btn_start_pipeline.click(run_inference, inputs=[
|
403 |
+
in_models,
|
404 |
+
in_devices,
|
405 |
+
in_use_safetensors,
|
406 |
+
in_data_type,
|
407 |
+
in_variant,
|
408 |
+
in_safety_checker,
|
409 |
+
in_requires_safety_checker,
|
410 |
+
in_schedulers,
|
411 |
+
in_prompt,
|
412 |
+
in_negative_prompt,
|
413 |
+
in_inference_steps,
|
414 |
+
in_manual_seed,
|
415 |
+
in_guidance_scale
|
416 |
+
], outputs=[out_image])
|
417 |
|
418 |
demo.launch()
|
requirements.txt
CHANGED
@@ -5,4 +5,5 @@ torch==2.2.1
|
|
5 |
gunicorn
|
6 |
urllib3==1.26.6
|
7 |
transformers
|
8 |
-
gradio
|
|
|
|
5 |
gunicorn
|
6 |
urllib3==1.26.6
|
7 |
transformers
|
8 |
+
gradio
|
9 |
+
stripe>=9.0.0
|