ppsurf / app.py
perler's picture
maybe gradio can render obj outputs
d2fb870
raw
history blame
8.14 kB
#!/usr/bin/env python
from __future__ import annotations
import sys
import os
import datetime
import gradio as gr
import spaces
@spaces.GPU(duration=60 * 3)
def run_on_gpu(input_point_cloud: gr.utils.NamedString,
gen_resolution_global: int,
padding_factor: float,
gen_subsample_manifold_iter: int,
gen_refine_iter: int) -> str:
print('Started inference at {}'.format(datetime.datetime.now()))
print('Inputs:', input_point_cloud, gen_resolution_global, padding_factor,
gen_subsample_manifold_iter, gen_refine_iter)
print('Types:', type(input_point_cloud), type(gen_resolution_global), type(padding_factor),
type(gen_subsample_manifold_iter), type(gen_refine_iter))
sys.path.append(os.path.abspath('ppsurf'))
import subprocess
import uuid
in_file = '{}'.format(input_point_cloud.name)
rand_hash = uuid.uuid4().hex
out_dir = '/tmp/outputs/{}'.format(rand_hash)
out_file_basename = os.path.basename(in_file) + '.obj'
out_file = os.path.join(out_dir, os.path.basename(in_file), out_file_basename)
os.makedirs(out_dir, exist_ok=True)
model_path = 'models/ppsurf_50nn/version_0/checkpoints/last.ckpt'
args = [
'pps.py', 'predict',
'-c', 'ppsurf/configs/poco.yaml',
'-c', 'ppsurf/configs/ppsurf.yaml',
'-c', 'ppsurf/configs/ppsurf_50nn.yaml',
'--ckpt_path', model_path,
'--data.init_args.in_file', in_file,
'--model.init_args.results_dir', out_dir,
'--trainer.logger', 'False',
'--trainer.devices', '1',
'--model.init_args.gen_resolution_global', str(gen_resolution_global),
'--data.init_args.padding_factor', str(padding_factor),
'--model.init_args.gen_subsample_manifold_iter', str(gen_subsample_manifold_iter),
'--model.init_args.gen_refine_iter', str(gen_refine_iter),
]
sys.argv = args
try:
subprocess.run(['python', 'ppsurf/pps.py'] + args[1:]) # need subprocess to spawn workers
except Exception as e:
gr.Warning("Reconstruction failed:\n{}".format(e))
print('Finished inference at {}'.format(datetime.datetime.now()))
result_3d_model = out_file
return result_3d_model
def main():
description_header = '# PPSurf: Combining Patches and Point Convolutions for Detailed Surface Reconstruction'
description_col0 = '''## [Github](https://github.com/cg-tuwien/ppsurf)
Supported input file formats:
- PLY, STL, OBJ and other mesh files,
- XYZ as whitespace-separated text file,
- NPY and NPZ (key='arr_0'),
- LAS and LAZ (version 1.0-1.4), COPC and CRS.
Best results for 50k-250k points.
'''
description_col1 = '''## [Project Info](https://www.cg.tuwien.ac.at/research/publications/2024/erler_2024_ppsurf/)
This method is meant for scans of single and few objects.
Quality for scenes and landscapes will be lower.
Inference takes up to 180 seconds.
'''
# can't render many input types directly in Gradio Model3D
# so we need to convert to supported format
# Gradio can't draw point clouds anyway (2024-03-04), so we skip this for now
# def convert_to_ply(input_point_cloud_upload: gr.utils.NamedString):
#
# # add absolute path to import dirs
# import sys
# import os
# sys.path.append(os.path.abspath('ppsurf'))
#
# # import os
# # os.chdir('ppsurf')
#
# print('Inputs:', input_point_cloud_upload, type(input_point_cloud_upload))
# input_shape: str = input_point_cloud_upload.name
# if not input_shape.endswith('.ply'):
# # load file
# from ppsurf.source.occupancy_data_module import OccupancyDataModule
# pts_np = OccupancyDataModule.load_pts(input_shape)
#
# # convert to ply
# import trimesh
# mesh = trimesh.Trimesh(vertices=pts_np[:, :3])
# input_shape = input_shape + '.ply'
# mesh.export(input_shape)
#
# print('ls:\n', subprocess.run(['ls', os.path.dirname(input_shape)]))
#
# # show in viewer
# print(type(input_tabs))
# # print(type(input_point_cloud_viewer))
# # input_tabs.selected = 'pc_viewer'
# # input_point_cloud_viewer.value = input_shape
if (SPACE_ID := os.getenv('SPACE_ID')) is not None:
description_col1 += (f'\n<p>For faster inference without waiting in queue, '
f'you may duplicate the space and upgrade to GPU in settings. '
f'<a href="https://huggingface.co/spaces/{SPACE_ID}?duplicate=true">'
f'<img style="display: inline; margin-top: 0em; margin-bottom: 0em" '
f'src="https://bit.ly/3gLdBN6" alt="Duplicate Space" /></a></p>')
with gr.Blocks(css='style.css') as demo:
# descriptions
gr.Markdown(description_header)
with gr.Row():
with gr.Column():
gr.Markdown(description_col0)
with gr.Column():
gr.Markdown(description_col1)
# inputs and outputs
with gr.Row():
with gr.Column():
input_point_cloud_upload = gr.File(show_label=False, file_count='single')
# with gr.Tabs() as input_tabs: # re-enable when Gradio supports point clouds
# with gr.TabItem(label='Input Point Cloud Upload', id='pc_upload'):
# input_point_cloud_upload.upload(
# fn=convert_to_ply,
# inputs=[
# input_point_cloud_upload,
# ],
# outputs=[
# # input_point_cloud_viewer, # not available here
# ])
# with gr.TabItem(label='Input Point Cloud Viewer', id='pc_viewer'):
# input_point_cloud_viewer = gr.Model3D(show_label=False)
gen_resolution_global = gr.Slider(
label='Grid Resolution (larger for more details)',
minimum=17, maximum=513, value=129, step=2)
padding_factor = gr.Slider(
label='Padding Factor (larger if object is cut off at boundaries)',
minimum=0, maximum=1.0, value=0.05, step=0.05)
gen_subsample_manifold_iter = gr.Slider(
label='Subsample Manifold Iterations (larger for larger point clouds)',
minimum=3, maximum=30, value=10, step=1)
gen_refine_iter = gr.Slider(
label='Edge Refinement Iterations (larger for more details)',
minimum=3, maximum=30, value=10, step=1)
with gr.Column():
result_3d_model = gr.Model3D(label='Reconstructed 3D model')
# progress_text = gr.Text(label='Progress')
# with gr.Tabs():
# with gr.TabItem(label='Reconstructed 3D model'):
# result_3d_model = gr.Model3D(show_label=False)
# with gr.TabItem(label='Output mesh file'):
# output_file = gr.File(show_label=False)
with gr.Row():
run_button = gr.Button('Reconstruct with PPSurf')
run_button.click(fn=run_on_gpu,
inputs=[
input_point_cloud_upload,
gen_resolution_global,
padding_factor,
gen_subsample_manifold_iter,
gen_refine_iter,
],
outputs=[
result_3d_model,
# output_file,
# progress_text,
])
demo.queue(max_size=5)
demo.launch(debug=True)
if __name__ == '__main__':
print(os.environ)
main()