PUMP / core /pixel_desc.py
Philippe Weinzaepfel
huggingface demo
3ef85e9
# Copyright 2022-present NAVER Corp.
# CC BY-NC-SA 4.0
# Available only for non-commercial use
from pdb import set_trace as bb
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as tvf
from core.conv_mixer import ConvMixer
norm_RGB = tvf.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
class PixelDesc (nn.Module):
def __init__(self, path='models/PUMP_st.pt'):
super().__init__()
state_dict = torch.load( path, 'cpu' )
self.pixel_desc = ConvMixer(output_dim=128, hidden_dim=512, depth=7, patch_size=4, kernel_size=9).eval()
self.pixel_desc.load_state_dict(state_dict)
def configure(self, pipeline):
# hot-update of the default HOG-based pipeline
pipeline.__class__ = type(type(pipeline).__name__+'_Trained', (DescPipeline, type(pipeline)), {})
return self
def get_atomic_patch_size(self):
return 4
def forward(self, img, stride=1, offset=0):
if img.ndim == 3: img = img[None]
trf = torch.eye(3, device=img.device)
desc = self.pixel_desc( img )
desc = desc[..., offset::stride, offset::stride].contiguous() # free memory
return desc, trf
class DescPipeline:
def extract_descs(self, img1, img2, dtype=None):
# this will rotate the image if needed
img1, sca1 = self.demultiplex_img_trf(img1)
img2, sca2 = self.demultiplex_img_trf(img2)
# convert to float and normalize std
fimg1, fimg2 = [norm_RGB(img.type(dtype)/255) for img in (img1, img2)]
self.pixel_desc.type(fimg1.dtype)
desc1, trf1 = self.pixel_desc(fimg1, stride=4, offset=2)
desc2, trf2 = self.pixel_desc(fimg2)
return (img1, img2), (desc1.type(dtype), desc2.type(dtype)), (sca1@trf1, sca2@trf2)
def first_level(self, desc1, desc2, **kw):
B, C, H, W = desc1.shape
weights = desc1.permute(0, 2, 3, 1).view(H*W, C, 1, 1) # rearrange(desc1, '1 C H W -> (H W) C 1 1')
corr = F.conv2d(desc2, weights, padding=0, bias=None)[0]
norms = torch.ones(desc1.shape[-2:], device=corr.device)
return corr.view(desc1.shape[-2:]+desc2.shape[-2:]), norms