File size: 7,568 Bytes
4ffe53d 0899c0c 4ffe53d 0899c0c 4ffe53d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 |
import torch
import spaces
import gradio as gr
import os
import numpy as np
import trimesh
import mcubes
import imageio
from PIL import Image
from transformers import AutoModel, AutoConfig
from rembg import remove, new_session
from functools import partial
import kiui
from gradio_litmodel3d import LitModel3D
class VFusion3DGenerator:
def __init__(self, model_name="facebook/vfusion3d"):
"""
Initialize the VFusion3D model
Args:
model_name (str): Hugging Face model identifier
"""
self.config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
self.model = AutoModel.from_pretrained(model_name, trust_remote_code=True)
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.model.to(self.device)
self.model.eval()
# Background removal session
self.rembg_session = new_session("isnet-general-use")
def preprocess_image(self, image, source_size=512):
"""
Preprocess input image for VFusion3D model
Args:
image (PIL.Image): Input image
source_size (int): Target image size
Returns:
torch.Tensor: Preprocessed image tensor
"""
rembg_remove = partial(remove, session=self.rembg_session)
image = np.array(image)
image = rembg_remove(image)
mask = rembg_remove(image, only_mask=True)
image = kiui.op.recenter(image, mask, border_ratio=0.20)
image = torch.tensor(image).permute(2, 0, 1).unsqueeze(0) / 255.0
if image.shape[1] == 4:
image = image[:, :3, ...] * image[:, 3:, ...] + (1 - image[:, 3:, ...])
image = torch.nn.functional.interpolate(
image,
size=(source_size, source_size),
mode='bicubic',
align_corners=True
)
return torch.clamp(image, 0, 1)
def generate_3d_output(self, image, output_type='mesh', render_size=384, mesh_size=512):
"""
Generate 3D output (mesh or video) from input image
Args:
image (PIL.Image): Input image
output_type (str): Type of output ('mesh' or 'video')
render_size (int): Rendering size
mesh_size (int): Mesh generation size
Returns:
str: Path to generated file
"""
# Preprocess image
image = self.preprocess_image(image).to(self.device)
# Default camera settings (you might want to adjust these)
source_camera = self._get_default_source_camera(batch_size=1).to(self.device)
with torch.no_grad():
# Forward pass
planes = self.model(image, source_camera)
if output_type == 'mesh':
return self._generate_mesh(planes, mesh_size)
elif output_type == 'video':
return self._generate_video(planes, render_size)
def _generate_mesh(self, planes, mesh_size=512):
"""
Generate 3D mesh from neural planes
Args:
planes: Neural representation planes
mesh_size (int): Size of the mesh grid
Returns:
str: Path to saved mesh file
"""
from skimage import measure
import numpy as np
import trimesh
# Use scikit-image's marching cubes instead of mcubes
grid_out = self.model.synthesizer.forward_grid(planes=planes, grid_size=mesh_size)
# Extract the sigma grid and threshold
sigma_grid = grid_out['sigma'].float().squeeze(0).squeeze(-1).cpu().numpy()
# Use marching cubes from scikit-image
vtx, faces, _, _ = measure.marching_cubes(sigma_grid, level=1.0)
# Normalize vertices
vtx = vtx / (mesh_size - 1) * 2 - 1
# Color vertices
vtx_tensor = torch.tensor(vtx, dtype=torch.float32, device=self.device).unsqueeze(0)
vtx_colors = self.model.synthesizer.forward_points(planes, vtx_tensor)['rgb'].float().squeeze(0).cpu().numpy()
vtx_colors = (vtx_colors * 255).astype(np.uint8)
# Create and save mesh
mesh = trimesh.Trimesh(vertices=vtx, faces=faces, vertex_colors=vtx_colors)
mesh_path = "generated_mesh.obj"
mesh.export(mesh_path, 'obj')
return mesh_path
def _generate_video(self, planes, render_size=384, fps=30):
"""
Generate rotating video from neural planes
Args:
planes: Neural representation planes
render_size (int): Size of rendered frames
fps (int): Frames per second
Returns:
str: Path to saved video file
"""
render_cameras = self._get_default_render_cameras(batch_size=1).to(self.device)
frames = []
for i in range(0, render_cameras.shape[1], 1):
frame_chunk = self.model.synthesizer(
planes,
render_cameras[:, i:i + 1],
render_size,
render_size,
0,
0
)
frames.append(frame_chunk['images_rgb'])
frames = torch.cat(frames, dim=1)
frames = frames.squeeze(0)
frames = (frames.permute(0, 2, 3, 1).cpu().numpy() * 255).astype(np.uint8)
video_path = "generated_video.mp4"
imageio.mimwrite(video_path, frames, fps=fps)
return video_path
def _get_default_source_camera(self, batch_size=1):
"""Generate default source camera parameters"""
# Implement camera generation logic here
# This is a placeholder and should match the original implementation
pass
def _get_default_render_cameras(self, batch_size=1):
"""Generate default render camera parameters"""
# Implement render camera generation logic here
# This is a placeholder and should match the original implementation
pass
# Create Gradio Interface
def create_vfusion3d_interface():
generator = VFusion3DGenerator()
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
gr.Markdown("# VFusion3D Model Converter")
input_image = gr.Image(type="pil", label="Upload Image")
with gr.Row():
mesh_btn = gr.Button("Generate 3D Mesh")
video_btn = gr.Button("Generate Rotation Video")
mesh_output = gr.File(label="3D Mesh (.obj)")
video_output = gr.File(label="Rotation Video")
with gr.Column():
model_viewer = LitModel3D(
label="3D Model Preview",
scale=1.0,
interactive=True
)
# Button click events
mesh_btn.click(
fn=lambda img: (
generator.generate_3d_output(img, output_type='mesh'),
generator.generate_3d_output(img, output_type='mesh')
),
inputs=input_image,
outputs=[mesh_output, model_viewer]
)
video_btn.click(
fn=lambda img: generator.generate_3d_output(img, output_type='video'),
inputs=input_image,
outputs=video_output
)
return demo
# Launch the interface
if __name__ == "__main__":
demo = create_vfusion3d_interface()
demo.launch() |