import torch from torch import nn import torch.nn.init as init import torch.nn.functional as F from paths import * from typing import Dict, List, Optional, Set, Tuple, Union from transformers import AutoImageProcessor, AutoModel, Dinov2Model from transformers.models.dinov2.modeling_dinov2 import Dinov2Embeddings from transformers.models.dinov2.configuration_dinov2 import Dinov2Config import numpy as np from contextlib import nullcontext def get_activation(activation): if activation.lower() == 'gelu': return nn.GELU() elif activation.lower() == 'rrelu': return nn.RReLU(inplace=True) elif activation.lower() == 'selu': return nn.SELU(inplace=True) elif activation.lower() == 'silu': return nn.SiLU(inplace=True) elif activation.lower() == 'hardswish': return nn.Hardswish(inplace=True) elif activation.lower() == 'leakyrelu': return nn.LeakyReLU(inplace=True) elif activation.lower() == 'sigmoid': return nn.Sigmoid() elif activation.lower() == 'tanh': return nn.Tanh() else: return nn.ReLU(inplace=True) class MLP_dim(nn.Module): def __init__( self, in_dim=512, out_dim=1024, bias=True, activation='relu'): super().__init__() self.act = get_activation(activation) self.net1 = nn.Sequential( nn.Linear(in_dim, int(out_dim), bias=bias), nn.BatchNorm1d(int(out_dim)), self.act ) self.net2 = nn.Sequential( nn.Linear(int(out_dim), out_dim, bias=bias), nn.BatchNorm1d(out_dim) ) def forward(self, x): return self.net2(self.net1(x)) class FLIP_Dinov2Embeddings(Dinov2Embeddings): """ Construct the CLS token, mask token, position and patch embeddings. """ def __init__(self, config: Dinov2Config) -> None: super().__init__(config) def forward(self, pixel_values: torch.Tensor, bool_masked_pos: Optional[torch.Tensor] = None) -> torch.Tensor: batch_size, _, height, width = pixel_values.shape target_dtype = self.patch_embeddings.projection.weight.dtype embeddings = self.patch_embeddings(pixel_values.to(dtype=target_dtype)) # add the [CLS] token to the embedded patch tokens cls_tokens = self.cls_token.expand(batch_size, -1, -1) embeddings = torch.cat((cls_tokens, embeddings), dim=1) # add positional encoding to each token embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width) if bool_masked_pos is not None: # embeddings = torch.where( # bool_masked_pos.unsqueeze(-1), self.mask_token.to(embeddings.dtype).unsqueeze(0), embeddings # ) B,S,D = embeddings.shape batch_indices = torch.arange(B).unsqueeze(1) embeddings = embeddings[batch_indices, bool_masked_pos] embeddings = self.dropout(embeddings) return embeddings class FLIP_DINOv2(Dinov2Model): def __init__(self, config): super().__init__(config) self.embeddings = FLIP_Dinov2Embeddings(config) class DINOv2_MLP(nn.Module): def __init__(self, dino_mode, in_dim, out_dim, evaluate, mask_dino, frozen_back ) -> None: super().__init__() # self.dinov2 = AutoModel.from_pretrained(DINO_BASE) if dino_mode == 'base': self.dinov2 = FLIP_DINOv2.from_pretrained(DINO_BASE, cache_dir='./') elif dino_mode == 'large': self.dinov2 = FLIP_DINOv2.from_pretrained(DINO_LARGE, cache_dir='./') elif dino_mode == 'small': self.dinov2 = FLIP_DINOv2.from_pretrained(DINO_SMALL, cache_dir='./') elif dino_mode == 'giant': self.dinov2 = FLIP_DINOv2.from_pretrained(DINO_GIANT, cache_dir='./') self.down_sampler = MLP_dim(in_dim=in_dim, out_dim=out_dim) self.random_mask = False if not evaluate: self.init_weights(self.down_sampler) self.random_mask = mask_dino if frozen_back: self.forward_mode = torch.no_grad() else: self.forward_mode = nullcontext() def forward(self, img_inputs): device = self.get_device() # print(img_inputs['pixel_values'].shape) with self.forward_mode: if self.random_mask: B = len(img_inputs['pixel_values']) S = 256 indices = [] for i in range(B): tmp = torch.randperm(S)[:S//2] tmp = tmp.sort().values + 1 indices.append(tmp) indices = torch.stack(indices, dim=0) indices = torch.cat([torch.zeros(B, 1, dtype=torch.long, device='cpu'), indices], dim=1) # print(indices.shape) img_inputs['bool_masked_pos'] = indices.to(device) dino_outputs = self.dinov2(**img_inputs) dino_seq = dino_outputs.last_hidden_state # B,S,_ = dino_seq.shape # dino_seq = dino_seq.view(B*S,-1) dino_seq = dino_seq[:,0,:] down_sample_out = self.down_sampler(dino_seq) # down_sample_out = down_sample_out.view(B,S,-1) # down_sample_out = down_sample_out[:,0,:] return down_sample_out def get_device(self): return next(self.parameters()).device def init_weights(self, m): if isinstance(m, nn.Linear): init.xavier_uniform_(m.weight) if m.bias is not None: init.constant_(m.bias, 0)