n42 commited on
Commit
c6a81f6
·
1 Parent(s): 56914a9

adding refiner option

Browse files
Files changed (3) hide show
  1. app.py +27 -6
  2. appConfig.json +23 -14
  3. config.py +24 -4
app.py CHANGED
@@ -16,13 +16,15 @@ def models_change(model, scheduler, config):
16
  config = set_config(config, 'model', model)
17
 
18
  use_safetensors = False
 
19
 
20
  # no model selected (because this is UI init run)
21
  if type(model) != list and str(model) != 'None':
22
 
23
  use_safetensors = str(models[model]['use_safetensors'])
24
  model_description = models[model]['description']
25
-
 
26
  # if no scheduler is selected, choose the default one for this model
27
  if scheduler == None:
28
 
@@ -34,11 +36,12 @@ def models_change(model, scheduler, config):
34
 
35
  config["use_safetensors"] = str(use_safetensors)
36
  config["scheduler"] = str(scheduler)
 
37
 
38
  # safety_checker_change(in_safety_checker.value, config)
39
  # requires_safety_checker_change(in_requires_safety_checker.value, config)
40
 
41
- return model_description, use_safetensors, scheduler, config, str(config), assemble_code(config)
42
 
43
  def data_type_change(data_type, config):
44
 
@@ -132,6 +135,15 @@ def run_inference(config, config_history, progress=gr.Progress(track_tqdm=True))
132
  torch_dtype = get_data_type(config["data_type"]),
133
  variant = get_variant(config["variant"])).to(config["device"])
134
 
 
 
 
 
 
 
 
 
 
135
  if str(config["safety_checker"]).lower() == 'false':
136
  pipeline.safety_checker = None
137
 
@@ -151,11 +163,18 @@ def run_inference(config, config_history, progress=gr.Progress(track_tqdm=True))
151
  negative_prompt = config["negative_prompt"],
152
  generator = generator,
153
  num_inference_steps = int(config["inference_steps"]),
154
- guidance_scale = float(config["guidance_scale"])).images[0]
155
-
 
 
 
 
 
 
 
156
  config_history.append(config.copy())
157
 
158
- return image, dict_list_to_markdown_table(config_history), config_history
159
 
160
  else:
161
 
@@ -191,6 +210,7 @@ with gr.Blocks(analytics_enabled=False) as demo:
191
  with gr.Row():
192
  with gr.Column(scale=1):
193
  in_use_safetensors = gr.Radio(label="Use safe tensors:", choices=["True", "False"], interactive=False)
 
194
  with gr.Column(scale=1):
195
  in_safety_checker = gr.Radio(label="Enable safety checker:", value=config.value["safety_checker"], choices=["True", "False"])
196
  in_requires_safety_checker = gr.Radio(label="Requires safety checker:", value=config.value["requires_safety_checker"], choices=["True", "False"])
@@ -229,7 +249,7 @@ with gr.Blocks(analytics_enabled=False) as demo:
229
  in_data_type.change(data_type_change, inputs=[in_data_type, config], outputs=[config, out_config, out_code])
230
  in_allow_tensorfloat32.change(tensorfloat32_change, inputs=[in_allow_tensorfloat32, config], outputs=[config, out_config, out_code])
231
  in_variant.change(variant_change, inputs=[in_variant, config], outputs=[config, out_config, out_code])
232
- in_models.change(models_change, inputs=[in_models, in_schedulers, config], outputs=[out_model_description, in_use_safetensors, in_schedulers, config, out_config, out_code])
233
  in_safety_checker.change(safety_checker_change, inputs=[in_safety_checker, config], outputs=[config, out_config, out_code])
234
  in_requires_safety_checker.change(requires_safety_checker_change, inputs=[in_requires_safety_checker, config], outputs=[config, out_config, out_code])
235
  in_schedulers.change(schedulers_change, inputs=[in_schedulers, config], outputs=[out_scheduler_description, config, out_config, out_code])
@@ -249,6 +269,7 @@ with gr.Blocks(analytics_enabled=False) as demo:
249
  in_devices,
250
  in_use_safetensors,
251
  in_data_type,
 
252
  in_variant,
253
  in_safety_checker,
254
  in_requires_safety_checker,
 
16
  config = set_config(config, 'model', model)
17
 
