Spaces:
Build error
Build error
import spaces | |
import os | |
import torch | |
import gradio as gr | |
from PIL import Image | |
from pipe.cfgs import load_cfg | |
from pipe.c2f_recons import Pipeline | |
from ops.gs.basic import Gaussian_Scene | |
from datetime import datetime | |
cfg = load_cfg(f'pipe/cfgs/basic.yaml') | |
vistadream = Pipeline(cfg) | |
from ops.visual_check import Check | |
checkor = Check() | |
def get_temp_path(): | |
if not os.path.exists('data/gradio_temp'):os.makedirs('data/gradio_temp') | |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
output_path = f"data/gradio_temp/{timestamp}/" | |
return output_path | |
def scene_generate(rgb,num_coarse_views,num_mcs_views,mcs_rect_w,mcs_steps): | |
torch.cuda.init() | |
# coarse | |
vistadream.scene = Gaussian_Scene(cfg) | |
# for trajectory genearation | |
vistadream.traj_type = 'spiral' | |
vistadream.scene.traj_type = 'spiral' | |
vistadream.n_sample = num_coarse_views | |
# for scene generation | |
vistadream.opt_iters_per_frame = 512 | |
vistadream.outpaint_extend_times = 0.45 #outpaint_extend_times | |
vistadream.outpaint_selections = ['Left','Right','Top','Bottom'] | |
# for scene refinement | |
vistadream.mcs_n_view = num_mcs_views | |
vistadream.mcs_rect_w = mcs_rect_w | |
vistadream.mcs_iterations = mcs_steps | |
# coarse scene | |
vistadream._coarse_scene(rgb) | |
torch.cuda.empty_cache() | |
def scene_refinement(): | |
# refinement | |
vistadream._MCS_Refinement() | |
output_path = get_temp_path() | |
torch.cuda.empty_cache() | |
torch.save(vistadream.scene,output_path+'scene.pth') | |
return output_path | |
def render_video(output_path): | |
scene = vistadream.scene | |
vistadream.checkor._render_video(scene,save_dir=output_path+'.') | |
return output_path+'video_rgb.mp4',output_path+'video_dpt.mp4' | |
def process(rgb,num_coarse_views,num_mcs_views,mcs_rect_w,mcs_steps): | |
scene_generate(rgb,num_coarse_views,num_mcs_views,mcs_rect_w,mcs_steps) | |
path = scene_refinement() | |
rgb.save(output_path+'input.png') | |
return render_video(path) | |
with gr.Blocks(analytics_enabled=False) as demo: | |
with gr.Column(elem_id="col-container"): | |
gr.Markdown("## VistaDream") | |
gr.Markdown("### Sampling multiview consistent images for single-view scene reconstruction") | |
gr.HTML(""" | |
<div style="display:flex;column-gap:4px;"> | |
<a href="https://github.com/WHU-USI3DV/VistaDream"> | |
<img src='https://img.shields.io/badge/GitHub-Repo-blue'> | |
</a> | |
<a href="https://vistadream-project-page.github.io/"> | |
<img src='https://img.shields.io/badge/Project-Page-green'> | |
</a> | |
<a href="https://arxiv.org/abs/2410.16892"> | |
<img src='https://img.shields.io/badge/ArXiv-Paper-red'> | |
</a> | |
</div> | |
""") | |
with gr.Row(): | |
with gr.Column(): | |
input_image = gr.Image(type="pil") | |
run_button = gr.Button("Run") | |
with gr.Accordion("Advanced options", open=False): | |
num_coarse_views = gr.Slider(label="Coarse-Expand", minimum=5, maximum=25, value=10, step=1) | |
num_mcs_views = gr.Slider(label="MCS Optimization Views", minimum=4, maximum=10, value=8, step=1) | |
mcs_rect_w = gr.Slider(label="MCS Rectification Weight", minimum=0.3, maximum=0.8, value=0.7, step=0.1) | |
mcs_steps = gr.Slider(label="MCS Steps", minimum=8, maximum=15, value=10, step=1) | |
with gr.Column(): | |
with gr.Row(): | |
with gr.Column(): | |
rgb_video = gr.Video("Output RGB renderings") | |
with gr.Column(): | |
dpt_video = gr.Video("Output DPT renderings") | |
examples = gr.Examples( | |
examples = [ | |
], | |
inputs=[input_image,rgb_video,dpt_video] | |
) | |
ips = [input_image,num_coarse_views,num_mcs_views,mcs_rect_w,mcs_steps] | |
run_button.click(fn=process, inputs=ips, outputs=[rgb_video,dpt_video]) | |
demo.launch(server_name='0.0.0.0') |