Spaces:
Runtime error
Runtime error
Martijn van Beers
commited on
Commit
·
8f3d1af
1
Parent(s):
cf1865f
Remove code for jupyter notebooks
Browse filesThere was some partially commented out code to create matplotlib
figures. Remove it altogether.
- CLIP_explainability/utils.py +7 -25
- app.py +2 -2
CLIP_explainability/utils.py
CHANGED
@@ -69,7 +69,7 @@ def interpret(image, texts, model, device):
|
|
69 |
return text_relevance, image_relevance
|
70 |
|
71 |
|
72 |
-
def show_image_relevance(image_relevance, image, orig_image, device
|
73 |
# create heatmap from mask on image
|
74 |
def show_cam_on_image(img, mask):
|
75 |
heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
|
@@ -78,15 +78,6 @@ def show_image_relevance(image_relevance, image, orig_image, device, show=True):
|
|
78 |
cam = cam / np.max(cam)
|
79 |
return cam
|
80 |
|
81 |
-
# plt.axis('off')
|
82 |
-
# f, axarr = plt.subplots(1,2)
|
83 |
-
# axarr[0].imshow(orig_image)
|
84 |
-
|
85 |
-
if show:
|
86 |
-
fig, axs = plt.subplots(1, 2)
|
87 |
-
axs[0].imshow(orig_image);
|
88 |
-
axs[0].axis('off');
|
89 |
-
|
90 |
image_relevance = image_relevance.reshape(1, 1, 7, 7)
|
91 |
image_relevance = torch.nn.functional.interpolate(image_relevance, size=224, mode='bilinear')
|
92 |
image_relevance = image_relevance.reshape(224, 224).to(device).data.cpu().numpy()
|
@@ -97,16 +88,10 @@ def show_image_relevance(image_relevance, image, orig_image, device, show=True):
|
|
97 |
vis = np.uint8(255 * vis)
|
98 |
vis = cv2.cvtColor(np.array(vis), cv2.COLOR_RGB2BGR)
|
99 |
|
100 |
-
if show:
|
101 |
-
# axar[1].imshow(vis)
|
102 |
-
axs[1].imshow(vis);
|
103 |
-
axs[1].axis('off');
|
104 |
-
# plt.imshow(vis)
|
105 |
-
|
106 |
return image_relevance
|
107 |
|
108 |
|
109 |
-
def show_heatmap_on_text(text, text_encoding, R_text
|
110 |
CLS_idx = text_encoding.argmax(dim=-1)
|
111 |
R_text = R_text[CLS_idx, 1:CLS_idx]
|
112 |
text_scores = R_text / R_text.sum()
|
@@ -115,19 +100,16 @@ def show_heatmap_on_text(text, text_encoding, R_text, show=True):
|
|
115 |
text_tokens=_tokenizer.encode(text)
|
116 |
text_tokens_decoded=[_tokenizer.decode([a]) for a in text_tokens]
|
117 |
vis_data_records = [visualization.VisualizationDataRecord(text_scores,0,0,0,0,0,text_tokens_decoded,1)]
|
118 |
-
|
119 |
-
if show:
|
120 |
-
visualization.visualize_text(vis_data_records)
|
121 |
|
122 |
return text_scores, text_tokens_decoded
|
123 |
|
124 |
|
125 |
-
def show_img_heatmap(image_relevance, image, orig_image, device
|
126 |
-
return show_image_relevance(image_relevance, image, orig_image, device
|
127 |
|
128 |
|
129 |
-
def show_txt_heatmap(text, text_encoding, R_text
|
130 |
-
return show_heatmap_on_text(text, text_encoding, R_text
|
131 |
|
132 |
|
133 |
def load_dataset():
|
@@ -149,4 +131,4 @@ class color:
|
|
149 |
RED = '\033[91m'
|
150 |
BOLD = '\033[1m'
|
151 |
UNDERLINE = '\033[4m'
|
152 |
-
END = '\033[0m'
|
|
|
69 |
return text_relevance, image_relevance
|
70 |
|
71 |
|
72 |
+
def show_image_relevance(image_relevance, image, orig_image, device):
|
73 |
# create heatmap from mask on image
|
74 |
def show_cam_on_image(img, mask):
|
75 |
heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
|
|
|
78 |
cam = cam / np.max(cam)
|
79 |
return cam
|
80 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
81 |
image_relevance = image_relevance.reshape(1, 1, 7, 7)
|
82 |
image_relevance = torch.nn.functional.interpolate(image_relevance, size=224, mode='bilinear')
|
83 |
image_relevance = image_relevance.reshape(224, 224).to(device).data.cpu().numpy()
|
|
|
88 |
vis = np.uint8(255 * vis)
|
89 |
vis = cv2.cvtColor(np.array(vis), cv2.COLOR_RGB2BGR)
|
90 |
|
|
|
|
|
|
|
|
|
|
|
|
|
91 |
return image_relevance
|
92 |
|
93 |
|
94 |
+
def show_heatmap_on_text(text, text_encoding, R_text):
|
95 |
CLS_idx = text_encoding.argmax(dim=-1)
|
96 |
R_text = R_text[CLS_idx, 1:CLS_idx]
|
97 |
text_scores = R_text / R_text.sum()
|
|
|
100 |
text_tokens=_tokenizer.encode(text)
|
101 |
text_tokens_decoded=[_tokenizer.decode([a]) for a in text_tokens]
|
102 |
vis_data_records = [visualization.VisualizationDataRecord(text_scores,0,0,0,0,0,text_tokens_decoded,1)]
|
|
|
|
|
|
|
103 |
|
104 |
return text_scores, text_tokens_decoded
|
105 |
|
106 |
|
107 |
+
def show_img_heatmap(image_relevance, image, orig_image, device):
|
108 |
+
return show_image_relevance(image_relevance, image, orig_image, device)
|
109 |
|
110 |
|
111 |
+
def show_txt_heatmap(text, text_encoding, R_text):
|
112 |
+
return show_heatmap_on_text(text, text_encoding, R_text)
|
113 |
|
114 |
|
115 |
def load_dataset():
|
|
|
131 |
RED = '\033[91m'
|
132 |
BOLD = '\033[1m'
|
133 |
UNDERLINE = '\033[4m'
|
134 |
+
END = '\033[0m'
|
app.py
CHANGED
@@ -59,10 +59,10 @@ def run_demo(image, text):
|
|
59 |
|
60 |
R_text, R_image = interpret(model=model, image=img, texts=text_input, device=device)
|
61 |
|
62 |
-
image_relevance = show_img_heatmap(R_image[0], img, orig_image=orig_image, device=device
|
63 |
overlapped = overlay_relevance_map_on_image(image, image_relevance)
|
64 |
|
65 |
-
text_scores, text_tokens_decoded = show_heatmap_on_text(text, text_input, R_text[0]
|
66 |
|
67 |
highlighted_text = []
|
68 |
for i, token in enumerate(text_tokens_decoded):
|
|
|
59 |
|
60 |
R_text, R_image = interpret(model=model, image=img, texts=text_input, device=device)
|
61 |
|
62 |
+
image_relevance = show_img_heatmap(R_image[0], img, orig_image=orig_image, device=device)
|
63 |
overlapped = overlay_relevance_map_on_image(image, image_relevance)
|
64 |
|
65 |
+
text_scores, text_tokens_decoded = show_heatmap_on_text(text, text_input, R_text[0])
|
66 |
|
67 |
highlighted_text = []
|
68 |
for i, token in enumerate(text_tokens_decoded):
|