Spaces:
Running
Running
import torch | |
from torch import nn | |
import torch.nn.init as init | |
import torch.nn.functional as F | |
from paths import * | |
from typing import Dict, List, Optional, Set, Tuple, Union | |
from transformers import AutoImageProcessor, AutoModel, Dinov2Model | |
from transformers.models.dinov2.modeling_dinov2 import Dinov2Embeddings | |
from transformers.models.dinov2.configuration_dinov2 import Dinov2Config | |
import numpy as np | |
from contextlib import nullcontext | |
def get_activation(activation): | |
if activation.lower() == 'gelu': | |
return nn.GELU() | |
elif activation.lower() == 'rrelu': | |
return nn.RReLU(inplace=True) | |
elif activation.lower() == 'selu': | |
return nn.SELU(inplace=True) | |
elif activation.lower() == 'silu': | |
return nn.SiLU(inplace=True) | |
elif activation.lower() == 'hardswish': | |
return nn.Hardswish(inplace=True) | |
elif activation.lower() == 'leakyrelu': | |
return nn.LeakyReLU(inplace=True) | |
elif activation.lower() == 'sigmoid': | |
return nn.Sigmoid() | |
elif activation.lower() == 'tanh': | |
return nn.Tanh() | |
else: | |
return nn.ReLU(inplace=True) | |
class MLP_dim(nn.Module): | |
def __init__( | |
self, in_dim=512, out_dim=1024, bias=True, activation='relu'): | |
super().__init__() | |
self.act = get_activation(activation) | |
self.net1 = nn.Sequential( | |
nn.Linear(in_dim, int(out_dim), bias=bias), | |
nn.BatchNorm1d(int(out_dim)), | |
self.act | |
) | |
self.net2 = nn.Sequential( | |
nn.Linear(int(out_dim), out_dim, bias=bias), | |
nn.BatchNorm1d(out_dim) | |
) | |
def forward(self, x): | |
return self.net2(self.net1(x)) | |
class FLIP_Dinov2Embeddings(Dinov2Embeddings): | |
""" | |
Construct the CLS token, mask token, position and patch embeddings. | |
""" | |
def __init__(self, config: Dinov2Config) -> None: | |
super().__init__(config) | |
def forward(self, pixel_values: torch.Tensor, bool_masked_pos: Optional[torch.Tensor] = None) -> torch.Tensor: | |
batch_size, _, height, width = pixel_values.shape | |
target_dtype = self.patch_embeddings.projection.weight.dtype | |
embeddings = self.patch_embeddings(pixel_values.to(dtype=target_dtype)) | |
# add the [CLS] token to the embedded patch tokens | |
cls_tokens = self.cls_token.expand(batch_size, -1, -1) | |
embeddings = torch.cat((cls_tokens, embeddings), dim=1) | |
# add positional encoding to each token | |
embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width) | |
if bool_masked_pos is not None: | |
# embeddings = torch.where( | |
# bool_masked_pos.unsqueeze(-1), self.mask_token.to(embeddings.dtype).unsqueeze(0), embeddings | |
# ) | |
B,S,D = embeddings.shape | |
batch_indices = torch.arange(B).unsqueeze(1) | |
embeddings = embeddings[batch_indices, bool_masked_pos] | |
embeddings = self.dropout(embeddings) | |
return embeddings | |
class FLIP_DINOv2(Dinov2Model): | |
def __init__(self, config): | |
super().__init__(config) | |
self.embeddings = FLIP_Dinov2Embeddings(config) | |
class DINOv2_MLP(nn.Module): | |
def __init__(self, | |
dino_mode, | |
in_dim, | |
out_dim, | |
evaluate, | |
mask_dino, | |
frozen_back | |
) -> None: | |
super().__init__() | |
# self.dinov2 = AutoModel.from_pretrained(DINO_BASE) | |
if dino_mode == 'base': | |
self.dinov2 = FLIP_DINOv2.from_pretrained(DINO_BASE, cache_dir='./') | |
elif dino_mode == 'large': | |
self.dinov2 = FLIP_DINOv2.from_pretrained(DINO_LARGE, cache_dir='./') | |
elif dino_mode == 'small': | |
self.dinov2 = FLIP_DINOv2.from_pretrained(DINO_SMALL, cache_dir='./') | |
elif dino_mode == 'giant': | |
self.dinov2 = FLIP_DINOv2.from_pretrained(DINO_GIANT, cache_dir='./') | |
self.down_sampler = MLP_dim(in_dim=in_dim, out_dim=out_dim) | |
self.random_mask = False | |
if not evaluate: | |
self.init_weights(self.down_sampler) | |
self.random_mask = mask_dino | |
if frozen_back: | |
self.forward_mode = torch.no_grad() | |
else: | |
self.forward_mode = nullcontext() | |
def forward(self, img_inputs): | |
device = self.get_device() | |
# print(img_inputs['pixel_values'].shape) | |
with self.forward_mode: | |
if self.random_mask: | |
B = len(img_inputs['pixel_values']) | |
S = 256 | |
indices = [] | |
for i in range(B): | |
tmp = torch.randperm(S)[:S//2] | |
tmp = tmp.sort().values + 1 | |
indices.append(tmp) | |
indices = torch.stack(indices, dim=0) | |
indices = torch.cat([torch.zeros(B, 1, dtype=torch.long, device='cpu'), indices], dim=1) | |
# print(indices.shape) | |
img_inputs['bool_masked_pos'] = indices.to(device) | |
dino_outputs = self.dinov2(**img_inputs) | |
dino_seq = dino_outputs.last_hidden_state | |
# B,S,_ = dino_seq.shape | |
# dino_seq = dino_seq.view(B*S,-1) | |
dino_seq = dino_seq[:,0,:] | |
down_sample_out = self.down_sampler(dino_seq) | |
# down_sample_out = down_sample_out.view(B,S,-1) | |
# down_sample_out = down_sample_out[:,0,:] | |
return down_sample_out | |
def get_device(self): | |
return next(self.parameters()).device | |
def init_weights(self, m): | |
if isinstance(m, nn.Linear): | |
init.xavier_uniform_(m.weight) | |
if m.bias is not None: | |
init.constant_(m.bias, 0) | |