Fabrice-TIERCELIN commited on
Commit
2299926
·
verified ·
1 Parent(s): 3acc825

Fix function

Browse files
Files changed (1) hide show
  1. gradio_demo.py +39 -38
gradio_demo.py CHANGED
@@ -180,7 +180,7 @@ def stage2_process(
180
  input_image = upscale_image(input_image, upscale, unit_resolution=32,
181
  min_size=min_size)
182
 
183
- samples = restore(
184
  model,
185
  edm_steps,
186
  s_stage1,
@@ -199,43 +199,8 @@ def stage2_process(
199
  spt_linear_s_stage2
200
  )
201
 
202
- x_samples = (einops.rearrange(samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().round().clip(
203
- 0, 255).astype(np.uint8)
204
- results = [x_samples[i] for i in range(num_samples)]
205
-
206
- if args.log_history:
207
- os.makedirs(f'./history/{event_id[:5]}/{event_id[5:]}', exist_ok=True)
208
- with open(f'./history/{event_id[:5]}/{event_id[5:]}/logs.txt', 'w') as f:
209
- f.write(str(event_dict))
210
- f.close()
211
- Image.fromarray(input_image).save(f'./history/{event_id[:5]}/{event_id[5:]}/LQ.png')
212
- for i, result in enumerate(results):
213
- Image.fromarray(result).save(f'./history/{event_id[:5]}/{event_id[5:]}/HQ_{i}.png')
214
-
215
- # All the results have the same size
216
- result_height, result_width, result_channel = np.array(results[0]).shape
217
-
218
- print('<<== stage2_process')
219
- end = time.time()
220
- secondes = int(end - start)
221
- minutes = math.floor(secondes / 60)
222
- secondes = secondes - (minutes * 60)
223
- hours = math.floor(minutes / 60)
224
- minutes = minutes - (hours * 60)
225
- information = ("Start the process again if you want a different result. " if randomize_seed else "") + \
226
- "The new image resolution is " + str(result_width) + \
227
- " pixels large and " + str(result_height) + \
228
- " pixels high, so a resolution of " + f'{result_width * result_height:,}' + " pixels. " + \
229
- "The image(s) has(ve) been generated in " + \
230
- ((str(hours) + " h, ") if hours != 0 else "") + \
231
- ((str(minutes) + " min, ") if hours != 0 or minutes != 0 else "") + \
232
- str(secondes) + " sec."
233
- print(information)
234
 
235
- # Only one image can be shown in the slider
236
- return [noisy_image] + [results[0]], gr.update(format = output_format, value = [noisy_image] + results), gr.update(value = information, visible = True), event_id
237
-
238
- @spaces.GPU(duration=600)
239
  def restore(
240
  model,
241
  edm_steps,
@@ -269,12 +234,48 @@ def restore(
269
  model.ae_dtype = convert_dtype(ae_dtype)
270
  model.model.dtype = convert_dtype(diff_dtype)
271
 
272
- return model.batchify_sample(LQ, captions, num_steps=edm_steps, restoration_scale=s_stage1, s_churn=s_churn,
273
  s_noise=s_noise, cfg_scale=s_cfg, control_scale=s_stage2, seed=seed,
274
  num_samples=num_samples, p_p=a_prompt, n_p=n_prompt, color_fix_type=color_fix_type,
275
  use_linear_CFG=linear_CFG, use_linear_control_scale=linear_s_stage2,
276
  cfg_scale_start=spt_linear_CFG, control_scale_start=spt_linear_s_stage2)
277
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
278
  def load_and_reset(param_setting):
279
  print('load_and_reset ==>>')
280
  if torch.cuda.device_count() == 0:
 
180
  input_image = upscale_image(input_image, upscale, unit_resolution=32,
181
  min_size=min_size)
182
 
183
+ result_slider, result_gallery, restore_information, event_id = restore(
184
  model,
185
  edm_steps,
186
  s_stage1,
 
199
  spt_linear_s_stage2
200
  )
201
 
202
+ return result_slider, result_gallery, restore_information, event_id
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
203
 
 
 
 
 
204
  def restore(
205
  model,
206
  edm_steps,
 
234
  model.ae_dtype = convert_dtype(ae_dtype)
235
  model.model.dtype = convert_dtype(diff_dtype)
236
 
237
+ samples = model.batchify_sample(LQ, captions, num_steps=edm_steps, restoration_scale=s_stage1, s_churn=s_churn,
238
  s_noise=s_noise, cfg_scale=s_cfg, control_scale=s_stage2, seed=seed,
239
  num_samples=num_samples, p_p=a_prompt, n_p=n_prompt, color_fix_type=color_fix_type,
240
  use_linear_CFG=linear_CFG, use_linear_control_scale=linear_s_stage2,
241
  cfg_scale_start=spt_linear_CFG, control_scale_start=spt_linear_s_stage2)
242
 
243
+ x_samples = (einops.rearrange(samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().round().clip(
244
+ 0, 255).astype(np.uint8)
245
+ results = [x_samples[i] for i in range(num_samples)]
246
+
247
+ if args.log_history:
248
+ os.makedirs(f'./history/{event_id[:5]}/{event_id[5:]}', exist_ok=True)
249
+ with open(f'./history/{event_id[:5]}/{event_id[5:]}/logs.txt', 'w') as f:
250
+ f.write(str(event_dict))
251
+ f.close()
252
+ Image.fromarray(input_image).save(f'./history/{event_id[:5]}/{event_id[5:]}/LQ.png')
253
+ for i, result in enumerate(results):
254
+ Image.fromarray(result).save(f'./history/{event_id[:5]}/{event_id[5:]}/HQ_{i}.png')
255
+
256
+ # All the results have the same size
257
+ result_height, result_width, result_channel = np.array(results[0]).shape
258
+
259
+ print('<<== stage2_process')
260
+ end = time.time()
261
+ secondes = int(end - start)
262
+ minutes = math.floor(secondes / 60)
263
+ secondes = secondes - (minutes * 60)
264
+ hours = math.floor(minutes / 60)
265
+ minutes = minutes - (hours * 60)
266
+ information = ("Start the process again if you want a different result. " if randomize_seed else "") + \
267
+ "The new image resolution is " + str(result_width) + \
268
+ " pixels large and " + str(result_height) + \
269
+ " pixels high, so a resolution of " + f'{result_width * result_height:,}' + " pixels. " + \
270
+ "The image(s) has(ve) been generated in " + \
271
+ ((str(hours) + " h, ") if hours != 0 else "") + \
272
+ ((str(minutes) + " min, ") if hours != 0 or minutes != 0 else "") + \
273
+ str(secondes) + " sec."
274
+ print(information)
275
+
276
+ # Only one image can be shown in the slider
277
+ return [noisy_image] + [results[0]], gr.update(format = output_format, value = [noisy_image] + results), gr.update(value = information, visible = True), event_id
278
+
279
  def load_and_reset(param_setting):
280
  print('load_and_reset ==>>')
281
  if torch.cuda.device_count() == 0: