n42 commited on
Commit
80c0bfe
·
1 Parent(s): 1028fae

sharing feature and refactoring

Browse files
Files changed (2) hide show
  1. app.py +38 -95
  2. config.py +113 -0
app.py CHANGED
@@ -22,31 +22,7 @@ from gradio import Interface
22
  from diffusers import AutoencoderKL
23
  import pandas as pd
24
  import base64
25
-
26
- js_get_url_parameters = """console.log(window.location);
27
- var urlParams = new URLSearchParams(window.location.hash.substr(1));
28
- var decodedParams = {};
29
- console.log(window.location);
30
- urlParams.forEach(function(value, key) {
31
- var decodedValue = atob(value);
32
- decodedParams[key] = decodedValue;
33
- });
34
- console.log(JSON.stringify(decodedParams));
35
- """
36
-
37
- def load_app_config():
38
- global appConfig
39
- try:
40
- with open('appConfig.json', 'r') as f:
41
- appConfig = json.load(f)
42
- except FileNotFoundError:
43
- print("App config file not found.")
44
- except json.JSONDecodeError:
45
- print("Error decoding JSON in app config file.")
46
- except Exception as e:
47
- print("An error occurred while loading app config:", str(e))
48
-
49
- load_app_config()
50
 
51
  # code output order
52
  code = {}
@@ -65,58 +41,9 @@ code_pos_inference_steps = '090_inference_steps'
65
  code_pos_guidance_scale = '095_guidance_scale'
66
  code_pos_run_inference = '100_run_inference'
67
 
68
- # model config
69
- model_configs = appConfig.get("models", {})
70
- models = list(model_configs.keys())
71
- model = None
72
- scheduler_configs = appConfig.get("schedulers", {})
73
- schedulers = list(scheduler_configs.keys())
74
- scheduler = None
75
-
76
- devices = appConfig.get("devices", [])
77
- device = None
78
-
79
- variant = None
80
- allow_tensorfloat32 = False
81
- use_safetensors = False
82
- data_type = 'float16'
83
- safety_checker = False
84
- requires_safety_checker = False
85
- manual_seed = 42
86
- inference_steps = 10
87
- guidance_scale = 0.5
88
- prompt = 'A white rabbit'
89
- negative_prompt = 'lowres, cropped, worst quality, low quality, chat bubble, chat bubbles, ugly'
90
-
91
- # init device parameters
92
- if torch.cuda.is_available():
93
- device = "cuda"
94
- data_type = 'bfloat16'
95
- allow_tensorfloat32 = True
96
- elif torch.backends.mps.is_available():
97
- device = "mps"
98
- else:
99
- device = "cpu"
100
-
101
- # inference config
102
- current_config = {
103
- "device": device,
104
- "model": model,
105
- "scheduler": scheduler,
106
- "variant": variant,
107
- "allow_tensorflow": allow_tensorfloat32,
108
- "use_safetensors": use_safetensors,
109
- "data_type": data_type,
110
- "safety_checker": safety_checker,
111
- "requires_safety_checker": requires_safety_checker,
112
- "manual_seed": manual_seed,
113
- "inference_steps": inference_steps,
114
- "guidance_scale": guidance_scale,
115
- "prompt": prompt,
116
- "negative_prompt": negative_prompt,
117
- }
118
-
119
- config_history = [current_config]
120
 
121
  def get_sorted_code():
122
 
@@ -131,11 +58,10 @@ def device_change(device):
131
 
132
  def models_change(model, scheduler):
133
 
134
- print(model)
135
  use_safetensors = False
136
 
137
  # no model selected (because this is UI init run)
138
- if type(model) != list:
139
 
140
  use_safetensors = str(model_configs[model]['use_safetensors'])
141
 
@@ -152,8 +78,7 @@ def models_change(model, scheduler):
152
 
153
  safety_checker_change(safety_checker)
154
  requires_safety_checker_change(requires_safety_checker)
155
-
156
- print(use_safetensors)
157
  return get_sorted_code(), use_safetensors, scheduler
158
 
159
  def data_type_change(selected_data_type):
@@ -210,7 +135,7 @@ def requires_safety_checker_change(requires_safety_checker):
210
 
211
  def schedulers_change(scheduler):
212
 
213
- if type(scheduler) != list:
214
 
215
  code[code_pos_scheduler] = f'pipeline.scheduler = {scheduler}.from_config(pipeline.scheduler.config)'
216
 
@@ -320,30 +245,31 @@ code[code_pos_run_inference] = f'''image = pipeline(
320
  num_inference_steps=inference_steps,
321
  guidance_scale=guidance_scale).images[0]'''
322
 
323
- def dict_list_to_markdown_table(data):
324
- if not data:
 
325
  return ""
326
 
