File size: 5,506 Bytes
42ade60 7c27268 42ade60 7c27268 42ade60 7c27268 42ade60 7c27268 42ade60 7c27268 42ade60 7c27268 42ade60 7c27268 42ade60 7c27268 42ade60 02d2686 42ade60 02d2686 42ade60 7c27268 42ade60 7c27268 42ade60 1837eda 42ade60 7c27268 42ade60 7c27268 42ade60 7c27268 e186d7e 1837eda 42ade60 7c27268 42ade60 7c27268 42ade60 7c27268 45a5139 7c27268 45a5139 1837eda 7c27268 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 |
import os
import torch
import spaces
import matplotlib
import numpy as np
import gradio as gr
from PIL import Image
from transformers import pipeline
from huggingface_hub import hf_hub_download
from gradio_imageslider import ImageSlider
from depth_anything_v2.dpt import DepthAnythingV2
from loguru import logger
css = """
#img-display-container {
max-height: 100vh;
}
#img-display-input {
max-height: 80vh;
}
#img-display-output {
max-height: 80vh;
}
#download {
height: 62px;
}
"""
title = "# Depth Anything: Watch V1 and V2 side by side."
description1 = """Please refer to **Depth Anything V2** [paper](https://arxiv.org/abs/2406.09414) for more details."""
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
DEFAULT_V2_MODEL_NAME = "Base"
DEFAULT_V1_MODEL_NAME = "Base"
cmap = matplotlib.colormaps.get_cmap('Spectral_r')
# --------------------------------------------------------------------
# Depth anything V1 configuration
# --------------------------------------------------------------------
depth_anything_v1_name2checkpoint = {
"Small": "LiheYoung/depth-anything-small-hf",
"Base": "LiheYoung/depth-anything-base-hf",
"Large": "LiheYoung/depth-anything-large-hf",
}
depth_anything_v1_pipelines = {}
# --------------------------------------------------------------------
# Depth anything V2 configuration
# --------------------------------------------------------------------
depth_anything_v2_configs = {
'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]},
'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]},
'vitg': {'encoder': 'vitg', 'features': 384, 'out_channels': [1536, 1536, 1536, 1536]}
}
depth_anything_v2_encoder2name = {
'vits': 'Small',
'vitb': 'Base',
'vitl': 'Large',
# 'vitg': 'Giant', # we are undergoing company review procedures to release our giant model checkpoint
}
depth_anything_v2_name2encoder = {v: k for k, v in depth_anything_v2_encoder2name.items()}
depth_anything_v2_models = {}
# --------------------------------------------------------------------
def get_v1_pipe(model_name):
return pipeline(task="depth-estimation", model=depth_anything_v1_name2checkpoint[model_name], device=DEVICE)
def get_v2_model(model_name):
encoder = depth_anything_v2_name2encoder[model_name]
model = DepthAnythingV2(**depth_anything_v2_configs[encoder])
filepath = hf_hub_download(repo_id=f"depth-anything/Depth-Anything-V2-{model_name}", filename=f"depth_anything_v2_{encoder}.pth", repo_type="model")
state_dict = torch.load(filepath, map_location="cpu")
model.load_state_dict(state_dict)
model = model.to(DEVICE).eval()
return model
def predict_depth_v1(image, model_name):
if model_name not in depth_anything_v1_pipelines:
depth_anything_v1_pipelines[model_name] = get_v1_pipe(model_name)
pipe = depth_anything_v1_pipelines[model_name]
return pipe(image)
def predict_depth_v2(image, model_name):
if model_name not in depth_anything_v2_models:
depth_anything_v2_models[model_name] = get_v2_model(model_name)
model = depth_anything_v2_models[model_name].cuda()
return model.infer_image(image)
def compute_depth_map_v2(image, model_select: str):
depth = predict_depth_v2(image[:, :, ::-1], model_select)
depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0
depth = depth.astype(np.uint8)
colored_depth = (cmap(depth)[:, :, :3] * 255).astype(np.uint8)
return colored_depth
def compute_depth_map_v1(image, model_select):
pil_image = Image.fromarray(image)
depth = predict_depth_v1(pil_image, model_select)
depth = np.array(depth["depth"]).astype(np.uint8)
colored_depth = (cmap(depth)[:, :, :3] * 255).astype(np.uint8)
return colored_depth
@spaces.GPU
@torch.no_grad()
def on_submit(image, model_v1_select, model_v2_select):
logger.info(f"Computing depth for V1 model: {model_v1_select} and V2 model: {model_v2_select}")
colored_depth_v1 = compute_depth_map_v1(image, model_v1_select)
colored_depth_v2 = compute_depth_map_v2(image, model_v2_select)
return colored_depth_v1, colored_depth_v2
with gr.Blocks(css=css) as demo:
gr.Markdown(title)
gr.Markdown(description1)
gr.Markdown("### Depth Prediction demo")
with gr.Row():
model_select_v1 = gr.Dropdown(label="Depth Anything V1 Model", choices=list(depth_anything_v1_name2checkpoint.keys()), value=DEFAULT_V1_MODEL_NAME)
model_select_v2 = gr.Dropdown(label="Depth Anything V2 Model", choices=list(depth_anything_v2_encoder2name.values()), value=DEFAULT_V2_MODEL_NAME)
with gr.Row():
gr.Markdown()
gr.Markdown("Depth Maps: V1 <-> V2")
with gr.Row():
input_image = gr.Image(label="Input Image", type='numpy', elem_id='img-display-input')
depth_image_slider = ImageSlider(elem_id='img-display-output', position=0.5)
submit = gr.Button(value="Compute Depth")
submit.click(on_submit, inputs=[input_image, model_select_v1, model_select_v2], outputs=[depth_image_slider])
example_files = os.listdir('assets/examples')
example_files.sort()
example_files = [os.path.join('assets/examples', filename) for filename in example_files]
examples = gr.Examples(examples=example_files, inputs=[input_image])
if __name__ == '__main__':
demo.queue().launch(share=True)
|