ironjr commited on
Commit
ca4247e
·
verified ·
1 Parent(s): 3070184

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -4
app.py CHANGED
@@ -423,7 +423,11 @@ def register(state, drawpad, model):
423
  seed_everything(state.seed if state.seed >=0 else np.random.randint(2147483647))
424
  print('Generate!')
425
 
426
- background = drawpad['background'].convert('RGBA')
 
 
 
 
427
  inpainting_mode = np.asarray(background).sum() != 0
428
  if not inpainting_mode:
429
  background = Image.new(size=(opt.width, opt.height), mode='RGB', color=(255, 255, 255))
@@ -432,9 +436,14 @@ def register(state, drawpad, model):
432
  background_prompt = None
433
  print('Inpainting mode: ', inpainting_mode)
434
 
435
- user_input = np.asarray(drawpad['layers'][0]) # (H, W, 4)
436
- foreground_mask = torch.tensor(user_input[..., -1], device=model.device)[None, None] # (1, 1, H, W)
437
- user_input = torch.tensor(user_input[..., :-1], device=model.device) # (H, W, 3)
 
 
 
 
 
438
 
439
  palette = torch.tensor([
440
  tuple(int(s[i+1:i+3], 16) for i in (0, 2, 4))
 
423
  seed_everything(state.seed if state.seed >=0 else np.random.randint(2147483647))
424
  print('Generate!')
425
 
426
+ background = drawpad['background']
427
+ if background is None:
428
+ background = Image.new(size=(opt.width, opt.height), mode='RGB', color=(255, 255, 255))
429
+ else:
430
+ background = background.convert('RGBA')
431
  inpainting_mode = np.asarray(background).sum() != 0
432
  if not inpainting_mode:
433
  background = Image.new(size=(opt.width, opt.height), mode='RGB', color=(255, 255, 255))
 
436
  background_prompt = None
437
  print('Inpainting mode: ', inpainting_mode)
438
 
439
+ if drawpad['composite'] is None:
440
+ user_input = np.zeros((opt.height, opt.width, 4))
441
+ foreground_mask = torch.zeros((1, 1, opt.height, opt.width))
442
+ user_input = torch.tensor(user_input[..., :-1]) # (H, W, 3)
443
+ else:
444
+ user_input = np.asarray(drawpad['composite']) # (H, W, 4)
445
+ foreground_mask = torch.tensor(user_input[..., -1])[None, None] # (1, 1, H, W)
446
+ user_input = torch.tensor(user_input[..., :-1]) # (H, W, 3)
447
 
448
  palette = torch.tensor([
449
  tuple(int(s[i+1:i+3], 16) for i in (0, 2, 4))