LiruiZhao commited on
Commit
46506d2
·
1 Parent(s): 621d96f

[Minor] Use The generator function to generate a list

Browse files
Files changed (1) hide show
  1. app.py +60 -16
app.py CHANGED
@@ -273,7 +273,6 @@ def generate(
273
  m_img.astype('float') / 2.0 * red).astype('uint8'))
274
 
275
 
276
-
277
  mask_video_path = "mask.mp4"
278
  fps = 30
279
  with imageio.get_writer(mask_video_path, fps=fps) as video:
@@ -282,7 +281,45 @@ def generate(
282
 
283
  return [int(seed), text_cfg_scale, image_cfg_scale, edited_image, mix_image, edited_mask_copy, mask_video_path, image_video_path, input_image_copy, mix_result_with_red_mask]
284
 
285
- @spaces.GPU(duration=120)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
286
  def generate_list(
287
  input_image: Image.Image,
288
  generate_list: str,
@@ -322,9 +359,11 @@ def generate_list(
322
  while generate_index < len(generate_list):
323
  print(f'generate_index: {str(generate_index)}')
324
  instruction = generate_list[generate_index]
 
 
325
  with torch.no_grad(), autocast("cuda"), model.ema_scope():
326
  cond = {}
327
- input_image_torch = 2 * torch.tensor(np.array(input_image_copy.copy())).float() / 255 - 1
328
  input_image_torch = rearrange(input_image_torch, "h w c -> 1 c h w").to(model.device)
329
  cond["c_crossattn"] = [model.get_learned_conditioning([instruction]).to(model.device)]
330
  cond["c_concat"] = [model.encode_first_stage(input_image_torch).mode().to(model.device)]
@@ -351,8 +390,10 @@ def generate_list(
351
 
352
  x_1 = nn.functional.interpolate(z_1, size=(height, width), mode="bilinear", align_corners=False)
353
  x_1 = torch.where(x_1 > 0, 1, -1) # Thresholding step
 
 
354
 
355
- if torch.sum(x_1).item()/x_1.numel() < -0.99:
356
  seed += 1
357
  retry_number +=1
358
  if retry_number > max_retry:
@@ -384,20 +425,22 @@ def generate_list(
384
 
385
  image_video.append((mix_image_np * 255).astype(np.uint8))
386
  mix_image = Image.fromarray((mix_image_np * 255).astype(np.uint8)).convert('RGB')
387
- input_image_copy = mix_image
388
-
389
- mix_result_with_red_mask = None
390
- mask_video_path = None
391
- edited_mask_copy = None
 
 
 
 
 
 
 
392
 
393
- image_video_path = "image.mp4"
394
- fps = 2
395
- with imageio.get_writer(image_video_path, fps=fps) as video:
396
- for image in image_video:
397
- video.append_data(image)
398
 
399
-
400
- return [int(seed), text_cfg_scale, image_cfg_scale, edited_image, mix_image, edited_mask_copy, mask_video_path, image_video_path, input_image, mix_result_with_red_mask]
401
 
402
 
403
  def reset():
@@ -553,4 +596,5 @@ with gr.Blocks(css="footer {visibility: hidden}") as demo:
553
  # demo.launch(share=True)
554
 
555
 
 
556
  demo.queue().launch()
 
273
  m_img.astype('float') / 2.0 * red).astype('uint8'))
274
 
275
 
 
276
  mask_video_path = "mask.mp4"
277
  fps = 30
278
  with imageio.get_writer(mask_video_path, fps=fps) as video:
 
281
 
282
  return [int(seed), text_cfg_scale, image_cfg_scale, edited_image, mix_image, edited_mask_copy, mask_video_path, image_video_path, input_image_copy, mix_result_with_red_mask]
283
 
284
+
285
+ def single_generation(model_wrap_cfg, input_image_copy, instruction, steps, seed, text_cfg_scale, image_cfg_scale, height, width):
286
+ model.cuda()
287
+ with torch.no_grad(), autocast("cuda"), model.ema_scope():
288
+ cond = {}
289
+ input_image_torch = 2 * torch.tensor(np.array(input_image_copy.to(model.device))).float() / 255 - 1
290
+ input_image_torch = rearrange(input_image_torch, "h w c -> 1 c h w").to(model.device)
291
+ cond["c_crossattn"] = [model.get_learned_conditioning([instruction]).to(model.device)]
292
+ cond["c_concat"] = [model.encode_first_stage(input_image_torch).mode().to(model.device)]
293
+
294
+ uncond = {}
295
+ uncond["c_crossattn"] = [null_token.to(model.device)]
296
+ uncond["c_concat"] = [torch.zeros_like(cond["c_concat"][0])]
297
+
298
+ sigmas = model_wrap.get_sigmas(steps).to(model.device)
299
+
300
+ extra_args = {
301
+ "cond": cond,
302
+ "uncond": uncond,
303
+ "text_cfg_scale": text_cfg_scale,
304
+ "image_cfg_scale": image_cfg_scale,
305
+ }
306
+ torch.manual_seed(seed)
307
+ z_0 = torch.randn_like(cond["c_concat"][0]).to(model.device) * sigmas[0]
308
+ z_1 = torch.randn_like(cond["c_concat"][0]).to(model.device) * sigmas[0]
309
+
310
+ z_0, z_1, _, _ = sample_euler_ancestral(model_wrap_cfg, z_0, z_1, sigmas, height, width, extra_args=extra_args)
311
+
312
+ x_0 = model.decode_first_stage(z_0)
313
+
314
+ x_1 = nn.functional.interpolate(z_1, size=(height, width), mode="bilinear", align_corners=False)
315
+ x_1 = torch.where(x_1 > 0, 1, -1) # Thresholding step
316
+
317
+ x_1_mean = torch.sum(x_1).item()/x_1.numel()
318
+
319
+ return x_0, x_1, x_1_mean
320
+
321
+
322
+ @spaces.GPU(duration=150)
323
  def generate_list(
324
  input_image: Image.Image,
325
  generate_list: str,
 
359
  while generate_index < len(generate_list):
360
  print(f'generate_index: {str(generate_index)}')
361
  instruction = generate_list[generate_index]
362
+
363
+ # x_0, x_1, x_1_mean = single_generation(model_wrap_cfg, input_image_copy, instruction, steps, seed, text_cfg_scale, image_cfg_scale, height, width)
364
  with torch.no_grad(), autocast("cuda"), model.ema_scope():
365
  cond = {}
366
+ input_image_torch = 2 * torch.tensor(np.array(input_image_copy)).float() / 255 - 1
367
  input_image_torch = rearrange(input_image_torch, "h w c -> 1 c h w").to(model.device)
368
  cond["c_crossattn"] = [model.get_learned_conditioning([instruction]).to(model.device)]
369
  cond["c_concat"] = [model.encode_first_stage(input_image_torch).mode().to(model.device)]
 
390
 
391
  x_1 = nn.functional.interpolate(z_1, size=(height, width), mode="bilinear", align_corners=False)
392
  x_1 = torch.where(x_1 > 0, 1, -1) # Thresholding step
393
+
394
+ x_1_mean = torch.sum(x_1).item()/x_1.numel()
395
 
396
+ if x_1_mean < -0.99:
397
  seed += 1
398
  retry_number +=1
399
  if retry_number > max_retry:
 
425
 
426
  image_video.append((mix_image_np * 255).astype(np.uint8))
427
  mix_image = Image.fromarray((mix_image_np * 255).astype(np.uint8)).convert('RGB')
428
+
429
+ mix_result_with_red_mask = None
430
+ mask_video_path = None
431
+ image_video_path = None
432
+ edited_mask_copy = None
433
+
434
+ if generate_index == len(generate_list):
435
+ image_video_path = "image.mp4"
436
+ fps = 2
437
+ with imageio.get_writer(image_video_path, fps=fps) as video:
438
+ for image in image_video:
439
+ video.append_data(image)
440
 
441
+ yield [int(seed), text_cfg_scale, image_cfg_scale, edited_image, mix_image, edited_mask_copy, mask_video_path, image_video_path, input_image, mix_result_with_red_mask]
 
 
 
 
442
 
443
+ input_image_copy = mix_image
 
444
 
445
 
446
  def reset():
 
596
  # demo.launch(share=True)
597
 
598
 
599
+ # demo.queue().launch(enable_queue=True)
600
  demo.queue().launch()