|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
import time |
|
import math |
|
import cv2 |
|
import numpy as np |
|
import itertools |
|
import shutil |
|
from tqdm import tqdm |
|
import torch |
|
import torch.nn.functional as F |
|
from einops import rearrange |
|
try: |
|
import trimesh |
|
import mcubes |
|
import xatlas |
|
import open3d as o3d |
|
except: |
|
raise "failed to import 3d libraries " |
|
|
|
from ..modules.rendering_neus.mesh import Mesh |
|
from ..modules.rendering_neus.rasterize import NVDiffRasterizerContext |
|
|
|
from ..utils.ops import scale_tensor |
|
from ..util import count_params, instantiate_from_config |
|
from ..vis_util import render |
|
|
|
|
|
def unwrap_uv(v_pos, t_pos_idx): |
|
print("Using xatlas to perform UV unwrapping, may take a while ...") |
|
atlas = xatlas.Atlas() |
|
atlas.add_mesh(v_pos, t_pos_idx) |
|
atlas.generate(xatlas.ChartOptions(), xatlas.PackOptions()) |
|
_, indices, uvs = atlas.get_mesh(0) |
|
indices = indices.astype(np.int64, casting="same_kind") |
|
return uvs, indices |
|
|
|
|
|
def uv_padding(image, hole_mask, uv_padding_size = 2): |
|
return cv2.inpaint( |
|
(image.detach().cpu().numpy() * 255).astype(np.uint8), |
|
(hole_mask.detach().cpu().numpy() * 255).astype(np.uint8), |
|
uv_padding_size, |
|
cv2.INPAINT_TELEA |
|
) |
|
|
|
def refine_mesh(vtx_refine, faces_refine): |
|
mesh = o3d.geometry.TriangleMesh( |
|
vertices=o3d.utility.Vector3dVector(vtx_refine), |
|
triangles=o3d.utility.Vector3iVector(faces_refine) |
|
) |
|
|
|
mesh = mesh.remove_unreferenced_vertices() |
|
mesh = mesh.remove_duplicated_triangles() |
|
mesh = mesh.remove_duplicated_vertices() |
|
|
|
voxel_size = max(mesh.get_max_bound() - mesh.get_min_bound()) |
|
|
|
mesh = mesh.simplify_vertex_clustering( |
|
voxel_size=0.007, |
|
contraction=o3d.geometry.SimplificationContraction.Average) |
|
|
|
mesh = mesh.filter_smooth_simple(number_of_iterations=2) |
|
|
|
vtx_refine = np.asarray(mesh.vertices).astype(np.float32) |
|
faces_refine = np.asarray(mesh.triangles) |
|
return vtx_refine, faces_refine, mesh |
|
|
|
|
|
class SVRMModel(torch.nn.Module): |
|
def __init__( |
|
self, |
|
img_encoder_config, |
|
img_to_triplane_config, |
|
render_config, |
|
device = "cuda:0", |
|
**kwargs |
|
): |
|
super(SVRMModel, self).__init__() |
|
self.img_encoder = instantiate_from_config(img_encoder_config).half() |
|
self.img_to_triplane_decoder = instantiate_from_config(img_to_triplane_config).half() |
|
self.render = instantiate_from_config(render_config).half() |
|
self.device = device |
|
count_params(self, verbose=True) |
|
|
|
|
|
@torch.no_grad() |
|
def export_mesh_with_uv( |
|
self, |
|
data, |
|
mesh_size: int = 384, |
|
ctx = None, |
|
context_type = 'cuda', |
|
texture_res = 1024, |
|
target_face_count = 10000, |
|
do_texture_mapping = True, |
|
out_dir = 'outputs/test' |
|
): |
|
""" |
|
color_type: 0 for ray texture, 1 for vertices texture |
|
""" |
|
|
|
obj_vertext_path = os.path.join(out_dir, 'mesh_with_colors.obj') |
|
obj_path = os.path.join(out_dir, 'mesh.obj') |
|
obj_texture_path = os.path.join(out_dir, 'texture.png') |
|
obj_mtl_path = os.path.join(out_dir, 'texture.mtl') |
|
glb_path = os.path.join(out_dir, 'mesh.glb') |
|
|
|
st = time.time() |
|
|
|
here = {'device': self.device, 'dtype': torch.float16} |
|
input_view_image = data["input_view"].to(**here) |
|
input_view_cam = data["input_view_cam"].to(**here) |
|
|
|
batch_size, input_view_num, *_ = input_view_image.shape |
|
assert batch_size == 1, "batch size should be 1" |
|
|
|
input_view_image = rearrange(input_view_image, 'b m c h w -> (b m) c h w') |
|
input_view_cam = rearrange(input_view_cam, 'b m d -> (b m) d') |
|
input_view_feat = self.img_encoder(input_view_image, input_view_cam) |
|
input_view_feat = rearrange(input_view_feat, '(b m) l d -> b (l m) d', m=input_view_num) |
|
|
|
|
|
torch.cuda.empty_cache() |
|
triplane_gen = self.img_to_triplane_decoder(input_view_feat) |
|
del input_view_feat |
|
torch.cuda.empty_cache() |
|
|
|
|
|
|
|
cur_triplane = triplane_gen[0:1] |
|
|
|
aabb = torch.tensor([[-0.6, -0.6, -0.6], [0.6, 0.6, 0.6]]).unsqueeze(0).to(**here) |
|
grid_out = self.render.forward_grid(planes=cur_triplane, grid_size=mesh_size, aabb=aabb) |
|
|
|
print(f"=====> Triplane forward time: {time.time() - st}") |
|
st = time.time() |
|
|
|
vtx, faces = mcubes.marching_cubes(0. - grid_out['sdf'].squeeze(0).squeeze(-1).cpu().float().numpy(), 0) |
|
|
|
bbox = aabb[0].cpu().numpy() |
|
vtx = vtx / (mesh_size - 1) |
|
vtx = vtx * (bbox[1] - bbox[0]) + bbox[0] |
|
|
|
|
|
vtx_refine, faces_refine, mesh = refine_mesh(vtx, faces) |
|
|
|
|
|
if faces_refine.shape[0] > target_face_count: |
|
print(f"reduce face: {faces_refine.shape[0]} -> {target_face_count}") |
|
mesh = o3d.geometry.TriangleMesh( |
|
vertices = o3d.utility.Vector3dVector(vtx_refine), |
|
triangles = o3d.utility.Vector3iVector(faces_refine) |
|
) |
|
|
|
|
|
mesh = mesh.simplify_quadric_decimation(target_face_count, boundary_weight=1.0) |
|
|
|
mesh = Mesh( |
|
v_pos = torch.from_numpy(np.asarray(mesh.vertices)).to(self.device), |
|
t_pos_idx = torch.from_numpy(np.asarray(mesh.triangles)).to(self.device), |
|
v_rgb = torch.from_numpy(np.asarray(mesh.vertex_colors)).to(self.device) |
|
) |
|
vtx_refine = mesh.v_pos.cpu().numpy() |
|
faces_refine = mesh.t_pos_idx.cpu().numpy() |
|
|
|
vtx_colors = self.render.forward_points(cur_triplane, torch.tensor(vtx_refine).unsqueeze(0).to(**here)) |
|
vtx_colors = vtx_colors['rgb'].float().squeeze(0).cpu().numpy() |
|
|
|
color_ratio = 0.8 |
|
with open(obj_vertext_path, 'w') as fid: |
|
verts = vtx_refine[:, [1,2,0]] |
|
for pidx, pp in enumerate(verts): |
|
color = vtx_colors[pidx] |
|
color = [color[0]**color_ratio, color[1]**color_ratio, color[2]**color_ratio] |
|
fid.write('v %f %f %f %f %f %f\n' % (pp[0], pp[1], pp[2], color[0], color[1], color[2])) |
|
for i, f in enumerate(faces_refine): |
|
f1 = f + 1 |
|
fid.write('f %d %d %d\n' % (f1[0], f1[1], f1[2])) |
|
|
|
mesh = trimesh.load_mesh(obj_vertext_path) |
|
print(f"=====> generate mesh with vertex shading time: {time.time() - st}") |
|
st = time.time() |
|
|
|
if not do_texture_mapping: |
|
shutil.copy(obj_vertext_path, obj_path) |
|
mesh.export(glb_path, file_type='glb') |
|
return None |
|
|
|
|
|
|
|
|
|
|
|
st = time.time() |
|
|
|
|
|
vtx_tex, t_tex_idx = unwrap_uv(vtx_refine, faces_refine) |
|
vtx_refine = torch.from_numpy(vtx_refine).to(self.device) |
|
faces_refine = torch.from_numpy(faces_refine).to(self.device) |
|
t_tex_idx = torch.from_numpy(t_tex_idx).to(self.device) |
|
uv_clip = torch.from_numpy(vtx_tex * 2.0 - 1.0).to(self.device) |
|
|
|
|
|
ctx = NVDiffRasterizerContext(context_type, cur_triplane.device) if ctx is None else ctx |
|
rast = ctx.rasterize_one( |
|
torch.cat([ |
|
uv_clip, |
|
torch.zeros_like(uv_clip[..., 0:1]), |
|
torch.ones_like(uv_clip[..., 0:1]) |
|
], dim=-1), |
|
t_tex_idx, |
|
(texture_res, texture_res) |
|
)[0] |
|
hole_mask = ~(rast[:, :, 3] > 0) |
|
|
|
|
|
gb_pos = ctx.interpolate_one(vtx_refine, rast[None, ...], faces_refine)[0][0] |
|
|
|
with torch.no_grad(): |
|
gb_mask_pos_scale = scale_tensor(gb_pos.unsqueeze(0).view(1, -1, 3), (-1, 1), (-1, 1)) |
|
|
|
tex_map = self.render.forward_points(cur_triplane, gb_mask_pos_scale)['rgb'] |
|
|
|
tex_map = tex_map.float().squeeze(0) |
|
tex_map = tex_map.view((texture_res, texture_res, 3)) |
|
img = uv_padding(tex_map, hole_mask) |
|
img = ((img/255.0) ** color_ratio) * 255 |
|
img = img.clip(0, 255).astype(np.uint8) |
|
|
|
verts = vtx_refine.cpu().numpy()[:, [1,2,0]] |
|
faces = faces_refine.cpu().numpy() |
|
|
|
with open(obj_mtl_path, 'w') as fid: |
|
fid.write('newmtl material_0\n') |
|
fid.write("Ka 1.000 1.000 1.000\n") |
|
fid.write("Kd 1.000 1.000 1.000\n") |
|
fid.write("Ks 0.000 0.000 0.000\n") |
|
fid.write("d 1.0\n") |
|
fid.write("illum 2\n") |
|
fid.write(f'map_Kd texture.png\n') |
|
|
|
with open(obj_path, 'w') as fid: |
|
fid.write(f'mtllib texture.mtl\n') |
|
for pidx, pp in enumerate(verts): |
|
fid.write('v %f %f %f\n' % (pp[0], pp[1], pp[2])) |
|
for pidx, pp in enumerate(vtx_tex): |
|
fid.write('vt %f %f\n' % (pp[0], 1 - pp[1])) |
|
fid.write('usemtl material_0\n') |
|
for i, f in enumerate(faces): |
|
f1 = f + 1 |
|
f2 = t_tex_idx[i] + 1 |
|
fid.write('f %d/%d %d/%d %d/%d\n' % (f1[0], f2[0], f1[1], f2[1], f1[2], f2[2],)) |
|
|
|
cv2.imwrite(obj_texture_path, img[..., [2, 1, 0]]) |
|
mesh = trimesh.load_mesh(obj_path) |
|
mesh.export(glb_path, file_type='glb') |
|
print(f"=====> generate mesh with texture shading time: {time.time() - st}") |
|
|