adirathor07 commited on
Commit
757ed1c
·
0 Parent(s):

initial commit

Browse files
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ saved_model/
README.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Snap2scene
3
+ emoji: 😻
4
+ colorFrom: pink
5
+ colorTo: blue
6
+ sdk: streamlit
7
+ sdk_version: 1.44.1
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
__pycache__/binvox_rw.cpython-38.pyc ADDED
Binary file (7.44 kB). View file
 
__pycache__/config.cpython-38.pyc ADDED
Binary file (2.48 kB). View file
 
__pycache__/data_transforms.cpython-38.pyc ADDED
Binary file (11.8 kB). View file
 
__pycache__/helpers.cpython-38.pyc ADDED
Binary file (2.86 kB). View file
 
__pycache__/utils.cpython-38.pyc ADDED
Binary file (2.8 kB). View file
 
app.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import numpy as np
3
+ import plotly.graph_objects as go
4
+ from PIL import Image
5
+ from helpers import *
6
+
7
+ # --- APP START ---
8
+ st.title("2D → 3D Voxel Reconstruction Viewer")
9
+
10
+ uploaded_images = st.file_uploader(f"Upload images", accept_multiple_files=True, type=["png", "jpg", "jpeg"])
11
+ # print(uploaded_images)
12
+
13
+
14
+ # --- DISPLAY ---
15
+ if uploaded_images:
16
+ st.subheader("Uploaded Input Views")
17
+ cols = st.columns(len(uploaded_images))
18
+ rendering_images = []
19
+
20
+ for i, uploaded_file in enumerate(uploaded_images):
21
+ img = Image.open(uploaded_file)
22
+
23
+ cols[i].image(img, caption=f"View {i+1}", use_container_width=True)
24
+
25
+ img_np = np.array(img).astype(np.float32) / 255.0
26
+
27
+ rendering_images.append(img_np)
28
+
29
+
30
+ if st.button("Submit for Reconstruction"):
31
+ gv=None
32
+ with st.spinner("Reconstructing..."):
33
+ gv = predict_voxel_from_images(rendering_images)
34
+
35
+ fig = voxel_to_plotly(gv)
36
+ st.plotly_chart(fig, use_container_width=True)
37
+
38
+ else:
39
+ st.info(f"Upload images to continue.")
check.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
config.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ #
3
+ # Developed by Haozhe Xie <[email protected]>
4
+
5
+ from easydict import EasyDict as edict
6
+
7
+ __C = edict()
8
+ cfg = __C
9
+
10
+ #
11
+ # Dataset Config
12
+ #
13
+ __C.DATASETS = edict()
14
+ __C.DATASETS.SHAPENET = edict()
15
+ __C.DATASETS.SHAPENET.TAXONOMY_FILE_PATH = 'datasets/ShapeNet.json'
16
+ # __C.DATASETS.SHAPENET.TAXONOMY_FILE_PATH = './datasets/PascalShapeNet.json'
17
+ __C.DATASETS.SHAPENET.RENDERING_PATH = 'datasets/ShapeNetRendering/%s/%s/rendering/%02d.png'
18
+ # __C.DATASETS.SHAPENET.RENDERING_PATH = '/home/hzxie/Datasets/ShapeNet/PascalShapeNetRendering/%s/%s/render_%04d.jpg'
19
+ __C.DATASETS.SHAPENET.VOXEL_PATH = 'datasets/ShapeNetVox32/%s/%s/model.binvox'
20
+ __C.DATASETS.PASCAL3D = edict()
21
+ __C.DATASETS.PASCAL3D.TAXONOMY_FILE_PATH = 'datasets/Pascal3D.json'
22
+ __C.DATASETS.PASCAL3D.ANNOTATION_PATH = '/home/hzxie/Datasets/PASCAL3D/Annotations/%s_imagenet/%s.mat'
23
+ __C.DATASETS.PASCAL3D.RENDERING_PATH = '/home/hzxie/Datasets/PASCAL3D/Images/%s_imagenet/%s.JPEG'
24
+ __C.DATASETS.PASCAL3D.VOXEL_PATH = '/home/hzxie/Datasets/PASCAL3D/CAD/%s/%02d.binvox'
25
+ __C.DATASETS.PIX3D = edict()
26
+ __C.DATASETS.PIX3D.TAXONOMY_FILE_PATH = 'datasets/Pix3D.json'
27
+ __C.DATASETS.PIX3D.ANNOTATION_PATH = 'datasets/Pix3D/pix3d.json'
28
+ __C.DATASETS.PIX3D.RENDERING_PATH = 'datasets/Pix3D/img/%s/%s.%s'
29
+ __C.DATASETS.PIX3D.VOXEL_PATH = 'datasets/Pix3D/model/%s/%s/%s.binvox'
30
+
31
+ #
32
+ # Dataset
33
+ #
34
+ __C.DATASET = edict()
35
+ __C.DATASET.MEAN = [0.5, 0.5, 0.5]
36
+ __C.DATASET.STD = [0.5, 0.5, 0.5]
37
+ __C.DATASET.TRAIN_DATASET = 'ShapeNet'
38
+ __C.DATASET.TEST_DATASET = 'ShapeNet'
39
+ # __C.DATASET.TEST_DATASET = 'Pascal3D'
40
+ # __C.DATASET.TEST_DATASET = 'Pix3D'
41
+
42
+ #
43
+ # Common
44
+ #
45
+ __C.CONST = edict()
46
+ __C.CONST.DEVICE = '0'
47
+ __C.CONST.RNG_SEED = 0
48
+ __C.CONST.IMG_W = 224 # Image width for input
49
+ __C.CONST.IMG_H = 224 # Image height for input
50
+ __C.CONST.N_VOX = 32
51
+ __C.CONST.BATCH_SIZE = 64
52
+ __C.CONST.N_VIEWS_RENDERING = 1 # Dummy property for Pascal 3D
53
+ __C.CONST.CROP_IMG_W = 128 # Dummy property for Pascal 3D
54
+ __C.CONST.CROP_IMG_H = 128 # Dummy property for Pascal 3D
55
+
56
+ #
57
+ # Directories
58
+ #
59
+ __C.DIR = edict()
60
+ __C.DIR.OUT_PATH = './output'
61
+ __C.DIR.RANDOM_BG_PATH = '/home/hzxie/Datasets/SUN2012/JPEGImages'
62
+
63
+ #
64
+ # Network
65
+ #
66
+ __C.NETWORK = edict()
67
+ __C.NETWORK.LEAKY_VALUE = .2
68
+ __C.NETWORK.TCONV_USE_BIAS = False
69
+ __C.NETWORK.USE_REFINER = True
70
+ __C.NETWORK.USE_MERGER = True
71
+
72
+ #
73
+ # Training
74
+ #
75
+ __C.TRAIN = edict()
76
+ __C.TRAIN.RESUME_TRAIN = False
77
+ __C.TRAIN.NUM_WORKER = 4 # number of data workers
78
+ __C.TRAIN.NUM_EPOCHES = 5
79
+ __C.TRAIN.BRIGHTNESS = .4
80
+ __C.TRAIN.CONTRAST = .4
81
+ __C.TRAIN.SATURATION = .4
82
+ __C.TRAIN.NOISE_STD = .1
83
+ __C.TRAIN.RANDOM_BG_COLOR_RANGE = [[225, 255], [225, 255], [225, 255]]
84
+ __C.TRAIN.POLICY = 'adam' # available options: sgd, adam
85
+ __C.TRAIN.EPOCH_START_USE_REFINER = 0
86
+ __C.TRAIN.EPOCH_START_USE_MERGER = 0
87
+ __C.TRAIN.ENCODER_LEARNING_RATE = 1e-3
88
+ __C.TRAIN.DECODER_LEARNING_RATE = 1e-3
89
+ __C.TRAIN.REFINER_LEARNING_RATE = 1e-3
90
+ __C.TRAIN.MERGER_LEARNING_RATE = 1e-4
91
+ __C.TRAIN.DISCRIMINATOR_LR = 1e-4
92
+ __C.TRAIN.GAN_LOSS_WEIGHT = 0.01
93
+ __C.TRAIN.ENCODER_LR_MILESTONES = [150]
94
+ __C.TRAIN.DECODER_LR_MILESTONES = [150]
95
+ __C.TRAIN.REFINER_LR_MILESTONES = [150]
96
+ __C.TRAIN.MERGER_LR_MILESTONES = [150]
97
+ __C.TRAIN.BETAS = (.9, .999)
98
+ __C.TRAIN.MOMENTUM = .9
99
+ __C.TRAIN.GAMMA = .5
100
+ __C.TRAIN.SAVE_FREQ = 10 # weights will be overwritten every save_freq epoch
101
+ __C.TRAIN.UPDATE_N_VIEWS_RENDERING = False
102
+
103
+ #
104
+ # Testing options
105
+ #
106
+ __C.TEST = edict()
107
+ __C.TEST.RANDOM_BG_COLOR_RANGE = [[240, 240], [240, 240], [240, 240]]
108
+ __C.TEST.VOXEL_THRESH = [.2, .3, .4, .5]
helpers.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import utils.binvox_rw as binvox_rw
2
+ import numpy as np
3
+ import plotly.graph_objects as go
4
+ from models.encoder import Encoder
5
+ from models.decoder import Decoder
6
+ from models.merger import Merger
7
+ from models.refiner import Refiner
8
+ from config import cfg
9
+ import torch
10
+ from datetime import datetime as dt
11
+ import utils.data_transforms
12
+ device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
13
+ print(device)
14
+ # device='cpu'
15
+
16
+ cfg.CONST.WEIGHTS='saved_model/Pix2Vox.pth'
17
+
18
+
19
+ def read_binvox(file) -> np.ndarray:
20
+ model = binvox_rw.read_as_3d_array(file)
21
+ return model.data.astype(np.uint8)
22
+
23
+
24
+ def voxel_to_plotly(voxels):
25
+ x, y, z = voxels.nonzero()
26
+ fig = go.Figure(data=[
27
+ go.Scatter3d(
28
+ x=x, y=y, z=z,
29
+ mode='markers',
30
+ marker=dict(size=3, color=z, colorscale='Viridis', opacity=0.7)
31
+ )
32
+ ])
33
+ fig.update_layout(scene=dict(aspectmode='data'))
34
+ return fig
35
+
36
+
37
+
38
+ IMG_SIZE = cfg.CONST.IMG_H, cfg.CONST.IMG_W
39
+ CROP_SIZE = cfg.CONST.CROP_IMG_H, cfg.CONST.CROP_IMG_W
40
+ test_transforms = utils.data_transforms.Compose([
41
+ utils.data_transforms.CenterCrop(IMG_SIZE, CROP_SIZE),
42
+ utils.data_transforms.RandomBackground(cfg.TEST.RANDOM_BG_COLOR_RANGE),
43
+ utils.data_transforms.Normalize(mean=cfg.DATASET.MEAN, std=cfg.DATASET.STD),
44
+ utils.data_transforms.ToTensor(),
45
+ ])
46
+
47
+
48
+ def predict_voxel_from_images(rendering_images):
49
+ transformed_images = test_transforms(rendering_images)
50
+
51
+ encoder = Encoder(cfg)
52
+ decoder = Decoder(cfg)
53
+ refiner = Refiner(cfg)
54
+ merger = Merger(cfg)
55
+
56
+
57
+ if torch.cuda.is_available():
58
+ encoder = torch.nn.DataParallel(encoder).cuda()
59
+ decoder = torch.nn.DataParallel(decoder).cuda()
60
+ refiner = torch.nn.DataParallel(refiner).cuda()
61
+ merger = torch.nn.DataParallel(merger).cuda()
62
+
63
+ print('[INFO] %s Loading weights from %s ...' % (dt.now(), cfg.CONST.WEIGHTS))
64
+ checkpoint = torch.load(cfg.CONST.WEIGHTS)
65
+
66
+ epoch_idx = checkpoint['epoch_idx']
67
+ encoder.load_state_dict(checkpoint['encoder_state_dict'])
68
+ decoder.load_state_dict(checkpoint['decoder_state_dict'])
69
+
70
+ if cfg.NETWORK.USE_REFINER:
71
+ refiner.load_state_dict(checkpoint['refiner_state_dict'])
72
+ if cfg.NETWORK.USE_MERGER:
73
+ merger.load_state_dict(checkpoint['merger_state_dict'])
74
+
75
+
76
+ encoder.eval()
77
+ decoder.eval()
78
+ merger.eval()
79
+ refiner.eval()
80
+
81
+
82
+ with torch.no_grad():
83
+
84
+ transformed_images = transformed_images.unsqueeze(0) #adding the batch_dim
85
+ transformed_images = transformed_images.to(device)
86
+
87
+ # print(rendering_images.shape)
88
+ image_features = encoder(transformed_images)
89
+ print(image_features.shape)
90
+ raw_features, generated_volume = decoder(image_features)
91
+ print(generated_volume.shape)
92
+
93
+
94
+ if cfg.NETWORK.USE_MERGER:
95
+ generated_volume = merger(raw_features, generated_volume)
96
+ else:
97
+ generated_volume = torch.mean(generated_volume, dim=1)
98
+
99
+
100
+ # encoder_loss = bce_loss(generated_volume, ground_truth_volume) * 10
101
+
102
+
103
+ if cfg.NETWORK.USE_REFINER:
104
+ generated_volume = refiner(generated_volume)
105
+ # refiner_loss = bce_loss(generated_volume, ground_truth_volume) * 10
106
+ else:
107
+ # refiner_loss = encoder_loss
108
+ pass
109
+
110
+
111
+ generated_volume=generated_volume.squeeze(0)
112
+ gv = generated_volume.cpu().numpy()
113
+ gv = (gv >= 0.5).astype(np.uint8)
114
+
115
+
116
+
117
+
118
+ torch.cuda.empty_cache()
119
+ return gv
models/__init__.py ADDED
File without changes
models/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (169 Bytes). View file
 
