Spaces:
Running
Running
import torch | |
from torch import nn | |
from abc import ABC, abstractmethod | |
from utils.dl.common.model import get_model_device, get_model_latency, get_model_size | |
from utils.common.log import logger | |
class KTakesAll(nn.Module): | |
# k means sparsity (the larger k is, the smaller model is) | |
def __init__(self, k): | |
super(KTakesAll, self).__init__() | |
self.k = k | |
self.cached_i = None | |
def forward(self, g: torch.Tensor): | |
# k = int(g.size(1) * self.k) | |
# i = (-g).topk(k, 1)[1] | |
# t = g.scatter(1, i, 0) | |
k = int(g.size(-1) * self.k) | |
i = (-g).topk(k, -1)[1] | |
self.cached_i = i | |
t = g.scatter(-1, i, 0) | |
return t | |
class Abs(nn.Module): | |
def __init__(self): | |
super(Abs, self).__init__() | |
def forward(self, x): | |
return x.abs() | |
class Layer_WrappedWithFBS(nn.Module): | |
def __init__(self): | |
super(Layer_WrappedWithFBS, self).__init__() | |
init_sparsity = 0.5 | |
self.k_takes_all = KTakesAll(init_sparsity) | |
self.cached_raw_channel_attention = None | |
self.cached_channel_attention = None | |
self.use_cached_channel_attention = False | |
class ElasticDNNUtil(ABC): | |
def convert_raw_dnn_to_master_dnn(self, raw_dnn: nn.Module, r: float, ignore_layers=[]): | |
raise NotImplementedError | |
def convert_raw_dnn_to_master_dnn_with_perf_test(self, raw_dnn: nn.Module, r: float, ignore_layers=[]): | |
raw_dnn_size = get_model_size(raw_dnn, True) | |
master_dnn = self.convert_raw_dnn_to_master_dnn(raw_dnn, r, ignore_layers) | |
master_dnn_size = get_model_size(master_dnn, True) | |
logger.info(f'master DNN w/o FBS ({raw_dnn_size:.3f}MB) -> master DNN w/ FBS ({master_dnn_size:.3f}MB) ' | |
f'(β {(((master_dnn_size - raw_dnn_size) / raw_dnn_size) * 100.):.2f}%)') | |
return master_dnn | |
def set_master_dnn_inference_via_cached_channel_attention(self, master_dnn: nn.Module): | |
for name, module in master_dnn.named_modules(): | |
if isinstance(module, Layer_WrappedWithFBS): | |
assert module.cached_channel_attention is not None | |
module.use_cached_channel_attention = True | |
def set_master_dnn_dynamic_inference(self, master_dnn: nn.Module): | |
for name, module in master_dnn.named_modules(): | |
if isinstance(module, Layer_WrappedWithFBS): | |
module.cached_channel_attention = None | |
module.use_cached_channel_attention = False | |
def train_only_fbs_of_master_dnn(self, master_dnn: nn.Module): | |
fbs_params = [] | |
for n, p in master_dnn.named_parameters(): | |
if '.fbs' in n: | |
fbs_params += [p] | |
p.requires_grad = True | |
else: | |
p.requires_grad = False | |
return fbs_params | |
def get_accu_l1_reg_of_raw_channel_attention_in_master_dnn(self, master_dnn: nn.Module): | |
res = 0. | |
for name, module in master_dnn.named_modules(): | |
if isinstance(module, Layer_WrappedWithFBS): | |
res += module.cached_raw_channel_attention.norm(1) | |
return res | |
def get_raw_channel_attention_in_master_dnn(self, master_dnn: nn.Module): | |
res = {} | |
for name, module in master_dnn.named_modules(): | |
if isinstance(module, Layer_WrappedWithFBS): | |
res[name] = module.cached_raw_channel_attention | |
return res | |
def set_master_dnn_sparsity(self, master_dnn: nn.Module, sparsity: float): | |
assert 0 <= sparsity <= 1., sparsity | |
for name, module in master_dnn.named_modules(): | |
if isinstance(module, KTakesAll): | |
module.k = sparsity | |
logger.debug(f'set master DNN sparsity to {sparsity}') | |
def clear_cached_channel_attention_in_master_dnn(self, master_dnn: nn.Module): | |
for name, module in master_dnn.named_modules(): | |
if isinstance(module, Layer_WrappedWithFBS): | |
module.cached_raw_channel_attention = None | |
module.cached_channel_attention = None | |
def select_most_rep_sample(self, master_dnn: nn.Module, samples: torch.Tensor): | |
raise NotImplementedError | |
def extract_surrogate_dnn_via_samples(self, master_dnn: nn.Module, samples: torch.Tensor, return_detail=False): | |
raise NotImplementedError | |
def extract_surrogate_dnn_via_samples_with_perf_test(self, master_dnn: nn.Module, samples: torch.Tensor, return_detail=False): | |
master_dnn_size = get_model_size(master_dnn, True) | |
master_dnn_latency = get_model_latency(master_dnn, (1, *list(samples.size())[1:]), 50, | |
get_model_device(master_dnn), 50, False) | |
res = self.extract_surrogate_dnn_via_samples(master_dnn, samples, return_detail) | |
if not return_detail: | |
surrogate_dnn = res | |
else: | |
surrogate_dnn, unpruned_indexes_of_layers = res | |
surrogate_dnn_size = get_model_size(surrogate_dnn, True) | |
surrogate_dnn_latency = get_model_latency(surrogate_dnn, (1, *list(samples.size())[1:]), 50, | |
get_model_device(surrogate_dnn), 50, False) | |
logger.info(f'master DNN ({master_dnn_size:.3f}MB, {master_dnn_latency:.4f}s/sample) -> ' | |
f'surrogate DNN ({surrogate_dnn_size:.3f}MB, {surrogate_dnn_latency:.4f}s/sample)\n' | |
f'(model size: β {(master_dnn_size / surrogate_dnn_size):.2f}x, ' | |
f'latency: β {(master_dnn_latency / surrogate_dnn_latency):.2f}x)') | |
return res | |