mkaraki
commited on
Commit
•
e629925
1
Parent(s):
41fb947
Add initial checkpoint and gradio code
Browse files- README.md +16 -3
- checkpoints/epoch1100.ckpt +3 -0
- grd.py +121 -0
README.md
CHANGED
@@ -1,3 +1,16 @@
|
|
1 |
-
---
|
2 |
-
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: Sakamata Font DCGAN
|
3 |
+
sdk: gradio
|
4 |
+
sdk_version: 3.1.7
|
5 |
+
app_file: grd.py
|
6 |
+
license: other
|
7 |
+
---
|
8 |
+
|
9 |
+
# Sakamata Font DCGAN
|
10 |
+
|
11 |
+
This is experimental project that make fake handwritten character by DCGAN.
|
12 |
+
|
13 |
+
Dataset: [SakamataFontProject](https://github.com/sakamata-ch/SakamataFontProject)
|
14 |
+
|
15 |
+
This project working under Hololive Derivative Works Guidelines.
|
16 |
+
You have to read and agree for guideline if you want to use artifact of this project.
|
checkpoints/epoch1100.ckpt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8b74c472f4796be1e4683eb1357aa78188fa1058f5cf59fcb9dc3829f76026f6
|
3 |
+
size 43006121
|
grd.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# pip install gradio
|
2 |
+
import gradio as gr
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.parallel
|
7 |
+
import torch.backends.cudnn as cudnn
|
8 |
+
import torch.optim as optim
|
9 |
+
import torch.utils.data
|
10 |
+
import torchvision.datasets as dset
|
11 |
+
import torchvision.transforms as transforms
|
12 |
+
import torchvision.utils as vutils
|
13 |
+
import numpy as np
|
14 |
+
|
15 |
+
|
16 |
+
# If `RuntimeError: Error(s) in loading state_dict for Generator` error occurs:
|
17 |
+
omit_module = True
|
18 |
+
|
19 |
+
# Spatial size of training images. All images will be resized to this
|
20 |
+
# size using a transformer.
|
21 |
+
image_size = 64
|
22 |
+
|
23 |
+
# Number of channels in the training images. For color images this is 3
|
24 |
+
nc = 1
|
25 |
+
|
26 |
+
# Size of z latent vector (i.e. size of generator input)
|
27 |
+
nz = 100
|
28 |
+
|
29 |
+
# Size of feature maps in generator
|
30 |
+
ngf = 64
|
31 |
+
|
32 |
+
# Size of feature maps in discriminator
|
33 |
+
ndf = 64
|
34 |
+
|
35 |
+
# Learning rate for optimizers
|
36 |
+
lr = 0.0002
|
37 |
+
|
38 |
+
# Beta1 hyperparam for Adam optimizers
|
39 |
+
beta1 = 0.5
|
40 |
+
|
41 |
+
# Number of GPUs available. Use 0 for CPU mode.
|
42 |
+
ngpu = 1
|
43 |
+
|
44 |
+
|
45 |
+
device = torch.device("cuda:0" if (torch.cuda.is_available() and ngpu > 0) else "cpu")
|
46 |
+
|
47 |
+
|
48 |
+
# custom weights initialization called on netG and netD
|
49 |
+
def weights_init(m):
|
50 |
+
classname = m.__class__.__name__
|
51 |
+
if classname.find('Conv') != -1:
|
52 |
+
nn.init.normal_(m.weight.data, 0.0, 0.02)
|
53 |
+
elif classname.find('BatchNorm') != -1:
|
54 |
+
nn.init.normal_(m.weight.data, 1.0, 0.02)
|
55 |
+
nn.init.constant_(m.bias.data, 0)
|
56 |
+
|
57 |
+
|
58 |
+
class Generator(nn.Module):
|
59 |
+
def __init__(self, ngpu):
|
60 |
+
super(Generator, self).__init__()
|
61 |
+
self.ngpu = ngpu
|
62 |
+
self.main = nn.Sequential(
|
63 |
+
# input is Z, going into a convolution
|
64 |
+
nn.ConvTranspose2d( nz, ngf * 8, 4, 1, 0, bias=False),
|
65 |
+
nn.BatchNorm2d(ngf * 8),
|
66 |
+
nn.ReLU(True),
|
67 |
+
# state size. (ngf*8) x 4 x 4
|
68 |
+
nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
|
69 |
+
nn.BatchNorm2d(ngf * 4),
|
70 |
+
nn.ReLU(True),
|
71 |
+
# state size. (ngf*4) x 8 x 8
|
72 |
+
nn.ConvTranspose2d( ngf * 4, ngf * 2, 4, 2, 1, bias=False),
|
73 |
+
nn.BatchNorm2d(ngf * 2),
|
74 |
+
nn.ReLU(True),
|
75 |
+
# state size. (ngf*2) x 16 x 16
|
76 |
+
nn.ConvTranspose2d( ngf * 2, ngf, 4, 2, 1, bias=False),
|
77 |
+
nn.BatchNorm2d(ngf),
|
78 |
+
nn.ReLU(True),
|
79 |
+
# state size. (ngf) x 32 x 32
|
80 |
+
nn.ConvTranspose2d( ngf, nc, 4, 2, 1, bias=False),
|
81 |
+
nn.Tanh()
|
82 |
+
# state size. (nc) x 64 x 64
|
83 |
+
)
|
84 |
+
|
85 |
+
def forward(self, input):
|
86 |
+
return self.main(input)
|
87 |
+
|
88 |
+
|
89 |
+
# Create the generator
|
90 |
+
netG = Generator(ngpu).to(device)
|
91 |
+
|
92 |
+
# Handle multi-gpu if desired
|
93 |
+
if (device.type == 'cuda') and (ngpu > 1):
|
94 |
+
netG = nn.DataParallel(netG, list(range(ngpu)))
|
95 |
+
|
96 |
+
# Apply the weights_init function to randomly initialize all weights
|
97 |
+
# to mean=0, stdev=0.02.
|
98 |
+
netG.apply(weights_init)
|
99 |
+
|
100 |
+
|
101 |
+
checkpoint = torch.load("checkpoints/epoch1100.ckpt")
|
102 |
+
|
103 |
+
|
104 |
+
if omit_module:
|
105 |
+
for i in list(checkpoint['netG_state_dict'].keys()):
|
106 |
+
if (str(i).startswith('module.')):
|
107 |
+
checkpoint['netG_state_dict'][i[7:]] = checkpoint['netG_state_dict'].pop(i)
|
108 |
+
|
109 |
+
|
110 |
+
netG.load_state_dict(checkpoint['netG_state_dict'])
|
111 |
+
|
112 |
+
|
113 |
+
def genImg():
|
114 |
+
fixed_noise = torch.randn(64, nz, 1, 1, device=device)
|
115 |
+
with torch.no_grad():
|
116 |
+
fake = netG(fixed_noise).detach().cpu()
|
117 |
+
fake_grid = vutils.make_grid(fake, padding=2, normalize=True)
|
118 |
+
return transforms.functional.to_pil_image(fake_grid)
|
119 |
+
|
120 |
+
demo = gr.Interface(fn=genImg, inputs=None, outputs="image")
|
121 |
+
demo.launch()
|