import os import open3d as o3d import torch from pytorch3d.renderer import TexturesUV from pytorch3d.structures import Meshes from pytorch3d.ops import knn_points from pytorch3d.io import save_obj from sugar.sugar_scene.gs_model import GaussianSplattingWrapper from sugar.sugar_scene.sugar_model import SuGaR, extract_texture_image_and_uv_from_gaussians from sugar.sugar_utils.spherical_harmonics import SH2RGB from rich.console import Console def extract_mesh_and_texture_from_refined_sugar(args): CONSOLE = Console(width=120) n_skip_images_for_eval_split = 8 # --- Scene data parameters --- source_path = args.scene_path use_train_test_split = args.eval # --- Vanilla 3DGS parameters --- iteration_to_load = args.iteration_to_load gs_checkpoint_path = args.checkpoint_path # --- Fine model parameters --- refined_model_path = args.refined_model_path if args.n_gaussians_per_surface_triangle is None: n_gaussians_per_surface_triangle = int(refined_model_path.split('/')[-2].split('_gaussperface')[-1]) else: n_gaussians_per_surface_triangle = args.n_gaussians_per_surface_triangle # --- Output parameters --- if args.mesh_output_dir is None: if len(args.scene_path.split("/")[-1]) > 0: args.mesh_output_dir = os.path.join("./output/refined_mesh", args.scene_path.split("/")[-1]) else: args.mesh_output_dir = os.path.join("./output/refined_mesh", args.scene_path.split("/")[-2]) mesh_output_dir = args.mesh_output_dir os.makedirs(mesh_output_dir, exist_ok=True) mesh_save_path = refined_model_path.split('/')[-2] if args.postprocess_mesh: mesh_save_path = mesh_save_path + '_postprocessed' mesh_save_path = mesh_save_path + '.obj' mesh_save_path = os.path.join(mesh_output_dir, mesh_save_path) scene_name = source_path.split('/')[-2] if len(source_path.split('/')[-1]) == 0 else source_path.split('/')[-1] sugar_mesh_path = os.path.join('./output/coarse_mesh/', scene_name, refined_model_path.split('/')[-2].split('_normalconsistency')[0].replace('sugarfine', 'sugarmesh') + '.ply') if args.square_size is None: if n_gaussians_per_surface_triangle == 1: # square_size = 5 # Maybe 4 already works square_size = 10 # Maybe 4 already works if n_gaussians_per_surface_triangle == 6: square_size = 10 else: square_size = args.square_size # Postprocessing postprocess_mesh = args.postprocess_mesh if postprocess_mesh: postprocess_density_threshold = args.postprocess_density_threshold postprocess_iterations = args.postprocess_iterations CONSOLE.print('==================================================') CONSOLE.print("Starting extracting texture from refined SuGaR model:") CONSOLE.print('Scene path:', source_path) CONSOLE.print('Iteration to load:', iteration_to_load) CONSOLE.print('Vanilla 3DGS checkpoint path:', gs_checkpoint_path) CONSOLE.print('Refined model path:', refined_model_path) CONSOLE.print('Coarse mesh path:', sugar_mesh_path) CONSOLE.print('Mesh output directory:', mesh_output_dir) CONSOLE.print('Mesh save path:', mesh_save_path) CONSOLE.print('Number of gaussians per surface triangle:', n_gaussians_per_surface_triangle) CONSOLE.print('Square size:', square_size) CONSOLE.print('Postprocess mesh:', postprocess_mesh) CONSOLE.print('==================================================') # Set the GPU torch.cuda.set_device(args.gpu) # ========================== # --- Loading Vanilla 3DGS model --- CONSOLE.print("Source path:", source_path) CONSOLE.print("Gaussian splatting checkpoint path:", gs_checkpoint_path) CONSOLE.print(f"\nLoading Vanilla 3DGS model config {gs_checkpoint_path}...") nerfmodel = GaussianSplattingWrapper( source_path=source_path, output_path=gs_checkpoint_path, iteration_to_load=iteration_to_load, load_gt_images=False, # TODO: Check eval_split=use_train_test_split, eval_split_interval=n_skip_images_for_eval_split, ) CONSOLE.print("Vanilla 3DGS Loaded.") CONSOLE.print(f'{len(nerfmodel.training_cameras)} training images detected.') CONSOLE.print(f'The model has been trained for {iteration_to_load} steps.') CONSOLE.print(len(nerfmodel.gaussians._xyz) / 1e6, "M gaussians detected.") # --- Loading coarse mesh --- o3d_mesh = o3d.io.read_triangle_mesh(sugar_mesh_path) # --- Loading refined SuGaR model --- checkpoint = torch.load(refined_model_path, map_location=nerfmodel.device) refined_sugar = SuGaR( nerfmodel=nerfmodel, points=checkpoint['state_dict']['_points'], colors=SH2RGB(checkpoint['state_dict']['_sh_coordinates_dc'][:, 0, :]), initialize=False, sh_levels=nerfmodel.gaussians.active_sh_degree+1, keep_track_of_knn=False, knn_to_track=0, beta_mode='average', surface_mesh_to_bind=o3d_mesh, n_gaussians_per_surface_triangle=n_gaussians_per_surface_triangle, ) refined_sugar.load_state_dict(checkpoint['state_dict']) refined_sugar.eval() if postprocess_mesh: CONSOLE.print("Postprocessing mesh by removing border triangles with low-opacity gaussians...") with torch.no_grad(): new_verts = refined_sugar.surface_mesh.verts_list()[0].detach().clone() new_faces = refined_sugar.surface_mesh.faces_list()[0].detach().clone() new_normals = refined_sugar.surface_mesh.faces_normals_list()[0].detach().clone() # For each face, get the 3 edges edges0 = new_faces[..., None, (0,1)].sort(dim=-1)[0] edges1 = new_faces[..., None, (1,2)].sort(dim=-1)[0] edges2 = new_faces[..., None, (2,0)].sort(dim=-1)[0] all_edges = torch.cat([edges0, edges1, edges2], dim=-2) # We start by identifying the inside faces and border faces face_mask = refined_sugar.strengths[..., 0] > -1. for i in range(postprocess_iterations): CONSOLE.print("\nStarting postprocessing iteration", i) # We look for edges that appear in the list at least twice (their NN is themselves) edges_neighbors = knn_points(all_edges[face_mask].view(1, -1, 2).float(), all_edges[face_mask].view(1, -1, 2).float(), K=2) # If all edges of a face appear in the list at least twice, then the face is inside the mesh is_inside = (edges_neighbors.dists[0][..., 1].view(-1, 3) < 0.01).all(-1) # We update the mask by removing border faces face_mask[face_mask.clone()] = is_inside # We then add back border faces with high-density face_centers = new_verts[new_faces].mean(-2) face_densities = refined_sugar.compute_density(face_centers[~face_mask]) face_mask[~face_mask.clone()] = face_densities > postprocess_density_threshold # And we create the new mesh and SuGaR model new_faces = new_faces[face_mask] new_normals = new_normals[face_mask] new_scales = refined_sugar._scales.reshape(len(face_mask), -1, 2)[face_mask].view(-1, 2) new_quaternions = refined_sugar._quaternions.reshape(len(face_mask), -1, 2)[face_mask].view(-1, 2) new_densities = refined_sugar.all_densities.reshape(len(face_mask), -1, 1)[face_mask].view(-1, 1) new_sh_coordinates_dc = refined_sugar._sh_coordinates_dc.reshape(len(face_mask), -1, 1, 3)[face_mask].view(-1, 1, 3) new_sh_coordinates_rest = refined_sugar._sh_coordinates_rest.reshape(len(face_mask), -1, 15, 3)[face_mask].view(-1, 15, 3) new_o3d_mesh = o3d.geometry.TriangleMesh() new_o3d_mesh.vertices = o3d.utility.Vector3dVector(new_verts.cpu().numpy()) new_o3d_mesh.triangles = o3d.utility.Vector3iVector(new_faces.cpu().numpy()) new_o3d_mesh.vertex_normals = o3d.utility.Vector3dVector(new_normals.cpu().numpy()) new_o3d_mesh.vertex_colors = o3d.utility.Vector3dVector(torch.ones_like(new_verts).cpu().numpy()) refined_sugar = SuGaR( nerfmodel=nerfmodel, points=None, colors=None, initialize=False, sh_levels=nerfmodel.gaussians.active_sh_degree+1, keep_track_of_knn=False, knn_to_track=0, beta_mode='average', surface_mesh_to_bind=new_o3d_mesh, n_gaussians_per_surface_triangle=refined_sugar.n_gaussians_per_surface_triangle, ) refined_sugar._scales[...] = new_scales refined_sugar._quaternions[...] = new_quaternions refined_sugar.all_densities[...] = new_densities refined_sugar._sh_coordinates_dc[...] = new_sh_coordinates_dc refined_sugar._sh_coordinates_rest[...] = new_sh_coordinates_rest CONSOLE.print("Mesh postprocessed.") # Compute texture with torch.no_grad(): verts_uv, faces_uv, texture_img = extract_texture_image_and_uv_from_gaussians( refined_sugar, square_size=square_size, n_sh=1, texture_with_gaussian_renders=True) textures_uv = TexturesUV( maps=texture_img[None], #texture_img[None]), verts_uvs=verts_uv[None], faces_uvs=faces_uv[None], sampling_mode='nearest', ) textured_mesh = Meshes( verts=[refined_sugar.surface_mesh.verts_list()[0]], faces=[refined_sugar.surface_mesh.faces_list()[0]], textures=textures_uv, ) CONSOLE.print("Texture extracted.") CONSOLE.print("Texture shape:", texture_img.shape) CONSOLE.print("Saving textured mesh...") with torch.no_grad(): save_obj( mesh_save_path, verts=textured_mesh.verts_list()[0], faces=textured_mesh.faces_list()[0], verts_uvs=textured_mesh.textures.verts_uvs_list()[0], faces_uvs=textured_mesh.textures.faces_uvs_list()[0], texture_map=textured_mesh.textures.maps_padded()[0].clamp(0., 1.), ) CONSOLE.print("Texture saved at:", mesh_save_path) return mesh_save_path