Andranik Sargsyan commited on
Commit
736e88e
·
1 Parent(s): 08504da

refactor code

Browse files
Files changed (4) hide show
  1. app.py +70 -38
  2. lib/methods/rasg.py +24 -6
  3. lib/methods/sd.py +14 -10
  4. lib/methods/sr.py +42 -37
app.py CHANGED
@@ -75,8 +75,12 @@ def set_model_from_name(inp_model_name):
75
  inp_model = inpainting_models[inp_model_name]
76
 
77
 
78
- def rasg_run(use_painta, prompt, input, seed, eta, negative_prompt, positive_prompt, ddim_steps,
79
- guidance_scale=7.5, batch_size=4):
 
 
 
 
80
  torch.cuda.empty_cache()
81
 
82
  seed = int(seed)
@@ -87,35 +91,44 @@ guidance_scale=7.5, batch_size=4):
87
 
88
  method = ['rasg']
89
  if use_painta: method.append('painta')
 
90
 
91
  inpainted_images = []
92
  blended_images = []
93
  for i in range(batch_size):
 
 
94
  inpainted_image = rasg.run(
95
- ddim = inp_model,
96
- method = '-'.join(method),
97
- prompt = prompt,
98
- image = image.padx(64),
99
- mask = mask.alpha().padx(64),
100
- seed = seed+i*1000,
101
- eta = eta,
102
- prefix = '{}',
103
- negative_prompt = negative_prompt,
104
- positive_prompt = f', {positive_prompt}',
105
- dt = 1000 // ddim_steps,
106
- guidance_scale = guidance_scale
107
  ).crop(image.size)
108
- blended_image = poisson_blend(orig_img = image.data[0], fake_img = inpainted_image.data[0],
109
- mask = mask.data[0], dilation = 12)
110
 
 
 
 
 
 
 
111
  blended_images.append(blended_image)
112
  inpainted_images.append(inpainted_image.numpy()[0])
113
 
114
  return blended_images, inpainted_images
115
 
116
 
117
- def sd_run(use_painta, prompt, input, seed, eta, negative_prompt, positive_prompt, ddim_steps,
118
- guidance_scale=7.5, batch_size=4):
 
 
 
119
  torch.cuda.empty_cache()
120
 
121
  seed = int(seed)
@@ -126,28 +139,33 @@ guidance_scale=7.5, batch_size=4):
126
 
127
  method = ['default']
128
  if use_painta: method.append('painta')
 
129
 
130
  inpainted_images = []
131
  blended_images = []
132
  for i in range(batch_size):
 
 
133
  inpainted_image = sd.run(
134
- ddim = inp_model,
135
- method = '-'.join(method),
136
- prompt = prompt,
137
- image = image.padx(64),
138
- mask = mask.alpha().padx(64),
139
- seed = seed+i*1000,
140
- eta = eta,
141
- prefix = '{}',
142
- negative_prompt = negative_prompt,
143
- positive_prompt = f', {positive_prompt}',
144
- dt = 1000 // ddim_steps,
145
- guidance_scale = guidance_scale
146
  ).crop(image.size)
147
 
148
- blended_image = poisson_blend(orig_img = image.data[0], fake_img = inpainted_image.data[0],
149
- mask = mask.data[0], dilation = 12)
150
-
 
 
 
151
  blended_images.append(blended_image)
152
  inpainted_images.append(inpainted_image.numpy()[0])
153
 
@@ -156,7 +174,9 @@ guidance_scale=7.5, batch_size=4):
156
 
157
  def upscale_run(
158
  prompt, input, ddim_steps, seed, use_sam_mask, gallery, img_index,
159
- negative_prompt='', positive_prompt=', high resolution professional photo'):
 
 
160
  torch.cuda.empty_cache()
161
 
162
  seed = int(seed)
@@ -169,10 +189,22 @@ negative_prompt='', positive_prompt=', high resolution professional photo'):
169
  lr_image = IImage(inpainted_image)
170
  hr_image = IImage(input['image']).resize(2048)
