|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import pytorch_lightning as pl |
|
|
|
from src.dinov2.models.vision_transformer import vit_base |
|
from src.options import opts |
|
|
|
def freeze_model(m): |
|
m.requires_grad_(False) |
|
|
|
def freeze_all_but_bn(m): |
|
if not isinstance(m, torch.nn.LayerNorm): |
|
if hasattr(m, 'weight') and m.weight is not None: |
|
m.weight.requires_grad_(False) |
|
if hasattr(m, 'bias') and m.bias is not None: |
|
m.bias.requires_grad_(False) |
|
else: |
|
print("LayerNorm") |
|
|
|
class Model(pl.LightningModule): |
|
def __init__(self): |
|
super().__init__() |
|
|
|
self.opts = opts |
|
|
|
self.dino = vit_base(patch_size=14, block_chunks=0, init_values=1.0) |
|
|
|
|
|
self.sk_prompt = nn.Parameter(torch.randn(self.opts.n_prompts, self.opts.prompt_dim)) |
|
self.img_prompt = nn.Parameter(torch.randn(self.opts.n_prompts, self.opts.prompt_dim)) |
|
|
|
|
|
def configure_optimizers(self): |
|
model_params = list(self.dino.parameters()) |
|
|
|
optimizer = torch.optim.Adam([ |
|
{'params': model_params, 'lr': self.opts.clip_LN_lr}] |
|
) |
|
return optimizer |
|
|
|
def forward(self, data, dtype='image'): |
|
if dtype == 'image': |
|
feat = self.dino(data, prompt=self.img_prompt.expand(data.shape[0], -1, -1)) |
|
else: |
|
feat = self.dino(data, prompt=self.sk_prompt.expand(data.shape[0], -1, -1)) |
|
return feat |