Huiwenshi's picture
Upload folder using huggingface_hub
e3e5f9e verified
raw
history blame
2.39 kB
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
# References:
# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
from functools import partial
import math
import logging
from typing import Sequence, Tuple, Union, Callable
from collections import OrderedDict
import torch
import torch.nn as nn
import torch.utils.checkpoint
from torch.nn.init import trunc_normal_
from .dinov2.hub.backbones import dinov2_vitb14
class FrozenDinoV2ImageEmbedder(nn.Module):
"""
Uses the dinov2 image encoder with camera modulation.
Not actually frozen... If you want that set cond_stage_trainable=False in cfg
"""
def __init__(
self,
version='dinov2_vitb14',
ckpt_path=None,
lrm_mode='plain_lrm',
):
super().__init__()
self.lrm_mode = lrm_mode
assert version in ['dinov2_vitb14', 'dinov2_vits14', 'dinov2_vitl14', 'dinov2_vitg14']
self.model = dinov2_vitb14(pretrained=False)
if ckpt_path is not None:
self.load_pretrained(ckpt_path)
else:
print('None pretrained model for dinov2 encoder ...')
def load_pretrained(self, ckpt_path):
print('Loading dinov2 encoder ...')
orig_state_dict = torch.load(ckpt_path, map_location='cpu')
try:
ret = self.model.load_state_dict(orig_state_dict, strict=False)
print(ret)
print('Successfully loaded orig state dict')
except:
new_state_dict = OrderedDict()
for k, v in orig_state_dict['state_dict'].items():
if 'img_encoder' in k:
new_state_dict[k.replace('img_encoder.model.', '')] = v
ret = self.model.load_state_dict(new_state_dict, strict=False)
print(ret)
print('Successfully loaded new state dict')
def forward(self, x, *args, **kwargs):
ret = self.model.forward_features_with_camera(x, *args, **kwargs)
output = torch.cat([ret['x_norm_clstoken'].unsqueeze(1), ret['x_norm_patchtokens']], dim=1)
return output