File size: 6,018 Bytes
1bfbd08
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0b88271
1bfbd08
 
 
 
5078caa
1bfbd08
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35867bd
 
 
 
1bfbd08
35867bd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0b88271
 
1bfbd08
0b88271
 
35867bd
0b88271
1bfbd08
 
 
 
 
 
 
 
35867bd
 
1bfbd08
 
 
 
 
35867bd
 
 
 
1bfbd08
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import os
import torch
import subprocess

from pathlib import Path
from ..utils.base_model import BaseModel
from .. import logger

from .networks.dkm.models.model_zoo.DKMv3 import DKMv3

weight_path = Path(__file__).parent / 'networks' / 'dkm'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class GIM(BaseModel):
    default_conf = {
        "model_name": "gim_dkm_100h.ckpt",
        "match_threshold": 0.2,
        "checkpoint_dir": weight_path,
    }
    required_inputs = [
        "image0",
        "image1",
    ]
    # Models exported using
    # dkm_models = {
    #     "DKMv3_outdoor.pth": "https://github.com/Parskatt/storage/releases/download/dkmv3/DKMv3_outdoor.pth",
    #     "DKMv3_indoor.pth": "https://github.com/Parskatt/storage/releases/download/dkmv3/DKMv3_indoor.pth",
    # }

    def _init(self, conf):
        model_path = weight_path / conf["model_name"]

        # Download the model.
        if not model_path.exists():
            model_path.parent.mkdir(exist_ok=True)
            link = self.dkm_models[conf["model_name"]]
            cmd = ["wget", link, "-O", str(model_path)]
            logger.info(f"Downloading the DKMv3 model with `{cmd}`.")
            subprocess.run(cmd, check=True)
        # logger.info(f"Loading GIM model...")
        # self.net = DKMv3(path_to_weights=str(model_path), device=device)

        model = DKMv3(None, 672, 896, upsample_preds=True)

        checkpoints_path = str(model_path)
        state_dict = torch.load(checkpoints_path, map_location='cpu')
        if 'state_dict' in state_dict.keys(): state_dict = state_dict['state_dict']
        for k in list(state_dict.keys()):
            if k.startswith('model.'):
                state_dict[k.replace('model.', '', 1)] = state_dict.pop(k)
            if 'encoder.net.fc' in k:
                state_dict.pop(k)
        model.load_state_dict(state_dict)

        self.net = model

    def _forward(self, data):
        # img0 = data["image0"].cpu().numpy().squeeze() * 255
        # img1 = data["image1"].cpu().numpy().squeeze() * 255
        # img0 = img0.transpose(1, 2, 0)
        # img1 = img1.transpose(1, 2, 0)
        # img0 = Image.fromarray(img0.astype("uint8"))
        # img1 = Image.fromarray(img1.astype("uint8"))
        # W_A, H_A = img0.size
        # W_B, H_B = img1.size
        #
        # warp, certainty = self.net.match(img0, img1, device=device)
        # matches, certainty = self.net.sample(warp, certainty)
        # kpts1, kpts2 = self.net.to_pixel_coordinates(
        #     matches, H_A, W_A, H_B, W_B
        # )

        image0, image1 = data['image0'], data['image1']
        orig_width0 = image0.shape[3]
        orig_height0 = image0.shape[2]
        orig_width1 = image1.shape[3]
        orig_height1 = image1.shape[2]
        aspect_ratio = 896 / 672
        new_width0 = max(orig_width0, int(orig_height0 * aspect_ratio))
        new_height0 = max(orig_height0, int(orig_width0 / aspect_ratio))
        new_width1 = max(orig_width1, int(orig_height1 * aspect_ratio))
        new_height1 = max(orig_height1, int(orig_width1 / aspect_ratio))
        new_width = max(new_width0, new_width1)
        new_height = max(new_height0, new_height1)
        pad_height0 = new_height - orig_height0
        pad_width0 = new_width - orig_width0
        pad_height1 = new_height - orig_height1
        pad_width1 = new_width - orig_width1
        pad_top0 = pad_height0 // 2
        pad_bottom0 = pad_height0 - pad_top0
        pad_left0 = pad_width0 // 2
        pad_right0 = pad_width0 - pad_left0
        pad_top1 = pad_height1 // 2
        pad_bottom1 = pad_height1 - pad_top1
        pad_left1 = pad_width1 // 2
        pad_right1 = pad_width1 - pad_left1
        image0 = torch.nn.functional.pad(image0, (pad_left0, pad_right0, pad_top0, pad_bottom0))
        image1 = torch.nn.functional.pad(image1, (pad_left1, pad_right1, pad_top1, pad_bottom1))
        import datetime
        print(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"), 'self.net.match start')
        dense_matches, dense_certainty = self.net.match(image0, image1)
        print(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"), 'self.net.match end')
        print(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"), 'self.net.sample start')
        sparse_matches, mconf = self.net.sample(dense_matches, dense_certainty, self.conf["max_keypoints"])
        print(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"), 'self.net.sample end')
        height0, width0 = image0.shape[-2:]
        height1, width1 = image1.shape[-2:]
        kpts0 = sparse_matches[:, :2]
        kpts1 = sparse_matches[:, 2:]
        kpts0 = torch.stack((width0 * (kpts0[:, 0] + 1) / 2, height0 * (kpts0[:, 1] + 1) / 2), dim=-1, )
        kpts1 = torch.stack((width1 * (kpts1[:, 0] + 1) / 2, height1 * (kpts1[:, 1] + 1) / 2), dim=-1, )
        b_ids, i_ids = torch.where(mconf[None])
        # before padding
        kpts0 -= kpts0.new_tensor((pad_left0, pad_top0))[None]
        kpts1 -= kpts1.new_tensor((pad_left1, pad_top1))[None]
        mask = (kpts0[:, 0] > 0) & \
               (kpts0[:, 1] > 0) & \
               (kpts1[:, 0] > 0) & \
               (kpts1[:, 1] > 0)
        mask = mask & \
               (kpts0[:, 0] <= (orig_width0 - 1)) & \
               (kpts1[:, 0] <= (orig_width1 - 1)) & \
               (kpts0[:, 1] <= (orig_height0 - 1)) & \
               (kpts1[:, 1] <= (orig_height1 - 1))
        pred = {
            'keypoints0': kpts0[i_ids],
            'keypoints1': kpts1[i_ids],
            'confidence': mconf[i_ids],
            'batch_indexes': b_ids,
        }
        scores, b_ids = pred['confidence'], pred['batch_indexes']
        kpts0, kpts1 = pred['keypoints0'], pred['keypoints1']
        pred['confidence'], pred['batch_indexes'] = scores[mask], b_ids[mask]
        pred['keypoints0'], pred['keypoints1'] = kpts0[mask], kpts1[mask]

        out = {"keypoints0": pred['keypoints0'], "keypoints1": pred['keypoints1']}
        return out