Spaces:
Runtime error
Runtime error
File size: 5,331 Bytes
0241217 dc15657 0241217 dc15657 0241217 8f3d1af 0241217 5f8002c 0241217 8f3d1af 0241217 8f3d1af 0241217 8f3d1af 0241217 8f3d1af |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
import torch
import CLIP.clip as clip
from PIL import Image
import numpy as np
import cv2
import matplotlib.pyplot as plt
from captum.attr import visualization
import os
from CLIP.clip.simple_tokenizer import SimpleTokenizer as _Tokenizer
_tokenizer = _Tokenizer()
#@title Control context expansion (number of attention layers to consider)
#@title Number of layers for image Transformer
#start_layer = 11#@param {type:"number"}
#@title Number of layers for text Transformer
start_layer_text = 11#@param {type:"number"}
def interpret(image, texts, model, device, start_layer):
batch_size = texts.shape[0]
images = image.repeat(batch_size, 1, 1, 1)
logits_per_image, logits_per_text = model(images, texts)
probs = logits_per_image.softmax(dim=-1).detach().cpu().numpy()
index = [i for i in range(batch_size)]
one_hot = np.zeros((logits_per_image.shape[0], logits_per_image.shape[1]), dtype=np.float32)
one_hot[torch.arange(logits_per_image.shape[0]), index] = 1
one_hot = torch.from_numpy(one_hot).requires_grad_(True)
one_hot = torch.sum(one_hot.to(device) * logits_per_image)
model.zero_grad()
image_attn_blocks = list(dict(model.visual.transformer.resblocks.named_children()).values())
num_tokens = image_attn_blocks[0].attn_probs.shape[-1]
R = torch.eye(num_tokens, num_tokens, dtype=image_attn_blocks[0].attn_probs.dtype).to(device)
R = R.unsqueeze(0).expand(batch_size, num_tokens, num_tokens)
for i, blk in enumerate(image_attn_blocks):
if i < start_layer:
continue
grad = torch.autograd.grad(one_hot, [blk.attn_probs], retain_graph=True)[0].detach()
cam = blk.attn_probs.detach()
cam = cam.reshape(-1, cam.shape[-1], cam.shape[-1])
grad = grad.reshape(-1, grad.shape[-1], grad.shape[-1])
cam = grad * cam
cam = cam.reshape(batch_size, -1, cam.shape[-1], cam.shape[-1])
cam = cam.clamp(min=0).mean(dim=1)
R = R + torch.bmm(cam, R)
image_relevance = R[:, 0, 1:]
text_attn_blocks = list(dict(model.transformer.resblocks.named_children()).values())
num_tokens = text_attn_blocks[0].attn_probs.shape[-1]
R_text = torch.eye(num_tokens, num_tokens, dtype=text_attn_blocks[0].attn_probs.dtype).to(device)
R_text = R_text.unsqueeze(0).expand(batch_size, num_tokens, num_tokens)
for i, blk in enumerate(text_attn_blocks):
if i < start_layer_text:
continue
grad = torch.autograd.grad(one_hot, [blk.attn_probs], retain_graph=True)[0].detach()
cam = blk.attn_probs.detach()
cam = cam.reshape(-1, cam.shape[-1], cam.shape[-1])
grad = grad.reshape(-1, grad.shape[-1], grad.shape[-1])
cam = grad * cam
cam = cam.reshape(batch_size, -1, cam.shape[-1], cam.shape[-1])
cam = cam.clamp(min=0).mean(dim=1)
R_text = R_text + torch.bmm(cam, R_text)
text_relevance = R_text
return text_relevance, image_relevance
def show_image_relevance(image_relevance, image, orig_image, device):
# create heatmap from mask on image
def show_cam_on_image(img, mask):
heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
heatmap = np.float32(heatmap) / 255
cam = heatmap + np.float32(img)
cam = cam / np.max(cam)
return cam
rel_shp = np.sqrt(image_relevance.shape[0]).astype(int)
img_size = image.shape[-1]
image_relevance = image_relevance.reshape(1, 1, rel_shp, rel_shp)
image_relevance = torch.nn.functional.interpolate(image_relevance, size=img_size, mode='bilinear')
image_relevance = image_relevance.reshape(img_size, img_size).data.cpu().numpy()
image_relevance = (image_relevance - image_relevance.min()) / (image_relevance.max() - image_relevance.min())
image = image[0].permute(1, 2, 0).data.cpu().numpy()
image = (image - image.min()) / (image.max() - image.min())
vis = show_cam_on_image(image, image_relevance)
vis = np.uint8(255 * vis)
vis = cv2.cvtColor(np.array(vis), cv2.COLOR_RGB2BGR)
return image_relevance
def show_heatmap_on_text(text, text_encoding, R_text):
CLS_idx = text_encoding.argmax(dim=-1)
R_text = R_text[CLS_idx, 1:CLS_idx]
text_scores = R_text / R_text.sum()
text_scores = text_scores.flatten()
# print(text_scores)
text_tokens=_tokenizer.encode(text)
text_tokens_decoded=[_tokenizer.decode([a]) for a in text_tokens]
vis_data_records = [visualization.VisualizationDataRecord(text_scores,0,0,0,0,0,text_tokens_decoded,1)]
return text_scores, text_tokens_decoded
def show_img_heatmap(image_relevance, image, orig_image, device):
return show_image_relevance(image_relevance, image, orig_image, device)
def show_txt_heatmap(text, text_encoding, R_text):
return show_heatmap_on_text(text, text_encoding, R_text)
def load_dataset():
dataset_path = os.path.join('..', '..', 'dummy-data', '71226_segments' + '.pt')
device = "cuda" if torch.cuda.is_available() else "cpu"
data = torch.load(dataset_path, map_location=device)
return data
class color:
PURPLE = '\033[95m'
CYAN = '\033[96m'
DARKCYAN = '\033[36m'
BLUE = '\033[94m'
GREEN = '\033[92m'
YELLOW = '\033[93m'
RED = '\033[91m'
BOLD = '\033[1m'
UNDERLINE = '\033[4m'
END = '\033[0m'
|