ascarlettvfx's picture
Upload 99 files
660acc7 verified
raw
history blame
10.1 kB
# Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# --------------------------------------------------------------------------
# If you find this code useful, we kindly ask you to cite our paper in your work.
# Please find bibtex at: https://github.com/prs-eth/Marigold#-citation
# More information about the method can be found at https://marigoldmonodepth.github.io
# --------------------------------------------------------------------------
import argparse
import logging
import os
from glob import glob
import numpy as np
import torch
from PIL import Image
from tqdm.auto import tqdm
from marigold import MarigoldPipeline
EXTENSION_LIST = [".jpg", ".jpeg", ".png"]
if "__main__" == __name__:
logging.basicConfig(level=logging.INFO)
# -------------------- Arguments --------------------
parser = argparse.ArgumentParser(
description="Run single-image depth estimation using Marigold."
)
parser.add_argument(
"--checkpoint",
type=str,
default="prs-eth/marigold-lcm-v1-0",
help="Checkpoint path or hub name.",
)
parser.add_argument(
"--input_rgb_dir",
type=str,
required=True,
help="Path to the input image folder.",
)
parser.add_argument(
"--output_dir", type=str, required=True, help="Output directory."
)
# inference setting
parser.add_argument(
"--denoise_steps",
type=int,
default=None,
help="Diffusion denoising steps, more steps results in higher accuracy but slower inference speed. For the original (DDIM) version, it's recommended to use 10-50 steps, while for LCM 1-4 steps.",
)
parser.add_argument(
"--ensemble_size",
type=int,
default=5,
help="Number of predictions to be ensembled, more inference gives better results but runs slower.",
)
parser.add_argument(
"--half_precision",
"--fp16",
action="store_true",
help="Run with half-precision (16-bit float), might lead to suboptimal result.",
)
# resolution setting
parser.add_argument(
"--processing_res",
type=int,
default=None,
help="Maximum resolution of processing. 0 for using input image resolution. Default: 768.",
)
parser.add_argument(
"--output_processing_res",
action="store_true",
help="When input is resized, out put depth at resized operating resolution. Default: False.",
)
parser.add_argument(
"--resample_method",
choices=["bilinear", "bicubic", "nearest"],
default="bilinear",
help="Resampling method used to resize images and depth predictions. This can be one of `bilinear`, `bicubic` or `nearest`. Default: `bilinear`",
)
# depth map colormap
parser.add_argument(
"--color_map",
type=str,
default="Spectral",
help="Colormap used to render depth predictions.",
)
# other settings
parser.add_argument(
"--seed",
type=int,
default=None,
help="Reproducibility seed. Set to `None` for unseeded inference.",
)
parser.add_argument(
"--batch_size",
type=int,
default=0,
help="Inference batch size. Default: 0 (will be set automatically).",
)
parser.add_argument(
"--apple_silicon",
action="store_true",
help="Flag of running on Apple Silicon.",
)
args = parser.parse_args()
checkpoint_path = args.checkpoint
input_rgb_dir = args.input_rgb_dir
output_dir = args.output_dir
denoise_steps = args.denoise_steps
ensemble_size = args.ensemble_size
if ensemble_size > 15:
logging.warning("Running with large ensemble size will be slow.")
half_precision = args.half_precision
processing_res = args.processing_res
match_input_res = not args.output_processing_res
if 0 == processing_res and match_input_res is False:
logging.warning(
"Processing at native resolution without resizing output might NOT lead to exactly the same resolution, due to the padding and pooling properties of conv layers."
)
resample_method = args.resample_method
color_map = args.color_map
seed = args.seed
batch_size = args.batch_size
apple_silicon = args.apple_silicon
if apple_silicon and 0 == batch_size:
batch_size = 1 # set default batchsize
# -------------------- Preparation --------------------
# Output directories
output_dir_color = os.path.join(output_dir, "depth_colored")
output_dir_tif = os.path.join(output_dir, "depth_bw")
output_dir_npy = os.path.join(output_dir, "depth_npy")
os.makedirs(output_dir, exist_ok=True)
os.makedirs(output_dir_color, exist_ok=True)
os.makedirs(output_dir_tif, exist_ok=True)
os.makedirs(output_dir_npy, exist_ok=True)
logging.info(f"output dir = {output_dir}")
# -------------------- Device --------------------
if apple_silicon:
if torch.backends.mps.is_available() and torch.backends.mps.is_built():
device = torch.device("mps:0")
else:
device = torch.device("cpu")
logging.warning("MPS is not available. Running on CPU will be slow.")
else:
if torch.cuda.is_available():
device = torch.device("cuda")
else:
device = torch.device("cpu")
logging.warning("CUDA is not available. Running on CPU will be slow.")
logging.info(f"device = {device}")
# -------------------- Data --------------------
rgb_filename_list = glob(os.path.join(input_rgb_dir, "*"))
rgb_filename_list = [
f for f in rgb_filename_list if os.path.splitext(f)[1].lower() in EXTENSION_LIST
]
rgb_filename_list = sorted(rgb_filename_list)
n_images = len(rgb_filename_list)
if n_images > 0:
logging.info(f"Found {n_images} images")
else:
logging.error(f"No image found in '{input_rgb_dir}'")
exit(1)
# -------------------- Model --------------------
if half_precision:
dtype = torch.float16
variant = "fp16"
logging.info(
f"Running with half precision ({dtype}), might lead to suboptimal result."
)
else:
dtype = torch.float32
variant = None
pipe: MarigoldPipeline = MarigoldPipeline.from_pretrained(
checkpoint_path, variant=variant, torch_dtype=dtype
)
try:
pipe.enable_xformers_memory_efficient_attention()
except ImportError:
pass # run without xformers
pipe = pipe.to(device)
logging.info(
f"scale_invariant: {pipe.scale_invariant}, shift_invariant: {pipe.shift_invariant}"
)
# Print out config
logging.info(
f"Inference settings: checkpoint = `{checkpoint_path}`, "
f"with denoise_steps = {denoise_steps or pipe.default_denoising_steps}, "
f"ensemble_size = {ensemble_size}, "
f"processing resolution = {processing_res or pipe.default_processing_resolution}, "
f"seed = {seed}; "
f"color_map = {color_map}."
)
# -------------------- Inference and saving --------------------
with torch.no_grad():
os.makedirs(output_dir, exist_ok=True)
for rgb_path in tqdm(rgb_filename_list, desc="Estimating depth", leave=True):
# Read input image
input_image = Image.open(rgb_path)
# Random number generator
if seed is None:
generator = None
else:
generator = torch.Generator(device=device)
generator.manual_seed(seed)
# Predict depth
pipe_out = pipe(
input_image,
denoising_steps=denoise_steps,
ensemble_size=ensemble_size,
processing_res=processing_res,
match_input_res=match_input_res,
batch_size=batch_size,
color_map=color_map,
show_progress_bar=True,
resample_method=resample_method,
generator=generator,
)
depth_pred: np.ndarray = pipe_out.depth_np
depth_colored: Image.Image = pipe_out.depth_colored
# Save as npy
rgb_name_base = os.path.splitext(os.path.basename(rgb_path))[0]
pred_name_base = rgb_name_base + "_pred"
npy_save_path = os.path.join(output_dir_npy, f"{pred_name_base}.npy")
if os.path.exists(npy_save_path):
logging.warning(f"Existing file: '{npy_save_path}' will be overwritten")
np.save(npy_save_path, depth_pred)
# Save as 16-bit uint png
depth_to_save = (depth_pred * 65535.0).astype(np.uint16)
png_save_path = os.path.join(output_dir_tif, f"{pred_name_base}.png")
if os.path.exists(png_save_path):
logging.warning(f"Existing file: '{png_save_path}' will be overwritten")
Image.fromarray(depth_to_save).save(png_save_path, mode="I;16")
# Colorize
colored_save_path = os.path.join(
output_dir_color, f"{pred_name_base}_colored.png"
)
if os.path.exists(colored_save_path):
logging.warning(
f"Existing file: '{colored_save_path}' will be overwritten"
)
depth_colored.save(colored_save_path)