Anonymous-123's picture
Add application file
ec0fdfd
raw
history blame
1.91 kB
from pathlib import Path
from numpy.core.shape_base import block
import torch
import matplotlib.pyplot as plt
from torchvision.transforms import functional as TF
from typing import Optional, Union
import matplotlib.pyplot as plt
import numpy as np
from PIL.Image import Image
def show_tensor_image(tensor: torch.Tensor, range_zero_one: bool = False):
"""Show a tensor of an image
Args:
tensor (torch.Tensor): Tensor of shape [N, 3, H, W] in range [-1, 1] or in range [0, 1]
"""
if not range_zero_one:
tensor = (tensor + 1) / 2
tensor.clamp(0, 1)
batch_size = tensor.shape[0]
for i in range(batch_size):
plt.title(f"Fig_{i}")
pil_image = TF.to_pil_image(tensor[i])
plt.imshow(pil_image)
plt.show(block=True)
def show_editied_masked_image(
title: str,
source_image: Image,
edited_image: Image,
mask: Optional[Image] = None,
path: Optional[Union[str, Path]] = None,
distance: Optional[str] = None,
):
fig_idx = 1
rows = 1
cols = 3 if mask is not None else 2
fig = plt.figure(figsize=(12, 5))
figure_title = f'Prompt: "{title}"'
if distance is not None:
figure_title += f" ({distance})"
plt.title(figure_title)
plt.axis("off")
fig.add_subplot(rows, cols, fig_idx)
fig_idx += 1
_set_image_plot_name("Source Image")
plt.imshow(source_image)
if mask is not None:
fig.add_subplot(rows, cols, fig_idx)
_set_image_plot_name("Mask")
plt.imshow(mask)
plt.gray()
fig_idx += 1
fig.add_subplot(rows, cols, fig_idx)
_set_image_plot_name("Edited Image")
plt.imshow(edited_image)
if path is not None:
plt.savefig(path, bbox_inches="tight")
else:
plt.show(block=True)
plt.close()
def _set_image_plot_name(name):
plt.title(name)
plt.xticks([])
plt.yticks([])