CLIP_as_RNN / modeling /model /clip_wrapper.py
Kevin Sun
init commit
6cd90b7
raw
history blame
9.09 kB
# coding=utf-8
# Copyright 2024 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A wrapper for CLIP model to support forward with a list of text inputs."""
# pylint: disable=g-importing-member
import clip
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
_CONTEXT_LENGTH = 77
def forward_clip_single(model, image, text, h, w):
"""Forward a single text input.
Args:
model (CLIPWrapper or CLIP): the CLIP model.
image (torch.Tensor): the image tensor.
text (List[str]): the text input.
h (int): the height of the image.
w (int): the width of the image.
Returns:
torch.Tensor: the logits.
"""
if isinstance(text, str):
text = [text]
text_tokens = clip.tokenize(text).to(image.device)
text_prediction = model(image, text_tokens, h, w)
return text_prediction.detach().cpu()
def forward_clip(model, image, text, h, w):
"""Forward a list of text inputs.
Args:
model (CLIPWrapper or CLIP): the CLIP model.
image (torch.Tensor): the image tensor.
text (List[str] or List[List[str]]): the text input.
h (int): the height of the image.
w (int): the width of the image.
Returns:
torch.Tensor: the logits.
"""
if isinstance(text[0], list):
text_prediction = torch.stack(
[forward_clip_single(model, image, t, h, w) for t in text], dim=0
)
text_prediction = torch.sum(text_prediction, dim=0)
text_prediction = F.softmax(text_prediction.float(), dim=-1)
else:
text_prediction = forward_clip_single(model, image, text, h, w)
return text_prediction.float()
def upsample_position_embedding(embed, new_size):
"""Upsample the pretrained embedding to a higher resolution.
Args:
embed (torch.Tensor): the pretrained embedding.
new_size (Tuple[int, int]): the new size of the embedding.
Returns:
torch.Tensor: the upsampled embedding.
"""
# emb size NxD
first = embed[:1, :]
embed = embed[1:, :]
n = embed.size(0)
d = embed.size(1)
size = int(np.sqrt(n))
if size * size != n:
raise ValueError(f'The size of embed {n} is not a perfect square number.')
# new_size = size * self.upsample
embed = embed.permute(1, 0)
embed = embed.view(1, d, size, size).contiguous()
embed = F.upsample(
embed,
size=new_size,
mode='bilinear',
)
embed = embed.view(d, -1).contiguous()
embed = embed.permute(1, 0)
embed = torch.cat([first, embed], 0)
embed = nn.parameter.Parameter(embed.half())
return embed
class CustomBlock(nn.Module):
"""A customized attention block."""
def __init__(self, block):
super().__init__()
for k, v in vars(block).items():
setattr(self, k, v)
def attention(self, x):
self.attn_mask = (
self.attn_mask.to(dtype=x.dtype, device=x.device)
if self.attn_mask is not None
else None
)
self.attn = self.attn.to(dtype=x.dtype, device=x.device)
# Setting need_weights to True also returns the attention weights
return self.attn(x, x, x, need_weights=True, attn_mask=self.attn_mask)
def forward(self, x):
# attn_output: (L,N,E), attn_weight: (N,L,L)
attn_output, attn_weight = self.attention(self.ln_1(x))
x = x + attn_output
x = x + self.mlp(self.ln_2(x))
return x, attn_weight
class CustomTransformer(nn.Module):
"""A customized Transformer to support CAM calculation."""
def __init__(self, transformer):
"""Initialize the wrapper.
Args:
transformer (nn.Module): the Transformer to be wrapped.
"""
super().__init__()
for k, v in vars(transformer).items():
setattr(self, k, v)
self.resblocks = nn.Sequential(
*[CustomBlock(block) for block in self.resblocks]
)
def forward(self, x):
attn_weights = []
with torch.no_grad():
layers = self.layers if x.shape[0] == _CONTEXT_LENGTH else self.layers - 1
for i in range(layers):
x, attn_weight = self.resblocks[i](x)
attn_weights.append(attn_weight)
return x, attn_weights
class CustomVisionTransformer(nn.Module):
"""A customized VisionTransformer to support CAM calculation."""
def __init__(self, model):
"""Initialize the wrapper.
Args:
model (VisionTransformer): the VisionTransformer to be wrapped.
"""
super().__init__()
for k, v in vars(model).items():
setattr(self, k, v)
self.patch_size = self.conv1.kernel_size[0]
self.transformer = CustomTransformer(self.transformer)
def forward(self, x, h, w):
self.positional_embedding_new = upsample_position_embedding(
self.positional_embedding, (h // self.patch_size, w // self.patch_size)
)
# shape = [*, width, grid, grid]
x = self.conv1(x)
# shape = [*, width, grid ** 2]
x = x.reshape(x.shape[0], x.shape[1], -1)
# shape = [*, grid ** 2, width]
x = x.permute(0, 2, 1)
zeros = torch.zeros(
x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device
)
# shape = [*, grid ** 2 + 1, width]
x = torch.cat([self.class_embedding.to(x.dtype) + zeros, x], dim=1)
x = x + self.positional_embedding_new.to(x.dtype)
x = self.ln_pre(x)
# NLD -> LND
x = x.permute(1, 0, 2)
x, attn_weight = self.transformer(x)
return x, attn_weight
class CLIPWrapper(nn.Module):
"""A wrapper for CLIP to support forward with a list of text inputs."""
def __init__(self, clip_model):
"""Initialize the wrapper.
Args:
clip_model (CLIP): the CLIP model to be wrapped.
"""
super().__init__()
# copy all attributes from clip_model to self
for k, v in vars(clip_model).items():
setattr(self, k, v)
self.visual = CustomVisionTransformer(self.visual)
self.transformer = CustomTransformer(self.transformer)
@property
def dtype(self):
return self.visual.conv1.weight.dtype
def encode_image(self, image, h, w):
return self.visual(image.type(self.dtype), h, w)
def encode_text(self, text):
x = self.token_embedding(text).type(
self.dtype
) # [batch_size, n_ctx, d_model]
x = x + self.positional_embedding.type(self.dtype)
x = x.permute(1, 0, 2) # NLD -> LND
x, _ = self.transformer(x)
x = x.permute(1, 0, 2) # LND -> NLD
x = self.ln_final(x).type(self.dtype)
# x.shape = [batch_size, n_ctx, transformer.width]
# take features from the eot embedding
# (eot_token is the highest number in each sequence)
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
return x
def pool_visual(self, x, use_cls_token=False):
if use_cls_token:
return x[:, 0]
else:
return torch.mean(x[:, 1:, :], dim=1)
def forward_last_layer(
self, image_features, text_features, use_cls_token=False, repeat_last=True
):
"""Forward the last layer of CLIP.
Args:
image_features (torch.Tensor): the image features.
text_features (torch.Tensor): the text features.
use_cls_token (bool, optional): whether to use the CLS token. Defaults
to False.
repeat_last (bool, optional): whether to repeat the last layer. Defaults
to True.
Returns:
torch.Tensor: the logits.
torch.Tensor: the attention weights.
"""
if repeat_last:
x, attention_weight = self.visual.transformer.resblocks[
self.visual.transformer.layers - 1
](image_features)
else:
x = image_features
attention_weight = None
x = x.permute(1, 0, 2) # LND -> NLD
x = self.visual.ln_post(x)
x = self.pool_visual(x, use_cls_token=use_cls_token)
if self.visual.proj is not None:
x = x @ self.visual.proj
image_features = x
# normalized features
image_features = image_features / image_features.norm(dim=1, keepdim=True)
text_features = text_features / text_features.norm(dim=1, keepdim=True)
# cosine similarity as logits
logit_scale = self.logit_scale.exp()
logits_per_image = logit_scale * image_features @ text_features.t()
# shape = [global_batch_size, global_batch_size]
logits_per_image = F.softmax(logits_per_image.float(), dim=-1)
return logits_per_image, attention_weight
def forward(self, image, text, h=224, w=224):
with torch.no_grad():
text_features = self.encode_text(text)
feature_map, _ = self.visual(image.type(self.dtype), h, w)
logits_per_image, _ = self.forward_last_layer(
feature_map, text_features, use_cls_token=True, repeat_last=False
)
return logits_per_image