n42 commited on
Commit
b96c8c5
·
1 Parent(s): 00e74d5

fix wrong state handling

Browse files
Files changed (5) hide show
  1. app.py +138 -154
  2. appConfig.json +2 -0
  3. change_handlers.py +0 -0
  4. config.py +108 -119
  5. helpers.py +76 -0
app.py CHANGED
@@ -6,15 +6,6 @@ import torch
6
  import json
7
  from PIL import Image
8
  from diffusers import DiffusionPipeline
9
- from diffusers import (
10
- DDPMScheduler,
11
- DDIMScheduler,
12
- PNDMScheduler,
13
- LMSDiscreteScheduler,
14
- EulerAncestralDiscreteScheduler,
15
- EulerDiscreteScheduler,
16
- DPMSolverMultistepScheduler,
17
- )
18
  import threading
19
  import requests
20
  from flask import Flask, render_template_string
@@ -22,181 +13,175 @@ from gradio import Interface
22
  from diffusers import AutoencoderKL
23
  import pandas as pd
24
  import base64
25
- from config import Config, init_config
 
26
 
27
- config = Config()
28
-
29
- def device_change(device):
30
 
31
- return config.set_config('device', device), config.assemble_code()
 
 
 
 
32
 
33
- def models_change(model, scheduler):
34
-
35
  use_safetensors = False
36
 
37
  # no model selected (because this is UI init run)
38
- if type(model) != list and model is not None:
 
 
 
39
 
40
- use_safetensors = str(config.model_configs[model]['use_safetensors'])
41
-
42
  # if no scheduler is selected, choose the default one for this model
43
  if scheduler == None:
44
 
45
- scheduler = config.model_configs[model]['scheduler']
46
 
47
- safety_checker_change(config.current["safety_checker"])
48
- requires_safety_checker_change(config.current["requires_safety_checker"])
49
 
50
- return use_safetensors, scheduler, config.set_config('model', model), config.assemble_code()
 
 
 
 
 
 
51
 
52
- def data_type_change(data_type):
53
 
54
- return config.set_config('data_type', data_type), config.assemble_code()
55
 
56
- def get_data_type(str_data_type):
57
-
58
- if str_data_type == "bfloat16":
59
- return torch.bfloat16 # BFloat16 is not supported on MPS as of 01/2024
60
- else:
61
- return torch.float16 # Half-precision weights, as of https://huggingface.co/docs/diffusers/main/en/optimization/fp16 will save GPU memory
62
 
63
- def tensorfloat32_change(allow_tensorfloat32):
64
 
65
- return config.set_config('allow_tensorfloat32', allow_tensorfloat32), config.assemble_code()
 
 
66
 
67
- def inference_steps_change(inference_steps):
68
 
69
- return config.set_config('inference_steps', inference_steps), config.assemble_code()
70
 
71
- def manual_seed_change(manual_seed):
72
-
73
- return config.set_config('manual_seed', manual_seed), config.assemble_code()
74
 
75
- def guidance_scale_change(guidance_scale):
76
 
77
- return config.set_config('guidance_scale', guidance_scale), config.assemble_code()
78
 
79
- def prompt_change(prompt):
80
-
81
- return config.set_config('prompt', prompt), config.assemble_code()
82
 
83
- def negative_prompt_change(negative_prompt):
 
 
 
 
 
 
84
 
85
- return config.set_config('negative_prompt', negative_prompt), config.assemble_code()
86
 
87
- def variant_change(variant):
 
 
88
 
89
- return config.set_config('variant', variant), config.assemble_code()
90
 
91
- def safety_checker_change(safety_checker):
92
-
93
- return config.set_config('safety_checker', safety_checker), config.assemble_code()
94
 
95
- def requires_safety_checker_change(requires_safety_checker):
96
 
97
- return config.set_config('requires_safety_checker', requires_safety_checker), config.assemble_code()
98
 
99
- def schedulers_change(scheduler):
100
 
101
- return config.get_scheduler_description(scheduler), config.set_config('scheduler', scheduler), config.assemble_code()
102
-
103
- def get_tensorfloat32(allow_tensorfloat32):
104
 
105
- return True if str(allow_tensorfloat32).lower() == 'true' else False
 
 
 
 
 
 
 
 
106
 
107
- def get_scheduler(scheduler, pipeline_config):
108
 
109
- if scheduler == "DDPMScheduler":
110
- return DDPMScheduler.from_config(pipeline_config)
111
- elif scheduler == "DDIMScheduler":
112
- return DDIMScheduler.from_config(pipeline_config)
113
- elif scheduler == "PNDMScheduler":
114
- return PNDMScheduler.from_config(pipeline_config)
115
- elif scheduler == "LMSDiscreteScheduler":
116
- return LMSDiscreteScheduler.from_config(pipeline_config)
117
- elif scheduler == "EulerAncestralDiscreteScheduler":
118
- return EulerAncestralDiscreteScheduler.from_config(pipeline_config)
119
- elif scheduler == "EulerDiscreteScheduler":
120
- return EulerDiscreteScheduler.from_config(pipeline_config)
121
- elif scheduler == "DPMSolverMultistepScheduler":
122
- return DPMSolverMultistepScheduler.from_config(pipeline_config)
123
  else:
124
- return DPMSolverMultistepScheduler.from_config(pipeline_config)
125
-
126
- # get
127
- # - initial configuration,
128
- # - a list of available devices from the config file
129
- # - a list of available models from the config file
130
- # - a list of available schedulers from the config file
131
- # - a dict that contains code to for reproduction
132
 