171
  hr_mask = IImage(input['mask']).resize(2048)
172
- output_image = sr.run(sr_model, sam_predictor, lr_image, hr_image, hr_mask, prompt=prompt + positive_prompt,
173
- noise_level=0, blend_trick=True, blend_output=True, negative_prompt=negative_prompt,
174
- seed=seed, use_sam_mask=use_sam_mask)
175
- return output_image.numpy()[0], output_image.numpy()[0]
 
 
 
 
 
 
 
 
 
 
 
 
176
 
177
 
178
  def switch_run(use_rasg, model_name, *args):
 
75
  inp_model = inpainting_models[inp_model_name]
76
 
77
 
78
+ def rasg_run(
79
+ use_painta, prompt, input, seed, eta,
80
+ negative_prompt, positive_prompt, ddim_steps,
81
+ guidance_scale=7.5,
82
+ batch_size=1
83
+ ):
84
  torch.cuda.empty_cache()
85
 
86
  seed = int(seed)
 
91
 
92
  method = ['rasg']
93
  if use_painta: method.append('painta')
94
+ method = '-'.join(method)
95
 
96
  inpainted_images = []
97
  blended_images = []
98
  for i in range(batch_size):
99
+ seed = seed + i * 1000
100
+
101
  inpainted_image = rasg.run(
102
+ ddim=inp_model,
103
+ method=method,
104
+ prompt=prompt,
105
+ image=image,
106
+ mask=mask,
107
+ seed=seed,
108
+ eta=eta,
109
+ negative_prompt=negative_prompt,
110
+ positive_prompt=positive_prompt,
111
+ num_steps=ddim_steps,
112
+ guidance_scale=guidance_scale
 
113
  ).crop(image.size)
 
 
114
 
115
+ blended_image = poisson_blend(
116
+ orig_img=image.data[0],
117
+ fake_img=inpainted_image.data[0],
118
+ mask=mask.data[0],
119
+ dilation=12
120
+ )
121
  blended_images.append(blended_image)
122
  inpainted_images.append(inpainted_image.numpy()[0])
123
 
124
  return blended_images, inpainted_images
125
 
126
 
127
+ def sd_run(use_painta, prompt, input, seed, eta,
128
+ negative_prompt, positive_prompt, ddim_steps,
129
+ guidance_scale=7.5,
130
+ batch_size=1
131
+ ):
132
  torch.cuda.empty_cache()
133
 
134
  seed = int(seed)
 
139
 
140
  method = ['default']
141
  if use_painta: method.append('painta')
142
+ method = '-'.join(method)
143
 
144
  inpainted_images = []
145
  blended_images = []
146
  for i in range(batch_size):
147
+ seed = seed + i * 1000
148
+
149
  inpainted_image = sd.run(
150
+ ddim=inp_model,
151
+ method=method,
152
+ prompt=prompt,
153
+ image=image,
154
+ mask=mask,
155
+ seed=seed,
156
+ eta=eta,
157
+ negative_prompt=negative_prompt,
158
+ positive_prompt=positive_prompt,
159
+ num_steps=ddim_steps,
160
+ guidance_scale=guidance_scale
 
161
  ).crop(image.size)
162
 
163
+ blended_image = poisson_blend(
164
+ orig_img=image.data[0],
165
+ fake_img=inpainted_image.data[0],
166
+ mask=mask.data[0],
167
+ dilation=12
168
+ )
169
  blended_images.append(blended_image)
170
  inpainted_images.append(inpainted_image.numpy()[0])
171
 
 
174
 
175
  def upscale_run(
176
  prompt, input, ddim_steps, seed, use_sam_mask, gallery, img_index,
177
+ negative_prompt='',
178
+ positive_prompt=', high resolution professional photo'
179
+ ):
180
  torch.cuda.empty_cache()
181
 
182
  seed = int(seed)
 
189
  lr_image = IImage(inpainted_image)
190
  hr_image = IImage(input['image']).resize(2048)
191
  hr_mask = IImage(input['mask']).resize(2048)
