|
|
|
|
|
from typing import Any, Dict, Union |
|
|
|
import torch |
|
from torch.utils.data import DataLoader, TensorDataset |
|
import numpy as np |
|
from tqdm.auto import tqdm |
|
from PIL import Image |
|
from diffusers import ( |
|
DiffusionPipeline, |
|
DDIMScheduler, |
|
AutoencoderKL, |
|
) |
|
from models.unet_2d_condition import UNet2DConditionModel |
|
from diffusers.utils import BaseOutput |
|
from transformers import CLIPTextModel, CLIPTokenizer |
|
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection |
|
import torchvision.transforms.functional as TF |
|
from torchvision.transforms import InterpolationMode |
|
|
|
from utils.image_util import resize_max_res,chw2hwc,colorize_depth_maps |
|
from utils.colormap import kitti_colormap |
|
from utils.depth_ensemble import ensemble_depths |
|
from utils.normal_ensemble import ensemble_normals |
|
from utils.batch_size import find_batch_size |
|
import cv2 |
|
|
|
class DepthNormalPipelineOutput(BaseOutput): |
|
""" |
|
Output class for Marigold monocular depth prediction pipeline. |
|
|
|
Args: |
|
depth_np (`np.ndarray`): |
|
Predicted depth map, with depth values in the range of [0, 1]. |
|
depth_colored (`PIL.Image.Image`): |
|
Colorized depth map, with the shape of [3, H, W] and values in [0, 1]. |
|
normal_np (`np.ndarray`): |
|
Predicted normal map, with depth values in the range of [0, 1]. |
|
normal_colored (`PIL.Image.Image`): |
|
Colorized normal map, with the shape of [3, H, W] and values in [0, 1]. |
|
uncertainty (`None` or `np.ndarray`): |
|
Uncalibrated uncertainty(MAD, median absolute deviation) coming from ensembling. |
|
""" |
|
depth_np: np.ndarray |
|
depth_colored: Image.Image |
|
normal_np: np.ndarray |
|
normal_colored: Image.Image |
|
uncertainty: Union[None, np.ndarray] |
|
|
|
class DepthNormalEstimationPipeline(DiffusionPipeline): |
|
|
|
latent_scale_factor = 0.18215 |
|
|
|
def __init__(self, |
|
unet:UNet2DConditionModel, |
|
vae:AutoencoderKL, |
|
scheduler:DDIMScheduler, |
|
image_encoder:CLIPVisionModelWithProjection, |
|
feature_extractor:CLIPImageProcessor, |
|
): |
|
super().__init__() |
|
|
|
self.register_modules( |
|
unet=unet, |
|
vae=vae, |
|
scheduler=scheduler, |
|
image_encoder=image_encoder, |
|
feature_extractor=feature_extractor, |
|
) |
|
self.img_embed = None |
|
|
|
@torch.no_grad() |
|
def __call__(self, |
|
input_image:Image, |
|
denosing_steps: int = 10, |
|
ensemble_size: int = 10, |
|
processing_res: int = 768, |
|
match_input_res:bool =True, |
|
batch_size:int = 0, |
|
domain: str = "indoor", |
|
color_map: str="Spectral", |
|
show_progress_bar:bool = True, |
|
ensemble_kwargs: Dict = None, |
|
) -> DepthNormalPipelineOutput: |
|
|
|
|
|
device = self.device |
|
input_size = input_image.size |
|
|
|
|
|
if not match_input_res: |
|
assert ( |
|
processing_res is not None |
|
)," Value Error: `resize_output_back` is only valid with " |
|
|
|
assert processing_res >=0 |
|
assert denosing_steps >=1 |
|
assert ensemble_size >=1 |
|
|
|
|
|
|
|
if processing_res >0: |
|
input_image = resize_max_res( |
|
input_image, max_edge_resolution=processing_res |
|
) |
|
|
|
|
|
input_image = input_image.convert("RGB") |
|
image = np.array(input_image) |
|
|
|
|
|
rgb = np.transpose(image,(2,0,1)) |
|
rgb_norm = rgb / 255.0 * 2.0 - 1.0 |
|
rgb_norm = torch.from_numpy(rgb_norm).to(self.dtype) |
|
rgb_norm = rgb_norm.to(device) |
|
|
|
assert rgb_norm.min() >= -1.0 and rgb_norm.max() <= 1.0 |
|
|
|
|
|
duplicated_rgb = torch.stack([rgb_norm] * ensemble_size) |
|
single_rgb_dataset = TensorDataset(duplicated_rgb) |
|
|
|
|
|
if batch_size>0: |
|
_bs = batch_size |
|
else: |
|
_bs = 1 |
|
|
|
single_rgb_loader = DataLoader(single_rgb_dataset, batch_size=_bs, shuffle=False) |
|
|
|
|
|
depth_pred_ls = [] |
|
normal_pred_ls = [] |
|
|
|
if show_progress_bar: |
|
iterable_bar = tqdm( |
|
single_rgb_loader, desc=" " * 2 + "Inference batches", leave=False |
|
) |
|
else: |
|
iterable_bar = single_rgb_loader |
|
|
|
for batch in iterable_bar: |
|
(batched_image, )= batch |
|
|
|
depth_pred_raw, normal_pred_raw = self.single_infer( |
|
input_rgb=batched_image, |
|
num_inference_steps=denosing_steps, |
|
domain=domain, |
|
show_pbar=show_progress_bar, |
|
) |
|
depth_pred_ls.append(depth_pred_raw.detach().clone()) |
|
normal_pred_ls.append(normal_pred_raw.detach().clone()) |
|
|
|
depth_preds = torch.concat(depth_pred_ls, axis=0).squeeze() |
|
normal_preds = torch.concat(normal_pred_ls, axis=0).squeeze() |
|
torch.cuda.empty_cache() |
|
|
|
|
|
if ensemble_size > 1: |
|
depth_pred, pred_uncert = ensemble_depths( |
|
depth_preds, **(ensemble_kwargs or {}) |
|
) |
|
normal_pred = ensemble_normals(normal_preds) |
|
else: |
|
depth_pred = depth_preds |
|
normal_pred = normal_preds |
|
pred_uncert = None |
|
|
|
|
|
|
|
min_d = torch.min(depth_pred) |
|
max_d = torch.max(depth_pred) |
|
depth_pred = (depth_pred - min_d) / (max_d - min_d) |
|
|
|
|
|
depth_pred = depth_pred.cpu().numpy().astype(np.float32) |
|
normal_pred = normal_pred.cpu().numpy().astype(np.float32) |
|
|
|
|
|
if match_input_res: |
|
pred_img = Image.fromarray(depth_pred) |
|
pred_img = pred_img.resize(input_size) |
|
depth_pred = np.asarray(pred_img) |
|
normal_pred = cv2.resize(chw2hwc(normal_pred), input_size, interpolation = cv2.INTER_NEAREST) |
|
|
|
|
|
depth_pred = depth_pred.clip(0, 1) |
|
normal_pred = normal_pred.clip(-1, 1) |
|
|
|
|
|
depth_colored = colorize_depth_maps( |
|
depth_pred, 0, 1, cmap=color_map |
|
).squeeze() |
|
depth_colored = (depth_colored * 255).astype(np.uint8) |
|
depth_colored_hwc = chw2hwc(depth_colored) |
|
depth_colored_img = Image.fromarray(depth_colored_hwc) |
|
|
|
normal_colored = ((normal_pred + 1)/2 * 255).astype(np.uint8) |
|
normal_colored_img = Image.fromarray(normal_colored) |
|
|
|
return DepthNormalPipelineOutput( |
|
depth_np = depth_pred, |
|
depth_colored = depth_colored_img, |
|
normal_np = normal_pred, |
|
normal_colored = normal_colored_img, |
|
uncertainty=pred_uncert, |
|
) |
|
|
|
def __encode_img_embed(self, rgb): |
|
""" |
|
Encode clip embeddings for img |
|
""" |
|
clip_image_mean = torch.as_tensor(self.feature_extractor.image_mean)[:,None,None].to(device=self.device, dtype=self.dtype) |
|
clip_image_std = torch.as_tensor(self.feature_extractor.image_std)[:,None,None].to(device=self.device, dtype=self.dtype) |
|
|
|
img_in_proc = TF.resize((rgb +1)/2, |
|
(self.feature_extractor.crop_size['height'], self.feature_extractor.crop_size['width']), |
|
interpolation=InterpolationMode.BICUBIC, |
|
antialias=True |
|
) |
|
|
|
img_in_proc = ((img_in_proc.float() - clip_image_mean) / clip_image_std).to(self.dtype) |
|
img_embed = self.image_encoder(img_in_proc).image_embeds.unsqueeze(1).to(self.dtype) |
|
|
|
self.img_embed = img_embed |
|
|
|
|
|
@torch.no_grad() |
|
def single_infer(self,input_rgb:torch.Tensor, |
|
num_inference_steps:int, |
|
domain:str, |
|
show_pbar:bool,): |
|
|
|
device = input_rgb.device |
|
|
|
|
|
self.scheduler.set_timesteps(num_inference_steps, device=device) |
|
timesteps = self.scheduler.timesteps |
|
|
|
|
|
rgb_latent = self.encode_RGB(input_rgb) |
|
|
|
|
|
geo_latent = torch.randn(rgb_latent.shape, device=device, dtype=self.dtype).repeat(2,1,1,1) |
|
rgb_latent = rgb_latent.repeat(2,1,1,1) |
|
|
|
|
|
if self.img_embed is None: |
|
self.__encode_img_embed(input_rgb) |
|
|
|
batch_img_embed = self.img_embed.repeat( |
|
(rgb_latent.shape[0], 1, 1) |
|
) |
|
|
|
|
|
geo_class = torch.tensor([[0., 1.], [1, 0]], device=device, dtype=self.dtype) |
|
geo_embedding = torch.cat([torch.sin(geo_class), torch.cos(geo_class)], dim=-1) |
|
|
|
if domain == "indoor": |
|
domain_class = torch.tensor([[1., 0., 0]], device=device, dtype=self.dtype).repeat(2,1) |
|
elif domain == "outdoor": |
|
domain_class = torch.tensor([[0., 1., 0]], device=device, dtype=self.dtype).repeat(2,1) |
|
elif domain == "object": |
|
domain_class = torch.tensor([[0., 0., 1]], device=device, dtype=self.dtype).repeat(2,1) |
|
domain_embedding = torch.cat([torch.sin(domain_class), torch.cos(domain_class)], dim=-1) |
|
|
|
class_embedding = torch.cat((geo_embedding, domain_embedding), dim=-1) |
|
|
|
|
|
if show_pbar: |
|
iterable = tqdm( |
|
enumerate(timesteps), |
|
total=len(timesteps), |
|
leave=False, |
|
desc=" " * 4 + "Diffusion denoising", |
|
) |
|
else: |
|
iterable = enumerate(timesteps) |
|
|
|
for i, t in iterable: |
|
unet_input = torch.cat([rgb_latent, geo_latent], dim=1) |
|
|
|
|
|
noise_pred = self.unet( |
|
unet_input, t.repeat(2), encoder_hidden_states=batch_img_embed, class_labels=class_embedding |
|
).sample |
|
|
|
|
|
geo_latent = self.scheduler.step(noise_pred, t, geo_latent).prev_sample |
|
|
|
geo_latent = geo_latent |
|
torch.cuda.empty_cache() |
|
|
|
depth = self.decode_depth(geo_latent[0][None]) |
|
depth = torch.clip(depth, -1.0, 1.0) |
|
depth = (depth + 1.0) / 2.0 |
|
|
|
normal = self.decode_normal(geo_latent[1][None]) |
|
normal /= (torch.norm(normal, p=2, dim=1, keepdim=True)+1e-5) |
|
normal *= -1. |
|
|
|
return depth, normal |
|
|
|
|
|
def encode_RGB(self, rgb_in: torch.Tensor) -> torch.Tensor: |
|
""" |
|
Encode RGB image into latent. |
|
|
|
Args: |
|
rgb_in (`torch.Tensor`): |
|
Input RGB image to be encoded. |
|
|
|
Returns: |
|
`torch.Tensor`: Image latent. |
|
""" |
|
|
|
|
|
h = self.vae.encoder(rgb_in) |
|
|
|
moments = self.vae.quant_conv(h) |
|
mean, logvar = torch.chunk(moments, 2, dim=1) |
|
|
|
rgb_latent = mean * self.latent_scale_factor |
|
|
|
return rgb_latent |
|
|
|
def decode_depth(self, depth_latent: torch.Tensor) -> torch.Tensor: |
|
""" |
|
Decode depth latent into depth map. |
|
|
|
Args: |
|
depth_latent (`torch.Tensor`): |
|
Depth latent to be decoded. |
|
|
|
Returns: |
|
`torch.Tensor`: Decoded depth map. |
|
""" |
|
|
|
|
|
depth_latent = depth_latent / self.latent_scale_factor |
|
|
|
z = self.vae.post_quant_conv(depth_latent) |
|
stacked = self.vae.decoder(z) |
|
|
|
depth_mean = stacked.mean(dim=1, keepdim=True) |
|
return depth_mean |
|
|
|
def decode_normal(self, normal_latent: torch.Tensor) -> torch.Tensor: |
|
""" |
|
Decode normal latent into normal map. |
|
|
|
Args: |
|
normal_latent (`torch.Tensor`): |
|
Depth latent to be decoded. |
|
|
|
Returns: |
|
`torch.Tensor`: Decoded normal map. |
|
""" |
|
|
|
|
|
normal_latent = normal_latent / self.latent_scale_factor |
|
|
|
z = self.vae.post_quant_conv(normal_latent) |
|
normal = self.vae.decoder(z) |
|
return normal |
|
|
|
|
|
|