marigold / app.py
toshas's picture
add badges
7355577
raw
history blame
8.29 kB
# Copyright 2023-2025 Marigold Team, ETH Zürich. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# --------------------------------------------------------------------------
# More information about Marigold:
# https://marigoldmonodepth.github.io
# https://marigoldcomputervision.github.io
# Efficient inference pipelines are now part of diffusers:
# https://huggingface.co/docs/diffusers/using-diffusers/marigold_usage
# https://huggingface.co/docs/diffusers/api/pipelines/marigold
# Examples of trained models and live demos:
# https://huggingface.co/prs-eth
# Related projects:
# https://marigolddepthcompletion.github.io/
# https://rollingdepth.github.io/
# Citation (BibTeX):
# https://github.com/prs-eth/Marigold#-citation
# https://github.com/prs-eth/Marigold-DC#-citation
# https://github.com/prs-eth/rollingdepth#-citation
# --------------------------------------------------------------------------
import os
os.system("pip freeze")
import spaces
import gradio as gr
import torch as torch
from diffusers import MarigoldDepthPipeline, DDIMScheduler
from gradio_dualvision import DualVisionApp
from huggingface_hub import login
from PIL import Image
CHECKPOINT = "prs-eth/marigold-depth-v1-1"
if "HF_TOKEN_LOGIN" in os.environ:
login(token=os.environ["HF_TOKEN_LOGIN"])
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
pipe = MarigoldDepthPipeline.from_pretrained(CHECKPOINT)
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
pipe = pipe.to(device=device, dtype=dtype)
try:
import xformers
pipe.enable_xformers_memory_efficient_attention()
except:
pass
class MarigoldDepthApp(DualVisionApp):
DEFAULT_SEED = 2024
DEFAULT_ENSEMBLE_SIZE = 1
DEFAULT_DENOISE_STEPS = 4
DEFAULT_PROCESSING_RES = 768
def make_header(self):
gr.Markdown(
"""
## Marigold Depth Estimation
<p align="center">
<a title="Website" href="https://marigoldmonodepth.github.io/" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
<img src="https://img.shields.io/badge/%E2%99%A5%20Project%20-Website-blue">
</a>
<a title="arXiv" href="https://arxiv.org/abs/2312.02145" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
<img src="https://img.shields.io/badge/%F0%9F%93%84%20Read%20-Paper-AF3436">
</a>
<a title="Github" href="https://github.com/prs-eth/marigold" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
<img src="https://img.shields.io/github/stars/prs-eth/marigold?label=GitHub%20%E2%98%85&logo=github&color=C8C" alt="badge-github-stars">
</a>
<a title="Image Normals" href="https://huggingface.co/spaces/prs-eth/marigold-normals" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
<img src="https://img.shields.io/badge/%F0%9F%A4%97%20Image%20Normals%20-Demo-yellow" alt="imagedepth">
</a>
<a title="Image Intrinsics" href="https://huggingface.co/spaces/prs-eth/marigold-iid" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
<img src="https://img.shields.io/badge/%F0%9F%A4%97%20Image%20Intrinsics%20-Demo-yellow" alt="imagedepth">
</a>
<a title="LiDAR Depth" href="https://huggingface.co/spaces/prs-eth/marigold-dc" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
<img src="https://img.shields.io/badge/%F0%9F%A4%97%20LiDAR%20Depth%20-Demo-yellow" alt="imagedepth">
</a>
<a title="Video Depth" href="https://huggingface.co/spaces/prs-eth/rollingdepth" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
<img src="https://img.shields.io/badge/%F0%9F%A4%97%20Video%20Depth%20-Demo-yellow" alt="videodepth">
</a>
<a title="Depth-to-3D" href="https://huggingface.co/spaces/prs-eth/depth-to-3d-print" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
<img src="https://img.shields.io/badge/%F0%9F%A4%97%20Depth--to--3D%20-Demo-yellow" alt="depthto3d">
</a>
<a title="Social" href="https://twitter.com/antonobukhov1" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
<img src="https://shields.io/twitter/follow/:?label=Subscribe%20for%20updates!" alt="social">
</a>
</p>
<p align="center" style="margin-top: 0px;">
Upload a photo or select an example below to compute depth maps in real time.
Use the slider to reveal areas of interest.
Use the radio-buttons to switch between modalities.
Check our other demo badges above for new or relocated functionality.
</p>
"""
)
def build_user_components(self):
with gr.Column():
ensemble_size = gr.Slider(
label="Ensemble size",
minimum=1,
maximum=10,
step=1,
value=self.DEFAULT_ENSEMBLE_SIZE,
)
denoise_steps = gr.Slider(
label="Number of denoising steps",
minimum=1,
maximum=20,
step=1,
value=self.DEFAULT_DENOISE_STEPS,
)
processing_res = gr.Radio(
[
("Native", 0),
("Recommended", 768),
],
label="Processing resolution",
value=self.DEFAULT_PROCESSING_RES,
)
return {
"ensemble_size": ensemble_size,
"denoise_steps": denoise_steps,
"processing_res": processing_res,
}
def process(self, image_in: Image.Image, **kwargs):
ensemble_size = kwargs.get("ensemble_size", self.DEFAULT_ENSEMBLE_SIZE)
denoise_steps = kwargs.get("denoise_steps", self.DEFAULT_DENOISE_STEPS)
processing_res = kwargs.get("processing_res", self.DEFAULT_PROCESSING_RES)
generator = torch.Generator(device=device).manual_seed(self.DEFAULT_SEED)
pipe_out = pipe(
image_in,
ensemble_size=ensemble_size,
num_inference_steps=denoise_steps,
processing_resolution=processing_res,
batch_size=1 if processing_res == 0 else 2,
output_uncertainty=ensemble_size >= 3,
generator=generator,
)
depth_vis = pipe.image_processor.visualize_depth(pipe_out.prediction)[0]
depth_16bit = pipe.image_processor.export_depth_to_16bit_png(pipe_out.prediction)[0]
out_modalities = {
"Depth Visualization": depth_vis,
"Depth 16-bit": depth_16bit,
}
if ensemble_size >= 3:
uncertainty = pipe.image_processor.visualize_uncertainty(pipe_out.uncertainty)[0]
out_modalities["Uncertainty"] = uncertainty
out_settings = {
"ensemble_size": ensemble_size,
"denoise_steps": denoise_steps,
"processing_res": processing_res,
}
return out_modalities, out_settings
with MarigoldDepthApp(
title="Marigold Depth",
examples_path="files",
examples_per_page=12,
squeeze_canvas=True,
spaces_zero_gpu_enabled=True,
) as demo:
demo.queue(
api_open=False,
).launch(
server_name="0.0.0.0",
server_port=7860,
)