|
import math |
|
from typing import List, Optional, Union |
|
|
|
import numpy as np |
|
import torch |
|
from PIL import Image |
|
|
|
|
|
def tensor_to_image( |
|
data: Union[Image.Image, torch.Tensor, np.ndarray], |
|
batched: bool = False, |
|
format: str = "HWC", |
|
) -> Union[Image.Image, List[Image.Image]]: |
|
if isinstance(data, Image.Image): |
|
return data |
|
if isinstance(data, torch.Tensor): |
|
data = data.detach().cpu().numpy() |
|
if data.dtype == np.float32 or data.dtype == np.float16: |
|
data = (data * 255).astype(np.uint8) |
|
elif data.dtype == np.bool_: |
|
data = data.astype(np.uint8) * 255 |
|
assert data.dtype == np.uint8 |
|
if format == "CHW": |
|
if batched and data.ndim == 4: |
|
data = data.transpose((0, 2, 3, 1)) |
|
elif not batched and data.ndim == 3: |
|
data = data.transpose((1, 2, 0)) |
|
|
|
if batched: |
|
return [Image.fromarray(d) for d in data] |
|
return Image.fromarray(data) |
|
|
|
|
|
def largest_factor_near_sqrt(n: int) -> int: |
|
""" |
|
Finds the largest factor of n that is closest to the square root of n. |
|
|
|
Args: |
|
n (int): The integer for which to find the largest factor near its square root. |
|
|
|
Returns: |
|
int: The largest factor of n that is closest to the square root of n. |
|
""" |
|
sqrt_n = int(math.sqrt(n)) |
|
|
|
|
|
if sqrt_n * sqrt_n == n: |
|
return sqrt_n |
|
|
|
|
|
for i in range(sqrt_n, 0, -1): |
|
if n % i == 0: |
|
return i |
|
|
|
|
|
return 1 |
|
|
|
|
|
def make_image_grid( |
|
images: List[Image.Image], |
|
rows: Optional[int] = None, |
|
cols: Optional[int] = None, |
|
resize: Optional[int] = None, |
|
) -> Image.Image: |
|
""" |
|
Prepares a single grid of images. Useful for visualization purposes. |
|
""" |
|
if rows is None and cols is not None: |
|
assert len(images) % cols == 0 |
|
rows = len(images) // cols |
|
elif cols is None and rows is not None: |
|
assert len(images) % rows == 0 |
|
cols = len(images) // rows |
|
elif rows is None and cols is None: |
|
rows = largest_factor_near_sqrt(len(images)) |
|
cols = len(images) // rows |
|
|
|
assert len(images) == rows * cols |
|
|
|
if resize is not None: |
|
images = [img.resize((resize, resize)) for img in images] |
|
|
|
w, h = images[0].size |
|
grid = Image.new("RGB", size=(cols * w, rows * h)) |
|
|
|
for i, img in enumerate(images): |
|
grid.paste(img, box=(i % cols * w, i // cols * h)) |
|
return grid |
|
|