133
- # pipeline
134
- def run_inference(progress=gr.Progress(track_tqdm=True)):
 
135
 
136
- if config.current["model"] != None and config.current["scheduler"] != None:
 
 
 
137
 
138
  progress((1,3), desc="Preparing pipeline initialization...")
139
 
140
- torch.backends.cuda.matmul.allow_tf32 = config.current["allow_tensorfloat32"] # Use TensorFloat-32 as of https://huggingface.co/docs/diffusers/main/en/optimization/fp16 faster, but slightly less accurate computations
141
 
142
  progress((2,3), desc="Initializing pipeline...")
143
-
144
  pipeline = DiffusionPipeline.from_pretrained(
145
- config.current["model"],
146
- use_safetensors = config.current["use_safetensors"],
147
- torch_dtype = get_data_type(config.current["data_type"]),
148
- variant = config.current["variant"]).to(config.current["device"])
149
 
150
- if config.current["safety_checker"] is None or str(config.current["safety_checker"]).lower == 'false':
151
  pipeline.safety_checker = None
152
 
153
- pipeline.requires_safety_checker = config.current["requires_safety_checker"]
154
-
155
- pipeline.scheduler = get_scheduler(config.current["scheduler"], pipeline.scheduler.config)
156
 
157
- if config.current["manual_seed"] < 0 or config.current["manual_seed"] is None or config.current["manual_seed"] == '':
158
- generator = torch.Generator(config.current["device"])
159
  else:
160
- generator = torch.manual_seed(42)
161
 
162
  progress((3,3), desc="Creating the result...")
163
 
164
  image = pipeline(
165
- prompt = config.current["prompt"],
166
- negative_prompt = config.current["negative_prompt"],
167
  generator = generator,
168
- num_inference_steps = config.current["inference_steps"],
169
- guidance_scale = config.current["guidance_scale"]).images[0]
170
 
171
- config.history.append(config.current.copy())
172
 
173
- return image, dict_list_to_markdown_table(config.history)
174
 
175
  else:
176
 
177
- return "Please select a model AND a scheduler.", None
178
-
179
- def dict_list_to_markdown_table(config_history):
180
-
181
- if not config_history:
182
- return ""
183
 
184
- headers = list(config_history[0].keys())
185
- markdown_table = "| share | " + " | ".join(headers) + " |\n"
186
- markdown_table += "| --- | " + " | ".join(["---"] * len(headers)) + " |\n"
187
-
188
- for index, config in enumerate(config_history):
189
-
190
- encoded_config = base64.b64encode(str(config).encode()).decode()
191
- share_link = f'<a target="_blank" href="?config={encoded_config}">📎</a>'
192
- markdown_table += f"| {share_link} | " + " | ".join(str(config.get(key, "")) for key in headers) + " |\n"
193
-
194
- markdown_table = '<div style="overflow-x: auto;">\n\n' + markdown_table + '</div>'
195
-
196
- return markdown_table
197
 
198
  # interface
199
  with gr.Blocks() as demo:
 
 
 
200
 
201
  gr.Markdown('''## Text-2-Image Playground
202
  <small>by <a target="_blank" href="https://www.linkedin.com/in/nickyreinert/">Nicky Reinert</a> |
@@ -204,26 +189,25 @@ with gr.Blocks() as demo:
204
  </small>''')
205
  gr.Markdown("### Device specific settings")
206
  with gr.Row():
207
- in_devices = gr.Dropdown(label="Device:", value=config.current["device"], choices=config.devices, filterable=True, multiselect=False, allow_custom_value=True)
208
- in_data_type = gr.Radio(label="Data Type:", value=config.current["data_type"], choices=["bfloat16", "float16"], info="`bfloat16` is not supported on MPS devices right now; Half-precision weights, will save GPU memory, see https://huggingface.co/docs/diffusers/main/en/optimization/fp16")
209
- in_allow_tensorfloat32 = gr.Radio(label="Allow TensorFloat32:", value=config.current["allow_tensorfloat32"], choices=[True, False], info="is not supported on MPS devices right now; use TensorFloat-32 is faster, but results in slightly less accurate computations, see https://huggingface.co/docs/diffusers/main/en/optimization/fp16 ")
210
- in_variant = gr.Radio(label="Variant:", value=config.current["variant"], choices=["fp16", None], info="Use half-precision weights will save GPU memory, not all models support that, see https://huggingface.co/docs/diffusers/main/en/optimization/fp16 ")
211
 
212
  gr.Markdown("### Model specific settings")
213
  with gr.Row():
214
- models = list(config.model_configs.keys())
215
- in_models = gr.Dropdown(choices=models, label="Model")
216
  with gr.Row():
217
  with gr.Column(scale=1):
218
  in_use_safetensors = gr.Radio(label="Use safe tensors:", choices=["True", "False"], interactive=False)
219
  with gr.Column(scale=1):
220
- in_safety_checker = gr.Radio(label="Enable safety checker:", value=config.current["safety_checker"], choices=[True, False])
221
- in_requires_safety_checker = gr.Radio(label="Requires safety checker:", value=config.current["requires_safety_checker"], choices=[True, False])
222
 
223
  gr.Markdown("### Scheduler")
224
  with gr.Row():
225
- schedulers = list(config.scheduler_configs.keys())
226
- in_schedulers = gr.Dropdown(choices=schedulers, label="Scheduler", info="see https://huggingface.co/docs/diffusers/using-diffusers/loading#schedulers" )
227
  out_scheduler_description = gr.Textbox(value="", label="Description")
