Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,873 Bytes
9e426da |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
import torch
import copy
import os
import timm
import transformers
import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.data import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
from torchvision.transforms import Normalize
class RandViT(nn.Module):
def __init__(self, model_id, weight_path:str=None):
super(RandViT, self).__init__()
self.encoder = timm.create_model(
model_id,
num_classes=0,
)
self.pos_embed = copy.deepcopy(self.encoder.pos_embed)
self.encoder.head = torch.nn.Identity()
self.patch_size = self.encoder.patch_embed.patch_size
self.shifts = nn.Parameter(torch.tensor([0.0
]), requires_grad=False)
self.scales = nn.Parameter(torch.tensor([1.0
]), requires_grad=False)
def forward(self, x):
x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
x = torch.nn.functional.interpolate(x, (224, 224), mode='bicubic')
b, c, h, w = x.shape
patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1]
feature = self.encoder.forward_features(x)[:, self.encoder.num_prefix_tokens:]
feature = feature.transpose(1, 2)
feature = feature.view(b, -1, patch_num_h, patch_num_w).contiguous()
feature = (feature - self.shifts.view(1, -1, 1, 1)) / self.scales.view(1, -1, 1, 1)
return feature
class MAE(nn.Module):
def __init__(self, model_id, weight_path:str):
super(MAE, self).__init__()
if os.path.isdir(weight_path):
weight_path = os.path.join(weight_path, "pytorch_model.bin")
self.encoder = timm.create_model(
model_id,
checkpoint_path=weight_path,
num_classes=0,
)
self.pos_embed = copy.deepcopy(self.encoder.pos_embed)
self.encoder.head = torch.nn.Identity()
self.patch_size = self.encoder.patch_embed.patch_size
self.shifts = nn.Parameter(torch.tensor([0.0
]), requires_grad=False)
self.scales = nn.Parameter(torch.tensor([1.0
]), requires_grad=False)
def forward(self, x):
x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
x = torch.nn.functional.interpolate(x, (224, 224), mode='bicubic')
b, c, h, w = x.shape
patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1]
feature = self.encoder.forward_features(x)[:, self.encoder.num_prefix_tokens:]
feature = feature.transpose(1, 2)
feature = feature.view(b, -1, patch_num_h, patch_num_w).contiguous()
feature = (feature - self.shifts.view(1, -1, 1, 1)) / self.scales.view(1, -1, 1, 1)
return feature
class DINO(nn.Module):
def __init__(self, model_id, weight_path:str):
super(DINO, self).__init__()
if os.path.isdir(weight_path):
weight_path = os.path.join(weight_path, "pytorch_model.bin")
self.encoder = timm.create_model(
model_id,
checkpoint_path=weight_path,
num_classes=0,
)
self.pos_embed = copy.deepcopy(self.encoder.pos_embed)
self.encoder.head = torch.nn.Identity()
self.patch_size = self.encoder.patch_embed.patch_size
self.shifts = nn.Parameter(torch.tensor([ 0.0,
]), requires_grad=False)
self.scales = nn.Parameter(torch.tensor([ 1.0,
]), requires_grad=False)
def forward(self, x):
x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
x = torch.nn.functional.interpolate(x, (224, 224), mode='bicubic')
b, c, h, w = x.shape
patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1]
feature = self.encoder.forward_features(x)[:, self.encoder.num_prefix_tokens:]
feature = feature.transpose(1, 2)
feature = feature.view(b, -1, patch_num_h, patch_num_w).contiguous()
feature = (feature - self.shifts.view(1, -1, 1, 1)) / self.scales.view(1, -1, 1, 1)
return feature
class CLIP(nn.Module):
def __init__(self, model_id, weight_path:str):
super(CLIP, self).__init__()
self.encoder = transformers.CLIPVisionModel.from_pretrained(weight_path)
self.patch_size = self.encoder.vision_model.embeddings.patch_embedding.kernel_size
self.shifts = nn.Parameter(torch.tensor([0.0,
]), requires_grad=False)
self.scales = nn.Parameter(torch.tensor([1.0,
]), requires_grad=False)
def forward(self, x):
x = Normalize(OPENAI_CLIP_MEAN, OPENAI_CLIP_STD)(x)
x = torch.nn.functional.interpolate(x, (224, 224), mode='bicubic')
b, c, h, w = x.shape
patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1]
feature = self.encoder(x)['last_hidden_state'][:, 1:]
feature = feature.transpose(1, 2)
feature = feature.view(b, -1, patch_num_h, patch_num_w).contiguous()
feature = (feature - self.shifts.view(1, -1, 1, 1)) / self.scales.view(1, -1, 1, 1)
return feature
class DINOv2(nn.Module):
def __init__(self, model_id, weight_path:str):
super(DINOv2, self).__init__()
self.encoder = transformers.Dinov2Model.from_pretrained(weight_path)
self.patch_size = self.encoder.embeddings.patch_embeddings.projection.kernel_size
def forward(self, x):
x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
x = torch.nn.functional.interpolate(x, (224, 224), mode='bicubic')
b, c, h, w = x.shape
patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1]
feature = self.encoder.forward(x)['last_hidden_state'][:, 1:]
feature = feature.transpose(1, 2)
feature = feature.view(b, -1, patch_num_h, patch_num_w).contiguous()
return feature |