EdgeTA / dnns /yolov3 /vit_yolov3.py
LINC-BIT's picture
Upload 1912 files
b84549f verified
import torch
from torch import nn
from timm.models.vision_transformer import VisionTransformer
from functools import partial
from einops import rearrange
from dnns.yolov3.yolo_fpn import YOLOFPN
from dnns.yolov3.head import YOLOXHead
from utils.dl.common.model import set_module, get_module
from types import MethodType
import os
from utils.common.log import logger
class VisionTransformerYOLOv3(VisionTransformer):
def forward_head(self, x):
# print(222)
return self.head(x)
def forward_features(self, x):
# print(111)
return self._intermediate_layers(x, n=[len(self.blocks) // 3 - 1, len(self.blocks) // 3 * 2 - 1, len(self.blocks) - 1])
def forward(self, x, targets=None):
features = self.forward_features(x)
return self.head(x, features, targets)
@staticmethod
def init_from_vit(vit: VisionTransformer):
res = VisionTransformerYOLOv3()
for attr in dir(vit):
# if str(attr) not in ['forward_head', 'forward_features'] and not attr.startswith('__'):
if isinstance(getattr(vit, attr), nn.Module):
# print(attr)
try:
setattr(res, attr, getattr(vit, attr))
except Exception as e:
print(attr, str(e))
return res
class Norm2d(nn.Module):
def __init__(self, embed_dim):
super().__init__()
self.ln = nn.LayerNorm(embed_dim, eps=1e-6)
def forward(self, x):
x = x.permute(0, 2, 3, 1)
x = self.ln(x)
x = x.permute(0, 3, 1, 2).contiguous()
return x
class ViTYOLOv3Head(nn.Module):
def __init__(self, im_size, patch_size, patch_dim, num_classes, use_bigger_fpns, cls_vit_ckpt_path, init_head):
super(ViTYOLOv3Head, self).__init__()
self.im_size = im_size
self.patch_size = patch_size
# target_patch_dim: [256, 512, 512]
# self.change_patchs_dim = nn.ModuleList([nn.Linear(patch_dim, target_patch_dim) for target_patch_dim in [256, 512, 512]])
# # input: (1, target_patch_dim, 14, 14)
# # target feature size: {40, 20, 10}
# self.change_features_size = nn.ModuleList([
# self.get_change_feature_size(cin, cout, t) for t, cin, cout in zip([40, 20, 10], [256, 512, 512], [256, 512, 512])
# ])
embed_dim = 768
self.before_fpns = nn.ModuleList([
# nn.Sequential(
# nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
# nn.GroupNorm(embed_dim),
# nn.GELU(),
# nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
# ),
nn.Sequential(
nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
),
nn.Identity(),
nn.MaxPool2d(kernel_size=2, stride=2)
])
if use_bigger_fpns == 1:
logger.info('use 421x fpns')
self.before_fpns = nn.ModuleList([
nn.Sequential(
nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
Norm2d(embed_dim),
nn.GELU(),
nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
),
nn.Sequential(
nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
),
nn.Identity(),
# nn.MaxPool2d(kernel_size=2, stride=2)
])
if use_bigger_fpns == -1:
logger.info('use 1/0.5/0.25x fpns')
self.before_fpns = nn.ModuleList([
# nn.Sequential(
# nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
# ),
nn.Identity(),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Sequential(nn.MaxPool2d(kernel_size=2, stride=2), nn.MaxPool2d(kernel_size=2, stride=2))
])
# self.fpn = YOLOFPN()
self.fpn = nn.Identity()
self.head = YOLOXHead(num_classes, in_channels=[768, 768, 768], act='lrelu')
if init_head:
logger.info('init head')
self.load_pretrained_weight(cls_vit_ckpt_path)
else:
logger.info('do not init head')
def load_pretrained_weight(self, cls_vit_ckpt_path):
ckpt = torch.load(os.path.join(os.path.dirname(__file__), 'yolox_darknet.pth'))
# for k in [f'head.cls_preds.{i}.{j}' for i in [0, 1, 2] for j in ['weight', 'bias']]:
# del ckpt['model'][k]
removed_k = [f'head.cls_preds.{i}.{j}' for i in [0, 1, 2] for j in ['weight', 'bias']]
for k, v in ckpt['model'].items():
if 'backbone.backbone' in k:
removed_k += [k]
if 'head.stems' in k and 'conv.weight' in k:
removed_k += [k]
for k in removed_k:
del ckpt['model'][k]
# print(ckpt['model'].keys())
new_state_dict = {}
for k, v in ckpt['model'].items():
new_k = k.replace('backbone', 'fpn')
new_state_dict[new_k] = v
# cls_vit_ckpt = torch.load(cls_vit_ckpt_path)
# for k, v in cls_vit_ckpt['main'].named_parameters():
# if not 'qkv.abs' not in k:
# continue
# new_state_dict[k] = v
# logger.info(f'load {k} from cls vit ckpt')
self.load_state_dict(new_state_dict, strict=False)
def get_change_feature_size(self, in_channels, out_channels, target_size):
H, W = self.im_size
GS = H // self.patch_size # 14
if target_size == GS:
return nn.Identity()
elif target_size < GS:
return nn.AdaptiveMaxPool2d((target_size, target_size))
else:
return {
20: nn.Sequential(
nn.ConvTranspose2d(in_channels, out_channels, kernel_size=(3, 3), stride=2, padding=0),
# nn.BatchNorm2d(out_channels),
# nn.ReLU(),
nn.AdaptiveMaxPool2d((target_size, target_size))
),
40: nn.Sequential(
nn.ConvTranspose2d(in_channels, out_channels, kernel_size=(3, 3), stride=3, padding=1),
# nn.BatchNorm2d(out_channels),
# nn.ReLU(),
)
}[target_size]
def forward(self, input_images, x, targets=None):
# print(111)
# NOTE: YOLOX backbone (w/o FPN) output, or FPN input: {'dark3': torch.Size([4, 256, 40, 40]), 'dark4': torch.Size([4, 512, 20, 20]), 'dark5': torch.Size([4, 512, 10, 10])}
# NOTE: YOLOXHead input: [torch.Size([4, 128, 40, 40]), torch.Size([4, 256, 20, 20]), torch.Size([4, 512, 10, 10])]
# print(x)
# print([i.size() for i in x])
x = [i[:, 1:] for i in x]
x = [i.permute(0, 2, 1).reshape(input_images.size(0), -1, 14, 14) for i in x] # 14 is hardcode, obtained from timm.layers.patch_embed.py
# print([i.size() for i in x])
# exit()
# NOTE: old
# x[0]: torch.Size([1, 196, 768])
# H, W = self.im_size
# GS = H // self.patch_size # 14
# xs = [cpd(x) for x, cpd in zip(xs, self.change_patchs_dim)] # (1, 196, target_patch_dim)
# xs = [rearrange(x, "b (h w) c -> b c h w", h=GS) for x in xs] # (1, target_patch_dim, 14, 14)
# xs = [cfs(x) for x, cfs in zip(xs, self.change_features_size)]
# print([i.size() for i in xs])
# ----------------
xs = [before_fpn(x[-1]) for i, before_fpn in zip(x, self.before_fpns)]
# print([i.size() for i in xs])
# exit()
# [torch.Size([1, 768, 28, 28]), torch.Size([1, 768, 14, 14]), torch.Size([1, 768, 7, 7])]
xs = self.fpn(xs)
# print('before head', [i.size() for i in xs])
xs = tuple(xs)
if targets is not None:
loss, iou_loss, conf_loss, cls_loss, l1_loss, num_fg = self.head(xs, targets, input_images)
return {
"total_loss": loss,
"iou_loss": iou_loss,
"l1_loss": l1_loss,
"conf_loss": conf_loss,
"cls_loss": cls_loss,
"num_fg": num_fg,
}
return self.head(xs)
class ViTYOLOv3Head2(nn.Module):
def __init__(self, im_size, patch_size, patch_dim, num_classes, use_bigger_fpns):
super(ViTYOLOv3Head2, self).__init__()
self.im_size = im_size
self.patch_size = patch_size
# target_patch_dim: [256, 512, 512]
# self.change_patchs_dim = nn.ModuleList([nn.Linear(patch_dim, target_patch_dim) for target_patch_dim in [256, 512, 512]])
# # input: (1, target_patch_dim, 14, 14)
# # target feature size: {40, 20, 10}
# self.change_features_size = nn.ModuleList([
# self.get_change_feature_size(cin, cout, t) for t, cin, cout in zip([40, 20, 10], [256, 512, 512], [256, 512, 512])
# ])
embed_dim = 768
self.before_fpns = nn.ModuleList([
# nn.Sequential(
# nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
# nn.GroupNorm(embed_dim),
# nn.GELU(),
# nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
# ),
nn.Sequential(
nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
),
nn.Identity(),
nn.MaxPool2d(kernel_size=2, stride=2)
])
if use_bigger_fpns:
logger.info('use 8/4/2x fpns')
self.before_fpns = nn.ModuleList([
nn.Sequential(
nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
Norm2d(embed_dim),
nn.GELU(),
nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
),
nn.Sequential(
nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
),
nn.Identity(),
# nn.MaxPool2d(kernel_size=2, stride=2)
])
# self.fpn = YOLOFPN()
self.fpn = nn.Identity()
self.head = YOLOXHead(num_classes, in_channels=[768, 768, 768], act='lrelu')
self.load_pretrained_weight()
def load_pretrained_weight(self):
ckpt = torch.load(os.path.join(os.path.dirname(__file__), 'yolox_darknet.pth'))
# for k in [f'head.cls_preds.{i}.{j}' for i in [0, 1, 2] for j in ['weight', 'bias']]:
# del ckpt['model'][k]
removed_k = [f'head.cls_preds.{i}.{j}' for i in [0, 1, 2] for j in ['weight', 'bias']]
for k, v in ckpt['model'].items():
if 'backbone.backbone' in k:
removed_k += [k]
if 'head.stems' in k and 'conv.weight' in k:
removed_k += [k]
for k in removed_k:
del ckpt['model'][k]
# print(ckpt['model'].keys())
new_state_dict = {}
for k, v in ckpt['model'].items():
new_k = k.replace('backbone', 'fpn')
new_state_dict[new_k] = v
self.load_state_dict(new_state_dict, strict=False)
def get_change_feature_size(self, in_channels, out_channels, target_size):
H, W = self.im_size
GS = H // self.patch_size # 14
if target_size == GS:
return nn.Identity()
elif target_size < GS:
return nn.AdaptiveMaxPool2d((target_size, target_size))
else:
return {
20: nn.Sequential(
nn.ConvTranspose2d(in_channels, out_channels, kernel_size=(3, 3), stride=2, padding=0),
# nn.BatchNorm2d(out_channels),
# nn.ReLU(),
nn.AdaptiveMaxPool2d((target_size, target_size))
),
40: nn.Sequential(
nn.ConvTranspose2d(in_channels, out_channels, kernel_size=(3, 3), stride=3, padding=1),
# nn.BatchNorm2d(out_channels),
# nn.ReLU(),
)
}[target_size]
def forward(self, input_images, x, targets=None):
# print(111)
# NOTE: YOLOX backbone (w/o FPN) output, or FPN input: {'dark3': torch.Size([4, 256, 40, 40]), 'dark4': torch.Size([4, 512, 20, 20]), 'dark5': torch.Size([4, 512, 10, 10])}
# NOTE: YOLOXHead input: [torch.Size([4, 128, 40, 40]), torch.Size([4, 256, 20, 20]), torch.Size([4, 512, 10, 10])]
# print(x)
# print([i.size() for i in x])
x = [i[:, 1:] for i in x]
x = [i.permute(0, 2, 1).reshape(input_images.size(0), -1, 14, 14) for i in x] # 14 is hardcode, obtained from timm.layers.patch_embed.py
# print([i.size() for i in x])
# exit()
# NOTE: old
# x[0]: torch.Size([1, 196, 768])
# H, W = self.im_size
# GS = H // self.patch_size # 14
# xs = [cpd(x) for x, cpd in zip(xs, self.change_patchs_dim)] # (1, 196, target_patch_dim)
# xs = [rearrange(x, "b (h w) c -> b c h w", h=GS) for x in xs] # (1, target_patch_dim, 14, 14)
# xs = [cfs(x) for x, cfs in zip(xs, self.change_features_size)]
# print([i.size() for i in xs])
# ----------------
xs = [before_fpn(i) for i, before_fpn in zip(x, self.before_fpns)]
# print([i.size() for i in xs])
# exit()
# [torch.Size([1, 768, 28, 28]), torch.Size([1, 768, 14, 14]), torch.Size([1, 768, 7, 7])]
xs = self.fpn(xs)
# print('before head', [i.size() for i in xs])
xs = tuple(xs)
if targets is not None:
loss, iou_loss, conf_loss, cls_loss, l1_loss, num_fg = self.head(xs, targets, input_images)
return {
"total_loss": loss,
"iou_loss": iou_loss,
"l1_loss": l1_loss,
"conf_loss": conf_loss,
"cls_loss": cls_loss,
"num_fg": num_fg,
}
return self.head(xs)
def _forward_head(self, x):
return self.head(x)
# def ensure_forward_head_obj_repoint(self):
# self.forward_head = MethodType(_forward_head, self)
@torch.no_grad()
def make_vit_yolov3(vit: VisionTransformer, samples: torch.Tensor, patch_size, patch_dim, num_classes,
use_bigger_fpns=False, use_multi_layer_feature=False, cls_vit_ckpt_path=None, init_head=False):
assert cls_vit_ckpt_path is None
# vit -> fpn -> head
# modify vit.forward() to make it output middle features
# vit.forward_features = partial(vit._intermediate_layers,
# n=[len(vit.blocks) // 3 - 1, len(vit.blocks) // 3 * 2 - 1, len(vit.blocks) - 1])
# vit.forward_head = _forward_head
# vit.__deepcopy__ = MethodType(ensure_forward_head_obj_repoint, vit)
vit = VisionTransformerYOLOv3.init_from_vit(vit)
if not use_multi_layer_feature:
set_module(vit, 'head', ViTYOLOv3Head(
im_size=(samples.size(2), samples.size(3)),
patch_size=patch_size,
patch_dim=patch_dim,
num_classes=num_classes,
use_bigger_fpns=use_bigger_fpns,
cls_vit_ckpt_path=cls_vit_ckpt_path,
init_head=init_head
))
else:
raise NotImplementedError
logger.info('use multi layer feature')
set_module(vit, 'head', ViTYOLOv3Head2(
im_size=(samples.size(2), samples.size(3)),
patch_size=patch_size,
patch_dim=patch_dim,
num_classes=num_classes,
use_bigger_fpns=use_bigger_fpns,
cls_vit_ckpt_path=cls_vit_ckpt_path
))
# print(vit)
vit.eval()
output = vit(samples)
# print([oo.size() for oo in output])
assert len(output) == samples.size(0) and output[0].size(1) == num_classes + 5, f'{[oo.size() for oo in output]}, {num_classes}'
return vit
if __name__ == '__main__':
from dnns.vit import vit_b_16
vit_b_16 = vit_b_16()
make_vit_yolov3(vit_b_16, torch.rand((1, 3, 224, 224)), 16, 768, 20)
exit()
from types import MethodType
class Student(object):
pass
def set_name(self, name):
self.name = name
def get_name(self):
print(self.name)
s1 = Student()
#将方法绑定到s1和s2实例中
s1.set_name = MethodType(set_name, s1)
s1.get_name = MethodType(get_name, s1)
s1.set_name('s1')
from copy import deepcopy
s2 = deepcopy(s1)
s2.get_name()
s2.set_name('s2')
s1.get_name()
s2.get_name()