File size: 7,568 Bytes
4ffe53d
 
0899c0c
4ffe53d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0899c0c
4ffe53d
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
import torch
import spaces
import gradio as gr
import os
import numpy as np
import trimesh
import mcubes
import imageio
from PIL import Image
from transformers import AutoModel, AutoConfig
from rembg import remove, new_session
from functools import partial
import kiui
from gradio_litmodel3d import LitModel3D

class VFusion3DGenerator:
    def __init__(self, model_name="facebook/vfusion3d"):
        """
        Initialize the VFusion3D model
        
        Args:
            model_name (str): Hugging Face model identifier
        """
        self.config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
        self.model = AutoModel.from_pretrained(model_name, trust_remote_code=True)
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.model.to(self.device)
        self.model.eval()
        
        # Background removal session
        self.rembg_session = new_session("isnet-general-use")

    def preprocess_image(self, image, source_size=512):
        """
        Preprocess input image for VFusion3D model
        
        Args:
            image (PIL.Image): Input image
            source_size (int): Target image size
        
        Returns:
            torch.Tensor: Preprocessed image tensor
        """
        rembg_remove = partial(remove, session=self.rembg_session)
        image = np.array(image)
        image = rembg_remove(image)
        mask = rembg_remove(image, only_mask=True)
        image = kiui.op.recenter(image, mask, border_ratio=0.20)
        
        image = torch.tensor(image).permute(2, 0, 1).unsqueeze(0) / 255.0
        if image.shape[1] == 4:
            image = image[:, :3, ...] * image[:, 3:, ...] + (1 - image[:, 3:, ...])
        
        image = torch.nn.functional.interpolate(
            image, 
            size=(source_size, source_size), 
            mode='bicubic', 
            align_corners=True
        )
        return torch.clamp(image, 0, 1)

    def generate_3d_output(self, image, output_type='mesh', render_size=384, mesh_size=512):
        """
        Generate 3D output (mesh or video) from input image
        
        Args:
            image (PIL.Image): Input image
            output_type (str): Type of output ('mesh' or 'video')
            render_size (int): Rendering size
            mesh_size (int): Mesh generation size
        
        Returns:
            str: Path to generated file
        """
        # Preprocess image
        image = self.preprocess_image(image).to(self.device)
        
        # Default camera settings (you might want to adjust these)
        source_camera = self._get_default_source_camera(batch_size=1).to(self.device)
        
        with torch.no_grad():
            # Forward pass
            planes = self.model(image, source_camera)
            
            if output_type == 'mesh':
                return self._generate_mesh(planes, mesh_size)
            elif output_type == 'video':
                return self._generate_video(planes, render_size)
    
    def _generate_mesh(self, planes, mesh_size=512):
        """
        Generate 3D mesh from neural planes
    
        Args:
        planes: Neural representation planes
        mesh_size (int): Size of the mesh grid
    
        Returns:
            str: Path to saved mesh file
        """
        from skimage import measure
        import numpy as np
        import trimesh

    # Use scikit-image's marching cubes instead of mcubes
        grid_out = self.model.synthesizer.forward_grid(planes=planes, grid_size=mesh_size)
    
    # Extract the sigma grid and threshold
        sigma_grid = grid_out['sigma'].float().squeeze(0).squeeze(-1).cpu().numpy()
    
    # Use marching cubes from scikit-image
        vtx, faces, _, _ = measure.marching_cubes(sigma_grid, level=1.0)
    
    # Normalize vertices
        vtx = vtx / (mesh_size - 1) * 2 - 1
    
    # Color vertices
        vtx_tensor = torch.tensor(vtx, dtype=torch.float32, device=self.device).unsqueeze(0)
        vtx_colors = self.model.synthesizer.forward_points(planes, vtx_tensor)['rgb'].float().squeeze(0).cpu().numpy()
        vtx_colors = (vtx_colors * 255).astype(np.uint8)
    
    # Create and save mesh
        mesh = trimesh.Trimesh(vertices=vtx, faces=faces, vertex_colors=vtx_colors)
        mesh_path = "generated_mesh.obj"
        mesh.export(mesh_path, 'obj')
        return mesh_path
    def _generate_video(self, planes, render_size=384, fps=30):
        """
        Generate rotating video from neural planes
        
        Args:
            planes: Neural representation planes
            render_size (int): Size of rendered frames
            fps (int): Frames per second
        
        Returns:
            str: Path to saved video file
        """
        render_cameras = self._get_default_render_cameras(batch_size=1).to(self.device)
        frames = []
        
        for i in range(0, render_cameras.shape[1], 1):
            frame_chunk = self.model.synthesizer(
                planes,
                render_cameras[:, i:i + 1],
                render_size,
                render_size,
                0,
                0
            )
            frames.append(frame_chunk['images_rgb'])
        
        frames = torch.cat(frames, dim=1)
        frames = frames.squeeze(0)
        frames = (frames.permute(0, 2, 3, 1).cpu().numpy() * 255).astype(np.uint8)
        
        video_path = "generated_video.mp4"
        imageio.mimwrite(video_path, frames, fps=fps)
        
        return video_path

    def _get_default_source_camera(self, batch_size=1):
        """Generate default source camera parameters"""
        # Implement camera generation logic here
        # This is a placeholder and should match the original implementation
        pass

    def _get_default_render_cameras(self, batch_size=1):
        """Generate default render camera parameters"""
        # Implement render camera generation logic here
        # This is a placeholder and should match the original implementation
        pass

# Create Gradio Interface
def create_vfusion3d_interface():
    generator = VFusion3DGenerator()

    with gr.Blocks() as demo:
        with gr.Row():
            with gr.Column():
                gr.Markdown("# VFusion3D Model Converter")
                input_image = gr.Image(type="pil", label="Upload Image")
                
                with gr.Row():
                    mesh_btn = gr.Button("Generate 3D Mesh")
                    video_btn = gr.Button("Generate Rotation Video")
                
                mesh_output = gr.File(label="3D Mesh (.obj)")
                video_output = gr.File(label="Rotation Video")
            
            with gr.Column():
                model_viewer = LitModel3D(
                    label="3D Model Preview",
                    scale=1.0,
                    interactive=True
                )
        
        # Button click events
        mesh_btn.click(
            fn=lambda img: (
                generator.generate_3d_output(img, output_type='mesh'),
                generator.generate_3d_output(img, output_type='mesh')
            ),
            inputs=input_image,
            outputs=[mesh_output, model_viewer]
        )
        
        video_btn.click(
            fn=lambda img: generator.generate_3d_output(img, output_type='video'),
            inputs=input_image,
            outputs=video_output
        )
    
    return demo

# Launch the interface
if __name__ == "__main__":
    demo = create_vfusion3d_interface()
    demo.launch()