paul hilders commited on
Commit
0241217
1 Parent(s): 8d581a7

Add new version of demo for IEAI course

Browse files
.gitmodules ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [submodule "CLIP_explainability/Transformer-MM-Explainability"]
2
+ path = CLIP_explainability/Transformer-MM-Explainability
3
+ url = https://github.com/hila-chefer/Transformer-MM-Explainability.git
CLIP_explainability/Transformer-MM-Explainability ADDED
@@ -0,0 +1 @@
 
 
1
+ Subproject commit 6a2c3c9da3fc186878e0c2bcf238c3a4c76d8af8
CLIP_explainability/utils.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import CLIP.clip as clip
3
+ from PIL import Image
4
+ import numpy as np
5
+ import cv2
6
+ import matplotlib.pyplot as plt
7
+ from captum.attr import visualization
8
+ import os
9
+
10
+
11
+ from CLIP.clip.simple_tokenizer import SimpleTokenizer as _Tokenizer
12
+ _tokenizer = _Tokenizer()
13
+
14
+ #@title Control context expansion (number of attention layers to consider)
15
+ #@title Number of layers for image Transformer
16
+ start_layer = 11#@param {type:"number"}
17
+
18
+ #@title Number of layers for text Transformer
19
+ start_layer_text = 11#@param {type:"number"}
20
+
21
+
22
+ def interpret(image, texts, model, device):
23
+ batch_size = texts.shape[0]
24
+ images = image.repeat(batch_size, 1, 1, 1)
25
+ logits_per_image, logits_per_text = model(images, texts)
26
+ probs = logits_per_image.softmax(dim=-1).detach().cpu().numpy()
27
+ index = [i for i in range(batch_size)]
28
+ one_hot = np.zeros((logits_per_image.shape[0], logits_per_image.shape[1]), dtype=np.float32)
29
+ one_hot[torch.arange(logits_per_image.shape[0]), index] = 1
30
+ one_hot = torch.from_numpy(one_hot).requires_grad_(True)
31
+ one_hot = torch.sum(one_hot.to(device) * logits_per_image)
32
+ model.zero_grad()
33
+
34
+ image_attn_blocks = list(dict(model.visual.transformer.resblocks.named_children()).values())
35
+ num_tokens = image_attn_blocks[0].attn_probs.shape[-1]
36
+ R = torch.eye(num_tokens, num_tokens, dtype=image_attn_blocks[0].attn_probs.dtype).to(device)
37
+ R = R.unsqueeze(0).expand(batch_size, num_tokens, num_tokens)
38
+ for i, blk in enumerate(image_attn_blocks):
39
+ if i < start_layer:
40
+ continue
41
+ grad = torch.autograd.grad(one_hot, [blk.attn_probs], retain_graph=True)[0].detach()
42
+ cam = blk.attn_probs.detach()
43
+ cam = cam.reshape(-1, cam.shape[-1], cam.shape[-1])
44
+ grad = grad.reshape(-1, grad.shape[-1], grad.shape[-1])
45
+ cam = grad * cam
46
+ cam = cam.reshape(batch_size, -1, cam.shape[-1], cam.shape[-1])
47
+ cam = cam.clamp(min=0).mean(dim=1)
48
+ R = R + torch.bmm(cam, R)
49
+ image_relevance = R[:, 0, 1:]
50
+
51
+
52
+ text_attn_blocks = list(dict(model.transformer.resblocks.named_children()).values())
53
+ num_tokens = text_attn_blocks[0].attn_probs.shape[-1]
54
+ R_text = torch.eye(num_tokens, num_tokens, dtype=text_attn_blocks[0].attn_probs.dtype).to(device)
55
+ R_text = R_text.unsqueeze(0).expand(batch_size, num_tokens, num_tokens)
56
+ for i, blk in enumerate(text_attn_blocks):
57
+ if i < start_layer_text:
58
+ continue
59
+ grad = torch.autograd.grad(one_hot, [blk.attn_probs], retain_graph=True)[0].detach()
60
+ cam = blk.attn_probs.detach()
61
+ cam = cam.reshape(-1, cam.shape[-1], cam.shape[-1])
62
+ grad = grad.reshape(-1, grad.shape[-1], grad.shape[-1])
63
+ cam = grad * cam
64
+ cam = cam.reshape(batch_size, -1, cam.shape[-1], cam.shape[-1])
65
+ cam = cam.clamp(min=0).mean(dim=1)
66
+ R_text = R_text + torch.bmm(cam, R_text)
67
+ text_relevance = R_text
68
+
69
+ return text_relevance, image_relevance
70
+
71
+
72
+ def show_image_relevance(image_relevance, image, orig_image, device, show=True):
73
+ # create heatmap from mask on image
74
+ def show_cam_on_image(img, mask):
75
+ heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
76
+ heatmap = np.float32(heatmap) / 255
77
+ cam = heatmap + np.float32(img)
78
+ cam = cam / np.max(cam)
79
+ return cam
80
+
81
+ # plt.axis('off')
82
+ # f, axarr = plt.subplots(1,2)
83
+ # axarr[0].imshow(orig_image)
84
+
85
+ if show:
86
+ fig, axs = plt.subplots(1, 2)
87
+ axs[0].imshow(orig_image);
88
+ axs[0].axis('off');
89
+
90
+ image_relevance = image_relevance.reshape(1, 1, 7, 7)
91
+ image_relevance = torch.nn.functional.interpolate(image_relevance, size=224, mode='bilinear')
92
+ image_relevance = image_relevance.reshape(224, 224).to(device).data.cpu().numpy()
93
+ image_relevance = (image_relevance - image_relevance.min()) / (image_relevance.max() - image_relevance.min())
94
+ image = image[0].permute(1, 2, 0).data.cpu().numpy()
95
+ image = (image - image.min()) / (image.max() - image.min())
96
+ vis = show_cam_on_image(image, image_relevance)
97
+ vis = np.uint8(255 * vis)
98
+ vis = cv2.cvtColor(np.array(vis), cv2.COLOR_RGB2BGR)
99
+
100
+ if show:
101
+ # axar[1].imshow(vis)
102
+ axs[1].imshow(vis);
103
+ axs[1].axis('off');
104
+ # plt.imshow(vis)
105
+
106
+ return image_relevance
107
+
108
+
109
+ def show_heatmap_on_text(text, text_encoding, R_text, show=True):
110
+ CLS_idx = text_encoding.argmax(dim=-1)
111
+ R_text = R_text[CLS_idx, 1:CLS_idx]
112
+ text_scores = R_text / R_text.sum()
113
+ text_scores = text_scores.flatten()
114
+ # print(text_scores)
115
+ text_tokens=_tokenizer.encode(text)
116
+ text_tokens_decoded=[_tokenizer.decode([a]) for a in text_tokens]
117
+ vis_data_records = [visualization.VisualizationDataRecord(text_scores,0,0,0,0,0,text_tokens_decoded,1)]
118
+
119
+ if show:
120
+ visualization.visualize_text(vis_data_records)
121
+
122
+ return text_scores, text_tokens_decoded
123
+
124
+
125
+ def show_img_heatmap(image_relevance, image, orig_image, device, show=True):
126
+ return show_image_relevance(image_relevance, image, orig_image, device, show=show)
127
+
128
+
129
+ def show_txt_heatmap(text, text_encoding, R_text, show=True):
130
+ return show_heatmap_on_text(text, text_encoding, R_text, show=show)
131
+
132
+
133
+ def load_dataset():
134
+ dataset_path = os.path.join('..', '..', 'dummy-data', '71226_segments' + '.pt')
135
+ device = "cuda" if torch.cuda.is_available() else "cpu"
136
+
137
+ data = torch.load(dataset_path, map_location=device)
138
+
139
+ return data
140
+
141
+
142
+ class color:
143
+ PURPLE = '\033[95m'
144
+ CYAN = '\033[96m'
145
+ DARKCYAN = '\033[36m'
146
+ BLUE = '\033[94m'
147
+ GREEN = '\033[92m'
148
+ YELLOW = '\033[93m'
149
+ RED = '\033[91m'
150
+ BOLD = '\033[1m'
151
+ UNDERLINE = '\033[4m'
152
+ END = '\033[0m'
app.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import gradio as gr
3
+
4
+ # sys.path.append("../")
5
+ sys.path.append("CLIP_explainability/Transformer-MM-Explainability/")
6
+
7
+ import torch
8
+ import CLIP.clip as clip
9
+
10
+
11
+ from clip_grounding.utils.image import pad_to_square
12
+ from clip_grounding.datasets.png import (
13
+ overlay_relevance_map_on_image,
14
+ )
15
+ from CLIP_explainability.utils import interpret, show_img_heatmap, show_heatmap_on_text
16
+
17
+ clip.clip._MODELS = {
18
+ "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
19
+ "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
20
+ }
21
+
22
+ device = "cuda" if torch.cuda.is_available() else "cpu"
23
+ model, preprocess = clip.load("ViT-B/32", device=device, jit=False)
24
+
25
+ # Gradio Section:
26
+ def run_demo(image, text):
27
+ orig_image = pad_to_square(image)
28
+ img = preprocess(orig_image).unsqueeze(0).to(device)
29
+ text_input = clip.tokenize([text]).to(device)
30
+
31
+ R_text, R_image = interpret(model=model, image=img, texts=text_input, device=device)
32
+
33
+ image_relevance = show_img_heatmap(R_image[0], img, orig_image=orig_image, device=device, show=False)
34
+ overlapped = overlay_relevance_map_on_image(image, image_relevance)
35
+
36
+ text_scores, text_tokens_decoded = show_heatmap_on_text(text, text_input, R_text[0], show=False)
37
+
38
+ highlighted_text = []
39
+ for i, token in enumerate(text_tokens_decoded):
40
+ highlighted_text.append((str(token), float(text_scores[i])))
41
+
42
+ return overlapped, highlighted_text
43
+
44
+ input_img = gr.inputs.Image(type='pil', label="Original Image")
45
+ input_txt = "text"
46
+ inputs = [input_img, input_txt]
47
+
48
+ outputs = [gr.inputs.Image(type='pil', label="Output Image"), "highlight"]
49
+
50
+
51
+ iface = gr.Interface(fn=run_demo,
52
+ inputs=inputs,
53
+ outputs=outputs,
54
+ title="CLIP Grounding Explainability",
55
+ description="A demonstration based on the Generic Attention-model Explainability method for Interpreting Bi-Modal Transformers by Chefer et al. (2021): https://github.com/hila-chefer/Transformer-MM-Explainability.",
56
+ examples=[["example_images/London.png", "London Eye"],
57
+ ["example_images/London.png", "Big Ben"],
58
+ ["example_images/harrypotter.png", "Harry"],
59
+ ["example_images/harrypotter.png", "Hermione"],
60
+ ["example_images/harrypotter.png", "Ron"],
61
+ ["example_images/Amsterdam.png", "Amsterdam canal"],
62
+ ["example_images/Amsterdam.png", "Old buildings"],
63
+ ["example_images/Amsterdam.png", "Pink flowers"],
64
+ ["example_images/dogs_on_bed.png", "Two dogs"],
65
+ ["example_images/dogs_on_bed.png", "Book"],
66
+ ["example_images/dogs_on_bed.png", "Cat"]])
67
+ iface.launch(debug=True)
clip_grounding/datasets/png.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Dataset object for Panoptic Narrative Grounding.
3
+
4
+ Paper: https://openaccess.thecvf.com/content/ICCV2021/papers/Gonzalez_Panoptic_Narrative_Grounding_ICCV_2021_paper.pdf
5
+ """
6
+
7
+ import os
8
+ from os.path import join, isdir, exists
9
+
10
+ import torch
11
+ from torch.utils.data import Dataset
12
+ import cv2
13
+ from PIL import Image
14
+ from skimage import io
15
+ import numpy as np
16
+ import textwrap
17
+ import matplotlib.pyplot as plt
18
+ from matplotlib import transforms
19
+ from imgaug.augmentables.segmaps import SegmentationMapsOnImage
20
+ import matplotlib.colors as mc
21
+
22
+ from clip_grounding.utils.io import load_json
23
+ from clip_grounding.datasets.png_utils import show_image_and_caption
24
+
25
+
26
+ class PNG(Dataset):
27
+ """Panoptic Narrative Grounding."""
28
+
29
+ def __init__(self, dataset_root, split) -> None:
30
+ """
31
+ Initializer.
32
+
33
+ Args:
34
+ dataset_root (str): path to the folder containing PNG dataset
35
+ split (str): MS-COCO split such as train2017/val2017
36
+ """
37
+ super().__init__()
38
+
39
+ assert isdir(dataset_root)
40
+ self.dataset_root = dataset_root
41
+
42
+ assert split in ["val2017"], f"Split {split} not supported. "\
43
+ "Currently, only supports split `val2017`."
44
+ self.split = split
45
+
46
+ self.ann_dir = join(self.dataset_root, "annotations")
47
+ # feat_dir = join(self.dataset_root, "features")
48
+
49
+ panoptic = load_json(join(self.ann_dir, "panoptic_{:s}.json".format(split)))
50
+ images = panoptic["images"]
51
+ self.images_info = {i["id"]: i for i in images}
52
+ panoptic_anns = panoptic["annotations"]
53
+ self.panoptic_anns = {int(a["image_id"]): a for a in panoptic_anns}
54
+
55
+ # self.panoptic_pred_path = join(
56
+ # feat_dir, split, "panoptic_seg_predictions"
57
+ # )
58
+ # assert isdir(self.panoptic_pred_path)
59
+
60
+ panoptic_narratives_path = join(self.dataset_root, "annotations", f"png_coco_{split}.json")
61
+ self.panoptic_narratives = load_json(panoptic_narratives_path)
62
+
63
+ def __len__(self):
64
+ return len(self.panoptic_narratives)
65
+
66
+ def get_image_path(self, image_id: str):
67
+ image_path = join(self.dataset_root, "images", self.split, f"{image_id.zfill(12)}.jpg")
68
+ return image_path
69
+
70
+ def __getitem__(self, idx: int):
71
+ narr = self.panoptic_narratives[idx]
72
+
73
+ image_id = narr["image_id"]
74
+ image_path = self.get_image_path(image_id)
75
+ assert exists(image_path)
76
+
77
+ image = Image.open(image_path)
78
+ caption = narr["caption"]
79
+
80
+ # show_single_image(image, title=caption, titlesize=12)
81
+
82
+ segments = narr["segments"]
83
+
84
+ image_id = int(narr["image_id"])
85
+ panoptic_ann = self.panoptic_anns[image_id]
86
+ panoptic_ann = self.panoptic_anns[image_id]
87
+ segment_infos = {}
88
+ for s in panoptic_ann["segments_info"]:
89
+ idi = s["id"]
90
+ segment_infos[idi] = s
91
+
92
+ image_info = self.images_info[image_id]
93
+ panoptic_segm = io.imread(
94
+ join(
95
+ self.ann_dir,
96
+ "panoptic_segmentation",
97
+ self.split,
98
+ "{:012d}.png".format(image_id),
99
+ )
100
+ )
101
+ panoptic_segm = (
102
+ panoptic_segm[:, :, 0]
103
+ + panoptic_segm[:, :, 1] * 256
104
+ + panoptic_segm[:, :, 2] * 256 ** 2
105
+ )
106
+
107
+ panoptic_ann = self.panoptic_anns[image_id]
108
+ # panoptic_pred = io.imread(
109
+ # join(self.panoptic_pred_path, "{:012d}.png".format(image_id))
110
+ # )[:, :, 0]
111
+
112
+
113
+ # # select a single utterance to visualize
114
+ # segment = segments[7]
115
+ # segment_ids = segment["segment_ids"]
116
+ # segment_mask = np.zeros((image_info["height"], image_info["width"]))
117
+ # for segment_id in segment_ids:
118
+ # segment_id = int(segment_id)
119
+ # segment_mask[panoptic_segm == segment_id] = 1.
120
+
121
+ utterances = [s["utterance"] for s in segments]
122
+ outputs = []
123
+ for i, segment in enumerate(segments):
124
+
125
+ # create segmentation mask on image
126
+ segment_ids = segment["segment_ids"]
127
+
128
+ # if no annotation for this word, skip
129
+ if not len(segment_ids):
130
+ continue
131
+
132
+ segment_mask = np.zeros((image_info["height"], image_info["width"]))
133
+ for segment_id in segment_ids:
134
+ segment_id = int(segment_id)
135
+ segment_mask[panoptic_segm == segment_id] = 1.
136
+
137
+ # store the outputs
138
+ text_mask = np.zeros(len(utterances))
139
+ text_mask[i] = 1.
140
+ segment_data = dict(
141
+ image=image,
142
+ text=utterances,
143
+ image_mask=segment_mask,
144
+ text_mask=text_mask,
145
+ full_caption=caption,
146
+ )
147
+ outputs.append(segment_data)
148
+
149
+ # # visualize segmentation mask with associated text
150
+ # segment_color = "red"
151
+ # segmap = SegmentationMapsOnImage(
152
+ # segment_mask.astype(np.uint8), shape=segment_mask.shape,
153
+ # )
154
+ # image_with_segmap = segmap.draw_on_image(np.asarray(image), colors=[0, COLORS[segment_color]])[0]
155
+ # image_with_segmap = Image.fromarray(image_with_segmap)
156
+
157
+ # colors = ["black" for _ in range(len(utterances))]
158
+ # colors[i] = segment_color
159
+ # show_image_and_caption(image_with_segmap, utterances, colors)
160
+
161
+ return outputs
162
+
163
+
164
+ def overlay_segmask_on_image(image, image_mask, segment_color="red"):
165
+ segmap = SegmentationMapsOnImage(
166
+ image_mask.astype(np.uint8), shape=image_mask.shape,
167
+ )
168
+ rgb_color = mc.to_rgb(segment_color)
169
+ rgb_color = 255 * np.array(rgb_color)
170
+ image_with_segmap = segmap.draw_on_image(np.asarray(image), colors=[0, rgb_color])[0]
171
+ image_with_segmap = Image.fromarray(image_with_segmap)
172
+ return image_with_segmap
173
+
174
+
175
+ def get_text_colors(text, text_mask, segment_color="red"):
176
+ colors = ["black" for _ in range(len(text))]
177
+ colors[text_mask.nonzero()[0][0]] = segment_color
178
+ return colors
179
+
180
+
181
+ def overlay_relevance_map_on_image(image, heatmap):
182
+ width, height = image.size
183
+
184
+ # resize the heatmap to image size
185
+ heatmap = cv2.resize(heatmap, (width, height))
186
+ heatmap = np.uint8(255 * heatmap)
187
+ heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
188
+ heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
189
+
190
+ # create overlapped super image
191
+ img = np.asarray(image)
192
+ super_img = heatmap * 0.4 + img * 0.6
193
+ super_img = np.uint8(super_img)
194
+ super_img = Image.fromarray(super_img)
195
+
196
+ return super_img
197
+
198
+
199
+ def visualize_item(image, text, image_mask, text_mask, segment_color="red"):
200
+
201
+ segmap = SegmentationMapsOnImage(
202
+ image_mask.astype(np.uint8), shape=image_mask.shape,
203
+ )
204
+ rgb_color = mc.to_rgb(segment_color)
205
+ rgb_color = 255 * np.array(rgb_color)
206
+ image_with_segmap = segmap.draw_on_image(np.asarray(image), colors=[0, rgb_color])[0]
207
+ image_with_segmap = Image.fromarray(image_with_segmap)
208
+
209
+ colors = ["black" for _ in range(len(text))]
210
+
211
+ text_idx = text_mask.argmax()
212
+ colors[text_idx] = segment_color
213
+ show_image_and_caption(image_with_segmap, text, colors)
214
+
215
+
216
+
217
+ if __name__ == "__main__":
218
+ from clip_grounding.utils.paths import REPO_PATH, DATASET_ROOTS
219
+
220
+ PNG_ROOT = DATASET_ROOTS["PNG"]
221
+ dataset = PNG(dataset_root=PNG_ROOT, split="val2017")
222
+
223
+ item = dataset[0]
224
+ sub_item = item[1]
225
+ visualize_item(
226
+ image=sub_item["image"],
227
+ text=sub_item["text"],
228
+ image_mask=sub_item["image_mask"],
229
+ text_mask=sub_item["text_mask"],
230
+ segment_color="red",
231
+ )
clip_grounding/datasets/png_utils.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Helper functions for Panoptic Narrative Grounding."""
2
+
3
+ import os
4
+ from os.path import join, isdir, exists
5
+ from typing import List
6
+
7
+ import torch
8
+ from PIL import Image
9
+ from skimage import io
10
+ import numpy as np
11
+ import textwrap
12
+ import matplotlib.pyplot as plt
13
+ from matplotlib import transforms
14
+ from imgaug.augmentables.segmaps import SegmentationMapsOnImage
15
+
16
+
17
+ def rainbow_text(x,y,ls,lc,fig, ax,**kw):
18
+ """
19
+ Take a list of strings ``ls`` and colors ``lc`` and place them next to each
20
+ other, with text ls[i] being shown in color lc[i].
21
+
22
+ Ref: https://stackoverflow.com/questions/9169052/partial-coloring-of-text-in-matplotlib
23
+ """
24
+ t = ax.transAxes
25
+
26
+ for s,c in zip(ls,lc):
27
+
28
+ text = ax.text(x,y,s+" ",color=c, transform=t, **kw)
29
+ text.draw(fig.canvas.get_renderer())
30
+ ex = text.get_window_extent()
31
+ t = transforms.offset_copy(text._transform, x=ex.width, units='dots')
32
+
33
+
34
+ def find_first_index_greater_than(elements, key):
35
+ return next(x[0] for x in enumerate(elements) if x[1] > key)
36
+
37
+
38
+ def split_caption_phrases(caption_phrases, colors, max_char_in_a_line=50):
39
+ char_lengths = np.cumsum([len(x) for x in caption_phrases])
40
+ thresholds = [max_char_in_a_line * i for i in range(1, 1 + char_lengths[-1] // max_char_in_a_line)]
41
+
42
+ utt_per_line = []
43
+ col_per_line = []
44
+ start_index = 0
45
+ for t in thresholds:
46
+ index = find_first_index_greater_than(char_lengths, t)
47
+ utt_per_line.append(caption_phrases[start_index:index])
48
+ col_per_line.append(colors[start_index:index])
49
+ start_index = index
50
+
51
+ return utt_per_line, col_per_line
52
+
53
+
54
+ def show_image_and_caption(image: Image, caption_phrases: list, colors: list = None):
55
+
56
+ if colors is None:
57
+ colors = ["black" for _ in range(len(caption_phrases))]
58
+
59
+ fig, axes = plt.subplots(1, 2, figsize=(15, 4))
60
+
61
+ ax = axes[0]
62
+ ax.imshow(image)
63
+ ax.set_xticks([])
64
+ ax.set_yticks([])
65
+
66
+ ax = axes[1]
67
+ utt_per_line, col_per_line = split_caption_phrases(caption_phrases, colors, max_char_in_a_line=50)
68
+ y = 0.7
69
+ for U, C in zip(utt_per_line, col_per_line):
70
+ rainbow_text(
71
+ 0., y,
72
+ U,
73
+ C,
74
+ size=15, ax=ax, fig=fig,
75
+ horizontalalignment='left',
76
+ verticalalignment='center',
77
+ )
78
+ y -= 0.11
79
+
80
+ ax.axis("off")
81
+
82
+ fig.tight_layout()
83
+ plt.show()
84
+
85
+
86
+ def show_images_and_caption(
87
+ images: List,
88
+ caption_phrases: list,
89
+ colors: list = None,
90
+ image_xlabels: List=[],
91
+ figsize=None,
92
+ show=False,
93
+ xlabelsize=14,
94
+ ):
95
+
96
+ if colors is None:
97
+ colors = ["black" for _ in range(len(caption_phrases))]
98
+ caption_phrases[0] = caption_phrases[0].capitalize()
99
+
100
+ if figsize is None:
101
+ figsize = (5 * len(images) + 8, 4)
102
+
103
+ if image_xlabels is None:
104
+ image_xlabels = ["" for _ in range(len(images))]
105
+
106
+ fig, axes = plt.subplots(1, len(images) + 1, figsize=figsize)
107
+
108
+ for i, image in enumerate(images):
109
+ ax = axes[i]
110
+ ax.imshow(image)
111
+ ax.set_xticks([])
112
+ ax.set_yticks([])
113
+ ax.set_xlabel(image_xlabels[i], fontsize=xlabelsize)
114
+
115
+ ax = axes[-1]
116
+ utt_per_line, col_per_line = split_caption_phrases(caption_phrases, colors, max_char_in_a_line=40)
117
+ y = 0.7
118
+ for U, C in zip(utt_per_line, col_per_line):
119
+ rainbow_text(
120
+ 0., y,
121
+ U,
122
+ C,
123
+ size=23, ax=ax, fig=fig,
124
+ horizontalalignment='left',
125
+ verticalalignment='center',
126
+ # weight='bold'
127
+ )
128
+ y -= 0.11
129
+
130
+ ax.axis("off")
131
+
132
+ fig.tight_layout()
133
+
134
+ if show:
135
+ plt.show()
clip_grounding/evaluation/clip_on_png.py ADDED
@@ -0,0 +1,362 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Evaluates cross-modal correspondence of CLIP on PNG images."""
2
+
3
+ import os
4
+ import sys
5
+ from os.path import join, exists
6
+
7
+ import warnings
8
+ warnings.filterwarnings('ignore')
9
+
10
+ from clip_grounding.utils.paths import REPO_PATH
11
+ sys.path.append(join(REPO_PATH, "CLIP_explainability/Transformer-MM-Explainability/"))
12
+
13
+ import torch
14
+ import CLIP.clip as clip
15
+ from PIL import Image
16
+ import numpy as np
17
+ import cv2
18
+ import matplotlib.pyplot as plt
19
+ from captum.attr import visualization
20
+ from torchmetrics import JaccardIndex
21
+ from collections import defaultdict
22
+ from IPython.core.display import display, HTML
23
+ from skimage import filters
24
+
25
+ from CLIP_explainability.utils import interpret, show_img_heatmap, show_txt_heatmap, color, _tokenizer
26
+ from clip_grounding.datasets.png import PNG
27
+ from clip_grounding.utils.image import pad_to_square
28
+ from clip_grounding.utils.visualize import show_grid_of_images
29
+ from clip_grounding.utils.log import tqdm_iterator, print_update
30
+
31
+
32
+ # global usage
33
+ # specify device
34
+ device = "cuda" if torch.cuda.is_available() else "cpu"
35
+
36
+ # load CLIP model
37
+ model, preprocess = clip.load("ViT-B/32", device=device, jit=False)
38
+
39
+
40
+ def show_cam(mask):
41
+ heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
42
+ heatmap = np.float32(heatmap) / 255
43
+ cam = heatmap
44
+ cam = cam / np.max(cam)
45
+ return cam
46
+
47
+
48
+ def interpret_and_generate(model, img, texts, orig_image, return_outputs=False, show=True):
49
+ text = clip.tokenize(texts).to(device)
50
+ R_text, R_image = interpret(model=model, image=img, texts=text, device=device)
51
+ batch_size = text.shape[0]
52
+
53
+ outputs = []
54
+ for i in range(batch_size):
55
+ text_scores, text_tokens_decoded = show_txt_heatmap(texts[i], text[i], R_text[i], show=show)
56
+ image_relevance = show_img_heatmap(R_image[i], img, orig_image=orig_image, device=device, show=show)
57
+ plt.show()
58
+ outputs.append({"text_scores": text_scores, "image_relevance": image_relevance, "tokens_decoded": text_tokens_decoded})
59
+
60
+ if return_outputs:
61
+ return outputs
62
+
63
+
64
+ def process_entry_text_to_image(entry, unimodal=False):
65
+ image = entry['image']
66
+ text_mask = entry['text_mask']
67
+ text = entry['text']
68
+ orig_image = pad_to_square(image)
69
+
70
+ img = preprocess(orig_image).unsqueeze(0).to(device)
71
+ text_index = text_mask.argmax()
72
+ texts = [text[text_index]] if not unimodal else ['']
73
+
74
+ return img, texts, orig_image
75
+
76
+
77
+ def preprocess_ground_truth_mask(mask, resize_shape):
78
+ mask = Image.fromarray(mask.astype(np.uint8) * 255)
79
+ mask = pad_to_square(mask, color=0)
80
+ mask = mask.resize(resize_shape)
81
+ mask = np.asarray(mask) / 255.
82
+ return mask
83
+
84
+
85
+ def apply_otsu_threshold(relevance_map):
86
+ threshold = filters.threshold_otsu(relevance_map)
87
+ otsu_map = (relevance_map > threshold).astype(np.uint8)
88
+ return otsu_map
89
+
90
+
91
+ def evaluate_text_to_image(method, dataset, debug=False):
92
+
93
+ instance_level_metrics = defaultdict(list)
94
+ entry_level_metrics = defaultdict(list)
95
+
96
+ jaccard = JaccardIndex(num_classes=2)
97
+ jaccard = jaccard.to(device)
98
+
99
+ num_iter = len(dataset)
100
+ if debug:
101
+ num_iter = 100
102
+
103
+ iterator = tqdm_iterator(range(num_iter), desc=f"Evaluating on {type(dataset).__name__} dataset")
104
+ for idx in iterator:
105
+ instance = dataset[idx]
106
+
107
+ instance_iou = 0.
108
+ for entry in instance:
109
+
110
+ # preprocess the image and text
111
+ unimodal = True if method == "clip-unimodal" else False
112
+ test_img, test_texts, orig_image = process_entry_text_to_image(entry, unimodal=unimodal)
113
+
114
+ if method in ["clip", "clip-unimodal"]:
115
+
116
+ # compute the relevance scores
117
+ outputs = interpret_and_generate(model, test_img, test_texts, orig_image, return_outputs=True, show=False)
118
+
119
+ # use the image relevance score to compute IoU w.r.t. ground truth segmentation masks
120
+
121
+ # NOTE: since we pass single entry (1-sized batch), outputs[0] contains our reqd outputs
122
+ relevance_map = outputs[0]["image_relevance"]
123
+ elif method == "random":
124
+ relevance_map = np.random.uniform(low=0., high=1., size=tuple(test_img.shape[2:]))
125
+
126
+ otsu_relevance_map = apply_otsu_threshold(relevance_map)
127
+
128
+ ground_truth_mask = entry["image_mask"]
129
+ ground_truth_mask = preprocess_ground_truth_mask(ground_truth_mask, relevance_map.shape)
130
+
131
+ entry_iou = jaccard(
132
+ torch.from_numpy(otsu_relevance_map).to(device),
133
+ torch.from_numpy(ground_truth_mask.astype(np.uint8)).to(device),
134
+ )
135
+ entry_iou = entry_iou.item()
136
+ instance_iou += (entry_iou / len(entry))
137
+
138
+ entry_level_metrics["iou"].append(entry_iou)
139
+
140
+ # capture instance (image-sentence pair) level IoU
141
+ instance_level_metrics["iou"].append(instance_iou)
142
+
143
+ average_metrics = {k: np.mean(v) for k, v in entry_level_metrics.items()}
144
+
145
+ return (
146
+ average_metrics,
147
+ instance_level_metrics,
148
+ entry_level_metrics
149
+ )
150
+
151
+
152
+ def process_entry_image_to_text(entry, unimodal=False):
153
+
154
+ if not unimodal:
155
+ if len(np.asarray(entry["image"]).shape) == 3:
156
+ mask = np.repeat(np.expand_dims(entry['image_mask'], -1), 3, axis=-1)
157
+ else:
158
+ mask = np.asarray(entry['image_mask'])
159
+
160
+ masked_image = (mask * np.asarray(entry['image'])).astype(np.uint8)
161
+ masked_image = Image.fromarray(masked_image)
162
+ orig_image = pad_to_square(masked_image)
163
+ img = preprocess(orig_image).unsqueeze(0).to(device)
164
+ else:
165
+ orig_image_shape = max(np.asarray(entry['image']).shape[:2])
166
+ orig_image = Image.fromarray(np.zeros((orig_image_shape, orig_image_shape, 3), dtype=np.uint8))
167
+ # orig_image = Image.fromarray(np.random.randint(0, 256, (orig_image_shape, orig_image_shape, 3), dtype=np.uint8))
168
+ img = preprocess(orig_image).unsqueeze(0).to(device)
169
+
170
+ texts = [' '.join(entry['text'])]
171
+
172
+ return img, texts, orig_image
173
+
174
+
175
+ def process_text_mask(text, text_mask, tokens):
176
+
177
+ token_level_mask = np.zeros(len(tokens))
178
+
179
+ for label, subtext in zip(text_mask, text):
180
+
181
+ subtext_tokens=_tokenizer.encode(subtext)
182
+ subtext_tokens_decoded=[_tokenizer.decode([a]) for a in subtext_tokens]
183
+
184
+ if label == 1:
185
+ start = tokens.index(subtext_tokens_decoded[0])
186
+ end = tokens.index(subtext_tokens_decoded[-1])
187
+ token_level_mask[start:end + 1] = 1
188
+
189
+ return token_level_mask
190
+
191
+
192
+ def evaluate_image_to_text(method, dataset, debug=False, clamp_sentence_len=70):
193
+
194
+ instance_level_metrics = defaultdict(list)
195
+ entry_level_metrics = defaultdict(list)
196
+
197
+ # skipped if text length > 77 which is CLIP limit
198
+ num_entries_skipped = 0
199
+ num_total_entries = 0
200
+
201
+ num_iter = len(dataset)
202
+ if debug:
203
+ num_iter = 100
204
+
205
+ jaccard_image_to_text = JaccardIndex(num_classes=2).to(device)
206
+
207
+ iterator = tqdm_iterator(range(num_iter), desc=f"Evaluating on {type(dataset).__name__} dataset")
208
+ for idx in iterator:
209
+ instance = dataset[idx]
210
+
211
+ instance_iou = 0.
212
+ for entry in instance:
213
+ num_total_entries += 1
214
+
215
+ # preprocess the image and text
216
+ unimodal = True if method == "clip-unimodal" else False
217
+ img, texts, orig_image = process_entry_image_to_text(entry, unimodal=unimodal)
218
+
219
+ appx_total_sent_len = np.sum([len(x.split(" ")) for x in texts])
220
+ if appx_total_sent_len > clamp_sentence_len:
221
+ # print(f"Skipping an entry since it's text has appx"\
222
+ # " {appx_total_sent_len} while CLIP cannot process beyond {clamp_sentence_len}")
223
+ num_entries_skipped += 1
224
+ continue
225
+
226
+ # compute the relevance scores
227
+ if method in ["clip", "clip-unimodal"]:
228
+ try:
229
+ outputs = interpret_and_generate(model, img, texts, orig_image, return_outputs=True, show=False)
230
+ except:
231
+ num_entries_skipped += 1
232
+ continue
233
+ elif method == "random":
234
+ text = texts[0]
235
+ text_tokens = _tokenizer.encode(text)
236
+ text_tokens_decoded=[_tokenizer.decode([a]) for a in text_tokens]
237
+ outputs = [
238
+ {
239
+ "text_scores": np.random.uniform(low=0., high=1., size=len(text_tokens_decoded)),
240
+ "tokens_decoded": text_tokens_decoded,
241
+ }
242
+ ]
243
+
244
+ # use the text relevance score to compute IoU w.r.t. ground truth text masks
245
+ # NOTE: since we pass single entry (1-sized batch), outputs[0] contains our reqd outputs
246
+ token_relevance_scores = outputs[0]["text_scores"]
247
+ if isinstance(token_relevance_scores, torch.Tensor):
248
+ token_relevance_scores = token_relevance_scores.cpu().numpy()
249
+ token_relevance_scores = apply_otsu_threshold(token_relevance_scores)
250
+ token_ground_truth_mask = process_text_mask(entry["text"], entry["text_mask"], outputs[0]["tokens_decoded"])
251
+
252
+ entry_iou = jaccard_image_to_text(
253
+ torch.from_numpy(token_relevance_scores).to(device),
254
+ torch.from_numpy(token_ground_truth_mask.astype(np.uint8)).to(device),
255
+ )
256
+ entry_iou = entry_iou.item()
257
+
258
+ instance_iou += (entry_iou / len(entry))
259
+ entry_level_metrics["iou"].append(entry_iou)
260
+
261
+ # capture instance (image-sentence pair) level IoU
262
+ instance_level_metrics["iou"].append(instance_iou)
263
+
264
+ print(f"CAUTION: Skipped {(num_entries_skipped / num_total_entries) * 100} % since these had length > 77 (CLIP limit).")
265
+ average_metrics = {k: np.mean(v) for k, v in entry_level_metrics.items()}
266
+
267
+ return (
268
+ average_metrics,
269
+ instance_level_metrics,
270
+ entry_level_metrics
271
+ )
272
+
273
+
274
+ if __name__ == "__main__":
275
+
276
+ import argparse
277
+ parser = argparse.ArgumentParser("Evaluate Image-to-Text & Text-to-Image model")
278
+ parser.add_argument(
279
+ "--eval_method", type=str, default="clip",
280
+ choices=["clip", "random", "clip-unimodal"],
281
+ help="Evaluation method to use",
282
+ )
283
+ parser.add_argument(
284
+ "--ignore_cache", action="store_true",
285
+ help="Ignore cache and force re-generation of the results",
286
+ )
287
+ parser.add_argument(
288
+ "--debug", action="store_true",
289
+ help="Run evaluation on a small subset of the dataset",
290
+ )
291
+ args = parser.parse_args()
292
+
293
+ print_update("Using evaluation method: {}".format(args.eval_method))
294
+
295
+
296
+ clip.clip._MODELS = {
297
+ "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
298
+ "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
299
+ }
300
+
301
+ # specify device
302
+ device = "cuda" if torch.cuda.is_available() else "cpu"
303
+
304
+ # load CLIP model
305
+ print_update("Loading CLIP model...")
306
+ model, preprocess = clip.load("ViT-B/32", device=device, jit=False)
307
+ print()
308
+
309
+ # load PNG dataset
310
+ print_update("Loading PNG dataset...")
311
+ dataset = PNG(dataset_root=join(REPO_PATH, "data", "panoptic_narrative_grounding"), split="val2017")
312
+ print()
313
+
314
+ # evaluate
315
+
316
+ # save metrics
317
+ metrics_dir = join(REPO_PATH, "outputs")
318
+ os.makedirs(metrics_dir, exist_ok=True)
319
+
320
+ metrics_path = join(metrics_dir, f"{args.eval_method}_on_{type(dataset).__name__}_text2image_metrics.pt")
321
+ if (not exists(metrics_path)) or args.ignore_cache:
322
+ print_update("Computing metrics for text-to-image grounding")
323
+ average_metrics, instance_level_metrics, entry_level_metrics = evaluate_text_to_image(
324
+ args.eval_method, dataset, debug=args.debug,
325
+ )
326
+ metrics = {
327
+ "average_metrics": average_metrics,
328
+ "instance_level_metrics":instance_level_metrics,
329
+ "entry_level_metrics": entry_level_metrics
330
+ }
331
+
332
+ torch.save(metrics, metrics_path)
333
+ print("TEXT2IMAGE METRICS SAVED TO:", metrics_path)
334
+ else:
335
+ print(f"Metrics already exist at: {metrics_path}. Loading cached metrics.")
336
+ metrics = torch.load(metrics_path)
337
+ average_metrics = metrics["average_metrics"]
338
+ print("TEXT2IMAGE METRICS:", np.round(average_metrics["iou"], 4))
339
+
340
+ print()
341
+
342
+ metrics_path = join(metrics_dir, f"{args.eval_method}_on_{type(dataset).__name__}_image2text_metrics.pt")
343
+ if (not exists(metrics_path)) or args.ignore_cache:
344
+ print_update("Computing metrics for image-to-text grounding")
345
+ average_metrics, instance_level_metrics, entry_level_metrics = evaluate_image_to_text(
346
+ args.eval_method, dataset, debug=args.debug,
347
+ )
348
+
349
+ torch.save(
350
+ {
351
+ "average_metrics": average_metrics,
352
+ "instance_level_metrics":instance_level_metrics,
353
+ "entry_level_metrics": entry_level_metrics
354
+ },
355
+ metrics_path,
356
+ )
357
+ print("IMAGE2TEXT METRICS SAVED TO:", metrics_path)
358
+ else:
359
+ print(f"Metrics already exist at: {metrics_path}. Loading cached metrics.")
360
+ metrics = torch.load(metrics_path)
361
+ average_metrics = metrics["average_metrics"]
362
+ print("IMAGE2TEXT METRICS:", np.round(average_metrics["iou"], 4))
clip_grounding/evaluation/qualitative_results.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Converts notebook for qualitative results to a python script."""
2
+ import sys
3
+ from os.path import join
4
+
5
+ from clip_grounding.utils.paths import REPO_PATH
6
+ sys.path.append(join(REPO_PATH, "CLIP_explainability/Transformer-MM-Explainability/"))
7
+
8
+ import os
9
+ import torch
10
+ import matplotlib.pyplot as plt
11
+ import numpy as np
12
+ from matplotlib.patches import Patch
13
+ import CLIP.clip as clip
14
+ import cv2
15
+ from PIL import Image
16
+ from glob import glob
17
+ from natsort import natsorted
18
+
19
+ from clip_grounding.utils.paths import REPO_PATH
20
+ from clip_grounding.utils.io import load_json
21
+ from clip_grounding.utils.visualize import set_latex_fonts, show_grid_of_images
22
+ from clip_grounding.utils.image import pad_to_square
23
+ from clip_grounding.datasets.png_utils import show_images_and_caption
24
+ from clip_grounding.datasets.png import (
25
+ PNG,
26
+ visualize_item,
27
+ overlay_segmask_on_image,
28
+ overlay_relevance_map_on_image,
29
+ get_text_colors,
30
+ )
31
+ from clip_grounding.evaluation.clip_on_png import (
32
+ process_entry_image_to_text,
33
+ process_entry_text_to_image,
34
+ interpret_and_generate,
35
+ )
36
+
37
+ # load dataset
38
+ dataset = PNG(dataset_root=join(REPO_PATH, "data/panoptic_narrative_grounding"), split="val2017")
39
+
40
+ # load CLIP model
41
+ device = "cuda" if torch.cuda.is_available() else "cpu"
42
+ model, preprocess = clip.load("ViT-B/32", device=device, jit=False)
43
+
44
+
45
+ def visualize_entry_text_to_image(entry, pad_images=True, figsize=(18, 5)):
46
+ test_img, test_texts, orig_image = process_entry_text_to_image(entry, unimodal=False)
47
+ outputs = interpret_and_generate(model, test_img, test_texts, orig_image, return_outputs=True, show=False)
48
+ relevance_map = outputs[0]["image_relevance"]
49
+
50
+ image_with_mask = overlay_segmask_on_image(entry["image"], entry["image_mask"])
51
+ if pad_images:
52
+ image_with_mask = pad_to_square(image_with_mask)
53
+
54
+ image_with_relevance_map = overlay_relevance_map_on_image(entry["image"], relevance_map)
55
+ if pad_images:
56
+ image_with_relevance_map = pad_to_square(image_with_relevance_map)
57
+
58
+ text_colors = get_text_colors(entry["text"], entry["text_mask"])
59
+
60
+ show_images_and_caption(
61
+ [image_with_mask, image_with_relevance_map],
62
+ entry["text"], text_colors, figsize=figsize,
63
+ image_xlabels=["Ground truth segmentation", "Predicted relevance map"]
64
+ )
65
+
66
+
67
+ def create_and_save_gif(filenames, save_path, **kwargs):
68
+ import imageio
69
+ images = []
70
+ for filename in filenames:
71
+ images.append(imageio.imread(filename))
72
+ imageio.mimsave(save_path, images, **kwargs)
73
+
74
+
75
+ idx = 100
76
+ instance = dataset[idx]
77
+
78
+ instance_dir = join(REPO_PATH, "figures", f"instance-{idx}")
79
+ os.makedirs(instance_dir, exist_ok=True)
80
+
81
+ for i, entry in enumerate(instance):
82
+ del entry["full_caption"]
83
+
84
+ visualize_entry_text_to_image(entry, pad_images=False, figsize=(19, 4))
85
+
86
+ save_path = instance_dir
87
+ plt.savefig(join(instance_dir, f"viz-{i}.png"), bbox_inches="tight")
88
+
89
+
90
+ filenames = natsorted(glob(join(instance_dir, "viz-*.png")))
91
+ save_path = join(REPO_PATH, "media", "sample.gif")
92
+
93
+ create_and_save_gif(filenames, save_path, duration=3)
clip_grounding/utils/image.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Image operations."""
2
+ from copy import deepcopy
3
+ from PIL import Image
4
+
5
+
6
+ def center_crop(im: Image):
7
+ width, height = im.size
8
+ new_width = width if width < height else height
9
+ new_height = height if height < width else width
10
+
11
+ left = (width - new_width)/2
12
+ top = (height - new_height)/2
13
+ right = (width + new_width)/2
14
+ bottom = (height + new_height)/2
15
+
16
+ # Crop the center of the image
17
+ im = im.crop((left, top, right, bottom))
18
+
19
+ return im
20
+
21
+
22
+ def pad_to_square(im: Image, color=(0, 0, 0)):
23
+ im = deepcopy(im)
24
+ width, height = im.size
25
+
26
+ vert_pad = (max(width, height) - height) // 2
27
+ hor_pad = (max(width, height) - width) // 2
28
+
29
+ if len(im.mode) == 3:
30
+ color = (0, 0, 0)
31
+ elif len(im.mode) == 1:
32
+ color = 0
33
+ else:
34
+ raise ValueError(f"Image mode not supported. Image has {im.mode} channels.")
35
+
36
+ return add_margin(im, vert_pad, hor_pad, vert_pad, hor_pad, color=color)
37
+
38
+
39
+ def add_margin(pil_img, top, right, bottom, left, color=(0, 0, 0)):
40
+ """Ref: https://note.nkmk.me/en/python-pillow-add-margin-expand-canvas/"""
41
+ width, height = pil_img.size
42
+ new_width = width + right + left
43
+ new_height = height + top + bottom
44
+ result = Image.new(pil_img.mode, (new_width, new_height), color)
45
+ result.paste(pil_img, (left, top))
46
+ return result
clip_grounding/utils/io.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utilities for input-output loading/saving.
3
+ """
4
+
5
+ from typing import Any, List
6
+ import yaml
7
+ import pickle
8
+ import json
9
+
10
+
11
+ class PrettySafeLoader(yaml.SafeLoader):
12
+ """Custom loader for reading YAML files"""
13
+ def construct_python_tuple(self, node):
14
+ return tuple(self.construct_sequence(node))
15
+
16
+
17
+ PrettySafeLoader.add_constructor(
18
+ u'tag:yaml.org,2002:python/tuple',
19
+ PrettySafeLoader.construct_python_tuple
20
+ )
21
+
22
+
23
+ def load_yml(path: str, loader_type: str = 'default'):
24
+ """Read params from a yml file.
25
+
26
+ Args:
27
+ path (str): path to the .yml file
28
+ loader_type (str, optional): type of loader used to load yml files. Defaults to 'default'.
29
+
30
+ Returns:
31
+ Any: object (typically dict) loaded from .yml file
32
+ """
33
+ assert loader_type in ['default', 'safe']
34
+
35
+ loader = yaml.Loader if (loader_type == "default") else PrettySafeLoader
36
+
37
+ with open(path, 'r') as f:
38
+ data = yaml.load(f, Loader=loader)
39
+
40
+ return data
41
+
42
+
43
+ def save_yml(data: dict, path: str):
44
+ """Save params in the given yml file path.
45
+
46
+ Args:
47
+ data (dict): data object to save
48
+ path (str): path to .yml file to be saved
49
+ """
50
+ with open(path, 'w') as f:
51
+ yaml.dump(data, f, default_flow_style=False)
52
+
53
+
54
+ def load_pkl(path: str, encoding: str = "ascii") -> Any:
55
+ """Loads a .pkl file.
56
+
57
+ Args:
58
+ path (str): path to the .pkl file
59
+ encoding (str, optional): encoding to use for loading. Defaults to "ascii".
60
+
61
+ Returns:
62
+ Any: unpickled object
63
+ """
64
+ return pickle.load(open(path, "rb"), encoding=encoding)
65
+
66
+
67
+ def save_pkl(data: Any, path: str) -> None:
68
+ """Saves given object into .pkl file
69
+
70
+ Args:
71
+ data (Any): object to be saved
72
+ path (str): path to the location to be saved at
73
+ """
74
+ with open(path, 'wb') as f:
75
+ pickle.dump(data, f)
76
+
77
+
78
+ def load_json(path: str) -> dict:
79
+ """Helper to load json file"""
80
+ with open(path, 'rb') as f:
81
+ data = json.load(f)
82
+ return data
83
+
84
+
85
+ def save_json(data: dict, path: str):
86
+ """Helper to save `dict` as .json file."""
87
+ with open(path, 'w') as f:
88
+ json.dump(data, f)
89
+
90
+
91
+ def load_txt(path: str) -> List:
92
+ """Loads lines of a .txt file.
93
+
94
+ Args:
95
+ path (str): path to the .txt file
96
+
97
+ Returns:
98
+ List: lines of .txt file
99
+ """
100
+ with open(path) as f:
101
+ lines = f.read().splitlines()
102
+ return lines
103
+
104
+
105
+ def save_txt(data: dict, path: str):
106
+ """Writes data (lines) to a txt file.
107
+
108
+ Args:
109
+ data (dict): List of strings
110
+ path (str): path to .txt file
111
+ """
112
+ assert isinstance(data, list)
113
+
114
+ lines = "\n".join(data)
115
+ with open(path, "w") as f:
116
+ f.write(str(lines))
clip_grounding/utils/log.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Utilities for logging"""
2
+ import logging
3
+ from tqdm import tqdm
4
+ from termcolor import colored
5
+
6
+
7
+ def color(string: str, color_name: str = 'yellow') -> str:
8
+ """Returns colored string for output to terminal"""
9
+ return colored(string, color_name)
10
+
11
+
12
+ def print_update(message: str, width: int = 140, fillchar: str = ":", color="yellow") -> str:
13
+ """Prints an update message
14
+
15
+ Args:
16
+ message (str): message
17
+ width (int): width of new update message
18
+ fillchar (str): character to be filled to L and R of message
19
+
20
+ Returns:
21
+ str: print-ready update message
22
+ """
23
+ message = message.center(len(message) + 2, " ")
24
+ print(colored(message.center(width, fillchar), color))
25
+
26
+
27
+ def set_logger(log_path):
28
+ """Set the logger to log info in terminal and file `log_path`.
29
+
30
+ Args:
31
+ log_path (str): path to the log file
32
+ """
33
+ logger = logging.getLogger()
34
+ logger.setLevel(logging.INFO)
35
+
36
+ if not logger.handlers:
37
+ # Logging to a file
38
+ file_handler = logging.FileHandler(log_path)
39
+ file_handler.setFormatter(logging.Formatter('%(asctime)s:%(levelname)s: %(message)s'))
40
+ logger.addHandler(file_handler)
41
+
42
+ # Logging to console
43
+ stream_handler = logging.StreamHandler()
44
+ stream_handler.setFormatter(logging.Formatter('%(message)s'))
45
+ logger.addHandler(stream_handler)
46
+
47
+
48
+ def tqdm_iterator(items, desc=None, bar_format=None, **kwargs):
49
+ tqdm._instances.clear()
50
+ iterator = tqdm(
51
+ items,
52
+ desc=desc,
53
+ bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}',
54
+ **kwargs,
55
+ )
56
+
57
+ return iterator
clip_grounding/utils/paths.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ """Path helpers for the relfm project."""
2
+ from os.path import join, abspath, dirname
3
+
4
+
5
+ REPO_PATH = dirname(dirname(dirname(abspath(__file__))))
6
+ DATA_ROOT = join(REPO_PATH, "data")
7
+
8
+ DATASET_ROOTS = {
9
+ "PNG": join(DATA_ROOT, "panoptic_narrative_grounding"),
10
+ }
clip_grounding/utils/visualize.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Helpers for visualization"""
2
+ import numpy as np
3
+ import matplotlib
4
+ import matplotlib.pyplot as plt
5
+ import cv2
6
+ from PIL import Image
7
+
8
+
9
+ # define predominanat colors
10
+ COLORS = {
11
+ "pink": (242, 116, 223),
12
+ "cyan": (46, 242, 203),
13
+ "red": (255, 0, 0),
14
+ "green": (0, 255, 0),
15
+ "blue": (0, 0, 255),
16
+ "yellow": (255, 255, 0),
17
+ }
18
+
19
+
20
+ def show_single_image(image: np.ndarray, figsize: tuple = (8, 8), title: str = None, titlesize=18, cmap: str = None, ticks=False, save=False, save_path=None):
21
+ """Show a single image."""
22
+ fig, ax = plt.subplots(1, 1, figsize=figsize)
23
+
24
+ if isinstance(image, Image.Image):
25
+ image = np.asarray(image)
26
+
27
+ ax.set_title(title, fontsize=titlesize)
28
+ ax.imshow(image, cmap=cmap)
29
+
30
+ if not ticks:
31
+ ax.set_xticks([])
32
+ ax.set_yticks([])
33
+
34
+ if save:
35
+ plt.savefig(save_path, bbox_inches='tight')
36
+
37
+ plt.show()
38
+
39
+
40
+ def show_grid_of_images(
41
+ images: np.ndarray, n_cols: int = 4, figsize: tuple = (8, 8),
42
+ cmap=None, subtitles=None, title=None, subtitlesize=18,
43
+ save=False, save_path=None, titlesize=20,
44
+ ):
45
+ """Show a grid of images."""
46
+ n_cols = min(n_cols, len(images))
47
+
48
+ copy_of_images = images.copy()
49
+ for i, image in enumerate(copy_of_images):
50
+ if isinstance(image, Image.Image):
51
+ image = np.asarray(image)
52
+ images[i] = image
53
+
54
+ if subtitles is None:
55
+ subtitles = [None] * len(images)
56
+
57
+ n_rows = int(np.ceil(len(images) / n_cols))
58
+ fig, axes = plt.subplots(n_rows, n_cols, figsize=figsize)
59
+ for i, ax in enumerate(axes.flat):
60
+ if i < len(images):
61
+ if len(images[i].shape) == 2 and cmap is None:
62
+ cmap="gray"
63
+ ax.imshow(images[i], cmap=cmap)
64
+ ax.set_title(subtitles[i], fontsize=subtitlesize)
65
+ ax.axis('off')
66
+ fig.set_tight_layout(True)
67
+ plt.suptitle(title, y=0.8, fontsize=titlesize)
68
+
69
+ if save:
70
+ plt.savefig(save_path, bbox_inches='tight')
71
+ plt.close()
72
+ else:
73
+ plt.show()
74
+
75
+
76
+ def show_keypoint_matches(
77
+ img1, kp1, img2, kp2, matches,
78
+ K=10, figsize=(10, 5), drawMatches_args=dict(matchesThickness=3, singlePointColor=(0, 0, 0)),
79
+ choose_matches="random",
80
+ ):
81
+ """Displays matches found in the pair of images"""
82
+ if choose_matches == "random":
83
+ selected_matches = np.random.choice(matches, K)
84
+ elif choose_matches == "all":
85
+ K = len(matches)
86
+ selected_matches = matches
87
+ elif choose_matches == "topk":
88
+ selected_matches = matches[:K]
89
+ else:
90
+ raise ValueError(f"Unknown value for choose_matches: {choose_matches}")
91
+
92
+ # color each match with a different color
93
+ cmap = matplotlib.cm.get_cmap('gist_rainbow', K)
94
+ colors = [[int(x*255) for x in cmap(i)[:3]] for i in np.arange(0,K)]
95
+ drawMatches_args.update({"matchColor": -1, "singlePointColor": (100, 100, 100)})
96
+
97
+ img3 = cv2.drawMatches(img1, kp1, img2, kp2, selected_matches, outImg=None, **drawMatches_args)
98
+ show_single_image(
99
+ img3,
100
+ figsize=figsize,
101
+ title=f"[{choose_matches.upper()}] Selected K = {K} matches between the pair of images.",
102
+ )
103
+ return img3
104
+
105
+
106
+ def draw_kps_on_image(image: np.ndarray, kps: np.ndarray, color=COLORS["red"], radius=3, thickness=-1, return_as="numpy"):
107
+ """
108
+ Draw keypoints on image.
109
+
110
+ Args:
111
+ image: Image to draw keypoints on.
112
+ kps: Keypoints to draw. Note these should be in (x, y) format.
113
+ """
114
+ if isinstance(image, Image.Image):
115
+ image = np.asarray(image)
116
+
117
+ for kp in kps:
118
+ image = cv2.circle(
119
+ image, (int(kp[0]), int(kp[1])), radius=radius, color=color, thickness=thickness)
120
+
121
+ if return_as == "PIL":
122
+ return Image.fromarray(image)
123
+
124
+ return image
125
+
126
+
127
+ def get_concat_h(im1, im2):
128
+ """Concatenate two images horizontally"""
129
+ dst = Image.new('RGB', (im1.width + im2.width, im1.height))
130
+ dst.paste(im1, (0, 0))
131
+ dst.paste(im2, (im1.width, 0))
132
+ return dst
133
+
134
+
135
+ def get_concat_v(im1, im2):
136
+ """Concatenate two images vertically"""
137
+ dst = Image.new('RGB', (im1.width, im1.height + im2.height))
138
+ dst.paste(im1, (0, 0))
139
+ dst.paste(im2, (0, im1.height))
140
+ return dst
141
+
142
+
143
+ def show_images_with_keypoints(images: list, kps: list, radius=15, color=(0, 220, 220), figsize=(10, 8), return_images=False, save=False, save_path="sample.png"):
144
+ assert len(images) == len(kps)
145
+
146
+ # generate
147
+ images_with_kps = []
148
+ for i in range(len(images)):
149
+ img_with_kps = draw_kps_on_image(images[i], kps[i], radius=radius, color=color, return_as="PIL")
150
+ images_with_kps.append(img_with_kps)
151
+
152
+ # show
153
+ show_grid_of_images(images_with_kps, n_cols=len(images), figsize=figsize, save=save, save_path=save_path)
154
+
155
+ if return_images:
156
+ return images_with_kps
157
+
158
+
159
+ def set_latex_fonts(usetex=True, fontsize=14, show_sample=False, **kwargs):
160
+ try:
161
+ plt.rcParams.update({
162
+ "text.usetex": usetex,
163
+ "font.family": "serif",
164
+ "font.serif": ["Computer Modern Roman"],
165
+ "font.size": fontsize,
166
+ **kwargs,
167
+ })
168
+ if show_sample:
169
+ plt.figure()
170
+ plt.title("Sample $y = x^2$")
171
+ plt.plot(np.arange(0, 10), np.arange(0, 10)**2, "--o")
172
+ plt.grid()
173
+ plt.show()
174
+ except:
175
+ print("Failed to setup LaTeX fonts. Proceeding without.")
176
+ pass
177
+
178
+
179
+ def get_colors(num_colors, palette="jet"):
180
+ cmap = plt.get_cmap(palette)
181
+ colors = [cmap(i) for i in np.linspace(0, 1, num_colors)]
182
+ return colors
183
+
example_images/Amsterdam.png ADDED
example_images/London.png ADDED
example_images/dogs_on_bed.png ADDED
example_images/harrypotter.png ADDED
requirements.txt ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ anyio==3.6.1
2
+ appnope==0.1.3
3
+ argon2-cffi==21.3.0
4
+ argon2-cffi-bindings==21.2.0
5
+ asttokens==2.0.5
6
+ attrs==21.4.0
7
+ Babel==2.10.3
8
+ backcall==0.2.0
9
+ beautifulsoup4==4.11.1
10
+ bleach==5.0.0
11
+ captum==0.5.0
12
+ certifi==2022.6.15
13
+ cffi==1.15.0
14
+ charset-normalizer==2.0.12
15
+ cycler==0.11.0
16
+ debugpy==1.6.0
17
+ decorator==5.1.1
18
+ defusedxml==0.7.1
19
+ entrypoints==0.4
20
+ executing==0.8.3
21
+ fastjsonschema==2.15.3
22
+ fonttools==4.33.3
23
+ ftfy==6.1.1
24
+ htmlmin==0.1.12
25
+ idna==3.3
26
+ ImageHash==4.2.1
27
+ imageio==2.19.3
28
+ imgaug==0.4.0
29
+ importlib-metadata==4.11.4
30
+ ipdb==0.13.9
31
+ ipykernel==6.15.0
32
+ ipython==8.4.0
33
+ ipython-genutils==0.2.0
34
+ ipywidgets==7.7.1
35
+ jedi==0.18.1
36
+ Jinja2==3.1.2
37
+ joblib==1.1.0
38
+ json5==0.9.8
39
+ jsonschema==4.6.0
40
+ jupyter-client==7.3.4
41
+ jupyter-core==4.10.0
42
+ jupyter-server==1.18.0
43
+ jupyterlab==3.4.3
44
+ jupyterlab-pygments==0.2.2
45
+ jupyterlab-server==2.14.0
46
+ jupyterlab-widgets==1.1.1
47
+ kiwisolver==1.4.3
48
+ MarkupSafe==2.1.1
49
+ matplotlib==3.5.2
50
+ matplotlib-inline==0.1.3
51
+ missingno==0.5.1
52
+ mistune==0.8.4
53
+ multimethod==1.8
54
+ natsort==8.1.0
55
+ nbclassic==0.3.7
56
+ nbclient==0.6.4
57
+ nbconvert==6.5.0
58
+ nbformat==5.4.0
59
+ nest-asyncio==1.5.5
60
+ networkx==2.8.4
61
+ notebook==6.4.12
62
+ notebook-shim==0.1.0
63
+ numpy==1.23.0
64
+ opencv-python==4.6.0.66
65
+ packaging==21.3
66
+ pandas==1.4.3
67
+ pandas-profiling==3.2.0
68
+ pandocfilters==1.5.0
69
+ parso==0.8.3
70
+ pexpect==4.8.0
71
+ phik==0.12.2
72
+ pickleshare==0.7.5
73
+ Pillow==9.1.1
74
+ prometheus-client==0.14.1
75
+ prompt-toolkit==3.0.29
76
+ psutil==5.9.1
77
+ ptyprocess==0.7.0
78
+ pure-eval==0.2.2
79
+ pycparser==2.21
80
+ pydantic==1.9.1
81
+ Pygments==2.12.0
82
+ pyparsing==3.0.9
83
+ pyrsistent==0.18.1
84
+ python-dateutil==2.8.2
85
+ pytz==2022.1
86
+ PyWavelets==1.3.0
87
+ PyYAML==6.0
88
+ pyzmq==23.2.0
89
+ regex==2022.6.2
90
+ requests==2.28.0
91
+ scikit-image==0.19.3
92
+ scikit-learn==1.1.1
93
+ scipy==1.8.1
94
+ seaborn==0.11.2
95
+ Send2Trash==1.8.0
96
+ Shapely==1.8.2
97
+ six==1.16.0
98
+ sniffio==1.2.0
99
+ soupsieve==2.3.2.post1
100
+ stack-data==0.3.0
101
+ tangled-up-in-unicode==0.2.0
102
+ termcolor==1.1.0
103
+ terminado==0.15.0
104
+ threadpoolctl==3.1.0
105
+ tifffile==2022.5.4
106
+ tinycss2==1.1.1
107
+ toml==0.10.2
108
+ torch==1.11.0
109
+ torchmetrics==0.9.1
110
+ torchvision==0.12.0
111
+ tornado==6.1
112
+ tqdm==4.64.0
113
+ traitlets==5.3.0
114
+ typing_extensions==4.2.0
115
+ urllib3==1.26.9
116
+ visions==0.7.4
117
+ wcwidth==0.2.5
118
+ webencodings==0.5.1
119
+ websocket-client==1.3.3
120
+ widgetsnbextension==3.6.1
121
+ zipp==3.8.0