n42 commited on
Commit
58f80cc
·
1 Parent(s): 285657d

using class as config

Browse files
Files changed (2) hide show
  1. app.py +94 -176
  2. config.py +175 -123
app.py CHANGED
@@ -22,43 +22,13 @@ from gradio import Interface
22
  from diffusers import AutoencoderKL
23
  import pandas as pd
24
  import base64
25
- from config import *
26
- from change_handlers import *
27
 
28
- # get
29
- # - initial configuration,
30
- # - a list of available devices from the config file
31
- # - a list of available models from the config file
32
- # - a list of available schedulers from the config file
33
- # - a dict that contains code to for reproduction
34
- config, devices, model_configs, scheduler_configs, code = get_inital_config()
35
-
36
- models = list(model_configs.keys())
37
- schedulers = list(scheduler_configs.keys())
38
-
39
- device = config["device"]
40
- model = config["model"]
41
- scheduler = config["scheduler"]
42
- variant = config["variant"]
43
- allow_tensorfloat32 = config["allow_tensorfloat32"]
44
- use_safetensors = config["use_safetensors"]
45
- data_type = config["data_type"]
46
- safety_checker = config["safety_checker"]
47
- requires_safety_checker = config["requires_safety_checker"]
48
- manual_seed = config["manual_seed"]
49
- inference_steps = config["inference_steps"]
50
- guidance_scale = config["guidance_scale"]
51
- prompt = config["prompt"]
52
- negative_prompt = config["negative_prompt"]
53
-
54
- config_history = []
55
 
56
  def device_change(device):
57
 
58
- code[code_pos_device] = f'''device = "{device}"'''
59
- config['device'] = device
60
-
61
- return get_sorted_code(), str(config)
62
 
63
  def models_change(model, scheduler):
64
 
@@ -67,148 +37,72 @@ def models_change(model, scheduler):
67
  # no model selected (because this is UI init run)
68
  if type(model) != list and model is not None:
69
 
70
- use_safetensors = str(model_configs[model]['use_safetensors'])
71
 
72
  # if no scheduler is selected, choose the default one for this model
73
  if scheduler == None:
74
 
75
- scheduler = model_configs[model]['scheduler']
76
-
77
- code[code_pos_init_pipeline] = f'''pipeline = DiffusionPipeline.from_pretrained(
78
- "{model}",
79
- use_safetensors=use_safetensors,
80
- torch_dtype=data_type,
81
- variant=variant).to(device)'''
82
- config['model'] = model
83
 
84
- safety_checker_change(safety_checker)
85
- requires_safety_checker_change(requires_safety_checker)
86
 
87
- return get_sorted_code(), use_safetensors, scheduler, str(config)
88
 
89
- def data_type_change(selected_data_type):
90
 
91
- config['data_type'] = data_type
92
 
93
- get_data_type(selected_data_type)
94
- return get_sorted_code(), str(config)
95
-
96
- def get_data_type(selected_data_type):
97
 
98
- if selected_data_type == "bfloat16":
99
- code[code_pos_data_type] = 'data_type = torch.bfloat16'
100
- data_type = torch.bfloat16 # BFloat16 is not supported on MPS as of 01/2024
101
  else:
102
- code[code_pos_data_type] = 'data_type = torch.float16'
103
- data_type = torch.float16 # Half-precision weights, as of https://huggingface.co/docs/diffusers/main/en/optimization/fp16 will save GPU memory
104
-
105
- return data_type
106
 
107
  def tensorfloat32_change(allow_tensorfloat32):
108
-
109
- config['allow_tensorfloat32'] = allow_tensorfloat32
110
-
111
- get_tensorfloat32(allow_tensorfloat32)
112
-
113
- return get_sorted_code(), str(config)
114
-
115
- def get_tensorfloat32(allow_tensorfloat32):
116
-
117
- code[code_pos_tf32] = f'torch.backends.cuda.matmul.allow_tf32 = {allow_tensorfloat32}'
118
-
119
- return True if str(allow_tensorfloat32).lower() == 'true' else False
120
 
121
  def inference_steps_change(inference_steps):
122
 
123
- config['inference_steps'] = inference_steps
124
- code[code_pos_inference_steps] = f'inference_steps = {inference_steps}'
125
-
126
- return get_sorted_code(), str(config)
127
 
128
  def manual_seed_change(manual_seed):
129
-
130
- config['manual_seed'] = manual_seed
131
-
132
- if manual_seed < 0 or manual_seed is None or manual_seed == '':
133
- code[code_pos_manual_seed] = f'# manual_seed = {manual_seed}'
134
- generator = f'generator = torch.Generator("{device}")'
135
- else:
136
- code[code_pos_manual_seed] = f'manual_seed = {manual_seed}'
137
- generator = f'generator = torch.manual_seed(manual_seed)'
138
-
139
- code[code_pos_run_inference] = f'''image = pipeline(
140
- prompt=prompt,
141
- negative_prompt=negative_prompt,
142
- generator={generator},
143
- num_inference_steps=inference_steps,
144
- guidance_scale=guidance_scale).images[0]'''
145
-
146
- return get_sorted_code(), str(config)
147
 
148
  def guidance_scale_change(guidance_scale):
149
 
150
- config['guidance_scale'] = guidance_scale
151
- code[code_pos_guidance_scale] = f'guidance_scale = {guidance_scale}'
152
-
153
- return get_sorted_code(), str(config)
154
 
