|
import torch |
|
from torch import nn |
|
from timm.models.vision_transformer import VisionTransformer |
|
from functools import partial |
|
from einops import rearrange |
|
from dnns.yolov3.yolo_fpn import YOLOFPN |
|
from dnns.yolov3.head import YOLOXHead |
|
from utils.dl.common.model import set_module, get_module |
|
from types import MethodType |
|
import os |
|
from utils.common.log import logger |
|
|
|
|
|
class VisionTransformerYOLOv3(VisionTransformer): |
|
def forward_head(self, x): |
|
|
|
return self.head(x) |
|
|
|
def forward_features(self, x): |
|
|
|
return self._intermediate_layers(x, n=[len(self.blocks) // 3 - 1, len(self.blocks) // 3 * 2 - 1, len(self.blocks) - 1]) |
|
|
|
def forward(self, x, targets=None): |
|
features = self.forward_features(x) |
|
return self.head(x, features, targets) |
|
|
|
@staticmethod |
|
def init_from_vit(vit: VisionTransformer): |
|
res = VisionTransformerYOLOv3() |
|
|
|
for attr in dir(vit): |
|
|
|
if isinstance(getattr(vit, attr), nn.Module): |
|
|
|
try: |
|
setattr(res, attr, getattr(vit, attr)) |
|
except Exception as e: |
|
print(attr, str(e)) |
|
|
|
return res |
|
|
|
|
|
class Norm2d(nn.Module): |
|
def __init__(self, embed_dim): |
|
super().__init__() |
|
self.ln = nn.LayerNorm(embed_dim, eps=1e-6) |
|
def forward(self, x): |
|
x = x.permute(0, 2, 3, 1) |
|
x = self.ln(x) |
|
x = x.permute(0, 3, 1, 2).contiguous() |
|
return x |
|
|
|
|
|
class ViTYOLOv3Head(nn.Module): |
|
def __init__(self, im_size, patch_size, patch_dim, num_classes, use_bigger_fpns, cls_vit_ckpt_path, init_head): |
|
super(ViTYOLOv3Head, self).__init__() |
|
|
|
self.im_size = im_size |
|
self.patch_size = patch_size |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
embed_dim = 768 |
|
self.before_fpns = nn.ModuleList([ |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
nn.Sequential( |
|
nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2), |
|
), |
|
|
|
nn.Identity(), |
|
|
|
nn.MaxPool2d(kernel_size=2, stride=2) |
|
]) |
|
|
|
if use_bigger_fpns == 1: |
|
logger.info('use 421x fpns') |
|
self.before_fpns = nn.ModuleList([ |
|
|
|
nn.Sequential( |
|
nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2), |
|
Norm2d(embed_dim), |
|
nn.GELU(), |
|
nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2), |
|
), |
|
|
|
nn.Sequential( |
|
nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2), |
|
), |
|
|
|
nn.Identity(), |
|
|
|
|
|
]) |
|
if use_bigger_fpns == -1: |
|
logger.info('use 1/0.5/0.25x fpns') |
|
self.before_fpns = nn.ModuleList([ |
|
|
|
|
|
|
|
|
|
|
|
nn.Identity(), |
|
|
|
nn.MaxPool2d(kernel_size=2, stride=2), |
|
|
|
nn.Sequential(nn.MaxPool2d(kernel_size=2, stride=2), nn.MaxPool2d(kernel_size=2, stride=2)) |
|
]) |
|
|
|
|
|
self.fpn = nn.Identity() |
|
self.head = YOLOXHead(num_classes, in_channels=[768, 768, 768], act='lrelu') |
|
|
|
if init_head: |
|
logger.info('init head') |
|
self.load_pretrained_weight(cls_vit_ckpt_path) |
|
else: |
|
logger.info('do not init head') |
|
|
|
def load_pretrained_weight(self, cls_vit_ckpt_path): |
|
ckpt = torch.load(os.path.join(os.path.dirname(__file__), 'yolox_darknet.pth')) |
|
|
|
|
|
removed_k = [f'head.cls_preds.{i}.{j}' for i in [0, 1, 2] for j in ['weight', 'bias']] |
|
for k, v in ckpt['model'].items(): |
|
if 'backbone.backbone' in k: |
|
removed_k += [k] |
|
if 'head.stems' in k and 'conv.weight' in k: |
|
removed_k += [k] |
|
for k in removed_k: |
|
del ckpt['model'][k] |
|
|
|
|
|
new_state_dict = {} |
|
for k, v in ckpt['model'].items(): |
|
new_k = k.replace('backbone', 'fpn') |
|
new_state_dict[new_k] = v |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.load_state_dict(new_state_dict, strict=False) |
|
|
|
def get_change_feature_size(self, in_channels, out_channels, target_size): |
|
H, W = self.im_size |
|
GS = H // self.patch_size |
|
|
|
if target_size == GS: |
|
return nn.Identity() |
|
elif target_size < GS: |
|
return nn.AdaptiveMaxPool2d((target_size, target_size)) |
|
else: |
|
return { |
|
20: nn.Sequential( |
|
nn.ConvTranspose2d(in_channels, out_channels, kernel_size=(3, 3), stride=2, padding=0), |
|
|
|
|
|
nn.AdaptiveMaxPool2d((target_size, target_size)) |
|
), |
|
40: nn.Sequential( |
|
nn.ConvTranspose2d(in_channels, out_channels, kernel_size=(3, 3), stride=3, padding=1), |
|
|
|
|
|
) |
|
}[target_size] |
|
|
|
def forward(self, input_images, x, targets=None): |
|
|
|
|
|
|
|
|
|
|
|
|
|
x = [i[:, 1:] for i in x] |
|
x = [i.permute(0, 2, 1).reshape(input_images.size(0), -1, 14, 14) for i in x] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
xs = [before_fpn(x[-1]) for i, before_fpn in zip(x, self.before_fpns)] |
|
|
|
|
|
|
|
|
|
|
|
|
|
xs = self.fpn(xs) |
|
|
|
xs = tuple(xs) |
|
|
|
if targets is not None: |
|
loss, iou_loss, conf_loss, cls_loss, l1_loss, num_fg = self.head(xs, targets, input_images) |
|
return { |
|
"total_loss": loss, |
|
"iou_loss": iou_loss, |
|
"l1_loss": l1_loss, |
|
"conf_loss": conf_loss, |
|
"cls_loss": cls_loss, |
|
"num_fg": num_fg, |
|
} |
|
|
|
return self.head(xs) |
|
|
|
|
|
class ViTYOLOv3Head2(nn.Module): |
|
def __init__(self, im_size, patch_size, patch_dim, num_classes, use_bigger_fpns): |
|
super(ViTYOLOv3Head2, self).__init__() |
|
|
|
self.im_size = im_size |
|
self.patch_size = patch_size |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
embed_dim = 768 |
|
self.before_fpns = nn.ModuleList([ |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
nn.Sequential( |
|
nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2), |
|
), |
|
|
|
nn.Identity(), |
|
|
|
nn.MaxPool2d(kernel_size=2, stride=2) |
|
]) |
|
|
|
if use_bigger_fpns: |
|
logger.info('use 8/4/2x fpns') |
|
self.before_fpns = nn.ModuleList([ |
|
|
|
nn.Sequential( |
|
nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2), |
|
Norm2d(embed_dim), |
|
nn.GELU(), |
|
nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2), |
|
), |
|
|
|
nn.Sequential( |
|
nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2), |
|
), |
|
|
|
nn.Identity(), |
|
|
|
|
|
]) |
|
|
|
|
|
self.fpn = nn.Identity() |
|
self.head = YOLOXHead(num_classes, in_channels=[768, 768, 768], act='lrelu') |
|
|
|
self.load_pretrained_weight() |
|
|
|
def load_pretrained_weight(self): |
|
ckpt = torch.load(os.path.join(os.path.dirname(__file__), 'yolox_darknet.pth')) |
|
|
|
|
|
removed_k = [f'head.cls_preds.{i}.{j}' for i in [0, 1, 2] for j in ['weight', 'bias']] |
|
for k, v in ckpt['model'].items(): |
|
if 'backbone.backbone' in k: |
|
removed_k += [k] |
|
if 'head.stems' in k and 'conv.weight' in k: |
|
removed_k += [k] |
|
for k in removed_k: |
|
del ckpt['model'][k] |
|
|
|
|
|
new_state_dict = {} |
|
for k, v in ckpt['model'].items(): |
|
new_k = k.replace('backbone', 'fpn') |
|
new_state_dict[new_k] = v |
|
self.load_state_dict(new_state_dict, strict=False) |
|
|
|
def get_change_feature_size(self, in_channels, out_channels, target_size): |
|
H, W = self.im_size |
|
GS = H // self.patch_size |
|
|
|
if target_size == GS: |
|
return nn.Identity() |
|
elif target_size < GS: |
|
return nn.AdaptiveMaxPool2d((target_size, target_size)) |
|
else: |
|
return { |
|
20: nn.Sequential( |
|
nn.ConvTranspose2d(in_channels, out_channels, kernel_size=(3, 3), stride=2, padding=0), |
|
|
|
|
|
nn.AdaptiveMaxPool2d((target_size, target_size)) |
|
), |
|
40: nn.Sequential( |
|
nn.ConvTranspose2d(in_channels, out_channels, kernel_size=(3, 3), stride=3, padding=1), |
|
|
|
|
|
) |
|
}[target_size] |
|
|
|
def forward(self, input_images, x, targets=None): |
|
|
|
|
|
|
|
|
|
|
|
|
|
x = [i[:, 1:] for i in x] |
|
x = [i.permute(0, 2, 1).reshape(input_images.size(0), -1, 14, 14) for i in x] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
xs = [before_fpn(i) for i, before_fpn in zip(x, self.before_fpns)] |
|
|
|
|
|
|
|
|
|
|
|
|
|
xs = self.fpn(xs) |
|
|
|
xs = tuple(xs) |
|
|
|
if targets is not None: |
|
loss, iou_loss, conf_loss, cls_loss, l1_loss, num_fg = self.head(xs, targets, input_images) |
|
return { |
|
"total_loss": loss, |
|
"iou_loss": iou_loss, |
|
"l1_loss": l1_loss, |
|
"conf_loss": conf_loss, |
|
"cls_loss": cls_loss, |
|
"num_fg": num_fg, |
|
} |
|
|
|
return self.head(xs) |
|
|
|
|
|
def _forward_head(self, x): |
|
return self.head(x) |
|
|
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
def make_vit_yolov3(vit: VisionTransformer, samples: torch.Tensor, patch_size, patch_dim, num_classes, |
|
use_bigger_fpns=False, use_multi_layer_feature=False, cls_vit_ckpt_path=None, init_head=False): |
|
|
|
assert cls_vit_ckpt_path is None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
vit = VisionTransformerYOLOv3.init_from_vit(vit) |
|
|
|
if not use_multi_layer_feature: |
|
set_module(vit, 'head', ViTYOLOv3Head( |
|
im_size=(samples.size(2), samples.size(3)), |
|
patch_size=patch_size, |
|
patch_dim=patch_dim, |
|
num_classes=num_classes, |
|
use_bigger_fpns=use_bigger_fpns, |
|
cls_vit_ckpt_path=cls_vit_ckpt_path, |
|
init_head=init_head |
|
)) |
|
|
|
|
|
else: |
|
raise NotImplementedError |
|
logger.info('use multi layer feature') |
|
set_module(vit, 'head', ViTYOLOv3Head2( |
|
im_size=(samples.size(2), samples.size(3)), |
|
patch_size=patch_size, |
|
patch_dim=patch_dim, |
|
num_classes=num_classes, |
|
use_bigger_fpns=use_bigger_fpns, |
|
cls_vit_ckpt_path=cls_vit_ckpt_path |
|
)) |
|
|
|
|
|
|
|
vit.eval() |
|
output = vit(samples) |
|
|
|
assert len(output) == samples.size(0) and output[0].size(1) == num_classes + 5, f'{[oo.size() for oo in output]}, {num_classes}' |
|
|
|
return vit |
|
|
|
|
|
if __name__ == '__main__': |
|
from dnns.vit import vit_b_16 |
|
vit_b_16 = vit_b_16() |
|
make_vit_yolov3(vit_b_16, torch.rand((1, 3, 224, 224)), 16, 768, 20) |
|
exit() |
|
|
|
from types import MethodType |
|
|
|
class Student(object): |
|
pass |
|
|
|
def set_name(self, name): |
|
self.name = name |
|
|
|
def get_name(self): |
|
print(self.name) |
|
|
|
s1 = Student() |
|
|
|
s1.set_name = MethodType(set_name, s1) |
|
s1.get_name = MethodType(get_name, s1) |
|
s1.set_name('s1') |
|
|
|
from copy import deepcopy |
|
s2 = deepcopy(s1) |
|
s2.get_name() |
|
|
|
s2.set_name('s2') |
|
s1.get_name() |
|
s2.get_name() |