# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import math import torch import torch.nn as nn import torch.nn.functional as F from torch.autograd import Variable from typing import Optional, Tuple, Type from .common import LayerNorm2d, MLPBlock from .image_encoder import ( window_partition, window_unpartition, add_decomposed_rel_pos, ImageEncoderViT, Block, Attention, ) class TokenClusteringBlock(nn.Module): def __init__(self, num_spixels=None, n_iters=5, temperture=0.05, window_size=7): super().__init__() if isinstance(num_spixels, tuple): assert len(num_spixels) == 2 elif num_spixels is not None: x = int(math.sqrt(num_spixels)) assert x * x == num_spixels num_spixels = (x, x) self.num_spixels = num_spixels self.n_iters = n_iters self.temperture = temperture assert window_size % 2 == 1 self.r = window_size // 2 def calc_init_centroid(self, images, num_spixels_width, num_spixels_height): """ calculate initial superpixels Args: images: torch.Tensor A Tensor of shape (B, C, H, W) spixels_width: int initial superpixel width spixels_height: int initial superpixel height Return: centroids: torch.Tensor A Tensor of shape (B, C, H * W) init_label_map: torch.Tensor A Tensor of shape (B, H * W) num_spixels_width: int A number of superpixels in each column num_spixels_height: int A number of superpixels int each raw """ batchsize, channels, height, width = images.shape device = images.device centroids = torch.nn.functional.adaptive_avg_pool2d( images, (num_spixels_height, num_spixels_width) ) with torch.no_grad(): num_spixels = num_spixels_width * num_spixels_height labels = ( torch.arange(num_spixels, device=device) .reshape(1, 1, *centroids.shape[-2:]) .type_as(centroids) ) init_label_map = torch.nn.functional.interpolate( labels, size=(height, width), mode="nearest" ).type_as(centroids) init_label_map = init_label_map.repeat(batchsize, 1, 1, 1) init_label_map = init_label_map.reshape(batchsize, -1) centroids = centroids.reshape(batchsize, channels, -1) return centroids, init_label_map def forward(self, pixel_features, num_spixels=None): if num_spixels is None: num_spixels = self.num_spixels assert num_spixels is not None else: if isinstance(num_spixels, tuple): assert len(num_spixels) == 2 else: x = int(math.sqrt(num_spixels)) assert x * x == num_spixels num_spixels = (x, x) pixel_features = pixel_features.permute(0, 3, 1, 2) num_spixels_height, num_spixels_width = num_spixels num_spixels = num_spixels_width * num_spixels_height spixel_features, init_label_map = self.calc_init_centroid( pixel_features, num_spixels_width, num_spixels_height ) device = init_label_map.device spixels_number = torch.arange(num_spixels, device=device)[None, :, None] relative_labels_widths = init_label_map[:, None] % num_spixels_width - spixels_number % num_spixels_width relative_labels_heights = torch.div(init_label_map[:, None], num_spixels_width, rounding_mode='trunc') - torch.div(spixels_number, num_spixels_width, rounding_mode='trunc') mask = torch.logical_and(torch.abs(relative_labels_widths) <= self.r, torch.abs(relative_labels_heights) <= self.r) mask_dist = (~mask) * 1e16 pixel_features = pixel_features.reshape(*pixel_features.shape[:2], -1) # (B, C, L) permuted_pixel_features = pixel_features.permute(0, 2, 1) # (B, L, C) for _ in range(self.n_iters): dist_matrix = self.pairwise_dist(pixel_features, spixel_features) # (B, L', L) dist_matrix += mask_dist affinity_matrix = (-dist_matrix * self.temperture).softmax(1) spixel_features = torch.bmm(affinity_matrix.detach(), permuted_pixel_features) spixel_features = spixel_features / affinity_matrix.detach().sum(2, keepdim=True).clamp_(min=1e-16) spixel_features = spixel_features.permute(0, 2, 1) dist_matrix = self.pairwise_dist(pixel_features, spixel_features) hard_labels = torch.argmin(dist_matrix, dim=1) B, C, _ = spixel_features.shape spixel_features = spixel_features.permute(0, 2, 1).reshape(B, num_spixels_height, num_spixels_width, C) return spixel_features, hard_labels def pairwise_dist(self, f1, f2): return ((f1 * f1).sum(dim=1).unsqueeze(1) + (f2 * f2).sum(dim=1).unsqueeze(2) - 2 * torch.einsum("bcm, bcn -> bmn", f2, f1)) def extra_repr(self): return f"num_spixels={self.num_spixels}, n_iters={self.n_iters}" def naive_unpool(f_regions, region_indices): _, _, C = f_regions.shape N, L = region_indices.shape index = region_indices.view(N, L, 1).expand(N, L, C) result = f_regions.gather(1, index) return result class State: def __init__(self, unpooling): self.unpooling = unpooling self.__updated = False @property def updated(self): return self.__updated def get(self, name, default=None): return getattr(self, name, default) def update_state(self, **states: dict): self.__updated = True for k, v in states.items(): setattr(self, k, v) def call(self, input: torch.Tensor): return self.unpooling(input, self) class UnpoolingBase(nn.Module): def forward(self, x, state: State): if not state.updated: return x, False return self._forward(x, state) def derive_unpooler(self): return State(self) class NaiveUnpooling(UnpoolingBase): def _forward(self, x, state: State): return naive_unpool(x, state.hard_labels), False class TokenReconstructionBlock(UnpoolingBase): def __init__(self, k=3, temperture=0.05): super().__init__() self.k = k self.temperture = temperture def _forward(self, x, state: State): feat = state.feat_before_pooling sfeat = state.feat_after_pooling ds = ( (feat * feat).sum(dim=2).unsqueeze(2) + (sfeat * sfeat).sum(dim=2).unsqueeze(1) - 2 * torch.einsum("bnc, bmc -> bnm", feat, sfeat) ) # distance between features and super-features ds[ds < 0] = 0 weight = torch.exp(-self.temperture * ds) if self.k >= 0: topk, indices = torch.topk(weight, k=self.k, dim=2) mink = torch.min(topk, dim=-1).values mink = mink.unsqueeze(-1).repeat(1, 1, weight.shape[-1]) mask = torch.ge(weight, mink) zero = Variable(torch.zeros_like(weight)).to(weight.device) attention = torch.where(mask, weight, zero) attention = F.normalize(attention, dim=2) ret = torch.einsum("bnm, bmc -> bnc", attention, x) return ret, False class HourglassImageEncoderViT(ImageEncoderViT): def __init__( self, img_size: int = 1024, patch_size: int = 16, in_chans: int = 3, embed_dim: int = 768, depth: int = 12, num_heads: int = 12, mlp_ratio: float = 4.0, out_chans: int = 256, qkv_bias: bool = True, norm_layer: Type[nn.Module] = nn.LayerNorm, act_layer: Type[nn.Module] = nn.GELU, use_abs_pos: bool = True, use_rel_pos: bool = False, rel_pos_zero_init: bool = True, window_size: int = 0, global_attn_indexes: Tuple[int, ...] = (), hourglass_clustering_location: int = -1, hourglass_num_cluster: int = None, hourglass_cluster_iters: int = 5, hourglass_temperture: float = 0.01, hourglass_cluster_window_size: int = 5, hourglass_reconstruction_k: int = 20, ) -> None: """ Args: img_size (int): Input image size. patch_size (int): Patch size. in_chans (int): Number of input image channels. embed_dim (int): Patch embedding dimension. depth (int): Depth of ViT. num_heads (int): Number of attention heads in each ViT block. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. qkv_bias (bool): If True, add a learnable bias to query, key, value. norm_layer (nn.Module): Normalization layer. act_layer (nn.Module): Activation layer. use_abs_pos (bool): If True, use absolute positional embeddings. use_rel_pos (bool): If True, add relative positional embeddings to the attention map. rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. window_size (int): Window size for window attention blocks. global_attn_indexes (list): Indexes for blocks using global attention. """ super().__init__( img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, depth=depth, num_heads=num_heads, mlp_ratio=mlp_ratio, out_chans=out_chans, qkv_bias=qkv_bias, norm_layer=norm_layer, act_layer=act_layer, use_abs_pos=use_abs_pos, use_rel_pos=use_rel_pos, rel_pos_zero_init=rel_pos_zero_init, window_size=window_size, global_attn_indexes=global_attn_indexes, ) self.window_size = window_size self.ws_new = int(math.sqrt(hourglass_num_cluster)) self.blocks = nn.ModuleList() for i in range(depth): block = HourglassBlock( dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, norm_layer=norm_layer, act_layer=act_layer, use_rel_pos=use_rel_pos, rel_pos_zero_init=rel_pos_zero_init, window_size=(window_size if i < hourglass_clustering_location else self.ws_new) if i not in global_attn_indexes else 0, window_size_ckpt=window_size, input_size=(img_size // patch_size, img_size // patch_size), ) self.blocks.append(block) self.clustering_location = hourglass_clustering_location self.token_clustering_block = TokenClusteringBlock( num_spixels=hourglass_num_cluster, n_iters=hourglass_cluster_iters, temperture=hourglass_temperture, window_size=hourglass_cluster_window_size, ) self.token_reconstruction_block = TokenReconstructionBlock( k=hourglass_reconstruction_k, temperture=hourglass_temperture, ) def cluster(self, x, reconstructer): # x: B, H, W, C H, W = x.shape[1:3] x, pad_hw = window_partition(x, self.window_size) # B*Nw, WH, WW, C Bnw, _, _, C = x.shape reconstructer.update_state( feat_before_pooling=x.view(-1, self.window_size * self.window_size, C) ) x, hard_labels = self.token_clustering_block(x) # B*H*W, Wh, Ww, C reconstructer.update_state(hard_labels=hard_labels) reconstructer.update_state(feat_after_pooling=x.view(Bnw, -1, C)) # merge window # Reverse window partition h = pad_hw[0] // self.window_size * x.shape[1] w = pad_hw[1] // self.window_size * x.shape[2] x = window_unpartition(x, self.ws_new, (h, w), (h, w)) # out: B, h, w, C return x, pad_hw def reconstruct(self, x, H, W, recontructer, pad_hw): # x: B, h, w, C x, _ = window_partition(x, self.ws_new) # B*Nw, Wh, Ww, C Bnw, _, _, C = x.shape x = x.view(Bnw, -1, C) x, _ = recontructer.call(x) # B*Nw, WH*WW, C # merge windows x = x.view(-1, self.window_size, self.window_size, C) x = window_unpartition(x, self.window_size, pad_hw, (H, W)) # B, H, W, C return x def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.patch_embed(x) if self.pos_embed is not None: x = x + self.pos_embed H, W = x.shape[1], x.shape[2] reconstructer = self.token_reconstruction_block.derive_unpooler() reconstructer.update_state(hw_shape=(H, W)) for i, blk in enumerate(self.blocks): if i == self.clustering_location: x, pad_hw = self.cluster(x, reconstructer) x = blk(x) x = self.reconstruct(x, H, W, reconstructer, pad_hw) x = self.neck(x.permute(0, 3, 1, 2)) return x class HourglassBlock(Block): """Transformer blocks with support of window attention and residual propagation blocks""" def __init__( self, dim: int, num_heads: int, mlp_ratio: float = 4.0, qkv_bias: bool = True, norm_layer: Type[nn.Module] = nn.LayerNorm, act_layer: Type[nn.Module] = nn.GELU, use_rel_pos: bool = False, rel_pos_zero_init: bool = True, window_size: int = 0, input_size: Optional[Tuple[int, int]] = None, window_size_ckpt: int = 0, ) -> None: """ Args: dim (int): Number of input channels. num_heads (int): Number of attention heads in each ViT block. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. qkv_bias (bool): If True, add a learnable bias to query, key, value. norm_layer (nn.Module): Normalization layer. act_layer (nn.Module): Activation layer. use_rel_pos (bool): If True, add relative positional embeddings to the attention map. rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. window_size (int): Window size for window attention blocks. If it equals 0, then use global attention. input_size (int or None): Input resolution for calculating the relative positional parameter size. """ super(HourglassBlock, self).__init__( dim=dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, norm_layer=norm_layer, act_layer=act_layer, use_rel_pos=use_rel_pos, rel_pos_zero_init=rel_pos_zero_init, window_size=window_size, input_size=input_size, ) self.attn = Attention( dim, num_heads=num_heads, qkv_bias=qkv_bias, use_rel_pos=use_rel_pos, rel_pos_zero_init=rel_pos_zero_init, input_size=input_size if window_size == 0 else (window_size_ckpt, window_size_ckpt), )