228
 
229
  gr.Markdown("### Adapters")
@@ -232,12 +216,12 @@ with gr.Blocks() as demo:
232
 
233
  gr.Markdown("### Inference settings")
234
  with gr.Row():
235
- in_prompt = gr.TextArea(label="Prompt", value=config.current["prompt"])
236
- in_negative_prompt = gr.TextArea(label="Negative prompt", value=config.current["negative_prompt"])
237
  with gr.Row():
238
- in_inference_steps = gr.Number(label="Inference steps", value=config.current["inference_steps"])
239
- in_manual_seed = gr.Number(label="Manual seed", value=config.current["manual_seed"], info="Set this to -1 or leave it empty to randomly generate an image. A fixed value will result in a similar image for every run")
240
- in_guidance_scale = gr.Slider(minimum=0, maximum=1, step=0.01, label="Guidance Scale", value=config.current["guidance_scale"], info="A low guidance scale leads to a faster inference time, with the drawback that negative prompts don’t have any effect on the denoising process.")
241
 
242
  gr.Markdown("### Output")
243
  with gr.Row():
@@ -245,30 +229,30 @@ with gr.Blocks() as demo:
245
  with gr.Row():
246
  # out_result = gr.Textbox(label="Status", value="")
247
  out_image = gr.Image()
248
- out_code = gr.Code(config.assemble_code(), label="Code")
249
  with gr.Row():
250
- out_current_config = gr.Code(value=str(config.current), label="Current config")
251
  with gr.Row():
252
- out_config_history = gr.Markdown(dict_list_to_markdown_table(config_history))
253
 
254
- in_devices.change(device_change, inputs=[in_devices], outputs=[out_current_config, out_code])
255
- in_data_type.change(data_type_change, inputs=[in_data_type], outputs=[out_current_config, out_code])
256
- in_allow_tensorfloat32.change(tensorfloat32_change, inputs=[in_allow_tensorfloat32], outputs=[out_current_config, out_code])
257
- in_variant.change(variant_change, inputs=[in_variant], outputs=[out_current_config, out_code])
258
- in_models.change(models_change, inputs=[in_models, in_schedulers], outputs=[in_use_safetensors, in_schedulers, out_current_config, out_code])
259
- in_safety_checker.change(safety_checker_change, inputs=[in_safety_checker], outputs=[out_current_config, out_code])
260
- in_requires_safety_checker.change(requires_safety_checker_change, inputs=[in_requires_safety_checker], outputs=[out_current_config, out_code])
261
- in_schedulers.change(schedulers_change, inputs=[in_schedulers], outputs=[out_scheduler_description, out_current_config, out_code])
262
- in_inference_steps.change(inference_steps_change, inputs=[in_inference_steps], outputs=[out_current_config, out_code])
263
- in_manual_seed.change(manual_seed_change, inputs=[in_manual_seed], outputs=[out_current_config, out_code])
264
- in_guidance_scale.change(guidance_scale_change, inputs=[in_guidance_scale], outputs=[out_current_config, out_code])
265
- in_prompt.change(prompt_change, inputs=[in_prompt], outputs=[out_current_config, out_code])
266
- in_negative_prompt.change(negative_prompt_change, inputs=[in_negative_prompt], outputs=[out_current_config, out_code])
267
- btn_start_pipeline.click(run_inference, inputs=[], outputs=[out_image, out_config_history])
268
 
269
  # send current respect initial config to init_config to populate parameters to all relevant input fields
270
  # if GET parameter is set, it will overwrite initial config parameters
