paul hilders
Add new version of demo for IEAI course
0241217
raw
history blame
5.69 kB
import torch
import CLIP.clip as clip
from PIL import Image
import numpy as np
import cv2
import matplotlib.pyplot as plt
from captum.attr import visualization
import os
from CLIP.clip.simple_tokenizer import SimpleTokenizer as _Tokenizer
_tokenizer = _Tokenizer()
#@title Control context expansion (number of attention layers to consider)
#@title Number of layers for image Transformer
start_layer = 11#@param {type:"number"}
#@title Number of layers for text Transformer
start_layer_text = 11#@param {type:"number"}
def interpret(image, texts, model, device):
batch_size = texts.shape[0]
images = image.repeat(batch_size, 1, 1, 1)
logits_per_image, logits_per_text = model(images, texts)
probs = logits_per_image.softmax(dim=-1).detach().cpu().numpy()
index = [i for i in range(batch_size)]
one_hot = np.zeros((logits_per_image.shape[0], logits_per_image.shape[1]), dtype=np.float32)
one_hot[torch.arange(logits_per_image.shape[0]), index] = 1
one_hot = torch.from_numpy(one_hot).requires_grad_(True)
one_hot = torch.sum(one_hot.to(device) * logits_per_image)
model.zero_grad()
image_attn_blocks = list(dict(model.visual.transformer.resblocks.named_children()).values())
num_tokens = image_attn_blocks[0].attn_probs.shape[-1]
R = torch.eye(num_tokens, num_tokens, dtype=image_attn_blocks[0].attn_probs.dtype).to(device)
R = R.unsqueeze(0).expand(batch_size, num_tokens, num_tokens)
for i, blk in enumerate(image_attn_blocks):
if i < start_layer:
continue
grad = torch.autograd.grad(one_hot, [blk.attn_probs], retain_graph=True)[0].detach()
cam = blk.attn_probs.detach()
cam = cam.reshape(-1, cam.shape[-1], cam.shape[-1])
grad = grad.reshape(-1, grad.shape[-1], grad.shape[-1])
cam = grad * cam
cam = cam.reshape(batch_size, -1, cam.shape[-1], cam.shape[-1])
cam = cam.clamp(min=0).mean(dim=1)
R = R + torch.bmm(cam, R)
image_relevance = R[:, 0, 1:]
text_attn_blocks = list(dict(model.transformer.resblocks.named_children()).values())
num_tokens = text_attn_blocks[0].attn_probs.shape[-1]
R_text = torch.eye(num_tokens, num_tokens, dtype=text_attn_blocks[0].attn_probs.dtype).to(device)
R_text = R_text.unsqueeze(0).expand(batch_size, num_tokens, num_tokens)
for i, blk in enumerate(text_attn_blocks):
if i < start_layer_text:
continue
grad = torch.autograd.grad(one_hot, [blk.attn_probs], retain_graph=True)[0].detach()
cam = blk.attn_probs.detach()
cam = cam.reshape(-1, cam.shape[-1], cam.shape[-1])
grad = grad.reshape(-1, grad.shape[-1], grad.shape[-1])
cam = grad * cam
cam = cam.reshape(batch_size, -1, cam.shape[-1], cam.shape[-1])
cam = cam.clamp(min=0).mean(dim=1)
R_text = R_text + torch.bmm(cam, R_text)
text_relevance = R_text
return text_relevance, image_relevance
def show_image_relevance(image_relevance, image, orig_image, device, show=True):
# create heatmap from mask on image
def show_cam_on_image(img, mask):
heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
heatmap = np.float32(heatmap) / 255
cam = heatmap + np.float32(img)
cam = cam / np.max(cam)
return cam
# plt.axis('off')
# f, axarr = plt.subplots(1,2)
# axarr[0].imshow(orig_image)
if show:
fig, axs = plt.subplots(1, 2)
axs[0].imshow(orig_image);
axs[0].axis('off');
image_relevance = image_relevance.reshape(1, 1, 7, 7)
image_relevance = torch.nn.functional.interpolate(image_relevance, size=224, mode='bilinear')
image_relevance = image_relevance.reshape(224, 224).to(device).data.cpu().numpy()
image_relevance = (image_relevance - image_relevance.min()) / (image_relevance.max() - image_relevance.min())
image = image[0].permute(1, 2, 0).data.cpu().numpy()
image = (image - image.min()) / (image.max() - image.min())
vis = show_cam_on_image(image, image_relevance)
vis = np.uint8(255 * vis)
vis = cv2.cvtColor(np.array(vis), cv2.COLOR_RGB2BGR)
if show:
# axar[1].imshow(vis)
axs[1].imshow(vis);
axs[1].axis('off');
# plt.imshow(vis)
return image_relevance
def show_heatmap_on_text(text, text_encoding, R_text, show=True):
CLS_idx = text_encoding.argmax(dim=-1)
R_text = R_text[CLS_idx, 1:CLS_idx]
text_scores = R_text / R_text.sum()
text_scores = text_scores.flatten()
# print(text_scores)
text_tokens=_tokenizer.encode(text)
text_tokens_decoded=[_tokenizer.decode([a]) for a in text_tokens]
vis_data_records = [visualization.VisualizationDataRecord(text_scores,0,0,0,0,0,text_tokens_decoded,1)]
if show:
visualization.visualize_text(vis_data_records)
return text_scores, text_tokens_decoded
def show_img_heatmap(image_relevance, image, orig_image, device, show=True):
return show_image_relevance(image_relevance, image, orig_image, device, show=show)
def show_txt_heatmap(text, text_encoding, R_text, show=True):
return show_heatmap_on_text(text, text_encoding, R_text, show=show)
def load_dataset():
dataset_path = os.path.join('..', '..', 'dummy-data', '71226_segments' + '.pt')
device = "cuda" if torch.cuda.is_available() else "cpu"
data = torch.load(dataset_path, map_location=device)
return data
class color:
PURPLE = '\033[95m'
CYAN = '\033[96m'
DARKCYAN = '\033[36m'
BLUE = '\033[94m'
GREEN = '\033[92m'
YELLOW = '\033[93m'
RED = '\033[91m'
BOLD = '\033[1m'
UNDERLINE = '\033[4m'
END = '\033[0m'