Spaces:
Runtime error
Runtime error
Commit
·
b440279
1
Parent(s):
b1809a9
Upload 2 files
Browse files- app.py +252 -0
- requirements.txt +19 -0
app.py
ADDED
@@ -0,0 +1,252 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import os
|
3 |
+
import re
|
4 |
+
from typing import List, Optional, Tuple, Union
|
5 |
+
|
6 |
+
import click
|
7 |
+
import dnnlib
|
8 |
+
import numpy as np
|
9 |
+
import PIL.Image
|
10 |
+
import torch
|
11 |
+
from tqdm import tqdm
|
12 |
+
|
13 |
+
|
14 |
+
import legacy
|
15 |
+
from camera_utils import LookAtPoseSampler
|
16 |
+
from huggingface_hub import hf_hub_download
|
17 |
+
|
18 |
+
from matplotlib import pyplot as plt
|
19 |
+
|
20 |
+
from pathlib import Path
|
21 |
+
|
22 |
+
import json
|
23 |
+
import gradio as gr
|
24 |
+
|
25 |
+
from training.utils import color_mask, color_list
|
26 |
+
import plotly.graph_objects as go
|
27 |
+
from tqdm import tqdm
|
28 |
+
|
29 |
+
import imageio
|
30 |
+
|
31 |
+
import argparse
|
32 |
+
|
33 |
+
import trimesh
|
34 |
+
import pyrender
|
35 |
+
import mcubes
|
36 |
+
|
37 |
+
os.environ["PYOPENGL_PLATFORM"] = "egl"
|
38 |
+
|
39 |
+
|
40 |
+
def get_sigma_field_np(nerf, styles, resolution=512, block_resolution=64):
|
41 |
+
# return numpy array of forwarded sigma value
|
42 |
+
# bound = (nerf.rendering_kwargs['ray_end'] - nerf.rendering_kwargs['ray_start']) * 0.5
|
43 |
+
bound = nerf.rendering_kwargs['box_warp'] * 0.5
|
44 |
+
X = torch.linspace(-bound, bound, resolution).split(block_resolution)
|
45 |
+
|
46 |
+
sigma_np = np.zeros([resolution, resolution, resolution], dtype=np.float32)
|
47 |
+
|
48 |
+
for xi, xs in enumerate(X):
|
49 |
+
for yi, ys in enumerate(X):
|
50 |
+
for zi, zs in enumerate(X):
|
51 |
+
xx, yy, zz = torch.meshgrid(xs, ys, zs)
|
52 |
+
pts = torch.stack([xx, yy, zz], dim=-1).unsqueeze(0).to(styles.device) # B, H, H, H, C
|
53 |
+
block_shape = [1, len(xs), len(ys), len(zs)]
|
54 |
+
out = nerf.sample_mixed(pts.reshape(1,-1,3), None, ws=styles, noise_mode='const')
|
55 |
+
feat_out, sigma_out = out['rgb'], out['sigma']
|
56 |
+
sigma_np[xi * block_resolution: xi * block_resolution + len(xs), \
|
57 |
+
yi * block_resolution: yi * block_resolution + len(ys), \
|
58 |
+
zi * block_resolution: zi * block_resolution + len(zs)] = sigma_out.reshape(block_shape[1:]).detach().cpu().numpy()
|
59 |
+
# print(feat_out.shape)
|
60 |
+
|
61 |
+
return sigma_np, bound
|
62 |
+
|
63 |
+
|
64 |
+
def extract_geometry(nerf, styles, resolution, threshold):
|
65 |
+
|
66 |
+
# print('threshold: {}'.format(threshold))
|
67 |
+
u, bound = get_sigma_field_np(nerf, styles, resolution)
|
68 |
+
vertices, faces = mcubes.marching_cubes(u, threshold)
|
69 |
+
# vertices, faces, normals, values = skimage.measure.marching_cubes(
|
70 |
+
# u, level=10
|
71 |
+
# )
|
72 |
+
b_min_np = np.array([-bound, -bound, -bound])
|
73 |
+
b_max_np = np.array([ bound, bound, bound])
|
74 |
+
|
75 |
+
vertices = vertices / (resolution - 1.0) * (b_max_np - b_min_np)[None, :] + b_min_np[None, :]
|
76 |
+
return vertices.astype('float32'), faces
|
77 |
+
def render_video(G, ws, intrinsics, num_frames = 120, pitch_range = 0.25, yaw_range = 0.35, neural_rendering_resolution = 128, device='cuda'):
|
78 |
+
frames, frames_label = [], []
|
79 |
+
|
80 |
+
for frame_idx in tqdm(range(num_frames)):
|
81 |
+
cam2world_pose = LookAtPoseSampler.sample(3.14/2 + yaw_range * np.sin(2 * 3.14 * frame_idx / num_frames),
|
82 |
+
3.14/2 -0.05 + pitch_range * np.cos(2 * 3.14 * frame_idx / num_frames),
|
83 |
+
torch.tensor(G.rendering_kwargs['avg_camera_pivot'], device=device), radius=G.rendering_kwargs['avg_camera_radius'], device=device)
|
84 |
+
pose = torch.cat([cam2world_pose.reshape(-1, 16), intrinsics.reshape(-1, 9)], 1)
|
85 |
+
with torch.no_grad():
|
86 |
+
# out = G(z, pose, {'mask': batch['mask'].unsqueeze(0).to(device), 'pose': torch.tensor(batch['pose']).unsqueeze(0).to(device)})
|
87 |
+
out = G.synthesis(ws, pose, noise_mode='const', neural_rendering_resolution=neural_rendering_resolution)
|
88 |
+
frames.append(((out['image'].cpu().numpy()[0] + 1) * 127.5).clip(0, 255).astype(np.uint8).transpose(1, 2, 0))
|
89 |
+
frames_label.append(color_mask(torch.argmax(out['semantic'], dim=1).cpu().numpy()[0]).astype(np.uint8))
|
90 |
+
|
91 |
+
return frames, frames_label
|
92 |
+
|
93 |
+
def return_plot_go(mesh_trimesh):
|
94 |
+
x=np.asarray(mesh_trimesh.vertices).T[0]
|
95 |
+
y=np.asarray(mesh_trimesh.vertices).T[1]
|
96 |
+
z=np.asarray(mesh_trimesh.vertices).T[2]
|
97 |
+
|
98 |
+
i=np.asarray(mesh_trimesh.faces).T[0]
|
99 |
+
j=np.asarray(mesh_trimesh.faces).T[1]
|
100 |
+
k=np.asarray(mesh_trimesh.faces).T[2]
|
101 |
+
fig = go.Figure(go.Mesh3d(x=x, y=y, z=z,
|
102 |
+
i=i, j=j, k=k,
|
103 |
+
vertexcolor=np.asarray(mesh_trimesh.visual.vertex_colors) ,
|
104 |
+
lighting=dict(ambient=0.5,
|
105 |
+
diffuse=1,
|
106 |
+
fresnel=4,
|
107 |
+
specular=0.5,
|
108 |
+
roughness=0.05,
|
109 |
+
facenormalsepsilon=0,
|
110 |
+
vertexnormalsepsilon=0),
|
111 |
+
lightposition=dict(x=100,
|
112 |
+
y=100,
|
113 |
+
z=1000)))
|
114 |
+
return fig
|
115 |
+
|
116 |
+
|
117 |
+
|
118 |
+
network_cat=hf_hub_download("SerdarHelli/pix2pix3d_seg2cat", filename="pix2pix3d_seg2cat.pkl",revision="main")
|
119 |
+
|
120 |
+
models={"seg2cat":network_cat
|
121 |
+
}
|
122 |
+
|
123 |
+
device='cuda' if torch.cuda.is_available() else 'cpu'
|
124 |
+
outdir="/content/"
|
125 |
+
|
126 |
+
def get_all(cfg,input,truncation_psi,mesh_resolution,random_seed,fps,num_frames):
|
127 |
+
|
128 |
+
newtork=models[cfg]
|
129 |
+
|
130 |
+
with dnnlib.util.open_url(network) as f:
|
131 |
+
G = legacy.load_network_pkl(f)['G_ema'].eval().to(device)
|
132 |
+
|
133 |
+
if cfg == 'seg2cat' or cfg == 'seg2face':
|
134 |
+
neural_rendering_resolution = 128
|
135 |
+
data_type = 'seg'
|
136 |
+
# Initialize pose sampler.
|
137 |
+
forward_cam2world_pose = LookAtPoseSampler.sample(3.14/2, 3.14/2, torch.tensor(G.rendering_kwargs['avg_camera_pivot'], device=device),
|
138 |
+
radius=G.rendering_kwargs['avg_camera_radius'], device=device)
|
139 |
+
focal_length = 4.2647 # shapenet has higher FOV
|
140 |
+
intrinsics = torch.tensor([[focal_length, 0, 0.5], [0, focal_length, 0.5], [0, 0, 1]], device=device)
|
141 |
+
forward_pose = torch.cat([forward_cam2world_pose.reshape(-1, 16), intrinsics.reshape(-1, 9)], 1)
|
142 |
+
elif cfg == 'edge2car':
|
143 |
+
neural_rendering_resolution = 64
|
144 |
+
data_type= 'edge'
|
145 |
+
else:
|
146 |
+
print('Invalid cfg')
|
147 |
+
|
148 |
+
save_dir = Path(outdir)
|
149 |
+
|
150 |
+
input_label = PIL.Image.open(input)
|
151 |
+
input_label = PIL.ImageOps.grayscale(input_label)
|
152 |
+
input_label = np.asarray(input_label).astype(np.uint8)
|
153 |
+
input_label = torch.from_numpy(input_label).unsqueeze(0).unsqueeze(0).to(device)
|
154 |
+
print(input_label.shape)
|
155 |
+
input_pose = forward_pose.to(device)
|
156 |
+
|
157 |
+
# Generate videos
|
158 |
+
z = torch.from_numpy(np.random.RandomState(int(0)).randn(1, G.z_dim).astype('float32')).to(device)
|
159 |
+
|
160 |
+
with torch.no_grad():
|
161 |
+
ws = G.mapping(z, input_pose, {'mask': input_label, 'pose': input_pose})
|
162 |
+
out = G.synthesis(ws, input_pose, noise_mode='const', neural_rendering_resolution=neural_rendering_resolution)
|
163 |
+
|
164 |
+
image_color = ((out['image'][0].permute(1, 2, 0).cpu().numpy().clip(-1, 1) + 1) * 127.5).astype(np.uint8)
|
165 |
+
image_seg = color_mask(torch.argmax(out['semantic'][0], dim=0).cpu().numpy()).astype(np.uint8)
|
166 |
+
mesh_trimesh = trimesh.Trimesh(*extract_geometry(G, ws, resolution=mesh_resolution, threshold=50.))
|
167 |
+
|
168 |
+
verts_np = np.array(mesh_trimesh.vertices)
|
169 |
+
colors = torch.zeros((verts_np.shape[0], 3), device=device)
|
170 |
+
semantic_colors = torch.zeros((verts_np.shape[0], 6), device=device)
|
171 |
+
samples_color = torch.tensor(verts_np, device=device).unsqueeze(0).float()
|
172 |
+
|
173 |
+
head = 0
|
174 |
+
max_batch = 10000000
|
175 |
+
with tqdm(total = verts_np.shape[0]) as pbar:
|
176 |
+
with torch.no_grad():
|
177 |
+
while head < verts_np.shape[0]:
|
178 |
+
torch.manual_seed(0)
|
179 |
+
out = G.sample_mixed(samples_color[:, head:head+max_batch], None, ws, truncation_psi=truncation_psi, noise_mode='const')
|
180 |
+
# sigma = out['sigma']
|
181 |
+
colors[head:head+max_batch, :] = out['rgb'][0,:,:3]
|
182 |
+
seg = out['rgb'][0, :, 32:32+6]
|
183 |
+
semantic_colors[head:head+max_batch, :] = seg
|
184 |
+
# semantics[:, head:head+max_batch] = out['semantic']
|
185 |
+
head += max_batch
|
186 |
+
pbar.update(max_batch)
|
187 |
+
|
188 |
+
semantic_colors = torch.tensor(color_list,device=device)[torch.argmax(semantic_colors, dim=-1)]
|
189 |
+
|
190 |
+
mesh_trimesh.visual.vertex_colors = semantic_colors.cpu().numpy().astype(np.uint8)
|
191 |
+
frames, frames_label = render_video(G, ws, intrinsics, num_frames = num_frames, pitch_range = 0.25, yaw_range = 0.35, neural_rendering_resolution=neural_rendering_resolution, device=device)
|
192 |
+
|
193 |
+
# Save the video
|
194 |
+
video=save_dir / f'{cfg}_color.mp4'
|
195 |
+
video_label=save_dir / f'{cfg}_label.mp4'
|
196 |
+
imageio.mimsave(video, frames, fps=fps)
|
197 |
+
imageio.mimsave(video_label, frames_label, fps=fps),
|
198 |
+
fig_mesh=return_plot_go(mesh_trimesh)
|
199 |
+
return fig_mesh,image_color,image_seg,video,video_label
|
200 |
+
|
201 |
+
markdown=f'''
|
202 |
+
# 3D-aware Conditional Image Synthesis
|
203 |
+
|
204 |
+
[Arxiv: "3D-aware Conditional Image Synthesis".](https://arxiv.org/abs/2302.08509)
|
205 |
+
[Project Page.](https://www.cs.cmu.edu/~pix2pix3D/)
|
206 |
+
[For the official implementation.](https://github.com/dunbar12138/pix2pix3D)
|
207 |
+
|
208 |
+
### Future Work based on interest
|
209 |
+
- Adding new models for new type objects
|
210 |
+
- New Customization
|
211 |
+
|
212 |
+
|
213 |
+
It is running on {device}
|
214 |
+
The process can take long time.Especially ,To generate videos and the time of process depends the number of frames,Mesh Resolution and current compiler device.
|
215 |
+
|
216 |
+
'''
|
217 |
+
|
218 |
+
|
219 |
+
with gr.Blocks() as demo:
|
220 |
+
gr.Markdown(markdown)
|
221 |
+
with gr.Row():
|
222 |
+
with gr.Column():
|
223 |
+
input=gr.Image(type="filepath",shape=(512, 512))
|
224 |
+
with gr.Column():
|
225 |
+
cfg=gr.Dropdown(choices=["seg2cat"],label="Choose Model",value="seg2cat")
|
226 |
+
truncation_psi = gr.Slider( minimum=0, maximum=2,label='Truncation PSI',value=1)
|
227 |
+
mesh_resolution = gr.Slider( minimum=32, maximum=512,label='Mesh Resolution',value=32)
|
228 |
+
random_seed = gr.Slider( minimum=0, maximum=2**16,label='Seed',value=128)
|
229 |
+
fps = gr.Slider( minimum=10, maximum=120,label='FPS',value=30)
|
230 |
+
num_frames = gr.Slider( minimum=10, maximum=120,label='The Number of Frames',value=30)
|
231 |
+
|
232 |
+
with gr.Row():
|
233 |
+
btn = gr.Button(value="Generate")
|
234 |
+
|
235 |
+
with gr.Row():
|
236 |
+
with gr.Column():
|
237 |
+
image_color=gr.Image(type="pil",shape=(256,256))
|
238 |
+
with gr.Column():
|
239 |
+
image_label=gr.Image(type="pil",shape=(256,256))
|
240 |
+
with gr.Row():
|
241 |
+
mesh = gr.Plot()
|
242 |
+
with gr.Row():
|
243 |
+
with gr.Column():
|
244 |
+
video_color=gr.Video()
|
245 |
+
with gr.Column():
|
246 |
+
video_label=gr.Video()
|
247 |
+
|
248 |
+
|
249 |
+
|
250 |
+
btn.click(get_all, [cfg,input,truncation_psi,mesh_resolution,random_seed,fps,num_frames],[ mesh,image_color,image_label,video_color,video_label])
|
251 |
+
|
252 |
+
demo.launch(debug=True,share=True)
|
requirements.txt
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
trimesh
|
3 |
+
pyrender
|
4 |
+
PyMCubes
|
5 |
+
pycollada
|
6 |
+
einops
|
7 |
+
ninja
|
8 |
+
imageio-ffmpeg
|
9 |
+
imgui==1.3.0
|
10 |
+
glfw==2.2.0
|
11 |
+
pyopengl==3.1.5
|
12 |
+
pyspng
|
13 |
+
psutil
|
14 |
+
mrcfile
|
15 |
+
opencv-python
|
16 |
+
tqdm
|
17 |
+
scipy
|
18 |
+
pillow
|
19 |
+
numpy
|