File size: 5,195 Bytes
3e423f6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 |
import torch
import torch.nn as nn
from tortoise.models.arch_util import (
AttentionBlock,
Downsample,
Upsample,
normalization,
zero_module,
)
class ResBlock(nn.Module):
def __init__(
self,
channels,
dropout,
out_channels=None,
use_conv=False,
use_scale_shift_norm=False,
dims=2,
up=False,
down=False,
kernel_size=3,
do_checkpoint=True,
):
super().__init__()
self.channels = channels
self.dropout = dropout
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.use_scale_shift_norm = use_scale_shift_norm
self.do_checkpoint = do_checkpoint
padding = 1 if kernel_size == 3 else 2
self.in_layers = nn.Sequential(
normalization(channels),
nn.SiLU(),
nn.Conv1d(channels, self.out_channels, kernel_size, padding=padding),
)
self.updown = up or down
if up:
self.h_upd = Upsample(channels, False, dims)
self.x_upd = Upsample(channels, False, dims)
elif down:
self.h_upd = Downsample(channels, False, dims)
self.x_upd = Downsample(channels, False, dims)
else:
self.h_upd = self.x_upd = nn.Identity()
self.out_layers = nn.Sequential(
normalization(self.out_channels),
nn.SiLU(),
nn.Dropout(p=dropout),
zero_module(
nn.Conv1d(
self.out_channels, self.out_channels, kernel_size, padding=padding
)
),
)
if self.out_channels == channels:
self.skip_connection = nn.Identity()
elif use_conv:
self.skip_connection = nn.Conv1d(
dims, channels, self.out_channels, kernel_size, padding=padding
)
else:
self.skip_connection = nn.Conv1d(dims, channels, self.out_channels, 1)
def forward(self, x):
if self.updown:
in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
h = in_rest(x)
h = self.h_upd(h)
x = self.x_upd(x)
h = in_conv(h)
else:
h = self.in_layers(x)
h = self.out_layers(h)
return self.skip_connection(x) + h
class AudioMiniEncoder(nn.Module):
def __init__(
self,
spec_dim,
embedding_dim,
base_channels=128,
depth=2,
resnet_blocks=2,
attn_blocks=4,
num_attn_heads=4,
dropout=0,
downsample_factor=2,
kernel_size=3,
):
super().__init__()
self.init = nn.Sequential(nn.Conv1d(spec_dim, base_channels, 3, padding=1))
ch = base_channels
res = []
self.layers = depth
for l in range(depth):
for r in range(resnet_blocks):
res.append(
ResBlock(ch, dropout, do_checkpoint=False, kernel_size=kernel_size)
)
res.append(
Downsample(
ch, use_conv=True, out_channels=ch * 2, factor=downsample_factor
)
)
ch *= 2
self.res = nn.Sequential(*res)
self.final = nn.Sequential(
normalization(ch), nn.SiLU(), nn.Conv1d(ch, embedding_dim, 1)
)
attn = []
for a in range(attn_blocks):
attn.append(
AttentionBlock(embedding_dim, num_attn_heads, do_checkpoint=False)
)
self.attn = nn.Sequential(*attn)
self.dim = embedding_dim
def forward(self, x):
h = self.init(x)
h = self.res(h)
h = self.final(h)
for blk in self.attn:
h = blk(h)
return h[:, :, 0]
class AudioMiniEncoderWithClassifierHead(nn.Module):
def __init__(self, classes, distribute_zero_label=True, **kwargs):
super().__init__()
self.enc = AudioMiniEncoder(**kwargs)
self.head = nn.Linear(self.enc.dim, classes)
self.num_classes = classes
self.distribute_zero_label = distribute_zero_label
def forward(self, x, labels=None):
h = self.enc(x)
logits = self.head(h)
if labels is None:
return logits
else:
if self.distribute_zero_label:
oh_labels = nn.functional.one_hot(labels, num_classes=self.num_classes)
zeros_indices = (labels == 0).unsqueeze(-1)
# Distribute 20% of the probability mass on all classes when zero is specified, to compensate for dataset noise.
zero_extra_mass = torch.full_like(
oh_labels,
dtype=torch.float,
fill_value=0.2 / (self.num_classes - 1),
)
zero_extra_mass[:, 0] = -0.2
zero_extra_mass = zero_extra_mass * zeros_indices
oh_labels = oh_labels + zero_extra_mass
else:
oh_labels = labels
loss = nn.functional.cross_entropy(logits, oh_labels)
return loss
|