155
  def prompt_change(prompt):
156
 
157
- config['prompt'] = prompt
158
- code[code_pos_prompt] = f'prompt = {prompt}'
159
-
160
- return get_sorted_code(), str(config)
161
 
162
  def negative_prompt_change(negative_prompt):
163
 
164
- config['negative_prompt'] = negative_prompt
165
- code[code_pos_negative_prompt] = f'negative_prompt = {negative_prompt}'
166
-
167
- return get_sorted_code(), str(config)
168
 
169
  def variant_change(variant):
170
 
171
- config['variant'] = variant
172
-
173
- if str(variant) == 'None':
174
- code[code_pos_variant] = f'variant = {variant}'
175
- else:
176
- code[code_pos_variant] = f'variant = "{variant}"'
177
-
178
- return get_sorted_code(), str(config)
179
 
180
  def safety_checker_change(safety_checker):
181
-
182
- config['safety_checker'] = safety_checker
183
-
184
- if not safety_checker or str(safety_checker).lower == 'false':
185
- code[code_pos_safety_checker] = f'pipeline.safety_checker = None'
186
- else:
187
- code[code_pos_safety_checker] = ''
188
-
189
- return get_sorted_code(), str(config)
190
 
191
  def requires_safety_checker_change(requires_safety_checker):
192
 
193
- code[code_pos_requires_safety_checker] = f'pipeline.requires_safety_checker = {requires_safety_checker}'
194
-
195
- config['requires_safety_checker'] = requires_safety_checker
196
-
197
- return get_sorted_code(), str(config)
198
 
199
  def schedulers_change(scheduler):
200
 
201
- if type(scheduler) != list and scheduler is not None:
202
 
203
- code[code_pos_scheduler] = f'pipeline.scheduler = {scheduler}.from_config(pipeline.scheduler.config)'
204
-
205
- config['scheduler'] = scheduler
206
-
207
- return get_sorted_code(), scheduler_configs[scheduler], str(config)
208
-
209
- else:
210
 
211
- return get_sorted_code(), '', str(config)
212
 
213
  def get_scheduler(scheduler, config):
214
 
@@ -228,7 +122,33 @@ def get_scheduler(scheduler, config):
228
  return DPMSolverMultistepScheduler.from_config(config)
229
  else:
230
  return DPMSolverMultistepScheduler.from_config(config)
231
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
232
  # pipeline
