import os import random import tempfile import time import zipfile from contextlib import nullcontext from functools import lru_cache from typing import Any import cv2 import gradio as gr import numpy as np import torch import trimesh from gradio_litmodel3d import LitModel3D from gradio_pointcloudeditor import PointCloudEditor from PIL import Image from transparent_background import Remover os.system("USE_CUDA=1 pip install -vv --no-build-isolation ./texture_baker ./uv_unwrapper") os.system("pip install ./deps/pynim-0.0.3-cp310-cp310-linux_x86_64.whl") import spar3d.utils as spar3d_utils from spar3d.models.mesh import QUAD_REMESH_AVAILABLE, TRIANGLE_REMESH_AVAILABLE from spar3d.system import SPAR3D os.environ["GRADIO_TEMP_DIR"] = os.path.join(os.environ.get("TMPDIR", "/tmp"), "gradio") bg_remover = Remover() # default setting COND_WIDTH = 512 COND_HEIGHT = 512 COND_DISTANCE = 2.2 COND_FOVY = 0.591627 BACKGROUND_COLOR = [0.5, 0.5, 0.5] # Cached. Doesn't change c2w_cond = spar3d_utils.default_cond_c2w(COND_DISTANCE) intrinsic, intrinsic_normed_cond = spar3d_utils.create_intrinsic_from_fov_rad( COND_FOVY, COND_HEIGHT, COND_WIDTH ) generated_files = [] # Delete previous gradio temp dir folder if os.path.exists(os.environ["GRADIO_TEMP_DIR"]): print(f"Deleting {os.environ['GRADIO_TEMP_DIR']}") import shutil shutil.rmtree(os.environ["GRADIO_TEMP_DIR"]) device = spar3d_utils.get_device() model = SPAR3D.from_pretrained( "stabilityai/stable-point-aware-3d", config_name="config.yaml", weight_name="model.safetensors", ) model.eval() model = model.to(device) example_files = [ os.path.join("demo_files/examples", f) for f in os.listdir("demo_files/examples") ] def create_zip_file(glb_file, pc_file, illumination_file): if not all([glb_file, pc_file, illumination_file]): return None # Create a temporary zip file temp_dir = tempfile.mkdtemp() zip_path = os.path.join(temp_dir, "spar3d_output.zip") with zipfile.ZipFile(zip_path, "w") as zipf: zipf.write(glb_file, "mesh.glb") zipf.write(pc_file, "points.ply") zipf.write(illumination_file, "illumination.hdr") generated_files.append(zip_path) return zip_path def forward_model( batch, system, guidance_scale=3.0, seed=0, device="cuda", remesh_option="none", vertex_count=-1, texture_resolution=1024, ): batch_size = batch["rgb_cond"].shape[0] # prepare the condition for point cloud generation # set seed random.seed(seed) torch.manual_seed(seed) np.random.seed(seed) cond_tokens = system.forward_pdiff_cond(batch) if "pc_cond" not in batch: sample_iter = system.sampler.sample_batch_progressive( batch_size, cond_tokens, guidance_scale=guidance_scale, device=device, ) for x in sample_iter: samples = x["xstart"] batch["pc_cond"] = samples.permute(0, 2, 1).float() batch["pc_cond"] = spar3d_utils.normalize_pc_bbox(batch["pc_cond"]) # subsample to the 512 points batch["pc_cond"] = batch["pc_cond"][ :, torch.randperm(batch["pc_cond"].shape[1])[:512] ] # get the point cloud xyz = batch["pc_cond"][0, :, :3].cpu().numpy() color_rgb = (batch["pc_cond"][0, :, 3:6] * 255).cpu().numpy().astype(np.uint8) pc_rgb_trimesh = trimesh.PointCloud(vertices=xyz, colors=color_rgb) # forward for the final mesh trimesh_mesh, _glob_dict = model.generate_mesh( batch, texture_resolution, remesh=remesh_option, vertex_count=vertex_count, estimate_illumination=True, ) trimesh_mesh = trimesh_mesh[0] illumination = _glob_dict["illumination"] return trimesh_mesh, pc_rgb_trimesh, illumination.cpu().detach().numpy()[0] def process_model_run( fr_res, guidance_scale, random_seed, pc_cond, remesh_option, vertex_count_type, vertex_count, texture_resolution, ): start = time.time() with torch.no_grad(): with ( torch.autocast(device_type=device, dtype=torch.bfloat16) if "cuda" in device else nullcontext() ): model_batch = create_batch(fr_res) model_batch = {k: v.to(device) for k, v in model_batch.items()} trimesh_mesh, trimesh_pc, illumination_map = forward_model( model_batch, model, guidance_scale=guidance_scale, seed=random_seed, device="cuda", remesh_option=remesh_option.lower(), vertex_count=vertex_count, texture_resolution=texture_resolution, ) # Create new tmp file temp_dir = tempfile.mkdtemp() tmp_file = os.path.join(temp_dir, "mesh.glb") trimesh_mesh.export(tmp_file, file_type="glb", include_normals=True) generated_files.append(tmp_file) tmp_file_pc = os.path.join(temp_dir, "points.ply") trimesh_pc.export(tmp_file_pc) generated_files.append(tmp_file_pc) tmp_file_illumination = os.path.join(temp_dir, "illumination.hdr") cv2.imwrite(tmp_file_illumination, illumination_map) generated_files.append(tmp_file_illumination) print("Generation took:", time.time() - start, "s") return tmp_file, tmp_file_pc, tmp_file_illumination, trimesh_pc def create_batch(input_image: Image) -> dict[str, Any]: img_cond = ( torch.from_numpy( np.asarray(input_image.resize((COND_WIDTH, COND_HEIGHT))).astype(np.float32) / 255.0 ) .float() .clip(0, 1) ) mask_cond = img_cond[:, :, -1:] rgb_cond = torch.lerp( torch.tensor(BACKGROUND_COLOR)[None, None, :], img_cond[:, :, :3], mask_cond ) batch_elem = { "rgb_cond": rgb_cond, "mask_cond": mask_cond, "c2w_cond": c2w_cond.unsqueeze(0), "intrinsic_cond": intrinsic.unsqueeze(0), "intrinsic_normed_cond": intrinsic_normed_cond.unsqueeze(0), } # Add batch dim batched = {k: v.unsqueeze(0) for k, v in batch_elem.items()} return batched def remove_background(input_image: Image) -> Image: return bg_remover.process(input_image.convert("RGB")) def auto_process(input_image): if input_image is None: return None, None, None, None # Default values guidance_scale = 3.0 random_seed = 0 foreground_ratio = 1.3 remesh_option = "None" vertex_count_type = "Keep Vertex Count" vertex_count = 2000 texture_resolution = 1024 no_crop = False pc_cond = None # First step: Remove background rem_removed = remove_background(input_image) fr_res = spar3d_utils.foreground_crop( rem_removed, crop_ratio=foreground_ratio, newsize=(COND_WIDTH, COND_HEIGHT), no_crop=no_crop, ) # Second step: Run model glb_file, pc_file, illumination_file, pc_list = process_model_run( fr_res, guidance_scale, random_seed, pc_cond, remesh_option, vertex_count_type, vertex_count, texture_resolution, ) zip_file = create_zip_file(glb_file, pc_file, illumination_file) return glb_file, illumination_file, zip_file, pc_list # Simplified interface with gr.Blocks() as demo: gr.Markdown( """ # SPAR3D: Stable Point-Aware Reconstruction of 3D Objects from Single Images Upload an image to generate a 3D model. """ ) with gr.Row(): with gr.Column(): input_img = gr.Image( type="pil", label="Upload Image", sources=["upload", "click"], image_mode="RGBA" ) with gr.Column(): output_3d = LitModel3D( label="3D Model", clear_color=[0.0, 0.0, 0.0, 0.0], tonemapping="aces", contrast=1.0, scale=1.0, ) download_all_btn = gr.File( label="Download Model (ZIP)", file_count="single", visible=True ) input_img.upload( auto_process, inputs=[input_img], outputs=[ output_3d, gr.State(), # for illumination file download_all_btn, gr.State(), # for point cloud list ], ) demo.queue().launch(share=False)