Expedit-SAM / segment_anything /modeling /hourglass_image_encoder.py
kxqt's picture
init commit
2c50deb
raw
history blame
15.6 kB
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from typing import Optional, Tuple, Type
from .common import LayerNorm2d, MLPBlock
from .image_encoder import (
window_partition,
window_unpartition,
add_decomposed_rel_pos,
ImageEncoderViT,
Block,
Attention,
)
class TokenClusteringBlock(nn.Module):
def __init__(self, num_spixels=None, n_iters=5, temperture=0.05, window_size=7):
super().__init__()
if isinstance(num_spixels, tuple):
assert len(num_spixels) == 2
elif num_spixels is not None:
x = int(math.sqrt(num_spixels))
assert x * x == num_spixels
num_spixels = (x, x)
self.num_spixels = num_spixels
self.n_iters = n_iters
self.temperture = temperture
assert window_size % 2 == 1
self.r = window_size // 2
def calc_init_centroid(self, images, num_spixels_width, num_spixels_height):
"""
calculate initial superpixels
Args:
images: torch.Tensor
A Tensor of shape (B, C, H, W)
spixels_width: int
initial superpixel width
spixels_height: int
initial superpixel height
Return:
centroids: torch.Tensor
A Tensor of shape (B, C, H * W)
init_label_map: torch.Tensor
A Tensor of shape (B, H * W)
num_spixels_width: int
A number of superpixels in each column
num_spixels_height: int
A number of superpixels int each raw
"""
batchsize, channels, height, width = images.shape
device = images.device
centroids = torch.nn.functional.adaptive_avg_pool2d(
images, (num_spixels_height, num_spixels_width)
)
with torch.no_grad():
num_spixels = num_spixels_width * num_spixels_height
labels = (
torch.arange(num_spixels, device=device)
.reshape(1, 1, *centroids.shape[-2:])
.type_as(centroids)
)
init_label_map = torch.nn.functional.interpolate(
labels, size=(height, width), mode="nearest"
).type_as(centroids)
init_label_map = init_label_map.repeat(batchsize, 1, 1, 1)
init_label_map = init_label_map.reshape(batchsize, -1)
centroids = centroids.reshape(batchsize, channels, -1)
return centroids, init_label_map
def forward(self, pixel_features, num_spixels=None):
if num_spixels is None:
num_spixels = self.num_spixels
assert num_spixels is not None
else:
if isinstance(num_spixels, tuple):
assert len(num_spixels) == 2
else:
x = int(math.sqrt(num_spixels))
assert x * x == num_spixels
num_spixels = (x, x)
pixel_features = pixel_features.permute(0, 3, 1, 2)
num_spixels_height, num_spixels_width = num_spixels
num_spixels = num_spixels_width * num_spixels_height
spixel_features, init_label_map = self.calc_init_centroid(
pixel_features, num_spixels_width, num_spixels_height
)
device = init_label_map.device
spixels_number = torch.arange(num_spixels, device=device)[None, :, None]
relative_labels_widths = init_label_map[:, None] % num_spixels_width - spixels_number % num_spixels_width
relative_labels_heights = torch.div(init_label_map[:, None], num_spixels_width, rounding_mode='trunc') - torch.div(spixels_number, num_spixels_width, rounding_mode='trunc')
mask = torch.logical_and(torch.abs(relative_labels_widths) <= self.r, torch.abs(relative_labels_heights) <= self.r)
mask_dist = (~mask) * 1e16
pixel_features = pixel_features.reshape(*pixel_features.shape[:2], -1) # (B, C, L)
permuted_pixel_features = pixel_features.permute(0, 2, 1) # (B, L, C)
for _ in range(self.n_iters):
dist_matrix = self.pairwise_dist(pixel_features, spixel_features) # (B, L', L)
dist_matrix += mask_dist
affinity_matrix = (-dist_matrix * self.temperture).softmax(1)
spixel_features = torch.bmm(affinity_matrix.detach(), permuted_pixel_features)
spixel_features = spixel_features / affinity_matrix.detach().sum(2, keepdim=True).clamp_(min=1e-16)
spixel_features = spixel_features.permute(0, 2, 1)
dist_matrix = self.pairwise_dist(pixel_features, spixel_features)
hard_labels = torch.argmin(dist_matrix, dim=1)
B, C, _ = spixel_features.shape
spixel_features = spixel_features.permute(0, 2, 1).reshape(B, num_spixels_height, num_spixels_width, C)
return spixel_features, hard_labels
def pairwise_dist(self, f1, f2):
return ((f1 * f1).sum(dim=1).unsqueeze(1)
+ (f2 * f2).sum(dim=1).unsqueeze(2)
- 2 * torch.einsum("bcm, bcn -> bmn", f2, f1))
def extra_repr(self):
return f"num_spixels={self.num_spixels}, n_iters={self.n_iters}"
def naive_unpool(f_regions, region_indices):
_, _, C = f_regions.shape
N, L = region_indices.shape
index = region_indices.view(N, L, 1).expand(N, L, C)
result = f_regions.gather(1, index)
return result
class State:
def __init__(self, unpooling):
self.unpooling = unpooling
self.__updated = False
@property
def updated(self):
return self.__updated
def get(self, name, default=None):
return getattr(self, name, default)
def update_state(self, **states: dict):
self.__updated = True
for k, v in states.items():
setattr(self, k, v)
def call(self, input: torch.Tensor):
return self.unpooling(input, self)
class UnpoolingBase(nn.Module):
def forward(self, x, state: State):
if not state.updated:
return x, False
return self._forward(x, state)
def derive_unpooler(self):
return State(self)
class NaiveUnpooling(UnpoolingBase):
def _forward(self, x, state: State):
return naive_unpool(x, state.hard_labels), False
class TokenReconstructionBlock(UnpoolingBase):
def __init__(self, k=3, temperture=0.05):
super().__init__()
self.k = k
self.temperture = temperture
def _forward(self, x, state: State):
feat = state.feat_before_pooling
sfeat = state.feat_after_pooling
ds = (
(feat * feat).sum(dim=2).unsqueeze(2)
+ (sfeat * sfeat).sum(dim=2).unsqueeze(1)
- 2 * torch.einsum("bnc, bmc -> bnm", feat, sfeat)
) # distance between features and super-features
ds[ds < 0] = 0
weight = torch.exp(-self.temperture * ds)
if self.k >= 0:
topk, indices = torch.topk(weight, k=self.k, dim=2)
mink = torch.min(topk, dim=-1).values
mink = mink.unsqueeze(-1).repeat(1, 1, weight.shape[-1])
mask = torch.ge(weight, mink)
zero = Variable(torch.zeros_like(weight)).cuda()
attention = torch.where(mask, weight, zero)
attention = F.normalize(attention, dim=2)
ret = torch.einsum("bnm, bmc -> bnc", attention, x)
return ret, False
class HourglassImageEncoderViT(ImageEncoderViT):
def __init__(
self,
img_size: int = 1024,
patch_size: int = 16,
in_chans: int = 3,
embed_dim: int = 768,
depth: int = 12,
num_heads: int = 12,
mlp_ratio: float = 4.0,
out_chans: int = 256,
qkv_bias: bool = True,
norm_layer: Type[nn.Module] = nn.LayerNorm,
act_layer: Type[nn.Module] = nn.GELU,
use_abs_pos: bool = True,
use_rel_pos: bool = False,
rel_pos_zero_init: bool = True,
window_size: int = 0,
global_attn_indexes: Tuple[int, ...] = (),
hourglass_clustering_location: int = -1,
hourglass_num_cluster: int = None,
hourglass_cluster_iters: int = 3,
hourglass_temperture: float = 0.1,
hourglass_cluster_window_size: int = 12,
hourglass_reconstruction_k: int = 36,
) -> None:
"""
Args:
img_size (int): Input image size.
patch_size (int): Patch size.
in_chans (int): Number of input image channels.
embed_dim (int): Patch embedding dimension.
depth (int): Depth of ViT.
num_heads (int): Number of attention heads in each ViT block.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool): If True, add a learnable bias to query, key, value.
norm_layer (nn.Module): Normalization layer.
act_layer (nn.Module): Activation layer.
use_abs_pos (bool): If True, use absolute positional embeddings.
use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
window_size (int): Window size for window attention blocks.
global_attn_indexes (list): Indexes for blocks using global attention.
"""
super().__init__(
img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
embed_dim=embed_dim,
depth=depth,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
out_chans=out_chans,
qkv_bias=qkv_bias,
norm_layer=norm_layer,
act_layer=act_layer,
use_abs_pos=use_abs_pos,
use_rel_pos=use_rel_pos,
rel_pos_zero_init=rel_pos_zero_init,
window_size=window_size,
global_attn_indexes=global_attn_indexes,
)
self.window_size = window_size
self.ws_new = int(math.sqrt(hourglass_num_cluster))
self.blocks = nn.ModuleList()
for i in range(depth):
block = HourglassBlock(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
norm_layer=norm_layer,
act_layer=act_layer,
use_rel_pos=use_rel_pos,
rel_pos_zero_init=rel_pos_zero_init,
window_size=(window_size if i < hourglass_clustering_location else self.ws_new) if i not in global_attn_indexes else 0,
window_size_ckpt=window_size,
input_size=(img_size // patch_size, img_size // patch_size),
)
self.blocks.append(block)
self.clustering_location = hourglass_clustering_location
self.token_clustering_block = TokenClusteringBlock(
num_spixels=hourglass_num_cluster,
n_iters=hourglass_cluster_iters,
temperture=hourglass_temperture,
window_size=hourglass_cluster_window_size,
)
self.token_reconstruction_block = TokenReconstructionBlock(
k=hourglass_reconstruction_k,
temperture=hourglass_temperture,
)
def cluster(self, x, reconstructer):
# x: B, H, W, C
H, W = x.shape[1:3]
x, pad_hw = window_partition(x, self.window_size) # B*Nw, WH, WW, C
Bnw, _, _, C = x.shape
reconstructer.update_state(
feat_before_pooling=x.view(-1, self.window_size * self.window_size, C)
)
x, hard_labels = self.token_clustering_block(x) # B*H*W, Wh, Ww, C
reconstructer.update_state(hard_labels=hard_labels)
reconstructer.update_state(feat_after_pooling=x.view(Bnw, -1, C))
# merge window
# Reverse window partition
h = pad_hw[0] // self.window_size * x.shape[1]
w = pad_hw[1] // self.window_size * x.shape[2]
x = window_unpartition(x, self.ws_new, (h, w), (h, w))
# out: B, h, w, C
return x, pad_hw
def reconstruct(self, x, H, W, recontructer, pad_hw):
# x: B, h, w, C
x, _ = window_partition(x, self.ws_new) # B*Nw, Wh, Ww, C
Bnw, _, _, C = x.shape
x = x.view(Bnw, -1, C)
x, _ = recontructer.call(x) # B*Nw, WH*WW, C
# merge windows
x = x.view(-1, self.window_size, self.window_size, C)
x = window_unpartition(x, self.window_size, pad_hw, (H, W)) # B, H, W, C
return x
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.patch_embed(x)
if self.pos_embed is not None:
x = x + self.pos_embed
H, W = x.shape[1], x.shape[2]
reconstructer = self.token_reconstruction_block.derive_unpooler()
reconstructer.update_state(hw_shape=(H, W))
for i, blk in enumerate(self.blocks):
if i == self.clustering_location:
x, pad_hw = self.cluster(x, reconstructer)
x = blk(x)
x = self.reconstruct(x, H, W, reconstructer, pad_hw)
x = self.neck(x.permute(0, 3, 1, 2))
return x
class HourglassBlock(Block):
"""Transformer blocks with support of window attention and residual propagation blocks"""
def __init__(
self,
dim: int,
num_heads: int,
mlp_ratio: float = 4.0,
qkv_bias: bool = True,
norm_layer: Type[nn.Module] = nn.LayerNorm,
act_layer: Type[nn.Module] = nn.GELU,
use_rel_pos: bool = False,
rel_pos_zero_init: bool = True,
window_size: int = 0,
input_size: Optional[Tuple[int, int]] = None,
window_size_ckpt: int = 0,
) -> None:
"""
Args:
dim (int): Number of input channels.
num_heads (int): Number of attention heads in each ViT block.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool): If True, add a learnable bias to query, key, value.
norm_layer (nn.Module): Normalization layer.
act_layer (nn.Module): Activation layer.
use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
window_size (int): Window size for window attention blocks. If it equals 0, then
use global attention.
input_size (int or None): Input resolution for calculating the relative positional
parameter size.
"""
super(HourglassBlock, self).__init__(
dim=dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
norm_layer=norm_layer,
act_layer=act_layer,
use_rel_pos=use_rel_pos,
rel_pos_zero_init=rel_pos_zero_init,
window_size=window_size,
input_size=input_size,
)
self.attn = Attention(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
use_rel_pos=use_rel_pos,
rel_pos_zero_init=rel_pos_zero_init,
input_size=input_size if window_size == 0 else (window_size_ckpt, window_size_ckpt),
)