Spaces:
Runtime error
Runtime error
import matplotlib.pyplot as plt | |
import numpy as np | |
import torchvision.utils as vutils | |
import torchvision.transforms as transforms | |
from skimage.exposure import match_histograms | |
import torch | |
# contains utility functions that we need in the main program | |
# matches the color histogram of original and the super resolution output | |
def color_histogram_mapping(images, references): | |
matched_list = [] | |
for i in range(len(images)): | |
matched = match_histograms(images[i].permute(1, 2, 0).numpy(), references[i].permute(1, 2, 0).numpy(), | |
channel_axis=-1) | |
matched_list.append(matched) | |
return torch.tensor(np.array(matched_list)).permute(0, 3, 1, 2) | |
def visualize_generations(seed, images): | |
plt.figure(figsize=(16, 16)) | |
plt.title(f"Seed: {seed}") | |
plt.axis("off") | |
plt.imshow(np.transpose(vutils.make_grid(images, padding=2, nrow=5, normalize=True), (2, 1, 0))) | |
plt.show() | |
# denormalize the images for proper display | |
def denormalize_images(images): | |
mean= [0.5, 0.5, 0.5] | |
std= [0.5, 0.5, 0.5] | |
inv_normalize = transforms.Normalize( | |
mean=[-m / s for m, s in zip(mean, std)], | |
std=[1 / s for s in std] | |
) | |
return inv_normalize(images) | |