xfeat / modules /xfeat.py
qubvel-hf's picture
qubvel-hf HF staff
Clean proj with LFS
9b7fcdb
"""
"XFeat: Accelerated Features for Lightweight Image Matching, CVPR 2024."
https://www.verlab.dcc.ufmg.br/descriptors/xfeat_cvpr24/
"""
import numpy as np
import os
import torch
import torch.nn.functional as F
import tqdm
from modules.model import *
from modules.interpolator import InterpolateSparse2d
class XFeat(nn.Module):
"""
Implements the inference module for XFeat.
It supports inference for both sparse and semi-dense feature extraction & matching.
"""
def __init__(self, weights = os.path.abspath(os.path.dirname(__file__)) + '/../weights/xfeat.pt', top_k = 4096):
super().__init__()
self.dev = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.net = XFeatModel().to(self.dev).eval()
self.top_k = top_k
if weights is not None:
if isinstance(weights, str):
print('loading weights from: ' + weights)
self.net.load_state_dict(torch.load(weights, map_location=self.dev))
else:
self.net.load_state_dict(weights)
self.interpolator = InterpolateSparse2d('bicubic')
@torch.inference_mode()
def detectAndCompute(self, x, top_k = None):
"""
Compute sparse keypoints & descriptors. Supports batched mode.
input:
x -> torch.Tensor(B, C, H, W): grayscale or rgb image
top_k -> int: keep best k features
return:
List[Dict]:
'keypoints' -> torch.Tensor(N, 2): keypoints (x,y)
'scores' -> torch.Tensor(N,): keypoint scores
'descriptors' -> torch.Tensor(N, 64): local features
"""
if top_k is None: top_k = self.top_k
x, rh1, rw1 = self.preprocess_tensor(x)
B, _, _H1, _W1 = x.shape
M1, K1, H1 = self.net(x)
M1 = F.normalize(M1, dim=1)
#Convert logits to heatmap and extract kpts
K1h = self.get_kpts_heatmap(K1)
mkpts = self.NMS(K1h, threshold=0.05, kernel_size=5)
#Compute reliability scores
_nearest = InterpolateSparse2d('nearest')
_bilinear = InterpolateSparse2d('bilinear')
scores = (_nearest(K1h, mkpts, _H1, _W1) * _bilinear(H1, mkpts, _H1, _W1)).squeeze(-1)
scores[torch.all(mkpts == 0, dim=-1)] = -1
#Select top-k features
idxs = torch.argsort(-scores)
mkpts_x = torch.gather(mkpts[...,0], -1, idxs)[:, :top_k]
mkpts_y = torch.gather(mkpts[...,1], -1, idxs)[:, :top_k]
mkpts = torch.cat([mkpts_x[...,None], mkpts_y[...,None]], dim=-1)
scores = torch.gather(scores, -1, idxs)[:, :top_k]
#Interpolate descriptors at kpts positions
feats = self.interpolator(M1, mkpts, H = _H1, W = _W1)
#L2-Normalize
feats = F.normalize(feats, dim=-1)
#Correct kpt scale
mkpts = mkpts * torch.tensor([rw1,rh1], device=mkpts.device).view(1, 1, -1)
valid = scores > 0
return [
{'keypoints': mkpts[b][valid[b]],
'scores': scores[b][valid[b]],
'descriptors': feats[b][valid[b]]} for b in range(B)
]
@torch.inference_mode()
def detectAndComputeDense(self, x, top_k = None, multiscale = True):
"""
Compute dense *and coarse* descriptors. Supports batched mode.
input:
x -> torch.Tensor(B, C, H, W): grayscale or rgb image
top_k -> int: keep best k features
return: features sorted by their reliability score -- from most to least
List[Dict]:
'keypoints' -> torch.Tensor(top_k, 2): coarse keypoints
'scales' -> torch.Tensor(top_k,): extraction scale
'descriptors' -> torch.Tensor(top_k, 64): coarse local features
"""
if top_k is None: top_k = self.top_k
if multiscale:
mkpts, sc, feats = self.extract_dualscale(x, top_k)
else:
mkpts, feats = self.extractDense(x, top_k)
sc = torch.ones(mkpts.shape[:2], device=mkpts.device)
return {'keypoints': mkpts,
'descriptors': feats,
'scales': sc }
@torch.inference_mode()
def match_xfeat(self, img1, img2, top_k = None, min_cossim = -1):
"""
Simple extractor and MNN matcher.
For simplicity it does not support batched mode due to possibly different number of kpts.
input:
img1 -> torch.Tensor (1,C,H,W) or np.ndarray (H,W,C): grayscale or rgb image.
img2 -> torch.Tensor (1,C,H,W) or np.ndarray (H,W,C): grayscale or rgb image.
top_k -> int: keep best k features
returns:
mkpts_0, mkpts_1 -> np.ndarray (N,2) xy coordinate matches from image1 to image2
"""
if top_k is None: top_k = self.top_k
img1 = self.parse_input(img1)
img2 = self.parse_input(img2)
out1 = self.detectAndCompute(img1, top_k=top_k)[0]
out2 = self.detectAndCompute(img2, top_k=top_k)[0]
idxs0, idxs1 = self.match(out1['descriptors'], out2['descriptors'], min_cossim=min_cossim )
return out1['keypoints'][idxs0].cpu().numpy(), out2['keypoints'][idxs1].cpu().numpy()
@torch.inference_mode()
def match_xfeat_star(self, im_set1, im_set2, top_k = None):
"""
Extracts coarse feats, then match pairs and finally refine matches, currently supports batched mode.
input:
im_set1 -> torch.Tensor(B, C, H, W) or np.ndarray (H,W,C): grayscale or rgb images.
im_set2 -> torch.Tensor(B, C, H, W) or np.ndarray (H,W,C): grayscale or rgb images.
top_k -> int: keep best k features
returns:
matches -> List[torch.Tensor(N, 4)]: List of size B containing tensor of pairwise matches (x1,y1,x2,y2)
"""
if top_k is None: top_k = self.top_k
im_set1 = self.parse_input(im_set1)
im_set2 = self.parse_input(im_set2)
#Compute coarse feats
out1 = self.detectAndComputeDense(im_set1, top_k=top_k)
out2 = self.detectAndComputeDense(im_set2, top_k=top_k)
#Match batches of pairs
idxs_list = self.batch_match(out1['descriptors'], out2['descriptors'] )
B = len(im_set1)
#Refine coarse matches
#this part is harder to batch, currently iterate
matches = []
for b in range(B):
matches.append(self.refine_matches(out1, out2, matches = idxs_list, batch_idx=b))
return matches if B > 1 else (matches[0][:, :2].cpu().numpy(), matches[0][:, 2:].cpu().numpy())
def preprocess_tensor(self, x):
""" Guarantee that image is divisible by 32 to avoid aliasing artifacts. """
if isinstance(x, np.ndarray) and x.shape == 3:
x = torch.tensor(x).permute(2,0,1)[None]
x = x.to(self.dev).float()
H, W = x.shape[-2:]
_H, _W = (H//32) * 32, (W//32) * 32
rh, rw = H/_H, W/_W
x = F.interpolate(x, (_H, _W), mode='bilinear', align_corners=False)
return x, rh, rw
def get_kpts_heatmap(self, kpts, softmax_temp = 1.0):
scores = F.softmax(kpts*softmax_temp, 1)[:, :64]
B, _, H, W = scores.shape
heatmap = scores.permute(0, 2, 3, 1).reshape(B, H, W, 8, 8)
heatmap = heatmap.permute(0, 1, 3, 2, 4).reshape(B, 1, H*8, W*8)
return heatmap
def NMS(self, x, threshold = 0.05, kernel_size = 5):
B, _, H, W = x.shape
pad=kernel_size//2
local_max = nn.MaxPool2d(kernel_size=kernel_size, stride=1, padding=pad)(x)
pos = (x == local_max) & (x > threshold)
pos_batched = [k.nonzero()[..., 1:].flip(-1) for k in pos]
pad_val = max([len(x) for x in pos_batched])
pos = torch.zeros((B, pad_val, 2), dtype=torch.long, device=x.device)
#Pad kpts and build (B, N, 2) tensor
for b in range(len(pos_batched)):
pos[b, :len(pos_batched[b]), :] = pos_batched[b]
return pos
@torch.inference_mode()
def batch_match(self, feats1, feats2, min_cossim = -1):
B = len(feats1)
cossim = torch.bmm(feats1, feats2.permute(0,2,1))
match12 = torch.argmax(cossim, dim=-1)
match21 = torch.argmax(cossim.permute(0,2,1), dim=-1)
idx0 = torch.arange(len(match12[0]), device=match12.device)
batched_matches = []
for b in range(B):
mutual = match21[b][match12[b]] == idx0
if min_cossim > 0:
cossim_max, _ = cossim[b].max(dim=1)
good = cossim_max > min_cossim
idx0_b = idx0[mutual & good]
idx1_b = match12[b][mutual & good]
else:
idx0_b = idx0[mutual]
idx1_b = match12[b][mutual]
batched_matches.append((idx0_b, idx1_b))
return batched_matches
def subpix_softmax2d(self, heatmaps, temp = 3):
N, H, W = heatmaps.shape
heatmaps = torch.softmax(temp * heatmaps.view(-1, H*W), -1).view(-1, H, W)
x, y = torch.meshgrid(torch.arange(W, device = heatmaps.device ), torch.arange(H, device = heatmaps.device ), indexing = 'xy')
x = x - (W//2)
y = y - (H//2)
coords_x = (x[None, ...] * heatmaps)
coords_y = (y[None, ...] * heatmaps)
coords = torch.cat([coords_x[..., None], coords_y[..., None]], -1).view(N, H*W, 2)
coords = coords.sum(1)
return coords
def refine_matches(self, d0, d1, matches, batch_idx, fine_conf = 0.25):
idx0, idx1 = matches[batch_idx]
feats1 = d0['descriptors'][batch_idx][idx0]
feats2 = d1['descriptors'][batch_idx][idx1]
mkpts_0 = d0['keypoints'][batch_idx][idx0]
mkpts_1 = d1['keypoints'][batch_idx][idx1]
sc0 = d0['scales'][batch_idx][idx0]
#Compute fine offsets
offsets = self.net.fine_matcher(torch.cat([feats1, feats2],dim=-1))
conf = F.softmax(offsets*3, dim=-1).max(dim=-1)[0]
offsets = self.subpix_softmax2d(offsets.view(-1,8,8))
mkpts_0 += offsets* (sc0[:,None]) #*0.9 #* (sc0[:,None])
mask_good = conf > fine_conf
mkpts_0 = mkpts_0[mask_good]
mkpts_1 = mkpts_1[mask_good]
return torch.cat([mkpts_0, mkpts_1], dim=-1)
@torch.inference_mode()
def match(self, feats1, feats2, min_cossim = 0.82):
cossim = feats1 @ feats2.t()
cossim_t = feats2 @ feats1.t()
_, match12 = cossim.max(dim=1)
_, match21 = cossim_t.max(dim=1)
idx0 = torch.arange(len(match12), device=match12.device)
mutual = match21[match12] == idx0
if min_cossim > 0:
cossim, _ = cossim.max(dim=1)
good = cossim > min_cossim
idx0 = idx0[mutual & good]
idx1 = match12[mutual & good]
else:
idx0 = idx0[mutual]
idx1 = match12[mutual]
return idx0, idx1
def create_xy(self, h, w, dev):
y, x = torch.meshgrid(torch.arange(h, device = dev),
torch.arange(w, device = dev), indexing='ij')
xy = torch.cat([x[..., None],y[..., None]], -1).reshape(-1,2)
return xy
def extractDense(self, x, top_k = 8_000):
if top_k < 1:
top_k = 100_000_000
x, rh1, rw1 = self.preprocess_tensor(x)
M1, K1, H1 = self.net(x)
B, C, _H1, _W1 = M1.shape
xy1 = (self.create_xy(_H1, _W1, M1.device) * 8).expand(B,-1,-1)
M1 = M1.permute(0,2,3,1).reshape(B, -1, C)
H1 = H1.permute(0,2,3,1).reshape(B, -1)
_, top_k = torch.topk(H1, k = min(len(H1[0]), top_k), dim=-1)
feats = torch.gather( M1, 1, top_k[...,None].expand(-1, -1, 64))
mkpts = torch.gather(xy1, 1, top_k[...,None].expand(-1, -1, 2))
mkpts = mkpts * torch.tensor([rw1, rh1], device=mkpts.device).view(1,-1)
return mkpts, feats
def extract_dualscale(self, x, top_k, s1 = 0.6, s2 = 1.3):
x1 = F.interpolate(x, scale_factor=s1, align_corners=False, mode='bilinear')
x2 = F.interpolate(x, scale_factor=s2, align_corners=False, mode='bilinear')
B, _, _, _ = x.shape
mkpts_1, feats_1 = self.extractDense(x1, int(top_k*0.20))
mkpts_2, feats_2 = self.extractDense(x2, int(top_k*0.80))
mkpts = torch.cat([mkpts_1/s1, mkpts_2/s2], dim=1)
sc1 = torch.ones(mkpts_1.shape[:2], device=mkpts_1.device) * (1/s1)
sc2 = torch.ones(mkpts_2.shape[:2], device=mkpts_2.device) * (1/s2)
sc = torch.cat([sc1, sc2],dim=1)
feats = torch.cat([feats_1, feats_2], dim=1)
return mkpts, sc, feats
def parse_input(self, x):
if len(x.shape) == 3:
x = x[None, ...]
if isinstance(x, np.ndarray):
x = torch.tensor(x).permute(0,3,1,2)/255
return x