233
  def run_inference(model,
234
  device,
@@ -245,26 +165,24 @@ def run_inference(model,
245
  guidance_scale,
246
  progress=gr.Progress(track_tqdm=True)):
247
 
248
- if model != None and scheduler != None:
249
 
250
  progress((1,3), desc="Preparing pipeline initialization...")
251
 
252
- torch.backends.cuda.matmul.allow_tf32 = get_tensorfloat32(allow_tensorfloat32) # Use TensorFloat-32 as of https://huggingface.co/docs/diffusers/main/en/optimization/fp16 faster, but slightly less accurate computations
253
-
254
- bool_use_safetensors = True if use_safetensors.lower() == 'true' else False
255
 
256
  progress((2,3), desc="Initializing pipeline...")
257
 
258
  pipeline = DiffusionPipeline.from_pretrained(
259
- model,
260
- use_safetensors=bool_use_safetensors,
261
- torch_dtype=get_data_type(data_type),
262
- variant=variant).to(device)
263
 
264
- if safety_checker is None or str(safety_checker).lower == 'false':
265
  pipeline.safety_checker = None
266
 
267
- pipeline.requires_safety_checker = bool(requires_safety_checker)
268
 
269
  pipeline.scheduler = get_scheduler(scheduler, pipeline.scheduler.config)
270
 
@@ -283,7 +201,7 @@ def run_inference(model,
283
  num_inference_steps=int(inference_steps),
284
  guidance_scale=float(guidance_scale)).images[0]
285
 
286
- config_history.append(config.copy())
287
 
288
  return image, dict_list_to_markdown_table(config_history)
289
 
@@ -300,8 +218,6 @@ def dict_list_to_markdown_table(config_history):
300
  markdown_table = "| share | " + " | ".join(headers) + " |\n"
301
  markdown_table += "| --- | " + " | ".join(["---"] * len(headers)) + " |\n"
302
 
303
- print('#######################')
304
-
305
  for index, config in enumerate(config_history):
306
 
307
  encoded_config = base64.b64encode(str(config).encode()).decode()
@@ -321,23 +237,25 @@ with gr.Blocks() as demo:
321
  </small>''')
322
  gr.Markdown("### Device specific settings")
323
  with gr.Row():
324
- in_devices = gr.Dropdown(label="Device:", value=device, choices=devices, filterable=True, multiselect=False, allow_custom_value=True)
325
- in_data_type = gr.Radio(label="Data Type:", value=data_type, choices=["bfloat16", "float16"], info="`bfloat16` is not supported on MPS devices right now; Half-precision weights, will save GPU memory, see https://huggingface.co/docs/diffusers/main/en/optimization/fp16")
326
- in_allow_tensorfloat32 = gr.Radio(label="Allow TensorFloat32:", value=allow_tensorfloat32, choices=[True, False], info="is not supported on MPS devices right now; use TensorFloat-32 is faster, but results in slightly less accurate computations, see https://huggingface.co/docs/diffusers/main/en/optimization/fp16 ")
327
- in_variant = gr.Radio(label="Variant:", value=variant, choices=["fp16", None], info="Use half-precision weights will save GPU memory, not all models support that, see https://huggingface.co/docs/diffusers/main/en/optimization/fp16 ")
328
 
329
  gr.Markdown("### Model specific settings")
330
  with gr.Row():
 
331
  in_models = gr.Dropdown(choices=models, label="Model")
332
  with gr.Row():
333
  with gr.Column(scale=1):
334
  in_use_safetensors = gr.Radio(label="Use safe tensors:", choices=["True", "False"], interactive=False)
335
  with gr.Column(scale=1):
336
- in_safety_checker = gr.Radio(label="Enable safety checker:", value=safety_checker, choices=[True, False])
337
- in_requires_safety_checker = gr.Radio(label="Requires safety checker:", value=requires_safety_checker, choices=[True, False])
338
 
339
  gr.Markdown("### Scheduler")
340
  with gr.Row():
 
341
  in_schedulers = gr.Dropdown(choices=schedulers, label="Scheduler", info="see https://huggingface.co/docs/diffusers/using-diffusers/loading#schedulers" )
342
  out_scheduler_description = gr.Textbox(value="", label="Description")
343
 
@@ -347,12 +265,12 @@ with gr.Blocks() as demo:
347
 
348
  gr.Markdown("### Inference settings")
349
  with gr.Row():
350
- in_prompt = gr.TextArea(label="Prompt", value=prompt)
351
- in_negative_prompt = gr.TextArea(label="Negative prompt", value=negative_prompt)
352
  with gr.Row():
353
- in_inference_steps = gr.Number(label="Inference steps", value=inference_steps)
354
- in_manual_seed = gr.Number(label="Manual seed", value=manual_seed, info="Set this to -1 or leave it empty to randomly generate an image. A fixed value will result in a similar image for every run")
355
- in_guidance_scale = gr.Slider(minimum=0, maximum=1, step=0.01, label="Guidance Scale", value=guidance_scale, info="A low guidance scale leads to a faster inference time, with the drawback that negative prompts don’t have any effect on the denoising process.")
356
 
357
  gr.Markdown("### Output")
358
  with gr.Row():
@@ -360,25 +278,25 @@ with gr.Blocks() as demo:
360
  with gr.Row():
361
  # out_result = gr.Textbox(label="Status", value="")
362
  out_image = gr.Image()
363
- out_code = gr.Code(get_sorted_code(), label="Code")
364
  with gr.Row():
365
- out_current_config = gr.Code(value=str(config), label="Current config")
366
  with gr.Row():
367
  out_config_history = gr.Markdown(dict_list_to_markdown_table(config_history))
368
 
369
- in_devices.change(device_change, inputs=[in_devices], outputs=[out_code, out_current_config])
370
- in_data_type.change(data_type_change, inputs=[in_data_type], outputs=[out_code, out_current_config])
371
- in_allow_tensorfloat32.change(tensorfloat32_change, inputs=[in_allow_tensorfloat32], outputs=[out_code, out_current_config])
372
- in_variant.change(variant_change, inputs=[in_variant], outputs=[out_code, out_current_config])
373
- in_models.change(models_change, inputs=[in_models, in_schedulers], outputs=[out_code, in_use_safetensors, in_schedulers, out_current_config])
374
- in_safety_checker.change(safety_checker_change, inputs=[in_safety_checker], outputs=[out_code, out_current_config])
375
- in_requires_safety_checker.change(requires_safety_checker_change, inputs=[in_requires_safety_checker], outputs=[out_code, out_current_config])
376
- in_schedulers.change(schedulers_change, inputs=[in_schedulers], outputs=[out_code, out_scheduler_description, out_current_config])
377
- in_inference_steps.change(inference_steps_change, inputs=[in_inference_steps], outputs=[out_code, out_current_config])
378
- in_manual_seed.change(manual_seed_change, inputs=[in_manual_seed], outputs=[out_code, out_current_config])
379
- in_guidance_scale.change(guidance_scale_change, inputs=[in_guidance_scale], outputs=[out_code, out_current_config])
380
- in_prompt.change(prompt_change, inputs=[in_prompt], outputs=[out_code, out_current_config])
381
- in_negative_prompt.change(negative_prompt_change, inputs=[in_negative_prompt], outputs=[out_code, out_current_config])
382
  btn_start_pipeline.click(run_inference, inputs=[
383
  in_models,
384
  in_devices,
@@ -397,7 +315,7 @@ with gr.Blocks() as demo:
397
  out_image,
398
  out_config_history])
399
 
400
- demo.load(fn=init_config, inputs=out_current_config,
401
  outputs=[
402
  in_models,
403
  in_devices,
 
22
  from diffusers import AutoencoderKL
23
  import pandas as pd
24
  import base64
25
+ from config import Config
 
26
 
27
+ config = Config()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
  def device_change(device):
30
 
31
+ return config.set_config('device', device), config.assemble_code()
 
 
 
32
 
33
  def models_change(model, scheduler):
34
 
 
37
  # no model selected (because this is UI init run)
38
  if type(model) != list and model is not None:
39
 
40
+ use_safetensors = str(config.model_configs[model]['use_safetensors'])
41
 
42
  # if no scheduler is selected, choose the default one for this model
43
  if scheduler == None:
44
 
45
+ scheduler = config.model_configs[model]['scheduler']
 
 
 
 
 
 
 
46
 
47
+ safety_checker_change(config.current["safety_checker"])
48
+ requires_safety_checker_change(config.current["requires_safety_checker"])
49
 
50
+ return use_safetensors, scheduler, config.set_config('model', model), config.assemble_code()
51
 
52
+ def data_type_change(data_type):
53
 
54
+ return config.set_config('data_type', data_type), config.assemble_code()
55
 
56
+ def get_data_type(str_data_type):
 
 
 
57
 
58
+ if str_data_type == "bfloat16":
59
+ return torch.bfloat16 # BFloat16 is not supported on MPS as of 01/2024
 
60
  else:
61
+ return torch.float16 # Half-precision weights, as of https://huggingface.co/docs/diffusers/main/en/optimization/fp16 will save GPU memory
 
 
 
62
 
63
  def tensorfloat32_change(allow_tensorfloat32):
64
+
65
+ return config.set_config('allow_tensorfloat32', allow_tensorfloat32), config.assemble_code()
 
 
 
 
 
 
 
 
 
 
66
 
67
  def inference_steps_change(inference_steps):
68
 
69
+ return config.set_config('inference_steps', inference_steps), config.assemble_code()
 
 
 
70
 
71
  def manual_seed_change(manual_seed):
72
+
73
+ return config.set_config('manual_seed', manual_seed), config.assemble_code()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
  def guidance_scale_change(guidance_scale):
76
 
77
+ return config.set_config('guidance_scale', guidance_scale), config.assemble_code()
 
 
 
78
 
79
  def prompt_change(prompt):
80
 
81
+ return config.set_config('prompt', prompt), config.assemble_code()
 
 
 
82
 
83
  def negative_prompt_change(negative_prompt):
84
 
85
+ return config.set_config('negative_prompt', negative_prompt), config.assemble_code()
 
 
 
86
 
87
  def variant_change(variant):
88
 
89
+ return config.set_config('variant', variant), config.assemble_code()
 
 
 
 
 
 
 
90
 
91
  def safety_checker_change(safety_checker):
92
+
93
+ return config.set_config('safety_checker', safety_checker), config.assemble_code()
 
 
 
 
 
 
 
94
 
95
  def requires_safety_checker_change(requires_safety_checker):
96
 
97
+ return config.set_config('requires_safety_checker', requires_safety_checker), config.assemble_code()
 
 
 
 
98
 
99
  def schedulers_change(scheduler):
100
 
101
+ return config.get_scheduler_description(scheduler), config.set_config('scheduler', scheduler), config.assemble_code()
102
 
103
+ def get_tensorfloat32(allow_tensorfloat32):
 
 
 
 
 
 
104
 
105
+ return True if str(allow_tensorfloat32).lower() == 'true' else False
106
 
107
  def get_scheduler(scheduler, config):
108
 
 
122
  return DPMSolverMultistepScheduler.from_config(config)
123
  else:
124
  return DPMSolverMultistepScheduler.from_config(config)
125
+
126
+ # get
127
+ # - initial configuration,
128
+ # - a list of available devices from the config file
129
+ # - a list of available models from the config file
130
+ # - a list of available schedulers from the config file
131
+ # - a dict that contains code to for reproduction
132
+ config.set_inital_config()
133
+ # config.current, devices, model_configs, scheduler_configs, code = config.get_inital_config()
134
+
135
+ # device = config.current["device"]
136
+ # model = config.current["model"]
137
+ # scheduler = config.current["scheduler"]
138
+ # variant = config.current["variant"]
139
+ # allow_tensorfloat32 = config.current["allow_tensorfloat32"]
140
+ # use_safetensors = config.current["use_safetensors"]
141
+ # data_type = config.current["data_type"]
142
+ # safety_checker = config.current["safety_checker"]
143
+ # requires_safety_checker = config.current["requires_safety_checker"]
144
+ # manual_seed = config.current["manual_seed"]
145
+ # inference_steps = config.current["inference_steps"]
146
+ # guidance_scale = config.current["guidance_scale"]
147
+ # prompt = config.current["prompt"]
148
+ # negative_prompt = config.current["negative_prompt"]
149
+
150
+ config_history = []
151
+
152
  # pipeline
153
  def run_inference(model,
154
  device,
 
165
  guidance_scale,
166
  progress=gr.Progress(track_tqdm=True)):
167
 
168
+ if config.current["model"] != None and config.current["scheduler"] != None:
169
 
170
  progress((1,3), desc="Preparing pipeline initialization...")
171
 
172
+ torch.backends.cuda.matmul.allow_tf32 = config.current["allow_tensorfloat32"] # Use TensorFloat-32 as of https://huggingface.co/docs/diffusers/main/en/optimization/fp16 faster, but slightly less accurate computations
 
 
173
 
174
  progress((2,3), desc="Initializing pipeline...")
175
 
176
  pipeline = DiffusionPipeline.from_pretrained(
177
+ config.current["model"],
178
+ use_safetensors=config.current["use_safetensors"],
179
+ torch_dtype=get_data_type(config.current["data_type"]),
180
+ variant=variant).to(config.current["device"])
181
 
182
+ if config.current["safety_checker"] is None or str(config.current["safety_checker"]).lower == 'false':
183
  pipeline.safety_checker = None
184
 
185
+ pipeline.requires_safety_checker = config.current["requires_safety_checker"]
186
 
187
  pipeline.scheduler = get_scheduler(scheduler, pipeline.scheduler.config)
188
 
 
201
  num_inference_steps=int(inference_steps),
202
  guidance_scale=float(guidance_scale)).images[0]
203
 
204
+ config_history.append(config.current.copy())
205
 
206
  return image, dict_list_to_markdown_table(config_history)
207
 
 
218
  markdown_table = "| share | " + " | ".join(headers) + " |\n"
219
  markdown_table += "| --- | " + " | ".join(["---"] * len(headers)) + " |\n"
220
 
 
 
221
  for index, config in enumerate(config_history):
222
 
223
  encoded_config = base64.b64encode(str(config).encode()).decode()
 
237
  </small>''')
