Working demo
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .DS_Store +0 -0
- .gitattributes +67 -0
- README.md +2 -2
- app.py +161 -0
- arial.ttf +0 -0
- checkpoints/sid_1-500_m3lteacher.pth +3 -0
- checkpoints/sid_1-500_mtteacher.pth +3 -0
- classcolors.png +3 -0
- colors.pkl +3 -0
- datasets/__init__.py +0 -0
- datasets/__pycache__/__init__.cpython-36.pyc +0 -0
- datasets/__pycache__/__init__.cpython-38.pyc +0 -0
- datasets/__pycache__/__init__.cpython-39.pyc +0 -0
- datasets/__pycache__/base_dataset.cpython-36.pyc +0 -0
- datasets/__pycache__/base_dataset.cpython-38.pyc +0 -0
- datasets/__pycache__/base_dataset.cpython-39.pyc +0 -0
- datasets/__pycache__/citysundepth.cpython-36.pyc +0 -0
- datasets/__pycache__/citysundepth.cpython-39.pyc +0 -0
- datasets/__pycache__/citysunrgb.cpython-36.pyc +0 -0
- datasets/__pycache__/citysunrgb.cpython-38.pyc +0 -0
- datasets/__pycache__/citysunrgb.cpython-39.pyc +0 -0
- datasets/__pycache__/citysunrgbd.cpython-36.pyc +0 -0
- datasets/__pycache__/citysunrgbd.cpython-38.pyc +0 -0
- datasets/__pycache__/get_dataset.cpython-36.pyc +0 -0
- datasets/__pycache__/get_dataset.cpython-39.pyc +0 -0
- datasets/__pycache__/preprocessors.cpython-36.pyc +0 -0
- datasets/__pycache__/preprocessors.cpython-38.pyc +0 -0
- datasets/__pycache__/tfnyu.cpython-36.pyc +0 -0
- datasets/base_dataset.py +128 -0
- datasets/citysunrgbd.py +67 -0
- datasets/get_dataset.py +146 -0
- datasets/preprocessors.py +144 -0
- examples/.DS_Store +0 -0
- examples/depth/camera_0a529ee09fc04037831902b83073862b_office_4_frame_45_domain.png +3 -0
- examples/depth/camera_0a529ee09fc04037831902b83073862b_office_4_frame_46_domain.png +3 -0
- examples/depth/camera_0a529ee09fc04037831902b83073862b_office_4_frame_47_domain.png +3 -0
- examples/depth/camera_0a529ee09fc04037831902b83073862b_office_4_frame_48_domain.png +3 -0
- examples/depth/camera_0a529ee09fc04037831902b83073862b_office_4_frame_49_domain.png +3 -0
- examples/depth/camera_0a529ee09fc04037831902b83073862b_office_4_frame_50_domain.png +3 -0
- examples/depth/camera_0a529ee09fc04037831902b83073862b_office_4_frame_51_domain.png +3 -0
- examples/depth/camera_0a529ee09fc04037831902b83073862b_office_4_frame_52_domain.png +3 -0
- examples/depth/camera_0a529ee09fc04037831902b83073862b_office_4_frame_53_domain.png +3 -0
- examples/depth/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_0_domain.png +3 -0
- examples/depth/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_10_domain.png +3 -0
- examples/depth/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_1_domain.png +3 -0
- examples/depth/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_2_domain.png +3 -0
- examples/depth/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_3_domain.png +3 -0
- examples/depth/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_4_domain.png +3 -0
- examples/depth/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_5_domain.png +3 -0
- examples/depth/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_6_domain.png +3 -0
.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
.gitattributes
CHANGED
@@ -32,3 +32,70 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
35 |
+
checkpoints/sid_1-500_m3lteacher.pth filter=lfs diff=lfs merge=lfs -text
|
36 |
+
checkpoints/sid_1-500_mtteacher.pth filter=lfs diff=lfs merge=lfs -text
|
37 |
+
examples/rgb/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_10_domain.png filter=lfs diff=lfs merge=lfs -text
|
38 |
+
examples/rgb/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_9_domain.png filter=lfs diff=lfs merge=lfs -text
|
39 |
+
examples/rgb/camera_0a529ee09fc04037831902b83073862b_office_4_frame_46_domain.png filter=lfs diff=lfs merge=lfs -text
|
40 |
+
examples/rgb/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_4_domain.png filter=lfs diff=lfs merge=lfs -text
|
41 |
+
examples/rgb/camera_0c2610e800274e329f5fac53402855ef_office_32_frame_8_domain.png filter=lfs diff=lfs merge=lfs -text
|
42 |
+
examples/rgb/camera_0a529ee09fc04037831902b83073862b_office_4_frame_53_domain.png filter=lfs diff=lfs merge=lfs -text
|
43 |
+
examples/rgb/camera_0a529ee09fc04037831902b83073862b_office_4_frame_49_domain.png filter=lfs diff=lfs merge=lfs -text
|
44 |
+
examples/rgb/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_8_domain.png filter=lfs diff=lfs merge=lfs -text
|
45 |
+
examples/rgb/camera_0c2610e800274e329f5fac53402855ef_office_32_frame_5_domain.png filter=lfs diff=lfs merge=lfs -text
|
46 |
+
examples/rgb/camera_4e7d4edf7c2a40ae93437d1fbde22043_office_6_frame_47_domain.png filter=lfs diff=lfs merge=lfs -text
|
47 |
+
examples/rgb/camera_0a529ee09fc04037831902b83073862b_office_4_frame_47_domain.png filter=lfs diff=lfs merge=lfs -text
|
48 |
+
examples/rgb/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_1_domain.png filter=lfs diff=lfs merge=lfs -text
|
49 |
+
examples/rgb/camera_0c2610e800274e329f5fac53402855ef_office_32_frame_2_domain.png filter=lfs diff=lfs merge=lfs -text
|
50 |
+
examples/rgb/camera_0c2610e800274e329f5fac53402855ef_office_32_frame_3_domain.png filter=lfs diff=lfs merge=lfs -text
|
51 |
+
examples/rgb/camera_0c2610e800274e329f5fac53402855ef_office_32_frame_9_domain.png filter=lfs diff=lfs merge=lfs -text
|
52 |
+
examples/rgb/camera_0a529ee09fc04037831902b83073862b_office_4_frame_48_domain.png filter=lfs diff=lfs merge=lfs -text
|
53 |
+
examples/rgb/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_7_domain.png filter=lfs diff=lfs merge=lfs -text
|
54 |
+
examples/rgb/camera_0c2610e800274e329f5fac53402855ef_office_32_frame_4_domain.png filter=lfs diff=lfs merge=lfs -text
|
55 |
+
examples/rgb/camera_0c2610e800274e329f5fac53402855ef_office_32_frame_6_domain.png filter=lfs diff=lfs merge=lfs -text
|
56 |
+
examples/rgb/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_5_domain.png filter=lfs diff=lfs merge=lfs -text
|
57 |
+
examples/rgb/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_0_domain.png filter=lfs diff=lfs merge=lfs -text
|
58 |
+
examples/rgb/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_3_domain.png filter=lfs diff=lfs merge=lfs -text
|
59 |
+
examples/rgb/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_6_domain.png filter=lfs diff=lfs merge=lfs -text
|
60 |
+
examples/rgb/camera_0c2610e800274e329f5fac53402855ef_office_32_frame_0_domain.png filter=lfs diff=lfs merge=lfs -text
|
61 |
+
examples/rgb/camera_0c2610e800274e329f5fac53402855ef_office_32_frame_10_domain.png filter=lfs diff=lfs merge=lfs -text
|
62 |
+
examples/rgb/camera_0c2610e800274e329f5fac53402855ef_office_32_frame_7_domain.png filter=lfs diff=lfs merge=lfs -text
|
63 |
+
examples/rgb/camera_0a529ee09fc04037831902b83073862b_office_4_frame_52_domain.png filter=lfs diff=lfs merge=lfs -text
|
64 |
+
examples/rgb/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_2_domain.png filter=lfs diff=lfs merge=lfs -text
|
65 |
+
examples/rgb/camera_0c2610e800274e329f5fac53402855ef_office_32_frame_1_domain.png filter=lfs diff=lfs merge=lfs -text
|
66 |
+
examples/rgb/camera_0a529ee09fc04037831902b83073862b_office_4_frame_50_domain.png filter=lfs diff=lfs merge=lfs -text
|
67 |
+
examples/rgb/camera_0a529ee09fc04037831902b83073862b_office_4_frame_51_domain.png filter=lfs diff=lfs merge=lfs -text
|
68 |
+
examples/rgb/camera_0a529ee09fc04037831902b83073862b_office_4_frame_45_domain.png filter=lfs diff=lfs merge=lfs -text
|
69 |
+
examples/depth/camera_0a529ee09fc04037831902b83073862b_office_4_frame_49_domain.png filter=lfs diff=lfs merge=lfs -text
|
70 |
+
examples/depth/camera_0a529ee09fc04037831902b83073862b_office_4_frame_53_domain.png filter=lfs diff=lfs merge=lfs -text
|
71 |
+
examples/depth/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_3_domain.png filter=lfs diff=lfs merge=lfs -text
|
72 |
+
examples/depth/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_6_domain.png filter=lfs diff=lfs merge=lfs -text
|
73 |
+
examples/depth/camera_0c2610e800274e329f5fac53402855ef_office_32_frame_0_domain.png filter=lfs diff=lfs merge=lfs -text
|
74 |
+
examples/depth/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_8_domain.png filter=lfs diff=lfs merge=lfs -text
|
75 |
+
examples/depth/camera_0c2610e800274e329f5fac53402855ef_office_32_frame_3_domain.png filter=lfs diff=lfs merge=lfs -text
|
76 |
+
examples/depth/camera_0c2610e800274e329f5fac53402855ef_office_32_frame_8_domain.png filter=lfs diff=lfs merge=lfs -text
|
77 |
+
examples/depth/camera_0a529ee09fc04037831902b83073862b_office_4_frame_46_domain.png filter=lfs diff=lfs merge=lfs -text
|
78 |
+
examples/depth/camera_0a529ee09fc04037831902b83073862b_office_4_frame_52_domain.png filter=lfs diff=lfs merge=lfs -text
|
79 |
+
examples/depth/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_5_domain.png filter=lfs diff=lfs merge=lfs -text
|
80 |
+
examples/depth/camera_0c2610e800274e329f5fac53402855ef_office_32_frame_4_domain.png filter=lfs diff=lfs merge=lfs -text
|
81 |
+
examples/depth/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_0_domain.png filter=lfs diff=lfs merge=lfs -text
|
82 |
+
examples/depth/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_10_domain.png filter=lfs diff=lfs merge=lfs -text
|
83 |
+
examples/depth/camera_4e7d4edf7c2a40ae93437d1fbde22043_office_6_frame_47_domain.png filter=lfs diff=lfs merge=lfs -text
|
84 |
+
examples/depth/camera_0a529ee09fc04037831902b83073862b_office_4_frame_47_domain.png filter=lfs diff=lfs merge=lfs -text
|
85 |
+
examples/depth/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_4_domain.png filter=lfs diff=lfs merge=lfs -text
|
86 |
+
examples/depth/camera_0c2610e800274e329f5fac53402855ef_office_32_frame_5_domain.png filter=lfs diff=lfs merge=lfs -text
|
87 |
+
examples/depth/camera_0c2610e800274e329f5fac53402855ef_office_32_frame_9_domain.png filter=lfs diff=lfs merge=lfs -text
|
88 |
+
examples/depth/camera_0a529ee09fc04037831902b83073862b_office_4_frame_48_domain.png filter=lfs diff=lfs merge=lfs -text
|
89 |
+
examples/depth/camera_0a529ee09fc04037831902b83073862b_office_4_frame_50_domain.png filter=lfs diff=lfs merge=lfs -text
|
90 |
+
examples/depth/camera_0a529ee09fc04037831902b83073862b_office_4_frame_51_domain.png filter=lfs diff=lfs merge=lfs -text
|
91 |
+
examples/depth/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_1_domain.png filter=lfs diff=lfs merge=lfs -text
|
92 |
+
examples/depth/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_9_domain.png filter=lfs diff=lfs merge=lfs -text
|
93 |
+
examples/depth/camera_0c2610e800274e329f5fac53402855ef_office_32_frame_10_domain.png filter=lfs diff=lfs merge=lfs -text
|
94 |
+
examples/depth/camera_0c2610e800274e329f5fac53402855ef_office_32_frame_6_domain.png filter=lfs diff=lfs merge=lfs -text
|
95 |
+
examples/depth/camera_0c2610e800274e329f5fac53402855ef_office_32_frame_2_domain.png filter=lfs diff=lfs merge=lfs -text
|
96 |
+
examples/depth/camera_0a529ee09fc04037831902b83073862b_office_4_frame_45_domain.png filter=lfs diff=lfs merge=lfs -text
|
97 |
+
examples/depth/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_2_domain.png filter=lfs diff=lfs merge=lfs -text
|
98 |
+
examples/depth/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_7_domain.png filter=lfs diff=lfs merge=lfs -text
|
99 |
+
examples/depth/camera_0c2610e800274e329f5fac53402855ef_office_32_frame_1_domain.png filter=lfs diff=lfs merge=lfs -text
|
100 |
+
examples/depth/camera_0c2610e800274e329f5fac53402855ef_office_32_frame_7_domain.png filter=lfs diff=lfs merge=lfs -text
|
101 |
+
classcolors.png filter=lfs diff=lfs merge=lfs -text
|
README.md
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
---
|
2 |
title: M3L
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
colorTo: gray
|
6 |
sdk: gradio
|
7 |
sdk_version: 3.23.0
|
|
|
1 |
---
|
2 |
title: M3L
|
3 |
+
emoji: 📚
|
4 |
+
colorFrom: purple
|
5 |
colorTo: gray
|
6 |
sdk: gradio
|
7 |
sdk_version: 3.23.0
|
app.py
ADDED
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import gradio as gr
|
4 |
+
import numpy as np
|
5 |
+
import os
|
6 |
+
import random
|
7 |
+
import pickle as pkl
|
8 |
+
from models.segmentation_models.linearfusemaskedconsmixbatch.segformer import LinearFusionMaskedConsistencyMixBatch
|
9 |
+
from models.segmentation_models.linearfuse.segformer import WeTrLinearFusion
|
10 |
+
from datasets.preprocessors import RGBDValPre
|
11 |
+
from utils.constants import Constants as C
|
12 |
+
|
13 |
+
class Arguments:
|
14 |
+
def __init__(self, ratio):
|
15 |
+
self.ratio = ratio
|
16 |
+
self.masking_ratio = 1.0
|
17 |
+
|
18 |
+
colors = pkl.load(open('./colors.pkl', 'rb'))
|
19 |
+
args = Arguments(ratio = 0.8)
|
20 |
+
|
21 |
+
mtmodel = WeTrLinearFusion("mit_b2", args, num_classes=13, pretrained=False)
|
22 |
+
mtmodelpath = './checkpoints/sid_1-500_mtteacher.pth'
|
23 |
+
mtmodel.load_state_dict(torch.load(mtmodelpath, map_location=torch.device('cpu')))
|
24 |
+
mtmodel.eval()
|
25 |
+
|
26 |
+
m3lmodel = LinearFusionMaskedConsistencyMixBatch("mit_b2", args, num_classes=13, pretrained=False)
|
27 |
+
m3lmodelpath = './checkpoints/sid_1-500_m3lteacher.pth'
|
28 |
+
m3lmodel.load_state_dict(torch.load(m3lmodelpath, map_location=torch.device('cpu')))
|
29 |
+
m3lmodel.eval()
|
30 |
+
|
31 |
+
|
32 |
+
|
33 |
+
class MaskStudentTeacher(nn.Module):
|
34 |
+
|
35 |
+
def __init__(self, student, teacher, ema_alpha, mode = 'train'):
|
36 |
+
super(MaskStudentTeacher, self).__init__()
|
37 |
+
self.student = student
|
38 |
+
self.teacher = teacher
|
39 |
+
self.teacher = self._detach_teacher(self.teacher)
|
40 |
+
self.ema_alpha = ema_alpha
|
41 |
+
self.mode = mode
|
42 |
+
def forward(self, data, student = True, teacher = True, mask = False, range_batches_to_mask = None, **kwargs):
|
43 |
+
ret = []
|
44 |
+
if student:
|
45 |
+
if self.mode == 'train':
|
46 |
+
ret.append(self.student(data, mask = mask, range_batches_to_mask = range_batches_to_mask, **kwargs))
|
47 |
+
elif self.mode == 'val':
|
48 |
+
ret.append(self.student(data, mask = False, **kwargs))
|
49 |
+
else:
|
50 |
+
raise Exception('Mode not supported')
|
51 |
+
if teacher:
|
52 |
+
ret.append(self.teacher(data, mask = False, **kwargs)) #Not computing loss for teacher ever but passing the results as if loss was also returned
|
53 |
+
return ret
|
54 |
+
def _detach_teacher(self, model):
|
55 |
+
for param in model.parameters():
|
56 |
+
param.detach_()
|
57 |
+
return model
|
58 |
+
def update_teacher_models(self, global_step):
|
59 |
+
alpha = min(1 - 1 / (global_step + 1), self.ema_alpha)
|
60 |
+
for ema_param, param in zip(self.teacher.parameters(), self.student.parameters()):
|
61 |
+
ema_param.data.mul_(alpha).add_(1 - alpha, param.data)
|
62 |
+
return
|
63 |
+
def copy_student_to_teacher(self):
|
64 |
+
for ema_param, param in zip(self.teacher.parameters(), self.student.parameters()):
|
65 |
+
ema_param.data.mul_(0).add_(param.data)
|
66 |
+
return
|
67 |
+
def get_params(self):
|
68 |
+
student_params = self.student.get_params()
|
69 |
+
teacher_params = self.teacher.get_params()
|
70 |
+
return student_params
|
71 |
+
|
72 |
+
|
73 |
+
def preprocess_data(rgb, depth, dataset_settings):
|
74 |
+
#RGB: np.array, RGB
|
75 |
+
#Depth: np.array, minmax normalized, *255
|
76 |
+
preprocess = RGBDValPre(C.pytorch_mean, C.pytorch_std, dataset_settings)
|
77 |
+
rgb, depth = preprocess(rgb, depth)
|
78 |
+
if rgb is not None:
|
79 |
+
rgb = torch.from_numpy(np.ascontiguousarray(rgb)).float()
|
80 |
+
if depth is not None:
|
81 |
+
depth = torch.from_numpy(np.ascontiguousarray(depth)).float()
|
82 |
+
return rgb, depth
|
83 |
+
|
84 |
+
|
85 |
+
def visualize(colors, pred, num_classes, dataset_settings):
|
86 |
+
pred = pred.transpose(1, 2, 0)
|
87 |
+
predvis = np.zeros((dataset_settings['orig_height'], dataset_settings['orig_width'], 3))
|
88 |
+
for i in range(num_classes):
|
89 |
+
color = colors[i]
|
90 |
+
predvis = np.where(pred == i, color, predvis)
|
91 |
+
predvis /= 255.0
|
92 |
+
predvis = predvis[:,:,::-1]
|
93 |
+
return predvis
|
94 |
+
|
95 |
+
def predict(rgb, depth, check):
|
96 |
+
dataset_settings = {}
|
97 |
+
dataset_settings['image_height'], dataset_settings['image_width'] = 540, 540
|
98 |
+
dataset_settings['orig_height'], dataset_settings['orig_width'] = 540,540
|
99 |
+
|
100 |
+
rgb, depth = preprocess_data(rgb, depth, dataset_settings)
|
101 |
+
if rgb is not None:
|
102 |
+
rgb = rgb.unsqueeze(dim = 0)
|
103 |
+
if depth is not None:
|
104 |
+
depth = depth.unsqueeze(dim = 0)
|
105 |
+
ret = [None, None, './classcolors.png']
|
106 |
+
if "Mean Teacher" in check:
|
107 |
+
if rgb is None:
|
108 |
+
rgb = torch.zeros_like(depth)
|
109 |
+
if depth is None:
|
110 |
+
depth = torch.zeros_like(rgb)
|
111 |
+
scores = mtmodel([rgb, depth])[2]
|
112 |
+
scores = torch.nn.functional.interpolate(scores, size = (dataset_settings["orig_height"], dataset_settings["orig_width"]), mode = 'bilinear', align_corners = True)
|
113 |
+
prob = scores.detach()
|
114 |
+
_, pred = torch.max(prob, dim=1)
|
115 |
+
pred = pred.numpy()
|
116 |
+
predvis = visualize(colors, pred, num_classes=13, dataset_settings=dataset_settings)
|
117 |
+
ret[0] = predvis
|
118 |
+
if "M3L" in check:
|
119 |
+
mask = False
|
120 |
+
masking_branch = None
|
121 |
+
if rgb is None:
|
122 |
+
mask = True
|
123 |
+
masking_branch = 0
|
124 |
+
if depth is None:
|
125 |
+
mask = True
|
126 |
+
masking_branch = 1
|
127 |
+
scores = m3lmodel([rgb, depth], mask = mask, masking_branch = masking_branch)[2]
|
128 |
+
scores = torch.nn.functional.interpolate(scores, size = (dataset_settings["orig_height"], dataset_settings["orig_width"]), mode = 'bilinear', align_corners = True)
|
129 |
+
prob = scores.detach()
|
130 |
+
_, pred = torch.max(prob, dim=1)
|
131 |
+
pred = pred.numpy()
|
132 |
+
predvis = visualize(colors, pred, num_classes=13, dataset_settings=dataset_settings)
|
133 |
+
ret[1] = predvis
|
134 |
+
|
135 |
+
return ret
|
136 |
+
|
137 |
+
imgs = os.listdir('./examples/rgb')
|
138 |
+
random.shuffle(imgs)
|
139 |
+
examples = []
|
140 |
+
for img in imgs:
|
141 |
+
examples.append([
|
142 |
+
'./examples/rgb/'+img, './examples/depth/'+img, ["M3L", "Mean Teacher"]
|
143 |
+
])
|
144 |
+
|
145 |
+
with gr.Blocks() as demo:
|
146 |
+
with gr.Row():
|
147 |
+
rgbinput = gr.Image(label="RGB Input").style(height=256, width=256)
|
148 |
+
depthinput = gr.Image(label="Depth Input").style(height=256, width=256)
|
149 |
+
with gr.Row():
|
150 |
+
modelcheck = gr.CheckboxGroup(["Mean Teacher", "M3L"], label="Predictions from", info="Predict using model trained with:")
|
151 |
+
with gr.Row():
|
152 |
+
submit_btn = gr.Button("Submit")
|
153 |
+
with gr.Row():
|
154 |
+
mtoutput = gr.Image(label="Mean Teacher Output").style(height=384, width=384)
|
155 |
+
m3loutput = gr.Image(label="M3L Output").style(height=384, width=384)
|
156 |
+
classnameouptut = gr.Image(label="Classes").style(height=384, width=384)
|
157 |
+
with gr.Row():
|
158 |
+
examplesRow = gr.Examples(examples=examples, examples_per_page=10, inputs=[rgbinput, depthinput, modelcheck])
|
159 |
+
submit_btn.click(fn = predict, inputs = [rgbinput, depthinput, modelcheck], outputs = [mtoutput, m3loutput, classnameouptut])
|
160 |
+
|
161 |
+
demo.launch()
|
arial.ttf
ADDED
Binary file (289 kB). View file
|
|
checkpoints/sid_1-500_m3lteacher.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d5a23d7e2697b44b18e368e01353c328b13055a05a1cb0946ffb95b692d6facd
|
3 |
+
size 99192724
|
checkpoints/sid_1-500_mtteacher.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7eb24d6275e15376c40ec526281d550e4842ffa71aaa7af58fea54cbf56c2eeb
|
3 |
+
size 99186911
|
classcolors.png
ADDED
Git LFS Details
|
colors.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3c03050cb753b0802781f0ce92893ac22129c15724dd4ece6f4b9b4a352db591
|
3 |
+
size 2342
|
datasets/__init__.py
ADDED
File without changes
|
datasets/__pycache__/__init__.cpython-36.pyc
ADDED
Binary file (192 Bytes). View file
|
|
datasets/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (165 Bytes). View file
|
|
datasets/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (200 Bytes). View file
|
|
datasets/__pycache__/base_dataset.cpython-36.pyc
ADDED
Binary file (4.77 kB). View file
|
|
datasets/__pycache__/base_dataset.cpython-38.pyc
ADDED
Binary file (4.75 kB). View file
|
|
datasets/__pycache__/base_dataset.cpython-39.pyc
ADDED
Binary file (4.78 kB). View file
|
|
datasets/__pycache__/citysundepth.cpython-36.pyc
ADDED
Binary file (1.65 kB). View file
|
|
datasets/__pycache__/citysundepth.cpython-39.pyc
ADDED
Binary file (1.65 kB). View file
|
|
datasets/__pycache__/citysunrgb.cpython-36.pyc
ADDED
Binary file (2.09 kB). View file
|
|
datasets/__pycache__/citysunrgb.cpython-38.pyc
ADDED
Binary file (2.04 kB). View file
|
|
datasets/__pycache__/citysunrgb.cpython-39.pyc
ADDED
Binary file (2.15 kB). View file
|
|
datasets/__pycache__/citysunrgbd.cpython-36.pyc
ADDED
Binary file (2.01 kB). View file
|
|
datasets/__pycache__/citysunrgbd.cpython-38.pyc
ADDED
Binary file (1.96 kB). View file
|
|
datasets/__pycache__/get_dataset.cpython-36.pyc
ADDED
Binary file (5.37 kB). View file
|
|
datasets/__pycache__/get_dataset.cpython-39.pyc
ADDED
Binary file (5.32 kB). View file
|
|
datasets/__pycache__/preprocessors.cpython-36.pyc
ADDED
Binary file (5.7 kB). View file
|
|
datasets/__pycache__/preprocessors.cpython-38.pyc
ADDED
Binary file (5.3 kB). View file
|
|
datasets/__pycache__/tfnyu.cpython-36.pyc
ADDED
Binary file (2.04 kB). View file
|
|
datasets/base_dataset.py
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.utils.data as data
|
3 |
+
import numpy as np
|
4 |
+
import cv2
|
5 |
+
from PIL import Image
|
6 |
+
from utils.img_utils import pad_image_to_shape
|
7 |
+
|
8 |
+
class BaseDataset(data.Dataset):
|
9 |
+
|
10 |
+
def __init__(self, dataset_settings, mode, unsupervised):
|
11 |
+
self._mode = mode
|
12 |
+
self.unsupervised = unsupervised
|
13 |
+
self._rgb_path = dataset_settings['rgb_root']
|
14 |
+
self._depth_path = dataset_settings['depth_root']
|
15 |
+
self._gt_path = dataset_settings['gt_root']
|
16 |
+
self._train_source = dataset_settings['train_source']
|
17 |
+
self._eval_source = dataset_settings['eval_source']
|
18 |
+
self.modalities = dataset_settings['modalities']
|
19 |
+
# self._file_length = dataset_settings['max_samples']
|
20 |
+
self._required_length = dataset_settings['required_length']
|
21 |
+
self._file_names = self._get_file_names(mode)
|
22 |
+
self.model_input_shape = (dataset_settings['image_height'], dataset_settings['image_width'])
|
23 |
+
|
24 |
+
def __len__(self):
|
25 |
+
if self._required_length is not None:
|
26 |
+
return self._required_length
|
27 |
+
return len(self._file_names) # when model == "val"
|
28 |
+
|
29 |
+
def _get_file_names(self, mode):
|
30 |
+
assert mode in ['train', 'val']
|
31 |
+
source = self._train_source
|
32 |
+
if mode == "val":
|
33 |
+
source = self._eval_source
|
34 |
+
|
35 |
+
file_names = []
|
36 |
+
with open(source) as f:
|
37 |
+
files = f.readlines()
|
38 |
+
|
39 |
+
for item in files:
|
40 |
+
names = self._process_item_names(item)
|
41 |
+
file_names.append(names)
|
42 |
+
|
43 |
+
if mode == "val":
|
44 |
+
return file_names
|
45 |
+
elif self._required_length <= len(file_names):
|
46 |
+
return file_names[:self._required_length]
|
47 |
+
else:
|
48 |
+
return self._construct_new_file_names(file_names, self._required_length)
|
49 |
+
|
50 |
+
def _construct_new_file_names(self, file_names, length):
|
51 |
+
assert isinstance(length, int)
|
52 |
+
files_len = len(file_names)
|
53 |
+
|
54 |
+
new_file_names = file_names * (length // files_len) #length % files_len items remaining
|
55 |
+
|
56 |
+
rand_indices = torch.randperm(files_len).tolist()
|
57 |
+
new_indices = rand_indices[:length % files_len]
|
58 |
+
|
59 |
+
new_file_names += [file_names[i] for i in new_indices]
|
60 |
+
|
61 |
+
return new_file_names
|
62 |
+
|
63 |
+
def _process_item_names(self, item):
|
64 |
+
item = item.strip()
|
65 |
+
item = item.split('\t')
|
66 |
+
num_modalities = len(self.modalities)
|
67 |
+
num_items = len(item)
|
68 |
+
names = {}
|
69 |
+
if not self.unsupervised:
|
70 |
+
assert num_modalities + 1 == num_items, f"Number of modalities and number of items in file name don't match, len(modalities) = {num_modalities} and len(item) = {num_items}" + item[0]
|
71 |
+
for i, modality in enumerate(self.modalities):
|
72 |
+
names[modality] = item[i]
|
73 |
+
names['gt'] = item[-1]
|
74 |
+
else:
|
75 |
+
assert num_modalities == num_items, f"Number of modalities and number of items in file name don't match, len(modalities) = {num_modalities} and len(item) = {num_items}"
|
76 |
+
for i, modality in enumerate(self.modalities):
|
77 |
+
names[modality] = item[i]
|
78 |
+
names['gt'] = None
|
79 |
+
|
80 |
+
return names
|
81 |
+
|
82 |
+
def _open_rgb(self, rgb_path, dtype = None):
|
83 |
+
bgr = cv2.imread(rgb_path, cv2.IMREAD_COLOR) #cv2 reads in BGR format, HxWxC
|
84 |
+
rgb = np.array(cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB), dtype=dtype) #Pretrained PyTorch model accepts image in RGB
|
85 |
+
return rgb
|
86 |
+
|
87 |
+
def _open_depth(self, depth_path, dtype = None): #returns in HxWx3 with the same image in all channels
|
88 |
+
img_arr = np.array(Image.open(depth_path))
|
89 |
+
if len(img_arr.shape) == 2: # grayscale
|
90 |
+
img_arr = np.array(np.tile(img_arr, [3, 1, 1]).transpose(1, 2, 0), dtype = dtype)
|
91 |
+
img_arr = (img_arr - img_arr.min()) * 255.0 / (img_arr.max() - img_arr.min())
|
92 |
+
return img_arr
|
93 |
+
|
94 |
+
def _open_depth_tf_nyu(self, depth_path, dtype = None): #returns in HxWx3 with the same image in all channels
|
95 |
+
img_arr = np.array(Image.open(depth_path))
|
96 |
+
if len(img_arr.shape) == 2: # grayscale
|
97 |
+
img_arr = np.tile(img_arr, [3, 1, 1]).transpose(1, 2, 0)
|
98 |
+
return img_arr
|
99 |
+
|
100 |
+
def _open_gt(self, gt_path, dtype = None):
|
101 |
+
return np.array(cv2.imread(gt_path, cv2.IMREAD_GRAYSCALE), dtype=dtype)
|
102 |
+
|
103 |
+
def slide_over_image(self, img, crop_size, stride_rate):
|
104 |
+
H, W, C = img.shape
|
105 |
+
long_size = H if H > W else W
|
106 |
+
output = []
|
107 |
+
if long_size <= min(crop_size[0], crop_size[1]):
|
108 |
+
raise Exception("Crop size is greater than the image size itself. Not handeled right now")
|
109 |
+
|
110 |
+
else:
|
111 |
+
stride_0 = int(np.ceil(crop_size[0] * stride_rate))
|
112 |
+
stride_1 = int(np.ceil(crop_size[1] * stride_rate))
|
113 |
+
r_grid = int(np.ceil((H - crop_size[0]) / stride_0)) + 1
|
114 |
+
c_grid = int(np.ceil((W - crop_size[1]) / stride_1)) + 1
|
115 |
+
|
116 |
+
for grid_yidx in range(r_grid):
|
117 |
+
for grid_xidx in range(c_grid):
|
118 |
+
s_x = grid_xidx * stride_1
|
119 |
+
s_y = grid_yidx * stride_0
|
120 |
+
e_x = min(s_x + crop_size[1], W)
|
121 |
+
e_y = min(s_y + crop_size[0], H)
|
122 |
+
s_x = e_x - crop_size[1]
|
123 |
+
s_y = e_y - crop_size[0]
|
124 |
+
img_sub = img[s_y:e_y, s_x: e_x, :]
|
125 |
+
img_sub, margin = pad_image_to_shape(img_sub, crop_size, cv2.BORDER_CONSTANT, value=0)
|
126 |
+
output.append((img_sub, np.array([s_y, e_y, s_x, e_x]), margin))
|
127 |
+
|
128 |
+
return output
|
datasets/citysunrgbd.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
from datasets.base_dataset import BaseDataset
|
5 |
+
|
6 |
+
|
7 |
+
class CityScapesSunRGBD(BaseDataset):
|
8 |
+
|
9 |
+
def __init__(self, dataset_settings, mode, unsupervised, preprocess, sliding = False, stride_rate = None):
|
10 |
+
super(CityScapesSunRGBD, self).__init__(dataset_settings, mode, unsupervised)
|
11 |
+
self.preprocess = preprocess
|
12 |
+
self.sliding = sliding
|
13 |
+
self.stride_rate = stride_rate
|
14 |
+
if self.sliding and self._mode == 'train':
|
15 |
+
print("Ensure correct preprocessing is being done!")
|
16 |
+
|
17 |
+
def __getitem__(self, index):
|
18 |
+
# if self._file_length is not None:
|
19 |
+
# names = self._construct_new_file_names(self._file_length)[index]
|
20 |
+
# else:
|
21 |
+
# names = self._file_names[index]
|
22 |
+
names = self._file_names[index]
|
23 |
+
rgb_path = self._rgb_path+names['rgb']
|
24 |
+
depth_path = self._rgb_path+names['depth']
|
25 |
+
if not self.unsupervised:
|
26 |
+
gt_path = self._gt_path+names['gt']
|
27 |
+
item_name = names['rgb'].split("/")[-1].split(".")[0]
|
28 |
+
|
29 |
+
rgb = self._open_rgb(rgb_path)
|
30 |
+
depth = self._open_depth(depth_path)
|
31 |
+
gt = None
|
32 |
+
if not self.unsupervised:
|
33 |
+
gt = self._open_gt(gt_path)
|
34 |
+
|
35 |
+
if not self.sliding:
|
36 |
+
if self.preprocess is not None:
|
37 |
+
rgb, depth, gt = self.preprocess(rgb, depth, gt)
|
38 |
+
|
39 |
+
if self._mode in ['train', 'val']:
|
40 |
+
rgb = torch.from_numpy(np.ascontiguousarray(rgb)).float()
|
41 |
+
depth = torch.from_numpy(np.ascontiguousarray(depth)).float()
|
42 |
+
if gt is not None:
|
43 |
+
gt = torch.from_numpy(np.ascontiguousarray(gt)).long()
|
44 |
+
else:
|
45 |
+
raise Exception(f"{self._mode} not supported in CityScapesSunRGB")
|
46 |
+
|
47 |
+
# output_dict = dict(rgb=rgb, fn=str(item_name),
|
48 |
+
# n=len(self._file_names))
|
49 |
+
output_dict = dict(data=[rgb, depth], name = item_name)
|
50 |
+
if gt is not None:
|
51 |
+
output_dict['gt'] = gt
|
52 |
+
return output_dict
|
53 |
+
|
54 |
+
else:
|
55 |
+
sliding_ouptut = self.slide_over_image(rgb, self.model_input_shape, self.stride_rate)
|
56 |
+
output_dict = {}
|
57 |
+
if self._mode in ['train', 'val']:
|
58 |
+
if gt is not None:
|
59 |
+
gt = torch.from_numpy(np.ascontiguousarray(gt)).long()
|
60 |
+
output_dict['gt'] = gt
|
61 |
+
output_dict['sliding_output'] = []
|
62 |
+
for img_sub, pos, margin in sliding_ouptut:
|
63 |
+
if self.preprocess is not None:
|
64 |
+
img_sub, _ = self.preprocess(img_sub, None)
|
65 |
+
img_sub = torch.from_numpy(np.ascontiguousarray(img_sub)).float()
|
66 |
+
output_dict['sliding_output'].append(([img_sub], pos, margin))
|
67 |
+
return output_dict
|
datasets/get_dataset.py
ADDED
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch.utils.data import DataLoader
|
3 |
+
from datasets.citysundepth import CityScapesSunDepth
|
4 |
+
from datasets.citysunrgb import CityScapesSunRGB
|
5 |
+
from datasets.citysunrgbd import CityScapesSunRGBD
|
6 |
+
from datasets.preprocessors import DepthTrainPre, DepthValPre, NYURGBDTrainPre, NYURGBDValPre, RGBDTrainPre, RGBDValPre, RGBTrainPre, RGBValPre
|
7 |
+
from datasets.tfnyu import TFNYU
|
8 |
+
from utils.constants import Constants as C
|
9 |
+
|
10 |
+
def get_dataset(args):
|
11 |
+
datasetClass = None
|
12 |
+
if args.data == "nyudv2":
|
13 |
+
return TFNYU
|
14 |
+
if args.data == "city" or args.data == "sunrgbd" or args.data == 'stanford_indoor':
|
15 |
+
if len(args.modalities) == 1 and args.modalities[0] == 'rgb':
|
16 |
+
datasetClass = CityScapesSunRGB
|
17 |
+
elif len(args.modalities) == 1 and args.modalities[0] == 'depth':
|
18 |
+
datasetClass = CityScapesSunDepth
|
19 |
+
elif len(args.modalities) == 2 and args.modalities[0] == 'rgb' and args.modalities[1] == 'depth':
|
20 |
+
datasetClass = CityScapesSunRGBD
|
21 |
+
else:
|
22 |
+
raise Exception(f"{args.modalities} not configured in get_dataset function.")
|
23 |
+
else:
|
24 |
+
raise Exception(f"{args.data} not configured in get_dataset function.")
|
25 |
+
return datasetClass
|
26 |
+
|
27 |
+
def get_preprocessors(args, dataset_settings, mode):
|
28 |
+
if args.data == "nyudv2" and len(args.modalities) == 2 and args.modalities[0] == 'rgb' and args.modalities[1] == 'depth':
|
29 |
+
if mode == 'train':
|
30 |
+
return NYURGBDTrainPre(C.pytorch_mean, C.pytorch_std, dataset_settings)
|
31 |
+
elif mode == 'val':
|
32 |
+
return NYURGBDValPre(C.pytorch_mean, C.pytorch_std, dataset_settings)
|
33 |
+
|
34 |
+
if len(args.modalities) == 1 and args.modalities[0] == 'rgb':
|
35 |
+
if mode == 'train':
|
36 |
+
return RGBTrainPre(C.pytorch_mean, C.pytorch_std, dataset_settings)
|
37 |
+
elif mode == 'val':
|
38 |
+
return RGBValPre(C.pytorch_mean, C.pytorch_std, dataset_settings)
|
39 |
+
else:
|
40 |
+
return Exception("%s mode not defined" % mode)
|
41 |
+
elif len(args.modalities) == 1 and args.modalities[0] == 'depth':
|
42 |
+
if mode == 'train':
|
43 |
+
return DepthTrainPre(dataset_settings)
|
44 |
+
elif mode == 'val':
|
45 |
+
return DepthValPre(dataset_settings)
|
46 |
+
else:
|
47 |
+
return Exception("%s mode not defined" % mode)
|
48 |
+
elif len(args.modalities) == 2 and args.modalities[0] == 'rgb' and args.modalities[1] == 'depth':
|
49 |
+
if mode == 'train':
|
50 |
+
return RGBDTrainPre(C.pytorch_mean, C.pytorch_std, dataset_settings)
|
51 |
+
elif mode == 'val':
|
52 |
+
return RGBDValPre(C.pytorch_mean, C.pytorch_std, dataset_settings)
|
53 |
+
else:
|
54 |
+
return Exception("%s mode not defined" % mode)
|
55 |
+
else:
|
56 |
+
raise Exception("%s not configured for preprocessing" % args.modalities)
|
57 |
+
|
58 |
+
def get_train_loader(datasetClass, args, train_source, unsupervised = False):
|
59 |
+
dataset_settings = {'rgb_root': args.rgb_root,
|
60 |
+
'gt_root': args.gt_root,
|
61 |
+
'depth_root': args.depth_root,
|
62 |
+
'train_source': train_source,
|
63 |
+
'eval_source': args.eval_source,
|
64 |
+
'required_length': args.total_train_imgs, #Every dataloader will have Total Train Images / batch size iterations to be consistent
|
65 |
+
# 'max_samples': args.max_samples, #Every dataloader will have Total Train Images / batch size iterations to be consistent
|
66 |
+
'train_scale_array': args.train_scale_array,
|
67 |
+
'image_height': args.image_height,
|
68 |
+
'image_width': args.image_width,
|
69 |
+
'modalities': args.modalities}
|
70 |
+
|
71 |
+
preprocessing = get_preprocessors(args, dataset_settings, "train")
|
72 |
+
train_dataset = datasetClass(dataset_settings, "train", unsupervised, preprocessing)
|
73 |
+
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, num_replicas = args.world_size, rank = args.rank)
|
74 |
+
if unsupervised and "unsup_batch_size" in args:
|
75 |
+
batch_size = args.unsup_batch_size
|
76 |
+
else:
|
77 |
+
batch_size = args.batch_size
|
78 |
+
train_loader = DataLoader(train_dataset,
|
79 |
+
batch_size = args.batch_size // args.world_size,
|
80 |
+
num_workers = args.num_workers,
|
81 |
+
drop_last = True,
|
82 |
+
shuffle = False,
|
83 |
+
sampler = train_sampler)
|
84 |
+
return train_loader
|
85 |
+
|
86 |
+
def get_val_loader(datasetClass, args):
|
87 |
+
dataset_settings = {'rgb_root': args.rgb_root,
|
88 |
+
'gt_root': args.gt_root,
|
89 |
+
'depth_root': args.depth_root,
|
90 |
+
'train_source': None,
|
91 |
+
'eval_source': args.eval_source,
|
92 |
+
'required_length': None,
|
93 |
+
'max_samples': None,
|
94 |
+
'train_scale_array': args.train_scale_array,
|
95 |
+
'image_height': args.image_height,
|
96 |
+
'image_width': args.image_width,
|
97 |
+
'modalities': args.modalities}
|
98 |
+
if args.data == 'sunrgbd':
|
99 |
+
eval_sources = []
|
100 |
+
for shape in ['427_561', '441_591', '530_730', '531_681']:
|
101 |
+
eval_sources.append(dataset_settings['eval_source'].split('.')[0] + '_' + shape + '.txt')
|
102 |
+
else:
|
103 |
+
eval_sources = [args.eval_source]
|
104 |
+
|
105 |
+
preprocessing = get_preprocessors(args, dataset_settings, "val")
|
106 |
+
if args.sliding_eval:
|
107 |
+
collate_fn = _sliding_collate_fn
|
108 |
+
else:
|
109 |
+
collate_fn = None
|
110 |
+
|
111 |
+
val_loaders = []
|
112 |
+
for eval_source in eval_sources:
|
113 |
+
dataset_settings['eval_source'] = eval_source
|
114 |
+
val_dataset = datasetClass(dataset_settings, "val", False, preprocessing, args.sliding_eval, args.stride_rate)
|
115 |
+
if args.rank is not None: #DDP Evaluation
|
116 |
+
val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset, num_replicas = args.world_size, rank = args.rank)
|
117 |
+
batch_size = args.val_batch_size // args.world_size
|
118 |
+
else: #DP Evaluation
|
119 |
+
val_sampler = None
|
120 |
+
batch_size = args.val_batch_size
|
121 |
+
|
122 |
+
val_loader = DataLoader(val_dataset,
|
123 |
+
batch_size = batch_size,
|
124 |
+
num_workers = 4,
|
125 |
+
drop_last = False,
|
126 |
+
shuffle = False,
|
127 |
+
collate_fn = collate_fn,
|
128 |
+
sampler = val_sampler)
|
129 |
+
val_loaders.append(val_loader)
|
130 |
+
return val_loaders
|
131 |
+
|
132 |
+
|
133 |
+
def _sliding_collate_fn(batch):
|
134 |
+
gt = torch.stack([b['gt'] for b in batch])
|
135 |
+
sliding_output = []
|
136 |
+
num_modalities = len(batch[0]['sliding_output'][0][0])
|
137 |
+
for i in range(len(batch[0]['sliding_output'])): #i iterates over positions
|
138 |
+
imgs = [torch.stack([b['sliding_output'][i][0][m] for b in batch]) for m in range(num_modalities)]
|
139 |
+
pos = batch[0]['sliding_output'][i][1]
|
140 |
+
pos_compare = [(b['sliding_output'][i][1] == pos).all() for b in batch]
|
141 |
+
assert all(pos_compare), f"Position not same for all points in the batch: {pos_compare}, {[b['sliding_output'][i][1] for b in batch]}"
|
142 |
+
margin = batch[0]['sliding_output'][i][2]
|
143 |
+
margin_compare = [(b['sliding_output'][i][2] == margin).all() for b in batch]
|
144 |
+
assert all(margin_compare), f"Margin not same for all points in the batch: {margin_compare}, {[b['sliding_output'][i][2] for b in batch]}"
|
145 |
+
sliding_output.append((imgs, pos, margin))
|
146 |
+
return {"gt": gt, "sliding_output": sliding_output}
|
datasets/preprocessors.py
ADDED
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from utils.img_utils import normalizedepth, random_crop_pad_to_shape, random_mirror, random_scale, normalize, resizedepth, resizergb, tfnyu_normalizedepth
|
2 |
+
|
3 |
+
class RGBTrainPre(object):
|
4 |
+
def __init__(self, pytorch_mean, pytorch_std, dataset_settings):
|
5 |
+
self.pytorch_mean = pytorch_mean
|
6 |
+
self.pytorch_std = pytorch_std
|
7 |
+
self.train_scale_array = dataset_settings['train_scale_array']
|
8 |
+
self.crop_size = (dataset_settings['image_height'], dataset_settings['image_width'])
|
9 |
+
|
10 |
+
def __call__(self, rgb, gt):
|
11 |
+
transformed_dict = random_mirror({"rgb":rgb, "gt":gt})
|
12 |
+
if self.train_scale_array is not None:
|
13 |
+
transformed_dict, _ = random_scale(transformed_dict, self.train_scale_array, (rgb.shape[0], rgb.shape[1]))
|
14 |
+
|
15 |
+
transformed_dict, _ = random_crop_pad_to_shape(transformed_dict, transformed_dict['rgb'].shape[:2], self.crop_size) #Makes gt HxWx1
|
16 |
+
rgb = transformed_dict['rgb']
|
17 |
+
gt = transformed_dict['gt']
|
18 |
+
rgb = normalize(rgb, self.pytorch_mean, self.pytorch_std)
|
19 |
+
|
20 |
+
rgb = rgb.transpose(2, 0, 1) #Brings the channel dimension in the top. Final output = CxHxW
|
21 |
+
return rgb, gt
|
22 |
+
|
23 |
+
|
24 |
+
class RGBValPre(object):
|
25 |
+
def __init__(self, pytorch_mean, pytorch_std, dataset_settings):
|
26 |
+
self.pytorch_mean = pytorch_mean
|
27 |
+
self.pytorch_std = pytorch_std
|
28 |
+
self.model_input_shape = (dataset_settings['image_height'], dataset_settings['image_width'])
|
29 |
+
|
30 |
+
def __call__(self, rgb, gt):
|
31 |
+
rgb = resizergb(rgb, self.model_input_shape)
|
32 |
+
rgb = normalize(rgb, self.pytorch_mean, self.pytorch_std)
|
33 |
+
rgb = rgb.transpose(2, 0, 1) #Brings the channel dimension in the top. Final output = CxHxW
|
34 |
+
return rgb, gt
|
35 |
+
|
36 |
+
|
37 |
+
class RGBDTrainPre(object):
|
38 |
+
def __init__(self, pytorch_mean, pytorch_std, dataset_settings):
|
39 |
+
self.pytorch_mean = pytorch_mean
|
40 |
+
self.pytorch_std = pytorch_std
|
41 |
+
self.train_scale_array = dataset_settings['train_scale_array']
|
42 |
+
self.crop_size = (dataset_settings['image_height'], dataset_settings['image_width'])
|
43 |
+
|
44 |
+
def __call__(self, rgb, depth, gt):
|
45 |
+
transformed_dict = random_mirror({"rgb":rgb, "depth": depth, "gt":gt})
|
46 |
+
if self.train_scale_array is not None:
|
47 |
+
transformed_dict, _ = random_scale(transformed_dict, self.train_scale_array, (rgb.shape[0], rgb.shape[1]))
|
48 |
+
|
49 |
+
transformed_dict, _ = random_crop_pad_to_shape(transformed_dict, transformed_dict['rgb'].shape[:2], self.crop_size) #Makes gt HxWx1
|
50 |
+
rgb = transformed_dict['rgb']
|
51 |
+
depth = transformed_dict['depth']
|
52 |
+
gt = transformed_dict['gt']
|
53 |
+
rgb = normalize(rgb, self.pytorch_mean, self.pytorch_std)
|
54 |
+
depth = normalizedepth(depth)
|
55 |
+
rgb = rgb.transpose(2, 0, 1) #Brings the channel dimension in the top. Final output = CxHxW
|
56 |
+
depth = depth.transpose(2, 0, 1) #Brings the channel dimension in the top. Final output = CxHxW
|
57 |
+
return rgb, depth, gt
|
58 |
+
|
59 |
+
|
60 |
+
class RGBDValPre(object):
|
61 |
+
def __init__(self, pytorch_mean, pytorch_std, dataset_settings):
|
62 |
+
self.pytorch_mean = pytorch_mean
|
63 |
+
self.pytorch_std = pytorch_std
|
64 |
+
self.model_input_shape = (dataset_settings['image_height'], dataset_settings['image_width'])
|
65 |
+
|
66 |
+
def __call__(self, rgb, depth):
|
67 |
+
if rgb is not None:
|
68 |
+
rgb = resizergb(rgb, self.model_input_shape)
|
69 |
+
rgb = normalize(rgb, self.pytorch_mean, self.pytorch_std)
|
70 |
+
rgb = rgb.transpose(2, 0, 1) #Brings the channel dimension in the top. Final output = CxHxW
|
71 |
+
if depth is not None:
|
72 |
+
depth = resizedepth(depth, self.model_input_shape)
|
73 |
+
depth = normalizedepth(depth)
|
74 |
+
depth = depth.transpose(2, 0, 1) #Brings the channel dimension in the top. Final output = CxHxW
|
75 |
+
|
76 |
+
return rgb, depth
|
77 |
+
|
78 |
+
|
79 |
+
class NYURGBDTrainPre(object):
|
80 |
+
def __init__(self, pytorch_mean, pytorch_std, dataset_settings):
|
81 |
+
self.pytorch_mean = pytorch_mean
|
82 |
+
self.pytorch_std = pytorch_std
|
83 |
+
self.train_scale_array = dataset_settings['train_scale_array']
|
84 |
+
self.crop_size = (dataset_settings['image_height'], dataset_settings['image_width'])
|
85 |
+
|
86 |
+
def __call__(self, rgb, depth, gt):
|
87 |
+
transformed_dict = random_mirror({"rgb":rgb, "depth": depth, "gt":gt})
|
88 |
+
if self.train_scale_array is not None:
|
89 |
+
transformed_dict, _ = random_scale(transformed_dict, self.train_scale_array, (rgb.shape[0], rgb.shape[1]))
|
90 |
+
|
91 |
+
transformed_dict, _ = random_crop_pad_to_shape(transformed_dict, transformed_dict['rgb'].shape[:2], self.crop_size) #Makes gt HxWx1
|
92 |
+
rgb = transformed_dict['rgb']
|
93 |
+
depth = transformed_dict['depth']
|
94 |
+
gt = transformed_dict['gt']
|
95 |
+
rgb = normalize(rgb, self.pytorch_mean, self.pytorch_std)
|
96 |
+
depth = tfnyu_normalizedepth(depth)
|
97 |
+
rgb = rgb.transpose(2, 0, 1) #Brings the channel dimension in the top. Final output = CxHxW
|
98 |
+
depth = depth.transpose(2, 0, 1) #Brings the channel dimension in the top. Final output = CxHxW
|
99 |
+
return rgb, depth, gt
|
100 |
+
|
101 |
+
|
102 |
+
class NYURGBDValPre(object):
|
103 |
+
def __init__(self, pytorch_mean, pytorch_std, dataset_settings):
|
104 |
+
self.pytorch_mean = pytorch_mean
|
105 |
+
self.pytorch_std = pytorch_std
|
106 |
+
self.model_input_shape = (dataset_settings['image_height'], dataset_settings['image_width'])
|
107 |
+
|
108 |
+
def __call__(self, rgb, depth, gt):
|
109 |
+
rgb = resizergb(rgb, self.model_input_shape)
|
110 |
+
depth = resizedepth(depth, self.model_input_shape)
|
111 |
+
rgb = normalize(rgb, self.pytorch_mean, self.pytorch_std)
|
112 |
+
depth = tfnyu_normalizedepth(depth)
|
113 |
+
rgb = rgb.transpose(2, 0, 1) #Brings the channel dimension in the top. Final output = CxHxW
|
114 |
+
depth = depth.transpose(2, 0, 1) #Brings the channel dimension in the top. Final output = CxHxW
|
115 |
+
return rgb, depth, gt
|
116 |
+
|
117 |
+
|
118 |
+
class DepthTrainPre(object):
|
119 |
+
def __init__(self, dataset_settings):
|
120 |
+
self.train_scale_array = dataset_settings['train_scale_array']
|
121 |
+
self.crop_size = (dataset_settings['image_height'], dataset_settings['image_width'])
|
122 |
+
|
123 |
+
def __call__(self, depth, gt):
|
124 |
+
transformed_dict = random_mirror({"depth": depth, "gt":gt})
|
125 |
+
if self.train_scale_array is not None:
|
126 |
+
transformed_dict, _ = random_scale(transformed_dict, self.train_scale_array, (depth.shape[0], depth.shape[1]))
|
127 |
+
|
128 |
+
transformed_dict, _ = random_crop_pad_to_shape(transformed_dict, transformed_dict['depth'].shape[:2], self.crop_size) #Makes gt HxWx1
|
129 |
+
depth = transformed_dict['depth']
|
130 |
+
gt = transformed_dict['gt']
|
131 |
+
depth = normalizedepth(depth)
|
132 |
+
depth = depth.transpose(2, 0, 1) #Brings the channel dimension in the top. Final output = CxHxW
|
133 |
+
return depth, gt
|
134 |
+
|
135 |
+
|
136 |
+
class DepthValPre(object):
|
137 |
+
def __init__(self, dataset_settings):
|
138 |
+
self.model_input_shape = (dataset_settings['image_height'], dataset_settings['image_width'])
|
139 |
+
|
140 |
+
def __call__(self, depth, gt):
|
141 |
+
depth = resizedepth(depth, self.model_input_shape)
|
142 |
+
depth = normalizedepth(depth)
|
143 |
+
depth = depth.transpose(2, 0, 1) #Brings the channel dimension in the top. Final output = CxHxW
|
144 |
+
return depth, gt
|
examples/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
examples/depth/camera_0a529ee09fc04037831902b83073862b_office_4_frame_45_domain.png
ADDED
Git LFS Details
|
examples/depth/camera_0a529ee09fc04037831902b83073862b_office_4_frame_46_domain.png
ADDED
Git LFS Details
|
examples/depth/camera_0a529ee09fc04037831902b83073862b_office_4_frame_47_domain.png
ADDED
Git LFS Details
|
examples/depth/camera_0a529ee09fc04037831902b83073862b_office_4_frame_48_domain.png
ADDED
Git LFS Details
|
examples/depth/camera_0a529ee09fc04037831902b83073862b_office_4_frame_49_domain.png
ADDED
Git LFS Details
|
examples/depth/camera_0a529ee09fc04037831902b83073862b_office_4_frame_50_domain.png
ADDED
Git LFS Details
|
examples/depth/camera_0a529ee09fc04037831902b83073862b_office_4_frame_51_domain.png
ADDED
Git LFS Details
|
examples/depth/camera_0a529ee09fc04037831902b83073862b_office_4_frame_52_domain.png
ADDED
Git LFS Details
|
examples/depth/camera_0a529ee09fc04037831902b83073862b_office_4_frame_53_domain.png
ADDED
Git LFS Details
|
examples/depth/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_0_domain.png
ADDED
Git LFS Details
|
examples/depth/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_10_domain.png
ADDED
Git LFS Details
|
examples/depth/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_1_domain.png
ADDED
Git LFS Details
|
examples/depth/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_2_domain.png
ADDED
Git LFS Details
|
examples/depth/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_3_domain.png
ADDED
Git LFS Details
|
examples/depth/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_4_domain.png
ADDED
Git LFS Details
|
examples/depth/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_5_domain.png
ADDED
Git LFS Details
|
examples/depth/camera_0b2396756adf4a76bbb985af92f534af_office_7_frame_6_domain.png
ADDED
Git LFS Details
|