|
import torch |
|
import spaces |
|
import gradio as gr |
|
import os |
|
import numpy as np |
|
import trimesh |
|
import mcubes |
|
import imageio |
|
from PIL import Image |
|
from transformers import AutoModel, AutoConfig |
|
from rembg import remove, new_session |
|
from functools import partial |
|
import kiui |
|
from gradio_litmodel3d import LitModel3D |
|
|
|
class VFusion3DGenerator: |
|
def __init__(self, model_name="facebook/vfusion3d"): |
|
""" |
|
Initialize the VFusion3D model |
|
|
|
Args: |
|
model_name (str): Hugging Face model identifier |
|
""" |
|
self.config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) |
|
self.model = AutoModel.from_pretrained(model_name, trust_remote_code=True) |
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
self.model.to(self.device) |
|
self.model.eval() |
|
|
|
|
|
self.rembg_session = new_session("isnet-general-use") |
|
|
|
def preprocess_image(self, image, source_size=512): |
|
""" |
|
Preprocess input image for VFusion3D model |
|
|
|
Args: |
|
image (PIL.Image): Input image |
|
source_size (int): Target image size |
|
|
|
Returns: |
|
torch.Tensor: Preprocessed image tensor |
|
""" |
|
rembg_remove = partial(remove, session=self.rembg_session) |
|
image = np.array(image) |
|
image = rembg_remove(image) |
|
mask = rembg_remove(image, only_mask=True) |
|
image = kiui.op.recenter(image, mask, border_ratio=0.20) |
|
|
|
image = torch.tensor(image).permute(2, 0, 1).unsqueeze(0) / 255.0 |
|
if image.shape[1] == 4: |
|
image = image[:, :3, ...] * image[:, 3:, ...] + (1 - image[:, 3:, ...]) |
|
|
|
image = torch.nn.functional.interpolate( |
|
image, |
|
size=(source_size, source_size), |
|
mode='bicubic', |
|
align_corners=True |
|
) |
|
return torch.clamp(image, 0, 1) |
|
|
|
def generate_3d_output(self, image, output_type='mesh', render_size=384, mesh_size=512): |
|
""" |
|
Generate 3D output (mesh or video) from input image |
|
|
|
Args: |
|
image (PIL.Image): Input image |
|
output_type (str): Type of output ('mesh' or 'video') |
|
render_size (int): Rendering size |
|
mesh_size (int): Mesh generation size |
|
|
|
Returns: |
|
str: Path to generated file |
|
""" |
|
|
|
image = self.preprocess_image(image).to(self.device) |
|
|
|
|
|
source_camera = self._get_default_source_camera(batch_size=1).to(self.device) |
|
|
|
with torch.no_grad(): |
|
|
|
planes = self.model(image, source_camera) |
|
|
|
if output_type == 'mesh': |
|
return self._generate_mesh(planes, mesh_size) |
|
elif output_type == 'video': |
|
return self._generate_video(planes, render_size) |
|
|
|
def _generate_mesh(self, planes, mesh_size=512): |
|
""" |
|
Generate 3D mesh from neural planes |
|
|
|
Args: |
|
planes: Neural representation planes |
|
mesh_size (int): Size of the mesh grid |
|
|
|
Returns: |
|
str: Path to saved mesh file |
|
""" |
|
from skimage import measure |
|
import numpy as np |
|
import trimesh |
|
|
|
|
|
grid_out = self.model.synthesizer.forward_grid(planes=planes, grid_size=mesh_size) |
|
|
|
|
|
sigma_grid = grid_out['sigma'].float().squeeze(0).squeeze(-1).cpu().numpy() |
|
|
|
|
|
vtx, faces, _, _ = measure.marching_cubes(sigma_grid, level=1.0) |
|
|
|
|
|
vtx = vtx / (mesh_size - 1) * 2 - 1 |
|
|
|
|
|
vtx_tensor = torch.tensor(vtx, dtype=torch.float32, device=self.device).unsqueeze(0) |
|
vtx_colors = self.model.synthesizer.forward_points(planes, vtx_tensor)['rgb'].float().squeeze(0).cpu().numpy() |
|
vtx_colors = (vtx_colors * 255).astype(np.uint8) |
|
|
|
|
|
mesh = trimesh.Trimesh(vertices=vtx, faces=faces, vertex_colors=vtx_colors) |
|
mesh_path = "generated_mesh.obj" |
|
mesh.export(mesh_path, 'obj') |
|
return mesh_path |
|
def _generate_video(self, planes, render_size=384, fps=30): |
|
""" |
|
Generate rotating video from neural planes |
|
|
|
Args: |
|
planes: Neural representation planes |
|
render_size (int): Size of rendered frames |
|
fps (int): Frames per second |
|
|
|
Returns: |
|
str: Path to saved video file |
|
""" |
|
render_cameras = self._get_default_render_cameras(batch_size=1).to(self.device) |
|
frames = [] |
|
|
|
for i in range(0, render_cameras.shape[1], 1): |
|
frame_chunk = self.model.synthesizer( |
|
planes, |
|
render_cameras[:, i:i + 1], |
|
render_size, |
|
render_size, |
|
0, |
|
0 |
|
) |
|
frames.append(frame_chunk['images_rgb']) |
|
|
|
frames = torch.cat(frames, dim=1) |
|
frames = frames.squeeze(0) |
|
frames = (frames.permute(0, 2, 3, 1).cpu().numpy() * 255).astype(np.uint8) |
|
|
|
video_path = "generated_video.mp4" |
|
imageio.mimwrite(video_path, frames, fps=fps) |
|
|
|
return video_path |
|
|
|
def _get_default_source_camera(self, batch_size=1): |
|
"""Generate default source camera parameters""" |
|
|
|
|
|
pass |
|
|
|
def _get_default_render_cameras(self, batch_size=1): |
|
"""Generate default render camera parameters""" |
|
|
|
|
|
pass |
|
|
|
|
|
def create_vfusion3d_interface(): |
|
generator = VFusion3DGenerator() |
|
|
|
with gr.Blocks() as demo: |
|
with gr.Row(): |
|
with gr.Column(): |
|
gr.Markdown("# VFusion3D Model Converter") |
|
input_image = gr.Image(type="pil", label="Upload Image") |
|
|
|
with gr.Row(): |
|
mesh_btn = gr.Button("Generate 3D Mesh") |
|
video_btn = gr.Button("Generate Rotation Video") |
|
|
|
mesh_output = gr.File(label="3D Mesh (.obj)") |
|
video_output = gr.File(label="Rotation Video") |
|
|
|
with gr.Column(): |
|
model_viewer = LitModel3D( |
|
label="3D Model Preview", |
|
scale=1.0, |
|
interactive=True |
|
) |
|
|
|
|
|
mesh_btn.click( |
|
fn=lambda img: ( |
|
generator.generate_3d_output(img, output_type='mesh'), |
|
generator.generate_3d_output(img, output_type='mesh') |
|
), |
|
inputs=input_image, |
|
outputs=[mesh_output, model_viewer] |
|
) |
|
|
|
video_btn.click( |
|
fn=lambda img: generator.generate_3d_output(img, output_type='video'), |
|
inputs=input_image, |
|
outputs=video_output |
|
) |
|
|
|
return demo |
|
|
|
|
|
if __name__ == "__main__": |
|
demo = create_vfusion3d_interface() |
|
demo.launch() |