from turtle import forward from typing import Optional import torch import copy from torch import nn #from methods.utils.data import get_source_dataloader from utils.dl.common.model import get_model_device, get_model_latency, get_model_size, get_module, get_super_module, set_module from utils.common.log import logger """ No real speedup. But it's ok because our big model just forward for one time to find the best sub-model. The sub-model doesn't contain filter selection modules. It's just a normal model. """ class KTakesAll(nn.Module): def __init__(self, k): super(KTakesAll, self).__init__() self.k = k def forward(self, g: torch.Tensor): # if self.k == 0.: # t = g # t = t / torch.sum(t, dim=1).unsqueeze(1) * t.size(1) # return t.unsqueeze(2).unsqueeze(3) # t = g # t = t / torch.sum(t, dim=1).unsqueeze(1) * t.size(1) # # print('000', t.size()) # t = t.unsqueeze(2).unsqueeze(3).mean((0, 2, 3)).unsqueeze(0).unsqueeze(2).unsqueeze(3) # # print('111', t.size()) # # print(t) # return t # # assert x.dim() == 2 # print(g) k = int(g.size(1) * self.k) i = (-g).topk(k, 1)[1] t = g.scatter(1, i, 0) # t = t / torch.sum(t, dim=1).unsqueeze(1) * t.size(1) # print(t) return t.unsqueeze(2).unsqueeze(3) # g = g.mean(0).unsqueeze(0) # k = int(g.size(1) * self.k) # i = (-g).topk(k, 1)[1] # t = g.scatter(1, i, 0) # t = t / torch.sum(t, dim=1).unsqueeze(1) * t.size(1) # return t.unsqueeze(2).unsqueeze(3) # class NoiseAdd(nn.Module): # def __init__(self): # super(NoiseAdd, self).__init__() # self.training = True # def forward(self, x): # if self.training: # return x + torch.randn_like(x, device=x.device) # else: # return x class Abs(nn.Module): def __init__(self): super(Abs, self).__init__() def forward(self, x): return x.abs() class DomainDynamicConv2d(nn.Module): def __init__(self, raw_conv2d: nn.Conv2d, raw_bn: nn.BatchNorm2d, k: float, bn_after_fc=False): super(DomainDynamicConv2d, self).__init__() assert not bn_after_fc self.filter_selection_module = nn.Sequential( Abs(), nn.AdaptiveAvgPool2d(1), nn.Flatten(), nn.Linear(raw_conv2d.in_channels, raw_conv2d.out_channels), # nn.Conv2d(raw_conv2d.in_channels, raw_conv2d.out_channels // 16, kernel_size=1, bias=False), # nn.Linear(raw_conv2d.in_channels, raw_conv2d.out_channels // 16), # nn.BatchNorm1d(raw_conv2d.out_channels // 16) if bn_after_fc else nn.Identity(), # nn.ReLU(), # nn.Linear(raw_conv2d.out_channels // 16, raw_conv2d.out_channels), # nn.BatchNorm1d(raw_conv2d.out_channels), nn.ReLU(), # NoiseAdd(), # nn.Sigmoid() # L1RegTrack(), # KTakesAll(k) ) self.k_takes_all = KTakesAll(k) self.raw_conv2d = raw_conv2d self.bn = raw_bn # remember clear the original BNs in the network nn.init.constant_(self.filter_selection_module[3].bias, 1.) nn.init.kaiming_normal_(self.filter_selection_module[3].weight) self.cached_raw_w = None self.l1_reg_of_raw_w = None self.cached_w = None self.static_w = None self.pruning_ratios = None def forward(self, x): raw_x = self.bn(self.raw_conv2d(x)) # if self.k_takes_all.k < 1e-7: # return raw_x if self.static_w is None: raw_w = self.filter_selection_module(x) self.cached_raw_w = raw_w # self.l1_reg_of_raw_w = raw_w.norm(1, dim=1).mean() self.l1_reg_of_raw_w = raw_w.norm(1) w = self.k_takes_all(raw_w) # w = w.unsqueeze(2).unsqueeze(3) # if self.training: # soft_w = torch.max(torch.zeros_like(raw_w), torch.min(torch.ones_like(raw_w), # 1.2 * (torch.sigmoid(raw_w + torch.randn_like(raw_w))) - 0.1)) # else: # soft_w = torch.max(torch.zeros_like(raw_w), torch.min(torch.ones_like(raw_w), # 1.2 * (torch.sigmoid(raw_w)) - 0.1)) # w = soft_w.detach().clone() # w[w < 0.5] = 0. # w[w >= 0.5] = 1. # w = w + soft_w - soft_w.detach() # w = w.unsqueeze(2).unsqueeze(3) # soft_w = soft_w.unsqueeze(2).unsqueeze(3) # self.l1_reg_of_raw_w = soft_w.norm(1) self.cached_w = w # print(w.size(), x.size(), raw_x.size()) else: w = self.static_w.unsqueeze(0).unsqueeze(2).unsqueeze(3) if self.pruning_ratios is not None: # self.pruning_ratios += [1. - float((w_of_a_asample > 0.).sum() / w_of_a_asample.numel()) for w_of_a_asample in w] self.pruning_ratios += [torch.sum(w > 0.) / w.numel()] return raw_x * w # def to_static(self): # global_w = self.cached_raw_w.detach().topk(0.25, 1)[0].mean(0).unsqueeze(0) # global_w = self.k_takes_all(global_w).squeeze(0) # self.static_w = global_w # def to_dynamic(self): # self.static_w = None def boost_raw_model_with_filter_selection(model: nn.Module, init_k: float, bn_after_fc=False, ignore_layers=None, perf_test=True, model_input_size: Optional[tuple]=None): model = copy.deepcopy(model) device = get_model_device(model) if perf_test: before_model_size = get_model_size(model, True) before_model_latency = get_model_latency( model, model_input_size, 50, device, 50) # clear original BNs num_original_bns = 0 last_conv_name = None conv_bn_map = {} for name, module in model.named_modules(): if isinstance(module, nn.Conv2d): last_conv_name = name if isinstance(module, nn.BatchNorm2d) and (ignore_layers is not None and last_conv_name not in ignore_layers): # set_module(model, name, nn.Identity()) num_original_bns += 1 conv_bn_map[last_conv_name] = name num_conv = 0 for name, module in model.named_modules(): if isinstance(module, nn.Conv2d) and (ignore_layers is not None and name not in ignore_layers): set_module(model, name, DomainDynamicConv2d(module, get_module(model, conv_bn_map[name]), init_k, bn_after_fc)) num_conv += 1 assert num_conv == num_original_bns for bn_layer in conv_bn_map.values(): set_module(model, bn_layer, nn.Identity()) if perf_test: after_model_size = get_model_size(model, True) after_model_latency = get_model_latency( model, model_input_size, 50, device, 50) logger.info(f'raw model -> raw model w/ filter selection:\n' f'model size: {before_model_size:.3f}MB -> {after_model_size:.3f}MB ' f'latency: {before_model_latency:.6f}s -> {after_model_latency:.6f}s') return model, conv_bn_map def get_l1_reg_in_model(boosted_model): res = 0. for name, module in boosted_model.named_modules(): if isinstance(module, DomainDynamicConv2d): res += module.l1_reg_of_raw_w return res def get_cached_w(model): res = [] for name, module in model.named_modules(): if isinstance(module, DomainDynamicConv2d): res += [module.cached_w] return torch.cat(res, dim=1) def set_pruning_rate(model, k): for name, module in model.named_modules(): if isinstance(module, KTakesAll): module.k = k def get_cached_raw_w(model): res = [] for name, module in model.named_modules(): if isinstance(module, DomainDynamicConv2d): res += [module.cached_raw_w] return torch.cat(res, dim=1) def start_accmu_flops(model): for name, module in model.named_modules(): if isinstance(module, DomainDynamicConv2d): module.pruning_ratios = [] def get_accmu_flops(model): layer_res = {} total_res = [] for name, module in model.named_modules(): if isinstance(module, DomainDynamicConv2d): layer_res[name] = module.pruning_ratios total_res += module.pruning_ratios module.pruning_ratios = None avg_pruning_ratio = sum(total_res) / len(total_res) return layer_res, total_res, avg_pruning_ratio def convert_boosted_model_to_static(boosted_model, a_few_data): boosted_model(a_few_data) for name, module in boosted_model.named_modules(): if isinstance(module, DomainDynamicConv2d): module.to_static() # TODO: use fn3 techniques def ensure_boosted_model_to_dynamic(boosted_model): for name, module in boosted_model.named_modules(): if isinstance(module, DomainDynamicConv2d): module.to_dynamic() def train_only_gate(model): gate_params = [] for n, p in model.named_parameters(): if 'filter_selection_module' in n: gate_params += [p] else: p.requires_grad = False return gate_params if __name__ == '__main__': # rand_input = torch.rand((256, 3, 32, 32)) # conv = nn.Conv2d(3, 64, 3, 1, 1, bias=False) # new_conv = DomainDynamicConv2d(conv, 0.1) # train_dataloader = get_source_dataloader('CIFAR100', 256, 4, 'train', True, None, True) # rand_input, _ = next(train_dataloader) # start_accmu_flops(new_conv) # new_conv(rand_input) # _, total_pruning_ratio, avg_pruning_ratio = get_accmu_flops(new_conv) # import matplotlib.pyplot as plt # plt.hist(total_pruning_ratio) # plt.savefig('./tmp.png') # plt.clf() # print(avg_pruning_ratio) # with torch.no_grad(): # conv(rand_input) # new_conv(rand_input) # from torchvision.models import resnet18 # model = resnet18() # boost_raw_model_with_filter_selection(model, 0.5, True, (1, 3, 224, 224)) # rand_input = torch.rand((2, 3, 32, 32)) # conv = nn.Conv2d(3, 4, 3, 1, 1, bias=False) # w = torch.rand((1, 4)).repeat(2, 1) # with torch.no_grad(): # o1 = conv(rand_input) * w.unsqueeze(2).unsqueeze(3) # print(w) # w = w.mean(0).unsqueeze(1).unsqueeze(2).unsqueeze(3) # print(w) # conv.weight.data.mul_(w) # o2 = conv(rand_input) # diff = ((o1 - o2) ** 2).sum() # print(diff) # rand_input = torch.rand((2, 3, 32, 32)) # conv1 = nn.Conv2d(3, 6, 3, 1, 1, bias=False) # conv2 = nn.Conv2d(3, 3, 3, 1, 1, bias=False, groups=3) # print(conv1.weight.data.size(), conv2.weight.data.size()) # import time # import torch # from utils.dl.common.model import get_model_latency # # s, e = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True) # # s.record() # # # TODO # # e.record() # # torch.cuda.synchronize() # # time_usage = s.elapsed_time(e) / 1000. # # print(time_usage) # data = [torch.rand((512, 3, 3)).cuda() for _ in range(512)] # # t1 = time.time() # # for i in range(300): d = torch.stack(data) # # t2 = time.time() # # for i in range(300): d = torch.cat(data).view(512, 512, 3, 3) # # t3 = time.time() # # print("torch.stack time: {}, torch.cat time: {}".format(t2 - t1, t3 - t2)) # s, e = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True) # s.record() # for i in range(300): d = torch.stack(data) # e.record() # torch.cuda.synchronize() # time_usage = s.elapsed_time(e) / 1000. # print(time_usage) # s, e = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True) # s.record() # for i in range(300): d = torch.cat(data).view(512, 512, 3, 3) # e.record() # torch.cuda.synchronize() # time_usage = s.elapsed_time(e) / 1000. # print(time_usage) # from models.resnet_cifar.resnet_cifar_3 import resnet18 # model = resnet18() # full_l1_reg = 0. # for name, module in model.named_modules(): # if isinstance(module, nn.Conv2d): # w = torch.ones((256, module.out_channels)) # w[:, (module.out_channels // 2):] = 0. # full_l1_reg += w.norm(1) # full_l1_reg /= 2 # print(f'{full_l1_reg:.3e}') # def f(x): # # x = x - 0.5 # return torch.max(torch.zeros_like(x), torch.min(torch.ones_like(x), 1.2 * torch.sigmoid(x) - 0.1)) # x = torch.arange(-2, 2, 0.01).float() # y = f(x) # print(f(torch.FloatTensor([0.]))) # print(f(torch.FloatTensor([0.5]))) # import matplotlib.pyplot as plt # plt.plot(x, y) # plt.savefig('./tmp.png') # rand_input = torch.rand((256, 3, 32, 32)) # conv = nn.Conv2d(3, 64, 3, 1, 1, bias=False) # new_conv = DomainDynamicConv2d(conv, 0.1) # new_conv(rand_input) # conv = nn.Conv2d(3, 64, 3, 1, 1, bias=False) # new_conv = DomainDynamicConv2d(conv, nn.BatchNorm2d(64), 0.1) # print(new_conv.filter_selection_module[5].training) # new_conv.eval() # print(new_conv.filter_selection_module[5].training) n = KTakesAll(0.6) rand_input = torch.rand((1, 5)) print(n(rand_input))