Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import torch | |
import torch.nn as nn | |
from mmengine.model import BaseModule | |
from mmpretrain.models.heads import ClsHead | |
from mmpretrain.registry import MODELS | |
from ..utils import build_norm_layer | |
class BatchNormLinear(BaseModule): | |
def __init__(self, in_channels, out_channels, norm_cfg=dict(type='BN1d')): | |
super(BatchNormLinear, self).__init__() | |
self.bn = build_norm_layer(norm_cfg, in_channels) | |
self.linear = nn.Linear(in_channels, out_channels) | |
def fuse(self): | |
w = self.bn.weight / (self.bn.running_var + self.bn.eps)**0.5 | |
b = self.bn.bias - self.bn.running_mean * \ | |
self.bn.weight / (self.bn.running_var + self.bn.eps) ** 0.5 | |
w = self.linear.weight * w[None, :] | |
b = (self.linear.weight @ b[:, None]).view(-1) + self.linear.bias | |
self.linear.weight.data.copy_(w) | |
self.linear.bias.data.copy_(b) | |
return self.linear | |
def forward(self, x): | |
x = self.bn(x) | |
x = self.linear(x) | |
return x | |
def fuse_parameters(module): | |
for child_name, child in module.named_children(): | |
if hasattr(child, 'fuse'): | |
setattr(module, child_name, child.fuse()) | |
else: | |
fuse_parameters(child) | |
class LeViTClsHead(ClsHead): | |
def __init__(self, | |
num_classes=1000, | |
distillation=True, | |
in_channels=None, | |
deploy=False, | |
**kwargs): | |
super(LeViTClsHead, self).__init__(**kwargs) | |
self.num_classes = num_classes | |
self.distillation = distillation | |
self.deploy = deploy | |
self.head = BatchNormLinear(in_channels, num_classes) | |
if distillation: | |
self.head_dist = BatchNormLinear(in_channels, num_classes) | |
if self.deploy: | |
self.switch_to_deploy(self) | |
def switch_to_deploy(self): | |
if self.deploy: | |
return | |
fuse_parameters(self) | |
self.deploy = True | |
def forward(self, x): | |
x = self.pre_logits(x) | |
if self.distillation: | |
x = self.head(x), self.head_dist(x) # 2 16 384 -> 2 1000 | |
if not self.training: | |
x = (x[0] + x[1]) / 2 | |
else: | |
raise NotImplementedError("MMPretrain doesn't support " | |
'training in distillation mode.') | |
else: | |
x = self.head(x) | |
return x | |