|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""A wrapper for CLIP model to support forward with a list of text inputs.""" |
|
|
|
|
|
import clip |
|
import numpy as np |
|
import torch |
|
from torch import nn |
|
import torch.nn.functional as F |
|
|
|
_CONTEXT_LENGTH = 77 |
|
|
|
|
|
def forward_clip_single(model, image, text, h, w): |
|
"""Forward a single text input. |
|
|
|
Args: |
|
model (CLIPWrapper or CLIP): the CLIP model. |
|
image (torch.Tensor): the image tensor. |
|
text (List[str]): the text input. |
|
h (int): the height of the image. |
|
w (int): the width of the image. |
|
|
|
Returns: |
|
torch.Tensor: the logits. |
|
""" |
|
if isinstance(text, str): |
|
text = [text] |
|
text_tokens = clip.tokenize(text).to(image.device) |
|
text_prediction = model(image, text_tokens, h, w) |
|
return text_prediction.detach().cpu() |
|
|
|
|
|
def forward_clip(model, image, text, h, w): |
|
"""Forward a list of text inputs. |
|
|
|
Args: |
|
model (CLIPWrapper or CLIP): the CLIP model. |
|
image (torch.Tensor): the image tensor. |
|
text (List[str] or List[List[str]]): the text input. |
|
h (int): the height of the image. |
|
w (int): the width of the image. |
|
|
|
Returns: |
|
torch.Tensor: the logits. |
|
""" |
|
if isinstance(text[0], list): |
|
text_prediction = torch.stack( |
|
[forward_clip_single(model, image, t, h, w) for t in text], dim=0 |
|
) |
|
text_prediction = torch.sum(text_prediction, dim=0) |
|
text_prediction = F.softmax(text_prediction.float(), dim=-1) |
|
else: |
|
text_prediction = forward_clip_single(model, image, text, h, w) |
|
return text_prediction.float() |
|
|
|
|
|
def upsample_position_embedding(embed, new_size): |
|
"""Upsample the pretrained embedding to a higher resolution. |
|
|
|
Args: |
|
embed (torch.Tensor): the pretrained embedding. |
|
new_size (Tuple[int, int]): the new size of the embedding. |
|
|
|
Returns: |
|
torch.Tensor: the upsampled embedding. |
|
""" |
|
|
|
first = embed[:1, :] |
|
embed = embed[1:, :] |
|
n = embed.size(0) |
|
d = embed.size(1) |
|
size = int(np.sqrt(n)) |
|
if size * size != n: |
|
raise ValueError(f'The size of embed {n} is not a perfect square number.') |
|
|
|
embed = embed.permute(1, 0) |
|
embed = embed.view(1, d, size, size).contiguous() |
|
embed = F.upsample( |
|
embed, |
|
size=new_size, |
|
mode='bilinear', |
|
) |
|
embed = embed.view(d, -1).contiguous() |
|
embed = embed.permute(1, 0) |
|
embed = torch.cat([first, embed], 0) |
|
embed = nn.parameter.Parameter(embed.half()) |
|
return embed |
|
|
|
|
|
class CustomBlock(nn.Module): |
|
"""A customized attention block.""" |
|
|
|
def __init__(self, block): |
|
super().__init__() |
|
for k, v in vars(block).items(): |
|
setattr(self, k, v) |
|
|
|
def attention(self, x): |
|
self.attn_mask = ( |
|
self.attn_mask.to(dtype=x.dtype, device=x.device) |
|
if self.attn_mask is not None |
|
else None |
|
) |
|
self.attn = self.attn.to(dtype=x.dtype, device=x.device) |
|
|
|
return self.attn(x, x, x, need_weights=True, attn_mask=self.attn_mask) |
|
|
|
def forward(self, x): |
|
|
|
attn_output, attn_weight = self.attention(self.ln_1(x)) |
|
x = x + attn_output |
|
x = x + self.mlp(self.ln_2(x)) |
|
return x, attn_weight |
|
|
|
|
|
class CustomTransformer(nn.Module): |
|
"""A customized Transformer to support CAM calculation.""" |
|
|
|
def __init__(self, transformer): |
|
"""Initialize the wrapper. |
|
|
|
Args: |
|
transformer (nn.Module): the Transformer to be wrapped. |
|
""" |
|
super().__init__() |
|
for k, v in vars(transformer).items(): |
|
setattr(self, k, v) |
|
|
|
self.resblocks = nn.Sequential( |
|
*[CustomBlock(block) for block in self.resblocks] |
|
) |
|
|
|
def forward(self, x): |
|
attn_weights = [] |
|
with torch.no_grad(): |
|
layers = self.layers if x.shape[0] == _CONTEXT_LENGTH else self.layers - 1 |
|
for i in range(layers): |
|
x, attn_weight = self.resblocks[i](x) |
|
attn_weights.append(attn_weight) |
|
return x, attn_weights |
|
|
|
|
|
class CustomVisionTransformer(nn.Module): |
|
"""A customized VisionTransformer to support CAM calculation.""" |
|
|
|
def __init__(self, model): |
|
"""Initialize the wrapper. |
|
|
|
Args: |
|
model (VisionTransformer): the VisionTransformer to be wrapped. |
|
""" |
|
super().__init__() |
|
for k, v in vars(model).items(): |
|
setattr(self, k, v) |
|
self.patch_size = self.conv1.kernel_size[0] |
|
self.transformer = CustomTransformer(self.transformer) |
|
|
|
def forward(self, x, h, w): |
|
self.positional_embedding_new = upsample_position_embedding( |
|
self.positional_embedding, (h // self.patch_size, w // self.patch_size) |
|
) |
|
|
|
x = self.conv1(x) |
|
|
|
x = x.reshape(x.shape[0], x.shape[1], -1) |
|
|
|
x = x.permute(0, 2, 1) |
|
zeros = torch.zeros( |
|
x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device |
|
) |
|
|
|
x = torch.cat([self.class_embedding.to(x.dtype) + zeros, x], dim=1) |
|
x = x + self.positional_embedding_new.to(x.dtype) |
|
x = self.ln_pre(x) |
|
|
|
x = x.permute(1, 0, 2) |
|
x, attn_weight = self.transformer(x) |
|
return x, attn_weight |
|
|
|
|
|
class CLIPWrapper(nn.Module): |
|
"""A wrapper for CLIP to support forward with a list of text inputs.""" |
|
|
|
def __init__(self, clip_model): |
|
"""Initialize the wrapper. |
|
|
|
Args: |
|
clip_model (CLIP): the CLIP model to be wrapped. |
|
""" |
|
super().__init__() |
|
|
|
for k, v in vars(clip_model).items(): |
|
setattr(self, k, v) |
|
self.visual = CustomVisionTransformer(self.visual) |
|
self.transformer = CustomTransformer(self.transformer) |
|
|
|
@property |
|
def dtype(self): |
|
return self.visual.conv1.weight.dtype |
|
|
|
def encode_image(self, image, h, w): |
|
return self.visual(image.type(self.dtype), h, w) |
|
|
|
def encode_text(self, text): |
|
x = self.token_embedding(text).type( |
|
self.dtype |
|
) |
|
|
|
x = x + self.positional_embedding.type(self.dtype) |
|
x = x.permute(1, 0, 2) |
|
x, _ = self.transformer(x) |
|
x = x.permute(1, 0, 2) |
|
x = self.ln_final(x).type(self.dtype) |
|
|
|
|
|
|
|
|
|
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection |
|
|
|
return x |
|
|
|
def pool_visual(self, x, use_cls_token=False): |
|
if use_cls_token: |
|
return x[:, 0] |
|
else: |
|
return torch.mean(x[:, 1:, :], dim=1) |
|
|
|
def forward_last_layer( |
|
self, image_features, text_features, use_cls_token=False, repeat_last=True |
|
): |
|
"""Forward the last layer of CLIP. |
|
|
|
Args: |
|
image_features (torch.Tensor): the image features. |
|
text_features (torch.Tensor): the text features. |
|
use_cls_token (bool, optional): whether to use the CLS token. Defaults |
|
to False. |
|
repeat_last (bool, optional): whether to repeat the last layer. Defaults |
|
to True. |
|
|
|
Returns: |
|
torch.Tensor: the logits. |
|
torch.Tensor: the attention weights. |
|
""" |
|
if repeat_last: |
|
x, attention_weight = self.visual.transformer.resblocks[ |
|
self.visual.transformer.layers - 1 |
|
](image_features) |
|
else: |
|
x = image_features |
|
attention_weight = None |
|
x = x.permute(1, 0, 2) |
|
|
|
x = self.visual.ln_post(x) |
|
x = self.pool_visual(x, use_cls_token=use_cls_token) |
|
|
|
if self.visual.proj is not None: |
|
x = x @ self.visual.proj |
|
|
|
image_features = x |
|
|
|
|
|
image_features = image_features / image_features.norm(dim=1, keepdim=True) |
|
text_features = text_features / text_features.norm(dim=1, keepdim=True) |
|
|
|
logit_scale = self.logit_scale.exp() |
|
logits_per_image = logit_scale * image_features @ text_features.t() |
|
|
|
|
|
logits_per_image = F.softmax(logits_per_image.float(), dim=-1) |
|
|
|
return logits_per_image, attention_weight |
|
|
|
def forward(self, image, text, h=224, w=224): |
|
with torch.no_grad(): |
|
text_features = self.encode_text(text) |
|
feature_map, _ = self.visual(image.type(self.dtype), h, w) |
|
|
|
logits_per_image, _ = self.forward_last_layer( |
|
feature_map, text_features, use_cls_token=True, repeat_last=False |
|
) |
|
return logits_per_image |
|
|