File size: 8,287 Bytes
5ed9923
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
import random

import numpy as np
import PIL
import torch
import torchvision

from src.mast3r_src.dust3r.dust3r.datasets.utils.transforms import ImgNorm
from src.mast3r_src.dust3r.dust3r.utils.geometry import depthmap_to_absolute_camera_coordinates, geotrf
from src.mast3r_src.dust3r.dust3r.utils.misc import invalid_to_zeros
import src.mast3r_src.dust3r.dust3r.datasets.utils.cropping as cropping


def crop_resize_if_necessary(image, depthmap, intrinsics, resolution):
    """Adapted from DUST3R's Co3D dataset implementation"""

    if not isinstance(image, PIL.Image.Image):
        image = PIL.Image.fromarray(image)

    # Downscale with lanczos interpolation so that image.size == resolution cropping centered on the principal point
    # The new window will be a rectangle of size (2*min_margin_x, 2*min_margin_y) centered on (cx,cy)
    W, H = image.size
    cx, cy = intrinsics[:2, 2].round().astype(int)
    min_margin_x = min(cx, W - cx)
    min_margin_y = min(cy, H - cy)
    assert min_margin_x > W / 5
    assert min_margin_y > H / 5
    l, t = cx - min_margin_x, cy - min_margin_y
    r, b = cx + min_margin_x, cy + min_margin_y
    crop_bbox = (l, t, r, b)
    image, depthmap, intrinsics = cropping.crop_image_depthmap(image, depthmap, intrinsics, crop_bbox)

    # High-quality Lanczos down-scaling
    target_resolution = np.array(resolution)
    image, depthmap, intrinsics = cropping.rescale_image_depthmap(image, depthmap, intrinsics, target_resolution)

    # Actual cropping (if necessary) with bilinear interpolation
    intrinsics2 = cropping.camera_matrix_of_crop(intrinsics, image.size, resolution, offset_factor=0.5)
    crop_bbox = cropping.bbox_from_intrinsics_in_out(intrinsics, intrinsics2, resolution)
    image, depthmap, intrinsics2 = cropping.crop_image_depthmap(image, depthmap, intrinsics, crop_bbox)

    return image, depthmap, intrinsics2


