# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# | |
# This source code is licensed under the Apache License, Version 2.0 | |
# found in the LICENSE file in the root directory of this source tree. | |
import itertools | |
import math | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
_DINOV2_BASE_URL = "https://dl.fbaipublicfiles.com/dinov2" | |
def _make_dinov2_model_name(arch_name: str, patch_size: int, num_register_tokens: int = 0) -> str: | |
compact_arch_name = arch_name.replace("_", "")[:4] | |
registers_suffix = f"_reg{num_register_tokens}" if num_register_tokens else "" | |
return f"dinov2_{compact_arch_name}{patch_size}{registers_suffix}" | |
class CenterPadding(nn.Module): | |
def __init__(self, multiple): | |
super().__init__() | |
self.multiple = multiple | |
def _get_pad(self, size): | |
new_size = math.ceil(size / self.multiple) * self.multiple | |
pad_size = new_size - size | |
pad_size_left = pad_size // 2 | |
pad_size_right = pad_size - pad_size_left | |
return pad_size_left, pad_size_right | |
def forward(self, x): | |
pads = list(itertools.chain.from_iterable(self._get_pad(m) for m in x.shape[:1:-1])) | |
output = F.pad(x, pads) | |
return output | |