Spaces:
Runtime error
Runtime error
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
import math | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch.autograd import Variable | |
from typing import Optional, Tuple, Type | |
from .common import LayerNorm2d, MLPBlock | |
from .image_encoder import ( | |
window_partition, | |
window_unpartition, | |
add_decomposed_rel_pos, | |
ImageEncoderViT, | |
Block, | |
Attention, | |
) | |
class TokenClusteringBlock(nn.Module): | |
def __init__(self, num_spixels=None, n_iters=5, temperture=0.01, window_size=5): | |
super().__init__() | |
if isinstance(num_spixels, tuple): | |
assert len(num_spixels) == 2 | |
elif num_spixels is not None: | |
x = int(math.sqrt(num_spixels)) | |
assert x * x == num_spixels | |
num_spixels = (x, x) | |
self.num_spixels = num_spixels | |
self.n_iters = n_iters | |
self.temperture = temperture | |
assert window_size % 2 == 1 | |
self.r = window_size // 2 | |
def calc_init_centroid(self, images, num_spixels_width, num_spixels_height): | |
""" | |
calculate initial superpixels | |
Args: | |
images: torch.Tensor | |
A Tensor of shape (B, C, H, W) | |
spixels_width: int | |
initial superpixel width | |
spixels_height: int | |
initial superpixel height | |
Return: | |
centroids: torch.Tensor | |
A Tensor of shape (B, C, H * W) | |
init_label_map: torch.Tensor | |
A Tensor of shape (B, H * W) | |
num_spixels_width: int | |
A number of superpixels in each column | |
num_spixels_height: int | |
A number of superpixels int each raw | |
""" | |
batchsize, channels, height, width = images.shape | |
device = images.device | |
centroids = torch.nn.functional.adaptive_avg_pool2d( | |
images, (num_spixels_height, num_spixels_width) | |
) | |
with torch.no_grad(): | |
num_spixels = num_spixels_width * num_spixels_height | |
labels = ( | |
torch.arange(num_spixels, device=device) | |
.reshape(1, 1, *centroids.shape[-2:]) | |
.type_as(centroids) | |
) | |
init_label_map = torch.nn.functional.interpolate( | |
labels, size=(height, width), mode="nearest" | |
).type_as(centroids) | |
init_label_map = init_label_map.repeat(batchsize, 1, 1, 1) | |
init_label_map = init_label_map.reshape(batchsize, -1) | |
centroids = centroids.reshape(batchsize, channels, -1) | |
return centroids, init_label_map | |
def forward(self, pixel_features, num_spixels=None): | |
if num_spixels is None: | |
num_spixels = self.num_spixels | |
assert num_spixels is not None | |
else: | |
if isinstance(num_spixels, tuple): | |
assert len(num_spixels) == 2 | |
else: | |
x = int(math.sqrt(num_spixels)) | |
assert x * x == num_spixels | |
num_spixels = (x, x) | |
pixel_features = pixel_features.permute(0, 3, 1, 2) | |
num_spixels_height, num_spixels_width = num_spixels | |
num_spixels = num_spixels_width * num_spixels_height | |
spixel_features, init_label_map = self.calc_init_centroid( | |
pixel_features, num_spixels_width, num_spixels_height | |
) | |
device = init_label_map.device | |
spixels_number = torch.arange(num_spixels, device=device)[None, :, None] | |
relative_labels_widths = init_label_map[:, None] % num_spixels_width - spixels_number % num_spixels_width | |
relative_labels_heights = torch.div(init_label_map[:, None], num_spixels_width, rounding_mode='trunc') - torch.div(spixels_number, num_spixels_width, rounding_mode='trunc') | |
mask = torch.logical_and(torch.abs(relative_labels_widths) <= self.r, torch.abs(relative_labels_heights) <= self.r) | |
mask_dist = (~mask) * 1e16 | |
pixel_features = pixel_features.reshape(*pixel_features.shape[:2], -1) # (B, C, L) | |
permuted_pixel_features = pixel_features.permute(0, 2, 1) # (B, L, C) | |
for _ in range(self.n_iters): | |
dist_matrix = self.pairwise_dist(pixel_features, spixel_features) # (B, L', L) | |
dist_matrix += mask_dist | |
affinity_matrix = (-dist_matrix * self.temperture).softmax(1) | |
spixel_features = torch.bmm(affinity_matrix.detach(), permuted_pixel_features) | |
spixel_features = spixel_features / affinity_matrix.detach().sum(2, keepdim=True).clamp_(min=1e-16) | |
spixel_features = spixel_features.permute(0, 2, 1) | |
dist_matrix = self.pairwise_dist(pixel_features, spixel_features) | |
hard_labels = torch.argmin(dist_matrix, dim=1) | |
B, C, _ = spixel_features.shape | |
spixel_features = spixel_features.permute(0, 2, 1).reshape(B, num_spixels_height, num_spixels_width, C) | |
return spixel_features, hard_labels | |
def pairwise_dist(self, f1, f2): | |
return ((f1 * f1).sum(dim=1).unsqueeze(1) | |
+ (f2 * f2).sum(dim=1).unsqueeze(2) | |
- 2 * torch.einsum("bcm, bcn -> bmn", f2, f1)) | |
def extra_repr(self): | |
return f"num_spixels={self.num_spixels}, n_iters={self.n_iters}" | |
def naive_unpool(f_regions, region_indices): | |
_, _, C = f_regions.shape | |
N, L = region_indices.shape | |
index = region_indices.view(N, L, 1).expand(N, L, C) | |
result = f_regions.gather(1, index) | |
return result | |
class State: | |
def __init__(self, unpooling): | |
self.unpooling = unpooling | |
self.__updated = False | |
def updated(self): | |
return self.__updated | |
def get(self, name, default=None): | |
return getattr(self, name, default) | |
def update_state(self, **states: dict): | |
self.__updated = True | |
for k, v in states.items(): | |
setattr(self, k, v) | |
def call(self, input: torch.Tensor): | |
return self.unpooling(input, self) | |
class UnpoolingBase(nn.Module): | |
def forward(self, x, state: State): | |
if not state.updated: | |
return x, False | |
return self._forward(x, state) | |
def derive_unpooler(self): | |
return State(self) | |
class NaiveUnpooling(UnpoolingBase): | |
def _forward(self, x, state: State): | |
return naive_unpool(x, state.hard_labels), False | |
class TokenReconstructionBlock(UnpoolingBase): | |
def __init__(self, k=20, temperture=0.01): | |
super().__init__() | |
self.k = k | |
self.temperture = temperture | |
def _forward(self, x, state: State): | |
feat = state.feat_before_pooling | |
sfeat = state.feat_after_pooling | |
ds = ( | |
(feat * feat).sum(dim=2).unsqueeze(2) | |
+ (sfeat * sfeat).sum(dim=2).unsqueeze(1) | |
- 2 * torch.einsum("bnc, bmc -> bnm", feat, sfeat) | |
) # distance between features and super-features | |
ds[ds < 0] = 0 | |
weight = torch.exp(-self.temperture * ds) | |
if self.k >= 0: | |
topk, indices = torch.topk(weight, k=self.k, dim=2) | |
mink = torch.min(topk, dim=-1).values | |
mink = mink.unsqueeze(-1).repeat(1, 1, weight.shape[-1]) | |
mask = torch.ge(weight, mink) | |
zero = Variable(torch.zeros_like(weight)).to(weight.device) | |
attention = torch.where(mask, weight, zero) | |
attention = F.normalize(attention, dim=2) | |
ret = torch.einsum("bnm, bmc -> bnc", attention, x) | |
return ret, False | |
class HourglassImageEncoderViT(ImageEncoderViT): | |
def __init__( | |
self, | |
img_size: int = 1024, | |
patch_size: int = 16, | |
in_chans: int = 3, | |
embed_dim: int = 768, | |
depth: int = 12, | |
num_heads: int = 12, | |
mlp_ratio: float = 4.0, | |
out_chans: int = 256, | |
qkv_bias: bool = True, | |
norm_layer: Type[nn.Module] = nn.LayerNorm, | |
act_layer: Type[nn.Module] = nn.GELU, | |
use_abs_pos: bool = True, | |
use_rel_pos: bool = False, | |
rel_pos_zero_init: bool = True, | |
window_size: int = 0, | |
global_attn_indexes: Tuple[int, ...] = (), | |
hourglass_clustering_location: int = -1, | |
hourglass_num_cluster: int = 100, | |
hourglass_cluster_iters: int = 5, | |
hourglass_temperture: float = 0.01, | |
hourglass_cluster_window_size: int = 5, | |
hourglass_reconstruction_k: int = 20, | |
) -> None: | |
""" | |
Args: | |
img_size (int): Input image size. | |
patch_size (int): Patch size. | |
in_chans (int): Number of input image channels. | |
embed_dim (int): Patch embedding dimension. | |
depth (int): Depth of ViT. | |
num_heads (int): Number of attention heads in each ViT block. | |
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. | |
qkv_bias (bool): If True, add a learnable bias to query, key, value. | |
norm_layer (nn.Module): Normalization layer. | |
act_layer (nn.Module): Activation layer. | |
use_abs_pos (bool): If True, use absolute positional embeddings. | |
use_rel_pos (bool): If True, add relative positional embeddings to the attention map. | |
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. | |
window_size (int): Window size for window attention blocks. | |
global_attn_indexes (list): Indexes for blocks using global attention. | |
""" | |
super().__init__( | |
img_size=img_size, | |
patch_size=patch_size, | |
in_chans=in_chans, | |
embed_dim=embed_dim, | |
depth=depth, | |
num_heads=num_heads, | |
mlp_ratio=mlp_ratio, | |
out_chans=out_chans, | |
qkv_bias=qkv_bias, | |
norm_layer=norm_layer, | |
act_layer=act_layer, | |
use_abs_pos=use_abs_pos, | |
use_rel_pos=use_rel_pos, | |
rel_pos_zero_init=rel_pos_zero_init, | |
window_size=window_size, | |
global_attn_indexes=global_attn_indexes, | |
) | |
hourglass_clustering_location = hourglass_clustering_location if hourglass_clustering_location >= 0 else depth + 1 | |
self.window_size = window_size | |
self.ws_new = int(math.sqrt(hourglass_num_cluster)) | |
self.blocks = nn.ModuleList() | |
for i in range(depth): | |
block = HourglassBlock( | |
dim=embed_dim, | |
num_heads=num_heads, | |
mlp_ratio=mlp_ratio, | |
qkv_bias=qkv_bias, | |
norm_layer=norm_layer, | |
act_layer=act_layer, | |
use_rel_pos=use_rel_pos, | |
rel_pos_zero_init=rel_pos_zero_init, | |
window_size=(window_size if i < hourglass_clustering_location else self.ws_new) if i not in global_attn_indexes else 0, | |
window_size_ckpt=window_size, | |
input_size=(img_size // patch_size, img_size // patch_size), | |
) | |
self.blocks.append(block) | |
self.clustering_location = hourglass_clustering_location | |
self.token_clustering_block = TokenClusteringBlock( | |
num_spixels=hourglass_num_cluster, | |
n_iters=hourglass_cluster_iters, | |
temperture=hourglass_temperture, | |
window_size=hourglass_cluster_window_size, | |
) | |
self.token_reconstruction_block = TokenReconstructionBlock( | |
k=hourglass_reconstruction_k, | |
temperture=hourglass_temperture, | |
) | |
def cluster(self, x, reconstructer): | |
# x: B, H, W, C | |
H, W = x.shape[1:3] | |
x, pad_hw = window_partition(x, self.window_size) # B*Nw, WH, WW, C | |
Bnw, _, _, C = x.shape | |
reconstructer.update_state( | |
feat_before_pooling=x.view(-1, self.window_size * self.window_size, C) | |
) | |
x, hard_labels = self.token_clustering_block(x) # B*H*W, Wh, Ww, C | |
reconstructer.update_state(hard_labels=hard_labels) | |
reconstructer.update_state(feat_after_pooling=x.view(Bnw, -1, C)) | |
# merge window | |
# Reverse window partition | |
h = pad_hw[0] // self.window_size * x.shape[1] | |
w = pad_hw[1] // self.window_size * x.shape[2] | |
x = window_unpartition(x, self.ws_new, (h, w), (h, w)) | |
# out: B, h, w, C | |
return x, pad_hw | |
def reconstruct(self, x, H, W, recontructer, pad_hw): | |
# x: B, h, w, C | |
x, _ = window_partition(x, self.ws_new) # B*Nw, Wh, Ww, C | |
Bnw, _, _, C = x.shape | |
x = x.view(Bnw, -1, C) | |
x, _ = recontructer.call(x) # B*Nw, WH*WW, C | |
# merge windows | |
x = x.view(-1, self.window_size, self.window_size, C) | |
x = window_unpartition(x, self.window_size, pad_hw, (H, W)) # B, H, W, C | |
return x | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
x = self.patch_embed(x) | |
if self.pos_embed is not None: | |
x = x + self.pos_embed | |
H, W = x.shape[1], x.shape[2] | |
reconstructer = self.token_reconstruction_block.derive_unpooler() | |
reconstructer.update_state(hw_shape=(H, W)) | |
for i, blk in enumerate(self.blocks): | |
if i == self.clustering_location: | |
x, pad_hw = self.cluster(x, reconstructer) | |
x = blk(x) | |
if x.shape[1] != H or x.shape[2] != W: | |
x = self.reconstruct(x, H, W, reconstructer, pad_hw) | |
x = self.neck(x.permute(0, 3, 1, 2)) | |
return x | |
def load_hourglass_args(self, **hourglass_args): | |
hourglass_clustering_location = hourglass_args.get('hourglass_clustering_location', self.clustering_location) | |
hourglass_num_cluster = hourglass_args.get('hourglass_num_cluster', self.token_clustering_block.num_spixels[0] * self.token_clustering_block.num_spixels[1]) | |
hourglass_cluster_iters = hourglass_args.get('hourglass_cluster_iters', self.token_clustering_block.n_iters) | |
hourglass_temperture = hourglass_args.get('hourglass_temperture', self.token_clustering_block.temperture) | |
hourglass_cluster_window_size = hourglass_args.get('hourglass_cluster_window_size', self.token_clustering_block.r * 2 + 1) | |
hourglass_reconstruction_k = hourglass_args.get('hourglass_reconstruction_k', self.token_reconstruction_block.k) | |
self.clustering_location = hourglass_clustering_location if hourglass_clustering_location >= 0 else len(self.blocks) + 1 | |
self.ws_new = int(math.sqrt(hourglass_num_cluster)) | |
for i, blk in enumerate(self.blocks): | |
blk.window_size = (self.window_size if i < self.clustering_location else self.ws_new) if blk.window_size != 0 else 0 | |
self.token_clustering_block = TokenClusteringBlock( | |
num_spixels=hourglass_num_cluster, | |
n_iters=hourglass_cluster_iters, | |
temperture=hourglass_temperture, | |
window_size=hourglass_cluster_window_size, | |
) | |
self.token_reconstruction_block = TokenReconstructionBlock( | |
k=hourglass_reconstruction_k, | |
temperture=hourglass_temperture, | |
) | |
class HourglassBlock(Block): | |
"""Transformer blocks with support of window attention and residual propagation blocks""" | |
def __init__( | |
self, | |
dim: int, | |
num_heads: int, | |
mlp_ratio: float = 4.0, | |
qkv_bias: bool = True, | |
norm_layer: Type[nn.Module] = nn.LayerNorm, | |
act_layer: Type[nn.Module] = nn.GELU, | |
use_rel_pos: bool = False, | |
rel_pos_zero_init: bool = True, | |
window_size: int = 0, | |
input_size: Optional[Tuple[int, int]] = None, | |
window_size_ckpt: int = 0, | |
) -> None: | |
""" | |
Args: | |
dim (int): Number of input channels. | |
num_heads (int): Number of attention heads in each ViT block. | |
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. | |
qkv_bias (bool): If True, add a learnable bias to query, key, value. | |
norm_layer (nn.Module): Normalization layer. | |
act_layer (nn.Module): Activation layer. | |
use_rel_pos (bool): If True, add relative positional embeddings to the attention map. | |
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. | |
window_size (int): Window size for window attention blocks. If it equals 0, then | |
use global attention. | |
input_size (int or None): Input resolution for calculating the relative positional | |
parameter size. | |
""" | |
super(HourglassBlock, self).__init__( | |
dim=dim, | |
num_heads=num_heads, | |
mlp_ratio=mlp_ratio, | |
qkv_bias=qkv_bias, | |
norm_layer=norm_layer, | |
act_layer=act_layer, | |
use_rel_pos=use_rel_pos, | |
rel_pos_zero_init=rel_pos_zero_init, | |
window_size=window_size, | |
input_size=input_size, | |
) | |
self.attn = Attention( | |
dim, | |
num_heads=num_heads, | |
qkv_bias=qkv_bias, | |
use_rel_pos=use_rel_pos, | |
rel_pos_zero_init=rel_pos_zero_init, | |
input_size=input_size if window_size == 0 else (window_size_ckpt, window_size_ckpt), | |
) | |