Spaces:
Sleeping
Sleeping
Commit
·
757ed1c
0
Parent(s):
initial commit
Browse files- .gitignore +1 -0
- README.md +13 -0
- __pycache__/binvox_rw.cpython-38.pyc +0 -0
- __pycache__/config.cpython-38.pyc +0 -0
- __pycache__/data_transforms.cpython-38.pyc +0 -0
- __pycache__/helpers.cpython-38.pyc +0 -0
- __pycache__/utils.cpython-38.pyc +0 -0
- app.py +39 -0
- check.ipynb +0 -0
- config.py +108 -0
- helpers.py +119 -0
- models/__init__.py +0 -0
- models/__pycache__/__init__.cpython-38.pyc +0 -0
- models/__pycache__/decoder.cpython-38.pyc +0 -0
- models/__pycache__/encoder.cpython-38.pyc +0 -0
- models/__pycache__/merger.cpython-38.pyc +0 -0
- models/__pycache__/refiner.cpython-38.pyc +0 -0
- models/decoder.py +88 -0
- models/encoder.py +85 -0
- models/merger.py +71 -0
- models/refiner.py +77 -0
- requirements.txt +11 -0
- utils/__pycache__/binvox_rw.cpython-38.pyc +0 -0
- utils/__pycache__/data_transforms.cpython-38.pyc +0 -0
- utils/binvox_rw.py +296 -0
- utils/data_transforms.py +452 -0
.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
saved_model/
|
README.md
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: Snap2scene
|
3 |
+
emoji: 😻
|
4 |
+
colorFrom: pink
|
5 |
+
colorTo: blue
|
6 |
+
sdk: streamlit
|
7 |
+
sdk_version: 1.44.1
|
8 |
+
app_file: app.py
|
9 |
+
pinned: false
|
10 |
+
license: mit
|
11 |
+
---
|
12 |
+
|
13 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
__pycache__/binvox_rw.cpython-38.pyc
ADDED
Binary file (7.44 kB). View file
|
|
__pycache__/config.cpython-38.pyc
ADDED
Binary file (2.48 kB). View file
|
|
__pycache__/data_transforms.cpython-38.pyc
ADDED
Binary file (11.8 kB). View file
|
|
__pycache__/helpers.cpython-38.pyc
ADDED
Binary file (2.86 kB). View file
|
|
__pycache__/utils.cpython-38.pyc
ADDED
Binary file (2.8 kB). View file
|
|
app.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import numpy as np
|
3 |
+
import plotly.graph_objects as go
|
4 |
+
from PIL import Image
|
5 |
+
from helpers import *
|
6 |
+
|
7 |
+
# --- APP START ---
|
8 |
+
st.title("2D → 3D Voxel Reconstruction Viewer")
|
9 |
+
|
10 |
+
uploaded_images = st.file_uploader(f"Upload images", accept_multiple_files=True, type=["png", "jpg", "jpeg"])
|
11 |
+
# print(uploaded_images)
|
12 |
+
|
13 |
+
|
14 |
+
# --- DISPLAY ---
|
15 |
+
if uploaded_images:
|
16 |
+
st.subheader("Uploaded Input Views")
|
17 |
+
cols = st.columns(len(uploaded_images))
|
18 |
+
rendering_images = []
|
19 |
+
|
20 |
+
for i, uploaded_file in enumerate(uploaded_images):
|
21 |
+
img = Image.open(uploaded_file)
|
22 |
+
|
23 |
+
cols[i].image(img, caption=f"View {i+1}", use_container_width=True)
|
24 |
+
|
25 |
+
img_np = np.array(img).astype(np.float32) / 255.0
|
26 |
+
|
27 |
+
rendering_images.append(img_np)
|
28 |
+
|
29 |
+
|
30 |
+
if st.button("Submit for Reconstruction"):
|
31 |
+
gv=None
|
32 |
+
with st.spinner("Reconstructing..."):
|
33 |
+
gv = predict_voxel_from_images(rendering_images)
|
34 |
+
|
35 |
+
fig = voxel_to_plotly(gv)
|
36 |
+
st.plotly_chart(fig, use_container_width=True)
|
37 |
+
|
38 |
+
else:
|
39 |
+
st.info(f"Upload images to continue.")
|
check.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
config.py
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
#
|
3 |
+
# Developed by Haozhe Xie <[email protected]>
|
4 |
+
|
5 |
+
from easydict import EasyDict as edict
|
6 |
+
|
7 |
+
__C = edict()
|
8 |
+
cfg = __C
|
9 |
+
|
10 |
+
#
|
11 |
+
# Dataset Config
|
12 |
+
#
|
13 |
+
__C.DATASETS = edict()
|
14 |
+
__C.DATASETS.SHAPENET = edict()
|
15 |
+
__C.DATASETS.SHAPENET.TAXONOMY_FILE_PATH = 'datasets/ShapeNet.json'
|
16 |
+
# __C.DATASETS.SHAPENET.TAXONOMY_FILE_PATH = './datasets/PascalShapeNet.json'
|
17 |
+
__C.DATASETS.SHAPENET.RENDERING_PATH = 'datasets/ShapeNetRendering/%s/%s/rendering/%02d.png'
|
18 |
+
# __C.DATASETS.SHAPENET.RENDERING_PATH = '/home/hzxie/Datasets/ShapeNet/PascalShapeNetRendering/%s/%s/render_%04d.jpg'
|
19 |
+
__C.DATASETS.SHAPENET.VOXEL_PATH = 'datasets/ShapeNetVox32/%s/%s/model.binvox'
|
20 |
+
__C.DATASETS.PASCAL3D = edict()
|
21 |
+
__C.DATASETS.PASCAL3D.TAXONOMY_FILE_PATH = 'datasets/Pascal3D.json'
|
22 |
+
__C.DATASETS.PASCAL3D.ANNOTATION_PATH = '/home/hzxie/Datasets/PASCAL3D/Annotations/%s_imagenet/%s.mat'
|
23 |
+
__C.DATASETS.PASCAL3D.RENDERING_PATH = '/home/hzxie/Datasets/PASCAL3D/Images/%s_imagenet/%s.JPEG'
|
24 |
+
__C.DATASETS.PASCAL3D.VOXEL_PATH = '/home/hzxie/Datasets/PASCAL3D/CAD/%s/%02d.binvox'
|
25 |
+
__C.DATASETS.PIX3D = edict()
|
26 |
+
__C.DATASETS.PIX3D.TAXONOMY_FILE_PATH = 'datasets/Pix3D.json'
|
27 |
+
__C.DATASETS.PIX3D.ANNOTATION_PATH = 'datasets/Pix3D/pix3d.json'
|
28 |
+
__C.DATASETS.PIX3D.RENDERING_PATH = 'datasets/Pix3D/img/%s/%s.%s'
|
29 |
+
__C.DATASETS.PIX3D.VOXEL_PATH = 'datasets/Pix3D/model/%s/%s/%s.binvox'
|
30 |
+
|
31 |
+
#
|
32 |
+
# Dataset
|
33 |
+
#
|
34 |
+
__C.DATASET = edict()
|
35 |
+
__C.DATASET.MEAN = [0.5, 0.5, 0.5]
|
36 |
+
__C.DATASET.STD = [0.5, 0.5, 0.5]
|
37 |
+
__C.DATASET.TRAIN_DATASET = 'ShapeNet'
|
38 |
+
__C.DATASET.TEST_DATASET = 'ShapeNet'
|
39 |
+
# __C.DATASET.TEST_DATASET = 'Pascal3D'
|
40 |
+
# __C.DATASET.TEST_DATASET = 'Pix3D'
|
41 |
+
|
42 |
+
#
|
43 |
+
# Common
|
44 |
+
#
|
45 |
+
__C.CONST = edict()
|
46 |
+
__C.CONST.DEVICE = '0'
|
47 |
+
__C.CONST.RNG_SEED = 0
|
48 |
+
__C.CONST.IMG_W = 224 # Image width for input
|
49 |
+
__C.CONST.IMG_H = 224 # Image height for input
|
50 |
+
__C.CONST.N_VOX = 32
|
51 |
+
__C.CONST.BATCH_SIZE = 64
|
52 |
+
__C.CONST.N_VIEWS_RENDERING = 1 # Dummy property for Pascal 3D
|
53 |
+
__C.CONST.CROP_IMG_W = 128 # Dummy property for Pascal 3D
|
54 |
+
__C.CONST.CROP_IMG_H = 128 # Dummy property for Pascal 3D
|
55 |
+
|
56 |
+
#
|
57 |
+
# Directories
|
58 |
+
#
|
59 |
+
__C.DIR = edict()
|
60 |
+
__C.DIR.OUT_PATH = './output'
|
61 |
+
__C.DIR.RANDOM_BG_PATH = '/home/hzxie/Datasets/SUN2012/JPEGImages'
|
62 |
+
|
63 |
+
#
|
64 |
+
# Network
|
65 |
+
#
|
66 |
+
__C.NETWORK = edict()
|
67 |
+
__C.NETWORK.LEAKY_VALUE = .2
|
68 |
+
__C.NETWORK.TCONV_USE_BIAS = False
|
69 |
+
__C.NETWORK.USE_REFINER = True
|
70 |
+
__C.NETWORK.USE_MERGER = True
|
71 |
+
|
72 |
+
#
|
73 |
+
# Training
|
74 |
+
#
|
75 |
+
__C.TRAIN = edict()
|
76 |
+
__C.TRAIN.RESUME_TRAIN = False
|
77 |
+
__C.TRAIN.NUM_WORKER = 4 # number of data workers
|
78 |
+
__C.TRAIN.NUM_EPOCHES = 5
|
79 |
+
__C.TRAIN.BRIGHTNESS = .4
|
80 |
+
__C.TRAIN.CONTRAST = .4
|
81 |
+
__C.TRAIN.SATURATION = .4
|
82 |
+
__C.TRAIN.NOISE_STD = .1
|
83 |
+
__C.TRAIN.RANDOM_BG_COLOR_RANGE = [[225, 255], [225, 255], [225, 255]]
|
84 |
+
__C.TRAIN.POLICY = 'adam' # available options: sgd, adam
|
85 |
+
__C.TRAIN.EPOCH_START_USE_REFINER = 0
|
86 |
+
__C.TRAIN.EPOCH_START_USE_MERGER = 0
|
87 |
+
__C.TRAIN.ENCODER_LEARNING_RATE = 1e-3
|
88 |
+
__C.TRAIN.DECODER_LEARNING_RATE = 1e-3
|
89 |
+
__C.TRAIN.REFINER_LEARNING_RATE = 1e-3
|
90 |
+
__C.TRAIN.MERGER_LEARNING_RATE = 1e-4
|
91 |
+
__C.TRAIN.DISCRIMINATOR_LR = 1e-4
|
92 |
+
__C.TRAIN.GAN_LOSS_WEIGHT = 0.01
|
93 |
+
__C.TRAIN.ENCODER_LR_MILESTONES = [150]
|
94 |
+
__C.TRAIN.DECODER_LR_MILESTONES = [150]
|
95 |
+
__C.TRAIN.REFINER_LR_MILESTONES = [150]
|
96 |
+
__C.TRAIN.MERGER_LR_MILESTONES = [150]
|
97 |
+
__C.TRAIN.BETAS = (.9, .999)
|
98 |
+
__C.TRAIN.MOMENTUM = .9
|
99 |
+
__C.TRAIN.GAMMA = .5
|
100 |
+
__C.TRAIN.SAVE_FREQ = 10 # weights will be overwritten every save_freq epoch
|
101 |
+
__C.TRAIN.UPDATE_N_VIEWS_RENDERING = False
|
102 |
+
|
103 |
+
#
|
104 |
+
# Testing options
|
105 |
+
#
|
106 |
+
__C.TEST = edict()
|
107 |
+
__C.TEST.RANDOM_BG_COLOR_RANGE = [[240, 240], [240, 240], [240, 240]]
|
108 |
+
__C.TEST.VOXEL_THRESH = [.2, .3, .4, .5]
|
helpers.py
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import utils.binvox_rw as binvox_rw
|
2 |
+
import numpy as np
|
3 |
+
import plotly.graph_objects as go
|
4 |
+
from models.encoder import Encoder
|
5 |
+
from models.decoder import Decoder
|
6 |
+
from models.merger import Merger
|
7 |
+
from models.refiner import Refiner
|
8 |
+
from config import cfg
|
9 |
+
import torch
|
10 |
+
from datetime import datetime as dt
|
11 |
+
import utils.data_transforms
|
12 |
+
device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
13 |
+
print(device)
|
14 |
+
# device='cpu'
|
15 |
+
|
16 |
+
cfg.CONST.WEIGHTS='saved_model/Pix2Vox.pth'
|
17 |
+
|
18 |
+
|
19 |
+
def read_binvox(file) -> np.ndarray:
|
20 |
+
model = binvox_rw.read_as_3d_array(file)
|
21 |
+
return model.data.astype(np.uint8)
|
22 |
+
|
23 |
+
|
24 |
+
def voxel_to_plotly(voxels):
|
25 |
+
x, y, z = voxels.nonzero()
|
26 |
+
fig = go.Figure(data=[
|
27 |
+
go.Scatter3d(
|
28 |
+
x=x, y=y, z=z,
|
29 |
+
mode='markers',
|
30 |
+
marker=dict(size=3, color=z, colorscale='Viridis', opacity=0.7)
|
31 |
+
)
|
32 |
+
])
|
33 |
+
fig.update_layout(scene=dict(aspectmode='data'))
|
34 |
+
return fig
|
35 |
+
|
36 |
+
|
37 |
+
|
38 |
+
IMG_SIZE = cfg.CONST.IMG_H, cfg.CONST.IMG_W
|
39 |
+
CROP_SIZE = cfg.CONST.CROP_IMG_H, cfg.CONST.CROP_IMG_W
|
40 |
+
test_transforms = utils.data_transforms.Compose([
|
41 |
+
utils.data_transforms.CenterCrop(IMG_SIZE, CROP_SIZE),
|
42 |
+
utils.data_transforms.RandomBackground(cfg.TEST.RANDOM_BG_COLOR_RANGE),
|
43 |
+
utils.data_transforms.Normalize(mean=cfg.DATASET.MEAN, std=cfg.DATASET.STD),
|
44 |
+
utils.data_transforms.ToTensor(),
|
45 |
+
])
|
46 |
+
|
47 |
+
|
48 |
+
def predict_voxel_from_images(rendering_images):
|
49 |
+
transformed_images = test_transforms(rendering_images)
|
50 |
+
|
51 |
+
encoder = Encoder(cfg)
|
52 |
+
decoder = Decoder(cfg)
|
53 |
+
refiner = Refiner(cfg)
|
54 |
+
merger = Merger(cfg)
|
55 |
+
|
56 |
+
|
57 |
+
if torch.cuda.is_available():
|
58 |
+
encoder = torch.nn.DataParallel(encoder).cuda()
|
59 |
+
decoder = torch.nn.DataParallel(decoder).cuda()
|
60 |
+
refiner = torch.nn.DataParallel(refiner).cuda()
|
61 |
+
merger = torch.nn.DataParallel(merger).cuda()
|
62 |
+
|
63 |
+
print('[INFO] %s Loading weights from %s ...' % (dt.now(), cfg.CONST.WEIGHTS))
|
64 |
+
checkpoint = torch.load(cfg.CONST.WEIGHTS)
|
65 |
+
|
66 |
+
epoch_idx = checkpoint['epoch_idx']
|
67 |
+
encoder.load_state_dict(checkpoint['encoder_state_dict'])
|
68 |
+
decoder.load_state_dict(checkpoint['decoder_state_dict'])
|
69 |
+
|
70 |
+
if cfg.NETWORK.USE_REFINER:
|
71 |
+
refiner.load_state_dict(checkpoint['refiner_state_dict'])
|
72 |
+
if cfg.NETWORK.USE_MERGER:
|
73 |
+
merger.load_state_dict(checkpoint['merger_state_dict'])
|
74 |
+
|
75 |
+
|
76 |
+
encoder.eval()
|
77 |
+
decoder.eval()
|
78 |
+
merger.eval()
|
79 |
+
refiner.eval()
|
80 |
+
|
81 |
+
|
82 |
+
with torch.no_grad():
|
83 |
+
|
84 |
+
transformed_images = transformed_images.unsqueeze(0) #adding the batch_dim
|
85 |
+
transformed_images = transformed_images.to(device)
|
86 |
+
|
87 |
+
# print(rendering_images.shape)
|
88 |
+
image_features = encoder(transformed_images)
|
89 |
+
print(image_features.shape)
|
90 |
+
raw_features, generated_volume = decoder(image_features)
|
91 |
+
print(generated_volume.shape)
|
92 |
+
|
93 |
+
|
94 |
+
if cfg.NETWORK.USE_MERGER:
|
95 |
+
generated_volume = merger(raw_features, generated_volume)
|
96 |
+
else:
|
97 |
+
generated_volume = torch.mean(generated_volume, dim=1)
|
98 |
+
|
99 |
+
|
100 |
+
# encoder_loss = bce_loss(generated_volume, ground_truth_volume) * 10
|
101 |
+
|
102 |
+
|
103 |
+
if cfg.NETWORK.USE_REFINER:
|
104 |
+
generated_volume = refiner(generated_volume)
|
105 |
+
# refiner_loss = bce_loss(generated_volume, ground_truth_volume) * 10
|
106 |
+
else:
|
107 |
+
# refiner_loss = encoder_loss
|
108 |
+
pass
|
109 |
+
|
110 |
+
|
111 |
+
generated_volume=generated_volume.squeeze(0)
|
112 |
+
gv = generated_volume.cpu().numpy()
|
113 |
+
gv = (gv >= 0.5).astype(np.uint8)
|
114 |
+
|
115 |
+
|
116 |
+
|
117 |
+
|
118 |
+
torch.cuda.empty_cache()
|
119 |
+
return gv
|
models/__init__.py
ADDED
File without changes
|
models/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (169 Bytes). View file
|
|
models/__pycache__/decoder.cpython-38.pyc
ADDED
Binary file (2.26 kB). View file
|
|
models/__pycache__/encoder.cpython-38.pyc
ADDED
Binary file (2.05 kB). View file
|
|
models/__pycache__/merger.cpython-38.pyc
ADDED
Binary file (1.65 kB). View file
|
|
models/__pycache__/refiner.cpython-38.pyc
ADDED
Binary file (1.92 kB). View file
|
|
models/decoder.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
class Decoder(torch.nn.Module):
|
4 |
+
def __init__(self, cfg):
|
5 |
+
super(Decoder, self).__init__()
|
6 |
+
self.cfg = cfg
|
7 |
+
|
8 |
+
# Layer Definition
|
9 |
+
self.layer1 = torch.nn.Sequential(
|
10 |
+
torch.nn.ConvTranspose3d(2048, 512, kernel_size=4, stride=2, bias=cfg.NETWORK.TCONV_USE_BIAS, padding=1),
|
11 |
+
torch.nn.BatchNorm3d(512),
|
12 |
+
torch.nn.ReLU()
|
13 |
+
)
|
14 |
+
self.layer2 = torch.nn.Sequential(
|
15 |
+
torch.nn.ConvTranspose3d(512, 128, kernel_size=4, stride=2, bias=cfg.NETWORK.TCONV_USE_BIAS, padding=1),
|
16 |
+
torch.nn.BatchNorm3d(128),
|
17 |
+
torch.nn.ReLU()
|
18 |
+
)
|
19 |
+
self.layer3 = torch.nn.Sequential(
|
20 |
+
torch.nn.ConvTranspose3d(128, 32, kernel_size=4, stride=2, bias=cfg.NETWORK.TCONV_USE_BIAS, padding=1),
|
21 |
+
torch.nn.BatchNorm3d(32),
|
22 |
+
torch.nn.ReLU()
|
23 |
+
)
|
24 |
+
self.layer4 = torch.nn.Sequential(
|
25 |
+
torch.nn.ConvTranspose3d(32, 8, kernel_size=4, stride=2, bias=cfg.NETWORK.TCONV_USE_BIAS, padding=1),
|
26 |
+
torch.nn.BatchNorm3d(8),
|
27 |
+
torch.nn.ReLU()
|
28 |
+
)
|
29 |
+
self.layer5 = torch.nn.Sequential(
|
30 |
+
torch.nn.ConvTranspose3d(8, 1, kernel_size=1, bias=cfg.NETWORK.TCONV_USE_BIAS),
|
31 |
+
torch.nn.Sigmoid()
|
32 |
+
)
|
33 |
+
|
34 |
+
def forward(self, image_features):
|
35 |
+
image_features = image_features.permute(1, 0, 2, 3, 4).contiguous()
|
36 |
+
image_features = torch.split(image_features, 1, dim=0)
|
37 |
+
gen_volumes = []
|
38 |
+
raw_features = []
|
39 |
+
|
40 |
+
for features in image_features:
|
41 |
+
gen_volume = features.view(-1, 2048, 2, 2, 2)
|
42 |
+
# print(gen_volume.size()) # torch.Size([batch_size, 2048, 2, 2, 2])
|
43 |
+
gen_volume = self.layer1(gen_volume)
|
44 |
+
# print(gen_volume.size()) # torch.Size([batch_size, 512, 4, 4, 4])
|
45 |
+
gen_volume = self.layer2(gen_volume)
|
46 |
+
# print(gen_volume.size()) # torch.Size([batch_size, 128, 8, 8, 8])
|
47 |
+
gen_volume = self.layer3(gen_volume)
|
48 |
+
# print(gen_volume.size()) # torch.Size([batch_size, 32, 16, 16, 16])
|
49 |
+
gen_volume = self.layer4(gen_volume)
|
50 |
+
raw_feature = gen_volume
|
51 |
+
# print(gen_volume.size()) # torch.Size([batch_size, 8, 32, 32, 32])
|
52 |
+
gen_volume = self.layer5(gen_volume)
|
53 |
+
# print(gen_volume.size()) # torch.Size([batch_size, 1, 32, 32, 32])
|
54 |
+
raw_feature = torch.cat((raw_feature, gen_volume), dim=1)
|
55 |
+
# print(raw_feature.size()) # torch.Size([batch_size, 9, 32, 32, 32])
|
56 |
+
|
57 |
+
gen_volumes.append(torch.squeeze(gen_volume, dim=1))
|
58 |
+
raw_features.append(raw_feature)
|
59 |
+
|
60 |
+
gen_volumes = torch.stack(gen_volumes).permute(1, 0, 2, 3, 4).contiguous()
|
61 |
+
raw_features = torch.stack(raw_features).permute(1, 0, 2, 3, 4, 5).contiguous()
|
62 |
+
# print(gen_volumes.size()) # torch.Size([batch_size, n_views, 32, 32, 32])
|
63 |
+
# print(raw_features.size()) # torch.Size([batch_size, n_views, 9, 32, 32, 32])
|
64 |
+
return raw_features, gen_volumes
|
65 |
+
|
66 |
+
|
67 |
+
class DummyCfg:
|
68 |
+
class NETWORK:
|
69 |
+
TCONV_USE_BIAS = False
|
70 |
+
|
71 |
+
cfg = DummyCfg()
|
72 |
+
|
73 |
+
# Instantiate the decoder
|
74 |
+
decoder = Decoder(cfg)
|
75 |
+
|
76 |
+
# Simulate input: shape [batch_size,n_views,img_c, img_h, img_w]
|
77 |
+
n_views = 1
|
78 |
+
batch_size = 64
|
79 |
+
img_c, img_h, img_w = 256, 8, 8
|
80 |
+
dummy_input = torch.randn(batch_size,n_views,img_c, img_h, img_w)
|
81 |
+
|
82 |
+
# Run the decoder
|
83 |
+
print(dummy_input.shape)
|
84 |
+
raw_features, gen_volumes = decoder(dummy_input)
|
85 |
+
|
86 |
+
# Output shapes
|
87 |
+
print("raw_features shape:", raw_features.shape) # Expected: [64, 5, 9, 32, 32, 32]
|
88 |
+
print("gen_volumes shape:", gen_volumes.shape)
|
models/encoder.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
#
|
3 |
+
# Developed by Haozhe Xie <[email protected]>
|
4 |
+
#
|
5 |
+
# References:
|
6 |
+
# - https://github.com/shawnxu1318/MVCNN-Multi-View-Convolutional-Neural-Networks/blob/master/mvcnn.py
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torchvision.models
|
10 |
+
|
11 |
+
|
12 |
+
class Encoder(torch.nn.Module):
|
13 |
+
def __init__(self, cfg):
|
14 |
+
super(Encoder, self).__init__()
|
15 |
+
self.cfg = cfg
|
16 |
+
|
17 |
+
# Layer Definition
|
18 |
+
vgg16_bn = torchvision.models.vgg16_bn(pretrained=True)
|
19 |
+
self.vgg = torch.nn.Sequential(*list(vgg16_bn.features.children()))[:27]
|
20 |
+
self.layer1 = torch.nn.Sequential(
|
21 |
+
torch.nn.Conv2d(512, 512, kernel_size=3),
|
22 |
+
torch.nn.BatchNorm2d(512),
|
23 |
+
torch.nn.ELU(),
|
24 |
+
)
|
25 |
+
self.layer2 = torch.nn.Sequential(
|
26 |
+
torch.nn.Conv2d(512, 512, kernel_size=3),
|
27 |
+
torch.nn.BatchNorm2d(512),
|
28 |
+
torch.nn.ELU(),
|
29 |
+
torch.nn.MaxPool2d(kernel_size=3)
|
30 |
+
)
|
31 |
+
self.layer3 = torch.nn.Sequential(
|
32 |
+
torch.nn.Conv2d(512, 256, kernel_size=1),
|
33 |
+
torch.nn.BatchNorm2d(256),
|
34 |
+
torch.nn.ELU()
|
35 |
+
)
|
36 |
+
|
37 |
+
# Don't update params in VGG16
|
38 |
+
for param in vgg16_bn.parameters():
|
39 |
+
param.requires_grad = False
|
40 |
+
|
41 |
+
def forward(self, rendering_images):
|
42 |
+
# print(rendering_images.size()) # torch.Size([batch_size, n_views, img_c, img_h, img_w])
|
43 |
+
rendering_images = rendering_images.permute(1, 0, 2, 3, 4).contiguous()
|
44 |
+
rendering_images = torch.split(rendering_images, 1, dim=0)
|
45 |
+
image_features = []
|
46 |
+
|
47 |
+
for img in rendering_images:
|
48 |
+
features = self.vgg(img.squeeze(dim=0))
|
49 |
+
# print(features.size()) # torch.Size([batch_size, 512, 28, 28])
|
50 |
+
features = self.layer1(features)
|
51 |
+
# print(features.size()) # torch.Size([batch_size, 512, 26, 26])
|
52 |
+
features = self.layer2(features)
|
53 |
+
# print(features.size()) # torch.Size([batch_size, 512, 24, 24])
|
54 |
+
features = self.layer3(features)
|
55 |
+
# print(features.size()) # torch.Size([batch_size, 256, 8, 8])
|
56 |
+
image_features.append(features)
|
57 |
+
|
58 |
+
image_features = torch.stack(image_features).permute(1, 0, 2, 3, 4).contiguous()
|
59 |
+
# print(image_features.size()) # torch.Size([batch_size, n_views, 256, 8, 8])
|
60 |
+
return image_features
|
61 |
+
|
62 |
+
|
63 |
+
|
64 |
+
class DummyCfg:
|
65 |
+
class NETWORK:
|
66 |
+
TCONV_USE_BIAS = False
|
67 |
+
|
68 |
+
cfg = DummyCfg()
|
69 |
+
|
70 |
+
# Instantiate the decoder
|
71 |
+
encoder = Encoder(cfg)
|
72 |
+
|
73 |
+
# Simulate input: shape [batch_size,n_views,img_c, img_h, img_w]
|
74 |
+
|
75 |
+
batch_size = 64
|
76 |
+
n_views=5
|
77 |
+
img_c, img_h, img_w = 3,224,224
|
78 |
+
dummy_input = torch.randn(batch_size,n_views,img_c, img_h, img_w)
|
79 |
+
|
80 |
+
# Run the decoder
|
81 |
+
print(dummy_input.shape)
|
82 |
+
image_features = encoder(dummy_input)
|
83 |
+
|
84 |
+
|
85 |
+
print("image_features shape:", image_features.shape) # Expected: [64, 5, 9, 32, 32, 32]
|
models/merger.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
#
|
3 |
+
# Developed by Haozhe Xie <[email protected]>
|
4 |
+
|
5 |
+
import torch
|
6 |
+
|
7 |
+
|
8 |
+
class Merger(torch.nn.Module):
|
9 |
+
def __init__(self, cfg):
|
10 |
+
super(Merger, self).__init__()
|
11 |
+
self.cfg = cfg
|
12 |
+
|
13 |
+
# Layer Definition
|
14 |
+
self.layer1 = torch.nn.Sequential(
|
15 |
+
torch.nn.Conv3d(9, 16, kernel_size=3, padding=1),
|
16 |
+
torch.nn.BatchNorm3d(16),
|
17 |
+
torch.nn.LeakyReLU(cfg.NETWORK.LEAKY_VALUE)
|
18 |
+
)
|
19 |
+
self.layer2 = torch.nn.Sequential(
|
20 |
+
torch.nn.Conv3d(16, 8, kernel_size=3, padding=1),
|
21 |
+
torch.nn.BatchNorm3d(8),
|
22 |
+
torch.nn.LeakyReLU(cfg.NETWORK.LEAKY_VALUE)
|
23 |
+
)
|
24 |
+
self.layer3 = torch.nn.Sequential(
|
25 |
+
torch.nn.Conv3d(8, 4, kernel_size=3, padding=1),
|
26 |
+
torch.nn.BatchNorm3d(4),
|
27 |
+
torch.nn.LeakyReLU(cfg.NETWORK.LEAKY_VALUE)
|
28 |
+
)
|
29 |
+
self.layer4 = torch.nn.Sequential(
|
30 |
+
torch.nn.Conv3d(4, 2, kernel_size=3, padding=1),
|
31 |
+
torch.nn.BatchNorm3d(2),
|
32 |
+
torch.nn.LeakyReLU(cfg.NETWORK.LEAKY_VALUE)
|
33 |
+
)
|
34 |
+
self.layer5 = torch.nn.Sequential(
|
35 |
+
torch.nn.Conv3d(2, 1, kernel_size=3, padding=1),
|
36 |
+
torch.nn.BatchNorm3d(1),
|
37 |
+
torch.nn.LeakyReLU(cfg.NETWORK.LEAKY_VALUE)
|
38 |
+
)
|
39 |
+
|
40 |
+
def forward(self, raw_features, coarse_volumes):
|
41 |
+
n_views_rendering = coarse_volumes.size(1)
|
42 |
+
raw_features = torch.split(raw_features, 1, dim=1)
|
43 |
+
volume_weights = []
|
44 |
+
|
45 |
+
for i in range(n_views_rendering):
|
46 |
+
raw_feature = torch.squeeze(raw_features[i], dim=1)
|
47 |
+
# print(raw_feature.size()) # torch.Size([batch_size, 9, 32, 32, 32])
|
48 |
+
|
49 |
+
volume_weight = self.layer1(raw_feature)
|
50 |
+
# print(volume_weight.size()) # torch.Size([batch_size, 16, 32, 32, 32])
|
51 |
+
volume_weight = self.layer2(volume_weight)
|
52 |
+
# print(volume_weight.size()) # torch.Size([batch_size, 8, 32, 32, 32])
|
53 |
+
volume_weight = self.layer3(volume_weight)
|
54 |
+
# print(volume_weight.size()) # torch.Size([batch_size, 4, 32, 32, 32])
|
55 |
+
volume_weight = self.layer4(volume_weight)
|
56 |
+
# print(volume_weight.size()) # torch.Size([batch_size, 2, 32, 32, 32])
|
57 |
+
volume_weight = self.layer5(volume_weight)
|
58 |
+
# print(volume_weight.size()) # torch.Size([batch_size, 1, 32, 32, 32])
|
59 |
+
|
60 |
+
volume_weight = torch.squeeze(volume_weight, dim=1)
|
61 |
+
# print(volume_weight.size()) # torch.Size([batch_size, 32, 32, 32])
|
62 |
+
volume_weights.append(volume_weight)
|
63 |
+
|
64 |
+
volume_weights = torch.stack(volume_weights).permute(1, 0, 2, 3, 4).contiguous()
|
65 |
+
volume_weights = torch.softmax(volume_weights, dim=1)
|
66 |
+
# print(volume_weights.size()) # torch.Size([batch_size, n_views, 32, 32, 32])
|
67 |
+
# print(coarse_volumes.size()) # torch.Size([batch_size, n_views, 32, 32, 32])
|
68 |
+
coarse_volumes = coarse_volumes * volume_weights
|
69 |
+
coarse_volumes = torch.sum(coarse_volumes, dim=1)
|
70 |
+
|
71 |
+
return torch.clamp(coarse_volumes, min=0, max=1)
|
models/refiner.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
#
|
3 |
+
# Developed by Haozhe Xie <[email protected]>
|
4 |
+
|
5 |
+
import torch
|
6 |
+
|
7 |
+
|
8 |
+
class Refiner(torch.nn.Module):
|
9 |
+
def __init__(self, cfg):
|
10 |
+
super(Refiner, self).__init__()
|
11 |
+
self.cfg = cfg
|
12 |
+
|
13 |
+
# Layer Definition
|
14 |
+
self.layer1 = torch.nn.Sequential(
|
15 |
+
torch.nn.Conv3d(1, 32, kernel_size=4, padding=2),
|
16 |
+
torch.nn.BatchNorm3d(32),
|
17 |
+
torch.nn.LeakyReLU(cfg.NETWORK.LEAKY_VALUE),
|
18 |
+
torch.nn.MaxPool3d(kernel_size=2)
|
19 |
+
)
|
20 |
+
self.layer2 = torch.nn.Sequential(
|
21 |
+
torch.nn.Conv3d(32, 64, kernel_size=4, padding=2),
|
22 |
+
torch.nn.BatchNorm3d(64),
|
23 |
+
torch.nn.LeakyReLU(cfg.NETWORK.LEAKY_VALUE),
|
24 |
+
torch.nn.MaxPool3d(kernel_size=2)
|
25 |
+
)
|
26 |
+
self.layer3 = torch.nn.Sequential(
|
27 |
+
torch.nn.Conv3d(64, 128, kernel_size=4, padding=2),
|
28 |
+
torch.nn.BatchNorm3d(128),
|
29 |
+
torch.nn.LeakyReLU(cfg.NETWORK.LEAKY_VALUE),
|
30 |
+
torch.nn.MaxPool3d(kernel_size=2)
|
31 |
+
)
|
32 |
+
self.layer4 = torch.nn.Sequential(
|
33 |
+
torch.nn.Linear(8192, 2048),
|
34 |
+
torch.nn.ReLU()
|
35 |
+
)
|
36 |
+
self.layer5 = torch.nn.Sequential(
|
37 |
+
torch.nn.Linear(2048, 8192),
|
38 |
+
torch.nn.ReLU()
|
39 |
+
)
|
40 |
+
self.layer6 = torch.nn.Sequential(
|
41 |
+
torch.nn.ConvTranspose3d(128, 64, kernel_size=4, stride=2, bias=cfg.NETWORK.TCONV_USE_BIAS, padding=1),
|
42 |
+
torch.nn.BatchNorm3d(64),
|
43 |
+
torch.nn.ReLU()
|
44 |
+
)
|
45 |
+
self.layer7 = torch.nn.Sequential(
|
46 |
+
torch.nn.ConvTranspose3d(64, 32, kernel_size=4, stride=2, bias=cfg.NETWORK.TCONV_USE_BIAS, padding=1),
|
47 |
+
torch.nn.BatchNorm3d(32),
|
48 |
+
torch.nn.ReLU()
|
49 |
+
)
|
50 |
+
self.layer8 = torch.nn.Sequential(
|
51 |
+
torch.nn.ConvTranspose3d(32, 1, kernel_size=4, stride=2, bias=cfg.NETWORK.TCONV_USE_BIAS, padding=1),
|
52 |
+
torch.nn.Sigmoid()
|
53 |
+
)
|
54 |
+
|
55 |
+
def forward(self, coarse_volumes):
|
56 |
+
volumes_32_l = coarse_volumes.view((-1, 1, self.cfg.CONST.N_VOX, self.cfg.CONST.N_VOX, self.cfg.CONST.N_VOX))
|
57 |
+
# print(volumes_32_l.size()) # torch.Size([batch_size, 1, 32, 32, 32])
|
58 |
+
volumes_16_l = self.layer1(volumes_32_l)
|
59 |
+
# print(volumes_16_l.size()) # torch.Size([batch_size, 32, 16, 16, 16])
|
60 |
+
volumes_8_l = self.layer2(volumes_16_l)
|
61 |
+
# print(volumes_8_l.size()) # torch.Size([batch_size, 64, 8, 8, 8])
|
62 |
+
volumes_4_l = self.layer3(volumes_8_l)
|
63 |
+
# print(volumes_4_l.size()) # torch.Size([batch_size, 128, 4, 4, 4])
|
64 |
+
flatten_features = self.layer4(volumes_4_l.view(-1, 8192))
|
65 |
+
# print(flatten_features.size()) # torch.Size([batch_size, 2048])
|
66 |
+
flatten_features = self.layer5(flatten_features)
|
67 |
+
# print(flatten_features.size()) # torch.Size([batch_size, 8192])
|
68 |
+
volumes_4_r = volumes_4_l + flatten_features.view(-1, 128, 4, 4, 4)
|
69 |
+
# print(volumes_4_r.size()) # torch.Size([batch_size, 128, 4, 4, 4])
|
70 |
+
volumes_8_r = volumes_8_l + self.layer6(volumes_4_r)
|
71 |
+
# print(volumes_8_r.size()) # torch.Size([batch_size, 64, 8, 8, 8])
|
72 |
+
volumes_16_r = volumes_16_l + self.layer7(volumes_8_r)
|
73 |
+
# print(volumes_16_r.size()) # torch.Size([batch_size, 32, 16, 16, 16])
|
74 |
+
volumes_32_r = (volumes_32_l + self.layer8(volumes_16_r)) * 0.5
|
75 |
+
# print(volumes_32_r.size()) # torch.Size([batch_size, 1, 32, 32, 32])
|
76 |
+
|
77 |
+
return volumes_32_r.view((-1, self.cfg.CONST.N_VOX, self.cfg.CONST.N_VOX, self.cfg.CONST.N_VOX))
|
requirements.txt
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
argparse
|
2 |
+
easydict
|
3 |
+
matplotlib
|
4 |
+
numpy
|
5 |
+
opencv-python
|
6 |
+
scipy
|
7 |
+
torch
|
8 |
+
torchvision
|
9 |
+
streamlit
|
10 |
+
plotly
|
11 |
+
pillow
|
utils/__pycache__/binvox_rw.cpython-38.pyc
ADDED
Binary file (7.44 kB). View file
|
|
utils/__pycache__/data_transforms.cpython-38.pyc
ADDED
Binary file (11.8 kB). View file
|
|
utils/binvox_rw.py
ADDED
@@ -0,0 +1,296 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (C) 2012 Daniel Maturana
|
2 |
+
# This file is part of binvox-rw-py.
|
3 |
+
#
|
4 |
+
# binvox-rw-py is free software: you can redistribute it and/or modify
|
5 |
+
# it under the terms of the GNU General Public License as published by
|
6 |
+
# the Free Software Foundation, either version 3 of the License, or
|
7 |
+
# (at your option) any later version.
|
8 |
+
#
|
9 |
+
# binvox-rw-py is distributed in the hope that it will be useful,
|
10 |
+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
11 |
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
12 |
+
# GNU General Public License for more details.
|
13 |
+
#
|
14 |
+
# You should have received a copy of the GNU General Public License
|
15 |
+
# along with binvox-rw-py. If not, see <http://www.gnu.org/licenses/>.
|
16 |
+
#
|
17 |
+
"""
|
18 |
+
Binvox to Numpy and back.
|
19 |
+
|
20 |
+
|
21 |
+
>>> import numpy as np
|
22 |
+
>>> import binvox_rw
|
23 |
+
>>> with open('chair.binvox', 'rb') as f:
|
24 |
+
... m1 = binvox_rw.read_as_3d_array(f)
|
25 |
+
...
|
26 |
+
>>> m1.dims
|
27 |
+
[32, 32, 32]
|
28 |
+
>>> m1.scale
|
29 |
+
41.133000000000003
|
30 |
+
>>> m1.translate
|
31 |
+
[0.0, 0.0, 0.0]
|
32 |
+
>>> with open('chair_out.binvox', 'wb') as f:
|
33 |
+
... m1.write(f)
|
34 |
+
...
|
35 |
+
>>> with open('chair_out.binvox', 'rb') as f:
|
36 |
+
... m2 = binvox_rw.read_as_3d_array(f)
|
37 |
+
...
|
38 |
+
>>> m1.dims == m2.dims
|
39 |
+
True
|
40 |
+
>>> m1.scale == m2.scale
|
41 |
+
True
|
42 |
+
>>> m1.translate == m2.translate
|
43 |
+
True
|
44 |
+
>>> np.all(m1.data == m2.data)
|
45 |
+
True
|
46 |
+
|
47 |
+
>>> with open('chair.binvox', 'rb') as f:
|
48 |
+
... md = binvox_rw.read_as_3d_array(f)
|
49 |
+
...
|
50 |
+
>>> with open('chair.binvox', 'rb') as f:
|
51 |
+
... ms = binvox_rw.read_as_coord_array(f)
|
52 |
+
...
|
53 |
+
>>> data_ds = binvox_rw.dense_to_sparse(md.data)
|
54 |
+
>>> data_sd = binvox_rw.sparse_to_dense(ms.data, 32)
|
55 |
+
>>> np.all(data_sd == md.data)
|
56 |
+
True
|
57 |
+
>>> # the ordering of elements returned by numpy.nonzero changes with axis
|
58 |
+
>>> # ordering, so to compare for equality we first lexically sort the voxels.
|
59 |
+
>>> np.all(ms.data[:, np.lexsort(ms.data)] == data_ds[:, np.lexsort(data_ds)])
|
60 |
+
True
|
61 |
+
"""
|
62 |
+
|
63 |
+
import numpy as np
|
64 |
+
|
65 |
+
|
66 |
+
class Voxels(object):
|
67 |
+
""" Holds a binvox model.
|
68 |
+
data is either a three-dimensional numpy boolean array (dense representation)
|
69 |
+
or a two-dimensional numpy float array (coordinate representation).
|
70 |
+
|
71 |
+
dims, translate and scale are the model metadata.
|
72 |
+
|
73 |
+
dims are the voxel dimensions, e.g. [32, 32, 32] for a 32x32x32 model.
|
74 |
+
|
75 |
+
scale and translate relate the voxels to the original model coordinates.
|
76 |
+
|
77 |
+
To translate voxel coordinates i, j, k to original coordinates x, y, z:
|
78 |
+
|
79 |
+
x_n = (i+.5)/dims[0]
|
80 |
+
y_n = (j+.5)/dims[1]
|
81 |
+
z_n = (k+.5)/dims[2]
|
82 |
+
x = scale*x_n + translate[0]
|
83 |
+
y = scale*y_n + translate[1]
|
84 |
+
z = scale*z_n + translate[2]
|
85 |
+
|
86 |
+
"""
|
87 |
+
def __init__(self, data, dims, translate, scale, axis_order):
|
88 |
+
self.data = data
|
89 |
+
self.dims = dims
|
90 |
+
self.translate = translate
|
91 |
+
self.scale = scale
|
92 |
+
assert (axis_order in ('xzy', 'xyz'))
|
93 |
+
self.axis_order = axis_order
|
94 |
+
|
95 |
+
def clone(self):
|
96 |
+
data = self.data.copy()
|
97 |
+
dims = self.dims[:]
|
98 |
+
translate = self.translate[:]
|
99 |
+
return Voxels(data, dims, translate, self.scale, self.axis_order)
|
100 |
+
|
101 |
+
def write(self, fp):
|
102 |
+
write(self, fp)
|
103 |
+
|
104 |
+
|
105 |
+
def read_header(fp):
|
106 |
+
""" Read binvox header. Mostly meant for internal use.
|
107 |
+
"""
|
108 |
+
line = fp.readline().strip()
|
109 |
+
if not line.startswith(b'#binvox'):
|
110 |
+
raise IOError('[ERROR] Not a binvox file')
|
111 |
+
dims = list(map(int, fp.readline().strip().split(b' ')[1:]))
|
112 |
+
translate = list(map(float, fp.readline().strip().split(b' ')[1:]))
|
113 |
+
scale = list(map(float, fp.readline().strip().split(b' ')[1:]))[0]
|
114 |
+
fp.readline()
|
115 |
+
return dims, translate, scale
|
116 |
+
|
117 |
+
|
118 |
+
def read_as_3d_array(fp, fix_coords=True):
|
119 |
+
""" Read binary binvox format as array.
|
120 |
+
|
121 |
+
Returns the model with accompanying metadata.
|
122 |
+
|
123 |
+
Voxels are stored in a three-dimensional numpy array, which is simple and
|
124 |
+
direct, but may use a lot of memory for large models. (Storage requirements
|
125 |
+
are 8*(d^3) bytes, where d is the dimensions of the binvox model. Numpy
|
126 |
+
boolean arrays use a byte per element).
|
127 |
+
|
128 |
+
Doesn't do any checks on input except for the '#binvox' line.
|
129 |
+
"""
|
130 |
+
dims, translate, scale = read_header(fp)
|
131 |
+
raw_data = np.frombuffer(fp.read(), dtype=np.uint8)
|
132 |
+
# if just using reshape() on the raw data:
|
133 |
+
# indexing the array as array[i,j,k], the indices map into the
|
134 |
+
# coords as:
|
135 |
+
# i -> x
|
136 |
+
# j -> z
|
137 |
+
# k -> y
|
138 |
+
# if fix_coords is true, then data is rearranged so that
|
139 |
+
# mapping is
|
140 |
+
# i -> x
|
141 |
+
# j -> y
|
142 |
+
# k -> z
|
143 |
+
values, counts = raw_data[::2], raw_data[1::2]
|
144 |
+
data = np.repeat(values, counts).astype(np.int32)
|
145 |
+
data = data.reshape(dims)
|
146 |
+
if fix_coords:
|
147 |
+
# xzy to xyz TODO the right thing
|
148 |
+
data = np.transpose(data, (0, 2, 1))
|
149 |
+
axis_order = 'xyz'
|
150 |
+
else:
|
151 |
+
axis_order = 'xzy'
|
152 |
+
return Voxels(data, dims, translate, scale, axis_order)
|
153 |
+
|
154 |
+
|
155 |
+
def read_as_coord_array(fp, fix_coords=True):
|
156 |
+
""" Read binary binvox format as coordinates.
|
157 |
+
|
158 |
+
Returns binvox model with voxels in a "coordinate" representation, i.e. an
|
159 |
+
3 x N array where N is the number of nonzero voxels. Each column
|
160 |
+
corresponds to a nonzero voxel and the 3 rows are the (x, z, y) coordinates
|
161 |
+
of the voxel. (The odd ordering is due to the way binvox format lays out
|
162 |
+
data). Note that coordinates refer to the binvox voxels, without any
|
163 |
+
scaling or translation.
|
164 |
+
|
165 |
+
Use this to save memory if your model is very sparse (mostly empty).
|
166 |
+
|
167 |
+
Doesn't do any checks on input except for the '#binvox' line.
|
168 |
+
"""
|
169 |
+
dims, translate, scale = read_header(fp)
|
170 |
+
raw_data = np.frombuffer(fp.read(), dtype=np.uint8)
|
171 |
+
|
172 |
+
values, counts = raw_data[::2], raw_data[1::2]
|
173 |
+
|
174 |
+
# sz = np.prod(dims)
|
175 |
+
# index, end_index = 0, 0
|
176 |
+
end_indices = np.cumsum(counts)
|
177 |
+
indices = np.concatenate(([0], end_indices[:-1])).astype(end_indices.dtype)
|
178 |
+
|
179 |
+
values = values.astype(np.bool)
|
180 |
+
indices = indices[values]
|
181 |
+
end_indices = end_indices[values]
|
182 |
+
|
183 |
+
nz_voxels = []
|
184 |
+
for index, end_index in zip(indices, end_indices):
|
185 |
+
nz_voxels.extend(range(index, end_index))
|
186 |
+
nz_voxels = np.array(nz_voxels)
|
187 |
+
# TODO are these dims correct?
|
188 |
+
# according to docs,
|
189 |
+
# index = x * wxh + z * width + y; // wxh = width * height = d * d
|
190 |
+
|
191 |
+
x = nz_voxels / (dims[0] * dims[1])
|
192 |
+
zwpy = nz_voxels % (dims[0] * dims[1]) # z*w + y
|
193 |
+
z = zwpy / dims[0]
|
194 |
+
y = zwpy % dims[0]
|
195 |
+
if fix_coords:
|
196 |
+
data = np.vstack((x, y, z))
|
197 |
+
axis_order = 'xyz'
|
198 |
+
else:
|
199 |
+
data = np.vstack((x, z, y))
|
200 |
+
axis_order = 'xzy'
|
201 |
+
|
202 |
+
#return Voxels(data, dims, translate, scale, axis_order)
|
203 |
+
return Voxels(np.ascontiguousarray(data), dims, translate, scale, axis_order)
|
204 |
+
|
205 |
+
|
206 |
+
def dense_to_sparse(voxel_data, dtype=int):
|
207 |
+
""" From dense representation to sparse (coordinate) representation.
|
208 |
+
No coordinate reordering.
|
209 |
+
"""
|
210 |
+
if voxel_data.ndim != 3:
|
211 |
+
raise ValueError('[ERROR] voxel_data is wrong shape; should be 3D array.')
|
212 |
+
return np.asarray(np.nonzero(voxel_data), dtype)
|
213 |
+
|
214 |
+
|
215 |
+
def sparse_to_dense(voxel_data, dims, dtype=bool):
|
216 |
+
if voxel_data.ndim != 2 or voxel_data.shape[0] != 3:
|
217 |
+
raise ValueError('[ERROR] voxel_data is wrong shape; should be 3xN array.')
|
218 |
+
if np.isscalar(dims):
|
219 |
+
dims = [dims] * 3
|
220 |
+
dims = np.atleast_2d(dims).T
|
221 |
+
# truncate to integers
|
222 |
+
xyz = voxel_data.astype(np.int)
|
223 |
+
# discard voxels that fall outside dims
|
224 |
+
valid_ix = ~np.any((xyz < 0) | (xyz >= dims), 0)
|
225 |
+
xyz = xyz[:, valid_ix]
|
226 |
+
out = np.zeros(dims.flatten(), dtype=dtype)
|
227 |
+
out[tuple(xyz)] = True
|
228 |
+
return out
|
229 |
+
|
230 |
+
|
231 |
+
#def get_linear_index(x, y, z, dims):
|
232 |
+
#""" Assuming xzy order. (y increasing fastest.
|
233 |
+
#TODO ensure this is right when dims are not all same
|
234 |
+
#"""
|
235 |
+
#return x*(dims[1]*dims[2]) + z*dims[1] + y
|
236 |
+
|
237 |
+
|
238 |
+
def write(voxel_model, fp):
|
239 |
+
""" Write binary binvox format.
|
240 |
+
|
241 |
+
Note that when saving a model in sparse (coordinate) format, it is first
|
242 |
+
converted to dense format.
|
243 |
+
|
244 |
+
Doesn't check if the model is 'sane'.
|
245 |
+
|
246 |
+
"""
|
247 |
+
if voxel_model.data.ndim == 2:
|
248 |
+
# TODO avoid conversion to dense
|
249 |
+
dense_voxel_data = sparse_to_dense(voxel_model.data, voxel_model.dims).astype(int)
|
250 |
+
else:
|
251 |
+
dense_voxel_data = voxel_model.data.astype(int)
|
252 |
+
|
253 |
+
file_header = [
|
254 |
+
'#binvox 1\n',
|
255 |
+
'dim %s\n' % ' '.join(map(str, voxel_model.dims)),
|
256 |
+
'translate %s\n' % ' '.join(map(str, voxel_model.translate)),
|
257 |
+
'scale %s\n' % str(voxel_model.scale), 'data\n'
|
258 |
+
]
|
259 |
+
|
260 |
+
for fh in file_header:
|
261 |
+
fp.write(fh.encode('latin-1'))
|
262 |
+
|
263 |
+
if voxel_model.axis_order not in ('xzy', 'xyz'):
|
264 |
+
raise ValueError('[ERROR] Unsupported voxel model axis order')
|
265 |
+
|
266 |
+
if voxel_model.axis_order == 'xzy':
|
267 |
+
voxels_flat = dense_voxel_data.flatten()
|
268 |
+
elif voxel_model.axis_order == 'xyz':
|
269 |
+
voxels_flat = np.transpose(dense_voxel_data, (0, 2, 1)).flatten()
|
270 |
+
|
271 |
+
# keep a sort of state machine for writing run length encoding
|
272 |
+
state = voxels_flat[0]
|
273 |
+
ctr = 0
|
274 |
+
for c in voxels_flat:
|
275 |
+
if c == state:
|
276 |
+
ctr += 1
|
277 |
+
# if ctr hits max, dump
|
278 |
+
if ctr == 255:
|
279 |
+
fp.write(chr(state).encode('latin-1'))
|
280 |
+
fp.write(chr(ctr).encode('latin-1'))
|
281 |
+
ctr = 0
|
282 |
+
else:
|
283 |
+
# if switch state, dump
|
284 |
+
fp.write(chr(state).encode('latin-1'))
|
285 |
+
fp.write(chr(ctr).encode('latin-1'))
|
286 |
+
state = c
|
287 |
+
ctr = 1
|
288 |
+
# flush out remainders
|
289 |
+
if ctr > 0:
|
290 |
+
fp.write(chr(state).encode('latin-1'))
|
291 |
+
fp.write(chr(ctr).encode('latin-1'))
|
292 |
+
|
293 |
+
|
294 |
+
if __name__ == '__main__':
|
295 |
+
import doctest
|
296 |
+
doctest.testmod()
|
utils/data_transforms.py
ADDED
@@ -0,0 +1,452 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
#
|
3 |
+
# Developed by Haozhe Xie <[email protected]>
|
4 |
+
# References:
|
5 |
+
# - https://github.com/xiumingzhang/GenRe-ShapeHD
|
6 |
+
|
7 |
+
import cv2
|
8 |
+
# import matplotlib.pyplot as plt
|
9 |
+
# import matplotlib.patches as patches
|
10 |
+
import numpy as np
|
11 |
+
import os
|
12 |
+
import random
|
13 |
+
import torch
|
14 |
+
|
15 |
+
|
16 |
+
class Compose(object):
|
17 |
+
""" Composes several transforms together.
|
18 |
+
For example:
|
19 |
+
>>> transforms.Compose([
|
20 |
+
>>> transforms.RandomBackground(),
|
21 |
+
>>> transforms.CenterCrop(127, 127, 3),
|
22 |
+
>>> ])
|
23 |
+
"""
|
24 |
+
def __init__(self, transforms):
|
25 |
+
self.transforms = transforms
|
26 |
+
|
27 |
+
def __call__(self, rendering_images, bounding_box=None):
|
28 |
+
for t in self.transforms:
|
29 |
+
if t.__class__.__name__ == 'RandomCrop' or t.__class__.__name__ == 'CenterCrop':
|
30 |
+
rendering_images = t(rendering_images, bounding_box)
|
31 |
+
else:
|
32 |
+
rendering_images = t(rendering_images)
|
33 |
+
|
34 |
+
return rendering_images
|
35 |
+
|
36 |
+
|
37 |
+
class ToTensor(object):
|
38 |
+
"""
|
39 |
+
Convert a PIL Image or numpy.ndarray to tensor.
|
40 |
+
Converts a PIL Image or numpy.ndarray (H x W x C) in the range [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0].
|
41 |
+
"""
|
42 |
+
def __call__(self, rendering_images):
|
43 |
+
assert (isinstance(rendering_images, np.ndarray))
|
44 |
+
array = np.transpose(rendering_images, (0, 3, 1, 2))
|
45 |
+
# handle numpy array
|
46 |
+
tensor = torch.from_numpy(array)
|
47 |
+
|
48 |
+
# put it from HWC to CHW format
|
49 |
+
return tensor.float()
|
50 |
+
|
51 |
+
|
52 |
+
class Normalize(object):
|
53 |
+
def __init__(self, mean, std):
|
54 |
+
self.mean = mean
|
55 |
+
self.std = std
|
56 |
+
|
57 |
+
def __call__(self, rendering_images):
|
58 |
+
assert (isinstance(rendering_images, np.ndarray))
|
59 |
+
rendering_images -= self.mean
|
60 |
+
rendering_images /= self.std
|
61 |
+
|
62 |
+
return rendering_images
|
63 |
+
|
64 |
+
|
65 |
+
class RandomPermuteRGB(object):
|
66 |
+
def __call__(self, rendering_images):
|
67 |
+
assert (isinstance(rendering_images, np.ndarray))
|
68 |
+
|
69 |
+
random_permutation = np.random.permutation(3)
|
70 |
+
for img_idx, img in enumerate(rendering_images):
|
71 |
+
rendering_images[img_idx] = img[..., random_permutation]
|
72 |
+
|
73 |
+
return rendering_images
|
74 |
+
|
75 |
+
|
76 |
+
class CenterCrop(object):
|
77 |
+
def __init__(self, img_size, crop_size):
|
78 |
+
"""Set the height and weight before and after cropping"""
|
79 |
+
self.img_size_h = img_size[0]
|
80 |
+
self.img_size_w = img_size[1]
|
81 |
+
self.crop_size_h = crop_size[0]
|
82 |
+
self.crop_size_w = crop_size[1]
|
83 |
+
|
84 |
+
def __call__(self, rendering_images, bounding_box=None):
|
85 |
+
if len(rendering_images) == 0:
|
86 |
+
return rendering_images
|
87 |
+
|
88 |
+
crop_size_c = rendering_images[0].shape[2]
|
89 |
+
processed_images = np.empty(shape=(0, self.img_size_h, self.img_size_w, crop_size_c))
|
90 |
+
for img_idx, img in enumerate(rendering_images):
|
91 |
+
img_height, img_width, _ = img.shape
|
92 |
+
|
93 |
+
if bounding_box is not None:
|
94 |
+
bounding_box = [
|
95 |
+
bounding_box[0] * img_width,
|
96 |
+
bounding_box[1] * img_height,
|
97 |
+
bounding_box[2] * img_width,
|
98 |
+
bounding_box[3] * img_height
|
99 |
+
] # yapf: disable
|
100 |
+
|
101 |
+
# Calculate the size of bounding boxes
|
102 |
+
bbox_width = bounding_box[2] - bounding_box[0]
|
103 |
+
bbox_height = bounding_box[3] - bounding_box[1]
|
104 |
+
bbox_x_mid = (bounding_box[2] + bounding_box[0]) * .5
|
105 |
+
bbox_y_mid = (bounding_box[3] + bounding_box[1]) * .5
|
106 |
+
|
107 |
+
# Make the crop area as a square
|
108 |
+
square_object_size = max(bbox_width, bbox_height)
|
109 |
+
x_left = int(bbox_x_mid - square_object_size * .5)
|
110 |
+
x_right = int(bbox_x_mid + square_object_size * .5)
|
111 |
+
y_top = int(bbox_y_mid - square_object_size * .5)
|
112 |
+
y_bottom = int(bbox_y_mid + square_object_size * .5)
|
113 |
+
|
114 |
+
# If the crop position is out of the image, fix it with padding
|
115 |
+
pad_x_left = 0
|
116 |
+
if x_left < 0:
|
117 |
+
pad_x_left = -x_left
|
118 |
+
x_left = 0
|
119 |
+
pad_x_right = 0
|
120 |
+
if x_right >= img_width:
|
121 |
+
pad_x_right = x_right - img_width + 1
|
122 |
+
x_right = img_width - 1
|
123 |
+
pad_y_top = 0
|
124 |
+
if y_top < 0:
|
125 |
+
pad_y_top = -y_top
|
126 |
+
y_top = 0
|
127 |
+
pad_y_bottom = 0
|
128 |
+
if y_bottom >= img_height:
|
129 |
+
pad_y_bottom = y_bottom - img_height + 1
|
130 |
+
y_bottom = img_height - 1
|
131 |
+
|
132 |
+
# Padding the image and resize the image
|
133 |
+
processed_image = np.pad(img[y_top:y_bottom + 1, x_left:x_right + 1],
|
134 |
+
((pad_y_top, pad_y_bottom), (pad_x_left, pad_x_right), (0, 0)),
|
135 |
+
mode='edge')
|
136 |
+
processed_image = cv2.resize(processed_image, (self.img_size_w, self.img_size_h))
|
137 |
+
else:
|
138 |
+
if img_height > self.crop_size_h and img_width > self.crop_size_w:
|
139 |
+
x_left = int(img_width - self.crop_size_w) // 2
|
140 |
+
x_right = int(x_left + self.crop_size_w)
|
141 |
+
y_top = int(img_height - self.crop_size_h) // 2
|
142 |
+
y_bottom = int(y_top + self.crop_size_h)
|
143 |
+
else:
|
144 |
+
x_left = 0
|
145 |
+
x_right = img_width
|
146 |
+
y_top = 0
|
147 |
+
y_bottom = img_height
|
148 |
+
|
149 |
+
processed_image = cv2.resize(img[y_top:y_bottom, x_left:x_right], (self.img_size_w, self.img_size_h))
|
150 |
+
|
151 |
+
processed_images = np.append(processed_images, [processed_image], axis=0)
|
152 |
+
# Debug
|
153 |
+
# fig = plt.figure()
|
154 |
+
# ax1 = fig.add_subplot(1, 2, 1)
|
155 |
+
# ax1.imshow(img)
|
156 |
+
# if not bounding_box is None:
|
157 |
+
# rect = patches.Rectangle((bounding_box[0], bounding_box[1]),
|
158 |
+
# bbox_width,
|
159 |
+
# bbox_height,
|
160 |
+
# linewidth=1,
|
161 |
+
# edgecolor='r',
|
162 |
+
# facecolor='none')
|
163 |
+
# ax1.add_patch(rect)
|
164 |
+
# ax2 = fig.add_subplot(1, 2, 2)
|
165 |
+
# ax2.imshow(processed_image)
|
166 |
+
# plt.show()
|
167 |
+
return processed_images
|
168 |
+
|
169 |
+
|
170 |
+
class RandomCrop(object):
|
171 |
+
def __init__(self, img_size, crop_size):
|
172 |
+
"""Set the height and weight before and after cropping"""
|
173 |
+
self.img_size_h = img_size[0]
|
174 |
+
self.img_size_w = img_size[1]
|
175 |
+
self.crop_size_h = crop_size[0]
|
176 |
+
self.crop_size_w = crop_size[1]
|
177 |
+
|
178 |
+
def __call__(self, rendering_images, bounding_box=None):
|
179 |
+
if len(rendering_images) == 0:
|
180 |
+
return rendering_images
|
181 |
+
|
182 |
+
crop_size_c = rendering_images[0].shape[2]
|
183 |
+
processed_images = np.empty(shape=(0, self.img_size_h, self.img_size_w, crop_size_c))
|
184 |
+
for img_idx, img in enumerate(rendering_images):
|
185 |
+
img_height, img_width, _ = img.shape
|
186 |
+
|
187 |
+
if bounding_box is not None:
|
188 |
+
bounding_box = [
|
189 |
+
bounding_box[0] * img_width,
|
190 |
+
bounding_box[1] * img_height,
|
191 |
+
bounding_box[2] * img_width,
|
192 |
+
bounding_box[3] * img_height
|
193 |
+
] # yapf: disable
|
194 |
+
|
195 |
+
# Calculate the size of bounding boxes
|
196 |
+
bbox_width = bounding_box[2] - bounding_box[0]
|
197 |
+
bbox_height = bounding_box[3] - bounding_box[1]
|
198 |
+
bbox_x_mid = (bounding_box[2] + bounding_box[0]) * .5
|
199 |
+
bbox_y_mid = (bounding_box[3] + bounding_box[1]) * .5
|
200 |
+
|
201 |
+
# Make the crop area as a square
|
202 |
+
square_object_size = max(bbox_width, bbox_height)
|
203 |
+
square_object_size = square_object_size * random.uniform(0.8, 1.2)
|
204 |
+
|
205 |
+
x_left = int(bbox_x_mid - square_object_size * random.uniform(.4, .6))
|
206 |
+
x_right = int(bbox_x_mid + square_object_size * random.uniform(.4, .6))
|
207 |
+
y_top = int(bbox_y_mid - square_object_size * random.uniform(.4, .6))
|
208 |
+
y_bottom = int(bbox_y_mid + square_object_size * random.uniform(.4, .6))
|
209 |
+
|
210 |
+
# If the crop position is out of the image, fix it with padding
|
211 |
+
pad_x_left = 0
|
212 |
+
if x_left < 0:
|
213 |
+
pad_x_left = -x_left
|
214 |
+
x_left = 0
|
215 |
+
pad_x_right = 0
|
216 |
+
if x_right >= img_width:
|
217 |
+
pad_x_right = x_right - img_width + 1
|
218 |
+
x_right = img_width - 1
|
219 |
+
pad_y_top = 0
|
220 |
+
if y_top < 0:
|
221 |
+
pad_y_top = -y_top
|
222 |
+
y_top = 0
|
223 |
+
pad_y_bottom = 0
|
224 |
+
if y_bottom >= img_height:
|
225 |
+
pad_y_bottom = y_bottom - img_height + 1
|
226 |
+
y_bottom = img_height - 1
|
227 |
+
|
228 |
+
# Padding the image and resize the image
|
229 |
+
processed_image = np.pad(img[y_top:y_bottom + 1, x_left:x_right + 1],
|
230 |
+
((pad_y_top, pad_y_bottom), (pad_x_left, pad_x_right), (0, 0)),
|
231 |
+
mode='edge')
|
232 |
+
processed_image = cv2.resize(processed_image, (self.img_size_w, self.img_size_h))
|
233 |
+
else:
|
234 |
+
if img_height > self.crop_size_h and img_width > self.crop_size_w:
|
235 |
+
x_left = int(img_width - self.crop_size_w) // 2
|
236 |
+
x_right = int(x_left + self.crop_size_w)
|
237 |
+
y_top = int(img_height - self.crop_size_h) // 2
|
238 |
+
y_bottom = int(y_top + self.crop_size_h)
|
239 |
+
else:
|
240 |
+
x_left = 0
|
241 |
+
x_right = img_width
|
242 |
+
y_top = 0
|
243 |
+
y_bottom = img_height
|
244 |
+
|
245 |
+
processed_image = cv2.resize(img[y_top:y_bottom, x_left:x_right], (self.img_size_w, self.img_size_h))
|
246 |
+
|
247 |
+
processed_images = np.append(processed_images, [processed_image], axis=0)
|
248 |
+
|
249 |
+
return processed_images
|
250 |
+
|
251 |
+
|
252 |
+
class RandomFlip(object):
|
253 |
+
def __call__(self, rendering_images):
|
254 |
+
assert (isinstance(rendering_images, np.ndarray))
|
255 |
+
|
256 |
+
for img_idx, img in enumerate(rendering_images):
|
257 |
+
if random.randint(0, 1):
|
258 |
+
rendering_images[img_idx] = np.fliplr(img)
|
259 |
+
|
260 |
+
return rendering_images
|
261 |
+
|
262 |
+
|
263 |
+
class ColorJitter(object):
|
264 |
+
def __init__(self, brightness, contrast, saturation):
|
265 |
+
self.brightness = brightness
|
266 |
+
self.contrast = contrast
|
267 |
+
self.saturation = saturation
|
268 |
+
|
269 |
+
def __call__(self, rendering_images):
|
270 |
+
if len(rendering_images) == 0:
|
271 |
+
return rendering_images
|
272 |
+
|
273 |
+
# Allocate new space for storing processed images
|
274 |
+
img_height, img_width, img_channels = rendering_images[0].shape
|
275 |
+
processed_images = np.empty(shape=(0, img_height, img_width, img_channels))
|
276 |
+
|
277 |
+
# Randomize the value of changing brightness, contrast, and saturation
|
278 |
+
brightness = 1 + np.random.uniform(low=-self.brightness, high=self.brightness)
|
279 |
+
contrast = 1 + np.random.uniform(low=-self.contrast, high=self.contrast)
|
280 |
+
saturation = 1 + np.random.uniform(low=-self.saturation, high=self.saturation)
|
281 |
+
|
282 |
+
# Randomize the order of changing brightness, contrast, and saturation
|
283 |
+
attr_names = ['brightness', 'contrast', 'saturation']
|
284 |
+
attr_values = [brightness, contrast, saturation] # The value of changing attrs
|
285 |
+
attr_indexes = np.array(range(len(attr_names))) # The order of changing attrs
|
286 |
+
np.random.shuffle(attr_indexes)
|
287 |
+
|
288 |
+
for img_idx, img in enumerate(rendering_images):
|
289 |
+
processed_image = img
|
290 |
+
for idx in attr_indexes:
|
291 |
+
processed_image = self._adjust_image_attr(processed_image, attr_names[idx], attr_values[idx])
|
292 |
+
|
293 |
+
processed_images = np.append(processed_images, [processed_image], axis=0)
|
294 |
+
# print('ColorJitter', np.mean(ori_img), np.mean(processed_image))
|
295 |
+
# fig = plt.figure(figsize=(8, 4))
|
296 |
+
# ax1 = fig.add_subplot(1, 2, 1)
|
297 |
+
# ax1.imshow(ori_img)
|
298 |
+
# ax2 = fig.add_subplot(1, 2, 2)
|
299 |
+
# ax2.imshow(processed_image)
|
300 |
+
# plt.show()
|
301 |
+
return processed_images
|
302 |
+
|
303 |
+
def _adjust_image_attr(self, img, attr_name, attr_value):
|
304 |
+
"""
|
305 |
+
Adjust or randomize the specified attribute of the image
|
306 |
+
|
307 |
+
Args:
|
308 |
+
img: Image in BGR format
|
309 |
+
Numpy array of shape (h, w, 3)
|
310 |
+
attr_name: Image attribute to adjust or randomize
|
311 |
+
'brightness', 'saturation', or 'contrast'
|
312 |
+
attr_value: the alpha for blending is randomly drawn from [1 - d, 1 + d]
|
313 |
+
|
314 |
+
Returns:
|
315 |
+
Output image in BGR format
|
316 |
+
Numpy array of the same shape as input
|
317 |
+
"""
|
318 |
+
gs = self._bgr_to_gray(img)
|
319 |
+
|
320 |
+
if attr_name == 'contrast':
|
321 |
+
img = self._alpha_blend(img, np.mean(gs[:, :, 0]), attr_value)
|
322 |
+
elif attr_name == 'saturation':
|
323 |
+
img = self._alpha_blend(img, gs, attr_value)
|
324 |
+
elif attr_name == 'brightness':
|
325 |
+
img = self._alpha_blend(img, 0, attr_value)
|
326 |
+
else:
|
327 |
+
raise NotImplementedError(attr_name)
|
328 |
+
return img
|
329 |
+
|
330 |
+
def _bgr_to_gray(self, bgr):
|
331 |
+
"""
|
332 |
+
Convert a RGB image to a grayscale image
|
333 |
+
Differences from cv2.cvtColor():
|
334 |
+
1. Input image can be float
|
335 |
+
2. Output image has three repeated channels, other than a single channel
|
336 |
+
|
337 |
+
Args:
|
338 |
+
bgr: Image in BGR format
|
339 |
+
Numpy array of shape (h, w, 3)
|
340 |
+
|
341 |
+
Returns:
|
342 |
+
gs: Grayscale image
|
343 |
+
Numpy array of the same shape as input; the three channels are the same
|
344 |
+
"""
|
345 |
+
ch = 0.114 * bgr[:, :, 0] + 0.587 * bgr[:, :, 1] + 0.299 * bgr[:, :, 2]
|
346 |
+
gs = np.dstack((ch, ch, ch))
|
347 |
+
return gs
|
348 |
+
|
349 |
+
def _alpha_blend(self, im1, im2, alpha):
|
350 |
+
"""
|
351 |
+
Alpha blending of two images or one image and a scalar
|
352 |
+
|
353 |
+
Args:
|
354 |
+
im1, im2: Image or scalar
|
355 |
+
Numpy array and a scalar or two numpy arrays of the same shape
|
356 |
+
alpha: Weight of im1
|
357 |
+
Float ranging usually from 0 to 1
|
358 |
+
|
359 |
+
Returns:
|
360 |
+
im_blend: Blended image -- alpha * im1 + (1 - alpha) * im2
|
361 |
+
Numpy array of the same shape as input image
|
362 |
+
"""
|
363 |
+
im_blend = alpha * im1 + (1 - alpha) * im2
|
364 |
+
return im_blend
|
365 |
+
|
366 |
+
|
367 |
+
class RandomNoise(object):
|
368 |
+
def __init__(self,
|
369 |
+
noise_std,
|
370 |
+
eigvals=(0.2175, 0.0188, 0.0045),
|
371 |
+
eigvecs=((-0.5675, 0.7192, 0.4009), (-0.5808, -0.0045, -0.8140), (-0.5836, -0.6948, 0.4203))):
|
372 |
+
self.noise_std = noise_std
|
373 |
+
self.eigvals = np.array(eigvals)
|
374 |
+
self.eigvecs = np.array(eigvecs)
|
375 |
+
|
376 |
+
def __call__(self, rendering_images):
|
377 |
+
alpha = np.random.normal(loc=0, scale=self.noise_std, size=3)
|
378 |
+
noise_rgb = \
|
379 |
+
np.sum(
|
380 |
+
np.multiply(
|
381 |
+
np.multiply(
|
382 |
+
self.eigvecs,
|
383 |
+
np.tile(alpha, (3, 1))
|
384 |
+
),
|
385 |
+
np.tile(self.eigvals, (3, 1))
|
386 |
+
),
|
387 |
+
axis=1
|
388 |
+
)
|
389 |
+
|
390 |
+
# Allocate new space for storing processed images
|
391 |
+
img_height, img_width, img_channels = rendering_images[0].shape
|
392 |
+
assert (img_channels == 3), "Please use RandomBackground to normalize image channels"
|
393 |
+
processed_images = np.empty(shape=(0, img_height, img_width, img_channels))
|
394 |
+
|
395 |
+
for img_idx, img in enumerate(rendering_images):
|
396 |
+
processed_image = img[:, :, ::-1] # BGR -> RGB
|
397 |
+
for i in range(img_channels):
|
398 |
+
processed_image[:, :, i] += noise_rgb[i]
|
399 |
+
|
400 |
+
processed_image = processed_image[:, :, ::-1] # RGB -> BGR
|
401 |
+
processed_images = np.append(processed_images, [processed_image], axis=0)
|
402 |
+
# from copy import deepcopy
|
403 |
+
# ori_img = deepcopy(img)
|
404 |
+
# print(noise_rgb, np.mean(processed_image), np.mean(ori_img))
|
405 |
+
# print('RandomNoise', np.mean(ori_img), np.mean(processed_image))
|
406 |
+
# fig = plt.figure(figsize=(8, 4))
|
407 |
+
# ax1 = fig.add_subplot(1, 2, 1)
|
408 |
+
# ax1.imshow(ori_img)
|
409 |
+
# ax2 = fig.add_subplot(1, 2, 2)
|
410 |
+
# ax2.imshow(processed_image)
|
411 |
+
# plt.show()
|
412 |
+
return processed_images
|
413 |
+
|
414 |
+
|
415 |
+
class RandomBackground(object):
|
416 |
+
def __init__(self, random_bg_color_range, random_bg_folder_path=None):
|
417 |
+
self.random_bg_color_range = random_bg_color_range
|
418 |
+
self.random_bg_files = []
|
419 |
+
if random_bg_folder_path is not None:
|
420 |
+
self.random_bg_files = os.listdir(random_bg_folder_path)
|
421 |
+
self.random_bg_files = [os.path.join(random_bg_folder_path, rbf) for rbf in self.random_bg_files]
|
422 |
+
|
423 |
+
def __call__(self, rendering_images):
|
424 |
+
if len(rendering_images) == 0:
|
425 |
+
return rendering_images
|
426 |
+
|
427 |
+
img_height, img_width, img_channels = rendering_images[0].shape
|
428 |
+
# If the image has the alpha channel, add the background
|
429 |
+
if not img_channels == 4:
|
430 |
+
return rendering_images
|
431 |
+
|
432 |
+
# Generate random background
|
433 |
+
r, g, b = np.array([
|
434 |
+
np.random.randint(self.random_bg_color_range[i][0], self.random_bg_color_range[i][1] + 1) for i in range(3)
|
435 |
+
]) / 255.
|
436 |
+
|
437 |
+
random_bg = None
|
438 |
+
if len(self.random_bg_files) > 0:
|
439 |
+
random_bg_file_path = random.choice(self.random_bg_files)
|
440 |
+
random_bg = cv2.imread(random_bg_file_path).astype(np.float32) / 255.
|
441 |
+
|
442 |
+
# Apply random background
|
443 |
+
processed_images = np.empty(shape=(0, img_height, img_width, img_channels - 1))
|
444 |
+
for img_idx, img in enumerate(rendering_images):
|
445 |
+
alpha = (np.expand_dims(img[:, :, 3], axis=2) == 0).astype(np.float32)
|
446 |
+
img = img[:, :, :3]
|
447 |
+
bg_color = random_bg if random.randint(0, 1) and random_bg is not None else np.array([[[r, g, b]]])
|
448 |
+
img = alpha * bg_color + (1 - alpha) * img
|
449 |
+
|
450 |
+
processed_images = np.append(processed_images, [img], axis=0)
|
451 |
+
|
452 |
+
return processed_images
|