Spaces:
Running
Running
from typing import Any, Dict | |
from schema import Schema, Or | |
import schema | |
from data import Scenario, MergedDataset | |
from methods.base.alg import BaseAlg | |
from data import build_dataloader | |
from ..model import ElasticDNN_OfflineFMModel, ElasticDNN_OfflineMDModel | |
from ...model.base import ElasticDNNUtil | |
import torch.optim | |
import tqdm | |
import torch.nn.functional as F | |
from torch import nn | |
from utils.dl.common.env import create_tbwriter | |
import os | |
import random | |
import numpy as np | |
from copy import deepcopy | |
from utils.dl.common.model import LayerActivation, get_module | |
from utils.common.log import logger | |
class ElasticDNN_MDPretrainingWFBSAlg(BaseAlg): | |
""" | |
TODO: fine-tuned FM -> init MD -> trained MD -> construct indexes (only between similar weights) and fine-tune | |
""" | |
def get_required_models_schema(self) -> Schema: | |
return Schema({ | |
'fm': ElasticDNN_OfflineFMModel, | |
'md': ElasticDNN_OfflineMDModel | |
}) | |
def get_required_hyp_schema(self) -> Schema: | |
return Schema({ | |
'launch_tbboard': bool, | |
'samples_size': (int, int, int, int), | |
'generate_md_width_ratio': int, | |
'FBS_r': int, | |
'FBS_ignore_layers': [str], | |
'train_batch_size': int, | |
'val_batch_size': int, | |
'num_workers': int, | |
'optimizer': str, | |
'optimizer_args': dict, | |
'scheduler': str, | |
'scheduler_args': dict, | |
'num_iters': int, | |
'val_freq': int, | |
'max_sparsity': float, | |
'min_sparsity': float, | |
'l1_reg_loss_weight': float, | |
'val_num_sparsities': int, | |
'bn_cal_num_iters': int | |
}) | |
def bn_cal(self, model: nn.Module, train_loader, num_iters, device): | |
has_bn = False | |
for n, m in model.named_modules(): | |
if isinstance(m, nn.BatchNorm2d): | |
has_bn = True | |
break | |
if not has_bn: | |
return {} | |
def bn_calibration_init(m): | |
""" calculating post-statistics of batch normalization """ | |
if getattr(m, 'track_running_stats', False): | |
# reset all values for post-statistics | |
m.reset_running_stats() | |
# set bn in training mode to update post-statistics | |
m.training = True | |
with torch.no_grad(): | |
model.eval() | |
model.apply(bn_calibration_init) | |
for _ in range(num_iters): | |
x, _ = next(train_loader) | |
model(x.to(device)) | |
model.eval() | |
bn_stats = {} | |
for n, m in model.named_modules(): | |
if isinstance(m, nn.BatchNorm2d): | |
bn_stats[n] = m | |
return bn_stats | |
def run(self, scenario: Scenario, hyps: Dict) -> Dict[str, Any]: | |
super().run(scenario, hyps) | |
assert isinstance(self.models['md'], ElasticDNN_OfflineMDModel) # for auto completion | |
assert isinstance(self.models['fm'], ElasticDNN_OfflineFMModel) # for auto completion | |
# 1. add FBS | |
device = self.models['md'].device | |
# 2. train (knowledge distillation, index relationship) | |
offline_datasets = scenario.get_offline_datasets() | |
train_dataset = MergedDataset([d['train'] for d in offline_datasets.values()]) | |
val_dataset = MergedDataset([d['val'] for d in offline_datasets.values()]) | |
train_loader = iter(build_dataloader(train_dataset, hyps['train_batch_size'], hyps['num_workers'], | |
True, None)) | |
val_loader = build_dataloader(val_dataset, hyps['val_batch_size'], hyps['num_workers'], | |
False, False) | |
logger.info(f'master DNN acc before inserting FBS: {self.models["md"].get_accuracy(val_loader):.4f}') | |
master_dnn = self.models['md'].models_dict['main'] | |
elastic_dnn_util = self.models['fm'].get_elastic_dnn_util() | |
master_dnn = elastic_dnn_util.convert_raw_dnn_to_master_dnn_with_perf_test(master_dnn, hyps['FBS_r'], hyps['FBS_ignore_layers']).to(device) | |
self.models['md'].models_dict['main'] = master_dnn | |
# 2.1 train whole master DNN (knowledge distillation) | |
for p in master_dnn.parameters(): | |
p.requires_grad = True | |
self.models['md'].to_train_mode() | |
optimizer = torch.optim.__dict__[hyps['optimizer']]([ | |
{'params': self.models['md'].models_dict['main'].parameters(), **hyps['optimizer_args']} | |
]) | |
scheduler = torch.optim.lr_scheduler.__dict__[hyps['scheduler']](optimizer, **hyps['scheduler_args']) | |
tb_writer = create_tbwriter(os.path.join(self.res_save_dir, 'tb_log'), launch_tbboard=hyps['launch_tbboard']) | |
pbar = tqdm.tqdm(range(hyps['num_iters']), dynamic_ncols=True) | |
best_avg_val_acc = 0. | |
for iter_index in pbar: | |
self.models['md'].to_train_mode() | |
self.models['fm'].to_eval_mode() | |
rand_sparsity = random.random() * (hyps['max_sparsity'] - hyps['min_sparsity']) + hyps['min_sparsity'] | |
elastic_dnn_util.set_master_dnn_sparsity(self.models['md'].models_dict['main'], rand_sparsity) | |
x, y = next(train_loader) | |
x, y = x.to(device), y.to(device) | |
task_loss = self.models['md'].forward_to_get_task_loss(x, y) | |
l1_reg_loss = hyps['l1_reg_loss_weight'] * elastic_dnn_util.get_accu_l1_reg_of_raw_channel_attention_in_master_dnn(master_dnn) | |
total_loss = task_loss + l1_reg_loss | |
optimizer.zero_grad() | |
total_loss.backward() | |
optimizer.step() | |
scheduler.step() | |
if (iter_index + 1) % hyps['val_freq'] == 0: | |
elastic_dnn_util.clear_cached_channel_attention_in_master_dnn(self.models['md'].models_dict['main']) | |
cur_md = self.models['md'].models_dict['main'] | |
md_for_test = deepcopy(self.models['md'].models_dict['main']) | |
val_accs = {} | |
avg_val_acc = 0. | |
bn_stats = {} | |
for val_sparsity in np.linspace(hyps['min_sparsity'], hyps['max_sparsity'], num=hyps['val_num_sparsities']): | |
elastic_dnn_util.set_master_dnn_sparsity(md_for_test, val_sparsity) | |
bn_stats[f'{val_sparsity:.4f}'] = self.bn_cal(md_for_test, train_loader, hyps['bn_cal_num_iters'], device) | |
# generate seperate surrogate DNN | |
test_sd = elastic_dnn_util.extract_surrogate_dnn_via_samples_with_perf_test(md_for_test, x) | |
self.models['md'].models_dict['main'] = test_sd | |
self.models['md'].to_eval_mode() | |
val_acc = self.models['md'].get_accuracy(val_loader) | |
val_accs[f'{val_sparsity:.4f}'] = val_acc | |
avg_val_acc += val_acc | |
avg_val_acc /= hyps['val_num_sparsities'] | |
self.models['md'].models_dict['main'] = cur_md | |
self.models['md'].models_dict['bn_stats'] = bn_stats | |
self.models['md'].save_model(os.path.join(self.res_save_dir, 'models/md_last.pt')) | |
self.models['fm'].save_model(os.path.join(self.res_save_dir, 'models/fm_last.pt')) | |
if avg_val_acc > best_avg_val_acc: | |
best_avg_val_acc = avg_val_acc | |
self.models['md'].save_model(os.path.join(self.res_save_dir, 'models/md_best.pt')) | |
self.models['fm'].save_model(os.path.join(self.res_save_dir, 'models/fm_best.pt')) | |
tb_writer.add_scalars(f'losses', dict(task=task_loss, l1_reg=l1_reg_loss, total=total_loss), iter_index) | |
pbar.set_description(f'loss: {total_loss:.6f}') | |
if (iter_index + 1) >= hyps['val_freq']: | |
tb_writer.add_scalars(f'accs/val_accs', val_accs, iter_index) | |
tb_writer.add_scalar(f'accs/avg_val_acc', avg_val_acc, iter_index) | |
pbar.set_description(f'loss: {total_loss:.6f}, task_loss: {task_loss:.6f}, ' | |
f'l1_loss: {l1_reg_loss:.6f}, avg_val_acc: {avg_val_acc:.4f}') | |