327
- headers = list(data[0].keys())
328
  markdown_table = "| share | " + " | ".join(headers) + " |\n"
329
  markdown_table += "| --- | " + " | ".join(["---"] * len(headers)) + " |\n"
330
 
331
- for i, row in enumerate(data):
332
- encoded_row = base64.b64encode(str(row).encode()).decode()
333
- share_link = f'<a href="#share/{encoded_row}">📎</a>'
334
- # Construct the row with links
335
- markdown_table += f"| {share_link} | " + " | ".join(str(row.get(key, "")) for key in headers) + " |\n"
336
 
337
- # Wrap the Markdown table in a <div> tag with horizontal scrolling
338
  markdown_table = '<div style="overflow-x: auto;">\n\n' + markdown_table + '</div>'
339
 
340
  return markdown_table
341
 
342
  # interface
343
  with gr.Blocks() as demo:
344
-
 
 
345
  gr.Markdown('''## Text-2-Image Playground
346
- <small>by <a href="https://www.linkedin.com/in/nickyreinert/">Nicky Reinert</a> |
347
  home base: https://huggingface.co/spaces/n42/pictero
348
  </small>''')
349
  gr.Markdown("### Device specific settings")
@@ -389,7 +315,7 @@ with gr.Blocks() as demo:
389
  out_image = gr.Image()
390
  out_code = gr.Code(get_sorted_code(), label="Code")
391
  with gr.Row():
392
- out_current_config = gr.Code(value=str(current_config), label="Current config")
393
  with gr.Row():
394
  out_config_history = gr.Markdown(dict_list_to_markdown_table(config_history))
395
 
@@ -416,5 +342,22 @@ with gr.Blocks() as demo:
416
  in_manual_seed,
417
  in_guidance_scale
418
  ], outputs=[out_image])
419
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
420
  demo.launch()
 
22
  from diffusers import AutoencoderKL
23
  import pandas as pd
24
  import base64
25
+ from config import *
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
  # code output order
28
  code = {}
 
41
  code_pos_guidance_scale = '095_guidance_scale'
42
  code_pos_run_inference = '100_run_inference'
43
 
44
+ initial_config, devices, models, schedulers = get_inital_config()
45
+
46
+ config_history = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
  def get_sorted_code():
49
 
 
58
 
59
  def models_change(model, scheduler):
60
 
 
61
  use_safetensors = False
62
 
63
  # no model selected (because this is UI init run)
64
+ if type(model) != list and model is not None:
65
 
66
  use_safetensors = str(model_configs[model]['use_safetensors'])
67
 
 
78
 
79
  safety_checker_change(safety_checker)
80
  requires_safety_checker_change(requires_safety_checker)
81
+
 
82
  return get_sorted_code(), use_safetensors, scheduler
83
 
84
  def data_type_change(selected_data_type):
 
135
 
136
  def schedulers_change(scheduler):
137
 
138
+ if type(scheduler) != list and scheduler is not None:
139
 
140
  code[code_pos_scheduler] = f'pipeline.scheduler = {scheduler}.from_config(pipeline.scheduler.config)'
141
 
 
245
  num_inference_steps=inference_steps,
246
  guidance_scale=guidance_scale).images[0]'''
247
 
248
+ def dict_list_to_markdown_table(config_history):
249
+
250
+ if not config_history:
251
  return ""
252
 
253
+ headers = list(config_history[0].keys())
254
  markdown_table = "| share | " + " | ".join(headers) + " |\n"
255
  markdown_table += "| --- | " + " | ".join(["---"] * len(headers)) + " |\n"
256
 
257
+ for index, config in enumerate(config_history):
258
+ encoded_config = base64.b64encode(str(config).encode()).decode()
259
+ share_link = f'<a target="_blank" href="?config={encoded_config}">📎</a>'
260
+ markdown_table += f"| {share_link} | " + " | ".join(str(config.get(key, "")) for key in headers) + " |\n"
 
261
 
 
262
  markdown_table = '<div style="overflow-x: auto;">\n\n' + markdown_table + '</div>'
263
 
264
  return markdown_table
265
 
266
  # interface
267
  with gr.Blocks() as demo:
268
+
269
+ in_import_config = gr.Text()
270
+
271
  gr.Markdown('''## Text-2-Image Playground
272
+ <small>by <a target="_blank" href="https://www.linkedin.com/in/nickyreinert/">Nicky Reinert</a> |
273
  home base: https://huggingface.co/spaces/n42/pictero
