n42 commited on
Commit
00e74d5
·
1 Parent(s): 5bd02e2

fixing wrong history address

Browse files
Files changed (2) hide show
  1. app.py +26 -73
  2. config.py +13 -15
app.py CHANGED
@@ -104,24 +104,24 @@ def get_tensorfloat32(allow_tensorfloat32):
104
 
105
  return True if str(allow_tensorfloat32).lower() == 'true' else False
106
 
107
- def get_scheduler(scheduler, config):
108
 
109
  if scheduler == "DDPMScheduler":
110
- return DDPMScheduler.from_config(config)
111
  elif scheduler == "DDIMScheduler":
112
- return DDIMScheduler.from_config(config)
113
  elif scheduler == "PNDMScheduler":
114
- return PNDMScheduler.from_config(config)
115
  elif scheduler == "LMSDiscreteScheduler":
116
- return LMSDiscreteScheduler.from_config(config)
117
  elif scheduler == "EulerAncestralDiscreteScheduler":
118
- return EulerAncestralDiscreteScheduler.from_config(config)
119
  elif scheduler == "EulerDiscreteScheduler":
120
- return EulerDiscreteScheduler.from_config(config)
121
  elif scheduler == "DPMSolverMultistepScheduler":
122
- return DPMSolverMultistepScheduler.from_config(config)
123
  else:
124
- return DPMSolverMultistepScheduler.from_config(config)
125
 
126
  # get
127
  # - initial configuration,
@@ -129,41 +129,9 @@ def get_scheduler(scheduler, config):
129
  # - a list of available models from the config file
130
  # - a list of available schedulers from the config file
131
  # - a dict that contains code to for reproduction
132
- config.set_inital_config()
133
- # config.current, devices, model_configs, scheduler_configs, code = config.get_inital_config()
134
-
135
- # device = config.current["device"]
136
- # model = config.current["model"]
137
- # scheduler = config.current["scheduler"]
138
- # variant = config.current["variant"]
139
- # allow_tensorfloat32 = config.current["allow_tensorfloat32"]
140
- # use_safetensors = config.current["use_safetensors"]
141
- # data_type = config.current["data_type"]
142
- # safety_checker = config.current["safety_checker"]
143
- # requires_safety_checker = config.current["requires_safety_checker"]
144
- # manual_seed = config.current["manual_seed"]
145
- # inference_steps = config.current["inference_steps"]
146
- # guidance_scale = config.current["guidance_scale"]
147
- # prompt = config.current["prompt"]
148
- # negative_prompt = config.current["negative_prompt"]
149
-
150
- config_history = []
151
 
152
  # pipeline
153
- def run_inference(model,
154
- device,
155
- use_safetensors,
156
- data_type,
157
- variant,
158
- safety_checker,
159
- requires_safety_checker,
160
- scheduler,
161
- prompt,
162
- negative_prompt,
163
- inference_steps,
164
- manual_seed,
165
- guidance_scale,
166
- progress=gr.Progress(track_tqdm=True)):
167
 
168
  if config.current["model"] != None and config.current["scheduler"] != None:
169
 
