|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
import torchaudio |
|
import os |
|
import random |
|
|
|
from attention_modules import BertConfig, BertEncoder, BertPooler |
|
|
|
|
|
class Conv_1d(nn.Module): |
|
def __init__(self, input_channels, output_channels, shape=3, stride=1, pooling=2): |
|
super(Conv_1d, self).__init__() |
|
self.conv = nn.Conv1d(input_channels, output_channels, shape, stride=stride, padding=shape//2) |
|
self.bn = nn.BatchNorm1d(output_channels) |
|
self.relu = nn.ReLU() |
|
self.mp = nn.MaxPool1d(pooling) |
|
def forward(self, x): |
|
out = self.mp(self.relu(self.bn(self.conv(x)))) |
|
return out |
|
|
|
|
|
class Conv_2d(nn.Module): |
|
def __init__(self, input_channels, output_channels, shape=3, stride=1, pooling=2): |
|
super(Conv_2d, self).__init__() |
|
self.conv = nn.Conv2d(input_channels, output_channels, shape, stride=stride, padding=shape//2) |
|
self.bn = nn.BatchNorm2d(output_channels) |
|
self.relu = nn.ReLU() |
|
self.mp = nn.MaxPool2d(pooling) |
|
def forward(self, x): |
|
out = self.mp(self.relu(self.bn(self.conv(x)))) |
|
return out |
|
|
|
|
|
class Res_2d(nn.Module): |
|
def __init__(self, input_channels, output_channels, shape=3, stride=2): |
|
super(Res_2d, self).__init__() |
|
|
|
self.conv_1 = nn.Conv2d(input_channels, output_channels, shape, stride=stride, padding=shape//2) |
|
self.bn_1 = nn.BatchNorm2d(output_channels) |
|
self.conv_2 = nn.Conv2d(output_channels, output_channels, shape, padding=shape//2) |
|
self.bn_2 = nn.BatchNorm2d(output_channels) |
|
|
|
|
|
self.diff = False |
|
if (stride != 1) or (input_channels != output_channels): |
|
self.conv_3 = nn.Conv2d(input_channels, output_channels, shape, stride=stride, padding=shape//2) |
|
self.bn_3 = nn.BatchNorm2d(output_channels) |
|
self.diff = True |
|
self.relu = nn.ReLU() |
|
|
|
def forward(self, x): |
|
|
|
out = self.bn_2(self.conv_2(self.relu(self.bn_1(self.conv_1(x))))) |
|
|
|
|
|
if self.diff: |
|
x = self.bn_3(self.conv_3(x)) |
|
out = x + out |
|
out = self.relu(out) |
|
return out |
|
|
|
|
|
class CNNSA(nn.Module): |
|
''' |
|
Won et al. 2019 |
|
Toward interpretable music tagging with self-attention. |
|
Feature extraction with CNN + temporal summary with Transformer encoder. |
|
''' |
|
def __init__(self, |
|
n_channels=128, |
|
sample_rate=16000, |
|
n_fft=512, |
|
f_min=0.0, |
|
f_max=8000.0, |
|
n_mels=128, |
|
n_class=50): |
|
super(CNNSA, self).__init__() |
|
|
|
|
|
self.spec = torchaudio.transforms.MelSpectrogram(sample_rate=sample_rate, |
|
n_fft=n_fft, |
|
f_min=f_min, |
|
f_max=f_max, |
|
n_mels=n_mels) |
|
self.to_db = torchaudio.transforms.AmplitudeToDB() |
|
self.spec_bn = nn.BatchNorm2d(1) |
|
|
|
|
|
self.layer1 = Res_2d(1, n_channels, stride=2) |
|
self.layer2 = Res_2d(n_channels, n_channels, stride=2) |
|
self.layer3 = Res_2d(n_channels, n_channels*2, stride=2) |
|
self.layer4 = Res_2d(n_channels*2, n_channels*2, stride=(2, 1)) |
|
self.layer5 = Res_2d(n_channels*2, n_channels*2, stride=(2, 1)) |
|
self.layer6 = Res_2d(n_channels*2, n_channels*2, stride=(2, 1)) |
|
self.layer7 = Res_2d(n_channels*2, n_channels*2, stride=(2, 1)) |
|
|
|
|
|
bert_config = BertConfig(vocab_size=256, |
|
hidden_size=256, |
|
num_hidden_layers=2, |
|
num_attention_heads=8, |
|
intermediate_size=1024, |
|
hidden_act="gelu", |
|
hidden_dropout_prob=0.4, |
|
max_position_embeddings=700, |
|
attention_probs_dropout_prob=0.5) |
|
self.encoder = BertEncoder(bert_config) |
|
self.pooler = BertPooler(bert_config) |
|
self.vec_cls = self.get_cls(256) |
|
|
|
|
|
self.dropout = nn.Dropout(0.5) |
|
self.dense = nn.Linear(256, n_class) |
|
|
|
def get_cls(self, channel): |
|
np.random.seed(0) |
|
single_cls = torch.Tensor(np.random.random((1, channel))) |
|
vec_cls = torch.cat([single_cls for _ in range(64)], dim=0) |
|
vec_cls = vec_cls.unsqueeze(1) |
|
return vec_cls |
|
|
|
def append_cls(self, x): |
|
batch, _, _ = x.size() |
|
part_vec_cls = self.vec_cls[:batch].clone() |
|
part_vec_cls = part_vec_cls.to(x.device) |
|
return torch.cat([part_vec_cls, x], dim=1) |
|
|
|
def get_spec(self, ids, audio_length=15*16000, allow_random=False): |
|
|
|
wav_list = list() |
|
|
|
for id in ids: |
|
audio_path = os.path.join("/import/c4dm-datasets/Music4All/music4all/audios", id + '.mp3') |
|
(wav, sample_rate) = torchaudio.backend.sox_io_backend.load(audio_path) |
|
|
|
|
|
mono_wav = torch.mean(wav, dim=0) |
|
|
|
|
|
if allow_random: |
|
random_index = random.randint(0, len(mono_wav) - audio_length - 1) |
|
else: |
|
random_index = 0 |
|
mono_wav_cut = mono_wav[random_index: random_index + audio_length] |
|
|
|
wav_list.append(mono_wav_cut) |
|
|
|
|
|
data = torch.stack(wav_list, dim=0) |
|
|
|
|
|
spectrogram = self.spec(data.cuda()) |
|
|
|
return spectrogram |
|
|
|
def forward(self, ids): |
|
|
|
|
|
spec = self.get_spec(ids) |
|
spec_db = self.to_db(spec) |
|
x = spec_db.unsqueeze(1) |
|
x = self.spec_bn(x) |
|
|
|
|
|
x = self.layer1(x) |
|
x = self.layer2(x) |
|
x = self.layer3(x) |
|
x = self.layer4(x) |
|
x = self.layer5(x) |
|
x = self.layer6(x) |
|
x = self.layer7(x) |
|
x = x.squeeze(2) |
|
|
|
|
|
x = x.permute(0, 2, 1) |
|
x = self.append_cls(x) |
|
|
|
|
|
x = self.encoder(x) |
|
x = x[-1] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return x |
|
|
|
|
|
|
|
|
|
|
|
|
|
|