|
from time import time |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
from torchvision.transforms import ToPILImage |
|
|
|
from dldemos.VAE.load_celebA import get_dataloader |
|
from dldemos.VAE.model import VAE |
|
|
|
|
|
n_epochs = 10 |
|
kl_weight = 0.00025 |
|
lr = 0.005 |
|
|
|
|
|
def loss_fn(y, y_hat, mean, logvar): |
|
recons_loss = F.mse_loss(y_hat, y) |
|
kl_loss = torch.mean( |
|
-0.5 * torch.sum(1 + logvar - mean**2 - torch.exp(logvar), 1), 0) |
|
loss = recons_loss + kl_loss * kl_weight |
|
return loss |
|
|
|
|
|
def train(device, dataloader, model): |
|
optimizer = torch.optim.Adam(model.parameters(), lr) |
|
dataset_len = len(dataloader.dataset) |
|
|
|
begin_time = time() |
|
|
|
for i in range(n_epochs): |
|
loss_sum = 0 |
|
for x in dataloader: |
|
x = x.to(device) |
|
y_hat, mean, logvar = model(x) |
|
loss = loss_fn(x, y_hat, mean, logvar) |
|
optimizer.zero_grad() |
|
loss.backward() |
|
optimizer.step() |
|
loss_sum += loss |
|
loss_sum /= dataset_len |
|
training_time = time() - begin_time |
|
minute = int(training_time // 60) |
|
second = int(training_time % 60) |
|
print(f'epoch {i}: loss {loss_sum} {minute}:{second}') |
|
torch.save(model.state_dict(), 'dldemos/VAE/model.pth') |
|
|
|
|
|
def reconstruct(device, dataloader, model): |
|
model.eval() |
|
batch = next(iter(dataloader)) |
|
x = batch[0:1, ...].to(device) |
|
output = model(x)[0] |
|
output = output[0].detach().cpu() |
|
input = batch[0].detach().cpu() |
|
combined = torch.cat((output, input), 1) |
|
img = ToPILImage()(combined) |
|
img.save('work_dirs/tmp.jpg') |
|
|
|
|
|
def generate(device, model): |
|
model.eval() |
|
output = model.sample(device) |
|
output = output[0].detach().cpu() |
|
img = ToPILImage()(output) |
|
img.save('work_dirs/tmp.jpg') |
|
|
|
|
|
def main(): |
|
device = 'cuda:0' |
|
dataloader = get_dataloader() |
|
|
|
model = VAE().to(device) |
|
|
|
|
|
model.load_state_dict(torch.load('dldemos/VAE/model.pth', 'cuda:0')) |
|
|
|
|
|
train(device, dataloader, model) |
|
reconstruct(device, dataloader, model) |
|
generate(device, model) |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |