File size: 5,506 Bytes
42ade60
 
 
7c27268
42ade60
7c27268
42ade60
 
7c27268
42ade60
7c27268
42ade60
7c27268
 
42ade60
7c27268
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42ade60
 
 
 
7c27268
42ade60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7c27268
 
 
 
 
42ade60
02d2686
 
 
42ade60
02d2686
42ade60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7c27268
42ade60
 
 
 
 
7c27268
 
42ade60
 
 
1837eda
42ade60
7c27268
 
42ade60
 
 
 
 
 
7c27268
 
42ade60
 
 
 
 
 
7c27268
 
e186d7e
1837eda
42ade60
 
 
 
 
7c27268
 
42ade60
 
 
 
 
 
 
 
 
 
 
 
 
7c27268
42ade60
 
7c27268
45a5139
7c27268
45a5139
1837eda
7c27268
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import os
import torch
import spaces
import matplotlib

import numpy as np
import gradio as gr

from PIL import Image
from transformers import pipeline
from huggingface_hub import hf_hub_download
from gradio_imageslider import ImageSlider

from depth_anything_v2.dpt import DepthAnythingV2
from loguru import logger

css = """
#img-display-container {
    max-height: 100vh;
}
#img-display-input {
    max-height: 80vh;
}
#img-display-output {
    max-height: 80vh;
}
#download {
    height: 62px;
}
"""

title = "# Depth Anything: Watch V1 and V2 side by side."
description1 = """Please refer to **Depth Anything V2** [paper](https://arxiv.org/abs/2406.09414) for more details."""

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
DEFAULT_V2_MODEL_NAME = "Base"
DEFAULT_V1_MODEL_NAME = "Base"

cmap = matplotlib.colormaps.get_cmap('Spectral_r')

# --------------------------------------------------------------------
# Depth anything V1 configuration
# --------------------------------------------------------------------
depth_anything_v1_name2checkpoint = {
    "Small": "LiheYoung/depth-anything-small-hf",
    "Base": "LiheYoung/depth-anything-base-hf",
    "Large": "LiheYoung/depth-anything-large-hf",
}

depth_anything_v1_pipelines = {}
# --------------------------------------------------------------------
# Depth anything V2 configuration
# --------------------------------------------------------------------

depth_anything_v2_configs = {
    'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
    'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]},
    'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]},
    'vitg': {'encoder': 'vitg', 'features': 384, 'out_channels': [1536, 1536, 1536, 1536]}
}
depth_anything_v2_encoder2name = {
    'vits': 'Small',
    'vitb': 'Base',
    'vitl': 'Large',
    # 'vitg': 'Giant', # we are undergoing company review procedures to release our giant model checkpoint
}
depth_anything_v2_name2encoder = {v: k for k, v in depth_anything_v2_encoder2name.items()}

depth_anything_v2_models = {}
# --------------------------------------------------------------------


def get_v1_pipe(model_name):
    return pipeline(task="depth-estimation", model=depth_anything_v1_name2checkpoint[model_name], device=DEVICE)


def get_v2_model(model_name):
    encoder = depth_anything_v2_name2encoder[model_name]
    model = DepthAnythingV2(**depth_anything_v2_configs[encoder])
    filepath = hf_hub_download(repo_id=f"depth-anything/Depth-Anything-V2-{model_name}", filename=f"depth_anything_v2_{encoder}.pth", repo_type="model")
    state_dict = torch.load(filepath, map_location="cpu")
    model.load_state_dict(state_dict)
    model = model.to(DEVICE).eval()
    return model


def predict_depth_v1(image, model_name):
    if model_name not in depth_anything_v1_pipelines:
        depth_anything_v1_pipelines[model_name] = get_v1_pipe(model_name)
    pipe = depth_anything_v1_pipelines[model_name]
    return pipe(image)


def predict_depth_v2(image, model_name):
    if model_name not in depth_anything_v2_models:
        depth_anything_v2_models[model_name] = get_v2_model(model_name)
    model = depth_anything_v2_models[model_name].cuda()
    return model.infer_image(image)


def compute_depth_map_v2(image, model_select: str):
    depth = predict_depth_v2(image[:, :, ::-1], model_select)
    depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0
    depth = depth.astype(np.uint8)
    colored_depth = (cmap(depth)[:, :, :3] * 255).astype(np.uint8)
    return colored_depth


def compute_depth_map_v1(image, model_select):
    pil_image = Image.fromarray(image)
    depth = predict_depth_v1(pil_image, model_select)
    depth = np.array(depth["depth"]).astype(np.uint8)
    colored_depth = (cmap(depth)[:, :, :3] * 255).astype(np.uint8)
    return colored_depth


@spaces.GPU
@torch.no_grad()
def on_submit(image, model_v1_select, model_v2_select):
    logger.info(f"Computing depth for V1 model: {model_v1_select} and V2 model: {model_v2_select}")
    colored_depth_v1 = compute_depth_map_v1(image, model_v1_select)
    colored_depth_v2 = compute_depth_map_v2(image, model_v2_select)
    return colored_depth_v1, colored_depth_v2


with gr.Blocks(css=css) as demo:
    gr.Markdown(title)
    gr.Markdown(description1)
    gr.Markdown("### Depth Prediction demo")
    with gr.Row():
        model_select_v1 = gr.Dropdown(label="Depth Anything V1 Model", choices=list(depth_anything_v1_name2checkpoint.keys()), value=DEFAULT_V1_MODEL_NAME)
        model_select_v2 = gr.Dropdown(label="Depth Anything V2 Model", choices=list(depth_anything_v2_encoder2name.values()), value=DEFAULT_V2_MODEL_NAME)
    with gr.Row():
        gr.Markdown()
        gr.Markdown("Depth Maps: V1 <-> V2")
    with gr.Row():
        input_image = gr.Image(label="Input Image", type='numpy', elem_id='img-display-input')
        depth_image_slider = ImageSlider(elem_id='img-display-output', position=0.5)

    submit = gr.Button(value="Compute Depth")
    submit.click(on_submit, inputs=[input_image, model_select_v1, model_select_v2], outputs=[depth_image_slider])

    example_files = os.listdir('assets/examples')
    example_files.sort()
    example_files = [os.path.join('assets/examples', filename) for filename in example_files]
    examples = gr.Examples(examples=example_files, inputs=[input_image])


if __name__ == '__main__':
    demo.queue().launch(share=True)