192
+ output_image = sr.run(
193
+ sr_model,
194
+ sam_predictor,
195
+ lr_image,
196
+ hr_image,
197
+ hr_mask,
198
+ prompt=prompt + positive_prompt,
199
+ noise_level=20,
200
+ blend_trick=True,
201
+ blend_output=True,
202
+ negative_prompt=negative_prompt,
203
+ seed=seed,
204
+ use_sam_mask=use_sam_mask
205
+ )
206
+ output_image.info = input['image'].info # save metadata
207
+ return output_image, output_image
208
 
209
 
210
  def switch_run(use_rasg, model_name, *args):
lib/methods/rasg.py CHANGED
@@ -23,12 +23,28 @@ def init_guidance():
23
  router.attention_forward = attentionpatch.default.forward_and_save
24
  router.basic_transformer_forward = transformerpatch.default.forward
25
 
26
- def run(ddim, method, prompt, image, mask, seed, eta, prefix, negative_prompt, positive_prompt, dt, guidance_scale):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  # Text condition
28
- prompt = prefix.format(prompt)
29
- context = ddim.encoder.encode([negative_prompt, prompt + positive_prompt])
30
- token_idx = list(range(1 + prefix.split(' ').index('{}'), tokenize(prompt).index('<end_of_text>')))
31
- token_idx += [tokenize(prompt + positive_prompt).index('<end_of_text>')]
32
 
33
  # Initialize painta
34
  if 'painta' in method: init_painta(token_idx)
@@ -84,7 +100,9 @@ def run(ddim, method, prompt, image, mask, seed, eta, prefix, negative_prompt, p
84
  grad -= grad.mean()
85
  grad /= grad.std()
86
 
87
- zt = share.schedule.sqrt_alphas[share.timestep - dt] * z0 + torch.sqrt(1 - share.schedule.alphas[share.timestep - dt] - sigma ** 2) * eps + eta * sigma * grad
 
 
88
 
89
  with torch.no_grad():
90
  output_image = IImage(ddim.vae.decode(z0 / ddim.config.scale_factor))
 
23
  router.attention_forward = attentionpatch.default.forward_and_save
24
  router.basic_transformer_forward = transformerpatch.default.forward
25
 
26
+ def run(
27
+ ddim,
28
+ method,
29
+ prompt,
30
+ image,
31
+ mask,
32
+ seed=0,
33
+ eta=0.1,
34
+ negative_prompt='',
35
+ positive_prompt='',
36
+ num_steps=50,
37
+ guidance_scale=7.5
38
+ ):
39
+ image = image.padx(64)
40
+ mask = mask.alpha().padx(64)
41
+ full_prompt = f'{prompt}, {positive_prompt}'
42
+ dt = 1000 // num_steps
43
+
44
  # Text condition
45
+ context = ddim.encoder.encode([negative_prompt, full_prompt])
46
+ token_idx = list(range(1, tokenize(prompt).index('<end_of_text>')))
47
+ token_idx += [tokenize(full_prompt).index('<end_of_text>')]
 
48
 
49
  # Initialize painta
50
  if 'painta' in method: init_painta(token_idx)
 
100
  grad -= grad.mean()
101
  grad /= grad.std()
102
 
103
+ zt = share.schedule.sqrt_alphas[share.timestep - dt] * z0 + \
104
+ torch.sqrt(1 - share.schedule.alphas[share.timestep - dt] - sigma ** 2) * eps + \
105
+ eta * sigma * grad
106
 
107
  with torch.no_grad():
108
  output_image = IImage(ddim.vae.decode(z0 / ddim.config.scale_factor))
lib/methods/sd.py CHANGED
@@ -24,18 +24,22 @@ def run(
24
  prompt,
25
  image,
26
  mask,
27
- seed,
28
- eta,
29
- prefix,
30
- negative_prompt,
31
- positive_prompt,
32
- dt,
33
- guidance_scale
34
  ):
 
 
 
 
 
35
  # Text condition
36
- context = ddim.encoder.encode([negative_prompt, prompt + positive_prompt])
37
- token_idx = list(range(1 + prefix.split(' ').index('{}'), tokenize(prompt).index('<end_of_text>')))
38
- token_idx += [tokenize(prompt + positive_prompt).index('<end_of_text>')]
39
 
40
  # Setup painta if needed
41
  if 'painta' in method: init_painta(token_idx)
 
24
  prompt,
25
  image,
26
  mask,
27
+ seed=0,
28
+ eta=0.1,
29
+ negative_prompt='',
30
+ positive_prompt='',
31
+ num_steps=50,
32
+ guidance_scale=7.5
 
33
  ):
