marsfu2009 commited on
Commit
555e103
·
1 Parent(s): c8c7fac

Upload 3 files

Browse files
Files changed (3) hide show
  1. load_celebA.py +39 -0
  2. main.py +84 -0
  3. model.py +91 -0
load_celebA.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ CelebFaces Attributes (CelebA) Dataset
3
+ https://www.kaggle.com/datasets/jessicali9530/celeba-dataset
4
+ """
5
+
6
+
7
+ import os
8
+
9
+ import torch
10
+ from PIL import Image
11
+ from torch.utils.data import DataLoader, Dataset
12
+ from torchvision import transforms
13
+
14
+
15
+ class CelebADataset(Dataset):
16
+
17
+ def __init__(self, root, img_shape=(64, 64)) -> None:
18
+ super().__init__()
19
+ self.root = root
20
+ self.img_shape = img_shape
21
+ self.filenames = sorted(os.listdir(root))
22
+
23
+ def __len__(self) -> int:
24
+ return len(self.filenames)
25
+
26
+ def __getitem__(self, index: int):
27
+ path = os.path.join(self.root, self.filenames[index])
28
+ img = Image.open(path).convert('RGB')
29
+ pipeline = transforms.Compose([
30
+ transforms.CenterCrop(168),
31
+ transforms.Resize(self.img_shape),
32
+ transforms.ToTensor()
33
+ ])
34
+ return pipeline(img)
35
+
36
+
37
+ def get_dataloader(root='data/celebA/img_align_celeba', **kwargs):
38
+ dataset = CelebADataset(root, **kwargs)
39
+ return DataLoader(dataset, 16, shuffle=True)
main.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from time import time
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torchvision.transforms import ToPILImage
6
+
7
+ from dldemos.VAE.load_celebA import get_dataloader
8
+ from dldemos.VAE.model import VAE
9
+
10
+ # Hyperparameters
11
+ n_epochs = 10
12
+ kl_weight = 0.00025
13
+ lr = 0.005
14
+
15
+
16
+ def loss_fn(y, y_hat, mean, logvar):
17
+ recons_loss = F.mse_loss(y_hat, y)
18
+ kl_loss = torch.mean(
19
+ -0.5 * torch.sum(1 + logvar - mean**2 - torch.exp(logvar), 1), 0)
20
+ loss = recons_loss + kl_loss * kl_weight
21
+ return loss
22
+
23
+
24
+ def train(device, dataloader, model):
25
+ optimizer = torch.optim.Adam(model.parameters(), lr)
26
+ dataset_len = len(dataloader.dataset)
27
+
28
+ begin_time = time()
29
+ # train
30
+ for i in range(n_epochs):
31
+ loss_sum = 0
32
+ for x in dataloader:
33
+ x = x.to(device)
34
+ y_hat, mean, logvar = model(x)
35
+ loss = loss_fn(x, y_hat, mean, logvar)
36
+ optimizer.zero_grad()
37
+ loss.backward()
38
+ optimizer.step()
39
+ loss_sum += loss
40
+ loss_sum /= dataset_len
41
+ training_time = time() - begin_time
42
+ minute = int(training_time // 60)
43
+ second = int(training_time % 60)
44
+ print(f'epoch {i}: loss {loss_sum} {minute}:{second}')
45
+ torch.save(model.state_dict(), 'dldemos/VAE/model.pth')
46
+
47
+
48
+ def reconstruct(device, dataloader, model):
49
+ model.eval()
50
+ batch = next(iter(dataloader))
51
+ x = batch[0:1, ...].to(device)
52
+ output = model(x)[0]
53
+ output = output[0].detach().cpu()
54
+ input = batch[0].detach().cpu()
55
+ combined = torch.cat((output, input), 1)
56
+ img = ToPILImage()(combined)
57
+ img.save('work_dirs/tmp.jpg')
58
+
59
+
60
+ def generate(device, model):
61
+ model.eval()
62
+ output = model.sample(device)
63
+ output = output[0].detach().cpu()
64
+ img = ToPILImage()(output)
65
+ img.save('work_dirs/tmp.jpg')
66
+
67
+
68
+ def main():
69
+ device = 'cuda:0'
70
+ dataloader = get_dataloader()
71
+
72
+ model = VAE().to(device)
73
+
74
+ # If you obtain the ckpt, load it
75
+ model.load_state_dict(torch.load('dldemos/VAE/model.pth', 'cuda:0'))
76
+
77
+ # Choose the function
78
+ train(device, dataloader, model)
79
+ reconstruct(device, dataloader, model)
80
+ generate(device, model)
81
+
82
+
83
+ if __name__ == '__main__':
84
+ main()
model.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Full definition of a VAE model, all of it in this single file.
3
+ References:
4
+ 1) An Introduction to Variational Autoencoders:
5
+ https://arxiv.org/abs/1906.02691
6
+ """
7
+
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+
12
+
13
+ class VAE(nn.Module):
14
+ """VAE for 64x64 face generation.
15
+
16
+ The hidden dimensions can be tuned.
17
+ """
18
+
19
+ def __init__(self, hiddens=[16, 32, 64, 128, 256], latent_dim=128) -> None:
20
+ super().__init__()
21
+
22
+ # encoder
23
+ prev_channels = 3
24
+ modules = []
25
+ img_length = 64
26
+ for cur_channels in hiddens:
27
+ modules.append(
28
+ nn.Sequential(
29
+ nn.Conv2d(prev_channels,
30
+ cur_channels,
31
+ kernel_size=3,
32
+ stride=2,
33
+ padding=1), nn.BatchNorm2d(cur_channels),
34
+ nn.ReLU()))
35
+ prev_channels = cur_channels
36
+ img_length //= 2
37
+ self.encoder = nn.Sequential(*modules)
38
+ self.mean_linear = nn.Linear(prev_channels * img_length * img_length,
39
+ latent_dim)
40
+ self.var_linear = nn.Linear(prev_channels * img_length * img_length,
41
+ latent_dim)
42
+ self.latent_dim = latent_dim
43
+
44
+ # decoder
45
+ modules = []
46
+ self.decoder_projection = nn.Linear(
47
+ latent_dim, prev_channels * img_length * img_length)
48
+ self.decoder_input_chw = (prev_channels, img_length, img_length)
49
+ for i in range(len(hiddens) - 1, 0, -1):
50
+ modules.append(
51
+ nn.Sequential(
52
+ nn.ConvTranspose2d(hiddens[i],
53
+ hiddens[i - 1],
54
+ kernel_size=3,
55
+ stride=2,
56
+ padding=1,
57
+ output_padding=1),
58
+ nn.BatchNorm2d(hiddens[i - 1]), nn.ReLU()))
59
+ modules.append(
60
+ nn.Sequential(
61
+ nn.ConvTranspose2d(hiddens[0],
62
+ hiddens[0],
63
+ kernel_size=3,
64
+ stride=2,
65
+ padding=1,
66
+ output_padding=1),
67
+ nn.BatchNorm2d(hiddens[0]), nn.ReLU(),
68
+ nn.Conv2d(hiddens[0], 3, kernel_size=3, stride=1, padding=1),
69
+ nn.ReLU()))
70
+ self.decoder = nn.Sequential(*modules)
71
+
72
+ def forward(self, x):
73
+ encoded = self.encoder(x)
74
+ encoded = torch.flatten(encoded, 1)
75
+ mean = self.mean_linear(encoded)
76
+ logvar = self.var_linear(encoded)
77
+ eps = torch.randn_like(logvar)
78
+ std = torch.exp(logvar / 2)
79
+ z = eps * std + mean
80
+ x = self.decoder_projection(z)
81
+ x = torch.reshape(x, (-1, *self.decoder_input_chw))
82
+ decoded = self.decoder(x)
83
+
84
+ return decoded, mean, logvar
85
+
86
+ def sample(self, device='cuda'):
87
+ z = torch.randn(1, self.latent_dim).to(device)
88
+ x = self.decoder_projection(z)
89
+ x = torch.reshape(x, (-1, *self.decoder_input_chw))
90
+ decoded = self.decoder(x)
91
+ return decoded