|
|
|
""" |
|
Model head modules |
|
""" |
|
|
|
import math |
|
|
|
import torch |
|
import torch.nn as nn |
|
from torch.nn.init import constant_, xavier_uniform_ |
|
|
|
from ultralytics.utils.tal import TORCH_1_10, dist2bbox, make_anchors |
|
|
|
from .block import DFL, Proto |
|
from .conv import Conv |
|
from .transformer import MLP, DeformableTransformerDecoder, DeformableTransformerDecoderLayer |
|
from .utils import bias_init_with_prob, linear_init_ |
|
|
|
__all__ = 'Detect', 'Segment','ExtendedSegment' ,'Pose', 'Classify', 'RTDETRDecoder' |
|
|
|
|
|
class Detect(nn.Module): |
|
"""YOLOv8 Detect head for detection models.""" |
|
dynamic = False |
|
export = False |
|
shape = None |
|
anchors = torch.empty(0) |
|
strides = torch.empty(0) |
|
|
|
def __init__(self, nc=80, ch=()): |
|
super().__init__() |
|
self.nc = nc |
|
self.nl = len(ch) |
|
self.reg_max = 16 |
|
self.no = nc + self.reg_max * 4 |
|
self.stride = torch.zeros(self.nl) |
|
c2, c3 = max((16, ch[0] // 4, self.reg_max * 4)), max(ch[0], min(self.nc, 100)) |
|
self.cv2 = nn.ModuleList( |
|
nn.Sequential(Conv(x, c2, 3), Conv(c2, c2, 3), nn.Conv2d(c2, 4 * self.reg_max, 1)) for x in ch) |
|
self.cv3 = nn.ModuleList(nn.Sequential(Conv(x, c3, 3), Conv(c3, c3, 3), nn.Conv2d(c3, self.nc, 1)) for x in ch) |
|
self.dfl = DFL(self.reg_max) if self.reg_max > 1 else nn.Identity() |
|
|
|
def forward(self, x): |
|
"""Concatenates and returns predicted bounding boxes and class probabilities.""" |
|
shape = x[0].shape |
|
for i in range(self.nl): |
|
x[i] = torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1) |
|
if self.training: |
|
return x |
|
elif self.dynamic or self.shape != shape: |
|
self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5)) |
|
self.shape = shape |
|
|
|
x_cat = torch.cat([xi.view(shape[0], self.no, -1) for xi in x], 2) |
|
if self.export and self.format in ('saved_model', 'pb', 'tflite', 'edgetpu', 'tfjs'): |
|
box = x_cat[:, :self.reg_max * 4] |
|
cls = x_cat[:, self.reg_max * 4:] |
|
else: |
|
box, cls = x_cat.split((self.reg_max * 4, self.nc), 1) |
|
dbox = dist2bbox(self.dfl(box), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.strides |
|
|
|
if self.export and self.format in ('tflite', 'edgetpu'): |
|
|
|
|
|
|
|
img_h = shape[2] * self.stride[0] |
|
img_w = shape[3] * self.stride[0] |
|
img_size = torch.tensor([img_w, img_h, img_w, img_h], device=dbox.device).reshape(1, 4, 1) |
|
dbox /= img_size |
|
|
|
y = torch.cat((dbox, cls.sigmoid()), 1) |
|
return y if self.export else (y, x) |
|
|
|
def bias_init(self): |
|
"""Initialize Detect() biases, WARNING: requires stride availability.""" |
|
m = self |
|
|
|
|
|
for a, b, s in zip(m.cv2, m.cv3, m.stride): |
|
a[-1].bias.data[:] = 1.0 |
|
b[-1].bias.data[:m.nc] = math.log(5 / m.nc / (640 / s) ** 2) |
|
|
|
|
|
|
|
|
|
class Segment(Detect): |
|
"""YOLOv8 Segment head for segmentation models.""" |
|
|
|
def __init__(self, nc=80, nm=32, npr=256, ch=()): |
|
"""Initialize the YOLO model attributes such as the number of masks, prototypes, and the convolution layers.""" |
|
super().__init__(nc, ch) |
|
self.nm = nm |
|
self.npr = npr |
|
self.proto = Proto(ch[0], self.npr, self.nm) |
|
self.detect = Detect.forward |
|
|
|
c4 = max(ch[0] // 4, self.nm) |
|
self.cv4 = nn.ModuleList(nn.Sequential(Conv(x, c4, 3), Conv(c4, c4, 3), nn.Conv2d(c4, self.nm, 1)) for x in ch) |
|
|
|
def forward(self, x): |
|
"""Return model outputs and mask coefficients if training, otherwise return outputs and mask coefficients.""" |
|
p = self.proto(x[0]) |
|
bs = p.shape[0] |
|
|
|
mc = torch.cat([self.cv4[i](x[i]).view(bs, self.nm, -1) for i in range(self.nl)], 2) |
|
x = self.detect(self, x) |
|
if self.training: |
|
return x, mc, p |
|
return (torch.cat([x, mc], 1), p) if self.export else (torch.cat([x[0], mc], 1), (x[1], mc, p)) |
|
|
|
|
|
class ExtendedSegment(Segment): |
|
"""Extends the Segment class to add a regression head predicting a 6D vector.""" |
|
|
|
def __init__(self, nc=80, nm=32, npr=256, ch=()): |
|
super().__init__(nc, nm, npr, ch) |
|
self.regression_head = nn.ModuleList(nn.Sequential( |
|
Conv(x, max(x // 4, 128), 3), |
|
nn.Conv2d(max(x // 4, 128), 5, 1), |
|
nn.Sigmoid() |
|
) for x in ch) |
|
|
|
def forward(self, x): |
|
regression_outputs = [self.regression_head[i](x[i]).view(x[i].shape[0], 5, -1) for i in range(self.nl)] |
|
regression_tensor = torch.cat(regression_outputs, 2) |
|
outputs = super().forward(x) |
|
|
|
|
|
if self.training: |
|
x, mc, p = outputs |
|
|
|
return x, mc, p, regression_tensor |
|
|
|
else: |
|
|
|
if self.export: |
|
out_1, out_2 = outputs |
|
return (out_1, out_2, regression_tensor) |
|
else: |
|
out_1, out_2 = outputs |
|
return ((out_1,regression_tensor), (out_2[0],out_2[1],out_2[2], regression_tensor)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Pose(Detect): |
|
"""YOLOv8 Pose head for keypoints models.""" |
|
|
|
def __init__(self, nc=80, kpt_shape=(17, 3), ch=()): |
|
"""Initialize YOLO network with default parameters and Convolutional Layers.""" |
|
super().__init__(nc, ch) |
|
self.kpt_shape = kpt_shape |
|
self.nk = kpt_shape[0] * kpt_shape[1] |
|
self.detect = Detect.forward |
|
|
|
c4 = max(ch[0] // 4, self.nk) |
|
self.cv4 = nn.ModuleList(nn.Sequential(Conv(x, c4, 3), Conv(c4, c4, 3), nn.Conv2d(c4, self.nk, 1)) for x in ch) |
|
|
|
def forward(self, x): |
|
"""Perform forward pass through YOLO model and return predictions.""" |
|
bs = x[0].shape[0] |
|
kpt = torch.cat([self.cv4[i](x[i]).view(bs, self.nk, -1) for i in range(self.nl)], -1) |
|
x = self.detect(self, x) |
|
if self.training: |
|
return x, kpt |
|
pred_kpt = self.kpts_decode(bs, kpt) |
|
return torch.cat([x, pred_kpt], 1) if self.export else (torch.cat([x[0], pred_kpt], 1), (x[1], kpt)) |
|
|
|
def kpts_decode(self, bs, kpts): |
|
"""Decodes keypoints.""" |
|
ndim = self.kpt_shape[1] |
|
if self.export: |
|
y = kpts.view(bs, *self.kpt_shape, -1) |
|
a = (y[:, :, :2] * 2.0 + (self.anchors - 0.5)) * self.strides |
|
if ndim == 3: |
|
a = torch.cat((a, y[:, :, 2:3].sigmoid()), 2) |
|
return a.view(bs, self.nk, -1) |
|
else: |
|
y = kpts.clone() |
|
if ndim == 3: |
|
y[:, 2::3].sigmoid_() |
|
y[:, 0::ndim] = (y[:, 0::ndim] * 2.0 + (self.anchors[0] - 0.5)) * self.strides |
|
y[:, 1::ndim] = (y[:, 1::ndim] * 2.0 + (self.anchors[1] - 0.5)) * self.strides |
|
return y |
|
|
|
|
|
class Classify(nn.Module): |
|
"""YOLOv8 classification head, i.e. x(b,c1,20,20) to x(b,c2).""" |
|
|
|
def __init__(self, c1, c2, k=1, s=1, p=None, g=1): |
|
super().__init__() |
|
c_ = 1280 |
|
self.conv = Conv(c1, c_, k, s, p, g) |
|
self.pool = nn.AdaptiveAvgPool2d(1) |
|
self.drop = nn.Dropout(p=0.0, inplace=True) |
|
self.linear = nn.Linear(c_, c2) |
|
|
|
def forward(self, x): |
|
"""Performs a forward pass of the YOLO model on input image data.""" |
|
if isinstance(x, list): |
|
x = torch.cat(x, 1) |
|
x = self.linear(self.drop(self.pool(self.conv(x)).flatten(1))) |
|
return x if self.training else x.softmax(1) |
|
|
|
|
|
class RTDETRDecoder(nn.Module): |
|
export = False |
|
|
|
def __init__( |
|
self, |
|
nc=80, |
|
ch=(512, 1024, 2048), |
|
hd=256, |
|
nq=300, |
|
ndp=4, |
|
nh=8, |
|
ndl=6, |
|
d_ffn=1024, |
|
dropout=0., |
|
act=nn.ReLU(), |
|
eval_idx=-1, |
|
|
|
nd=100, |
|
label_noise_ratio=0.5, |
|
box_noise_scale=1.0, |
|
learnt_init_query=False): |
|
super().__init__() |
|
self.hidden_dim = hd |
|
self.nhead = nh |
|
self.nl = len(ch) |
|
self.nc = nc |
|
self.num_queries = nq |
|
self.num_decoder_layers = ndl |
|
|
|
|
|
self.input_proj = nn.ModuleList(nn.Sequential(nn.Conv2d(x, hd, 1, bias=False), nn.BatchNorm2d(hd)) for x in ch) |
|
|
|
|
|
|
|
|
|
decoder_layer = DeformableTransformerDecoderLayer(hd, nh, d_ffn, dropout, act, self.nl, ndp) |
|
self.decoder = DeformableTransformerDecoder(hd, decoder_layer, ndl, eval_idx) |
|
|
|
|
|
self.denoising_class_embed = nn.Embedding(nc, hd) |
|
self.num_denoising = nd |
|
self.label_noise_ratio = label_noise_ratio |
|
self.box_noise_scale = box_noise_scale |
|
|
|
|
|
self.learnt_init_query = learnt_init_query |
|
if learnt_init_query: |
|
self.tgt_embed = nn.Embedding(nq, hd) |
|
self.query_pos_head = MLP(4, 2 * hd, hd, num_layers=2) |
|
|
|
|
|
self.enc_output = nn.Sequential(nn.Linear(hd, hd), nn.LayerNorm(hd)) |
|
self.enc_score_head = nn.Linear(hd, nc) |
|
self.enc_bbox_head = MLP(hd, hd, 4, num_layers=3) |
|
|
|
|
|
self.dec_score_head = nn.ModuleList([nn.Linear(hd, nc) for _ in range(ndl)]) |
|
self.dec_bbox_head = nn.ModuleList([MLP(hd, hd, 4, num_layers=3) for _ in range(ndl)]) |
|
|
|
self._reset_parameters() |
|
|
|
def forward(self, x, batch=None): |
|
from ultralytics.models.utils.ops import get_cdn_group |
|
|
|
|
|
feats, shapes = self._get_encoder_input(x) |
|
|
|
|
|
dn_embed, dn_bbox, attn_mask, dn_meta = \ |
|
get_cdn_group(batch, |
|
self.nc, |
|
self.num_queries, |
|
self.denoising_class_embed.weight, |
|
self.num_denoising, |
|
self.label_noise_ratio, |
|
self.box_noise_scale, |
|
self.training) |
|
|
|
embed, refer_bbox, enc_bboxes, enc_scores = \ |
|
self._get_decoder_input(feats, shapes, dn_embed, dn_bbox) |
|
|
|
|
|
dec_bboxes, dec_scores = self.decoder(embed, |
|
refer_bbox, |
|
feats, |
|
shapes, |
|
self.dec_bbox_head, |
|
self.dec_score_head, |
|
self.query_pos_head, |
|
attn_mask=attn_mask) |
|
x = dec_bboxes, dec_scores, enc_bboxes, enc_scores, dn_meta |
|
if self.training: |
|
return x |
|
|
|
y = torch.cat((dec_bboxes.squeeze(0), dec_scores.squeeze(0).sigmoid()), -1) |
|
return y if self.export else (y, x) |
|
|
|
def _generate_anchors(self, shapes, grid_size=0.05, dtype=torch.float32, device='cpu', eps=1e-2): |
|
anchors = [] |
|
for i, (h, w) in enumerate(shapes): |
|
sy = torch.arange(end=h, dtype=dtype, device=device) |
|
sx = torch.arange(end=w, dtype=dtype, device=device) |
|
grid_y, grid_x = torch.meshgrid(sy, sx, indexing='ij') if TORCH_1_10 else torch.meshgrid(sy, sx) |
|
grid_xy = torch.stack([grid_x, grid_y], -1) |
|
|
|
valid_WH = torch.tensor([h, w], dtype=dtype, device=device) |
|
grid_xy = (grid_xy.unsqueeze(0) + 0.5) / valid_WH |
|
wh = torch.ones_like(grid_xy, dtype=dtype, device=device) * grid_size * (2.0 ** i) |
|
anchors.append(torch.cat([grid_xy, wh], -1).view(-1, h * w, 4)) |
|
|
|
anchors = torch.cat(anchors, 1) |
|
valid_mask = ((anchors > eps) * (anchors < 1 - eps)).all(-1, keepdim=True) |
|
anchors = torch.log(anchors / (1 - anchors)) |
|
anchors = anchors.masked_fill(~valid_mask, float('inf')) |
|
return anchors, valid_mask |
|
|
|
def _get_encoder_input(self, x): |
|
|
|
x = [self.input_proj[i](feat) for i, feat in enumerate(x)] |
|
|
|
feats = [] |
|
shapes = [] |
|
for feat in x: |
|
h, w = feat.shape[2:] |
|
|
|
feats.append(feat.flatten(2).permute(0, 2, 1)) |
|
|
|
shapes.append([h, w]) |
|
|
|
|
|
feats = torch.cat(feats, 1) |
|
return feats, shapes |
|
|
|
def _get_decoder_input(self, feats, shapes, dn_embed=None, dn_bbox=None): |
|
bs = len(feats) |
|
|
|
anchors, valid_mask = self._generate_anchors(shapes, dtype=feats.dtype, device=feats.device) |
|
features = self.enc_output(valid_mask * feats) |
|
|
|
enc_outputs_scores = self.enc_score_head(features) |
|
|
|
|
|
|
|
topk_ind = torch.topk(enc_outputs_scores.max(-1).values, self.num_queries, dim=1).indices.view(-1) |
|
|
|
batch_ind = torch.arange(end=bs, dtype=topk_ind.dtype).unsqueeze(-1).repeat(1, self.num_queries).view(-1) |
|
|
|
|
|
top_k_features = features[batch_ind, topk_ind].view(bs, self.num_queries, -1) |
|
|
|
top_k_anchors = anchors[:, topk_ind].view(bs, self.num_queries, -1) |
|
|
|
|
|
refer_bbox = self.enc_bbox_head(top_k_features) + top_k_anchors |
|
|
|
enc_bboxes = refer_bbox.sigmoid() |
|
if dn_bbox is not None: |
|
refer_bbox = torch.cat([dn_bbox, refer_bbox], 1) |
|
enc_scores = enc_outputs_scores[batch_ind, topk_ind].view(bs, self.num_queries, -1) |
|
|
|
embeddings = self.tgt_embed.weight.unsqueeze(0).repeat(bs, 1, 1) if self.learnt_init_query else top_k_features |
|
if self.training: |
|
refer_bbox = refer_bbox.detach() |
|
if not self.learnt_init_query: |
|
embeddings = embeddings.detach() |
|
if dn_embed is not None: |
|
embeddings = torch.cat([dn_embed, embeddings], 1) |
|
|
|
return embeddings, refer_bbox, enc_bboxes, enc_scores |
|
|
|
|
|
def _reset_parameters(self): |
|
|
|
bias_cls = bias_init_with_prob(0.01) / 80 * self.nc |
|
|
|
|
|
constant_(self.enc_score_head.bias, bias_cls) |
|
constant_(self.enc_bbox_head.layers[-1].weight, 0.) |
|
constant_(self.enc_bbox_head.layers[-1].bias, 0.) |
|
for cls_, reg_ in zip(self.dec_score_head, self.dec_bbox_head): |
|
|
|
constant_(cls_.bias, bias_cls) |
|
constant_(reg_.layers[-1].weight, 0.) |
|
constant_(reg_.layers[-1].bias, 0.) |
|
|
|
linear_init_(self.enc_output[0]) |
|
xavier_uniform_(self.enc_output[0].weight) |
|
if self.learnt_init_query: |
|
xavier_uniform_(self.tgt_embed.weight) |
|
xavier_uniform_(self.query_pos_head.layers[0].weight) |
|
xavier_uniform_(self.query_pos_head.layers[1].weight) |
|
for layer in self.input_proj: |
|
xavier_uniform_(layer[0].weight) |
|
|