Spaces:
Runtime error
Runtime error
import warnings | |
warnings.filterwarnings("ignore") | |
import os | |
import sys | |
import glob | |
import time | |
import numpy as np | |
from PIL import Image | |
from pathlib import Path | |
from tqdm.notebook import tqdm | |
import matplotlib.pyplot as plt | |
from skimage.color import rgb2lab, lab2rgb | |
import torch | |
from torch import nn, optim | |
from torchvision import transforms | |
from torchvision.utils import make_grid | |
from torch.utils.data import Dataset, DataLoader | |
class GANLoss(nn.Module): | |
def __init__(self, gan_mode="vanilla", real_label=1.0, fake_label=0.0): | |
super().__init__() | |
self.register_buffer("real_label", torch.tensor(real_label)) | |
self.register_buffer("fake_label", torch.tensor(fake_label)) | |
if gan_mode == "vanilla": | |
self.loss = nn.BCEWithLogitsLoss() | |
elif gan_mode == "lsgan": | |
self.loss = nn.MSELoss() | |
def get_labels(self, preds, target_is_real): | |
if target_is_real: | |
labels = self.real_label | |
else: | |
labels = self.fake_label | |
return labels.expand_as(preds) | |
def __call__(self, preds, target_is_real): | |
labels = self.get_labels(preds, target_is_real) | |
loss = self.loss(preds, labels) | |
return loss | |