from typing import List import PIL.Image import PIL.ImageOps from packaging import version from PIL import Image if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"): PIL_INTERPOLATION = { "linear": PIL.Image.Resampling.BILINEAR, "bilinear": PIL.Image.Resampling.BILINEAR, "bicubic": PIL.Image.Resampling.BICUBIC, "lanczos": PIL.Image.Resampling.LANCZOS, "nearest": PIL.Image.Resampling.NEAREST, } else: PIL_INTERPOLATION = { "linear": PIL.Image.LINEAR, "bilinear": PIL.Image.BILINEAR, "bicubic": PIL.Image.BICUBIC, "lanczos": PIL.Image.LANCZOS, "nearest": PIL.Image.NEAREST, } def pt_to_pil(images): """ Convert a torch image to a PIL image. """ images = (images / 2 + 0.5).clamp(0, 1) images = images.cpu().permute(0, 2, 3, 1).float().numpy() images = numpy_to_pil(images) return images def numpy_to_pil(images): """ Convert a numpy image or a batch of images to a PIL image. """ if images.ndim == 3: images = images[None, ...] images = (images * 255).round().astype("uint8") if images.shape[-1] == 1: # special case for grayscale (single channel) images pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images] else: pil_images = [Image.fromarray(image) for image in images] return pil_images def make_image_grid(images: List[PIL.Image.Image], rows: int, cols: int, resize: int = None) -> PIL.Image.Image: """ Prepares a single grid of images. Useful for visualization purposes. """ assert len(images) == rows * cols if resize is not None: images = [img.resize((resize, resize)) for img in images] w, h = images[0].size grid = Image.new("RGB", size=(cols * w, rows * h)) for i, img in enumerate(images): grid.paste(img, box=(i % cols * w, i // cols * h)) return grid