ekta / app.py
umeshhh's picture
Update app.py
4ffe53d verified
import torch
import spaces
import gradio as gr
import os
import numpy as np
import trimesh
import mcubes
import imageio
from PIL import Image
from transformers import AutoModel, AutoConfig
from rembg import remove, new_session
from functools import partial
import kiui
from gradio_litmodel3d import LitModel3D
class VFusion3DGenerator:
def __init__(self, model_name="facebook/vfusion3d"):
"""
Initialize the VFusion3D model
Args:
model_name (str): Hugging Face model identifier
"""
self.config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
self.model = AutoModel.from_pretrained(model_name, trust_remote_code=True)
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.model.to(self.device)
self.model.eval()
# Background removal session
self.rembg_session = new_session("isnet-general-use")
def preprocess_image(self, image, source_size=512):
"""
Preprocess input image for VFusion3D model
Args:
image (PIL.Image): Input image
source_size (int): Target image size
Returns:
torch.Tensor: Preprocessed image tensor
"""
rembg_remove = partial(remove, session=self.rembg_session)
image = np.array(image)
image = rembg_remove(image)
mask = rembg_remove(image, only_mask=True)
image = kiui.op.recenter(image, mask, border_ratio=0.20)
image = torch.tensor(image).permute(2, 0, 1).unsqueeze(0) / 255.0
if image.shape[1] == 4:
image = image[:, :3, ...] * image[:, 3:, ...] + (1 - image[:, 3:, ...])
image = torch.nn.functional.interpolate(
image,
size=(source_size, source_size),
mode='bicubic',
align_corners=True
)
return torch.clamp(image, 0, 1)
def generate_3d_output(self, image, output_type='mesh', render_size=384, mesh_size=512):
"""
Generate 3D output (mesh or video) from input image
Args:
image (PIL.Image): Input image
output_type (str): Type of output ('mesh' or 'video')
render_size (int): Rendering size
mesh_size (int): Mesh generation size
Returns:
str: Path to generated file
"""
# Preprocess image
image = self.preprocess_image(image).to(self.device)
# Default camera settings (you might want to adjust these)
source_camera = self._get_default_source_camera(batch_size=1).to(self.device)
with torch.no_grad():
# Forward pass
planes = self.model(image, source_camera)
if output_type == 'mesh':
return self._generate_mesh(planes, mesh_size)
elif output_type == 'video':
return self._generate_video(planes, render_size)
def _generate_mesh(self, planes, mesh_size=512):
"""
Generate 3D mesh from neural planes
Args:
planes: Neural representation planes
mesh_size (int): Size of the mesh grid
Returns:
str: Path to saved mesh file
"""
from skimage import measure
import numpy as np
import trimesh
# Use scikit-image's marching cubes instead of mcubes
grid_out = self.model.synthesizer.forward_grid(planes=planes, grid_size=mesh_size)
# Extract the sigma grid and threshold
sigma_grid = grid_out['sigma'].float().squeeze(0).squeeze(-1).cpu().numpy()
# Use marching cubes from scikit-image
vtx, faces, _, _ = measure.marching_cubes(sigma_grid, level=1.0)
# Normalize vertices
vtx = vtx / (mesh_size - 1) * 2 - 1
# Color vertices
vtx_tensor = torch.tensor(vtx, dtype=torch.float32, device=self.device).unsqueeze(0)
vtx_colors = self.model.synthesizer.forward_points(planes, vtx_tensor)['rgb'].float().squeeze(0).cpu().numpy()
vtx_colors = (vtx_colors * 255).astype(np.uint8)
# Create and save mesh
mesh = trimesh.Trimesh(vertices=vtx, faces=faces, vertex_colors=vtx_colors)
mesh_path = "generated_mesh.obj"
mesh.export(mesh_path, 'obj')
return mesh_path
def _generate_video(self, planes, render_size=384, fps=30):
"""
Generate rotating video from neural planes
Args:
planes: Neural representation planes
render_size (int): Size of rendered frames
fps (int): Frames per second
Returns:
str: Path to saved video file
"""
render_cameras = self._get_default_render_cameras(batch_size=1).to(self.device)
frames = []
for i in range(0, render_cameras.shape[1], 1):
frame_chunk = self.model.synthesizer(
planes,
render_cameras[:, i:i + 1],
render_size,
render_size,
0,
0
)
frames.append(frame_chunk['images_rgb'])
frames = torch.cat(frames, dim=1)
frames = frames.squeeze(0)
frames = (frames.permute(0, 2, 3, 1).cpu().numpy() * 255).astype(np.uint8)
video_path = "generated_video.mp4"
imageio.mimwrite(video_path, frames, fps=fps)
return video_path
def _get_default_source_camera(self, batch_size=1):
"""Generate default source camera parameters"""
# Implement camera generation logic here
# This is a placeholder and should match the original implementation
pass
def _get_default_render_cameras(self, batch_size=1):
"""Generate default render camera parameters"""
# Implement render camera generation logic here
# This is a placeholder and should match the original implementation
pass
# Create Gradio Interface
def create_vfusion3d_interface():
generator = VFusion3DGenerator()
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
gr.Markdown("# VFusion3D Model Converter")
input_image = gr.Image(type="pil", label="Upload Image")
with gr.Row():
mesh_btn = gr.Button("Generate 3D Mesh")
video_btn = gr.Button("Generate Rotation Video")
mesh_output = gr.File(label="3D Mesh (.obj)")
video_output = gr.File(label="Rotation Video")
with gr.Column():
model_viewer = LitModel3D(
label="3D Model Preview",
scale=1.0,
interactive=True
)
# Button click events
mesh_btn.click(
fn=lambda img: (
generator.generate_3d_output(img, output_type='mesh'),
generator.generate_3d_output(img, output_type='mesh')
),
inputs=input_image,
outputs=[mesh_output, model_viewer]
)
video_btn.click(
fn=lambda img: generator.generate_3d_output(img, output_type='video'),
inputs=input_image,
outputs=video_output
)
return demo
# Launch the interface
if __name__ == "__main__":
demo = create_vfusion3d_interface()
demo.launch()