| import matplotlib.pyplot as plt | |
| def _update_axis( | |
| axis, image, title=None, fontsize=18, remove_axis=True, title_loc="center" | |
| ): | |
| axis.imshow(image, origin="upper") | |
| if title is not None: | |
| axis.set_title(title, fontsize=fontsize, loc=title_loc) | |
| if remove_axis is True: | |
| axis.set_axis_off() | |
| def tensor_image_to_grid( | |
| images: list, | |
| transform, | |
| row_count, | |
| col_count=None, | |
| figsize=(20, 20), | |
| fontsize=None, | |
| ): | |
| def splt_image_title(image): | |
| if isinstance(image, tuple): | |
| return image[0], image[1] | |
| else: | |
| return image, None | |
| def torch_to_image(t): | |
| return transform(image=t.permute(1, 2, 0).numpy())["image"] | |
| col_count = row_count if col_count is None else col_count | |
| if len(images) == 1: | |
| img, title = splt_image_title(images[0]) | |
| plt.imshow(torch_to_image(img)) | |
| plt.title = title | |
| plt.tight_layout() | |
| plt.axis("off") | |
| else: | |
| _, axii = plt.subplots(row_count, col_count, figsize=figsize) | |
| for ax, image in zip(axii.reshape(-1), images): | |
| try: | |
| img, title = splt_image_title(image) | |
| _update_axis( | |
| axis=ax, | |
| image=torch_to_image(img), | |
| remove_axis=True, | |
| title=title, | |
| fontsize=figsize[0] if fontsize is None else fontsize, | |
| ) | |
| except: | |
| pass | |
| plt.tight_layout() | |
| plt.show() | |