import spaces import os import torch import gradio as gr from PIL import Image from pipe.cfgs import load_cfg from pipe.c2f_recons import Pipeline from ops.gs.basic import Gaussian_Scene from datetime import datetime cfg = load_cfg(f'pipe/cfgs/basic.yaml') vistadream = Pipeline(cfg) from ops.visual_check import Check checkor = Check() def get_temp_path(): if not os.path.exists('data/gradio_temp'):os.makedirs('data/gradio_temp') timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") output_path = f"data/gradio_temp/{timestamp}/" return output_path @spaces.GPU(duration=120) def scene_generate(rgb,num_coarse_views,num_mcs_views,mcs_rect_w,mcs_steps): torch.cuda.init() # coarse vistadream.scene = Gaussian_Scene(cfg) # for trajectory genearation vistadream.traj_type = 'spiral' vistadream.scene.traj_type = 'spiral' vistadream.n_sample = num_coarse_views # for scene generation vistadream.opt_iters_per_frame = 512 vistadream.outpaint_extend_times = 0.45 #outpaint_extend_times vistadream.outpaint_selections = ['Left','Right','Top','Bottom'] # for scene refinement vistadream.mcs_n_view = num_mcs_views vistadream.mcs_rect_w = mcs_rect_w vistadream.mcs_iterations = mcs_steps # coarse scene vistadream._coarse_scene(rgb) torch.cuda.empty_cache() @spaces.GPU(duration=120) def scene_refinement(): # refinement vistadream._MCS_Refinement() output_path = get_temp_path() torch.cuda.empty_cache() torch.save(vistadream.scene,output_path+'scene.pth') return output_path def render_video(output_path): scene = vistadream.scene vistadream.checkor._render_video(scene,save_dir=output_path+'.') return output_path+'video_rgb.mp4',output_path+'video_dpt.mp4' def process(rgb,num_coarse_views,num_mcs_views,mcs_rect_w,mcs_steps): scene_generate(rgb,num_coarse_views,num_mcs_views,mcs_rect_w,mcs_steps) path = scene_refinement() rgb.save(output_path+'input.png') return render_video(path) with gr.Blocks(analytics_enabled=False) as demo: with gr.Column(elem_id="col-container"): gr.Markdown("## VistaDream") gr.Markdown("### Sampling multiview consistent images for single-view scene reconstruction") gr.HTML("""
""") with gr.Row(): with gr.Column(): input_image = gr.Image(type="pil") run_button = gr.Button("Run") with gr.Accordion("Advanced options", open=False): num_coarse_views = gr.Slider(label="Coarse-Expand", minimum=5, maximum=25, value=10, step=1) num_mcs_views = gr.Slider(label="MCS Optimization Views", minimum=4, maximum=10, value=8, step=1) mcs_rect_w = gr.Slider(label="MCS Rectification Weight", minimum=0.3, maximum=0.8, value=0.7, step=0.1) mcs_steps = gr.Slider(label="MCS Steps", minimum=8, maximum=15, value=10, step=1) with gr.Column(): with gr.Row(): with gr.Column(): rgb_video = gr.Video("Output RGB renderings") with gr.Column(): dpt_video = gr.Video("Output DPT renderings") examples = gr.Examples( examples = [ ], inputs=[input_image,rgb_video,dpt_video] ) ips = [input_image,num_coarse_views,num_mcs_views,mcs_rect_w,mcs_steps] run_button.click(fn=process, inputs=ips, outputs=[rgb_video,dpt_video]) demo.launch(server_name='0.0.0.0')