PhyscalX's picture
TAP v1.1 models release
4cee877
# ------------------------------------------------------------------------
# Copyright (c) 2023-present, BAAI. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ------------------------------------------------------------------------
"""Image decoder."""
try:
from flash_attn import flash_attn_func
except ImportError:
flash_attn_func = None
import torch
from torch import nn
class TransposedLayerNorm(nn.LayerNorm):
"""LayerNorm with pre-transposed spatial axes."""
def forward(self, input):
return super().forward(input.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
class MLP(nn.Module):
"""Two layers MLP."""
def __init__(self, dim, mlp_dim, activation_type="ReLU"):
super(MLP, self).__init__()
self.fc1 = nn.Linear(dim, mlp_dim)
self.fc2 = nn.Linear(mlp_dim, dim)
self.activation = getattr(nn, activation_type)()
self.activation.inplace = True
def forward(self, x):
return self.fc2(self.activation(self.fc1(x)))
class Attention(nn.Module):
"""Multi-head attention."""
def __init__(self, dim=256, num_heads=8, attn_ratio=1):
super(Attention, self).__init__()
self.num_heads = num_heads or dim // 64
self.head_dim = int(dim * attn_ratio) // self.num_heads
self.q_proj = nn.Linear(dim, self.num_heads * self.head_dim)
self.k_proj = nn.Linear(dim, self.num_heads * self.head_dim)
self.v_proj = nn.Linear(dim, self.num_heads * self.head_dim)
self.proj = nn.Linear(self.num_heads * self.head_dim, dim)
self.scale = self.head_dim**-0.5
def forward(self, q, k, v):
q = self.q_proj(q).view((-1, q.size(1), self.num_heads, self.head_dim))
k = self.k_proj(k).view((-1, k.size(1), self.num_heads, self.head_dim))
v = self.v_proj(v).view((-1, v.size(1), self.num_heads, self.head_dim))
o = flash_attn_func(q, k, v, softmax_scale=self.scale)
return self.proj(o.flatten(2))
class Block(nn.Module):
"""Transformer block."""
def __init__(
self,
dim=256,
num_heads=8,
attn_ratio=0.5,
mlp_dim=2048,
activation_type="ReLU",
skip_first_query_pos=False,
):
super(Block, self).__init__()
self.self_attn = Attention(dim, num_heads)
self.norm1 = nn.LayerNorm(dim)
self.cross_attn_token_to_image = Attention(dim, num_heads, attn_ratio)
self.norm2 = nn.LayerNorm(dim)
self.mlp = MLP(dim, mlp_dim, activation_type)
self.norm3 = nn.LayerNorm(dim)
self.cross_attn_image_to_token = Attention(dim, num_heads, attn_ratio)
self.norm4 = nn.LayerNorm(dim)
self.dropout = nn.Dropout(0.1, inplace=True)
self.skip_first_query_pos = skip_first_query_pos
def forward(self, query, key, query_pos, key_pos):
if self.skip_first_query_pos:
query = self.norm1(self.self_attn(query, query, query))
else:
q = query + query_pos
query = self.norm1(self.dropout(self.self_attn(q, q, query)).add_(query))
q, k = query + query_pos, key + key_pos
query = self.norm2(self.dropout(self.cross_attn_token_to_image(q, k, key)).add_(query))
query = self.norm3(self.dropout(self.mlp(query)).add_(query))
key = self.norm4(self.cross_attn_image_to_token(k, query + query_pos, query).add_(key))
return query, key
class Transformer(nn.Module):
"""Two-way transformer decoder."""
def __init__(
self,
embed_dim=256,
num_heads=8,
attn_ratio=0.5,
mlp_dim=2048,
activation_type="ReLU",
depth=2,
):
super(Transformer, self).__init__()
self.blocks = nn.ModuleList(
Block(
embed_dim,
num_heads,
attn_ratio=attn_ratio,
mlp_dim=mlp_dim,
activation_type=activation_type,
skip_first_query_pos=i == 0,
)
for i in range(depth)
)
self.final_attn_token_to_image = Attention(embed_dim, num_heads, attn_ratio)
self.norm = nn.LayerNorm(embed_dim)
self.dropout = nn.Dropout(0.1, inplace=True)
def forward(self, query, key, query_pos, key_pos):
for blk in self.blocks:
query, key = blk(query, key, query_pos, key_pos)
q, k = query + query_pos, key + key_pos
query = self.norm(self.dropout(self.final_attn_token_to_image(q, k, key)).add_(query))
return query, key
class Predictor(nn.Module):
"""MLP predictor."""
def __init__(self, in_dim, out_dim, mlp_dim=None, depth=3):
super(Predictor, self).__init__()
mlp_dims = [mlp_dim or in_dim] * (depth - 1)
in_dims, out_dims = [in_dim] + mlp_dims, mlp_dims + [out_dim]
self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip(in_dims, out_dims))
def forward(self, x):
for fc in self.layers[:-1]:
x = nn.functional.relu(fc(x), inplace=True)
return self.layers[-1](x)
class ImageDecoder(nn.Module):
"""Module to decode region tokens and masks."""
def __init__(self, depth, embed_dim, num_heads, num_mask_tokens=4, sem_embed_dim=1024):
super(ImageDecoder, self).__init__()
self.embed_dim = embed_dim
self.num_mask_tokens = num_mask_tokens
self.transformer = Transformer(embed_dim, num_heads, depth=depth)
self.iou_token = nn.Embedding(1, embed_dim)
self.sem_tokens = nn.Embedding(num_mask_tokens, embed_dim)
self.mask_tokens = nn.Embedding(num_mask_tokens, embed_dim)
self.output_conv = nn.Sequential(
nn.ConvTranspose2d(embed_dim, embed_dim // 4, 2, 2),
TransposedLayerNorm(embed_dim // 4),
nn.GELU(),
nn.ConvTranspose2d(embed_dim // 4, embed_dim // 8, 2, 2),
nn.GELU(),
)
self.mask_pred = nn.ModuleList(
Predictor(embed_dim, embed_dim // 8) for _ in range(num_mask_tokens)
)
self.iou_pred = Predictor(embed_dim, num_mask_tokens)
self.sem_pred = Predictor(embed_dim, sem_embed_dim, sem_embed_dim)
def get_outputs(self, inputs):
img_embeds = inputs["img_embeds"]
sparse_embeds = inputs["sparse_embeds"]
ims_per_batch = img_embeds.size(0)
prompts_per_batch = sparse_embeds.size(0)
img_embed_size = img_embeds.shape[2:-1]
# Prepare query.
tokens = [self.sem_tokens.weight, self.iou_token.weight, self.mask_tokens.weight]
query = torch.cat(tokens).unsqueeze_(0).expand(prompts_per_batch, -1, -1)
query = torch.cat((query, sparse_embeds), dim=1)
num_tokens = query.shape[1] - sparse_embeds.shape[1]
# Prepare key.
key = img_embeds.expand(-1, prompts_per_batch // ims_per_batch, -1, -1, -1)
key = key.flatten(0, 1).flatten(1, 2)
# Decode.
query, key = self.transformer(query, key, query, inputs["img_pos"])
# Upscale key.
key = key.transpose(1, 2).view((-1, self.embed_dim) + img_embed_size)
mask_embeds = self.output_conv(key).flatten(2)
# Unpack query.
sem_tokens = query[:, : self.num_mask_tokens]
sam_tokens = query[:, self.num_mask_tokens : num_tokens].unbind(1)
iou_tokens, mask_tokens = sam_tokens[0], sam_tokens[1:]
# Predict.
mask_pred = [f(x) for f, x in zip(self.mask_pred, mask_tokens)]
mask_pred = torch.stack(mask_pred, dim=1) @ mask_embeds
mask_pred_size = list(4 * embed_size for embed_size in img_embed_size)
mask_pred = mask_pred.view([-1, self.num_mask_tokens] + mask_pred_size)
outputs = {"iou_pred": self.iou_pred(iou_tokens), "mask_pred": mask_pred}
outputs["sem_tokens"] = sem_tokens.unsqueeze_(2)
outputs["sem_embeds"] = self.sem_pred(outputs["sem_tokens"].flatten(2))
return outputs
def forward(self, inputs):
outputs = self.get_outputs(inputs)
outputs["iou_pred"] = outputs["iou_pred"].float()
return outputs