""" CLIP Model Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. """ import copy import logging import math from dataclasses import dataclass from typing import Any, Dict, Optional, Tuple, Union import numpy as np import torch import torch.nn.functional as F from torch import nn from torch.utils.checkpoint import checkpoint from functools import partial # from .hf_model import HFTextEncoder from .modified_resnet import ModifiedResNet from .timm_model import TimmModel from .transformer import LayerNormFp32, LayerNorm, QuickGELU, Attention, VisionTransformer, TextTransformer,\ text_global_pool from .utils import to_2tuple @dataclass class CLIPVisionCfg: layers: Union[Tuple[int, int, int, int], int] = 12 width: int = 768 head_width: int = 64 mlp_ratio: float = 4.0 patch_size: int = 16 image_size: Union[Tuple[int, int], int] = 224 in_channels: int = 100 ls_init_value: Optional[float] = None # layer scale initial value patch_dropout: float = 0. # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results attentional_pool: bool = False # whether to use attentional pooler in the last embedding layer (overrides pool_type) attn_pooler_queries: int = 256 # n_queries for attentional pooler attn_pooler_heads: int = 8 # n heads for attentional_pooling no_ln_pre: bool = False # disable pre transformer LayerNorm pos_embed_type: str = 'learnable' final_ln_after_pool: bool = False # apply final LayerNorm after pooling pool_type: str = 'tok' output_tokens: bool = False act_kwargs: Optional[dict] = None norm_kwargs: Optional[dict] = None timm_model_name: Optional[str] = None # a valid model name overrides layers, width, patch_size timm_model_pretrained: bool = False # use (imagenet) pretrained weights for named model timm_pool: str = 'avg' # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '') timm_proj: str = 'linear' # linear projection for timm model output ('linear', 'mlp', '') timm_proj_bias: bool = False # enable bias final projection timm_drop: float = 0. # head dropout timm_drop_path: Optional[float] = None # backbone stochastic depth def get_cast_dtype(precision: str): cast_dtype = None if precision == 'bf16': cast_dtype = torch.bfloat16 elif precision == 'fp16': cast_dtype = torch.float16 return cast_dtype def get_input_dtype(precision: str): input_dtype = None if precision in ('bf16', 'pure_bf16'): input_dtype = torch.bfloat16 elif precision in ('fp16', 'pure_fp16'): input_dtype = torch.float16 return input_dtype def _build_vision_tower( embed_dim: int, vision_cfg: CLIPVisionCfg, quick_gelu: bool = False, cast_dtype: Optional[torch.dtype] = None ): if isinstance(vision_cfg, dict): vision_cfg = CLIPVisionCfg(**vision_cfg) # OpenAI models are pretrained w/ QuickGELU but native nn.GELU is both faster and more # memory efficient in recent PyTorch releases (>= 1.10). # NOTE: timm models always use native GELU regardless of quick_gelu flag. act_layer = QuickGELU if quick_gelu else nn.GELU if vision_cfg.timm_model_name: visual = TimmModel( vision_cfg.timm_model_name, pretrained=vision_cfg.timm_model_pretrained, pool=vision_cfg.timm_pool, proj=vision_cfg.timm_proj, proj_bias=vision_cfg.timm_proj_bias, drop=vision_cfg.timm_drop, drop_path=vision_cfg.timm_drop_path, patch_drop=vision_cfg.patch_dropout if vision_cfg.patch_dropout > 0 else None, embed_dim=embed_dim, image_size=vision_cfg.image_size, ) elif isinstance(vision_cfg.layers, (tuple, list)): vision_heads = vision_cfg.width * 32 // vision_cfg.head_width visual = ModifiedResNet( layers=vision_cfg.layers, output_dim=embed_dim, heads=vision_heads, image_size=vision_cfg.image_size, width=vision_cfg.width, ) else: vision_heads = vision_cfg.width // vision_cfg.head_width norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm if vision_cfg.norm_kwargs: norm_layer = partial(norm_layer, **vision_cfg.norm_kwargs) if vision_cfg.act_kwargs is not None: act_layer = partial(act_layer, **vision_cfg.act_kwargs) visual = VisionTransformer( image_size=vision_cfg.image_size, patch_size=vision_cfg.patch_size, width=vision_cfg.width, layers=vision_cfg.layers, heads=vision_heads, in_channels=vision_cfg.in_channels, mlp_ratio=vision_cfg.mlp_ratio, ls_init_value=vision_cfg.ls_init_value, patch_dropout=vision_cfg.patch_dropout, attentional_pool=vision_cfg.attentional_pool, attn_pooler_queries=vision_cfg.attn_pooler_queries, attn_pooler_heads=vision_cfg.attn_pooler_heads, pos_embed_type=vision_cfg.pos_embed_type, no_ln_pre=vision_cfg.no_ln_pre, final_ln_after_pool=vision_cfg.final_ln_after_pool, pool_type=vision_cfg.pool_type, output_tokens=vision_cfg.output_tokens, output_dim=embed_dim, act_layer=act_layer, norm_layer=norm_layer, ) return visual class CLIP(nn.Module): output_dict: torch.jit.Final[bool] def __init__( self, embed_dim: int, vision_cfg: CLIPVisionCfg, quick_gelu: bool = False, init_logit_scale: float = np.log(1 / 0.07), init_logit_bias: Optional[float] = None, cast_dtype: Optional[torch.dtype] = None, output_dict: bool = False, ): super().__init__() self.output_dict = output_dict self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype) # self.logit_scale = nn.Parameter(torch.ones([]) * init_logit_scale) if init_logit_bias is not None: self.logit_bias = nn.Parameter(torch.ones([]) * init_logit_bias) else: self.logit_bias = None def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False): # lock image tower as per LiT - https://arxiv.org/abs/2111.07991 self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats) @torch.jit.ignore def set_grad_checkpointing(self, enable=True): self.visual.set_grad_checkpointing(enable) self.transformer.grad_checkpointing = enable def encode_image(self, image, normalize: bool = False): features = self.visual(image) return F.normalize(features, dim=-1) if normalize else features def forward( self, image: Optional[torch.Tensor] = None, ): image_features = self.encode_image(image, normalize=True) if image is not None else None return image_features def convert_weights_to_lp(model: nn.Module, dtype=torch.float16): """Convert applicable model parameters to low-precision (bf16 or fp16)""" def _convert_weights(l): if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): l.weight.data = l.weight.data.to(dtype) if l.bias is not None: l.bias.data = l.bias.data.to(dtype) if isinstance(l, (nn.MultiheadAttention, Attention)): for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: tensor = getattr(l, attr) if tensor is not None: tensor.data = tensor.data.to(dtype) if isinstance(l, (CLIP, TextTransformer)): # convert text nn.Parameter projections attr = getattr(l, "text_projection", None) if attr is not None: attr.data = attr.data.to(dtype) if isinstance(l, VisionTransformer): # convert vision nn.Parameter projections attr = getattr(l, "proj", None) if attr is not None: attr.data = attr.data.to(dtype) model.apply(_convert_weights) convert_weights_to_fp16 = convert_weights_to_lp # backwards compat # used to maintain checkpoint compatibility def convert_to_custom_text_state_dict(state_dict: dict): if 'text_projection' in state_dict: # old format state_dict, move text tower -> .text new_state_dict = {} for k, v in state_dict.items(): if any(k.startswith(p) for p in ( 'text_projection', 'positional_embedding', 'token_embedding', 'transformer', 'ln_final', )): k = 'text.' + k new_state_dict[k] = v return new_state_dict return state_dict def build_model_from_openai_state_dict( state_dict: dict, quick_gelu=True, cast_dtype=torch.float16, ): vit = "visual.proj" in state_dict if vit: vision_width = state_dict["visual.conv1.weight"].shape[0] vision_layers = len( [k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) image_size = vision_patch_size * grid_size else: counts: list = [ len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]] vision_layers = tuple(counts) vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5) vision_patch_size = None assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0] image_size = output_width * 32 embed_dim = state_dict["text_projection"].shape[1] context_length = state_dict["positional_embedding"].shape[0] vocab_size = state_dict["token_embedding.weight"].shape[0] transformer_width = state_dict["ln_final.weight"].shape[0] transformer_heads = transformer_width // 64 transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks"))) vision_cfg = CLIPVisionCfg( layers=vision_layers, width=vision_width, patch_size=vision_patch_size, image_size=image_size, ) model = CLIP( embed_dim, vision_cfg=vision_cfg, quick_gelu=quick_gelu, # OpenAI models were trained with QuickGELU cast_dtype=cast_dtype, ) for key in ["input_resolution", "context_length", "vocab_size"]: state_dict.pop(key, None) convert_weights_to_fp16(model) # OpenAI state dicts are partially converted to float16 model.load_state_dict(state_dict) return model.eval() def trace_model(model, batch_size=256, device=torch.device('cpu')): model.eval() image_size = model.visual.image_size example_images = torch.ones((batch_size, 3, image_size, image_size), device=device) example_text = torch.zeros((batch_size, model.context_length), dtype=torch.int, device=device) model = torch.jit.trace_module( model, inputs=dict( forward=(example_images, example_text), encode_text=(example_text,), encode_image=(example_images,) )) model.visual.image_size = image_size return model def resize_pos_embed(state_dict, model, interpolation: str = 'bicubic', antialias: bool = True): # Rescale the grid of position embeddings when loading from state_dict old_pos_embed = state_dict.get('visual.positional_embedding', None) if old_pos_embed is None or not hasattr(model.visual, 'grid_size'): return grid_size = to_2tuple(model.visual.grid_size) extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more) new_seq_len = grid_size[0] * grid_size[1] + extra_tokens if new_seq_len == old_pos_embed.shape[0]: return if extra_tokens: pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:] else: pos_emb_tok, pos_emb_img = None, old_pos_embed old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img)))) logging.info('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size) pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2) pos_emb_img = F.interpolate( pos_emb_img, size=grid_size, mode=interpolation, antialias=antialias, align_corners=False, ) pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0] if pos_emb_tok is not None: new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0) else: new_pos_embed = pos_emb_img state_dict['visual.positional_embedding'] = new_pos_embed