ZYMPKU commited on
Commit
bf50ebc
1 Parent(s): 979a298
Files changed (2) hide show
  1. app.py +1 -1
  2. sgm/modules/diffusionmodules/sampling.py +5 -10
app.py CHANGED
@@ -190,7 +190,7 @@ if __name__ == "__main__":
190
  attn_map = gr.Image(show_label=False, show_download_button=False)
191
  with gr.Tab(label="Segmentation Maps"):
192
  gr.Markdown("### Character-level segmentation maps (using upscaled attention maps):")
193
- seg_map = gr.AnnotatedImage(height=384, show_label=False, show_download_button=False)
194
 
195
  # examples
196
  examples = []
 
190
  attn_map = gr.Image(show_label=False, show_download_button=False)
191
  with gr.Tab(label="Segmentation Maps"):
192
  gr.Markdown("### Character-level segmentation maps (using upscaled attention maps):")
193
+ seg_map = gr.AnnotatedImage(height=384, show_label=False)
194
 
195
  # examples
196
  examples = []
sgm/modules/diffusionmodules/sampling.py CHANGED
@@ -268,12 +268,10 @@ class EulerEDMSampler(EDMSampler):
268
 
269
  return colormap
270
 
271
- def save_segment_map(self, image, attn_maps, tokens=None, save_name=None):
272
 
273
  colormap = self.create_pascal_label_colormap()
274
- H, W = image.shape[-2:]
275
 
276
- image_ = image*0.3
277
  sections = []
278
  for i in range(len(tokens)):
279
  attn_map = attn_maps[i]
@@ -285,14 +283,11 @@ class EulerEDMSampler(EDMSampler):
285
  colored_attn_map = attn_map_t * color
286
  colored_attn_map = colored_attn_map.to(device=image_.device)
287
 
288
- image_ += colored_attn_map*0.7
289
  sections.append(attn_map)
290
 
291
  section = np.stack(sections)
292
  np.save(f"./temp/seg_map/seg_{save_name}.npy", section)
293
 
294
- save_image(image_, f"./temp/seg_map/seg_{save_name}.png", normalize=True)
295
-
296
  def get_init_noise(self, cfgs, model, cond, batch, uc=None):
297
 
298
  H, W = batch["target_size_as_tuple"][0]
@@ -375,8 +370,8 @@ class EulerEDMSampler(EDMSampler):
375
  local_loss = torch.zeros(1)
376
  if save_attn:
377
  attn_map = model.model.diffusion_model.save_attn_map(save_name=name, tokens=batch["label"][0])
378
- denoised_decode = model.decode_first_stage(denoised) if denoised_decode is None else denoised_decode
379
- self.save_segment_map(denoised_decode, attn_map, tokens=batch["label"][0], save_name=name)
380
 
381
  d = to_d(x, sigma_hat, denoised)
382
  dt = append_dims(next_sigma - sigma_hat, x.ndim)
@@ -557,8 +552,8 @@ class EulerEDMDualSampler(EulerEDMSampler):
557
  local_loss = torch.zeros(1)
558
  if save_attn:
559
  attn_map = model.model.diffusion_model.save_attn_map(save_name=name, save_single=True)
560
- denoised_decode = model.decode_first_stage(denoised) if denoised_decode is None else denoised_decode
561
- self.save_segment_map(denoised_decode, attn_map, tokens=batch["label"][0], save_name=name)
562
 
563
  d = to_d(x, sigma_hat, denoised)
564
  dt = append_dims(next_sigma - sigma_hat, x.ndim)
 
268
 
269
  return colormap
270
 
271
+ def save_segment_map(self, H, W, attn_maps, tokens=None, save_name=None):
272
 
273
  colormap = self.create_pascal_label_colormap()
 
274
 
 
275
  sections = []
276
  for i in range(len(tokens)):
277
  attn_map = attn_maps[i]
 
283
  colored_attn_map = attn_map_t * color
284
  colored_attn_map = colored_attn_map.to(device=image_.device)
285
 
 
286
  sections.append(attn_map)
287
 
288
  section = np.stack(sections)
289
  np.save(f"./temp/seg_map/seg_{save_name}.npy", section)
290
 
 
 
291
  def get_init_noise(self, cfgs, model, cond, batch, uc=None):
292
 
293
  H, W = batch["target_size_as_tuple"][0]
 
370
  local_loss = torch.zeros(1)
371
  if save_attn:
372
  attn_map = model.model.diffusion_model.save_attn_map(save_name=name, tokens=batch["label"][0])
373
+ H, W = batch["target_size_as_tuple"][0]
374
+ self.save_segment_map(H, W, attn_map, tokens=batch["label"][0], save_name=name)
375
 
376
  d = to_d(x, sigma_hat, denoised)
377
  dt = append_dims(next_sigma - sigma_hat, x.ndim)
 
552
  local_loss = torch.zeros(1)
553
  if save_attn:
554
  attn_map = model.model.diffusion_model.save_attn_map(save_name=name, save_single=True)
555
+ H, W = batch["target_size_as_tuple"][0]
556
+ self.save_segment_map(H, W, attn_map, tokens=batch["label"][0], save_name=name)
557
 
558
  d = to_d(x, sigma_hat, denoised)
559
  dt = append_dims(next_sigma - sigma_hat, x.ndim)