Spaces:
Sleeping
Sleeping
from collections import OrderedDict | |
import math | |
from typing import Callable, Optional, Sequence, Tuple | |
import torch | |
from torch import nn | |
from torch.nn import functional as F | |
from torch.utils.checkpoint import checkpoint | |
class LayerNormFp32(nn.LayerNorm): | |
"""Subclass torch's LayerNorm to handle fp16 (by casting to float32 and back).""" | |
def forward(self, x: torch.Tensor): | |
orig_type = x.dtype | |
x = F.layer_norm(x.to(torch.float32), self.normalized_shape, self.weight, self.bias, self.eps) | |
return x.to(orig_type) | |
class LayerNorm(nn.LayerNorm): | |
"""Subclass torch's LayerNorm (with cast back to input dtype).""" | |
def forward(self, x: torch.Tensor): | |
orig_type = x.dtype | |
x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) | |
return x.to(orig_type) | |
class QuickGELU(nn.Module): | |
# NOTE This is slower than nn.GELU or nn.SiLU and uses more GPU memory | |
def forward(self, x: torch.Tensor): | |
return x * torch.sigmoid(1.702 * x) | |
class LayerScale(nn.Module): | |
def __init__(self, dim, init_values=1e-5, inplace=False): | |
super().__init__() | |
self.inplace = inplace | |
self.gamma = nn.Parameter(init_values * torch.ones(dim)) | |
def forward(self, x): | |
return x.mul_(self.gamma) if self.inplace else x * self.gamma | |
class PatchDropout(nn.Module): | |
""" | |
https://arxiv.org/abs/2212.00794 | |
""" | |
def __init__(self, prob, exclude_first_token=True): | |
super().__init__() | |
assert 0 <= prob < 1. | |
self.prob = prob | |
self.exclude_first_token = exclude_first_token # exclude CLS token | |
def forward(self, x): | |
if not self.training or self.prob == 0.: | |
return x | |
if self.exclude_first_token: | |
cls_tokens, x = x[:, :1], x[:, 1:] | |
else: | |
cls_tokens = torch.jit.annotate(torch.Tensor, x[:, :1]) | |
batch = x.size()[0] | |
num_tokens = x.size()[1] | |
batch_indices = torch.arange(batch) | |
batch_indices = batch_indices[..., None] | |
keep_prob = 1 - self.prob | |
num_patches_keep = max(1, int(num_tokens * keep_prob)) | |
rand = torch.randn(batch, num_tokens) | |
patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices | |
x = x[batch_indices, patch_indices_keep] | |
if self.exclude_first_token: | |
x = torch.cat((cls_tokens, x), dim=1) | |
return x | |
class Attention(nn.Module): | |
def __init__( | |
self, | |
dim, | |
num_heads=8, | |
qkv_bias=True, | |
scaled_cosine=False, | |
scale_heads=False, | |
logit_scale_max=math.log(1. / 0.01), | |
attn_drop=0., | |
proj_drop=0. | |
): | |
super().__init__() | |
self.scaled_cosine = scaled_cosine | |
self.scale_heads = scale_heads | |
assert dim % num_heads == 0, 'dim should be divisible by num_heads' | |
self.num_heads = num_heads | |
self.head_dim = dim // num_heads | |
self.scale = self.head_dim ** -0.5 | |
self.logit_scale_max = logit_scale_max | |
# keeping in_proj in this form (instead of nn.Linear) to match weight scheme of original | |
self.in_proj_weight = nn.Parameter(torch.randn((dim * 3, dim)) * self.scale) | |
if qkv_bias: | |
self.in_proj_bias = nn.Parameter(torch.zeros(dim * 3)) | |
else: | |
self.in_proj_bias = None | |
if self.scaled_cosine: | |
self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1)))) | |
else: | |
self.logit_scale = None | |
self.attn_drop = nn.Dropout(attn_drop) | |
if self.scale_heads: | |
self.head_scale = nn.Parameter(torch.ones((num_heads, 1, 1))) | |
else: | |
self.head_scale = None | |
self.out_proj = nn.Linear(dim, dim) | |
self.out_drop = nn.Dropout(proj_drop) | |
def forward(self, x, attn_mask: Optional[torch.Tensor] = None): | |
L, N, C = x.shape | |
q, k, v = F.linear(x, self.in_proj_weight, self.in_proj_bias).chunk(3, dim=-1) | |
q = q.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1) | |
k = k.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1) | |
v = v.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1) | |
if self.logit_scale is not None: | |
attn = torch.bmm(F.normalize(q, dim=-1), F.normalize(k, dim=-1).transpose(-1, -2)) | |
logit_scale = torch.clamp(self.logit_scale, max=self.logit_scale_max).exp() | |
attn = attn.view(N, self.num_heads, L, L) * logit_scale | |
attn = attn.view(-1, L, L) | |
else: | |
q = q * self.scale | |
attn = torch.bmm(q, k.transpose(-1, -2)) | |
if attn_mask is not None: | |
if attn_mask.dtype == torch.bool: | |
new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype) | |
new_attn_mask.masked_fill_(attn_mask, float("-inf")) | |
attn_mask = new_attn_mask | |
attn += attn_mask | |
attn = attn.softmax(dim=-1) | |
attn = self.attn_drop(attn) | |
x = torch.bmm(attn, v) | |
if self.head_scale is not None: | |
x = x.view(N, self.num_heads, L, C) * self.head_scale | |
x = x.view(-1, L, C) | |
x = x.transpose(0, 1).reshape(L, N, C) | |
x = self.out_proj(x) | |
x = self.out_drop(x) | |
return x | |
class ResidualAttentionBlock(nn.Module): | |
def __init__( | |
self, | |
d_model: int, | |
n_head: int, | |
mlp_ratio: float = 4.0, | |
ls_init_value: float = None, | |
act_layer: Callable = nn.GELU, | |
norm_layer: Callable = LayerNorm, | |
is_cross_attention: bool = False, | |
): | |
super().__init__() | |
self.ln_1 = norm_layer(d_model) | |
self.attn = nn.MultiheadAttention(d_model, n_head) | |
self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() | |
if is_cross_attention: | |
self.ln_1_kv = norm_layer(d_model) | |
self.ln_2 = norm_layer(d_model) | |
mlp_width = int(d_model * mlp_ratio) | |
self.mlp = nn.Sequential(OrderedDict([ | |
("c_fc", nn.Linear(d_model, mlp_width)), | |
("gelu", act_layer()), | |
("c_proj", nn.Linear(mlp_width, d_model)) | |
])) | |
self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() | |
def attention( | |
self, | |
q_x: torch.Tensor, | |
k_x: Optional[torch.Tensor] = None, | |
v_x: Optional[torch.Tensor] = None, | |
attn_mask: Optional[torch.Tensor] = None, | |
): | |
k_x = k_x if k_x is not None else q_x | |
v_x = v_x if v_x is not None else q_x | |
attn_mask = attn_mask.to(q_x.dtype) if attn_mask is not None else None | |
return self.attn( | |
q_x, k_x, v_x, need_weights=False, attn_mask=attn_mask | |
)[0] | |
def forward( | |
self, | |
q_x: torch.Tensor, | |
k_x: Optional[torch.Tensor] = None, | |
v_x: Optional[torch.Tensor] = None, | |
attn_mask: Optional[torch.Tensor] = None, | |
): | |
k_x = self.ln_1_kv(k_x) if hasattr(self, "ln_1_kv") and k_x is not None else None | |
v_x = self.ln_1_kv(v_x) if hasattr(self, "ln_1_kv") and v_x is not None else None | |
x = q_x + self.ls_1(self.attention(q_x=self.ln_1(q_x), k_x=k_x, v_x=v_x, attn_mask=attn_mask)) | |
x = x + self.ls_2(self.mlp(self.ln_2(x))) | |
return x | |
class Transformer(nn.Module): | |
def __init__( | |
self, | |
width: int, | |
layers: int, | |
heads: int, | |
mlp_ratio: float = 4.0, | |
ls_init_value: float = None, | |
act_layer: Callable = nn.GELU, | |
norm_layer: Callable = LayerNorm, | |
): | |
super().__init__() | |
self.width = width | |
self.layers = layers | |
self.grad_checkpointing = False | |
self.resblocks = nn.ModuleList([ | |
ResidualAttentionBlock( | |
width, heads, mlp_ratio, ls_init_value=ls_init_value, act_layer=act_layer, norm_layer=norm_layer) | |
for _ in range(layers) | |
]) | |
def get_cast_dtype(self) -> torch.dtype: | |
if hasattr(self.resblocks[0].mlp.c_fc, 'int8_original_dtype'): | |
return self.resblocks[0].mlp.c_fc.int8_original_dtype | |
return self.resblocks[0].mlp.c_fc.weight.dtype | |
def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): | |
for r in self.resblocks: | |
if self.grad_checkpointing and not torch.jit.is_scripting(): | |
# TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372 | |
x = checkpoint(r, x, None, None, attn_mask) | |
else: | |
x = r(x, attn_mask=attn_mask) | |
return x |