MapLocNet / models /maplocnet.py
wangerniu
Commit message.
124ba77
# Copyright (c) Meta Platforms, Inc. and affiliates.
import numpy as np
import torch
from torch.nn.functional import normalize
from . import get_model
from models.base import BaseModel
# from models.bev_net import BEVNet
# from models.bev_projection import CartesianProjection, PolarProjectionDepth
from models.voting import (
argmax_xyr,
conv2d_fft_batchwise,
expectation_xyr,
log_softmax_spatial,
mask_yaw_prior,
nll_loss_xyr,
nll_loss_xyr_smoothed,
TemplateSampler,
UAVTemplateSampler,
UAVTemplateSamplerFast
)
from .map_encoder import MapEncoder
from .metrics import AngleError, AngleRecall, Location2DError, Location2DRecall
class MapLocNet(BaseModel):
default_conf = {
"image_size": "???",
"val_citys":"???",
"image_encoder": "???",
"map_encoder": "???",
"bev_net": "???",
"latent_dim": "???",
"matching_dim": "???",
"scale_range": [0, 9],
"num_scale_bins": "???",
"z_min": None,
"z_max": "???",
"x_max": "???",
"pixel_per_meter": "???",
"num_rotations": "???",
"add_temperature": False,
"normalize_features": False,
"padding_matching": "replicate",
"apply_map_prior": True,
"do_label_smoothing": False,
"sigma_xy": 1,
"sigma_r": 2,
# depcreated
"depth_parameterization": "scale",
"norm_depth_scores": False,
"normalize_scores_by_dim": False,
"normalize_scores_by_num_valid": True,
"prior_renorm": True,
"retrieval_dim": None,
}
def _init(self, conf):
assert not self.conf.norm_depth_scores
assert self.conf.depth_parameterization == "scale"
assert not self.conf.normalize_scores_by_dim
assert self.conf.normalize_scores_by_num_valid
assert self.conf.prior_renorm
Encoder = get_model(conf.image_encoder.get("name", "feature_extractor_v2"))
self.image_encoder = Encoder(conf.image_encoder.backbone)
self.map_encoder = MapEncoder(conf.map_encoder)
# self.bev_net = None if conf.bev_net is None else BEVNet(conf.bev_net)
ppm = conf.pixel_per_meter
# self.projection_polar = PolarProjectionDepth(
# conf.z_max,
# ppm,
# conf.scale_range,
# conf.z_min,
# )
# self.projection_bev = CartesianProjection(
# conf.z_max, conf.x_max, ppm, conf.z_min
# )
# self.template_sampler = TemplateSampler(
# self.projection_bev.grid_xz, ppm, conf.num_rotations
# )
# self.template_sampler = UAVTemplateSamplerFast(conf.num_rotations,w=conf.image_size//2)
self.template_sampler = UAVTemplateSampler(conf.num_rotations)
# self.scale_classifier = torch.nn.Linear(conf.latent_dim, conf.num_scale_bins)
# if conf.bev_net is None:
# self.feature_projection = torch.nn.Linear(
# conf.latent_dim, conf.matching_dim
# )
if conf.add_temperature:
temperature = torch.nn.Parameter(torch.tensor(0.0))
self.register_parameter("temperature", temperature)
def exhaustive_voting(self, f_bev, f_map):
if self.conf.normalize_features:
f_bev = normalize(f_bev, dim=1)
f_map = normalize(f_map, dim=1)
# Build the templates and exhaustively match against the map.
# if confidence_bev is not None:
# f_bev = f_bev * confidence_bev.unsqueeze(1)
# f_bev = f_bev.masked_fill(~valid_bev.unsqueeze(1), 0.0)
# torch.save(f_bev, 'f_bev.pt')
# torch.save(f_map, 'f_map.pt')
templates = self.template_sampler(f_bev)#[batch,256,8,129,129]
# torch.save(templates, 'templates.pt')
with torch.autocast("cuda", enabled=False):
scores = conv2d_fft_batchwise(
f_map.float(),
templates.float(),
padding_mode=self.conf.padding_matching,
)
if self.conf.add_temperature:
scores = scores * torch.exp(self.temperature)
# Reweight the different rotations based on the number of valid pixels
# in each template. Axis-aligned rotation have the maximum number of valid pixels.
# valid_templates = self.template_sampler(valid_bev.float()[None]) > (1 - 1e-4)
# num_valid = valid_templates.float().sum((-3, -2, -1))
# scores = scores / num_valid[..., None, None]
return scores
def _forward(self, data):
pred = {}
pred_map = pred["map"] = self.map_encoder(data)
f_map = pred_map["map_features"][0]#[batch,8,256,256]
# Extract image features.
level = 0
f_image = self.image_encoder(data)["feature_maps"][level]#[batch,128,128,176]
# print("f_map:",f_map.shape)
scores = self.exhaustive_voting(f_image, f_map)#f_bev:[batch,8,64,129] f_map:[batch,8,256,256] confidence:[1,64,129]
scores = scores.moveaxis(1, -1) # B,H,W,N
if "log_prior" in pred_map and self.conf.apply_map_prior:
scores = scores + pred_map["log_prior"][0].unsqueeze(-1)
# pred["scores_unmasked"] = scores.clone()
if "map_mask" in data:
scores.masked_fill_(~data["map_mask"][..., None], -np.inf)
if "yaw_prior" in data:
mask_yaw_prior(scores, data["yaw_prior"], self.conf.num_rotations)
log_probs = log_softmax_spatial(scores)
# torch.save(scores, 'scores.pt')
with torch.no_grad():
uvr_max = argmax_xyr(scores).to(scores)
uvr_avg, _ = expectation_xyr(log_probs.exp())
return {
**pred,
"scores": scores,
"log_probs": log_probs,
"uvr_max": uvr_max,
"uv_max": uvr_max[..., :2],
"yaw_max": uvr_max[..., 2],
"uvr_expectation": uvr_avg,
"uv_expectation": uvr_avg[..., :2],
"yaw_expectation": uvr_avg[..., 2],
"features_image": f_image,
}
def loss(self, pred, data):
xy_gt = data["uv"]
yaw_gt = data["roll_pitch_yaw"][..., -1]
if self.conf.do_label_smoothing:
nll = nll_loss_xyr_smoothed(
pred["log_probs"],
xy_gt,
yaw_gt,
self.conf.sigma_xy / self.conf.pixel_per_meter,
self.conf.sigma_r,
mask=data.get("map_mask"),
)
else:
nll = nll_loss_xyr(pred["log_probs"], xy_gt, yaw_gt)
loss = {"total": nll, "nll": nll}
if self.training and self.conf.add_temperature:
loss["temperature"] = self.temperature.expand(len(nll))
return loss
def metrics(self):
return {
"xy_max_error": Location2DError("uv_max", self.conf.pixel_per_meter),
"xy_expectation_error": Location2DError(
"uv_expectation", self.conf.pixel_per_meter
),
"yaw_max_error": AngleError("yaw_max"),
"xy_recall_1m": Location2DRecall(1.0, self.conf.pixel_per_meter, "uv_max"),
"xy_recall_3m": Location2DRecall(3.0, self.conf.pixel_per_meter, "uv_max"),
"xy_recall_5m": Location2DRecall(5.0, self.conf.pixel_per_meter, "uv_max"),
# "x_recall_1m": Location2DRecall(1.0, self.conf.pixel_per_meter, "uv_max"),
# "x_recall_3m": Location2DRecall(3.0, self.conf.pixel_per_meter, "uv_max"),
# "x_recall_5m": Location2DRecall(5.0, self.conf.pixel_per_meter, "uv_max"),
#
# "y_recall_1m": Location2DRecall(1.0, self.conf.pixel_per_meter, "uv_max"),
# "y_recall_3m": Location2DRecall(3.0, self.conf.pixel_per_meter, "uv_max"),
# "y_recall_5m": Location2DRecall(5.0, self.conf.pixel_per_meter, "uv_max"),
"yaw_recall_1°": AngleRecall(1.0, "yaw_max"),
"yaw_recall_3°": AngleRecall(3.0, "yaw_max"),
"yaw_recall_5°": AngleRecall(5.0, "yaw_max"),
}