|
""" Image to Patch Embedding using Conv2d |
|
|
|
A convolution based approach to patchifying a 2D image w/ embedding projection. |
|
|
|
Based on code in: |
|
* https://github.com/google-research/vision_transformer |
|
* https://github.com/google-research/big_vision/tree/main/big_vision |
|
|
|
Hacked together by / Copyright 2020 Ross Wightman |
|
""" |
|
import logging |
|
import math |
|
from typing import Callable, List, Optional, Tuple, Union |
|
|
|
import torch |
|
from torch import nn as nn |
|
import torch.nn.functional as F |
|
|
|
from .format import Format, nchw_to |
|
from .helpers import to_2tuple |
|
from .trace_utils import _assert |
|
|
|
_logger = logging.getLogger(__name__) |
|
|
|
|
|
class PatchEmbed(nn.Module): |
|
""" 2D Image to Patch Embedding |
|
""" |
|
output_fmt: Format |
|
dynamic_img_pad: torch.jit.Final[bool] |
|
|
|
def __init__( |
|
self, |
|
img_size: Optional[int] = 224, |
|
patch_size: int = 16, |
|
in_chans: int = 3, |
|
embed_dim: int = 768, |
|
norm_layer: Optional[Callable] = None, |
|
flatten: bool = True, |
|
output_fmt: Optional[str] = None, |
|
bias: bool = True, |
|
strict_img_size: bool = True, |
|
dynamic_img_pad: bool = False, |
|
): |
|
super().__init__() |
|
self.patch_size = to_2tuple(patch_size) |
|
if img_size is not None: |
|
self.img_size = to_2tuple(img_size) |
|
self.grid_size = tuple([s // p for s, p in zip(self.img_size, self.patch_size)]) |
|
self.num_patches = self.grid_size[0] * self.grid_size[1] |
|
else: |
|
self.img_size = None |
|
self.grid_size = None |
|
self.num_patches = None |
|
|
|
if output_fmt is not None: |
|
self.flatten = False |
|
self.output_fmt = Format(output_fmt) |
|
else: |
|
|
|
self.flatten = flatten |
|
self.output_fmt = Format.NCHW |
|
self.strict_img_size = strict_img_size |
|
self.dynamic_img_pad = dynamic_img_pad |
|
|
|
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias) |
|
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() |
|
|
|
def feat_ratio(self, as_scalar=True) -> Union[Tuple[int, int], int]: |
|
if as_scalar: |
|
return max(self.patch_size) |
|
else: |
|
return self.patch_size |
|
|
|
def dynamic_feat_size(self, img_size: Tuple[int, int]) -> Tuple[int, int]: |
|
""" Get grid (feature) size for given image size taking account of dynamic padding. |
|
NOTE: must be torchscript compatible so using fixed tuple indexing |
|
""" |
|
if self.dynamic_img_pad: |
|
return math.ceil(img_size[0] / self.patch_size[0]), math.ceil(img_size[1] / self.patch_size[1]) |
|
else: |
|
return img_size[0] // self.patch_size[0], img_size[1] // self.patch_size[1] |
|
|
|
def forward(self, x): |
|
B, C, H, W = x.shape |
|
if self.img_size is not None: |
|
if self.strict_img_size: |
|
_assert(H == self.img_size[0], f"Input height ({H}) doesn't match model ({self.img_size[0]}).") |
|
_assert(W == self.img_size[1], f"Input width ({W}) doesn't match model ({self.img_size[1]}).") |
|
elif not self.dynamic_img_pad: |
|
_assert( |
|
H % self.patch_size[0] == 0, |
|
f"Input height ({H}) should be divisible by patch size ({self.patch_size[0]})." |
|
) |
|
_assert( |
|
W % self.patch_size[1] == 0, |
|
f"Input width ({W}) should be divisible by patch size ({self.patch_size[1]})." |
|
) |
|
if self.dynamic_img_pad: |
|
pad_h = (self.patch_size[0] - H % self.patch_size[0]) % self.patch_size[0] |
|
pad_w = (self.patch_size[1] - W % self.patch_size[1]) % self.patch_size[1] |
|
x = F.pad(x, (0, pad_w, 0, pad_h)) |
|
x = self.proj(x) |
|
if self.flatten: |
|
x = x.flatten(2).transpose(1, 2) |
|
elif self.output_fmt != Format.NCHW: |
|
x = nchw_to(x, self.output_fmt) |
|
x = self.norm(x) |
|
return x |
|
|
|
|
|
class PatchEmbedWithSize(PatchEmbed): |
|
""" 2D Image to Patch Embedding |
|
""" |
|
output_fmt: Format |
|
|
|
def __init__( |
|
self, |
|
img_size: Optional[int] = 224, |
|
patch_size: int = 16, |
|
in_chans: int = 3, |
|
embed_dim: int = 768, |
|
norm_layer: Optional[Callable] = None, |
|
flatten: bool = True, |
|
output_fmt: Optional[str] = None, |
|
bias: bool = True, |
|
): |
|
super().__init__( |
|
img_size=img_size, |
|
patch_size=patch_size, |
|
in_chans=in_chans, |
|
embed_dim=embed_dim, |
|
norm_layer=norm_layer, |
|
flatten=flatten, |
|
output_fmt=output_fmt, |
|
bias=bias, |
|
) |
|
|
|
def forward(self, x) -> Tuple[torch.Tensor, List[int]]: |
|
B, C, H, W = x.shape |
|
if self.img_size is not None: |
|
_assert(H % self.patch_size[0] == 0, f"Input image height ({H}) must be divisible by patch size ({self.patch_size[0]}).") |
|
_assert(W % self.patch_size[1] == 0, f"Input image width ({W}) must be divisible by patch size ({self.patch_size[1]}).") |
|
|
|
x = self.proj(x) |
|
feat_size = x.shape[-2:] |
|
if self.flatten: |
|
x = x.flatten(2).transpose(1, 2) |
|
elif self.output_fmt != Format.NCHW: |
|
x = nchw_to(x, self.output_fmt) |
|
x = self.norm(x) |
|
return x, feat_size |
|
|
|
|
|
def resample_patch_embed( |
|
patch_embed, |
|
new_size: List[int], |
|
interpolation: str = 'bicubic', |
|
antialias: bool = True, |
|
verbose: bool = False, |
|
): |
|
"""Resample the weights of the patch embedding kernel to target resolution. |
|
We resample the patch embedding kernel by approximately inverting the effect |
|
of patch resizing. |
|
|
|
Code based on: |
|
https://github.com/google-research/big_vision/blob/b00544b81f8694488d5f36295aeb7972f3755ffe/big_vision/models/proj/flexi/vit.py |
|
|
|
With this resizing, we can for example load a B/8 filter into a B/16 model |
|
and, on 2x larger input image, the result will match. |
|
|
|
Args: |
|
patch_embed: original parameter to be resized. |
|
new_size (tuple(int, int): target shape (height, width)-only. |
|
interpolation (str): interpolation for resize |
|
antialias (bool): use anti-aliasing filter in resize |
|
verbose (bool): log operation |
|
Returns: |
|
Resized patch embedding kernel. |
|
""" |
|
import numpy as np |
|
try: |
|
import functorch |
|
vmap = functorch.vmap |
|
except ImportError: |
|
if hasattr(torch, 'vmap'): |
|
vmap = torch.vmap |
|
else: |
|
assert False, "functorch or a version of torch with vmap is required for FlexiViT resizing." |
|
|
|
assert len(patch_embed.shape) == 4, "Four dimensions expected" |
|
assert len(new_size) == 2, "New shape should only be hw" |
|
old_size = patch_embed.shape[-2:] |
|
if tuple(old_size) == tuple(new_size): |
|
return patch_embed |
|
|
|
if verbose: |
|
_logger.info(f"Resize patch embedding {patch_embed.shape} to {new_size}, w/ {interpolation} interpolation.") |
|
|
|
def resize(x_np, _new_size): |
|
x_tf = torch.Tensor(x_np)[None, None, ...] |
|
x_upsampled = F.interpolate( |
|
x_tf, size=_new_size, mode=interpolation, antialias=antialias)[0, 0, ...].numpy() |
|
return x_upsampled |
|
|
|
def get_resize_mat(_old_size, _new_size): |
|
mat = [] |
|
for i in range(np.prod(_old_size)): |
|
basis_vec = np.zeros(_old_size) |
|
basis_vec[np.unravel_index(i, _old_size)] = 1. |
|
mat.append(resize(basis_vec, _new_size).reshape(-1)) |
|
return np.stack(mat).T |
|
|
|
resize_mat = get_resize_mat(old_size, new_size) |
|
resize_mat_pinv = torch.tensor(np.linalg.pinv(resize_mat.T), device=patch_embed.device) |
|
|
|
def resample_kernel(kernel): |
|
resampled_kernel = resize_mat_pinv @ kernel.reshape(-1) |
|
return resampled_kernel.reshape(new_size) |
|
|
|
v_resample_kernel = vmap(vmap(resample_kernel, 0, 0), 1, 1) |
|
orig_dtype = patch_embed.dtype |
|
patch_embed = patch_embed.float() |
|
patch_embed = v_resample_kernel(patch_embed) |
|
patch_embed = patch_embed.to(orig_dtype) |
|
return patch_embed |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|