271
- demo.load(fn=init_config, inputs=out_current_config,
272
  outputs=[
273
  in_models,
274
  in_devices,
 
6
  import json
7
  from PIL import Image
8
  from diffusers import DiffusionPipeline
 
 
 
 
 
 
 
 
 
9
  import threading
10
  import requests
11
  from flask import Flask, render_template_string
 
13
  from diffusers import AutoencoderKL
14
  import pandas as pd
15
  import base64
16
+ from config import *
17
+ from helpers import *
18
 
19
+ def device_change(device, config):
20
+
21
+ config = set_config(config, 'device', device)
22
 
23
+ return config, str(config), assemble_code(config)
24
+
25
+ def models_change(model, scheduler, config):
26
+
27
+ config = set_config(config, 'model', model)
28
 
 
 
29
  use_safetensors = False
30
 
31
  # no model selected (because this is UI init run)
32
+ if type(model) != list and str(model) != 'None':
33
+
34
+ use_safetensors = str(models[model]['use_safetensors'])
35
+ model_description = models[model]['description']
36
 
 
 
37
  # if no scheduler is selected, choose the default one for this model
38
  if scheduler == None:
39
 
40
+ scheduler = models[model]['scheduler']
41
 
42
+ else:
 
43
 
44
+ model_description = 'Please select a model.'
45
+
46
+ config["use_safetensors"] = str(use_safetensors)
47
+ config["scheduler"] = str(scheduler)
48
+
49
+ # safety_checker_change(in_safety_checker.value, config)
50
+ # requires_safety_checker_change(in_requires_safety_checker.value, config)
51
 
52
+ return model_description, use_safetensors, scheduler, config, str(config), assemble_code(config)
53
 
54
+ def data_type_change(data_type, config):
55
 
56
+ config = set_config(config, 'data_type', data_type)
57
+
58
+ return config, str(config), assemble_code(config)
 
 
 
59
 
60
+ def tensorfloat32_change(allow_tensorfloat32, config):
61
 
62
+ config = set_config(config, 'allow_tensorfloat32', allow_tensorfloat32)
63
+
64
+ return config, str(config), assemble_code(config)
65
 
66
+ def inference_steps_change(inference_steps, config):
67
 
68
+ config = set_config(config, 'inference_steps', inference_steps)
69
 
70
+ return config, str(config), assemble_code(config)
 
 
71
 
72
+ def manual_seed_change(manual_seed, config):
73
 
74
+ config = set_config(config, 'manual_seed', manual_seed)
75
 
76
+ return config, str(config), assemble_code(config)
 
 
77
 
78
+ def guidance_scale_change(guidance_scale, config):
79
+
80
+ config = set_config(config, 'guidance_scale', guidance_scale)
81
+
82
+ return config, str(config), assemble_code(config)
83
+
84
+ def prompt_change(prompt, config):
85
 
86
+ config = set_config(config, 'prompt', prompt)
87
 
88
+ return config, str(config), assemble_code(config)
89
+
90
+ def negative_prompt_change(negative_prompt, config):
91
 
92
+ config = set_config(config, 'negative_prompt', negative_prompt)
93
 
94
+ return config, str(config), assemble_code(config)
 
 
95
 
96
+ def variant_change(variant, config):
97
 
98
+ config = set_config(config, 'variant', variant)
99
 
100
+ return config, str(config), assemble_code(config)
101
 
102
+ def safety_checker_change(safety_checker, config):
 
 
103
 
104
+ config = set_config(config, 'safety_checker', safety_checker)
105
+
106
+ return config, str(config), assemble_code(config)
107
+
108
+ def requires_safety_checker_change(requires_safety_checker, config):
109
+
110
+ config = set_config(config, 'requires_safety_checker', requires_safety_checker)
111
+
112
+ return config, str(config), assemble_code(config)
113
 
114
+ def schedulers_change(scheduler, config):
115
 
116
+ if str(scheduler) != 'None' and type(scheduler) != list:
117
+
118
+ scheduler_description = schedulers[scheduler]
119
+
 
 
 
 
 
 
 
 
 
 
120
  else:
121
+ scheduler_description = 'Please select a scheduler.'
122
+
123
+ config = set_config(config, 'scheduler', scheduler)
 
 
 
 
 
124
 
125
+ return scheduler_description, config, str(config), assemble_code(config)
126
+
127
+ def run_inference(config, config_history, progress=gr.Progress(track_tqdm=True)):
128
 
129
+ # str_config = str_config.replace("'", '"').replace('None', 'null').replace('False', 'false')
130
+ # config = json.loads(str_config)
131
+
132
+ if str(config["model"]) != 'None' and str(config["scheduler"]) != 'None':
133
 
134
  progress((1,3), desc="Preparing pipeline initialization...")
135
 
136
+ torch.backends.cuda.matmul.allow_tf32 = get_bool(config["allow_tensorfloat32"]) # Use TensorFloat-32 as of https://huggingface.co/docs/diffusers/main/en/optimization/fp16 faster, but slightly less accurate computations
137
 
138
  progress((2,3), desc="Initializing pipeline...")
139
+
140
  pipeline = DiffusionPipeline.from_pretrained(
141
+ config["model"],
142
+ use_safetensors = get_bool(config["use_safetensors"]),
143
+ torch_dtype = get_data_type(config["data_type"]),
144
+ variant = get_variant(config["variant"])).to(config["device"])
145
 
146
+ if str(config["safety_checker"]).lower() == 'false':
147
  pipeline.safety_checker = None
148
 
149
+ pipeline.requires_safety_checker = get_bool(config["requires_safety_checker"])
150
+
151
+ pipeline.scheduler = get_scheduler(config["scheduler"], pipeline.scheduler.config)
152
 
153
+ if config["manual_seed"] < 0 or config["manual_seed"] is None or config["manual_seed"] == '':
154
+ generator = torch.Generator(config["device"])
155
  else:
156
+ generator = torch.manual_seed(int(config["manual_seed"]))
157
 
158
  progress((3,3), desc="Creating the result...")
159
 
160
  image = pipeline(
161
+ prompt = config["prompt"],
162
+ negative_prompt = config["negative_prompt"],
163
  generator = generator,
164
+ num_inference_steps = int(config["inference_steps"]),
165
+ guidance_scale = float(config["guidance_scale"])).images[0]
166
 
167
+ config_history.append(config.copy())
168
 
169
+ return image, dict_list_to_markdown_table(config_history), config_history
170
 
171
  else:
172
 
173
+ return "Please select a model AND a scheduler.", None, config_history
 
 
 
 
 
174
 
175
+ appConfig = load_app_config()
176
+ models = appConfig.get("models", {})
177
+ schedulers = appConfig.get("schedulers", {})
178
+ devices = appConfig.get("devices", [])
 
 
 
 
 
 
 
 
 
179
 
180
  # interface
181
  with gr.Blocks() as demo:
182
+
183
+ config = gr.State(value=get_initial_config())
184
+ config_history = gr.State(value=[])
185
 
186
  gr.Markdown('''## Text-2-Image Playground
187
  <small>by <a target="_blank" href="https://www.linkedin.com/in/nickyreinert/">Nicky Reinert</a> |
 
189
  </small>''')
190
  gr.Markdown("### Device specific settings")
191
  with gr.Row():
192
+ in_devices = gr.Dropdown(label="Device:", value=config.value["device"], choices=devices, filterable=True, multiselect=False, allow_custom_value=True)
193
+ in_data_type = gr.Radio(label="Data Type:", value=config.value["data_type"], choices=["bfloat16", "float16"], info="`bfloat16` is not supported on MPS devices right now; Half-precision weights, will save GPU memory, see https://huggingface.co/docs/diffusers/main/en/optimization/fp16")
194
+ in_allow_tensorfloat32 = gr.Radio(label="Allow TensorFloat32:", value=config.value["allow_tensorfloat32"], choices=["True", "False"], info="is not supported on MPS devices right now; use TensorFloat-32 is faster, but results in slightly less accurate computations, see https://huggingface.co/docs/diffusers/main/en/optimization/fp16 ")
195
+ in_variant = gr.Radio(label="Variant:", value=config.value["variant"], choices=["fp16", None], info="Use half-precision weights will save GPU memory, not all models support that, see https://huggingface.co/docs/diffusers/main/en/optimization/fp16 ")
196
 
197
  gr.Markdown("### Model specific settings")
198
  with gr.Row():
199
+ in_models = gr.Dropdown(choices=list(models.keys()), label="Model")
200
+ out_model_description = gr.Textbox(value="", label="Description")
201
  with gr.Row():
202
  with gr.Column(scale=1):
203
  in_use_safetensors = gr.Radio(label="Use safe tensors:", choices=["True", "False"], interactive=False)
204
  with gr.Column(scale=1):
205
+ in_safety_checker = gr.Radio(label="Enable safety checker:", value=config.value["safety_checker"], choices=["True", "False"])
206
+ in_requires_safety_checker = gr.Radio(label="Requires safety checker:", value=config.value["requires_safety_checker"], choices=["True", "False"])
207
 
208
  gr.Markdown("### Scheduler")
209
  with gr.Row():
210
+ in_schedulers = gr.Dropdown(choices=list(schedulers.keys()), label="Scheduler", info="see https://huggingface.co/docs/diffusers/using-diffusers/loading#schedulers" )
 
211
  out_scheduler_description = gr.Textbox(value="", label="Description")
212
 
213
  gr.Markdown("### Adapters")
 
216
 
217
  gr.Markdown("### Inference settings")
218
  with gr.Row():
219
+ in_prompt = gr.TextArea(label="Prompt", value=config.value["prompt"])
220
+ in_negative_prompt = gr.TextArea(label="Negative prompt", value=config.value["negative_prompt"])
221
  with gr.Row():
222
+ in_inference_steps = gr.Number(label="Inference steps", value=config.value["inference_steps"])
223
+ in_manual_seed = gr.Number(label="Manual seed", value=config.value["manual_seed"], info="Set this to -1 or leave it empty to randomly generate an image. A fixed value will result in a similar image for every run")
224
+ in_guidance_scale = gr.Slider(minimum=0, maximum=1, step=0.01, label="Guidance Scale", value=config.value["guidance_scale"], info="A low guidance scale leads to a faster inference time, with the drawback that negative prompts don’t have any effect on the denoising process.")
225
 
226
  gr.Markdown("### Output")
227
  with gr.Row():
 
229
  with gr.Row():
230
  # out_result = gr.Textbox(label="Status", value="")
231
  out_image = gr.Image()
232
+ out_code = gr.Code(assemble_code(config.value), label="Code")
233
  with gr.Row():
234
+ out_config = gr.Code(value=str(config.value), label="Current config")
235
  with gr.Row():
236
+ out_config_history = gr.Markdown(dict_list_to_markdown_table(config_history.value))
237
 
238
+ in_devices.change(device_change, inputs=[in_devices, config], outputs=[config, out_config, out_code])
239
+ in_data_type.change(data_type_change, inputs=[in_data_type, config], outputs=[config, out_config, out_code])
240
+ in_allow_tensorfloat32.change(tensorfloat32_change, inputs=[in_allow_tensorfloat32, config], outputs=[config, out_config, out_code])
241
+ in_variant.change(variant_change, inputs=[in_variant, config], outputs=[config, out_config, out_code])
242
+ in_models.change(models_change, inputs=[in_models, in_schedulers, config], outputs=[out_model_description, in_use_safetensors, in_schedulers, config, out_config, out_code])
243
+ in_safety_checker.change(safety_checker_change, inputs=[in_safety_checker, config], outputs=[config, out_config, out_code])
244
+ in_requires_safety_checker.change(requires_safety_checker_change, inputs=[in_requires_safety_checker, config], outputs=[config, out_config, out_code])
245
+ in_schedulers.change(schedulers_change, inputs=[in_schedulers, config], outputs=[out_scheduler_description, config, out_config, out_code])
246
+ in_inference_steps.change(inference_steps_change, inputs=[in_inference_steps, config], outputs=[config, out_config, out_code])
247
+ in_manual_seed.change(manual_seed_change, inputs=[in_manual_seed, config], outputs=[config, out_config, out_code])
248
+ in_guidance_scale.change(guidance_scale_change, inputs=[in_guidance_scale, config], outputs=[config, out_config, out_code])
249
+ in_prompt.change(prompt_change, inputs=[in_prompt, config], outputs=[config, out_config, out_code])
250
+ in_negative_prompt.change(negative_prompt_change, inputs=[in_negative_prompt, config], outputs=[config, out_config, out_code])
251
+ btn_start_pipeline.click(run_inference, inputs=[config, config_history], outputs=[out_image, out_config_history, config_history])
252
 
253
  # send current respect initial config to init_config to populate parameters to all relevant input fields
254
  # if GET parameter is set, it will overwrite initial config parameters
255
+ demo.load(fn=get_config_from_url, inputs=config,
256
  outputs=[
257
  in_models,
258
  in_devices,
appConfig.json CHANGED
@@ -3,10 +3,12 @@
3
 
4
  "sd-dreambooth-library/solo-levelling-art-style": {
5
  "use_safetensors": false,
 
6
  "scheduler": "DDPMScheduler"
7
  },
8
  "CompVis/stable-diffusion-v1-4": {
9
  "use_safetensors": true,
 
10
  "scheduler": "DDPMScheduler"
11
  }
12
 
 
3
 
4
  "sd-dreambooth-library/solo-levelling-art-style": {
5
  "use_safetensors": false,
6
+ "description": "A wonderful model",
7
  "scheduler": "DDPMScheduler"
8
  },
9
  "CompVis/stable-diffusion-v1-4": {
10
  "use_safetensors": true,
11
+ "description": "Another wonderful model",
12
  "scheduler": "DDPMScheduler"
13
  }
14
 
change_handlers.py DELETED
File without changes
config.py CHANGED
@@ -3,7 +3,41 @@ import base64
3
  import json
4
  import torch
5
 
6
- def init_config(request: gr.Request, inital_config):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
  encoded_params = request.request.query_params.get('config')
9
  return_config = {}
@@ -12,7 +46,7 @@ def init_config(request: gr.Request, inital_config):
12
  if encoded_params is not None:
13
  decoded_params = base64.b64decode(encoded_params)
14
  decoded_params = decoded_params.decode('utf-8')
15
- decoded_params = decoded_params.replace("'", '"').replace('None', 'null').replace('False', 'false')
16
  dict_params = json.loads(decoded_params)
17
 
18
  return_config = dict_params
@@ -20,10 +54,9 @@ def init_config(request: gr.Request, inital_config):
20
  # otherwise use default initial config
21
  else:
22
 
23
- inital_config = inital_config.replace("'", '"').replace('None', 'null').replace('False', 'false')
24
- dict_inital_config = json.loads(inital_config)
25
-
26
- return_config = dict_inital_config
27
 
28
  return [return_config['model'],
29
  return_config['device'],
@@ -40,126 +73,82 @@ def init_config(request: gr.Request, inital_config):
40
  return_config['guidance_scale']
41
  ]
42
 
43
-
44
- class Config:
45
-
46
- def __init__(self):
47
-
48
- self.code = {}
49
- self.history = []
50
- self.devices = []
51
-
52
- appConfig = self.load_app_config()
53
-
54
- self.model_configs = appConfig.get("models", {})
55
- self.scheduler_configs = appConfig.get("schedulers", {})
56
 
57
- # default device
58
- self.devices = appConfig.get("devices", [])
59
- device = None
60
- data_type = 'float16'
61
- allow_tensorfloat32 = False
62
- if torch.cuda.is_available():
63
- device = "cuda"
64
- data_type = 'bfloat16'
65
- allow_tensorfloat32 = True
66
- elif torch.backends.mps.is_available():
67
- device = "mps"
68
- else:
69
- device = "cpu"
70
-
71
- self.current = {
72
- "device": device,
73
- "model": None,
74
- "scheduler": None,
75
- "variant": None,
76
- "allow_tensorfloat32": allow_tensorfloat32,
77
- "use_safetensors": False,
78
- "data_type": data_type,
79
- "safety_checker": False,
80
- "requires_safety_checker": False,
81
- "manual_seed": 42,
82
- "inference_steps": 10,
83
- "guidance_scale": 0.5,
84
- "prompt": 'A white rabbit',
85
- "negative_prompt": 'lowres, cropped, worst quality, low quality, chat bubble, chat bubbles, ugly',
86
- }
87
-
88
- self.assemble_code()
89
 
90
- def load_app_config(self):
91
- try:
92
- with open('appConfig.json', 'r') as f:
93
- appConfig = json.load(f)
94
- except FileNotFoundError:
95
- print("App config file not found.")
96
- except json.JSONDecodeError:
97
- print("Error decoding JSON in app config file.")
98
- except Exception as e:
99
- print("An error occurred while loading app config:", str(e))
100
-
101
- return appConfig
102
-
103
- def set_config(self, key, value):
104
-
105
- self.current[key] = value
106
 
107
- return str(self.current)
 
 
108
 
109
- def get_scheduler_description(self, scheduler):
110
-
111
- if type(scheduler) != list and scheduler is not None:
112
-
113
- return self.scheduler_configs[scheduler]
114
 
115
- else:
116
-
117
- return ''
118
 
119
- def assemble_code(self):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
 
121
- self.code['001_code'] = f'''device = "{self.current['device']}"'''
122
- if self.current['data_type'] == "bfloat16":
123
- self.code['002_data_type'] = 'data_type = torch.bfloat16'
124
- else:
125
- self.code['002_data_type'] = 'data_type = torch.float16'
126
- self.code['003_tf32'] = f'torch.backends.cuda.matmul.allow_tf32 = {self.current["allow_tensorfloat32"]}'
127
- if str(self.current["variant"]) == 'None':
128
- self.code['004_variant'] = f'variant = {self.current["variant"]}'
129
- else:
130
- self.code['004_variant'] = f'variant = "{self.current["variant"]}"'
131
- self.code['050_init_pipe'] = f'''pipeline = DiffusionPipeline.from_pretrained(
132
- "{self.current['model']}",
133
- use_safetensors=use_safetensors,
134
- torch_dtype=data_type,
135
- variant=variant).to(device)'''
136
 
137
- self.code['054_requires_safety_checker'] = f'pipeline.requires_safety_checker = {self.current["requires_safety_checker"]}'
 
 
 
 
 
 
138
 
139
- if not self.current["safety_checker"] or str(self.current["safety_checker"]).lower == 'false':
140
- self.code['055_safety_checker'] = f'pipeline.safety_checker = None'
141
- else:
142
- self.code['055_safety_checker'] = ''
143
-
144
- self.code['060_scheduler'] = f'pipeline.scheduler = {self.current["scheduler"]}.from_config(pipeline.scheduler.config)'
145
-
146
- if self.current['manual_seed'] < 0 or self.current['manual_seed'] is None or self.current['manual_seed'] == '':
147
- self.code['091_manual_seed'] = f'# manual_seed = {self.current["manual_seed"]}'
148
- self.code['092_generator'] = f'generator = torch.Generator("{self.current["device"]}")'
149
- else:
150
- self.code['091_manual_seed'] = f'manual_seed = {self.current["manual_seed"]}'
151
- self.code['092_generator'] = f'generator = torch.manual_seed(manual_seed)'
152
 
153
- self.code["080_prompt"] = f'prompt = "{self.current["prompt"]}"'
154
- self.code["085_negative_prompt"] = f'negative_prompt = "{self.current["negative_prompt"]}"'
155
- self.code["090_inference_steps"] = f'inference_steps = {self.current["inference_steps"]}'
156
- self.code["095_guidance_scale"] = f'guidance_scale = {self.current["guidance_scale"]}'
157
-
158
- self.code["100_run_inference"] = f'''image = pipeline(
159
- prompt=prompt,
160
- negative_prompt=negative_prompt,
161
- generator=generator,
162
- num_inference_steps=inference_steps,
163
- guidance_scale=guidance_scale).images[0]'''
164
-
165
- return '\r\n'.join(value[1] for value in sorted(self.code.items()))
 
3
  import json
4
  import torch
5
 
6
+ def get_initial_config():
7
+
8
+ device = None
9
+ data_type = 'float16'
10
+ allow_tensorfloat32 = "False"
11
+ if torch.cuda.is_available():
12
+ device = "cuda"
13
+ data_type = 'bfloat16'
14
+ allow_tensorfloat32 = "True"
15
+ elif torch.backends.mps.is_available():
16
+ device = "mps"
17
+ else:
18
+ device = "cpu"
19
+
20
+ config = {
21
+ "device": device,
22
+ "model": None,
23
+ "scheduler": None,
24
+ "variant": None,
25
+ "allow_tensorfloat32": allow_tensorfloat32,
26
+ "use_safetensors": "False",
27
+ "data_type": data_type,
28
+ "safety_checker": "False",
29
+ "requires_safety_checker": "False",
30
+ "manual_seed": 42,
31
+ "inference_steps": 10,
32
+ "guidance_scale": 0.5,
33
+ "prompt": 'A white rabbit',
34
+ "negative_prompt": 'lowres, cropped, worst quality, low quality, chat bubble, chat bubbles, ugly',
35
+ }
36
+
37
+
38
+ return config
39
+
40
+ def get_config_from_url(request: gr.Request, initial_config):
41
 
42
  encoded_params = request.request.query_params.get('config')
43
  return_config = {}
 
46
  if encoded_params is not None:
47
  decoded_params = base64.b64decode(encoded_params)
48
  decoded_params = decoded_params.decode('utf-8')
49
+ decoded_params = decoded_params.replace("'", '"').replace('None', 'null').replace('False', 'False')
50
  dict_params = json.loads(decoded_params)
51
 
52
  return_config = dict_params
 
54
  # otherwise use default initial config
55
  else:
56
 
57
+ # initial_config = initial_config.replace("'", '"').replace('None', 'null').replace('False', 'False')
58
+ # return_config = json.loads(initial_config)
59
+ return_config = initial_config
 
60
 
61
  return [return_config['model'],
62
  return_config['device'],
 
73
  return_config['guidance_scale']
74
  ]
75
 
76
+ def load_app_config():
77
+ try:
78
+ with open('appConfig.json', 'r') as f:
79
+ appConfig = json.load(f)
80
+ except FileNotFoundError:
81
+ print("App config file not found.")
82
+ except json.JSONDecodeError:
83
+ print("Error decoding JSON in app config file.")
84
+ except Exception as e:
85
+ print("An error occurred while loading app config:", str(e))
 
 
 
86
 
87
+ return appConfig
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
 
89
+ def set_config(config, key, value):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
 
91
+ # str_config = str_config.replace("'", '"').replace('None', 'null').replace('False', 'false')
92
+ # config = json.loads(str_config)
93
+ config[key] = value
94
 
95
+ return config
 
 
 
 
96
 
97
+ def assemble_code(str_config):
 
 
98
 
99
+ # str_config = str_config.replace("'", '"').replace('None', 'null').replace('False', 'false')
100
+ # config = json.loads(str_config)
101
+ config = str_config
102
+
103
+ code = {}
104
+
105
+ code['001_code'] = f'''device = "{config['device']}"'''
106
+ if config['data_type'] == "bfloat16":
107
+ code['002_data_type'] = 'data_type = torch.bfloat16'
108
+ else:
109
+ code['002_data_type'] = 'data_type = torch.float16'
110
+ code['003_tf32'] = f'torch.backends.cuda.matmul.allow_tf32 = {config["allow_tensorfloat32"]}'
111
+
112
+ if str(config["variant"]) == 'None':
113
+ code['004_variant'] = f'variant = {config["variant"]}'
114
+ else:
115
+ code['004_variant'] = f'variant = "{config["variant"]}"'
116
+
117
 
118
+ code['005_use_safetensors'] = f'''use_safetensors = {config["use_safetensors"]}'''
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
 
120
+ code['050_init_pipe'] = f'''pipeline = DiffusionPipeline.from_pretrained(
121
+ "{config['model']}",
122
+ use_safetensors=use_safetensors,
123
+ torch_dtype=data_type,
124
+ variant=variant).to(device)'''
125
+
126
+ code['054_requires_safety_checker'] = f'pipeline.requires_safety_checker = {config["requires_safety_checker"]}'
127
 
128
+ if str(config["safety_checker"]).lower() == 'false':
129
+ code['055_safety_checker'] = f'pipeline.safety_checker = None'
130
+ else:
131
+ code['055_safety_checker'] = ''
132
+
133
+ code['060_scheduler'] = f'pipeline.scheduler = {config["scheduler"]}.from_config(pipeline.scheduler.config)'
134
+
135
+ if config['manual_seed'] < 0 or config['manual_seed'] is None or config['manual_seed'] == '':
136
+ code['091_manual_seed'] = f'# manual_seed = {config["manual_seed"]}'
137
+ code['092_generator'] = f'generator = torch.Generator("{config["device"]}")'
138
+ else:
139
+ code['091_manual_seed'] = f'manual_seed = {config["manual_seed"]}'
140
+ code['092_generator'] = f'generator = torch.manual_seed(manual_seed)'
141
 
142
+ code["080_prompt"] = f'prompt = "{config["prompt"]}"'
143
+ code["085_negative_prompt"] = f'negative_prompt = "{config["negative_prompt"]}"'
144
+ code["090_inference_steps"] = f'inference_steps = {config["inference_steps"]}'
145
+ code["095_guidance_scale"] = f'guidance_scale = {config["guidance_scale"]}'
146
+
147
+ code["100_run_inference"] = f'''image = pipeline(
148
+ prompt=prompt,
149
+ negative_prompt=negative_prompt,
150
+ generator=generator,
151
+ num_inference_steps=inference_steps,
152
+ guidance_scale=guidance_scale).images[0]'''
153
+
154
+ return '\r\n'.join(value[1] for value in sorted(code.items()))
helpers.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from diffusers import (
3
+ DDPMScheduler,
4
+ DDIMScheduler,
5
+ PNDMScheduler,
6
+ LMSDiscreteScheduler,
7
+ EulerAncestralDiscreteScheduler,
8
+ EulerDiscreteScheduler,
9
+ DPMSolverMultistepScheduler,
10
+ )
11
+ import base64
12
+
13
+ def get_variant(str_variant):
14
+
15
+ if str(str_variant).lower() == 'none':
16
+ return None
17
+ else:
18
+ return str_variant
19
+
20
+ def get_bool(str_bool):
21
+
22
+ if str(str_bool).lower() == 'false':
23
+ return False
24
+ else:
25
+ return True
26
+
27
+
28
+ def get_data_type(str_data_type):
29
+
30
+ if str_data_type == "bfloat16":
31
+ return torch.bfloat16 # BFloat16 is not supported on MPS as of 01/2024
32
+ else:
33
+ return torch.float16 # Half-precision weights, as of https://huggingface.co/docs/diffusers/main/en/optimization/fp16 will save GPU memory
34
+
35
+ def get_tensorfloat32(allow_tensorfloat32):
36
+
37
+ return True if str(allow_tensorfloat32).lower() == 'true' else False
38
+
39
+ def get_scheduler(scheduler, pipeline_config):
40
+
41
+ if scheduler == "DDPMScheduler":
42
+ return DDPMScheduler.from_config(pipeline_config)
43
+ elif scheduler == "DDIMScheduler":
44
+ return DDIMScheduler.from_config(pipeline_config)
45
+ elif scheduler == "PNDMScheduler":
46
+ return PNDMScheduler.from_config(pipeline_config)
47
+ elif scheduler == "LMSDiscreteScheduler":
48
+ return LMSDiscreteScheduler.from_config(pipeline_config)
49
+ elif scheduler == "EulerAncestralDiscreteScheduler":
50
+ return EulerAncestralDiscreteScheduler.from_config(pipeline_config)
51
+ elif scheduler == "EulerDiscreteScheduler":
52
+ return EulerDiscreteScheduler.from_config(pipeline_config)
53
+ elif scheduler == "DPMSolverMultistepScheduler":
54
+ return DPMSolverMultistepScheduler.from_config(pipeline_config)
55
+ else:
56
+ return DPMSolverMultistepScheduler.from_config(pipeline_config)
57
+
58
+ def dict_list_to_markdown_table(config_history):
59
+
60
+ if not config_history:
61
+ return ""
62
+
63
+ headers = list(config_history[0].keys())
64
+ markdown_table = "| share | " + " | ".join(headers) + " |\n"
65
+ markdown_table += "| --- | " + " | ".join(["---"] * len(headers)) + " |\n"
66
+
67
+ for index, config in enumerate(config_history):
68
+
69
+ encoded_config = base64.b64encode(str(config).encode()).decode()
70
+ share_link = f'<a target="_blank" href="?config={encoded_config}">📎</a>'
71
+ markdown_table += f"| {share_link} | " + " | ".join(str(config.get(key, "")) for key in headers) + " |\n"
72
+
73
+ markdown_table = '<div style="overflow-x: auto;">\n\n' + markdown_table + '</div>'
74
+
75
+ return markdown_table
76
+