File size: 7,302 Bytes
5ed9923
 
 
 
 
 
 
 
 
 
 
 
 
 
633a18a
5ed9923
 
 
 
 
 
 
 
 
 
 
 
12c2bcf
5ed9923
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bc4bba8
5ed9923
633a18a
 
 
5ed9923
 
 
 
 
49b3e3d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5ed9923
 
49b3e3d
 
 
 
5ed9923
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12c2bcf
 
5ed9923
12c2bcf
5ed9923
 
 
 
 
 
 
 
 
12c2bcf
 
5ed9923
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e6feb9e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
#!/usr/bin/env python3
# The MASt3R Gradio demo, modified for predicting 3D Gaussian Splats

# --- Original License ---
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).

import functools
import os
import sys
import tempfile

import gradio
import torch
from huggingface_hub import hf_hub_download

sys.path.append('src/mast3r_src')
sys.path.append('src/mast3r_src/dust3r')
sys.path.append('src/pixelsplat_src')
from dust3r.utils.image import load_images
from mast3r.utils.misc import hash_md5
import main
import utils.export as export


def get_reconstructed_scene(outdir, model, device, silent, image_size, ios_mode, filelist):

    assert len(filelist) == 2, "Please provide two images"
    if ios_mode:
        filelist = [f[0] for f in filelist]
    if len(filelist) == 1:
        filelist = [filelist[0], filelist[0]]
    imgs = load_images(filelist, size=image_size, verbose=not silent)

    for img in imgs:
        img['img'] = img['img'].to(device)
        img['original_img'] = img['original_img'].to(device)
        img['true_shape'] = torch.from_numpy(img['true_shape'])

    output = model(imgs[0], imgs[1])

    pred1, pred2 = output
    plyfile = os.path.join(outdir, 'gaussians.ply')
    export.save_as_ply(pred1, pred2, plyfile)
    return plyfile

