|
|
|
import torch |
|
from timm.models.layers import trunc_normal_ |
|
from functools import partial |
|
import timm.models.vision_transformer |
|
import torch.nn as nn |
|
from timm.models.vision_transformer import Block, PatchEmbed |
|
import os |
|
from torchvision.io import read_image |
|
import numpy as np |
|
import sys |
|
import random |
|
import pytorch_lightning as pl |
|
import torch.nn.functional as F |
|
from huggingface_hub import PyTorchModelHubMixin |
|
|
|
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): |
|
""" |
|
grid_size: int of the grid height and width |
|
return: |
|
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) |
|
""" |
|
grid_h = np.arange(grid_size, dtype=np.float32) |
|
grid_w = np.arange(grid_size, dtype=np.float32) |
|
grid = np.meshgrid(grid_w, grid_h) |
|
grid = np.stack(grid, axis=0) |
|
|
|
grid = grid.reshape([2, 1, grid_size, grid_size]) |
|
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) |
|
if cls_token: |
|
pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) |
|
return pos_embed |
|
|
|
|
|
def get_2d_sincos_pos_embed_with_resolution( |
|
embed_dim, grid_size, res, cls_token=False, device="cpu" |
|
): |
|
""" |
|
grid_size: int of the grid height and width |
|
res: array of size n, representing the resolution of a pixel (say, in meters), |
|
return: |
|
pos_embed: [n,grid_size*grid_size, embed_dim] or [n,1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) |
|
""" |
|
|
|
res = res.to(device) |
|
grid_h = torch.arange(grid_size, dtype=torch.float32, device=device) |
|
grid_w = torch.arange(grid_size, dtype=torch.float32, device=device) |
|
grid = torch.meshgrid( |
|
grid_w, grid_h, indexing="xy" |
|
) |
|
grid = torch.stack(grid, dim=0) |
|
|
|
|
|
grid = torch.einsum("chw,n->cnhw", grid, res) |
|
_, n, h, w = grid.shape |
|
pos_embed = get_2d_sincos_pos_embed_from_grid_torch( |
|
embed_dim, grid |
|
) |
|
pos_embed = pos_embed.reshape(n, h * w, embed_dim) |
|
if cls_token: |
|
pos_embed = torch.cat( |
|
[ |
|
torch.zeros( |
|
[n, 1, embed_dim], dtype=torch.float32, device=pos_embed.device |
|
), |
|
pos_embed, |
|
], |
|
dim=1, |
|
) |
|
return pos_embed |
|
|
|
|
|
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): |
|
assert embed_dim % 2 == 0 |
|
|
|
|
|
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) |
|
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) |
|
|
|
emb = np.concatenate([emb_h, emb_w], axis=1) |
|
return emb |
|
|
|
|
|
def get_2d_sincos_pos_embed_from_grid_torch(embed_dim, grid): |
|
assert embed_dim % 2 == 0 |
|
|
|
|
|
emb_h = get_1d_sincos_pos_embed_from_grid_torch( |
|
embed_dim // 2, grid[0] |
|
) |
|
emb_w = get_1d_sincos_pos_embed_from_grid_torch( |
|
embed_dim // 2, grid[1] |
|
) |
|
|
|
emb = torch.cat([emb_h, emb_w], dim=1) |
|
return emb |
|
|
|
|
|
def get_1d_sincos_pos_embed_from_grid_torch(embed_dim, pos): |
|
""" |
|
embed_dim: output dimension for each position |
|
pos: a list of positions to be encoded: size (M,) |
|
out: (M, D) |
|
""" |
|
assert embed_dim % 2 == 0 |
|
old_shape = pos |
|
omega = torch.arange(embed_dim // 2, dtype=torch.float32, device=pos.device) |
|
omega /= embed_dim / 2.0 |
|
omega = 1.0 / 10000**omega |
|
|
|
pos = pos.reshape(-1) |
|
out = torch.einsum("m,d->md", pos, omega) |
|
|
|
emb_sin = torch.sin(out) |
|
emb_cos = torch.cos(out) |
|
|
|
emb = torch.cat([emb_sin, emb_cos], dim=1) |
|
return emb |
|
|
|
|
|
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): |
|
""" |
|
embed_dim: output dimension for each position |
|
pos: a list of positions to be encoded: size (M,) |
|
out: (M, D) |
|
""" |
|
assert embed_dim % 2 == 0 |
|
omega = np.arange(embed_dim // 2, dtype=np.float32) |
|
omega /= embed_dim / 2.0 |
|
omega = 1.0 / 10000**omega |
|
|
|
pos = pos.reshape(-1) |
|
out = np.einsum("m,d->md", pos, omega) |
|
|
|
emb_sin = np.sin(out) |
|
emb_cos = np.cos(out) |
|
|
|
emb = np.concatenate([emb_sin, emb_cos], axis=1) |
|
return emb |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def interpolate_pos_embed(model, checkpoint_model): |
|
if "pos_embed" in checkpoint_model: |
|
pos_embed_checkpoint = checkpoint_model["pos_embed"] |
|
embedding_size = pos_embed_checkpoint.shape[-1] |
|
num_patches = model.patch_embed.num_patches |
|
num_extra_tokens = model.pos_embed.shape[-2] - num_patches |
|
|
|
orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) |
|
|
|
new_size = int(num_patches**0.5) |
|
|
|
if orig_size != new_size: |
|
print( |
|
"Position interpolate from %dx%d to %dx%d" |
|
% (orig_size, orig_size, new_size, new_size) |
|
) |
|
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] |
|
|
|
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] |
|
pos_tokens = pos_tokens.reshape( |
|
-1, orig_size, orig_size, embedding_size |
|
).permute(0, 3, 1, 2) |
|
pos_tokens = torch.nn.functional.interpolate( |
|
pos_tokens, |
|
size=(new_size, new_size), |
|
mode="bicubic", |
|
align_corners=False, |
|
) |
|
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) |
|
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) |
|
checkpoint_model["pos_embed"] = new_pos_embed |
|
|
|
class PatchEmbedUnSafe(PatchEmbed): |
|
"""Image to Patch Embedding""" |
|
|
|
def forward(self, x): |
|
B, C, H, W = x.shape |
|
|
|
|
|
|
|
x = self.proj(x).flatten(2).transpose(1, 2) |
|
return x |
|
|
|
|
|
class VisionTransformer(timm.models.vision_transformer.VisionTransformer): |
|
"""Vision Transformer with support for global average pooling""" |
|
|
|
def __init__( |
|
self, cls_token_flag=False, global_pool=False, patch_size=16, in_chans=3, embed_dim=1024, **kwargs |
|
): |
|
super().__init__(embed_dim=embed_dim, **kwargs) |
|
self.cls_token_flag = cls_token_flag |
|
|
|
self.patch_embed = PatchEmbedUnSafe( |
|
img_size=224, |
|
patch_size=patch_size, |
|
in_chans=in_chans, |
|
embed_dim=embed_dim, |
|
) |
|
|
|
self.global_pool = global_pool |
|
if self.global_pool: |
|
norm_layer = kwargs["norm_layer"] |
|
embed_dim = embed_dim |
|
self.fc_norm = norm_layer(embed_dim) |
|
|
|
del self.norm |
|
|
|
del self.head |
|
if self.cls_token_flag == False: |
|
del self.cls_token |
|
del self.pos_embed |
|
|
|
def forward_features(self, x, input_res=None): |
|
B, _, h, w = x.shape |
|
x = self.patch_embed(x) |
|
input_res = input_res.cpu() |
|
|
|
num_patches = int( |
|
(h * w) / (self.patch_embed.patch_size[0] * self.patch_embed.patch_size[1]) |
|
) |
|
pos_embed = get_2d_sincos_pos_embed_with_resolution( |
|
x.shape[-1], |
|
int(num_patches**0.5), |
|
input_res, |
|
cls_token=self.cls_token_flag, |
|
device=x.device, |
|
) |
|
|
|
if self.cls_token_flag: |
|
cls_tokens = self.cls_token.expand( |
|
B, -1, -1 |
|
) |
|
x = torch.cat((cls_tokens, x), dim=1) |
|
x = x + pos_embed |
|
x = self.pos_drop(x) |
|
|
|
for blk in self.blocks: |
|
x = blk(x) |
|
|
|
|
|
|
|
outcome = self.fc_norm(x) |
|
return outcome |
|
|
|
def forward(self, x, input_res=None): |
|
x = self.forward_features(x, input_res=input_res) |
|
return x |
|
|
|
|
|
def vit_large_patch16(**kwargs): |
|
model = VisionTransformer( |
|
patch_size=16, |
|
embed_dim=1024, |
|
depth=24, |
|
num_heads=16, |
|
mlp_ratio=4, |
|
qkv_bias=True, |
|
norm_layer=partial(nn.LayerNorm, eps=1e-6), |
|
**kwargs |
|
) |
|
return model |
|
|
|
def get_ScaleMAE_model(global_pool=True, cls_token=True): |
|
|
|
model = vit_large_patch16( |
|
num_classes=1000, |
|
drop_path_rate=0.1, |
|
global_pool=global_pool, |
|
cls_token_flag = cls_token |
|
) |
|
|
|
if global_pool: |
|
assert set(msg.missing_keys) == { |
|
"head.weight", |
|
"head.bias", |
|
"fc_norm.weight", |
|
"fc_norm.bias", |
|
} |
|
else: |
|
pass |
|
|
|
return model |
|
|
|
|
|
class ScaleMAE_baseline(pl.LightningModule, PyTorchModelHubMixin): |
|
def __init__(self, feat_dim=1024, fc_dim=1024, global_pool=False, cls_token_flag=True): |
|
super().__init__() |
|
self.model = get_ScaleMAE_model(global_pool= global_pool,cls_token = cls_token_flag) |
|
|
|
def forward(self,x,patch_size,input_res=10.0): |
|
|
|
input_res = torch.tensor([10.0]).to(x.device) |
|
x = self.model(x,input_res=input_res) |
|
|
|
return x |