File size: 1,175 Bytes
9b7fcdb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 |
"""
"XFeat: Accelerated Features for Lightweight Image Matching, CVPR 2024."
https://www.verlab.dcc.ufmg.br/descriptors/xfeat_cvpr24/
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
class InterpolateSparse2d(nn.Module):
""" Efficiently interpolate tensor at given sparse 2D positions. """
def __init__(self, mode = 'bicubic', align_corners = False):
super().__init__()
self.mode = mode
self.align_corners = align_corners
def normgrid(self, x, H, W):
""" Normalize coords to [-1,1]. """
return 2. * (x/(torch.tensor([W-1, H-1], device = x.device, dtype = x.dtype))) - 1.
def forward(self, x, pos, H, W):
"""
Input
x: [B, C, H, W] feature tensor
pos: [B, N, 2] tensor of positions
H, W: int, original resolution of input 2d positions -- used in normalization [-1,1]
Returns
[B, N, C] sampled channels at 2d positions
"""
grid = self.normgrid(pos, H, W).unsqueeze(-2).to(x.dtype)
x = F.grid_sample(x, grid, mode = self.mode , align_corners = False)
return x.permute(0,2,3,1).squeeze(-2) |