EdgeTA / utils /dl /common /pruning.py
LINC-BIT's picture
Upload 1912 files
b84549f verified
raw
history blame
3.29 kB
import copy
import os
import sys
from tabnanny import verbose
from typing import List, Optional, Tuple
import torch
from ...third_party.nni_new.algorithms.compression.pytorch.pruning import L1FilterPruner
from ...third_party.nni_new.compression.pytorch.speedup import ModelSpeedup
from ...common.others import get_cur_time_str
def _prune_module(model, pruner, model_input_size, device, verbose=False, need_return_mask=False):
pruner.compress()
pid = os.getpid()
timestamp = get_cur_time_str()
tmp_model_path = './tmp_weight-{}-{}.pth'.format(pid, timestamp)
tmp_mask_path = './tmp_mask-{}-{}.pth'.format(pid, timestamp)
pruner.export_model(model_path=tmp_model_path, mask_path=tmp_mask_path)
os.remove(tmp_model_path)
# speed up
dummy_input = torch.rand(model_input_size).to(device)
pruned_model = model
pruned_model.eval()
model_speedup = ModelSpeedup(pruned_model, dummy_input, tmp_mask_path, device)
fixed_mask = model_speedup.speedup_model()
if not need_return_mask:
os.remove(tmp_mask_path)
return pruned_model
else:
mask = fixed_mask
os.remove(tmp_mask_path)
return pruned_model, mask
def l1_prune_model(model: torch.nn.Module, pruned_layers_name: Optional[List[str]], sparsity: float,
model_input_size: Tuple[int], device: str, verbose=False, need_return_mask=False, dep_aware=False):
"""Get the pruned model via L1 Filter Pruning.
Reference:
Li H, Kadav A, Durdanovic I, et al. Pruning filters for efficient convnets[J]. arXiv preprint arXiv:1608.08710, 2016.
Args:
model (torch.nn.Module): A PyTorch model.
pruned_layers_name (Optional[List[str]]): Which layers will be pruned. If it's `None`, all layers will be pruned.
sparsity (float): Target sparsity. The pruned model is smaller if sparsity is higher.
model_input_size (Tuple[int]): Typically be `(1, 3, 32, 32)` or `(1, 3, 224, 224)`.
device (str): Typically be 'cpu' or 'cuda'.
verbose (bool, optional): Whether to output the verbose log. Defaults to False. (BUG TO FIX)
need_return_mask (bool, optional): Return the fine-grained mask generated by NNI framework for debug. Defaults to False.
dep_aware (bool, optional): Refers to the argument `dependency_aware` in NNI framework. Defaults to False.
Returns:
torch.nn.Module: Pruned model.
"""
model = copy.deepcopy(model).to(device)
if sparsity == 0:
return model
pruned_model = copy.deepcopy(model).to(device)
# generate mask
model.eval()
if pruned_layers_name is not None:
config_list = [{
'op_types': ['Conv2d', 'ConvTranspose2d'],
'op_names': pruned_layers_name,
'sparsity': sparsity
}]
else:
config_list = [{
'op_types': ['Conv2d', 'ConvTranspose2d'],
'sparsity': sparsity
}]
pruner = L1FilterPruner(model, config_list, dependency_aware=dep_aware,
dummy_input=torch.rand(model_input_size).to(device) if dep_aware else None)
pruned_model = _prune_module(pruned_model, pruner, model_input_size, device, verbose, need_return_mask)
return pruned_model