oiv_ld_phenotyping / src /com_plot.py
treizh's picture
Upload folder using huggingface_hub
fc262e7 verified
raw
history blame
1.52 kB
import matplotlib.pyplot as plt
def _update_axis(
axis, image, title=None, fontsize=18, remove_axis=True, title_loc="center"
):
axis.imshow(image, origin="upper")
if title is not None:
axis.set_title(title, fontsize=fontsize, loc=title_loc)
if remove_axis is True:
axis.set_axis_off()
def tensor_image_to_grid(
images: list,
transform,
row_count,
col_count=None,
figsize=(20, 20),
fontsize=None,
):
def splt_image_title(image):
if isinstance(image, tuple):
return image[0], image[1]
else:
return image, None
def torch_to_image(t):
return transform(image=t.permute(1, 2, 0).numpy())["image"]
col_count = row_count if col_count is None else col_count
if len(images) == 1:
img, title = splt_image_title(images[0])
plt.imshow(torch_to_image(img))
plt.title = title
plt.tight_layout()
plt.axis("off")
else:
_, axii = plt.subplots(row_count, col_count, figsize=figsize)
for ax, image in zip(axii.reshape(-1), images):
try:
img, title = splt_image_title(image)
_update_axis(
axis=ax,
image=torch_to_image(img),
remove_axis=True,
title=title,
fontsize=figsize[0] if fontsize is None else fontsize,
)
except:
pass
plt.tight_layout()
plt.show()