class DUST3RSplattingDataset(torch.utils.data.Dataset):

    def __init__(self, data, coverage, resolution, num_epochs_per_epoch=1, alpha=0.3, beta=0.3):

        super(DUST3RSplattingDataset, self).__init__()
        self.data = data
        self.coverage = coverage

        self.num_context_views = 2
        self.num_target_views = 3

        self.resolution = resolution
        self.transform = ImgNorm
        self.org_transform = torchvision.transforms.ToTensor()
        self.num_epochs_per_epoch = num_epochs_per_epoch

        self.alpha = alpha
        self.beta = beta

    def __getitem__(self, idx):

        sequence = self.data.sequences[idx // self.num_epochs_per_epoch]
        sequence_length = len(self.data.color_paths[sequence])

        context_views, target_views = self.sample(sequence, self.num_target_views, self.alpha, self.beta)

        views = {"context": [], "target": [], "scene": sequence}

        # Fetch the context views
        for c_view in context_views:

            assert c_view < sequence_length, f"Invalid view index: {c_view}, sequence length: {sequence_length}, c_views: {context_views}"

            view = self.data.get_view(sequence, c_view, self.resolution)

            # Transform the input
            view['img'] = self.transform(view['original_img'])
            view['original_img'] = self.org_transform(view['original_img'])

            # Create the point cloud and validity mask
            pts3d, valid_mask = depthmap_to_absolute_camera_coordinates(**view)
            view['pts3d'] = pts3d
            view['valid_mask'] = valid_mask & np.isfinite(pts3d).all(axis=-1)
            assert view['valid_mask'].any(), f"Invalid mask for sequence: {sequence}, view: {c_view}"

            views['context'].append(view)

        # Fetch the target views
        for t_view in target_views:

            view = self.data.get_view(sequence, t_view, self.resolution)
            view['original_img'] = self.org_transform(view['original_img'])
            views['target'].append(view)

        return views

    def __len__(self):

        return len(self.data.sequences) * self.num_epochs_per_epoch

    def sample(self, sequence, num_target_views, context_overlap_threshold=0.5, target_overlap_threshold=0.6):

        first_context_view = random.randint(0, len(self.data.color_paths[sequence]) - 1)

        # Pick a second context view that has sufficient overlap with the first context view
        valid_second_context_views = []
        for frame in range(len(self.data.color_paths[sequence])):
            if frame == first_context_view:
                continue
            overlap = self.coverage[sequence][first_context_view][frame]
            if overlap > context_overlap_threshold:
                valid_second_context_views.append(frame)
        if len(valid_second_context_views) > 0:
            second_context_view = random.choice(valid_second_context_views)

        # If there are no valid second context views, pick the best one
        else:
            best_view = None
            best_overlap = None
            for frame in range(len(self.data.color_paths[sequence])):
                if frame == first_context_view:
                    continue
                overlap = self.coverage[sequence][first_context_view][frame]
                if best_view is None or overlap > best_overlap:
                    best_view = frame
                    best_overlap = overlap
            second_context_view = best_view

        # Pick the target views
        valid_target_views = []
        for frame in range(len(self.data.color_paths[sequence])):
            if frame == first_context_view or frame == second_context_view:
                continue
            overlap_max = max(
                self.coverage[sequence][first_context_view][frame],
                self.coverage[sequence][second_context_view][frame]
            )
            if overlap_max > target_overlap_threshold:
                valid_target_views.append(frame)
        if len(valid_target_views) >= num_target_views:
            target_views = random.sample(valid_target_views, num_target_views)

        # If there are not enough valid target views, pick the best ones
        else:
            overlaps = []
            for frame in range(len(self.data.color_paths[sequence])):
                if frame == first_context_view or frame == second_context_view:
                    continue
                overlap = max(
                    self.coverage[sequence][first_context_view][frame],
                    self.coverage[sequence][second_context_view][frame]
                )
                overlaps.append((frame, overlap))
            overlaps.sort(key=lambda x: x[1], reverse=True)
            target_views = [frame for frame, _ in overlaps[:num_target_views]]

        return [first_context_view, second_context_view], target_views


class DUST3RSplattingTestDataset(torch.utils.data.Dataset):

    def __init__(self, data, samples, resolution):

        self.data = data
        self.samples = samples

        self.resolution = resolution
        self.transform = ImgNorm
        self.org_transform = torchvision.transforms.ToTensor()

    def get_view(self, sequence, c_view):

        view = self.data.get_view(sequence, c_view, self.resolution)

        # Transform the input
        view['img'] = self.transform(view['original_img'])
        view['original_img'] = self.org_transform(view['original_img'])

        # Create the point cloud and validity mask
        pts3d, valid_mask = depthmap_to_absolute_camera_coordinates(**view)
        view['pts3d'] = pts3d
        view['valid_mask'] = valid_mask & np.isfinite(pts3d).all(axis=-1)
        assert view['valid_mask'].any(), f"Invalid mask for sequence: {sequence}, view: {c_view}"

        return view

    def __getitem__(self, idx):

        sequence, c_view_1, c_view_2, target_view = self.samples[idx]
        c_view_1, c_view_2, target_view = int(c_view_1), int(c_view_2), int(target_view)
        fetched_c_view_1 = self.get_view(sequence, c_view_1)
        fetched_c_view_2 = self.get_view(sequence, c_view_2)
        fetched_target_view = self.get_view(sequence, target_view)

        views = {"context": [fetched_c_view_1, fetched_c_view_2], "target": [fetched_target_view], "scene": sequence}

        return views

    def __len__(self):

        return len(self.samples)