Spaces:
Running
Running
File size: 6,018 Bytes
1bfbd08 0b88271 1bfbd08 5078caa 1bfbd08 35867bd 1bfbd08 35867bd 0b88271 1bfbd08 0b88271 35867bd 0b88271 1bfbd08 35867bd 1bfbd08 35867bd 1bfbd08 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
import os
import torch
import subprocess
from pathlib import Path
from ..utils.base_model import BaseModel
from .. import logger
from .networks.dkm.models.model_zoo.DKMv3 import DKMv3
weight_path = Path(__file__).parent / 'networks' / 'dkm'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class GIM(BaseModel):
default_conf = {
"model_name": "gim_dkm_100h.ckpt",
"match_threshold": 0.2,
"checkpoint_dir": weight_path,
}
required_inputs = [
"image0",
"image1",
]
# Models exported using
# dkm_models = {
# "DKMv3_outdoor.pth": "https://github.com/Parskatt/storage/releases/download/dkmv3/DKMv3_outdoor.pth",
# "DKMv3_indoor.pth": "https://github.com/Parskatt/storage/releases/download/dkmv3/DKMv3_indoor.pth",
# }
def _init(self, conf):
model_path = weight_path / conf["model_name"]
# Download the model.
if not model_path.exists():
model_path.parent.mkdir(exist_ok=True)
link = self.dkm_models[conf["model_name"]]
cmd = ["wget", link, "-O", str(model_path)]
logger.info(f"Downloading the DKMv3 model with `{cmd}`.")
subprocess.run(cmd, check=True)
# logger.info(f"Loading GIM model...")
# self.net = DKMv3(path_to_weights=str(model_path), device=device)
model = DKMv3(None, 672, 896, upsample_preds=True)
checkpoints_path = str(model_path)
state_dict = torch.load(checkpoints_path, map_location='cpu')
if 'state_dict' in state_dict.keys(): state_dict = state_dict['state_dict']
for k in list(state_dict.keys()):
if k.startswith('model.'):
state_dict[k.replace('model.', '', 1)] = state_dict.pop(k)
if 'encoder.net.fc' in k:
state_dict.pop(k)
model.load_state_dict(state_dict)
self.net = model
def _forward(self, data):
# img0 = data["image0"].cpu().numpy().squeeze() * 255
# img1 = data["image1"].cpu().numpy().squeeze() * 255
# img0 = img0.transpose(1, 2, 0)
# img1 = img1.transpose(1, 2, 0)
# img0 = Image.fromarray(img0.astype("uint8"))
# img1 = Image.fromarray(img1.astype("uint8"))
# W_A, H_A = img0.size
# W_B, H_B = img1.size
#
# warp, certainty = self.net.match(img0, img1, device=device)
# matches, certainty = self.net.sample(warp, certainty)
# kpts1, kpts2 = self.net.to_pixel_coordinates(
# matches, H_A, W_A, H_B, W_B
# )
image0, image1 = data['image0'], data['image1']
orig_width0 = image0.shape[3]
orig_height0 = image0.shape[2]
orig_width1 = image1.shape[3]
orig_height1 = image1.shape[2]
aspect_ratio = 896 / 672
new_width0 = max(orig_width0, int(orig_height0 * aspect_ratio))
new_height0 = max(orig_height0, int(orig_width0 / aspect_ratio))
new_width1 = max(orig_width1, int(orig_height1 * aspect_ratio))
new_height1 = max(orig_height1, int(orig_width1 / aspect_ratio))
new_width = max(new_width0, new_width1)
new_height = max(new_height0, new_height1)
pad_height0 = new_height - orig_height0
pad_width0 = new_width - orig_width0
pad_height1 = new_height - orig_height1
pad_width1 = new_width - orig_width1
pad_top0 = pad_height0 // 2
pad_bottom0 = pad_height0 - pad_top0
pad_left0 = pad_width0 // 2
pad_right0 = pad_width0 - pad_left0
pad_top1 = pad_height1 // 2
pad_bottom1 = pad_height1 - pad_top1
pad_left1 = pad_width1 // 2
pad_right1 = pad_width1 - pad_left1
image0 = torch.nn.functional.pad(image0, (pad_left0, pad_right0, pad_top0, pad_bottom0))
image1 = torch.nn.functional.pad(image1, (pad_left1, pad_right1, pad_top1, pad_bottom1))
import datetime
print(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"), 'self.net.match start')
dense_matches, dense_certainty = self.net.match(image0, image1)
print(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"), 'self.net.match end')
print(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"), 'self.net.sample start')
sparse_matches, mconf = self.net.sample(dense_matches, dense_certainty, self.conf["max_keypoints"])
print(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"), 'self.net.sample end')
height0, width0 = image0.shape[-2:]
height1, width1 = image1.shape[-2:]
kpts0 = sparse_matches[:, :2]
kpts1 = sparse_matches[:, 2:]
kpts0 = torch.stack((width0 * (kpts0[:, 0] + 1) / 2, height0 * (kpts0[:, 1] + 1) / 2), dim=-1, )
kpts1 = torch.stack((width1 * (kpts1[:, 0] + 1) / 2, height1 * (kpts1[:, 1] + 1) / 2), dim=-1, )
b_ids, i_ids = torch.where(mconf[None])
# before padding
kpts0 -= kpts0.new_tensor((pad_left0, pad_top0))[None]
kpts1 -= kpts1.new_tensor((pad_left1, pad_top1))[None]
mask = (kpts0[:, 0] > 0) & \
(kpts0[:, 1] > 0) & \
(kpts1[:, 0] > 0) & \
(kpts1[:, 1] > 0)
mask = mask & \
(kpts0[:, 0] <= (orig_width0 - 1)) & \
(kpts1[:, 0] <= (orig_width1 - 1)) & \
(kpts0[:, 1] <= (orig_height0 - 1)) & \
(kpts1[:, 1] <= (orig_height1 - 1))
pred = {
'keypoints0': kpts0[i_ids],
'keypoints1': kpts1[i_ids],
'confidence': mconf[i_ids],
'batch_indexes': b_ids,
}
scores, b_ids = pred['confidence'], pred['batch_indexes']
kpts0, kpts1 = pred['keypoints0'], pred['keypoints1']
pred['confidence'], pred['batch_indexes'] = scores[mask], b_ids[mask]
pred['keypoints0'], pred['keypoints1'] = kpts0[mask], kpts1[mask]
out = {"keypoints0": pred['keypoints0'], "keypoints1": pred['keypoints1']}
return out
|