Spaces:
Runtime error
Runtime error
File size: 10,629 Bytes
e2ebf5f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 |
import os
import open3d as o3d
import torch
from pytorch3d.renderer import TexturesUV
from pytorch3d.structures import Meshes
from pytorch3d.ops import knn_points
from pytorch3d.io import save_obj
from sugar.sugar_scene.gs_model import GaussianSplattingWrapper
from sugar.sugar_scene.sugar_model import SuGaR, extract_texture_image_and_uv_from_gaussians
from sugar.sugar_utils.spherical_harmonics import SH2RGB
from rich.console import Console
def extract_mesh_and_texture_from_refined_sugar(args):
CONSOLE = Console(width=120)
n_skip_images_for_eval_split = 8
# --- Scene data parameters ---
source_path = args.scene_path
use_train_test_split = args.eval
# --- Vanilla 3DGS parameters ---
iteration_to_load = args.iteration_to_load
gs_checkpoint_path = args.checkpoint_path
# --- Fine model parameters ---
refined_model_path = args.refined_model_path
if args.n_gaussians_per_surface_triangle is None:
n_gaussians_per_surface_triangle = int(refined_model_path.split('/')[-2].split('_gaussperface')[-1])
else:
n_gaussians_per_surface_triangle = args.n_gaussians_per_surface_triangle
# --- Output parameters ---
if args.mesh_output_dir is None:
if len(args.scene_path.split("/")[-1]) > 0:
args.mesh_output_dir = os.path.join("./output/refined_mesh", args.scene_path.split("/")[-1])
else:
args.mesh_output_dir = os.path.join("./output/refined_mesh", args.scene_path.split("/")[-2])
mesh_output_dir = args.mesh_output_dir
os.makedirs(mesh_output_dir, exist_ok=True)
mesh_save_path = refined_model_path.split('/')[-2]
if args.postprocess_mesh:
mesh_save_path = mesh_save_path + '_postprocessed'
mesh_save_path = mesh_save_path + '.obj'
mesh_save_path = os.path.join(mesh_output_dir, mesh_save_path)
scene_name = source_path.split('/')[-2] if len(source_path.split('/')[-1]) == 0 else source_path.split('/')[-1]
sugar_mesh_path = os.path.join('./output/coarse_mesh/', scene_name,
refined_model_path.split('/')[-2].split('_normalconsistency')[0].replace('sugarfine', 'sugarmesh') + '.ply')
if args.square_size is None:
if n_gaussians_per_surface_triangle == 1:
# square_size = 5 # Maybe 4 already works
square_size = 10 # Maybe 4 already works
if n_gaussians_per_surface_triangle == 6:
square_size = 10
else:
square_size = args.square_size
# Postprocessing
postprocess_mesh = args.postprocess_mesh
if postprocess_mesh:
postprocess_density_threshold = args.postprocess_density_threshold
postprocess_iterations = args.postprocess_iterations
CONSOLE.print('==================================================')
CONSOLE.print("Starting extracting texture from refined SuGaR model:")
CONSOLE.print('Scene path:', source_path)
CONSOLE.print('Iteration to load:', iteration_to_load)
CONSOLE.print('Vanilla 3DGS checkpoint path:', gs_checkpoint_path)
CONSOLE.print('Refined model path:', refined_model_path)
CONSOLE.print('Coarse mesh path:', sugar_mesh_path)
CONSOLE.print('Mesh output directory:', mesh_output_dir)
CONSOLE.print('Mesh save path:', mesh_save_path)
CONSOLE.print('Number of gaussians per surface triangle:', n_gaussians_per_surface_triangle)
CONSOLE.print('Square size:', square_size)
CONSOLE.print('Postprocess mesh:', postprocess_mesh)
CONSOLE.print('==================================================')
# Set the GPU
torch.cuda.set_device(args.gpu)
# ==========================
# --- Loading Vanilla 3DGS model ---
CONSOLE.print("Source path:", source_path)
CONSOLE.print("Gaussian splatting checkpoint path:", gs_checkpoint_path)
CONSOLE.print(f"\nLoading Vanilla 3DGS model config {gs_checkpoint_path}...")
nerfmodel = GaussianSplattingWrapper(
source_path=source_path,
output_path=gs_checkpoint_path,
iteration_to_load=iteration_to_load,
load_gt_images=False, # TODO: Check
eval_split=use_train_test_split,
eval_split_interval=n_skip_images_for_eval_split,
)
CONSOLE.print("Vanilla 3DGS Loaded.")
CONSOLE.print(f'{len(nerfmodel.training_cameras)} training images detected.')
CONSOLE.print(f'The model has been trained for {iteration_to_load} steps.')
CONSOLE.print(len(nerfmodel.gaussians._xyz) / 1e6, "M gaussians detected.")
# --- Loading coarse mesh ---
o3d_mesh = o3d.io.read_triangle_mesh(sugar_mesh_path)
# --- Loading refined SuGaR model ---
checkpoint = torch.load(refined_model_path, map_location=nerfmodel.device)
refined_sugar = SuGaR(
nerfmodel=nerfmodel,
points=checkpoint['state_dict']['_points'],
colors=SH2RGB(checkpoint['state_dict']['_sh_coordinates_dc'][:, 0, :]),
initialize=False,
sh_levels=nerfmodel.gaussians.active_sh_degree+1,
keep_track_of_knn=False,
knn_to_track=0,
beta_mode='average',
surface_mesh_to_bind=o3d_mesh,
n_gaussians_per_surface_triangle=n_gaussians_per_surface_triangle,
)
refined_sugar.load_state_dict(checkpoint['state_dict'])
refined_sugar.eval()
if postprocess_mesh:
CONSOLE.print("Postprocessing mesh by removing border triangles with low-opacity gaussians...")
with torch.no_grad():
new_verts = refined_sugar.surface_mesh.verts_list()[0].detach().clone()
new_faces = refined_sugar.surface_mesh.faces_list()[0].detach().clone()
new_normals = refined_sugar.surface_mesh.faces_normals_list()[0].detach().clone()
# For each face, get the 3 edges
edges0 = new_faces[..., None, (0,1)].sort(dim=-1)[0]
edges1 = new_faces[..., None, (1,2)].sort(dim=-1)[0]
edges2 = new_faces[..., None, (2,0)].sort(dim=-1)[0]
all_edges = torch.cat([edges0, edges1, edges2], dim=-2)
# We start by identifying the inside faces and border faces
face_mask = refined_sugar.strengths[..., 0] > -1.
for i in range(postprocess_iterations):
CONSOLE.print("\nStarting postprocessing iteration", i)
# We look for edges that appear in the list at least twice (their NN is themselves)
edges_neighbors = knn_points(all_edges[face_mask].view(1, -1, 2).float(), all_edges[face_mask].view(1, -1, 2).float(), K=2)
# If all edges of a face appear in the list at least twice, then the face is inside the mesh
is_inside = (edges_neighbors.dists[0][..., 1].view(-1, 3) < 0.01).all(-1)
# We update the mask by removing border faces
face_mask[face_mask.clone()] = is_inside
# We then add back border faces with high-density
face_centers = new_verts[new_faces].mean(-2)
face_densities = refined_sugar.compute_density(face_centers[~face_mask])
face_mask[~face_mask.clone()] = face_densities > postprocess_density_threshold
# And we create the new mesh and SuGaR model
new_faces = new_faces[face_mask]
new_normals = new_normals[face_mask]
new_scales = refined_sugar._scales.reshape(len(face_mask), -1, 2)[face_mask].view(-1, 2)
new_quaternions = refined_sugar._quaternions.reshape(len(face_mask), -1, 2)[face_mask].view(-1, 2)
new_densities = refined_sugar.all_densities.reshape(len(face_mask), -1, 1)[face_mask].view(-1, 1)
new_sh_coordinates_dc = refined_sugar._sh_coordinates_dc.reshape(len(face_mask), -1, 1, 3)[face_mask].view(-1, 1, 3)
new_sh_coordinates_rest = refined_sugar._sh_coordinates_rest.reshape(len(face_mask), -1, 15, 3)[face_mask].view(-1, 15, 3)
new_o3d_mesh = o3d.geometry.TriangleMesh()
new_o3d_mesh.vertices = o3d.utility.Vector3dVector(new_verts.cpu().numpy())
new_o3d_mesh.triangles = o3d.utility.Vector3iVector(new_faces.cpu().numpy())
new_o3d_mesh.vertex_normals = o3d.utility.Vector3dVector(new_normals.cpu().numpy())
new_o3d_mesh.vertex_colors = o3d.utility.Vector3dVector(torch.ones_like(new_verts).cpu().numpy())
refined_sugar = SuGaR(
nerfmodel=nerfmodel,
points=None,
colors=None,
initialize=False,
sh_levels=nerfmodel.gaussians.active_sh_degree+1,
keep_track_of_knn=False,
knn_to_track=0,
beta_mode='average',
surface_mesh_to_bind=new_o3d_mesh,
n_gaussians_per_surface_triangle=refined_sugar.n_gaussians_per_surface_triangle,
)
refined_sugar._scales[...] = new_scales
refined_sugar._quaternions[...] = new_quaternions
refined_sugar.all_densities[...] = new_densities
refined_sugar._sh_coordinates_dc[...] = new_sh_coordinates_dc
refined_sugar._sh_coordinates_rest[...] = new_sh_coordinates_rest
CONSOLE.print("Mesh postprocessed.")
# Compute texture
with torch.no_grad():
verts_uv, faces_uv, texture_img = extract_texture_image_and_uv_from_gaussians(
refined_sugar, square_size=square_size, n_sh=1, texture_with_gaussian_renders=True)
textures_uv = TexturesUV(
maps=texture_img[None], #texture_img[None]),
verts_uvs=verts_uv[None],
faces_uvs=faces_uv[None],
sampling_mode='nearest',
)
textured_mesh = Meshes(
verts=[refined_sugar.surface_mesh.verts_list()[0]],
faces=[refined_sugar.surface_mesh.faces_list()[0]],
textures=textures_uv,
)
CONSOLE.print("Texture extracted.")
CONSOLE.print("Texture shape:", texture_img.shape)
CONSOLE.print("Saving textured mesh...")
with torch.no_grad():
save_obj(
mesh_save_path,
verts=textured_mesh.verts_list()[0],
faces=textured_mesh.faces_list()[0],
verts_uvs=textured_mesh.textures.verts_uvs_list()[0],
faces_uvs=textured_mesh.textures.faces_uvs_list()[0],
texture_map=textured_mesh.textures.maps_padded()[0].clamp(0., 1.),
)
CONSOLE.print("Texture saved at:", mesh_save_path)
return mesh_save_path |