Ved Gupta
initially repo created
1c7b15f
raw
history blame
1.22 kB
import warnings
warnings.filterwarnings("ignore")
import os
import sys
import glob
import time
import numpy as np
from PIL import Image
from pathlib import Path
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
from skimage.color import rgb2lab, lab2rgb
import torch
from torch import nn, optim
from torchvision import transforms
from torchvision.utils import make_grid
from torch.utils.data import Dataset, DataLoader
class GANLoss(nn.Module):
def __init__(self, gan_mode="vanilla", real_label=1.0, fake_label=0.0):
super().__init__()
self.register_buffer("real_label", torch.tensor(real_label))
self.register_buffer("fake_label", torch.tensor(fake_label))
if gan_mode == "vanilla":
self.loss = nn.BCEWithLogitsLoss()
elif gan_mode == "lsgan":
self.loss = nn.MSELoss()
def get_labels(self, preds, target_is_real):
if target_is_real:
labels = self.real_label
else:
labels = self.fake_label
return labels.expand_as(preds)
def __call__(self, preds, target_is_real):
labels = self.get_labels(preds, target_is_real)
loss = self.loss(preds, labels)
return loss