wavlm_ssl_sv / loss /aamsoftmax.py
theolepage's picture
initial commit
430712c
#! /usr/bin/python
# -*- encoding: utf-8 -*-
# Adapted from https://github.com/wujiyang/Face_Pytorch (Apache License)
import torch
import torch.nn as nn
import torch.nn.functional as F
import time, pdb, numpy, math
from utils import accuracy
import numpy as np
class LossFunction(nn.Module):
def __init__(self, nOut, nClasses, margin=0.3, scale=15, easy_margin=False, **kwargs):
super(LossFunction, self).__init__()
self.test_normalize = True
self.m = margin
self.s = scale
self.in_feats = nOut
self.weight = torch.nn.Parameter(torch.FloatTensor(nClasses, nOut), requires_grad=True)
# self.ce = nn.CrossEntropyLoss()
self.ce = nn.CrossEntropyLoss(reduction='none') # return loss per sample
nn.init.xavier_normal_(self.weight, gain=1)
self.easy_margin = easy_margin
self.cos_m = math.cos(self.m)
self.sin_m = math.sin(self.m)
# make the function cos(theta+m) monotonic decreasing while theta in [0°,180°]
self.th = math.cos(math.pi - self.m)
self.mm = math.sin(math.pi - self.m) * self.m
self.lgl_threshold = 1e6
self.lc_threshold = 0.5
print('Initialised AAMSoftmax margin %.3f scale %.3f'%(self.m,self.s))
def _forward(self, x, label):
# cos(theta)
cosine = F.linear(F.normalize(x), F.normalize(self.weight))
# cos(theta + m)
sine = torch.sqrt((1.0 - torch.mul(cosine, cosine)).clamp(0, 1))
phi = cosine * self.cos_m - sine * self.sin_m
if self.easy_margin:
phi = torch.where(cosine > 0, phi, cosine)
else:
phi = torch.where((cosine - self.th) > 0, phi, cosine - self.mm)
#one_hot = torch.zeros(cosine.size(), device='cuda' if torch.cuda.is_available() else 'cpu')
one_hot = torch.zeros_like(cosine)
one_hot.scatter_(1, label.view(-1, 1), 1)
output = (one_hot * phi) + ((1.0 - one_hot) * cosine)
output = output * self.s
return output
def _forward_softmax_sharpened(self, x, e=0.1):
# regular softmax
output = F.linear(x, self.weight)
probas = F.softmax(output / e, dim=1)
return probas
def forward(self, x, x_clean, label=None, epoch=-1):
assert x.size()[0] == label.size()[0]
assert x.size()[1] == self.in_feats
output = self._forward(x, label)
output_clean = self._forward_softmax_sharpened(x_clean)
ce = self.ce(output, label)
# No LGL
# prec1 = accuracy(output.detach(), label.detach(), topk=(1,))[0]
# return ce, prec1, None
mask = (torch.log(ce) <= self.lgl_threshold).detach()
if epoch <= 8:
# LGL only
nselect = torch.clamp(sum(mask), min=1).item()
loss = torch.sum(ce * mask, dim=-1) / nselect
prec1 = accuracy(output.detach(), label * mask.detach(), topk=(1,))[0]
return loss, prec1, ce
# LGL + LC
label_LC = output_clean.argmax(dim=1)
max_vals = torch.gather(output_clean, 1, label_LC.unsqueeze(1)).squeeze(1)
mask_LC = (max_vals > self.lc_threshold).detach()
ce_LC = self.ce(output, label_LC)
mask_LGL_LC = ~mask & mask_LC
loss = torch.mean(ce * mask + ce_LC * mask_LGL_LC, dim=-1)
prec1 = accuracy(output.detach(), label * mask.detach() + label_LC * mask_LGL_LC.detach(), topk=(1,))[0]
return loss, prec1, ce
def get_pseudo_labels(self, x, label):
output = self._forward_softmax_sharpened(x)
return output.argmax(dim=1)
"""
def forward(self, x, x_clean, label=None):
assert x.size()[0] == label.size()[0]
assert x.size()[1] == self.in_feats
P_aam = self._forward(x, label)
P_softmax = self._forward_softmax_sharpened(x)
P_clean_softmax = self._forward_softmax_sharpened(x_clean)
ce = self.ce(P_aam, label)
# No LGL
# prec1 = accuracy(output.detach(), label.detach(), topk=(1,))[0]
# return ce, prec1, None
mask = (torch.log(ce) <= self.lgl_threshold).detach()
# LGL only
# nselect = torch.clamp(sum(mask), min=1).item()
# loss = torch.sum(ce * mask, dim=-1) / nselect
# prec1 = accuracy(output.detach(), label * mask.detach(), topk=(1,))[0]
# return loss, prec1, ce
# LGL + LC
label_LC = P_clean_softmax.argmax(dim=1)
ce_LC = self.ce(P_softmax, label_LC)
inverted_mask = ~mask
loss = torch.mean(ce * mask + ce_LC * inverted_mask, dim=-1)
prec1 = accuracy(P_softmax.detach(), label * mask.detach() + label_LC * inverted_mask.detach(), topk=(1,))[0]
return loss, prec1, ce
"""