# coding=utf-8 # Copyright 2024 The Google Research Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """A wrapper for CLIP model to support forward with a list of text inputs.""" # pylint: disable=g-importing-member import clip import numpy as np import torch from torch import nn import torch.nn.functional as F _CONTEXT_LENGTH = 77 def forward_clip_single(model, image, text, h, w): """Forward a single text input. Args: model (CLIPWrapper or CLIP): the CLIP model. image (torch.Tensor): the image tensor. text (List[str]): the text input. h (int): the height of the image. w (int): the width of the image. Returns: torch.Tensor: the logits. """ if isinstance(text, str): text = [text] text_tokens = clip.tokenize(text).to(image.device) text_prediction = model(image, text_tokens, h, w) return text_prediction.detach().cpu() def forward_clip(model, image, text, h, w): """Forward a list of text inputs. Args: model (CLIPWrapper or CLIP): the CLIP model. image (torch.Tensor): the image tensor. text (List[str] or List[List[str]]): the text input. h (int): the height of the image. w (int): the width of the image. Returns: torch.Tensor: the logits. """ if isinstance(text[0], list): text_prediction = torch.stack( [forward_clip_single(model, image, t, h, w) for t in text], dim=0 ) text_prediction = torch.sum(text_prediction, dim=0) text_prediction = F.softmax(text_prediction.float(), dim=-1) else: text_prediction = forward_clip_single(model, image, text, h, w) return text_prediction.float() def upsample_position_embedding(embed, new_size): """Upsample the pretrained embedding to a higher resolution. Args: embed (torch.Tensor): the pretrained embedding. new_size (Tuple[int, int]): the new size of the embedding. Returns: torch.Tensor: the upsampled embedding. """ # emb size NxD first = embed[:1, :] embed = embed[1:, :] n = embed.size(0) d = embed.size(1) size = int(np.sqrt(n)) if size * size != n: raise ValueError(f'The size of embed {n} is not a perfect square number.') # new_size = size * self.upsample embed = embed.permute(1, 0) embed = embed.view(1, d, size, size).contiguous() embed = F.upsample( embed, size=new_size, mode='bilinear', ) embed = embed.view(d, -1).contiguous() embed = embed.permute(1, 0) embed = torch.cat([first, embed], 0) embed = nn.parameter.Parameter(embed.half()) return embed class CustomBlock(nn.Module): """A customized attention block.""" def __init__(self, block): super().__init__() for k, v in vars(block).items(): setattr(self, k, v) def attention(self, x): self.attn_mask = ( self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None ) self.attn = self.attn.to(dtype=x.dtype, device=x.device) # Setting need_weights to True also returns the attention weights return self.attn(x, x, x, need_weights=True, attn_mask=self.attn_mask) def forward(self, x): # attn_output: (L,N,E), attn_weight: (N,L,L) attn_output, attn_weight = self.attention(self.ln_1(x)) x = x + attn_output x = x + self.mlp(self.ln_2(x)) return x, attn_weight class CustomTransformer(nn.Module): """A customized Transformer to support CAM calculation.""" def __init__(self, transformer): """Initialize the wrapper. Args: transformer (nn.Module): the Transformer to be wrapped. """ super().__init__() for k, v in vars(transformer).items(): setattr(self, k, v) self.resblocks = nn.Sequential( *[CustomBlock(block) for block in self.resblocks] ) def forward(self, x): attn_weights = [] with torch.no_grad(): layers = self.layers if x.shape[0] == _CONTEXT_LENGTH else self.layers - 1 for i in range(layers): x, attn_weight = self.resblocks[i](x) attn_weights.append(attn_weight) return x, attn_weights class CustomVisionTransformer(nn.Module): """A customized VisionTransformer to support CAM calculation.""" def __init__(self, model): """Initialize the wrapper. Args: model (VisionTransformer): the VisionTransformer to be wrapped. """ super().__init__() for k, v in vars(model).items(): setattr(self, k, v) self.patch_size = self.conv1.kernel_size[0] self.transformer = CustomTransformer(self.transformer) def forward(self, x, h, w): self.positional_embedding_new = upsample_position_embedding( self.positional_embedding, (h // self.patch_size, w // self.patch_size) ) # shape = [*, width, grid, grid] x = self.conv1(x) # shape = [*, width, grid ** 2] x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, grid ** 2, width] x = x.permute(0, 2, 1) zeros = torch.zeros( x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device ) # shape = [*, grid ** 2 + 1, width] x = torch.cat([self.class_embedding.to(x.dtype) + zeros, x], dim=1) x = x + self.positional_embedding_new.to(x.dtype) x = self.ln_pre(x) # NLD -> LND x = x.permute(1, 0, 2) x, attn_weight = self.transformer(x) return x, attn_weight class CLIPWrapper(nn.Module): """A wrapper for CLIP to support forward with a list of text inputs.""" def __init__(self, clip_model): """Initialize the wrapper. Args: clip_model (CLIP): the CLIP model to be wrapped. """ super().__init__() # copy all attributes from clip_model to self for k, v in vars(clip_model).items(): setattr(self, k, v) self.visual = CustomVisionTransformer(self.visual) self.transformer = CustomTransformer(self.transformer) @property def dtype(self): return self.visual.conv1.weight.dtype def encode_image(self, image, h, w): return self.visual(image.type(self.dtype), h, w) def encode_text(self, text): x = self.token_embedding(text).type( self.dtype ) # [batch_size, n_ctx, d_model] x = x + self.positional_embedding.type(self.dtype) x = x.permute(1, 0, 2) # NLD -> LND x, _ = self.transformer(x) x = x.permute(1, 0, 2) # LND -> NLD x = self.ln_final(x).type(self.dtype) # x.shape = [batch_size, n_ctx, transformer.width] # take features from the eot embedding # (eot_token is the highest number in each sequence) x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection return x def pool_visual(self, x, use_cls_token=False): if use_cls_token: return x[:, 0] else: return torch.mean(x[:, 1:, :], dim=1) def forward_last_layer( self, image_features, text_features, use_cls_token=False, repeat_last=True ): """Forward the last layer of CLIP. Args: image_features (torch.Tensor): the image features. text_features (torch.Tensor): the text features. use_cls_token (bool, optional): whether to use the CLS token. Defaults to False. repeat_last (bool, optional): whether to repeat the last layer. Defaults to True. Returns: torch.Tensor: the logits. torch.Tensor: the attention weights. """ if repeat_last: x, attention_weight = self.visual.transformer.resblocks[ self.visual.transformer.layers - 1 ](image_features) else: x = image_features attention_weight = None x = x.permute(1, 0, 2) # LND -> NLD x = self.visual.ln_post(x) x = self.pool_visual(x, use_cls_token=use_cls_token) if self.visual.proj is not None: x = x @ self.visual.proj image_features = x # normalized features image_features = image_features / image_features.norm(dim=1, keepdim=True) text_features = text_features / text_features.norm(dim=1, keepdim=True) # cosine similarity as logits logit_scale = self.logit_scale.exp() logits_per_image = logit_scale * image_features @ text_features.t() # shape = [global_batch_size, global_batch_size] logits_per_image = F.softmax(logits_per_image.float(), dim=-1) return logits_per_image, attention_weight def forward(self, image, text, h=224, w=224): with torch.no_grad(): text_features = self.encode_text(text) feature_map, _ = self.visual(image.type(self.dtype), h, w) logits_per_image, _ = self.forward_last_layer( feature_map, text_features, use_cls_token=True, repeat_last=False ) return logits_per_image