Spaces:
Build error
Build error
import torch | |
from torch import nn | |
class IOULoss(nn.Module): | |
def __init__(self, loss_type="iou"): | |
super(IOULoss, self).__init__() | |
self.loss_type = loss_type | |
def forward(self, pred, target, weight=None): | |
pred_left = pred[:, 0] | |
pred_top = pred[:, 1] | |
pred_right = pred[:, 2] | |
pred_bottom = pred[:, 3] | |
target_left = target[:, 0] | |
target_top = target[:, 1] | |
target_right = target[:, 2] | |
target_bottom = target[:, 3] | |
target_area = (target_left + target_right) * \ | |
(target_top + target_bottom) | |
pred_area = (pred_left + pred_right) * \ | |
(pred_top + pred_bottom) | |
w_intersect = torch.min(pred_left, target_left) + torch.min(pred_right, target_right) | |
g_w_intersect = torch.max(pred_left, target_left) + torch.max( | |
pred_right, target_right) | |
h_intersect = torch.min(pred_bottom, target_bottom) + torch.min(pred_top, target_top) | |
g_h_intersect = torch.max(pred_bottom, target_bottom) + torch.max(pred_top, target_top) | |
ac_uion = g_w_intersect * g_h_intersect + 1e-7 | |
area_intersect = w_intersect * h_intersect | |
area_union = target_area + pred_area - area_intersect | |
ious = (area_intersect + 1.0) / (area_union + 1.0) | |
gious = ious - (ac_uion - area_union) / ac_uion | |
if self.loss_type == 'iou': | |
losses = -torch.log(ious) | |
elif self.loss_type == 'linear_iou': | |
losses = 1 - ious | |
elif self.loss_type == 'giou': | |
losses = 1 - gious | |
else: | |
raise NotImplementedError | |
if weight is not None and weight.sum() > 0: | |
return (losses * weight).sum() | |
else: | |
assert losses.numel() != 0 | |
return losses.sum() | |
class IOUWHLoss(nn.Module): # used for anchor guiding | |
def __init__(self, reduction='none'): | |
super(IOUWHLoss, self).__init__() | |
self.reduction = reduction | |
def forward(self, pred, target): | |
orig_shape = pred.shape | |
pred = pred.view(-1, 4) | |
target = target.view(-1, 4) | |
target[:, :2] = 0 | |
tl = torch.max((target[:, :2] - pred[:, 2:] / 2), | |
(target[:, :2] - target[:, 2:] / 2)) | |
br = torch.min((target[:, :2] + pred[:, 2:] / 2), | |
(target[:, :2] + target[:, 2:] / 2)) | |
area_p = torch.prod(pred[:, 2:], 1) | |
area_g = torch.prod(target[:, 2:], 1) | |
en = (tl < br).type(tl.type()).prod(dim=1) | |
area_i = torch.prod(br - tl, 1) * en | |
U = area_p + area_g - area_i + 1e-16 | |
iou = area_i / U | |
loss = 1 - iou ** 2 | |
if self.reduction == 'mean': | |
loss = loss.mean() | |
elif self.reduction == 'sum': | |
loss = loss.sum() | |
return loss | |