238
  gr.Markdown("### Device specific settings")
239
  with gr.Row():
240
+ in_devices = gr.Dropdown(label="Device:", value=config.current["device"], choices=config.devices, filterable=True, multiselect=False, allow_custom_value=True)
241
+ in_data_type = gr.Radio(label="Data Type:", value=config.current["data_type"], choices=["bfloat16", "float16"], info="`bfloat16` is not supported on MPS devices right now; Half-precision weights, will save GPU memory, see https://huggingface.co/docs/diffusers/main/en/optimization/fp16")
242
+ in_allow_tensorfloat32 = gr.Radio(label="Allow TensorFloat32:", value=config.current["allow_tensorfloat32"], choices=[True, False], info="is not supported on MPS devices right now; use TensorFloat-32 is faster, but results in slightly less accurate computations, see https://huggingface.co/docs/diffusers/main/en/optimization/fp16 ")
243
+ in_variant = gr.Radio(label="Variant:", value=config.current["variant"], choices=["fp16", None], info="Use half-precision weights will save GPU memory, not all models support that, see https://huggingface.co/docs/diffusers/main/en/optimization/fp16 ")
244
 
245
  gr.Markdown("### Model specific settings")
246
  with gr.Row():
247
+ models = list(config.model_configs.keys())
248
  in_models = gr.Dropdown(choices=models, label="Model")
