wgcban's picture
Upload 98 files
803ef9e
from itertools import chain
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from model import get_model, get_head
from .base import BaseMethod
from .norm_mse import norm_mse_loss
class BYOL(BaseMethod):
""" implements BYOL loss https://arxiv.org/abs/2006.07733 """
def __init__(self, cfg):
""" init additional target and predictor networks """
super().__init__(cfg)
self.pred = nn.Sequential(
nn.Linear(cfg.emb, cfg.head_size),
nn.BatchNorm1d(cfg.head_size),
nn.ReLU(),
nn.Linear(cfg.head_size, cfg.emb),
)
self.model_t, _ = get_model(cfg.arch, cfg.dataset)
self.head_t = get_head(self.out_size, cfg)
for param in chain(self.model_t.parameters(), self.head_t.parameters()):
param.requires_grad = False
self.update_target(0)
self.byol_tau = cfg.byol_tau
self.loss_f = norm_mse_loss if cfg.norm else F.mse_loss
def update_target(self, tau):
""" copy parameters from main network to target """
for t, s in zip(self.model_t.parameters(), self.model.parameters()):
t.data.copy_(t.data * tau + s.data * (1.0 - tau))
for t, s in zip(self.head_t.parameters(), self.head.parameters()):
t.data.copy_(t.data * tau + s.data * (1.0 - tau))
def forward(self, samples):
z = [self.pred(self.head(self.model(x))) for x in samples]
with torch.no_grad():
zt = [self.head_t(self.model_t(x)) for x in samples]
loss = 0
for i in range(len(samples) - 1):
for j in range(i + 1, len(samples)):
loss += self.loss_f(z[i], zt[j]) + self.loss_f(z[j], zt[i])
loss /= self.num_pairs
return loss
def step(self, progress):
""" update target network with cosine increasing schedule """
tau = 1 - (1 - self.byol_tau) * (math.cos(math.pi * progress) + 1) / 2
self.update_target(tau)