|
|
|
|
|
import numpy as np |
|
import torch |
|
from torch.nn.functional import normalize |
|
|
|
from . import get_model |
|
from models.base import BaseModel |
|
|
|
|
|
from models.voting import ( |
|
argmax_xyr, |
|
conv2d_fft_batchwise, |
|
expectation_xyr, |
|
log_softmax_spatial, |
|
mask_yaw_prior, |
|
nll_loss_xyr, |
|
nll_loss_xyr_smoothed, |
|
TemplateSampler, |
|
UAVTemplateSampler, |
|
UAVTemplateSamplerFast |
|
) |
|
from .map_encoder import MapEncoder |
|
from .metrics import AngleError, AngleRecall, Location2DError, Location2DRecall |
|
|
|
|
|
class MapLocNet(BaseModel): |
|
default_conf = { |
|
"image_size": "???", |
|
"val_citys":"???", |
|
"image_encoder": "???", |
|
"map_encoder": "???", |
|
"bev_net": "???", |
|
"latent_dim": "???", |
|
"matching_dim": "???", |
|
"scale_range": [0, 9], |
|
"num_scale_bins": "???", |
|
"z_min": None, |
|
"z_max": "???", |
|
"x_max": "???", |
|
"pixel_per_meter": "???", |
|
"num_rotations": "???", |
|
"add_temperature": False, |
|
"normalize_features": False, |
|
"padding_matching": "replicate", |
|
"apply_map_prior": True, |
|
"do_label_smoothing": False, |
|
"sigma_xy": 1, |
|
"sigma_r": 2, |
|
|
|
"depth_parameterization": "scale", |
|
"norm_depth_scores": False, |
|
"normalize_scores_by_dim": False, |
|
"normalize_scores_by_num_valid": True, |
|
"prior_renorm": True, |
|
"retrieval_dim": None, |
|
} |
|
|
|
def _init(self, conf): |
|
assert not self.conf.norm_depth_scores |
|
assert self.conf.depth_parameterization == "scale" |
|
assert not self.conf.normalize_scores_by_dim |
|
assert self.conf.normalize_scores_by_num_valid |
|
assert self.conf.prior_renorm |
|
|
|
Encoder = get_model(conf.image_encoder.get("name", "feature_extractor_v2")) |
|
self.image_encoder = Encoder(conf.image_encoder.backbone) |
|
self.map_encoder = MapEncoder(conf.map_encoder) |
|
|
|
|
|
ppm = conf.pixel_per_meter |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.template_sampler = UAVTemplateSampler(conf.num_rotations) |
|
|
|
|
|
|
|
|
|
|
|
if conf.add_temperature: |
|
temperature = torch.nn.Parameter(torch.tensor(0.0)) |
|
self.register_parameter("temperature", temperature) |
|
|
|
def exhaustive_voting(self, f_bev, f_map): |
|
if self.conf.normalize_features: |
|
f_bev = normalize(f_bev, dim=1) |
|
f_map = normalize(f_map, dim=1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
templates = self.template_sampler(f_bev) |
|
|
|
with torch.autocast("cuda", enabled=False): |
|
scores = conv2d_fft_batchwise( |
|
f_map.float(), |
|
templates.float(), |
|
padding_mode=self.conf.padding_matching, |
|
) |
|
if self.conf.add_temperature: |
|
scores = scores * torch.exp(self.temperature) |
|
|
|
|
|
|
|
|
|
|
|
|
|
return scores |
|
|
|
def _forward(self, data): |
|
pred = {} |
|
pred_map = pred["map"] = self.map_encoder(data) |
|
f_map = pred_map["map_features"][0] |
|
|
|
|
|
level = 0 |
|
f_image = self.image_encoder(data)["feature_maps"][level] |
|
|
|
|
|
scores = self.exhaustive_voting(f_image, f_map) |
|
scores = scores.moveaxis(1, -1) |
|
if "log_prior" in pred_map and self.conf.apply_map_prior: |
|
scores = scores + pred_map["log_prior"][0].unsqueeze(-1) |
|
|
|
if "map_mask" in data: |
|
scores.masked_fill_(~data["map_mask"][..., None], -np.inf) |
|
if "yaw_prior" in data: |
|
mask_yaw_prior(scores, data["yaw_prior"], self.conf.num_rotations) |
|
log_probs = log_softmax_spatial(scores) |
|
|
|
with torch.no_grad(): |
|
uvr_max = argmax_xyr(scores).to(scores) |
|
uvr_avg, _ = expectation_xyr(log_probs.exp()) |
|
|
|
return { |
|
**pred, |
|
"scores": scores, |
|
"log_probs": log_probs, |
|
"uvr_max": uvr_max, |
|
"uv_max": uvr_max[..., :2], |
|
"yaw_max": uvr_max[..., 2], |
|
"uvr_expectation": uvr_avg, |
|
"uv_expectation": uvr_avg[..., :2], |
|
"yaw_expectation": uvr_avg[..., 2], |
|
"features_image": f_image, |
|
} |
|
|
|
def loss(self, pred, data): |
|
xy_gt = data["uv"] |
|
yaw_gt = data["roll_pitch_yaw"][..., -1] |
|
if self.conf.do_label_smoothing: |
|
nll = nll_loss_xyr_smoothed( |
|
pred["log_probs"], |
|
xy_gt, |
|
yaw_gt, |
|
self.conf.sigma_xy / self.conf.pixel_per_meter, |
|
self.conf.sigma_r, |
|
mask=data.get("map_mask"), |
|
) |
|
else: |
|
nll = nll_loss_xyr(pred["log_probs"], xy_gt, yaw_gt) |
|
loss = {"total": nll, "nll": nll} |
|
if self.training and self.conf.add_temperature: |
|
loss["temperature"] = self.temperature.expand(len(nll)) |
|
return loss |
|
|
|
def metrics(self): |
|
return { |
|
"xy_max_error": Location2DError("uv_max", self.conf.pixel_per_meter), |
|
"xy_expectation_error": Location2DError( |
|
"uv_expectation", self.conf.pixel_per_meter |
|
), |
|
"yaw_max_error": AngleError("yaw_max"), |
|
"xy_recall_1m": Location2DRecall(1.0, self.conf.pixel_per_meter, "uv_max"), |
|
"xy_recall_3m": Location2DRecall(3.0, self.conf.pixel_per_meter, "uv_max"), |
|
"xy_recall_5m": Location2DRecall(5.0, self.conf.pixel_per_meter, "uv_max"), |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"yaw_recall_1°": AngleRecall(1.0, "yaw_max"), |
|
"yaw_recall_3°": AngleRecall(3.0, "yaw_max"), |
|
"yaw_recall_5°": AngleRecall(5.0, "yaw_max"), |
|
} |
|
|