|
|
|
|
|
|
|
|
|
from pdb import set_trace as bb |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import torchvision.transforms as tvf |
|
|
|
from core.conv_mixer import ConvMixer |
|
|
|
norm_RGB = tvf.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
|
|
|
|
|
class PixelDesc (nn.Module): |
|
def __init__(self, path='models/PUMP_st.pt'): |
|
super().__init__() |
|
state_dict = torch.load( path, 'cpu' ) |
|
self.pixel_desc = ConvMixer(output_dim=128, hidden_dim=512, depth=7, patch_size=4, kernel_size=9).eval() |
|
self.pixel_desc.load_state_dict(state_dict) |
|
|
|
def configure(self, pipeline): |
|
|
|
pipeline.__class__ = type(type(pipeline).__name__+'_Trained', (DescPipeline, type(pipeline)), {}) |
|
return self |
|
|
|
def get_atomic_patch_size(self): |
|
return 4 |
|
|
|
def forward(self, img, stride=1, offset=0): |
|
if img.ndim == 3: img = img[None] |
|
trf = torch.eye(3, device=img.device) |
|
|
|
desc = self.pixel_desc( img ) |
|
desc = desc[..., offset::stride, offset::stride].contiguous() |
|
return desc, trf |
|
|
|
|
|
class DescPipeline: |
|
def extract_descs(self, img1, img2, dtype=None): |
|
|
|
img1, sca1 = self.demultiplex_img_trf(img1) |
|
img2, sca2 = self.demultiplex_img_trf(img2) |
|
|
|
|
|
fimg1, fimg2 = [norm_RGB(img.type(dtype)/255) for img in (img1, img2)] |
|
|
|
self.pixel_desc.type(fimg1.dtype) |
|
desc1, trf1 = self.pixel_desc(fimg1, stride=4, offset=2) |
|
desc2, trf2 = self.pixel_desc(fimg2) |
|
return (img1, img2), (desc1.type(dtype), desc2.type(dtype)), (sca1@trf1, sca2@trf2) |
|
|
|
def first_level(self, desc1, desc2, **kw): |
|
B, C, H, W = desc1.shape |
|
weights = desc1.permute(0, 2, 3, 1).view(H*W, C, 1, 1) |
|
corr = F.conv2d(desc2, weights, padding=0, bias=None)[0] |
|
norms = torch.ones(desc1.shape[-2:], device=corr.device) |
|
return corr.view(desc1.shape[-2:]+desc2.shape[-2:]), norms |
|
|