Spaces:
Sleeping
Sleeping
import utils.binvox_rw as binvox_rw | |
import numpy as np | |
import plotly.graph_objects as go | |
from models.encoder import Encoder | |
from models.decoder import Decoder | |
from models.merger import Merger | |
from models.refiner import Refiner | |
from config import cfg | |
import torch | |
from datetime import datetime as dt | |
import utils.data_transforms | |
device=torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
print(device) | |
# device='cpu' | |
cfg.CONST.WEIGHTS='saved_model/Pix2Vox.pth' | |
def read_binvox(file) -> np.ndarray: | |
model = binvox_rw.read_as_3d_array(file) | |
return model.data.astype(np.uint8) | |
def voxel_to_plotly(voxels): | |
x, y, z = voxels.nonzero() | |
fig = go.Figure(data=[ | |
go.Scatter3d( | |
x=x, y=y, z=z, | |
mode='markers', | |
marker=dict(size=3, color=z, colorscale='Viridis', opacity=0.7) | |
) | |
]) | |
fig.update_layout(scene=dict(aspectmode='data')) | |
return fig | |
# when gpu is not available | |
def remove_module_prefix(state_dict): | |
new_state_dict = {} | |
for k, v in state_dict.items(): | |
if k.startswith('module.'): | |
new_state_dict[k[7:]] = v | |
else: | |
new_state_dict[k] = v | |
return new_state_dict | |
IMG_SIZE = cfg.CONST.IMG_H, cfg.CONST.IMG_W | |
CROP_SIZE = cfg.CONST.CROP_IMG_H, cfg.CONST.CROP_IMG_W | |
test_transforms = utils.data_transforms.Compose([ | |
utils.data_transforms.CenterCrop(IMG_SIZE, CROP_SIZE), | |
utils.data_transforms.RandomBackground(cfg.TEST.RANDOM_BG_COLOR_RANGE), | |
utils.data_transforms.Normalize(mean=cfg.DATASET.MEAN, std=cfg.DATASET.STD), | |
utils.data_transforms.ToTensor(), | |
]) | |
def predict_voxel_from_images(rendering_images): | |
transformed_images = test_transforms(rendering_images) | |
encoder = Encoder(cfg) | |
decoder = Decoder(cfg) | |
refiner = Refiner(cfg) | |
merger = Merger(cfg) | |
print('[INFO] %s Loading weights from %s ...' % (dt.now(), cfg.CONST.WEIGHTS)) | |
checkpoint = torch.load(cfg.CONST.WEIGHTS, weights_only=False, map_location=device) | |
if torch.cuda.is_available(): | |
encoder = torch.nn.DataParallel(encoder).cuda() | |
decoder = torch.nn.DataParallel(decoder).cuda() | |
refiner = torch.nn.DataParallel(refiner).cuda() | |
merger = torch.nn.DataParallel(merger).cuda() | |
encoder_state_dict=checkpoint['encoder_state_dict'] | |
decoder_state_dict=checkpoint['decoder_state_dict'] | |
merger_state_dict=checkpoint['merger_state_dict'] | |
refiner_state_dict = checkpoint['refiner_state_dict'] | |
else: | |
encoder_state_dict = remove_module_prefix(checkpoint['encoder_state_dict']) | |
decoder_state_dict = remove_module_prefix(checkpoint['decoder_state_dict']) | |
merger_state_dict = remove_module_prefix(checkpoint['merger_state_dict']) | |
refiner_state_dict = remove_module_prefix(checkpoint['refiner_state_dict']) | |
epoch_idx = checkpoint['epoch_idx'] | |
encoder.load_state_dict(encoder_state_dict) | |
decoder.load_state_dict(decoder_state_dict) | |
if cfg.NETWORK.USE_REFINER: | |
refiner.load_state_dict(refiner_state_dict) | |
if cfg.NETWORK.USE_MERGER: | |
merger.load_state_dict(merger_state_dict) | |
encoder.eval() | |
decoder.eval() | |
merger.eval() | |
refiner.eval() | |
with torch.no_grad(): | |
transformed_images = transformed_images.unsqueeze(0) #adding the batch_dim | |
transformed_images = transformed_images.to(device) | |
# print(rendering_images.shape) | |
image_features = encoder(transformed_images) | |
print(image_features.shape) | |
raw_features, generated_volume = decoder(image_features) | |
print(generated_volume.shape) | |
if cfg.NETWORK.USE_MERGER: | |
generated_volume = merger(raw_features, generated_volume) | |
else: | |
generated_volume = torch.mean(generated_volume, dim=1) | |
# encoder_loss = bce_loss(generated_volume, ground_truth_volume) * 10 | |
if cfg.NETWORK.USE_REFINER: | |
generated_volume = refiner(generated_volume) | |
# refiner_loss = bce_loss(generated_volume, ground_truth_volume) * 10 | |
else: | |
# refiner_loss = encoder_loss | |
pass | |
generated_volume=generated_volume.squeeze(0) | |
gv = generated_volume.cpu().numpy() | |
gv = (gv >= 0.5).astype(np.uint8) | |
torch.cuda.empty_cache() | |
return gv | |