File size: 3,495 Bytes
32b2aaa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
import logging
import math
import torch
import torch.nn as nn
logger = logging.getLogger(__name__)
@torch.jit.script
def _fused_tanh_sigmoid(h):
a, b = h.chunk(2, dim=1)
h = a.tanh() * b.sigmoid()
return h
class WNLayer(nn.Module):
"""
A DiffWave-like WN
"""
def __init__(self, hidden_dim, local_dim, global_dim, kernel_size, dilation):
super().__init__()
local_output_dim = hidden_dim * 2
if global_dim is not None:
self.gconv = nn.Conv1d(global_dim, hidden_dim, 1)
if local_dim is not None:
self.lconv = nn.Conv1d(local_dim, local_output_dim, 1)
self.dconv = nn.Conv1d(hidden_dim, local_output_dim, kernel_size, dilation=dilation, padding="same")
self.out = nn.Conv1d(hidden_dim, 2 * hidden_dim, kernel_size=1)
def forward(self, z, l, g):
identity = z
if g is not None:
if g.dim() == 2:
g = g.unsqueeze(-1)
z = z + self.gconv(g)
z = self.dconv(z)
if l is not None:
z = z + self.lconv(l)
z = _fused_tanh_sigmoid(z)
h = self.out(z)
z, s = h.chunk(2, dim=1)
o = (z + identity) / math.sqrt(2)
return o, s
class WN(nn.Module):
def __init__(
self,
input_dim,
output_dim,
local_dim=None,
global_dim=None,
n_layers=30,
kernel_size=3,
dilation_cycle=5,
hidden_dim=512,
):
super().__init__()
assert kernel_size % 2 == 1
assert hidden_dim % 2 == 0
self.input_dim = input_dim
self.hidden_dim = hidden_dim
self.local_dim = local_dim
self.global_dim = global_dim
self.start = nn.Conv1d(input_dim, hidden_dim, 1)
if local_dim is not None:
self.local_norm = nn.InstanceNorm1d(local_dim)
self.layers = nn.ModuleList(
[
WNLayer(
hidden_dim=hidden_dim,
local_dim=local_dim,
global_dim=global_dim,
kernel_size=kernel_size,
dilation=2 ** (i % dilation_cycle),
)
for i in range(n_layers)
]
)
self.end = nn.Conv1d(hidden_dim, output_dim, 1)
def forward(self, z, l=None, g=None):
"""
Args:
z: input (b c t)
l: local condition (b c t)
g: global condition (b d)
"""
z = self.start(z)
if l is not None:
l = self.local_norm(l)
# Skips
s_list = []
for layer in self.layers:
z, s = layer(z, l, g)
s_list.append(s)
s_list = torch.stack(s_list, dim=0).sum(dim=0)
s_list = s_list / math.sqrt(len(self.layers))
o = self.end(s_list)
return o
def summarize(self, length=100):
from ptflops import get_model_complexity_info
x = torch.randn(1, self.input_dim, length)
macs, params = get_model_complexity_info(
self,
(self.input_dim, length),
as_strings=True,
print_per_layer_stat=True,
verbose=True,
)
print(f"Input shape: {x.shape}")
print(f"Computational complexity: {macs}")
print(f"Number of parameters: {params}")
if __name__ == "__main__":
model = WN(input_dim=64, output_dim=64)
model.summarize()
|