if __name__ == '__main__':

    image_size = 512
    silent = False
    ios_mode = True
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    model_name = "brandonsmart/splatt3r_v1.0"
    filename = "epoch=19-step=1200.ckpt"
    weights_path = hf_hub_download(repo_id=model_name, filename=filename)
    model = main.MAST3RGaussians.load_from_checkpoint(weights_path, device)
    chkpt_tag = hash_md5(weights_path)

    # Define example inputs and their corresponding precalculated outputs
    examples = [
        ["demo_examples/scannet++_1_img_1.jpg", "demo_examples/scannet++_1_img_2.jpg", "demo_examples/scannet++_1.ply"],
        ["demo_examples/scannet++_2_img_1.jpg", "demo_examples/scannet++_2_img_2.jpg", "demo_examples/scannet++_2.ply"],
        ["demo_examples/scannet++_3_img_1.jpg", "demo_examples/scannet++_3_img_2.jpg", "demo_examples/scannet++_3.ply"],
        ["demo_examples/scannet++_4_img_1.jpg", "demo_examples/scannet++_4_img_2.jpg", "demo_examples/scannet++_4.ply"],
        ["demo_examples/scannet++_5_img_1.jpg", "demo_examples/scannet++_5_img_2.jpg", "demo_examples/scannet++_5.ply"],
        ["demo_examples/scannet++_6_img_1.jpg", "demo_examples/scannet++_6_img_2.jpg", "demo_examples/scannet++_6.ply"],
        ["demo_examples/scannet++_7_img_1.jpg", "demo_examples/scannet++_7_img_2.jpg", "demo_examples/scannet++_7.ply"],
        ["demo_examples/scannet++_8_img_1.jpg", "demo_examples/scannet++_8_img_2.jpg", "demo_examples/scannet++_8.ply"],
        ["demo_examples/in_the_wild_1_img_1.jpg", "demo_examples/in_the_wild_1_img_2.jpg", "demo_examples/in_the_wild_1.ply"],
        ["demo_examples/in_the_wild_2_img_1.jpg", "demo_examples/in_the_wild_2_img_2.jpg", "demo_examples/in_the_wild_2.ply"],
        ["demo_examples/in_the_wild_3_img_1.jpg", "demo_examples/in_the_wild_3_img_2.jpg", "demo_examples/in_the_wild_3.ply"],
        ["demo_examples/in_the_wild_4_img_1.jpg", "demo_examples/in_the_wild_4_img_2.jpg", "demo_examples/in_the_wild_4.ply"],
        ["demo_examples/in_the_wild_5_img_1.jpg", "demo_examples/in_the_wild_5_img_2.jpg", "demo_examples/in_the_wild_5.ply"],
        ["demo_examples/in_the_wild_6_img_1.jpg", "demo_examples/in_the_wild_6_img_2.jpg", "demo_examples/in_the_wild_6.ply"],
        ["demo_examples/in_the_wild_7_img_1.jpg", "demo_examples/in_the_wild_7_img_2.jpg", "demo_examples/in_the_wild_7.ply"],
        ["demo_examples/in_the_wild_8_img_1.jpg", "demo_examples/in_the_wild_8_img_2.jpg", "demo_examples/in_the_wild_8.ply"],
    ]

    for i in range(len(examples)):
        for j in range(len(examples[i])):
            examples[i][j] = hf_hub_download(repo_id=model_name, filename=examples[i][j])

    with tempfile.TemporaryDirectory(suffix='_mast3r_gradio_demo') as tmpdirname:

        cache_path = os.path.join(tmpdirname, chkpt_tag)
        os.makedirs(cache_path, exist_ok=True)

        recon_fun = functools.partial(get_reconstructed_scene, tmpdirname, model, device, silent, image_size, ios_mode)

        if not ios_mode:
            for i in range(len(examples)):
                examples[i].insert(2, (examples[i][0], examples[i][1]))
                                         
        css = """.gradio-container {margin: 0 !important; min-width: 100%};"""
        with gradio.Blocks(css=css, title="Splatt3R Demo") as demo:

            gradio.HTML('<h2 style="text-align: center;">Splatt3R Demo</h2>')

            with gradio.Column():
                gradio.Markdown('''
                    Please upload exactly one or two images below to be used for reconstruction.
                    If non-square images are uploaded, they will be cropped to squares for reconstruction.
                ''')
                if ios_mode:
                    inputfiles = gradio.Gallery(type="filepath")
                else:
                    inputfiles = gradio.File(file_count="multiple")
                run_btn = gradio.Button("Run")
                gradio.Markdown('''
                    ## Output
                    Below we show the generated 3D Gaussian Splat.
                    The generated splats are 30-40MB, so please allow up to 30 seconds for them to be downloaded from Hugging Face before rendering.
                    In the mean time your previous generations may be visible.
                    The arrow in the top right of the window below can be used to download the .ply for rendering with other viewers,
                    such as [here](https://projects.markkellogg.org/threejs/demo_gaussian_splats_3d.php?art=1&cu=0,-1,0&cp=0,1,0&cla=1,0,0&aa=false&2d=false&sh=0) or [here](https://playcanvas.com/supersplat/editor).
                ''')
                outmodel = gradio.Model3D(
                    clear_color=[1.0, 1.0, 1.0, 0.0],
                )
                run_btn.click(fn=recon_fun, inputs=[inputfiles], outputs=[outmodel])

                gradio.Markdown('''
                    ## Examples
                    A gallery of examples generated from ScanNet++ and from 'in the wild' images taken with a mobile phone.
                    These examples are 30-40MB, so please allow up to 30 seconds for them to be downloaded from Hugging Face before rendering.
                    In the mean time your previous generations may be visible.
                ''')
                
                snapshot_1 = gradio.Image(None, visible=False)
                snapshot_2 = gradio.Image(None, visible=False)
                if ios_mode:
                    gradio.Examples(
                        examples=examples,
                        inputs=[snapshot_1, snapshot_2, outmodel],
                        examples_per_page=5
                    )
                else:
                    gradio.Examples(
                        examples=examples,
                        inputs=[snapshot_1, snapshot_2, inputfiles, outmodel],
                        examples_per_page=5
                    )

        demo.launch()