HungNP
New single commit message
cb80c28
import logging
import numpy as np
import torch
from torch import nn
from torch import optim
from torch.nn import functional as F
from torch.utils.data import DataLoader
from models.base import BaseLearner
from utils.inc_net import IncrementalNetWithBias
epochs = 170
lrate = 0.1
milestones = [60, 100, 140]
lrate_decay = 0.1
batch_size = 128
split_ratio = 0.1
T = 2
weight_decay = 2e-4
num_workers = 8
class BiC(BaseLearner):
def __init__(self, args):
super().__init__(args)
self._network = IncrementalNetWithBias(
args, False, bias_correction=True
)
self._class_means = None
def after_task(self):
self._old_network = self._network.copy().freeze()
self._known_classes = self._total_classes
logging.info("Exemplar size: {}".format(self.exemplar_size))
def incremental_train(self, data_manager):
self._cur_task += 1
self._total_classes = self._known_classes + data_manager.get_task_size(
self._cur_task
)
self._network.update_fc(self._total_classes)
logging.info(
"Learning on {}-{}".format(self._known_classes, self._total_classes)
)
if self._cur_task >= 1:
train_dset, val_dset = data_manager.get_dataset_with_split(
np.arange(self._known_classes, self._total_classes),
source="train",
mode="train",
appendent=self._get_memory(),
val_samples_per_class=int(
split_ratio * self._memory_size / self._known_classes
),
)
self.val_loader = DataLoader(
val_dset, batch_size=batch_size, shuffle=True, num_workers=num_workers
)
logging.info(
"Stage1 dset: {}, Stage2 dset: {}".format(
len(train_dset), len(val_dset)
)
)
self.lamda = self._known_classes / self._total_classes
logging.info("Lambda: {:.3f}".format(self.lamda))
else:
train_dset = data_manager.get_dataset(
np.arange(self._known_classes, self._total_classes),
source="train",
mode="train",
appendent=self._get_memory(),
)
test_dset = data_manager.get_dataset(
np.arange(0, self._total_classes), source="test", mode="test"
)
self.train_loader = DataLoader(
train_dset, batch_size=batch_size, shuffle=True, num_workers=num_workers
)
self.test_loader = DataLoader(
test_dset, batch_size=batch_size, shuffle=False, num_workers=num_workers
)
self._log_bias_params()
self._stage1_training(self.train_loader, self.test_loader)
if self._cur_task >= 1:
self._stage2_bias_correction(self.val_loader, self.test_loader)
self.build_rehearsal_memory(data_manager, self.samples_per_class)
if len(self._multiple_gpus) > 1:
self._network = self._network.module
self._log_bias_params()
def _run(self, train_loader, test_loader, optimizer, scheduler, stage):
for epoch in range(1, epochs + 1):
self._network.train()
losses = 0.0
for i, (_, inputs, targets) in enumerate(train_loader):
inputs, targets = inputs.to(self._device), targets.to(self._device)
logits = self._network(inputs)["logits"]
if stage == "training":
clf_loss = F.cross_entropy(logits, targets)
if self._old_network is not None:
old_logits = self._old_network(inputs)["logits"].detach()
hat_pai_k = F.softmax(old_logits / T, dim=1)
log_pai_k = F.log_softmax(
logits[:, : self._known_classes] / T, dim=1
)
distill_loss = -torch.mean(
torch.sum(hat_pai_k * log_pai_k, dim=1)
)
loss = distill_loss * self.lamda + clf_loss * (1 - self.lamda)
else:
loss = clf_loss
elif stage == "bias_correction":
loss = F.cross_entropy(torch.softmax(logits, dim=1), targets)
else:
raise NotImplementedError()
optimizer.zero_grad()
loss.backward()
optimizer.step()
losses += loss.item()
scheduler.step()
train_acc = self._compute_accuracy(self._network, train_loader)
test_acc = self._compute_accuracy(self._network, test_loader)
info = "{} => Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.3f}, Test_accy {:.3f}".format(
stage,
self._cur_task,
epoch,
epochs,
losses / len(train_loader),
train_acc,
test_acc,
)
logging.info(info)
def _stage1_training(self, train_loader, test_loader):
"""
if self._cur_task == 0:
loaded_dict = torch.load('./dict_0.pkl')
self._network.load_state_dict(loaded_dict['model_state_dict'])
self._network.to(self._device)
return
"""
ignored_params = list(map(id, self._network.bias_layers.parameters()))
base_params = filter(
lambda p: id(p) not in ignored_params, self._network.parameters()
)
network_params = [
{"params": base_params, "lr": lrate, "weight_decay": weight_decay},
{
"params": self._network.bias_layers.parameters(),
"lr": 0,
"weight_decay": 0,
},
]
optimizer = optim.SGD(
network_params, lr=lrate, momentum=0.9, weight_decay=weight_decay
)
scheduler = optim.lr_scheduler.MultiStepLR(
optimizer=optimizer, milestones=milestones, gamma=lrate_decay
)
if len(self._multiple_gpus) > 1:
self._network = nn.DataParallel(self._network, self._multiple_gpus)
self._network.to(self._device)
if self._old_network is not None:
self._old_network.to(self._device)
self._run(train_loader, test_loader, optimizer, scheduler, stage="training")
def _stage2_bias_correction(self, val_loader, test_loader):
if isinstance(self._network, nn.DataParallel):
self._network = self._network.module
network_params = [
{
"params": self._network.bias_layers[-1].parameters(),
"lr": lrate,
"weight_decay": weight_decay,
}
]
optimizer = optim.SGD(
network_params, lr=lrate, momentum=0.9, weight_decay=weight_decay
)
scheduler = optim.lr_scheduler.MultiStepLR(
optimizer=optimizer, milestones=milestones, gamma=lrate_decay
)
if len(self._multiple_gpus) > 1:
self._network = nn.DataParallel(self._network, self._multiple_gpus)
self._network.to(self._device)
self._run(
val_loader, test_loader, optimizer, scheduler, stage="bias_correction"
)
def _log_bias_params(self):
logging.info("Parameters of bias layer:")
params = self._network.get_bias_params()
for i, param in enumerate(params):
logging.info("{} => {:.3f}, {:.3f}".format(i, param[0], param[1]))