249
  with gr.Row():
250
  with gr.Column(scale=1):
251
  in_use_safetensors = gr.Radio(label="Use safe tensors:", choices=["True", "False"], interactive=False)
252
  with gr.Column(scale=1):
253
+ in_safety_checker = gr.Radio(label="Enable safety checker:", value=config.current["safety_checker"], choices=[True, False])
254
+ in_requires_safety_checker = gr.Radio(label="Requires safety checker:", value=config.current["requires_safety_checker"], choices=[True, False])
255
 
256
  gr.Markdown("### Scheduler")
257
  with gr.Row():
258
+ schedulers = list(config.scheduler_configs.keys())
259
  in_schedulers = gr.Dropdown(choices=schedulers, label="Scheduler", info="see https://huggingface.co/docs/diffusers/using-diffusers/loading#schedulers" )
260
  out_scheduler_description = gr.Textbox(value="", label="Description")
261
 
 
265
 
266
  gr.Markdown("### Inference settings")
267
  with gr.Row():
268
+ in_prompt = gr.TextArea(label="Prompt", value=config.current["prompt"])
269
+ in_negative_prompt = gr.TextArea(label="Negative prompt", value=config.current["negative_prompt"])
270
  with gr.Row():
271
+ in_inference_steps = gr.Number(label="Inference steps", value=config.current["inference_steps"])
272
+ in_manual_seed = gr.Number(label="Manual seed", value=config.current["manual_seed"], info="Set this to -1 or leave it empty to randomly generate an image. A fixed value will result in a similar image for every run")
273
+ in_guidance_scale = gr.Slider(minimum=0, maximum=1, step=0.01, label="Guidance Scale", value=config.current["guidance_scale"], info="A low guidance scale leads to a faster inference time, with the drawback that negative prompts don’t have any effect on the denoising process.")
274
 
275
  gr.Markdown("### Output")
276
  with gr.Row():
 
278
  with gr.Row():
279
  # out_result = gr.Textbox(label="Status", value="")
280
  out_image = gr.Image()
281
+ out_code = gr.Code(config.assemble_code(), label="Code")
282
  with gr.Row():
283
+ out_current_config = gr.Code(value=str(config.current), label="Current config")
284
  with gr.Row():
285
  out_config_history = gr.Markdown(dict_list_to_markdown_table(config_history))
286
 