models/__pycache__/decoder.cpython-38.pyc ADDED
Binary file (2.26 kB). View file
 
models/__pycache__/encoder.cpython-38.pyc ADDED
Binary file (2.05 kB). View file
 
models/__pycache__/merger.cpython-38.pyc ADDED
Binary file (1.65 kB). View file
 
models/__pycache__/refiner.cpython-38.pyc ADDED
Binary file (1.92 kB). View file
 
models/decoder.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ class Decoder(torch.nn.Module):
4
+ def __init__(self, cfg):
5
+ super(Decoder, self).__init__()
6
+ self.cfg = cfg
7
+
8
+ # Layer Definition
9
+ self.layer1 = torch.nn.Sequential(
10
+ torch.nn.ConvTranspose3d(2048, 512, kernel_size=4, stride=2, bias=cfg.NETWORK.TCONV_USE_BIAS, padding=1),
11
+ torch.nn.BatchNorm3d(512),
12
+ torch.nn.ReLU()
13
+ )
14
+ self.layer2 = torch.nn.Sequential(
15
+ torch.nn.ConvTranspose3d(512, 128, kernel_size=4, stride=2, bias=cfg.NETWORK.TCONV_USE_BIAS, padding=1),
16
+ torch.nn.BatchNorm3d(128),
17
+ torch.nn.ReLU()
18
+ )
19
+ self.layer3 = torch.nn.Sequential(
20
+ torch.nn.ConvTranspose3d(128, 32, kernel_size=4, stride=2, bias=cfg.NETWORK.TCONV_USE_BIAS, padding=1),
21
+ torch.nn.BatchNorm3d(32),
22
+ torch.nn.ReLU()
23
+ )
24
+ self.layer4 = torch.nn.Sequential(
25
+ torch.nn.ConvTranspose3d(32, 8, kernel_size=4, stride=2, bias=cfg.NETWORK.TCONV_USE_BIAS, padding=1),
26
+ torch.nn.BatchNorm3d(8),
27
+ torch.nn.ReLU()
28
+ )
29
+ self.layer5 = torch.nn.Sequential(
30
+ torch.nn.ConvTranspose3d(8, 1, kernel_size=1, bias=cfg.NETWORK.TCONV_USE_BIAS),
31
+ torch.nn.Sigmoid()
32
+ )
33
+
34
+ def forward(self, image_features):
35
+ image_features = image_features.permute(1, 0, 2, 3, 4).contiguous()
36
+ image_features = torch.split(image_features, 1, dim=0)
37
+ gen_volumes = []
38
+ raw_features = []
39
+
40
+ for features in image_features:
41
+ gen_volume = features.view(-1, 2048, 2, 2, 2)
42
+ # print(gen_volume.size()) # torch.Size([batch_size, 2048, 2, 2, 2])
43
+ gen_volume = self.layer1(gen_volume)
44
+ # print(gen_volume.size()) # torch.Size([batch_size, 512, 4, 4, 4])
45
+ gen_volume = self.layer2(gen_volume)
46
+ # print(gen_volume.size()) # torch.Size([batch_size, 128, 8, 8, 8])
47
+ gen_volume = self.layer3(gen_volume)
48
+ # print(gen_volume.size()) # torch.Size([batch_size, 32, 16, 16, 16])
49
+ gen_volume = self.layer4(gen_volume)
50
+ raw_feature = gen_volume
51
+ # print(gen_volume.size()) # torch.Size([batch_size, 8, 32, 32, 32])
52
+ gen_volume = self.layer5(gen_volume)
53
+ # print(gen_volume.size()) # torch.Size([batch_size, 1, 32, 32, 32])
54
+ raw_feature = torch.cat((raw_feature, gen_volume), dim=1)
55
+ # print(raw_feature.size()) # torch.Size([batch_size, 9, 32, 32, 32])
56
+
57
+ gen_volumes.append(torch.squeeze(gen_volume, dim=1))
58
+ raw_features.append(raw_feature)
59
+
60
+ gen_volumes = torch.stack(gen_volumes).permute(1, 0, 2, 3, 4).contiguous()
61
+ raw_features = torch.stack(raw_features).permute(1, 0, 2, 3, 4, 5).contiguous()
62
+ # print(gen_volumes.size()) # torch.Size([batch_size, n_views, 32, 32, 32])
63
+ # print(raw_features.size()) # torch.Size([batch_size, n_views, 9, 32, 32, 32])
64
+ return raw_features, gen_volumes
65
+
66
+
67
+ class DummyCfg:
68
+ class NETWORK:
69
+ TCONV_USE_BIAS = False
70
+
71
+ cfg = DummyCfg()
72
+
73
+ # Instantiate the decoder
74
+ decoder = Decoder(cfg)
75
+
76
+ # Simulate input: shape [batch_size,n_views,img_c, img_h, img_w]
77
+ n_views = 1
78
+ batch_size = 64
79
+ img_c, img_h, img_w = 256, 8, 8
80
+ dummy_input = torch.randn(batch_size,n_views,img_c, img_h, img_w)
81
+
82
+ # Run the decoder
83
+ print(dummy_input.shape)
84
+ raw_features, gen_volumes = decoder(dummy_input)
85
+
86
+ # Output shapes
87
+ print("raw_features shape:", raw_features.shape) # Expected: [64, 5, 9, 32, 32, 32]
88
+ print("gen_volumes shape:", gen_volumes.shape)
models/encoder.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ #
3
+ # Developed by Haozhe Xie <[email protected]>
4
+ #
5
+ # References:
6
+ # - https://github.com/shawnxu1318/MVCNN-Multi-View-Convolutional-Neural-Networks/blob/master/mvcnn.py
7
+
8
+ import torch
9
+ import torchvision.models
10
+
11
+
12
+ class Encoder(torch.nn.Module):
13
+ def __init__(self, cfg):
14
+ super(Encoder, self).__init__()
15
+ self.cfg = cfg
16
+
17
+ # Layer Definition
18
+ vgg16_bn = torchvision.models.vgg16_bn(pretrained=True)
19
+ self.vgg = torch.nn.Sequential(*list(vgg16_bn.features.children()))[:27]
20
+ self.layer1 = torch.nn.Sequential(
21
+ torch.nn.Conv2d(512, 512, kernel_size=3),
22
+ torch.nn.BatchNorm2d(512),
23
+ torch.nn.ELU(),
24
+ )
25
+ self.layer2 = torch.nn.Sequential(
26
+ torch.nn.Conv2d(512, 512, kernel_size=3),
27
+ torch.nn.BatchNorm2d(512),
28
+ torch.nn.ELU(),
29
+ torch.nn.MaxPool2d(kernel_size=3)
30
+ )
31
+ self.layer3 = torch.nn.Sequential(
32
+ torch.nn.Conv2d(512, 256, kernel_size=1),
33
+ torch.nn.BatchNorm2d(256),
34
+ torch.nn.ELU()
35
+ )
36
+
37
+ # Don't update params in VGG16
38
+ for param in vgg16_bn.parameters():
39
+ param.requires_grad = False
40
+
41
+ def forward(self, rendering_images):
42
+ # print(rendering_images.size()) # torch.Size([batch_size, n_views, img_c, img_h, img_w])
43
+ rendering_images = rendering_images.permute(1, 0, 2, 3, 4).contiguous()
44
+ rendering_images = torch.split(rendering_images, 1, dim=0)
45
+ image_features = []
46
+
47
+ for img in rendering_images:
48
+ features = self.vgg(img.squeeze(dim=0))
49
+ # print(features.size()) # torch.Size([batch_size, 512, 28, 28])
50
+ features = self.layer1(features)
51
+ # print(features.size()) # torch.Size([batch_size, 512, 26, 26])
52
+ features = self.layer2(features)
53
+ # print(features.size()) # torch.Size([batch_size, 512, 24, 24])
54
+ features = self.layer3(features)
55
+ # print(features.size()) # torch.Size([batch_size, 256, 8, 8])
56
+ image_features.append(features)
57
+
58
+ image_features = torch.stack(image_features).permute(1, 0, 2, 3, 4).contiguous()
59
+ # print(image_features.size()) # torch.Size([batch_size, n_views, 256, 8, 8])
60
+ return image_features
61
+
62
+
63
+
64
+ class DummyCfg:
65
+ class NETWORK:
66
+ TCONV_USE_BIAS = False
67
+
68
+ cfg = DummyCfg()
69
+
70
+ # Instantiate the decoder
71
+ encoder = Encoder(cfg)
72
+
73
+ # Simulate input: shape [batch_size,n_views,img_c, img_h, img_w]
74
+
75
+ batch_size = 64
76
+ n_views=5
77
+ img_c, img_h, img_w = 3,224,224
78
+ dummy_input = torch.randn(batch_size,n_views,img_c, img_h, img_w)
79
+
80
+ # Run the decoder
81
+ print(dummy_input.shape)
82
+ image_features = encoder(dummy_input)
83
+
84
+
85
+ print("image_features shape:", image_features.shape) # Expected: [64, 5, 9, 32, 32, 32]
models/merger.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ #
3
+ # Developed by Haozhe Xie <[email protected]>
4
+
5
+ import torch
6
+
7
+
8
+ class Merger(torch.nn.Module):
9
+ def __init__(self, cfg):
10
+ super(Merger, self).__init__()
11
+ self.cfg = cfg
12
+
13
+ # Layer Definition
14
+ self.layer1 = torch.nn.Sequential(
15
+ torch.nn.Conv3d(9, 16, kernel_size=3, padding=1),
16
+ torch.nn.BatchNorm3d(16),
17
+ torch.nn.LeakyReLU(cfg.NETWORK.LEAKY_VALUE)
18
+ )
19
+ self.layer2 = torch.nn.Sequential(
20
+ torch.nn.Conv3d(16, 8, kernel_size=3, padding=1),
21
+ torch.nn.BatchNorm3d(8),
22
+ torch.nn.LeakyReLU(cfg.NETWORK.LEAKY_VALUE)
23
+ )
24
+ self.layer3 = torch.nn.Sequential(
25
+ torch.nn.Conv3d(8, 4, kernel_size=3, padding=1),
26
+ torch.nn.BatchNorm3d(4),
27
+ torch.nn.LeakyReLU(cfg.NETWORK.LEAKY_VALUE)
28
+ )
29
+ self.layer4 = torch.nn.Sequential(
30
+ torch.nn.Conv3d(4, 2, kernel_size=3, padding=1),
31
+ torch.nn.BatchNorm3d(2),
32
+ torch.nn.LeakyReLU(cfg.NETWORK.LEAKY_VALUE)
33
+ )
34
+ self.layer5 = torch.nn.Sequential(
35
+ torch.nn.Conv3d(2, 1, kernel_size=3, padding=1),
36
+ torch.nn.BatchNorm3d(1),
37
+ torch.nn.LeakyReLU(cfg.NETWORK.LEAKY_VALUE)
38
+ )
39
+
40
+ def forward(self, raw_features, coarse_volumes):
41
+ n_views_rendering = coarse_volumes.size(1)
42
+ raw_features = torch.split(raw_features, 1, dim=1)
43
+ volume_weights = []
44
+
45
+ for i in range(n_views_rendering):
46
+ raw_feature = torch.squeeze(raw_features[i], dim=1)
47
+ # print(raw_feature.size()) # torch.Size([batch_size, 9, 32, 32, 32])
48
+
49
+ volume_weight = self.layer1(raw_feature)
50
+ # print(volume_weight.size()) # torch.Size([batch_size, 16, 32, 32, 32])
51
+ volume_weight = self.layer2(volume_weight)
52
+ # print(volume_weight.size()) # torch.Size([batch_size, 8, 32, 32, 32])
53
+ volume_weight = self.layer3(volume_weight)
54
+ # print(volume_weight.size()) # torch.Size([batch_size, 4, 32, 32, 32])
55
+ volume_weight = self.layer4(volume_weight)
56
+ # print(volume_weight.size()) # torch.Size([batch_size, 2, 32, 32, 32])
57
+ volume_weight = self.layer5(volume_weight)
58
+ # print(volume_weight.size()) # torch.Size([batch_size, 1, 32, 32, 32])
59
+
60
+ volume_weight = torch.squeeze(volume_weight, dim=1)
61
+ # print(volume_weight.size()) # torch.Size([batch_size, 32, 32, 32])
62
+ volume_weights.append(volume_weight)
63
+
64
+ volume_weights = torch.stack(volume_weights).permute(1, 0, 2, 3, 4).contiguous()
65
+ volume_weights = torch.softmax(volume_weights, dim=1)
66
+ # print(volume_weights.size()) # torch.Size([batch_size, n_views, 32, 32, 32])
67
+ # print(coarse_volumes.size()) # torch.Size([batch_size, n_views, 32, 32, 32])
68
+ coarse_volumes = coarse_volumes * volume_weights
69
+ coarse_volumes = torch.sum(coarse_volumes, dim=1)
70
+
71
+ return torch.clamp(coarse_volumes, min=0, max=1)
models/refiner.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ #
3
+ # Developed by Haozhe Xie <[email protected]>
4
+
5
+ import torch
6
+
7
+
8
+ class Refiner(torch.nn.Module):
9
+ def __init__(self, cfg):
10
+ super(Refiner, self).__init__()
11
+ self.cfg = cfg
12
+
13
+ # Layer Definition
14
+ self.layer1 = torch.nn.Sequential(
15
+ torch.nn.Conv3d(1, 32, kernel_size=4, padding=2),
16
+ torch.nn.BatchNorm3d(32),
17
+ torch.nn.LeakyReLU(cfg.NETWORK.LEAKY_VALUE),
18
+ torch.nn.MaxPool3d(kernel_size=2)
19
+ )
20
+ self.layer2 = torch.nn.Sequential(
21
+ torch.nn.Conv3d(32, 64, kernel_size=4, padding=2),
22
+ torch.nn.BatchNorm3d(64),
23
+ torch.nn.LeakyReLU(cfg.NETWORK.LEAKY_VALUE),
24
+ torch.nn.MaxPool3d(kernel_size=2)
25
+ )
26
+ self.layer3 = torch.nn.Sequential(
27
+ torch.nn.Conv3d(64, 128, kernel_size=4, padding=2),
28
+ torch.nn.BatchNorm3d(128),
29
+ torch.nn.LeakyReLU(cfg.NETWORK.LEAKY_VALUE),
30
+ torch.nn.MaxPool3d(kernel_size=2)
31
+ )
32
+ self.layer4 = torch.nn.Sequential(
33
+ torch.nn.Linear(8192, 2048),
34
+ torch.nn.ReLU()
35
+ )
36
+ self.layer5 = torch.nn.Sequential(
37
+ torch.nn.Linear(2048, 8192),
38
+ torch.nn.ReLU()
39
+ )
40
+ self.layer6 = torch.nn.Sequential(
41
+ torch.nn.ConvTranspose3d(128, 64, kernel_size=4, stride=2, bias=cfg.NETWORK.TCONV_USE_BIAS, padding=1),
42
+ torch.nn.BatchNorm3d(64),
43
+ torch.nn.ReLU()
44
+ )
45
+ self.layer7 = torch.nn.Sequential(
46
+ torch.nn.ConvTranspose3d(64, 32, kernel_size=4, stride=2, bias=cfg.NETWORK.TCONV_USE_BIAS, padding=1),
47
+ torch.nn.BatchNorm3d(32),
48
+ torch.nn.ReLU()
49
+ )
50
+ self.layer8 = torch.nn.Sequential(
51
+ torch.nn.ConvTranspose3d(32, 1, kernel_size=4, stride=2, bias=cfg.NETWORK.TCONV_USE_BIAS, padding=1),
52
+ torch.nn.Sigmoid()
53
+ )
54
+
55
+ def forward(self, coarse_volumes):
56
+ volumes_32_l = coarse_volumes.view((-1, 1, self.cfg.CONST.N_VOX, self.cfg.CONST.N_VOX, self.cfg.CONST.N_VOX))
57
+ # print(volumes_32_l.size()) # torch.Size([batch_size, 1, 32, 32, 32])
58
+ volumes_16_l = self.layer1(volumes_32_l)
59
+ # print(volumes_16_l.size()) # torch.Size([batch_size, 32, 16, 16, 16])
60
+ volumes_8_l = self.layer2(volumes_16_l)
61
+ # print(volumes_8_l.size()) # torch.Size([batch_size, 64, 8, 8, 8])
62
+ volumes_4_l = self.layer3(volumes_8_l)
63
+ # print(volumes_4_l.size()) # torch.Size([batch_size, 128, 4, 4, 4])
64
+ flatten_features = self.layer4(volumes_4_l.view(-1, 8192))
65
+ # print(flatten_features.size()) # torch.Size([batch_size, 2048])
66
+ flatten_features = self.layer5(flatten_features)
67
+ # print(flatten_features.size()) # torch.Size([batch_size, 8192])
68
+ volumes_4_r = volumes_4_l + flatten_features.view(-1, 128, 4, 4, 4)
69
+ # print(volumes_4_r.size()) # torch.Size([batch_size, 128, 4, 4, 4])
70
+ volumes_8_r = volumes_8_l + self.layer6(volumes_4_r)
71
+ # print(volumes_8_r.size()) # torch.Size([batch_size, 64, 8, 8, 8])
72
+ volumes_16_r = volumes_16_l + self.layer7(volumes_8_r)
73
+ # print(volumes_16_r.size()) # torch.Size([batch_size, 32, 16, 16, 16])
74
+ volumes_32_r = (volumes_32_l + self.layer8(volumes_16_r)) * 0.5
75
+ # print(volumes_32_r.size()) # torch.Size([batch_size, 1, 32, 32, 32])
76
+
77
+ return volumes_32_r.view((-1, self.cfg.CONST.N_VOX, self.cfg.CONST.N_VOX, self.cfg.CONST.N_VOX))
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ argparse
2
+ easydict
3
+ matplotlib
4
+ numpy
5
+ opencv-python
6
+ scipy
7
+ torch
8
+ torchvision
9
+ streamlit
10
+ plotly
11
+ pillow
utils/__pycache__/binvox_rw.cpython-38.pyc ADDED
Binary file (7.44 kB). View file
 
