umeshhh commited on
Commit
4ffe53d
·
verified ·
1 Parent(s): eb4de38

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +217 -27
app.py CHANGED
@@ -1,29 +1,219 @@
 
 
1
  import gradio as gr
2
- from transformers import AutoProcessor, AutoModel
3
-
4
- # Model name
5
- model_name = "facebook/VFusion3D"
6
-
7
- # Load processor and model with trusted code
8
- processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
9
- model = AutoModel.from_pretrained(model_name, trust_remote_code=True)
10
-
11
- # Define prediction function
12
- def predict(input_text):
13
- # Convert input into a format the model understands
14
- inputs = processor(inputs=input_text, return_tensors="pt")
15
- outputs = model(**inputs)
16
- return outputs.logits.tolist()
17
-
18
- # Gradio interface
19
- interface = gr.Interface(
20
- fn=predict,
21
- inputs="text",
22
- outputs="text",
23
- title="VFusion3D Deployment",
24
- description="A demo for facebook/VFusion3D model."
25
- )
26
-
27
- # Launch the app
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  if __name__ == "__main__":
29
- interface.launch()
 
 
1
+ import torch
2
+ import spaces
3
  import gradio as gr
4
+ import os
5
+ import numpy as np
6
+ import trimesh
7
+ import mcubes
8
+ import imageio
9
+ from PIL import Image
10
+ from transformers import AutoModel, AutoConfig
11
+ from rembg import remove, new_session
12
+ from functools import partial
13
+ import kiui
14
+ from gradio_litmodel3d import LitModel3D
15
+
16
+ class VFusion3DGenerator:
17
+ def __init__(self, model_name="facebook/vfusion3d"):
18
+ """
19
+ Initialize the VFusion3D model
20
+
21
+ Args:
22
+ model_name (str): Hugging Face model identifier
23
+ """
24
+ self.config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
25
+ self.model = AutoModel.from_pretrained(model_name, trust_remote_code=True)
26
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
27
+ self.model.to(self.device)
28
+ self.model.eval()
29
+
30
+ # Background removal session
31
+ self.rembg_session = new_session("isnet-general-use")
32
+
33
+ def preprocess_image(self, image, source_size=512):
34
+ """
35
+ Preprocess input image for VFusion3D model
36
+
37
+ Args:
38
+ image (PIL.Image): Input image
39
+ source_size (int): Target image size
40
+
41
+ Returns:
42
+ torch.Tensor: Preprocessed image tensor
43
+ """
44
+ rembg_remove = partial(remove, session=self.rembg_session)
45
+ image = np.array(image)
46
+ image = rembg_remove(image)
47
+ mask = rembg_remove(image, only_mask=True)
48
+ image = kiui.op.recenter(image, mask, border_ratio=0.20)
49
+
50
+ image = torch.tensor(image).permute(2, 0, 1).unsqueeze(0) / 255.0
51
+ if image.shape[1] == 4:
52
+ image = image[:, :3, ...] * image[:, 3:, ...] + (1 - image[:, 3:, ...])
53
+
54
+ image = torch.nn.functional.interpolate(
55
+ image,
56
+ size=(source_size, source_size),
57
+ mode='bicubic',
58
+ align_corners=True
59
+ )
60
+ return torch.clamp(image, 0, 1)
61
+
62
+ def generate_3d_output(self, image, output_type='mesh', render_size=384, mesh_size=512):
63
+ """
64
+ Generate 3D output (mesh or video) from input image
65
+
66
+ Args:
67
+ image (PIL.Image): Input image
68
+ output_type (str): Type of output ('mesh' or 'video')
69
+ render_size (int): Rendering size
70
+ mesh_size (int): Mesh generation size
71
+
72
+ Returns:
73
+ str: Path to generated file
74
+ """
75
+ # Preprocess image
76
+ image = self.preprocess_image(image).to(self.device)
77
+
78
+ # Default camera settings (you might want to adjust these)
79
+ source_camera = self._get_default_source_camera(batch_size=1).to(self.device)
80
+
81
+ with torch.no_grad():
82
+ # Forward pass
83
+ planes = self.model(image, source_camera)
84
+
85
+ if output_type == 'mesh':
86
+ return self._generate_mesh(planes, mesh_size)
87
+ elif output_type == 'video':
88
+ return self._generate_video(planes, render_size)
89
+
90
+ def _generate_mesh(self, planes, mesh_size=512):
91
+ """
92
+ Generate 3D mesh from neural planes
93
+
94
+ Args:
95
+ planes: Neural representation planes
96
+ mesh_size (int): Size of the mesh grid
97
+
98
+ Returns:
99
+ str: Path to saved mesh file
100
+ """
101
+ from skimage import measure
102
+ import numpy as np
103
+ import trimesh
104
+
105
+ # Use scikit-image's marching cubes instead of mcubes
106
+ grid_out = self.model.synthesizer.forward_grid(planes=planes, grid_size=mesh_size)
107
+
108
+ # Extract the sigma grid and threshold
109
+ sigma_grid = grid_out['sigma'].float().squeeze(0).squeeze(-1).cpu().numpy()
110
+
111
+ # Use marching cubes from scikit-image
112
+ vtx, faces, _, _ = measure.marching_cubes(sigma_grid, level=1.0)
113
+
114
+ # Normalize vertices
115
+ vtx = vtx / (mesh_size - 1) * 2 - 1
116
+
117
+ # Color vertices
118
+ vtx_tensor = torch.tensor(vtx, dtype=torch.float32, device=self.device).unsqueeze(0)
119
+ vtx_colors = self.model.synthesizer.forward_points(planes, vtx_tensor)['rgb'].float().squeeze(0).cpu().numpy()
120
+ vtx_colors = (vtx_colors * 255).astype(np.uint8)
121
+
122
+ # Create and save mesh
123
+ mesh = trimesh.Trimesh(vertices=vtx, faces=faces, vertex_colors=vtx_colors)
124
+ mesh_path = "generated_mesh.obj"
125
+ mesh.export(mesh_path, 'obj')
126
+ return mesh_path
127
+ def _generate_video(self, planes, render_size=384, fps=30):
128
+ """
129
+ Generate rotating video from neural planes
130
+
131
+ Args:
132
+ planes: Neural representation planes
133
+ render_size (int): Size of rendered frames
134
+ fps (int): Frames per second
135
+
136
+ Returns:
137
+ str: Path to saved video file
138
+ """
139
+ render_cameras = self._get_default_render_cameras(batch_size=1).to(self.device)
140
+ frames = []
141
+
142
+ for i in range(0, render_cameras.shape[1], 1):
143
+ frame_chunk = self.model.synthesizer(
144
+ planes,
145
+ render_cameras[:, i:i + 1],
146
+ render_size,
147
+ render_size,
148
+ 0,
149
+ 0
150
+ )
151
+ frames.append(frame_chunk['images_rgb'])
152
+
153
+ frames = torch.cat(frames, dim=1)
154
+ frames = frames.squeeze(0)
155
+ frames = (frames.permute(0, 2, 3, 1).cpu().numpy() * 255).astype(np.uint8)
156
+
157
+ video_path = "generated_video.mp4"
158
+ imageio.mimwrite(video_path, frames, fps=fps)
159
+
160
+ return video_path
161
+
162
+ def _get_default_source_camera(self, batch_size=1):
163
+ """Generate default source camera parameters"""
164
+ # Implement camera generation logic here
165
+ # This is a placeholder and should match the original implementation
166
+ pass
167
+
168
+ def _get_default_render_cameras(self, batch_size=1):
169
+ """Generate default render camera parameters"""
170
+ # Implement render camera generation logic here
171
+ # This is a placeholder and should match the original implementation
172
+ pass
173
+
174
+ # Create Gradio Interface
175
+ def create_vfusion3d_interface():
176
+ generator = VFusion3DGenerator()
177
+
178
+ with gr.Blocks() as demo:
179
+ with gr.Row():
180
+ with gr.Column():
181
+ gr.Markdown("# VFusion3D Model Converter")
182
+ input_image = gr.Image(type="pil", label="Upload Image")
183
+
184
+ with gr.Row():
185
+ mesh_btn = gr.Button("Generate 3D Mesh")
186
+ video_btn = gr.Button("Generate Rotation Video")
187
+
188
+ mesh_output = gr.File(label="3D Mesh (.obj)")
189
+ video_output = gr.File(label="Rotation Video")
190
+
191
+ with gr.Column():
192
+ model_viewer = LitModel3D(
193
+ label="3D Model Preview",
194
+ scale=1.0,
195
+ interactive=True
196
+ )
197
+
198
+ # Button click events
199
+ mesh_btn.click(
200
+ fn=lambda img: (
201
+ generator.generate_3d_output(img, output_type='mesh'),
202
+ generator.generate_3d_output(img, output_type='mesh')
203
+ ),
204
+ inputs=input_image,
205
+ outputs=[mesh_output, model_viewer]
206
+ )
207
+
208
+ video_btn.click(
209
+ fn=lambda img: generator.generate_3d_output(img, output_type='video'),
210
+ inputs=input_image,
211
+ outputs=video_output
212
+ )
213
+
214
+ return demo
215
+
216
+ # Launch the interface
217
  if __name__ == "__main__":
218
+ demo = create_vfusion3d_interface()
219
+ demo.launch()