Inpaint / marigold /marigold_inpaint_pipeline.py
ZehanWang's picture
Upload folder using huggingface_hub
ebca029 verified
# Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved.
# Last modified: 2024-05-24
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# --------------------------------------------------------------------------
# If you find this code useful, we kindly ask you to cite our paper in your work.
# Please find bibtex at: https://github.com/prs-eth/Marigold#-citation
# More information about the method can be found at https://marigoldmonodepth.github.io
# --------------------------------------------------------------------------
import logging
from diffusers.image_processor import VaeImageProcessor
import pdb
from typing import Dict, Optional, Union
import PIL.Image
import numpy as np
import torch
from diffusers import (
AutoencoderKL,
DDIMScheduler,
DiffusionPipeline,
LCMScheduler,
PNDMScheduler,
UNet2DConditionModel,
)
from .duplicate_unet import DoubleUNet2DConditionModel
from torch.nn import Conv2d
from PIL import ImageDraw, ImageFont
from torch.nn.parameter import Parameter
from diffusers.utils import BaseOutput, make_image_grid
from PIL import Image
from torch.utils.data import DataLoader, TensorDataset
from torchvision.transforms import InterpolationMode
from torchvision.transforms.functional import pil_to_tensor, resize
from tqdm.auto import tqdm
from transformers import CLIPTextModel, CLIPTokenizer
from .util.batchsize import find_batch_size
from .util.ensemble import ensemble_depth
from .util.image_util import (
chw2hwc,
colorize_depth_maps,
get_tv_resample_method,
resize_max_res,
)
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
"""
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
"""
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
# rescale the results from guidance (fixes overexposure)
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
# mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
return noise_cfg
class MarigoldDepthOutput(BaseOutput):
"""
Output class for Marigold monocular depth prediction pipeline.
Args:
depth_np (`np.ndarray`):
Predicted depth map, with depth values in the range of [0, 1].
depth_colored (`PIL.Image.Image`):
Colorized depth map, with the shape of [3, H, W] and values in [0, 1].
uncertainty (`None` or `np.ndarray`):
Uncalibrated uncertainty(MAD, median absolute deviation) coming from ensembling.
"""
depth_np: np.ndarray
depth_colored: Union[None, Image.Image]
uncertainty: Union[None, np.ndarray]
class MarigoldInpaintPipeline(DiffusionPipeline):
"""
Pipeline for monocular depth estimation using Marigold: https://marigoldmonodepth.github.io.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
Args:
unet (`UNet2DConditionModel`):
Conditional U-Net to denoise the depth latent, conditioned on image latent.
vae (`AutoencoderKL`):
Variational Auto-Encoder (VAE) Model to encode and decode images and depth maps
to and from latent representations.
scheduler (`DDIMScheduler`):
A scheduler to be used in combination with `unet` to denoise the encoded image latents.
text_encoder (`CLIPTextModel`):
Text-encoder, for empty text embedding.
tokenizer (`CLIPTokenizer`):
CLIP tokenizer.
scale_invariant (`bool`, *optional*):
A model property specifying whether the predicted depth maps are scale-invariant. This value must be set in
the model config. When used together with the `shift_invariant=True` flag, the model is also called
"affine-invariant". NB: overriding this value is not supported.
shift_invariant (`bool`, *optional*):
A model property specifying whether the predicted depth maps are shift-invariant. This value must be set in
the model config. When used together with the `scale_invariant=True` flag, the model is also called
"affine-invariant". NB: overriding this value is not supported.
default_denoising_steps (`int`, *optional*):
The minimum number of denoising diffusion steps that are required to produce a prediction of reasonable
quality with the given model. This value must be set in the model config. When the pipeline is called
without explicitly setting `num_inference_steps`, the default value is used. This is required to ensure
reasonable results with various model flavors compatible with the pipeline, such as those relying on very
short denoising schedules (`LCMScheduler`) and those with full diffusion schedules (`DDIMScheduler`).
default_processing_resolution (`int`, *optional*):
The recommended value of the `processing_resolution` parameter of the pipeline. This value must be set in
the model config. When the pipeline is called without explicitly setting `processing_resolution`, the
default value is used. This is required to ensure reasonable results with various model flavors trained
with varying optimal processing resolution values.
"""
rgb_latent_scale_factor = 0.18215
depth_latent_scale_factor = 0.18215
def __init__(
self,
unet: DoubleUNet2DConditionModel,
vae: AutoencoderKL,
scheduler: Union[DDIMScheduler, LCMScheduler],
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
scale_invariant: Optional[bool] = True,
shift_invariant: Optional[bool] = True,
default_denoising_steps: Optional[int] = None,
default_processing_resolution: Optional[int] = None,
requires_safety_checker: bool = False,
):
super().__init__()
self.register_modules(
unet=unet,
vae=vae,
scheduler=scheduler,
text_encoder=text_encoder,
tokenizer=tokenizer,
)
self.register_to_config(
scale_invariant=scale_invariant,
shift_invariant=shift_invariant,
default_denoising_steps=default_denoising_steps,
default_processing_resolution=default_processing_resolution,
)
self.scale_invariant = scale_invariant
self.shift_invariant = shift_invariant
self.default_denoising_steps = default_denoising_steps
self.default_processing_resolution = default_processing_resolution
self.rgb_scheduler = None
self.depth_scheduler = None
self.empty_text_embed = None
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.mask_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True
)
self.register_to_config(requires_safety_checker=requires_safety_checker)
self.separate_list = [0,0]
@torch.no_grad()
def __call__(
self,
input_image: Union[Image.Image, torch.Tensor],
denoising_steps: Optional[int] = None,
ensemble_size: int = 5,
processing_res: Optional[int] = None,
match_input_res: bool = True,
resample_method: str = "bilinear",
batch_size: int = 0,
generator: Union[torch.Generator, None] = None,
color_map: str = "Spectral",
show_progress_bar: bool = True,
ensemble_kwargs: Dict = None,
) -> MarigoldDepthOutput:
"""
Function invoked when calling the pipeline.
Args:
input_image (`Image`):
Input RGB (or gray-scale) image.
denoising_steps (`int`, *optional*, defaults to `None`):
Number of denoising diffusion steps during inference. The default value `None` results in automatic
selection. The number of steps should be at least 10 with the full Marigold models, and between 1 and 4
for Marigold-LCM models.
ensemble_size (`int`, *optional*, defaults to `10`):
Number of predictions to be ensembled.
processing_res (`int`, *optional*, defaults to `None`):
Effective processing resolution. When set to `0`, processes at the original image resolution. This
produces crisper predictions, but may also lead to the overall loss of global context. The default
value `None` resolves to the optimal value from the model config.
match_input_res (`bool`, *optional*, defaults to `True`):
Resize depth prediction to match input resolution.
Only valid if `processing_res` > 0.
resample_method: (`str`, *optional*, defaults to `bilinear`):
Resampling method used to resize images and depth predictions. This can be one of `bilinear`, `bicubic` or `nearest`, defaults to: `bilinear`.
batch_size (`int`, *optional*, defaults to `0`):
Inference batch size, no bigger than `num_ensemble`.
If set to 0, the script will automatically decide the proper batch size.
generator (`torch.Generator`, *optional*, defaults to `None`)
Random generator for initial noise generation.
show_progress_bar (`bool`, *optional*, defaults to `True`):
Display a progress bar of diffusion denoising.
color_map (`str`, *optional*, defaults to `"Spectral"`, pass `None` to skip colorized depth map generation):
Colormap used to colorize the depth map.
scale_invariant (`str`, *optional*, defaults to `True`):
Flag of scale-invariant prediction, if True, scale will be adjusted from the raw prediction.
shift_invariant (`str`, *optional*, defaults to `True`):
Flag of shift-invariant prediction, if True, shift will be adjusted from the raw prediction, if False, near plane will be fixed at 0m.
ensemble_kwargs (`dict`, *optional*, defaults to `None`):
Arguments for detailed ensembling settings.
Returns:
`MarigoldDepthOutput`: Output class for Marigold monocular depth prediction pipeline, including:
- **depth_np** (`np.ndarray`) Predicted depth map, with depth values in the range of [0, 1]
- **depth_colored** (`PIL.Image.Image`) Colorized depth map, with the shape of [3, H, W] and values in [0, 1], None if `color_map` is `None`
- **uncertainty** (`None` or `np.ndarray`) Uncalibrated uncertainty(MAD, median absolute deviation)
coming from ensembling. None if `ensemble_size = 1`
"""
# Model-specific optimal default values leading to fast and reasonable results.
if denoising_steps is None:
denoising_steps = self.default_denoising_steps
if processing_res is None:
processing_res = self.default_processing_resolution
assert processing_res >= 0
assert ensemble_size >= 1
# Check if denoising step is reasonable
self._check_inference_step(denoising_steps)
resample_method: InterpolationMode = get_tv_resample_method(resample_method)
# ----------------- Image Preprocess -----------------
# Convert to torch tensor
if isinstance(input_image, Image.Image):
input_image = input_image.convert("RGB")
# convert to torch tensor [H, W, rgb] -> [rgb, H, W]
rgb = pil_to_tensor(input_image)
rgb = rgb.unsqueeze(0) # [1, rgb, H, W]
elif isinstance(input_image, torch.Tensor):
rgb = input_image
else:
raise TypeError(f"Unknown input type: {type(input_image) = }")
input_size = rgb.shape
assert (
4 == rgb.dim() and 3 == input_size[-3]
), f"Wrong input shape {input_size}, expected [1, rgb, H, W]"
# Resize image
if processing_res > 0:
rgb = resize_max_res(
rgb,
max_edge_resolution=processing_res,
resample_method=resample_method,
)
# Normalize rgb values
rgb_norm: torch.Tensor = rgb / 255.0 * 2.0 - 1.0 # [0, 255] -> [-1, 1]
rgb_norm = rgb_norm.to(self.dtype)
assert rgb_norm.min() >= -1.0 and rgb_norm.max() <= 1.0
# ----------------- Predicting depth -----------------
# Batch repeated input image
duplicated_rgb = rgb_norm.expand(ensemble_size, -1, -1, -1)
single_rgb_dataset = TensorDataset(duplicated_rgb)
if batch_size > 0:
_bs = batch_size
else:
_bs = find_batch_size(
ensemble_size=ensemble_size,
input_res=max(rgb_norm.shape[1:]),
dtype=self.dtype,
)
single_rgb_loader = DataLoader(
single_rgb_dataset, batch_size=_bs, shuffle=False
)
# Predict depth maps (batched)
depth_pred_ls = []
if show_progress_bar:
iterable = tqdm(
single_rgb_loader, desc=" " * 2 + "Inference batches", leave=False
)
else:
iterable = single_rgb_loader
for batch in iterable:
(batched_img,) = batch
depth_pred_raw = self.single_infer(
rgb_in=batched_img,
num_inference_steps=denoising_steps,
show_pbar=show_progress_bar,
generator=generator,
)
depth_pred_ls.append(depth_pred_raw.detach())
depth_preds = torch.concat(depth_pred_ls, dim=0)
torch.cuda.empty_cache() # clear vram cache for ensembling
# ----------------- Test-time ensembling -----------------
if ensemble_size > 1:
depth_pred, pred_uncert = ensemble_depth(
depth_preds,
scale_invariant=self.scale_invariant,
shift_invariant=self.shift_invariant,
max_res=50,
**(ensemble_kwargs or {}),
)
else:
depth_pred = depth_preds
pred_uncert = None
# Resize back to original resolution
if match_input_res:
depth_pred = resize(
depth_pred,
input_size[-2:],
interpolation=resample_method,
antialias=True,
)
# Convert to numpy
depth_pred = depth_pred.squeeze()
depth_pred = depth_pred.cpu().numpy()
if pred_uncert is not None:
pred_uncert = pred_uncert.squeeze().cpu().numpy()
# Clip output range
depth_pred = depth_pred.clip(0, 1)
# Colorize
if color_map is not None:
depth_colored = colorize_depth_maps(
depth_pred, 0, 1, cmap=color_map
).squeeze() # [3, H, W], value in (0, 1)
depth_colored = (depth_colored * 255).astype(np.uint8)
depth_colored_hwc = chw2hwc(depth_colored)
depth_colored_img = Image.fromarray(depth_colored_hwc)
else:
depth_colored_img = None
return MarigoldDepthOutput(
depth_np=depth_pred,
depth_colored=depth_colored_img,
uncertainty=pred_uncert,
)
def _replace_unet_conv_in(self):
# replace the first layer to accept 8 in_channels
_weight = self.unet.conv_in.weight.clone() # [320, 4, 3, 3]
_bias = self.unet.conv_in.bias.clone() # [320]
zero_weight = torch.zeros(_weight.shape).to(_weight.device)
_weight = torch.cat([_weight, zero_weight], dim=1)
# _weight = _weight.repeat((1, 2, 1, 1)) # Keep selected channel(s)
# half the activation magnitude
# _weight *= 0.5
# new conv_in channel
_n_convin_out_channel = self.unet.conv_in.out_channels
_new_conv_in = Conv2d(
8, _n_convin_out_channel, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)
)
_new_conv_in.weight = Parameter(_weight)
_new_conv_in.bias = Parameter(_bias)
self.unet.conv_in = _new_conv_in
logging.info("Unet conv_in layer is replaced")
# replace config
self.unet.config["in_channels"] = 8
logging.info("Unet config is updated")
return
def _replace_unet_conv_out(self):
# replace the first layer to accept 8 in_channels
_weight = self.unet.conv_out.weight.clone() # [8, 320, 3, 3]
_bias = self.unet.conv_out.bias.clone() # [320]
_weight = _weight.repeat((2, 1, 1, 1)) # Keep selected channel(s)
_bias = _bias.repeat((2))
# half the activation magnitude
# new conv_in channel
_n_convin_out_channel = self.unet.conv_out.out_channels
_new_conv_out = Conv2d(
_n_convin_out_channel, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)
)
_new_conv_out.weight = Parameter(_weight)
_new_conv_out.bias = Parameter(_bias)
self.unet.conv_out = _new_conv_out
logging.info("Unet conv_out layer is replaced")
# replace config
self.unet.config["out_channels"] = 8
logging.info("Unet config is updated")
return
def _check_inference_step(self, n_step: int) -> None:
"""
Check if denoising step is reasonable
Args:
n_step (`int`): denoising steps
"""
assert n_step >= 1
if isinstance(self.scheduler, DDIMScheduler):
if n_step < 10:
logging.warning(
f"Too few denoising steps: {n_step}. Recommended to use the LCM checkpoint for few-step inference."
)
elif isinstance(self.scheduler, LCMScheduler):
if not 1 <= n_step <= 4:
logging.warning(
f"Non-optimal setting of denoising steps: {n_step}. Recommended setting is 1-4 steps."
)
elif isinstance(self.scheduler, PNDMScheduler):
if n_step < 10:
logging.warning(
f"Too few denoising steps: {n_step}. Recommended to use the LCM checkpoint for few-step inference."
)
else:
raise RuntimeError(f"Unsupported scheduler type: {type(self.scheduler)}")
def encode_empty_text(self):
"""
Encode text embedding for empty prompt
"""
prompt = ""
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids.to(self.text_encoder.device)
self.empty_text_embed = self.text_encoder(text_input_ids)[0].to(self.dtype)
def encode_text(self, prompt):
"""
Encode text embedding for empty prompt
"""
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids.to(self.text_encoder.device)
text_embed = self.text_encoder(text_input_ids)[0].to(self.dtype)
return text_embed
def numpy_to_pil(self, images: np.ndarray) -> PIL.Image.Image:
"""
Convert a numpy image or a batch of images to a PIL image.
"""
if images.ndim == 3:
images = images[None, ...]
images = (images * 255).round().astype("uint8")
if images.shape[-1] == 1:
# special case for grayscale (single channel) images
pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
else:
pil_images = [Image.fromarray(image) for image in images]
return pil_images
def full_depth_rgb_inpaint(self,
rgb_in,
depth_in,
image_mask,
text_embed,
timesteps,
generator,
guidance_scale,
):
depth_latent = self.encode_depth(depth_in)
depth_mask = torch.zeros_like(image_mask)
depth_mask_latent = self.encode_depth(depth_in)
rgb_latent = torch.randn(
depth_latent.shape,
device=self.device,
dtype=self.unet.dtype,
generator=generator,
) * self.rgb_scheduler.init_noise_sigma
rgb_mask = image_mask
rgb_mask_latent = self.encode_rgb(rgb_in * (image_mask.squeeze() < 0.5), generator=generator)
rgb_mask = torch.nn.functional.interpolate(rgb_mask, size=rgb_latent.shape[-2:])
depth_mask = torch.nn.functional.interpolate(depth_mask, size=rgb_latent.shape[-2:])
for i, t in enumerate(timesteps):
cat_latent = torch.cat(
[rgb_latent, rgb_mask, rgb_mask_latent, depth_mask_latent, depth_latent, depth_mask, rgb_mask_latent,
depth_mask_latent], dim=1
).float() # [B, 9*2, h, w]
latent_model_input = torch.cat([cat_latent] * 2)
# predict the noise residual
with torch.no_grad():
partial_noise_pred = self.unet(
latent_model_input,
rgb_timestep=t,
depth_timestep=t,
encoder_hidden_states=text_embed,
return_dict=False,
depth2rgb_scale=0.2
)[0]
noise_pred = self.unet(
latent_model_input,
rgb_timestep=t,
depth_timestep=t,
encoder_hidden_states=text_embed,
return_dict=False,
# separate_list=self.separate_list
)[0]
# perform guidance
rgb_pred_wo_depth_text = partial_noise_pred[0, :4, :, :]
rgb_pred_wo_text = noise_pred[0, :4, :, :]
rgb_pred = noise_pred[1, :4, :, :]
noise_pred = rgb_pred_wo_depth_text + 2 * (rgb_pred_wo_text - rgb_pred_wo_depth_text) + 3 * (rgb_pred - rgb_pred_wo_text)
# compute the previous noisy sample x_t -> x_t-1
rgb_latent = self.rgb_scheduler.step(noise_pred, t, rgb_latent).prev_sample
return rgb_latent, depth_latent
def full_rgb_depth_inpaint(self,
rgb_in,
depth_in,
image_mask,
text_embed,
timesteps,
generator,
guidance_scale
):
rgb_latent = self.encode_rgb(rgb_in)
rgb_mask = torch.zeros_like(image_mask)
rgb_mask_latent = self.encode_rgb(rgb_in)
depth_latent = torch.randn(
rgb_latent.shape,
device=self.device,
dtype=self.unet.dtype,
generator=generator,
) * self.depth_scheduler.init_noise_sigma
depth_mask = image_mask
depth_mask_latent = self.encode_depth(depth_in * (image_mask.squeeze() < 0.5))
rgb_mask = torch.nn.functional.interpolate(rgb_mask, size=rgb_latent.shape[-2:])
depth_mask = torch.nn.functional.interpolate(depth_mask, size=rgb_latent.shape[-2:])
for i, t in enumerate(timesteps):
cat_latent = torch.cat(
[rgb_latent, rgb_mask, rgb_mask_latent, depth_mask_latent, depth_latent, depth_mask, rgb_mask_latent,
depth_mask_latent], dim=1
).float() # [B, 9*2, h, w]
latent_model_input = torch.cat([cat_latent] * 2)
# predict the noise residual
with torch.no_grad():
partial_noise_pred = self.unet(
latent_model_input,
rgb_timestep=t,
depth_timestep=t,
encoder_hidden_states=text_embed,
return_dict=False,
rgb2depth_scale=0.2
)[0]
noise_pred = self.unet(
latent_model_input,
rgb_timestep=t,
depth_timestep=t,
encoder_hidden_states=text_embed,
return_dict=False,
# separate_list=self.separate_list
)[0]
# compute the previous noisy sample x_t -> x_t-1
depth_pre_wo_rgb = partial_noise_pred[1, 4:, :, :]
depth_pre = depth_pre_wo_rgb + 4 * (noise_pred[1, 4:, :, :] - depth_pre_wo_rgb)
depth_latent = self.depth_scheduler.step(depth_pre, t, depth_latent, generator=generator).prev_sample
return rgb_latent, depth_latent
def joint_inpaint(self,
rgb_in,
depth_in,
image_mask,
text_embed,
timesteps,
generator,
guidance_scale
):
bs = rgb_in.shape[0]
h, w = int(rgb_in.shape[-2]/8), int(rgb_in.shape[-1]/8)
rgb_latent = torch.randn(
[bs, 4, h, w],
device=self.device,
dtype=self.unet.dtype,
generator=generator,
) * self.rgb_scheduler.init_noise_sigma
rgb_mask = image_mask
rgb_mask_latent = self.encode_rgb(rgb_in * (rgb_mask.squeeze() < 0.5), generator=generator)
depth_latent = torch.randn(
[bs, 4, h, w],
device=self.device,
dtype=self.unet.dtype,
generator=generator,
) * self.depth_scheduler.init_noise_sigma
depth_mask = image_mask
depth_mask_latent = self.encode_depth(depth_in * (image_mask.squeeze() < 0.5))
rgb_mask = torch.nn.functional.interpolate(rgb_mask, size=rgb_latent.shape[-2:])
depth_mask = torch.nn.functional.interpolate(depth_mask, size=rgb_latent.shape[-2:])
for i, t in enumerate(timesteps):
cat_latent = torch.cat(
[rgb_latent, rgb_mask, rgb_mask_latent, depth_mask_latent, depth_latent, depth_mask, rgb_mask_latent, depth_mask_latent], dim=1
).float() # [B, 9*2, h, w]
latent_model_input = torch.cat([cat_latent] * 2)
# predict the noise residual
with torch.no_grad():
partial_noise_pred = self.unet(
latent_model_input,
rgb_timestep=t,
depth_timestep=t,
encoder_hidden_states=text_embed,
return_dict=False,
depth2rgb_scale=0,
rgb2depth_scale=0.2
)[0]
noise_pred = self.unet(
latent_model_input,
rgb_timestep=t,
depth_timestep=t,
encoder_hidden_states=text_embed,
return_dict=False,
)[0]
# perform guidance
noise_pred_untext_undual, noise_pred_undual = partial_noise_pred.chunk(2)
noise_pred_untext, noise_pred_cond = noise_pred.chunk(2)
# noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
depth_noise_pred = noise_pred_undual + 3 * (noise_pred_cond - noise_pred_undual)
rgb_latent = self.rgb_scheduler.step(noise_pred_cond[:, :4, :, :], t, rgb_latent, return_dict=False)[0]
depth_latent = self.depth_scheduler.step(depth_noise_pred[:, 4:, :, :], t, depth_latent, generator=generator, return_dict=False)[0]
return rgb_latent, depth_latent
@torch.no_grad()
def _rgbd_inpaint(self,
input_image: [torch.Tensor, PIL.Image.Image],
depth_image: [torch.Tensor, PIL.Image.Image],
mask: [torch.Tensor, PIL.Image.Image],
prompt: str = '',
guidance_scale: float = 4.5,
generator: Union[torch.Generator, None] = None,
num_inference_steps: int = 50,
resample_method: str = "bilinear",
processing_res: int = 512,
mode: str = 'full_depth_rgb_inpaint'
) -> PIL.Image:
self._check_inference_step(num_inference_steps)
resample_method: InterpolationMode = get_tv_resample_method(resample_method)
# ----------------- encoder prompt -----------------
if isinstance(prompt, list):
bs = len(prompt)
batch_text_embed = []
for p in prompt:
batch_text_embed.append(self.encode_text(p))
batch_text_embed = torch.cat(batch_text_embed, dim=0)
elif isinstance(prompt, str):
bs = 1
batch_text_embed = self.encode_text(prompt).unsqueeze(0)
else:
raise NotImplementedError
if self.empty_text_embed is None:
self.encode_empty_text()
batch_empty_text_embed = self.empty_text_embed.repeat(
(batch_text_embed.shape[0], 1, 1)
).to(self.device) # [B, 2, 1024]
text_embed = torch.cat([batch_empty_text_embed, batch_text_embed], dim=0)
# ----------------- Image Preprocess -----------------
# Convert to torch tensor
if isinstance(input_image, Image.Image):
rgb_in = self.image_processor.preprocess(input_image, height=processing_res,
width=processing_res).to(self.dtype).to(self.device)
elif isinstance(input_image, torch.Tensor):
rgb = input_image.unsqueeze(0)
input_size = rgb.shape
assert (
4 == rgb.dim() and 3 == input_size[-3]
), f"Wrong input shape {input_size}, expected [1, rgb, H, W]"
if processing_res > 0:
rgb = resize(rgb, [processing_res, processing_res], resample_method, antialias=True)
rgb_norm: torch.Tensor = rgb / 255.0 * 2.0 - 1.0 # [0, 255] -> [-1, 1]
rgb_in = rgb_norm.to(self.dtype).to(self.device)
assert rgb_norm.min() >= -1.0 and rgb_norm.max() <= 1.0
if isinstance(depth_image, Image.Image):
depth = pil_to_tensor(depth_image)
depth = depth.unsqueeze(0) # [1, rgb, H, W]
elif isinstance(depth_image, torch.Tensor):
if len(depth_image.shape) == 3:
depth = depth_image.unsqueeze(0)
else:
depth = depth_image
# pdb.set_trace()
depth = depth.repeat(1, 3, 1, 1)
input_size = depth.shape
assert (
4 == depth.dim() and 3 == input_size[-3]
), f"Wrong input shape {input_size}, expected [1, 1, H, W]"
if processing_res > 0:
depth = resize(depth, [processing_res, processing_res], resample_method, antialias=True)
depth_norm: torch.Tensor = (depth - depth.min()) / (
depth.max() - depth.min()) * 2.0 - 1.0 # [0, 255] -> [-1, 1]
depth_in = depth_norm.to(self.dtype).to(self.device)
assert depth_norm.min() >= -1.0 and depth_norm.max() <= 1.0
if (mask.max() - mask.min()) != 0:
mask = (mask - mask.min()) / (mask.max() - mask.min()) * 255
image_mask = self.mask_processor.preprocess(mask, height=processing_res, width=processing_res).to(self.device)
self.rgb_scheduler.set_timesteps(num_inference_steps, device=self.device)
self.depth_scheduler.set_timesteps(num_inference_steps, device=self.device)
timesteps = self.rgb_scheduler.timesteps
if mode == 'full_rgb_depth_inpaint':
rgb_latent, depth_latent = self.full_rgb_depth_inpaint(rgb_in, depth_in, image_mask, text_embed, timesteps,
generator, guidance_scale=guidance_scale)
if mode == 'partial_depth_rgb_inpaint':
rgb_latent, depth_latent = self.partial_depth_rgb_inpaint(rgb_in, depth_in, image_mask, text_embed, timesteps,
generator, guidance_scale=guidance_scale)
if mode == 'full_depth_rgb_inpaint':
rgb_latent, depth_latent = self.full_depth_rgb_inpaint(rgb_in, depth_in, image_mask, text_embed, timesteps,
generator, guidance_scale=guidance_scale)
if mode == 'joint_inpaint':
rgb_latent, depth_latent = self.joint_inpaint(rgb_in, depth_in, image_mask, text_embed, timesteps,
generator, guidance_scale=guidance_scale)
image = self.decode_image(rgb_latent)
image = self.numpy_to_pil(image)[0]
d_image = self.decode_depth(depth_latent)
d_image = d_image.cpu().permute(0, 2, 3, 1).numpy()
d_image = (d_image - d_image.min()) / (d_image.max() - d_image.min())
d_image = self.numpy_to_pil(d_image)[0]
depth = depth.squeeze().permute(1, 2, 0).cpu().numpy()
depth = (depth - depth.min()) / (depth.max() - depth.min())
ori_depth = self.numpy_to_pil(depth)[0]
ori_image = input_image.squeeze().permute(1, 2, 0).cpu().numpy()
ori_image = self.numpy_to_pil(ori_image/255)[0]
image_mask = self.numpy_to_pil(image_mask.permute(0, 2, 3, 1).cpu().numpy())[0]
cat_image = make_image_grid([ori_image, ori_depth, image_mask, image, d_image], rows=1, cols=5)
return cat_image
def encode_rgb(self, rgb_in: torch.Tensor, generator=None) -> torch.Tensor:
"""
Encode RGB image into latent.
Args:
rgb_in (`torch.Tensor`):
Input RGB image to be encoded.
Returns:
`torch.Tensor`: Image latent.
"""
# encode
image_latents = self.vae.encode(rgb_in).latent_dist.sample(generator=generator)
image_latents = self.vae.config.scaling_factor * image_latents
return image_latents
def encode_depth(self, depth_in: torch.Tensor) -> torch.Tensor:
"""
Encode RGB image into latent.
Args:
rgb_in (`torch.Tensor`):
Input RGB image to be encoded.
Returns:
`torch.Tensor`: Image latent.
"""
# encode
h = self.vae.encoder(depth_in)
moments = self.vae.quant_conv(h)
mean, logvar = torch.chunk(moments, 2, dim=1)
# scale latent
depth_latent = mean * self.depth_latent_scale_factor
return depth_latent
def decode_image(self, latents):
latents = 1 / self.vae.config.scaling_factor * latents
z = self.vae.post_quant_conv(latents)
image = self.vae.decoder(z)
image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
return image
def decode_depth(self, depth_latent: torch.Tensor) -> torch.Tensor:
"""
Decode depth latent into depth map.
Args:
depth_latent (`torch.Tensor`):
Depth latent to be decoded.
Returns:
`torch.Tensor`: Decoded depth map.
"""
# scale latent
depth_latent = depth_latent / self.depth_latent_scale_factor
# decode
z = self.vae.post_quant_conv(depth_latent)
stacked = self.vae.decoder(z)
# mean of output channels
depth_mean = stacked.mean(dim=1, keepdim=True)
return depth_mean
def post_process_rgbd(self, prompts, rgb_image, depth_image):
rgbd_images = []
for idx, p in enumerate(prompts):
image1, image2 = rgb_image[idx], depth_image[idx]
width1, height1 = image1.size
width2, height2 = image2.size
font = ImageFont.load_default(size=20)
text = p
draw = ImageDraw.Draw(image1)
text_bbox = draw.textbbox((0, 0), text, font=font)
text_width = text_bbox[2] - text_bbox[0]
text_height = text_bbox[3] - text_bbox[1]
new_image = Image.new('RGB', (width1 + width2, max(height1, height2) + text_height), (255, 255, 255))
text_x = (new_image.width - text_width) // 2
text_y = 0
draw = ImageDraw.Draw(new_image)
draw.text((text_x, text_y), text, fill="black", font=font)
new_image.paste(image1, (0, text_height))
new_image.paste(image2, (width1, text_height))
rgbd_images.append(pil_to_tensor(new_image))
return rgbd_images