Songwei Ge commited on
Commit
ab7db7f
·
1 Parent(s): b41079f
Files changed (2) hide show
  1. app.py +4 -22
  2. utils/attention_utils.py +11 -22
app.py CHANGED
@@ -23,25 +23,6 @@ Instructions placeholder.
23
  """
24
 
25
 
26
- example_instructions = [
27
- "Make it a picasso painting",
28
- "as if it were by modigliani",
29
- "convert to a bronze statue",
30
- "Turn it into an anime.",
31
- "have it look like a graphic novel",
32
- "make him gain weight",
33
- "what would he look like bald?",
34
- "Have him smile",
35
- "Put him in a cocktail party.",
36
- "move him at the beach.",
37
- "add dramatic lighting",
38
- "Convert to black and white",
39
- "What if it were snowing?",
40
- "Give him a leather jacket",
41
- "Turn him into a cyborg!",
42
- "make him wear a beanie",
43
- ]
44
-
45
  def main():
46
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
47
  model = RegionDiffusion(device)
@@ -90,9 +71,9 @@ def main():
90
  height=height, width=width, num_inference_steps=steps,
91
  guidance_scale=guidance_weight)
92
  print('time lapses to get attention maps: %.4f' % (time.time()-begin_time))
93
- color_obj_masks = get_token_maps(
94
  model.attention_maps, run_dir, width//8, height//8, color_target_token_ids, seed)
95
- model.masks = get_token_maps(
96
  model.attention_maps, run_dir, width//8, height//8, region_target_token_ids, seed, base_tokens)
97
  color_obj_masks = [transforms.functional.resize(color_obj_mask, (height, width),
98
  interpolation=transforms.InterpolationMode.BICUBIC,
@@ -110,7 +91,8 @@ def main():
110
  text_format_dict=text_format_dict)
111
  print('time lapses to generate image from rich text: %.4f' %
112
  (time.time()-begin_time))
113
- return [plain_img[0], rich_img[0]]
 
114
 
115
  with gr.Blocks() as demo:
116
  gr.HTML("""<h1 style="font-weight: 900; margin-bottom: 7px;">Expressive Text-to-Image Generation with Rich Text</h1>
 
23
  """
24
 
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  def main():
27
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
28
  model = RegionDiffusion(device)
 
71
  height=height, width=width, num_inference_steps=steps,
72
  guidance_scale=guidance_weight)
73
  print('time lapses to get attention maps: %.4f' % (time.time()-begin_time))
74
+ color_obj_masks, _ = get_token_maps(
75
  model.attention_maps, run_dir, width//8, height//8, color_target_token_ids, seed)
76
+ model.masks, token_maps = get_token_maps(
77
  model.attention_maps, run_dir, width//8, height//8, region_target_token_ids, seed, base_tokens)
78
  color_obj_masks = [transforms.functional.resize(color_obj_mask, (height, width),
79
  interpolation=transforms.InterpolationMode.BICUBIC,
 
91
  text_format_dict=text_format_dict)
92
  print('time lapses to generate image from rich text: %.4f' %
93
  (time.time()-begin_time))
94
+ cat_img = np.concatenate([plain_img[0], rich_img[0]], 1)
95
+ return [cat_img, token_maps]
96
 
97
  with gr.Blocks() as demo:
98
  gr.HTML("""<h1 style="font-weight: 900; margin-bottom: 7px;">Expressive Text-to-Image Generation with Rich Text</h1>
utils/attention_utils.py CHANGED
@@ -76,15 +76,19 @@ def plot_attention_maps(atten_map_list, obj_tokens, save_dir, seed, tokens_vis=N
76
  norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
77
  sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
78
  fig.colorbar(sm, cax=axs[-1])
 
 
 
 
79
 
80
  fig.tight_layout()
81
  plt.savefig(os.path.join(
82
  save_dir, 'token_mapes_seed%d_%s.png' % (seed, atten_names[i])), dpi=100)
83
  plt.close('all')
 
84
 
85
 
86
- def get_token_maps(attention_maps, save_dir, width, height, obj_tokens, seed=0, tokens_vis=None,
87
- preprocess=False):
88
  r"""Function to visualize attention maps.
89
  Args:
90
  save_dir (str): Path to save attention maps
@@ -177,23 +181,8 @@ def get_token_maps(attention_maps, save_dir, width, height, obj_tokens, seed=0,
177
  attention_maps_averaged_normalized = [
178
  attention_maps_averaged_normalized[i:i+1] for i in range(attention_maps_averaged_normalized.shape[0])]
179
 
180
- if preprocess:
181
- # it is possible to preprocess the attention maps here
182
- selem = square(5)
183
- attention_maps_averaged_eroded = [erosion(skimage.img_as_float(
184
- map[0].numpy()*255), selem) for map in attention_maps_averaged_normalized[:2]]
185
- attention_maps_averaged_eroded = [(torch.from_numpy(map).unsqueeze(
186
- 0)/255. > 0.8).float() for map in attention_maps_averaged_eroded]
187
- attention_maps_averaged_eroded.append(
188
- 1 - torch.cat(attention_maps_averaged_eroded).sum(0, True))
189
- plot_attention_maps([attention_maps_averaged, attention_maps_averaged_normalized,
190
- attention_maps_averaged_eroded], obj_tokens, save_dir, seed, tokens_vis)
191
- attention_maps_averaged_eroded = [attn_mask.unsqueeze(1).repeat(
192
- [1, 4, 1, 1]).cuda() for attn_mask in attention_maps_averaged_eroded]
193
- return attention_maps_averaged_eroded
194
- else:
195
- plot_attention_maps([attention_maps_averaged, attention_maps_averaged_normalized],
196
- obj_tokens, save_dir, seed, tokens_vis)
197
- attention_maps_averaged_normalized = [attn_mask.unsqueeze(1).repeat(
198
- [1, 4, 1, 1]).cuda() for attn_mask in attention_maps_averaged_normalized]
199
- return attention_maps_averaged_normalized
 
76
  norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
77
  sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
78
  fig.colorbar(sm, cax=axs[-1])
79
+ canvas = fig.canvas
80
+ canvas.draw()
81
+ width, height = canvas.get_width_height()
82
+ img = np.frombuffer(canvas.tostring_rgb(), dtype='uint8').reshape((height, width, 3))
83
 
84
  fig.tight_layout()
85
  plt.savefig(os.path.join(
86
  save_dir, 'token_mapes_seed%d_%s.png' % (seed, atten_names[i])), dpi=100)
87
  plt.close('all')
88
+ return img
89
 
90
 
91
+ def get_token_maps(attention_maps, save_dir, width, height, obj_tokens, seed=0, tokens_vis=None):
 
92
  r"""Function to visualize attention maps.
93
  Args:
94
  save_dir (str): Path to save attention maps
 
181
  attention_maps_averaged_normalized = [
182
  attention_maps_averaged_normalized[i:i+1] for i in range(attention_maps_averaged_normalized.shape[0])]
183
 
184
+ token_maps_vis = plot_attention_maps([attention_maps_averaged, attention_maps_averaged_normalized],
185
+ obj_tokens, save_dir, seed, tokens_vis)
186
+ attention_maps_averaged_normalized = [attn_mask.unsqueeze(1).repeat(
187
+ [1, 4, 1, 1]).cuda() for attn_mask in attention_maps_averaged_normalized]
188
+ return attention_maps_averaged_normalized, token_maps_vis