18
  use_safetensors = False
19
+ refiner = ""
20
 
21
  # no model selected (because this is UI init run)
22
  if type(model) != list and str(model) != 'None':
23
 
24
  use_safetensors = str(models[model]['use_safetensors'])
25
  model_description = models[model]['description']
26
+ refiner = models[model]['refiner']
27
+
28
  # if no scheduler is selected, choose the default one for this model
29
  if scheduler == None:
30
 
 
36
 
37
  config["use_safetensors"] = str(use_safetensors)
38
  config["scheduler"] = str(scheduler)
39
+ config["refiner"] = str(refiner)
40
 
41
  # safety_checker_change(in_safety_checker.value, config)
42
  # requires_safety_checker_change(in_requires_safety_checker.value, config)
43
 
44
+ return model_description, refiner, use_safetensors, scheduler, config, str(config), assemble_code(config)
45
 
46
  def data_type_change(data_type, config):
47
 
 
135
  torch_dtype = get_data_type(config["data_type"]),
136
  variant = get_variant(config["variant"])).to(config["device"])
137
 
138
+ if config['refiner'] != '':
139
+ refiner = DiffusionPipeline.from_pretrained(
140
+ config['refiner'],
141
+ text_encoder_2=pipeline.text_encoder_2,
142
+ vae=pipeline.vae,
143
+ torch_dtype=get_data_type(config["data_type"]),
144
+ use_safetensors=get_bool(config["use_safetensors"]),
145
+ variant = get_variant(config["variant"])).to(config["device"])
146
+
147
  if str(config["safety_checker"]).lower() == 'false':
148
  pipeline.safety_checker = None
149
 
 
163
  negative_prompt = config["negative_prompt"],
164
  generator = generator,
165
  num_inference_steps = int(config["inference_steps"]),
166
+ guidance_scale = float(config["guidance_scale"])).images
167
+
168
+ if config['refiner'] != '':
169
+ image = refiner(
170
+ prompt = config["prompt"],
171
+ num_inference_steps = int(config["inference_steps"]),
172
+ image=image,
173
+ ).images
174
+
175
  config_history.append(config.copy())
176
 
177
+ return image[0], dict_list_to_markdown_table(config_history), config_history
178
 
179
  else:
180
 
 
210
  with gr.Row():
211
  with gr.Column(scale=1):
212
  in_use_safetensors = gr.Radio(label="Use safe tensors:", choices=["True", "False"], interactive=False)
213
+ in_model_refiner = gr.Textbox(value="", label="Refiner")
214
  with gr.Column(scale=1):
215
  in_safety_checker = gr.Radio(label="Enable safety checker:", value=config.value["safety_checker"], choices=["True", "False"])
216
  in_requires_safety_checker = gr.Radio(label="Requires safety checker:", value=config.value["requires_safety_checker"], choices=["True", "False"])
 
249
  in_data_type.change(data_type_change, inputs=[in_data_type, config], outputs=[config, out_config, out_code])
250
  in_allow_tensorfloat32.change(tensorfloat32_change, inputs=[in_allow_tensorfloat32, config], outputs=[config, out_config, out_code])
251
  in_variant.change(variant_change, inputs=[in_variant, config], outputs=[config, out_config, out_code])
252
+ in_models.change(models_change, inputs=[in_models, in_schedulers, config], outputs=[out_model_description, in_model_refiner, in_use_safetensors, in_schedulers, config, out_config, out_code])
253
  in_safety_checker.change(safety_checker_change, inputs=[in_safety_checker, config], outputs=[config, out_config, out_code])
254
  in_requires_safety_checker.change(requires_safety_checker_change, inputs=[in_requires_safety_checker, config], outputs=[config, out_config, out_code])
255
  in_schedulers.change(schedulers_change, inputs=[in_schedulers, config], outputs=[out_scheduler_description, config, out_config, out_code])
 
269
  in_devices,
270
  in_use_safetensors,
271
  in_data_type,
272
+ in_model_refiner,
273
  in_variant,
274
  in_safety_checker,
275
  in_requires_safety_checker,
appConfig.json CHANGED
@@ -4,47 +4,56 @@
4
  "sd-dreambooth-library/solo-levelling-art-style": {
5
  "use_safetensors": false,
6
  "description": "see https://huggingface.co/sd-dreambooth-library/solo-levelling-art-style",
7
- "scheduler": "DDIMScheduler"
 
8
  },
