Spaces:
Running
on
Zero
Running
on
Zero
from typing import List, Optional, Tuple, Union | |
import numpy as np | |
import PIL | |
import torch | |
import torch.nn.functional as F | |
from PIL import Image | |
from ... import ConfigMixin | |
from ...configuration_utils import register_to_config | |
from ...image_processor import PipelineImageInput | |
from ...utils import CONFIG_NAME, logging | |
from ...utils.import_utils import is_matplotlib_available | |
logger = logging.get_logger(__name__) # pylint: disable=invalid-name | |
class MarigoldImageProcessor(ConfigMixin): | |
config_name = CONFIG_NAME | |
def __init__( | |
self, | |
vae_scale_factor: int = 8, | |
do_normalize: bool = True, | |
do_range_check: bool = True, | |
): | |
super().__init__() | |
def expand_tensor_or_array(images: Union[torch.Tensor, np.ndarray]) -> Union[torch.Tensor, np.ndarray]: | |
""" | |
Expand a tensor or array to a specified number of images. | |
""" | |
if isinstance(images, np.ndarray): | |
if images.ndim == 2: # [H,W] -> [1,H,W,1] | |
images = images[None, ..., None] | |
if images.ndim == 3: # [H,W,C] -> [1,H,W,C] | |
images = images[None] | |
elif isinstance(images, torch.Tensor): | |
if images.ndim == 2: # [H,W] -> [1,1,H,W] | |
images = images[None, None] | |
elif images.ndim == 3: # [1,H,W] -> [1,1,H,W] | |
images = images[None] | |
else: | |
raise ValueError(f"Unexpected input type: {type(images)}") | |
return images | |
def pt_to_numpy(images: torch.Tensor) -> np.ndarray: | |
""" | |
Convert a PyTorch tensor to a NumPy image. | |
""" | |
images = images.cpu().permute(0, 2, 3, 1).float().numpy() | |
return images | |
def numpy_to_pt(images: np.ndarray) -> torch.Tensor: | |
""" | |
Convert a NumPy image to a PyTorch tensor. | |
""" | |
if np.issubdtype(images.dtype, np.integer) and not np.issubdtype(images.dtype, np.unsignedinteger): | |
raise ValueError(f"Input image dtype={images.dtype} cannot be a signed integer.") | |
if np.issubdtype(images.dtype, np.complexfloating): | |
raise ValueError(f"Input image dtype={images.dtype} cannot be complex.") | |
if np.issubdtype(images.dtype, bool): | |
raise ValueError(f"Input image dtype={images.dtype} cannot be boolean.") | |
images = torch.from_numpy(images.transpose(0, 3, 1, 2)) | |
return images | |
def resize_antialias( | |
image: torch.Tensor, size: Tuple[int, int], mode: str, is_aa: Optional[bool] = None | |
) -> torch.Tensor: | |
if not torch.is_tensor(image): | |
raise ValueError(f"Invalid input type={type(image)}.") | |
if not torch.is_floating_point(image): | |
raise ValueError(f"Invalid input dtype={image.dtype}.") | |
if image.dim() != 4: | |
raise ValueError(f"Invalid input dimensions; shape={image.shape}.") | |
antialias = is_aa and mode in ("bilinear", "bicubic") | |
image = F.interpolate(image, size, mode=mode, antialias=antialias) | |
return image | |
def resize_to_max_edge(image: torch.Tensor, max_edge_sz: int, mode: str) -> torch.Tensor: | |
if not torch.is_tensor(image): | |
raise ValueError(f"Invalid input type={type(image)}.") | |
if not torch.is_floating_point(image): | |
raise ValueError(f"Invalid input dtype={image.dtype}.") | |
if image.dim() != 4: | |
raise ValueError(f"Invalid input dimensions; shape={image.shape}.") | |
h, w = image.shape[-2:] | |
max_orig = max(h, w) | |
new_h = h * max_edge_sz // max_orig | |
new_w = w * max_edge_sz // max_orig | |
if new_h == 0 or new_w == 0: | |
raise ValueError(f"Extreme aspect ratio of the input image: [{w} x {h}]") | |
image = MarigoldImageProcessor.resize_antialias(image, (new_h, new_w), mode, is_aa=True) | |
return image | |
def pad_image(image: torch.Tensor, align: int) -> Tuple[torch.Tensor, Tuple[int, int]]: | |
if not torch.is_tensor(image): | |
raise ValueError(f"Invalid input type={type(image)}.") | |
if not torch.is_floating_point(image): | |
raise ValueError(f"Invalid input dtype={image.dtype}.") | |
if image.dim() != 4: | |
raise ValueError(f"Invalid input dimensions; shape={image.shape}.") | |
h, w = image.shape[-2:] | |
ph, pw = -h % align, -w % align | |
image = F.pad(image, (0, pw, 0, ph), mode="replicate") | |
return image, (ph, pw) | |
def unpad_image(image: torch.Tensor, padding: Tuple[int, int]) -> torch.Tensor: | |
if not torch.is_tensor(image): | |
raise ValueError(f"Invalid input type={type(image)}.") | |
if not torch.is_floating_point(image): | |
raise ValueError(f"Invalid input dtype={image.dtype}.") | |
if image.dim() != 4: | |
raise ValueError(f"Invalid input dimensions; shape={image.shape}.") | |
ph, pw = padding | |
uh = None if ph == 0 else -ph | |
uw = None if pw == 0 else -pw | |
image = image[:, :, :uh, :uw] | |
return image | |
def load_image_canonical( | |
image: Union[torch.Tensor, np.ndarray, Image.Image], | |
device: torch.device = torch.device("cpu"), | |
dtype: torch.dtype = torch.float32, | |
) -> Tuple[torch.Tensor, int]: | |
if isinstance(image, Image.Image): | |
image = np.array(image) | |
image_dtype_max = None | |
if isinstance(image, (np.ndarray, torch.Tensor)): | |
image = MarigoldImageProcessor.expand_tensor_or_array(image) | |
if image.ndim != 4: | |
raise ValueError("Input image is not 2-, 3-, or 4-dimensional.") | |
if isinstance(image, np.ndarray): | |
if np.issubdtype(image.dtype, np.integer) and not np.issubdtype(image.dtype, np.unsignedinteger): | |
raise ValueError(f"Input image dtype={image.dtype} cannot be a signed integer.") | |
if np.issubdtype(image.dtype, np.complexfloating): | |
raise ValueError(f"Input image dtype={image.dtype} cannot be complex.") | |
if np.issubdtype(image.dtype, bool): | |
raise ValueError(f"Input image dtype={image.dtype} cannot be boolean.") | |
if np.issubdtype(image.dtype, np.unsignedinteger): | |
image_dtype_max = np.iinfo(image.dtype).max | |
image = image.astype(np.float32) # because torch does not have unsigned dtypes beyond torch.uint8 | |
image = MarigoldImageProcessor.numpy_to_pt(image) | |
if torch.is_tensor(image) and not torch.is_floating_point(image) and image_dtype_max is None: | |
if image.dtype != torch.uint8: | |
raise ValueError(f"Image dtype={image.dtype} is not supported.") | |
image_dtype_max = 255 | |
if not torch.is_tensor(image): | |
raise ValueError(f"Input type unsupported: {type(image)}.") | |
if image.shape[1] == 1: | |
image = image.repeat(1, 3, 1, 1) # [N,1,H,W] -> [N,3,H,W] | |
if image.shape[1] != 3: | |
raise ValueError(f"Input image is not 1- or 3-channel: {image.shape}.") | |
image = image.to(device=device, dtype=dtype) | |
if image_dtype_max is not None: | |
image = image / image_dtype_max | |
return image | |
def check_image_values_range(image: torch.Tensor) -> None: | |
if not torch.is_tensor(image): | |
raise ValueError(f"Invalid input type={type(image)}.") | |
if not torch.is_floating_point(image): | |
raise ValueError(f"Invalid input dtype={image.dtype}.") | |
if image.min().item() < 0.0 or image.max().item() > 1.0: | |
raise ValueError("Input image data is partially outside of the [0,1] range.") | |
def preprocess( | |
self, | |
image: PipelineImageInput, | |
processing_resolution: Optional[int] = None, | |
resample_method_input: str = "bilinear", | |
device: torch.device = torch.device("cpu"), | |
dtype: torch.dtype = torch.float32, | |
): | |
if isinstance(image, list): | |
images = None | |
for i, img in enumerate(image): | |
img = self.load_image_canonical(img, device, dtype) # [N,3,H,W] | |
if images is None: | |
images = img | |
else: | |
if images.shape[2:] != img.shape[2:]: | |
raise ValueError( | |
f"Input image[{i}] has incompatible dimensions {img.shape[2:]} with the previous images " | |
f"{images.shape[2:]}" | |
) | |
images = torch.cat((images, img), dim=0) | |
image = images | |
del images | |
else: | |
image = self.load_image_canonical(image, device, dtype) # [N,3,H,W] | |
original_resolution = image.shape[2:] | |
if self.config.do_range_check: | |
self.check_image_values_range(image) | |
if self.config.do_normalize: | |
image = image * 2.0 - 1.0 | |
if processing_resolution is not None and processing_resolution > 0: | |
image = self.resize_to_max_edge(image, processing_resolution, resample_method_input) # [N,3,PH,PW] | |
image, padding = self.pad_image(image, self.config.vae_scale_factor) # [N,3,PPH,PPW] | |
return image, padding, original_resolution | |
def colormap( | |
image: Union[np.ndarray, torch.Tensor], | |
cmap: str = "Spectral", | |
bytes: bool = False, | |
_force_method: Optional[str] = None, | |
) -> Union[np.ndarray, torch.Tensor]: | |
""" | |
Converts a monochrome image into an RGB image by applying the specified colormap. This function mimics the | |
behavior of matplotlib.colormaps, but allows the user to use the most discriminative color maps ("Spectral", | |
"binary") without having to install or import matplotlib. For all other cases, the function will attempt to use | |
the native implementation. | |
Args: | |
image: 2D tensor of values between 0 and 1, either as np.ndarray or torch.Tensor. | |
cmap: Colormap name. | |
bytes: Whether to return the output as uint8 or floating point image. | |
_force_method: | |
Can be used to specify whether to use the native implementation (`"matplotlib"`), the efficient custom | |
implementation of the select color maps (`"custom"`), or rely on autodetection (`None`, default). | |
Returns: | |
An RGB-colorized tensor corresponding to the input image. | |
""" | |
if not (torch.is_tensor(image) or isinstance(image, np.ndarray)): | |
raise ValueError("Argument must be a numpy array or torch tensor.") | |
if _force_method not in (None, "matplotlib", "custom"): | |
raise ValueError("_force_method must be either `None`, `'matplotlib'` or `'custom'`.") | |
supported_cmaps = { | |
"binary": [ | |
(1.0, 1.0, 1.0), | |
(0.0, 0.0, 0.0), | |
], | |
"Spectral": [ # Taken from matplotlib/_cm.py | |
(0.61960784313725492, 0.003921568627450980, 0.25882352941176473), # 0.0 -> [0] | |
(0.83529411764705885, 0.24313725490196078, 0.30980392156862746), | |
(0.95686274509803926, 0.42745098039215684, 0.2627450980392157), | |
(0.99215686274509807, 0.68235294117647061, 0.38039215686274508), | |
(0.99607843137254903, 0.8784313725490196, 0.54509803921568623), | |
(1.0, 1.0, 0.74901960784313726), | |
(0.90196078431372551, 0.96078431372549022, 0.59607843137254901), | |
(0.6705882352941176, 0.8666666666666667, 0.64313725490196083), | |
(0.4, 0.76078431372549016, 0.6470588235294118), | |
(0.19607843137254902, 0.53333333333333333, 0.74117647058823533), | |
(0.36862745098039218, 0.30980392156862746, 0.63529411764705879), # 1.0 -> [K-1] | |
], | |
} | |
def method_matplotlib(image, cmap, bytes=False): | |
if is_matplotlib_available(): | |
import matplotlib | |
else: | |
return None | |
arg_is_pt, device = torch.is_tensor(image), None | |
if arg_is_pt: | |
image, device = image.cpu().numpy(), image.device | |
if cmap not in matplotlib.colormaps: | |
raise ValueError( | |
f"Unexpected color map {cmap}; available options are: {', '.join(list(matplotlib.colormaps.keys()))}" | |
) | |
cmap = matplotlib.colormaps[cmap] | |
out = cmap(image, bytes=bytes) # [?,4] | |
out = out[..., :3] # [?,3] | |
if arg_is_pt: | |
out = torch.tensor(out, device=device) | |
return out | |
def method_custom(image, cmap, bytes=False): | |
arg_is_np = isinstance(image, np.ndarray) | |
if arg_is_np: | |
image = torch.tensor(image) | |
if image.dtype == torch.uint8: | |
image = image.float() / 255 | |
else: | |
image = image.float() | |
is_cmap_reversed = cmap.endswith("_r") | |
if is_cmap_reversed: | |
cmap = cmap[:-2] | |
if cmap not in supported_cmaps: | |
raise ValueError( | |
f"Only {list(supported_cmaps.keys())} color maps are available without installing matplotlib." | |
) | |
cmap = supported_cmaps[cmap] | |
if is_cmap_reversed: | |
cmap = cmap[::-1] | |
cmap = torch.tensor(cmap, dtype=torch.float, device=image.device) # [K,3] | |
K = cmap.shape[0] | |
pos = image.clamp(min=0, max=1) * (K - 1) | |
left = pos.long() | |
right = (left + 1).clamp(max=K - 1) | |
d = (pos - left.float()).unsqueeze(-1) | |
left_colors = cmap[left] | |
right_colors = cmap[right] | |
out = (1 - d) * left_colors + d * right_colors | |
if bytes: | |
out = (out * 255).to(torch.uint8) | |
if arg_is_np: | |
out = out.numpy() | |
return out | |
if _force_method is None and torch.is_tensor(image) and cmap == "Spectral": | |
return method_custom(image, cmap, bytes) | |
out = None | |
if _force_method != "custom": | |
out = method_matplotlib(image, cmap, bytes) | |
if _force_method == "matplotlib" and out is None: | |
raise ImportError("Make sure to install matplotlib if you want to use a color map other than 'Spectral'.") | |
if out is None: | |
out = method_custom(image, cmap, bytes) | |
return out | |
def visualize_depth( | |
depth: Union[ | |
PIL.Image.Image, | |
np.ndarray, | |
torch.Tensor, | |
List[PIL.Image.Image], | |
List[np.ndarray], | |
List[torch.Tensor], | |
], | |
val_min: float = 0.0, | |
val_max: float = 1.0, | |
color_map: str = "Spectral", | |
) -> Union[PIL.Image.Image, List[PIL.Image.Image]]: | |
""" | |
Visualizes depth maps, such as predictions of the `MarigoldDepthPipeline`. | |
Args: | |
depth (`Union[PIL.Image.Image, np.ndarray, torch.Tensor, List[PIL.Image.Image], List[np.ndarray], | |
List[torch.Tensor]]`): Depth maps. | |
val_min (`float`, *optional*, defaults to `0.0`): Minimum value of the visualized depth range. | |
val_max (`float`, *optional*, defaults to `1.0`): Maximum value of the visualized depth range. | |
color_map (`str`, *optional*, defaults to `"Spectral"`): Color map used to convert a single-channel | |
depth prediction into colored representation. | |
Returns: `PIL.Image.Image` or `List[PIL.Image.Image]` with depth maps visualization. | |
""" | |
if val_max <= val_min: | |
raise ValueError(f"Invalid values range: [{val_min}, {val_max}].") | |
def visualize_depth_one(img, idx=None): | |
prefix = "Depth" + (f"[{idx}]" if idx else "") | |
if isinstance(img, PIL.Image.Image): | |
if img.mode != "I;16": | |
raise ValueError(f"{prefix}: invalid PIL mode={img.mode}.") | |
img = np.array(img).astype(np.float32) / (2**16 - 1) | |
if isinstance(img, np.ndarray) or torch.is_tensor(img): | |
if img.ndim != 2: | |
raise ValueError(f"{prefix}: unexpected shape={img.shape}.") | |
if isinstance(img, np.ndarray): | |
img = torch.from_numpy(img) | |
if not torch.is_floating_point(img): | |
raise ValueError(f"{prefix}: unexected dtype={img.dtype}.") | |
else: | |
raise ValueError(f"{prefix}: unexpected type={type(img)}.") | |
if val_min != 0.0 or val_max != 1.0: | |
img = (img - val_min) / (val_max - val_min) | |
img = MarigoldImageProcessor.colormap(img, cmap=color_map, bytes=True) # [H,W,3] | |
img = PIL.Image.fromarray(img.cpu().numpy()) | |
return img | |
if depth is None or isinstance(depth, list) and any(o is None for o in depth): | |
raise ValueError("Input depth is `None`") | |
if isinstance(depth, (np.ndarray, torch.Tensor)): | |
depth = MarigoldImageProcessor.expand_tensor_or_array(depth) | |
if isinstance(depth, np.ndarray): | |
depth = MarigoldImageProcessor.numpy_to_pt(depth) # [N,H,W,1] -> [N,1,H,W] | |
if not (depth.ndim == 4 and depth.shape[1] == 1): # [N,1,H,W] | |
raise ValueError(f"Unexpected input shape={depth.shape}, expecting [N,1,H,W].") | |
return [visualize_depth_one(img[0], idx) for idx, img in enumerate(depth)] | |
elif isinstance(depth, list): | |
return [visualize_depth_one(img, idx) for idx, img in enumerate(depth)] | |
else: | |
raise ValueError(f"Unexpected input type: {type(depth)}") | |
def export_depth_to_16bit_png( | |
depth: Union[np.ndarray, torch.Tensor, List[np.ndarray], List[torch.Tensor]], | |
val_min: float = 0.0, | |
val_max: float = 1.0, | |
) -> Union[PIL.Image.Image, List[PIL.Image.Image]]: | |
def export_depth_to_16bit_png_one(img, idx=None): | |
prefix = "Depth" + (f"[{idx}]" if idx else "") | |
if not isinstance(img, np.ndarray) and not torch.is_tensor(img): | |
raise ValueError(f"{prefix}: unexpected type={type(img)}.") | |
if img.ndim != 2: | |
raise ValueError(f"{prefix}: unexpected shape={img.shape}.") | |
if torch.is_tensor(img): | |
img = img.cpu().numpy() | |
if not np.issubdtype(img.dtype, np.floating): | |
raise ValueError(f"{prefix}: unexected dtype={img.dtype}.") | |
if val_min != 0.0 or val_max != 1.0: | |
img = (img - val_min) / (val_max - val_min) | |
img = (img * (2**16 - 1)).astype(np.uint16) | |
img = PIL.Image.fromarray(img, mode="I;16") | |
return img | |
if depth is None or isinstance(depth, list) and any(o is None for o in depth): | |
raise ValueError("Input depth is `None`") | |
if isinstance(depth, (np.ndarray, torch.Tensor)): | |
depth = MarigoldImageProcessor.expand_tensor_or_array(depth) | |
if isinstance(depth, np.ndarray): | |
depth = MarigoldImageProcessor.numpy_to_pt(depth) # [N,H,W,1] -> [N,1,H,W] | |
if not (depth.ndim == 4 and depth.shape[1] == 1): | |
raise ValueError(f"Unexpected input shape={depth.shape}, expecting [N,1,H,W].") | |
return [export_depth_to_16bit_png_one(img[0], idx) for idx, img in enumerate(depth)] | |
elif isinstance(depth, list): | |
return [export_depth_to_16bit_png_one(img, idx) for idx, img in enumerate(depth)] | |
else: | |
raise ValueError(f"Unexpected input type: {type(depth)}") | |
def visualize_normals( | |
normals: Union[ | |
np.ndarray, | |
torch.Tensor, | |
List[np.ndarray], | |
List[torch.Tensor], | |
], | |
flip_x: bool = False, | |
flip_y: bool = False, | |
flip_z: bool = False, | |
) -> Union[PIL.Image.Image, List[PIL.Image.Image]]: | |
""" | |
Visualizes surface normals, such as predictions of the `MarigoldNormalsPipeline`. | |
Args: | |
normals (`Union[np.ndarray, torch.Tensor, List[np.ndarray], List[torch.Tensor]]`): | |
Surface normals. | |
flip_x (`bool`, *optional*, defaults to `False`): Flips the X axis of the normals frame of reference. | |
Default direction is right. | |
flip_y (`bool`, *optional*, defaults to `False`): Flips the Y axis of the normals frame of reference. | |
Default direction is top. | |
flip_z (`bool`, *optional*, defaults to `False`): Flips the Z axis of the normals frame of reference. | |
Default direction is facing the observer. | |
Returns: `PIL.Image.Image` or `List[PIL.Image.Image]` with surface normals visualization. | |
""" | |
flip_vec = None | |
if any((flip_x, flip_y, flip_z)): | |
flip_vec = torch.tensor( | |
[ | |
(-1) ** flip_x, | |
(-1) ** flip_y, | |
(-1) ** flip_z, | |
], | |
dtype=torch.float32, | |
) | |
def visualize_normals_one(img, idx=None): | |
img = img.permute(1, 2, 0) | |
if flip_vec is not None: | |
img *= flip_vec.to(img.device) | |
img = (img + 1.0) * 0.5 | |
img = (img * 255).to(dtype=torch.uint8, device="cpu").numpy() | |
img = PIL.Image.fromarray(img) | |
return img | |
if normals is None or isinstance(normals, list) and any(o is None for o in normals): | |
raise ValueError("Input normals is `None`") | |
if isinstance(normals, (np.ndarray, torch.Tensor)): | |
normals = MarigoldImageProcessor.expand_tensor_or_array(normals) | |
if isinstance(normals, np.ndarray): | |
normals = MarigoldImageProcessor.numpy_to_pt(normals) # [N,3,H,W] | |
if not (normals.ndim == 4 and normals.shape[1] == 3): | |
raise ValueError(f"Unexpected input shape={normals.shape}, expecting [N,3,H,W].") | |
return [visualize_normals_one(img, idx) for idx, img in enumerate(normals)] | |
elif isinstance(normals, list): | |
return [visualize_normals_one(img, idx) for idx, img in enumerate(normals)] | |
else: | |
raise ValueError(f"Unexpected input type: {type(normals)}") | |
def visualize_uncertainty( | |
uncertainty: Union[ | |
np.ndarray, | |
torch.Tensor, | |
List[np.ndarray], | |
List[torch.Tensor], | |
], | |
saturation_percentile=95, | |
) -> Union[PIL.Image.Image, List[PIL.Image.Image]]: | |
""" | |
Visualizes dense uncertainties, such as produced by `MarigoldDepthPipeline` or `MarigoldNormalsPipeline`. | |
Args: | |
uncertainty (`Union[np.ndarray, torch.Tensor, List[np.ndarray], List[torch.Tensor]]`): | |
Uncertainty maps. | |
saturation_percentile (`int`, *optional*, defaults to `95`): | |
Specifies the percentile uncertainty value visualized with maximum intensity. | |
Returns: `PIL.Image.Image` or `List[PIL.Image.Image]` with uncertainty visualization. | |
""" | |
def visualize_uncertainty_one(img, idx=None): | |
prefix = "Uncertainty" + (f"[{idx}]" if idx else "") | |
if img.min() < 0: | |
raise ValueError(f"{prefix}: unexected data range, min={img.min()}.") | |
img = img.squeeze(0).cpu().numpy() | |
saturation_value = np.percentile(img, saturation_percentile) | |
img = np.clip(img * 255 / saturation_value, 0, 255) | |
img = img.astype(np.uint8) | |
img = PIL.Image.fromarray(img) | |
return img | |
if uncertainty is None or isinstance(uncertainty, list) and any(o is None for o in uncertainty): | |
raise ValueError("Input uncertainty is `None`") | |
if isinstance(uncertainty, (np.ndarray, torch.Tensor)): | |
uncertainty = MarigoldImageProcessor.expand_tensor_or_array(uncertainty) | |
if isinstance(uncertainty, np.ndarray): | |
uncertainty = MarigoldImageProcessor.numpy_to_pt(uncertainty) # [N,1,H,W] | |
if not (uncertainty.ndim == 4 and uncertainty.shape[1] == 1): | |
raise ValueError(f"Unexpected input shape={uncertainty.shape}, expecting [N,1,H,W].") | |
return [visualize_uncertainty_one(img, idx) for idx, img in enumerate(uncertainty)] | |
elif isinstance(uncertainty, list): | |
return [visualize_uncertainty_one(img, idx) for idx, img in enumerate(uncertainty)] | |
else: | |
raise ValueError(f"Unexpected input type: {type(uncertainty)}") | |