utils/__pycache__/data_transforms.cpython-38.pyc ADDED
Binary file (11.8 kB). View file
 
utils/binvox_rw.py ADDED
@@ -0,0 +1,296 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2012 Daniel Maturana
2
+ # This file is part of binvox-rw-py.
3
+ #
4
+ # binvox-rw-py is free software: you can redistribute it and/or modify
5
+ # it under the terms of the GNU General Public License as published by
6
+ # the Free Software Foundation, either version 3 of the License, or
7
+ # (at your option) any later version.
8
+ #
9
+ # binvox-rw-py is distributed in the hope that it will be useful,
10
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
11
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12
+ # GNU General Public License for more details.
13
+ #
14
+ # You should have received a copy of the GNU General Public License
15
+ # along with binvox-rw-py. If not, see <http://www.gnu.org/licenses/>.
16
+ #
17
+ """
18
+ Binvox to Numpy and back.
19
+
20
+
21
+ >>> import numpy as np
22
+ >>> import binvox_rw
23
+ >>> with open('chair.binvox', 'rb') as f:
24
+ ... m1 = binvox_rw.read_as_3d_array(f)
25
+ ...
26
+ >>> m1.dims
27
+ [32, 32, 32]
28
+ >>> m1.scale
29
+ 41.133000000000003
30
+ >>> m1.translate
31
+ [0.0, 0.0, 0.0]
32
+ >>> with open('chair_out.binvox', 'wb') as f:
33
+ ... m1.write(f)
34
+ ...
35
+ >>> with open('chair_out.binvox', 'rb') as f:
36
+ ... m2 = binvox_rw.read_as_3d_array(f)
37
+ ...
38
+ >>> m1.dims == m2.dims
39
+ True
40
+ >>> m1.scale == m2.scale
41
+ True
42
+ >>> m1.translate == m2.translate
43
+ True
44
+ >>> np.all(m1.data == m2.data)
45
+ True
46
+
47
+ >>> with open('chair.binvox', 'rb') as f:
48
+ ... md = binvox_rw.read_as_3d_array(f)
49
+ ...
50
+ >>> with open('chair.binvox', 'rb') as f:
51
+ ... ms = binvox_rw.read_as_coord_array(f)
52
+ ...
53
+ >>> data_ds = binvox_rw.dense_to_sparse(md.data)
54
+ >>> data_sd = binvox_rw.sparse_to_dense(ms.data, 32)
55
+ >>> np.all(data_sd == md.data)
56
+ True
57
+ >>> # the ordering of elements returned by numpy.nonzero changes with axis
58
+ >>> # ordering, so to compare for equality we first lexically sort the voxels.
59
+ >>> np.all(ms.data[:, np.lexsort(ms.data)] == data_ds[:, np.lexsort(data_ds)])
60
+ True
61
+ """
62
+
63
+ import numpy as np
64
+
65
+
66
+ class Voxels(object):
67
+ """ Holds a binvox model.
68
+ data is either a three-dimensional numpy boolean array (dense representation)
69
+ or a two-dimensional numpy float array (coordinate representation).
70
+
71
+ dims, translate and scale are the model metadata.
72
+
73
+ dims are the voxel dimensions, e.g. [32, 32, 32] for a 32x32x32 model.
74
+
75
+ scale and translate relate the voxels to the original model coordinates.
76
+
77
+ To translate voxel coordinates i, j, k to original coordinates x, y, z:
78
+
79
+ x_n = (i+.5)/dims[0]
80
+ y_n = (j+.5)/dims[1]
81
+ z_n = (k+.5)/dims[2]
82
+ x = scale*x_n + translate[0]
83
+ y = scale*y_n + translate[1]
84
+ z = scale*z_n + translate[2]
85
+
86
+ """
87
+ def __init__(self, data, dims, translate, scale, axis_order):
88
+ self.data = data
89
+ self.dims = dims
90
+ self.translate = translate
91
+ self.scale = scale
92
+ assert (axis_order in ('xzy', 'xyz'))
93
+ self.axis_order = axis_order
94
+
95
+ def clone(self):
96
+ data = self.data.copy()
97
+ dims = self.dims[:]
98
+ translate = self.translate[:]
99
+ return Voxels(data, dims, translate, self.scale, self.axis_order)
100
+
101
+ def write(self, fp):
102
+ write(self, fp)
103
+
104
+
105
+ def read_header(fp):
106
+ """ Read binvox header. Mostly meant for internal use.
107
+ """
108
+ line = fp.readline().strip()
109
+ if not line.startswith(b'#binvox'):
110
+ raise IOError('[ERROR] Not a binvox file')
111
+ dims = list(map(int, fp.readline().strip().split(b' ')[1:]))
112
+ translate = list(map(float, fp.readline().strip().split(b' ')[1:]))
113
+ scale = list(map(float, fp.readline().strip().split(b' ')[1:]))[0]
114
+ fp.readline()
115
+ return dims, translate, scale
116
+
117
+
118
+ def read_as_3d_array(fp, fix_coords=True):
119
+ """ Read binary binvox format as array.
120
+
121
+ Returns the model with accompanying metadata.
122
+
123
+ Voxels are stored in a three-dimensional numpy array, which is simple and
124
+ direct, but may use a lot of memory for large models. (Storage requirements
125
+ are 8*(d^3) bytes, where d is the dimensions of the binvox model. Numpy
126
+ boolean arrays use a byte per element).
127
+
128
+ Doesn't do any checks on input except for the '#binvox' line.
129
+ """
130
+ dims, translate, scale = read_header(fp)
131
+ raw_data = np.frombuffer(fp.read(), dtype=np.uint8)
132
+ # if just using reshape() on the raw data:
133
+ # indexing the array as array[i,j,k], the indices map into the
134
+ # coords as:
135
+ # i -> x
136
+ # j -> z
137
+ # k -> y
138
+ # if fix_coords is true, then data is rearranged so that
139
+ # mapping is
140
+ # i -> x
141
+ # j -> y
142
+ # k -> z
143
+ values, counts = raw_data[::2], raw_data[1::2]
144
+ data = np.repeat(values, counts).astype(np.int32)
145
+ data = data.reshape(dims)
146
+ if fix_coords:
147
+ # xzy to xyz TODO the right thing
148
+ data = np.transpose(data, (0, 2, 1))
149
+ axis_order = 'xyz'
150
+ else:
151
+ axis_order = 'xzy'
152
+ return Voxels(data, dims, translate, scale, axis_order)
153
+
154
+
155
+ def read_as_coord_array(fp, fix_coords=True):
156
+ """ Read binary binvox format as coordinates.
157
+
158
+ Returns binvox model with voxels in a "coordinate" representation, i.e. an
159
+ 3 x N array where N is the number of nonzero voxels. Each column
160
+ corresponds to a nonzero voxel and the 3 rows are the (x, z, y) coordinates
161
+ of the voxel. (The odd ordering is due to the way binvox format lays out
162
+ data). Note that coordinates refer to the binvox voxels, without any
163
+ scaling or translation.
164
+
165
+ Use this to save memory if your model is very sparse (mostly empty).
166
+
167
+ Doesn't do any checks on input except for the '#binvox' line.
168
+ """
169
+ dims, translate, scale = read_header(fp)
170
+ raw_data = np.frombuffer(fp.read(), dtype=np.uint8)
171
+
172
+ values, counts = raw_data[::2], raw_data[1::2]
173
+
174
+ # sz = np.prod(dims)
175
+ # index, end_index = 0, 0
176
+ end_indices = np.cumsum(counts)
177
+ indices = np.concatenate(([0], end_indices[:-1])).astype(end_indices.dtype)
178
+
179
+ values = values.astype(np.bool)
180
+ indices = indices[values]
181
+ end_indices = end_indices[values]
182
+
183
+ nz_voxels = []
184
+ for index, end_index in zip(indices, end_indices):
185
+ nz_voxels.extend(range(index, end_index))
186
+ nz_voxels = np.array(nz_voxels)
187
+ # TODO are these dims correct?
188
+ # according to docs,
189
+ # index = x * wxh + z * width + y; // wxh = width * height = d * d
190
+
191
+ x = nz_voxels / (dims[0] * dims[1])
192
+ zwpy = nz_voxels % (dims[0] * dims[1]) # z*w + y
193
+ z = zwpy / dims[0]
194
+ y = zwpy % dims[0]
195
+ if fix_coords:
196
+ data = np.vstack((x, y, z))
197
+ axis_order = 'xyz'
198
+ else:
199
+ data = np.vstack((x, z, y))
200
+ axis_order = 'xzy'
201
+
202
+ #return Voxels(data, dims, translate, scale, axis_order)
203
+ return Voxels(np.ascontiguousarray(data), dims, translate, scale, axis_order)
204
+
205
+
206
+ def dense_to_sparse(voxel_data, dtype=int):
207
+ """ From dense representation to sparse (coordinate) representation.
208
+ No coordinate reordering.
209
+ """
210
+ if voxel_data.ndim != 3:
211
+ raise ValueError('[ERROR] voxel_data is wrong shape; should be 3D array.')
212
+ return np.asarray(np.nonzero(voxel_data), dtype)
213
+
214
+
215
+ def sparse_to_dense(voxel_data, dims, dtype=bool):
216
+ if voxel_data.ndim != 2 or voxel_data.shape[0] != 3:
217
+ raise ValueError('[ERROR] voxel_data is wrong shape; should be 3xN array.')
218
+ if np.isscalar(dims):
219
+ dims = [dims] * 3
220
+ dims = np.atleast_2d(dims).T
221
+ # truncate to integers
222
+ xyz = voxel_data.astype(np.int)
223
+ # discard voxels that fall outside dims
224
+ valid_ix = ~np.any((xyz < 0) | (xyz >= dims), 0)
225
+ xyz = xyz[:, valid_ix]
226
+ out = np.zeros(dims.flatten(), dtype=dtype)
227
+ out[tuple(xyz)] = True
228
+ return out
229
+
230
+
231
+ #def get_linear_index(x, y, z, dims):
232
+ #""" Assuming xzy order. (y increasing fastest.
233
+ #TODO ensure this is right when dims are not all same
234
+ #"""
235
+ #return x*(dims[1]*dims[2]) + z*dims[1] + y
236
+
237
+
238
+ def write(voxel_model, fp):
239
+ """ Write binary binvox format.
240
+
241
+ Note that when saving a model in sparse (coordinate) format, it is first
242
+ converted to dense format.
243
+
244
+ Doesn't check if the model is 'sane'.
245
+
246
+ """
247
+ if voxel_model.data.ndim == 2:
248
+ # TODO avoid conversion to dense
249
+ dense_voxel_data = sparse_to_dense(voxel_model.data, voxel_model.dims).astype(int)
250
+ else:
251
+ dense_voxel_data = voxel_model.data.astype(int)
252
+
253
+ file_header = [
254
+ '#binvox 1\n',
255
+ 'dim %s\n' % ' '.join(map(str, voxel_model.dims)),
256
+ 'translate %s\n' % ' '.join(map(str, voxel_model.translate)),
257
+ 'scale %s\n' % str(voxel_model.scale), 'data\n'
258
+ ]
259
+
260
+ for fh in file_header:
261
+ fp.write(fh.encode('latin-1'))
262
+
263
+ if voxel_model.axis_order not in ('xzy', 'xyz'):
264
+ raise ValueError('[ERROR] Unsupported voxel model axis order')
265
+
266
+ if voxel_model.axis_order == 'xzy':
267
+ voxels_flat = dense_voxel_data.flatten()
268
+ elif voxel_model.axis_order == 'xyz':
269
+ voxels_flat = np.transpose(dense_voxel_data, (0, 2, 1)).flatten()
270
+
271
+ # keep a sort of state machine for writing run length encoding
272
+ state = voxels_flat[0]
273
+ ctr = 0
274
+ for c in voxels_flat:
275
+ if c == state:
276
+ ctr += 1
277
+ # if ctr hits max, dump
278
+ if ctr == 255:
279
+ fp.write(chr(state).encode('latin-1'))
280
+ fp.write(chr(ctr).encode('latin-1'))
281
+ ctr = 0
282
+ else:
283
+ # if switch state, dump
284
+ fp.write(chr(state).encode('latin-1'))
285
+ fp.write(chr(ctr).encode('latin-1'))
286
+ state = c
287
+ ctr = 1
288
+ # flush out remainders
289
+ if ctr > 0:
290
+ fp.write(chr(state).encode('latin-1'))
291
+ fp.write(chr(ctr).encode('latin-1'))
292
+
293
+
294
+ if __name__ == '__main__':
295
+ import doctest
296
+ doctest.testmod()
utils/data_transforms.py ADDED
@@ -0,0 +1,452 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ #
3
+ # Developed by Haozhe Xie <[email protected]>
4
+ # References:
5
+ # - https://github.com/xiumingzhang/GenRe-ShapeHD
6
+
7
+ import cv2
8
+ # import matplotlib.pyplot as plt
9
+ # import matplotlib.patches as patches
10
+ import numpy as np
11
+ import os
12
+ import random
13
+ import torch
14
+
15
+
16
+ class Compose(object):
17
+ """ Composes several transforms together.
18
+ For example:
19
+ >>> transforms.Compose([
20
+ >>> transforms.RandomBackground(),
21
+ >>> transforms.CenterCrop(127, 127, 3),
22
+ >>> ])
23
+ """
24
+ def __init__(self, transforms):
25
+ self.transforms = transforms
26
+
27
+ def __call__(self, rendering_images, bounding_box=None):
28
+ for t in self.transforms:
29
+ if t.__class__.__name__ == 'RandomCrop' or t.__class__.__name__ == 'CenterCrop':
30
+ rendering_images = t(rendering_images, bounding_box)
31
+ else:
32
+ rendering_images = t(rendering_images)
33
+
34
+ return rendering_images
35
+
36
+
37
+ class ToTensor(object):
38
+ """
39
+ Convert a PIL Image or numpy.ndarray to tensor.
40
+ Converts a PIL Image or numpy.ndarray (H x W x C) in the range [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0].
41
+ """
42
+ def __call__(self, rendering_images):
43
+ assert (isinstance(rendering_images, np.ndarray))
44
+ array = np.transpose(rendering_images, (0, 3, 1, 2))
45
+ # handle numpy array
46
+ tensor = torch.from_numpy(array)
47
+
48
+ # put it from HWC to CHW format
49
+ return tensor.float()
50
+
51
+
52
+ class Normalize(object):
53
+ def __init__(self, mean, std):
54
+ self.mean = mean
55
+ self.std = std
56
+
57
+ def __call__(self, rendering_images):
58
+ assert (isinstance(rendering_images, np.ndarray))
59
+ rendering_images -= self.mean
60
+ rendering_images /= self.std
61
+
62
+ return rendering_images
63
+
64
+
65
+ class RandomPermuteRGB(object):
66
+ def __call__(self, rendering_images):
67
+ assert (isinstance(rendering_images, np.ndarray))
68
+
69
+ random_permutation = np.random.permutation(3)
70
+ for img_idx, img in enumerate(rendering_images):
71
+ rendering_images[img_idx] = img[..., random_permutation]
72
+
73
+ return rendering_images
74
+
75
+
76
+ class CenterCrop(object):
77
+ def __init__(self, img_size, crop_size):
78
+ """Set the height and weight before and after cropping"""
79
+ self.img_size_h = img_size[0]
80
+ self.img_size_w = img_size[1]
81
+ self.crop_size_h = crop_size[0]
82
+ self.crop_size_w = crop_size[1]
83
+
84
+ def __call__(self, rendering_images, bounding_box=None):
85
+ if len(rendering_images) == 0:
86
+ return rendering_images
87
+
88
+ crop_size_c = rendering_images[0].shape[2]
89
+ processed_images = np.empty(shape=(0, self.img_size_h, self.img_size_w, crop_size_c))
90
+ for img_idx, img in enumerate(rendering_images):
91
+ img_height, img_width, _ = img.shape
92
+
93
+ if bounding_box is not None:
94
+ bounding_box = [
95
+ bounding_box[0] * img_width,
96
+ bounding_box[1] * img_height,
97
+ bounding_box[2] * img_width,
98
+ bounding_box[3] * img_height
99
+ ] # yapf: disable
100
+
101
+ # Calculate the size of bounding boxes
102
+ bbox_width = bounding_box[2] - bounding_box[0]
103
+ bbox_height = bounding_box[3] - bounding_box[1]
104
+ bbox_x_mid = (bounding_box[2] + bounding_box[0]) * .5
105
+ bbox_y_mid = (bounding_box[3] + bounding_box[1]) * .5
106
+
107
+ # Make the crop area as a square
108
+ square_object_size = max(bbox_width, bbox_height)
109
+ x_left = int(bbox_x_mid - square_object_size * .5)
110
+ x_right = int(bbox_x_mid + square_object_size * .5)
111
+ y_top = int(bbox_y_mid - square_object_size * .5)
112
+ y_bottom = int(bbox_y_mid + square_object_size * .5)
113
+
114
+ # If the crop position is out of the image, fix it with padding
115
+ pad_x_left = 0
116
+ if x_left < 0:
117
+ pad_x_left = -x_left
118
+ x_left = 0
119
+ pad_x_right = 0
120
+ if x_right >= img_width:
121
+ pad_x_right = x_right - img_width + 1
122
+ x_right = img_width - 1
123
+ pad_y_top = 0
124
+ if y_top < 0:
125
+ pad_y_top = -y_top
126
+ y_top = 0
127
+ pad_y_bottom = 0
128
+ if y_bottom >= img_height:
129
+ pad_y_bottom = y_bottom - img_height + 1
130
+ y_bottom = img_height - 1
131
+
132
+ # Padding the image and resize the image
133
+ processed_image = np.pad(img[y_top:y_bottom + 1, x_left:x_right + 1],
134
+ ((pad_y_top, pad_y_bottom), (pad_x_left, pad_x_right), (0, 0)),
135
+ mode='edge')
136
+ processed_image = cv2.resize(processed_image, (self.img_size_w, self.img_size_h))
137
+ else:
138
+ if img_height > self.crop_size_h and img_width > self.crop_size_w:
139
+ x_left = int(img_width - self.crop_size_w) // 2
140
+ x_right = int(x_left + self.crop_size_w)
141
+ y_top = int(img_height - self.crop_size_h) // 2
142
+ y_bottom = int(y_top + self.crop_size_h)
143
+ else:
144
+ x_left = 0
145
+ x_right = img_width
146
+ y_top = 0
147
+ y_bottom = img_height
148
+
149
+ processed_image = cv2.resize(img[y_top:y_bottom, x_left:x_right], (self.img_size_w, self.img_size_h))
150
+
151
+ processed_images = np.append(processed_images, [processed_image], axis=0)
152
+ # Debug
153
+ # fig = plt.figure()
154
+ # ax1 = fig.add_subplot(1, 2, 1)
155
+ # ax1.imshow(img)
156
+ # if not bounding_box is None:
157
+ # rect = patches.Rectangle((bounding_box[0], bounding_box[1]),
158
+ # bbox_width,
159
+ # bbox_height,
160
+ # linewidth=1,
161
+ # edgecolor='r',
162
+ # facecolor='none')
163
+ # ax1.add_patch(rect)
164
+ # ax2 = fig.add_subplot(1, 2, 2)
165
+ # ax2.imshow(processed_image)
166
+ # plt.show()
167
+ return processed_images
168
+
169
+
170
+ class RandomCrop(object):
171
+ def __init__(self, img_size, crop_size):
172
+ """Set the height and weight before and after cropping"""
173
+ self.img_size_h = img_size[0]
174
+ self.img_size_w = img_size[1]
175
+ self.crop_size_h = crop_size[0]
176
+ self.crop_size_w = crop_size[1]
177
+
178
+ def __call__(self, rendering_images, bounding_box=None):
179
+ if len(rendering_images) == 0:
180
+ return rendering_images
181
+
182
+ crop_size_c = rendering_images[0].shape[2]
183
+ processed_images = np.empty(shape=(0, self.img_size_h, self.img_size_w, crop_size_c))
184
+ for img_idx, img in enumerate(rendering_images):
185
+ img_height, img_width, _ = img.shape
186
+
187
+ if bounding_box is not None:
188
+ bounding_box = [
189
+ bounding_box[0] * img_width,
190
+ bounding_box[1] * img_height,
191
+ bounding_box[2] * img_width,
192
+ bounding_box[3] * img_height
193
+ ] # yapf: disable
194
+
195
+ # Calculate the size of bounding boxes
196
+ bbox_width = bounding_box[2] - bounding_box[0]
197
+ bbox_height = bounding_box[3] - bounding_box[1]
198
+ bbox_x_mid = (bounding_box[2] + bounding_box[0]) * .5
199
+ bbox_y_mid = (bounding_box[3] + bounding_box[1]) * .5
200
+
201
+ # Make the crop area as a square
202
+ square_object_size = max(bbox_width, bbox_height)
203
+ square_object_size = square_object_size * random.uniform(0.8, 1.2)
204
+
205
+ x_left = int(bbox_x_mid - square_object_size * random.uniform(.4, .6))
206
+ x_right = int(bbox_x_mid + square_object_size * random.uniform(.4, .6))
207
+ y_top = int(bbox_y_mid - square_object_size * random.uniform(.4, .6))
208
+ y_bottom = int(bbox_y_mid + square_object_size * random.uniform(.4, .6))
209
+
210
+ # If the crop position is out of the image, fix it with padding
211
+ pad_x_left = 0
212
+ if x_left < 0:
213
+ pad_x_left = -x_left
214
+ x_left = 0
215
+ pad_x_right = 0
216
+ if x_right >= img_width:
217
+ pad_x_right = x_right - img_width + 1
218
+ x_right = img_width - 1
219
+ pad_y_top = 0
220
+ if y_top < 0:
221
+ pad_y_top = -y_top
222
+ y_top = 0
223
+ pad_y_bottom = 0
224
+ if y_bottom >= img_height:
225
+ pad_y_bottom = y_bottom - img_height + 1
226
+ y_bottom = img_height - 1
227
+
228
+ # Padding the image and resize the image
229
+ processed_image = np.pad(img[y_top:y_bottom + 1, x_left:x_right + 1],
230
+ ((pad_y_top, pad_y_bottom), (pad_x_left, pad_x_right), (0, 0)),
231
+ mode='edge')
232
+ processed_image = cv2.resize(processed_image, (self.img_size_w, self.img_size_h))
233
+ else:
234
+ if img_height > self.crop_size_h and img_width > self.crop_size_w:
235
+ x_left = int(img_width - self.crop_size_w) // 2
236
+ x_right = int(x_left + self.crop_size_w)
237
+ y_top = int(img_height - self.crop_size_h) // 2
238
+ y_bottom = int(y_top + self.crop_size_h)
239
+ else:
240
+ x_left = 0
241
+ x_right = img_width
242
+ y_top = 0
243
+ y_bottom = img_height
244
+
245
+ processed_image = cv2.resize(img[y_top:y_bottom, x_left:x_right], (self.img_size_w, self.img_size_h))
246
+
247
+ processed_images = np.append(processed_images, [processed_image], axis=0)
248
+
249
+ return processed_images
250
+
251
+
252
+ class RandomFlip(object):
253
+ def __call__(self, rendering_images):
254
+ assert (isinstance(rendering_images, np.ndarray))
255
+
256
+ for img_idx, img in enumerate(rendering_images):
257
+ if random.randint(0, 1):
258
+ rendering_images[img_idx] = np.fliplr(img)
259
+
260
+ return rendering_images
261
+
262
+
263
+ class ColorJitter(object):
264
+ def __init__(self, brightness, contrast, saturation):
265
+ self.brightness = brightness
266
+ self.contrast = contrast
267
+ self.saturation = saturation
268
+
269
+ def __call__(self, rendering_images):
270
+ if len(rendering_images) == 0:
271
+ return rendering_images
272
+
273
+ # Allocate new space for storing processed images
274
+ img_height, img_width, img_channels = rendering_images[0].shape
275
+ processed_images = np.empty(shape=(0, img_height, img_width, img_channels))
276
+
277
+ # Randomize the value of changing brightness, contrast, and saturation
278
+ brightness = 1 + np.random.uniform(low=-self.brightness, high=self.brightness)
279
+ contrast = 1 + np.random.uniform(low=-self.contrast, high=self.contrast)
280
+ saturation = 1 + np.random.uniform(low=-self.saturation, high=self.saturation)
281
+
282
+ # Randomize the order of changing brightness, contrast, and saturation
283
+ attr_names = ['brightness', 'contrast', 'saturation']
284
+ attr_values = [brightness, contrast, saturation] # The value of changing attrs
285
+ attr_indexes = np.array(range(len(attr_names))) # The order of changing attrs
286
+ np.random.shuffle(attr_indexes)
287
+
288
+ for img_idx, img in enumerate(rendering_images):
289
+ processed_image = img
290
+ for idx in attr_indexes:
291
+ processed_image = self._adjust_image_attr(processed_image, attr_names[idx], attr_values[idx])
292
+
293
+ processed_images = np.append(processed_images, [processed_image], axis=0)
294
+ # print('ColorJitter', np.mean(ori_img), np.mean(processed_image))
295
+ # fig = plt.figure(figsize=(8, 4))
296
+ # ax1 = fig.add_subplot(1, 2, 1)
297
+ # ax1.imshow(ori_img)
298
+ # ax2 = fig.add_subplot(1, 2, 2)
299
+ # ax2.imshow(processed_image)
300
+ # plt.show()
301
+ return processed_images
302
+
303
+ def _adjust_image_attr(self, img, attr_name, attr_value):
304
+ """
305
+ Adjust or randomize the specified attribute of the image
306
+
307
+ Args:
308
+ img: Image in BGR format
309
+ Numpy array of shape (h, w, 3)
310
+ attr_name: Image attribute to adjust or randomize
311
+ 'brightness', 'saturation', or 'contrast'
312
+ attr_value: the alpha for blending is randomly drawn from [1 - d, 1 + d]
313
+
314
+ Returns:
315
+ Output image in BGR format
316
+ Numpy array of the same shape as input
317
+ """
318
+ gs = self._bgr_to_gray(img)
319
+
320
+ if attr_name == 'contrast':
321
+ img = self._alpha_blend(img, np.mean(gs[:, :, 0]), attr_value)
322
+ elif attr_name == 'saturation':
323
+ img = self._alpha_blend(img, gs, attr_value)
324
+ elif attr_name == 'brightness':
325
+ img = self._alpha_blend(img, 0, attr_value)
326
+ else:
327
+ raise NotImplementedError(attr_name)
328
+ return img
329
+
330
+ def _bgr_to_gray(self, bgr):
331
+ """
332
+ Convert a RGB image to a grayscale image
333
+ Differences from cv2.cvtColor():
334
+ 1. Input image can be float
335
+ 2. Output image has three repeated channels, other than a single channel
336
+
337
+ Args:
338
+ bgr: Image in BGR format
339
+ Numpy array of shape (h, w, 3)
340
+
341
+ Returns:
342
+ gs: Grayscale image
343
+ Numpy array of the same shape as input; the three channels are the same
344
+ """
345
+ ch = 0.114 * bgr[:, :, 0] + 0.587 * bgr[:, :, 1] + 0.299 * bgr[:, :, 2]
346
+ gs = np.dstack((ch, ch, ch))
347
+ return gs
348
+
349
+ def _alpha_blend(self, im1, im2, alpha):
350
+ """
351
+ Alpha blending of two images or one image and a scalar
352
+
353
+ Args:
354
+ im1, im2: Image or scalar
355
+ Numpy array and a scalar or two numpy arrays of the same shape
356
+ alpha: Weight of im1
357
+ Float ranging usually from 0 to 1
358
+
359
+ Returns:
360
+ im_blend: Blended image -- alpha * im1 + (1 - alpha) * im2
361
+ Numpy array of the same shape as input image
362
+ """
363
+ im_blend = alpha * im1 + (1 - alpha) * im2
364
+ return im_blend
365
+
366
+
367
+ class RandomNoise(object):
368
+ def __init__(self,
369
+ noise_std,
370
+ eigvals=(0.2175, 0.0188, 0.0045),
371
+ eigvecs=((-0.5675, 0.7192, 0.4009), (-0.5808, -0.0045, -0.8140), (-0.5836, -0.6948, 0.4203))):
372
+ self.noise_std = noise_std
373
+ self.eigvals = np.array(eigvals)
374
+ self.eigvecs = np.array(eigvecs)
375
+
376
+ def __call__(self, rendering_images):
377
+ alpha = np.random.normal(loc=0, scale=self.noise_std, size=3)
378
+ noise_rgb = \
379
+ np.sum(
380
+ np.multiply(
381
+ np.multiply(
382
+ self.eigvecs,
383
+ np.tile(alpha, (3, 1))
384
+ ),
385
+ np.tile(self.eigvals, (3, 1))
386
+ ),
387
+ axis=1
388
+ )
389
+
390
+ # Allocate new space for storing processed images
391
+ img_height, img_width, img_channels = rendering_images[0].shape
392
+ assert (img_channels == 3), "Please use RandomBackground to normalize image channels"
393
+ processed_images = np.empty(shape=(0, img_height, img_width, img_channels))
394
+
395
+ for img_idx, img in enumerate(rendering_images):
396
+ processed_image = img[:, :, ::-1] # BGR -> RGB
397
+ for i in range(img_channels):
398
+ processed_image[:, :, i] += noise_rgb[i]
399
+
400
+ processed_image = processed_image[:, :, ::-1] # RGB -> BGR
401
+ processed_images = np.append(processed_images, [processed_image], axis=0)
402
+ # from copy import deepcopy
403
+ # ori_img = deepcopy(img)
404
+ # print(noise_rgb, np.mean(processed_image), np.mean(ori_img))
405
+ # print('RandomNoise', np.mean(ori_img), np.mean(processed_image))
406
+ # fig = plt.figure(figsize=(8, 4))
407
+ # ax1 = fig.add_subplot(1, 2, 1)
408
+ # ax1.imshow(ori_img)
409
+ # ax2 = fig.add_subplot(1, 2, 2)
410
+ # ax2.imshow(processed_image)
411
+ # plt.show()
412
+ return processed_images
413
+
414
+
415
+ class RandomBackground(object):
416
+ def __init__(self, random_bg_color_range, random_bg_folder_path=None):
417
+ self.random_bg_color_range = random_bg_color_range
418
+ self.random_bg_files = []
419
+ if random_bg_folder_path is not None:
420
+ self.random_bg_files = os.listdir(random_bg_folder_path)
421
+ self.random_bg_files = [os.path.join(random_bg_folder_path, rbf) for rbf in self.random_bg_files]
422
+
423
+ def __call__(self, rendering_images):
424
+ if len(rendering_images) == 0:
425
+ return rendering_images
426
+
427
+ img_height, img_width, img_channels = rendering_images[0].shape
428
+ # If the image has the alpha channel, add the background
429
+ if not img_channels == 4:
430
+ return rendering_images
431
+
432
+ # Generate random background
433
+ r, g, b = np.array([
434
+ np.random.randint(self.random_bg_color_range[i][0], self.random_bg_color_range[i][1] + 1) for i in range(3)
435
+ ]) / 255.
436
+
437
+ random_bg = None
438
+ if len(self.random_bg_files) > 0:
439
+ random_bg_file_path = random.choice(self.random_bg_files)
440
+ random_bg = cv2.imread(random_bg_file_path).astype(np.float32) / 255.
441
+
442
+ # Apply random background
443
+ processed_images = np.empty(shape=(0, img_height, img_width, img_channels - 1))
444
+ for img_idx, img in enumerate(rendering_images):
445
+ alpha = (np.expand_dims(img[:, :, 3], axis=2) == 0).astype(np.float32)
446
+ img = img[:, :, :3]
447
+ bg_color = random_bg if random.randint(0, 1) and random_bg is not None else np.array([[[r, g, b]]])
448
+ img = alpha * bg_color + (1 - alpha) * img
449
+
450
+ processed_images = np.append(processed_images, [img], axis=0)
451
+
452
+ return processed_images