9
  "CompVis/stable-diffusion-v1-4": {
10
  "use_safetensors": true,
11
  "description": "see https://huggingface.co/CompVis/stable-diffusion-v1-4",
12
- "scheduler": "EulerDiscreteScheduler"
 
13
  },
14
  "runwayml/stable-diffusion-v1-5": {
15
  "use_safetensors": true,
16
  "description": "see https://huggingface.co/runwayml/stable-diffusion-v1-5",
17
- "scheduler": "DDPMScheduler"
 
18
  },
19
  "stabilityai/stable-diffusion-2-1": {
20
  "use_safetensors": true,
21
  "description": "see https://huggingface.co/stabilityai/stable-diffusion-2-1",
22
- "scheduler": "DPMSolverMultistepScheduler"
 
23
  },
24
  "stabilityai/stable-diffusion-xl-base-1.0": {
25
  "use_safetensors": true,
26
  "description": "see https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0",
27
- "scheduler": "DDPMScheduler"
 
28
  },
29
  "sd-dreambooth-library/house-emblem": {
30
  "use_safetensors": false,
31
  "description": "see https://huggingface.co/sd-dreambooth-library/house-emblem",
32
- "scheduler": "DDPMScheduler"
 
33
  },
34
  "Envvi/Inkpunk-Diffusion": {
35
- "use_safetensors": true,
36
- "description": "Another wonderful model",
37
- "scheduler": "DDPMScheduler"
 
38
  },
39
  "Stelath/textual_inversion_comic_strip_fp16": {
40
  "use_safetensors": true,
41
- "description": "Another wonderful model",
42
- "scheduler": "DDPMScheduler"
 
43
  },
44
  "sd-dreambooth-library/herge-style": {
45
- "use_safetensors": true,
46
- "description": "Another wonderful model",
47
- "scheduler": "DDPMScheduler"
 
48
  }
49
 
50
  },
 
4
  "sd-dreambooth-library/solo-levelling-art-style": {
5
  "use_safetensors": false,
6
  "description": "see https://huggingface.co/sd-dreambooth-library/solo-levelling-art-style",
7
+ "scheduler": "DDIMScheduler",
8
+ "refiner": ""
9
  },
10
  "CompVis/stable-diffusion-v1-4": {
11
  "use_safetensors": true,
12
  "description": "see https://huggingface.co/CompVis/stable-diffusion-v1-4",
13
+ "scheduler": "EulerDiscreteScheduler",
14
+ "refiner": ""
15
  },
16
  "runwayml/stable-diffusion-v1-5": {
17
  "use_safetensors": true,
18
  "description": "see https://huggingface.co/runwayml/stable-diffusion-v1-5",
19
+ "scheduler": "DDPMScheduler",
20
+ "refiner": ""
21
  },
22
  "stabilityai/stable-diffusion-2-1": {
23
  "use_safetensors": true,
24
  "description": "see https://huggingface.co/stabilityai/stable-diffusion-2-1",
25
+ "scheduler": "DPMSolverMultistepScheduler",
26
+ "refiner": ""
27
  },
28
  "stabilityai/stable-diffusion-xl-base-1.0": {
29
  "use_safetensors": true,
30
  "description": "see https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0",
31
+ "scheduler": "DDPMScheduler",
32
+ "refiner": "stabilityai/stable-diffusion-xl-refiner-1.0"
33
  },
34
  "sd-dreambooth-library/house-emblem": {
35
  "use_safetensors": false,
36
  "description": "see https://huggingface.co/sd-dreambooth-library/house-emblem",
37
+ "scheduler": "DDPMScheduler",
38
+ "refiner": ""
39
  },
40
  "Envvi/Inkpunk-Diffusion": {
41
+ "use_safetensors": false,
42
+ "description": "see https://huggingface.co/Envvi/Inkpunk-Diffusion",
43
+ "scheduler": "DDPMScheduler",
44
+ "refiner": ""
45
  },
46
  "Stelath/textual_inversion_comic_strip_fp16": {
47
  "use_safetensors": true,
48
+ "description": "see https://huggingface.co/Stelath/textual_inversion_comic_strip_fp16",
49
+ "scheduler": "DDPMScheduler",
50
+ "refiner": ""
51
  },
