import torch import spaces import gradio as gr import os import numpy as np import trimesh import mcubes import imageio from PIL import Image from transformers import AutoModel, AutoConfig from rembg import remove, new_session from functools import partial import kiui from gradio_litmodel3d import LitModel3D class VFusion3DGenerator: def __init__(self, model_name="facebook/vfusion3d"): """ Initialize the VFusion3D model Args: model_name (str): Hugging Face model identifier """ self.config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) self.model = AutoModel.from_pretrained(model_name, trust_remote_code=True) self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') self.model.to(self.device) self.model.eval() # Background removal session self.rembg_session = new_session("isnet-general-use") def preprocess_image(self, image, source_size=512): """ Preprocess input image for VFusion3D model Args: image (PIL.Image): Input image source_size (int): Target image size Returns: torch.Tensor: Preprocessed image tensor """ rembg_remove = partial(remove, session=self.rembg_session) image = np.array(image) image = rembg_remove(image) mask = rembg_remove(image, only_mask=True) image = kiui.op.recenter(image, mask, border_ratio=0.20) image = torch.tensor(image).permute(2, 0, 1).unsqueeze(0) / 255.0 if image.shape[1] == 4: image = image[:, :3, ...] * image[:, 3:, ...] + (1 - image[:, 3:, ...]) image = torch.nn.functional.interpolate( image, size=(source_size, source_size), mode='bicubic', align_corners=True ) return torch.clamp(image, 0, 1) def generate_3d_output(self, image, output_type='mesh', render_size=384, mesh_size=512): """ Generate 3D output (mesh or video) from input image Args: image (PIL.Image): Input image output_type (str): Type of output ('mesh' or 'video') render_size (int): Rendering size mesh_size (int): Mesh generation size Returns: str: Path to generated file """ # Preprocess image image = self.preprocess_image(image).to(self.device) # Default camera settings (you might want to adjust these) source_camera = self._get_default_source_camera(batch_size=1).to(self.device) with torch.no_grad(): # Forward pass planes = self.model(image, source_camera) if output_type == 'mesh': return self._generate_mesh(planes, mesh_size) elif output_type == 'video': return self._generate_video(planes, render_size) def _generate_mesh(self, planes, mesh_size=512): """ Generate 3D mesh from neural planes Args: planes: Neural representation planes mesh_size (int): Size of the mesh grid Returns: str: Path to saved mesh file """ from skimage import measure import numpy as np import trimesh # Use scikit-image's marching cubes instead of mcubes grid_out = self.model.synthesizer.forward_grid(planes=planes, grid_size=mesh_size) # Extract the sigma grid and threshold sigma_grid = grid_out['sigma'].float().squeeze(0).squeeze(-1).cpu().numpy() # Use marching cubes from scikit-image vtx, faces, _, _ = measure.marching_cubes(sigma_grid, level=1.0) # Normalize vertices vtx = vtx / (mesh_size - 1) * 2 - 1 # Color vertices vtx_tensor = torch.tensor(vtx, dtype=torch.float32, device=self.device).unsqueeze(0) vtx_colors = self.model.synthesizer.forward_points(planes, vtx_tensor)['rgb'].float().squeeze(0).cpu().numpy() vtx_colors = (vtx_colors * 255).astype(np.uint8) # Create and save mesh mesh = trimesh.Trimesh(vertices=vtx, faces=faces, vertex_colors=vtx_colors) mesh_path = "generated_mesh.obj" mesh.export(mesh_path, 'obj') return mesh_path def _generate_video(self, planes, render_size=384, fps=30): """ Generate rotating video from neural planes Args: planes: Neural representation planes render_size (int): Size of rendered frames fps (int): Frames per second Returns: str: Path to saved video file """ render_cameras = self._get_default_render_cameras(batch_size=1).to(self.device) frames = [] for i in range(0, render_cameras.shape[1], 1): frame_chunk = self.model.synthesizer( planes, render_cameras[:, i:i + 1], render_size, render_size, 0, 0 ) frames.append(frame_chunk['images_rgb']) frames = torch.cat(frames, dim=1) frames = frames.squeeze(0) frames = (frames.permute(0, 2, 3, 1).cpu().numpy() * 255).astype(np.uint8) video_path = "generated_video.mp4" imageio.mimwrite(video_path, frames, fps=fps) return video_path def _get_default_source_camera(self, batch_size=1): """Generate default source camera parameters""" # Implement camera generation logic here # This is a placeholder and should match the original implementation pass def _get_default_render_cameras(self, batch_size=1): """Generate default render camera parameters""" # Implement render camera generation logic here # This is a placeholder and should match the original implementation pass # Create Gradio Interface def create_vfusion3d_interface(): generator = VFusion3DGenerator() with gr.Blocks() as demo: with gr.Row(): with gr.Column(): gr.Markdown("# VFusion3D Model Converter") input_image = gr.Image(type="pil", label="Upload Image") with gr.Row(): mesh_btn = gr.Button("Generate 3D Mesh") video_btn = gr.Button("Generate Rotation Video") mesh_output = gr.File(label="3D Mesh (.obj)") video_output = gr.File(label="Rotation Video") with gr.Column(): model_viewer = LitModel3D( label="3D Model Preview", scale=1.0, interactive=True ) # Button click events mesh_btn.click( fn=lambda img: ( generator.generate_3d_output(img, output_type='mesh'), generator.generate_3d_output(img, output_type='mesh') ), inputs=input_image, outputs=[mesh_output, model_viewer] ) video_btn.click( fn=lambda img: generator.generate_3d_output(img, output_type='video'), inputs=input_image, outputs=video_output ) return demo # Launch the interface if __name__ == "__main__": demo = create_vfusion3d_interface() demo.launch()