fix wrong state handling
Browse files- app.py +138 -154
- appConfig.json +2 -0
- change_handlers.py +0 -0
- config.py +108 -119
- helpers.py +76 -0
app.py
CHANGED
@@ -6,15 +6,6 @@ import torch
|
|
6 |
import json
|
7 |
from PIL import Image
|
8 |
from diffusers import DiffusionPipeline
|
9 |
-
from diffusers import (
|
10 |
-
DDPMScheduler,
|
11 |
-
DDIMScheduler,
|
12 |
-
PNDMScheduler,
|
13 |
-
LMSDiscreteScheduler,
|
14 |
-
EulerAncestralDiscreteScheduler,
|
15 |
-
EulerDiscreteScheduler,
|
16 |
-
DPMSolverMultistepScheduler,
|
17 |
-
)
|
18 |
import threading
|
19 |
import requests
|
20 |
from flask import Flask, render_template_string
|
@@ -22,181 +13,175 @@ from gradio import Interface
|
|
22 |
from diffusers import AutoencoderKL
|
23 |
import pandas as pd
|
24 |
import base64
|
25 |
-
from config import
|
|
|
26 |
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
|
31 |
-
return config
|
|
|
|
|
|
|
|
|
32 |
|
33 |
-
def models_change(model, scheduler):
|
34 |
-
|
35 |
use_safetensors = False
|
36 |
|
37 |
# no model selected (because this is UI init run)
|
38 |
-
if type(model) != list and model
|
|
|
|
|
|
|
39 |
|
40 |
-
use_safetensors = str(config.model_configs[model]['use_safetensors'])
|
41 |
-
|
42 |
# if no scheduler is selected, choose the default one for this model
|
43 |
if scheduler == None:
|
44 |
|
45 |
-
scheduler =
|
46 |
|
47 |
-
|
48 |
-
requires_safety_checker_change(config.current["requires_safety_checker"])
|
49 |
|
50 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
|
52 |
-
|
53 |
|
54 |
-
|
55 |
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
return torch.bfloat16 # BFloat16 is not supported on MPS as of 01/2024
|
60 |
-
else:
|
61 |
-
return torch.float16 # Half-precision weights, as of https://huggingface.co/docs/diffusers/main/en/optimization/fp16 will save GPU memory
|
62 |
|
63 |
-
def tensorfloat32_change(allow_tensorfloat32):
|
64 |
|
65 |
-
|
|
|
|
|
66 |
|
67 |
-
def inference_steps_change(inference_steps):
|
68 |
|
69 |
-
|
70 |
|
71 |
-
|
72 |
-
|
73 |
-
return config.set_config('manual_seed', manual_seed), config.assemble_code()
|
74 |
|
75 |
-
def
|
76 |
|
77 |
-
|
78 |
|
79 |
-
|
80 |
-
|
81 |
-
return config.set_config('prompt', prompt), config.assemble_code()
|
82 |
|
83 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
|
85 |
-
|
86 |
|
87 |
-
|
|
|
|
|
88 |
|
89 |
-
|
90 |
|
91 |
-
|
92 |
-
|
93 |
-
return config.set_config('safety_checker', safety_checker), config.assemble_code()
|
94 |
|
95 |
-
def
|
96 |
|
97 |
-
|
98 |
|
99 |
-
|
100 |
|
101 |
-
|
102 |
-
|
103 |
-
def get_tensorfloat32(allow_tensorfloat32):
|
104 |
|
105 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
106 |
|
107 |
-
def
|
108 |
|
109 |
-
if scheduler
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
elif scheduler == "PNDMScheduler":
|
114 |
-
return PNDMScheduler.from_config(pipeline_config)
|
115 |
-
elif scheduler == "LMSDiscreteScheduler":
|
116 |
-
return LMSDiscreteScheduler.from_config(pipeline_config)
|
117 |
-
elif scheduler == "EulerAncestralDiscreteScheduler":
|
118 |
-
return EulerAncestralDiscreteScheduler.from_config(pipeline_config)
|
119 |
-
elif scheduler == "EulerDiscreteScheduler":
|
120 |
-
return EulerDiscreteScheduler.from_config(pipeline_config)
|
121 |
-
elif scheduler == "DPMSolverMultistepScheduler":
|
122 |
-
return DPMSolverMultistepScheduler.from_config(pipeline_config)
|
123 |
else:
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
# - initial configuration,
|
128 |
-
# - a list of available devices from the config file
|
129 |
-
# - a list of available models from the config file
|
130 |
-
# - a list of available schedulers from the config file
|
131 |
-
# - a dict that contains code to for reproduction
|
132 |
|
133 |
-
|
134 |
-
|
|
|
135 |
|
136 |
-
|
|
|
|
|
|
|
137 |
|
138 |
progress((1,3), desc="Preparing pipeline initialization...")
|
139 |
|
140 |
-
torch.backends.cuda.matmul.allow_tf32 = config
|
141 |
|
142 |
progress((2,3), desc="Initializing pipeline...")
|
143 |
-
|
144 |
pipeline = DiffusionPipeline.from_pretrained(
|
145 |
-
config
|
146 |
-
use_safetensors = config
|
147 |
-
torch_dtype = get_data_type(config
|
148 |
-
variant = config
|
149 |
|
150 |
-
if
|
151 |
pipeline.safety_checker = None
|
152 |
|
153 |
-
pipeline.requires_safety_checker = config
|
154 |
-
|
155 |
-
pipeline.scheduler = get_scheduler(config
|
156 |
|
157 |
-
if config
|
158 |
-
generator = torch.Generator(config
|
159 |
else:
|
160 |
-
generator = torch.manual_seed(
|
161 |
|
162 |
progress((3,3), desc="Creating the result...")
|
163 |
|
164 |
image = pipeline(
|
165 |
-
prompt = config
|
166 |
-
negative_prompt = config
|
167 |
generator = generator,
|
168 |
-
num_inference_steps = config
|
169 |
-
guidance_scale = config
|
170 |
|
171 |
-
|
172 |
|
173 |
-
return image, dict_list_to_markdown_table(
|
174 |
|
175 |
else:
|
176 |
|
177 |
-
return "Please select a model AND a scheduler.", None
|
178 |
-
|
179 |
-
def dict_list_to_markdown_table(config_history):
|
180 |
-
|
181 |
-
if not config_history:
|
182 |
-
return ""
|
183 |
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
for index, config in enumerate(config_history):
|
189 |
-
|
190 |
-
encoded_config = base64.b64encode(str(config).encode()).decode()
|
191 |
-
share_link = f'<a target="_blank" href="?config={encoded_config}">📎</a>'
|
192 |
-
markdown_table += f"| {share_link} | " + " | ".join(str(config.get(key, "")) for key in headers) + " |\n"
|
193 |
-
|
194 |
-
markdown_table = '<div style="overflow-x: auto;">\n\n' + markdown_table + '</div>'
|
195 |
-
|
196 |
-
return markdown_table
|
197 |
|
198 |
# interface
|
199 |
with gr.Blocks() as demo:
|
|
|
|
|
|
|
200 |
|
201 |
gr.Markdown('''## Text-2-Image Playground
|
202 |
<small>by <a target="_blank" href="https://www.linkedin.com/in/nickyreinert/">Nicky Reinert</a> |
|
@@ -204,26 +189,25 @@ with gr.Blocks() as demo:
|
|
204 |
</small>''')
|
205 |
gr.Markdown("### Device specific settings")
|
206 |
with gr.Row():
|
207 |
-
in_devices = gr.Dropdown(label="Device:", value=config.
|
208 |
-
in_data_type = gr.Radio(label="Data Type:", value=config.
|
209 |
-
in_allow_tensorfloat32 = gr.Radio(label="Allow TensorFloat32:", value=config.
|
210 |
-
in_variant = gr.Radio(label="Variant:", value=config.
|
211 |
|
212 |
gr.Markdown("### Model specific settings")
|
213 |
with gr.Row():
|
214 |
-
|
215 |
-
|
216 |
with gr.Row():
|
217 |
with gr.Column(scale=1):
|
218 |
in_use_safetensors = gr.Radio(label="Use safe tensors:", choices=["True", "False"], interactive=False)
|
219 |
with gr.Column(scale=1):
|
220 |
-
in_safety_checker = gr.Radio(label="Enable safety checker:", value=config.
|
221 |
-
in_requires_safety_checker = gr.Radio(label="Requires safety checker:", value=config.
|
222 |
|
223 |
gr.Markdown("### Scheduler")
|
224 |
with gr.Row():
|
225 |
-
|
226 |
-
in_schedulers = gr.Dropdown(choices=schedulers, label="Scheduler", info="see https://huggingface.co/docs/diffusers/using-diffusers/loading#schedulers" )
|
227 |
out_scheduler_description = gr.Textbox(value="", label="Description")
|
228 |
|
229 |
gr.Markdown("### Adapters")
|
@@ -232,12 +216,12 @@ with gr.Blocks() as demo:
|
|
232 |
|
233 |
gr.Markdown("### Inference settings")
|
234 |
with gr.Row():
|
235 |
-
in_prompt = gr.TextArea(label="Prompt", value=config.
|
236 |
-
in_negative_prompt = gr.TextArea(label="Negative prompt", value=config.
|
237 |
with gr.Row():
|
238 |
-
in_inference_steps = gr.Number(label="Inference steps", value=config.
|
239 |
-
in_manual_seed = gr.Number(label="Manual seed", value=config.
|
240 |
-
in_guidance_scale = gr.Slider(minimum=0, maximum=1, step=0.01, label="Guidance Scale", value=config.
|
241 |
|
242 |
gr.Markdown("### Output")
|
243 |
with gr.Row():
|
@@ -245,30 +229,30 @@ with gr.Blocks() as demo:
|
|
245 |
with gr.Row():
|
246 |
# out_result = gr.Textbox(label="Status", value="")
|
247 |
out_image = gr.Image()
|
248 |
-
out_code = gr.Code(config.
|
249 |
with gr.Row():
|
250 |
-
|
251 |
with gr.Row():
|
252 |
-
out_config_history = gr.Markdown(dict_list_to_markdown_table(config_history))
|
253 |
|
254 |
-
in_devices.change(device_change, inputs=[in_devices], outputs=[
|
255 |
-
in_data_type.change(data_type_change, inputs=[in_data_type], outputs=[
|
256 |
-
in_allow_tensorfloat32.change(tensorfloat32_change, inputs=[in_allow_tensorfloat32], outputs=[
|
257 |
-
in_variant.change(variant_change, inputs=[in_variant], outputs=[
|
258 |
-
in_models.change(models_change, inputs=[in_models, in_schedulers], outputs=[in_use_safetensors, in_schedulers,
|
259 |
-
in_safety_checker.change(safety_checker_change, inputs=[in_safety_checker], outputs=[
|
260 |
-
in_requires_safety_checker.change(requires_safety_checker_change, inputs=[in_requires_safety_checker], outputs=[
|
261 |
-
in_schedulers.change(schedulers_change, inputs=[in_schedulers], outputs=[out_scheduler_description,
|
262 |
-
in_inference_steps.change(inference_steps_change, inputs=[in_inference_steps], outputs=[
|
263 |
-
in_manual_seed.change(manual_seed_change, inputs=[in_manual_seed], outputs=[
|
264 |
-
in_guidance_scale.change(guidance_scale_change, inputs=[in_guidance_scale], outputs=[
|
265 |
-
in_prompt.change(prompt_change, inputs=[in_prompt], outputs=[
|
266 |
-
in_negative_prompt.change(negative_prompt_change, inputs=[in_negative_prompt], outputs=[
|
267 |
-
btn_start_pipeline.click(run_inference, inputs=[], outputs=[out_image, out_config_history])
|
268 |
|
269 |
# send current respect initial config to init_config to populate parameters to all relevant input fields
|
270 |
# if GET parameter is set, it will overwrite initial config parameters
|
271 |
-
demo.load(fn=
|
272 |
outputs=[
|
273 |
in_models,
|
274 |
in_devices,
|
|
|
6 |
import json
|
7 |
from PIL import Image
|
8 |
from diffusers import DiffusionPipeline
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
import threading
|
10 |
import requests
|
11 |
from flask import Flask, render_template_string
|
|
|
13 |
from diffusers import AutoencoderKL
|
14 |
import pandas as pd
|
15 |
import base64
|
16 |
+
from config import *
|
17 |
+
from helpers import *
|
18 |
|
19 |
+
def device_change(device, config):
|
20 |
+
|
21 |
+
config = set_config(config, 'device', device)
|
22 |
|
23 |
+
return config, str(config), assemble_code(config)
|
24 |
+
|
25 |
+
def models_change(model, scheduler, config):
|
26 |
+
|
27 |
+
config = set_config(config, 'model', model)
|
28 |
|
|
|
|
|
29 |
use_safetensors = False
|
30 |
|
31 |
# no model selected (because this is UI init run)
|
32 |
+
if type(model) != list and str(model) != 'None':
|
33 |
+
|
34 |
+
use_safetensors = str(models[model]['use_safetensors'])
|
35 |
+
model_description = models[model]['description']
|
36 |
|
|
|
|
|
37 |
# if no scheduler is selected, choose the default one for this model
|
38 |
if scheduler == None:
|
39 |
|
40 |
+
scheduler = models[model]['scheduler']
|
41 |
|
42 |
+
else:
|
|
|
43 |
|
44 |
+
model_description = 'Please select a model.'
|
45 |
+
|
46 |
+
config["use_safetensors"] = str(use_safetensors)
|
47 |
+
config["scheduler"] = str(scheduler)
|
48 |
+
|
49 |
+
# safety_checker_change(in_safety_checker.value, config)
|
50 |
+
# requires_safety_checker_change(in_requires_safety_checker.value, config)
|
51 |
|
52 |
+
return model_description, use_safetensors, scheduler, config, str(config), assemble_code(config)
|
53 |
|
54 |
+
def data_type_change(data_type, config):
|
55 |
|
56 |
+
config = set_config(config, 'data_type', data_type)
|
57 |
+
|
58 |
+
return config, str(config), assemble_code(config)
|
|
|
|
|
|
|
59 |
|
60 |
+
def tensorfloat32_change(allow_tensorfloat32, config):
|
61 |
|
62 |
+
config = set_config(config, 'allow_tensorfloat32', allow_tensorfloat32)
|
63 |
+
|
64 |
+
return config, str(config), assemble_code(config)
|
65 |
|
66 |
+
def inference_steps_change(inference_steps, config):
|
67 |
|
68 |
+
config = set_config(config, 'inference_steps', inference_steps)
|
69 |
|
70 |
+
return config, str(config), assemble_code(config)
|
|
|
|
|
71 |
|
72 |
+
def manual_seed_change(manual_seed, config):
|
73 |
|
74 |
+
config = set_config(config, 'manual_seed', manual_seed)
|
75 |
|
76 |
+
return config, str(config), assemble_code(config)
|
|
|
|
|
77 |
|
78 |
+
def guidance_scale_change(guidance_scale, config):
|
79 |
+
|
80 |
+
config = set_config(config, 'guidance_scale', guidance_scale)
|
81 |
+
|
82 |
+
return config, str(config), assemble_code(config)
|
83 |
+
|
84 |
+
def prompt_change(prompt, config):
|
85 |
|
86 |
+
config = set_config(config, 'prompt', prompt)
|
87 |
|
88 |
+
return config, str(config), assemble_code(config)
|
89 |
+
|
90 |
+
def negative_prompt_change(negative_prompt, config):
|
91 |
|
92 |
+
config = set_config(config, 'negative_prompt', negative_prompt)
|
93 |
|
94 |
+
return config, str(config), assemble_code(config)
|
|
|
|
|
95 |
|
96 |
+
def variant_change(variant, config):
|
97 |
|
98 |
+
config = set_config(config, 'variant', variant)
|
99 |
|
100 |
+
return config, str(config), assemble_code(config)
|
101 |
|
102 |
+
def safety_checker_change(safety_checker, config):
|
|
|
|
|
103 |
|
104 |
+
config = set_config(config, 'safety_checker', safety_checker)
|
105 |
+
|
106 |
+
return config, str(config), assemble_code(config)
|
107 |
+
|
108 |
+
def requires_safety_checker_change(requires_safety_checker, config):
|
109 |
+
|
110 |
+
config = set_config(config, 'requires_safety_checker', requires_safety_checker)
|
111 |
+
|
112 |
+
return config, str(config), assemble_code(config)
|
113 |
|
114 |
+
def schedulers_change(scheduler, config):
|
115 |
|
116 |
+
if str(scheduler) != 'None' and type(scheduler) != list:
|
117 |
+
|
118 |
+
scheduler_description = schedulers[scheduler]
|
119 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
120 |
else:
|
121 |
+
scheduler_description = 'Please select a scheduler.'
|
122 |
+
|
123 |
+
config = set_config(config, 'scheduler', scheduler)
|
|
|
|
|
|
|
|
|
|
|
124 |
|
125 |
+
return scheduler_description, config, str(config), assemble_code(config)
|
126 |
+
|
127 |
+
def run_inference(config, config_history, progress=gr.Progress(track_tqdm=True)):
|
128 |
|
129 |
+
# str_config = str_config.replace("'", '"').replace('None', 'null').replace('False', 'false')
|
130 |
+
# config = json.loads(str_config)
|
131 |
+
|
132 |
+
if str(config["model"]) != 'None' and str(config["scheduler"]) != 'None':
|
133 |
|
134 |
progress((1,3), desc="Preparing pipeline initialization...")
|
135 |
|
136 |
+
torch.backends.cuda.matmul.allow_tf32 = get_bool(config["allow_tensorfloat32"]) # Use TensorFloat-32 as of https://huggingface.co/docs/diffusers/main/en/optimization/fp16 faster, but slightly less accurate computations
|
137 |
|
138 |
progress((2,3), desc="Initializing pipeline...")
|
139 |
+
|
140 |
pipeline = DiffusionPipeline.from_pretrained(
|
141 |
+
config["model"],
|
142 |
+
use_safetensors = get_bool(config["use_safetensors"]),
|
143 |
+
torch_dtype = get_data_type(config["data_type"]),
|
144 |
+
variant = get_variant(config["variant"])).to(config["device"])
|
145 |
|
146 |
+
if str(config["safety_checker"]).lower() == 'false':
|
147 |
pipeline.safety_checker = None
|
148 |
|
149 |
+
pipeline.requires_safety_checker = get_bool(config["requires_safety_checker"])
|
150 |
+
|
151 |
+
pipeline.scheduler = get_scheduler(config["scheduler"], pipeline.scheduler.config)
|
152 |
|
153 |
+
if config["manual_seed"] < 0 or config["manual_seed"] is None or config["manual_seed"] == '':
|
154 |
+
generator = torch.Generator(config["device"])
|
155 |
else:
|
156 |
+
generator = torch.manual_seed(int(config["manual_seed"]))
|
157 |
|
158 |
progress((3,3), desc="Creating the result...")
|
159 |
|
160 |
image = pipeline(
|
161 |
+
prompt = config["prompt"],
|
162 |
+
negative_prompt = config["negative_prompt"],
|
163 |
generator = generator,
|
164 |
+
num_inference_steps = int(config["inference_steps"]),
|
165 |
+
guidance_scale = float(config["guidance_scale"])).images[0]
|
166 |
|
167 |
+
config_history.append(config.copy())
|
168 |
|
169 |
+
return image, dict_list_to_markdown_table(config_history), config_history
|
170 |
|
171 |
else:
|
172 |
|
173 |
+
return "Please select a model AND a scheduler.", None, config_history
|
|
|
|
|
|
|
|
|
|
|
174 |
|
175 |
+
appConfig = load_app_config()
|
176 |
+
models = appConfig.get("models", {})
|
177 |
+
schedulers = appConfig.get("schedulers", {})
|
178 |
+
devices = appConfig.get("devices", [])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
179 |
|
180 |
# interface
|
181 |
with gr.Blocks() as demo:
|
182 |
+
|
183 |
+
config = gr.State(value=get_initial_config())
|
184 |
+
config_history = gr.State(value=[])
|
185 |
|
186 |
gr.Markdown('''## Text-2-Image Playground
|
187 |
<small>by <a target="_blank" href="https://www.linkedin.com/in/nickyreinert/">Nicky Reinert</a> |
|
|
|
189 |
</small>''')
|
190 |
gr.Markdown("### Device specific settings")
|
191 |
with gr.Row():
|
192 |
+
in_devices = gr.Dropdown(label="Device:", value=config.value["device"], choices=devices, filterable=True, multiselect=False, allow_custom_value=True)
|
193 |
+
in_data_type = gr.Radio(label="Data Type:", value=config.value["data_type"], choices=["bfloat16", "float16"], info="`bfloat16` is not supported on MPS devices right now; Half-precision weights, will save GPU memory, see https://huggingface.co/docs/diffusers/main/en/optimization/fp16")
|
194 |
+
in_allow_tensorfloat32 = gr.Radio(label="Allow TensorFloat32:", value=config.value["allow_tensorfloat32"], choices=["True", "False"], info="is not supported on MPS devices right now; use TensorFloat-32 is faster, but results in slightly less accurate computations, see https://huggingface.co/docs/diffusers/main/en/optimization/fp16 ")
|
195 |
+
in_variant = gr.Radio(label="Variant:", value=config.value["variant"], choices=["fp16", None], info="Use half-precision weights will save GPU memory, not all models support that, see https://huggingface.co/docs/diffusers/main/en/optimization/fp16 ")
|
196 |
|
197 |
gr.Markdown("### Model specific settings")
|
198 |
with gr.Row():
|
199 |
+
in_models = gr.Dropdown(choices=list(models.keys()), label="Model")
|
200 |
+
out_model_description = gr.Textbox(value="", label="Description")
|
201 |
with gr.Row():
|
202 |
with gr.Column(scale=1):
|
203 |
in_use_safetensors = gr.Radio(label="Use safe tensors:", choices=["True", "False"], interactive=False)
|
204 |
with gr.Column(scale=1):
|
205 |
+
in_safety_checker = gr.Radio(label="Enable safety checker:", value=config.value["safety_checker"], choices=["True", "False"])
|
206 |
+
in_requires_safety_checker = gr.Radio(label="Requires safety checker:", value=config.value["requires_safety_checker"], choices=["True", "False"])
|
207 |
|
208 |
gr.Markdown("### Scheduler")
|
209 |
with gr.Row():
|
210 |
+
in_schedulers = gr.Dropdown(choices=list(schedulers.keys()), label="Scheduler", info="see https://huggingface.co/docs/diffusers/using-diffusers/loading#schedulers" )
|
|
|
211 |
out_scheduler_description = gr.Textbox(value="", label="Description")
|
212 |
|
213 |
gr.Markdown("### Adapters")
|
|
|
216 |
|
217 |
gr.Markdown("### Inference settings")
|
218 |
with gr.Row():
|
219 |
+
in_prompt = gr.TextArea(label="Prompt", value=config.value["prompt"])
|
220 |
+
in_negative_prompt = gr.TextArea(label="Negative prompt", value=config.value["negative_prompt"])
|
221 |
with gr.Row():
|
222 |
+
in_inference_steps = gr.Number(label="Inference steps", value=config.value["inference_steps"])
|
223 |
+
in_manual_seed = gr.Number(label="Manual seed", value=config.value["manual_seed"], info="Set this to -1 or leave it empty to randomly generate an image. A fixed value will result in a similar image for every run")
|
224 |
+
in_guidance_scale = gr.Slider(minimum=0, maximum=1, step=0.01, label="Guidance Scale", value=config.value["guidance_scale"], info="A low guidance scale leads to a faster inference time, with the drawback that negative prompts don’t have any effect on the denoising process.")
|
225 |
|
226 |
gr.Markdown("### Output")
|
227 |
with gr.Row():
|
|
|
229 |
with gr.Row():
|
230 |
# out_result = gr.Textbox(label="Status", value="")
|
231 |
out_image = gr.Image()
|
232 |
+
out_code = gr.Code(assemble_code(config.value), label="Code")
|
233 |
with gr.Row():
|
234 |
+
out_config = gr.Code(value=str(config.value), label="Current config")
|
235 |
with gr.Row():
|
236 |
+
out_config_history = gr.Markdown(dict_list_to_markdown_table(config_history.value))
|
237 |
|
238 |
+
in_devices.change(device_change, inputs=[in_devices, config], outputs=[config, out_config, out_code])
|
239 |
+
in_data_type.change(data_type_change, inputs=[in_data_type, config], outputs=[config, out_config, out_code])
|
240 |
+
in_allow_tensorfloat32.change(tensorfloat32_change, inputs=[in_allow_tensorfloat32, config], outputs=[config, out_config, out_code])
|
241 |
+
in_variant.change(variant_change, inputs=[in_variant, config], outputs=[config, out_config, out_code])
|
242 |
+
in_models.change(models_change, inputs=[in_models, in_schedulers, config], outputs=[out_model_description, in_use_safetensors, in_schedulers, config, out_config, out_code])
|
243 |
+
in_safety_checker.change(safety_checker_change, inputs=[in_safety_checker, config], outputs=[config, out_config, out_code])
|
244 |
+
in_requires_safety_checker.change(requires_safety_checker_change, inputs=[in_requires_safety_checker, config], outputs=[config, out_config, out_code])
|
245 |
+
in_schedulers.change(schedulers_change, inputs=[in_schedulers, config], outputs=[out_scheduler_description, config, out_config, out_code])
|
246 |
+
in_inference_steps.change(inference_steps_change, inputs=[in_inference_steps, config], outputs=[config, out_config, out_code])
|
247 |
+
in_manual_seed.change(manual_seed_change, inputs=[in_manual_seed, config], outputs=[config, out_config, out_code])
|
248 |
+
in_guidance_scale.change(guidance_scale_change, inputs=[in_guidance_scale, config], outputs=[config, out_config, out_code])
|
249 |
+
in_prompt.change(prompt_change, inputs=[in_prompt, config], outputs=[config, out_config, out_code])
|
250 |
+
in_negative_prompt.change(negative_prompt_change, inputs=[in_negative_prompt, config], outputs=[config, out_config, out_code])
|
251 |
+
btn_start_pipeline.click(run_inference, inputs=[config, config_history], outputs=[out_image, out_config_history, config_history])
|
252 |
|
253 |
# send current respect initial config to init_config to populate parameters to all relevant input fields
|
254 |
# if GET parameter is set, it will overwrite initial config parameters
|
255 |
+
demo.load(fn=get_config_from_url, inputs=config,
|
256 |
outputs=[
|
257 |
in_models,
|
258 |
in_devices,
|
appConfig.json
CHANGED
@@ -3,10 +3,12 @@
|
|
3 |
|
4 |
"sd-dreambooth-library/solo-levelling-art-style": {
|
5 |
"use_safetensors": false,
|
|
|
6 |
"scheduler": "DDPMScheduler"
|
7 |
},
|
8 |
"CompVis/stable-diffusion-v1-4": {
|
9 |
"use_safetensors": true,
|
|
|
10 |
"scheduler": "DDPMScheduler"
|
11 |
}
|
12 |
|
|
|
3 |
|
4 |
"sd-dreambooth-library/solo-levelling-art-style": {
|
5 |
"use_safetensors": false,
|
6 |
+
"description": "A wonderful model",
|
7 |
"scheduler": "DDPMScheduler"
|
8 |
},
|
9 |
"CompVis/stable-diffusion-v1-4": {
|
10 |
"use_safetensors": true,
|
11 |
+
"description": "Another wonderful model",
|
12 |
"scheduler": "DDPMScheduler"
|
13 |
}
|
14 |
|
change_handlers.py
DELETED
File without changes
|
config.py
CHANGED
@@ -3,7 +3,41 @@ import base64
|
|
3 |
import json
|
4 |
import torch
|
5 |
|
6 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
|
8 |
encoded_params = request.request.query_params.get('config')
|
9 |
return_config = {}
|
@@ -12,7 +46,7 @@ def init_config(request: gr.Request, inital_config):
|
|
12 |
if encoded_params is not None:
|
13 |
decoded_params = base64.b64decode(encoded_params)
|
14 |
decoded_params = decoded_params.decode('utf-8')
|
15 |
-
decoded_params = decoded_params.replace("'", '"').replace('None', 'null').replace('False', '
|
16 |
dict_params = json.loads(decoded_params)
|
17 |
|
18 |
return_config = dict_params
|
@@ -20,10 +54,9 @@ def init_config(request: gr.Request, inital_config):
|
|
20 |
# otherwise use default initial config
|
21 |
else:
|
22 |
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
return_config = dict_inital_config
|
27 |
|
28 |
return [return_config['model'],
|
29 |
return_config['device'],
|
@@ -40,126 +73,82 @@ def init_config(request: gr.Request, inital_config):
|
|
40 |
return_config['guidance_scale']
|
41 |
]
|
42 |
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
self.model_configs = appConfig.get("models", {})
|
55 |
-
self.scheduler_configs = appConfig.get("schedulers", {})
|
56 |
|
57 |
-
|
58 |
-
self.devices = appConfig.get("devices", [])
|
59 |
-
device = None
|
60 |
-
data_type = 'float16'
|
61 |
-
allow_tensorfloat32 = False
|
62 |
-
if torch.cuda.is_available():
|
63 |
-
device = "cuda"
|
64 |
-
data_type = 'bfloat16'
|
65 |
-
allow_tensorfloat32 = True
|
66 |
-
elif torch.backends.mps.is_available():
|
67 |
-
device = "mps"
|
68 |
-
else:
|
69 |
-
device = "cpu"
|
70 |
-
|
71 |
-
self.current = {
|
72 |
-
"device": device,
|
73 |
-
"model": None,
|
74 |
-
"scheduler": None,
|
75 |
-
"variant": None,
|
76 |
-
"allow_tensorfloat32": allow_tensorfloat32,
|
77 |
-
"use_safetensors": False,
|
78 |
-
"data_type": data_type,
|
79 |
-
"safety_checker": False,
|
80 |
-
"requires_safety_checker": False,
|
81 |
-
"manual_seed": 42,
|
82 |
-
"inference_steps": 10,
|
83 |
-
"guidance_scale": 0.5,
|
84 |
-
"prompt": 'A white rabbit',
|
85 |
-
"negative_prompt": 'lowres, cropped, worst quality, low quality, chat bubble, chat bubbles, ugly',
|
86 |
-
}
|
87 |
-
|
88 |
-
self.assemble_code()
|
89 |
|
90 |
-
|
91 |
-
try:
|
92 |
-
with open('appConfig.json', 'r') as f:
|
93 |
-
appConfig = json.load(f)
|
94 |
-
except FileNotFoundError:
|
95 |
-
print("App config file not found.")
|
96 |
-
except json.JSONDecodeError:
|
97 |
-
print("Error decoding JSON in app config file.")
|
98 |
-
except Exception as e:
|
99 |
-
print("An error occurred while loading app config:", str(e))
|
100 |
-
|
101 |
-
return appConfig
|
102 |
-
|
103 |
-
def set_config(self, key, value):
|
104 |
-
|
105 |
-
self.current[key] = value
|
106 |
|
107 |
-
|
|
|
|
|
108 |
|
109 |
-
|
110 |
-
|
111 |
-
if type(scheduler) != list and scheduler is not None:
|
112 |
-
|
113 |
-
return self.scheduler_configs[scheduler]
|
114 |
|
115 |
-
|
116 |
-
|
117 |
-
return ''
|
118 |
|
119 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
120 |
|
121 |
-
|
122 |
-
if self.current['data_type'] == "bfloat16":
|
123 |
-
self.code['002_data_type'] = 'data_type = torch.bfloat16'
|
124 |
-
else:
|
125 |
-
self.code['002_data_type'] = 'data_type = torch.float16'
|
126 |
-
self.code['003_tf32'] = f'torch.backends.cuda.matmul.allow_tf32 = {self.current["allow_tensorfloat32"]}'
|
127 |
-
if str(self.current["variant"]) == 'None':
|
128 |
-
self.code['004_variant'] = f'variant = {self.current["variant"]}'
|
129 |
-
else:
|
130 |
-
self.code['004_variant'] = f'variant = "{self.current["variant"]}"'
|
131 |
-
self.code['050_init_pipe'] = f'''pipeline = DiffusionPipeline.from_pretrained(
|
132 |
-
"{self.current['model']}",
|
133 |
-
use_safetensors=use_safetensors,
|
134 |
-
torch_dtype=data_type,
|
135 |
-
variant=variant).to(device)'''
|
136 |
|
137 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
138 |
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
|
|
3 |
import json
|
4 |
import torch
|
5 |
|
6 |
+
def get_initial_config():
|
7 |
+
|
8 |
+
device = None
|
9 |
+
data_type = 'float16'
|
10 |
+
allow_tensorfloat32 = "False"
|
11 |
+
if torch.cuda.is_available():
|
12 |
+
device = "cuda"
|
13 |
+
data_type = 'bfloat16'
|
14 |
+
allow_tensorfloat32 = "True"
|
15 |
+
elif torch.backends.mps.is_available():
|
16 |
+
device = "mps"
|
17 |
+
else:
|
18 |
+
device = "cpu"
|
19 |
+
|
20 |
+
config = {
|
21 |
+
"device": device,
|
22 |
+
"model": None,
|
23 |
+
"scheduler": None,
|
24 |
+
"variant": None,
|
25 |
+
"allow_tensorfloat32": allow_tensorfloat32,
|
26 |
+
"use_safetensors": "False",
|
27 |
+
"data_type": data_type,
|
28 |
+
"safety_checker": "False",
|
29 |
+
"requires_safety_checker": "False",
|
30 |
+
"manual_seed": 42,
|
31 |
+
"inference_steps": 10,
|
32 |
+
"guidance_scale": 0.5,
|
33 |
+
"prompt": 'A white rabbit',
|
34 |
+
"negative_prompt": 'lowres, cropped, worst quality, low quality, chat bubble, chat bubbles, ugly',
|
35 |
+
}
|
36 |
+
|
37 |
+
|
38 |
+
return config
|
39 |
+
|
40 |
+
def get_config_from_url(request: gr.Request, initial_config):
|
41 |
|
42 |
encoded_params = request.request.query_params.get('config')
|
43 |
return_config = {}
|
|
|
46 |
if encoded_params is not None:
|
47 |
decoded_params = base64.b64decode(encoded_params)
|
48 |
decoded_params = decoded_params.decode('utf-8')
|
49 |
+
decoded_params = decoded_params.replace("'", '"').replace('None', 'null').replace('False', 'False')
|
50 |
dict_params = json.loads(decoded_params)
|
51 |
|
52 |
return_config = dict_params
|
|
|
54 |
# otherwise use default initial config
|
55 |
else:
|
56 |
|
57 |
+
# initial_config = initial_config.replace("'", '"').replace('None', 'null').replace('False', 'False')
|
58 |
+
# return_config = json.loads(initial_config)
|
59 |
+
return_config = initial_config
|
|
|
60 |
|
61 |
return [return_config['model'],
|
62 |
return_config['device'],
|
|
|
73 |
return_config['guidance_scale']
|
74 |
]
|
75 |
|
76 |
+
def load_app_config():
|
77 |
+
try:
|
78 |
+
with open('appConfig.json', 'r') as f:
|
79 |
+
appConfig = json.load(f)
|
80 |
+
except FileNotFoundError:
|
81 |
+
print("App config file not found.")
|
82 |
+
except json.JSONDecodeError:
|
83 |
+
print("Error decoding JSON in app config file.")
|
84 |
+
except Exception as e:
|
85 |
+
print("An error occurred while loading app config:", str(e))
|
|
|
|
|
|
|
86 |
|
87 |
+
return appConfig
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
88 |
|
89 |
+
def set_config(config, key, value):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
90 |
|
91 |
+
# str_config = str_config.replace("'", '"').replace('None', 'null').replace('False', 'false')
|
92 |
+
# config = json.loads(str_config)
|
93 |
+
config[key] = value
|
94 |
|
95 |
+
return config
|
|
|
|
|
|
|
|
|
96 |
|
97 |
+
def assemble_code(str_config):
|
|
|
|
|
98 |
|
99 |
+
# str_config = str_config.replace("'", '"').replace('None', 'null').replace('False', 'false')
|
100 |
+
# config = json.loads(str_config)
|
101 |
+
config = str_config
|
102 |
+
|
103 |
+
code = {}
|
104 |
+
|
105 |
+
code['001_code'] = f'''device = "{config['device']}"'''
|
106 |
+
if config['data_type'] == "bfloat16":
|
107 |
+
code['002_data_type'] = 'data_type = torch.bfloat16'
|
108 |
+
else:
|
109 |
+
code['002_data_type'] = 'data_type = torch.float16'
|
110 |
+
code['003_tf32'] = f'torch.backends.cuda.matmul.allow_tf32 = {config["allow_tensorfloat32"]}'
|
111 |
+
|
112 |
+
if str(config["variant"]) == 'None':
|
113 |
+
code['004_variant'] = f'variant = {config["variant"]}'
|
114 |
+
else:
|
115 |
+
code['004_variant'] = f'variant = "{config["variant"]}"'
|
116 |
+
|
117 |
|
118 |
+
code['005_use_safetensors'] = f'''use_safetensors = {config["use_safetensors"]}'''
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
119 |
|
120 |
+
code['050_init_pipe'] = f'''pipeline = DiffusionPipeline.from_pretrained(
|
121 |
+
"{config['model']}",
|
122 |
+
use_safetensors=use_safetensors,
|
123 |
+
torch_dtype=data_type,
|
124 |
+
variant=variant).to(device)'''
|
125 |
+
|
126 |
+
code['054_requires_safety_checker'] = f'pipeline.requires_safety_checker = {config["requires_safety_checker"]}'
|
127 |
|
128 |
+
if str(config["safety_checker"]).lower() == 'false':
|
129 |
+
code['055_safety_checker'] = f'pipeline.safety_checker = None'
|
130 |
+
else:
|
131 |
+
code['055_safety_checker'] = ''
|
132 |
+
|
133 |
+
code['060_scheduler'] = f'pipeline.scheduler = {config["scheduler"]}.from_config(pipeline.scheduler.config)'
|
134 |
+
|
135 |
+
if config['manual_seed'] < 0 or config['manual_seed'] is None or config['manual_seed'] == '':
|
136 |
+
code['091_manual_seed'] = f'# manual_seed = {config["manual_seed"]}'
|
137 |
+
code['092_generator'] = f'generator = torch.Generator("{config["device"]}")'
|
138 |
+
else:
|
139 |
+
code['091_manual_seed'] = f'manual_seed = {config["manual_seed"]}'
|
140 |
+
code['092_generator'] = f'generator = torch.manual_seed(manual_seed)'
|
141 |
|
142 |
+
code["080_prompt"] = f'prompt = "{config["prompt"]}"'
|
143 |
+
code["085_negative_prompt"] = f'negative_prompt = "{config["negative_prompt"]}"'
|
144 |
+
code["090_inference_steps"] = f'inference_steps = {config["inference_steps"]}'
|
145 |
+
code["095_guidance_scale"] = f'guidance_scale = {config["guidance_scale"]}'
|
146 |
+
|
147 |
+
code["100_run_inference"] = f'''image = pipeline(
|
148 |
+
prompt=prompt,
|
149 |
+
negative_prompt=negative_prompt,
|
150 |
+
generator=generator,
|
151 |
+
num_inference_steps=inference_steps,
|
152 |
+
guidance_scale=guidance_scale).images[0]'''
|
153 |
+
|
154 |
+
return '\r\n'.join(value[1] for value in sorted(code.items()))
|
helpers.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from diffusers import (
|
3 |
+
DDPMScheduler,
|
4 |
+
DDIMScheduler,
|
5 |
+
PNDMScheduler,
|
6 |
+
LMSDiscreteScheduler,
|
7 |
+
EulerAncestralDiscreteScheduler,
|
8 |
+
EulerDiscreteScheduler,
|
9 |
+
DPMSolverMultistepScheduler,
|
10 |
+
)
|
11 |
+
import base64
|
12 |
+
|
13 |
+
def get_variant(str_variant):
|
14 |
+
|
15 |
+
if str(str_variant).lower() == 'none':
|
16 |
+
return None
|
17 |
+
else:
|
18 |
+
return str_variant
|
19 |
+
|
20 |
+
def get_bool(str_bool):
|
21 |
+
|
22 |
+
if str(str_bool).lower() == 'false':
|
23 |
+
return False
|
24 |
+
else:
|
25 |
+
return True
|
26 |
+
|
27 |
+
|
28 |
+
def get_data_type(str_data_type):
|
29 |
+
|
30 |
+
if str_data_type == "bfloat16":
|
31 |
+
return torch.bfloat16 # BFloat16 is not supported on MPS as of 01/2024
|
32 |
+
else:
|
33 |
+
return torch.float16 # Half-precision weights, as of https://huggingface.co/docs/diffusers/main/en/optimization/fp16 will save GPU memory
|
34 |
+
|
35 |
+
def get_tensorfloat32(allow_tensorfloat32):
|
36 |
+
|
37 |
+
return True if str(allow_tensorfloat32).lower() == 'true' else False
|
38 |
+
|
39 |
+
def get_scheduler(scheduler, pipeline_config):
|
40 |
+
|
41 |
+
if scheduler == "DDPMScheduler":
|
42 |
+
return DDPMScheduler.from_config(pipeline_config)
|
43 |
+
elif scheduler == "DDIMScheduler":
|
44 |
+
return DDIMScheduler.from_config(pipeline_config)
|
45 |
+
elif scheduler == "PNDMScheduler":
|
46 |
+
return PNDMScheduler.from_config(pipeline_config)
|
47 |
+
elif scheduler == "LMSDiscreteScheduler":
|
48 |
+
return LMSDiscreteScheduler.from_config(pipeline_config)
|
49 |
+
elif scheduler == "EulerAncestralDiscreteScheduler":
|
50 |
+
return EulerAncestralDiscreteScheduler.from_config(pipeline_config)
|
51 |
+
elif scheduler == "EulerDiscreteScheduler":
|
52 |
+
return EulerDiscreteScheduler.from_config(pipeline_config)
|
53 |
+
elif scheduler == "DPMSolverMultistepScheduler":
|
54 |
+
return DPMSolverMultistepScheduler.from_config(pipeline_config)
|
55 |
+
else:
|
56 |
+
return DPMSolverMultistepScheduler.from_config(pipeline_config)
|
57 |
+
|
58 |
+
def dict_list_to_markdown_table(config_history):
|
59 |
+
|
60 |
+
if not config_history:
|
61 |
+
return ""
|
62 |
+
|
63 |
+
headers = list(config_history[0].keys())
|
64 |
+
markdown_table = "| share | " + " | ".join(headers) + " |\n"
|
65 |
+
markdown_table += "| --- | " + " | ".join(["---"] * len(headers)) + " |\n"
|
66 |
+
|
67 |
+
for index, config in enumerate(config_history):
|
68 |
+
|
69 |
+
encoded_config = base64.b64encode(str(config).encode()).decode()
|
70 |
+
share_link = f'<a target="_blank" href="?config={encoded_config}">📎</a>'
|
71 |
+
markdown_table += f"| {share_link} | " + " | ".join(str(config.get(key, "")) for key in headers) + " |\n"
|
72 |
+
|
73 |
+
markdown_table = '<div style="overflow-x: auto;">\n\n' + markdown_table + '</div>'
|
74 |
+
|
75 |
+
return markdown_table
|
76 |
+
|