@@ -175,35 +143,34 @@ def run_inference(model,
175
 
176
  pipeline = DiffusionPipeline.from_pretrained(
177
  config.current["model"],
178
- use_safetensors=config.current["use_safetensors"],
179
- torch_dtype=get_data_type(config.current["data_type"]),
180
- variant=variant).to(config.current["device"])
181
 
182
  if config.current["safety_checker"] is None or str(config.current["safety_checker"]).lower == 'false':
183
  pipeline.safety_checker = None
184
 
185
  pipeline.requires_safety_checker = config.current["requires_safety_checker"]
186
 
187
- pipeline.scheduler = get_scheduler(scheduler, pipeline.scheduler.config)
188
 
189
- manual_seed = int(manual_seed)
190
- if manual_seed < 0 or manual_seed is None or manual_seed == '':
191
- generator = torch.Generator(device)
192
  else:
193
  generator = torch.manual_seed(42)
194
 
195
  progress((3,3), desc="Creating the result...")
196
 
197
  image = pipeline(
198
- prompt=prompt,
199
- negative_prompt=negative_prompt,
200
- generator=generator,
201
- num_inference_steps=int(inference_steps),
202
- guidance_scale=float(guidance_scale)).images[0]
203
 
204
- config_history.append(config.current.copy())
205
 
206
- return image, dict_list_to_markdown_table(config_history)
207
 
208
  else:
209
 
@@ -297,24 +264,10 @@ with gr.Blocks() as demo:
297
  in_guidance_scale.change(guidance_scale_change, inputs=[in_guidance_scale], outputs=[out_current_config, out_code])
298
  in_prompt.change(prompt_change, inputs=[in_prompt], outputs=[out_current_config, out_code])
299
  in_negative_prompt.change(negative_prompt_change, inputs=[in_negative_prompt], outputs=[out_current_config, out_code])
300
- btn_start_pipeline.click(run_inference, inputs=[
301
- in_models,
302
- in_devices,
303
- in_use_safetensors,
304
- in_data_type,
305
- in_variant,
306
- in_safety_checker,
307
- in_requires_safety_checker,
308
- in_schedulers,
309
- in_prompt,
310
- in_negative_prompt,
311
- in_inference_steps,
312
- in_manual_seed,
313
- in_guidance_scale
314
- ], outputs=[
315
- out_image,
316
- out_config_history])
317
 
 
 
318
  demo.load(fn=init_config, inputs=out_current_config,
319
  outputs=[
320
  in_models,
 
104
 
105
  return True if str(allow_tensorfloat32).lower() == 'true' else False
106
 
107
+ def get_scheduler(scheduler, pipeline_config):
108
 
109
  if scheduler == "DDPMScheduler":
110
+ return DDPMScheduler.from_config(pipeline_config)
111
  elif scheduler == "DDIMScheduler":
112
+ return DDIMScheduler.from_config(pipeline_config)
113
  elif scheduler == "PNDMScheduler":
114
+ return PNDMScheduler.from_config(pipeline_config)
115
  elif scheduler == "LMSDiscreteScheduler":
116
+ return LMSDiscreteScheduler.from_config(pipeline_config)
117
  elif scheduler == "EulerAncestralDiscreteScheduler":
118
+ return EulerAncestralDiscreteScheduler.from_config(pipeline_config)
119
  elif scheduler == "EulerDiscreteScheduler":
120
+ return EulerDiscreteScheduler.from_config(pipeline_config)
121
  elif scheduler == "DPMSolverMultistepScheduler":
122
+ return DPMSolverMultistepScheduler.from_config(pipeline_config)
123
  else:
124
+ return DPMSolverMultistepScheduler.from_config(pipeline_config)
125
 
126
  # get
127
  # - initial configuration,
 
129
  # - a list of available models from the config file
130
  # - a list of available schedulers from the config file
131
  # - a dict that contains code to for reproduction
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
 
133
  # pipeline
134
+ def run_inference(progress=gr.Progress(track_tqdm=True)):
 
 
 
 
 
 
 
 
 
 
 
 
 
135
 
136
  if config.current["model"] != None and config.current["scheduler"] != None:
137
 
 
143
 
144
  pipeline = DiffusionPipeline.from_pretrained(
145
  config.current["model"],
146
+ use_safetensors = config.current["use_safetensors"],
147
+ torch_dtype = get_data_type(config.current["data_type"]),
148
+ variant = config.current["variant"]).to(config.current["device"])
149
 
150
  if config.current["safety_checker"] is None or str(config.current["safety_checker"]).lower == 'false':
151
  pipeline.safety_checker = None
152
 
153
  pipeline.requires_safety_checker = config.current["requires_safety_checker"]
154
 
155
+ pipeline.scheduler = get_scheduler(config.current["scheduler"], pipeline.scheduler.config)
156
 
157
+ if config.current["manual_seed"] < 0 or config.current["manual_seed"] is None or config.current["manual_seed"] == '':
158
+ generator = torch.Generator(config.current["device"])
 
159
  else:
160
  generator = torch.manual_seed(42)
161
 
162
  progress((3,3), desc="Creating the result...")
163
 
164
  image = pipeline(
165
+ prompt = config.current["prompt"],
166
+ negative_prompt = config.current["negative_prompt"],
167
+ generator = generator,
168
+ num_inference_steps = config.current["inference_steps"],
169
+ guidance_scale = config.current["guidance_scale"]).images[0]
170
 
171
+ config.history.append(config.current.copy())
172
 
173
+ return image, dict_list_to_markdown_table(config.history)
174
 
175
  else:
176
 
 
264
  in_guidance_scale.change(guidance_scale_change, inputs=[in_guidance_scale], outputs=[out_current_config, out_code])
265
  in_prompt.change(prompt_change, inputs=[in_prompt], outputs=[out_current_config, out_code])
266
  in_negative_prompt.change(negative_prompt_change, inputs=[in_negative_prompt], outputs=[out_current_config, out_code])
267
+ btn_start_pipeline.click(run_inference, inputs=[], outputs=[out_image, out_config_history])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
268
 
269
+ # send current respect initial config to init_config to populate parameters to all relevant input fields
270
+ # if GET parameter is set, it will overwrite initial config parameters
271
  demo.load(fn=init_config, inputs=out_current_config,
272
  outputs=[
273
  in_models,
config.py CHANGED
@@ -49,21 +49,6 @@ class Config:
49
  self.history = []
50
  self.devices = []
51
 
52
- def load_app_config(self):
53
- try:
54
- with open('appConfig.json', 'r') as f:
55
- appConfig = json.load(f)
56
- except FileNotFoundError:
57
- print("App config file not found.")
58
- except json.JSONDecodeError:
59
- print("Error decoding JSON in app config file.")
60
- except Exception as e:
61
- print("An error occurred while loading app config:", str(e))
62
-
63
- return appConfig
64
-
65
- def set_inital_config(self):
66
-
67
  appConfig = self.load_app_config()
68
 
69
  self.model_configs = appConfig.get("models", {})
@@ -102,6 +87,19 @@ class Config:
102
 
103
  self.assemble_code()
104
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  def set_config(self, key, value):
106
 
107
  self.current[key] = value
 
49
  self.history = []
50
  self.devices = []
51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  appConfig = self.load_app_config()
53
 
54
  self.model_configs = appConfig.get("models", {})
 
87
 
88
  self.assemble_code()
89
 
90
+ def load_app_config(self):
91
+ try:
92
+ with open('appConfig.json', 'r') as f:
93
+ appConfig = json.load(f)
94
+ except FileNotFoundError:
95
+ print("App config file not found.")
96
+ except json.JSONDecodeError:
97
+ print("Error decoding JSON in app config file.")
98
+ except Exception as e:
99
+ print("An error occurred while loading app config:", str(e))
100
+
101
+ return appConfig
102
+
103
  def set_config(self, key, value):
104
 
105
  self.current[key] = value