52
  "sd-dreambooth-library/herge-style": {
53
+ "use_safetensors": false,
54
+ "description": "see https://huggingface.co/sd-dreambooth-library/herge-style",
55
+ "scheduler": "DDPMScheduler",
56
+ "refiner": ""
57
  }
58
 
59
  },
config.py CHANGED
@@ -42,6 +42,7 @@ def get_initial_config():
42
  "allow_tensorfloat32": allow_tensorfloat32,
43
  "use_safetensors": "False",
44
  "data_type": data_type,
 
45
  "safety_checker": "False",
46
  "requires_safety_checker": "False",
47
  "manual_seed": 42,
@@ -78,6 +79,7 @@ def get_config_from_url(initial_config, request: Request):
78
  return_config['device'],
79
  return_config['use_safetensors'],
80
  return_config['data_type'],
 
81
  return_config['variant'],
82
  return_config['safety_checker'],
83
  return_config['requires_safety_checker'],
@@ -139,12 +141,20 @@ def assemble_code(str_config):
139
  torch_dtype=data_type,
140
  variant=variant).to(device)'''
141
 
 
 
 
 
 
 
 
 
 
 
142
  code['054_requires_safety_checker'] = f'pipeline.requires_safety_checker = {config["requires_safety_checker"]}'
143
 
144
  if str(config["safety_checker"]).lower() == 'false':
145
  code['055_safety_checker'] = f'pipeline.safety_checker = None'
146
- else:
147
- code['055_safety_checker'] = ''
148
 
149
  code['060_scheduler'] = f'pipeline.scheduler = {config["scheduler"]}.from_config(pipeline.scheduler.config)'
150
 
@@ -165,6 +175,16 @@ def assemble_code(str_config):
165
  negative_prompt=negative_prompt,
166
  generator=generator,
167
  num_inference_steps=inference_steps,
168
- guidance_scale=guidance_scale).images[0]'''
169
-
 
 
 
 
 
 
 
 
 
 
170
  return '\r\n'.join(value[1] for value in sorted(code.items()))
 
42
  "allow_tensorfloat32": allow_tensorfloat32,
43
  "use_safetensors": "False",
44
  "data_type": data_type,
45
+ "refiner": "",
46
  "safety_checker": "False",
47
  "requires_safety_checker": "False",
48
  "manual_seed": 42,
 
79
  return_config['device'],
80
  return_config['use_safetensors'],
81
  return_config['data_type'],
82
+ return_config['refiner'],
83
  return_config['variant'],
84
  return_config['safety_checker'],
85
  return_config['requires_safety_checker'],
 
141
  torch_dtype=data_type,
142
  variant=variant).to(device)'''
143
 
144
+ if config['refiner'] != '':
145
+ code['051_refiner'] = f'''refiner = DiffusionPipeline.from_pretrained(
146
+ "{config['refiner']}",
147
+ text_encoder_2 = base.text_encoder_2,
148
+ vae = base.vae,
149
+ torch_dtype = data_type,
150
+ use_safetensors = use_safetensors,
151
+ variant=variant,
152
+ ).to(device)'''
153
+
154
  code['054_requires_safety_checker'] = f'pipeline.requires_safety_checker = {config["requires_safety_checker"]}'
155
 
156
  if str(config["safety_checker"]).lower() == 'false':
157
  code['055_safety_checker'] = f'pipeline.safety_checker = None'
 
 
158
 
159
  code['060_scheduler'] = f'pipeline.scheduler = {config["scheduler"]}.from_config(pipeline.scheduler.config)'
160
 
 
175
  negative_prompt=negative_prompt,
176
  generator=generator,
177
  num_inference_steps=inference_steps,
178
+ guidance_scale=guidance_scale).images
179
+ '''
180
+
181
+ if config['refiner'] != '':
182
+ code["110_run_refiner"] = f'''image = refiner(
183
+ prompt=prompt,
184
+ negative_prompt=negative_prompt,
185
+ num_inference_steps=inference_steps
186
+ ).images[0]'''
187
+
188
+ code["200_show_image"] = 'image[0]'
189
+
190
  return '\r\n'.join(value[1] for value in sorted(code.items()))