34
+ image = image.padx(64)
35
+ mask = mask.alpha().padx(64)
36
+ full_prompt = f'{prompt}, {positive_prompt}'
37
+ dt = 1000 // num_steps
38
+
39
  # Text condition
40
+ context = ddim.encoder.encode([negative_prompt, full_prompt])
41
+ token_idx = list(range(1, tokenize(prompt).index('<end_of_text>')))
42
+ token_idx += [tokenize(full_prompt).index('<end_of_text>')]
43
 
44
  # Setup painta if needed
45
  if 'painta' in method: init_painta(token_idx)
lib/methods/sr.py CHANGED
@@ -57,9 +57,22 @@ def refine_mask(hr_image, hr_mask, lr_image, sam_predictor):
57
  return new_mask
58
 
59
 
60
- def run(ddim, sam_predictor, lr_image, hr_image, hr_mask, prompt = 'high resolution professional photo', noise_level=20,
61
- blend_output = True, blend_trick = True, no_superres = False,
62
- dt = 50, seed = 1, guidance_scale = 7.5, negative_prompt = '', use_sam_mask = False):
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  torch.manual_seed(seed)
64
  dtype = ddim.vae.encoder.conv_in.weight.dtype
65
  device = ddim.vae.encoder.conv_in.weight.device
@@ -67,6 +80,9 @@ dt = 50, seed = 1, guidance_scale = 7.5, negative_prompt = '', use_sam_mask = Fa
67
  router.attention_forward = attentionpatch.default.forward_xformers
68
  router.basic_transformer_forward = transformerpatch.default.forward
69
 
 
 
 
70
  if use_sam_mask:
71
  with torch.no_grad():
72
  hr_mask = refine_mask(hr_image, hr_mask, lr_image, sam_predictor)
@@ -74,70 +90,59 @@ dt = 50, seed = 1, guidance_scale = 7.5, negative_prompt = '', use_sam_mask = Fa
74
  orig_h, orig_w = hr_image.torch().shape[2], hr_image.torch().shape[3]
75
  hr_image = hr_image.padx(256, padding_mode='reflect')
76
  hr_mask = hr_mask.padx(256, padding_mode='reflect').dilate(19)
77
- hr_mask_orig = hr_mask
78
- lr_image = lr_image.padx(64, padding_mode='reflect')
79
- lr_mask = hr_mask.resize((lr_image.torch().shape[2], lr_image.torch().shape[3]), resample = Image.BICUBIC).alpha().torch(vmin=0).to(device)
 
80
  lr_mask = TvF.gaussian_blur(lr_mask, kernel_size=19)
81
-
82
- if no_superres:
83
- output_tensor = lr_image.resize((hr_image.torch().shape[2], hr_image.torch().shape[3]), resample = Image.BICUBIC).torch().cuda()
84
- output_tensor = (255*((output_tensor.clip(-1, 1) + 1) / 2)).to(torch.uint8)
85
- output_tensor = poisson_blend(
86
- orig_img=hr_image.data[0][:orig_h, :orig_w, :],
87
- fake_img=output_tensor.cpu().permute(0, 2, 3, 1)[0].numpy()[:orig_h, :orig_w, :],
88
- mask=hr_mask_orig.alpha().data[0][:orig_h, :orig_w, :]
89
- )
90
- return IImage(output_tensor[:orig_h, :orig_w, :])
91
 
92
  # encode hr image
