from itertools import chain import math import torch import torch.nn as nn import torch.nn.functional as F from model import get_model, get_head from .base import BaseMethod from .norm_mse import norm_mse_loss class BYOL(BaseMethod): """ implements BYOL loss https://arxiv.org/abs/2006.07733 """ def __init__(self, cfg): """ init additional target and predictor networks """ super().__init__(cfg) self.pred = nn.Sequential( nn.Linear(cfg.emb, cfg.head_size), nn.BatchNorm1d(cfg.head_size), nn.ReLU(), nn.Linear(cfg.head_size, cfg.emb), ) self.model_t, _ = get_model(cfg.arch, cfg.dataset) self.head_t = get_head(self.out_size, cfg) for param in chain(self.model_t.parameters(), self.head_t.parameters()): param.requires_grad = False self.update_target(0) self.byol_tau = cfg.byol_tau self.loss_f = norm_mse_loss if cfg.norm else F.mse_loss def update_target(self, tau): """ copy parameters from main network to target """ for t, s in zip(self.model_t.parameters(), self.model.parameters()): t.data.copy_(t.data * tau + s.data * (1.0 - tau)) for t, s in zip(self.head_t.parameters(), self.head.parameters()): t.data.copy_(t.data * tau + s.data * (1.0 - tau)) def forward(self, samples): z = [self.pred(self.head(self.model(x))) for x in samples] with torch.no_grad(): zt = [self.head_t(self.model_t(x)) for x in samples] loss = 0 for i in range(len(samples) - 1): for j in range(i + 1, len(samples)): loss += self.loss_f(z[i], zt[j]) + self.loss_f(z[j], zt[i]) loss /= self.num_pairs return loss def step(self, progress): """ update target network with cosine increasing schedule """ tau = 1 - (1 - self.byol_tau) * (math.cos(math.pi * progress) + 1) / 2 self.update_target(tau)