VfiTest / utils /plot.py
SuyeonJ's picture
Upload folder using huggingface_hub
8d015d4 verified
raw
history blame
1.08 kB
"""
Plotting utilities to visualize training logs.
"""
import imageio
import os
import torchvision.utils as v_utils
def plot_samples_per_epoch(gen_batch, output_dir, epoch, iteration, nsample):
"""
Plot and save output samples per epoch
"""
fname = "samples_epoch_{:d}_{:d}.jpg".format(epoch, iteration)
fpath = os.path.join(output_dir, fname)
nrow = gen_batch.shape[0] // nsample
image = v_utils.make_grid(gen_batch, nrow=nrow, padding=2, normalize=True)
v_utils.save_image(image, fpath)
return image
def plot_val_samples(gen_batch, output_dir, fname, nrow):
"""
Plot and dsave output samples for validations
"""
fpath = os.path.join(output_dir, fname)
image = v_utils.make_grid(gen_batch, nrow=nrow, padding=2, normalize=True)
v_utils.save_image(image, fpath)
return image
def plot_image(img, output_dir, fname):
"""
img in tensor format
"""
fpath = os.path.join(output_dir, fname)
v_utils.save_image(img, fpath, nrow=4, padding=2, normalize=True)
return imageio.imread(fpath)