287
+ in_devices.change(device_change, inputs=[in_devices], outputs=[out_current_config, out_code])
288
+ in_data_type.change(data_type_change, inputs=[in_data_type], outputs=[out_current_config, out_code])
289
+ in_allow_tensorfloat32.change(tensorfloat32_change, inputs=[in_allow_tensorfloat32], outputs=[out_current_config, out_code])
290
+ in_variant.change(variant_change, inputs=[in_variant], outputs=[out_current_config, out_code])
291
+ in_models.change(models_change, inputs=[in_models, in_schedulers], outputs=[in_use_safetensors, in_schedulers, out_current_config, out_code])
292
+ in_safety_checker.change(safety_checker_change, inputs=[in_safety_checker], outputs=[out_current_config, out_code])
293
+ in_requires_safety_checker.change(requires_safety_checker_change, inputs=[in_requires_safety_checker], outputs=[out_current_config, out_code])
294
+ in_schedulers.change(schedulers_change, inputs=[in_schedulers], outputs=[out_scheduler_description, out_current_config, out_code])
295
+ in_inference_steps.change(inference_steps_change, inputs=[in_inference_steps], outputs=[out_current_config, out_code])
296
+ in_manual_seed.change(manual_seed_change, inputs=[in_manual_seed], outputs=[out_current_config, out_code])
297
+ in_guidance_scale.change(guidance_scale_change, inputs=[in_guidance_scale], outputs=[out_current_config, out_code])
298
+ in_prompt.change(prompt_change, inputs=[in_prompt], outputs=[out_current_config, out_code])
299
+ in_negative_prompt.change(negative_prompt_change, inputs=[in_negative_prompt], outputs=[out_current_config, out_code])
300
  btn_start_pipeline.click(run_inference, inputs=[
301
  in_models,
302
  in_devices,
 
315
  out_image,
316
  out_config_history])
317
 
318
+ demo.load(fn=config.init_config, inputs=out_current_config,
319
  outputs=[
320
  in_models,
321
  in_devices,
config.py CHANGED
@@ -3,135 +3,187 @@ import base64
3
  import json
4
  import torch
5
 
6
- code = {}
7
- code_pos_device = '001_code'
8
- code_pos_data_type = '002_data_type'
9
- code_pos_tf32 = '003_tf32'
10
- code_pos_variant = '004_variant'
11
- code_pos_init_pipeline = '050_init_pipe'
12
- code_pos_requires_safety_checker = '054_requires_safety_checker'
13
- code_pos_safety_checker = '055_safety_checker'
14
- code_pos_scheduler = '060_scheduler'
15
- code_pos_generator = '070_generator'
16
- code_pos_prompt = '080_prompt'
17
- code_pos_negative_prompt = '085_negative_prompt'
18
- code_pos_inference_steps = '090_inference_steps'
19
- code_pos_manual_seed = '091_manual_seed'
20
- code_pos_guidance_scale = '095_guidance_scale'
21
- code_pos_run_inference = '100_run_inference'
22
-
23
- def load_app_config():
24
- try:
25
- with open('appConfig.json', 'r') as f:
26
- appConfig = json.load(f)
27
- except FileNotFoundError:
28
- print("App config file not found.")
29
- except json.JSONDecodeError:
30
- print("Error decoding JSON in app config file.")
31
- except Exception as e:
32
- print("An error occurred while loading app config:", str(e))
33
 
34
- return appConfig
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
- def get_inital_config():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
- appConfig = load_app_config()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
- model_configs = appConfig.get("models", {})
41
- scheduler_configs = appConfig.get("schedulers", {})
42
-
43
- # default device
44
- devices = appConfig.get("devices", [])
45
- device = None
46
- data_type = 'float16'
47
- allow_tensorfloat32 = False
48
- if torch.cuda.is_available():
49
- device = "cuda"
50
- data_type = 'bfloat16'
51
- allow_tensorfloat32 = True
52
- elif torch.backends.mps.is_available():
53
- device = "mps"
54
- else:
55
- device = "cpu"
56
-
57
- initial_config = {
58
- "device": device,
59
- "model": None,
60
- "scheduler": None,
61
- "variant": None,
62
- "allow_tensorfloat32": allow_tensorfloat32,
63
- "use_safetensors": False,
64
- "data_type": data_type,
65
- "safety_checker": False,
66
- "requires_safety_checker": False,
67
- "manual_seed": 42,
68
- "inference_steps": 10,
69
- "guidance_scale": 0.5,
70
- "prompt": 'A white rabbit',
71
- "negative_prompt": 'lowres, cropped, worst quality, low quality, chat bubble, chat bubbles, ugly',
72
- }
73
-
74
- # code output order
75
- code[code_pos_device] = f'device = "{device}"'
76
- code[code_pos_variant] = f'variant = {initial_config["variant"]}'
77
- code[code_pos_tf32] = f'torch.backends.cuda.matmul.allow_tf32 = {initial_config["allow_tensorfloat32"]}'
78
- code[code_pos_data_type] = 'data_type = torch.bfloat16'
79
- code[code_pos_init_pipeline] = 'sys.exit("No model selected!")'
80
- code[code_pos_safety_checker] = 'pipeline.safety_checker = None'
81
- code[code_pos_requires_safety_checker] = f'pipeline.requires_safety_checker = {initial_config["requires_safety_checker"]}'
82
- code[code_pos_scheduler] = 'sys.exit("No scheduler selected!")'
83
- code[code_pos_generator] = f'generator = torch.Generator("{device}")'
84
- code[code_pos_prompt] = f'prompt = "{initial_config["prompt"]}"'
85
- code[code_pos_negative_prompt] = f'negative_prompt = "{initial_config["negative_prompt"]}"'
86
- code[code_pos_inference_steps] = f'inference_steps = {initial_config["inference_steps"]}'
87
- code[code_pos_manual_seed] = f'manual_seed = {initial_config["manual_seed"]}'
88
- code[code_pos_guidance_scale] = f'guidance_scale = {initial_config["guidance_scale"]}'
89
- code[code_pos_run_inference] = f'''image = pipeline(
90
- prompt=prompt,
91
- negative_prompt=negative_prompt,
92
- generator=generator.manual_seed(manual_seed),
93
- num_inference_steps=inference_steps,
94
- guidance_scale=guidance_scale).images[0]'''
95
-
96
- return initial_config, devices, model_configs, scheduler_configs, code
97
-
98
- def init_config(request: gr.Request, inital_config):
99
-
100
- encoded_params = request.request.query_params.get('config')
101
- return_config = {}
102
-
103
- # get configuration from URL if GET parameter `share` is set
104
- if encoded_params is not None:
105
- decoded_params = base64.b64decode(encoded_params)
106
- decoded_params = decoded_params.decode('utf-8')
107
- decoded_params = decoded_params.replace("'", '"').replace('None', 'null').replace('False', 'false')
108
- dict_params = json.loads(decoded_params)
109
 
110
- return_config = dict_params
 
 
 
 
 
 
 
 
111
 
112
- # otherwise use default initial config
113
- else:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
 
115
- inital_config = inital_config.replace("'", '"').replace('None', 'null').replace('False', 'false')
116
- dict_inital_config = json.loads(inital_config)
117
 
118
- return_config = dict_inital_config
119
-
120
- return [return_config['model'],
121
- return_config['device'],
122
- return_config['use_safetensors'],
123
- return_config['data_type'],
124
- return_config['variant'],
125
- return_config['safety_checker'],
126
- return_config['requires_safety_checker'],
127
- return_config['scheduler'],
128
- return_config['prompt'],
129
- return_config['negative_prompt'],
130
- return_config['inference_steps'],
131
- return_config['manual_seed'],
132
- return_config['guidance_scale']
133
- ]
134
 
135
- def get_sorted_code():
136
-
137
- return '\r\n'.join(value[1] for value in sorted(code.items()))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  import json
4
  import torch
5
 
6
+ class Config:
7
+
8
+ def __init__(self):
9
+
10
+ self.code = {}
11
+ self.history = []
12
+ self.devices = []
13
+
14
+ def load_app_config(self):
15
+ try:
16
+ with open('appConfig.json', 'r') as f:
17
+ appConfig = json.load(f)
18
+ except FileNotFoundError:
19
+ print("App config file not found.")
20
+ except json.JSONDecodeError:
21
+ print("Error decoding JSON in app config file.")
22
+ except Exception as e:
23
+ print("An error occurred while loading app config:", str(e))
24
+
25
+ return appConfig
26
+
27
+ def set_inital_config(self):
28
+
29
+ appConfig = self.load_app_config()
30
+
31
+ self.model_configs = appConfig.get("models", {})
32
+ self.scheduler_configs = appConfig.get("schedulers", {})
33
 
34
+ # default device
35
+ self.devices = appConfig.get("devices", [])
36
+ device = None
37
+ data_type = 'float16'
38
+ allow_tensorfloat32 = False
39
+ if torch.cuda.is_available():
40
+ device = "cuda"
41
+ data_type = 'bfloat16'
42
+ allow_tensorfloat32 = True
43
+ elif torch.backends.mps.is_available():
44
+ device = "mps"
45
+ else:
46
+ device = "cpu"
47
 
48
+ self.current = {
49
+ "device": device,
50
+ "model": None,
51
+ "scheduler": None,
52
+ "variant": None,
53
+ "allow_tensorfloat32": allow_tensorfloat32,
54
+ "use_safetensors": False,
55
+ "data_type": data_type,
56
+ "safety_checker": False,
57
+ "requires_safety_checker": False,
58
+ "manual_seed": 42,
59
+ "inference_steps": 10,
60
+ "guidance_scale": 0.5,
61
+ "prompt": 'A white rabbit',
62
+ "negative_prompt": 'lowres, cropped, worst quality, low quality, chat bubble, chat bubbles, ugly',
63
+ }
64
 
65
+ self.assemble_code()
66
+ # code output order
67
+ # self.code[self.code_pos_device] = f'device = "{device}"'
68
+ # self.code[self.code_pos_variant] = f'variant = {initial_config["variant"]}'
69
+ # self.code[self.code_pos_tf32] = f'torch.backends.cuda.matmul.allow_tf32 = {initial_config["allow_tensorfloat32"]}'
70
+ # self.code[self.code_pos_data_type] = 'data_type = torch.bfloat16'
71
+ # self.code[self.code_pos_init_pipeline] = 'sys.exit("No model selected!")'
72
+ # self.code[self.code_pos_safety_checker] = 'pipeline.safety_checker = None'
73
+ # self.code[self.code_pos_requires_safety_checker] = f'pipeline.requires_safety_checker = {initial_config["requires_safety_checker"]}'
74
+ # self.code[self.code_pos_scheduler] = 'sys.exit("No scheduler selected!")'
75
+ # self.code[self.code_pos_generator] = f'generator = torch.Generator("{device}")'
76
+ # self.code[self.code_pos_prompt] = f'prompt = "{initial_config["prompt"]}"'
77
+ # self.code[self.code_pos_negative_prompt] = f'negative_prompt = "{initial_config["negative_prompt"]}"'
78
+ # self.code[self.code_pos_inference_steps] = f'inference_steps = {initial_config["inference_steps"]}'
79
+ # self.code[self.code_pos_manual_seed] = f'manual_seed = {initial_config["inference_steps"]}'
80
+ # self.code[self.code_pos_guidance_scale] = f'guidance_scale = {initial_config["guidance_scale"]}'
81
+ # self.code[self.code_pos_run_inference] = f'''image = pipeline(
82
+ # prompt=prompt,
83
+ # negative_prompt=negative_prompt,
84
+ # generator=generator.manual_seed(manual_seed),
85
+ # num_inference_steps=inference_steps,
86
+ # guidance_scale=guidance_scale).images[0]'''
87
 
88
+ # return initial_config, devices, model_configs, scheduler_configs, self.code
89
+
90
+ def init_config(self, request: gr.Request, inital_config):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
 
92
+ encoded_params = request.request.query_params.get('config')
93
+ return_config = {}
94
+
95
+ # get configuration from URL if GET parameter `share` is set
96
+ if encoded_params is not None:
97
+ decoded_params = base64.b64decode(encoded_params)
98
+ decoded_params = decoded_params.decode('utf-8')
99
+ decoded_params = decoded_params.replace("'", '"').replace('None', 'null').replace('False', 'false')
100
+ dict_params = json.loads(decoded_params)
101
 
102
+ return_config = dict_params
103
+
104
+ # otherwise use default initial config
105
+ else:
106
+
107
+ inital_config = inital_config.replace("'", '"').replace('None', 'null').replace('False', 'false')
108
+ dict_inital_config = json.loads(inital_config)
109
+
110
+ return_config = dict_inital_config
111
+
112
+ return [return_config['model'],
113
+ return_config['device'],
114
+ return_config['use_safetensors'],
115
+ return_config['data_type'],
116
+ return_config['variant'],
117
+ return_config['safety_checker'],
118
+ return_config['requires_safety_checker'],
119
+ return_config['scheduler'],
120
+ return_config['prompt'],
121
+ return_config['negative_prompt'],
122
+ return_config['inference_steps'],
123
+ return_config['manual_seed'],
124
+ return_config['guidance_scale']
125
+ ]
126
 
127
+ def set_config(self, key, value):
 
128
 
129
+ self.current[key] = value
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
 
131
+ return str(self.current)
132
+
133
+ def get_scheduler_description(self, scheduler):
134
+
135
+ if type(scheduler) != list and scheduler is not None:
136
+
137
+ return self.scheduler_configs[scheduler]
138
+
139
+ else:
140
+
141
+ return ''
142
+
143
+ def assemble_code(self):
144
+
145
+ self.code['001_code'] = f'''device = "{self.current['device']}"'''
146
+ if self.current['data_type'] == "bfloat16":
147
+ self.code['002_data_type'] = 'data_type = torch.bfloat16'
148
+ else:
149
+ self.code['002_data_type'] = 'data_type = torch.float16'
150
+ self.code['003_tf32'] = f'torch.backends.cuda.matmul.allow_tf32 = {self.current["allow_tensorfloat32"]}'
151
+ if str(self.current["variant"]) == 'None':
152
+ self.code['004_variant'] = f'variant = {self.current["variant"]}'
153
+ else:
154
+ self.code['004_variant'] = f'variant = "{self.current["variant"]}"'
155
+ self.code['050_init_pipe'] = f'''pipeline = DiffusionPipeline.from_pretrained(
156
+ "{self.current['model']}",
157
+ use_safetensors=use_safetensors,
158
+ torch_dtype=data_type,
159
+ variant=variant).to(device)'''
160
+
161
+ self.code['054_requires_safety_checker'] = f'pipeline.requires_safety_checker = {self.current["requires_safety_checker"]}'
162
+
163
+ if not self.current["safety_checker"] or str(self.current["safety_checker"]).lower == 'false':
164
+ self.code['055_safety_checker'] = f'pipeline.safety_checker = None'
165
+ else:
166
+ self.code['055_safety_checker'] = ''
167
+
168
+ self.code['060_scheduler'] = f'pipeline.scheduler = {self.current["scheduler"]}.from_config(pipeline.scheduler.config)'
169
+
170
+ if self.current['manual_seed'] < 0 or self.current['manual_seed'] is None or self.current['manual_seed'] == '':
171
+ self.code['070_generator'] = f'generator = torch.Generator("{self.current["device"]}")'
172
+ self.code['091_manual_seed'] = f'# manual_seed = {self.current["manual_seed"]}'
173
+ else:
174
+ self.code['070_generator'] = f'generator = torch.manual_seed(manual_seed)'
175
+ self.code['091_manual_seed'] = f'manual_seed = {self.current["manual_seed"]}'
176
+
177
+ self.code["080_prompt"] = f'prompt = {self.current["prompt"]}'
178
+ self.code["085_negative_prompt"] = f'negative_prompt = {self.current["negative_prompt"]}'
179
+ self.code["090_inference_steps"] = f'inference_steps = {self.current["inference_steps"]}'
180
+ self.code["095_guidance_scale"] = f'guidance_scale = {self.current["guidance_scale"]}'
181
+
182
+ self.code["100_run_inference"] = f'''image = pipeline(
183
+ prompt=prompt,
184
+ negative_prompt=negative_prompt,
185
+ generator={self.code["070_generator"]},
186
+ num_inference_steps=inference_steps,
187
+ guidance_scale=guidance_scale).images[0]'''
188
+
189
+ return '\r\n'.join(value[1] for value in sorted(self.code.items()))