PyCIL / models /podnet.py
HungNP
New single commit message
cb80c28
import math
import logging
import numpy as np
import torch
from torch import optim
from torch.nn import functional as F
from torch.utils.data import DataLoader
from models.base import BaseLearner
from utils.inc_net import CosineIncrementalNet
from utils.toolkit import tensor2numpy
epochs = 100
lrate = 0.1
ft_epochs = 20
ft_lrate = 0.005
batch_size = 32
lambda_c_base = 5
lambda_f_base = 1
nb_proxy = 10
weight_decay = 5e-4
num_workers = 4
"""
Distillation losses: POD-flat (lambda_f=1) + POD-spatial (lambda_c=5)
NME results are shown.
The reproduced results are not in line with the reported results.
Maybe I missed something...
+--------------------+--------------------+--------------------+--------------------+
| Classifier | Steps | Reported (%) | Reproduced (%) |
+--------------------+--------------------+--------------------+--------------------+
| Cosine (k=1) | 50 | 56.69 | 55.49 |
+--------------------+--------------------+--------------------+--------------------+
| LSC-CE (k=10) | 50 | 59.86 | 55.69 |
+--------------------+--------------------+--------------------+--------------------+
| LSC-NCA (k=10) | 50 | 61.40 | 56.50 |
+--------------------+--------------------+--------------------+--------------------+
| LSC-CE (k=10) | 25 | ----- | 59.16 |
+--------------------+--------------------+--------------------+--------------------+
| LSC-NCA (k=10) | 25 | 62.71 | 59.79 |
+--------------------+--------------------+--------------------+--------------------+
| LSC-CE (k=10) | 10 | ----- | 62.59 |
+--------------------+--------------------+--------------------+--------------------+
| LSC-NCA (k=10) | 10 | 64.03 | 62.81 |
+--------------------+--------------------+--------------------+--------------------+
| LSC-CE (k=10) | 5 | ----- | 64.16 |
+--------------------+--------------------+--------------------+--------------------+
| LSC-NCA (k=10) | 5 | 64.48 | 64.37 |
+--------------------+--------------------+--------------------+--------------------+
"""
class PODNet(BaseLearner):
def __init__(self, args):
super().__init__(args)
self._network = CosineIncrementalNet(
args, pretrained=False, nb_proxy=nb_proxy
)
self._class_means = None
def after_task(self):
self._old_network = self._network.copy().freeze()
self._known_classes = self._total_classes
logging.info("Exemplar size: {}".format(self.exemplar_size))
def incremental_train(self, data_manager):
self._cur_task += 1
self._total_classes = self._known_classes + data_manager.get_task_size(
self._cur_task
)
self.task_size = self._total_classes - self._known_classes
self._network.update_fc(self._total_classes, self._cur_task)
logging.info(
"Learning on {}-{}".format(self._known_classes, self._total_classes)
)
train_dset = data_manager.get_dataset(
np.arange(self._known_classes, self._total_classes),
source="train",
mode="train",
appendent=self._get_memory(),
)
test_dset = data_manager.get_dataset(
np.arange(0, self._total_classes), source="test", mode="test"
)
self.train_loader = DataLoader(
train_dset, batch_size=batch_size, shuffle=True, num_workers=num_workers
)
self.test_loader = DataLoader(
test_dset, batch_size=batch_size, shuffle=False, num_workers=num_workers
)
self._train(data_manager, self.train_loader, self.test_loader)
self.build_rehearsal_memory(data_manager, self.samples_per_class)
def _train(self, data_manager, train_loader, test_loader):
if self._cur_task == 0:
self.factor = 0
else:
self.factor = math.sqrt(
self._total_classes / (self._total_classes - self._known_classes)
)
logging.info("Adaptive factor: {}".format(self.factor))
self._network.to(self._device)
if self._old_network is not None:
self._old_network.to(self._device)
if self._cur_task == 0:
network_params = self._network.parameters()
else:
ignored_params = list(map(id, self._network.fc.fc1.parameters()))
base_params = filter(
lambda p: id(p) not in ignored_params, self._network.parameters()
)
network_params = [
{"params": base_params, "lr": lrate, "weight_decay": weight_decay},
{
"params": self._network.fc.fc1.parameters(),
"lr": 0,
"weight_decay": 0,
},
]
optimizer = optim.SGD(
network_params, lr=lrate, momentum=0.9, weight_decay=weight_decay
)
scheduler = optim.lr_scheduler.CosineAnnealingLR(
optimizer=optimizer, T_max=epochs
)
self._run(train_loader, test_loader, optimizer, scheduler, epochs)
if self._cur_task == 0:
return
logging.info(
"Finetune the network (classifier part) with the undersampled dataset!"
)
if self._fixed_memory:
finetune_samples_per_class = self._memory_per_class
self._construct_exemplar_unified(data_manager, finetune_samples_per_class)
else:
finetune_samples_per_class = self._memory_size // self._known_classes
self._reduce_exemplar(data_manager, finetune_samples_per_class)
self._construct_exemplar(data_manager, finetune_samples_per_class)
finetune_train_dataset = data_manager.get_dataset(
[], source="train", mode="train", appendent=self._get_memory()
)
finetune_train_loader = DataLoader(
finetune_train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
)
logging.info(
"The size of finetune dataset: {}".format(len(finetune_train_dataset))
)
ignored_params = list(map(id, self._network.fc.fc1.parameters()))
base_params = filter(
lambda p: id(p) not in ignored_params, self._network.parameters()
)
network_params = [
{"params": base_params, "lr": ft_lrate, "weight_decay": weight_decay},
{"params": self._network.fc.fc1.parameters(), "lr": 0, "weight_decay": 0},
]
optimizer = optim.SGD(
network_params, lr=ft_lrate, momentum=0.9, weight_decay=weight_decay
)
scheduler = optim.lr_scheduler.CosineAnnealingLR(
optimizer=optimizer, T_max=ft_epochs
)
self._run(finetune_train_loader, test_loader, optimizer, scheduler, ft_epochs)
if self._fixed_memory:
self._data_memory = self._data_memory[
: -self._memory_per_class * self.task_size
]
self._targets_memory = self._targets_memory[
: -self._memory_per_class * self.task_size
]
assert (
len(
np.setdiff1d(
self._targets_memory, np.arange(0, self._known_classes)
)
)
== 0
), "Exemplar error!"
def _run(self, train_loader, test_loader, optimizer, scheduler, epk):
for epoch in range(1, epk + 1):
self._network.train()
lsc_losses = 0.0
spatial_losses = 0.0
flat_losses = 0.0
correct, total = 0, 0
for i, (_, inputs, targets) in enumerate(train_loader):
inputs, targets = inputs.to(self._device), targets.to(self._device)
outputs = self._network(inputs)
logits = outputs["logits"]
features = outputs["features"]
fmaps = outputs["fmaps"]
lsc_loss = nca(logits, targets)
spatial_loss = 0.0
flat_loss = 0.0
if self._old_network is not None:
with torch.no_grad():
old_outputs = self._old_network(inputs)
old_features = old_outputs["features"]
old_fmaps = old_outputs["fmaps"]
flat_loss = (
F.cosine_embedding_loss(
features,
old_features.detach(),
torch.ones(inputs.shape[0]).to(self._device),
)
* self.factor
* lambda_f_base
)
spatial_loss = (
pod_spatial_loss(fmaps, old_fmaps) * self.factor * lambda_c_base
)
loss = lsc_loss + flat_loss + spatial_loss
optimizer.zero_grad()
loss.backward()
optimizer.step()
lsc_losses += lsc_loss.item()
spatial_losses += (
spatial_loss.item() if self._cur_task != 0 else spatial_loss
)
flat_losses += flat_loss.item() if self._cur_task != 0 else flat_loss
_, preds = torch.max(logits, dim=1)
correct += preds.eq(targets.expand_as(preds)).cpu().sum()
total += len(targets)
if scheduler is not None:
scheduler.step()
train_acc = np.around(tensor2numpy(correct) * 100 / total, decimals=2)
test_acc = self._compute_accuracy(self._network, test_loader)
info1 = "Task {}, Epoch {}/{} (LR {:.5f}) => ".format(
self._cur_task, epoch, epk, optimizer.param_groups[0]["lr"]
)
info2 = "LSC_loss {:.2f}, Spatial_loss {:.2f}, Flat_loss {:.2f}, Train_acc {:.2f}, Test_acc {:.2f}".format(
lsc_losses / (i + 1),
spatial_losses / (i + 1),
flat_losses / (i + 1),
train_acc,
test_acc,
)
logging.info(info1 + info2)
def pod_spatial_loss(old_fmaps, fmaps, normalize=True):
"""
a, b: list of [bs, c, w, h]
"""
loss = torch.tensor(0.0).to(fmaps[0].device)
for i, (a, b) in enumerate(zip(old_fmaps, fmaps)):
assert a.shape == b.shape, "Shape error"
a = torch.pow(a, 2)
b = torch.pow(b, 2)
a_h = a.sum(dim=3).view(a.shape[0], -1) # [bs, c*w]
b_h = b.sum(dim=3).view(b.shape[0], -1) # [bs, c*w]
a_w = a.sum(dim=2).view(a.shape[0], -1) # [bs, c*h]
b_w = b.sum(dim=2).view(b.shape[0], -1) # [bs, c*h]
a = torch.cat([a_h, a_w], dim=-1)
b = torch.cat([b_h, b_w], dim=-1)
if normalize:
a = F.normalize(a, dim=1, p=2)
b = F.normalize(b, dim=1, p=2)
layer_loss = torch.mean(torch.frobenius_norm(a - b, dim=-1))
loss += layer_loss
return loss / len(fmaps)
def nca(
similarities,
targets,
class_weights=None,
focal_gamma=None,
scale=1.0,
margin=0.6,
exclude_pos_denominator=True,
hinge_proxynca=False,
memory_flags=None,
):
margins = torch.zeros_like(similarities)
margins[torch.arange(margins.shape[0]), targets] = margin
similarities = scale * (similarities - margin)
if exclude_pos_denominator:
similarities = similarities - similarities.max(1)[0].view(-1, 1)
disable_pos = torch.zeros_like(similarities)
disable_pos[torch.arange(len(similarities)), targets] = similarities[
torch.arange(len(similarities)), targets
]
numerator = similarities[torch.arange(similarities.shape[0]), targets]
denominator = similarities - disable_pos
losses = numerator - torch.log(torch.exp(denominator).sum(-1))
if class_weights is not None:
losses = class_weights[targets] * losses
losses = -losses
if hinge_proxynca:
losses = torch.clamp(losses, min=0.0)
loss = torch.mean(losses)
return loss
return F.cross_entropy(
similarities, targets, weight=class_weights, reduction="mean"
)