File size: 3,712 Bytes
eea4530
 
 
b2ee272
eea4530
 
 
4d29a77
95377ef
eea4530
 
 
95377ef
eea4530
b2ee272
95377ef
 
 
b2ee272
 
 
 
95377ef
b2ee272
 
 
 
4d29a77
 
 
 
 
 
 
 
 
 
 
 
 
eea4530
95377ef
4d29a77
 
 
 
 
 
 
 
 
95377ef
eea4530
 
 
 
 
 
 
 
 
95377ef
 
 
 
eea4530
 
 
 
 
 
95377ef
 
 
 
4d29a77
 
95377ef
 
 
 
 
 
4d29a77
95377ef
 
 
 
eea4530
95377ef
 
 
 
 
 
eea4530
95377ef
eea4530
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95377ef
eea4530
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import rerun as rr
import rerun.blueprint as rrb
import depth_pro
import subprocess

import torch
import cv2
import numpy as np
import os
from pathlib import Path
import gradio as gr
from gradio_rerun import Rerun
import spaces

# Run the script to get pretrained models
if not os.path.exists("checkpoints/depth_pro.pt"):
    print("downloading pretrained model")
    subprocess.run(["bash", "get_pretrained_models.sh"])

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Load model and preprocessing transform
print("loading model...")
model, transform = depth_pro.create_model_and_transforms()
model = model.to(device)
model.eval()

def resize_image(image_path, max_size=1536):
    with Image.open(image_path) as img:
        # Calculate the new size while maintaining aspect ratio
        ratio = max_size / max(img.size)
        new_size = tuple([int(x * ratio) for x in img.size])
        
        # Resize the image
        img = img.resize(new_size, Image.LANCZOS)
        
        # Create a temporary file
        with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as temp_file:
            img.save(temp_file, format="PNG")
            return temp_file.name

@spaces.GPU(duration=20)
def predict(frame): 
    image = transform(frame)
    image = image.to(device)
    prediction = model.infer(image)
    depth = prediction["depth"].squeeze().detach().cpu().numpy()
    return depth.cpu().numpy(), prediction["focallength_px"].item()


@rr.thread_local_stream("rerun_example_ml_depth_pro")
def run_ml_depth_pro(frame):
    stream = rr.binary_stream()

    assert model is not None, "Model is None"
    assert transform is not None, "Transform is None"
    assert frames is not None, "Frames is None"

    blueprint = rrb.Blueprint(
        rrb.Spatial3DView(origin="/"),
        rrb.Horizontal(
            rrb.Spatial2DView(
                origin="/world/camera/depth",
            ),
            rrb.Spatial2DView(origin="/world/camera/image"),
        ),
        collapse_panels=True,
    )

    rr.send_blueprint(blueprint)

    # for i, frame in enumerate(frames):
    rr.set_time_sequence("frame", 0)
    rr.log("world/camera/image", rr.Image(frame))

    depth, focal_length = predict(frame)


    rr.log(
        "world/camera",
        rr.Pinhole(
            width=frame.shape[1],
            height=frame.shape[0],
            focal_length=focal_length,
            principal_point=(frame.shape[1] / 2, frame.shape[0] / 2),
            image_plane_distance=depth.max(),
        ),
    )

    rr.log(
        "world/camera/depth",
        # need 0.19 stable for this
        # rr.DepthImage(depth, meter=1, depth_range=(depth.min(), depth.max())),
        rr.DepthImage(depth, meter=1),
    )

    yield stream.read()


video_path = Path("hd-cat.mp4")


# Load video
frames = []
video = cv2.VideoCapture("hd-cat2.mp4")
while True:
    read, frame = video.read()
    if not read:
        break
    frame = cv2.resize(frame, (320, 240))
    frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    frames.append(frame)

with gr.Blocks() as demo:
    with gr.Tab("Streaming"):
        with gr.Row():
            img = gr.Image(interactive=True, label="Image")
            with gr.Column():
                stream_ml_depth_pro = gr.Button("Stream Ml Depth Pro")
        with gr.Row():
            viewer = Rerun(
                streaming=True,
                panel_states={
                    "time": "collapsed",
                    "blueprint": "hidden",
                    "selection": "hidden",
                },
            )
        stream_ml_depth_pro.click(run_ml_depth_pro, inputs=[img], outputs=[viewer])


if __name__ == "__main__":
    demo.launch()