MapLocNet / models /metrics.py
wangerniu
Commit message.
124ba77
# Copyright (c) Meta Platforms, Inc. and affiliates.
import torch
import torchmetrics
from torchmetrics.utilities.data import dim_zero_cat
from .utils import deg2rad, rotmat2d
def location_error(uv, uv_gt, ppm=1):
return torch.norm(uv - uv_gt.to(uv), dim=-1) / ppm
def location_error_single(uv, uv_gt, ppm=1):
return torch.norm(uv - uv_gt.to(uv), dim=-1) / ppm
def angle_error(t, t_gt):
error = torch.abs(t % 360 - t_gt.to(t) % 360)
error = torch.minimum(error, 360 - error)
return error
class Location2DRecall(torchmetrics.MeanMetric):
def __init__(self, threshold, pixel_per_meter, key="uv_max", *args, **kwargs):
self.threshold = threshold
self.ppm = pixel_per_meter
self.key = key
super().__init__(*args, **kwargs)
def update(self, pred, data):
self.cuda()
error = location_error(pred[self.key], data["uv"], self.ppm)
# print(error,self.threshold)
super().update((error <= torch.tensor(self.threshold,device=error.device)).float())
class Location1DRecall(torchmetrics.MeanMetric):
def __init__(self, threshold, pixel_per_meter, key="uv_max", *args, **kwargs):
self.threshold = threshold
self.ppm = pixel_per_meter
self.key = key
super().__init__(*args, **kwargs)
def update(self, pred, data):
self.cuda()
error = location_error(pred[self.key], data["uv"], self.ppm)
# print(error,self.threshold)
super().update((error <= torch.tensor(self.threshold,device=error.device)).float())
class AngleRecall(torchmetrics.MeanMetric):
def __init__(self, threshold, key="yaw_max", *args, **kwargs):
self.threshold = threshold
self.key = key
super().__init__(*args, **kwargs)
def update(self, pred, data):
self.cuda()
error = angle_error(pred[self.key], data["roll_pitch_yaw"][..., -1])
super().update((error <= self.threshold).float())
class MeanMetricWithRecall(torchmetrics.Metric):
full_state_update = True
def __init__(self):
super().__init__()
self.add_state("value", default=[], dist_reduce_fx="cat")
def compute(self):
return dim_zero_cat(self.value).mean(0)
def get_errors(self):
return dim_zero_cat(self.value)
def recall(self, thresholds):
self.cuda()
error = self.get_errors()
thresholds = error.new_tensor(thresholds)
return (error.unsqueeze(-1) < thresholds).float().mean(0) * 100
class AngleError(MeanMetricWithRecall):
def __init__(self, key):
super().__init__()
self.key = key
def update(self, pred, data):
self.cuda()
value = angle_error(pred[self.key], data["roll_pitch_yaw"][..., -1])
if value.numel():
self.value.append(value)
class Location2DError(MeanMetricWithRecall):
def __init__(self, key, pixel_per_meter):
super().__init__()
self.key = key
self.ppm = pixel_per_meter
def update(self, pred, data):
self.cuda()
value = location_error(pred[self.key], data["uv"], self.ppm)
if value.numel():
self.value.append(value)
class LateralLongitudinalError(MeanMetricWithRecall):
def __init__(self, pixel_per_meter, key="uv_max"):
super().__init__()
self.ppm = pixel_per_meter
self.key = key
def update(self, pred, data):
self.cuda()
yaw = deg2rad(data["roll_pitch_yaw"][..., -1])
shift = (pred[self.key] - data["uv"]) * yaw.new_tensor([-1, 1])
shift = (rotmat2d(yaw) @ shift.unsqueeze(-1)).squeeze(-1)
error = torch.abs(shift) / self.ppm
value = error.view(-1, 2)
if value.numel():
self.value.append(value)