|
""" Cross Entropy w/ smoothing or soft targets |
|
|
|
Hacked together by / Copyright 2021 Ross Wightman |
|
""" |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
|
|
class LabelSmoothingCrossEntropy(nn.Module): |
|
""" NLL loss with label smoothing. |
|
""" |
|
def __init__(self, smoothing=0.1): |
|
super(LabelSmoothingCrossEntropy, self).__init__() |
|
assert smoothing < 1.0 |
|
self.smoothing = smoothing |
|
self.confidence = 1. - smoothing |
|
|
|
def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor: |
|
logprobs = F.log_softmax(x, dim=-1) |
|
nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1)) |
|
nll_loss = nll_loss.squeeze(1) |
|
smooth_loss = -logprobs.mean(dim=-1) |
|
loss = self.confidence * nll_loss + self.smoothing * smooth_loss |
|
return loss.mean() |
|
|
|
|
|
class SoftTargetCrossEntropy(nn.Module): |
|
|
|
def __init__(self): |
|
super(SoftTargetCrossEntropy, self).__init__() |
|
|
|
def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor: |
|
loss = torch.sum(-target * F.log_softmax(x, dim=-1), dim=-1) |
|
return loss.mean() |
|
|