93
  with torch.no_grad():
94
- hr_z0 = ddim.vae.encode(hr_image.torch().cuda().to(dtype=dtype, device=device)).mean * ddim.config.scale_factor
 
95
 
96
- assert hr_z0.shape[2] == lr_image.torch().shape[2]
97
- assert hr_z0.shape[3] == lr_image.torch().shape[3]
98
 
99
- unet_condition = lr_image.cuda().torch().to(memory_format=torch.contiguous_format).to(dtype=dtype, device=device)
100
- zT = torch.randn((1,4,unet_condition.shape[2], unet_condition.shape[3])).cuda().to(dtype=dtype, device=device)
101
-
102
  with torch.no_grad():
103
  context = ddim.encoder.encode([negative_prompt, prompt])
104
-
105
  noise_level = torch.Tensor(1 * [noise_level]).to(device=device).long()
 
106
  unet_condition, noise_level = ddim.low_scale_model(unet_condition, noise_level=noise_level)
107
 
108
  with torch.autocast('cuda'), torch.no_grad():
109
- zt = zT
 
110
  for index,t in enumerate(range(999, 0, -dt)):
111
-
112
  _zt = zt if unet_condition is None else torch.cat([zt, unet_condition], 1)
113
-
114
  eps_uncond, eps = ddim.unet(
115
  torch.cat([_zt, _zt]).to(dtype=dtype, device=device),
116
  timesteps = torch.tensor([t, t]).to(device=device),
117
  context = context,
118
  y=torch.cat([noise_level]*2)
119
  ).chunk(2)
120
-
121
  ts = torch.full((zt.shape[0],), t, device=device, dtype=torch.long)
122
  model_output = (eps_uncond + guidance_scale * (eps - eps_uncond))
123
  eps = predict_eps_from_z_and_v(ddim.schedule, zt, ts, model_output).to(dtype)
124
  z0 = predict_start_from_z_and_v(ddim.schedule, zt, ts, model_output).to(dtype)
125
-
126
  if blend_trick:
127
  z0 = z0 * lr_mask + hr_z0 * (1-lr_mask)
128
-
129
  zt = ddim.schedule.sqrt_alphas[t - dt] * z0 + ddim.schedule.sqrt_one_minus_alphas[t - dt] * eps
130
 
131
  with torch.no_grad():
132
- output_tensor = ddim.vae.decode(z0.to(dtype) / ddim.config.scale_factor)
 
 
 
 
133
 
134
  if blend_output:
135
- output_tensor = (255*((output_tensor + 1) / 2).clip(0, 1)).to(torch.uint8)
136
- output_tensor = poisson_blend(
137
- orig_img=hr_image.data[0][:orig_h, :orig_w, :],
138
- fake_img=output_tensor.cpu().permute(0, 2, 3, 1)[0].numpy()[:orig_h, :orig_w, :],
139
- mask=hr_mask_orig.alpha().data[0][:orig_h, :orig_w, :]
140
  )
141
- return IImage(output_tensor[:orig_h, :orig_w, :])
142
  else:
143
- return IImage(output_tensor[:, :, :orig_h, :orig_w])
 
57
  return new_mask
58
 
59
 
60
+ def run(
61
+ ddim,
62
+ sam_predictor,
63
+ lr_image,
64
+ hr_image,
65
+ hr_mask,
66
+ prompt = 'high resolution professional photo',
67
+ noise_level=20,
68
+ blend_output = True,
69
+ blend_trick = True,
70
+ dt = 50,
71
+ seed = 1,
72
+ guidance_scale = 7.5,
73
+ negative_prompt = '',
74
+ use_sam_mask = False
75
+ ):
76
  torch.manual_seed(seed)
77
  dtype = ddim.vae.encoder.conv_in.weight.dtype
78
  device = ddim.vae.encoder.conv_in.weight.device
 
80
  router.attention_forward = attentionpatch.default.forward_xformers
81
  router.basic_transformer_forward = transformerpatch.default.forward
82
 