274
  </small>''')
275
  gr.Markdown("### Device specific settings")
 
315
  out_image = gr.Image()
316
  out_code = gr.Code(get_sorted_code(), label="Code")
317
  with gr.Row():
318
+ out_current_config = gr.Code(value=str(initial_config), label="Current config")
319
  with gr.Row():
320
  out_config_history = gr.Markdown(dict_list_to_markdown_table(config_history))
321
 
 
342
  in_manual_seed,
343
  in_guidance_scale
344
  ], outputs=[out_image])
345
+
346
+ demo.load(fn=init_config, inputs=out_current_config,
347
+ outputs=[
348
+ in_models,
349
+ in_devices,
350
+ in_use_safetensors,
351
+ in_data_type,
352
+ in_variant,
353
+ in_safety_checker,
354
+ in_requires_safety_checker,
355
+ in_schedulers,
356
+ in_prompt,
357
+ in_negative_prompt,
358
+ in_inference_steps,
359
+ in_manual_seed,
360
+ in_guidance_scale
361
+ ])
362
+
363
  demo.launch()
config.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import base64
3
+ import json
4
+ import torch
5
+
6
+ def load_app_config():
7
+ try:
8
+ with open('appConfig.json', 'r') as f:
9
+ appConfig = json.load(f)
10
+ except FileNotFoundError:
11
+ print("App config file not found.")
12
+ except json.JSONDecodeError:
13
+ print("Error decoding JSON in app config file.")
14
+ except Exception as e:
15
+ print("An error occurred while loading app config:", str(e))
16
+
17
+ return appConfig
18
+
19
+ appConfig = load_app_config()
20
+
21
+ def get_inital_config():
22
+
23
+ # model config
24
+ model_configs = appConfig.get("models", {})
25
+ models = list(model_configs.keys())
26
+ model = None
27
+ scheduler_configs = appConfig.get("schedulers", {})
28
+ schedulers = list(scheduler_configs.keys())
29
+ scheduler = None
30
+
31
+ devices = appConfig.get("devices", [])
32
+ device = None
33
+
34
+ variant = None
35
+ allow_tensorfloat32 = False
36
+ use_safetensors = False
37
+ data_type = 'float16'
38
+ safety_checker = False
39
+ requires_safety_checker = False
40
+ manual_seed = 42
41
+ inference_steps = 10
42
+ guidance_scale = 0.5
43
+ prompt = 'A white rabbit'
44
+ negative_prompt = 'lowres, cropped, worst quality, low quality, chat bubble, chat bubbles, ugly'
45
+
46
+ # # init device parameters
47
+ if torch.cuda.is_available():
48
+ device = "cuda"
49
+ data_type = 'bfloat16'
50
+ allow_tensorfloat32 = True
51
+ elif torch.backends.mps.is_available():
52
+ device = "mps"
53
+ else:
54
+ device = "cpu"
55
+
56
+ # inference config
57
+ initial_config = {
58
+ "device": device,
59
+ "model": model,
60
+ "scheduler": scheduler,
61
+ "variant": variant,
62
+ "allow_tensorflow": allow_tensorfloat32,
63
+ "use_safetensors": use_safetensors,
64
+ "data_type": data_type,
65
+ "safety_checker": safety_checker,
66
+ "requires_safety_checker": requires_safety_checker,
67
+ "manual_seed": manual_seed,
68
+ "inference_steps": inference_steps,
69
+ "guidance_scale": guidance_scale,
70
+ "prompt": prompt,
71
+ "negative_prompt": negative_prompt,
72
+ }
73
+
74
+
75
+ return initial_config, devices, models, schedulers
76
+
77
+ def init_config(request: gr.Request, inital_config):
78
+
79
+ encoded_params = request.request.query_params.get('config')
80
+ return_config = {}
81
+
82
+ # get configuration from URL if GET parameter `share` is set
83
+ if encoded_params is not None:
84
+ decoded_params = base64.b64decode(encoded_params)
85
+ decoded_params = decoded_params.decode('utf-8')
86
+ decoded_params = decoded_params.replace("'", '"').replace('None', 'null').replace('False', 'false')
87
+ dict_params = json.loads(decoded_params)
88
+
89
+ return_config = dict_params
90
+
91
+ # otherwise use default initial config
92
+ else:
93
+
94
+ inital_config = inital_config.replace("'", '"').replace('None', 'null').replace('False', 'false')
95
+ dict_inital_config = json.loads(inital_config)
96
+
97
+ return_config = dict_inital_config
98
+
99
+ return [return_config['model'],
100
+ return_config['device'],
101
+ return_config['use_safetensors'],
102
+ return_config['data_type'],
103
+ return_config['variant'],
104
+ return_config['safety_checker'],
105
+ return_config['requires_safety_checker'],
106
+ return_config['scheduler'],
107
+ return_config['prompt'],
108
+ return_config['negative_prompt'],
109
+ return_config['inference_steps'],
110
+ return_config['manual_seed'],
111
+ return_config['guidance_scale']
112
+ ]
113
+