Spaces:
Runtime error
Runtime error
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torchvision.utils import make_grid
|
3 |
+
from torchvision import transforms
|
4 |
+
import torchvision.transforms.functional as TF
|
5 |
+
from torch import nn, optim
|
6 |
+
from torch.optim.lr_scheduler import CosineAnnealingLR
|
7 |
+
from torch.utils.data import DataLoader, Dataset
|
8 |
+
from huggingface_hub import hf_hub_download
|
9 |
+
import requests
|
10 |
+
import gradio as gr
|
11 |
+
|
12 |
+
class Upsample(nn.Module):
|
13 |
+
def __init__(self, in_channels, out_channels, kernel_size=4, stride=2, padding=1, dropout=True):
|
14 |
+
super(Upsample, self).__init__()
|
15 |
+
self.dropout = dropout
|
16 |
+
self.block = nn.Sequential(
|
17 |
+
nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, bias=nn.InstanceNorm2d),
|
18 |
+
nn.InstanceNorm2d(out_channels),
|
19 |
+
nn.ReLU(inplace=True)
|
20 |
+
)
|
21 |
+
self.dropout_layer = nn.Dropout2d(0.5)
|
22 |
+
|
23 |
+
def forward(self, x, shortcut=None):
|
24 |
+
x = self.block(x)
|
25 |
+
if self.dropout:
|
26 |
+
x = self.dropout_layer(x)
|
27 |
+
|
28 |
+
if shortcut is not None:
|
29 |
+
x = torch.cat([x, shortcut], dim=1)
|
30 |
+
|
31 |
+
return x
|
32 |
+
|
33 |
+
|
34 |
+
class Downsample(nn.Module):
|
35 |
+
def __init__(self, in_channels, out_channels, kernel_size=4, stride=2, padding=1, apply_instancenorm=True):
|
36 |
+
super(Downsample, self).__init__()
|
37 |
+
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=nn.InstanceNorm2d)
|
38 |
+
self.norm = nn.InstanceNorm2d(out_channels)
|
39 |
+
self.relu = nn.LeakyReLU(0.2, inplace=True)
|
40 |
+
self.apply_norm = apply_instancenorm
|
41 |
+
|
42 |
+
def forward(self, x):
|
43 |
+
x = self.conv(x)
|
44 |
+
if self.apply_norm:
|
45 |
+
x = self.norm(x)
|
46 |
+
x = self.relu(x)
|
47 |
+
|
48 |
+
return x
|
49 |
+
|
50 |
+
|
51 |
+
class CycleGAN_Unet_Generator(nn.Module):
|
52 |
+
def __init__(self, filter=64):
|
53 |
+
super(CycleGAN_Unet_Generator, self).__init__()
|
54 |
+
self.downsamples = nn.ModuleList([
|
55 |
+
Downsample(3, filter, kernel_size=4, apply_instancenorm=False), # (b, filter, 128, 128)
|
56 |
+
Downsample(filter, filter * 2), # (b, filter * 2, 64, 64)
|
57 |
+
Downsample(filter * 2, filter * 4), # (b, filter * 4, 32, 32)
|
58 |
+
Downsample(filter * 4, filter * 8), # (b, filter * 8, 16, 16)
|
59 |
+
Downsample(filter * 8, filter * 8), # (b, filter * 8, 8, 8)
|
60 |
+
Downsample(filter * 8, filter * 8), # (b, filter * 8, 4, 4)
|
61 |
+
Downsample(filter * 8, filter * 8), # (b, filter * 8, 2, 2)
|
62 |
+
])
|
63 |
+
|
64 |
+
self.upsamples = nn.ModuleList([
|
65 |
+
Upsample(filter * 8, filter * 8),
|
66 |
+
Upsample(filter * 16, filter * 8),
|
67 |
+
Upsample(filter * 16, filter * 8),
|
68 |
+
Upsample(filter * 16, filter * 4, dropout=False),
|
69 |
+
Upsample(filter * 8, filter * 2, dropout=False),
|
70 |
+
Upsample(filter * 4, filter, dropout=False)
|
71 |
+
])
|
72 |
+
|
73 |
+
self.last = nn.Sequential(
|
74 |
+
nn.ConvTranspose2d(filter * 2, 3, kernel_size=4, stride=2, padding=1),
|
75 |
+
nn.Tanh()
|
76 |
+
)
|
77 |
+
|
78 |
+
def forward(self, x):
|
79 |
+
skips = []
|
80 |
+
for l in self.downsamples:
|
81 |
+
x = l(x)
|
82 |
+
skips.append(x)
|
83 |
+
|
84 |
+
skips = reversed(skips[:-1])
|
85 |
+
for l, s in zip(self.upsamples, skips):
|
86 |
+
x = l(x, s)
|
87 |
+
|
88 |
+
out = self.last(x)
|
89 |
+
|
90 |
+
return out
|
91 |
+
|
92 |
+
class ImageTransform:
|
93 |
+
def __init__(self, img_size=256):
|
94 |
+
self.transform = {
|
95 |
+
'train': transforms.Compose([
|
96 |
+
transforms.Resize((img_size, img_size)),
|
97 |
+
transforms.RandomHorizontalFlip(),
|
98 |
+
transforms.RandomVerticalFlip(),
|
99 |
+
transforms.ToTensor(),
|
100 |
+
transforms.Normalize(mean=[0.5], std=[0.5])
|
101 |
+
]),
|
102 |
+
'test': transforms.Compose([
|
103 |
+
transforms.Resize((img_size, img_size)),
|
104 |
+
transforms.ToTensor(),
|
105 |
+
transforms.Normalize(mean=[0.5], std=[0.5])
|
106 |
+
})}
|
107 |
+
|
108 |
+
def __call__(self, img, phase='train'):
|
109 |
+
img = self.transform[phase](img)
|
110 |
+
|
111 |
+
return img
|
112 |
+
|
113 |
+
|
114 |
+
path = hf_hub_download('huggan/NeonGAN', 'model.bin')
|
115 |
+
model_gen_n = torch.load(path, map_location=torch.device('cpu'))
|
116 |
+
|
117 |
+
|
118 |
+
|
119 |
+
|
120 |
+
|
121 |
+
|