Kiwinicki commited on
Commit
8cc9b23
·
verified ·
1 Parent(s): bb75d9f

Upload model.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. model.py +126 -0
model.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import tanh, Tensor
2
+ import torch.nn as nn
3
+ from omegaconf import DictConfig
4
+ from abc import ABC, abstractmethod
5
+
6
+
7
+ class BaseGenerator(ABC, nn.Module):
8
+ def __init__(self, channels: int = 3):
9
+ super().__init__()
10
+ self.channels = channels
11
+
12
+ @abstractmethod
13
+ def forward(self, x: Tensor) -> Tensor:
14
+ pass
15
+
16
+
17
+ class Generator(BaseGenerator):
18
+ def __init__(self, cfg: DictConfig):
19
+ super().__init__(cfg.channels)
20
+ self.cfg = cfg
21
+ self.model = self._construct_model()
22
+
23
+ def _construct_model(self):
24
+ initial_layer = nn.Sequential(
25
+ nn.Conv2d(
26
+ self.cfg.channels,
27
+ self.cfg.num_features,
28
+ kernel_size=7,
29
+ stride=1,
30
+ padding=3,
31
+ padding_mode="reflect",
32
+ ),
33
+ nn.ReLU(inplace=True),
34
+ )
35
+
36
+ down_blocks = nn.Sequential(
37
+ ConvBlock(
38
+ self.cfg.num_features,
39
+ self.cfg.num_features * 2,
40
+ kernel_size=3,
41
+ stride=2,
42
+ padding=1,
43
+ ),
44
+ ConvBlock(
45
+ self.cfg.num_features * 2,
46
+ self.cfg.num_features * 4,
47
+ kernel_size=3,
48
+ stride=2,
49
+ padding=1,
50
+ ),
51
+ )
52
+
53
+ residual_blocks = nn.Sequential(
54
+ *[
55
+ ResidualBlock(self.cfg.num_features * 4)
56
+ for _ in range(self.cfg.num_residuals)
57
+ ]
58
+ )
59
+
60
+ up_blocks = nn.Sequential(
61
+ ConvBlock(
62
+ self.cfg.num_features * 4,
63
+ self.cfg.num_features * 2,
64
+ down=False,
65
+ kernel_size=3,
66
+ stride=2,
67
+ padding=1,
68
+ output_padding=1,
69
+ ),
70
+ ConvBlock(
71
+ self.cfg.num_features * 2,
72
+ self.cfg.num_features,
73
+ down=False,
74
+ kernel_size=3,
75
+ stride=2,
76
+ padding=1,
77
+ output_padding=1,
78
+ ),
79
+ )
80
+
81
+ last_layer = nn.Conv2d(
82
+ self.cfg.num_features,
83
+ self.cfg.channels,
84
+ kernel_size=7,
85
+ stride=1,
86
+ padding=3,
87
+ padding_mode="reflect",
88
+ )
89
+
90
+ return nn.Sequential(
91
+ initial_layer, down_blocks, residual_blocks, up_blocks, last_layer
92
+ )
93
+
94
+ def forward(self, x: Tensor) -> Tensor:
95
+ return tanh(self.model(x))
96
+
97
+
98
+ class ConvBlock(nn.Module):
99
+ def __init__(
100
+ self, in_channels, out_channels, down=True, use_activation=True, **kwargs
101
+ ):
102
+ super().__init__()
103
+ self.conv = nn.Sequential(
104
+ nn.Conv2d(in_channels, out_channels, padding_mode="reflect", **kwargs)
105
+ if down
106
+ else nn.ConvTranspose2d(in_channels, out_channels, **kwargs),
107
+ nn.InstanceNorm2d(out_channels),
108
+ nn.ReLU(inplace=True) if use_activation else nn.Identity(),
109
+ )
110
+
111
+ def forward(self, x: Tensor) -> Tensor:
112
+ return self.conv(x)
113
+
114
+
115
+ class ResidualBlock(nn.Module):
116
+ def __init__(self, channels: int):
117
+ super().__init__()
118
+ self.block = nn.Sequential(
119
+ ConvBlock(channels, channels, kernel_size=3, padding=1),
120
+ ConvBlock(
121
+ channels, channels, use_activation=False, kernel_size=3, padding=1
122
+ ),
123
+ )
124
+
125
+ def forward(self, x: Tensor) -> Tensor:
126
+ return x + self.block(x)