import functools import torch import torch.nn.functional as F import crepe ########################################################################### # Model definition ########################################################################### class Crepe(torch.nn.Module): """Crepe model definition""" def __init__(self, model='full'): super().__init__() # Model-specific layer parameters if model == 'full': in_channels = [1, 1024, 128, 128, 128, 256] out_channels = [1024, 128, 128, 128, 256, 512] self.in_features = 2048 elif model == 'tiny': in_channels = [1, 128, 16, 16, 16, 32] out_channels = [128, 16, 16, 16, 32, 64] self.in_features = 256 else: raise ValueError(f'Model {model} is not supported') # Shared layer parameters kernel_sizes = [(512, 1)] + 5 * [(64, 1)] strides = [(4, 1)] + 5 * [(1, 1)] # Overload with eps and momentum conversion given by MMdnn batch_norm_fn = functools.partial(torch.nn.BatchNorm2d, eps=0.0010000000474974513, momentum=0.0) # Layer definitions self.conv1 = torch.nn.Conv2d( in_channels=in_channels[0], out_channels=out_channels[0], kernel_size=kernel_sizes[0], stride=strides[0]) self.conv1_BN = batch_norm_fn( num_features=out_channels[0]) self.conv2 = torch.nn.Conv2d( in_channels=in_channels[1], out_channels=out_channels[1], kernel_size=kernel_sizes[1], stride=strides[1]) self.conv2_BN = batch_norm_fn( num_features=out_channels[1]) self.conv3 = torch.nn.Conv2d( in_channels=in_channels[2], out_channels=out_channels[2], kernel_size=kernel_sizes[2], stride=strides[2]) self.conv3_BN = batch_norm_fn( num_features=out_channels[2]) self.conv4 = torch.nn.Conv2d( in_channels=in_channels[3], out_channels=out_channels[3], kernel_size=kernel_sizes[3], stride=strides[3]) self.conv4_BN = batch_norm_fn( num_features=out_channels[3]) self.conv5 = torch.nn.Conv2d( in_channels=in_channels[4], out_channels=out_channels[4], kernel_size=kernel_sizes[4], stride=strides[4]) self.conv5_BN = batch_norm_fn( num_features=out_channels[4]) self.conv6 = torch.nn.Conv2d( in_channels=in_channels[5], out_channels=out_channels[5], kernel_size=kernel_sizes[5], stride=strides[5]) self.conv6_BN = batch_norm_fn( num_features=out_channels[5]) self.classifier = torch.nn.Linear( in_features=self.in_features, out_features=crepe.PITCH_BINS) def forward(self, x, embed=False): # Forward pass through first five layers x = self.embed(x) if embed: return x # Forward pass through layer six x = self.layer(x, self.conv6, self.conv6_BN) # shape=(batch, self.in_features) x = x.permute(0, 2, 1, 3).reshape(-1, self.in_features) # Compute logits return torch.sigmoid(self.classifier(x)) ########################################################################### # Forward pass utilities ########################################################################### def embed(self, x): """Map input audio to pitch embedding""" # shape=(batch, 1, 1024, 1) x = x[:, None, :, None] # Forward pass through first five layers x = self.layer(x, self.conv1, self.conv1_BN, (0, 0, 254, 254)) x = self.layer(x, self.conv2, self.conv2_BN) x = self.layer(x, self.conv3, self.conv3_BN) x = self.layer(x, self.conv4, self.conv4_BN) x = self.layer(x, self.conv5, self.conv5_BN) return x def layer(self, x, conv, batch_norm, padding=(0, 0, 31, 32)): """Forward pass through one layer""" x = F.pad(x, padding) x = conv(x) x = F.relu(x) x = batch_norm(x) return F.max_pool2d(x, (2, 1), (2, 1))