|
""" |
|
Plotting utilities to visualize training logs. |
|
""" |
|
|
|
import imageio |
|
import os |
|
import torchvision.utils as v_utils |
|
|
|
|
|
def plot_samples_per_epoch(gen_batch, output_dir, epoch, iteration, nsample): |
|
""" |
|
Plot and save output samples per epoch |
|
""" |
|
fname = "samples_epoch_{:d}_{:d}.jpg".format(epoch, iteration) |
|
fpath = os.path.join(output_dir, fname) |
|
nrow = gen_batch.shape[0] // nsample |
|
|
|
image = v_utils.make_grid(gen_batch, nrow=nrow, padding=2, normalize=True) |
|
v_utils.save_image(image, fpath) |
|
return image |
|
|
|
|
|
def plot_val_samples(gen_batch, output_dir, fname, nrow): |
|
""" |
|
Plot and dsave output samples for validations |
|
""" |
|
fpath = os.path.join(output_dir, fname) |
|
image = v_utils.make_grid(gen_batch, nrow=nrow, padding=2, normalize=True) |
|
v_utils.save_image(image, fpath) |
|
return image |
|
|
|
|
|
def plot_image(img, output_dir, fname): |
|
""" |
|
img in tensor format |
|
""" |
|
|
|
fpath = os.path.join(output_dir, fname) |
|
|
|
v_utils.save_image(img, fpath, nrow=4, padding=2, normalize=True) |
|
return imageio.imread(fpath) |
|
|
|
|