mkaraki commited on
Commit
e629925
1 Parent(s): 41fb947

Add initial checkpoint and gradio code

Browse files
Files changed (3) hide show
  1. README.md +16 -3
  2. checkpoints/epoch1100.ckpt +3 -0
  3. grd.py +121 -0
README.md CHANGED
@@ -1,3 +1,16 @@
1
- ---
2
- license: other
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Sakamata Font DCGAN
3
+ sdk: gradio
4
+ sdk_version: 3.1.7
5
+ app_file: grd.py
6
+ license: other
7
+ ---
8
+
9
+ # Sakamata Font DCGAN
10
+
11
+ This is experimental project that make fake handwritten character by DCGAN.
12
+
13
+ Dataset: [SakamataFontProject](https://github.com/sakamata-ch/SakamataFontProject)
14
+
15
+ This project working under Hololive Derivative Works Guidelines.
16
+ You have to read and agree for guideline if you want to use artifact of this project.
checkpoints/epoch1100.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8b74c472f4796be1e4683eb1357aa78188fa1058f5cf59fcb9dc3829f76026f6
3
+ size 43006121
grd.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pip install gradio
2
+ import gradio as gr
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.parallel
7
+ import torch.backends.cudnn as cudnn
8
+ import torch.optim as optim
9
+ import torch.utils.data
10
+ import torchvision.datasets as dset
11
+ import torchvision.transforms as transforms
12
+ import torchvision.utils as vutils
13
+ import numpy as np
14
+
15
+
16
+ # If `RuntimeError: Error(s) in loading state_dict for Generator` error occurs:
17
+ omit_module = True
18
+
19
+ # Spatial size of training images. All images will be resized to this
20
+ # size using a transformer.
21
+ image_size = 64
22
+
23
+ # Number of channels in the training images. For color images this is 3
24
+ nc = 1
25
+
26
+ # Size of z latent vector (i.e. size of generator input)
27
+ nz = 100
28
+
29
+ # Size of feature maps in generator
30
+ ngf = 64
31
+
32
+ # Size of feature maps in discriminator
33
+ ndf = 64
34
+
35
+ # Learning rate for optimizers
36
+ lr = 0.0002
37
+
38
+ # Beta1 hyperparam for Adam optimizers
39
+ beta1 = 0.5
40
+
41
+ # Number of GPUs available. Use 0 for CPU mode.
42
+ ngpu = 1
43
+
44
+
45
+ device = torch.device("cuda:0" if (torch.cuda.is_available() and ngpu > 0) else "cpu")
46
+
47
+
48
+ # custom weights initialization called on netG and netD
49
+ def weights_init(m):
50
+ classname = m.__class__.__name__
51
+ if classname.find('Conv') != -1:
52
+ nn.init.normal_(m.weight.data, 0.0, 0.02)
53
+ elif classname.find('BatchNorm') != -1:
54
+ nn.init.normal_(m.weight.data, 1.0, 0.02)
55
+ nn.init.constant_(m.bias.data, 0)
56
+
57
+
58
+ class Generator(nn.Module):
59
+ def __init__(self, ngpu):
60
+ super(Generator, self).__init__()
61
+ self.ngpu = ngpu
62
+ self.main = nn.Sequential(
63
+ # input is Z, going into a convolution
64
+ nn.ConvTranspose2d( nz, ngf * 8, 4, 1, 0, bias=False),
65
+ nn.BatchNorm2d(ngf * 8),
66
+ nn.ReLU(True),
67
+ # state size. (ngf*8) x 4 x 4
68
+ nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
69
+ nn.BatchNorm2d(ngf * 4),
70
+ nn.ReLU(True),
71
+ # state size. (ngf*4) x 8 x 8
72
+ nn.ConvTranspose2d( ngf * 4, ngf * 2, 4, 2, 1, bias=False),
73
+ nn.BatchNorm2d(ngf * 2),
74
+ nn.ReLU(True),
75
+ # state size. (ngf*2) x 16 x 16
76
+ nn.ConvTranspose2d( ngf * 2, ngf, 4, 2, 1, bias=False),
77
+ nn.BatchNorm2d(ngf),
78
+ nn.ReLU(True),
79
+ # state size. (ngf) x 32 x 32
80
+ nn.ConvTranspose2d( ngf, nc, 4, 2, 1, bias=False),
81
+ nn.Tanh()
82
+ # state size. (nc) x 64 x 64
83
+ )
84
+
85
+ def forward(self, input):
86
+ return self.main(input)
87
+
88
+
89
+ # Create the generator
90
+ netG = Generator(ngpu).to(device)
91
+
92
+ # Handle multi-gpu if desired
93
+ if (device.type == 'cuda') and (ngpu > 1):
94
+ netG = nn.DataParallel(netG, list(range(ngpu)))
95
+
96
+ # Apply the weights_init function to randomly initialize all weights
97
+ # to mean=0, stdev=0.02.
98
+ netG.apply(weights_init)
99
+
100
+
101
+ checkpoint = torch.load("checkpoints/epoch1100.ckpt")
102
+
103
+
104
+ if omit_module:
105
+ for i in list(checkpoint['netG_state_dict'].keys()):
106
+ if (str(i).startswith('module.')):
107
+ checkpoint['netG_state_dict'][i[7:]] = checkpoint['netG_state_dict'].pop(i)
108
+
109
+
110
+ netG.load_state_dict(checkpoint['netG_state_dict'])
111
+
112
+
113
+ def genImg():
114
+ fixed_noise = torch.randn(64, nz, 1, 1, device=device)
115
+ with torch.no_grad():
116
+ fake = netG(fixed_noise).detach().cpu()
117
+ fake_grid = vutils.make_grid(fake, padding=2, normalize=True)
118
+ return transforms.functional.to_pil_image(fake_grid)
119
+
120
+ demo = gr.Interface(fn=genImg, inputs=None, outputs="image")
121
+ demo.launch()