Spaces:
Runtime error
Runtime error
# Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# -------------------------------------------------------------------------- | |
# If you find this code useful, we kindly ask you to cite our paper in your work. | |
# Please find bibtex at: https://github.com/prs-eth/Marigold#-citation | |
# More information about the method can be found at https://marigoldmonodepth.github.io | |
# -------------------------------------------------------------------------- | |
import argparse | |
import logging | |
import os | |
from glob import glob | |
import numpy as np | |
import torch | |
from PIL import Image | |
from tqdm.auto import tqdm | |
from marigold import MarigoldPipeline | |
EXTENSION_LIST = [".jpg", ".jpeg", ".png"] | |
if "__main__" == __name__: | |
logging.basicConfig(level=logging.INFO) | |
# -------------------- Arguments -------------------- | |
parser = argparse.ArgumentParser( | |
description="Run single-image depth estimation using Marigold." | |
) | |
parser.add_argument( | |
"--checkpoint", | |
type=str, | |
default="prs-eth/marigold-lcm-v1-0", | |
help="Checkpoint path or hub name.", | |
) | |
parser.add_argument( | |
"--input_rgb_dir", | |
type=str, | |
required=True, | |
help="Path to the input image folder.", | |
) | |
parser.add_argument( | |
"--output_dir", type=str, required=True, help="Output directory." | |
) | |
# inference setting | |
parser.add_argument( | |
"--denoise_steps", | |
type=int, | |
default=None, | |
help="Diffusion denoising steps, more steps results in higher accuracy but slower inference speed. For the original (DDIM) version, it's recommended to use 10-50 steps, while for LCM 1-4 steps.", | |
) | |
parser.add_argument( | |
"--ensemble_size", | |
type=int, | |
default=5, | |
help="Number of predictions to be ensembled, more inference gives better results but runs slower.", | |
) | |
parser.add_argument( | |
"--half_precision", | |
"--fp16", | |
action="store_true", | |
help="Run with half-precision (16-bit float), might lead to suboptimal result.", | |
) | |
# resolution setting | |
parser.add_argument( | |
"--processing_res", | |
type=int, | |
default=None, | |
help="Maximum resolution of processing. 0 for using input image resolution. Default: 768.", | |
) | |
parser.add_argument( | |
"--output_processing_res", | |
action="store_true", | |
help="When input is resized, out put depth at resized operating resolution. Default: False.", | |
) | |
parser.add_argument( | |
"--resample_method", | |
choices=["bilinear", "bicubic", "nearest"], | |
default="bilinear", | |
help="Resampling method used to resize images and depth predictions. This can be one of `bilinear`, `bicubic` or `nearest`. Default: `bilinear`", | |
) | |
# depth map colormap | |
parser.add_argument( | |
"--color_map", | |
type=str, | |
default="Spectral", | |
help="Colormap used to render depth predictions.", | |
) | |
# other settings | |
parser.add_argument( | |
"--seed", | |
type=int, | |
default=None, | |
help="Reproducibility seed. Set to `None` for unseeded inference.", | |
) | |
parser.add_argument( | |
"--batch_size", | |
type=int, | |
default=0, | |
help="Inference batch size. Default: 0 (will be set automatically).", | |
) | |
parser.add_argument( | |
"--apple_silicon", | |
action="store_true", | |
help="Flag of running on Apple Silicon.", | |
) | |
args = parser.parse_args() | |
checkpoint_path = args.checkpoint | |
input_rgb_dir = args.input_rgb_dir | |
output_dir = args.output_dir | |
denoise_steps = args.denoise_steps | |
ensemble_size = args.ensemble_size | |
if ensemble_size > 15: | |
logging.warning("Running with large ensemble size will be slow.") | |
half_precision = args.half_precision | |
processing_res = args.processing_res | |
match_input_res = not args.output_processing_res | |
if 0 == processing_res and match_input_res is False: | |
logging.warning( | |
"Processing at native resolution without resizing output might NOT lead to exactly the same resolution, due to the padding and pooling properties of conv layers." | |
) | |
resample_method = args.resample_method | |
color_map = args.color_map | |
seed = args.seed | |
batch_size = args.batch_size | |
apple_silicon = args.apple_silicon | |
if apple_silicon and 0 == batch_size: | |
batch_size = 1 # set default batchsize | |
# -------------------- Preparation -------------------- | |
# Output directories | |
output_dir_color = os.path.join(output_dir, "depth_colored") | |
output_dir_tif = os.path.join(output_dir, "depth_bw") | |
output_dir_npy = os.path.join(output_dir, "depth_npy") | |
os.makedirs(output_dir, exist_ok=True) | |
os.makedirs(output_dir_color, exist_ok=True) | |
os.makedirs(output_dir_tif, exist_ok=True) | |
os.makedirs(output_dir_npy, exist_ok=True) | |
logging.info(f"output dir = {output_dir}") | |
# -------------------- Device -------------------- | |
if apple_silicon: | |
if torch.backends.mps.is_available() and torch.backends.mps.is_built(): | |
device = torch.device("mps:0") | |
else: | |
device = torch.device("cpu") | |
logging.warning("MPS is not available. Running on CPU will be slow.") | |
else: | |
if torch.cuda.is_available(): | |
device = torch.device("cuda") | |
else: | |
device = torch.device("cpu") | |
logging.warning("CUDA is not available. Running on CPU will be slow.") | |
logging.info(f"device = {device}") | |
# -------------------- Data -------------------- | |
rgb_filename_list = glob(os.path.join(input_rgb_dir, "*")) | |
rgb_filename_list = [ | |
f for f in rgb_filename_list if os.path.splitext(f)[1].lower() in EXTENSION_LIST | |
] | |
rgb_filename_list = sorted(rgb_filename_list) | |
n_images = len(rgb_filename_list) | |
if n_images > 0: | |
logging.info(f"Found {n_images} images") | |
else: | |
logging.error(f"No image found in '{input_rgb_dir}'") | |
exit(1) | |
# -------------------- Model -------------------- | |
if half_precision: | |
dtype = torch.float16 | |
variant = "fp16" | |
logging.info( | |
f"Running with half precision ({dtype}), might lead to suboptimal result." | |
) | |
else: | |
dtype = torch.float32 | |
variant = None | |
pipe: MarigoldPipeline = MarigoldPipeline.from_pretrained( | |
checkpoint_path, variant=variant, torch_dtype=dtype | |
) | |
try: | |
pipe.enable_xformers_memory_efficient_attention() | |
except ImportError: | |
pass # run without xformers | |
pipe = pipe.to(device) | |
logging.info( | |
f"scale_invariant: {pipe.scale_invariant}, shift_invariant: {pipe.shift_invariant}" | |
) | |
# Print out config | |
logging.info( | |
f"Inference settings: checkpoint = `{checkpoint_path}`, " | |
f"with denoise_steps = {denoise_steps or pipe.default_denoising_steps}, " | |
f"ensemble_size = {ensemble_size}, " | |
f"processing resolution = {processing_res or pipe.default_processing_resolution}, " | |
f"seed = {seed}; " | |
f"color_map = {color_map}." | |
) | |
# -------------------- Inference and saving -------------------- | |
with torch.no_grad(): | |
os.makedirs(output_dir, exist_ok=True) | |
for rgb_path in tqdm(rgb_filename_list, desc="Estimating depth", leave=True): | |
# Read input image | |
input_image = Image.open(rgb_path) | |
# Random number generator | |
if seed is None: | |
generator = None | |
else: | |
generator = torch.Generator(device=device) | |
generator.manual_seed(seed) | |
# Predict depth | |
pipe_out = pipe( | |
input_image, | |
denoising_steps=denoise_steps, | |
ensemble_size=ensemble_size, | |
processing_res=processing_res, | |
match_input_res=match_input_res, | |
batch_size=batch_size, | |
color_map=color_map, | |
show_progress_bar=True, | |
resample_method=resample_method, | |
generator=generator, | |
) | |
depth_pred: np.ndarray = pipe_out.depth_np | |
depth_colored: Image.Image = pipe_out.depth_colored | |
# Save as npy | |
rgb_name_base = os.path.splitext(os.path.basename(rgb_path))[0] | |
pred_name_base = rgb_name_base + "_pred" | |
npy_save_path = os.path.join(output_dir_npy, f"{pred_name_base}.npy") | |
if os.path.exists(npy_save_path): | |
logging.warning(f"Existing file: '{npy_save_path}' will be overwritten") | |
np.save(npy_save_path, depth_pred) | |
# Save as 16-bit uint png | |
depth_to_save = (depth_pred * 65535.0).astype(np.uint16) | |
png_save_path = os.path.join(output_dir_tif, f"{pred_name_base}.png") | |
if os.path.exists(png_save_path): | |
logging.warning(f"Existing file: '{png_save_path}' will be overwritten") | |
Image.fromarray(depth_to_save).save(png_save_path, mode="I;16") | |
# Colorize | |
colored_save_path = os.path.join( | |
output_dir_color, f"{pred_name_base}_colored.png" | |
) | |
if os.path.exists(colored_save_path): | |
logging.warning( | |
f"Existing file: '{colored_save_path}' will be overwritten" | |
) | |
depth_colored.save(colored_save_path) | |