File size: 12,915 Bytes
32fc8c9 243800e 32fc8c9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 |
import math
import os
import warnings
from dataclasses import dataclass
from functools import lru_cache, partial
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
from torch import Tensor, nn
from torch.autograd import Function
from torch.autograd.function import once_differentiable
from transformers.activations import ACT2CLS, ACT2FN
from transformers.image_transforms import center_to_corners_format, corners_to_center_format
from transformers.modeling_outputs import BaseModelOutput
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import (
ModelOutput,
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_ninja_available,
is_scipy_available,
is_torch_cuda_available,
logging,
replace_return_docstrings,
requires_backends,
)
from transformers.models.rt_detr.configuration_rt_detr_resnet import RTDetrResNetConfig
from transformers.models.rt_detr.modeling_rt_detr import (
RTDetrConfig,
RTDetrDecoderOutput,
RTDetrModelOutput,
RTDetrObjectDetectionOutput,
RTDetrFrozenBatchNorm2d,
RTDetrConvEncoder,
RTDetrConvNormLayer,
RTDetrEncoderLayer,
RTDetrRepVggBlock,
RTDetrCSPRepLayer,
RTDetrMultiscaleDeformableAttention,
RTDetrMultiheadAttention,
RTDetrDecoderLayer,
RTDetrPreTrainedModel,
RTDetrEncoder,
RTDetrHybridEncoder,
RTDetrDecoder,
RTDetrModel,
RTDetrMLPPredictionHead,
RTDetrForObjectDetection
)
from transformers.loss.loss_rt_detr import (RTDetrLoss, RTDetrHungarianMatcher)
from transformers.utils.backbone_utils import load_backbone
# from .configuration_rt_detr_v2 import RTDetrV2Config TODO define the config
class RTDetrV2Config(RTDetrConfig):
model_type = "rt_detr_v2" # Update the model type
def __init__(
self,
decoder_n_levels=3,
decoder_offset_scale=0.5,
**kwargs
):
super().__init__(**kwargs)
self.decoder_n_levels = decoder_n_levels
self.decoder_offset_scale = decoder_offset_scale
class RTDetrV2ResNetConfig(RTDetrResNetConfig):
model_type = "rt_detr_v2_resnet"
logger = logging.get_logger(__name__)
class RTDetrV2DecoderOutput(RTDetrDecoderOutput):
pass
class RTDetrV2ModelOutput(RTDetrModelOutput):
pass
class RTDetrV2ObjectDetectionOutput(RTDetrObjectDetectionOutput):
pass
class RTDetrV2FrozenBatchNorm2d(RTDetrFrozenBatchNorm2d):
pass
class RTDetrV2ConvEncoder(RTDetrConvEncoder):
pass
class RTDetrV2ConvNormLayer(RTDetrConvNormLayer):
pass
class RTDetrV2EncoderLayer(RTDetrEncoderLayer):
pass
class RTDetrV2RepVggBlock(RTDetrRepVggBlock):
pass
class RTDetrV2CSPRepLayer(RTDetrCSPRepLayer):
pass
# new implementaiton of the multiscale deformable attention (v2)
def multi_scale_deformable_attention_v2(
value: Tensor,
value_spatial_shapes: Tensor,
sampling_locations: Tensor,
attention_weights: Tensor,
num_points_list: List[int],
method="default",
) -> Tensor:
batch_size, _, num_heads, hidden_dim = value.shape
_, num_queries, num_heads, num_levels, num_points = sampling_locations.shape
value_list = (
value.permute(0, 2, 3, 1)
.flatten(0, 1)
.split([height.item() * width.item() for height, width in value_spatial_shapes], dim=-1)
)
# sampling_offsets [8, 480, 8, 12, 2]
if method == "default":
sampling_grids = 2 * sampling_locations - 1
elif method == "discrete":
sampling_grids = sampling_locations
sampling_grids = sampling_grids.permute(0, 2, 1, 3, 4).flatten(0, 1)
sampling_grids = sampling_grids.split(num_points_list, dim=-2)
sampling_value_list = []
for level_id, (height, width) in enumerate(value_spatial_shapes):
# batch_size, height*width, num_heads, hidden_dim
# -> batch_size, height*width, num_heads*hidden_dim
# -> batch_size, num_heads*hidden_dim, height*width
# -> batch_size*num_heads, hidden_dim, height, width
value_l_ = value_list[level_id].reshape(batch_size * num_heads, hidden_dim, height, width)
# batch_size, num_queries, num_heads, num_points, 2
# -> batch_size, num_heads, num_queries, num_points, 2
# -> batch_size*num_heads, num_queries, num_points, 2
sampling_grid_l_ = sampling_grids[level_id]
# batch_size*num_heads, hidden_dim, num_queries, num_points
if method == "default":
sampling_value_l_ = nn.functional.grid_sample(
value_l_, sampling_grid_l_, mode="bilinear", padding_mode="zeros", align_corners=False
)
elif method == "discrete":
sampling_coord = (sampling_grid_l_ * torch.tensor([[width, height]], device=value.device) + 0.5).to(
torch.int64
)
# Separate clamping for x and y coordinates
sampling_coord_x = sampling_coord[..., 0].clamp(0, width - 1)
sampling_coord_y = sampling_coord[..., 1].clamp(0, height - 1)
# Combine the clamped coordinates
sampling_coord = torch.stack([sampling_coord_x, sampling_coord_y], dim=-1)
sampling_coord = sampling_coord.reshape(batch_size * num_heads, num_queries * num_points_list[level_id], 2)
sampling_idx = (
torch.arange(sampling_coord.shape[0], device=value.device)
.unsqueeze(-1)
.repeat(1, sampling_coord.shape[1])
)
sampling_value_l_ = value_l_[sampling_idx, :, sampling_coord[..., 1], sampling_coord[..., 0]]
sampling_value_l_ = sampling_value_l_.permute(0, 2, 1).reshape(
batch_size * num_heads, hidden_dim, num_queries, num_points_list[level_id]
)
sampling_value_list.append(sampling_value_l_)
# (batch_size, num_queries, num_heads, num_levels, num_points)
# -> (batch_size, num_heads, num_queries, num_levels, num_points)
# -> (batch_size, num_heads, 1, num_queries, num_levels*num_points)
attention_weights = attention_weights.permute(0, 2, 1, 3).reshape(
batch_size * num_heads, 1, num_queries, sum(num_points_list)
)
output = (
(torch.concat(sampling_value_list, dim=-1) * attention_weights)
.sum(-1)
.view(batch_size, num_heads * hidden_dim, num_queries)
)
return output.transpose(1, 2).contiguous()
def __init__(self, config: RTDetrV2Config):
super().__init__(config, config.decoder_attention_heads, config.decoder_n_points)
self.n_levels = config.decoder_n_levels
self.offset_scale = config.decoder_offset_scale
class RTDetrV2MultiscaleDeformableAttention(RTDetrMultiscaleDeformableAttention):
def __init__(self, config: RTDetrV2Config):
super().__init__(config, config.decoder_attention_heads, config.decoder_n_points)
self.n_levels = config.decoder_n_levels
self.offset_scale = config.decoder_offset_scale
n_points_list = [self.n_points for _ in range(self.n_levels)]
self.n_points_list = n_points_list
n_points_scale = [1 / n for n in n_points_list for _ in range(n)]
self.register_buffer("n_points_scale", torch.tensor(n_points_scale, dtype=torch.float32))
self._reset_parameters()
def _reset_parameters(self):
nn.init.constant_(self.sampling_offsets.weight.data, 0.0)
default_dtype = torch.get_default_dtype()
thetas = torch.arange(self.n_heads, dtype=torch.int64).to(default_dtype) * (2.0 * math.pi / self.n_heads)
grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
grid_init = (
(grid_init / grid_init.abs().max(-1, keepdim=True)[0])
.view(self.n_heads, 1, 1, 2)
.repeat(1, self.n_levels, self.n_points, 1)
)
for i in range(self.n_points):
grid_init[:, :, i, :] *= i + 1
with torch.no_grad():
self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1))
nn.init.constant_(self.attention_weights.weight.data, 0.0)
nn.init.constant_(self.attention_weights.bias.data, 0.0)
nn.init.xavier_uniform_(self.value_proj.weight.data)
nn.init.constant_(self.value_proj.bias.data, 0.0)
nn.init.xavier_uniform_(self.output_proj.weight.data)
nn.init.constant_(self.output_proj.bias.data, 0.0)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
encoder_hidden_states=None,
encoder_attention_mask=None,
position_embeddings: Optional[torch.Tensor] = None,
reference_points=None,
spatial_shapes=None,
level_start_index=None,
output_attentions: bool = False,
):
# add position embeddings to the hidden states before projecting to queries and keys
if position_embeddings is not None:
hidden_states = self.with_pos_embed(hidden_states, position_embeddings)
batch_size, num_queries, _ = hidden_states.shape
batch_size, sequence_length, _ = encoder_hidden_states.shape
if (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() != sequence_length:
raise ValueError(
"Make sure to align the spatial shapes with the sequence length of the encoder hidden states"
)
value = self.value_proj(encoder_hidden_states)
if attention_mask is not None:
# we invert the attention_mask
value = value.masked_fill(~attention_mask[..., None], float(0))
value = value.view(batch_size, sequence_length, self.n_heads, self.d_model // self.n_heads)
sampling_offsets = self.sampling_offsets(hidden_states).view(
batch_size, num_queries, self.n_heads, self.n_levels * self.n_points, 2
)
attention_weights = self.attention_weights(hidden_states).view(
batch_size, num_queries, self.n_heads, self.n_levels * self.n_points
)
attention_weights = F.softmax(attention_weights, -1).view(
batch_size, num_queries, self.n_heads, self.n_levels * self.n_points
)
# batch_size, num_queries, n_heads, n_levels, n_points, 2
num_coordinates = reference_points.shape[-1]
if num_coordinates == 2:
offset_normalizer = torch.stack([spatial_shapes[..., 1], spatial_shapes[..., 0]], -1)
sampling_locations = (
reference_points[:, :, None, :, None, :]
+ sampling_offsets / offset_normalizer[None, None, None, :, None, :]
)
elif num_coordinates == 4:
n_points_scale = self.n_points_scale.to(dtype=hidden_states.dtype).unsqueeze(-1)
offset = sampling_offsets * n_points_scale * reference_points[:, :, None, :, 2:] * self.offset_scale
sampling_locations = reference_points[:, :, None, :, :2] + offset
else:
raise ValueError(f"Last dim of reference_points must be 2 or 4, but got {reference_points.shape[-1]}")
if self.disable_custom_kernels:
# PyTorch implementation
output = multi_scale_deformable_attention_v2(
value, spatial_shapes, sampling_locations, attention_weights, self.n_points_list
)
else:
try:
# custom kernel
output = MultiScaleDeformableAttentionFunction.apply(
value,
spatial_shapes,
level_start_index,
sampling_locations,
attention_weights,
self.im2col_step,
)
except Exception:
# PyTorch implementation
output = multi_scale_deformable_attention_v2(
value, spatial_shapes, sampling_locations, attention_weights, self.n_points_list
)
output = self.output_proj(output)
return output, attention_weights
class RTDetrV2MultiheadAttention(RTDetrMultiheadAttention):
pass
class RTDetrV2DecoderLayer(RTDetrDecoderLayer):
pass
class RTDetrV2PreTrainedModel(RTDetrPreTrainedModel):
config_class = RTDetrV2Config
base_model_prefix = "rt_detr_v2"
main_input_name = "pixel_values"
_no_split_modules = [r"RTDetrV2ConvEncoder", r"RTDetrV2EncoderLayer", r"RTDetrV2DecoderLayer"]
class RTDetrV2Encoder(RTDetrEncoder):
pass
class RTDetrV2HybridEncoder(RTDetrHybridEncoder):
pass
class RTDetrV2Decoder(RTDetrDecoder):
pass
class RTDetrV2Model(RTDetrModel):
pass
class RTDetrV2Loss(RTDetrLoss):
pass
class RTDetrV2MLPPredictionHead(RTDetrMLPPredictionHead):
pass
class RTDetrV2HungarianMatcher(RTDetrHungarianMatcher):
pass
# must inherit the new classes!
class RTDetrV2ForObjectDetection(RTDetrForObjectDetection):
pass
|