Spaces:
Runtime error
Runtime error
import sys | |
import os | |
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) | |
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) | |
import time | |
import json | |
import numpy as np | |
import torch | |
from torch.utils.data import DataLoader | |
from lib.options import BaseOptions | |
from lib.mesh_util import * | |
from lib.sample_util import * | |
from lib.train_util import * | |
from lib.model import * | |
from PIL import Image | |
import torchvision.transforms as transforms | |
import glob | |
import tqdm | |
import trimesh | |
# get options | |
opt = BaseOptions().parse() | |
class Evaluator: | |
def __init__(self, opt, projection_mode='orthogonal'): | |
self.opt = opt | |
self.load_size = self.opt.loadSize | |
self.to_tensor = transforms.Compose([ | |
transforms.Resize(self.load_size), | |
transforms.ToTensor(), | |
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) | |
]) | |
# set cuda | |
cuda = torch.device('cuda:%d' % opt.gpu_id) if torch.cuda.is_available() else torch.device('cpu') | |
# create net | |
netG = HGPIFuNet(opt, projection_mode).to(device=cuda) | |
print('Using Network: ', netG.name) | |
if opt.load_netG_checkpoint_path: | |
netG.load_state_dict(torch.load(opt.load_netG_checkpoint_path, map_location=cuda)) | |
if opt.load_netC_checkpoint_path is not None: | |
print('loading for net C ...', opt.load_netC_checkpoint_path) | |
netC = ResBlkPIFuNet(opt).to(device=cuda) | |
netC.load_state_dict(torch.load(opt.load_netC_checkpoint_path, map_location=cuda)) | |
else: | |
netC = None | |
os.makedirs(opt.results_path, exist_ok=True) | |
os.makedirs('%s/%s' % (opt.results_path, opt.name), exist_ok=True) | |
opt_log = os.path.join(opt.results_path, opt.name, 'opt.txt') | |
with open(opt_log, 'w') as outfile: | |
outfile.write(json.dumps(vars(opt), indent=2)) | |
self.cuda = cuda | |
self.netG = netG | |
self.netC = netC | |
def load_image(self, image_path, mask_path): | |
# Name | |
img_name = os.path.splitext(os.path.basename(image_path))[0] | |
# Calib | |
B_MIN = np.array([-1, -1, -1]) | |
B_MAX = np.array([1, 1, 1]) | |
projection_matrix = np.identity(4) | |
projection_matrix[1, 1] = -1 | |
calib = torch.Tensor(projection_matrix).float() | |
# Mask | |
mask = Image.open(mask_path).convert('L') | |
mask = transforms.Resize(self.load_size)(mask) | |
mask = transforms.ToTensor()(mask).float() | |
# image | |
image = Image.open(image_path).convert('RGB') | |
image = self.to_tensor(image) | |
image = mask.expand_as(image) * image | |
return { | |
'name': img_name, | |
'img': image.unsqueeze(0), | |
'calib': calib.unsqueeze(0), | |
'mask': mask.unsqueeze(0), | |
'b_min': B_MIN, | |
'b_max': B_MAX, | |
} | |
def eval(self, data, use_octree=False): | |
''' | |
Evaluate a data point | |
:param data: a dict containing at least ['name'], ['image'], ['calib'], ['b_min'] and ['b_max'] tensors. | |
:return: | |
''' | |
opt = self.opt | |
with torch.no_grad(): | |
self.netG.eval() | |
if self.netC: | |
self.netC.eval() | |
save_path = '%s/%s/result_%s.obj' % (opt.results_path, opt.name, data['name']) | |
if self.netC: | |
gen_mesh_color(opt, self.netG, self.netC, self.cuda, data, save_path, use_octree=use_octree) | |
else: | |
gen_mesh(opt, self.netG, self.cuda, data, save_path, use_octree=use_octree) | |
if __name__ == '__main__': | |
evaluator = Evaluator(opt) | |
results_path = opt.results_path | |
name = opt.name | |
test_image_path = opt.img_path | |
test_mask_path = test_image_path[:-4] +'_mask.png' | |
test_img_name = os.path.splitext(os.path.basename(test_image_path))[0] | |
print("test_image: ", test_image_path) | |
print("test_mask: ", test_mask_path) | |
try: | |
data = evaluator.load_image(test_image_path, test_mask_path) | |
evaluator.eval(data, True) | |
mesh = trimesh.load(f'{results_path}/{name}/result_{test_img_name}.obj') | |
mesh.apply_transform([[1, 0, 0, 0], | |
[0, 1, 0, 0], | |
[0, 0, -1, 0], | |
[0, 0, 0, 1]]) | |
mesh.export(file_obj=f'{results_path}/{name}/result_{test_img_name}.glb') | |
except Exception as e: | |
print("error:", e.args) | |