83
+ hr_image_orig = hr_image
84
+ hr_mask_orig = hr_mask
85
+
86
  if use_sam_mask:
87
  with torch.no_grad():
88
  hr_mask = refine_mask(hr_image, hr_mask, lr_image, sam_predictor)
 
90
  orig_h, orig_w = hr_image.torch().shape[2], hr_image.torch().shape[3]
91
  hr_image = hr_image.padx(256, padding_mode='reflect')
92
  hr_mask = hr_mask.padx(256, padding_mode='reflect').dilate(19)
93
+
94
+ lr_image = lr_image.padx(64, padding_mode='reflect').torch()
95
+ lr_mask = hr_mask.resize((lr_image.shape[2:]), resample = Image.BICUBIC)
96
+ lr_mask = lr_mask.alpha().torch(vmin=0).to(device)
97
  lr_mask = TvF.gaussian_blur(lr_mask, kernel_size=19)
 
 
 
 
 
 
 
 
 
 
98
 
99
  # encode hr image
100
  with torch.no_grad():
101
+ hr_image = hr_image.torch().to(dtype=dtype, device=device)
102
+ hr_z0 = ddim.vae.encode(hr_image).mean * ddim.config.scale_factor
103
 
104
+ assert hr_z0.shape[2] == lr_image.shape[2]
105
+ assert hr_z0.shape[3] == lr_image.shape[3]
106
 
 
 
 
107
  with torch.no_grad():
108
  context = ddim.encoder.encode([negative_prompt, prompt])
109
+
110
  noise_level = torch.Tensor(1 * [noise_level]).to(device=device).long()
111
+ unet_condition = lr_image.to(dtype=dtype, device=device, memory_format=torch.contiguous_format)
112
  unet_condition, noise_level = ddim.low_scale_model(unet_condition, noise_level=noise_level)
113
 
114
  with torch.autocast('cuda'), torch.no_grad():
115
+ zt = torch.randn((1,4,unet_condition.shape[2], unet_condition.shape[3]))
116
+ zt = zt.cuda().to(dtype=dtype, device=device)
117
  for index,t in enumerate(range(999, 0, -dt)):
 
118
  _zt = zt if unet_condition is None else torch.cat([zt, unet_condition], 1)
 
119
  eps_uncond, eps = ddim.unet(
120
  torch.cat([_zt, _zt]).to(dtype=dtype, device=device),
121
  timesteps = torch.tensor([t, t]).to(device=device),
122
  context = context,
123
  y=torch.cat([noise_level]*2)
124
  ).chunk(2)
 
125
  ts = torch.full((zt.shape[0],), t, device=device, dtype=torch.long)
126
  model_output = (eps_uncond + guidance_scale * (eps - eps_uncond))
127
  eps = predict_eps_from_z_and_v(ddim.schedule, zt, ts, model_output).to(dtype)
128
  z0 = predict_start_from_z_and_v(ddim.schedule, zt, ts, model_output).to(dtype)
 
129
  if blend_trick:
130
  z0 = z0 * lr_mask + hr_z0 * (1-lr_mask)
 
131
  zt = ddim.schedule.sqrt_alphas[t - dt] * z0 + ddim.schedule.sqrt_one_minus_alphas[t - dt] * eps
132
 
133
  with torch.no_grad():
134
+ hr_result = ddim.vae.decode(z0.to(dtype) / ddim.config.scale_factor)
135
+ # postprocess
136
+ hr_result = (255 * ((hr_result + 1) / 2).clip(0, 1)).to(torch.uint8)
137
+ hr_result = hr_result.cpu().permute(0, 2, 3, 1)[0].numpy()
138
+ hr_result = hr_result[:orig_h, :orig_w, :]
139
 
140
  if blend_output:
141
+ hr_result = poisson_blend(
142
+ orig_img=hr_image_orig.data[0],
143
+ fake_img=hr_result,
144
+ mask=hr_mask_orig.alpha().data[0]
 
145
  )
146
+ return Image.fromarray(hr_result)
147
  else:
148
+ return Image.fromarray(hr_result)