|
import matplotlib.pyplot as plt |
|
|
|
|
|
def _update_axis( |
|
axis, image, title=None, fontsize=18, remove_axis=True, title_loc="center" |
|
): |
|
axis.imshow(image, origin="upper") |
|
if title is not None: |
|
axis.set_title(title, fontsize=fontsize, loc=title_loc) |
|
if remove_axis is True: |
|
axis.set_axis_off() |
|
|
|
|
|
def tensor_image_to_grid( |
|
images: list, |
|
transform, |
|
row_count, |
|
col_count=None, |
|
figsize=(20, 20), |
|
fontsize=None, |
|
): |
|
def splt_image_title(image): |
|
if isinstance(image, tuple): |
|
return image[0], image[1] |
|
else: |
|
return image, None |
|
|
|
def torch_to_image(t): |
|
return transform(image=t.permute(1, 2, 0).numpy())["image"] |
|
|
|
col_count = row_count if col_count is None else col_count |
|
if len(images) == 1: |
|
img, title = splt_image_title(images[0]) |
|
plt.imshow(torch_to_image(img)) |
|
plt.title = title |
|
plt.tight_layout() |
|
plt.axis("off") |
|
else: |
|
_, axii = plt.subplots(row_count, col_count, figsize=figsize) |
|
for ax, image in zip(axii.reshape(-1), images): |
|
try: |
|
img, title = splt_image_title(image) |
|
_update_axis( |
|
axis=ax, |
|
image=torch_to_image(img), |
|
remove_axis=True, |
|
title=title, |
|
fontsize=figsize[0] if fontsize is None else fontsize, |
|
) |
|
except: |
|
pass |
|
|
|
plt.tight_layout() |
|
plt.show() |
|
|